summaryrefslogtreecommitdiff
path: root/resulets/adapt/fit_theta_lm_head_update.py
diff options
context:
space:
mode:
Diffstat (limited to 'resulets/adapt/fit_theta_lm_head_update.py')
-rw-r--r--resulets/adapt/fit_theta_lm_head_update.py102
1 files changed, 102 insertions, 0 deletions
diff --git a/resulets/adapt/fit_theta_lm_head_update.py b/resulets/adapt/fit_theta_lm_head_update.py
new file mode 100644
index 0000000..7a38e62
--- /dev/null
+++ b/resulets/adapt/fit_theta_lm_head_update.py
@@ -0,0 +1,102 @@
+"""Fit theta_u for a low-rank LM-head weight update.
+
+The update is W'_u = W + gamma * alpha * C diag(theta_u) A, so the
+per-token logit correction depends on the current hidden state.
+"""
+
+import torch
+import torch.nn.functional as F
+
+
+CHUNK_SIZE = 16
+
+
+def _backward_chunked_ce_kl(
+ h_cpu, lm_w, lm_bias, y_cpu, head_update, theta, beta, blend_gamma, device, total_tokens
+):
+ total_ce_value = 0.0
+ total_kl_value = 0.0
+
+ for start in range(0, h_cpu.shape[0], CHUNK_SIZE):
+ end = min(start + CHUNK_SIZE, h_cpu.shape[0])
+ h_chunk = h_cpu[start:end].to(device).float()
+ y_chunk = y_cpu[start:end].to(device)
+
+ base_logits = F.linear(h_chunk, lm_w, lm_bias)
+ delta_logits = head_update.logit_delta(h_chunk, theta)
+ logits = base_logits + blend_gamma * delta_logits
+
+ ce = F.cross_entropy(logits, y_chunk, reduction='sum')
+
+ if beta > 0:
+ log_p = F.log_softmax(logits, dim=-1)
+ p0 = F.softmax(base_logits.detach(), dim=-1)
+ kl = F.kl_div(log_p, p0, reduction='sum')
+ else:
+ kl = torch.zeros((), device=device)
+
+ ((ce + beta * kl) / max(total_tokens, 1)).backward()
+ total_ce_value += float(ce.detach().cpu())
+ total_kl_value += float(kl.detach().cpu())
+
+ if beta > 0:
+ del log_p, p0
+
+ del h_chunk, y_chunk, base_logits, delta_logits, logits, ce, kl
+
+ return total_ce_value, total_kl_value
+
+
+def fit_theta_lm_head_update(
+ cached_h: list,
+ lm_head_weight: torch.Tensor,
+ lm_head_bias: torch.Tensor | None,
+ head_update,
+ d: int = 64,
+ lr: float = 0.05,
+ steps: int = 30,
+ beta: float = 0.05,
+ lam: float = 1e-4,
+ blend_gamma: float = 0.5,
+ max_grad_norm: float = 5.0,
+ device: str = "cuda:0",
+ verbose: bool = False,
+) -> torch.Tensor:
+ """Fit the user vector theta_u for an LM-head update."""
+ theta = torch.zeros(d, device=device, requires_grad=True, dtype=torch.float32)
+ optimizer = torch.optim.Adam([theta], lr=lr)
+
+ lm_w = lm_head_weight.float()
+ lm_b = lm_head_bias.float() if lm_head_bias is not None else None
+
+ total_tokens = sum(y_cpu.shape[0] for _, y_cpu in cached_h)
+
+ for step in range(steps):
+ total_ce_value = 0.0
+ total_kl_value = 0.0
+ optimizer.zero_grad()
+
+ for h_cpu, y_cpu in cached_h:
+ ce_value, kl_value = _backward_chunked_ce_kl(
+ h_cpu, lm_w, lm_b, y_cpu, head_update, theta, beta, blend_gamma, device, total_tokens
+ )
+ total_ce_value += ce_value
+ total_kl_value += kl_value
+
+ reg = lam * theta.square().sum()
+ reg.backward()
+ torch.nn.utils.clip_grad_norm_([theta], max_norm=max_grad_norm)
+ optimizer.step()
+
+ with torch.no_grad():
+ norm = theta.norm()
+ if norm > max_grad_norm:
+ theta.mul_(max_grad_norm / norm)
+
+ if verbose and (step % 10 == 0 or step == steps - 1):
+ loss_value = (total_ce_value + beta * total_kl_value) / max(total_tokens, 1) + float(reg.detach().cpu())
+ print(f" Step {step:3d}: loss={loss_value:.4f}, |theta|={theta.norm().item():.4f}")
+
+ del reg
+
+ return theta.detach()