summaryrefslogtreecommitdiff
path: root/figures/gen_fig4_combined.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-05-04 23:05:16 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-05-04 23:05:16 -0500
commitbd9333eda60a9029a198acaeacb1eca4312bd1e8 (patch)
tree7544c347b7ac4e8629fa1cc0fcf341d48cb69e2e /figures/gen_fig4_combined.py
Initial release: GRAFT (KAFT) — NeurIPS 2026 submission code
Topology-factorized Jacobian-aligned feedback for deep GNNs. Includes: - src/: GraphGrAPETrainer (KAFT) + BP / DFA / DFA-GNN / VanillaGrAPE baselines + multi-probe alignment estimator + dataset / sparse-mm utilities. - experiments/: 19 runners reproducing every figure / table in the paper. - figures/: 4 generators + the 4 PDFs cited in the report. - paper/: NeurIPS .tex and consolidated experiments_master notes. Smoke test: 50-epoch Cora GCN L=4 gives BP 77.3% / KAFT 79.0%.
Diffstat (limited to 'figures/gen_fig4_combined.py')
-rw-r--r--figures/gen_fig4_combined.py191
1 files changed, 191 insertions, 0 deletions
diff --git a/figures/gen_fig4_combined.py b/figures/gen_fig4_combined.py
new file mode 100644
index 0000000..5b8d464
--- /dev/null
+++ b/figures/gen_fig4_combined.py
@@ -0,0 +1,191 @@
+#!/usr/bin/env python3
+"""Figure 4-style combined plot: 4 panels (depth / add / remove / flip).
+
+Each panel: 9 curves = 3 datasets × 3 methods.
+ color = dataset (Cora / CiteSeer / PubMed)
+ linestyle = method (BP dashed, DFA-GNN dotted, GRAFT solid)
+
+Matches DFA-GNN Figure 4 layout.
+"""
+
+import json
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.colors import to_rgba
+from matplotlib.lines import Line2D
+
+DATASETS = ['Cora', 'CiteSeer', 'PubMed']
+METHODS = ['BP', 'DFA-GNN', 'GRAFT'] # data-lookup keys (unchanged)
+DISPLAY_NAME = {'BP': 'BP', 'DFA-GNN': 'DFA-GNN', 'GRAFT': 'KAFT'}
+
+DEPTHS = [4, 6, 8, 10, 12, 14, 16, 18, 20]
+RATES = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
+ATTACKS = ['add', 'remove', 'flip']
+
+# Method colors — consistent with other GRAFT figures
+METHOD_COLORS = {
+ 'BP': '#888888', # gray
+ 'DFA-GNN': '#3B7AC2', # complementary blue
+ 'GRAFT': '#C23B3B', # brick red (ours)
+}
+# Dataset linestyles
+DS_STYLE = {
+ 'Cora': (0, ()), # solid
+ 'CiteSeer': (0, (5, 2)), # dashed
+ 'PubMed': (0, (1, 1.5)), # dotted
+}
+DS_MARKER = {
+ 'Cora': 'o',
+ 'CiteSeer': 's',
+ 'PubMed': '^',
+}
+
+GRID_COLOR = '#ECEFF3'
+TEXT_COLOR = '#2F3437'
+
+# --- depth data sources (depth_sweep reuses gen_depth_sweep_fig loaders) -----
+DEPTH_SOURCES = [
+ 'results/combo_20seeds/per_seed_data.json',
+ 'results/hero_extras_20seeds/per_seed_data.json',
+ 'results/shallow_depth_20seeds/per_seed_data.json',
+ 'results/bp_graft_depth_20seeds/per_seed_data.json',
+ 'results/dfagnn_depth_20seeds/per_seed_data.json',
+ 'results/dfagnn_resgcn_20seeds/per_seed_data.json',
+ 'results/depth_extras_20seeds/per_seed_data.json', # L=14, 18
+]
+PERTURB_SOURCE = 'results/perturb_sweep_20seeds/per_seed_data.json'
+
+
+def load_depth():
+ merged = {}
+ for path in DEPTH_SOURCES:
+ try:
+ with open(f'/home/yurenh2/graph-grape/{path}') as f:
+ d = json.load(f)
+ for k, v in d.items():
+ if k not in merged:
+ merged[k] = v
+ else:
+ for sk, sv in v.items():
+ if sk not in merged[k]:
+ merged[k][sk] = sv
+ except FileNotFoundError:
+ pass
+ return merged
+
+
+def depth_lookup(data, ds, L, method):
+ for key in [f'{ds}_L{L}_{method}', f'{ds}_{method}' if L == 6 else None]:
+ if key and key in data and len(data[key]) >= 15:
+ vals = np.array(list(data[key].values())) * 100
+ return vals.mean(), vals.std()
+ return None
+
+
+def perturb_lookup(data, ds, attack, rate, method):
+ key = f'{ds}_{attack}_r{rate}_{method}'
+ if key in data and len(data[key]) >= 15:
+ vals = np.array(list(data[key].values())) * 100
+ return vals.mean(), vals.std()
+ return None
+
+
+def plot_panel(ax, panel_type, data, title):
+ """panel_type: 'depth' or attack name."""
+ xs = DEPTHS if panel_type == 'depth' else RATES
+ for ds in DATASETS:
+ for method in METHODS:
+ means = []
+ stds = []
+ xs_used = []
+ for x in xs:
+ if panel_type == 'depth':
+ r = depth_lookup(data, ds, x, method)
+ else:
+ r = perturb_lookup(data, ds, panel_type, x, method)
+ if r is not None:
+ xs_used.append(x)
+ means.append(r[0])
+ stds.append(r[1])
+ if not means:
+ continue
+ color = METHOD_COLORS[method]
+ style = DS_STYLE[ds]
+ marker = DS_MARKER[ds]
+ ax.plot(xs_used, means, color=color, linestyle=style, marker=marker,
+ markersize=4.5, linewidth=1.3,
+ markerfacecolor=to_rgba(color, alpha=0.35),
+ markeredgecolor=color, markeredgewidth=0.7,
+ zorder=3)
+ # Shaded band (light)
+ means = np.array(means); stds = np.array(stds)
+ ax.fill_between(xs_used, means - stds, means + stds,
+ color=color, alpha=0.06, edgecolor='none', zorder=1)
+
+ ax.set_title(title, fontsize=10, color=TEXT_COLOR, pad=5)
+ ax.grid(axis='both', color=GRID_COLOR, linewidth=0.6)
+ ax.set_axisbelow(True)
+ ax.spines['top'].set_visible(False)
+ ax.spines['right'].set_visible(False)
+ ax.spines['left'].set_color('#C9CDD3')
+ ax.spines['bottom'].set_color('#C9CDD3')
+ ax.tick_params(colors=TEXT_COLOR)
+ if panel_type == 'depth':
+ ax.set_xticks(DEPTHS)
+ ax.set_xlabel('Number of layers $L$', fontsize=9, color=TEXT_COLOR)
+ else:
+ ax.set_xticks(RATES)
+ ax.set_xlabel('Perturbation rate $\\lambda$', fontsize=9, color=TEXT_COLOR)
+
+
+def main():
+ depth_data = load_depth()
+ with open(f'/home/yurenh2/graph-grape/{PERTURB_SOURCE}') as f:
+ perturb_data = json.load(f)
+
+ plt.rcParams.update({
+ 'font.size': 9,
+ 'axes.labelsize': 9,
+ 'xtick.labelsize': 8,
+ 'ytick.labelsize': 8,
+ 'legend.fontsize': 8,
+ 'pdf.fonttype': 42,
+ 'ps.fonttype': 42,
+ })
+
+ fig, axes = plt.subplots(1, 4, figsize=(13.5, 3.2))
+
+ plot_panel(axes[0], 'depth', depth_data, '(a) Depth')
+ plot_panel(axes[1], 'add', perturb_data, '(b) Add')
+ plot_panel(axes[2], 'remove', perturb_data, '(c) Remove')
+ plot_panel(axes[3], 'flip', perturb_data, '(d) Flip')
+
+ axes[0].set_ylabel('Test accuracy (%)', fontsize=9, color=TEXT_COLOR)
+
+ # Dual-legend: colors (methods) + linestyles (datasets)
+ method_handles = [Line2D([0], [0], color=METHOD_COLORS[m], linewidth=2.5,
+ label=DISPLAY_NAME[m])
+ for m in METHODS]
+ ds_handles = [Line2D([0], [0], color='#444', linestyle=DS_STYLE[ds],
+ marker=DS_MARKER[ds], markersize=4.5,
+ linewidth=1.5, label=ds)
+ for ds in DATASETS]
+
+ fig.tight_layout(rect=(0.0, 0.09, 1.0, 1.0), w_pad=1.3)
+ fig.legend(handles=method_handles, loc='lower left', bbox_to_anchor=(0.08, -0.01),
+ frameon=False, ncol=3, handletextpad=0.5, columnspacing=1.5,
+ title='Method', title_fontsize=9)
+ fig.legend(handles=ds_handles, loc='lower right', bbox_to_anchor=(0.92, -0.01),
+ frameon=False, ncol=3, handletextpad=0.5, columnspacing=1.5,
+ title='Dataset', title_fontsize=9)
+
+ fig.savefig('/home/yurenh2/graph-grape/kaft_fig4_combined.png',
+ dpi=300, bbox_inches='tight')
+ fig.savefig('/home/yurenh2/graph-grape/kaft_fig4_combined.pdf',
+ bbox_inches='tight')
+ plt.close(fig)
+ print('Saved /home/yurenh2/graph-grape/kaft_fig4_combined.{png,pdf}')
+
+
+if __name__ == '__main__':
+ main()