"""Unit tests for the energy computation module.""" import torch from hag.config import HopfieldConfig from hag.energy import ( compute_attention_entropy, compute_energy_curve, compute_energy_gap, verify_monotonic_decrease, ) from hag.hopfield import HopfieldRetrieval class TestEnergy: """Tests for energy computation and analysis utilities.""" def test_energy_computation_matches_formula(self) -> None: """Manually compute E(q) and verify it matches the module output.""" torch.manual_seed(42) d, N = 16, 10 memory = torch.randn(d, N) query = torch.randn(1, d) beta = 2.0 config = HopfieldConfig(beta=beta) hopfield = HopfieldRetrieval(config) computed = hopfield.compute_energy(query, memory) # Manual computation: # E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2 logits = beta * (query @ memory) # (1, N) expected = ( -1.0 / beta * torch.logsumexp(logits, dim=-1) + 0.5 * (query**2).sum(dim=-1) ) assert torch.allclose(computed, expected, atol=1e-5) def test_verify_monotonic(self) -> None: """Monotonically decreasing energy curve should pass.""" energies = [10.0, 8.0, 6.5, 6.4, 6.39] assert verify_monotonic_decrease(energies) is True def test_verify_monotonic_fails(self) -> None: """Non-monotonic energy curve should fail.""" energies = [10.0, 8.0, 9.0, 6.0] assert verify_monotonic_decrease(energies) is False def test_attention_entropy(self) -> None: """Uniform attention should have max entropy, one-hot should have zero.""" uniform = torch.ones(100) / 100 one_hot = torch.zeros(100) one_hot[0] = 1.0 assert compute_attention_entropy(uniform) > compute_attention_entropy(one_hot) assert compute_attention_entropy(one_hot) < 0.01 def test_energy_curve_extraction(self) -> None: """Energy curve should be extractable from a HopfieldResult.""" torch.manual_seed(42) d, N = 32, 50 memory = torch.randn(d, N) query = torch.randn(1, d) config = HopfieldConfig(beta=1.0, max_iter=5) hopfield = HopfieldRetrieval(config) result = hopfield.retrieve(query, memory, return_energy=True) curve = compute_energy_curve(result) assert len(curve) > 1 assert all(isinstance(v, float) for v in curve) def test_energy_gap(self) -> None: """Energy gap should be positive when energy decreases.""" energies = [10.0, 8.0, 6.0] gap = compute_energy_gap(energies) assert gap == 4.0 def test_energy_gap_single_point(self) -> None: """Energy gap with fewer than 2 points should be 0.""" assert compute_energy_gap([5.0]) == 0.0 assert compute_energy_gap([]) == 0.0