"""Fit theta_u for a low-rank LM-head weight update. The update is W'_u = W + gamma * alpha * C diag(theta_u) A, so the per-token logit correction depends on the current hidden state. """ import torch import torch.nn.functional as F CHUNK_SIZE = 16 def _backward_chunked_ce_kl( h_cpu, lm_w, lm_bias, y_cpu, head_update, theta, beta, blend_gamma, device, total_tokens ): total_ce_value = 0.0 total_kl_value = 0.0 for start in range(0, h_cpu.shape[0], CHUNK_SIZE): end = min(start + CHUNK_SIZE, h_cpu.shape[0]) h_chunk = h_cpu[start:end].to(device).float() y_chunk = y_cpu[start:end].to(device) base_logits = F.linear(h_chunk, lm_w, lm_bias) delta_logits = head_update.logit_delta(h_chunk, theta) logits = base_logits + blend_gamma * delta_logits ce = F.cross_entropy(logits, y_chunk, reduction='sum') if beta > 0: log_p = F.log_softmax(logits, dim=-1) p0 = F.softmax(base_logits.detach(), dim=-1) kl = F.kl_div(log_p, p0, reduction='sum') else: kl = torch.zeros((), device=device) ((ce + beta * kl) / max(total_tokens, 1)).backward() total_ce_value += float(ce.detach().cpu()) total_kl_value += float(kl.detach().cpu()) if beta > 0: del log_p, p0 del h_chunk, y_chunk, base_logits, delta_logits, logits, ce, kl return total_ce_value, total_kl_value def fit_theta_lm_head_update( cached_h: list, lm_head_weight: torch.Tensor, lm_head_bias: torch.Tensor | None, head_update, d: int = 64, lr: float = 0.05, steps: int = 30, beta: float = 0.05, lam: float = 1e-4, blend_gamma: float = 0.5, max_grad_norm: float = 5.0, device: str = "cuda:0", verbose: bool = False, ) -> torch.Tensor: """Fit the user vector theta_u for an LM-head update.""" theta = torch.zeros(d, device=device, requires_grad=True, dtype=torch.float32) optimizer = torch.optim.Adam([theta], lr=lr) lm_w = lm_head_weight.float() lm_b = lm_head_bias.float() if lm_head_bias is not None else None total_tokens = sum(y_cpu.shape[0] for _, y_cpu in cached_h) for step in range(steps): total_ce_value = 0.0 total_kl_value = 0.0 optimizer.zero_grad() for h_cpu, y_cpu in cached_h: ce_value, kl_value = _backward_chunked_ce_kl( h_cpu, lm_w, lm_b, y_cpu, head_update, theta, beta, blend_gamma, device, total_tokens ) total_ce_value += ce_value total_kl_value += kl_value reg = lam * theta.square().sum() reg.backward() torch.nn.utils.clip_grad_norm_([theta], max_norm=max_grad_norm) optimizer.step() with torch.no_grad(): norm = theta.norm() if norm > max_grad_norm: theta.mul_(max_grad_norm / norm) if verbose and (step % 10 == 0 or step == steps - 1): loss_value = (total_ce_value + beta * total_kl_value) / max(total_tokens, 1) + float(reg.detach().cpu()) print(f" Step {step:3d}: loss={loss_value:.4f}, |theta|={theta.norm().item():.4f}") del reg return theta.detach()