summaryrefslogtreecommitdiff
path: root/hag/retriever_faiss.py
blob: cd54a85d724e1174d5738c83de5bf6744c21228e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""Baseline FAISS top-k retriever for vanilla RAG."""

import logging
from typing import List, Optional

import faiss
import numpy as np
import torch

from hag.datatypes import RetrievalResult

logger = logging.getLogger(__name__)


class FAISSRetriever:
    """Standard top-k retrieval using FAISS inner product search.

    This is the baseline to compare against Hopfield retrieval.
    """

    def __init__(self, top_k: int = 5) -> None:
        self.index: Optional[faiss.IndexFlatIP] = None
        self.passages: List[str] = []
        self.top_k = top_k

    def build_index(self, embeddings: np.ndarray, passages: List[str]) -> None:
        """Build FAISS IndexFlatIP from embeddings.

        Args:
            embeddings: (N, d) numpy array of passage embeddings
            passages: list of N passage strings
        """
        assert embeddings.shape[0] == len(passages)
        d = embeddings.shape[1]
        self.index = faiss.IndexFlatIP(d)
        # Normalize for cosine similarity via inner product
        faiss.normalize_L2(embeddings)
        self.index.add(embeddings)
        self.passages = list(passages)
        logger.info("Built FAISS index with %d passages, dim=%d", len(passages), d)

    def retrieve(self, query: np.ndarray) -> RetrievalResult:
        """Retrieve top-k passages for a query.

        Args:
            query: (d,) or (batch, d) numpy array

        Returns:
            RetrievalResult with passages, scores, and indices.
        """
        assert self.index is not None, "Index not built. Call build_index first."

        if query.ndim == 1:
            query = query.reshape(1, -1)  # (1, d)

        # Normalize query for cosine similarity
        query_copy = query.copy()
        faiss.normalize_L2(query_copy)

        scores, indices = self.index.search(query_copy, self.top_k)  # (batch, k)

        # Flatten for single query case
        if scores.shape[0] == 1:
            scores = scores[0]  # (k,)
            indices = indices[0]  # (k,)

        passages = [self.passages[i] for i in indices.flatten().tolist()]

        return RetrievalResult(
            passages=passages,
            scores=torch.from_numpy(scores).float(),
            indices=torch.from_numpy(indices).long(),
        )