"""Run evaluation on a dataset with either FAISS or Hopfield retrieval. Usage: python scripts/run_eval.py --config configs/hotpotqa.yaml --memory-bank data/memory_bank.pt --dataset hotpotqa --split validation --max-samples 500 """ import argparse import json import logging import yaml from hag.config import ( EncoderConfig, GeneratorConfig, HopfieldConfig, MemoryBankConfig, PipelineConfig, ) from hag.encoder import Encoder from hag.generator import Generator from hag.memory_bank import MemoryBank from hag.metrics import evaluate_dataset from hag.pipeline import RAGPipeline from hag.retriever_faiss import FAISSRetriever logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def main() -> None: parser = argparse.ArgumentParser(description="Run HAG/RAG evaluation") parser.add_argument("--config", type=str, default="configs/default.yaml") parser.add_argument("--memory-bank", type=str, required=True) parser.add_argument("--dataset", type=str, default="hotpotqa") parser.add_argument("--split", type=str, default="validation") parser.add_argument("--max-samples", type=int, default=500) parser.add_argument("--output", type=str, default="results.json") args = parser.parse_args() with open(args.config) as f: cfg = yaml.safe_load(f) pipeline_config = PipelineConfig( hopfield=HopfieldConfig(**cfg.get("hopfield", {})), memory=MemoryBankConfig(**cfg.get("memory", {})), encoder=EncoderConfig(**cfg.get("encoder", {})), generator=GeneratorConfig(**cfg.get("generator", {})), retriever_type=cfg.get("retriever_type", "hopfield"), ) # Load memory bank mb = MemoryBank(pipeline_config.memory) mb.load(args.memory_bank) # Build pipeline encoder = Encoder(pipeline_config.encoder) generator = Generator(pipeline_config.generator) pipeline = RAGPipeline( config=pipeline_config, encoder=encoder, generator=generator, memory_bank=mb, ) # Load dataset from datasets import load_dataset logger.info("Loading dataset: %s / %s", args.dataset, args.split) ds = load_dataset(args.dataset, split=args.split) if args.max_samples and len(ds) > args.max_samples: ds = ds.select(range(args.max_samples)) questions = [ex["question"] for ex in ds] gold_answers = [ex["answer"] for ex in ds] # Run evaluation logger.info("Running evaluation on %d questions", len(questions)) results = pipeline.run_batch(questions) metrics = evaluate_dataset(results, gold_answers) logger.info("Results: %s", metrics) with open(args.output, "w") as f: json.dump(metrics, f, indent=2) logger.info("Results saved to %s", args.output) if __name__ == "__main__": main()