blob: 62a39e9d5370037b4503122f5fcf8bff647bfaca (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
|
"""Energy computation and analysis utilities for Hopfield retrieval."""
import logging
from typing import List
import torch
from hag.datatypes import HopfieldResult
logger = logging.getLogger(__name__)
def compute_energy_curve(hopfield_result: HopfieldResult) -> List[float]:
"""Extract energy values at each iteration step.
Args:
hopfield_result: result from HopfieldRetrieval.retrieve() with return_energy=True
Returns:
List of energy values (floats) at each step.
"""
if hopfield_result.energy_curve is None:
return []
return [e.item() if e.dim() == 0 else e.mean().item() for e in hopfield_result.energy_curve]
def compute_energy_gap(energy_curve: List[float]) -> float:
"""Compute the energy gap: Delta_E = E(q_0) - E(q_T).
Larger gap means more refinement happened during iteration.
Args:
energy_curve: list of energy values at each step
Returns:
Energy gap (float). Positive if energy decreased.
"""
if len(energy_curve) < 2:
return 0.0
return energy_curve[0] - energy_curve[-1]
def verify_monotonic_decrease(energy_curve: List[float], tol: float = 1e-6) -> bool:
"""Check that E(q_{t+1}) <= E(q_t) for all t.
This should always be True for the Modern Hopfield Network.
Args:
energy_curve: list of energy values at each step
tol: numerical tolerance for comparison
Returns:
True if energy decreases monotonically (within tolerance).
"""
for i in range(len(energy_curve) - 1):
if energy_curve[i + 1] > energy_curve[i] + tol:
return False
return True
def compute_attention_entropy(attention_weights: torch.Tensor) -> float:
"""Compute the entropy of attention weights.
H(alpha) = -sum_i alpha_i * log(alpha_i)
Low entropy = sharp retrieval (confident).
High entropy = diffuse retrieval (uncertain).
Args:
attention_weights: (N,) or (batch, N) — attention distribution
Returns:
Entropy value (float). Averaged over batch if batched.
"""
if attention_weights.dim() == 1:
attention_weights = attention_weights.unsqueeze(0) # (1, N)
# Clamp to avoid log(0)
eps = 1e-12
alpha = attention_weights.clamp(min=eps)
entropy = -(alpha * alpha.log()).sum(dim=-1) # (batch,)
return entropy.mean().item()
|