"""Fit theta_u on cached hidden states. Loss = CE(lm_head(h + alpha * B(theta ⊙ A@h)), y) + beta * KL(p_theta || p_0) + lambda * ||theta||^2 """ import torch import torch.nn.functional as F # Maximum chunk size for logit computation to avoid OOM CHUNK_SIZE = 128 def _chunked_ce_kl(h_prime, h_base, lm_w, lm_bias, y, beta): """Compute CE + KL in chunks to avoid OOM from huge vocab logits.""" seq_len = h_prime.shape[0] total_ce = 0.0 total_kl = 0.0 for start in range(0, seq_len, CHUNK_SIZE): end = min(start + CHUNK_SIZE, seq_len) hp_chunk = h_prime[start:end] hb_chunk = h_base[start:end] y_chunk = y[start:end] logits = F.linear(hp_chunk, lm_w, lm_bias) base_logits = F.linear(hb_chunk, lm_w, lm_bias) total_ce = total_ce + F.cross_entropy(logits, y_chunk, reduction='sum') if beta > 0: log_p = F.log_softmax(logits, dim=-1) p0 = F.softmax(base_logits.detach(), dim=-1) total_kl = total_kl + F.kl_div(log_p, p0, reduction='sum') # Free intermediates del logits, base_logits if beta > 0: del log_p, p0 return total_ce, total_kl def fit_theta( cached_h: list, # list of (h_states: (T_i, H), label_ids: (T_i,)) lm_head_weight: torch.Tensor, # (vocab_size, H) lm_head_bias: torch.Tensor | None, head_module, # CVHHead or UnconditionalHead d: int = 64, lr: float = 0.05, steps: int = 30, beta: float = 0.05, lam: float = 1e-4, max_grad_norm: float = 5.0, device: str = "cuda:1", verbose: bool = False, ) -> torch.Tensor: """Fit a user vector theta_u on cached hidden states. Memory-efficient: computes logits in chunks, no pre-computation of base logits. """ theta = torch.zeros(d, device=device, requires_grad=True, dtype=torch.float32) optimizer = torch.optim.Adam([theta], lr=lr) lm_w = lm_head_weight.float() lm_b = lm_head_bias.float() if lm_head_bias is not None else None for step in range(steps): total_loss = 0.0 total_tokens = 0 for h_cpu, y_cpu in cached_h: h = h_cpu.to(device).float() y = y_cpu.to(device) # Apply head to get personalized hidden states h_prime = head_module(h, theta) # Compute CE + KL in chunks ce, kl = _chunked_ce_kl(h_prime, h.detach(), lm_w, lm_b, y, beta) total_loss = total_loss + ce + beta * kl total_tokens += y.shape[0] # Free GPU memory del h, y, h_prime # Average over tokens + L2 reg loss = total_loss / max(total_tokens, 1) + lam * theta.square().sum() optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_([theta], max_norm=max_grad_norm) optimizer.step() # Clip theta L2 norm with torch.no_grad(): norm = theta.norm() if norm > max_grad_norm: theta.mul_(max_grad_norm / norm) if verbose and (step % 10 == 0 or step == steps - 1): print(f" Step {step:3d}: loss={loss.item():.4f}, |theta|={theta.norm().item():.4f}") # Free graph del total_loss, loss return theta.detach()