summaryrefslogtreecommitdiff
path: root/figures/gen_depth_sweep_fig.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-05-04 23:05:16 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-05-04 23:05:16 -0500
commitbd9333eda60a9029a198acaeacb1eca4312bd1e8 (patch)
tree7544c347b7ac4e8629fa1cc0fcf341d48cb69e2e /figures/gen_depth_sweep_fig.py
Initial release: GRAFT (KAFT) — NeurIPS 2026 submission code
Topology-factorized Jacobian-aligned feedback for deep GNNs. Includes: - src/: GraphGrAPETrainer (KAFT) + BP / DFA / DFA-GNN / VanillaGrAPE baselines + multi-probe alignment estimator + dataset / sparse-mm utilities. - experiments/: 19 runners reproducing every figure / table in the paper. - figures/: 4 generators + the 4 PDFs cited in the report. - paper/: NeurIPS .tex and consolidated experiments_master notes. Smoke test: 50-epoch Cora GCN L=4 gives BP 77.3% / KAFT 79.0%.
Diffstat (limited to 'figures/gen_depth_sweep_fig.py')
-rw-r--r--figures/gen_depth_sweep_fig.py165
1 files changed, 165 insertions, 0 deletions
diff --git a/figures/gen_depth_sweep_fig.py b/figures/gen_depth_sweep_fig.py
new file mode 100644
index 0000000..9604a6a
--- /dev/null
+++ b/figures/gen_depth_sweep_fig.py
@@ -0,0 +1,165 @@
+#!/usr/bin/env python3
+"""H8: Generate Figure 4(a)-style depth sweep plot.
+
+4 panels (Cora/CiteSeer/PubMed/DBLP), 3 curves per panel (BP/DFA-GNN/GRAFT).
+x = number of layers L; y = test accuracy (%) with shaded std band.
+
+Method distinguished by color only (per memory `feedback_viz_shape`:
+shape encodes sweep axis — here L is the x-axis, so same marker for all methods).
+"""
+
+import json
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.colors import to_rgba
+
+DATASETS = ['Cora', 'CiteSeer', 'PubMed', 'DBLP']
+METHODS = ['BP', 'DFA-GNN', 'GRAFT']
+# Per-dataset depth grids — DBLP extends to 24, 32 from dblp_depth_scaling.
+# Other datasets cover 2..20. Missing entries (e.g. DFA-GNN at L=2/3, DBLP L=10
+# for BP/GRAFT) will be silently skipped by lookup().
+DEPTHS_DEFAULT = [2, 3, 4, 6, 8, 10, 12, 14, 16, 18, 20]
+DEPTHS_DBLP = [2, 3, 4, 6, 8, 10, 12, 14, 16, 18, 20, 24, 32]
+DEPTHS_BY_DS = {ds: (DEPTHS_DBLP if ds == 'DBLP' else DEPTHS_DEFAULT)
+ for ds in DATASETS}
+
+# All result files we might need to consult
+SOURCES = [
+ 'results/combo_20seeds/per_seed_data.json', # L=6 BP/GRAFT/stacks on Cora/CS/DBLP
+ 'results/hero_extras_20seeds/per_seed_data.json', # L=6 on PubMed + Coauthor
+ 'results/shallow_depth_20seeds/per_seed_data.json', # L=2,3,4 on 4ds
+ 'results/dblp_depth_scaling_20seeds/per_seed_data.json', # DBLP L=8-32
+ 'results/bp_graft_depth_20seeds/per_seed_data.json', # Cora/CS/PubMed L=8-20
+ 'results/dfagnn_depth_20seeds/per_seed_data.json', # DFA-GNN at all depths
+ 'results/dfagnn_resgcn_20seeds/per_seed_data.json', # DFA-GNN L=6 Cora/CS/DBLP
+ 'results/depth_extras_20seeds/per_seed_data.json', # L=14, L=18 × 4ds × 3 methods
+]
+
+# Colors — GRAFT brick red (main method), BP gray, DFA-GNN complementary blue
+COLORS = {
+ 'BP': '#888888', # reference gray
+ 'DFA-GNN': '#3B7AC2', # complementary blue
+ 'GRAFT': '#C23B3B', # brick red (our method)
+}
+
+GRID_COLOR = '#ECEFF3'
+TEXT_COLOR = '#2F3437'
+
+
+def load_all():
+ """Load all sources into a single dict keyed by original keys."""
+ merged = {}
+ for path in SOURCES:
+ try:
+ with open(f'/home/yurenh2/graph-grape/{path}') as f:
+ d = json.load(f)
+ for k, v in d.items():
+ if k not in merged:
+ merged[k] = v
+ else:
+ # Merge seed dicts (take first available if conflict)
+ for sk, sv in v.items():
+ if sk not in merged[k]:
+ merged[k][sk] = sv
+ except FileNotFoundError:
+ pass
+ return merged
+
+
+def lookup(data, ds, L, method):
+ """Return (mean, std) or None if unavailable."""
+ # Try multiple key formats
+ # 1. {ds}_L{L}_{method} (depth-indexed)
+ # 2. {ds}_{method} (for L=6, assumed default in combo/hero files)
+ for key in [f'{ds}_L{L}_{method}', f'{ds}_{method}' if L == 6 else None]:
+ if key and key in data:
+ seeds = data[key]
+ if len(seeds) >= 15: # allow a few missing seeds
+ vals = np.array(list(seeds.values())) * 100
+ return vals.mean(), vals.std()
+ return None
+
+
+def main():
+ data = load_all()
+
+ plt.rcParams.update({
+ 'font.size': 10,
+ 'axes.labelsize': 10,
+ 'xtick.labelsize': 9,
+ 'ytick.labelsize': 9,
+ 'legend.fontsize': 9,
+ 'pdf.fonttype': 42,
+ 'ps.fonttype': 42,
+ })
+
+ fig, axes = plt.subplots(1, 4, figsize=(13.0, 3.3), sharey=False)
+
+ legend_handles = {}
+
+ for ax, ds in zip(axes, DATASETS):
+ depths = DEPTHS_BY_DS[ds]
+ for method in METHODS:
+ xs, means, stds = [], [], []
+ for L in depths:
+ r = lookup(data, ds, L, method)
+ if r is not None:
+ xs.append(L)
+ means.append(r[0])
+ stds.append(r[1])
+ if not xs:
+ continue
+ xs = np.array(xs); means = np.array(means); stds = np.array(stds)
+ color = COLORS[method]
+ line, = ax.plot(xs, means, marker='o', markersize=5,
+ color=color, linewidth=1.6,
+ markerfacecolor=to_rgba(color, alpha=0.35),
+ markeredgecolor=color, markeredgewidth=0.8,
+ zorder=3)
+ ax.fill_between(xs, means - stds, means + stds,
+ color=color, alpha=0.12, edgecolor='none', zorder=2)
+ if method not in legend_handles:
+ legend_handles[method] = line
+
+ ax.set_title(ds, fontsize=10, color=TEXT_COLOR, pad=6)
+ ax.set_xlabel('Number of layers $L$', fontsize=9, color=TEXT_COLOR)
+ ax.grid(axis='both', color=GRID_COLOR, linewidth=0.7)
+ ax.set_axisbelow(True)
+ ax.spines['top'].set_visible(False)
+ ax.spines['right'].set_visible(False)
+ ax.spines['left'].set_color('#C9CDD3')
+ ax.spines['bottom'].set_color('#C9CDD3')
+ ax.tick_params(colors=TEXT_COLOR)
+ # Show every other tick for readability when grid is dense
+ ticks = depths if len(depths) <= 8 else depths[::2]
+ ax.set_xticks(ticks)
+
+ axes[0].set_ylabel('Test accuracy (%)', fontsize=10, color=TEXT_COLOR)
+
+ handles = [legend_handles[m] for m in METHODS if m in legend_handles]
+ labels = [m for m in METHODS if m in legend_handles]
+ fig.tight_layout(rect=(0.0, 0.06, 1.0, 1.0), w_pad=1.5)
+ fig.legend(handles, labels,
+ frameon=False, loc='lower center',
+ ncol=len(labels), bbox_to_anchor=(0.5, -0.005),
+ handletextpad=0.6, columnspacing=1.8)
+ fig.savefig('/home/yurenh2/graph-grape/graft_depth_sweep.png', dpi=300, bbox_inches='tight')
+ fig.savefig('/home/yurenh2/graph-grape/graft_depth_sweep.pdf', bbox_inches='tight')
+ plt.close(fig)
+ print('Saved /home/yurenh2/graph-grape/graft_depth_sweep.{png,pdf}')
+
+ # Data dump
+ print('\nData (mean ± std):')
+ for ds in DATASETS:
+ print(f'\n{ds}:')
+ depths = DEPTHS_BY_DS[ds]
+ for method in METHODS:
+ row = [f'{method:<9}']
+ for L in depths:
+ r = lookup(data, ds, L, method)
+ row.append(f'L{L}: {r[0]:5.1f}±{r[1]:4.1f}' if r else f'L{L}: {"—":>10}')
+ print(' ' + ' '.join(row))
+
+
+if __name__ == '__main__':
+ main()