diff --git a/tests/unit/test_trainer.py b/tests/unit/test_trainer.py new file mode 100644 index 0000000..aad7dd7 --- /dev/null +++ b/tests/unit/test_trainer.py @@ -0,0 +1,55 @@ +"""Tests for pet ID model trainer.""" + +import os +from pathlib import Path + +import numpy as np +import pytest + +from vigilar.detection.trainer import PetTrainer, TrainingStatus + + +class TestTrainingStatus: + def test_initial_status(self): + status = TrainingStatus() + assert status.is_training is False + assert status.progress == 0.0 + assert status.error is None + + +class TestPetTrainer: + def test_check_readiness_no_pets(self, tmp_path): + trainer = PetTrainer(training_dir=str(tmp_path), model_output_path=str(tmp_path / "model.pt")) + ready, msg = trainer.check_readiness(min_images=20) + assert not ready + assert "No pet" in msg + + def test_check_readiness_insufficient_images(self, tmp_path): + pet_dir = tmp_path / "angel" + pet_dir.mkdir() + for i in range(5): + (pet_dir / f"{i}.jpg").write_bytes(b"fake") + + trainer = PetTrainer(training_dir=str(tmp_path), model_output_path=str(tmp_path / "model.pt")) + ready, msg = trainer.check_readiness(min_images=20) + assert not ready + assert "angel" in msg.lower() + + def test_check_readiness_sufficient_images(self, tmp_path): + for name in ["angel", "taquito"]: + pet_dir = tmp_path / name + pet_dir.mkdir() + for i in range(25): + (pet_dir / f"{i}.jpg").write_bytes(b"fake") + + trainer = PetTrainer(training_dir=str(tmp_path), model_output_path=str(tmp_path / "model.pt")) + ready, msg = trainer.check_readiness(min_images=20) + assert ready + + def test_get_class_names(self, tmp_path): + for name in ["angel", "milo", "taquito"]: + (tmp_path / name).mkdir() + + trainer = PetTrainer(training_dir=str(tmp_path), model_output_path=str(tmp_path / "model.pt")) + names = trainer.get_class_names() + assert names == ["angel", "milo", "taquito"] # sorted diff --git a/vigilar/detection/trainer.py b/vigilar/detection/trainer.py new file mode 100644 index 0000000..5a68fcc --- /dev/null +++ b/vigilar/detection/trainer.py @@ -0,0 +1,141 @@ +"""Pet ID model trainer using MobileNetV3-Small with transfer learning.""" + +import logging +import shutil +from dataclasses import dataclass, field +from pathlib import Path + +log = logging.getLogger(__name__) + + +@dataclass +class TrainingStatus: + is_training: bool = False + progress: float = 0.0 + epoch: int = 0 + total_epochs: int = 0 + accuracy: float = 0.0 + error: str | None = None + + +class PetTrainer: + def __init__(self, training_dir: str, model_output_path: str): + self._training_dir = Path(training_dir) + self._model_output_path = Path(model_output_path) + self.status = TrainingStatus() + + def get_class_names(self) -> list[str]: + if not self._training_dir.exists(): + return [] + return sorted([ + d.name for d in self._training_dir.iterdir() + if d.is_dir() and not d.name.startswith(".") + ]) + + def check_readiness(self, min_images: int = 20) -> tuple[bool, str]: + class_names = self.get_class_names() + if not class_names: + return False, "No pet directories found in training directory." + + insufficient = [] + for name in class_names: + pet_dir = self._training_dir / name + image_count = sum(1 for f in pet_dir.iterdir() + if f.suffix.lower() in (".jpg", ".jpeg", ".png")) + if image_count < min_images: + insufficient.append(f"{name}: {image_count}/{min_images}") + + if insufficient: + return False, f"Insufficient training images: {', '.join(insufficient)}" + + return True, f"Ready to train with {len(class_names)} classes." + + def train(self, epochs: int = 30, batch_size: int = 16) -> bool: + try: + import torch + import torch.nn as nn + from torch.utils.data import DataLoader + from torchvision import datasets, models, transforms + + self.status = TrainingStatus(is_training=True, total_epochs=epochs) + class_names = self.get_class_names() + num_classes = len(class_names) + + if num_classes < 2: + self.status.error = "Need at least 2 pets to train." + self.status.is_training = False + return False + + train_transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.RandomCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ]) + + dataset = datasets.ImageFolder(str(self._training_dir), transform=train_transform) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2) + + model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT) + model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, num_classes) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + + for param in model.features.parameters(): + param.requires_grad = False + + optimizer = torch.optim.Adam(model.classifier.parameters(), lr=1e-3) + criterion = nn.CrossEntropyLoss() + + for epoch in range(epochs): + model.train() + running_loss = 0.0 + correct = 0 + total = 0 + + if epoch == 5: + for param in model.features.parameters(): + param.requires_grad = True + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + + for inputs, labels in loader: + inputs, labels = inputs.to(device), labels.to(device) + optimizer.zero_grad() + outputs = model(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + running_loss += loss.item() + _, predicted = outputs.max(1) + total += labels.size(0) + correct += predicted.eq(labels).sum().item() + + accuracy = correct / total if total > 0 else 0 + self.status.epoch = epoch + 1 + self.status.progress = (epoch + 1) / epochs + self.status.accuracy = accuracy + log.info("Epoch %d/%d — loss: %.4f, accuracy: %.4f", + epoch + 1, epochs, running_loss / len(loader), accuracy) + + if self._model_output_path.exists(): + backup_path = self._model_output_path.with_suffix(".backup.pt") + shutil.copy2(self._model_output_path, backup_path) + log.info("Backed up previous model to %s", backup_path) + + model = model.to("cpu") + torch.save(model, self._model_output_path) + log.info("Pet ID model saved to %s (accuracy: %.2f%%)", + self._model_output_path, self.status.accuracy * 100) + + self.status.is_training = False + return True + + except Exception as e: + log.exception("Training failed: %s", e) + self.status.error = str(e) + self.status.is_training = False + return False