summaryrefslogtreecommitdiff
path: root/resulets/baselines/logit_bias.py
diff options
context:
space:
mode:
authorBLUESKY477 <abcd15803148000@163.com>2026-05-22 19:23:44 -0500
committerGitHub <noreply@github.com>2026-05-22 19:23:44 -0500
commit896df7f11b441a9b8dfa50820024a82884da58d0 (patch)
tree0182ae4a7a0bb16ee6a764393838a580e1ba1c31 /resulets/baselines/logit_bias.py
parent6f48c4fae3243e280b27a977c6a8cb731becf446 (diff)
Add files via uploadHEADmaster
Diffstat (limited to 'resulets/baselines/logit_bias.py')
-rw-r--r--resulets/baselines/logit_bias.py286
1 files changed, 286 insertions, 0 deletions
diff --git a/resulets/baselines/logit_bias.py b/resulets/baselines/logit_bias.py
new file mode 100644
index 0000000..a219d04
--- /dev/null
+++ b/resulets/baselines/logit_bias.py
@@ -0,0 +1,286 @@
+"""Static user logit-bias baselines.
+
+These baselines test whether UPH's gains can be explained by a trivial
+user-specific vocabulary prior. They do not modify hidden states or model
+weights; generation only receives an additive bias on the output logits.
+"""
+
+import torch
+import torch.nn.functional as F
+from transformers import LogitsProcessor, LogitsProcessorList
+
+
+CHUNK_SIZE = 32
+
+
+class StaticBiasLogitsProcessor(LogitsProcessor):
+ """Add a fixed user-specific vector to next-token scores."""
+
+ def __init__(self, bias: torch.Tensor):
+ self.bias = bias
+
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
+ bias = self.bias
+ if bias.numel() != scores.shape[-1]:
+ adjusted = torch.zeros(scores.shape[-1], dtype=bias.dtype)
+ n = min(bias.numel(), scores.shape[-1])
+ adjusted[:n] = bias[:n]
+ bias = adjusted
+ return scores + bias.to(device=scores.device, dtype=scores.dtype).unsqueeze(0)
+
+
+def _encode_output(tokenizer, text: str) -> list[int]:
+ return tokenizer.encode(text or "", add_special_tokens=False)
+
+
+def _count_tokens(tokenizer, texts: list[str], vocab_size: int, smoothing: float = 0.1):
+ counts = torch.full((vocab_size,), smoothing, dtype=torch.float32)
+ for text in texts:
+ ids = _encode_output(tokenizer, text)
+ if ids:
+ binc = torch.bincount(torch.tensor(ids, dtype=torch.long), minlength=vocab_size)
+ counts += binc.float()
+ return counts
+
+
+def build_global_log_probs(
+ tokenizer,
+ support_sets: list[list[dict]],
+ smoothing: float = 0.1,
+ vocab_size: int | None = None,
+):
+ """Estimate the background token distribution from support outputs only."""
+ if vocab_size is None:
+ vocab_size = len(tokenizer)
+ texts = [
+ item["support_output"]
+ for support in support_sets
+ for item in support
+ if item.get("support_output")
+ ]
+ counts = _count_tokens(tokenizer, texts, vocab_size, smoothing=smoothing)
+ return torch.log(counts / counts.sum())
+
+
+def build_user_unigram_bias(
+ tokenizer,
+ support_items: list[dict],
+ global_log_probs: torch.Tensor,
+ vocab_size: int | None = None,
+ top_m: int = 512,
+ scale: float = 0.5,
+ smoothing: float = 0.1,
+ only_positive: bool = True,
+):
+ """Build a sparse log-odds vocabulary bias from one user's support outputs."""
+ if vocab_size is None:
+ vocab_size = global_log_probs.numel()
+ texts = [item["support_output"] for item in support_items if item.get("support_output")]
+ counts = _count_tokens(tokenizer, texts, vocab_size, smoothing=smoothing)
+ user_log_probs = torch.log(counts / counts.sum())
+ log_odds = user_log_probs - global_log_probs
+
+ special_ids = {
+ tid for tid in [
+ getattr(tokenizer, "pad_token_id", None),
+ getattr(tokenizer, "eos_token_id", None),
+ getattr(tokenizer, "bos_token_id", None),
+ ] if tid is not None
+ }
+ for tid in special_ids:
+ if 0 <= tid < log_odds.numel():
+ log_odds[tid] = float("-inf")
+
+ if only_positive:
+ log_odds = torch.clamp(log_odds, min=0.0)
+
+ k = min(top_m, log_odds.numel())
+ values, token_ids = torch.topk(log_odds, k=k)
+ keep = torch.isfinite(values) & (values > 0)
+ values = values[keep]
+ token_ids = token_ids[keep]
+
+ bias = torch.zeros(vocab_size, dtype=torch.float32)
+ if token_ids.numel() > 0:
+ bias[token_ids] = scale * values
+ return bias, token_ids.tolist()
+
+
+def generate_with_logit_bias(
+ wrapper,
+ prompt: str,
+ bias: torch.Tensor,
+ max_new_tokens: int = 512,
+ min_new_tokens: int = 128,
+ temperature: float = 0.0,
+):
+ """Generate with a fixed additive bias on the model's next-token logits."""
+ chat_messages = [
+ {"role": "system", "content": "You are a helpful writing assistant."},
+ {"role": "user", "content": prompt},
+ ]
+ prompt_text = wrapper.tokenizer.apply_chat_template(
+ chat_messages, tokenize=False, add_generation_prompt=True
+ )
+ input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device)
+
+ logits_processor = LogitsProcessorList([StaticBiasLogitsProcessor(bias)])
+ with torch.no_grad():
+ outputs = wrapper.model.generate(
+ input_ids,
+ max_new_tokens=max_new_tokens,
+ min_new_tokens=min_new_tokens,
+ temperature=temperature if temperature > 0 else None,
+ top_p=None,
+ do_sample=temperature > 0,
+ pad_token_id=wrapper.tokenizer.pad_token_id,
+ logits_processor=logits_processor,
+ )
+ generated_ids = outputs[0, input_ids.shape[1]:]
+ return wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True)
+
+
+def _precompute_sparse_bias_terms(h, y, lm_w, lm_b, token_ids, token_id_to_pos):
+ selected_logits = []
+ base_lse = []
+ non_selected_lse = []
+ target_base_logits = []
+ target_selected_pos = []
+
+ with torch.no_grad():
+ for start in range(0, h.shape[0], CHUNK_SIZE):
+ end = min(start + CHUNK_SIZE, h.shape[0])
+ h_chunk = h[start:end]
+ y_chunk = y[start:end]
+
+ base_logits = F.linear(h_chunk, lm_w, lm_b)
+ selected = base_logits[:, token_ids]
+ all_lse = torch.logsumexp(base_logits, dim=-1)
+ selected_lse = torch.logsumexp(selected, dim=-1)
+
+ # Logsumexp over tokens that do not receive the learned sparse bias.
+ # This lets every optimization step avoid recomputing full-vocab logits.
+ selected_mass = torch.exp(selected_lse - all_lse).clamp(max=1.0 - 1e-7)
+ rest_lse = all_lse + torch.log1p(-selected_mass)
+
+ selected_logits.append(selected.detach().cpu())
+ base_lse.append(all_lse.detach().cpu())
+ non_selected_lse.append(rest_lse.detach().cpu())
+ target_base_logits.append(
+ base_logits.gather(1, y_chunk.unsqueeze(1)).squeeze(1).detach().cpu()
+ )
+ target_selected_pos.append(token_id_to_pos[y_chunk].detach().cpu())
+
+ del base_logits, selected, all_lse, selected_lse, rest_lse
+
+ return {
+ "selected_logits": torch.cat(selected_logits, dim=0),
+ "base_lse": torch.cat(base_lse, dim=0),
+ "non_selected_lse": torch.cat(non_selected_lse, dim=0),
+ "target_base_logits": torch.cat(target_base_logits, dim=0),
+ "target_selected_pos": torch.cat(target_selected_pos, dim=0),
+ "num_tokens": y.shape[0],
+ }
+
+
+def _sparse_bias_loss_from_terms(terms, bias_params, beta, device):
+ selected_logits = terms["selected_logits"].to(device)
+ base_lse = terms["base_lse"].to(device)
+ non_selected_lse = terms["non_selected_lse"].to(device)
+ target_base_logits = terms["target_base_logits"].to(device)
+ target_selected_pos = terms["target_selected_pos"].to(device)
+
+ adjusted_selected = selected_logits + bias_params.unsqueeze(0)
+ selected_adjusted_lse = torch.logsumexp(adjusted_selected, dim=-1)
+ denom = torch.logaddexp(non_selected_lse, selected_adjusted_lse)
+
+ target_logits = target_base_logits
+ selected_mask = target_selected_pos >= 0
+ if selected_mask.any():
+ target_logits = target_logits.clone()
+ selected_rows = selected_mask.nonzero(as_tuple=False).squeeze(1)
+ selected_cols = target_selected_pos[selected_rows]
+ target_logits[selected_rows] = adjusted_selected[selected_rows, selected_cols]
+
+ ce = (denom - target_logits).sum()
+
+ if beta > 0:
+ selected_base_probs = torch.exp(selected_logits - base_lse.unsqueeze(1))
+ selected_bias_expectation = (selected_base_probs * bias_params.unsqueeze(0)).sum(dim=-1)
+ kl = (denom - base_lse - selected_bias_expectation).sum()
+ else:
+ kl = torch.zeros((), device=device, dtype=ce.dtype)
+
+ del selected_logits, base_lse, non_selected_lse
+ del target_base_logits, target_selected_pos, adjusted_selected
+ return ce, kl
+
+
+def fit_sparse_logit_bias(
+ cached_h: list,
+ lm_head_weight: torch.Tensor,
+ lm_head_bias: torch.Tensor | None,
+ token_ids: list[int],
+ vocab_size: int,
+ init_values: torch.Tensor | None = None,
+ lr: float = 0.05,
+ steps: int = 30,
+ beta: float = 0.05,
+ lam: float = 1e-4,
+ max_grad_norm: float = 5.0,
+ device: str = "cuda:0",
+ verbose: bool = False,
+):
+ """Fit a sparse vocabulary bias on support hidden states."""
+ if not cached_h or not token_ids:
+ return torch.zeros(vocab_size, dtype=torch.float32), 0
+
+ token_ids_tensor = torch.tensor(token_ids, dtype=torch.long, device=device)
+ token_id_to_pos = torch.full((vocab_size,), -1, dtype=torch.long, device=device)
+ token_id_to_pos[token_ids_tensor] = torch.arange(len(token_ids), device=device)
+ if init_values is None:
+ bias_params = torch.zeros(len(token_ids), device=device, dtype=torch.float32)
+ else:
+ bias_params = init_values.to(device=device, dtype=torch.float32).clone()
+ bias_params.requires_grad_(True)
+
+ optimizer = torch.optim.Adam([bias_params], lr=lr)
+ lm_w = lm_head_weight.float()
+ lm_b = lm_head_bias.float() if lm_head_bias is not None else None
+ precomputed = []
+ total_tokens = 0
+
+ for h_cpu, y_cpu in cached_h:
+ h = h_cpu.to(device).float()
+ y = y_cpu.to(device)
+ terms = _precompute_sparse_bias_terms(h, y, lm_w, lm_b, token_ids_tensor, token_id_to_pos)
+ precomputed.append(terms)
+ total_tokens += terms["num_tokens"]
+ del h, y
+
+ for step in range(steps):
+ total_loss = 0.0
+
+ for terms in precomputed:
+ ce, kl = _sparse_bias_loss_from_terms(terms, bias_params, beta, device)
+ total_loss = total_loss + ce + beta * kl
+
+ loss = total_loss / max(total_tokens, 1) + lam * bias_params.square().sum()
+ optimizer.zero_grad()
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_([bias_params], max_norm=max_grad_norm)
+ optimizer.step()
+
+ with torch.no_grad():
+ norm = bias_params.norm()
+ if norm > max_grad_norm:
+ bias_params.mul_(max_grad_norm / norm)
+
+ if verbose and (step % 10 == 0 or step == steps - 1):
+ print(f" Step {step:3d}: loss={loss.item():.4f}, |bias|={bias_params.norm().item():.4f}")
+
+ del total_loss, loss
+
+ bias = torch.zeros(vocab_size, dtype=torch.float32)
+ bias[token_ids] = bias_params.detach().cpu()
+ return bias, len(token_ids)