summaryrefslogtreecommitdiff
path: root/tests/test_energy.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
commitc90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch)
tree43edac8013fec4e65a0b9cddec5314489b4aafc2 /tests/test_energy.py
Initial implementation of HAG (Hopfield-Augmented Generation)HEADmaster
Core Hopfield retrieval module with energy-based convergence guarantees, memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end pipeline. All 45 tests passing on CPU with synthetic data. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'tests/test_energy.py')
-rw-r--r--tests/test_energy.py80
1 files changed, 80 insertions, 0 deletions
diff --git a/tests/test_energy.py b/tests/test_energy.py
new file mode 100644
index 0000000..256f2d8
--- /dev/null
+++ b/tests/test_energy.py
@@ -0,0 +1,80 @@
+"""Unit tests for the energy computation module."""
+
+import torch
+
+from hag.config import HopfieldConfig
+from hag.energy import (
+ compute_attention_entropy,
+ compute_energy_curve,
+ compute_energy_gap,
+ verify_monotonic_decrease,
+)
+from hag.hopfield import HopfieldRetrieval
+
+
+class TestEnergy:
+ """Tests for energy computation and analysis utilities."""
+
+ def test_energy_computation_matches_formula(self) -> None:
+ """Manually compute E(q) and verify it matches the module output."""
+ torch.manual_seed(42)
+ d, N = 16, 10
+ memory = torch.randn(d, N)
+ query = torch.randn(1, d)
+ beta = 2.0
+
+ config = HopfieldConfig(beta=beta)
+ hopfield = HopfieldRetrieval(config)
+ computed = hopfield.compute_energy(query, memory)
+
+ # Manual computation:
+ # E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2
+ logits = beta * (query @ memory) # (1, N)
+ expected = (
+ -1.0 / beta * torch.logsumexp(logits, dim=-1)
+ + 0.5 * (query**2).sum(dim=-1)
+ )
+ assert torch.allclose(computed, expected, atol=1e-5)
+
+ def test_verify_monotonic(self) -> None:
+ """Monotonically decreasing energy curve should pass."""
+ energies = [10.0, 8.0, 6.5, 6.4, 6.39]
+ assert verify_monotonic_decrease(energies) is True
+
+ def test_verify_monotonic_fails(self) -> None:
+ """Non-monotonic energy curve should fail."""
+ energies = [10.0, 8.0, 9.0, 6.0]
+ assert verify_monotonic_decrease(energies) is False
+
+ def test_attention_entropy(self) -> None:
+ """Uniform attention should have max entropy, one-hot should have zero."""
+ uniform = torch.ones(100) / 100
+ one_hot = torch.zeros(100)
+ one_hot[0] = 1.0
+ assert compute_attention_entropy(uniform) > compute_attention_entropy(one_hot)
+ assert compute_attention_entropy(one_hot) < 0.01
+
+ def test_energy_curve_extraction(self) -> None:
+ """Energy curve should be extractable from a HopfieldResult."""
+ torch.manual_seed(42)
+ d, N = 32, 50
+ memory = torch.randn(d, N)
+ query = torch.randn(1, d)
+ config = HopfieldConfig(beta=1.0, max_iter=5)
+ hopfield = HopfieldRetrieval(config)
+
+ result = hopfield.retrieve(query, memory, return_energy=True)
+ curve = compute_energy_curve(result)
+ assert len(curve) > 1
+ assert all(isinstance(v, float) for v in curve)
+
+ def test_energy_gap(self) -> None:
+ """Energy gap should be positive when energy decreases."""
+ energies = [10.0, 8.0, 6.0]
+ gap = compute_energy_gap(energies)
+ assert gap == 4.0
+
+ def test_energy_gap_single_point(self) -> None:
+ """Energy gap with fewer than 2 points should be 0."""
+ assert compute_energy_gap([5.0]) == 0.0
+ assert compute_energy_gap([]) == 0.0