summaryrefslogtreecommitdiff
path: root/hag/energy.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
commitc90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch)
tree43edac8013fec4e65a0b9cddec5314489b4aafc2 /hag/energy.py
Initial implementation of HAG (Hopfield-Augmented Generation)HEADmaster
Core Hopfield retrieval module with energy-based convergence guarantees, memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end pipeline. All 45 tests passing on CPU with synthetic data. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'hag/energy.py')
-rw-r--r--hag/energy.py83
1 files changed, 83 insertions, 0 deletions
diff --git a/hag/energy.py b/hag/energy.py
new file mode 100644
index 0000000..62a39e9
--- /dev/null
+++ b/hag/energy.py
@@ -0,0 +1,83 @@
+"""Energy computation and analysis utilities for Hopfield retrieval."""
+
+import logging
+from typing import List
+
+import torch
+
+from hag.datatypes import HopfieldResult
+
+logger = logging.getLogger(__name__)
+
+
+def compute_energy_curve(hopfield_result: HopfieldResult) -> List[float]:
+ """Extract energy values at each iteration step.
+
+ Args:
+ hopfield_result: result from HopfieldRetrieval.retrieve() with return_energy=True
+
+ Returns:
+ List of energy values (floats) at each step.
+ """
+ if hopfield_result.energy_curve is None:
+ return []
+ return [e.item() if e.dim() == 0 else e.mean().item() for e in hopfield_result.energy_curve]
+
+
+def compute_energy_gap(energy_curve: List[float]) -> float:
+ """Compute the energy gap: Delta_E = E(q_0) - E(q_T).
+
+ Larger gap means more refinement happened during iteration.
+
+ Args:
+ energy_curve: list of energy values at each step
+
+ Returns:
+ Energy gap (float). Positive if energy decreased.
+ """
+ if len(energy_curve) < 2:
+ return 0.0
+ return energy_curve[0] - energy_curve[-1]
+
+
+def verify_monotonic_decrease(energy_curve: List[float], tol: float = 1e-6) -> bool:
+ """Check that E(q_{t+1}) <= E(q_t) for all t.
+
+ This should always be True for the Modern Hopfield Network.
+
+ Args:
+ energy_curve: list of energy values at each step
+ tol: numerical tolerance for comparison
+
+ Returns:
+ True if energy decreases monotonically (within tolerance).
+ """
+ for i in range(len(energy_curve) - 1):
+ if energy_curve[i + 1] > energy_curve[i] + tol:
+ return False
+ return True
+
+
+def compute_attention_entropy(attention_weights: torch.Tensor) -> float:
+ """Compute the entropy of attention weights.
+
+ H(alpha) = -sum_i alpha_i * log(alpha_i)
+
+ Low entropy = sharp retrieval (confident).
+ High entropy = diffuse retrieval (uncertain).
+
+ Args:
+ attention_weights: (N,) or (batch, N) — attention distribution
+
+ Returns:
+ Entropy value (float). Averaged over batch if batched.
+ """
+ if attention_weights.dim() == 1:
+ attention_weights = attention_weights.unsqueeze(0) # (1, N)
+
+ # Clamp to avoid log(0)
+ eps = 1e-12
+ alpha = attention_weights.clamp(min=eps)
+ entropy = -(alpha * alpha.log()).sum(dim=-1) # (batch,)
+
+ return entropy.mean().item()