summaryrefslogtreecommitdiff
path: root/figures
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-05-04 23:05:16 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-05-04 23:05:16 -0500
commitbd9333eda60a9029a198acaeacb1eca4312bd1e8 (patch)
tree7544c347b7ac4e8629fa1cc0fcf341d48cb69e2e /figures
Initial release: GRAFT (KAFT) — NeurIPS 2026 submission code
Topology-factorized Jacobian-aligned feedback for deep GNNs. Includes: - src/: GraphGrAPETrainer (KAFT) + BP / DFA / DFA-GNN / VanillaGrAPE baselines + multi-probe alignment estimator + dataset / sparse-mm utilities. - experiments/: 19 runners reproducing every figure / table in the paper. - figures/: 4 generators + the 4 PDFs cited in the report. - paper/: NeurIPS .tex and consolidated experiments_master notes. Smoke test: 50-epoch Cora GCN L=4 gives BP 77.3% / KAFT 79.0%.
Diffstat (limited to 'figures')
-rw-r--r--figures/fig1_bp_bottleneck.pdfbin0 -> 35302 bytes
-rw-r--r--figures/gen_depth_sweep_fig.py165
-rw-r--r--figures/gen_fig1_diagnostic.py271
-rw-r--r--figures/gen_fig4_combined.py191
-rw-r--r--figures/gen_realworld_depth_fig.py93
-rw-r--r--figures/graft_depth_sweep.pdfbin0 -> 27139 bytes
-rw-r--r--figures/kaft_fig4_combined.pdfbin0 -> 37989 bytes
-rw-r--r--figures/kaft_realworld_depth.pdfbin0 -> 28763 bytes
8 files changed, 720 insertions, 0 deletions
diff --git a/figures/fig1_bp_bottleneck.pdf b/figures/fig1_bp_bottleneck.pdf
new file mode 100644
index 0000000..df55379
--- /dev/null
+++ b/figures/fig1_bp_bottleneck.pdf
Binary files differ
diff --git a/figures/gen_depth_sweep_fig.py b/figures/gen_depth_sweep_fig.py
new file mode 100644
index 0000000..9604a6a
--- /dev/null
+++ b/figures/gen_depth_sweep_fig.py
@@ -0,0 +1,165 @@
+#!/usr/bin/env python3
+"""H8: Generate Figure 4(a)-style depth sweep plot.
+
+4 panels (Cora/CiteSeer/PubMed/DBLP), 3 curves per panel (BP/DFA-GNN/GRAFT).
+x = number of layers L; y = test accuracy (%) with shaded std band.
+
+Method distinguished by color only (per memory `feedback_viz_shape`:
+shape encodes sweep axis — here L is the x-axis, so same marker for all methods).
+"""
+
+import json
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.colors import to_rgba
+
+DATASETS = ['Cora', 'CiteSeer', 'PubMed', 'DBLP']
+METHODS = ['BP', 'DFA-GNN', 'GRAFT']
+# Per-dataset depth grids — DBLP extends to 24, 32 from dblp_depth_scaling.
+# Other datasets cover 2..20. Missing entries (e.g. DFA-GNN at L=2/3, DBLP L=10
+# for BP/GRAFT) will be silently skipped by lookup().
+DEPTHS_DEFAULT = [2, 3, 4, 6, 8, 10, 12, 14, 16, 18, 20]
+DEPTHS_DBLP = [2, 3, 4, 6, 8, 10, 12, 14, 16, 18, 20, 24, 32]
+DEPTHS_BY_DS = {ds: (DEPTHS_DBLP if ds == 'DBLP' else DEPTHS_DEFAULT)
+ for ds in DATASETS}
+
+# All result files we might need to consult
+SOURCES = [
+ 'results/combo_20seeds/per_seed_data.json', # L=6 BP/GRAFT/stacks on Cora/CS/DBLP
+ 'results/hero_extras_20seeds/per_seed_data.json', # L=6 on PubMed + Coauthor
+ 'results/shallow_depth_20seeds/per_seed_data.json', # L=2,3,4 on 4ds
+ 'results/dblp_depth_scaling_20seeds/per_seed_data.json', # DBLP L=8-32
+ 'results/bp_graft_depth_20seeds/per_seed_data.json', # Cora/CS/PubMed L=8-20
+ 'results/dfagnn_depth_20seeds/per_seed_data.json', # DFA-GNN at all depths
+ 'results/dfagnn_resgcn_20seeds/per_seed_data.json', # DFA-GNN L=6 Cora/CS/DBLP
+ 'results/depth_extras_20seeds/per_seed_data.json', # L=14, L=18 × 4ds × 3 methods
+]
+
+# Colors — GRAFT brick red (main method), BP gray, DFA-GNN complementary blue
+COLORS = {
+ 'BP': '#888888', # reference gray
+ 'DFA-GNN': '#3B7AC2', # complementary blue
+ 'GRAFT': '#C23B3B', # brick red (our method)
+}
+
+GRID_COLOR = '#ECEFF3'
+TEXT_COLOR = '#2F3437'
+
+
+def load_all():
+ """Load all sources into a single dict keyed by original keys."""
+ merged = {}
+ for path in SOURCES:
+ try:
+ with open(f'/home/yurenh2/graph-grape/{path}') as f:
+ d = json.load(f)
+ for k, v in d.items():
+ if k not in merged:
+ merged[k] = v
+ else:
+ # Merge seed dicts (take first available if conflict)
+ for sk, sv in v.items():
+ if sk not in merged[k]:
+ merged[k][sk] = sv
+ except FileNotFoundError:
+ pass
+ return merged
+
+
+def lookup(data, ds, L, method):
+ """Return (mean, std) or None if unavailable."""
+ # Try multiple key formats
+ # 1. {ds}_L{L}_{method} (depth-indexed)
+ # 2. {ds}_{method} (for L=6, assumed default in combo/hero files)
+ for key in [f'{ds}_L{L}_{method}', f'{ds}_{method}' if L == 6 else None]:
+ if key and key in data:
+ seeds = data[key]
+ if len(seeds) >= 15: # allow a few missing seeds
+ vals = np.array(list(seeds.values())) * 100
+ return vals.mean(), vals.std()
+ return None
+
+
+def main():
+ data = load_all()
+
+ plt.rcParams.update({
+ 'font.size': 10,
+ 'axes.labelsize': 10,
+ 'xtick.labelsize': 9,
+ 'ytick.labelsize': 9,
+ 'legend.fontsize': 9,
+ 'pdf.fonttype': 42,
+ 'ps.fonttype': 42,
+ })
+
+ fig, axes = plt.subplots(1, 4, figsize=(13.0, 3.3), sharey=False)
+
+ legend_handles = {}
+
+ for ax, ds in zip(axes, DATASETS):
+ depths = DEPTHS_BY_DS[ds]
+ for method in METHODS:
+ xs, means, stds = [], [], []
+ for L in depths:
+ r = lookup(data, ds, L, method)
+ if r is not None:
+ xs.append(L)
+ means.append(r[0])
+ stds.append(r[1])
+ if not xs:
+ continue
+ xs = np.array(xs); means = np.array(means); stds = np.array(stds)
+ color = COLORS[method]
+ line, = ax.plot(xs, means, marker='o', markersize=5,
+ color=color, linewidth=1.6,
+ markerfacecolor=to_rgba(color, alpha=0.35),
+ markeredgecolor=color, markeredgewidth=0.8,
+ zorder=3)
+ ax.fill_between(xs, means - stds, means + stds,
+ color=color, alpha=0.12, edgecolor='none', zorder=2)
+ if method not in legend_handles:
+ legend_handles[method] = line
+
+ ax.set_title(ds, fontsize=10, color=TEXT_COLOR, pad=6)
+ ax.set_xlabel('Number of layers $L$', fontsize=9, color=TEXT_COLOR)
+ ax.grid(axis='both', color=GRID_COLOR, linewidth=0.7)
+ ax.set_axisbelow(True)
+ ax.spines['top'].set_visible(False)
+ ax.spines['right'].set_visible(False)
+ ax.spines['left'].set_color('#C9CDD3')
+ ax.spines['bottom'].set_color('#C9CDD3')
+ ax.tick_params(colors=TEXT_COLOR)
+ # Show every other tick for readability when grid is dense
+ ticks = depths if len(depths) <= 8 else depths[::2]
+ ax.set_xticks(ticks)
+
+ axes[0].set_ylabel('Test accuracy (%)', fontsize=10, color=TEXT_COLOR)
+
+ handles = [legend_handles[m] for m in METHODS if m in legend_handles]
+ labels = [m for m in METHODS if m in legend_handles]
+ fig.tight_layout(rect=(0.0, 0.06, 1.0, 1.0), w_pad=1.5)
+ fig.legend(handles, labels,
+ frameon=False, loc='lower center',
+ ncol=len(labels), bbox_to_anchor=(0.5, -0.005),
+ handletextpad=0.6, columnspacing=1.8)
+ fig.savefig('/home/yurenh2/graph-grape/graft_depth_sweep.png', dpi=300, bbox_inches='tight')
+ fig.savefig('/home/yurenh2/graph-grape/graft_depth_sweep.pdf', bbox_inches='tight')
+ plt.close(fig)
+ print('Saved /home/yurenh2/graph-grape/graft_depth_sweep.{png,pdf}')
+
+ # Data dump
+ print('\nData (mean ± std):')
+ for ds in DATASETS:
+ print(f'\n{ds}:')
+ depths = DEPTHS_BY_DS[ds]
+ for method in METHODS:
+ row = [f'{method:<9}']
+ for L in depths:
+ r = lookup(data, ds, L, method)
+ row.append(f'L{L}: {r[0]:5.1f}±{r[1]:4.1f}' if r else f'L{L}: {"—":>10}')
+ print(' ' + ' '.join(row))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/figures/gen_fig1_diagnostic.py b/figures/gen_fig1_diagnostic.py
new file mode 100644
index 0000000..99ffc15
--- /dev/null
+++ b/figures/gen_fig1_diagnostic.py
@@ -0,0 +1,271 @@
+#!/usr/bin/env python3
+"""Figure 1 (main, 3 panels) + Appendix figure (1 panel) for §2.3.
+
+Main Figure 1 (fig1_bp_bottleneck.{png,pdf}) — three panels:
+ (a) BP hidden weight-gradient collapse: ||dL/dW_l||_F per layer, log scale,
+ L∈{6,10,20}. Zeros clipped at 1e-39 for log-scale visualization.
+ Output-side error is in the (c) summary table, NOT overlaid here.
+ (b) Frozen linear-probe accuracy on H_l with chance line at 1/7. Caveat
+ goes in figure caption (probes are diagnostic, not a training method).
+ (c) Summary table — Depth × {BP acc, hidden underflow count,
+ output error ||dL/dZ_{L-1}||, mid-layer probe acc}.
+
+Appendix figure (fig_app_forward_magnitude.{png,pdf}) — one panel:
+ Raw activation magnitude M_l and centered dispersion D_l per layer.
+ Supports the caption note that the §2.3 claim is about scale-normalized
+ recoverability, not numerical largeness of the forward pass.
+
+20 seeds, GCN, Cora, paper setup, epoch-100 checkpoint.
+Source: results/diag_section23/diag_data_v2.json.
+"""
+import json
+import numpy as np
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+from matplotlib.lines import Line2D
+
+DATA_PATH = '/home/yurenh2/graph-grape/results/diag_section23/diag_data_v2.json'
+OUT_PNG = '/home/yurenh2/graph-grape/fig1_bp_bottleneck.png'
+OUT_PDF = '/home/yurenh2/graph-grape/fig1_bp_bottleneck.pdf'
+APP_PNG = '/home/yurenh2/graph-grape/fig_app_forward_magnitude.png'
+APP_PDF = '/home/yurenh2/graph-grape/fig_app_forward_magnitude.pdf'
+CHANCE = 1.0 / 7.0
+UNDERFLOW = 1e-39
+
+DATA = json.load(open(DATA_PATH))
+DEPTHS = [(6, '#5b8def', 'GCN $L\\!=\\!6$'),
+ (10, '#cc6677', 'GCN $L\\!=\\!10$'),
+ (20, '#882255', 'GCN $L\\!=\\!20$')]
+
+plt.rcParams.update({
+ 'font.size': 9, 'axes.labelsize': 9,
+ 'xtick.labelsize': 8, 'ytick.labelsize': 8,
+ 'legend.fontsize': 8,
+ 'pdf.fonttype': 42, 'ps.fonttype': 42,
+})
+
+GRID = '#ECEFF3'
+TEXT = '#2F3437'
+
+
+def panel_weight_grad(ax):
+ for L, color, label in DEPTHS:
+ rows = DATA[f'L={L}']
+ Wg = np.array([r['W_grads_F'] for r in rows])
+ Wg_c = np.where(Wg <= 0, UNDERFLOW, Wg)
+ med = np.median(Wg_c, axis=0)
+ p25 = np.percentile(Wg_c, 25, axis=0)
+ p75 = np.percentile(Wg_c, 75, axis=0)
+ xs = np.arange(L)
+ ax.plot(xs, med, marker='o', markersize=4, color=color,
+ linewidth=1.6, label=label, zorder=3)
+ ax.fill_between(xs, p25, p75, color=color, alpha=0.15,
+ edgecolor='none', zorder=2)
+ ax.axhline(y=UNDERFLOW * 1.5, color='#999999', linestyle='--', linewidth=0.7)
+ ax.text(0.5, UNDERFLOW * 3, 'recorded as zero (display floor)',
+ fontsize=7, color='#666666', va='bottom')
+ ax.set_yscale('log')
+ ax.set_ylim(UNDERFLOW * 0.5, 5)
+ ax.set_xlabel('Layer index $\\ell$', color=TEXT)
+ ax.set_ylabel('$\\|\\partial \\mathcal{L}/\\partial W_\\ell\\|_F$', color=TEXT)
+ ax.set_title('(a) BP returns zero hidden weight gradients',
+ fontsize=10, color=TEXT, pad=4)
+ ax.grid(axis='both', color=GRID, linewidth=0.6)
+ ax.set_axisbelow(True)
+ ax.spines['top'].set_visible(False)
+ ax.spines['right'].set_visible(False)
+ ax.legend(loc='lower right', fontsize=7, frameon=False,
+ handletextpad=0.4, labelspacing=0.3)
+
+
+def panel_linear_probe(ax):
+ for L, color, label in DEPTHS:
+ rows = DATA[f'L={L}']
+ P = np.array([r['probe_acc'] for r in rows])
+ med = np.nanmedian(P, axis=0)
+ p25 = np.nanpercentile(P, 25, axis=0)
+ p75 = np.nanpercentile(P, 75, axis=0)
+ xs = np.arange(P.shape[1])
+ ax.plot(xs, med, marker='o', markersize=4, color=color,
+ linewidth=1.6, label=label, zorder=3)
+ ax.fill_between(xs, p25, p75, color=color, alpha=0.15,
+ edgecolor='none', zorder=2)
+ ax.axhline(y=CHANCE, color='#999999', linestyle='--', linewidth=0.7)
+ ax.text(0.4, CHANCE + 0.015, 'chance ($1/7$)', fontsize=7, color='#666666')
+ ax.set_xlabel('Layer index $\\ell$ (post-act $H_\\ell$)', color=TEXT)
+ ax.set_ylabel('Frozen linear-probe accuracy', color=TEXT)
+ ax.set_title('(b) Linear probe on hidden states',
+ fontsize=10, color=TEXT, pad=4)
+ ax.set_ylim(0.05, 0.85)
+ ax.grid(axis='both', color=GRID, linewidth=0.6)
+ ax.set_axisbelow(True)
+ ax.spines['top'].set_visible(False)
+ ax.spines['right'].set_visible(False)
+ ax.legend(loc='upper right', fontsize=7, frameon=False,
+ handletextpad=0.4, labelspacing=0.3)
+
+
+def compute_summary_rows():
+ """Return list of (depth, bp_acc_str, underflow_str, out_err_str, probe_str)."""
+ out = []
+ for L, _, _ in DEPTHS:
+ rows = DATA[f'L={L}']
+ Wg = np.array([r['W_grads_F'] for r in rows])
+ n_under = int((Wg <= 0).sum())
+ n_total = Wg.size
+ accs = np.array([r['bp_acc'] for r in rows]) * 100 # percent
+ Zg_out = np.array([r['Z_grads_F'][-1] for r in rows])
+ Zg_med = np.median(Zg_out)
+ P = np.array([r['probe_acc'] for r in rows])
+ if L >= 6:
+ mid_slice = P[:, 1:L]
+ else:
+ mid_slice = P[:, 1:]
+ probe_mid = np.nanmedian(mid_slice)
+ # tight "xx.x ± y.y %" (% in the value since the column header dropped it)
+ bp_str = f'{accs.mean():.1f} ± {accs.std():.1f}%'
+ out.append((
+ f'$L = {L}$',
+ bp_str,
+ f'{n_under}/{n_total}',
+ f'{Zg_med:.1e}',
+ f'{probe_mid:.2f}',
+ ))
+ return out
+
+
+def panel_summary_table(ax):
+ """Hand-render a clean summary table that fills the panel."""
+ ax.set_xlim(0, 1)
+ ax.set_ylim(0, 1)
+ ax.set_xticks([]); ax.set_yticks([])
+ for s in ax.spines.values():
+ s.set_visible(False)
+ ax.set_title('(c) Summary across depth (20 seeds)',
+ fontsize=10, color=TEXT, pad=4)
+
+ rows = compute_summary_rows()
+ headers = ['Depth', 'BP test acc', '$W$-grad zeros',
+ 'out. err.', 'mid-layer\nprobe']
+ n_cols = len(headers)
+ # column boundaries: depth narrow, BP-acc / probe slightly wider,
+ # remainder evenly split. Then x-centers are exact midpoints so every
+ # cell is centred between its dividers.
+ col_edges = [0.13, 0.36, 0.58, 0.78] # 4 inner dividers
+ bounds = [0.0] + col_edges + [1.0] # 6 outer / inner edges
+ col_x = [(bounds[i] + bounds[i + 1]) / 2 for i in range(n_cols)]
+ # Stretch table to fill axes height: header band on top, three rows
+ # filling the rest of the panel down to y=0.
+ header_h = 0.22
+ row_h = 0.26 # 0.78 / 3
+ header_y = 1.0 - header_h / 2 # = 0.89
+ header_top = 1.0
+ header_bot = 1.0 - header_h # = 0.78
+ row_ys = [header_bot - row_h * (i + 0.5) # 0.65 / 0.39 / 0.13
+ for i in range(3)]
+ # Alternating row backgrounds
+ for i, y in enumerate(row_ys):
+ bg = '#F7F8FA' if i % 2 else '#FFFFFF'
+ ax.add_patch(plt.Rectangle((0.0, y - row_h / 2), 1.0, row_h,
+ facecolor=bg, edgecolor='none', zorder=1))
+ # Header band
+ ax.add_patch(plt.Rectangle((0.0, header_bot), 1.0, header_h,
+ facecolor='#EAEDF1', edgecolor='none', zorder=1))
+ # Header text
+ for x, h in zip(col_x, headers):
+ ax.text(x, header_y, h, ha='center', va='center',
+ fontsize=8.5, fontweight='bold', color=TEXT, zorder=3,
+ linespacing=1.0)
+ # Data rows
+ for i, ((depth_str, bp_str, under_str, out_str, probe_str),
+ (_, color, _), y) in enumerate(zip(rows, DEPTHS, row_ys)):
+ ax.text(col_x[0], y, depth_str, ha='center', va='center',
+ fontsize=9, fontweight='bold', color=color, zorder=3)
+ ax.text(col_x[1], y, bp_str, ha='center', va='center',
+ fontsize=8.5, color=TEXT, zorder=3)
+ ax.text(col_x[2], y, under_str, ha='center', va='center',
+ fontsize=8.5, color=TEXT, zorder=3)
+ ax.text(col_x[3], y, out_str, ha='center', va='center',
+ fontsize=8.5, color=TEXT, zorder=3)
+ ax.text(col_x[4], y, probe_str, ha='center', va='center',
+ fontsize=8.5, color=TEXT, zorder=3)
+ # Horizontal rules: top, under header, between rows, bottom
+ bottom = row_ys[-1] - row_h / 2
+ for y in (1.0, header_bot, bottom):
+ ax.plot([0, 1], [y, y], color='#C9CDD3', linewidth=0.8, zorder=2)
+ # Vertical separators between columns, full height
+ for x in col_edges:
+ ax.plot([x, x], [bottom, 1.0],
+ color='#C9CDD3', linewidth=0.6, zorder=2)
+ # Outer left/right borders for symmetry
+ for x in (0.0, 1.0):
+ ax.plot([x, x], [bottom, 1.0],
+ color='#C9CDD3', linewidth=0.8, zorder=2)
+ # Pin axes to the table extent so title sits flush like (a)/(b)
+ ax.set_xlim(0, 1)
+ ax.set_ylim(bottom, 1.0)
+
+
+def panel_forward_magnitude(ax):
+ for L, color, label in DEPTHS:
+ rows = DATA[f'L={L}']
+ M = np.array([r['M_rms'] for r in rows])
+ D = np.array([r['D_norm'] for r in rows])
+ M_c = np.where(M <= 0, UNDERFLOW, M)
+ D_c = np.where(D <= 0, UNDERFLOW, D)
+ M_med = np.median(M_c, axis=0)
+ D_med = np.median(D_c, axis=0)
+ xs = np.arange(L + 1)
+ ax.plot(xs, M_med, marker='o', markersize=3.5, color=color,
+ linewidth=1.4, label=f'{label} : $M_\\ell$', zorder=3)
+ ax.plot(xs, D_med, marker='s', markersize=3.5, color=color,
+ linewidth=1.0, linestyle='--', alpha=0.7,
+ label=f'{label} : $D_\\ell$', zorder=3)
+ ax.set_yscale('log')
+ ax.set_ylim(UNDERFLOW * 0.5, 200)
+ ax.axhline(y=UNDERFLOW * 1.5, color='#999999', linestyle='--', linewidth=0.7)
+ ax.set_xlabel('Layer index $\\ell$ (post-act $H_\\ell$)', color=TEXT)
+ ax.set_ylabel('Forward magnitude $M_\\ell$, dispersion $D_\\ell$', color=TEXT)
+ ax.set_title('Raw activation magnitude and centered dispersion',
+ fontsize=10, color=TEXT, pad=4)
+ ax.grid(axis='both', color=GRID, linewidth=0.6)
+ ax.set_axisbelow(True)
+ ax.spines['top'].set_visible(False)
+ ax.spines['right'].set_visible(False)
+ color_handles = [Line2D([0], [0], color=c, linewidth=1.6, label=lbl)
+ for _, c, lbl in DEPTHS]
+ mag_handle = Line2D([0], [0], color='gray', linewidth=1.4, marker='o',
+ markersize=3.5, label='$M_\\ell$ (RMS magnitude)')
+ disp_handle = Line2D([0], [0], color='gray', linewidth=1.0, marker='s',
+ markersize=3.5, linestyle='--', alpha=0.7,
+ label='$D_\\ell$ (centered dispersion)')
+ ax.legend(handles=color_handles + [mag_handle, disp_handle],
+ loc='lower right', fontsize=7, frameon=False,
+ handletextpad=0.4, labelspacing=0.3)
+
+
+# Main Figure 1 — 3 panels (weight grad / probe / summary table)
+fig, axes = plt.subplots(1, 3, figsize=(13.5, 3.4),
+ gridspec_kw={'width_ratios': [1.0, 1.0, 1.45]})
+panel_weight_grad(axes[0])
+panel_linear_probe(axes[1])
+panel_summary_table(axes[2])
+fig.tight_layout(w_pad=2.5)
+fig.savefig(OUT_PNG, dpi=300, bbox_inches='tight')
+fig.savefig(OUT_PDF, bbox_inches='tight')
+plt.close(fig)
+print(f'Saved {OUT_PNG} and {OUT_PDF}')
+
+# Appendix figure
+fig, ax = plt.subplots(1, 1, figsize=(5.5, 3.4))
+panel_forward_magnitude(ax)
+fig.tight_layout()
+fig.savefig(APP_PNG, dpi=300, bbox_inches='tight')
+fig.savefig(APP_PDF, bbox_inches='tight')
+plt.close(fig)
+print(f'Saved {APP_PNG} and {APP_PDF}')
+
+print('\nSummary table:')
+for row in compute_summary_rows():
+ print(' ', row)
diff --git a/figures/gen_fig4_combined.py b/figures/gen_fig4_combined.py
new file mode 100644
index 0000000..5b8d464
--- /dev/null
+++ b/figures/gen_fig4_combined.py
@@ -0,0 +1,191 @@
+#!/usr/bin/env python3
+"""Figure 4-style combined plot: 4 panels (depth / add / remove / flip).
+
+Each panel: 9 curves = 3 datasets × 3 methods.
+ color = dataset (Cora / CiteSeer / PubMed)
+ linestyle = method (BP dashed, DFA-GNN dotted, GRAFT solid)
+
+Matches DFA-GNN Figure 4 layout.
+"""
+
+import json
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.colors import to_rgba
+from matplotlib.lines import Line2D
+
+DATASETS = ['Cora', 'CiteSeer', 'PubMed']
+METHODS = ['BP', 'DFA-GNN', 'GRAFT'] # data-lookup keys (unchanged)
+DISPLAY_NAME = {'BP': 'BP', 'DFA-GNN': 'DFA-GNN', 'GRAFT': 'KAFT'}
+
+DEPTHS = [4, 6, 8, 10, 12, 14, 16, 18, 20]
+RATES = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
+ATTACKS = ['add', 'remove', 'flip']
+
+# Method colors — consistent with other GRAFT figures
+METHOD_COLORS = {
+ 'BP': '#888888', # gray
+ 'DFA-GNN': '#3B7AC2', # complementary blue
+ 'GRAFT': '#C23B3B', # brick red (ours)
+}
+# Dataset linestyles
+DS_STYLE = {
+ 'Cora': (0, ()), # solid
+ 'CiteSeer': (0, (5, 2)), # dashed
+ 'PubMed': (0, (1, 1.5)), # dotted
+}
+DS_MARKER = {
+ 'Cora': 'o',
+ 'CiteSeer': 's',
+ 'PubMed': '^',
+}
+
+GRID_COLOR = '#ECEFF3'
+TEXT_COLOR = '#2F3437'
+
+# --- depth data sources (depth_sweep reuses gen_depth_sweep_fig loaders) -----
+DEPTH_SOURCES = [
+ 'results/combo_20seeds/per_seed_data.json',
+ 'results/hero_extras_20seeds/per_seed_data.json',
+ 'results/shallow_depth_20seeds/per_seed_data.json',
+ 'results/bp_graft_depth_20seeds/per_seed_data.json',
+ 'results/dfagnn_depth_20seeds/per_seed_data.json',
+ 'results/dfagnn_resgcn_20seeds/per_seed_data.json',
+ 'results/depth_extras_20seeds/per_seed_data.json', # L=14, 18
+]
+PERTURB_SOURCE = 'results/perturb_sweep_20seeds/per_seed_data.json'
+
+
+def load_depth():
+ merged = {}
+ for path in DEPTH_SOURCES:
+ try:
+ with open(f'/home/yurenh2/graph-grape/{path}') as f:
+ d = json.load(f)
+ for k, v in d.items():
+ if k not in merged:
+ merged[k] = v
+ else:
+ for sk, sv in v.items():
+ if sk not in merged[k]:
+ merged[k][sk] = sv
+ except FileNotFoundError:
+ pass
+ return merged
+
+
+def depth_lookup(data, ds, L, method):
+ for key in [f'{ds}_L{L}_{method}', f'{ds}_{method}' if L == 6 else None]:
+ if key and key in data and len(data[key]) >= 15:
+ vals = np.array(list(data[key].values())) * 100
+ return vals.mean(), vals.std()
+ return None
+
+
+def perturb_lookup(data, ds, attack, rate, method):
+ key = f'{ds}_{attack}_r{rate}_{method}'
+ if key in data and len(data[key]) >= 15:
+ vals = np.array(list(data[key].values())) * 100
+ return vals.mean(), vals.std()
+ return None
+
+
+def plot_panel(ax, panel_type, data, title):
+ """panel_type: 'depth' or attack name."""
+ xs = DEPTHS if panel_type == 'depth' else RATES
+ for ds in DATASETS:
+ for method in METHODS:
+ means = []
+ stds = []
+ xs_used = []
+ for x in xs:
+ if panel_type == 'depth':
+ r = depth_lookup(data, ds, x, method)
+ else:
+ r = perturb_lookup(data, ds, panel_type, x, method)
+ if r is not None:
+ xs_used.append(x)
+ means.append(r[0])
+ stds.append(r[1])
+ if not means:
+ continue
+ color = METHOD_COLORS[method]
+ style = DS_STYLE[ds]
+ marker = DS_MARKER[ds]
+ ax.plot(xs_used, means, color=color, linestyle=style, marker=marker,
+ markersize=4.5, linewidth=1.3,
+ markerfacecolor=to_rgba(color, alpha=0.35),
+ markeredgecolor=color, markeredgewidth=0.7,
+ zorder=3)
+ # Shaded band (light)
+ means = np.array(means); stds = np.array(stds)
+ ax.fill_between(xs_used, means - stds, means + stds,
+ color=color, alpha=0.06, edgecolor='none', zorder=1)
+
+ ax.set_title(title, fontsize=10, color=TEXT_COLOR, pad=5)
+ ax.grid(axis='both', color=GRID_COLOR, linewidth=0.6)
+ ax.set_axisbelow(True)
+ ax.spines['top'].set_visible(False)
+ ax.spines['right'].set_visible(False)
+ ax.spines['left'].set_color('#C9CDD3')
+ ax.spines['bottom'].set_color('#C9CDD3')
+ ax.tick_params(colors=TEXT_COLOR)
+ if panel_type == 'depth':
+ ax.set_xticks(DEPTHS)
+ ax.set_xlabel('Number of layers $L$', fontsize=9, color=TEXT_COLOR)
+ else:
+ ax.set_xticks(RATES)
+ ax.set_xlabel('Perturbation rate $\\lambda$', fontsize=9, color=TEXT_COLOR)
+
+
+def main():
+ depth_data = load_depth()
+ with open(f'/home/yurenh2/graph-grape/{PERTURB_SOURCE}') as f:
+ perturb_data = json.load(f)
+
+ plt.rcParams.update({
+ 'font.size': 9,
+ 'axes.labelsize': 9,
+ 'xtick.labelsize': 8,
+ 'ytick.labelsize': 8,
+ 'legend.fontsize': 8,
+ 'pdf.fonttype': 42,
+ 'ps.fonttype': 42,
+ })
+
+ fig, axes = plt.subplots(1, 4, figsize=(13.5, 3.2))
+
+ plot_panel(axes[0], 'depth', depth_data, '(a) Depth')
+ plot_panel(axes[1], 'add', perturb_data, '(b) Add')
+ plot_panel(axes[2], 'remove', perturb_data, '(c) Remove')
+ plot_panel(axes[3], 'flip', perturb_data, '(d) Flip')
+
+ axes[0].set_ylabel('Test accuracy (%)', fontsize=9, color=TEXT_COLOR)
+
+ # Dual-legend: colors (methods) + linestyles (datasets)
+ method_handles = [Line2D([0], [0], color=METHOD_COLORS[m], linewidth=2.5,
+ label=DISPLAY_NAME[m])
+ for m in METHODS]
+ ds_handles = [Line2D([0], [0], color='#444', linestyle=DS_STYLE[ds],
+ marker=DS_MARKER[ds], markersize=4.5,
+ linewidth=1.5, label=ds)
+ for ds in DATASETS]
+
+ fig.tight_layout(rect=(0.0, 0.09, 1.0, 1.0), w_pad=1.3)
+ fig.legend(handles=method_handles, loc='lower left', bbox_to_anchor=(0.08, -0.01),
+ frameon=False, ncol=3, handletextpad=0.5, columnspacing=1.5,
+ title='Method', title_fontsize=9)
+ fig.legend(handles=ds_handles, loc='lower right', bbox_to_anchor=(0.92, -0.01),
+ frameon=False, ncol=3, handletextpad=0.5, columnspacing=1.5,
+ title='Dataset', title_fontsize=9)
+
+ fig.savefig('/home/yurenh2/graph-grape/kaft_fig4_combined.png',
+ dpi=300, bbox_inches='tight')
+ fig.savefig('/home/yurenh2/graph-grape/kaft_fig4_combined.pdf',
+ bbox_inches='tight')
+ plt.close(fig)
+ print('Saved /home/yurenh2/graph-grape/kaft_fig4_combined.{png,pdf}')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/figures/gen_realworld_depth_fig.py b/figures/gen_realworld_depth_fig.py
new file mode 100644
index 0000000..1ff7e2d
--- /dev/null
+++ b/figures/gen_realworld_depth_fig.py
@@ -0,0 +1,93 @@
+#!/usr/bin/env python3
+"""Real-world dataset depth-sweep figure (Fig 4(a)' style).
+4 panels: CFull-CiteSeer, CFull-DBLP, CFull-PubMed (biomed), Coauthor-Physics.
+Data hardcoded from cfull_paper_setup.log + dblpfull_full_depth.log +
+pubmedfull_full_depth.log + physics_full_depth.log + dblp_paper_setup.log + cs_paper_setup.log."""
+
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.colors import to_rgba
+
+# Aggregated paper-setup data: (mean, std) for BP and GRAFT at each depth
+DATA = {
+ 'CFull-CiteSeer': {
+ 'depths': [3, 5, 8, 10, 12, 14, 16, 18, 20],
+ 'BP': [(0.870, 0.0072), (0.860, 0.0056), (0.825, 0.0208), (0.549, 0.1164), (0.365, 0.0209), (0.297, 0.0421), (0.230, 0.0209), (0.238, 0.0131), (0.209, 0.0319)],
+ 'DFA': [(0.855, 0.0044), (0.834, 0.0106), (0.566, 0.0289), (0.425, 0.0993), (0.329, 0.1060), (0.368, 0.0604), (0.297, 0.0722), (0.243, 0.0661), (0.244, 0.0667)],
+ 'DFA-GNN': [(0.858, 0.0038), (0.826, 0.0187), (0.581, 0.1085), (0.465, 0.0698), (0.289, 0.0677), (0.296, 0.1372), (0.244, 0.0673), (0.211, 0.0204), (0.193, 0.0051)],
+ 'GRAFT': [(0.857, 0.0006), (0.846, 0.0019), (0.829, 0.0021), (0.780, 0.0197), (0.667, 0.0630), (0.487, 0.0621), (0.430, 0.1145), (0.369, 0.0089), (0.380, 0.0258)],
+ },
+ 'CFull-DBLP': {
+ 'depths': [3, 5, 8, 10, 12, 14, 16, 18, 20],
+ 'BP': [(0.826, 0.0027), (0.814, 0.0006), (0.793, 0.0070), (0.710, 0.1180), (0.652, 0.0728), (0.559, 0.1132), (0.454, 0.0065), (0.469, 0.0077), (0.461, 0.0144)],
+ 'DFA': [(0.829, 0.0031), (0.819, 0.0076), (0.736, 0.0409), (0.703, 0.0025), (0.682, 0.0257), (0.548, 0.1104), (0.532, 0.1206), (0.533, 0.1209), (0.447, 0.0000)],
+ 'DFA-GNN': [(0.832, 0.0024), (0.823, 0.0033), (0.766, 0.0362), (0.617, 0.1203), (0.617, 0.1203), (0.523, 0.1018), (0.447, 0.0000), (0.447, 0.0000), (0.531, 0.1187)],
+ 'GRAFT': [(0.827, 0.0024), (0.825, 0.0090), (0.813, 0.0121), (0.786, 0.0032), (0.730, 0.0315), (0.701, 0.0020), (0.700, 0.0001), (0.610, 0.1150), (0.613, 0.1175)],
+ },
+ 'CFull-PubMed (biomed)': {
+ 'depths': [3, 5, 8, 10, 12, 14, 16, 18, 20],
+ 'BP': [(0.845, 0.0018), (0.833, 0.0023), (0.825, 0.0026), (0.824, 0.0025), (0.699, 0.0096), (0.499, 0.1413), (0.399, 0.0000), (0.500, 0.1421), (0.399, 0.0000)],
+ 'DFA': [(0.822, 0.0041), (0.793, 0.0188), (0.585, 0.1353), (0.531, 0.0768), (0.484, 0.0833), (0.431, 0.0446), (0.427, 0.0383), (0.399, 0.0000), (0.399, 0.0000)],
+ 'DFA-GNN': [(0.822, 0.0040), (0.750, 0.0551), (0.604, 0.1572), (0.522, 0.1154), (0.462, 0.0888), (0.399, 0.0000), (0.438, 0.0550), (0.399, 0.0000), (0.466, 0.0945)],
+ 'GRAFT': [(0.830, 0.0068), (0.814, 0.0049), (0.789, 0.0099), (0.732, 0.0713), (0.690, 0.0585), (0.646, 0.0134), (0.603, 0.0086), (0.545, 0.1031), (0.525, 0.0887)],
+ },
+ 'Coauthor-Physics': {
+ 'depths': [3, 5, 8, 10, 12, 14, 16, 18, 20],
+ 'BP': [(0.949, 0.0005), (0.943, 0.0014), (0.937, 0.0011), (0.829, 0.0344), (0.818, 0.0387), (0.770, 0.0151), (0.743, 0.0038), (0.682, 0.1000), (0.521, 0.0215)],
+ 'DFA': [(0.948, 0.0007), (0.920, 0.0067), (0.711, 0.0227), (0.686, 0.1275), (0.560, 0.0751), (0.506, 0.0005), (0.557, 0.0737), (0.559, 0.0762), (0.505, 0.0000)],
+ 'DFA-GNN': [(0.947, 0.0012), (0.836, 0.0451), (0.712, 0.0369), (0.567, 0.0720), (0.505, 0.0003), (0.505, 0.0000), (0.505, 0.0000), (0.559, 0.0756), (0.505, 0.0000)],
+ 'GRAFT': [(0.947, 0.0008), (0.943, 0.0004), (0.922, 0.0092), (0.867, 0.0368), (0.749, 0.0423), (0.686, 0.0122), (0.614, 0.0771), (0.666, 0.0010), (0.667, 0.0003)],
+ },
+}
+
+COLORS = {'BP': '#888888', 'DFA': '#7A5BAA', 'DFA-GNN': '#3B7AC2', 'GRAFT': '#C23B3B'}
+GRID = '#ECEFF3'
+TEXT = '#2F3437'
+
+plt.rcParams.update({
+ 'font.size': 9, 'axes.labelsize': 9,
+ 'xtick.labelsize': 8, 'ytick.labelsize': 8, 'legend.fontsize': 9,
+ 'pdf.fonttype': 42, 'ps.fonttype': 42,
+})
+
+fig, axes = plt.subplots(1, 4, figsize=(13.0, 3.0))
+
+datasets = list(DATA.keys())
+legend_handles = {}
+for ax, ds in zip(axes, datasets):
+ d = DATA[ds]
+ xs = d['depths']
+ for method in ['BP', 'DFA', 'DFA-GNN', 'GRAFT']:
+ means = np.array([v[0] for v in d[method]])
+ stds = np.array([v[1] for v in d[method]])
+ c = COLORS[method]
+ line, = ax.plot(xs, means, marker='o', markersize=5, color=c, linewidth=1.6,
+ markerfacecolor=to_rgba(c, alpha=0.35), markeredgecolor=c,
+ markeredgewidth=0.8, zorder=3)
+ ax.fill_between(xs, means - stds, means + stds, color=c, alpha=0.12, edgecolor='none', zorder=2)
+ if method not in legend_handles:
+ legend_handles[method] = line
+
+ ax.set_title(ds, fontsize=10, color=TEXT, pad=4)
+ ax.set_xlabel('Number of layers $L$', fontsize=9, color=TEXT)
+ ax.grid(axis='both', color=GRID, linewidth=0.6)
+ ax.set_axisbelow(True)
+ ax.spines['top'].set_visible(False)
+ ax.spines['right'].set_visible(False)
+ ax.spines['left'].set_color('#C9CDD3')
+ ax.spines['bottom'].set_color('#C9CDD3')
+ ax.tick_params(colors=TEXT)
+ ax.set_xticks([3, 5, 10, 14, 18, 20])
+
+axes[0].set_ylabel('Test accuracy', fontsize=9, color=TEXT)
+
+handles = [legend_handles[m] for m in ['BP', 'DFA', 'DFA-GNN', 'GRAFT']]
+fig.tight_layout(rect=(0.0, 0.06, 1.0, 1.0), w_pad=1.5)
+# Display label: GRAFT data key stays for the lookup, render as KAFT
+fig.legend(handles, ['BP', 'DFA', 'DFA-GNN', 'KAFT'], frameon=False, loc='lower center',
+ ncol=4, bbox_to_anchor=(0.5, -0.005), handletextpad=0.6, columnspacing=1.8)
+
+fig.savefig('/home/yurenh2/graph-grape/kaft_realworld_depth.png', dpi=300, bbox_inches='tight')
+fig.savefig('/home/yurenh2/graph-grape/kaft_realworld_depth.pdf', bbox_inches='tight')
+plt.close(fig)
+print('Saved /home/yurenh2/graph-grape/kaft_realworld_depth.{png,pdf}')
diff --git a/figures/graft_depth_sweep.pdf b/figures/graft_depth_sweep.pdf
new file mode 100644
index 0000000..21b06f2
--- /dev/null
+++ b/figures/graft_depth_sweep.pdf
Binary files differ
diff --git a/figures/kaft_fig4_combined.pdf b/figures/kaft_fig4_combined.pdf
new file mode 100644
index 0000000..0420951
--- /dev/null
+++ b/figures/kaft_fig4_combined.pdf
Binary files differ
diff --git a/figures/kaft_realworld_depth.pdf b/figures/kaft_realworld_depth.pdf
new file mode 100644
index 0000000..7e07a37
--- /dev/null
+++ b/figures/kaft_realworld_depth.pdf
Binary files differ