"""Unit tests for AI service components"""

import asyncio
import pytest
import json
import time
from unittest.mock import Mock, AsyncMock, patch
from typing import Dict, Any

import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))

from ai.config import AIConfig, ConfigManager
from ai.context import ContextManager, ConversationTurn
from ai.prompt import PromptBuilder
from ai.filters import ResponseFilter
from ai.llm import LLMClient, LLMResult, CircuitBreaker
from ai.health import HealthMonitor


class TestPromptBuilder:
    """Test prompt assembly and management"""

    def test_build_basic_prompt(self):
        """Test basic prompt construction"""
        builder = PromptBuilder()

        prompt = builder.build_prompt(
            user_input="raconte une blague",
            system_prompt="Tu es un crâne joyeux",
            model="gpt-5-mini",
            max_tokens=256,
        )

        assert "[SYSTEM] Tu es un crâne joyeux" in prompt
        assert "[RULES] Réponds en français" in prompt
        assert "[USER] raconte une blague" in prompt
        assert "politique" in prompt  # Should contain avoidance rules

    def test_prompt_with_context(self):
        """Test prompt with conversation context"""
        builder = PromptBuilder()

        # Add some context
        from ai.context import context_manager

        context_manager.add_turn("Bonjour", "Salut ! Comment ça va ?")

        prompt = builder.build_prompt(
            user_input="raconte une blague",
            system_prompt="Tu es un crâne joyeux",
            model="gpt-5-mini",
            max_tokens=256,
        )

        assert "[CONTEXT]" in prompt
        assert "Bonjour" in prompt

        # Clean up
        context_manager.purge()

    def test_prompt_truncation(self):
        """Test prompt truncation for long content"""
        builder = PromptBuilder()

        # Create very long user input
        long_input = "raconte une histoire " * 100

        prompt = builder.build_prompt(
            user_input=long_input,
            system_prompt="Tu es un crâne joyeux",
            model="gpt-5-mini",
            max_tokens=100,  # Very small token limit
        )

        # Should still contain essential sections
        assert "[SYSTEM]" in prompt
        assert "[RULES]" in prompt
        assert "[USER]" in prompt

        # Should be reasonably sized
        word_count = len(prompt.split())
        assert word_count < 200  # Rough estimate

    def test_sensitive_intent_detection(self):
        """Test detection of sensitive topics"""
        builder = PromptBuilder()

        # Test various sensitive topics
        test_cases = [
            ("Combien coûte cette maison ?", True, "argent"),
            ("Qui voter aux élections ?", True, "politique"),
            ("J'ai mal au dos", True, "santé"),
            ("Raconte une blague", False, None),
        ]

        for user_input, should_be_sensitive, expected_topic in test_cases:
            intent = builder.extract_user_intent(user_input)

            assert intent["sensitive"] == should_be_sensitive
            if should_be_sensitive:
                assert expected_topic in intent.get("sensitive_topic", "")


class TestResponseFilter:
    """Test response filtering and post-processing"""

    def test_sentence_limiting(self):
        """Test limiting response to max sentences"""
        filter_instance = ResponseFilter()

        long_response = (
            "Voici une longue réponse. "
            "Elle contient plusieurs phrases. "
            "Peut-être trop de phrases. "
            "Certainement trop de phrases. "
            "Vraiment beaucoup de phrases."
        )

        intent = {"sensitive": False}
        result = filter_instance.process_response(long_response, intent)

        # Should be limited to 3 sentences
        sentence_count = len([s for s in result.split(".") if s.strip()])
        assert sentence_count <= 3

        # Should end properly
        assert result.endswith((".", "!", "?"))

    def test_sensitive_topic_redirect(self):
        """Test redirection for sensitive topics"""
        filter_instance = ResponseFilter()

        test_cases = [
            {"sensitive": True, "sensitive_topic": "argent"},
            {"sensitive": True, "sensitive_topic": "politique"},
            {"sensitive": True, "sensitive_topic": "santé"},
        ]

        for intent in test_cases:
            result = filter_instance.process_response("Réponse normale", intent)

            # Should be redirected
            assert "sujet" in result.lower() or "changeons" in result.lower()
            assert len(result) > 10  # Should have meaningful redirect

    def test_text_cleaning(self):
        """Test text cleaning and normalization"""
        filter_instance = ResponseFilter()

        dirty_text = (
            "  Voici une réponse  avec    espaces   ! Et des balises <div>test</div>  "
        )
        intent = {"sensitive": False}

        result = filter_instance.process_response(dirty_text, intent)

        # Should be cleaned
        assert "  " not in result  # No double spaces
        assert "<div>" not in result  # No HTML tags
        assert result.strip() == result  # No leading/trailing spaces

    def test_fallback_responses(self):
        """Test fallback response generation"""
        filter_instance = ResponseFilter()

        fallback_types = ["timeout", "rate_limit", "model_error", "general"]

        for error_type in fallback_types:
            response = filter_instance.get_fallback_response(error_type)

            assert len(response) > 5  # Meaningful response
            assert response.endswith((".", "!", "?"))  # Proper ending


class TestContextManager:
    """Test conversation context management"""

    def test_add_and_retrieve_context(self):
        """Test adding turns and retrieving context"""
        context = ContextManager(max_turns=2)

        # Add turns
        context.add_turn("Bonjour", "Salut !")
        context.add_turn("Comment ça va ?", "Ça va bien, merci !")

        # Get context
        context_text = context.get_context_for_model("gpt-5-mini")

        assert "Bonjour" in context_text
        assert "Comment ça va" in context_text
        assert "Salut" in context_text

    def test_context_model_adaptation(self):
        """Test context adaptation for different models"""
        context = ContextManager(max_turns=3)

        # Add multiple turns
        for i in range(3):
            context.add_turn(f"Question {i}", f"Réponse {i}")

        # Mini model should get more context
        mini_context = context.get_context_for_model("gpt-5-mini")
        nano_context = context.get_context_for_model("gpt-5-nano")

        # Mini should have more context than nano
        assert len(mini_context) >= len(nano_context)

    def test_context_purge(self):
        """Test context purging after inactivity"""
        context = ContextManager()

        # Add turn with old timestamp
        old_turn = ConversationTurn(
            timestamp=time.time() - 700,  # 700 seconds ago
            user_input="Test",
            ai_response="Response",
        )
        context.turns.append(old_turn)

        # Should need purging
        assert context.should_purge()

        # Purge and verify
        context.purge()
        assert len(context.turns) == 0


class TestCircuitBreaker:
    """Test circuit breaker functionality"""

    def test_circuit_breaker_states(self):
        """Test circuit breaker state transitions"""
        cb = CircuitBreaker(failure_threshold=3, recovery_timeout=1)

        # Initially closed
        assert cb.can_execute()
        assert not cb.should_use_fallback()

        # Add failures
        for _ in range(3):
            cb.call_failed()

        # Should be open now
        assert not cb.can_execute()
        assert cb.should_use_fallback()

        # Wait for recovery
        time.sleep(1.1)
        assert cb.can_execute()  # Should be half-open

        # Success should close it
        cb.call_succeeded()
        assert cb.can_execute()
        assert not cb.should_use_fallback()


class TestLLMClient:
    """Test LLM client with retries"""

    @pytest.mark.asyncio
    async def test_successful_generation(self):
        """Test successful LLM generation"""
        client = LLMClient()

        result = await client.generate(
            prompt="Test prompt",
            model="gpt-5-nano",
            max_tokens=50,
            temperature=0.7,
            top_p=0.9,
            timeout_s=5,
        )

        assert isinstance(result, LLMResult)
        assert len(result.text) > 0
        assert result.latency_ms > 0
        assert result.model in ["gpt-5-nano", "gpt-5-mini"]

    @pytest.mark.asyncio
    async def test_fallback_model_switching(self):
        """Test fallback to nano model when circuit is open"""
        client = LLMClient()

        # Force circuit breaker to open
        for _ in range(5):
            client.circuit_breaker.call_failed()

        # Should use fallback model
        with patch.object(
            client, "_mock_llm_call", new_callable=AsyncMock
        ) as mock_call:
            mock_call.return_value = {
                "text": "Fallback response",
                "input_tokens": 10,
                "output_tokens": 5,
            }

            result = await client.generate(
                prompt="Test prompt",
                model="gpt-5-mini",  # Request mini
                max_tokens=50,
                temperature=0.7,
                top_p=0.9,
                timeout_s=5,
            )

            # Should have switched to nano
            assert result.model == "gpt-5-nano"


class TestConfigManager:
    """Test configuration management"""

    def test_config_from_env(self):
        """Test configuration loading from environment"""
        with patch.dict(
            os.environ,
            {"AI_MODEL": "gpt-5-nano", "AI_MAX_TOKENS": "128", "AI_TEMPERATURE": "0.8"},
        ):
            config = AIConfig.from_env()

            assert config.model == "gpt-5-nano"
            assert config.max_tokens == 128
            assert config.temperature == 0.8

    def test_mqtt_config_update(self):
        """Test configuration update from MQTT"""
        config = AIConfig()
        initial_model = config.model

        # Update via MQTT
        mqtt_data = {"model": "gpt-5-nano", "max_tokens": 128, "temperature": 0.9}

        config.update_from_mqtt(mqtt_data)

        assert config.model == "gpt-5-nano"
        assert config.max_tokens == 128
        assert config.temperature == 0.9

    def test_capabilities_generation(self):
        """Test capabilities message generation"""
        manager = ConfigManager()
        capabilities = manager.get_capabilities()

        assert "models" in capabilities
        assert "gpt-5-mini" in capabilities["models"]
        assert "gpt-5-nano" in capabilities["models"]
        assert capabilities["lang"] == "fr"


class TestHealthMonitor:
    """Test health monitoring"""

    def test_health_metrics_recording(self):
        """Test recording of health metrics"""
        monitor = HealthMonitor()

        # Record some requests
        start_time = monitor.record_request_start()
        time.sleep(0.01)  # Small delay
        monitor.record_request_success(start_time)

        status = monitor.get_health_status()

        assert status["requests_total"] == 1
        assert status["requests_success"] == 1
        assert status["requests_failed"] == 0
        assert status["success_rate"] == 1.0
        assert status["latency_p50"] > 0

    def test_health_assessment(self):
        """Test health status assessment"""
        monitor = HealthMonitor()

        # Record some failures
        for _ in range(10):
            start_time = monitor.record_request_start()
            monitor.record_request_failure(start_time)

        status = monitor.get_health_status()

        # Should be unhealthy due to low success rate
        assert not status["ok"]
        assert status["success_rate"] == 0.0


class TestIntegration:
    """Integration tests"""

    @pytest.mark.asyncio
    async def test_debounce_functionality(self):
        """Test ASR debouncing"""
        from ai.mqtt import DebounceHandler

        debouncer = DebounceHandler(delay_ms=100)
        call_count = 0

        async def test_callback():
            nonlocal call_count
            call_count += 1

        # Send multiple rapid calls
        for _ in range(3):
            await debouncer.debounce("test_key", test_callback)

        # Wait for debounce delay
        await asyncio.sleep(0.15)

        # Should only be called once
        assert call_count == 1

    def test_prompt_assembly_with_budget(self):
        """Test complete prompt assembly with token budget"""
        builder = PromptBuilder()

        # Test with different token limits
        for max_tokens in [50, 100, 200]:
            prompt = builder.build_prompt(
                user_input="Raconte-moi une histoire amusante",
                system_prompt="Tu es un crâne joyeux et drôle",
                model="gpt-5-mini",
                max_tokens=max_tokens,
            )

            # Estimate tokens (rough)
            estimated_tokens = len(prompt.split()) * 1.3
            budget = max_tokens * 0.7  # Reserve 30% for output

            # Should respect budget
            assert estimated_tokens <= budget * 1.2  # Allow 20% margin

    def test_soft_policy_enforcement(self):
        """Test soft policy for sensitive topics"""
        builder = PromptBuilder()
        filter_instance = ResponseFilter()

        sensitive_inputs = [
            "Combien d'argent tu gagnes ?",
            "Que penses-tu de Macron ?",
            "J'ai des douleurs thoraciques",
        ]

        for user_input in sensitive_inputs:
            # Extract intent
            intent = builder.extract_user_intent(user_input)
            assert intent["sensitive"], f"Should detect {user_input} as sensitive"

            # Process response (would normally be LLM output)
            filtered = filter_instance.process_response("Réponse normale", intent)

            # Should be redirected
            assert "sujet" in filtered.lower() or "changeons" in filtered.lower()


# Test fixtures and utilities
@pytest.fixture
def mock_mqtt_handler():
    """Mock MQTT handler for testing"""
    handler = Mock()
    handler.publish_json = Mock()
    handler.publish_capabilities = Mock()
    return handler


@pytest.fixture
def sample_config():
    """Sample configuration for testing"""
    return {
        "model": "gpt-5-mini",
        "max_tokens": 256,
        "temperature": 0.7,
        "top_p": 0.9,
        "lang": "fr",
    }


if __name__ == "__main__":
    pytest.main([__file__, "-v"])
