""" Append-only hash chain store for attestation records. Storage format: - chain.bin: length-prefixed CBOR records (uint32 BE + serialized record) - state.cbor: chain state checkpoint (performance optimization) The canonical state is always derivable from chain.bin. If state.cbor is corrupted or missing, it is rebuilt by scanning the log. """ from __future__ import annotations import fcntl import hashlib import os import struct import time from collections.abc import Iterator from pathlib import Path import cbor2 from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat from soosef.exceptions import ChainAppendError, ChainError, ChainIntegrityError from soosef.federation.entropy import collect_entropy_witnesses from soosef.federation.models import AttestationChainRecord, ChainState from soosef.federation.serialization import ( canonical_bytes, compute_record_hash, deserialize_record, serialize_record, ) # Length prefix: 4 bytes, big-endian unsigned 32-bit _LEN_STRUCT = struct.Struct(">I") # Maximum record size: 1 MiB. Far larger than any valid record (~200-500 bytes # typically). Prevents OOM from corrupted length prefixes in chain.bin. MAX_RECORD_SIZE = 1_048_576 # Content type for key rotation events. A rotation record is signed by the OLD # key and carries the new public key in metadata["new_pubkey"] (hex-encoded). CONTENT_TYPE_KEY_ROTATION = "soosef/key-rotation-v1" def _now_us() -> int: """Current time as Unix microseconds.""" return int(time.time() * 1_000_000) class ChainStore: """Manages an append-only hash chain of attestation records. Thread safety: single-writer via fcntl.flock. Multiple readers are safe. Offset index: ``_offsets`` maps chain_index (int) to the byte offset of that record's length prefix in chain.bin. It is built lazily during ``_rebuild_state()`` and kept up-to-date by ``append()``. The index is in-memory only — it is reconstructed on every cold load, which is fast because it is done in the same single pass that already must read every record to compute the chain state. """ def __init__(self, chain_dir: Path): self._dir = chain_dir self._chain_file = chain_dir / "chain.bin" self._state_file = chain_dir / "state.cbor" self._dir.mkdir(parents=True, exist_ok=True) self._state: ChainState | None = None # chain_index → byte offset of the record's 4-byte length prefix. # None means the index has not been built yet (cold start). self._offsets: dict[int, int] | None = None def _load_state(self) -> ChainState | None: """Load cached state from state.cbor.""" if self._state is not None: return self._state if self._state_file.exists(): data = self._state_file.read_bytes() m = cbor2.loads(data) self._state = ChainState( chain_id=m["chain_id"], head_index=m["head_index"], head_hash=m["head_hash"], record_count=m["record_count"], created_at=m["created_at"], last_append_at=m["last_append_at"], ) return self._state # No state file — rebuild if chain.bin exists if self._chain_file.exists() and self._chain_file.stat().st_size > 0: return self._rebuild_state() return None def _save_state(self, state: ChainState) -> None: """Atomically write state checkpoint.""" m = { "chain_id": state.chain_id, "head_index": state.head_index, "head_hash": state.head_hash, "record_count": state.record_count, "created_at": state.created_at, "last_append_at": state.last_append_at, } tmp = self._state_file.with_suffix(".tmp") tmp.write_bytes(cbor2.dumps(m, canonical=True)) tmp.rename(self._state_file) self._state = state def _rebuild_state(self) -> ChainState: """Rebuild state by scanning chain.bin. Used on corruption or first load. Also builds the in-memory offset index in the same pass so that no second scan is ever needed. """ genesis = None last = None count = 0 offsets: dict[int, int] = {} for offset, record in self._iter_raw_with_offsets(): offsets[record.chain_index] = offset if count == 0: genesis = record last = record count += 1 if genesis is None or last is None: raise ChainError("Chain file exists but contains no valid records.") self._offsets = offsets state = ChainState( chain_id=hashlib.sha256(canonical_bytes(genesis)).digest(), head_index=last.chain_index, head_hash=compute_record_hash(last), record_count=count, created_at=genesis.claimed_ts, last_append_at=last.claimed_ts, ) self._save_state(state) return state def _iter_raw_with_offsets(self) -> Iterator[tuple[int, AttestationChainRecord]]: """Iterate all records, yielding (byte_offset, record) pairs. ``byte_offset`` is the position of the record's 4-byte length prefix within chain.bin. Used internally to build and exploit the offset index. """ if not self._chain_file.exists(): return with open(self._chain_file, "rb") as f: while True: offset = f.tell() len_bytes = f.read(4) if len(len_bytes) < 4: break (record_len,) = _LEN_STRUCT.unpack(len_bytes) if record_len > MAX_RECORD_SIZE: raise ChainError( f"Record length {record_len} exceeds maximum {MAX_RECORD_SIZE} — " f"chain file may be corrupted" ) record_bytes = f.read(record_len) if len(record_bytes) < record_len: break yield offset, deserialize_record(record_bytes) def _iter_raw(self) -> Iterator[AttestationChainRecord]: """Iterate all records from chain.bin without state checks.""" for _offset, record in self._iter_raw_with_offsets(): yield record def _ensure_offsets(self) -> dict[int, int]: """Return the offset index, building it if necessary.""" if self._offsets is None: # Trigger a full scan; _rebuild_state populates self._offsets. if self._chain_file.exists() and self._chain_file.stat().st_size > 0: self._rebuild_state() else: self._offsets = {} return self._offsets # type: ignore[return-value] def _read_record_at(self, offset: int) -> AttestationChainRecord: """Read and deserialize the single record whose length prefix is at *offset*.""" with open(self._chain_file, "rb") as f: f.seek(offset) len_bytes = f.read(4) if len(len_bytes) < 4: raise ChainError(f"Truncated length prefix at offset {offset}.") (record_len,) = _LEN_STRUCT.unpack(len_bytes) if record_len > MAX_RECORD_SIZE: raise ChainError( f"Record length {record_len} exceeds maximum {MAX_RECORD_SIZE} — " f"chain file may be corrupted" ) record_bytes = f.read(record_len) if len(record_bytes) < record_len: raise ChainError(f"Truncated record body at offset {offset}.") return deserialize_record(record_bytes) def state(self) -> ChainState | None: """Get current chain state, or None if chain is empty.""" return self._load_state() def is_empty(self) -> bool: """True if the chain has no records.""" return self._load_state() is None def head(self) -> AttestationChainRecord | None: """Return the most recent record, or None if chain is empty.""" state = self._load_state() if state is None: return None return self.get(state.head_index) def get(self, index: int) -> AttestationChainRecord: """Get a record by chain index. O(1) via offset index. Raises ChainError if not found.""" offsets = self._ensure_offsets() if index not in offsets: raise ChainError(f"Record at index {index} not found.") return self._read_record_at(offsets[index]) def iter_records( self, start: int = 0, end: int | None = None ) -> Iterator[AttestationChainRecord]: """Iterate records in [start, end] range (inclusive). Seeks directly to the first record in range via the offset index, so records before *start* are never read or deserialized. """ offsets = self._ensure_offsets() if not offsets: return # Determine the byte offset to start reading from. if start in offsets: seek_offset = offsets[start] elif start == 0: seek_offset = 0 else: # start index not in chain — find the nearest offset above start. candidates = [off for idx, off in offsets.items() if idx >= start] if not candidates: return seek_offset = min(candidates) with open(self._chain_file, "rb") as f: f.seek(seek_offset) while True: len_bytes = f.read(4) if len(len_bytes) < 4: break (record_len,) = _LEN_STRUCT.unpack(len_bytes) if record_len > MAX_RECORD_SIZE: raise ChainError( f"Record length {record_len} exceeds maximum {MAX_RECORD_SIZE} — " f"chain file may be corrupted" ) record_bytes = f.read(record_len) if len(record_bytes) < record_len: break record = deserialize_record(record_bytes) if end is not None and record.chain_index > end: break yield record def append( self, content_hash: bytes, content_type: str, private_key: Ed25519PrivateKey, metadata: dict | None = None, ) -> AttestationChainRecord: """Create, sign, and append a new record to the chain. The entire read-compute-write cycle runs under an exclusive file lock to prevent concurrent writers from forking the chain (TOCTOU defense). Args: content_hash: SHA-256 of the content being attested. content_type: MIME-like type identifier for the content. private_key: Ed25519 private key for signing. metadata: Optional extensible key-value metadata. Returns: The newly created and appended AttestationChainRecord. """ from uuid_utils import uuid7 # Pre-compute values that don't depend on chain state public_key = private_key.public_key() pub_bytes = public_key.public_bytes(Encoding.Raw, PublicFormat.Raw) try: with open(self._chain_file, "ab") as f: fcntl.flock(f, fcntl.LOCK_EX) try: # Re-read state INSIDE the lock to prevent TOCTOU races. # Also invalidate the offset index so that any records # written by another process since our last read are picked # up during the ensuing offset rebuild. self._state = None self._offsets = None state = self._load_state() # Ensure the offset index reflects the current file contents # (including any records appended by other processes). This # is a full scan only when state.cbor exists and the index # was not already built by _rebuild_state() above. self._ensure_offsets() now = _now_us() if state is None: chain_index = 0 prev_hash = ChainState.GENESIS_PREV_HASH else: chain_index = state.head_index + 1 prev_hash = state.head_hash entropy = collect_entropy_witnesses(self._chain_file) # Build unsigned record record = AttestationChainRecord( version=1, record_id=uuid7().bytes, chain_index=chain_index, prev_hash=prev_hash, content_hash=content_hash, content_type=content_type, metadata=metadata or {}, claimed_ts=now, entropy_witnesses=entropy, signer_pubkey=pub_bytes, signature=b"", # placeholder ) # Sign canonical bytes sig = private_key.sign(canonical_bytes(record)) # Replace with signed record (frozen dataclass) record = AttestationChainRecord( version=record.version, record_id=record.record_id, chain_index=record.chain_index, prev_hash=record.prev_hash, content_hash=record.content_hash, content_type=record.content_type, metadata=record.metadata, claimed_ts=record.claimed_ts, entropy_witnesses=record.entropy_witnesses, signer_pubkey=record.signer_pubkey, signature=sig, ) # Serialize and write record_bytes = serialize_record(record) length_prefix = _LEN_STRUCT.pack(len(record_bytes)) # Record the byte offset before writing so it can be added # to the in-memory offset index without a second file scan. new_record_offset = f.seek(0, os.SEEK_CUR) f.write(length_prefix) f.write(record_bytes) f.flush() os.fsync(f.fileno()) # Update state inside the lock record_hash = compute_record_hash(record) if state is None: chain_id = hashlib.sha256(canonical_bytes(record)).digest() new_state = ChainState( chain_id=chain_id, head_index=0, head_hash=record_hash, record_count=1, created_at=now, last_append_at=now, ) else: new_state = ChainState( chain_id=state.chain_id, head_index=chain_index, head_hash=record_hash, record_count=state.record_count + 1, created_at=state.created_at, last_append_at=now, ) self._save_state(new_state) # Keep the offset index consistent so subsequent get() / # iter_records() calls on this instance remain O(1). if self._offsets is not None: self._offsets[chain_index] = new_record_offset finally: fcntl.flock(f, fcntl.LOCK_UN) except OSError as e: raise ChainAppendError(f"Failed to write to chain: {e}") from e return record def append_key_rotation( self, old_private_key: Ed25519PrivateKey, new_private_key: Ed25519PrivateKey, ) -> AttestationChainRecord: """Record a key rotation event in the chain. The rotation record is signed by the OLD key and carries the new public key in metadata. This creates a cryptographic trust chain: anyone who trusts the old key can verify the transition to the new one. Args: old_private_key: The current (soon-to-be-archived) signing key. new_private_key: The newly generated signing key. Returns: The rotation record appended to the chain. """ new_pub = new_private_key.public_key() new_pub_bytes = new_pub.public_bytes(Encoding.Raw, PublicFormat.Raw) # Content hash is the SHA-256 of the new public key content_hash = hashlib.sha256(new_pub_bytes).digest() return self.append( content_hash=content_hash, content_type=CONTENT_TYPE_KEY_ROTATION, private_key=old_private_key, metadata={"new_pubkey": new_pub_bytes.hex()}, ) def verify_chain(self, start: int = 0, end: int | None = None) -> bool: """Verify hash chain integrity and signatures over a range. Args: start: First record index to verify (default 0). end: Last record index to verify (default: head). Returns: True if the chain is valid. Raises: ChainIntegrityError: If any integrity check fails. """ from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey prev_record: AttestationChainRecord | None = None expected_index = start authorized_signers: set[bytes] = set() # If starting from 0, first record must have genesis prev_hash if start > 0: # Load the record before start to check the first prev_hash try: prev_record = self.get(start - 1) except ChainError: pass # Can't verify prev_hash of first record in range for record in self.iter_records(start, end): # Check index continuity if record.chain_index != expected_index: raise ChainIntegrityError( f"Expected index {expected_index}, got {record.chain_index}" ) # Check prev_hash linkage if prev_record is not None: expected_hash = compute_record_hash(prev_record) if record.prev_hash != expected_hash: raise ChainIntegrityError(f"Record {record.chain_index}: prev_hash mismatch") elif record.chain_index == 0: if record.prev_hash != ChainState.GENESIS_PREV_HASH: raise ChainIntegrityError("Genesis record has non-zero prev_hash") # Check signature try: pub = Ed25519PublicKey.from_public_bytes(record.signer_pubkey) pub.verify(record.signature, canonical_bytes(record)) except Exception as e: raise ChainIntegrityError( f"Record {record.chain_index}: signature verification failed: {e}" ) from e # Track authorized signers: the genesis signer plus any keys # introduced by valid key-rotation records. if not authorized_signers: authorized_signers.add(record.signer_pubkey) elif record.signer_pubkey not in authorized_signers: raise ChainIntegrityError( f"Record {record.chain_index}: signer " f"{record.signer_pubkey.hex()[:16]}... is not authorized" ) # If this is a key rotation record, authorize the new key if record.content_type == CONTENT_TYPE_KEY_ROTATION: new_pubkey_hex = record.metadata.get("new_pubkey") if not new_pubkey_hex: raise ChainIntegrityError( f"Record {record.chain_index}: key rotation missing new_pubkey" ) authorized_signers.add(bytes.fromhex(new_pubkey_hex)) # Revoke the old key — the rotation record was its last authorized action authorized_signers.discard(record.signer_pubkey) prev_record = record expected_index += 1 return True