diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/analyze_all.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/analyze_all.py')
| -rw-r--r-- | ep_run/analyze_all.py | 86 |
1 files changed, 86 insertions, 0 deletions
diff --git a/ep_run/analyze_all.py b/ep_run/analyze_all.py new file mode 100644 index 0000000..f309190 --- /dev/null +++ b/ep_run/analyze_all.py @@ -0,0 +1,86 @@ +"""Analyze all runs in runs_local/ — produces a summary table sorted by val_loss. + +For each run: final val_loss, per-projection grad_cos breakdown, STE flags used. +""" +import json +from pathlib import Path + +def get_final(arm_dir): + p = arm_dir / "log.jsonl" + if not p.exists(): + return None + lines = p.read_text().strip().split("\n") + evals = [json.loads(l) for l in lines if l and json.loads(l).get("event") == "eval"] + if not evals: + return None + return evals[-1] + +def group_cos(gc): + groups = {} + for name, val in gc.items(): + if "q_proj" in name: key = "q" + elif "k_proj" in name: key = "k" + elif "v_proj" in name: key = "v" + elif "o_proj" in name: key = "o" + elif "mlp.fc" in name: key = "fc" + elif "mlp.proj" in name: key = "pr" + elif "head" in name: key = "hd" + else: key = "?" + groups.setdefault(key, []).append(val) + out = {} + for k, v in groups.items(): + valid = [x for x in v if x == x] + if valid: + out[k] = sum(valid) / len(valid) + return out + +def main(): + runs_dir = Path("runs_local") + rows = [] + for d in sorted(runs_dir.iterdir()): + if not d.is_dir(): + continue + ev = get_final(d) + if ev is None: + continue + cfg_path = d / "config.json" + cfg = json.loads(cfg_path.read_text()) if cfg_path.exists() else {} + gc = ev.get("grad_cos") or {} + groups = group_cos(gc) + valid = [v for v in gc.values() if v == v] + mean_cos = sum(valid) / len(valid) if valid else float("nan") + flags = [] + if cfg.get("ste_sigmoid"): flags.append("σSTE") + if cfg.get("ste_gelu"): flags.append("gSTE") + if cfg.get("ste_ln"): flags.append("lSTE") + if cfg.get("freeze_emb"): flags.append("frzE") + method = cfg.get("method", "?") + attn = cfg.get("attn_mode", "?") + nl = cfg.get("n_layer", "?") + rows.append({ + "name": d.name, + "method": method, + "attn": attn[:3], + "L": nl, + "flags": "+".join(flags) if flags else "-", + "val": ev.get("val_loss", float("nan")), + "μcos": mean_cos, + "groups": groups, + }) + + rows.sort(key=lambda r: r["val"] if r["val"] == r["val"] else 999) + + hdr = f"{'name':24s} {'meth':4s} {'attn':3s} {'L':>2s} {'flags':20s} {'val':>8s} {'μcos':>6s} {'hd':>5s} {'o':>5s} {'v':>5s} {'q':>5s} {'k':>5s} {'fc':>5s} {'pr':>5s}" + print(hdr) + print("-" * len(hdr)) + for r in rows: + g = r["groups"] + def fmt(k): + v = g.get(k, float("nan")) + return f"{v:>5.2f}" if v == v else " nan" + val_s = f"{r['val']:>8.4f}" if r["val"] == r["val"] else " nan" + cos_s = f"{r['μcos']:>6.3f}" if r["μcos"] == r["μcos"] else " nan" + print(f"{r['name']:24s} {r['method']:4s} {r['attn']:3s} {str(r['L']):>2s} {r['flags']:20s} {val_s} {cos_s} {fmt('hd')} {fmt('o')} {fmt('v')} {fmt('q')} {fmt('k')} {fmt('fc')} {fmt('pr')}") + +if __name__ == "__main__": + main() |
