1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
|
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# HAG: Hopfield-Augmented Generation Demo\n",
"\n",
"This notebook demonstrates the core Hopfield retrieval mechanism with synthetic data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn.functional as F\n",
"\n",
"from hag.config import HopfieldConfig\n",
"from hag.hopfield import HopfieldRetrieval\n",
"from hag.energy import compute_energy_curve, verify_monotonic_decrease, compute_attention_entropy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create synthetic memory bank and query\n",
"torch.manual_seed(42)\n",
"d, N = 64, 200\n",
"memory = F.normalize(torch.randn(d, N), dim=0)\n",
"query = F.normalize(torch.randn(1, d), dim=-1)\n",
"\n",
"# Run Hopfield retrieval with different beta values\n",
"for beta in [0.5, 1.0, 2.0, 5.0]:\n",
" config = HopfieldConfig(beta=beta, max_iter=20, conv_threshold=1e-6)\n",
" hopfield = HopfieldRetrieval(config)\n",
" result = hopfield.retrieve(query, memory, return_energy=True)\n",
" curve = compute_energy_curve(result)\n",
" entropy = compute_attention_entropy(result.attention_weights)\n",
" print(f'beta={beta}: steps={result.num_steps}, entropy={entropy:.4f}, monotonic={verify_monotonic_decrease(curve)}')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
|