"""Gestionnaire MQTT pour le service voice."""

import json
import time
import threading
from typing import Optional, Callable, Dict, Any
import paho.mqtt.client as mqtt

from .voice_types import TTSRequest, MP3Request, JawConfig, VoiceState, Capabilities
from .config import VoiceConfig
from .logger import setup_logger, log_with_metrics


class MQTTHandler:
    """Gestionnaire des communications MQTT."""

    def __init__(self):
        """Initialise le gestionnaire MQTT."""
        self.logger = setup_logger(__name__)
        self.client: Optional[mqtt.Client] = None

        # Callbacks pour les différents types de messages
        self.callbacks: Dict[str, Callable] = {}

        # État de connexion
        self._connected = False
        self._lock = threading.Lock()

        # Configuration depuis env
        config = VoiceConfig.get_env_config()
        self.host = config["mqtt_host"]
        self.port = config["mqtt_port"]

    def set_callback(self, topic_key: str, callback: Callable):
        """Définit un callback pour un topic."""
        self.callbacks[topic_key] = callback

    def connect(self) -> bool:
        """Se connecte au broker MQTT."""
        try:
            self.client = mqtt.Client()
            self.client.on_connect = self._on_connect
            self.client.on_disconnect = self._on_disconnect
            self.client.on_message = self._on_message

            self.client.connect(self.host, self.port, VoiceConfig.MQTT_KEEPALIVE)
            self.client.loop_start()

            # Attend connexion
            timeout = 5.0
            start_time = time.time()
            while not self._connected and (time.time() - start_time) < timeout:
                time.sleep(0.1)

            if self._connected:
                self.logger.info(f"Connecté à MQTT: {self.host}:{self.port}")
                return True
            else:
                self.logger.error("Timeout connexion MQTT")
                return False

        except Exception as e:
            self.logger.error(f"Erreur connexion MQTT: {e}")
            return False

    def disconnect(self):
        """Se déconnecte du broker MQTT."""
        if self.client:
            self.client.loop_stop()
            self.client.disconnect()
            self._connected = False
            self.logger.info("Déconnecté de MQTT")

    def _on_connect(self, client, userdata, flags, rc):
        """Callback connexion MQTT."""
        if rc == 0:
            self._connected = True
            self.logger.info("Connexion MQTT établie")
            self._subscribe_to_topics()
        else:
            self.logger.error(f"Échec connexion MQTT: {rc}")

    def _on_disconnect(self, client, userdata, rc):
        """Callback déconnexion MQTT."""
        self._connected = False
        if rc != 0:
            self.logger.warning(f"Déconnexion MQTT inattendue: {rc}")
        else:
            self.logger.info("Déconnexion MQTT normale")

    def _subscribe_to_topics(self):
        """S'abonne aux topics d'entrée."""
        topics = [
            VoiceConfig.TOPICS["tts"],
            VoiceConfig.TOPICS["mp3_play"],
            VoiceConfig.TOPICS["mp3_stop"],
            VoiceConfig.TOPICS["mp3_pause"],
            VoiceConfig.TOPICS["mp3_resume"],
            VoiceConfig.TOPICS["jaw_config"],
        ]

        for topic in topics:
            self.client.subscribe(topic)
            self.logger.debug(f"Abonné au topic: {topic}")

    def _on_message(self, client, userdata, msg):
        """Callback réception message MQTT."""
        try:
            topic = msg.topic
            payload = msg.payload.decode("utf-8")

            log_with_metrics(
                self.logger,
                "DEBUG",
                f"Message MQTT reçu: {topic}",
                payload_size=len(payload),
            )

            # Route vers le bon handler
            self._route_message(topic, payload)

        except Exception as e:
            self.logger.error(f"Erreur traitement message MQTT: {e}")

    def _route_message(self, topic: str, payload: str):
        """Route un message vers le bon callback."""
        try:
            # TTS
            if topic == VoiceConfig.TOPICS["tts"]:
                if "tts" in self.callbacks:
                    data = json.loads(payload)
                    request = TTSRequest(**data)
                    self.callbacks["tts"](request)

            # MP3 Play
            elif topic == VoiceConfig.TOPICS["mp3_play"]:
                if "mp3_play" in self.callbacks:
                    data = json.loads(payload)
                    request = MP3Request(**data)
                    self.callbacks["mp3_play"](request)

            # MP3 Controls
            elif topic == VoiceConfig.TOPICS["mp3_stop"]:
                if "mp3_stop" in self.callbacks:
                    self.callbacks["mp3_stop"]()

            elif topic == VoiceConfig.TOPICS["mp3_pause"]:
                if "mp3_pause" in self.callbacks:
                    self.callbacks["mp3_pause"]()

            elif topic == VoiceConfig.TOPICS["mp3_resume"]:
                if "mp3_resume" in self.callbacks:
                    self.callbacks["mp3_resume"]()

            # Jaw Config
            elif topic == VoiceConfig.TOPICS["jaw_config"]:
                if "jaw_config" in self.callbacks:
                    data = json.loads(payload)
                    config = JawConfig(**data)
                    self.callbacks["jaw_config"](config)

        except json.JSONDecodeError as e:
            self.logger.error(f"Erreur parsing JSON: {e}")
        except TypeError as e:
            self.logger.error(f"Erreur construction objet: {e}")
        except Exception as e:
            self.logger.error(f"Erreur routage message: {e}")

    def publish_viseme(self, viseme_data: dict):
        """Publie des données visème."""
        self._publish(VoiceConfig.TOPICS["viseme"], viseme_data)

    def publish_rms(self, rms_data: dict):
        """Publie des données RMS."""
        self._publish(VoiceConfig.TOPICS["rms"], rms_data)

    def publish_state(self, state: VoiceState):
        """Publie l'état du service."""
        self._publish(VoiceConfig.TOPICS["state"], state.to_dict())

    def publish_capabilities(self, capabilities: Capabilities):
        """Publie les capacités (retained)."""
        self._publish(
            VoiceConfig.TOPICS["capabilities"], capabilities.to_dict(), retain=True
        )

    def _publish(self, topic: str, data: dict, retain: bool = False):
        """Publie des données sur un topic."""
        if not self._connected or not self.client:
            self.logger.warning(f"Publication impossible - non connecté: {topic}")
            return

        try:
            payload = json.dumps(data, ensure_ascii=False)

            with self._lock:
                result = self.client.publish(topic, payload, retain=retain)

                if result.rc != mqtt.MQTT_ERR_SUCCESS:
                    self.logger.error(f"Erreur publication MQTT: {topic} - {result.rc}")
                else:
                    self.logger.debug(f"Message publié: {topic}")

        except Exception as e:
            self.logger.error(f"Erreur publication: {topic} - {e}")

    def is_connected(self) -> bool:
        """Vérifie si la connexion MQTT est active."""
        return self._connected
