"""
Simplified Database Utilities
Only IP-based rate limiting (4 requests per IP like DeepSite)
"""

import os
import time
from typing import Dict, Any
from .db import get_connection
import logging

logger = logging.getLogger(__name__)

def init_simple_db() -> None:
    """Initialize simplified database with only IP-based rate limiting."""
    conn = get_connection()
    try:
        cur = conn.cursor()
        
        # Create simple IP rate limiting table
        cur.execute(
            """
            CREATE TABLE IF NOT EXISTS ip_rate_limits (
                id INT PRIMARY KEY AUTO_INCREMENT,
                ip_address VARCHAR(64) NOT NULL UNIQUE,
                request_count INT DEFAULT 0,
                window_start DOUBLE NOT NULL,
                created_at DOUBLE NOT NULL,
                updated_at DOUBLE NOT NULL,
                INDEX idx_ip_address (ip_address),
                INDEX idx_window_start (window_start)
            ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
            """
        )
        
        # Create simple logging table (optional, for basic usage tracking)
        cur.execute(
            """
            CREATE TABLE IF NOT EXISTS usage_logs (
                id INT PRIMARY KEY AUTO_INCREMENT,
                ip_address VARCHAR(64) NOT NULL,
                endpoint VARCHAR(255) NOT NULL,
                created_at DOUBLE NOT NULL,
                user_agent TEXT,
                INDEX idx_ip_address (ip_address),
                INDEX idx_created_at (created_at)
            ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
            """
        )
        
        conn.commit()
        logger.info("Simplified database initialized with IP-based rate limiting only")
        
    except Exception as e:
        logger.error(f"Failed to initialize simplified database: {e}")
        raise
    finally:
        try:
            cur.close()
        except Exception:
            pass

def check_rate_limit(ip_address: str, max_requests: int = 4, window_seconds: int = 3600) -> Dict[str, Any]:
    """
    Check if IP address is within rate limit.
    
    Args:
        ip_address: Client IP address
        max_requests: Maximum requests allowed (default: 4)
        window_seconds: Time window in seconds (default: 3600 = 1 hour)
    
    Returns:
        Dict with rate limit status
    """
    try:
        current_time = time.time()
        window_start = current_time - window_seconds
        
        conn = get_connection()
        cur = conn.cursor()
        
        # Get current count for this IP
        cur.execute(
            "SELECT request_count, window_start FROM ip_rate_limits WHERE ip_address = %s",
            (ip_address,)
        )
        row = cur.fetchone()
        
        if row:
            stored_count, stored_window = row['request_count'], row['window_start']
            
            # If window has expired, reset count
            if stored_window < window_start:
                cur.execute(
                    "UPDATE ip_rate_limits SET request_count = 1, window_start = %s, updated_at = %s WHERE ip_address = %s",
                    (current_time, current_time, ip_address)
                )
                conn.commit()
                return {
                    'allowed': True,
                    'count': 1,
                    'remaining': max_requests - 1,
                    'reset_time': current_time + window_seconds
                }
            else:
                # Check if limit exceeded
                if stored_count >= max_requests:
                    return {
                        'allowed': False,
                        'count': stored_count,
                        'remaining': 0,
                        'reset_time': stored_window + window_seconds
                    }
                else:
                    # Increment count
                    cur.execute(
                        "UPDATE ip_rate_limits SET request_count = request_count + 1, updated_at = %s WHERE ip_address = %s",
                        (current_time, ip_address)
                    )
                    conn.commit()
                    return {
                        'allowed': True,
                        'count': stored_count + 1,
                        'remaining': max_requests - (stored_count + 1),
                        'reset_time': stored_window + window_seconds
                    }
        else:
            # First request from this IP
            cur.execute(
                "INSERT INTO ip_rate_limits (ip_address, request_count, window_start, created_at, updated_at) VALUES (%s, 1, %s, %s, %s)",
                (ip_address, current_time, current_time, current_time)
            )
            conn.commit()
            return {
                'allowed': True,
                'count': 1,
                'remaining': max_requests - 1,
                'reset_time': current_time + window_seconds
            }
            
    except Exception as e:
        logger.error(f"Rate limit check failed: {e}")
        # Fail open - allow request if rate limiting fails
        return {
            'allowed': True,
            'count': 0,
            'remaining': max_requests,
            'reset_time': time.time() + window_seconds,
            'error': str(e)
        }
    finally:
        try:
            cur.close()
        except Exception:
            pass

def log_usage(ip_address: str, endpoint: str, user_agent: str = None) -> None:
    """
    Log basic usage for monitoring (optional).
    
    Args:
        ip_address: Client IP address
        endpoint: API endpoint accessed
        user_agent: User agent string (optional)
    """
    try:
        conn = get_connection()
        cur = conn.cursor()
        
        cur.execute(
            "INSERT INTO usage_logs (ip_address, endpoint, created_at, user_agent) VALUES (%s, %s, %s, %s)",
            (ip_address, endpoint, time.time(), user_agent)
        )
        conn.commit()
        
    except Exception as e:
        logger.warning(f"Failed to log usage: {e}")
    finally:
        try:
            cur.close()
        except Exception:
            pass

def get_usage_stats(limit: int = 100) -> Dict[str, Any]:
    """
    Get basic usage statistics.
    
    Args:
        limit: Maximum number of recent logs to return
    
    Returns:
        Dict with usage statistics
    """
    try:
        conn = get_connection()
        cur = conn.cursor()
        
        # Get recent usage logs
        cur.execute(
            "SELECT ip_address, endpoint, created_at FROM usage_logs ORDER BY created_at DESC LIMIT %s",
            (limit,)
        )
        recent_logs = cur.fetchall()
        
        # Get IP rate limit status
        cur.execute(
            "SELECT ip_address, request_count, window_start FROM ip_rate_limits ORDER BY updated_at DESC LIMIT %s",
            (limit,)
        )
        rate_limits = cur.fetchall()
        
        # Get total counts
        cur.execute("SELECT COUNT(*) as total_logs FROM usage_logs")
        total_logs = cur.fetchone()['total_logs']
        
        cur.execute("SELECT COUNT(*) as total_ips FROM ip_rate_limits")
        total_ips = cur.fetchone()['total_ips']
        
        return {
            'recent_logs': recent_logs,
            'rate_limits': rate_limits,
            'total_logs': total_logs,
            'total_ips': total_ips,
            'generated_at': time.time()
        }
        
    except Exception as e:
        logger.error(f"Failed to get usage stats: {e}")
        return {
            'recent_logs': [],
            'rate_limits': [],
            'total_logs': 0,
            'total_ips': 0,
            'error': str(e)
        }
    finally:
        try:
            cur.close()
        except Exception:
            pass
