#!/usr/bin/env python3

from __future__ import annotations

import argparse
import json
import logging
import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

from dotenv import load_dotenv

load_dotenv()

from scheduler_config import BATCH_SCHEDULER, BatchSchedulerConfig
from core.db_connection import db_session, verify_database_connectivity
from core.logger_config import setup_logging
from profiling_orchestrator import ProfilingOrchestrator
from storage.user_profiling import (
    ensure_profiling_storage_tables,
    publish_profiling_staging,
    tracker_should_skip_scheduled_run,
    truncate_profiling_staging,
)

logger = logging.getLogger(__name__)


def _read_checkpoint(path: Path) -> int:
    if not path.exists():
        return 0
    try:
        data = json.loads(path.read_text(encoding="utf-8"))
        return int(data.get("last_success_user_id", 0) or 0)
    except (json.JSONDecodeError, OSError, TypeError, ValueError):
        logger.warning("Invalid checkpoint %s; starting from 0", path)
        return 0


def _write_checkpoint(path: Path, last_success_user_id: int) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(
        json.dumps({"last_success_user_id": last_success_user_id}, indent=2),
        encoding="utf-8",
    )


def _process_one_user(
    orchestrator: ProfilingOrchestrator,
    user_id: int,
    window_days: int,
    with_deep_analysis: bool,
    activity_gate_days: Optional[int],
) -> None:
    profile = orchestrator.generate_user_profile_data(
        user_id,
        window_days=window_days,
        activity_gate_days=activity_gate_days,
    )
    if not profile or not profile.get("metadata"):
        logger.warning("Skip user_id=%s: no profile metadata", user_id)
        return
    if with_deep_analysis:
        result = orchestrator.run_deep_analysis(user_id, profile, persist=False)
        orchestrator.persist_user_profiling_staging(
            user_id, profile, result if result is not None else ""
        )
    else:
        md = None
        if profile.get("data_completeness") == "inactive":
            md = profile.get("_inactive_markdown")
        orchestrator.persist_user_profiling_staging(user_id, profile, md)


def _worker_run_user(
    user_id: int,
    window_days: int,
    with_deep_analysis: bool,
    activity_gate_days: Optional[int],
) -> Tuple[int, bool, Optional[BaseException]]:
    """One full profiling run; fresh orchestrator per task (thread-safe)."""
    orch = ProfilingOrchestrator(db_session)
    try:
        _process_one_user(orch, user_id, window_days, with_deep_analysis, activity_gate_days)
        return (user_id, True, None)
    except BaseException as exc:  # noqa: BLE001
        logger.exception("Failed user_id=%s", user_id)
        return (user_id, False, exc)


def process_keyset_page_users(
    ids: List[int],
    cfg: BatchSchedulerConfig,
    *,
    dry_run: bool,
    fail_fast: bool,
    recency_days: int,
    checkpoint_path: Optional[Path],
    success_count: int,
    fail_count: int,
    skipped_recency_count: int,
    print_completed_users: bool = False,
) -> Tuple[int, int, int, bool]:
    """
    Process one keyset page of user ids in keyset order.

    When ``cfg.batch_max_workers > 1``, runs up to that many users in parallel per wave;
    inner per-user parallel SQL fetches are unchanged.

    Returns ``(success_count, fail_count, skipped_recency_count, stop_due_to_max_users)``.
    """
    s = success_count
    f = fail_count
    sk = skipped_recency_count
    activity_gate_days: Optional[int] = (
        None if cfg.batch_require_recent_activity else cfg.recent_activity_days
    )
    workers = max(1, cfg.batch_max_workers)
    idx = 0
    n = len(ids)
    stop_max = False

    while idx < n:
        if cfg.max_users is not None and s >= cfg.max_users:
            stop_max = True
            break

        remaining = None if cfg.max_users is None else cfg.max_users - s
        room = workers if remaining is None else min(workers, max(0, remaining))
        if room == 0:
            stop_max = True
            break

        chunk: List[int] = []
        while idx < n and len(chunk) < room:
            uid = ids[idx]
            idx += 1
            if recency_days > 0 and tracker_should_skip_scheduled_run(db_session, uid, recency_days):
                sk += 1
                if dry_run:
                    logger.info(
                        "DRY-RUN skip user_id=%s (tracker: within last %s days)",
                        uid,
                        recency_days,
                    )
                else:
                    logger.debug(
                        "Skip user_id=%s: completed within last %s days (tracker)",
                        uid,
                        recency_days,
                    )
                continue
            if dry_run:
                logger.info("DRY-RUN would process user_id=%s", uid)
                continue
            chunk.append(uid)

        if not chunk:
            continue

        pool_workers = min(len(chunk), workers)
        futs: Dict[Any, int] = {}
        with ThreadPoolExecutor(max_workers=pool_workers) as ex:
            for uid in chunk:
                futs[
                    ex.submit(
                        _worker_run_user,
                        uid,
                        cfg.window_days,
                        cfg.with_deep_analysis,
                        activity_gate_days,
                    )
                ] = uid
            results: Dict[int, Tuple[bool, Optional[BaseException]]] = {}
            for fut in as_completed(futs):
                uid_r, ok, exc = fut.result()
                results[uid_r] = (ok, exc)

        # Checkpoint only for successes in chunk order until the first failure: write after each
        # such success so we keep progress, but never skip a failed user because a later id succeeded.
        first_err: Optional[BaseException] = None
        contiguous_prefix_ok = True
        for uid in chunk:
            ok, exc = results[uid]
            if ok:
                s += 1
                if contiguous_prefix_ok and checkpoint_path:
                    _write_checkpoint(checkpoint_path, uid)
                if print_completed_users:
                    print(f"completed user_id={uid}", flush=True)
            else:
                f += 1
                contiguous_prefix_ok = False
                if first_err is None:
                    first_err = exc
            time.sleep(cfg.sleep_between_users)

        if fail_fast and first_err is not None:
            raise first_err

    return s, f, sk, stop_max


def run_batch(
    cfg: BatchSchedulerConfig,
    *,
    dry_run: bool = False,
    fail_fast: bool = False,
) -> int:
    setup_logging(logging.ERROR, file_level=logging.ERROR)
    verify_database_connectivity()
    ensure_profiling_storage_tables(db_session)
    orchestrator = ProfilingOrchestrator(db_session)
    fetcher = orchestrator.fetcher

    checkpoint_path: Optional[Path] = (
        Path(cfg.checkpoint_file).resolve() if cfg.checkpoint_file else None
    )
    last_id = cfg.start_after_id
    if last_id == 0 and checkpoint_path is not None:
        last_id = _read_checkpoint(checkpoint_path)

    success_count = 0
    fail_count = 0
    skipped_recency_count = 0
    page_num = 0
    recency_days = 0 if cfg.ignore_recency else cfg.skip_recency_days
    natural_scan_complete = False

    if not dry_run and last_id == 0:
        truncate_profiling_staging(db_session)
        logger.info("Truncated user_profiling_staging (fresh batch from user id 0).")

    logger.info(
        "Batch profiling start: batch_size=%s batch_max_workers=%s window_days=%s deep=%s last_id=%s "
        "max_users=%s recency_skip_days=%s activity_filter=%s recent_activity_days=%s",
        cfg.batch_size,
        cfg.batch_max_workers,
        cfg.window_days,
        cfg.with_deep_analysis,
        last_id,
        cfg.max_users,
        recency_days if recency_days > 0 else "off",
        cfg.batch_require_recent_activity,
        cfg.recent_activity_days,
    )

    try:
        while True:
            if cfg.max_users is not None and success_count >= cfg.max_users:
                logger.info("Stopping: reached max_users=%s", cfg.max_users)
                break

            ids = fetcher.fetch_user_ids_keyset(
                last_id,
                cfg.batch_size,
                recent_activity_days=(
                    cfg.recent_activity_days if cfg.batch_require_recent_activity else None
                ),
            )
            if not ids:
                natural_scan_complete = True
                logger.info("No more user ids after id>%s — full scan complete.", last_id)
                break

            page_num += 1
            logger.info("Page %s: fetched %s ids (>%s .. %s)", page_num, len(ids), last_id, ids[-1])

            success_count, fail_count, skipped_recency_count, stop_max = process_keyset_page_users(
                ids,
                cfg,
                dry_run=dry_run,
                fail_fast=fail_fast,
                recency_days=recency_days,
                checkpoint_path=checkpoint_path,
                success_count=success_count,
                fail_count=fail_count,
                skipped_recency_count=skipped_recency_count,
                print_completed_users=True,
            )
            if stop_max:
                logger.info("Stopping: reached max_users=%s", cfg.max_users)
                break

            last_id = ids[-1]
            time.sleep(cfg.sleep_between_batches)

    except KeyboardInterrupt:
        print("Interrupted by user", file=sys.stderr, flush=True)
        return 130

    logger.info(
        "Batch finished: success=%s failed=%s skipped_recency=%s full_scan=%s",
        success_count,
        fail_count,
        skipped_recency_count,
        natural_scan_complete,
    )
    print(
        f"batch_scheduled_profiling finished: success={success_count} failed={fail_count} "
        f"skipped_recency={skipped_recency_count} full_scan={natural_scan_complete}",
        flush=True,
    )

    if (
        natural_scan_complete
        and not dry_run
        and not cfg.skip_publish
        and (cfg.publish_with_failures or fail_count == 0)
    ):
        try:
            publish_profiling_staging(db_session)
            logger.info("Published staging → user_profiling (single transaction).")
            print("Published staging → dbo.user_profiling (single transaction).", flush=True)
        except Exception:
            logger.exception("Publish failed; published table unchanged (rolled back).")
            print("Publish failed; published table unchanged (rolled back).", file=sys.stderr, flush=True)
            return 1
    elif natural_scan_complete and fail_count > 0 and not cfg.publish_with_failures:
        logger.warning(
            "Skipping publish because fail_count=%s (set publish_with_failures in scheduler_config to override).",
            fail_count,
        )
        print(
            f"Skipping publish: fail_count={fail_count} (set publish_with_failures to override).",
            flush=True,
        )
    elif natural_scan_complete and cfg.skip_publish:
        logger.info("Skipping publish (skip_publish in scheduler_config). Staging left populated for inspection.")
        print("Skipping publish (skip_publish=True); staging left as-is.", flush=True)

    return 1 if fail_count > 0 else 0


def build_arg_parser() -> argparse.ArgumentParser:
    return argparse.ArgumentParser(
        description=(
            "Run profiling for many users in throttled keyset batches. "
            "Defaults live in scheduler_config.BATCH_SCHEDULER — edit that module (or import a fork) "
            "for batch_size, checkpoint_file, publish flags, etc."
        ),
        epilog="Example: python batch_scheduled_profiling.py --dry-run",
    )


def main(argv: Optional[list] = None) -> int:
    parser = build_arg_parser()
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Only fetch ids and log; no profiling or DB writes (overrides for this run).",
    )
    parser.add_argument(
        "--fail-fast",
        action="store_true",
        help="Exit on first user failure (overrides for this run).",
    )
    args = parser.parse_args(argv)

    cfg = BATCH_SCHEDULER
    if cfg.start_after_id < 0:
        print("scheduler_config: start_after_id must be >= 0", file=sys.stderr)
        return 2
    if cfg.batch_size < 1:
        print("scheduler_config: batch_size must be >= 1", file=sys.stderr)
        return 2
    if cfg.skip_recency_days < 0:
        print("scheduler_config: skip_recency_days must be >= 0", file=sys.stderr)
        return 2
    if cfg.batch_max_workers < 1:
        print("scheduler_config: batch_max_workers must be >= 1", file=sys.stderr)
        return 2

    print("batch_scheduled_profiling starting:", flush=True)
    print(f"  batch_size={cfg.batch_size}", flush=True)
    print(f"  batch_max_workers={cfg.batch_max_workers}", flush=True)
    print(f"  window_days={cfg.window_days} recent_activity_days={cfg.recent_activity_days}", flush=True)
    print(f"  batch_require_recent_activity={cfg.batch_require_recent_activity}", flush=True)
    print(f"  with_deep_analysis={cfg.with_deep_analysis}", flush=True)
    print(f"  skip_publish={cfg.skip_publish} publish_with_failures={cfg.publish_with_failures}", flush=True)
    print(f"  sleep_between_users={cfg.sleep_between_users} sleep_between_batches={cfg.sleep_between_batches}", flush=True)
    print(f"  max_users={cfg.max_users} checkpoint_file={cfg.checkpoint_file!r}", flush=True)
    print(f"  skip_recency_days={cfg.skip_recency_days} ignore_recency={cfg.ignore_recency}", flush=True)
    print(f"  start_after_id={cfg.start_after_id} dry_run={args.dry_run}", flush=True)

    return run_batch(cfg, dry_run=args.dry_run, fail_fast=args.fail_fast)


if __name__ == "__main__":
    raise SystemExit(main())
