diff options
Diffstat (limited to 'scripts/visualize_trajectory.py')
| -rw-r--r-- | scripts/visualize_trajectory.py | 80 |
1 files changed, 80 insertions, 0 deletions
diff --git a/scripts/visualize_trajectory.py b/scripts/visualize_trajectory.py new file mode 100644 index 0000000..e4ba902 --- /dev/null +++ b/scripts/visualize_trajectory.py @@ -0,0 +1,80 @@ +"""UMAP visualization of query trajectory in Hopfield retrieval. + +Usage: + python scripts/visualize_trajectory.py --config configs/default.yaml --memory-bank data/memory_bank.pt --question "Who wrote Hamlet?" +""" + +import argparse +import logging + +import numpy as np +import torch +import yaml + +from hag.config import EncoderConfig, HopfieldConfig, MemoryBankConfig +from hag.encoder import Encoder +from hag.hopfield import HopfieldRetrieval +from hag.memory_bank import MemoryBank + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Visualize Hopfield query trajectory") + parser.add_argument("--config", type=str, default="configs/default.yaml") + parser.add_argument("--memory-bank", type=str, required=True) + parser.add_argument("--question", type=str, required=True) + parser.add_argument("--output", type=str, default="trajectory.png") + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + hopfield_config = HopfieldConfig(**cfg.get("hopfield", {})) + memory_config = MemoryBankConfig(**cfg.get("memory", {})) + encoder_config = EncoderConfig(**cfg.get("encoder", {})) + + mb = MemoryBank(memory_config) + mb.load(args.memory_bank) + + encoder = Encoder(encoder_config) + hopfield = HopfieldRetrieval(hopfield_config) + + query_emb = encoder.encode(args.question) # (1, d) + result = hopfield.retrieve( + query_emb, mb.embeddings, return_trajectory=True + ) + + # Gather all points for UMAP: memories + trajectory + memories_np = mb.embeddings.T.numpy() # (N, d) + trajectory_np = np.stack([q.squeeze().numpy() for q in result.trajectory]) # (T+1, d) + all_points = np.concatenate([memories_np, trajectory_np], axis=0) + + # UMAP projection + import umap + + reducer = umap.UMAP(n_components=2, random_state=42) + projected = reducer.fit_transform(all_points) + + mem_proj = projected[: len(memories_np)] + traj_proj = projected[len(memories_np) :] + + # Plot + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(figsize=(10, 8)) + ax.scatter(mem_proj[:, 0], mem_proj[:, 1], c="lightgray", s=10, alpha=0.5, label="Memories") + ax.plot(traj_proj[:, 0], traj_proj[:, 1], "b-o", markersize=6, label="Query trajectory") + ax.scatter(traj_proj[0, 0], traj_proj[0, 1], c="green", s=100, zorder=5, label="q_0 (start)") + ax.scatter(traj_proj[-1, 0], traj_proj[-1, 1], c="red", s=100, zorder=5, label="q_T (final)") + + ax.set_title(f"Hopfield Query Trajectory ({result.num_steps} steps)") + ax.legend() + plt.tight_layout() + plt.savefig(args.output, dpi=150) + logger.info("Trajectory visualization saved to %s", args.output) + + +if __name__ == "__main__": + main() |
