summaryrefslogtreecommitdiff
path: root/experiments/plot_synth_ladder.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/plot_synth_ladder.py')
-rw-r--r--experiments/plot_synth_ladder.py449
1 files changed, 449 insertions, 0 deletions
diff --git a/experiments/plot_synth_ladder.py b/experiments/plot_synth_ladder.py
new file mode 100644
index 0000000..a3fb0b4
--- /dev/null
+++ b/experiments/plot_synth_ladder.py
@@ -0,0 +1,449 @@
+"""Generate phase diagram plots from synthetic nonlinearity ladder results."""
+import os
+import sys
+import json
+import numpy as np
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+from collections import defaultdict
+
+output_dir = 'report_explore'
+os.makedirs(output_dir, exist_ok=True)
+
+
+def load_summaries(result_dirs):
+ """Load and merge summary.json files from multiple result directories."""
+ merged = {}
+ for d in result_dirs:
+ path = os.path.join(d, 'summary.json')
+ if os.path.exists(path):
+ with open(path) as f:
+ data = json.load(f)
+ merged.update(data)
+ return merged
+
+
+def parse_key(key):
+ """Parse 'a0.5_L8_s42' -> (alpha, L, seed)."""
+ parts = key.split('_')
+ alpha = float(parts[0][1:])
+ L = int(parts[1][1:])
+ seed = int(parts[2][1:])
+ return alpha, L, seed
+
+
+def aggregate_seeds(summary):
+ """Aggregate over seeds: group by (alpha, L)."""
+ groups = defaultdict(list)
+ for key, data in summary.items():
+ alpha, L, seed = parse_key(key)
+ groups[(alpha, L)].append(data)
+
+ agg = {}
+ for (alpha, L), entries in groups.items():
+ agg[(alpha, L)] = {}
+ for method in ['bp', 'dfa', 'state_bridge', 'credit_bridge']:
+ vals = {}
+ for metric in ['test_acc', 'mean_bp_cosine', 'mean_rho', 'mean_nudge_01',
+ 'mean_nudge_001', 'mean_nudge_003']:
+ arr = [e[method].get(metric, 0) for e in entries if method in e]
+ # Filter out blown-up values
+ arr = [v for v in arr if abs(v) < 1e6]
+ if arr:
+ vals[f'{metric}_mean'] = np.mean(arr)
+ vals[f'{metric}_std'] = np.std(arr)
+ else:
+ vals[f'{metric}_mean'] = np.nan
+ vals[f'{metric}_std'] = np.nan
+
+ # State prediction error
+ if method == 'state_bridge':
+ arr = [e[method].get('mean_state_pred_error', 0) for e in entries if method in e]
+ arr = [v for v in arr if abs(v) < 1e6]
+ if arr:
+ vals['state_pred_error_mean'] = np.mean(arr)
+ vals['state_pred_error_std'] = np.std(arr)
+ else:
+ vals['state_pred_error_mean'] = np.nan
+ vals['state_pred_error_std'] = np.nan
+
+ # Per-layer diagnostics (average over seeds)
+ for pl_metric in ['bp_cosine_per_layer', 'rho_per_layer', 'nudge_per_layer']:
+ arrays = []
+ for e in entries:
+ if method in e and pl_metric in e[method]:
+ arr = e[method][pl_metric]
+ # Check for blowup
+ if all(abs(v) < 1e6 for v in arr):
+ arrays.append(arr)
+ if arrays:
+ arr2d = np.array(arrays)
+ vals[f'{pl_metric}_mean'] = arr2d.mean(axis=0).tolist()
+ vals[f'{pl_metric}_std'] = arr2d.std(axis=0).tolist()
+
+ agg[(alpha, L)][method] = vals
+ return agg
+
+
+def plot_phase_diagrams(agg, alphas, depths, save_prefix='synth'):
+ """Generate the 4 key phase diagram plots."""
+ methods = ['dfa', 'state_bridge', 'credit_bridge']
+ colors = {'bp': '#F44336', 'dfa': '#2196F3', 'state_bridge': '#FF9800', 'credit_bridge': '#4CAF50'}
+ labels = {'bp': 'BP', 'dfa': 'DFA', 'state_bridge': 'State Bridge', 'credit_bridge': 'Credit Bridge'}
+
+ alphas = sorted(alphas)
+ depths = sorted(depths)
+
+ # ========================================================================
+ # Plot 1: Phase diagram heatmaps (Gamma, rho, nudge) for each method
+ # ========================================================================
+ fig, axes = plt.subplots(3, 4, figsize=(20, 12))
+ metrics_info = [
+ ('mean_bp_cosine_mean', 'BP Cosine (Γ)', 'RdYlGn', -0.1, 1.0),
+ ('mean_rho_mean', 'Perturbation ρ', 'RdYlGn', -0.2, 1.0),
+ ('mean_nudge_01_mean', 'Nudge (η=0.01)', 'RdYlGn_r', None, None),
+ ]
+ all_methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge']
+
+ for row, (metric_key, metric_name, cmap, vmin, vmax) in enumerate(metrics_info):
+ for col, method in enumerate(all_methods):
+ ax = axes[row, col]
+ grid = np.full((len(alphas), len(depths)), np.nan)
+ for i, alpha in enumerate(alphas):
+ for j, L in enumerate(depths):
+ if (alpha, L) in agg and method in agg[(alpha, L)]:
+ val = agg[(alpha, L)][method].get(metric_key, np.nan)
+ grid[i, j] = val
+
+ if vmin is None:
+ valid = grid[~np.isnan(grid)]
+ if len(valid) > 0:
+ vmin_use = np.nanmin(grid)
+ vmax_use = min(0, np.nanmax(grid)) # nudge should be <=0
+ else:
+ vmin_use, vmax_use = -1, 0
+ else:
+ vmin_use, vmax_use = vmin, vmax
+
+ im = ax.imshow(grid, cmap=cmap, vmin=vmin_use, vmax=vmax_use, aspect='auto',
+ origin='lower')
+ ax.set_xticks(range(len(depths)))
+ ax.set_xticklabels([str(d) for d in depths])
+ ax.set_yticks(range(len(alphas)))
+ ax.set_yticklabels([str(a) for a in alphas])
+
+ # Annotate cells
+ for i in range(len(alphas)):
+ for j in range(len(depths)):
+ val = grid[i, j]
+ if not np.isnan(val):
+ txt = f'{val:.3f}' if abs(val) < 100 else f'{val:.1e}'
+ ax.text(j, i, txt, ha='center', va='center', fontsize=8,
+ color='black' if abs(val) < 0.5 else 'white')
+ else:
+ ax.text(j, i, 'X', ha='center', va='center', fontsize=10,
+ color='red', fontweight='bold')
+
+ if row == 0:
+ ax.set_title(labels[method], fontsize=12)
+ if col == 0:
+ ax.set_ylabel(f'{metric_name}\nα (nonlinearity)', fontsize=10)
+ if row == len(metrics_info) - 1:
+ ax.set_xlabel('Depth (L)', fontsize=10)
+ plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
+
+ fig.suptitle('Synthetic Ladder: Phase Diagrams (X = blown up)', fontsize=14, y=1.02)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, f'{save_prefix}_phase_heatmaps.png'), dpi=150, bbox_inches='tight')
+ plt.close(fig)
+ print(f"Saved {save_prefix}_phase_heatmaps.png")
+
+ # ========================================================================
+ # Plot 2: Line plots - metric vs alpha for each depth
+ # ========================================================================
+ fig, axes = plt.subplots(2, max(len(depths), 1), figsize=(6 * len(depths), 10), squeeze=False)
+
+ for j, L in enumerate(depths):
+ # Row 0: Gamma
+ ax = axes[0, j]
+ for method in all_methods:
+ vals = []
+ errs = []
+ valid_alphas = []
+ for alpha in alphas:
+ if (alpha, L) in agg and method in agg[(alpha, L)]:
+ v = agg[(alpha, L)][method].get('mean_bp_cosine_mean', np.nan)
+ e = agg[(alpha, L)][method].get('mean_bp_cosine_std', 0)
+ if not np.isnan(v):
+ vals.append(v)
+ errs.append(e)
+ valid_alphas.append(alpha)
+ if vals:
+ ax.errorbar(valid_alphas, vals, yerr=errs, marker='o', color=colors[method],
+ label=labels[method], capsize=3, linewidth=2, markersize=6)
+ ax.set_xlabel('α (nonlinearity)', fontsize=11)
+ ax.set_ylabel('Mean BP Cosine (Γ)', fontsize=11)
+ ax.set_title(f'L={L}', fontsize=13)
+ ax.legend(fontsize=9)
+ ax.grid(True, alpha=0.3)
+ ax.set_ylim(-0.1, 1.05)
+
+ # Row 1: rho
+ ax = axes[1, j]
+ for method in all_methods:
+ vals = []
+ errs = []
+ valid_alphas = []
+ for alpha in alphas:
+ if (alpha, L) in agg and method in agg[(alpha, L)]:
+ v = agg[(alpha, L)][method].get('mean_rho_mean', np.nan)
+ e = agg[(alpha, L)][method].get('mean_rho_std', 0)
+ if not np.isnan(v):
+ vals.append(v)
+ errs.append(e)
+ valid_alphas.append(alpha)
+ if vals:
+ ax.errorbar(valid_alphas, vals, yerr=errs, marker='s', color=colors[method],
+ label=labels[method], capsize=3, linewidth=2, markersize=6)
+ ax.set_xlabel('α (nonlinearity)', fontsize=11)
+ ax.set_ylabel('Mean Perturbation ρ', fontsize=11)
+ ax.set_title(f'L={L}', fontsize=13)
+ ax.legend(fontsize=9)
+ ax.grid(True, alpha=0.3)
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+
+ fig.suptitle('Synthetic Ladder: Credit Quality vs Nonlinearity', fontsize=14, y=1.02)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, f'{save_prefix}_gamma_rho_vs_alpha.png'), dpi=150, bbox_inches='tight')
+ plt.close(fig)
+ print(f"Saved {save_prefix}_gamma_rho_vs_alpha.png")
+
+ # ========================================================================
+ # Plot 3: State bridge prediction error vs credit quality
+ # ========================================================================
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
+
+ # Left: State prediction error vs alpha
+ ax = axes[0]
+ for j, L in enumerate(depths):
+ vals = []
+ valid_alphas = []
+ for alpha in alphas:
+ if (alpha, L) in agg and 'state_bridge' in agg[(alpha, L)]:
+ v = agg[(alpha, L)]['state_bridge'].get('state_pred_error_mean', np.nan)
+ if not np.isnan(v):
+ vals.append(v)
+ valid_alphas.append(alpha)
+ if vals:
+ ax.plot(valid_alphas, vals, 'o-', label=f'L={L}', markersize=6, linewidth=2)
+ ax.set_xlabel('α (nonlinearity)', fontsize=11)
+ ax.set_ylabel('State Prediction Error', fontsize=11)
+ ax.set_title('State Bridge: Terminal Prediction Error', fontsize=12)
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ ax.set_yscale('log')
+
+ # Right: State bridge Gamma vs credit bridge Gamma
+ ax = axes[1]
+ for L in depths:
+ sb_gammas = []
+ cb_gammas = []
+ valid_alphas = []
+ for alpha in alphas:
+ if (alpha, L) in agg:
+ sb_g = agg[(alpha, L)].get('state_bridge', {}).get('mean_bp_cosine_mean', np.nan)
+ cb_g = agg[(alpha, L)].get('credit_bridge', {}).get('mean_bp_cosine_mean', np.nan)
+ if not np.isnan(sb_g) and not np.isnan(cb_g):
+ sb_gammas.append(sb_g)
+ cb_gammas.append(cb_g)
+ valid_alphas.append(alpha)
+ if sb_gammas:
+ ax.scatter(sb_gammas, cb_gammas, s=80, label=f'L={L}', zorder=5)
+ for a, sx, sy in zip(valid_alphas, sb_gammas, cb_gammas):
+ ax.annotate(f'α={a}', (sx, sy), textcoords="offset points",
+ xytext=(5, 5), fontsize=8)
+ ax.plot([0, 1], [0, 1], 'k--', alpha=0.3, label='y=x')
+ ax.set_xlabel('State Bridge Γ', fontsize=11)
+ ax.set_ylabel('Credit Bridge Γ', fontsize=11)
+ ax.set_title('State Bridge vs Credit Bridge BP Cosine', fontsize=12)
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ ax.set_xlim(-0.1, 1.0)
+ ax.set_ylim(-0.1, 1.0)
+
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, f'{save_prefix}_state_vs_credit.png'), dpi=150, bbox_inches='tight')
+ plt.close(fig)
+ print(f"Saved {save_prefix}_state_vs_credit.png")
+
+ # ========================================================================
+ # Plot 4: Nudging and accuracy comparison
+ # ========================================================================
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
+
+ # Left: Nudge vs alpha for each depth
+ ax = axes[0]
+ for method in ['dfa', 'state_bridge', 'credit_bridge']:
+ for L in depths:
+ vals = []
+ valid_alphas = []
+ for alpha in alphas:
+ if (alpha, L) in agg and method in agg[(alpha, L)]:
+ v = agg[(alpha, L)][method].get('mean_nudge_01_mean', np.nan)
+ if not np.isnan(v):
+ vals.append(v)
+ valid_alphas.append(alpha)
+ if vals:
+ ls = '-' if L == min(depths) else '--'
+ ax.plot(valid_alphas, vals, f'o{ls}', color=colors[method],
+ label=f'{labels[method]} L={L}', markersize=5)
+ ax.set_xlabel('α (nonlinearity)', fontsize=11)
+ ax.set_ylabel('Mean Nudge (η=0.01, negative=good)', fontsize=11)
+ ax.set_title('Nudging Test', fontsize=12)
+ ax.legend(fontsize=8, ncol=2)
+ ax.grid(True, alpha=0.3)
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+
+ # Right: Test accuracy
+ ax = axes[1]
+ for method in all_methods:
+ for L in depths:
+ vals = []
+ valid_alphas = []
+ for alpha in alphas:
+ if (alpha, L) in agg and method in agg[(alpha, L)]:
+ v = agg[(alpha, L)][method].get('test_acc_mean', np.nan)
+ if not np.isnan(v):
+ vals.append(v)
+ valid_alphas.append(alpha)
+ if vals:
+ ls = '-' if L == min(depths) else '--'
+ ax.plot(valid_alphas, vals, f'o{ls}', color=colors[method],
+ label=f'{labels[method]} L={L}', markersize=5)
+ ax.set_xlabel('α (nonlinearity)', fontsize=11)
+ ax.set_ylabel('Test Accuracy', fontsize=11)
+ ax.set_title('Test Accuracy', fontsize=12)
+ ax.legend(fontsize=8, ncol=2)
+ ax.grid(True, alpha=0.3)
+
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, f'{save_prefix}_nudge_acc.png'), dpi=150, bbox_inches='tight')
+ plt.close(fig)
+ print(f"Saved {save_prefix}_nudge_acc.png")
+
+ # ========================================================================
+ # Plot 5: Per-layer diagnostics for selected (alpha, L) combos
+ # ========================================================================
+ combos = [(a, L) for a in alphas for L in depths if (a, L) in agg]
+ # Only plot non-blown-up combos
+ valid_combos = []
+ for a, L in combos:
+ if not np.isnan(agg[(a, L)].get('credit_bridge', {}).get('mean_bp_cosine_mean', np.nan)):
+ valid_combos.append((a, L))
+
+ if valid_combos:
+ n_combos = len(valid_combos)
+ fig, axes = plt.subplots(2, n_combos, figsize=(5 * n_combos, 8), squeeze=False)
+
+ for idx, (alpha, L) in enumerate(valid_combos):
+ # Row 0: BP cosine per layer
+ ax = axes[0, idx]
+ for method in all_methods:
+ if method in agg[(alpha, L)]:
+ pl = agg[(alpha, L)][method].get('bp_cosine_per_layer_mean', None)
+ if pl is not None:
+ ax.plot(range(len(pl)), pl, 'o-', color=colors[method],
+ label=labels[method], markersize=4)
+ ax.set_xlabel('Layer')
+ ax.set_ylabel('BP Cosine')
+ ax.set_title(f'α={alpha}, L={L}', fontsize=11)
+ ax.legend(fontsize=8)
+ ax.grid(True, alpha=0.3)
+ ax.set_ylim(-0.3, 1.05)
+
+ # Row 1: rho per layer
+ ax = axes[1, idx]
+ for method in all_methods:
+ if method in agg[(alpha, L)]:
+ pl = agg[(alpha, L)][method].get('rho_per_layer_mean', None)
+ if pl is not None:
+ ax.plot(range(len(pl)), pl, 's-', color=colors[method],
+ label=labels[method], markersize=4)
+ ax.set_xlabel('Layer')
+ ax.set_ylabel('Perturbation ρ')
+ ax.set_title(f'α={alpha}, L={L}', fontsize=11)
+ ax.legend(fontsize=8)
+ ax.grid(True, alpha=0.3)
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+
+ fig.suptitle('Per-Layer Diagnostics', fontsize=14, y=1.02)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, f'{save_prefix}_per_layer.png'), dpi=150, bbox_inches='tight')
+ plt.close(fig)
+ print(f"Saved {save_prefix}_per_layer.png")
+
+ # ========================================================================
+ # Plot 6: Advantage plots: CB - DFA
+ # ========================================================================
+ fig, axes = plt.subplots(1, 3, figsize=(18, 5))
+ metrics = [
+ ('mean_bp_cosine_mean', 'Γ(CB) - Γ(DFA)'),
+ ('mean_rho_mean', 'ρ(CB) - ρ(DFA)'),
+ ('mean_nudge_01_mean', 'nudge(CB) - nudge(DFA)'),
+ ]
+
+ for ax, (metric_key, ylabel) in zip(axes, metrics):
+ for L in depths:
+ diffs = []
+ valid_alphas = []
+ for alpha in alphas:
+ if (alpha, L) in agg:
+ cb_v = agg[(alpha, L)].get('credit_bridge', {}).get(metric_key, np.nan)
+ dfa_v = agg[(alpha, L)].get('dfa', {}).get(metric_key, np.nan)
+ if not np.isnan(cb_v) and not np.isnan(dfa_v):
+ diffs.append(cb_v - dfa_v)
+ valid_alphas.append(alpha)
+ if diffs:
+ ax.plot(valid_alphas, diffs, 'o-', label=f'L={L}', markersize=6, linewidth=2)
+ ax.set_xlabel('α (nonlinearity)', fontsize=11)
+ ax.set_ylabel(ylabel, fontsize=11)
+ ax.axhline(y=0, color='red', linestyle='--', alpha=0.7, label='zero (parity)')
+ ax.legend(fontsize=10)
+ ax.grid(True, alpha=0.3)
+
+ fig.suptitle('Credit Bridge Advantage over DFA', fontsize=14, y=1.02)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, f'{save_prefix}_cb_advantage.png'), dpi=150, bbox_inches='tight')
+ plt.close(fig)
+ print(f"Saved {save_prefix}_cb_advantage.png")
+
+
+if __name__ == '__main__':
+ import sys
+ result_dirs = sys.argv[1:] if len(sys.argv) > 1 else ['results/synth_ladder_smoke']
+
+ print(f"Loading from: {result_dirs}")
+ summary = load_summaries(result_dirs)
+
+ if not summary:
+ print("No results found!")
+ sys.exit(1)
+
+ # Extract alphas and depths from keys
+ alphas = set()
+ depths = set()
+ for key in summary:
+ alpha, L, seed = parse_key(key)
+ alphas.add(alpha)
+ depths.add(L)
+
+ alphas = sorted(alphas)
+ depths = sorted(depths)
+ print(f"Alphas: {alphas}")
+ print(f"Depths: {depths}")
+ print(f"Total configs: {len(summary)}")
+
+ agg = aggregate_seeds(summary)
+ plot_phase_diagrams(agg, alphas, depths)
+ print("\nAll plots generated!")