summaryrefslogtreecommitdiff
path: root/protocol/examples/plot_penalty_rescue.py
diff options
context:
space:
mode:
Diffstat (limited to 'protocol/examples/plot_penalty_rescue.py')
-rw-r--r--protocol/examples/plot_penalty_rescue.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/protocol/examples/plot_penalty_rescue.py b/protocol/examples/plot_penalty_rescue.py
index 37b0fa9..fff300e 100644
--- a/protocol/examples/plot_penalty_rescue.py
+++ b/protocol/examples/plot_penalty_rescue.py
@@ -10,7 +10,7 @@ Data sources:
- vanilla DFA trajectory: results/snapshot_evolution_v2/snapshot_evolution_s42.json
- penalized DFA (lam=1e-2): results/dfa_residual_penalty/dfa_pen_lam0.01_s42.json
- DFA-shallow baseline 3-seed mean (drawn as horizontal line): 0.349
- - BP-trainable 3-seed mean: 0.609
+ - BP-trainable 3-seed mean: 0.6147 (100 ep) / 0.585 (matched 30 ep)
Run:
python -m protocol.examples.plot_penalty_rescue
@@ -91,7 +91,7 @@ def main():
ax.plot([e["epoch"] for e in penalty], [e["acc_eval"] for e in penalty],
label=r"DFA + $\lambda \|f_l\|^2$", color="C2", lw=2, marker="o", markersize=4)
ax.axhline(0.349, color="k", linestyle="--", lw=1.2, label="DFA-shallow 0.349")
- ax.axhline(0.609, color="C0", linestyle=":", lw=1, label="BP-trainable 0.609")
+ ax.axhline(0.6147, color="C0", linestyle=":", lw=1, label="BP-trainable 100ep 0.615")
ax.set_xlabel("epoch", fontsize=10)
ax.set_ylabel("test acc", fontsize=10)
ax.set_title("(d) headline accuracy", fontsize=11)