summaryrefslogtreecommitdiff
path: root/src/trainers.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/trainers.py')
-rw-r--r--src/trainers.py697
1 files changed, 697 insertions, 0 deletions
diff --git a/src/trainers.py b/src/trainers.py
new file mode 100644
index 0000000..651dffc
--- /dev/null
+++ b/src/trainers.py
@@ -0,0 +1,697 @@
+"""
+Training methods for Graph-GrAPE experiments.
+Generalized to L-layer residual GCN.
+
+Methods compared:
+ BP — Standard backprop GCN
+ DFA — Fixed random R, P=I
+ DFA-GNN — Fixed random R, P=Â^{L-l}
+ VanillaGrAPE — Aligned R (per layer), P=I
+ GraphGrAPE — Aligned R (per layer) + topology P=Â^{L-l}
+"""
+
+import torch
+import torch.nn.functional as F
+from src.data import spmm
+
+
+# ---------------------------------------------------------------------------
+# Error diffusion (DFA-GNN style label spreading)
+# ---------------------------------------------------------------------------
+
+def label_spreading(E, A_hat, alpha=0.5, num_iters=10):
+ """Diffuse error from labeled to unlabeled nodes."""
+ Z = E.clone()
+ for _ in range(num_iters):
+ Z = (1 - alpha) * E + alpha * spmm(A_hat, Z)
+ labeled_mask = E.abs().sum(dim=1) > 0
+ if labeled_mask.any():
+ avg_norm = E[labeled_mask].norm(dim=1).mean()
+ unlabeled = ~labeled_mask
+ norms = Z[unlabeled].norm(dim=1, keepdim=True).clamp(min=1e-8)
+ Z[unlabeled] = Z[unlabeled] * (avg_norm / norms)
+ return Z
+
+
+# ---------------------------------------------------------------------------
+# BP Trainer
+# ---------------------------------------------------------------------------
+
+class BPTrainer:
+ """L-layer GNN with backpropagation. Supports GCN/SAGE/GIN + BN/Dropout."""
+
+ def __init__(self, data, hidden_dim, lr, weight_decay,
+ num_layers=2, residual_alpha=0.0, backbone='gcn',
+ use_batchnorm=False, dropout=0.0, **_kw):
+ dev = data['X'].device
+ d_in, d_out = data['num_features'], data['num_classes']
+ self.data = data
+ self.num_layers = num_layers
+ self.residual_alpha = residual_alpha
+ self.backbone = backbone
+ self.dropout = dropout
+ self._training = True
+
+ dims = [d_in] + [hidden_dim] * (num_layers - 1) + [d_out]
+ self.weights = []
+ for i in range(num_layers):
+ w = torch.nn.Parameter(torch.empty(dims[i], dims[i + 1], device=dev))
+ torch.nn.init.xavier_uniform_(w)
+ self.weights.append(w)
+
+ # GIN: learnable ε per layer
+ if backbone == 'gin':
+ self.gin_eps = [torch.nn.Parameter(torch.zeros(1, device=dev))
+ for _ in range(num_layers)]
+ else:
+ self.gin_eps = None
+
+ # BatchNorm (using nn.BatchNorm1d for autograd compatibility)
+ self.use_batchnorm = use_batchnorm
+ self.bns = []
+ if use_batchnorm:
+ for _ in range(num_layers - 1):
+ self.bns.append(torch.nn.BatchNorm1d(hidden_dim).to(dev))
+
+ # Optimizer — include all learnable params
+ all_params = list(self.weights)
+ if self.gin_eps:
+ all_params += self.gin_eps
+ for bn in self.bns:
+ all_params += list(bn.parameters())
+ self.optimizer = torch.optim.Adam(all_params, lr=lr, weight_decay=weight_decay)
+
+ def _graph_conv(self, H, W, l):
+ HW = H @ W
+ if self.backbone in ('gcn', 'appnp'):
+ return spmm(self.data['A_hat'], HW)
+ elif self.backbone == 'sage':
+ return spmm(self.data['A_row'], HW)
+ elif self.backbone == 'gin':
+ return (1 + self.gin_eps[l]) * HW + spmm(self.data['A_hat'], HW)
+ raise ValueError(self.backbone)
+
+ def _appnp_propagate(self, Z, alpha=0.1, K=10):
+ """APPNP-style propagation: H = α·Z + (1-α)·Â·H, iterated K times."""
+ H = Z
+ A = self.data['A_hat']
+ for _ in range(K):
+ H = alpha * Z + (1 - alpha) * spmm(A, H)
+ return H
+
+ def forward(self):
+ X = self.data['X']
+ H = X
+ H0 = None
+
+ if self.backbone == 'appnp':
+ # APPNP: MLP first, then propagate
+ for l in range(self.num_layers):
+ Z = H @ self.weights[l] # pure linear (no graph conv)
+ if l < self.num_layers - 1:
+ if self.use_batchnorm:
+ Z = self.bns[l](Z)
+ H = F.relu(Z)
+ if self.dropout > 0 and self._training:
+ H = F.dropout(H, p=self.dropout, training=True)
+ else:
+ # Propagate only at the end
+ Z = self._appnp_propagate(Z)
+ return Z, {}
+ return Z, {}
+
+ # Standard per-layer graph conv (GCN/SAGE/GIN)
+ for l in range(self.num_layers):
+ if l > 0 and l < self.num_layers - 1 and self.residual_alpha > 0 and H0 is not None:
+ H = (1 - self.residual_alpha) * H + self.residual_alpha * H0
+ Z = self._graph_conv(H, self.weights[l], l)
+ if l < self.num_layers - 1:
+ if self.use_batchnorm:
+ Z = self.bns[l](Z)
+ H = F.relu(Z)
+ if self.dropout > 0 and self._training:
+ H = F.dropout(H, p=self.dropout, training=True)
+ if l == 0:
+ H0 = H
+ else:
+ return Z, {}
+ return Z, {}
+
+ def train_step(self):
+ self.optimizer.zero_grad()
+ Z_out, _ = self.forward()
+ mask = self.data['train_mask']
+ loss = F.cross_entropy(Z_out[mask], self.data['y'][mask])
+ loss.backward()
+ self.optimizer.step()
+ with torch.no_grad():
+ acc = (Z_out[mask].argmax(1) == self.data['y'][mask]).float().mean()
+ return loss.item(), acc.item(), {}
+
+ @torch.no_grad()
+ def evaluate(self, mask_name='test_mask'):
+ self._training = False
+ for bn in self.bns:
+ bn.eval()
+ Z_out, _ = self.forward()
+ self._training = True
+ for bn in self.bns:
+ bn.train()
+ mask = self.data[mask_name]
+ return (Z_out[mask].argmax(1) == self.data['y'][mask]).float().mean().item()
+
+ def train(self, epochs, verbose=True):
+ hist = {k: [] for k in ['train_loss', 'train_acc', 'val_acc', 'test_acc']}
+ for ep in range(epochs):
+ loss, tacc, _ = self.train_step()
+ vacc = self.evaluate('val_mask')
+ teacc = self.evaluate('test_mask')
+ for k, v in zip(hist, [loss, tacc, vacc, teacc]):
+ hist[k].append(v)
+ if verbose and ep % 50 == 0:
+ print(f" [BP] ep {ep:3d} | loss {loss:.4f} | "
+ f"train {tacc:.4f} | val {vacc:.4f} | test {teacc:.4f}")
+ return hist
+
+
+# ---------------------------------------------------------------------------
+# Base class for non-BP methods (L-layer)
+# ---------------------------------------------------------------------------
+
+class _FeedbackTrainerBase:
+ """Shared logic for DFA / GrAPE variants, generalized to L layers."""
+
+ def __init__(self, data, hidden_dim, lr, weight_decay,
+ diffusion_alpha, diffusion_iters,
+ num_layers=2, residual_alpha=0.0, backbone='gcn',
+ use_batchnorm=False, dropout=0.0):
+ dev = data['X'].device
+ self.device = dev
+ d_in = data['num_features']
+ d_out = data['num_classes']
+ self.data = data
+ self.d_in = d_in
+ self.d_out = d_out
+ self.hidden_dim = hidden_dim
+ self.lr = lr
+ self.wd = weight_decay
+ self.diff_alpha = diffusion_alpha
+ self.diff_iters = diffusion_iters
+ self.num_layers = num_layers
+ self.residual_alpha = residual_alpha
+ self.backbone = backbone
+ self.dropout = dropout
+ self._training = True
+
+ dims = [d_in] + [hidden_dim] * (num_layers - 1) + [d_out]
+ self.weights = []
+ for i in range(num_layers):
+ w = torch.empty(dims[i], dims[i + 1], device=dev)
+ torch.nn.init.xavier_uniform_(w)
+ self.weights.append(w)
+
+ # GIN: learnable ε per layer
+ if backbone == 'gin':
+ self.gin_eps = [torch.zeros(1, device=dev) for _ in range(num_layers)]
+ else:
+ self.gin_eps = None
+
+ # BatchNorm per hidden layer (running stats tracked manually)
+ self.use_batchnorm = use_batchnorm
+ if use_batchnorm:
+ self.bn_weight = [torch.ones(hidden_dim, device=dev) for _ in range(num_layers - 1)]
+ self.bn_bias = [torch.zeros(hidden_dim, device=dev) for _ in range(num_layers - 1)]
+ self.bn_running_mean = [torch.zeros(hidden_dim, device=dev) for _ in range(num_layers - 1)]
+ self.bn_running_var = [torch.ones(hidden_dim, device=dev) for _ in range(num_layers - 1)]
+ self.bn_momentum = 0.1
+
+ # Adam state (per weight)
+ self._use_adam = True
+ self._adam = [{'m': torch.zeros_like(w), 'v': torch.zeros_like(w)}
+ for w in self.weights]
+ self._adam_t = 0
+ self._adam_beta1 = 0.9
+ self._adam_beta2 = 0.999
+ self._adam_eps = 1e-8
+
+ # SGD momentum state
+ self._momentum = 0.0
+ self._sgd_vel = [torch.zeros_like(w) for w in self.weights]
+
+ # --- graph conv helpers -------------------------------------------------
+
+ def _graph_conv(self, H, W, l):
+ """Forward graph convolution (backbone-dependent)."""
+ HW = H @ W
+ if self.backbone in ('gcn', 'appnp'):
+ return spmm(self.data['A_hat'], HW)
+ elif self.backbone == 'sage':
+ return spmm(self.data['A_row'], HW)
+ elif self.backbone == 'gin':
+ return (1 + self.gin_eps[l]) * HW + spmm(self.data['A_hat'], HW)
+ raise ValueError(self.backbone)
+
+ def _graph_conv_T(self, delta, l):
+ """Transpose of graph conv applied to delta (for gradient computation)."""
+ if self.backbone in ('gcn', 'appnp'):
+ return spmm(self.data['A_hat'], delta)
+ elif self.backbone == 'sage':
+ return spmm(self.data['A_row_T'], delta)
+ elif self.backbone == 'gin':
+ return (1 + self.gin_eps[l]) * delta + spmm(self.data['A_hat'], delta)
+ raise ValueError(self.backbone)
+
+ # --- batchnorm helper --------------------------------------------------
+
+ def _apply_bn(self, H, l):
+ """Manual BatchNorm (no autograd needed)."""
+ if not self.use_batchnorm:
+ return H
+ if self._training:
+ mean = H.mean(dim=0)
+ var = H.var(dim=0, unbiased=False)
+ # Update running stats
+ self.bn_running_mean[l] = (1 - self.bn_momentum) * self.bn_running_mean[l] + self.bn_momentum * mean
+ self.bn_running_var[l] = (1 - self.bn_momentum) * self.bn_running_var[l] + self.bn_momentum * var
+ else:
+ mean = self.bn_running_mean[l]
+ var = self.bn_running_var[l]
+ H_norm = (H - mean) / (var + 1e-5).sqrt()
+ return H_norm * self.bn_weight[l] + self.bn_bias[l]
+
+ # --- APPNP propagation -------------------------------------------------
+
+ def _appnp_propagate(self, Z, alpha=0.1, K=10):
+ H = Z
+ A = self.data['A_hat']
+ for _ in range(K):
+ H = alpha * Z + (1 - alpha) * spmm(A, H)
+ return H
+
+ # --- forward -----------------------------------------------------------
+
+ def forward(self):
+ X = self.data['X']
+ Zs = []
+ Hs = []
+ H = X
+ H0 = None
+
+ if self.backbone == 'appnp':
+ # APPNP: MLP layers, then propagate at end
+ for l in range(self.num_layers):
+ Z = H @ self.weights[l] # pure linear
+ Zs.append(Z)
+ if l < self.num_layers - 1:
+ Z_bn = self._apply_bn(Z, l)
+ H = F.relu(Z_bn)
+ if self.dropout > 0 and self._training:
+ H = F.dropout(H, p=self.dropout, training=True)
+ Hs.append(H)
+ else:
+ Z = self._appnp_propagate(Z)
+ Zs[-1] = Z # replace with propagated version
+ return Z, {'Zs': Zs, 'Hs': Hs, 'H0': H0}
+
+ # Standard per-layer graph conv
+ for l in range(self.num_layers):
+ if l > 0 and l < self.num_layers - 1 and self.residual_alpha > 0 and H0 is not None:
+ H = (1 - self.residual_alpha) * H + self.residual_alpha * H0
+
+ Z = self._graph_conv(H, self.weights[l], l)
+ Zs.append(Z)
+
+ if l < self.num_layers - 1:
+ Z_bn = self._apply_bn(Z, l)
+ H = F.relu(Z_bn)
+ if self.dropout > 0 and self._training:
+ H = F.dropout(H, p=self.dropout, training=True)
+ Hs.append(H)
+ if l == 0:
+ H0 = H
+
+ return Z, {'Zs': Zs, 'Hs': Hs, 'H0': H0}
+
+ # --- output error ------------------------------------------------------
+
+ def _output_error(self, Z_out):
+ mask = self.data['train_mask']
+ y = self.data['y']
+ n_labeled = mask.sum().float().clamp(min=1.0)
+ probs = F.softmax(Z_out.detach(), dim=1)
+ y_oh = F.one_hot(y, self.d_out).float()
+ E0 = torch.zeros_like(probs)
+ E0[mask] = (probs[mask] - y_oh[mask]) / n_labeled
+ E_bar = label_spreading(
+ E0, self.data['A_hat'], self.diff_alpha, self.diff_iters
+ )
+ return E0, E_bar
+
+ # --- weight update (Adam / SGD / SGD+momentum) -------------------------
+
+ def _adam_step(self, idx, grad):
+ s = self._adam[idx]
+ b1, b2, eps = self._adam_beta1, self._adam_beta2, self._adam_eps
+ t = self._adam_t
+ s['m'] = b1 * s['m'] + (1 - b1) * grad
+ s['v'] = b2 * s['v'] + (1 - b2) * grad ** 2
+ m_hat = s['m'] / (1 - b1 ** t)
+ v_hat = s['v'] / (1 - b2 ** t)
+ return self.lr * (m_hat / (v_hat.sqrt() + eps) + self.wd * self.weights[idx])
+
+ def _update_weights(self, inter, E0, deltas):
+ """Update all weights.
+
+ Output layer (last): true gradient from E0.
+ Hidden layers: feedback-based deltas[l].
+ """
+ X = self.data['X']
+ Hs = inter['Hs']
+ H0 = inter['H0']
+
+ grads = []
+ for l in range(self.num_layers):
+ if l == self.num_layers - 1:
+ H_prev = Hs[-1] if Hs else X
+ g = H_prev.t() @ self._graph_conv_T(E0, l)
+ else:
+ if l == 0:
+ H_in = X
+ else:
+ H_prev = Hs[l - 1]
+ if self.residual_alpha > 0 and H0 is not None:
+ H_in = (1 - self.residual_alpha) * H_prev + self.residual_alpha * H0
+ else:
+ H_in = H_prev
+ g = H_in.t() @ self._graph_conv_T(deltas[l], l)
+ grads.append(g)
+
+ if self._use_adam:
+ self._adam_t += 1
+ for i in range(self.num_layers):
+ self.weights[i] = self.weights[i] - self._adam_step(i, grads[i])
+ else:
+ for i in range(self.num_layers):
+ if self._momentum > 0:
+ self._sgd_vel[i] = self._momentum * self._sgd_vel[i] + grads[i] + self.wd * self.weights[i]
+ self.weights[i] = self.weights[i] - self.lr * self._sgd_vel[i]
+ else:
+ self.weights[i] = self.weights[i] - self.lr * (grads[i] + self.wd * self.weights[i])
+
+ # --- alignment / feedback (override in subclasses) ---------------------
+
+ def _alignment_step(self, inter):
+ return {}
+
+ def _compute_hidden_feedback(self, l, inter, E_bar):
+ raise NotImplementedError
+
+ # --- train loop --------------------------------------------------------
+
+ def train_step(self):
+ Z_out, inter = self.forward()
+ E0, E_bar = self._output_error(Z_out)
+ align_metrics = self._alignment_step(inter)
+
+ deltas = []
+ for l in range(self.num_layers - 1):
+ relu_gate = (inter['Zs'][l].detach() > 0).float()
+ raw_fb = self._compute_hidden_feedback(l, inter, E_bar)
+ deltas.append(relu_gate * raw_fb)
+
+ self._update_weights(inter, E0, deltas)
+
+ with torch.no_grad():
+ mask = self.data['train_mask']
+ loss = F.cross_entropy(Z_out[mask], self.data['y'][mask]).item()
+ acc = (Z_out[mask].argmax(1) == self.data['y'][mask]).float().mean().item()
+ return loss, acc, align_metrics
+
+ @torch.no_grad()
+ def evaluate(self, mask_name='test_mask'):
+ self._training = False
+ Z_out, _ = self.forward()
+ self._training = True
+ mask = self.data[mask_name]
+ return (Z_out[mask].argmax(1) == self.data['y'][mask]).float().mean().item()
+
+ def compute_bp_gradient_cosine(self):
+ """Average cos(feedback grad, BP grad) across hidden layers."""
+ if self.backbone == 'appnp':
+ return 0.0 # APPNP forward differs; skip cos_bp for now
+
+ wp = []
+ for w in self.weights:
+ wp.append(w.clone().detach().requires_grad_(True))
+
+ # Also handle GIN eps for autograd
+ eps_p = None
+ if self.backbone == 'gin':
+ eps_p = [e.clone().detach().requires_grad_(True) for e in self.gin_eps]
+
+ X = self.data['X']
+ H = X
+ H0_a = None
+ for l in range(self.num_layers):
+ if l > 0 and l < self.num_layers - 1 and self.residual_alpha > 0 and H0_a is not None:
+ H = (1 - self.residual_alpha) * H + self.residual_alpha * H0_a
+ HW = H @ wp[l]
+ if self.backbone == 'gcn':
+ Z = spmm(self.data['A_hat'], HW)
+ elif self.backbone == 'sage':
+ Z = spmm(self.data['A_row'], HW)
+ elif self.backbone == 'gin':
+ Z = (1 + eps_p[l]) * HW + spmm(self.data['A_hat'], HW)
+ if l < self.num_layers - 1:
+ H = F.relu(Z)
+ if l == 0:
+ H0_a = H
+
+ mask = self.data['train_mask']
+ loss = F.cross_entropy(Z[mask], self.data['y'][mask])
+ loss.backward()
+
+ _, inter = self.forward()
+ E0, E_bar = self._output_error(Z)
+
+ cosines = []
+ for l in range(self.num_layers - 1):
+ bp_grad_l = wp[l].grad.detach()
+ relu_gate = (inter['Zs'][l].detach() > 0).float()
+ raw_fb = self._compute_hidden_feedback(l, inter, E_bar)
+ delta_l = relu_gate * raw_fb
+
+ if l == 0:
+ H_in = X
+ else:
+ H_prev = inter['Hs'][l - 1]
+ if self.residual_alpha > 0 and inter['H0'] is not None:
+ H_in = (1 - self.residual_alpha) * H_prev + self.residual_alpha * inter['H0']
+ else:
+ H_in = H_prev
+ our_grad_l = H_in.t() @ self._graph_conv_T(delta_l, l)
+
+ c = F.cosine_similarity(
+ bp_grad_l.reshape(1, -1), our_grad_l.reshape(1, -1)
+ ).item()
+ cosines.append(c)
+
+ return sum(cosines) / len(cosines) if cosines else 0.0
+
+ def train(self, epochs, verbose=True):
+ hist = {k: [] for k in
+ ['train_loss', 'train_acc', 'val_acc', 'test_acc', 'cos_bp']}
+ for ep in range(epochs):
+ loss, tacc, metrics = self.train_step()
+ vacc = self.evaluate('val_mask')
+ teacc = self.evaluate('test_mask')
+
+ cos_bp = 0.0
+ if ep % 10 == 0:
+ cos_bp = self.compute_bp_gradient_cosine()
+
+ hist['train_loss'].append(loss)
+ hist['train_acc'].append(tacc)
+ hist['val_acc'].append(vacc)
+ hist['test_acc'].append(teacc)
+ hist['cos_bp'].append(cos_bp)
+ for k, v in metrics.items():
+ hist.setdefault(k, []).append(v)
+
+ if verbose and ep % 50 == 0:
+ tag = self.__class__.__name__
+ extra = ''.join(f' | {k} {v:.4f}' for k, v in metrics.items())
+ print(f" [{tag}] ep {ep:3d} | loss {loss:.4f} | "
+ f"train {tacc:.4f} | val {vacc:.4f} | test {teacc:.4f} | "
+ f"cos_bp {cos_bp:.4f}{extra}")
+ return hist
+
+
+# ---------------------------------------------------------------------------
+# DFA Trainer
+# ---------------------------------------------------------------------------
+
+class DFATrainer(_FeedbackTrainerBase):
+ """DFA: fixed random R, no topology. Same R for all layers."""
+
+ def __init__(self, data, hidden_dim, lr, weight_decay,
+ diffusion_alpha=0.5, diffusion_iters=10,
+ num_layers=2, residual_alpha=0.0, backbone='gcn', **_kw):
+ super().__init__(data, hidden_dim, lr, weight_decay,
+ diffusion_alpha, diffusion_iters,
+ num_layers, residual_alpha, backbone,
+ _kw.get('use_batchnorm', False), _kw.get('dropout', 0.0))
+ self.R_fixed = torch.randn(self.d_out, hidden_dim, device=self.device) * 0.01
+
+ def _compute_hidden_feedback(self, l, inter, E_bar):
+ return E_bar @ self.R_fixed
+
+
+# ---------------------------------------------------------------------------
+# DFA-GNN Trainer
+# ---------------------------------------------------------------------------
+
+class DFAGNNTrainer(_FeedbackTrainerBase):
+ """DFA-GNN: fixed random R, topology P = Â^{min(L-l, max_power)} per layer."""
+
+ def __init__(self, data, hidden_dim, lr, weight_decay,
+ diffusion_alpha=0.5, diffusion_iters=10,
+ num_layers=2, residual_alpha=0.0, backbone='gcn',
+ max_topo_power=3, **_kw):
+ super().__init__(data, hidden_dim, lr, weight_decay,
+ diffusion_alpha, diffusion_iters,
+ num_layers, residual_alpha, backbone,
+ _kw.get('use_batchnorm', False), _kw.get('dropout', 0.0))
+ self.max_topo_power = max_topo_power
+ self.R_fixed = torch.randn(self.d_out, hidden_dim, device=self.device) * 0.01
+
+ def _compute_hidden_feedback(self, l, inter, E_bar):
+ A = self.data['A_hat']
+ power = min(self.num_layers - l, self.max_topo_power)
+ out = E_bar
+ for _ in range(power):
+ out = spmm(A, out)
+ return out @ self.R_fixed
+
+
+# ---------------------------------------------------------------------------
+# Vanilla GrAPE Trainer
+# ---------------------------------------------------------------------------
+
+class VanillaGrAPETrainer(_FeedbackTrainerBase):
+ """Aligned R per layer, no topology (P=I)."""
+
+ def __init__(self, data, hidden_dim, lr, weight_decay,
+ lr_feedback=0.5, num_probes=64,
+ diffusion_alpha=0.5, diffusion_iters=10,
+ num_layers=2, residual_alpha=0.0, backbone='gcn', **_kw):
+ super().__init__(data, hidden_dim, lr, weight_decay,
+ diffusion_alpha, diffusion_iters,
+ num_layers, residual_alpha, backbone,
+ _kw.get('use_batchnorm', False), _kw.get('dropout', 0.0))
+ self.lr_fb = lr_feedback
+ self.num_probes = num_probes
+ # One R per hidden layer
+ self.Rs = [torch.randn(self.d_out, hidden_dim, device=self.device) * 0.01
+ for _ in range(num_layers - 1)]
+
+ def _alignment_step(self, inter):
+ metrics = {}
+ for l in range(self.num_layers - 1):
+ cos = _align_R_layer(self, l)
+ metrics[f'cos_feat_L{l}'] = cos
+ metrics['cos_feat'] = sum(metrics.values()) / len(metrics)
+ return metrics
+
+ def _compute_hidden_feedback(self, l, inter, E_bar):
+ return E_bar @ self.Rs[l]
+
+
+# ---------------------------------------------------------------------------
+# Graph-GrAPE Trainer
+# ---------------------------------------------------------------------------
+
+class GraphGrAPETrainer(_FeedbackTrainerBase):
+ """Aligned R per layer + topology P = Â^{min(L-l, max_power)}."""
+
+ def __init__(self, data, hidden_dim, lr, weight_decay,
+ lr_feedback=0.5, num_probes=64,
+ topo_mode='fixed_A', max_topo_power=3,
+ diffusion_alpha=0.5, diffusion_iters=10,
+ num_layers=2, residual_alpha=0.0, backbone='gcn', **_kw):
+ super().__init__(data, hidden_dim, lr, weight_decay,
+ diffusion_alpha, diffusion_iters,
+ num_layers, residual_alpha, backbone,
+ _kw.get('use_batchnorm', False), _kw.get('dropout', 0.0))
+ self.lr_fb = lr_feedback
+ self.num_probes = num_probes
+ self.topo_mode = topo_mode
+ self.max_topo_power = max_topo_power
+ self.Rs = [torch.randn(self.d_out, hidden_dim, device=self.device) * 0.01
+ for _ in range(num_layers - 1)]
+
+ def _alignment_step(self, inter):
+ metrics = {}
+ for l in range(self.num_layers - 1):
+ cos = _align_R_layer(self, l)
+ metrics[f'cos_feat_L{l}'] = cos
+ metrics['cos_feat'] = sum(metrics.values()) / len(metrics)
+ return metrics
+
+ def _compute_hidden_feedback(self, l, inter, E_bar):
+ A = self.data['A_hat']
+ power = min(self.num_layers - l, self.max_topo_power)
+ topo_E = E_bar
+ for _ in range(power):
+ topo_E = spmm(A, topo_E)
+ return topo_E @ self.Rs[l]
+
+
+# ---------------------------------------------------------------------------
+# Shared multi-probe feature-side alignment (per layer)
+# ---------------------------------------------------------------------------
+
+def _align_R_layer(trainer, l):
+ """Align R_l via multi-probe estimation.
+
+ Two modes controlled by trainer.align_mode:
+ 'chain_norm' (default): full chain with per-step normalization to prevent explosion
+ 'next_layer': align to W_{l+1}^T only (local, stable for any depth)
+ """
+ mode = getattr(trainer, 'align_mode', 'chain_norm')
+ B_mat = torch.randn(trainer.hidden_dim, trainer.num_probes, device=trainer.device)
+
+ if mode == 'next_layer':
+ # Align to the last two layers' chain (stable, captures output mapping)
+ # For any layer l: target = W_{L-1}^T @ W_{L-2}^T (last 2 layers)
+ # This keeps the target shape consistent (d_out × hidden)
+ result = B_mat
+ start = max(l + 1, trainer.num_layers - 2) # at most last 2 layers
+ for k in range(start, trainer.num_layers):
+ result = trainer.weights[k].t() @ result
+ else:
+ # Full chain with per-step normalization to prevent explosion
+ result = B_mat
+ for k in range(l + 1, trainer.num_layers):
+ result = trainer.weights[k].t() @ result
+ # Normalize to prevent chain explosion (preserve direction, bound magnitude)
+ col_norms = result.norm(dim=0, keepdim=True).clamp(min=1e-8)
+ result = result / col_norms * B_mat.norm(dim=0, keepdim=True).mean()
+
+ J_feat = result @ B_mat.t() / trainer.num_probes # (d_out, hidden_dim)
+
+ R_l = trainer.Rs[l]
+ cos_feat = F.cosine_similarity(
+ R_l.reshape(1, -1), J_feat.reshape(1, -1)
+ ).item()
+
+ R_norm = R_l.norm().clamp(min=1e-8)
+ J_norm = J_feat.norm().clamp(min=1e-8)
+ grad_R = J_feat / (R_norm * J_norm) - cos_feat * R_l / (R_norm ** 2)
+ trainer.Rs[l] = R_l + trainer.lr_fb * grad_R
+
+ # Column normalization (standard)
+ col_norms = trainer.Rs[l].norm(dim=0, keepdim=True).clamp(min=1e-8)
+ trainer.Rs[l] = trainer.Rs[l] / col_norms
+
+ return cos_feat