summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-03-23 19:46:08 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-03-23 19:46:08 -0500
commit32123cb36ae9521f60c9b6f67458b931b6540ef2 (patch)
tree4731e1dc513f5b613f80c4d20fc4114044c266d3 /experiments
parentbbb1a36d67f2f0c83106c1e771ea2c2fcb7fd83a (diff)
Add final report, plots, experiment guide, and complete NOTE.md
All experiments complete: - Toy LQ: credit bridge matches state bridge (~0.94 costate cosine) - CIFAR-10: credit bridge (29.6%) comparable to DFA (30.0%), both beat state bridge (18.5%) - State bridge confirms core hypothesis: perfect state prediction != useful credit - Terminal gradient matching is essential for credit bridge
Diffstat (limited to 'experiments')
-rw-r--r--experiments/plot_cifar_final.py208
-rw-r--r--experiments/plot_toy_final.py2
2 files changed, 209 insertions, 1 deletions
diff --git a/experiments/plot_cifar_final.py b/experiments/plot_cifar_final.py
new file mode 100644
index 0000000..b0e60f1
--- /dev/null
+++ b/experiments/plot_cifar_final.py
@@ -0,0 +1,208 @@
+"""Generate final CIFAR-10 plots from 3-seed results."""
+import os, json, numpy as np
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+
+output_dir = 'report'
+os.makedirs(output_dir, exist_ok=True)
+
+# Load all seeds
+all_results = {}
+for seed, path in [(42, 'results/cifar10/results_cifar10.json'),
+ (123, 'results/cifar10_seed123/results_cifar10.json'),
+ (456, 'results/cifar10_seed456/results_cifar10.json')]:
+ with open(path) as f:
+ d = json.load(f)
+ all_results[seed] = d[str(seed)]
+
+methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge']
+colors = {'bp': '#F44336', 'dfa': '#2196F3', 'state_bridge': '#FF9800', 'credit_bridge': '#4CAF50'}
+labels = {'bp': 'BP', 'dfa': 'DFA', 'state_bridge': 'State Bridge', 'credit_bridge': 'Credit Bridge'}
+
+# 1. Accuracy curves
+fig, axes = plt.subplots(1, 2, figsize=(14, 5))
+for method in methods:
+ train_accs = []
+ test_accs = []
+ for seed in [42, 123, 456]:
+ log = all_results[seed][method]['log']
+ train_accs.append(log['train_acc'])
+ test_accs.append(log['test_acc'])
+
+ train_arr = np.array(train_accs)
+ test_arr = np.array(test_accs)
+ epochs = np.arange(1, train_arr.shape[1] + 1)
+
+ for ax, arr, title in zip(axes, [train_arr, test_arr], ['Train Accuracy', 'Test Accuracy']):
+ mean = arr.mean(0)
+ std = arr.std(0)
+ ax.plot(epochs, mean, '-', color=colors[method], label=labels[method])
+ ax.fill_between(epochs, mean - std, mean + std, alpha=0.15, color=colors[method])
+
+for ax, title in zip(axes, ['Train Accuracy', 'Test Accuracy']):
+ ax.set_xlabel('Epoch', fontsize=12)
+ ax.set_ylabel(title, fontsize=12)
+ ax.set_title(title, fontsize=13)
+ ax.legend(fontsize=10)
+ ax.grid(True, alpha=0.3)
+
+fig.suptitle('CIFAR-10 Deep Residual MLP (d=512, L=12, 3 seeds)', fontsize=14, y=1.02)
+fig.tight_layout()
+fig.savefig(os.path.join(output_dir, 'cifar_accuracy.png'), dpi=150, bbox_inches='tight')
+plt.close(fig)
+print("Saved cifar_accuracy.png")
+
+# 2. Per-layer diagnostics (seed 42)
+fig, axes = plt.subplots(1, 3, figsize=(18, 5))
+
+# BP cosine
+ax = axes[0]
+for method in methods:
+ diag = all_results[42][method].get('diagnostics', {})
+ if 'bp_cosine' in diag:
+ layers = range(len(diag['bp_cosine']))
+ ax.plot(layers, diag['bp_cosine'], 'o-', color=colors[method], label=labels[method], markersize=4)
+ax.set_xlabel('Layer')
+ax.set_ylabel('Cosine with BP Gradient')
+ax.set_title('Offline BP Cosine')
+ax.legend(fontsize=9)
+ax.grid(True, alpha=0.3)
+
+# Perturbation rho
+ax = axes[1]
+for method in methods:
+ diag = all_results[42][method].get('diagnostics', {})
+ if 'perturbation_rho' in diag:
+ layers = range(len(diag['perturbation_rho']))
+ ax.plot(layers, diag['perturbation_rho'], 'o-', color=colors[method], label=labels[method], markersize=4)
+ax.set_xlabel('Layer')
+ax.set_ylabel('Perturbation Correlation (ρ)')
+ax.set_title('Local Perturbation Correlation')
+ax.legend(fontsize=9)
+ax.grid(True, alpha=0.3)
+ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+
+# Nudging (eta=0.01)
+ax = axes[2]
+for method in methods:
+ diag = all_results[42][method].get('diagnostics', {})
+ if 'nudging' in diag and '0.01' in diag['nudging']:
+ nud = diag['nudging']['0.01']
+ layers = range(len(nud))
+ ax.plot(layers, nud, 'o-', color=colors[method], label=labels[method], markersize=4)
+ax.set_xlabel('Layer')
+ax.set_ylabel('Nudge Delta (negative=good)')
+ax.set_title('Nudging Test (η=0.01)')
+ax.legend(fontsize=9)
+ax.grid(True, alpha=0.3)
+ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+
+fig.suptitle('CIFAR-10 Per-Layer Diagnostics (seed 42)', fontsize=14, y=1.02)
+fig.tight_layout()
+fig.savefig(os.path.join(output_dir, 'cifar_diagnostics.png'), dpi=150, bbox_inches='tight')
+plt.close(fig)
+print("Saved cifar_diagnostics.png")
+
+# 3. Feature drift
+fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+for method in methods:
+ drift = all_results[42][method].get('drift', {})
+ block_drifts = []
+ for l in range(12):
+ key = f'blocks.{l}.w1.weight'
+ if key in drift:
+ block_drifts.append(drift[key])
+ if block_drifts:
+ ax.plot(range(len(block_drifts)), block_drifts, 'o-', color=colors[method],
+ label=labels[method], markersize=4)
+ax.set_xlabel('Block', fontsize=12)
+ax.set_ylabel('Feature Drift ||W_final - W_init|| / ||W_init||', fontsize=11)
+ax.set_title('Feature Drift per Block (CIFAR-10, seed 42)', fontsize=13)
+ax.legend(fontsize=10)
+ax.grid(True, alpha=0.3)
+fig.tight_layout()
+fig.savefig(os.path.join(output_dir, 'cifar_feature_drift.png'), dpi=150)
+plt.close(fig)
+print("Saved cifar_feature_drift.png")
+
+# 4. State bridge: prediction quality vs credit quality
+fig, axes = plt.subplots(1, 2, figsize=(12, 5))
+
+# State prediction error over epochs
+ax = axes[0]
+for seed in [42, 123, 456]:
+ log = all_results[seed]['state_bridge']['log']
+ if 'state_pred_error' in log:
+ epochs = range(1, len(log['state_pred_error']) + 1)
+ ax.plot(epochs, log['state_pred_error'], '-', alpha=0.7, label=f'Seed {seed}')
+ax.set_xlabel('Epoch', fontsize=12)
+ax.set_ylabel('State Prediction Error', fontsize=12)
+ax.set_title('State Bridge: Prediction Error', fontsize=13)
+ax.set_yscale('log')
+ax.legend()
+ax.grid(True, alpha=0.3)
+
+# Compare: prediction error (near zero) vs accuracy (poor)
+ax = axes[1]
+methods_compare = ['dfa', 'state_bridge', 'credit_bridge']
+test_accs = {m: [] for m in methods_compare}
+for seed in [42, 123, 456]:
+ for m in methods_compare:
+ test_accs[m].append(all_results[seed][m]['log']['test_acc'][-1])
+
+x_pos = range(len(methods_compare))
+means = [np.mean(test_accs[m]) for m in methods_compare]
+stds = [np.std(test_accs[m]) for m in methods_compare]
+bar_colors = [colors[m] for m in methods_compare]
+bar_labels = [labels[m] for m in methods_compare]
+
+bars = ax.bar(x_pos, means, yerr=stds, color=bar_colors, capsize=5, alpha=0.8)
+ax.set_xticks(x_pos)
+ax.set_xticklabels(bar_labels, fontsize=11)
+ax.set_ylabel('Test Accuracy', fontsize=12)
+ax.set_title('Test Accuracy Comparison\n(State Bridge predicts h_L perfectly\nbut produces worst credit)', fontsize=12)
+ax.grid(True, alpha=0.3, axis='y')
+
+# Add text annotation
+ax.annotate('State pred err ≈ 0.0000\nbut worst accuracy!',
+ xy=(1, means[1]), xytext=(1.3, means[1] + 0.06),
+ arrowprops=dict(arrowstyle='->', color='black'),
+ fontsize=9, ha='center')
+
+fig.tight_layout()
+fig.savefig(os.path.join(output_dir, 'cifar_state_vs_credit.png'), dpi=150)
+plt.close(fig)
+print("Saved cifar_state_vs_credit.png")
+
+# 5. Combined summary bar chart
+fig, ax = plt.subplots(figsize=(10, 6))
+x = np.arange(len(methods))
+width = 0.6
+test_accs_all = {}
+for method in methods:
+ accs = [all_results[seed][method]['log']['test_acc'][-1] for seed in [42, 123, 456]]
+ test_accs_all[method] = (np.mean(accs), np.std(accs))
+
+means = [test_accs_all[m][0] for m in methods]
+stds = [test_accs_all[m][1] for m in methods]
+bar_colors = [colors[m] for m in methods]
+
+bars = ax.bar(x, means, width, yerr=stds, color=bar_colors, capsize=5, alpha=0.85)
+ax.set_ylabel('Test Accuracy', fontsize=13)
+ax.set_title('CIFAR-10 Test Accuracy (3 seeds, 100 epochs)', fontsize=14)
+ax.set_xticks(x)
+ax.set_xticklabels([labels[m] for m in methods], fontsize=12)
+ax.grid(True, alpha=0.3, axis='y')
+
+# Add value labels
+for bar, mean, std in zip(bars, means, stds):
+ ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + std + 0.005,
+ f'{mean:.1%}', ha='center', va='bottom', fontweight='bold', fontsize=11)
+
+fig.tight_layout()
+fig.savefig(os.path.join(output_dir, 'cifar_summary.png'), dpi=150)
+plt.close(fig)
+print("Saved cifar_summary.png")
+
+print("\nAll CIFAR plots generated.")
diff --git a/experiments/plot_toy_final.py b/experiments/plot_toy_final.py
index 2f7c109..e4e1a68 100644
--- a/experiments/plot_toy_final.py
+++ b/experiments/plot_toy_final.py
@@ -13,7 +13,7 @@ os.makedirs(output_dir, exist_ok=True)
seeds = [42, 123, 456]
all_data = []
for seed in seeds:
- path = f'results/toy_lq/toy_lq_v2_seed{seed}_lam0.1_sig0.1_tgw1.0_fm0.0.json'
+ path = f'results/toy_lq_frozen/toy_lq_v2_seed{seed}_lam0.1_sig0.1_tgw1.0_fm0.0.json'
if os.path.exists(path):
with open(path) as f:
all_data.append(json.load(f))