summaryrefslogtreecommitdiff
path: root/scripts/run_eval.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/run_eval.py')
-rw-r--r--scripts/run_eval.py90
1 files changed, 90 insertions, 0 deletions
diff --git a/scripts/run_eval.py b/scripts/run_eval.py
new file mode 100644
index 0000000..713b3c2
--- /dev/null
+++ b/scripts/run_eval.py
@@ -0,0 +1,90 @@
+"""Run evaluation on a dataset with either FAISS or Hopfield retrieval.
+
+Usage:
+ python scripts/run_eval.py --config configs/hotpotqa.yaml --memory-bank data/memory_bank.pt --dataset hotpotqa --split validation --max-samples 500
+"""
+
+import argparse
+import json
+import logging
+
+import yaml
+
+from hag.config import (
+ EncoderConfig,
+ GeneratorConfig,
+ HopfieldConfig,
+ MemoryBankConfig,
+ PipelineConfig,
+)
+from hag.encoder import Encoder
+from hag.generator import Generator
+from hag.memory_bank import MemoryBank
+from hag.metrics import evaluate_dataset
+from hag.pipeline import RAGPipeline
+from hag.retriever_faiss import FAISSRetriever
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Run HAG/RAG evaluation")
+ parser.add_argument("--config", type=str, default="configs/default.yaml")
+ parser.add_argument("--memory-bank", type=str, required=True)
+ parser.add_argument("--dataset", type=str, default="hotpotqa")
+ parser.add_argument("--split", type=str, default="validation")
+ parser.add_argument("--max-samples", type=int, default=500)
+ parser.add_argument("--output", type=str, default="results.json")
+ args = parser.parse_args()
+
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ pipeline_config = PipelineConfig(
+ hopfield=HopfieldConfig(**cfg.get("hopfield", {})),
+ memory=MemoryBankConfig(**cfg.get("memory", {})),
+ encoder=EncoderConfig(**cfg.get("encoder", {})),
+ generator=GeneratorConfig(**cfg.get("generator", {})),
+ retriever_type=cfg.get("retriever_type", "hopfield"),
+ )
+
+ # Load memory bank
+ mb = MemoryBank(pipeline_config.memory)
+ mb.load(args.memory_bank)
+
+ # Build pipeline
+ encoder = Encoder(pipeline_config.encoder)
+ generator = Generator(pipeline_config.generator)
+ pipeline = RAGPipeline(
+ config=pipeline_config,
+ encoder=encoder,
+ generator=generator,
+ memory_bank=mb,
+ )
+
+ # Load dataset
+ from datasets import load_dataset
+
+ logger.info("Loading dataset: %s / %s", args.dataset, args.split)
+ ds = load_dataset(args.dataset, split=args.split)
+ if args.max_samples and len(ds) > args.max_samples:
+ ds = ds.select(range(args.max_samples))
+
+ questions = [ex["question"] for ex in ds]
+ gold_answers = [ex["answer"] for ex in ds]
+
+ # Run evaluation
+ logger.info("Running evaluation on %d questions", len(questions))
+ results = pipeline.run_batch(questions)
+ metrics = evaluate_dataset(results, gold_answers)
+
+ logger.info("Results: %s", metrics)
+
+ with open(args.output, "w") as f:
+ json.dump(metrics, f, indent=2)
+ logger.info("Results saved to %s", args.output)
+
+
+if __name__ == "__main__":
+ main()