""" Training methods for KAFT experiments. Generalized to L-layer residual GCN. Methods compared: BP — Standard backprop GCN DFA — Fixed random R, P=I DFA-GNN — Fixed random R, P=Â^{L-l} VanillaGrAPE — Aligned R (per layer), P=I KAFT — Aligned R (per layer) + topology P=Â^{L-l} """ import torch import torch.nn.functional as F from src.data import spmm # --------------------------------------------------------------------------- # Error diffusion (DFA-GNN style label spreading) # --------------------------------------------------------------------------- def label_spreading(E, A_hat, alpha=0.5, num_iters=10): """Diffuse error from labeled to unlabeled nodes.""" Z = E.clone() for _ in range(num_iters): Z = (1 - alpha) * E + alpha * spmm(A_hat, Z) labeled_mask = E.abs().sum(dim=1) > 0 if labeled_mask.any(): avg_norm = E[labeled_mask].norm(dim=1).mean() unlabeled = ~labeled_mask norms = Z[unlabeled].norm(dim=1, keepdim=True).clamp(min=1e-8) Z[unlabeled] = Z[unlabeled] * (avg_norm / norms) return Z # --------------------------------------------------------------------------- # BP Trainer # --------------------------------------------------------------------------- class BPTrainer: """L-layer GNN with backpropagation. Supports GCN/SAGE/GIN + BN/Dropout.""" def __init__(self, data, hidden_dim, lr, weight_decay, num_layers=2, residual_alpha=0.0, backbone='gcn', use_batchnorm=False, dropout=0.0, **_kw): dev = data['X'].device d_in, d_out = data['num_features'], data['num_classes'] self.data = data self.num_layers = num_layers self.residual_alpha = residual_alpha self.backbone = backbone self.dropout = dropout self._training = True dims = [d_in] + [hidden_dim] * (num_layers - 1) + [d_out] self.weights = [] for i in range(num_layers): w = torch.nn.Parameter(torch.empty(dims[i], dims[i + 1], device=dev)) torch.nn.init.xavier_uniform_(w) self.weights.append(w) # GIN: learnable ε per layer if backbone == 'gin': self.gin_eps = [torch.nn.Parameter(torch.zeros(1, device=dev)) for _ in range(num_layers)] else: self.gin_eps = None # BatchNorm (using nn.BatchNorm1d for autograd compatibility) self.use_batchnorm = use_batchnorm self.bns = [] if use_batchnorm: for _ in range(num_layers - 1): self.bns.append(torch.nn.BatchNorm1d(hidden_dim).to(dev)) # Optimizer — include all learnable params all_params = list(self.weights) if self.gin_eps: all_params += self.gin_eps for bn in self.bns: all_params += list(bn.parameters()) self.optimizer = torch.optim.Adam(all_params, lr=lr, weight_decay=weight_decay) def _graph_conv(self, H, W, l): HW = H @ W if self.backbone in ('gcn', 'appnp'): return spmm(self.data['A_hat'], HW) elif self.backbone == 'sage': return spmm(self.data['A_row'], HW) elif self.backbone == 'gin': return (1 + self.gin_eps[l]) * HW + spmm(self.data['A_hat'], HW) raise ValueError(self.backbone) def _appnp_propagate(self, Z, alpha=0.1, K=10): """APPNP-style propagation: H = α·Z + (1-α)·Â·H, iterated K times.""" H = Z A = self.data['A_hat'] for _ in range(K): H = alpha * Z + (1 - alpha) * spmm(A, H) return H def forward(self): X = self.data['X'] H = X H0 = None if self.backbone == 'appnp': # APPNP: MLP first, then propagate for l in range(self.num_layers): Z = H @ self.weights[l] # pure linear (no graph conv) if l < self.num_layers - 1: if self.use_batchnorm: Z = self.bns[l](Z) H = F.relu(Z) if self.dropout > 0 and self._training: H = F.dropout(H, p=self.dropout, training=True) else: # Propagate only at the end Z = self._appnp_propagate(Z) return Z, {} return Z, {} # Standard per-layer graph conv (GCN/SAGE/GIN) for l in range(self.num_layers): if l > 0 and l < self.num_layers - 1 and self.residual_alpha > 0 and H0 is not None: H = (1 - self.residual_alpha) * H + self.residual_alpha * H0 Z = self._graph_conv(H, self.weights[l], l) if l < self.num_layers - 1: if self.use_batchnorm: Z = self.bns[l](Z) H = F.relu(Z) if self.dropout > 0 and self._training: H = F.dropout(H, p=self.dropout, training=True) if l == 0: H0 = H else: return Z, {} return Z, {} def train_step(self): self.optimizer.zero_grad() Z_out, _ = self.forward() mask = self.data['train_mask'] loss = F.cross_entropy(Z_out[mask], self.data['y'][mask]) loss.backward() self.optimizer.step() with torch.no_grad(): acc = (Z_out[mask].argmax(1) == self.data['y'][mask]).float().mean() return loss.item(), acc.item(), {} @torch.no_grad() def evaluate(self, mask_name='test_mask'): self._training = False for bn in self.bns: bn.eval() Z_out, _ = self.forward() self._training = True for bn in self.bns: bn.train() mask = self.data[mask_name] return (Z_out[mask].argmax(1) == self.data['y'][mask]).float().mean().item() def train(self, epochs, verbose=True): hist = {k: [] for k in ['train_loss', 'train_acc', 'val_acc', 'test_acc']} for ep in range(epochs): loss, tacc, _ = self.train_step() vacc = self.evaluate('val_mask') teacc = self.evaluate('test_mask') for k, v in zip(hist, [loss, tacc, vacc, teacc]): hist[k].append(v) if verbose and ep % 50 == 0: print(f" [BP] ep {ep:3d} | loss {loss:.4f} | " f"train {tacc:.4f} | val {vacc:.4f} | test {teacc:.4f}") return hist # --------------------------------------------------------------------------- # Base class for non-BP methods (L-layer) # --------------------------------------------------------------------------- class _FeedbackTrainerBase: """Shared logic for DFA / KAFT variants, generalized to L layers.""" def __init__(self, data, hidden_dim, lr, weight_decay, diffusion_alpha, diffusion_iters, num_layers=2, residual_alpha=0.0, backbone='gcn', use_batchnorm=False, dropout=0.0): dev = data['X'].device self.device = dev d_in = data['num_features'] d_out = data['num_classes'] self.data = data self.d_in = d_in self.d_out = d_out self.hidden_dim = hidden_dim self.lr = lr self.wd = weight_decay self.diff_alpha = diffusion_alpha self.diff_iters = diffusion_iters self.num_layers = num_layers self.residual_alpha = residual_alpha self.backbone = backbone self.dropout = dropout self._training = True dims = [d_in] + [hidden_dim] * (num_layers - 1) + [d_out] self.weights = [] for i in range(num_layers): w = torch.empty(dims[i], dims[i + 1], device=dev) torch.nn.init.xavier_uniform_(w) self.weights.append(w) # GIN: learnable ε per layer if backbone == 'gin': self.gin_eps = [torch.zeros(1, device=dev) for _ in range(num_layers)] else: self.gin_eps = None # BatchNorm per hidden layer (running stats tracked manually) self.use_batchnorm = use_batchnorm if use_batchnorm: self.bn_weight = [torch.ones(hidden_dim, device=dev) for _ in range(num_layers - 1)] self.bn_bias = [torch.zeros(hidden_dim, device=dev) for _ in range(num_layers - 1)] self.bn_running_mean = [torch.zeros(hidden_dim, device=dev) for _ in range(num_layers - 1)] self.bn_running_var = [torch.ones(hidden_dim, device=dev) for _ in range(num_layers - 1)] self.bn_momentum = 0.1 # Adam state (per weight) self._use_adam = True self._adam = [{'m': torch.zeros_like(w), 'v': torch.zeros_like(w)} for w in self.weights] self._adam_t = 0 self._adam_beta1 = 0.9 self._adam_beta2 = 0.999 self._adam_eps = 1e-8 # SGD momentum state self._momentum = 0.0 self._sgd_vel = [torch.zeros_like(w) for w in self.weights] # --- graph conv helpers ------------------------------------------------- def _graph_conv(self, H, W, l): """Forward graph convolution (backbone-dependent).""" HW = H @ W if self.backbone in ('gcn', 'appnp'): return spmm(self.data['A_hat'], HW) elif self.backbone == 'sage': return spmm(self.data['A_row'], HW) elif self.backbone == 'gin': return (1 + self.gin_eps[l]) * HW + spmm(self.data['A_hat'], HW) raise ValueError(self.backbone) def _graph_conv_T(self, delta, l): """Transpose of graph conv applied to delta (for gradient computation).""" if self.backbone in ('gcn', 'appnp'): return spmm(self.data['A_hat'], delta) elif self.backbone == 'sage': return spmm(self.data['A_row_T'], delta) elif self.backbone == 'gin': return (1 + self.gin_eps[l]) * delta + spmm(self.data['A_hat'], delta) raise ValueError(self.backbone) # --- batchnorm helper -------------------------------------------------- def _apply_bn(self, H, l): """Manual BatchNorm (no autograd needed).""" if not self.use_batchnorm: return H if self._training: mean = H.mean(dim=0) var = H.var(dim=0, unbiased=False) # Update running stats self.bn_running_mean[l] = (1 - self.bn_momentum) * self.bn_running_mean[l] + self.bn_momentum * mean self.bn_running_var[l] = (1 - self.bn_momentum) * self.bn_running_var[l] + self.bn_momentum * var else: mean = self.bn_running_mean[l] var = self.bn_running_var[l] H_norm = (H - mean) / (var + 1e-5).sqrt() return H_norm * self.bn_weight[l] + self.bn_bias[l] # --- APPNP propagation ------------------------------------------------- def _appnp_propagate(self, Z, alpha=0.1, K=10): H = Z A = self.data['A_hat'] for _ in range(K): H = alpha * Z + (1 - alpha) * spmm(A, H) return H # --- forward ----------------------------------------------------------- def forward(self): X = self.data['X'] Zs = [] Hs = [] H = X H0 = None if self.backbone == 'appnp': # APPNP: MLP layers, then propagate at end for l in range(self.num_layers): Z = H @ self.weights[l] # pure linear Zs.append(Z) if l < self.num_layers - 1: Z_bn = self._apply_bn(Z, l) H = F.relu(Z_bn) if self.dropout > 0 and self._training: H = F.dropout(H, p=self.dropout, training=True) Hs.append(H) else: Z = self._appnp_propagate(Z) Zs[-1] = Z # replace with propagated version return Z, {'Zs': Zs, 'Hs': Hs, 'H0': H0} # Standard per-layer graph conv for l in range(self.num_layers): if l > 0 and l < self.num_layers - 1 and self.residual_alpha > 0 and H0 is not None: H = (1 - self.residual_alpha) * H + self.residual_alpha * H0 Z = self._graph_conv(H, self.weights[l], l) Zs.append(Z) if l < self.num_layers - 1: Z_bn = self._apply_bn(Z, l) H = F.relu(Z_bn) if self.dropout > 0 and self._training: H = F.dropout(H, p=self.dropout, training=True) Hs.append(H) if l == 0: H0 = H return Z, {'Zs': Zs, 'Hs': Hs, 'H0': H0} # --- output error ------------------------------------------------------ def _output_error(self, Z_out): mask = self.data['train_mask'] y = self.data['y'] n_labeled = mask.sum().float().clamp(min=1.0) probs = F.softmax(Z_out.detach(), dim=1) y_oh = F.one_hot(y, self.d_out).float() E0 = torch.zeros_like(probs) E0[mask] = (probs[mask] - y_oh[mask]) / n_labeled E_bar = label_spreading( E0, self.data['A_hat'], self.diff_alpha, self.diff_iters ) return E0, E_bar # --- weight update (Adam / SGD / SGD+momentum) ------------------------- def _adam_step(self, idx, grad): s = self._adam[idx] b1, b2, eps = self._adam_beta1, self._adam_beta2, self._adam_eps t = self._adam_t s['m'] = b1 * s['m'] + (1 - b1) * grad s['v'] = b2 * s['v'] + (1 - b2) * grad ** 2 m_hat = s['m'] / (1 - b1 ** t) v_hat = s['v'] / (1 - b2 ** t) return self.lr * (m_hat / (v_hat.sqrt() + eps) + self.wd * self.weights[idx]) def _update_weights(self, inter, E0, deltas): """Update all weights. Output layer (last): true gradient from E0. Hidden layers: feedback-based deltas[l]. """ X = self.data['X'] Hs = inter['Hs'] H0 = inter['H0'] grads = [] for l in range(self.num_layers): if l == self.num_layers - 1: H_prev = Hs[-1] if Hs else X g = H_prev.t() @ self._graph_conv_T(E0, l) else: if l == 0: H_in = X else: H_prev = Hs[l - 1] if self.residual_alpha > 0 and H0 is not None: H_in = (1 - self.residual_alpha) * H_prev + self.residual_alpha * H0 else: H_in = H_prev g = H_in.t() @ self._graph_conv_T(deltas[l], l) grads.append(g) if self._use_adam: self._adam_t += 1 for i in range(self.num_layers): self.weights[i] = self.weights[i] - self._adam_step(i, grads[i]) else: for i in range(self.num_layers): if self._momentum > 0: self._sgd_vel[i] = self._momentum * self._sgd_vel[i] + grads[i] + self.wd * self.weights[i] self.weights[i] = self.weights[i] - self.lr * self._sgd_vel[i] else: self.weights[i] = self.weights[i] - self.lr * (grads[i] + self.wd * self.weights[i]) # --- alignment / feedback (override in subclasses) --------------------- def _alignment_step(self, inter): return {} def _compute_hidden_feedback(self, l, inter, E_bar): raise NotImplementedError # --- train loop -------------------------------------------------------- def train_step(self): Z_out, inter = self.forward() E0, E_bar = self._output_error(Z_out) align_metrics = self._alignment_step(inter) deltas = [] for l in range(self.num_layers - 1): relu_gate = (inter['Zs'][l].detach() > 0).float() raw_fb = self._compute_hidden_feedback(l, inter, E_bar) deltas.append(relu_gate * raw_fb) self._update_weights(inter, E0, deltas) with torch.no_grad(): mask = self.data['train_mask'] loss = F.cross_entropy(Z_out[mask], self.data['y'][mask]).item() acc = (Z_out[mask].argmax(1) == self.data['y'][mask]).float().mean().item() return loss, acc, align_metrics @torch.no_grad() def evaluate(self, mask_name='test_mask'): self._training = False Z_out, _ = self.forward() self._training = True mask = self.data[mask_name] return (Z_out[mask].argmax(1) == self.data['y'][mask]).float().mean().item() def compute_bp_gradient_cosine(self): """Average cos(feedback grad, BP grad) across hidden layers.""" if self.backbone == 'appnp': return 0.0 # APPNP forward differs; skip cos_bp for now wp = [] for w in self.weights: wp.append(w.clone().detach().requires_grad_(True)) # Also handle GIN eps for autograd eps_p = None if self.backbone == 'gin': eps_p = [e.clone().detach().requires_grad_(True) for e in self.gin_eps] X = self.data['X'] H = X H0_a = None for l in range(self.num_layers): if l > 0 and l < self.num_layers - 1 and self.residual_alpha > 0 and H0_a is not None: H = (1 - self.residual_alpha) * H + self.residual_alpha * H0_a HW = H @ wp[l] if self.backbone == 'gcn': Z = spmm(self.data['A_hat'], HW) elif self.backbone == 'sage': Z = spmm(self.data['A_row'], HW) elif self.backbone == 'gin': Z = (1 + eps_p[l]) * HW + spmm(self.data['A_hat'], HW) if l < self.num_layers - 1: H = F.relu(Z) if l == 0: H0_a = H mask = self.data['train_mask'] loss = F.cross_entropy(Z[mask], self.data['y'][mask]) loss.backward() _, inter = self.forward() E0, E_bar = self._output_error(Z) cosines = [] for l in range(self.num_layers - 1): bp_grad_l = wp[l].grad.detach() relu_gate = (inter['Zs'][l].detach() > 0).float() raw_fb = self._compute_hidden_feedback(l, inter, E_bar) delta_l = relu_gate * raw_fb if l == 0: H_in = X else: H_prev = inter['Hs'][l - 1] if self.residual_alpha > 0 and inter['H0'] is not None: H_in = (1 - self.residual_alpha) * H_prev + self.residual_alpha * inter['H0'] else: H_in = H_prev our_grad_l = H_in.t() @ self._graph_conv_T(delta_l, l) c = F.cosine_similarity( bp_grad_l.reshape(1, -1), our_grad_l.reshape(1, -1) ).item() cosines.append(c) return sum(cosines) / len(cosines) if cosines else 0.0 def train(self, epochs, verbose=True): hist = {k: [] for k in ['train_loss', 'train_acc', 'val_acc', 'test_acc', 'cos_bp']} for ep in range(epochs): loss, tacc, metrics = self.train_step() vacc = self.evaluate('val_mask') teacc = self.evaluate('test_mask') cos_bp = 0.0 if ep % 10 == 0: cos_bp = self.compute_bp_gradient_cosine() hist['train_loss'].append(loss) hist['train_acc'].append(tacc) hist['val_acc'].append(vacc) hist['test_acc'].append(teacc) hist['cos_bp'].append(cos_bp) for k, v in metrics.items(): hist.setdefault(k, []).append(v) if verbose and ep % 50 == 0: tag = self.__class__.__name__ extra = ''.join(f' | {k} {v:.4f}' for k, v in metrics.items()) print(f" [{tag}] ep {ep:3d} | loss {loss:.4f} | " f"train {tacc:.4f} | val {vacc:.4f} | test {teacc:.4f} | " f"cos_bp {cos_bp:.4f}{extra}") return hist # --------------------------------------------------------------------------- # DFA Trainer # --------------------------------------------------------------------------- class DFATrainer(_FeedbackTrainerBase): """DFA: fixed random R, no topology. Same R for all layers.""" def __init__(self, data, hidden_dim, lr, weight_decay, diffusion_alpha=0.5, diffusion_iters=10, num_layers=2, residual_alpha=0.0, backbone='gcn', **_kw): super().__init__(data, hidden_dim, lr, weight_decay, diffusion_alpha, diffusion_iters, num_layers, residual_alpha, backbone, _kw.get('use_batchnorm', False), _kw.get('dropout', 0.0)) self.R_fixed = torch.randn(self.d_out, hidden_dim, device=self.device) * 0.01 def _compute_hidden_feedback(self, l, inter, E_bar): return E_bar @ self.R_fixed # --------------------------------------------------------------------------- # DFA-GNN Trainer # --------------------------------------------------------------------------- class DFAGNNTrainer(_FeedbackTrainerBase): """DFA-GNN: fixed random R, topology P = Â^{min(L-l, max_power)} per layer.""" def __init__(self, data, hidden_dim, lr, weight_decay, diffusion_alpha=0.5, diffusion_iters=10, num_layers=2, residual_alpha=0.0, backbone='gcn', max_topo_power=3, **_kw): super().__init__(data, hidden_dim, lr, weight_decay, diffusion_alpha, diffusion_iters, num_layers, residual_alpha, backbone, _kw.get('use_batchnorm', False), _kw.get('dropout', 0.0)) self.max_topo_power = max_topo_power self.R_fixed = torch.randn(self.d_out, hidden_dim, device=self.device) * 0.01 def _compute_hidden_feedback(self, l, inter, E_bar): A = self.data['A_hat'] power = min(self.num_layers - l, self.max_topo_power) out = E_bar for _ in range(power): out = spmm(A, out) return out @ self.R_fixed # --------------------------------------------------------------------------- # Vanilla GrAPE Trainer # --------------------------------------------------------------------------- class VanillaGrAPETrainer(_FeedbackTrainerBase): """Aligned R per layer, no topology (P=I).""" def __init__(self, data, hidden_dim, lr, weight_decay, lr_feedback=0.5, num_probes=64, diffusion_alpha=0.5, diffusion_iters=10, num_layers=2, residual_alpha=0.0, backbone='gcn', **_kw): super().__init__(data, hidden_dim, lr, weight_decay, diffusion_alpha, diffusion_iters, num_layers, residual_alpha, backbone, _kw.get('use_batchnorm', False), _kw.get('dropout', 0.0)) self.lr_fb = lr_feedback self.num_probes = num_probes # One R per hidden layer self.Rs = [torch.randn(self.d_out, hidden_dim, device=self.device) * 0.01 for _ in range(num_layers - 1)] def _alignment_step(self, inter): metrics = {} for l in range(self.num_layers - 1): cos = _align_R_layer(self, l) metrics[f'cos_feat_L{l}'] = cos metrics['cos_feat'] = sum(metrics.values()) / len(metrics) return metrics def _compute_hidden_feedback(self, l, inter, E_bar): return E_bar @ self.Rs[l] # --------------------------------------------------------------------------- # KAFT Trainer # --------------------------------------------------------------------------- class KAFTTrainer(_FeedbackTrainerBase): """Aligned R per layer + topology P = Â^{min(L-l, max_power)}.""" def __init__(self, data, hidden_dim, lr, weight_decay, lr_feedback=0.5, num_probes=64, topo_mode='fixed_A', max_topo_power=3, diffusion_alpha=0.5, diffusion_iters=10, num_layers=2, residual_alpha=0.0, backbone='gcn', **_kw): super().__init__(data, hidden_dim, lr, weight_decay, diffusion_alpha, diffusion_iters, num_layers, residual_alpha, backbone, _kw.get('use_batchnorm', False), _kw.get('dropout', 0.0)) self.lr_fb = lr_feedback self.num_probes = num_probes self.topo_mode = topo_mode self.max_topo_power = max_topo_power self.Rs = [torch.randn(self.d_out, hidden_dim, device=self.device) * 0.01 for _ in range(num_layers - 1)] def _alignment_step(self, inter): metrics = {} for l in range(self.num_layers - 1): cos = _align_R_layer(self, l) metrics[f'cos_feat_L{l}'] = cos metrics['cos_feat'] = sum(metrics.values()) / len(metrics) return metrics def _compute_hidden_feedback(self, l, inter, E_bar): A = self.data['A_hat'] power = min(self.num_layers - l, self.max_topo_power) topo_E = E_bar for _ in range(power): topo_E = spmm(A, topo_E) return topo_E @ self.Rs[l] # --------------------------------------------------------------------------- # Shared multi-probe feature-side alignment (per layer) # --------------------------------------------------------------------------- def _align_R_layer(trainer, l): """Align R_l via multi-probe estimation. Two modes controlled by trainer.align_mode: 'chain_norm' (default): full chain with per-step normalization to prevent explosion 'next_layer': align to W_{l+1}^T only (local, stable for any depth) """ mode = getattr(trainer, 'align_mode', 'chain_norm') B_mat = torch.randn(trainer.hidden_dim, trainer.num_probes, device=trainer.device) if mode == 'next_layer': # Align to the last two layers' chain (stable, captures output mapping) # For any layer l: target = W_{L-1}^T @ W_{L-2}^T (last 2 layers) # This keeps the target shape consistent (d_out × hidden) result = B_mat start = max(l + 1, trainer.num_layers - 2) # at most last 2 layers for k in range(start, trainer.num_layers): result = trainer.weights[k].t() @ result else: # Full chain with per-step normalization to prevent explosion result = B_mat for k in range(l + 1, trainer.num_layers): result = trainer.weights[k].t() @ result # Normalize to prevent chain explosion (preserve direction, bound magnitude) col_norms = result.norm(dim=0, keepdim=True).clamp(min=1e-8) result = result / col_norms * B_mat.norm(dim=0, keepdim=True).mean() J_feat = result @ B_mat.t() / trainer.num_probes # (d_out, hidden_dim) R_l = trainer.Rs[l] cos_feat = F.cosine_similarity( R_l.reshape(1, -1), J_feat.reshape(1, -1) ).item() R_norm = R_l.norm().clamp(min=1e-8) J_norm = J_feat.norm().clamp(min=1e-8) grad_R = J_feat / (R_norm * J_norm) - cos_feat * R_l / (R_norm ** 2) trainer.Rs[l] = R_l + trainer.lr_fb * grad_R # Column normalization (standard) col_norms = trainer.Rs[l].norm(dim=0, keepdim=True).clamp(min=1e-8) trainer.Rs[l] = trainer.Rs[l] / col_norms return cos_feat