diff --git a/tests/unit/test_visitor_queries.py b/tests/unit/test_visitor_queries.py new file mode 100644 index 0000000..1db29b4 --- /dev/null +++ b/tests/unit/test_visitor_queries.py @@ -0,0 +1,57 @@ +import time +import pytest +from vigilar.storage.queries import ( + create_face_profile, get_face_profile, get_all_profiles, update_face_profile, + delete_face_profile_cascade, insert_face_embedding, get_embeddings_for_profile, + insert_visit, get_visits, get_active_visits, +) + + +def test_create_and_get_profile(test_db): + now = time.time() + pid = create_face_profile(test_db, first_seen_at=now, last_seen_at=now) + assert pid > 0 + profile = get_face_profile(test_db, pid) + assert profile is not None + assert profile["name"] is None + + +def test_get_all_profiles(test_db): + now = time.time() + create_face_profile(test_db, first_seen_at=now, last_seen_at=now) + create_face_profile(test_db, name="Bob", first_seen_at=now, last_seen_at=now) + assert len(get_all_profiles(test_db)) == 2 + + +def test_update_profile(test_db): + pid = create_face_profile(test_db, first_seen_at=0, last_seen_at=0) + update_face_profile(test_db, pid, name="Alice") + assert get_face_profile(test_db, pid)["name"] == "Alice" + + +def test_insert_embedding(test_db): + pid = create_face_profile(test_db, first_seen_at=0, last_seen_at=0) + eid = insert_face_embedding(test_db, pid, "AAAA", "front", time.time()) + assert eid > 0 + assert len(get_embeddings_for_profile(test_db, pid)) == 1 + + +def test_insert_and_get_visit(test_db): + pid = create_face_profile(test_db, first_seen_at=0, last_seen_at=0) + insert_visit(test_db, pid, "front", time.time()) + assert len(get_visits(test_db)) == 1 + + +def test_get_active_visits(test_db): + pid = create_face_profile(test_db, first_seen_at=0, last_seen_at=0) + insert_visit(test_db, pid, "front", time.time()) + assert len(get_active_visits(test_db)) == 1 + + +def test_delete_cascade(test_db): + pid = create_face_profile(test_db, first_seen_at=0, last_seen_at=0) + insert_face_embedding(test_db, pid, "AAAA", "front", time.time()) + insert_visit(test_db, pid, "front", time.time()) + delete_face_profile_cascade(test_db, pid) + assert get_face_profile(test_db, pid) is None + assert len(get_embeddings_for_profile(test_db, pid)) == 0 diff --git a/vigilar/storage/queries.py b/vigilar/storage/queries.py index c606769..9fa815e 100644 --- a/vigilar/storage/queries.py +++ b/vigilar/storage/queries.py @@ -595,3 +595,99 @@ def count_pet_rules(engine, pet_id) -> int: return conn.execute( select(func.count()).select_from(pet_rules).where(pet_rules.c.pet_id == pet_id) ).scalar() or 0 + + +# --- Face Profiles --- + +def create_face_profile(engine, name=None, first_seen_at=0, last_seen_at=0) -> int: + from vigilar.storage.schema import face_profiles + with engine.begin() as conn: + result = conn.execute(face_profiles.insert().values( + name=name, is_household=0, visit_count=0, + first_seen_at=first_seen_at, last_seen_at=last_seen_at, + ignored=0, created_at=time.time())) + return result.inserted_primary_key[0] + + +def get_face_profile(engine, profile_id) -> dict | None: + from vigilar.storage.schema import face_profiles + with engine.connect() as conn: + row = conn.execute(select(face_profiles).where( + face_profiles.c.id == profile_id)).mappings().first() + return dict(row) if row else None + + +def get_all_profiles(engine, named_only=False, include_ignored=True) -> list[dict]: + from vigilar.storage.schema import face_profiles + query = select(face_profiles).order_by(desc(face_profiles.c.last_seen_at)) + if named_only: + query = query.where(face_profiles.c.name.isnot(None)) + if not include_ignored: + query = query.where(face_profiles.c.ignored == 0) + with engine.connect() as conn: + return [dict(r) for r in conn.execute(query).mappings().all()] + + +def update_face_profile(engine, profile_id, **updates) -> None: + from vigilar.storage.schema import face_profiles + with engine.begin() as conn: + conn.execute(face_profiles.update().where( + face_profiles.c.id == profile_id).values(**updates)) + + +def delete_face_profile_cascade(engine, profile_id) -> None: + from vigilar.storage.schema import face_embeddings, face_profiles, visits + with engine.begin() as conn: + conn.execute(face_embeddings.delete().where(face_embeddings.c.profile_id == profile_id)) + conn.execute(visits.delete().where(visits.c.profile_id == profile_id)) + conn.execute(face_profiles.delete().where(face_profiles.c.id == profile_id)) + + +# --- Face Embeddings --- + +def insert_face_embedding( + engine, profile_id, embedding_b64, camera_id, captured_at, crop_path=None +) -> int: + from vigilar.storage.schema import face_embeddings + with engine.begin() as conn: + result = conn.execute(face_embeddings.insert().values( + profile_id=profile_id, embedding=embedding_b64, + crop_path=crop_path, camera_id=camera_id, captured_at=captured_at)) + return result.inserted_primary_key[0] + + +def get_embeddings_for_profile(engine, profile_id) -> list[dict]: + from vigilar.storage.schema import face_embeddings + with engine.connect() as conn: + return [dict(r) for r in conn.execute( + select(face_embeddings).where( + face_embeddings.c.profile_id == profile_id)).mappings().all()] + + +# --- Visits --- + +def insert_visit(engine, profile_id, camera_id, arrived_at, event_id=None) -> int: + from vigilar.storage.schema import visits + with engine.begin() as conn: + result = conn.execute(visits.insert().values( + profile_id=profile_id, camera_id=camera_id, + arrived_at=arrived_at, event_id=event_id)) + return result.inserted_primary_key[0] + + +def get_visits(engine, profile_id=None, camera_id=None, limit=50) -> list[dict]: + from vigilar.storage.schema import visits + query = select(visits).order_by(desc(visits.c.arrived_at)).limit(limit) + if profile_id: + query = query.where(visits.c.profile_id == profile_id) + if camera_id: + query = query.where(visits.c.camera_id == camera_id) + with engine.connect() as conn: + return [dict(r) for r in conn.execute(query).mappings().all()] + + +def get_active_visits(engine) -> list[dict]: + from vigilar.storage.schema import visits + with engine.connect() as conn: + return [dict(r) for r in conn.execute( + select(visits).where(visits.c.departed_at.is_(None))).mappings().all()]