1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
|
import argparse
import csv
import json
import os
from glob import glob
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt
def load_run(run_dir: str) -> Dict:
args_path = os.path.join(run_dir, "args.json")
metrics_path = os.path.join(run_dir, "metrics.csv")
if not (os.path.exists(args_path) and os.path.exists(metrics_path)):
return {}
with open(args_path, "r") as f:
args = json.load(f)
epochs = []
loss = []
acc = []
lyap = []
with open(metrics_path, "r") as f:
reader = csv.DictReader(f)
for row in reader:
if row.get("step", "") != "epoch":
continue
try:
e = int(row.get("epoch", "0"))
l = float(row.get("loss", "nan"))
a = float(row.get("acc", "nan"))
y = row.get("lyap", "nan")
y = float("nan") if (y is None or y == "" or str(y).lower() == "nan") else float(y)
except Exception:
continue
epochs.append(e)
loss.append(l)
acc.append(a)
lyap.append(y)
if not epochs:
return {}
return {
"args": args,
"epochs": epochs,
"loss": loss,
"acc": acc,
"lyap": lyap,
"run_dir": run_dir,
}
def label_for_run(run: Dict) -> str:
args = run["args"]
if args.get("lyapunov", False):
lam = args.get("lambda_reg", None)
tgt = args.get("lambda_target", None)
hid = args.get("hidden", None)
return f"Lyap λ={lam}, tgt={tgt}, H={hid}"
else:
hid = args.get("hidden", None)
return f"Baseline H={hid}"
def gather_runs(base_dir: str) -> List[Dict]:
cand = sorted(glob(os.path.join(base_dir, "*")))
runs = []
for rd in cand:
data = load_run(rd)
if data:
runs.append(data)
return runs
def plot_runs(runs: List[Dict], out_path: str):
if not runs:
raise SystemExit("No valid runs found (expected args.json and metrics.csv under base_dir/*)")
# Split baseline vs lyapunov
base = [r for r in runs if not r["args"].get("lyapunov", False)]
lyap = [r for r in runs if r["args"].get("lyapunov", False)]
fig, axes = plt.subplots(1, 3, figsize=(14, 4.5))
# Loss
ax = axes[0]
for r in base:
ax.plot(r["epochs"], r["loss"], label=label_for_run(r), alpha=0.9)
for r in lyap:
ax.plot(r["epochs"], r["loss"], label=label_for_run(r), linestyle="--", alpha=0.9)
ax.set_title("Training loss")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.grid(True, alpha=0.3)
ax.legend(fontsize=8)
# Accuracy
ax = axes[1]
for r in base:
ax.plot(r["epochs"], r["acc"], label=label_for_run(r), alpha=0.9)
for r in lyap:
ax.plot(r["epochs"], r["acc"], label=label_for_run(r), linestyle="--", alpha=0.9)
ax.set_title("Training accuracy")
ax.set_xlabel("Epoch")
ax.set_ylabel("Accuracy")
ax.grid(True, alpha=0.3)
ax.legend(fontsize=8)
# Lyapunov estimate (only lyap runs)
ax = axes[2]
if lyap:
for r in lyap:
ax.plot(r["epochs"], r["lyap"], label=label_for_run(r), alpha=0.9)
ax.set_title("Surrogate Lyapunov estimate")
ax.set_xlabel("Epoch")
ax.set_ylabel("Avg log growth")
ax.grid(True, alpha=0.3)
ax.legend(fontsize=8)
else:
ax.text(0.5, 0.5, "No Lyapunov runs found", ha="center", va="center", transform=ax.transAxes)
ax.set_axis_off()
fig.tight_layout()
os.makedirs(os.path.dirname(out_path), exist_ok=True)
fig.savefig(out_path, dpi=150)
print(f"Saved figure to {out_path}")
def main():
ap = argparse.ArgumentParser(description="Plot MVP runs: baseline vs Lyapunov comparison")
ap.add_argument("--base_dir", type=str, default="runs/mvp", help="Directory containing run subfolders")
ap.add_argument("--out", type=str, default="runs/mvp_summary.png", help="Output figure path")
args = ap.parse_args()
runs = gather_runs(args.base_dir)
plot_runs(runs, args.out)
if __name__ == "__main__":
main()
|