diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
| commit | c90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch) | |
| tree | 43edac8013fec4e65a0b9cddec5314489b4aafc2 /hag/retriever_faiss.py | |
Core Hopfield retrieval module with energy-based convergence guarantees,
memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end
pipeline. All 45 tests passing on CPU with synthetic data.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'hag/retriever_faiss.py')
| -rw-r--r-- | hag/retriever_faiss.py | 73 |
1 files changed, 73 insertions, 0 deletions
diff --git a/hag/retriever_faiss.py b/hag/retriever_faiss.py new file mode 100644 index 0000000..cd54a85 --- /dev/null +++ b/hag/retriever_faiss.py @@ -0,0 +1,73 @@ +"""Baseline FAISS top-k retriever for vanilla RAG.""" + +import logging +from typing import List, Optional + +import faiss +import numpy as np +import torch + +from hag.datatypes import RetrievalResult + +logger = logging.getLogger(__name__) + + +class FAISSRetriever: + """Standard top-k retrieval using FAISS inner product search. + + This is the baseline to compare against Hopfield retrieval. + """ + + def __init__(self, top_k: int = 5) -> None: + self.index: Optional[faiss.IndexFlatIP] = None + self.passages: List[str] = [] + self.top_k = top_k + + def build_index(self, embeddings: np.ndarray, passages: List[str]) -> None: + """Build FAISS IndexFlatIP from embeddings. + + Args: + embeddings: (N, d) numpy array of passage embeddings + passages: list of N passage strings + """ + assert embeddings.shape[0] == len(passages) + d = embeddings.shape[1] + self.index = faiss.IndexFlatIP(d) + # Normalize for cosine similarity via inner product + faiss.normalize_L2(embeddings) + self.index.add(embeddings) + self.passages = list(passages) + logger.info("Built FAISS index with %d passages, dim=%d", len(passages), d) + + def retrieve(self, query: np.ndarray) -> RetrievalResult: + """Retrieve top-k passages for a query. + + Args: + query: (d,) or (batch, d) numpy array + + Returns: + RetrievalResult with passages, scores, and indices. + """ + assert self.index is not None, "Index not built. Call build_index first." + + if query.ndim == 1: + query = query.reshape(1, -1) # (1, d) + + # Normalize query for cosine similarity + query_copy = query.copy() + faiss.normalize_L2(query_copy) + + scores, indices = self.index.search(query_copy, self.top_k) # (batch, k) + + # Flatten for single query case + if scores.shape[0] == 1: + scores = scores[0] # (k,) + indices = indices[0] # (k,) + + passages = [self.passages[i] for i in indices.flatten().tolist()] + + return RetrievalResult( + passages=passages, + scores=torch.from_numpy(scores).float(), + indices=torch.from_numpy(indices).long(), + ) |
