import collections
import math
from datetime import datetime
from typing import List, Dict, Any, Optional

from .constants import (
    EMPTY_ACTIVITY_SIGNALS,
    EMPTY_ORDER_SIGNALS,
    EMPTY_IMPRESSION_SIGNALS,
    EMPTY_SENTIMENT_SIGNALS,
    MIN_ORDERS_WITH_LOCATION,
    MIN_DISTINCT_LOCATIONS,
    MIN_LOCATION_TIME_PAIRS,
    TOP_SEARCHES_N,
    TOP_RESTAURANTS_N,
    TOP_ITEMS_N,
    TOP_ITEM_PAIRS_N,
    TOP_CATEGORY_PAIRS_N,
    TOP_LOCATION_ORDER_N,
    TOP_LOCATION_TIME_N,
    TOP_LOCATION_TRIPLET_N,
    ACTIVITY_VOLUME_LOW_THRESHOLD,
    ORDER_VOLUME_LOW_THRESHOLD,
    POSITIVE_KEYWORDS,
    NEGATIVE_KEYWORDS,
)
from .helpers import (
    parse_date,
    calculate_reorder_ratio,
    temporal_distribution,
    pair_counts_from_groups,
)


class SignalProcessor:
    """
    Computes deterministic behavioral signals from raw data.
    Strictly returning evidence (counts, ratios), not conclusions.
    """

    def _parse_date(self, date_val: Any) -> Optional[datetime]:
        return parse_date(date_val)

    def process_activity_signals(self, raw_activities: List[Dict[str, Any]]) -> Dict[str, Any]:
        if not raw_activities:
            return dict(EMPTY_ACTIVITY_SIGNALS)

        screen_visits = collections.Counter([a.get("screen") for a in raw_activities if a.get("screen")])
        search_keywords = [a.get("search_keyword") for a in raw_activities if a.get("search_keyword")]

        hour_dist, day_dist = temporal_distribution(raw_activities, parse_fn=self._parse_date)

        now = datetime.now()
        activity_7d_count = sum(
            1 for a in raw_activities
            if self._parse_date(a.get("created_at")) and (now - self._parse_date(a.get("created_at"))).days <= 7
        )

        filter_signals = {
            "pure_veg_count": sum(1 for a in raw_activities if a.get("pure_veg_filter") == "1"),
            "offers_count": sum(1 for a in raw_activities if a.get("offers_filter") == "1"),
            "fast_delivery_count": sum(1 for a in raw_activities if a.get("fast_delivery_filter") == "1"),
            "rating_4plus_count": sum(1 for a in raw_activities if a.get("rating4plus_filter") == "1"),
        }

        return {
            "total_events": len(raw_activities),
            "screen_visits": dict(screen_visits),
            "top_searches": collections.Counter(search_keywords).most_common(TOP_SEARCHES_N),
            "temporal_distribution": {"hour": hour_dist, "day": day_dist},
            "filters": filter_signals,
            "platform_usage": dict(collections.Counter([(a.get("platform") or "unknown").lower() for a in raw_activities])),
            "recency": {
                "recent_activity_ratio_7d": round(activity_7d_count / len(raw_activities), 2) if raw_activities else 0,
            },
            "stability": {"activity_volume_low": len(raw_activities) < ACTIVITY_VOLUME_LOW_THRESHOLD},
        }

    def process_order_signals(self, raw_orders: List[Dict[str, Any]], raw_items: List[Dict[str, Any]]) -> Dict[str, Any]:
        if not raw_orders:
            return dict(EMPTY_ORDER_SIGNALS)

        items_by_order = collections.defaultdict(list)
        for item in raw_items:
            items_by_order[item.get("order_id")].append(item)

        total_spent = sum(float(o.get("sub_total", 0)) for o in raw_orders)

        order_hour, order_day = temporal_distribution(raw_orders, parse_fn=self._parse_date)

        item_pairs = pair_counts_from_groups(
            items_by_order,
            lambda i: i.get("menu_name") or i.get("item_name"),
            TOP_ITEM_PAIRS_N,
            "item_a",
            "item_b",
        )
        category_pairs = pair_counts_from_groups(
            items_by_order,
            lambda i: i.get("category_name"),
            TOP_CATEGORY_PAIRS_N,
            "category_a",
            "category_b",
        )

        def _order_location(o: Dict[str, Any]) -> str:
            a = (o.get("order_area") or "").strip()
            c = (o.get("order_city") or "").strip()
            return a or c or ""

        orders_with_location = [(o, _order_location(o)) for o in raw_orders if _order_location(o)]
        location_order_counts = collections.Counter(loc for _, loc in orders_with_location)
        location_order = [
            {"location": loc, "order_count": count}
            for loc, count in location_order_counts.most_common(TOP_LOCATION_ORDER_N)
        ]

        loc_time_hour = collections.Counter()
        loc_time_day = collections.Counter()
        for o, loc in orders_with_location:
            dt = self._parse_date(o.get("created_at"))
            if dt:
                loc_time_hour[(loc, dt.strftime("%H"))] += 1
                loc_time_day[(loc, dt.strftime("%A"))] += 1
        location_time_by_hour = [
            {"location": loc, "hour": h, "count": c}
            for (loc, h), c in loc_time_hour.most_common(TOP_LOCATION_TIME_N)
        ]
        location_time_by_day = [
            {"location": loc, "day": d, "count": c}
            for (loc, d), c in loc_time_day.most_common(TOP_LOCATION_TIME_N)
        ]

        orders_with_loc_count = len(orders_with_location)
        distinct_locations = len(location_order_counts)
        location_time_pairs = len(loc_time_hour) + len(loc_time_day)
        location_order_significant = (
            orders_with_loc_count >= MIN_ORDERS_WITH_LOCATION and distinct_locations >= MIN_DISTINCT_LOCATIONS
        )
        location_time_significant = location_time_pairs >= MIN_LOCATION_TIME_PAIRS
        location_significant = location_order_significant and location_time_significant

        location_triplet = []
        if location_significant and items_by_order:
            triplet_counts = collections.Counter()
            for o, loc in orders_with_location:
                dt = self._parse_date(o.get("created_at"))
                if not dt:
                    continue
                hour_bucket = dt.strftime("%H")
                items = items_by_order.get(o.get("id"), [])
                cats = set(
                    (item.get("category_name") or "").strip()
                    for item in items
                    if (item.get("category_name") or "").strip()
                )
                for cat in cats:
                    if cat:
                        triplet_counts[(loc, hour_bucket, cat)] += 1
            location_triplet = [
                {"location": loc, "hour": h, "category": cat, "count": c}
                for (loc, h, cat), c in triplet_counts.most_common(TOP_LOCATION_TRIPLET_N)
            ]

        restaurants = collections.Counter([o.get("restaurant_name") for o in raw_orders if o.get("restaurant_name")])
        top_restaurant_count = restaurants.most_common(1)[0][1] if restaurants else 0

        item_names = [
            item.get("menu_name") or item.get("item_name")
            for item in raw_items
            if (item.get("menu_name") or item.get("item_name"))
        ]
        item_counts = collections.Counter(item_names)
        top_items = item_counts.most_common(TOP_ITEMS_N)
        top_item_count = top_items[0][1] if top_items else 0

        coupon_orders = [o for o in raw_orders if o.get("coupon_id") and o.get("coupon_id") != 0]

        order_dates = sorted(
            [self._parse_date(o.get("created_at")) for o in raw_orders if self._parse_date(o.get("created_at"))]
        )
        avg_velocity = None
        std_dev_velocity = None
        last_order_days = None
        if len(order_dates) > 1:
            intervals = [(order_dates[i] - order_dates[i - 1]).days for i in range(1, len(order_dates))]
            avg_velocity = round(sum(intervals) / len(intervals), 1)
            mean = sum(intervals) / len(intervals)
            variance = sum((x - mean) ** 2 for x in intervals) / len(intervals)
            std_dev_velocity = round(math.sqrt(variance), 2)
        if order_dates:
            last_order_days = (datetime.now() - order_dates[-1]).days

        return {
            "metrics": {
                "total_orders": len(raw_orders),
                "total_spent": total_spent,
                "avg_order_value": round(total_spent / len(raw_orders), 2) if raw_orders else 0,
            },
            "loyalty": {
                "unique_restaurants_count": len(restaurants),
                "top_restaurants": restaurants.most_common(TOP_RESTAURANTS_N),
                "reorder_ratio": calculate_reorder_ratio(restaurants),
            },
            "inventory": {"top_items": top_items, "total_items_ordered": len(raw_items)},
            "financial": {
                "coupon_order_count": len(coupon_orders),
                "coupon_ratio": round(len(coupon_orders) / len(raw_orders), 2) if raw_orders else 0,
            },
            "velocity": {
                "avg_days_between_orders": avg_velocity,
                "std_dev_days_between_orders": std_dev_velocity,
                "last_order_days_ago": last_order_days,
            },
            "concentration": {
                "top_restaurant_share": round(top_restaurant_count / len(raw_orders), 2) if raw_orders else 0,
                "top_item_share": round(top_item_count / len(raw_items), 2) if raw_items else 0,
            },
            "stability": {"order_volume_low": len(raw_orders) < ORDER_VOLUME_LOW_THRESHOLD},
            "item_pairs": item_pairs,
            "category_pairs": category_pairs,
            "order_temporal": {"hour": order_hour, "day": order_day},
            "location_order": location_order,
            "location_time": {"by_hour": location_time_by_hour, "by_day": location_time_by_day},
            "location_triplet": location_triplet,
            "location_significant": location_significant,
        }

    def process_impression_signals(self, raw_impressions: List[Dict[str, Any]]) -> Dict[str, Any]:
        if not raw_impressions:
            return dict(EMPTY_IMPRESSION_SIGNALS)
        screens = collections.Counter([i.get("screen") for i in raw_impressions if i.get("screen")])
        restaurants_seen = len(set([i.get("restaurant_id") for i in raw_impressions if i.get("restaurant_id")]))
        return {
            "total_impressions": len(raw_impressions),
            "unique_restaurants_seen": restaurants_seen,
            "screen_exposure": dict(screens),
        }

    def process_sentiment_signals(self, raw_reviews: List[Dict[str, Any]]) -> Dict[str, Any]:
        if not raw_reviews:
            return dict(EMPTY_SENTIMENT_SIGNALS)
        ratings = [float(r.get("rating")) for r in raw_reviews if r.get("rating") is not None]
        all_text = " ".join([(r.get("review") or "") for r in raw_reviews if (r.get("review") or "")]).lower()
        words = all_text.split()
        positive_evidence = sum(1 for w in words if w in POSITIVE_KEYWORDS)
        negative_evidence = sum(1 for w in words if w in NEGATIVE_KEYWORDS)
        return {
            "review_count": len(raw_reviews),
            "avg_rating": round(sum(ratings) / len(ratings), 2) if ratings else 0,
            "rating_distribution": dict(collections.Counter(ratings)),
            "text_evidence": {
                "positive_keyword_count": positive_evidence,
                "negative_keyword_count": negative_evidence,
                "total_word_count": len(words),
            },
        }
