- CameraConfig.location now uses CameraLocation enum (Pydantic v2 coerces TOML strings) - Wildlife classifier returns ThreatLevel enum values with correct return type annotation - Model backup path fixed: pet_id_backup.pt instead of pet_id.backup.pt - Dashboard submitLabel JS now posts to /pets/<sighting_id>/label matching Flask route - Pet status API computes status field (safe/unknown) based on last-seen recency - digest.py comment explains timestamp unit difference between tables Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
144 lines
5.4 KiB
Python
144 lines
5.4 KiB
Python
"""Pet ID model trainer using MobileNetV3-Small with transfer learning."""
|
|
|
|
import logging
|
|
import shutil
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class TrainingStatus:
|
|
is_training: bool = False
|
|
progress: float = 0.0
|
|
epoch: int = 0
|
|
total_epochs: int = 0
|
|
accuracy: float = 0.0
|
|
error: str | None = None
|
|
|
|
|
|
class PetTrainer:
|
|
def __init__(self, training_dir: str, model_output_path: str):
|
|
self._training_dir = Path(training_dir)
|
|
self._model_output_path = Path(model_output_path)
|
|
self.status = TrainingStatus()
|
|
|
|
def get_class_names(self) -> list[str]:
|
|
if not self._training_dir.exists():
|
|
return []
|
|
return sorted([
|
|
d.name for d in self._training_dir.iterdir()
|
|
if d.is_dir() and not d.name.startswith(".")
|
|
])
|
|
|
|
def check_readiness(self, min_images: int = 20) -> tuple[bool, str]:
|
|
class_names = self.get_class_names()
|
|
if not class_names:
|
|
return False, "No pet directories found in training directory."
|
|
|
|
insufficient = []
|
|
for name in class_names:
|
|
pet_dir = self._training_dir / name
|
|
image_count = sum(1 for f in pet_dir.iterdir()
|
|
if f.suffix.lower() in (".jpg", ".jpeg", ".png"))
|
|
if image_count < min_images:
|
|
insufficient.append(f"{name}: {image_count}/{min_images}")
|
|
|
|
if insufficient:
|
|
return False, f"Insufficient training images: {', '.join(insufficient)}"
|
|
|
|
return True, f"Ready to train with {len(class_names)} classes."
|
|
|
|
def train(self, epochs: int = 30, batch_size: int = 16) -> bool:
|
|
try:
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader
|
|
from torchvision import datasets, models, transforms
|
|
|
|
self.status = TrainingStatus(is_training=True, total_epochs=epochs)
|
|
class_names = self.get_class_names()
|
|
num_classes = len(class_names)
|
|
|
|
if num_classes < 2:
|
|
self.status.error = "Need at least 2 pets to train."
|
|
self.status.is_training = False
|
|
return False
|
|
|
|
train_transform = transforms.Compose([
|
|
transforms.Resize((256, 256)),
|
|
transforms.RandomCrop(224),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
|
])
|
|
|
|
dataset = datasets.ImageFolder(str(self._training_dir), transform=train_transform)
|
|
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
|
|
|
|
model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
|
|
model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, num_classes)
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model = model.to(device)
|
|
|
|
for param in model.features.parameters():
|
|
param.requires_grad = False
|
|
|
|
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=1e-3)
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
|
for epoch in range(epochs):
|
|
model.train()
|
|
running_loss = 0.0
|
|
correct = 0
|
|
total = 0
|
|
|
|
if epoch == 5:
|
|
for param in model.features.parameters():
|
|
param.requires_grad = True
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
|
|
|
for inputs, labels in loader:
|
|
inputs, labels = inputs.to(device), labels.to(device)
|
|
optimizer.zero_grad()
|
|
outputs = model(inputs)
|
|
loss = criterion(outputs, labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
running_loss += loss.item()
|
|
_, predicted = outputs.max(1)
|
|
total += labels.size(0)
|
|
correct += predicted.eq(labels).sum().item()
|
|
|
|
accuracy = correct / total if total > 0 else 0
|
|
self.status.epoch = epoch + 1
|
|
self.status.progress = (epoch + 1) / epochs
|
|
self.status.accuracy = accuracy
|
|
log.info("Epoch %d/%d — loss: %.4f, accuracy: %.4f",
|
|
epoch + 1, epochs, running_loss / len(loader), accuracy)
|
|
|
|
if self._model_output_path.exists():
|
|
backup_path = self._model_output_path.with_name(
|
|
self._model_output_path.stem + "_backup.pt"
|
|
)
|
|
shutil.copy2(self._model_output_path, backup_path)
|
|
log.info("Backed up previous model to %s", backup_path)
|
|
|
|
model = model.to("cpu")
|
|
torch.save(model, self._model_output_path)
|
|
log.info("Pet ID model saved to %s (accuracy: %.2f%%)",
|
|
self._model_output_path, self.status.accuracy * 100)
|
|
|
|
self.status.is_training = False
|
|
return True
|
|
|
|
except Exception as e:
|
|
log.exception("Training failed: %s", e)
|
|
self.status.error = str(e)
|
|
self.status.is_training = False
|
|
return False
|