"""Load testing for AI service"""

import asyncio
import time
import statistics
from typing import List, Dict, Any
from unittest.mock import Mock

import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))

from ai.main import AIService
from ai.llm import llm_client


class LoadTestRunner:
    """Simple load test runner for AI service"""

    def __init__(self):
        self.results: List[Dict[str, Any]] = []

    async def simulate_request(self, request_id: int, text: str) -> Dict[str, Any]:
        """Simulate a single AI request"""
        start_time = time.time()

        try:
            # Simulate the main processing path
            result = await llm_client.generate(
                prompt=f"[SYSTEM] Tu es un crâne joyeux.\n[USER] {text}",
                model="gpt-5-nano",  # Use faster model for load test
                max_tokens=100,
                temperature=0.7,
                top_p=0.9,
                timeout_s=6,
            )

            latency_ms = int((time.time() - start_time) * 1000)

            return {
                "request_id": request_id,
                "success": True,
                "latency_ms": latency_ms,
                "text_length": len(result.text),
                "model": result.model,
            }

        except Exception as e:
            latency_ms = int((time.time() - start_time) * 1000)

            return {
                "request_id": request_id,
                "success": False,
                "latency_ms": latency_ms,
                "error": str(e),
                "model": "unknown",
            }

    async def run_burst_test(self, num_requests: int = 5) -> Dict[str, Any]:
        """Run burst test with multiple concurrent requests"""

        print(f"Starting burst test with {num_requests} concurrent requests...")

        # Prepare test requests
        test_requests = [
            "Raconte une blague courte",
            "Présente-toi en une phrase",
            "Dis bonjour",
            "Quelle heure est-il ?",
            "Merci beaucoup",
        ]

        # Ensure we have enough requests
        while len(test_requests) < num_requests:
            test_requests.extend(test_requests[: num_requests - len(test_requests)])

        # Run requests concurrently
        start_time = time.time()

        tasks = [
            self.simulate_request(i, test_requests[i % len(test_requests)])
            for i in range(num_requests)
        ]

        results = await asyncio.gather(*tasks, return_exceptions=True)

        total_time = time.time() - start_time

        # Process results
        successful_results = [
            r for r in results if isinstance(r, dict) and r.get("success", False)
        ]
        failed_results = [
            r for r in results if isinstance(r, dict) and not r.get("success", True)
        ]
        exception_results = [r for r in results if not isinstance(r, dict)]

        # Calculate statistics
        if successful_results:
            latencies = [r["latency_ms"] for r in successful_results]
            latency_stats = {
                "min": min(latencies),
                "max": max(latencies),
                "mean": statistics.mean(latencies),
                "median": statistics.median(latencies),
                "p95": (
                    latencies[int(len(latencies) * 0.95)]
                    if len(latencies) > 1
                    else latencies[0]
                ),
            }
        else:
            latency_stats = {"min": 0, "max": 0, "mean": 0, "median": 0, "p95": 0}

        return {
            "test_type": "burst",
            "num_requests": num_requests,
            "total_time_s": round(total_time, 2),
            "requests_per_second": round(num_requests / total_time, 2),
            "success_count": len(successful_results),
            "failure_count": len(failed_results),
            "exception_count": len(exception_results),
            "success_rate": len(successful_results) / num_requests,
            "latency_stats": latency_stats,
            "successful_results": successful_results[
                :3
            ],  # Sample of successful results
            "failed_results": (
                failed_results[:3] if failed_results else []
            ),  # Sample of failures
        }

    async def run_sustained_test(
        self, duration_s: int = 10, rate_per_s: float = 1.0
    ) -> Dict[str, Any]:
        """Run sustained load test"""

        print(f"Starting sustained test for {duration_s}s at {rate_per_s} req/s...")

        start_time = time.time()
        end_time = start_time + duration_s
        interval = 1.0 / rate_per_s

        request_id = 0
        results = []

        while time.time() < end_time:
            # Send request
            task = self.simulate_request(
                request_id, f"Request {request_id} - raconte quelque chose"
            )
            results.append(asyncio.create_task(task))
            request_id += 1

            # Wait for next interval
            await asyncio.sleep(interval)

        # Wait for all requests to complete (with timeout)
        print("Waiting for remaining requests to complete...")
        completed_results = []

        for task in results:
            try:
                result = await asyncio.wait_for(task, timeout=10.0)
                completed_results.append(result)
            except asyncio.TimeoutError:
                completed_results.append(
                    {
                        "request_id": -1,
                        "success": False,
                        "latency_ms": 10000,
                        "error": "timeout",
                    }
                )

        total_time = time.time() - start_time

        # Analyze results
        successful = [r for r in completed_results if r.get("success", False)]
        failed = [r for r in completed_results if not r.get("success", True)]

        if successful:
            latencies = [r["latency_ms"] for r in successful]
            latency_stats = {
                "min": min(latencies),
                "max": max(latencies),
                "mean": statistics.mean(latencies),
                "median": statistics.median(latencies),
                "p95": (
                    latencies[int(len(latencies) * 0.95)]
                    if len(latencies) > 1
                    else latencies[0]
                ),
            }
        else:
            latency_stats = {"min": 0, "max": 0, "mean": 0, "median": 0, "p95": 0}

        return {
            "test_type": "sustained",
            "duration_s": round(total_time, 2),
            "target_rate": rate_per_s,
            "actual_rate": round(len(completed_results) / total_time, 2),
            "total_requests": len(completed_results),
            "success_count": len(successful),
            "failure_count": len(failed),
            "success_rate": (
                len(successful) / len(completed_results) if completed_results else 0
            ),
            "latency_stats": latency_stats,
        }

    def print_results(self, results: Dict[str, Any]) -> None:
        """Print test results in a readable format"""

        print(f"\n{'='*50}")
        print(f"Load Test Results - {results['test_type'].upper()}")
        print(f"{'='*50}")

        if results["test_type"] == "burst":
            print(f"Requests: {results['num_requests']}")
            print(f"Total Time: {results['total_time_s']}s")
            print(f"Rate: {results['requests_per_second']} req/s")
        else:
            print(f"Duration: {results['duration_s']}s")
            print(f"Target Rate: {results['target_rate']} req/s")
            print(f"Actual Rate: {results['actual_rate']} req/s")
            print(f"Total Requests: {results['total_requests']}")

        print(f"\nSuccess Rate: {results['success_rate']:.1%}")
        print(f"Successful: {results['success_count']}")
        print(f"Failed: {results['failure_count']}")

        if "exception_count" in results:
            print(f"Exceptions: {results['exception_count']}")

        print(f"\nLatency Statistics (ms):")
        stats = results["latency_stats"]
        print(f"  Min: {stats['min']}")
        print(f"  Max: {stats['max']}")
        print(f"  Mean: {stats['mean']:.1f}")
        print(f"  Median: {stats['median']:.1f}")
        print(f"  P95: {stats['p95']:.1f}")

        # Show sample successful results
        if results.get("successful_results"):
            print(f"\nSample Successful Responses:")
            for i, result in enumerate(results["successful_results"][:2]):
                print(
                    f"  {i+1}. Latency: {result['latency_ms']}ms, Length: {result['text_length']} chars"
                )

        # Show failures if any
        if results.get("failed_results"):
            print(f"\nSample Failures:")
            for i, result in enumerate(results["failed_results"][:2]):
                print(
                    f"  {i+1}. Error: {result.get('error', 'Unknown')}, Latency: {result['latency_ms']}ms"
                )

        print(f"{'='*50}\n")


async def main():
    """Run load tests"""

    print("AI Service Load Testing")
    print(
        "Note: Using mock LLM for testing - replace with real implementation for production testing"
    )

    runner = LoadTestRunner()

    try:
        # Test 1: Burst test (5 concurrent requests)
        burst_results = await runner.run_burst_test(num_requests=5)
        runner.print_results(burst_results)

        # Validate burst test results
        assert (
            burst_results["success_rate"] >= 0.8
        ), f"Burst test success rate too low: {burst_results['success_rate']}"
        assert (
            burst_results["latency_stats"]["p95"] <= 2000
        ), f"Burst test P95 latency too high: {burst_results['latency_stats']['p95']}ms"

        print("✅ Burst test passed!")

        # Wait between tests
        await asyncio.sleep(2)

        # Test 2: Sustained test (10 seconds at 1 req/s)
        sustained_results = await runner.run_sustained_test(
            duration_s=10, rate_per_s=1.0
        )
        runner.print_results(sustained_results)

        # Validate sustained test results
        assert (
            sustained_results["success_rate"] >= 0.8
        ), f"Sustained test success rate too low: {sustained_results['success_rate']}"
        assert (
            sustained_results["latency_stats"]["mean"] <= 1500
        ), f"Sustained test mean latency too high: {sustained_results['latency_stats']['mean']}ms"

        print("✅ Sustained test passed!")

        print("\n🎉 All load tests passed!")

        # Performance summary
        print(f"\nPerformance Summary:")
        print(
            f"Burst Test - P95 Latency: {burst_results['latency_stats']['p95']:.0f}ms"
        )
        print(
            f"Sustained Test - Mean Latency: {sustained_results['latency_stats']['mean']:.0f}ms"
        )
        print(
            f"Overall Success Rate: {min(burst_results['success_rate'], sustained_results['success_rate']):.1%}"
        )

    except AssertionError as e:
        print(f"❌ Load test failed: {e}")
        return False
    except Exception as e:
        print(f"❌ Load test error: {e}")
        return False

    return True


if __name__ == "__main__":
    success = asyncio.run(main())
    exit(0 if success else 1)
