
from __future__ import annotations

import argparse
import json
import logging
import sys
import time
from pathlib import Path
from typing import Optional

from dotenv import load_dotenv

load_dotenv()

from scheduler_config import BatchSchedulerConfig
from batch_scheduled_profiling import process_keyset_page_users
from core.db_connection import db_session, verify_database_connectivity
from core.logger_config import setup_logging
from profiling_orchestrator import ProfilingOrchestrator
from storage.user_profiling import (
    ensure_profiling_storage_tables,
    publish_profiling_staging,
    truncate_profiling_staging,
)

logger = logging.getLogger(__name__)

# INFO and above for this script only (overwritten each run). Same format as console.
TESTING_LOG_FILE = Path(__file__).resolve().parent / "batch_scheduled_testing.log"


def _attach_testing_info_log_file() -> None:
    root = logging.getLogger()
    fmt = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    fh = logging.FileHandler(TESTING_LOG_FILE, mode="w", encoding="utf-8")
    fh.setLevel(logging.INFO)
    fh.setFormatter(fmt)
    root.addHandler(fh)


# Manual test: 3 keyset pages × 5 users, 10s pause after each page (edit here).
SMOKE_PAGES = 3
SMOKE_PAGE_SIZE = 5
SMOKE_PAUSE_SECONDS = 10.0
# Only fetch user ids that have ≥1 row in dbo.user_activity_logs in this window (same as production batch).
# Avoids scanning the whole users table and keeps deep analysis (when enabled) off inactive accounts.
SMOKE_RECENT_ACTIVITY_DAYS = 90
# Parallel users per keyset page (1 = same as legacy sequential).
SMOKE_BATCH_MAX_WORKERS = 1

SMOKE_TEST_CFG = BatchSchedulerConfig(
    batch_size=SMOKE_PAGE_SIZE,
    max_users=None,
    sleep_between_batches=SMOKE_PAUSE_SECONDS,
    sleep_between_users=0.0,
    window_days=60,
    with_deep_analysis=True,
    skip_publish=True,
    checkpoint_file=None,
    start_after_id=0,
    ignore_recency=False,
    batch_require_recent_activity=True,
    recent_activity_days=SMOKE_RECENT_ACTIVITY_DAYS,
    batch_max_workers=SMOKE_BATCH_MAX_WORKERS,
)


def _read_checkpoint(path: Path) -> int:
    if not path.exists():
        return 0
    try:
        data = json.loads(path.read_text(encoding="utf-8"))
        return int(data.get("last_success_user_id", 0) or 0)
    except (json.JSONDecodeError, OSError, TypeError, ValueError):
        logger.warning("Invalid checkpoint %s; starting from 0", path)
        return 0


def _write_checkpoint(path: Path, last_success_user_id: int) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(
        json.dumps({"last_success_user_id": last_success_user_id}, indent=2),
        encoding="utf-8",
    )

def run_batch(
    cfg: BatchSchedulerConfig,
    *,
    dry_run: bool = False,
    fail_fast: bool = False,
    max_pages: Optional[int] = None,
) -> int:
    setup_logging()
    _attach_testing_info_log_file()
    logger.info("Writing INFO+ logs to %s", TESTING_LOG_FILE)
    verify_database_connectivity()
    ensure_profiling_storage_tables(db_session)
    orchestrator = ProfilingOrchestrator(db_session)
    fetcher = orchestrator.fetcher

    checkpoint_path: Optional[Path] = (
        Path(cfg.checkpoint_file).resolve() if cfg.checkpoint_file else None
    )
    last_id = cfg.start_after_id
    if last_id == 0 and checkpoint_path is not None:
        last_id = _read_checkpoint(checkpoint_path)

    success_count = 0
    fail_count = 0
    skipped_recency_count = 0
    page_num = 0
    recency_days = 0 if cfg.ignore_recency else cfg.skip_recency_days
    natural_scan_complete = False

    if not dry_run and last_id == 0:
        truncate_profiling_staging(db_session)
        logger.info("Truncated user_profiling_staging (fresh batch from user id 0).")

    logger.info(
        "Batch profiling start: batch_size=%s batch_max_workers=%s window_days=%s deep=%s last_id=%s "
        "max_users=%s recency_skip_days=%s activity_filter=%s recent_activity_days=%s max_pages=%s",
        cfg.batch_size,
        cfg.batch_max_workers,
        cfg.window_days,
        cfg.with_deep_analysis,
        last_id,
        cfg.max_users,
        recency_days if recency_days > 0 else "off",
        cfg.batch_require_recent_activity,
        cfg.recent_activity_days,
        max_pages,
    )

    try:
        while True:
            if cfg.max_users is not None and success_count >= cfg.max_users:
                logger.info("Stopping: reached max_users=%s", cfg.max_users)
                break

            ids = fetcher.fetch_user_ids_keyset(
                last_id,
                cfg.batch_size,
                recent_activity_days=(
                    cfg.recent_activity_days if cfg.batch_require_recent_activity else None
                ),
            )
            if not ids:
                natural_scan_complete = True
                logger.info("No more user ids after id>%s — full scan complete.", last_id)
                break

            page_num += 1
            logger.info("Page %s: fetched %s ids (>%s .. %s)", page_num, len(ids), last_id, ids[-1])

            success_count, fail_count, skipped_recency_count, stop_max = process_keyset_page_users(
                ids,
                cfg,
                dry_run=dry_run,
                fail_fast=fail_fast,
                recency_days=recency_days,
                checkpoint_path=checkpoint_path,
                success_count=success_count,
                fail_count=fail_count,
                skipped_recency_count=skipped_recency_count,
            )
            if stop_max:
                logger.info("Stopping: reached max_users=%s", cfg.max_users)
                break

            last_id = ids[-1]
            time.sleep(cfg.sleep_between_batches)

            if max_pages is not None and page_num >= max_pages:
                logger.info("Stopping: reached max_pages=%s (testing limit).", max_pages)
                break

    except KeyboardInterrupt:
        logger.warning("Interrupted by user")
        return 130

    logger.info(
        "Batch finished: success=%s failed=%s skipped_recency=%s full_scan=%s",
        success_count,
        fail_count,
        skipped_recency_count,
        natural_scan_complete,
    )

    if (
        natural_scan_complete
        and not dry_run
        and not cfg.skip_publish
        and (cfg.publish_with_failures or fail_count == 0)
    ):
        try:
            publish_profiling_staging(db_session)
            logger.info("Published staging → user_profiling (single transaction).")
        except Exception:
            logger.exception("Publish failed; published table unchanged (rolled back).")
            return 1
    elif natural_scan_complete and fail_count > 0 and not cfg.publish_with_failures:
        logger.warning(
            "Skipping publish because fail_count=%s (set publish_with_failures in scheduler_config to override).",
            fail_count,
        )
    elif natural_scan_complete and cfg.skip_publish:
        logger.info("Skipping publish (skip_publish in scheduler_config). Staging left populated for inspection.")

    return 1 if fail_count > 0 else 0


def build_arg_parser() -> argparse.ArgumentParser:
    return argparse.ArgumentParser(
        description=(
            "Manual batch test: fixed pages/users/pause in this file (SMOKE_*), not BATCH_SCHEDULER."
        ),
        epilog="Example: python batch_scheduled_testing.py --dry-run",
    )


def main(argv: Optional[list] = None) -> int:
    parser = build_arg_parser()
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Only fetch ids and log; no profiling or DB writes (overrides for this run).",
    )
    parser.add_argument(
        "--fail-fast",
        action="store_true",
        help="Exit on first user failure (overrides for this run).",
    )
    args = parser.parse_args(argv)

    cfg = SMOKE_TEST_CFG
    if cfg.start_after_id < 0:
        print("batch_scheduled_testing: start_after_id must be >= 0", file=sys.stderr)
        return 2
    if cfg.batch_size < 1:
        print("batch_scheduled_testing: batch_size must be >= 1", file=sys.stderr)
        return 2
    if cfg.skip_recency_days < 0:
        print("batch_scheduled_testing: skip_recency_days must be >= 0", file=sys.stderr)
        return 2
    if cfg.batch_max_workers < 1:
        print("batch_scheduled_testing: batch_max_workers must be >= 1", file=sys.stderr)
        return 2

    print(
        f"Testing run: {SMOKE_PAGES} pages × {SMOKE_PAGE_SIZE} users, "
        f"batch_max_workers={cfg.batch_max_workers}, "
        f"{SMOKE_PAUSE_SECONDS}s between pages, skip_publish={cfg.skip_publish}",
        file=sys.stderr,
    )
    return run_batch(
        cfg,
        dry_run=args.dry_run,
        fail_fast=args.fail_fast,
        max_pages=SMOKE_PAGES,
    )


if __name__ == "__main__":
    raise SystemExit(main())
