import logging
import sys
import os
from typing import Dict, Optional

# Default global logging level for console
DEFAULT_CONSOLE_LEVEL = logging.INFO

# Default global logging level for file (always DEBUG for detail)
DEFAULT_FILE_LEVEL = logging.DEBUG
LOG_FILE_PATH = "run.log"

# Central repository for module-specific overrides
MODULE_OVERRIDES: Dict[str, int] = {}

def setup_logging(
    level: Optional[int] = None,
    *,
    file_level: Optional[int] = None,
    overrides: Optional[Dict[str, int]] = None,
):
    """
    Configures the root logger with a console handler and a file handler.

    The first call creates ``run.log`` (truncate). Later calls in the same process
    refresh the console handler and file level but **reuse** the existing file handler
    so the log file is not truncated again (e.g. tests or multiple entry points).

    ``file_level``: if set, overrides the file handler level (e.g. ``logging.ERROR`` for quiet ``run.log``).
    """
    global DEFAULT_CONSOLE_LEVEL, MODULE_OVERRIDES

    if level is not None:
        DEFAULT_CONSOLE_LEVEL = level
    if overrides is not None:
        MODULE_OVERRIDES.update(overrides)

    root_logger = logging.getLogger()
    log_abs = os.path.abspath(LOG_FILE_PATH)
    existing_file_handler: Optional[logging.FileHandler] = None
    for h in list(root_logger.handlers):
        if isinstance(h, logging.FileHandler):
            try:
                if os.path.abspath(getattr(h, "baseFilename", "")) == log_abs:
                    existing_file_handler = h
                    break
            except (OSError, TypeError, ValueError):
                continue

    for handler in list(root_logger.handlers):
        if handler is existing_file_handler:
            continue
        root_logger.removeHandler(handler)
        handler.close()

    # Root logger must be at DEBUG to allow handlers to filter
    root_logger.setLevel(logging.DEBUG)

    formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")

    # Console Handler (respects DEFAULT_CONSOLE_LEVEL)
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setLevel(DEFAULT_CONSOLE_LEVEL)
    console_handler.setFormatter(formatter)
    root_logger.addHandler(console_handler)

    # File Handler (default DEBUG; batch jobs may pass file_level=ERROR)
    fh_level = DEFAULT_FILE_LEVEL if file_level is None else file_level
    if existing_file_handler is not None:
        existing_file_handler.setLevel(fh_level)
        if existing_file_handler.formatter is None:
            existing_file_handler.setFormatter(formatter)
    else:
        file_handler = logging.FileHandler(LOG_FILE_PATH, mode="w")
        file_handler.setLevel(fh_level)
        file_handler.setFormatter(formatter)
        root_logger.addHandler(file_handler)

    # Apply overrides (these will affect both handlers if they lower the level below DEBUG)
    for module_name, module_level in MODULE_OVERRIDES.items():
        logging.getLogger(module_name).setLevel(module_level)

def get_logger(name: str) -> logging.Logger:
    """
    Gets a logger instance and ensures it respects the configured overrides.
    """
    logger = logging.getLogger(name)
    
    # Check for direct match or parent match (e.g., 'foo.bar' matches 'foo')
    for override_name, level in MODULE_OVERRIDES.items():
        if name == override_name or name.startswith(override_name + '.'):
            logger.setLevel(level)
            break
            
    return logger

# Initialize with defaults on import
setup_logging()
