summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/__init__.py0
-rw-r--r--tests/test_energy.py80
-rw-r--r--tests/test_hopfield.py147
-rw-r--r--tests/test_memory_bank.py65
-rw-r--r--tests/test_metrics.py54
-rw-r--r--tests/test_pipeline.py132
-rw-r--r--tests/test_retriever.py149
7 files changed, 627 insertions, 0 deletions
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/__init__.py
diff --git a/tests/test_energy.py b/tests/test_energy.py
new file mode 100644
index 0000000..256f2d8
--- /dev/null
+++ b/tests/test_energy.py
@@ -0,0 +1,80 @@
+"""Unit tests for the energy computation module."""
+
+import torch
+
+from hag.config import HopfieldConfig
+from hag.energy import (
+ compute_attention_entropy,
+ compute_energy_curve,
+ compute_energy_gap,
+ verify_monotonic_decrease,
+)
+from hag.hopfield import HopfieldRetrieval
+
+
+class TestEnergy:
+ """Tests for energy computation and analysis utilities."""
+
+ def test_energy_computation_matches_formula(self) -> None:
+ """Manually compute E(q) and verify it matches the module output."""
+ torch.manual_seed(42)
+ d, N = 16, 10
+ memory = torch.randn(d, N)
+ query = torch.randn(1, d)
+ beta = 2.0
+
+ config = HopfieldConfig(beta=beta)
+ hopfield = HopfieldRetrieval(config)
+ computed = hopfield.compute_energy(query, memory)
+
+ # Manual computation:
+ # E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2
+ logits = beta * (query @ memory) # (1, N)
+ expected = (
+ -1.0 / beta * torch.logsumexp(logits, dim=-1)
+ + 0.5 * (query**2).sum(dim=-1)
+ )
+ assert torch.allclose(computed, expected, atol=1e-5)
+
+ def test_verify_monotonic(self) -> None:
+ """Monotonically decreasing energy curve should pass."""
+ energies = [10.0, 8.0, 6.5, 6.4, 6.39]
+ assert verify_monotonic_decrease(energies) is True
+
+ def test_verify_monotonic_fails(self) -> None:
+ """Non-monotonic energy curve should fail."""
+ energies = [10.0, 8.0, 9.0, 6.0]
+ assert verify_monotonic_decrease(energies) is False
+
+ def test_attention_entropy(self) -> None:
+ """Uniform attention should have max entropy, one-hot should have zero."""
+ uniform = torch.ones(100) / 100
+ one_hot = torch.zeros(100)
+ one_hot[0] = 1.0
+ assert compute_attention_entropy(uniform) > compute_attention_entropy(one_hot)
+ assert compute_attention_entropy(one_hot) < 0.01
+
+ def test_energy_curve_extraction(self) -> None:
+ """Energy curve should be extractable from a HopfieldResult."""
+ torch.manual_seed(42)
+ d, N = 32, 50
+ memory = torch.randn(d, N)
+ query = torch.randn(1, d)
+ config = HopfieldConfig(beta=1.0, max_iter=5)
+ hopfield = HopfieldRetrieval(config)
+
+ result = hopfield.retrieve(query, memory, return_energy=True)
+ curve = compute_energy_curve(result)
+ assert len(curve) > 1
+ assert all(isinstance(v, float) for v in curve)
+
+ def test_energy_gap(self) -> None:
+ """Energy gap should be positive when energy decreases."""
+ energies = [10.0, 8.0, 6.0]
+ gap = compute_energy_gap(energies)
+ assert gap == 4.0
+
+ def test_energy_gap_single_point(self) -> None:
+ """Energy gap with fewer than 2 points should be 0."""
+ assert compute_energy_gap([5.0]) == 0.0
+ assert compute_energy_gap([]) == 0.0
diff --git a/tests/test_hopfield.py b/tests/test_hopfield.py
new file mode 100644
index 0000000..246b120
--- /dev/null
+++ b/tests/test_hopfield.py
@@ -0,0 +1,147 @@
+"""Unit tests for the core Hopfield retrieval module."""
+
+import torch
+import torch.nn.functional as F
+
+from hag.config import HopfieldConfig
+from hag.hopfield import HopfieldRetrieval
+
+
+class TestHopfieldRetrieval:
+ """Test the core Hopfield module with synthetic data on CPU."""
+
+ def setup_method(self) -> None:
+ """Create small synthetic memory bank and queries."""
+ torch.manual_seed(42)
+ self.d = 64 # embedding dim
+ self.N = 100 # number of memories
+ self.memory = F.normalize(torch.randn(self.d, self.N), dim=0) # (d, N)
+ self.query = F.normalize(torch.randn(1, self.d), dim=-1) # (1, d)
+ self.config = HopfieldConfig(beta=1.0, max_iter=10, conv_threshold=1e-6)
+ self.hopfield = HopfieldRetrieval(self.config)
+
+ def test_output_shapes(self) -> None:
+ """attention_weights should be (1, N), converged_query should be (1, d)."""
+ result = self.hopfield.retrieve(self.query, self.memory)
+ assert result.attention_weights.shape == (1, self.N)
+ assert result.converged_query.shape == (1, self.d)
+
+ def test_attention_weights_sum_to_one(self) -> None:
+ """softmax output must sum to 1."""
+ result = self.hopfield.retrieve(self.query, self.memory)
+ assert torch.allclose(
+ result.attention_weights.sum(dim=-1), torch.ones(1), atol=1e-5
+ )
+
+ def test_attention_weights_non_negative(self) -> None:
+ """All attention weights must be >= 0."""
+ result = self.hopfield.retrieve(self.query, self.memory)
+ assert (result.attention_weights >= 0).all()
+
+ def test_energy_monotonic_decrease(self) -> None:
+ """E(q_{t+1}) <= E(q_t) for all t. This is THE key theoretical property."""
+ result = self.hopfield.retrieve(
+ self.query, self.memory, return_energy=True
+ )
+ energies = [e.item() for e in result.energy_curve]
+ for i in range(len(energies) - 1):
+ assert energies[i + 1] <= energies[i] + 1e-6, (
+ f"Energy increased at step {i}: {energies[i]} -> {energies[i + 1]}"
+ )
+
+ def test_convergence(self) -> None:
+ """With enough iterations, query should converge (delta < threshold)."""
+ config = HopfieldConfig(beta=2.0, max_iter=50, conv_threshold=1e-6)
+ hopfield = HopfieldRetrieval(config)
+ result = hopfield.retrieve(
+ self.query, self.memory, return_trajectory=True
+ )
+ # Final two queries should be very close
+ q_last = result.trajectory[-1]
+ q_prev = result.trajectory[-2]
+ delta = torch.norm(q_last - q_prev)
+ assert delta < 1e-4, f"Did not converge: delta={delta}"
+
+ def test_high_beta_sharp_retrieval(self) -> None:
+ """Higher beta should produce sharper (lower entropy) attention."""
+ low_beta = HopfieldRetrieval(HopfieldConfig(beta=0.5, max_iter=5))
+ high_beta = HopfieldRetrieval(HopfieldConfig(beta=5.0, max_iter=5))
+
+ result_low = low_beta.retrieve(self.query, self.memory)
+ result_high = high_beta.retrieve(self.query, self.memory)
+
+ entropy_low = -(
+ result_low.attention_weights * result_low.attention_weights.log()
+ ).sum()
+ entropy_high = -(
+ result_high.attention_weights * result_high.attention_weights.log()
+ ).sum()
+
+ assert entropy_high < entropy_low, "Higher beta should give lower entropy"
+
+ def test_single_memory_converges_to_it(self) -> None:
+ """With N=1, retrieval should converge to the single memory."""
+ single_memory = F.normalize(torch.randn(self.d, 1), dim=0)
+ result = self.hopfield.retrieve(self.query, single_memory)
+ assert torch.allclose(
+ result.attention_weights, torch.ones(1, 1), atol=1e-5
+ )
+
+ def test_query_near_memory_retrieves_it(self) -> None:
+ """If query ~= memory_i, attention should peak at index i."""
+ target_idx = 42
+ query = self.memory[:, target_idx].unsqueeze(0) # (1, d) — exact match
+ config = HopfieldConfig(beta=10.0, max_iter=5)
+ hopfield = HopfieldRetrieval(config)
+ result = hopfield.retrieve(query, self.memory)
+ top_idx = result.attention_weights.argmax(dim=-1).item()
+ assert top_idx == target_idx, f"Expected {target_idx}, got {top_idx}"
+
+ def test_batch_retrieval(self) -> None:
+ """Should handle batch of queries."""
+ batch_query = F.normalize(torch.randn(8, self.d), dim=-1)
+ result = self.hopfield.retrieve(batch_query, self.memory)
+ assert result.attention_weights.shape == (8, self.N)
+ assert result.converged_query.shape == (8, self.d)
+
+ def test_iteration_refines_query(self) -> None:
+ """Multi-hop test: query starts far from target, iteration should bring it closer.
+
+ Setup: memory has two clusters. Query is near cluster A but the "answer"
+ is in cluster B, reachable through an intermediate memory that bridges both.
+ After iteration, the query should drift toward cluster B.
+ """
+ torch.manual_seed(0)
+ d = 32
+
+ # Cluster A: memories 0-4 centered around direction [1, 0, 0, ...]
+ # Cluster B: memories 5-9 centered around direction [0, 1, 0, ...]
+ # Bridge memory 10: between A and B
+ center_a = torch.zeros(d)
+ center_a[0] = 1.0
+ center_b = torch.zeros(d)
+ center_b[1] = 1.0
+ bridge = F.normalize(center_a + center_b, dim=0)
+
+ memories = []
+ for _ in range(5):
+ memories.append(F.normalize(center_a + 0.1 * torch.randn(d), dim=0))
+ for _ in range(5):
+ memories.append(F.normalize(center_b + 0.1 * torch.randn(d), dim=0))
+ memories.append(bridge)
+
+ M = torch.stack(memories, dim=1) # (d, 11)
+ q0 = F.normalize(center_a + 0.05 * torch.randn(d), dim=0).unsqueeze(0)
+
+ config = HopfieldConfig(beta=3.0, max_iter=10, conv_threshold=1e-8)
+ hopfield = HopfieldRetrieval(config)
+ result = hopfield.retrieve(q0, M, return_trajectory=True)
+
+ # After iteration, query should have drifted: its dot product with center_b
+ # should be higher than the initial query's dot product with center_b
+ initial_sim_b = (q0.squeeze() @ center_b).item()
+ final_sim_b = (result.converged_query.squeeze() @ center_b).item()
+ assert final_sim_b > initial_sim_b, (
+ f"Iteration should pull query toward cluster B: "
+ f"initial={initial_sim_b:.4f}, final={final_sim_b:.4f}"
+ )
diff --git a/tests/test_memory_bank.py b/tests/test_memory_bank.py
new file mode 100644
index 0000000..0087bbd
--- /dev/null
+++ b/tests/test_memory_bank.py
@@ -0,0 +1,65 @@
+"""Unit tests for the memory bank module."""
+
+import torch
+
+from hag.config import MemoryBankConfig
+from hag.memory_bank import MemoryBank
+
+
+class TestMemoryBank:
+ """Tests for MemoryBank construction, lookup, and persistence."""
+
+ def test_build_and_size(self) -> None:
+ """Build memory bank and verify size."""
+ config = MemoryBankConfig(embedding_dim=64, normalize=True)
+ mb = MemoryBank(config)
+ embeddings = torch.randn(100, 64)
+ passages = [f"passage {i}" for i in range(100)]
+ mb.build_from_embeddings(embeddings, passages)
+ assert mb.size == 100
+ assert mb.dim == 64
+
+ def test_normalization(self) -> None:
+ """If normalize=True, stored embeddings should have unit norm."""
+ config = MemoryBankConfig(embedding_dim=64, normalize=True)
+ mb = MemoryBank(config)
+ embeddings = torch.randn(50, 64) * 5 # non-unit norm
+ mb.build_from_embeddings(embeddings, [f"p{i}" for i in range(50)])
+ norms = torch.norm(mb.embeddings, dim=0)
+ assert torch.allclose(norms, torch.ones(50), atol=1e-5)
+
+ def test_no_normalization(self) -> None:
+ """If normalize=False, stored embeddings keep original norms."""
+ config = MemoryBankConfig(embedding_dim=64, normalize=False)
+ mb = MemoryBank(config)
+ embeddings = torch.randn(50, 64) * 5
+ original_norms = torch.norm(embeddings, dim=-1)
+ mb.build_from_embeddings(embeddings, [f"p{i}" for i in range(50)])
+ stored_norms = torch.norm(mb.embeddings, dim=0)
+ assert torch.allclose(stored_norms, original_norms, atol=1e-5)
+
+ def test_get_passages_by_indices(self) -> None:
+ """Index -> passage text lookup."""
+ config = MemoryBankConfig(embedding_dim=64, normalize=False)
+ mb = MemoryBank(config)
+ passages = [f"passage {i}" for i in range(100)]
+ mb.build_from_embeddings(torch.randn(100, 64), passages)
+ result = mb.get_passages_by_indices(torch.tensor([0, 50, 99]))
+ assert result == ["passage 0", "passage 50", "passage 99"]
+
+ def test_save_and_load(self, tmp_path) -> None: # type: ignore[no-untyped-def]
+ """Save and reload memory bank, verify contents match."""
+ config = MemoryBankConfig(embedding_dim=32, normalize=True)
+ mb = MemoryBank(config)
+ embeddings = torch.randn(20, 32)
+ passages = [f"text {i}" for i in range(20)]
+ mb.build_from_embeddings(embeddings, passages)
+
+ save_path = str(tmp_path / "mb.pt")
+ mb.save(save_path)
+
+ mb2 = MemoryBank(config)
+ mb2.load(save_path)
+ assert mb2.size == 20
+ assert mb2.passages == passages
+ assert torch.allclose(mb.embeddings, mb2.embeddings, atol=1e-6)
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
new file mode 100644
index 0000000..4a18bec
--- /dev/null
+++ b/tests/test_metrics.py
@@ -0,0 +1,54 @@
+"""Unit tests for evaluation metrics."""
+
+from hag.metrics import exact_match, f1_score, retrieval_recall_at_k
+
+
+class TestMetrics:
+ """Tests for EM, F1, and retrieval recall metrics."""
+
+ def test_exact_match_basic(self) -> None:
+ """Basic exact match: case insensitive."""
+ assert exact_match("Paris", "paris") == 1.0
+ assert exact_match("Paris", "London") == 0.0
+
+ def test_exact_match_normalization(self) -> None:
+ """Should strip whitespace, lowercase, remove articles."""
+ assert exact_match(" The Paris ", "paris") == 1.0
+ assert exact_match("A dog", "dog") == 1.0
+
+ def test_exact_match_empty(self) -> None:
+ """Empty strings should match."""
+ assert exact_match("", "") == 1.0
+
+ def test_f1_score_perfect(self) -> None:
+ """Identical strings should have F1 = 1.0."""
+ assert f1_score("the cat sat", "the cat sat") == 1.0
+
+ def test_f1_score_partial(self) -> None:
+ """Partial overlap should give 0 < F1 < 1."""
+ score = f1_score("the cat sat on the mat", "the cat sat")
+ assert 0.5 < score < 1.0
+
+ def test_f1_score_no_overlap(self) -> None:
+ """No common tokens should give F1 = 0."""
+ assert f1_score("hello world", "foo bar") == 0.0
+
+ def test_f1_score_empty(self) -> None:
+ """Two empty strings should have F1 = 1.0."""
+ assert f1_score("", "") == 1.0
+
+ def test_retrieval_recall(self) -> None:
+ """Standard retrieval recall computation."""
+ assert retrieval_recall_at_k([1, 3, 5], [1, 5, 7], k=3) == 2 / 3
+
+ def test_retrieval_recall_perfect(self) -> None:
+ """All gold passages retrieved."""
+ assert retrieval_recall_at_k([1, 2, 3], [1, 2, 3], k=3) == 1.0
+
+ def test_retrieval_recall_none(self) -> None:
+ """No gold passages retrieved."""
+ assert retrieval_recall_at_k([4, 5, 6], [1, 2, 3], k=3) == 0.0
+
+ def test_retrieval_recall_empty_gold(self) -> None:
+ """No gold passages means perfect recall by convention."""
+ assert retrieval_recall_at_k([1, 2, 3], [], k=3) == 1.0
diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py
new file mode 100644
index 0000000..3ac81ce
--- /dev/null
+++ b/tests/test_pipeline.py
@@ -0,0 +1,132 @@
+"""Integration tests for the RAG/HAG pipeline using mock encoder and generator."""
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from hag.config import (
+ HopfieldConfig,
+ MemoryBankConfig,
+ PipelineConfig,
+)
+from hag.encoder import FakeEncoder
+from hag.generator import FakeGenerator
+from hag.memory_bank import MemoryBank
+from hag.pipeline import RAGPipeline
+from hag.retriever_faiss import FAISSRetriever
+
+
+def _make_memory_bank(n: int = 20, d: int = 64) -> MemoryBank:
+ """Create a small memory bank with random embeddings."""
+ torch.manual_seed(123)
+ config = MemoryBankConfig(embedding_dim=d, normalize=True)
+ mb = MemoryBank(config)
+ embeddings = torch.randn(n, d)
+ passages = [f"passage {i} content" for i in range(n)]
+ mb.build_from_embeddings(embeddings, passages)
+ return mb
+
+
+def _make_faiss_retriever(n: int = 20, d: int = 64) -> FAISSRetriever:
+ """Create a small FAISS retriever with random embeddings."""
+ np.random.seed(123)
+ embeddings = np.random.randn(n, d).astype(np.float32)
+ passages = [f"passage {i} content" for i in range(n)]
+ retriever = FAISSRetriever(top_k=3)
+ retriever.build_index(embeddings, passages)
+ return retriever
+
+
+class TestPipeline:
+ """Integration tests using mock encoder and generator."""
+
+ def test_pipeline_runs_end_to_end_hopfield(self) -> None:
+ """Mock encoder + mock LLM + Hopfield retriever -> produces an answer."""
+ config = PipelineConfig(
+ hopfield=HopfieldConfig(beta=1.0, max_iter=3, top_k=3),
+ memory=MemoryBankConfig(embedding_dim=64, normalize=True),
+ retriever_type="hopfield",
+ )
+ mb = _make_memory_bank(n=20, d=64)
+ encoder = FakeEncoder(dim=64)
+ generator = FakeGenerator()
+
+ pipeline = RAGPipeline(
+ config=config,
+ encoder=encoder,
+ generator=generator,
+ memory_bank=mb,
+ )
+ result = pipeline.run("What is the capital of France?")
+
+ assert isinstance(result.answer, str)
+ assert len(result.answer) > 0
+ assert len(result.retrieved_passages) == 3
+
+ def test_pipeline_runs_end_to_end_faiss(self) -> None:
+ """Mock encoder + mock LLM + FAISS retriever -> produces an answer."""
+ config = PipelineConfig(
+ hopfield=HopfieldConfig(top_k=3),
+ retriever_type="faiss",
+ )
+ faiss_ret = _make_faiss_retriever(n=20, d=64)
+ encoder = FakeEncoder(dim=64)
+ generator = FakeGenerator()
+
+ pipeline = RAGPipeline(
+ config=config,
+ encoder=encoder,
+ generator=generator,
+ faiss_retriever=faiss_ret,
+ )
+ result = pipeline.run("What is the capital of France?")
+
+ assert isinstance(result.answer, str)
+ assert len(result.answer) > 0
+ assert len(result.retrieved_passages) == 3
+
+ def test_pipeline_retrieval_result_included(self) -> None:
+ """PipelineResult should contain retrieval metadata."""
+ config = PipelineConfig(
+ hopfield=HopfieldConfig(beta=1.0, max_iter=3, top_k=3),
+ memory=MemoryBankConfig(embedding_dim=64, normalize=True),
+ retriever_type="hopfield",
+ )
+ mb = _make_memory_bank(n=20, d=64)
+ encoder = FakeEncoder(dim=64)
+ generator = FakeGenerator()
+
+ pipeline = RAGPipeline(
+ config=config,
+ encoder=encoder,
+ generator=generator,
+ memory_bank=mb,
+ )
+ result = pipeline.run("Test question?")
+
+ assert result.retrieval_result is not None
+ assert result.retrieval_result.scores is not None
+ assert result.retrieval_result.indices is not None
+ assert result.question == "Test question?"
+
+ def test_pipeline_batch(self) -> None:
+ """run_batch should return results for each question."""
+ config = PipelineConfig(
+ hopfield=HopfieldConfig(beta=1.0, max_iter=3, top_k=3),
+ memory=MemoryBankConfig(embedding_dim=64, normalize=True),
+ retriever_type="hopfield",
+ )
+ mb = _make_memory_bank(n=20, d=64)
+ encoder = FakeEncoder(dim=64)
+ generator = FakeGenerator()
+
+ pipeline = RAGPipeline(
+ config=config,
+ encoder=encoder,
+ generator=generator,
+ memory_bank=mb,
+ )
+ questions = ["Q1?", "Q2?", "Q3?"]
+ results = pipeline.run_batch(questions)
+ assert len(results) == 3
+ assert all(r.answer == "mock answer" for r in results)
diff --git a/tests/test_retriever.py b/tests/test_retriever.py
new file mode 100644
index 0000000..fa96dca
--- /dev/null
+++ b/tests/test_retriever.py
@@ -0,0 +1,149 @@
+"""Unit tests for both FAISS and Hopfield retrievers."""
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from hag.config import HopfieldConfig, MemoryBankConfig
+from hag.hopfield import HopfieldRetrieval
+from hag.memory_bank import MemoryBank
+from hag.retriever_faiss import FAISSRetriever
+from hag.retriever_hopfield import HopfieldRetriever
+
+
+class TestHopfieldRetriever:
+ """Tests for the Hopfield-based retriever."""
+
+ def setup_method(self) -> None:
+ """Set up a small memory bank and Hopfield retriever."""
+ torch.manual_seed(42)
+ self.d = 64
+ self.N = 50
+ self.top_k = 3
+
+ config = MemoryBankConfig(embedding_dim=self.d, normalize=True)
+ self.mb = MemoryBank(config)
+ embeddings = torch.randn(self.N, self.d)
+ self.passages = [f"passage {i}" for i in range(self.N)]
+ self.mb.build_from_embeddings(embeddings, self.passages)
+
+ hopfield_config = HopfieldConfig(beta=2.0, max_iter=5)
+ hopfield = HopfieldRetrieval(hopfield_config)
+ self.retriever = HopfieldRetriever(hopfield, self.mb, top_k=self.top_k)
+
+ def test_retrieves_correct_number_of_passages(self) -> None:
+ """top_k=3 should return exactly 3 passages."""
+ query = F.normalize(torch.randn(1, self.d), dim=-1)
+ result = self.retriever.retrieve(query)
+ assert len(result.passages) == self.top_k
+ assert result.scores.shape == (self.top_k,)
+ assert result.indices.shape == (self.top_k,)
+
+ def test_scores_are_sorted_descending(self) -> None:
+ """Returned scores should be in descending order."""
+ query = F.normalize(torch.randn(1, self.d), dim=-1)
+ result = self.retriever.retrieve(query)
+ scores = result.scores.tolist()
+ assert scores == sorted(scores, reverse=True)
+
+ def test_passages_match_indices(self) -> None:
+ """Returned passage texts should correspond to returned indices."""
+ query = F.normalize(torch.randn(1, self.d), dim=-1)
+ result = self.retriever.retrieve(query)
+ expected_passages = [self.passages[i] for i in result.indices.tolist()]
+ assert result.passages == expected_passages
+
+ def test_analysis_mode_returns_hopfield_result(self) -> None:
+ """When return_analysis=True, hopfield_result should be populated."""
+ query = F.normalize(torch.randn(1, self.d), dim=-1)
+ result = self.retriever.retrieve(query, return_analysis=True)
+ assert result.hopfield_result is not None
+ assert result.hopfield_result.energy_curve is not None
+ assert result.hopfield_result.trajectory is not None
+
+
+class TestFAISSRetriever:
+ """Tests for the FAISS baseline retriever."""
+
+ def setup_method(self) -> None:
+ """Set up a small FAISS index."""
+ np.random.seed(42)
+ self.d = 64
+ self.N = 50
+ self.top_k = 3
+
+ self.embeddings = np.random.randn(self.N, self.d).astype(np.float32)
+ self.passages = [f"passage {i}" for i in range(self.N)]
+
+ self.retriever = FAISSRetriever(top_k=self.top_k)
+ self.retriever.build_index(self.embeddings.copy(), self.passages)
+
+ def test_retrieves_correct_number(self) -> None:
+ """top_k=3 should return exactly 3 passages."""
+ query = np.random.randn(self.d).astype(np.float32)
+ result = self.retriever.retrieve(query)
+ assert len(result.passages) == self.top_k
+
+ def test_nearest_neighbor_is_self(self) -> None:
+ """If query = passage_i's embedding, top-1 should be passage_i."""
+ # Use the original embedding (before normalization in build_index)
+ # Rebuild with fresh copy so we know the normalization state
+ embeddings = self.embeddings.copy()
+ retriever = FAISSRetriever(top_k=1)
+ retriever.build_index(embeddings.copy(), self.passages)
+
+ # Use the normalized version as query (FAISS normalizes internally)
+ target_idx = 10
+ query = self.embeddings[target_idx].copy()
+ result = retriever.retrieve(query)
+ assert result.indices[0].item() == target_idx
+
+ def test_scores_sorted_descending(self) -> None:
+ """Returned scores should be in descending order."""
+ query = np.random.randn(self.d).astype(np.float32)
+ result = self.retriever.retrieve(query)
+ scores = result.scores.tolist()
+ assert scores == sorted(scores, reverse=True)
+
+
+class TestRetrieverComparison:
+ """Compare FAISS and Hopfield retrievers."""
+
+ def test_same_top1_for_obvious_query(self) -> None:
+ """When query is very close to one memory, both should agree on top-1."""
+ torch.manual_seed(42)
+ np.random.seed(42)
+ d = 64
+ N = 50
+ target_idx = 25
+
+ # Create embeddings
+ embeddings_np = np.random.randn(N, d).astype(np.float32)
+ # Normalize
+ norms = np.linalg.norm(embeddings_np, axis=1, keepdims=True)
+ embeddings_np = embeddings_np / norms
+
+ # Query is exactly the target embedding
+ query_np = embeddings_np[target_idx].copy()
+ query_torch = torch.from_numpy(query_np).unsqueeze(0) # (1, d)
+
+ passages = [f"passage {i}" for i in range(N)]
+
+ # FAISS retriever
+ faiss_ret = FAISSRetriever(top_k=1)
+ faiss_ret.build_index(embeddings_np.copy(), passages)
+ faiss_result = faiss_ret.retrieve(query_np.copy())
+
+ # Hopfield retriever
+ mb_config = MemoryBankConfig(embedding_dim=d, normalize=False)
+ mb = MemoryBank(mb_config)
+ mb.build_from_embeddings(
+ torch.from_numpy(embeddings_np), passages
+ )
+ hop_config = HopfieldConfig(beta=10.0, max_iter=5)
+ hopfield = HopfieldRetrieval(hop_config)
+ hop_ret = HopfieldRetriever(hopfield, mb, top_k=1)
+ hop_result = hop_ret.retrieve(query_torch)
+
+ assert faiss_result.indices[0].item() == target_idx
+ assert hop_result.indices[0].item() == target_idx