"""Generate all exploration phase plots.""" import os import json import numpy as np import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt output_dir = 'report_explore' os.makedirs(output_dir, exist_ok=True) # ============================================================================= # CIFAR Depth Scan # ============================================================================= def plot_cifar_depth_scan(): depths = [2, 4, 6, 8, 12] results = {} for L in depths: path = f'results/cifar_depth_scan_s42/d512_L{L}_s42.json' if os.path.exists(path): with open(path) as f: results[L] = json.load(f) if not results: print("No CIFAR depth scan results found") return methods = ['bp', 'dfa', 'credit_bridge'] colors = {'bp': '#F44336', 'dfa': '#2196F3', 'credit_bridge': '#4CAF50'} labels = {'bp': 'BP', 'dfa': 'DFA', 'credit_bridge': 'Credit Bridge'} fig, axes = plt.subplots(1, 3, figsize=(18, 5)) # Accuracy vs depth ax = axes[0] for m in methods: accs = [results[L][m]['log']['test_acc'][-1] for L in depths if m in results[L]] valid_depths = [L for L in depths if m in results[L]] ax.plot(valid_depths, accs, 'o-', color=colors[m], label=labels[m], markersize=6, linewidth=2) ax.set_xlabel('Depth (L)', fontsize=12) ax.set_ylabel('Test Accuracy', fontsize=12) ax.set_title('CIFAR-10 Test Accuracy vs Depth', fontsize=13) ax.legend(fontsize=10) ax.grid(True, alpha=0.3) # Gamma vs depth ax = axes[1] for m in methods: gammas = [np.mean(results[L][m]['diagnostics']['bp_cosine']) for L in depths if m in results[L]] valid_depths = [L for L in depths if m in results[L]] ax.plot(valid_depths, gammas, 'o-', color=colors[m], label=labels[m], markersize=6, linewidth=2) ax.set_xlabel('Depth (L)', fontsize=12) ax.set_ylabel('Mean BP Cosine (Gamma)', fontsize=12) ax.set_title('CIFAR-10 BP Cosine vs Depth', fontsize=13) ax.legend(fontsize=10) ax.grid(True, alpha=0.3) ax.set_ylim(-0.05, 0.25) # rho vs depth ax = axes[2] for m in ['dfa', 'credit_bridge']: rhos = [np.mean(results[L][m]['diagnostics']['perturbation_rho']) for L in depths if m in results[L]] valid_depths = [L for L in depths if m in results[L]] ax.plot(valid_depths, rhos, 'o-', color=colors[m], label=labels[m], markersize=6, linewidth=2) ax.set_xlabel('Depth (L)', fontsize=12) ax.set_ylabel('Mean Perturbation rho', fontsize=12) ax.set_title('CIFAR-10 Perturbation rho vs Depth', fontsize=13) ax.legend(fontsize=10) ax.grid(True, alpha=0.3) ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) fig.suptitle('CIFAR-10 Depth Scan (d=512, seed=42)', fontsize=14, y=1.02) fig.tight_layout() fig.savefig(os.path.join(output_dir, 'cifar_depth_scan.png'), dpi=150, bbox_inches='tight') plt.close(fig) print("Saved cifar_depth_scan.png") # ============================================================================= # Boundary Ablation # ============================================================================= def plot_boundary_ablation(): # s_type comparison across 3 seeds s_types = ['eT', 'deltaL'] seed_data = {} for seed, path in [(42, 'results/boundary_ablation_s_sweep/ablation_a1.0_L4_s42.json'), (123, 'results/boundary_ablation_s123/ablation_a1.0_L4_s123.json'), (456, 'results/boundary_ablation_s456/ablation_a1.0_L4_s456.json')]: if os.path.exists(path): with open(path) as f: seed_data[seed] = json.load(f) # Also load eT_hL and deltaL_hL from seed 42 if 42 in seed_data: s_types_full = ['eT', 'deltaL', 'eT_hL', 'deltaL_hL'] else: s_types_full = s_types fig, axes = plt.subplots(1, 3, figsize=(18, 5)) # s_type bar chart (3 seeds for eT and deltaL) ax = axes[0] colors_s = {'eT': '#2196F3', 'deltaL': '#FF9800', 'eT_hL': '#9C27B0', 'deltaL_hL': '#795548'} x = np.arange(len(s_types_full)) width = 0.35 for i, metric in enumerate(['mean_bp_cosine', 'mean_rho']): means = [] stds = [] for s_type in s_types_full: vals = [] for seed in seed_data: key = f's_{s_type}_tgw1.0_wr0.2' if key in seed_data[seed]: vals.append(seed_data[seed][key][metric]) means.append(np.mean(vals) if vals else 0) stds.append(np.std(vals) if len(vals) > 1 else 0) offset = (i - 0.5) * width label = 'Gamma' if metric == 'mean_bp_cosine' else 'rho' color = '#F44336' if metric == 'mean_bp_cosine' else '#4CAF50' ax.bar(x + offset, means, width, yerr=stds, capsize=3, label=label, color=color, alpha=0.7) ax.set_xticks(x) ax.set_xticklabels(s_types_full, fontsize=10) ax.set_ylabel('Value', fontsize=12) ax.set_title('Terminal Conditioning Code', fontsize=13) ax.legend(fontsize=10) ax.grid(True, alpha=0.3, axis='y') # tgw sweep ax = axes[1] tgw_path = 'results/boundary_ablation_tgw_sweep/ablation_a1.0_L4_s42.json' if os.path.exists(tgw_path): with open(tgw_path) as f: tgw_data = json.load(f) tgws = [] gammas = [] rhos = [] accs = [] for key in sorted(tgw_data.keys()): r = tgw_data[key] tgws.append(r['term_grad_weight']) gammas.append(r['mean_bp_cosine']) rhos.append(r['mean_rho']) accs.append(r['test_acc']) ax.plot(tgws, gammas, 'o-', color='#F44336', label='Gamma', markersize=8, linewidth=2) ax.plot(tgws, rhos, 's-', color='#4CAF50', label='rho', markersize=8, linewidth=2) ax.plot(tgws, accs, '^-', color='#2196F3', label='Accuracy', markersize=8, linewidth=2) ax.set_xlabel('Terminal Gradient Weight', fontsize=12) ax.set_ylabel('Value', fontsize=12) ax.set_title('Terminal Gradient Matching Weight', fontsize=13) ax.legend(fontsize=10) ax.grid(True, alpha=0.3) # Warmup ratio sweep ax = axes[2] wr_path = 'results/boundary_ablation_wr_sweep/ablation_a1.0_L4_s42.json' if os.path.exists(wr_path): with open(wr_path) as f: wr_data = json.load(f) wrs = [] gammas = [] rhos = [] accs = [] for key in sorted(wr_data.keys()): r = wr_data[key] wrs.append(r['warmup_ratio']) gammas.append(r['mean_bp_cosine']) rhos.append(r['mean_rho']) accs.append(r['test_acc']) ax.plot(wrs, gammas, 'o-', color='#F44336', label='Gamma', markersize=8, linewidth=2) ax.plot(wrs, rhos, 's-', color='#4CAF50', label='rho', markersize=8, linewidth=2) ax.plot(wrs, accs, '^-', color='#2196F3', label='Accuracy', markersize=8, linewidth=2) ax.set_xlabel('Warmup Ratio', fontsize=12) ax.set_ylabel('Value', fontsize=12) ax.set_title('DFA Warmup Ratio', fontsize=13) ax.legend(fontsize=10) ax.grid(True, alpha=0.3) fig.suptitle('Boundary-Condition Ablation (alpha=1.0, L=4)', fontsize=14, y=1.02) fig.tight_layout() fig.savefig(os.path.join(output_dir, 'boundary_ablation.png'), dpi=150, bbox_inches='tight') plt.close(fig) print("Saved boundary_ablation.png") # ============================================================================= # Comparison: Synthetic vs CIFAR (the gap) # ============================================================================= def plot_synth_vs_cifar(): fig, axes = plt.subplots(1, 2, figsize=(14, 5)) # Load synthetic data (alpha=1.0, L=4, 3 seeds) from collections import defaultdict synth_data = defaultdict(list) for path_dir in ['results/synth_ladder_v2_hi']: summary_path = os.path.join(path_dir, 'summary.json') if os.path.exists(summary_path): with open(summary_path) as f: data = json.load(f) for key, d in data.items(): parts = key.split('_') alpha = float(parts[0][1:]) L = int(parts[1][1:]) if alpha == 1.0 and L == 4: synth_data['seed'].append(int(parts[2][1:])) for m in ['dfa', 'credit_bridge']: synth_data[f'{m}_gamma'].append(d[m]['mean_bp_cosine']) synth_data[f'{m}_rho'].append(d[m]['mean_rho']) # Load CIFAR data (L=4) cifar_path = 'results/cifar_depth_scan_s42/d512_L4_s42.json' cifar_data = {} if os.path.exists(cifar_path): with open(cifar_path) as f: cifar_data = json.load(f) # Gamma comparison ax = axes[0] x = np.arange(2) width = 0.35 synth_dfa_gamma = np.mean(synth_data.get('dfa_gamma', [0])) synth_cb_gamma = np.mean(synth_data.get('credit_bridge_gamma', [0])) cifar_dfa_gamma = np.mean(cifar_data.get('dfa', {}).get('diagnostics', {}).get('bp_cosine', [0])) cifar_cb_gamma = np.mean(cifar_data.get('credit_bridge', {}).get('diagnostics', {}).get('bp_cosine', [0])) ax.bar(x - width/2, [synth_dfa_gamma, synth_cb_gamma], width, label='Synthetic (d=128)', color=['#2196F3', '#4CAF50'], alpha=0.7) ax.bar(x + width/2, [cifar_dfa_gamma, cifar_cb_gamma], width, label='CIFAR (d=512)', color=['#2196F3', '#4CAF50'], alpha=0.3, edgecolor='black') ax.set_xticks(x) ax.set_xticklabels(['DFA', 'Credit Bridge']) ax.set_ylabel('Mean BP Cosine (Gamma)') ax.set_title('Gamma: Synthetic vs CIFAR (L=4)') ax.legend(['Synthetic', 'CIFAR'], fontsize=10) ax.grid(True, alpha=0.3, axis='y') # Add annotations for i, (sv, cv) in enumerate([(synth_dfa_gamma, cifar_dfa_gamma), (synth_cb_gamma, cifar_cb_gamma)]): ax.text(i - width/2, sv + 0.01, f'{sv:.3f}', ha='center', fontsize=9) ax.text(i + width/2, cv + 0.01, f'{cv:.3f}', ha='center', fontsize=9) # rho comparison ax = axes[1] synth_dfa_rho = np.mean(synth_data.get('dfa_rho', [0])) synth_cb_rho = np.mean(synth_data.get('credit_bridge_rho', [0])) cifar_dfa_rho = np.mean(cifar_data.get('dfa', {}).get('diagnostics', {}).get('perturbation_rho', [0])) cifar_cb_rho = np.mean(cifar_data.get('credit_bridge', {}).get('diagnostics', {}).get('perturbation_rho', [0])) ax.bar(x - width/2, [synth_dfa_rho, synth_cb_rho], width, label='Synthetic (d=128)', color=['#2196F3', '#4CAF50'], alpha=0.7) ax.bar(x + width/2, [cifar_dfa_rho, cifar_cb_rho], width, label='CIFAR (d=512)', color=['#2196F3', '#4CAF50'], alpha=0.3, edgecolor='black') ax.set_xticks(x) ax.set_xticklabels(['DFA', 'Credit Bridge']) ax.set_ylabel('Mean Perturbation rho') ax.set_title('rho: Synthetic vs CIFAR (L=4)') ax.legend(['Synthetic', 'CIFAR'], fontsize=10) ax.grid(True, alpha=0.3, axis='y') for i, (sv, cv) in enumerate([(synth_dfa_rho, cifar_dfa_rho), (synth_cb_rho, cifar_cb_rho)]): ax.text(i - width/2, sv + 0.01, f'{sv:.3f}', ha='center', fontsize=9) ax.text(i + width/2, max(cv, 0) + 0.01, f'{cv:.3f}', ha='center', fontsize=9) fig.suptitle('The Dimensionality Gap: Same Method, Different Scales', fontsize=14, y=1.02) fig.tight_layout() fig.savefig(os.path.join(output_dir, 'synth_vs_cifar.png'), dpi=150, bbox_inches='tight') plt.close(fig) print("Saved synth_vs_cifar.png") if __name__ == '__main__': plot_cifar_depth_scan() plot_boundary_ablation() plot_synth_vs_cifar() print("\nAll exploration plots generated!")