"""Calculateur RMS pour signaux audio."""

import numpy as np
from typing import Union


class RMSCalculator:
    """Calculateur d'énergie RMS normalisée."""

    def __init__(self, sample_rate: int = 16000, frame_size: int = 320):
        """
        Initialise le calculateur RMS.

        Args:
            sample_rate: Fréquence d'échantillonnage
            frame_size: Taille de frame en samples
        """
        self.sample_rate = sample_rate
        self.frame_size = frame_size
        self.max_amplitude = 32767.0  # Pour 16 bits signé

    def calculate(self, frame: Union[np.ndarray, bytes]) -> float:
        """
        Calcule le RMS normalisé d'une frame audio.

        Args:
            frame: Frame audio (numpy array ou bytes)

        Returns:
            RMS normalisé entre 0.0 et 1.0
        """
        if isinstance(frame, bytes):
            # Conversion bytes -> numpy array (16-bit signed)
            audio_data = np.frombuffer(frame, dtype=np.int16)
        else:
            audio_data = frame.astype(np.int16)

        if len(audio_data) == 0:
            return 0.0

        # Calcul RMS
        rms_raw = np.sqrt(np.mean(audio_data.astype(np.float64) ** 2))

        # Normalisation [0..1]
        rms_normalized = min(rms_raw / self.max_amplitude, 1.0)

        return float(rms_normalized)

    def calculate_db(self, frame: Union[np.ndarray, bytes]) -> float:
        """
        Calcule le RMS en décibels.

        Args:
            frame: Frame audio

        Returns:
            RMS en dB (référence: pleine échelle)
        """
        rms_norm = self.calculate(frame)

        if rms_norm <= 0.0:
            return -80.0  # Plancher de bruit

        return 20.0 * np.log10(rms_norm)

    def is_valid_frame_size(self, frame: Union[np.ndarray, bytes]) -> bool:
        """Vérifie si la taille de frame est correcte."""
        if isinstance(frame, bytes):
            expected_bytes = self.frame_size * 2  # 16-bit = 2 bytes
            return len(frame) == expected_bytes
        else:
            return len(frame) == self.frame_size


def calculate_rms_simple(audio_data: np.ndarray) -> float:
    """Fonction utilitaire pour calcul RMS simple."""
    if len(audio_data) == 0:
        return 0.0

    return float(np.sqrt(np.mean(audio_data.astype(np.float64) ** 2)))


def normalize_rms(rms_value: float, max_amplitude: float = 32767.0) -> float:
    """Normalise une valeur RMS."""
    return min(rms_value / max_amplitude, 1.0)
