summaryrefslogtreecommitdiff
path: root/baselines/dense_retrieval.py
blob: db403a8b06ff59f8df0374d2e795b5e80e9a3d6a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
"""Dense Retrieval ICL baseline.

Uses sentence-transformers for dense retrieval over user support set,
then places top-K retrieved items as in-context examples.
"""

import torch
from sentence_transformers import SentenceTransformer


class DenseRetriever:
    """Dense retriever using sentence-transformers embeddings."""

    def __init__(self, model_name='all-MiniLM-L6-v2', device='cpu'):
        self.model = SentenceTransformer(model_name, device=device)

    def retrieve_top_k(self, query: str, support_items: list, k: int = 1):
        """Retrieve top-k support items most relevant to query.

        Args:
            query: query input text
            support_items: list of dicts with 'support_input', 'support_output'
            k: number of items to retrieve

        Returns:
            List of top-k support items (sorted by relevance)
        """
        if len(support_items) <= k:
            return support_items

        texts = [item['support_input'] for item in support_items]
        embeddings = self.model.encode([query] + texts, convert_to_tensor=True)

        query_emb = embeddings[0]
        item_embs = embeddings[1:]

        similarities = torch.nn.functional.cosine_similarity(
            query_emb.unsqueeze(0), item_embs, dim=1
        )

        top_indices = similarities.argsort(descending=True)[:k].tolist()
        return [support_items[i] for i in top_indices]