summaryrefslogtreecommitdiff
path: root/notebooks/demo.ipynb
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
commitc90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch)
tree43edac8013fec4e65a0b9cddec5314489b4aafc2 /notebooks/demo.ipynb
Initial implementation of HAG (Hopfield-Augmented Generation)HEADmaster
Core Hopfield retrieval module with energy-based convergence guarantees, memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end pipeline. All 45 tests passing on CPU with synthetic data. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'notebooks/demo.ipynb')
-rw-r--r--notebooks/demo.ipynb62
1 files changed, 62 insertions, 0 deletions
diff --git a/notebooks/demo.ipynb b/notebooks/demo.ipynb
new file mode 100644
index 0000000..caace5b
--- /dev/null
+++ b/notebooks/demo.ipynb
@@ -0,0 +1,62 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# HAG: Hopfield-Augmented Generation Demo\n",
+ "\n",
+ "This notebook demonstrates the core Hopfield retrieval mechanism with synthetic data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import torch.nn.functional as F\n",
+ "\n",
+ "from hag.config import HopfieldConfig\n",
+ "from hag.hopfield import HopfieldRetrieval\n",
+ "from hag.energy import compute_energy_curve, verify_monotonic_decrease, compute_attention_entropy"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create synthetic memory bank and query\n",
+ "torch.manual_seed(42)\n",
+ "d, N = 64, 200\n",
+ "memory = F.normalize(torch.randn(d, N), dim=0)\n",
+ "query = F.normalize(torch.randn(1, d), dim=-1)\n",
+ "\n",
+ "# Run Hopfield retrieval with different beta values\n",
+ "for beta in [0.5, 1.0, 2.0, 5.0]:\n",
+ " config = HopfieldConfig(beta=beta, max_iter=20, conv_threshold=1e-6)\n",
+ " hopfield = HopfieldRetrieval(config)\n",
+ " result = hopfield.retrieve(query, memory, return_energy=True)\n",
+ " curve = compute_energy_curve(result)\n",
+ " entropy = compute_attention_entropy(result.attention_weights)\n",
+ " print(f'beta={beta}: steps={result.num_steps}, entropy={entropy:.4f}, monotonic={verify_monotonic_decrease(curve)}')"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.10.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}