"""UMAP visualization of query trajectory in Hopfield retrieval. Usage: python scripts/visualize_trajectory.py --config configs/default.yaml --memory-bank data/memory_bank.pt --question "Who wrote Hamlet?" """ import argparse import logging import numpy as np import torch import yaml from hag.config import EncoderConfig, HopfieldConfig, MemoryBankConfig from hag.encoder import Encoder from hag.hopfield import HopfieldRetrieval from hag.memory_bank import MemoryBank logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def main() -> None: parser = argparse.ArgumentParser(description="Visualize Hopfield query trajectory") parser.add_argument("--config", type=str, default="configs/default.yaml") parser.add_argument("--memory-bank", type=str, required=True) parser.add_argument("--question", type=str, required=True) parser.add_argument("--output", type=str, default="trajectory.png") args = parser.parse_args() with open(args.config) as f: cfg = yaml.safe_load(f) hopfield_config = HopfieldConfig(**cfg.get("hopfield", {})) memory_config = MemoryBankConfig(**cfg.get("memory", {})) encoder_config = EncoderConfig(**cfg.get("encoder", {})) mb = MemoryBank(memory_config) mb.load(args.memory_bank) encoder = Encoder(encoder_config) hopfield = HopfieldRetrieval(hopfield_config) query_emb = encoder.encode(args.question) # (1, d) result = hopfield.retrieve( query_emb, mb.embeddings, return_trajectory=True ) # Gather all points for UMAP: memories + trajectory memories_np = mb.embeddings.T.numpy() # (N, d) trajectory_np = np.stack([q.squeeze().numpy() for q in result.trajectory]) # (T+1, d) all_points = np.concatenate([memories_np, trajectory_np], axis=0) # UMAP projection import umap reducer = umap.UMAP(n_components=2, random_state=42) projected = reducer.fit_transform(all_points) mem_proj = projected[: len(memories_np)] traj_proj = projected[len(memories_np) :] # Plot import matplotlib.pyplot as plt fig, ax = plt.subplots(figsize=(10, 8)) ax.scatter(mem_proj[:, 0], mem_proj[:, 1], c="lightgray", s=10, alpha=0.5, label="Memories") ax.plot(traj_proj[:, 0], traj_proj[:, 1], "b-o", markersize=6, label="Query trajectory") ax.scatter(traj_proj[0, 0], traj_proj[0, 1], c="green", s=100, zorder=5, label="q_0 (start)") ax.scatter(traj_proj[-1, 0], traj_proj[-1, 1], c="red", s=100, zorder=5, label="q_T (final)") ax.set_title(f"Hopfield Query Trajectory ({result.num_steps} steps)") ax.legend() plt.tight_layout() plt.savefig(args.output, dpi=150) logger.info("Trajectory visualization saved to %s", args.output) if __name__ == "__main__": main()