Add pet ID classifier with species-filtered identification
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
13b7c2a219
commit
c7f9304f2a
49
tests/unit/test_pet_id.py
Normal file
49
tests/unit/test_pet_id.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
"""Tests for pet ID classifier."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vigilar.detection.pet_id import PetIDClassifier, PetIDResult
|
||||||
|
|
||||||
|
|
||||||
|
class TestPetIDResult:
|
||||||
|
def test_identified(self):
|
||||||
|
r = PetIDResult(pet_id="pet-1", pet_name="Angel", confidence=0.9)
|
||||||
|
assert r.is_identified
|
||||||
|
assert not r.is_low_confidence
|
||||||
|
|
||||||
|
def test_low_confidence(self):
|
||||||
|
r = PetIDResult(pet_id="pet-1", pet_name="Angel", confidence=0.6)
|
||||||
|
assert r.is_identified
|
||||||
|
assert r.is_low_confidence
|
||||||
|
|
||||||
|
def test_unknown(self):
|
||||||
|
r = PetIDResult(pet_id=None, pet_name=None, confidence=0.3)
|
||||||
|
assert not r.is_identified
|
||||||
|
|
||||||
|
|
||||||
|
class TestPetIDClassifier:
|
||||||
|
def test_not_loaded_returns_unknown(self):
|
||||||
|
classifier = PetIDClassifier(model_path="nonexistent.pt")
|
||||||
|
assert not classifier.is_loaded
|
||||||
|
crop = np.zeros((224, 224, 3), dtype=np.uint8)
|
||||||
|
result = classifier.identify(crop, species="cat")
|
||||||
|
assert not result.is_identified
|
||||||
|
|
||||||
|
def test_no_pets_registered_returns_unknown(self):
|
||||||
|
classifier = PetIDClassifier(model_path="nonexistent.pt")
|
||||||
|
assert classifier.pet_count == 0
|
||||||
|
|
||||||
|
def test_register_pet(self):
|
||||||
|
classifier = PetIDClassifier(model_path="nonexistent.pt")
|
||||||
|
classifier.register_pet("pet-1", "Angel", "cat")
|
||||||
|
classifier.register_pet("pet-2", "Milo", "dog")
|
||||||
|
assert classifier.pet_count == 2
|
||||||
|
|
||||||
|
def test_species_filter(self):
|
||||||
|
classifier = PetIDClassifier(model_path="nonexistent.pt")
|
||||||
|
classifier.register_pet("pet-1", "Angel", "cat")
|
||||||
|
classifier.register_pet("pet-2", "Taquito", "cat")
|
||||||
|
classifier.register_pet("pet-3", "Milo", "dog")
|
||||||
|
assert len(classifier.get_pets_by_species("cat")) == 2
|
||||||
|
assert len(classifier.get_pets_by_species("dog")) == 1
|
||||||
132
vigilar/detection/pet_id.py
Normal file
132
vigilar/detection/pet_id.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
"""Pet identification classifier using MobileNetV3-Small."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_HIGH_THRESHOLD = 0.7
|
||||||
|
DEFAULT_LOW_THRESHOLD = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PetIDResult:
|
||||||
|
pet_id: str | None
|
||||||
|
pet_name: str | None
|
||||||
|
confidence: float
|
||||||
|
high_threshold: float = DEFAULT_HIGH_THRESHOLD
|
||||||
|
low_threshold: float = DEFAULT_LOW_THRESHOLD
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_identified(self) -> bool:
|
||||||
|
return self.pet_id is not None and self.confidence >= self.low_threshold
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_low_confidence(self) -> bool:
|
||||||
|
return (
|
||||||
|
self.pet_id is not None
|
||||||
|
and self.low_threshold <= self.confidence < self.high_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RegisteredPet:
|
||||||
|
pet_id: str
|
||||||
|
name: str
|
||||||
|
species: str
|
||||||
|
class_index: int
|
||||||
|
|
||||||
|
|
||||||
|
class PetIDClassifier:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path: str,
|
||||||
|
high_threshold: float = DEFAULT_HIGH_THRESHOLD,
|
||||||
|
low_threshold: float = DEFAULT_LOW_THRESHOLD,
|
||||||
|
):
|
||||||
|
self._model_path = model_path
|
||||||
|
self._high_threshold = high_threshold
|
||||||
|
self._low_threshold = low_threshold
|
||||||
|
self._model = None
|
||||||
|
self.is_loaded = False
|
||||||
|
self._pets: list[RegisteredPet] = []
|
||||||
|
|
||||||
|
if Path(model_path).exists():
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
self._model = torch.load(model_path, map_location="cpu", weights_only=False)
|
||||||
|
self._model.eval()
|
||||||
|
self.is_loaded = True
|
||||||
|
log.info("Pet ID model loaded from %s", model_path)
|
||||||
|
except Exception as e:
|
||||||
|
log.error("Failed to load pet ID model: %s", e)
|
||||||
|
else:
|
||||||
|
log.info(
|
||||||
|
"Pet ID model not found at %s — identification disabled until trained",
|
||||||
|
model_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pet_count(self) -> int:
|
||||||
|
return len(self._pets)
|
||||||
|
|
||||||
|
def register_pet(self, pet_id: str, name: str, species: str) -> None:
|
||||||
|
idx = len(self._pets)
|
||||||
|
self._pets.append(RegisteredPet(pet_id=pet_id, name=name, species=species,
|
||||||
|
class_index=idx))
|
||||||
|
|
||||||
|
def get_pets_by_species(self, species: str) -> list[RegisteredPet]:
|
||||||
|
return [p for p in self._pets if p.species == species]
|
||||||
|
|
||||||
|
def identify(self, crop: np.ndarray, species: str) -> PetIDResult:
|
||||||
|
if not self.is_loaded or self._model is None:
|
||||||
|
return PetIDResult(pet_id=None, pet_name=None, confidence=0.0)
|
||||||
|
|
||||||
|
candidates = self.get_pets_by_species(species)
|
||||||
|
if not candidates:
|
||||||
|
return PetIDResult(pet_id=None, pet_name=None, confidence=0.0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
resized = cv2.resize(crop, (224, 224))
|
||||||
|
rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||||
|
std=[0.229, 0.224, 0.225]),
|
||||||
|
])
|
||||||
|
tensor = transform(rgb).unsqueeze(0)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = self._model(tensor)
|
||||||
|
probs = torch.softmax(output, dim=1)[0]
|
||||||
|
|
||||||
|
best_conf = 0.0
|
||||||
|
best_pet = None
|
||||||
|
for pet in candidates:
|
||||||
|
if pet.class_index < len(probs):
|
||||||
|
conf = float(probs[pet.class_index])
|
||||||
|
if conf > best_conf:
|
||||||
|
best_conf = conf
|
||||||
|
best_pet = pet
|
||||||
|
|
||||||
|
if best_pet and best_conf >= self._low_threshold:
|
||||||
|
return PetIDResult(
|
||||||
|
pet_id=best_pet.pet_id,
|
||||||
|
pet_name=best_pet.name,
|
||||||
|
confidence=best_conf,
|
||||||
|
high_threshold=self._high_threshold,
|
||||||
|
low_threshold=self._low_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
return PetIDResult(pet_id=None, pet_name=None, confidence=best_conf)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error("Pet ID inference failed: %s", e)
|
||||||
|
return PetIDResult(pet_id=None, pet_name=None, confidence=0.0)
|
||||||
Loading…
Reference in New Issue
Block a user