summaryrefslogtreecommitdiff
path: root/tests/test_retriever.py
blob: fa96dca22e8b2825c82c952f79d5859a7f879b44 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""Unit tests for both FAISS and Hopfield retrievers."""

import numpy as np
import torch
import torch.nn.functional as F

from hag.config import HopfieldConfig, MemoryBankConfig
from hag.hopfield import HopfieldRetrieval
from hag.memory_bank import MemoryBank
from hag.retriever_faiss import FAISSRetriever
from hag.retriever_hopfield import HopfieldRetriever


class TestHopfieldRetriever:
    """Tests for the Hopfield-based retriever."""

    def setup_method(self) -> None:
        """Set up a small memory bank and Hopfield retriever."""
        torch.manual_seed(42)
        self.d = 64
        self.N = 50
        self.top_k = 3

        config = MemoryBankConfig(embedding_dim=self.d, normalize=True)
        self.mb = MemoryBank(config)
        embeddings = torch.randn(self.N, self.d)
        self.passages = [f"passage {i}" for i in range(self.N)]
        self.mb.build_from_embeddings(embeddings, self.passages)

        hopfield_config = HopfieldConfig(beta=2.0, max_iter=5)
        hopfield = HopfieldRetrieval(hopfield_config)
        self.retriever = HopfieldRetriever(hopfield, self.mb, top_k=self.top_k)

    def test_retrieves_correct_number_of_passages(self) -> None:
        """top_k=3 should return exactly 3 passages."""
        query = F.normalize(torch.randn(1, self.d), dim=-1)
        result = self.retriever.retrieve(query)
        assert len(result.passages) == self.top_k
        assert result.scores.shape == (self.top_k,)
        assert result.indices.shape == (self.top_k,)

    def test_scores_are_sorted_descending(self) -> None:
        """Returned scores should be in descending order."""
        query = F.normalize(torch.randn(1, self.d), dim=-1)
        result = self.retriever.retrieve(query)
        scores = result.scores.tolist()
        assert scores == sorted(scores, reverse=True)

    def test_passages_match_indices(self) -> None:
        """Returned passage texts should correspond to returned indices."""
        query = F.normalize(torch.randn(1, self.d), dim=-1)
        result = self.retriever.retrieve(query)
        expected_passages = [self.passages[i] for i in result.indices.tolist()]
        assert result.passages == expected_passages

    def test_analysis_mode_returns_hopfield_result(self) -> None:
        """When return_analysis=True, hopfield_result should be populated."""
        query = F.normalize(torch.randn(1, self.d), dim=-1)
        result = self.retriever.retrieve(query, return_analysis=True)
        assert result.hopfield_result is not None
        assert result.hopfield_result.energy_curve is not None
        assert result.hopfield_result.trajectory is not None


class TestFAISSRetriever:
    """Tests for the FAISS baseline retriever."""

    def setup_method(self) -> None:
        """Set up a small FAISS index."""
        np.random.seed(42)
        self.d = 64
        self.N = 50
        self.top_k = 3

        self.embeddings = np.random.randn(self.N, self.d).astype(np.float32)
        self.passages = [f"passage {i}" for i in range(self.N)]

        self.retriever = FAISSRetriever(top_k=self.top_k)
        self.retriever.build_index(self.embeddings.copy(), self.passages)

    def test_retrieves_correct_number(self) -> None:
        """top_k=3 should return exactly 3 passages."""
        query = np.random.randn(self.d).astype(np.float32)
        result = self.retriever.retrieve(query)
        assert len(result.passages) == self.top_k

    def test_nearest_neighbor_is_self(self) -> None:
        """If query = passage_i's embedding, top-1 should be passage_i."""
        # Use the original embedding (before normalization in build_index)
        # Rebuild with fresh copy so we know the normalization state
        embeddings = self.embeddings.copy()
        retriever = FAISSRetriever(top_k=1)
        retriever.build_index(embeddings.copy(), self.passages)

        # Use the normalized version as query (FAISS normalizes internally)
        target_idx = 10
        query = self.embeddings[target_idx].copy()
        result = retriever.retrieve(query)
        assert result.indices[0].item() == target_idx

    def test_scores_sorted_descending(self) -> None:
        """Returned scores should be in descending order."""
        query = np.random.randn(self.d).astype(np.float32)
        result = self.retriever.retrieve(query)
        scores = result.scores.tolist()
        assert scores == sorted(scores, reverse=True)


class TestRetrieverComparison:
    """Compare FAISS and Hopfield retrievers."""

    def test_same_top1_for_obvious_query(self) -> None:
        """When query is very close to one memory, both should agree on top-1."""
        torch.manual_seed(42)
        np.random.seed(42)
        d = 64
        N = 50
        target_idx = 25

        # Create embeddings
        embeddings_np = np.random.randn(N, d).astype(np.float32)
        # Normalize
        norms = np.linalg.norm(embeddings_np, axis=1, keepdims=True)
        embeddings_np = embeddings_np / norms

        # Query is exactly the target embedding
        query_np = embeddings_np[target_idx].copy()
        query_torch = torch.from_numpy(query_np).unsqueeze(0)  # (1, d)

        passages = [f"passage {i}" for i in range(N)]

        # FAISS retriever
        faiss_ret = FAISSRetriever(top_k=1)
        faiss_ret.build_index(embeddings_np.copy(), passages)
        faiss_result = faiss_ret.retrieve(query_np.copy())

        # Hopfield retriever
        mb_config = MemoryBankConfig(embedding_dim=d, normalize=False)
        mb = MemoryBank(mb_config)
        mb.build_from_embeddings(
            torch.from_numpy(embeddings_np), passages
        )
        hop_config = HopfieldConfig(beta=10.0, max_iter=5)
        hopfield = HopfieldRetrieval(hop_config)
        hop_ret = HopfieldRetriever(hopfield, mb, top_k=1)
        hop_result = hop_ret.retrieve(query_torch)

        assert faiss_result.indices[0].item() == target_idx
        assert hop_result.indices[0].item() == target_idx