diff options
Diffstat (limited to 'hag/hopfield.py')
| -rw-r--r-- | hag/hopfield.py | 124 |
1 files changed, 124 insertions, 0 deletions
diff --git a/hag/hopfield.py b/hag/hopfield.py new file mode 100644 index 0000000..287e4af --- /dev/null +++ b/hag/hopfield.py @@ -0,0 +1,124 @@ +"""Core Modern Continuous Hopfield Network retrieval module. + +Implements the iterative retrieval dynamics from: + Ramsauer et al., "Hopfield Networks is All You Need" (ICLR 2021) + +Update rule: q_{t+1} = M * softmax(beta * M^T * q_t) +Energy: E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2 +""" + +import logging +from typing import Optional + +import torch + +from hag.config import HopfieldConfig +from hag.datatypes import HopfieldResult + +logger = logging.getLogger(__name__) + + +class HopfieldRetrieval: + """Modern Continuous Hopfield Network for memory retrieval. + + Given memory bank M in R^{d x N} and query q in R^d: + 1. Compute attention: alpha = softmax(beta * M^T @ q) + 2. Update query: q_new = M @ alpha + 3. Repeat until convergence or max_iter + + The energy function is: + E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2 + + Key property: E(q_{t+1}) <= E(q_t) (monotonic decrease) + """ + + def __init__(self, config: HopfieldConfig) -> None: + self.config = config + + @torch.no_grad() + def retrieve( + self, + query: torch.Tensor, + memory: torch.Tensor, + return_trajectory: bool = False, + return_energy: bool = False, + ) -> HopfieldResult: + """Run iterative Hopfield retrieval. + + Args: + query: (d,) or (batch, d) — query embedding(s) + memory: (d, N) — memory bank of passage embeddings + return_trajectory: if True, store q_t at each step + return_energy: if True, store E(q_t) at each step + + Returns: + HopfieldResult with attention_weights, converged_query, num_steps, + and optionally trajectory and energy_curve. + """ + # Ensure query is 2D: (batch, d) + if query.dim() == 1: + query = query.unsqueeze(0) # (1, d) + + q = query.clone() # (batch, d) + + trajectory = [q.clone()] if return_trajectory else None + energies = [self.compute_energy(q, memory)] if return_energy else None + + num_steps = 0 + for t in range(self.config.max_iter): + # Core Hopfield update + logits = self.config.beta * (q @ memory) # (batch, N) + alpha = torch.softmax(logits, dim=-1) # (batch, N) + q_new = alpha @ memory.T # (batch, d) + + # Check convergence + delta = torch.norm(q_new - q, dim=-1).max() # scalar + q = q_new + + if return_trajectory: + trajectory.append(q.clone()) + if return_energy: + energies.append(self.compute_energy(q, memory)) + + num_steps = t + 1 + + if delta < self.config.conv_threshold: + break + + # Final attention weights (recompute to ensure consistency) + logits = self.config.beta * (q @ memory) # (batch, N) + alpha = torch.softmax(logits, dim=-1) # (batch, N) + + return HopfieldResult( + attention_weights=alpha, + converged_query=q, + num_steps=num_steps, + trajectory=trajectory, + energy_curve=energies, + ) + + def compute_energy( + self, + query: torch.Tensor, + memory: torch.Tensor, + ) -> torch.Tensor: + """Compute the Hopfield energy function. + + E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2 + + Args: + query: (batch, d) or (d,) — query embedding(s) + memory: (d, N) — memory bank + + Returns: + Energy scalar or (batch,) tensor. + """ + if query.dim() == 1: + query = query.unsqueeze(0) # (1, d) + + logits = self.config.beta * (query @ memory) # (batch, N) + lse = torch.logsumexp(logits, dim=-1) # (batch,) + norm_sq = 0.5 * (query**2).sum(dim=-1) # (batch,) + energy = -1.0 / self.config.beta * lse + norm_sq # (batch,) + + return energy |
