"""Security-focused tests for the attestation chain. Tests concurrent access, oversized records, and edge cases that could compromise chain integrity. """ from __future__ import annotations import hashlib import struct import threading from pathlib import Path import pytest from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey from fieldwitness.exceptions import ChainError from fieldwitness.federation.chain import MAX_RECORD_SIZE, ChainStore def test_concurrent_append_no_fork(chain_dir: Path): """Concurrent appends must not fork the chain — indices must be unique.""" private_key = Ed25519PrivateKey.generate() num_threads = 8 records_per_thread = 5 results: list[list] = [[] for _ in range(num_threads)] errors: list[Exception] = [] def worker(thread_id: int): try: store = ChainStore(chain_dir) for i in range(records_per_thread): content = hashlib.sha256(f"t{thread_id}-r{i}".encode()).digest() record = store.append(content, "test/plain", private_key) results[thread_id].append(record.chain_index) except Exception as e: errors.append(e) threads = [threading.Thread(target=worker, args=(t,)) for t in range(num_threads)] for t in threads: t.start() for t in threads: t.join() assert not errors, f"Thread errors: {errors}" # Collect all indices all_indices = [] for r in results: all_indices.extend(r) # Every index must be unique (no fork) assert len(all_indices) == len(set(all_indices)), ( f"Duplicate chain indices detected — chain forked! " f"Indices: {sorted(all_indices)}" ) # Indices should be 0..N-1 contiguous total = num_threads * records_per_thread assert sorted(all_indices) == list(range(total)) # Full chain verification should pass store = ChainStore(chain_dir) assert store.verify_chain() is True assert store.state().record_count == total def test_oversized_record_rejected(chain_dir: Path): """A corrupted length prefix exceeding MAX_RECORD_SIZE must raise ChainError.""" chain_file = chain_dir / "chain.bin" # Write a length prefix claiming a 100 MB record bogus_length = 100 * 1024 * 1024 chain_file.write_bytes(struct.pack(">I", bogus_length) + b"\x00" * 100) store = ChainStore(chain_dir) with pytest.raises(ChainError, match="exceeds maximum"): list(store._iter_raw()) def test_max_record_size_boundary(chain_dir: Path): """Records at exactly MAX_RECORD_SIZE should be rejected (real records are <1KB).""" chain_file = chain_dir / "chain.bin" chain_file.write_bytes(struct.pack(">I", MAX_RECORD_SIZE + 1) + b"\x00" * 100) store = ChainStore(chain_dir) with pytest.raises(ChainError, match="exceeds maximum"): list(store._iter_raw()) def test_truncated_chain_file(chain_dir: Path, private_key: Ed25519PrivateKey): """A truncated chain.bin still yields complete records before the truncation.""" store = ChainStore(chain_dir) for i in range(3): store.append(hashlib.sha256(f"c-{i}".encode()).digest(), "test/plain", private_key) # Truncate the file mid-record chain_file = chain_dir / "chain.bin" data = chain_file.read_bytes() chain_file.write_bytes(data[: len(data) - 50]) store2 = ChainStore(chain_dir) records = list(store2._iter_raw()) # Should get at least the first 2 complete records assert len(records) >= 2 assert records[0].chain_index == 0 assert records[1].chain_index == 1 def test_empty_chain_file(chain_dir: Path): """An empty chain.bin (0 bytes) yields no records without error.""" chain_file = chain_dir / "chain.bin" chain_file.write_bytes(b"") store = ChainStore(chain_dir) records = list(store._iter_raw()) assert records == [] def test_concurrent_read_during_write(chain_dir: Path): """Reading the chain while appending should not crash.""" private_key = Ed25519PrivateKey.generate() store = ChainStore(chain_dir) # Seed with some records for i in range(5): store.append(hashlib.sha256(f"seed-{i}".encode()).digest(), "test/plain", private_key) read_errors: list[Exception] = [] write_errors: list[Exception] = [] def reader(): try: s = ChainStore(chain_dir) for _ in range(20): list(s.iter_records()) except Exception as e: read_errors.append(e) def writer(): try: s = ChainStore(chain_dir) for i in range(10): s.append(hashlib.sha256(f"w-{i}".encode()).digest(), "test/plain", private_key) except Exception as e: write_errors.append(e) threads = [ threading.Thread(target=reader), threading.Thread(target=reader), threading.Thread(target=writer), ] for t in threads: t.start() for t in threads: t.join() assert not read_errors, f"Read errors during concurrent access: {read_errors}" assert not write_errors, f"Write errors during concurrent access: {write_errors}"