"""Unit tests for both FAISS and Hopfield retrievers.""" import numpy as np import torch import torch.nn.functional as F from hag.config import HopfieldConfig, MemoryBankConfig from hag.hopfield import HopfieldRetrieval from hag.memory_bank import MemoryBank from hag.retriever_faiss import FAISSRetriever from hag.retriever_hopfield import HopfieldRetriever class TestHopfieldRetriever: """Tests for the Hopfield-based retriever.""" def setup_method(self) -> None: """Set up a small memory bank and Hopfield retriever.""" torch.manual_seed(42) self.d = 64 self.N = 50 self.top_k = 3 config = MemoryBankConfig(embedding_dim=self.d, normalize=True) self.mb = MemoryBank(config) embeddings = torch.randn(self.N, self.d) self.passages = [f"passage {i}" for i in range(self.N)] self.mb.build_from_embeddings(embeddings, self.passages) hopfield_config = HopfieldConfig(beta=2.0, max_iter=5) hopfield = HopfieldRetrieval(hopfield_config) self.retriever = HopfieldRetriever(hopfield, self.mb, top_k=self.top_k) def test_retrieves_correct_number_of_passages(self) -> None: """top_k=3 should return exactly 3 passages.""" query = F.normalize(torch.randn(1, self.d), dim=-1) result = self.retriever.retrieve(query) assert len(result.passages) == self.top_k assert result.scores.shape == (self.top_k,) assert result.indices.shape == (self.top_k,) def test_scores_are_sorted_descending(self) -> None: """Returned scores should be in descending order.""" query = F.normalize(torch.randn(1, self.d), dim=-1) result = self.retriever.retrieve(query) scores = result.scores.tolist() assert scores == sorted(scores, reverse=True) def test_passages_match_indices(self) -> None: """Returned passage texts should correspond to returned indices.""" query = F.normalize(torch.randn(1, self.d), dim=-1) result = self.retriever.retrieve(query) expected_passages = [self.passages[i] for i in result.indices.tolist()] assert result.passages == expected_passages def test_analysis_mode_returns_hopfield_result(self) -> None: """When return_analysis=True, hopfield_result should be populated.""" query = F.normalize(torch.randn(1, self.d), dim=-1) result = self.retriever.retrieve(query, return_analysis=True) assert result.hopfield_result is not None assert result.hopfield_result.energy_curve is not None assert result.hopfield_result.trajectory is not None class TestFAISSRetriever: """Tests for the FAISS baseline retriever.""" def setup_method(self) -> None: """Set up a small FAISS index.""" np.random.seed(42) self.d = 64 self.N = 50 self.top_k = 3 self.embeddings = np.random.randn(self.N, self.d).astype(np.float32) self.passages = [f"passage {i}" for i in range(self.N)] self.retriever = FAISSRetriever(top_k=self.top_k) self.retriever.build_index(self.embeddings.copy(), self.passages) def test_retrieves_correct_number(self) -> None: """top_k=3 should return exactly 3 passages.""" query = np.random.randn(self.d).astype(np.float32) result = self.retriever.retrieve(query) assert len(result.passages) == self.top_k def test_nearest_neighbor_is_self(self) -> None: """If query = passage_i's embedding, top-1 should be passage_i.""" # Use the original embedding (before normalization in build_index) # Rebuild with fresh copy so we know the normalization state embeddings = self.embeddings.copy() retriever = FAISSRetriever(top_k=1) retriever.build_index(embeddings.copy(), self.passages) # Use the normalized version as query (FAISS normalizes internally) target_idx = 10 query = self.embeddings[target_idx].copy() result = retriever.retrieve(query) assert result.indices[0].item() == target_idx def test_scores_sorted_descending(self) -> None: """Returned scores should be in descending order.""" query = np.random.randn(self.d).astype(np.float32) result = self.retriever.retrieve(query) scores = result.scores.tolist() assert scores == sorted(scores, reverse=True) class TestRetrieverComparison: """Compare FAISS and Hopfield retrievers.""" def test_same_top1_for_obvious_query(self) -> None: """When query is very close to one memory, both should agree on top-1.""" torch.manual_seed(42) np.random.seed(42) d = 64 N = 50 target_idx = 25 # Create embeddings embeddings_np = np.random.randn(N, d).astype(np.float32) # Normalize norms = np.linalg.norm(embeddings_np, axis=1, keepdims=True) embeddings_np = embeddings_np / norms # Query is exactly the target embedding query_np = embeddings_np[target_idx].copy() query_torch = torch.from_numpy(query_np).unsqueeze(0) # (1, d) passages = [f"passage {i}" for i in range(N)] # FAISS retriever faiss_ret = FAISSRetriever(top_k=1) faiss_ret.build_index(embeddings_np.copy(), passages) faiss_result = faiss_ret.retrieve(query_np.copy()) # Hopfield retriever mb_config = MemoryBankConfig(embedding_dim=d, normalize=False) mb = MemoryBank(mb_config) mb.build_from_embeddings( torch.from_numpy(embeddings_np), passages ) hop_config = HopfieldConfig(beta=10.0, max_iter=5) hopfield = HopfieldRetrieval(hop_config) hop_ret = HopfieldRetriever(hopfield, mb, top_k=1) hop_result = hop_ret.retrieve(query_torch) assert faiss_result.indices[0].item() == target_idx assert hop_result.indices[0].item() == target_idx