summaryrefslogtreecommitdiff
path: root/hag/energy.py
blob: 62a39e9d5370037b4503122f5fcf8bff647bfaca (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""Energy computation and analysis utilities for Hopfield retrieval."""

import logging
from typing import List

import torch

from hag.datatypes import HopfieldResult

logger = logging.getLogger(__name__)


def compute_energy_curve(hopfield_result: HopfieldResult) -> List[float]:
    """Extract energy values at each iteration step.

    Args:
        hopfield_result: result from HopfieldRetrieval.retrieve() with return_energy=True

    Returns:
        List of energy values (floats) at each step.
    """
    if hopfield_result.energy_curve is None:
        return []
    return [e.item() if e.dim() == 0 else e.mean().item() for e in hopfield_result.energy_curve]


def compute_energy_gap(energy_curve: List[float]) -> float:
    """Compute the energy gap: Delta_E = E(q_0) - E(q_T).

    Larger gap means more refinement happened during iteration.

    Args:
        energy_curve: list of energy values at each step

    Returns:
        Energy gap (float). Positive if energy decreased.
    """
    if len(energy_curve) < 2:
        return 0.0
    return energy_curve[0] - energy_curve[-1]


def verify_monotonic_decrease(energy_curve: List[float], tol: float = 1e-6) -> bool:
    """Check that E(q_{t+1}) <= E(q_t) for all t.

    This should always be True for the Modern Hopfield Network.

    Args:
        energy_curve: list of energy values at each step
        tol: numerical tolerance for comparison

    Returns:
        True if energy decreases monotonically (within tolerance).
    """
    for i in range(len(energy_curve) - 1):
        if energy_curve[i + 1] > energy_curve[i] + tol:
            return False
    return True


def compute_attention_entropy(attention_weights: torch.Tensor) -> float:
    """Compute the entropy of attention weights.

    H(alpha) = -sum_i alpha_i * log(alpha_i)

    Low entropy = sharp retrieval (confident).
    High entropy = diffuse retrieval (uncertain).

    Args:
        attention_weights: (N,) or (batch, N) — attention distribution

    Returns:
        Entropy value (float). Averaged over batch if batched.
    """
    if attention_weights.dim() == 1:
        attention_weights = attention_weights.unsqueeze(0)  # (1, N)

    # Clamp to avoid log(0)
    eps = 1e-12
    alpha = attention_weights.clamp(min=eps)
    entropy = -(alpha * alpha.log()).sum(dim=-1)  # (batch,)

    return entropy.mean().item()