summaryrefslogtreecommitdiff
path: root/scripts/analyze_energy.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
commitc90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch)
tree43edac8013fec4e65a0b9cddec5314489b4aafc2 /scripts/analyze_energy.py
Initial implementation of HAG (Hopfield-Augmented Generation)HEADmaster
Core Hopfield retrieval module with energy-based convergence guarantees, memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end pipeline. All 45 tests passing on CPU with synthetic data. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'scripts/analyze_energy.py')
-rw-r--r--scripts/analyze_energy.py78
1 files changed, 78 insertions, 0 deletions
diff --git a/scripts/analyze_energy.py b/scripts/analyze_energy.py
new file mode 100644
index 0000000..fd044a4
--- /dev/null
+++ b/scripts/analyze_energy.py
@@ -0,0 +1,78 @@
+"""Analyze energy curves and convergence properties of Hopfield retrieval.
+
+Usage:
+ python scripts/analyze_energy.py --config configs/default.yaml --memory-bank data/memory_bank.pt --questions data/questions.jsonl --output energy_analysis.json
+"""
+
+import argparse
+import json
+import logging
+
+import torch
+import yaml
+
+from hag.config import EncoderConfig, HopfieldConfig, MemoryBankConfig
+from hag.encoder import Encoder
+from hag.energy import (
+ compute_attention_entropy,
+ compute_energy_curve,
+ compute_energy_gap,
+ verify_monotonic_decrease,
+)
+from hag.hopfield import HopfieldRetrieval
+from hag.memory_bank import MemoryBank
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Analyze Hopfield energy curves")
+ parser.add_argument("--config", type=str, default="configs/default.yaml")
+ parser.add_argument("--memory-bank", type=str, required=True)
+ parser.add_argument("--questions", type=str, required=True)
+ parser.add_argument("--output", type=str, default="energy_analysis.json")
+ args = parser.parse_args()
+
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ hopfield_config = HopfieldConfig(**cfg.get("hopfield", {}))
+ memory_config = MemoryBankConfig(**cfg.get("memory", {}))
+ encoder_config = EncoderConfig(**cfg.get("encoder", {}))
+
+ # Load memory bank
+ mb = MemoryBank(memory_config)
+ mb.load(args.memory_bank)
+
+ # Load questions
+ with open(args.questions) as f:
+ questions = [json.loads(line)["question"] for line in f]
+
+ encoder = Encoder(encoder_config)
+ hopfield = HopfieldRetrieval(hopfield_config)
+
+ analyses = []
+ for q in questions:
+ query_emb = encoder.encode(q) # (1, d)
+ result = hopfield.retrieve(
+ query_emb, mb.embeddings, return_energy=True, return_trajectory=True
+ )
+
+ curve = compute_energy_curve(result)
+ analyses.append({
+ "question": q,
+ "energy_curve": curve,
+ "energy_gap": compute_energy_gap(curve),
+ "monotonic": verify_monotonic_decrease(curve),
+ "num_steps": result.num_steps,
+ "attention_entropy": compute_attention_entropy(result.attention_weights),
+ })
+
+ with open(args.output, "w") as f:
+ json.dump(analyses, f, indent=2)
+ logger.info("Energy analysis saved to %s (%d questions)", args.output, len(analyses))
+
+
+if __name__ == "__main__":
+ main()