"""Dense Retrieval ICL baselines. Uses sentence-transformers for dense retrieval over the user support set, then places top-K retrieved items as in-context examples. """ from dataclasses import dataclass import torch @dataclass(frozen=True) class DenseRetrieverConfig: method_name: str model_name: str text_mode: str = "input_output" query_prefix: str = "" passage_prefix: str = "" normalize_embeddings: bool = True citation_year: str = "" description: str = "" DENSE_RETRIEVER_CONFIGS = { "dense_minilm_top1": DenseRetrieverConfig( method_name="dense_minilm_top1", model_name="sentence-transformers/all-MiniLM-L6-v2", citation_year="MiniLM 2020; Sentence-Transformers checkpoint circa 2021", description="Lightweight SBERT/MiniLM dense retriever.", ), "dense_mpnet_top1": DenseRetrieverConfig( method_name="dense_mpnet_top1", model_name="sentence-transformers/all-mpnet-base-v2", citation_year="MPNet 2020; Sentence-Transformers checkpoint circa 2021", description="Stronger SBERT/MPNet dense retriever.", ), "dense_e5_top1": DenseRetrieverConfig( method_name="dense_e5_top1", model_name="intfloat/e5-base-v2", query_prefix="query: ", passage_prefix="passage: ", citation_year="E5 2022", description="E5 dense retriever with the model-card query/passage prefixes.", ), "dense_bge_top1": DenseRetrieverConfig( method_name="dense_bge_top1", model_name="BAAI/bge-base-en-v1.5", query_prefix="Represent this sentence for searching relevant passages: ", citation_year="BGE v1.5 2023", description="BGE v1.5 dense retriever with the recommended query instruction.", ), } def get_dense_retriever_config(method_name: str) -> DenseRetrieverConfig: return DENSE_RETRIEVER_CONFIGS[method_name] class DenseRetriever: """Dense retriever using sentence-transformers embeddings.""" def __init__( self, model_name='sentence-transformers/all-MiniLM-L6-v2', device='cpu', text_mode='input_output', query_prefix='', passage_prefix='', normalize_embeddings=True, ): self.model_name = model_name self.text_mode = text_mode self.query_prefix = query_prefix self.passage_prefix = passage_prefix self.normalize_embeddings = normalize_embeddings from sentence_transformers import SentenceTransformer self.model = SentenceTransformer(model_name, device=device) def _support_text(self, item: dict) -> str: if self.text_mode == 'input': return item['support_input'] if self.text_mode == 'output': return item['support_output'] if self.text_mode == 'input_output': return f"{item['support_input']}\n{item['support_output']}" raise ValueError(f"Unknown dense retrieval text_mode: {self.text_mode}") def retrieve_top_k(self, query: str, support_items: list, k: int = 1, return_metadata=False): """Retrieve top-k support items most relevant to query. Args: query: query input text support_items: list of dicts with 'support_input', 'support_output' k: number of items to retrieve return_metadata: whether to also return retrieval diagnostics Returns: List of top-k support items, optionally with metadata. """ if len(support_items) <= k: metadata = [ { 'rank': rank + 1, 'support_index': rank, 'score': None, 'model_name': self.model_name, 'text_mode': self.text_mode, } for rank in range(len(support_items)) ] return (support_items, metadata) if return_metadata else support_items query_text = self.query_prefix + query texts = [self.passage_prefix + self._support_text(item) for item in support_items] embeddings = self.model.encode( [query_text] + texts, convert_to_tensor=True, normalize_embeddings=self.normalize_embeddings, ) query_emb = embeddings[0] item_embs = embeddings[1:] if self.normalize_embeddings: similarities = item_embs @ query_emb else: similarities = torch.nn.functional.cosine_similarity( query_emb.unsqueeze(0), item_embs, dim=1 ) top_indices = similarities.argsort(descending=True)[:k].tolist() selected = [support_items[i] for i in top_indices] metadata = [ { 'rank': rank + 1, 'support_index': idx, 'score': float(similarities[idx].detach().cpu()), 'model_name': self.model_name, 'text_mode': self.text_mode, } for rank, idx in enumerate(top_indices) ] return (selected, metadata) if return_metadata else selected