"""Tests d'intégration pour le service audioin."""

import json
import time
import threading
import wave
import numpy as np
from unittest.mock import Mock, patch, MagicMock
from pathlib import Path
import tempfile
import pytest

from audioin.service import AudioInService
from audioin.config import ServiceConfig, VADConfig
from audioin.vad import VoiceActivityDetector
from audioin.rms import RMSCalculator


class TestAudioProcessingPipeline:
    """Tests du pipeline de traitement audio complet."""

    def setup_method(self):
        """Setup pour chaque test."""
        self.config = ServiceConfig()
        self.rms_calc = RMSCalculator()
        self.vad = VoiceActivityDetector(self.config.vad)

    def test_rms_to_vad_pipeline(self):
        """Test pipeline RMS -> VAD."""
        # Générer signal test : silence -> parole -> silence
        test_frames = [
            np.zeros(320, dtype=np.int16),  # Silence
            np.full(320, 10000, dtype=np.int16),  # Parole forte
            np.full(320, 10000, dtype=np.int16),  # Parole forte
            np.full(320, 10000, dtype=np.int16),  # Parole forte
            np.full(320, 10000, dtype=np.int16),  # Parole forte
            np.zeros(320, dtype=np.int16),  # Silence
        ]

        vad_results = []
        for frame in test_frames:
            rms = self.rms_calc.calculate(frame)
            vad_result = self.vad.process(rms)
            vad_results.append((rms, vad_result.active))

        # Vérifier qu'il y a eu activation
        activations = [active for _, active in vad_results]
        assert any(activations), "VAD devrait détecter de l'activité"

    def test_config_update_propagation(self):
        """Test propagation mise à jour config."""
        # Config initiale
        initial_threshold = self.vad.config.threshold

        # Nouvelle config
        new_config = VADConfig(threshold=0.5)
        self.vad.update_config(new_config)

        assert self.vad.config.threshold == 0.5
        assert self.vad.config.threshold != initial_threshold


class MockAudioCapture:
    """Mock pour AudioCapture."""

    def __init__(self, frames_data):
        self.frames_data = frames_data
        self.frame_index = 0
        self.running = False

    def start(self):
        self.running = True

    def stop(self):
        self.running = False

    def get_frames(self):
        while self.running and self.frame_index < len(self.frames_data):
            yield self.frames_data[self.frame_index]
            self.frame_index += 1
            time.sleep(0.02)  # Simuler 20ms

    def get_queue_size(self):
        return 0


class MockMQTTClient:
    """Mock pour client MQTT."""

    def __init__(self):
        self.connected = False
        self.published_messages = []
        self.config_callback = None

    def connect(self, timeout=10.0):
        self.connected = True
        return True

    def disconnect(self):
        self.connected = False

    def publish_rms(self, rms):
        self.published_messages.append(("rms", rms))

    def publish_vad(self, vad_result):
        self.published_messages.append(("vad", vad_result))

    def publish_capabilities(self):
        self.published_messages.append(("capabilities", {}))

    def set_config_callback(self, callback):
        self.config_callback = callback

    def simulate_config_update(self, new_config):
        if self.config_callback:
            self.config_callback(new_config)

    def is_connected(self):
        return self.connected


class TestServiceIntegration:
    """Tests d'intégration du service complet."""

    def test_service_initialization(self):
        """Test initialisation complète du service."""
        with tempfile.TemporaryDirectory() as temp_dir:
            config = ServiceConfig()
            config.log_file = Path(temp_dir) / "test.log"

            service = AudioInService(config)

            assert service.config == config
            assert service.logger is not None
            assert not service.running

    @patch("audioin.service.AudioCapture")
    @patch("audioin.service.MQTTAudioClient")
    def test_service_start_stop(self, mock_mqtt_class, mock_audio_class):
        """Test démarrage et arrêt du service."""
        # Mocks
        mock_audio = Mock()
        mock_mqtt = Mock()
        mock_mqtt.connect.return_value = True
        mock_mqtt.is_connected.return_value = True

        mock_audio_class.return_value = mock_audio
        mock_mqtt_class.return_value = mock_mqtt

        # Service
        service = AudioInService()

        # Démarrage
        assert service.start()
        assert service.running

        # Vérifications
        mock_audio.start.assert_called_once()
        mock_mqtt.connect.assert_called_once()

        # Arrêt
        service.stop()
        assert not service.running

        mock_audio.stop.assert_called_once()
        mock_mqtt.disconnect.assert_called_once()

    @patch("audioin.service.AudioCapture")
    @patch("audioin.service.MQTTAudioClient")
    def test_audio_processing_loop(self, mock_mqtt_class, mock_audio_class):
        """Test boucle de traitement audio."""
        # Données de test
        test_frames = [
            np.zeros(320, dtype=np.int16).tobytes(),  # Silence
            np.full(320, 15000, dtype=np.int16).tobytes(),  # Son
            np.zeros(320, dtype=np.int16).tobytes(),  # Silence
        ]

        # Mocks
        mock_audio = MockAudioCapture(test_frames)
        mock_mqtt = MockMQTTClient()

        mock_audio_class.return_value = mock_audio
        mock_mqtt_class.return_value = mock_mqtt

        # Service
        service = AudioInService()
        service._initialize_components()
        service.audio_capture = mock_audio
        service.mqtt_client = mock_mqtt
        service.mqtt_client.connected = True

        # Traitement manuel des frames
        for frame_bytes in test_frames:
            service._process_audio_frame(frame_bytes)

        # Vérifications
        assert len(mock_mqtt.published_messages) >= 3  # Au moins RMS + VAD

        rms_messages = [
            msg for msg_type, msg in mock_mqtt.published_messages if msg_type == "rms"
        ]
        vad_messages = [
            msg for msg_type, msg in mock_mqtt.published_messages if msg_type == "vad"
        ]

        assert len(rms_messages) == 3
        assert len(vad_messages) == 3

    def test_config_dynamic_update(self):
        """Test mise à jour dynamique de la config."""
        service = AudioInService()
        service._initialize_components()

        # Config initiale
        initial_threshold = service.vad.config.threshold

        # Simulation mise à jour via MQTT
        new_config = VADConfig(threshold=0.5, attack_ms=60)
        service._on_config_update(new_config)

        # Vérification
        assert service.vad.config.threshold == 0.5
        assert service.vad.config.attack_ms == 60
        assert service.vad.config.threshold != initial_threshold


class TestFileBasedTesting:
    """Tests avec fichiers audio simulés."""

    def create_test_wav(self, filename, duration_ms, frequency, amplitude=0.5):
        """Crée un fichier WAV de test."""
        sample_rate = 16000
        samples = int(sample_rate * duration_ms / 1000)

        t = np.linspace(0, duration_ms / 1000, samples, False)
        signal = amplitude * np.sin(2 * np.pi * frequency * t)
        signal_int16 = (signal * 32767).astype(np.int16)

        with wave.open(filename, "w") as wav_file:
            wav_file.setnchannels(1)
            wav_file.setsampwidth(2)
            wav_file.setframerate(sample_rate)
            wav_file.writeframes(signal_int16.tobytes())

    def test_sine_wave_processing(self):
        """Test traitement d'une onde sinusoïdale."""
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav:
            # Créer signal test : 1kHz, 1 seconde
            self.create_test_wav(temp_wav.name, 1000, 1000, amplitude=0.6)

            # Lire et traiter
            rms_calc = RMSCalculator()
            vad = VoiceActivityDetector(VADConfig(threshold=0.3))

            with wave.open(temp_wav.name, "r") as wav_file:
                frames = wav_file.readframes(wav_file.getnframes())
                audio_data = np.frombuffer(frames, dtype=np.int16)

                # Traiter par chunks de 320 samples
                for i in range(0, len(audio_data), 320):
                    chunk = audio_data[i : i + 320]
                    if len(chunk) == 320:
                        rms = rms_calc.calculate(chunk)
                        vad_result = vad.process(rms)

                        # Signal continu -> RMS stable, VAD actif
                        assert rms > 0.3  # Au-dessus du seuil

            # Nettoyage
            Path(temp_wav.name).unlink()

    def test_silence_detection(self):
        """Test détection de silence."""
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav:
            # Créer silence (fréquence 0)
            self.create_test_wav(temp_wav.name, 500, 0, amplitude=0.0)

            rms_calc = RMSCalculator()

            with wave.open(temp_wav.name, "r") as wav_file:
                frames = wav_file.readframes(wav_file.getnframes())
                audio_data = np.frombuffer(frames, dtype=np.int16)

                # Traiter premier chunk
                chunk = audio_data[:320]
                rms = rms_calc.calculate(chunk)

                assert rms == 0.0  # Silence parfait

            Path(temp_wav.name).unlink()


class TestErrorHandling:
    """Tests de gestion d'erreurs."""

    @patch("audioin.service.AudioCapture")
    def test_audio_capture_failure(self, mock_audio_class):
        """Test échec de capture audio."""
        mock_audio_class.side_effect = Exception("Device not found")

        service = AudioInService()
        success = service.start()

        assert not success
        assert not service.running

    @patch("audioin.service.MQTTAudioClient")
    def test_mqtt_connection_failure(self, mock_mqtt_class):
        """Test échec connexion MQTT."""
        mock_mqtt = Mock()
        mock_mqtt.connect.return_value = False
        mock_mqtt_class.return_value = mock_mqtt

        with patch("audioin.service.AudioCapture") as mock_audio_class:
            mock_audio_class.return_value = Mock()

            service = AudioInService()
            success = service.start()

            assert not success

    def test_invalid_audio_frame(self):
        """Test gestion frame audio invalide."""
        service = AudioInService()
        service._initialize_components()

        # Frame vide ou corrompue
        try:
            service._process_audio_frame(b"")  # Frame vide
            service._process_audio_frame(b"invalid")  # Frame invalide
        except Exception as e:
            pytest.fail(f"Service ne devrait pas crash sur frame invalide: {e}")


@pytest.fixture
def integration_config():
    """Configuration pour tests d'intégration."""
    config = ServiceConfig()
    config.vad.threshold = 0.2  # Seuil bas pour tests
    config.vad.attack_ms = 40  # Réaction rapide
    config.vad.release_ms = 100
    return config


def test_end_to_end_speech_detection(integration_config):
    """Test détection de parole bout-en-bout."""
    service = AudioInService(integration_config)
    service._initialize_components()

    # Simuler pattern de parole
    speech_frames = []

    # Silence initial
    for _ in range(3):
        speech_frames.append(np.zeros(320, dtype=np.int16).tobytes())

    # Parole
    for _ in range(10):
        speech_frames.append(np.full(320, 12000, dtype=np.int16).tobytes())

    # Silence final
    for _ in range(50):  # Long silence pour EOS
        speech_frames.append(np.zeros(320, dtype=np.int16).tobytes())

    # Traiter toutes les frames
    vad_states = []
    eos_detected = False

    for frame_bytes in speech_frames:
        service._process_audio_frame(frame_bytes)

        # Simuler récupération état VAD
        if service.vad.state.value == "active":
            vad_states.append(True)
        else:
            vad_states.append(False)

    # Vérifications
    assert any(vad_states), "Activité vocale détectée"

    # Vérifier que VAD retourne à idle après EOS
    final_stats = service.vad.get_stats()
    assert final_stats["current_state"] in ["idle", "release"]
