diff options
Diffstat (limited to 'figures/gen_depth_sweep_fig.py')
| -rw-r--r-- | figures/gen_depth_sweep_fig.py | 165 |
1 files changed, 165 insertions, 0 deletions
diff --git a/figures/gen_depth_sweep_fig.py b/figures/gen_depth_sweep_fig.py new file mode 100644 index 0000000..9604a6a --- /dev/null +++ b/figures/gen_depth_sweep_fig.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +"""H8: Generate Figure 4(a)-style depth sweep plot. + +4 panels (Cora/CiteSeer/PubMed/DBLP), 3 curves per panel (BP/DFA-GNN/GRAFT). +x = number of layers L; y = test accuracy (%) with shaded std band. + +Method distinguished by color only (per memory `feedback_viz_shape`: +shape encodes sweep axis — here L is the x-axis, so same marker for all methods). +""" + +import json +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.colors import to_rgba + +DATASETS = ['Cora', 'CiteSeer', 'PubMed', 'DBLP'] +METHODS = ['BP', 'DFA-GNN', 'GRAFT'] +# Per-dataset depth grids — DBLP extends to 24, 32 from dblp_depth_scaling. +# Other datasets cover 2..20. Missing entries (e.g. DFA-GNN at L=2/3, DBLP L=10 +# for BP/GRAFT) will be silently skipped by lookup(). +DEPTHS_DEFAULT = [2, 3, 4, 6, 8, 10, 12, 14, 16, 18, 20] +DEPTHS_DBLP = [2, 3, 4, 6, 8, 10, 12, 14, 16, 18, 20, 24, 32] +DEPTHS_BY_DS = {ds: (DEPTHS_DBLP if ds == 'DBLP' else DEPTHS_DEFAULT) + for ds in DATASETS} + +# All result files we might need to consult +SOURCES = [ + 'results/combo_20seeds/per_seed_data.json', # L=6 BP/GRAFT/stacks on Cora/CS/DBLP + 'results/hero_extras_20seeds/per_seed_data.json', # L=6 on PubMed + Coauthor + 'results/shallow_depth_20seeds/per_seed_data.json', # L=2,3,4 on 4ds + 'results/dblp_depth_scaling_20seeds/per_seed_data.json', # DBLP L=8-32 + 'results/bp_graft_depth_20seeds/per_seed_data.json', # Cora/CS/PubMed L=8-20 + 'results/dfagnn_depth_20seeds/per_seed_data.json', # DFA-GNN at all depths + 'results/dfagnn_resgcn_20seeds/per_seed_data.json', # DFA-GNN L=6 Cora/CS/DBLP + 'results/depth_extras_20seeds/per_seed_data.json', # L=14, L=18 × 4ds × 3 methods +] + +# Colors — GRAFT brick red (main method), BP gray, DFA-GNN complementary blue +COLORS = { + 'BP': '#888888', # reference gray + 'DFA-GNN': '#3B7AC2', # complementary blue + 'GRAFT': '#C23B3B', # brick red (our method) +} + +GRID_COLOR = '#ECEFF3' +TEXT_COLOR = '#2F3437' + + +def load_all(): + """Load all sources into a single dict keyed by original keys.""" + merged = {} + for path in SOURCES: + try: + with open(f'/home/yurenh2/graph-grape/{path}') as f: + d = json.load(f) + for k, v in d.items(): + if k not in merged: + merged[k] = v + else: + # Merge seed dicts (take first available if conflict) + for sk, sv in v.items(): + if sk not in merged[k]: + merged[k][sk] = sv + except FileNotFoundError: + pass + return merged + + +def lookup(data, ds, L, method): + """Return (mean, std) or None if unavailable.""" + # Try multiple key formats + # 1. {ds}_L{L}_{method} (depth-indexed) + # 2. {ds}_{method} (for L=6, assumed default in combo/hero files) + for key in [f'{ds}_L{L}_{method}', f'{ds}_{method}' if L == 6 else None]: + if key and key in data: + seeds = data[key] + if len(seeds) >= 15: # allow a few missing seeds + vals = np.array(list(seeds.values())) * 100 + return vals.mean(), vals.std() + return None + + +def main(): + data = load_all() + + plt.rcParams.update({ + 'font.size': 10, + 'axes.labelsize': 10, + 'xtick.labelsize': 9, + 'ytick.labelsize': 9, + 'legend.fontsize': 9, + 'pdf.fonttype': 42, + 'ps.fonttype': 42, + }) + + fig, axes = plt.subplots(1, 4, figsize=(13.0, 3.3), sharey=False) + + legend_handles = {} + + for ax, ds in zip(axes, DATASETS): + depths = DEPTHS_BY_DS[ds] + for method in METHODS: + xs, means, stds = [], [], [] + for L in depths: + r = lookup(data, ds, L, method) + if r is not None: + xs.append(L) + means.append(r[0]) + stds.append(r[1]) + if not xs: + continue + xs = np.array(xs); means = np.array(means); stds = np.array(stds) + color = COLORS[method] + line, = ax.plot(xs, means, marker='o', markersize=5, + color=color, linewidth=1.6, + markerfacecolor=to_rgba(color, alpha=0.35), + markeredgecolor=color, markeredgewidth=0.8, + zorder=3) + ax.fill_between(xs, means - stds, means + stds, + color=color, alpha=0.12, edgecolor='none', zorder=2) + if method not in legend_handles: + legend_handles[method] = line + + ax.set_title(ds, fontsize=10, color=TEXT_COLOR, pad=6) + ax.set_xlabel('Number of layers $L$', fontsize=9, color=TEXT_COLOR) + ax.grid(axis='both', color=GRID_COLOR, linewidth=0.7) + ax.set_axisbelow(True) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.spines['left'].set_color('#C9CDD3') + ax.spines['bottom'].set_color('#C9CDD3') + ax.tick_params(colors=TEXT_COLOR) + # Show every other tick for readability when grid is dense + ticks = depths if len(depths) <= 8 else depths[::2] + ax.set_xticks(ticks) + + axes[0].set_ylabel('Test accuracy (%)', fontsize=10, color=TEXT_COLOR) + + handles = [legend_handles[m] for m in METHODS if m in legend_handles] + labels = [m for m in METHODS if m in legend_handles] + fig.tight_layout(rect=(0.0, 0.06, 1.0, 1.0), w_pad=1.5) + fig.legend(handles, labels, + frameon=False, loc='lower center', + ncol=len(labels), bbox_to_anchor=(0.5, -0.005), + handletextpad=0.6, columnspacing=1.8) + fig.savefig('/home/yurenh2/graph-grape/graft_depth_sweep.png', dpi=300, bbox_inches='tight') + fig.savefig('/home/yurenh2/graph-grape/graft_depth_sweep.pdf', bbox_inches='tight') + plt.close(fig) + print('Saved /home/yurenh2/graph-grape/graft_depth_sweep.{png,pdf}') + + # Data dump + print('\nData (mean ± std):') + for ds in DATASETS: + print(f'\n{ds}:') + depths = DEPTHS_BY_DS[ds] + for method in METHODS: + row = [f'{method:<9}'] + for L in depths: + r = lookup(data, ds, L, method) + row.append(f'L{L}: {r[0]:5.1f}±{r[1]:4.1f}' if r else f'L{L}: {"—":>10}') + print(' ' + ' '.join(row)) + + +if __name__ == '__main__': + main() |
