"""
Client MQTT pour l'orchestrateur Skull Pi.
Gère la communication avec tous les autres services.
"""

import asyncio
import json
import logging
from typing import Any, Awaitable, Callable, Dict, Optional

import paho.mqtt.client as mqtt

from orchestrator.config import Config


# Type d'un handler asynchrone: (topic, payload) -> awaitable
AsyncHandler = Callable[[str, str], Awaitable[None]]


class MqttClient:
    """Client MQTT asynchrone pour l'orchestrateur."""

    def __init__(self, config: Config):
        self.config = config
        self.logger = logging.getLogger("orchestrator.mqtt")

        # --- Configuration MQTT (corrigée) ---
        # Avant : self.config.get("mqtt", {}) -> passait un dict comme 'key' -> TypeError
        mqtt_cfg = self.config.get("mqtt")
        if not isinstance(mqtt_cfg, dict):
            mqtt_cfg = {}

        # Supporte 'host' ou 'broker'
        self.broker: str = mqtt_cfg.get("broker") or mqtt_cfg.get("host") or "localhost"
        self.port: int = int(mqtt_cfg.get("port", 1883))
        self.username: Optional[str] = mqtt_cfg.get("username") or None
        self.password: Optional[str] = mqtt_cfg.get("password") or None
        self.client_id: str = str(mqtt_cfg.get("client_id", "orchestrator"))
        self.keepalive: int = int(mqtt_cfg.get("keepalive", 60))
        self._use_tls = bool(mqtt_cfg.get("tls", False))

        # --- Client MQTT ---
        # Force le protocole MQTT 3.1.1 pour conserver les callbacks à 4 paramètres
        self.client = mqtt.Client(client_id=self.client_id, protocol=mqtt.MQTTv311)
        self.client.on_connect = self._on_connect
        self.client.on_message = self._on_message
        self.client.on_disconnect = self._on_disconnect

        # Authentification si configurée
        if self.username and self.password:
            self.client.username_pw_set(self.username, self.password)

        # TLS simple si demandé (config avancée possible via orchestrator.config si besoin)
        if self._use_tls:
            try:
                self.client.tls_set()
            except Exception as e:
                self.logger.warning(
                    f"Activation TLS MQTT échouée, connexion sans TLS: {e}"
                )

        # État de connexion
        self.connected = False
        self.reconnect_interval = 5

        # Gestionnaires de messages
        self.message_handlers: Dict[str, AsyncHandler] = {}

        # File d'attente (non utilisée directement, mais gardée si besoin)
        self.message_queue: asyncio.Queue = asyncio.Queue()

    async def connect(self) -> bool:
        """Connecte au broker MQTT."""
        try:
            self.logger.info(f"Connexion MQTT vers {self.broker}:{self.port}")

            # Connexion synchrone dans un executor
            loop = asyncio.get_event_loop()
            await loop.run_in_executor(None, self._connect_sync)

            # Attendre la connexion
            for _ in range(50):  # ~5 secondes
                if self.connected:
                    break
                await asyncio.sleep(0.1)

            if not self.connected:
                raise TimeoutError("Timeout de connexion MQTT")

            self.logger.info("Connexion MQTT établie")
            return True

        except Exception as e:
            self.logger.error(f"Erreur connexion MQTT: {e}")
            return False

    def _connect_sync(self) -> None:
        """Connexion synchrone au broker."""
        # Déclenche les callbacks on_connect / on_disconnect
        self.client.connect(self.broker, self.port, self.keepalive)
        self.client.loop_start()

    async def disconnect(self) -> None:
        """Déconnecte du broker MQTT."""
        if self.connected:
            self.logger.info("Déconnexion MQTT")
            self.client.loop_stop()
            self.client.disconnect()
            self.connected = False

    # --- Callbacks paho (signature v3.1.1 / MQTTv311) ---
    def _on_connect(self, client, userdata, flags, rc):
        if rc == 0:
            self.connected = True
            self.logger.info("Connecté au broker MQTT")
        else:
            self.logger.error(f"Échec connexion MQTT: rc={rc}")

    def _on_disconnect(self, client, userdata, rc):
        self.connected = False
        if rc != 0:
            self.logger.warning(f"Déconnexion MQTT inattendue (rc={rc})")

    def _on_message(self, client, userdata, msg):
        """Callback de réception de message."""
        try:
            topic = msg.topic
            payload = msg.payload.decode("utf-8", errors="replace")
            # Délègue au handler asynchrone
            asyncio.create_task(self._handle_message(topic, payload))
        except Exception as e:
            self.logger.error(f"Erreur traitement message MQTT: {e}")

    async def _handle_message(self, topic: str, payload: str) -> None:
        """Traite un message reçu."""
        try:
            # Handler exact
            handler = self.message_handlers.get(topic)
            if handler:
                await handler(topic, payload)
                return

            # Handler par pattern
            for pattern, handler_func in self.message_handlers.items():
                if pattern is handler:  # skip l'entrée précédente si identique
                    continue
                if self._topic_matches(topic, pattern):
                    await handler_func(topic, payload)
                    return

            # Aucun handler: log en debug
            self.logger.debug(f"Message sans handler pour {topic}: {payload[:120]}")

        except Exception as e:
            self.logger.error(f"Erreur handler message {topic}: {e}")

    def _topic_matches(self, topic: str, pattern: str) -> bool:
        """Vérifie si un topic correspond à un pattern MQTT ('#' et '+')."""
        if pattern == topic:
            return True
        if pattern.endswith("/#"):
            prefix = pattern[:-2]
            return topic.startswith(prefix)
        if "+" in pattern:
            p_parts = pattern.split("/")
            t_parts = topic.split("/")
            if len(p_parts) != len(t_parts):
                return False
            for p, t in zip(p_parts, t_parts):
                if p != "+" and p != t:
                    return False
            return True
        return False

    async def subscribe(self, topic: str, handler: AsyncHandler) -> bool:
        """S'abonne à un topic avec un handler asynchrone."""
        if not self.connected:
            self.logger.error("Pas de connexion MQTT pour subscription")
            return False

        try:
            result, _mid = self.client.subscribe(topic)
            if result == mqtt.MQTT_ERR_SUCCESS:
                self.message_handlers[topic] = handler
                self.logger.debug(f"Abonnement à {topic}")
                return True
            else:
                self.logger.error(f"Échec abonnement à {topic}: {result}")
                return False
        except Exception as e:
            self.logger.error(f"Erreur abonnement {topic}: {e}")
            return False

    async def unsubscribe(self, topic: str) -> bool:
        """Se désabonne d'un topic."""
        if not self.connected:
            return False
        try:
            result, _mid = self.client.unsubscribe(topic)
            if result == mqtt.MQTT_ERR_SUCCESS:
                self.message_handlers.pop(topic, None)
                self.logger.debug(f"Désabonnement de {topic}")
                return True
            else:
                self.logger.error(f"Échec désabonnement de {topic}: {result}")
                return False
        except Exception as e:
            self.logger.error(f"Erreur désabonnement {topic}: {e}")
            return False

    async def publish(self, topic: str, payload: str, retain: bool = False) -> bool:
        """Publie un message."""
        if not self.connected:
            self.logger.error("Pas de connexion MQTT pour publication")
            return False
        try:
            info = self.client.publish(topic, payload, retain=retain)
            # Compat paho 1.x / 2.x
            try:
                info.wait_for_publish(timeout=1.0)
            except TypeError:
                info.wait_for_publish()
            self.logger.debug(
                f"Publié sur {topic}: {payload[:100]}{'...' if len(payload) > 100 else ''}"
            )
            return True
        except Exception as e:
            self.logger.error(f"Erreur publication {topic}: {e}")
            return False

    async def publish_json(self, topic: str, data: Any, retain: bool = False) -> bool:
        """Publie un message JSON."""
        try:
            payload = json.dumps(data)
            return await self.publish(topic, payload, retain)
        except Exception as e:
            self.logger.error(f"Erreur sérialisation JSON pour {topic}: {e}")
            return False

    async def process_messages(self) -> None:
        """Maintient la connexion et laisse les callbacks traiter les messages."""
        if not self.connected:
            await self._reconnect()

    async def _reconnect(self) -> None:
        """Tente une reconnexion progressive."""
        try:
            await asyncio.sleep(self.reconnect_interval)
            await self.connect()
        except Exception as e:
            self.logger.error(f"Erreur reconnexion: {e}")

    async def republish_retained_configs(self) -> None:
        """Republie toutes les configurations retained (placeholder)."""
        pass
