"""Unit tests for the core Hopfield retrieval module.""" import torch import torch.nn.functional as F from hag.config import HopfieldConfig from hag.hopfield import HopfieldRetrieval class TestHopfieldRetrieval: """Test the core Hopfield module with synthetic data on CPU.""" def setup_method(self) -> None: """Create small synthetic memory bank and queries.""" torch.manual_seed(42) self.d = 64 # embedding dim self.N = 100 # number of memories self.memory = F.normalize(torch.randn(self.d, self.N), dim=0) # (d, N) self.query = F.normalize(torch.randn(1, self.d), dim=-1) # (1, d) self.config = HopfieldConfig(beta=1.0, max_iter=10, conv_threshold=1e-6) self.hopfield = HopfieldRetrieval(self.config) def test_output_shapes(self) -> None: """attention_weights should be (1, N), converged_query should be (1, d).""" result = self.hopfield.retrieve(self.query, self.memory) assert result.attention_weights.shape == (1, self.N) assert result.converged_query.shape == (1, self.d) def test_attention_weights_sum_to_one(self) -> None: """softmax output must sum to 1.""" result = self.hopfield.retrieve(self.query, self.memory) assert torch.allclose( result.attention_weights.sum(dim=-1), torch.ones(1), atol=1e-5 ) def test_attention_weights_non_negative(self) -> None: """All attention weights must be >= 0.""" result = self.hopfield.retrieve(self.query, self.memory) assert (result.attention_weights >= 0).all() def test_energy_monotonic_decrease(self) -> None: """E(q_{t+1}) <= E(q_t) for all t. This is THE key theoretical property.""" result = self.hopfield.retrieve( self.query, self.memory, return_energy=True ) energies = [e.item() for e in result.energy_curve] for i in range(len(energies) - 1): assert energies[i + 1] <= energies[i] + 1e-6, ( f"Energy increased at step {i}: {energies[i]} -> {energies[i + 1]}" ) def test_convergence(self) -> None: """With enough iterations, query should converge (delta < threshold).""" config = HopfieldConfig(beta=2.0, max_iter=50, conv_threshold=1e-6) hopfield = HopfieldRetrieval(config) result = hopfield.retrieve( self.query, self.memory, return_trajectory=True ) # Final two queries should be very close q_last = result.trajectory[-1] q_prev = result.trajectory[-2] delta = torch.norm(q_last - q_prev) assert delta < 1e-4, f"Did not converge: delta={delta}" def test_high_beta_sharp_retrieval(self) -> None: """Higher beta should produce sharper (lower entropy) attention.""" low_beta = HopfieldRetrieval(HopfieldConfig(beta=0.5, max_iter=5)) high_beta = HopfieldRetrieval(HopfieldConfig(beta=5.0, max_iter=5)) result_low = low_beta.retrieve(self.query, self.memory) result_high = high_beta.retrieve(self.query, self.memory) entropy_low = -( result_low.attention_weights * result_low.attention_weights.log() ).sum() entropy_high = -( result_high.attention_weights * result_high.attention_weights.log() ).sum() assert entropy_high < entropy_low, "Higher beta should give lower entropy" def test_single_memory_converges_to_it(self) -> None: """With N=1, retrieval should converge to the single memory.""" single_memory = F.normalize(torch.randn(self.d, 1), dim=0) result = self.hopfield.retrieve(self.query, single_memory) assert torch.allclose( result.attention_weights, torch.ones(1, 1), atol=1e-5 ) def test_query_near_memory_retrieves_it(self) -> None: """If query ~= memory_i, attention should peak at index i.""" target_idx = 42 query = self.memory[:, target_idx].unsqueeze(0) # (1, d) — exact match config = HopfieldConfig(beta=10.0, max_iter=5) hopfield = HopfieldRetrieval(config) result = hopfield.retrieve(query, self.memory) top_idx = result.attention_weights.argmax(dim=-1).item() assert top_idx == target_idx, f"Expected {target_idx}, got {top_idx}" def test_batch_retrieval(self) -> None: """Should handle batch of queries.""" batch_query = F.normalize(torch.randn(8, self.d), dim=-1) result = self.hopfield.retrieve(batch_query, self.memory) assert result.attention_weights.shape == (8, self.N) assert result.converged_query.shape == (8, self.d) def test_iteration_refines_query(self) -> None: """Multi-hop test: query starts far from target, iteration should bring it closer. Setup: memory has two clusters. Query is near cluster A but the "answer" is in cluster B, reachable through an intermediate memory that bridges both. After iteration, the query should drift toward cluster B. """ torch.manual_seed(0) d = 32 # Cluster A: memories 0-4 centered around direction [1, 0, 0, ...] # Cluster B: memories 5-9 centered around direction [0, 1, 0, ...] # Bridge memory 10: between A and B center_a = torch.zeros(d) center_a[0] = 1.0 center_b = torch.zeros(d) center_b[1] = 1.0 bridge = F.normalize(center_a + center_b, dim=0) memories = [] for _ in range(5): memories.append(F.normalize(center_a + 0.1 * torch.randn(d), dim=0)) for _ in range(5): memories.append(F.normalize(center_b + 0.1 * torch.randn(d), dim=0)) memories.append(bridge) M = torch.stack(memories, dim=1) # (d, 11) q0 = F.normalize(center_a + 0.05 * torch.randn(d), dim=0).unsqueeze(0) config = HopfieldConfig(beta=3.0, max_iter=10, conv_threshold=1e-8) hopfield = HopfieldRetrieval(config) result = hopfield.retrieve(q0, M, return_trajectory=True) # After iteration, query should have drifted: its dot product with center_b # should be higher than the initial query's dot product with center_b initial_sim_b = (q0.squeeze() @ center_b).item() final_sim_b = (result.converged_query.squeeze() @ center_b).item() assert final_sim_b > initial_sim_b, ( f"Iteration should pull query toward cluster B: " f"initial={initial_sim_b:.4f}, final={final_sim_b:.4f}" )