diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-05-04 23:05:16 -0500 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-05-04 23:05:16 -0500 |
| commit | bd9333eda60a9029a198acaeacb1eca4312bd1e8 (patch) | |
| tree | 7544c347b7ac4e8629fa1cc0fcf341d48cb69e2e /figures/gen_fig1_diagnostic.py | |
Initial release: GRAFT (KAFT) — NeurIPS 2026 submission code
Topology-factorized Jacobian-aligned feedback for deep GNNs. Includes:
- src/: GraphGrAPETrainer (KAFT) + BP / DFA / DFA-GNN / VanillaGrAPE baselines
+ multi-probe alignment estimator + dataset / sparse-mm utilities.
- experiments/: 19 runners reproducing every figure / table in the paper.
- figures/: 4 generators + the 4 PDFs cited in the report.
- paper/: NeurIPS .tex and consolidated experiments_master notes.
Smoke test: 50-epoch Cora GCN L=4 gives BP 77.3% / KAFT 79.0%.
Diffstat (limited to 'figures/gen_fig1_diagnostic.py')
| -rw-r--r-- | figures/gen_fig1_diagnostic.py | 271 |
1 files changed, 271 insertions, 0 deletions
diff --git a/figures/gen_fig1_diagnostic.py b/figures/gen_fig1_diagnostic.py new file mode 100644 index 0000000..99ffc15 --- /dev/null +++ b/figures/gen_fig1_diagnostic.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python3 +"""Figure 1 (main, 3 panels) + Appendix figure (1 panel) for §2.3. + +Main Figure 1 (fig1_bp_bottleneck.{png,pdf}) — three panels: + (a) BP hidden weight-gradient collapse: ||dL/dW_l||_F per layer, log scale, + L∈{6,10,20}. Zeros clipped at 1e-39 for log-scale visualization. + Output-side error is in the (c) summary table, NOT overlaid here. + (b) Frozen linear-probe accuracy on H_l with chance line at 1/7. Caveat + goes in figure caption (probes are diagnostic, not a training method). + (c) Summary table — Depth × {BP acc, hidden underflow count, + output error ||dL/dZ_{L-1}||, mid-layer probe acc}. + +Appendix figure (fig_app_forward_magnitude.{png,pdf}) — one panel: + Raw activation magnitude M_l and centered dispersion D_l per layer. + Supports the caption note that the §2.3 claim is about scale-normalized + recoverability, not numerical largeness of the forward pass. + +20 seeds, GCN, Cora, paper setup, epoch-100 checkpoint. +Source: results/diag_section23/diag_data_v2.json. +""" +import json +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D + +DATA_PATH = '/home/yurenh2/graph-grape/results/diag_section23/diag_data_v2.json' +OUT_PNG = '/home/yurenh2/graph-grape/fig1_bp_bottleneck.png' +OUT_PDF = '/home/yurenh2/graph-grape/fig1_bp_bottleneck.pdf' +APP_PNG = '/home/yurenh2/graph-grape/fig_app_forward_magnitude.png' +APP_PDF = '/home/yurenh2/graph-grape/fig_app_forward_magnitude.pdf' +CHANCE = 1.0 / 7.0 +UNDERFLOW = 1e-39 + +DATA = json.load(open(DATA_PATH)) +DEPTHS = [(6, '#5b8def', 'GCN $L\\!=\\!6$'), + (10, '#cc6677', 'GCN $L\\!=\\!10$'), + (20, '#882255', 'GCN $L\\!=\\!20$')] + +plt.rcParams.update({ + 'font.size': 9, 'axes.labelsize': 9, + 'xtick.labelsize': 8, 'ytick.labelsize': 8, + 'legend.fontsize': 8, + 'pdf.fonttype': 42, 'ps.fonttype': 42, +}) + +GRID = '#ECEFF3' +TEXT = '#2F3437' + + +def panel_weight_grad(ax): + for L, color, label in DEPTHS: + rows = DATA[f'L={L}'] + Wg = np.array([r['W_grads_F'] for r in rows]) + Wg_c = np.where(Wg <= 0, UNDERFLOW, Wg) + med = np.median(Wg_c, axis=0) + p25 = np.percentile(Wg_c, 25, axis=0) + p75 = np.percentile(Wg_c, 75, axis=0) + xs = np.arange(L) + ax.plot(xs, med, marker='o', markersize=4, color=color, + linewidth=1.6, label=label, zorder=3) + ax.fill_between(xs, p25, p75, color=color, alpha=0.15, + edgecolor='none', zorder=2) + ax.axhline(y=UNDERFLOW * 1.5, color='#999999', linestyle='--', linewidth=0.7) + ax.text(0.5, UNDERFLOW * 3, 'recorded as zero (display floor)', + fontsize=7, color='#666666', va='bottom') + ax.set_yscale('log') + ax.set_ylim(UNDERFLOW * 0.5, 5) + ax.set_xlabel('Layer index $\\ell$', color=TEXT) + ax.set_ylabel('$\\|\\partial \\mathcal{L}/\\partial W_\\ell\\|_F$', color=TEXT) + ax.set_title('(a) BP returns zero hidden weight gradients', + fontsize=10, color=TEXT, pad=4) + ax.grid(axis='both', color=GRID, linewidth=0.6) + ax.set_axisbelow(True) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.legend(loc='lower right', fontsize=7, frameon=False, + handletextpad=0.4, labelspacing=0.3) + + +def panel_linear_probe(ax): + for L, color, label in DEPTHS: + rows = DATA[f'L={L}'] + P = np.array([r['probe_acc'] for r in rows]) + med = np.nanmedian(P, axis=0) + p25 = np.nanpercentile(P, 25, axis=0) + p75 = np.nanpercentile(P, 75, axis=0) + xs = np.arange(P.shape[1]) + ax.plot(xs, med, marker='o', markersize=4, color=color, + linewidth=1.6, label=label, zorder=3) + ax.fill_between(xs, p25, p75, color=color, alpha=0.15, + edgecolor='none', zorder=2) + ax.axhline(y=CHANCE, color='#999999', linestyle='--', linewidth=0.7) + ax.text(0.4, CHANCE + 0.015, 'chance ($1/7$)', fontsize=7, color='#666666') + ax.set_xlabel('Layer index $\\ell$ (post-act $H_\\ell$)', color=TEXT) + ax.set_ylabel('Frozen linear-probe accuracy', color=TEXT) + ax.set_title('(b) Linear probe on hidden states', + fontsize=10, color=TEXT, pad=4) + ax.set_ylim(0.05, 0.85) + ax.grid(axis='both', color=GRID, linewidth=0.6) + ax.set_axisbelow(True) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.legend(loc='upper right', fontsize=7, frameon=False, + handletextpad=0.4, labelspacing=0.3) + + +def compute_summary_rows(): + """Return list of (depth, bp_acc_str, underflow_str, out_err_str, probe_str).""" + out = [] + for L, _, _ in DEPTHS: + rows = DATA[f'L={L}'] + Wg = np.array([r['W_grads_F'] for r in rows]) + n_under = int((Wg <= 0).sum()) + n_total = Wg.size + accs = np.array([r['bp_acc'] for r in rows]) * 100 # percent + Zg_out = np.array([r['Z_grads_F'][-1] for r in rows]) + Zg_med = np.median(Zg_out) + P = np.array([r['probe_acc'] for r in rows]) + if L >= 6: + mid_slice = P[:, 1:L] + else: + mid_slice = P[:, 1:] + probe_mid = np.nanmedian(mid_slice) + # tight "xx.x ± y.y %" (% in the value since the column header dropped it) + bp_str = f'{accs.mean():.1f} ± {accs.std():.1f}%' + out.append(( + f'$L = {L}$', + bp_str, + f'{n_under}/{n_total}', + f'{Zg_med:.1e}', + f'{probe_mid:.2f}', + )) + return out + + +def panel_summary_table(ax): + """Hand-render a clean summary table that fills the panel.""" + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.set_xticks([]); ax.set_yticks([]) + for s in ax.spines.values(): + s.set_visible(False) + ax.set_title('(c) Summary across depth (20 seeds)', + fontsize=10, color=TEXT, pad=4) + + rows = compute_summary_rows() + headers = ['Depth', 'BP test acc', '$W$-grad zeros', + 'out. err.', 'mid-layer\nprobe'] + n_cols = len(headers) + # column boundaries: depth narrow, BP-acc / probe slightly wider, + # remainder evenly split. Then x-centers are exact midpoints so every + # cell is centred between its dividers. + col_edges = [0.13, 0.36, 0.58, 0.78] # 4 inner dividers + bounds = [0.0] + col_edges + [1.0] # 6 outer / inner edges + col_x = [(bounds[i] + bounds[i + 1]) / 2 for i in range(n_cols)] + # Stretch table to fill axes height: header band on top, three rows + # filling the rest of the panel down to y=0. + header_h = 0.22 + row_h = 0.26 # 0.78 / 3 + header_y = 1.0 - header_h / 2 # = 0.89 + header_top = 1.0 + header_bot = 1.0 - header_h # = 0.78 + row_ys = [header_bot - row_h * (i + 0.5) # 0.65 / 0.39 / 0.13 + for i in range(3)] + # Alternating row backgrounds + for i, y in enumerate(row_ys): + bg = '#F7F8FA' if i % 2 else '#FFFFFF' + ax.add_patch(plt.Rectangle((0.0, y - row_h / 2), 1.0, row_h, + facecolor=bg, edgecolor='none', zorder=1)) + # Header band + ax.add_patch(plt.Rectangle((0.0, header_bot), 1.0, header_h, + facecolor='#EAEDF1', edgecolor='none', zorder=1)) + # Header text + for x, h in zip(col_x, headers): + ax.text(x, header_y, h, ha='center', va='center', + fontsize=8.5, fontweight='bold', color=TEXT, zorder=3, + linespacing=1.0) + # Data rows + for i, ((depth_str, bp_str, under_str, out_str, probe_str), + (_, color, _), y) in enumerate(zip(rows, DEPTHS, row_ys)): + ax.text(col_x[0], y, depth_str, ha='center', va='center', + fontsize=9, fontweight='bold', color=color, zorder=3) + ax.text(col_x[1], y, bp_str, ha='center', va='center', + fontsize=8.5, color=TEXT, zorder=3) + ax.text(col_x[2], y, under_str, ha='center', va='center', + fontsize=8.5, color=TEXT, zorder=3) + ax.text(col_x[3], y, out_str, ha='center', va='center', + fontsize=8.5, color=TEXT, zorder=3) + ax.text(col_x[4], y, probe_str, ha='center', va='center', + fontsize=8.5, color=TEXT, zorder=3) + # Horizontal rules: top, under header, between rows, bottom + bottom = row_ys[-1] - row_h / 2 + for y in (1.0, header_bot, bottom): + ax.plot([0, 1], [y, y], color='#C9CDD3', linewidth=0.8, zorder=2) + # Vertical separators between columns, full height + for x in col_edges: + ax.plot([x, x], [bottom, 1.0], + color='#C9CDD3', linewidth=0.6, zorder=2) + # Outer left/right borders for symmetry + for x in (0.0, 1.0): + ax.plot([x, x], [bottom, 1.0], + color='#C9CDD3', linewidth=0.8, zorder=2) + # Pin axes to the table extent so title sits flush like (a)/(b) + ax.set_xlim(0, 1) + ax.set_ylim(bottom, 1.0) + + +def panel_forward_magnitude(ax): + for L, color, label in DEPTHS: + rows = DATA[f'L={L}'] + M = np.array([r['M_rms'] for r in rows]) + D = np.array([r['D_norm'] for r in rows]) + M_c = np.where(M <= 0, UNDERFLOW, M) + D_c = np.where(D <= 0, UNDERFLOW, D) + M_med = np.median(M_c, axis=0) + D_med = np.median(D_c, axis=0) + xs = np.arange(L + 1) + ax.plot(xs, M_med, marker='o', markersize=3.5, color=color, + linewidth=1.4, label=f'{label} : $M_\\ell$', zorder=3) + ax.plot(xs, D_med, marker='s', markersize=3.5, color=color, + linewidth=1.0, linestyle='--', alpha=0.7, + label=f'{label} : $D_\\ell$', zorder=3) + ax.set_yscale('log') + ax.set_ylim(UNDERFLOW * 0.5, 200) + ax.axhline(y=UNDERFLOW * 1.5, color='#999999', linestyle='--', linewidth=0.7) + ax.set_xlabel('Layer index $\\ell$ (post-act $H_\\ell$)', color=TEXT) + ax.set_ylabel('Forward magnitude $M_\\ell$, dispersion $D_\\ell$', color=TEXT) + ax.set_title('Raw activation magnitude and centered dispersion', + fontsize=10, color=TEXT, pad=4) + ax.grid(axis='both', color=GRID, linewidth=0.6) + ax.set_axisbelow(True) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + color_handles = [Line2D([0], [0], color=c, linewidth=1.6, label=lbl) + for _, c, lbl in DEPTHS] + mag_handle = Line2D([0], [0], color='gray', linewidth=1.4, marker='o', + markersize=3.5, label='$M_\\ell$ (RMS magnitude)') + disp_handle = Line2D([0], [0], color='gray', linewidth=1.0, marker='s', + markersize=3.5, linestyle='--', alpha=0.7, + label='$D_\\ell$ (centered dispersion)') + ax.legend(handles=color_handles + [mag_handle, disp_handle], + loc='lower right', fontsize=7, frameon=False, + handletextpad=0.4, labelspacing=0.3) + + +# Main Figure 1 — 3 panels (weight grad / probe / summary table) +fig, axes = plt.subplots(1, 3, figsize=(13.5, 3.4), + gridspec_kw={'width_ratios': [1.0, 1.0, 1.45]}) +panel_weight_grad(axes[0]) +panel_linear_probe(axes[1]) +panel_summary_table(axes[2]) +fig.tight_layout(w_pad=2.5) +fig.savefig(OUT_PNG, dpi=300, bbox_inches='tight') +fig.savefig(OUT_PDF, bbox_inches='tight') +plt.close(fig) +print(f'Saved {OUT_PNG} and {OUT_PDF}') + +# Appendix figure +fig, ax = plt.subplots(1, 1, figsize=(5.5, 3.4)) +panel_forward_magnitude(ax) +fig.tight_layout() +fig.savefig(APP_PNG, dpi=300, bbox_inches='tight') +fig.savefig(APP_PDF, bbox_inches='tight') +plt.close(fig) +print(f'Saved {APP_PNG} and {APP_PDF}') + +print('\nSummary table:') +for row in compute_summary_rows(): + print(' ', row) |
