diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-02-16 14:44:42 -0600 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-02-16 14:44:42 -0600 |
| commit | 09d50e47860da0035e178a442dc936028808a0b3 (patch) | |
| tree | 9d651b0c7d289a9a0405953f2da989a3c431f147 /scripts/visualize_energy.py | |
| parent | c90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (diff) | |
- Add centering support to MemoryBank (center_query, apply_centering, mean
persistence in save/load) to remove centroid attractor in Hopfield dynamics
- Add center flag to MemoryBankConfig, device field to PipelineConfig
- Grid search scripts: initial (β≤8), residual, high-β, and centered grids
with dedup-based LLM caching (89-91% call savings)
- Energy landscape visualization: 2D contour, 1D profile, UMAP, PCA heatmap
comparing centered vs uncentered dynamics
- Experiment log (note.md) documenting 4 rounds of results and root cause
analysis of centroid attractor problem
- Key finding: β_critical ≈ 37.6 for centered memory; best configs beat
FAISS baseline by +3-4% F1
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'scripts/visualize_energy.py')
| -rw-r--r-- | scripts/visualize_energy.py | 443 |
1 files changed, 443 insertions, 0 deletions
diff --git a/scripts/visualize_energy.py b/scripts/visualize_energy.py new file mode 100644 index 0000000..f39953c --- /dev/null +++ b/scripts/visualize_energy.py @@ -0,0 +1,443 @@ +"""Visualize Hopfield energy landscape: centered vs uncentered. + +Produces 4 figures, each with centered/uncentered side-by-side: + 1. 2D contour + Hopfield trajectory + 2. 1D energy profile along key directions + 3. UMAP of memories + query trajectories + 4. PCA top-2 energy heatmap + +Usage: + CUDA_VISIBLE_DEVICES=1 python -u scripts/visualize_energy.py \ + --memory-bank data/processed/hotpotqa_memory_bank.pt \ + --questions data/processed/hotpotqa_questions.jsonl \ + --device cuda --query-idx 0 +""" + +import argparse +import json +import sys +from pathlib import Path + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import matplotlib.colors as mcolors +import numpy as np +import torch +import torch.nn.functional as F + +sys.path.insert(0, "/home/yurenh2/HAG") + +from hag.config import MemoryBankConfig, EncoderConfig +from hag.memory_bank import MemoryBank +from hag.encoder import Encoder + + +# ── Helpers ────────────────────────────────────────────────────────── + +def compute_energy(q: torch.Tensor, M: torch.Tensor, beta: float) -> torch.Tensor: + """E(q) = -1/β · logsumexp(β · qᵀM) + 1/2 · ‖q‖² + + Args: + q: (..., d) + M: (d, N) + beta: inverse temperature + Returns: + energy: (...) + """ + logits = beta * (q @ M) # (..., N) + lse = torch.logsumexp(logits, dim=-1) # (...) + norm_sq = 0.5 * (q ** 2).sum(dim=-1) # (...) + return -1.0 / beta * lse + norm_sq + + +def hopfield_trajectory(q0: torch.Tensor, M: torch.Tensor, beta: float, + max_iter: int = 15) -> torch.Tensor: + """Run Hopfield and return full trajectory. Returns (T+1, d).""" + q = q0.clone().unsqueeze(0) if q0.dim() == 1 else q0.clone() # (1, d) + traj = [q.squeeze(0).clone()] + for _ in range(max_iter): + logits = beta * (q @ M) + alpha = torch.softmax(logits, dim=-1) + q_new = alpha @ M.T + traj.append(q_new.squeeze(0).clone()) + if (q_new - q).norm() < 1e-8: + break + q = q_new + return torch.stack(traj, dim=0) # (T+1, d) + + +def orthonormalize(v1: torch.Tensor, v2: torch.Tensor): + """Return two orthonormal vectors spanning the plane of v1, v2.""" + e1 = v1 / v1.norm() + v2_orth = v2 - (v2 @ e1) * e1 + if v2_orth.norm() < 1e-8: + # v1 and v2 are parallel, pick a random orthogonal direction + rand = torch.randn_like(v1) + v2_orth = rand - (rand @ e1) * e1 + e2 = v2_orth / v2_orth.norm() + return e1, e2 + + +def project_to_plane(points: torch.Tensor, e1: torch.Tensor, e2: torch.Tensor): + """Project (K, d) points onto 2D plane defined by e1, e2. Returns (K, 2).""" + return torch.stack([points @ e1, points @ e2], dim=-1) + + +# ── Figure 1: 2D Contour + Trajectory ─────────────────────────────── + +def fig1_contour(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device): + """2D energy contour on the query-centroid plane, with Hopfield trajectories.""" + + centroid = M_raw.mean(dim=1) # (d,) + q0_cent = q0_raw - mu + + fig, axes = plt.subplots(2, len(betas_plot), figsize=(6 * len(betas_plot), 12), + squeeze=False) + + for col, beta in enumerate(betas_plot): + for row, (label, M, q0, ref_point, ref_label) in enumerate([ + ("Uncentered", M_raw, q0_raw, centroid, "centroid"), + ("Centered", M_cent, q0_cent, torch.zeros_like(centroid), "origin"), + ]): + ax = axes[row, col] + + # Define 2D plane: query direction + centroid/origin direction + e1, e2 = orthonormalize(q0.to(device), ref_point.to(device) if ref_point.norm() > 1e-6 else M.to(device)[:, 0]) + + # Grid + grid_range = 1.5 + n_grid = 150 + xs = torch.linspace(-grid_range, grid_range, n_grid, device=device) + ys = torch.linspace(-grid_range, grid_range, n_grid, device=device) + xx, yy = torch.meshgrid(xs, ys, indexing='ij') + grid_points = xx.reshape(-1, 1) * e1.unsqueeze(0) + yy.reshape(-1, 1) * e2.unsqueeze(0) # (n^2, d) + + E = compute_energy(grid_points, M.to(device), beta).reshape(n_grid, n_grid).cpu().numpy() + + # Trajectory + traj = hopfield_trajectory(q0.to(device), M.to(device), beta, max_iter=15) + traj_2d = project_to_plane(traj, e1, e2).cpu().numpy() + + # Project memories + mem_2d = project_to_plane(M.T.to(device), e1, e2).cpu().numpy() + + # Project reference point + ref_2d = project_to_plane(ref_point.unsqueeze(0).to(device), e1, e2).cpu().numpy() + + # Plot + xx_np, yy_np = xx.cpu().numpy(), yy.cpu().numpy() + # Clip energy for better visualization + E_clip = np.clip(E, np.percentile(E, 1), np.percentile(E, 95)) + cs = ax.contourf(xx_np, yy_np, E_clip, levels=40, cmap='viridis') + ax.contour(xx_np, yy_np, E_clip, levels=15, colors='white', linewidths=0.3, alpha=0.5) + + # Memories (small dots) + ax.scatter(mem_2d[:, 0], mem_2d[:, 1], c='white', s=3, alpha=0.3, zorder=2) + + # Reference point + if ref_point.norm() > 1e-6: + ax.scatter(ref_2d[:, 0], ref_2d[:, 1], c='red', s=100, marker='*', + zorder=5, label=ref_label) + else: + ax.scatter(0, 0, c='red', s=100, marker='*', zorder=5, label='origin') + + # Trajectory + ax.plot(traj_2d[:, 0], traj_2d[:, 1], 'o-', color='#ff6600', markersize=4, + linewidth=2, zorder=4, label='trajectory') + ax.scatter(traj_2d[0, 0], traj_2d[0, 1], c='lime', s=80, marker='s', + zorder=5, label='q₀') + ax.scatter(traj_2d[-1, 0], traj_2d[-1, 1], c='magenta', s=80, marker='D', + zorder=5, label=f'q_T (T={len(traj_2d)-1})') + + ax.set_title(f"{label}, β={beta}", fontsize=13, fontweight='bold') + ax.set_xlabel("e₁ (query dir)") + ax.set_ylabel("e₂") + ax.legend(fontsize=7, loc='upper right') + plt.colorbar(cs, ax=ax, shrink=0.8, label='E(q)') + + fig.suptitle("Fig 1: 2D Energy Contour + Hopfield Trajectory", fontsize=15, fontweight='bold') + fig.tight_layout() + fig.savefig(outdir / "fig1_contour.png", dpi=150, bbox_inches='tight') + plt.close(fig) + print(f"Saved {outdir / 'fig1_contour.png'}") + + +# ── Figure 2: 1D Energy Profile ───────────────────────────────────── + +def fig2_profile(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device): + """1D energy along key directions.""" + + centroid = M_raw.mean(dim=1) + q0_cent = q0_raw - mu + + # Find top-1 memory for each + scores_raw = q0_raw @ M_raw + top1_raw_idx = scores_raw.argmax().item() + top1_raw = M_raw[:, top1_raw_idx] + + scores_cent = q0_cent @ M_cent + top1_cent_idx = scores_cent.argmax().item() + top1_cent = M_cent[:, top1_cent_idx] + + fig, axes = plt.subplots(2, len(betas_plot), figsize=(6 * len(betas_plot), 10), + squeeze=False) + + ts = torch.linspace(-0.5, 2.0, 300, device=device) + + for col, beta in enumerate(betas_plot): + # Uncentered + ax = axes[0, col] + for target, name, color in [ + (centroid, "→ centroid", "red"), + (top1_raw, f"→ memory[{top1_raw_idx}]", "blue"), + (torch.zeros_like(q0_raw), "→ origin", "gray"), + ]: + direction = target - q0_raw.to(device) + if direction.norm() < 1e-8: + continue + points = q0_raw.unsqueeze(0).to(device) + ts.unsqueeze(1) * direction.unsqueeze(0) + E = compute_energy(points, M_raw.to(device), beta).cpu().numpy() + ax.plot(ts.cpu().numpy(), E, label=name, color=color, linewidth=2) + + # Mark t=0 (query) and t=1 (target) + E_q0 = compute_energy(q0_raw.unsqueeze(0).to(device), M_raw.to(device), beta).item() + ax.axvline(0, color='lime', linestyle='--', alpha=0.5, label='q₀') + ax.axvline(1, color='black', linestyle=':', alpha=0.5, label='target') + ax.set_title(f"Uncentered, β={beta}", fontsize=13, fontweight='bold') + ax.set_xlabel("t (q₀ + t·(target - q₀))") + ax.set_ylabel("E(q)") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # Centered + ax = axes[1, col] + for target, name, color in [ + (torch.zeros_like(q0_cent), "→ origin", "red"), + (top1_cent, f"→ memory[{top1_cent_idx}]", "blue"), + ]: + direction = target.to(device) - q0_cent.to(device) + if direction.norm() < 1e-8: + continue + points = q0_cent.unsqueeze(0).to(device) + ts.unsqueeze(1) * direction.unsqueeze(0) + E = compute_energy(points, M_cent.to(device), beta).cpu().numpy() + ax.plot(ts.cpu().numpy(), E, label=name, color=color, linewidth=2) + + ax.axvline(0, color='lime', linestyle='--', alpha=0.5, label='q₀') + ax.axvline(1, color='black', linestyle=':', alpha=0.5, label='target') + ax.set_title(f"Centered, β={beta}", fontsize=13, fontweight='bold') + ax.set_xlabel("t (q̃₀ + t·(target - q̃₀))") + ax.set_ylabel("E(q)") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + fig.suptitle("Fig 2: 1D Energy Profile Along Key Directions", fontsize=15, fontweight='bold') + fig.tight_layout() + fig.savefig(outdir / "fig2_profile.png", dpi=150, bbox_inches='tight') + plt.close(fig) + print(f"Saved {outdir / 'fig2_profile.png'}") + + +# ── Figure 3: UMAP + Trajectories ─────────────────────────────────── + +def fig3_umap(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device): + """UMAP of memories + query trajectories.""" + try: + import umap + except ImportError: + print("umap-learn not installed, skipping fig3") + return + + centroid = M_raw.mean(dim=1) + q0_cent = q0_raw - mu + + fig, axes = plt.subplots(2, len(betas_plot), figsize=(6 * len(betas_plot), 12), + squeeze=False) + + for col, beta in enumerate(betas_plot): + for row, (label, M, q0) in enumerate([ + ("Uncentered", M_raw, q0_raw), + ("Centered", M_cent, q0_cent), + ]): + ax = axes[row, col] + + # Trajectory + traj = hopfield_trajectory(q0.to(device), M.to(device), beta, max_iter=15) + traj_cpu = traj.cpu() + + # Combine memories + trajectory for UMAP + mem_cpu = M.T.cpu() # (N, d) + all_points = torch.cat([mem_cpu, traj_cpu], dim=0).numpy() + + # Fit UMAP + reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, random_state=42) + embedding = reducer.fit_transform(all_points) + + n_mem = mem_cpu.shape[0] + mem_emb = embedding[:n_mem] + traj_emb = embedding[n_mem:] + + # Energy for color + E_mem = compute_energy(mem_cpu.to(device), M.to(device), beta).cpu().numpy() + + # Plot memories colored by energy + sc = ax.scatter(mem_emb[:, 0], mem_emb[:, 1], c=E_mem, cmap='viridis', + s=10, alpha=0.5, zorder=1) + + # Plot trajectory + ax.plot(traj_emb[:, 0], traj_emb[:, 1], 'o-', color='#ff6600', + markersize=5, linewidth=2, zorder=3, label='trajectory') + ax.scatter(traj_emb[0, 0], traj_emb[0, 1], c='lime', s=100, + marker='s', zorder=4, label='q₀') + ax.scatter(traj_emb[-1, 0], traj_emb[-1, 1], c='magenta', s=100, + marker='D', zorder=4, label=f'q_T') + + ax.set_title(f"{label}, β={beta}", fontsize=13, fontweight='bold') + ax.legend(fontsize=8, loc='upper right') + plt.colorbar(sc, ax=ax, shrink=0.8, label='E(q)') + + fig.suptitle("Fig 3: UMAP of Memories + Hopfield Trajectory (color = energy)", + fontsize=15, fontweight='bold') + fig.tight_layout() + fig.savefig(outdir / "fig3_umap.png", dpi=150, bbox_inches='tight') + plt.close(fig) + print(f"Saved {outdir / 'fig3_umap.png'}") + + +# ── Figure 4: PCA Top-2 Energy Heatmap ────────────────────────────── + +def fig4_pca(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device): + """Energy heatmap on PCA top-2 components of memory bank.""" + + centroid = M_raw.mean(dim=1) + q0_cent = q0_raw - mu + + fig, axes = plt.subplots(2, len(betas_plot), figsize=(6 * len(betas_plot), 12), + squeeze=False) + + for row, (label, M, q0, ref_point, ref_label) in enumerate([ + ("Uncentered", M_raw, q0_raw, centroid, "centroid"), + ("Centered", M_cent, q0_cent, torch.zeros_like(centroid), "origin"), + ]): + # PCA on this memory bank + M_cpu = M.cpu() # (d, N) + # SVD of M to get top-2 directions + U, S, Vh = torch.linalg.svd(M_cpu, full_matrices=False) + pc1 = U[:, 0].to(device) # (d,) + pc2 = U[:, 1].to(device) # (d,) + + for col, beta in enumerate(betas_plot): + ax = axes[row, col] + + # Grid in PCA space + grid_range = 1.5 + n_grid = 150 + xs = torch.linspace(-grid_range, grid_range, n_grid, device=device) + ys = torch.linspace(-grid_range, grid_range, n_grid, device=device) + xx, yy = torch.meshgrid(xs, ys, indexing='ij') + grid_points = xx.reshape(-1, 1) * pc1.unsqueeze(0) + yy.reshape(-1, 1) * pc2.unsqueeze(0) + + E = compute_energy(grid_points, M.to(device), beta).reshape(n_grid, n_grid).cpu().numpy() + + # Trajectory + traj = hopfield_trajectory(q0.to(device), M.to(device), beta, max_iter=15) + traj_2d = project_to_plane(traj, pc1, pc2).cpu().numpy() + + # Memories projected + mem_2d = project_to_plane(M.T.to(device), pc1, pc2).cpu().numpy() + + # Reference point + ref_2d = project_to_plane(ref_point.unsqueeze(0).to(device), pc1, pc2).cpu().numpy() + + # Plot + xx_np, yy_np = xx.cpu().numpy(), yy.cpu().numpy() + E_clip = np.clip(E, np.percentile(E, 1), np.percentile(E, 95)) + cs = ax.pcolormesh(xx_np, yy_np, E_clip, cmap='viridis', shading='auto') + ax.contour(xx_np, yy_np, E_clip, levels=15, colors='white', linewidths=0.3, alpha=0.5) + + ax.scatter(mem_2d[:, 0], mem_2d[:, 1], c='white', s=3, alpha=0.3, zorder=2) + + if ref_point.norm() > 1e-6: + ax.scatter(ref_2d[:, 0], ref_2d[:, 1], c='red', s=100, marker='*', + zorder=5, label=ref_label) + else: + ax.scatter(0, 0, c='red', s=100, marker='*', zorder=5, label='origin') + + ax.plot(traj_2d[:, 0], traj_2d[:, 1], 'o-', color='#ff6600', markersize=4, + linewidth=2, zorder=4) + ax.scatter(traj_2d[0, 0], traj_2d[0, 1], c='lime', s=80, marker='s', + zorder=5, label='q₀') + ax.scatter(traj_2d[-1, 0], traj_2d[-1, 1], c='magenta', s=80, marker='D', + zorder=5, label='q_T') + + ax.set_title(f"{label}, β={beta}", fontsize=13, fontweight='bold') + ax.set_xlabel("PC1") + ax.set_ylabel("PC2") + ax.legend(fontsize=7, loc='upper right') + plt.colorbar(cs, ax=ax, shrink=0.8, label='E(q)') + + fig.suptitle("Fig 4: PCA Top-2 Energy Heatmap + Trajectory", fontsize=15, fontweight='bold') + fig.tight_layout() + fig.savefig(outdir / "fig4_pca.png", dpi=150, bbox_inches='tight') + plt.close(fig) + print(f"Saved {outdir / 'fig4_pca.png'}") + + +# ── Main ───────────────────────────────────────────────────────────── + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--memory-bank", type=str, required=True) + parser.add_argument("--questions", type=str, required=True) + parser.add_argument("--device", type=str, default="cpu") + parser.add_argument("--query-idx", type=int, default=0) + parser.add_argument("--outdir", type=str, default="figures") + args = parser.parse_args() + + device = args.device + outdir = Path(args.outdir) + outdir.mkdir(parents=True, exist_ok=True) + + # Load memory bank + mb = MemoryBank(MemoryBankConfig(embedding_dim=768, normalize=True, center=False)) + mb.load(args.memory_bank, device=device) + M_raw = mb.embeddings # (d, N) + d, N = M_raw.shape + print(f"Memory bank: d={d}, N={N}") + + # Center + mu = M_raw.mean(dim=1) # (d,) + M_cent = M_raw - mu.unsqueeze(1) + print(f"‖μ‖ = {mu.norm():.4f}") + + # Load one query + with open(args.questions) as f: + questions = [json.loads(line) for line in f] + + q_text = questions[args.query_idx]["question"] + print(f"Query [{args.query_idx}]: '{q_text}'") + + encoder = Encoder(EncoderConfig(model_name="facebook/contriever-msmarco"), device=device) + q0_raw = encoder.encode([q_text]).squeeze(0) # (d,) + print(f"‖q0_raw‖ = {q0_raw.norm():.4f}") + + # β values: below and above β_critical ≈ 37.6 + betas_plot = [5.0, 20.0, 50.0, 100.0] + + print("\n--- Generating Figure 1: 2D Contour ---") + fig1_contour(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device) + + print("\n--- Generating Figure 2: 1D Profile ---") + fig2_profile(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device) + + print("\n--- Generating Figure 3: UMAP ---") + fig3_umap(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device) + + print("\n--- Generating Figure 4: PCA Heatmap ---") + fig4_pca(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device) + + print(f"\nAll figures saved to {outdir}/") + + +if __name__ == "__main__": + main() |
