From bd9333eda60a9029a198acaeacb1eca4312bd1e8 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 4 May 2026 23:05:16 -0500 Subject: =?UTF-8?q?Initial=20release:=20GRAFT=20(KAFT)=20=E2=80=94=20NeurI?= =?UTF-8?q?PS=202026=20submission=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Topology-factorized Jacobian-aligned feedback for deep GNNs. Includes: - src/: GraphGrAPETrainer (KAFT) + BP / DFA / DFA-GNN / VanillaGrAPE baselines + multi-probe alignment estimator + dataset / sparse-mm utilities. - experiments/: 19 runners reproducing every figure / table in the paper. - figures/: 4 generators + the 4 PDFs cited in the report. - paper/: NeurIPS .tex and consolidated experiments_master notes. Smoke test: 50-epoch Cora GCN L=4 gives BP 77.3% / KAFT 79.0%. --- src/data.py | 189 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 189 insertions(+) create mode 100644 src/data.py (limited to 'src/data.py') diff --git a/src/data.py b/src/data.py new file mode 100644 index 0000000..6e80285 --- /dev/null +++ b/src/data.py @@ -0,0 +1,189 @@ +"""Data loading and preprocessing for Graph-GrAPE experiments.""" + +import torch +import torch.nn.functional as F +from torch_geometric.datasets import Planetoid +import torch_geometric.transforms as T + + +def spmm(A, B): + """Sparse matrix @ dense matrix.""" + return torch.sparse.mm(A, B) + + +def build_normalized_adj(edge_index, num_nodes): + """Build  = D̃^{-1/2} à D̃^{-1/2} with self-loops, as sparse tensor.""" + row, col = edge_index + # Add self-loops: à = A + I + self_loops = torch.arange(num_nodes, device=edge_index.device) + row = torch.cat([row, self_loops]) + col = torch.cat([col, self_loops]) + + # Degree + deg = torch.zeros(num_nodes, device=edge_index.device) + deg.scatter_add_(0, row, torch.ones_like(row, dtype=torch.float)) + + # D̃^{-1/2} + deg_inv_sqrt = deg.pow(-0.5) + deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0.0 + + # Edge weights: d_i^{-1/2} * d_j^{-1/2} + values = deg_inv_sqrt[row] * deg_inv_sqrt[col] + + A_hat = torch.sparse_coo_tensor( + torch.stack([row, col]), values, (num_nodes, num_nodes) + ).coalesce() + + return A_hat + + +def precompute_traces(A_hat, max_power=4): + """Precompute tr(Â^k) for k=0..max_power. + + For Cora-sized graphs, we use exact computation via sparse powers. + For larger graphs, switch to Hutchinson estimator. + """ + N = A_hat.size(0) + traces = {0: torch.tensor(float(N), device=A_hat.device)} + + # tr(Â) = sum of diagonal entries + indices = A_hat.indices() + values = A_hat.values() + diag_mask = indices[0] == indices[1] + traces[1] = values[diag_mask].sum() + + # tr(Â^2) = ||Â||_F^2 = sum of squared entries + traces[2] = (values ** 2).sum() + + # For higher powers, use Hutchinson estimator: tr(M) ≈ (1/K) Σ z^T M z + if max_power >= 3: + num_probes = 100 + for power in range(3, max_power + 1): + est = 0.0 + for _ in range(num_probes): + z = torch.randn(N, 1, device=A_hat.device) + Az = z + for _ in range(power): + Az = spmm(A_hat, Az) + est += (z * Az).sum().item() + traces[power] = torch.tensor(est / num_probes, device=A_hat.device) + + return traces + + +def subsample_train_mask(data, label_rate, seed=0): + """Create a train mask with `label_rate` fraction of total nodes as labels. + + Ensures at least 1 node per class. + """ + y = data['y'] + N = data['num_nodes'] + C = data['num_classes'] + n_per_class = max(1, int(label_rate * N / C)) + + rng = torch.Generator() + rng.manual_seed(seed) + + mask = torch.zeros(N, dtype=torch.bool, device=y.device) + for c in range(C): + idx_c = (y == c).nonzero(as_tuple=True)[0] + perm = torch.randperm(len(idx_c), generator=rng) + selected = idx_c[perm[:n_per_class]] + mask[selected] = True + + return mask + + +def build_row_normalized_adj(edge_index, num_nodes): + """Build D⁻¹Ã (row-normalized) and its transpose, as sparse tensors.""" + row, col = edge_index + self_loops = torch.arange(num_nodes, device=edge_index.device) + row = torch.cat([row, self_loops]) + col = torch.cat([col, self_loops]) + + deg = torch.zeros(num_nodes, device=edge_index.device) + deg.scatter_add_(0, row, torch.ones_like(row, dtype=torch.float)) + deg_inv = deg.pow(-1) + deg_inv[deg_inv == float('inf')] = 0.0 + + # D⁻¹Ã: normalize by row (target node degree) + vals = deg_inv[row] + A_row = torch.sparse_coo_tensor( + torch.stack([row, col]), vals, (num_nodes, num_nodes) + ).coalesce() + + # Transpose: à D⁻¹ (normalize by column / source node degree) + vals_T = deg_inv[col] + A_row_T = torch.sparse_coo_tensor( + torch.stack([row, col]), vals_T, (num_nodes, num_nodes) + ).coalesce() + + return A_row, A_row_T + + +def load_amazon(name, root='./data', device='cuda:0', train_ratio=0.1, val_ratio=0.1, seed=0): + """Load Amazon Photo or Computers with random split.""" + from torch_geometric.datasets import Amazon + dataset = Amazon(root=root, name=name) + data = dataset[0] + N = data.num_nodes + C = dataset.num_classes + + # Random split: train_ratio per class, val_ratio per class, rest = test + rng = torch.Generator().manual_seed(seed) + train_mask = torch.zeros(N, dtype=torch.bool) + val_mask = torch.zeros(N, dtype=torch.bool) + test_mask = torch.zeros(N, dtype=torch.bool) + for c in range(C): + idx = (data.y == c).nonzero(as_tuple=True)[0] + perm = torch.randperm(len(idx), generator=rng) + n_train = max(1, int(train_ratio * len(idx))) + n_val = max(1, int(val_ratio * len(idx))) + train_mask[idx[perm[:n_train]]] = True + val_mask[idx[perm[n_train:n_train + n_val]]] = True + test_mask[idx[perm[n_train + n_val:]]] = True + + A_hat = build_normalized_adj(data.edge_index, N) + A_row, A_row_T = build_row_normalized_adj(data.edge_index, N) + traces = precompute_traces(A_hat, max_power=4) + + return { + 'X': data.x.to(device), + 'y': data.y.to(device), + 'A_hat': A_hat.to(device), + 'A_row': A_row.to(device), + 'A_row_T': A_row_T.to(device), + 'train_mask': train_mask.to(device), + 'val_mask': val_mask.to(device), + 'test_mask': test_mask.to(device), + 'num_nodes': N, + 'num_features': data.x.shape[1], + 'num_classes': C, + 'traces': {k: v.to(device) for k, v in traces.items()}, + } + + +def load_dataset(name, root='./data', device='cuda:3'): + """Load Planetoid dataset and precompute graph structures.""" + dataset = Planetoid(root=root, name=name, transform=T.NormalizeFeatures()) + data = dataset[0] + + A_hat = build_normalized_adj(data.edge_index, data.num_nodes) + A_row, A_row_T = build_row_normalized_adj(data.edge_index, data.num_nodes) + traces = precompute_traces(A_hat, max_power=4) + + result = { + 'X': data.x.to(device), + 'y': data.y.to(device), + 'A_hat': A_hat.to(device), + 'A_row': A_row.to(device), + 'A_row_T': A_row_T.to(device), + 'train_mask': data.train_mask.to(device), + 'val_mask': data.val_mask.to(device), + 'test_mask': data.test_mask.to(device), + 'num_nodes': data.num_nodes, + 'num_features': dataset.num_features, + 'num_classes': dataset.num_classes, + 'traces': {k: v.to(device) for k, v in traces.items()}, + } + return result -- cgit v1.2.3