"""Baseline FAISS top-k retriever for vanilla RAG.""" import logging from typing import List, Optional import faiss import numpy as np import torch from hag.datatypes import RetrievalResult logger = logging.getLogger(__name__) class FAISSRetriever: """Standard top-k retrieval using FAISS inner product search. This is the baseline to compare against Hopfield retrieval. """ def __init__(self, top_k: int = 5) -> None: self.index: Optional[faiss.IndexFlatIP] = None self.passages: List[str] = [] self.top_k = top_k def build_index(self, embeddings: np.ndarray, passages: List[str]) -> None: """Build FAISS IndexFlatIP from embeddings. Args: embeddings: (N, d) numpy array of passage embeddings passages: list of N passage strings """ assert embeddings.shape[0] == len(passages) d = embeddings.shape[1] self.index = faiss.IndexFlatIP(d) # Normalize for cosine similarity via inner product faiss.normalize_L2(embeddings) self.index.add(embeddings) self.passages = list(passages) logger.info("Built FAISS index with %d passages, dim=%d", len(passages), d) def retrieve(self, query: np.ndarray) -> RetrievalResult: """Retrieve top-k passages for a query. Args: query: (d,) or (batch, d) numpy array Returns: RetrievalResult with passages, scores, and indices. """ assert self.index is not None, "Index not built. Call build_index first." if query.ndim == 1: query = query.reshape(1, -1) # (1, d) # Normalize query for cosine similarity query_copy = query.copy() faiss.normalize_L2(query_copy) scores, indices = self.index.search(query_copy, self.top_k) # (batch, k) # Flatten for single query case if scores.shape[0] == 1: scores = scores[0] # (k,) indices = indices[0] # (k,) passages = [self.passages[i] for i in indices.flatten().tolist()] return RetrievalResult( passages=passages, scores=torch.from_numpy(scores).float(), indices=torch.from_numpy(indices).long(), )