Add pet ID model trainer with MobileNetV3-Small transfer learning
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
55
tests/unit/test_trainer.py
Normal file
55
tests/unit/test_trainer.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Tests for pet ID model trainer."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from vigilar.detection.trainer import PetTrainer, TrainingStatus
|
||||
|
||||
|
||||
class TestTrainingStatus:
|
||||
def test_initial_status(self):
|
||||
status = TrainingStatus()
|
||||
assert status.is_training is False
|
||||
assert status.progress == 0.0
|
||||
assert status.error is None
|
||||
|
||||
|
||||
class TestPetTrainer:
|
||||
def test_check_readiness_no_pets(self, tmp_path):
|
||||
trainer = PetTrainer(training_dir=str(tmp_path), model_output_path=str(tmp_path / "model.pt"))
|
||||
ready, msg = trainer.check_readiness(min_images=20)
|
||||
assert not ready
|
||||
assert "No pet" in msg
|
||||
|
||||
def test_check_readiness_insufficient_images(self, tmp_path):
|
||||
pet_dir = tmp_path / "angel"
|
||||
pet_dir.mkdir()
|
||||
for i in range(5):
|
||||
(pet_dir / f"{i}.jpg").write_bytes(b"fake")
|
||||
|
||||
trainer = PetTrainer(training_dir=str(tmp_path), model_output_path=str(tmp_path / "model.pt"))
|
||||
ready, msg = trainer.check_readiness(min_images=20)
|
||||
assert not ready
|
||||
assert "angel" in msg.lower()
|
||||
|
||||
def test_check_readiness_sufficient_images(self, tmp_path):
|
||||
for name in ["angel", "taquito"]:
|
||||
pet_dir = tmp_path / name
|
||||
pet_dir.mkdir()
|
||||
for i in range(25):
|
||||
(pet_dir / f"{i}.jpg").write_bytes(b"fake")
|
||||
|
||||
trainer = PetTrainer(training_dir=str(tmp_path), model_output_path=str(tmp_path / "model.pt"))
|
||||
ready, msg = trainer.check_readiness(min_images=20)
|
||||
assert ready
|
||||
|
||||
def test_get_class_names(self, tmp_path):
|
||||
for name in ["angel", "milo", "taquito"]:
|
||||
(tmp_path / name).mkdir()
|
||||
|
||||
trainer = PetTrainer(training_dir=str(tmp_path), model_output_path=str(tmp_path / "model.pt"))
|
||||
names = trainer.get_class_names()
|
||||
assert names == ["angel", "milo", "taquito"] # sorted
|
||||
Reference in New Issue
Block a user