diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
| commit | 66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch) | |
| tree | c29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/eval_directional_ckpts.py | |
Curated export for clone-and-run Maze training (2x A6000) + diagnostics.
trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible).
Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/eval_directional_ckpts.py')
| -rw-r--r-- | research/flossing/eval_directional_ckpts.py | 86 |
1 files changed, 86 insertions, 0 deletions
diff --git a/research/flossing/eval_directional_ckpts.py b/research/flossing/eval_directional_ckpts.py new file mode 100644 index 0000000..68653e6 --- /dev/null +++ b/research/flossing/eval_directional_ckpts.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import argparse +import csv +import json +import sys +import time +from pathlib import Path + +import torch + + +FLOSS_DIR = Path(__file__).resolve().parent +sys.path.insert(0, str(FLOSS_DIR)) + +from step7_interfloss import evaluate, load_model # noqa: E402 + + +def parse_steps(text: str) -> list[int]: + return [int(x.strip()) for x in text.split(",") if x.strip()] + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt-root", required=True) + parser.add_argument("--steps", required=True) + parser.add_argument("--n", type=int, default=1000) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--seed", type=int, default=20260611) + parser.add_argument("--out", required=True) + args = parser.parse_args() + + ckpt_root = Path(args.ckpt_root) + rows = [] + for step in parse_steps(args.steps): + ckpt_name = f"step_{step}" + if not (ckpt_root / ckpt_name).exists(): + print(f"[skip] missing {ckpt_name}", flush=True) + continue + + t0 = time.time() + head, base, cfg, _adam_cls, _sparse_cls = load_model( + "trm", + ckpt_root, + ckpt_name, + "cuda", + batch_size_override=args.batch_size, + ) + data_path = Path(cfg["data_path"]) + acc, tok = evaluate( + head, + base, + data_path, + n_samples=args.n, + batch_size=args.batch_size, + device="cuda", + seed=args.seed, + ) + elapsed = time.time() - t0 + row = { + "step": step, + "n": args.n, + "seed": args.seed, + "exact_acc": acc, + "token_acc": tok, + "elapsed_sec": elapsed, + } + rows.append(row) + print( + f"step={step} exact={acc:.4f} token={tok:.4f} elapsed={elapsed:.1f}s", + flush=True, + ) + del head, base + torch.cuda.empty_cache() + + out_path = Path(args.out) + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.with_suffix(".json").write_text(json.dumps(rows, indent=2)) + with out_path.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=list(rows[0].keys())) + writer.writeheader() + writer.writerows(rows) + + +if __name__ == "__main__": + main() |
