"""Analyze energy curves and convergence properties of Hopfield retrieval. Usage: python scripts/analyze_energy.py --config configs/default.yaml --memory-bank data/memory_bank.pt --questions data/questions.jsonl --output energy_analysis.json """ import argparse import json import logging import torch import yaml from hag.config import EncoderConfig, HopfieldConfig, MemoryBankConfig from hag.encoder import Encoder from hag.energy import ( compute_attention_entropy, compute_energy_curve, compute_energy_gap, verify_monotonic_decrease, ) from hag.hopfield import HopfieldRetrieval from hag.memory_bank import MemoryBank logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def main() -> None: parser = argparse.ArgumentParser(description="Analyze Hopfield energy curves") parser.add_argument("--config", type=str, default="configs/default.yaml") parser.add_argument("--memory-bank", type=str, required=True) parser.add_argument("--questions", type=str, required=True) parser.add_argument("--output", type=str, default="energy_analysis.json") args = parser.parse_args() with open(args.config) as f: cfg = yaml.safe_load(f) hopfield_config = HopfieldConfig(**cfg.get("hopfield", {})) memory_config = MemoryBankConfig(**cfg.get("memory", {})) encoder_config = EncoderConfig(**cfg.get("encoder", {})) # Load memory bank mb = MemoryBank(memory_config) mb.load(args.memory_bank) # Load questions with open(args.questions) as f: questions = [json.loads(line)["question"] for line in f] encoder = Encoder(encoder_config) hopfield = HopfieldRetrieval(hopfield_config) analyses = [] for q in questions: query_emb = encoder.encode(q) # (1, d) result = hopfield.retrieve( query_emb, mb.embeddings, return_energy=True, return_trajectory=True ) curve = compute_energy_curve(result) analyses.append({ "question": q, "energy_curve": curve, "energy_gap": compute_energy_gap(curve), "monotonic": verify_monotonic_decrease(curve), "num_steps": result.num_steps, "attention_entropy": compute_attention_entropy(result.attention_weights), }) with open(args.output, "w") as f: json.dump(analyses, f, indent=2) logger.info("Energy analysis saved to %s (%d questions)", args.output, len(analyses)) if __name__ == "__main__": main()