56 lines
1.9 KiB
Python
56 lines
1.9 KiB
Python
"""Tests for pet ID model trainer."""
|
|
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from vigilar.detection.trainer import PetTrainer, TrainingStatus
|
|
|
|
|
|
class TestTrainingStatus:
|
|
def test_initial_status(self):
|
|
status = TrainingStatus()
|
|
assert status.is_training is False
|
|
assert status.progress == 0.0
|
|
assert status.error is None
|
|
|
|
|
|
class TestPetTrainer:
|
|
def test_check_readiness_no_pets(self, tmp_path):
|
|
trainer = PetTrainer(training_dir=str(tmp_path), model_output_path=str(tmp_path / "model.pt"))
|
|
ready, msg = trainer.check_readiness(min_images=20)
|
|
assert not ready
|
|
assert "No pet" in msg
|
|
|
|
def test_check_readiness_insufficient_images(self, tmp_path):
|
|
pet_dir = tmp_path / "angel"
|
|
pet_dir.mkdir()
|
|
for i in range(5):
|
|
(pet_dir / f"{i}.jpg").write_bytes(b"fake")
|
|
|
|
trainer = PetTrainer(training_dir=str(tmp_path), model_output_path=str(tmp_path / "model.pt"))
|
|
ready, msg = trainer.check_readiness(min_images=20)
|
|
assert not ready
|
|
assert "angel" in msg.lower()
|
|
|
|
def test_check_readiness_sufficient_images(self, tmp_path):
|
|
for name in ["angel", "taquito"]:
|
|
pet_dir = tmp_path / name
|
|
pet_dir.mkdir()
|
|
for i in range(25):
|
|
(pet_dir / f"{i}.jpg").write_bytes(b"fake")
|
|
|
|
trainer = PetTrainer(training_dir=str(tmp_path), model_output_path=str(tmp_path / "model.pt"))
|
|
ready, msg = trainer.check_readiness(min_images=20)
|
|
assert ready
|
|
|
|
def test_get_class_names(self, tmp_path):
|
|
for name in ["angel", "milo", "taquito"]:
|
|
(tmp_path / name).mkdir()
|
|
|
|
trainer = PetTrainer(training_dir=str(tmp_path), model_output_path=str(tmp_path / "model.pt"))
|
|
names = trainer.get_class_names()
|
|
assert names == ["angel", "milo", "taquito"] # sorted
|