summaryrefslogtreecommitdiff
path: root/ep_run/analyze.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/analyze.py')
-rw-r--r--ep_run/analyze.py95
1 files changed, 95 insertions, 0 deletions
diff --git a/ep_run/analyze.py b/ep_run/analyze.py
new file mode 100644
index 0000000..9514773
--- /dev/null
+++ b/ep_run/analyze.py
@@ -0,0 +1,95 @@
+"""Load runs/*/log.jsonl, print comparison table, and save loss curves as ASCII/PNG."""
+import json
+import sys
+from pathlib import Path
+
+import numpy as np
+
+
+def load_run(run_dir: Path):
+ log_path = run_dir / "log.jsonl"
+ if not log_path.exists():
+ return None
+ steps, step_losses, evals = [], [], []
+ for line in log_path.read_text().splitlines():
+ rec = json.loads(line)
+ if rec.get("event") == "step":
+ steps.append(rec["iter"])
+ step_losses.append(rec["train_loss"])
+ elif rec.get("event") == "eval":
+ evals.append((rec["iter"], rec["train_loss"], rec["val_loss"]))
+ return {
+ "name": run_dir.name,
+ "steps": np.array(steps),
+ "step_losses": np.array(step_losses),
+ "evals": np.array(evals) if evals else np.zeros((0, 3)),
+ }
+
+
+def ascii_plot(runs, key_idx=2, width=60, height=15, title="val loss"):
+ """key_idx: 1 = train loss (from eval), 2 = val loss (from eval)."""
+ lines = [title]
+ all_y = np.concatenate([r["evals"][:, key_idx] for r in runs if len(r["evals"]) > 0])
+ all_x = np.concatenate([r["evals"][:, 0] for r in runs if len(r["evals"]) > 0])
+ if len(all_y) == 0:
+ return "(no eval data)"
+ ymin, ymax = float(all_y.min()), float(all_y.max())
+ xmin, xmax = float(all_x.min()), float(all_x.max())
+ ymin -= 0.02 * (ymax - ymin + 1e-9)
+ ymax += 0.02 * (ymax - ymin + 1e-9)
+ grid = [[" "] * width for _ in range(height)]
+ markers = {0: "o", 1: "x", 2: "+", 3: "*"}
+ for i, r in enumerate(runs):
+ if len(r["evals"]) == 0:
+ continue
+ mk = markers.get(i, "#")
+ for x, _tl, vl in r["evals"]:
+ col = int((x - xmin) / (xmax - xmin + 1e-9) * (width - 1))
+ row = height - 1 - int((vl - ymin) / (ymax - ymin + 1e-9) * (height - 1))
+ row = max(0, min(height - 1, row))
+ col = max(0, min(width - 1, col))
+ grid[row][col] = mk
+ lines.append(f" y: [{ymin:.3f} .. {ymax:.3f}] x: [{int(xmin)} .. {int(xmax)}]")
+ for i, row in enumerate(grid):
+ lines.append(" |" + "".join(row) + "|")
+ lines.append(" +" + "-" * width + "+")
+ legend = " legend: " + " ".join(
+ f"{markers.get(i, '#')}={r['name']}" for i, r in enumerate(runs)
+ )
+ lines.append(legend)
+ return "\n".join(lines)
+
+
+def main():
+ runs_dir = Path("runs")
+ run_names = sys.argv[1:] if len(sys.argv) > 1 else [
+ "softmax_baseline", "sigmoid_b0", "sigmoid_blogn",
+ ]
+ runs = []
+ for name in run_names:
+ r = load_run(runs_dir / name)
+ if r is None:
+ print(f"WARNING: {name}/log.jsonl missing")
+ continue
+ runs.append(r)
+
+ print("\n=== final losses ===")
+ print(f"{'run':<20s} {'final train':>12s} {'final val':>10s} {'best val':>10s} {'iter':>6s}")
+ for r in runs:
+ if len(r["evals"]) == 0:
+ print(f"{r['name']:<20s} (no evals)")
+ continue
+ last = r["evals"][-1]
+ best_idx = int(np.argmin(r["evals"][:, 2]))
+ best = r["evals"][best_idx]
+ print(
+ f"{r['name']:<20s} {last[1]:>12.4f} {last[2]:>10.4f} "
+ f"{best[2]:>10.4f} {int(best[0]):>6d}"
+ )
+
+ print("\n" + ascii_plot(runs, key_idx=2, title="val loss vs iter"))
+ print("\n" + ascii_plot(runs, key_idx=1, title="eval train loss vs iter"))
+
+
+if __name__ == "__main__":
+ main()