"""
Deep Analysis Agent nodes and routers.
Class-based nodes; executor logic split into QueryBatchExecutor for easier debugging.
"""

import os
import re
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict, Any, Tuple, Optional

try:
    from langsmith.utils import ContextThreadPoolExecutor
except ImportError:
    ContextThreadPoolExecutor = None

# Use ContextThreadPoolExecutor when available so LangSmith traces nest under the LangGraph run
_Executor = ContextThreadPoolExecutor if ContextThreadPoolExecutor is not None else ThreadPoolExecutor

from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage, SystemMessage

from .state import DeepAnalysisState
from .constants import REASONER_QUERIES_MAX
from .models import ReasonerQueriesBatch
from .tools import execute_sql_query
from .helpers import serialize_sql_result, normalize_markdown_output
from .prompts import (
    INVESTIGATION_PROMPT,
    REPORTING_PROMPT,
    REASONER_SYSTEM_PROMPT,
    REASONER_PROMPT,
    QUERY_FIXER_PROMPT,
    FINALIZER_PROMPT,
    FINDING_FROM_QUERY_RESULT_PROMPT,
    DECIDE_CONTINUE_OR_FINALIZE_PROMPT,
)
from core.logger_config import get_logger
from core.llm_provider import get_llm
from core.context_manager import (
    cap_result_content,
    MAX_CHARS_PER_RESULT,
    count_tokens,
    TARGET_TOKENS,
    invoke_with_retry,
    rewrite_query_reduced_window,
)

logger = get_logger(__name__)


# --- Shared: simple LLM invocation (used by query_fixer, finalizer, decide) ---

def _invoke_llm(prompt: str, temperature: float = 0) -> str:
    """Single HumanMessage prompt; returns content string. Never returns None."""
    response = get_llm(temperature=temperature).invoke([HumanMessage(content=prompt)])
    content = getattr(response, "content", None) if response else None
    return (content or "").strip()


# --- Helper: build ToolMessage content for one query (used by context retry) ---


def _run_query_and_build_tool_content(hypothesis: str, query: str) -> str:
    """
    Runs the given SQL query and returns full ToolMessage content string
    (hypothesis + query + answer). Used when retrying after context overflow
    with a reduced-date-window query.
    """
    try:
        result = execute_sql_query.invoke({"query": query})
    except Exception as e:
        return f"hypothesis: {hypothesis or ''}\nquery: {query}\nanswer: SYSTEM_ERROR: {e}"

    if isinstance(result, dict) and result.get("_sys_error"):
        return f"hypothesis: {hypothesis or ''}\nquery: {query}\nanswer: SYSTEM_ERROR: {result['_sys_error']}"
    if isinstance(result, list) and result and isinstance(result[0], dict) and result[0].get("error"):
        err = result[0]["error"]
        return f"hypothesis: {hypothesis or ''}\nquery: {query}\nanswer: SQL_ERROR: {err}"
    if isinstance(result, dict) and result.get("error"):
        return f"hypothesis: {hypothesis or ''}\nquery: {query}\nanswer: SYSTEM_ERROR: {result['error']}"

    raw_answer = serialize_sql_result(result) if result else "No rows."
    answer_text = cap_result_content(raw_answer, max_chars=MAX_CHARS_PER_RESULT, query_label="retry reduced window")
    return f"hypothesis: {hypothesis or ''}\nquery: {query}\nanswer: {answer_text}"


# --- QueryBatchExecutor: executor logic split into testable steps ---

class QueryBatchExecutor:
    """
    Runs a batch of SQL queries in parallel and builds tool messages + findings.
    Each step is a separate method for easier debugging and testing.
    """

    def resolve_batch(self, state: DeepAnalysisState) -> List[Dict[str, str]]:
        """Resolve batch from state; apply fixed_query when set (reserved for a future QueryFixer loop)."""
        batch = state.get("last_queries_batch") or []
        if state.get("fixed_query") is not None and state.get("failed_query_index") is not None:
            idx = state["failed_query_index"]
            batch = list(batch)
            if 0 <= idx < len(batch):
                batch[idx] = {
                    "hypothesis": batch[idx].get("hypothesis", ""),
                    "query": state["fixed_query"],
                }
        return batch

    def run_queries_parallel(self, batch: List[Dict[str, str]]) -> List[Any]:
        """Run all queries in parallel; returns list of results in same order as valid items."""
        valid = [{"query": (item.get("query") or "").strip(), "hypothesis": (item.get("hypothesis") or "")} for item in batch if (item.get("query") or "").strip()]

        def run_one(payload: dict) -> Any:
            try:
                return execute_sql_query.invoke({"query": payload["query"]})
            except Exception as e:
                return {"_sys_error": str(e)}

        with _Executor(max_workers=min(len(valid), REASONER_QUERIES_MAX)) as ex:
            return list(ex.map(run_one, valid))

    def classify_result(self, result: Any) -> str:
        """One of: 'success', 'sys_exception', 'sql_error', 'system_error'."""
        if isinstance(result, dict) and result.get("_sys_error"):
            return "sys_exception"
        if isinstance(result, list) and len(result) > 0 and isinstance(result[0], dict) and result[0].get("error"):
            return "sql_error"
        if isinstance(result, dict) and result.get("error"):
            return "system_error"
        return "success"

    def _maybe_reduce_result_on_size(self, query: str, hypothesis: str, result: Any) -> Tuple[str, Any]:
        """
        Backup path: if result is too large (would blow context), ask LLM for reduced-date-window
        query, re-run, and return smaller (query_to_use, result_to_use). Otherwise return (query, result).
        Rare in practice; keeps context safe when a query returns unexpectedly large data.
        """
        if not result or not isinstance(result, list):
            return (query, result)
        raw_answer = serialize_sql_result(result)
        if len(raw_answer) <= MAX_CHARS_PER_RESULT:
            return (query, result)
        logger.warning(
            "Executor: result size %s chars exceeds threshold; re-running with reduced date window.",
            len(raw_answer),
        )
        new_query = rewrite_query_reduced_window(query)
        if new_query.strip() == query.strip():
            return (query, result)
        try:
            new_result = execute_sql_query.invoke({"query": new_query})
        except Exception as e:
            logger.warning("Executor: reduced-window re-run failed: %s; using original result.", e)
            return (query, result)
        if self.classify_result(new_result) != "success":
            return (query, result)
        logger.info("Executor: using reduced-window result for this query.")
        return (new_query, new_result)

    def _process_one_result(
        self, item: Tuple[int, str, str, Any]
    ) -> Tuple[int, str, Any, str, Optional[str]]:
        """Run query → check size → maybe reduce. Returns (index, query_used, result_used, kind, error_msg). Used in parallel."""
        i, query, hypothesis, result = item
        kind = self.classify_result(result)
        if kind == "sys_exception":
            return (i, query, result, kind, (result.get("_sys_error") or ""))
        if kind == "sql_error":
            err = (result[0].get("error") or "") if (result and isinstance(result[0], dict)) else ""
            return (i, query, result, kind, err)
        if kind == "system_error":
            return (i, query, result, kind, (result.get("error") or ""))
        query_used, result_used = self._maybe_reduce_result_on_size(query, hypothesis, result)
        return (i, query_used, result_used, "success", None)

    def build_success_message(
        self,
        query: str,
        hypothesis: str,
        result: Any,
        index: int,
        batch_len: int,
    ) -> Tuple[BaseMessage, str]:
        """Build ToolMessage and short finding snippet for a successful result."""
        tool_call_id = f"batch_{index}"
        is_empty = not result or len(result) == 0
        raw_answer = serialize_sql_result(result) if result else "No rows."
        answer_text = "No data rows." if not result else cap_result_content(
            raw_answer,
            max_chars=MAX_CHARS_PER_RESULT,
            query_label=f"query {index + 1} of {batch_len}",
        )
        content = f"hypothesis: {hypothesis or ''}\nquery: {query}\nanswer: {answer_text}"
        if len(content) > MAX_CHARS_PER_RESULT + 500:
            content = cap_result_content(content, max_chars=MAX_CHARS_PER_RESULT, query_label=f"query {index + 1} of {batch_len}")
        msg = ToolMessage(content=content, tool_call_id=tool_call_id)
        # Keep findings short so combined_findings does not blow up context
        hyp = (hypothesis or f"Query {index+1}")[:50]
        ans = answer_text[:80].replace("\n", " ").strip()
        finding = f"[{hyp}] {ans}" + ("..." if len(answer_text) > 80 else "")
        return msg, finding

    def run(self, state: DeepAnalysisState) -> Dict[str, Any]:
        """Full executor step: resolve batch, run queries, process results, return state update."""
        batch = self.resolve_batch(state)
        if not batch:
            return {"last_error": "SYSTEM_ERROR: No queries to execute."}

        turn_display = state["tool_call_count"] + 1
        max_turns = state.get("max_turns", 2)
        logger.info(f"Executor running batch of {len(batch)} queries (turn {turn_display} of {max_turns})")
        valid_items = [
            (i, (item.get("query") or "").strip(), (item.get("hypothesis") or ""))
            for i, item in enumerate(batch)
            if (item.get("query") or "").strip()
        ]
        if not valid_items:
            return {"last_error": "SYSTEM_ERROR: No queries to execute."}

        results_list = self.run_queries_parallel(batch)
        items_for_process = [
            (valid_items[j][0], valid_items[j][1], valid_items[j][2], results_list[j])
            for j in range(len(valid_items))
        ]
        with _Executor(max_workers=min(len(items_for_process), REASONER_QUERIES_MAX)) as ex:
            processed = list(ex.map(self._process_one_result, items_for_process))
        processed.sort(key=lambda x: x[0])

        tool_messages: List[BaseMessage] = []
        findings_this_turn: List[str] = []
        results_empty: List[bool] = []
        last_error_out: Optional[str] = None
        failed_query_index_out: Optional[int] = None

        for (i, query_used, result_used, kind, error_msg) in processed:
            item = batch[i]
            hypothesis = (item.get("hypothesis") or "")
            tool_call_id = f"batch_{i}"
            logger.info("-----------------")
            logger.info("HYPOTHESIS : %s", hypothesis)
            logger.info("LLM query : %s", query_used)
            logger.info("-----------------------")

            if kind == "sys_exception":
                logger.error(f"Executor system failure for query {i}: {error_msg}")
                last_error_out = f"SYSTEM_ERROR: {error_msg}"
                failed_query_index_out = i
                tool_messages.append(
                    ToolMessage(
                        content=(
                            f"hypothesis: {hypothesis}\nquery: {query_used}\nanswer: SYSTEM_ERROR: {error_msg}"
                        ),
                        tool_call_id=tool_call_id,
                    )
                )
                findings_this_turn.append(f"[Query {i+1} failed] {error_msg[:100]}")
                results_empty.append(True)
                continue

            if kind == "sql_error":
                logger.error(
                    "SQL Error (query %s): %s | HYPOTHESIS: %s | QUERY (first 200 chars): %s",
                    i + 1,
                    error_msg,
                    hypothesis[:200] if hypothesis else "(none)",
                    query_used[:200] + ("..." if len(query_used) > 200 else ""),
                )
                last_error_out = f"SQL_ERROR: {error_msg} | QUERY: {query_used}"
                failed_query_index_out = i
                tool_messages.append(
                    ToolMessage(
                        content=f"hypothesis: {hypothesis}\nquery: {query_used}\nanswer: SQL_ERROR: {error_msg}",
                        tool_call_id=tool_call_id,
                    )
                )
                findings_this_turn.append(f"[Query {i+1} failed] {error_msg[:100]}")
                results_empty.append(True)
                continue

            if kind == "system_error":
                tool_messages.append(
                    ToolMessage(content=f"SYSTEM_ERROR: {error_msg}", tool_call_id=tool_call_id)
                )
                return {
                    "messages": tool_messages,
                    "last_error": f"SYSTEM_ERROR: {error_msg}",
                    "last_queries_batch": None,
                    "failed_query_index": None,
                    "fixed_query": None,
                    "tool_call_count": state["tool_call_count"] + 1,
                }

            msg, finding = self.build_success_message(
                query_used, hypothesis, result_used, i, len(batch)
            )
            tool_messages.append(msg)
            findings_this_turn.append(finding)
            results_empty.append(not result_used or len(result_used) == 0)

        batch_all_empty = len(valid_items) > 0 and len(results_empty) == len(valid_items) and all(results_empty)
        if batch_all_empty:
            logger.info(f"All {len(batch)} queries in this turn returned no data; will ask LLM whether to continue or finalize.")

        new_messages = tool_messages
        total_tokens = count_tokens(state["messages"] + new_messages) if state.get("messages") else 0
        context_near_limit = total_tokens > TARGET_TOKENS * 0.9

        # Clear last_error so routing always goes to findings → finalizer (QueryFixer is not in the graph).
        # Per-query SQL failures are already in ToolMessage content for FindingsNode.
        return {
            "messages": new_messages,
            "findings": findings_this_turn,
            "last_query": None,
            "last_queries_batch": None,
            "last_error": None,
            "failed_query_index": None,
            "fixed_query": None,
            "retry_count": 0,
            "tool_call_count": state["tool_call_count"] + 1,
            "context_near_limit": context_near_limit,
            "batch_all_empty": batch_all_empty,
        }


# --- Node classes (each callable as node(state)) ---

class ReasonerNode:
    """Generates 3–5 investigative SQL queries (single turn) using structured LLM output. Does not run queries; Executor does that."""

    def __init__(self, analytical_schema: str):
        self.analytical_schema = analytical_schema

    def __call__(self, state: DeepAnalysisState) -> Dict[str, Any]:
        logger.info(f"Reasoner generating 3–{REASONER_QUERIES_MAX} queries for User {state['user_id']}")

        system_prompt = REASONER_SYSTEM_PROMPT.format(
            user_id=state["user_id"],
            analytical_schema=self.analytical_schema,
            profile_json=state["profile_json"],
            investigation_prompt=INVESTIGATION_PROMPT,
        )

        reasoner_provider = "openai" if os.getenv("OPENAI_API_KEY") else None
        model = get_llm(temperature=0.2, provider=reasoner_provider).with_structured_output(ReasonerQueriesBatch)
        messages = [
            SystemMessage(content=system_prompt),
            HumanMessage(content=REASONER_PROMPT.format(user_id=state["user_id"])),
        ]
        batch: ReasonerQueriesBatch = invoke_with_retry(model, messages)
        queries_list = [{"hypothesis": q.hypothesis, "query": q.query} for q in batch.queries]
        logger.info(f"Reasoner produced {len(queries_list)} queries.")

        return {
            "messages": [AIMessage(content=f"Generated {len(queries_list)} investigative queries.")],
            "last_query": None,
            "last_queries_batch": queries_list,
            "last_error": None,
            "failed_query_index": None,
            "fixed_query": None,
            "retry_count": 0,
            "context_near_limit": False,
        }


class ExecutorNode:
    """Runs the batch of SQL queries (from Reasoner, 3–5) in parallel and appends tool messages + findings."""

    def __init__(self):
        self._runner = QueryBatchExecutor()

    def __call__(self, state: DeepAnalysisState) -> Dict[str, Any]:
        return self._runner.run(state)


class QueryFixerNode:
    """Repairs one failed SQL query using LLM."""

    def __init__(self, analytical_schema: str):
        self.analytical_schema = analytical_schema

    def __call__(self, state: DeepAnalysisState) -> Dict[str, Any]:
        error_log = state["last_error"]
        idx = state.get("failed_query_index", 0)
        logger.info(f"Query Fixer attempting SQL repair for query index {idx}...")

        prompt = QUERY_FIXER_PROMPT.format(error_log=error_log, analytical_schema=self.analytical_schema)
        content = _invoke_llm(prompt, temperature=0)
        fixed_query = content.strip()
        fence_match = re.search(r"```(?:sql)?\s*([\s\S]*?)```", fixed_query, flags=re.IGNORECASE)
        if fence_match:
            fixed_query = fence_match.group(1).strip()

        return {
            "fixed_query": fixed_query,
            "retry_count": state["retry_count"] + 1,
            "last_error": None,
        }


class FinalizerNode:
    """Produces the final behavioral report from profile + findings."""

    def __call__(self, state: DeepAnalysisState) -> Dict[str, Any]:
        logger.info("Starting finalize node (synthesizing report from profile + findings).")
        logger.info("Finalizing Behavioral Analysis for User %s", state["user_id"])

        llm_findings = state.get("llm_findings") or []
        executor_findings = state.get("findings") or []
        failure_notes = [
            f
            for f in executor_findings
            if isinstance(f, str) and "[Query" in f and "failed]" in f
        ]
        if llm_findings:
            combined_findings = "\n\n".join(llm_findings)
            if failure_notes:
                combined_findings += (
                    "\n\n---\nQueries that failed before analysis (no LLM interpretation):\n"
                    + "\n".join(failure_notes)
                )
        else:
            combined_findings = "\n\n".join(executor_findings)
        if not combined_findings:
            combined_findings = "The investigation concludes that the initial profile is consistent with raw event logs, with no major supplemental anomalies found."

        prompt = FINALIZER_PROMPT.format(
            profile_json=state["profile_json"],
            combined_findings=combined_findings,
            reporting_prompt=REPORTING_PROMPT,
        )
        content = _invoke_llm(prompt, temperature=0.1)
        content = normalize_markdown_output(content)
        content = (content or "").strip() if isinstance(content, str) else str(content or "")

        if content and ("```sql" in content.lower() or "select " in content.lower()):
            logger.warning("SQL leakage detected in Finalizer output. Stripping technical blocks.")
            content = re.sub(r"```sql.*?```", "", content, flags=re.DOTALL | re.IGNORECASE)
            content = re.sub(r"SELECT\s+.*?\s+FROM\s+.*?;", "[Technical Detail Redacted]", content, flags=re.DOTALL | re.IGNORECASE)

        return {"final_analysis": content}


def _parse_tool_message_for_finding(content: str) -> Optional[Tuple[str, str, str]]:
    """Parse ToolMessage content into (hypothesis, query, result). Returns None if parse fails."""
    if not content or "answer:" not in content:
        return None
    parts = content.split("answer:", 1)
    if len(parts) != 2:
        return None
    head, result = parts[0].strip(), parts[1].strip()
    result = result[:500] + ("..." if len(result) > 500 else "")
    if "hypothesis:" in head and "query:" in head:
        h_part = head.split("query:", 1)[0].replace("hypothesis:", "").strip()
        q_part = head.split("query:", 1)[1].strip()
        return (h_part, q_part, result)
    if "query:" in head:
        q_part = head.split("query:", 1)[1].strip()
        return ("", q_part, result)
    return ("", head, result)


def _answer_should_skip_finding_llm(answer: str) -> bool:
    """True when the executor/tool layer failed — do not call the findings LLM for this row."""
    a = (answer or "").strip()
    u = a.upper()
    return u.startswith("SQL_ERROR:") or u.startswith("SYSTEM_ERROR:")


# Must match ReasonerNode AIMessage text so we only parse ToolMessages from the latest batch.
_REASONER_BATCH_AIMSG_MARKER = "investigative queries."


def _tool_messages_from_latest_reasoner_batch(messages: List[BaseMessage]) -> List[ToolMessage]:
    """
    When the graph loops (e.g. reasoner_abort -> reasoner -> executor), ``messages`` accumulates
    ToolMessages from earlier turns. Only analyze tools after the most recent Reasoner batch marker.
    """
    last_reasoner_idx = -1
    for i, m in enumerate(messages):
        if isinstance(m, AIMessage):
            c = getattr(m, "content", "") or ""
            if _REASONER_BATCH_AIMSG_MARKER in c and "Generated" in c:
                last_reasoner_idx = i
    if last_reasoner_idx < 0:
        return [m for m in messages if isinstance(m, ToolMessage)]
    return [m for m in messages[last_reasoner_idx + 1 :] if isinstance(m, ToolMessage)]


class FindingsNode:
    """Runs parallel LLM calls only for successful tool rows: each (hypothesis, query, result) -> one finding.
    Skips SQL/system errors (syntax, invalid object, etc.) — no LLM spend on those rows."""

    def __call__(self, state: DeepAnalysisState) -> Dict[str, Any]:
        messages = state.get("messages") or []
        tool_messages = _tool_messages_from_latest_reasoner_batch(messages)
        triples: List[Tuple[str, str, str]] = []
        skipped_failures = 0
        for m in tool_messages:
            parsed = _parse_tool_message_for_finding(getattr(m, "content", "") or "")
            if not parsed:
                continue
            _hyp, _q, res = parsed
            if _answer_should_skip_finding_llm(res):
                skipped_failures += 1
                continue
            triples.append(parsed)
        if skipped_failures:
            logger.info(
                "FindingsNode: skipping %s tool message(s) with SQL/SYSTEM errors (no findings LLM).",
                skipped_failures,
            )
        if not triples:
            logger.warning("FindingsNode: no successful (hypothesis, query, result) triples; returning empty llm_findings.")
            return {"llm_findings": []}

        n = len(triples)
        logger.info("FindingsNode: starting %s parallel LLM analyses (hypothesis+query+result -> finding).", n)

        def one_finding(args: Tuple[str, str, str]) -> str:
            hyp, q, res = args
            prompt = FINDING_FROM_QUERY_RESULT_PROMPT.format(
                hypothesis=hyp,
                query=q,
                result=res,
            )
            try:
                return _invoke_llm(prompt, temperature=0.1)
            except Exception as e:
                logger.warning(f"FindingsNode LLM call failed: {e}")
                return f"[Finding unavailable: {str(e)[:100]}]"

        def analyze_one(item: Tuple[int, Tuple[str, str, str]]) -> str:
            idx, triple = item
            logger.info("Analyzing finding %s/%s...", idx + 1, n)
            return one_finding(triple)

        with _Executor(max_workers=min(n, REASONER_QUERIES_MAX)) as ex:
            findings_list = list(ex.map(analyze_one, list(enumerate(triples))))
        logger.info("FindingsNode: all %s findings complete; produced %s findings.", n, len(findings_list))
        return {"llm_findings": findings_list}


class DecideNode:
    """After each batch, asks LLM whether to continue or finalize."""

    def __call__(self, state: DeepAnalysisState) -> Dict[str, Any]:
        user_id = state["user_id"]
        profile_snippet = (state.get("profile_json") or "")[:800]
        raw_findings = state.get("findings") or []
        findings_text = "\n".join(raw_findings) if raw_findings else "No findings yet."
        findings_text = findings_text[:1500] + ("..." if len(findings_text) > 1500 else "")
        turns_remaining = max(0, state.get("max_turns", 2) - state["tool_call_count"])
        batch_all_empty = state.get("batch_all_empty") is True
        empty_note = " This batch returned no data for any query." if batch_all_empty else ""

        prompt = DECIDE_CONTINUE_OR_FINALIZE_PROMPT.format(
            user_id=user_id,
            profile_snippet=profile_snippet,
            findings_text=findings_text,
            turns_remaining=turns_remaining,
            empty_note=empty_note,
        )
        try:
            content = _invoke_llm(prompt, temperature=0)
            first_word = (content or "").strip().upper().split()[0] if (content or "").strip() else ""
            decision = "finalize" if first_word == "FINALIZE" else "continue"
        except Exception as e:
            logger.warning(f"Decide node LLM failed: {e}; defaulting to finalize.")
            decision = "finalize"

        logger.info(f"After-batch decision: {decision} (turns_remaining={turns_remaining})")
        return {"batch_all_empty": False, "empty_turn_decision": decision}


class ReasonerAbortNode:
    """Cleans up state after investigation abort (e.g. system error or retry limit)."""

    def __call__(self, state: DeepAnalysisState) -> Dict[str, Any]:
        last_ai_msg = state["messages"][-1] if state["messages"] else None
        tool_call_id = "abort"
        if last_ai_msg and hasattr(last_ai_msg, "tool_calls") and last_ai_msg.tool_calls:
            tool_call_id = last_ai_msg.tool_calls[0]["id"]
        error_summary = f"INVESTIGATION ABORTED: {state['last_error']}"
        tool_msg = ToolMessage(content=error_summary, tool_call_id=tool_call_id)
        failed_desc = state.get("last_query") or (state.get("last_queries_batch") and f"Batch query index {state.get('failed_query_index')}") or "Query"
        return {
            "messages": [tool_msg],
            "failed_investigations": [f"{failed_desc} | Error: {state['last_error']}"],
            "last_query": None,
            "last_queries_batch": None,
            "last_error": None,
            "failed_query_index": None,
            "fixed_query": None,
            "retry_count": 0,
            "tool_call_count": state["tool_call_count"] + 1,
        }


# --- Routers (pure functions for LangGraph conditional_edges) ---

def should_investigate(state: DeepAnalysisState) -> str:
    max_turns = state.get("max_turns", 1)
    if state.get("last_error"):
        if "SYSTEM_ERROR" in state["last_error"] or state.get("retry_count", 0) >= 1:
            return "reasoner_abort"
        if state["tool_call_count"] + 1 >= max_turns:
            return "finalize"
        return "reasoner_abort"
    if state.get("tool_call_count", 0) > 0 and not state.get("last_error"):
        return "findings"
    if state["tool_call_count"] >= max_turns:
        return "finalize"
    if state.get("last_queries_batch"):
        return "execute"
    if state.get("fixed_query"):
        return "execute"
    if state.get("last_query"):
        return "execute"
    return "reasoner"


def should_continue_after_decide(state: DeepAnalysisState) -> str:
    if (state.get("empty_turn_decision") or "").strip().upper() == "FINALIZE":
        return "finalize"
    return "reasoner"
