summaryrefslogtreecommitdiff
path: root/ep_run/analyze_all.py
blob: f309190ee0d98d9aea3142ae1f0da968c52300b9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""Analyze all runs in runs_local/ — produces a summary table sorted by val_loss.

For each run: final val_loss, per-projection grad_cos breakdown, STE flags used.
"""
import json
from pathlib import Path

def get_final(arm_dir):
    p = arm_dir / "log.jsonl"
    if not p.exists():
        return None
    lines = p.read_text().strip().split("\n")
    evals = [json.loads(l) for l in lines if l and json.loads(l).get("event") == "eval"]
    if not evals:
        return None
    return evals[-1]

def group_cos(gc):
    groups = {}
    for name, val in gc.items():
        if "q_proj" in name: key = "q"
        elif "k_proj" in name: key = "k"
        elif "v_proj" in name: key = "v"
        elif "o_proj" in name: key = "o"
        elif "mlp.fc" in name: key = "fc"
        elif "mlp.proj" in name: key = "pr"
        elif "head" in name: key = "hd"
        else: key = "?"
        groups.setdefault(key, []).append(val)
    out = {}
    for k, v in groups.items():
        valid = [x for x in v if x == x]
        if valid:
            out[k] = sum(valid) / len(valid)
    return out

def main():
    runs_dir = Path("runs_local")
    rows = []
    for d in sorted(runs_dir.iterdir()):
        if not d.is_dir():
            continue
        ev = get_final(d)
        if ev is None:
            continue
        cfg_path = d / "config.json"
        cfg = json.loads(cfg_path.read_text()) if cfg_path.exists() else {}
        gc = ev.get("grad_cos") or {}
        groups = group_cos(gc)
        valid = [v for v in gc.values() if v == v]
        mean_cos = sum(valid) / len(valid) if valid else float("nan")
        flags = []
        if cfg.get("ste_sigmoid"): flags.append("σSTE")
        if cfg.get("ste_gelu"): flags.append("gSTE")
        if cfg.get("ste_ln"): flags.append("lSTE")
        if cfg.get("freeze_emb"): flags.append("frzE")
        method = cfg.get("method", "?")
        attn = cfg.get("attn_mode", "?")
        nl = cfg.get("n_layer", "?")
        rows.append({
            "name": d.name,
            "method": method,
            "attn": attn[:3],
            "L": nl,
            "flags": "+".join(flags) if flags else "-",
            "val": ev.get("val_loss", float("nan")),
            "μcos": mean_cos,
            "groups": groups,
        })

    rows.sort(key=lambda r: r["val"] if r["val"] == r["val"] else 999)

    hdr = f"{'name':24s} {'meth':4s} {'attn':3s} {'L':>2s} {'flags':20s} {'val':>8s} {'μcos':>6s}  {'hd':>5s} {'o':>5s} {'v':>5s} {'q':>5s} {'k':>5s} {'fc':>5s} {'pr':>5s}"
    print(hdr)
    print("-" * len(hdr))
    for r in rows:
        g = r["groups"]
        def fmt(k):
            v = g.get(k, float("nan"))
            return f"{v:>5.2f}" if v == v else "  nan"
        val_s = f"{r['val']:>8.4f}" if r["val"] == r["val"] else "     nan"
        cos_s = f"{r['μcos']:>6.3f}" if r["μcos"] == r["μcos"] else "   nan"
        print(f"{r['name']:24s} {r['method']:4s} {r['attn']:3s} {str(r['L']):>2s} {r['flags']:20s} {val_s} {cos_s}  {fmt('hd')} {fmt('o')} {fmt('v')} {fmt('q')} {fmt('k')} {fmt('fc')} {fmt('pr')}")

if __name__ == "__main__":
    main()