#!/usr/bin/env python3 """Figure 4-style combined plot: 4 panels (depth / add / remove / flip). Each panel: 9 curves = 3 datasets × 3 methods. color = dataset (Cora / CiteSeer / PubMed) linestyle = method (BP dashed, DFA-GNN dotted, GRAFT solid) Matches DFA-GNN Figure 4 layout. """ import json import numpy as np import matplotlib.pyplot as plt from matplotlib.colors import to_rgba from matplotlib.lines import Line2D DATASETS = ['Cora', 'CiteSeer', 'PubMed'] METHODS = ['BP', 'DFA-GNN', 'GRAFT'] # data-lookup keys (unchanged) DISPLAY_NAME = {'BP': 'BP', 'DFA-GNN': 'DFA-GNN', 'GRAFT': 'KAFT'} DEPTHS = [4, 6, 8, 10, 12, 14, 16, 18, 20] RATES = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] ATTACKS = ['add', 'remove', 'flip'] # Method colors — consistent with other GRAFT figures METHOD_COLORS = { 'BP': '#888888', # gray 'DFA-GNN': '#3B7AC2', # complementary blue 'GRAFT': '#C23B3B', # brick red (ours) } # Dataset linestyles DS_STYLE = { 'Cora': (0, ()), # solid 'CiteSeer': (0, (5, 2)), # dashed 'PubMed': (0, (1, 1.5)), # dotted } DS_MARKER = { 'Cora': 'o', 'CiteSeer': 's', 'PubMed': '^', } GRID_COLOR = '#ECEFF3' TEXT_COLOR = '#2F3437' # --- depth data sources (depth_sweep reuses gen_depth_sweep_fig loaders) ----- DEPTH_SOURCES = [ 'results/combo_20seeds/per_seed_data.json', 'results/hero_extras_20seeds/per_seed_data.json', 'results/shallow_depth_20seeds/per_seed_data.json', 'results/bp_graft_depth_20seeds/per_seed_data.json', 'results/dfagnn_depth_20seeds/per_seed_data.json', 'results/dfagnn_resgcn_20seeds/per_seed_data.json', 'results/depth_extras_20seeds/per_seed_data.json', # L=14, 18 ] PERTURB_SOURCE = 'results/perturb_sweep_20seeds/per_seed_data.json' def load_depth(): merged = {} for path in DEPTH_SOURCES: try: with open(f'/home/yurenh2/graph-grape/{path}') as f: d = json.load(f) for k, v in d.items(): if k not in merged: merged[k] = v else: for sk, sv in v.items(): if sk not in merged[k]: merged[k][sk] = sv except FileNotFoundError: pass return merged def depth_lookup(data, ds, L, method): for key in [f'{ds}_L{L}_{method}', f'{ds}_{method}' if L == 6 else None]: if key and key in data and len(data[key]) >= 15: vals = np.array(list(data[key].values())) * 100 return vals.mean(), vals.std() return None def perturb_lookup(data, ds, attack, rate, method): key = f'{ds}_{attack}_r{rate}_{method}' if key in data and len(data[key]) >= 15: vals = np.array(list(data[key].values())) * 100 return vals.mean(), vals.std() return None def plot_panel(ax, panel_type, data, title): """panel_type: 'depth' or attack name.""" xs = DEPTHS if panel_type == 'depth' else RATES for ds in DATASETS: for method in METHODS: means = [] stds = [] xs_used = [] for x in xs: if panel_type == 'depth': r = depth_lookup(data, ds, x, method) else: r = perturb_lookup(data, ds, panel_type, x, method) if r is not None: xs_used.append(x) means.append(r[0]) stds.append(r[1]) if not means: continue color = METHOD_COLORS[method] style = DS_STYLE[ds] marker = DS_MARKER[ds] ax.plot(xs_used, means, color=color, linestyle=style, marker=marker, markersize=4.5, linewidth=1.3, markerfacecolor=to_rgba(color, alpha=0.35), markeredgecolor=color, markeredgewidth=0.7, zorder=3) # Shaded band (light) means = np.array(means); stds = np.array(stds) ax.fill_between(xs_used, means - stds, means + stds, color=color, alpha=0.06, edgecolor='none', zorder=1) ax.set_title(title, fontsize=10, color=TEXT_COLOR, pad=5) ax.grid(axis='both', color=GRID_COLOR, linewidth=0.6) ax.set_axisbelow(True) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['left'].set_color('#C9CDD3') ax.spines['bottom'].set_color('#C9CDD3') ax.tick_params(colors=TEXT_COLOR) if panel_type == 'depth': ax.set_xticks(DEPTHS) ax.set_xlabel('Number of layers $L$', fontsize=9, color=TEXT_COLOR) else: ax.set_xticks(RATES) ax.set_xlabel('Perturbation rate $\\lambda$', fontsize=9, color=TEXT_COLOR) def main(): depth_data = load_depth() with open(f'/home/yurenh2/graph-grape/{PERTURB_SOURCE}') as f: perturb_data = json.load(f) plt.rcParams.update({ 'font.size': 9, 'axes.labelsize': 9, 'xtick.labelsize': 8, 'ytick.labelsize': 8, 'legend.fontsize': 8, 'pdf.fonttype': 42, 'ps.fonttype': 42, }) fig, axes = plt.subplots(1, 4, figsize=(13.5, 3.2)) plot_panel(axes[0], 'depth', depth_data, '(a) Depth') plot_panel(axes[1], 'add', perturb_data, '(b) Add') plot_panel(axes[2], 'remove', perturb_data, '(c) Remove') plot_panel(axes[3], 'flip', perturb_data, '(d) Flip') axes[0].set_ylabel('Test accuracy (%)', fontsize=9, color=TEXT_COLOR) # Dual-legend: colors (methods) + linestyles (datasets) method_handles = [Line2D([0], [0], color=METHOD_COLORS[m], linewidth=2.5, label=DISPLAY_NAME[m]) for m in METHODS] ds_handles = [Line2D([0], [0], color='#444', linestyle=DS_STYLE[ds], marker=DS_MARKER[ds], markersize=4.5, linewidth=1.5, label=ds) for ds in DATASETS] fig.tight_layout(rect=(0.0, 0.09, 1.0, 1.0), w_pad=1.3) fig.legend(handles=method_handles, loc='lower left', bbox_to_anchor=(0.08, -0.01), frameon=False, ncol=3, handletextpad=0.5, columnspacing=1.5, title='Method', title_fontsize=9) fig.legend(handles=ds_handles, loc='lower right', bbox_to_anchor=(0.92, -0.01), frameon=False, ncol=3, handletextpad=0.5, columnspacing=1.5, title='Dataset', title_fontsize=9) fig.savefig('/home/yurenh2/graph-grape/kaft_fig4_combined.png', dpi=300, bbox_inches='tight') fig.savefig('/home/yurenh2/graph-grape/kaft_fig4_combined.pdf', bbox_inches='tight') plt.close(fig) print('Saved /home/yurenh2/graph-grape/kaft_fig4_combined.{png,pdf}') if __name__ == '__main__': main()