diff options
Diffstat (limited to 'hag/energy.py')
| -rw-r--r-- | hag/energy.py | 83 |
1 files changed, 83 insertions, 0 deletions
diff --git a/hag/energy.py b/hag/energy.py new file mode 100644 index 0000000..62a39e9 --- /dev/null +++ b/hag/energy.py @@ -0,0 +1,83 @@ +"""Energy computation and analysis utilities for Hopfield retrieval.""" + +import logging +from typing import List + +import torch + +from hag.datatypes import HopfieldResult + +logger = logging.getLogger(__name__) + + +def compute_energy_curve(hopfield_result: HopfieldResult) -> List[float]: + """Extract energy values at each iteration step. + + Args: + hopfield_result: result from HopfieldRetrieval.retrieve() with return_energy=True + + Returns: + List of energy values (floats) at each step. + """ + if hopfield_result.energy_curve is None: + return [] + return [e.item() if e.dim() == 0 else e.mean().item() for e in hopfield_result.energy_curve] + + +def compute_energy_gap(energy_curve: List[float]) -> float: + """Compute the energy gap: Delta_E = E(q_0) - E(q_T). + + Larger gap means more refinement happened during iteration. + + Args: + energy_curve: list of energy values at each step + + Returns: + Energy gap (float). Positive if energy decreased. + """ + if len(energy_curve) < 2: + return 0.0 + return energy_curve[0] - energy_curve[-1] + + +def verify_monotonic_decrease(energy_curve: List[float], tol: float = 1e-6) -> bool: + """Check that E(q_{t+1}) <= E(q_t) for all t. + + This should always be True for the Modern Hopfield Network. + + Args: + energy_curve: list of energy values at each step + tol: numerical tolerance for comparison + + Returns: + True if energy decreases monotonically (within tolerance). + """ + for i in range(len(energy_curve) - 1): + if energy_curve[i + 1] > energy_curve[i] + tol: + return False + return True + + +def compute_attention_entropy(attention_weights: torch.Tensor) -> float: + """Compute the entropy of attention weights. + + H(alpha) = -sum_i alpha_i * log(alpha_i) + + Low entropy = sharp retrieval (confident). + High entropy = diffuse retrieval (uncertain). + + Args: + attention_weights: (N,) or (batch, N) — attention distribution + + Returns: + Entropy value (float). Averaged over batch if batched. + """ + if attention_weights.dim() == 1: + attention_weights = attention_weights.unsqueeze(0) # (1, N) + + # Clamp to avoid log(0) + eps = 1e-12 + alpha = attention_weights.clamp(min=eps) + entropy = -(alpha * alpha.log()).sum(dim=-1) # (batch,) + + return entropy.mean().item() |
