"""
배치 스크립트: winner 대비 과도하게 비싼(> winner_price * ratio) 항목을 찾아
SearchGroupPositionUI_FullIntegrated.MainWindow._move_one_item_to_winner_group() 로직으로
"자기 winner 그룹(없으면 생성/매칭)"으로 분리 이동한다.

사용 예)
  python batch_move_overpriced_to_winner_group.py
  python batch_move_overpriced_to_winner_group.py --limit 500 --ratio 1.3
  python batch_move_overpriced_to_winner_group.py --dry-run

주의
- DB 접속 정보는 `SearchGroupPositionUI_FullIntegrated.DB_INFO`를 그대로 사용한다.
- 이 스크립트는 Qt(QApplication) 없이 동작하도록, MainWindow 인스턴스를 만들지 않고
  unbound method를 dummy self에 바인딩하여 호출한다.
"""

from __future__ import annotations

import argparse
import csv
import logging
import os
import sys
import time
from dataclasses import dataclass
from datetime import datetime
import multiprocessing as mp

import psycopg2
from psycopg2.extras import DictCursor

import SearchGroupPositionUI_FullIntegrated as ui


@dataclass
class _Dummy:
    conn: "psycopg2.extensions.connection"


QUERY_TMPL = """
SELECT
    m.group_id,
    m.icode,
    m.img_url,
    m.iname,
    m.price,
    m.vender_code
 FROM mlinkdw.shopprod_group_map2 m
 JOIN mlinkdw.shopprod_group2 g
   ON g.group_id = m.group_id
WHERE g.winner_price IS NOT NULL
  AND m.price > g.winner_price * %(ratio)s
  AND NOT (m.vender_code = g.winner_vender_code AND m.icode = g.winner_icode)
  AND m.group_id > %(min_group_id)s
  AND (%(workers)s = 1 OR (m.group_id %% %(workers)s) = %(worker_idx)s)
ORDER BY m.group_id, m.vender_code, m.icode
LIMIT %(limit)s;
""".strip()


def _setup_logger(log_file: str | None, *, name: str = "batch_move_overpriced", force: bool = False):
    """
    - log_file이 비어있으면: 기존처럼 stdout만 사용
    - log_file이 있으면: stdout + 파일 동시 기록
    """
    logger = logging.getLogger(str(name))
    if force and logger.handlers:
        for h in list(logger.handlers):
            try:
                h.flush()
                h.close()
            except Exception:
                pass
            try:
                logger.removeHandler(h)
            except Exception:
                pass
    if (not force) and logger.handlers:
        return logger
    logger.setLevel(logging.INFO)
    fmt = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")

    sh = logging.StreamHandler(stream=sys.stdout)
    sh.setFormatter(fmt)
    logger.addHandler(sh)

    lf = (str(log_file or "")).strip()
    if lf:
        try:
            os.makedirs(os.path.dirname(os.path.abspath(lf)) or ".", exist_ok=True)
        except Exception:
            pass
        fh = logging.FileHandler(lf, encoding="utf-8")
        fh.setFormatter(fmt)
        logger.addHandler(fh)

    return logger


def _with_worker_suffix(path: str, *, worker_idx: int) -> str:
    """
    병렬 실행 시 파일 충돌 방지:
    - 경로에 {worker_idx}/{pid} placeholder가 있으면 format 적용
    - 없으면 확장자 앞에 .w{idx} 추가
    """
    p = str(path or "").strip()
    if not p:
        return p
    try:
        if ("{worker_idx}" in p) or ("{pid}" in p):
            return p.format(worker_idx=int(worker_idx), pid=os.getpid())
    except Exception:
        pass
    root, ext = os.path.splitext(p)
    return f"{root}.w{int(worker_idx)}{ext}"

def _resolve_worker_path(path: str, *, worker_idx: int) -> str:
    """
    워커별 파일 경로 해석:
    - {worker_idx}/{pid} placeholder가 있으면 format 적용
    - 없으면 원본 그대로 반환(사용자가 '워커별 로그를 명시적으로 지정'할 수 있게)
    """
    p = str(path or "").strip()
    if not p:
        return p
    try:
        if ("{worker_idx}" in p) or ("{pid}" in p):
            return p.format(worker_idx=int(worker_idx), pid=os.getpid())
    except Exception:
        pass
    return p

def _read_checkpoint(path: str) -> int:
    try:
        if not path or (not os.path.exists(path)):
            return 0
        with open(path, "r", encoding="utf-8") as fp:
            s = (fp.read() or "").strip()
        return int(s) if s else 0
    except Exception:
        return 0

def _write_checkpoint(path: str, value: int):
    if not path:
        return
    try:
        os.makedirs(os.path.dirname(os.path.abspath(path)) or ".", exist_ok=True)
    except Exception:
        pass
    tmp = f"{path}.tmp"
    with open(tmp, "w", encoding="utf-8") as fp:
        fp.write(str(int(value)))
        fp.write("\n")
    try:
        os.replace(tmp, path)
    except Exception:
        # best-effort
        try:
            with open(path, "w", encoding="utf-8") as fp:
                fp.write(str(int(value)))
                fp.write("\n")
        except Exception:
            pass


def _run_worker(args, *, worker_idx: int, workers: int) -> int:
    """
    워커 1개가 무한 루프(또는 max 제한)로 처리.
    """
    # 로깅/출력 파일 경로 워커별 분리
    log_file = _resolve_worker_path((args.log_file or "").strip(), worker_idx=worker_idx)
    # 기존 호환: placeholder가 없고 workers>1이면 자동 suffix도 지원
    if log_file and int(workers) > 1 and ("{worker_idx}" not in log_file) and ("{pid}" not in log_file):
        log_file = _with_worker_suffix(log_file, worker_idx=worker_idx)
    log = _setup_logger(log_file, name=f"batch_move_overpriced.w{worker_idx}", force=True)

    # unbound method 호출 준비
    mover = ui.MainWindow._move_one_item_to_winner_group

    conn = None
    dummy = None
    total_ok = 0
    total_fail = 0
    batch_no_progress = 0
    batch_idx = 0
    t0 = time.perf_counter()

    # 체크포인트(재개) 준비: 각 워커가 자신의 group_id 진행도를 저장
    checkpoint_path = (args.checkpoint_file or "").strip()
    if not checkpoint_path:
        checkpoint_path = os.path.join(str(args.checkpoint_dir), f"checkpoint_overpriced_w{int(worker_idx)}.txt")
    checkpoint_path = _resolve_worker_path(checkpoint_path, worker_idx=worker_idx)
    if checkpoint_path and int(workers) > 1 and ("{worker_idx}" not in checkpoint_path) and ("{pid}" not in checkpoint_path):
        checkpoint_path = _with_worker_suffix(checkpoint_path, worker_idx=worker_idx)
    checkpoint_path = os.path.abspath(checkpoint_path) if checkpoint_path else ""
    min_group_id = max(int(getattr(args, "min_group_id", 0) or 0), _read_checkpoint(checkpoint_path))

    # 실패 CSV 준비
    fail_csv_path = (args.fail_csv or "").strip()
    if fail_csv_path and int(workers) > 1:
        fail_csv_path = _with_worker_suffix(fail_csv_path, worker_idx=worker_idx)
    if not fail_csv_path:
        ts = datetime.now().strftime("%Y%m%d_%H%M%S")
        fail_csv_path = os.path.join(str(args.fail_csv_dir), f"fail_overpriced_{ts}_w{int(worker_idx)}.csv")
    fail_csv_path = os.path.abspath(fail_csv_path)
    fail_csv_header = [
        "ts",
        "batch_idx",
        "idx_in_batch",
        "group_id",
        "vender_code",
        "icode",
        "price",
        "iname",
        "img_url",
        "error",
    ]
    try:
        os.makedirs(os.path.dirname(fail_csv_path), exist_ok=True)
    except Exception:
        pass
    fail_csv_needs_header = True
    try:
        if os.path.exists(fail_csv_path) and os.path.getsize(fail_csv_path) > 0:
            fail_csv_needs_header = False
    except Exception:
        fail_csv_needs_header = True
    fail_csv_fp = open(fail_csv_path, "a", newline="", encoding="utf-8-sig")
    fail_writer = csv.DictWriter(fail_csv_fp, fieldnames=fail_csv_header)
    if fail_csv_needs_header:
        fail_writer.writeheader()
        fail_csv_fp.flush()
    log.info(
        f"worker={worker_idx}/{workers} pid={os.getpid()} "
        f"fail-csv={fail_csv_path} checkpoint={checkpoint_path} min_group_id={min_group_id}"
    )

    def _ensure_conn():
        nonlocal conn, dummy
        try:
            if conn is not None and getattr(conn, "closed", 1) == 0:
                return
        except Exception:
            pass
        conn = psycopg2.connect(**ui.DB_INFO)
        conn.autocommit = False
        dummy = _Dummy(conn=conn)

    try:
        while True:
            if args.max_batches and batch_idx >= int(args.max_batches):
                log.info(f"[STOP] max-batches reached: {args.max_batches}")
                break
            if args.max_total and total_ok >= int(args.max_total):
                log.info(f"[STOP] max-total reached: {args.max_total}")
                break

            batch_idx += 1
            _ensure_conn()

            with conn.cursor(cursor_factory=DictCursor) as cur:
                cur.execute(
                    QUERY_TMPL,
                    {
                        "ratio": float(args.ratio),
                        "limit": int(args.limit),
                        "min_group_id": int(min_group_id),
                        "workers": int(workers),
                        "worker_idx": int(worker_idx),
                    },
                )
                rows = cur.fetchall()

            if not rows:
                log.info("[DONE] no more rows")
                break

            log.info(f"[BATCH {batch_idx}] rows={len(rows)} (limit={args.limit}, ratio={args.ratio})")

            batch_ok = 0
            batch_fail = 0
            batch_max_gid = min_group_id
            # 1500건(배치) 단위 커밋
            with conn.cursor() as txcur:
                txcur.execute("BEGIN")
            for i, r in enumerate(rows, start=1):
                vender_code = str(r.get("vender_code") or "")
                icode = str(r.get("icode") or "")
                iname = str(r.get("iname") or "")
                price = r.get("price")
                img_url = r.get("img_url")
                group_id = r.get("group_id")
                try:
                    if group_id is not None:
                        batch_max_gid = max(int(batch_max_gid), int(group_id))
                except Exception:
                    pass

                if args.dry_run:
                    log.info(
                        f"[DRY] {i}/{len(rows)} group_id={group_id} {vender_code}/{icode} price={price} {iname}"
                    )
                    continue

                try:
                    sp_name = f"sp_b{batch_idx}_{i}"
                    with conn.cursor() as spcur:
                        spcur.execute(f"SAVEPOINT {sp_name}")
                    mover(
                        dummy,  # type: ignore[arg-type]
                        vender_code=vender_code,
                        icode=icode,
                        iname=iname,
                        price=price,
                        img_url=img_url,
                        manage_tx=False,  # 배치 단위 커밋을 위해 외부에서 트랜잭션 관리
                    )
                    with conn.cursor() as spcur:
                        spcur.execute(f"RELEASE SAVEPOINT {sp_name}")
                    batch_ok += 1
                    total_ok += 1
                    if i <= 20 or (i % 50 == 0):
                        log.info(f"[OK] {i}/{len(rows)} {vender_code}/{icode} (group_id={group_id})")
                except Exception as e:
                    try:
                        with conn.cursor() as spcur:
                            spcur.execute(f"ROLLBACK TO SAVEPOINT {sp_name}")
                            spcur.execute(f"RELEASE SAVEPOINT {sp_name}")
                    except Exception:
                        # savepoint 실패 시 전체 트랜잭션이 깨졌을 수 있으므로 롤백
                        try:
                            conn.rollback()
                            with conn.cursor() as txcur:
                                txcur.execute("BEGIN")
                        except Exception:
                            pass
                    batch_fail += 1
                    total_fail += 1
                    log.error(f"[FAIL] {i}/{len(rows)} {vender_code}/{icode} (group_id={group_id}): {e}")
                    try:
                        fail_writer.writerow(
                            dict(
                                ts=datetime.now().isoformat(timespec="seconds"),
                                batch_idx=batch_idx,
                                idx_in_batch=i,
                                group_id=group_id,
                                vender_code=vender_code,
                                icode=icode,
                                price=price,
                                iname=iname,
                                img_url=img_url,
                                error=str(e),
                            )
                        )
                        fail_csv_fp.flush()
                    except Exception:
                        pass

                if args.max_total and total_ok >= int(args.max_total):
                    break
                if args.sleep and args.sleep > 0:
                    time.sleep(float(args.sleep))

            if args.dry_run:
                log.info("[STOP] dry-run mode")
                try:
                    conn.rollback()
                except Exception:
                    pass
                break

            # 배치 단위 커밋
            try:
                conn.commit()
            except Exception:
                try:
                    conn.rollback()
                except Exception:
                    pass
            else:
                # 커밋 성공 시 체크포인트 갱신
                try:
                    if int(batch_max_gid) > int(min_group_id):
                        min_group_id = int(batch_max_gid)
                        _write_checkpoint(checkpoint_path, int(min_group_id))
                        log.info(f"[CHECKPOINT] min_group_id={min_group_id}")
                except Exception:
                    pass

            if batch_ok == 0:
                batch_no_progress += 1
            else:
                batch_no_progress = 0

            dt = time.perf_counter() - t0
            log.info(
                f"[BATCH {batch_idx} DONE] ok={batch_ok} fail={batch_fail} "
                f"total_ok={total_ok} total_fail={total_fail} elapsed={dt:.1f}s"
            )

            if args.stop_after_no_progress and batch_no_progress >= int(args.stop_after_no_progress):
                log.info(
                    f"[STOP] no progress for {batch_no_progress} consecutive batches "
                    f"(stop-after-no-progress={args.stop_after_no_progress})"
                )
                break

    finally:
        try:
            if conn is not None:
                conn.close()
        except Exception:
            pass
        try:
            fail_csv_fp.close()
        except Exception:
            pass

    return 0 if total_fail == 0 else 2


def main(argv: list[str] | None = None) -> int:
    ap = argparse.ArgumentParser()
    ap.add_argument("--limit", type=int, default=1500, help="배치 1회당 처리할 최대 row 수")
    ap.add_argument("--ratio", type=float, default=1.3)
    ap.add_argument("--sleep", type=float, default=0.0, help="각 이동 처리 사이 sleep(초)")
    ap.add_argument("--dry-run", action="store_true", help="이동 호출 없이 대상만 출력")
    ap.add_argument("--max-batches", type=int, default=0, help="0이면 무제한, 그 외 배치 반복 횟수 제한")
    ap.add_argument("--max-total", type=int, default=0, help="0이면 무제한, 그 외 총 OK 처리 건수 제한")
    ap.add_argument(
        "--stop-after-no-progress",
        type=int,
        default=3,
        help="OK=0인 배치가 연속 N회면 중단(실패로 무한루프 방지)",
    )
    ap.add_argument(
        "--log-file",
        type=str,
        default="",
        help="실행 로그 파일 경로. 예: log/run_{worker_idx}.log (병렬이면 placeholder 권장, 없으면 .w{idx} 자동 추가)",
    )
    ap.add_argument(
        "--fail-csv",
        type=str,
        default="",
        help="실패 항목 CSV 경로(비우면 자동 파일명). 예: log/fail_{worker_idx}.csv (병렬이면 .w{idx} 자동 추가)",
    )
    ap.add_argument(
        "--fail-csv-dir",
        type=str,
        default=".",
        help="--fail-csv가 비어있을 때 자동 생성 파일을 저장할 폴더",
    )
    ap.add_argument("--min-group-id", type=int, default=0, help="WHERE m.group_id > min_group_id 시작값(체크포인트 없을 때)")
    ap.add_argument("--checkpoint-file", type=str, default="", help="체크포인트 파일 경로(워커별 권장: log/ckpt_{worker_idx}.txt)")
    ap.add_argument("--checkpoint-dir", type=str, default=".", help="--checkpoint-file이 비어있을 때 자동 생성 파일 저장 폴더")
    ap.add_argument("--workers", type=int, default=1, help="병렬 워커 프로세스 수(기본 1)")
    ap.add_argument("--worker-idx", type=int, default=-1, help="내부용(직접 실행 시 워커 인덱스)")
    args = ap.parse_args(argv)

    workers = max(1, int(args.workers or 1))
    # 단일 워커(또는 특정 워커 지정)면 현재 프로세스에서 실행
    if workers == 1 or int(args.worker_idx) >= 0:
        wi = int(args.worker_idx) if int(args.worker_idx) >= 0 else 0
        return _run_worker(args, worker_idx=wi, workers=workers)

    # 멀티프로세스 병렬 실행
    # - group_id % workers 파티셔닝으로 중복 처리 방지
    # - 각 워커는 독립 DB connection/커밋/로그/실패CSV 사용
    procs: list[mp.Process] = []
    exit_codes: dict[int, int] = {}

    # 부모 프로세스 로그는 stdout에만
    parent_log = _setup_logger("", name="batch_move_overpriced.parent", force=True)
    parent_log.info(f"starting workers={workers}")
    try:
        for wi in range(workers):
            p = mp.Process(target=lambda: sys.exit(_run_worker(args, worker_idx=wi, workers=workers)))
            p.daemon = False
            p.start()
            procs.append(p)

        for p in procs:
            p.join()
            exit_codes[p.pid or -1] = int(p.exitcode or 0)
    except KeyboardInterrupt:
        parent_log.info("KeyboardInterrupt: terminating workers...")
        for p in procs:
            try:
                p.terminate()
            except Exception:
                pass
        for p in procs:
            try:
                p.join(timeout=5)
            except Exception:
                pass
        return 130

    bad = [c for c in exit_codes.values() if c != 0]
    parent_log.info(f"workers done. nonzero={len(bad)}")
    return 0 if not bad else 2


if __name__ == "__main__":
    raise SystemExit(main(sys.argv[1:]))

