diff options
Diffstat (limited to 'scripts/run_baseline.py')
| -rw-r--r-- | scripts/run_baseline.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/scripts/run_baseline.py b/scripts/run_baseline.py index 74c4710..beef76b 100644 --- a/scripts/run_baseline.py +++ b/scripts/run_baseline.py @@ -27,6 +27,7 @@ def main() -> None: parser.add_argument("--memory-bank", type=str, required=True) parser.add_argument("--question", type=str, required=True) parser.add_argument("--top-k", type=int, default=5) + parser.add_argument("--device", type=str, default="cpu") args = parser.parse_args() with open(args.config) as f: @@ -44,17 +45,17 @@ def main() -> None: from hag.config import MemoryBankConfig mb = MemoryBank(MemoryBankConfig(**cfg.get("memory", {}))) - mb.load(args.memory_bank) + mb.load(args.memory_bank) # FAISS needs CPU, load on CPU # Build FAISS index from memory bank embeddings import numpy as np - embeddings_np = mb.embeddings.T.numpy().astype(np.float32) # (N, d) + embeddings_np = mb.embeddings.T.cpu().numpy().astype(np.float32) # (N, d) faiss_ret = FAISSRetriever(top_k=args.top_k) faiss_ret.build_index(embeddings_np, mb.passages) - encoder = Encoder(pipeline_config.encoder) - generator = Generator(pipeline_config.generator) + encoder = Encoder(pipeline_config.encoder, device=args.device) + generator = Generator(pipeline_config.generator, device=args.device) pipeline = RAGPipeline( config=pipeline_config, |
