summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--experiments/plot_explore_all.py285
-rw-r--r--report_explore/boundary_ablation.pngbin0 -> 151623 bytes
-rw-r--r--report_explore/cifar_depth_scan.pngbin0 -> 177260 bytes
-rw-r--r--report_explore/synth_vs_cifar.pngbin0 -> 76788 bytes
4 files changed, 285 insertions, 0 deletions
diff --git a/experiments/plot_explore_all.py b/experiments/plot_explore_all.py
new file mode 100644
index 0000000..3d6660b
--- /dev/null
+++ b/experiments/plot_explore_all.py
@@ -0,0 +1,285 @@
+"""Generate all exploration phase plots."""
+import os
+import json
+import numpy as np
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+
+output_dir = 'report_explore'
+os.makedirs(output_dir, exist_ok=True)
+
+# =============================================================================
+# CIFAR Depth Scan
+# =============================================================================
+def plot_cifar_depth_scan():
+ depths = [2, 4, 6, 8, 12]
+ results = {}
+ for L in depths:
+ path = f'results/cifar_depth_scan_s42/d512_L{L}_s42.json'
+ if os.path.exists(path):
+ with open(path) as f:
+ results[L] = json.load(f)
+
+ if not results:
+ print("No CIFAR depth scan results found")
+ return
+
+ methods = ['bp', 'dfa', 'credit_bridge']
+ colors = {'bp': '#F44336', 'dfa': '#2196F3', 'credit_bridge': '#4CAF50'}
+ labels = {'bp': 'BP', 'dfa': 'DFA', 'credit_bridge': 'Credit Bridge'}
+
+ fig, axes = plt.subplots(1, 3, figsize=(18, 5))
+
+ # Accuracy vs depth
+ ax = axes[0]
+ for m in methods:
+ accs = [results[L][m]['log']['test_acc'][-1] for L in depths if m in results[L]]
+ valid_depths = [L for L in depths if m in results[L]]
+ ax.plot(valid_depths, accs, 'o-', color=colors[m], label=labels[m], markersize=6, linewidth=2)
+ ax.set_xlabel('Depth (L)', fontsize=12)
+ ax.set_ylabel('Test Accuracy', fontsize=12)
+ ax.set_title('CIFAR-10 Test Accuracy vs Depth', fontsize=13)
+ ax.legend(fontsize=10)
+ ax.grid(True, alpha=0.3)
+
+ # Gamma vs depth
+ ax = axes[1]
+ for m in methods:
+ gammas = [np.mean(results[L][m]['diagnostics']['bp_cosine']) for L in depths if m in results[L]]
+ valid_depths = [L for L in depths if m in results[L]]
+ ax.plot(valid_depths, gammas, 'o-', color=colors[m], label=labels[m], markersize=6, linewidth=2)
+ ax.set_xlabel('Depth (L)', fontsize=12)
+ ax.set_ylabel('Mean BP Cosine (Gamma)', fontsize=12)
+ ax.set_title('CIFAR-10 BP Cosine vs Depth', fontsize=13)
+ ax.legend(fontsize=10)
+ ax.grid(True, alpha=0.3)
+ ax.set_ylim(-0.05, 0.25)
+
+ # rho vs depth
+ ax = axes[2]
+ for m in ['dfa', 'credit_bridge']:
+ rhos = [np.mean(results[L][m]['diagnostics']['perturbation_rho']) for L in depths if m in results[L]]
+ valid_depths = [L for L in depths if m in results[L]]
+ ax.plot(valid_depths, rhos, 'o-', color=colors[m], label=labels[m], markersize=6, linewidth=2)
+ ax.set_xlabel('Depth (L)', fontsize=12)
+ ax.set_ylabel('Mean Perturbation rho', fontsize=12)
+ ax.set_title('CIFAR-10 Perturbation rho vs Depth', fontsize=13)
+ ax.legend(fontsize=10)
+ ax.grid(True, alpha=0.3)
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+
+ fig.suptitle('CIFAR-10 Depth Scan (d=512, seed=42)', fontsize=14, y=1.02)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'cifar_depth_scan.png'), dpi=150, bbox_inches='tight')
+ plt.close(fig)
+ print("Saved cifar_depth_scan.png")
+
+
+# =============================================================================
+# Boundary Ablation
+# =============================================================================
+def plot_boundary_ablation():
+ # s_type comparison across 3 seeds
+ s_types = ['eT', 'deltaL']
+ seed_data = {}
+ for seed, path in [(42, 'results/boundary_ablation_s_sweep/ablation_a1.0_L4_s42.json'),
+ (123, 'results/boundary_ablation_s123/ablation_a1.0_L4_s123.json'),
+ (456, 'results/boundary_ablation_s456/ablation_a1.0_L4_s456.json')]:
+ if os.path.exists(path):
+ with open(path) as f:
+ seed_data[seed] = json.load(f)
+
+ # Also load eT_hL and deltaL_hL from seed 42
+ if 42 in seed_data:
+ s_types_full = ['eT', 'deltaL', 'eT_hL', 'deltaL_hL']
+ else:
+ s_types_full = s_types
+
+ fig, axes = plt.subplots(1, 3, figsize=(18, 5))
+
+ # s_type bar chart (3 seeds for eT and deltaL)
+ ax = axes[0]
+ colors_s = {'eT': '#2196F3', 'deltaL': '#FF9800', 'eT_hL': '#9C27B0', 'deltaL_hL': '#795548'}
+ x = np.arange(len(s_types_full))
+ width = 0.35
+
+ for i, metric in enumerate(['mean_bp_cosine', 'mean_rho']):
+ means = []
+ stds = []
+ for s_type in s_types_full:
+ vals = []
+ for seed in seed_data:
+ key = f's_{s_type}_tgw1.0_wr0.2'
+ if key in seed_data[seed]:
+ vals.append(seed_data[seed][key][metric])
+ means.append(np.mean(vals) if vals else 0)
+ stds.append(np.std(vals) if len(vals) > 1 else 0)
+
+ offset = (i - 0.5) * width
+ label = 'Gamma' if metric == 'mean_bp_cosine' else 'rho'
+ color = '#F44336' if metric == 'mean_bp_cosine' else '#4CAF50'
+ ax.bar(x + offset, means, width, yerr=stds, capsize=3,
+ label=label, color=color, alpha=0.7)
+
+ ax.set_xticks(x)
+ ax.set_xticklabels(s_types_full, fontsize=10)
+ ax.set_ylabel('Value', fontsize=12)
+ ax.set_title('Terminal Conditioning Code', fontsize=13)
+ ax.legend(fontsize=10)
+ ax.grid(True, alpha=0.3, axis='y')
+
+ # tgw sweep
+ ax = axes[1]
+ tgw_path = 'results/boundary_ablation_tgw_sweep/ablation_a1.0_L4_s42.json'
+ if os.path.exists(tgw_path):
+ with open(tgw_path) as f:
+ tgw_data = json.load(f)
+
+ tgws = []
+ gammas = []
+ rhos = []
+ accs = []
+ for key in sorted(tgw_data.keys()):
+ r = tgw_data[key]
+ tgws.append(r['term_grad_weight'])
+ gammas.append(r['mean_bp_cosine'])
+ rhos.append(r['mean_rho'])
+ accs.append(r['test_acc'])
+
+ ax.plot(tgws, gammas, 'o-', color='#F44336', label='Gamma', markersize=8, linewidth=2)
+ ax.plot(tgws, rhos, 's-', color='#4CAF50', label='rho', markersize=8, linewidth=2)
+ ax.plot(tgws, accs, '^-', color='#2196F3', label='Accuracy', markersize=8, linewidth=2)
+ ax.set_xlabel('Terminal Gradient Weight', fontsize=12)
+ ax.set_ylabel('Value', fontsize=12)
+ ax.set_title('Terminal Gradient Matching Weight', fontsize=13)
+ ax.legend(fontsize=10)
+ ax.grid(True, alpha=0.3)
+
+ # Warmup ratio sweep
+ ax = axes[2]
+ wr_path = 'results/boundary_ablation_wr_sweep/ablation_a1.0_L4_s42.json'
+ if os.path.exists(wr_path):
+ with open(wr_path) as f:
+ wr_data = json.load(f)
+
+ wrs = []
+ gammas = []
+ rhos = []
+ accs = []
+ for key in sorted(wr_data.keys()):
+ r = wr_data[key]
+ wrs.append(r['warmup_ratio'])
+ gammas.append(r['mean_bp_cosine'])
+ rhos.append(r['mean_rho'])
+ accs.append(r['test_acc'])
+
+ ax.plot(wrs, gammas, 'o-', color='#F44336', label='Gamma', markersize=8, linewidth=2)
+ ax.plot(wrs, rhos, 's-', color='#4CAF50', label='rho', markersize=8, linewidth=2)
+ ax.plot(wrs, accs, '^-', color='#2196F3', label='Accuracy', markersize=8, linewidth=2)
+ ax.set_xlabel('Warmup Ratio', fontsize=12)
+ ax.set_ylabel('Value', fontsize=12)
+ ax.set_title('DFA Warmup Ratio', fontsize=13)
+ ax.legend(fontsize=10)
+ ax.grid(True, alpha=0.3)
+
+ fig.suptitle('Boundary-Condition Ablation (alpha=1.0, L=4)', fontsize=14, y=1.02)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'boundary_ablation.png'), dpi=150, bbox_inches='tight')
+ plt.close(fig)
+ print("Saved boundary_ablation.png")
+
+
+# =============================================================================
+# Comparison: Synthetic vs CIFAR (the gap)
+# =============================================================================
+def plot_synth_vs_cifar():
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
+
+ # Load synthetic data (alpha=1.0, L=4, 3 seeds)
+ from collections import defaultdict
+ synth_data = defaultdict(list)
+ for path_dir in ['results/synth_ladder_v2_hi']:
+ summary_path = os.path.join(path_dir, 'summary.json')
+ if os.path.exists(summary_path):
+ with open(summary_path) as f:
+ data = json.load(f)
+ for key, d in data.items():
+ parts = key.split('_')
+ alpha = float(parts[0][1:])
+ L = int(parts[1][1:])
+ if alpha == 1.0 and L == 4:
+ synth_data['seed'].append(int(parts[2][1:]))
+ for m in ['dfa', 'credit_bridge']:
+ synth_data[f'{m}_gamma'].append(d[m]['mean_bp_cosine'])
+ synth_data[f'{m}_rho'].append(d[m]['mean_rho'])
+
+ # Load CIFAR data (L=4)
+ cifar_path = 'results/cifar_depth_scan_s42/d512_L4_s42.json'
+ cifar_data = {}
+ if os.path.exists(cifar_path):
+ with open(cifar_path) as f:
+ cifar_data = json.load(f)
+
+ # Gamma comparison
+ ax = axes[0]
+ x = np.arange(2)
+ width = 0.35
+
+ synth_dfa_gamma = np.mean(synth_data.get('dfa_gamma', [0]))
+ synth_cb_gamma = np.mean(synth_data.get('credit_bridge_gamma', [0]))
+ cifar_dfa_gamma = np.mean(cifar_data.get('dfa', {}).get('diagnostics', {}).get('bp_cosine', [0]))
+ cifar_cb_gamma = np.mean(cifar_data.get('credit_bridge', {}).get('diagnostics', {}).get('bp_cosine', [0]))
+
+ ax.bar(x - width/2, [synth_dfa_gamma, synth_cb_gamma], width,
+ label='Synthetic (d=128)', color=['#2196F3', '#4CAF50'], alpha=0.7)
+ ax.bar(x + width/2, [cifar_dfa_gamma, cifar_cb_gamma], width,
+ label='CIFAR (d=512)', color=['#2196F3', '#4CAF50'], alpha=0.3, edgecolor='black')
+ ax.set_xticks(x)
+ ax.set_xticklabels(['DFA', 'Credit Bridge'])
+ ax.set_ylabel('Mean BP Cosine (Gamma)')
+ ax.set_title('Gamma: Synthetic vs CIFAR (L=4)')
+ ax.legend(['Synthetic', 'CIFAR'], fontsize=10)
+ ax.grid(True, alpha=0.3, axis='y')
+
+ # Add annotations
+ for i, (sv, cv) in enumerate([(synth_dfa_gamma, cifar_dfa_gamma),
+ (synth_cb_gamma, cifar_cb_gamma)]):
+ ax.text(i - width/2, sv + 0.01, f'{sv:.3f}', ha='center', fontsize=9)
+ ax.text(i + width/2, cv + 0.01, f'{cv:.3f}', ha='center', fontsize=9)
+
+ # rho comparison
+ ax = axes[1]
+ synth_dfa_rho = np.mean(synth_data.get('dfa_rho', [0]))
+ synth_cb_rho = np.mean(synth_data.get('credit_bridge_rho', [0]))
+ cifar_dfa_rho = np.mean(cifar_data.get('dfa', {}).get('diagnostics', {}).get('perturbation_rho', [0]))
+ cifar_cb_rho = np.mean(cifar_data.get('credit_bridge', {}).get('diagnostics', {}).get('perturbation_rho', [0]))
+
+ ax.bar(x - width/2, [synth_dfa_rho, synth_cb_rho], width,
+ label='Synthetic (d=128)', color=['#2196F3', '#4CAF50'], alpha=0.7)
+ ax.bar(x + width/2, [cifar_dfa_rho, cifar_cb_rho], width,
+ label='CIFAR (d=512)', color=['#2196F3', '#4CAF50'], alpha=0.3, edgecolor='black')
+ ax.set_xticks(x)
+ ax.set_xticklabels(['DFA', 'Credit Bridge'])
+ ax.set_ylabel('Mean Perturbation rho')
+ ax.set_title('rho: Synthetic vs CIFAR (L=4)')
+ ax.legend(['Synthetic', 'CIFAR'], fontsize=10)
+ ax.grid(True, alpha=0.3, axis='y')
+
+ for i, (sv, cv) in enumerate([(synth_dfa_rho, cifar_dfa_rho),
+ (synth_cb_rho, cifar_cb_rho)]):
+ ax.text(i - width/2, sv + 0.01, f'{sv:.3f}', ha='center', fontsize=9)
+ ax.text(i + width/2, max(cv, 0) + 0.01, f'{cv:.3f}', ha='center', fontsize=9)
+
+ fig.suptitle('The Dimensionality Gap: Same Method, Different Scales', fontsize=14, y=1.02)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'synth_vs_cifar.png'), dpi=150, bbox_inches='tight')
+ plt.close(fig)
+ print("Saved synth_vs_cifar.png")
+
+
+if __name__ == '__main__':
+ plot_cifar_depth_scan()
+ plot_boundary_ablation()
+ plot_synth_vs_cifar()
+ print("\nAll exploration plots generated!")
diff --git a/report_explore/boundary_ablation.png b/report_explore/boundary_ablation.png
new file mode 100644
index 0000000..f769257
--- /dev/null
+++ b/report_explore/boundary_ablation.png
Binary files differ
diff --git a/report_explore/cifar_depth_scan.png b/report_explore/cifar_depth_scan.png
new file mode 100644
index 0000000..01f064f
--- /dev/null
+++ b/report_explore/cifar_depth_scan.png
Binary files differ
diff --git a/report_explore/synth_vs_cifar.png b/report_explore/synth_vs_cifar.png
new file mode 100644
index 0000000..a3daf51
--- /dev/null
+++ b/report_explore/synth_vs_cifar.png
Binary files differ