Add pet ID model trainer with MobileNetV3-Small transfer learning
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
c7f9304f2a
commit
e48ba305ea
55
tests/unit/test_trainer.py
Normal file
55
tests/unit/test_trainer.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
"""Tests for pet ID model trainer."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vigilar.detection.trainer import PetTrainer, TrainingStatus
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrainingStatus:
|
||||||
|
def test_initial_status(self):
|
||||||
|
status = TrainingStatus()
|
||||||
|
assert status.is_training is False
|
||||||
|
assert status.progress == 0.0
|
||||||
|
assert status.error is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestPetTrainer:
|
||||||
|
def test_check_readiness_no_pets(self, tmp_path):
|
||||||
|
trainer = PetTrainer(training_dir=str(tmp_path), model_output_path=str(tmp_path / "model.pt"))
|
||||||
|
ready, msg = trainer.check_readiness(min_images=20)
|
||||||
|
assert not ready
|
||||||
|
assert "No pet" in msg
|
||||||
|
|
||||||
|
def test_check_readiness_insufficient_images(self, tmp_path):
|
||||||
|
pet_dir = tmp_path / "angel"
|
||||||
|
pet_dir.mkdir()
|
||||||
|
for i in range(5):
|
||||||
|
(pet_dir / f"{i}.jpg").write_bytes(b"fake")
|
||||||
|
|
||||||
|
trainer = PetTrainer(training_dir=str(tmp_path), model_output_path=str(tmp_path / "model.pt"))
|
||||||
|
ready, msg = trainer.check_readiness(min_images=20)
|
||||||
|
assert not ready
|
||||||
|
assert "angel" in msg.lower()
|
||||||
|
|
||||||
|
def test_check_readiness_sufficient_images(self, tmp_path):
|
||||||
|
for name in ["angel", "taquito"]:
|
||||||
|
pet_dir = tmp_path / name
|
||||||
|
pet_dir.mkdir()
|
||||||
|
for i in range(25):
|
||||||
|
(pet_dir / f"{i}.jpg").write_bytes(b"fake")
|
||||||
|
|
||||||
|
trainer = PetTrainer(training_dir=str(tmp_path), model_output_path=str(tmp_path / "model.pt"))
|
||||||
|
ready, msg = trainer.check_readiness(min_images=20)
|
||||||
|
assert ready
|
||||||
|
|
||||||
|
def test_get_class_names(self, tmp_path):
|
||||||
|
for name in ["angel", "milo", "taquito"]:
|
||||||
|
(tmp_path / name).mkdir()
|
||||||
|
|
||||||
|
trainer = PetTrainer(training_dir=str(tmp_path), model_output_path=str(tmp_path / "model.pt"))
|
||||||
|
names = trainer.get_class_names()
|
||||||
|
assert names == ["angel", "milo", "taquito"] # sorted
|
||||||
141
vigilar/detection/trainer.py
Normal file
141
vigilar/detection/trainer.py
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
"""Pet ID model trainer using MobileNetV3-Small with transfer learning."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingStatus:
|
||||||
|
is_training: bool = False
|
||||||
|
progress: float = 0.0
|
||||||
|
epoch: int = 0
|
||||||
|
total_epochs: int = 0
|
||||||
|
accuracy: float = 0.0
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class PetTrainer:
|
||||||
|
def __init__(self, training_dir: str, model_output_path: str):
|
||||||
|
self._training_dir = Path(training_dir)
|
||||||
|
self._model_output_path = Path(model_output_path)
|
||||||
|
self.status = TrainingStatus()
|
||||||
|
|
||||||
|
def get_class_names(self) -> list[str]:
|
||||||
|
if not self._training_dir.exists():
|
||||||
|
return []
|
||||||
|
return sorted([
|
||||||
|
d.name for d in self._training_dir.iterdir()
|
||||||
|
if d.is_dir() and not d.name.startswith(".")
|
||||||
|
])
|
||||||
|
|
||||||
|
def check_readiness(self, min_images: int = 20) -> tuple[bool, str]:
|
||||||
|
class_names = self.get_class_names()
|
||||||
|
if not class_names:
|
||||||
|
return False, "No pet directories found in training directory."
|
||||||
|
|
||||||
|
insufficient = []
|
||||||
|
for name in class_names:
|
||||||
|
pet_dir = self._training_dir / name
|
||||||
|
image_count = sum(1 for f in pet_dir.iterdir()
|
||||||
|
if f.suffix.lower() in (".jpg", ".jpeg", ".png"))
|
||||||
|
if image_count < min_images:
|
||||||
|
insufficient.append(f"{name}: {image_count}/{min_images}")
|
||||||
|
|
||||||
|
if insufficient:
|
||||||
|
return False, f"Insufficient training images: {', '.join(insufficient)}"
|
||||||
|
|
||||||
|
return True, f"Ready to train with {len(class_names)} classes."
|
||||||
|
|
||||||
|
def train(self, epochs: int = 30, batch_size: int = 16) -> bool:
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torchvision import datasets, models, transforms
|
||||||
|
|
||||||
|
self.status = TrainingStatus(is_training=True, total_epochs=epochs)
|
||||||
|
class_names = self.get_class_names()
|
||||||
|
num_classes = len(class_names)
|
||||||
|
|
||||||
|
if num_classes < 2:
|
||||||
|
self.status.error = "Need at least 2 pets to train."
|
||||||
|
self.status.is_training = False
|
||||||
|
return False
|
||||||
|
|
||||||
|
train_transform = transforms.Compose([
|
||||||
|
transforms.Resize((256, 256)),
|
||||||
|
transforms.RandomCrop(224),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||||
|
])
|
||||||
|
|
||||||
|
dataset = datasets.ImageFolder(str(self._training_dir), transform=train_transform)
|
||||||
|
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
|
||||||
|
|
||||||
|
model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
|
||||||
|
model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, num_classes)
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
for param in model.features.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=1e-3)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
model.train()
|
||||||
|
running_loss = 0.0
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
if epoch == 5:
|
||||||
|
for param in model.features.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||||
|
|
||||||
|
for inputs, labels in loader:
|
||||||
|
inputs, labels = inputs.to(device), labels.to(device)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
outputs = model(inputs)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
running_loss += loss.item()
|
||||||
|
_, predicted = outputs.max(1)
|
||||||
|
total += labels.size(0)
|
||||||
|
correct += predicted.eq(labels).sum().item()
|
||||||
|
|
||||||
|
accuracy = correct / total if total > 0 else 0
|
||||||
|
self.status.epoch = epoch + 1
|
||||||
|
self.status.progress = (epoch + 1) / epochs
|
||||||
|
self.status.accuracy = accuracy
|
||||||
|
log.info("Epoch %d/%d — loss: %.4f, accuracy: %.4f",
|
||||||
|
epoch + 1, epochs, running_loss / len(loader), accuracy)
|
||||||
|
|
||||||
|
if self._model_output_path.exists():
|
||||||
|
backup_path = self._model_output_path.with_suffix(".backup.pt")
|
||||||
|
shutil.copy2(self._model_output_path, backup_path)
|
||||||
|
log.info("Backed up previous model to %s", backup_path)
|
||||||
|
|
||||||
|
model = model.to("cpu")
|
||||||
|
torch.save(model, self._model_output_path)
|
||||||
|
log.info("Pet ID model saved to %s (accuracy: %.2f%%)",
|
||||||
|
self._model_output_path, self.status.accuracy * 100)
|
||||||
|
|
||||||
|
self.status.is_training = False
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.exception("Training failed: %s", e)
|
||||||
|
self.status.error = str(e)
|
||||||
|
self.status.is_training = False
|
||||||
|
return False
|
||||||
Loading…
Reference in New Issue
Block a user