"""Generate phase diagram plots from synthetic nonlinearity ladder results.""" import os import sys import json import numpy as np import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from collections import defaultdict output_dir = 'report_explore' os.makedirs(output_dir, exist_ok=True) def load_summaries(result_dirs): """Load and merge summary.json files from multiple result directories.""" merged = {} for d in result_dirs: path = os.path.join(d, 'summary.json') if os.path.exists(path): with open(path) as f: data = json.load(f) merged.update(data) return merged def parse_key(key): """Parse 'a0.5_L8_s42' -> (alpha, L, seed).""" parts = key.split('_') alpha = float(parts[0][1:]) L = int(parts[1][1:]) seed = int(parts[2][1:]) return alpha, L, seed def aggregate_seeds(summary): """Aggregate over seeds: group by (alpha, L).""" groups = defaultdict(list) for key, data in summary.items(): alpha, L, seed = parse_key(key) groups[(alpha, L)].append(data) agg = {} for (alpha, L), entries in groups.items(): agg[(alpha, L)] = {} for method in ['bp', 'dfa', 'state_bridge', 'credit_bridge']: vals = {} for metric in ['test_acc', 'mean_bp_cosine', 'mean_rho', 'mean_nudge_01', 'mean_nudge_001', 'mean_nudge_003']: arr = [e[method].get(metric, 0) for e in entries if method in e] # Filter out blown-up values arr = [v for v in arr if abs(v) < 1e6] if arr: vals[f'{metric}_mean'] = np.mean(arr) vals[f'{metric}_std'] = np.std(arr) else: vals[f'{metric}_mean'] = np.nan vals[f'{metric}_std'] = np.nan # State prediction error if method == 'state_bridge': arr = [e[method].get('mean_state_pred_error', 0) for e in entries if method in e] arr = [v for v in arr if abs(v) < 1e6] if arr: vals['state_pred_error_mean'] = np.mean(arr) vals['state_pred_error_std'] = np.std(arr) else: vals['state_pred_error_mean'] = np.nan vals['state_pred_error_std'] = np.nan # Per-layer diagnostics (average over seeds) for pl_metric in ['bp_cosine_per_layer', 'rho_per_layer', 'nudge_per_layer']: arrays = [] for e in entries: if method in e and pl_metric in e[method]: arr = e[method][pl_metric] # Check for blowup if all(abs(v) < 1e6 for v in arr): arrays.append(arr) if arrays: arr2d = np.array(arrays) vals[f'{pl_metric}_mean'] = arr2d.mean(axis=0).tolist() vals[f'{pl_metric}_std'] = arr2d.std(axis=0).tolist() agg[(alpha, L)][method] = vals return agg def plot_phase_diagrams(agg, alphas, depths, save_prefix='synth'): """Generate the 4 key phase diagram plots.""" methods = ['dfa', 'state_bridge', 'credit_bridge'] colors = {'bp': '#F44336', 'dfa': '#2196F3', 'state_bridge': '#FF9800', 'credit_bridge': '#4CAF50'} labels = {'bp': 'BP', 'dfa': 'DFA', 'state_bridge': 'State Bridge', 'credit_bridge': 'Credit Bridge'} alphas = sorted(alphas) depths = sorted(depths) # ======================================================================== # Plot 1: Phase diagram heatmaps (Gamma, rho, nudge) for each method # ======================================================================== fig, axes = plt.subplots(3, 4, figsize=(20, 12)) metrics_info = [ ('mean_bp_cosine_mean', 'BP Cosine (Γ)', 'RdYlGn', -0.1, 1.0), ('mean_rho_mean', 'Perturbation ρ', 'RdYlGn', -0.2, 1.0), ('mean_nudge_01_mean', 'Nudge (η=0.01)', 'RdYlGn_r', None, None), ] all_methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge'] for row, (metric_key, metric_name, cmap, vmin, vmax) in enumerate(metrics_info): for col, method in enumerate(all_methods): ax = axes[row, col] grid = np.full((len(alphas), len(depths)), np.nan) for i, alpha in enumerate(alphas): for j, L in enumerate(depths): if (alpha, L) in agg and method in agg[(alpha, L)]: val = agg[(alpha, L)][method].get(metric_key, np.nan) grid[i, j] = val if vmin is None: valid = grid[~np.isnan(grid)] if len(valid) > 0: vmin_use = np.nanmin(grid) vmax_use = min(0, np.nanmax(grid)) # nudge should be <=0 else: vmin_use, vmax_use = -1, 0 else: vmin_use, vmax_use = vmin, vmax im = ax.imshow(grid, cmap=cmap, vmin=vmin_use, vmax=vmax_use, aspect='auto', origin='lower') ax.set_xticks(range(len(depths))) ax.set_xticklabels([str(d) for d in depths]) ax.set_yticks(range(len(alphas))) ax.set_yticklabels([str(a) for a in alphas]) # Annotate cells for i in range(len(alphas)): for j in range(len(depths)): val = grid[i, j] if not np.isnan(val): txt = f'{val:.3f}' if abs(val) < 100 else f'{val:.1e}' ax.text(j, i, txt, ha='center', va='center', fontsize=8, color='black' if abs(val) < 0.5 else 'white') else: ax.text(j, i, 'X', ha='center', va='center', fontsize=10, color='red', fontweight='bold') if row == 0: ax.set_title(labels[method], fontsize=12) if col == 0: ax.set_ylabel(f'{metric_name}\nα (nonlinearity)', fontsize=10) if row == len(metrics_info) - 1: ax.set_xlabel('Depth (L)', fontsize=10) plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) fig.suptitle('Synthetic Ladder: Phase Diagrams (X = blown up)', fontsize=14, y=1.02) fig.tight_layout() fig.savefig(os.path.join(output_dir, f'{save_prefix}_phase_heatmaps.png'), dpi=150, bbox_inches='tight') plt.close(fig) print(f"Saved {save_prefix}_phase_heatmaps.png") # ======================================================================== # Plot 2: Line plots - metric vs alpha for each depth # ======================================================================== fig, axes = plt.subplots(2, max(len(depths), 1), figsize=(6 * len(depths), 10), squeeze=False) for j, L in enumerate(depths): # Row 0: Gamma ax = axes[0, j] for method in all_methods: vals = [] errs = [] valid_alphas = [] for alpha in alphas: if (alpha, L) in agg and method in agg[(alpha, L)]: v = agg[(alpha, L)][method].get('mean_bp_cosine_mean', np.nan) e = agg[(alpha, L)][method].get('mean_bp_cosine_std', 0) if not np.isnan(v): vals.append(v) errs.append(e) valid_alphas.append(alpha) if vals: ax.errorbar(valid_alphas, vals, yerr=errs, marker='o', color=colors[method], label=labels[method], capsize=3, linewidth=2, markersize=6) ax.set_xlabel('α (nonlinearity)', fontsize=11) ax.set_ylabel('Mean BP Cosine (Γ)', fontsize=11) ax.set_title(f'L={L}', fontsize=13) ax.legend(fontsize=9) ax.grid(True, alpha=0.3) ax.set_ylim(-0.1, 1.05) # Row 1: rho ax = axes[1, j] for method in all_methods: vals = [] errs = [] valid_alphas = [] for alpha in alphas: if (alpha, L) in agg and method in agg[(alpha, L)]: v = agg[(alpha, L)][method].get('mean_rho_mean', np.nan) e = agg[(alpha, L)][method].get('mean_rho_std', 0) if not np.isnan(v): vals.append(v) errs.append(e) valid_alphas.append(alpha) if vals: ax.errorbar(valid_alphas, vals, yerr=errs, marker='s', color=colors[method], label=labels[method], capsize=3, linewidth=2, markersize=6) ax.set_xlabel('α (nonlinearity)', fontsize=11) ax.set_ylabel('Mean Perturbation ρ', fontsize=11) ax.set_title(f'L={L}', fontsize=13) ax.legend(fontsize=9) ax.grid(True, alpha=0.3) ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) fig.suptitle('Synthetic Ladder: Credit Quality vs Nonlinearity', fontsize=14, y=1.02) fig.tight_layout() fig.savefig(os.path.join(output_dir, f'{save_prefix}_gamma_rho_vs_alpha.png'), dpi=150, bbox_inches='tight') plt.close(fig) print(f"Saved {save_prefix}_gamma_rho_vs_alpha.png") # ======================================================================== # Plot 3: State bridge prediction error vs credit quality # ======================================================================== fig, axes = plt.subplots(1, 2, figsize=(14, 5)) # Left: State prediction error vs alpha ax = axes[0] for j, L in enumerate(depths): vals = [] valid_alphas = [] for alpha in alphas: if (alpha, L) in agg and 'state_bridge' in agg[(alpha, L)]: v = agg[(alpha, L)]['state_bridge'].get('state_pred_error_mean', np.nan) if not np.isnan(v): vals.append(v) valid_alphas.append(alpha) if vals: ax.plot(valid_alphas, vals, 'o-', label=f'L={L}', markersize=6, linewidth=2) ax.set_xlabel('α (nonlinearity)', fontsize=11) ax.set_ylabel('State Prediction Error', fontsize=11) ax.set_title('State Bridge: Terminal Prediction Error', fontsize=12) ax.legend() ax.grid(True, alpha=0.3) ax.set_yscale('log') # Right: State bridge Gamma vs credit bridge Gamma ax = axes[1] for L in depths: sb_gammas = [] cb_gammas = [] valid_alphas = [] for alpha in alphas: if (alpha, L) in agg: sb_g = agg[(alpha, L)].get('state_bridge', {}).get('mean_bp_cosine_mean', np.nan) cb_g = agg[(alpha, L)].get('credit_bridge', {}).get('mean_bp_cosine_mean', np.nan) if not np.isnan(sb_g) and not np.isnan(cb_g): sb_gammas.append(sb_g) cb_gammas.append(cb_g) valid_alphas.append(alpha) if sb_gammas: ax.scatter(sb_gammas, cb_gammas, s=80, label=f'L={L}', zorder=5) for a, sx, sy in zip(valid_alphas, sb_gammas, cb_gammas): ax.annotate(f'α={a}', (sx, sy), textcoords="offset points", xytext=(5, 5), fontsize=8) ax.plot([0, 1], [0, 1], 'k--', alpha=0.3, label='y=x') ax.set_xlabel('State Bridge Γ', fontsize=11) ax.set_ylabel('Credit Bridge Γ', fontsize=11) ax.set_title('State Bridge vs Credit Bridge BP Cosine', fontsize=12) ax.legend() ax.grid(True, alpha=0.3) ax.set_xlim(-0.1, 1.0) ax.set_ylim(-0.1, 1.0) fig.tight_layout() fig.savefig(os.path.join(output_dir, f'{save_prefix}_state_vs_credit.png'), dpi=150, bbox_inches='tight') plt.close(fig) print(f"Saved {save_prefix}_state_vs_credit.png") # ======================================================================== # Plot 4: Nudging and accuracy comparison # ======================================================================== fig, axes = plt.subplots(1, 2, figsize=(14, 5)) # Left: Nudge vs alpha for each depth ax = axes[0] for method in ['dfa', 'state_bridge', 'credit_bridge']: for L in depths: vals = [] valid_alphas = [] for alpha in alphas: if (alpha, L) in agg and method in agg[(alpha, L)]: v = agg[(alpha, L)][method].get('mean_nudge_01_mean', np.nan) if not np.isnan(v): vals.append(v) valid_alphas.append(alpha) if vals: ls = '-' if L == min(depths) else '--' ax.plot(valid_alphas, vals, f'o{ls}', color=colors[method], label=f'{labels[method]} L={L}', markersize=5) ax.set_xlabel('α (nonlinearity)', fontsize=11) ax.set_ylabel('Mean Nudge (η=0.01, negative=good)', fontsize=11) ax.set_title('Nudging Test', fontsize=12) ax.legend(fontsize=8, ncol=2) ax.grid(True, alpha=0.3) ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) # Right: Test accuracy ax = axes[1] for method in all_methods: for L in depths: vals = [] valid_alphas = [] for alpha in alphas: if (alpha, L) in agg and method in agg[(alpha, L)]: v = agg[(alpha, L)][method].get('test_acc_mean', np.nan) if not np.isnan(v): vals.append(v) valid_alphas.append(alpha) if vals: ls = '-' if L == min(depths) else '--' ax.plot(valid_alphas, vals, f'o{ls}', color=colors[method], label=f'{labels[method]} L={L}', markersize=5) ax.set_xlabel('α (nonlinearity)', fontsize=11) ax.set_ylabel('Test Accuracy', fontsize=11) ax.set_title('Test Accuracy', fontsize=12) ax.legend(fontsize=8, ncol=2) ax.grid(True, alpha=0.3) fig.tight_layout() fig.savefig(os.path.join(output_dir, f'{save_prefix}_nudge_acc.png'), dpi=150, bbox_inches='tight') plt.close(fig) print(f"Saved {save_prefix}_nudge_acc.png") # ======================================================================== # Plot 5: Per-layer diagnostics for selected (alpha, L) combos # ======================================================================== combos = [(a, L) for a in alphas for L in depths if (a, L) in agg] # Only plot non-blown-up combos valid_combos = [] for a, L in combos: if not np.isnan(agg[(a, L)].get('credit_bridge', {}).get('mean_bp_cosine_mean', np.nan)): valid_combos.append((a, L)) if valid_combos: n_combos = len(valid_combos) fig, axes = plt.subplots(2, n_combos, figsize=(5 * n_combos, 8), squeeze=False) for idx, (alpha, L) in enumerate(valid_combos): # Row 0: BP cosine per layer ax = axes[0, idx] for method in all_methods: if method in agg[(alpha, L)]: pl = agg[(alpha, L)][method].get('bp_cosine_per_layer_mean', None) if pl is not None: ax.plot(range(len(pl)), pl, 'o-', color=colors[method], label=labels[method], markersize=4) ax.set_xlabel('Layer') ax.set_ylabel('BP Cosine') ax.set_title(f'α={alpha}, L={L}', fontsize=11) ax.legend(fontsize=8) ax.grid(True, alpha=0.3) ax.set_ylim(-0.3, 1.05) # Row 1: rho per layer ax = axes[1, idx] for method in all_methods: if method in agg[(alpha, L)]: pl = agg[(alpha, L)][method].get('rho_per_layer_mean', None) if pl is not None: ax.plot(range(len(pl)), pl, 's-', color=colors[method], label=labels[method], markersize=4) ax.set_xlabel('Layer') ax.set_ylabel('Perturbation ρ') ax.set_title(f'α={alpha}, L={L}', fontsize=11) ax.legend(fontsize=8) ax.grid(True, alpha=0.3) ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) fig.suptitle('Per-Layer Diagnostics', fontsize=14, y=1.02) fig.tight_layout() fig.savefig(os.path.join(output_dir, f'{save_prefix}_per_layer.png'), dpi=150, bbox_inches='tight') plt.close(fig) print(f"Saved {save_prefix}_per_layer.png") # ======================================================================== # Plot 6: Advantage plots: CB - DFA # ======================================================================== fig, axes = plt.subplots(1, 3, figsize=(18, 5)) metrics = [ ('mean_bp_cosine_mean', 'Γ(CB) - Γ(DFA)'), ('mean_rho_mean', 'ρ(CB) - ρ(DFA)'), ('mean_nudge_01_mean', 'nudge(CB) - nudge(DFA)'), ] for ax, (metric_key, ylabel) in zip(axes, metrics): for L in depths: diffs = [] valid_alphas = [] for alpha in alphas: if (alpha, L) in agg: cb_v = agg[(alpha, L)].get('credit_bridge', {}).get(metric_key, np.nan) dfa_v = agg[(alpha, L)].get('dfa', {}).get(metric_key, np.nan) if not np.isnan(cb_v) and not np.isnan(dfa_v): diffs.append(cb_v - dfa_v) valid_alphas.append(alpha) if diffs: ax.plot(valid_alphas, diffs, 'o-', label=f'L={L}', markersize=6, linewidth=2) ax.set_xlabel('α (nonlinearity)', fontsize=11) ax.set_ylabel(ylabel, fontsize=11) ax.axhline(y=0, color='red', linestyle='--', alpha=0.7, label='zero (parity)') ax.legend(fontsize=10) ax.grid(True, alpha=0.3) fig.suptitle('Credit Bridge Advantage over DFA', fontsize=14, y=1.02) fig.tight_layout() fig.savefig(os.path.join(output_dir, f'{save_prefix}_cb_advantage.png'), dpi=150, bbox_inches='tight') plt.close(fig) print(f"Saved {save_prefix}_cb_advantage.png") if __name__ == '__main__': import sys result_dirs = sys.argv[1:] if len(sys.argv) > 1 else ['results/synth_ladder_smoke'] print(f"Loading from: {result_dirs}") summary = load_summaries(result_dirs) if not summary: print("No results found!") sys.exit(1) # Extract alphas and depths from keys alphas = set() depths = set() for key in summary: alpha, L, seed = parse_key(key) alphas.add(alpha) depths.add(L) alphas = sorted(alphas) depths = sorted(depths) print(f"Alphas: {alphas}") print(f"Depths: {depths}") print(f"Total configs: {len(summary)}") agg = aggregate_seeds(summary) plot_phase_diagrams(agg, alphas, depths) print("\nAll plots generated!")