From 896df7f11b441a9b8dfa50820024a82884da58d0 Mon Sep 17 00:00:00 2001 From: BLUESKY477 Date: Fri, 22 May 2026 19:23:44 -0500 Subject: Add files via upload --- resulets/baselines/dense_retrieval.py | 143 ++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 resulets/baselines/dense_retrieval.py (limited to 'resulets/baselines/dense_retrieval.py') diff --git a/resulets/baselines/dense_retrieval.py b/resulets/baselines/dense_retrieval.py new file mode 100644 index 0000000..a627319 --- /dev/null +++ b/resulets/baselines/dense_retrieval.py @@ -0,0 +1,143 @@ +"""Dense Retrieval ICL baselines. + +Uses sentence-transformers for dense retrieval over the user support set, +then places top-K retrieved items as in-context examples. +""" + +from dataclasses import dataclass + +import torch + + +@dataclass(frozen=True) +class DenseRetrieverConfig: + method_name: str + model_name: str + text_mode: str = "input_output" + query_prefix: str = "" + passage_prefix: str = "" + normalize_embeddings: bool = True + citation_year: str = "" + description: str = "" + + +DENSE_RETRIEVER_CONFIGS = { + "dense_minilm_top1": DenseRetrieverConfig( + method_name="dense_minilm_top1", + model_name="sentence-transformers/all-MiniLM-L6-v2", + citation_year="MiniLM 2020; Sentence-Transformers checkpoint circa 2021", + description="Lightweight SBERT/MiniLM dense retriever.", + ), + "dense_mpnet_top1": DenseRetrieverConfig( + method_name="dense_mpnet_top1", + model_name="sentence-transformers/all-mpnet-base-v2", + citation_year="MPNet 2020; Sentence-Transformers checkpoint circa 2021", + description="Stronger SBERT/MPNet dense retriever.", + ), + "dense_e5_top1": DenseRetrieverConfig( + method_name="dense_e5_top1", + model_name="intfloat/e5-base-v2", + query_prefix="query: ", + passage_prefix="passage: ", + citation_year="E5 2022", + description="E5 dense retriever with the model-card query/passage prefixes.", + ), + "dense_bge_top1": DenseRetrieverConfig( + method_name="dense_bge_top1", + model_name="BAAI/bge-base-en-v1.5", + query_prefix="Represent this sentence for searching relevant passages: ", + citation_year="BGE v1.5 2023", + description="BGE v1.5 dense retriever with the recommended query instruction.", + ), +} + + +def get_dense_retriever_config(method_name: str) -> DenseRetrieverConfig: + return DENSE_RETRIEVER_CONFIGS[method_name] + + +class DenseRetriever: + """Dense retriever using sentence-transformers embeddings.""" + + def __init__( + self, + model_name='sentence-transformers/all-MiniLM-L6-v2', + device='cpu', + text_mode='input_output', + query_prefix='', + passage_prefix='', + normalize_embeddings=True, + ): + self.model_name = model_name + self.text_mode = text_mode + self.query_prefix = query_prefix + self.passage_prefix = passage_prefix + self.normalize_embeddings = normalize_embeddings + from sentence_transformers import SentenceTransformer + self.model = SentenceTransformer(model_name, device=device) + + def _support_text(self, item: dict) -> str: + if self.text_mode == 'input': + return item['support_input'] + if self.text_mode == 'output': + return item['support_output'] + if self.text_mode == 'input_output': + return f"{item['support_input']}\n{item['support_output']}" + raise ValueError(f"Unknown dense retrieval text_mode: {self.text_mode}") + + def retrieve_top_k(self, query: str, support_items: list, k: int = 1, return_metadata=False): + """Retrieve top-k support items most relevant to query. + + Args: + query: query input text + support_items: list of dicts with 'support_input', 'support_output' + k: number of items to retrieve + return_metadata: whether to also return retrieval diagnostics + + Returns: + List of top-k support items, optionally with metadata. + """ + if len(support_items) <= k: + metadata = [ + { + 'rank': rank + 1, + 'support_index': rank, + 'score': None, + 'model_name': self.model_name, + 'text_mode': self.text_mode, + } + for rank in range(len(support_items)) + ] + return (support_items, metadata) if return_metadata else support_items + + query_text = self.query_prefix + query + texts = [self.passage_prefix + self._support_text(item) for item in support_items] + embeddings = self.model.encode( + [query_text] + texts, + convert_to_tensor=True, + normalize_embeddings=self.normalize_embeddings, + ) + + query_emb = embeddings[0] + item_embs = embeddings[1:] + + if self.normalize_embeddings: + similarities = item_embs @ query_emb + else: + similarities = torch.nn.functional.cosine_similarity( + query_emb.unsqueeze(0), item_embs, dim=1 + ) + + top_indices = similarities.argsort(descending=True)[:k].tolist() + selected = [support_items[i] for i in top_indices] + metadata = [ + { + 'rank': rank + 1, + 'support_index': idx, + 'score': float(similarities[idx].detach().cpu()), + 'model_name': self.model_name, + 'text_mode': self.text_mode, + } + for rank, idx in enumerate(top_indices) + ] + return (selected, metadata) if return_metadata else selected -- cgit v1.2.3