From 896df7f11b441a9b8dfa50820024a82884da58d0 Mon Sep 17 00:00:00 2001 From: BLUESKY477 Date: Fri, 22 May 2026 19:23:44 -0500 Subject: Add files via upload --- resulets/baselines/dense_retrieval.py | 143 +++++++++++++++++ resulets/baselines/logit_bias.py | 286 ++++++++++++++++++++++++++++++++++ 2 files changed, 429 insertions(+) create mode 100644 resulets/baselines/dense_retrieval.py create mode 100644 resulets/baselines/logit_bias.py (limited to 'resulets/baselines') diff --git a/resulets/baselines/dense_retrieval.py b/resulets/baselines/dense_retrieval.py new file mode 100644 index 0000000..a627319 --- /dev/null +++ b/resulets/baselines/dense_retrieval.py @@ -0,0 +1,143 @@ +"""Dense Retrieval ICL baselines. + +Uses sentence-transformers for dense retrieval over the user support set, +then places top-K retrieved items as in-context examples. +""" + +from dataclasses import dataclass + +import torch + + +@dataclass(frozen=True) +class DenseRetrieverConfig: + method_name: str + model_name: str + text_mode: str = "input_output" + query_prefix: str = "" + passage_prefix: str = "" + normalize_embeddings: bool = True + citation_year: str = "" + description: str = "" + + +DENSE_RETRIEVER_CONFIGS = { + "dense_minilm_top1": DenseRetrieverConfig( + method_name="dense_minilm_top1", + model_name="sentence-transformers/all-MiniLM-L6-v2", + citation_year="MiniLM 2020; Sentence-Transformers checkpoint circa 2021", + description="Lightweight SBERT/MiniLM dense retriever.", + ), + "dense_mpnet_top1": DenseRetrieverConfig( + method_name="dense_mpnet_top1", + model_name="sentence-transformers/all-mpnet-base-v2", + citation_year="MPNet 2020; Sentence-Transformers checkpoint circa 2021", + description="Stronger SBERT/MPNet dense retriever.", + ), + "dense_e5_top1": DenseRetrieverConfig( + method_name="dense_e5_top1", + model_name="intfloat/e5-base-v2", + query_prefix="query: ", + passage_prefix="passage: ", + citation_year="E5 2022", + description="E5 dense retriever with the model-card query/passage prefixes.", + ), + "dense_bge_top1": DenseRetrieverConfig( + method_name="dense_bge_top1", + model_name="BAAI/bge-base-en-v1.5", + query_prefix="Represent this sentence for searching relevant passages: ", + citation_year="BGE v1.5 2023", + description="BGE v1.5 dense retriever with the recommended query instruction.", + ), +} + + +def get_dense_retriever_config(method_name: str) -> DenseRetrieverConfig: + return DENSE_RETRIEVER_CONFIGS[method_name] + + +class DenseRetriever: + """Dense retriever using sentence-transformers embeddings.""" + + def __init__( + self, + model_name='sentence-transformers/all-MiniLM-L6-v2', + device='cpu', + text_mode='input_output', + query_prefix='', + passage_prefix='', + normalize_embeddings=True, + ): + self.model_name = model_name + self.text_mode = text_mode + self.query_prefix = query_prefix + self.passage_prefix = passage_prefix + self.normalize_embeddings = normalize_embeddings + from sentence_transformers import SentenceTransformer + self.model = SentenceTransformer(model_name, device=device) + + def _support_text(self, item: dict) -> str: + if self.text_mode == 'input': + return item['support_input'] + if self.text_mode == 'output': + return item['support_output'] + if self.text_mode == 'input_output': + return f"{item['support_input']}\n{item['support_output']}" + raise ValueError(f"Unknown dense retrieval text_mode: {self.text_mode}") + + def retrieve_top_k(self, query: str, support_items: list, k: int = 1, return_metadata=False): + """Retrieve top-k support items most relevant to query. + + Args: + query: query input text + support_items: list of dicts with 'support_input', 'support_output' + k: number of items to retrieve + return_metadata: whether to also return retrieval diagnostics + + Returns: + List of top-k support items, optionally with metadata. + """ + if len(support_items) <= k: + metadata = [ + { + 'rank': rank + 1, + 'support_index': rank, + 'score': None, + 'model_name': self.model_name, + 'text_mode': self.text_mode, + } + for rank in range(len(support_items)) + ] + return (support_items, metadata) if return_metadata else support_items + + query_text = self.query_prefix + query + texts = [self.passage_prefix + self._support_text(item) for item in support_items] + embeddings = self.model.encode( + [query_text] + texts, + convert_to_tensor=True, + normalize_embeddings=self.normalize_embeddings, + ) + + query_emb = embeddings[0] + item_embs = embeddings[1:] + + if self.normalize_embeddings: + similarities = item_embs @ query_emb + else: + similarities = torch.nn.functional.cosine_similarity( + query_emb.unsqueeze(0), item_embs, dim=1 + ) + + top_indices = similarities.argsort(descending=True)[:k].tolist() + selected = [support_items[i] for i in top_indices] + metadata = [ + { + 'rank': rank + 1, + 'support_index': idx, + 'score': float(similarities[idx].detach().cpu()), + 'model_name': self.model_name, + 'text_mode': self.text_mode, + } + for rank, idx in enumerate(top_indices) + ] + return (selected, metadata) if return_metadata else selected diff --git a/resulets/baselines/logit_bias.py b/resulets/baselines/logit_bias.py new file mode 100644 index 0000000..a219d04 --- /dev/null +++ b/resulets/baselines/logit_bias.py @@ -0,0 +1,286 @@ +"""Static user logit-bias baselines. + +These baselines test whether UPH's gains can be explained by a trivial +user-specific vocabulary prior. They do not modify hidden states or model +weights; generation only receives an additive bias on the output logits. +""" + +import torch +import torch.nn.functional as F +from transformers import LogitsProcessor, LogitsProcessorList + + +CHUNK_SIZE = 32 + + +class StaticBiasLogitsProcessor(LogitsProcessor): + """Add a fixed user-specific vector to next-token scores.""" + + def __init__(self, bias: torch.Tensor): + self.bias = bias + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): + bias = self.bias + if bias.numel() != scores.shape[-1]: + adjusted = torch.zeros(scores.shape[-1], dtype=bias.dtype) + n = min(bias.numel(), scores.shape[-1]) + adjusted[:n] = bias[:n] + bias = adjusted + return scores + bias.to(device=scores.device, dtype=scores.dtype).unsqueeze(0) + + +def _encode_output(tokenizer, text: str) -> list[int]: + return tokenizer.encode(text or "", add_special_tokens=False) + + +def _count_tokens(tokenizer, texts: list[str], vocab_size: int, smoothing: float = 0.1): + counts = torch.full((vocab_size,), smoothing, dtype=torch.float32) + for text in texts: + ids = _encode_output(tokenizer, text) + if ids: + binc = torch.bincount(torch.tensor(ids, dtype=torch.long), minlength=vocab_size) + counts += binc.float() + return counts + + +def build_global_log_probs( + tokenizer, + support_sets: list[list[dict]], + smoothing: float = 0.1, + vocab_size: int | None = None, +): + """Estimate the background token distribution from support outputs only.""" + if vocab_size is None: + vocab_size = len(tokenizer) + texts = [ + item["support_output"] + for support in support_sets + for item in support + if item.get("support_output") + ] + counts = _count_tokens(tokenizer, texts, vocab_size, smoothing=smoothing) + return torch.log(counts / counts.sum()) + + +def build_user_unigram_bias( + tokenizer, + support_items: list[dict], + global_log_probs: torch.Tensor, + vocab_size: int | None = None, + top_m: int = 512, + scale: float = 0.5, + smoothing: float = 0.1, + only_positive: bool = True, +): + """Build a sparse log-odds vocabulary bias from one user's support outputs.""" + if vocab_size is None: + vocab_size = global_log_probs.numel() + texts = [item["support_output"] for item in support_items if item.get("support_output")] + counts = _count_tokens(tokenizer, texts, vocab_size, smoothing=smoothing) + user_log_probs = torch.log(counts / counts.sum()) + log_odds = user_log_probs - global_log_probs + + special_ids = { + tid for tid in [ + getattr(tokenizer, "pad_token_id", None), + getattr(tokenizer, "eos_token_id", None), + getattr(tokenizer, "bos_token_id", None), + ] if tid is not None + } + for tid in special_ids: + if 0 <= tid < log_odds.numel(): + log_odds[tid] = float("-inf") + + if only_positive: + log_odds = torch.clamp(log_odds, min=0.0) + + k = min(top_m, log_odds.numel()) + values, token_ids = torch.topk(log_odds, k=k) + keep = torch.isfinite(values) & (values > 0) + values = values[keep] + token_ids = token_ids[keep] + + bias = torch.zeros(vocab_size, dtype=torch.float32) + if token_ids.numel() > 0: + bias[token_ids] = scale * values + return bias, token_ids.tolist() + + +def generate_with_logit_bias( + wrapper, + prompt: str, + bias: torch.Tensor, + max_new_tokens: int = 512, + min_new_tokens: int = 128, + temperature: float = 0.0, +): + """Generate with a fixed additive bias on the model's next-token logits.""" + chat_messages = [ + {"role": "system", "content": "You are a helpful writing assistant."}, + {"role": "user", "content": prompt}, + ] + prompt_text = wrapper.tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True + ) + input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device) + + logits_processor = LogitsProcessorList([StaticBiasLogitsProcessor(bias)]) + with torch.no_grad(): + outputs = wrapper.model.generate( + input_ids, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + temperature=temperature if temperature > 0 else None, + top_p=None, + do_sample=temperature > 0, + pad_token_id=wrapper.tokenizer.pad_token_id, + logits_processor=logits_processor, + ) + generated_ids = outputs[0, input_ids.shape[1]:] + return wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True) + + +def _precompute_sparse_bias_terms(h, y, lm_w, lm_b, token_ids, token_id_to_pos): + selected_logits = [] + base_lse = [] + non_selected_lse = [] + target_base_logits = [] + target_selected_pos = [] + + with torch.no_grad(): + for start in range(0, h.shape[0], CHUNK_SIZE): + end = min(start + CHUNK_SIZE, h.shape[0]) + h_chunk = h[start:end] + y_chunk = y[start:end] + + base_logits = F.linear(h_chunk, lm_w, lm_b) + selected = base_logits[:, token_ids] + all_lse = torch.logsumexp(base_logits, dim=-1) + selected_lse = torch.logsumexp(selected, dim=-1) + + # Logsumexp over tokens that do not receive the learned sparse bias. + # This lets every optimization step avoid recomputing full-vocab logits. + selected_mass = torch.exp(selected_lse - all_lse).clamp(max=1.0 - 1e-7) + rest_lse = all_lse + torch.log1p(-selected_mass) + + selected_logits.append(selected.detach().cpu()) + base_lse.append(all_lse.detach().cpu()) + non_selected_lse.append(rest_lse.detach().cpu()) + target_base_logits.append( + base_logits.gather(1, y_chunk.unsqueeze(1)).squeeze(1).detach().cpu() + ) + target_selected_pos.append(token_id_to_pos[y_chunk].detach().cpu()) + + del base_logits, selected, all_lse, selected_lse, rest_lse + + return { + "selected_logits": torch.cat(selected_logits, dim=0), + "base_lse": torch.cat(base_lse, dim=0), + "non_selected_lse": torch.cat(non_selected_lse, dim=0), + "target_base_logits": torch.cat(target_base_logits, dim=0), + "target_selected_pos": torch.cat(target_selected_pos, dim=0), + "num_tokens": y.shape[0], + } + + +def _sparse_bias_loss_from_terms(terms, bias_params, beta, device): + selected_logits = terms["selected_logits"].to(device) + base_lse = terms["base_lse"].to(device) + non_selected_lse = terms["non_selected_lse"].to(device) + target_base_logits = terms["target_base_logits"].to(device) + target_selected_pos = terms["target_selected_pos"].to(device) + + adjusted_selected = selected_logits + bias_params.unsqueeze(0) + selected_adjusted_lse = torch.logsumexp(adjusted_selected, dim=-1) + denom = torch.logaddexp(non_selected_lse, selected_adjusted_lse) + + target_logits = target_base_logits + selected_mask = target_selected_pos >= 0 + if selected_mask.any(): + target_logits = target_logits.clone() + selected_rows = selected_mask.nonzero(as_tuple=False).squeeze(1) + selected_cols = target_selected_pos[selected_rows] + target_logits[selected_rows] = adjusted_selected[selected_rows, selected_cols] + + ce = (denom - target_logits).sum() + + if beta > 0: + selected_base_probs = torch.exp(selected_logits - base_lse.unsqueeze(1)) + selected_bias_expectation = (selected_base_probs * bias_params.unsqueeze(0)).sum(dim=-1) + kl = (denom - base_lse - selected_bias_expectation).sum() + else: + kl = torch.zeros((), device=device, dtype=ce.dtype) + + del selected_logits, base_lse, non_selected_lse + del target_base_logits, target_selected_pos, adjusted_selected + return ce, kl + + +def fit_sparse_logit_bias( + cached_h: list, + lm_head_weight: torch.Tensor, + lm_head_bias: torch.Tensor | None, + token_ids: list[int], + vocab_size: int, + init_values: torch.Tensor | None = None, + lr: float = 0.05, + steps: int = 30, + beta: float = 0.05, + lam: float = 1e-4, + max_grad_norm: float = 5.0, + device: str = "cuda:0", + verbose: bool = False, +): + """Fit a sparse vocabulary bias on support hidden states.""" + if not cached_h or not token_ids: + return torch.zeros(vocab_size, dtype=torch.float32), 0 + + token_ids_tensor = torch.tensor(token_ids, dtype=torch.long, device=device) + token_id_to_pos = torch.full((vocab_size,), -1, dtype=torch.long, device=device) + token_id_to_pos[token_ids_tensor] = torch.arange(len(token_ids), device=device) + if init_values is None: + bias_params = torch.zeros(len(token_ids), device=device, dtype=torch.float32) + else: + bias_params = init_values.to(device=device, dtype=torch.float32).clone() + bias_params.requires_grad_(True) + + optimizer = torch.optim.Adam([bias_params], lr=lr) + lm_w = lm_head_weight.float() + lm_b = lm_head_bias.float() if lm_head_bias is not None else None + precomputed = [] + total_tokens = 0 + + for h_cpu, y_cpu in cached_h: + h = h_cpu.to(device).float() + y = y_cpu.to(device) + terms = _precompute_sparse_bias_terms(h, y, lm_w, lm_b, token_ids_tensor, token_id_to_pos) + precomputed.append(terms) + total_tokens += terms["num_tokens"] + del h, y + + for step in range(steps): + total_loss = 0.0 + + for terms in precomputed: + ce, kl = _sparse_bias_loss_from_terms(terms, bias_params, beta, device) + total_loss = total_loss + ce + beta * kl + + loss = total_loss / max(total_tokens, 1) + lam * bias_params.square().sum() + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_([bias_params], max_norm=max_grad_norm) + optimizer.step() + + with torch.no_grad(): + norm = bias_params.norm() + if norm > max_grad_norm: + bias_params.mul_(max_grad_norm / norm) + + if verbose and (step % 10 == 0 or step == steps - 1): + print(f" Step {step:3d}: loss={loss.item():.4f}, |bias|={bias_params.norm().item():.4f}") + + del total_loss, loss + + bias = torch.zeros(vocab_size, dtype=torch.float32) + bias[token_ids] = bias_params.detach().cpu() + return bias, len(token_ids) -- cgit v1.2.3