summaryrefslogtreecommitdiff
path: root/hag/retriever_faiss.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
commitc90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch)
tree43edac8013fec4e65a0b9cddec5314489b4aafc2 /hag/retriever_faiss.py
Initial implementation of HAG (Hopfield-Augmented Generation)HEADmaster
Core Hopfield retrieval module with energy-based convergence guarantees, memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end pipeline. All 45 tests passing on CPU with synthetic data. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'hag/retriever_faiss.py')
-rw-r--r--hag/retriever_faiss.py73
1 files changed, 73 insertions, 0 deletions
diff --git a/hag/retriever_faiss.py b/hag/retriever_faiss.py
new file mode 100644
index 0000000..cd54a85
--- /dev/null
+++ b/hag/retriever_faiss.py
@@ -0,0 +1,73 @@
+"""Baseline FAISS top-k retriever for vanilla RAG."""
+
+import logging
+from typing import List, Optional
+
+import faiss
+import numpy as np
+import torch
+
+from hag.datatypes import RetrievalResult
+
+logger = logging.getLogger(__name__)
+
+
+class FAISSRetriever:
+ """Standard top-k retrieval using FAISS inner product search.
+
+ This is the baseline to compare against Hopfield retrieval.
+ """
+
+ def __init__(self, top_k: int = 5) -> None:
+ self.index: Optional[faiss.IndexFlatIP] = None
+ self.passages: List[str] = []
+ self.top_k = top_k
+
+ def build_index(self, embeddings: np.ndarray, passages: List[str]) -> None:
+ """Build FAISS IndexFlatIP from embeddings.
+
+ Args:
+ embeddings: (N, d) numpy array of passage embeddings
+ passages: list of N passage strings
+ """
+ assert embeddings.shape[0] == len(passages)
+ d = embeddings.shape[1]
+ self.index = faiss.IndexFlatIP(d)
+ # Normalize for cosine similarity via inner product
+ faiss.normalize_L2(embeddings)
+ self.index.add(embeddings)
+ self.passages = list(passages)
+ logger.info("Built FAISS index with %d passages, dim=%d", len(passages), d)
+
+ def retrieve(self, query: np.ndarray) -> RetrievalResult:
+ """Retrieve top-k passages for a query.
+
+ Args:
+ query: (d,) or (batch, d) numpy array
+
+ Returns:
+ RetrievalResult with passages, scores, and indices.
+ """
+ assert self.index is not None, "Index not built. Call build_index first."
+
+ if query.ndim == 1:
+ query = query.reshape(1, -1) # (1, d)
+
+ # Normalize query for cosine similarity
+ query_copy = query.copy()
+ faiss.normalize_L2(query_copy)
+
+ scores, indices = self.index.search(query_copy, self.top_k) # (batch, k)
+
+ # Flatten for single query case
+ if scores.shape[0] == 1:
+ scores = scores[0] # (k,)
+ indices = indices[0] # (k,)
+
+ passages = [self.passages[i] for i in indices.flatten().tolist()]
+
+ return RetrievalResult(
+ passages=passages,
+ scores=torch.from_numpy(scores).float(),
+ indices=torch.from_numpy(indices).long(),
+ )