"""Tests pour le Voice Activity Detector."""

import pytest
import time

from audioin.vad import VoiceActivityDetector, VADState, VADResult, create_vad
from audioin.config import VADConfig


class TestVADConfig:
    """Tests de la configuration VAD."""

    def test_vad_config_creation(self):
        """Test création config VAD."""
        config = VADConfig(threshold=0.4, attack_ms=100, release_ms=300, eos_ms=1000)

        assert config.threshold == 0.4
        assert config.attack_ms == 100
        assert config.release_ms == 300
        assert config.eos_ms == 1000

    def test_vad_config_from_dict(self):
        """Test création depuis dictionnaire."""
        data = {"threshold": 0.25, "attack_ms": 60, "release_ms": 150, "eos_ms": 600}

        config = VADConfig.from_dict(data)
        assert config.threshold == 0.25
        assert config.attack_ms == 60

    def test_vad_config_partial_dict(self):
        """Test création avec dictionnaire partiel (valeurs par défaut)."""
        data = {"threshold": 0.35}
        config = VADConfig.from_dict(data)

        assert config.threshold == 0.35
        assert config.attack_ms == 80  # Valeur par défaut

    def test_vad_config_to_dict(self):
        """Test conversion vers dictionnaire."""
        config = VADConfig(threshold=0.3, attack_ms=80)
        data = config.to_dict()

        expected = {"threshold": 0.3, "attack_ms": 80, "release_ms": 200, "eos_ms": 800}
        assert data == expected


class TestVoiceActivityDetector:
    """Tests du VAD."""

    def setup_method(self):
        """Setup pour chaque test."""
        self.config = VADConfig(
            threshold=0.3,
            attack_ms=80,  # 4 frames à 20ms
            release_ms=200,  # 10 frames
            eos_ms=800,  # 40 frames
        )
        self.vad = VoiceActivityDetector(self.config, frame_duration_ms=20.0)

    def test_initial_state(self):
        """Test état initial du VAD."""
        assert self.vad.state == VADState.IDLE
        assert self.vad.total_frames == 0
        assert self.vad.active_frames == 0

    def test_rms_below_threshold(self):
        """Test RMS sous le seuil."""
        result = self.vad.process(0.2)  # Sous seuil 0.3

        assert result.active is False
        assert result.rms == 0.2
        assert result.end_of_speech is False
        assert self.vad.state == VADState.IDLE

    def test_single_frame_above_threshold(self):
        """Test une seule frame au-dessus du seuil (pas assez pour activer)."""
        result = self.vad.process(0.4)

        assert result.active is False  # Pas encore actif
        assert self.vad.state == VADState.IDLE

    def test_attack_transition(self):
        """Test transition IDLE -> ACTIVE."""
        # 4 frames nécessaires pour attack (80ms / 20ms)
        for i in range(3):
            result = self.vad.process(0.4)
            assert result.active is False

        # 4ème frame -> activation
        result = self.vad.process(0.4)
        assert result.active is True
        assert self.vad.state == VADState.ACTIVE

    def test_stay_active(self):
        """Test maintien état ACTIVE."""
        # Activation
        for i in range(4):
            self.vad.process(0.4)

        # Maintien activité
        for i in range(10):
            result = self.vad.process(0.5)
            assert result.active is True
            assert self.vad.state == VADState.ACTIVE

    def test_release_transition(self):
        """Test transition ACTIVE -> RELEASE -> IDLE."""
        # Activation
        for i in range(4):
            self.vad.process(0.4)

        # 10 frames sous seuil pour release (200ms / 20ms)
        for i in range(9):
            result = self.vad.process(0.2)
            assert result.active is True  # Encore actif en release

        # 10ème frame -> désactivation
        result = self.vad.process(0.2)
        assert result.active is False
        assert self.vad.state == VADState.RELEASE

    def test_re_activation_during_release(self):
        """Test réactivation pendant release."""
        # Activation puis début release
        for i in range(4):
            self.vad.process(0.4)

        for i in range(5):  # 5 frames release
            self.vad.process(0.2)

        # Réactivation avant fin de release
        result = self.vad.process(0.4)
        assert result.active is True
        assert self.vad.state == VADState.ACTIVE

    def test_end_of_speech_detection(self):
        """Test détection fin de parole."""
        # Activation
        for i in range(4):
            self.vad.process(0.4)

        # Release
        for i in range(10):
            self.vad.process(0.2)

        # Attendre EOS (800ms - 200ms release = 600ms = 30 frames)
        for i in range(29):
            result = self.vad.process(0.1)
            assert result.end_of_speech is False

        # 30ème frame -> EOS
        result = self.vad.process(0.1)
        assert result.end_of_speech is True
        assert self.vad.state == VADState.IDLE

    def test_config_update(self):
        """Test mise à jour configuration."""
        new_config = VADConfig(
            threshold=0.5,  # Seuil plus élevé
            attack_ms=40,  # Plus rapide
            release_ms=100,
            eos_ms=400,
        )

        self.vad.update_config(new_config)

        # Vérifier que nouvelle config est appliquée
        assert self.vad.config.threshold == 0.5
        assert self.vad.config.attack_ms == 40

    def test_statistics(self):
        """Test calcul statistiques."""
        # Traiter quelques frames
        self.vad.process(0.2)  # Inactive

        for i in range(4):  # Activation
            self.vad.process(0.4)

        self.vad.process(0.5)  # Active
        self.vad.process(0.5)  # Active

        stats = self.vad.get_stats()

        assert stats["total_frames"] == 7
        assert stats["active_frames"] == 3  # 3 frames actives
        assert abs(stats["active_ratio"] - (3 / 7)) < 0.01
        assert stats["current_state"] == "active"

    def test_reset_functionality(self):
        """Test reset du VAD."""
        # Générer de l'activité
        for i in range(10):
            self.vad.process(0.4)

        # Reset
        self.vad.reset()

        assert self.vad.state == VADState.IDLE
        assert self.vad.total_frames == 0
        assert self.vad.active_frames == 0


class TestVADUtilities:
    """Tests des utilitaires VAD."""

    def test_create_vad_function(self):
        """Test fonction utilitaire create_vad."""
        vad = create_vad(threshold=0.25, attack_ms=60, release_ms=150, eos_ms=600)

        assert vad.config.threshold == 0.25
        assert vad.config.attack_ms == 60
        assert vad.config.release_ms == 150
        assert vad.config.eos_ms == 600


@pytest.fixture
def speech_pattern():
    """Pattern de parole simulé : silence -> parole -> silence."""
    return [
        # Silence initial
        *([0.1] * 5),
        # Parole (au-dessus du seuil 0.3)
        *([0.5] * 20),
        # Silence final
        *([0.1] * 50),
    ]


def test_speech_pattern_detection(speech_pattern):
    """Test détection pattern de parole complet."""
    vad = create_vad(threshold=0.3, attack_ms=80, release_ms=200, eos_ms=800)

    results = []
    for rms_value in speech_pattern:
        result = vad.process(rms_value)
        results.append(result)

    # Vérifier qu'il y a eu activation puis désactivation
    activation_found = any(r.active for r in results)
    eos_found = any(r.end_of_speech for r in results)

    assert activation_found, "Parole détectée"
    assert eos_found, "Fin de parole détectée"


def test_vad_timing_precision():
    """Test précision temporelle du VAD."""
    config = VADConfig(threshold=0.3, attack_ms=60, release_ms=100, eos_ms=400)
    vad = VoiceActivityDetector(config, frame_duration_ms=20.0)

    start_time = time.time() * 1000

    # 3 frames pour activation (60ms)
    for i in range(3):
        result = vad.process(0.4)
        # Vérifier timestamp
        assert result.timestamp_ms >= start_time

    assert result.active is True


def test_vad_hysteresis():
    """Test hystérésis du VAD."""
    vad = create_vad(threshold=0.3, attack_ms=40, release_ms=100)

    # Signal qui oscille autour du seuil
    oscillating_signal = [0.35, 0.25, 0.35, 0.25, 0.35, 0.25]

    results = []
    for rms in oscillating_signal:
        result = vad.process(rms)
        results.append(result.active)

    # L'hystérésis doit éviter les commutations rapides
    # Une fois activé, il ne doit pas se désactiver immédiatement
    activations = sum(results)
    assert activations > 0, "Activité détectée"
