1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
|
"""Baseline FAISS top-k retriever for vanilla RAG."""
import logging
from typing import List, Optional
import faiss
import numpy as np
import torch
from hag.datatypes import RetrievalResult
logger = logging.getLogger(__name__)
class FAISSRetriever:
"""Standard top-k retrieval using FAISS inner product search.
This is the baseline to compare against Hopfield retrieval.
"""
def __init__(self, top_k: int = 5) -> None:
self.index: Optional[faiss.IndexFlatIP] = None
self.passages: List[str] = []
self.top_k = top_k
def build_index(self, embeddings: np.ndarray, passages: List[str]) -> None:
"""Build FAISS IndexFlatIP from embeddings.
Args:
embeddings: (N, d) numpy array of passage embeddings
passages: list of N passage strings
"""
assert embeddings.shape[0] == len(passages)
d = embeddings.shape[1]
self.index = faiss.IndexFlatIP(d)
# Normalize for cosine similarity via inner product
faiss.normalize_L2(embeddings)
self.index.add(embeddings)
self.passages = list(passages)
logger.info("Built FAISS index with %d passages, dim=%d", len(passages), d)
def retrieve(self, query: np.ndarray) -> RetrievalResult:
"""Retrieve top-k passages for a query.
Args:
query: (d,) or (batch, d) numpy array
Returns:
RetrievalResult with passages, scores, and indices.
"""
assert self.index is not None, "Index not built. Call build_index first."
if query.ndim == 1:
query = query.reshape(1, -1) # (1, d)
# Normalize query for cosine similarity
query_copy = query.copy()
faiss.normalize_L2(query_copy)
scores, indices = self.index.search(query_copy, self.top_k) # (batch, k)
# Flatten for single query case
if scores.shape[0] == 1:
scores = scores[0] # (k,)
indices = indices[0] # (k,)
passages = [self.passages[i] for i in indices.flatten().tolist()]
return RetrievalResult(
passages=passages,
scores=torch.from_numpy(scores).float(),
indices=torch.from_numpy(indices).long(),
)
|