"""Main AI service orchestrator"""

import asyncio
import json
import logging
import signal
import sys
import time
from typing import Dict, Any, Optional
from pathlib import Path

from .config import ConfigManager
from .llm import llm_client
from .context import context_manager
from .prompt import prompt_builder
from .filters import response_filter
from .mqtt import MQTTHandler, MQTTConfig, debounce_handler
from .health import service_health


# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    handlers=[
        logging.FileHandler("/opt/Skull/logs/ai.log"),
        logging.StreamHandler(sys.stdout),
    ],
)
logger = logging.getLogger(__name__)


class AIService:
    """Main AI service class"""

    def __init__(self):
        self.config_manager = ConfigManager()
        self.mqtt_handler: Optional[MQTTHandler] = None
        self.running = False
        self.busy = False
        self.request_queue = asyncio.Queue(maxsize=3)
        self.shutdown_event = asyncio.Event()

    async def start(self) -> None:
        """Start the AI service"""
        logger.info("Starting Skull AI service...")

        try:
            # Initialize MQTT
            mqtt_config = MQTTConfig(
                host=self.config_manager.config.mqtt_host,
                port=self.config_manager.config.mqtt_port,
            )
            self.mqtt_handler = MQTTHandler(mqtt_config)

            # Register MQTT callbacks
            self._register_mqtt_callbacks()

            # Connect to MQTT
            await self.mqtt_handler.connect()

            # Publish initial capabilities
            capabilities = self.config_manager.get_capabilities()
            self.mqtt_handler.publish_capabilities(capabilities)

            # Start health monitoring
            await service_health.start_monitoring(self.mqtt_handler)

            # Start request processing loop
            self.running = True
            await asyncio.gather(self._process_requests(), self._wait_for_shutdown())

        except Exception as e:
            logger.error(f"Failed to start AI service: {e}")
            raise

    async def stop(self) -> None:
        """Stop the AI service"""
        logger.info("Stopping Skull AI service...")

        self.running = False
        self.shutdown_event.set()

        # Stop health monitoring
        await service_health.stop_monitoring()

        # Cancel debounced tasks
        debounce_handler.cancel_all()

        # Disconnect MQTT
        if self.mqtt_handler:
            self.mqtt_handler.disconnect()

        logger.info("AI service stopped")

    def _register_mqtt_callbacks(self) -> None:
        """Register MQTT message callbacks"""
        self.mqtt_handler.register_callback("asr/text", self._handle_asr_text)
        self.mqtt_handler.register_callback("ai/request", self._handle_ai_request)
        self.mqtt_handler.register_callback("ai/cancel", self._handle_ai_cancel)
        self.mqtt_handler.register_callback(
            "skull/config/ia", self._handle_config_update
        )
        self.mqtt_handler.register_callback(
            "skull/config/prompt_system", self._handle_prompt_update
        )

    async def _handle_asr_text(self, topic: str, data: Dict[str, Any]) -> None:
        """Handle ASR text input with debouncing"""
        text = data.get("text", "").strip()
        if not text:
            return

        # Debounce ASR messages to avoid duplicates
        await debounce_handler.debounce(
            "asr_text", self._process_user_input, text, "asr"
        )

    async def _handle_ai_request(self, topic: str, data: Dict[str, Any]) -> None:
        """Handle direct AI request"""
        text = data.get("text", "").strip()
        if not text:
            return

        await self._process_user_input(text, "request")

    async def _handle_ai_cancel(self, topic: str, data: Dict[str, Any]) -> None:
        """Handle AI request cancellation"""
        logger.info("AI request cancellation received")

        # Clear request queue
        while not self.request_queue.empty():
            try:
                self.request_queue.get_nowait()
            except asyncio.QueueEmpty:
                break

        # Publish interruption message
        if self.busy:
            response = {
                "ts_ms": int(time.time() * 1000),
                "text": response_filter.get_fallback_response("interrupted"),
            }
            self.mqtt_handler.publish_json("ai/response", response)
            self.busy = False

    async def _handle_config_update(self, topic: str, data: Dict[str, Any]) -> None:
        """Handle configuration update"""
        logger.info("Received AI config update")
        self.config_manager.config.update_from_mqtt(data)

        # Republish capabilities if model changed
        if "model" in data:
            capabilities = self.config_manager.get_capabilities()
            self.mqtt_handler.publish_capabilities(capabilities)

    async def _handle_prompt_update(self, topic: str, data: Dict[str, Any]) -> None:
        """Handle system prompt update"""
        logger.info("Received prompt system update")
        self.config_manager.update_prompt_system(data)

    async def _process_user_input(self, text: str, source: str) -> None:
        """Process user input and queue for AI processing"""

        # Skip if busy and this is from ASR (avoid flooding)
        if self.busy and source == "asr":
            logger.debug("Skipping ASR input while busy")
            return

        # Add to request queue with back-pressure
        try:
            request_data = {"text": text, "source": source, "timestamp": time.time()}
            self.request_queue.put_nowait(request_data)
            logger.debug(f"Queued request from {source}: {text[:50]}...")

        except asyncio.QueueFull:
            logger.warning("Request queue full, dropping request")
            # Publish error response
            error_response = {
                "ts_ms": int(time.time() * 1000),
                "text": response_filter.get_fallback_response("rate_limit"),
            }
            self.mqtt_handler.publish_json("ai/response", error_response)

    async def _process_requests(self) -> None:
        """Main request processing loop"""
        while self.running:
            try:
                # Get request from queue (with timeout)
                try:
                    request = await asyncio.wait_for(
                        self.request_queue.get(), timeout=1.0
                    )
                except asyncio.TimeoutError:
                    continue

                # Process the request
                await self._handle_ai_generation(request)

            except Exception as e:
                logger.error(f"Error in request processing loop: {e}")
                await asyncio.sleep(1.0)

    async def _handle_ai_generation(self, request: Dict[str, Any]) -> None:
        """Handle AI text generation"""
        start_time = service_health.health_monitor.record_request_start()
        self.busy = True

        try:
            user_text = request["text"]
            source = request["source"]

            logger.info(f"Processing AI request from {source}: {user_text[:100]}...")

            # Extract user intent
            user_intent = prompt_builder.extract_user_intent(user_text)

            # Build prompt
            prompt = prompt_builder.build_prompt(
                user_text,
                self.config_manager.prompt_system,
                self.config_manager.config.model,
                self.config_manager.config.max_tokens,
            )

            # Validate prompt safety
            if not prompt_builder.validate_prompt_safety(prompt):
                raise ValueError("Unsafe prompt detected")

            # Generate AI response
            result = await llm_client.generate(
                prompt,
                model=self.config_manager.config.model,
                max_tokens=self.config_manager.config.max_tokens,
                temperature=self.config_manager.config.temperature,
                top_p=self.config_manager.config.top_p,
                timeout_s=self.config_manager.config.timeout_s,
            )

            # Filter and post-process response
            filtered_text = response_filter.process_response(result.text, user_intent)

            # Update conversation context
            context_manager.add_turn(user_text, filtered_text)

            # Publish AI response
            response = {"ts_ms": int(time.time() * 1000), "text": filtered_text}
            self.mqtt_handler.publish_json("ai/response", response)

            # Publish usage metrics
            usage = {
                "ts_ms": int(time.time() * 1000),
                "latency_ms": result.latency_ms,
                "input_tokens": int(result.input_tokens),
                "output_tokens": int(result.output_tokens),
                "model": result.model,
                "cached": result.cached,
            }
            self.mqtt_handler.publish_json("ai/usage", usage)

            # Record success
            service_health.health_monitor.record_request_success(start_time)

            logger.info(f"AI response generated in {result.latency_ms}ms")

        except Exception as e:
            logger.error(f"AI generation failed: {e}")

            # Record failure
            service_health.health_monitor.record_request_failure(start_time)

            # Publish error response
            error_type = "timeout" if "timeout" in str(e).lower() else "model_error"
            error_response = {
                "ts_ms": int(time.time() * 1000),
                "text": response_filter.get_fallback_response(error_type),
            }
            self.mqtt_handler.publish_json("ai/response", error_response)

        finally:
            self.busy = False

            # Clean old context periodically
            context_manager.clear_old_context()

            # Purge context if inactive
            if context_manager.should_purge():
                context_manager.purge()
                logger.debug("Purged conversation context due to inactivity")

    async def _wait_for_shutdown(self) -> None:
        """Wait for shutdown signal"""
        await self.shutdown_event.wait()


# Global service instance
ai_service = AIService()


def signal_handler(signum, frame):
    """Handle shutdown signals"""
    logger.info(f"Received signal {signum}")
    asyncio.create_task(ai_service.stop())


async def main():
    """Main entry point"""

    # Setup signal handlers
    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    try:
        await ai_service.start()
    except KeyboardInterrupt:
        logger.info("Received keyboard interrupt")
    except Exception as e:
        logger.error(f"Service error: {e}")
        sys.exit(1)
    finally:
        await ai_service.stop()


if __name__ == "__main__":
    # Ensure log directory exists
    Path("/opt/Skull/logs").mkdir(parents=True, exist_ok=True)

    # Run the service
    asyncio.run(main())
