import os
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional
from data_fetcher import DataFetcher
from signal_processing import SignalProcessor
from statistical_analysis import StatisticalAnalyzer
from deep_analysis import DeepAnalysisAgent
from deep_analysis.helpers import normalize_markdown_output
from core.logger_config import get_logger
from storage.user_profiling import (
    insert_user_profiling_row,
    insert_user_profiling_staging_row,
)

logger = get_logger(__name__)

# Parallel sub-fetches in _fetch_all_raw_data; must not exceed pyodbc DB_POOL_SIZE (same env var).
_FETCH_TASKS = 7


def _inner_fetch_pool_workers() -> int:
    try:
        pool_sz = int(os.getenv("DB_POOL_SIZE", "7"))
    except ValueError:
        pool_sz = 7
    return max(1, min(_FETCH_TASKS, pool_sz))


class ProfilingOrchestrator:
    """
    Coordinates the profiling pipeline for a single user.
    Manages time windows, data fetching, and signal processing.

    Call :func:`storage.user_profiling.ensure_profiling_storage_tables` once per process
    before constructing orchestrators (API / batch startup), not in ``__init__``.
    """

    def __init__(self, db_connection):
        self._db = db_connection
        self.fetcher = DataFetcher(db_connection)
        self.processor = SignalProcessor()
        self.statistical_analyzer = StatisticalAnalyzer(db_session=db_connection)
        self.deep_analyzer = DeepAnalysisAgent()

        # Default rolling window: 60 days
        self.default_window_days = 60
        # Milestone timings (seconds) from last run; cleared at start of each run
        self._last_timings: Dict[str, float] = {}

    def get_last_timings(self) -> Dict[str, float]:
        """Returns milestone timings (seconds) from the last profile + deep_analysis run."""
        return dict(self._last_timings)

    def generate_user_profile_data(
        self,
        user_id: int,
        window_days: Optional[int] = None,
        activity_gate_days: Optional[int] = 90,
    ) -> Dict[str, Any]:
        """
        Runs the full signal extraction pipeline for a user.

        If ``activity_gate_days`` is a positive int and the user has no rows in
        ``user_activity_logs`` in that window, returns a minimal **inactive** profile
        (no heavy fetches). Pass ``None`` to disable the gate (e.g. batch already
        filtered by activity).
        """
        if window_days is None:
            window_days = self.default_window_days

        if activity_gate_days is not None and activity_gate_days > 0:
            if not self.fetcher.user_has_recent_activity_logs(user_id, activity_gate_days):
                logger.info(
                    "User %s: no user_activity_logs in last %s days — inactive placeholder profile",
                    user_id,
                    activity_gate_days,
                )
                return self._inactive_user_profile(user_id, window_days, activity_gate_days)

        logger.info(f"Starting profile generation for User {user_id} (Window: {window_days} days)")
        self._last_timings = {}
        t_profile_start = time.perf_counter()

        end_date = datetime.now()
        start_date = end_date - timedelta(days=window_days)

        str_start = start_date.strftime("%Y-%m-%d %H:%M:%S")
        str_end = end_date.strftime("%Y-%m-%d %H:%M:%S")

        # 1. Fetch Raw Data
        raw_data = self._fetch_all_raw_data(user_id, str_start, str_end)
        t_fetch_done = time.perf_counter()
        self._last_timings["fetch_s"] = round(t_fetch_done - t_profile_start, 2)
        logger.info(f"Milestone: fetch done in {self._last_timings['fetch_s']}s")

        # 2. Process Deterministic Signals
        signals = self._process_all_signals(raw_data)
        # 3. User basic details (for name in report)
        basic_details = self.fetcher.fetch_user_basic_details(user_id)
        # 4. Calculate Confidence & Meta
        metadata = self._generate_metadata(user_id, raw_data, window_days, basic_details)
        # 5. Construct Base Profile
        profile_data = {
            "metadata": metadata,
            "signals": signals,
            "data_completeness": self._check_data_completeness(signals)
        }
        # 6. Add Statistical Analysis
        stats = self.statistical_analyzer.analyze(profile_data)
        profile_data['statistics'] = stats

        t_profile_done = time.perf_counter()
        self._last_timings["signals_stats_s"] = round(t_profile_done - t_fetch_done, 2)
        self._last_timings["profile_total_s"] = round(t_profile_done - t_profile_start, 2)
        logger.info(f"Milestone: signals+stats done in {self._last_timings['signals_stats_s']}s | profile total {self._last_timings['profile_total_s']}s")

        return profile_data

    def _inactive_user_profile(
        self,
        user_id: int,
        window_days: int,
        activity_days: int,
    ) -> Dict[str, Any]:
        """Minimal profile + markdown text when user has no recent activity logs."""
        basic = self.fetcher.fetch_user_basic_details(user_id)
        end = datetime.now()
        start = end - timedelta(days=activity_days)
        range_label = f"{start.strftime('%Y-%m-%d')} – {end.strftime('%Y-%m-%d')}"
        md = (
            f"This user is not active in the last {activity_days} days "
            f"(no rows in `user_activity_logs` between {range_label})."
        )
        self._last_timings = {
            "fetch_s": 0.0,
            "signals_stats_s": 0.0,
            "profile_total_s": 0.0,
        }
        return {
            "metadata": {
                "user_id": user_id,
                "name": _display_name_from_user_row(basic),
                "profile_date": datetime.now().strftime("%Y-%m-%d"),
                "data_window_days": window_days,
            },
            "signals": {},
            "data_completeness": "inactive",
            "_inactive_markdown": md,
            "statistics": {"confidence": 0.0},
        }

    def persist_user_profiling(
        self,
        user_id: int,
        profile_data: Dict[str, Any],
        deep_analysis_markdown: Optional[str] = None,
    ) -> None:
        """Upsert **published** dbo.user_profiling (+ tracker). Used by API / forced runs."""
        snapshot = dict(profile_data)
        snapshot["pipeline_timings"] = dict(self._last_timings)
        insert_user_profiling_row(self._db, user_id, snapshot, deep_analysis_markdown)

    def persist_user_profiling_staging(
        self,
        user_id: int,
        profile_data: Dict[str, Any],
        deep_analysis_markdown: Optional[str] = None,
    ) -> None:
        """Upsert **staging** only (scheduled batch). Call publish after full scan."""
        snapshot = dict(profile_data)
        snapshot["pipeline_timings"] = dict(self._last_timings)
        insert_user_profiling_staging_row(self._db, user_id, snapshot, deep_analysis_markdown)

    def run_deep_analysis(
        self,
        user_id: int,
        profile_data: Dict[str, Any],
        max_turns: Optional[int] = None,
        persist: bool = True,
    ) -> str:
        """
        Triggers the structured deep analysis flow.
        Skips LLM for inactive users (no recent activity logs) and low-signal users.
        """
        if profile_data.get("data_completeness") == "inactive":
            msg = (profile_data.get("_inactive_markdown") or "").strip() or (
                "This user has no qualifying recent activity in user_activity_logs."
            )
            logger.info("Skipping deep analysis for User %s (inactive)", user_id)
            self._last_timings["deep_analysis_s"] = 0.0
            markdown_out = normalize_markdown_output(msg)
            if persist:
                self.persist_user_profiling(user_id, profile_data, markdown_out)
            return markdown_out

        if profile_data.get("data_completeness") == "low_signal":
            logger.info(f"Skipping Deep Analysis for User {user_id} (low-signal)")
            low_msg = "Insufficient data for deep analysis. Profile is marked as low-signal."
            markdown_out = normalize_markdown_output(low_msg)
            if persist:
                self.persist_user_profiling(user_id, profile_data, markdown_out)
            return markdown_out

        logger.info(f"Triggering Deep Analysis for User {user_id}")
        profile_summary = self._build_profile_summary(profile_data)
        order_s = (profile_data.get("signals") or {}).get("order", {})
        ip, cp = order_s.get("item_pairs", []), order_s.get("category_pairs", [])
        if ip or cp:
            logger.info(f"Bought together: {len(ip)} item pairs, {len(cp)} category pairs")
        t_deep_start = time.perf_counter()
        result = self.deep_analyzer.run(user_id, profile_summary, max_turns=max_turns)
        self._last_timings["deep_analysis_s"] = round(time.perf_counter() - t_deep_start, 2)
        logger.info(f"Milestone: deep_analysis done in {self._last_timings['deep_analysis_s']}s")
        markdown_out = normalize_markdown_output(result if result is not None else "")
        if persist:
            self.persist_user_profiling(user_id, profile_data, markdown_out)
        return markdown_out

    def _fetch_all_raw_data(self, user_id: int, start: str, end: str) -> Dict[str, Any]:
        # First: fetch orders (needed for order_ids before order_items)
        orders = self.fetcher.fetch_orders(user_id, start, end)
        order_ids = [o.get("id") for o in orders if o.get("id")]

        # Then: run all other fetches in parallel (order_items + independent queries)
        def fetch_order_items():
            return self.fetcher.fetch_order_items(order_ids)

        def fetch_activities():
            return self.fetcher.fetch_activity_logs(user_id, start, end)

        def fetch_impressions():
            return self.fetcher.fetch_impressions(user_id, start, end)

        def fetch_reviews():
            return self.fetcher.fetch_reviews_and_ratings(user_id, start, end)

        def fetch_carts():
            return self.fetcher.fetch_cart_data(user_id, start, end)

        def fetch_dietary():
            return self.fetcher.fetch_dietary_preferences(user_id)

        def fetch_cart_add_to_cart():
            return self.fetcher.fetch_cart_add_to_cart_signals(user_id, start, end)

        workers = _inner_fetch_pool_workers()
        results = {}
        with ThreadPoolExecutor(max_workers=workers) as executor:
            future_order_items = executor.submit(fetch_order_items)
            future_activities = executor.submit(fetch_activities)
            future_impressions = executor.submit(fetch_impressions)
            future_reviews = executor.submit(fetch_reviews)
            future_carts = executor.submit(fetch_carts)
            future_dietary = executor.submit(fetch_dietary)
            future_cart_add_to_cart = executor.submit(fetch_cart_add_to_cart)

            results["order_items"] = future_order_items.result()
            results["activities"] = future_activities.result()
            results["impressions"] = future_impressions.result()
            results["reviews"] = future_reviews.result()
            results["carts"] = future_carts.result()
            results["dietary_raw"] = future_dietary.result()
            results["cart_add_to_cart"] = future_cart_add_to_cart.result()

        return {
            "activities": results["activities"],
            "orders": orders,
            "order_items": results["order_items"],
            "impressions": results["impressions"],
            "reviews": results["reviews"],
            "carts": results["carts"],
            "dietary_raw": results["dietary_raw"],
            "cart_add_to_cart": results["cart_add_to_cart"],
        }

    def _process_all_signals(self, raw: Dict[str, Any]) -> Dict[str, Any]:
        for key in ("activities", "orders", "order_items", "impressions", "reviews", "dietary_raw", "carts"):
            if raw.get(key) is None:
                raise ValueError(f"Raw data key '{key}' is missing or None; fetch must return a list/dict for each key. Check data_fetcher.")
        return {
            "activity": self.processor.process_activity_signals(raw["activities"]),
            "order": self.processor.process_order_signals(raw["orders"], raw["order_items"]),
            "impression": self.processor.process_impression_signals(raw["impressions"]),
            "sentiment": self.processor.process_sentiment_signals(raw["reviews"]),
            "dietary_evidence": self._process_dietary_evidence(raw["dietary_raw"], raw["order_items"]),
            "cart": self._process_cart_signals(raw["carts"], raw.get("cart_add_to_cart") or {}),
        }

    def _process_dietary_evidence(self, raw_pref: Dict[str, Any], raw_items: List[Dict[str, Any]]) -> Dict[str, Any]:
        # Logic moved from the old 'God Class' but stripped of interpretations
        # Just returning counts and declared slugs
        return {
            "declared_slugs": raw_pref.get('dietary_preferences', []),
            "special_requests": raw_pref.get('special_requests'),
            "item_count": len(raw_items)
            # Signal processor could be extended for text analysis on item names
        }

    def _process_cart_signals(
        self,
        raw_carts: List[Dict[str, Any]],
        cart_add_to_cart: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, Any]:
        out = {
            "cart_count": len(raw_carts),
            "unique_restaurants_affected": len(set([c.get("restaurant_id") for c in raw_carts if c.get("restaurant_id")])),
            "items_in_cart_window": [],
            "top_considered_not_bought": [],
            "considered_not_bought_by_category": [],
            "add_to_cart_stats": {
                "add_to_cart_events_per_item_ordered": 0.0,
                "count_distinct_items_added_but_never_ordered": 0,
                "count_distinct_items_added_then_removed_from_cart": 0,
                "total_add_to_cart_events_in_window": 0,
                "total_items_ordered_in_window": 0,
            },
        }
        if cart_add_to_cart:
            out["items_in_cart_window"] = cart_add_to_cart.get("items_in_window") or []
            out["top_considered_not_bought"] = cart_add_to_cart.get("top_considered_not_bought") or []
            out["considered_not_bought_by_category"] = (
                cart_add_to_cart.get("considered_not_bought_by_category") or []
            )
            if cart_add_to_cart.get("stats"):
                out["add_to_cart_stats"] = cart_add_to_cart["stats"]
        return out

    def _generate_metadata(
        self,
        user_id: int,
        raw: Dict[str, Any],
        window: int,
        basic_details: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, Any]:
        name = _display_name_from_user_row(basic_details or {})
        return {
            "user_id": user_id,
            "name": name,
            "profile_date": datetime.now().strftime("%Y-%m-%d"),
            "data_window_days": window,
        }

    def _build_profile_summary(self, profile_data: Dict[str, Any]) -> str:
        """Delegates to module-level build_profile_summary."""
        return build_profile_summary(profile_data)

    def _check_data_completeness(self, signals: Dict[str, Any]) -> str:
        sig = signals or {}
        order_count = ((sig.get("order") or {}).get("metrics") or {}).get("total_orders", 0)
        activity_count = (sig.get("activity") or {}).get("total_events", 0)

        if order_count == 0 and activity_count < 5:
            return "low_signal"
        if order_count > 3:
            return "high_signal"
        return "medium_signal"


def _display_name_from_user_row(row: Dict[str, Any]) -> str:
    """Derive display name from users table row (column names may vary)."""
    if not row:
        return ""
    for key in ("name", "full_name", "display_name"):
        val = (row.get(key) or "").strip()
        if val:
            return val
    first = (row.get("first_name") or "").strip()
    last = (row.get("last_name") or "").strip()
    return f"{first} {last}".strip() if (first or last) else ""


def sanitize_profile_for_api(profile: Dict[str, Any]) -> None:
    """
    Mutates profile to remove internal/redundant fields before API or dashboard response.
    Pipeline and model summary still use full data; call this only when returning to client.
    """
    if not profile:
        return
    profile.pop("_inactive_markdown", None)
    # Statistics: drop confidence (passed to model only), recency, stability
    stats = profile.get("statistics", {})
    stats.pop("confidence", None)
    stats.pop("recency", None)
    stats.pop("stability", None)
    # Signals: drop internal flags and redundant or low-value fields
    signals = profile.get("signals", {})
    activity = signals.get("activity", {})
    if isinstance(activity, dict):
        activity.pop("stability", None)
    order = signals.get("order", {})
    if isinstance(order, dict):
        order.pop("stability", None)
        order.pop("concentration", None)
        velocity = order.get("velocity", {})
        if isinstance(velocity, dict):
            velocity.pop("std_dev_days_between_orders", None)
    sentiment = signals.get("sentiment", {})
    if isinstance(sentiment, dict):
        sentiment.pop("text_evidence", None)
    dietary = signals.get("dietary_evidence", {})
    if isinstance(dietary, dict):
        dietary.pop("item_count", None)
    impression = signals.get("impression", {})
    if isinstance(impression, dict):
        impression.pop("screen_exposure", None)


def build_profile_summary(profile_data: Dict[str, Any]) -> str:
    """
    Builds a readable paragraph summary of the profile for the deep analysis agent.
    Does not trim; summarises key facts into prose so the agent gets context without raw JSON.
    Can be used by orchestrator or by other entry points (e.g. API).
    """
    meta = profile_data.get("metadata") or {}
    signals = profile_data.get("signals") or {}
    stats = profile_data.get("statistics") or {}
    completeness = profile_data.get("data_completeness", "unknown")

    order_s = signals.get("order") or {}
    order_metrics = order_s.get("metrics") or {}
    order_loyalty = order_s.get("loyalty") or {}
    order_inv = order_s.get("inventory") or {}
    order_velocity = order_s.get("velocity") or {}
    order_fin = order_s.get("financial") or {}

    activity_s = signals.get("activity") or {}
    impression_s = signals.get("impression") or {}

    name = (meta.get("name") or "").strip()
    first_line = (
        f"User {meta.get('user_id')} — profile date {meta.get('profile_date')}, "
        f"window {meta.get('data_window_days')} days. Data completeness: {completeness}."
    )
    if name:
        first_line = f"Name: {name}. " + first_line
    parts = [first_line]
    parts.append(
        f"Orders: {order_metrics.get('total_orders', 0)} total, "
        f"total spent {order_metrics.get('total_spent', 0)}, "
        f"avg order value {order_metrics.get('avg_order_value', 0)}. "
    )
    parts.append(
        f"Unique restaurants: {order_loyalty.get('unique_restaurants_count', 0)}; "
        f"reorder ratio {order_loyalty.get('reorder_ratio', 0)}. "
    )
    if order_loyalty.get("top_restaurants"):
        top = order_loyalty["top_restaurants"][:3]
        parts.append(f"Top restaurants by order count: {', '.join(str(x) for x in top)}. ")
    parts.append(
        f"Top items ordered: {order_inv.get('total_items_ordered', 0)} total; "
        f"top items list: {(order_inv.get('top_items') or [])[:5]}. "
    )
    parts.append(
        f"Velocity: avg days between orders {order_velocity.get('avg_days_between_orders')}, "
        f"last order {order_velocity.get('last_order_days_ago')} days ago. "
    )
    parts.append(
        f"Coupon usage: {order_fin.get('coupon_order_count', 0)} orders with coupon, "
        f"ratio {order_fin.get('coupon_ratio', 0)}. "
    )
    # Always report bought-together when there are orders (so agent knows we looked)
    num_orders = order_metrics.get("total_orders", 0)
    item_pairs = (order_s.get("item_pairs") or [])[:5]
    cat_pairs = (order_s.get("category_pairs") or [])[:5]
    if num_orders > 0:
        if item_pairs or cat_pairs:
            if item_pairs:
                parts.append(
                    f"Item pairs (ordered together): {[(p.get('item_a'), p.get('item_b'), p.get('times_together')) for p in item_pairs]}. "
                )
            if cat_pairs:
                parts.append(
                    f"Category pairs (ordered together): {[(p.get('category_a'), p.get('category_b'), p.get('times_together')) for p in cat_pairs]}. "
                )
        else:
            parts.append(
                "Bought together: 0 item pairs, 0 category pairs (orders exist but no multi-item combos). "
            )
    order_temporal = order_s.get("order_temporal") or {"hour": {}, "day": {}}
    if order_temporal.get("hour") or order_temporal.get("day"):
        parts.append(
            f"Order temporal (when they place orders): hour {order_temporal.get('hour', {})}, day {order_temporal.get('day', {})}. "
        )
    loc_order = order_s.get("location_order") or []
    loc_time = order_s.get("location_time") or {}
    loc_triplet = order_s.get("location_triplet") or []
    loc_sig = order_s.get("location_significant", False)
    if num_orders > 0:
        if loc_order or loc_time.get("by_hour") or loc_time.get("by_day"):
            parts.append(
                f"Location (where they order): location_order={loc_order[:5]}, location_time by_hour={(loc_time.get('by_hour') or [])[:3]}, by_day={(loc_time.get('by_day') or [])[:3]}. "
            )
            if loc_sig and loc_triplet:
                parts.append(f"Location triplet (location+time+category, only when significant): {loc_triplet[:5]}. ")
            elif num_orders > 0 and not loc_sig:
                parts.append("Location triplet: not computed (location+order or location+time below threshold). ")
    parts.append(
        f"Activity: {activity_s.get('total_events', 0)} events; "
        f"top searches {activity_s.get('top_searches', [])}; "
        f"filters: pure_veg {(activity_s.get('filters') or {}).get('pure_veg_count', 0)}, "
        f"offers {(activity_s.get('filters') or {}).get('offers_count', 0)}. "
    )
    parts.append(f"Impressions: {impression_s.get('total_impressions', 0)}. ")

    if stats:
        parts.append(
            f"Statistics — temporal: dominant hour share {(stats.get('temporal') or {}).get('dominant_hour_share')}, "
            f"entropy {(stats.get('temporal') or {}).get('entropy')}; "
        )
        parts.append(
            f"concentration (HHI): restaurant {(stats.get('concentration') or {}).get('restaurant_hhi')}, "
            f"item {(stats.get('concentration') or {}).get('item_hhi')}; "
        )
        parts.append(
            f"engagement: browse-to-order {(stats.get('engagement') or {}).get('browse_to_order_ratio')}, "
            f"search-to-order {(stats.get('engagement') or {}).get('search_to_order_ratio')}; "
        )
        parts.append(f"math confidence {stats.get('confidence')}.")
        cart_stats = stats.get("cart") or {}
        if cart_stats:
            parts.append(
                f"Cart add-to-cart: add_to_cart_events_per_item_ordered {cart_stats.get('add_to_cart_events_per_item_ordered', 0)}, "
                f"count_distinct_items_added_but_never_ordered {cart_stats.get('count_distinct_items_added_but_never_ordered', 0)}, "
                f"count_distinct_items_added_then_removed_from_cart {cart_stats.get('count_distinct_items_added_then_removed_from_cart', 0)}. "
            )

    cart_s = signals.get("cart") or {}
    top_not_bought = (cart_s.get("top_considered_not_bought") or [])[:5]
    not_bought_by_cat = (cart_s.get("considered_not_bought_by_category") or [])[:5]
    if top_not_bought or not_bought_by_cat:
        if top_not_bought:
            parts.append(
                f"Top items added to cart but never ordered: {[p.get('item_name') for p in top_not_bought]}. "
            )
        if not_bought_by_cat:
            parts.append(
                f"Considered but not bought by category: {[(p.get('category_name'), p.get('distinct_items_added_but_never_ordered_in_this_category')) for p in not_bought_by_cat]}. "
            )

    # Raw order signals block so LLM always receives structure (including when empty)
    raw_item_pairs = order_s.get("item_pairs") or []
    raw_cat_pairs = order_s.get("category_pairs") or []
    raw_order_temporal = order_s.get("order_temporal") or {"hour": {}, "day": {}}
    raw_location_order = order_s.get("location_order") or []
    raw_location_time = order_s.get("location_time") or {}
    raw_location_triplet = order_s.get("location_triplet") or []
    raw_location_significant = order_s.get("location_significant", False)
    parts.append(
        " RAW ORDER SIGNALS (exact data): "
        f"item_pairs={raw_item_pairs}, "
        f"category_pairs={raw_cat_pairs}, "
        f"order_temporal={raw_order_temporal}, "
        f"location_order={raw_location_order}, location_time={raw_location_time}, "
        f"location_triplet={raw_location_triplet}, location_significant={raw_location_significant}."
    )

    return " ".join(parts).replace("  ", " ").strip()
