diff options
Diffstat (limited to 'tests/test_energy.py')
| -rw-r--r-- | tests/test_energy.py | 80 |
1 files changed, 80 insertions, 0 deletions
diff --git a/tests/test_energy.py b/tests/test_energy.py new file mode 100644 index 0000000..256f2d8 --- /dev/null +++ b/tests/test_energy.py @@ -0,0 +1,80 @@ +"""Unit tests for the energy computation module.""" + +import torch + +from hag.config import HopfieldConfig +from hag.energy import ( + compute_attention_entropy, + compute_energy_curve, + compute_energy_gap, + verify_monotonic_decrease, +) +from hag.hopfield import HopfieldRetrieval + + +class TestEnergy: + """Tests for energy computation and analysis utilities.""" + + def test_energy_computation_matches_formula(self) -> None: + """Manually compute E(q) and verify it matches the module output.""" + torch.manual_seed(42) + d, N = 16, 10 + memory = torch.randn(d, N) + query = torch.randn(1, d) + beta = 2.0 + + config = HopfieldConfig(beta=beta) + hopfield = HopfieldRetrieval(config) + computed = hopfield.compute_energy(query, memory) + + # Manual computation: + # E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2 + logits = beta * (query @ memory) # (1, N) + expected = ( + -1.0 / beta * torch.logsumexp(logits, dim=-1) + + 0.5 * (query**2).sum(dim=-1) + ) + assert torch.allclose(computed, expected, atol=1e-5) + + def test_verify_monotonic(self) -> None: + """Monotonically decreasing energy curve should pass.""" + energies = [10.0, 8.0, 6.5, 6.4, 6.39] + assert verify_monotonic_decrease(energies) is True + + def test_verify_monotonic_fails(self) -> None: + """Non-monotonic energy curve should fail.""" + energies = [10.0, 8.0, 9.0, 6.0] + assert verify_monotonic_decrease(energies) is False + + def test_attention_entropy(self) -> None: + """Uniform attention should have max entropy, one-hot should have zero.""" + uniform = torch.ones(100) / 100 + one_hot = torch.zeros(100) + one_hot[0] = 1.0 + assert compute_attention_entropy(uniform) > compute_attention_entropy(one_hot) + assert compute_attention_entropy(one_hot) < 0.01 + + def test_energy_curve_extraction(self) -> None: + """Energy curve should be extractable from a HopfieldResult.""" + torch.manual_seed(42) + d, N = 32, 50 + memory = torch.randn(d, N) + query = torch.randn(1, d) + config = HopfieldConfig(beta=1.0, max_iter=5) + hopfield = HopfieldRetrieval(config) + + result = hopfield.retrieve(query, memory, return_energy=True) + curve = compute_energy_curve(result) + assert len(curve) > 1 + assert all(isinstance(v, float) for v in curve) + + def test_energy_gap(self) -> None: + """Energy gap should be positive when energy decreases.""" + energies = [10.0, 8.0, 6.0] + gap = compute_energy_gap(energies) + assert gap == 4.0 + + def test_energy_gap_single_point(self) -> None: + """Energy gap with fewer than 2 points should be 0.""" + assert compute_energy_gap([5.0]) == 0.0 + assert compute_energy_gap([]) == 0.0 |
