"""
배치 스크립트:
shopprod_group.winner_* 가 shopprod_group_map 기준 "최적 winner"와 불일치하는 group_id를 찾고,
shopprod_group_winner_policy.refresh_group_winner_and_tokens_if_needed()로
winner 재선정 + winner 변경 시 content_token 정리/재생성까지 수행한다.

대상 쿼리(요청 쿼리 기반):
  - group_id별 price/grade/reg_date/icode 기준 1등(rn=1)을 계산
  - g.winner_* 와 불일치하면 처리 대상

사용 예)
  python batch_fix_group_winner_and_tokens.py
  python batch_fix_group_winner_and_tokens.py --limit 5000
  python batch_fix_group_winner_and_tokens.py --workers 4 --log-file log/fix_{worker_idx}.log
  python batch_fix_group_winner_and_tokens.py --dry-run
"""

from __future__ import annotations

import argparse
import asyncio
import csv
import logging
import os
import sys
import time
from dataclasses import dataclass
from datetime import datetime
import multiprocessing as mp

import asyncpg

from shopprod_group_winner_policy import refresh_group_winner_and_tokens_if_needed


from db_config import DB_INFO_ASYNCPG as DB_INFO


QUERY_TMPL = """
SELECT
  g.group_id,
  r.vender_code AS best_vender_code,
  r.icode       AS best_icode,
  r.price       AS best_price
FROM mlinkdw.shopprod_group2 g
JOIN (
  SELECT
    group_id,
    vender_code,
    icode,
    price,
    ROW_NUMBER() OVER (
      PARTITION BY group_id
      ORDER BY
        price NULLS LAST,
        CASE
          WHEN vender_grade ~ '^[0-9]+$' THEN vender_grade::int
          ELSE 2147483647
        END,
        reg_date ASC NULLS LAST,
        icode ASC
    ) rn
  FROM mlinkdw.shopprod_group_map2
) r ON r.group_id = g.group_id AND r.rn = 1
WHERE
  (g.winner_vender_code IS DISTINCT FROM r.vender_code
   OR g.winner_icode    IS DISTINCT FROM r.icode
   OR g.winner_price    IS DISTINCT FROM r.price)
  AND g.group_id > $1
  AND ($2::int = 1 OR (g.group_id % $2::int) = $3::int)
ORDER BY g.group_id
LIMIT $4;
""".strip()


def _setup_logger(log_file: str | None, *, name: str, force: bool = False):
    logger = logging.getLogger(str(name))
    if force and logger.handlers:
        for h in list(logger.handlers):
            try:
                h.flush()
                h.close()
            except Exception:
                pass
            try:
                logger.removeHandler(h)
            except Exception:
                pass
    if (not force) and logger.handlers:
        return logger
    logger.setLevel(logging.INFO)
    fmt = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")

    sh = logging.StreamHandler(stream=sys.stdout)
    sh.setFormatter(fmt)
    logger.addHandler(sh)

    lf = (str(log_file or "")).strip()
    if lf:
        try:
            os.makedirs(os.path.dirname(os.path.abspath(lf)) or ".", exist_ok=True)
        except Exception:
            pass
        fh = logging.FileHandler(lf, encoding="utf-8")
        fh.setFormatter(fmt)
        logger.addHandler(fh)

    return logger


def _with_worker_suffix(path: str, *, worker_idx: int) -> str:
    p = str(path or "").strip()
    if not p:
        return p
    try:
        if ("{worker_idx}" in p) or ("{pid}" in p):
            return p.format(worker_idx=int(worker_idx), pid=os.getpid())
    except Exception:
        pass
    root, ext = os.path.splitext(p)
    return f"{root}.w{int(worker_idx)}{ext}"


def _resolve_worker_path(path: str, *, worker_idx: int) -> str:
    p = str(path or "").strip()
    if not p:
        return p
    try:
        if ("{worker_idx}" in p) or ("{pid}" in p):
            return p.format(worker_idx=int(worker_idx), pid=os.getpid())
    except Exception:
        pass
    return p


def _read_checkpoint(path: str) -> int:
    try:
        if not path or (not os.path.exists(path)):
            return 0
        with open(path, "r", encoding="utf-8") as fp:
            s = (fp.read() or "").strip()
        return int(s) if s else 0
    except Exception:
        return 0


def _write_checkpoint(path: str, value: int):
    if not path:
        return
    try:
        os.makedirs(os.path.dirname(os.path.abspath(path)) or ".", exist_ok=True)
    except Exception:
        pass
    tmp = f"{path}.tmp"
    with open(tmp, "w", encoding="utf-8") as fp:
        fp.write(str(int(value)))
        fp.write("\n")
    try:
        os.replace(tmp, path)
    except Exception:
        try:
            with open(path, "w", encoding="utf-8") as fp:
                fp.write(str(int(value)))
                fp.write("\n")
        except Exception:
            pass


async def _run_worker_async(args, *, worker_idx: int, workers: int) -> int:
    log_file = _resolve_worker_path((args.log_file or "").strip(), worker_idx=worker_idx)
    if log_file and int(workers) > 1 and ("{worker_idx}" not in log_file) and ("{pid}" not in log_file):
        log_file = _with_worker_suffix(log_file, worker_idx=worker_idx)
    log = _setup_logger(log_file, name=f"batch_fix_winner.w{worker_idx}", force=True)

    # checkpoint
    checkpoint_path = (args.checkpoint_file or "").strip()
    if not checkpoint_path:
        checkpoint_path = os.path.join(str(args.checkpoint_dir), f"checkpoint_fix_winner_w{int(worker_idx)}.txt")
    checkpoint_path = _resolve_worker_path(checkpoint_path, worker_idx=worker_idx)
    if checkpoint_path and int(workers) > 1 and ("{worker_idx}" not in checkpoint_path) and ("{pid}" not in checkpoint_path):
        checkpoint_path = _with_worker_suffix(checkpoint_path, worker_idx=worker_idx)
    checkpoint_path = os.path.abspath(checkpoint_path) if checkpoint_path else ""
    min_group_id = max(int(getattr(args, "min_group_id", 0) or 0), _read_checkpoint(checkpoint_path))

    # fail csv
    fail_csv_path = (args.fail_csv or "").strip()
    if fail_csv_path and int(workers) > 1:
        fail_csv_path = _with_worker_suffix(fail_csv_path, worker_idx=worker_idx)
    if not fail_csv_path:
        ts = datetime.now().strftime("%Y%m%d_%H%M%S")
        fail_csv_path = os.path.join(str(args.fail_csv_dir), f"fail_fix_winner_{ts}_w{int(worker_idx)}.csv")
    fail_csv_path = os.path.abspath(fail_csv_path)
    fail_csv_header = ["ts", "batch_idx", "idx_in_batch", "group_id", "best_vender_code", "best_icode", "best_price", "error"]
    try:
        os.makedirs(os.path.dirname(fail_csv_path), exist_ok=True)
    except Exception:
        pass
    need_header = True
    try:
        if os.path.exists(fail_csv_path) and os.path.getsize(fail_csv_path) > 0:
            need_header = False
    except Exception:
        need_header = True
    fail_fp = open(fail_csv_path, "a", newline="", encoding="utf-8-sig")
    fail_writer = csv.DictWriter(fail_fp, fieldnames=fail_csv_header)
    if need_header:
        fail_writer.writeheader()
        fail_fp.flush()

    log.info(
        f"worker={worker_idx}/{workers} pid={os.getpid()} "
        f"fail-csv={fail_csv_path} checkpoint={checkpoint_path} min_group_id={min_group_id}"
    )

    conn: asyncpg.Connection | None = None
    total_ok = 0
    total_fail = 0
    batch_no_progress = 0
    batch_idx = 0
    t0 = time.perf_counter()

    async def _ensure_conn():
        nonlocal conn
        if conn is not None and not conn.is_closed():
            return
        conn = await asyncpg.connect(**DB_INFO)

    try:
        while True:
            if args.max_batches and batch_idx >= int(args.max_batches):
                log.info(f"[STOP] max-batches reached: {args.max_batches}")
                break
            if args.max_total and total_ok >= int(args.max_total):
                log.info(f"[STOP] max-total reached: {args.max_total}")
                break

            batch_idx += 1
            await _ensure_conn()
            assert conn is not None

            rows = await conn.fetch(QUERY_TMPL, int(min_group_id), int(workers), int(worker_idx), int(args.limit))
            if not rows:
                log.info("[DONE] no more rows")
                break

            log.info(f"[BATCH {batch_idx}] rows={len(rows)} (limit={args.limit}) min_group_id={min_group_id}")

            batch_ok = 0
            batch_fail = 0
            batch_changed = 0
            batch_max_gid = min_group_id

            if args.dry_run:
                for i, r in enumerate(rows, start=1):
                    gid = r.get("group_id")
                    log.info(
                        f"[DRY] {i}/{len(rows)} group_id={gid} "
                        f"best={r.get('best_vender_code')}/{r.get('best_icode')} price={r.get('best_price')}"
                    )
                log.info("[STOP] dry-run mode")
                break

            # 배치 단위 커밋 + 그룹 단위 savepoint(중첩 트랜잭션)
            async with conn.transaction():
                for i, r in enumerate(rows, start=1):
                    gid = r.get("group_id")
                    try:
                        if gid is not None:
                            batch_max_gid = max(int(batch_max_gid), int(gid))
                    except Exception:
                        pass

                    try:
                        async with conn.transaction():
                            changed = await refresh_group_winner_and_tokens_if_needed(conn, int(gid))
                        batch_ok += 1
                        total_ok += 1
                        if changed:
                            batch_changed += 1
                        if i <= 20 or (i % 50 == 0):
                            log.info(f"[OK] {i}/{len(rows)} group_id={gid} changed={bool(changed)}")
                    except Exception as e:
                        batch_fail += 1
                        total_fail += 1
                        log.error(f"[FAIL] {i}/{len(rows)} group_id={gid}: {e}")
                        try:
                            fail_writer.writerow(
                                dict(
                                    ts=datetime.now().isoformat(timespec="seconds"),
                                    batch_idx=batch_idx,
                                    idx_in_batch=i,
                                    group_id=gid,
                                    best_vender_code=r.get("best_vender_code"),
                                    best_icode=r.get("best_icode"),
                                    best_price=r.get("best_price"),
                                    error=str(e),
                                )
                            )
                            fail_fp.flush()
                        except Exception:
                            pass

                    if args.max_total and total_ok >= int(args.max_total):
                        break
                    if args.sleep and args.sleep > 0:
                        await asyncio.sleep(float(args.sleep))

            # 커밋 성공 시 체크포인트 갱신
            try:
                if int(batch_max_gid) > int(min_group_id):
                    min_group_id = int(batch_max_gid)
                    _write_checkpoint(checkpoint_path, int(min_group_id))
                    log.info(f"[CHECKPOINT] min_group_id={min_group_id}")
            except Exception:
                pass

            if batch_changed == 0:
                batch_no_progress += 1
            else:
                batch_no_progress = 0

            dt = time.perf_counter() - t0
            log.info(
                f"[BATCH {batch_idx} DONE] ok={batch_ok} fail={batch_fail} changed={batch_changed} "
                f"total_ok={total_ok} total_fail={total_fail} elapsed={dt:.1f}s"
            )

            if args.stop_after_no_progress and batch_no_progress >= int(args.stop_after_no_progress):
                log.info(
                    f"[STOP] no progress for {batch_no_progress} consecutive batches "
                    f"(stop-after-no-progress={args.stop_after_no_progress})"
                )
                break

    finally:
        try:
            if conn is not None and not conn.is_closed():
                await conn.close()
        except Exception:
            pass
        try:
            fail_fp.close()
        except Exception:
            pass

    return 0 if total_fail == 0 else 2


def _run_worker(args, *, worker_idx: int, workers: int) -> int:
    return asyncio.run(_run_worker_async(args, worker_idx=worker_idx, workers=workers))


def main(argv: list[str] | None = None) -> int:
    ap = argparse.ArgumentParser()
    ap.add_argument("--limit", type=int, default=1500, help="배치 1회당 처리할 최대 group 수")
    ap.add_argument("--sleep", type=float, default=0.0, help="각 처리 사이 sleep(초)")
    ap.add_argument("--dry-run", action="store_true", help="DB 수정 없이 대상만 출력")
    ap.add_argument("--max-batches", type=int, default=0, help="0이면 무제한, 그 외 배치 반복 횟수 제한")
    ap.add_argument("--max-total", type=int, default=0, help="0이면 무제한, 그 외 총 OK 처리 건수 제한")
    ap.add_argument(
        "--stop-after-no-progress",
        type=int,
        default=3,
        help="changed=0인 배치가 연속 N회면 중단(무한루프 방지)",
    )
    ap.add_argument("--min-group-id", type=int, default=0, help="WHERE g.group_id > min_group_id 시작값(체크포인트 없을 때)")
    ap.add_argument("--workers", type=int, default=1, help="병렬 워커 프로세스 수(기본 1)")
    ap.add_argument("--worker-idx", type=int, default=-1, help="내부용(직접 실행 시 워커 인덱스)")
    ap.add_argument(
        "--log-file",
        type=str,
        default="",
        help="실행 로그 파일 경로. 예: log/fix_{worker_idx}.log (병렬이면 placeholder 권장, 없으면 .w{idx} 자동 추가)",
    )
    ap.add_argument(
        "--fail-csv",
        type=str,
        default="",
        help="실패 항목 CSV 경로(비우면 자동 파일명). 예: log/fail_{worker_idx}.csv (병렬이면 .w{idx} 자동 추가)",
    )
    ap.add_argument("--fail-csv-dir", type=str, default=".", help="--fail-csv가 비어있을 때 자동 생성 파일 저장 폴더")
    ap.add_argument("--checkpoint-file", type=str, default="", help="체크포인트 파일 경로(워커별 권장: log/ckpt_{worker_idx}.txt)")
    ap.add_argument("--checkpoint-dir", type=str, default=".", help="--checkpoint-file이 비어있을 때 자동 생성 파일 저장 폴더")
    args = ap.parse_args(argv)

    workers = max(1, int(args.workers or 1))
    if workers == 1 or int(args.worker_idx) >= 0:
        wi = int(args.worker_idx) if int(args.worker_idx) >= 0 else 0
        return _run_worker(args, worker_idx=wi, workers=workers)

    parent_log = _setup_logger("", name="batch_fix_winner.parent", force=True)
    parent_log.info(f"starting workers={workers}")
    procs: list[mp.Process] = []
    exit_codes: dict[int, int] = {}

    try:
        for wi in range(workers):
            p = mp.Process(target=lambda: sys.exit(_run_worker(args, worker_idx=wi, workers=workers)))
            p.daemon = False
            p.start()
            procs.append(p)

        for p in procs:
            p.join()
            exit_codes[p.pid or -1] = int(p.exitcode or 0)
    except KeyboardInterrupt:
        parent_log.info("KeyboardInterrupt: terminating workers...")
        for p in procs:
            try:
                p.terminate()
            except Exception:
                pass
        for p in procs:
            try:
                p.join(timeout=5)
            except Exception:
                pass
        return 130

    bad = [c for c in exit_codes.values() if c != 0]
    parent_log.info(f"workers done. nonzero={len(bad)}")
    return 0 if not bad else 2


if __name__ == "__main__":
    raise SystemExit(main(sys.argv[1:]))

