import pyodbc
import os
import threading
import queue
from contextlib import contextmanager
from typing import List, Dict, Any
from dotenv import load_dotenv
from .logger_config import get_logger

load_dotenv()
logger = get_logger(__name__)


class ConnectionPool:
    """
    Thread-safe connection pool for pyodbc.
    Reuses connections to avoid connect/teardown overhead per query.
    """

    def __init__(self, connection_string: str, pool_size: int = 5):
        self.connection_string = connection_string
        self.pool_size = pool_size
        self._pool: queue.Queue = queue.Queue(maxsize=pool_size)
        self._lock = threading.Lock()
        self._created = 0

    def _create_connection(self):
        return pyodbc.connect(self.connection_string)

    def _check_connection_alive(self, conn):
        """Verify connection is alive; replace with new one if not. Cursor is closed to avoid idle state."""
        cur = None
        try:
            cur = conn.cursor()
            cur.execute("SELECT 1")
            return conn
        except Exception:
            try:
                conn.close()
            except Exception:
                pass
            return self._create_connection()
        finally:
            if cur is not None:
                try:
                    cur.close()
                except Exception:
                    pass

    def get_connection(self):
        try:
            conn = self._pool.get_nowait()
            return self._check_connection_alive(conn)
        except queue.Empty:
            pass
        should_create = False
        with self._lock:
            if self._created < self.pool_size:
                self._created += 1
                should_create = True
        if should_create:
            try:
                return self._create_connection()
            except Exception:
                with self._lock:
                    self._created -= 1
                raise
        # Block until one is available
        conn = self._pool.get()
        return self._check_connection_alive(conn)

    def return_connection(self, conn):
        if conn is None:
            return
        try:
            self._pool.put_nowait(conn)
        except queue.Full:
            conn.close()
            with self._lock:
                self._created -= 1


class DatabaseConnector:
    """
    Centralized MS SQL Database Connector with connection pooling.
    Uses pyodbc; connections are taken from and returned to a pool for reuse.

    Requires ``DB_USER`` and ``DB_PASSWORD`` (no defaults). Optional:
    ``DB_TRUST_SERVER_CERTIFICATE=true`` to set ``TrustServerCertificate=yes`` (dev/self-signed only).
    """

    _instance = None

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(DatabaseConnector, cls).__new__(cls)
            cls._instance._initialized = False
        return cls._instance

    def __init__(self):
        if self._initialized:
            return

        self.server = os.getenv("DB_SERVER", "localhost")
        self.database = os.getenv("DB_NAME", "profiling_db")
        self.driver = os.getenv("DB_DRIVER", "{ODBC Driver 18 for SQL Server}")
        # Default 7: matches parallel fetch tasks in ProfilingOrchestrator._fetch_all_raw_data
        pool_size = int(os.getenv("DB_POOL_SIZE", "7"))

        raw_user = os.getenv("DB_USER")
        raw_password = os.getenv("DB_PASSWORD")
        if raw_user is None or not str(raw_user).strip():
            raise ValueError(
                "DB_USER must be set to a non-empty value (no default credentials). "
                "Set it in the environment or .env."
            )
        if raw_password is None:
            raise ValueError(
                "DB_PASSWORD must be set (no default). Set it in the environment or .env."
            )
        self.username = str(raw_user).strip()
        self.password = raw_password

        trust_raw = os.getenv("DB_TRUST_SERVER_CERTIFICATE", "").strip().lower()
        trust_cert = trust_raw in ("1", "true", "yes", "on")

        self.connection_string = (
            f"DRIVER={self.driver};"
            f"SERVER={self.server};"
            f"DATABASE={self.database};"
            f"UID={self.username};"
            f"PWD={self.password};"
            f"TrustServerCertificate=yes;"
        )
        self._pool = ConnectionPool(self.connection_string, pool_size=pool_size)
        self._initialized = True

    @contextmanager
    def get_connection(self):
        conn = self._pool.get_connection()
        try:
            yield conn
        finally:
            try:
                conn.rollback()
            except Exception:
                pass
            self._pool.return_connection(conn)

    def execute(self, query: str, params: tuple = ()) -> List[Dict[str, Any]]:
        """
        Executes a query and returns results as a list of dictionaries.
        Uses a connection from the pool.
        """
        # Local debug: uncomment to log full SQL, bound params, and row samples (may contain PII).
        # logger.debug("Executing SQL: %s | params=%r", query, params)
        _q = query.strip()
        _stmt = _q.split(None, 1)[0].upper() if _q else "?"
        logger.debug("SQL execute: statement=%s param_count=%s", _stmt, len(params))
        with self.get_connection() as conn:
            cursor = conn.cursor()
            try:
                cursor.execute(query, params)

                # Check if there are results to fetch (e.g., SELECT vs INSERT)
                if cursor.description:
                    columns = [column[0] for column in cursor.description]
                    results = []
                    for row in cursor.fetchall():
                        results.append(dict(zip(columns, row)))

                    logger.info("Fetched %s rows", len(results))
                    # Local debug: uncomment for row payload sample (PII).
                    # if results:
                    #     logger.debug("Data snippet: %s...", str(results[:2])[:200])
                    return results
                else:
                    conn.commit()
                    logger.debug("Query executed (no results returned), connection committed.")
                    return []
            finally:
                cursor.close()


# Centralized provider
db_session = DatabaseConnector()


def verify_database_connectivity() -> None:
    """
    Ensures both the pyodbc pool (orchestrator / raw SQL) and the SQLAlchemy engine
    (profiling tables) can reach the database. Raises on failure so the process can exit
    during application startup.
    """
    results = db_session.execute("SELECT 1 AS connectivity_check")
    if not results or results[0].get("connectivity_check") != 1:
        raise RuntimeError(
            "Database connectivity check failed (pyodbc): unexpected SELECT 1 result"
        )
    from sqlalchemy import text

    from storage.db_repo import get_engine

    with get_engine().connect() as conn:
        one = conn.execute(text("SELECT 1")).scalar_one()
    if one != 1:
        raise RuntimeError(
            "Database connectivity check failed (SQLAlchemy): unexpected SELECT 1 result"
        )


def get_db_session():
    """Returns the centralized database session provider."""
    return db_session


if __name__ == "__main__":
    print("Testing DB Connection (with pool)...")
    try:
        results = db_session.execute("SELECT 1 AS test")
        if results and results[0]['test'] == 1:
            print("✅ Connection Successful!")
            print(f"Connected to: {db_session.database} on {db_session.server}")
        else:
            print("❌ Connection failed (Unexpected result)")
    except Exception as e:
        print(f"❌ Connection Error: {str(e)}")
        print("\nPossible issues:")
        print("1. DB_USER and DB_PASSWORD must be set (no defaults)")
        print("2. For dev/self-signed SQL: set DB_TRUST_SERVER_CERTIFICATE=true if required")
        print("3. Check DB_SERVER, DB_NAME, and .env")
        print("4. Ensure the ODBC Driver is installed")
        print("5. Verify network/firewall access to the SQL Server")
