import re
import tiktoken
from langchain_core.messages import ToolMessage, BaseMessage, HumanMessage, AIMessage, SystemMessage
from .logger_config import get_logger
from typing import List, Optional, Tuple, Callable
from .llm_provider import get_llm
import openai

logger = get_logger(__name__)

# Constants
MAX_TOKENS = 128000
SAFE_THRESHOLD = 0.7
TARGET_TOKENS = int(MAX_TOKENS * SAFE_THRESHOLD)  # ~89,600 tokens
MAX_CHARS_PER_RESULT = 10_000  # Per (query, answer) message cap for multi-query batch

def cap_result_content(text: str, max_chars: int = MAX_CHARS_PER_RESULT, query_label: str = "") -> str:
    """
    Truncates result text to max_chars and appends advice for context safety.
    Used for multi-query batch: each (query, answer) message is capped so no single result fills context.
    """
    if len(text) <= max_chars:
        return text
    suffix = f"\n... [TRUNCATED – original {len(text)} chars. Consider smaller window_days or TOP N to reduce rows.]"
    if query_label:
        suffix = f"\n... [TRUNCATED – {query_label}; consider DATEADD/DATEDIFF or TOP to reduce rows.]"
    return text[: max_chars - len(suffix)] + suffix


def count_tokens(messages: List[BaseMessage]) -> int:
    """
    Estimates token count for a list of messages using tiktoken.
    """
    try:
        encoding = tiktoken.get_encoding("cl100k_base")
        num_tokens = 0
        for msg in messages:
            # Add tokens for message content
            num_tokens += len(encoding.encode(str(msg.content)))
            # Add overhead for message structure (approximate)
            num_tokens += 4 
        return num_tokens
    except Exception as e:
        logger.error(f"Error counting tokens: {e}")
        # Fallback: 1 token ~= 4 chars
        return sum(len(str(m.content)) for m in messages) // 4

def prune_messages(messages: List[BaseMessage]) -> List[BaseMessage]:
    """
    Proactively prunes messages if they exceed the SAFE_THRESHOLD.
    Strategy:
    1. Truncate large ToolOutputs first (they are usually the culprit).
    2. If still over, drop oldest history (excluding System Prompt).
    """
    current_tokens = count_tokens(messages)
    
    if current_tokens < TARGET_TOKENS:
        return messages

    logger.warning(f"Context Usage High ({current_tokens} tokens). Pruning to target {TARGET_TOKENS}...")
    
    # Strategy 1: Truncate oversized ToolMessage content
    pruned_messages = []
    for msg in messages:
        if isinstance(msg, ToolMessage) and len(str(msg.content)) > MAX_CHARS_PER_RESULT:
            truncated_content = (
                str(msg.content)[:MAX_CHARS_PER_RESULT]
                + f"\n... [TRUNCATED SYSTEM: Original size {len(str(msg.content))} chars]"
            )
            # Create new instance with same ID to update if using add_messages, 
            # or just to be safe in the list
            new_msg = ToolMessage(content=truncated_content, tool_call_id=msg.tool_call_id, name=msg.name, id=msg.id)
            pruned_messages.append(new_msg)
        else:
            pruned_messages.append(msg)
            
    # Re-check count
    current_tokens = count_tokens(pruned_messages)
    logger.info(f"Pruning successful (Truncation). New count: {current_tokens}")
    if current_tokens < TARGET_TOKENS:
        return pruned_messages
        
    # Strategy 2: Summarize History (Keep last 5 messages + System)
    logger.warning("Truncation insufficient. Summarizing older history.")
    
    # Identify System Prompt usually at index 0
    system_msgs = [m for m in pruned_messages if isinstance(m, SystemMessage)]
    non_system = [m for m in pruned_messages if not isinstance(m, SystemMessage)]
    
    if len(non_system) <= 5:
        logger.warning("Not enough history to summarize. Returning truncated list.")
        return pruned_messages

    # Split: To Summarize vs Keep (Last 5)
    to_summarize = non_system[:-5]
    recent_history = non_system[-5:]
    
    try:
        summary_text = summarize_history(to_summarize)
        
        # Create new history: [System, Summary, ...Recent]
        summary_msg = SystemMessage(content=f"PREVIOUS CONTEXT SUMMARY: {summary_text}")
        
        final_messages = system_msgs + [summary_msg] + recent_history
        logger.info(f"Summarization complete. Final message count: {len(final_messages)}")
        return final_messages
        
    except Exception as e:
        logger.error(f"Summarization failed: {e}. Falling back to drop-oldest strategy.")
        # Fallback: Just keep system + recent
        return system_msgs + recent_history

def summarize_history(messages: List[BaseMessage]) -> str:
    """
    Uses the LLM to summarize a list of messages into a concise narrative.
    """
    if not messages:
        return "No history."
        
    llm = get_llm(temperature=0.0) # Low temp for factual summary
    
    # We create a new temporary conversation for the summarization task
    conversation_text = ""
    for m in messages:
        role = m.type.upper()
        content = str(m.content)
        # Truncate content in the prompt to avoid recursively hitting limit
        if len(content) > 500:
            content = content[:500] + "..."
        conversation_text += f"{role}: {content}\n"
        
    prompt = f"""
    Summarize the following conversation history into a concise paragraph. 
    Focus on the SQL queries run and the key insights found. 
    Ignore detailed JSON data, just capture the high-level findings.
    
    HISTORY:
    {conversation_text}
    """
    
    try:
        response = llm.invoke([HumanMessage(content=prompt)])
        return response.content
    except Exception as e:
        logger.error(f"Error calling LLM for summary: {e}")
        return "Error generating summary."


def get_volatile_cleanup_updates(messages: List[BaseMessage]) -> List[ToolMessage]:
    """
    Identifies 'old' ToolMessages that are no longer part of the active tool call batch
    and returns placeholder updates to wipe their content (Volatile Memory).
    
    Usage: Call this when entering a NEW tool execution phase.
    """
    if not messages:
         return []
         
    # The last message should be the AIMessage triggering the new tools
    last_msg = messages[-1]
    
    # If the last message isn't an AI message with tool calls, we might be in a weird state,
    # but generally we expect this to be called from a ToolNode.
    current_tool_call_ids = set()
    if hasattr(last_msg, 'tool_calls') and last_msg.tool_calls:
        current_tool_call_ids = {tc['id'] for tc in last_msg.tool_calls}
    
    cleanup_updates = []
    # Identify candidates: ToolMessages that are NOT in the current batch
    for m in messages:
        if isinstance(m, ToolMessage) and m.tool_call_id not in current_tool_call_ids:
             # Only wipe if it has substantial content
             if len(str(m.content)) > 100:
                 logger.info(f"Volatile Memory: Wiping ToolMessage {m.tool_call_id} (Data processed)")
                 cleanup_updates.append(ToolMessage(
                     content="[Data Processed & Removed from History]", 
                     tool_call_id=m.tool_call_id, 
                     id=m.id, 
                     name=m.name
                 ))
                 
    return cleanup_updates


def _parse_tool_message_content(content: str) -> Optional[Tuple[str, str]]:
    """
    Parse ToolMessage content (hypothesis: ...\\nquery: ...\\nanswer: ...) into (hypothesis, query).
    Returns None if the format is not recognized.
    """
    if not content or "query:" not in content or "answer:" not in content:
        return None
    parts = content.split("answer:", 1)
    if len(parts) != 2:
        return None
    head = parts[0].strip()
    if "hypothesis:" in head and "query:" in head:
        h_part = head.split("query:", 1)[0].replace("hypothesis:", "").strip()
        q_part = head.split("query:", 1)[1].strip()
        return (h_part, q_part)
    if "query:" in head:
        q_part = head.split("query:", 1)[1].strip()
        return ("", q_part)
    return None


REDUCE_WINDOW_PROMPT = """This SQL query returned too much data and caused a context limit error.
Return the SAME query with a STRICTER date filter (e.g. last 30 days or 7 days) to reduce rows.
Change only the date range; do not change other logic.
Output ONLY the SQL statement, no explanation or markdown."""


def rewrite_query_reduced_window(query: str) -> str:
    """
    Calls the LLM to return the same query with a reduced date window.
    On failure returns the original query so the pipeline can continue.
    """
    if not (query or "").strip():
        return query or ""
    try:
        llm = get_llm(temperature=0.0)
        prompt = f"{REDUCE_WINDOW_PROMPT}\n\nQuery:\n{query}"
        response = llm.invoke([HumanMessage(content=prompt)])
        new_sql = (getattr(response, "content", None) or "").strip()
        if new_sql:
            if "```" in new_sql:
                m = re.search(r"```(?:\w*)\s*([\s\S]*?)```", new_sql)
                if m:
                    new_sql = m.group(1).strip()
            return new_sql
    except Exception as e:
        logger.warning(f"rewrite_query_reduced_window failed: {e}; using original query.")
    return query


def invoke_with_retry(
    model,
    messages,
    max_retries: int = 3,
    execute_sql_fn: Optional[Callable[[str, str], str]] = None,
):
    """
    Invokes the model with retry logic for context overflow (400) or payload too large (413).
    When the last message is a ToolMessage and execute_sql_fn is provided: parses the query,
    asks the LLM for a reduced-date-window version, runs it via execute_sql_fn, replaces the
    ToolMessage content with the smaller result, and retries. Otherwise replaces with a short
    error message and retries.
    """
    current_messages = list(messages)

    for attempt in range(max_retries + 1):
        try:
            return model.invoke(current_messages)
        except (openai.BadRequestError, openai.APIStatusError, Exception) as e:
            is_context_413 = False
            if isinstance(e, openai.BadRequestError):
                is_context_413 = "context" in str(e).lower() or "token" in str(e).lower()
            elif isinstance(e, openai.APIStatusError):
                is_context_413 = e.status_code == 413
            else:
                is_context_413 = "context" in str(e).lower() or "token" in str(e).lower() or "413" in str(e)

            if not (is_context_413 and attempt < max_retries):
                logger.error(f"Unrecoverable LLM error: {e}")
                raise e

            logger.warning(f"Context/Payload error on attempt {attempt + 1}: {e}")

            if not current_messages or not isinstance(current_messages[-1], ToolMessage):
                raise e

            last_tool = current_messages[-1]
            content = str(last_tool.content or "")
            parsed = _parse_tool_message_content(content)
            new_content = None

            if execute_sql_fn and parsed:
                hypothesis, query = parsed
                new_query = rewrite_query_reduced_window(query)
                try:
                    new_content = execute_sql_fn(hypothesis, new_query)
                    logger.info("Re-ran query with reduced date window and replaced ToolMessage content.")
                except Exception as run_err:
                    logger.warning(f"execute_sql_fn failed: {run_err}; falling back to error message.")

            if new_content is None:
                new_content = (
                    f"SYSTEM ERROR: The tool output was too large and exceeded the model's context limit. "
                    f"Result size was {len(content)} chars. "
                    f"PLEASE RE-TRY with a DATE FILTER (e.g., last 30 or 7 days) to reduce data volume."
                )

            current_messages[-1] = ToolMessage(
                content=new_content,
                tool_call_id=last_tool.tool_call_id,
                name=last_tool.name,
            )
            continue
