diff options
Diffstat (limited to 'metrics/credit_metrics.py')
| -rw-r--r-- | metrics/credit_metrics.py | 156 |
1 files changed, 156 insertions, 0 deletions
diff --git a/metrics/credit_metrics.py b/metrics/credit_metrics.py new file mode 100644 index 0000000..516dca2 --- /dev/null +++ b/metrics/credit_metrics.py @@ -0,0 +1,156 @@ +""" +Credit assignment diagnostic metrics: +1. Exact costate cosine (for toy LQ) +2. Local perturbation correlation rho_l +3. Nudging test Delta_l^nudge +4. Offline BP cosine Gamma_l +5. Bridge residual R_l +6. Feature drift M_l +""" +import torch +import torch.nn.functional as F +import numpy as np +from scipy.stats import pearsonr + + +def cosine_similarity_batch(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """Compute cosine similarity between a and b along last dim, averaged over batch.""" + a_flat = a.reshape(a.shape[0], -1) + b_flat = b.reshape(b.shape[0], -1) + cos = F.cosine_similarity(a_flat, b_flat, dim=-1) + return cos.mean().item() + + +def perturbation_correlation(h_l, a_l, forward_fn, epsilon=1e-3, M=32): + """ + Compute local perturbation correlation rho_l. + + Args: + h_l: (batch, d) hidden state at layer l + a_l: (batch, d) credit signal at layer l + forward_fn: callable that takes h_l -> scalar loss (averaged over batch dims handled inside) + epsilon: perturbation magnitude + M: number of random directions + + Returns: + rho: Pearson correlation between predicted and true loss changes + """ + batch_size, d = h_l.shape + device = h_l.device + + pred_list = [] + true_list = [] + + base_loss = forward_fn(h_l) # (batch,) or scalar + + for _ in range(M): + v = torch.randn(batch_size, d, device=device) + v = v / (v.norm(dim=-1, keepdim=True) + 1e-8) + + # Predicted change: <a_l, epsilon * v> + delta_pred = (a_l * (epsilon * v)).sum(dim=-1) # (batch,) + + # True change: forward from perturbed h + perturbed_loss = forward_fn(h_l + epsilon * v) # (batch,) + delta_true = perturbed_loss - base_loss # (batch,) + + pred_list.append(delta_pred.detach().cpu().numpy()) + true_list.append(delta_true.detach().cpu().numpy()) + + pred_arr = np.concatenate(pred_list) + true_arr = np.concatenate(true_list) + + if np.std(pred_arr) < 1e-12 or np.std(true_arr) < 1e-12: + return 0.0 + + rho, _ = pearsonr(pred_arr, true_arr) + return float(rho) + + +def nudging_test(h_l, a_l, forward_fn, eta=0.01): + """ + Nudging test: check if moving h_l in -a_l direction decreases loss. + + Args: + h_l: (batch, d) hidden state + a_l: (batch, d) credit signal + forward_fn: callable h -> loss per sample (batch,) + eta: step size + + Returns: + mean delta_nudge (negative is good) + """ + rms_a = (a_l ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_normed = a_l / rms_a + h_nudged = h_l - eta * a_normed + + base_loss = forward_fn(h_l) + nudged_loss = forward_fn(h_nudged) + delta = (nudged_loss - base_loss).mean().item() + return delta + + +def offline_bp_cosine(a_l, bp_grad_l): + """ + Compute offline BP cosine similarity. + a_l: (batch, d) credit signal + bp_grad_l: (batch, d) true BP gradient at layer l + """ + return cosine_similarity_batch(a_l, bp_grad_l) + + +def bridge_residual(V_phi, V_bar_phi, h_l, t_l, s, h_l_next_noisy_list, t_l_next, lam=0.1): + """ + Compute bridge residual R_l. + + Args: + V_phi: value network + V_bar_phi: EMA target value network + h_l: (batch, d) + t_l: (batch,) + s: (batch, s_dim) + h_l_next_noisy_list: list of K tensors (batch, d), noisy next states + t_l_next: (batch,) + lam: temperature + + Returns: + mean absolute bridge residual + """ + with torch.no_grad(): + V_current = V_phi(h_l, t_l, s) # (batch,) + + # Compute soft-min target + K = len(h_l_next_noisy_list) + log_terms = [] + for h_next in h_l_next_noisy_list: + V_next = V_bar_phi(h_next, t_l_next, s) # (batch,) + log_terms.append(-V_next / lam) + + log_terms = torch.stack(log_terms, dim=-1) # (batch, K) + V_target = -lam * torch.logsumexp(log_terms, dim=-1) + lam * np.log(K) + + residual = (V_current - V_target).abs().mean().item() + return residual + + +def feature_drift(model_init_params, model_final_params): + """ + Compute per-layer feature drift M_l = ||W_final - W_init||_F / ||W_init||_F. + + Args: + model_init_params: dict of {name: tensor} initial parameters + model_final_params: dict of {name: tensor} final parameters + + Returns: + dict of {name: drift_ratio} + """ + drifts = {} + for name in model_init_params: + if name in model_final_params: + w_init = model_init_params[name] + w_final = model_final_params[name] + init_norm = w_init.norm().item() + if init_norm > 1e-8: + drift = (w_final - w_init).norm().item() / init_norm + drifts[name] = drift + return drifts |
