summaryrefslogtreecommitdiff
path: root/research/flossing/analyze_step3_all.py
blob: bfd2676b6e3c4218590961c2b0f37a2c9f7e86e0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""Compile ABCDEF Step 3 final analysis."""
import json, os
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

ROOT = "/home/yurenh2/rrm/research/flossing"
OUT = f"{ROOT}/plots_step3_final"
os.makedirs(OUT, exist_ok=True)

runs = {
    "A: baseline α=0 from 18228":      ("step3_A_baseline_18228.json",         "C0", "-"),
    "B: CF α=10 λ*=-0.15 from 18228":   ("step3_B_rf_18228.json",               "C3", "-"),
    "C: CF α=10 λ*=-0.05 from 26040":   ("step3_C_rf_26040.json",               "C2", "-"),
    "D: CF α=10 λ*=0 from 26040":       ("step3_D_rf_26040_lstar0.json",        "C4", "-"),
    "E: CF α=10 λ*=0 from 18228":       ("step3_E_rf_18228_lstar0.json",        "C1", "-"),
    "F: extended D (1500 step)":        ("step3_F_rf_26040_lstar0_1500.json",   "C2", "--"),
}

fig, axes = plt.subplots(2, 2, figsize=(15, 9))

summary = []
for label, (fn, color, ls) in runs.items():
    d = json.loads(open(f"{ROOT}/{fn}").read())
    key = label.split(":")[0]

    eval_steps = [e["step"] for e in d["evals"]]
    eval_accs = [e["acc"] for e in d["evals"]]
    steps = [r["step"] for r in d["steps"]]
    sup = np.array([r["sup_loss"] for r in d["steps"]])
    lyap_mean = np.array([r["lyap1_mean"] for r in d["steps"]])
    frac = np.array([r["frac_above_star"] for r in d["steps"]])

    def smooth(x, w=20):
        if len(x) < w: return x
        return np.convolve(x, np.ones(w)/w, mode="same")

    axes[0,0].plot(eval_steps, eval_accs, f"{color}{ls}", marker="o", label=label, lw=1.5, alpha=0.85)
    axes[0,1].plot(steps, smooth(sup), f"{color}{ls}", label=key, alpha=0.7)
    axes[1,0].plot(steps, smooth(lyap_mean), f"{color}{ls}", label=key, alpha=0.7)
    axes[1,1].plot(steps, smooth(frac), f"{color}{ls}", label=key, alpha=0.7)

    summary.append({
        "key": key, "label": label,
        "init_acc": d["initial_acc"],
        "final_acc": d["final_acc"],
        "delta": d["final_acc"] - d["initial_acc"],
        "final_lyap_mean": d["steps"][-1]["lyap1_mean"],
        "final_frac_above": d["steps"][-1]["frac_above_star"],
        "n_steps": len(d["steps"]),
    })

axes[0,0].set_title("Test exact accuracy vs training step")
axes[0,0].set_xlabel("step"); axes[0,0].set_ylabel("exact_acc"); axes[0,0].legend(fontsize=8, loc="best"); axes[0,0].grid(alpha=0.3)

axes[0,1].set_title("Supervised loss (smoothed)")
axes[0,1].set_xlabel("step"); axes[0,1].set_ylabel("sup_loss"); axes[0,1].legend(fontsize=8); axes[0,1].grid(alpha=0.3)

axes[1,0].set_title(r"$\lambda_{joint,1}$ mean trajectory (smoothed)")
axes[1,0].axhline(0, color="k", ls=":", lw=0.6, alpha=0.6)
axes[1,0].set_xlabel("step"); axes[1,0].set_ylabel(r"$\lambda_{1,joint}$"); axes[1,0].legend(fontsize=8); axes[1,0].grid(alpha=0.3)

axes[1,1].set_title(r"Fraction of batch with $\lambda > \lambda^*$ (smoothed)")
axes[1,1].set_xlabel("step"); axes[1,1].set_ylabel("frac > λ*"); axes[1,1].legend(fontsize=8); axes[1,1].grid(alpha=0.3)

fig.suptitle("Step 3 — Contractive Flossing (CF) as training-time regularizer on HRM Sudoku-Extreme-1k", fontsize=11)
fig.tight_layout()
fig.savefig(f"{OUT}/abcdef_full.png", dpi=130)
plt.close()

# Bar chart of final Δ acc
fig, ax = plt.subplots(1, 1, figsize=(9, 5))
keys = [s["key"] for s in summary]
deltas = [s["delta"]*100 for s in summary]
colors = ["C0", "C3", "C2", "C4", "C1", "C2"]
bars = ax.bar(keys, deltas, color=colors)
ax.axhline(0, color="k", lw=0.6)
for bar, delta in zip(bars, deltas):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + (0.3 if delta>0 else -1),
            f"{delta:+.1f}%", ha="center", fontsize=9, fontweight="bold")
ax.set_ylabel("Δ test exact_accuracy (pp)")
ax.set_title("CF intervention: final accuracy change vs baseline")
ax.grid(alpha=0.3, axis="y")
fig.tight_layout()
fig.savefig(f"{OUT}/abcdef_deltas.png", dpi=130)
plt.close()

print(f"{'key':>3} {'init':>7} {'final':>7} {'Δ':>7} {'n_steps':>8} {'final_λ':>9} {'frac>λ*':>9}")
for s in summary:
    print(f"  {s['key']:>3} {s['init_acc']:>7.3f} {s['final_acc']:>7.3f} {s['delta']*100:>6.1f}% {s['n_steps']:>8} "
          f"{s['final_lyap_mean']:>+9.3f} {s['final_frac_above']:>9.2f}")

print(f"\nplots → {OUT}/")