From 32123cb36ae9521f60c9b6f67458b931b6540ef2 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 23 Mar 2026 19:46:08 -0500 Subject: Add final report, plots, experiment guide, and complete NOTE.md All experiments complete: - Toy LQ: credit bridge matches state bridge (~0.94 costate cosine) - CIFAR-10: credit bridge (29.6%) comparable to DFA (30.0%), both beat state bridge (18.5%) - State bridge confirms core hypothesis: perfect state prediction != useful credit - Terminal gradient matching is essential for credit bridge --- experiments/plot_cifar_final.py | 208 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 208 insertions(+) create mode 100644 experiments/plot_cifar_final.py (limited to 'experiments/plot_cifar_final.py') diff --git a/experiments/plot_cifar_final.py b/experiments/plot_cifar_final.py new file mode 100644 index 0000000..b0e60f1 --- /dev/null +++ b/experiments/plot_cifar_final.py @@ -0,0 +1,208 @@ +"""Generate final CIFAR-10 plots from 3-seed results.""" +import os, json, numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +output_dir = 'report' +os.makedirs(output_dir, exist_ok=True) + +# Load all seeds +all_results = {} +for seed, path in [(42, 'results/cifar10/results_cifar10.json'), + (123, 'results/cifar10_seed123/results_cifar10.json'), + (456, 'results/cifar10_seed456/results_cifar10.json')]: + with open(path) as f: + d = json.load(f) + all_results[seed] = d[str(seed)] + +methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge'] +colors = {'bp': '#F44336', 'dfa': '#2196F3', 'state_bridge': '#FF9800', 'credit_bridge': '#4CAF50'} +labels = {'bp': 'BP', 'dfa': 'DFA', 'state_bridge': 'State Bridge', 'credit_bridge': 'Credit Bridge'} + +# 1. Accuracy curves +fig, axes = plt.subplots(1, 2, figsize=(14, 5)) +for method in methods: + train_accs = [] + test_accs = [] + for seed in [42, 123, 456]: + log = all_results[seed][method]['log'] + train_accs.append(log['train_acc']) + test_accs.append(log['test_acc']) + + train_arr = np.array(train_accs) + test_arr = np.array(test_accs) + epochs = np.arange(1, train_arr.shape[1] + 1) + + for ax, arr, title in zip(axes, [train_arr, test_arr], ['Train Accuracy', 'Test Accuracy']): + mean = arr.mean(0) + std = arr.std(0) + ax.plot(epochs, mean, '-', color=colors[method], label=labels[method]) + ax.fill_between(epochs, mean - std, mean + std, alpha=0.15, color=colors[method]) + +for ax, title in zip(axes, ['Train Accuracy', 'Test Accuracy']): + ax.set_xlabel('Epoch', fontsize=12) + ax.set_ylabel(title, fontsize=12) + ax.set_title(title, fontsize=13) + ax.legend(fontsize=10) + ax.grid(True, alpha=0.3) + +fig.suptitle('CIFAR-10 Deep Residual MLP (d=512, L=12, 3 seeds)', fontsize=14, y=1.02) +fig.tight_layout() +fig.savefig(os.path.join(output_dir, 'cifar_accuracy.png'), dpi=150, bbox_inches='tight') +plt.close(fig) +print("Saved cifar_accuracy.png") + +# 2. Per-layer diagnostics (seed 42) +fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + +# BP cosine +ax = axes[0] +for method in methods: + diag = all_results[42][method].get('diagnostics', {}) + if 'bp_cosine' in diag: + layers = range(len(diag['bp_cosine'])) + ax.plot(layers, diag['bp_cosine'], 'o-', color=colors[method], label=labels[method], markersize=4) +ax.set_xlabel('Layer') +ax.set_ylabel('Cosine with BP Gradient') +ax.set_title('Offline BP Cosine') +ax.legend(fontsize=9) +ax.grid(True, alpha=0.3) + +# Perturbation rho +ax = axes[1] +for method in methods: + diag = all_results[42][method].get('diagnostics', {}) + if 'perturbation_rho' in diag: + layers = range(len(diag['perturbation_rho'])) + ax.plot(layers, diag['perturbation_rho'], 'o-', color=colors[method], label=labels[method], markersize=4) +ax.set_xlabel('Layer') +ax.set_ylabel('Perturbation Correlation (ρ)') +ax.set_title('Local Perturbation Correlation') +ax.legend(fontsize=9) +ax.grid(True, alpha=0.3) +ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + +# Nudging (eta=0.01) +ax = axes[2] +for method in methods: + diag = all_results[42][method].get('diagnostics', {}) + if 'nudging' in diag and '0.01' in diag['nudging']: + nud = diag['nudging']['0.01'] + layers = range(len(nud)) + ax.plot(layers, nud, 'o-', color=colors[method], label=labels[method], markersize=4) +ax.set_xlabel('Layer') +ax.set_ylabel('Nudge Delta (negative=good)') +ax.set_title('Nudging Test (η=0.01)') +ax.legend(fontsize=9) +ax.grid(True, alpha=0.3) +ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + +fig.suptitle('CIFAR-10 Per-Layer Diagnostics (seed 42)', fontsize=14, y=1.02) +fig.tight_layout() +fig.savefig(os.path.join(output_dir, 'cifar_diagnostics.png'), dpi=150, bbox_inches='tight') +plt.close(fig) +print("Saved cifar_diagnostics.png") + +# 3. Feature drift +fig, ax = plt.subplots(1, 1, figsize=(10, 6)) +for method in methods: + drift = all_results[42][method].get('drift', {}) + block_drifts = [] + for l in range(12): + key = f'blocks.{l}.w1.weight' + if key in drift: + block_drifts.append(drift[key]) + if block_drifts: + ax.plot(range(len(block_drifts)), block_drifts, 'o-', color=colors[method], + label=labels[method], markersize=4) +ax.set_xlabel('Block', fontsize=12) +ax.set_ylabel('Feature Drift ||W_final - W_init|| / ||W_init||', fontsize=11) +ax.set_title('Feature Drift per Block (CIFAR-10, seed 42)', fontsize=13) +ax.legend(fontsize=10) +ax.grid(True, alpha=0.3) +fig.tight_layout() +fig.savefig(os.path.join(output_dir, 'cifar_feature_drift.png'), dpi=150) +plt.close(fig) +print("Saved cifar_feature_drift.png") + +# 4. State bridge: prediction quality vs credit quality +fig, axes = plt.subplots(1, 2, figsize=(12, 5)) + +# State prediction error over epochs +ax = axes[0] +for seed in [42, 123, 456]: + log = all_results[seed]['state_bridge']['log'] + if 'state_pred_error' in log: + epochs = range(1, len(log['state_pred_error']) + 1) + ax.plot(epochs, log['state_pred_error'], '-', alpha=0.7, label=f'Seed {seed}') +ax.set_xlabel('Epoch', fontsize=12) +ax.set_ylabel('State Prediction Error', fontsize=12) +ax.set_title('State Bridge: Prediction Error', fontsize=13) +ax.set_yscale('log') +ax.legend() +ax.grid(True, alpha=0.3) + +# Compare: prediction error (near zero) vs accuracy (poor) +ax = axes[1] +methods_compare = ['dfa', 'state_bridge', 'credit_bridge'] +test_accs = {m: [] for m in methods_compare} +for seed in [42, 123, 456]: + for m in methods_compare: + test_accs[m].append(all_results[seed][m]['log']['test_acc'][-1]) + +x_pos = range(len(methods_compare)) +means = [np.mean(test_accs[m]) for m in methods_compare] +stds = [np.std(test_accs[m]) for m in methods_compare] +bar_colors = [colors[m] for m in methods_compare] +bar_labels = [labels[m] for m in methods_compare] + +bars = ax.bar(x_pos, means, yerr=stds, color=bar_colors, capsize=5, alpha=0.8) +ax.set_xticks(x_pos) +ax.set_xticklabels(bar_labels, fontsize=11) +ax.set_ylabel('Test Accuracy', fontsize=12) +ax.set_title('Test Accuracy Comparison\n(State Bridge predicts h_L perfectly\nbut produces worst credit)', fontsize=12) +ax.grid(True, alpha=0.3, axis='y') + +# Add text annotation +ax.annotate('State pred err ≈ 0.0000\nbut worst accuracy!', + xy=(1, means[1]), xytext=(1.3, means[1] + 0.06), + arrowprops=dict(arrowstyle='->', color='black'), + fontsize=9, ha='center') + +fig.tight_layout() +fig.savefig(os.path.join(output_dir, 'cifar_state_vs_credit.png'), dpi=150) +plt.close(fig) +print("Saved cifar_state_vs_credit.png") + +# 5. Combined summary bar chart +fig, ax = plt.subplots(figsize=(10, 6)) +x = np.arange(len(methods)) +width = 0.6 +test_accs_all = {} +for method in methods: + accs = [all_results[seed][method]['log']['test_acc'][-1] for seed in [42, 123, 456]] + test_accs_all[method] = (np.mean(accs), np.std(accs)) + +means = [test_accs_all[m][0] for m in methods] +stds = [test_accs_all[m][1] for m in methods] +bar_colors = [colors[m] for m in methods] + +bars = ax.bar(x, means, width, yerr=stds, color=bar_colors, capsize=5, alpha=0.85) +ax.set_ylabel('Test Accuracy', fontsize=13) +ax.set_title('CIFAR-10 Test Accuracy (3 seeds, 100 epochs)', fontsize=14) +ax.set_xticks(x) +ax.set_xticklabels([labels[m] for m in methods], fontsize=12) +ax.grid(True, alpha=0.3, axis='y') + +# Add value labels +for bar, mean, std in zip(bars, means, stds): + ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + std + 0.005, + f'{mean:.1%}', ha='center', va='bottom', fontweight='bold', fontsize=11) + +fig.tight_layout() +fig.savefig(os.path.join(output_dir, 'cifar_summary.png'), dpi=150) +plt.close(fig) +print("Saved cifar_summary.png") + +print("\nAll CIFAR plots generated.") -- cgit v1.2.3