"""Tracking de visages avec IDs stables et politiques de sélection."""

import time
import numpy as np
import logging
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass
from .mp import FaceDetection

logger = logging.getLogger(__name__)


@dataclass
class TrackedFace:
    """Visage avec ID stable et historique."""

    id: int
    detection: FaceDetection
    last_seen: float
    age_frames: int
    history_bbox: List[Tuple[float, float, float, float]]  # Historique pour lissage

    def update(self, detection: FaceDetection) -> None:
        """Met à jour avec nouvelle détection."""
        self.detection = detection
        self.last_seen = time.time()
        self.age_frames += 1

        # Garder historique limité
        self.history_bbox.append(detection.bbox)
        if len(self.history_bbox) > 10:
            self.history_bbox.pop(0)


class FaceTracker:
    """Tracker de visages avec IDs stables."""

    def __init__(self, max_targets: int = 3, max_age_seconds: float = 2.0):
        self.max_targets = max_targets
        self.max_age_seconds = max_age_seconds
        self.tracked_faces: Dict[int, TrackedFace] = {}
        self._next_id = 1
        self._last_update = time.time()

    def update(self, detections: List[FaceDetection]) -> List[TrackedFace]:
        """Met à jour tracking avec nouvelles détections."""
        current_time = time.time()

        # Nettoyer visages trop anciens
        self._cleanup_old_faces(current_time)

        if not detections:
            return list(self.tracked_faces.values())

        # Association détections -> tracked faces existantes
        matches = self._associate_detections(detections)

        # Mettre à jour faces existantes
        updated_ids = set()
        for det_idx, track_id in matches.items():
            if track_id in self.tracked_faces:
                self.tracked_faces[track_id].update(detections[det_idx])
                updated_ids.add(track_id)

        # Créer nouvelles faces pour détections non associées
        for i, detection in enumerate(detections):
            if i not in matches:
                if len(self.tracked_faces) < self.max_targets:
                    new_face = TrackedFace(
                        id=self._next_id,
                        detection=detection,
                        last_seen=current_time,
                        age_frames=1,
                        history_bbox=[detection.bbox],
                    )
                    self.tracked_faces[self._next_id] = new_face
                    self._next_id += 1

        self._last_update = current_time
        return list(self.tracked_faces.values())

    def _cleanup_old_faces(self, current_time: float) -> None:
        """Supprime les visages non vus depuis trop longtemps."""
        to_remove = []
        for face_id, face in self.tracked_faces.items():
            if current_time - face.last_seen > self.max_age_seconds:
                to_remove.append(face_id)

        for face_id in to_remove:
            del self.tracked_faces[face_id]
            logger.debug(f"Visage {face_id} supprimé (trop ancien)")

    def _associate_detections(self, detections: List[FaceDetection]) -> Dict[int, int]:
        """Associe détections aux visages trackés par IoU + distance."""
        if not self.tracked_faces or not detections:
            return {}

        matches = {}
        used_tracks = set()

        # Calcul matrice de coûts (1 - score_association)
        costs = []
        track_ids = list(self.tracked_faces.keys())

        for i, detection in enumerate(detections):
            det_costs = []
            for track_id in track_ids:
                if track_id in used_tracks:
                    det_costs.append(1.0)  # Coût max si déjà utilisé
                else:
                    score = self._association_score(
                        detection, self.tracked_faces[track_id]
                    )
                    det_costs.append(1.0 - score)
            costs.append(det_costs)

        # Association gloutonne simple (Hungarian serait mieux mais plus lourd)
        for _ in range(min(len(detections), len(track_ids))):
            best_score = 0.0
            best_det = -1
            best_track = -1

            for i, detection in enumerate(detections):
                if i in matches:
                    continue

                for j, track_id in enumerate(track_ids):
                    if track_id in used_tracks:
                        continue

                    score = self._association_score(
                        detection, self.tracked_faces[track_id]
                    )
                    if score > best_score and score > 0.3:  # Seuil minimum
                        best_score = score
                        best_det = i
                        best_track = track_id

            if best_det >= 0 and best_track >= 0:
                matches[best_det] = best_track
                used_tracks.add(best_track)

        return matches

    def _association_score(
        self, detection: FaceDetection, tracked: TrackedFace
    ) -> float:
        """Calcule score d'association [0,1] entre détection et visage tracké."""
        # IoU des bounding boxes
        iou = self._compute_iou(detection.bbox, tracked.detection.bbox)

        # Distance entre centres
        det_cx, det_cy = detection.bbox[0], detection.bbox[1]
        track_cx, track_cy = tracked.detection.bbox[0], tracked.detection.bbox[1]
        dist = np.sqrt((det_cx - track_cx) ** 2 + (det_cy - track_cy) ** 2)
        dist_score = max(0.0, 1.0 - dist * 3.0)  # Pénaliser distance > 0.33

        # Similarité taille
        det_area = detection.bbox[2] * detection.bbox[3]
        track_area = tracked.detection.bbox[2] * tracked.detection.bbox[3]
        if det_area > 0 and track_area > 0:
            size_ratio = min(det_area, track_area) / max(det_area, track_area)
        else:
            size_ratio = 0.0

        # Score combiné
        score = 0.5 * iou + 0.3 * dist_score + 0.2 * size_ratio
        return min(1.0, score)

    def _compute_iou(
        self,
        bbox1: Tuple[float, float, float, float],
        bbox2: Tuple[float, float, float, float],
    ) -> float:
        """Calcule IoU entre deux bboxes (cx, cy, w, h)."""
        cx1, cy1, w1, h1 = bbox1
        cx2, cy2, w2, h2 = bbox2

        # Conversion vers (x1, y1, x2, y2)
        x1_1, y1_1 = cx1 - w1 / 2, cy1 - h1 / 2
        x2_1, y2_1 = cx1 + w1 / 2, cy1 + h1 / 2

        x1_2, y1_2 = cx2 - w2 / 2, cy2 - h2 / 2
        x2_2, y2_2 = cx2 + w2 / 2, cy2 + h2 / 2

        # Intersection
        xi1 = max(x1_1, x1_2)
        yi1 = max(y1_1, y1_2)
        xi2 = min(x2_1, x2_2)
        yi2 = min(y2_1, y2_2)

        if xi2 <= xi1 or yi2 <= yi1:
            return 0.0

        inter_area = (xi2 - xi1) * (yi2 - yi1)

        # Union
        area1 = w1 * h1
        area2 = w2 * h2
        union_area = area1 + area2 - inter_area

        if union_area <= 0:
            return 0.0

        return inter_area / union_area

    def get_tracked_faces(self) -> List[TrackedFace]:
        """Retourne la liste des visages trackés."""
        return list(self.tracked_faces.values())


class TargetSelector:
    """Sélecteur de cible selon différentes politiques."""

    def __init__(self, policy: str = "round_robin", dwell_ms: int = 1200):
        self.policy = policy
        self.dwell_ms = dwell_ms
        self._current_target_id: Optional[int] = None
        self._last_switch_time = time.time()
        self._round_robin_idx = 0

    def select_target(self, tracked_faces: List[TrackedFace]) -> Optional[TrackedFace]:
        """Sélectionne la cible courante selon la politique."""
        if not tracked_faces:
            self._current_target_id = None
            return None

        if self.policy == "closest":
            return self._select_closest(tracked_faces)
        elif self.policy == "largest":
            return self._select_largest(tracked_faces)
        elif self.policy == "round_robin":
            return self._select_round_robin(tracked_faces)
        else:
            logger.warning(f"Politique inconnue: {self.policy}, utilisation 'largest'")
            return self._select_largest(tracked_faces)

    def _select_closest(self, faces: List[TrackedFace]) -> TrackedFace:
        """Sélectionne le visage le plus proche du centre."""
        best_face = None
        best_dist = float("inf")

        for face in faces:
            cx, cy = face.detection.bbox[0], face.detection.bbox[1]
            # Distance au centre de l'image (0.5, 0.5)
            dist = np.sqrt((cx - 0.5) ** 2 + (cy - 0.5) ** 2)
            if dist < best_dist:
                best_dist = dist
                best_face = face

        if best_face:
            self._current_target_id = best_face.id

        return best_face

    def _select_largest(self, faces: List[TrackedFace]) -> TrackedFace:
        """Sélectionne le visage avec la plus grande bbox."""
        best_face = None
        best_area = 0.0

        for face in faces:
            w, h = face.detection.bbox[2], face.detection.bbox[3]
            area = w * h
            if area > best_area:
                best_area = area
                best_face = face

        if best_face:
            self._current_target_id = best_face.id

        return best_face

    def _select_round_robin(self, faces: List[TrackedFace]) -> TrackedFace:
        """Rotation entre visages toutes les dwell_ms."""
        current_time = time.time()

        # Trier par ID pour ordre stable
        sorted_faces = sorted(faces, key=lambda f: f.id)

        # Vérifier si il faut changer de cible
        should_switch = (
            self._current_target_id is None
            or current_time - self._last_switch_time >= self.dwell_ms / 1000.0
            or not any(f.id == self._current_target_id for f in sorted_faces)
        )

        if should_switch:
            # Passer au suivant
            if len(sorted_faces) > 0:
                self._round_robin_idx = (self._round_robin_idx + 1) % len(sorted_faces)
                selected_face = sorted_faces[self._round_robin_idx]
                self._current_target_id = selected_face.id
                self._last_switch_time = current_time
                logger.debug(f"Round robin: cible {self._current_target_id}")
                return selected_face

        # Garder cible courante si elle existe encore
        for face in sorted_faces:
            if face.id == self._current_target_id:
                return face

        # Fallback: premier visage
        if sorted_faces:
            self._current_target_id = sorted_faces[0].id
            return sorted_faces[0]

        return None

    def update_config(self, policy: str, dwell_ms: int) -> None:
        """Met à jour la configuration."""
        if self.policy != policy:
            logger.info(f"Politique changée: {self.policy} -> {policy}")
            self.policy = policy
            self._current_target_id = None  # Reset sélection
            self._round_robin_idx = 0

        self.dwell_ms = dwell_ms

    @property
    def current_target_id(self) -> Optional[int]:
        return self._current_target_id
