#!/usr/bin/env python3 """Figure 1 (main, 3 panels) + Appendix figure (1 panel) for §2.3. Main Figure 1 (fig1_bp_bottleneck.{png,pdf}) — three panels: (a) BP hidden weight-gradient collapse: ||dL/dW_l||_F per layer, log scale, L∈{6,10,20}. Zeros clipped at 1e-39 for log-scale visualization. Output-side error is in the (c) summary table, NOT overlaid here. (b) Frozen linear-probe accuracy on H_l with chance line at 1/7. Caveat goes in figure caption (probes are diagnostic, not a training method). (c) Summary table — Depth × {BP acc, hidden underflow count, output error ||dL/dZ_{L-1}||, mid-layer probe acc}. Appendix figure (fig_app_forward_magnitude.{png,pdf}) — one panel: Raw activation magnitude M_l and centered dispersion D_l per layer. Supports the caption note that the §2.3 claim is about scale-normalized recoverability, not numerical largeness of the forward pass. 20 seeds, GCN, Cora, paper setup, epoch-100 checkpoint. Source: results/diag_section23/diag_data_v2.json. """ import json import numpy as np import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from matplotlib.lines import Line2D DATA_PATH = '/home/yurenh2/graph-grape/results/diag_section23/diag_data_v2.json' OUT_PNG = '/home/yurenh2/graph-grape/fig1_bp_bottleneck.png' OUT_PDF = '/home/yurenh2/graph-grape/fig1_bp_bottleneck.pdf' APP_PNG = '/home/yurenh2/graph-grape/fig_app_forward_magnitude.png' APP_PDF = '/home/yurenh2/graph-grape/fig_app_forward_magnitude.pdf' CHANCE = 1.0 / 7.0 UNDERFLOW = 1e-39 DATA = json.load(open(DATA_PATH)) DEPTHS = [(6, '#5b8def', 'GCN $L\\!=\\!6$'), (10, '#cc6677', 'GCN $L\\!=\\!10$'), (20, '#882255', 'GCN $L\\!=\\!20$')] plt.rcParams.update({ 'font.size': 9, 'axes.labelsize': 9, 'xtick.labelsize': 8, 'ytick.labelsize': 8, 'legend.fontsize': 8, 'pdf.fonttype': 42, 'ps.fonttype': 42, }) GRID = '#ECEFF3' TEXT = '#2F3437' def panel_weight_grad(ax): for L, color, label in DEPTHS: rows = DATA[f'L={L}'] Wg = np.array([r['W_grads_F'] for r in rows]) Wg_c = np.where(Wg <= 0, UNDERFLOW, Wg) med = np.median(Wg_c, axis=0) p25 = np.percentile(Wg_c, 25, axis=0) p75 = np.percentile(Wg_c, 75, axis=0) xs = np.arange(L) ax.plot(xs, med, marker='o', markersize=4, color=color, linewidth=1.6, label=label, zorder=3) ax.fill_between(xs, p25, p75, color=color, alpha=0.15, edgecolor='none', zorder=2) ax.axhline(y=UNDERFLOW * 1.5, color='#999999', linestyle='--', linewidth=0.7) ax.text(0.5, UNDERFLOW * 3, 'recorded as zero (display floor)', fontsize=7, color='#666666', va='bottom') ax.set_yscale('log') ax.set_ylim(UNDERFLOW * 0.5, 5) ax.set_xlabel('Layer index $\\ell$', color=TEXT) ax.set_ylabel('$\\|\\partial \\mathcal{L}/\\partial W_\\ell\\|_F$', color=TEXT) ax.set_title('(a) BP returns zero hidden weight gradients', fontsize=10, color=TEXT, pad=4) ax.grid(axis='both', color=GRID, linewidth=0.6) ax.set_axisbelow(True) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.legend(loc='lower right', fontsize=7, frameon=False, handletextpad=0.4, labelspacing=0.3) def panel_linear_probe(ax): for L, color, label in DEPTHS: rows = DATA[f'L={L}'] P = np.array([r['probe_acc'] for r in rows]) med = np.nanmedian(P, axis=0) p25 = np.nanpercentile(P, 25, axis=0) p75 = np.nanpercentile(P, 75, axis=0) xs = np.arange(P.shape[1]) ax.plot(xs, med, marker='o', markersize=4, color=color, linewidth=1.6, label=label, zorder=3) ax.fill_between(xs, p25, p75, color=color, alpha=0.15, edgecolor='none', zorder=2) ax.axhline(y=CHANCE, color='#999999', linestyle='--', linewidth=0.7) ax.text(0.4, CHANCE + 0.015, 'chance ($1/7$)', fontsize=7, color='#666666') ax.set_xlabel('Layer index $\\ell$ (post-act $H_\\ell$)', color=TEXT) ax.set_ylabel('Frozen linear-probe accuracy', color=TEXT) ax.set_title('(b) Linear probe on hidden states', fontsize=10, color=TEXT, pad=4) ax.set_ylim(0.05, 0.85) ax.grid(axis='both', color=GRID, linewidth=0.6) ax.set_axisbelow(True) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.legend(loc='upper right', fontsize=7, frameon=False, handletextpad=0.4, labelspacing=0.3) def compute_summary_rows(): """Return list of (depth, bp_acc_str, underflow_str, out_err_str, probe_str).""" out = [] for L, _, _ in DEPTHS: rows = DATA[f'L={L}'] Wg = np.array([r['W_grads_F'] for r in rows]) n_under = int((Wg <= 0).sum()) n_total = Wg.size accs = np.array([r['bp_acc'] for r in rows]) * 100 # percent Zg_out = np.array([r['Z_grads_F'][-1] for r in rows]) Zg_med = np.median(Zg_out) P = np.array([r['probe_acc'] for r in rows]) if L >= 6: mid_slice = P[:, 1:L] else: mid_slice = P[:, 1:] probe_mid = np.nanmedian(mid_slice) # tight "xx.x ± y.y %" (% in the value since the column header dropped it) bp_str = f'{accs.mean():.1f} ± {accs.std():.1f}%' out.append(( f'$L = {L}$', bp_str, f'{n_under}/{n_total}', f'{Zg_med:.1e}', f'{probe_mid:.2f}', )) return out def panel_summary_table(ax): """Hand-render a clean summary table that fills the panel.""" ax.set_xlim(0, 1) ax.set_ylim(0, 1) ax.set_xticks([]); ax.set_yticks([]) for s in ax.spines.values(): s.set_visible(False) ax.set_title('(c) Summary across depth (20 seeds)', fontsize=10, color=TEXT, pad=4) rows = compute_summary_rows() headers = ['Depth', 'BP test acc', '$W$-grad zeros', 'out. err.', 'mid-layer\nprobe'] n_cols = len(headers) # column boundaries: depth narrow, BP-acc / probe slightly wider, # remainder evenly split. Then x-centers are exact midpoints so every # cell is centred between its dividers. col_edges = [0.13, 0.36, 0.58, 0.78] # 4 inner dividers bounds = [0.0] + col_edges + [1.0] # 6 outer / inner edges col_x = [(bounds[i] + bounds[i + 1]) / 2 for i in range(n_cols)] # Stretch table to fill axes height: header band on top, three rows # filling the rest of the panel down to y=0. header_h = 0.22 row_h = 0.26 # 0.78 / 3 header_y = 1.0 - header_h / 2 # = 0.89 header_top = 1.0 header_bot = 1.0 - header_h # = 0.78 row_ys = [header_bot - row_h * (i + 0.5) # 0.65 / 0.39 / 0.13 for i in range(3)] # Alternating row backgrounds for i, y in enumerate(row_ys): bg = '#F7F8FA' if i % 2 else '#FFFFFF' ax.add_patch(plt.Rectangle((0.0, y - row_h / 2), 1.0, row_h, facecolor=bg, edgecolor='none', zorder=1)) # Header band ax.add_patch(plt.Rectangle((0.0, header_bot), 1.0, header_h, facecolor='#EAEDF1', edgecolor='none', zorder=1)) # Header text for x, h in zip(col_x, headers): ax.text(x, header_y, h, ha='center', va='center', fontsize=8.5, fontweight='bold', color=TEXT, zorder=3, linespacing=1.0) # Data rows for i, ((depth_str, bp_str, under_str, out_str, probe_str), (_, color, _), y) in enumerate(zip(rows, DEPTHS, row_ys)): ax.text(col_x[0], y, depth_str, ha='center', va='center', fontsize=9, fontweight='bold', color=color, zorder=3) ax.text(col_x[1], y, bp_str, ha='center', va='center', fontsize=8.5, color=TEXT, zorder=3) ax.text(col_x[2], y, under_str, ha='center', va='center', fontsize=8.5, color=TEXT, zorder=3) ax.text(col_x[3], y, out_str, ha='center', va='center', fontsize=8.5, color=TEXT, zorder=3) ax.text(col_x[4], y, probe_str, ha='center', va='center', fontsize=8.5, color=TEXT, zorder=3) # Horizontal rules: top, under header, between rows, bottom bottom = row_ys[-1] - row_h / 2 for y in (1.0, header_bot, bottom): ax.plot([0, 1], [y, y], color='#C9CDD3', linewidth=0.8, zorder=2) # Vertical separators between columns, full height for x in col_edges: ax.plot([x, x], [bottom, 1.0], color='#C9CDD3', linewidth=0.6, zorder=2) # Outer left/right borders for symmetry for x in (0.0, 1.0): ax.plot([x, x], [bottom, 1.0], color='#C9CDD3', linewidth=0.8, zorder=2) # Pin axes to the table extent so title sits flush like (a)/(b) ax.set_xlim(0, 1) ax.set_ylim(bottom, 1.0) def panel_forward_magnitude(ax): for L, color, label in DEPTHS: rows = DATA[f'L={L}'] M = np.array([r['M_rms'] for r in rows]) D = np.array([r['D_norm'] for r in rows]) M_c = np.where(M <= 0, UNDERFLOW, M) D_c = np.where(D <= 0, UNDERFLOW, D) M_med = np.median(M_c, axis=0) D_med = np.median(D_c, axis=0) xs = np.arange(L + 1) ax.plot(xs, M_med, marker='o', markersize=3.5, color=color, linewidth=1.4, label=f'{label} : $M_\\ell$', zorder=3) ax.plot(xs, D_med, marker='s', markersize=3.5, color=color, linewidth=1.0, linestyle='--', alpha=0.7, label=f'{label} : $D_\\ell$', zorder=3) ax.set_yscale('log') ax.set_ylim(UNDERFLOW * 0.5, 200) ax.axhline(y=UNDERFLOW * 1.5, color='#999999', linestyle='--', linewidth=0.7) ax.set_xlabel('Layer index $\\ell$ (post-act $H_\\ell$)', color=TEXT) ax.set_ylabel('Forward magnitude $M_\\ell$, dispersion $D_\\ell$', color=TEXT) ax.set_title('Raw activation magnitude and centered dispersion', fontsize=10, color=TEXT, pad=4) ax.grid(axis='both', color=GRID, linewidth=0.6) ax.set_axisbelow(True) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) color_handles = [Line2D([0], [0], color=c, linewidth=1.6, label=lbl) for _, c, lbl in DEPTHS] mag_handle = Line2D([0], [0], color='gray', linewidth=1.4, marker='o', markersize=3.5, label='$M_\\ell$ (RMS magnitude)') disp_handle = Line2D([0], [0], color='gray', linewidth=1.0, marker='s', markersize=3.5, linestyle='--', alpha=0.7, label='$D_\\ell$ (centered dispersion)') ax.legend(handles=color_handles + [mag_handle, disp_handle], loc='lower right', fontsize=7, frameon=False, handletextpad=0.4, labelspacing=0.3) # Main Figure 1 — 3 panels (weight grad / probe / summary table) fig, axes = plt.subplots(1, 3, figsize=(13.5, 3.4), gridspec_kw={'width_ratios': [1.0, 1.0, 1.45]}) panel_weight_grad(axes[0]) panel_linear_probe(axes[1]) panel_summary_table(axes[2]) fig.tight_layout(w_pad=2.5) fig.savefig(OUT_PNG, dpi=300, bbox_inches='tight') fig.savefig(OUT_PDF, bbox_inches='tight') plt.close(fig) print(f'Saved {OUT_PNG} and {OUT_PDF}') # Appendix figure fig, ax = plt.subplots(1, 1, figsize=(5.5, 3.4)) panel_forward_magnitude(ax) fig.tight_layout() fig.savefig(APP_PNG, dpi=300, bbox_inches='tight') fig.savefig(APP_PDF, bbox_inches='tight') plt.close(fig) print(f'Saved {APP_PNG} and {APP_PDF}') print('\nSummary table:') for row in compute_summary_rows(): print(' ', row)