summaryrefslogtreecommitdiff
path: root/tests/test_energy.py
blob: 256f2d8b93d72ab4e335af718d84b26fd46639bf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""Unit tests for the energy computation module."""

import torch

from hag.config import HopfieldConfig
from hag.energy import (
    compute_attention_entropy,
    compute_energy_curve,
    compute_energy_gap,
    verify_monotonic_decrease,
)
from hag.hopfield import HopfieldRetrieval


class TestEnergy:
    """Tests for energy computation and analysis utilities."""

    def test_energy_computation_matches_formula(self) -> None:
        """Manually compute E(q) and verify it matches the module output."""
        torch.manual_seed(42)
        d, N = 16, 10
        memory = torch.randn(d, N)
        query = torch.randn(1, d)
        beta = 2.0

        config = HopfieldConfig(beta=beta)
        hopfield = HopfieldRetrieval(config)
        computed = hopfield.compute_energy(query, memory)

        # Manual computation:
        # E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2
        logits = beta * (query @ memory)  # (1, N)
        expected = (
            -1.0 / beta * torch.logsumexp(logits, dim=-1)
            + 0.5 * (query**2).sum(dim=-1)
        )
        assert torch.allclose(computed, expected, atol=1e-5)

    def test_verify_monotonic(self) -> None:
        """Monotonically decreasing energy curve should pass."""
        energies = [10.0, 8.0, 6.5, 6.4, 6.39]
        assert verify_monotonic_decrease(energies) is True

    def test_verify_monotonic_fails(self) -> None:
        """Non-monotonic energy curve should fail."""
        energies = [10.0, 8.0, 9.0, 6.0]
        assert verify_monotonic_decrease(energies) is False

    def test_attention_entropy(self) -> None:
        """Uniform attention should have max entropy, one-hot should have zero."""
        uniform = torch.ones(100) / 100
        one_hot = torch.zeros(100)
        one_hot[0] = 1.0
        assert compute_attention_entropy(uniform) > compute_attention_entropy(one_hot)
        assert compute_attention_entropy(one_hot) < 0.01

    def test_energy_curve_extraction(self) -> None:
        """Energy curve should be extractable from a HopfieldResult."""
        torch.manual_seed(42)
        d, N = 32, 50
        memory = torch.randn(d, N)
        query = torch.randn(1, d)
        config = HopfieldConfig(beta=1.0, max_iter=5)
        hopfield = HopfieldRetrieval(config)

        result = hopfield.retrieve(query, memory, return_energy=True)
        curve = compute_energy_curve(result)
        assert len(curve) > 1
        assert all(isinstance(v, float) for v in curve)

    def test_energy_gap(self) -> None:
        """Energy gap should be positive when energy decreases."""
        energies = [10.0, 8.0, 6.0]
        gap = compute_energy_gap(energies)
        assert gap == 4.0

    def test_energy_gap_single_point(self) -> None:
        """Energy gap with fewer than 2 points should be 0."""
        assert compute_energy_gap([5.0]) == 0.0
        assert compute_energy_gap([]) == 0.0