diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
| commit | c90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch) | |
| tree | 43edac8013fec4e65a0b9cddec5314489b4aafc2 /tests/test_hopfield.py | |
Core Hopfield retrieval module with energy-based convergence guarantees,
memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end
pipeline. All 45 tests passing on CPU with synthetic data.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'tests/test_hopfield.py')
| -rw-r--r-- | tests/test_hopfield.py | 147 |
1 files changed, 147 insertions, 0 deletions
diff --git a/tests/test_hopfield.py b/tests/test_hopfield.py new file mode 100644 index 0000000..246b120 --- /dev/null +++ b/tests/test_hopfield.py @@ -0,0 +1,147 @@ +"""Unit tests for the core Hopfield retrieval module.""" + +import torch +import torch.nn.functional as F + +from hag.config import HopfieldConfig +from hag.hopfield import HopfieldRetrieval + + +class TestHopfieldRetrieval: + """Test the core Hopfield module with synthetic data on CPU.""" + + def setup_method(self) -> None: + """Create small synthetic memory bank and queries.""" + torch.manual_seed(42) + self.d = 64 # embedding dim + self.N = 100 # number of memories + self.memory = F.normalize(torch.randn(self.d, self.N), dim=0) # (d, N) + self.query = F.normalize(torch.randn(1, self.d), dim=-1) # (1, d) + self.config = HopfieldConfig(beta=1.0, max_iter=10, conv_threshold=1e-6) + self.hopfield = HopfieldRetrieval(self.config) + + def test_output_shapes(self) -> None: + """attention_weights should be (1, N), converged_query should be (1, d).""" + result = self.hopfield.retrieve(self.query, self.memory) + assert result.attention_weights.shape == (1, self.N) + assert result.converged_query.shape == (1, self.d) + + def test_attention_weights_sum_to_one(self) -> None: + """softmax output must sum to 1.""" + result = self.hopfield.retrieve(self.query, self.memory) + assert torch.allclose( + result.attention_weights.sum(dim=-1), torch.ones(1), atol=1e-5 + ) + + def test_attention_weights_non_negative(self) -> None: + """All attention weights must be >= 0.""" + result = self.hopfield.retrieve(self.query, self.memory) + assert (result.attention_weights >= 0).all() + + def test_energy_monotonic_decrease(self) -> None: + """E(q_{t+1}) <= E(q_t) for all t. This is THE key theoretical property.""" + result = self.hopfield.retrieve( + self.query, self.memory, return_energy=True + ) + energies = [e.item() for e in result.energy_curve] + for i in range(len(energies) - 1): + assert energies[i + 1] <= energies[i] + 1e-6, ( + f"Energy increased at step {i}: {energies[i]} -> {energies[i + 1]}" + ) + + def test_convergence(self) -> None: + """With enough iterations, query should converge (delta < threshold).""" + config = HopfieldConfig(beta=2.0, max_iter=50, conv_threshold=1e-6) + hopfield = HopfieldRetrieval(config) + result = hopfield.retrieve( + self.query, self.memory, return_trajectory=True + ) + # Final two queries should be very close + q_last = result.trajectory[-1] + q_prev = result.trajectory[-2] + delta = torch.norm(q_last - q_prev) + assert delta < 1e-4, f"Did not converge: delta={delta}" + + def test_high_beta_sharp_retrieval(self) -> None: + """Higher beta should produce sharper (lower entropy) attention.""" + low_beta = HopfieldRetrieval(HopfieldConfig(beta=0.5, max_iter=5)) + high_beta = HopfieldRetrieval(HopfieldConfig(beta=5.0, max_iter=5)) + + result_low = low_beta.retrieve(self.query, self.memory) + result_high = high_beta.retrieve(self.query, self.memory) + + entropy_low = -( + result_low.attention_weights * result_low.attention_weights.log() + ).sum() + entropy_high = -( + result_high.attention_weights * result_high.attention_weights.log() + ).sum() + + assert entropy_high < entropy_low, "Higher beta should give lower entropy" + + def test_single_memory_converges_to_it(self) -> None: + """With N=1, retrieval should converge to the single memory.""" + single_memory = F.normalize(torch.randn(self.d, 1), dim=0) + result = self.hopfield.retrieve(self.query, single_memory) + assert torch.allclose( + result.attention_weights, torch.ones(1, 1), atol=1e-5 + ) + + def test_query_near_memory_retrieves_it(self) -> None: + """If query ~= memory_i, attention should peak at index i.""" + target_idx = 42 + query = self.memory[:, target_idx].unsqueeze(0) # (1, d) — exact match + config = HopfieldConfig(beta=10.0, max_iter=5) + hopfield = HopfieldRetrieval(config) + result = hopfield.retrieve(query, self.memory) + top_idx = result.attention_weights.argmax(dim=-1).item() + assert top_idx == target_idx, f"Expected {target_idx}, got {top_idx}" + + def test_batch_retrieval(self) -> None: + """Should handle batch of queries.""" + batch_query = F.normalize(torch.randn(8, self.d), dim=-1) + result = self.hopfield.retrieve(batch_query, self.memory) + assert result.attention_weights.shape == (8, self.N) + assert result.converged_query.shape == (8, self.d) + + def test_iteration_refines_query(self) -> None: + """Multi-hop test: query starts far from target, iteration should bring it closer. + + Setup: memory has two clusters. Query is near cluster A but the "answer" + is in cluster B, reachable through an intermediate memory that bridges both. + After iteration, the query should drift toward cluster B. + """ + torch.manual_seed(0) + d = 32 + + # Cluster A: memories 0-4 centered around direction [1, 0, 0, ...] + # Cluster B: memories 5-9 centered around direction [0, 1, 0, ...] + # Bridge memory 10: between A and B + center_a = torch.zeros(d) + center_a[0] = 1.0 + center_b = torch.zeros(d) + center_b[1] = 1.0 + bridge = F.normalize(center_a + center_b, dim=0) + + memories = [] + for _ in range(5): + memories.append(F.normalize(center_a + 0.1 * torch.randn(d), dim=0)) + for _ in range(5): + memories.append(F.normalize(center_b + 0.1 * torch.randn(d), dim=0)) + memories.append(bridge) + + M = torch.stack(memories, dim=1) # (d, 11) + q0 = F.normalize(center_a + 0.05 * torch.randn(d), dim=0).unsqueeze(0) + + config = HopfieldConfig(beta=3.0, max_iter=10, conv_threshold=1e-8) + hopfield = HopfieldRetrieval(config) + result = hopfield.retrieve(q0, M, return_trajectory=True) + + # After iteration, query should have drifted: its dot product with center_b + # should be higher than the initial query's dot product with center_b + initial_sim_b = (q0.squeeze() @ center_b).item() + final_sim_b = (result.converged_query.squeeze() @ center_b).item() + assert final_sim_b > initial_sim_b, ( + f"Iteration should pull query toward cluster B: " + f"initial={initial_sim_b:.4f}, final={final_sim_b:.4f}" + ) |
