summaryrefslogtreecommitdiff
path: root/scripts/visualize_trajectory.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
commitc90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch)
tree43edac8013fec4e65a0b9cddec5314489b4aafc2 /scripts/visualize_trajectory.py
Initial implementation of HAG (Hopfield-Augmented Generation)HEADmaster
Core Hopfield retrieval module with energy-based convergence guarantees, memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end pipeline. All 45 tests passing on CPU with synthetic data. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'scripts/visualize_trajectory.py')
-rw-r--r--scripts/visualize_trajectory.py80
1 files changed, 80 insertions, 0 deletions
diff --git a/scripts/visualize_trajectory.py b/scripts/visualize_trajectory.py
new file mode 100644
index 0000000..e4ba902
--- /dev/null
+++ b/scripts/visualize_trajectory.py
@@ -0,0 +1,80 @@
+"""UMAP visualization of query trajectory in Hopfield retrieval.
+
+Usage:
+ python scripts/visualize_trajectory.py --config configs/default.yaml --memory-bank data/memory_bank.pt --question "Who wrote Hamlet?"
+"""
+
+import argparse
+import logging
+
+import numpy as np
+import torch
+import yaml
+
+from hag.config import EncoderConfig, HopfieldConfig, MemoryBankConfig
+from hag.encoder import Encoder
+from hag.hopfield import HopfieldRetrieval
+from hag.memory_bank import MemoryBank
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Visualize Hopfield query trajectory")
+ parser.add_argument("--config", type=str, default="configs/default.yaml")
+ parser.add_argument("--memory-bank", type=str, required=True)
+ parser.add_argument("--question", type=str, required=True)
+ parser.add_argument("--output", type=str, default="trajectory.png")
+ args = parser.parse_args()
+
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ hopfield_config = HopfieldConfig(**cfg.get("hopfield", {}))
+ memory_config = MemoryBankConfig(**cfg.get("memory", {}))
+ encoder_config = EncoderConfig(**cfg.get("encoder", {}))
+
+ mb = MemoryBank(memory_config)
+ mb.load(args.memory_bank)
+
+ encoder = Encoder(encoder_config)
+ hopfield = HopfieldRetrieval(hopfield_config)
+
+ query_emb = encoder.encode(args.question) # (1, d)
+ result = hopfield.retrieve(
+ query_emb, mb.embeddings, return_trajectory=True
+ )
+
+ # Gather all points for UMAP: memories + trajectory
+ memories_np = mb.embeddings.T.numpy() # (N, d)
+ trajectory_np = np.stack([q.squeeze().numpy() for q in result.trajectory]) # (T+1, d)
+ all_points = np.concatenate([memories_np, trajectory_np], axis=0)
+
+ # UMAP projection
+ import umap
+
+ reducer = umap.UMAP(n_components=2, random_state=42)
+ projected = reducer.fit_transform(all_points)
+
+ mem_proj = projected[: len(memories_np)]
+ traj_proj = projected[len(memories_np) :]
+
+ # Plot
+ import matplotlib.pyplot as plt
+
+ fig, ax = plt.subplots(figsize=(10, 8))
+ ax.scatter(mem_proj[:, 0], mem_proj[:, 1], c="lightgray", s=10, alpha=0.5, label="Memories")
+ ax.plot(traj_proj[:, 0], traj_proj[:, 1], "b-o", markersize=6, label="Query trajectory")
+ ax.scatter(traj_proj[0, 0], traj_proj[0, 1], c="green", s=100, zorder=5, label="q_0 (start)")
+ ax.scatter(traj_proj[-1, 0], traj_proj[-1, 1], c="red", s=100, zorder=5, label="q_T (final)")
+
+ ax.set_title(f"Hopfield Query Trajectory ({result.num_steps} steps)")
+ ax.legend()
+ plt.tight_layout()
+ plt.savefig(args.output, dpi=150)
+ logger.info("Trajectory visualization saved to %s", args.output)
+
+
+if __name__ == "__main__":
+ main()