1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
|
"""UMAP visualization of query trajectory in Hopfield retrieval.
Usage:
python scripts/visualize_trajectory.py --config configs/default.yaml --memory-bank data/memory_bank.pt --question "Who wrote Hamlet?"
"""
import argparse
import logging
import numpy as np
import torch
import yaml
from hag.config import EncoderConfig, HopfieldConfig, MemoryBankConfig
from hag.encoder import Encoder
from hag.hopfield import HopfieldRetrieval
from hag.memory_bank import MemoryBank
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def main() -> None:
parser = argparse.ArgumentParser(description="Visualize Hopfield query trajectory")
parser.add_argument("--config", type=str, default="configs/default.yaml")
parser.add_argument("--memory-bank", type=str, required=True)
parser.add_argument("--question", type=str, required=True)
parser.add_argument("--output", type=str, default="trajectory.png")
args = parser.parse_args()
with open(args.config) as f:
cfg = yaml.safe_load(f)
hopfield_config = HopfieldConfig(**cfg.get("hopfield", {}))
memory_config = MemoryBankConfig(**cfg.get("memory", {}))
encoder_config = EncoderConfig(**cfg.get("encoder", {}))
mb = MemoryBank(memory_config)
mb.load(args.memory_bank)
encoder = Encoder(encoder_config)
hopfield = HopfieldRetrieval(hopfield_config)
query_emb = encoder.encode(args.question) # (1, d)
result = hopfield.retrieve(
query_emb, mb.embeddings, return_trajectory=True
)
# Gather all points for UMAP: memories + trajectory
memories_np = mb.embeddings.T.numpy() # (N, d)
trajectory_np = np.stack([q.squeeze().numpy() for q in result.trajectory]) # (T+1, d)
all_points = np.concatenate([memories_np, trajectory_np], axis=0)
# UMAP projection
import umap
reducer = umap.UMAP(n_components=2, random_state=42)
projected = reducer.fit_transform(all_points)
mem_proj = projected[: len(memories_np)]
traj_proj = projected[len(memories_np) :]
# Plot
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(10, 8))
ax.scatter(mem_proj[:, 0], mem_proj[:, 1], c="lightgray", s=10, alpha=0.5, label="Memories")
ax.plot(traj_proj[:, 0], traj_proj[:, 1], "b-o", markersize=6, label="Query trajectory")
ax.scatter(traj_proj[0, 0], traj_proj[0, 1], c="green", s=100, zorder=5, label="q_0 (start)")
ax.scatter(traj_proj[-1, 0], traj_proj[-1, 1], c="red", s=100, zorder=5, label="q_T (final)")
ax.set_title(f"Hopfield Query Trajectory ({result.num_steps} steps)")
ax.legend()
plt.tight_layout()
plt.savefig(args.output, dpi=150)
logger.info("Trajectory visualization saved to %s", args.output)
if __name__ == "__main__":
main()
|