diff options
| author | BLUESKY477 <abcd15803148000@163.com> | 2026-05-22 19:23:44 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-05-22 19:23:44 -0500 |
| commit | 896df7f11b441a9b8dfa50820024a82884da58d0 (patch) | |
| tree | 0182ae4a7a0bb16ee6a764393838a580e1ba1c31 | |
| parent | 6f48c4fae3243e280b27a977c6a8cb731becf446 (diff) | |
| -rw-r--r-- | resulets/README_UPLOAD.md | 21 | ||||
| -rw-r--r-- | resulets/adapt/fit_theta_lm_head_update.py | 102 | ||||
| -rw-r--r-- | resulets/baselines/dense_retrieval.py | 143 | ||||
| -rw-r--r-- | resulets/baselines/logit_bias.py | 286 | ||||
| -rw-r--r-- | resulets/outputs/dense_retrieval_baselines/dense_summary.csv | 9 | ||||
| -rw-r--r-- | resulets/outputs/dense_retrieval_baselines/dense_summary.md | 20 | ||||
| -rw-r--r-- | resulets/outputs/dense_retrieval_baselines/review_user_K4/summary.json | 39 | ||||
| -rw-r--r-- | resulets/outputs/dense_retrieval_baselines/topic_user_K4/summary.json | 39 | ||||
| -rw-r--r-- | resulets/outputs/diagnostics_full/review_user_K4/summary.json | 25 | ||||
| -rw-r--r-- | resulets/outputs/diagnostics_full/topic_user_K4/summary.json | 25 | ||||
| -rw-r--r-- | resulets/outputs/injection_ablation_rerun/review_user_K4/summary.json | 46 | ||||
| -rw-r--r-- | resulets/outputs/injection_ablation_rerun/topic_user_K4/summary.json | 46 | ||||
| -rw-r--r-- | resulets/scripts/run_all_methods.py | 908 | ||||
| -rw-r--r-- | resulets/scripts/summarize_dense_baselines.py | 63 |
14 files changed, 1772 insertions, 0 deletions
diff --git a/resulets/README_UPLOAD.md b/resulets/README_UPLOAD.md new file mode 100644 index 0000000..a248708 --- /dev/null +++ b/resulets/README_UPLOAD.md @@ -0,0 +1,21 @@ +# Files to upload + +This folder contains only the summarized experiment results and the code files needed for the new UPH experiments. + +Upload these paths to GitHub with the same relative structure: + +- `outputs/dense_retrieval_baselines/dense_summary.csv` +- `outputs/dense_retrieval_baselines/dense_summary.md` +- `outputs/dense_retrieval_baselines/topic_user_K4/summary.json` +- `outputs/dense_retrieval_baselines/review_user_K4/summary.json` +- `outputs/diagnostics_full/topic_user_K4/summary.json` +- `outputs/diagnostics_full/review_user_K4/summary.json` +- `outputs/injection_ablation_rerun/topic_user_K4/summary.json` +- `outputs/injection_ablation_rerun/review_user_K4/summary.json` +- `baselines/dense_retrieval.py` +- `baselines/logit_bias.py` +- `adapt/fit_theta_lm_head_update.py` +- `scripts/run_all_methods.py` +- `scripts/summarize_dense_baselines.py` + +Do not upload raw `per_user.json`, `progress.jsonl`, smoke runs, or logs. diff --git a/resulets/adapt/fit_theta_lm_head_update.py b/resulets/adapt/fit_theta_lm_head_update.py new file mode 100644 index 0000000..7a38e62 --- /dev/null +++ b/resulets/adapt/fit_theta_lm_head_update.py @@ -0,0 +1,102 @@ +"""Fit theta_u for a low-rank LM-head weight update. + +The update is W'_u = W + gamma * alpha * C diag(theta_u) A, so the +per-token logit correction depends on the current hidden state. +""" + +import torch +import torch.nn.functional as F + + +CHUNK_SIZE = 16 + + +def _backward_chunked_ce_kl( + h_cpu, lm_w, lm_bias, y_cpu, head_update, theta, beta, blend_gamma, device, total_tokens +): + total_ce_value = 0.0 + total_kl_value = 0.0 + + for start in range(0, h_cpu.shape[0], CHUNK_SIZE): + end = min(start + CHUNK_SIZE, h_cpu.shape[0]) + h_chunk = h_cpu[start:end].to(device).float() + y_chunk = y_cpu[start:end].to(device) + + base_logits = F.linear(h_chunk, lm_w, lm_bias) + delta_logits = head_update.logit_delta(h_chunk, theta) + logits = base_logits + blend_gamma * delta_logits + + ce = F.cross_entropy(logits, y_chunk, reduction='sum') + + if beta > 0: + log_p = F.log_softmax(logits, dim=-1) + p0 = F.softmax(base_logits.detach(), dim=-1) + kl = F.kl_div(log_p, p0, reduction='sum') + else: + kl = torch.zeros((), device=device) + + ((ce + beta * kl) / max(total_tokens, 1)).backward() + total_ce_value += float(ce.detach().cpu()) + total_kl_value += float(kl.detach().cpu()) + + if beta > 0: + del log_p, p0 + + del h_chunk, y_chunk, base_logits, delta_logits, logits, ce, kl + + return total_ce_value, total_kl_value + + +def fit_theta_lm_head_update( + cached_h: list, + lm_head_weight: torch.Tensor, + lm_head_bias: torch.Tensor | None, + head_update, + d: int = 64, + lr: float = 0.05, + steps: int = 30, + beta: float = 0.05, + lam: float = 1e-4, + blend_gamma: float = 0.5, + max_grad_norm: float = 5.0, + device: str = "cuda:0", + verbose: bool = False, +) -> torch.Tensor: + """Fit the user vector theta_u for an LM-head update.""" + theta = torch.zeros(d, device=device, requires_grad=True, dtype=torch.float32) + optimizer = torch.optim.Adam([theta], lr=lr) + + lm_w = lm_head_weight.float() + lm_b = lm_head_bias.float() if lm_head_bias is not None else None + + total_tokens = sum(y_cpu.shape[0] for _, y_cpu in cached_h) + + for step in range(steps): + total_ce_value = 0.0 + total_kl_value = 0.0 + optimizer.zero_grad() + + for h_cpu, y_cpu in cached_h: + ce_value, kl_value = _backward_chunked_ce_kl( + h_cpu, lm_w, lm_b, y_cpu, head_update, theta, beta, blend_gamma, device, total_tokens + ) + total_ce_value += ce_value + total_kl_value += kl_value + + reg = lam * theta.square().sum() + reg.backward() + torch.nn.utils.clip_grad_norm_([theta], max_norm=max_grad_norm) + optimizer.step() + + with torch.no_grad(): + norm = theta.norm() + if norm > max_grad_norm: + theta.mul_(max_grad_norm / norm) + + if verbose and (step % 10 == 0 or step == steps - 1): + loss_value = (total_ce_value + beta * total_kl_value) / max(total_tokens, 1) + float(reg.detach().cpu()) + print(f" Step {step:3d}: loss={loss_value:.4f}, |theta|={theta.norm().item():.4f}") + + del reg + + return theta.detach() diff --git a/resulets/baselines/dense_retrieval.py b/resulets/baselines/dense_retrieval.py new file mode 100644 index 0000000..a627319 --- /dev/null +++ b/resulets/baselines/dense_retrieval.py @@ -0,0 +1,143 @@ +"""Dense Retrieval ICL baselines. + +Uses sentence-transformers for dense retrieval over the user support set, +then places top-K retrieved items as in-context examples. +""" + +from dataclasses import dataclass + +import torch + + +@dataclass(frozen=True) +class DenseRetrieverConfig: + method_name: str + model_name: str + text_mode: str = "input_output" + query_prefix: str = "" + passage_prefix: str = "" + normalize_embeddings: bool = True + citation_year: str = "" + description: str = "" + + +DENSE_RETRIEVER_CONFIGS = { + "dense_minilm_top1": DenseRetrieverConfig( + method_name="dense_minilm_top1", + model_name="sentence-transformers/all-MiniLM-L6-v2", + citation_year="MiniLM 2020; Sentence-Transformers checkpoint circa 2021", + description="Lightweight SBERT/MiniLM dense retriever.", + ), + "dense_mpnet_top1": DenseRetrieverConfig( + method_name="dense_mpnet_top1", + model_name="sentence-transformers/all-mpnet-base-v2", + citation_year="MPNet 2020; Sentence-Transformers checkpoint circa 2021", + description="Stronger SBERT/MPNet dense retriever.", + ), + "dense_e5_top1": DenseRetrieverConfig( + method_name="dense_e5_top1", + model_name="intfloat/e5-base-v2", + query_prefix="query: ", + passage_prefix="passage: ", + citation_year="E5 2022", + description="E5 dense retriever with the model-card query/passage prefixes.", + ), + "dense_bge_top1": DenseRetrieverConfig( + method_name="dense_bge_top1", + model_name="BAAI/bge-base-en-v1.5", + query_prefix="Represent this sentence for searching relevant passages: ", + citation_year="BGE v1.5 2023", + description="BGE v1.5 dense retriever with the recommended query instruction.", + ), +} + + +def get_dense_retriever_config(method_name: str) -> DenseRetrieverConfig: + return DENSE_RETRIEVER_CONFIGS[method_name] + + +class DenseRetriever: + """Dense retriever using sentence-transformers embeddings.""" + + def __init__( + self, + model_name='sentence-transformers/all-MiniLM-L6-v2', + device='cpu', + text_mode='input_output', + query_prefix='', + passage_prefix='', + normalize_embeddings=True, + ): + self.model_name = model_name + self.text_mode = text_mode + self.query_prefix = query_prefix + self.passage_prefix = passage_prefix + self.normalize_embeddings = normalize_embeddings + from sentence_transformers import SentenceTransformer + self.model = SentenceTransformer(model_name, device=device) + + def _support_text(self, item: dict) -> str: + if self.text_mode == 'input': + return item['support_input'] + if self.text_mode == 'output': + return item['support_output'] + if self.text_mode == 'input_output': + return f"{item['support_input']}\n{item['support_output']}" + raise ValueError(f"Unknown dense retrieval text_mode: {self.text_mode}") + + def retrieve_top_k(self, query: str, support_items: list, k: int = 1, return_metadata=False): + """Retrieve top-k support items most relevant to query. + + Args: + query: query input text + support_items: list of dicts with 'support_input', 'support_output' + k: number of items to retrieve + return_metadata: whether to also return retrieval diagnostics + + Returns: + List of top-k support items, optionally with metadata. + """ + if len(support_items) <= k: + metadata = [ + { + 'rank': rank + 1, + 'support_index': rank, + 'score': None, + 'model_name': self.model_name, + 'text_mode': self.text_mode, + } + for rank in range(len(support_items)) + ] + return (support_items, metadata) if return_metadata else support_items + + query_text = self.query_prefix + query + texts = [self.passage_prefix + self._support_text(item) for item in support_items] + embeddings = self.model.encode( + [query_text] + texts, + convert_to_tensor=True, + normalize_embeddings=self.normalize_embeddings, + ) + + query_emb = embeddings[0] + item_embs = embeddings[1:] + + if self.normalize_embeddings: + similarities = item_embs @ query_emb + else: + similarities = torch.nn.functional.cosine_similarity( + query_emb.unsqueeze(0), item_embs, dim=1 + ) + + top_indices = similarities.argsort(descending=True)[:k].tolist() + selected = [support_items[i] for i in top_indices] + metadata = [ + { + 'rank': rank + 1, + 'support_index': idx, + 'score': float(similarities[idx].detach().cpu()), + 'model_name': self.model_name, + 'text_mode': self.text_mode, + } + for rank, idx in enumerate(top_indices) + ] + return (selected, metadata) if return_metadata else selected diff --git a/resulets/baselines/logit_bias.py b/resulets/baselines/logit_bias.py new file mode 100644 index 0000000..a219d04 --- /dev/null +++ b/resulets/baselines/logit_bias.py @@ -0,0 +1,286 @@ +"""Static user logit-bias baselines. + +These baselines test whether UPH's gains can be explained by a trivial +user-specific vocabulary prior. They do not modify hidden states or model +weights; generation only receives an additive bias on the output logits. +""" + +import torch +import torch.nn.functional as F +from transformers import LogitsProcessor, LogitsProcessorList + + +CHUNK_SIZE = 32 + + +class StaticBiasLogitsProcessor(LogitsProcessor): + """Add a fixed user-specific vector to next-token scores.""" + + def __init__(self, bias: torch.Tensor): + self.bias = bias + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): + bias = self.bias + if bias.numel() != scores.shape[-1]: + adjusted = torch.zeros(scores.shape[-1], dtype=bias.dtype) + n = min(bias.numel(), scores.shape[-1]) + adjusted[:n] = bias[:n] + bias = adjusted + return scores + bias.to(device=scores.device, dtype=scores.dtype).unsqueeze(0) + + +def _encode_output(tokenizer, text: str) -> list[int]: + return tokenizer.encode(text or "", add_special_tokens=False) + + +def _count_tokens(tokenizer, texts: list[str], vocab_size: int, smoothing: float = 0.1): + counts = torch.full((vocab_size,), smoothing, dtype=torch.float32) + for text in texts: + ids = _encode_output(tokenizer, text) + if ids: + binc = torch.bincount(torch.tensor(ids, dtype=torch.long), minlength=vocab_size) + counts += binc.float() + return counts + + +def build_global_log_probs( + tokenizer, + support_sets: list[list[dict]], + smoothing: float = 0.1, + vocab_size: int | None = None, +): + """Estimate the background token distribution from support outputs only.""" + if vocab_size is None: + vocab_size = len(tokenizer) + texts = [ + item["support_output"] + for support in support_sets + for item in support + if item.get("support_output") + ] + counts = _count_tokens(tokenizer, texts, vocab_size, smoothing=smoothing) + return torch.log(counts / counts.sum()) + + +def build_user_unigram_bias( + tokenizer, + support_items: list[dict], + global_log_probs: torch.Tensor, + vocab_size: int | None = None, + top_m: int = 512, + scale: float = 0.5, + smoothing: float = 0.1, + only_positive: bool = True, +): + """Build a sparse log-odds vocabulary bias from one user's support outputs.""" + if vocab_size is None: + vocab_size = global_log_probs.numel() + texts = [item["support_output"] for item in support_items if item.get("support_output")] + counts = _count_tokens(tokenizer, texts, vocab_size, smoothing=smoothing) + user_log_probs = torch.log(counts / counts.sum()) + log_odds = user_log_probs - global_log_probs + + special_ids = { + tid for tid in [ + getattr(tokenizer, "pad_token_id", None), + getattr(tokenizer, "eos_token_id", None), + getattr(tokenizer, "bos_token_id", None), + ] if tid is not None + } + for tid in special_ids: + if 0 <= tid < log_odds.numel(): + log_odds[tid] = float("-inf") + + if only_positive: + log_odds = torch.clamp(log_odds, min=0.0) + + k = min(top_m, log_odds.numel()) + values, token_ids = torch.topk(log_odds, k=k) + keep = torch.isfinite(values) & (values > 0) + values = values[keep] + token_ids = token_ids[keep] + + bias = torch.zeros(vocab_size, dtype=torch.float32) + if token_ids.numel() > 0: + bias[token_ids] = scale * values + return bias, token_ids.tolist() + + +def generate_with_logit_bias( + wrapper, + prompt: str, + bias: torch.Tensor, + max_new_tokens: int = 512, + min_new_tokens: int = 128, + temperature: float = 0.0, +): + """Generate with a fixed additive bias on the model's next-token logits.""" + chat_messages = [ + {"role": "system", "content": "You are a helpful writing assistant."}, + {"role": "user", "content": prompt}, + ] + prompt_text = wrapper.tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True + ) + input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device) + + logits_processor = LogitsProcessorList([StaticBiasLogitsProcessor(bias)]) + with torch.no_grad(): + outputs = wrapper.model.generate( + input_ids, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + temperature=temperature if temperature > 0 else None, + top_p=None, + do_sample=temperature > 0, + pad_token_id=wrapper.tokenizer.pad_token_id, + logits_processor=logits_processor, + ) + generated_ids = outputs[0, input_ids.shape[1]:] + return wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True) + + +def _precompute_sparse_bias_terms(h, y, lm_w, lm_b, token_ids, token_id_to_pos): + selected_logits = [] + base_lse = [] + non_selected_lse = [] + target_base_logits = [] + target_selected_pos = [] + + with torch.no_grad(): + for start in range(0, h.shape[0], CHUNK_SIZE): + end = min(start + CHUNK_SIZE, h.shape[0]) + h_chunk = h[start:end] + y_chunk = y[start:end] + + base_logits = F.linear(h_chunk, lm_w, lm_b) + selected = base_logits[:, token_ids] + all_lse = torch.logsumexp(base_logits, dim=-1) + selected_lse = torch.logsumexp(selected, dim=-1) + + # Logsumexp over tokens that do not receive the learned sparse bias. + # This lets every optimization step avoid recomputing full-vocab logits. + selected_mass = torch.exp(selected_lse - all_lse).clamp(max=1.0 - 1e-7) + rest_lse = all_lse + torch.log1p(-selected_mass) + + selected_logits.append(selected.detach().cpu()) + base_lse.append(all_lse.detach().cpu()) + non_selected_lse.append(rest_lse.detach().cpu()) + target_base_logits.append( + base_logits.gather(1, y_chunk.unsqueeze(1)).squeeze(1).detach().cpu() + ) + target_selected_pos.append(token_id_to_pos[y_chunk].detach().cpu()) + + del base_logits, selected, all_lse, selected_lse, rest_lse + + return { + "selected_logits": torch.cat(selected_logits, dim=0), + "base_lse": torch.cat(base_lse, dim=0), + "non_selected_lse": torch.cat(non_selected_lse, dim=0), + "target_base_logits": torch.cat(target_base_logits, dim=0), + "target_selected_pos": torch.cat(target_selected_pos, dim=0), + "num_tokens": y.shape[0], + } + + +def _sparse_bias_loss_from_terms(terms, bias_params, beta, device): + selected_logits = terms["selected_logits"].to(device) + base_lse = terms["base_lse"].to(device) + non_selected_lse = terms["non_selected_lse"].to(device) + target_base_logits = terms["target_base_logits"].to(device) + target_selected_pos = terms["target_selected_pos"].to(device) + + adjusted_selected = selected_logits + bias_params.unsqueeze(0) + selected_adjusted_lse = torch.logsumexp(adjusted_selected, dim=-1) + denom = torch.logaddexp(non_selected_lse, selected_adjusted_lse) + + target_logits = target_base_logits + selected_mask = target_selected_pos >= 0 + if selected_mask.any(): + target_logits = target_logits.clone() + selected_rows = selected_mask.nonzero(as_tuple=False).squeeze(1) + selected_cols = target_selected_pos[selected_rows] + target_logits[selected_rows] = adjusted_selected[selected_rows, selected_cols] + + ce = (denom - target_logits).sum() + + if beta > 0: + selected_base_probs = torch.exp(selected_logits - base_lse.unsqueeze(1)) + selected_bias_expectation = (selected_base_probs * bias_params.unsqueeze(0)).sum(dim=-1) + kl = (denom - base_lse - selected_bias_expectation).sum() + else: + kl = torch.zeros((), device=device, dtype=ce.dtype) + + del selected_logits, base_lse, non_selected_lse + del target_base_logits, target_selected_pos, adjusted_selected + return ce, kl + + +def fit_sparse_logit_bias( + cached_h: list, + lm_head_weight: torch.Tensor, + lm_head_bias: torch.Tensor | None, + token_ids: list[int], + vocab_size: int, + init_values: torch.Tensor | None = None, + lr: float = 0.05, + steps: int = 30, + beta: float = 0.05, + lam: float = 1e-4, + max_grad_norm: float = 5.0, + device: str = "cuda:0", + verbose: bool = False, +): + """Fit a sparse vocabulary bias on support hidden states.""" + if not cached_h or not token_ids: + return torch.zeros(vocab_size, dtype=torch.float32), 0 + + token_ids_tensor = torch.tensor(token_ids, dtype=torch.long, device=device) + token_id_to_pos = torch.full((vocab_size,), -1, dtype=torch.long, device=device) + token_id_to_pos[token_ids_tensor] = torch.arange(len(token_ids), device=device) + if init_values is None: + bias_params = torch.zeros(len(token_ids), device=device, dtype=torch.float32) + else: + bias_params = init_values.to(device=device, dtype=torch.float32).clone() + bias_params.requires_grad_(True) + + optimizer = torch.optim.Adam([bias_params], lr=lr) + lm_w = lm_head_weight.float() + lm_b = lm_head_bias.float() if lm_head_bias is not None else None + precomputed = [] + total_tokens = 0 + + for h_cpu, y_cpu in cached_h: + h = h_cpu.to(device).float() + y = y_cpu.to(device) + terms = _precompute_sparse_bias_terms(h, y, lm_w, lm_b, token_ids_tensor, token_id_to_pos) + precomputed.append(terms) + total_tokens += terms["num_tokens"] + del h, y + + for step in range(steps): + total_loss = 0.0 + + for terms in precomputed: + ce, kl = _sparse_bias_loss_from_terms(terms, bias_params, beta, device) + total_loss = total_loss + ce + beta * kl + + loss = total_loss / max(total_tokens, 1) + lam * bias_params.square().sum() + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_([bias_params], max_norm=max_grad_norm) + optimizer.step() + + with torch.no_grad(): + norm = bias_params.norm() + if norm > max_grad_norm: + bias_params.mul_(max_grad_norm / norm) + + if verbose and (step % 10 == 0 or step == steps - 1): + print(f" Step {step:3d}: loss={loss.item():.4f}, |bias|={bias_params.norm().item():.4f}") + + del total_loss, loss + + bias = torch.zeros(vocab_size, dtype=torch.float32) + bias[token_ids] = bias_params.detach().cpu() + return bias, len(token_ids) diff --git a/resulets/outputs/dense_retrieval_baselines/dense_summary.csv b/resulets/outputs/dense_retrieval_baselines/dense_summary.csv new file mode 100644 index 0000000..587b09f --- /dev/null +++ b/resulets/outputs/dense_retrieval_baselines/dense_summary.csv @@ -0,0 +1,9 @@ +task,setting,K,method,model,retrieval_text,year,rougeL,meteor,sfd_nolen,avg_len
+review,user,4,dense_minilm_top1,sentence-transformers/all-MiniLM-L6-v2,input_output,MiniLM 2020; Sentence-Transformers checkpoint circa 2021,0.13634167996937627,0.19007702913610236,0.6699762816479236,213.215
+review,user,4,dense_mpnet_top1,sentence-transformers/all-mpnet-base-v2,input_output,MPNet 2020; Sentence-Transformers checkpoint circa 2021,0.1373899379012414,0.19026458136155466,0.7026739815740055,217.525
+review,user,4,dense_e5_top1,intfloat/e5-base-v2,input_output,E5 2022,0.13778466039449483,0.19137571071197623,0.6781906955515157,219.245
+review,user,4,dense_bge_top1,BAAI/bge-base-en-v1.5,input_output,BGE v1.5 2023,0.1397550736294082,0.19234841124713037,0.6609804389020235,214.785
+topic,user,4,dense_minilm_top1,sentence-transformers/all-MiniLM-L6-v2,input_output,MiniLM 2020; Sentence-Transformers checkpoint circa 2021,0.1165623349420512,0.18663326580520853,0.7993059511162214,227.195
+topic,user,4,dense_mpnet_top1,sentence-transformers/all-mpnet-base-v2,input_output,MPNet 2020; Sentence-Transformers checkpoint circa 2021,0.11818723277088544,0.18875778323321796,0.8179664549308402,232.69
+topic,user,4,dense_e5_top1,intfloat/e5-base-v2,input_output,E5 2022,0.11910482780107752,0.18752784709181344,0.8803836118758828,236.06
+topic,user,4,dense_bge_top1,BAAI/bge-base-en-v1.5,input_output,BGE v1.5 2023,0.11914980629806343,0.18572971446713582,0.793203306935874,224.995
diff --git a/resulets/outputs/dense_retrieval_baselines/dense_summary.md b/resulets/outputs/dense_retrieval_baselines/dense_summary.md new file mode 100644 index 0000000..54529fd --- /dev/null +++ b/resulets/outputs/dense_retrieval_baselines/dense_summary.md @@ -0,0 +1,20 @@ +# Dense Retrieval Baseline Summary + +All runs use `K=4`, greedy decoding, and `support_input + support_output` as the retrieval passage text. + +| Task | Method | Model | Year | ROUGE-L | METEOR | SFD_-len | Avg Len | +|---|---|---|---|---:|---:|---:|---:| +| topic_user | dense_minilm_top1 | sentence-transformers/all-MiniLM-L6-v2 | MiniLM 2020; ST checkpoint circa 2021 | 0.1166 | 0.1866 | 0.7993 | 227.2 | +| topic_user | dense_mpnet_top1 | sentence-transformers/all-mpnet-base-v2 | MPNet 2020; ST checkpoint circa 2021 | 0.1182 | 0.1888 | 0.8180 | 232.7 | +| topic_user | dense_e5_top1 | intfloat/e5-base-v2 | E5 2022 | 0.1191 | 0.1875 | 0.8804 | 236.1 | +| topic_user | dense_bge_top1 | BAAI/bge-base-en-v1.5 | BGE v1.5 2023 | 0.1191 | 0.1857 | 0.7932 | 225.0 | +| review_user | dense_minilm_top1 | sentence-transformers/all-MiniLM-L6-v2 | MiniLM 2020; ST checkpoint circa 2021 | 0.1363 | 0.1901 | 0.6700 | 213.2 | +| review_user | dense_mpnet_top1 | sentence-transformers/all-mpnet-base-v2 | MPNet 2020; ST checkpoint circa 2021 | 0.1374 | 0.1903 | 0.7027 | 217.5 | +| review_user | dense_e5_top1 | intfloat/e5-base-v2 | E5 2022 | 0.1378 | 0.1914 | 0.6782 | 219.2 | +| review_user | dense_bge_top1 | BAAI/bge-base-en-v1.5 | BGE v1.5 2023 | 0.1398 | 0.1923 | 0.6610 | 214.8 | + +Best dense retriever: + +- `topic_user`: `dense_bge_top1` by ROUGE-L, essentially tied with `dense_e5_top1`. +- `review_user`: `dense_bge_top1`. + diff --git a/resulets/outputs/dense_retrieval_baselines/review_user_K4/summary.json b/resulets/outputs/dense_retrieval_baselines/review_user_K4/summary.json new file mode 100644 index 0000000..08e7540 --- /dev/null +++ b/resulets/outputs/dense_retrieval_baselines/review_user_K4/summary.json @@ -0,0 +1,39 @@ +{
+ "aggregate": {
+ "dense_minilm_top1": {
+ "rougeL": 0.13634167996937627,
+ "meteor": 0.19007702913610236,
+ "sfd_nolen": 0.6699762816479236,
+ "avg_len": 213.215
+ },
+ "dense_mpnet_top1": {
+ "rougeL": 0.1373899379012414,
+ "meteor": 0.19026458136155466,
+ "sfd_nolen": 0.7026739815740055,
+ "avg_len": 217.525
+ },
+ "dense_e5_top1": {
+ "rougeL": 0.13778466039449483,
+ "meteor": 0.19137571071197623,
+ "sfd_nolen": 0.6781906955515157,
+ "avg_len": 219.245
+ },
+ "dense_bge_top1": {
+ "rougeL": 0.1397550736294082,
+ "meteor": 0.19234841124713037,
+ "sfd_nolen": 0.6609804389020235,
+ "avg_len": 214.785
+ }
+ },
+ "significance": {},
+ "num_examples": 200,
+ "task": "review",
+ "setting": "user",
+ "K": 4,
+ "methods": [
+ "dense_minilm_top1",
+ "dense_mpnet_top1",
+ "dense_e5_top1",
+ "dense_bge_top1"
+ ]
+}
\ No newline at end of file diff --git a/resulets/outputs/dense_retrieval_baselines/topic_user_K4/summary.json b/resulets/outputs/dense_retrieval_baselines/topic_user_K4/summary.json new file mode 100644 index 0000000..bd2cdd8 --- /dev/null +++ b/resulets/outputs/dense_retrieval_baselines/topic_user_K4/summary.json @@ -0,0 +1,39 @@ +{
+ "aggregate": {
+ "dense_minilm_top1": {
+ "rougeL": 0.1165623349420512,
+ "meteor": 0.18663326580520853,
+ "sfd_nolen": 0.7993059511162214,
+ "avg_len": 227.195
+ },
+ "dense_mpnet_top1": {
+ "rougeL": 0.11818723277088544,
+ "meteor": 0.18875778323321796,
+ "sfd_nolen": 0.8179664549308402,
+ "avg_len": 232.69
+ },
+ "dense_e5_top1": {
+ "rougeL": 0.11910482780107752,
+ "meteor": 0.18752784709181344,
+ "sfd_nolen": 0.8803836118758828,
+ "avg_len": 236.06
+ },
+ "dense_bge_top1": {
+ "rougeL": 0.11914980629806343,
+ "meteor": 0.18572971446713582,
+ "sfd_nolen": 0.793203306935874,
+ "avg_len": 224.995
+ }
+ },
+ "significance": {},
+ "num_examples": 200,
+ "task": "topic",
+ "setting": "user",
+ "K": 4,
+ "methods": [
+ "dense_minilm_top1",
+ "dense_mpnet_top1",
+ "dense_e5_top1",
+ "dense_bge_top1"
+ ]
+}
\ No newline at end of file diff --git a/resulets/outputs/diagnostics_full/review_user_K4/summary.json b/resulets/outputs/diagnostics_full/review_user_K4/summary.json new file mode 100644 index 0000000..e017734 --- /dev/null +++ b/resulets/outputs/diagnostics_full/review_user_K4/summary.json @@ -0,0 +1,25 @@ +{
+ "aggregate": {
+ "user_unigram_bias": {
+ "rougeL": 0.12527160864644804,
+ "meteor": 0.15533419842583318,
+ "sfd_nolen": 0.9565568018108134,
+ "avg_len": 166.495
+ },
+ "learned_sparse_logit_bias": {
+ "rougeL": 0.12442403714079388,
+ "meteor": 0.15573424774252828,
+ "sfd_nolen": 0.9361127095590612,
+ "avg_len": 165.905
+ }
+ },
+ "significance": {},
+ "num_examples": 200,
+ "task": "review",
+ "setting": "user",
+ "K": 4,
+ "methods": [
+ "user_unigram_bias",
+ "learned_sparse_logit_bias"
+ ]
+}
\ No newline at end of file diff --git a/resulets/outputs/diagnostics_full/topic_user_K4/summary.json b/resulets/outputs/diagnostics_full/topic_user_K4/summary.json new file mode 100644 index 0000000..b3b82d4 --- /dev/null +++ b/resulets/outputs/diagnostics_full/topic_user_K4/summary.json @@ -0,0 +1,25 @@ +{
+ "aggregate": {
+ "user_unigram_bias": {
+ "rougeL": 0.11952007854701062,
+ "meteor": 0.20468026869316788,
+ "sfd_nolen": 1.045830797035999,
+ "avg_len": 247.045
+ },
+ "learned_sparse_logit_bias": {
+ "rougeL": 0.11851260526759347,
+ "meteor": 0.20384780674291916,
+ "sfd_nolen": 0.891284645760399,
+ "avg_len": 246.92
+ }
+ },
+ "significance": {},
+ "num_examples": 200,
+ "task": "topic",
+ "setting": "user",
+ "K": 4,
+ "methods": [
+ "user_unigram_bias",
+ "learned_sparse_logit_bias"
+ ]
+}
\ No newline at end of file diff --git a/resulets/outputs/injection_ablation_rerun/review_user_K4/summary.json b/resulets/outputs/injection_ablation_rerun/review_user_K4/summary.json new file mode 100644 index 0000000..4f1bce5 --- /dev/null +++ b/resulets/outputs/injection_ablation_rerun/review_user_K4/summary.json @@ -0,0 +1,46 @@ +{
+ "aggregate": {
+ "uph": {
+ "rougeL": 0.12591913138908858,
+ "meteor": 0.15704431994591794,
+ "sfd_nolen": 0.9380754971612366,
+ "avg_len": 165.04
+ },
+ "lm_head_update": {
+ "rougeL": 0.1381619922784921,
+ "meteor": 0.14988041373383443,
+ "sfd_nolen": 1.2312511738320773,
+ "avg_len": 142.135
+ }
+ },
+ "significance": {
+ "lm_head_update": {
+ "rougeL": {
+ "mean_a": 0.12591913138908858,
+ "mean_b": 0.1381619922784921,
+ "mean_diff": -0.012242860889403491,
+ "ci_low": -0.015769665293717965,
+ "ci_high": -0.008716056485089017,
+ "t_pval": 1.1679421884663955e-10,
+ "w_pval": 5.342212114561821e-11
+ },
+ "sfd_nolen": {
+ "mean_a": 0.9380754971612366,
+ "mean_b": 1.2312511738320773,
+ "mean_diff": -0.2931756766708408,
+ "ci_low": -1.0036757003149965,
+ "ci_high": 0.41732434697331494,
+ "t_pval": 0.41961878992446333,
+ "w_pval": 0.04049481176403265
+ }
+ }
+ },
+ "num_examples": 200,
+ "task": "review",
+ "setting": "user",
+ "K": 4,
+ "methods": [
+ "uph",
+ "lm_head_update"
+ ]
+}
\ No newline at end of file diff --git a/resulets/outputs/injection_ablation_rerun/topic_user_K4/summary.json b/resulets/outputs/injection_ablation_rerun/topic_user_K4/summary.json new file mode 100644 index 0000000..47e5734 --- /dev/null +++ b/resulets/outputs/injection_ablation_rerun/topic_user_K4/summary.json @@ -0,0 +1,46 @@ +{
+ "aggregate": {
+ "uph": {
+ "rougeL": 0.11947665707568338,
+ "meteor": 0.2031354029453746,
+ "sfd_nolen": 0.8995390462886158,
+ "avg_len": 246.47
+ },
+ "lm_head_update": {
+ "rougeL": 0.12993177009628162,
+ "meteor": 0.19363473440376885,
+ "sfd_nolen": 0.9707037426803997,
+ "avg_len": 253.99
+ }
+ },
+ "significance": {
+ "lm_head_update": {
+ "rougeL": {
+ "mean_a": 0.11947665707568338,
+ "mean_b": 0.12993177009628162,
+ "mean_diff": -0.010455113020598263,
+ "ci_low": -0.013828552374635705,
+ "ci_high": -0.0070816736665608206,
+ "t_pval": 6.213636827361804e-09,
+ "w_pval": 1.1330271751337874e-11
+ },
+ "sfd_nolen": {
+ "mean_a": 0.8995390462886158,
+ "mean_b": 0.9707037426803997,
+ "mean_diff": -0.07116469639178391,
+ "ci_low": -0.1589445052822547,
+ "ci_high": 0.016615112498686885,
+ "t_pval": 0.1136461024786351,
+ "w_pval": 0.7252815218848135
+ }
+ }
+ },
+ "num_examples": 200,
+ "task": "topic",
+ "setting": "user",
+ "K": 4,
+ "methods": [
+ "uph",
+ "lm_head_update"
+ ]
+}
\ No newline at end of file diff --git a/resulets/scripts/run_all_methods.py b/resulets/scripts/run_all_methods.py new file mode 100644 index 0000000..3809333 --- /dev/null +++ b/resulets/scripts/run_all_methods.py @@ -0,0 +1,908 @@ +"""Unified evaluation pipeline: all methods, all per-user data saved. + +CRASH-SAFE: Each example is appended to a JSONL file immediately after +computation. If the process is killed, all completed examples are preserved. +Already-complete methods are automatically skipped on re-run. + +Usage: + python scripts/run_all_methods.py --task review --setting user --device cuda:0 + python scripts/run_all_methods.py --task review --setting user --methods base,uph,lora +""" + +import sys +import os +import json +import time +import numpy as np +import torch +from scipy import stats + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from data.longlamp import load_longlamp, select_k_profile_items +from data.templates import build_query_prompt, build_prompt_with_examples +from data.style_features import compute_sfd, compute_feature_deltas +from transformers import AutoModelForCausalLM +from models.qwen_wrapper import QwenWrapper +from models.cvh import CVHHead, LMHeadUpdate, UnconditionalHead +from adapt.cache_hidden import cache_support_hidden_states +from adapt.fit_theta import fit_theta +from adapt.fit_theta_lm_head_update import fit_theta_lm_head_update +from baselines.peft_baseline import ( + PEFTBaseline, get_lora_config, get_tiny_lora_config, get_vera_config, + get_prompt_tuning_config, get_prefix_tuning_config, +) +from baselines.bm25_top1 import bm25_select_top1 +from baselines.dense_retrieval import ( + DENSE_RETRIEVER_CONFIGS, + DenseRetriever, + get_dense_retriever_config, +) +from baselines.logit_bias import ( + build_global_log_probs, + build_user_unigram_bias, + fit_sparse_logit_bias, + generate_with_logit_bias, +) +from baselines.profile_based import generate_profile, build_profile_conditioned_prompt +from eval.metrics import compute_rouge, compute_meteor + + +ALL_METHODS = [ + 'base', 'uph', 'cvh', 'lm_head_update', + 'user_unigram_bias', 'learned_sparse_logit_bias', + 'prompt_all_k', 'bm25_top1', 'dense_top1', + 'dense_minilm_top1', 'dense_mpnet_top1', 'dense_e5_top1', 'dense_bge_top1', + 'profile_based', + 'lora', 'tiny_lora', 'vera', + 'prompt_tuning_5', 'prompt_tuning_10', 'prompt_tuning_20', + 'prefix_tuning_5', 'prefix_tuning_10', +] + + +def compute_per_user_metrics(pred, ref, support_texts): + r = compute_rouge([pred], [ref]) + m = compute_meteor([pred], [ref]) + p = pred if pred.strip() else "empty" + sfd_all = compute_sfd(p, support_texts, exclude_length=False) + sfd_nolen = compute_sfd(p, support_texts, exclude_length=True) + deltas = compute_feature_deltas(p, support_texts) + return { + 'rouge1': r['rouge1'], + 'rougeL': r['rougeL'], + 'meteor': m, + 'sfd_all': sfd_all, + 'sfd_nolen': sfd_nolen, + 'length': len(pred.split()), + 'feature_deltas': {k: v['delta'] for k, v in deltas.items()}, + } + + +def generate_greedy(wrapper, prompt, max_new_tokens=512, min_new_tokens=128): + chat_messages = [ + {"role": "system", "content": "You are a helpful writing assistant."}, + {"role": "user", "content": prompt}, + ] + prompt_text = wrapper.tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True + ) + input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device) + with torch.no_grad(): + outputs = wrapper.model.generate( + input_ids, + max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, + temperature=None, top_p=None, do_sample=False, + pad_token_id=wrapper.tokenizer.pad_token_id, + ) + return wrapper.tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True) + + +# ─── Incremental saving ────────────────────────────────────────────── + +def get_method_dir(output_dir, task, setting, K, method_name, d=64): + """Get the output directory for a method.""" + exp_dir = os.path.join(output_dir, f"{task}_{setting}_K{K}") + method_label = f"uph_d{d}" if method_name == 'uph' and d != 64 else method_name + return os.path.join(exp_dir, method_label), method_label + + +def is_method_complete(method_dir, N): + """Check if a method already has a complete per_user.json.""" + path = os.path.join(method_dir, 'per_user.json') + if not os.path.exists(path): + return False + try: + with open(path) as f: + data = json.load(f) + return len(data.get('per_user', [])) >= N + except: + return False + + +def append_jsonl(path, entry): + """Append one JSON entry to a JSONL file (crash-safe).""" + with open(path, 'a') as f: + f.write(json.dumps(entry, default=str) + '\n') + + +def read_jsonl(path): + """Read all entries from a JSONL file.""" + entries = [] + if os.path.exists(path): + with open(path) as f: + for line in f: + line = line.strip() + if line: + entries.append(json.loads(line)) + return entries + + +def finalize_method(method_dir, method_label, per_user, task, setting, K, d=64): + """Write final per_user.json from completed per-user list.""" + agg = { + 'rougeL': float(np.mean([u['metrics']['rougeL'] for u in per_user])), + 'meteor': float(np.mean([u['metrics']['meteor'] for u in per_user])), + 'sfd_nolen': float(np.mean([u['metrics']['sfd_nolen'] for u in per_user])), + 'avg_len': float(np.mean([u['metrics']['length'] for u in per_user])), + } + save_data = { + 'per_user': per_user, + 'aggregate': agg, + 'num_examples': len(per_user), + 'task': task, 'setting': setting, 'K': K, + 'method': method_label, + 'decode_policy': 'greedy, min=128, max=512', + } + if 'uph' in method_label: + save_data['d'] = d + path = os.path.join(method_dir, 'per_user.json') + with open(path, 'w') as f: + json.dump(save_data, f, indent=2, default=str) + print(f" Saved: {path} ({len(per_user)} examples)") + + +# ─── Method runners ────────────────────────────────────────────────── + +class MethodRunner: + def __init__( + self, + wrapper, + device, + dense_retriever=None, + uph_d=64, + bias_top_m=512, + unigram_scale=0.5, + sparse_bias_lr=0.05, + sparse_bias_steps=30, + ): + self.wrapper = wrapper + self.device = device + self.dense_retriever = dense_retriever + self.dense_retrievers = {} + self.uph_d = uph_d + self.bias_top_m = bias_top_m + self.unigram_scale = unigram_scale + self.sparse_bias_lr = sparse_bias_lr + self.sparse_bias_steps = sparse_bias_steps + + def _make_entry(self, ex, ref, stexts, K, pred, timing, extra=None): + metrics = compute_per_user_metrics(pred, ref, stexts) + entry = { + 'example_id': ex['example_id'], + 'user_id': ex['user_id'], + 'prediction': pred, + 'reference': ref, + 'support_texts': stexts, + 'K': K, + 'metrics': metrics, + **timing, + } + if extra: + entry.update(extra) + return entry + + def run(self, method_name, examples, support_sets, references, support_texts, + N, method_dir, method_label, task, setting, K, d=64): + """Run a method with incremental JSONL saving. Returns per_user list.""" + + dispatch = { + 'base': self._run_base, + 'uph': self._run_uph, + 'cvh': self._run_cvh, + 'lm_head_update': self._run_lm_head_update, + 'user_unigram_bias': self._run_user_unigram_bias, + 'learned_sparse_logit_bias': self._run_learned_sparse_logit_bias, + 'prompt_all_k': self._run_prompt_all_k, + 'bm25_top1': self._run_bm25_top1, + 'dense_top1': self._run_dense_top1, + 'profile_based': self._run_profile_based, + 'lora': lambda *a, **kw: self._run_peft(*a, config=get_lora_config(rank=8), lr=1e-4, desc='LoRA r=8', **kw), + 'tiny_lora': lambda *a, **kw: self._run_peft(*a, config=get_tiny_lora_config(rank=1), lr=1e-4, desc='Tiny LoRA r=1', **kw), + 'vera': lambda *a, **kw: self._run_peft(*a, config=get_vera_config(rank=256), lr=1e-3, desc='VeRA r=256', **kw), + 'prompt_tuning_5': lambda *a, **kw: self._run_peft(*a, config=get_prompt_tuning_config(5), lr=1e-3, desc='PromptTuning L=5', steps=100, **kw), + 'prompt_tuning_10': lambda *a, **kw: self._run_peft(*a, config=get_prompt_tuning_config(10), lr=1e-3, desc='PromptTuning L=10', steps=100, **kw), + 'prompt_tuning_20': lambda *a, **kw: self._run_peft(*a, config=get_prompt_tuning_config(20), lr=1e-3, desc='PromptTuning L=20', steps=100, **kw), + 'prefix_tuning_5': lambda *a, **kw: self._run_peft(*a, config=get_prefix_tuning_config(5), lr=5e-4, desc='PrefixTuning L=5', steps=100, **kw), + 'prefix_tuning_10': lambda *a, **kw: self._run_peft(*a, config=get_prefix_tuning_config(10), lr=5e-4, desc='PrefixTuning L=10', steps=100, **kw), + } + + if method_name not in dispatch: + if method_name in DENSE_RETRIEVER_CONFIGS: + run_fn = lambda *a, **kw: self._run_dense_configured(method_name, *a, **kw) + else: + print(f"Unknown method: {method_name}") + return [] + else: + run_fn = dispatch[method_name] + + os.makedirs(method_dir, exist_ok=True) + jsonl_path = os.path.join(method_dir, 'progress.jsonl') + + # Resume: check how many examples already done + existing = read_jsonl(jsonl_path) + start_idx = len(existing) + + if start_idx >= N: + print(f"\n--- {method_name} --- SKIPPED (already {start_idx}/{N} done)") + per_user = existing[:N] + else: + if start_idx > 0: + print(f"\n--- {method_name} --- RESUMING from {start_idx}/{N}") + else: + print(f"\n--- {method_name} ---") + + per_user = run_fn( + examples, support_sets, references, support_texts, N, + jsonl_path=jsonl_path, start_idx=start_idx, existing=existing, + ) + + avg_rl = np.mean([u['metrics']['rougeL'] for u in per_user]) + avg_sfd = np.mean([u['metrics']['sfd_nolen'] for u in per_user]) + print(f" Mean R-L: {avg_rl:.4f}, SFD_-len: {avg_sfd:.4f}") + + # Write final per_user.json + finalize_method(method_dir, method_label, per_user, task, setting, K, d) + return per_user + + # --- Individual method runners --- + # All accept jsonl_path, start_idx, existing for resume support + + def _run_base(self, examples, support_sets, references, support_texts, N, + jsonl_path, start_idx, existing): + per_user = list(existing) + for i in range(start_idx, N): + ex = examples[i] + t0 = time.time() + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = generate_greedy(self.wrapper, prompt) + entry = self._make_entry( + ex, references[i], support_texts[i], len(support_sets[i]), + pred, {'gen_time': time.time() - t0} + ) + per_user.append(entry) + append_jsonl(jsonl_path, entry) + if (i + 1) % 40 == 0: + print(f" {i+1}/{N}") + return per_user + + def _run_prompt_all_k(self, examples, support_sets, references, support_texts, N, + jsonl_path, start_idx, existing): + per_user = list(existing) + for i in range(start_idx, N): + ex, support = examples[i], support_sets[i] + t0 = time.time() + prompt = build_prompt_with_examples(ex['query_input'], support, ex['task']) + pred = generate_greedy(self.wrapper, prompt) + entry = self._make_entry( + ex, references[i], support_texts[i], len(support), + pred, {'gen_time': time.time() - t0} + ) + per_user.append(entry) + append_jsonl(jsonl_path, entry) + if (i + 1) % 40 == 0: + print(f" {i+1}/{N}") + return per_user + + def _run_bm25_top1(self, examples, support_sets, references, support_texts, N, + jsonl_path, start_idx, existing): + per_user = list(existing) + for i in range(start_idx, N): + ex, support = examples[i], support_sets[i] + t0 = time.time() + selected = bm25_select_top1(ex['query_input'], support) + prompt = build_prompt_with_examples(ex['query_input'], selected, ex['task']) + pred = generate_greedy(self.wrapper, prompt) + entry = self._make_entry( + ex, references[i], support_texts[i], len(support), + pred, {'gen_time': time.time() - t0} + ) + per_user.append(entry) + append_jsonl(jsonl_path, entry) + if (i + 1) % 40 == 0: + print(f" {i+1}/{N}") + return per_user + + def _run_dense_top1(self, examples, support_sets, references, support_texts, N, + jsonl_path, start_idx, existing): + if self.dense_retriever is None: + self.dense_retriever = DenseRetriever( + model_name='sentence-transformers/all-MiniLM-L6-v2', + device='cpu', + text_mode='input', + normalize_embeddings=True, + ) + per_user = list(existing) + for i in range(start_idx, N): + ex, support = examples[i], support_sets[i] + t0 = time.time() + selected, retrieval = self.dense_retriever.retrieve_top_k( + ex['query_input'], support, k=1, return_metadata=True + ) + prompt = build_prompt_with_examples(ex['query_input'], selected, ex['task']) + pred = generate_greedy(self.wrapper, prompt) + entry = self._make_entry( + ex, references[i], support_texts[i], len(support), + pred, {'gen_time': time.time() - t0}, + extra={ + 'retriever_model': self.dense_retriever.model_name, + 'retrieval_text_mode': self.dense_retriever.text_mode, + 'retrieval': retrieval, + }, + ) + per_user.append(entry) + append_jsonl(jsonl_path, entry) + if (i + 1) % 40 == 0: + print(f" {i+1}/{N}") + return per_user + + def _get_dense_retriever(self, config): + key = ( + config.model_name, + config.text_mode, + config.query_prefix, + config.passage_prefix, + config.normalize_embeddings, + ) + if key not in self.dense_retrievers: + self.dense_retrievers[key] = DenseRetriever( + model_name=config.model_name, + device='cpu', + text_mode=config.text_mode, + query_prefix=config.query_prefix, + passage_prefix=config.passage_prefix, + normalize_embeddings=config.normalize_embeddings, + ) + return self.dense_retrievers[key] + + def _run_dense_configured(self, method_name, examples, support_sets, references, support_texts, N, + jsonl_path, start_idx, existing): + config = get_dense_retriever_config(method_name) + retriever = self._get_dense_retriever(config) + print( + f" Dense retriever: {config.model_name}, " + f"text_mode={config.text_mode}, year={config.citation_year}" + ) + + per_user = list(existing) + for i in range(start_idx, N): + ex, support = examples[i], support_sets[i] + t0 = time.time() + selected, retrieval = retriever.retrieve_top_k( + ex['query_input'], support, k=1, return_metadata=True + ) + prompt = build_prompt_with_examples(ex['query_input'], selected, ex['task']) + pred = generate_greedy(self.wrapper, prompt) + entry = self._make_entry( + ex, references[i], support_texts[i], len(support), + pred, {'gen_time': time.time() - t0}, + extra={ + 'retriever_model': config.model_name, + 'retrieval_text_mode': config.text_mode, + 'retriever_year': config.citation_year, + 'retriever_description': config.description, + 'retrieval': retrieval, + }, + ) + per_user.append(entry) + append_jsonl(jsonl_path, entry) + if (i + 1) % 40 == 0: + avg_rl = np.mean([u['metrics']['rougeL'] for u in per_user]) + print(f" {i+1}/{N} (avg R-L: {avg_rl:.4f})") + return per_user + + def _run_profile_based(self, examples, support_sets, references, support_texts, N, + jsonl_path, start_idx, existing): + per_user = list(existing) + for i in range(start_idx, N): + ex, support = examples[i], support_sets[i] + t0 = time.time() + profile = generate_profile(self.wrapper, support, ex['task']) + prompt = build_profile_conditioned_prompt(ex['query_input'], profile, ex['task']) + pred = generate_greedy(self.wrapper, prompt) + entry = self._make_entry( + ex, references[i], support_texts[i], len(support), + pred, {'gen_time': time.time() - t0}, + extra={'profile_summary': profile}, + ) + per_user.append(entry) + append_jsonl(jsonl_path, entry) + if (i + 1) % 40 == 0: + print(f" {i+1}/{N}") + return per_user + + def _run_uph(self, examples, support_sets, references, support_texts, N, + jsonl_path, start_idx, existing): + d = self.uph_d + H = self.wrapper.hidden_size + uncond = UnconditionalHead(H, d=d, alpha=0.1, basis_seed=42).to(self.device) + print(f" UPH d={d}, params={d}, bytes={d*2}") + lm_head_bias = None + if hasattr(self.wrapper.model.lm_head, 'bias') and self.wrapper.model.lm_head.bias is not None: + lm_head_bias = self.wrapper.model.lm_head.bias.data + + per_user = list(existing) + for i in range(start_idx, N): + ex, support = examples[i], support_sets[i] + t0 = time.time() + cached_h = cache_support_hidden_states(self.wrapper, support, ex['task']) + if not cached_h: + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = generate_greedy(self.wrapper, prompt) + else: + theta = fit_theta( + cached_h=cached_h, + lm_head_weight=self.wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, + head_module=uncond, + d=d, lr=0.05, steps=30, beta=0.05, lam=1e-4, + max_grad_norm=5.0, device=self.device, + ) + prompt = build_query_prompt(ex['query_input'], ex['task']) + delta_h = uncond.alpha * (uncond.U.float() @ theta.to(self.device).float()) + logit_bias = 0.5 * torch.mv(self.wrapper.lm_head_weight.float(), delta_h) + pred = generate_with_logit_bias( + self.wrapper, + prompt, + logit_bias.detach().cpu(), + max_new_tokens=512, + min_new_tokens=128, + temperature=0.0, + ) + del cached_h, theta + torch.cuda.empty_cache() + + entry = self._make_entry( + ex, references[i], support_texts[i], len(support), + pred, {'adapt_time': time.time() - t0} + ) + per_user.append(entry) + append_jsonl(jsonl_path, entry) + if (i + 1) % 40 == 0: + avg_rl = np.mean([u['metrics']['rougeL'] for u in per_user]) + print(f" {i+1}/{N} (avg R-L: {avg_rl:.4f})") + return per_user + + def _run_cvh(self, examples, support_sets, references, support_texts, N, + jsonl_path, start_idx, existing): + d = self.uph_d + H = self.wrapper.hidden_size + cvh = CVHHead(H, d=d, alpha=0.1, basis_seed=42).to(self.device) + print(f" CVH d={d}, params={d}, bytes={d*2}") + lm_head_bias = None + if hasattr(self.wrapper.model.lm_head, 'bias') and self.wrapper.model.lm_head.bias is not None: + lm_head_bias = self.wrapper.model.lm_head.bias.data + + per_user = list(existing) + for i in range(start_idx, N): + ex, support = examples[i], support_sets[i] + t0 = time.time() + cached_h = cache_support_hidden_states(self.wrapper, support, ex['task']) + if not cached_h: + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = generate_greedy(self.wrapper, prompt) + else: + theta = fit_theta( + cached_h=cached_h, + lm_head_weight=self.wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, + head_module=cvh, + d=d, lr=0.05, steps=30, beta=0.05, lam=1e-4, + max_grad_norm=5.0, device=self.device, + ) + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = self.wrapper.generate_with_head_blended( + prompt, theta, cvh.forward_fn, + blend_gamma=0.5, max_new_tokens=512, + min_new_tokens=128, temperature=0.0, + ) + del cached_h, theta + torch.cuda.empty_cache() + + entry = self._make_entry( + ex, references[i], support_texts[i], len(support), + pred, {'adapt_time': time.time() - t0} + ) + per_user.append(entry) + append_jsonl(jsonl_path, entry) + if (i + 1) % 40 == 0: + avg_rl = np.mean([u['metrics']['rougeL'] for u in per_user]) + print(f" {i+1}/{N} (avg R-L: {avg_rl:.4f})") + return per_user + + def _run_lm_head_update(self, examples, support_sets, references, support_texts, N, + jsonl_path, start_idx, existing): + d = self.uph_d + H = self.wrapper.hidden_size + vocab_size = self.wrapper.lm_head_weight.shape[0] + head_update = LMHeadUpdate(H, vocab_size, d=d, alpha=0.1, basis_seed=42).to(self.device) + print( + f" LM-head update d={d}, user params={d}, " + f"fixed basis params={H*d + vocab_size*d}, bytes={d*2}" + ) + lm_head_bias = None + if hasattr(self.wrapper.model.lm_head, 'bias') and self.wrapper.model.lm_head.bias is not None: + lm_head_bias = self.wrapper.model.lm_head.bias.data + + per_user = list(existing) + for i in range(start_idx, N): + ex, support = examples[i], support_sets[i] + t0 = time.time() + cached_h = cache_support_hidden_states(self.wrapper, support, ex['task']) + if not cached_h: + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = generate_greedy(self.wrapper, prompt) + else: + theta = fit_theta_lm_head_update( + cached_h=cached_h, + lm_head_weight=self.wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, + head_update=head_update, + d=d, lr=0.05, steps=30, beta=0.05, lam=1e-4, + blend_gamma=0.5, max_grad_norm=5.0, device=self.device, + ) + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = self.wrapper.generate_with_lm_head_update( + prompt, theta, head_update, + blend_gamma=0.5, max_new_tokens=512, + min_new_tokens=128, temperature=0.0, + ) + del cached_h, theta + torch.cuda.empty_cache() + + entry = self._make_entry( + ex, references[i], support_texts[i], len(support), + pred, {'adapt_time': time.time() - t0}, + extra={ + 'update_form': 'W + gamma * alpha * C diag(theta) A', + 'blend_gamma': 0.5, + }, + ) + per_user.append(entry) + append_jsonl(jsonl_path, entry) + if (i + 1) % 40 == 0: + avg_rl = np.mean([u['metrics']['rougeL'] for u in per_user]) + print(f" {i+1}/{N} (avg R-L: {avg_rl:.4f})") + return per_user + + def _run_user_unigram_bias(self, examples, support_sets, references, support_texts, N, + jsonl_path, start_idx, existing): + print(f" User-Unigram Bias top_m={self.bias_top_m}, scale={self.unigram_scale}") + vocab_size = self.wrapper.lm_head_weight.shape[0] + global_log_probs = build_global_log_probs( + self.wrapper.tokenizer, support_sets[:N], smoothing=0.1, vocab_size=vocab_size + ) + + per_user = list(existing) + for i in range(start_idx, N): + ex, support = examples[i], support_sets[i] + t0 = time.time() + bias, token_ids = build_user_unigram_bias( + self.wrapper.tokenizer, + support, + global_log_probs, + vocab_size=vocab_size, + top_m=self.bias_top_m, + scale=self.unigram_scale, + smoothing=0.1, + ) + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = generate_with_logit_bias( + self.wrapper, prompt, bias, + max_new_tokens=512, min_new_tokens=128, temperature=0.0, + ) + entry = self._make_entry( + ex, references[i], support_texts[i], len(support), + pred, {'gen_time': time.time() - t0}, + extra={'bias_top_m': self.bias_top_m, 'bias_tokens': len(token_ids), + 'unigram_scale': self.unigram_scale}, + ) + per_user.append(entry) + append_jsonl(jsonl_path, entry) + if (i + 1) % 40 == 0: + avg_rl = np.mean([u['metrics']['rougeL'] for u in per_user]) + print(f" {i+1}/{N} (avg R-L: {avg_rl:.4f})") + return per_user + + def _run_learned_sparse_logit_bias(self, examples, support_sets, references, support_texts, N, + jsonl_path, start_idx, existing): + print( + f" Learned Sparse Logit Bias top_m={self.bias_top_m}, " + f"steps={self.sparse_bias_steps}, lr={self.sparse_bias_lr}" + ) + vocab_size = self.wrapper.lm_head_weight.shape[0] + global_log_probs = build_global_log_probs( + self.wrapper.tokenizer, support_sets[:N], smoothing=0.1, vocab_size=vocab_size + ) + lm_head_bias = None + if hasattr(self.wrapper.model.lm_head, 'bias') and self.wrapper.model.lm_head.bias is not None: + lm_head_bias = self.wrapper.model.lm_head.bias.data + + per_user = list(existing) + for i in range(start_idx, N): + ex, support = examples[i], support_sets[i] + t0 = time.time() + init_bias, token_ids = build_user_unigram_bias( + self.wrapper.tokenizer, + support, + global_log_probs, + vocab_size=vocab_size, + top_m=self.bias_top_m, + scale=0.0, + smoothing=0.1, + ) + cached_h = cache_support_hidden_states(self.wrapper, support, ex['task']) + if not cached_h or not token_ids: + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = generate_greedy(self.wrapper, prompt) + n_bias = 0 + else: + learned_bias, n_bias = fit_sparse_logit_bias( + cached_h=cached_h, + lm_head_weight=self.wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, + token_ids=token_ids, + vocab_size=vocab_size, + init_values=None, + lr=self.sparse_bias_lr, + steps=self.sparse_bias_steps, + beta=0.05, + lam=1e-4, + max_grad_norm=5.0, + device=self.device, + ) + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = generate_with_logit_bias( + self.wrapper, prompt, learned_bias, + max_new_tokens=512, min_new_tokens=128, temperature=0.0, + ) + del cached_h, learned_bias + torch.cuda.empty_cache() + + entry = self._make_entry( + ex, references[i], support_texts[i], len(support), + pred, {'adapt_time': time.time() - t0}, + extra={'bias_top_m': self.bias_top_m, 'bias_tokens': n_bias, + 'sparse_bias_steps': self.sparse_bias_steps, + 'sparse_bias_lr': self.sparse_bias_lr}, + ) + per_user.append(entry) + append_jsonl(jsonl_path, entry) + if (i + 1) % 40 == 0: + avg_rl = np.mean([u['metrics']['rougeL'] for u in per_user]) + print(f" {i+1}/{N} (avg R-L: {avg_rl:.4f})") + return per_user + + def _run_peft(self, examples, support_sets, references, support_texts, N, + config, lr, desc, steps=30, jsonl_path=None, start_idx=0, existing=None): + if existing is None: + existing = [] + + # Reload model fresh to avoid contamination from previous PEFT methods + print(f" Reloading model for {desc}...") + self.wrapper.model = AutoModelForCausalLM.from_pretrained( + 'Qwen/Qwen2.5-1.5B-Instruct', + torch_dtype=torch.bfloat16, + trust_remote_code=True, + ).to(self.device) + self.wrapper.model.eval() + self.wrapper.lm_head_weight = self.wrapper.model.lm_head.weight.data + torch.cuda.empty_cache() + + baseline = PEFTBaseline(self.wrapper, config) + print(f" {desc}: {baseline.n_params:,} params ({baseline.n_bytes:,} bytes), steps={steps}, lr={lr}") + + per_user = list(existing) + for i in range(start_idx, N): + ex, support = examples[i], support_sets[i] + t0 = time.time() + pred = baseline.adapt_and_generate( + support_items=support, + query_input=ex['query_input'], + task=ex['task'], + lr=lr, steps=steps, + max_new_tokens=512, min_new_tokens=128, + ) + entry = self._make_entry( + ex, references[i], support_texts[i], len(support), + pred, {'adapt_time': time.time() - t0}, + extra={'n_params': baseline.n_params, 'n_bytes': baseline.n_bytes}, + ) + per_user.append(entry) + append_jsonl(jsonl_path, entry) + if (i + 1) % 20 == 0: + avg_rl = np.mean([u['metrics']['rougeL'] for u in per_user]) + avg_t = np.mean([u['adapt_time'] for u in per_user]) + print(f" {i+1}/{N} (avg R-L: {avg_rl:.4f}, avg time: {avg_t:.1f}s)") + + # No cleanup needed — model will be reloaded fresh for next PEFT method + del baseline + torch.cuda.empty_cache() + return per_user + + +# ─── Main ──────────────────────────────────────────────────────────── + +def paired_test(scores_a, scores_b, name_a, name_b, metric_name): + a, b = np.array(scores_a), np.array(scores_b) + diff = a - b + mean_diff = np.mean(diff) + t_stat, t_pval = stats.ttest_rel(a, b) + try: + w_stat, w_pval = stats.wilcoxon(a, b) + except ValueError: + w_stat, w_pval = float('nan'), float('nan') + se = stats.sem(diff) + ci_low, ci_high = mean_diff - 1.96 * se, mean_diff + 1.96 * se + return { + 'mean_a': float(np.mean(a)), 'mean_b': float(np.mean(b)), + 'mean_diff': float(mean_diff), + 'ci_low': float(ci_low), 'ci_high': float(ci_high), + 't_pval': float(t_pval), 'w_pval': float(w_pval), + } + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--num_eval', type=int, default=200) + parser.add_argument('--task', type=str, default='review', choices=['review', 'topic']) + parser.add_argument('--setting', type=str, default='user', choices=['user', 'temporal']) + parser.add_argument('--methods', type=str, default='all', + help='Comma-separated methods or "all"') + parser.add_argument('--device', type=str, default='cuda:0') + parser.add_argument('--K', type=int, default=4) + parser.add_argument('--d', type=int, default=64, help='UPH theta dimension') + parser.add_argument('--output_dir', type=str, default='outputs/unified') + parser.add_argument('--bias_top_m', type=int, default=512, + help='Number of user-specific tokens for logit-bias baselines') + parser.add_argument('--unigram_scale', type=float, default=0.5, + help='Scale for zero-training user unigram logit bias') + parser.add_argument('--sparse_bias_lr', type=float, default=0.05, + help='Learning rate for learned sparse logit-bias baseline') + parser.add_argument('--sparse_bias_steps', type=int, default=30, + help='Adaptation steps for learned sparse logit-bias baseline') + args = parser.parse_args() + + N = args.num_eval + task = args.task + setting = args.setting + K = args.K + + config_map = { + ('review', 'user'): 'product_review_user', + ('review', 'temporal'): 'product_review_temporal', + ('topic', 'user'): 'topic_writing_user', + ('topic', 'temporal'): 'topic_writing_temporal', + } + config_name = config_map[(task, setting)] + + if args.methods == 'all': + methods = ALL_METHODS + else: + methods = [m.strip() for m in args.methods.split(',')] + + print(f"=== Unified Eval: {task}_{setting}, N={N}, K={K}, d={args.d} ===") + print(f"Methods: {methods}") + print(f"Decode: greedy, min=128, max=512") + + print("\nLoading data...") + examples = load_longlamp(config_name, split='val')[:N] + support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples] + references = [ex['target_output'] for ex in examples] + support_texts = [[s['support_output'] for s in ss] for ss in support_sets] + + print(f"Loading model on {args.device}...") + wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device=args.device) + + runner = MethodRunner( + wrapper, + args.device, + uph_d=args.d, + bias_top_m=args.bias_top_m, + unigram_scale=args.unigram_scale, + sparse_bias_lr=args.sparse_bias_lr, + sparse_bias_steps=args.sparse_bias_steps, + ) + all_per_user = {} + + for method in methods: + method_dir, method_label = get_method_dir( + args.output_dir, task, setting, K, method, args.d + ) + + # Skip if already complete + if is_method_complete(method_dir, N): + print(f"\n--- {method} --- COMPLETE (loading from disk)") + with open(os.path.join(method_dir, 'per_user.json')) as f: + data = json.load(f) + all_per_user[method] = data['per_user'][:N] + avg_rl = np.mean([u['metrics']['rougeL'] for u in all_per_user[method]]) + print(f" Mean R-L: {avg_rl:.4f}") + continue + + per_user = runner.run( + method, examples, support_sets, references, support_texts, + N, method_dir, method_label, task, setting, K, args.d, + ) + all_per_user[method] = per_user + + # Summary table + print("\n" + "=" * 90) + print(f"{'Method':<15} {'R-L':<8} {'METEOR':<8} {'SFD_-len':<9} {'Len':<6}") + print("-" * 90) + for method in methods: + if method not in all_per_user: + continue + pu = all_per_user[method] + rl = np.mean([u['metrics']['rougeL'] for u in pu]) + mt = np.mean([u['metrics']['meteor'] for u in pu]) + sf = np.mean([u['metrics']['sfd_nolen'] for u in pu]) + ln = np.mean([u['metrics']['length'] for u in pu]) + print(f"{method:<15} {rl:<8.4f} {mt:<8.4f} {sf:<9.4f} {ln:<6.0f}") + + # Significance tests (UPH vs all others) + sig_results = {} + if 'uph' in all_per_user: + print("\n" + "=" * 90) + print("Significance (UPH vs each, paired t-test p-value)") + print("=" * 90) + uph_rl = [u['metrics']['rougeL'] for u in all_per_user['uph']] + uph_sf = [u['metrics']['sfd_nolen'] for u in all_per_user['uph']] + for method in methods: + if method == 'uph' or method not in all_per_user: + continue + other_rl = [u['metrics']['rougeL'] for u in all_per_user[method]] + other_sf = [u['metrics']['sfd_nolen'] for u in all_per_user[method]] + rl_t = paired_test(uph_rl, other_rl, 'uph', method, 'R-L') + sf_t = paired_test(uph_sf, other_sf, 'uph', method, 'SFD') + sig_results[method] = {'rougeL': rl_t, 'sfd_nolen': sf_t} + print(f" vs {method:<12} R-L: diff={rl_t['mean_diff']:+.4f} p={rl_t['t_pval']:.2e} " + f"SFD: diff={sf_t['mean_diff']:+.4f} p={sf_t['t_pval']:.2e}") + + # Save summary + exp_dir = os.path.join(args.output_dir, f"{task}_{setting}_K{K}") + summary = {} + for method in methods: + if method not in all_per_user: + continue + pu = all_per_user[method] + summary[method] = { + 'rougeL': float(np.mean([u['metrics']['rougeL'] for u in pu])), + 'meteor': float(np.mean([u['metrics']['meteor'] for u in pu])), + 'sfd_nolen': float(np.mean([u['metrics']['sfd_nolen'] for u in pu])), + 'avg_len': float(np.mean([u['metrics']['length'] for u in pu])), + } + summary_path = os.path.join(exp_dir, 'summary.json') + with open(summary_path, 'w') as f: + json.dump({ + 'aggregate': summary, + 'significance': sig_results, + 'num_examples': N, 'task': task, 'setting': setting, 'K': K, + 'methods': methods, + }, f, indent=2, default=str) + + print(f"\nSummary: {summary_path}") + + +if __name__ == '__main__': + main() diff --git a/resulets/scripts/summarize_dense_baselines.py b/resulets/scripts/summarize_dense_baselines.py new file mode 100644 index 0000000..081667e --- /dev/null +++ b/resulets/scripts/summarize_dense_baselines.py @@ -0,0 +1,63 @@ +"""Summarize dense retrieval baseline runs into a compact CSV table.""" + +import argparse +import csv +import json +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from baselines.dense_retrieval import DENSE_RETRIEVER_CONFIGS + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--output_dir", default="outputs/dense_retrieval_baselines") + parser.add_argument("--csv_name", default="dense_summary.csv") + args = parser.parse_args() + + output_dir = Path(args.output_dir) + rows = [] + + for summary_path in sorted(output_dir.glob("*_K*/summary.json")): + with summary_path.open() as f: + summary = json.load(f) + + task = summary.get("task") + setting = summary.get("setting") + K = summary.get("K") + aggregate = summary.get("aggregate", {}) + + for method, metrics in aggregate.items(): + config = DENSE_RETRIEVER_CONFIGS.get(method) + rows.append({ + "task": task, + "setting": setting, + "K": K, + "method": method, + "model": config.model_name if config else "", + "retrieval_text": config.text_mode if config else "", + "year": config.citation_year if config else "", + "rougeL": metrics.get("rougeL"), + "meteor": metrics.get("meteor"), + "sfd_nolen": metrics.get("sfd_nolen"), + "avg_len": metrics.get("avg_len"), + }) + + csv_path = output_dir / args.csv_name + csv_path.parent.mkdir(parents=True, exist_ok=True) + fieldnames = [ + "task", "setting", "K", "method", "model", "retrieval_text", "year", + "rougeL", "meteor", "sfd_nolen", "avg_len", + ] + with csv_path.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(rows) + + print(f"Wrote {csv_path} ({len(rows)} rows)") + + +if __name__ == "__main__": + main() |
