"""Analyze all runs in runs_local/ — produces a summary table sorted by val_loss. For each run: final val_loss, per-projection grad_cos breakdown, STE flags used. """ import json from pathlib import Path def get_final(arm_dir): p = arm_dir / "log.jsonl" if not p.exists(): return None lines = p.read_text().strip().split("\n") evals = [json.loads(l) for l in lines if l and json.loads(l).get("event") == "eval"] if not evals: return None return evals[-1] def group_cos(gc): groups = {} for name, val in gc.items(): if "q_proj" in name: key = "q" elif "k_proj" in name: key = "k" elif "v_proj" in name: key = "v" elif "o_proj" in name: key = "o" elif "mlp.fc" in name: key = "fc" elif "mlp.proj" in name: key = "pr" elif "head" in name: key = "hd" else: key = "?" groups.setdefault(key, []).append(val) out = {} for k, v in groups.items(): valid = [x for x in v if x == x] if valid: out[k] = sum(valid) / len(valid) return out def main(): runs_dir = Path("runs_local") rows = [] for d in sorted(runs_dir.iterdir()): if not d.is_dir(): continue ev = get_final(d) if ev is None: continue cfg_path = d / "config.json" cfg = json.loads(cfg_path.read_text()) if cfg_path.exists() else {} gc = ev.get("grad_cos") or {} groups = group_cos(gc) valid = [v for v in gc.values() if v == v] mean_cos = sum(valid) / len(valid) if valid else float("nan") flags = [] if cfg.get("ste_sigmoid"): flags.append("σSTE") if cfg.get("ste_gelu"): flags.append("gSTE") if cfg.get("ste_ln"): flags.append("lSTE") if cfg.get("freeze_emb"): flags.append("frzE") method = cfg.get("method", "?") attn = cfg.get("attn_mode", "?") nl = cfg.get("n_layer", "?") rows.append({ "name": d.name, "method": method, "attn": attn[:3], "L": nl, "flags": "+".join(flags) if flags else "-", "val": ev.get("val_loss", float("nan")), "μcos": mean_cos, "groups": groups, }) rows.sort(key=lambda r: r["val"] if r["val"] == r["val"] else 999) hdr = f"{'name':24s} {'meth':4s} {'attn':3s} {'L':>2s} {'flags':20s} {'val':>8s} {'μcos':>6s} {'hd':>5s} {'o':>5s} {'v':>5s} {'q':>5s} {'k':>5s} {'fc':>5s} {'pr':>5s}" print(hdr) print("-" * len(hdr)) for r in rows: g = r["groups"] def fmt(k): v = g.get(k, float("nan")) return f"{v:>5.2f}" if v == v else " nan" val_s = f"{r['val']:>8.4f}" if r["val"] == r["val"] else " nan" cos_s = f"{r['μcos']:>6.3f}" if r["μcos"] == r["μcos"] else " nan" print(f"{r['name']:24s} {r['method']:4s} {r['attn']:3s} {str(r['L']):>2s} {r['flags']:20s} {val_s} {cos_s} {fmt('hd')} {fmt('o')} {fmt('v')} {fmt('q')} {fmt('k')} {fmt('fc')} {fmt('pr')}") if __name__ == "__main__": main()