diff options
Diffstat (limited to 'tests/test_retriever.py')
| -rw-r--r-- | tests/test_retriever.py | 149 |
1 files changed, 149 insertions, 0 deletions
diff --git a/tests/test_retriever.py b/tests/test_retriever.py new file mode 100644 index 0000000..fa96dca --- /dev/null +++ b/tests/test_retriever.py @@ -0,0 +1,149 @@ +"""Unit tests for both FAISS and Hopfield retrievers.""" + +import numpy as np +import torch +import torch.nn.functional as F + +from hag.config import HopfieldConfig, MemoryBankConfig +from hag.hopfield import HopfieldRetrieval +from hag.memory_bank import MemoryBank +from hag.retriever_faiss import FAISSRetriever +from hag.retriever_hopfield import HopfieldRetriever + + +class TestHopfieldRetriever: + """Tests for the Hopfield-based retriever.""" + + def setup_method(self) -> None: + """Set up a small memory bank and Hopfield retriever.""" + torch.manual_seed(42) + self.d = 64 + self.N = 50 + self.top_k = 3 + + config = MemoryBankConfig(embedding_dim=self.d, normalize=True) + self.mb = MemoryBank(config) + embeddings = torch.randn(self.N, self.d) + self.passages = [f"passage {i}" for i in range(self.N)] + self.mb.build_from_embeddings(embeddings, self.passages) + + hopfield_config = HopfieldConfig(beta=2.0, max_iter=5) + hopfield = HopfieldRetrieval(hopfield_config) + self.retriever = HopfieldRetriever(hopfield, self.mb, top_k=self.top_k) + + def test_retrieves_correct_number_of_passages(self) -> None: + """top_k=3 should return exactly 3 passages.""" + query = F.normalize(torch.randn(1, self.d), dim=-1) + result = self.retriever.retrieve(query) + assert len(result.passages) == self.top_k + assert result.scores.shape == (self.top_k,) + assert result.indices.shape == (self.top_k,) + + def test_scores_are_sorted_descending(self) -> None: + """Returned scores should be in descending order.""" + query = F.normalize(torch.randn(1, self.d), dim=-1) + result = self.retriever.retrieve(query) + scores = result.scores.tolist() + assert scores == sorted(scores, reverse=True) + + def test_passages_match_indices(self) -> None: + """Returned passage texts should correspond to returned indices.""" + query = F.normalize(torch.randn(1, self.d), dim=-1) + result = self.retriever.retrieve(query) + expected_passages = [self.passages[i] for i in result.indices.tolist()] + assert result.passages == expected_passages + + def test_analysis_mode_returns_hopfield_result(self) -> None: + """When return_analysis=True, hopfield_result should be populated.""" + query = F.normalize(torch.randn(1, self.d), dim=-1) + result = self.retriever.retrieve(query, return_analysis=True) + assert result.hopfield_result is not None + assert result.hopfield_result.energy_curve is not None + assert result.hopfield_result.trajectory is not None + + +class TestFAISSRetriever: + """Tests for the FAISS baseline retriever.""" + + def setup_method(self) -> None: + """Set up a small FAISS index.""" + np.random.seed(42) + self.d = 64 + self.N = 50 + self.top_k = 3 + + self.embeddings = np.random.randn(self.N, self.d).astype(np.float32) + self.passages = [f"passage {i}" for i in range(self.N)] + + self.retriever = FAISSRetriever(top_k=self.top_k) + self.retriever.build_index(self.embeddings.copy(), self.passages) + + def test_retrieves_correct_number(self) -> None: + """top_k=3 should return exactly 3 passages.""" + query = np.random.randn(self.d).astype(np.float32) + result = self.retriever.retrieve(query) + assert len(result.passages) == self.top_k + + def test_nearest_neighbor_is_self(self) -> None: + """If query = passage_i's embedding, top-1 should be passage_i.""" + # Use the original embedding (before normalization in build_index) + # Rebuild with fresh copy so we know the normalization state + embeddings = self.embeddings.copy() + retriever = FAISSRetriever(top_k=1) + retriever.build_index(embeddings.copy(), self.passages) + + # Use the normalized version as query (FAISS normalizes internally) + target_idx = 10 + query = self.embeddings[target_idx].copy() + result = retriever.retrieve(query) + assert result.indices[0].item() == target_idx + + def test_scores_sorted_descending(self) -> None: + """Returned scores should be in descending order.""" + query = np.random.randn(self.d).astype(np.float32) + result = self.retriever.retrieve(query) + scores = result.scores.tolist() + assert scores == sorted(scores, reverse=True) + + +class TestRetrieverComparison: + """Compare FAISS and Hopfield retrievers.""" + + def test_same_top1_for_obvious_query(self) -> None: + """When query is very close to one memory, both should agree on top-1.""" + torch.manual_seed(42) + np.random.seed(42) + d = 64 + N = 50 + target_idx = 25 + + # Create embeddings + embeddings_np = np.random.randn(N, d).astype(np.float32) + # Normalize + norms = np.linalg.norm(embeddings_np, axis=1, keepdims=True) + embeddings_np = embeddings_np / norms + + # Query is exactly the target embedding + query_np = embeddings_np[target_idx].copy() + query_torch = torch.from_numpy(query_np).unsqueeze(0) # (1, d) + + passages = [f"passage {i}" for i in range(N)] + + # FAISS retriever + faiss_ret = FAISSRetriever(top_k=1) + faiss_ret.build_index(embeddings_np.copy(), passages) + faiss_result = faiss_ret.retrieve(query_np.copy()) + + # Hopfield retriever + mb_config = MemoryBankConfig(embedding_dim=d, normalize=False) + mb = MemoryBank(mb_config) + mb.build_from_embeddings( + torch.from_numpy(embeddings_np), passages + ) + hop_config = HopfieldConfig(beta=10.0, max_iter=5) + hopfield = HopfieldRetrieval(hop_config) + hop_ret = HopfieldRetriever(hopfield, mb, top_k=1) + hop_result = hop_ret.retrieve(query_torch) + + assert faiss_result.indices[0].item() == target_idx + assert hop_result.indices[0].item() == target_idx |
