summaryrefslogtreecommitdiff
path: root/scripts/visualize_trajectory.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/visualize_trajectory.py')
-rw-r--r--scripts/visualize_trajectory.py11
1 files changed, 6 insertions, 5 deletions
diff --git a/scripts/visualize_trajectory.py b/scripts/visualize_trajectory.py
index e4ba902..0087563 100644
--- a/scripts/visualize_trajectory.py
+++ b/scripts/visualize_trajectory.py
@@ -26,6 +26,7 @@ def main() -> None:
parser.add_argument("--memory-bank", type=str, required=True)
parser.add_argument("--question", type=str, required=True)
parser.add_argument("--output", type=str, default="trajectory.png")
+ parser.add_argument("--device", type=str, default="cpu")
args = parser.parse_args()
with open(args.config) as f:
@@ -36,9 +37,9 @@ def main() -> None:
encoder_config = EncoderConfig(**cfg.get("encoder", {}))
mb = MemoryBank(memory_config)
- mb.load(args.memory_bank)
+ mb.load(args.memory_bank, device=args.device)
- encoder = Encoder(encoder_config)
+ encoder = Encoder(encoder_config, device=args.device)
hopfield = HopfieldRetrieval(hopfield_config)
query_emb = encoder.encode(args.question) # (1, d)
@@ -46,9 +47,9 @@ def main() -> None:
query_emb, mb.embeddings, return_trajectory=True
)
- # Gather all points for UMAP: memories + trajectory
- memories_np = mb.embeddings.T.numpy() # (N, d)
- trajectory_np = np.stack([q.squeeze().numpy() for q in result.trajectory]) # (T+1, d)
+ # Gather all points for UMAP: memories + trajectory (must be on CPU for numpy)
+ memories_np = mb.embeddings.T.cpu().numpy() # (N, d)
+ trajectory_np = np.stack([q.squeeze().cpu().numpy() for q in result.trajectory]) # (T+1, d)
all_points = np.concatenate([memories_np, trajectory_np], axis=0)
# UMAP projection