diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
| commit | c90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch) | |
| tree | 43edac8013fec4e65a0b9cddec5314489b4aafc2 /scripts/analyze_energy.py | |
Core Hopfield retrieval module with energy-based convergence guarantees,
memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end
pipeline. All 45 tests passing on CPU with synthetic data.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'scripts/analyze_energy.py')
| -rw-r--r-- | scripts/analyze_energy.py | 78 |
1 files changed, 78 insertions, 0 deletions
diff --git a/scripts/analyze_energy.py b/scripts/analyze_energy.py new file mode 100644 index 0000000..fd044a4 --- /dev/null +++ b/scripts/analyze_energy.py @@ -0,0 +1,78 @@ +"""Analyze energy curves and convergence properties of Hopfield retrieval. + +Usage: + python scripts/analyze_energy.py --config configs/default.yaml --memory-bank data/memory_bank.pt --questions data/questions.jsonl --output energy_analysis.json +""" + +import argparse +import json +import logging + +import torch +import yaml + +from hag.config import EncoderConfig, HopfieldConfig, MemoryBankConfig +from hag.encoder import Encoder +from hag.energy import ( + compute_attention_entropy, + compute_energy_curve, + compute_energy_gap, + verify_monotonic_decrease, +) +from hag.hopfield import HopfieldRetrieval +from hag.memory_bank import MemoryBank + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Analyze Hopfield energy curves") + parser.add_argument("--config", type=str, default="configs/default.yaml") + parser.add_argument("--memory-bank", type=str, required=True) + parser.add_argument("--questions", type=str, required=True) + parser.add_argument("--output", type=str, default="energy_analysis.json") + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + hopfield_config = HopfieldConfig(**cfg.get("hopfield", {})) + memory_config = MemoryBankConfig(**cfg.get("memory", {})) + encoder_config = EncoderConfig(**cfg.get("encoder", {})) + + # Load memory bank + mb = MemoryBank(memory_config) + mb.load(args.memory_bank) + + # Load questions + with open(args.questions) as f: + questions = [json.loads(line)["question"] for line in f] + + encoder = Encoder(encoder_config) + hopfield = HopfieldRetrieval(hopfield_config) + + analyses = [] + for q in questions: + query_emb = encoder.encode(q) # (1, d) + result = hopfield.retrieve( + query_emb, mb.embeddings, return_energy=True, return_trajectory=True + ) + + curve = compute_energy_curve(result) + analyses.append({ + "question": q, + "energy_curve": curve, + "energy_gap": compute_energy_gap(curve), + "monotonic": verify_monotonic_decrease(curve), + "num_steps": result.num_steps, + "attention_entropy": compute_attention_entropy(result.attention_weights), + }) + + with open(args.output, "w") as f: + json.dump(analyses, f, indent=2) + logger.info("Energy analysis saved to %s (%d questions)", args.output, len(analyses)) + + +if __name__ == "__main__": + main() |
