diff options
Diffstat (limited to 'scripts/run_eval.py')
| -rw-r--r-- | scripts/run_eval.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/scripts/run_eval.py b/scripts/run_eval.py index 713b3c2..144fc2f 100644 --- a/scripts/run_eval.py +++ b/scripts/run_eval.py @@ -36,6 +36,7 @@ def main() -> None: parser.add_argument("--split", type=str, default="validation") parser.add_argument("--max-samples", type=int, default=500) parser.add_argument("--output", type=str, default="results.json") + parser.add_argument("--device", type=str, default="cpu") args = parser.parse_args() with open(args.config) as f: @@ -47,15 +48,16 @@ def main() -> None: encoder=EncoderConfig(**cfg.get("encoder", {})), generator=GeneratorConfig(**cfg.get("generator", {})), retriever_type=cfg.get("retriever_type", "hopfield"), + device=args.device, ) # Load memory bank mb = MemoryBank(pipeline_config.memory) - mb.load(args.memory_bank) + mb.load(args.memory_bank, device=args.device) # Build pipeline - encoder = Encoder(pipeline_config.encoder) - generator = Generator(pipeline_config.generator) + encoder = Encoder(pipeline_config.encoder, device=args.device) + generator = Generator(pipeline_config.generator, device=args.device) pipeline = RAGPipeline( config=pipeline_config, encoder=encoder, |
