diff options
Diffstat (limited to 'experiments/plot_synth_ladder.py')
| -rw-r--r-- | experiments/plot_synth_ladder.py | 449 |
1 files changed, 449 insertions, 0 deletions
diff --git a/experiments/plot_synth_ladder.py b/experiments/plot_synth_ladder.py new file mode 100644 index 0000000..a3fb0b4 --- /dev/null +++ b/experiments/plot_synth_ladder.py @@ -0,0 +1,449 @@ +"""Generate phase diagram plots from synthetic nonlinearity ladder results.""" +import os +import sys +import json +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +from collections import defaultdict + +output_dir = 'report_explore' +os.makedirs(output_dir, exist_ok=True) + + +def load_summaries(result_dirs): + """Load and merge summary.json files from multiple result directories.""" + merged = {} + for d in result_dirs: + path = os.path.join(d, 'summary.json') + if os.path.exists(path): + with open(path) as f: + data = json.load(f) + merged.update(data) + return merged + + +def parse_key(key): + """Parse 'a0.5_L8_s42' -> (alpha, L, seed).""" + parts = key.split('_') + alpha = float(parts[0][1:]) + L = int(parts[1][1:]) + seed = int(parts[2][1:]) + return alpha, L, seed + + +def aggregate_seeds(summary): + """Aggregate over seeds: group by (alpha, L).""" + groups = defaultdict(list) + for key, data in summary.items(): + alpha, L, seed = parse_key(key) + groups[(alpha, L)].append(data) + + agg = {} + for (alpha, L), entries in groups.items(): + agg[(alpha, L)] = {} + for method in ['bp', 'dfa', 'state_bridge', 'credit_bridge']: + vals = {} + for metric in ['test_acc', 'mean_bp_cosine', 'mean_rho', 'mean_nudge_01', + 'mean_nudge_001', 'mean_nudge_003']: + arr = [e[method].get(metric, 0) for e in entries if method in e] + # Filter out blown-up values + arr = [v for v in arr if abs(v) < 1e6] + if arr: + vals[f'{metric}_mean'] = np.mean(arr) + vals[f'{metric}_std'] = np.std(arr) + else: + vals[f'{metric}_mean'] = np.nan + vals[f'{metric}_std'] = np.nan + + # State prediction error + if method == 'state_bridge': + arr = [e[method].get('mean_state_pred_error', 0) for e in entries if method in e] + arr = [v for v in arr if abs(v) < 1e6] + if arr: + vals['state_pred_error_mean'] = np.mean(arr) + vals['state_pred_error_std'] = np.std(arr) + else: + vals['state_pred_error_mean'] = np.nan + vals['state_pred_error_std'] = np.nan + + # Per-layer diagnostics (average over seeds) + for pl_metric in ['bp_cosine_per_layer', 'rho_per_layer', 'nudge_per_layer']: + arrays = [] + for e in entries: + if method in e and pl_metric in e[method]: + arr = e[method][pl_metric] + # Check for blowup + if all(abs(v) < 1e6 for v in arr): + arrays.append(arr) + if arrays: + arr2d = np.array(arrays) + vals[f'{pl_metric}_mean'] = arr2d.mean(axis=0).tolist() + vals[f'{pl_metric}_std'] = arr2d.std(axis=0).tolist() + + agg[(alpha, L)][method] = vals + return agg + + +def plot_phase_diagrams(agg, alphas, depths, save_prefix='synth'): + """Generate the 4 key phase diagram plots.""" + methods = ['dfa', 'state_bridge', 'credit_bridge'] + colors = {'bp': '#F44336', 'dfa': '#2196F3', 'state_bridge': '#FF9800', 'credit_bridge': '#4CAF50'} + labels = {'bp': 'BP', 'dfa': 'DFA', 'state_bridge': 'State Bridge', 'credit_bridge': 'Credit Bridge'} + + alphas = sorted(alphas) + depths = sorted(depths) + + # ======================================================================== + # Plot 1: Phase diagram heatmaps (Gamma, rho, nudge) for each method + # ======================================================================== + fig, axes = plt.subplots(3, 4, figsize=(20, 12)) + metrics_info = [ + ('mean_bp_cosine_mean', 'BP Cosine (Γ)', 'RdYlGn', -0.1, 1.0), + ('mean_rho_mean', 'Perturbation ρ', 'RdYlGn', -0.2, 1.0), + ('mean_nudge_01_mean', 'Nudge (η=0.01)', 'RdYlGn_r', None, None), + ] + all_methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge'] + + for row, (metric_key, metric_name, cmap, vmin, vmax) in enumerate(metrics_info): + for col, method in enumerate(all_methods): + ax = axes[row, col] + grid = np.full((len(alphas), len(depths)), np.nan) + for i, alpha in enumerate(alphas): + for j, L in enumerate(depths): + if (alpha, L) in agg and method in agg[(alpha, L)]: + val = agg[(alpha, L)][method].get(metric_key, np.nan) + grid[i, j] = val + + if vmin is None: + valid = grid[~np.isnan(grid)] + if len(valid) > 0: + vmin_use = np.nanmin(grid) + vmax_use = min(0, np.nanmax(grid)) # nudge should be <=0 + else: + vmin_use, vmax_use = -1, 0 + else: + vmin_use, vmax_use = vmin, vmax + + im = ax.imshow(grid, cmap=cmap, vmin=vmin_use, vmax=vmax_use, aspect='auto', + origin='lower') + ax.set_xticks(range(len(depths))) + ax.set_xticklabels([str(d) for d in depths]) + ax.set_yticks(range(len(alphas))) + ax.set_yticklabels([str(a) for a in alphas]) + + # Annotate cells + for i in range(len(alphas)): + for j in range(len(depths)): + val = grid[i, j] + if not np.isnan(val): + txt = f'{val:.3f}' if abs(val) < 100 else f'{val:.1e}' + ax.text(j, i, txt, ha='center', va='center', fontsize=8, + color='black' if abs(val) < 0.5 else 'white') + else: + ax.text(j, i, 'X', ha='center', va='center', fontsize=10, + color='red', fontweight='bold') + + if row == 0: + ax.set_title(labels[method], fontsize=12) + if col == 0: + ax.set_ylabel(f'{metric_name}\nα (nonlinearity)', fontsize=10) + if row == len(metrics_info) - 1: + ax.set_xlabel('Depth (L)', fontsize=10) + plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) + + fig.suptitle('Synthetic Ladder: Phase Diagrams (X = blown up)', fontsize=14, y=1.02) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f'{save_prefix}_phase_heatmaps.png'), dpi=150, bbox_inches='tight') + plt.close(fig) + print(f"Saved {save_prefix}_phase_heatmaps.png") + + # ======================================================================== + # Plot 2: Line plots - metric vs alpha for each depth + # ======================================================================== + fig, axes = plt.subplots(2, max(len(depths), 1), figsize=(6 * len(depths), 10), squeeze=False) + + for j, L in enumerate(depths): + # Row 0: Gamma + ax = axes[0, j] + for method in all_methods: + vals = [] + errs = [] + valid_alphas = [] + for alpha in alphas: + if (alpha, L) in agg and method in agg[(alpha, L)]: + v = agg[(alpha, L)][method].get('mean_bp_cosine_mean', np.nan) + e = agg[(alpha, L)][method].get('mean_bp_cosine_std', 0) + if not np.isnan(v): + vals.append(v) + errs.append(e) + valid_alphas.append(alpha) + if vals: + ax.errorbar(valid_alphas, vals, yerr=errs, marker='o', color=colors[method], + label=labels[method], capsize=3, linewidth=2, markersize=6) + ax.set_xlabel('α (nonlinearity)', fontsize=11) + ax.set_ylabel('Mean BP Cosine (Γ)', fontsize=11) + ax.set_title(f'L={L}', fontsize=13) + ax.legend(fontsize=9) + ax.grid(True, alpha=0.3) + ax.set_ylim(-0.1, 1.05) + + # Row 1: rho + ax = axes[1, j] + for method in all_methods: + vals = [] + errs = [] + valid_alphas = [] + for alpha in alphas: + if (alpha, L) in agg and method in agg[(alpha, L)]: + v = agg[(alpha, L)][method].get('mean_rho_mean', np.nan) + e = agg[(alpha, L)][method].get('mean_rho_std', 0) + if not np.isnan(v): + vals.append(v) + errs.append(e) + valid_alphas.append(alpha) + if vals: + ax.errorbar(valid_alphas, vals, yerr=errs, marker='s', color=colors[method], + label=labels[method], capsize=3, linewidth=2, markersize=6) + ax.set_xlabel('α (nonlinearity)', fontsize=11) + ax.set_ylabel('Mean Perturbation ρ', fontsize=11) + ax.set_title(f'L={L}', fontsize=13) + ax.legend(fontsize=9) + ax.grid(True, alpha=0.3) + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + + fig.suptitle('Synthetic Ladder: Credit Quality vs Nonlinearity', fontsize=14, y=1.02) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f'{save_prefix}_gamma_rho_vs_alpha.png'), dpi=150, bbox_inches='tight') + plt.close(fig) + print(f"Saved {save_prefix}_gamma_rho_vs_alpha.png") + + # ======================================================================== + # Plot 3: State bridge prediction error vs credit quality + # ======================================================================== + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + + # Left: State prediction error vs alpha + ax = axes[0] + for j, L in enumerate(depths): + vals = [] + valid_alphas = [] + for alpha in alphas: + if (alpha, L) in agg and 'state_bridge' in agg[(alpha, L)]: + v = agg[(alpha, L)]['state_bridge'].get('state_pred_error_mean', np.nan) + if not np.isnan(v): + vals.append(v) + valid_alphas.append(alpha) + if vals: + ax.plot(valid_alphas, vals, 'o-', label=f'L={L}', markersize=6, linewidth=2) + ax.set_xlabel('α (nonlinearity)', fontsize=11) + ax.set_ylabel('State Prediction Error', fontsize=11) + ax.set_title('State Bridge: Terminal Prediction Error', fontsize=12) + ax.legend() + ax.grid(True, alpha=0.3) + ax.set_yscale('log') + + # Right: State bridge Gamma vs credit bridge Gamma + ax = axes[1] + for L in depths: + sb_gammas = [] + cb_gammas = [] + valid_alphas = [] + for alpha in alphas: + if (alpha, L) in agg: + sb_g = agg[(alpha, L)].get('state_bridge', {}).get('mean_bp_cosine_mean', np.nan) + cb_g = agg[(alpha, L)].get('credit_bridge', {}).get('mean_bp_cosine_mean', np.nan) + if not np.isnan(sb_g) and not np.isnan(cb_g): + sb_gammas.append(sb_g) + cb_gammas.append(cb_g) + valid_alphas.append(alpha) + if sb_gammas: + ax.scatter(sb_gammas, cb_gammas, s=80, label=f'L={L}', zorder=5) + for a, sx, sy in zip(valid_alphas, sb_gammas, cb_gammas): + ax.annotate(f'α={a}', (sx, sy), textcoords="offset points", + xytext=(5, 5), fontsize=8) + ax.plot([0, 1], [0, 1], 'k--', alpha=0.3, label='y=x') + ax.set_xlabel('State Bridge Γ', fontsize=11) + ax.set_ylabel('Credit Bridge Γ', fontsize=11) + ax.set_title('State Bridge vs Credit Bridge BP Cosine', fontsize=12) + ax.legend() + ax.grid(True, alpha=0.3) + ax.set_xlim(-0.1, 1.0) + ax.set_ylim(-0.1, 1.0) + + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f'{save_prefix}_state_vs_credit.png'), dpi=150, bbox_inches='tight') + plt.close(fig) + print(f"Saved {save_prefix}_state_vs_credit.png") + + # ======================================================================== + # Plot 4: Nudging and accuracy comparison + # ======================================================================== + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + + # Left: Nudge vs alpha for each depth + ax = axes[0] + for method in ['dfa', 'state_bridge', 'credit_bridge']: + for L in depths: + vals = [] + valid_alphas = [] + for alpha in alphas: + if (alpha, L) in agg and method in agg[(alpha, L)]: + v = agg[(alpha, L)][method].get('mean_nudge_01_mean', np.nan) + if not np.isnan(v): + vals.append(v) + valid_alphas.append(alpha) + if vals: + ls = '-' if L == min(depths) else '--' + ax.plot(valid_alphas, vals, f'o{ls}', color=colors[method], + label=f'{labels[method]} L={L}', markersize=5) + ax.set_xlabel('α (nonlinearity)', fontsize=11) + ax.set_ylabel('Mean Nudge (η=0.01, negative=good)', fontsize=11) + ax.set_title('Nudging Test', fontsize=12) + ax.legend(fontsize=8, ncol=2) + ax.grid(True, alpha=0.3) + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + + # Right: Test accuracy + ax = axes[1] + for method in all_methods: + for L in depths: + vals = [] + valid_alphas = [] + for alpha in alphas: + if (alpha, L) in agg and method in agg[(alpha, L)]: + v = agg[(alpha, L)][method].get('test_acc_mean', np.nan) + if not np.isnan(v): + vals.append(v) + valid_alphas.append(alpha) + if vals: + ls = '-' if L == min(depths) else '--' + ax.plot(valid_alphas, vals, f'o{ls}', color=colors[method], + label=f'{labels[method]} L={L}', markersize=5) + ax.set_xlabel('α (nonlinearity)', fontsize=11) + ax.set_ylabel('Test Accuracy', fontsize=11) + ax.set_title('Test Accuracy', fontsize=12) + ax.legend(fontsize=8, ncol=2) + ax.grid(True, alpha=0.3) + + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f'{save_prefix}_nudge_acc.png'), dpi=150, bbox_inches='tight') + plt.close(fig) + print(f"Saved {save_prefix}_nudge_acc.png") + + # ======================================================================== + # Plot 5: Per-layer diagnostics for selected (alpha, L) combos + # ======================================================================== + combos = [(a, L) for a in alphas for L in depths if (a, L) in agg] + # Only plot non-blown-up combos + valid_combos = [] + for a, L in combos: + if not np.isnan(agg[(a, L)].get('credit_bridge', {}).get('mean_bp_cosine_mean', np.nan)): + valid_combos.append((a, L)) + + if valid_combos: + n_combos = len(valid_combos) + fig, axes = plt.subplots(2, n_combos, figsize=(5 * n_combos, 8), squeeze=False) + + for idx, (alpha, L) in enumerate(valid_combos): + # Row 0: BP cosine per layer + ax = axes[0, idx] + for method in all_methods: + if method in agg[(alpha, L)]: + pl = agg[(alpha, L)][method].get('bp_cosine_per_layer_mean', None) + if pl is not None: + ax.plot(range(len(pl)), pl, 'o-', color=colors[method], + label=labels[method], markersize=4) + ax.set_xlabel('Layer') + ax.set_ylabel('BP Cosine') + ax.set_title(f'α={alpha}, L={L}', fontsize=11) + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + ax.set_ylim(-0.3, 1.05) + + # Row 1: rho per layer + ax = axes[1, idx] + for method in all_methods: + if method in agg[(alpha, L)]: + pl = agg[(alpha, L)][method].get('rho_per_layer_mean', None) + if pl is not None: + ax.plot(range(len(pl)), pl, 's-', color=colors[method], + label=labels[method], markersize=4) + ax.set_xlabel('Layer') + ax.set_ylabel('Perturbation ρ') + ax.set_title(f'α={alpha}, L={L}', fontsize=11) + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + + fig.suptitle('Per-Layer Diagnostics', fontsize=14, y=1.02) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f'{save_prefix}_per_layer.png'), dpi=150, bbox_inches='tight') + plt.close(fig) + print(f"Saved {save_prefix}_per_layer.png") + + # ======================================================================== + # Plot 6: Advantage plots: CB - DFA + # ======================================================================== + fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + metrics = [ + ('mean_bp_cosine_mean', 'Γ(CB) - Γ(DFA)'), + ('mean_rho_mean', 'ρ(CB) - ρ(DFA)'), + ('mean_nudge_01_mean', 'nudge(CB) - nudge(DFA)'), + ] + + for ax, (metric_key, ylabel) in zip(axes, metrics): + for L in depths: + diffs = [] + valid_alphas = [] + for alpha in alphas: + if (alpha, L) in agg: + cb_v = agg[(alpha, L)].get('credit_bridge', {}).get(metric_key, np.nan) + dfa_v = agg[(alpha, L)].get('dfa', {}).get(metric_key, np.nan) + if not np.isnan(cb_v) and not np.isnan(dfa_v): + diffs.append(cb_v - dfa_v) + valid_alphas.append(alpha) + if diffs: + ax.plot(valid_alphas, diffs, 'o-', label=f'L={L}', markersize=6, linewidth=2) + ax.set_xlabel('α (nonlinearity)', fontsize=11) + ax.set_ylabel(ylabel, fontsize=11) + ax.axhline(y=0, color='red', linestyle='--', alpha=0.7, label='zero (parity)') + ax.legend(fontsize=10) + ax.grid(True, alpha=0.3) + + fig.suptitle('Credit Bridge Advantage over DFA', fontsize=14, y=1.02) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f'{save_prefix}_cb_advantage.png'), dpi=150, bbox_inches='tight') + plt.close(fig) + print(f"Saved {save_prefix}_cb_advantage.png") + + +if __name__ == '__main__': + import sys + result_dirs = sys.argv[1:] if len(sys.argv) > 1 else ['results/synth_ladder_smoke'] + + print(f"Loading from: {result_dirs}") + summary = load_summaries(result_dirs) + + if not summary: + print("No results found!") + sys.exit(1) + + # Extract alphas and depths from keys + alphas = set() + depths = set() + for key in summary: + alpha, L, seed = parse_key(key) + alphas.add(alpha) + depths.add(L) + + alphas = sorted(alphas) + depths = sorted(depths) + print(f"Alphas: {alphas}") + print(f"Depths: {depths}") + print(f"Total configs: {len(summary)}") + + agg = aggregate_seeds(summary) + plot_phase_diagrams(agg, alphas, depths) + print("\nAll plots generated!") |
