1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
|
"""Core Modern Continuous Hopfield Network retrieval module.
Implements the iterative retrieval dynamics from:
Ramsauer et al., "Hopfield Networks is All You Need" (ICLR 2021)
Update rule: q_{t+1} = M * softmax(beta * M^T * q_t)
Energy: E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2
"""
import logging
from typing import Optional
import torch
from hag.config import HopfieldConfig
from hag.datatypes import HopfieldResult
logger = logging.getLogger(__name__)
class HopfieldRetrieval:
"""Modern Continuous Hopfield Network for memory retrieval.
Given memory bank M in R^{d x N} and query q in R^d:
1. Compute attention: alpha = softmax(beta * M^T @ q)
2. Update query: q_new = M @ alpha
3. Repeat until convergence or max_iter
The energy function is:
E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2
Key property: E(q_{t+1}) <= E(q_t) (monotonic decrease)
"""
def __init__(self, config: HopfieldConfig) -> None:
self.config = config
@torch.no_grad()
def retrieve(
self,
query: torch.Tensor,
memory: torch.Tensor,
return_trajectory: bool = False,
return_energy: bool = False,
) -> HopfieldResult:
"""Run iterative Hopfield retrieval.
Args:
query: (d,) or (batch, d) — query embedding(s)
memory: (d, N) — memory bank of passage embeddings
return_trajectory: if True, store q_t at each step
return_energy: if True, store E(q_t) at each step
Returns:
HopfieldResult with attention_weights, converged_query, num_steps,
and optionally trajectory and energy_curve.
"""
# Ensure query is 2D: (batch, d)
if query.dim() == 1:
query = query.unsqueeze(0) # (1, d)
q = query.clone() # (batch, d)
trajectory = [q.clone()] if return_trajectory else None
energies = [self.compute_energy(q, memory)] if return_energy else None
num_steps = 0
for t in range(self.config.max_iter):
# Core Hopfield update
logits = self.config.beta * (q @ memory) # (batch, N)
alpha = torch.softmax(logits, dim=-1) # (batch, N)
q_new = alpha @ memory.T # (batch, d)
# Check convergence
delta = torch.norm(q_new - q, dim=-1).max() # scalar
q = q_new
if return_trajectory:
trajectory.append(q.clone())
if return_energy:
energies.append(self.compute_energy(q, memory))
num_steps = t + 1
if delta < self.config.conv_threshold:
break
# Final attention weights (recompute to ensure consistency)
logits = self.config.beta * (q @ memory) # (batch, N)
alpha = torch.softmax(logits, dim=-1) # (batch, N)
return HopfieldResult(
attention_weights=alpha,
converged_query=q,
num_steps=num_steps,
trajectory=trajectory,
energy_curve=energies,
)
def compute_energy(
self,
query: torch.Tensor,
memory: torch.Tensor,
) -> torch.Tensor:
"""Compute the Hopfield energy function.
E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2
Args:
query: (batch, d) or (d,) — query embedding(s)
memory: (d, N) — memory bank
Returns:
Energy scalar or (batch,) tensor.
"""
if query.dim() == 1:
query = query.unsqueeze(0) # (1, d)
logits = self.config.beta * (query @ memory) # (batch, N)
lse = torch.logsumexp(logits, dim=-1) # (batch,)
norm_sq = 0.5 * (query**2).sum(dim=-1) # (batch,)
energy = -1.0 / self.config.beta * lse + norm_sq # (batch,)
return energy
|