diff options
Diffstat (limited to 'hag/pipeline.py')
| -rw-r--r-- | hag/pipeline.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/hag/pipeline.py b/hag/pipeline.py index 1fefb84..086b3be 100644 --- a/hag/pipeline.py +++ b/hag/pipeline.py @@ -82,7 +82,8 @@ class RAGPipeline: if self.retriever_type == "hopfield": retrieval_result = self.hopfield_retriever.retrieve(query_emb) else: - query_np = query_emb.detach().numpy().astype(np.float32) + # FAISS requires CPU numpy arrays + query_np = query_emb.detach().cpu().numpy().astype(np.float32) retrieval_result = self.faiss_retriever.retrieve(query_np) # Generate |
