diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-29 12:15:51 -0500 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-29 12:15:51 -0500 |
| commit | a6ec4288a2232988b130b2f00bb2565f81706966 (patch) | |
| tree | 1bb86e7f0b899b823b9e7fdf383e832d30a181e0 /eval_directional_ckpts.py | |
Recursive reasoning dynamics: analysis pipeline, paper drafts, toy models
Failure=more-chaotic (task-general under validity labeling) reduces to convergence/completeness
detection; mechanism (transient chaos vs multistability vs input-induced) under investigation.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'eval_directional_ckpts.py')
| -rw-r--r-- | eval_directional_ckpts.py | 86 |
1 files changed, 86 insertions, 0 deletions
diff --git a/eval_directional_ckpts.py b/eval_directional_ckpts.py new file mode 100644 index 0000000..68653e6 --- /dev/null +++ b/eval_directional_ckpts.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import argparse +import csv +import json +import sys +import time +from pathlib import Path + +import torch + + +FLOSS_DIR = Path(__file__).resolve().parent +sys.path.insert(0, str(FLOSS_DIR)) + +from step7_interfloss import evaluate, load_model # noqa: E402 + + +def parse_steps(text: str) -> list[int]: + return [int(x.strip()) for x in text.split(",") if x.strip()] + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt-root", required=True) + parser.add_argument("--steps", required=True) + parser.add_argument("--n", type=int, default=1000) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--seed", type=int, default=20260611) + parser.add_argument("--out", required=True) + args = parser.parse_args() + + ckpt_root = Path(args.ckpt_root) + rows = [] + for step in parse_steps(args.steps): + ckpt_name = f"step_{step}" + if not (ckpt_root / ckpt_name).exists(): + print(f"[skip] missing {ckpt_name}", flush=True) + continue + + t0 = time.time() + head, base, cfg, _adam_cls, _sparse_cls = load_model( + "trm", + ckpt_root, + ckpt_name, + "cuda", + batch_size_override=args.batch_size, + ) + data_path = Path(cfg["data_path"]) + acc, tok = evaluate( + head, + base, + data_path, + n_samples=args.n, + batch_size=args.batch_size, + device="cuda", + seed=args.seed, + ) + elapsed = time.time() - t0 + row = { + "step": step, + "n": args.n, + "seed": args.seed, + "exact_acc": acc, + "token_acc": tok, + "elapsed_sec": elapsed, + } + rows.append(row) + print( + f"step={step} exact={acc:.4f} token={tok:.4f} elapsed={elapsed:.1f}s", + flush=True, + ) + del head, base + torch.cuda.empty_cache() + + out_path = Path(args.out) + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.with_suffix(".json").write_text(json.dumps(rows, indent=2)) + with out_path.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=list(rows[0].keys())) + writer.writeheader() + writer.writerows(rows) + + +if __name__ == "__main__": + main() |
