summaryrefslogtreecommitdiff
path: root/figures/gen_fig1_diagnostic.py
diff options
context:
space:
mode:
Diffstat (limited to 'figures/gen_fig1_diagnostic.py')
-rw-r--r--figures/gen_fig1_diagnostic.py271
1 files changed, 271 insertions, 0 deletions
diff --git a/figures/gen_fig1_diagnostic.py b/figures/gen_fig1_diagnostic.py
new file mode 100644
index 0000000..99ffc15
--- /dev/null
+++ b/figures/gen_fig1_diagnostic.py
@@ -0,0 +1,271 @@
+#!/usr/bin/env python3
+"""Figure 1 (main, 3 panels) + Appendix figure (1 panel) for §2.3.
+
+Main Figure 1 (fig1_bp_bottleneck.{png,pdf}) — three panels:
+ (a) BP hidden weight-gradient collapse: ||dL/dW_l||_F per layer, log scale,
+ L∈{6,10,20}. Zeros clipped at 1e-39 for log-scale visualization.
+ Output-side error is in the (c) summary table, NOT overlaid here.
+ (b) Frozen linear-probe accuracy on H_l with chance line at 1/7. Caveat
+ goes in figure caption (probes are diagnostic, not a training method).
+ (c) Summary table — Depth × {BP acc, hidden underflow count,
+ output error ||dL/dZ_{L-1}||, mid-layer probe acc}.
+
+Appendix figure (fig_app_forward_magnitude.{png,pdf}) — one panel:
+ Raw activation magnitude M_l and centered dispersion D_l per layer.
+ Supports the caption note that the §2.3 claim is about scale-normalized
+ recoverability, not numerical largeness of the forward pass.
+
+20 seeds, GCN, Cora, paper setup, epoch-100 checkpoint.
+Source: results/diag_section23/diag_data_v2.json.
+"""
+import json
+import numpy as np
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+from matplotlib.lines import Line2D
+
+DATA_PATH = '/home/yurenh2/graph-grape/results/diag_section23/diag_data_v2.json'
+OUT_PNG = '/home/yurenh2/graph-grape/fig1_bp_bottleneck.png'
+OUT_PDF = '/home/yurenh2/graph-grape/fig1_bp_bottleneck.pdf'
+APP_PNG = '/home/yurenh2/graph-grape/fig_app_forward_magnitude.png'
+APP_PDF = '/home/yurenh2/graph-grape/fig_app_forward_magnitude.pdf'
+CHANCE = 1.0 / 7.0
+UNDERFLOW = 1e-39
+
+DATA = json.load(open(DATA_PATH))
+DEPTHS = [(6, '#5b8def', 'GCN $L\\!=\\!6$'),
+ (10, '#cc6677', 'GCN $L\\!=\\!10$'),
+ (20, '#882255', 'GCN $L\\!=\\!20$')]
+
+plt.rcParams.update({
+ 'font.size': 9, 'axes.labelsize': 9,
+ 'xtick.labelsize': 8, 'ytick.labelsize': 8,
+ 'legend.fontsize': 8,
+ 'pdf.fonttype': 42, 'ps.fonttype': 42,
+})
+
+GRID = '#ECEFF3'
+TEXT = '#2F3437'
+
+
+def panel_weight_grad(ax):
+ for L, color, label in DEPTHS:
+ rows = DATA[f'L={L}']
+ Wg = np.array([r['W_grads_F'] for r in rows])
+ Wg_c = np.where(Wg <= 0, UNDERFLOW, Wg)
+ med = np.median(Wg_c, axis=0)
+ p25 = np.percentile(Wg_c, 25, axis=0)
+ p75 = np.percentile(Wg_c, 75, axis=0)
+ xs = np.arange(L)
+ ax.plot(xs, med, marker='o', markersize=4, color=color,
+ linewidth=1.6, label=label, zorder=3)
+ ax.fill_between(xs, p25, p75, color=color, alpha=0.15,
+ edgecolor='none', zorder=2)
+ ax.axhline(y=UNDERFLOW * 1.5, color='#999999', linestyle='--', linewidth=0.7)
+ ax.text(0.5, UNDERFLOW * 3, 'recorded as zero (display floor)',
+ fontsize=7, color='#666666', va='bottom')
+ ax.set_yscale('log')
+ ax.set_ylim(UNDERFLOW * 0.5, 5)
+ ax.set_xlabel('Layer index $\\ell$', color=TEXT)
+ ax.set_ylabel('$\\|\\partial \\mathcal{L}/\\partial W_\\ell\\|_F$', color=TEXT)
+ ax.set_title('(a) BP returns zero hidden weight gradients',
+ fontsize=10, color=TEXT, pad=4)
+ ax.grid(axis='both', color=GRID, linewidth=0.6)
+ ax.set_axisbelow(True)
+ ax.spines['top'].set_visible(False)
+ ax.spines['right'].set_visible(False)
+ ax.legend(loc='lower right', fontsize=7, frameon=False,
+ handletextpad=0.4, labelspacing=0.3)
+
+
+def panel_linear_probe(ax):
+ for L, color, label in DEPTHS:
+ rows = DATA[f'L={L}']
+ P = np.array([r['probe_acc'] for r in rows])
+ med = np.nanmedian(P, axis=0)
+ p25 = np.nanpercentile(P, 25, axis=0)
+ p75 = np.nanpercentile(P, 75, axis=0)
+ xs = np.arange(P.shape[1])
+ ax.plot(xs, med, marker='o', markersize=4, color=color,
+ linewidth=1.6, label=label, zorder=3)
+ ax.fill_between(xs, p25, p75, color=color, alpha=0.15,
+ edgecolor='none', zorder=2)
+ ax.axhline(y=CHANCE, color='#999999', linestyle='--', linewidth=0.7)
+ ax.text(0.4, CHANCE + 0.015, 'chance ($1/7$)', fontsize=7, color='#666666')
+ ax.set_xlabel('Layer index $\\ell$ (post-act $H_\\ell$)', color=TEXT)
+ ax.set_ylabel('Frozen linear-probe accuracy', color=TEXT)
+ ax.set_title('(b) Linear probe on hidden states',
+ fontsize=10, color=TEXT, pad=4)
+ ax.set_ylim(0.05, 0.85)
+ ax.grid(axis='both', color=GRID, linewidth=0.6)
+ ax.set_axisbelow(True)
+ ax.spines['top'].set_visible(False)
+ ax.spines['right'].set_visible(False)
+ ax.legend(loc='upper right', fontsize=7, frameon=False,
+ handletextpad=0.4, labelspacing=0.3)
+
+
+def compute_summary_rows():
+ """Return list of (depth, bp_acc_str, underflow_str, out_err_str, probe_str)."""
+ out = []
+ for L, _, _ in DEPTHS:
+ rows = DATA[f'L={L}']
+ Wg = np.array([r['W_grads_F'] for r in rows])
+ n_under = int((Wg <= 0).sum())
+ n_total = Wg.size
+ accs = np.array([r['bp_acc'] for r in rows]) * 100 # percent
+ Zg_out = np.array([r['Z_grads_F'][-1] for r in rows])
+ Zg_med = np.median(Zg_out)
+ P = np.array([r['probe_acc'] for r in rows])
+ if L >= 6:
+ mid_slice = P[:, 1:L]
+ else:
+ mid_slice = P[:, 1:]
+ probe_mid = np.nanmedian(mid_slice)
+ # tight "xx.x ± y.y %" (% in the value since the column header dropped it)
+ bp_str = f'{accs.mean():.1f} ± {accs.std():.1f}%'
+ out.append((
+ f'$L = {L}$',
+ bp_str,
+ f'{n_under}/{n_total}',
+ f'{Zg_med:.1e}',
+ f'{probe_mid:.2f}',
+ ))
+ return out
+
+
+def panel_summary_table(ax):
+ """Hand-render a clean summary table that fills the panel."""
+ ax.set_xlim(0, 1)
+ ax.set_ylim(0, 1)
+ ax.set_xticks([]); ax.set_yticks([])
+ for s in ax.spines.values():
+ s.set_visible(False)
+ ax.set_title('(c) Summary across depth (20 seeds)',
+ fontsize=10, color=TEXT, pad=4)
+
+ rows = compute_summary_rows()
+ headers = ['Depth', 'BP test acc', '$W$-grad zeros',
+ 'out. err.', 'mid-layer\nprobe']
+ n_cols = len(headers)
+ # column boundaries: depth narrow, BP-acc / probe slightly wider,
+ # remainder evenly split. Then x-centers are exact midpoints so every
+ # cell is centred between its dividers.
+ col_edges = [0.13, 0.36, 0.58, 0.78] # 4 inner dividers
+ bounds = [0.0] + col_edges + [1.0] # 6 outer / inner edges
+ col_x = [(bounds[i] + bounds[i + 1]) / 2 for i in range(n_cols)]
+ # Stretch table to fill axes height: header band on top, three rows
+ # filling the rest of the panel down to y=0.
+ header_h = 0.22
+ row_h = 0.26 # 0.78 / 3
+ header_y = 1.0 - header_h / 2 # = 0.89
+ header_top = 1.0
+ header_bot = 1.0 - header_h # = 0.78
+ row_ys = [header_bot - row_h * (i + 0.5) # 0.65 / 0.39 / 0.13
+ for i in range(3)]
+ # Alternating row backgrounds
+ for i, y in enumerate(row_ys):
+ bg = '#F7F8FA' if i % 2 else '#FFFFFF'
+ ax.add_patch(plt.Rectangle((0.0, y - row_h / 2), 1.0, row_h,
+ facecolor=bg, edgecolor='none', zorder=1))
+ # Header band
+ ax.add_patch(plt.Rectangle((0.0, header_bot), 1.0, header_h,
+ facecolor='#EAEDF1', edgecolor='none', zorder=1))
+ # Header text
+ for x, h in zip(col_x, headers):
+ ax.text(x, header_y, h, ha='center', va='center',
+ fontsize=8.5, fontweight='bold', color=TEXT, zorder=3,
+ linespacing=1.0)
+ # Data rows
+ for i, ((depth_str, bp_str, under_str, out_str, probe_str),
+ (_, color, _), y) in enumerate(zip(rows, DEPTHS, row_ys)):
+ ax.text(col_x[0], y, depth_str, ha='center', va='center',
+ fontsize=9, fontweight='bold', color=color, zorder=3)
+ ax.text(col_x[1], y, bp_str, ha='center', va='center',
+ fontsize=8.5, color=TEXT, zorder=3)
+ ax.text(col_x[2], y, under_str, ha='center', va='center',
+ fontsize=8.5, color=TEXT, zorder=3)
+ ax.text(col_x[3], y, out_str, ha='center', va='center',
+ fontsize=8.5, color=TEXT, zorder=3)
+ ax.text(col_x[4], y, probe_str, ha='center', va='center',
+ fontsize=8.5, color=TEXT, zorder=3)
+ # Horizontal rules: top, under header, between rows, bottom
+ bottom = row_ys[-1] - row_h / 2
+ for y in (1.0, header_bot, bottom):
+ ax.plot([0, 1], [y, y], color='#C9CDD3', linewidth=0.8, zorder=2)
+ # Vertical separators between columns, full height
+ for x in col_edges:
+ ax.plot([x, x], [bottom, 1.0],
+ color='#C9CDD3', linewidth=0.6, zorder=2)
+ # Outer left/right borders for symmetry
+ for x in (0.0, 1.0):
+ ax.plot([x, x], [bottom, 1.0],
+ color='#C9CDD3', linewidth=0.8, zorder=2)
+ # Pin axes to the table extent so title sits flush like (a)/(b)
+ ax.set_xlim(0, 1)
+ ax.set_ylim(bottom, 1.0)
+
+
+def panel_forward_magnitude(ax):
+ for L, color, label in DEPTHS:
+ rows = DATA[f'L={L}']
+ M = np.array([r['M_rms'] for r in rows])
+ D = np.array([r['D_norm'] for r in rows])
+ M_c = np.where(M <= 0, UNDERFLOW, M)
+ D_c = np.where(D <= 0, UNDERFLOW, D)
+ M_med = np.median(M_c, axis=0)
+ D_med = np.median(D_c, axis=0)
+ xs = np.arange(L + 1)
+ ax.plot(xs, M_med, marker='o', markersize=3.5, color=color,
+ linewidth=1.4, label=f'{label} : $M_\\ell$', zorder=3)
+ ax.plot(xs, D_med, marker='s', markersize=3.5, color=color,
+ linewidth=1.0, linestyle='--', alpha=0.7,
+ label=f'{label} : $D_\\ell$', zorder=3)
+ ax.set_yscale('log')
+ ax.set_ylim(UNDERFLOW * 0.5, 200)
+ ax.axhline(y=UNDERFLOW * 1.5, color='#999999', linestyle='--', linewidth=0.7)
+ ax.set_xlabel('Layer index $\\ell$ (post-act $H_\\ell$)', color=TEXT)
+ ax.set_ylabel('Forward magnitude $M_\\ell$, dispersion $D_\\ell$', color=TEXT)
+ ax.set_title('Raw activation magnitude and centered dispersion',
+ fontsize=10, color=TEXT, pad=4)
+ ax.grid(axis='both', color=GRID, linewidth=0.6)
+ ax.set_axisbelow(True)
+ ax.spines['top'].set_visible(False)
+ ax.spines['right'].set_visible(False)
+ color_handles = [Line2D([0], [0], color=c, linewidth=1.6, label=lbl)
+ for _, c, lbl in DEPTHS]
+ mag_handle = Line2D([0], [0], color='gray', linewidth=1.4, marker='o',
+ markersize=3.5, label='$M_\\ell$ (RMS magnitude)')
+ disp_handle = Line2D([0], [0], color='gray', linewidth=1.0, marker='s',
+ markersize=3.5, linestyle='--', alpha=0.7,
+ label='$D_\\ell$ (centered dispersion)')
+ ax.legend(handles=color_handles + [mag_handle, disp_handle],
+ loc='lower right', fontsize=7, frameon=False,
+ handletextpad=0.4, labelspacing=0.3)
+
+
+# Main Figure 1 — 3 panels (weight grad / probe / summary table)
+fig, axes = plt.subplots(1, 3, figsize=(13.5, 3.4),
+ gridspec_kw={'width_ratios': [1.0, 1.0, 1.45]})
+panel_weight_grad(axes[0])
+panel_linear_probe(axes[1])
+panel_summary_table(axes[2])
+fig.tight_layout(w_pad=2.5)
+fig.savefig(OUT_PNG, dpi=300, bbox_inches='tight')
+fig.savefig(OUT_PDF, bbox_inches='tight')
+plt.close(fig)
+print(f'Saved {OUT_PNG} and {OUT_PDF}')
+
+# Appendix figure
+fig, ax = plt.subplots(1, 1, figsize=(5.5, 3.4))
+panel_forward_magnitude(ax)
+fig.tight_layout()
+fig.savefig(APP_PNG, dpi=300, bbox_inches='tight')
+fig.savefig(APP_PDF, bbox_inches='tight')
+plt.close(fig)
+print(f'Saved {APP_PNG} and {APP_PDF}')
+
+print('\nSummary table:')
+for row in compute_summary_rows():
+ print(' ', row)