summaryrefslogtreecommitdiff
path: root/ep_run/analyze.py
blob: 9514773f60544864b2a1066710cf7e8ce5814cfd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""Load runs/*/log.jsonl, print comparison table, and save loss curves as ASCII/PNG."""
import json
import sys
from pathlib import Path

import numpy as np


def load_run(run_dir: Path):
    log_path = run_dir / "log.jsonl"
    if not log_path.exists():
        return None
    steps, step_losses, evals = [], [], []
    for line in log_path.read_text().splitlines():
        rec = json.loads(line)
        if rec.get("event") == "step":
            steps.append(rec["iter"])
            step_losses.append(rec["train_loss"])
        elif rec.get("event") == "eval":
            evals.append((rec["iter"], rec["train_loss"], rec["val_loss"]))
    return {
        "name": run_dir.name,
        "steps": np.array(steps),
        "step_losses": np.array(step_losses),
        "evals": np.array(evals) if evals else np.zeros((0, 3)),
    }


def ascii_plot(runs, key_idx=2, width=60, height=15, title="val loss"):
    """key_idx: 1 = train loss (from eval), 2 = val loss (from eval)."""
    lines = [title]
    all_y = np.concatenate([r["evals"][:, key_idx] for r in runs if len(r["evals"]) > 0])
    all_x = np.concatenate([r["evals"][:, 0] for r in runs if len(r["evals"]) > 0])
    if len(all_y) == 0:
        return "(no eval data)"
    ymin, ymax = float(all_y.min()), float(all_y.max())
    xmin, xmax = float(all_x.min()), float(all_x.max())
    ymin -= 0.02 * (ymax - ymin + 1e-9)
    ymax += 0.02 * (ymax - ymin + 1e-9)
    grid = [[" "] * width for _ in range(height)]
    markers = {0: "o", 1: "x", 2: "+", 3: "*"}
    for i, r in enumerate(runs):
        if len(r["evals"]) == 0:
            continue
        mk = markers.get(i, "#")
        for x, _tl, vl in r["evals"]:
            col = int((x - xmin) / (xmax - xmin + 1e-9) * (width - 1))
            row = height - 1 - int((vl - ymin) / (ymax - ymin + 1e-9) * (height - 1))
            row = max(0, min(height - 1, row))
            col = max(0, min(width - 1, col))
            grid[row][col] = mk
    lines.append(f"  y: [{ymin:.3f} .. {ymax:.3f}]   x: [{int(xmin)} .. {int(xmax)}]")
    for i, row in enumerate(grid):
        lines.append("  |" + "".join(row) + "|")
    lines.append("  +" + "-" * width + "+")
    legend = "  legend: " + "  ".join(
        f"{markers.get(i, '#')}={r['name']}" for i, r in enumerate(runs)
    )
    lines.append(legend)
    return "\n".join(lines)


def main():
    runs_dir = Path("runs")
    run_names = sys.argv[1:] if len(sys.argv) > 1 else [
        "softmax_baseline", "sigmoid_b0", "sigmoid_blogn",
    ]
    runs = []
    for name in run_names:
        r = load_run(runs_dir / name)
        if r is None:
            print(f"WARNING: {name}/log.jsonl missing")
            continue
        runs.append(r)

    print("\n=== final losses ===")
    print(f"{'run':<20s}  {'final train':>12s}  {'final val':>10s}  {'best val':>10s}  {'iter':>6s}")
    for r in runs:
        if len(r["evals"]) == 0:
            print(f"{r['name']:<20s}  (no evals)")
            continue
        last = r["evals"][-1]
        best_idx = int(np.argmin(r["evals"][:, 2]))
        best = r["evals"][best_idx]
        print(
            f"{r['name']:<20s}  {last[1]:>12.4f}  {last[2]:>10.4f}  "
            f"{best[2]:>10.4f}  {int(best[0]):>6d}"
        )

    print("\n" + ascii_plot(runs, key_idx=2, title="val loss vs iter"))
    print("\n" + ascii_plot(runs, key_idx=1, title="eval train loss vs iter"))


if __name__ == "__main__":
    main()