summaryrefslogtreecommitdiff
path: root/scripts/run_hag.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/run_hag.py')
-rw-r--r--scripts/run_hag.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/scripts/run_hag.py b/scripts/run_hag.py
index 4cacd1a..b6c9004 100644
--- a/scripts/run_hag.py
+++ b/scripts/run_hag.py
@@ -33,6 +33,7 @@ def main() -> None:
parser.add_argument("--beta", type=float, default=None)
parser.add_argument("--max-iter", type=int, default=None)
parser.add_argument("--top-k", type=int, default=None)
+ parser.add_argument("--device", type=str, default="cpu")
args = parser.parse_args()
with open(args.config) as f:
@@ -52,13 +53,14 @@ def main() -> None:
encoder=EncoderConfig(**cfg.get("encoder", {})),
generator=GeneratorConfig(**cfg.get("generator", {})),
retriever_type="hopfield",
+ device=args.device,
)
mb = MemoryBank(pipeline_config.memory)
- mb.load(args.memory_bank)
+ mb.load(args.memory_bank, device=args.device)
- encoder = Encoder(pipeline_config.encoder)
- generator = Generator(pipeline_config.generator)
+ encoder = Encoder(pipeline_config.encoder, device=args.device)
+ generator = Generator(pipeline_config.generator, device=args.device)
pipeline = RAGPipeline(
config=pipeline_config,