summaryrefslogtreecommitdiff
path: root/research/flossing/eval_directional_ckpts.py
diff options
context:
space:
mode:
Diffstat (limited to 'research/flossing/eval_directional_ckpts.py')
-rw-r--r--research/flossing/eval_directional_ckpts.py86
1 files changed, 86 insertions, 0 deletions
diff --git a/research/flossing/eval_directional_ckpts.py b/research/flossing/eval_directional_ckpts.py
new file mode 100644
index 0000000..68653e6
--- /dev/null
+++ b/research/flossing/eval_directional_ckpts.py
@@ -0,0 +1,86 @@
+from __future__ import annotations
+
+import argparse
+import csv
+import json
+import sys
+import time
+from pathlib import Path
+
+import torch
+
+
+FLOSS_DIR = Path(__file__).resolve().parent
+sys.path.insert(0, str(FLOSS_DIR))
+
+from step7_interfloss import evaluate, load_model # noqa: E402
+
+
+def parse_steps(text: str) -> list[int]:
+ return [int(x.strip()) for x in text.split(",") if x.strip()]
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--ckpt-root", required=True)
+ parser.add_argument("--steps", required=True)
+ parser.add_argument("--n", type=int, default=1000)
+ parser.add_argument("--batch-size", type=int, default=32)
+ parser.add_argument("--seed", type=int, default=20260611)
+ parser.add_argument("--out", required=True)
+ args = parser.parse_args()
+
+ ckpt_root = Path(args.ckpt_root)
+ rows = []
+ for step in parse_steps(args.steps):
+ ckpt_name = f"step_{step}"
+ if not (ckpt_root / ckpt_name).exists():
+ print(f"[skip] missing {ckpt_name}", flush=True)
+ continue
+
+ t0 = time.time()
+ head, base, cfg, _adam_cls, _sparse_cls = load_model(
+ "trm",
+ ckpt_root,
+ ckpt_name,
+ "cuda",
+ batch_size_override=args.batch_size,
+ )
+ data_path = Path(cfg["data_path"])
+ acc, tok = evaluate(
+ head,
+ base,
+ data_path,
+ n_samples=args.n,
+ batch_size=args.batch_size,
+ device="cuda",
+ seed=args.seed,
+ )
+ elapsed = time.time() - t0
+ row = {
+ "step": step,
+ "n": args.n,
+ "seed": args.seed,
+ "exact_acc": acc,
+ "token_acc": tok,
+ "elapsed_sec": elapsed,
+ }
+ rows.append(row)
+ print(
+ f"step={step} exact={acc:.4f} token={tok:.4f} elapsed={elapsed:.1f}s",
+ flush=True,
+ )
+ del head, base
+ torch.cuda.empty_cache()
+
+ out_path = Path(args.out)
+ out_path.parent.mkdir(parents=True, exist_ok=True)
+ out_path.with_suffix(".json").write_text(json.dumps(rows, indent=2))
+ with out_path.open("w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
+ writer.writeheader()
+ writer.writerows(rows)
+
+
+if __name__ == "__main__":
+ main()