summaryrefslogtreecommitdiff
path: root/resulets/baselines/dense_retrieval.py
blob: a6273193eb5b1e437e2c0a9624af0d2a3da8d101 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""Dense Retrieval ICL baselines.

Uses sentence-transformers for dense retrieval over the user support set,
then places top-K retrieved items as in-context examples.
"""

from dataclasses import dataclass

import torch


@dataclass(frozen=True)
class DenseRetrieverConfig:
    method_name: str
    model_name: str
    text_mode: str = "input_output"
    query_prefix: str = ""
    passage_prefix: str = ""
    normalize_embeddings: bool = True
    citation_year: str = ""
    description: str = ""


DENSE_RETRIEVER_CONFIGS = {
    "dense_minilm_top1": DenseRetrieverConfig(
        method_name="dense_minilm_top1",
        model_name="sentence-transformers/all-MiniLM-L6-v2",
        citation_year="MiniLM 2020; Sentence-Transformers checkpoint circa 2021",
        description="Lightweight SBERT/MiniLM dense retriever.",
    ),
    "dense_mpnet_top1": DenseRetrieverConfig(
        method_name="dense_mpnet_top1",
        model_name="sentence-transformers/all-mpnet-base-v2",
        citation_year="MPNet 2020; Sentence-Transformers checkpoint circa 2021",
        description="Stronger SBERT/MPNet dense retriever.",
    ),
    "dense_e5_top1": DenseRetrieverConfig(
        method_name="dense_e5_top1",
        model_name="intfloat/e5-base-v2",
        query_prefix="query: ",
        passage_prefix="passage: ",
        citation_year="E5 2022",
        description="E5 dense retriever with the model-card query/passage prefixes.",
    ),
    "dense_bge_top1": DenseRetrieverConfig(
        method_name="dense_bge_top1",
        model_name="BAAI/bge-base-en-v1.5",
        query_prefix="Represent this sentence for searching relevant passages: ",
        citation_year="BGE v1.5 2023",
        description="BGE v1.5 dense retriever with the recommended query instruction.",
    ),
}


def get_dense_retriever_config(method_name: str) -> DenseRetrieverConfig:
    return DENSE_RETRIEVER_CONFIGS[method_name]


class DenseRetriever:
    """Dense retriever using sentence-transformers embeddings."""

    def __init__(
        self,
        model_name='sentence-transformers/all-MiniLM-L6-v2',
        device='cpu',
        text_mode='input_output',
        query_prefix='',
        passage_prefix='',
        normalize_embeddings=True,
    ):
        self.model_name = model_name
        self.text_mode = text_mode
        self.query_prefix = query_prefix
        self.passage_prefix = passage_prefix
        self.normalize_embeddings = normalize_embeddings
        from sentence_transformers import SentenceTransformer
        self.model = SentenceTransformer(model_name, device=device)

    def _support_text(self, item: dict) -> str:
        if self.text_mode == 'input':
            return item['support_input']
        if self.text_mode == 'output':
            return item['support_output']
        if self.text_mode == 'input_output':
            return f"{item['support_input']}\n{item['support_output']}"
        raise ValueError(f"Unknown dense retrieval text_mode: {self.text_mode}")

    def retrieve_top_k(self, query: str, support_items: list, k: int = 1, return_metadata=False):
        """Retrieve top-k support items most relevant to query.

        Args:
            query: query input text
            support_items: list of dicts with 'support_input', 'support_output'
            k: number of items to retrieve
            return_metadata: whether to also return retrieval diagnostics

        Returns:
            List of top-k support items, optionally with metadata.
        """
        if len(support_items) <= k:
            metadata = [
                {
                    'rank': rank + 1,
                    'support_index': rank,
                    'score': None,
                    'model_name': self.model_name,
                    'text_mode': self.text_mode,
                }
                for rank in range(len(support_items))
            ]
            return (support_items, metadata) if return_metadata else support_items

        query_text = self.query_prefix + query
        texts = [self.passage_prefix + self._support_text(item) for item in support_items]
        embeddings = self.model.encode(
            [query_text] + texts,
            convert_to_tensor=True,
            normalize_embeddings=self.normalize_embeddings,
        )

        query_emb = embeddings[0]
        item_embs = embeddings[1:]

        if self.normalize_embeddings:
            similarities = item_embs @ query_emb
        else:
            similarities = torch.nn.functional.cosine_similarity(
                query_emb.unsqueeze(0), item_embs, dim=1
            )

        top_indices = similarities.argsort(descending=True)[:k].tolist()
        selected = [support_items[i] for i in top_indices]
        metadata = [
            {
                'rank': rank + 1,
                'support_index': idx,
                'score': float(similarities[idx].detach().cpu()),
                'model_name': self.model_name,
                'text_mode': self.text_mode,
            }
            for rank, idx in enumerate(top_indices)
        ]
        return (selected, metadata) if return_metadata else selected