From bd9333eda60a9029a198acaeacb1eca4312bd1e8 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 4 May 2026 23:05:16 -0500 Subject: =?UTF-8?q?Initial=20release:=20GRAFT=20(KAFT)=20=E2=80=94=20NeurI?= =?UTF-8?q?PS=202026=20submission=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Topology-factorized Jacobian-aligned feedback for deep GNNs. Includes: - src/: GraphGrAPETrainer (KAFT) + BP / DFA / DFA-GNN / VanillaGrAPE baselines + multi-probe alignment estimator + dataset / sparse-mm utilities. - experiments/: 19 runners reproducing every figure / table in the paper. - figures/: 4 generators + the 4 PDFs cited in the report. - paper/: NeurIPS .tex and consolidated experiments_master notes. Smoke test: 50-epoch Cora GCN L=4 gives BP 77.3% / KAFT 79.0%. --- src/__init__.py | 0 src/data.py | 189 +++++++++++++++ src/trainers.py | 697 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 886 insertions(+) create mode 100644 src/__init__.py create mode 100644 src/data.py create mode 100644 src/trainers.py (limited to 'src') diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/data.py b/src/data.py new file mode 100644 index 0000000..6e80285 --- /dev/null +++ b/src/data.py @@ -0,0 +1,189 @@ +"""Data loading and preprocessing for Graph-GrAPE experiments.""" + +import torch +import torch.nn.functional as F +from torch_geometric.datasets import Planetoid +import torch_geometric.transforms as T + + +def spmm(A, B): + """Sparse matrix @ dense matrix.""" + return torch.sparse.mm(A, B) + + +def build_normalized_adj(edge_index, num_nodes): + """Build  = D̃^{-1/2} à D̃^{-1/2} with self-loops, as sparse tensor.""" + row, col = edge_index + # Add self-loops: à = A + I + self_loops = torch.arange(num_nodes, device=edge_index.device) + row = torch.cat([row, self_loops]) + col = torch.cat([col, self_loops]) + + # Degree + deg = torch.zeros(num_nodes, device=edge_index.device) + deg.scatter_add_(0, row, torch.ones_like(row, dtype=torch.float)) + + # D̃^{-1/2} + deg_inv_sqrt = deg.pow(-0.5) + deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0.0 + + # Edge weights: d_i^{-1/2} * d_j^{-1/2} + values = deg_inv_sqrt[row] * deg_inv_sqrt[col] + + A_hat = torch.sparse_coo_tensor( + torch.stack([row, col]), values, (num_nodes, num_nodes) + ).coalesce() + + return A_hat + + +def precompute_traces(A_hat, max_power=4): + """Precompute tr(Â^k) for k=0..max_power. + + For Cora-sized graphs, we use exact computation via sparse powers. + For larger graphs, switch to Hutchinson estimator. + """ + N = A_hat.size(0) + traces = {0: torch.tensor(float(N), device=A_hat.device)} + + # tr(Â) = sum of diagonal entries + indices = A_hat.indices() + values = A_hat.values() + diag_mask = indices[0] == indices[1] + traces[1] = values[diag_mask].sum() + + # tr(Â^2) = ||Â||_F^2 = sum of squared entries + traces[2] = (values ** 2).sum() + + # For higher powers, use Hutchinson estimator: tr(M) ≈ (1/K) Σ z^T M z + if max_power >= 3: + num_probes = 100 + for power in range(3, max_power + 1): + est = 0.0 + for _ in range(num_probes): + z = torch.randn(N, 1, device=A_hat.device) + Az = z + for _ in range(power): + Az = spmm(A_hat, Az) + est += (z * Az).sum().item() + traces[power] = torch.tensor(est / num_probes, device=A_hat.device) + + return traces + + +def subsample_train_mask(data, label_rate, seed=0): + """Create a train mask with `label_rate` fraction of total nodes as labels. + + Ensures at least 1 node per class. + """ + y = data['y'] + N = data['num_nodes'] + C = data['num_classes'] + n_per_class = max(1, int(label_rate * N / C)) + + rng = torch.Generator() + rng.manual_seed(seed) + + mask = torch.zeros(N, dtype=torch.bool, device=y.device) + for c in range(C): + idx_c = (y == c).nonzero(as_tuple=True)[0] + perm = torch.randperm(len(idx_c), generator=rng) + selected = idx_c[perm[:n_per_class]] + mask[selected] = True + + return mask + + +def build_row_normalized_adj(edge_index, num_nodes): + """Build D⁻¹Ã (row-normalized) and its transpose, as sparse tensors.""" + row, col = edge_index + self_loops = torch.arange(num_nodes, device=edge_index.device) + row = torch.cat([row, self_loops]) + col = torch.cat([col, self_loops]) + + deg = torch.zeros(num_nodes, device=edge_index.device) + deg.scatter_add_(0, row, torch.ones_like(row, dtype=torch.float)) + deg_inv = deg.pow(-1) + deg_inv[deg_inv == float('inf')] = 0.0 + + # D⁻¹Ã: normalize by row (target node degree) + vals = deg_inv[row] + A_row = torch.sparse_coo_tensor( + torch.stack([row, col]), vals, (num_nodes, num_nodes) + ).coalesce() + + # Transpose: à D⁻¹ (normalize by column / source node degree) + vals_T = deg_inv[col] + A_row_T = torch.sparse_coo_tensor( + torch.stack([row, col]), vals_T, (num_nodes, num_nodes) + ).coalesce() + + return A_row, A_row_T + + +def load_amazon(name, root='./data', device='cuda:0', train_ratio=0.1, val_ratio=0.1, seed=0): + """Load Amazon Photo or Computers with random split.""" + from torch_geometric.datasets import Amazon + dataset = Amazon(root=root, name=name) + data = dataset[0] + N = data.num_nodes + C = dataset.num_classes + + # Random split: train_ratio per class, val_ratio per class, rest = test + rng = torch.Generator().manual_seed(seed) + train_mask = torch.zeros(N, dtype=torch.bool) + val_mask = torch.zeros(N, dtype=torch.bool) + test_mask = torch.zeros(N, dtype=torch.bool) + for c in range(C): + idx = (data.y == c).nonzero(as_tuple=True)[0] + perm = torch.randperm(len(idx), generator=rng) + n_train = max(1, int(train_ratio * len(idx))) + n_val = max(1, int(val_ratio * len(idx))) + train_mask[idx[perm[:n_train]]] = True + val_mask[idx[perm[n_train:n_train + n_val]]] = True + test_mask[idx[perm[n_train + n_val:]]] = True + + A_hat = build_normalized_adj(data.edge_index, N) + A_row, A_row_T = build_row_normalized_adj(data.edge_index, N) + traces = precompute_traces(A_hat, max_power=4) + + return { + 'X': data.x.to(device), + 'y': data.y.to(device), + 'A_hat': A_hat.to(device), + 'A_row': A_row.to(device), + 'A_row_T': A_row_T.to(device), + 'train_mask': train_mask.to(device), + 'val_mask': val_mask.to(device), + 'test_mask': test_mask.to(device), + 'num_nodes': N, + 'num_features': data.x.shape[1], + 'num_classes': C, + 'traces': {k: v.to(device) for k, v in traces.items()}, + } + + +def load_dataset(name, root='./data', device='cuda:3'): + """Load Planetoid dataset and precompute graph structures.""" + dataset = Planetoid(root=root, name=name, transform=T.NormalizeFeatures()) + data = dataset[0] + + A_hat = build_normalized_adj(data.edge_index, data.num_nodes) + A_row, A_row_T = build_row_normalized_adj(data.edge_index, data.num_nodes) + traces = precompute_traces(A_hat, max_power=4) + + result = { + 'X': data.x.to(device), + 'y': data.y.to(device), + 'A_hat': A_hat.to(device), + 'A_row': A_row.to(device), + 'A_row_T': A_row_T.to(device), + 'train_mask': data.train_mask.to(device), + 'val_mask': data.val_mask.to(device), + 'test_mask': data.test_mask.to(device), + 'num_nodes': data.num_nodes, + 'num_features': dataset.num_features, + 'num_classes': dataset.num_classes, + 'traces': {k: v.to(device) for k, v in traces.items()}, + } + return result diff --git a/src/trainers.py b/src/trainers.py new file mode 100644 index 0000000..651dffc --- /dev/null +++ b/src/trainers.py @@ -0,0 +1,697 @@ +""" +Training methods for Graph-GrAPE experiments. +Generalized to L-layer residual GCN. + +Methods compared: + BP — Standard backprop GCN + DFA — Fixed random R, P=I + DFA-GNN — Fixed random R, P=Â^{L-l} + VanillaGrAPE — Aligned R (per layer), P=I + GraphGrAPE — Aligned R (per layer) + topology P=Â^{L-l} +""" + +import torch +import torch.nn.functional as F +from src.data import spmm + + +# --------------------------------------------------------------------------- +# Error diffusion (DFA-GNN style label spreading) +# --------------------------------------------------------------------------- + +def label_spreading(E, A_hat, alpha=0.5, num_iters=10): + """Diffuse error from labeled to unlabeled nodes.""" + Z = E.clone() + for _ in range(num_iters): + Z = (1 - alpha) * E + alpha * spmm(A_hat, Z) + labeled_mask = E.abs().sum(dim=1) > 0 + if labeled_mask.any(): + avg_norm = E[labeled_mask].norm(dim=1).mean() + unlabeled = ~labeled_mask + norms = Z[unlabeled].norm(dim=1, keepdim=True).clamp(min=1e-8) + Z[unlabeled] = Z[unlabeled] * (avg_norm / norms) + return Z + + +# --------------------------------------------------------------------------- +# BP Trainer +# --------------------------------------------------------------------------- + +class BPTrainer: + """L-layer GNN with backpropagation. Supports GCN/SAGE/GIN + BN/Dropout.""" + + def __init__(self, data, hidden_dim, lr, weight_decay, + num_layers=2, residual_alpha=0.0, backbone='gcn', + use_batchnorm=False, dropout=0.0, **_kw): + dev = data['X'].device + d_in, d_out = data['num_features'], data['num_classes'] + self.data = data + self.num_layers = num_layers + self.residual_alpha = residual_alpha + self.backbone = backbone + self.dropout = dropout + self._training = True + + dims = [d_in] + [hidden_dim] * (num_layers - 1) + [d_out] + self.weights = [] + for i in range(num_layers): + w = torch.nn.Parameter(torch.empty(dims[i], dims[i + 1], device=dev)) + torch.nn.init.xavier_uniform_(w) + self.weights.append(w) + + # GIN: learnable ε per layer + if backbone == 'gin': + self.gin_eps = [torch.nn.Parameter(torch.zeros(1, device=dev)) + for _ in range(num_layers)] + else: + self.gin_eps = None + + # BatchNorm (using nn.BatchNorm1d for autograd compatibility) + self.use_batchnorm = use_batchnorm + self.bns = [] + if use_batchnorm: + for _ in range(num_layers - 1): + self.bns.append(torch.nn.BatchNorm1d(hidden_dim).to(dev)) + + # Optimizer — include all learnable params + all_params = list(self.weights) + if self.gin_eps: + all_params += self.gin_eps + for bn in self.bns: + all_params += list(bn.parameters()) + self.optimizer = torch.optim.Adam(all_params, lr=lr, weight_decay=weight_decay) + + def _graph_conv(self, H, W, l): + HW = H @ W + if self.backbone in ('gcn', 'appnp'): + return spmm(self.data['A_hat'], HW) + elif self.backbone == 'sage': + return spmm(self.data['A_row'], HW) + elif self.backbone == 'gin': + return (1 + self.gin_eps[l]) * HW + spmm(self.data['A_hat'], HW) + raise ValueError(self.backbone) + + def _appnp_propagate(self, Z, alpha=0.1, K=10): + """APPNP-style propagation: H = α·Z + (1-α)·Â·H, iterated K times.""" + H = Z + A = self.data['A_hat'] + for _ in range(K): + H = alpha * Z + (1 - alpha) * spmm(A, H) + return H + + def forward(self): + X = self.data['X'] + H = X + H0 = None + + if self.backbone == 'appnp': + # APPNP: MLP first, then propagate + for l in range(self.num_layers): + Z = H @ self.weights[l] # pure linear (no graph conv) + if l < self.num_layers - 1: + if self.use_batchnorm: + Z = self.bns[l](Z) + H = F.relu(Z) + if self.dropout > 0 and self._training: + H = F.dropout(H, p=self.dropout, training=True) + else: + # Propagate only at the end + Z = self._appnp_propagate(Z) + return Z, {} + return Z, {} + + # Standard per-layer graph conv (GCN/SAGE/GIN) + for l in range(self.num_layers): + if l > 0 and l < self.num_layers - 1 and self.residual_alpha > 0 and H0 is not None: + H = (1 - self.residual_alpha) * H + self.residual_alpha * H0 + Z = self._graph_conv(H, self.weights[l], l) + if l < self.num_layers - 1: + if self.use_batchnorm: + Z = self.bns[l](Z) + H = F.relu(Z) + if self.dropout > 0 and self._training: + H = F.dropout(H, p=self.dropout, training=True) + if l == 0: + H0 = H + else: + return Z, {} + return Z, {} + + def train_step(self): + self.optimizer.zero_grad() + Z_out, _ = self.forward() + mask = self.data['train_mask'] + loss = F.cross_entropy(Z_out[mask], self.data['y'][mask]) + loss.backward() + self.optimizer.step() + with torch.no_grad(): + acc = (Z_out[mask].argmax(1) == self.data['y'][mask]).float().mean() + return loss.item(), acc.item(), {} + + @torch.no_grad() + def evaluate(self, mask_name='test_mask'): + self._training = False + for bn in self.bns: + bn.eval() + Z_out, _ = self.forward() + self._training = True + for bn in self.bns: + bn.train() + mask = self.data[mask_name] + return (Z_out[mask].argmax(1) == self.data['y'][mask]).float().mean().item() + + def train(self, epochs, verbose=True): + hist = {k: [] for k in ['train_loss', 'train_acc', 'val_acc', 'test_acc']} + for ep in range(epochs): + loss, tacc, _ = self.train_step() + vacc = self.evaluate('val_mask') + teacc = self.evaluate('test_mask') + for k, v in zip(hist, [loss, tacc, vacc, teacc]): + hist[k].append(v) + if verbose and ep % 50 == 0: + print(f" [BP] ep {ep:3d} | loss {loss:.4f} | " + f"train {tacc:.4f} | val {vacc:.4f} | test {teacc:.4f}") + return hist + + +# --------------------------------------------------------------------------- +# Base class for non-BP methods (L-layer) +# --------------------------------------------------------------------------- + +class _FeedbackTrainerBase: + """Shared logic for DFA / GrAPE variants, generalized to L layers.""" + + def __init__(self, data, hidden_dim, lr, weight_decay, + diffusion_alpha, diffusion_iters, + num_layers=2, residual_alpha=0.0, backbone='gcn', + use_batchnorm=False, dropout=0.0): + dev = data['X'].device + self.device = dev + d_in = data['num_features'] + d_out = data['num_classes'] + self.data = data + self.d_in = d_in + self.d_out = d_out + self.hidden_dim = hidden_dim + self.lr = lr + self.wd = weight_decay + self.diff_alpha = diffusion_alpha + self.diff_iters = diffusion_iters + self.num_layers = num_layers + self.residual_alpha = residual_alpha + self.backbone = backbone + self.dropout = dropout + self._training = True + + dims = [d_in] + [hidden_dim] * (num_layers - 1) + [d_out] + self.weights = [] + for i in range(num_layers): + w = torch.empty(dims[i], dims[i + 1], device=dev) + torch.nn.init.xavier_uniform_(w) + self.weights.append(w) + + # GIN: learnable ε per layer + if backbone == 'gin': + self.gin_eps = [torch.zeros(1, device=dev) for _ in range(num_layers)] + else: + self.gin_eps = None + + # BatchNorm per hidden layer (running stats tracked manually) + self.use_batchnorm = use_batchnorm + if use_batchnorm: + self.bn_weight = [torch.ones(hidden_dim, device=dev) for _ in range(num_layers - 1)] + self.bn_bias = [torch.zeros(hidden_dim, device=dev) for _ in range(num_layers - 1)] + self.bn_running_mean = [torch.zeros(hidden_dim, device=dev) for _ in range(num_layers - 1)] + self.bn_running_var = [torch.ones(hidden_dim, device=dev) for _ in range(num_layers - 1)] + self.bn_momentum = 0.1 + + # Adam state (per weight) + self._use_adam = True + self._adam = [{'m': torch.zeros_like(w), 'v': torch.zeros_like(w)} + for w in self.weights] + self._adam_t = 0 + self._adam_beta1 = 0.9 + self._adam_beta2 = 0.999 + self._adam_eps = 1e-8 + + # SGD momentum state + self._momentum = 0.0 + self._sgd_vel = [torch.zeros_like(w) for w in self.weights] + + # --- graph conv helpers ------------------------------------------------- + + def _graph_conv(self, H, W, l): + """Forward graph convolution (backbone-dependent).""" + HW = H @ W + if self.backbone in ('gcn', 'appnp'): + return spmm(self.data['A_hat'], HW) + elif self.backbone == 'sage': + return spmm(self.data['A_row'], HW) + elif self.backbone == 'gin': + return (1 + self.gin_eps[l]) * HW + spmm(self.data['A_hat'], HW) + raise ValueError(self.backbone) + + def _graph_conv_T(self, delta, l): + """Transpose of graph conv applied to delta (for gradient computation).""" + if self.backbone in ('gcn', 'appnp'): + return spmm(self.data['A_hat'], delta) + elif self.backbone == 'sage': + return spmm(self.data['A_row_T'], delta) + elif self.backbone == 'gin': + return (1 + self.gin_eps[l]) * delta + spmm(self.data['A_hat'], delta) + raise ValueError(self.backbone) + + # --- batchnorm helper -------------------------------------------------- + + def _apply_bn(self, H, l): + """Manual BatchNorm (no autograd needed).""" + if not self.use_batchnorm: + return H + if self._training: + mean = H.mean(dim=0) + var = H.var(dim=0, unbiased=False) + # Update running stats + self.bn_running_mean[l] = (1 - self.bn_momentum) * self.bn_running_mean[l] + self.bn_momentum * mean + self.bn_running_var[l] = (1 - self.bn_momentum) * self.bn_running_var[l] + self.bn_momentum * var + else: + mean = self.bn_running_mean[l] + var = self.bn_running_var[l] + H_norm = (H - mean) / (var + 1e-5).sqrt() + return H_norm * self.bn_weight[l] + self.bn_bias[l] + + # --- APPNP propagation ------------------------------------------------- + + def _appnp_propagate(self, Z, alpha=0.1, K=10): + H = Z + A = self.data['A_hat'] + for _ in range(K): + H = alpha * Z + (1 - alpha) * spmm(A, H) + return H + + # --- forward ----------------------------------------------------------- + + def forward(self): + X = self.data['X'] + Zs = [] + Hs = [] + H = X + H0 = None + + if self.backbone == 'appnp': + # APPNP: MLP layers, then propagate at end + for l in range(self.num_layers): + Z = H @ self.weights[l] # pure linear + Zs.append(Z) + if l < self.num_layers - 1: + Z_bn = self._apply_bn(Z, l) + H = F.relu(Z_bn) + if self.dropout > 0 and self._training: + H = F.dropout(H, p=self.dropout, training=True) + Hs.append(H) + else: + Z = self._appnp_propagate(Z) + Zs[-1] = Z # replace with propagated version + return Z, {'Zs': Zs, 'Hs': Hs, 'H0': H0} + + # Standard per-layer graph conv + for l in range(self.num_layers): + if l > 0 and l < self.num_layers - 1 and self.residual_alpha > 0 and H0 is not None: + H = (1 - self.residual_alpha) * H + self.residual_alpha * H0 + + Z = self._graph_conv(H, self.weights[l], l) + Zs.append(Z) + + if l < self.num_layers - 1: + Z_bn = self._apply_bn(Z, l) + H = F.relu(Z_bn) + if self.dropout > 0 and self._training: + H = F.dropout(H, p=self.dropout, training=True) + Hs.append(H) + if l == 0: + H0 = H + + return Z, {'Zs': Zs, 'Hs': Hs, 'H0': H0} + + # --- output error ------------------------------------------------------ + + def _output_error(self, Z_out): + mask = self.data['train_mask'] + y = self.data['y'] + n_labeled = mask.sum().float().clamp(min=1.0) + probs = F.softmax(Z_out.detach(), dim=1) + y_oh = F.one_hot(y, self.d_out).float() + E0 = torch.zeros_like(probs) + E0[mask] = (probs[mask] - y_oh[mask]) / n_labeled + E_bar = label_spreading( + E0, self.data['A_hat'], self.diff_alpha, self.diff_iters + ) + return E0, E_bar + + # --- weight update (Adam / SGD / SGD+momentum) ------------------------- + + def _adam_step(self, idx, grad): + s = self._adam[idx] + b1, b2, eps = self._adam_beta1, self._adam_beta2, self._adam_eps + t = self._adam_t + s['m'] = b1 * s['m'] + (1 - b1) * grad + s['v'] = b2 * s['v'] + (1 - b2) * grad ** 2 + m_hat = s['m'] / (1 - b1 ** t) + v_hat = s['v'] / (1 - b2 ** t) + return self.lr * (m_hat / (v_hat.sqrt() + eps) + self.wd * self.weights[idx]) + + def _update_weights(self, inter, E0, deltas): + """Update all weights. + + Output layer (last): true gradient from E0. + Hidden layers: feedback-based deltas[l]. + """ + X = self.data['X'] + Hs = inter['Hs'] + H0 = inter['H0'] + + grads = [] + for l in range(self.num_layers): + if l == self.num_layers - 1: + H_prev = Hs[-1] if Hs else X + g = H_prev.t() @ self._graph_conv_T(E0, l) + else: + if l == 0: + H_in = X + else: + H_prev = Hs[l - 1] + if self.residual_alpha > 0 and H0 is not None: + H_in = (1 - self.residual_alpha) * H_prev + self.residual_alpha * H0 + else: + H_in = H_prev + g = H_in.t() @ self._graph_conv_T(deltas[l], l) + grads.append(g) + + if self._use_adam: + self._adam_t += 1 + for i in range(self.num_layers): + self.weights[i] = self.weights[i] - self._adam_step(i, grads[i]) + else: + for i in range(self.num_layers): + if self._momentum > 0: + self._sgd_vel[i] = self._momentum * self._sgd_vel[i] + grads[i] + self.wd * self.weights[i] + self.weights[i] = self.weights[i] - self.lr * self._sgd_vel[i] + else: + self.weights[i] = self.weights[i] - self.lr * (grads[i] + self.wd * self.weights[i]) + + # --- alignment / feedback (override in subclasses) --------------------- + + def _alignment_step(self, inter): + return {} + + def _compute_hidden_feedback(self, l, inter, E_bar): + raise NotImplementedError + + # --- train loop -------------------------------------------------------- + + def train_step(self): + Z_out, inter = self.forward() + E0, E_bar = self._output_error(Z_out) + align_metrics = self._alignment_step(inter) + + deltas = [] + for l in range(self.num_layers - 1): + relu_gate = (inter['Zs'][l].detach() > 0).float() + raw_fb = self._compute_hidden_feedback(l, inter, E_bar) + deltas.append(relu_gate * raw_fb) + + self._update_weights(inter, E0, deltas) + + with torch.no_grad(): + mask = self.data['train_mask'] + loss = F.cross_entropy(Z_out[mask], self.data['y'][mask]).item() + acc = (Z_out[mask].argmax(1) == self.data['y'][mask]).float().mean().item() + return loss, acc, align_metrics + + @torch.no_grad() + def evaluate(self, mask_name='test_mask'): + self._training = False + Z_out, _ = self.forward() + self._training = True + mask = self.data[mask_name] + return (Z_out[mask].argmax(1) == self.data['y'][mask]).float().mean().item() + + def compute_bp_gradient_cosine(self): + """Average cos(feedback grad, BP grad) across hidden layers.""" + if self.backbone == 'appnp': + return 0.0 # APPNP forward differs; skip cos_bp for now + + wp = [] + for w in self.weights: + wp.append(w.clone().detach().requires_grad_(True)) + + # Also handle GIN eps for autograd + eps_p = None + if self.backbone == 'gin': + eps_p = [e.clone().detach().requires_grad_(True) for e in self.gin_eps] + + X = self.data['X'] + H = X + H0_a = None + for l in range(self.num_layers): + if l > 0 and l < self.num_layers - 1 and self.residual_alpha > 0 and H0_a is not None: + H = (1 - self.residual_alpha) * H + self.residual_alpha * H0_a + HW = H @ wp[l] + if self.backbone == 'gcn': + Z = spmm(self.data['A_hat'], HW) + elif self.backbone == 'sage': + Z = spmm(self.data['A_row'], HW) + elif self.backbone == 'gin': + Z = (1 + eps_p[l]) * HW + spmm(self.data['A_hat'], HW) + if l < self.num_layers - 1: + H = F.relu(Z) + if l == 0: + H0_a = H + + mask = self.data['train_mask'] + loss = F.cross_entropy(Z[mask], self.data['y'][mask]) + loss.backward() + + _, inter = self.forward() + E0, E_bar = self._output_error(Z) + + cosines = [] + for l in range(self.num_layers - 1): + bp_grad_l = wp[l].grad.detach() + relu_gate = (inter['Zs'][l].detach() > 0).float() + raw_fb = self._compute_hidden_feedback(l, inter, E_bar) + delta_l = relu_gate * raw_fb + + if l == 0: + H_in = X + else: + H_prev = inter['Hs'][l - 1] + if self.residual_alpha > 0 and inter['H0'] is not None: + H_in = (1 - self.residual_alpha) * H_prev + self.residual_alpha * inter['H0'] + else: + H_in = H_prev + our_grad_l = H_in.t() @ self._graph_conv_T(delta_l, l) + + c = F.cosine_similarity( + bp_grad_l.reshape(1, -1), our_grad_l.reshape(1, -1) + ).item() + cosines.append(c) + + return sum(cosines) / len(cosines) if cosines else 0.0 + + def train(self, epochs, verbose=True): + hist = {k: [] for k in + ['train_loss', 'train_acc', 'val_acc', 'test_acc', 'cos_bp']} + for ep in range(epochs): + loss, tacc, metrics = self.train_step() + vacc = self.evaluate('val_mask') + teacc = self.evaluate('test_mask') + + cos_bp = 0.0 + if ep % 10 == 0: + cos_bp = self.compute_bp_gradient_cosine() + + hist['train_loss'].append(loss) + hist['train_acc'].append(tacc) + hist['val_acc'].append(vacc) + hist['test_acc'].append(teacc) + hist['cos_bp'].append(cos_bp) + for k, v in metrics.items(): + hist.setdefault(k, []).append(v) + + if verbose and ep % 50 == 0: + tag = self.__class__.__name__ + extra = ''.join(f' | {k} {v:.4f}' for k, v in metrics.items()) + print(f" [{tag}] ep {ep:3d} | loss {loss:.4f} | " + f"train {tacc:.4f} | val {vacc:.4f} | test {teacc:.4f} | " + f"cos_bp {cos_bp:.4f}{extra}") + return hist + + +# --------------------------------------------------------------------------- +# DFA Trainer +# --------------------------------------------------------------------------- + +class DFATrainer(_FeedbackTrainerBase): + """DFA: fixed random R, no topology. Same R for all layers.""" + + def __init__(self, data, hidden_dim, lr, weight_decay, + diffusion_alpha=0.5, diffusion_iters=10, + num_layers=2, residual_alpha=0.0, backbone='gcn', **_kw): + super().__init__(data, hidden_dim, lr, weight_decay, + diffusion_alpha, diffusion_iters, + num_layers, residual_alpha, backbone, + _kw.get('use_batchnorm', False), _kw.get('dropout', 0.0)) + self.R_fixed = torch.randn(self.d_out, hidden_dim, device=self.device) * 0.01 + + def _compute_hidden_feedback(self, l, inter, E_bar): + return E_bar @ self.R_fixed + + +# --------------------------------------------------------------------------- +# DFA-GNN Trainer +# --------------------------------------------------------------------------- + +class DFAGNNTrainer(_FeedbackTrainerBase): + """DFA-GNN: fixed random R, topology P = Â^{min(L-l, max_power)} per layer.""" + + def __init__(self, data, hidden_dim, lr, weight_decay, + diffusion_alpha=0.5, diffusion_iters=10, + num_layers=2, residual_alpha=0.0, backbone='gcn', + max_topo_power=3, **_kw): + super().__init__(data, hidden_dim, lr, weight_decay, + diffusion_alpha, diffusion_iters, + num_layers, residual_alpha, backbone, + _kw.get('use_batchnorm', False), _kw.get('dropout', 0.0)) + self.max_topo_power = max_topo_power + self.R_fixed = torch.randn(self.d_out, hidden_dim, device=self.device) * 0.01 + + def _compute_hidden_feedback(self, l, inter, E_bar): + A = self.data['A_hat'] + power = min(self.num_layers - l, self.max_topo_power) + out = E_bar + for _ in range(power): + out = spmm(A, out) + return out @ self.R_fixed + + +# --------------------------------------------------------------------------- +# Vanilla GrAPE Trainer +# --------------------------------------------------------------------------- + +class VanillaGrAPETrainer(_FeedbackTrainerBase): + """Aligned R per layer, no topology (P=I).""" + + def __init__(self, data, hidden_dim, lr, weight_decay, + lr_feedback=0.5, num_probes=64, + diffusion_alpha=0.5, diffusion_iters=10, + num_layers=2, residual_alpha=0.0, backbone='gcn', **_kw): + super().__init__(data, hidden_dim, lr, weight_decay, + diffusion_alpha, diffusion_iters, + num_layers, residual_alpha, backbone, + _kw.get('use_batchnorm', False), _kw.get('dropout', 0.0)) + self.lr_fb = lr_feedback + self.num_probes = num_probes + # One R per hidden layer + self.Rs = [torch.randn(self.d_out, hidden_dim, device=self.device) * 0.01 + for _ in range(num_layers - 1)] + + def _alignment_step(self, inter): + metrics = {} + for l in range(self.num_layers - 1): + cos = _align_R_layer(self, l) + metrics[f'cos_feat_L{l}'] = cos + metrics['cos_feat'] = sum(metrics.values()) / len(metrics) + return metrics + + def _compute_hidden_feedback(self, l, inter, E_bar): + return E_bar @ self.Rs[l] + + +# --------------------------------------------------------------------------- +# Graph-GrAPE Trainer +# --------------------------------------------------------------------------- + +class GraphGrAPETrainer(_FeedbackTrainerBase): + """Aligned R per layer + topology P = Â^{min(L-l, max_power)}.""" + + def __init__(self, data, hidden_dim, lr, weight_decay, + lr_feedback=0.5, num_probes=64, + topo_mode='fixed_A', max_topo_power=3, + diffusion_alpha=0.5, diffusion_iters=10, + num_layers=2, residual_alpha=0.0, backbone='gcn', **_kw): + super().__init__(data, hidden_dim, lr, weight_decay, + diffusion_alpha, diffusion_iters, + num_layers, residual_alpha, backbone, + _kw.get('use_batchnorm', False), _kw.get('dropout', 0.0)) + self.lr_fb = lr_feedback + self.num_probes = num_probes + self.topo_mode = topo_mode + self.max_topo_power = max_topo_power + self.Rs = [torch.randn(self.d_out, hidden_dim, device=self.device) * 0.01 + for _ in range(num_layers - 1)] + + def _alignment_step(self, inter): + metrics = {} + for l in range(self.num_layers - 1): + cos = _align_R_layer(self, l) + metrics[f'cos_feat_L{l}'] = cos + metrics['cos_feat'] = sum(metrics.values()) / len(metrics) + return metrics + + def _compute_hidden_feedback(self, l, inter, E_bar): + A = self.data['A_hat'] + power = min(self.num_layers - l, self.max_topo_power) + topo_E = E_bar + for _ in range(power): + topo_E = spmm(A, topo_E) + return topo_E @ self.Rs[l] + + +# --------------------------------------------------------------------------- +# Shared multi-probe feature-side alignment (per layer) +# --------------------------------------------------------------------------- + +def _align_R_layer(trainer, l): + """Align R_l via multi-probe estimation. + + Two modes controlled by trainer.align_mode: + 'chain_norm' (default): full chain with per-step normalization to prevent explosion + 'next_layer': align to W_{l+1}^T only (local, stable for any depth) + """ + mode = getattr(trainer, 'align_mode', 'chain_norm') + B_mat = torch.randn(trainer.hidden_dim, trainer.num_probes, device=trainer.device) + + if mode == 'next_layer': + # Align to the last two layers' chain (stable, captures output mapping) + # For any layer l: target = W_{L-1}^T @ W_{L-2}^T (last 2 layers) + # This keeps the target shape consistent (d_out × hidden) + result = B_mat + start = max(l + 1, trainer.num_layers - 2) # at most last 2 layers + for k in range(start, trainer.num_layers): + result = trainer.weights[k].t() @ result + else: + # Full chain with per-step normalization to prevent explosion + result = B_mat + for k in range(l + 1, trainer.num_layers): + result = trainer.weights[k].t() @ result + # Normalize to prevent chain explosion (preserve direction, bound magnitude) + col_norms = result.norm(dim=0, keepdim=True).clamp(min=1e-8) + result = result / col_norms * B_mat.norm(dim=0, keepdim=True).mean() + + J_feat = result @ B_mat.t() / trainer.num_probes # (d_out, hidden_dim) + + R_l = trainer.Rs[l] + cos_feat = F.cosine_similarity( + R_l.reshape(1, -1), J_feat.reshape(1, -1) + ).item() + + R_norm = R_l.norm().clamp(min=1e-8) + J_norm = J_feat.norm().clamp(min=1e-8) + grad_R = J_feat / (R_norm * J_norm) - cos_feat * R_l / (R_norm ** 2) + trainer.Rs[l] = R_l + trainer.lr_fb * grad_R + + # Column normalization (standard) + col_norms = trainer.Rs[l].norm(dim=0, keepdim=True).clamp(min=1e-8) + trainer.Rs[l] = trainer.Rs[l] / col_norms + + return cos_feat -- cgit v1.2.3