summaryrefslogtreecommitdiff
path: root/scripts/visualize_trajectory.py
blob: e4ba902f2a9af71da0f3e5701d651f061e85b1a0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""UMAP visualization of query trajectory in Hopfield retrieval.

Usage:
    python scripts/visualize_trajectory.py --config configs/default.yaml --memory-bank data/memory_bank.pt --question "Who wrote Hamlet?"
"""

import argparse
import logging

import numpy as np
import torch
import yaml

from hag.config import EncoderConfig, HopfieldConfig, MemoryBankConfig
from hag.encoder import Encoder
from hag.hopfield import HopfieldRetrieval
from hag.memory_bank import MemoryBank

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def main() -> None:
    parser = argparse.ArgumentParser(description="Visualize Hopfield query trajectory")
    parser.add_argument("--config", type=str, default="configs/default.yaml")
    parser.add_argument("--memory-bank", type=str, required=True)
    parser.add_argument("--question", type=str, required=True)
    parser.add_argument("--output", type=str, default="trajectory.png")
    args = parser.parse_args()

    with open(args.config) as f:
        cfg = yaml.safe_load(f)

    hopfield_config = HopfieldConfig(**cfg.get("hopfield", {}))
    memory_config = MemoryBankConfig(**cfg.get("memory", {}))
    encoder_config = EncoderConfig(**cfg.get("encoder", {}))

    mb = MemoryBank(memory_config)
    mb.load(args.memory_bank)

    encoder = Encoder(encoder_config)
    hopfield = HopfieldRetrieval(hopfield_config)

    query_emb = encoder.encode(args.question)  # (1, d)
    result = hopfield.retrieve(
        query_emb, mb.embeddings, return_trajectory=True
    )

    # Gather all points for UMAP: memories + trajectory
    memories_np = mb.embeddings.T.numpy()  # (N, d)
    trajectory_np = np.stack([q.squeeze().numpy() for q in result.trajectory])  # (T+1, d)
    all_points = np.concatenate([memories_np, trajectory_np], axis=0)

    # UMAP projection
    import umap

    reducer = umap.UMAP(n_components=2, random_state=42)
    projected = reducer.fit_transform(all_points)

    mem_proj = projected[: len(memories_np)]
    traj_proj = projected[len(memories_np) :]

    # Plot
    import matplotlib.pyplot as plt

    fig, ax = plt.subplots(figsize=(10, 8))
    ax.scatter(mem_proj[:, 0], mem_proj[:, 1], c="lightgray", s=10, alpha=0.5, label="Memories")
    ax.plot(traj_proj[:, 0], traj_proj[:, 1], "b-o", markersize=6, label="Query trajectory")
    ax.scatter(traj_proj[0, 0], traj_proj[0, 1], c="green", s=100, zorder=5, label="q_0 (start)")
    ax.scatter(traj_proj[-1, 0], traj_proj[-1, 1], c="red", s=100, zorder=5, label="q_T (final)")

    ax.set_title(f"Hopfield Query Trajectory ({result.num_steps} steps)")
    ax.legend()
    plt.tight_layout()
    plt.savefig(args.output, dpi=150)
    logger.info("Trajectory visualization saved to %s", args.output)


if __name__ == "__main__":
    main()