"""
Rate limiting middleware
Simple IP-based rate limiting (4 requests per IP like DeepSite)
Uses in-memory storage for rate limiting (no database required)
"""

from fastapi import Request, HTTPException
from functools import wraps
import time
from typing import Dict, Any
import threading
from collections import defaultdict
import sys
from pathlib import Path

# Add src to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))

from core.config import settings

# In-memory storage for rate limiting
_rate_limit_storage = defaultdict(lambda: {'count': 0, 'window_start': 0})
_storage_lock = threading.Lock()

# Read configuration from settings (which loads from .env in root)
DEFAULT_MAX_REQUESTS = settings.max_requests_per_ip
DEFAULT_WINDOW_SECONDS = settings.max_requests_window

def _cleanup_expired_entries():
    """Clean up expired rate limit entries to prevent memory leaks"""
    current_time = time.time()
    with _storage_lock:
        expired_ips = []
        for ip, data in _rate_limit_storage.items():
            # If window is older than 24 hours, remove it
            if current_time - data['window_start'] > 86400:  # 24 hours
                expired_ips.append(ip)
        
        for ip in expired_ips:
            del _rate_limit_storage[ip]

def _check_rate_limit_in_memory(ip_address: str, max_requests: int = 4, window_seconds: int = 3600) -> Dict[str, Any]:
    """
    Check if IP address is within rate limit using in-memory storage.
    
    Args:
        ip_address: Client IP address
        max_requests: Maximum requests allowed (default: 4)
        window_seconds: Time window in seconds (default: 3600 = 1 hour)
    
    Returns:
        Dict with rate limit status
    """
    current_time = time.time()
    window_start = current_time - window_seconds
    
    with _storage_lock:
        # Clean up expired entries periodically
        if len(_rate_limit_storage) > 1000:  # Clean up when we have many entries
            _cleanup_expired_entries()
        
        data = _rate_limit_storage[ip_address]
        
        # If window has expired, reset count
        if data['window_start'] < window_start:
            data['count'] = 1
            data['window_start'] = current_time
            return {
                'allowed': True,
                'count': 1,
                'remaining': max_requests - 1,
                'reset_time': current_time + window_seconds
            }
        else:
            # Check if limit exceeded
            if data['count'] >= max_requests:
                return {
                    'allowed': False,
                    'count': data['count'],
                    'remaining': 0,
                    'reset_time': data['window_start'] + window_seconds
                }
            else:
                # Increment count
                data['count'] += 1
                return {
                    'allowed': True,
                    'count': data['count'],
                    'remaining': max_requests - data['count'],
                    'reset_time': data['window_start'] + window_seconds
                }

def rate_limit(requests: int = None, per: int = None):  # Uses environment variables as defaults
    """
    Rate limiting decorator using in-memory storage
    """
    # Use environment variables as defaults if not specified
    if requests is None:
        requests = DEFAULT_MAX_REQUESTS
    if per is None:
        per = DEFAULT_WINDOW_SECONDS
        
    def decorator(func):
        @wraps(func)
        async def wrapper(*args, **kwargs):
            # Get request object from args or kwargs
            request = None
            
            # Check args first
            for arg in args:
                if isinstance(arg, Request):
                    request = arg
                    break
            
            # Check kwargs if not found in args
            if not request:
                for key, value in kwargs.items():
                    if isinstance(value, Request):
                        request = value
                        break
            
            if not request:
                # If no request found, skip rate limiting
                return await func(*args, **kwargs)
            
            # Get client IP
            client_ip = request.client.host
            
            # Check rate limit using in-memory storage
            rate_limit_result = _check_rate_limit_in_memory(client_ip, requests, per)
            
            if not rate_limit_result['allowed']:
                raise HTTPException(
                    status_code=429,
                    detail=f"Rate limit exceeded. Maximum {requests} requests per {per} seconds. Try again in {int(rate_limit_result['reset_time'] - time.time())} seconds."
                )
            
            # Call the original function
            return await func(*args, **kwargs)
        
        return wrapper
    return decorator