diff --git a/tests/unit/test_pet_queries.py b/tests/unit/test_pet_queries.py new file mode 100644 index 0000000..8890394 --- /dev/null +++ b/tests/unit/test_pet_queries.py @@ -0,0 +1,101 @@ +"""Tests for pet and wildlife query functions.""" + +import time + +from vigilar.storage.queries import ( + insert_pet, + get_pet, + get_all_pets, + insert_pet_sighting, + get_pet_sightings, + get_pet_last_location, + insert_wildlife_sighting, + get_wildlife_sightings, + insert_training_image, + get_training_images, + get_unlabeled_sightings, + label_sighting, +) + + +class TestPetCRUD: + def test_insert_and_get_pet(self, test_db): + pet_id = insert_pet(test_db, name="Angel", species="cat", breed="DSH", + color_description="black") + pet = get_pet(test_db, pet_id) + assert pet is not None + assert pet["name"] == "Angel" + assert pet["species"] == "cat" + assert pet["training_count"] == 0 + + def test_get_all_pets(self, test_db): + insert_pet(test_db, name="Angel", species="cat") + insert_pet(test_db, name="Milo", species="dog") + all_pets = get_all_pets(test_db) + assert len(all_pets) == 2 + + +class TestPetSightings: + def test_insert_and_query(self, test_db): + pet_id = insert_pet(test_db, name="Angel", species="cat") + insert_pet_sighting(test_db, pet_id=pet_id, species="cat", + camera_id="kitchen", confidence=0.92) + sightings = get_pet_sightings(test_db, limit=10) + assert len(sightings) == 1 + assert sightings[0]["camera_id"] == "kitchen" + + def test_last_location(self, test_db): + pet_id = insert_pet(test_db, name="Angel", species="cat") + insert_pet_sighting(test_db, pet_id=pet_id, species="cat", + camera_id="kitchen", confidence=0.9) + insert_pet_sighting(test_db, pet_id=pet_id, species="cat", + camera_id="living_room", confidence=0.95) + loc = get_pet_last_location(test_db, pet_id) + assert loc is not None + assert loc["camera_id"] == "living_room" + + def test_unlabeled_sightings(self, test_db): + insert_pet_sighting(test_db, pet_id=None, species="cat", + camera_id="kitchen", confidence=0.6, crop_path="/tmp/crop.jpg") + unlabeled = get_unlabeled_sightings(test_db, limit=10) + assert len(unlabeled) == 1 + assert unlabeled[0]["labeled"] == 0 + + def test_label_sighting(self, test_db): + pet_id = insert_pet(test_db, name="Angel", species="cat") + insert_pet_sighting(test_db, pet_id=None, species="cat", + camera_id="kitchen", confidence=0.6) + sightings = get_unlabeled_sightings(test_db, limit=10) + sighting_id = sightings[0]["id"] + label_sighting(test_db, sighting_id, pet_id) + updated = get_pet_sightings(test_db, pet_id=pet_id) + assert len(updated) == 1 + assert updated[0]["labeled"] == 1 + + +class TestWildlifeSightings: + def test_insert_and_query(self, test_db): + insert_wildlife_sighting(test_db, species="bear", threat_level="PREDATOR", + camera_id="front", confidence=0.88) + sightings = get_wildlife_sightings(test_db, limit=10) + assert len(sightings) == 1 + assert sightings[0]["threat_level"] == "PREDATOR" + + def test_filter_by_threat(self, test_db): + insert_wildlife_sighting(test_db, species="bear", threat_level="PREDATOR", + camera_id="front", confidence=0.88) + insert_wildlife_sighting(test_db, species="deer", threat_level="PASSIVE", + camera_id="back", confidence=0.75) + predators = get_wildlife_sightings(test_db, threat_level="PREDATOR") + assert len(predators) == 1 + + +class TestTrainingImages: + def test_insert_and_query(self, test_db): + pet_id = insert_pet(test_db, name="Angel", species="cat") + insert_training_image(test_db, pet_id=pet_id, + image_path="/var/vigilar/pets/training/angel/001.jpg", + source="upload") + images = get_training_images(test_db, pet_id) + assert len(images) == 1 + assert images[0]["source"] == "upload" diff --git a/vigilar/storage/queries.py b/vigilar/storage/queries.py index 1f87050..001f8a4 100644 --- a/vigilar/storage/queries.py +++ b/vigilar/storage/queries.py @@ -11,10 +11,14 @@ from vigilar.storage.schema import ( alert_log, arm_state_log, events, + pet_sightings, + pet_training_images, + pets, push_subscriptions, recordings, sensor_states, system_events, + wildlife_sightings, ) @@ -271,3 +275,186 @@ def delete_push_subscription(engine: Engine, endpoint: str) -> bool: push_subscriptions.delete().where(push_subscriptions.c.endpoint == endpoint) ) return result.rowcount > 0 + + +# --- Pets --- + +def insert_pet( + engine: Engine, + name: str, + species: str, + breed: str | None = None, + color_description: str | None = None, + photo_path: str | None = None, +) -> str: + import uuid + pet_id = str(uuid.uuid4()) + with engine.begin() as conn: + conn.execute(pets.insert().values( + id=pet_id, name=name, species=species, breed=breed, + color_description=color_description, photo_path=photo_path, + training_count=0, created_at=time.time(), + )) + return pet_id + + +def get_pet(engine: Engine, pet_id: str) -> dict[str, Any] | None: + with engine.connect() as conn: + row = conn.execute(pets.select().where(pets.c.id == pet_id)).first() + return dict(row._mapping) if row else None + + +def get_all_pets(engine: Engine) -> list[dict[str, Any]]: + with engine.connect() as conn: + rows = conn.execute(pets.select().order_by(pets.c.name)).fetchall() + return [dict(r._mapping) for r in rows] + + +# --- Pet Sightings --- + +def insert_pet_sighting( + engine: Engine, + species: str, + camera_id: str, + confidence: float, + pet_id: str | None = None, + crop_path: str | None = None, + event_id: int | None = None, +) -> int: + with engine.begin() as conn: + result = conn.execute(pet_sightings.insert().values( + ts=time.time(), pet_id=pet_id, species=species, + camera_id=camera_id, confidence=confidence, + crop_path=crop_path, labeled=1 if pet_id else 0, + event_id=event_id, + )) + return result.inserted_primary_key[0] + + +def get_pet_sightings( + engine: Engine, + pet_id: str | None = None, + camera_id: str | None = None, + since_ts: float | None = None, + limit: int = 100, +) -> list[dict[str, Any]]: + query = select(pet_sightings).order_by(desc(pet_sightings.c.ts)).limit(limit) + if pet_id: + query = query.where(pet_sightings.c.pet_id == pet_id) + if camera_id: + query = query.where(pet_sightings.c.camera_id == camera_id) + if since_ts: + query = query.where(pet_sightings.c.ts >= since_ts) + with engine.connect() as conn: + rows = conn.execute(query).fetchall() + return [dict(r._mapping) for r in rows] + + +def get_pet_last_location(engine: Engine, pet_id: str) -> dict[str, Any] | None: + with engine.connect() as conn: + row = conn.execute( + select(pet_sightings) + .where(pet_sightings.c.pet_id == pet_id) + .order_by(desc(pet_sightings.c.ts)) + .limit(1) + ).first() + return dict(row._mapping) if row else None + + +def get_unlabeled_sightings( + engine: Engine, + species: str | None = None, + limit: int = 50, +) -> list[dict[str, Any]]: + query = ( + select(pet_sightings) + .where(pet_sightings.c.labeled == 0) + .order_by(desc(pet_sightings.c.ts)) + .limit(limit) + ) + if species: + query = query.where(pet_sightings.c.species == species) + with engine.connect() as conn: + rows = conn.execute(query).fetchall() + return [dict(r._mapping) for r in rows] + + +def label_sighting(engine: Engine, sighting_id: int, pet_id: str) -> None: + with engine.begin() as conn: + conn.execute( + pet_sightings.update() + .where(pet_sightings.c.id == sighting_id) + .values(pet_id=pet_id, labeled=1) + ) + + +# --- Wildlife Sightings --- + +def insert_wildlife_sighting( + engine: Engine, + species: str, + threat_level: str, + camera_id: str, + confidence: float, + crop_path: str | None = None, + event_id: int | None = None, +) -> int: + with engine.begin() as conn: + result = conn.execute(wildlife_sightings.insert().values( + ts=time.time(), species=species, threat_level=threat_level, + camera_id=camera_id, confidence=confidence, + crop_path=crop_path, event_id=event_id, + )) + return result.inserted_primary_key[0] + + +def get_wildlife_sightings( + engine: Engine, + threat_level: str | None = None, + camera_id: str | None = None, + since_ts: float | None = None, + limit: int = 100, +) -> list[dict[str, Any]]: + query = select(wildlife_sightings).order_by(desc(wildlife_sightings.c.ts)).limit(limit) + if threat_level: + query = query.where(wildlife_sightings.c.threat_level == threat_level) + if camera_id: + query = query.where(wildlife_sightings.c.camera_id == camera_id) + if since_ts: + query = query.where(wildlife_sightings.c.ts >= since_ts) + with engine.connect() as conn: + rows = conn.execute(query).fetchall() + return [dict(r._mapping) for r in rows] + + +# --- Training Images --- + +def insert_training_image( + engine: Engine, + pet_id: str, + image_path: str, + source: str, +) -> int: + with engine.begin() as conn: + result = conn.execute(pet_training_images.insert().values( + pet_id=pet_id, image_path=image_path, + source=source, created_at=time.time(), + )) + conn.execute( + pets.update().where(pets.c.id == pet_id) + .values(training_count=pets.c.training_count + 1) + ) + return result.inserted_primary_key[0] + + +def get_training_images( + engine: Engine, + pet_id: str, +) -> list[dict[str, Any]]: + with engine.connect() as conn: + rows = conn.execute( + select(pet_training_images) + .where(pet_training_images.c.pet_id == pet_id) + .order_by(desc(pet_training_images.c.created_at)) + ).fetchall() + return [dict(r._mapping) for r in rows]