"""LLM client with retries, timeouts and circuit breaker"""

import asyncio
import time
import random
import logging
from typing import Dict, Any, Optional
from dataclasses import dataclass
from enum import Enum

# Mock imports - replace with actual LLM client in production
# from openai import AsyncOpenAI


logger = logging.getLogger(__name__)


@dataclass
class LLMResult:
    """Result from LLM generation"""

    text: str
    input_tokens: int
    output_tokens: int
    latency_ms: int
    model: str
    cached: bool = False


class CircuitState(Enum):
    CLOSED = "closed"
    OPEN = "open"
    HALF_OPEN = "half_open"


class CircuitBreaker:
    """Simple circuit breaker for LLM calls"""

    def __init__(self, failure_threshold: int = 5, recovery_timeout: int = 30):
        self.failure_threshold = failure_threshold
        self.recovery_timeout = recovery_timeout
        self.failure_count = 0
        self.last_failure_time = 0
        self.state = CircuitState.CLOSED

    def call_succeeded(self):
        """Record successful call"""
        self.failure_count = 0
        self.state = CircuitState.CLOSED

    def call_failed(self):
        """Record failed call"""
        self.failure_count += 1
        self.last_failure_time = time.time()

        if self.failure_count >= self.failure_threshold:
            self.state = CircuitState.OPEN
            logger.warning(f"Circuit breaker OPEN after {self.failure_count} failures")

    def can_execute(self) -> bool:
        """Check if call can be executed"""
        if self.state == CircuitState.CLOSED:
            return True

        if self.state == CircuitState.OPEN:
            if time.time() - self.last_failure_time > self.recovery_timeout:
                self.state = CircuitState.HALF_OPEN
                logger.info("Circuit breaker HALF_OPEN, trying recovery")
                return True
            return False

        # HALF_OPEN state
        return True

    def should_use_fallback(self) -> bool:
        """Check if should use fallback model"""
        return self.state == CircuitState.OPEN


class LLMClient:
    """LLM client with retries and circuit breaker"""

    def __init__(self):
        self.circuit_breaker = CircuitBreaker()
        # self.client = AsyncOpenAI()  # Replace with actual client

    async def generate(
        self,
        prompt: str,
        *,
        model: str,
        max_tokens: int,
        temperature: float,
        top_p: float,
        timeout_s: float,
    ) -> LLMResult:
        """Generate response with retries and circuit breaker"""

        # Check circuit breaker
        if not self.circuit_breaker.can_execute():
            if self.circuit_breaker.should_use_fallback():
                # Switch to fallback model
                model = "gpt-5-nano"
                timeout_s = 6  # Shorter timeout for nano
                logger.info(f"Using fallback model: {model}")

        max_retries = 3
        base_delay = 0.1

        for attempt in range(max_retries):
            try:
                start_time = time.time()

                # Mock LLM call - replace with actual implementation
                result = await self._mock_llm_call(
                    prompt, model, max_tokens, temperature, top_p, timeout_s
                )

                latency_ms = int((time.time() - start_time) * 1000)

                # Record success
                self.circuit_breaker.call_succeeded()

                return LLMResult(
                    text=result["text"],
                    input_tokens=result.get("input_tokens", 0),
                    output_tokens=result.get("output_tokens", 0),
                    latency_ms=latency_ms,
                    model=model,
                    cached=result.get("cached", False),
                )

            except asyncio.TimeoutError:
                logger.warning(f"LLM timeout on attempt {attempt + 1}")
                self.circuit_breaker.call_failed()

            except Exception as e:
                logger.warning(f"LLM error on attempt {attempt + 1}: {e}")
                self.circuit_breaker.call_failed()

                # Don't retry on certain errors
                if "rate_limit" in str(e).lower():
                    await asyncio.sleep(1)  # Rate limit backoff

            if attempt < max_retries - 1:
                # Exponential backoff with jitter
                delay = base_delay * (2**attempt) + random.uniform(0.1, 0.4)
                await asyncio.sleep(delay)

        # All retries failed
        raise Exception("LLM generation failed after all retries")

    async def _mock_llm_call(
        self,
        prompt: str,
        model: str,
        max_tokens: int,
        temperature: float,
        top_p: float,
        timeout_s: float,
    ) -> Dict[str, Any]:
        """Mock LLM call - replace with actual implementation"""

        # Simulate API call delay
        if model == "gpt-5-nano":
            delay = random.uniform(0.3, 0.8)
        else:
            delay = random.uniform(0.8, 1.5)

        try:
            await asyncio.wait_for(asyncio.sleep(delay), timeout=timeout_s)
        except asyncio.TimeoutError:
            raise asyncio.TimeoutError("Mock LLM timeout")

        # Mock response based on prompt content
        if "blague" in prompt.lower():
            text = "Pourquoi les plongeurs plongent-ils toujours en arrière ? Parce que sinon ils tombent dans le bateau !"
        elif "présente" in prompt.lower():
            text = "Salut ! Je suis Skull, ton crâne loquace et joyeux. Que puis-je faire pour toi aujourd'hui ?"
        elif "devinette" in prompt.lower():
            text = "Qu'est-ce qui a des dents mais ne mord jamais ? Un peigne ! Facile celle-là, non ?"
        else:
            text = "Je suis là pour t'aider ! Pose-moi une autre question ou raconte-moi quelque chose d'amusant."

        # Ensure French response and reasonable length
        sentences = text.split(". ")
        if len(sentences) > 3:
            text = ". ".join(sentences[:3]) + "."

        return {
            "text": text,
            "input_tokens": len(prompt.split()) * 1.3,  # Rough estimate
            "output_tokens": len(text.split()) * 1.3,
            "cached": False,
        }


# Global LLM client instance
llm_client = LLMClient()
