"""Tests pour le module jaw sync."""

import pytest
import numpy as np
import time

from voice.jaw_sync import JawSyncProcessor, VisemeGenerator
from voice.voice_types import JawConfig, VisemeData
from voice.config import VoiceConfig


class TestJawSyncProcessor:
    """Tests pour JawSyncProcessor."""

    def test_initialization(self, jaw_config):
        """Test initialisation processeur."""
        processor = JawSyncProcessor(jaw_config)

        assert processor.config.gain == jaw_config.gain
        assert processor.config.smoothing == jaw_config.smoothing
        assert processor.config.latency_ms == jaw_config.latency_ms
        assert processor.smoothed_rms == 0.0

    def test_config_update(self, jaw_config):
        """Test mise à jour configuration."""
        processor = JawSyncProcessor(jaw_config)

        new_config = JawConfig(gain=2.0, smoothing=0.5, latency_ms=200)
        processor.update_config(new_config)

        assert processor.config.gain == 2.0
        assert processor.config.smoothing == 0.5
        assert processor.config.latency_ms == 200

    def test_rms_calculation_sine_wave(self, jaw_config):
        """Test calcul RMS avec signal sinus."""
        processor = JawSyncProcessor(jaw_config)

        # Génère sinus 440Hz constant
        sample_rate = VoiceConfig.SAMPLE_RATE
        duration = 1.0  # 1 seconde
        frequency = 440.0

        t = np.linspace(0, duration, int(sample_rate * duration))
        sine_wave = 0.5 * np.sin(2 * np.pi * frequency * t)  # Amplitude 0.5

        rms_values = []
        for rms_data in processor.process_audio_frames(sine_wave, sample_rate):
            rms_values.append(rms_data.rms)

        # Pour un sinus d'amplitude 0.5, RMS théorique = 0.5/√2 ≈ 0.353
        expected_rms = 0.5 / np.sqrt(2)

        # Vérifie que RMS est dans la plage attendue
        final_rms = np.mean(rms_values[-5:])  # Moyennes des dernières valeurs
        assert abs(final_rms - expected_rms) < 0.1, f"RMS {final_rms} != {expected_rms}"

    def test_gain_application(self, jaw_config):
        """Test application du gain."""
        processor = JawSyncProcessor(jaw_config)

        # Signal test
        signal_amplitude = 0.3
        test_signal = signal_amplitude * np.ones(1000)

        # Test avec gain 1.0
        rms_values_1x = list(processor.process_audio_frames(test_signal))

        # Test avec gain 2.0
        processor.update_config(JawConfig(gain=2.0, smoothing=0.0, latency_ms=0))
        processor.smoothed_rms = 0.0  # Reset
        rms_values_2x = list(processor.process_audio_frames(test_signal))

        # Le gain devrait doubler les valeurs RMS
        if rms_values_1x and rms_values_2x:
            ratio = rms_values_2x[-1].rms / rms_values_1x[-1].rms
            assert abs(ratio - 2.0) < 0.2, f"Ratio gain {ratio} != 2.0"

    def test_smoothing_effect(self):
        """Test effet du smoothing EMA."""
        # Test avec smoothing fort
        config_smooth = JawConfig(gain=1.0, smoothing=0.8, latency_ms=0)
        processor_smooth = JawSyncProcessor(config_smooth)

        # Test avec smoothing faible
        config_sharp = JawConfig(gain=1.0, smoothing=0.1, latency_ms=0)
        processor_sharp = JawSyncProcessor(config_sharp)

        # Signal step (0 puis 1)
        step_signal = np.concatenate([np.zeros(500), np.ones(500)])

        smooth_values = [rms.rms for rms in processor_smooth.process_audio_frames(step_signal)]
        sharp_values = [rms.rms for rms in processor_sharp.process_audio_frames(step_signal)]

        # Smoothing fort devrait donner transition plus lente
        if len(smooth_values) > 10 and len(sharp_values) > 10:
            # Compare transition à mi-parcours
            mid_idx = len(smooth_values) // 2 + 5

            smooth_transition = smooth_values[mid_idx]
            sharp_transition = sharp_values[mid_idx]

            # Smoothing fort devrait être plus proche de 0 pendant transition
            assert smooth_transition < sharp_transition

    def test_latency_buffer(self):
        """Test buffer de latence."""
        # Latence courte
        config_short = JawConfig(gain=1.0, smoothing=0.0, latency_ms=50)
        processor_short = JawSyncProcessor(config_short)

        # Latence longue
        config_long = JawConfig(gain=1.0, smoothing=0.0, latency_ms=200)
        processor_long = JawSyncProcessor(config_long)

        test_signal = np.ones(1000)  # Signal constant

        short_values = list(processor_short.process_audio_frames(test_signal))
        long_values = list(processor_long.process_audio_frames(test_signal))

        # Latence longue devrait produire plus de zéros au début
        short_zeros = sum(1 for rms in short_values[:10] if rms.rms == 0.0)
        long_zeros = sum(1 for rms in long_values[:10] if rms.rms == 0.0)

        assert long_zeros >= short_zeros

    def test_silence_processing(self, jaw_config):
        """Test traitement du silence."""
        processor = JawSyncProcessor(jaw_config)

        silence_duration = 500  # ms
        silence_data = list(processor.process_silence(silence_duration))

        # Vérifie nombre de frames
        expected_frames = silence_duration // VoiceConfig.FRAME_SIZE
        assert len(silence_data) == expected_frames

        # Vérifie que tous les RMS sont 0
        for rms_data in silence_data:
            assert rms_data.rms == 0.0
            assert rms_data.ts_ms > 0

    def test_threading_safety(self, jaw_config):
        """Test sécurité threading."""
        processor = JawSyncProcessor(jaw_config)

        # Test mise à jour config concurrente
        import threading

        def update_config():
            for i in range(10):
                new_config = JawConfig(gain=1.0 + i * 0.1, smoothing=0.3, latency_ms=120)
                processor.update_config(new_config)
                time.sleep(0.001)

        threads = []
        for _ in range(3):
            thread = threading.Thread(target=update_config)
            threads.append(thread)
            thread.start()

        for thread in threads:
            thread.join()

        # Si pas de deadlock, le test passe
        assert True

    def test_bandpass_filter_setup(self, jaw_config):
        """Test configuration filtre passe-bande."""
        processor = JawSyncProcessor(jaw_config)

        # Vérifie que les coefficients du filtre sont définis
        assert hasattr(processor, "b")
        assert hasattr(processor, "a")
        assert len(processor.b) > 0
        assert len(processor.a) > 0

    def test_frame_size_consistency(self, jaw_config):
        """Test cohérence taille des frames."""
        processor = JawSyncProcessor(jaw_config)

        # Signal de longueur connue
        sample_rate = VoiceConfig.SAMPLE_RATE
        duration = 1.0  # 1 seconde
        signal_length = int(sample_rate * duration)
        test_signal = np.random.randn(signal_length) * 0.1

        rms_count = 0
        for _ in processor.process_audio_frames(test_signal, sample_rate):
            rms_count += 1

        # Vérifie nombre de frames attendu
        expected_frames = duration * 1000 / VoiceConfig.FRAME_SIZE  # frames par seconde
        assert abs(rms_count - expected_frames) <= 2  # Tolérance pour arrondi


class TestVisemeGenerator:
    """Tests pour VisemeGenerator."""

    def test_viseme_generation_from_phonemes(self):
        """Test génération visèmes depuis phonèmes."""
        generator = VisemeGenerator()

        phonemes = ["s", "a", "l", "u", "t"]
        start_time = 1000
        duration = 500

        visemes = generator.generate_visemes(phonemes, start_time, duration)

        assert len(visemes) == len(phonemes)

        # Vérifie ordre temporel
        for i in range(1, len(visemes)):
            assert visemes[i].ts_ms >= visemes[i - 1].ts_ms

        # Vérifie plage de temps
        first_time = visemes[0].ts_ms
        last_time = visemes[-1].ts_ms + visemes[-1].dur_ms

        assert first_time >= start_time
        assert (last_time - start_time) <= duration * 1.5  # Tolérance 50%

    def test_viseme_generation_from_text(self):
        """Test génération visèmes depuis texte (fallback)."""
        generator = VisemeGenerator()

        text = "bonjour monde"
        start_time = 2000
        duration = 1000

        visemes = generator.generate_from_text(text, start_time, duration)

        assert len(visemes) > 0

        # Vérifie que les visèmes couvrent la durée
        total_duration = sum(v.dur_ms for v in visemes)
        assert total_duration <= duration * 1.2  # Tolérance

        # Vérifie IDs visèmes valides
        for viseme in visemes:
            assert 0 <= viseme.id <= 15  # IDs visèmes valides

    def test_empty_input_handling(self):
        """Test gestion entrées vides."""
        generator = VisemeGenerator()

        # Phonèmes vides
        empty_visemes = generator.generate_visemes([], 0, 100)
        assert len(empty_visemes) == 0

        # Texte vide
        empty_text_visemes = generator.generate_from_text("", 0, 100)
        assert len(empty_text_visemes) == 0

    def test_viseme_data_structure(self):
        """Test structure des données visème."""
        generator = VisemeGenerator()

        phonemes = ["a", "e", "i"]
        visemes = generator.generate_visemes(phonemes, 1000, 300)

        for viseme in visemes:
            assert isinstance(viseme, VisemeData)
            assert hasattr(viseme, "ts_ms")
            assert hasattr(viseme, "id")
            assert hasattr(viseme, "dur_ms")

            # Test conversion dict
            viseme_dict = viseme.to_dict()
            assert "ts_ms" in viseme_dict
            assert "id" in viseme_dict
            assert "dur_ms" in viseme_dict

    def test_long_text_handling(self):
        """Test gestion texte long."""
        generator = VisemeGenerator()

        long_text = "Ceci est un texte assez long pour tester la génération de visèmes sur une phrase complète avec plusieurs mots et syllabes différentes."
        start_time = 0
        duration = 5000  # 5 secondes

        visemes = generator.generate_from_text(long_text, start_time, duration)

        assert len(visemes) > 10  # Beaucoup de visèmes pour texte long

        # Vérifie distribution temporelle
        total_duration = visemes[-1].ts_ms + visemes[-1].dur_ms - visemes[0].ts_ms
        assert total_duration <= duration * 1.1  # Tolérance 10%

    def test_special_characters_handling(self):
        """Test gestion caractères spéciaux."""
        generator = VisemeGenerator()

        text_with_special = "Hello! Comment ça va? Très bien... 123 OK."
        visemes = generator.generate_from_text(text_with_special, 0, 1000)

        # Devrait ignorer les caractères non alphabétiques
        assert len(visemes) > 0

        # Tous les IDs doivent être valides
        for viseme in visemes:
            assert 0 <= viseme.id <= 15


class TestVisemeMapping:
    """Tests pour le mapping phonèmes/visèmes."""

    def test_phoneme_to_viseme_mapping(self):
        """Test mapping phonèmes vers visèmes."""
        from voice.visemes import VisemeMapper

        mapper = VisemeMapper()

        # Test voyelles
        assert mapper.phoneme_to_viseme_id("a") == 1
        assert mapper.phoneme_to_viseme_id("e") == 2
        assert mapper.phoneme_to_viseme_id("i") == 3

        # Test consonnes labiales
        assert mapper.phoneme_to_viseme_id("p") == 7
        assert mapper.phoneme_to_viseme_id("b") == 7
        assert mapper.phoneme_to_viseme_id("m") == 7

        # Test silence
        assert mapper.phoneme_to_viseme_id("_") == 0
        assert mapper.phoneme_to_viseme_id(" ") == 0

        # Test phonème inconnu
        assert mapper.phoneme_to_viseme_id("xyz") == 0

    def test_phoneme_duration_estimation(self):
        """Test estimation durée phonèmes."""
        from voice.visemes import VisemeMapper

        mapper = VisemeMapper()

        # Voyelles (plus longues)
        vowel_duration = mapper.get_phoneme_duration("a", 100)
        consonant_duration = mapper.get_phoneme_duration("t", 100)
        silence_duration = mapper.get_phoneme_duration("_", 100)

        # Voyelles devraient être plus longues que consonnes
        assert vowel_duration > consonant_duration

        # Silence a durée fixe
        assert silence_duration == mapper.PHONEME_DURATIONS["silence"]

    def test_espeak_phoneme_parsing(self):
        """Test parsing phonèmes eSpeak."""
        from voice.visemes import VisemeMapper

        mapper = VisemeMapper()

        # Format typique eSpeak
        espeak_output = "s a l y t"
        phonemes = mapper.parse_espeak_phonemes(espeak_output)

        assert len(phonemes) == 5
        assert phonemes == ["s", "a", "l", "y", "t"]

        # Format avec annotations
        espeak_annotated = "s'a:l3y2t"
        phonemes_clean = mapper.parse_espeak_phonemes(espeak_annotated)

        assert len(phonemes_clean) > 0
        # Vérifie que les annotations sont supprimées
        for phoneme in phonemes_clean:
            assert not any(char.isdigit() for char in phoneme)

    def test_phoneme_cleaning(self):
        """Test nettoyage phonèmes."""
        from voice.visemes import VisemeMapper

        mapper = VisemeMapper()

        # Test nettoyage annotations
        assert mapper._clean_phoneme("a'") == "a"
        assert mapper._clean_phoneme("e:") == "e"
        assert mapper._clean_phoneme("o3") == "o"

        # Test mapping spécifique eSpeak
        assert mapper._clean_phoneme("aa") == "a"
        assert mapper._clean_phoneme("ch") == "j"
        assert mapper._clean_phoneme("sh") == "j"

    def test_word_timing_estimation(self):
        """Test estimation timing mots."""
        from voice.visemes import VisemeMapper

        mapper = VisemeMapper()

        # Mot avec voyelles identifiables
        word = "bonjour"
        total_duration = 600

        timings = mapper.estimate_word_timing(word, total_duration)

        assert len(timings) > 0

        # Vérifie que le timing total ne dépasse pas la durée
        total_estimated = sum(duration for _, duration in timings)
        assert total_estimated <= total_duration

        # Vérifie ordre temporel
        current_time = 0
        for start_time, duration in timings:
            assert start_time == current_time
            current_time += duration

    def test_vowel_identification(self):
        """Test identification voyelles."""
        from voice.visemes import VisemeMapper

        mapper = VisemeMapper()

        # Test voyelles
        for vowel in ["a", "e", "i", "o", "u", "y"]:
            assert vowel in mapper.vowels
            duration = mapper.get_phoneme_duration(vowel, 100)
            assert duration > 100  # Voyelles plus longues

        # Test consonnes
        for consonant in ["p", "t", "k", "s", "r"]:
            assert consonant not in mapper.vowels

    def test_complex_espeak_output(self):
        """Test parsing sortie eSpeak complexe."""
        from voice.visemes import VisemeMapper

        mapper = VisemeMapper()

        # Sortie eSpeak complexe avec stress et symboles
        complex_output = "b'o~:Z'u:r m'o~d@"
        phonemes = mapper.parse_espeak_phonemes(complex_output)

        assert len(phonemes) > 0
        # Tous les phonèmes doivent être nettoyés
        for phoneme in phonemes:
            assert phoneme.isalpha() or phoneme == "_"

    def test_mapping_coverage(self):
        """Test couverture du mapping phonèmes."""
        from voice.visemes import VisemeMapper

        mapper = VisemeMapper()

        # Test que tous les phonèmes français courants sont mappés
        common_french_phonemes = [
            "a",
            "e",
            "i",
            "o",
            "u",
            "y",  # Voyelles
            "p",
            "b",
            "t",
            "d",
            "k",
            "g",  # Occlusives
            "f",
            "v",
            "s",
            "z",
            "r",
            "l",
            "j",
            "m",
            "n",  # Autres
        ]

        for phoneme in common_french_phonemes:
            viseme_id = mapper.phoneme_to_viseme_id(phoneme)
            assert viseme_id >= 0  # Au minimum silence (0)
            assert viseme_id <= 15  # Maximum ID visème

    def test_edge_cases(self):
        """Test cas limites."""
        from voice.visemes import VisemeMapper

        mapper = VisemeMapper()

        # Entrées None/vides
        assert mapper.phoneme_to_viseme_id("") == 0
        assert mapper.parse_espeak_phonemes("") == []
        assert mapper.parse_espeak_phonemes("   ") == []

        # Phonèmes très longs
        long_phoneme = "a" * 100
        viseme_id = mapper.phoneme_to_viseme_id(long_phoneme)
        assert viseme_id >= 0

        # Durée zéro
        timings = mapper.estimate_word_timing("test", 0)
        assert len(timings) > 0  # Devrait quand même produire quelque chose
