import argparse import csv import json import os from glob import glob from typing import Dict, List, Tuple import matplotlib.pyplot as plt def load_run(run_dir: str) -> Dict: args_path = os.path.join(run_dir, "args.json") metrics_path = os.path.join(run_dir, "metrics.csv") if not (os.path.exists(args_path) and os.path.exists(metrics_path)): return {} with open(args_path, "r") as f: args = json.load(f) epochs = [] loss = [] acc = [] lyap = [] with open(metrics_path, "r") as f: reader = csv.DictReader(f) for row in reader: if row.get("step", "") != "epoch": continue try: e = int(row.get("epoch", "0")) l = float(row.get("loss", "nan")) a = float(row.get("acc", "nan")) y = row.get("lyap", "nan") y = float("nan") if (y is None or y == "" or str(y).lower() == "nan") else float(y) except Exception: continue epochs.append(e) loss.append(l) acc.append(a) lyap.append(y) if not epochs: return {} return { "args": args, "epochs": epochs, "loss": loss, "acc": acc, "lyap": lyap, "run_dir": run_dir, } def label_for_run(run: Dict) -> str: args = run["args"] if args.get("lyapunov", False): lam = args.get("lambda_reg", None) tgt = args.get("lambda_target", None) hid = args.get("hidden", None) return f"Lyap λ={lam}, tgt={tgt}, H={hid}" else: hid = args.get("hidden", None) return f"Baseline H={hid}" def gather_runs(base_dir: str) -> List[Dict]: cand = sorted(glob(os.path.join(base_dir, "*"))) runs = [] for rd in cand: data = load_run(rd) if data: runs.append(data) return runs def plot_runs(runs: List[Dict], out_path: str): if not runs: raise SystemExit("No valid runs found (expected args.json and metrics.csv under base_dir/*)") # Split baseline vs lyapunov base = [r for r in runs if not r["args"].get("lyapunov", False)] lyap = [r for r in runs if r["args"].get("lyapunov", False)] fig, axes = plt.subplots(1, 3, figsize=(14, 4.5)) # Loss ax = axes[0] for r in base: ax.plot(r["epochs"], r["loss"], label=label_for_run(r), alpha=0.9) for r in lyap: ax.plot(r["epochs"], r["loss"], label=label_for_run(r), linestyle="--", alpha=0.9) ax.set_title("Training loss") ax.set_xlabel("Epoch") ax.set_ylabel("Loss") ax.grid(True, alpha=0.3) ax.legend(fontsize=8) # Accuracy ax = axes[1] for r in base: ax.plot(r["epochs"], r["acc"], label=label_for_run(r), alpha=0.9) for r in lyap: ax.plot(r["epochs"], r["acc"], label=label_for_run(r), linestyle="--", alpha=0.9) ax.set_title("Training accuracy") ax.set_xlabel("Epoch") ax.set_ylabel("Accuracy") ax.grid(True, alpha=0.3) ax.legend(fontsize=8) # Lyapunov estimate (only lyap runs) ax = axes[2] if lyap: for r in lyap: ax.plot(r["epochs"], r["lyap"], label=label_for_run(r), alpha=0.9) ax.set_title("Surrogate Lyapunov estimate") ax.set_xlabel("Epoch") ax.set_ylabel("Avg log growth") ax.grid(True, alpha=0.3) ax.legend(fontsize=8) else: ax.text(0.5, 0.5, "No Lyapunov runs found", ha="center", va="center", transform=ax.transAxes) ax.set_axis_off() fig.tight_layout() os.makedirs(os.path.dirname(out_path), exist_ok=True) fig.savefig(out_path, dpi=150) print(f"Saved figure to {out_path}") def main(): ap = argparse.ArgumentParser(description="Plot MVP runs: baseline vs Lyapunov comparison") ap.add_argument("--base_dir", type=str, default="runs/mvp", help="Directory containing run subfolders") ap.add_argument("--out", type=str, default="runs/mvp_summary.png", help="Output figure path") args = ap.parse_args() runs = gather_runs(args.base_dir) plot_runs(runs, args.out) if __name__ == "__main__": main()