"""Wrapper MediaPipe pour détection et analyse de visages."""

import cv2
import numpy as np
import logging
from typing import List, Optional, Dict, Any, Tuple
from dataclasses import dataclass

try:
    import mediapipe as mp

    MEDIAPIPE_AVAILABLE = True
except ImportError:
    MEDIAPIPE_AVAILABLE = False
    mp = None

logger = logging.getLogger(__name__)


@dataclass
class FaceDetection:
    """Résultat de détection d'un visage."""

    bbox: Tuple[float, float, float, float]  # cx, cy, w, h normalisés [0,1]
    confidence: float
    yaw: Optional[float] = None
    pitch: Optional[float] = None
    roll: Optional[float] = None
    landmarks: Optional[np.ndarray] = None


class MediaPipeProcessor:
    """Processeur MediaPipe pour détection + landmarks."""

    def __init__(
        self,
        min_detection_confidence: float = 0.6,
        min_tracking_confidence: float = 0.5,
    ):
        if not MEDIAPIPE_AVAILABLE:
            raise RuntimeError("MediaPipe non disponible")

        self.min_detection_confidence = min_detection_confidence
        self.min_tracking_confidence = min_tracking_confidence

        # Initialiser MediaPipe Face Detection
        self.mp_face_detection = mp.solutions.face_detection
        self.face_detection = self.mp_face_detection.FaceDetection(
            model_selection=0,  # Modèle court range optimisé
            min_detection_confidence=min_detection_confidence,
        )

        # Initialiser MediaPipe Face Mesh pour landmarks (optionnel)
        try:
            self.mp_face_mesh = mp.solutions.face_mesh
            self.face_mesh = self.mp_face_mesh.FaceMesh(
                static_image_mode=False,
                max_num_faces=3,
                refine_landmarks=False,  # Plus léger
                min_detection_confidence=min_detection_confidence,
                min_tracking_confidence=min_tracking_confidence,
            )
            self.mesh_available = True
            logger.info("Face Mesh initialisé")
        except Exception as e:
            logger.warning(f"Face Mesh non disponible: {e}")
            self.mesh_available = False
            self.face_mesh = None

    def process_frame(self, frame_rgb: np.ndarray) -> List[FaceDetection]:
        """Traite une frame et retourne les détections."""
        if frame_rgb is None:
            return []

        try:
            h, w = frame_rgb.shape[:2]
            results = self.face_detection.process(frame_rgb)

            detections = []

            if results.detections:
                for detection in results.detections:
                    bbox_norm = detection.location_data.relative_bounding_box

                    # Conversion bbox MediaPipe -> format normalisé cx,cy,w,h
                    x_min = bbox_norm.xmin
                    y_min = bbox_norm.ymin
                    width = bbox_norm.width
                    height = bbox_norm.height

                    # Centre normalisé
                    cx = x_min + width / 2.0
                    cy = y_min + height / 2.0

                    # Clamp [0,1]
                    cx = max(0.0, min(1.0, cx))
                    cy = max(0.0, min(1.0, cy))
                    width = max(0.0, min(1.0, width))
                    height = max(0.0, min(1.0, height))

                    confidence = detection.score[0] if detection.score else 0.0

                    face_det = FaceDetection(
                        bbox=(cx, cy, width, height), confidence=confidence
                    )

                    # Tenter estimation pose via Face Mesh si disponible
                    if self.mesh_available and self.face_mesh:
                        try:
                            face_det = self._estimate_pose_angles(frame_rgb, face_det)
                        except Exception as e:
                            logger.debug(f"Erreur pose estimation: {e}")

                    detections.append(face_det)

            return detections

        except Exception as e:
            logger.error(f"Erreur processing MediaPipe: {e}")
            return []

    def _estimate_pose_angles(
        self, frame_rgb: np.ndarray, face_det: FaceDetection
    ) -> FaceDetection:
        """Estime yaw/pitch/roll via Face Mesh landmarks."""
        try:
            results = self.face_mesh.process(frame_rgb)

            if not results.multi_face_landmarks:
                return face_det

            h, w = frame_rgb.shape[:2]

            # Prendre le premier visage (TODO: associer par IoU)
            landmarks = results.multi_face_landmarks[0]

            # Extraire points clés pour estimation pose
            # Points MediaPipe Face Mesh :
            # - Nez tip: 1
            # - Coin œil gauche: 33, droit: 263
            # - Coin bouche gauche: 61, droit: 291
            # - Chin: 18

            nose_tip = landmarks.landmark[1]
            left_eye = landmarks.landmark[33]
            right_eye = landmarks.landmark[263]
            chin = landmarks.landmark[18]
            left_mouth = landmarks.landmark[61]
            right_mouth = landmarks.landmark[291]

            # Conversion en coordonnées pixel normalisées
            points = np.array(
                [
                    [nose_tip.x, nose_tip.y],
                    [left_eye.x, left_eye.y],
                    [right_eye.x, right_eye.y],
                    [chin.x, chin.y],
                    [left_mouth.x, left_mouth.y],
                    [right_mouth.x, right_mouth.y],
                ]
            )

            # Estimation angles simplifiée
            yaw = self._estimate_yaw(points)
            pitch = self._estimate_pitch(points)
            roll = self._estimate_roll(points)

            face_det.yaw = yaw
            face_det.pitch = pitch
            face_det.roll = roll
            face_det.landmarks = points

            return face_det

        except Exception as e:
            logger.debug(f"Erreur estimation pose: {e}")
            return face_det

    def _estimate_yaw(self, points: np.ndarray) -> float:
        """Estime angle yaw (rotation horizontale) en radians."""
        # Utiliser asymétrie yeux par rapport nez
        nose = points[0]  # nez
        left_eye = points[1]  # œil gauche
        right_eye = points[2]  # œil droit

        # Distance nez -> yeux
        d_left = np.linalg.norm(nose - left_eye)
        d_right = np.linalg.norm(nose - right_eye)

        # Ratio asymétrie -> yaw approximatif
        if d_left + d_right > 0:
            ratio = (d_right - d_left) / (d_left + d_right)
            yaw = ratio * 0.8  # Factor empirique, clamp [-0.8, 0.8] rad ≈ [-45°, 45°]
            return max(-0.8, min(0.8, yaw))
        return 0.0

    def _estimate_pitch(self, points: np.ndarray) -> float:
        """Estime angle pitch (rotation verticale) en radians."""
        nose = points[0]
        chin = points[3]

        # Vecteur nez -> menton
        vec = chin - nose

        # Angle par rapport horizontal (y positif = vers le bas)
        if abs(vec[1]) > 1e-6:
            pitch = np.arctan2(vec[1], abs(vec[0])) * 0.5  # Factor réduction
            return max(-0.6, min(0.6, pitch))  # Clamp ±35°
        return 0.0

    def _estimate_roll(self, points: np.ndarray) -> float:
        """Estime angle roll (rotation dans plan image) en radians."""
        left_eye = points[1]
        right_eye = points[2]

        # Vecteur œil gauche -> œil droit
        vec = right_eye - left_eye

        # Angle par rapport horizontal
        if abs(vec[0]) > 1e-6:
            roll = np.arctan2(vec[1], vec[0])
            return max(-0.5, min(0.5, roll))  # Clamp ±30°
        return 0.0

    def update_confidence_thresholds(
        self, min_detection: float, min_tracking: float
    ) -> None:
        """Met à jour seuils de confiance."""
        self.min_detection_confidence = min_detection
        self.min_tracking_confidence = min_tracking

        # Recréer détecteur avec nouveaux seuils
        try:
            self.face_detection = self.mp_face_detection.FaceDetection(
                model_selection=0, min_detection_confidence=min_detection
            )

            if self.mesh_available and self.face_mesh:
                self.face_mesh = self.mp_face_mesh.FaceMesh(
                    static_image_mode=False,
                    max_num_faces=3,
                    refine_landmarks=False,
                    min_detection_confidence=min_detection,
                    min_tracking_confidence=min_tracking,
                )
        except Exception as e:
            logger.error(f"Erreur mise à jour seuils MediaPipe: {e}")


class MockMediaPipeProcessor:
    """Processeur MediaPipe simulé pour tests."""

    def __init__(
        self,
        min_detection_confidence: float = 0.6,
        min_tracking_confidence: float = 0.5,
    ):
        self.min_detection_confidence = min_detection_confidence
        self.min_tracking_confidence = min_tracking_confidence
        self.mesh_available = True
        logger.info("Mock MediaPipe processor initialisé")

    def process_frame(self, frame_rgb: np.ndarray) -> List[FaceDetection]:
        """Simule détection de visages."""
        if frame_rgb is None:
            return []

        h, w = frame_rgb.shape[:2]

        # Simuler 1-2 détections selon contenu frame
        detections = []

        # Visage principal au centre
        det1 = FaceDetection(
            bbox=(0.5, 0.45, 0.25, 0.3),  # Centre légèrement haut
            confidence=0.92,
            yaw=-0.1,
            pitch=0.05,
            roll=0.02,
        )
        detections.append(det1)

        # Visage secondaire si frame assez large
        if w >= 300:
            det2 = FaceDetection(
                bbox=(0.25, 0.6, 0.18, 0.22),
                confidence=0.78,
                yaw=0.15,
                pitch=-0.08,
                roll=-0.05,
            )
            detections.append(det2)

        return detections

    def update_confidence_thresholds(
        self, min_detection: float, min_tracking: float
    ) -> None:
        """Simule mise à jour seuils."""
        self.min_detection_confidence = min_detection
        self.min_tracking_confidence = min_tracking


def create_processor(
    min_detection_confidence: float = 0.6,
    min_tracking_confidence: float = 0.5,
    mock: bool = False,
) -> MediaPipeProcessor:
    """Factory pour créer processeur réel ou mock."""
    if mock or not MEDIAPIPE_AVAILABLE:
        return MockMediaPipeProcessor(min_detection_confidence, min_tracking_confidence)
    return MediaPipeProcessor(min_detection_confidence, min_tracking_confidence)
