"""
Planificateur de trajectoires S-curve avec limitation de jerk
"""

import math
from typing import List, NamedTuple, Optional
from dataclasses import dataclass
from .utils import get_monotonic_ms


@dataclass
class TrajectoryParams:
    """Paramètres de trajectoire S-curve"""

    v_max: float = 90.0  # deg/s (vitesse max)
    a_max: float = 180.0  # deg/s² (accélération max)
    j_max: float = 720.0  # deg/s³ (jerk max)


class TrajectoryPoint(NamedTuple):
    """Point de trajectoire"""

    timestamp_ms: int
    position: float
    velocity: float
    acceleration: float


class SCurvePlanner:
    """Planificateur de trajectoires S-curve avec jerk limité"""

    def __init__(self, dt_ms: int = 20):
        self.dt_ms = dt_ms
        self.dt_s = dt_ms / 1000.0
        self._current_trajectory: List[TrajectoryPoint] = []
        self._trajectory_index = 0
        self._start_time_ms = 0

    def generate_trajectory(
        self,
        start_pos: float,
        target_pos: float,
        speed_factor: float = 1.0,
        params: Optional[TrajectoryParams] = None,
    ) -> List[TrajectoryPoint]:
        """Génère une trajectoire S-curve"""
        if params is None:
            params = TrajectoryParams()

        # Adapter les paramètres selon speed_factor
        v_max = params.v_max * clamp(speed_factor, 0.1, 1.0)
        a_max = params.a_max * clamp(speed_factor, 0.1, 1.0)
        j_max = params.j_max

        distance = target_pos - start_pos
        if abs(distance) < 1e-6:
            # Pas de mouvement
            timestamp = get_monotonic_ms()
            return [TrajectoryPoint(timestamp, start_pos, 0.0, 0.0)]

        direction = 1.0 if distance > 0 else -1.0
        abs_distance = abs(distance)

        # Calcul des phases S-curve
        # Phase 1: Montée jerk (0 → a_max)
        t1 = a_max / j_max

        # Phase 2: Accélération constante
        # Phase 3: Descente jerk (a_max → 0)
        t3 = t1  # Symétrique

        # Distance pendant les phases de jerk
        d1 = 0.5 * j_max * t1**3
        d3 = d1

        # Vitesse max atteinte après phase 1
        v1 = 0.5 * j_max * t1**2

        # Vérifier si on peut atteindre v_max
        d_accel_min = 2 * d1 + 2 * v1 * t1  # Distance minimale pour atteindre v_max

        if d_accel_min > abs_distance:
            # Trajet trop court, réduire les paramètres
            # Simplification: trajet triangulaire
            t_total = math.sqrt(2 * abs_distance / a_max)
            t2 = 0
            v_max_eff = a_max * t_total / 2
        else:
            # Phase 2 existe
            v_max_eff = min(v_max, v1 + a_max * t1)
            t2 = (abs_distance - 2 * d1 - 2 * v1 * t1) / v_max_eff
            t2 = max(0, t2)

        # Générer les points
        trajectory = []
        t = 0.0
        timestamp_start = get_monotonic_ms()

        # Phase 1: Montée jerk (0 → a_max)
        while t <= t1:
            jerk = j_max * direction
            accel = jerk * t
            velocity = 0.5 * jerk * t**2
            position = start_pos + (1 / 6) * jerk * t**3

            trajectory.append(
                TrajectoryPoint(
                    int(timestamp_start + t * 1000), position, velocity, accel
                )
            )
            t += self.dt_s

        # Phase 2: Accélération constante (optionnelle)
        if t2 > 0:
            accel_const = a_max * direction
            v_start_phase2 = 0.5 * j_max * direction * t1**2
            pos_start_phase2 = start_pos + (1 / 6) * j_max * direction * t1**3

            t_phase2 = 0.0
            while t_phase2 <= t2:
                velocity = v_start_phase2 + accel_const * t_phase2
                position = (
                    pos_start_phase2
                    + v_start_phase2 * t_phase2
                    + 0.5 * accel_const * t_phase2**2
                )

                trajectory.append(
                    TrajectoryPoint(
                        int(timestamp_start + (t1 + t_phase2) * 1000),
                        position,
                        velocity,
                        accel_const,
                    )
                )
                t_phase2 += self.dt_s

        # Phase 3: Descente jerk (a_max → 0)
        t_start_phase3 = t1 + t2
        v_start_phase3 = v_max_eff * direction
        pos_start_phase3 = start_pos + d1 * direction + v_max_eff * direction * t2

        t_phase3 = 0.0
        while t_phase3 <= t3:
            jerk = -j_max * direction
            accel = a_max * direction + jerk * t_phase3
            velocity = (
                v_start_phase3 + a_max * direction * t_phase3 + 0.5 * jerk * t_phase3**2
            )
            position = (
                pos_start_phase3
                + v_start_phase3 * t_phase3
                + 0.5 * a_max * direction * t_phase3**2
                + (1 / 6) * jerk * t_phase3**3
            )

            trajectory.append(
                TrajectoryPoint(
                    int(timestamp_start + (t_start_phase3 + t_phase3) * 1000),
                    position,
                    velocity,
                    accel,
                )
            )
            t_phase3 += self.dt_s

        # Point final (exacte)
        trajectory.append(
            TrajectoryPoint(
                int(timestamp_start + (t_start_phase3 + t3) * 1000),
                target_pos,
                0.0,
                0.0,
            )
        )

        return trajectory

    def start_movement(
        self,
        start_pos: float,
        target_pos: float,
        speed_factor: float = 1.0,
        delay_ms: int = 0,
    ) -> None:
        """Démarre un mouvement vers la position cible"""
        trajectory = self.generate_trajectory(start_pos, target_pos, speed_factor)
        self._current_trajectory = trajectory
        self._trajectory_index = 0
        self._start_time_ms = get_monotonic_ms() + delay_ms

    def get_current_position(self) -> Optional[float]:
        """Retourne la position actuelle selon la trajectoire"""
        if not self._current_trajectory:
            return None

        current_time = get_monotonic_ms()
        if current_time < self._start_time_ms:
            # En attente (delay)
            return (
                self._current_trajectory[0].position
                if self._current_trajectory
                else None
            )

        # Trouver le point de trajectoire approprié
        elapsed_ms = current_time - self._start_time_ms

        # Chercher le point le plus proche
        for i, point in enumerate(self._current_trajectory):
            point_time = point.timestamp_ms - self._current_trajectory[0].timestamp_ms
            if point_time >= elapsed_ms:
                if i == 0:
                    return point.position
                else:
                    # Interpolation linéaire entre points
                    prev_point = self._current_trajectory[i - 1]
                    prev_time = (
                        prev_point.timestamp_ms
                        - self._current_trajectory[0].timestamp_ms
                    )

                    alpha = (elapsed_ms - prev_time) / (point_time - prev_time)
                    alpha = clamp(alpha, 0.0, 1.0)

                    return prev_point.position + alpha * (
                        point.position - prev_point.position
                    )

        # Mouvement terminé
        return (
            self._current_trajectory[-1].position if self._current_trajectory else None
        )

    def is_movement_complete(self) -> bool:
        """Vérifie si le mouvement est terminé"""
        if not self._current_trajectory:
            return True

        current_time = get_monotonic_ms()
        last_point_time = self._start_time_ms + (
            self._current_trajectory[-1].timestamp_ms
            - self._current_trajectory[0].timestamp_ms
        )

        return current_time >= last_point_time


def clamp(value: float, min_val: float, max_val: float) -> float:
    """Clamp une valeur entre min et max"""
    return max(min_val, min(max_val, value))
