diff options
Diffstat (limited to 'baselines/dense_retrieval.py')
| -rw-r--r-- | baselines/dense_retrieval.py | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/baselines/dense_retrieval.py b/baselines/dense_retrieval.py new file mode 100644 index 0000000..db403a8 --- /dev/null +++ b/baselines/dense_retrieval.py @@ -0,0 +1,42 @@ +"""Dense Retrieval ICL baseline. + +Uses sentence-transformers for dense retrieval over user support set, +then places top-K retrieved items as in-context examples. +""" + +import torch +from sentence_transformers import SentenceTransformer + + +class DenseRetriever: + """Dense retriever using sentence-transformers embeddings.""" + + def __init__(self, model_name='all-MiniLM-L6-v2', device='cpu'): + self.model = SentenceTransformer(model_name, device=device) + + def retrieve_top_k(self, query: str, support_items: list, k: int = 1): + """Retrieve top-k support items most relevant to query. + + Args: + query: query input text + support_items: list of dicts with 'support_input', 'support_output' + k: number of items to retrieve + + Returns: + List of top-k support items (sorted by relevance) + """ + if len(support_items) <= k: + return support_items + + texts = [item['support_input'] for item in support_items] + embeddings = self.model.encode([query] + texts, convert_to_tensor=True) + + query_emb = embeddings[0] + item_embs = embeddings[1:] + + similarities = torch.nn.functional.cosine_similarity( + query_emb.unsqueeze(0), item_embs, dim=1 + ) + + top_indices = similarities.argsort(descending=True)[:k].tolist() + return [support_items[i] for i in top_indices] |
