diff options
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/__init__.py | 0 | ||||
| -rw-r--r-- | tests/test_energy.py | 80 | ||||
| -rw-r--r-- | tests/test_hopfield.py | 147 | ||||
| -rw-r--r-- | tests/test_memory_bank.py | 65 | ||||
| -rw-r--r-- | tests/test_metrics.py | 54 | ||||
| -rw-r--r-- | tests/test_pipeline.py | 132 | ||||
| -rw-r--r-- | tests/test_retriever.py | 149 |
7 files changed, 627 insertions, 0 deletions
diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/__init__.py diff --git a/tests/test_energy.py b/tests/test_energy.py new file mode 100644 index 0000000..256f2d8 --- /dev/null +++ b/tests/test_energy.py @@ -0,0 +1,80 @@ +"""Unit tests for the energy computation module.""" + +import torch + +from hag.config import HopfieldConfig +from hag.energy import ( + compute_attention_entropy, + compute_energy_curve, + compute_energy_gap, + verify_monotonic_decrease, +) +from hag.hopfield import HopfieldRetrieval + + +class TestEnergy: + """Tests for energy computation and analysis utilities.""" + + def test_energy_computation_matches_formula(self) -> None: + """Manually compute E(q) and verify it matches the module output.""" + torch.manual_seed(42) + d, N = 16, 10 + memory = torch.randn(d, N) + query = torch.randn(1, d) + beta = 2.0 + + config = HopfieldConfig(beta=beta) + hopfield = HopfieldRetrieval(config) + computed = hopfield.compute_energy(query, memory) + + # Manual computation: + # E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2 + logits = beta * (query @ memory) # (1, N) + expected = ( + -1.0 / beta * torch.logsumexp(logits, dim=-1) + + 0.5 * (query**2).sum(dim=-1) + ) + assert torch.allclose(computed, expected, atol=1e-5) + + def test_verify_monotonic(self) -> None: + """Monotonically decreasing energy curve should pass.""" + energies = [10.0, 8.0, 6.5, 6.4, 6.39] + assert verify_monotonic_decrease(energies) is True + + def test_verify_monotonic_fails(self) -> None: + """Non-monotonic energy curve should fail.""" + energies = [10.0, 8.0, 9.0, 6.0] + assert verify_monotonic_decrease(energies) is False + + def test_attention_entropy(self) -> None: + """Uniform attention should have max entropy, one-hot should have zero.""" + uniform = torch.ones(100) / 100 + one_hot = torch.zeros(100) + one_hot[0] = 1.0 + assert compute_attention_entropy(uniform) > compute_attention_entropy(one_hot) + assert compute_attention_entropy(one_hot) < 0.01 + + def test_energy_curve_extraction(self) -> None: + """Energy curve should be extractable from a HopfieldResult.""" + torch.manual_seed(42) + d, N = 32, 50 + memory = torch.randn(d, N) + query = torch.randn(1, d) + config = HopfieldConfig(beta=1.0, max_iter=5) + hopfield = HopfieldRetrieval(config) + + result = hopfield.retrieve(query, memory, return_energy=True) + curve = compute_energy_curve(result) + assert len(curve) > 1 + assert all(isinstance(v, float) for v in curve) + + def test_energy_gap(self) -> None: + """Energy gap should be positive when energy decreases.""" + energies = [10.0, 8.0, 6.0] + gap = compute_energy_gap(energies) + assert gap == 4.0 + + def test_energy_gap_single_point(self) -> None: + """Energy gap with fewer than 2 points should be 0.""" + assert compute_energy_gap([5.0]) == 0.0 + assert compute_energy_gap([]) == 0.0 diff --git a/tests/test_hopfield.py b/tests/test_hopfield.py new file mode 100644 index 0000000..246b120 --- /dev/null +++ b/tests/test_hopfield.py @@ -0,0 +1,147 @@ +"""Unit tests for the core Hopfield retrieval module.""" + +import torch +import torch.nn.functional as F + +from hag.config import HopfieldConfig +from hag.hopfield import HopfieldRetrieval + + +class TestHopfieldRetrieval: + """Test the core Hopfield module with synthetic data on CPU.""" + + def setup_method(self) -> None: + """Create small synthetic memory bank and queries.""" + torch.manual_seed(42) + self.d = 64 # embedding dim + self.N = 100 # number of memories + self.memory = F.normalize(torch.randn(self.d, self.N), dim=0) # (d, N) + self.query = F.normalize(torch.randn(1, self.d), dim=-1) # (1, d) + self.config = HopfieldConfig(beta=1.0, max_iter=10, conv_threshold=1e-6) + self.hopfield = HopfieldRetrieval(self.config) + + def test_output_shapes(self) -> None: + """attention_weights should be (1, N), converged_query should be (1, d).""" + result = self.hopfield.retrieve(self.query, self.memory) + assert result.attention_weights.shape == (1, self.N) + assert result.converged_query.shape == (1, self.d) + + def test_attention_weights_sum_to_one(self) -> None: + """softmax output must sum to 1.""" + result = self.hopfield.retrieve(self.query, self.memory) + assert torch.allclose( + result.attention_weights.sum(dim=-1), torch.ones(1), atol=1e-5 + ) + + def test_attention_weights_non_negative(self) -> None: + """All attention weights must be >= 0.""" + result = self.hopfield.retrieve(self.query, self.memory) + assert (result.attention_weights >= 0).all() + + def test_energy_monotonic_decrease(self) -> None: + """E(q_{t+1}) <= E(q_t) for all t. This is THE key theoretical property.""" + result = self.hopfield.retrieve( + self.query, self.memory, return_energy=True + ) + energies = [e.item() for e in result.energy_curve] + for i in range(len(energies) - 1): + assert energies[i + 1] <= energies[i] + 1e-6, ( + f"Energy increased at step {i}: {energies[i]} -> {energies[i + 1]}" + ) + + def test_convergence(self) -> None: + """With enough iterations, query should converge (delta < threshold).""" + config = HopfieldConfig(beta=2.0, max_iter=50, conv_threshold=1e-6) + hopfield = HopfieldRetrieval(config) + result = hopfield.retrieve( + self.query, self.memory, return_trajectory=True + ) + # Final two queries should be very close + q_last = result.trajectory[-1] + q_prev = result.trajectory[-2] + delta = torch.norm(q_last - q_prev) + assert delta < 1e-4, f"Did not converge: delta={delta}" + + def test_high_beta_sharp_retrieval(self) -> None: + """Higher beta should produce sharper (lower entropy) attention.""" + low_beta = HopfieldRetrieval(HopfieldConfig(beta=0.5, max_iter=5)) + high_beta = HopfieldRetrieval(HopfieldConfig(beta=5.0, max_iter=5)) + + result_low = low_beta.retrieve(self.query, self.memory) + result_high = high_beta.retrieve(self.query, self.memory) + + entropy_low = -( + result_low.attention_weights * result_low.attention_weights.log() + ).sum() + entropy_high = -( + result_high.attention_weights * result_high.attention_weights.log() + ).sum() + + assert entropy_high < entropy_low, "Higher beta should give lower entropy" + + def test_single_memory_converges_to_it(self) -> None: + """With N=1, retrieval should converge to the single memory.""" + single_memory = F.normalize(torch.randn(self.d, 1), dim=0) + result = self.hopfield.retrieve(self.query, single_memory) + assert torch.allclose( + result.attention_weights, torch.ones(1, 1), atol=1e-5 + ) + + def test_query_near_memory_retrieves_it(self) -> None: + """If query ~= memory_i, attention should peak at index i.""" + target_idx = 42 + query = self.memory[:, target_idx].unsqueeze(0) # (1, d) — exact match + config = HopfieldConfig(beta=10.0, max_iter=5) + hopfield = HopfieldRetrieval(config) + result = hopfield.retrieve(query, self.memory) + top_idx = result.attention_weights.argmax(dim=-1).item() + assert top_idx == target_idx, f"Expected {target_idx}, got {top_idx}" + + def test_batch_retrieval(self) -> None: + """Should handle batch of queries.""" + batch_query = F.normalize(torch.randn(8, self.d), dim=-1) + result = self.hopfield.retrieve(batch_query, self.memory) + assert result.attention_weights.shape == (8, self.N) + assert result.converged_query.shape == (8, self.d) + + def test_iteration_refines_query(self) -> None: + """Multi-hop test: query starts far from target, iteration should bring it closer. + + Setup: memory has two clusters. Query is near cluster A but the "answer" + is in cluster B, reachable through an intermediate memory that bridges both. + After iteration, the query should drift toward cluster B. + """ + torch.manual_seed(0) + d = 32 + + # Cluster A: memories 0-4 centered around direction [1, 0, 0, ...] + # Cluster B: memories 5-9 centered around direction [0, 1, 0, ...] + # Bridge memory 10: between A and B + center_a = torch.zeros(d) + center_a[0] = 1.0 + center_b = torch.zeros(d) + center_b[1] = 1.0 + bridge = F.normalize(center_a + center_b, dim=0) + + memories = [] + for _ in range(5): + memories.append(F.normalize(center_a + 0.1 * torch.randn(d), dim=0)) + for _ in range(5): + memories.append(F.normalize(center_b + 0.1 * torch.randn(d), dim=0)) + memories.append(bridge) + + M = torch.stack(memories, dim=1) # (d, 11) + q0 = F.normalize(center_a + 0.05 * torch.randn(d), dim=0).unsqueeze(0) + + config = HopfieldConfig(beta=3.0, max_iter=10, conv_threshold=1e-8) + hopfield = HopfieldRetrieval(config) + result = hopfield.retrieve(q0, M, return_trajectory=True) + + # After iteration, query should have drifted: its dot product with center_b + # should be higher than the initial query's dot product with center_b + initial_sim_b = (q0.squeeze() @ center_b).item() + final_sim_b = (result.converged_query.squeeze() @ center_b).item() + assert final_sim_b > initial_sim_b, ( + f"Iteration should pull query toward cluster B: " + f"initial={initial_sim_b:.4f}, final={final_sim_b:.4f}" + ) diff --git a/tests/test_memory_bank.py b/tests/test_memory_bank.py new file mode 100644 index 0000000..0087bbd --- /dev/null +++ b/tests/test_memory_bank.py @@ -0,0 +1,65 @@ +"""Unit tests for the memory bank module.""" + +import torch + +from hag.config import MemoryBankConfig +from hag.memory_bank import MemoryBank + + +class TestMemoryBank: + """Tests for MemoryBank construction, lookup, and persistence.""" + + def test_build_and_size(self) -> None: + """Build memory bank and verify size.""" + config = MemoryBankConfig(embedding_dim=64, normalize=True) + mb = MemoryBank(config) + embeddings = torch.randn(100, 64) + passages = [f"passage {i}" for i in range(100)] + mb.build_from_embeddings(embeddings, passages) + assert mb.size == 100 + assert mb.dim == 64 + + def test_normalization(self) -> None: + """If normalize=True, stored embeddings should have unit norm.""" + config = MemoryBankConfig(embedding_dim=64, normalize=True) + mb = MemoryBank(config) + embeddings = torch.randn(50, 64) * 5 # non-unit norm + mb.build_from_embeddings(embeddings, [f"p{i}" for i in range(50)]) + norms = torch.norm(mb.embeddings, dim=0) + assert torch.allclose(norms, torch.ones(50), atol=1e-5) + + def test_no_normalization(self) -> None: + """If normalize=False, stored embeddings keep original norms.""" + config = MemoryBankConfig(embedding_dim=64, normalize=False) + mb = MemoryBank(config) + embeddings = torch.randn(50, 64) * 5 + original_norms = torch.norm(embeddings, dim=-1) + mb.build_from_embeddings(embeddings, [f"p{i}" for i in range(50)]) + stored_norms = torch.norm(mb.embeddings, dim=0) + assert torch.allclose(stored_norms, original_norms, atol=1e-5) + + def test_get_passages_by_indices(self) -> None: + """Index -> passage text lookup.""" + config = MemoryBankConfig(embedding_dim=64, normalize=False) + mb = MemoryBank(config) + passages = [f"passage {i}" for i in range(100)] + mb.build_from_embeddings(torch.randn(100, 64), passages) + result = mb.get_passages_by_indices(torch.tensor([0, 50, 99])) + assert result == ["passage 0", "passage 50", "passage 99"] + + def test_save_and_load(self, tmp_path) -> None: # type: ignore[no-untyped-def] + """Save and reload memory bank, verify contents match.""" + config = MemoryBankConfig(embedding_dim=32, normalize=True) + mb = MemoryBank(config) + embeddings = torch.randn(20, 32) + passages = [f"text {i}" for i in range(20)] + mb.build_from_embeddings(embeddings, passages) + + save_path = str(tmp_path / "mb.pt") + mb.save(save_path) + + mb2 = MemoryBank(config) + mb2.load(save_path) + assert mb2.size == 20 + assert mb2.passages == passages + assert torch.allclose(mb.embeddings, mb2.embeddings, atol=1e-6) diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000..4a18bec --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,54 @@ +"""Unit tests for evaluation metrics.""" + +from hag.metrics import exact_match, f1_score, retrieval_recall_at_k + + +class TestMetrics: + """Tests for EM, F1, and retrieval recall metrics.""" + + def test_exact_match_basic(self) -> None: + """Basic exact match: case insensitive.""" + assert exact_match("Paris", "paris") == 1.0 + assert exact_match("Paris", "London") == 0.0 + + def test_exact_match_normalization(self) -> None: + """Should strip whitespace, lowercase, remove articles.""" + assert exact_match(" The Paris ", "paris") == 1.0 + assert exact_match("A dog", "dog") == 1.0 + + def test_exact_match_empty(self) -> None: + """Empty strings should match.""" + assert exact_match("", "") == 1.0 + + def test_f1_score_perfect(self) -> None: + """Identical strings should have F1 = 1.0.""" + assert f1_score("the cat sat", "the cat sat") == 1.0 + + def test_f1_score_partial(self) -> None: + """Partial overlap should give 0 < F1 < 1.""" + score = f1_score("the cat sat on the mat", "the cat sat") + assert 0.5 < score < 1.0 + + def test_f1_score_no_overlap(self) -> None: + """No common tokens should give F1 = 0.""" + assert f1_score("hello world", "foo bar") == 0.0 + + def test_f1_score_empty(self) -> None: + """Two empty strings should have F1 = 1.0.""" + assert f1_score("", "") == 1.0 + + def test_retrieval_recall(self) -> None: + """Standard retrieval recall computation.""" + assert retrieval_recall_at_k([1, 3, 5], [1, 5, 7], k=3) == 2 / 3 + + def test_retrieval_recall_perfect(self) -> None: + """All gold passages retrieved.""" + assert retrieval_recall_at_k([1, 2, 3], [1, 2, 3], k=3) == 1.0 + + def test_retrieval_recall_none(self) -> None: + """No gold passages retrieved.""" + assert retrieval_recall_at_k([4, 5, 6], [1, 2, 3], k=3) == 0.0 + + def test_retrieval_recall_empty_gold(self) -> None: + """No gold passages means perfect recall by convention.""" + assert retrieval_recall_at_k([1, 2, 3], [], k=3) == 1.0 diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 0000000..3ac81ce --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,132 @@ +"""Integration tests for the RAG/HAG pipeline using mock encoder and generator.""" + +import numpy as np +import torch +import torch.nn.functional as F + +from hag.config import ( + HopfieldConfig, + MemoryBankConfig, + PipelineConfig, +) +from hag.encoder import FakeEncoder +from hag.generator import FakeGenerator +from hag.memory_bank import MemoryBank +from hag.pipeline import RAGPipeline +from hag.retriever_faiss import FAISSRetriever + + +def _make_memory_bank(n: int = 20, d: int = 64) -> MemoryBank: + """Create a small memory bank with random embeddings.""" + torch.manual_seed(123) + config = MemoryBankConfig(embedding_dim=d, normalize=True) + mb = MemoryBank(config) + embeddings = torch.randn(n, d) + passages = [f"passage {i} content" for i in range(n)] + mb.build_from_embeddings(embeddings, passages) + return mb + + +def _make_faiss_retriever(n: int = 20, d: int = 64) -> FAISSRetriever: + """Create a small FAISS retriever with random embeddings.""" + np.random.seed(123) + embeddings = np.random.randn(n, d).astype(np.float32) + passages = [f"passage {i} content" for i in range(n)] + retriever = FAISSRetriever(top_k=3) + retriever.build_index(embeddings, passages) + return retriever + + +class TestPipeline: + """Integration tests using mock encoder and generator.""" + + def test_pipeline_runs_end_to_end_hopfield(self) -> None: + """Mock encoder + mock LLM + Hopfield retriever -> produces an answer.""" + config = PipelineConfig( + hopfield=HopfieldConfig(beta=1.0, max_iter=3, top_k=3), + memory=MemoryBankConfig(embedding_dim=64, normalize=True), + retriever_type="hopfield", + ) + mb = _make_memory_bank(n=20, d=64) + encoder = FakeEncoder(dim=64) + generator = FakeGenerator() + + pipeline = RAGPipeline( + config=config, + encoder=encoder, + generator=generator, + memory_bank=mb, + ) + result = pipeline.run("What is the capital of France?") + + assert isinstance(result.answer, str) + assert len(result.answer) > 0 + assert len(result.retrieved_passages) == 3 + + def test_pipeline_runs_end_to_end_faiss(self) -> None: + """Mock encoder + mock LLM + FAISS retriever -> produces an answer.""" + config = PipelineConfig( + hopfield=HopfieldConfig(top_k=3), + retriever_type="faiss", + ) + faiss_ret = _make_faiss_retriever(n=20, d=64) + encoder = FakeEncoder(dim=64) + generator = FakeGenerator() + + pipeline = RAGPipeline( + config=config, + encoder=encoder, + generator=generator, + faiss_retriever=faiss_ret, + ) + result = pipeline.run("What is the capital of France?") + + assert isinstance(result.answer, str) + assert len(result.answer) > 0 + assert len(result.retrieved_passages) == 3 + + def test_pipeline_retrieval_result_included(self) -> None: + """PipelineResult should contain retrieval metadata.""" + config = PipelineConfig( + hopfield=HopfieldConfig(beta=1.0, max_iter=3, top_k=3), + memory=MemoryBankConfig(embedding_dim=64, normalize=True), + retriever_type="hopfield", + ) + mb = _make_memory_bank(n=20, d=64) + encoder = FakeEncoder(dim=64) + generator = FakeGenerator() + + pipeline = RAGPipeline( + config=config, + encoder=encoder, + generator=generator, + memory_bank=mb, + ) + result = pipeline.run("Test question?") + + assert result.retrieval_result is not None + assert result.retrieval_result.scores is not None + assert result.retrieval_result.indices is not None + assert result.question == "Test question?" + + def test_pipeline_batch(self) -> None: + """run_batch should return results for each question.""" + config = PipelineConfig( + hopfield=HopfieldConfig(beta=1.0, max_iter=3, top_k=3), + memory=MemoryBankConfig(embedding_dim=64, normalize=True), + retriever_type="hopfield", + ) + mb = _make_memory_bank(n=20, d=64) + encoder = FakeEncoder(dim=64) + generator = FakeGenerator() + + pipeline = RAGPipeline( + config=config, + encoder=encoder, + generator=generator, + memory_bank=mb, + ) + questions = ["Q1?", "Q2?", "Q3?"] + results = pipeline.run_batch(questions) + assert len(results) == 3 + assert all(r.answer == "mock answer" for r in results) diff --git a/tests/test_retriever.py b/tests/test_retriever.py new file mode 100644 index 0000000..fa96dca --- /dev/null +++ b/tests/test_retriever.py @@ -0,0 +1,149 @@ +"""Unit tests for both FAISS and Hopfield retrievers.""" + +import numpy as np +import torch +import torch.nn.functional as F + +from hag.config import HopfieldConfig, MemoryBankConfig +from hag.hopfield import HopfieldRetrieval +from hag.memory_bank import MemoryBank +from hag.retriever_faiss import FAISSRetriever +from hag.retriever_hopfield import HopfieldRetriever + + +class TestHopfieldRetriever: + """Tests for the Hopfield-based retriever.""" + + def setup_method(self) -> None: + """Set up a small memory bank and Hopfield retriever.""" + torch.manual_seed(42) + self.d = 64 + self.N = 50 + self.top_k = 3 + + config = MemoryBankConfig(embedding_dim=self.d, normalize=True) + self.mb = MemoryBank(config) + embeddings = torch.randn(self.N, self.d) + self.passages = [f"passage {i}" for i in range(self.N)] + self.mb.build_from_embeddings(embeddings, self.passages) + + hopfield_config = HopfieldConfig(beta=2.0, max_iter=5) + hopfield = HopfieldRetrieval(hopfield_config) + self.retriever = HopfieldRetriever(hopfield, self.mb, top_k=self.top_k) + + def test_retrieves_correct_number_of_passages(self) -> None: + """top_k=3 should return exactly 3 passages.""" + query = F.normalize(torch.randn(1, self.d), dim=-1) + result = self.retriever.retrieve(query) + assert len(result.passages) == self.top_k + assert result.scores.shape == (self.top_k,) + assert result.indices.shape == (self.top_k,) + + def test_scores_are_sorted_descending(self) -> None: + """Returned scores should be in descending order.""" + query = F.normalize(torch.randn(1, self.d), dim=-1) + result = self.retriever.retrieve(query) + scores = result.scores.tolist() + assert scores == sorted(scores, reverse=True) + + def test_passages_match_indices(self) -> None: + """Returned passage texts should correspond to returned indices.""" + query = F.normalize(torch.randn(1, self.d), dim=-1) + result = self.retriever.retrieve(query) + expected_passages = [self.passages[i] for i in result.indices.tolist()] + assert result.passages == expected_passages + + def test_analysis_mode_returns_hopfield_result(self) -> None: + """When return_analysis=True, hopfield_result should be populated.""" + query = F.normalize(torch.randn(1, self.d), dim=-1) + result = self.retriever.retrieve(query, return_analysis=True) + assert result.hopfield_result is not None + assert result.hopfield_result.energy_curve is not None + assert result.hopfield_result.trajectory is not None + + +class TestFAISSRetriever: + """Tests for the FAISS baseline retriever.""" + + def setup_method(self) -> None: + """Set up a small FAISS index.""" + np.random.seed(42) + self.d = 64 + self.N = 50 + self.top_k = 3 + + self.embeddings = np.random.randn(self.N, self.d).astype(np.float32) + self.passages = [f"passage {i}" for i in range(self.N)] + + self.retriever = FAISSRetriever(top_k=self.top_k) + self.retriever.build_index(self.embeddings.copy(), self.passages) + + def test_retrieves_correct_number(self) -> None: + """top_k=3 should return exactly 3 passages.""" + query = np.random.randn(self.d).astype(np.float32) + result = self.retriever.retrieve(query) + assert len(result.passages) == self.top_k + + def test_nearest_neighbor_is_self(self) -> None: + """If query = passage_i's embedding, top-1 should be passage_i.""" + # Use the original embedding (before normalization in build_index) + # Rebuild with fresh copy so we know the normalization state + embeddings = self.embeddings.copy() + retriever = FAISSRetriever(top_k=1) + retriever.build_index(embeddings.copy(), self.passages) + + # Use the normalized version as query (FAISS normalizes internally) + target_idx = 10 + query = self.embeddings[target_idx].copy() + result = retriever.retrieve(query) + assert result.indices[0].item() == target_idx + + def test_scores_sorted_descending(self) -> None: + """Returned scores should be in descending order.""" + query = np.random.randn(self.d).astype(np.float32) + result = self.retriever.retrieve(query) + scores = result.scores.tolist() + assert scores == sorted(scores, reverse=True) + + +class TestRetrieverComparison: + """Compare FAISS and Hopfield retrievers.""" + + def test_same_top1_for_obvious_query(self) -> None: + """When query is very close to one memory, both should agree on top-1.""" + torch.manual_seed(42) + np.random.seed(42) + d = 64 + N = 50 + target_idx = 25 + + # Create embeddings + embeddings_np = np.random.randn(N, d).astype(np.float32) + # Normalize + norms = np.linalg.norm(embeddings_np, axis=1, keepdims=True) + embeddings_np = embeddings_np / norms + + # Query is exactly the target embedding + query_np = embeddings_np[target_idx].copy() + query_torch = torch.from_numpy(query_np).unsqueeze(0) # (1, d) + + passages = [f"passage {i}" for i in range(N)] + + # FAISS retriever + faiss_ret = FAISSRetriever(top_k=1) + faiss_ret.build_index(embeddings_np.copy(), passages) + faiss_result = faiss_ret.retrieve(query_np.copy()) + + # Hopfield retriever + mb_config = MemoryBankConfig(embedding_dim=d, normalize=False) + mb = MemoryBank(mb_config) + mb.build_from_embeddings( + torch.from_numpy(embeddings_np), passages + ) + hop_config = HopfieldConfig(beta=10.0, max_iter=5) + hopfield = HopfieldRetrieval(hop_config) + hop_ret = HopfieldRetriever(hopfield, mb, top_k=1) + hop_result = hop_ret.retrieve(query_torch) + + assert faiss_result.indices[0].item() == target_idx + assert hop_result.indices[0].item() == target_idx |
