diff options
Diffstat (limited to 'experiments/run_dblp_depth.py')
| -rw-r--r-- | experiments/run_dblp_depth.py | 162 |
1 files changed, 162 insertions, 0 deletions
diff --git a/experiments/run_dblp_depth.py b/experiments/run_dblp_depth.py new file mode 100644 index 0000000..d63b94a --- /dev/null +++ b/experiments/run_dblp_depth.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 +""" +CitationFull-DBLP experiment + depth sweep L=2-6 补数据. +""" + +import torch +import torch.nn.functional as F +import numpy as np +import json +import os +import time +from torch_geometric.datasets import CitationFull +from src.data import build_normalized_adj, build_row_normalized_adj, spmm, precompute_traces +from src.trainers import BPTrainer, DFATrainer, GraphGrAPETrainer +from benchmark_efficient import GraphGrAPEEfficient + +device = 'cuda:0' +SEEDS = [0, 1, 2, 3, 4] +EPOCHS = 200 +OUT_DIR = 'results/dblp_depth' + + +def load_dblp(): + ds = CitationFull(root='./data', name='DBLP') + data = ds[0] + N, C = data.num_nodes, ds.num_classes + # Random split + rng = torch.Generator().manual_seed(42) + train_mask = torch.zeros(N, dtype=torch.bool) + val_mask = torch.zeros(N, dtype=torch.bool) + test_mask = torch.zeros(N, dtype=torch.bool) + for c in range(C): + idx = (data.y == c).nonzero(as_tuple=True)[0] + perm = torch.randperm(len(idx), generator=rng) + n_tr = max(1, int(0.05 * len(idx))) # 5% train (like Planetoid) + n_va = max(1, int(0.1 * len(idx))) + train_mask[idx[perm[:n_tr]]] = True + val_mask[idx[perm[n_tr:n_tr + n_va]]] = True + test_mask[idx[perm[n_tr + n_va:]]] = True + + A_hat = build_normalized_adj(data.edge_index, N) + A_row, A_row_T = build_row_normalized_adj(data.edge_index, N) + traces = {k: torch.tensor(0.0) for k in range(5)} # skip expensive trace computation + + return { + 'X': data.x.to(device), 'y': data.y.to(device), + 'A_hat': A_hat.to(device), 'A_row': A_row.to(device), 'A_row_T': A_row_T.to(device), + 'train_mask': train_mask.to(device), 'val_mask': val_mask.to(device), + 'test_mask': test_mask.to(device), + 'num_nodes': N, 'num_features': data.x.shape[1], 'num_classes': C, + 'traces': {k: v.to(device) for k, v in traces.items()}, + } + + +def train_one(cls, common, extra, seed): + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + t = cls(**common, **extra) + if hasattr(t, 'align_mode'): + t.align_mode = 'chain_norm' + bv, bt = 0, 0 + for ep in range(EPOCHS): + t.train_step() + if ep % 5 == 0: + v, te = t.evaluate('val_mask'), t.evaluate('test_mask') + if v > bv: bv, bt = v, te + del t; torch.cuda.empty_cache() + return bt + + +def time_method(cls, common, extra, n_warmup=10, n_steps=200): + torch.manual_seed(0) + t = cls(**common, **extra) + if hasattr(t, 'align_mode'): + t.align_mode = 'chain_norm' + for _ in range(n_warmup): + t.train_step() + torch.cuda.synchronize() + times = [] + for _ in range(n_steps): + torch.cuda.synchronize(); t0 = time.perf_counter() + t.train_step() + torch.cuda.synchronize(); times.append(time.perf_counter() - t0) + del t; torch.cuda.empty_cache() + return float(np.median(times) * 1000) + + +def main(): + os.makedirs(OUT_DIR, exist_ok=True) + results = {} + + grape_extra = dict(diffusion_alpha=0.5, diffusion_iters=10, + lr_feedback=0.5, num_probes=64, topo_mode='fixed_A') + + # ======== Part 1: DBLP full sweep ======== + print("=" * 60) + print("Part 1: CitationFull-DBLP") + print("=" * 60) + dblp = load_dblp() + print(f"DBLP: N={dblp['num_nodes']}, F={dblp['num_features']}, C={dblp['num_classes']}, " + f"train={dblp['train_mask'].sum().item()}", flush=True) + + for bb in ['gcn', 'sage', 'gin', 'appnp']: + for L in [5, 6]: + for lr in [0.001, 0.005, 0.01]: + common = dict(data=dblp, hidden_dim=64, lr=lr, weight_decay=5e-4, + num_layers=L, residual_alpha=0.0, backbone=bb) + key = f"DBLP|{bb}|L={L}|lr={lr}" + row = {} + for mname, cls, extra in [('BP', BPTrainer, {}), + ('DFA', DFATrainer, dict(diffusion_alpha=0.5, diffusion_iters=10)), + ('GrAPE', GraphGrAPETrainer, grape_extra)]: + accs = [train_one(cls, common, extra, s) for s in SEEDS] + row[mname] = {'mean': float(np.mean(accs)), 'std': float(np.std(accs))} + results[key] = row + bp, dfa, gr = row['BP']['mean']*100, row['DFA']['mean']*100, row['GrAPE']['mean']*100 + print(f" {bb:>6} L={L} lr={lr:.3f} | BP {bp:.1f} DFA {dfa:.1f} GrAPE {gr:.1f} | " + f"Δ(BP) {gr-bp:+.1f} Δ(DFA) {gr-dfa:+.1f}", flush=True) + + # DBLP efficiency + print("\nDBLP Efficiency:") + for bb in ['gcn', 'sage', 'gin', 'appnp']: + for L in [5, 6]: + common = dict(data=dblp, hidden_dim=64, lr=0.01, weight_decay=5e-4, + num_layers=L, residual_alpha=0.0, backbone=bb) + bp_ms = time_method(BPTrainer, common, {}) + eff_ms = time_method(GraphGrAPEEfficient, common, + dict(lr_feedback=0.5, num_probes=64, max_topo_power=3, + diff_alpha=0.5, align_every=10)) + key = f"DBLP_eff|{bb}|L={L}" + results[key] = {'BP_ms': bp_ms, 'GrAPE_Eff_ms': eff_ms, 'speedup': bp_ms / eff_ms} + print(f" {bb:>6} L={L} | BP {bp_ms:.2f}ms GrAPE-Eff {eff_ms:.2f}ms | " + f"speedup {bp_ms/eff_ms:.2f}x", flush=True) + + # ======== Part 2: Depth sweep L=2-4 补数据 (Planetoid × GCN/SAGE/APPNP) ======== + print("\n" + "=" * 60) + print("Part 2: Depth sweep L=2,3,4 补数据") + print("=" * 60) + + from src.data import load_dataset + for ds_name in ['Cora', 'CiteSeer', 'PubMed']: + data = load_dataset(ds_name, device=device) + for bb in ['gcn', 'sage', 'appnp']: + for L in [2, 3, 4]: + common = dict(data=data, hidden_dim=64, lr=0.01, weight_decay=5e-4, + num_layers=L, residual_alpha=0.0, backbone=bb) + key = f"{ds_name}|{bb}|L={L}|lr=0.01" + row = {} + for mname, cls, extra in [('BP', BPTrainer, {}), + ('GrAPE', GraphGrAPETrainer, grape_extra)]: + accs = [train_one(cls, common, extra, s) for s in SEEDS] + row[mname] = {'mean': float(np.mean(accs)), 'std': float(np.std(accs))} + results[key] = row + bp, gr = row['BP']['mean']*100, row['GrAPE']['mean']*100 + print(f" {ds_name:>10} {bb:>6} L={L} | BP {bp:.1f} GrAPE {gr:.1f} | Δ {gr-bp:+.1f}", flush=True) + + with open(os.path.join(OUT_DIR, 'results.json'), 'w') as f: + json.dump(results, f, indent=2) + print(f"\nSaved to {OUT_DIR}/results.json") + + +if __name__ == '__main__': + main() |
