1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
|
from __future__ import annotations
import argparse
import csv
import json
import sys
import time
from pathlib import Path
import torch
FLOSS_DIR = Path(__file__).resolve().parent
sys.path.insert(0, str(FLOSS_DIR))
from step7_interfloss import evaluate, load_model # noqa: E402
def parse_steps(text: str) -> list[int]:
return [int(x.strip()) for x in text.split(",") if x.strip()]
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt-root", required=True)
parser.add_argument("--steps", required=True)
parser.add_argument("--n", type=int, default=1000)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--seed", type=int, default=20260611)
parser.add_argument("--out", required=True)
args = parser.parse_args()
ckpt_root = Path(args.ckpt_root)
rows = []
for step in parse_steps(args.steps):
ckpt_name = f"step_{step}"
if not (ckpt_root / ckpt_name).exists():
print(f"[skip] missing {ckpt_name}", flush=True)
continue
t0 = time.time()
head, base, cfg, _adam_cls, _sparse_cls = load_model(
"trm",
ckpt_root,
ckpt_name,
"cuda",
batch_size_override=args.batch_size,
)
data_path = Path(cfg["data_path"])
acc, tok = evaluate(
head,
base,
data_path,
n_samples=args.n,
batch_size=args.batch_size,
device="cuda",
seed=args.seed,
)
elapsed = time.time() - t0
row = {
"step": step,
"n": args.n,
"seed": args.seed,
"exact_acc": acc,
"token_acc": tok,
"elapsed_sec": elapsed,
}
rows.append(row)
print(
f"step={step} exact={acc:.4f} token={tok:.4f} elapsed={elapsed:.1f}s",
flush=True,
)
del head, base
torch.cuda.empty_cache()
out_path = Path(args.out)
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.with_suffix(".json").write_text(json.dumps(rows, indent=2))
with out_path.open("w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
writer.writeheader()
writer.writerows(rows)
if __name__ == "__main__":
main()
|