summaryrefslogtreecommitdiff
path: root/scripts/build_memory_bank.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/build_memory_bank.py')
-rw-r--r--scripts/build_memory_bank.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/scripts/build_memory_bank.py b/scripts/build_memory_bank.py
index 2aff828..0fc1c51 100644
--- a/scripts/build_memory_bank.py
+++ b/scripts/build_memory_bank.py
@@ -50,13 +50,13 @@ def main() -> None:
logger.info("Loaded %d passages", len(passages))
# Encode passages in batches
- encoder = Encoder(encoder_config)
+ encoder = Encoder(encoder_config, device=args.device)
all_embeddings = []
for i in tqdm(range(0, len(passages), encoder_config.batch_size), desc="Encoding"):
batch = passages[i : i + encoder_config.batch_size]
emb = encoder.encode(batch) # (batch_size, d)
- all_embeddings.append(emb.cpu())
+ all_embeddings.append(emb.cpu()) # Always store on CPU for saving
embeddings = torch.cat(all_embeddings, dim=0) # (N, d)
logger.info("Encoded %d passages -> embeddings shape: %s", len(passages), embeddings.shape)