"""Energy computation and analysis utilities for Hopfield retrieval.""" import logging from typing import List import torch from hag.datatypes import HopfieldResult logger = logging.getLogger(__name__) def compute_energy_curve(hopfield_result: HopfieldResult) -> List[float]: """Extract energy values at each iteration step. Args: hopfield_result: result from HopfieldRetrieval.retrieve() with return_energy=True Returns: List of energy values (floats) at each step. """ if hopfield_result.energy_curve is None: return [] return [e.item() if e.dim() == 0 else e.mean().item() for e in hopfield_result.energy_curve] def compute_energy_gap(energy_curve: List[float]) -> float: """Compute the energy gap: Delta_E = E(q_0) - E(q_T). Larger gap means more refinement happened during iteration. Args: energy_curve: list of energy values at each step Returns: Energy gap (float). Positive if energy decreased. """ if len(energy_curve) < 2: return 0.0 return energy_curve[0] - energy_curve[-1] def verify_monotonic_decrease(energy_curve: List[float], tol: float = 1e-6) -> bool: """Check that E(q_{t+1}) <= E(q_t) for all t. This should always be True for the Modern Hopfield Network. Args: energy_curve: list of energy values at each step tol: numerical tolerance for comparison Returns: True if energy decreases monotonically (within tolerance). """ for i in range(len(energy_curve) - 1): if energy_curve[i + 1] > energy_curve[i] + tol: return False return True def compute_attention_entropy(attention_weights: torch.Tensor) -> float: """Compute the entropy of attention weights. H(alpha) = -sum_i alpha_i * log(alpha_i) Low entropy = sharp retrieval (confident). High entropy = diffuse retrieval (uncertain). Args: attention_weights: (N,) or (batch, N) — attention distribution Returns: Entropy value (float). Averaged over batch if batched. """ if attention_weights.dim() == 1: attention_weights = attention_weights.unsqueeze(0) # (1, N) # Clamp to avoid log(0) eps = 1e-12 alpha = attention_weights.clamp(min=eps) entropy = -(alpha * alpha.log()).sum(dim=-1) # (batch,) return entropy.mean().item()