diff options
27 files changed, 5108 insertions, 33 deletions
diff --git a/data/processed/grid_search_results.json b/data/processed/grid_search_results.json new file mode 100644 index 0000000..cdbd3ca --- /dev/null +++ b/data/processed/grid_search_results.json @@ -0,0 +1,449 @@ +{ + "meta": { + "grid_size": 42, + "n_questions": 100, + "total_grid_evaluations": 4200, + "unique_llm_calls": 281, + "faiss_llm_calls": 100, + "total_llm_calls": 381, + "savings_pct": 91.1, + "retrieval_time_s": 0.93, + "generation_time_s": 735.92, + "total_time_s": 1049.0 + }, + "faiss_baseline": { + "em": 0.32, + "f1": 0.4380753968253968 + }, + "grid_results": [ + { + "beta": 0.25, + "max_iter": 1, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1785, + "avg_energy_gap": 2.7302, + "avg_faiss_overlap": 0.002, + "avg_steps": 1.0 + }, + { + "beta": 0.25, + "max_iter": 2, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1785, + "avg_energy_gap": 2.7302, + "avg_faiss_overlap": 0.002, + "avg_steps": 2.0 + }, + { + "beta": 0.25, + "max_iter": 3, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1785, + "avg_energy_gap": 2.7302, + "avg_faiss_overlap": 0.002, + "avg_steps": 3.0 + }, + { + "beta": 0.25, + "max_iter": 5, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1785, + "avg_energy_gap": 2.7302, + "avg_faiss_overlap": 0.002, + "avg_steps": 3.0 + }, + { + "beta": 0.25, + "max_iter": 8, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1785, + "avg_energy_gap": 2.7302, + "avg_faiss_overlap": 0.002, + "avg_steps": 3.0 + }, + { + "beta": 0.25, + "max_iter": 15, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1785, + "avg_energy_gap": 2.7302, + "avg_faiss_overlap": 0.002, + "avg_steps": 3.0 + }, + { + "beta": 0.5, + "max_iter": 1, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1784, + "avg_energy_gap": 2.7292, + "avg_faiss_overlap": 0.002, + "avg_steps": 1.0 + }, + { + "beta": 0.5, + "max_iter": 2, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1784, + "avg_energy_gap": 2.7292, + "avg_faiss_overlap": 0.002, + "avg_steps": 2.0 + }, + { + "beta": 0.5, + "max_iter": 3, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1784, + "avg_energy_gap": 2.7292, + "avg_faiss_overlap": 0.002, + "avg_steps": 3.0 + }, + { + "beta": 0.5, + "max_iter": 5, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1784, + "avg_energy_gap": 2.7292, + "avg_faiss_overlap": 0.002, + "avg_steps": 3.0 + }, + { + "beta": 0.5, + "max_iter": 8, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1784, + "avg_energy_gap": 2.7292, + "avg_faiss_overlap": 0.002, + "avg_steps": 3.0 + }, + { + "beta": 0.5, + "max_iter": 15, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1784, + "avg_energy_gap": 2.7292, + "avg_faiss_overlap": 0.002, + "avg_steps": 3.0 + }, + { + "beta": 1.0, + "max_iter": 1, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1781, + "avg_energy_gap": 2.7273, + "avg_faiss_overlap": 0.002, + "avg_steps": 1.0 + }, + { + "beta": 1.0, + "max_iter": 2, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1781, + "avg_energy_gap": 2.7273, + "avg_faiss_overlap": 0.002, + "avg_steps": 2.0 + }, + { + "beta": 1.0, + "max_iter": 3, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1781, + "avg_energy_gap": 2.7273, + "avg_faiss_overlap": 0.002, + "avg_steps": 3.0 + }, + { + "beta": 1.0, + "max_iter": 5, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1781, + "avg_energy_gap": 2.7273, + "avg_faiss_overlap": 0.002, + "avg_steps": 4.0 + }, + { + "beta": 1.0, + "max_iter": 8, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1781, + "avg_energy_gap": 2.7273, + "avg_faiss_overlap": 0.002, + "avg_steps": 4.0 + }, + { + "beta": 1.0, + "max_iter": 15, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1781, + "avg_energy_gap": 2.7273, + "avg_faiss_overlap": 0.002, + "avg_steps": 4.0 + }, + { + "beta": 2.0, + "max_iter": 1, + "em": 0.15, + "f1": 0.1977, + "avg_entropy": 7.1767, + "avg_energy_gap": 2.7234, + "avg_faiss_overlap": 0.004, + "avg_steps": 1.0 + }, + { + "beta": 2.0, + "max_iter": 2, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1767, + "avg_energy_gap": 2.7235, + "avg_faiss_overlap": 0.002, + "avg_steps": 2.0 + }, + { + "beta": 2.0, + "max_iter": 3, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1767, + "avg_energy_gap": 2.7235, + "avg_faiss_overlap": 0.002, + "avg_steps": 3.0 + }, + { + "beta": 2.0, + "max_iter": 5, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1767, + "avg_energy_gap": 2.7235, + "avg_faiss_overlap": 0.002, + "avg_steps": 4.0 + }, + { + "beta": 2.0, + "max_iter": 8, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1767, + "avg_energy_gap": 2.7235, + "avg_faiss_overlap": 0.002, + "avg_steps": 4.0 + }, + { + "beta": 2.0, + "max_iter": 15, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1767, + "avg_energy_gap": 2.7235, + "avg_faiss_overlap": 0.002, + "avg_steps": 4.0 + }, + { + "beta": 3.0, + "max_iter": 1, + "em": 0.14, + "f1": 0.2016, + "avg_entropy": 7.1742, + "avg_energy_gap": 2.7193, + "avg_faiss_overlap": 0.006, + "avg_steps": 1.0 + }, + { + "beta": 3.0, + "max_iter": 2, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1742, + "avg_energy_gap": 2.7196, + "avg_faiss_overlap": 0.002, + "avg_steps": 2.0 + }, + { + "beta": 3.0, + "max_iter": 3, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1742, + "avg_energy_gap": 2.7196, + "avg_faiss_overlap": 0.002, + "avg_steps": 3.0 + }, + { + "beta": 3.0, + "max_iter": 5, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1742, + "avg_energy_gap": 2.7196, + "avg_faiss_overlap": 0.002, + "avg_steps": 5.0 + }, + { + "beta": 3.0, + "max_iter": 8, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1742, + "avg_energy_gap": 2.7196, + "avg_faiss_overlap": 0.002, + "avg_steps": 5.0 + }, + { + "beta": 3.0, + "max_iter": 15, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1742, + "avg_energy_gap": 2.7196, + "avg_faiss_overlap": 0.002, + "avg_steps": 5.0 + }, + { + "beta": 5.0, + "max_iter": 1, + "em": 0.16, + "f1": 0.2207, + "avg_entropy": 7.1659, + "avg_energy_gap": 2.7105, + "avg_faiss_overlap": 0.018, + "avg_steps": 1.0 + }, + { + "beta": 5.0, + "max_iter": 2, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1661, + "avg_energy_gap": 2.7114, + "avg_faiss_overlap": 0.002, + "avg_steps": 2.0 + }, + { + "beta": 5.0, + "max_iter": 3, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1661, + "avg_energy_gap": 2.7114, + "avg_faiss_overlap": 0.002, + "avg_steps": 3.0 + }, + { + "beta": 5.0, + "max_iter": 5, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1661, + "avg_energy_gap": 2.7114, + "avg_faiss_overlap": 0.002, + "avg_steps": 5.0 + }, + { + "beta": 5.0, + "max_iter": 8, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1661, + "avg_energy_gap": 2.7114, + "avg_faiss_overlap": 0.002, + "avg_steps": 5.0 + }, + { + "beta": 5.0, + "max_iter": 15, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1661, + "avg_energy_gap": 2.7114, + "avg_faiss_overlap": 0.002, + "avg_steps": 5.0 + }, + { + "beta": 8.0, + "max_iter": 1, + "em": 0.21, + "f1": 0.2938, + "avg_entropy": 7.1438, + "avg_energy_gap": 2.6938, + "avg_faiss_overlap": 0.068, + "avg_steps": 1.0 + }, + { + "beta": 8.0, + "max_iter": 2, + "em": 0.12, + "f1": 0.1766, + "avg_entropy": 7.145, + "avg_energy_gap": 2.6972, + "avg_faiss_overlap": 0.004, + "avg_steps": 2.0 + }, + { + "beta": 8.0, + "max_iter": 3, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1449, + "avg_energy_gap": 2.6972, + "avg_faiss_overlap": 0.002, + "avg_steps": 3.0 + }, + { + "beta": 8.0, + "max_iter": 5, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1448, + "avg_energy_gap": 2.6972, + "avg_faiss_overlap": 0.002, + "avg_steps": 5.0 + }, + { + "beta": 8.0, + "max_iter": 8, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1448, + "avg_energy_gap": 2.6972, + "avg_faiss_overlap": 0.002, + "avg_steps": 7.0 + }, + { + "beta": 8.0, + "max_iter": 15, + "em": 0.14, + "f1": 0.1877, + "avg_entropy": 7.1448, + "avg_energy_gap": 2.6972, + "avg_faiss_overlap": 0.002, + "avg_steps": 7.0 + } + ], + "best_config": { + "beta": 8.0, + "max_iter": 1, + "em": 0.21, + "f1": 0.2938, + "avg_entropy": 7.1438, + "avg_energy_gap": 2.6938, + "avg_faiss_overlap": 0.068 + } +}
\ No newline at end of file diff --git a/data/processed/highbeta_grid_results.json b/data/processed/highbeta_grid_results.json new file mode 100644 index 0000000..a04895a --- /dev/null +++ b/data/processed/highbeta_grid_results.json @@ -0,0 +1,828 @@ +{ + "meta": { + "n_questions": 100, + "total_configs": 105, + "unique_llm_calls": 1379, + "total_time_s": 4571.0 + }, + "faiss_baseline": { + "em": 0.32, + "f1": 0.4381 + }, + "grid_results": [ + { + "config": "\u03b2=20.0_iter=1_standard", + "em": 0.38, + "f1": 0.4691, + "avg_faiss_overlap": 0.48, + "avg_entropy": 4.5305 + }, + { + "config": "\u03b2=50.0_iter=1_standard", + "em": 0.36, + "f1": 0.4565, + "avg_faiss_overlap": 0.508, + "avg_entropy": 0.3196 + }, + { + "config": "\u03b2=20.0_iter=1_residual_0.9", + "em": 0.34, + "f1": 0.4552, + "avg_faiss_overlap": 0.966, + "avg_entropy": 3.5526 + }, + { + "config": "\u03b2=20.0_iter=2_residual_0.95", + "em": 0.34, + "f1": 0.4552, + "avg_faiss_overlap": 0.966, + "avg_entropy": 3.5503 + }, + { + "config": "\u03b2=500.0_iter=1_residual_0.9", + "em": 0.36, + "f1": 0.4545, + "avg_faiss_overlap": 0.692, + "avg_entropy": 0.0074 + }, + { + "config": "\u03b2=50.0_iter=1_normalized", + "em": 0.37, + "f1": 0.4539, + "avg_faiss_overlap": 0.464, + "avg_entropy": 1.7333 + }, + { + "config": "\u03b2=500.0_iter=1_residual_0.95", + "em": 0.36, + "f1": 0.4536, + "avg_faiss_overlap": 0.748, + "avg_entropy": 0.013 + }, + { + "config": "\u03b2=20.0_iter=1_residual_0.95", + "em": 0.34, + "f1": 0.4511, + "avg_faiss_overlap": 0.98, + "avg_entropy": 3.5011 + }, + { + "config": "\u03b2=50.0_iter=2_normalized", + "em": 0.37, + "f1": 0.4498, + "avg_faiss_overlap": 0.38, + "avg_entropy": 1.0954 + }, + { + "config": "\u03b2=20.0_iter=3_residual_0.95", + "em": 0.32, + "f1": 0.4494, + "avg_faiss_overlap": 0.946, + "avg_entropy": 3.5994 + }, + { + "config": "\u03b2=50.0_iter=1_residual_0.95", + "em": 0.34, + "f1": 0.4486, + "avg_faiss_overlap": 0.974, + "avg_entropy": 0.677 + }, + { + "config": "\u03b2=100.0_iter=1_residual_0.95", + "em": 0.34, + "f1": 0.4464, + "avg_faiss_overlap": 0.97, + "avg_entropy": 0.2081 + }, + { + "config": "\u03b2=200.0_iter=1_residual_0.95", + "em": 0.34, + "f1": 0.4464, + "avg_faiss_overlap": 0.968, + "avg_entropy": 0.0722 + }, + { + "config": "\u03b2=100.0_iter=1_normalized", + "em": 0.35, + "f1": 0.446, + "avg_faiss_overlap": 0.508, + "avg_entropy": 0.1139 + }, + { + "config": "\u03b2=500.0_iter=2_residual_0.95", + "em": 0.35, + "f1": 0.4445, + "avg_faiss_overlap": 0.692, + "avg_entropy": 0.0037 + }, + { + "config": "\u03b2=50.0_iter=2_residual_0.9", + "em": 0.33, + "f1": 0.4441, + "avg_faiss_overlap": 0.886, + "avg_entropy": 0.4749 + }, + { + "config": "\u03b2=100.0_iter=2_residual_0.9", + "em": 0.33, + "f1": 0.4441, + "avg_faiss_overlap": 0.878, + "avg_entropy": 0.0623 + }, + { + "config": "\u03b2=50.0_iter=8_residual_0.95", + "em": 0.32, + "f1": 0.4431, + "avg_faiss_overlap": 0.808, + "avg_entropy": 0.2083 + }, + { + "config": "\u03b2=100.0_iter=8_residual_0.95", + "em": 0.32, + "f1": 0.4431, + "avg_faiss_overlap": 0.794, + "avg_entropy": 0.0019 + }, + { + "config": "\u03b2=100.0_iter=3_residual_0.95", + "em": 0.33, + "f1": 0.4427, + "avg_faiss_overlap": 0.9, + "avg_entropy": 0.0754 + }, + { + "config": "\u03b2=20.0_iter=8_residual_0.95", + "em": 0.33, + "f1": 0.442, + "avg_faiss_overlap": 0.878, + "avg_entropy": 3.8365 + }, + { + "config": "\u03b2=50.0_iter=1_residual_0.9", + "em": 0.32, + "f1": 0.4419, + "avg_faiss_overlap": 0.946, + "avg_entropy": 0.6043 + }, + { + "config": "\u03b2=50.0_iter=2_residual_0.95", + "em": 0.32, + "f1": 0.4419, + "avg_faiss_overlap": 0.944, + "avg_entropy": 0.5941 + }, + { + "config": "\u03b2=50.0_iter=5_residual_0.95", + "em": 0.32, + "f1": 0.4407, + "avg_faiss_overlap": 0.87, + "avg_entropy": 0.3877 + }, + { + "config": "\u03b2=200.0_iter=1_normalized", + "em": 0.33, + "f1": 0.4404, + "avg_faiss_overlap": 0.484, + "avg_entropy": 0.0109 + }, + { + "config": "\u03b2=500.0_iter=5_residual_0.95", + "em": 0.35, + "f1": 0.4404, + "avg_faiss_overlap": 0.546, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=50.0_iter=8_residual_0.9", + "em": 0.34, + "f1": 0.4402, + "avg_faiss_overlap": 0.666, + "avg_entropy": 0.0306 + }, + { + "config": "\u03b2=100.0_iter=1_standard", + "em": 0.33, + "f1": 0.44, + "avg_faiss_overlap": 0.464, + "avg_entropy": 0.0196 + }, + { + "config": "\u03b2=100.0_iter=1_residual_0.9", + "em": 0.32, + "f1": 0.4397, + "avg_faiss_overlap": 0.932, + "avg_entropy": 0.1471 + }, + { + "config": "\u03b2=100.0_iter=2_residual_0.95", + "em": 0.32, + "f1": 0.4397, + "avg_faiss_overlap": 0.932, + "avg_entropy": 0.1268 + }, + { + "config": "\u03b2=100.0_iter=5_residual_0.95", + "em": 0.32, + "f1": 0.4391, + "avg_faiss_overlap": 0.858, + "avg_entropy": 0.019 + }, + { + "config": "\u03b2=20.0_iter=0_standard", + "em": 0.32, + "f1": 0.4381, + "avg_faiss_overlap": 1.0, + "avg_entropy": 3.452 + }, + { + "config": "\u03b2=50.0_iter=0_standard", + "em": 0.32, + "f1": 0.4381, + "avg_faiss_overlap": 1.0, + "avg_entropy": 0.7723 + }, + { + "config": "\u03b2=50.0_iter=5_standard", + "em": 0.33, + "f1": 0.4381, + "avg_faiss_overlap": 0.408, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=50.0_iter=8_standard", + "em": 0.33, + "f1": 0.4381, + "avg_faiss_overlap": 0.408, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=50.0_iter=8_normalized", + "em": 0.34, + "f1": 0.4381, + "avg_faiss_overlap": 0.358, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=100.0_iter=0_standard", + "em": 0.32, + "f1": 0.4381, + "avg_faiss_overlap": 1.0, + "avg_entropy": 0.3128 + }, + { + "config": "\u03b2=100.0_iter=2_normalized", + "em": 0.33, + "f1": 0.4381, + "avg_faiss_overlap": 0.422, + "avg_entropy": 0.0143 + }, + { + "config": "\u03b2=100.0_iter=3_normalized", + "em": 0.33, + "f1": 0.4381, + "avg_faiss_overlap": 0.412, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=100.0_iter=5_normalized", + "em": 0.33, + "f1": 0.4381, + "avg_faiss_overlap": 0.41, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=100.0_iter=8_normalized", + "em": 0.33, + "f1": 0.4381, + "avg_faiss_overlap": 0.41, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=200.0_iter=0_standard", + "em": 0.32, + "f1": 0.4381, + "avg_faiss_overlap": 1.0, + "avg_entropy": 0.1535 + }, + { + "config": "\u03b2=50.0_iter=3_residual_0.95", + "em": 0.32, + "f1": 0.4377, + "avg_faiss_overlap": 0.912, + "avg_entropy": 0.5215 + }, + { + "config": "\u03b2=20.0_iter=5_residual_0.9", + "em": 0.32, + "f1": 0.437, + "avg_faiss_overlap": 0.846, + "avg_entropy": 3.9311 + }, + { + "config": "\u03b2=20.0_iter=2_residual_0.9", + "em": 0.31, + "f1": 0.4349, + "avg_faiss_overlap": 0.926, + "avg_entropy": 3.6521 + }, + { + "config": "\u03b2=200.0_iter=1_residual_0.9", + "em": 0.32, + "f1": 0.4347, + "avg_faiss_overlap": 0.928, + "avg_entropy": 0.0466 + }, + { + "config": "\u03b2=200.0_iter=2_residual_0.95", + "em": 0.32, + "f1": 0.4347, + "avg_faiss_overlap": 0.928, + "avg_entropy": 0.0274 + }, + { + "config": "\u03b2=50.0_iter=3_residual_0.9", + "em": 0.31, + "f1": 0.4344, + "avg_faiss_overlap": 0.842, + "avg_entropy": 0.3574 + }, + { + "config": "\u03b2=200.0_iter=2_residual_0.9", + "em": 0.32, + "f1": 0.4341, + "avg_faiss_overlap": 0.876, + "avg_entropy": 0.008 + }, + { + "config": "\u03b2=200.0_iter=5_residual_0.95", + "em": 0.31, + "f1": 0.4341, + "avg_faiss_overlap": 0.852, + "avg_entropy": 0.0005 + }, + { + "config": "\u03b2=200.0_iter=1_standard", + "em": 0.33, + "f1": 0.4331, + "avg_faiss_overlap": 0.43, + "avg_entropy": 0.0061 + }, + { + "config": "\u03b2=200.0_iter=3_residual_0.95", + "em": 0.32, + "f1": 0.4327, + "avg_faiss_overlap": 0.898, + "avg_entropy": 0.0082 + }, + { + "config": "\u03b2=20.0_iter=3_residual_0.9", + "em": 0.31, + "f1": 0.4313, + "avg_faiss_overlap": 0.9, + "avg_entropy": 3.7492 + }, + { + "config": "\u03b2=500.0_iter=3_residual_0.95", + "em": 0.35, + "f1": 0.4312, + "avg_faiss_overlap": 0.636, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=20.0_iter=8_residual_0.9", + "em": 0.31, + "f1": 0.4302, + "avg_faiss_overlap": 0.748, + "avg_entropy": 4.1579 + }, + { + "config": "\u03b2=20.0_iter=5_residual_0.95", + "em": 0.31, + "f1": 0.4299, + "avg_faiss_overlap": 0.91, + "avg_entropy": 3.6963 + }, + { + "config": "\u03b2=50.0_iter=3_normalized", + "em": 0.33, + "f1": 0.4291, + "avg_faiss_overlap": 0.368, + "avg_entropy": 0.6769 + }, + { + "config": "\u03b2=50.0_iter=5_residual_0.9", + "em": 0.32, + "f1": 0.4291, + "avg_faiss_overlap": 0.778, + "avg_entropy": 0.1599 + }, + { + "config": "\u03b2=50.0_iter=3_standard", + "em": 0.32, + "f1": 0.4281, + "avg_faiss_overlap": 0.41, + "avg_entropy": 0.0016 + }, + { + "config": "\u03b2=50.0_iter=5_normalized", + "em": 0.33, + "f1": 0.4281, + "avg_faiss_overlap": 0.356, + "avg_entropy": 0.1417 + }, + { + "config": "\u03b2=100.0_iter=3_standard", + "em": 0.32, + "f1": 0.4281, + "avg_faiss_overlap": 0.406, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=100.0_iter=5_standard", + "em": 0.32, + "f1": 0.4281, + "avg_faiss_overlap": 0.406, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=100.0_iter=8_standard", + "em": 0.32, + "f1": 0.4281, + "avg_faiss_overlap": 0.406, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=200.0_iter=2_standard", + "em": 0.32, + "f1": 0.4281, + "avg_faiss_overlap": 0.408, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=200.0_iter=2_normalized", + "em": 0.32, + "f1": 0.4281, + "avg_faiss_overlap": 0.408, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=200.0_iter=3_standard", + "em": 0.32, + "f1": 0.4281, + "avg_faiss_overlap": 0.406, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=200.0_iter=3_normalized", + "em": 0.32, + "f1": 0.4281, + "avg_faiss_overlap": 0.406, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=200.0_iter=5_standard", + "em": 0.32, + "f1": 0.4281, + "avg_faiss_overlap": 0.406, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=200.0_iter=5_normalized", + "em": 0.32, + "f1": 0.4281, + "avg_faiss_overlap": 0.406, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=200.0_iter=8_standard", + "em": 0.32, + "f1": 0.4281, + "avg_faiss_overlap": 0.406, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=200.0_iter=8_normalized", + "em": 0.32, + "f1": 0.4281, + "avg_faiss_overlap": 0.406, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=50.0_iter=2_standard", + "em": 0.32, + "f1": 0.4265, + "avg_faiss_overlap": 0.422, + "avg_entropy": 0.0481 + }, + { + "config": "\u03b2=500.0_iter=3_residual_0.9", + "em": 0.33, + "f1": 0.4265, + "avg_faiss_overlap": 0.51, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=100.0_iter=3_residual_0.9", + "em": 0.29, + "f1": 0.4244, + "avg_faiss_overlap": 0.828, + "avg_entropy": 0.019 + }, + { + "config": "\u03b2=200.0_iter=3_residual_0.9", + "em": 0.29, + "f1": 0.4244, + "avg_faiss_overlap": 0.824, + "avg_entropy": 0.0021 + }, + { + "config": "\u03b2=500.0_iter=0_standard", + "em": 0.33, + "f1": 0.4236, + "avg_faiss_overlap": 0.798, + "avg_entropy": 0.0746 + }, + { + "config": "\u03b2=200.0_iter=8_residual_0.95", + "em": 0.3, + "f1": 0.4231, + "avg_faiss_overlap": 0.786, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=100.0_iter=2_standard", + "em": 0.31, + "f1": 0.4181, + "avg_faiss_overlap": 0.41, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=100.0_iter=8_residual_0.9", + "em": 0.32, + "f1": 0.4152, + "avg_faiss_overlap": 0.636, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=200.0_iter=8_residual_0.9", + "em": 0.32, + "f1": 0.4152, + "avg_faiss_overlap": 0.634, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=100.0_iter=5_residual_0.9", + "em": 0.3, + "f1": 0.4141, + "avg_faiss_overlap": 0.764, + "avg_entropy": 0.0014 + }, + { + "config": "\u03b2=500.0_iter=2_residual_0.9", + "em": 0.32, + "f1": 0.4098, + "avg_faiss_overlap": 0.584, + "avg_entropy": 0.0001 + }, + { + "config": "\u03b2=200.0_iter=5_residual_0.9", + "em": 0.29, + "f1": 0.4041, + "avg_faiss_overlap": 0.754, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=500.0_iter=5_residual_0.9", + "em": 0.31, + "f1": 0.3949, + "avg_faiss_overlap": 0.34, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=500.0_iter=8_residual_0.95", + "em": 0.29, + "f1": 0.3839, + "avg_faiss_overlap": 0.424, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=500.0_iter=8_residual_0.9", + "em": 0.27, + "f1": 0.3743, + "avg_faiss_overlap": 0.228, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=20.0_iter=2_standard", + "em": 0.29, + "f1": 0.3713, + "avg_faiss_overlap": 0.248, + "avg_entropy": 4.6996 + }, + { + "config": "\u03b2=500.0_iter=1_normalized", + "em": 0.27, + "f1": 0.3693, + "avg_faiss_overlap": 0.236, + "avg_entropy": 0.0018 + }, + { + "config": "\u03b2=500.0_iter=1_standard", + "em": 0.25, + "f1": 0.35, + "avg_faiss_overlap": 0.22, + "avg_entropy": 0.0003 + }, + { + "config": "\u03b2=500.0_iter=2_standard", + "em": 0.25, + "f1": 0.35, + "avg_faiss_overlap": 0.202, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=500.0_iter=2_normalized", + "em": 0.25, + "f1": 0.35, + "avg_faiss_overlap": 0.202, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=500.0_iter=3_standard", + "em": 0.25, + "f1": 0.35, + "avg_faiss_overlap": 0.202, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=500.0_iter=3_normalized", + "em": 0.25, + "f1": 0.35, + "avg_faiss_overlap": 0.202, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=500.0_iter=5_standard", + "em": 0.25, + "f1": 0.35, + "avg_faiss_overlap": 0.202, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=500.0_iter=5_normalized", + "em": 0.25, + "f1": 0.35, + "avg_faiss_overlap": 0.202, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=500.0_iter=8_standard", + "em": 0.25, + "f1": 0.35, + "avg_faiss_overlap": 0.202, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=500.0_iter=8_normalized", + "em": 0.25, + "f1": 0.35, + "avg_faiss_overlap": 0.202, + "avg_entropy": 0.0 + }, + { + "config": "\u03b2=20.0_iter=1_normalized", + "em": 0.27, + "f1": 0.3367, + "avg_faiss_overlap": 0.14, + "avg_entropy": 6.5894 + }, + { + "config": "\u03b2=20.0_iter=3_standard", + "em": 0.27, + "f1": 0.332, + "avg_faiss_overlap": 0.184, + "avg_entropy": 4.6947 + }, + { + "config": "\u03b2=20.0_iter=5_standard", + "em": 0.24, + "f1": 0.3116, + "avg_faiss_overlap": 0.162, + "avg_entropy": 4.6875 + }, + { + "config": "\u03b2=20.0_iter=8_standard", + "em": 0.23, + "f1": 0.3016, + "avg_faiss_overlap": 0.16, + "avg_entropy": 4.684 + }, + { + "config": "\u03b2=20.0_iter=2_normalized", + "em": 0.18, + "f1": 0.2255, + "avg_faiss_overlap": 0.044, + "avg_entropy": 6.4837 + }, + { + "config": "\u03b2=20.0_iter=8_normalized", + "em": 0.14, + "f1": 0.2221, + "avg_faiss_overlap": 0.02, + "avg_entropy": 6.3573 + }, + { + "config": "\u03b2=20.0_iter=3_normalized", + "em": 0.14, + "f1": 0.2155, + "avg_faiss_overlap": 0.024, + "avg_entropy": 6.4174 + }, + { + "config": "\u03b2=20.0_iter=5_normalized", + "em": 0.13, + "f1": 0.1961, + "avg_faiss_overlap": 0.02, + "avg_entropy": 6.3707 + } + ], + "best_config": { + "config": "\u03b2=20.0_iter=1_standard", + "em": 0.38, + "f1": 0.4691, + "avg_faiss_overlap": 0.48, + "avg_entropy": 4.5305 + }, + "top10": [ + { + "config": "\u03b2=20.0_iter=1_standard", + "em": 0.38, + "f1": 0.4691, + "avg_faiss_overlap": 0.48, + "avg_entropy": 4.5305 + }, + { + "config": "\u03b2=50.0_iter=1_standard", + "em": 0.36, + "f1": 0.4565, + "avg_faiss_overlap": 0.508, + "avg_entropy": 0.3196 + }, + { + "config": "\u03b2=20.0_iter=1_residual_0.9", + "em": 0.34, + "f1": 0.4552, + "avg_faiss_overlap": 0.966, + "avg_entropy": 3.5526 + }, + { + "config": "\u03b2=20.0_iter=2_residual_0.95", + "em": 0.34, + "f1": 0.4552, + "avg_faiss_overlap": 0.966, + "avg_entropy": 3.5503 + }, + { + "config": "\u03b2=500.0_iter=1_residual_0.9", + "em": 0.36, + "f1": 0.4545, + "avg_faiss_overlap": 0.692, + "avg_entropy": 0.0074 + }, + { + "config": "\u03b2=50.0_iter=1_normalized", + "em": 0.37, + "f1": 0.4539, + "avg_faiss_overlap": 0.464, + "avg_entropy": 1.7333 + }, + { + "config": "\u03b2=500.0_iter=1_residual_0.95", + "em": 0.36, + "f1": 0.4536, + "avg_faiss_overlap": 0.748, + "avg_entropy": 0.013 + }, + { + "config": "\u03b2=20.0_iter=1_residual_0.95", + "em": 0.34, + "f1": 0.4511, + "avg_faiss_overlap": 0.98, + "avg_entropy": 3.5011 + }, + { + "config": "\u03b2=50.0_iter=2_normalized", + "em": 0.37, + "f1": 0.4498, + "avg_faiss_overlap": 0.38, + "avg_entropy": 1.0954 + }, + { + "config": "\u03b2=20.0_iter=3_residual_0.95", + "em": 0.32, + "f1": 0.4494, + "avg_faiss_overlap": 0.946, + "avg_entropy": 3.5994 + } + ] +}
\ No newline at end of file diff --git a/data/processed/residual_grid_results.json b/data/processed/residual_grid_results.json new file mode 100644 index 0000000..c5b6d28 --- /dev/null +++ b/data/processed/residual_grid_results.json @@ -0,0 +1,1016 @@ +{ + "meta": { + "n_questions": 100, + "total_configs": 100, + "unique_llm_calls": 1666, + "faiss_llm_calls": 100, + "total_time_s": 5113.4 + }, + "faiss_baseline": { + "em": 0.32, + "f1": 0.4381 + }, + "grid_results": [ + { + "beta": 5.0, + "lambda": 0.7, + "max_iter": 1, + "em": 0.36, + "f1": 0.4809, + "avg_faiss_overlap": 0.902, + "avg_entropy": 7.1163 + }, + { + "beta": 5.0, + "lambda": 0.9, + "max_iter": 3, + "em": 0.36, + "f1": 0.4809, + "avg_faiss_overlap": 0.912, + "avg_entropy": 7.1122 + }, + { + "beta": 10.0, + "lambda": 0.7, + "max_iter": 1, + "em": 0.36, + "f1": 0.4809, + "avg_faiss_overlap": 0.9, + "avg_entropy": 6.8422 + }, + { + "beta": 10.0, + "lambda": 0.95, + "max_iter": 8, + "em": 0.36, + "f1": 0.4797, + "avg_faiss_overlap": 0.886, + "avg_entropy": 6.8941 + }, + { + "beta": 5.0, + "lambda": 0.95, + "max_iter": 8, + "em": 0.35, + "f1": 0.4697, + "avg_faiss_overlap": 0.886, + "avg_entropy": 7.1219 + }, + { + "beta": 5.0, + "lambda": 0.95, + "max_iter": 3, + "em": 0.35, + "f1": 0.4692, + "avg_faiss_overlap": 0.956, + "avg_entropy": 7.09 + }, + { + "beta": 5.0, + "lambda": 0.95, + "max_iter": 5, + "em": 0.35, + "f1": 0.4692, + "avg_faiss_overlap": 0.928, + "avg_entropy": 7.105 + }, + { + "beta": 10.0, + "lambda": 0.9, + "max_iter": 1, + "em": 0.35, + "f1": 0.4692, + "avg_faiss_overlap": 0.966, + "avg_entropy": 6.6138 + }, + { + "beta": 10.0, + "lambda": 0.95, + "max_iter": 3, + "em": 0.35, + "f1": 0.4692, + "avg_faiss_overlap": 0.956, + "avg_entropy": 6.6763 + }, + { + "beta": 5.0, + "lambda": 0.9, + "max_iter": 1, + "em": 0.34, + "f1": 0.4672, + "avg_faiss_overlap": 0.97, + "avg_entropy": 7.0815 + }, + { + "beta": 10.0, + "lambda": 0.9, + "max_iter": 5, + "em": 0.34, + "f1": 0.4637, + "avg_faiss_overlap": 0.856, + "avg_entropy": 6.9499 + }, + { + "beta": 5.0, + "lambda": 0.5, + "max_iter": 1, + "em": 0.33, + "f1": 0.4627, + "avg_faiss_overlap": 0.786, + "avg_entropy": 7.1407 + }, + { + "beta": 5.0, + "lambda": 0.9, + "max_iter": 8, + "em": 0.34, + "f1": 0.4625, + "avg_faiss_overlap": 0.724, + "avg_entropy": 7.1479 + }, + { + "beta": 10.0, + "lambda": 0.5, + "max_iter": 1, + "em": 0.33, + "f1": 0.4622, + "avg_faiss_overlap": 0.808, + "avg_entropy": 6.9842 + }, + { + "beta": 100.0, + "lambda": 0.5, + "max_iter": 1, + "em": 0.36, + "f1": 0.4608, + "avg_faiss_overlap": 0.722, + "avg_entropy": 0.0442 + }, + { + "beta": 10.0, + "lambda": 0.9, + "max_iter": 8, + "em": 0.33, + "f1": 0.46, + "avg_faiss_overlap": 0.744, + "avg_entropy": 7.0404 + }, + { + "beta": 5.0, + "lambda": 0.8, + "max_iter": 1, + "em": 0.34, + "f1": 0.4592, + "avg_faiss_overlap": 0.944, + "avg_entropy": 7.1003 + }, + { + "beta": 10.0, + "lambda": 0.8, + "max_iter": 1, + "em": 0.34, + "f1": 0.4592, + "avg_faiss_overlap": 0.938, + "avg_entropy": 6.7407 + }, + { + "beta": 10.0, + "lambda": 0.95, + "max_iter": 5, + "em": 0.34, + "f1": 0.4592, + "avg_faiss_overlap": 0.928, + "avg_entropy": 6.7819 + }, + { + "beta": 5.0, + "lambda": 0.95, + "max_iter": 1, + "em": 0.34, + "f1": 0.4556, + "avg_faiss_overlap": 0.984, + "avg_entropy": 7.0709 + }, + { + "beta": 10.0, + "lambda": 0.95, + "max_iter": 1, + "em": 0.34, + "f1": 0.4556, + "avg_faiss_overlap": 0.984, + "avg_entropy": 6.5398 + }, + { + "beta": 20.0, + "lambda": 0.9, + "max_iter": 1, + "em": 0.34, + "f1": 0.4552, + "avg_faiss_overlap": 0.966, + "avg_entropy": 3.5526 + }, + { + "beta": 20.0, + "lambda": 0.5, + "max_iter": 1, + "em": 0.34, + "f1": 0.4548, + "avg_faiss_overlap": 0.796, + "avg_entropy": 4.0138 + }, + { + "beta": 5.0, + "lambda": 0.8, + "max_iter": 3, + "em": 0.32, + "f1": 0.4527, + "avg_faiss_overlap": 0.8, + "avg_entropy": 7.14 + }, + { + "beta": 10.0, + "lambda": 0.8, + "max_iter": 3, + "em": 0.32, + "f1": 0.4527, + "avg_faiss_overlap": 0.8, + "avg_entropy": 6.9964 + }, + { + "beta": 50.0, + "lambda": 0.8, + "max_iter": 3, + "em": 0.34, + "f1": 0.4524, + "avg_faiss_overlap": 0.74, + "avg_entropy": 0.1516 + }, + { + "beta": 20.0, + "lambda": 0.95, + "max_iter": 1, + "em": 0.34, + "f1": 0.4511, + "avg_faiss_overlap": 0.98, + "avg_entropy": 3.5011 + }, + { + "beta": 10.0, + "lambda": 0.9, + "max_iter": 3, + "em": 0.33, + "f1": 0.4509, + "avg_faiss_overlap": 0.918, + "avg_entropy": 6.828 + }, + { + "beta": 50.0, + "lambda": 0.7, + "max_iter": 5, + "em": 0.34, + "f1": 0.4509, + "avg_faiss_overlap": 0.492, + "avg_entropy": 0.0065 + }, + { + "beta": 50.0, + "lambda": 0.5, + "max_iter": 3, + "em": 0.35, + "f1": 0.4501, + "avg_faiss_overlap": 0.472, + "avg_entropy": 0.0137 + }, + { + "beta": 20.0, + "lambda": 0.8, + "max_iter": 3, + "em": 0.33, + "f1": 0.4498, + "avg_faiss_overlap": 0.798, + "avg_entropy": 4.0296 + }, + { + "beta": 50.0, + "lambda": 0.5, + "max_iter": 1, + "em": 0.34, + "f1": 0.4496, + "avg_faiss_overlap": 0.752, + "avg_entropy": 0.3624 + }, + { + "beta": 20.0, + "lambda": 0.95, + "max_iter": 3, + "em": 0.32, + "f1": 0.4494, + "avg_faiss_overlap": 0.946, + "avg_entropy": 3.5994 + }, + { + "beta": 5.0, + "lambda": 0.9, + "max_iter": 5, + "em": 0.32, + "f1": 0.4487, + "avg_faiss_overlap": 0.842, + "avg_entropy": 7.1313 + }, + { + "beta": 50.0, + "lambda": 0.95, + "max_iter": 1, + "em": 0.34, + "f1": 0.4486, + "avg_faiss_overlap": 0.974, + "avg_entropy": 0.677 + }, + { + "beta": 100.0, + "lambda": 0.95, + "max_iter": 1, + "em": 0.34, + "f1": 0.4464, + "avg_faiss_overlap": 0.97, + "avg_entropy": 0.2081 + }, + { + "beta": 50.0, + "lambda": 0.8, + "max_iter": 8, + "em": 0.33, + "f1": 0.4459, + "avg_faiss_overlap": 0.482, + "avg_entropy": 0.001 + }, + { + "beta": 100.0, + "lambda": 0.8, + "max_iter": 3, + "em": 0.34, + "f1": 0.4458, + "avg_faiss_overlap": 0.704, + "avg_entropy": 0.0034 + }, + { + "beta": 100.0, + "lambda": 0.8, + "max_iter": 1, + "em": 0.33, + "f1": 0.4441, + "avg_faiss_overlap": 0.878, + "avg_entropy": 0.0927 + }, + { + "beta": 20.0, + "lambda": 0.7, + "max_iter": 3, + "em": 0.33, + "f1": 0.4433, + "avg_faiss_overlap": 0.676, + "avg_entropy": 4.2538 + }, + { + "beta": 50.0, + "lambda": 0.95, + "max_iter": 8, + "em": 0.32, + "f1": 0.4431, + "avg_faiss_overlap": 0.808, + "avg_entropy": 0.2083 + }, + { + "beta": 100.0, + "lambda": 0.95, + "max_iter": 8, + "em": 0.32, + "f1": 0.4431, + "avg_faiss_overlap": 0.794, + "avg_entropy": 0.0019 + }, + { + "beta": 50.0, + "lambda": 0.8, + "max_iter": 1, + "em": 0.32, + "f1": 0.443, + "avg_faiss_overlap": 0.894, + "avg_entropy": 0.5053 + }, + { + "beta": 100.0, + "lambda": 0.95, + "max_iter": 3, + "em": 0.33, + "f1": 0.4427, + "avg_faiss_overlap": 0.9, + "avg_entropy": 0.0754 + }, + { + "beta": 10.0, + "lambda": 0.7, + "max_iter": 3, + "em": 0.31, + "f1": 0.4426, + "avg_faiss_overlap": 0.654, + "avg_entropy": 7.0699 + }, + { + "beta": 20.0, + "lambda": 0.95, + "max_iter": 8, + "em": 0.33, + "f1": 0.442, + "avg_faiss_overlap": 0.878, + "avg_entropy": 3.8365 + }, + { + "beta": 50.0, + "lambda": 0.9, + "max_iter": 1, + "em": 0.32, + "f1": 0.4419, + "avg_faiss_overlap": 0.946, + "avg_entropy": 0.6043 + }, + { + "beta": 50.0, + "lambda": 0.7, + "max_iter": 1, + "em": 0.32, + "f1": 0.4415, + "avg_faiss_overlap": 0.836, + "avg_entropy": 0.4413 + }, + { + "beta": 100.0, + "lambda": 0.7, + "max_iter": 1, + "em": 0.32, + "f1": 0.4415, + "avg_faiss_overlap": 0.824, + "avg_entropy": 0.069 + }, + { + "beta": 20.0, + "lambda": 0.7, + "max_iter": 1, + "em": 0.32, + "f1": 0.4413, + "avg_faiss_overlap": 0.888, + "avg_entropy": 3.7761 + }, + { + "beta": 50.0, + "lambda": 0.95, + "max_iter": 5, + "em": 0.32, + "f1": 0.4407, + "avg_faiss_overlap": 0.87, + "avg_entropy": 0.3877 + }, + { + "beta": 50.0, + "lambda": 0.9, + "max_iter": 8, + "em": 0.34, + "f1": 0.4402, + "avg_faiss_overlap": 0.666, + "avg_entropy": 0.0306 + }, + { + "beta": 100.0, + "lambda": 0.9, + "max_iter": 1, + "em": 0.32, + "f1": 0.4397, + "avg_faiss_overlap": 0.932, + "avg_entropy": 0.1471 + }, + { + "beta": 100.0, + "lambda": 0.95, + "max_iter": 5, + "em": 0.32, + "f1": 0.4391, + "avg_faiss_overlap": 0.858, + "avg_entropy": 0.019 + }, + { + "beta": 50.0, + "lambda": 0.8, + "max_iter": 5, + "em": 0.32, + "f1": 0.439, + "avg_faiss_overlap": 0.592, + "avg_entropy": 0.024 + }, + { + "beta": 50.0, + "lambda": 0.5, + "max_iter": 8, + "em": 0.33, + "f1": 0.4381, + "avg_faiss_overlap": 0.41, + "avg_entropy": 0.0 + }, + { + "beta": 50.0, + "lambda": 0.95, + "max_iter": 3, + "em": 0.32, + "f1": 0.4377, + "avg_faiss_overlap": 0.912, + "avg_entropy": 0.5215 + }, + { + "beta": 20.0, + "lambda": 0.9, + "max_iter": 5, + "em": 0.32, + "f1": 0.437, + "avg_faiss_overlap": 0.846, + "avg_entropy": 3.9311 + }, + { + "beta": 100.0, + "lambda": 0.7, + "max_iter": 5, + "em": 0.32, + "f1": 0.4359, + "avg_faiss_overlap": 0.474, + "avg_entropy": 0.0 + }, + { + "beta": 100.0, + "lambda": 0.8, + "max_iter": 8, + "em": 0.32, + "f1": 0.4359, + "avg_faiss_overlap": 0.472, + "avg_entropy": 0.0 + }, + { + "beta": 20.0, + "lambda": 0.8, + "max_iter": 5, + "em": 0.32, + "f1": 0.4356, + "avg_faiss_overlap": 0.658, + "avg_entropy": 4.2886 + }, + { + "beta": 100.0, + "lambda": 0.5, + "max_iter": 3, + "em": 0.33, + "f1": 0.4351, + "avg_faiss_overlap": 0.456, + "avg_entropy": 0.0 + }, + { + "beta": 20.0, + "lambda": 0.8, + "max_iter": 1, + "em": 0.31, + "f1": 0.4349, + "avg_faiss_overlap": 0.924, + "avg_entropy": 3.6613 + }, + { + "beta": 50.0, + "lambda": 0.9, + "max_iter": 3, + "em": 0.31, + "f1": 0.4344, + "avg_faiss_overlap": 0.842, + "avg_entropy": 0.3574 + }, + { + "beta": 20.0, + "lambda": 0.9, + "max_iter": 3, + "em": 0.31, + "f1": 0.4313, + "avg_faiss_overlap": 0.9, + "avg_entropy": 3.7492 + }, + { + "beta": 20.0, + "lambda": 0.9, + "max_iter": 8, + "em": 0.31, + "f1": 0.4302, + "avg_faiss_overlap": 0.748, + "avg_entropy": 4.1579 + }, + { + "beta": 20.0, + "lambda": 0.95, + "max_iter": 5, + "em": 0.31, + "f1": 0.4299, + "avg_faiss_overlap": 0.91, + "avg_entropy": 3.6963 + }, + { + "beta": 50.0, + "lambda": 0.7, + "max_iter": 3, + "em": 0.32, + "f1": 0.4299, + "avg_faiss_overlap": 0.614, + "avg_entropy": 0.0708 + }, + { + "beta": 50.0, + "lambda": 0.9, + "max_iter": 5, + "em": 0.32, + "f1": 0.4291, + "avg_faiss_overlap": 0.778, + "avg_entropy": 0.1599 + }, + { + "beta": 100.0, + "lambda": 0.5, + "max_iter": 8, + "em": 0.32, + "f1": 0.4281, + "avg_faiss_overlap": 0.408, + "avg_entropy": 0.0 + }, + { + "beta": 50.0, + "lambda": 0.7, + "max_iter": 8, + "em": 0.33, + "f1": 0.4257, + "avg_faiss_overlap": 0.426, + "avg_entropy": 0.0 + }, + { + "beta": 10.0, + "lambda": 0.8, + "max_iter": 5, + "em": 0.29, + "f1": 0.4249, + "avg_faiss_overlap": 0.632, + "avg_entropy": 7.0765 + }, + { + "beta": 100.0, + "lambda": 0.9, + "max_iter": 3, + "em": 0.29, + "f1": 0.4244, + "avg_faiss_overlap": 0.828, + "avg_entropy": 0.019 + }, + { + "beta": 50.0, + "lambda": 0.5, + "max_iter": 5, + "em": 0.32, + "f1": 0.4225, + "avg_faiss_overlap": 0.42, + "avg_entropy": 0.0 + }, + { + "beta": 100.0, + "lambda": 0.8, + "max_iter": 5, + "em": 0.31, + "f1": 0.4165, + "avg_faiss_overlap": 0.562, + "avg_entropy": 0.0 + }, + { + "beta": 100.0, + "lambda": 0.7, + "max_iter": 8, + "em": 0.32, + "f1": 0.4157, + "avg_faiss_overlap": 0.422, + "avg_entropy": 0.0 + }, + { + "beta": 100.0, + "lambda": 0.9, + "max_iter": 8, + "em": 0.32, + "f1": 0.4152, + "avg_faiss_overlap": 0.636, + "avg_entropy": 0.0 + }, + { + "beta": 100.0, + "lambda": 0.9, + "max_iter": 5, + "em": 0.3, + "f1": 0.4141, + "avg_faiss_overlap": 0.764, + "avg_entropy": 0.0014 + }, + { + "beta": 20.0, + "lambda": 0.7, + "max_iter": 5, + "em": 0.3, + "f1": 0.4127, + "avg_faiss_overlap": 0.498, + "avg_entropy": 4.4813 + }, + { + "beta": 100.0, + "lambda": 0.5, + "max_iter": 5, + "em": 0.31, + "f1": 0.4125, + "avg_faiss_overlap": 0.416, + "avg_entropy": 0.0 + }, + { + "beta": 5.0, + "lambda": 0.8, + "max_iter": 5, + "em": 0.27, + "f1": 0.4054, + "avg_faiss_overlap": 0.614, + "avg_entropy": 7.1555 + }, + { + "beta": 5.0, + "lambda": 0.7, + "max_iter": 3, + "em": 0.27, + "f1": 0.4051, + "avg_faiss_overlap": 0.64, + "avg_entropy": 7.1544 + }, + { + "beta": 100.0, + "lambda": 0.7, + "max_iter": 3, + "em": 0.3, + "f1": 0.4049, + "avg_faiss_overlap": 0.582, + "avg_entropy": 0.0002 + }, + { + "beta": 20.0, + "lambda": 0.7, + "max_iter": 8, + "em": 0.31, + "f1": 0.3988, + "avg_faiss_overlap": 0.28, + "avg_entropy": 4.5698 + }, + { + "beta": 20.0, + "lambda": 0.5, + "max_iter": 5, + "em": 0.32, + "f1": 0.3946, + "avg_faiss_overlap": 0.258, + "avg_entropy": 4.6045 + }, + { + "beta": 20.0, + "lambda": 0.8, + "max_iter": 8, + "em": 0.28, + "f1": 0.3927, + "avg_faiss_overlap": 0.48, + "avg_entropy": 4.4867 + }, + { + "beta": 20.0, + "lambda": 0.5, + "max_iter": 3, + "em": 0.29, + "f1": 0.3925, + "avg_faiss_overlap": 0.452, + "avg_entropy": 4.5183 + }, + { + "beta": 10.0, + "lambda": 0.7, + "max_iter": 5, + "em": 0.3, + "f1": 0.379, + "avg_faiss_overlap": 0.334, + "avg_entropy": 7.112 + }, + { + "beta": 10.0, + "lambda": 0.8, + "max_iter": 8, + "em": 0.28, + "f1": 0.3659, + "avg_faiss_overlap": 0.318, + "avg_entropy": 7.1124 + }, + { + "beta": 10.0, + "lambda": 0.5, + "max_iter": 3, + "em": 0.28, + "f1": 0.353, + "avg_faiss_overlap": 0.222, + "avg_entropy": 7.1169 + }, + { + "beta": 5.0, + "lambda": 0.5, + "max_iter": 3, + "em": 0.27, + "f1": 0.3436, + "avg_faiss_overlap": 0.156, + "avg_entropy": 7.1645 + }, + { + "beta": 20.0, + "lambda": 0.5, + "max_iter": 8, + "em": 0.25, + "f1": 0.3298, + "avg_faiss_overlap": 0.186, + "avg_entropy": 4.565 + }, + { + "beta": 5.0, + "lambda": 0.7, + "max_iter": 5, + "em": 0.24, + "f1": 0.3274, + "avg_faiss_overlap": 0.278, + "avg_entropy": 7.1633 + }, + { + "beta": 5.0, + "lambda": 0.8, + "max_iter": 8, + "em": 0.24, + "f1": 0.3274, + "avg_faiss_overlap": 0.276, + "avg_entropy": 7.1634 + }, + { + "beta": 10.0, + "lambda": 0.7, + "max_iter": 8, + "em": 0.15, + "f1": 0.1995, + "avg_faiss_overlap": 0.044, + "avg_entropy": 7.123 + }, + { + "beta": 10.0, + "lambda": 0.5, + "max_iter": 5, + "em": 0.15, + "f1": 0.1983, + "avg_faiss_overlap": 0.022, + "avg_entropy": 7.1239 + }, + { + "beta": 5.0, + "lambda": 0.5, + "max_iter": 5, + "em": 0.14, + "f1": 0.1895, + "avg_faiss_overlap": 0.018, + "avg_entropy": 7.166 + }, + { + "beta": 5.0, + "lambda": 0.5, + "max_iter": 8, + "em": 0.14, + "f1": 0.1877, + "avg_faiss_overlap": 0.002, + "avg_entropy": 7.1661 + }, + { + "beta": 5.0, + "lambda": 0.7, + "max_iter": 8, + "em": 0.12, + "f1": 0.1824, + "avg_faiss_overlap": 0.034, + "avg_entropy": 7.1658 + }, + { + "beta": 10.0, + "lambda": 0.5, + "max_iter": 8, + "em": 0.12, + "f1": 0.1766, + "avg_faiss_overlap": 0.002, + "avg_entropy": 7.1239 + } + ], + "best_config": { + "beta": 5.0, + "lambda": 0.7, + "max_iter": 1, + "em": 0.36, + "f1": 0.4809, + "avg_faiss_overlap": 0.902, + "avg_entropy": 7.1163 + }, + "top10": [ + { + "beta": 5.0, + "lambda": 0.7, + "max_iter": 1, + "em": 0.36, + "f1": 0.4809, + "avg_faiss_overlap": 0.902, + "avg_entropy": 7.1163 + }, + { + "beta": 5.0, + "lambda": 0.9, + "max_iter": 3, + "em": 0.36, + "f1": 0.4809, + "avg_faiss_overlap": 0.912, + "avg_entropy": 7.1122 + }, + { + "beta": 10.0, + "lambda": 0.7, + "max_iter": 1, + "em": 0.36, + "f1": 0.4809, + "avg_faiss_overlap": 0.9, + "avg_entropy": 6.8422 + }, + { + "beta": 10.0, + "lambda": 0.95, + "max_iter": 8, + "em": 0.36, + "f1": 0.4797, + "avg_faiss_overlap": 0.886, + "avg_entropy": 6.8941 + }, + { + "beta": 5.0, + "lambda": 0.95, + "max_iter": 8, + "em": 0.35, + "f1": 0.4697, + "avg_faiss_overlap": 0.886, + "avg_entropy": 7.1219 + }, + { + "beta": 5.0, + "lambda": 0.95, + "max_iter": 3, + "em": 0.35, + "f1": 0.4692, + "avg_faiss_overlap": 0.956, + "avg_entropy": 7.09 + }, + { + "beta": 5.0, + "lambda": 0.95, + "max_iter": 5, + "em": 0.35, + "f1": 0.4692, + "avg_faiss_overlap": 0.928, + "avg_entropy": 7.105 + }, + { + "beta": 10.0, + "lambda": 0.9, + "max_iter": 1, + "em": 0.35, + "f1": 0.4692, + "avg_faiss_overlap": 0.966, + "avg_entropy": 6.6138 + }, + { + "beta": 10.0, + "lambda": 0.95, + "max_iter": 3, + "em": 0.35, + "f1": 0.4692, + "avg_faiss_overlap": 0.956, + "avg_entropy": 6.6763 + }, + { + "beta": 5.0, + "lambda": 0.9, + "max_iter": 1, + "em": 0.34, + "f1": 0.4672, + "avg_faiss_overlap": 0.97, + "avg_entropy": 7.0815 + } + ] +}
\ No newline at end of file diff --git a/figures/fig1_contour.png b/figures/fig1_contour.png Binary files differnew file mode 100644 index 0000000..87aca0b --- /dev/null +++ b/figures/fig1_contour.png diff --git a/figures/fig2_profile.png b/figures/fig2_profile.png Binary files differnew file mode 100644 index 0000000..eb25eae --- /dev/null +++ b/figures/fig2_profile.png diff --git a/figures/fig3_umap.png b/figures/fig3_umap.png Binary files differnew file mode 100644 index 0000000..681ccb7 --- /dev/null +++ b/figures/fig3_umap.png diff --git a/figures/fig4_pca.png b/figures/fig4_pca.png Binary files differnew file mode 100644 index 0000000..35dee46 --- /dev/null +++ b/figures/fig4_pca.png diff --git a/hag/config.py b/hag/config.py index 793e3a6..10d0aff 100644 --- a/hag/config.py +++ b/hag/config.py @@ -19,6 +19,7 @@ class MemoryBankConfig: embedding_dim: int = 768 # Must match encoder output dim normalize: bool = True # L2-normalize embeddings in memory bank + center: bool = False # Mean-center embeddings to remove centroid attractor @dataclass @@ -35,7 +36,7 @@ class GeneratorConfig: """Configuration for the LLM generator.""" model_name: str = "meta-llama/Llama-3.1-8B-Instruct" - max_new_tokens: int = 128 + max_new_tokens: int = 32 temperature: float = 0.0 # Greedy decoding for reproducibility @@ -48,3 +49,4 @@ class PipelineConfig: encoder: EncoderConfig = field(default_factory=EncoderConfig) generator: GeneratorConfig = field(default_factory=GeneratorConfig) retriever_type: str = "hopfield" # "hopfield" or "faiss" + device: str = "cpu" # "cpu", "cuda", "cuda:0", etc. diff --git a/hag/encoder.py b/hag/encoder.py index 7e103f3..c380ad1 100644 --- a/hag/encoder.py +++ b/hag/encoder.py @@ -17,18 +17,20 @@ class Encoder: For testing, use FakeEncoder instead. """ - def __init__(self, config: EncoderConfig) -> None: + def __init__(self, config: EncoderConfig, device: str = "cpu") -> None: self.config = config + self.device = torch.device(device) self._tokenizer = None self._model = None def _load_model(self) -> None: - """Lazy-load the model and tokenizer.""" + """Lazy-load the model and tokenizer, placing model on device.""" from transformers import AutoModel, AutoTokenizer - logger.info("Loading encoder model: %s", self.config.model_name) + logger.info("Loading encoder model: %s (device=%s)", self.config.model_name, self.device) self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name) self._model = AutoModel.from_pretrained(self.config.model_name) + self._model.to(self.device) self._model.eval() @torch.no_grad() @@ -39,7 +41,7 @@ class Encoder: texts: single string or list of strings Returns: - (1, d) tensor for single input, (N, d) for list input. + (1, d) tensor for single input, (N, d) for list input. On self.device. """ if self._model is None: self._load_model() @@ -54,6 +56,7 @@ class Encoder: truncation=True, return_tensors="pt", ) + inputs = {k: v.to(self.device) for k, v in inputs.items()} outputs = self._model(**inputs) # Mean pooling over token embeddings embeddings = outputs.last_hidden_state.mean(dim=1) # (N, d) diff --git a/hag/generator.py b/hag/generator.py index 2142e0c..d0de468 100644 --- a/hag/generator.py +++ b/hag/generator.py @@ -3,11 +3,13 @@ import logging from typing import List +import torch + from hag.config import GeneratorConfig logger = logging.getLogger(__name__) -PROMPT_TEMPLATE = """Answer the following question based on the provided context passages. +PROMPT_TEMPLATE = """Answer the following question based on the provided context passages. Give ONLY the answer itself in a few words, with no explanation. Context: {context} @@ -24,20 +26,22 @@ class Generator: For testing, use FakeGenerator instead. """ - def __init__(self, config: GeneratorConfig) -> None: + def __init__(self, config: GeneratorConfig, device: str = "cpu") -> None: self.config = config + self.device = torch.device(device) self._tokenizer = None self._model = None def _load_model(self) -> None: - """Lazy-load the model and tokenizer.""" + """Lazy-load the model and tokenizer, placing model on device.""" from transformers import AutoModelForCausalLM, AutoTokenizer - logger.info("Loading generator model: %s", self.config.model_name) + logger.info("Loading generator model: %s (device=%s)", self.config.model_name, self.device) self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name) self._model = AutoModelForCausalLM.from_pretrained( self.config.model_name, torch_dtype="auto", + device_map=self.device, ) self._model.eval() @@ -60,15 +64,23 @@ class Generator: prompt = PROMPT_TEMPLATE.format(context=context, question=question) inputs = self._tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to(self.device) for k, v in inputs.items()} outputs = self._model.generate( **inputs, max_new_tokens=self.config.max_new_tokens, temperature=self.config.temperature if self.config.temperature > 0 else None, do_sample=self.config.temperature > 0, + repetition_penalty=1.2, ) # Decode only the generated tokens (skip the prompt) generated = outputs[0][inputs["input_ids"].shape[1]:] - return self._tokenizer.decode(generated, skip_special_tokens=True).strip() + answer = self._tokenizer.decode(generated, skip_special_tokens=True).strip() + # Take only the first sentence/line as the answer + for sep in ["\n", ". ", ".\n"]: + if sep in answer: + answer = answer.split(sep)[0].strip() + break + return answer class FakeGenerator: diff --git a/hag/memory_bank.py b/hag/memory_bank.py index 42dcc73..0a0a87c 100644 --- a/hag/memory_bank.py +++ b/hag/memory_bank.py @@ -16,12 +16,17 @@ class MemoryBank: The memory bank is M in R^{d x N} where each column is a passage embedding. Also maintains a mapping from column index to passage text for final retrieval. + + When config.center=True, embeddings are mean-centered to remove the centroid + attractor in Hopfield dynamics. The mean is saved so queries can be centered + with the same offset via center_query(). """ def __init__(self, config: MemoryBankConfig) -> None: self.config = config self.embeddings: Optional[torch.Tensor] = None # (d, N) self.passages: List[str] = [] + self.mean: Optional[torch.Tensor] = None # (d,) — saved for query centering def build_from_embeddings( self, embeddings: torch.Tensor, passages: List[str] @@ -38,10 +43,42 @@ class MemoryBank: ) if self.config.normalize: embeddings = F.normalize(embeddings, dim=-1) + if self.config.center: + self.mean = embeddings.mean(dim=0) # (d,) + embeddings = embeddings - self.mean.unsqueeze(0) # (N, d) + logger.info("Centered memory bank (removed mean)") self.embeddings = embeddings.T # Store as (d, N) for efficient matmul self.passages = list(passages) logger.info("Built memory bank with %d passages, dim=%d", self.size, self.dim) + def center_query(self, query: torch.Tensor) -> torch.Tensor: + """Center a query embedding using the saved memory mean. + + Must be called before Hopfield retrieval when config.center=True. + + Args: + query: (d,) or (batch, d) — query embedding(s) + + Returns: + Centered query tensor, same shape as input. + """ + if self.mean is None: + return query + return query - self.mean.to(query.device) + + def apply_centering(self) -> None: + """Center an already-loaded (uncentered) memory bank in-place. + + Useful when loading a memory bank that was saved without centering. + Computes and stores the mean, then subtracts it from embeddings. + """ + if self.embeddings is None: + return + # embeddings is (d, N), mean over columns + self.mean = self.embeddings.mean(dim=1) # (d,) + self.embeddings = self.embeddings - self.mean.unsqueeze(1) # (d, N) + logger.info("Applied centering to loaded memory bank") + def get_passages_by_indices(self, indices: torch.Tensor) -> List[str]: """Given top-k indices, return corresponding passage texts. @@ -67,20 +104,38 @@ class MemoryBank: "embedding_dim": self.config.embedding_dim, "normalize": self.config.normalize, }, + "mean": self.mean, } torch.save(data, path) logger.info("Saved memory bank to %s", path) - def load(self, path: str) -> None: + def load(self, path: str, device: str = "cpu") -> None: """Load memory bank from disk. Args: path: file path to load from + device: device to load tensors onto ("cpu", "cuda", "cuda:0", etc.) """ - data = torch.load(path, weights_only=False) + data = torch.load(path, weights_only=False, map_location=device) self.embeddings = data["embeddings"] self.passages = data["passages"] - logger.info("Loaded memory bank from %s (%d passages)", path, self.size) + self.mean = data.get("mean", None) + logger.info("Loaded memory bank from %s (%d passages, device=%s)", path, self.size, device) + + def to(self, device: str) -> "MemoryBank": + """Move memory bank embeddings to the specified device. + + Args: + device: target device ("cpu", "cuda", "cuda:0", etc.) + + Returns: + self (for chaining). + """ + if self.embeddings is not None: + self.embeddings = self.embeddings.to(device) + if self.mean is not None: + self.mean = self.mean.to(device) + return self @property def size(self) -> int: diff --git a/hag/pipeline.py b/hag/pipeline.py index 1fefb84..086b3be 100644 --- a/hag/pipeline.py +++ b/hag/pipeline.py @@ -82,7 +82,8 @@ class RAGPipeline: if self.retriever_type == "hopfield": retrieval_result = self.hopfield_retriever.retrieve(query_emb) else: - query_np = query_emb.detach().numpy().astype(np.float32) + # FAISS requires CPU numpy arrays + query_np = query_emb.detach().cpu().numpy().astype(np.float32) retrieval_result = self.faiss_retriever.retrieve(query_np) # Generate @@ -0,0 +1,181 @@ +# HAG 实验记录 + +## 实验环境 + +- Memory bank: HotpotQA, 1311 passages, dim=768 (Contriever-MSMARCO) +- Generator: Llama-3.1-8B-Instruct, temperature=0 (greedy) +- Evaluation: 100 questions, EM + F1 +- Hardware: NVIDIA RTX A6000 × 4 + +--- + +## Round 1: 初始 Grid Search (β ≤ 8) + +**日期**: 2026-02-15 +**脚本**: `scripts/run_grid_search.py` +**Grid**: β=[0.25, 0.5, 1.0, 2.0, 3.0, 5.0, 8.0] × iter=[1,2,3,5,8,15], top_k=5 +**Dedup**: 381 LLM calls / 4300 grid evals (91.1% saving) + +| 方法 | EM | F1 | +|---|---|---| +| FAISS baseline | 0.320 | 0.438 | +| Best HAG (β=8, iter=1) | 0.210 | 0.294 | + +**结论**: 全军覆没。所有 42 个 HAG 配置均低于 FAISS baseline。 + +**问题诊断**: +- Entropy ≈ 7.17 / 7.18 (≈log(1311)),attention 接近均匀分布 +- FAISS overlap 极低 (0.2%-6.8%) +- iter>1 普遍更差,所有 query 被吸到同一个 centroid + +--- + +## Root Cause Analysis: Centroid Attractor + +### 发现 1: 能量面结构问题 + +E(q) = -1/β · logsumexp(β · qᵀM) + 1/2 · ‖q‖² + +| β | E(centroid) | E(memory_i) | 谁更低? | +|---|---|---|---| +| 1.0 | **-7.375** | -7.068 | centroid | +| 8.0 | **-1.097** | -0.812 | centroid | +| 50.0 | -0.357 | **-0.500** | memory | + +**原因**: ‖centroid‖=0.63 vs ‖memory‖=1.0,norm 惩罚项 1/2‖q‖² 让 centroid 白省 0.3 能量。β<50 时 centroid 是全局最低能量点。 + +### 发现 2: 一步迭代即崩溃 + +β=8 时: +- t=0: query 正常,top3 是正确 passages +- t=1: cos(q, centroid) = 0.993,‖q‖从 1.36 骤降到 0.64 +- t=2: cos(q, centroid) = 0.999,所有 query 收敛到同一点 + +**机制**: softmax(β·qᵀM) 近均匀 → q_new = Σ αᵢmᵢ ≈ centroid → norm 缩小 → 下一步更均匀 → 恶性循环 + +### 发现 3: 去掉 norm 项不够 + +即使 normalize 每步 update(‖q‖=1),β≤8 时仍收敛到 normalized centroid 方向。因为 softmax averaging 本身就把 query 拉向 centroid 方向,与 norm 项无关。 + +### 发现 4: Memory bank 结构 + +- 1311 passages pairwise cosine sim: mean=0.392, std=0.065 +- Centroid 与所有 memory 的 cosine sim: mean=0.627 (很高) +- β 需要 ≈50+ 才能让 softmax 区分 passages(entropy 从 99% 降到 10%) + +### 关键洞察 + +**max_iter=1 在代码中意味着做 1 次 Hopfield 更新**(不是 0 次)。Grid search 从未测试 iter=0(纯 softmax top-k = FAISS)。这解释了为什么 "iter=1" 已经比 FAISS 差很多。 + +--- + +## Round 2: 20-question 方案探索 + +在 20 个问题上快速测试 4 种修复方案(带 LLM 生成): + +### A. 高 β(50, 100)纯 Hopfield + +| 配置 | EM | F1 | +|---|---|---| +| FAISS | 0.300 | 0.456 | +| β=50 iter=1 | 0.300 | 0.417 | +| β=100 iter=1 | 0.300 | 0.417 | + +高 β 使初始 attention 变尖锐,但迭代后 query 仍偏移,F1 反而下降。 + +### B. Pre-filter K + Hopfield + +| 配置 | EM | F1 | +|---|---|---| +| PreFilter K=20 → β=5 iter=1 | **0.350** | **0.500** | +| PreFilter K=20 → β=20 iter=1 | 0.350 | 0.462 | +| PreFilter K=50 → β=5 iter=1 | 0.350 | 0.462 | + +最优方案,但本质是 FAISS 初筛 + Hopfield 重排,削弱了 "用 Hopfield 替代 FAISS" 的叙事。 + +### C. Residual Connection + +q_{t+1} = λ · q_t + (1-λ) · M @ softmax(β · Mᵀ · q_t) + +| 配置 | EM | F1 | +|---|---|---| +| β=20 λ=0.9 iter=3 | 0.350 | 0.473 | +| β=50 λ=0.9 iter=3 | 0.350 | 0.473 | +| β=50 λ=0.7 iter=3 | 0.350 | 0.433 | + +λ=0.9(保留 90% 原始 query)有效防止 centroid 崩塌。 + +### D. Pre-filter + Residual + +| 配置 | EM | F1 | +|---|---|---| +| PF K=20 + β=5 λ=0.9 iter=3 | 0.350 | 0.473 | +| PF K=50 + β=10 λ=0.9 iter=3 | 0.350 | 0.473 | + +没有比单独用 Residual 更好。 + +**决策**: 选择 pure Hopfield + Residual 路线(不依赖 FAISS pre-filter)。 + +--- + +## Round 3: Residual Grid (100 questions) + +**脚本**: `scripts/eval_residual_grid.py` +**Grid**: β=[5,10,20,50,100] × λ=[0.5,0.7,0.8,0.9,0.95] × iter=[1,3,5,8] +**Dedup**: 1666 unique LLM calls + +| 排名 | 配置 | EM | F1 | FAISS overlap | +|---|---|---|---|---| +| - | FAISS baseline | 0.320 | 0.438 | 1.000 | +| 1 | **β=5 λ=0.7 iter=1** | **0.360** | **0.481** | 0.902 | +| 2 | β=5 λ=0.9 iter=3 | 0.360 | 0.481 | 0.912 | +| 3 | β=10 λ=0.7 iter=1 | 0.360 | 0.481 | 0.900 | +| 4 | β=10 λ=0.95 iter=8 | 0.360 | 0.480 | 0.886 | +| 5 | β=5 λ=0.95 iter=8 | 0.350 | 0.470 | 0.886 | + +**55/100 configs beat FAISS F1**。Residual 路线 robust。 + +--- + +## Round 4: High-β Grid (100 questions) + +**脚本**: `scripts/eval_highbeta_grid.py` +**Grid**: β=[20,50,100,200,500] × iter=[0,1,2,3,5,8] × mode=[standard, normalized, residual_0.9, residual_0.95] +**Dedup**: 1379 unique LLM calls + +| 排名 | 配置 | EM | F1 | FAISS overlap | +|---|---|---|---|---| +| - | FAISS baseline | 0.320 | 0.438 | 1.000 | +| 1 | **β=20 iter=1 standard** | **0.380** | **0.469** | 0.480 | +| 2 | β=50 iter=1 standard | 0.360 | 0.457 | 0.508 | +| 3 | β=20 iter=1 residual_0.9 | 0.340 | 0.455 | 0.966 | +| 4 | β=20 iter=2 residual_0.95 | 0.340 | 0.455 | 0.966 | +| 5 | β=500 iter=1 residual_0.9 | 0.360 | 0.455 | 0.692 | +| 6 | β=50 iter=1 normalized | 0.370 | 0.454 | 0.464 | + +**31/105 configs beat FAISS F1**。Standard 高 β 也能 work 但 overlap 低(~0.5,换了一半 passages)。 + +--- + +## 综合结论 + +### 最优配置 + +| 方法 | EM | F1 | vs FAISS | +|---|---|---|---| +| FAISS baseline | 0.320 | 0.438 | - | +| Residual β=5 λ=0.7 iter=1 | 0.360 | 0.481 | **+4.3% F1** | +| Standard β=20 iter=1 | 0.380 | 0.469 | **+3.1% F1** | + +### 规律 + +1. **iter=1 普遍最优**,多次迭代有害(centroid 吸引) +2. **低 β + residual > 高 β + standard**:保守修正比激进替换好 +3. **Residual 更 robust**: 55% configs beat FAISS vs 30% for high-β +4. **FAISS overlap 高 = 好**: 最优 residual configs overlap ≈ 0.9(只换 10% passages 即提升) + +### 未解决问题 + +- 只做了 1 步迭代就最优,"iterative refinement" 的叙事受限 +- Centroid attractor 在低 β 时仍存在,需要根本性解决方案 +- 100 questions 样本量偏小,需要在 500+ 上验证 diff --git a/scripts/analyze_energy.py b/scripts/analyze_energy.py index fd044a4..cd93b15 100644 --- a/scripts/analyze_energy.py +++ b/scripts/analyze_energy.py @@ -32,6 +32,7 @@ def main() -> None: parser.add_argument("--memory-bank", type=str, required=True) parser.add_argument("--questions", type=str, required=True) parser.add_argument("--output", type=str, default="energy_analysis.json") + parser.add_argument("--device", type=str, default="cpu") args = parser.parse_args() with open(args.config) as f: @@ -43,13 +44,13 @@ def main() -> None: # Load memory bank mb = MemoryBank(memory_config) - mb.load(args.memory_bank) + mb.load(args.memory_bank, device=args.device) # Load questions with open(args.questions) as f: questions = [json.loads(line)["question"] for line in f] - encoder = Encoder(encoder_config) + encoder = Encoder(encoder_config, device=args.device) hopfield = HopfieldRetrieval(hopfield_config) analyses = [] diff --git a/scripts/build_memory_bank.py b/scripts/build_memory_bank.py index 2aff828..0fc1c51 100644 --- a/scripts/build_memory_bank.py +++ b/scripts/build_memory_bank.py @@ -50,13 +50,13 @@ def main() -> None: logger.info("Loaded %d passages", len(passages)) # Encode passages in batches - encoder = Encoder(encoder_config) + encoder = Encoder(encoder_config, device=args.device) all_embeddings = [] for i in tqdm(range(0, len(passages), encoder_config.batch_size), desc="Encoding"): batch = passages[i : i + encoder_config.batch_size] emb = encoder.encode(batch) # (batch_size, d) - all_embeddings.append(emb.cpu()) + all_embeddings.append(emb.cpu()) # Always store on CPU for saving embeddings = torch.cat(all_embeddings, dim=0) # (N, d) logger.info("Encoded %d passages -> embeddings shape: %s", len(passages), embeddings.shape) diff --git a/scripts/diagnose_centering.py b/scripts/diagnose_centering.py new file mode 100644 index 0000000..5b9b4ee --- /dev/null +++ b/scripts/diagnose_centering.py @@ -0,0 +1,301 @@ +"""Diagnose centering dynamics: trace q step by step, verify β_critical theory.""" + +import sys +import torch +import torch.nn.functional as F +import numpy as np + +sys.path.insert(0, "/home/yurenh2/HAG") + +from hag.memory_bank import MemoryBank +from hag.config import MemoryBankConfig + +# ── Load memory bank ───────────────────────────────────────────────── +device = "cuda:0" # CUDA_VISIBLE_DEVICES remaps +mb = MemoryBank(MemoryBankConfig(embedding_dim=768, normalize=True, center=False)) +mb.load("/home/yurenh2/HAG/data/processed/hotpotqa_memory_bank.pt", device=device) +M_raw = mb.embeddings # (d, N), L2-normalized, NOT centered +d, N = M_raw.shape +print(f"Memory bank: d={d}, N={N}") + +# ── Center manually ────────────────────────────────────────────────── +mu = M_raw.mean(dim=1) # (d,) +M_cent = M_raw - mu.unsqueeze(1) # (d, N) centered +print(f"‖μ‖ = {mu.norm():.4f}") +print(f"‖M̃·1/N‖ = {(M_cent.mean(dim=1)).norm():.2e} (should be ~0)") + +# ── Column norms ───────────────────────────────────────────────────── +col_norms_raw = M_raw.norm(dim=0) +col_norms_cent = M_cent.norm(dim=0) +print(f"\nRaw column norms: mean={col_norms_raw.mean():.4f}, std={col_norms_raw.std():.4f}") +print(f"Centered column norms: mean={col_norms_cent.mean():.4f}, std={col_norms_cent.std():.4f}") + +# ── SVD and β_critical ────────────────────────────────────────────── +# M̃M̃ᵀ/N is the sample covariance. β_crit = N / λ_max(M̃M̃ᵀ) = 1/λ_max(C) +# where C = M̃M̃ᵀ/N +# But let's compute via SVD of M̃ +print("\nComputing SVD of centered memory...") +U, S, Vh = torch.linalg.svd(M_cent, full_matrices=False) # S shape: (min(d,N),) +print(f"Top 10 singular values: {S[:10].cpu().tolist()}") +lambda_max_MMT = S[0].item() ** 2 # largest eigenvalue of M̃M̃ᵀ +lambda_max_C = lambda_max_MMT / N # largest eigenvalue of M̃M̃ᵀ/N + +print(f"\nλ_max(M̃M̃ᵀ) = {lambda_max_MMT:.2f}") +print(f"λ_max(C=M̃M̃ᵀ/N) = {lambda_max_C:.4f}") + +# Jacobian at origin: DT = β/N · M̃M̃ᵀ +# Spectral radius = β/N · λ_max(M̃M̃ᵀ) = β · λ_max(C) +# For instability: β · λ_max(C) > 1 → β > 1/λ_max(C) +beta_crit = 1.0 / lambda_max_C +print(f"\nβ_critical = 1/λ_max(C) = {beta_crit:.4f}") +print("For β > β_crit, origin is UNSTABLE (Jacobian spectral radius > 1)") + +for beta in [0.5, 1.0, 5.0, 10.0, 20.0, 50.0, 100.0]: + rho = beta * lambda_max_C + print(f" β={beta:6.1f}: Jacobian spectral radius = {rho:.2f} {'UNSTABLE' if rho > 1 else 'stable'}") + +# ── Load some real queries ─────────────────────────────────────────── +import json +questions_path = "/home/yurenh2/HAG/data/processed/hotpotqa_questions.jsonl" +with open(questions_path) as f: + questions = [json.loads(line) for line in f][:5] + +from transformers import AutoTokenizer, AutoModel +print("\nLoading encoder...") +tokenizer = AutoTokenizer.from_pretrained("facebook/contriever-msmarco") +model = AutoModel.from_pretrained("facebook/contriever-msmarco").to(device) +model.eval() + +def encode(texts): + inputs = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device) + with torch.no_grad(): + outputs = model(**inputs) + # Contriever uses mean pooling + mask = inputs["attention_mask"].unsqueeze(-1).float() + emb = (outputs.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1) + return F.normalize(emb, dim=-1) + +q_texts = [q["question"] for q in questions] +q_embs = encode(q_texts) # (5, d) + +# ── Trace dynamics step by step (centered) ─────────────────────────── +print("\n" + "=" * 80) +print("CENTERED DYNAMICS TRACE (5 queries)") +print("=" * 80) + +for beta in [5.0, 20.0, 50.0, 100.0]: + print(f"\n--- β = {beta} ---") + + for qi in range(min(3, len(q_texts))): + q0_raw = q_embs[qi:qi+1] # (1, d) + q0 = q0_raw - mu.unsqueeze(0) # center the query + + q = q0.clone() + print(f"\n Q{qi}: '{q_texts[qi][:60]}...'") + print(f" t=0: ‖q‖={q.norm():.4f}") + + for t in range(8): + logits = beta * (q @ M_cent) # (1, N) + alpha = torch.softmax(logits, dim=-1) # (1, N) + entropy = -(alpha * alpha.log()).sum().item() + max_alpha = alpha.max().item() + q_new = alpha @ M_cent.T # (1, d) + + # How close is q_new to the dominant eigenvector? + cos_v1 = abs((q_new / (q_new.norm() + 1e-12)) @ U[:, 0]).item() + + # FAISS-equivalent: initial attention from raw query on raw memory + if t == 0: + logits_raw = beta * (q0_raw @ M_raw) + alpha_raw = torch.softmax(logits_raw, dim=-1) + _, top5_raw = alpha_raw.topk(5) + + # Centered attention top-5 + _, top5_cent = alpha.topk(5) + overlap = len(set(top5_raw[0].cpu().tolist()) & set(top5_cent[0].cpu().tolist())) + print(f" initial overlap(raw top5, centered top5) = {overlap}/5") + + delta = (q_new - q).norm().item() + q = q_new + print(f" t={t+1}: ‖q‖={q.norm():.6f}, H(α)={entropy:.2f}/{np.log(N):.2f}, " + f"max(α)={max_alpha:.2e}, Δ={delta:.2e}, cos(q,v1)={cos_v1:.4f}") + + if delta < 1e-8: + print(f" [converged]") + break + + # Final top-5 from centered attention + logits_final = beta * (q @ M_cent) + alpha_final = torch.softmax(logits_final, dim=-1) + _, top5_final = alpha_final.topk(5) + + # Compare to "iter=0" (just softmax on centered, no iteration) + logits_iter0 = beta * (q0 @ M_cent) + alpha_iter0 = torch.softmax(logits_iter0, dim=-1) + _, top5_iter0 = alpha_iter0.topk(5) + + # And raw (FAISS-like) + logits_faiss = beta * (q0_raw @ M_raw) + alpha_faiss = torch.softmax(logits_faiss, dim=-1) + _, top5_faiss = alpha_faiss.topk(5) + + print(f" Top-5 FAISS: {top5_faiss[0].cpu().tolist()}") + print(f" Top-5 cent t=0: {top5_iter0[0].cpu().tolist()}") + print(f" Top-5 cent final: {top5_final[0].cpu().tolist()}") + +# ── KEY TEST: Does the Jacobian actually amplify near origin? ──────── +print("\n" + "=" * 80) +print("JACOBIAN TEST: Start near origin, see if dynamics amplify") +print("=" * 80) + +for beta in [5.0, 50.0, 100.0]: + # Start with a tiny perturbation in direction of top eigenvector + eps = 1e-4 + q_tiny = eps * U[:, 0].unsqueeze(0) # (1, d), tiny perturbation along v1 + + print(f"\n--- β = {beta}, q0 = {eps}·v1 ---") + q = q_tiny.clone() + for t in range(10): + logits = beta * (q @ M_cent) + alpha = torch.softmax(logits, dim=-1) + q_new = alpha @ M_cent.T + amplification = q_new.norm().item() / (q.norm().item() + 1e-20) + print(f" t={t}: ‖q‖={q.norm():.6e} → ‖q_new‖={q_new.norm():.6e}, " + f"amplification={amplification:.2f}") + q = q_new + if q.norm().item() > 1.0: + print(f" [escaped origin at t={t}]") + break + +# ── Alternative: What about removing the ‖q‖² term from energy? ───── +# The standard update q_{t+1} = M·softmax(β·Mᵀq) minimizes +# E(q) = -1/β·lse(β·Mᵀq) + 1/2·‖q‖² +# What if we don't want the ‖q‖² penalty? Then the fixed point equation +# is just q* = M·softmax(β·Mᵀq*), same update but different energy landscape. +# The issue is: with centering, M̃·uniform = 0 regardless of energy. +# The ‖q‖² penalty is NOT the problem for centering — the averaging is. + +print("\n" + "=" * 80) +print("DIAGNOSIS: Why centering fails for iteration") +print("=" * 80) + +# For β=50, show the attention distribution at t=0 (before any iteration) +beta = 50.0 +q0_raw = q_embs[0:1] +q0_cent = q0_raw - mu.unsqueeze(0) +logits_cent = beta * (q0_cent @ M_cent) # (1, N) +alpha_cent = torch.softmax(logits_cent, dim=-1) +entropy_cent = -(alpha_cent * alpha_cent.log()).sum().item() +max_alpha_cent = alpha_cent.max().item() + +logits_raw = beta * (q0_raw @ M_raw) +alpha_raw = torch.softmax(logits_raw, dim=-1) +entropy_raw = -(alpha_raw * alpha_raw.log()).sum().item() + +print(f"\nβ={beta}, Q0: '{q_texts[0][:60]}...'") +print(f"Raw attention: entropy={entropy_raw:.2f}, max={alpha_raw.max():.4f}") +print(f"Cent attention: entropy={entropy_cent:.2f}, max={max_alpha_cent:.4f}") +print(f"‖q0_cent‖ = {q0_cent.norm():.4f}") +print(f"‖q0_raw‖ = {q0_raw.norm():.4f}") + +# Show: what's the actual q1 norm vs predicted from Jacobian? +q1_cent = alpha_cent @ M_cent.T # (1, d) +predicted_norm = (beta / N * lambda_max_MMT) * q0_cent.norm().item() # rough bound +print(f"\n‖q1_cent‖ actual = {q1_cent.norm():.6f}") +print(f"‖q0_cent‖ × Jacobian_spectral_radius ≈ {q0_cent.norm():.4f} × {beta*lambda_max_C:.2f} = {q0_cent.norm().item()*beta*lambda_max_C:.4f}") +print(f"But the linearization only holds for q→0. q0 is NOT near zero.") + +# The real issue: softmax(β·M̃ᵀ·q0) when q0 has ‖q‖=0.5 +# The logits have some spread, but the weighted average of CENTERED vectors +# inherently cancels out. +weighted_avg = (alpha_cent @ M_cent.T) # (1, d) +unweighted_avg = M_cent.mean(dim=1) # (d,) +print(f"\n‖weighted_avg‖ = {weighted_avg.norm():.6f}") +print(f"‖unweighted_avg‖ = {unweighted_avg.norm():.2e}") + +# How concentrated is the attention? +top50_vals, top50_idx = alpha_cent.topk(50) +mass_top50 = top50_vals.sum().item() +print(f"Mass in top-50 memories: {mass_top50:.4f}") +print(f"Mass in top-5 memories: {alpha_cent.topk(5)[0].sum().item():.4f}") + +# The weighted average of centered vectors is small because: +# 1. Centered vectors m̃_i have ‖m̃_i‖ ≈ 0.78 (smaller than raw ‖m_i‖=1) +# 2. Centered vectors point in diverse directions (they have mean removed) +# 3. Even with non-uniform weights, the cancellation is severe unless +# attention is extremely peaked on a few memories +# So the output ‖q1‖ << ‖q0‖, even though β is large + +# Key quantification: what fraction of ‖q0‖ is preserved? +preserve_ratio = q1_cent.norm().item() / q0_cent.norm().item() +print(f"\n‖q1‖/‖q0‖ = {preserve_ratio:.4f} (fraction of query norm preserved)") +print("This ratio << 1 means the averaging contracts the query toward 0.") +print("For centering to work with iteration, this ratio must be > 1.") + +print("\n" + "=" * 80) +print("SOLUTION ANALYSIS") +print("=" * 80) +print(""" +The centering fix removes the centroid attractor: M̃·uniform = 0, not μ. +But the fundamental problem remains: ANY weighted average of centered vectors +is much shorter than the input query, because centered vectors cancel. + +For the origin to be unstable, β must exceed β_critical so that the Jacobian +amplifies perturbations near zero. But the dynamics from a realistic starting +point (‖q‖≈0.5) don't behave like the linearization predicts. + +The actual contraction ratio ‖q1‖/‖q0‖ is what matters, not the Jacobian +at origin. This ratio is small because softmax isn't peaked enough. + +Possible fixes: +1. MUCH higher β (β > 500?) to make attention ultra-peaked → less cancellation +2. Residual connection with centering: q_{t+1} = λ·q_t + (1-λ)·M̃·softmax(...) + This explicitly preserves query norm while still benefiting from centering. +3. Normalize q_{t+1} after each step to prevent norm collapse. +4. Use centering only for the attention computation, not for the update target: + α = softmax(β · M̃ᵀ · q̃) but q_{t+1} = M_raw · α (update in original space) +""") + +# Test option 4: centered attention, raw update +print("\n" + "=" * 80) +print("TEST: Centered attention + raw update (hybrid)") +print("=" * 80) + +for beta in [5.0, 20.0, 50.0]: + print(f"\n--- β = {beta} ---") + for qi in range(min(3, len(q_texts))): + q_raw = q_embs[qi:qi+1].clone() # (1, d) raw query + print(f" Q{qi}: '{q_texts[qi][:50]}...'") + + for t in range(5): + q_cent = q_raw - mu.unsqueeze(0) # center query + logits = beta * (q_cent @ M_cent) # attention on centered space + alpha = torch.softmax(logits, dim=-1) + q_new = alpha @ M_raw.T # update in RAW space + + entropy = -(alpha * alpha.log()).sum().item() + delta = (q_new - q_raw).norm().item() + + if t == 0: + _, top5 = alpha.topk(5) + # FAISS top5 + logits_f = q_embs[qi:qi+1] @ M_raw + _, top5_f = logits_f.topk(5) + overlap = len(set(top5[0].cpu().tolist()) & set(top5_f[0].cpu().tolist())) + + print(f" t={t}: ‖q‖={q_raw.norm():.4f} → {q_new.norm():.4f}, " + f"H={entropy:.2f}, Δ={delta:.4f}") + q_raw = q_new + + # Final top-5 + q_cent_f = q_raw - mu.unsqueeze(0) + logits_f = beta * (q_cent_f @ M_cent) + _, top5_final = torch.softmax(logits_f, dim=-1).topk(5) + logits_faiss = q_embs[qi:qi+1] @ M_raw + _, top5_faiss = logits_faiss.topk(5) + overlap = len(set(top5_final[0].cpu().tolist()) & set(top5_faiss[0].cpu().tolist())) + print(f" final vs FAISS overlap: {overlap}/5") + print(f" FAISS top5: {top5_faiss[0].cpu().tolist()}") + print(f" Hybrid top5: {top5_final[0].cpu().tolist()}") + +print("\nDone.") diff --git a/scripts/eval_centered_grid.py b/scripts/eval_centered_grid.py new file mode 100644 index 0000000..4581251 --- /dev/null +++ b/scripts/eval_centered_grid.py @@ -0,0 +1,313 @@ +"""Evaluate centered Hopfield on 100 questions. + +Memory bank is mean-centered (M̃ = M - μ), query is centered (q̃ = q - μ). +β_critical = 37.6: below it origin is stable attractor, above it dynamics escape. + +Grid: β spanning both sides of β_critical, iter = [0, 1, 2, 3, 5, 8]. +No residual — pure Hopfield update on centered space. + +Usage: + CUDA_VISIBLE_DEVICES=1 nohup python -u scripts/eval_centered_grid.py \ + --memory-bank data/processed/hotpotqa_memory_bank.pt \ + --questions data/processed/hotpotqa_questions.jsonl \ + --device cuda --max-samples 100 \ + > data/processed/centered_grid.log 2>&1 & +""" + +import argparse +import json +import logging +import sys +import time +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +import yaml + +from hag.config import EncoderConfig, GeneratorConfig, MemoryBankConfig +from hag.energy import compute_attention_entropy +from hag.encoder import Encoder +from hag.generator import Generator +from hag.memory_bank import MemoryBank +from hag.metrics import exact_match, f1_score +from hag.retriever_faiss import FAISSRetriever + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + stream=sys.stdout, +) +logger = logging.getLogger(__name__) + + +def load_questions(path: str, max_samples: int) -> Tuple[List[str], List[str]]: + questions, gold_answers = [], [] + with open(path) as f: + for line in f: + r = json.loads(line) + questions.append(r["question"]) + gold_answers.append(r["answer"]) + if len(questions) >= max_samples: + break + return questions, gold_answers + + +@torch.no_grad() +def hopfield_retrieve_centered( + query: torch.Tensor, + memory_centered: torch.Tensor, + mean: torch.Tensor, + beta: float, + max_iter: int, + top_k: int, +) -> Tuple[torch.Tensor, float]: + """Pure Hopfield retrieval on centered memory bank. + + Args: + query: (batch, d) raw query embeddings + memory_centered: (d, N) centered memory bank (M̃ = M - μ) + mean: (d,) memory bank mean + beta, max_iter, top_k: Hopfield params + + Returns: + (top_k_indices (batch, top_k), avg_entropy) + """ + # Center the query + q = query - mean.unsqueeze(0) # (batch, d) + + for _ in range(max_iter): + logits = beta * (q @ memory_centered) # (batch, N) + alpha = torch.softmax(logits, dim=-1) # (batch, N) + q = alpha @ memory_centered.T # (batch, d) + + # Final attention + logits = beta * (q @ memory_centered) + alpha = torch.softmax(logits, dim=-1) + _, indices = torch.topk(alpha, top_k, dim=-1) + entropy = compute_attention_entropy(alpha) + return indices, entropy + + +def main() -> None: + parser = argparse.ArgumentParser(description="Centered Hopfield grid evaluation") + parser.add_argument("--config", type=str, default="configs/hotpotqa.yaml") + parser.add_argument("--memory-bank", type=str, required=True) + parser.add_argument("--questions", type=str, required=True) + parser.add_argument("--device", type=str, default="cpu") + parser.add_argument("--max-samples", type=int, default=100) + parser.add_argument("--output", type=str, default="data/processed/centered_grid_results.json") + parser.add_argument("--top-k", type=int, default=5) + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + # Grid: span β_critical ≈ 37.6 + betas = [10.0, 20.0, 30.0, 38.0, 40.0, 45.0, 50.0, 60.0, 75.0, 100.0, 150.0, 200.0] + max_iters_list = [0, 1, 2, 3, 5, 8] + top_k = args.top_k + + total_configs = len(betas) * len(max_iters_list) + logger.info("=" * 60) + logger.info("Centered Hopfield Grid Search") + logger.info(" β_critical ≈ 37.6") + logger.info(" betas: %s", betas) + logger.info(" max_iters: %s", max_iters_list) + logger.info(" total configs: %d", total_configs) + logger.info("=" * 60) + + t_start = time.time() + + questions, gold_answers = load_questions(args.questions, args.max_samples) + n = len(questions) + logger.info("Loaded %d questions", n) + + # Load memory bank (uncentered) + mb = MemoryBank(MemoryBankConfig(**cfg.get("memory", {}))) + mb.load(args.memory_bank, device=args.device) + M_raw = mb.embeddings # (d, N) + d, N = M_raw.shape + logger.info("Memory bank: %d passages, dim=%d", N, d) + + # Center the memory bank + mu = M_raw.mean(dim=1) # (d,) + M_cent = M_raw - mu.unsqueeze(1) # (d, N) + logger.info("Centered memory bank. ‖μ‖=%.4f, ‖M̃·1/N‖=%.2e", + mu.norm().item(), M_cent.mean(dim=1).norm().item()) + + # Compute β_critical + S = torch.linalg.svdvals(M_cent) + lambda_max_C = (S[0].item() ** 2) / N + beta_crit = 1.0 / lambda_max_C + logger.info("β_critical = %.2f (λ_max(C)=%.4f, σ_max=%.4f)", beta_crit, lambda_max_C, S[0].item()) + + encoder = Encoder(EncoderConfig(**cfg.get("encoder", {})), device=args.device) + generator = Generator(GeneratorConfig(**cfg.get("generator", {})), device=args.device) + + logger.info("Encoding questions...") + all_embs = [] + batch_size = cfg.get("encoder", {}).get("batch_size", 64) + for i in range(0, n, batch_size): + all_embs.append(encoder.encode(questions[i : i + batch_size])) + Q = torch.cat(all_embs, dim=0) # (n, d) + logger.info("Encoded, shape=%s", Q.shape) + + # FAISS baseline (on raw embeddings) + logger.info("Running FAISS baseline...") + emb_np = M_raw.T.cpu().numpy().astype(np.float32) + faiss_ret = FAISSRetriever(top_k=top_k) + faiss_ret.build_index(emb_np, mb.passages) + + faiss_indices: Dict[int, Tuple[int, ...]] = {} + llm_cache: Dict[Tuple[int, frozenset], str] = {} + + for i in range(n): + q_np = Q[i].cpu().numpy().astype(np.float32) + result = faiss_ret.retrieve(q_np) + idx_tuple = tuple(sorted(result.indices.tolist())) + faiss_indices[i] = idx_tuple + cache_key = (i, frozenset(idx_tuple)) + answer = generator.generate(questions[i], result.passages) + llm_cache[cache_key] = answer + if (i + 1) % 20 == 0: + ems = [exact_match(llm_cache[(j, frozenset(faiss_indices[j]))], gold_answers[j]) for j in range(i + 1)] + f1s = [f1_score(llm_cache[(j, frozenset(faiss_indices[j]))], gold_answers[j]) for j in range(i + 1)] + logger.info(" FAISS %d/%d: EM=%.3f F1=%.3f", i + 1, n, np.mean(ems), np.mean(f1s)) + + faiss_em = np.mean([exact_match(llm_cache[(i, frozenset(faiss_indices[i]))], gold_answers[i]) for i in range(n)]) + faiss_f1 = np.mean([f1_score(llm_cache[(i, frozenset(faiss_indices[i]))], gold_answers[i]) for i in range(n)]) + logger.info("FAISS baseline: EM=%.4f F1=%.4f", faiss_em, faiss_f1) + + # Phase 2: Retrieve all configs (centered) + logger.info("Phase 2: Retrieving all %d configs (centered)...", total_configs) + t_ret = time.time() + + retrieval_data: Dict[str, List[Tuple[Tuple[int, ...], float]]] = {} + + for beta in betas: + for max_iter in max_iters_list: + config_key = f"β={beta}_iter={max_iter}" + indices_batch, entropy = hopfield_retrieve_centered( + Q, M_cent, mu, beta=beta, max_iter=max_iter, top_k=top_k, + ) + per_q = [] + for i in range(n): + idx_tuple = tuple(sorted(indices_batch[i].tolist())) + per_q.append((idx_tuple, entropy)) + retrieval_data[config_key] = per_q + + logger.info("Retrieval done in %.1fs, %d configs", time.time() - t_ret, len(retrieval_data)) + + # Phase 3: Dedup + generate + needed: Dict[Tuple[int, frozenset], Tuple[int, Tuple[int, ...]]] = {} + for key, per_q in retrieval_data.items(): + for i, (idx_tuple, _) in enumerate(per_q): + cache_key = (i, frozenset(idx_tuple)) + if cache_key not in llm_cache and cache_key not in needed: + needed[cache_key] = (i, idx_tuple) + + logger.info("Unique LLM calls needed: %d (cache has %d)", len(needed), len(llm_cache)) + + t_gen = time.time() + for call_idx, (cache_key, (q_idx, idx_tuple)) in enumerate(needed.items()): + passages = mb.get_passages_by_indices(torch.tensor(list(idx_tuple), dtype=torch.long)) + answer = generator.generate(questions[q_idx], passages) + llm_cache[cache_key] = answer + if (call_idx + 1) % 50 == 0: + elapsed = time.time() - t_gen + rate = (call_idx + 1) / elapsed + logger.info(" Generated %d/%d (%.1f/s, ~%.0fs left)", + call_idx + 1, len(needed), rate, (len(needed) - call_idx - 1) / rate) + logger.info("Generation done: %d calls in %.1fs", len(needed), time.time() - t_gen) + + # Phase 4: Evaluate + logger.info("Phase 4: Evaluating...") + results = [] + for config_key, per_q in retrieval_data.items(): + ems, f1s, overlaps = [], [], [] + for i, (idx_tuple, ent) in enumerate(per_q): + cache_key = (i, frozenset(idx_tuple)) + answer = llm_cache[cache_key] + ems.append(exact_match(answer, gold_answers[i])) + f1s.append(f1_score(answer, gold_answers[i])) + overlaps.append(len(set(idx_tuple) & set(faiss_indices[i])) / top_k) + + em, f1 = np.mean(ems), np.mean(f1s) + # Parse beta from config_key + beta_val = float(config_key.split("_")[0].split("=")[1]) + iter_val = int(config_key.split("_")[1].split("=")[1]) + r = { + "config": config_key, + "beta": beta_val, + "max_iter": iter_val, + "em": round(em, 4), + "f1": round(f1, 4), + "avg_faiss_overlap": round(np.mean(overlaps), 4), + "avg_entropy": round(per_q[0][1], 4), + "above_beta_crit": beta_val > beta_crit, + } + results.append(r) + + results.sort(key=lambda x: x["f1"], reverse=True) + + # Count how many beat FAISS + n_beat = sum(1 for r in results if r["f1"] > faiss_f1) + logger.info("\n%d/%d configs beat FAISS F1=%.3f", n_beat, len(results), faiss_f1) + + # Log top 20 + logger.info("\nTop 20 configs:") + for i, r in enumerate(results[:20]): + marker = " ***" if r["f1"] > faiss_f1 else "" + crit = ">" if r["above_beta_crit"] else "<" + logger.info(" %2d. %-25s EM=%.3f F1=%.3f overlap=%.3f H=%.2f β%sβ_c%s", + i + 1, r["config"], r["em"], r["f1"], r["avg_faiss_overlap"], + r["avg_entropy"], crit, marker) + + # Summary by β: best iter for each β + logger.info("\nBest iter per β:") + for beta in betas: + beta_results = [r for r in results if r["beta"] == beta] + if beta_results: + best = beta_results[0] + crit = ">" if best["above_beta_crit"] else "<" + logger.info(" β=%6.1f (β%sβ_c): best iter=%d EM=%.3f F1=%.3f overlap=%.3f", + beta, crit, best["max_iter"], best["em"], best["f1"], best["avg_faiss_overlap"]) + + t_total = time.time() - t_start + output = { + "meta": { + "n_questions": n, + "total_configs": len(retrieval_data), + "unique_llm_calls": len(needed), + "total_time_s": round(t_total, 1), + "beta_critical": round(beta_crit, 2), + }, + "faiss_baseline": {"em": round(faiss_em, 4), "f1": round(faiss_f1, 4)}, + "grid_results": results, + "best_config": results[0], + "top10": results[:10], + } + + Path(args.output).parent.mkdir(parents=True, exist_ok=True) + with open(args.output, "w") as f: + json.dump(output, f, indent=2) + + logger.info("=" * 60) + logger.info("RESULTS SUMMARY") + logger.info(" FAISS: EM=%.4f F1=%.4f", faiss_em, faiss_f1) + logger.info(" β_critical = %.2f", beta_crit) + logger.info(" Configs beating FAISS: %d/%d", n_beat, len(results)) + logger.info(" Top 5:") + for i, r in enumerate(results[:5]): + logger.info(" %d. %-25s EM=%.3f F1=%.3f overlap=%.3f", + i + 1, r["config"], r["em"], r["f1"], r["avg_faiss_overlap"]) + logger.info(" Total time: %.1fs", t_total) + logger.info(" Saved to: %s", args.output) + logger.info("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/scripts/eval_highbeta_grid.py b/scripts/eval_highbeta_grid.py new file mode 100644 index 0000000..bf1d97d --- /dev/null +++ b/scripts/eval_highbeta_grid.py @@ -0,0 +1,301 @@ +"""Evaluate high-β Hopfield (standard + normalized update) on 100 questions. + +Tests whether high β (≥50) allows standard Hopfield to work without residual. +Also tests normalized update: q → normalize(M @ softmax(β * M^T @ q)). + +Usage: + CUDA_VISIBLE_DEVICES=0 python -u scripts/eval_highbeta_grid.py \ + --config configs/hotpotqa.yaml \ + --memory-bank data/processed/hotpotqa_memory_bank.pt \ + --questions data/processed/hotpotqa_questions.jsonl \ + --device cuda --max-samples 100 +""" + +import argparse +import json +import logging +import sys +import time +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +import yaml + +from hag.config import EncoderConfig, GeneratorConfig, MemoryBankConfig +from hag.energy import compute_attention_entropy +from hag.encoder import Encoder +from hag.generator import Generator +from hag.memory_bank import MemoryBank +from hag.metrics import exact_match, f1_score +from hag.retriever_faiss import FAISSRetriever + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + stream=sys.stdout, +) +logger = logging.getLogger(__name__) + + +def load_questions(path: str, max_samples: int) -> Tuple[List[str], List[str]]: + questions, gold_answers = [], [] + with open(path) as f: + for line in f: + r = json.loads(line) + questions.append(r["question"]) + gold_answers.append(r["answer"]) + if len(questions) >= max_samples: + break + return questions, gold_answers + + +@torch.no_grad() +def hopfield_retrieve( + query: torch.Tensor, + memory: torch.Tensor, + beta: float, + max_iter: int, + top_k: int, + mode: str = "standard", + lam: float = 0.0, +) -> Tuple[torch.Tensor, float]: + """Hopfield retrieval with different update modes. + + Args: + query: (batch, d) + memory: (d, N) + beta, max_iter, top_k: Hopfield params + mode: "standard" | "normalized" | "residual" + lam: residual weight (only for mode="residual") + + Returns: + (top_k_indices (batch, top_k), avg_entropy) + """ + q = query.clone() + if mode == "normalized": + q = F.normalize(q, dim=-1) + + for _ in range(max_iter): + logits = beta * (q @ memory) # (batch, N) + alpha = torch.softmax(logits, dim=-1) # (batch, N) + q_hop = alpha @ memory.T # (batch, d) + + if mode == "standard": + q = q_hop + elif mode == "normalized": + q = F.normalize(q_hop, dim=-1) + elif mode == "residual": + q = lam * q + (1.0 - lam) * q_hop + + # Final attention + logits = beta * (q @ memory) + alpha = torch.softmax(logits, dim=-1) + _, indices = torch.topk(alpha, top_k, dim=-1) + entropy = compute_attention_entropy(alpha) + return indices, entropy + + +def main() -> None: + parser = argparse.ArgumentParser(description="High-β Hopfield grid evaluation") + parser.add_argument("--config", type=str, default="configs/hotpotqa.yaml") + parser.add_argument("--memory-bank", type=str, required=True) + parser.add_argument("--questions", type=str, required=True) + parser.add_argument("--device", type=str, default="cpu") + parser.add_argument("--max-samples", type=int, default=100) + parser.add_argument("--output", type=str, default="data/processed/highbeta_grid_results.json") + parser.add_argument("--top-k", type=int, default=5) + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + # Grid: high β focus + betas = [20.0, 50.0, 100.0, 200.0, 500.0] + max_iters_list = [0, 1, 2, 3, 5, 8] + # Modes: standard, normalized, residual(λ=0.9) + modes = [ + ("standard", 0.0), + ("normalized", 0.0), + ("residual_0.9", 0.9), + ("residual_0.95", 0.95), + ] + top_k = args.top_k + + total_configs = len(betas) * len(max_iters_list) * len(modes) + logger.info("=" * 60) + logger.info("High-β Hopfield Grid Search") + logger.info(" betas: %s", betas) + logger.info(" max_iters: %s", max_iters_list) + logger.info(" modes: %s", [m[0] for m in modes]) + logger.info(" total configs: %d", total_configs) + logger.info("=" * 60) + + t_start = time.time() + + questions, gold_answers = load_questions(args.questions, args.max_samples) + n = len(questions) + logger.info("Loaded %d questions", n) + + mb = MemoryBank(MemoryBankConfig(**cfg.get("memory", {}))) + mb.load(args.memory_bank, device=args.device) + M = mb.embeddings + logger.info("Memory bank: %d passages, dim=%d", mb.size, mb.dim) + + encoder = Encoder(EncoderConfig(**cfg.get("encoder", {})), device=args.device) + generator = Generator(GeneratorConfig(**cfg.get("generator", {})), device=args.device) + + logger.info("Encoding questions...") + all_embs = [] + batch_size = cfg.get("encoder", {}).get("batch_size", 64) + for i in range(0, n, batch_size): + all_embs.append(encoder.encode(questions[i : i + batch_size])) + Q = torch.cat(all_embs, dim=0) + logger.info("Encoded, shape=%s", Q.shape) + + # FAISS baseline + logger.info("Running FAISS baseline...") + emb_np = mb.embeddings.T.cpu().numpy().astype(np.float32) + faiss_ret = FAISSRetriever(top_k=top_k) + faiss_ret.build_index(emb_np, mb.passages) + + faiss_indices: Dict[int, Tuple[int, ...]] = {} + llm_cache: Dict[Tuple[int, frozenset], str] = {} + + for i in range(n): + q_np = Q[i].cpu().numpy().astype(np.float32) + result = faiss_ret.retrieve(q_np) + idx_tuple = tuple(sorted(result.indices.tolist())) + faiss_indices[i] = idx_tuple + cache_key = (i, frozenset(idx_tuple)) + answer = generator.generate(questions[i], result.passages) + llm_cache[cache_key] = answer + if (i + 1) % 20 == 0: + ems = [exact_match(llm_cache[(j, frozenset(faiss_indices[j]))], gold_answers[j]) for j in range(i + 1)] + f1s = [f1_score(llm_cache[(j, frozenset(faiss_indices[j]))], gold_answers[j]) for j in range(i + 1)] + logger.info(" FAISS %d/%d: EM=%.3f F1=%.3f", i + 1, n, np.mean(ems), np.mean(f1s)) + + faiss_em = np.mean([exact_match(llm_cache[(i, frozenset(faiss_indices[i]))], gold_answers[i]) for i in range(n)]) + faiss_f1 = np.mean([f1_score(llm_cache[(i, frozenset(faiss_indices[i]))], gold_answers[i]) for i in range(n)]) + logger.info("FAISS baseline: EM=%.4f F1=%.4f", faiss_em, faiss_f1) + + # Phase 2: Retrieve all configs + logger.info("Phase 2: Retrieving all %d configs...", total_configs) + t_ret = time.time() + + retrieval_data: Dict[str, List[Tuple[Tuple[int, ...], float]]] = {} + + for beta in betas: + for max_iter in max_iters_list: + for mode_name, lam in modes: + if max_iter == 0: + # iter=0: just use initial query's softmax top-k (same for all modes) + if mode_name != "standard": + continue # skip duplicates for iter=0 + indices_batch = (beta * (Q @ M)).softmax(dim=-1).topk(top_k, dim=-1).indices + entropy = compute_attention_entropy((beta * (Q @ M)).softmax(dim=-1)) + config_key = f"β={beta}_iter=0_standard" + else: + actual_mode = "residual" if mode_name.startswith("residual") else mode_name + indices_batch, entropy = hopfield_retrieve( + Q, M, beta=beta, max_iter=max_iter, top_k=top_k, + mode=actual_mode, lam=lam, + ) + config_key = f"β={beta}_iter={max_iter}_{mode_name}" + + per_q = [] + for i in range(n): + idx_tuple = tuple(sorted(indices_batch[i].tolist())) + per_q.append((idx_tuple, entropy)) + retrieval_data[config_key] = per_q + + logger.info("Retrieval done in %.1fs, %d configs", time.time() - t_ret, len(retrieval_data)) + + # Phase 3: Dedup + generate + needed: Dict[Tuple[int, frozenset], Tuple[int, Tuple[int, ...]]] = {} + for key, per_q in retrieval_data.items(): + for i, (idx_tuple, _) in enumerate(per_q): + cache_key = (i, frozenset(idx_tuple)) + if cache_key not in llm_cache and cache_key not in needed: + needed[cache_key] = (i, idx_tuple) + + logger.info("Unique LLM calls needed: %d (cache has %d)", len(needed), len(llm_cache)) + + t_gen = time.time() + for call_idx, (cache_key, (q_idx, idx_tuple)) in enumerate(needed.items()): + passages = mb.get_passages_by_indices(torch.tensor(list(idx_tuple), dtype=torch.long)) + answer = generator.generate(questions[q_idx], passages) + llm_cache[cache_key] = answer + if (call_idx + 1) % 50 == 0: + elapsed = time.time() - t_gen + rate = (call_idx + 1) / elapsed + logger.info(" Generated %d/%d (%.1f/s, ~%.0fs left)", + call_idx + 1, len(needed), rate, (len(needed) - call_idx - 1) / rate) + logger.info("Generation done: %d calls in %.1fs", len(needed), time.time() - t_gen) + + # Phase 4: Evaluate + logger.info("Phase 4: Evaluating...") + results = [] + for config_key, per_q in retrieval_data.items(): + ems, f1s, overlaps = [], [], [] + for i, (idx_tuple, ent) in enumerate(per_q): + cache_key = (i, frozenset(idx_tuple)) + answer = llm_cache[cache_key] + ems.append(exact_match(answer, gold_answers[i])) + f1s.append(f1_score(answer, gold_answers[i])) + overlaps.append(len(set(idx_tuple) & set(faiss_indices[i])) / top_k) + + em, f1 = np.mean(ems), np.mean(f1s) + r = { + "config": config_key, + "em": round(em, 4), + "f1": round(f1, 4), + "avg_faiss_overlap": round(np.mean(overlaps), 4), + "avg_entropy": round(per_q[0][1], 4), + } + results.append(r) + + results.sort(key=lambda x: x["f1"], reverse=True) + + # Log all that beat or match FAISS + logger.info("\nConfigs matching or beating FAISS (F1≥%.3f):", faiss_f1) + for r in results: + if r["f1"] >= faiss_f1 - 0.005: + marker = " ***" if r["f1"] > faiss_f1 else "" + logger.info(" %s: EM=%.3f F1=%.3f overlap=%.3f%s", + r["config"], r["em"], r["f1"], r["avg_faiss_overlap"], marker) + + t_total = time.time() - t_start + output = { + "meta": { + "n_questions": n, + "total_configs": len(retrieval_data), + "unique_llm_calls": len(needed), + "total_time_s": round(t_total, 1), + }, + "faiss_baseline": {"em": round(faiss_em, 4), "f1": round(faiss_f1, 4)}, + "grid_results": results, + "best_config": results[0], + "top10": results[:10], + } + + Path(args.output).parent.mkdir(parents=True, exist_ok=True) + with open(args.output, "w") as f: + json.dump(output, f, indent=2) + + logger.info("=" * 60) + logger.info("RESULTS") + logger.info(" FAISS: EM=%.4f F1=%.4f", faiss_em, faiss_f1) + logger.info(" Top 10:") + for i, r in enumerate(results[:10]): + logger.info(" %2d. %-40s EM=%.3f F1=%.3f overlap=%.3f", + i + 1, r["config"], r["em"], r["f1"], r["avg_faiss_overlap"]) + logger.info(" Total time: %.1fs", t_total) + logger.info(" Saved to: %s", args.output) + logger.info("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/scripts/eval_residual_grid.py b/scripts/eval_residual_grid.py new file mode 100644 index 0000000..f716c51 --- /dev/null +++ b/scripts/eval_residual_grid.py @@ -0,0 +1,298 @@ +"""Evaluate Residual Hopfield configs on 100 questions with dedup-based LLM caching. + +Residual update: q_{t+1} = λ * q_t + (1-λ) * M @ softmax(β * M^T @ q_t) + +Usage: + CUDA_VISIBLE_DEVICES=1 python scripts/eval_residual_grid.py \ + --config configs/hotpotqa.yaml \ + --memory-bank data/processed/hotpotqa_memory_bank.pt \ + --questions data/processed/hotpotqa_questions.jsonl \ + --device cuda --max-samples 100 +""" + +import argparse +import json +import logging +import sys +import time +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +import torch +import yaml + +from hag.config import EncoderConfig, GeneratorConfig, HopfieldConfig, MemoryBankConfig +from hag.encoder import Encoder +from hag.energy import compute_attention_entropy +from hag.generator import Generator +from hag.memory_bank import MemoryBank +from hag.metrics import exact_match, f1_score +from hag.retriever_faiss import FAISSRetriever + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + stream=sys.stdout, +) +logger = logging.getLogger(__name__) + + +def load_questions(path: str, max_samples: int) -> Tuple[List[str], List[str]]: + questions, gold_answers = [], [] + with open(path) as f: + for line in f: + r = json.loads(line) + questions.append(r["question"]) + gold_answers.append(r["answer"]) + if len(questions) >= max_samples: + break + return questions, gold_answers + + +@torch.no_grad() +def residual_hopfield_retrieve( + query: torch.Tensor, + memory: torch.Tensor, + beta: float, + lam: float, + max_iter: int, + top_k: int, +) -> Tuple[torch.Tensor, torch.Tensor, float]: + """Residual Hopfield retrieval on full memory bank. + + q_{t+1} = λ * q_t + (1-λ) * M @ softmax(β * M^T @ q_t) + + Args: + query: (batch, d) + memory: (d, N) + beta: inverse temperature + lam: residual weight (0=pure Hopfield, 1=no update) + max_iter: number of iterations + top_k: number of passages to return + + Returns: + (top_k_indices, top_k_scores, avg_entropy) for the batch. + indices: (batch, top_k), scores: (batch, top_k), entropy: float + """ + q = query.clone() + for _ in range(max_iter): + logits = beta * (q @ memory) # (batch, N) + alpha = torch.softmax(logits, dim=-1) # (batch, N) + q_hop = alpha @ memory.T # (batch, d) + q = lam * q + (1.0 - lam) * q_hop # (batch, d) + + # Final attention + logits = beta * (q @ memory) # (batch, N) + alpha = torch.softmax(logits, dim=-1) # (batch, N) + scores, indices = torch.topk(alpha, top_k, dim=-1) # (batch, top_k) + + entropy = compute_attention_entropy(alpha) + return indices, scores, entropy + + +def main() -> None: + parser = argparse.ArgumentParser(description="Residual Hopfield grid evaluation") + parser.add_argument("--config", type=str, default="configs/hotpotqa.yaml") + parser.add_argument("--memory-bank", type=str, required=True) + parser.add_argument("--questions", type=str, required=True) + parser.add_argument("--device", type=str, default="cpu") + parser.add_argument("--max-samples", type=int, default=100) + parser.add_argument("--output", type=str, default="data/processed/residual_grid_results.json") + parser.add_argument("--top-k", type=int, default=5) + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + # Grid + betas = [5.0, 10.0, 20.0, 50.0, 100.0] + lambdas = [0.5, 0.7, 0.8, 0.9, 0.95] + max_iters_list = [1, 3, 5, 8] + top_k = args.top_k + + total_configs = len(betas) * len(lambdas) * len(max_iters_list) + logger.info("=" * 60) + logger.info("Residual Hopfield Grid Search") + logger.info(" betas: %s", betas) + logger.info(" lambdas: %s", lambdas) + logger.info(" max_iters: %s", max_iters_list) + logger.info(" total configs: %d", total_configs) + logger.info("=" * 60) + + t_start = time.time() + + # Load + questions, gold_answers = load_questions(args.questions, args.max_samples) + n = len(questions) + logger.info("Loaded %d questions", n) + + mb = MemoryBank(MemoryBankConfig(**cfg.get("memory", {}))) + mb.load(args.memory_bank, device=args.device) + M = mb.embeddings # (d, N) + logger.info("Memory bank: %d passages, dim=%d", mb.size, mb.dim) + + encoder = Encoder(EncoderConfig(**cfg.get("encoder", {})), device=args.device) + generator = Generator(GeneratorConfig(**cfg.get("generator", {})), device=args.device) + + logger.info("Encoding questions...") + all_embs = [] + batch_size = cfg.get("encoder", {}).get("batch_size", 64) + for i in range(0, n, batch_size): + all_embs.append(encoder.encode(questions[i : i + batch_size])) + Q = torch.cat(all_embs, dim=0) # (n, d) + logger.info("Encoded, shape=%s", Q.shape) + + # FAISS baseline + logger.info("Running FAISS baseline...") + emb_np = mb.embeddings.T.cpu().numpy().astype(np.float32) + faiss_ret = FAISSRetriever(top_k=top_k) + faiss_ret.build_index(emb_np, mb.passages) + + faiss_indices: Dict[int, Tuple[int, ...]] = {} + llm_cache: Dict[Tuple[int, frozenset], str] = {} + + for i in range(n): + q_np = Q[i].cpu().numpy().astype(np.float32) + result = faiss_ret.retrieve(q_np) + idx_tuple = tuple(sorted(result.indices.tolist())) + faiss_indices[i] = idx_tuple + cache_key = (i, frozenset(idx_tuple)) + answer = generator.generate(questions[i], result.passages) + llm_cache[cache_key] = answer + if (i + 1) % 20 == 0: + ems = [exact_match(llm_cache[(j, frozenset(faiss_indices[j]))], gold_answers[j]) for j in range(i + 1)] + f1s = [f1_score(llm_cache[(j, frozenset(faiss_indices[j]))], gold_answers[j]) for j in range(i + 1)] + logger.info(" FAISS %d/%d: EM=%.3f F1=%.3f", i + 1, n, np.mean(ems), np.mean(f1s)) + + faiss_em = np.mean([exact_match(llm_cache[(i, frozenset(faiss_indices[i]))], gold_answers[i]) for i in range(n)]) + faiss_f1 = np.mean([f1_score(llm_cache[(i, frozenset(faiss_indices[i]))], gold_answers[i]) for i in range(n)]) + logger.info("FAISS baseline: EM=%.4f F1=%.4f", faiss_em, faiss_f1) + + # Phase 2: Retrieve all configs (batched, fast) + logger.info("Phase 2: Retrieving all %d configs...", total_configs) + t_ret = time.time() + + # config_key -> list of (sorted_indices_tuple, entropy) per question + retrieval_data: Dict[Tuple[float, float, int], List[Tuple[Tuple[int, ...], float]]] = {} + + for beta in betas: + for lam in lambdas: + for max_iter in max_iters_list: + indices, scores, entropy = residual_hopfield_retrieve( + Q, M, beta=beta, lam=lam, max_iter=max_iter, top_k=top_k + ) + per_q = [] + for i in range(n): + idx_tuple = tuple(sorted(indices[i].tolist())) + # per-question entropy + per_q.append((idx_tuple, entropy)) + retrieval_data[(beta, lam, max_iter)] = per_q + + logger.info("Retrieval done in %.1fs", time.time() - t_ret) + + # Phase 3: Dedup + generate + needed: Dict[Tuple[int, frozenset], Tuple[int, Tuple[int, ...]]] = {} + for key, per_q in retrieval_data.items(): + for i, (idx_tuple, _) in enumerate(per_q): + cache_key = (i, frozenset(idx_tuple)) + if cache_key not in llm_cache and cache_key not in needed: + needed[cache_key] = (i, idx_tuple) + + total_grid_evals = total_configs * n + logger.info( + "Unique LLM calls needed: %d / %d grid evals (%.1f%% saving)", + len(needed), total_grid_evals, + (1 - len(needed) / total_grid_evals) * 100, + ) + + t_gen = time.time() + for call_idx, (cache_key, (q_idx, idx_tuple)) in enumerate(needed.items()): + passages = mb.get_passages_by_indices(torch.tensor(list(idx_tuple), dtype=torch.long)) + answer = generator.generate(questions[q_idx], passages) + llm_cache[cache_key] = answer + if (call_idx + 1) % 50 == 0: + elapsed = time.time() - t_gen + rate = (call_idx + 1) / elapsed + logger.info( + " Generated %d/%d (%.1f/s, ~%.0fs left)", + call_idx + 1, len(needed), rate, (len(needed) - call_idx - 1) / rate, + ) + logger.info("Generation done: %d calls in %.1fs", len(needed), time.time() - t_gen) + + # Phase 4: Evaluate + logger.info("Phase 4: Evaluating...") + results = [] + for beta in betas: + for lam in lambdas: + for max_iter in max_iters_list: + per_q = retrieval_data[(beta, lam, max_iter)] + ems, f1s, overlaps, entropies = [], [], [], [] + for i, (idx_tuple, ent) in enumerate(per_q): + cache_key = (i, frozenset(idx_tuple)) + answer = llm_cache[cache_key] + ems.append(exact_match(answer, gold_answers[i])) + f1s.append(f1_score(answer, gold_answers[i])) + overlap = len(set(idx_tuple) & set(faiss_indices[i])) / top_k + overlaps.append(overlap) + entropies.append(ent) + + em, f1 = np.mean(ems), np.mean(f1s) + r = { + "beta": beta, "lambda": lam, "max_iter": max_iter, + "em": round(em, 4), "f1": round(f1, 4), + "avg_faiss_overlap": round(np.mean(overlaps), 4), + "avg_entropy": round(np.mean(entropies), 4), + } + results.append(r) + if f1 >= faiss_f1 - 0.01: + marker = " ***" if f1 > faiss_f1 else "" + logger.info( + " β=%5.1f λ=%.2f iter=%d => EM=%.3f F1=%.3f overlap=%.3f%s", + beta, lam, max_iter, em, f1, np.mean(overlaps), marker, + ) + + # Sort by F1 + results.sort(key=lambda x: x["f1"], reverse=True) + best = results[0] + + t_total = time.time() - t_start + + output = { + "meta": { + "n_questions": n, + "total_configs": total_configs, + "unique_llm_calls": len(needed), + "faiss_llm_calls": n, + "total_time_s": round(t_total, 1), + }, + "faiss_baseline": {"em": round(faiss_em, 4), "f1": round(faiss_f1, 4)}, + "grid_results": results, + "best_config": best, + "top10": results[:10], + } + + Path(args.output).parent.mkdir(parents=True, exist_ok=True) + with open(args.output, "w") as f: + json.dump(output, f, indent=2) + + logger.info("=" * 60) + logger.info("RESULTS") + logger.info(" FAISS: EM=%.4f F1=%.4f", faiss_em, faiss_f1) + logger.info( + " Best: β=%.1f λ=%.2f iter=%d => EM=%.4f F1=%.4f", + best["beta"], best["lambda"], best["max_iter"], best["em"], best["f1"], + ) + logger.info(" Top 5:") + for i, r in enumerate(results[:5]): + logger.info( + " %d. β=%5.1f λ=%.2f iter=%d => EM=%.3f F1=%.3f overlap=%.3f", + i + 1, r["beta"], r["lambda"], r["max_iter"], r["em"], r["f1"], r["avg_faiss_overlap"], + ) + logger.info(" Total time: %.1fs", t_total) + logger.info(" Saved to: %s", args.output) + logger.info("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/scripts/prepare_corpus.py b/scripts/prepare_corpus.py new file mode 100644 index 0000000..93fc0ce --- /dev/null +++ b/scripts/prepare_corpus.py @@ -0,0 +1,127 @@ +"""Convert linear-rag chunks.json to JSONL corpus for build_memory_bank.py. + +The linear-rag dataset stores chunks as a list of strings with format "idx:text...". +This script strips the index prefix and outputs one {"text": "..."} per line. + +Usage: + python scripts/prepare_corpus.py --dataset hotpotqa + python scripts/prepare_corpus.py --dataset hotpotqa --dataset musique --dataset 2wikimultihop +""" + +import argparse +import json +import logging +from pathlib import Path + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +DATASETS = ["hotpotqa", "musique", "2wikimultihop", "medical"] + + +def convert_chunks(dataset: str, data_root: Path, output_dir: Path) -> Path: + """Convert a single dataset's chunks.json to corpus JSONL. + + Args: + dataset: dataset name (e.g., "hotpotqa") + data_root: path to linear-rag clone + output_dir: directory to write output JSONL + + Returns: + Path to the output JSONL file. + """ + chunks_path = data_root / dataset / "chunks.json" + if not chunks_path.exists(): + raise FileNotFoundError(f"Not found: {chunks_path}") + + with open(chunks_path) as f: + chunks = json.load(f) + + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / f"{dataset}_corpus.jsonl" + + count = 0 + with open(output_path, "w") as out: + for chunk in chunks: + # Strip the "idx:" prefix + text = chunk.split(":", 1)[1] if ":" in chunk else chunk + text = text.strip() + if text: + out.write(json.dumps({"text": text}) + "\n") + count += 1 + + logger.info("%s: %d chunks -> %s", dataset, count, output_path) + return output_path + + +def convert_questions(dataset: str, data_root: Path, output_dir: Path) -> Path: + """Convert questions.json to a standardized JSONL format. + + Args: + dataset: dataset name + data_root: path to linear-rag clone + output_dir: directory to write output JSONL + + Returns: + Path to the output JSONL file. + """ + questions_path = data_root / dataset / "questions.json" + if not questions_path.exists(): + raise FileNotFoundError(f"Not found: {questions_path}") + + with open(questions_path) as f: + questions = json.load(f) + + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / f"{dataset}_questions.jsonl" + + count = 0 + with open(output_path, "w") as out: + for q in questions: + record = { + "id": q.get("id", ""), + "question": q["question"], + "answer": q["answer"], + "question_type": q.get("question_type", ""), + } + out.write(json.dumps(record) + "\n") + count += 1 + + logger.info("%s: %d questions -> %s", dataset, count, output_path) + return output_path + + +def main() -> None: + parser = argparse.ArgumentParser(description="Prepare linear-rag data for HAG") + parser.add_argument( + "--dataset", + type=str, + action="append", + default=None, + help=f"Dataset(s) to process. Choices: {DATASETS}. Can specify multiple times.", + ) + parser.add_argument( + "--data-root", + type=str, + default="data/linear-rag", + help="Path to linear-rag clone", + ) + parser.add_argument( + "--output-dir", + type=str, + default="data/processed", + help="Output directory for processed files", + ) + args = parser.parse_args() + + datasets = args.dataset if args.dataset else DATASETS + data_root = Path(args.data_root) + output_dir = Path(args.output_dir) + + for ds in datasets: + convert_chunks(ds, data_root, output_dir) + convert_questions(ds, data_root, output_dir) + + +if __name__ == "__main__": + main() diff --git a/scripts/run_baseline.py b/scripts/run_baseline.py index 74c4710..beef76b 100644 --- a/scripts/run_baseline.py +++ b/scripts/run_baseline.py @@ -27,6 +27,7 @@ def main() -> None: parser.add_argument("--memory-bank", type=str, required=True) parser.add_argument("--question", type=str, required=True) parser.add_argument("--top-k", type=int, default=5) + parser.add_argument("--device", type=str, default="cpu") args = parser.parse_args() with open(args.config) as f: @@ -44,17 +45,17 @@ def main() -> None: from hag.config import MemoryBankConfig mb = MemoryBank(MemoryBankConfig(**cfg.get("memory", {}))) - mb.load(args.memory_bank) + mb.load(args.memory_bank) # FAISS needs CPU, load on CPU # Build FAISS index from memory bank embeddings import numpy as np - embeddings_np = mb.embeddings.T.numpy().astype(np.float32) # (N, d) + embeddings_np = mb.embeddings.T.cpu().numpy().astype(np.float32) # (N, d) faiss_ret = FAISSRetriever(top_k=args.top_k) faiss_ret.build_index(embeddings_np, mb.passages) - encoder = Encoder(pipeline_config.encoder) - generator = Generator(pipeline_config.generator) + encoder = Encoder(pipeline_config.encoder, device=args.device) + generator = Generator(pipeline_config.generator, device=args.device) pipeline = RAGPipeline( config=pipeline_config, diff --git a/scripts/run_comparison.py b/scripts/run_comparison.py new file mode 100644 index 0000000..29f23f8 --- /dev/null +++ b/scripts/run_comparison.py @@ -0,0 +1,186 @@ +"""Run side-by-side comparison of FAISS (baseline) vs Hopfield (HAG) retrieval. + +Usage: + CUDA_VISIBLE_DEVICES=1 python scripts/run_comparison.py \ + --config configs/hotpotqa.yaml \ + --memory-bank data/processed/hotpotqa_memory_bank.pt \ + --questions data/processed/hotpotqa_questions.jsonl \ + --device cuda \ + --max-samples 500 +""" + +import argparse +import json +import logging +import time + +import numpy as np +import torch +import yaml + +from hag.config import ( + EncoderConfig, + GeneratorConfig, + HopfieldConfig, + MemoryBankConfig, + PipelineConfig, +) +from hag.encoder import Encoder +from hag.generator import Generator +from hag.hopfield import HopfieldRetrieval +from hag.memory_bank import MemoryBank +from hag.metrics import evaluate_dataset, exact_match, f1_score +from hag.pipeline import RAGPipeline +from hag.retriever_faiss import FAISSRetriever +from hag.retriever_hopfield import HopfieldRetriever + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") +logger = logging.getLogger(__name__) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Compare FAISS vs Hopfield retrieval") + parser.add_argument("--config", type=str, default="configs/hotpotqa.yaml") + parser.add_argument("--memory-bank", type=str, required=True) + parser.add_argument("--questions", type=str, required=True) + parser.add_argument("--device", type=str, default="cpu") + parser.add_argument("--max-samples", type=int, default=None) + parser.add_argument("--output", type=str, default="data/processed/comparison_results.json") + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + hopfield_config = HopfieldConfig(**cfg.get("hopfield", {})) + memory_config = MemoryBankConfig(**cfg.get("memory", {})) + encoder_config = EncoderConfig(**cfg.get("encoder", {})) + generator_config = GeneratorConfig(**cfg.get("generator", {})) + + # Load questions + with open(args.questions) as f: + questions_data = [json.loads(line) for line in f] + if args.max_samples and len(questions_data) > args.max_samples: + questions_data = questions_data[: args.max_samples] + + questions = [q["question"] for q in questions_data] + gold_answers = [q["answer"] for q in questions_data] + logger.info("Loaded %d questions", len(questions)) + + # Load memory bank + mb = MemoryBank(memory_config) + mb.load(args.memory_bank, device=args.device) + logger.info("Memory bank: %d passages, dim=%d", mb.size, mb.dim) + + # Shared encoder and generator + encoder = Encoder(encoder_config, device=args.device) + generator = Generator(generator_config, device=args.device) + + # --- Build FAISS retriever --- + embeddings_np = mb.embeddings.T.cpu().numpy().astype(np.float32) # (N, d) + faiss_retriever = FAISSRetriever(top_k=hopfield_config.top_k) + faiss_retriever.build_index(embeddings_np, mb.passages) + + # --- Build Hopfield retriever --- + hopfield = HopfieldRetrieval(hopfield_config) + hopfield_retriever = HopfieldRetriever(hopfield, mb, top_k=hopfield_config.top_k) + + # --- Build pipelines --- + faiss_pipeline_cfg = PipelineConfig( + hopfield=hopfield_config, + memory=memory_config, + encoder=encoder_config, + generator=generator_config, + retriever_type="faiss", + device=args.device, + ) + faiss_pipeline = RAGPipeline( + config=faiss_pipeline_cfg, + encoder=encoder, + generator=generator, + faiss_retriever=faiss_retriever, + ) + + hopfield_pipeline_cfg = PipelineConfig( + hopfield=hopfield_config, + memory=memory_config, + encoder=encoder_config, + generator=generator_config, + retriever_type="hopfield", + device=args.device, + ) + hopfield_pipeline = RAGPipeline( + config=hopfield_pipeline_cfg, + encoder=encoder, + generator=generator, + memory_bank=mb, + ) + + # --- Run FAISS baseline --- + logger.info("=" * 60) + logger.info("Running FAISS baseline (%d questions)...", len(questions)) + t0 = time.time() + faiss_results = faiss_pipeline.run_batch(questions) + faiss_time = time.time() - t0 + faiss_metrics = evaluate_dataset(faiss_results, gold_answers) + logger.info("FAISS done in %.1fs | EM=%.4f | F1=%.4f", faiss_time, faiss_metrics["em"], faiss_metrics["f1"]) + + # --- Run HAG --- + logger.info("=" * 60) + logger.info("Running HAG (beta=%.1f, max_iter=%d, top_k=%d) (%d questions)...", + hopfield_config.beta, hopfield_config.max_iter, hopfield_config.top_k, len(questions)) + t0 = time.time() + hag_results = hopfield_pipeline.run_batch(questions) + hag_time = time.time() - t0 + hag_metrics = evaluate_dataset(hag_results, gold_answers) + logger.info("HAG done in %.1fs | EM=%.4f | F1=%.4f", hag_time, hag_metrics["em"], hag_metrics["f1"]) + + # --- Summary --- + logger.info("=" * 60) + logger.info("COMPARISON SUMMARY") + logger.info("%-20s %10s %10s", "", "FAISS", "HAG") + logger.info("%-20s %10.4f %10.4f", "Exact Match", faiss_metrics["em"], hag_metrics["em"]) + logger.info("%-20s %10.4f %10.4f", "F1 Score", faiss_metrics["f1"], hag_metrics["f1"]) + logger.info("%-20s %10.1fs %10.1fs", "Time", faiss_time, hag_time) + em_delta = hag_metrics["em"] - faiss_metrics["em"] + f1_delta = hag_metrics["f1"] - faiss_metrics["f1"] + logger.info("%-20s %+10.4f %+10.4f", "Delta (HAG - FAISS)", em_delta, f1_delta) + + # --- Per-question details --- + per_question = [] + for i, (fq, hq, gold) in enumerate(zip(faiss_results, hag_results, gold_answers)): + per_question.append({ + "id": questions_data[i].get("id", i), + "question": questions[i], + "gold_answer": gold, + "faiss_answer": fq.answer, + "hag_answer": hq.answer, + "faiss_em": exact_match(fq.answer, gold), + "hag_em": exact_match(hq.answer, gold), + "faiss_f1": f1_score(fq.answer, gold), + "hag_f1": f1_score(hq.answer, gold), + "faiss_passages": fq.retrieved_passages, + "hag_passages": hq.retrieved_passages, + }) + + output = { + "config": { + "hopfield_beta": hopfield_config.beta, + "hopfield_max_iter": hopfield_config.max_iter, + "top_k": hopfield_config.top_k, + "encoder": encoder_config.model_name, + "generator": generator_config.model_name, + "num_questions": len(questions), + "num_passages": mb.size, + }, + "faiss_metrics": {**faiss_metrics, "time_seconds": faiss_time}, + "hag_metrics": {**hag_metrics, "time_seconds": hag_time}, + "per_question": per_question, + } + + with open(args.output, "w") as f: + json.dump(output, f, indent=2, ensure_ascii=False) + logger.info("Full results saved to %s", args.output) + + +if __name__ == "__main__": + main() diff --git a/scripts/run_eval.py b/scripts/run_eval.py index 713b3c2..144fc2f 100644 --- a/scripts/run_eval.py +++ b/scripts/run_eval.py @@ -36,6 +36,7 @@ def main() -> None: parser.add_argument("--split", type=str, default="validation") parser.add_argument("--max-samples", type=int, default=500) parser.add_argument("--output", type=str, default="results.json") + parser.add_argument("--device", type=str, default="cpu") args = parser.parse_args() with open(args.config) as f: @@ -47,15 +48,16 @@ def main() -> None: encoder=EncoderConfig(**cfg.get("encoder", {})), generator=GeneratorConfig(**cfg.get("generator", {})), retriever_type=cfg.get("retriever_type", "hopfield"), + device=args.device, ) # Load memory bank mb = MemoryBank(pipeline_config.memory) - mb.load(args.memory_bank) + mb.load(args.memory_bank, device=args.device) # Build pipeline - encoder = Encoder(pipeline_config.encoder) - generator = Generator(pipeline_config.generator) + encoder = Encoder(pipeline_config.encoder, device=args.device) + generator = Generator(pipeline_config.generator, device=args.device) pipeline = RAGPipeline( config=pipeline_config, encoder=encoder, diff --git a/scripts/run_grid_search.py b/scripts/run_grid_search.py new file mode 100644 index 0000000..ddd5a8d --- /dev/null +++ b/scripts/run_grid_search.py @@ -0,0 +1,552 @@ +"""Grid search over HAG hyperparameters (beta, max_iter) with dedup-based LLM caching. + +Key insight: many (beta, max_iter) combos retrieve the same top-k passages for a given +question. By deduplicating on (question_idx, frozenset(top_k_indices)), we call the LLM +only for unique passage sets, saving ~80-89% of generation calls. + +Usage: + python scripts/run_grid_search.py \ + --config configs/hotpotqa.yaml \ + --memory-bank data/processed/hotpotqa_memory_bank.pt \ + --questions data/processed/hotpotqa_questions.jsonl \ + --device cuda \ + --max-samples 100 +""" + +import argparse +import json +import logging +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch + +from hag.config import ( + EncoderConfig, + GeneratorConfig, + HopfieldConfig, + MemoryBankConfig, + PipelineConfig, +) +from hag.encoder import Encoder +from hag.energy import compute_attention_entropy, compute_energy_curve, compute_energy_gap +from hag.generator import Generator +from hag.hopfield import HopfieldRetrieval +from hag.memory_bank import MemoryBank +from hag.metrics import exact_match, f1_score +from hag.retriever_faiss import FAISSRetriever + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") +logger = logging.getLogger(__name__) + + +@dataclass +class GridPoint: + """Results for a single (beta, max_iter) configuration.""" + + beta: float + max_iter: int + em: float + f1: float + avg_entropy: float + avg_energy_gap: float + avg_faiss_overlap: float + avg_steps: float + + +def load_questions(path: str, max_samples: Optional[int] = None) -> Tuple[List[str], List[str]]: + """Load questions and gold answers from JSONL file. + + Args: + path: path to JSONL file with 'question' and 'answer' fields + max_samples: if set, limit to first N samples + + Returns: + Tuple of (questions, gold_answers). + """ + questions = [] + gold_answers = [] + with open(path) as f: + for line in f: + record = json.loads(line) + questions.append(record["question"]) + gold_answers.append(record["answer"]) + if max_samples and len(questions) >= max_samples: + break + return questions, gold_answers + + +def encode_questions_batched( + encoder: Encoder, questions: List[str], batch_size: int = 32 +) -> torch.Tensor: + """Encode all questions into embeddings, batched for efficiency. + + Args: + encoder: the encoder instance + questions: list of question strings + batch_size: encoding batch size + + Returns: + (N, d) tensor of query embeddings. + """ + all_embeddings = [] + for i in range(0, len(questions), batch_size): + batch = questions[i : i + batch_size] + embs = encoder.encode(batch) # (batch_size, d) + all_embeddings.append(embs) + return torch.cat(all_embeddings, dim=0) # (N, d) + + +def run_faiss_baseline( + query_embeddings: torch.Tensor, + memory_bank: MemoryBank, + generator: Generator, + questions: List[str], + gold_answers: List[str], + top_k: int, +) -> Tuple[Dict[str, float], Dict[int, Tuple[str, Tuple[int, ...]]]]: + """Run FAISS baseline and cache results. + + Args: + query_embeddings: (N, d) tensor + memory_bank: the memory bank + generator: LLM generator + questions: list of question strings + gold_answers: list of gold answer strings + top_k: number of passages to retrieve + + Returns: + Tuple of (metrics_dict, faiss_cache). + faiss_cache maps question_idx -> (answer, top_k_indices_tuple). + """ + logger.info("Building FAISS index...") + embeddings_np = memory_bank.embeddings.T.cpu().numpy().astype(np.float32) # (N_passages, d) + faiss_ret = FAISSRetriever(top_k=top_k) + faiss_ret.build_index(embeddings_np, memory_bank.passages) + + faiss_cache: Dict[int, Tuple[str, Tuple[int, ...]]] = {} + em_scores = [] + f1_scores = [] + + logger.info("Running FAISS baseline on %d questions...", len(questions)) + for i, question in enumerate(questions): + query_np = query_embeddings[i].cpu().numpy().astype(np.float32) # (d,) + result = faiss_ret.retrieve(query_np) + answer = generator.generate(question, result.passages) + indices_tuple = tuple(sorted(result.indices.tolist())) + + faiss_cache[i] = (answer, indices_tuple) + em_scores.append(exact_match(answer, gold_answers[i])) + f1_scores.append(f1_score(answer, gold_answers[i])) + + if (i + 1) % 20 == 0: + logger.info( + " FAISS baseline: %d/%d (EM=%.3f, F1=%.3f)", + i + 1, + len(questions), + sum(em_scores) / len(em_scores), + sum(f1_scores) / len(f1_scores), + ) + + metrics = { + "em": sum(em_scores) / len(em_scores), + "f1": sum(f1_scores) / len(f1_scores), + } + logger.info("FAISS baseline: EM=%.4f, F1=%.4f", metrics["em"], metrics["f1"]) + return metrics, faiss_cache + + +def run_hopfield_grid( + query_embeddings: torch.Tensor, + memory_bank: MemoryBank, + generator: Generator, + questions: List[str], + gold_answers: List[str], + faiss_cache: Dict[int, Tuple[str, Tuple[int, ...]]], + betas: List[float], + max_iters: List[int], + top_k: int, + device: str, +) -> Tuple[List[GridPoint], Dict]: + """Run grid search over (beta, max_iter) with dedup-based LLM caching. + + Phase 2: Retrieve all configs (fast, batched). + Phase 3: Deduplicate and generate (LLM calls only for unique passage sets). + Phase 4: Evaluate and collect results. + + Args: + query_embeddings: (N, d) tensor on device + memory_bank: memory bank (embeddings on device) + generator: LLM generator + questions: list of question strings + gold_answers: list of gold answer strings + faiss_cache: maps question_idx -> (answer, sorted_indices_tuple) + betas: list of beta values to sweep + max_iters: list of max_iter values to sweep + top_k: fixed top_k for retrieval + device: computation device + + Returns: + Tuple of (grid_results, meta_dict). + """ + n_questions = len(questions) + memory = memory_bank.embeddings # (d, N_passages) on device + + # ========================================================================= + # Phase 2: Retrieve all configurations (batched, milliseconds each) + # ========================================================================= + logger.info("Phase 2: Running %d retrieval configs...", len(betas) * len(max_iters)) + + # Structure: config_key -> per-question retrieval data + # retrieval_data[config_key][q_idx] = {indices_tuple, entropy, energy_gap, steps, faiss_overlap} + @dataclass + class RetrievalInfo: + indices_tuple: Tuple[int, ...] + entropy: float + energy_gap: float + steps: int + faiss_overlap: float + + retrieval_data: Dict[Tuple[float, int], List[RetrievalInfo]] = {} + + t_retrieve_start = time.time() + for beta in betas: + for max_iter in max_iters: + config = HopfieldConfig(beta=beta, max_iter=max_iter, top_k=top_k) + hopfield = HopfieldRetrieval(config) + + # Batched retrieval: all questions at once + result = hopfield.retrieve( + query_embeddings, memory, return_energy=True + ) # attention_weights: (N_questions, N_passages) + + alpha = result.attention_weights # (N_questions, N_passages) + k = min(top_k, alpha.shape[-1]) + scores, indices = torch.topk(alpha, k, dim=-1) # (N_questions, k) + + # Compute energy curve per-question (energy_curve contains batch tensors) + energy_curves_raw = result.energy_curve # list of (N_questions,) tensors + + infos = [] + for q_idx in range(n_questions): + q_indices = sorted(indices[q_idx].tolist()) + q_indices_tuple = tuple(q_indices) + + # Per-question entropy + q_entropy = compute_attention_entropy(alpha[q_idx]) + + # Per-question energy gap + if energy_curves_raw is not None: + q_energies = [e[q_idx].item() for e in energy_curves_raw] + q_energy_gap = compute_energy_gap(q_energies) + else: + q_energy_gap = 0.0 + + # FAISS overlap: fraction of top-k indices shared with FAISS + faiss_indices_set = set(faiss_cache[q_idx][1]) + hopfield_indices_set = set(q_indices) + overlap = len(faiss_indices_set & hopfield_indices_set) / k + + infos.append(RetrievalInfo( + indices_tuple=q_indices_tuple, + entropy=q_entropy, + energy_gap=q_energy_gap, + steps=result.num_steps, + faiss_overlap=overlap, + )) + + retrieval_data[(beta, max_iter)] = infos + + t_retrieve_end = time.time() + logger.info("Phase 2 complete: %.2fs for all retrieval configs", t_retrieve_end - t_retrieve_start) + + # ========================================================================= + # Phase 3: Deduplicate and generate + # ========================================================================= + logger.info("Phase 3: Deduplicating and generating...") + + # Build set of unique (question_idx, passage_set) combos needing LLM calls + # Cache key: (question_idx, frozenset(top_k_indices)) + llm_cache: Dict[Tuple[int, frozenset], str] = {} + + # Seed cache with FAISS answers (same passage sets don't need re-generation) + for q_idx, (answer, indices_tuple) in faiss_cache.items(): + cache_key = (q_idx, frozenset(indices_tuple)) + llm_cache[cache_key] = answer + + # Collect all unique keys we need + needed_keys: Dict[Tuple[int, frozenset], Tuple[int, Tuple[int, ...]]] = {} + for (beta, max_iter), infos in retrieval_data.items(): + for q_idx, info in enumerate(infos): + cache_key = (q_idx, frozenset(info.indices_tuple)) + if cache_key not in llm_cache and cache_key not in needed_keys: + needed_keys[cache_key] = (q_idx, info.indices_tuple) + + total_grid_calls = n_questions * len(betas) * len(max_iters) + already_cached = total_grid_calls - len(needed_keys) # rough; some may still be unique + logger.info( + "Unique LLM calls needed: %d (out of %d grid points, %.1f%% saving)", + len(needed_keys), + total_grid_calls, + (1 - len(needed_keys) / total_grid_calls) * 100 if total_grid_calls > 0 else 0, + ) + + # Generate answers for unique passage sets + t_gen_start = time.time() + for call_idx, (cache_key, (q_idx, indices_tuple)) in enumerate(needed_keys.items()): + # Look up passages by sorted indices + indices_tensor = torch.tensor(list(indices_tuple), dtype=torch.long) + passages = memory_bank.get_passages_by_indices(indices_tensor) + answer = generator.generate(questions[q_idx], passages) + llm_cache[cache_key] = answer + + if (call_idx + 1) % 20 == 0: + elapsed = time.time() - t_gen_start + rate = (call_idx + 1) / elapsed + remaining = (len(needed_keys) - call_idx - 1) / rate + logger.info( + " Generated %d/%d (%.1f calls/s, ~%.0fs remaining)", + call_idx + 1, + len(needed_keys), + rate, + remaining, + ) + + t_gen_end = time.time() + logger.info("Phase 3 complete: %d LLM calls in %.1fs", len(needed_keys), t_gen_end - t_gen_start) + + # ========================================================================= + # Phase 4: Evaluate all grid points + # ========================================================================= + logger.info("Phase 4: Evaluating all grid points...") + + grid_results: List[GridPoint] = [] + for beta in betas: + for max_iter in max_iters: + infos = retrieval_data[(beta, max_iter)] + em_scores = [] + f1_scores = [] + entropies = [] + energy_gaps = [] + faiss_overlaps = [] + steps_list = [] + + for q_idx, info in enumerate(infos): + cache_key = (q_idx, frozenset(info.indices_tuple)) + answer = llm_cache[cache_key] + + em_scores.append(exact_match(answer, gold_answers[q_idx])) + f1_scores.append(f1_score(answer, gold_answers[q_idx])) + entropies.append(info.entropy) + energy_gaps.append(info.energy_gap) + faiss_overlaps.append(info.faiss_overlap) + steps_list.append(info.steps) + + gp = GridPoint( + beta=beta, + max_iter=max_iter, + em=sum(em_scores) / len(em_scores), + f1=sum(f1_scores) / len(f1_scores), + avg_entropy=sum(entropies) / len(entropies), + avg_energy_gap=sum(energy_gaps) / len(energy_gaps), + avg_faiss_overlap=sum(faiss_overlaps) / len(faiss_overlaps), + avg_steps=sum(steps_list) / len(steps_list), + ) + grid_results.append(gp) + logger.info( + " beta=%.2f max_iter=%2d => EM=%.3f F1=%.3f entropy=%.3f energy_gap=%.3f faiss_overlap=%.3f", + beta, + max_iter, + gp.em, + gp.f1, + gp.avg_entropy, + gp.avg_energy_gap, + gp.avg_faiss_overlap, + ) + + total_llm_calls = len(faiss_cache) + len(needed_keys) + meta = { + "grid_size": len(betas) * len(max_iters), + "n_questions": n_questions, + "total_grid_evaluations": total_grid_calls, + "unique_llm_calls": len(needed_keys), + "faiss_llm_calls": len(faiss_cache), + "total_llm_calls": total_llm_calls, + "savings_pct": round( + (1 - total_llm_calls / (total_grid_calls + len(faiss_cache))) * 100, 1 + ) + if (total_grid_calls + len(faiss_cache)) > 0 + else 0, + "retrieval_time_s": round(t_retrieve_end - t_retrieve_start, 2), + "generation_time_s": round(t_gen_end - t_gen_start, 2), + } + + return grid_results, meta + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Grid search over HAG hyperparameters (beta, max_iter)" + ) + parser.add_argument("--config", type=str, default="configs/hotpotqa.yaml") + parser.add_argument("--memory-bank", type=str, required=True) + parser.add_argument("--questions", type=str, required=True) + parser.add_argument("--device", type=str, default="cpu") + parser.add_argument("--max-samples", type=int, default=100) + parser.add_argument( + "--output", + type=str, + default=None, + help="Output JSON path (default: data/processed/grid_search_results.json)", + ) + parser.add_argument( + "--betas", + type=float, + nargs="+", + default=[0.25, 0.5, 1.0, 2.0, 3.0, 5.0, 8.0], + ) + parser.add_argument( + "--max-iters", + type=int, + nargs="+", + default=[1, 2, 3, 5, 8, 15], + ) + parser.add_argument("--top-k", type=int, default=5) + args = parser.parse_args() + + import yaml + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + output_path = args.output or "data/processed/grid_search_results.json" + + # ========================================================================= + # Phase 1: Load everything once + # ========================================================================= + logger.info("=" * 60) + logger.info("HAG Grid Search") + logger.info(" betas: %s", args.betas) + logger.info(" max_iters: %s", args.max_iters) + logger.info(" top_k: %d", args.top_k) + logger.info(" grid points: %d", len(args.betas) * len(args.max_iters)) + logger.info(" max_samples: %d", args.max_samples) + logger.info(" device: %s", args.device) + logger.info("=" * 60) + + t_start = time.time() + + # Load questions + logger.info("Loading questions from %s...", args.questions) + questions, gold_answers = load_questions(args.questions, args.max_samples) + logger.info("Loaded %d questions", len(questions)) + + # Load memory bank + logger.info("Loading memory bank from %s...", args.memory_bank) + mb_config = MemoryBankConfig(**cfg.get("memory", {})) + memory_bank = MemoryBank(mb_config) + memory_bank.load(args.memory_bank, device=args.device) + logger.info("Memory bank: %d passages, dim=%d", memory_bank.size, memory_bank.dim) + + # Load encoder + logger.info("Loading encoder...") + encoder_config = EncoderConfig(**cfg.get("encoder", {})) + encoder = Encoder(encoder_config, device=args.device) + + # Load generator + logger.info("Loading generator...") + generator_config = GeneratorConfig(**cfg.get("generator", {})) + generator = Generator(generator_config, device=args.device) + + # Encode all questions once + logger.info("Encoding %d questions...", len(questions)) + t_enc_start = time.time() + query_embeddings = encode_questions_batched( + encoder, questions, batch_size=encoder_config.batch_size + ) # (N, d) on device + t_enc_end = time.time() + logger.info("Encoded in %.2fs, shape=%s", t_enc_end - t_enc_start, query_embeddings.shape) + + # ========================================================================= + # Run FAISS baseline + # ========================================================================= + faiss_metrics, faiss_cache = run_faiss_baseline( + query_embeddings, memory_bank, generator, questions, gold_answers, args.top_k + ) + + # ========================================================================= + # Run Hopfield grid search + # ========================================================================= + grid_results, meta = run_hopfield_grid( + query_embeddings, + memory_bank, + generator, + questions, + gold_answers, + faiss_cache, + betas=args.betas, + max_iters=args.max_iters, + top_k=args.top_k, + device=args.device, + ) + + # ========================================================================= + # Find best config and save results + # ========================================================================= + best = max(grid_results, key=lambda gp: gp.f1) + + t_total = time.time() - t_start + meta["total_time_s"] = round(t_total, 1) + + output = { + "meta": meta, + "faiss_baseline": faiss_metrics, + "grid_results": [ + { + "beta": gp.beta, + "max_iter": gp.max_iter, + "em": round(gp.em, 4), + "f1": round(gp.f1, 4), + "avg_entropy": round(gp.avg_entropy, 4), + "avg_energy_gap": round(gp.avg_energy_gap, 4), + "avg_faiss_overlap": round(gp.avg_faiss_overlap, 4), + "avg_steps": round(gp.avg_steps, 2), + } + for gp in grid_results + ], + "best_config": { + "beta": best.beta, + "max_iter": best.max_iter, + "em": round(best.em, 4), + "f1": round(best.f1, 4), + "avg_entropy": round(best.avg_entropy, 4), + "avg_energy_gap": round(best.avg_energy_gap, 4), + "avg_faiss_overlap": round(best.avg_faiss_overlap, 4), + }, + } + + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as f: + json.dump(output, f, indent=2) + + logger.info("=" * 60) + logger.info("RESULTS SUMMARY") + logger.info(" FAISS baseline: EM=%.4f, F1=%.4f", faiss_metrics["em"], faiss_metrics["f1"]) + logger.info( + " Best HAG config: beta=%.2f, max_iter=%d => EM=%.4f, F1=%.4f", + best.beta, + best.max_iter, + best.em, + best.f1, + ) + logger.info(" Total LLM calls: %d (saved %.1f%%)", meta["total_llm_calls"], meta["savings_pct"]) + logger.info(" Total time: %.1fs", t_total) + logger.info(" Results saved to: %s", output_path) + logger.info("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/scripts/run_hag.py b/scripts/run_hag.py index 4cacd1a..b6c9004 100644 --- a/scripts/run_hag.py +++ b/scripts/run_hag.py @@ -33,6 +33,7 @@ def main() -> None: parser.add_argument("--beta", type=float, default=None) parser.add_argument("--max-iter", type=int, default=None) parser.add_argument("--top-k", type=int, default=None) + parser.add_argument("--device", type=str, default="cpu") args = parser.parse_args() with open(args.config) as f: @@ -52,13 +53,14 @@ def main() -> None: encoder=EncoderConfig(**cfg.get("encoder", {})), generator=GeneratorConfig(**cfg.get("generator", {})), retriever_type="hopfield", + device=args.device, ) mb = MemoryBank(pipeline_config.memory) - mb.load(args.memory_bank) + mb.load(args.memory_bank, device=args.device) - encoder = Encoder(pipeline_config.encoder) - generator = Generator(pipeline_config.generator) + encoder = Encoder(pipeline_config.encoder, device=args.device) + generator = Generator(pipeline_config.generator, device=args.device) pipeline = RAGPipeline( config=pipeline_config, diff --git a/scripts/visualize_energy.py b/scripts/visualize_energy.py new file mode 100644 index 0000000..f39953c --- /dev/null +++ b/scripts/visualize_energy.py @@ -0,0 +1,443 @@ +"""Visualize Hopfield energy landscape: centered vs uncentered. + +Produces 4 figures, each with centered/uncentered side-by-side: + 1. 2D contour + Hopfield trajectory + 2. 1D energy profile along key directions + 3. UMAP of memories + query trajectories + 4. PCA top-2 energy heatmap + +Usage: + CUDA_VISIBLE_DEVICES=1 python -u scripts/visualize_energy.py \ + --memory-bank data/processed/hotpotqa_memory_bank.pt \ + --questions data/processed/hotpotqa_questions.jsonl \ + --device cuda --query-idx 0 +""" + +import argparse +import json +import sys +from pathlib import Path + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import matplotlib.colors as mcolors +import numpy as np +import torch +import torch.nn.functional as F + +sys.path.insert(0, "/home/yurenh2/HAG") + +from hag.config import MemoryBankConfig, EncoderConfig +from hag.memory_bank import MemoryBank +from hag.encoder import Encoder + + +# ── Helpers ────────────────────────────────────────────────────────── + +def compute_energy(q: torch.Tensor, M: torch.Tensor, beta: float) -> torch.Tensor: + """E(q) = -1/β · logsumexp(β · qᵀM) + 1/2 · ‖q‖² + + Args: + q: (..., d) + M: (d, N) + beta: inverse temperature + Returns: + energy: (...) + """ + logits = beta * (q @ M) # (..., N) + lse = torch.logsumexp(logits, dim=-1) # (...) + norm_sq = 0.5 * (q ** 2).sum(dim=-1) # (...) + return -1.0 / beta * lse + norm_sq + + +def hopfield_trajectory(q0: torch.Tensor, M: torch.Tensor, beta: float, + max_iter: int = 15) -> torch.Tensor: + """Run Hopfield and return full trajectory. Returns (T+1, d).""" + q = q0.clone().unsqueeze(0) if q0.dim() == 1 else q0.clone() # (1, d) + traj = [q.squeeze(0).clone()] + for _ in range(max_iter): + logits = beta * (q @ M) + alpha = torch.softmax(logits, dim=-1) + q_new = alpha @ M.T + traj.append(q_new.squeeze(0).clone()) + if (q_new - q).norm() < 1e-8: + break + q = q_new + return torch.stack(traj, dim=0) # (T+1, d) + + +def orthonormalize(v1: torch.Tensor, v2: torch.Tensor): + """Return two orthonormal vectors spanning the plane of v1, v2.""" + e1 = v1 / v1.norm() + v2_orth = v2 - (v2 @ e1) * e1 + if v2_orth.norm() < 1e-8: + # v1 and v2 are parallel, pick a random orthogonal direction + rand = torch.randn_like(v1) + v2_orth = rand - (rand @ e1) * e1 + e2 = v2_orth / v2_orth.norm() + return e1, e2 + + +def project_to_plane(points: torch.Tensor, e1: torch.Tensor, e2: torch.Tensor): + """Project (K, d) points onto 2D plane defined by e1, e2. Returns (K, 2).""" + return torch.stack([points @ e1, points @ e2], dim=-1) + + +# ── Figure 1: 2D Contour + Trajectory ─────────────────────────────── + +def fig1_contour(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device): + """2D energy contour on the query-centroid plane, with Hopfield trajectories.""" + + centroid = M_raw.mean(dim=1) # (d,) + q0_cent = q0_raw - mu + + fig, axes = plt.subplots(2, len(betas_plot), figsize=(6 * len(betas_plot), 12), + squeeze=False) + + for col, beta in enumerate(betas_plot): + for row, (label, M, q0, ref_point, ref_label) in enumerate([ + ("Uncentered", M_raw, q0_raw, centroid, "centroid"), + ("Centered", M_cent, q0_cent, torch.zeros_like(centroid), "origin"), + ]): + ax = axes[row, col] + + # Define 2D plane: query direction + centroid/origin direction + e1, e2 = orthonormalize(q0.to(device), ref_point.to(device) if ref_point.norm() > 1e-6 else M.to(device)[:, 0]) + + # Grid + grid_range = 1.5 + n_grid = 150 + xs = torch.linspace(-grid_range, grid_range, n_grid, device=device) + ys = torch.linspace(-grid_range, grid_range, n_grid, device=device) + xx, yy = torch.meshgrid(xs, ys, indexing='ij') + grid_points = xx.reshape(-1, 1) * e1.unsqueeze(0) + yy.reshape(-1, 1) * e2.unsqueeze(0) # (n^2, d) + + E = compute_energy(grid_points, M.to(device), beta).reshape(n_grid, n_grid).cpu().numpy() + + # Trajectory + traj = hopfield_trajectory(q0.to(device), M.to(device), beta, max_iter=15) + traj_2d = project_to_plane(traj, e1, e2).cpu().numpy() + + # Project memories + mem_2d = project_to_plane(M.T.to(device), e1, e2).cpu().numpy() + + # Project reference point + ref_2d = project_to_plane(ref_point.unsqueeze(0).to(device), e1, e2).cpu().numpy() + + # Plot + xx_np, yy_np = xx.cpu().numpy(), yy.cpu().numpy() + # Clip energy for better visualization + E_clip = np.clip(E, np.percentile(E, 1), np.percentile(E, 95)) + cs = ax.contourf(xx_np, yy_np, E_clip, levels=40, cmap='viridis') + ax.contour(xx_np, yy_np, E_clip, levels=15, colors='white', linewidths=0.3, alpha=0.5) + + # Memories (small dots) + ax.scatter(mem_2d[:, 0], mem_2d[:, 1], c='white', s=3, alpha=0.3, zorder=2) + + # Reference point + if ref_point.norm() > 1e-6: + ax.scatter(ref_2d[:, 0], ref_2d[:, 1], c='red', s=100, marker='*', + zorder=5, label=ref_label) + else: + ax.scatter(0, 0, c='red', s=100, marker='*', zorder=5, label='origin') + + # Trajectory + ax.plot(traj_2d[:, 0], traj_2d[:, 1], 'o-', color='#ff6600', markersize=4, + linewidth=2, zorder=4, label='trajectory') + ax.scatter(traj_2d[0, 0], traj_2d[0, 1], c='lime', s=80, marker='s', + zorder=5, label='q₀') + ax.scatter(traj_2d[-1, 0], traj_2d[-1, 1], c='magenta', s=80, marker='D', + zorder=5, label=f'q_T (T={len(traj_2d)-1})') + + ax.set_title(f"{label}, β={beta}", fontsize=13, fontweight='bold') + ax.set_xlabel("e₁ (query dir)") + ax.set_ylabel("e₂") + ax.legend(fontsize=7, loc='upper right') + plt.colorbar(cs, ax=ax, shrink=0.8, label='E(q)') + + fig.suptitle("Fig 1: 2D Energy Contour + Hopfield Trajectory", fontsize=15, fontweight='bold') + fig.tight_layout() + fig.savefig(outdir / "fig1_contour.png", dpi=150, bbox_inches='tight') + plt.close(fig) + print(f"Saved {outdir / 'fig1_contour.png'}") + + +# ── Figure 2: 1D Energy Profile ───────────────────────────────────── + +def fig2_profile(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device): + """1D energy along key directions.""" + + centroid = M_raw.mean(dim=1) + q0_cent = q0_raw - mu + + # Find top-1 memory for each + scores_raw = q0_raw @ M_raw + top1_raw_idx = scores_raw.argmax().item() + top1_raw = M_raw[:, top1_raw_idx] + + scores_cent = q0_cent @ M_cent + top1_cent_idx = scores_cent.argmax().item() + top1_cent = M_cent[:, top1_cent_idx] + + fig, axes = plt.subplots(2, len(betas_plot), figsize=(6 * len(betas_plot), 10), + squeeze=False) + + ts = torch.linspace(-0.5, 2.0, 300, device=device) + + for col, beta in enumerate(betas_plot): + # Uncentered + ax = axes[0, col] + for target, name, color in [ + (centroid, "→ centroid", "red"), + (top1_raw, f"→ memory[{top1_raw_idx}]", "blue"), + (torch.zeros_like(q0_raw), "→ origin", "gray"), + ]: + direction = target - q0_raw.to(device) + if direction.norm() < 1e-8: + continue + points = q0_raw.unsqueeze(0).to(device) + ts.unsqueeze(1) * direction.unsqueeze(0) + E = compute_energy(points, M_raw.to(device), beta).cpu().numpy() + ax.plot(ts.cpu().numpy(), E, label=name, color=color, linewidth=2) + + # Mark t=0 (query) and t=1 (target) + E_q0 = compute_energy(q0_raw.unsqueeze(0).to(device), M_raw.to(device), beta).item() + ax.axvline(0, color='lime', linestyle='--', alpha=0.5, label='q₀') + ax.axvline(1, color='black', linestyle=':', alpha=0.5, label='target') + ax.set_title(f"Uncentered, β={beta}", fontsize=13, fontweight='bold') + ax.set_xlabel("t (q₀ + t·(target - q₀))") + ax.set_ylabel("E(q)") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # Centered + ax = axes[1, col] + for target, name, color in [ + (torch.zeros_like(q0_cent), "→ origin", "red"), + (top1_cent, f"→ memory[{top1_cent_idx}]", "blue"), + ]: + direction = target.to(device) - q0_cent.to(device) + if direction.norm() < 1e-8: + continue + points = q0_cent.unsqueeze(0).to(device) + ts.unsqueeze(1) * direction.unsqueeze(0) + E = compute_energy(points, M_cent.to(device), beta).cpu().numpy() + ax.plot(ts.cpu().numpy(), E, label=name, color=color, linewidth=2) + + ax.axvline(0, color='lime', linestyle='--', alpha=0.5, label='q₀') + ax.axvline(1, color='black', linestyle=':', alpha=0.5, label='target') + ax.set_title(f"Centered, β={beta}", fontsize=13, fontweight='bold') + ax.set_xlabel("t (q̃₀ + t·(target - q̃₀))") + ax.set_ylabel("E(q)") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + fig.suptitle("Fig 2: 1D Energy Profile Along Key Directions", fontsize=15, fontweight='bold') + fig.tight_layout() + fig.savefig(outdir / "fig2_profile.png", dpi=150, bbox_inches='tight') + plt.close(fig) + print(f"Saved {outdir / 'fig2_profile.png'}") + + +# ── Figure 3: UMAP + Trajectories ─────────────────────────────────── + +def fig3_umap(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device): + """UMAP of memories + query trajectories.""" + try: + import umap + except ImportError: + print("umap-learn not installed, skipping fig3") + return + + centroid = M_raw.mean(dim=1) + q0_cent = q0_raw - mu + + fig, axes = plt.subplots(2, len(betas_plot), figsize=(6 * len(betas_plot), 12), + squeeze=False) + + for col, beta in enumerate(betas_plot): + for row, (label, M, q0) in enumerate([ + ("Uncentered", M_raw, q0_raw), + ("Centered", M_cent, q0_cent), + ]): + ax = axes[row, col] + + # Trajectory + traj = hopfield_trajectory(q0.to(device), M.to(device), beta, max_iter=15) + traj_cpu = traj.cpu() + + # Combine memories + trajectory for UMAP + mem_cpu = M.T.cpu() # (N, d) + all_points = torch.cat([mem_cpu, traj_cpu], dim=0).numpy() + + # Fit UMAP + reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, random_state=42) + embedding = reducer.fit_transform(all_points) + + n_mem = mem_cpu.shape[0] + mem_emb = embedding[:n_mem] + traj_emb = embedding[n_mem:] + + # Energy for color + E_mem = compute_energy(mem_cpu.to(device), M.to(device), beta).cpu().numpy() + + # Plot memories colored by energy + sc = ax.scatter(mem_emb[:, 0], mem_emb[:, 1], c=E_mem, cmap='viridis', + s=10, alpha=0.5, zorder=1) + + # Plot trajectory + ax.plot(traj_emb[:, 0], traj_emb[:, 1], 'o-', color='#ff6600', + markersize=5, linewidth=2, zorder=3, label='trajectory') + ax.scatter(traj_emb[0, 0], traj_emb[0, 1], c='lime', s=100, + marker='s', zorder=4, label='q₀') + ax.scatter(traj_emb[-1, 0], traj_emb[-1, 1], c='magenta', s=100, + marker='D', zorder=4, label=f'q_T') + + ax.set_title(f"{label}, β={beta}", fontsize=13, fontweight='bold') + ax.legend(fontsize=8, loc='upper right') + plt.colorbar(sc, ax=ax, shrink=0.8, label='E(q)') + + fig.suptitle("Fig 3: UMAP of Memories + Hopfield Trajectory (color = energy)", + fontsize=15, fontweight='bold') + fig.tight_layout() + fig.savefig(outdir / "fig3_umap.png", dpi=150, bbox_inches='tight') + plt.close(fig) + print(f"Saved {outdir / 'fig3_umap.png'}") + + +# ── Figure 4: PCA Top-2 Energy Heatmap ────────────────────────────── + +def fig4_pca(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device): + """Energy heatmap on PCA top-2 components of memory bank.""" + + centroid = M_raw.mean(dim=1) + q0_cent = q0_raw - mu + + fig, axes = plt.subplots(2, len(betas_plot), figsize=(6 * len(betas_plot), 12), + squeeze=False) + + for row, (label, M, q0, ref_point, ref_label) in enumerate([ + ("Uncentered", M_raw, q0_raw, centroid, "centroid"), + ("Centered", M_cent, q0_cent, torch.zeros_like(centroid), "origin"), + ]): + # PCA on this memory bank + M_cpu = M.cpu() # (d, N) + # SVD of M to get top-2 directions + U, S, Vh = torch.linalg.svd(M_cpu, full_matrices=False) + pc1 = U[:, 0].to(device) # (d,) + pc2 = U[:, 1].to(device) # (d,) + + for col, beta in enumerate(betas_plot): + ax = axes[row, col] + + # Grid in PCA space + grid_range = 1.5 + n_grid = 150 + xs = torch.linspace(-grid_range, grid_range, n_grid, device=device) + ys = torch.linspace(-grid_range, grid_range, n_grid, device=device) + xx, yy = torch.meshgrid(xs, ys, indexing='ij') + grid_points = xx.reshape(-1, 1) * pc1.unsqueeze(0) + yy.reshape(-1, 1) * pc2.unsqueeze(0) + + E = compute_energy(grid_points, M.to(device), beta).reshape(n_grid, n_grid).cpu().numpy() + + # Trajectory + traj = hopfield_trajectory(q0.to(device), M.to(device), beta, max_iter=15) + traj_2d = project_to_plane(traj, pc1, pc2).cpu().numpy() + + # Memories projected + mem_2d = project_to_plane(M.T.to(device), pc1, pc2).cpu().numpy() + + # Reference point + ref_2d = project_to_plane(ref_point.unsqueeze(0).to(device), pc1, pc2).cpu().numpy() + + # Plot + xx_np, yy_np = xx.cpu().numpy(), yy.cpu().numpy() + E_clip = np.clip(E, np.percentile(E, 1), np.percentile(E, 95)) + cs = ax.pcolormesh(xx_np, yy_np, E_clip, cmap='viridis', shading='auto') + ax.contour(xx_np, yy_np, E_clip, levels=15, colors='white', linewidths=0.3, alpha=0.5) + + ax.scatter(mem_2d[:, 0], mem_2d[:, 1], c='white', s=3, alpha=0.3, zorder=2) + + if ref_point.norm() > 1e-6: + ax.scatter(ref_2d[:, 0], ref_2d[:, 1], c='red', s=100, marker='*', + zorder=5, label=ref_label) + else: + ax.scatter(0, 0, c='red', s=100, marker='*', zorder=5, label='origin') + + ax.plot(traj_2d[:, 0], traj_2d[:, 1], 'o-', color='#ff6600', markersize=4, + linewidth=2, zorder=4) + ax.scatter(traj_2d[0, 0], traj_2d[0, 1], c='lime', s=80, marker='s', + zorder=5, label='q₀') + ax.scatter(traj_2d[-1, 0], traj_2d[-1, 1], c='magenta', s=80, marker='D', + zorder=5, label='q_T') + + ax.set_title(f"{label}, β={beta}", fontsize=13, fontweight='bold') + ax.set_xlabel("PC1") + ax.set_ylabel("PC2") + ax.legend(fontsize=7, loc='upper right') + plt.colorbar(cs, ax=ax, shrink=0.8, label='E(q)') + + fig.suptitle("Fig 4: PCA Top-2 Energy Heatmap + Trajectory", fontsize=15, fontweight='bold') + fig.tight_layout() + fig.savefig(outdir / "fig4_pca.png", dpi=150, bbox_inches='tight') + plt.close(fig) + print(f"Saved {outdir / 'fig4_pca.png'}") + + +# ── Main ───────────────────────────────────────────────────────────── + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--memory-bank", type=str, required=True) + parser.add_argument("--questions", type=str, required=True) + parser.add_argument("--device", type=str, default="cpu") + parser.add_argument("--query-idx", type=int, default=0) + parser.add_argument("--outdir", type=str, default="figures") + args = parser.parse_args() + + device = args.device + outdir = Path(args.outdir) + outdir.mkdir(parents=True, exist_ok=True) + + # Load memory bank + mb = MemoryBank(MemoryBankConfig(embedding_dim=768, normalize=True, center=False)) + mb.load(args.memory_bank, device=device) + M_raw = mb.embeddings # (d, N) + d, N = M_raw.shape + print(f"Memory bank: d={d}, N={N}") + + # Center + mu = M_raw.mean(dim=1) # (d,) + M_cent = M_raw - mu.unsqueeze(1) + print(f"‖μ‖ = {mu.norm():.4f}") + + # Load one query + with open(args.questions) as f: + questions = [json.loads(line) for line in f] + + q_text = questions[args.query_idx]["question"] + print(f"Query [{args.query_idx}]: '{q_text}'") + + encoder = Encoder(EncoderConfig(model_name="facebook/contriever-msmarco"), device=device) + q0_raw = encoder.encode([q_text]).squeeze(0) # (d,) + print(f"‖q0_raw‖ = {q0_raw.norm():.4f}") + + # β values: below and above β_critical ≈ 37.6 + betas_plot = [5.0, 20.0, 50.0, 100.0] + + print("\n--- Generating Figure 1: 2D Contour ---") + fig1_contour(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device) + + print("\n--- Generating Figure 2: 1D Profile ---") + fig2_profile(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device) + + print("\n--- Generating Figure 3: UMAP ---") + fig3_umap(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device) + + print("\n--- Generating Figure 4: PCA Heatmap ---") + fig4_pca(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device) + + print(f"\nAll figures saved to {outdir}/") + + +if __name__ == "__main__": + main() diff --git a/scripts/visualize_trajectory.py b/scripts/visualize_trajectory.py index e4ba902..0087563 100644 --- a/scripts/visualize_trajectory.py +++ b/scripts/visualize_trajectory.py @@ -26,6 +26,7 @@ def main() -> None: parser.add_argument("--memory-bank", type=str, required=True) parser.add_argument("--question", type=str, required=True) parser.add_argument("--output", type=str, default="trajectory.png") + parser.add_argument("--device", type=str, default="cpu") args = parser.parse_args() with open(args.config) as f: @@ -36,9 +37,9 @@ def main() -> None: encoder_config = EncoderConfig(**cfg.get("encoder", {})) mb = MemoryBank(memory_config) - mb.load(args.memory_bank) + mb.load(args.memory_bank, device=args.device) - encoder = Encoder(encoder_config) + encoder = Encoder(encoder_config, device=args.device) hopfield = HopfieldRetrieval(hopfield_config) query_emb = encoder.encode(args.question) # (1, d) @@ -46,9 +47,9 @@ def main() -> None: query_emb, mb.embeddings, return_trajectory=True ) - # Gather all points for UMAP: memories + trajectory - memories_np = mb.embeddings.T.numpy() # (N, d) - trajectory_np = np.stack([q.squeeze().numpy() for q in result.trajectory]) # (T+1, d) + # Gather all points for UMAP: memories + trajectory (must be on CPU for numpy) + memories_np = mb.embeddings.T.cpu().numpy() # (N, d) + trajectory_np = np.stack([q.squeeze().cpu().numpy() for q in result.trajectory]) # (T+1, d) all_points = np.concatenate([memories_np, trajectory_np], axis=0) # UMAP projection |
