"""Dense Retrieval ICL baseline. Uses sentence-transformers for dense retrieval over user support set, then places top-K retrieved items as in-context examples. """ import torch from sentence_transformers import SentenceTransformer class DenseRetriever: """Dense retriever using sentence-transformers embeddings.""" def __init__(self, model_name='all-MiniLM-L6-v2', device='cpu'): self.model = SentenceTransformer(model_name, device=device) def retrieve_top_k(self, query: str, support_items: list, k: int = 1): """Retrieve top-k support items most relevant to query. Args: query: query input text support_items: list of dicts with 'support_input', 'support_output' k: number of items to retrieve Returns: List of top-k support items (sorted by relevance) """ if len(support_items) <= k: return support_items texts = [item['support_input'] for item in support_items] embeddings = self.model.encode([query] + texts, convert_to_tensor=True) query_emb = embeddings[0] item_embs = embeddings[1:] similarities = torch.nn.functional.cosine_similarity( query_emb.unsqueeze(0), item_embs, dim=1 ) top_indices = similarities.argsort(descending=True)[:k].tolist() return [support_items[i] for i in top_indices]