From 8f786597d1007f0ef6012f53c22958d9c4e9b81a Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Tue, 24 Mar 2026 01:21:31 -0500 Subject: Add exploration visualization: CIFAR depth scan, boundary ablation, synth vs CIFAR gap Three new plots: - cifar_depth_scan.png: acc/Gamma/rho vs depth for all methods - boundary_ablation.png: s_type, tgw, warmup ratio sweeps - synth_vs_cifar.png: dimensionality gap comparison (d=128 vs d=512) Co-Authored-By: Claude Opus 4.6 (1M context) --- experiments/plot_explore_all.py | 285 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 285 insertions(+) create mode 100644 experiments/plot_explore_all.py (limited to 'experiments/plot_explore_all.py') diff --git a/experiments/plot_explore_all.py b/experiments/plot_explore_all.py new file mode 100644 index 0000000..3d6660b --- /dev/null +++ b/experiments/plot_explore_all.py @@ -0,0 +1,285 @@ +"""Generate all exploration phase plots.""" +import os +import json +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +output_dir = 'report_explore' +os.makedirs(output_dir, exist_ok=True) + +# ============================================================================= +# CIFAR Depth Scan +# ============================================================================= +def plot_cifar_depth_scan(): + depths = [2, 4, 6, 8, 12] + results = {} + for L in depths: + path = f'results/cifar_depth_scan_s42/d512_L{L}_s42.json' + if os.path.exists(path): + with open(path) as f: + results[L] = json.load(f) + + if not results: + print("No CIFAR depth scan results found") + return + + methods = ['bp', 'dfa', 'credit_bridge'] + colors = {'bp': '#F44336', 'dfa': '#2196F3', 'credit_bridge': '#4CAF50'} + labels = {'bp': 'BP', 'dfa': 'DFA', 'credit_bridge': 'Credit Bridge'} + + fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + + # Accuracy vs depth + ax = axes[0] + for m in methods: + accs = [results[L][m]['log']['test_acc'][-1] for L in depths if m in results[L]] + valid_depths = [L for L in depths if m in results[L]] + ax.plot(valid_depths, accs, 'o-', color=colors[m], label=labels[m], markersize=6, linewidth=2) + ax.set_xlabel('Depth (L)', fontsize=12) + ax.set_ylabel('Test Accuracy', fontsize=12) + ax.set_title('CIFAR-10 Test Accuracy vs Depth', fontsize=13) + ax.legend(fontsize=10) + ax.grid(True, alpha=0.3) + + # Gamma vs depth + ax = axes[1] + for m in methods: + gammas = [np.mean(results[L][m]['diagnostics']['bp_cosine']) for L in depths if m in results[L]] + valid_depths = [L for L in depths if m in results[L]] + ax.plot(valid_depths, gammas, 'o-', color=colors[m], label=labels[m], markersize=6, linewidth=2) + ax.set_xlabel('Depth (L)', fontsize=12) + ax.set_ylabel('Mean BP Cosine (Gamma)', fontsize=12) + ax.set_title('CIFAR-10 BP Cosine vs Depth', fontsize=13) + ax.legend(fontsize=10) + ax.grid(True, alpha=0.3) + ax.set_ylim(-0.05, 0.25) + + # rho vs depth + ax = axes[2] + for m in ['dfa', 'credit_bridge']: + rhos = [np.mean(results[L][m]['diagnostics']['perturbation_rho']) for L in depths if m in results[L]] + valid_depths = [L for L in depths if m in results[L]] + ax.plot(valid_depths, rhos, 'o-', color=colors[m], label=labels[m], markersize=6, linewidth=2) + ax.set_xlabel('Depth (L)', fontsize=12) + ax.set_ylabel('Mean Perturbation rho', fontsize=12) + ax.set_title('CIFAR-10 Perturbation rho vs Depth', fontsize=13) + ax.legend(fontsize=10) + ax.grid(True, alpha=0.3) + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + + fig.suptitle('CIFAR-10 Depth Scan (d=512, seed=42)', fontsize=14, y=1.02) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'cifar_depth_scan.png'), dpi=150, bbox_inches='tight') + plt.close(fig) + print("Saved cifar_depth_scan.png") + + +# ============================================================================= +# Boundary Ablation +# ============================================================================= +def plot_boundary_ablation(): + # s_type comparison across 3 seeds + s_types = ['eT', 'deltaL'] + seed_data = {} + for seed, path in [(42, 'results/boundary_ablation_s_sweep/ablation_a1.0_L4_s42.json'), + (123, 'results/boundary_ablation_s123/ablation_a1.0_L4_s123.json'), + (456, 'results/boundary_ablation_s456/ablation_a1.0_L4_s456.json')]: + if os.path.exists(path): + with open(path) as f: + seed_data[seed] = json.load(f) + + # Also load eT_hL and deltaL_hL from seed 42 + if 42 in seed_data: + s_types_full = ['eT', 'deltaL', 'eT_hL', 'deltaL_hL'] + else: + s_types_full = s_types + + fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + + # s_type bar chart (3 seeds for eT and deltaL) + ax = axes[0] + colors_s = {'eT': '#2196F3', 'deltaL': '#FF9800', 'eT_hL': '#9C27B0', 'deltaL_hL': '#795548'} + x = np.arange(len(s_types_full)) + width = 0.35 + + for i, metric in enumerate(['mean_bp_cosine', 'mean_rho']): + means = [] + stds = [] + for s_type in s_types_full: + vals = [] + for seed in seed_data: + key = f's_{s_type}_tgw1.0_wr0.2' + if key in seed_data[seed]: + vals.append(seed_data[seed][key][metric]) + means.append(np.mean(vals) if vals else 0) + stds.append(np.std(vals) if len(vals) > 1 else 0) + + offset = (i - 0.5) * width + label = 'Gamma' if metric == 'mean_bp_cosine' else 'rho' + color = '#F44336' if metric == 'mean_bp_cosine' else '#4CAF50' + ax.bar(x + offset, means, width, yerr=stds, capsize=3, + label=label, color=color, alpha=0.7) + + ax.set_xticks(x) + ax.set_xticklabels(s_types_full, fontsize=10) + ax.set_ylabel('Value', fontsize=12) + ax.set_title('Terminal Conditioning Code', fontsize=13) + ax.legend(fontsize=10) + ax.grid(True, alpha=0.3, axis='y') + + # tgw sweep + ax = axes[1] + tgw_path = 'results/boundary_ablation_tgw_sweep/ablation_a1.0_L4_s42.json' + if os.path.exists(tgw_path): + with open(tgw_path) as f: + tgw_data = json.load(f) + + tgws = [] + gammas = [] + rhos = [] + accs = [] + for key in sorted(tgw_data.keys()): + r = tgw_data[key] + tgws.append(r['term_grad_weight']) + gammas.append(r['mean_bp_cosine']) + rhos.append(r['mean_rho']) + accs.append(r['test_acc']) + + ax.plot(tgws, gammas, 'o-', color='#F44336', label='Gamma', markersize=8, linewidth=2) + ax.plot(tgws, rhos, 's-', color='#4CAF50', label='rho', markersize=8, linewidth=2) + ax.plot(tgws, accs, '^-', color='#2196F3', label='Accuracy', markersize=8, linewidth=2) + ax.set_xlabel('Terminal Gradient Weight', fontsize=12) + ax.set_ylabel('Value', fontsize=12) + ax.set_title('Terminal Gradient Matching Weight', fontsize=13) + ax.legend(fontsize=10) + ax.grid(True, alpha=0.3) + + # Warmup ratio sweep + ax = axes[2] + wr_path = 'results/boundary_ablation_wr_sweep/ablation_a1.0_L4_s42.json' + if os.path.exists(wr_path): + with open(wr_path) as f: + wr_data = json.load(f) + + wrs = [] + gammas = [] + rhos = [] + accs = [] + for key in sorted(wr_data.keys()): + r = wr_data[key] + wrs.append(r['warmup_ratio']) + gammas.append(r['mean_bp_cosine']) + rhos.append(r['mean_rho']) + accs.append(r['test_acc']) + + ax.plot(wrs, gammas, 'o-', color='#F44336', label='Gamma', markersize=8, linewidth=2) + ax.plot(wrs, rhos, 's-', color='#4CAF50', label='rho', markersize=8, linewidth=2) + ax.plot(wrs, accs, '^-', color='#2196F3', label='Accuracy', markersize=8, linewidth=2) + ax.set_xlabel('Warmup Ratio', fontsize=12) + ax.set_ylabel('Value', fontsize=12) + ax.set_title('DFA Warmup Ratio', fontsize=13) + ax.legend(fontsize=10) + ax.grid(True, alpha=0.3) + + fig.suptitle('Boundary-Condition Ablation (alpha=1.0, L=4)', fontsize=14, y=1.02) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'boundary_ablation.png'), dpi=150, bbox_inches='tight') + plt.close(fig) + print("Saved boundary_ablation.png") + + +# ============================================================================= +# Comparison: Synthetic vs CIFAR (the gap) +# ============================================================================= +def plot_synth_vs_cifar(): + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + + # Load synthetic data (alpha=1.0, L=4, 3 seeds) + from collections import defaultdict + synth_data = defaultdict(list) + for path_dir in ['results/synth_ladder_v2_hi']: + summary_path = os.path.join(path_dir, 'summary.json') + if os.path.exists(summary_path): + with open(summary_path) as f: + data = json.load(f) + for key, d in data.items(): + parts = key.split('_') + alpha = float(parts[0][1:]) + L = int(parts[1][1:]) + if alpha == 1.0 and L == 4: + synth_data['seed'].append(int(parts[2][1:])) + for m in ['dfa', 'credit_bridge']: + synth_data[f'{m}_gamma'].append(d[m]['mean_bp_cosine']) + synth_data[f'{m}_rho'].append(d[m]['mean_rho']) + + # Load CIFAR data (L=4) + cifar_path = 'results/cifar_depth_scan_s42/d512_L4_s42.json' + cifar_data = {} + if os.path.exists(cifar_path): + with open(cifar_path) as f: + cifar_data = json.load(f) + + # Gamma comparison + ax = axes[0] + x = np.arange(2) + width = 0.35 + + synth_dfa_gamma = np.mean(synth_data.get('dfa_gamma', [0])) + synth_cb_gamma = np.mean(synth_data.get('credit_bridge_gamma', [0])) + cifar_dfa_gamma = np.mean(cifar_data.get('dfa', {}).get('diagnostics', {}).get('bp_cosine', [0])) + cifar_cb_gamma = np.mean(cifar_data.get('credit_bridge', {}).get('diagnostics', {}).get('bp_cosine', [0])) + + ax.bar(x - width/2, [synth_dfa_gamma, synth_cb_gamma], width, + label='Synthetic (d=128)', color=['#2196F3', '#4CAF50'], alpha=0.7) + ax.bar(x + width/2, [cifar_dfa_gamma, cifar_cb_gamma], width, + label='CIFAR (d=512)', color=['#2196F3', '#4CAF50'], alpha=0.3, edgecolor='black') + ax.set_xticks(x) + ax.set_xticklabels(['DFA', 'Credit Bridge']) + ax.set_ylabel('Mean BP Cosine (Gamma)') + ax.set_title('Gamma: Synthetic vs CIFAR (L=4)') + ax.legend(['Synthetic', 'CIFAR'], fontsize=10) + ax.grid(True, alpha=0.3, axis='y') + + # Add annotations + for i, (sv, cv) in enumerate([(synth_dfa_gamma, cifar_dfa_gamma), + (synth_cb_gamma, cifar_cb_gamma)]): + ax.text(i - width/2, sv + 0.01, f'{sv:.3f}', ha='center', fontsize=9) + ax.text(i + width/2, cv + 0.01, f'{cv:.3f}', ha='center', fontsize=9) + + # rho comparison + ax = axes[1] + synth_dfa_rho = np.mean(synth_data.get('dfa_rho', [0])) + synth_cb_rho = np.mean(synth_data.get('credit_bridge_rho', [0])) + cifar_dfa_rho = np.mean(cifar_data.get('dfa', {}).get('diagnostics', {}).get('perturbation_rho', [0])) + cifar_cb_rho = np.mean(cifar_data.get('credit_bridge', {}).get('diagnostics', {}).get('perturbation_rho', [0])) + + ax.bar(x - width/2, [synth_dfa_rho, synth_cb_rho], width, + label='Synthetic (d=128)', color=['#2196F3', '#4CAF50'], alpha=0.7) + ax.bar(x + width/2, [cifar_dfa_rho, cifar_cb_rho], width, + label='CIFAR (d=512)', color=['#2196F3', '#4CAF50'], alpha=0.3, edgecolor='black') + ax.set_xticks(x) + ax.set_xticklabels(['DFA', 'Credit Bridge']) + ax.set_ylabel('Mean Perturbation rho') + ax.set_title('rho: Synthetic vs CIFAR (L=4)') + ax.legend(['Synthetic', 'CIFAR'], fontsize=10) + ax.grid(True, alpha=0.3, axis='y') + + for i, (sv, cv) in enumerate([(synth_dfa_rho, cifar_dfa_rho), + (synth_cb_rho, cifar_cb_rho)]): + ax.text(i - width/2, sv + 0.01, f'{sv:.3f}', ha='center', fontsize=9) + ax.text(i + width/2, max(cv, 0) + 0.01, f'{cv:.3f}', ha='center', fontsize=9) + + fig.suptitle('The Dimensionality Gap: Same Method, Different Scales', fontsize=14, y=1.02) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'synth_vs_cifar.png'), dpi=150, bbox_inches='tight') + plt.close(fig) + print("Saved synth_vs_cifar.png") + + +if __name__ == '__main__': + plot_cifar_depth_scan() + plot_boundary_ablation() + plot_synth_vs_cifar() + print("\nAll exploration plots generated!") -- cgit v1.2.3