diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-05-04 23:05:16 -0500 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-05-04 23:05:16 -0500 |
| commit | bd9333eda60a9029a198acaeacb1eca4312bd1e8 (patch) | |
| tree | 7544c347b7ac4e8629fa1cc0fcf341d48cb69e2e /src/data.py | |
Initial release: GRAFT (KAFT) — NeurIPS 2026 submission code
Topology-factorized Jacobian-aligned feedback for deep GNNs. Includes:
- src/: GraphGrAPETrainer (KAFT) + BP / DFA / DFA-GNN / VanillaGrAPE baselines
+ multi-probe alignment estimator + dataset / sparse-mm utilities.
- experiments/: 19 runners reproducing every figure / table in the paper.
- figures/: 4 generators + the 4 PDFs cited in the report.
- paper/: NeurIPS .tex and consolidated experiments_master notes.
Smoke test: 50-epoch Cora GCN L=4 gives BP 77.3% / KAFT 79.0%.
Diffstat (limited to 'src/data.py')
| -rw-r--r-- | src/data.py | 189 |
1 files changed, 189 insertions, 0 deletions
diff --git a/src/data.py b/src/data.py new file mode 100644 index 0000000..6e80285 --- /dev/null +++ b/src/data.py @@ -0,0 +1,189 @@ +"""Data loading and preprocessing for Graph-GrAPE experiments.""" + +import torch +import torch.nn.functional as F +from torch_geometric.datasets import Planetoid +import torch_geometric.transforms as T + + +def spmm(A, B): + """Sparse matrix @ dense matrix.""" + return torch.sparse.mm(A, B) + + +def build_normalized_adj(edge_index, num_nodes): + """Build  = D̃^{-1/2} à D̃^{-1/2} with self-loops, as sparse tensor.""" + row, col = edge_index + # Add self-loops: à = A + I + self_loops = torch.arange(num_nodes, device=edge_index.device) + row = torch.cat([row, self_loops]) + col = torch.cat([col, self_loops]) + + # Degree + deg = torch.zeros(num_nodes, device=edge_index.device) + deg.scatter_add_(0, row, torch.ones_like(row, dtype=torch.float)) + + # D̃^{-1/2} + deg_inv_sqrt = deg.pow(-0.5) + deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0.0 + + # Edge weights: d_i^{-1/2} * d_j^{-1/2} + values = deg_inv_sqrt[row] * deg_inv_sqrt[col] + + A_hat = torch.sparse_coo_tensor( + torch.stack([row, col]), values, (num_nodes, num_nodes) + ).coalesce() + + return A_hat + + +def precompute_traces(A_hat, max_power=4): + """Precompute tr(Â^k) for k=0..max_power. + + For Cora-sized graphs, we use exact computation via sparse powers. + For larger graphs, switch to Hutchinson estimator. + """ + N = A_hat.size(0) + traces = {0: torch.tensor(float(N), device=A_hat.device)} + + # tr(Â) = sum of diagonal entries + indices = A_hat.indices() + values = A_hat.values() + diag_mask = indices[0] == indices[1] + traces[1] = values[diag_mask].sum() + + # tr(Â^2) = ||Â||_F^2 = sum of squared entries + traces[2] = (values ** 2).sum() + + # For higher powers, use Hutchinson estimator: tr(M) ≈ (1/K) Σ z^T M z + if max_power >= 3: + num_probes = 100 + for power in range(3, max_power + 1): + est = 0.0 + for _ in range(num_probes): + z = torch.randn(N, 1, device=A_hat.device) + Az = z + for _ in range(power): + Az = spmm(A_hat, Az) + est += (z * Az).sum().item() + traces[power] = torch.tensor(est / num_probes, device=A_hat.device) + + return traces + + +def subsample_train_mask(data, label_rate, seed=0): + """Create a train mask with `label_rate` fraction of total nodes as labels. + + Ensures at least 1 node per class. + """ + y = data['y'] + N = data['num_nodes'] + C = data['num_classes'] + n_per_class = max(1, int(label_rate * N / C)) + + rng = torch.Generator() + rng.manual_seed(seed) + + mask = torch.zeros(N, dtype=torch.bool, device=y.device) + for c in range(C): + idx_c = (y == c).nonzero(as_tuple=True)[0] + perm = torch.randperm(len(idx_c), generator=rng) + selected = idx_c[perm[:n_per_class]] + mask[selected] = True + + return mask + + +def build_row_normalized_adj(edge_index, num_nodes): + """Build D⁻¹Ã (row-normalized) and its transpose, as sparse tensors.""" + row, col = edge_index + self_loops = torch.arange(num_nodes, device=edge_index.device) + row = torch.cat([row, self_loops]) + col = torch.cat([col, self_loops]) + + deg = torch.zeros(num_nodes, device=edge_index.device) + deg.scatter_add_(0, row, torch.ones_like(row, dtype=torch.float)) + deg_inv = deg.pow(-1) + deg_inv[deg_inv == float('inf')] = 0.0 + + # D⁻¹Ã: normalize by row (target node degree) + vals = deg_inv[row] + A_row = torch.sparse_coo_tensor( + torch.stack([row, col]), vals, (num_nodes, num_nodes) + ).coalesce() + + # Transpose: à D⁻¹ (normalize by column / source node degree) + vals_T = deg_inv[col] + A_row_T = torch.sparse_coo_tensor( + torch.stack([row, col]), vals_T, (num_nodes, num_nodes) + ).coalesce() + + return A_row, A_row_T + + +def load_amazon(name, root='./data', device='cuda:0', train_ratio=0.1, val_ratio=0.1, seed=0): + """Load Amazon Photo or Computers with random split.""" + from torch_geometric.datasets import Amazon + dataset = Amazon(root=root, name=name) + data = dataset[0] + N = data.num_nodes + C = dataset.num_classes + + # Random split: train_ratio per class, val_ratio per class, rest = test + rng = torch.Generator().manual_seed(seed) + train_mask = torch.zeros(N, dtype=torch.bool) + val_mask = torch.zeros(N, dtype=torch.bool) + test_mask = torch.zeros(N, dtype=torch.bool) + for c in range(C): + idx = (data.y == c).nonzero(as_tuple=True)[0] + perm = torch.randperm(len(idx), generator=rng) + n_train = max(1, int(train_ratio * len(idx))) + n_val = max(1, int(val_ratio * len(idx))) + train_mask[idx[perm[:n_train]]] = True + val_mask[idx[perm[n_train:n_train + n_val]]] = True + test_mask[idx[perm[n_train + n_val:]]] = True + + A_hat = build_normalized_adj(data.edge_index, N) + A_row, A_row_T = build_row_normalized_adj(data.edge_index, N) + traces = precompute_traces(A_hat, max_power=4) + + return { + 'X': data.x.to(device), + 'y': data.y.to(device), + 'A_hat': A_hat.to(device), + 'A_row': A_row.to(device), + 'A_row_T': A_row_T.to(device), + 'train_mask': train_mask.to(device), + 'val_mask': val_mask.to(device), + 'test_mask': test_mask.to(device), + 'num_nodes': N, + 'num_features': data.x.shape[1], + 'num_classes': C, + 'traces': {k: v.to(device) for k, v in traces.items()}, + } + + +def load_dataset(name, root='./data', device='cuda:3'): + """Load Planetoid dataset and precompute graph structures.""" + dataset = Planetoid(root=root, name=name, transform=T.NormalizeFeatures()) + data = dataset[0] + + A_hat = build_normalized_adj(data.edge_index, data.num_nodes) + A_row, A_row_T = build_row_normalized_adj(data.edge_index, data.num_nodes) + traces = precompute_traces(A_hat, max_power=4) + + result = { + 'X': data.x.to(device), + 'y': data.y.to(device), + 'A_hat': A_hat.to(device), + 'A_row': A_row.to(device), + 'A_row_T': A_row_T.to(device), + 'train_mask': data.train_mask.to(device), + 'val_mask': data.val_mask.to(device), + 'test_mask': data.test_mask.to(device), + 'num_nodes': data.num_nodes, + 'num_features': dataset.num_features, + 'num_classes': dataset.num_classes, + 'traces': {k: v.to(device) for k, v in traces.items()}, + } + return result |
