summaryrefslogtreecommitdiff
path: root/scripts/visualize_energy.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-02-16 14:44:42 -0600
committerYurenHao0426 <Blackhao0426@gmail.com>2026-02-16 14:44:42 -0600
commit09d50e47860da0035e178a442dc936028808a0b3 (patch)
tree9d651b0c7d289a9a0405953f2da989a3c431f147 /scripts/visualize_energy.py
parentc90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (diff)
Add memory centering, grid search experiments, and energy visualizationsHEADmaster
- Add centering support to MemoryBank (center_query, apply_centering, mean persistence in save/load) to remove centroid attractor in Hopfield dynamics - Add center flag to MemoryBankConfig, device field to PipelineConfig - Grid search scripts: initial (β≤8), residual, high-β, and centered grids with dedup-based LLM caching (89-91% call savings) - Energy landscape visualization: 2D contour, 1D profile, UMAP, PCA heatmap comparing centered vs uncentered dynamics - Experiment log (note.md) documenting 4 rounds of results and root cause analysis of centroid attractor problem - Key finding: β_critical ≈ 37.6 for centered memory; best configs beat FAISS baseline by +3-4% F1 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'scripts/visualize_energy.py')
-rw-r--r--scripts/visualize_energy.py443
1 files changed, 443 insertions, 0 deletions
diff --git a/scripts/visualize_energy.py b/scripts/visualize_energy.py
new file mode 100644
index 0000000..f39953c
--- /dev/null
+++ b/scripts/visualize_energy.py
@@ -0,0 +1,443 @@
+"""Visualize Hopfield energy landscape: centered vs uncentered.
+
+Produces 4 figures, each with centered/uncentered side-by-side:
+ 1. 2D contour + Hopfield trajectory
+ 2. 1D energy profile along key directions
+ 3. UMAP of memories + query trajectories
+ 4. PCA top-2 energy heatmap
+
+Usage:
+ CUDA_VISIBLE_DEVICES=1 python -u scripts/visualize_energy.py \
+ --memory-bank data/processed/hotpotqa_memory_bank.pt \
+ --questions data/processed/hotpotqa_questions.jsonl \
+ --device cuda --query-idx 0
+"""
+
+import argparse
+import json
+import sys
+from pathlib import Path
+
+import matplotlib
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+import matplotlib.colors as mcolors
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+sys.path.insert(0, "/home/yurenh2/HAG")
+
+from hag.config import MemoryBankConfig, EncoderConfig
+from hag.memory_bank import MemoryBank
+from hag.encoder import Encoder
+
+
+# ── Helpers ──────────────────────────────────────────────────────────
+
+def compute_energy(q: torch.Tensor, M: torch.Tensor, beta: float) -> torch.Tensor:
+ """E(q) = -1/β · logsumexp(β · qᵀM) + 1/2 · ‖q‖²
+
+ Args:
+ q: (..., d)
+ M: (d, N)
+ beta: inverse temperature
+ Returns:
+ energy: (...)
+ """
+ logits = beta * (q @ M) # (..., N)
+ lse = torch.logsumexp(logits, dim=-1) # (...)
+ norm_sq = 0.5 * (q ** 2).sum(dim=-1) # (...)
+ return -1.0 / beta * lse + norm_sq
+
+
+def hopfield_trajectory(q0: torch.Tensor, M: torch.Tensor, beta: float,
+ max_iter: int = 15) -> torch.Tensor:
+ """Run Hopfield and return full trajectory. Returns (T+1, d)."""
+ q = q0.clone().unsqueeze(0) if q0.dim() == 1 else q0.clone() # (1, d)
+ traj = [q.squeeze(0).clone()]
+ for _ in range(max_iter):
+ logits = beta * (q @ M)
+ alpha = torch.softmax(logits, dim=-1)
+ q_new = alpha @ M.T
+ traj.append(q_new.squeeze(0).clone())
+ if (q_new - q).norm() < 1e-8:
+ break
+ q = q_new
+ return torch.stack(traj, dim=0) # (T+1, d)
+
+
+def orthonormalize(v1: torch.Tensor, v2: torch.Tensor):
+ """Return two orthonormal vectors spanning the plane of v1, v2."""
+ e1 = v1 / v1.norm()
+ v2_orth = v2 - (v2 @ e1) * e1
+ if v2_orth.norm() < 1e-8:
+ # v1 and v2 are parallel, pick a random orthogonal direction
+ rand = torch.randn_like(v1)
+ v2_orth = rand - (rand @ e1) * e1
+ e2 = v2_orth / v2_orth.norm()
+ return e1, e2
+
+
+def project_to_plane(points: torch.Tensor, e1: torch.Tensor, e2: torch.Tensor):
+ """Project (K, d) points onto 2D plane defined by e1, e2. Returns (K, 2)."""
+ return torch.stack([points @ e1, points @ e2], dim=-1)
+
+
+# ── Figure 1: 2D Contour + Trajectory ───────────────────────────────
+
+def fig1_contour(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device):
+ """2D energy contour on the query-centroid plane, with Hopfield trajectories."""
+
+ centroid = M_raw.mean(dim=1) # (d,)
+ q0_cent = q0_raw - mu
+
+ fig, axes = plt.subplots(2, len(betas_plot), figsize=(6 * len(betas_plot), 12),
+ squeeze=False)
+
+ for col, beta in enumerate(betas_plot):
+ for row, (label, M, q0, ref_point, ref_label) in enumerate([
+ ("Uncentered", M_raw, q0_raw, centroid, "centroid"),
+ ("Centered", M_cent, q0_cent, torch.zeros_like(centroid), "origin"),
+ ]):
+ ax = axes[row, col]
+
+ # Define 2D plane: query direction + centroid/origin direction
+ e1, e2 = orthonormalize(q0.to(device), ref_point.to(device) if ref_point.norm() > 1e-6 else M.to(device)[:, 0])
+
+ # Grid
+ grid_range = 1.5
+ n_grid = 150
+ xs = torch.linspace(-grid_range, grid_range, n_grid, device=device)
+ ys = torch.linspace(-grid_range, grid_range, n_grid, device=device)
+ xx, yy = torch.meshgrid(xs, ys, indexing='ij')
+ grid_points = xx.reshape(-1, 1) * e1.unsqueeze(0) + yy.reshape(-1, 1) * e2.unsqueeze(0) # (n^2, d)
+
+ E = compute_energy(grid_points, M.to(device), beta).reshape(n_grid, n_grid).cpu().numpy()
+
+ # Trajectory
+ traj = hopfield_trajectory(q0.to(device), M.to(device), beta, max_iter=15)
+ traj_2d = project_to_plane(traj, e1, e2).cpu().numpy()
+
+ # Project memories
+ mem_2d = project_to_plane(M.T.to(device), e1, e2).cpu().numpy()
+
+ # Project reference point
+ ref_2d = project_to_plane(ref_point.unsqueeze(0).to(device), e1, e2).cpu().numpy()
+
+ # Plot
+ xx_np, yy_np = xx.cpu().numpy(), yy.cpu().numpy()
+ # Clip energy for better visualization
+ E_clip = np.clip(E, np.percentile(E, 1), np.percentile(E, 95))
+ cs = ax.contourf(xx_np, yy_np, E_clip, levels=40, cmap='viridis')
+ ax.contour(xx_np, yy_np, E_clip, levels=15, colors='white', linewidths=0.3, alpha=0.5)
+
+ # Memories (small dots)
+ ax.scatter(mem_2d[:, 0], mem_2d[:, 1], c='white', s=3, alpha=0.3, zorder=2)
+
+ # Reference point
+ if ref_point.norm() > 1e-6:
+ ax.scatter(ref_2d[:, 0], ref_2d[:, 1], c='red', s=100, marker='*',
+ zorder=5, label=ref_label)
+ else:
+ ax.scatter(0, 0, c='red', s=100, marker='*', zorder=5, label='origin')
+
+ # Trajectory
+ ax.plot(traj_2d[:, 0], traj_2d[:, 1], 'o-', color='#ff6600', markersize=4,
+ linewidth=2, zorder=4, label='trajectory')
+ ax.scatter(traj_2d[0, 0], traj_2d[0, 1], c='lime', s=80, marker='s',
+ zorder=5, label='q₀')
+ ax.scatter(traj_2d[-1, 0], traj_2d[-1, 1], c='magenta', s=80, marker='D',
+ zorder=5, label=f'q_T (T={len(traj_2d)-1})')
+
+ ax.set_title(f"{label}, β={beta}", fontsize=13, fontweight='bold')
+ ax.set_xlabel("e₁ (query dir)")
+ ax.set_ylabel("e₂")
+ ax.legend(fontsize=7, loc='upper right')
+ plt.colorbar(cs, ax=ax, shrink=0.8, label='E(q)')
+
+ fig.suptitle("Fig 1: 2D Energy Contour + Hopfield Trajectory", fontsize=15, fontweight='bold')
+ fig.tight_layout()
+ fig.savefig(outdir / "fig1_contour.png", dpi=150, bbox_inches='tight')
+ plt.close(fig)
+ print(f"Saved {outdir / 'fig1_contour.png'}")
+
+
+# ── Figure 2: 1D Energy Profile ─────────────────────────────────────
+
+def fig2_profile(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device):
+ """1D energy along key directions."""
+
+ centroid = M_raw.mean(dim=1)
+ q0_cent = q0_raw - mu
+
+ # Find top-1 memory for each
+ scores_raw = q0_raw @ M_raw
+ top1_raw_idx = scores_raw.argmax().item()
+ top1_raw = M_raw[:, top1_raw_idx]
+
+ scores_cent = q0_cent @ M_cent
+ top1_cent_idx = scores_cent.argmax().item()
+ top1_cent = M_cent[:, top1_cent_idx]
+
+ fig, axes = plt.subplots(2, len(betas_plot), figsize=(6 * len(betas_plot), 10),
+ squeeze=False)
+
+ ts = torch.linspace(-0.5, 2.0, 300, device=device)
+
+ for col, beta in enumerate(betas_plot):
+ # Uncentered
+ ax = axes[0, col]
+ for target, name, color in [
+ (centroid, "→ centroid", "red"),
+ (top1_raw, f"→ memory[{top1_raw_idx}]", "blue"),
+ (torch.zeros_like(q0_raw), "→ origin", "gray"),
+ ]:
+ direction = target - q0_raw.to(device)
+ if direction.norm() < 1e-8:
+ continue
+ points = q0_raw.unsqueeze(0).to(device) + ts.unsqueeze(1) * direction.unsqueeze(0)
+ E = compute_energy(points, M_raw.to(device), beta).cpu().numpy()
+ ax.plot(ts.cpu().numpy(), E, label=name, color=color, linewidth=2)
+
+ # Mark t=0 (query) and t=1 (target)
+ E_q0 = compute_energy(q0_raw.unsqueeze(0).to(device), M_raw.to(device), beta).item()
+ ax.axvline(0, color='lime', linestyle='--', alpha=0.5, label='q₀')
+ ax.axvline(1, color='black', linestyle=':', alpha=0.5, label='target')
+ ax.set_title(f"Uncentered, β={beta}", fontsize=13, fontweight='bold')
+ ax.set_xlabel("t (q₀ + t·(target - q₀))")
+ ax.set_ylabel("E(q)")
+ ax.legend(fontsize=8)
+ ax.grid(True, alpha=0.3)
+
+ # Centered
+ ax = axes[1, col]
+ for target, name, color in [
+ (torch.zeros_like(q0_cent), "→ origin", "red"),
+ (top1_cent, f"→ memory[{top1_cent_idx}]", "blue"),
+ ]:
+ direction = target.to(device) - q0_cent.to(device)
+ if direction.norm() < 1e-8:
+ continue
+ points = q0_cent.unsqueeze(0).to(device) + ts.unsqueeze(1) * direction.unsqueeze(0)
+ E = compute_energy(points, M_cent.to(device), beta).cpu().numpy()
+ ax.plot(ts.cpu().numpy(), E, label=name, color=color, linewidth=2)
+
+ ax.axvline(0, color='lime', linestyle='--', alpha=0.5, label='q₀')
+ ax.axvline(1, color='black', linestyle=':', alpha=0.5, label='target')
+ ax.set_title(f"Centered, β={beta}", fontsize=13, fontweight='bold')
+ ax.set_xlabel("t (q̃₀ + t·(target - q̃₀))")
+ ax.set_ylabel("E(q)")
+ ax.legend(fontsize=8)
+ ax.grid(True, alpha=0.3)
+
+ fig.suptitle("Fig 2: 1D Energy Profile Along Key Directions", fontsize=15, fontweight='bold')
+ fig.tight_layout()
+ fig.savefig(outdir / "fig2_profile.png", dpi=150, bbox_inches='tight')
+ plt.close(fig)
+ print(f"Saved {outdir / 'fig2_profile.png'}")
+
+
+# ── Figure 3: UMAP + Trajectories ───────────────────────────────────
+
+def fig3_umap(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device):
+ """UMAP of memories + query trajectories."""
+ try:
+ import umap
+ except ImportError:
+ print("umap-learn not installed, skipping fig3")
+ return
+
+ centroid = M_raw.mean(dim=1)
+ q0_cent = q0_raw - mu
+
+ fig, axes = plt.subplots(2, len(betas_plot), figsize=(6 * len(betas_plot), 12),
+ squeeze=False)
+
+ for col, beta in enumerate(betas_plot):
+ for row, (label, M, q0) in enumerate([
+ ("Uncentered", M_raw, q0_raw),
+ ("Centered", M_cent, q0_cent),
+ ]):
+ ax = axes[row, col]
+
+ # Trajectory
+ traj = hopfield_trajectory(q0.to(device), M.to(device), beta, max_iter=15)
+ traj_cpu = traj.cpu()
+
+ # Combine memories + trajectory for UMAP
+ mem_cpu = M.T.cpu() # (N, d)
+ all_points = torch.cat([mem_cpu, traj_cpu], dim=0).numpy()
+
+ # Fit UMAP
+ reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, random_state=42)
+ embedding = reducer.fit_transform(all_points)
+
+ n_mem = mem_cpu.shape[0]
+ mem_emb = embedding[:n_mem]
+ traj_emb = embedding[n_mem:]
+
+ # Energy for color
+ E_mem = compute_energy(mem_cpu.to(device), M.to(device), beta).cpu().numpy()
+
+ # Plot memories colored by energy
+ sc = ax.scatter(mem_emb[:, 0], mem_emb[:, 1], c=E_mem, cmap='viridis',
+ s=10, alpha=0.5, zorder=1)
+
+ # Plot trajectory
+ ax.plot(traj_emb[:, 0], traj_emb[:, 1], 'o-', color='#ff6600',
+ markersize=5, linewidth=2, zorder=3, label='trajectory')
+ ax.scatter(traj_emb[0, 0], traj_emb[0, 1], c='lime', s=100,
+ marker='s', zorder=4, label='q₀')
+ ax.scatter(traj_emb[-1, 0], traj_emb[-1, 1], c='magenta', s=100,
+ marker='D', zorder=4, label=f'q_T')
+
+ ax.set_title(f"{label}, β={beta}", fontsize=13, fontweight='bold')
+ ax.legend(fontsize=8, loc='upper right')
+ plt.colorbar(sc, ax=ax, shrink=0.8, label='E(q)')
+
+ fig.suptitle("Fig 3: UMAP of Memories + Hopfield Trajectory (color = energy)",
+ fontsize=15, fontweight='bold')
+ fig.tight_layout()
+ fig.savefig(outdir / "fig3_umap.png", dpi=150, bbox_inches='tight')
+ plt.close(fig)
+ print(f"Saved {outdir / 'fig3_umap.png'}")
+
+
+# ── Figure 4: PCA Top-2 Energy Heatmap ──────────────────────────────
+
+def fig4_pca(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device):
+ """Energy heatmap on PCA top-2 components of memory bank."""
+
+ centroid = M_raw.mean(dim=1)
+ q0_cent = q0_raw - mu
+
+ fig, axes = plt.subplots(2, len(betas_plot), figsize=(6 * len(betas_plot), 12),
+ squeeze=False)
+
+ for row, (label, M, q0, ref_point, ref_label) in enumerate([
+ ("Uncentered", M_raw, q0_raw, centroid, "centroid"),
+ ("Centered", M_cent, q0_cent, torch.zeros_like(centroid), "origin"),
+ ]):
+ # PCA on this memory bank
+ M_cpu = M.cpu() # (d, N)
+ # SVD of M to get top-2 directions
+ U, S, Vh = torch.linalg.svd(M_cpu, full_matrices=False)
+ pc1 = U[:, 0].to(device) # (d,)
+ pc2 = U[:, 1].to(device) # (d,)
+
+ for col, beta in enumerate(betas_plot):
+ ax = axes[row, col]
+
+ # Grid in PCA space
+ grid_range = 1.5
+ n_grid = 150
+ xs = torch.linspace(-grid_range, grid_range, n_grid, device=device)
+ ys = torch.linspace(-grid_range, grid_range, n_grid, device=device)
+ xx, yy = torch.meshgrid(xs, ys, indexing='ij')
+ grid_points = xx.reshape(-1, 1) * pc1.unsqueeze(0) + yy.reshape(-1, 1) * pc2.unsqueeze(0)
+
+ E = compute_energy(grid_points, M.to(device), beta).reshape(n_grid, n_grid).cpu().numpy()
+
+ # Trajectory
+ traj = hopfield_trajectory(q0.to(device), M.to(device), beta, max_iter=15)
+ traj_2d = project_to_plane(traj, pc1, pc2).cpu().numpy()
+
+ # Memories projected
+ mem_2d = project_to_plane(M.T.to(device), pc1, pc2).cpu().numpy()
+
+ # Reference point
+ ref_2d = project_to_plane(ref_point.unsqueeze(0).to(device), pc1, pc2).cpu().numpy()
+
+ # Plot
+ xx_np, yy_np = xx.cpu().numpy(), yy.cpu().numpy()
+ E_clip = np.clip(E, np.percentile(E, 1), np.percentile(E, 95))
+ cs = ax.pcolormesh(xx_np, yy_np, E_clip, cmap='viridis', shading='auto')
+ ax.contour(xx_np, yy_np, E_clip, levels=15, colors='white', linewidths=0.3, alpha=0.5)
+
+ ax.scatter(mem_2d[:, 0], mem_2d[:, 1], c='white', s=3, alpha=0.3, zorder=2)
+
+ if ref_point.norm() > 1e-6:
+ ax.scatter(ref_2d[:, 0], ref_2d[:, 1], c='red', s=100, marker='*',
+ zorder=5, label=ref_label)
+ else:
+ ax.scatter(0, 0, c='red', s=100, marker='*', zorder=5, label='origin')
+
+ ax.plot(traj_2d[:, 0], traj_2d[:, 1], 'o-', color='#ff6600', markersize=4,
+ linewidth=2, zorder=4)
+ ax.scatter(traj_2d[0, 0], traj_2d[0, 1], c='lime', s=80, marker='s',
+ zorder=5, label='q₀')
+ ax.scatter(traj_2d[-1, 0], traj_2d[-1, 1], c='magenta', s=80, marker='D',
+ zorder=5, label='q_T')
+
+ ax.set_title(f"{label}, β={beta}", fontsize=13, fontweight='bold')
+ ax.set_xlabel("PC1")
+ ax.set_ylabel("PC2")
+ ax.legend(fontsize=7, loc='upper right')
+ plt.colorbar(cs, ax=ax, shrink=0.8, label='E(q)')
+
+ fig.suptitle("Fig 4: PCA Top-2 Energy Heatmap + Trajectory", fontsize=15, fontweight='bold')
+ fig.tight_layout()
+ fig.savefig(outdir / "fig4_pca.png", dpi=150, bbox_inches='tight')
+ plt.close(fig)
+ print(f"Saved {outdir / 'fig4_pca.png'}")
+
+
+# ── Main ─────────────────────────────────────────────────────────────
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--memory-bank", type=str, required=True)
+ parser.add_argument("--questions", type=str, required=True)
+ parser.add_argument("--device", type=str, default="cpu")
+ parser.add_argument("--query-idx", type=int, default=0)
+ parser.add_argument("--outdir", type=str, default="figures")
+ args = parser.parse_args()
+
+ device = args.device
+ outdir = Path(args.outdir)
+ outdir.mkdir(parents=True, exist_ok=True)
+
+ # Load memory bank
+ mb = MemoryBank(MemoryBankConfig(embedding_dim=768, normalize=True, center=False))
+ mb.load(args.memory_bank, device=device)
+ M_raw = mb.embeddings # (d, N)
+ d, N = M_raw.shape
+ print(f"Memory bank: d={d}, N={N}")
+
+ # Center
+ mu = M_raw.mean(dim=1) # (d,)
+ M_cent = M_raw - mu.unsqueeze(1)
+ print(f"‖μ‖ = {mu.norm():.4f}")
+
+ # Load one query
+ with open(args.questions) as f:
+ questions = [json.loads(line) for line in f]
+
+ q_text = questions[args.query_idx]["question"]
+ print(f"Query [{args.query_idx}]: '{q_text}'")
+
+ encoder = Encoder(EncoderConfig(model_name="facebook/contriever-msmarco"), device=device)
+ q0_raw = encoder.encode([q_text]).squeeze(0) # (d,)
+ print(f"‖q0_raw‖ = {q0_raw.norm():.4f}")
+
+ # β values: below and above β_critical ≈ 37.6
+ betas_plot = [5.0, 20.0, 50.0, 100.0]
+
+ print("\n--- Generating Figure 1: 2D Contour ---")
+ fig1_contour(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device)
+
+ print("\n--- Generating Figure 2: 1D Profile ---")
+ fig2_profile(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device)
+
+ print("\n--- Generating Figure 3: UMAP ---")
+ fig3_umap(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device)
+
+ print("\n--- Generating Figure 4: PCA Heatmap ---")
+ fig4_pca(q0_raw, M_raw, M_cent, mu, betas_plot, outdir, device)
+
+ print(f"\nAll figures saved to {outdir}/")
+
+
+if __name__ == "__main__":
+ main()