summaryrefslogtreecommitdiff
path: root/hag/hopfield.py
diff options
context:
space:
mode:
Diffstat (limited to 'hag/hopfield.py')
-rw-r--r--hag/hopfield.py124
1 files changed, 124 insertions, 0 deletions
diff --git a/hag/hopfield.py b/hag/hopfield.py
new file mode 100644
index 0000000..287e4af
--- /dev/null
+++ b/hag/hopfield.py
@@ -0,0 +1,124 @@
+"""Core Modern Continuous Hopfield Network retrieval module.
+
+Implements the iterative retrieval dynamics from:
+ Ramsauer et al., "Hopfield Networks is All You Need" (ICLR 2021)
+
+Update rule: q_{t+1} = M * softmax(beta * M^T * q_t)
+Energy: E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2
+"""
+
+import logging
+from typing import Optional
+
+import torch
+
+from hag.config import HopfieldConfig
+from hag.datatypes import HopfieldResult
+
+logger = logging.getLogger(__name__)
+
+
+class HopfieldRetrieval:
+ """Modern Continuous Hopfield Network for memory retrieval.
+
+ Given memory bank M in R^{d x N} and query q in R^d:
+ 1. Compute attention: alpha = softmax(beta * M^T @ q)
+ 2. Update query: q_new = M @ alpha
+ 3. Repeat until convergence or max_iter
+
+ The energy function is:
+ E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2
+
+ Key property: E(q_{t+1}) <= E(q_t) (monotonic decrease)
+ """
+
+ def __init__(self, config: HopfieldConfig) -> None:
+ self.config = config
+
+ @torch.no_grad()
+ def retrieve(
+ self,
+ query: torch.Tensor,
+ memory: torch.Tensor,
+ return_trajectory: bool = False,
+ return_energy: bool = False,
+ ) -> HopfieldResult:
+ """Run iterative Hopfield retrieval.
+
+ Args:
+ query: (d,) or (batch, d) — query embedding(s)
+ memory: (d, N) — memory bank of passage embeddings
+ return_trajectory: if True, store q_t at each step
+ return_energy: if True, store E(q_t) at each step
+
+ Returns:
+ HopfieldResult with attention_weights, converged_query, num_steps,
+ and optionally trajectory and energy_curve.
+ """
+ # Ensure query is 2D: (batch, d)
+ if query.dim() == 1:
+ query = query.unsqueeze(0) # (1, d)
+
+ q = query.clone() # (batch, d)
+
+ trajectory = [q.clone()] if return_trajectory else None
+ energies = [self.compute_energy(q, memory)] if return_energy else None
+
+ num_steps = 0
+ for t in range(self.config.max_iter):
+ # Core Hopfield update
+ logits = self.config.beta * (q @ memory) # (batch, N)
+ alpha = torch.softmax(logits, dim=-1) # (batch, N)
+ q_new = alpha @ memory.T # (batch, d)
+
+ # Check convergence
+ delta = torch.norm(q_new - q, dim=-1).max() # scalar
+ q = q_new
+
+ if return_trajectory:
+ trajectory.append(q.clone())
+ if return_energy:
+ energies.append(self.compute_energy(q, memory))
+
+ num_steps = t + 1
+
+ if delta < self.config.conv_threshold:
+ break
+
+ # Final attention weights (recompute to ensure consistency)
+ logits = self.config.beta * (q @ memory) # (batch, N)
+ alpha = torch.softmax(logits, dim=-1) # (batch, N)
+
+ return HopfieldResult(
+ attention_weights=alpha,
+ converged_query=q,
+ num_steps=num_steps,
+ trajectory=trajectory,
+ energy_curve=energies,
+ )
+
+ def compute_energy(
+ self,
+ query: torch.Tensor,
+ memory: torch.Tensor,
+ ) -> torch.Tensor:
+ """Compute the Hopfield energy function.
+
+ E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2
+
+ Args:
+ query: (batch, d) or (d,) — query embedding(s)
+ memory: (d, N) — memory bank
+
+ Returns:
+ Energy scalar or (batch,) tensor.
+ """
+ if query.dim() == 1:
+ query = query.unsqueeze(0) # (1, d)
+
+ logits = self.config.beta * (query @ memory) # (batch, N)
+ lse = torch.logsumexp(logits, dim=-1) # (batch,)
+ norm_sq = 0.5 * (query**2).sum(dim=-1) # (batch,)
+ energy = -1.0 / self.config.beta * lse + norm_sq # (batch,)
+
+ return energy