diff --git a/tests/unit/test_wildlife.py b/tests/unit/test_wildlife.py new file mode 100644 index 0000000..0abdd35 --- /dev/null +++ b/tests/unit/test_wildlife.py @@ -0,0 +1,50 @@ +"""Tests for wildlife threat classification.""" + +from vigilar.config import WildlifeConfig, WildlifeThreatMap, WildlifeSizeHeuristics +from vigilar.detection.person import Detection +from vigilar.detection.wildlife import classify_wildlife_threat + + +def _make_config(**kwargs): + return WildlifeConfig(**kwargs) + + +class TestWildlifeThreatClassification: + def test_bear_is_predator(self): + cfg = _make_config() + d = Detection(class_name="bear", class_id=21, confidence=0.9, bbox=(100, 100, 200, 300)) + level, species = classify_wildlife_threat(d, cfg, frame_area=1920 * 1080) + assert level == "PREDATOR" + assert species == "bear" + + def test_bird_is_passive(self): + cfg = _make_config() + d = Detection(class_name="bird", class_id=14, confidence=0.8, bbox=(10, 10, 30, 20)) + level, species = classify_wildlife_threat(d, cfg, frame_area=1920 * 1080) + assert level == "PASSIVE" + assert species == "bird" + + def test_unknown_small_is_nuisance(self): + cfg = _make_config() + d = Detection(class_name="unknown", class_id=99, confidence=0.7, bbox=(100, 100, 30, 30)) + level, species = classify_wildlife_threat(d, cfg, frame_area=1920 * 1080) + assert level == "NUISANCE" + assert species == "unknown" + + def test_unknown_medium_is_predator(self): + cfg = _make_config() + d = Detection(class_name="unknown", class_id=99, confidence=0.7, bbox=(100, 100, 300, 300)) + level, species = classify_wildlife_threat(d, cfg, frame_area=1920 * 1080) + assert level == "PREDATOR" + + def test_unknown_large_is_passive(self): + cfg = _make_config() + d = Detection(class_name="unknown", class_id=99, confidence=0.7, bbox=(100, 100, 600, 500)) + level, species = classify_wildlife_threat(d, cfg, frame_area=1920 * 1080) + assert level == "PASSIVE" + + def test_custom_threat_map(self): + cfg = _make_config(threat_map=WildlifeThreatMap(predator=["bear", "wolf"])) + d = Detection(class_name="wolf", class_id=99, confidence=0.85, bbox=(100, 100, 200, 200)) + level, _ = classify_wildlife_threat(d, cfg, frame_area=1920 * 1080) + assert level == "PREDATOR" diff --git a/vigilar/detection/wildlife.py b/vigilar/detection/wildlife.py new file mode 100644 index 0000000..85fb175 --- /dev/null +++ b/vigilar/detection/wildlife.py @@ -0,0 +1,38 @@ +"""Wildlife threat level classification.""" + +from vigilar.config import WildlifeConfig +from vigilar.detection.person import Detection + + +def classify_wildlife_threat( + detection: Detection, + config: WildlifeConfig, + frame_area: int, +) -> tuple[str, str]: + """Classify a wildlife detection into threat level and species. + + Returns (threat_level, species_name). + """ + species = detection.class_name + threat_map = config.threat_map + + # Direct COCO class mapping first + if species in threat_map.predator: + return "PREDATOR", species + if species in threat_map.nuisance: + return "NUISANCE", species + if species in threat_map.passive: + return "PASSIVE", species + + # Fallback to size heuristics for unknown species + _, _, w, h = detection.bbox + bbox_area = w * h + area_ratio = bbox_area / frame_area if frame_area > 0 else 0 + + heuristics = config.size_heuristics + if area_ratio < heuristics.small: + return "NUISANCE", species + elif area_ratio < heuristics.medium: + return "PREDATOR", species + else: + return "PASSIVE", species