summaryrefslogtreecommitdiff
path: root/tests/test_hopfield.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_hopfield.py')
-rw-r--r--tests/test_hopfield.py147
1 files changed, 147 insertions, 0 deletions
diff --git a/tests/test_hopfield.py b/tests/test_hopfield.py
new file mode 100644
index 0000000..246b120
--- /dev/null
+++ b/tests/test_hopfield.py
@@ -0,0 +1,147 @@
+"""Unit tests for the core Hopfield retrieval module."""
+
+import torch
+import torch.nn.functional as F
+
+from hag.config import HopfieldConfig
+from hag.hopfield import HopfieldRetrieval
+
+
+class TestHopfieldRetrieval:
+ """Test the core Hopfield module with synthetic data on CPU."""
+
+ def setup_method(self) -> None:
+ """Create small synthetic memory bank and queries."""
+ torch.manual_seed(42)
+ self.d = 64 # embedding dim
+ self.N = 100 # number of memories
+ self.memory = F.normalize(torch.randn(self.d, self.N), dim=0) # (d, N)
+ self.query = F.normalize(torch.randn(1, self.d), dim=-1) # (1, d)
+ self.config = HopfieldConfig(beta=1.0, max_iter=10, conv_threshold=1e-6)
+ self.hopfield = HopfieldRetrieval(self.config)
+
+ def test_output_shapes(self) -> None:
+ """attention_weights should be (1, N), converged_query should be (1, d)."""
+ result = self.hopfield.retrieve(self.query, self.memory)
+ assert result.attention_weights.shape == (1, self.N)
+ assert result.converged_query.shape == (1, self.d)
+
+ def test_attention_weights_sum_to_one(self) -> None:
+ """softmax output must sum to 1."""
+ result = self.hopfield.retrieve(self.query, self.memory)
+ assert torch.allclose(
+ result.attention_weights.sum(dim=-1), torch.ones(1), atol=1e-5
+ )
+
+ def test_attention_weights_non_negative(self) -> None:
+ """All attention weights must be >= 0."""
+ result = self.hopfield.retrieve(self.query, self.memory)
+ assert (result.attention_weights >= 0).all()
+
+ def test_energy_monotonic_decrease(self) -> None:
+ """E(q_{t+1}) <= E(q_t) for all t. This is THE key theoretical property."""
+ result = self.hopfield.retrieve(
+ self.query, self.memory, return_energy=True
+ )
+ energies = [e.item() for e in result.energy_curve]
+ for i in range(len(energies) - 1):
+ assert energies[i + 1] <= energies[i] + 1e-6, (
+ f"Energy increased at step {i}: {energies[i]} -> {energies[i + 1]}"
+ )
+
+ def test_convergence(self) -> None:
+ """With enough iterations, query should converge (delta < threshold)."""
+ config = HopfieldConfig(beta=2.0, max_iter=50, conv_threshold=1e-6)
+ hopfield = HopfieldRetrieval(config)
+ result = hopfield.retrieve(
+ self.query, self.memory, return_trajectory=True
+ )
+ # Final two queries should be very close
+ q_last = result.trajectory[-1]
+ q_prev = result.trajectory[-2]
+ delta = torch.norm(q_last - q_prev)
+ assert delta < 1e-4, f"Did not converge: delta={delta}"
+
+ def test_high_beta_sharp_retrieval(self) -> None:
+ """Higher beta should produce sharper (lower entropy) attention."""
+ low_beta = HopfieldRetrieval(HopfieldConfig(beta=0.5, max_iter=5))
+ high_beta = HopfieldRetrieval(HopfieldConfig(beta=5.0, max_iter=5))
+
+ result_low = low_beta.retrieve(self.query, self.memory)
+ result_high = high_beta.retrieve(self.query, self.memory)
+
+ entropy_low = -(
+ result_low.attention_weights * result_low.attention_weights.log()
+ ).sum()
+ entropy_high = -(
+ result_high.attention_weights * result_high.attention_weights.log()
+ ).sum()
+
+ assert entropy_high < entropy_low, "Higher beta should give lower entropy"
+
+ def test_single_memory_converges_to_it(self) -> None:
+ """With N=1, retrieval should converge to the single memory."""
+ single_memory = F.normalize(torch.randn(self.d, 1), dim=0)
+ result = self.hopfield.retrieve(self.query, single_memory)
+ assert torch.allclose(
+ result.attention_weights, torch.ones(1, 1), atol=1e-5
+ )
+
+ def test_query_near_memory_retrieves_it(self) -> None:
+ """If query ~= memory_i, attention should peak at index i."""
+ target_idx = 42
+ query = self.memory[:, target_idx].unsqueeze(0) # (1, d) — exact match
+ config = HopfieldConfig(beta=10.0, max_iter=5)
+ hopfield = HopfieldRetrieval(config)
+ result = hopfield.retrieve(query, self.memory)
+ top_idx = result.attention_weights.argmax(dim=-1).item()
+ assert top_idx == target_idx, f"Expected {target_idx}, got {top_idx}"
+
+ def test_batch_retrieval(self) -> None:
+ """Should handle batch of queries."""
+ batch_query = F.normalize(torch.randn(8, self.d), dim=-1)
+ result = self.hopfield.retrieve(batch_query, self.memory)
+ assert result.attention_weights.shape == (8, self.N)
+ assert result.converged_query.shape == (8, self.d)
+
+ def test_iteration_refines_query(self) -> None:
+ """Multi-hop test: query starts far from target, iteration should bring it closer.
+
+ Setup: memory has two clusters. Query is near cluster A but the "answer"
+ is in cluster B, reachable through an intermediate memory that bridges both.
+ After iteration, the query should drift toward cluster B.
+ """
+ torch.manual_seed(0)
+ d = 32
+
+ # Cluster A: memories 0-4 centered around direction [1, 0, 0, ...]
+ # Cluster B: memories 5-9 centered around direction [0, 1, 0, ...]
+ # Bridge memory 10: between A and B
+ center_a = torch.zeros(d)
+ center_a[0] = 1.0
+ center_b = torch.zeros(d)
+ center_b[1] = 1.0
+ bridge = F.normalize(center_a + center_b, dim=0)
+
+ memories = []
+ for _ in range(5):
+ memories.append(F.normalize(center_a + 0.1 * torch.randn(d), dim=0))
+ for _ in range(5):
+ memories.append(F.normalize(center_b + 0.1 * torch.randn(d), dim=0))
+ memories.append(bridge)
+
+ M = torch.stack(memories, dim=1) # (d, 11)
+ q0 = F.normalize(center_a + 0.05 * torch.randn(d), dim=0).unsqueeze(0)
+
+ config = HopfieldConfig(beta=3.0, max_iter=10, conv_threshold=1e-8)
+ hopfield = HopfieldRetrieval(config)
+ result = hopfield.retrieve(q0, M, return_trajectory=True)
+
+ # After iteration, query should have drifted: its dot product with center_b
+ # should be higher than the initial query's dot product with center_b
+ initial_sim_b = (q0.squeeze() @ center_b).item()
+ final_sim_b = (result.converged_query.squeeze() @ center_b).item()
+ assert final_sim_b > initial_sim_b, (
+ f"Iteration should pull query toward cluster B: "
+ f"initial={initial_sim_b:.4f}, final={final_sim_b:.4f}"
+ )