- Fix 3 missing CSRF tokens on admin user delete/reset and account
key delete forms (were broken — CSRFProtect rejected submissions)
- Fix trust store path traversal: untrust_key() now validates
fingerprint format ([0-9a-f]{32}) and checks resolved path
- Fix chain key rotation: old key is now revoked after rotation
record, preventing compromised old keys from appending records
- Fix SSRF in deadman webhook: block private/internal IP targets
- Fix logout CSRF: /logout is now POST-only with CSRF token,
preventing cross-site forced logout via img tags
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
509 lines
20 KiB
Python
509 lines
20 KiB
Python
"""
|
|
Append-only hash chain store for attestation records.
|
|
|
|
Storage format:
|
|
- chain.bin: length-prefixed CBOR records (uint32 BE + serialized record)
|
|
- state.cbor: chain state checkpoint (performance optimization)
|
|
|
|
The canonical state is always derivable from chain.bin. If state.cbor is
|
|
corrupted or missing, it is rebuilt by scanning the log.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import fcntl
|
|
import hashlib
|
|
import os
|
|
import struct
|
|
import time
|
|
from collections.abc import Iterator
|
|
from pathlib import Path
|
|
|
|
import cbor2
|
|
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
|
|
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
|
|
|
|
from soosef.exceptions import ChainAppendError, ChainError, ChainIntegrityError
|
|
from soosef.federation.entropy import collect_entropy_witnesses
|
|
from soosef.federation.models import AttestationChainRecord, ChainState
|
|
from soosef.federation.serialization import (
|
|
canonical_bytes,
|
|
compute_record_hash,
|
|
deserialize_record,
|
|
serialize_record,
|
|
)
|
|
|
|
# Length prefix: 4 bytes, big-endian unsigned 32-bit
|
|
_LEN_STRUCT = struct.Struct(">I")
|
|
|
|
# Maximum record size: 1 MiB. Far larger than any valid record (~200-500 bytes
|
|
# typically). Prevents OOM from corrupted length prefixes in chain.bin.
|
|
MAX_RECORD_SIZE = 1_048_576
|
|
|
|
# Content type for key rotation events. A rotation record is signed by the OLD
|
|
# key and carries the new public key in metadata["new_pubkey"] (hex-encoded).
|
|
CONTENT_TYPE_KEY_ROTATION = "soosef/key-rotation-v1"
|
|
|
|
|
|
def _now_us() -> int:
|
|
"""Current time as Unix microseconds."""
|
|
return int(time.time() * 1_000_000)
|
|
|
|
|
|
class ChainStore:
|
|
"""Manages an append-only hash chain of attestation records.
|
|
|
|
Thread safety: single-writer via fcntl.flock. Multiple readers are safe.
|
|
|
|
Offset index: ``_offsets`` maps chain_index (int) to the byte offset of
|
|
that record's length prefix in chain.bin. It is built lazily during
|
|
``_rebuild_state()`` and kept up-to-date by ``append()``. The index is
|
|
in-memory only — it is reconstructed on every cold load, which is fast
|
|
because it is done in the same single pass that already must read every
|
|
record to compute the chain state.
|
|
"""
|
|
|
|
def __init__(self, chain_dir: Path):
|
|
self._dir = chain_dir
|
|
self._chain_file = chain_dir / "chain.bin"
|
|
self._state_file = chain_dir / "state.cbor"
|
|
self._dir.mkdir(parents=True, exist_ok=True)
|
|
self._state: ChainState | None = None
|
|
# chain_index → byte offset of the record's 4-byte length prefix.
|
|
# None means the index has not been built yet (cold start).
|
|
self._offsets: dict[int, int] | None = None
|
|
|
|
def _load_state(self) -> ChainState | None:
|
|
"""Load cached state from state.cbor."""
|
|
if self._state is not None:
|
|
return self._state
|
|
if self._state_file.exists():
|
|
data = self._state_file.read_bytes()
|
|
m = cbor2.loads(data)
|
|
self._state = ChainState(
|
|
chain_id=m["chain_id"],
|
|
head_index=m["head_index"],
|
|
head_hash=m["head_hash"],
|
|
record_count=m["record_count"],
|
|
created_at=m["created_at"],
|
|
last_append_at=m["last_append_at"],
|
|
)
|
|
return self._state
|
|
# No state file — rebuild if chain.bin exists
|
|
if self._chain_file.exists() and self._chain_file.stat().st_size > 0:
|
|
return self._rebuild_state()
|
|
return None
|
|
|
|
def _save_state(self, state: ChainState) -> None:
|
|
"""Atomically write state checkpoint."""
|
|
m = {
|
|
"chain_id": state.chain_id,
|
|
"head_index": state.head_index,
|
|
"head_hash": state.head_hash,
|
|
"record_count": state.record_count,
|
|
"created_at": state.created_at,
|
|
"last_append_at": state.last_append_at,
|
|
}
|
|
tmp = self._state_file.with_suffix(".tmp")
|
|
tmp.write_bytes(cbor2.dumps(m, canonical=True))
|
|
tmp.rename(self._state_file)
|
|
self._state = state
|
|
|
|
def _rebuild_state(self) -> ChainState:
|
|
"""Rebuild state by scanning chain.bin. Used on corruption or first load.
|
|
|
|
Also builds the in-memory offset index in the same pass so that no
|
|
second scan is ever needed.
|
|
"""
|
|
genesis = None
|
|
last = None
|
|
count = 0
|
|
offsets: dict[int, int] = {}
|
|
for offset, record in self._iter_raw_with_offsets():
|
|
offsets[record.chain_index] = offset
|
|
if count == 0:
|
|
genesis = record
|
|
last = record
|
|
count += 1
|
|
|
|
if genesis is None or last is None:
|
|
raise ChainError("Chain file exists but contains no valid records.")
|
|
|
|
self._offsets = offsets
|
|
|
|
state = ChainState(
|
|
chain_id=hashlib.sha256(canonical_bytes(genesis)).digest(),
|
|
head_index=last.chain_index,
|
|
head_hash=compute_record_hash(last),
|
|
record_count=count,
|
|
created_at=genesis.claimed_ts,
|
|
last_append_at=last.claimed_ts,
|
|
)
|
|
self._save_state(state)
|
|
return state
|
|
|
|
def _iter_raw_with_offsets(self) -> Iterator[tuple[int, AttestationChainRecord]]:
|
|
"""Iterate all records, yielding (byte_offset, record) pairs.
|
|
|
|
``byte_offset`` is the position of the record's 4-byte length prefix
|
|
within chain.bin. Used internally to build and exploit the offset index.
|
|
"""
|
|
if not self._chain_file.exists():
|
|
return
|
|
with open(self._chain_file, "rb") as f:
|
|
while True:
|
|
offset = f.tell()
|
|
len_bytes = f.read(4)
|
|
if len(len_bytes) < 4:
|
|
break
|
|
(record_len,) = _LEN_STRUCT.unpack(len_bytes)
|
|
if record_len > MAX_RECORD_SIZE:
|
|
raise ChainError(
|
|
f"Record length {record_len} exceeds maximum {MAX_RECORD_SIZE} — "
|
|
f"chain file may be corrupted"
|
|
)
|
|
record_bytes = f.read(record_len)
|
|
if len(record_bytes) < record_len:
|
|
break
|
|
yield offset, deserialize_record(record_bytes)
|
|
|
|
def _iter_raw(self) -> Iterator[AttestationChainRecord]:
|
|
"""Iterate all records from chain.bin without state checks."""
|
|
for _offset, record in self._iter_raw_with_offsets():
|
|
yield record
|
|
|
|
def _ensure_offsets(self) -> dict[int, int]:
|
|
"""Return the offset index, building it if necessary."""
|
|
if self._offsets is None:
|
|
# Trigger a full scan; _rebuild_state populates self._offsets.
|
|
if self._chain_file.exists() and self._chain_file.stat().st_size > 0:
|
|
self._rebuild_state()
|
|
else:
|
|
self._offsets = {}
|
|
return self._offsets # type: ignore[return-value]
|
|
|
|
def _read_record_at(self, offset: int) -> AttestationChainRecord:
|
|
"""Read and deserialize the single record whose length prefix is at *offset*."""
|
|
with open(self._chain_file, "rb") as f:
|
|
f.seek(offset)
|
|
len_bytes = f.read(4)
|
|
if len(len_bytes) < 4:
|
|
raise ChainError(f"Truncated length prefix at offset {offset}.")
|
|
(record_len,) = _LEN_STRUCT.unpack(len_bytes)
|
|
if record_len > MAX_RECORD_SIZE:
|
|
raise ChainError(
|
|
f"Record length {record_len} exceeds maximum {MAX_RECORD_SIZE} — "
|
|
f"chain file may be corrupted"
|
|
)
|
|
record_bytes = f.read(record_len)
|
|
if len(record_bytes) < record_len:
|
|
raise ChainError(f"Truncated record body at offset {offset}.")
|
|
return deserialize_record(record_bytes)
|
|
|
|
def state(self) -> ChainState | None:
|
|
"""Get current chain state, or None if chain is empty."""
|
|
return self._load_state()
|
|
|
|
def is_empty(self) -> bool:
|
|
"""True if the chain has no records."""
|
|
return self._load_state() is None
|
|
|
|
def head(self) -> AttestationChainRecord | None:
|
|
"""Return the most recent record, or None if chain is empty."""
|
|
state = self._load_state()
|
|
if state is None:
|
|
return None
|
|
return self.get(state.head_index)
|
|
|
|
def get(self, index: int) -> AttestationChainRecord:
|
|
"""Get a record by chain index. O(1) via offset index. Raises ChainError if not found."""
|
|
offsets = self._ensure_offsets()
|
|
if index not in offsets:
|
|
raise ChainError(f"Record at index {index} not found.")
|
|
return self._read_record_at(offsets[index])
|
|
|
|
def iter_records(
|
|
self, start: int = 0, end: int | None = None
|
|
) -> Iterator[AttestationChainRecord]:
|
|
"""Iterate records in [start, end] range (inclusive).
|
|
|
|
Seeks directly to the first record in range via the offset index, so
|
|
records before *start* are never read or deserialized.
|
|
"""
|
|
offsets = self._ensure_offsets()
|
|
if not offsets:
|
|
return
|
|
|
|
# Determine the byte offset to start reading from.
|
|
if start in offsets:
|
|
seek_offset = offsets[start]
|
|
elif start == 0:
|
|
seek_offset = 0
|
|
else:
|
|
# start index not in chain — find the nearest offset above start.
|
|
candidates = [off for idx, off in offsets.items() if idx >= start]
|
|
if not candidates:
|
|
return
|
|
seek_offset = min(candidates)
|
|
|
|
with open(self._chain_file, "rb") as f:
|
|
f.seek(seek_offset)
|
|
while True:
|
|
len_bytes = f.read(4)
|
|
if len(len_bytes) < 4:
|
|
break
|
|
(record_len,) = _LEN_STRUCT.unpack(len_bytes)
|
|
if record_len > MAX_RECORD_SIZE:
|
|
raise ChainError(
|
|
f"Record length {record_len} exceeds maximum {MAX_RECORD_SIZE} — "
|
|
f"chain file may be corrupted"
|
|
)
|
|
record_bytes = f.read(record_len)
|
|
if len(record_bytes) < record_len:
|
|
break
|
|
record = deserialize_record(record_bytes)
|
|
if end is not None and record.chain_index > end:
|
|
break
|
|
yield record
|
|
|
|
def append(
|
|
self,
|
|
content_hash: bytes,
|
|
content_type: str,
|
|
private_key: Ed25519PrivateKey,
|
|
metadata: dict | None = None,
|
|
) -> AttestationChainRecord:
|
|
"""Create, sign, and append a new record to the chain.
|
|
|
|
The entire read-compute-write cycle runs under an exclusive file lock
|
|
to prevent concurrent writers from forking the chain (TOCTOU defense).
|
|
|
|
Args:
|
|
content_hash: SHA-256 of the content being attested.
|
|
content_type: MIME-like type identifier for the content.
|
|
private_key: Ed25519 private key for signing.
|
|
metadata: Optional extensible key-value metadata.
|
|
|
|
Returns:
|
|
The newly created and appended AttestationChainRecord.
|
|
"""
|
|
from uuid_utils import uuid7
|
|
|
|
# Pre-compute values that don't depend on chain state
|
|
public_key = private_key.public_key()
|
|
pub_bytes = public_key.public_bytes(Encoding.Raw, PublicFormat.Raw)
|
|
|
|
try:
|
|
with open(self._chain_file, "ab") as f:
|
|
fcntl.flock(f, fcntl.LOCK_EX)
|
|
try:
|
|
# Re-read state INSIDE the lock to prevent TOCTOU races.
|
|
# Also invalidate the offset index so that any records
|
|
# written by another process since our last read are picked
|
|
# up during the ensuing offset rebuild.
|
|
self._state = None
|
|
self._offsets = None
|
|
state = self._load_state()
|
|
# Ensure the offset index reflects the current file contents
|
|
# (including any records appended by other processes). This
|
|
# is a full scan only when state.cbor exists and the index
|
|
# was not already built by _rebuild_state() above.
|
|
self._ensure_offsets()
|
|
now = _now_us()
|
|
|
|
if state is None:
|
|
chain_index = 0
|
|
prev_hash = ChainState.GENESIS_PREV_HASH
|
|
else:
|
|
chain_index = state.head_index + 1
|
|
prev_hash = state.head_hash
|
|
|
|
entropy = collect_entropy_witnesses(self._chain_file)
|
|
|
|
# Build unsigned record
|
|
record = AttestationChainRecord(
|
|
version=1,
|
|
record_id=uuid7().bytes,
|
|
chain_index=chain_index,
|
|
prev_hash=prev_hash,
|
|
content_hash=content_hash,
|
|
content_type=content_type,
|
|
metadata=metadata or {},
|
|
claimed_ts=now,
|
|
entropy_witnesses=entropy,
|
|
signer_pubkey=pub_bytes,
|
|
signature=b"", # placeholder
|
|
)
|
|
|
|
# Sign canonical bytes
|
|
sig = private_key.sign(canonical_bytes(record))
|
|
|
|
# Replace with signed record (frozen dataclass)
|
|
record = AttestationChainRecord(
|
|
version=record.version,
|
|
record_id=record.record_id,
|
|
chain_index=record.chain_index,
|
|
prev_hash=record.prev_hash,
|
|
content_hash=record.content_hash,
|
|
content_type=record.content_type,
|
|
metadata=record.metadata,
|
|
claimed_ts=record.claimed_ts,
|
|
entropy_witnesses=record.entropy_witnesses,
|
|
signer_pubkey=record.signer_pubkey,
|
|
signature=sig,
|
|
)
|
|
|
|
# Serialize and write
|
|
record_bytes = serialize_record(record)
|
|
length_prefix = _LEN_STRUCT.pack(len(record_bytes))
|
|
# Record the byte offset before writing so it can be added
|
|
# to the in-memory offset index without a second file scan.
|
|
new_record_offset = f.seek(0, os.SEEK_CUR)
|
|
f.write(length_prefix)
|
|
f.write(record_bytes)
|
|
f.flush()
|
|
os.fsync(f.fileno())
|
|
|
|
# Update state inside the lock
|
|
record_hash = compute_record_hash(record)
|
|
if state is None:
|
|
chain_id = hashlib.sha256(canonical_bytes(record)).digest()
|
|
new_state = ChainState(
|
|
chain_id=chain_id,
|
|
head_index=0,
|
|
head_hash=record_hash,
|
|
record_count=1,
|
|
created_at=now,
|
|
last_append_at=now,
|
|
)
|
|
else:
|
|
new_state = ChainState(
|
|
chain_id=state.chain_id,
|
|
head_index=chain_index,
|
|
head_hash=record_hash,
|
|
record_count=state.record_count + 1,
|
|
created_at=state.created_at,
|
|
last_append_at=now,
|
|
)
|
|
self._save_state(new_state)
|
|
|
|
# Keep the offset index consistent so subsequent get() /
|
|
# iter_records() calls on this instance remain O(1).
|
|
if self._offsets is not None:
|
|
self._offsets[chain_index] = new_record_offset
|
|
finally:
|
|
fcntl.flock(f, fcntl.LOCK_UN)
|
|
except OSError as e:
|
|
raise ChainAppendError(f"Failed to write to chain: {e}") from e
|
|
|
|
return record
|
|
|
|
def append_key_rotation(
|
|
self,
|
|
old_private_key: Ed25519PrivateKey,
|
|
new_private_key: Ed25519PrivateKey,
|
|
) -> AttestationChainRecord:
|
|
"""Record a key rotation event in the chain.
|
|
|
|
The rotation record is signed by the OLD key and carries the new
|
|
public key in metadata. This creates a cryptographic trust chain:
|
|
anyone who trusts the old key can verify the transition to the new one.
|
|
|
|
Args:
|
|
old_private_key: The current (soon-to-be-archived) signing key.
|
|
new_private_key: The newly generated signing key.
|
|
|
|
Returns:
|
|
The rotation record appended to the chain.
|
|
"""
|
|
new_pub = new_private_key.public_key()
|
|
new_pub_bytes = new_pub.public_bytes(Encoding.Raw, PublicFormat.Raw)
|
|
|
|
# Content hash is the SHA-256 of the new public key
|
|
content_hash = hashlib.sha256(new_pub_bytes).digest()
|
|
|
|
return self.append(
|
|
content_hash=content_hash,
|
|
content_type=CONTENT_TYPE_KEY_ROTATION,
|
|
private_key=old_private_key,
|
|
metadata={"new_pubkey": new_pub_bytes.hex()},
|
|
)
|
|
|
|
def verify_chain(self, start: int = 0, end: int | None = None) -> bool:
|
|
"""Verify hash chain integrity and signatures over a range.
|
|
|
|
Args:
|
|
start: First record index to verify (default 0).
|
|
end: Last record index to verify (default: head).
|
|
|
|
Returns:
|
|
True if the chain is valid.
|
|
|
|
Raises:
|
|
ChainIntegrityError: If any integrity check fails.
|
|
"""
|
|
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey
|
|
|
|
prev_record: AttestationChainRecord | None = None
|
|
expected_index = start
|
|
authorized_signers: set[bytes] = set()
|
|
|
|
# If starting from 0, first record must have genesis prev_hash
|
|
if start > 0:
|
|
# Load the record before start to check the first prev_hash
|
|
try:
|
|
prev_record = self.get(start - 1)
|
|
except ChainError:
|
|
pass # Can't verify prev_hash of first record in range
|
|
|
|
for record in self.iter_records(start, end):
|
|
# Check index continuity
|
|
if record.chain_index != expected_index:
|
|
raise ChainIntegrityError(
|
|
f"Expected index {expected_index}, got {record.chain_index}"
|
|
)
|
|
|
|
# Check prev_hash linkage
|
|
if prev_record is not None:
|
|
expected_hash = compute_record_hash(prev_record)
|
|
if record.prev_hash != expected_hash:
|
|
raise ChainIntegrityError(f"Record {record.chain_index}: prev_hash mismatch")
|
|
elif record.chain_index == 0:
|
|
if record.prev_hash != ChainState.GENESIS_PREV_HASH:
|
|
raise ChainIntegrityError("Genesis record has non-zero prev_hash")
|
|
|
|
# Check signature
|
|
try:
|
|
pub = Ed25519PublicKey.from_public_bytes(record.signer_pubkey)
|
|
pub.verify(record.signature, canonical_bytes(record))
|
|
except Exception as e:
|
|
raise ChainIntegrityError(
|
|
f"Record {record.chain_index}: signature verification failed: {e}"
|
|
) from e
|
|
|
|
# Track authorized signers: the genesis signer plus any keys
|
|
# introduced by valid key-rotation records.
|
|
if not authorized_signers:
|
|
authorized_signers.add(record.signer_pubkey)
|
|
elif record.signer_pubkey not in authorized_signers:
|
|
raise ChainIntegrityError(
|
|
f"Record {record.chain_index}: signer "
|
|
f"{record.signer_pubkey.hex()[:16]}... is not authorized"
|
|
)
|
|
|
|
# If this is a key rotation record, authorize the new key
|
|
if record.content_type == CONTENT_TYPE_KEY_ROTATION:
|
|
new_pubkey_hex = record.metadata.get("new_pubkey")
|
|
if not new_pubkey_hex:
|
|
raise ChainIntegrityError(
|
|
f"Record {record.chain_index}: key rotation missing new_pubkey"
|
|
)
|
|
authorized_signers.add(bytes.fromhex(new_pubkey_hex))
|
|
# Revoke the old key — the rotation record was its last authorized action
|
|
authorized_signers.discard(record.signer_pubkey)
|
|
|
|
prev_record = record
|
|
expected_index += 1
|
|
|
|
return True
|