"""Core Modern Continuous Hopfield Network retrieval module. Implements the iterative retrieval dynamics from: Ramsauer et al., "Hopfield Networks is All You Need" (ICLR 2021) Update rule: q_{t+1} = M * softmax(beta * M^T * q_t) Energy: E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2 """ import logging from typing import Optional import torch from hag.config import HopfieldConfig from hag.datatypes import HopfieldResult logger = logging.getLogger(__name__) class HopfieldRetrieval: """Modern Continuous Hopfield Network for memory retrieval. Given memory bank M in R^{d x N} and query q in R^d: 1. Compute attention: alpha = softmax(beta * M^T @ q) 2. Update query: q_new = M @ alpha 3. Repeat until convergence or max_iter The energy function is: E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2 Key property: E(q_{t+1}) <= E(q_t) (monotonic decrease) """ def __init__(self, config: HopfieldConfig) -> None: self.config = config @torch.no_grad() def retrieve( self, query: torch.Tensor, memory: torch.Tensor, return_trajectory: bool = False, return_energy: bool = False, ) -> HopfieldResult: """Run iterative Hopfield retrieval. Args: query: (d,) or (batch, d) — query embedding(s) memory: (d, N) — memory bank of passage embeddings return_trajectory: if True, store q_t at each step return_energy: if True, store E(q_t) at each step Returns: HopfieldResult with attention_weights, converged_query, num_steps, and optionally trajectory and energy_curve. """ # Ensure query is 2D: (batch, d) if query.dim() == 1: query = query.unsqueeze(0) # (1, d) q = query.clone() # (batch, d) trajectory = [q.clone()] if return_trajectory else None energies = [self.compute_energy(q, memory)] if return_energy else None num_steps = 0 for t in range(self.config.max_iter): # Core Hopfield update logits = self.config.beta * (q @ memory) # (batch, N) alpha = torch.softmax(logits, dim=-1) # (batch, N) q_new = alpha @ memory.T # (batch, d) # Check convergence delta = torch.norm(q_new - q, dim=-1).max() # scalar q = q_new if return_trajectory: trajectory.append(q.clone()) if return_energy: energies.append(self.compute_energy(q, memory)) num_steps = t + 1 if delta < self.config.conv_threshold: break # Final attention weights (recompute to ensure consistency) logits = self.config.beta * (q @ memory) # (batch, N) alpha = torch.softmax(logits, dim=-1) # (batch, N) return HopfieldResult( attention_weights=alpha, converged_query=q, num_steps=num_steps, trajectory=trajectory, energy_curve=energies, ) def compute_energy( self, query: torch.Tensor, memory: torch.Tensor, ) -> torch.Tensor: """Compute the Hopfield energy function. E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2 Args: query: (batch, d) or (d,) — query embedding(s) memory: (d, N) — memory bank Returns: Energy scalar or (batch,) tensor. """ if query.dim() == 1: query = query.unsqueeze(0) # (1, d) logits = self.config.beta * (query @ memory) # (batch, N) lse = torch.logsumexp(logits, dim=-1) # (batch,) norm_sq = 0.5 * (query**2).sum(dim=-1) # (batch,) energy = -1.0 / self.config.beta * lse + norm_sq # (batch,) return energy