diff options
Diffstat (limited to 'scripts/run_eval.py')
| -rw-r--r-- | scripts/run_eval.py | 90 |
1 files changed, 90 insertions, 0 deletions
diff --git a/scripts/run_eval.py b/scripts/run_eval.py new file mode 100644 index 0000000..713b3c2 --- /dev/null +++ b/scripts/run_eval.py @@ -0,0 +1,90 @@ +"""Run evaluation on a dataset with either FAISS or Hopfield retrieval. + +Usage: + python scripts/run_eval.py --config configs/hotpotqa.yaml --memory-bank data/memory_bank.pt --dataset hotpotqa --split validation --max-samples 500 +""" + +import argparse +import json +import logging + +import yaml + +from hag.config import ( + EncoderConfig, + GeneratorConfig, + HopfieldConfig, + MemoryBankConfig, + PipelineConfig, +) +from hag.encoder import Encoder +from hag.generator import Generator +from hag.memory_bank import MemoryBank +from hag.metrics import evaluate_dataset +from hag.pipeline import RAGPipeline +from hag.retriever_faiss import FAISSRetriever + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Run HAG/RAG evaluation") + parser.add_argument("--config", type=str, default="configs/default.yaml") + parser.add_argument("--memory-bank", type=str, required=True) + parser.add_argument("--dataset", type=str, default="hotpotqa") + parser.add_argument("--split", type=str, default="validation") + parser.add_argument("--max-samples", type=int, default=500) + parser.add_argument("--output", type=str, default="results.json") + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + pipeline_config = PipelineConfig( + hopfield=HopfieldConfig(**cfg.get("hopfield", {})), + memory=MemoryBankConfig(**cfg.get("memory", {})), + encoder=EncoderConfig(**cfg.get("encoder", {})), + generator=GeneratorConfig(**cfg.get("generator", {})), + retriever_type=cfg.get("retriever_type", "hopfield"), + ) + + # Load memory bank + mb = MemoryBank(pipeline_config.memory) + mb.load(args.memory_bank) + + # Build pipeline + encoder = Encoder(pipeline_config.encoder) + generator = Generator(pipeline_config.generator) + pipeline = RAGPipeline( + config=pipeline_config, + encoder=encoder, + generator=generator, + memory_bank=mb, + ) + + # Load dataset + from datasets import load_dataset + + logger.info("Loading dataset: %s / %s", args.dataset, args.split) + ds = load_dataset(args.dataset, split=args.split) + if args.max_samples and len(ds) > args.max_samples: + ds = ds.select(range(args.max_samples)) + + questions = [ex["question"] for ex in ds] + gold_answers = [ex["answer"] for ex in ds] + + # Run evaluation + logger.info("Running evaluation on %d questions", len(questions)) + results = pipeline.run_batch(questions) + metrics = evaluate_dataset(results, gold_answers) + + logger.info("Results: %s", metrics) + + with open(args.output, "w") as f: + json.dump(metrics, f, indent=2) + logger.info("Results saved to %s", args.output) + + +if __name__ == "__main__": + main() |
