Add pet ID classifier with species-filtered identification
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
13b7c2a219
commit
c7f9304f2a
49
tests/unit/test_pet_id.py
Normal file
49
tests/unit/test_pet_id.py
Normal file
@ -0,0 +1,49 @@
|
||||
"""Tests for pet ID classifier."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from vigilar.detection.pet_id import PetIDClassifier, PetIDResult
|
||||
|
||||
|
||||
class TestPetIDResult:
|
||||
def test_identified(self):
|
||||
r = PetIDResult(pet_id="pet-1", pet_name="Angel", confidence=0.9)
|
||||
assert r.is_identified
|
||||
assert not r.is_low_confidence
|
||||
|
||||
def test_low_confidence(self):
|
||||
r = PetIDResult(pet_id="pet-1", pet_name="Angel", confidence=0.6)
|
||||
assert r.is_identified
|
||||
assert r.is_low_confidence
|
||||
|
||||
def test_unknown(self):
|
||||
r = PetIDResult(pet_id=None, pet_name=None, confidence=0.3)
|
||||
assert not r.is_identified
|
||||
|
||||
|
||||
class TestPetIDClassifier:
|
||||
def test_not_loaded_returns_unknown(self):
|
||||
classifier = PetIDClassifier(model_path="nonexistent.pt")
|
||||
assert not classifier.is_loaded
|
||||
crop = np.zeros((224, 224, 3), dtype=np.uint8)
|
||||
result = classifier.identify(crop, species="cat")
|
||||
assert not result.is_identified
|
||||
|
||||
def test_no_pets_registered_returns_unknown(self):
|
||||
classifier = PetIDClassifier(model_path="nonexistent.pt")
|
||||
assert classifier.pet_count == 0
|
||||
|
||||
def test_register_pet(self):
|
||||
classifier = PetIDClassifier(model_path="nonexistent.pt")
|
||||
classifier.register_pet("pet-1", "Angel", "cat")
|
||||
classifier.register_pet("pet-2", "Milo", "dog")
|
||||
assert classifier.pet_count == 2
|
||||
|
||||
def test_species_filter(self):
|
||||
classifier = PetIDClassifier(model_path="nonexistent.pt")
|
||||
classifier.register_pet("pet-1", "Angel", "cat")
|
||||
classifier.register_pet("pet-2", "Taquito", "cat")
|
||||
classifier.register_pet("pet-3", "Milo", "dog")
|
||||
assert len(classifier.get_pets_by_species("cat")) == 2
|
||||
assert len(classifier.get_pets_by_species("dog")) == 1
|
||||
132
vigilar/detection/pet_id.py
Normal file
132
vigilar/detection/pet_id.py
Normal file
@ -0,0 +1,132 @@
|
||||
"""Pet identification classifier using MobileNetV3-Small."""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_HIGH_THRESHOLD = 0.7
|
||||
DEFAULT_LOW_THRESHOLD = 0.5
|
||||
|
||||
|
||||
@dataclass
|
||||
class PetIDResult:
|
||||
pet_id: str | None
|
||||
pet_name: str | None
|
||||
confidence: float
|
||||
high_threshold: float = DEFAULT_HIGH_THRESHOLD
|
||||
low_threshold: float = DEFAULT_LOW_THRESHOLD
|
||||
|
||||
@property
|
||||
def is_identified(self) -> bool:
|
||||
return self.pet_id is not None and self.confidence >= self.low_threshold
|
||||
|
||||
@property
|
||||
def is_low_confidence(self) -> bool:
|
||||
return (
|
||||
self.pet_id is not None
|
||||
and self.low_threshold <= self.confidence < self.high_threshold
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegisteredPet:
|
||||
pet_id: str
|
||||
name: str
|
||||
species: str
|
||||
class_index: int
|
||||
|
||||
|
||||
class PetIDClassifier:
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
high_threshold: float = DEFAULT_HIGH_THRESHOLD,
|
||||
low_threshold: float = DEFAULT_LOW_THRESHOLD,
|
||||
):
|
||||
self._model_path = model_path
|
||||
self._high_threshold = high_threshold
|
||||
self._low_threshold = low_threshold
|
||||
self._model = None
|
||||
self.is_loaded = False
|
||||
self._pets: list[RegisteredPet] = []
|
||||
|
||||
if Path(model_path).exists():
|
||||
try:
|
||||
import torch
|
||||
self._model = torch.load(model_path, map_location="cpu", weights_only=False)
|
||||
self._model.eval()
|
||||
self.is_loaded = True
|
||||
log.info("Pet ID model loaded from %s", model_path)
|
||||
except Exception as e:
|
||||
log.error("Failed to load pet ID model: %s", e)
|
||||
else:
|
||||
log.info(
|
||||
"Pet ID model not found at %s — identification disabled until trained",
|
||||
model_path,
|
||||
)
|
||||
|
||||
@property
|
||||
def pet_count(self) -> int:
|
||||
return len(self._pets)
|
||||
|
||||
def register_pet(self, pet_id: str, name: str, species: str) -> None:
|
||||
idx = len(self._pets)
|
||||
self._pets.append(RegisteredPet(pet_id=pet_id, name=name, species=species,
|
||||
class_index=idx))
|
||||
|
||||
def get_pets_by_species(self, species: str) -> list[RegisteredPet]:
|
||||
return [p for p in self._pets if p.species == species]
|
||||
|
||||
def identify(self, crop: np.ndarray, species: str) -> PetIDResult:
|
||||
if not self.is_loaded or self._model is None:
|
||||
return PetIDResult(pet_id=None, pet_name=None, confidence=0.0)
|
||||
|
||||
candidates = self.get_pets_by_species(species)
|
||||
if not candidates:
|
||||
return PetIDResult(pet_id=None, pet_name=None, confidence=0.0)
|
||||
|
||||
try:
|
||||
import cv2
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
resized = cv2.resize(crop, (224, 224))
|
||||
rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
tensor = transform(rgb).unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
output = self._model(tensor)
|
||||
probs = torch.softmax(output, dim=1)[0]
|
||||
|
||||
best_conf = 0.0
|
||||
best_pet = None
|
||||
for pet in candidates:
|
||||
if pet.class_index < len(probs):
|
||||
conf = float(probs[pet.class_index])
|
||||
if conf > best_conf:
|
||||
best_conf = conf
|
||||
best_pet = pet
|
||||
|
||||
if best_pet and best_conf >= self._low_threshold:
|
||||
return PetIDResult(
|
||||
pet_id=best_pet.pet_id,
|
||||
pet_name=best_pet.name,
|
||||
confidence=best_conf,
|
||||
high_threshold=self._high_threshold,
|
||||
low_threshold=self._low_threshold,
|
||||
)
|
||||
|
||||
return PetIDResult(pet_id=None, pet_name=None, confidence=best_conf)
|
||||
|
||||
except Exception as e:
|
||||
log.error("Pet ID inference failed: %s", e)
|
||||
return PetIDResult(pet_id=None, pet_name=None, confidence=0.0)
|
||||
Loading…
Reference in New Issue
Block a user