"""Static user logit-bias baselines. These baselines test whether UPH's gains can be explained by a trivial user-specific vocabulary prior. They do not modify hidden states or model weights; generation only receives an additive bias on the output logits. """ import torch import torch.nn.functional as F from transformers import LogitsProcessor, LogitsProcessorList CHUNK_SIZE = 32 class StaticBiasLogitsProcessor(LogitsProcessor): """Add a fixed user-specific vector to next-token scores.""" def __init__(self, bias: torch.Tensor): self.bias = bias def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): bias = self.bias if bias.numel() != scores.shape[-1]: adjusted = torch.zeros(scores.shape[-1], dtype=bias.dtype) n = min(bias.numel(), scores.shape[-1]) adjusted[:n] = bias[:n] bias = adjusted return scores + bias.to(device=scores.device, dtype=scores.dtype).unsqueeze(0) def _encode_output(tokenizer, text: str) -> list[int]: return tokenizer.encode(text or "", add_special_tokens=False) def _count_tokens(tokenizer, texts: list[str], vocab_size: int, smoothing: float = 0.1): counts = torch.full((vocab_size,), smoothing, dtype=torch.float32) for text in texts: ids = _encode_output(tokenizer, text) if ids: binc = torch.bincount(torch.tensor(ids, dtype=torch.long), minlength=vocab_size) counts += binc.float() return counts def build_global_log_probs( tokenizer, support_sets: list[list[dict]], smoothing: float = 0.1, vocab_size: int | None = None, ): """Estimate the background token distribution from support outputs only.""" if vocab_size is None: vocab_size = len(tokenizer) texts = [ item["support_output"] for support in support_sets for item in support if item.get("support_output") ] counts = _count_tokens(tokenizer, texts, vocab_size, smoothing=smoothing) return torch.log(counts / counts.sum()) def build_user_unigram_bias( tokenizer, support_items: list[dict], global_log_probs: torch.Tensor, vocab_size: int | None = None, top_m: int = 512, scale: float = 0.5, smoothing: float = 0.1, only_positive: bool = True, ): """Build a sparse log-odds vocabulary bias from one user's support outputs.""" if vocab_size is None: vocab_size = global_log_probs.numel() texts = [item["support_output"] for item in support_items if item.get("support_output")] counts = _count_tokens(tokenizer, texts, vocab_size, smoothing=smoothing) user_log_probs = torch.log(counts / counts.sum()) log_odds = user_log_probs - global_log_probs special_ids = { tid for tid in [ getattr(tokenizer, "pad_token_id", None), getattr(tokenizer, "eos_token_id", None), getattr(tokenizer, "bos_token_id", None), ] if tid is not None } for tid in special_ids: if 0 <= tid < log_odds.numel(): log_odds[tid] = float("-inf") if only_positive: log_odds = torch.clamp(log_odds, min=0.0) k = min(top_m, log_odds.numel()) values, token_ids = torch.topk(log_odds, k=k) keep = torch.isfinite(values) & (values > 0) values = values[keep] token_ids = token_ids[keep] bias = torch.zeros(vocab_size, dtype=torch.float32) if token_ids.numel() > 0: bias[token_ids] = scale * values return bias, token_ids.tolist() def generate_with_logit_bias( wrapper, prompt: str, bias: torch.Tensor, max_new_tokens: int = 512, min_new_tokens: int = 128, temperature: float = 0.0, ): """Generate with a fixed additive bias on the model's next-token logits.""" chat_messages = [ {"role": "system", "content": "You are a helpful writing assistant."}, {"role": "user", "content": prompt}, ] prompt_text = wrapper.tokenizer.apply_chat_template( chat_messages, tokenize=False, add_generation_prompt=True ) input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device) logits_processor = LogitsProcessorList([StaticBiasLogitsProcessor(bias)]) with torch.no_grad(): outputs = wrapper.model.generate( input_ids, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, temperature=temperature if temperature > 0 else None, top_p=None, do_sample=temperature > 0, pad_token_id=wrapper.tokenizer.pad_token_id, logits_processor=logits_processor, ) generated_ids = outputs[0, input_ids.shape[1]:] return wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True) def _precompute_sparse_bias_terms(h, y, lm_w, lm_b, token_ids, token_id_to_pos): selected_logits = [] base_lse = [] non_selected_lse = [] target_base_logits = [] target_selected_pos = [] with torch.no_grad(): for start in range(0, h.shape[0], CHUNK_SIZE): end = min(start + CHUNK_SIZE, h.shape[0]) h_chunk = h[start:end] y_chunk = y[start:end] base_logits = F.linear(h_chunk, lm_w, lm_b) selected = base_logits[:, token_ids] all_lse = torch.logsumexp(base_logits, dim=-1) selected_lse = torch.logsumexp(selected, dim=-1) # Logsumexp over tokens that do not receive the learned sparse bias. # This lets every optimization step avoid recomputing full-vocab logits. selected_mass = torch.exp(selected_lse - all_lse).clamp(max=1.0 - 1e-7) rest_lse = all_lse + torch.log1p(-selected_mass) selected_logits.append(selected.detach().cpu()) base_lse.append(all_lse.detach().cpu()) non_selected_lse.append(rest_lse.detach().cpu()) target_base_logits.append( base_logits.gather(1, y_chunk.unsqueeze(1)).squeeze(1).detach().cpu() ) target_selected_pos.append(token_id_to_pos[y_chunk].detach().cpu()) del base_logits, selected, all_lse, selected_lse, rest_lse return { "selected_logits": torch.cat(selected_logits, dim=0), "base_lse": torch.cat(base_lse, dim=0), "non_selected_lse": torch.cat(non_selected_lse, dim=0), "target_base_logits": torch.cat(target_base_logits, dim=0), "target_selected_pos": torch.cat(target_selected_pos, dim=0), "num_tokens": y.shape[0], } def _sparse_bias_loss_from_terms(terms, bias_params, beta, device): selected_logits = terms["selected_logits"].to(device) base_lse = terms["base_lse"].to(device) non_selected_lse = terms["non_selected_lse"].to(device) target_base_logits = terms["target_base_logits"].to(device) target_selected_pos = terms["target_selected_pos"].to(device) adjusted_selected = selected_logits + bias_params.unsqueeze(0) selected_adjusted_lse = torch.logsumexp(adjusted_selected, dim=-1) denom = torch.logaddexp(non_selected_lse, selected_adjusted_lse) target_logits = target_base_logits selected_mask = target_selected_pos >= 0 if selected_mask.any(): target_logits = target_logits.clone() selected_rows = selected_mask.nonzero(as_tuple=False).squeeze(1) selected_cols = target_selected_pos[selected_rows] target_logits[selected_rows] = adjusted_selected[selected_rows, selected_cols] ce = (denom - target_logits).sum() if beta > 0: selected_base_probs = torch.exp(selected_logits - base_lse.unsqueeze(1)) selected_bias_expectation = (selected_base_probs * bias_params.unsqueeze(0)).sum(dim=-1) kl = (denom - base_lse - selected_bias_expectation).sum() else: kl = torch.zeros((), device=device, dtype=ce.dtype) del selected_logits, base_lse, non_selected_lse del target_base_logits, target_selected_pos, adjusted_selected return ce, kl def fit_sparse_logit_bias( cached_h: list, lm_head_weight: torch.Tensor, lm_head_bias: torch.Tensor | None, token_ids: list[int], vocab_size: int, init_values: torch.Tensor | None = None, lr: float = 0.05, steps: int = 30, beta: float = 0.05, lam: float = 1e-4, max_grad_norm: float = 5.0, device: str = "cuda:0", verbose: bool = False, ): """Fit a sparse vocabulary bias on support hidden states.""" if not cached_h or not token_ids: return torch.zeros(vocab_size, dtype=torch.float32), 0 token_ids_tensor = torch.tensor(token_ids, dtype=torch.long, device=device) token_id_to_pos = torch.full((vocab_size,), -1, dtype=torch.long, device=device) token_id_to_pos[token_ids_tensor] = torch.arange(len(token_ids), device=device) if init_values is None: bias_params = torch.zeros(len(token_ids), device=device, dtype=torch.float32) else: bias_params = init_values.to(device=device, dtype=torch.float32).clone() bias_params.requires_grad_(True) optimizer = torch.optim.Adam([bias_params], lr=lr) lm_w = lm_head_weight.float() lm_b = lm_head_bias.float() if lm_head_bias is not None else None precomputed = [] total_tokens = 0 for h_cpu, y_cpu in cached_h: h = h_cpu.to(device).float() y = y_cpu.to(device) terms = _precompute_sparse_bias_terms(h, y, lm_w, lm_b, token_ids_tensor, token_id_to_pos) precomputed.append(terms) total_tokens += terms["num_tokens"] del h, y for step in range(steps): total_loss = 0.0 for terms in precomputed: ce, kl = _sparse_bias_loss_from_terms(terms, bias_params, beta, device) total_loss = total_loss + ce + beta * kl loss = total_loss / max(total_tokens, 1) + lam * bias_params.square().sum() optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_([bias_params], max_norm=max_grad_norm) optimizer.step() with torch.no_grad(): norm = bias_params.norm() if norm > max_grad_norm: bias_params.mul_(max_grad_norm / norm) if verbose and (step % 10 == 0 or step == steps - 1): print(f" Step {step:3d}: loss={loss.item():.4f}, |bias|={bias_params.norm().item():.4f}") del total_loss, loss bias = torch.zeros(vocab_size, dtype=torch.float32) bias[token_ids] = bias_params.detach().cpu() return bias, len(token_ids)