diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-05-04 23:05:16 -0500 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-05-04 23:05:16 -0500 |
| commit | bd9333eda60a9029a198acaeacb1eca4312bd1e8 (patch) | |
| tree | 7544c347b7ac4e8629fa1cc0fcf341d48cb69e2e /figures/gen_fig4_combined.py | |
Initial release: GRAFT (KAFT) — NeurIPS 2026 submission code
Topology-factorized Jacobian-aligned feedback for deep GNNs. Includes:
- src/: GraphGrAPETrainer (KAFT) + BP / DFA / DFA-GNN / VanillaGrAPE baselines
+ multi-probe alignment estimator + dataset / sparse-mm utilities.
- experiments/: 19 runners reproducing every figure / table in the paper.
- figures/: 4 generators + the 4 PDFs cited in the report.
- paper/: NeurIPS .tex and consolidated experiments_master notes.
Smoke test: 50-epoch Cora GCN L=4 gives BP 77.3% / KAFT 79.0%.
Diffstat (limited to 'figures/gen_fig4_combined.py')
| -rw-r--r-- | figures/gen_fig4_combined.py | 191 |
1 files changed, 191 insertions, 0 deletions
diff --git a/figures/gen_fig4_combined.py b/figures/gen_fig4_combined.py new file mode 100644 index 0000000..5b8d464 --- /dev/null +++ b/figures/gen_fig4_combined.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python3 +"""Figure 4-style combined plot: 4 panels (depth / add / remove / flip). + +Each panel: 9 curves = 3 datasets × 3 methods. + color = dataset (Cora / CiteSeer / PubMed) + linestyle = method (BP dashed, DFA-GNN dotted, GRAFT solid) + +Matches DFA-GNN Figure 4 layout. +""" + +import json +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.colors import to_rgba +from matplotlib.lines import Line2D + +DATASETS = ['Cora', 'CiteSeer', 'PubMed'] +METHODS = ['BP', 'DFA-GNN', 'GRAFT'] # data-lookup keys (unchanged) +DISPLAY_NAME = {'BP': 'BP', 'DFA-GNN': 'DFA-GNN', 'GRAFT': 'KAFT'} + +DEPTHS = [4, 6, 8, 10, 12, 14, 16, 18, 20] +RATES = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] +ATTACKS = ['add', 'remove', 'flip'] + +# Method colors — consistent with other GRAFT figures +METHOD_COLORS = { + 'BP': '#888888', # gray + 'DFA-GNN': '#3B7AC2', # complementary blue + 'GRAFT': '#C23B3B', # brick red (ours) +} +# Dataset linestyles +DS_STYLE = { + 'Cora': (0, ()), # solid + 'CiteSeer': (0, (5, 2)), # dashed + 'PubMed': (0, (1, 1.5)), # dotted +} +DS_MARKER = { + 'Cora': 'o', + 'CiteSeer': 's', + 'PubMed': '^', +} + +GRID_COLOR = '#ECEFF3' +TEXT_COLOR = '#2F3437' + +# --- depth data sources (depth_sweep reuses gen_depth_sweep_fig loaders) ----- +DEPTH_SOURCES = [ + 'results/combo_20seeds/per_seed_data.json', + 'results/hero_extras_20seeds/per_seed_data.json', + 'results/shallow_depth_20seeds/per_seed_data.json', + 'results/bp_graft_depth_20seeds/per_seed_data.json', + 'results/dfagnn_depth_20seeds/per_seed_data.json', + 'results/dfagnn_resgcn_20seeds/per_seed_data.json', + 'results/depth_extras_20seeds/per_seed_data.json', # L=14, 18 +] +PERTURB_SOURCE = 'results/perturb_sweep_20seeds/per_seed_data.json' + + +def load_depth(): + merged = {} + for path in DEPTH_SOURCES: + try: + with open(f'/home/yurenh2/graph-grape/{path}') as f: + d = json.load(f) + for k, v in d.items(): + if k not in merged: + merged[k] = v + else: + for sk, sv in v.items(): + if sk not in merged[k]: + merged[k][sk] = sv + except FileNotFoundError: + pass + return merged + + +def depth_lookup(data, ds, L, method): + for key in [f'{ds}_L{L}_{method}', f'{ds}_{method}' if L == 6 else None]: + if key and key in data and len(data[key]) >= 15: + vals = np.array(list(data[key].values())) * 100 + return vals.mean(), vals.std() + return None + + +def perturb_lookup(data, ds, attack, rate, method): + key = f'{ds}_{attack}_r{rate}_{method}' + if key in data and len(data[key]) >= 15: + vals = np.array(list(data[key].values())) * 100 + return vals.mean(), vals.std() + return None + + +def plot_panel(ax, panel_type, data, title): + """panel_type: 'depth' or attack name.""" + xs = DEPTHS if panel_type == 'depth' else RATES + for ds in DATASETS: + for method in METHODS: + means = [] + stds = [] + xs_used = [] + for x in xs: + if panel_type == 'depth': + r = depth_lookup(data, ds, x, method) + else: + r = perturb_lookup(data, ds, panel_type, x, method) + if r is not None: + xs_used.append(x) + means.append(r[0]) + stds.append(r[1]) + if not means: + continue + color = METHOD_COLORS[method] + style = DS_STYLE[ds] + marker = DS_MARKER[ds] + ax.plot(xs_used, means, color=color, linestyle=style, marker=marker, + markersize=4.5, linewidth=1.3, + markerfacecolor=to_rgba(color, alpha=0.35), + markeredgecolor=color, markeredgewidth=0.7, + zorder=3) + # Shaded band (light) + means = np.array(means); stds = np.array(stds) + ax.fill_between(xs_used, means - stds, means + stds, + color=color, alpha=0.06, edgecolor='none', zorder=1) + + ax.set_title(title, fontsize=10, color=TEXT_COLOR, pad=5) + ax.grid(axis='both', color=GRID_COLOR, linewidth=0.6) + ax.set_axisbelow(True) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.spines['left'].set_color('#C9CDD3') + ax.spines['bottom'].set_color('#C9CDD3') + ax.tick_params(colors=TEXT_COLOR) + if panel_type == 'depth': + ax.set_xticks(DEPTHS) + ax.set_xlabel('Number of layers $L$', fontsize=9, color=TEXT_COLOR) + else: + ax.set_xticks(RATES) + ax.set_xlabel('Perturbation rate $\\lambda$', fontsize=9, color=TEXT_COLOR) + + +def main(): + depth_data = load_depth() + with open(f'/home/yurenh2/graph-grape/{PERTURB_SOURCE}') as f: + perturb_data = json.load(f) + + plt.rcParams.update({ + 'font.size': 9, + 'axes.labelsize': 9, + 'xtick.labelsize': 8, + 'ytick.labelsize': 8, + 'legend.fontsize': 8, + 'pdf.fonttype': 42, + 'ps.fonttype': 42, + }) + + fig, axes = plt.subplots(1, 4, figsize=(13.5, 3.2)) + + plot_panel(axes[0], 'depth', depth_data, '(a) Depth') + plot_panel(axes[1], 'add', perturb_data, '(b) Add') + plot_panel(axes[2], 'remove', perturb_data, '(c) Remove') + plot_panel(axes[3], 'flip', perturb_data, '(d) Flip') + + axes[0].set_ylabel('Test accuracy (%)', fontsize=9, color=TEXT_COLOR) + + # Dual-legend: colors (methods) + linestyles (datasets) + method_handles = [Line2D([0], [0], color=METHOD_COLORS[m], linewidth=2.5, + label=DISPLAY_NAME[m]) + for m in METHODS] + ds_handles = [Line2D([0], [0], color='#444', linestyle=DS_STYLE[ds], + marker=DS_MARKER[ds], markersize=4.5, + linewidth=1.5, label=ds) + for ds in DATASETS] + + fig.tight_layout(rect=(0.0, 0.09, 1.0, 1.0), w_pad=1.3) + fig.legend(handles=method_handles, loc='lower left', bbox_to_anchor=(0.08, -0.01), + frameon=False, ncol=3, handletextpad=0.5, columnspacing=1.5, + title='Method', title_fontsize=9) + fig.legend(handles=ds_handles, loc='lower right', bbox_to_anchor=(0.92, -0.01), + frameon=False, ncol=3, handletextpad=0.5, columnspacing=1.5, + title='Dataset', title_fontsize=9) + + fig.savefig('/home/yurenh2/graph-grape/kaft_fig4_combined.png', + dpi=300, bbox_inches='tight') + fig.savefig('/home/yurenh2/graph-grape/kaft_fig4_combined.pdf', + bbox_inches='tight') + plt.close(fig) + print('Saved /home/yurenh2/graph-grape/kaft_fig4_combined.{png,pdf}') + + +if __name__ == '__main__': + main() |
