"""Gestion du jaw sync - visèmes et RMS."""

import numpy as np
from scipy import signal
from typing import List, Optional, Generator
import time
import threading
from collections import deque

from .voice_types import JawConfig, RMSData, VisemeData
from .config import VoiceConfig
from .logger import setup_logger


class JawSyncProcessor:
    """Processeur pour jaw sync (RMS et visèmes)."""

    def __init__(self, config: JawConfig):
        """Initialise le processeur."""
        self.config = config
        self.logger = setup_logger(__name__)

        # Buffer pour latence (minimum 1)
        self.buffer_size = max(1, int(config.latency_ms * VoiceConfig.SAMPLE_RATE / 1000))
        self.rms_buffer: deque = deque(maxlen=self.buffer_size)

        # Smoothing EMA
        self.smoothed_rms = 0.0

        # Filtres passe-bande pour RMS (200-3000Hz)
        self._setup_bandpass_filter()

        self._lock = threading.Lock()

    def update_config(self, config: JawConfig):
        """Met à jour la configuration."""
        with self._lock:
            self.config = config
            # Recalcule buffer size si latence changée (minimum 1)
            new_buffer_size = max(1, int(config.latency_ms * VoiceConfig.SAMPLE_RATE / 1000))
            if new_buffer_size != self.buffer_size:
                self.buffer_size = new_buffer_size
                # Recrée le buffer - IMPORTANT: remettre à zéro pour éviter les bugs
                self.rms_buffer.clear()
                self.rms_buffer = deque(maxlen=self.buffer_size)

    def reset_state(self):
        """Remet à zéro l'état interne (pour tests)."""
        with self._lock:
            self.smoothed_rms = 0.0
            self.rms_buffer.clear()

    def _setup_bandpass_filter(self):
        """Configure le filtre passe-bande."""
        nyquist = VoiceConfig.SAMPLE_RATE / 2
        low = VoiceConfig.RMS_FILTER_LOW / nyquist
        high = VoiceConfig.RMS_FILTER_HIGH / nyquist

        self.b, self.a = signal.butter(4, [low, high], btype="band")

    def process_audio_frames(
        self, audio_data: np.ndarray, sample_rate: int = VoiceConfig.SAMPLE_RATE
    ) -> Generator[RMSData, None, None]:
        """Traite les frames audio et génère RMS."""
        frame_samples = int(VoiceConfig.FRAME_SIZE * sample_rate / 1000)

        for i in range(0, len(audio_data), frame_samples):
            frame = audio_data[i : i + frame_samples]
            if len(frame) < frame_samples:
                # Pad le dernier frame
                frame = np.pad(frame, (0, frame_samples - len(frame)))

            rms = self._calculate_rms(frame)
            timestamp_ms = int(time.time() * 1000)

            yield RMSData(ts_ms=timestamp_ms, rms=rms)

    def _calculate_rms(self, frame: np.ndarray) -> float:
        """Calcule le RMS d'une frame."""
        with self._lock:
            # Applique le filtre passe-bande
            try:
                filtered = signal.filtfilt(self.b, self.a, frame)
            except Exception:
                # Fallback si filtrage échoue
                filtered = frame

            # Calcul RMS
            rms = np.sqrt(np.mean(filtered**2))

            # Applique gain
            rms *= self.config.gain

            # Smoothing EMA
            alpha = 1.0 - self.config.smoothing
            self.smoothed_rms = alpha * rms + self.config.smoothing * self.smoothed_rms

            # Logique de buffer simplifiée et robuste
            if self.buffer_size == 1:
                # Pas de latence, retourne directement
                return float(self.smoothed_rms)
            else:
                # Ajoute au buffer
                self.rms_buffer.append(self.smoothed_rms)

                # Si le buffer est plein, retourne la valeur la plus ancienne
                if len(self.rms_buffer) == self.buffer_size:
                    return float(self.rms_buffer[0])  # Première valeur, pas popleft
                else:
                    # Buffer pas encore plein, retourne 0 ou valeur actuelle
                    return float(self.smoothed_rms * 0.5)  # Ramp-up progressif

    def process_silence(self, duration_ms: int) -> Generator[RMSData, None, None]:
        """Génère des frames de silence."""
        frame_duration = VoiceConfig.FRAME_SIZE
        num_frames = duration_ms // frame_duration

        for i in range(num_frames):
            timestamp_ms = int(time.time() * 1000)
            yield RMSData(ts_ms=timestamp_ms, rms=0.0)


class VisemeGenerator:
    """Générateur de visèmes depuis phonèmes."""

    def __init__(self):
        """Initialise le générateur."""
        self.logger = setup_logger(__name__)

    def generate_visemes(
        self, phonemes: List[str], start_time_ms: int, total_duration_ms: int
    ) -> List[VisemeData]:
        """Génère les visèmes depuis une liste de phonèmes."""
        from .visemes import VisemeMapper

        mapper = VisemeMapper()
        visemes = []

        if not phonemes:
            return visemes

        # Calcule durée par phonème
        phoneme_duration = total_duration_ms // len(phonemes)

        current_time = start_time_ms

        for phoneme in phonemes:
            viseme_id = mapper.phoneme_to_viseme_id(phoneme)
            duration = mapper.get_phoneme_duration(phoneme, phoneme_duration)

            viseme = VisemeData(ts_ms=current_time, id=viseme_id, dur_ms=duration)
            visemes.append(viseme)

            current_time += duration

        return visemes

    def generate_from_text(
        self, text: str, start_time_ms: int, estimated_duration_ms: int
    ) -> List[VisemeData]:
        """Génère visèmes approximatifs depuis le texte (fallback)."""
        from .visemes import VisemeMapper

        mapper = VisemeMapper()
        visemes = []

        # Analyse basique du texte
        words = text.lower().split()
        if not words:
            return visemes

        word_duration = estimated_duration_ms // len(words)
        current_time = start_time_ms

        for word in words:
            # Estime phonèmes depuis lettres (très approximatif)
            phonemes = list(word)

            for phoneme in phonemes:
                viseme_id = mapper.phoneme_to_viseme_id(phoneme)
                duration = word_duration // len(phonemes)

                viseme = VisemeData(ts_ms=current_time, id=viseme_id, dur_ms=duration)
                visemes.append(viseme)

                current_time += duration

        return visemes
