summaryrefslogtreecommitdiff
path: root/scripts/visualize_energy.py
blob: f39953ca62980627dff8c03d5962b89b74b35822 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
"""Visualize Hopfield energy landscape: centered vs uncentered.

Produces 4 figures, each with centered/uncentered side-by-side:
  1. 2D contour + Hopfield trajectory
  2. 1D energy profile along key directions
  3. UMAP of memories + query trajectories
  4. PCA top-2 energy heatmap

Usage:
    CUDA_VISIBLE_DEVICES=1 python -u scripts/visualize_energy.py \
        --memory-bank data/processed/hotpotqa_memory_bank.pt \
        --questions data/processed/hotpotqa_questions.jsonl \
        --device cuda --query-idx 0
"""

import argparse
import json
import sys
from pathlib import Path

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import torch
import torch.nn.functional as F

sys.path.insert(0, "/home/yurenh2/HAG")

from hag.config import MemoryBankConfig, EncoderConfig
from hag.memory_bank import MemoryBank
from hag.encoder import Encoder


# ── Helpers ──────────────────────────────────────────────────────────

def compute_energy(q: torch.Tensor, M: torch.Tensor, beta: float) -> torch.Tensor:
    """E(q) = -1/β · logsumexp(β · qᵀM) + 1/2 · ‖q‖²

    Args:
        q: (..., d)
        M: (d, N)
        beta: inverse temperature
    Returns:
        energy: (...)
    """
    logits = beta * (q @ M)  # (..., N)
    lse = torch.logsumexp(logits, dim=-1)  # (...)
    norm_sq = 0.5 * (q ** 2).sum(dim=-1)  # (...)
    return -1.0 / beta * lse + norm_sq


def hopfield_trajectory(q0: torch.Tensor, M: torch.Tensor, beta: float,
                        max_iter: int = 15) -> torch.Tensor:
    """Run Hopfield and return full trajectory. Returns (T+1, d)."""
    q = q0.clone().unsqueeze(0) if q0.dim() == 1 else q0.clone()  # (1, d)
    traj = [q.squeeze(0).clone()]
    for _ in range(max_iter):
        logits = beta * (q @ M)
        alpha = torch.softmax(logits, dim=-1)
        q_new = alpha @ M.T
        traj.append(q_new.squeeze(0).clone())
        if (q_new - q).norm() < 1e-8:
            break
        q = q_new
    return torch.stack(traj, dim=0)  # (T+1, d)


def orthonormalize(v1: torch.Tensor, v2: torch.Tensor):
    """Return two orthonormal vectors spanning the plane of v1, v2."""
    e1 = v1 / v1.norm()
    v2_orth = v2 - (v2 @ e1) * e1
    if v2_orth.norm() < 1e-8:
        # v1 and v2 are parallel, pick a random orthogonal direction
        rand = torch.randn_like(v1)
        v2_orth = rand - (rand @ e1) * e1
    e2 = v2_orth / v2_orth.norm()
    return e1, e2


def project_to_plane(points: torch.Tensor, e1: torch.Tensor, e2: torch.Tensor):
    """Project (K, d) points onto 2D plane defined by e1, e2. Returns (K, 2)."""
    return torch.stack([points @ e1, points @ e2], dim=-1)


# ── Figure 1: 2D Contour + Trajectory ───────────────────────────────

def fig1_contour(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device):
    """2D energy contour on the query-centroid plane, with Hopfield trajectories."""

    centroid = M_raw.mean(dim=1)  # (d,)
    q0_cent = q0_raw - mu

    fig, axes = plt.subplots(2, len(betas_plot), figsize=(6 * len(betas_plot), 12),
                             squeeze=False)

    for col, beta in enumerate(betas_plot):
        for row, (label, M, q0, ref_point, ref_label) in enumerate([
            ("Uncentered", M_raw, q0_raw, centroid, "centroid"),
            ("Centered",   M_cent, q0_cent, torch.zeros_like(centroid), "origin"),
        ]):
            ax = axes[row, col]

            # Define 2D plane: query direction + centroid/origin direction
            e1, e2 = orthonormalize(q0.to(device), ref_point.to(device) if ref_point.norm() > 1e-6 else M.to(device)[:, 0])

            # Grid
            grid_range = 1.5
            n_grid = 150
            xs = torch.linspace(-grid_range, grid_range, n_grid, device=device)
            ys = torch.linspace(-grid_range, grid_range, n_grid, device=device)
            xx, yy = torch.meshgrid(xs, ys, indexing='ij')
            grid_points = xx.reshape(-1, 1) * e1.unsqueeze(0) + yy.reshape(-1, 1) * e2.unsqueeze(0)  # (n^2, d)

            E = compute_energy(grid_points, M.to(device), beta).reshape(n_grid, n_grid).cpu().numpy()

            # Trajectory
            traj = hopfield_trajectory(q0.to(device), M.to(device), beta, max_iter=15)
            traj_2d = project_to_plane(traj, e1, e2).cpu().numpy()

            # Project memories
            mem_2d = project_to_plane(M.T.to(device), e1, e2).cpu().numpy()

            # Project reference point
            ref_2d = project_to_plane(ref_point.unsqueeze(0).to(device), e1, e2).cpu().numpy()

            # Plot
            xx_np, yy_np = xx.cpu().numpy(), yy.cpu().numpy()
            # Clip energy for better visualization
            E_clip = np.clip(E, np.percentile(E, 1), np.percentile(E, 95))
            cs = ax.contourf(xx_np, yy_np, E_clip, levels=40, cmap='viridis')
            ax.contour(xx_np, yy_np, E_clip, levels=15, colors='white', linewidths=0.3, alpha=0.5)

            # Memories (small dots)
            ax.scatter(mem_2d[:, 0], mem_2d[:, 1], c='white', s=3, alpha=0.3, zorder=2)

            # Reference point
            if ref_point.norm() > 1e-6:
                ax.scatter(ref_2d[:, 0], ref_2d[:, 1], c='red', s=100, marker='*',
                          zorder=5, label=ref_label)
            else:
                ax.scatter(0, 0, c='red', s=100, marker='*', zorder=5, label='origin')

            # Trajectory
            ax.plot(traj_2d[:, 0], traj_2d[:, 1], 'o-', color='#ff6600', markersize=4,
                   linewidth=2, zorder=4, label='trajectory')
            ax.scatter(traj_2d[0, 0], traj_2d[0, 1], c='lime', s=80, marker='s',
                      zorder=5, label='q₀')
            ax.scatter(traj_2d[-1, 0], traj_2d[-1, 1], c='magenta', s=80, marker='D',
                      zorder=5, label=f'q_T (T={len(traj_2d)-1})')

            ax.set_title(f"{label}, β={beta}", fontsize=13, fontweight='bold')
            ax.set_xlabel("e₁ (query dir)")
            ax.set_ylabel("e₂")
            ax.legend(fontsize=7, loc='upper right')
            plt.colorbar(cs, ax=ax, shrink=0.8, label='E(q)')

    fig.suptitle("Fig 1: 2D Energy Contour + Hopfield Trajectory", fontsize=15, fontweight='bold')
    fig.tight_layout()
    fig.savefig(outdir / "fig1_contour.png", dpi=150, bbox_inches='tight')
    plt.close(fig)
    print(f"Saved {outdir / 'fig1_contour.png'}")


# ── Figure 2: 1D Energy Profile ─────────────────────────────────────

def fig2_profile(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device):
    """1D energy along key directions."""

    centroid = M_raw.mean(dim=1)
    q0_cent = q0_raw - mu

    # Find top-1 memory for each
    scores_raw = q0_raw @ M_raw
    top1_raw_idx = scores_raw.argmax().item()
    top1_raw = M_raw[:, top1_raw_idx]

    scores_cent = q0_cent @ M_cent
    top1_cent_idx = scores_cent.argmax().item()
    top1_cent = M_cent[:, top1_cent_idx]

    fig, axes = plt.subplots(2, len(betas_plot), figsize=(6 * len(betas_plot), 10),
                             squeeze=False)

    ts = torch.linspace(-0.5, 2.0, 300, device=device)

    for col, beta in enumerate(betas_plot):
        # Uncentered
        ax = axes[0, col]
        for target, name, color in [
            (centroid, "→ centroid", "red"),
            (top1_raw, f"→ memory[{top1_raw_idx}]", "blue"),
            (torch.zeros_like(q0_raw), "→ origin", "gray"),
        ]:
            direction = target - q0_raw.to(device)
            if direction.norm() < 1e-8:
                continue
            points = q0_raw.unsqueeze(0).to(device) + ts.unsqueeze(1) * direction.unsqueeze(0)
            E = compute_energy(points, M_raw.to(device), beta).cpu().numpy()
            ax.plot(ts.cpu().numpy(), E, label=name, color=color, linewidth=2)

        # Mark t=0 (query) and t=1 (target)
        E_q0 = compute_energy(q0_raw.unsqueeze(0).to(device), M_raw.to(device), beta).item()
        ax.axvline(0, color='lime', linestyle='--', alpha=0.5, label='q₀')
        ax.axvline(1, color='black', linestyle=':', alpha=0.5, label='target')
        ax.set_title(f"Uncentered, β={beta}", fontsize=13, fontweight='bold')
        ax.set_xlabel("t  (q₀ + t·(target - q₀))")
        ax.set_ylabel("E(q)")
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3)

        # Centered
        ax = axes[1, col]
        for target, name, color in [
            (torch.zeros_like(q0_cent), "→ origin", "red"),
            (top1_cent, f"→ memory[{top1_cent_idx}]", "blue"),
        ]:
            direction = target.to(device) - q0_cent.to(device)
            if direction.norm() < 1e-8:
                continue
            points = q0_cent.unsqueeze(0).to(device) + ts.unsqueeze(1) * direction.unsqueeze(0)
            E = compute_energy(points, M_cent.to(device), beta).cpu().numpy()
            ax.plot(ts.cpu().numpy(), E, label=name, color=color, linewidth=2)

        ax.axvline(0, color='lime', linestyle='--', alpha=0.5, label='q₀')
        ax.axvline(1, color='black', linestyle=':', alpha=0.5, label='target')
        ax.set_title(f"Centered, β={beta}", fontsize=13, fontweight='bold')
        ax.set_xlabel("t  (q̃₀ + t·(target - q̃₀))")
        ax.set_ylabel("E(q)")
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3)

    fig.suptitle("Fig 2: 1D Energy Profile Along Key Directions", fontsize=15, fontweight='bold')
    fig.tight_layout()
    fig.savefig(outdir / "fig2_profile.png", dpi=150, bbox_inches='tight')
    plt.close(fig)
    print(f"Saved {outdir / 'fig2_profile.png'}")


# ── Figure 3: UMAP + Trajectories ───────────────────────────────────

def fig3_umap(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device):
    """UMAP of memories + query trajectories."""
    try:
        import umap
    except ImportError:
        print("umap-learn not installed, skipping fig3")
        return

    centroid = M_raw.mean(dim=1)
    q0_cent = q0_raw - mu

    fig, axes = plt.subplots(2, len(betas_plot), figsize=(6 * len(betas_plot), 12),
                             squeeze=False)

    for col, beta in enumerate(betas_plot):
        for row, (label, M, q0) in enumerate([
            ("Uncentered", M_raw, q0_raw),
            ("Centered",   M_cent, q0_cent),
        ]):
            ax = axes[row, col]

            # Trajectory
            traj = hopfield_trajectory(q0.to(device), M.to(device), beta, max_iter=15)
            traj_cpu = traj.cpu()

            # Combine memories + trajectory for UMAP
            mem_cpu = M.T.cpu()  # (N, d)
            all_points = torch.cat([mem_cpu, traj_cpu], dim=0).numpy()

            # Fit UMAP
            reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, random_state=42)
            embedding = reducer.fit_transform(all_points)

            n_mem = mem_cpu.shape[0]
            mem_emb = embedding[:n_mem]
            traj_emb = embedding[n_mem:]

            # Energy for color
            E_mem = compute_energy(mem_cpu.to(device), M.to(device), beta).cpu().numpy()

            # Plot memories colored by energy
            sc = ax.scatter(mem_emb[:, 0], mem_emb[:, 1], c=E_mem, cmap='viridis',
                          s=10, alpha=0.5, zorder=1)

            # Plot trajectory
            ax.plot(traj_emb[:, 0], traj_emb[:, 1], 'o-', color='#ff6600',
                   markersize=5, linewidth=2, zorder=3, label='trajectory')
            ax.scatter(traj_emb[0, 0], traj_emb[0, 1], c='lime', s=100,
                      marker='s', zorder=4, label='q₀')
            ax.scatter(traj_emb[-1, 0], traj_emb[-1, 1], c='magenta', s=100,
                      marker='D', zorder=4, label=f'q_T')

            ax.set_title(f"{label}, β={beta}", fontsize=13, fontweight='bold')
            ax.legend(fontsize=8, loc='upper right')
            plt.colorbar(sc, ax=ax, shrink=0.8, label='E(q)')

    fig.suptitle("Fig 3: UMAP of Memories + Hopfield Trajectory (color = energy)",
                fontsize=15, fontweight='bold')
    fig.tight_layout()
    fig.savefig(outdir / "fig3_umap.png", dpi=150, bbox_inches='tight')
    plt.close(fig)
    print(f"Saved {outdir / 'fig3_umap.png'}")


# ── Figure 4: PCA Top-2 Energy Heatmap ──────────────────────────────

def fig4_pca(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device):
    """Energy heatmap on PCA top-2 components of memory bank."""

    centroid = M_raw.mean(dim=1)
    q0_cent = q0_raw - mu

    fig, axes = plt.subplots(2, len(betas_plot), figsize=(6 * len(betas_plot), 12),
                             squeeze=False)

    for row, (label, M, q0, ref_point, ref_label) in enumerate([
        ("Uncentered", M_raw, q0_raw, centroid, "centroid"),
        ("Centered",   M_cent, q0_cent, torch.zeros_like(centroid), "origin"),
    ]):
        # PCA on this memory bank
        M_cpu = M.cpu()  # (d, N)
        # SVD of M to get top-2 directions
        U, S, Vh = torch.linalg.svd(M_cpu, full_matrices=False)
        pc1 = U[:, 0].to(device)  # (d,)
        pc2 = U[:, 1].to(device)  # (d,)

        for col, beta in enumerate(betas_plot):
            ax = axes[row, col]

            # Grid in PCA space
            grid_range = 1.5
            n_grid = 150
            xs = torch.linspace(-grid_range, grid_range, n_grid, device=device)
            ys = torch.linspace(-grid_range, grid_range, n_grid, device=device)
            xx, yy = torch.meshgrid(xs, ys, indexing='ij')
            grid_points = xx.reshape(-1, 1) * pc1.unsqueeze(0) + yy.reshape(-1, 1) * pc2.unsqueeze(0)

            E = compute_energy(grid_points, M.to(device), beta).reshape(n_grid, n_grid).cpu().numpy()

            # Trajectory
            traj = hopfield_trajectory(q0.to(device), M.to(device), beta, max_iter=15)
            traj_2d = project_to_plane(traj, pc1, pc2).cpu().numpy()

            # Memories projected
            mem_2d = project_to_plane(M.T.to(device), pc1, pc2).cpu().numpy()

            # Reference point
            ref_2d = project_to_plane(ref_point.unsqueeze(0).to(device), pc1, pc2).cpu().numpy()

            # Plot
            xx_np, yy_np = xx.cpu().numpy(), yy.cpu().numpy()
            E_clip = np.clip(E, np.percentile(E, 1), np.percentile(E, 95))
            cs = ax.pcolormesh(xx_np, yy_np, E_clip, cmap='viridis', shading='auto')
            ax.contour(xx_np, yy_np, E_clip, levels=15, colors='white', linewidths=0.3, alpha=0.5)

            ax.scatter(mem_2d[:, 0], mem_2d[:, 1], c='white', s=3, alpha=0.3, zorder=2)

            if ref_point.norm() > 1e-6:
                ax.scatter(ref_2d[:, 0], ref_2d[:, 1], c='red', s=100, marker='*',
                          zorder=5, label=ref_label)
            else:
                ax.scatter(0, 0, c='red', s=100, marker='*', zorder=5, label='origin')

            ax.plot(traj_2d[:, 0], traj_2d[:, 1], 'o-', color='#ff6600', markersize=4,
                   linewidth=2, zorder=4)
            ax.scatter(traj_2d[0, 0], traj_2d[0, 1], c='lime', s=80, marker='s',
                      zorder=5, label='q₀')
            ax.scatter(traj_2d[-1, 0], traj_2d[-1, 1], c='magenta', s=80, marker='D',
                      zorder=5, label='q_T')

            ax.set_title(f"{label}, β={beta}", fontsize=13, fontweight='bold')
            ax.set_xlabel("PC1")
            ax.set_ylabel("PC2")
            ax.legend(fontsize=7, loc='upper right')
            plt.colorbar(cs, ax=ax, shrink=0.8, label='E(q)')

    fig.suptitle("Fig 4: PCA Top-2 Energy Heatmap + Trajectory", fontsize=15, fontweight='bold')
    fig.tight_layout()
    fig.savefig(outdir / "fig4_pca.png", dpi=150, bbox_inches='tight')
    plt.close(fig)
    print(f"Saved {outdir / 'fig4_pca.png'}")


# ── Main ─────────────────────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--memory-bank", type=str, required=True)
    parser.add_argument("--questions", type=str, required=True)
    parser.add_argument("--device", type=str, default="cpu")
    parser.add_argument("--query-idx", type=int, default=0)
    parser.add_argument("--outdir", type=str, default="figures")
    args = parser.parse_args()

    device = args.device
    outdir = Path(args.outdir)
    outdir.mkdir(parents=True, exist_ok=True)

    # Load memory bank
    mb = MemoryBank(MemoryBankConfig(embedding_dim=768, normalize=True, center=False))
    mb.load(args.memory_bank, device=device)
    M_raw = mb.embeddings  # (d, N)
    d, N = M_raw.shape
    print(f"Memory bank: d={d}, N={N}")

    # Center
    mu = M_raw.mean(dim=1)  # (d,)
    M_cent = M_raw - mu.unsqueeze(1)
    print(f"‖μ‖ = {mu.norm():.4f}")

    # Load one query
    with open(args.questions) as f:
        questions = [json.loads(line) for line in f]

    q_text = questions[args.query_idx]["question"]
    print(f"Query [{args.query_idx}]: '{q_text}'")

    encoder = Encoder(EncoderConfig(model_name="facebook/contriever-msmarco"), device=device)
    q0_raw = encoder.encode([q_text]).squeeze(0)  # (d,)
    print(f"‖q0_raw‖ = {q0_raw.norm():.4f}")

    # β values: below and above β_critical ≈ 37.6
    betas_plot = [5.0, 20.0, 50.0, 100.0]

    print("\n--- Generating Figure 1: 2D Contour ---")
    fig1_contour(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device)

    print("\n--- Generating Figure 2: 1D Profile ---")
    fig2_profile(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device)

    print("\n--- Generating Figure 3: UMAP ---")
    fig3_umap(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device)

    print("\n--- Generating Figure 4: PCA Heatmap ---")
    fig4_pca(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device)

    print(f"\nAll figures saved to {outdir}/")


if __name__ == "__main__":
    main()