summaryrefslogtreecommitdiff
path: root/ep_run/analyze_all.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
committerYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
commitb83947778e2c776f757a07d4719b7ce961d7ed55 (patch)
treeb9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/analyze_all.py
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}), analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints git-ignored (share separately). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/analyze_all.py')
-rw-r--r--ep_run/analyze_all.py86
1 files changed, 86 insertions, 0 deletions
diff --git a/ep_run/analyze_all.py b/ep_run/analyze_all.py
new file mode 100644
index 0000000..f309190
--- /dev/null
+++ b/ep_run/analyze_all.py
@@ -0,0 +1,86 @@
+"""Analyze all runs in runs_local/ — produces a summary table sorted by val_loss.
+
+For each run: final val_loss, per-projection grad_cos breakdown, STE flags used.
+"""
+import json
+from pathlib import Path
+
+def get_final(arm_dir):
+ p = arm_dir / "log.jsonl"
+ if not p.exists():
+ return None
+ lines = p.read_text().strip().split("\n")
+ evals = [json.loads(l) for l in lines if l and json.loads(l).get("event") == "eval"]
+ if not evals:
+ return None
+ return evals[-1]
+
+def group_cos(gc):
+ groups = {}
+ for name, val in gc.items():
+ if "q_proj" in name: key = "q"
+ elif "k_proj" in name: key = "k"
+ elif "v_proj" in name: key = "v"
+ elif "o_proj" in name: key = "o"
+ elif "mlp.fc" in name: key = "fc"
+ elif "mlp.proj" in name: key = "pr"
+ elif "head" in name: key = "hd"
+ else: key = "?"
+ groups.setdefault(key, []).append(val)
+ out = {}
+ for k, v in groups.items():
+ valid = [x for x in v if x == x]
+ if valid:
+ out[k] = sum(valid) / len(valid)
+ return out
+
+def main():
+ runs_dir = Path("runs_local")
+ rows = []
+ for d in sorted(runs_dir.iterdir()):
+ if not d.is_dir():
+ continue
+ ev = get_final(d)
+ if ev is None:
+ continue
+ cfg_path = d / "config.json"
+ cfg = json.loads(cfg_path.read_text()) if cfg_path.exists() else {}
+ gc = ev.get("grad_cos") or {}
+ groups = group_cos(gc)
+ valid = [v for v in gc.values() if v == v]
+ mean_cos = sum(valid) / len(valid) if valid else float("nan")
+ flags = []
+ if cfg.get("ste_sigmoid"): flags.append("σSTE")
+ if cfg.get("ste_gelu"): flags.append("gSTE")
+ if cfg.get("ste_ln"): flags.append("lSTE")
+ if cfg.get("freeze_emb"): flags.append("frzE")
+ method = cfg.get("method", "?")
+ attn = cfg.get("attn_mode", "?")
+ nl = cfg.get("n_layer", "?")
+ rows.append({
+ "name": d.name,
+ "method": method,
+ "attn": attn[:3],
+ "L": nl,
+ "flags": "+".join(flags) if flags else "-",
+ "val": ev.get("val_loss", float("nan")),
+ "μcos": mean_cos,
+ "groups": groups,
+ })
+
+ rows.sort(key=lambda r: r["val"] if r["val"] == r["val"] else 999)
+
+ hdr = f"{'name':24s} {'meth':4s} {'attn':3s} {'L':>2s} {'flags':20s} {'val':>8s} {'μcos':>6s} {'hd':>5s} {'o':>5s} {'v':>5s} {'q':>5s} {'k':>5s} {'fc':>5s} {'pr':>5s}"
+ print(hdr)
+ print("-" * len(hdr))
+ for r in rows:
+ g = r["groups"]
+ def fmt(k):
+ v = g.get(k, float("nan"))
+ return f"{v:>5.2f}" if v == v else " nan"
+ val_s = f"{r['val']:>8.4f}" if r["val"] == r["val"] else " nan"
+ cos_s = f"{r['μcos']:>6.3f}" if r["μcos"] == r["μcos"] else " nan"
+ print(f"{r['name']:24s} {r['method']:4s} {r['attn']:3s} {str(r['L']):>2s} {r['flags']:20s} {val_s} {cos_s} {fmt('hd')} {fmt('o')} {fmt('v')} {fmt('q')} {fmt('k')} {fmt('fc')} {fmt('pr')}")
+
+if __name__ == "__main__":
+ main()