From cd99d6b874d9d09b3bb87b8485cc787885af71f1 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Tue, 13 Jan 2026 23:49:05 -0600 Subject: init commit --- files/analysis/plot_mvp.py | 139 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 139 insertions(+) create mode 100644 files/analysis/plot_mvp.py (limited to 'files/analysis/plot_mvp.py') diff --git a/files/analysis/plot_mvp.py b/files/analysis/plot_mvp.py new file mode 100644 index 0000000..495a9c2 --- /dev/null +++ b/files/analysis/plot_mvp.py @@ -0,0 +1,139 @@ +import argparse +import csv +import json +import os +from glob import glob +from typing import Dict, List, Tuple + +import matplotlib.pyplot as plt + + +def load_run(run_dir: str) -> Dict: + args_path = os.path.join(run_dir, "args.json") + metrics_path = os.path.join(run_dir, "metrics.csv") + if not (os.path.exists(args_path) and os.path.exists(metrics_path)): + return {} + with open(args_path, "r") as f: + args = json.load(f) + epochs = [] + loss = [] + acc = [] + lyap = [] + with open(metrics_path, "r") as f: + reader = csv.DictReader(f) + for row in reader: + if row.get("step", "") != "epoch": + continue + try: + e = int(row.get("epoch", "0")) + l = float(row.get("loss", "nan")) + a = float(row.get("acc", "nan")) + y = row.get("lyap", "nan") + y = float("nan") if (y is None or y == "" or str(y).lower() == "nan") else float(y) + except Exception: + continue + epochs.append(e) + loss.append(l) + acc.append(a) + lyap.append(y) + if not epochs: + return {} + return { + "args": args, + "epochs": epochs, + "loss": loss, + "acc": acc, + "lyap": lyap, + "run_dir": run_dir, + } + + +def label_for_run(run: Dict) -> str: + args = run["args"] + if args.get("lyapunov", False): + lam = args.get("lambda_reg", None) + tgt = args.get("lambda_target", None) + hid = args.get("hidden", None) + return f"Lyap λ={lam}, tgt={tgt}, H={hid}" + else: + hid = args.get("hidden", None) + return f"Baseline H={hid}" + + +def gather_runs(base_dir: str) -> List[Dict]: + cand = sorted(glob(os.path.join(base_dir, "*"))) + runs = [] + for rd in cand: + data = load_run(rd) + if data: + runs.append(data) + return runs + + +def plot_runs(runs: List[Dict], out_path: str): + if not runs: + raise SystemExit("No valid runs found (expected args.json and metrics.csv under base_dir/*)") + # Split baseline vs lyapunov + base = [r for r in runs if not r["args"].get("lyapunov", False)] + lyap = [r for r in runs if r["args"].get("lyapunov", False)] + + fig, axes = plt.subplots(1, 3, figsize=(14, 4.5)) + + # Loss + ax = axes[0] + for r in base: + ax.plot(r["epochs"], r["loss"], label=label_for_run(r), alpha=0.9) + for r in lyap: + ax.plot(r["epochs"], r["loss"], label=label_for_run(r), linestyle="--", alpha=0.9) + ax.set_title("Training loss") + ax.set_xlabel("Epoch") + ax.set_ylabel("Loss") + ax.grid(True, alpha=0.3) + ax.legend(fontsize=8) + + # Accuracy + ax = axes[1] + for r in base: + ax.plot(r["epochs"], r["acc"], label=label_for_run(r), alpha=0.9) + for r in lyap: + ax.plot(r["epochs"], r["acc"], label=label_for_run(r), linestyle="--", alpha=0.9) + ax.set_title("Training accuracy") + ax.set_xlabel("Epoch") + ax.set_ylabel("Accuracy") + ax.grid(True, alpha=0.3) + ax.legend(fontsize=8) + + # Lyapunov estimate (only lyap runs) + ax = axes[2] + if lyap: + for r in lyap: + ax.plot(r["epochs"], r["lyap"], label=label_for_run(r), alpha=0.9) + ax.set_title("Surrogate Lyapunov estimate") + ax.set_xlabel("Epoch") + ax.set_ylabel("Avg log growth") + ax.grid(True, alpha=0.3) + ax.legend(fontsize=8) + else: + ax.text(0.5, 0.5, "No Lyapunov runs found", ha="center", va="center", transform=ax.transAxes) + ax.set_axis_off() + + fig.tight_layout() + os.makedirs(os.path.dirname(out_path), exist_ok=True) + fig.savefig(out_path, dpi=150) + print(f"Saved figure to {out_path}") + + +def main(): + ap = argparse.ArgumentParser(description="Plot MVP runs: baseline vs Lyapunov comparison") + ap.add_argument("--base_dir", type=str, default="runs/mvp", help="Directory containing run subfolders") + ap.add_argument("--out", type=str, default="runs/mvp_summary.png", help="Output figure path") + args = ap.parse_args() + + runs = gather_runs(args.base_dir) + plot_runs(runs, args.out) + + +if __name__ == "__main__": + main() + + -- cgit v1.2.3