1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
|
"""Unit tests for the energy computation module."""
import torch
from hag.config import HopfieldConfig
from hag.energy import (
compute_attention_entropy,
compute_energy_curve,
compute_energy_gap,
verify_monotonic_decrease,
)
from hag.hopfield import HopfieldRetrieval
class TestEnergy:
"""Tests for energy computation and analysis utilities."""
def test_energy_computation_matches_formula(self) -> None:
"""Manually compute E(q) and verify it matches the module output."""
torch.manual_seed(42)
d, N = 16, 10
memory = torch.randn(d, N)
query = torch.randn(1, d)
beta = 2.0
config = HopfieldConfig(beta=beta)
hopfield = HopfieldRetrieval(config)
computed = hopfield.compute_energy(query, memory)
# Manual computation:
# E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2
logits = beta * (query @ memory) # (1, N)
expected = (
-1.0 / beta * torch.logsumexp(logits, dim=-1)
+ 0.5 * (query**2).sum(dim=-1)
)
assert torch.allclose(computed, expected, atol=1e-5)
def test_verify_monotonic(self) -> None:
"""Monotonically decreasing energy curve should pass."""
energies = [10.0, 8.0, 6.5, 6.4, 6.39]
assert verify_monotonic_decrease(energies) is True
def test_verify_monotonic_fails(self) -> None:
"""Non-monotonic energy curve should fail."""
energies = [10.0, 8.0, 9.0, 6.0]
assert verify_monotonic_decrease(energies) is False
def test_attention_entropy(self) -> None:
"""Uniform attention should have max entropy, one-hot should have zero."""
uniform = torch.ones(100) / 100
one_hot = torch.zeros(100)
one_hot[0] = 1.0
assert compute_attention_entropy(uniform) > compute_attention_entropy(one_hot)
assert compute_attention_entropy(one_hot) < 0.01
def test_energy_curve_extraction(self) -> None:
"""Energy curve should be extractable from a HopfieldResult."""
torch.manual_seed(42)
d, N = 32, 50
memory = torch.randn(d, N)
query = torch.randn(1, d)
config = HopfieldConfig(beta=1.0, max_iter=5)
hopfield = HopfieldRetrieval(config)
result = hopfield.retrieve(query, memory, return_energy=True)
curve = compute_energy_curve(result)
assert len(curve) > 1
assert all(isinstance(v, float) for v in curve)
def test_energy_gap(self) -> None:
"""Energy gap should be positive when energy decreases."""
energies = [10.0, 8.0, 6.0]
gap = compute_energy_gap(energies)
assert gap == 4.0
def test_energy_gap_single_point(self) -> None:
"""Energy gap with fewer than 2 points should be 0."""
assert compute_energy_gap([5.0]) == 0.0
assert compute_energy_gap([]) == 0.0
|