"""Calcul du score de mouvement basé sur variations temporelles des features."""

import numpy as np
import time
import logging
from typing import List, Optional, Tuple
from dataclasses import dataclass
from .tracker import TrackedFace

logger = logging.getLogger(__name__)


@dataclass
class MotionState:
    """État du système de mouvement."""

    active: bool
    score: float
    reason: str
    timestamp: float


class MotionScoreCalculator:
    """Calculateur de score de mouvement basé sur historique des visages."""

    def __init__(self, threshold: float = 0.35, ema_alpha: float = 0.4):
        self.threshold = threshold
        self.ema_alpha = ema_alpha

        # Historique pour calcul variations
        self._prev_faces: List[TrackedFace] = []
        self._prev_timestamp = time.time()

        # Score lissé
        self._smoothed_score = 0.0
        self._last_motion_state = MotionState(
            active=False, score=0.0, reason="static", timestamp=time.time()
        )

    def calculate_motion_score(self, tracked_faces: List[TrackedFace]) -> MotionState:
        """Calcule le score de mouvement courant."""
        current_time = time.time()

        # Cas de base : aucun visage
        if not tracked_faces:
            self._smoothed_score = self._apply_ema(0.0, self._smoothed_score)
            self._prev_faces = []
            self._prev_timestamp = current_time

            return MotionState(
                active=False,
                score=self._smoothed_score,
                reason="no_face",
                timestamp=current_time,
            )

        # Calculer score brut
        raw_score = 0.0
        if self._prev_faces and len(self._prev_faces) > 0:
            raw_score = self._compute_raw_motion_score(tracked_faces)

        # Lisser le score
        self._smoothed_score = self._apply_ema(raw_score, self._smoothed_score)

        # Déterminer état actif
        active = self._smoothed_score >= self.threshold

        # Déterminer raison
        reason = self._determine_reason(tracked_faces, self._smoothed_score, active)

        # Sauvegarder état pour prochaine itération
        self._prev_faces = self._copy_faces(tracked_faces)
        self._prev_timestamp = current_time

        motion_state = MotionState(
            active=active,
            score=self._smoothed_score,
            reason=reason,
            timestamp=current_time,
        )

        self._last_motion_state = motion_state
        return motion_state

    def _compute_raw_motion_score(self, current_faces: List[TrackedFace]) -> float:
        """Calcule score brut basé sur variations des features."""
        if not self._prev_faces:
            return 0.0

        # Construire dictionnaire des visages précédents par ID
        prev_by_id = {face.id: face for face in self._prev_faces}

        total_motion = 0.0
        num_tracked_faces = 0

        for current_face in current_faces:
            if current_face.id not in prev_by_id:
                # Nouveau visage = mouvement
                total_motion += 0.3
                num_tracked_faces += 1
                continue

            prev_face = prev_by_id[current_face.id]
            face_motion = self._compute_face_motion(current_face, prev_face)
            total_motion += face_motion
            num_tracked_faces += 1

        # Visages disparus = mouvement
        for prev_id in prev_by_id:
            if not any(f.id == prev_id for f in current_faces):
                total_motion += 0.4
                num_tracked_faces += 1

        # Score moyen normalisé
        if num_tracked_faces > 0:
            avg_motion = total_motion / num_tracked_faces
            return min(1.0, avg_motion)

        return 0.0

    def _compute_face_motion(self, current: TrackedFace, prev: TrackedFace) -> float:
        """Calcule mouvement d'un visage individuel."""
        curr_bbox = current.detection.bbox
        prev_bbox = prev.detection.bbox

        # Mouvement du centre (cx, cy)
        center_motion = np.sqrt(
            (curr_bbox[0] - prev_bbox[0]) ** 2 + (curr_bbox[1] - prev_bbox[1]) ** 2
        )

        # Changement de taille (w, h)
        size_motion = abs(curr_bbox[2] - prev_bbox[2]) + abs(
            curr_bbox[3] - prev_bbox[3]
        )

        # Changement d'angles (si disponibles)
        angle_motion = 0.0
        if (
            current.detection.yaw is not None
            and prev.detection.yaw is not None
            and current.detection.pitch is not None
            and prev.detection.pitch is not None
        ):

            yaw_change = abs(current.detection.yaw - prev.detection.yaw)
            pitch_change = abs(current.detection.pitch - prev.detection.pitch)
            roll_change = 0.0

            if current.detection.roll is not None and prev.detection.roll is not None:
                roll_change = abs(current.detection.roll - prev.detection.roll)

            angle_motion = (yaw_change + pitch_change + roll_change) / 3.0

        # Combinaison pondérée
        motion_score = (
            0.5 * center_motion * 10.0  # Amplifier mouvement centre
            + 0.3 * size_motion * 5.0  # Changement taille
            + 0.2 * angle_motion * 2.0  # Changement orientation
        )

        return min(1.0, motion_score)

    def _apply_ema(self, new_value: float, prev_smoothed: float) -> float:
        """Applique filtre EMA (Exponential Moving Average)."""
        return self.ema_alpha * new_value + (1.0 - self.ema_alpha) * prev_smoothed

    def _determine_reason(
        self, faces: List[TrackedFace], score: float, active: bool
    ) -> str:
        """Détermine la raison de l'état de mouvement."""
        if not faces:
            return "no_face"
        elif active and score >= self.threshold:
            return "face_motion"
        else:
            return "static"

    def _copy_faces(self, faces: List[TrackedFace]) -> List[TrackedFace]:
        """Copie superficielle des visages pour historique."""
        # Copie simple des attributs essentiels
        copied = []
        for face in faces:
            # Créer copie avec même ID et detection
            copied_face = TrackedFace(
                id=face.id,
                detection=face.detection,
                last_seen=face.last_seen,
                age_frames=face.age_frames,
                history_bbox=face.history_bbox.copy(),
            )
            copied.append(copied_face)
        return copied

    def update_threshold(self, threshold: float) -> None:
        """Met à jour le seuil de détection de mouvement."""
        self.threshold = threshold
        logger.debug(f"Seuil mouvement mis à jour: {threshold}")

    def update_ema_alpha(self, alpha: float) -> None:
        """Met à jour le coefficient de lissage EMA."""
        self.ema_alpha = max(0.1, min(0.9, alpha))  # Clamp [0.1, 0.9]
        logger.debug(f"Alpha EMA mis à jour: {self.ema_alpha}")

    @property
    def current_motion_state(self) -> MotionState:
        """Retourne le dernier état de mouvement calculé."""
        return self._last_motion_state

    def get_debug_info(self) -> dict:
        """Retourne infos de debug pour le score de mouvement."""
        return {
            "threshold": self.threshold,
            "ema_alpha": self.ema_alpha,
            "smoothed_score": self._smoothed_score,
            "num_prev_faces": len(self._prev_faces),
            "last_reason": self._last_motion_state.reason,
        }


class StaticMotionCalculator:
    """Calculateur de mouvement simplifié pour tests."""

    def __init__(self, threshold: float = 0.35):
        self.threshold = threshold

    def calculate_motion_score(self, tracked_faces: List[TrackedFace]) -> MotionState:
        """Calcul basique pour tests."""
        current_time = time.time()

        if not tracked_faces:
            return MotionState(
                active=False, score=0.0, reason="no_face", timestamp=current_time
            )

        # Score basique basé sur nombre de visages
        base_score = min(0.8, len(tracked_faces) * 0.3)

        return MotionState(
            active=base_score >= self.threshold,
            score=base_score,
            reason="face_motion" if base_score >= self.threshold else "static",
            timestamp=current_time,
        )
