fieldwitness/src/soosef/federation/chain.py
Aaron D. Lee 0d8c94bf82 Fix 6 security issues from post-FR audit
- Fix 3 missing CSRF tokens on admin user delete/reset and account
  key delete forms (were broken — CSRFProtect rejected submissions)
- Fix trust store path traversal: untrust_key() now validates
  fingerprint format ([0-9a-f]{32}) and checks resolved path
- Fix chain key rotation: old key is now revoked after rotation
  record, preventing compromised old keys from appending records
- Fix SSRF in deadman webhook: block private/internal IP targets
- Fix logout CSRF: /logout is now POST-only with CSRF token,
  preventing cross-site forced logout via img tags

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-01 19:44:15 -04:00

509 lines
20 KiB
Python

"""
Append-only hash chain store for attestation records.
Storage format:
- chain.bin: length-prefixed CBOR records (uint32 BE + serialized record)
- state.cbor: chain state checkpoint (performance optimization)
The canonical state is always derivable from chain.bin. If state.cbor is
corrupted or missing, it is rebuilt by scanning the log.
"""
from __future__ import annotations
import fcntl
import hashlib
import os
import struct
import time
from collections.abc import Iterator
from pathlib import Path
import cbor2
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
from soosef.exceptions import ChainAppendError, ChainError, ChainIntegrityError
from soosef.federation.entropy import collect_entropy_witnesses
from soosef.federation.models import AttestationChainRecord, ChainState
from soosef.federation.serialization import (
canonical_bytes,
compute_record_hash,
deserialize_record,
serialize_record,
)
# Length prefix: 4 bytes, big-endian unsigned 32-bit
_LEN_STRUCT = struct.Struct(">I")
# Maximum record size: 1 MiB. Far larger than any valid record (~200-500 bytes
# typically). Prevents OOM from corrupted length prefixes in chain.bin.
MAX_RECORD_SIZE = 1_048_576
# Content type for key rotation events. A rotation record is signed by the OLD
# key and carries the new public key in metadata["new_pubkey"] (hex-encoded).
CONTENT_TYPE_KEY_ROTATION = "soosef/key-rotation-v1"
def _now_us() -> int:
"""Current time as Unix microseconds."""
return int(time.time() * 1_000_000)
class ChainStore:
"""Manages an append-only hash chain of attestation records.
Thread safety: single-writer via fcntl.flock. Multiple readers are safe.
Offset index: ``_offsets`` maps chain_index (int) to the byte offset of
that record's length prefix in chain.bin. It is built lazily during
``_rebuild_state()`` and kept up-to-date by ``append()``. The index is
in-memory only — it is reconstructed on every cold load, which is fast
because it is done in the same single pass that already must read every
record to compute the chain state.
"""
def __init__(self, chain_dir: Path):
self._dir = chain_dir
self._chain_file = chain_dir / "chain.bin"
self._state_file = chain_dir / "state.cbor"
self._dir.mkdir(parents=True, exist_ok=True)
self._state: ChainState | None = None
# chain_index → byte offset of the record's 4-byte length prefix.
# None means the index has not been built yet (cold start).
self._offsets: dict[int, int] | None = None
def _load_state(self) -> ChainState | None:
"""Load cached state from state.cbor."""
if self._state is not None:
return self._state
if self._state_file.exists():
data = self._state_file.read_bytes()
m = cbor2.loads(data)
self._state = ChainState(
chain_id=m["chain_id"],
head_index=m["head_index"],
head_hash=m["head_hash"],
record_count=m["record_count"],
created_at=m["created_at"],
last_append_at=m["last_append_at"],
)
return self._state
# No state file — rebuild if chain.bin exists
if self._chain_file.exists() and self._chain_file.stat().st_size > 0:
return self._rebuild_state()
return None
def _save_state(self, state: ChainState) -> None:
"""Atomically write state checkpoint."""
m = {
"chain_id": state.chain_id,
"head_index": state.head_index,
"head_hash": state.head_hash,
"record_count": state.record_count,
"created_at": state.created_at,
"last_append_at": state.last_append_at,
}
tmp = self._state_file.with_suffix(".tmp")
tmp.write_bytes(cbor2.dumps(m, canonical=True))
tmp.rename(self._state_file)
self._state = state
def _rebuild_state(self) -> ChainState:
"""Rebuild state by scanning chain.bin. Used on corruption or first load.
Also builds the in-memory offset index in the same pass so that no
second scan is ever needed.
"""
genesis = None
last = None
count = 0
offsets: dict[int, int] = {}
for offset, record in self._iter_raw_with_offsets():
offsets[record.chain_index] = offset
if count == 0:
genesis = record
last = record
count += 1
if genesis is None or last is None:
raise ChainError("Chain file exists but contains no valid records.")
self._offsets = offsets
state = ChainState(
chain_id=hashlib.sha256(canonical_bytes(genesis)).digest(),
head_index=last.chain_index,
head_hash=compute_record_hash(last),
record_count=count,
created_at=genesis.claimed_ts,
last_append_at=last.claimed_ts,
)
self._save_state(state)
return state
def _iter_raw_with_offsets(self) -> Iterator[tuple[int, AttestationChainRecord]]:
"""Iterate all records, yielding (byte_offset, record) pairs.
``byte_offset`` is the position of the record's 4-byte length prefix
within chain.bin. Used internally to build and exploit the offset index.
"""
if not self._chain_file.exists():
return
with open(self._chain_file, "rb") as f:
while True:
offset = f.tell()
len_bytes = f.read(4)
if len(len_bytes) < 4:
break
(record_len,) = _LEN_STRUCT.unpack(len_bytes)
if record_len > MAX_RECORD_SIZE:
raise ChainError(
f"Record length {record_len} exceeds maximum {MAX_RECORD_SIZE}"
f"chain file may be corrupted"
)
record_bytes = f.read(record_len)
if len(record_bytes) < record_len:
break
yield offset, deserialize_record(record_bytes)
def _iter_raw(self) -> Iterator[AttestationChainRecord]:
"""Iterate all records from chain.bin without state checks."""
for _offset, record in self._iter_raw_with_offsets():
yield record
def _ensure_offsets(self) -> dict[int, int]:
"""Return the offset index, building it if necessary."""
if self._offsets is None:
# Trigger a full scan; _rebuild_state populates self._offsets.
if self._chain_file.exists() and self._chain_file.stat().st_size > 0:
self._rebuild_state()
else:
self._offsets = {}
return self._offsets # type: ignore[return-value]
def _read_record_at(self, offset: int) -> AttestationChainRecord:
"""Read and deserialize the single record whose length prefix is at *offset*."""
with open(self._chain_file, "rb") as f:
f.seek(offset)
len_bytes = f.read(4)
if len(len_bytes) < 4:
raise ChainError(f"Truncated length prefix at offset {offset}.")
(record_len,) = _LEN_STRUCT.unpack(len_bytes)
if record_len > MAX_RECORD_SIZE:
raise ChainError(
f"Record length {record_len} exceeds maximum {MAX_RECORD_SIZE}"
f"chain file may be corrupted"
)
record_bytes = f.read(record_len)
if len(record_bytes) < record_len:
raise ChainError(f"Truncated record body at offset {offset}.")
return deserialize_record(record_bytes)
def state(self) -> ChainState | None:
"""Get current chain state, or None if chain is empty."""
return self._load_state()
def is_empty(self) -> bool:
"""True if the chain has no records."""
return self._load_state() is None
def head(self) -> AttestationChainRecord | None:
"""Return the most recent record, or None if chain is empty."""
state = self._load_state()
if state is None:
return None
return self.get(state.head_index)
def get(self, index: int) -> AttestationChainRecord:
"""Get a record by chain index. O(1) via offset index. Raises ChainError if not found."""
offsets = self._ensure_offsets()
if index not in offsets:
raise ChainError(f"Record at index {index} not found.")
return self._read_record_at(offsets[index])
def iter_records(
self, start: int = 0, end: int | None = None
) -> Iterator[AttestationChainRecord]:
"""Iterate records in [start, end] range (inclusive).
Seeks directly to the first record in range via the offset index, so
records before *start* are never read or deserialized.
"""
offsets = self._ensure_offsets()
if not offsets:
return
# Determine the byte offset to start reading from.
if start in offsets:
seek_offset = offsets[start]
elif start == 0:
seek_offset = 0
else:
# start index not in chain — find the nearest offset above start.
candidates = [off for idx, off in offsets.items() if idx >= start]
if not candidates:
return
seek_offset = min(candidates)
with open(self._chain_file, "rb") as f:
f.seek(seek_offset)
while True:
len_bytes = f.read(4)
if len(len_bytes) < 4:
break
(record_len,) = _LEN_STRUCT.unpack(len_bytes)
if record_len > MAX_RECORD_SIZE:
raise ChainError(
f"Record length {record_len} exceeds maximum {MAX_RECORD_SIZE}"
f"chain file may be corrupted"
)
record_bytes = f.read(record_len)
if len(record_bytes) < record_len:
break
record = deserialize_record(record_bytes)
if end is not None and record.chain_index > end:
break
yield record
def append(
self,
content_hash: bytes,
content_type: str,
private_key: Ed25519PrivateKey,
metadata: dict | None = None,
) -> AttestationChainRecord:
"""Create, sign, and append a new record to the chain.
The entire read-compute-write cycle runs under an exclusive file lock
to prevent concurrent writers from forking the chain (TOCTOU defense).
Args:
content_hash: SHA-256 of the content being attested.
content_type: MIME-like type identifier for the content.
private_key: Ed25519 private key for signing.
metadata: Optional extensible key-value metadata.
Returns:
The newly created and appended AttestationChainRecord.
"""
from uuid_utils import uuid7
# Pre-compute values that don't depend on chain state
public_key = private_key.public_key()
pub_bytes = public_key.public_bytes(Encoding.Raw, PublicFormat.Raw)
try:
with open(self._chain_file, "ab") as f:
fcntl.flock(f, fcntl.LOCK_EX)
try:
# Re-read state INSIDE the lock to prevent TOCTOU races.
# Also invalidate the offset index so that any records
# written by another process since our last read are picked
# up during the ensuing offset rebuild.
self._state = None
self._offsets = None
state = self._load_state()
# Ensure the offset index reflects the current file contents
# (including any records appended by other processes). This
# is a full scan only when state.cbor exists and the index
# was not already built by _rebuild_state() above.
self._ensure_offsets()
now = _now_us()
if state is None:
chain_index = 0
prev_hash = ChainState.GENESIS_PREV_HASH
else:
chain_index = state.head_index + 1
prev_hash = state.head_hash
entropy = collect_entropy_witnesses(self._chain_file)
# Build unsigned record
record = AttestationChainRecord(
version=1,
record_id=uuid7().bytes,
chain_index=chain_index,
prev_hash=prev_hash,
content_hash=content_hash,
content_type=content_type,
metadata=metadata or {},
claimed_ts=now,
entropy_witnesses=entropy,
signer_pubkey=pub_bytes,
signature=b"", # placeholder
)
# Sign canonical bytes
sig = private_key.sign(canonical_bytes(record))
# Replace with signed record (frozen dataclass)
record = AttestationChainRecord(
version=record.version,
record_id=record.record_id,
chain_index=record.chain_index,
prev_hash=record.prev_hash,
content_hash=record.content_hash,
content_type=record.content_type,
metadata=record.metadata,
claimed_ts=record.claimed_ts,
entropy_witnesses=record.entropy_witnesses,
signer_pubkey=record.signer_pubkey,
signature=sig,
)
# Serialize and write
record_bytes = serialize_record(record)
length_prefix = _LEN_STRUCT.pack(len(record_bytes))
# Record the byte offset before writing so it can be added
# to the in-memory offset index without a second file scan.
new_record_offset = f.seek(0, os.SEEK_CUR)
f.write(length_prefix)
f.write(record_bytes)
f.flush()
os.fsync(f.fileno())
# Update state inside the lock
record_hash = compute_record_hash(record)
if state is None:
chain_id = hashlib.sha256(canonical_bytes(record)).digest()
new_state = ChainState(
chain_id=chain_id,
head_index=0,
head_hash=record_hash,
record_count=1,
created_at=now,
last_append_at=now,
)
else:
new_state = ChainState(
chain_id=state.chain_id,
head_index=chain_index,
head_hash=record_hash,
record_count=state.record_count + 1,
created_at=state.created_at,
last_append_at=now,
)
self._save_state(new_state)
# Keep the offset index consistent so subsequent get() /
# iter_records() calls on this instance remain O(1).
if self._offsets is not None:
self._offsets[chain_index] = new_record_offset
finally:
fcntl.flock(f, fcntl.LOCK_UN)
except OSError as e:
raise ChainAppendError(f"Failed to write to chain: {e}") from e
return record
def append_key_rotation(
self,
old_private_key: Ed25519PrivateKey,
new_private_key: Ed25519PrivateKey,
) -> AttestationChainRecord:
"""Record a key rotation event in the chain.
The rotation record is signed by the OLD key and carries the new
public key in metadata. This creates a cryptographic trust chain:
anyone who trusts the old key can verify the transition to the new one.
Args:
old_private_key: The current (soon-to-be-archived) signing key.
new_private_key: The newly generated signing key.
Returns:
The rotation record appended to the chain.
"""
new_pub = new_private_key.public_key()
new_pub_bytes = new_pub.public_bytes(Encoding.Raw, PublicFormat.Raw)
# Content hash is the SHA-256 of the new public key
content_hash = hashlib.sha256(new_pub_bytes).digest()
return self.append(
content_hash=content_hash,
content_type=CONTENT_TYPE_KEY_ROTATION,
private_key=old_private_key,
metadata={"new_pubkey": new_pub_bytes.hex()},
)
def verify_chain(self, start: int = 0, end: int | None = None) -> bool:
"""Verify hash chain integrity and signatures over a range.
Args:
start: First record index to verify (default 0).
end: Last record index to verify (default: head).
Returns:
True if the chain is valid.
Raises:
ChainIntegrityError: If any integrity check fails.
"""
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey
prev_record: AttestationChainRecord | None = None
expected_index = start
authorized_signers: set[bytes] = set()
# If starting from 0, first record must have genesis prev_hash
if start > 0:
# Load the record before start to check the first prev_hash
try:
prev_record = self.get(start - 1)
except ChainError:
pass # Can't verify prev_hash of first record in range
for record in self.iter_records(start, end):
# Check index continuity
if record.chain_index != expected_index:
raise ChainIntegrityError(
f"Expected index {expected_index}, got {record.chain_index}"
)
# Check prev_hash linkage
if prev_record is not None:
expected_hash = compute_record_hash(prev_record)
if record.prev_hash != expected_hash:
raise ChainIntegrityError(f"Record {record.chain_index}: prev_hash mismatch")
elif record.chain_index == 0:
if record.prev_hash != ChainState.GENESIS_PREV_HASH:
raise ChainIntegrityError("Genesis record has non-zero prev_hash")
# Check signature
try:
pub = Ed25519PublicKey.from_public_bytes(record.signer_pubkey)
pub.verify(record.signature, canonical_bytes(record))
except Exception as e:
raise ChainIntegrityError(
f"Record {record.chain_index}: signature verification failed: {e}"
) from e
# Track authorized signers: the genesis signer plus any keys
# introduced by valid key-rotation records.
if not authorized_signers:
authorized_signers.add(record.signer_pubkey)
elif record.signer_pubkey not in authorized_signers:
raise ChainIntegrityError(
f"Record {record.chain_index}: signer "
f"{record.signer_pubkey.hex()[:16]}... is not authorized"
)
# If this is a key rotation record, authorize the new key
if record.content_type == CONTENT_TYPE_KEY_ROTATION:
new_pubkey_hex = record.metadata.get("new_pubkey")
if not new_pubkey_hex:
raise ChainIntegrityError(
f"Record {record.chain_index}: key rotation missing new_pubkey"
)
authorized_signers.add(bytes.fromhex(new_pubkey_hex))
# Revoke the old key — the rotation record was its last authorized action
authorized_signers.discard(record.signer_pubkey)
prev_record = record
expected_index += 1
return True