summaryrefslogtreecommitdiff
path: root/metrics/credit_metrics.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-05-04 19:50:45 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-05-04 19:50:45 -0500
commitb480d0cdc21f944e4adccf6e81cc939b0450c5e9 (patch)
treef0e6afb5b3d448d1d6c35d9622d22d63073ca9a7 /metrics/credit_metrics.py
Initial submission code: FA evaluation protocol + reproduction scripts
Reference implementation of the three-diagnostic FA evaluation protocol (scale stability, reference validity, depth utility) from the NeurIPS 2026 E&D track paper. Includes models, metrics, and full reproduction pipeline. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'metrics/credit_metrics.py')
-rw-r--r--metrics/credit_metrics.py156
1 files changed, 156 insertions, 0 deletions
diff --git a/metrics/credit_metrics.py b/metrics/credit_metrics.py
new file mode 100644
index 0000000..516dca2
--- /dev/null
+++ b/metrics/credit_metrics.py
@@ -0,0 +1,156 @@
+"""
+Credit assignment diagnostic metrics:
+1. Exact costate cosine (for toy LQ)
+2. Local perturbation correlation rho_l
+3. Nudging test Delta_l^nudge
+4. Offline BP cosine Gamma_l
+5. Bridge residual R_l
+6. Feature drift M_l
+"""
+import torch
+import torch.nn.functional as F
+import numpy as np
+from scipy.stats import pearsonr
+
+
+def cosine_similarity_batch(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
+ """Compute cosine similarity between a and b along last dim, averaged over batch."""
+ a_flat = a.reshape(a.shape[0], -1)
+ b_flat = b.reshape(b.shape[0], -1)
+ cos = F.cosine_similarity(a_flat, b_flat, dim=-1)
+ return cos.mean().item()
+
+
+def perturbation_correlation(h_l, a_l, forward_fn, epsilon=1e-3, M=32):
+ """
+ Compute local perturbation correlation rho_l.
+
+ Args:
+ h_l: (batch, d) hidden state at layer l
+ a_l: (batch, d) credit signal at layer l
+ forward_fn: callable that takes h_l -> scalar loss (averaged over batch dims handled inside)
+ epsilon: perturbation magnitude
+ M: number of random directions
+
+ Returns:
+ rho: Pearson correlation between predicted and true loss changes
+ """
+ batch_size, d = h_l.shape
+ device = h_l.device
+
+ pred_list = []
+ true_list = []
+
+ base_loss = forward_fn(h_l) # (batch,) or scalar
+
+ for _ in range(M):
+ v = torch.randn(batch_size, d, device=device)
+ v = v / (v.norm(dim=-1, keepdim=True) + 1e-8)
+
+ # Predicted change: <a_l, epsilon * v>
+ delta_pred = (a_l * (epsilon * v)).sum(dim=-1) # (batch,)
+
+ # True change: forward from perturbed h
+ perturbed_loss = forward_fn(h_l + epsilon * v) # (batch,)
+ delta_true = perturbed_loss - base_loss # (batch,)
+
+ pred_list.append(delta_pred.detach().cpu().numpy())
+ true_list.append(delta_true.detach().cpu().numpy())
+
+ pred_arr = np.concatenate(pred_list)
+ true_arr = np.concatenate(true_list)
+
+ if np.std(pred_arr) < 1e-12 or np.std(true_arr) < 1e-12:
+ return 0.0
+
+ rho, _ = pearsonr(pred_arr, true_arr)
+ return float(rho)
+
+
+def nudging_test(h_l, a_l, forward_fn, eta=0.01):
+ """
+ Nudging test: check if moving h_l in -a_l direction decreases loss.
+
+ Args:
+ h_l: (batch, d) hidden state
+ a_l: (batch, d) credit signal
+ forward_fn: callable h -> loss per sample (batch,)
+ eta: step size
+
+ Returns:
+ mean delta_nudge (negative is good)
+ """
+ rms_a = (a_l ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_normed = a_l / rms_a
+ h_nudged = h_l - eta * a_normed
+
+ base_loss = forward_fn(h_l)
+ nudged_loss = forward_fn(h_nudged)
+ delta = (nudged_loss - base_loss).mean().item()
+ return delta
+
+
+def offline_bp_cosine(a_l, bp_grad_l):
+ """
+ Compute offline BP cosine similarity.
+ a_l: (batch, d) credit signal
+ bp_grad_l: (batch, d) true BP gradient at layer l
+ """
+ return cosine_similarity_batch(a_l, bp_grad_l)
+
+
+def bridge_residual(V_phi, V_bar_phi, h_l, t_l, s, h_l_next_noisy_list, t_l_next, lam=0.1):
+ """
+ Compute bridge residual R_l.
+
+ Args:
+ V_phi: value network
+ V_bar_phi: EMA target value network
+ h_l: (batch, d)
+ t_l: (batch,)
+ s: (batch, s_dim)
+ h_l_next_noisy_list: list of K tensors (batch, d), noisy next states
+ t_l_next: (batch,)
+ lam: temperature
+
+ Returns:
+ mean absolute bridge residual
+ """
+ with torch.no_grad():
+ V_current = V_phi(h_l, t_l, s) # (batch,)
+
+ # Compute soft-min target
+ K = len(h_l_next_noisy_list)
+ log_terms = []
+ for h_next in h_l_next_noisy_list:
+ V_next = V_bar_phi(h_next, t_l_next, s) # (batch,)
+ log_terms.append(-V_next / lam)
+
+ log_terms = torch.stack(log_terms, dim=-1) # (batch, K)
+ V_target = -lam * torch.logsumexp(log_terms, dim=-1) + lam * np.log(K)
+
+ residual = (V_current - V_target).abs().mean().item()
+ return residual
+
+
+def feature_drift(model_init_params, model_final_params):
+ """
+ Compute per-layer feature drift M_l = ||W_final - W_init||_F / ||W_init||_F.
+
+ Args:
+ model_init_params: dict of {name: tensor} initial parameters
+ model_final_params: dict of {name: tensor} final parameters
+
+ Returns:
+ dict of {name: drift_ratio}
+ """
+ drifts = {}
+ for name in model_init_params:
+ if name in model_final_params:
+ w_init = model_init_params[name]
+ w_final = model_final_params[name]
+ init_norm = w_init.norm().item()
+ if init_norm > 1e-8:
+ drift = (w_final - w_init).norm().item() / init_norm
+ drifts[name] = drift
+ return drifts