summaryrefslogtreecommitdiff
path: root/resulets/baselines
diff options
context:
space:
mode:
authorBLUESKY477 <abcd15803148000@163.com>2026-05-22 19:23:44 -0500
committerGitHub <noreply@github.com>2026-05-22 19:23:44 -0500
commit896df7f11b441a9b8dfa50820024a82884da58d0 (patch)
tree0182ae4a7a0bb16ee6a764393838a580e1ba1c31 /resulets/baselines
parent6f48c4fae3243e280b27a977c6a8cb731becf446 (diff)
Add files via uploadHEADmaster
Diffstat (limited to 'resulets/baselines')
-rw-r--r--resulets/baselines/dense_retrieval.py143
-rw-r--r--resulets/baselines/logit_bias.py286
2 files changed, 429 insertions, 0 deletions
diff --git a/resulets/baselines/dense_retrieval.py b/resulets/baselines/dense_retrieval.py
new file mode 100644
index 0000000..a627319
--- /dev/null
+++ b/resulets/baselines/dense_retrieval.py
@@ -0,0 +1,143 @@
+"""Dense Retrieval ICL baselines.
+
+Uses sentence-transformers for dense retrieval over the user support set,
+then places top-K retrieved items as in-context examples.
+"""
+
+from dataclasses import dataclass
+
+import torch
+
+
+@dataclass(frozen=True)
+class DenseRetrieverConfig:
+ method_name: str
+ model_name: str
+ text_mode: str = "input_output"
+ query_prefix: str = ""
+ passage_prefix: str = ""
+ normalize_embeddings: bool = True
+ citation_year: str = ""
+ description: str = ""
+
+
+DENSE_RETRIEVER_CONFIGS = {
+ "dense_minilm_top1": DenseRetrieverConfig(
+ method_name="dense_minilm_top1",
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
+ citation_year="MiniLM 2020; Sentence-Transformers checkpoint circa 2021",
+ description="Lightweight SBERT/MiniLM dense retriever.",
+ ),
+ "dense_mpnet_top1": DenseRetrieverConfig(
+ method_name="dense_mpnet_top1",
+ model_name="sentence-transformers/all-mpnet-base-v2",
+ citation_year="MPNet 2020; Sentence-Transformers checkpoint circa 2021",
+ description="Stronger SBERT/MPNet dense retriever.",
+ ),
+ "dense_e5_top1": DenseRetrieverConfig(
+ method_name="dense_e5_top1",
+ model_name="intfloat/e5-base-v2",
+ query_prefix="query: ",
+ passage_prefix="passage: ",
+ citation_year="E5 2022",
+ description="E5 dense retriever with the model-card query/passage prefixes.",
+ ),
+ "dense_bge_top1": DenseRetrieverConfig(
+ method_name="dense_bge_top1",
+ model_name="BAAI/bge-base-en-v1.5",
+ query_prefix="Represent this sentence for searching relevant passages: ",
+ citation_year="BGE v1.5 2023",
+ description="BGE v1.5 dense retriever with the recommended query instruction.",
+ ),
+}
+
+
+def get_dense_retriever_config(method_name: str) -> DenseRetrieverConfig:
+ return DENSE_RETRIEVER_CONFIGS[method_name]
+
+
+class DenseRetriever:
+ """Dense retriever using sentence-transformers embeddings."""
+
+ def __init__(
+ self,
+ model_name='sentence-transformers/all-MiniLM-L6-v2',
+ device='cpu',
+ text_mode='input_output',
+ query_prefix='',
+ passage_prefix='',
+ normalize_embeddings=True,
+ ):
+ self.model_name = model_name
+ self.text_mode = text_mode
+ self.query_prefix = query_prefix
+ self.passage_prefix = passage_prefix
+ self.normalize_embeddings = normalize_embeddings
+ from sentence_transformers import SentenceTransformer
+ self.model = SentenceTransformer(model_name, device=device)
+
+ def _support_text(self, item: dict) -> str:
+ if self.text_mode == 'input':
+ return item['support_input']
+ if self.text_mode == 'output':
+ return item['support_output']
+ if self.text_mode == 'input_output':
+ return f"{item['support_input']}\n{item['support_output']}"
+ raise ValueError(f"Unknown dense retrieval text_mode: {self.text_mode}")
+
+ def retrieve_top_k(self, query: str, support_items: list, k: int = 1, return_metadata=False):
+ """Retrieve top-k support items most relevant to query.
+
+ Args:
+ query: query input text
+ support_items: list of dicts with 'support_input', 'support_output'
+ k: number of items to retrieve
+ return_metadata: whether to also return retrieval diagnostics
+
+ Returns:
+ List of top-k support items, optionally with metadata.
+ """
+ if len(support_items) <= k:
+ metadata = [
+ {
+ 'rank': rank + 1,
+ 'support_index': rank,
+ 'score': None,
+ 'model_name': self.model_name,
+ 'text_mode': self.text_mode,
+ }
+ for rank in range(len(support_items))
+ ]
+ return (support_items, metadata) if return_metadata else support_items
+
+ query_text = self.query_prefix + query
+ texts = [self.passage_prefix + self._support_text(item) for item in support_items]
+ embeddings = self.model.encode(
+ [query_text] + texts,
+ convert_to_tensor=True,
+ normalize_embeddings=self.normalize_embeddings,
+ )
+
+ query_emb = embeddings[0]
+ item_embs = embeddings[1:]
+
+ if self.normalize_embeddings:
+ similarities = item_embs @ query_emb
+ else:
+ similarities = torch.nn.functional.cosine_similarity(
+ query_emb.unsqueeze(0), item_embs, dim=1
+ )
+
+ top_indices = similarities.argsort(descending=True)[:k].tolist()
+ selected = [support_items[i] for i in top_indices]
+ metadata = [
+ {
+ 'rank': rank + 1,
+ 'support_index': idx,
+ 'score': float(similarities[idx].detach().cpu()),
+ 'model_name': self.model_name,
+ 'text_mode': self.text_mode,
+ }
+ for rank, idx in enumerate(top_indices)
+ ]
+ return (selected, metadata) if return_metadata else selected
diff --git a/resulets/baselines/logit_bias.py b/resulets/baselines/logit_bias.py
new file mode 100644
index 0000000..a219d04
--- /dev/null
+++ b/resulets/baselines/logit_bias.py
@@ -0,0 +1,286 @@
+"""Static user logit-bias baselines.
+
+These baselines test whether UPH's gains can be explained by a trivial
+user-specific vocabulary prior. They do not modify hidden states or model
+weights; generation only receives an additive bias on the output logits.
+"""
+
+import torch
+import torch.nn.functional as F
+from transformers import LogitsProcessor, LogitsProcessorList
+
+
+CHUNK_SIZE = 32
+
+
+class StaticBiasLogitsProcessor(LogitsProcessor):
+ """Add a fixed user-specific vector to next-token scores."""
+
+ def __init__(self, bias: torch.Tensor):
+ self.bias = bias
+
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
+ bias = self.bias
+ if bias.numel() != scores.shape[-1]:
+ adjusted = torch.zeros(scores.shape[-1], dtype=bias.dtype)
+ n = min(bias.numel(), scores.shape[-1])
+ adjusted[:n] = bias[:n]
+ bias = adjusted
+ return scores + bias.to(device=scores.device, dtype=scores.dtype).unsqueeze(0)
+
+
+def _encode_output(tokenizer, text: str) -> list[int]:
+ return tokenizer.encode(text or "", add_special_tokens=False)
+
+
+def _count_tokens(tokenizer, texts: list[str], vocab_size: int, smoothing: float = 0.1):
+ counts = torch.full((vocab_size,), smoothing, dtype=torch.float32)
+ for text in texts:
+ ids = _encode_output(tokenizer, text)
+ if ids:
+ binc = torch.bincount(torch.tensor(ids, dtype=torch.long), minlength=vocab_size)
+ counts += binc.float()
+ return counts
+
+
+def build_global_log_probs(
+ tokenizer,
+ support_sets: list[list[dict]],
+ smoothing: float = 0.1,
+ vocab_size: int | None = None,
+):
+ """Estimate the background token distribution from support outputs only."""
+ if vocab_size is None:
+ vocab_size = len(tokenizer)
+ texts = [
+ item["support_output"]
+ for support in support_sets
+ for item in support
+ if item.get("support_output")
+ ]
+ counts = _count_tokens(tokenizer, texts, vocab_size, smoothing=smoothing)
+ return torch.log(counts / counts.sum())
+
+
+def build_user_unigram_bias(
+ tokenizer,
+ support_items: list[dict],
+ global_log_probs: torch.Tensor,
+ vocab_size: int | None = None,
+ top_m: int = 512,
+ scale: float = 0.5,
+ smoothing: float = 0.1,
+ only_positive: bool = True,
+):
+ """Build a sparse log-odds vocabulary bias from one user's support outputs."""
+ if vocab_size is None:
+ vocab_size = global_log_probs.numel()
+ texts = [item["support_output"] for item in support_items if item.get("support_output")]
+ counts = _count_tokens(tokenizer, texts, vocab_size, smoothing=smoothing)
+ user_log_probs = torch.log(counts / counts.sum())
+ log_odds = user_log_probs - global_log_probs
+
+ special_ids = {
+ tid for tid in [
+ getattr(tokenizer, "pad_token_id", None),
+ getattr(tokenizer, "eos_token_id", None),
+ getattr(tokenizer, "bos_token_id", None),
+ ] if tid is not None
+ }
+ for tid in special_ids:
+ if 0 <= tid < log_odds.numel():
+ log_odds[tid] = float("-inf")
+
+ if only_positive:
+ log_odds = torch.clamp(log_odds, min=0.0)
+
+ k = min(top_m, log_odds.numel())
+ values, token_ids = torch.topk(log_odds, k=k)
+ keep = torch.isfinite(values) & (values > 0)
+ values = values[keep]
+ token_ids = token_ids[keep]
+
+ bias = torch.zeros(vocab_size, dtype=torch.float32)
+ if token_ids.numel() > 0:
+ bias[token_ids] = scale * values
+ return bias, token_ids.tolist()
+
+
+def generate_with_logit_bias(
+ wrapper,
+ prompt: str,
+ bias: torch.Tensor,
+ max_new_tokens: int = 512,
+ min_new_tokens: int = 128,
+ temperature: float = 0.0,
+):
+ """Generate with a fixed additive bias on the model's next-token logits."""
+ chat_messages = [
+ {"role": "system", "content": "You are a helpful writing assistant."},
+ {"role": "user", "content": prompt},
+ ]
+ prompt_text = wrapper.tokenizer.apply_chat_template(
+ chat_messages, tokenize=False, add_generation_prompt=True
+ )
+ input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device)
+
+ logits_processor = LogitsProcessorList([StaticBiasLogitsProcessor(bias)])
+ with torch.no_grad():
+ outputs = wrapper.model.generate(
+ input_ids,
+ max_new_tokens=max_new_tokens,
+ min_new_tokens=min_new_tokens,
+ temperature=temperature if temperature > 0 else None,
+ top_p=None,
+ do_sample=temperature > 0,
+ pad_token_id=wrapper.tokenizer.pad_token_id,
+ logits_processor=logits_processor,
+ )
+ generated_ids = outputs[0, input_ids.shape[1]:]
+ return wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True)
+
+
+def _precompute_sparse_bias_terms(h, y, lm_w, lm_b, token_ids, token_id_to_pos):
+ selected_logits = []
+ base_lse = []
+ non_selected_lse = []
+ target_base_logits = []
+ target_selected_pos = []
+
+ with torch.no_grad():
+ for start in range(0, h.shape[0], CHUNK_SIZE):
+ end = min(start + CHUNK_SIZE, h.shape[0])
+ h_chunk = h[start:end]
+ y_chunk = y[start:end]
+
+ base_logits = F.linear(h_chunk, lm_w, lm_b)
+ selected = base_logits[:, token_ids]
+ all_lse = torch.logsumexp(base_logits, dim=-1)
+ selected_lse = torch.logsumexp(selected, dim=-1)
+
+ # Logsumexp over tokens that do not receive the learned sparse bias.
+ # This lets every optimization step avoid recomputing full-vocab logits.
+ selected_mass = torch.exp(selected_lse - all_lse).clamp(max=1.0 - 1e-7)
+ rest_lse = all_lse + torch.log1p(-selected_mass)
+
+ selected_logits.append(selected.detach().cpu())
+ base_lse.append(all_lse.detach().cpu())
+ non_selected_lse.append(rest_lse.detach().cpu())
+ target_base_logits.append(
+ base_logits.gather(1, y_chunk.unsqueeze(1)).squeeze(1).detach().cpu()
+ )
+ target_selected_pos.append(token_id_to_pos[y_chunk].detach().cpu())
+
+ del base_logits, selected, all_lse, selected_lse, rest_lse
+
+ return {
+ "selected_logits": torch.cat(selected_logits, dim=0),
+ "base_lse": torch.cat(base_lse, dim=0),
+ "non_selected_lse": torch.cat(non_selected_lse, dim=0),
+ "target_base_logits": torch.cat(target_base_logits, dim=0),
+ "target_selected_pos": torch.cat(target_selected_pos, dim=0),
+ "num_tokens": y.shape[0],
+ }
+
+
+def _sparse_bias_loss_from_terms(terms, bias_params, beta, device):
+ selected_logits = terms["selected_logits"].to(device)
+ base_lse = terms["base_lse"].to(device)
+ non_selected_lse = terms["non_selected_lse"].to(device)
+ target_base_logits = terms["target_base_logits"].to(device)
+ target_selected_pos = terms["target_selected_pos"].to(device)
+
+ adjusted_selected = selected_logits + bias_params.unsqueeze(0)
+ selected_adjusted_lse = torch.logsumexp(adjusted_selected, dim=-1)
+ denom = torch.logaddexp(non_selected_lse, selected_adjusted_lse)
+
+ target_logits = target_base_logits
+ selected_mask = target_selected_pos >= 0
+ if selected_mask.any():
+ target_logits = target_logits.clone()
+ selected_rows = selected_mask.nonzero(as_tuple=False).squeeze(1)
+ selected_cols = target_selected_pos[selected_rows]
+ target_logits[selected_rows] = adjusted_selected[selected_rows, selected_cols]
+
+ ce = (denom - target_logits).sum()
+
+ if beta > 0:
+ selected_base_probs = torch.exp(selected_logits - base_lse.unsqueeze(1))
+ selected_bias_expectation = (selected_base_probs * bias_params.unsqueeze(0)).sum(dim=-1)
+ kl = (denom - base_lse - selected_bias_expectation).sum()
+ else:
+ kl = torch.zeros((), device=device, dtype=ce.dtype)
+
+ del selected_logits, base_lse, non_selected_lse
+ del target_base_logits, target_selected_pos, adjusted_selected
+ return ce, kl
+
+
+def fit_sparse_logit_bias(
+ cached_h: list,
+ lm_head_weight: torch.Tensor,
+ lm_head_bias: torch.Tensor | None,
+ token_ids: list[int],
+ vocab_size: int,
+ init_values: torch.Tensor | None = None,
+ lr: float = 0.05,
+ steps: int = 30,
+ beta: float = 0.05,
+ lam: float = 1e-4,
+ max_grad_norm: float = 5.0,
+ device: str = "cuda:0",
+ verbose: bool = False,
+):
+ """Fit a sparse vocabulary bias on support hidden states."""
+ if not cached_h or not token_ids:
+ return torch.zeros(vocab_size, dtype=torch.float32), 0
+
+ token_ids_tensor = torch.tensor(token_ids, dtype=torch.long, device=device)
+ token_id_to_pos = torch.full((vocab_size,), -1, dtype=torch.long, device=device)
+ token_id_to_pos[token_ids_tensor] = torch.arange(len(token_ids), device=device)
+ if init_values is None:
+ bias_params = torch.zeros(len(token_ids), device=device, dtype=torch.float32)
+ else:
+ bias_params = init_values.to(device=device, dtype=torch.float32).clone()
+ bias_params.requires_grad_(True)
+
+ optimizer = torch.optim.Adam([bias_params], lr=lr)
+ lm_w = lm_head_weight.float()
+ lm_b = lm_head_bias.float() if lm_head_bias is not None else None
+ precomputed = []
+ total_tokens = 0
+
+ for h_cpu, y_cpu in cached_h:
+ h = h_cpu.to(device).float()
+ y = y_cpu.to(device)
+ terms = _precompute_sparse_bias_terms(h, y, lm_w, lm_b, token_ids_tensor, token_id_to_pos)
+ precomputed.append(terms)
+ total_tokens += terms["num_tokens"]
+ del h, y
+
+ for step in range(steps):
+ total_loss = 0.0
+
+ for terms in precomputed:
+ ce, kl = _sparse_bias_loss_from_terms(terms, bias_params, beta, device)
+ total_loss = total_loss + ce + beta * kl
+
+ loss = total_loss / max(total_tokens, 1) + lam * bias_params.square().sum()
+ optimizer.zero_grad()
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_([bias_params], max_norm=max_grad_norm)
+ optimizer.step()
+
+ with torch.no_grad():
+ norm = bias_params.norm()
+ if norm > max_grad_norm:
+ bias_params.mul_(max_grad_norm / norm)
+
+ if verbose and (step % 10 == 0 or step == steps - 1):
+ print(f" Step {step:3d}: loss={loss.item():.4f}, |bias|={bias_params.norm().item():.4f}")
+
+ del total_loss, loss
+
+ bias = torch.zeros(vocab_size, dtype=torch.float32)
+ bias[token_ids] = bias_params.detach().cpu()
+ return bias, len(token_ids)