"""MQTT client for AI service"""

import asyncio
import json
import logging
import time
from typing import Dict, Any, Callable, Optional
from dataclasses import dataclass

try:
    import paho.mqtt.client as mqtt_client
    from paho.mqtt.client import MQTTMessage
except ImportError:
    # Fallback for testing without paho-mqtt
    class MockMQTTClient:
        def __init__(self, *args, **kwargs):
            pass

        def connect(self, *args, **kwargs):
            pass

        def subscribe(self, *args, **kwargs):
            pass

        def publish(self, *args, **kwargs):
            pass

        def loop_start(self):
            pass

        def loop_stop(self):
            pass

    mqtt_client = type("", (), {"Client": MockMQTTClient})()


logger = logging.getLogger(__name__)


@dataclass
class MQTTConfig:
    """MQTT configuration"""

    host: str = "127.0.0.1"
    port: int = 1883
    keepalive: int = 60
    client_id: str = "skull-ai"


class MQTTHandler:
    """MQTT client handler for AI service"""

    def __init__(self, config: MQTTConfig):
        self.config = config
        self.client = mqtt_client.Client(client_id=config.client_id)
        self.message_callbacks: Dict[str, Callable] = {}
        self.connected = False
        self.last_capabilities_publish = 0

        # Setup callbacks
        self.client.on_connect = self._on_connect
        self.client.on_message = self._on_message
        self.client.on_disconnect = self._on_disconnect

    async def connect(self) -> None:
        """Connect to MQTT broker"""
        try:
            self.client.connect(
                self.config.host, self.config.port, self.config.keepalive
            )
            self.client.loop_start()

            # Wait for connection
            for _ in range(50):  # 5 seconds timeout
                if self.connected:
                    break
                await asyncio.sleep(0.1)

            if not self.connected:
                raise Exception("MQTT connection timeout")

            logger.info(
                f"Connected to MQTT broker at {self.config.host}:{self.config.port}"
            )

        except Exception as e:
            logger.error(f"Failed to connect to MQTT broker: {e}")
            raise

    def disconnect(self) -> None:
        """Disconnect from MQTT broker"""
        if self.connected:
            self.client.loop_stop()
            self.client.disconnect()
            self.connected = False
            logger.info("Disconnected from MQTT broker")

    def subscribe_to_topics(self) -> None:
        """Subscribe to required topics"""
        topics = [
            ("asr/text", 0),
            ("ai/request", 0),
            ("ai/cancel", 0),
            ("skull/config/ia", 0),
            ("skull/config/prompt_system", 0),
        ]

        for topic, qos in topics:
            self.client.subscribe(topic, qos)
            logger.info(f"Subscribed to topic: {topic}")

    def register_callback(self, topic_pattern: str, callback: Callable) -> None:
        """Register callback for topic pattern"""
        self.message_callbacks[topic_pattern] = callback

    def publish_json(
        self, topic: str, data: Dict[str, Any], retain: bool = False
    ) -> None:
        """Publish JSON data to topic"""
        try:
            payload = json.dumps(data, ensure_ascii=False)
            self.client.publish(topic, payload, qos=0, retain=retain)
            logger.debug(f"Published to {topic}: {payload}")
        except Exception as e:
            logger.error(f"Failed to publish to {topic}: {e}")

    def publish_capabilities(self, capabilities: Dict[str, Any]) -> None:
        """Publish capabilities (retained)"""
        current_time = time.time()

        # Rate limit capabilities publishing (max once per 5 seconds)
        if current_time - self.last_capabilities_publish < 5:
            return

        self.publish_json("ai/capabilities", capabilities, retain=True)
        self.last_capabilities_publish = current_time

    def _on_connect(self, client, userdata, flags, rc) -> None:
        """Handle MQTT connection"""
        if rc == 0:
            self.connected = True
            logger.info("MQTT connected successfully")
            self.subscribe_to_topics()
        else:
            logger.error(f"MQTT connection failed with code {rc}")

    def _on_disconnect(self, client, userdata, rc) -> None:
        """Handle MQTT disconnection"""
        self.connected = False
        logger.warning(f"MQTT disconnected with code {rc}")

    def _on_message(self, client, userdata, msg: MQTTMessage) -> None:
        """Handle incoming MQTT message"""
        try:
            topic = msg.topic
            payload = msg.payload.decode("utf-8")

            logger.debug(f"Received message on {topic}: {payload}")

            # Parse JSON payload
            try:
                data = json.loads(payload) if payload else {}
            except json.JSONDecodeError as e:
                logger.warning(f"Invalid JSON on {topic}: {e}")
                return

            # Find matching callback
            for topic_pattern, callback in self.message_callbacks.items():
                if self._topic_matches(topic, topic_pattern):
                    try:
                        asyncio.create_task(callback(topic, data))
                    except Exception as e:
                        logger.error(f"Callback error for {topic}: {e}")
                    break

        except Exception as e:
            logger.error(f"Error processing message on {topic}: {e}")

    def _topic_matches(self, topic: str, pattern: str) -> bool:
        """Check if topic matches pattern (simple matching)"""
        if pattern == topic:
            return True

        # Simple wildcard matching
        if pattern.endswith("/#"):
            prefix = pattern[:-2]
            return topic.startswith(prefix)

        if pattern.endswith("/+"):
            prefix = pattern[:-2]
            topic_parts = topic.split("/")
            pattern_parts = prefix.split("/")

            if len(topic_parts) == len(pattern_parts) + 1:
                return topic.startswith(prefix)

        return False


class DebounceHandler:
    """Debounce handler for ASR messages"""

    def __init__(self, delay_ms: int = 300):
        self.delay_ms = delay_ms
        self.pending_tasks: Dict[str, asyncio.Task] = {}

    async def debounce(self, key: str, callback: Callable, *args, **kwargs) -> None:
        """Debounce callback execution"""

        # Cancel existing task for this key
        if key in self.pending_tasks:
            self.pending_tasks[key].cancel()

        # Create new delayed task
        async def delayed_callback():
            await asyncio.sleep(self.delay_ms / 1000.0)
            try:
                await callback(*args, **kwargs)
            finally:
                self.pending_tasks.pop(key, None)

        self.pending_tasks[key] = asyncio.create_task(delayed_callback())

    def cancel_all(self) -> None:
        """Cancel all pending tasks"""
        for task in self.pending_tasks.values():
            task.cancel()
        self.pending_tasks.clear()


# Global instances
mqtt_handler: Optional[MQTTHandler] = None
debounce_handler = DebounceHandler()
