summaryrefslogtreecommitdiff
path: root/files/analysis/plot_mvp.py
diff options
context:
space:
mode:
Diffstat (limited to 'files/analysis/plot_mvp.py')
-rw-r--r--files/analysis/plot_mvp.py139
1 files changed, 139 insertions, 0 deletions
diff --git a/files/analysis/plot_mvp.py b/files/analysis/plot_mvp.py
new file mode 100644
index 0000000..495a9c2
--- /dev/null
+++ b/files/analysis/plot_mvp.py
@@ -0,0 +1,139 @@
+import argparse
+import csv
+import json
+import os
+from glob import glob
+from typing import Dict, List, Tuple
+
+import matplotlib.pyplot as plt
+
+
+def load_run(run_dir: str) -> Dict:
+ args_path = os.path.join(run_dir, "args.json")
+ metrics_path = os.path.join(run_dir, "metrics.csv")
+ if not (os.path.exists(args_path) and os.path.exists(metrics_path)):
+ return {}
+ with open(args_path, "r") as f:
+ args = json.load(f)
+ epochs = []
+ loss = []
+ acc = []
+ lyap = []
+ with open(metrics_path, "r") as f:
+ reader = csv.DictReader(f)
+ for row in reader:
+ if row.get("step", "") != "epoch":
+ continue
+ try:
+ e = int(row.get("epoch", "0"))
+ l = float(row.get("loss", "nan"))
+ a = float(row.get("acc", "nan"))
+ y = row.get("lyap", "nan")
+ y = float("nan") if (y is None or y == "" or str(y).lower() == "nan") else float(y)
+ except Exception:
+ continue
+ epochs.append(e)
+ loss.append(l)
+ acc.append(a)
+ lyap.append(y)
+ if not epochs:
+ return {}
+ return {
+ "args": args,
+ "epochs": epochs,
+ "loss": loss,
+ "acc": acc,
+ "lyap": lyap,
+ "run_dir": run_dir,
+ }
+
+
+def label_for_run(run: Dict) -> str:
+ args = run["args"]
+ if args.get("lyapunov", False):
+ lam = args.get("lambda_reg", None)
+ tgt = args.get("lambda_target", None)
+ hid = args.get("hidden", None)
+ return f"Lyap λ={lam}, tgt={tgt}, H={hid}"
+ else:
+ hid = args.get("hidden", None)
+ return f"Baseline H={hid}"
+
+
+def gather_runs(base_dir: str) -> List[Dict]:
+ cand = sorted(glob(os.path.join(base_dir, "*")))
+ runs = []
+ for rd in cand:
+ data = load_run(rd)
+ if data:
+ runs.append(data)
+ return runs
+
+
+def plot_runs(runs: List[Dict], out_path: str):
+ if not runs:
+ raise SystemExit("No valid runs found (expected args.json and metrics.csv under base_dir/*)")
+ # Split baseline vs lyapunov
+ base = [r for r in runs if not r["args"].get("lyapunov", False)]
+ lyap = [r for r in runs if r["args"].get("lyapunov", False)]
+
+ fig, axes = plt.subplots(1, 3, figsize=(14, 4.5))
+
+ # Loss
+ ax = axes[0]
+ for r in base:
+ ax.plot(r["epochs"], r["loss"], label=label_for_run(r), alpha=0.9)
+ for r in lyap:
+ ax.plot(r["epochs"], r["loss"], label=label_for_run(r), linestyle="--", alpha=0.9)
+ ax.set_title("Training loss")
+ ax.set_xlabel("Epoch")
+ ax.set_ylabel("Loss")
+ ax.grid(True, alpha=0.3)
+ ax.legend(fontsize=8)
+
+ # Accuracy
+ ax = axes[1]
+ for r in base:
+ ax.plot(r["epochs"], r["acc"], label=label_for_run(r), alpha=0.9)
+ for r in lyap:
+ ax.plot(r["epochs"], r["acc"], label=label_for_run(r), linestyle="--", alpha=0.9)
+ ax.set_title("Training accuracy")
+ ax.set_xlabel("Epoch")
+ ax.set_ylabel("Accuracy")
+ ax.grid(True, alpha=0.3)
+ ax.legend(fontsize=8)
+
+ # Lyapunov estimate (only lyap runs)
+ ax = axes[2]
+ if lyap:
+ for r in lyap:
+ ax.plot(r["epochs"], r["lyap"], label=label_for_run(r), alpha=0.9)
+ ax.set_title("Surrogate Lyapunov estimate")
+ ax.set_xlabel("Epoch")
+ ax.set_ylabel("Avg log growth")
+ ax.grid(True, alpha=0.3)
+ ax.legend(fontsize=8)
+ else:
+ ax.text(0.5, 0.5, "No Lyapunov runs found", ha="center", va="center", transform=ax.transAxes)
+ ax.set_axis_off()
+
+ fig.tight_layout()
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
+ fig.savefig(out_path, dpi=150)
+ print(f"Saved figure to {out_path}")
+
+
+def main():
+ ap = argparse.ArgumentParser(description="Plot MVP runs: baseline vs Lyapunov comparison")
+ ap.add_argument("--base_dir", type=str, default="runs/mvp", help="Directory containing run subfolders")
+ ap.add_argument("--out", type=str, default="runs/mvp_summary.png", help="Output figure path")
+ args = ap.parse_args()
+
+ runs = gather_runs(args.base_dir)
+ plot_runs(runs, args.out)
+
+
+if __name__ == "__main__":
+ main()
+
+