1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
|
"""Analyze all runs in runs_local/ — produces a summary table sorted by val_loss.
For each run: final val_loss, per-projection grad_cos breakdown, STE flags used.
"""
import json
from pathlib import Path
def get_final(arm_dir):
p = arm_dir / "log.jsonl"
if not p.exists():
return None
lines = p.read_text().strip().split("\n")
evals = [json.loads(l) for l in lines if l and json.loads(l).get("event") == "eval"]
if not evals:
return None
return evals[-1]
def group_cos(gc):
groups = {}
for name, val in gc.items():
if "q_proj" in name: key = "q"
elif "k_proj" in name: key = "k"
elif "v_proj" in name: key = "v"
elif "o_proj" in name: key = "o"
elif "mlp.fc" in name: key = "fc"
elif "mlp.proj" in name: key = "pr"
elif "head" in name: key = "hd"
else: key = "?"
groups.setdefault(key, []).append(val)
out = {}
for k, v in groups.items():
valid = [x for x in v if x == x]
if valid:
out[k] = sum(valid) / len(valid)
return out
def main():
runs_dir = Path("runs_local")
rows = []
for d in sorted(runs_dir.iterdir()):
if not d.is_dir():
continue
ev = get_final(d)
if ev is None:
continue
cfg_path = d / "config.json"
cfg = json.loads(cfg_path.read_text()) if cfg_path.exists() else {}
gc = ev.get("grad_cos") or {}
groups = group_cos(gc)
valid = [v for v in gc.values() if v == v]
mean_cos = sum(valid) / len(valid) if valid else float("nan")
flags = []
if cfg.get("ste_sigmoid"): flags.append("σSTE")
if cfg.get("ste_gelu"): flags.append("gSTE")
if cfg.get("ste_ln"): flags.append("lSTE")
if cfg.get("freeze_emb"): flags.append("frzE")
method = cfg.get("method", "?")
attn = cfg.get("attn_mode", "?")
nl = cfg.get("n_layer", "?")
rows.append({
"name": d.name,
"method": method,
"attn": attn[:3],
"L": nl,
"flags": "+".join(flags) if flags else "-",
"val": ev.get("val_loss", float("nan")),
"μcos": mean_cos,
"groups": groups,
})
rows.sort(key=lambda r: r["val"] if r["val"] == r["val"] else 999)
hdr = f"{'name':24s} {'meth':4s} {'attn':3s} {'L':>2s} {'flags':20s} {'val':>8s} {'μcos':>6s} {'hd':>5s} {'o':>5s} {'v':>5s} {'q':>5s} {'k':>5s} {'fc':>5s} {'pr':>5s}"
print(hdr)
print("-" * len(hdr))
for r in rows:
g = r["groups"]
def fmt(k):
v = g.get(k, float("nan"))
return f"{v:>5.2f}" if v == v else " nan"
val_s = f"{r['val']:>8.4f}" if r["val"] == r["val"] else " nan"
cos_s = f"{r['μcos']:>6.3f}" if r["μcos"] == r["μcos"] else " nan"
print(f"{r['name']:24s} {r['method']:4s} {r['attn']:3s} {str(r['L']):>2s} {r['flags']:20s} {val_s} {cos_s} {fmt('hd')} {fmt('o')} {fmt('v')} {fmt('q')} {fmt('k')} {fmt('fc')} {fmt('pr')}")
if __name__ == "__main__":
main()
|