""" Credit assignment diagnostic metrics: 1. Exact costate cosine (for toy LQ) 2. Local perturbation correlation rho_l 3. Nudging test Delta_l^nudge 4. Offline BP cosine Gamma_l 5. Bridge residual R_l 6. Feature drift M_l """ import torch import torch.nn.functional as F import numpy as np from scipy.stats import pearsonr def cosine_similarity_batch(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """Compute cosine similarity between a and b along last dim, averaged over batch.""" a_flat = a.reshape(a.shape[0], -1) b_flat = b.reshape(b.shape[0], -1) cos = F.cosine_similarity(a_flat, b_flat, dim=-1) return cos.mean().item() def perturbation_correlation(h_l, a_l, forward_fn, epsilon=1e-3, M=32): """ Compute local perturbation correlation rho_l. Args: h_l: (batch, d) hidden state at layer l a_l: (batch, d) credit signal at layer l forward_fn: callable that takes h_l -> scalar loss (averaged over batch dims handled inside) epsilon: perturbation magnitude M: number of random directions Returns: rho: Pearson correlation between predicted and true loss changes """ batch_size, d = h_l.shape device = h_l.device pred_list = [] true_list = [] base_loss = forward_fn(h_l) # (batch,) or scalar for _ in range(M): v = torch.randn(batch_size, d, device=device) v = v / (v.norm(dim=-1, keepdim=True) + 1e-8) # Predicted change: delta_pred = (a_l * (epsilon * v)).sum(dim=-1) # (batch,) # True change: forward from perturbed h perturbed_loss = forward_fn(h_l + epsilon * v) # (batch,) delta_true = perturbed_loss - base_loss # (batch,) pred_list.append(delta_pred.detach().cpu().numpy()) true_list.append(delta_true.detach().cpu().numpy()) pred_arr = np.concatenate(pred_list) true_arr = np.concatenate(true_list) if np.std(pred_arr) < 1e-12 or np.std(true_arr) < 1e-12: return 0.0 rho, _ = pearsonr(pred_arr, true_arr) return float(rho) def nudging_test(h_l, a_l, forward_fn, eta=0.01): """ Nudging test: check if moving h_l in -a_l direction decreases loss. Args: h_l: (batch, d) hidden state a_l: (batch, d) credit signal forward_fn: callable h -> loss per sample (batch,) eta: step size Returns: mean delta_nudge (negative is good) """ rms_a = (a_l ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_normed = a_l / rms_a h_nudged = h_l - eta * a_normed base_loss = forward_fn(h_l) nudged_loss = forward_fn(h_nudged) delta = (nudged_loss - base_loss).mean().item() return delta def offline_bp_cosine(a_l, bp_grad_l): """ Compute offline BP cosine similarity. a_l: (batch, d) credit signal bp_grad_l: (batch, d) true BP gradient at layer l """ return cosine_similarity_batch(a_l, bp_grad_l) def bridge_residual(V_phi, V_bar_phi, h_l, t_l, s, h_l_next_noisy_list, t_l_next, lam=0.1): """ Compute bridge residual R_l. Args: V_phi: value network V_bar_phi: EMA target value network h_l: (batch, d) t_l: (batch,) s: (batch, s_dim) h_l_next_noisy_list: list of K tensors (batch, d), noisy next states t_l_next: (batch,) lam: temperature Returns: mean absolute bridge residual """ with torch.no_grad(): V_current = V_phi(h_l, t_l, s) # (batch,) # Compute soft-min target K = len(h_l_next_noisy_list) log_terms = [] for h_next in h_l_next_noisy_list: V_next = V_bar_phi(h_next, t_l_next, s) # (batch,) log_terms.append(-V_next / lam) log_terms = torch.stack(log_terms, dim=-1) # (batch, K) V_target = -lam * torch.logsumexp(log_terms, dim=-1) + lam * np.log(K) residual = (V_current - V_target).abs().mean().item() return residual def feature_drift(model_init_params, model_final_params): """ Compute per-layer feature drift M_l = ||W_final - W_init||_F / ||W_init||_F. Args: model_init_params: dict of {name: tensor} initial parameters model_final_params: dict of {name: tensor} final parameters Returns: dict of {name: drift_ratio} """ drifts = {} for name in model_init_params: if name in model_final_params: w_init = model_init_params[name] w_final = model_final_params[name] init_norm = w_init.norm().item() if init_norm > 1e-8: drift = (w_final - w_init).norm().item() / init_norm drifts[name] = drift return drifts