summaryrefslogtreecommitdiff
path: root/resulets/baselines/dense_retrieval.py
diff options
context:
space:
mode:
authorBLUESKY477 <abcd15803148000@163.com>2026-05-22 19:23:44 -0500
committerGitHub <noreply@github.com>2026-05-22 19:23:44 -0500
commit896df7f11b441a9b8dfa50820024a82884da58d0 (patch)
tree0182ae4a7a0bb16ee6a764393838a580e1ba1c31 /resulets/baselines/dense_retrieval.py
parent6f48c4fae3243e280b27a977c6a8cb731becf446 (diff)
Add files via uploadHEADmaster
Diffstat (limited to 'resulets/baselines/dense_retrieval.py')
-rw-r--r--resulets/baselines/dense_retrieval.py143
1 files changed, 143 insertions, 0 deletions
diff --git a/resulets/baselines/dense_retrieval.py b/resulets/baselines/dense_retrieval.py
new file mode 100644
index 0000000..a627319
--- /dev/null
+++ b/resulets/baselines/dense_retrieval.py
@@ -0,0 +1,143 @@
+"""Dense Retrieval ICL baselines.
+
+Uses sentence-transformers for dense retrieval over the user support set,
+then places top-K retrieved items as in-context examples.
+"""
+
+from dataclasses import dataclass
+
+import torch
+
+
+@dataclass(frozen=True)
+class DenseRetrieverConfig:
+ method_name: str
+ model_name: str
+ text_mode: str = "input_output"
+ query_prefix: str = ""
+ passage_prefix: str = ""
+ normalize_embeddings: bool = True
+ citation_year: str = ""
+ description: str = ""
+
+
+DENSE_RETRIEVER_CONFIGS = {
+ "dense_minilm_top1": DenseRetrieverConfig(
+ method_name="dense_minilm_top1",
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
+ citation_year="MiniLM 2020; Sentence-Transformers checkpoint circa 2021",
+ description="Lightweight SBERT/MiniLM dense retriever.",
+ ),
+ "dense_mpnet_top1": DenseRetrieverConfig(
+ method_name="dense_mpnet_top1",
+ model_name="sentence-transformers/all-mpnet-base-v2",
+ citation_year="MPNet 2020; Sentence-Transformers checkpoint circa 2021",
+ description="Stronger SBERT/MPNet dense retriever.",
+ ),
+ "dense_e5_top1": DenseRetrieverConfig(
+ method_name="dense_e5_top1",
+ model_name="intfloat/e5-base-v2",
+ query_prefix="query: ",
+ passage_prefix="passage: ",
+ citation_year="E5 2022",
+ description="E5 dense retriever with the model-card query/passage prefixes.",
+ ),
+ "dense_bge_top1": DenseRetrieverConfig(
+ method_name="dense_bge_top1",
+ model_name="BAAI/bge-base-en-v1.5",
+ query_prefix="Represent this sentence for searching relevant passages: ",
+ citation_year="BGE v1.5 2023",
+ description="BGE v1.5 dense retriever with the recommended query instruction.",
+ ),
+}
+
+
+def get_dense_retriever_config(method_name: str) -> DenseRetrieverConfig:
+ return DENSE_RETRIEVER_CONFIGS[method_name]
+
+
+class DenseRetriever:
+ """Dense retriever using sentence-transformers embeddings."""
+
+ def __init__(
+ self,
+ model_name='sentence-transformers/all-MiniLM-L6-v2',
+ device='cpu',
+ text_mode='input_output',
+ query_prefix='',
+ passage_prefix='',
+ normalize_embeddings=True,
+ ):
+ self.model_name = model_name
+ self.text_mode = text_mode
+ self.query_prefix = query_prefix
+ self.passage_prefix = passage_prefix
+ self.normalize_embeddings = normalize_embeddings
+ from sentence_transformers import SentenceTransformer
+ self.model = SentenceTransformer(model_name, device=device)
+
+ def _support_text(self, item: dict) -> str:
+ if self.text_mode == 'input':
+ return item['support_input']
+ if self.text_mode == 'output':
+ return item['support_output']
+ if self.text_mode == 'input_output':
+ return f"{item['support_input']}\n{item['support_output']}"
+ raise ValueError(f"Unknown dense retrieval text_mode: {self.text_mode}")
+
+ def retrieve_top_k(self, query: str, support_items: list, k: int = 1, return_metadata=False):
+ """Retrieve top-k support items most relevant to query.
+
+ Args:
+ query: query input text
+ support_items: list of dicts with 'support_input', 'support_output'
+ k: number of items to retrieve
+ return_metadata: whether to also return retrieval diagnostics
+
+ Returns:
+ List of top-k support items, optionally with metadata.
+ """
+ if len(support_items) <= k:
+ metadata = [
+ {
+ 'rank': rank + 1,
+ 'support_index': rank,
+ 'score': None,
+ 'model_name': self.model_name,
+ 'text_mode': self.text_mode,
+ }
+ for rank in range(len(support_items))
+ ]
+ return (support_items, metadata) if return_metadata else support_items
+
+ query_text = self.query_prefix + query
+ texts = [self.passage_prefix + self._support_text(item) for item in support_items]
+ embeddings = self.model.encode(
+ [query_text] + texts,
+ convert_to_tensor=True,
+ normalize_embeddings=self.normalize_embeddings,
+ )
+
+ query_emb = embeddings[0]
+ item_embs = embeddings[1:]
+
+ if self.normalize_embeddings:
+ similarities = item_embs @ query_emb
+ else:
+ similarities = torch.nn.functional.cosine_similarity(
+ query_emb.unsqueeze(0), item_embs, dim=1
+ )
+
+ top_indices = similarities.argsort(descending=True)[:k].tolist()
+ selected = [support_items[i] for i in top_indices]
+ metadata = [
+ {
+ 'rank': rank + 1,
+ 'support_index': idx,
+ 'score': float(similarities[idx].detach().cpu()),
+ 'model_name': self.model_name,
+ 'text_mode': self.text_mode,
+ }
+ for rank, idx in enumerate(top_indices)
+ ]
+ return (selected, metadata) if return_metadata else selected