diff --git a/tests/unit/test_face.py b/tests/unit/test_face.py new file mode 100644 index 0000000..a45ab86 --- /dev/null +++ b/tests/unit/test_face.py @@ -0,0 +1,48 @@ +import numpy as np +import pytest +from vigilar.detection.face import FaceRecognizer, FaceResult + + +def test_face_recognizer_init(): + fr = FaceRecognizer(match_threshold=0.6) + assert fr._threshold == 0.6 + assert fr.is_loaded is False + + +def test_load_profiles_empty(test_db): + fr = FaceRecognizer() + fr.load_profiles(test_db) + assert fr.is_loaded is True + assert len(fr._known_encodings) == 0 + + +def test_add_encoding(): + fr = FaceRecognizer() + fr.is_loaded = True + enc = np.random.rand(128).astype(np.float32) + fr.add_encoding(1, enc) + assert len(fr._known_encodings) == 1 + + +def test_find_match(): + fr = FaceRecognizer(match_threshold=0.6) + fr.is_loaded = True + enc = np.random.rand(128).astype(np.float32) + fr.add_encoding(1, enc) + result = fr._find_match(enc) + assert result is not None + assert result[0] == 1 + + +def test_no_match(): + fr = FaceRecognizer(match_threshold=0.6) + fr.is_loaded = True + fr.add_encoding(1, np.ones(128, dtype=np.float32)) + result = fr._find_match(-np.ones(128, dtype=np.float32)) + assert result is None + + +def test_face_result_dataclass(): + crop = np.zeros((100, 100, 3), dtype=np.uint8) + r = FaceResult(1, "Bob", 0.85, crop, (10, 20, 50, 50)) + assert r.profile_id == 1 diff --git a/vigilar/detection/face.py b/vigilar/detection/face.py new file mode 100644 index 0000000..0d4fc51 --- /dev/null +++ b/vigilar/detection/face.py @@ -0,0 +1,103 @@ +"""Local face recognition using face_recognition library (dlib-based).""" + +import base64 +import logging +from dataclasses import dataclass + +import numpy as np +from sqlalchemy.engine import Engine + +log = logging.getLogger(__name__) + + +@dataclass +class FaceResult: + profile_id: int | None + name: str | None + confidence: float + face_crop: np.ndarray + bbox: tuple[int, int, int, int] + + +class FaceRecognizer: + def __init__(self, match_threshold: float = 0.6): + self._threshold = match_threshold + self._known_encodings: list[np.ndarray] = [] + self._known_profile_ids: list[int] = [] + self._known_names: list[str | None] = [] + self.is_loaded = False + + def load_profiles(self, engine: Engine) -> None: + from vigilar.storage.queries import get_all_profiles, get_embeddings_for_profile + self._known_encodings = [] + self._known_profile_ids = [] + self._known_names = [] + profiles = get_all_profiles(engine) + for profile in profiles: + embeddings = get_embeddings_for_profile(engine, profile["id"]) + for emb in embeddings: + try: + raw = base64.b64decode(emb["embedding"]) + encoding = np.frombuffer(raw, dtype=np.float32) + if len(encoding) == 128: + self._known_encodings.append(encoding) + self._known_profile_ids.append(profile["id"]) + self._known_names.append(profile.get("name")) + except Exception: + continue + self.is_loaded = True + log.info( + "Loaded %d face encodings from %d profiles", + len(self._known_encodings), + len(profiles), + ) + + def add_encoding(self, profile_id: int, encoding: np.ndarray, name: str | None = None) -> None: + self._known_encodings.append(encoding) + self._known_profile_ids.append(profile_id) + self._known_names.append(name) + + def identify(self, frame: np.ndarray) -> list[FaceResult]: + if not self.is_loaded: + return [] + try: + import face_recognition + except ImportError: + log.warning("face_recognition not installed") + return [] + face_locations = face_recognition.face_locations(frame) + if not face_locations: + return [] + face_encodings = face_recognition.face_encodings(frame, face_locations) + results = [] + for (top, right, bottom, left), encoding in zip(face_locations, face_encodings): + face_crop = frame[top:bottom, left:right].copy() + bbox = (left, top, right - left, bottom - top) + match = self._find_match(encoding) + if match: + pid, conf, name = match + results.append(FaceResult(pid, name, conf, face_crop, bbox)) + else: + results.append(FaceResult(None, None, 0.0, face_crop, bbox)) + return results + + def _find_match(self, encoding: np.ndarray) -> tuple[int, float, str | None] | None: + if not self._known_encodings: + return None + distances = [float(np.linalg.norm(encoding - k)) for k in self._known_encodings] + min_idx = int(np.argmin(distances)) + if distances[min_idx] < self._threshold: + return ( + self._known_profile_ids[min_idx], + 1.0 - distances[min_idx], + self._known_names[min_idx], + ) + return None + + +def encoding_to_b64(encoding: np.ndarray) -> str: + return base64.b64encode(encoding.astype(np.float32).tobytes()).decode() + + +def b64_to_encoding(b64_str: str) -> np.ndarray: + return np.frombuffer(base64.b64decode(b64_str), dtype=np.float32)