1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
|
"""Dense Retrieval ICL baselines.
Uses sentence-transformers for dense retrieval over the user support set,
then places top-K retrieved items as in-context examples.
"""
from dataclasses import dataclass
import torch
@dataclass(frozen=True)
class DenseRetrieverConfig:
method_name: str
model_name: str
text_mode: str = "input_output"
query_prefix: str = ""
passage_prefix: str = ""
normalize_embeddings: bool = True
citation_year: str = ""
description: str = ""
DENSE_RETRIEVER_CONFIGS = {
"dense_minilm_top1": DenseRetrieverConfig(
method_name="dense_minilm_top1",
model_name="sentence-transformers/all-MiniLM-L6-v2",
citation_year="MiniLM 2020; Sentence-Transformers checkpoint circa 2021",
description="Lightweight SBERT/MiniLM dense retriever.",
),
"dense_mpnet_top1": DenseRetrieverConfig(
method_name="dense_mpnet_top1",
model_name="sentence-transformers/all-mpnet-base-v2",
citation_year="MPNet 2020; Sentence-Transformers checkpoint circa 2021",
description="Stronger SBERT/MPNet dense retriever.",
),
"dense_e5_top1": DenseRetrieverConfig(
method_name="dense_e5_top1",
model_name="intfloat/e5-base-v2",
query_prefix="query: ",
passage_prefix="passage: ",
citation_year="E5 2022",
description="E5 dense retriever with the model-card query/passage prefixes.",
),
"dense_bge_top1": DenseRetrieverConfig(
method_name="dense_bge_top1",
model_name="BAAI/bge-base-en-v1.5",
query_prefix="Represent this sentence for searching relevant passages: ",
citation_year="BGE v1.5 2023",
description="BGE v1.5 dense retriever with the recommended query instruction.",
),
}
def get_dense_retriever_config(method_name: str) -> DenseRetrieverConfig:
return DENSE_RETRIEVER_CONFIGS[method_name]
class DenseRetriever:
"""Dense retriever using sentence-transformers embeddings."""
def __init__(
self,
model_name='sentence-transformers/all-MiniLM-L6-v2',
device='cpu',
text_mode='input_output',
query_prefix='',
passage_prefix='',
normalize_embeddings=True,
):
self.model_name = model_name
self.text_mode = text_mode
self.query_prefix = query_prefix
self.passage_prefix = passage_prefix
self.normalize_embeddings = normalize_embeddings
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(model_name, device=device)
def _support_text(self, item: dict) -> str:
if self.text_mode == 'input':
return item['support_input']
if self.text_mode == 'output':
return item['support_output']
if self.text_mode == 'input_output':
return f"{item['support_input']}\n{item['support_output']}"
raise ValueError(f"Unknown dense retrieval text_mode: {self.text_mode}")
def retrieve_top_k(self, query: str, support_items: list, k: int = 1, return_metadata=False):
"""Retrieve top-k support items most relevant to query.
Args:
query: query input text
support_items: list of dicts with 'support_input', 'support_output'
k: number of items to retrieve
return_metadata: whether to also return retrieval diagnostics
Returns:
List of top-k support items, optionally with metadata.
"""
if len(support_items) <= k:
metadata = [
{
'rank': rank + 1,
'support_index': rank,
'score': None,
'model_name': self.model_name,
'text_mode': self.text_mode,
}
for rank in range(len(support_items))
]
return (support_items, metadata) if return_metadata else support_items
query_text = self.query_prefix + query
texts = [self.passage_prefix + self._support_text(item) for item in support_items]
embeddings = self.model.encode(
[query_text] + texts,
convert_to_tensor=True,
normalize_embeddings=self.normalize_embeddings,
)
query_emb = embeddings[0]
item_embs = embeddings[1:]
if self.normalize_embeddings:
similarities = item_embs @ query_emb
else:
similarities = torch.nn.functional.cosine_similarity(
query_emb.unsqueeze(0), item_embs, dim=1
)
top_indices = similarities.argsort(descending=True)[:k].tolist()
selected = [support_items[i] for i in top_indices]
metadata = [
{
'rank': rank + 1,
'support_index': idx,
'score': float(similarities[idx].detach().cpu()),
'model_name': self.model_name,
'text_mode': self.text_mode,
}
for rank, idx in enumerate(top_indices)
]
return (selected, metadata) if return_metadata else selected
|