#!/usr/bin/env python3
import argparse
import csv
import json
import math
import random
from collections import Counter, defaultdict

try:
    from sklearn.svm import LinearSVC
except Exception:
    LinearSVC = None


def log1p_safe(v):
    if v <= 0:
        return 0.0
    return math.log1p(v)


def norm_profile_name(v):
    s = (v or "default").strip().lower()
    if not s:
        return "default"
    out = []
    for c in s:
        if c.isalnum() or c in "_-":
            out.append(c)
        else:
            out.append("_")
    s = "".join(out).strip("_")
    return s or "default"


def load_rows(path, intent_col):
    rows = []
    with open(path, newline="", encoding="utf-8") as f:
        rdr = csv.DictReader(f)
        for r in rdr:
            rows.append(
                {
                    "query_id": str(r.get("query_id", "")).strip(),
                    "url": str(r.get("url", "")).strip(),
                    "level": float(r.get("level", 0) or 0),
                    "title_len": float(r.get("title_len", 0) or 0),
                    "text_len": float(r.get("text_len", 0) or 0),
                    "label": float(r.get("label", 0) or 0),
                    "intent": norm_profile_name(r.get(intent_col, "default")),
                }
            )
    rows = [r for r in rows if r["query_id"]]
    if not rows:
        raise ValueError("No valid rows with query_id")
    return rows


def normalize_features(rows):
    max_level = max(max(r["level"], 1.0) for r in rows)
    max_title_log = max(log1p_safe(r["title_len"]) for r in rows) or 1.0
    max_text_log = max(log1p_safe(r["text_len"]) for r in rows) or 1.0
    out = []
    for r in rows:
        rr = dict(r)
        f_level = (max_level - max(rr["level"], 1.0) + 1.0) / max_level
        f_title = log1p_safe(rr["title_len"]) / max_title_log
        f_text = log1p_safe(rr["text_len"]) / max_text_log
        rr["x"] = [f_level, f_title, f_text]
        out.append(rr)
    return out


def group_by_query(rows):
    by_q = defaultdict(list)
    for r in rows:
        by_q[r["query_id"]].append(r)
    return by_q


def make_pairwise(rows):
    x = []
    y = []
    for i in range(len(rows)):
        for j in range(len(rows)):
            if i == j:
                continue
            if rows[i]["label"] > rows[j]["label"]:
                d = [rows[i]["x"][k] - rows[j]["x"][k] for k in range(3)]
                x.append(d)
                y.append(1)
    return x, y


def count_positive_pairs(rows):
    by_q = group_by_query(rows)
    pairs = 0
    for q_rows in by_q.values():
        by_label = Counter(r["label"] for r in q_rows)
        lower_total = 0
        for label in sorted(by_label.keys()):
            cnt = by_label[label]
            pairs += cnt * lower_total
            lower_total += cnt
    return pairs


def dcg(labels, k):
    out = 0.0
    for i, rel in enumerate(labels[:k], start=1):
        out += (2.0**rel - 1.0) / math.log2(i + 1.0)
    return out


def ndcg_for_query(rows, w, k=10):
    scored = sorted(rows, key=lambda r: sum(w[i] * r["x"][i] for i in range(3)), reverse=True)
    ideal = sorted(rows, key=lambda r: r["label"], reverse=True)
    got = dcg([r["label"] for r in scored], k)
    best = dcg([r["label"] for r in ideal], k)
    return (got / best) if best > 0 else 0.0


def train_once(train_q, test_q, by_q, c_value):
    train_rows = []
    for q in train_q:
        train_rows.extend(by_q[q])
    x, y = make_pairwise(train_rows)
    if len(x) < 2:
        return None
    if LinearSVC is None:
        raise RuntimeError("scikit-learn is required (pip install scikit-learn)")
    model = LinearSVC(C=c_value, max_iter=5000)
    model.fit(x, y)
    w = model.coef_[0].tolist()
    w_sum = sum(max(1e-9, abs(v)) for v in w)
    w = [max(1e-9, abs(v)) / w_sum for v in w]

    ndcgs = []
    for q in test_q:
        ndcgs.append(ndcg_for_query(by_q[q], w, k=10))
    return {"w": w, "ndcg10": (sum(ndcgs) / len(ndcgs)) if ndcgs else 0.0}


def aggregate_runs(w_runs, ndcg_runs):
    means = [sum(v[i] for v in w_runs) / len(w_runs) for i in range(3)]
    stds = []
    for i in range(3):
        mu = means[i]
        var = sum((v[i] - mu) ** 2 for v in w_runs) / len(w_runs)
        stds.append(math.sqrt(var))
    ndcg_mean = sum(ndcg_runs) / len(ndcg_runs) if ndcg_runs else 0.0
    ndcg_std = (
        math.sqrt(sum((x - ndcg_mean) ** 2 for x in ndcg_runs) / len(ndcg_runs)) if ndcg_runs else 0.0
    )
    return {
        "runs_ok": len(w_runs),
        "mean_w": {"level": means[0], "title": means[1], "text": means[2]},
        "std_w": {"level": stds[0], "title": stds[1], "text": stds[2]},
        "ndcg10_mean": ndcg_mean,
        "ndcg10_std": ndcg_std,
    }


def train_profile(rows, args, seed):
    meta = {
        "rows": len(rows),
        "queries": len({r["query_id"] for r in rows}),
        "pairs": count_positive_pairs(rows),
    }

    rows_n = normalize_features(rows)
    by_q = group_by_query(rows_n)
    queries = list(by_q.keys())
    if len(queries) < args.min_queries_per_intent:
        return None, meta

    rnd = random.Random(seed)
    w_runs = []
    ndcg_runs = []
    for _ in range(max(5, args.runs)):
        q = queries[:]
        rnd.shuffle(q)
        cut = max(1, min(len(q) - 1, int(len(q) * args.train_ratio)))
        train_q = q[:cut]
        test_q = q[cut:]
        out = train_once(train_q, test_q, by_q, args.c)
        if not out:
            continue
        w_runs.append(out["w"])
        ndcg_runs.append(out["ndcg10"])

    if not w_runs:
        return None, meta
    rep = aggregate_runs(w_runs, ndcg_runs)
    rep.update(meta)
    return rep, meta


def write_conf(path, profiles):
    with open(path, "w", encoding="utf-8") as f:
        f.write("# Auto-generated by tools/ltr/train_rank_ltr.py\n")
        f.write("# Multi-profile format: [default], [blog], [docs], [ecommerce], ...\n\n")
        for name, report in profiles.items():
            f.write(f"[{name}]\n")
            f.write(f"w_level={report['mean_w']['level']:.8f}\n")
            f.write(f"w_title={report['mean_w']['title']:.8f}\n")
            f.write(f"w_text={report['mean_w']['text']:.8f}\n")
            f.write(f"std_level={report['std_w']['level']:.8f}\n")
            f.write(f"std_title={report['std_w']['title']:.8f}\n")
            f.write(f"std_text={report['std_w']['text']:.8f}\n\n")


def main():
    ap = argparse.ArgumentParser(
        description="Train RankSVM-like linear LTR and export multi-profile rank_weights.conf"
    )
    ap.add_argument("--input", required=True, help="CSV with query_id,url,level,title_len,text_len,label[,intent]")
    ap.add_argument("--intent-col", default="intent", help="CSV column used as profile key (default: intent)")
    ap.add_argument("--runs", type=int, default=100)
    ap.add_argument("--train-ratio", type=float, default=0.8)
    ap.add_argument("--c", type=float, default=1.0)
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--min-queries-per-intent", type=int, default=3)
    ap.add_argument("--out-conf", default="rank_weights.conf")
    ap.add_argument("--out-json", default="tools/ltr/ltr_report.json")
    args = ap.parse_args()

    rows = load_rows(args.input, args.intent_col)

    all_report, _ = train_profile(rows, args, args.seed)
    if not all_report:
        raise RuntimeError("No valid training run produced default profile")

    by_intent = defaultdict(list)
    for r in rows:
        by_intent[r["intent"]].append(r)

    profiles = {"default": all_report}
    intent_reports = {}
    skipped = {}
    offset = 1

    for intent in sorted(by_intent.keys()):
        if intent == "default":
            continue
        rep, meta = train_profile(by_intent[intent], args, args.seed + offset)
        offset += 1
        if rep:
            profiles[intent] = rep
            intent_reports[intent] = rep
        else:
            skipped[intent] = {
                "reason": "insufficient_data_or_pairs",
                "queries": meta["queries"],
                "rows": meta["rows"],
                "pairs": meta["pairs"],
            }

    write_conf(args.out_conf, profiles)

    report = {
        "default": all_report,
        "profiles": intent_reports,
        "skipped": skipped,
        "profiles_written": list(profiles.keys()),
    }
    with open(args.out_json, "w", encoding="utf-8") as f:
        json.dump(report, f, indent=2)

    print("Training complete")
    print(json.dumps(report, indent=2))
    print(f"Exported weights: {args.out_conf}")


if __name__ == "__main__":
    main()
