summaryrefslogtreecommitdiff
path: root/tests/test_retriever.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
commitc90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch)
tree43edac8013fec4e65a0b9cddec5314489b4aafc2 /tests/test_retriever.py
Initial implementation of HAG (Hopfield-Augmented Generation)HEADmaster
Core Hopfield retrieval module with energy-based convergence guarantees, memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end pipeline. All 45 tests passing on CPU with synthetic data. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'tests/test_retriever.py')
-rw-r--r--tests/test_retriever.py149
1 files changed, 149 insertions, 0 deletions
diff --git a/tests/test_retriever.py b/tests/test_retriever.py
new file mode 100644
index 0000000..fa96dca
--- /dev/null
+++ b/tests/test_retriever.py
@@ -0,0 +1,149 @@
+"""Unit tests for both FAISS and Hopfield retrievers."""
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from hag.config import HopfieldConfig, MemoryBankConfig
+from hag.hopfield import HopfieldRetrieval
+from hag.memory_bank import MemoryBank
+from hag.retriever_faiss import FAISSRetriever
+from hag.retriever_hopfield import HopfieldRetriever
+
+
+class TestHopfieldRetriever:
+ """Tests for the Hopfield-based retriever."""
+
+ def setup_method(self) -> None:
+ """Set up a small memory bank and Hopfield retriever."""
+ torch.manual_seed(42)
+ self.d = 64
+ self.N = 50
+ self.top_k = 3
+
+ config = MemoryBankConfig(embedding_dim=self.d, normalize=True)
+ self.mb = MemoryBank(config)
+ embeddings = torch.randn(self.N, self.d)
+ self.passages = [f"passage {i}" for i in range(self.N)]
+ self.mb.build_from_embeddings(embeddings, self.passages)
+
+ hopfield_config = HopfieldConfig(beta=2.0, max_iter=5)
+ hopfield = HopfieldRetrieval(hopfield_config)
+ self.retriever = HopfieldRetriever(hopfield, self.mb, top_k=self.top_k)
+
+ def test_retrieves_correct_number_of_passages(self) -> None:
+ """top_k=3 should return exactly 3 passages."""
+ query = F.normalize(torch.randn(1, self.d), dim=-1)
+ result = self.retriever.retrieve(query)
+ assert len(result.passages) == self.top_k
+ assert result.scores.shape == (self.top_k,)
+ assert result.indices.shape == (self.top_k,)
+
+ def test_scores_are_sorted_descending(self) -> None:
+ """Returned scores should be in descending order."""
+ query = F.normalize(torch.randn(1, self.d), dim=-1)
+ result = self.retriever.retrieve(query)
+ scores = result.scores.tolist()
+ assert scores == sorted(scores, reverse=True)
+
+ def test_passages_match_indices(self) -> None:
+ """Returned passage texts should correspond to returned indices."""
+ query = F.normalize(torch.randn(1, self.d), dim=-1)
+ result = self.retriever.retrieve(query)
+ expected_passages = [self.passages[i] for i in result.indices.tolist()]
+ assert result.passages == expected_passages
+
+ def test_analysis_mode_returns_hopfield_result(self) -> None:
+ """When return_analysis=True, hopfield_result should be populated."""
+ query = F.normalize(torch.randn(1, self.d), dim=-1)
+ result = self.retriever.retrieve(query, return_analysis=True)
+ assert result.hopfield_result is not None
+ assert result.hopfield_result.energy_curve is not None
+ assert result.hopfield_result.trajectory is not None
+
+
+class TestFAISSRetriever:
+ """Tests for the FAISS baseline retriever."""
+
+ def setup_method(self) -> None:
+ """Set up a small FAISS index."""
+ np.random.seed(42)
+ self.d = 64
+ self.N = 50
+ self.top_k = 3
+
+ self.embeddings = np.random.randn(self.N, self.d).astype(np.float32)
+ self.passages = [f"passage {i}" for i in range(self.N)]
+
+ self.retriever = FAISSRetriever(top_k=self.top_k)
+ self.retriever.build_index(self.embeddings.copy(), self.passages)
+
+ def test_retrieves_correct_number(self) -> None:
+ """top_k=3 should return exactly 3 passages."""
+ query = np.random.randn(self.d).astype(np.float32)
+ result = self.retriever.retrieve(query)
+ assert len(result.passages) == self.top_k
+
+ def test_nearest_neighbor_is_self(self) -> None:
+ """If query = passage_i's embedding, top-1 should be passage_i."""
+ # Use the original embedding (before normalization in build_index)
+ # Rebuild with fresh copy so we know the normalization state
+ embeddings = self.embeddings.copy()
+ retriever = FAISSRetriever(top_k=1)
+ retriever.build_index(embeddings.copy(), self.passages)
+
+ # Use the normalized version as query (FAISS normalizes internally)
+ target_idx = 10
+ query = self.embeddings[target_idx].copy()
+ result = retriever.retrieve(query)
+ assert result.indices[0].item() == target_idx
+
+ def test_scores_sorted_descending(self) -> None:
+ """Returned scores should be in descending order."""
+ query = np.random.randn(self.d).astype(np.float32)
+ result = self.retriever.retrieve(query)
+ scores = result.scores.tolist()
+ assert scores == sorted(scores, reverse=True)
+
+
+class TestRetrieverComparison:
+ """Compare FAISS and Hopfield retrievers."""
+
+ def test_same_top1_for_obvious_query(self) -> None:
+ """When query is very close to one memory, both should agree on top-1."""
+ torch.manual_seed(42)
+ np.random.seed(42)
+ d = 64
+ N = 50
+ target_idx = 25
+
+ # Create embeddings
+ embeddings_np = np.random.randn(N, d).astype(np.float32)
+ # Normalize
+ norms = np.linalg.norm(embeddings_np, axis=1, keepdims=True)
+ embeddings_np = embeddings_np / norms
+
+ # Query is exactly the target embedding
+ query_np = embeddings_np[target_idx].copy()
+ query_torch = torch.from_numpy(query_np).unsqueeze(0) # (1, d)
+
+ passages = [f"passage {i}" for i in range(N)]
+
+ # FAISS retriever
+ faiss_ret = FAISSRetriever(top_k=1)
+ faiss_ret.build_index(embeddings_np.copy(), passages)
+ faiss_result = faiss_ret.retrieve(query_np.copy())
+
+ # Hopfield retriever
+ mb_config = MemoryBankConfig(embedding_dim=d, normalize=False)
+ mb = MemoryBank(mb_config)
+ mb.build_from_embeddings(
+ torch.from_numpy(embeddings_np), passages
+ )
+ hop_config = HopfieldConfig(beta=10.0, max_iter=5)
+ hopfield = HopfieldRetrieval(hop_config)
+ hop_ret = HopfieldRetriever(hopfield, mb, top_k=1)
+ hop_result = hop_ret.retrieve(query_torch)
+
+ assert faiss_result.indices[0].item() == target_idx
+ assert hop_result.indices[0].item() == target_idx