summaryrefslogtreecommitdiff
path: root/research/flossing/eval_directional_ckpts.py
blob: 68653e67829902559cfc355e94ba6ddca80c8018 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from __future__ import annotations

import argparse
import csv
import json
import sys
import time
from pathlib import Path

import torch


FLOSS_DIR = Path(__file__).resolve().parent
sys.path.insert(0, str(FLOSS_DIR))

from step7_interfloss import evaluate, load_model  # noqa: E402


def parse_steps(text: str) -> list[int]:
    return [int(x.strip()) for x in text.split(",") if x.strip()]


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--ckpt-root", required=True)
    parser.add_argument("--steps", required=True)
    parser.add_argument("--n", type=int, default=1000)
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument("--seed", type=int, default=20260611)
    parser.add_argument("--out", required=True)
    args = parser.parse_args()

    ckpt_root = Path(args.ckpt_root)
    rows = []
    for step in parse_steps(args.steps):
        ckpt_name = f"step_{step}"
        if not (ckpt_root / ckpt_name).exists():
            print(f"[skip] missing {ckpt_name}", flush=True)
            continue

        t0 = time.time()
        head, base, cfg, _adam_cls, _sparse_cls = load_model(
            "trm",
            ckpt_root,
            ckpt_name,
            "cuda",
            batch_size_override=args.batch_size,
        )
        data_path = Path(cfg["data_path"])
        acc, tok = evaluate(
            head,
            base,
            data_path,
            n_samples=args.n,
            batch_size=args.batch_size,
            device="cuda",
            seed=args.seed,
        )
        elapsed = time.time() - t0
        row = {
            "step": step,
            "n": args.n,
            "seed": args.seed,
            "exact_acc": acc,
            "token_acc": tok,
            "elapsed_sec": elapsed,
        }
        rows.append(row)
        print(
            f"step={step} exact={acc:.4f} token={tok:.4f} elapsed={elapsed:.1f}s",
            flush=True,
        )
        del head, base
        torch.cuda.empty_cache()

        out_path = Path(args.out)
        out_path.parent.mkdir(parents=True, exist_ok=True)
        out_path.with_suffix(".json").write_text(json.dumps(rows, indent=2))
        with out_path.open("w", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
            writer.writeheader()
            writer.writerows(rows)


if __name__ == "__main__":
    main()