"""Run HAG (Hopfield-Augmented Generation) on a question. Usage: python scripts/run_hag.py --config configs/default.yaml --memory-bank data/memory_bank.pt --question "Who wrote Hamlet?" """ import argparse import logging import yaml from hag.config import ( EncoderConfig, GeneratorConfig, HopfieldConfig, MemoryBankConfig, PipelineConfig, ) from hag.encoder import Encoder from hag.generator import Generator from hag.memory_bank import MemoryBank from hag.pipeline import RAGPipeline logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def main() -> None: parser = argparse.ArgumentParser(description="Run HAG") parser.add_argument("--config", type=str, default="configs/default.yaml") parser.add_argument("--memory-bank", type=str, required=True) parser.add_argument("--question", type=str, required=True) parser.add_argument("--beta", type=float, default=None) parser.add_argument("--max-iter", type=int, default=None) parser.add_argument("--top-k", type=int, default=None) args = parser.parse_args() with open(args.config) as f: cfg = yaml.safe_load(f) hopfield_cfg = cfg.get("hopfield", {}) if args.beta is not None: hopfield_cfg["beta"] = args.beta if args.max_iter is not None: hopfield_cfg["max_iter"] = args.max_iter if args.top_k is not None: hopfield_cfg["top_k"] = args.top_k pipeline_config = PipelineConfig( hopfield=HopfieldConfig(**hopfield_cfg), memory=MemoryBankConfig(**cfg.get("memory", {})), encoder=EncoderConfig(**cfg.get("encoder", {})), generator=GeneratorConfig(**cfg.get("generator", {})), retriever_type="hopfield", ) mb = MemoryBank(pipeline_config.memory) mb.load(args.memory_bank) encoder = Encoder(pipeline_config.encoder) generator = Generator(pipeline_config.generator) pipeline = RAGPipeline( config=pipeline_config, encoder=encoder, generator=generator, memory_bank=mb, ) result = pipeline.run(args.question) print(f"\nQuestion: {result.question}") print(f"Answer: {result.answer}") print(f"\nRetrieved passages:") for i, p in enumerate(result.retrieved_passages): print(f" [{i+1}] {p[:200]}...") if __name__ == "__main__": main()