"""Tests pour le calculateur RMS."""

import numpy as np
import pytest

from audioin.rms import RMSCalculator, calculate_rms_simple, normalize_rms


class TestRMSCalculator:
    """Tests du calculateur RMS."""

    def setup_method(self):
        """Setup pour chaque test."""
        self.rms_calc = RMSCalculator(sample_rate=16000, frame_size=320)

    def test_rms_silence(self):
        """Test RMS sur signal nul."""
        silence = np.zeros(320, dtype=np.int16)
        rms = self.rms_calc.calculate(silence)
        assert rms == 0.0

    def test_rms_constant_amplitude(self):
        """Test RMS constant avec amplitude fixe."""
        # Sine 1kHz à amplitude 0.5 * max
        amplitude = 16383  # 0.5 * 32767
        samples = np.arange(320)
        sine_wave = (amplitude * np.sin(2 * np.pi * 1000 * samples / 16000)).astype(
            np.int16
        )

        rms = self.rms_calc.calculate(sine_wave)

        # RMS théorique d'une sinusoïde = amplitude / sqrt(2)
        expected_rms = (amplitude / np.sqrt(2)) / 32767.0

        # Tolérance de 1%
        assert abs(rms - expected_rms) < 0.01

    def test_rms_max_amplitude(self):
        """Test RMS avec amplitude maximale."""
        max_signal = np.full(320, 32767, dtype=np.int16)
        rms = self.rms_calc.calculate(max_signal)
        assert rms == 1.0  # Normalisé à 1.0

    def test_rms_from_bytes(self):
        """Test RMS depuis bytes."""
        # Signal test
        signal = np.array([1000, -1000, 2000, -2000] * 80, dtype=np.int16)
        signal_bytes = signal.tobytes()

        rms_from_array = self.rms_calc.calculate(signal)
        rms_from_bytes = self.rms_calc.calculate(signal_bytes)

        assert abs(rms_from_array - rms_from_bytes) < 1e-6

    def test_rms_empty_frame(self):
        """Test RMS avec frame vide."""
        empty = np.array([], dtype=np.int16)
        rms = self.rms_calc.calculate(empty)
        assert rms == 0.0

    def test_frame_size_validation(self):
        """Test validation taille de frame."""
        # Frame correcte (320 samples * 2 bytes)
        correct_frame = b"\x00\x01" * 320
        assert self.rms_calc.is_valid_frame_size(correct_frame)

        # Frame incorrecte
        wrong_frame = b"\x00\x01" * 100
        assert not self.rms_calc.is_valid_frame_size(wrong_frame)

    def test_rms_db_conversion(self):
        """Test conversion en dB."""
        # Signal à 50% amplitude
        signal = np.full(320, 16383, dtype=np.int16)
        rms_db = self.rms_calc.calculate_db(signal)

        # Doit être environ -6dB (20*log10(0.5))
        expected_db = 20 * np.log10(0.5)
        assert abs(rms_db - expected_db) < 0.1

    def test_rms_db_silence(self):
        """Test conversion dB sur silence."""
        silence = np.zeros(320, dtype=np.int16)
        rms_db = self.rms_calc.calculate_db(silence)
        assert rms_db == -80.0  # Plancher de bruit


class TestRMSUtilities:
    """Tests des fonctions utilitaires RMS."""

    def test_calculate_rms_simple(self):
        """Test fonction RMS simple."""
        signal = np.array([1000, -1000, 2000, -2000], dtype=np.float64)
        rms = calculate_rms_simple(signal)

        expected = np.sqrt(np.mean(signal**2))
        assert abs(rms - expected) < 1e-6

    def test_normalize_rms(self):
        """Test normalisation RMS."""
        raw_rms = 16383.5  # 0.5 * max_amplitude
        normalized = normalize_rms(raw_rms, max_amplitude=32767.0)

        expected = 16383.5 / 32767.0
        assert abs(normalized - expected) < 1e-6

    def test_normalize_rms_clipping(self):
        """Test clipping lors de la normalisation."""
        # RMS supérieur au max (peut arriver avec du bruit)
        raw_rms = 40000.0
        normalized = normalize_rms(raw_rms, max_amplitude=32767.0)

        assert normalized == 1.0  # Clippé à 1.0


@pytest.fixture
def sine_wave_1khz():
    """Fixture : signal sine 1kHz, 320 samples."""
    samples = np.arange(320)
    amplitude = 16000
    sine = (amplitude * np.sin(2 * np.pi * 1000 * samples / 16000)).astype(np.int16)
    return sine


@pytest.fixture
def noise_signal():
    """Fixture : bruit blanc, 320 samples."""
    np.random.seed(42)  # Reproductible
    noise = (np.random.randn(320) * 5000).astype(np.int16)
    return noise


def test_rms_sine_wave(sine_wave_1khz):
    """Test RMS sur onde sine connue."""
    rms_calc = RMSCalculator()
    rms = rms_calc.calculate(sine_wave_1khz)

    # RMS théorique
    theoretical_rms = (16000 / np.sqrt(2)) / 32767.0

    assert abs(rms - theoretical_rms) < 0.01


def test_rms_noise_properties(noise_signal):
    """Test propriétés RMS sur bruit."""
    rms_calc = RMSCalculator()
    rms = rms_calc.calculate(noise_signal)

    # Le RMS de bruit gaussien doit être > 0
    assert rms > 0.0
    assert rms < 1.0  # Et normalisé
