133 lines
4.2 KiB
Python
133 lines
4.2 KiB
Python
"""Pet identification classifier using MobileNetV3-Small."""
|
|
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
DEFAULT_HIGH_THRESHOLD = 0.7
|
|
DEFAULT_LOW_THRESHOLD = 0.5
|
|
|
|
|
|
@dataclass
|
|
class PetIDResult:
|
|
pet_id: str | None
|
|
pet_name: str | None
|
|
confidence: float
|
|
high_threshold: float = DEFAULT_HIGH_THRESHOLD
|
|
low_threshold: float = DEFAULT_LOW_THRESHOLD
|
|
|
|
@property
|
|
def is_identified(self) -> bool:
|
|
return self.pet_id is not None and self.confidence >= self.low_threshold
|
|
|
|
@property
|
|
def is_low_confidence(self) -> bool:
|
|
return (
|
|
self.pet_id is not None
|
|
and self.low_threshold <= self.confidence < self.high_threshold
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class RegisteredPet:
|
|
pet_id: str
|
|
name: str
|
|
species: str
|
|
class_index: int
|
|
|
|
|
|
class PetIDClassifier:
|
|
def __init__(
|
|
self,
|
|
model_path: str,
|
|
high_threshold: float = DEFAULT_HIGH_THRESHOLD,
|
|
low_threshold: float = DEFAULT_LOW_THRESHOLD,
|
|
):
|
|
self._model_path = model_path
|
|
self._high_threshold = high_threshold
|
|
self._low_threshold = low_threshold
|
|
self._model = None
|
|
self.is_loaded = False
|
|
self._pets: list[RegisteredPet] = []
|
|
|
|
if Path(model_path).exists():
|
|
try:
|
|
import torch
|
|
self._model = torch.load(model_path, map_location="cpu", weights_only=False)
|
|
self._model.eval()
|
|
self.is_loaded = True
|
|
log.info("Pet ID model loaded from %s", model_path)
|
|
except Exception as e:
|
|
log.error("Failed to load pet ID model: %s", e)
|
|
else:
|
|
log.info(
|
|
"Pet ID model not found at %s — identification disabled until trained",
|
|
model_path,
|
|
)
|
|
|
|
@property
|
|
def pet_count(self) -> int:
|
|
return len(self._pets)
|
|
|
|
def register_pet(self, pet_id: str, name: str, species: str) -> None:
|
|
idx = len(self._pets)
|
|
self._pets.append(RegisteredPet(pet_id=pet_id, name=name, species=species,
|
|
class_index=idx))
|
|
|
|
def get_pets_by_species(self, species: str) -> list[RegisteredPet]:
|
|
return [p for p in self._pets if p.species == species]
|
|
|
|
def identify(self, crop: np.ndarray, species: str) -> PetIDResult:
|
|
if not self.is_loaded or self._model is None:
|
|
return PetIDResult(pet_id=None, pet_name=None, confidence=0.0)
|
|
|
|
candidates = self.get_pets_by_species(species)
|
|
if not candidates:
|
|
return PetIDResult(pet_id=None, pet_name=None, confidence=0.0)
|
|
|
|
try:
|
|
import cv2
|
|
import torch
|
|
from torchvision import transforms
|
|
|
|
resized = cv2.resize(crop, (224, 224))
|
|
rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
std=[0.229, 0.224, 0.225]),
|
|
])
|
|
tensor = transform(rgb).unsqueeze(0)
|
|
|
|
with torch.no_grad():
|
|
output = self._model(tensor)
|
|
probs = torch.softmax(output, dim=1)[0]
|
|
|
|
best_conf = 0.0
|
|
best_pet = None
|
|
for pet in candidates:
|
|
if pet.class_index < len(probs):
|
|
conf = float(probs[pet.class_index])
|
|
if conf > best_conf:
|
|
best_conf = conf
|
|
best_pet = pet
|
|
|
|
if best_pet and best_conf >= self._low_threshold:
|
|
return PetIDResult(
|
|
pet_id=best_pet.pet_id,
|
|
pet_name=best_pet.name,
|
|
confidence=best_conf,
|
|
high_threshold=self._high_threshold,
|
|
low_threshold=self._low_threshold,
|
|
)
|
|
|
|
return PetIDResult(pet_id=None, pet_name=None, confidence=best_conf)
|
|
|
|
except Exception as e:
|
|
log.error("Pet ID inference failed: %s", e)
|
|
return PetIDResult(pet_id=None, pet_name=None, confidence=0.0)
|