"""
Pure helpers for SignalProcessor (date parsing, reorder ratio, temporal aggregation, pair counts).
"""

import collections
from datetime import datetime
from typing import Any, List, Dict, Callable, Optional, Tuple

from .constants import DATETIME_FORMAT


def parse_date(date_val: Any) -> Optional[datetime]:
    if isinstance(date_val, datetime):
        return date_val
    if isinstance(date_val, str):
        try:
            return datetime.strptime(date_val, DATETIME_FORMAT)
        except ValueError:
            return None
    return None


def calculate_reorder_ratio(restaurant_counts: collections.Counter) -> float:
    if not restaurant_counts:
        return 0.0
    reorders = sum(1 for count in restaurant_counts.values() if count > 1)
    return round(reorders / len(restaurant_counts), 2)


def temporal_distribution(
    records: List[Dict[str, Any]],
    date_key: str = "created_at",
    parse_fn: Optional[Callable[[Any], Optional[datetime]]] = None,
) -> Tuple[Dict[str, str], Dict[str, str]]:
    """Returns (hour_dist, day_dist) as dicts of str -> count (str keys for JSON)."""
    parse = parse_fn or parse_date
    hour_dist = collections.Counter()
    day_dist = collections.Counter()
    for r in records:
        dt = parse(r.get(date_key))
        if dt:
            hour_dist[dt.strftime("%H")] += 1
            day_dist[dt.strftime("%A")] += 1
    return dict(hour_dist), dict(day_dist)


def pair_counts_from_groups(
    groups: Dict[Any, List[Dict[str, Any]]],
    key_fn: Callable[[Dict[str, Any]], str],
    top_n: int,
    result_key_a: str = "item_a",
    result_key_b: str = "item_b",
) -> List[Dict[str, Any]]:
    """
    For each group (e.g. order_id -> items), collect unique keys per item, then count pairs.
    Returns list of {result_key_a, result_key_b, 'times_together': count} for top_n pairs.
    """
    pair_counts = collections.Counter()
    for _group_id, items in groups.items():
        keys = sorted(set((key_fn(item) or "").strip() for item in items))
        for i in range(len(keys)):
            for j in range(i + 1, len(keys)):
                if keys[i] and keys[j]:
                    pair_counts[(keys[i], keys[j])] += 1
    return [
        {result_key_a: a, result_key_b: b, "times_together": count}
        for (a, b), count in pair_counts.most_common(top_n)
    ]
