diff --git a/tests/unit/test_pet_id.py b/tests/unit/test_pet_id.py new file mode 100644 index 0000000..afefe02 --- /dev/null +++ b/tests/unit/test_pet_id.py @@ -0,0 +1,49 @@ +"""Tests for pet ID classifier.""" + +import numpy as np +import pytest + +from vigilar.detection.pet_id import PetIDClassifier, PetIDResult + + +class TestPetIDResult: + def test_identified(self): + r = PetIDResult(pet_id="pet-1", pet_name="Angel", confidence=0.9) + assert r.is_identified + assert not r.is_low_confidence + + def test_low_confidence(self): + r = PetIDResult(pet_id="pet-1", pet_name="Angel", confidence=0.6) + assert r.is_identified + assert r.is_low_confidence + + def test_unknown(self): + r = PetIDResult(pet_id=None, pet_name=None, confidence=0.3) + assert not r.is_identified + + +class TestPetIDClassifier: + def test_not_loaded_returns_unknown(self): + classifier = PetIDClassifier(model_path="nonexistent.pt") + assert not classifier.is_loaded + crop = np.zeros((224, 224, 3), dtype=np.uint8) + result = classifier.identify(crop, species="cat") + assert not result.is_identified + + def test_no_pets_registered_returns_unknown(self): + classifier = PetIDClassifier(model_path="nonexistent.pt") + assert classifier.pet_count == 0 + + def test_register_pet(self): + classifier = PetIDClassifier(model_path="nonexistent.pt") + classifier.register_pet("pet-1", "Angel", "cat") + classifier.register_pet("pet-2", "Milo", "dog") + assert classifier.pet_count == 2 + + def test_species_filter(self): + classifier = PetIDClassifier(model_path="nonexistent.pt") + classifier.register_pet("pet-1", "Angel", "cat") + classifier.register_pet("pet-2", "Taquito", "cat") + classifier.register_pet("pet-3", "Milo", "dog") + assert len(classifier.get_pets_by_species("cat")) == 2 + assert len(classifier.get_pets_by_species("dog")) == 1 diff --git a/vigilar/detection/pet_id.py b/vigilar/detection/pet_id.py new file mode 100644 index 0000000..f758561 --- /dev/null +++ b/vigilar/detection/pet_id.py @@ -0,0 +1,132 @@ +"""Pet identification classifier using MobileNetV3-Small.""" + +import logging +from dataclasses import dataclass +from pathlib import Path + +import numpy as np + +log = logging.getLogger(__name__) + +DEFAULT_HIGH_THRESHOLD = 0.7 +DEFAULT_LOW_THRESHOLD = 0.5 + + +@dataclass +class PetIDResult: + pet_id: str | None + pet_name: str | None + confidence: float + high_threshold: float = DEFAULT_HIGH_THRESHOLD + low_threshold: float = DEFAULT_LOW_THRESHOLD + + @property + def is_identified(self) -> bool: + return self.pet_id is not None and self.confidence >= self.low_threshold + + @property + def is_low_confidence(self) -> bool: + return ( + self.pet_id is not None + and self.low_threshold <= self.confidence < self.high_threshold + ) + + +@dataclass +class RegisteredPet: + pet_id: str + name: str + species: str + class_index: int + + +class PetIDClassifier: + def __init__( + self, + model_path: str, + high_threshold: float = DEFAULT_HIGH_THRESHOLD, + low_threshold: float = DEFAULT_LOW_THRESHOLD, + ): + self._model_path = model_path + self._high_threshold = high_threshold + self._low_threshold = low_threshold + self._model = None + self.is_loaded = False + self._pets: list[RegisteredPet] = [] + + if Path(model_path).exists(): + try: + import torch + self._model = torch.load(model_path, map_location="cpu", weights_only=False) + self._model.eval() + self.is_loaded = True + log.info("Pet ID model loaded from %s", model_path) + except Exception as e: + log.error("Failed to load pet ID model: %s", e) + else: + log.info( + "Pet ID model not found at %s — identification disabled until trained", + model_path, + ) + + @property + def pet_count(self) -> int: + return len(self._pets) + + def register_pet(self, pet_id: str, name: str, species: str) -> None: + idx = len(self._pets) + self._pets.append(RegisteredPet(pet_id=pet_id, name=name, species=species, + class_index=idx)) + + def get_pets_by_species(self, species: str) -> list[RegisteredPet]: + return [p for p in self._pets if p.species == species] + + def identify(self, crop: np.ndarray, species: str) -> PetIDResult: + if not self.is_loaded or self._model is None: + return PetIDResult(pet_id=None, pet_name=None, confidence=0.0) + + candidates = self.get_pets_by_species(species) + if not candidates: + return PetIDResult(pet_id=None, pet_name=None, confidence=0.0) + + try: + import cv2 + import torch + from torchvision import transforms + + resized = cv2.resize(crop, (224, 224)) + rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB) + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + ]) + tensor = transform(rgb).unsqueeze(0) + + with torch.no_grad(): + output = self._model(tensor) + probs = torch.softmax(output, dim=1)[0] + + best_conf = 0.0 + best_pet = None + for pet in candidates: + if pet.class_index < len(probs): + conf = float(probs[pet.class_index]) + if conf > best_conf: + best_conf = conf + best_pet = pet + + if best_pet and best_conf >= self._low_threshold: + return PetIDResult( + pet_id=best_pet.pet_id, + pet_name=best_pet.name, + confidence=best_conf, + high_threshold=self._high_threshold, + low_threshold=self._low_threshold, + ) + + return PetIDResult(pet_id=None, pet_name=None, confidence=best_conf) + + except Exception as e: + log.error("Pet ID inference failed: %s", e) + return PetIDResult(pet_id=None, pet_name=None, confidence=0.0)