Add pet ID classifier with species-filtered identification

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Aaron D. Lee 2026-04-03 13:18:42 -04:00
parent 13b7c2a219
commit c7f9304f2a
2 changed files with 181 additions and 0 deletions

49
tests/unit/test_pet_id.py Normal file
View File

@ -0,0 +1,49 @@
"""Tests for pet ID classifier."""
import numpy as np
import pytest
from vigilar.detection.pet_id import PetIDClassifier, PetIDResult
class TestPetIDResult:
def test_identified(self):
r = PetIDResult(pet_id="pet-1", pet_name="Angel", confidence=0.9)
assert r.is_identified
assert not r.is_low_confidence
def test_low_confidence(self):
r = PetIDResult(pet_id="pet-1", pet_name="Angel", confidence=0.6)
assert r.is_identified
assert r.is_low_confidence
def test_unknown(self):
r = PetIDResult(pet_id=None, pet_name=None, confidence=0.3)
assert not r.is_identified
class TestPetIDClassifier:
def test_not_loaded_returns_unknown(self):
classifier = PetIDClassifier(model_path="nonexistent.pt")
assert not classifier.is_loaded
crop = np.zeros((224, 224, 3), dtype=np.uint8)
result = classifier.identify(crop, species="cat")
assert not result.is_identified
def test_no_pets_registered_returns_unknown(self):
classifier = PetIDClassifier(model_path="nonexistent.pt")
assert classifier.pet_count == 0
def test_register_pet(self):
classifier = PetIDClassifier(model_path="nonexistent.pt")
classifier.register_pet("pet-1", "Angel", "cat")
classifier.register_pet("pet-2", "Milo", "dog")
assert classifier.pet_count == 2
def test_species_filter(self):
classifier = PetIDClassifier(model_path="nonexistent.pt")
classifier.register_pet("pet-1", "Angel", "cat")
classifier.register_pet("pet-2", "Taquito", "cat")
classifier.register_pet("pet-3", "Milo", "dog")
assert len(classifier.get_pets_by_species("cat")) == 2
assert len(classifier.get_pets_by_species("dog")) == 1

132
vigilar/detection/pet_id.py Normal file
View File

@ -0,0 +1,132 @@
"""Pet identification classifier using MobileNetV3-Small."""
import logging
from dataclasses import dataclass
from pathlib import Path
import numpy as np
log = logging.getLogger(__name__)
DEFAULT_HIGH_THRESHOLD = 0.7
DEFAULT_LOW_THRESHOLD = 0.5
@dataclass
class PetIDResult:
pet_id: str | None
pet_name: str | None
confidence: float
high_threshold: float = DEFAULT_HIGH_THRESHOLD
low_threshold: float = DEFAULT_LOW_THRESHOLD
@property
def is_identified(self) -> bool:
return self.pet_id is not None and self.confidence >= self.low_threshold
@property
def is_low_confidence(self) -> bool:
return (
self.pet_id is not None
and self.low_threshold <= self.confidence < self.high_threshold
)
@dataclass
class RegisteredPet:
pet_id: str
name: str
species: str
class_index: int
class PetIDClassifier:
def __init__(
self,
model_path: str,
high_threshold: float = DEFAULT_HIGH_THRESHOLD,
low_threshold: float = DEFAULT_LOW_THRESHOLD,
):
self._model_path = model_path
self._high_threshold = high_threshold
self._low_threshold = low_threshold
self._model = None
self.is_loaded = False
self._pets: list[RegisteredPet] = []
if Path(model_path).exists():
try:
import torch
self._model = torch.load(model_path, map_location="cpu", weights_only=False)
self._model.eval()
self.is_loaded = True
log.info("Pet ID model loaded from %s", model_path)
except Exception as e:
log.error("Failed to load pet ID model: %s", e)
else:
log.info(
"Pet ID model not found at %s — identification disabled until trained",
model_path,
)
@property
def pet_count(self) -> int:
return len(self._pets)
def register_pet(self, pet_id: str, name: str, species: str) -> None:
idx = len(self._pets)
self._pets.append(RegisteredPet(pet_id=pet_id, name=name, species=species,
class_index=idx))
def get_pets_by_species(self, species: str) -> list[RegisteredPet]:
return [p for p in self._pets if p.species == species]
def identify(self, crop: np.ndarray, species: str) -> PetIDResult:
if not self.is_loaded or self._model is None:
return PetIDResult(pet_id=None, pet_name=None, confidence=0.0)
candidates = self.get_pets_by_species(species)
if not candidates:
return PetIDResult(pet_id=None, pet_name=None, confidence=0.0)
try:
import cv2
import torch
from torchvision import transforms
resized = cv2.resize(crop, (224, 224))
rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
tensor = transform(rgb).unsqueeze(0)
with torch.no_grad():
output = self._model(tensor)
probs = torch.softmax(output, dim=1)[0]
best_conf = 0.0
best_pet = None
for pet in candidates:
if pet.class_index < len(probs):
conf = float(probs[pet.class_index])
if conf > best_conf:
best_conf = conf
best_pet = pet
if best_pet and best_conf >= self._low_threshold:
return PetIDResult(
pet_id=best_pet.pet_id,
pet_name=best_pet.name,
confidence=best_conf,
high_threshold=self._high_threshold,
low_threshold=self._low_threshold,
)
return PetIDResult(pet_id=None, pet_name=None, confidence=best_conf)
except Exception as e:
log.error("Pet ID inference failed: %s", e)
return PetIDResult(pet_id=None, pet_name=None, confidence=0.0)