summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-02-16 14:44:42 -0600
committerYurenHao0426 <Blackhao0426@gmail.com>2026-02-16 14:44:42 -0600
commit09d50e47860da0035e178a442dc936028808a0b3 (patch)
tree9d651b0c7d289a9a0405953f2da989a3c431f147
parentc90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (diff)
Add memory centering, grid search experiments, and energy visualizationsHEADmaster
- Add centering support to MemoryBank (center_query, apply_centering, mean persistence in save/load) to remove centroid attractor in Hopfield dynamics - Add center flag to MemoryBankConfig, device field to PipelineConfig - Grid search scripts: initial (β≤8), residual, high-β, and centered grids with dedup-based LLM caching (89-91% call savings) - Energy landscape visualization: 2D contour, 1D profile, UMAP, PCA heatmap comparing centered vs uncentered dynamics - Experiment log (note.md) documenting 4 rounds of results and root cause analysis of centroid attractor problem - Key finding: β_critical ≈ 37.6 for centered memory; best configs beat FAISS baseline by +3-4% F1 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
-rw-r--r--data/processed/grid_search_results.json449
-rw-r--r--data/processed/highbeta_grid_results.json828
-rw-r--r--data/processed/residual_grid_results.json1016
-rw-r--r--figures/fig1_contour.pngbin0 -> 1543206 bytes
-rw-r--r--figures/fig2_profile.pngbin0 -> 413833 bytes
-rw-r--r--figures/fig3_umap.pngbin0 -> 2132956 bytes
-rw-r--r--figures/fig4_pca.pngbin0 -> 1691214 bytes
-rw-r--r--hag/config.py4
-rw-r--r--hag/encoder.py11
-rw-r--r--hag/generator.py22
-rw-r--r--hag/memory_bank.py61
-rw-r--r--hag/pipeline.py3
-rw-r--r--note.md181
-rw-r--r--scripts/analyze_energy.py5
-rw-r--r--scripts/build_memory_bank.py4
-rw-r--r--scripts/diagnose_centering.py301
-rw-r--r--scripts/eval_centered_grid.py313
-rw-r--r--scripts/eval_highbeta_grid.py301
-rw-r--r--scripts/eval_residual_grid.py298
-rw-r--r--scripts/prepare_corpus.py127
-rw-r--r--scripts/run_baseline.py9
-rw-r--r--scripts/run_comparison.py186
-rw-r--r--scripts/run_eval.py8
-rw-r--r--scripts/run_grid_search.py552
-rw-r--r--scripts/run_hag.py8
-rw-r--r--scripts/visualize_energy.py443
-rw-r--r--scripts/visualize_trajectory.py11
27 files changed, 5108 insertions, 33 deletions
diff --git a/data/processed/grid_search_results.json b/data/processed/grid_search_results.json
new file mode 100644
index 0000000..cdbd3ca
--- /dev/null
+++ b/data/processed/grid_search_results.json
@@ -0,0 +1,449 @@
+{
+ "meta": {
+ "grid_size": 42,
+ "n_questions": 100,
+ "total_grid_evaluations": 4200,
+ "unique_llm_calls": 281,
+ "faiss_llm_calls": 100,
+ "total_llm_calls": 381,
+ "savings_pct": 91.1,
+ "retrieval_time_s": 0.93,
+ "generation_time_s": 735.92,
+ "total_time_s": 1049.0
+ },
+ "faiss_baseline": {
+ "em": 0.32,
+ "f1": 0.4380753968253968
+ },
+ "grid_results": [
+ {
+ "beta": 0.25,
+ "max_iter": 1,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1785,
+ "avg_energy_gap": 2.7302,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 1.0
+ },
+ {
+ "beta": 0.25,
+ "max_iter": 2,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1785,
+ "avg_energy_gap": 2.7302,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 2.0
+ },
+ {
+ "beta": 0.25,
+ "max_iter": 3,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1785,
+ "avg_energy_gap": 2.7302,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 3.0
+ },
+ {
+ "beta": 0.25,
+ "max_iter": 5,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1785,
+ "avg_energy_gap": 2.7302,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 3.0
+ },
+ {
+ "beta": 0.25,
+ "max_iter": 8,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1785,
+ "avg_energy_gap": 2.7302,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 3.0
+ },
+ {
+ "beta": 0.25,
+ "max_iter": 15,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1785,
+ "avg_energy_gap": 2.7302,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 3.0
+ },
+ {
+ "beta": 0.5,
+ "max_iter": 1,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1784,
+ "avg_energy_gap": 2.7292,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 1.0
+ },
+ {
+ "beta": 0.5,
+ "max_iter": 2,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1784,
+ "avg_energy_gap": 2.7292,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 2.0
+ },
+ {
+ "beta": 0.5,
+ "max_iter": 3,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1784,
+ "avg_energy_gap": 2.7292,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 3.0
+ },
+ {
+ "beta": 0.5,
+ "max_iter": 5,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1784,
+ "avg_energy_gap": 2.7292,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 3.0
+ },
+ {
+ "beta": 0.5,
+ "max_iter": 8,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1784,
+ "avg_energy_gap": 2.7292,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 3.0
+ },
+ {
+ "beta": 0.5,
+ "max_iter": 15,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1784,
+ "avg_energy_gap": 2.7292,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 3.0
+ },
+ {
+ "beta": 1.0,
+ "max_iter": 1,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1781,
+ "avg_energy_gap": 2.7273,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 1.0
+ },
+ {
+ "beta": 1.0,
+ "max_iter": 2,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1781,
+ "avg_energy_gap": 2.7273,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 2.0
+ },
+ {
+ "beta": 1.0,
+ "max_iter": 3,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1781,
+ "avg_energy_gap": 2.7273,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 3.0
+ },
+ {
+ "beta": 1.0,
+ "max_iter": 5,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1781,
+ "avg_energy_gap": 2.7273,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 4.0
+ },
+ {
+ "beta": 1.0,
+ "max_iter": 8,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1781,
+ "avg_energy_gap": 2.7273,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 4.0
+ },
+ {
+ "beta": 1.0,
+ "max_iter": 15,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1781,
+ "avg_energy_gap": 2.7273,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 4.0
+ },
+ {
+ "beta": 2.0,
+ "max_iter": 1,
+ "em": 0.15,
+ "f1": 0.1977,
+ "avg_entropy": 7.1767,
+ "avg_energy_gap": 2.7234,
+ "avg_faiss_overlap": 0.004,
+ "avg_steps": 1.0
+ },
+ {
+ "beta": 2.0,
+ "max_iter": 2,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1767,
+ "avg_energy_gap": 2.7235,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 2.0
+ },
+ {
+ "beta": 2.0,
+ "max_iter": 3,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1767,
+ "avg_energy_gap": 2.7235,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 3.0
+ },
+ {
+ "beta": 2.0,
+ "max_iter": 5,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1767,
+ "avg_energy_gap": 2.7235,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 4.0
+ },
+ {
+ "beta": 2.0,
+ "max_iter": 8,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1767,
+ "avg_energy_gap": 2.7235,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 4.0
+ },
+ {
+ "beta": 2.0,
+ "max_iter": 15,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1767,
+ "avg_energy_gap": 2.7235,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 4.0
+ },
+ {
+ "beta": 3.0,
+ "max_iter": 1,
+ "em": 0.14,
+ "f1": 0.2016,
+ "avg_entropy": 7.1742,
+ "avg_energy_gap": 2.7193,
+ "avg_faiss_overlap": 0.006,
+ "avg_steps": 1.0
+ },
+ {
+ "beta": 3.0,
+ "max_iter": 2,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1742,
+ "avg_energy_gap": 2.7196,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 2.0
+ },
+ {
+ "beta": 3.0,
+ "max_iter": 3,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1742,
+ "avg_energy_gap": 2.7196,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 3.0
+ },
+ {
+ "beta": 3.0,
+ "max_iter": 5,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1742,
+ "avg_energy_gap": 2.7196,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 5.0
+ },
+ {
+ "beta": 3.0,
+ "max_iter": 8,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1742,
+ "avg_energy_gap": 2.7196,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 5.0
+ },
+ {
+ "beta": 3.0,
+ "max_iter": 15,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1742,
+ "avg_energy_gap": 2.7196,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 5.0
+ },
+ {
+ "beta": 5.0,
+ "max_iter": 1,
+ "em": 0.16,
+ "f1": 0.2207,
+ "avg_entropy": 7.1659,
+ "avg_energy_gap": 2.7105,
+ "avg_faiss_overlap": 0.018,
+ "avg_steps": 1.0
+ },
+ {
+ "beta": 5.0,
+ "max_iter": 2,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1661,
+ "avg_energy_gap": 2.7114,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 2.0
+ },
+ {
+ "beta": 5.0,
+ "max_iter": 3,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1661,
+ "avg_energy_gap": 2.7114,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 3.0
+ },
+ {
+ "beta": 5.0,
+ "max_iter": 5,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1661,
+ "avg_energy_gap": 2.7114,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 5.0
+ },
+ {
+ "beta": 5.0,
+ "max_iter": 8,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1661,
+ "avg_energy_gap": 2.7114,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 5.0
+ },
+ {
+ "beta": 5.0,
+ "max_iter": 15,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1661,
+ "avg_energy_gap": 2.7114,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 5.0
+ },
+ {
+ "beta": 8.0,
+ "max_iter": 1,
+ "em": 0.21,
+ "f1": 0.2938,
+ "avg_entropy": 7.1438,
+ "avg_energy_gap": 2.6938,
+ "avg_faiss_overlap": 0.068,
+ "avg_steps": 1.0
+ },
+ {
+ "beta": 8.0,
+ "max_iter": 2,
+ "em": 0.12,
+ "f1": 0.1766,
+ "avg_entropy": 7.145,
+ "avg_energy_gap": 2.6972,
+ "avg_faiss_overlap": 0.004,
+ "avg_steps": 2.0
+ },
+ {
+ "beta": 8.0,
+ "max_iter": 3,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1449,
+ "avg_energy_gap": 2.6972,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 3.0
+ },
+ {
+ "beta": 8.0,
+ "max_iter": 5,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1448,
+ "avg_energy_gap": 2.6972,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 5.0
+ },
+ {
+ "beta": 8.0,
+ "max_iter": 8,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1448,
+ "avg_energy_gap": 2.6972,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 7.0
+ },
+ {
+ "beta": 8.0,
+ "max_iter": 15,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_entropy": 7.1448,
+ "avg_energy_gap": 2.6972,
+ "avg_faiss_overlap": 0.002,
+ "avg_steps": 7.0
+ }
+ ],
+ "best_config": {
+ "beta": 8.0,
+ "max_iter": 1,
+ "em": 0.21,
+ "f1": 0.2938,
+ "avg_entropy": 7.1438,
+ "avg_energy_gap": 2.6938,
+ "avg_faiss_overlap": 0.068
+ }
+} \ No newline at end of file
diff --git a/data/processed/highbeta_grid_results.json b/data/processed/highbeta_grid_results.json
new file mode 100644
index 0000000..a04895a
--- /dev/null
+++ b/data/processed/highbeta_grid_results.json
@@ -0,0 +1,828 @@
+{
+ "meta": {
+ "n_questions": 100,
+ "total_configs": 105,
+ "unique_llm_calls": 1379,
+ "total_time_s": 4571.0
+ },
+ "faiss_baseline": {
+ "em": 0.32,
+ "f1": 0.4381
+ },
+ "grid_results": [
+ {
+ "config": "\u03b2=20.0_iter=1_standard",
+ "em": 0.38,
+ "f1": 0.4691,
+ "avg_faiss_overlap": 0.48,
+ "avg_entropy": 4.5305
+ },
+ {
+ "config": "\u03b2=50.0_iter=1_standard",
+ "em": 0.36,
+ "f1": 0.4565,
+ "avg_faiss_overlap": 0.508,
+ "avg_entropy": 0.3196
+ },
+ {
+ "config": "\u03b2=20.0_iter=1_residual_0.9",
+ "em": 0.34,
+ "f1": 0.4552,
+ "avg_faiss_overlap": 0.966,
+ "avg_entropy": 3.5526
+ },
+ {
+ "config": "\u03b2=20.0_iter=2_residual_0.95",
+ "em": 0.34,
+ "f1": 0.4552,
+ "avg_faiss_overlap": 0.966,
+ "avg_entropy": 3.5503
+ },
+ {
+ "config": "\u03b2=500.0_iter=1_residual_0.9",
+ "em": 0.36,
+ "f1": 0.4545,
+ "avg_faiss_overlap": 0.692,
+ "avg_entropy": 0.0074
+ },
+ {
+ "config": "\u03b2=50.0_iter=1_normalized",
+ "em": 0.37,
+ "f1": 0.4539,
+ "avg_faiss_overlap": 0.464,
+ "avg_entropy": 1.7333
+ },
+ {
+ "config": "\u03b2=500.0_iter=1_residual_0.95",
+ "em": 0.36,
+ "f1": 0.4536,
+ "avg_faiss_overlap": 0.748,
+ "avg_entropy": 0.013
+ },
+ {
+ "config": "\u03b2=20.0_iter=1_residual_0.95",
+ "em": 0.34,
+ "f1": 0.4511,
+ "avg_faiss_overlap": 0.98,
+ "avg_entropy": 3.5011
+ },
+ {
+ "config": "\u03b2=50.0_iter=2_normalized",
+ "em": 0.37,
+ "f1": 0.4498,
+ "avg_faiss_overlap": 0.38,
+ "avg_entropy": 1.0954
+ },
+ {
+ "config": "\u03b2=20.0_iter=3_residual_0.95",
+ "em": 0.32,
+ "f1": 0.4494,
+ "avg_faiss_overlap": 0.946,
+ "avg_entropy": 3.5994
+ },
+ {
+ "config": "\u03b2=50.0_iter=1_residual_0.95",
+ "em": 0.34,
+ "f1": 0.4486,
+ "avg_faiss_overlap": 0.974,
+ "avg_entropy": 0.677
+ },
+ {
+ "config": "\u03b2=100.0_iter=1_residual_0.95",
+ "em": 0.34,
+ "f1": 0.4464,
+ "avg_faiss_overlap": 0.97,
+ "avg_entropy": 0.2081
+ },
+ {
+ "config": "\u03b2=200.0_iter=1_residual_0.95",
+ "em": 0.34,
+ "f1": 0.4464,
+ "avg_faiss_overlap": 0.968,
+ "avg_entropy": 0.0722
+ },
+ {
+ "config": "\u03b2=100.0_iter=1_normalized",
+ "em": 0.35,
+ "f1": 0.446,
+ "avg_faiss_overlap": 0.508,
+ "avg_entropy": 0.1139
+ },
+ {
+ "config": "\u03b2=500.0_iter=2_residual_0.95",
+ "em": 0.35,
+ "f1": 0.4445,
+ "avg_faiss_overlap": 0.692,
+ "avg_entropy": 0.0037
+ },
+ {
+ "config": "\u03b2=50.0_iter=2_residual_0.9",
+ "em": 0.33,
+ "f1": 0.4441,
+ "avg_faiss_overlap": 0.886,
+ "avg_entropy": 0.4749
+ },
+ {
+ "config": "\u03b2=100.0_iter=2_residual_0.9",
+ "em": 0.33,
+ "f1": 0.4441,
+ "avg_faiss_overlap": 0.878,
+ "avg_entropy": 0.0623
+ },
+ {
+ "config": "\u03b2=50.0_iter=8_residual_0.95",
+ "em": 0.32,
+ "f1": 0.4431,
+ "avg_faiss_overlap": 0.808,
+ "avg_entropy": 0.2083
+ },
+ {
+ "config": "\u03b2=100.0_iter=8_residual_0.95",
+ "em": 0.32,
+ "f1": 0.4431,
+ "avg_faiss_overlap": 0.794,
+ "avg_entropy": 0.0019
+ },
+ {
+ "config": "\u03b2=100.0_iter=3_residual_0.95",
+ "em": 0.33,
+ "f1": 0.4427,
+ "avg_faiss_overlap": 0.9,
+ "avg_entropy": 0.0754
+ },
+ {
+ "config": "\u03b2=20.0_iter=8_residual_0.95",
+ "em": 0.33,
+ "f1": 0.442,
+ "avg_faiss_overlap": 0.878,
+ "avg_entropy": 3.8365
+ },
+ {
+ "config": "\u03b2=50.0_iter=1_residual_0.9",
+ "em": 0.32,
+ "f1": 0.4419,
+ "avg_faiss_overlap": 0.946,
+ "avg_entropy": 0.6043
+ },
+ {
+ "config": "\u03b2=50.0_iter=2_residual_0.95",
+ "em": 0.32,
+ "f1": 0.4419,
+ "avg_faiss_overlap": 0.944,
+ "avg_entropy": 0.5941
+ },
+ {
+ "config": "\u03b2=50.0_iter=5_residual_0.95",
+ "em": 0.32,
+ "f1": 0.4407,
+ "avg_faiss_overlap": 0.87,
+ "avg_entropy": 0.3877
+ },
+ {
+ "config": "\u03b2=200.0_iter=1_normalized",
+ "em": 0.33,
+ "f1": 0.4404,
+ "avg_faiss_overlap": 0.484,
+ "avg_entropy": 0.0109
+ },
+ {
+ "config": "\u03b2=500.0_iter=5_residual_0.95",
+ "em": 0.35,
+ "f1": 0.4404,
+ "avg_faiss_overlap": 0.546,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=50.0_iter=8_residual_0.9",
+ "em": 0.34,
+ "f1": 0.4402,
+ "avg_faiss_overlap": 0.666,
+ "avg_entropy": 0.0306
+ },
+ {
+ "config": "\u03b2=100.0_iter=1_standard",
+ "em": 0.33,
+ "f1": 0.44,
+ "avg_faiss_overlap": 0.464,
+ "avg_entropy": 0.0196
+ },
+ {
+ "config": "\u03b2=100.0_iter=1_residual_0.9",
+ "em": 0.32,
+ "f1": 0.4397,
+ "avg_faiss_overlap": 0.932,
+ "avg_entropy": 0.1471
+ },
+ {
+ "config": "\u03b2=100.0_iter=2_residual_0.95",
+ "em": 0.32,
+ "f1": 0.4397,
+ "avg_faiss_overlap": 0.932,
+ "avg_entropy": 0.1268
+ },
+ {
+ "config": "\u03b2=100.0_iter=5_residual_0.95",
+ "em": 0.32,
+ "f1": 0.4391,
+ "avg_faiss_overlap": 0.858,
+ "avg_entropy": 0.019
+ },
+ {
+ "config": "\u03b2=20.0_iter=0_standard",
+ "em": 0.32,
+ "f1": 0.4381,
+ "avg_faiss_overlap": 1.0,
+ "avg_entropy": 3.452
+ },
+ {
+ "config": "\u03b2=50.0_iter=0_standard",
+ "em": 0.32,
+ "f1": 0.4381,
+ "avg_faiss_overlap": 1.0,
+ "avg_entropy": 0.7723
+ },
+ {
+ "config": "\u03b2=50.0_iter=5_standard",
+ "em": 0.33,
+ "f1": 0.4381,
+ "avg_faiss_overlap": 0.408,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=50.0_iter=8_standard",
+ "em": 0.33,
+ "f1": 0.4381,
+ "avg_faiss_overlap": 0.408,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=50.0_iter=8_normalized",
+ "em": 0.34,
+ "f1": 0.4381,
+ "avg_faiss_overlap": 0.358,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=100.0_iter=0_standard",
+ "em": 0.32,
+ "f1": 0.4381,
+ "avg_faiss_overlap": 1.0,
+ "avg_entropy": 0.3128
+ },
+ {
+ "config": "\u03b2=100.0_iter=2_normalized",
+ "em": 0.33,
+ "f1": 0.4381,
+ "avg_faiss_overlap": 0.422,
+ "avg_entropy": 0.0143
+ },
+ {
+ "config": "\u03b2=100.0_iter=3_normalized",
+ "em": 0.33,
+ "f1": 0.4381,
+ "avg_faiss_overlap": 0.412,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=100.0_iter=5_normalized",
+ "em": 0.33,
+ "f1": 0.4381,
+ "avg_faiss_overlap": 0.41,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=100.0_iter=8_normalized",
+ "em": 0.33,
+ "f1": 0.4381,
+ "avg_faiss_overlap": 0.41,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=200.0_iter=0_standard",
+ "em": 0.32,
+ "f1": 0.4381,
+ "avg_faiss_overlap": 1.0,
+ "avg_entropy": 0.1535
+ },
+ {
+ "config": "\u03b2=50.0_iter=3_residual_0.95",
+ "em": 0.32,
+ "f1": 0.4377,
+ "avg_faiss_overlap": 0.912,
+ "avg_entropy": 0.5215
+ },
+ {
+ "config": "\u03b2=20.0_iter=5_residual_0.9",
+ "em": 0.32,
+ "f1": 0.437,
+ "avg_faiss_overlap": 0.846,
+ "avg_entropy": 3.9311
+ },
+ {
+ "config": "\u03b2=20.0_iter=2_residual_0.9",
+ "em": 0.31,
+ "f1": 0.4349,
+ "avg_faiss_overlap": 0.926,
+ "avg_entropy": 3.6521
+ },
+ {
+ "config": "\u03b2=200.0_iter=1_residual_0.9",
+ "em": 0.32,
+ "f1": 0.4347,
+ "avg_faiss_overlap": 0.928,
+ "avg_entropy": 0.0466
+ },
+ {
+ "config": "\u03b2=200.0_iter=2_residual_0.95",
+ "em": 0.32,
+ "f1": 0.4347,
+ "avg_faiss_overlap": 0.928,
+ "avg_entropy": 0.0274
+ },
+ {
+ "config": "\u03b2=50.0_iter=3_residual_0.9",
+ "em": 0.31,
+ "f1": 0.4344,
+ "avg_faiss_overlap": 0.842,
+ "avg_entropy": 0.3574
+ },
+ {
+ "config": "\u03b2=200.0_iter=2_residual_0.9",
+ "em": 0.32,
+ "f1": 0.4341,
+ "avg_faiss_overlap": 0.876,
+ "avg_entropy": 0.008
+ },
+ {
+ "config": "\u03b2=200.0_iter=5_residual_0.95",
+ "em": 0.31,
+ "f1": 0.4341,
+ "avg_faiss_overlap": 0.852,
+ "avg_entropy": 0.0005
+ },
+ {
+ "config": "\u03b2=200.0_iter=1_standard",
+ "em": 0.33,
+ "f1": 0.4331,
+ "avg_faiss_overlap": 0.43,
+ "avg_entropy": 0.0061
+ },
+ {
+ "config": "\u03b2=200.0_iter=3_residual_0.95",
+ "em": 0.32,
+ "f1": 0.4327,
+ "avg_faiss_overlap": 0.898,
+ "avg_entropy": 0.0082
+ },
+ {
+ "config": "\u03b2=20.0_iter=3_residual_0.9",
+ "em": 0.31,
+ "f1": 0.4313,
+ "avg_faiss_overlap": 0.9,
+ "avg_entropy": 3.7492
+ },
+ {
+ "config": "\u03b2=500.0_iter=3_residual_0.95",
+ "em": 0.35,
+ "f1": 0.4312,
+ "avg_faiss_overlap": 0.636,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=20.0_iter=8_residual_0.9",
+ "em": 0.31,
+ "f1": 0.4302,
+ "avg_faiss_overlap": 0.748,
+ "avg_entropy": 4.1579
+ },
+ {
+ "config": "\u03b2=20.0_iter=5_residual_0.95",
+ "em": 0.31,
+ "f1": 0.4299,
+ "avg_faiss_overlap": 0.91,
+ "avg_entropy": 3.6963
+ },
+ {
+ "config": "\u03b2=50.0_iter=3_normalized",
+ "em": 0.33,
+ "f1": 0.4291,
+ "avg_faiss_overlap": 0.368,
+ "avg_entropy": 0.6769
+ },
+ {
+ "config": "\u03b2=50.0_iter=5_residual_0.9",
+ "em": 0.32,
+ "f1": 0.4291,
+ "avg_faiss_overlap": 0.778,
+ "avg_entropy": 0.1599
+ },
+ {
+ "config": "\u03b2=50.0_iter=3_standard",
+ "em": 0.32,
+ "f1": 0.4281,
+ "avg_faiss_overlap": 0.41,
+ "avg_entropy": 0.0016
+ },
+ {
+ "config": "\u03b2=50.0_iter=5_normalized",
+ "em": 0.33,
+ "f1": 0.4281,
+ "avg_faiss_overlap": 0.356,
+ "avg_entropy": 0.1417
+ },
+ {
+ "config": "\u03b2=100.0_iter=3_standard",
+ "em": 0.32,
+ "f1": 0.4281,
+ "avg_faiss_overlap": 0.406,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=100.0_iter=5_standard",
+ "em": 0.32,
+ "f1": 0.4281,
+ "avg_faiss_overlap": 0.406,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=100.0_iter=8_standard",
+ "em": 0.32,
+ "f1": 0.4281,
+ "avg_faiss_overlap": 0.406,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=200.0_iter=2_standard",
+ "em": 0.32,
+ "f1": 0.4281,
+ "avg_faiss_overlap": 0.408,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=200.0_iter=2_normalized",
+ "em": 0.32,
+ "f1": 0.4281,
+ "avg_faiss_overlap": 0.408,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=200.0_iter=3_standard",
+ "em": 0.32,
+ "f1": 0.4281,
+ "avg_faiss_overlap": 0.406,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=200.0_iter=3_normalized",
+ "em": 0.32,
+ "f1": 0.4281,
+ "avg_faiss_overlap": 0.406,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=200.0_iter=5_standard",
+ "em": 0.32,
+ "f1": 0.4281,
+ "avg_faiss_overlap": 0.406,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=200.0_iter=5_normalized",
+ "em": 0.32,
+ "f1": 0.4281,
+ "avg_faiss_overlap": 0.406,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=200.0_iter=8_standard",
+ "em": 0.32,
+ "f1": 0.4281,
+ "avg_faiss_overlap": 0.406,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=200.0_iter=8_normalized",
+ "em": 0.32,
+ "f1": 0.4281,
+ "avg_faiss_overlap": 0.406,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=50.0_iter=2_standard",
+ "em": 0.32,
+ "f1": 0.4265,
+ "avg_faiss_overlap": 0.422,
+ "avg_entropy": 0.0481
+ },
+ {
+ "config": "\u03b2=500.0_iter=3_residual_0.9",
+ "em": 0.33,
+ "f1": 0.4265,
+ "avg_faiss_overlap": 0.51,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=100.0_iter=3_residual_0.9",
+ "em": 0.29,
+ "f1": 0.4244,
+ "avg_faiss_overlap": 0.828,
+ "avg_entropy": 0.019
+ },
+ {
+ "config": "\u03b2=200.0_iter=3_residual_0.9",
+ "em": 0.29,
+ "f1": 0.4244,
+ "avg_faiss_overlap": 0.824,
+ "avg_entropy": 0.0021
+ },
+ {
+ "config": "\u03b2=500.0_iter=0_standard",
+ "em": 0.33,
+ "f1": 0.4236,
+ "avg_faiss_overlap": 0.798,
+ "avg_entropy": 0.0746
+ },
+ {
+ "config": "\u03b2=200.0_iter=8_residual_0.95",
+ "em": 0.3,
+ "f1": 0.4231,
+ "avg_faiss_overlap": 0.786,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=100.0_iter=2_standard",
+ "em": 0.31,
+ "f1": 0.4181,
+ "avg_faiss_overlap": 0.41,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=100.0_iter=8_residual_0.9",
+ "em": 0.32,
+ "f1": 0.4152,
+ "avg_faiss_overlap": 0.636,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=200.0_iter=8_residual_0.9",
+ "em": 0.32,
+ "f1": 0.4152,
+ "avg_faiss_overlap": 0.634,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=100.0_iter=5_residual_0.9",
+ "em": 0.3,
+ "f1": 0.4141,
+ "avg_faiss_overlap": 0.764,
+ "avg_entropy": 0.0014
+ },
+ {
+ "config": "\u03b2=500.0_iter=2_residual_0.9",
+ "em": 0.32,
+ "f1": 0.4098,
+ "avg_faiss_overlap": 0.584,
+ "avg_entropy": 0.0001
+ },
+ {
+ "config": "\u03b2=200.0_iter=5_residual_0.9",
+ "em": 0.29,
+ "f1": 0.4041,
+ "avg_faiss_overlap": 0.754,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=500.0_iter=5_residual_0.9",
+ "em": 0.31,
+ "f1": 0.3949,
+ "avg_faiss_overlap": 0.34,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=500.0_iter=8_residual_0.95",
+ "em": 0.29,
+ "f1": 0.3839,
+ "avg_faiss_overlap": 0.424,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=500.0_iter=8_residual_0.9",
+ "em": 0.27,
+ "f1": 0.3743,
+ "avg_faiss_overlap": 0.228,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=20.0_iter=2_standard",
+ "em": 0.29,
+ "f1": 0.3713,
+ "avg_faiss_overlap": 0.248,
+ "avg_entropy": 4.6996
+ },
+ {
+ "config": "\u03b2=500.0_iter=1_normalized",
+ "em": 0.27,
+ "f1": 0.3693,
+ "avg_faiss_overlap": 0.236,
+ "avg_entropy": 0.0018
+ },
+ {
+ "config": "\u03b2=500.0_iter=1_standard",
+ "em": 0.25,
+ "f1": 0.35,
+ "avg_faiss_overlap": 0.22,
+ "avg_entropy": 0.0003
+ },
+ {
+ "config": "\u03b2=500.0_iter=2_standard",
+ "em": 0.25,
+ "f1": 0.35,
+ "avg_faiss_overlap": 0.202,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=500.0_iter=2_normalized",
+ "em": 0.25,
+ "f1": 0.35,
+ "avg_faiss_overlap": 0.202,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=500.0_iter=3_standard",
+ "em": 0.25,
+ "f1": 0.35,
+ "avg_faiss_overlap": 0.202,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=500.0_iter=3_normalized",
+ "em": 0.25,
+ "f1": 0.35,
+ "avg_faiss_overlap": 0.202,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=500.0_iter=5_standard",
+ "em": 0.25,
+ "f1": 0.35,
+ "avg_faiss_overlap": 0.202,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=500.0_iter=5_normalized",
+ "em": 0.25,
+ "f1": 0.35,
+ "avg_faiss_overlap": 0.202,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=500.0_iter=8_standard",
+ "em": 0.25,
+ "f1": 0.35,
+ "avg_faiss_overlap": 0.202,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=500.0_iter=8_normalized",
+ "em": 0.25,
+ "f1": 0.35,
+ "avg_faiss_overlap": 0.202,
+ "avg_entropy": 0.0
+ },
+ {
+ "config": "\u03b2=20.0_iter=1_normalized",
+ "em": 0.27,
+ "f1": 0.3367,
+ "avg_faiss_overlap": 0.14,
+ "avg_entropy": 6.5894
+ },
+ {
+ "config": "\u03b2=20.0_iter=3_standard",
+ "em": 0.27,
+ "f1": 0.332,
+ "avg_faiss_overlap": 0.184,
+ "avg_entropy": 4.6947
+ },
+ {
+ "config": "\u03b2=20.0_iter=5_standard",
+ "em": 0.24,
+ "f1": 0.3116,
+ "avg_faiss_overlap": 0.162,
+ "avg_entropy": 4.6875
+ },
+ {
+ "config": "\u03b2=20.0_iter=8_standard",
+ "em": 0.23,
+ "f1": 0.3016,
+ "avg_faiss_overlap": 0.16,
+ "avg_entropy": 4.684
+ },
+ {
+ "config": "\u03b2=20.0_iter=2_normalized",
+ "em": 0.18,
+ "f1": 0.2255,
+ "avg_faiss_overlap": 0.044,
+ "avg_entropy": 6.4837
+ },
+ {
+ "config": "\u03b2=20.0_iter=8_normalized",
+ "em": 0.14,
+ "f1": 0.2221,
+ "avg_faiss_overlap": 0.02,
+ "avg_entropy": 6.3573
+ },
+ {
+ "config": "\u03b2=20.0_iter=3_normalized",
+ "em": 0.14,
+ "f1": 0.2155,
+ "avg_faiss_overlap": 0.024,
+ "avg_entropy": 6.4174
+ },
+ {
+ "config": "\u03b2=20.0_iter=5_normalized",
+ "em": 0.13,
+ "f1": 0.1961,
+ "avg_faiss_overlap": 0.02,
+ "avg_entropy": 6.3707
+ }
+ ],
+ "best_config": {
+ "config": "\u03b2=20.0_iter=1_standard",
+ "em": 0.38,
+ "f1": 0.4691,
+ "avg_faiss_overlap": 0.48,
+ "avg_entropy": 4.5305
+ },
+ "top10": [
+ {
+ "config": "\u03b2=20.0_iter=1_standard",
+ "em": 0.38,
+ "f1": 0.4691,
+ "avg_faiss_overlap": 0.48,
+ "avg_entropy": 4.5305
+ },
+ {
+ "config": "\u03b2=50.0_iter=1_standard",
+ "em": 0.36,
+ "f1": 0.4565,
+ "avg_faiss_overlap": 0.508,
+ "avg_entropy": 0.3196
+ },
+ {
+ "config": "\u03b2=20.0_iter=1_residual_0.9",
+ "em": 0.34,
+ "f1": 0.4552,
+ "avg_faiss_overlap": 0.966,
+ "avg_entropy": 3.5526
+ },
+ {
+ "config": "\u03b2=20.0_iter=2_residual_0.95",
+ "em": 0.34,
+ "f1": 0.4552,
+ "avg_faiss_overlap": 0.966,
+ "avg_entropy": 3.5503
+ },
+ {
+ "config": "\u03b2=500.0_iter=1_residual_0.9",
+ "em": 0.36,
+ "f1": 0.4545,
+ "avg_faiss_overlap": 0.692,
+ "avg_entropy": 0.0074
+ },
+ {
+ "config": "\u03b2=50.0_iter=1_normalized",
+ "em": 0.37,
+ "f1": 0.4539,
+ "avg_faiss_overlap": 0.464,
+ "avg_entropy": 1.7333
+ },
+ {
+ "config": "\u03b2=500.0_iter=1_residual_0.95",
+ "em": 0.36,
+ "f1": 0.4536,
+ "avg_faiss_overlap": 0.748,
+ "avg_entropy": 0.013
+ },
+ {
+ "config": "\u03b2=20.0_iter=1_residual_0.95",
+ "em": 0.34,
+ "f1": 0.4511,
+ "avg_faiss_overlap": 0.98,
+ "avg_entropy": 3.5011
+ },
+ {
+ "config": "\u03b2=50.0_iter=2_normalized",
+ "em": 0.37,
+ "f1": 0.4498,
+ "avg_faiss_overlap": 0.38,
+ "avg_entropy": 1.0954
+ },
+ {
+ "config": "\u03b2=20.0_iter=3_residual_0.95",
+ "em": 0.32,
+ "f1": 0.4494,
+ "avg_faiss_overlap": 0.946,
+ "avg_entropy": 3.5994
+ }
+ ]
+} \ No newline at end of file
diff --git a/data/processed/residual_grid_results.json b/data/processed/residual_grid_results.json
new file mode 100644
index 0000000..c5b6d28
--- /dev/null
+++ b/data/processed/residual_grid_results.json
@@ -0,0 +1,1016 @@
+{
+ "meta": {
+ "n_questions": 100,
+ "total_configs": 100,
+ "unique_llm_calls": 1666,
+ "faiss_llm_calls": 100,
+ "total_time_s": 5113.4
+ },
+ "faiss_baseline": {
+ "em": 0.32,
+ "f1": 0.4381
+ },
+ "grid_results": [
+ {
+ "beta": 5.0,
+ "lambda": 0.7,
+ "max_iter": 1,
+ "em": 0.36,
+ "f1": 0.4809,
+ "avg_faiss_overlap": 0.902,
+ "avg_entropy": 7.1163
+ },
+ {
+ "beta": 5.0,
+ "lambda": 0.9,
+ "max_iter": 3,
+ "em": 0.36,
+ "f1": 0.4809,
+ "avg_faiss_overlap": 0.912,
+ "avg_entropy": 7.1122
+ },
+ {
+ "beta": 10.0,
+ "lambda": 0.7,
+ "max_iter": 1,
+ "em": 0.36,
+ "f1": 0.4809,
+ "avg_faiss_overlap": 0.9,
+ "avg_entropy": 6.8422
+ },
+ {
+ "beta": 10.0,
+ "lambda": 0.95,
+ "max_iter": 8,
+ "em": 0.36,
+ "f1": 0.4797,
+ "avg_faiss_overlap": 0.886,
+ "avg_entropy": 6.8941
+ },
+ {
+ "beta": 5.0,
+ "lambda": 0.95,
+ "max_iter": 8,
+ "em": 0.35,
+ "f1": 0.4697,
+ "avg_faiss_overlap": 0.886,
+ "avg_entropy": 7.1219
+ },
+ {
+ "beta": 5.0,
+ "lambda": 0.95,
+ "max_iter": 3,
+ "em": 0.35,
+ "f1": 0.4692,
+ "avg_faiss_overlap": 0.956,
+ "avg_entropy": 7.09
+ },
+ {
+ "beta": 5.0,
+ "lambda": 0.95,
+ "max_iter": 5,
+ "em": 0.35,
+ "f1": 0.4692,
+ "avg_faiss_overlap": 0.928,
+ "avg_entropy": 7.105
+ },
+ {
+ "beta": 10.0,
+ "lambda": 0.9,
+ "max_iter": 1,
+ "em": 0.35,
+ "f1": 0.4692,
+ "avg_faiss_overlap": 0.966,
+ "avg_entropy": 6.6138
+ },
+ {
+ "beta": 10.0,
+ "lambda": 0.95,
+ "max_iter": 3,
+ "em": 0.35,
+ "f1": 0.4692,
+ "avg_faiss_overlap": 0.956,
+ "avg_entropy": 6.6763
+ },
+ {
+ "beta": 5.0,
+ "lambda": 0.9,
+ "max_iter": 1,
+ "em": 0.34,
+ "f1": 0.4672,
+ "avg_faiss_overlap": 0.97,
+ "avg_entropy": 7.0815
+ },
+ {
+ "beta": 10.0,
+ "lambda": 0.9,
+ "max_iter": 5,
+ "em": 0.34,
+ "f1": 0.4637,
+ "avg_faiss_overlap": 0.856,
+ "avg_entropy": 6.9499
+ },
+ {
+ "beta": 5.0,
+ "lambda": 0.5,
+ "max_iter": 1,
+ "em": 0.33,
+ "f1": 0.4627,
+ "avg_faiss_overlap": 0.786,
+ "avg_entropy": 7.1407
+ },
+ {
+ "beta": 5.0,
+ "lambda": 0.9,
+ "max_iter": 8,
+ "em": 0.34,
+ "f1": 0.4625,
+ "avg_faiss_overlap": 0.724,
+ "avg_entropy": 7.1479
+ },
+ {
+ "beta": 10.0,
+ "lambda": 0.5,
+ "max_iter": 1,
+ "em": 0.33,
+ "f1": 0.4622,
+ "avg_faiss_overlap": 0.808,
+ "avg_entropy": 6.9842
+ },
+ {
+ "beta": 100.0,
+ "lambda": 0.5,
+ "max_iter": 1,
+ "em": 0.36,
+ "f1": 0.4608,
+ "avg_faiss_overlap": 0.722,
+ "avg_entropy": 0.0442
+ },
+ {
+ "beta": 10.0,
+ "lambda": 0.9,
+ "max_iter": 8,
+ "em": 0.33,
+ "f1": 0.46,
+ "avg_faiss_overlap": 0.744,
+ "avg_entropy": 7.0404
+ },
+ {
+ "beta": 5.0,
+ "lambda": 0.8,
+ "max_iter": 1,
+ "em": 0.34,
+ "f1": 0.4592,
+ "avg_faiss_overlap": 0.944,
+ "avg_entropy": 7.1003
+ },
+ {
+ "beta": 10.0,
+ "lambda": 0.8,
+ "max_iter": 1,
+ "em": 0.34,
+ "f1": 0.4592,
+ "avg_faiss_overlap": 0.938,
+ "avg_entropy": 6.7407
+ },
+ {
+ "beta": 10.0,
+ "lambda": 0.95,
+ "max_iter": 5,
+ "em": 0.34,
+ "f1": 0.4592,
+ "avg_faiss_overlap": 0.928,
+ "avg_entropy": 6.7819
+ },
+ {
+ "beta": 5.0,
+ "lambda": 0.95,
+ "max_iter": 1,
+ "em": 0.34,
+ "f1": 0.4556,
+ "avg_faiss_overlap": 0.984,
+ "avg_entropy": 7.0709
+ },
+ {
+ "beta": 10.0,
+ "lambda": 0.95,
+ "max_iter": 1,
+ "em": 0.34,
+ "f1": 0.4556,
+ "avg_faiss_overlap": 0.984,
+ "avg_entropy": 6.5398
+ },
+ {
+ "beta": 20.0,
+ "lambda": 0.9,
+ "max_iter": 1,
+ "em": 0.34,
+ "f1": 0.4552,
+ "avg_faiss_overlap": 0.966,
+ "avg_entropy": 3.5526
+ },
+ {
+ "beta": 20.0,
+ "lambda": 0.5,
+ "max_iter": 1,
+ "em": 0.34,
+ "f1": 0.4548,
+ "avg_faiss_overlap": 0.796,
+ "avg_entropy": 4.0138
+ },
+ {
+ "beta": 5.0,
+ "lambda": 0.8,
+ "max_iter": 3,
+ "em": 0.32,
+ "f1": 0.4527,
+ "avg_faiss_overlap": 0.8,
+ "avg_entropy": 7.14
+ },
+ {
+ "beta": 10.0,
+ "lambda": 0.8,
+ "max_iter": 3,
+ "em": 0.32,
+ "f1": 0.4527,
+ "avg_faiss_overlap": 0.8,
+ "avg_entropy": 6.9964
+ },
+ {
+ "beta": 50.0,
+ "lambda": 0.8,
+ "max_iter": 3,
+ "em": 0.34,
+ "f1": 0.4524,
+ "avg_faiss_overlap": 0.74,
+ "avg_entropy": 0.1516
+ },
+ {
+ "beta": 20.0,
+ "lambda": 0.95,
+ "max_iter": 1,
+ "em": 0.34,
+ "f1": 0.4511,
+ "avg_faiss_overlap": 0.98,
+ "avg_entropy": 3.5011
+ },
+ {
+ "beta": 10.0,
+ "lambda": 0.9,
+ "max_iter": 3,
+ "em": 0.33,
+ "f1": 0.4509,
+ "avg_faiss_overlap": 0.918,
+ "avg_entropy": 6.828
+ },
+ {
+ "beta": 50.0,
+ "lambda": 0.7,
+ "max_iter": 5,
+ "em": 0.34,
+ "f1": 0.4509,
+ "avg_faiss_overlap": 0.492,
+ "avg_entropy": 0.0065
+ },
+ {
+ "beta": 50.0,
+ "lambda": 0.5,
+ "max_iter": 3,
+ "em": 0.35,
+ "f1": 0.4501,
+ "avg_faiss_overlap": 0.472,
+ "avg_entropy": 0.0137
+ },
+ {
+ "beta": 20.0,
+ "lambda": 0.8,
+ "max_iter": 3,
+ "em": 0.33,
+ "f1": 0.4498,
+ "avg_faiss_overlap": 0.798,
+ "avg_entropy": 4.0296
+ },
+ {
+ "beta": 50.0,
+ "lambda": 0.5,
+ "max_iter": 1,
+ "em": 0.34,
+ "f1": 0.4496,
+ "avg_faiss_overlap": 0.752,
+ "avg_entropy": 0.3624
+ },
+ {
+ "beta": 20.0,
+ "lambda": 0.95,
+ "max_iter": 3,
+ "em": 0.32,
+ "f1": 0.4494,
+ "avg_faiss_overlap": 0.946,
+ "avg_entropy": 3.5994
+ },
+ {
+ "beta": 5.0,
+ "lambda": 0.9,
+ "max_iter": 5,
+ "em": 0.32,
+ "f1": 0.4487,
+ "avg_faiss_overlap": 0.842,
+ "avg_entropy": 7.1313
+ },
+ {
+ "beta": 50.0,
+ "lambda": 0.95,
+ "max_iter": 1,
+ "em": 0.34,
+ "f1": 0.4486,
+ "avg_faiss_overlap": 0.974,
+ "avg_entropy": 0.677
+ },
+ {
+ "beta": 100.0,
+ "lambda": 0.95,
+ "max_iter": 1,
+ "em": 0.34,
+ "f1": 0.4464,
+ "avg_faiss_overlap": 0.97,
+ "avg_entropy": 0.2081
+ },
+ {
+ "beta": 50.0,
+ "lambda": 0.8,
+ "max_iter": 8,
+ "em": 0.33,
+ "f1": 0.4459,
+ "avg_faiss_overlap": 0.482,
+ "avg_entropy": 0.001
+ },
+ {
+ "beta": 100.0,
+ "lambda": 0.8,
+ "max_iter": 3,
+ "em": 0.34,
+ "f1": 0.4458,
+ "avg_faiss_overlap": 0.704,
+ "avg_entropy": 0.0034
+ },
+ {
+ "beta": 100.0,
+ "lambda": 0.8,
+ "max_iter": 1,
+ "em": 0.33,
+ "f1": 0.4441,
+ "avg_faiss_overlap": 0.878,
+ "avg_entropy": 0.0927
+ },
+ {
+ "beta": 20.0,
+ "lambda": 0.7,
+ "max_iter": 3,
+ "em": 0.33,
+ "f1": 0.4433,
+ "avg_faiss_overlap": 0.676,
+ "avg_entropy": 4.2538
+ },
+ {
+ "beta": 50.0,
+ "lambda": 0.95,
+ "max_iter": 8,
+ "em": 0.32,
+ "f1": 0.4431,
+ "avg_faiss_overlap": 0.808,
+ "avg_entropy": 0.2083
+ },
+ {
+ "beta": 100.0,
+ "lambda": 0.95,
+ "max_iter": 8,
+ "em": 0.32,
+ "f1": 0.4431,
+ "avg_faiss_overlap": 0.794,
+ "avg_entropy": 0.0019
+ },
+ {
+ "beta": 50.0,
+ "lambda": 0.8,
+ "max_iter": 1,
+ "em": 0.32,
+ "f1": 0.443,
+ "avg_faiss_overlap": 0.894,
+ "avg_entropy": 0.5053
+ },
+ {
+ "beta": 100.0,
+ "lambda": 0.95,
+ "max_iter": 3,
+ "em": 0.33,
+ "f1": 0.4427,
+ "avg_faiss_overlap": 0.9,
+ "avg_entropy": 0.0754
+ },
+ {
+ "beta": 10.0,
+ "lambda": 0.7,
+ "max_iter": 3,
+ "em": 0.31,
+ "f1": 0.4426,
+ "avg_faiss_overlap": 0.654,
+ "avg_entropy": 7.0699
+ },
+ {
+ "beta": 20.0,
+ "lambda": 0.95,
+ "max_iter": 8,
+ "em": 0.33,
+ "f1": 0.442,
+ "avg_faiss_overlap": 0.878,
+ "avg_entropy": 3.8365
+ },
+ {
+ "beta": 50.0,
+ "lambda": 0.9,
+ "max_iter": 1,
+ "em": 0.32,
+ "f1": 0.4419,
+ "avg_faiss_overlap": 0.946,
+ "avg_entropy": 0.6043
+ },
+ {
+ "beta": 50.0,
+ "lambda": 0.7,
+ "max_iter": 1,
+ "em": 0.32,
+ "f1": 0.4415,
+ "avg_faiss_overlap": 0.836,
+ "avg_entropy": 0.4413
+ },
+ {
+ "beta": 100.0,
+ "lambda": 0.7,
+ "max_iter": 1,
+ "em": 0.32,
+ "f1": 0.4415,
+ "avg_faiss_overlap": 0.824,
+ "avg_entropy": 0.069
+ },
+ {
+ "beta": 20.0,
+ "lambda": 0.7,
+ "max_iter": 1,
+ "em": 0.32,
+ "f1": 0.4413,
+ "avg_faiss_overlap": 0.888,
+ "avg_entropy": 3.7761
+ },
+ {
+ "beta": 50.0,
+ "lambda": 0.95,
+ "max_iter": 5,
+ "em": 0.32,
+ "f1": 0.4407,
+ "avg_faiss_overlap": 0.87,
+ "avg_entropy": 0.3877
+ },
+ {
+ "beta": 50.0,
+ "lambda": 0.9,
+ "max_iter": 8,
+ "em": 0.34,
+ "f1": 0.4402,
+ "avg_faiss_overlap": 0.666,
+ "avg_entropy": 0.0306
+ },
+ {
+ "beta": 100.0,
+ "lambda": 0.9,
+ "max_iter": 1,
+ "em": 0.32,
+ "f1": 0.4397,
+ "avg_faiss_overlap": 0.932,
+ "avg_entropy": 0.1471
+ },
+ {
+ "beta": 100.0,
+ "lambda": 0.95,
+ "max_iter": 5,
+ "em": 0.32,
+ "f1": 0.4391,
+ "avg_faiss_overlap": 0.858,
+ "avg_entropy": 0.019
+ },
+ {
+ "beta": 50.0,
+ "lambda": 0.8,
+ "max_iter": 5,
+ "em": 0.32,
+ "f1": 0.439,
+ "avg_faiss_overlap": 0.592,
+ "avg_entropy": 0.024
+ },
+ {
+ "beta": 50.0,
+ "lambda": 0.5,
+ "max_iter": 8,
+ "em": 0.33,
+ "f1": 0.4381,
+ "avg_faiss_overlap": 0.41,
+ "avg_entropy": 0.0
+ },
+ {
+ "beta": 50.0,
+ "lambda": 0.95,
+ "max_iter": 3,
+ "em": 0.32,
+ "f1": 0.4377,
+ "avg_faiss_overlap": 0.912,
+ "avg_entropy": 0.5215
+ },
+ {
+ "beta": 20.0,
+ "lambda": 0.9,
+ "max_iter": 5,
+ "em": 0.32,
+ "f1": 0.437,
+ "avg_faiss_overlap": 0.846,
+ "avg_entropy": 3.9311
+ },
+ {
+ "beta": 100.0,
+ "lambda": 0.7,
+ "max_iter": 5,
+ "em": 0.32,
+ "f1": 0.4359,
+ "avg_faiss_overlap": 0.474,
+ "avg_entropy": 0.0
+ },
+ {
+ "beta": 100.0,
+ "lambda": 0.8,
+ "max_iter": 8,
+ "em": 0.32,
+ "f1": 0.4359,
+ "avg_faiss_overlap": 0.472,
+ "avg_entropy": 0.0
+ },
+ {
+ "beta": 20.0,
+ "lambda": 0.8,
+ "max_iter": 5,
+ "em": 0.32,
+ "f1": 0.4356,
+ "avg_faiss_overlap": 0.658,
+ "avg_entropy": 4.2886
+ },
+ {
+ "beta": 100.0,
+ "lambda": 0.5,
+ "max_iter": 3,
+ "em": 0.33,
+ "f1": 0.4351,
+ "avg_faiss_overlap": 0.456,
+ "avg_entropy": 0.0
+ },
+ {
+ "beta": 20.0,
+ "lambda": 0.8,
+ "max_iter": 1,
+ "em": 0.31,
+ "f1": 0.4349,
+ "avg_faiss_overlap": 0.924,
+ "avg_entropy": 3.6613
+ },
+ {
+ "beta": 50.0,
+ "lambda": 0.9,
+ "max_iter": 3,
+ "em": 0.31,
+ "f1": 0.4344,
+ "avg_faiss_overlap": 0.842,
+ "avg_entropy": 0.3574
+ },
+ {
+ "beta": 20.0,
+ "lambda": 0.9,
+ "max_iter": 3,
+ "em": 0.31,
+ "f1": 0.4313,
+ "avg_faiss_overlap": 0.9,
+ "avg_entropy": 3.7492
+ },
+ {
+ "beta": 20.0,
+ "lambda": 0.9,
+ "max_iter": 8,
+ "em": 0.31,
+ "f1": 0.4302,
+ "avg_faiss_overlap": 0.748,
+ "avg_entropy": 4.1579
+ },
+ {
+ "beta": 20.0,
+ "lambda": 0.95,
+ "max_iter": 5,
+ "em": 0.31,
+ "f1": 0.4299,
+ "avg_faiss_overlap": 0.91,
+ "avg_entropy": 3.6963
+ },
+ {
+ "beta": 50.0,
+ "lambda": 0.7,
+ "max_iter": 3,
+ "em": 0.32,
+ "f1": 0.4299,
+ "avg_faiss_overlap": 0.614,
+ "avg_entropy": 0.0708
+ },
+ {
+ "beta": 50.0,
+ "lambda": 0.9,
+ "max_iter": 5,
+ "em": 0.32,
+ "f1": 0.4291,
+ "avg_faiss_overlap": 0.778,
+ "avg_entropy": 0.1599
+ },
+ {
+ "beta": 100.0,
+ "lambda": 0.5,
+ "max_iter": 8,
+ "em": 0.32,
+ "f1": 0.4281,
+ "avg_faiss_overlap": 0.408,
+ "avg_entropy": 0.0
+ },
+ {
+ "beta": 50.0,
+ "lambda": 0.7,
+ "max_iter": 8,
+ "em": 0.33,
+ "f1": 0.4257,
+ "avg_faiss_overlap": 0.426,
+ "avg_entropy": 0.0
+ },
+ {
+ "beta": 10.0,
+ "lambda": 0.8,
+ "max_iter": 5,
+ "em": 0.29,
+ "f1": 0.4249,
+ "avg_faiss_overlap": 0.632,
+ "avg_entropy": 7.0765
+ },
+ {
+ "beta": 100.0,
+ "lambda": 0.9,
+ "max_iter": 3,
+ "em": 0.29,
+ "f1": 0.4244,
+ "avg_faiss_overlap": 0.828,
+ "avg_entropy": 0.019
+ },
+ {
+ "beta": 50.0,
+ "lambda": 0.5,
+ "max_iter": 5,
+ "em": 0.32,
+ "f1": 0.4225,
+ "avg_faiss_overlap": 0.42,
+ "avg_entropy": 0.0
+ },
+ {
+ "beta": 100.0,
+ "lambda": 0.8,
+ "max_iter": 5,
+ "em": 0.31,
+ "f1": 0.4165,
+ "avg_faiss_overlap": 0.562,
+ "avg_entropy": 0.0
+ },
+ {
+ "beta": 100.0,
+ "lambda": 0.7,
+ "max_iter": 8,
+ "em": 0.32,
+ "f1": 0.4157,
+ "avg_faiss_overlap": 0.422,
+ "avg_entropy": 0.0
+ },
+ {
+ "beta": 100.0,
+ "lambda": 0.9,
+ "max_iter": 8,
+ "em": 0.32,
+ "f1": 0.4152,
+ "avg_faiss_overlap": 0.636,
+ "avg_entropy": 0.0
+ },
+ {
+ "beta": 100.0,
+ "lambda": 0.9,
+ "max_iter": 5,
+ "em": 0.3,
+ "f1": 0.4141,
+ "avg_faiss_overlap": 0.764,
+ "avg_entropy": 0.0014
+ },
+ {
+ "beta": 20.0,
+ "lambda": 0.7,
+ "max_iter": 5,
+ "em": 0.3,
+ "f1": 0.4127,
+ "avg_faiss_overlap": 0.498,
+ "avg_entropy": 4.4813
+ },
+ {
+ "beta": 100.0,
+ "lambda": 0.5,
+ "max_iter": 5,
+ "em": 0.31,
+ "f1": 0.4125,
+ "avg_faiss_overlap": 0.416,
+ "avg_entropy": 0.0
+ },
+ {
+ "beta": 5.0,
+ "lambda": 0.8,
+ "max_iter": 5,
+ "em": 0.27,
+ "f1": 0.4054,
+ "avg_faiss_overlap": 0.614,
+ "avg_entropy": 7.1555
+ },
+ {
+ "beta": 5.0,
+ "lambda": 0.7,
+ "max_iter": 3,
+ "em": 0.27,
+ "f1": 0.4051,
+ "avg_faiss_overlap": 0.64,
+ "avg_entropy": 7.1544
+ },
+ {
+ "beta": 100.0,
+ "lambda": 0.7,
+ "max_iter": 3,
+ "em": 0.3,
+ "f1": 0.4049,
+ "avg_faiss_overlap": 0.582,
+ "avg_entropy": 0.0002
+ },
+ {
+ "beta": 20.0,
+ "lambda": 0.7,
+ "max_iter": 8,
+ "em": 0.31,
+ "f1": 0.3988,
+ "avg_faiss_overlap": 0.28,
+ "avg_entropy": 4.5698
+ },
+ {
+ "beta": 20.0,
+ "lambda": 0.5,
+ "max_iter": 5,
+ "em": 0.32,
+ "f1": 0.3946,
+ "avg_faiss_overlap": 0.258,
+ "avg_entropy": 4.6045
+ },
+ {
+ "beta": 20.0,
+ "lambda": 0.8,
+ "max_iter": 8,
+ "em": 0.28,
+ "f1": 0.3927,
+ "avg_faiss_overlap": 0.48,
+ "avg_entropy": 4.4867
+ },
+ {
+ "beta": 20.0,
+ "lambda": 0.5,
+ "max_iter": 3,
+ "em": 0.29,
+ "f1": 0.3925,
+ "avg_faiss_overlap": 0.452,
+ "avg_entropy": 4.5183
+ },
+ {
+ "beta": 10.0,
+ "lambda": 0.7,
+ "max_iter": 5,
+ "em": 0.3,
+ "f1": 0.379,
+ "avg_faiss_overlap": 0.334,
+ "avg_entropy": 7.112
+ },
+ {
+ "beta": 10.0,
+ "lambda": 0.8,
+ "max_iter": 8,
+ "em": 0.28,
+ "f1": 0.3659,
+ "avg_faiss_overlap": 0.318,
+ "avg_entropy": 7.1124
+ },
+ {
+ "beta": 10.0,
+ "lambda": 0.5,
+ "max_iter": 3,
+ "em": 0.28,
+ "f1": 0.353,
+ "avg_faiss_overlap": 0.222,
+ "avg_entropy": 7.1169
+ },
+ {
+ "beta": 5.0,
+ "lambda": 0.5,
+ "max_iter": 3,
+ "em": 0.27,
+ "f1": 0.3436,
+ "avg_faiss_overlap": 0.156,
+ "avg_entropy": 7.1645
+ },
+ {
+ "beta": 20.0,
+ "lambda": 0.5,
+ "max_iter": 8,
+ "em": 0.25,
+ "f1": 0.3298,
+ "avg_faiss_overlap": 0.186,
+ "avg_entropy": 4.565
+ },
+ {
+ "beta": 5.0,
+ "lambda": 0.7,
+ "max_iter": 5,
+ "em": 0.24,
+ "f1": 0.3274,
+ "avg_faiss_overlap": 0.278,
+ "avg_entropy": 7.1633
+ },
+ {
+ "beta": 5.0,
+ "lambda": 0.8,
+ "max_iter": 8,
+ "em": 0.24,
+ "f1": 0.3274,
+ "avg_faiss_overlap": 0.276,
+ "avg_entropy": 7.1634
+ },
+ {
+ "beta": 10.0,
+ "lambda": 0.7,
+ "max_iter": 8,
+ "em": 0.15,
+ "f1": 0.1995,
+ "avg_faiss_overlap": 0.044,
+ "avg_entropy": 7.123
+ },
+ {
+ "beta": 10.0,
+ "lambda": 0.5,
+ "max_iter": 5,
+ "em": 0.15,
+ "f1": 0.1983,
+ "avg_faiss_overlap": 0.022,
+ "avg_entropy": 7.1239
+ },
+ {
+ "beta": 5.0,
+ "lambda": 0.5,
+ "max_iter": 5,
+ "em": 0.14,
+ "f1": 0.1895,
+ "avg_faiss_overlap": 0.018,
+ "avg_entropy": 7.166
+ },
+ {
+ "beta": 5.0,
+ "lambda": 0.5,
+ "max_iter": 8,
+ "em": 0.14,
+ "f1": 0.1877,
+ "avg_faiss_overlap": 0.002,
+ "avg_entropy": 7.1661
+ },
+ {
+ "beta": 5.0,
+ "lambda": 0.7,
+ "max_iter": 8,
+ "em": 0.12,
+ "f1": 0.1824,
+ "avg_faiss_overlap": 0.034,
+ "avg_entropy": 7.1658
+ },
+ {
+ "beta": 10.0,
+ "lambda": 0.5,
+ "max_iter": 8,
+ "em": 0.12,
+ "f1": 0.1766,
+ "avg_faiss_overlap": 0.002,
+ "avg_entropy": 7.1239
+ }
+ ],
+ "best_config": {
+ "beta": 5.0,
+ "lambda": 0.7,
+ "max_iter": 1,
+ "em": 0.36,
+ "f1": 0.4809,
+ "avg_faiss_overlap": 0.902,
+ "avg_entropy": 7.1163
+ },
+ "top10": [
+ {
+ "beta": 5.0,
+ "lambda": 0.7,
+ "max_iter": 1,
+ "em": 0.36,
+ "f1": 0.4809,
+ "avg_faiss_overlap": 0.902,
+ "avg_entropy": 7.1163
+ },
+ {
+ "beta": 5.0,
+ "lambda": 0.9,
+ "max_iter": 3,
+ "em": 0.36,
+ "f1": 0.4809,
+ "avg_faiss_overlap": 0.912,
+ "avg_entropy": 7.1122
+ },
+ {
+ "beta": 10.0,
+ "lambda": 0.7,
+ "max_iter": 1,
+ "em": 0.36,
+ "f1": 0.4809,
+ "avg_faiss_overlap": 0.9,
+ "avg_entropy": 6.8422
+ },
+ {
+ "beta": 10.0,
+ "lambda": 0.95,
+ "max_iter": 8,
+ "em": 0.36,
+ "f1": 0.4797,
+ "avg_faiss_overlap": 0.886,
+ "avg_entropy": 6.8941
+ },
+ {
+ "beta": 5.0,
+ "lambda": 0.95,
+ "max_iter": 8,
+ "em": 0.35,
+ "f1": 0.4697,
+ "avg_faiss_overlap": 0.886,
+ "avg_entropy": 7.1219
+ },
+ {
+ "beta": 5.0,
+ "lambda": 0.95,
+ "max_iter": 3,
+ "em": 0.35,
+ "f1": 0.4692,
+ "avg_faiss_overlap": 0.956,
+ "avg_entropy": 7.09
+ },
+ {
+ "beta": 5.0,
+ "lambda": 0.95,
+ "max_iter": 5,
+ "em": 0.35,
+ "f1": 0.4692,
+ "avg_faiss_overlap": 0.928,
+ "avg_entropy": 7.105
+ },
+ {
+ "beta": 10.0,
+ "lambda": 0.9,
+ "max_iter": 1,
+ "em": 0.35,
+ "f1": 0.4692,
+ "avg_faiss_overlap": 0.966,
+ "avg_entropy": 6.6138
+ },
+ {
+ "beta": 10.0,
+ "lambda": 0.95,
+ "max_iter": 3,
+ "em": 0.35,
+ "f1": 0.4692,
+ "avg_faiss_overlap": 0.956,
+ "avg_entropy": 6.6763
+ },
+ {
+ "beta": 5.0,
+ "lambda": 0.9,
+ "max_iter": 1,
+ "em": 0.34,
+ "f1": 0.4672,
+ "avg_faiss_overlap": 0.97,
+ "avg_entropy": 7.0815
+ }
+ ]
+} \ No newline at end of file
diff --git a/figures/fig1_contour.png b/figures/fig1_contour.png
new file mode 100644
index 0000000..87aca0b
--- /dev/null
+++ b/figures/fig1_contour.png
Binary files differ
diff --git a/figures/fig2_profile.png b/figures/fig2_profile.png
new file mode 100644
index 0000000..eb25eae
--- /dev/null
+++ b/figures/fig2_profile.png
Binary files differ
diff --git a/figures/fig3_umap.png b/figures/fig3_umap.png
new file mode 100644
index 0000000..681ccb7
--- /dev/null
+++ b/figures/fig3_umap.png
Binary files differ
diff --git a/figures/fig4_pca.png b/figures/fig4_pca.png
new file mode 100644
index 0000000..35dee46
--- /dev/null
+++ b/figures/fig4_pca.png
Binary files differ
diff --git a/hag/config.py b/hag/config.py
index 793e3a6..10d0aff 100644
--- a/hag/config.py
+++ b/hag/config.py
@@ -19,6 +19,7 @@ class MemoryBankConfig:
embedding_dim: int = 768 # Must match encoder output dim
normalize: bool = True # L2-normalize embeddings in memory bank
+ center: bool = False # Mean-center embeddings to remove centroid attractor
@dataclass
@@ -35,7 +36,7 @@ class GeneratorConfig:
"""Configuration for the LLM generator."""
model_name: str = "meta-llama/Llama-3.1-8B-Instruct"
- max_new_tokens: int = 128
+ max_new_tokens: int = 32
temperature: float = 0.0 # Greedy decoding for reproducibility
@@ -48,3 +49,4 @@ class PipelineConfig:
encoder: EncoderConfig = field(default_factory=EncoderConfig)
generator: GeneratorConfig = field(default_factory=GeneratorConfig)
retriever_type: str = "hopfield" # "hopfield" or "faiss"
+ device: str = "cpu" # "cpu", "cuda", "cuda:0", etc.
diff --git a/hag/encoder.py b/hag/encoder.py
index 7e103f3..c380ad1 100644
--- a/hag/encoder.py
+++ b/hag/encoder.py
@@ -17,18 +17,20 @@ class Encoder:
For testing, use FakeEncoder instead.
"""
- def __init__(self, config: EncoderConfig) -> None:
+ def __init__(self, config: EncoderConfig, device: str = "cpu") -> None:
self.config = config
+ self.device = torch.device(device)
self._tokenizer = None
self._model = None
def _load_model(self) -> None:
- """Lazy-load the model and tokenizer."""
+ """Lazy-load the model and tokenizer, placing model on device."""
from transformers import AutoModel, AutoTokenizer
- logger.info("Loading encoder model: %s", self.config.model_name)
+ logger.info("Loading encoder model: %s (device=%s)", self.config.model_name, self.device)
self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
self._model = AutoModel.from_pretrained(self.config.model_name)
+ self._model.to(self.device)
self._model.eval()
@torch.no_grad()
@@ -39,7 +41,7 @@ class Encoder:
texts: single string or list of strings
Returns:
- (1, d) tensor for single input, (N, d) for list input.
+ (1, d) tensor for single input, (N, d) for list input. On self.device.
"""
if self._model is None:
self._load_model()
@@ -54,6 +56,7 @@ class Encoder:
truncation=True,
return_tensors="pt",
)
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self._model(**inputs)
# Mean pooling over token embeddings
embeddings = outputs.last_hidden_state.mean(dim=1) # (N, d)
diff --git a/hag/generator.py b/hag/generator.py
index 2142e0c..d0de468 100644
--- a/hag/generator.py
+++ b/hag/generator.py
@@ -3,11 +3,13 @@
import logging
from typing import List
+import torch
+
from hag.config import GeneratorConfig
logger = logging.getLogger(__name__)
-PROMPT_TEMPLATE = """Answer the following question based on the provided context passages.
+PROMPT_TEMPLATE = """Answer the following question based on the provided context passages. Give ONLY the answer itself in a few words, with no explanation.
Context:
{context}
@@ -24,20 +26,22 @@ class Generator:
For testing, use FakeGenerator instead.
"""
- def __init__(self, config: GeneratorConfig) -> None:
+ def __init__(self, config: GeneratorConfig, device: str = "cpu") -> None:
self.config = config
+ self.device = torch.device(device)
self._tokenizer = None
self._model = None
def _load_model(self) -> None:
- """Lazy-load the model and tokenizer."""
+ """Lazy-load the model and tokenizer, placing model on device."""
from transformers import AutoModelForCausalLM, AutoTokenizer
- logger.info("Loading generator model: %s", self.config.model_name)
+ logger.info("Loading generator model: %s (device=%s)", self.config.model_name, self.device)
self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
self._model = AutoModelForCausalLM.from_pretrained(
self.config.model_name,
torch_dtype="auto",
+ device_map=self.device,
)
self._model.eval()
@@ -60,15 +64,23 @@ class Generator:
prompt = PROMPT_TEMPLATE.format(context=context, question=question)
inputs = self._tokenizer(prompt, return_tensors="pt")
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self._model.generate(
**inputs,
max_new_tokens=self.config.max_new_tokens,
temperature=self.config.temperature if self.config.temperature > 0 else None,
do_sample=self.config.temperature > 0,
+ repetition_penalty=1.2,
)
# Decode only the generated tokens (skip the prompt)
generated = outputs[0][inputs["input_ids"].shape[1]:]
- return self._tokenizer.decode(generated, skip_special_tokens=True).strip()
+ answer = self._tokenizer.decode(generated, skip_special_tokens=True).strip()
+ # Take only the first sentence/line as the answer
+ for sep in ["\n", ". ", ".\n"]:
+ if sep in answer:
+ answer = answer.split(sep)[0].strip()
+ break
+ return answer
class FakeGenerator:
diff --git a/hag/memory_bank.py b/hag/memory_bank.py
index 42dcc73..0a0a87c 100644
--- a/hag/memory_bank.py
+++ b/hag/memory_bank.py
@@ -16,12 +16,17 @@ class MemoryBank:
The memory bank is M in R^{d x N} where each column is a passage embedding.
Also maintains a mapping from column index to passage text for final retrieval.
+
+ When config.center=True, embeddings are mean-centered to remove the centroid
+ attractor in Hopfield dynamics. The mean is saved so queries can be centered
+ with the same offset via center_query().
"""
def __init__(self, config: MemoryBankConfig) -> None:
self.config = config
self.embeddings: Optional[torch.Tensor] = None # (d, N)
self.passages: List[str] = []
+ self.mean: Optional[torch.Tensor] = None # (d,) — saved for query centering
def build_from_embeddings(
self, embeddings: torch.Tensor, passages: List[str]
@@ -38,10 +43,42 @@ class MemoryBank:
)
if self.config.normalize:
embeddings = F.normalize(embeddings, dim=-1)
+ if self.config.center:
+ self.mean = embeddings.mean(dim=0) # (d,)
+ embeddings = embeddings - self.mean.unsqueeze(0) # (N, d)
+ logger.info("Centered memory bank (removed mean)")
self.embeddings = embeddings.T # Store as (d, N) for efficient matmul
self.passages = list(passages)
logger.info("Built memory bank with %d passages, dim=%d", self.size, self.dim)
+ def center_query(self, query: torch.Tensor) -> torch.Tensor:
+ """Center a query embedding using the saved memory mean.
+
+ Must be called before Hopfield retrieval when config.center=True.
+
+ Args:
+ query: (d,) or (batch, d) — query embedding(s)
+
+ Returns:
+ Centered query tensor, same shape as input.
+ """
+ if self.mean is None:
+ return query
+ return query - self.mean.to(query.device)
+
+ def apply_centering(self) -> None:
+ """Center an already-loaded (uncentered) memory bank in-place.
+
+ Useful when loading a memory bank that was saved without centering.
+ Computes and stores the mean, then subtracts it from embeddings.
+ """
+ if self.embeddings is None:
+ return
+ # embeddings is (d, N), mean over columns
+ self.mean = self.embeddings.mean(dim=1) # (d,)
+ self.embeddings = self.embeddings - self.mean.unsqueeze(1) # (d, N)
+ logger.info("Applied centering to loaded memory bank")
+
def get_passages_by_indices(self, indices: torch.Tensor) -> List[str]:
"""Given top-k indices, return corresponding passage texts.
@@ -67,20 +104,38 @@ class MemoryBank:
"embedding_dim": self.config.embedding_dim,
"normalize": self.config.normalize,
},
+ "mean": self.mean,
}
torch.save(data, path)
logger.info("Saved memory bank to %s", path)
- def load(self, path: str) -> None:
+ def load(self, path: str, device: str = "cpu") -> None:
"""Load memory bank from disk.
Args:
path: file path to load from
+ device: device to load tensors onto ("cpu", "cuda", "cuda:0", etc.)
"""
- data = torch.load(path, weights_only=False)
+ data = torch.load(path, weights_only=False, map_location=device)
self.embeddings = data["embeddings"]
self.passages = data["passages"]
- logger.info("Loaded memory bank from %s (%d passages)", path, self.size)
+ self.mean = data.get("mean", None)
+ logger.info("Loaded memory bank from %s (%d passages, device=%s)", path, self.size, device)
+
+ def to(self, device: str) -> "MemoryBank":
+ """Move memory bank embeddings to the specified device.
+
+ Args:
+ device: target device ("cpu", "cuda", "cuda:0", etc.)
+
+ Returns:
+ self (for chaining).
+ """
+ if self.embeddings is not None:
+ self.embeddings = self.embeddings.to(device)
+ if self.mean is not None:
+ self.mean = self.mean.to(device)
+ return self
@property
def size(self) -> int:
diff --git a/hag/pipeline.py b/hag/pipeline.py
index 1fefb84..086b3be 100644
--- a/hag/pipeline.py
+++ b/hag/pipeline.py
@@ -82,7 +82,8 @@ class RAGPipeline:
if self.retriever_type == "hopfield":
retrieval_result = self.hopfield_retriever.retrieve(query_emb)
else:
- query_np = query_emb.detach().numpy().astype(np.float32)
+ # FAISS requires CPU numpy arrays
+ query_np = query_emb.detach().cpu().numpy().astype(np.float32)
retrieval_result = self.faiss_retriever.retrieve(query_np)
# Generate
diff --git a/note.md b/note.md
new file mode 100644
index 0000000..5abf740
--- /dev/null
+++ b/note.md
@@ -0,0 +1,181 @@
+# HAG 实验记录
+
+## 实验环境
+
+- Memory bank: HotpotQA, 1311 passages, dim=768 (Contriever-MSMARCO)
+- Generator: Llama-3.1-8B-Instruct, temperature=0 (greedy)
+- Evaluation: 100 questions, EM + F1
+- Hardware: NVIDIA RTX A6000 × 4
+
+---
+
+## Round 1: 初始 Grid Search (β ≤ 8)
+
+**日期**: 2026-02-15
+**脚本**: `scripts/run_grid_search.py`
+**Grid**: β=[0.25, 0.5, 1.0, 2.0, 3.0, 5.0, 8.0] × iter=[1,2,3,5,8,15], top_k=5
+**Dedup**: 381 LLM calls / 4300 grid evals (91.1% saving)
+
+| 方法 | EM | F1 |
+|---|---|---|
+| FAISS baseline | 0.320 | 0.438 |
+| Best HAG (β=8, iter=1) | 0.210 | 0.294 |
+
+**结论**: 全军覆没。所有 42 个 HAG 配置均低于 FAISS baseline。
+
+**问题诊断**:
+- Entropy ≈ 7.17 / 7.18 (≈log(1311)),attention 接近均匀分布
+- FAISS overlap 极低 (0.2%-6.8%)
+- iter>1 普遍更差,所有 query 被吸到同一个 centroid
+
+---
+
+## Root Cause Analysis: Centroid Attractor
+
+### 发现 1: 能量面结构问题
+
+E(q) = -1/β · logsumexp(β · qᵀM) + 1/2 · ‖q‖²
+
+| β | E(centroid) | E(memory_i) | 谁更低? |
+|---|---|---|---|
+| 1.0 | **-7.375** | -7.068 | centroid |
+| 8.0 | **-1.097** | -0.812 | centroid |
+| 50.0 | -0.357 | **-0.500** | memory |
+
+**原因**: ‖centroid‖=0.63 vs ‖memory‖=1.0,norm 惩罚项 1/2‖q‖² 让 centroid 白省 0.3 能量。β<50 时 centroid 是全局最低能量点。
+
+### 发现 2: 一步迭代即崩溃
+
+β=8 时:
+- t=0: query 正常,top3 是正确 passages
+- t=1: cos(q, centroid) = 0.993,‖q‖从 1.36 骤降到 0.64
+- t=2: cos(q, centroid) = 0.999,所有 query 收敛到同一点
+
+**机制**: softmax(β·qᵀM) 近均匀 → q_new = Σ αᵢmᵢ ≈ centroid → norm 缩小 → 下一步更均匀 → 恶性循环
+
+### 发现 3: 去掉 norm 项不够
+
+即使 normalize 每步 update(‖q‖=1),β≤8 时仍收敛到 normalized centroid 方向。因为 softmax averaging 本身就把 query 拉向 centroid 方向,与 norm 项无关。
+
+### 发现 4: Memory bank 结构
+
+- 1311 passages pairwise cosine sim: mean=0.392, std=0.065
+- Centroid 与所有 memory 的 cosine sim: mean=0.627 (很高)
+- β 需要 ≈50+ 才能让 softmax 区分 passages(entropy 从 99% 降到 10%)
+
+### 关键洞察
+
+**max_iter=1 在代码中意味着做 1 次 Hopfield 更新**(不是 0 次)。Grid search 从未测试 iter=0(纯 softmax top-k = FAISS)。这解释了为什么 "iter=1" 已经比 FAISS 差很多。
+
+---
+
+## Round 2: 20-question 方案探索
+
+在 20 个问题上快速测试 4 种修复方案(带 LLM 生成):
+
+### A. 高 β(50, 100)纯 Hopfield
+
+| 配置 | EM | F1 |
+|---|---|---|
+| FAISS | 0.300 | 0.456 |
+| β=50 iter=1 | 0.300 | 0.417 |
+| β=100 iter=1 | 0.300 | 0.417 |
+
+高 β 使初始 attention 变尖锐,但迭代后 query 仍偏移,F1 反而下降。
+
+### B. Pre-filter K + Hopfield
+
+| 配置 | EM | F1 |
+|---|---|---|
+| PreFilter K=20 → β=5 iter=1 | **0.350** | **0.500** |
+| PreFilter K=20 → β=20 iter=1 | 0.350 | 0.462 |
+| PreFilter K=50 → β=5 iter=1 | 0.350 | 0.462 |
+
+最优方案,但本质是 FAISS 初筛 + Hopfield 重排,削弱了 "用 Hopfield 替代 FAISS" 的叙事。
+
+### C. Residual Connection
+
+q_{t+1} = λ · q_t + (1-λ) · M @ softmax(β · Mᵀ · q_t)
+
+| 配置 | EM | F1 |
+|---|---|---|
+| β=20 λ=0.9 iter=3 | 0.350 | 0.473 |
+| β=50 λ=0.9 iter=3 | 0.350 | 0.473 |
+| β=50 λ=0.7 iter=3 | 0.350 | 0.433 |
+
+λ=0.9(保留 90% 原始 query)有效防止 centroid 崩塌。
+
+### D. Pre-filter + Residual
+
+| 配置 | EM | F1 |
+|---|---|---|
+| PF K=20 + β=5 λ=0.9 iter=3 | 0.350 | 0.473 |
+| PF K=50 + β=10 λ=0.9 iter=3 | 0.350 | 0.473 |
+
+没有比单独用 Residual 更好。
+
+**决策**: 选择 pure Hopfield + Residual 路线(不依赖 FAISS pre-filter)。
+
+---
+
+## Round 3: Residual Grid (100 questions)
+
+**脚本**: `scripts/eval_residual_grid.py`
+**Grid**: β=[5,10,20,50,100] × λ=[0.5,0.7,0.8,0.9,0.95] × iter=[1,3,5,8]
+**Dedup**: 1666 unique LLM calls
+
+| 排名 | 配置 | EM | F1 | FAISS overlap |
+|---|---|---|---|---|
+| - | FAISS baseline | 0.320 | 0.438 | 1.000 |
+| 1 | **β=5 λ=0.7 iter=1** | **0.360** | **0.481** | 0.902 |
+| 2 | β=5 λ=0.9 iter=3 | 0.360 | 0.481 | 0.912 |
+| 3 | β=10 λ=0.7 iter=1 | 0.360 | 0.481 | 0.900 |
+| 4 | β=10 λ=0.95 iter=8 | 0.360 | 0.480 | 0.886 |
+| 5 | β=5 λ=0.95 iter=8 | 0.350 | 0.470 | 0.886 |
+
+**55/100 configs beat FAISS F1**。Residual 路线 robust。
+
+---
+
+## Round 4: High-β Grid (100 questions)
+
+**脚本**: `scripts/eval_highbeta_grid.py`
+**Grid**: β=[20,50,100,200,500] × iter=[0,1,2,3,5,8] × mode=[standard, normalized, residual_0.9, residual_0.95]
+**Dedup**: 1379 unique LLM calls
+
+| 排名 | 配置 | EM | F1 | FAISS overlap |
+|---|---|---|---|---|
+| - | FAISS baseline | 0.320 | 0.438 | 1.000 |
+| 1 | **β=20 iter=1 standard** | **0.380** | **0.469** | 0.480 |
+| 2 | β=50 iter=1 standard | 0.360 | 0.457 | 0.508 |
+| 3 | β=20 iter=1 residual_0.9 | 0.340 | 0.455 | 0.966 |
+| 4 | β=20 iter=2 residual_0.95 | 0.340 | 0.455 | 0.966 |
+| 5 | β=500 iter=1 residual_0.9 | 0.360 | 0.455 | 0.692 |
+| 6 | β=50 iter=1 normalized | 0.370 | 0.454 | 0.464 |
+
+**31/105 configs beat FAISS F1**。Standard 高 β 也能 work 但 overlap 低(~0.5,换了一半 passages)。
+
+---
+
+## 综合结论
+
+### 最优配置
+
+| 方法 | EM | F1 | vs FAISS |
+|---|---|---|---|
+| FAISS baseline | 0.320 | 0.438 | - |
+| Residual β=5 λ=0.7 iter=1 | 0.360 | 0.481 | **+4.3% F1** |
+| Standard β=20 iter=1 | 0.380 | 0.469 | **+3.1% F1** |
+
+### 规律
+
+1. **iter=1 普遍最优**,多次迭代有害(centroid 吸引)
+2. **低 β + residual > 高 β + standard**:保守修正比激进替换好
+3. **Residual 更 robust**: 55% configs beat FAISS vs 30% for high-β
+4. **FAISS overlap 高 = 好**: 最优 residual configs overlap ≈ 0.9(只换 10% passages 即提升)
+
+### 未解决问题
+
+- 只做了 1 步迭代就最优,"iterative refinement" 的叙事受限
+- Centroid attractor 在低 β 时仍存在,需要根本性解决方案
+- 100 questions 样本量偏小,需要在 500+ 上验证
diff --git a/scripts/analyze_energy.py b/scripts/analyze_energy.py
index fd044a4..cd93b15 100644
--- a/scripts/analyze_energy.py
+++ b/scripts/analyze_energy.py
@@ -32,6 +32,7 @@ def main() -> None:
parser.add_argument("--memory-bank", type=str, required=True)
parser.add_argument("--questions", type=str, required=True)
parser.add_argument("--output", type=str, default="energy_analysis.json")
+ parser.add_argument("--device", type=str, default="cpu")
args = parser.parse_args()
with open(args.config) as f:
@@ -43,13 +44,13 @@ def main() -> None:
# Load memory bank
mb = MemoryBank(memory_config)
- mb.load(args.memory_bank)
+ mb.load(args.memory_bank, device=args.device)
# Load questions
with open(args.questions) as f:
questions = [json.loads(line)["question"] for line in f]
- encoder = Encoder(encoder_config)
+ encoder = Encoder(encoder_config, device=args.device)
hopfield = HopfieldRetrieval(hopfield_config)
analyses = []
diff --git a/scripts/build_memory_bank.py b/scripts/build_memory_bank.py
index 2aff828..0fc1c51 100644
--- a/scripts/build_memory_bank.py
+++ b/scripts/build_memory_bank.py
@@ -50,13 +50,13 @@ def main() -> None:
logger.info("Loaded %d passages", len(passages))
# Encode passages in batches
- encoder = Encoder(encoder_config)
+ encoder = Encoder(encoder_config, device=args.device)
all_embeddings = []
for i in tqdm(range(0, len(passages), encoder_config.batch_size), desc="Encoding"):
batch = passages[i : i + encoder_config.batch_size]
emb = encoder.encode(batch) # (batch_size, d)
- all_embeddings.append(emb.cpu())
+ all_embeddings.append(emb.cpu()) # Always store on CPU for saving
embeddings = torch.cat(all_embeddings, dim=0) # (N, d)
logger.info("Encoded %d passages -> embeddings shape: %s", len(passages), embeddings.shape)
diff --git a/scripts/diagnose_centering.py b/scripts/diagnose_centering.py
new file mode 100644
index 0000000..5b9b4ee
--- /dev/null
+++ b/scripts/diagnose_centering.py
@@ -0,0 +1,301 @@
+"""Diagnose centering dynamics: trace q step by step, verify β_critical theory."""
+
+import sys
+import torch
+import torch.nn.functional as F
+import numpy as np
+
+sys.path.insert(0, "/home/yurenh2/HAG")
+
+from hag.memory_bank import MemoryBank
+from hag.config import MemoryBankConfig
+
+# ── Load memory bank ─────────────────────────────────────────────────
+device = "cuda:0" # CUDA_VISIBLE_DEVICES remaps
+mb = MemoryBank(MemoryBankConfig(embedding_dim=768, normalize=True, center=False))
+mb.load("/home/yurenh2/HAG/data/processed/hotpotqa_memory_bank.pt", device=device)
+M_raw = mb.embeddings # (d, N), L2-normalized, NOT centered
+d, N = M_raw.shape
+print(f"Memory bank: d={d}, N={N}")
+
+# ── Center manually ──────────────────────────────────────────────────
+mu = M_raw.mean(dim=1) # (d,)
+M_cent = M_raw - mu.unsqueeze(1) # (d, N) centered
+print(f"‖μ‖ = {mu.norm():.4f}")
+print(f"‖M̃·1/N‖ = {(M_cent.mean(dim=1)).norm():.2e} (should be ~0)")
+
+# ── Column norms ─────────────────────────────────────────────────────
+col_norms_raw = M_raw.norm(dim=0)
+col_norms_cent = M_cent.norm(dim=0)
+print(f"\nRaw column norms: mean={col_norms_raw.mean():.4f}, std={col_norms_raw.std():.4f}")
+print(f"Centered column norms: mean={col_norms_cent.mean():.4f}, std={col_norms_cent.std():.4f}")
+
+# ── SVD and β_critical ──────────────────────────────────────────────
+# M̃M̃ᵀ/N is the sample covariance. β_crit = N / λ_max(M̃M̃ᵀ) = 1/λ_max(C)
+# where C = M̃M̃ᵀ/N
+# But let's compute via SVD of M̃
+print("\nComputing SVD of centered memory...")
+U, S, Vh = torch.linalg.svd(M_cent, full_matrices=False) # S shape: (min(d,N),)
+print(f"Top 10 singular values: {S[:10].cpu().tolist()}")
+lambda_max_MMT = S[0].item() ** 2 # largest eigenvalue of M̃M̃ᵀ
+lambda_max_C = lambda_max_MMT / N # largest eigenvalue of M̃M̃ᵀ/N
+
+print(f"\nλ_max(M̃M̃ᵀ) = {lambda_max_MMT:.2f}")
+print(f"λ_max(C=M̃M̃ᵀ/N) = {lambda_max_C:.4f}")
+
+# Jacobian at origin: DT = β/N · M̃M̃ᵀ
+# Spectral radius = β/N · λ_max(M̃M̃ᵀ) = β · λ_max(C)
+# For instability: β · λ_max(C) > 1 → β > 1/λ_max(C)
+beta_crit = 1.0 / lambda_max_C
+print(f"\nβ_critical = 1/λ_max(C) = {beta_crit:.4f}")
+print("For β > β_crit, origin is UNSTABLE (Jacobian spectral radius > 1)")
+
+for beta in [0.5, 1.0, 5.0, 10.0, 20.0, 50.0, 100.0]:
+ rho = beta * lambda_max_C
+ print(f" β={beta:6.1f}: Jacobian spectral radius = {rho:.2f} {'UNSTABLE' if rho > 1 else 'stable'}")
+
+# ── Load some real queries ───────────────────────────────────────────
+import json
+questions_path = "/home/yurenh2/HAG/data/processed/hotpotqa_questions.jsonl"
+with open(questions_path) as f:
+ questions = [json.loads(line) for line in f][:5]
+
+from transformers import AutoTokenizer, AutoModel
+print("\nLoading encoder...")
+tokenizer = AutoTokenizer.from_pretrained("facebook/contriever-msmarco")
+model = AutoModel.from_pretrained("facebook/contriever-msmarco").to(device)
+model.eval()
+
+def encode(texts):
+ inputs = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
+ with torch.no_grad():
+ outputs = model(**inputs)
+ # Contriever uses mean pooling
+ mask = inputs["attention_mask"].unsqueeze(-1).float()
+ emb = (outputs.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1)
+ return F.normalize(emb, dim=-1)
+
+q_texts = [q["question"] for q in questions]
+q_embs = encode(q_texts) # (5, d)
+
+# ── Trace dynamics step by step (centered) ───────────────────────────
+print("\n" + "=" * 80)
+print("CENTERED DYNAMICS TRACE (5 queries)")
+print("=" * 80)
+
+for beta in [5.0, 20.0, 50.0, 100.0]:
+ print(f"\n--- β = {beta} ---")
+
+ for qi in range(min(3, len(q_texts))):
+ q0_raw = q_embs[qi:qi+1] # (1, d)
+ q0 = q0_raw - mu.unsqueeze(0) # center the query
+
+ q = q0.clone()
+ print(f"\n Q{qi}: '{q_texts[qi][:60]}...'")
+ print(f" t=0: ‖q‖={q.norm():.4f}")
+
+ for t in range(8):
+ logits = beta * (q @ M_cent) # (1, N)
+ alpha = torch.softmax(logits, dim=-1) # (1, N)
+ entropy = -(alpha * alpha.log()).sum().item()
+ max_alpha = alpha.max().item()
+ q_new = alpha @ M_cent.T # (1, d)
+
+ # How close is q_new to the dominant eigenvector?
+ cos_v1 = abs((q_new / (q_new.norm() + 1e-12)) @ U[:, 0]).item()
+
+ # FAISS-equivalent: initial attention from raw query on raw memory
+ if t == 0:
+ logits_raw = beta * (q0_raw @ M_raw)
+ alpha_raw = torch.softmax(logits_raw, dim=-1)
+ _, top5_raw = alpha_raw.topk(5)
+
+ # Centered attention top-5
+ _, top5_cent = alpha.topk(5)
+ overlap = len(set(top5_raw[0].cpu().tolist()) & set(top5_cent[0].cpu().tolist()))
+ print(f" initial overlap(raw top5, centered top5) = {overlap}/5")
+
+ delta = (q_new - q).norm().item()
+ q = q_new
+ print(f" t={t+1}: ‖q‖={q.norm():.6f}, H(α)={entropy:.2f}/{np.log(N):.2f}, "
+ f"max(α)={max_alpha:.2e}, Δ={delta:.2e}, cos(q,v1)={cos_v1:.4f}")
+
+ if delta < 1e-8:
+ print(f" [converged]")
+ break
+
+ # Final top-5 from centered attention
+ logits_final = beta * (q @ M_cent)
+ alpha_final = torch.softmax(logits_final, dim=-1)
+ _, top5_final = alpha_final.topk(5)
+
+ # Compare to "iter=0" (just softmax on centered, no iteration)
+ logits_iter0 = beta * (q0 @ M_cent)
+ alpha_iter0 = torch.softmax(logits_iter0, dim=-1)
+ _, top5_iter0 = alpha_iter0.topk(5)
+
+ # And raw (FAISS-like)
+ logits_faiss = beta * (q0_raw @ M_raw)
+ alpha_faiss = torch.softmax(logits_faiss, dim=-1)
+ _, top5_faiss = alpha_faiss.topk(5)
+
+ print(f" Top-5 FAISS: {top5_faiss[0].cpu().tolist()}")
+ print(f" Top-5 cent t=0: {top5_iter0[0].cpu().tolist()}")
+ print(f" Top-5 cent final: {top5_final[0].cpu().tolist()}")
+
+# ── KEY TEST: Does the Jacobian actually amplify near origin? ────────
+print("\n" + "=" * 80)
+print("JACOBIAN TEST: Start near origin, see if dynamics amplify")
+print("=" * 80)
+
+for beta in [5.0, 50.0, 100.0]:
+ # Start with a tiny perturbation in direction of top eigenvector
+ eps = 1e-4
+ q_tiny = eps * U[:, 0].unsqueeze(0) # (1, d), tiny perturbation along v1
+
+ print(f"\n--- β = {beta}, q0 = {eps}·v1 ---")
+ q = q_tiny.clone()
+ for t in range(10):
+ logits = beta * (q @ M_cent)
+ alpha = torch.softmax(logits, dim=-1)
+ q_new = alpha @ M_cent.T
+ amplification = q_new.norm().item() / (q.norm().item() + 1e-20)
+ print(f" t={t}: ‖q‖={q.norm():.6e} → ‖q_new‖={q_new.norm():.6e}, "
+ f"amplification={amplification:.2f}")
+ q = q_new
+ if q.norm().item() > 1.0:
+ print(f" [escaped origin at t={t}]")
+ break
+
+# ── Alternative: What about removing the ‖q‖² term from energy? ─────
+# The standard update q_{t+1} = M·softmax(β·Mᵀq) minimizes
+# E(q) = -1/β·lse(β·Mᵀq) + 1/2·‖q‖²
+# What if we don't want the ‖q‖² penalty? Then the fixed point equation
+# is just q* = M·softmax(β·Mᵀq*), same update but different energy landscape.
+# The issue is: with centering, M̃·uniform = 0 regardless of energy.
+# The ‖q‖² penalty is NOT the problem for centering — the averaging is.
+
+print("\n" + "=" * 80)
+print("DIAGNOSIS: Why centering fails for iteration")
+print("=" * 80)
+
+# For β=50, show the attention distribution at t=0 (before any iteration)
+beta = 50.0
+q0_raw = q_embs[0:1]
+q0_cent = q0_raw - mu.unsqueeze(0)
+logits_cent = beta * (q0_cent @ M_cent) # (1, N)
+alpha_cent = torch.softmax(logits_cent, dim=-1)
+entropy_cent = -(alpha_cent * alpha_cent.log()).sum().item()
+max_alpha_cent = alpha_cent.max().item()
+
+logits_raw = beta * (q0_raw @ M_raw)
+alpha_raw = torch.softmax(logits_raw, dim=-1)
+entropy_raw = -(alpha_raw * alpha_raw.log()).sum().item()
+
+print(f"\nβ={beta}, Q0: '{q_texts[0][:60]}...'")
+print(f"Raw attention: entropy={entropy_raw:.2f}, max={alpha_raw.max():.4f}")
+print(f"Cent attention: entropy={entropy_cent:.2f}, max={max_alpha_cent:.4f}")
+print(f"‖q0_cent‖ = {q0_cent.norm():.4f}")
+print(f"‖q0_raw‖ = {q0_raw.norm():.4f}")
+
+# Show: what's the actual q1 norm vs predicted from Jacobian?
+q1_cent = alpha_cent @ M_cent.T # (1, d)
+predicted_norm = (beta / N * lambda_max_MMT) * q0_cent.norm().item() # rough bound
+print(f"\n‖q1_cent‖ actual = {q1_cent.norm():.6f}")
+print(f"‖q0_cent‖ × Jacobian_spectral_radius ≈ {q0_cent.norm():.4f} × {beta*lambda_max_C:.2f} = {q0_cent.norm().item()*beta*lambda_max_C:.4f}")
+print(f"But the linearization only holds for q→0. q0 is NOT near zero.")
+
+# The real issue: softmax(β·M̃ᵀ·q0) when q0 has ‖q‖=0.5
+# The logits have some spread, but the weighted average of CENTERED vectors
+# inherently cancels out.
+weighted_avg = (alpha_cent @ M_cent.T) # (1, d)
+unweighted_avg = M_cent.mean(dim=1) # (d,)
+print(f"\n‖weighted_avg‖ = {weighted_avg.norm():.6f}")
+print(f"‖unweighted_avg‖ = {unweighted_avg.norm():.2e}")
+
+# How concentrated is the attention?
+top50_vals, top50_idx = alpha_cent.topk(50)
+mass_top50 = top50_vals.sum().item()
+print(f"Mass in top-50 memories: {mass_top50:.4f}")
+print(f"Mass in top-5 memories: {alpha_cent.topk(5)[0].sum().item():.4f}")
+
+# The weighted average of centered vectors is small because:
+# 1. Centered vectors m̃_i have ‖m̃_i‖ ≈ 0.78 (smaller than raw ‖m_i‖=1)
+# 2. Centered vectors point in diverse directions (they have mean removed)
+# 3. Even with non-uniform weights, the cancellation is severe unless
+# attention is extremely peaked on a few memories
+# So the output ‖q1‖ << ‖q0‖, even though β is large
+
+# Key quantification: what fraction of ‖q0‖ is preserved?
+preserve_ratio = q1_cent.norm().item() / q0_cent.norm().item()
+print(f"\n‖q1‖/‖q0‖ = {preserve_ratio:.4f} (fraction of query norm preserved)")
+print("This ratio << 1 means the averaging contracts the query toward 0.")
+print("For centering to work with iteration, this ratio must be > 1.")
+
+print("\n" + "=" * 80)
+print("SOLUTION ANALYSIS")
+print("=" * 80)
+print("""
+The centering fix removes the centroid attractor: M̃·uniform = 0, not μ.
+But the fundamental problem remains: ANY weighted average of centered vectors
+is much shorter than the input query, because centered vectors cancel.
+
+For the origin to be unstable, β must exceed β_critical so that the Jacobian
+amplifies perturbations near zero. But the dynamics from a realistic starting
+point (‖q‖≈0.5) don't behave like the linearization predicts.
+
+The actual contraction ratio ‖q1‖/‖q0‖ is what matters, not the Jacobian
+at origin. This ratio is small because softmax isn't peaked enough.
+
+Possible fixes:
+1. MUCH higher β (β > 500?) to make attention ultra-peaked → less cancellation
+2. Residual connection with centering: q_{t+1} = λ·q_t + (1-λ)·M̃·softmax(...)
+ This explicitly preserves query norm while still benefiting from centering.
+3. Normalize q_{t+1} after each step to prevent norm collapse.
+4. Use centering only for the attention computation, not for the update target:
+ α = softmax(β · M̃ᵀ · q̃) but q_{t+1} = M_raw · α (update in original space)
+""")
+
+# Test option 4: centered attention, raw update
+print("\n" + "=" * 80)
+print("TEST: Centered attention + raw update (hybrid)")
+print("=" * 80)
+
+for beta in [5.0, 20.0, 50.0]:
+ print(f"\n--- β = {beta} ---")
+ for qi in range(min(3, len(q_texts))):
+ q_raw = q_embs[qi:qi+1].clone() # (1, d) raw query
+ print(f" Q{qi}: '{q_texts[qi][:50]}...'")
+
+ for t in range(5):
+ q_cent = q_raw - mu.unsqueeze(0) # center query
+ logits = beta * (q_cent @ M_cent) # attention on centered space
+ alpha = torch.softmax(logits, dim=-1)
+ q_new = alpha @ M_raw.T # update in RAW space
+
+ entropy = -(alpha * alpha.log()).sum().item()
+ delta = (q_new - q_raw).norm().item()
+
+ if t == 0:
+ _, top5 = alpha.topk(5)
+ # FAISS top5
+ logits_f = q_embs[qi:qi+1] @ M_raw
+ _, top5_f = logits_f.topk(5)
+ overlap = len(set(top5[0].cpu().tolist()) & set(top5_f[0].cpu().tolist()))
+
+ print(f" t={t}: ‖q‖={q_raw.norm():.4f} → {q_new.norm():.4f}, "
+ f"H={entropy:.2f}, Δ={delta:.4f}")
+ q_raw = q_new
+
+ # Final top-5
+ q_cent_f = q_raw - mu.unsqueeze(0)
+ logits_f = beta * (q_cent_f @ M_cent)
+ _, top5_final = torch.softmax(logits_f, dim=-1).topk(5)
+ logits_faiss = q_embs[qi:qi+1] @ M_raw
+ _, top5_faiss = logits_faiss.topk(5)
+ overlap = len(set(top5_final[0].cpu().tolist()) & set(top5_faiss[0].cpu().tolist()))
+ print(f" final vs FAISS overlap: {overlap}/5")
+ print(f" FAISS top5: {top5_faiss[0].cpu().tolist()}")
+ print(f" Hybrid top5: {top5_final[0].cpu().tolist()}")
+
+print("\nDone.")
diff --git a/scripts/eval_centered_grid.py b/scripts/eval_centered_grid.py
new file mode 100644
index 0000000..4581251
--- /dev/null
+++ b/scripts/eval_centered_grid.py
@@ -0,0 +1,313 @@
+"""Evaluate centered Hopfield on 100 questions.
+
+Memory bank is mean-centered (M̃ = M - μ), query is centered (q̃ = q - μ).
+β_critical = 37.6: below it origin is stable attractor, above it dynamics escape.
+
+Grid: β spanning both sides of β_critical, iter = [0, 1, 2, 3, 5, 8].
+No residual — pure Hopfield update on centered space.
+
+Usage:
+ CUDA_VISIBLE_DEVICES=1 nohup python -u scripts/eval_centered_grid.py \
+ --memory-bank data/processed/hotpotqa_memory_bank.pt \
+ --questions data/processed/hotpotqa_questions.jsonl \
+ --device cuda --max-samples 100 \
+ > data/processed/centered_grid.log 2>&1 &
+"""
+
+import argparse
+import json
+import logging
+import sys
+import time
+from pathlib import Path
+from typing import Dict, List, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import yaml
+
+from hag.config import EncoderConfig, GeneratorConfig, MemoryBankConfig
+from hag.energy import compute_attention_entropy
+from hag.encoder import Encoder
+from hag.generator import Generator
+from hag.memory_bank import MemoryBank
+from hag.metrics import exact_match, f1_score
+from hag.retriever_faiss import FAISSRetriever
+
+logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s %(levelname)s %(name)s: %(message)s",
+ stream=sys.stdout,
+)
+logger = logging.getLogger(__name__)
+
+
+def load_questions(path: str, max_samples: int) -> Tuple[List[str], List[str]]:
+ questions, gold_answers = [], []
+ with open(path) as f:
+ for line in f:
+ r = json.loads(line)
+ questions.append(r["question"])
+ gold_answers.append(r["answer"])
+ if len(questions) >= max_samples:
+ break
+ return questions, gold_answers
+
+
+@torch.no_grad()
+def hopfield_retrieve_centered(
+ query: torch.Tensor,
+ memory_centered: torch.Tensor,
+ mean: torch.Tensor,
+ beta: float,
+ max_iter: int,
+ top_k: int,
+) -> Tuple[torch.Tensor, float]:
+ """Pure Hopfield retrieval on centered memory bank.
+
+ Args:
+ query: (batch, d) raw query embeddings
+ memory_centered: (d, N) centered memory bank (M̃ = M - μ)
+ mean: (d,) memory bank mean
+ beta, max_iter, top_k: Hopfield params
+
+ Returns:
+ (top_k_indices (batch, top_k), avg_entropy)
+ """
+ # Center the query
+ q = query - mean.unsqueeze(0) # (batch, d)
+
+ for _ in range(max_iter):
+ logits = beta * (q @ memory_centered) # (batch, N)
+ alpha = torch.softmax(logits, dim=-1) # (batch, N)
+ q = alpha @ memory_centered.T # (batch, d)
+
+ # Final attention
+ logits = beta * (q @ memory_centered)
+ alpha = torch.softmax(logits, dim=-1)
+ _, indices = torch.topk(alpha, top_k, dim=-1)
+ entropy = compute_attention_entropy(alpha)
+ return indices, entropy
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Centered Hopfield grid evaluation")
+ parser.add_argument("--config", type=str, default="configs/hotpotqa.yaml")
+ parser.add_argument("--memory-bank", type=str, required=True)
+ parser.add_argument("--questions", type=str, required=True)
+ parser.add_argument("--device", type=str, default="cpu")
+ parser.add_argument("--max-samples", type=int, default=100)
+ parser.add_argument("--output", type=str, default="data/processed/centered_grid_results.json")
+ parser.add_argument("--top-k", type=int, default=5)
+ args = parser.parse_args()
+
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ # Grid: span β_critical ≈ 37.6
+ betas = [10.0, 20.0, 30.0, 38.0, 40.0, 45.0, 50.0, 60.0, 75.0, 100.0, 150.0, 200.0]
+ max_iters_list = [0, 1, 2, 3, 5, 8]
+ top_k = args.top_k
+
+ total_configs = len(betas) * len(max_iters_list)
+ logger.info("=" * 60)
+ logger.info("Centered Hopfield Grid Search")
+ logger.info(" β_critical ≈ 37.6")
+ logger.info(" betas: %s", betas)
+ logger.info(" max_iters: %s", max_iters_list)
+ logger.info(" total configs: %d", total_configs)
+ logger.info("=" * 60)
+
+ t_start = time.time()
+
+ questions, gold_answers = load_questions(args.questions, args.max_samples)
+ n = len(questions)
+ logger.info("Loaded %d questions", n)
+
+ # Load memory bank (uncentered)
+ mb = MemoryBank(MemoryBankConfig(**cfg.get("memory", {})))
+ mb.load(args.memory_bank, device=args.device)
+ M_raw = mb.embeddings # (d, N)
+ d, N = M_raw.shape
+ logger.info("Memory bank: %d passages, dim=%d", N, d)
+
+ # Center the memory bank
+ mu = M_raw.mean(dim=1) # (d,)
+ M_cent = M_raw - mu.unsqueeze(1) # (d, N)
+ logger.info("Centered memory bank. ‖μ‖=%.4f, ‖M̃·1/N‖=%.2e",
+ mu.norm().item(), M_cent.mean(dim=1).norm().item())
+
+ # Compute β_critical
+ S = torch.linalg.svdvals(M_cent)
+ lambda_max_C = (S[0].item() ** 2) / N
+ beta_crit = 1.0 / lambda_max_C
+ logger.info("β_critical = %.2f (λ_max(C)=%.4f, σ_max=%.4f)", beta_crit, lambda_max_C, S[0].item())
+
+ encoder = Encoder(EncoderConfig(**cfg.get("encoder", {})), device=args.device)
+ generator = Generator(GeneratorConfig(**cfg.get("generator", {})), device=args.device)
+
+ logger.info("Encoding questions...")
+ all_embs = []
+ batch_size = cfg.get("encoder", {}).get("batch_size", 64)
+ for i in range(0, n, batch_size):
+ all_embs.append(encoder.encode(questions[i : i + batch_size]))
+ Q = torch.cat(all_embs, dim=0) # (n, d)
+ logger.info("Encoded, shape=%s", Q.shape)
+
+ # FAISS baseline (on raw embeddings)
+ logger.info("Running FAISS baseline...")
+ emb_np = M_raw.T.cpu().numpy().astype(np.float32)
+ faiss_ret = FAISSRetriever(top_k=top_k)
+ faiss_ret.build_index(emb_np, mb.passages)
+
+ faiss_indices: Dict[int, Tuple[int, ...]] = {}
+ llm_cache: Dict[Tuple[int, frozenset], str] = {}
+
+ for i in range(n):
+ q_np = Q[i].cpu().numpy().astype(np.float32)
+ result = faiss_ret.retrieve(q_np)
+ idx_tuple = tuple(sorted(result.indices.tolist()))
+ faiss_indices[i] = idx_tuple
+ cache_key = (i, frozenset(idx_tuple))
+ answer = generator.generate(questions[i], result.passages)
+ llm_cache[cache_key] = answer
+ if (i + 1) % 20 == 0:
+ ems = [exact_match(llm_cache[(j, frozenset(faiss_indices[j]))], gold_answers[j]) for j in range(i + 1)]
+ f1s = [f1_score(llm_cache[(j, frozenset(faiss_indices[j]))], gold_answers[j]) for j in range(i + 1)]
+ logger.info(" FAISS %d/%d: EM=%.3f F1=%.3f", i + 1, n, np.mean(ems), np.mean(f1s))
+
+ faiss_em = np.mean([exact_match(llm_cache[(i, frozenset(faiss_indices[i]))], gold_answers[i]) for i in range(n)])
+ faiss_f1 = np.mean([f1_score(llm_cache[(i, frozenset(faiss_indices[i]))], gold_answers[i]) for i in range(n)])
+ logger.info("FAISS baseline: EM=%.4f F1=%.4f", faiss_em, faiss_f1)
+
+ # Phase 2: Retrieve all configs (centered)
+ logger.info("Phase 2: Retrieving all %d configs (centered)...", total_configs)
+ t_ret = time.time()
+
+ retrieval_data: Dict[str, List[Tuple[Tuple[int, ...], float]]] = {}
+
+ for beta in betas:
+ for max_iter in max_iters_list:
+ config_key = f"β={beta}_iter={max_iter}"
+ indices_batch, entropy = hopfield_retrieve_centered(
+ Q, M_cent, mu, beta=beta, max_iter=max_iter, top_k=top_k,
+ )
+ per_q = []
+ for i in range(n):
+ idx_tuple = tuple(sorted(indices_batch[i].tolist()))
+ per_q.append((idx_tuple, entropy))
+ retrieval_data[config_key] = per_q
+
+ logger.info("Retrieval done in %.1fs, %d configs", time.time() - t_ret, len(retrieval_data))
+
+ # Phase 3: Dedup + generate
+ needed: Dict[Tuple[int, frozenset], Tuple[int, Tuple[int, ...]]] = {}
+ for key, per_q in retrieval_data.items():
+ for i, (idx_tuple, _) in enumerate(per_q):
+ cache_key = (i, frozenset(idx_tuple))
+ if cache_key not in llm_cache and cache_key not in needed:
+ needed[cache_key] = (i, idx_tuple)
+
+ logger.info("Unique LLM calls needed: %d (cache has %d)", len(needed), len(llm_cache))
+
+ t_gen = time.time()
+ for call_idx, (cache_key, (q_idx, idx_tuple)) in enumerate(needed.items()):
+ passages = mb.get_passages_by_indices(torch.tensor(list(idx_tuple), dtype=torch.long))
+ answer = generator.generate(questions[q_idx], passages)
+ llm_cache[cache_key] = answer
+ if (call_idx + 1) % 50 == 0:
+ elapsed = time.time() - t_gen
+ rate = (call_idx + 1) / elapsed
+ logger.info(" Generated %d/%d (%.1f/s, ~%.0fs left)",
+ call_idx + 1, len(needed), rate, (len(needed) - call_idx - 1) / rate)
+ logger.info("Generation done: %d calls in %.1fs", len(needed), time.time() - t_gen)
+
+ # Phase 4: Evaluate
+ logger.info("Phase 4: Evaluating...")
+ results = []
+ for config_key, per_q in retrieval_data.items():
+ ems, f1s, overlaps = [], [], []
+ for i, (idx_tuple, ent) in enumerate(per_q):
+ cache_key = (i, frozenset(idx_tuple))
+ answer = llm_cache[cache_key]
+ ems.append(exact_match(answer, gold_answers[i]))
+ f1s.append(f1_score(answer, gold_answers[i]))
+ overlaps.append(len(set(idx_tuple) & set(faiss_indices[i])) / top_k)
+
+ em, f1 = np.mean(ems), np.mean(f1s)
+ # Parse beta from config_key
+ beta_val = float(config_key.split("_")[0].split("=")[1])
+ iter_val = int(config_key.split("_")[1].split("=")[1])
+ r = {
+ "config": config_key,
+ "beta": beta_val,
+ "max_iter": iter_val,
+ "em": round(em, 4),
+ "f1": round(f1, 4),
+ "avg_faiss_overlap": round(np.mean(overlaps), 4),
+ "avg_entropy": round(per_q[0][1], 4),
+ "above_beta_crit": beta_val > beta_crit,
+ }
+ results.append(r)
+
+ results.sort(key=lambda x: x["f1"], reverse=True)
+
+ # Count how many beat FAISS
+ n_beat = sum(1 for r in results if r["f1"] > faiss_f1)
+ logger.info("\n%d/%d configs beat FAISS F1=%.3f", n_beat, len(results), faiss_f1)
+
+ # Log top 20
+ logger.info("\nTop 20 configs:")
+ for i, r in enumerate(results[:20]):
+ marker = " ***" if r["f1"] > faiss_f1 else ""
+ crit = ">" if r["above_beta_crit"] else "<"
+ logger.info(" %2d. %-25s EM=%.3f F1=%.3f overlap=%.3f H=%.2f β%sβ_c%s",
+ i + 1, r["config"], r["em"], r["f1"], r["avg_faiss_overlap"],
+ r["avg_entropy"], crit, marker)
+
+ # Summary by β: best iter for each β
+ logger.info("\nBest iter per β:")
+ for beta in betas:
+ beta_results = [r for r in results if r["beta"] == beta]
+ if beta_results:
+ best = beta_results[0]
+ crit = ">" if best["above_beta_crit"] else "<"
+ logger.info(" β=%6.1f (β%sβ_c): best iter=%d EM=%.3f F1=%.3f overlap=%.3f",
+ beta, crit, best["max_iter"], best["em"], best["f1"], best["avg_faiss_overlap"])
+
+ t_total = time.time() - t_start
+ output = {
+ "meta": {
+ "n_questions": n,
+ "total_configs": len(retrieval_data),
+ "unique_llm_calls": len(needed),
+ "total_time_s": round(t_total, 1),
+ "beta_critical": round(beta_crit, 2),
+ },
+ "faiss_baseline": {"em": round(faiss_em, 4), "f1": round(faiss_f1, 4)},
+ "grid_results": results,
+ "best_config": results[0],
+ "top10": results[:10],
+ }
+
+ Path(args.output).parent.mkdir(parents=True, exist_ok=True)
+ with open(args.output, "w") as f:
+ json.dump(output, f, indent=2)
+
+ logger.info("=" * 60)
+ logger.info("RESULTS SUMMARY")
+ logger.info(" FAISS: EM=%.4f F1=%.4f", faiss_em, faiss_f1)
+ logger.info(" β_critical = %.2f", beta_crit)
+ logger.info(" Configs beating FAISS: %d/%d", n_beat, len(results))
+ logger.info(" Top 5:")
+ for i, r in enumerate(results[:5]):
+ logger.info(" %d. %-25s EM=%.3f F1=%.3f overlap=%.3f",
+ i + 1, r["config"], r["em"], r["f1"], r["avg_faiss_overlap"])
+ logger.info(" Total time: %.1fs", t_total)
+ logger.info(" Saved to: %s", args.output)
+ logger.info("=" * 60)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/eval_highbeta_grid.py b/scripts/eval_highbeta_grid.py
new file mode 100644
index 0000000..bf1d97d
--- /dev/null
+++ b/scripts/eval_highbeta_grid.py
@@ -0,0 +1,301 @@
+"""Evaluate high-β Hopfield (standard + normalized update) on 100 questions.
+
+Tests whether high β (≥50) allows standard Hopfield to work without residual.
+Also tests normalized update: q → normalize(M @ softmax(β * M^T @ q)).
+
+Usage:
+ CUDA_VISIBLE_DEVICES=0 python -u scripts/eval_highbeta_grid.py \
+ --config configs/hotpotqa.yaml \
+ --memory-bank data/processed/hotpotqa_memory_bank.pt \
+ --questions data/processed/hotpotqa_questions.jsonl \
+ --device cuda --max-samples 100
+"""
+
+import argparse
+import json
+import logging
+import sys
+import time
+from pathlib import Path
+from typing import Dict, List, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import yaml
+
+from hag.config import EncoderConfig, GeneratorConfig, MemoryBankConfig
+from hag.energy import compute_attention_entropy
+from hag.encoder import Encoder
+from hag.generator import Generator
+from hag.memory_bank import MemoryBank
+from hag.metrics import exact_match, f1_score
+from hag.retriever_faiss import FAISSRetriever
+
+logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s %(levelname)s %(name)s: %(message)s",
+ stream=sys.stdout,
+)
+logger = logging.getLogger(__name__)
+
+
+def load_questions(path: str, max_samples: int) -> Tuple[List[str], List[str]]:
+ questions, gold_answers = [], []
+ with open(path) as f:
+ for line in f:
+ r = json.loads(line)
+ questions.append(r["question"])
+ gold_answers.append(r["answer"])
+ if len(questions) >= max_samples:
+ break
+ return questions, gold_answers
+
+
+@torch.no_grad()
+def hopfield_retrieve(
+ query: torch.Tensor,
+ memory: torch.Tensor,
+ beta: float,
+ max_iter: int,
+ top_k: int,
+ mode: str = "standard",
+ lam: float = 0.0,
+) -> Tuple[torch.Tensor, float]:
+ """Hopfield retrieval with different update modes.
+
+ Args:
+ query: (batch, d)
+ memory: (d, N)
+ beta, max_iter, top_k: Hopfield params
+ mode: "standard" | "normalized" | "residual"
+ lam: residual weight (only for mode="residual")
+
+ Returns:
+ (top_k_indices (batch, top_k), avg_entropy)
+ """
+ q = query.clone()
+ if mode == "normalized":
+ q = F.normalize(q, dim=-1)
+
+ for _ in range(max_iter):
+ logits = beta * (q @ memory) # (batch, N)
+ alpha = torch.softmax(logits, dim=-1) # (batch, N)
+ q_hop = alpha @ memory.T # (batch, d)
+
+ if mode == "standard":
+ q = q_hop
+ elif mode == "normalized":
+ q = F.normalize(q_hop, dim=-1)
+ elif mode == "residual":
+ q = lam * q + (1.0 - lam) * q_hop
+
+ # Final attention
+ logits = beta * (q @ memory)
+ alpha = torch.softmax(logits, dim=-1)
+ _, indices = torch.topk(alpha, top_k, dim=-1)
+ entropy = compute_attention_entropy(alpha)
+ return indices, entropy
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="High-β Hopfield grid evaluation")
+ parser.add_argument("--config", type=str, default="configs/hotpotqa.yaml")
+ parser.add_argument("--memory-bank", type=str, required=True)
+ parser.add_argument("--questions", type=str, required=True)
+ parser.add_argument("--device", type=str, default="cpu")
+ parser.add_argument("--max-samples", type=int, default=100)
+ parser.add_argument("--output", type=str, default="data/processed/highbeta_grid_results.json")
+ parser.add_argument("--top-k", type=int, default=5)
+ args = parser.parse_args()
+
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ # Grid: high β focus
+ betas = [20.0, 50.0, 100.0, 200.0, 500.0]
+ max_iters_list = [0, 1, 2, 3, 5, 8]
+ # Modes: standard, normalized, residual(λ=0.9)
+ modes = [
+ ("standard", 0.0),
+ ("normalized", 0.0),
+ ("residual_0.9", 0.9),
+ ("residual_0.95", 0.95),
+ ]
+ top_k = args.top_k
+
+ total_configs = len(betas) * len(max_iters_list) * len(modes)
+ logger.info("=" * 60)
+ logger.info("High-β Hopfield Grid Search")
+ logger.info(" betas: %s", betas)
+ logger.info(" max_iters: %s", max_iters_list)
+ logger.info(" modes: %s", [m[0] for m in modes])
+ logger.info(" total configs: %d", total_configs)
+ logger.info("=" * 60)
+
+ t_start = time.time()
+
+ questions, gold_answers = load_questions(args.questions, args.max_samples)
+ n = len(questions)
+ logger.info("Loaded %d questions", n)
+
+ mb = MemoryBank(MemoryBankConfig(**cfg.get("memory", {})))
+ mb.load(args.memory_bank, device=args.device)
+ M = mb.embeddings
+ logger.info("Memory bank: %d passages, dim=%d", mb.size, mb.dim)
+
+ encoder = Encoder(EncoderConfig(**cfg.get("encoder", {})), device=args.device)
+ generator = Generator(GeneratorConfig(**cfg.get("generator", {})), device=args.device)
+
+ logger.info("Encoding questions...")
+ all_embs = []
+ batch_size = cfg.get("encoder", {}).get("batch_size", 64)
+ for i in range(0, n, batch_size):
+ all_embs.append(encoder.encode(questions[i : i + batch_size]))
+ Q = torch.cat(all_embs, dim=0)
+ logger.info("Encoded, shape=%s", Q.shape)
+
+ # FAISS baseline
+ logger.info("Running FAISS baseline...")
+ emb_np = mb.embeddings.T.cpu().numpy().astype(np.float32)
+ faiss_ret = FAISSRetriever(top_k=top_k)
+ faiss_ret.build_index(emb_np, mb.passages)
+
+ faiss_indices: Dict[int, Tuple[int, ...]] = {}
+ llm_cache: Dict[Tuple[int, frozenset], str] = {}
+
+ for i in range(n):
+ q_np = Q[i].cpu().numpy().astype(np.float32)
+ result = faiss_ret.retrieve(q_np)
+ idx_tuple = tuple(sorted(result.indices.tolist()))
+ faiss_indices[i] = idx_tuple
+ cache_key = (i, frozenset(idx_tuple))
+ answer = generator.generate(questions[i], result.passages)
+ llm_cache[cache_key] = answer
+ if (i + 1) % 20 == 0:
+ ems = [exact_match(llm_cache[(j, frozenset(faiss_indices[j]))], gold_answers[j]) for j in range(i + 1)]
+ f1s = [f1_score(llm_cache[(j, frozenset(faiss_indices[j]))], gold_answers[j]) for j in range(i + 1)]
+ logger.info(" FAISS %d/%d: EM=%.3f F1=%.3f", i + 1, n, np.mean(ems), np.mean(f1s))
+
+ faiss_em = np.mean([exact_match(llm_cache[(i, frozenset(faiss_indices[i]))], gold_answers[i]) for i in range(n)])
+ faiss_f1 = np.mean([f1_score(llm_cache[(i, frozenset(faiss_indices[i]))], gold_answers[i]) for i in range(n)])
+ logger.info("FAISS baseline: EM=%.4f F1=%.4f", faiss_em, faiss_f1)
+
+ # Phase 2: Retrieve all configs
+ logger.info("Phase 2: Retrieving all %d configs...", total_configs)
+ t_ret = time.time()
+
+ retrieval_data: Dict[str, List[Tuple[Tuple[int, ...], float]]] = {}
+
+ for beta in betas:
+ for max_iter in max_iters_list:
+ for mode_name, lam in modes:
+ if max_iter == 0:
+ # iter=0: just use initial query's softmax top-k (same for all modes)
+ if mode_name != "standard":
+ continue # skip duplicates for iter=0
+ indices_batch = (beta * (Q @ M)).softmax(dim=-1).topk(top_k, dim=-1).indices
+ entropy = compute_attention_entropy((beta * (Q @ M)).softmax(dim=-1))
+ config_key = f"β={beta}_iter=0_standard"
+ else:
+ actual_mode = "residual" if mode_name.startswith("residual") else mode_name
+ indices_batch, entropy = hopfield_retrieve(
+ Q, M, beta=beta, max_iter=max_iter, top_k=top_k,
+ mode=actual_mode, lam=lam,
+ )
+ config_key = f"β={beta}_iter={max_iter}_{mode_name}"
+
+ per_q = []
+ for i in range(n):
+ idx_tuple = tuple(sorted(indices_batch[i].tolist()))
+ per_q.append((idx_tuple, entropy))
+ retrieval_data[config_key] = per_q
+
+ logger.info("Retrieval done in %.1fs, %d configs", time.time() - t_ret, len(retrieval_data))
+
+ # Phase 3: Dedup + generate
+ needed: Dict[Tuple[int, frozenset], Tuple[int, Tuple[int, ...]]] = {}
+ for key, per_q in retrieval_data.items():
+ for i, (idx_tuple, _) in enumerate(per_q):
+ cache_key = (i, frozenset(idx_tuple))
+ if cache_key not in llm_cache and cache_key not in needed:
+ needed[cache_key] = (i, idx_tuple)
+
+ logger.info("Unique LLM calls needed: %d (cache has %d)", len(needed), len(llm_cache))
+
+ t_gen = time.time()
+ for call_idx, (cache_key, (q_idx, idx_tuple)) in enumerate(needed.items()):
+ passages = mb.get_passages_by_indices(torch.tensor(list(idx_tuple), dtype=torch.long))
+ answer = generator.generate(questions[q_idx], passages)
+ llm_cache[cache_key] = answer
+ if (call_idx + 1) % 50 == 0:
+ elapsed = time.time() - t_gen
+ rate = (call_idx + 1) / elapsed
+ logger.info(" Generated %d/%d (%.1f/s, ~%.0fs left)",
+ call_idx + 1, len(needed), rate, (len(needed) - call_idx - 1) / rate)
+ logger.info("Generation done: %d calls in %.1fs", len(needed), time.time() - t_gen)
+
+ # Phase 4: Evaluate
+ logger.info("Phase 4: Evaluating...")
+ results = []
+ for config_key, per_q in retrieval_data.items():
+ ems, f1s, overlaps = [], [], []
+ for i, (idx_tuple, ent) in enumerate(per_q):
+ cache_key = (i, frozenset(idx_tuple))
+ answer = llm_cache[cache_key]
+ ems.append(exact_match(answer, gold_answers[i]))
+ f1s.append(f1_score(answer, gold_answers[i]))
+ overlaps.append(len(set(idx_tuple) & set(faiss_indices[i])) / top_k)
+
+ em, f1 = np.mean(ems), np.mean(f1s)
+ r = {
+ "config": config_key,
+ "em": round(em, 4),
+ "f1": round(f1, 4),
+ "avg_faiss_overlap": round(np.mean(overlaps), 4),
+ "avg_entropy": round(per_q[0][1], 4),
+ }
+ results.append(r)
+
+ results.sort(key=lambda x: x["f1"], reverse=True)
+
+ # Log all that beat or match FAISS
+ logger.info("\nConfigs matching or beating FAISS (F1≥%.3f):", faiss_f1)
+ for r in results:
+ if r["f1"] >= faiss_f1 - 0.005:
+ marker = " ***" if r["f1"] > faiss_f1 else ""
+ logger.info(" %s: EM=%.3f F1=%.3f overlap=%.3f%s",
+ r["config"], r["em"], r["f1"], r["avg_faiss_overlap"], marker)
+
+ t_total = time.time() - t_start
+ output = {
+ "meta": {
+ "n_questions": n,
+ "total_configs": len(retrieval_data),
+ "unique_llm_calls": len(needed),
+ "total_time_s": round(t_total, 1),
+ },
+ "faiss_baseline": {"em": round(faiss_em, 4), "f1": round(faiss_f1, 4)},
+ "grid_results": results,
+ "best_config": results[0],
+ "top10": results[:10],
+ }
+
+ Path(args.output).parent.mkdir(parents=True, exist_ok=True)
+ with open(args.output, "w") as f:
+ json.dump(output, f, indent=2)
+
+ logger.info("=" * 60)
+ logger.info("RESULTS")
+ logger.info(" FAISS: EM=%.4f F1=%.4f", faiss_em, faiss_f1)
+ logger.info(" Top 10:")
+ for i, r in enumerate(results[:10]):
+ logger.info(" %2d. %-40s EM=%.3f F1=%.3f overlap=%.3f",
+ i + 1, r["config"], r["em"], r["f1"], r["avg_faiss_overlap"])
+ logger.info(" Total time: %.1fs", t_total)
+ logger.info(" Saved to: %s", args.output)
+ logger.info("=" * 60)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/eval_residual_grid.py b/scripts/eval_residual_grid.py
new file mode 100644
index 0000000..f716c51
--- /dev/null
+++ b/scripts/eval_residual_grid.py
@@ -0,0 +1,298 @@
+"""Evaluate Residual Hopfield configs on 100 questions with dedup-based LLM caching.
+
+Residual update: q_{t+1} = λ * q_t + (1-λ) * M @ softmax(β * M^T @ q_t)
+
+Usage:
+ CUDA_VISIBLE_DEVICES=1 python scripts/eval_residual_grid.py \
+ --config configs/hotpotqa.yaml \
+ --memory-bank data/processed/hotpotqa_memory_bank.pt \
+ --questions data/processed/hotpotqa_questions.jsonl \
+ --device cuda --max-samples 100
+"""
+
+import argparse
+import json
+import logging
+import sys
+import time
+from pathlib import Path
+from typing import Dict, List, Tuple
+
+import numpy as np
+import torch
+import yaml
+
+from hag.config import EncoderConfig, GeneratorConfig, HopfieldConfig, MemoryBankConfig
+from hag.encoder import Encoder
+from hag.energy import compute_attention_entropy
+from hag.generator import Generator
+from hag.memory_bank import MemoryBank
+from hag.metrics import exact_match, f1_score
+from hag.retriever_faiss import FAISSRetriever
+
+logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s %(levelname)s %(name)s: %(message)s",
+ stream=sys.stdout,
+)
+logger = logging.getLogger(__name__)
+
+
+def load_questions(path: str, max_samples: int) -> Tuple[List[str], List[str]]:
+ questions, gold_answers = [], []
+ with open(path) as f:
+ for line in f:
+ r = json.loads(line)
+ questions.append(r["question"])
+ gold_answers.append(r["answer"])
+ if len(questions) >= max_samples:
+ break
+ return questions, gold_answers
+
+
+@torch.no_grad()
+def residual_hopfield_retrieve(
+ query: torch.Tensor,
+ memory: torch.Tensor,
+ beta: float,
+ lam: float,
+ max_iter: int,
+ top_k: int,
+) -> Tuple[torch.Tensor, torch.Tensor, float]:
+ """Residual Hopfield retrieval on full memory bank.
+
+ q_{t+1} = λ * q_t + (1-λ) * M @ softmax(β * M^T @ q_t)
+
+ Args:
+ query: (batch, d)
+ memory: (d, N)
+ beta: inverse temperature
+ lam: residual weight (0=pure Hopfield, 1=no update)
+ max_iter: number of iterations
+ top_k: number of passages to return
+
+ Returns:
+ (top_k_indices, top_k_scores, avg_entropy) for the batch.
+ indices: (batch, top_k), scores: (batch, top_k), entropy: float
+ """
+ q = query.clone()
+ for _ in range(max_iter):
+ logits = beta * (q @ memory) # (batch, N)
+ alpha = torch.softmax(logits, dim=-1) # (batch, N)
+ q_hop = alpha @ memory.T # (batch, d)
+ q = lam * q + (1.0 - lam) * q_hop # (batch, d)
+
+ # Final attention
+ logits = beta * (q @ memory) # (batch, N)
+ alpha = torch.softmax(logits, dim=-1) # (batch, N)
+ scores, indices = torch.topk(alpha, top_k, dim=-1) # (batch, top_k)
+
+ entropy = compute_attention_entropy(alpha)
+ return indices, scores, entropy
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Residual Hopfield grid evaluation")
+ parser.add_argument("--config", type=str, default="configs/hotpotqa.yaml")
+ parser.add_argument("--memory-bank", type=str, required=True)
+ parser.add_argument("--questions", type=str, required=True)
+ parser.add_argument("--device", type=str, default="cpu")
+ parser.add_argument("--max-samples", type=int, default=100)
+ parser.add_argument("--output", type=str, default="data/processed/residual_grid_results.json")
+ parser.add_argument("--top-k", type=int, default=5)
+ args = parser.parse_args()
+
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ # Grid
+ betas = [5.0, 10.0, 20.0, 50.0, 100.0]
+ lambdas = [0.5, 0.7, 0.8, 0.9, 0.95]
+ max_iters_list = [1, 3, 5, 8]
+ top_k = args.top_k
+
+ total_configs = len(betas) * len(lambdas) * len(max_iters_list)
+ logger.info("=" * 60)
+ logger.info("Residual Hopfield Grid Search")
+ logger.info(" betas: %s", betas)
+ logger.info(" lambdas: %s", lambdas)
+ logger.info(" max_iters: %s", max_iters_list)
+ logger.info(" total configs: %d", total_configs)
+ logger.info("=" * 60)
+
+ t_start = time.time()
+
+ # Load
+ questions, gold_answers = load_questions(args.questions, args.max_samples)
+ n = len(questions)
+ logger.info("Loaded %d questions", n)
+
+ mb = MemoryBank(MemoryBankConfig(**cfg.get("memory", {})))
+ mb.load(args.memory_bank, device=args.device)
+ M = mb.embeddings # (d, N)
+ logger.info("Memory bank: %d passages, dim=%d", mb.size, mb.dim)
+
+ encoder = Encoder(EncoderConfig(**cfg.get("encoder", {})), device=args.device)
+ generator = Generator(GeneratorConfig(**cfg.get("generator", {})), device=args.device)
+
+ logger.info("Encoding questions...")
+ all_embs = []
+ batch_size = cfg.get("encoder", {}).get("batch_size", 64)
+ for i in range(0, n, batch_size):
+ all_embs.append(encoder.encode(questions[i : i + batch_size]))
+ Q = torch.cat(all_embs, dim=0) # (n, d)
+ logger.info("Encoded, shape=%s", Q.shape)
+
+ # FAISS baseline
+ logger.info("Running FAISS baseline...")
+ emb_np = mb.embeddings.T.cpu().numpy().astype(np.float32)
+ faiss_ret = FAISSRetriever(top_k=top_k)
+ faiss_ret.build_index(emb_np, mb.passages)
+
+ faiss_indices: Dict[int, Tuple[int, ...]] = {}
+ llm_cache: Dict[Tuple[int, frozenset], str] = {}
+
+ for i in range(n):
+ q_np = Q[i].cpu().numpy().astype(np.float32)
+ result = faiss_ret.retrieve(q_np)
+ idx_tuple = tuple(sorted(result.indices.tolist()))
+ faiss_indices[i] = idx_tuple
+ cache_key = (i, frozenset(idx_tuple))
+ answer = generator.generate(questions[i], result.passages)
+ llm_cache[cache_key] = answer
+ if (i + 1) % 20 == 0:
+ ems = [exact_match(llm_cache[(j, frozenset(faiss_indices[j]))], gold_answers[j]) for j in range(i + 1)]
+ f1s = [f1_score(llm_cache[(j, frozenset(faiss_indices[j]))], gold_answers[j]) for j in range(i + 1)]
+ logger.info(" FAISS %d/%d: EM=%.3f F1=%.3f", i + 1, n, np.mean(ems), np.mean(f1s))
+
+ faiss_em = np.mean([exact_match(llm_cache[(i, frozenset(faiss_indices[i]))], gold_answers[i]) for i in range(n)])
+ faiss_f1 = np.mean([f1_score(llm_cache[(i, frozenset(faiss_indices[i]))], gold_answers[i]) for i in range(n)])
+ logger.info("FAISS baseline: EM=%.4f F1=%.4f", faiss_em, faiss_f1)
+
+ # Phase 2: Retrieve all configs (batched, fast)
+ logger.info("Phase 2: Retrieving all %d configs...", total_configs)
+ t_ret = time.time()
+
+ # config_key -> list of (sorted_indices_tuple, entropy) per question
+ retrieval_data: Dict[Tuple[float, float, int], List[Tuple[Tuple[int, ...], float]]] = {}
+
+ for beta in betas:
+ for lam in lambdas:
+ for max_iter in max_iters_list:
+ indices, scores, entropy = residual_hopfield_retrieve(
+ Q, M, beta=beta, lam=lam, max_iter=max_iter, top_k=top_k
+ )
+ per_q = []
+ for i in range(n):
+ idx_tuple = tuple(sorted(indices[i].tolist()))
+ # per-question entropy
+ per_q.append((idx_tuple, entropy))
+ retrieval_data[(beta, lam, max_iter)] = per_q
+
+ logger.info("Retrieval done in %.1fs", time.time() - t_ret)
+
+ # Phase 3: Dedup + generate
+ needed: Dict[Tuple[int, frozenset], Tuple[int, Tuple[int, ...]]] = {}
+ for key, per_q in retrieval_data.items():
+ for i, (idx_tuple, _) in enumerate(per_q):
+ cache_key = (i, frozenset(idx_tuple))
+ if cache_key not in llm_cache and cache_key not in needed:
+ needed[cache_key] = (i, idx_tuple)
+
+ total_grid_evals = total_configs * n
+ logger.info(
+ "Unique LLM calls needed: %d / %d grid evals (%.1f%% saving)",
+ len(needed), total_grid_evals,
+ (1 - len(needed) / total_grid_evals) * 100,
+ )
+
+ t_gen = time.time()
+ for call_idx, (cache_key, (q_idx, idx_tuple)) in enumerate(needed.items()):
+ passages = mb.get_passages_by_indices(torch.tensor(list(idx_tuple), dtype=torch.long))
+ answer = generator.generate(questions[q_idx], passages)
+ llm_cache[cache_key] = answer
+ if (call_idx + 1) % 50 == 0:
+ elapsed = time.time() - t_gen
+ rate = (call_idx + 1) / elapsed
+ logger.info(
+ " Generated %d/%d (%.1f/s, ~%.0fs left)",
+ call_idx + 1, len(needed), rate, (len(needed) - call_idx - 1) / rate,
+ )
+ logger.info("Generation done: %d calls in %.1fs", len(needed), time.time() - t_gen)
+
+ # Phase 4: Evaluate
+ logger.info("Phase 4: Evaluating...")
+ results = []
+ for beta in betas:
+ for lam in lambdas:
+ for max_iter in max_iters_list:
+ per_q = retrieval_data[(beta, lam, max_iter)]
+ ems, f1s, overlaps, entropies = [], [], [], []
+ for i, (idx_tuple, ent) in enumerate(per_q):
+ cache_key = (i, frozenset(idx_tuple))
+ answer = llm_cache[cache_key]
+ ems.append(exact_match(answer, gold_answers[i]))
+ f1s.append(f1_score(answer, gold_answers[i]))
+ overlap = len(set(idx_tuple) & set(faiss_indices[i])) / top_k
+ overlaps.append(overlap)
+ entropies.append(ent)
+
+ em, f1 = np.mean(ems), np.mean(f1s)
+ r = {
+ "beta": beta, "lambda": lam, "max_iter": max_iter,
+ "em": round(em, 4), "f1": round(f1, 4),
+ "avg_faiss_overlap": round(np.mean(overlaps), 4),
+ "avg_entropy": round(np.mean(entropies), 4),
+ }
+ results.append(r)
+ if f1 >= faiss_f1 - 0.01:
+ marker = " ***" if f1 > faiss_f1 else ""
+ logger.info(
+ " β=%5.1f λ=%.2f iter=%d => EM=%.3f F1=%.3f overlap=%.3f%s",
+ beta, lam, max_iter, em, f1, np.mean(overlaps), marker,
+ )
+
+ # Sort by F1
+ results.sort(key=lambda x: x["f1"], reverse=True)
+ best = results[0]
+
+ t_total = time.time() - t_start
+
+ output = {
+ "meta": {
+ "n_questions": n,
+ "total_configs": total_configs,
+ "unique_llm_calls": len(needed),
+ "faiss_llm_calls": n,
+ "total_time_s": round(t_total, 1),
+ },
+ "faiss_baseline": {"em": round(faiss_em, 4), "f1": round(faiss_f1, 4)},
+ "grid_results": results,
+ "best_config": best,
+ "top10": results[:10],
+ }
+
+ Path(args.output).parent.mkdir(parents=True, exist_ok=True)
+ with open(args.output, "w") as f:
+ json.dump(output, f, indent=2)
+
+ logger.info("=" * 60)
+ logger.info("RESULTS")
+ logger.info(" FAISS: EM=%.4f F1=%.4f", faiss_em, faiss_f1)
+ logger.info(
+ " Best: β=%.1f λ=%.2f iter=%d => EM=%.4f F1=%.4f",
+ best["beta"], best["lambda"], best["max_iter"], best["em"], best["f1"],
+ )
+ logger.info(" Top 5:")
+ for i, r in enumerate(results[:5]):
+ logger.info(
+ " %d. β=%5.1f λ=%.2f iter=%d => EM=%.3f F1=%.3f overlap=%.3f",
+ i + 1, r["beta"], r["lambda"], r["max_iter"], r["em"], r["f1"], r["avg_faiss_overlap"],
+ )
+ logger.info(" Total time: %.1fs", t_total)
+ logger.info(" Saved to: %s", args.output)
+ logger.info("=" * 60)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/prepare_corpus.py b/scripts/prepare_corpus.py
new file mode 100644
index 0000000..93fc0ce
--- /dev/null
+++ b/scripts/prepare_corpus.py
@@ -0,0 +1,127 @@
+"""Convert linear-rag chunks.json to JSONL corpus for build_memory_bank.py.
+
+The linear-rag dataset stores chunks as a list of strings with format "idx:text...".
+This script strips the index prefix and outputs one {"text": "..."} per line.
+
+Usage:
+ python scripts/prepare_corpus.py --dataset hotpotqa
+ python scripts/prepare_corpus.py --dataset hotpotqa --dataset musique --dataset 2wikimultihop
+"""
+
+import argparse
+import json
+import logging
+from pathlib import Path
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+DATASETS = ["hotpotqa", "musique", "2wikimultihop", "medical"]
+
+
+def convert_chunks(dataset: str, data_root: Path, output_dir: Path) -> Path:
+ """Convert a single dataset's chunks.json to corpus JSONL.
+
+ Args:
+ dataset: dataset name (e.g., "hotpotqa")
+ data_root: path to linear-rag clone
+ output_dir: directory to write output JSONL
+
+ Returns:
+ Path to the output JSONL file.
+ """
+ chunks_path = data_root / dataset / "chunks.json"
+ if not chunks_path.exists():
+ raise FileNotFoundError(f"Not found: {chunks_path}")
+
+ with open(chunks_path) as f:
+ chunks = json.load(f)
+
+ output_dir.mkdir(parents=True, exist_ok=True)
+ output_path = output_dir / f"{dataset}_corpus.jsonl"
+
+ count = 0
+ with open(output_path, "w") as out:
+ for chunk in chunks:
+ # Strip the "idx:" prefix
+ text = chunk.split(":", 1)[1] if ":" in chunk else chunk
+ text = text.strip()
+ if text:
+ out.write(json.dumps({"text": text}) + "\n")
+ count += 1
+
+ logger.info("%s: %d chunks -> %s", dataset, count, output_path)
+ return output_path
+
+
+def convert_questions(dataset: str, data_root: Path, output_dir: Path) -> Path:
+ """Convert questions.json to a standardized JSONL format.
+
+ Args:
+ dataset: dataset name
+ data_root: path to linear-rag clone
+ output_dir: directory to write output JSONL
+
+ Returns:
+ Path to the output JSONL file.
+ """
+ questions_path = data_root / dataset / "questions.json"
+ if not questions_path.exists():
+ raise FileNotFoundError(f"Not found: {questions_path}")
+
+ with open(questions_path) as f:
+ questions = json.load(f)
+
+ output_dir.mkdir(parents=True, exist_ok=True)
+ output_path = output_dir / f"{dataset}_questions.jsonl"
+
+ count = 0
+ with open(output_path, "w") as out:
+ for q in questions:
+ record = {
+ "id": q.get("id", ""),
+ "question": q["question"],
+ "answer": q["answer"],
+ "question_type": q.get("question_type", ""),
+ }
+ out.write(json.dumps(record) + "\n")
+ count += 1
+
+ logger.info("%s: %d questions -> %s", dataset, count, output_path)
+ return output_path
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Prepare linear-rag data for HAG")
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ action="append",
+ default=None,
+ help=f"Dataset(s) to process. Choices: {DATASETS}. Can specify multiple times.",
+ )
+ parser.add_argument(
+ "--data-root",
+ type=str,
+ default="data/linear-rag",
+ help="Path to linear-rag clone",
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default="data/processed",
+ help="Output directory for processed files",
+ )
+ args = parser.parse_args()
+
+ datasets = args.dataset if args.dataset else DATASETS
+ data_root = Path(args.data_root)
+ output_dir = Path(args.output_dir)
+
+ for ds in datasets:
+ convert_chunks(ds, data_root, output_dir)
+ convert_questions(ds, data_root, output_dir)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/run_baseline.py b/scripts/run_baseline.py
index 74c4710..beef76b 100644
--- a/scripts/run_baseline.py
+++ b/scripts/run_baseline.py
@@ -27,6 +27,7 @@ def main() -> None:
parser.add_argument("--memory-bank", type=str, required=True)
parser.add_argument("--question", type=str, required=True)
parser.add_argument("--top-k", type=int, default=5)
+ parser.add_argument("--device", type=str, default="cpu")
args = parser.parse_args()
with open(args.config) as f:
@@ -44,17 +45,17 @@ def main() -> None:
from hag.config import MemoryBankConfig
mb = MemoryBank(MemoryBankConfig(**cfg.get("memory", {})))
- mb.load(args.memory_bank)
+ mb.load(args.memory_bank) # FAISS needs CPU, load on CPU
# Build FAISS index from memory bank embeddings
import numpy as np
- embeddings_np = mb.embeddings.T.numpy().astype(np.float32) # (N, d)
+ embeddings_np = mb.embeddings.T.cpu().numpy().astype(np.float32) # (N, d)
faiss_ret = FAISSRetriever(top_k=args.top_k)
faiss_ret.build_index(embeddings_np, mb.passages)
- encoder = Encoder(pipeline_config.encoder)
- generator = Generator(pipeline_config.generator)
+ encoder = Encoder(pipeline_config.encoder, device=args.device)
+ generator = Generator(pipeline_config.generator, device=args.device)
pipeline = RAGPipeline(
config=pipeline_config,
diff --git a/scripts/run_comparison.py b/scripts/run_comparison.py
new file mode 100644
index 0000000..29f23f8
--- /dev/null
+++ b/scripts/run_comparison.py
@@ -0,0 +1,186 @@
+"""Run side-by-side comparison of FAISS (baseline) vs Hopfield (HAG) retrieval.
+
+Usage:
+ CUDA_VISIBLE_DEVICES=1 python scripts/run_comparison.py \
+ --config configs/hotpotqa.yaml \
+ --memory-bank data/processed/hotpotqa_memory_bank.pt \
+ --questions data/processed/hotpotqa_questions.jsonl \
+ --device cuda \
+ --max-samples 500
+"""
+
+import argparse
+import json
+import logging
+import time
+
+import numpy as np
+import torch
+import yaml
+
+from hag.config import (
+ EncoderConfig,
+ GeneratorConfig,
+ HopfieldConfig,
+ MemoryBankConfig,
+ PipelineConfig,
+)
+from hag.encoder import Encoder
+from hag.generator import Generator
+from hag.hopfield import HopfieldRetrieval
+from hag.memory_bank import MemoryBank
+from hag.metrics import evaluate_dataset, exact_match, f1_score
+from hag.pipeline import RAGPipeline
+from hag.retriever_faiss import FAISSRetriever
+from hag.retriever_hopfield import HopfieldRetriever
+
+logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
+logger = logging.getLogger(__name__)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Compare FAISS vs Hopfield retrieval")
+ parser.add_argument("--config", type=str, default="configs/hotpotqa.yaml")
+ parser.add_argument("--memory-bank", type=str, required=True)
+ parser.add_argument("--questions", type=str, required=True)
+ parser.add_argument("--device", type=str, default="cpu")
+ parser.add_argument("--max-samples", type=int, default=None)
+ parser.add_argument("--output", type=str, default="data/processed/comparison_results.json")
+ args = parser.parse_args()
+
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ hopfield_config = HopfieldConfig(**cfg.get("hopfield", {}))
+ memory_config = MemoryBankConfig(**cfg.get("memory", {}))
+ encoder_config = EncoderConfig(**cfg.get("encoder", {}))
+ generator_config = GeneratorConfig(**cfg.get("generator", {}))
+
+ # Load questions
+ with open(args.questions) as f:
+ questions_data = [json.loads(line) for line in f]
+ if args.max_samples and len(questions_data) > args.max_samples:
+ questions_data = questions_data[: args.max_samples]
+
+ questions = [q["question"] for q in questions_data]
+ gold_answers = [q["answer"] for q in questions_data]
+ logger.info("Loaded %d questions", len(questions))
+
+ # Load memory bank
+ mb = MemoryBank(memory_config)
+ mb.load(args.memory_bank, device=args.device)
+ logger.info("Memory bank: %d passages, dim=%d", mb.size, mb.dim)
+
+ # Shared encoder and generator
+ encoder = Encoder(encoder_config, device=args.device)
+ generator = Generator(generator_config, device=args.device)
+
+ # --- Build FAISS retriever ---
+ embeddings_np = mb.embeddings.T.cpu().numpy().astype(np.float32) # (N, d)
+ faiss_retriever = FAISSRetriever(top_k=hopfield_config.top_k)
+ faiss_retriever.build_index(embeddings_np, mb.passages)
+
+ # --- Build Hopfield retriever ---
+ hopfield = HopfieldRetrieval(hopfield_config)
+ hopfield_retriever = HopfieldRetriever(hopfield, mb, top_k=hopfield_config.top_k)
+
+ # --- Build pipelines ---
+ faiss_pipeline_cfg = PipelineConfig(
+ hopfield=hopfield_config,
+ memory=memory_config,
+ encoder=encoder_config,
+ generator=generator_config,
+ retriever_type="faiss",
+ device=args.device,
+ )
+ faiss_pipeline = RAGPipeline(
+ config=faiss_pipeline_cfg,
+ encoder=encoder,
+ generator=generator,
+ faiss_retriever=faiss_retriever,
+ )
+
+ hopfield_pipeline_cfg = PipelineConfig(
+ hopfield=hopfield_config,
+ memory=memory_config,
+ encoder=encoder_config,
+ generator=generator_config,
+ retriever_type="hopfield",
+ device=args.device,
+ )
+ hopfield_pipeline = RAGPipeline(
+ config=hopfield_pipeline_cfg,
+ encoder=encoder,
+ generator=generator,
+ memory_bank=mb,
+ )
+
+ # --- Run FAISS baseline ---
+ logger.info("=" * 60)
+ logger.info("Running FAISS baseline (%d questions)...", len(questions))
+ t0 = time.time()
+ faiss_results = faiss_pipeline.run_batch(questions)
+ faiss_time = time.time() - t0
+ faiss_metrics = evaluate_dataset(faiss_results, gold_answers)
+ logger.info("FAISS done in %.1fs | EM=%.4f | F1=%.4f", faiss_time, faiss_metrics["em"], faiss_metrics["f1"])
+
+ # --- Run HAG ---
+ logger.info("=" * 60)
+ logger.info("Running HAG (beta=%.1f, max_iter=%d, top_k=%d) (%d questions)...",
+ hopfield_config.beta, hopfield_config.max_iter, hopfield_config.top_k, len(questions))
+ t0 = time.time()
+ hag_results = hopfield_pipeline.run_batch(questions)
+ hag_time = time.time() - t0
+ hag_metrics = evaluate_dataset(hag_results, gold_answers)
+ logger.info("HAG done in %.1fs | EM=%.4f | F1=%.4f", hag_time, hag_metrics["em"], hag_metrics["f1"])
+
+ # --- Summary ---
+ logger.info("=" * 60)
+ logger.info("COMPARISON SUMMARY")
+ logger.info("%-20s %10s %10s", "", "FAISS", "HAG")
+ logger.info("%-20s %10.4f %10.4f", "Exact Match", faiss_metrics["em"], hag_metrics["em"])
+ logger.info("%-20s %10.4f %10.4f", "F1 Score", faiss_metrics["f1"], hag_metrics["f1"])
+ logger.info("%-20s %10.1fs %10.1fs", "Time", faiss_time, hag_time)
+ em_delta = hag_metrics["em"] - faiss_metrics["em"]
+ f1_delta = hag_metrics["f1"] - faiss_metrics["f1"]
+ logger.info("%-20s %+10.4f %+10.4f", "Delta (HAG - FAISS)", em_delta, f1_delta)
+
+ # --- Per-question details ---
+ per_question = []
+ for i, (fq, hq, gold) in enumerate(zip(faiss_results, hag_results, gold_answers)):
+ per_question.append({
+ "id": questions_data[i].get("id", i),
+ "question": questions[i],
+ "gold_answer": gold,
+ "faiss_answer": fq.answer,
+ "hag_answer": hq.answer,
+ "faiss_em": exact_match(fq.answer, gold),
+ "hag_em": exact_match(hq.answer, gold),
+ "faiss_f1": f1_score(fq.answer, gold),
+ "hag_f1": f1_score(hq.answer, gold),
+ "faiss_passages": fq.retrieved_passages,
+ "hag_passages": hq.retrieved_passages,
+ })
+
+ output = {
+ "config": {
+ "hopfield_beta": hopfield_config.beta,
+ "hopfield_max_iter": hopfield_config.max_iter,
+ "top_k": hopfield_config.top_k,
+ "encoder": encoder_config.model_name,
+ "generator": generator_config.model_name,
+ "num_questions": len(questions),
+ "num_passages": mb.size,
+ },
+ "faiss_metrics": {**faiss_metrics, "time_seconds": faiss_time},
+ "hag_metrics": {**hag_metrics, "time_seconds": hag_time},
+ "per_question": per_question,
+ }
+
+ with open(args.output, "w") as f:
+ json.dump(output, f, indent=2, ensure_ascii=False)
+ logger.info("Full results saved to %s", args.output)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/run_eval.py b/scripts/run_eval.py
index 713b3c2..144fc2f 100644
--- a/scripts/run_eval.py
+++ b/scripts/run_eval.py
@@ -36,6 +36,7 @@ def main() -> None:
parser.add_argument("--split", type=str, default="validation")
parser.add_argument("--max-samples", type=int, default=500)
parser.add_argument("--output", type=str, default="results.json")
+ parser.add_argument("--device", type=str, default="cpu")
args = parser.parse_args()
with open(args.config) as f:
@@ -47,15 +48,16 @@ def main() -> None:
encoder=EncoderConfig(**cfg.get("encoder", {})),
generator=GeneratorConfig(**cfg.get("generator", {})),
retriever_type=cfg.get("retriever_type", "hopfield"),
+ device=args.device,
)
# Load memory bank
mb = MemoryBank(pipeline_config.memory)
- mb.load(args.memory_bank)
+ mb.load(args.memory_bank, device=args.device)
# Build pipeline
- encoder = Encoder(pipeline_config.encoder)
- generator = Generator(pipeline_config.generator)
+ encoder = Encoder(pipeline_config.encoder, device=args.device)
+ generator = Generator(pipeline_config.generator, device=args.device)
pipeline = RAGPipeline(
config=pipeline_config,
encoder=encoder,
diff --git a/scripts/run_grid_search.py b/scripts/run_grid_search.py
new file mode 100644
index 0000000..ddd5a8d
--- /dev/null
+++ b/scripts/run_grid_search.py
@@ -0,0 +1,552 @@
+"""Grid search over HAG hyperparameters (beta, max_iter) with dedup-based LLM caching.
+
+Key insight: many (beta, max_iter) combos retrieve the same top-k passages for a given
+question. By deduplicating on (question_idx, frozenset(top_k_indices)), we call the LLM
+only for unique passage sets, saving ~80-89% of generation calls.
+
+Usage:
+ python scripts/run_grid_search.py \
+ --config configs/hotpotqa.yaml \
+ --memory-bank data/processed/hotpotqa_memory_bank.pt \
+ --questions data/processed/hotpotqa_questions.jsonl \
+ --device cuda \
+ --max-samples 100
+"""
+
+import argparse
+import json
+import logging
+import time
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import numpy as np
+import torch
+
+from hag.config import (
+ EncoderConfig,
+ GeneratorConfig,
+ HopfieldConfig,
+ MemoryBankConfig,
+ PipelineConfig,
+)
+from hag.encoder import Encoder
+from hag.energy import compute_attention_entropy, compute_energy_curve, compute_energy_gap
+from hag.generator import Generator
+from hag.hopfield import HopfieldRetrieval
+from hag.memory_bank import MemoryBank
+from hag.metrics import exact_match, f1_score
+from hag.retriever_faiss import FAISSRetriever
+
+logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class GridPoint:
+ """Results for a single (beta, max_iter) configuration."""
+
+ beta: float
+ max_iter: int
+ em: float
+ f1: float
+ avg_entropy: float
+ avg_energy_gap: float
+ avg_faiss_overlap: float
+ avg_steps: float
+
+
+def load_questions(path: str, max_samples: Optional[int] = None) -> Tuple[List[str], List[str]]:
+ """Load questions and gold answers from JSONL file.
+
+ Args:
+ path: path to JSONL file with 'question' and 'answer' fields
+ max_samples: if set, limit to first N samples
+
+ Returns:
+ Tuple of (questions, gold_answers).
+ """
+ questions = []
+ gold_answers = []
+ with open(path) as f:
+ for line in f:
+ record = json.loads(line)
+ questions.append(record["question"])
+ gold_answers.append(record["answer"])
+ if max_samples and len(questions) >= max_samples:
+ break
+ return questions, gold_answers
+
+
+def encode_questions_batched(
+ encoder: Encoder, questions: List[str], batch_size: int = 32
+) -> torch.Tensor:
+ """Encode all questions into embeddings, batched for efficiency.
+
+ Args:
+ encoder: the encoder instance
+ questions: list of question strings
+ batch_size: encoding batch size
+
+ Returns:
+ (N, d) tensor of query embeddings.
+ """
+ all_embeddings = []
+ for i in range(0, len(questions), batch_size):
+ batch = questions[i : i + batch_size]
+ embs = encoder.encode(batch) # (batch_size, d)
+ all_embeddings.append(embs)
+ return torch.cat(all_embeddings, dim=0) # (N, d)
+
+
+def run_faiss_baseline(
+ query_embeddings: torch.Tensor,
+ memory_bank: MemoryBank,
+ generator: Generator,
+ questions: List[str],
+ gold_answers: List[str],
+ top_k: int,
+) -> Tuple[Dict[str, float], Dict[int, Tuple[str, Tuple[int, ...]]]]:
+ """Run FAISS baseline and cache results.
+
+ Args:
+ query_embeddings: (N, d) tensor
+ memory_bank: the memory bank
+ generator: LLM generator
+ questions: list of question strings
+ gold_answers: list of gold answer strings
+ top_k: number of passages to retrieve
+
+ Returns:
+ Tuple of (metrics_dict, faiss_cache).
+ faiss_cache maps question_idx -> (answer, top_k_indices_tuple).
+ """
+ logger.info("Building FAISS index...")
+ embeddings_np = memory_bank.embeddings.T.cpu().numpy().astype(np.float32) # (N_passages, d)
+ faiss_ret = FAISSRetriever(top_k=top_k)
+ faiss_ret.build_index(embeddings_np, memory_bank.passages)
+
+ faiss_cache: Dict[int, Tuple[str, Tuple[int, ...]]] = {}
+ em_scores = []
+ f1_scores = []
+
+ logger.info("Running FAISS baseline on %d questions...", len(questions))
+ for i, question in enumerate(questions):
+ query_np = query_embeddings[i].cpu().numpy().astype(np.float32) # (d,)
+ result = faiss_ret.retrieve(query_np)
+ answer = generator.generate(question, result.passages)
+ indices_tuple = tuple(sorted(result.indices.tolist()))
+
+ faiss_cache[i] = (answer, indices_tuple)
+ em_scores.append(exact_match(answer, gold_answers[i]))
+ f1_scores.append(f1_score(answer, gold_answers[i]))
+
+ if (i + 1) % 20 == 0:
+ logger.info(
+ " FAISS baseline: %d/%d (EM=%.3f, F1=%.3f)",
+ i + 1,
+ len(questions),
+ sum(em_scores) / len(em_scores),
+ sum(f1_scores) / len(f1_scores),
+ )
+
+ metrics = {
+ "em": sum(em_scores) / len(em_scores),
+ "f1": sum(f1_scores) / len(f1_scores),
+ }
+ logger.info("FAISS baseline: EM=%.4f, F1=%.4f", metrics["em"], metrics["f1"])
+ return metrics, faiss_cache
+
+
+def run_hopfield_grid(
+ query_embeddings: torch.Tensor,
+ memory_bank: MemoryBank,
+ generator: Generator,
+ questions: List[str],
+ gold_answers: List[str],
+ faiss_cache: Dict[int, Tuple[str, Tuple[int, ...]]],
+ betas: List[float],
+ max_iters: List[int],
+ top_k: int,
+ device: str,
+) -> Tuple[List[GridPoint], Dict]:
+ """Run grid search over (beta, max_iter) with dedup-based LLM caching.
+
+ Phase 2: Retrieve all configs (fast, batched).
+ Phase 3: Deduplicate and generate (LLM calls only for unique passage sets).
+ Phase 4: Evaluate and collect results.
+
+ Args:
+ query_embeddings: (N, d) tensor on device
+ memory_bank: memory bank (embeddings on device)
+ generator: LLM generator
+ questions: list of question strings
+ gold_answers: list of gold answer strings
+ faiss_cache: maps question_idx -> (answer, sorted_indices_tuple)
+ betas: list of beta values to sweep
+ max_iters: list of max_iter values to sweep
+ top_k: fixed top_k for retrieval
+ device: computation device
+
+ Returns:
+ Tuple of (grid_results, meta_dict).
+ """
+ n_questions = len(questions)
+ memory = memory_bank.embeddings # (d, N_passages) on device
+
+ # =========================================================================
+ # Phase 2: Retrieve all configurations (batched, milliseconds each)
+ # =========================================================================
+ logger.info("Phase 2: Running %d retrieval configs...", len(betas) * len(max_iters))
+
+ # Structure: config_key -> per-question retrieval data
+ # retrieval_data[config_key][q_idx] = {indices_tuple, entropy, energy_gap, steps, faiss_overlap}
+ @dataclass
+ class RetrievalInfo:
+ indices_tuple: Tuple[int, ...]
+ entropy: float
+ energy_gap: float
+ steps: int
+ faiss_overlap: float
+
+ retrieval_data: Dict[Tuple[float, int], List[RetrievalInfo]] = {}
+
+ t_retrieve_start = time.time()
+ for beta in betas:
+ for max_iter in max_iters:
+ config = HopfieldConfig(beta=beta, max_iter=max_iter, top_k=top_k)
+ hopfield = HopfieldRetrieval(config)
+
+ # Batched retrieval: all questions at once
+ result = hopfield.retrieve(
+ query_embeddings, memory, return_energy=True
+ ) # attention_weights: (N_questions, N_passages)
+
+ alpha = result.attention_weights # (N_questions, N_passages)
+ k = min(top_k, alpha.shape[-1])
+ scores, indices = torch.topk(alpha, k, dim=-1) # (N_questions, k)
+
+ # Compute energy curve per-question (energy_curve contains batch tensors)
+ energy_curves_raw = result.energy_curve # list of (N_questions,) tensors
+
+ infos = []
+ for q_idx in range(n_questions):
+ q_indices = sorted(indices[q_idx].tolist())
+ q_indices_tuple = tuple(q_indices)
+
+ # Per-question entropy
+ q_entropy = compute_attention_entropy(alpha[q_idx])
+
+ # Per-question energy gap
+ if energy_curves_raw is not None:
+ q_energies = [e[q_idx].item() for e in energy_curves_raw]
+ q_energy_gap = compute_energy_gap(q_energies)
+ else:
+ q_energy_gap = 0.0
+
+ # FAISS overlap: fraction of top-k indices shared with FAISS
+ faiss_indices_set = set(faiss_cache[q_idx][1])
+ hopfield_indices_set = set(q_indices)
+ overlap = len(faiss_indices_set & hopfield_indices_set) / k
+
+ infos.append(RetrievalInfo(
+ indices_tuple=q_indices_tuple,
+ entropy=q_entropy,
+ energy_gap=q_energy_gap,
+ steps=result.num_steps,
+ faiss_overlap=overlap,
+ ))
+
+ retrieval_data[(beta, max_iter)] = infos
+
+ t_retrieve_end = time.time()
+ logger.info("Phase 2 complete: %.2fs for all retrieval configs", t_retrieve_end - t_retrieve_start)
+
+ # =========================================================================
+ # Phase 3: Deduplicate and generate
+ # =========================================================================
+ logger.info("Phase 3: Deduplicating and generating...")
+
+ # Build set of unique (question_idx, passage_set) combos needing LLM calls
+ # Cache key: (question_idx, frozenset(top_k_indices))
+ llm_cache: Dict[Tuple[int, frozenset], str] = {}
+
+ # Seed cache with FAISS answers (same passage sets don't need re-generation)
+ for q_idx, (answer, indices_tuple) in faiss_cache.items():
+ cache_key = (q_idx, frozenset(indices_tuple))
+ llm_cache[cache_key] = answer
+
+ # Collect all unique keys we need
+ needed_keys: Dict[Tuple[int, frozenset], Tuple[int, Tuple[int, ...]]] = {}
+ for (beta, max_iter), infos in retrieval_data.items():
+ for q_idx, info in enumerate(infos):
+ cache_key = (q_idx, frozenset(info.indices_tuple))
+ if cache_key not in llm_cache and cache_key not in needed_keys:
+ needed_keys[cache_key] = (q_idx, info.indices_tuple)
+
+ total_grid_calls = n_questions * len(betas) * len(max_iters)
+ already_cached = total_grid_calls - len(needed_keys) # rough; some may still be unique
+ logger.info(
+ "Unique LLM calls needed: %d (out of %d grid points, %.1f%% saving)",
+ len(needed_keys),
+ total_grid_calls,
+ (1 - len(needed_keys) / total_grid_calls) * 100 if total_grid_calls > 0 else 0,
+ )
+
+ # Generate answers for unique passage sets
+ t_gen_start = time.time()
+ for call_idx, (cache_key, (q_idx, indices_tuple)) in enumerate(needed_keys.items()):
+ # Look up passages by sorted indices
+ indices_tensor = torch.tensor(list(indices_tuple), dtype=torch.long)
+ passages = memory_bank.get_passages_by_indices(indices_tensor)
+ answer = generator.generate(questions[q_idx], passages)
+ llm_cache[cache_key] = answer
+
+ if (call_idx + 1) % 20 == 0:
+ elapsed = time.time() - t_gen_start
+ rate = (call_idx + 1) / elapsed
+ remaining = (len(needed_keys) - call_idx - 1) / rate
+ logger.info(
+ " Generated %d/%d (%.1f calls/s, ~%.0fs remaining)",
+ call_idx + 1,
+ len(needed_keys),
+ rate,
+ remaining,
+ )
+
+ t_gen_end = time.time()
+ logger.info("Phase 3 complete: %d LLM calls in %.1fs", len(needed_keys), t_gen_end - t_gen_start)
+
+ # =========================================================================
+ # Phase 4: Evaluate all grid points
+ # =========================================================================
+ logger.info("Phase 4: Evaluating all grid points...")
+
+ grid_results: List[GridPoint] = []
+ for beta in betas:
+ for max_iter in max_iters:
+ infos = retrieval_data[(beta, max_iter)]
+ em_scores = []
+ f1_scores = []
+ entropies = []
+ energy_gaps = []
+ faiss_overlaps = []
+ steps_list = []
+
+ for q_idx, info in enumerate(infos):
+ cache_key = (q_idx, frozenset(info.indices_tuple))
+ answer = llm_cache[cache_key]
+
+ em_scores.append(exact_match(answer, gold_answers[q_idx]))
+ f1_scores.append(f1_score(answer, gold_answers[q_idx]))
+ entropies.append(info.entropy)
+ energy_gaps.append(info.energy_gap)
+ faiss_overlaps.append(info.faiss_overlap)
+ steps_list.append(info.steps)
+
+ gp = GridPoint(
+ beta=beta,
+ max_iter=max_iter,
+ em=sum(em_scores) / len(em_scores),
+ f1=sum(f1_scores) / len(f1_scores),
+ avg_entropy=sum(entropies) / len(entropies),
+ avg_energy_gap=sum(energy_gaps) / len(energy_gaps),
+ avg_faiss_overlap=sum(faiss_overlaps) / len(faiss_overlaps),
+ avg_steps=sum(steps_list) / len(steps_list),
+ )
+ grid_results.append(gp)
+ logger.info(
+ " beta=%.2f max_iter=%2d => EM=%.3f F1=%.3f entropy=%.3f energy_gap=%.3f faiss_overlap=%.3f",
+ beta,
+ max_iter,
+ gp.em,
+ gp.f1,
+ gp.avg_entropy,
+ gp.avg_energy_gap,
+ gp.avg_faiss_overlap,
+ )
+
+ total_llm_calls = len(faiss_cache) + len(needed_keys)
+ meta = {
+ "grid_size": len(betas) * len(max_iters),
+ "n_questions": n_questions,
+ "total_grid_evaluations": total_grid_calls,
+ "unique_llm_calls": len(needed_keys),
+ "faiss_llm_calls": len(faiss_cache),
+ "total_llm_calls": total_llm_calls,
+ "savings_pct": round(
+ (1 - total_llm_calls / (total_grid_calls + len(faiss_cache))) * 100, 1
+ )
+ if (total_grid_calls + len(faiss_cache)) > 0
+ else 0,
+ "retrieval_time_s": round(t_retrieve_end - t_retrieve_start, 2),
+ "generation_time_s": round(t_gen_end - t_gen_start, 2),
+ }
+
+ return grid_results, meta
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(
+ description="Grid search over HAG hyperparameters (beta, max_iter)"
+ )
+ parser.add_argument("--config", type=str, default="configs/hotpotqa.yaml")
+ parser.add_argument("--memory-bank", type=str, required=True)
+ parser.add_argument("--questions", type=str, required=True)
+ parser.add_argument("--device", type=str, default="cpu")
+ parser.add_argument("--max-samples", type=int, default=100)
+ parser.add_argument(
+ "--output",
+ type=str,
+ default=None,
+ help="Output JSON path (default: data/processed/grid_search_results.json)",
+ )
+ parser.add_argument(
+ "--betas",
+ type=float,
+ nargs="+",
+ default=[0.25, 0.5, 1.0, 2.0, 3.0, 5.0, 8.0],
+ )
+ parser.add_argument(
+ "--max-iters",
+ type=int,
+ nargs="+",
+ default=[1, 2, 3, 5, 8, 15],
+ )
+ parser.add_argument("--top-k", type=int, default=5)
+ args = parser.parse_args()
+
+ import yaml
+
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ output_path = args.output or "data/processed/grid_search_results.json"
+
+ # =========================================================================
+ # Phase 1: Load everything once
+ # =========================================================================
+ logger.info("=" * 60)
+ logger.info("HAG Grid Search")
+ logger.info(" betas: %s", args.betas)
+ logger.info(" max_iters: %s", args.max_iters)
+ logger.info(" top_k: %d", args.top_k)
+ logger.info(" grid points: %d", len(args.betas) * len(args.max_iters))
+ logger.info(" max_samples: %d", args.max_samples)
+ logger.info(" device: %s", args.device)
+ logger.info("=" * 60)
+
+ t_start = time.time()
+
+ # Load questions
+ logger.info("Loading questions from %s...", args.questions)
+ questions, gold_answers = load_questions(args.questions, args.max_samples)
+ logger.info("Loaded %d questions", len(questions))
+
+ # Load memory bank
+ logger.info("Loading memory bank from %s...", args.memory_bank)
+ mb_config = MemoryBankConfig(**cfg.get("memory", {}))
+ memory_bank = MemoryBank(mb_config)
+ memory_bank.load(args.memory_bank, device=args.device)
+ logger.info("Memory bank: %d passages, dim=%d", memory_bank.size, memory_bank.dim)
+
+ # Load encoder
+ logger.info("Loading encoder...")
+ encoder_config = EncoderConfig(**cfg.get("encoder", {}))
+ encoder = Encoder(encoder_config, device=args.device)
+
+ # Load generator
+ logger.info("Loading generator...")
+ generator_config = GeneratorConfig(**cfg.get("generator", {}))
+ generator = Generator(generator_config, device=args.device)
+
+ # Encode all questions once
+ logger.info("Encoding %d questions...", len(questions))
+ t_enc_start = time.time()
+ query_embeddings = encode_questions_batched(
+ encoder, questions, batch_size=encoder_config.batch_size
+ ) # (N, d) on device
+ t_enc_end = time.time()
+ logger.info("Encoded in %.2fs, shape=%s", t_enc_end - t_enc_start, query_embeddings.shape)
+
+ # =========================================================================
+ # Run FAISS baseline
+ # =========================================================================
+ faiss_metrics, faiss_cache = run_faiss_baseline(
+ query_embeddings, memory_bank, generator, questions, gold_answers, args.top_k
+ )
+
+ # =========================================================================
+ # Run Hopfield grid search
+ # =========================================================================
+ grid_results, meta = run_hopfield_grid(
+ query_embeddings,
+ memory_bank,
+ generator,
+ questions,
+ gold_answers,
+ faiss_cache,
+ betas=args.betas,
+ max_iters=args.max_iters,
+ top_k=args.top_k,
+ device=args.device,
+ )
+
+ # =========================================================================
+ # Find best config and save results
+ # =========================================================================
+ best = max(grid_results, key=lambda gp: gp.f1)
+
+ t_total = time.time() - t_start
+ meta["total_time_s"] = round(t_total, 1)
+
+ output = {
+ "meta": meta,
+ "faiss_baseline": faiss_metrics,
+ "grid_results": [
+ {
+ "beta": gp.beta,
+ "max_iter": gp.max_iter,
+ "em": round(gp.em, 4),
+ "f1": round(gp.f1, 4),
+ "avg_entropy": round(gp.avg_entropy, 4),
+ "avg_energy_gap": round(gp.avg_energy_gap, 4),
+ "avg_faiss_overlap": round(gp.avg_faiss_overlap, 4),
+ "avg_steps": round(gp.avg_steps, 2),
+ }
+ for gp in grid_results
+ ],
+ "best_config": {
+ "beta": best.beta,
+ "max_iter": best.max_iter,
+ "em": round(best.em, 4),
+ "f1": round(best.f1, 4),
+ "avg_entropy": round(best.avg_entropy, 4),
+ "avg_energy_gap": round(best.avg_energy_gap, 4),
+ "avg_faiss_overlap": round(best.avg_faiss_overlap, 4),
+ },
+ }
+
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
+ with open(output_path, "w") as f:
+ json.dump(output, f, indent=2)
+
+ logger.info("=" * 60)
+ logger.info("RESULTS SUMMARY")
+ logger.info(" FAISS baseline: EM=%.4f, F1=%.4f", faiss_metrics["em"], faiss_metrics["f1"])
+ logger.info(
+ " Best HAG config: beta=%.2f, max_iter=%d => EM=%.4f, F1=%.4f",
+ best.beta,
+ best.max_iter,
+ best.em,
+ best.f1,
+ )
+ logger.info(" Total LLM calls: %d (saved %.1f%%)", meta["total_llm_calls"], meta["savings_pct"])
+ logger.info(" Total time: %.1fs", t_total)
+ logger.info(" Results saved to: %s", output_path)
+ logger.info("=" * 60)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/run_hag.py b/scripts/run_hag.py
index 4cacd1a..b6c9004 100644
--- a/scripts/run_hag.py
+++ b/scripts/run_hag.py
@@ -33,6 +33,7 @@ def main() -> None:
parser.add_argument("--beta", type=float, default=None)
parser.add_argument("--max-iter", type=int, default=None)
parser.add_argument("--top-k", type=int, default=None)
+ parser.add_argument("--device", type=str, default="cpu")
args = parser.parse_args()
with open(args.config) as f:
@@ -52,13 +53,14 @@ def main() -> None:
encoder=EncoderConfig(**cfg.get("encoder", {})),
generator=GeneratorConfig(**cfg.get("generator", {})),
retriever_type="hopfield",
+ device=args.device,
)
mb = MemoryBank(pipeline_config.memory)
- mb.load(args.memory_bank)
+ mb.load(args.memory_bank, device=args.device)
- encoder = Encoder(pipeline_config.encoder)
- generator = Generator(pipeline_config.generator)
+ encoder = Encoder(pipeline_config.encoder, device=args.device)
+ generator = Generator(pipeline_config.generator, device=args.device)
pipeline = RAGPipeline(
config=pipeline_config,
diff --git a/scripts/visualize_energy.py b/scripts/visualize_energy.py
new file mode 100644
index 0000000..f39953c
--- /dev/null
+++ b/scripts/visualize_energy.py
@@ -0,0 +1,443 @@
+"""Visualize Hopfield energy landscape: centered vs uncentered.
+
+Produces 4 figures, each with centered/uncentered side-by-side:
+ 1. 2D contour + Hopfield trajectory
+ 2. 1D energy profile along key directions
+ 3. UMAP of memories + query trajectories
+ 4. PCA top-2 energy heatmap
+
+Usage:
+ CUDA_VISIBLE_DEVICES=1 python -u scripts/visualize_energy.py \
+ --memory-bank data/processed/hotpotqa_memory_bank.pt \
+ --questions data/processed/hotpotqa_questions.jsonl \
+ --device cuda --query-idx 0
+"""
+
+import argparse
+import json
+import sys
+from pathlib import Path
+
+import matplotlib
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+import matplotlib.colors as mcolors
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+sys.path.insert(0, "/home/yurenh2/HAG")
+
+from hag.config import MemoryBankConfig, EncoderConfig
+from hag.memory_bank import MemoryBank
+from hag.encoder import Encoder
+
+
+# ── Helpers ──────────────────────────────────────────────────────────
+
+def compute_energy(q: torch.Tensor, M: torch.Tensor, beta: float) -> torch.Tensor:
+ """E(q) = -1/β · logsumexp(β · qᵀM) + 1/2 · ‖q‖²
+
+ Args:
+ q: (..., d)
+ M: (d, N)
+ beta: inverse temperature
+ Returns:
+ energy: (...)
+ """
+ logits = beta * (q @ M) # (..., N)
+ lse = torch.logsumexp(logits, dim=-1) # (...)
+ norm_sq = 0.5 * (q ** 2).sum(dim=-1) # (...)
+ return -1.0 / beta * lse + norm_sq
+
+
+def hopfield_trajectory(q0: torch.Tensor, M: torch.Tensor, beta: float,
+ max_iter: int = 15) -> torch.Tensor:
+ """Run Hopfield and return full trajectory. Returns (T+1, d)."""
+ q = q0.clone().unsqueeze(0) if q0.dim() == 1 else q0.clone() # (1, d)
+ traj = [q.squeeze(0).clone()]
+ for _ in range(max_iter):
+ logits = beta * (q @ M)
+ alpha = torch.softmax(logits, dim=-1)
+ q_new = alpha @ M.T
+ traj.append(q_new.squeeze(0).clone())
+ if (q_new - q).norm() < 1e-8:
+ break
+ q = q_new
+ return torch.stack(traj, dim=0) # (T+1, d)
+
+
+def orthonormalize(v1: torch.Tensor, v2: torch.Tensor):
+ """Return two orthonormal vectors spanning the plane of v1, v2."""
+ e1 = v1 / v1.norm()
+ v2_orth = v2 - (v2 @ e1) * e1
+ if v2_orth.norm() < 1e-8:
+ # v1 and v2 are parallel, pick a random orthogonal direction
+ rand = torch.randn_like(v1)
+ v2_orth = rand - (rand @ e1) * e1
+ e2 = v2_orth / v2_orth.norm()
+ return e1, e2
+
+
+def project_to_plane(points: torch.Tensor, e1: torch.Tensor, e2: torch.Tensor):
+ """Project (K, d) points onto 2D plane defined by e1, e2. Returns (K, 2)."""
+ return torch.stack([points @ e1, points @ e2], dim=-1)
+
+
+# ── Figure 1: 2D Contour + Trajectory ───────────────────────────────
+
+def fig1_contour(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device):
+ """2D energy contour on the query-centroid plane, with Hopfield trajectories."""
+
+ centroid = M_raw.mean(dim=1) # (d,)
+ q0_cent = q0_raw - mu
+
+ fig, axes = plt.subplots(2, len(betas_plot), figsize=(6 * len(betas_plot), 12),
+ squeeze=False)
+
+ for col, beta in enumerate(betas_plot):
+ for row, (label, M, q0, ref_point, ref_label) in enumerate([
+ ("Uncentered", M_raw, q0_raw, centroid, "centroid"),
+ ("Centered", M_cent, q0_cent, torch.zeros_like(centroid), "origin"),
+ ]):
+ ax = axes[row, col]
+
+ # Define 2D plane: query direction + centroid/origin direction
+ e1, e2 = orthonormalize(q0.to(device), ref_point.to(device) if ref_point.norm() > 1e-6 else M.to(device)[:, 0])
+
+ # Grid
+ grid_range = 1.5
+ n_grid = 150
+ xs = torch.linspace(-grid_range, grid_range, n_grid, device=device)
+ ys = torch.linspace(-grid_range, grid_range, n_grid, device=device)
+ xx, yy = torch.meshgrid(xs, ys, indexing='ij')
+ grid_points = xx.reshape(-1, 1) * e1.unsqueeze(0) + yy.reshape(-1, 1) * e2.unsqueeze(0) # (n^2, d)
+
+ E = compute_energy(grid_points, M.to(device), beta).reshape(n_grid, n_grid).cpu().numpy()
+
+ # Trajectory
+ traj = hopfield_trajectory(q0.to(device), M.to(device), beta, max_iter=15)
+ traj_2d = project_to_plane(traj, e1, e2).cpu().numpy()
+
+ # Project memories
+ mem_2d = project_to_plane(M.T.to(device), e1, e2).cpu().numpy()
+
+ # Project reference point
+ ref_2d = project_to_plane(ref_point.unsqueeze(0).to(device), e1, e2).cpu().numpy()
+
+ # Plot
+ xx_np, yy_np = xx.cpu().numpy(), yy.cpu().numpy()
+ # Clip energy for better visualization
+ E_clip = np.clip(E, np.percentile(E, 1), np.percentile(E, 95))
+ cs = ax.contourf(xx_np, yy_np, E_clip, levels=40, cmap='viridis')
+ ax.contour(xx_np, yy_np, E_clip, levels=15, colors='white', linewidths=0.3, alpha=0.5)
+
+ # Memories (small dots)
+ ax.scatter(mem_2d[:, 0], mem_2d[:, 1], c='white', s=3, alpha=0.3, zorder=2)
+
+ # Reference point
+ if ref_point.norm() > 1e-6:
+ ax.scatter(ref_2d[:, 0], ref_2d[:, 1], c='red', s=100, marker='*',
+ zorder=5, label=ref_label)
+ else:
+ ax.scatter(0, 0, c='red', s=100, marker='*', zorder=5, label='origin')
+
+ # Trajectory
+ ax.plot(traj_2d[:, 0], traj_2d[:, 1], 'o-', color='#ff6600', markersize=4,
+ linewidth=2, zorder=4, label='trajectory')
+ ax.scatter(traj_2d[0, 0], traj_2d[0, 1], c='lime', s=80, marker='s',
+ zorder=5, label='q₀')
+ ax.scatter(traj_2d[-1, 0], traj_2d[-1, 1], c='magenta', s=80, marker='D',
+ zorder=5, label=f'q_T (T={len(traj_2d)-1})')
+
+ ax.set_title(f"{label}, β={beta}", fontsize=13, fontweight='bold')
+ ax.set_xlabel("e₁ (query dir)")
+ ax.set_ylabel("e₂")
+ ax.legend(fontsize=7, loc='upper right')
+ plt.colorbar(cs, ax=ax, shrink=0.8, label='E(q)')
+
+ fig.suptitle("Fig 1: 2D Energy Contour + Hopfield Trajectory", fontsize=15, fontweight='bold')
+ fig.tight_layout()
+ fig.savefig(outdir / "fig1_contour.png", dpi=150, bbox_inches='tight')
+ plt.close(fig)
+ print(f"Saved {outdir / 'fig1_contour.png'}")
+
+
+# ── Figure 2: 1D Energy Profile ─────────────────────────────────────
+
+def fig2_profile(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device):
+ """1D energy along key directions."""
+
+ centroid = M_raw.mean(dim=1)
+ q0_cent = q0_raw - mu
+
+ # Find top-1 memory for each
+ scores_raw = q0_raw @ M_raw
+ top1_raw_idx = scores_raw.argmax().item()
+ top1_raw = M_raw[:, top1_raw_idx]
+
+ scores_cent = q0_cent @ M_cent
+ top1_cent_idx = scores_cent.argmax().item()
+ top1_cent = M_cent[:, top1_cent_idx]
+
+ fig, axes = plt.subplots(2, len(betas_plot), figsize=(6 * len(betas_plot), 10),
+ squeeze=False)
+
+ ts = torch.linspace(-0.5, 2.0, 300, device=device)
+
+ for col, beta in enumerate(betas_plot):
+ # Uncentered
+ ax = axes[0, col]
+ for target, name, color in [
+ (centroid, "→ centroid", "red"),
+ (top1_raw, f"→ memory[{top1_raw_idx}]", "blue"),
+ (torch.zeros_like(q0_raw), "→ origin", "gray"),
+ ]:
+ direction = target - q0_raw.to(device)
+ if direction.norm() < 1e-8:
+ continue
+ points = q0_raw.unsqueeze(0).to(device) + ts.unsqueeze(1) * direction.unsqueeze(0)
+ E = compute_energy(points, M_raw.to(device), beta).cpu().numpy()
+ ax.plot(ts.cpu().numpy(), E, label=name, color=color, linewidth=2)
+
+ # Mark t=0 (query) and t=1 (target)
+ E_q0 = compute_energy(q0_raw.unsqueeze(0).to(device), M_raw.to(device), beta).item()
+ ax.axvline(0, color='lime', linestyle='--', alpha=0.5, label='q₀')
+ ax.axvline(1, color='black', linestyle=':', alpha=0.5, label='target')
+ ax.set_title(f"Uncentered, β={beta}", fontsize=13, fontweight='bold')
+ ax.set_xlabel("t (q₀ + t·(target - q₀))")
+ ax.set_ylabel("E(q)")
+ ax.legend(fontsize=8)
+ ax.grid(True, alpha=0.3)
+
+ # Centered
+ ax = axes[1, col]
+ for target, name, color in [
+ (torch.zeros_like(q0_cent), "→ origin", "red"),
+ (top1_cent, f"→ memory[{top1_cent_idx}]", "blue"),
+ ]:
+ direction = target.to(device) - q0_cent.to(device)
+ if direction.norm() < 1e-8:
+ continue
+ points = q0_cent.unsqueeze(0).to(device) + ts.unsqueeze(1) * direction.unsqueeze(0)
+ E = compute_energy(points, M_cent.to(device), beta).cpu().numpy()
+ ax.plot(ts.cpu().numpy(), E, label=name, color=color, linewidth=2)
+
+ ax.axvline(0, color='lime', linestyle='--', alpha=0.5, label='q₀')
+ ax.axvline(1, color='black', linestyle=':', alpha=0.5, label='target')
+ ax.set_title(f"Centered, β={beta}", fontsize=13, fontweight='bold')
+ ax.set_xlabel("t (q̃₀ + t·(target - q̃₀))")
+ ax.set_ylabel("E(q)")
+ ax.legend(fontsize=8)
+ ax.grid(True, alpha=0.3)
+
+ fig.suptitle("Fig 2: 1D Energy Profile Along Key Directions", fontsize=15, fontweight='bold')
+ fig.tight_layout()
+ fig.savefig(outdir / "fig2_profile.png", dpi=150, bbox_inches='tight')
+ plt.close(fig)
+ print(f"Saved {outdir / 'fig2_profile.png'}")
+
+
+# ── Figure 3: UMAP + Trajectories ───────────────────────────────────
+
+def fig3_umap(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device):
+ """UMAP of memories + query trajectories."""
+ try:
+ import umap
+ except ImportError:
+ print("umap-learn not installed, skipping fig3")
+ return
+
+ centroid = M_raw.mean(dim=1)
+ q0_cent = q0_raw - mu
+
+ fig, axes = plt.subplots(2, len(betas_plot), figsize=(6 * len(betas_plot), 12),
+ squeeze=False)
+
+ for col, beta in enumerate(betas_plot):
+ for row, (label, M, q0) in enumerate([
+ ("Uncentered", M_raw, q0_raw),
+ ("Centered", M_cent, q0_cent),
+ ]):
+ ax = axes[row, col]
+
+ # Trajectory
+ traj = hopfield_trajectory(q0.to(device), M.to(device), beta, max_iter=15)
+ traj_cpu = traj.cpu()
+
+ # Combine memories + trajectory for UMAP
+ mem_cpu = M.T.cpu() # (N, d)
+ all_points = torch.cat([mem_cpu, traj_cpu], dim=0).numpy()
+
+ # Fit UMAP
+ reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, random_state=42)
+ embedding = reducer.fit_transform(all_points)
+
+ n_mem = mem_cpu.shape[0]
+ mem_emb = embedding[:n_mem]
+ traj_emb = embedding[n_mem:]
+
+ # Energy for color
+ E_mem = compute_energy(mem_cpu.to(device), M.to(device), beta).cpu().numpy()
+
+ # Plot memories colored by energy
+ sc = ax.scatter(mem_emb[:, 0], mem_emb[:, 1], c=E_mem, cmap='viridis',
+ s=10, alpha=0.5, zorder=1)
+
+ # Plot trajectory
+ ax.plot(traj_emb[:, 0], traj_emb[:, 1], 'o-', color='#ff6600',
+ markersize=5, linewidth=2, zorder=3, label='trajectory')
+ ax.scatter(traj_emb[0, 0], traj_emb[0, 1], c='lime', s=100,
+ marker='s', zorder=4, label='q₀')
+ ax.scatter(traj_emb[-1, 0], traj_emb[-1, 1], c='magenta', s=100,
+ marker='D', zorder=4, label=f'q_T')
+
+ ax.set_title(f"{label}, β={beta}", fontsize=13, fontweight='bold')
+ ax.legend(fontsize=8, loc='upper right')
+ plt.colorbar(sc, ax=ax, shrink=0.8, label='E(q)')
+
+ fig.suptitle("Fig 3: UMAP of Memories + Hopfield Trajectory (color = energy)",
+ fontsize=15, fontweight='bold')
+ fig.tight_layout()
+ fig.savefig(outdir / "fig3_umap.png", dpi=150, bbox_inches='tight')
+ plt.close(fig)
+ print(f"Saved {outdir / 'fig3_umap.png'}")
+
+
+# ── Figure 4: PCA Top-2 Energy Heatmap ──────────────────────────────
+
+def fig4_pca(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device):
+ """Energy heatmap on PCA top-2 components of memory bank."""
+
+ centroid = M_raw.mean(dim=1)
+ q0_cent = q0_raw - mu
+
+ fig, axes = plt.subplots(2, len(betas_plot), figsize=(6 * len(betas_plot), 12),
+ squeeze=False)
+
+ for row, (label, M, q0, ref_point, ref_label) in enumerate([
+ ("Uncentered", M_raw, q0_raw, centroid, "centroid"),
+ ("Centered", M_cent, q0_cent, torch.zeros_like(centroid), "origin"),
+ ]):
+ # PCA on this memory bank
+ M_cpu = M.cpu() # (d, N)
+ # SVD of M to get top-2 directions
+ U, S, Vh = torch.linalg.svd(M_cpu, full_matrices=False)
+ pc1 = U[:, 0].to(device) # (d,)
+ pc2 = U[:, 1].to(device) # (d,)
+
+ for col, beta in enumerate(betas_plot):
+ ax = axes[row, col]
+
+ # Grid in PCA space
+ grid_range = 1.5
+ n_grid = 150
+ xs = torch.linspace(-grid_range, grid_range, n_grid, device=device)
+ ys = torch.linspace(-grid_range, grid_range, n_grid, device=device)
+ xx, yy = torch.meshgrid(xs, ys, indexing='ij')
+ grid_points = xx.reshape(-1, 1) * pc1.unsqueeze(0) + yy.reshape(-1, 1) * pc2.unsqueeze(0)
+
+ E = compute_energy(grid_points, M.to(device), beta).reshape(n_grid, n_grid).cpu().numpy()
+
+ # Trajectory
+ traj = hopfield_trajectory(q0.to(device), M.to(device), beta, max_iter=15)
+ traj_2d = project_to_plane(traj, pc1, pc2).cpu().numpy()
+
+ # Memories projected
+ mem_2d = project_to_plane(M.T.to(device), pc1, pc2).cpu().numpy()
+
+ # Reference point
+ ref_2d = project_to_plane(ref_point.unsqueeze(0).to(device), pc1, pc2).cpu().numpy()
+
+ # Plot
+ xx_np, yy_np = xx.cpu().numpy(), yy.cpu().numpy()
+ E_clip = np.clip(E, np.percentile(E, 1), np.percentile(E, 95))
+ cs = ax.pcolormesh(xx_np, yy_np, E_clip, cmap='viridis', shading='auto')
+ ax.contour(xx_np, yy_np, E_clip, levels=15, colors='white', linewidths=0.3, alpha=0.5)
+
+ ax.scatter(mem_2d[:, 0], mem_2d[:, 1], c='white', s=3, alpha=0.3, zorder=2)
+
+ if ref_point.norm() > 1e-6:
+ ax.scatter(ref_2d[:, 0], ref_2d[:, 1], c='red', s=100, marker='*',
+ zorder=5, label=ref_label)
+ else:
+ ax.scatter(0, 0, c='red', s=100, marker='*', zorder=5, label='origin')
+
+ ax.plot(traj_2d[:, 0], traj_2d[:, 1], 'o-', color='#ff6600', markersize=4,
+ linewidth=2, zorder=4)
+ ax.scatter(traj_2d[0, 0], traj_2d[0, 1], c='lime', s=80, marker='s',
+ zorder=5, label='q₀')
+ ax.scatter(traj_2d[-1, 0], traj_2d[-1, 1], c='magenta', s=80, marker='D',
+ zorder=5, label='q_T')
+
+ ax.set_title(f"{label}, β={beta}", fontsize=13, fontweight='bold')
+ ax.set_xlabel("PC1")
+ ax.set_ylabel("PC2")
+ ax.legend(fontsize=7, loc='upper right')
+ plt.colorbar(cs, ax=ax, shrink=0.8, label='E(q)')
+
+ fig.suptitle("Fig 4: PCA Top-2 Energy Heatmap + Trajectory", fontsize=15, fontweight='bold')
+ fig.tight_layout()
+ fig.savefig(outdir / "fig4_pca.png", dpi=150, bbox_inches='tight')
+ plt.close(fig)
+ print(f"Saved {outdir / 'fig4_pca.png'}")
+
+
+# ── Main ─────────────────────────────────────────────────────────────
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--memory-bank", type=str, required=True)
+ parser.add_argument("--questions", type=str, required=True)
+ parser.add_argument("--device", type=str, default="cpu")
+ parser.add_argument("--query-idx", type=int, default=0)
+ parser.add_argument("--outdir", type=str, default="figures")
+ args = parser.parse_args()
+
+ device = args.device
+ outdir = Path(args.outdir)
+ outdir.mkdir(parents=True, exist_ok=True)
+
+ # Load memory bank
+ mb = MemoryBank(MemoryBankConfig(embedding_dim=768, normalize=True, center=False))
+ mb.load(args.memory_bank, device=device)
+ M_raw = mb.embeddings # (d, N)
+ d, N = M_raw.shape
+ print(f"Memory bank: d={d}, N={N}")
+
+ # Center
+ mu = M_raw.mean(dim=1) # (d,)
+ M_cent = M_raw - mu.unsqueeze(1)
+ print(f"‖μ‖ = {mu.norm():.4f}")
+
+ # Load one query
+ with open(args.questions) as f:
+ questions = [json.loads(line) for line in f]
+
+ q_text = questions[args.query_idx]["question"]
+ print(f"Query [{args.query_idx}]: '{q_text}'")
+
+ encoder = Encoder(EncoderConfig(model_name="facebook/contriever-msmarco"), device=device)
+ q0_raw = encoder.encode([q_text]).squeeze(0) # (d,)
+ print(f"‖q0_raw‖ = {q0_raw.norm():.4f}")
+
+ # β values: below and above β_critical ≈ 37.6
+ betas_plot = [5.0, 20.0, 50.0, 100.0]
+
+ print("\n--- Generating Figure 1: 2D Contour ---")
+ fig1_contour(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device)
+
+ print("\n--- Generating Figure 2: 1D Profile ---")
+ fig2_profile(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device)
+
+ print("\n--- Generating Figure 3: UMAP ---")
+ fig3_umap(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device)
+
+ print("\n--- Generating Figure 4: PCA Heatmap ---")
+ fig4_pca(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device)
+
+ print(f"\nAll figures saved to {outdir}/")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/visualize_trajectory.py b/scripts/visualize_trajectory.py
index e4ba902..0087563 100644
--- a/scripts/visualize_trajectory.py
+++ b/scripts/visualize_trajectory.py
@@ -26,6 +26,7 @@ def main() -> None:
parser.add_argument("--memory-bank", type=str, required=True)
parser.add_argument("--question", type=str, required=True)
parser.add_argument("--output", type=str, default="trajectory.png")
+ parser.add_argument("--device", type=str, default="cpu")
args = parser.parse_args()
with open(args.config) as f:
@@ -36,9 +37,9 @@ def main() -> None:
encoder_config = EncoderConfig(**cfg.get("encoder", {}))
mb = MemoryBank(memory_config)
- mb.load(args.memory_bank)
+ mb.load(args.memory_bank, device=args.device)
- encoder = Encoder(encoder_config)
+ encoder = Encoder(encoder_config, device=args.device)
hopfield = HopfieldRetrieval(hopfield_config)
query_emb = encoder.encode(args.question) # (1, d)
@@ -46,9 +47,9 @@ def main() -> None:
query_emb, mb.embeddings, return_trajectory=True
)
- # Gather all points for UMAP: memories + trajectory
- memories_np = mb.embeddings.T.numpy() # (N, d)
- trajectory_np = np.stack([q.squeeze().numpy() for q in result.trajectory]) # (T+1, d)
+ # Gather all points for UMAP: memories + trajectory (must be on CPU for numpy)
+ memories_np = mb.embeddings.T.cpu().numpy() # (N, d)
+ trajectory_np = np.stack([q.squeeze().cpu().numpy() for q in result.trajectory]) # (T+1, d)
all_points = np.concatenate([memories_np, trajectory_np], axis=0)
# UMAP projection