"""Generate final CIFAR-10 plots from 3-seed results.""" import os, json, numpy as np import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt output_dir = 'report' os.makedirs(output_dir, exist_ok=True) # Load all seeds all_results = {} for seed, path in [(42, 'results/cifar10/results_cifar10.json'), (123, 'results/cifar10_seed123/results_cifar10.json'), (456, 'results/cifar10_seed456/results_cifar10.json')]: with open(path) as f: d = json.load(f) all_results[seed] = d[str(seed)] methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge'] colors = {'bp': '#F44336', 'dfa': '#2196F3', 'state_bridge': '#FF9800', 'credit_bridge': '#4CAF50'} labels = {'bp': 'BP', 'dfa': 'DFA', 'state_bridge': 'State Bridge', 'credit_bridge': 'Credit Bridge'} # 1. Accuracy curves fig, axes = plt.subplots(1, 2, figsize=(14, 5)) for method in methods: train_accs = [] test_accs = [] for seed in [42, 123, 456]: log = all_results[seed][method]['log'] train_accs.append(log['train_acc']) test_accs.append(log['test_acc']) train_arr = np.array(train_accs) test_arr = np.array(test_accs) epochs = np.arange(1, train_arr.shape[1] + 1) for ax, arr, title in zip(axes, [train_arr, test_arr], ['Train Accuracy', 'Test Accuracy']): mean = arr.mean(0) std = arr.std(0) ax.plot(epochs, mean, '-', color=colors[method], label=labels[method]) ax.fill_between(epochs, mean - std, mean + std, alpha=0.15, color=colors[method]) for ax, title in zip(axes, ['Train Accuracy', 'Test Accuracy']): ax.set_xlabel('Epoch', fontsize=12) ax.set_ylabel(title, fontsize=12) ax.set_title(title, fontsize=13) ax.legend(fontsize=10) ax.grid(True, alpha=0.3) fig.suptitle('CIFAR-10 Deep Residual MLP (d=512, L=12, 3 seeds)', fontsize=14, y=1.02) fig.tight_layout() fig.savefig(os.path.join(output_dir, 'cifar_accuracy.png'), dpi=150, bbox_inches='tight') plt.close(fig) print("Saved cifar_accuracy.png") # 2. Per-layer diagnostics (seed 42) fig, axes = plt.subplots(1, 3, figsize=(18, 5)) # BP cosine ax = axes[0] for method in methods: diag = all_results[42][method].get('diagnostics', {}) if 'bp_cosine' in diag: layers = range(len(diag['bp_cosine'])) ax.plot(layers, diag['bp_cosine'], 'o-', color=colors[method], label=labels[method], markersize=4) ax.set_xlabel('Layer') ax.set_ylabel('Cosine with BP Gradient') ax.set_title('Offline BP Cosine') ax.legend(fontsize=9) ax.grid(True, alpha=0.3) # Perturbation rho ax = axes[1] for method in methods: diag = all_results[42][method].get('diagnostics', {}) if 'perturbation_rho' in diag: layers = range(len(diag['perturbation_rho'])) ax.plot(layers, diag['perturbation_rho'], 'o-', color=colors[method], label=labels[method], markersize=4) ax.set_xlabel('Layer') ax.set_ylabel('Perturbation Correlation (ρ)') ax.set_title('Local Perturbation Correlation') ax.legend(fontsize=9) ax.grid(True, alpha=0.3) ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) # Nudging (eta=0.01) ax = axes[2] for method in methods: diag = all_results[42][method].get('diagnostics', {}) if 'nudging' in diag and '0.01' in diag['nudging']: nud = diag['nudging']['0.01'] layers = range(len(nud)) ax.plot(layers, nud, 'o-', color=colors[method], label=labels[method], markersize=4) ax.set_xlabel('Layer') ax.set_ylabel('Nudge Delta (negative=good)') ax.set_title('Nudging Test (η=0.01)') ax.legend(fontsize=9) ax.grid(True, alpha=0.3) ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) fig.suptitle('CIFAR-10 Per-Layer Diagnostics (seed 42)', fontsize=14, y=1.02) fig.tight_layout() fig.savefig(os.path.join(output_dir, 'cifar_diagnostics.png'), dpi=150, bbox_inches='tight') plt.close(fig) print("Saved cifar_diagnostics.png") # 3. Feature drift fig, ax = plt.subplots(1, 1, figsize=(10, 6)) for method in methods: drift = all_results[42][method].get('drift', {}) block_drifts = [] for l in range(12): key = f'blocks.{l}.w1.weight' if key in drift: block_drifts.append(drift[key]) if block_drifts: ax.plot(range(len(block_drifts)), block_drifts, 'o-', color=colors[method], label=labels[method], markersize=4) ax.set_xlabel('Block', fontsize=12) ax.set_ylabel('Feature Drift ||W_final - W_init|| / ||W_init||', fontsize=11) ax.set_title('Feature Drift per Block (CIFAR-10, seed 42)', fontsize=13) ax.legend(fontsize=10) ax.grid(True, alpha=0.3) fig.tight_layout() fig.savefig(os.path.join(output_dir, 'cifar_feature_drift.png'), dpi=150) plt.close(fig) print("Saved cifar_feature_drift.png") # 4. State bridge: prediction quality vs credit quality fig, axes = plt.subplots(1, 2, figsize=(12, 5)) # State prediction error over epochs ax = axes[0] for seed in [42, 123, 456]: log = all_results[seed]['state_bridge']['log'] if 'state_pred_error' in log: epochs = range(1, len(log['state_pred_error']) + 1) ax.plot(epochs, log['state_pred_error'], '-', alpha=0.7, label=f'Seed {seed}') ax.set_xlabel('Epoch', fontsize=12) ax.set_ylabel('State Prediction Error', fontsize=12) ax.set_title('State Bridge: Prediction Error', fontsize=13) ax.set_yscale('log') ax.legend() ax.grid(True, alpha=0.3) # Compare: prediction error (near zero) vs accuracy (poor) ax = axes[1] methods_compare = ['dfa', 'state_bridge', 'credit_bridge'] test_accs = {m: [] for m in methods_compare} for seed in [42, 123, 456]: for m in methods_compare: test_accs[m].append(all_results[seed][m]['log']['test_acc'][-1]) x_pos = range(len(methods_compare)) means = [np.mean(test_accs[m]) for m in methods_compare] stds = [np.std(test_accs[m]) for m in methods_compare] bar_colors = [colors[m] for m in methods_compare] bar_labels = [labels[m] for m in methods_compare] bars = ax.bar(x_pos, means, yerr=stds, color=bar_colors, capsize=5, alpha=0.8) ax.set_xticks(x_pos) ax.set_xticklabels(bar_labels, fontsize=11) ax.set_ylabel('Test Accuracy', fontsize=12) ax.set_title('Test Accuracy Comparison\n(State Bridge predicts h_L perfectly\nbut produces worst credit)', fontsize=12) ax.grid(True, alpha=0.3, axis='y') # Add text annotation ax.annotate('State pred err ≈ 0.0000\nbut worst accuracy!', xy=(1, means[1]), xytext=(1.3, means[1] + 0.06), arrowprops=dict(arrowstyle='->', color='black'), fontsize=9, ha='center') fig.tight_layout() fig.savefig(os.path.join(output_dir, 'cifar_state_vs_credit.png'), dpi=150) plt.close(fig) print("Saved cifar_state_vs_credit.png") # 5. Combined summary bar chart fig, ax = plt.subplots(figsize=(10, 6)) x = np.arange(len(methods)) width = 0.6 test_accs_all = {} for method in methods: accs = [all_results[seed][method]['log']['test_acc'][-1] for seed in [42, 123, 456]] test_accs_all[method] = (np.mean(accs), np.std(accs)) means = [test_accs_all[m][0] for m in methods] stds = [test_accs_all[m][1] for m in methods] bar_colors = [colors[m] for m in methods] bars = ax.bar(x, means, width, yerr=stds, color=bar_colors, capsize=5, alpha=0.85) ax.set_ylabel('Test Accuracy', fontsize=13) ax.set_title('CIFAR-10 Test Accuracy (3 seeds, 100 epochs)', fontsize=14) ax.set_xticks(x) ax.set_xticklabels([labels[m] for m in methods], fontsize=12) ax.grid(True, alpha=0.3, axis='y') # Add value labels for bar, mean, std in zip(bars, means, stds): ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + std + 0.005, f'{mean:.1%}', ha='center', va='bottom', fontweight='bold', fontsize=11) fig.tight_layout() fig.savefig(os.path.join(output_dir, 'cifar_summary.png'), dpi=150) plt.close(fig) print("Saved cifar_summary.png") print("\nAll CIFAR plots generated.")