From 896df7f11b441a9b8dfa50820024a82884da58d0 Mon Sep 17 00:00:00 2001 From: BLUESKY477 Date: Fri, 22 May 2026 19:23:44 -0500 Subject: Add files via upload --- resulets/adapt/fit_theta_lm_head_update.py | 102 +++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 resulets/adapt/fit_theta_lm_head_update.py (limited to 'resulets/adapt/fit_theta_lm_head_update.py') diff --git a/resulets/adapt/fit_theta_lm_head_update.py b/resulets/adapt/fit_theta_lm_head_update.py new file mode 100644 index 0000000..7a38e62 --- /dev/null +++ b/resulets/adapt/fit_theta_lm_head_update.py @@ -0,0 +1,102 @@ +"""Fit theta_u for a low-rank LM-head weight update. + +The update is W'_u = W + gamma * alpha * C diag(theta_u) A, so the +per-token logit correction depends on the current hidden state. +""" + +import torch +import torch.nn.functional as F + + +CHUNK_SIZE = 16 + + +def _backward_chunked_ce_kl( + h_cpu, lm_w, lm_bias, y_cpu, head_update, theta, beta, blend_gamma, device, total_tokens +): + total_ce_value = 0.0 + total_kl_value = 0.0 + + for start in range(0, h_cpu.shape[0], CHUNK_SIZE): + end = min(start + CHUNK_SIZE, h_cpu.shape[0]) + h_chunk = h_cpu[start:end].to(device).float() + y_chunk = y_cpu[start:end].to(device) + + base_logits = F.linear(h_chunk, lm_w, lm_bias) + delta_logits = head_update.logit_delta(h_chunk, theta) + logits = base_logits + blend_gamma * delta_logits + + ce = F.cross_entropy(logits, y_chunk, reduction='sum') + + if beta > 0: + log_p = F.log_softmax(logits, dim=-1) + p0 = F.softmax(base_logits.detach(), dim=-1) + kl = F.kl_div(log_p, p0, reduction='sum') + else: + kl = torch.zeros((), device=device) + + ((ce + beta * kl) / max(total_tokens, 1)).backward() + total_ce_value += float(ce.detach().cpu()) + total_kl_value += float(kl.detach().cpu()) + + if beta > 0: + del log_p, p0 + + del h_chunk, y_chunk, base_logits, delta_logits, logits, ce, kl + + return total_ce_value, total_kl_value + + +def fit_theta_lm_head_update( + cached_h: list, + lm_head_weight: torch.Tensor, + lm_head_bias: torch.Tensor | None, + head_update, + d: int = 64, + lr: float = 0.05, + steps: int = 30, + beta: float = 0.05, + lam: float = 1e-4, + blend_gamma: float = 0.5, + max_grad_norm: float = 5.0, + device: str = "cuda:0", + verbose: bool = False, +) -> torch.Tensor: + """Fit the user vector theta_u for an LM-head update.""" + theta = torch.zeros(d, device=device, requires_grad=True, dtype=torch.float32) + optimizer = torch.optim.Adam([theta], lr=lr) + + lm_w = lm_head_weight.float() + lm_b = lm_head_bias.float() if lm_head_bias is not None else None + + total_tokens = sum(y_cpu.shape[0] for _, y_cpu in cached_h) + + for step in range(steps): + total_ce_value = 0.0 + total_kl_value = 0.0 + optimizer.zero_grad() + + for h_cpu, y_cpu in cached_h: + ce_value, kl_value = _backward_chunked_ce_kl( + h_cpu, lm_w, lm_b, y_cpu, head_update, theta, beta, blend_gamma, device, total_tokens + ) + total_ce_value += ce_value + total_kl_value += kl_value + + reg = lam * theta.square().sum() + reg.backward() + torch.nn.utils.clip_grad_norm_([theta], max_norm=max_grad_norm) + optimizer.step() + + with torch.no_grad(): + norm = theta.norm() + if norm > max_grad_norm: + theta.mul_(max_grad_norm / norm) + + if verbose and (step % 10 == 0 or step == steps - 1): + loss_value = (total_ce_value + beta * total_kl_value) / max(total_tokens, 1) + float(reg.detach().cpu()) + print(f" Step {step:3d}: loss={loss_value:.4f}, |theta|={theta.norm().item():.4f}") + + del reg + + return theta.detach() -- cgit v1.2.3