summaryrefslogtreecommitdiff
path: root/scripts/run_baseline.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/run_baseline.py')
-rw-r--r--scripts/run_baseline.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/scripts/run_baseline.py b/scripts/run_baseline.py
index 74c4710..beef76b 100644
--- a/scripts/run_baseline.py
+++ b/scripts/run_baseline.py
@@ -27,6 +27,7 @@ def main() -> None:
parser.add_argument("--memory-bank", type=str, required=True)
parser.add_argument("--question", type=str, required=True)
parser.add_argument("--top-k", type=int, default=5)
+ parser.add_argument("--device", type=str, default="cpu")
args = parser.parse_args()
with open(args.config) as f:
@@ -44,17 +45,17 @@ def main() -> None:
from hag.config import MemoryBankConfig
mb = MemoryBank(MemoryBankConfig(**cfg.get("memory", {})))
- mb.load(args.memory_bank)
+ mb.load(args.memory_bank) # FAISS needs CPU, load on CPU
# Build FAISS index from memory bank embeddings
import numpy as np
- embeddings_np = mb.embeddings.T.numpy().astype(np.float32) # (N, d)
+ embeddings_np = mb.embeddings.T.cpu().numpy().astype(np.float32) # (N, d)
faiss_ret = FAISSRetriever(top_k=args.top_k)
faiss_ret.build_index(embeddings_np, mb.passages)
- encoder = Encoder(pipeline_config.encoder)
- generator = Generator(pipeline_config.generator)
+ encoder = Encoder(pipeline_config.encoder, device=args.device)
+ generator = Generator(pipeline_config.generator, device=args.device)
pipeline = RAGPipeline(
config=pipeline_config,