import os
import threading
from typing import Optional

logger = __import__('logging').getLogger(__name__)

# Thread-local storage for DB connections
_local = threading.local()

# Read DB config from environment (MySQL only)
DB_NAME = os.getenv('DB_NAME', os.getenv('MYSQL_DB', 'eatance_website_builder'))
DB_USER = os.getenv('DB_USER', os.getenv('MYSQL_USER', 'root'))
DB_PASSWORD = os.getenv('DB_PASSWORD', os.getenv('MYSQL_PASSWORD', ''))
DB_HOST = os.getenv('DB_HOST', os.getenv('MYSQL_HOST', 'localhost'))
DB_PORT = os.getenv('DB_PORT', os.getenv('MYSQL_PORT', '3306'))


class DBAdapter:
    """Light adapter that exposes execute(), executemany(), cursor(), commit(), and fetch helpers
    for underlying connections from sqlite3, psycopg2, and pymysql.
    """
    def __init__(self, conn, backend: str):
        self._conn = conn
        self.backend = backend
        self._last_cursor = None

    def cursor(self):
        # For sqlite the connection itself is a cursor-factory compatible object
        if self.backend == 'sqlite':
            return self._conn.cursor()
        if self.backend in ('postgres', 'mysql'):
            return self._conn.cursor()
        return self._conn.cursor()

    def execute(self, query, params=None):
        cur = self.cursor()
        q = query
        p = params
        # Convert sqlite-style '?' placeholders to '%s' for paramstyle drivers
        if p is not None and self.backend in ('postgres', 'mysql'):
            try:
                if '?' in q:
                    q = q.replace('?', '%s')
            except Exception:
                pass
        # normalize params: avoid passing a bare string/number which would be iterated
        if p is not None and not isinstance(p, (list, tuple, dict)):
            p = (p,)

        try:
            if p is None:
                cur.execute(q)
            else:
                cur.execute(q, p)
            self._last_cursor = cur
            return cur
        except Exception as e:
            # attempt a tolerant retry for common param-formatting issues (MySQLdb/pymysql)
            msg = str(e)
            try:
                logger.debug(f"DB execute failed, attempting tolerant retry: backend={self.backend}; error={msg}")
            except Exception:
                pass

            if p is not None and isinstance(p, (list, tuple)):
                try:
                    alt = tuple(None if x is None else str(x) for x in p)
                    cur.execute(q, alt)
                    self._last_cursor = cur
                    return cur
                except Exception:
                    pass

            try:
                logger.error(f"DB execute error (backend={self.backend}): {e}; query={q}; params={p}")
            except Exception:
                pass
            raise

    def executemany(self, query, seq_of_params):
        cur = self.cursor()
        res = cur.executemany(query, seq_of_params)
        self._last_cursor = cur
        return res

    def fetchone(self):
        if self._last_cursor is None:
            cur = self.cursor()
            return cur.fetchone()
        return self._last_cursor.fetchone()

    def fetchall(self):
        if self._last_cursor is None:
            cur = self.cursor()
            return cur.fetchall()
        return self._last_cursor.fetchall()

    def commit(self):
        try:
            self._conn.commit()
        except Exception:
            pass

    def close(self):
        try:
            self._conn.close()
        except Exception:
            pass

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc, tb):
        # commit on normal exit, rollback on exception if supported
        try:
            if exc_type is None:
                try:
                    self._conn.commit()
                except Exception:
                    pass
            else:
                try:
                    self._conn.rollback()
                except Exception:
                    pass
        except Exception:
            pass

    def __getattr__(self, name):
        # Proxy other attributes to underlying connection
        return getattr(self._conn, name)


def _connect_mysql():
    try:
        import pymysql
    except Exception as e:  # pragma: no cover - optional dependency
        logger.error('pymysql is required for mysql backend but is not installed')
        raise

    if not hasattr(_local, 'db_connection') or _local.db_connection is None:
        conn_info = {
            'host': DB_HOST,
            'user': DB_USER,
            'password': DB_PASSWORD,
            'db': DB_NAME,
            'charset': 'utf8mb4',
            'cursorclass': pymysql.cursors.DictCursor,
        }
        if DB_PORT:
            try:
                conn_info['port'] = int(DB_PORT)
            except Exception:
                pass

        conn = pymysql.connect(**conn_info)
        # Log the connection target for debugging
        try:
            logger.info(f"Connected to MySQL {conn_info.get('db')}@{conn_info.get('host')}:{conn_info.get('port')}")
        except Exception:
            pass
        _local.db_connection = DBAdapter(conn, 'mysql')
    return _local.db_connection


def get_connection():
    """Return a DB connection (MySQL only)."""
    return _connect_mysql()


def close_connection():
    try:
        if hasattr(_local, 'db_connection') and _local.db_connection is not None:
            _local.db_connection.close()
            _local.db_connection = None
    except Exception:
        pass


