from __future__ import annotations import argparse import csv import json import sys import time from pathlib import Path import torch FLOSS_DIR = Path(__file__).resolve().parent sys.path.insert(0, str(FLOSS_DIR)) from step7_interfloss import evaluate, load_model # noqa: E402 def parse_steps(text: str) -> list[int]: return [int(x.strip()) for x in text.split(",") if x.strip()] def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--ckpt-root", required=True) parser.add_argument("--steps", required=True) parser.add_argument("--n", type=int, default=1000) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--seed", type=int, default=20260611) parser.add_argument("--out", required=True) args = parser.parse_args() ckpt_root = Path(args.ckpt_root) rows = [] for step in parse_steps(args.steps): ckpt_name = f"step_{step}" if not (ckpt_root / ckpt_name).exists(): print(f"[skip] missing {ckpt_name}", flush=True) continue t0 = time.time() head, base, cfg, _adam_cls, _sparse_cls = load_model( "trm", ckpt_root, ckpt_name, "cuda", batch_size_override=args.batch_size, ) data_path = Path(cfg["data_path"]) acc, tok = evaluate( head, base, data_path, n_samples=args.n, batch_size=args.batch_size, device="cuda", seed=args.seed, ) elapsed = time.time() - t0 row = { "step": step, "n": args.n, "seed": args.seed, "exact_acc": acc, "token_acc": tok, "elapsed_sec": elapsed, } rows.append(row) print( f"step={step} exact={acc:.4f} token={tok:.4f} elapsed={elapsed:.1f}s", flush=True, ) del head, base torch.cuda.empty_cache() out_path = Path(args.out) out_path.parent.mkdir(parents=True, exist_ok=True) out_path.with_suffix(".json").write_text(json.dumps(rows, indent=2)) with out_path.open("w", newline="") as f: writer = csv.DictWriter(f, fieldnames=list(rows[0].keys())) writer.writeheader() writer.writerows(rows) if __name__ == "__main__": main()