1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
|
"""Unit tests for the core Hopfield retrieval module."""
import torch
import torch.nn.functional as F
from hag.config import HopfieldConfig
from hag.hopfield import HopfieldRetrieval
class TestHopfieldRetrieval:
"""Test the core Hopfield module with synthetic data on CPU."""
def setup_method(self) -> None:
"""Create small synthetic memory bank and queries."""
torch.manual_seed(42)
self.d = 64 # embedding dim
self.N = 100 # number of memories
self.memory = F.normalize(torch.randn(self.d, self.N), dim=0) # (d, N)
self.query = F.normalize(torch.randn(1, self.d), dim=-1) # (1, d)
self.config = HopfieldConfig(beta=1.0, max_iter=10, conv_threshold=1e-6)
self.hopfield = HopfieldRetrieval(self.config)
def test_output_shapes(self) -> None:
"""attention_weights should be (1, N), converged_query should be (1, d)."""
result = self.hopfield.retrieve(self.query, self.memory)
assert result.attention_weights.shape == (1, self.N)
assert result.converged_query.shape == (1, self.d)
def test_attention_weights_sum_to_one(self) -> None:
"""softmax output must sum to 1."""
result = self.hopfield.retrieve(self.query, self.memory)
assert torch.allclose(
result.attention_weights.sum(dim=-1), torch.ones(1), atol=1e-5
)
def test_attention_weights_non_negative(self) -> None:
"""All attention weights must be >= 0."""
result = self.hopfield.retrieve(self.query, self.memory)
assert (result.attention_weights >= 0).all()
def test_energy_monotonic_decrease(self) -> None:
"""E(q_{t+1}) <= E(q_t) for all t. This is THE key theoretical property."""
result = self.hopfield.retrieve(
self.query, self.memory, return_energy=True
)
energies = [e.item() for e in result.energy_curve]
for i in range(len(energies) - 1):
assert energies[i + 1] <= energies[i] + 1e-6, (
f"Energy increased at step {i}: {energies[i]} -> {energies[i + 1]}"
)
def test_convergence(self) -> None:
"""With enough iterations, query should converge (delta < threshold)."""
config = HopfieldConfig(beta=2.0, max_iter=50, conv_threshold=1e-6)
hopfield = HopfieldRetrieval(config)
result = hopfield.retrieve(
self.query, self.memory, return_trajectory=True
)
# Final two queries should be very close
q_last = result.trajectory[-1]
q_prev = result.trajectory[-2]
delta = torch.norm(q_last - q_prev)
assert delta < 1e-4, f"Did not converge: delta={delta}"
def test_high_beta_sharp_retrieval(self) -> None:
"""Higher beta should produce sharper (lower entropy) attention."""
low_beta = HopfieldRetrieval(HopfieldConfig(beta=0.5, max_iter=5))
high_beta = HopfieldRetrieval(HopfieldConfig(beta=5.0, max_iter=5))
result_low = low_beta.retrieve(self.query, self.memory)
result_high = high_beta.retrieve(self.query, self.memory)
entropy_low = -(
result_low.attention_weights * result_low.attention_weights.log()
).sum()
entropy_high = -(
result_high.attention_weights * result_high.attention_weights.log()
).sum()
assert entropy_high < entropy_low, "Higher beta should give lower entropy"
def test_single_memory_converges_to_it(self) -> None:
"""With N=1, retrieval should converge to the single memory."""
single_memory = F.normalize(torch.randn(self.d, 1), dim=0)
result = self.hopfield.retrieve(self.query, single_memory)
assert torch.allclose(
result.attention_weights, torch.ones(1, 1), atol=1e-5
)
def test_query_near_memory_retrieves_it(self) -> None:
"""If query ~= memory_i, attention should peak at index i."""
target_idx = 42
query = self.memory[:, target_idx].unsqueeze(0) # (1, d) — exact match
config = HopfieldConfig(beta=10.0, max_iter=5)
hopfield = HopfieldRetrieval(config)
result = hopfield.retrieve(query, self.memory)
top_idx = result.attention_weights.argmax(dim=-1).item()
assert top_idx == target_idx, f"Expected {target_idx}, got {top_idx}"
def test_batch_retrieval(self) -> None:
"""Should handle batch of queries."""
batch_query = F.normalize(torch.randn(8, self.d), dim=-1)
result = self.hopfield.retrieve(batch_query, self.memory)
assert result.attention_weights.shape == (8, self.N)
assert result.converged_query.shape == (8, self.d)
def test_iteration_refines_query(self) -> None:
"""Multi-hop test: query starts far from target, iteration should bring it closer.
Setup: memory has two clusters. Query is near cluster A but the "answer"
is in cluster B, reachable through an intermediate memory that bridges both.
After iteration, the query should drift toward cluster B.
"""
torch.manual_seed(0)
d = 32
# Cluster A: memories 0-4 centered around direction [1, 0, 0, ...]
# Cluster B: memories 5-9 centered around direction [0, 1, 0, ...]
# Bridge memory 10: between A and B
center_a = torch.zeros(d)
center_a[0] = 1.0
center_b = torch.zeros(d)
center_b[1] = 1.0
bridge = F.normalize(center_a + center_b, dim=0)
memories = []
for _ in range(5):
memories.append(F.normalize(center_a + 0.1 * torch.randn(d), dim=0))
for _ in range(5):
memories.append(F.normalize(center_b + 0.1 * torch.randn(d), dim=0))
memories.append(bridge)
M = torch.stack(memories, dim=1) # (d, 11)
q0 = F.normalize(center_a + 0.05 * torch.randn(d), dim=0).unsqueeze(0)
config = HopfieldConfig(beta=3.0, max_iter=10, conv_threshold=1e-8)
hopfield = HopfieldRetrieval(config)
result = hopfield.retrieve(q0, M, return_trajectory=True)
# After iteration, query should have drifted: its dot product with center_b
# should be higher than the initial query's dot product with center_b
initial_sim_b = (q0.squeeze() @ center_b).item()
final_sim_b = (result.converged_query.squeeze() @ center_b).item()
assert final_sim_b > initial_sim_b, (
f"Iteration should pull query toward cluster B: "
f"initial={initial_sim_b:.4f}, final={final_sim_b:.4f}"
)
|