"""
Moniteur de santé pour les services Skull Pi.
Vérifie la santé des services et les redémarre si nécessaire.
"""

import asyncio
import logging
import time
from typing import Dict, Optional

from orchestrator.mqtt import MqttClient
from orchestrator.services import ServiceManager


class HealthMonitor:
    """Moniteur de santé des services."""

    def __init__(
        self,
        service_manager: ServiceManager,
        mqtt_client: MqttClient,
        logger: logging.Logger,
    ):
        self.service_manager = service_manager
        self.mqtt_client = mqtt_client
        self.logger = logger

        # Configuration
        self.check_interval = 30  # 30 secondes
        self.service_timeout = 60  # 1 minute
        self.restart_attempts = 3

        # État des services
        self.service_heartbeats: Dict[str, float] = {}
        self.service_failures: Dict[str, int] = {}
        self.service_last_status: Dict[str, Dict] = {}

        # État du monitor
        self.running = False
        self.monitor_task: Optional[asyncio.Task] = None

    async def start(self) -> None:
        """Démarre le moniteur de santé."""
        if self.running:
            return

        self.running = True

        # S'abonner aux heartbeats des services
        await self._subscribe_heartbeats()

        # Initialiser les heartbeats
        current_time = time.time()
        for service in self.service_manager.get_managed_services():
            self.service_heartbeats[service] = current_time
            self.service_failures[service] = 0

        # Démarrer la tâche de monitoring
        self.monitor_task = asyncio.create_task(self._monitor_loop())

        self.logger.info("Moniteur de santé démarré")

    async def stop(self) -> None:
        """Arrête le moniteur de santé."""
        self.running = False

        if self.monitor_task:
            self.monitor_task.cancel()
            try:
                await self.monitor_task
            except asyncio.CancelledError:
                pass
            self.monitor_task = None

        self.logger.info("Moniteur de santé arrêté")

    async def _subscribe_heartbeats(self) -> None:
        """S'abonne aux messages de heartbeat des services."""
        heartbeat_topics = [
            "motion/heartbeat",
            "vision/heartbeat",
            "voice/heartbeat",
            "audioin/heartbeat",
            "asr/heartbeat",
            "ai/heartbeat",
        ]

        for topic in heartbeat_topics:
            await self.mqtt_client.subscribe(topic, self._on_heartbeat)

    async def _on_heartbeat(self, topic: str, payload: str) -> None:
        """Traite un message de heartbeat."""
        try:
            # Extraire le nom du service du topic
            service_name = None
            if topic.startswith("motion/"):
                service_name = "skull-motion"
            elif topic.startswith("vision/"):
                service_name = "skull-vision"
            elif topic.startswith("voice/"):
                service_name = "skull-voice"
            elif topic.startswith("audioin/"):
                service_name = "skull-audioin"
            elif topic.startswith("asr/"):
                service_name = "skull-asr"
            elif topic.startswith("ai/"):
                service_name = "skull-ai"

            if service_name:
                self.service_heartbeats[service_name] = time.time()
                self.logger.debug(f"Heartbeat reçu de {service_name}")

        except Exception as e:
            self.logger.error(f"Erreur traitement heartbeat {topic}: {e}")

    async def _monitor_loop(self) -> None:
        """Boucle principale de monitoring."""
        while self.running:
            try:
                await self._check_all_services()
                await asyncio.sleep(self.check_interval)

            except asyncio.CancelledError:
                break
            except Exception as e:
                self.logger.error(f"Erreur dans la boucle de monitoring: {e}")
                await asyncio.sleep(5)

    async def _check_all_services(self) -> None:
        """Vérifie la santé de tous les services."""
        current_time = time.time()

        for service_name in self.service_manager.get_managed_services():
            await self._check_service_health(service_name, current_time)

    async def _check_service_health(
        self, service_name: str, current_time: float
    ) -> None:
        """Vérifie la santé d'un service spécifique."""
        try:
            # Obtenir le statut du service
            status = await self.service_manager.get_service_status(service_name)
            self.service_last_status[service_name] = status

            # Vérifier si le service est actif
            is_active = status.get("active", False)

            # Vérifier le heartbeat
            last_heartbeat = self.service_heartbeats.get(service_name, 0)
            heartbeat_age = current_time - last_heartbeat

            # Le service est-il en panne ?
            service_failed = False
            failure_reason = ""

            if not is_active:
                service_failed = True
                failure_reason = "Service inactif"
            elif heartbeat_age > self.service_timeout:
                service_failed = True
                failure_reason = f"Pas de heartbeat depuis {heartbeat_age:.1f}s"

            if service_failed:
                await self._handle_service_failure(service_name, failure_reason)
            else:
                # Réinitialiser le compteur d'échecs si le service va bien
                if service_name in self.service_failures:
                    self.service_failures[service_name] = 0

        except Exception as e:
            self.logger.error(f"Erreur vérification santé {service_name}: {e}")

    async def _handle_service_failure(self, service_name: str, reason: str) -> None:
        """Gère la panne d'un service."""
        # Incrémenter le compteur d'échecs
        failures = self.service_failures.get(service_name, 0) + 1
        self.service_failures[service_name] = failures

        self.logger.warning(
            f"Service {service_name} en panne: {reason} (tentative {failures}/{self.restart_attempts})"
        )

        # Tentative de redémarrage si on n'a pas dépassé le maximum
        if failures <= self.restart_attempts:
            self.logger.info(f"Redémarrage de {service_name}...")

            success = await self.service_manager.restart_service(service_name)

            if success:
                self.logger.info(f"Service {service_name} redémarré avec succès")
                # Réinitialiser le heartbeat
                self.service_heartbeats[service_name] = time.time()

                # Publier notification de redémarrage
                await self.mqtt_client.publish_json(
                    "orchestrator/service_restarted",
                    {"service": service_name, "reason": reason, "attempt": failures},
                )
            else:
                self.logger.error(f"Échec redémarrage de {service_name}")

        else:
            self.logger.error(
                f"Service {service_name} défaillant après {self.restart_attempts} tentatives"
            )

            # Publier alerte critique
            await self.mqtt_client.publish_json(
                "orchestrator/service_critical",
                {
                    "service": service_name,
                    "reason": reason,
                    "failures": failures,
                    "max_attempts": self.restart_attempts,
                },
            )

    async def force_service_check(self, service_name: str) -> Dict:
        """Force une vérification de santé pour un service."""
        if service_name not in self.service_manager.get_managed_services():
            return {"error": "Service inconnu"}

        current_time = time.time()
        await self._check_service_health(service_name, current_time)

        return self.get_service_health_status(service_name)

    def get_service_health_status(self, service_name: str) -> Dict:
        """Retourne le statut de santé d'un service."""
        if service_name not in self.service_manager.get_managed_services():
            return {"error": "Service inconnu"}

        current_time = time.time()
        last_heartbeat = self.service_heartbeats.get(service_name, 0)
        heartbeat_age = current_time - last_heartbeat
        failures = self.service_failures.get(service_name, 0)
        last_status = self.service_last_status.get(service_name, {})

        return {
            "service": service_name,
            "last_heartbeat": last_heartbeat,
            "heartbeat_age_seconds": heartbeat_age,
            "failures": failures,
            "max_failures": self.restart_attempts,
            "healthy": heartbeat_age < self.service_timeout and failures == 0,
            "last_status": last_status,
        }

    def get_all_health_status(self) -> Dict:
        """Retourne le statut de santé de tous les services."""
        status = {
            "monitor_running": self.running,
            "check_interval": self.check_interval,
            "service_timeout": self.service_timeout,
            "services": {},
        }

        for service_name in self.service_manager.get_managed_services():
            status["services"][service_name] = self.get_service_health_status(
                service_name
            )

        return status

    def is_service_healthy(self, service_name: str) -> bool:
        """Vérifie rapidement si un service est en bonne santé."""
        health_status = self.get_service_health_status(service_name)
        return health_status.get("healthy", False)

    async def reset_service_failures(self, service_name: str) -> bool:
        """Remet à zéro le compteur d'échecs d'un service."""
        if service_name not in self.service_manager.get_managed_services():
            return False

        self.service_failures[service_name] = 0
        self.service_heartbeats[service_name] = time.time()

        self.logger.info(f"Compteur d'échecs réinitialisé pour {service_name}")
        return True
