Add pet ID model trainer with MobileNetV3-Small transfer learning

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Aaron D. Lee 2026-04-03 13:19:12 -04:00
parent c7f9304f2a
commit e48ba305ea
2 changed files with 196 additions and 0 deletions

View File

@ -0,0 +1,55 @@
"""Tests for pet ID model trainer."""
import os
from pathlib import Path
import numpy as np
import pytest
from vigilar.detection.trainer import PetTrainer, TrainingStatus
class TestTrainingStatus:
def test_initial_status(self):
status = TrainingStatus()
assert status.is_training is False
assert status.progress == 0.0
assert status.error is None
class TestPetTrainer:
def test_check_readiness_no_pets(self, tmp_path):
trainer = PetTrainer(training_dir=str(tmp_path), model_output_path=str(tmp_path / "model.pt"))
ready, msg = trainer.check_readiness(min_images=20)
assert not ready
assert "No pet" in msg
def test_check_readiness_insufficient_images(self, tmp_path):
pet_dir = tmp_path / "angel"
pet_dir.mkdir()
for i in range(5):
(pet_dir / f"{i}.jpg").write_bytes(b"fake")
trainer = PetTrainer(training_dir=str(tmp_path), model_output_path=str(tmp_path / "model.pt"))
ready, msg = trainer.check_readiness(min_images=20)
assert not ready
assert "angel" in msg.lower()
def test_check_readiness_sufficient_images(self, tmp_path):
for name in ["angel", "taquito"]:
pet_dir = tmp_path / name
pet_dir.mkdir()
for i in range(25):
(pet_dir / f"{i}.jpg").write_bytes(b"fake")
trainer = PetTrainer(training_dir=str(tmp_path), model_output_path=str(tmp_path / "model.pt"))
ready, msg = trainer.check_readiness(min_images=20)
assert ready
def test_get_class_names(self, tmp_path):
for name in ["angel", "milo", "taquito"]:
(tmp_path / name).mkdir()
trainer = PetTrainer(training_dir=str(tmp_path), model_output_path=str(tmp_path / "model.pt"))
names = trainer.get_class_names()
assert names == ["angel", "milo", "taquito"] # sorted

View File

@ -0,0 +1,141 @@
"""Pet ID model trainer using MobileNetV3-Small with transfer learning."""
import logging
import shutil
from dataclasses import dataclass, field
from pathlib import Path
log = logging.getLogger(__name__)
@dataclass
class TrainingStatus:
is_training: bool = False
progress: float = 0.0
epoch: int = 0
total_epochs: int = 0
accuracy: float = 0.0
error: str | None = None
class PetTrainer:
def __init__(self, training_dir: str, model_output_path: str):
self._training_dir = Path(training_dir)
self._model_output_path = Path(model_output_path)
self.status = TrainingStatus()
def get_class_names(self) -> list[str]:
if not self._training_dir.exists():
return []
return sorted([
d.name for d in self._training_dir.iterdir()
if d.is_dir() and not d.name.startswith(".")
])
def check_readiness(self, min_images: int = 20) -> tuple[bool, str]:
class_names = self.get_class_names()
if not class_names:
return False, "No pet directories found in training directory."
insufficient = []
for name in class_names:
pet_dir = self._training_dir / name
image_count = sum(1 for f in pet_dir.iterdir()
if f.suffix.lower() in (".jpg", ".jpeg", ".png"))
if image_count < min_images:
insufficient.append(f"{name}: {image_count}/{min_images}")
if insufficient:
return False, f"Insufficient training images: {', '.join(insufficient)}"
return True, f"Ready to train with {len(class_names)} classes."
def train(self, epochs: int = 30, batch_size: int = 16) -> bool:
try:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
self.status = TrainingStatus(is_training=True, total_epochs=epochs)
class_names = self.get_class_names()
num_classes = len(class_names)
if num_classes < 2:
self.status.error = "Need at least 2 pets to train."
self.status.is_training = False
return False
train_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
dataset = datasets.ImageFolder(str(self._training_dir), transform=train_transform)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, num_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
for param in model.features.parameters():
param.requires_grad = False
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
for epoch in range(epochs):
model.train()
running_loss = 0.0
correct = 0
total = 0
if epoch == 5:
for param in model.features.parameters():
param.requires_grad = True
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for inputs, labels in loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
accuracy = correct / total if total > 0 else 0
self.status.epoch = epoch + 1
self.status.progress = (epoch + 1) / epochs
self.status.accuracy = accuracy
log.info("Epoch %d/%d — loss: %.4f, accuracy: %.4f",
epoch + 1, epochs, running_loss / len(loader), accuracy)
if self._model_output_path.exists():
backup_path = self._model_output_path.with_suffix(".backup.pt")
shutil.copy2(self._model_output_path, backup_path)
log.info("Backed up previous model to %s", backup_path)
model = model.to("cpu")
torch.save(model, self._model_output_path)
log.info("Pet ID model saved to %s (accuracy: %.2f%%)",
self._model_output_path, self.status.accuracy * 100)
self.status.is_training = False
return True
except Exception as e:
log.exception("Training failed: %s", e)
self.status.error = str(e)
self.status.is_training = False
return False