"""Visualize Hopfield energy landscape: centered vs uncentered. Produces 4 figures, each with centered/uncentered side-by-side: 1. 2D contour + Hopfield trajectory 2. 1D energy profile along key directions 3. UMAP of memories + query trajectories 4. PCA top-2 energy heatmap Usage: CUDA_VISIBLE_DEVICES=1 python -u scripts/visualize_energy.py \ --memory-bank data/processed/hotpotqa_memory_bank.pt \ --questions data/processed/hotpotqa_questions.jsonl \ --device cuda --query-idx 0 """ import argparse import json import sys from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.colors as mcolors import numpy as np import torch import torch.nn.functional as F sys.path.insert(0, "/home/yurenh2/HAG") from hag.config import MemoryBankConfig, EncoderConfig from hag.memory_bank import MemoryBank from hag.encoder import Encoder # ── Helpers ────────────────────────────────────────────────────────── def compute_energy(q: torch.Tensor, M: torch.Tensor, beta: float) -> torch.Tensor: """E(q) = -1/β · logsumexp(β · qᵀM) + 1/2 · ‖q‖² Args: q: (..., d) M: (d, N) beta: inverse temperature Returns: energy: (...) """ logits = beta * (q @ M) # (..., N) lse = torch.logsumexp(logits, dim=-1) # (...) norm_sq = 0.5 * (q ** 2).sum(dim=-1) # (...) return -1.0 / beta * lse + norm_sq def hopfield_trajectory(q0: torch.Tensor, M: torch.Tensor, beta: float, max_iter: int = 15) -> torch.Tensor: """Run Hopfield and return full trajectory. Returns (T+1, d).""" q = q0.clone().unsqueeze(0) if q0.dim() == 1 else q0.clone() # (1, d) traj = [q.squeeze(0).clone()] for _ in range(max_iter): logits = beta * (q @ M) alpha = torch.softmax(logits, dim=-1) q_new = alpha @ M.T traj.append(q_new.squeeze(0).clone()) if (q_new - q).norm() < 1e-8: break q = q_new return torch.stack(traj, dim=0) # (T+1, d) def orthonormalize(v1: torch.Tensor, v2: torch.Tensor): """Return two orthonormal vectors spanning the plane of v1, v2.""" e1 = v1 / v1.norm() v2_orth = v2 - (v2 @ e1) * e1 if v2_orth.norm() < 1e-8: # v1 and v2 are parallel, pick a random orthogonal direction rand = torch.randn_like(v1) v2_orth = rand - (rand @ e1) * e1 e2 = v2_orth / v2_orth.norm() return e1, e2 def project_to_plane(points: torch.Tensor, e1: torch.Tensor, e2: torch.Tensor): """Project (K, d) points onto 2D plane defined by e1, e2. Returns (K, 2).""" return torch.stack([points @ e1, points @ e2], dim=-1) # ── Figure 1: 2D Contour + Trajectory ─────────────────────────────── def fig1_contour(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device): """2D energy contour on the query-centroid plane, with Hopfield trajectories.""" centroid = M_raw.mean(dim=1) # (d,) q0_cent = q0_raw - mu fig, axes = plt.subplots(2, len(betas_plot), figsize=(6 * len(betas_plot), 12), squeeze=False) for col, beta in enumerate(betas_plot): for row, (label, M, q0, ref_point, ref_label) in enumerate([ ("Uncentered", M_raw, q0_raw, centroid, "centroid"), ("Centered", M_cent, q0_cent, torch.zeros_like(centroid), "origin"), ]): ax = axes[row, col] # Define 2D plane: query direction + centroid/origin direction e1, e2 = orthonormalize(q0.to(device), ref_point.to(device) if ref_point.norm() > 1e-6 else M.to(device)[:, 0]) # Grid grid_range = 1.5 n_grid = 150 xs = torch.linspace(-grid_range, grid_range, n_grid, device=device) ys = torch.linspace(-grid_range, grid_range, n_grid, device=device) xx, yy = torch.meshgrid(xs, ys, indexing='ij') grid_points = xx.reshape(-1, 1) * e1.unsqueeze(0) + yy.reshape(-1, 1) * e2.unsqueeze(0) # (n^2, d) E = compute_energy(grid_points, M.to(device), beta).reshape(n_grid, n_grid).cpu().numpy() # Trajectory traj = hopfield_trajectory(q0.to(device), M.to(device), beta, max_iter=15) traj_2d = project_to_plane(traj, e1, e2).cpu().numpy() # Project memories mem_2d = project_to_plane(M.T.to(device), e1, e2).cpu().numpy() # Project reference point ref_2d = project_to_plane(ref_point.unsqueeze(0).to(device), e1, e2).cpu().numpy() # Plot xx_np, yy_np = xx.cpu().numpy(), yy.cpu().numpy() # Clip energy for better visualization E_clip = np.clip(E, np.percentile(E, 1), np.percentile(E, 95)) cs = ax.contourf(xx_np, yy_np, E_clip, levels=40, cmap='viridis') ax.contour(xx_np, yy_np, E_clip, levels=15, colors='white', linewidths=0.3, alpha=0.5) # Memories (small dots) ax.scatter(mem_2d[:, 0], mem_2d[:, 1], c='white', s=3, alpha=0.3, zorder=2) # Reference point if ref_point.norm() > 1e-6: ax.scatter(ref_2d[:, 0], ref_2d[:, 1], c='red', s=100, marker='*', zorder=5, label=ref_label) else: ax.scatter(0, 0, c='red', s=100, marker='*', zorder=5, label='origin') # Trajectory ax.plot(traj_2d[:, 0], traj_2d[:, 1], 'o-', color='#ff6600', markersize=4, linewidth=2, zorder=4, label='trajectory') ax.scatter(traj_2d[0, 0], traj_2d[0, 1], c='lime', s=80, marker='s', zorder=5, label='q₀') ax.scatter(traj_2d[-1, 0], traj_2d[-1, 1], c='magenta', s=80, marker='D', zorder=5, label=f'q_T (T={len(traj_2d)-1})') ax.set_title(f"{label}, β={beta}", fontsize=13, fontweight='bold') ax.set_xlabel("e₁ (query dir)") ax.set_ylabel("e₂") ax.legend(fontsize=7, loc='upper right') plt.colorbar(cs, ax=ax, shrink=0.8, label='E(q)') fig.suptitle("Fig 1: 2D Energy Contour + Hopfield Trajectory", fontsize=15, fontweight='bold') fig.tight_layout() fig.savefig(outdir / "fig1_contour.png", dpi=150, bbox_inches='tight') plt.close(fig) print(f"Saved {outdir / 'fig1_contour.png'}") # ── Figure 2: 1D Energy Profile ───────────────────────────────────── def fig2_profile(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device): """1D energy along key directions.""" centroid = M_raw.mean(dim=1) q0_cent = q0_raw - mu # Find top-1 memory for each scores_raw = q0_raw @ M_raw top1_raw_idx = scores_raw.argmax().item() top1_raw = M_raw[:, top1_raw_idx] scores_cent = q0_cent @ M_cent top1_cent_idx = scores_cent.argmax().item() top1_cent = M_cent[:, top1_cent_idx] fig, axes = plt.subplots(2, len(betas_plot), figsize=(6 * len(betas_plot), 10), squeeze=False) ts = torch.linspace(-0.5, 2.0, 300, device=device) for col, beta in enumerate(betas_plot): # Uncentered ax = axes[0, col] for target, name, color in [ (centroid, "→ centroid", "red"), (top1_raw, f"→ memory[{top1_raw_idx}]", "blue"), (torch.zeros_like(q0_raw), "→ origin", "gray"), ]: direction = target - q0_raw.to(device) if direction.norm() < 1e-8: continue points = q0_raw.unsqueeze(0).to(device) + ts.unsqueeze(1) * direction.unsqueeze(0) E = compute_energy(points, M_raw.to(device), beta).cpu().numpy() ax.plot(ts.cpu().numpy(), E, label=name, color=color, linewidth=2) # Mark t=0 (query) and t=1 (target) E_q0 = compute_energy(q0_raw.unsqueeze(0).to(device), M_raw.to(device), beta).item() ax.axvline(0, color='lime', linestyle='--', alpha=0.5, label='q₀') ax.axvline(1, color='black', linestyle=':', alpha=0.5, label='target') ax.set_title(f"Uncentered, β={beta}", fontsize=13, fontweight='bold') ax.set_xlabel("t (q₀ + t·(target - q₀))") ax.set_ylabel("E(q)") ax.legend(fontsize=8) ax.grid(True, alpha=0.3) # Centered ax = axes[1, col] for target, name, color in [ (torch.zeros_like(q0_cent), "→ origin", "red"), (top1_cent, f"→ memory[{top1_cent_idx}]", "blue"), ]: direction = target.to(device) - q0_cent.to(device) if direction.norm() < 1e-8: continue points = q0_cent.unsqueeze(0).to(device) + ts.unsqueeze(1) * direction.unsqueeze(0) E = compute_energy(points, M_cent.to(device), beta).cpu().numpy() ax.plot(ts.cpu().numpy(), E, label=name, color=color, linewidth=2) ax.axvline(0, color='lime', linestyle='--', alpha=0.5, label='q₀') ax.axvline(1, color='black', linestyle=':', alpha=0.5, label='target') ax.set_title(f"Centered, β={beta}", fontsize=13, fontweight='bold') ax.set_xlabel("t (q̃₀ + t·(target - q̃₀))") ax.set_ylabel("E(q)") ax.legend(fontsize=8) ax.grid(True, alpha=0.3) fig.suptitle("Fig 2: 1D Energy Profile Along Key Directions", fontsize=15, fontweight='bold') fig.tight_layout() fig.savefig(outdir / "fig2_profile.png", dpi=150, bbox_inches='tight') plt.close(fig) print(f"Saved {outdir / 'fig2_profile.png'}") # ── Figure 3: UMAP + Trajectories ─────────────────────────────────── def fig3_umap(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device): """UMAP of memories + query trajectories.""" try: import umap except ImportError: print("umap-learn not installed, skipping fig3") return centroid = M_raw.mean(dim=1) q0_cent = q0_raw - mu fig, axes = plt.subplots(2, len(betas_plot), figsize=(6 * len(betas_plot), 12), squeeze=False) for col, beta in enumerate(betas_plot): for row, (label, M, q0) in enumerate([ ("Uncentered", M_raw, q0_raw), ("Centered", M_cent, q0_cent), ]): ax = axes[row, col] # Trajectory traj = hopfield_trajectory(q0.to(device), M.to(device), beta, max_iter=15) traj_cpu = traj.cpu() # Combine memories + trajectory for UMAP mem_cpu = M.T.cpu() # (N, d) all_points = torch.cat([mem_cpu, traj_cpu], dim=0).numpy() # Fit UMAP reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, random_state=42) embedding = reducer.fit_transform(all_points) n_mem = mem_cpu.shape[0] mem_emb = embedding[:n_mem] traj_emb = embedding[n_mem:] # Energy for color E_mem = compute_energy(mem_cpu.to(device), M.to(device), beta).cpu().numpy() # Plot memories colored by energy sc = ax.scatter(mem_emb[:, 0], mem_emb[:, 1], c=E_mem, cmap='viridis', s=10, alpha=0.5, zorder=1) # Plot trajectory ax.plot(traj_emb[:, 0], traj_emb[:, 1], 'o-', color='#ff6600', markersize=5, linewidth=2, zorder=3, label='trajectory') ax.scatter(traj_emb[0, 0], traj_emb[0, 1], c='lime', s=100, marker='s', zorder=4, label='q₀') ax.scatter(traj_emb[-1, 0], traj_emb[-1, 1], c='magenta', s=100, marker='D', zorder=4, label=f'q_T') ax.set_title(f"{label}, β={beta}", fontsize=13, fontweight='bold') ax.legend(fontsize=8, loc='upper right') plt.colorbar(sc, ax=ax, shrink=0.8, label='E(q)') fig.suptitle("Fig 3: UMAP of Memories + Hopfield Trajectory (color = energy)", fontsize=15, fontweight='bold') fig.tight_layout() fig.savefig(outdir / "fig3_umap.png", dpi=150, bbox_inches='tight') plt.close(fig) print(f"Saved {outdir / 'fig3_umap.png'}") # ── Figure 4: PCA Top-2 Energy Heatmap ────────────────────────────── def fig4_pca(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device): """Energy heatmap on PCA top-2 components of memory bank.""" centroid = M_raw.mean(dim=1) q0_cent = q0_raw - mu fig, axes = plt.subplots(2, len(betas_plot), figsize=(6 * len(betas_plot), 12), squeeze=False) for row, (label, M, q0, ref_point, ref_label) in enumerate([ ("Uncentered", M_raw, q0_raw, centroid, "centroid"), ("Centered", M_cent, q0_cent, torch.zeros_like(centroid), "origin"), ]): # PCA on this memory bank M_cpu = M.cpu() # (d, N) # SVD of M to get top-2 directions U, S, Vh = torch.linalg.svd(M_cpu, full_matrices=False) pc1 = U[:, 0].to(device) # (d,) pc2 = U[:, 1].to(device) # (d,) for col, beta in enumerate(betas_plot): ax = axes[row, col] # Grid in PCA space grid_range = 1.5 n_grid = 150 xs = torch.linspace(-grid_range, grid_range, n_grid, device=device) ys = torch.linspace(-grid_range, grid_range, n_grid, device=device) xx, yy = torch.meshgrid(xs, ys, indexing='ij') grid_points = xx.reshape(-1, 1) * pc1.unsqueeze(0) + yy.reshape(-1, 1) * pc2.unsqueeze(0) E = compute_energy(grid_points, M.to(device), beta).reshape(n_grid, n_grid).cpu().numpy() # Trajectory traj = hopfield_trajectory(q0.to(device), M.to(device), beta, max_iter=15) traj_2d = project_to_plane(traj, pc1, pc2).cpu().numpy() # Memories projected mem_2d = project_to_plane(M.T.to(device), pc1, pc2).cpu().numpy() # Reference point ref_2d = project_to_plane(ref_point.unsqueeze(0).to(device), pc1, pc2).cpu().numpy() # Plot xx_np, yy_np = xx.cpu().numpy(), yy.cpu().numpy() E_clip = np.clip(E, np.percentile(E, 1), np.percentile(E, 95)) cs = ax.pcolormesh(xx_np, yy_np, E_clip, cmap='viridis', shading='auto') ax.contour(xx_np, yy_np, E_clip, levels=15, colors='white', linewidths=0.3, alpha=0.5) ax.scatter(mem_2d[:, 0], mem_2d[:, 1], c='white', s=3, alpha=0.3, zorder=2) if ref_point.norm() > 1e-6: ax.scatter(ref_2d[:, 0], ref_2d[:, 1], c='red', s=100, marker='*', zorder=5, label=ref_label) else: ax.scatter(0, 0, c='red', s=100, marker='*', zorder=5, label='origin') ax.plot(traj_2d[:, 0], traj_2d[:, 1], 'o-', color='#ff6600', markersize=4, linewidth=2, zorder=4) ax.scatter(traj_2d[0, 0], traj_2d[0, 1], c='lime', s=80, marker='s', zorder=5, label='q₀') ax.scatter(traj_2d[-1, 0], traj_2d[-1, 1], c='magenta', s=80, marker='D', zorder=5, label='q_T') ax.set_title(f"{label}, β={beta}", fontsize=13, fontweight='bold') ax.set_xlabel("PC1") ax.set_ylabel("PC2") ax.legend(fontsize=7, loc='upper right') plt.colorbar(cs, ax=ax, shrink=0.8, label='E(q)') fig.suptitle("Fig 4: PCA Top-2 Energy Heatmap + Trajectory", fontsize=15, fontweight='bold') fig.tight_layout() fig.savefig(outdir / "fig4_pca.png", dpi=150, bbox_inches='tight') plt.close(fig) print(f"Saved {outdir / 'fig4_pca.png'}") # ── Main ───────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser() parser.add_argument("--memory-bank", type=str, required=True) parser.add_argument("--questions", type=str, required=True) parser.add_argument("--device", type=str, default="cpu") parser.add_argument("--query-idx", type=int, default=0) parser.add_argument("--outdir", type=str, default="figures") args = parser.parse_args() device = args.device outdir = Path(args.outdir) outdir.mkdir(parents=True, exist_ok=True) # Load memory bank mb = MemoryBank(MemoryBankConfig(embedding_dim=768, normalize=True, center=False)) mb.load(args.memory_bank, device=device) M_raw = mb.embeddings # (d, N) d, N = M_raw.shape print(f"Memory bank: d={d}, N={N}") # Center mu = M_raw.mean(dim=1) # (d,) M_cent = M_raw - mu.unsqueeze(1) print(f"‖μ‖ = {mu.norm():.4f}") # Load one query with open(args.questions) as f: questions = [json.loads(line) for line in f] q_text = questions[args.query_idx]["question"] print(f"Query [{args.query_idx}]: '{q_text}'") encoder = Encoder(EncoderConfig(model_name="facebook/contriever-msmarco"), device=device) q0_raw = encoder.encode([q_text]).squeeze(0) # (d,) print(f"‖q0_raw‖ = {q0_raw.norm():.4f}") # β values: below and above β_critical ≈ 37.6 betas_plot = [5.0, 20.0, 50.0, 100.0] print("\n--- Generating Figure 1: 2D Contour ---") fig1_contour(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device) print("\n--- Generating Figure 2: 1D Profile ---") fig2_profile(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device) print("\n--- Generating Figure 3: UMAP ---") fig3_umap(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device) print("\n--- Generating Figure 4: PCA Heatmap ---") fig4_pca(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device) print(f"\nAll figures saved to {outdir}/") if __name__ == "__main__": main()