diff --git a/tests/unit/test_yolo_detector.py b/tests/unit/test_yolo_detector.py new file mode 100644 index 0000000..0d67ad2 --- /dev/null +++ b/tests/unit/test_yolo_detector.py @@ -0,0 +1,52 @@ +"""Tests for YOLOv8 unified detector.""" + +import numpy as np +import pytest + +from vigilar.detection.person import Detection +from vigilar.detection.yolo import YOLODetector, ANIMAL_CLASSES, WILDLIFE_CLASSES + + +class TestYOLOConstants: + def test_animal_classes(self): + assert "cat" in ANIMAL_CLASSES + assert "dog" in ANIMAL_CLASSES + + def test_wildlife_classes(self): + assert "bear" in WILDLIFE_CLASSES + assert "bird" in WILDLIFE_CLASSES + + def test_no_overlap_animal_wildlife(self): + assert not ANIMAL_CLASSES.intersection(WILDLIFE_CLASSES) + + +class TestYOLODetector: + def test_initializes_without_model(self): + detector = YOLODetector(model_path="nonexistent.pt", confidence_threshold=0.5) + assert not detector.is_loaded + + def test_detect_returns_empty_when_not_loaded(self): + detector = YOLODetector(model_path="nonexistent.pt") + frame = np.zeros((480, 640, 3), dtype=np.uint8) + detections = detector.detect(frame) + assert detections == [] + + def test_classify_detection_person(self): + d = Detection(class_name="person", class_id=0, confidence=0.9, bbox=(10, 20, 100, 200)) + assert YOLODetector.classify(d) == "person" + + def test_classify_detection_vehicle(self): + d = Detection(class_name="car", class_id=2, confidence=0.85, bbox=(10, 20, 100, 200)) + assert YOLODetector.classify(d) == "vehicle" + + def test_classify_detection_domestic_animal(self): + d = Detection(class_name="cat", class_id=15, confidence=0.9, bbox=(10, 20, 100, 200)) + assert YOLODetector.classify(d) == "domestic_animal" + + def test_classify_detection_wildlife(self): + d = Detection(class_name="bear", class_id=21, confidence=0.8, bbox=(10, 20, 100, 200)) + assert YOLODetector.classify(d) == "wildlife" + + def test_classify_detection_other(self): + d = Detection(class_name="chair", class_id=56, confidence=0.7, bbox=(10, 20, 100, 200)) + assert YOLODetector.classify(d) == "other" diff --git a/vigilar/detection/yolo.py b/vigilar/detection/yolo.py new file mode 100644 index 0000000..93baabe --- /dev/null +++ b/vigilar/detection/yolo.py @@ -0,0 +1,78 @@ +"""Unified object detection using YOLOv8 via ultralytics.""" + +import logging +from pathlib import Path + +import numpy as np + +from vigilar.detection.person import Detection + +log = logging.getLogger(__name__) + +# COCO class names for domestic animals +ANIMAL_CLASSES = {"cat", "dog"} + +# COCO class names for wildlife (subset that YOLO can detect) +WILDLIFE_CLASSES = {"bear", "bird", "horse", "cow", "sheep", "elephant", "zebra", "giraffe"} + +# Vehicle class names from COCO +VEHICLE_CLASSES = {"car", "motorcycle", "bus", "truck", "boat"} + + +class YOLODetector: + def __init__(self, model_path: str, confidence_threshold: float = 0.5): + self._threshold = confidence_threshold + self._model = None + self.is_loaded = False + + if Path(model_path).exists(): + try: + from ultralytics import YOLO + self._model = YOLO(model_path) + self.is_loaded = True + log.info("YOLO model loaded from %s", model_path) + except Exception as e: + log.error("Failed to load YOLO model: %s", e) + else: + log.warning("YOLO model not found at %s — detection disabled", model_path) + + def detect(self, frame: np.ndarray) -> list[Detection]: + if not self.is_loaded or self._model is None: + return [] + + results = self._model(frame, conf=self._threshold, verbose=False) + detections = [] + + for result in results: + for box in result.boxes: + class_id = int(box.cls[0]) + confidence = float(box.conf[0]) + class_name = result.names[class_id] + + x1, y1, x2, y2 = box.xyxy[0].tolist() + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + bw, bh = x2 - x1, y2 - y1 + if bw <= 0 or bh <= 0: + continue + + detections.append(Detection( + class_name=class_name, + class_id=class_id, + confidence=confidence, + bbox=(x1, y1, bw, bh), + )) + + return detections + + @staticmethod + def classify(detection: Detection) -> str: + name = detection.class_name + if name == "person": + return "person" + if name in VEHICLE_CLASSES: + return "vehicle" + if name in ANIMAL_CLASSES: + return "domestic_animal" + if name in WILDLIFE_CLASSES: + return "wildlife" + return "other"