From c90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Sun, 15 Feb 2026 18:19:50 +0000 Subject: Initial implementation of HAG (Hopfield-Augmented Generation) Core Hopfield retrieval module with energy-based convergence guarantees, memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end pipeline. All 45 tests passing on CPU with synthetic data. Co-Authored-By: Claude Opus 4.6 --- hag/retriever_faiss.py | 73 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 hag/retriever_faiss.py (limited to 'hag/retriever_faiss.py') diff --git a/hag/retriever_faiss.py b/hag/retriever_faiss.py new file mode 100644 index 0000000..cd54a85 --- /dev/null +++ b/hag/retriever_faiss.py @@ -0,0 +1,73 @@ +"""Baseline FAISS top-k retriever for vanilla RAG.""" + +import logging +from typing import List, Optional + +import faiss +import numpy as np +import torch + +from hag.datatypes import RetrievalResult + +logger = logging.getLogger(__name__) + + +class FAISSRetriever: + """Standard top-k retrieval using FAISS inner product search. + + This is the baseline to compare against Hopfield retrieval. + """ + + def __init__(self, top_k: int = 5) -> None: + self.index: Optional[faiss.IndexFlatIP] = None + self.passages: List[str] = [] + self.top_k = top_k + + def build_index(self, embeddings: np.ndarray, passages: List[str]) -> None: + """Build FAISS IndexFlatIP from embeddings. + + Args: + embeddings: (N, d) numpy array of passage embeddings + passages: list of N passage strings + """ + assert embeddings.shape[0] == len(passages) + d = embeddings.shape[1] + self.index = faiss.IndexFlatIP(d) + # Normalize for cosine similarity via inner product + faiss.normalize_L2(embeddings) + self.index.add(embeddings) + self.passages = list(passages) + logger.info("Built FAISS index with %d passages, dim=%d", len(passages), d) + + def retrieve(self, query: np.ndarray) -> RetrievalResult: + """Retrieve top-k passages for a query. + + Args: + query: (d,) or (batch, d) numpy array + + Returns: + RetrievalResult with passages, scores, and indices. + """ + assert self.index is not None, "Index not built. Call build_index first." + + if query.ndim == 1: + query = query.reshape(1, -1) # (1, d) + + # Normalize query for cosine similarity + query_copy = query.copy() + faiss.normalize_L2(query_copy) + + scores, indices = self.index.search(query_copy, self.top_k) # (batch, k) + + # Flatten for single query case + if scores.shape[0] == 1: + scores = scores[0] # (k,) + indices = indices[0] # (k,) + + passages = [self.passages[i] for i in indices.flatten().tolist()] + + return RetrievalResult( + passages=passages, + scores=torch.from_numpy(scores).float(), + indices=torch.from_numpy(indices).long(), + ) -- cgit v1.2.3