summaryrefslogtreecommitdiff
path: root/scripts/diagnose_centering.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/diagnose_centering.py')
-rw-r--r--scripts/diagnose_centering.py301
1 files changed, 301 insertions, 0 deletions
diff --git a/scripts/diagnose_centering.py b/scripts/diagnose_centering.py
new file mode 100644
index 0000000..5b9b4ee
--- /dev/null
+++ b/scripts/diagnose_centering.py
@@ -0,0 +1,301 @@
+"""Diagnose centering dynamics: trace q step by step, verify β_critical theory."""
+
+import sys
+import torch
+import torch.nn.functional as F
+import numpy as np
+
+sys.path.insert(0, "/home/yurenh2/HAG")
+
+from hag.memory_bank import MemoryBank
+from hag.config import MemoryBankConfig
+
+# ── Load memory bank ─────────────────────────────────────────────────
+device = "cuda:0" # CUDA_VISIBLE_DEVICES remaps
+mb = MemoryBank(MemoryBankConfig(embedding_dim=768, normalize=True, center=False))
+mb.load("/home/yurenh2/HAG/data/processed/hotpotqa_memory_bank.pt", device=device)
+M_raw = mb.embeddings # (d, N), L2-normalized, NOT centered
+d, N = M_raw.shape
+print(f"Memory bank: d={d}, N={N}")
+
+# ── Center manually ──────────────────────────────────────────────────
+mu = M_raw.mean(dim=1) # (d,)
+M_cent = M_raw - mu.unsqueeze(1) # (d, N) centered
+print(f"‖μ‖ = {mu.norm():.4f}")
+print(f"‖M̃·1/N‖ = {(M_cent.mean(dim=1)).norm():.2e} (should be ~0)")
+
+# ── Column norms ─────────────────────────────────────────────────────
+col_norms_raw = M_raw.norm(dim=0)
+col_norms_cent = M_cent.norm(dim=0)
+print(f"\nRaw column norms: mean={col_norms_raw.mean():.4f}, std={col_norms_raw.std():.4f}")
+print(f"Centered column norms: mean={col_norms_cent.mean():.4f}, std={col_norms_cent.std():.4f}")
+
+# ── SVD and β_critical ──────────────────────────────────────────────
+# M̃M̃ᵀ/N is the sample covariance. β_crit = N / λ_max(M̃M̃ᵀ) = 1/λ_max(C)
+# where C = M̃M̃ᵀ/N
+# But let's compute via SVD of M̃
+print("\nComputing SVD of centered memory...")
+U, S, Vh = torch.linalg.svd(M_cent, full_matrices=False) # S shape: (min(d,N),)
+print(f"Top 10 singular values: {S[:10].cpu().tolist()}")
+lambda_max_MMT = S[0].item() ** 2 # largest eigenvalue of M̃M̃ᵀ
+lambda_max_C = lambda_max_MMT / N # largest eigenvalue of M̃M̃ᵀ/N
+
+print(f"\nλ_max(M̃M̃ᵀ) = {lambda_max_MMT:.2f}")
+print(f"λ_max(C=M̃M̃ᵀ/N) = {lambda_max_C:.4f}")
+
+# Jacobian at origin: DT = β/N · M̃M̃ᵀ
+# Spectral radius = β/N · λ_max(M̃M̃ᵀ) = β · λ_max(C)
+# For instability: β · λ_max(C) > 1 → β > 1/λ_max(C)
+beta_crit = 1.0 / lambda_max_C
+print(f"\nβ_critical = 1/λ_max(C) = {beta_crit:.4f}")
+print("For β > β_crit, origin is UNSTABLE (Jacobian spectral radius > 1)")
+
+for beta in [0.5, 1.0, 5.0, 10.0, 20.0, 50.0, 100.0]:
+ rho = beta * lambda_max_C
+ print(f" β={beta:6.1f}: Jacobian spectral radius = {rho:.2f} {'UNSTABLE' if rho > 1 else 'stable'}")
+
+# ── Load some real queries ───────────────────────────────────────────
+import json
+questions_path = "/home/yurenh2/HAG/data/processed/hotpotqa_questions.jsonl"
+with open(questions_path) as f:
+ questions = [json.loads(line) for line in f][:5]
+
+from transformers import AutoTokenizer, AutoModel
+print("\nLoading encoder...")
+tokenizer = AutoTokenizer.from_pretrained("facebook/contriever-msmarco")
+model = AutoModel.from_pretrained("facebook/contriever-msmarco").to(device)
+model.eval()
+
+def encode(texts):
+ inputs = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
+ with torch.no_grad():
+ outputs = model(**inputs)
+ # Contriever uses mean pooling
+ mask = inputs["attention_mask"].unsqueeze(-1).float()
+ emb = (outputs.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1)
+ return F.normalize(emb, dim=-1)
+
+q_texts = [q["question"] for q in questions]
+q_embs = encode(q_texts) # (5, d)
+
+# ── Trace dynamics step by step (centered) ───────────────────────────
+print("\n" + "=" * 80)
+print("CENTERED DYNAMICS TRACE (5 queries)")
+print("=" * 80)
+
+for beta in [5.0, 20.0, 50.0, 100.0]:
+ print(f"\n--- β = {beta} ---")
+
+ for qi in range(min(3, len(q_texts))):
+ q0_raw = q_embs[qi:qi+1] # (1, d)
+ q0 = q0_raw - mu.unsqueeze(0) # center the query
+
+ q = q0.clone()
+ print(f"\n Q{qi}: '{q_texts[qi][:60]}...'")
+ print(f" t=0: ‖q‖={q.norm():.4f}")
+
+ for t in range(8):
+ logits = beta * (q @ M_cent) # (1, N)
+ alpha = torch.softmax(logits, dim=-1) # (1, N)
+ entropy = -(alpha * alpha.log()).sum().item()
+ max_alpha = alpha.max().item()
+ q_new = alpha @ M_cent.T # (1, d)
+
+ # How close is q_new to the dominant eigenvector?
+ cos_v1 = abs((q_new / (q_new.norm() + 1e-12)) @ U[:, 0]).item()
+
+ # FAISS-equivalent: initial attention from raw query on raw memory
+ if t == 0:
+ logits_raw = beta * (q0_raw @ M_raw)
+ alpha_raw = torch.softmax(logits_raw, dim=-1)
+ _, top5_raw = alpha_raw.topk(5)
+
+ # Centered attention top-5
+ _, top5_cent = alpha.topk(5)
+ overlap = len(set(top5_raw[0].cpu().tolist()) & set(top5_cent[0].cpu().tolist()))
+ print(f" initial overlap(raw top5, centered top5) = {overlap}/5")
+
+ delta = (q_new - q).norm().item()
+ q = q_new
+ print(f" t={t+1}: ‖q‖={q.norm():.6f}, H(α)={entropy:.2f}/{np.log(N):.2f}, "
+ f"max(α)={max_alpha:.2e}, Δ={delta:.2e}, cos(q,v1)={cos_v1:.4f}")
+
+ if delta < 1e-8:
+ print(f" [converged]")
+ break
+
+ # Final top-5 from centered attention
+ logits_final = beta * (q @ M_cent)
+ alpha_final = torch.softmax(logits_final, dim=-1)
+ _, top5_final = alpha_final.topk(5)
+
+ # Compare to "iter=0" (just softmax on centered, no iteration)
+ logits_iter0 = beta * (q0 @ M_cent)
+ alpha_iter0 = torch.softmax(logits_iter0, dim=-1)
+ _, top5_iter0 = alpha_iter0.topk(5)
+
+ # And raw (FAISS-like)
+ logits_faiss = beta * (q0_raw @ M_raw)
+ alpha_faiss = torch.softmax(logits_faiss, dim=-1)
+ _, top5_faiss = alpha_faiss.topk(5)
+
+ print(f" Top-5 FAISS: {top5_faiss[0].cpu().tolist()}")
+ print(f" Top-5 cent t=0: {top5_iter0[0].cpu().tolist()}")
+ print(f" Top-5 cent final: {top5_final[0].cpu().tolist()}")
+
+# ── KEY TEST: Does the Jacobian actually amplify near origin? ────────
+print("\n" + "=" * 80)
+print("JACOBIAN TEST: Start near origin, see if dynamics amplify")
+print("=" * 80)
+
+for beta in [5.0, 50.0, 100.0]:
+ # Start with a tiny perturbation in direction of top eigenvector
+ eps = 1e-4
+ q_tiny = eps * U[:, 0].unsqueeze(0) # (1, d), tiny perturbation along v1
+
+ print(f"\n--- β = {beta}, q0 = {eps}·v1 ---")
+ q = q_tiny.clone()
+ for t in range(10):
+ logits = beta * (q @ M_cent)
+ alpha = torch.softmax(logits, dim=-1)
+ q_new = alpha @ M_cent.T
+ amplification = q_new.norm().item() / (q.norm().item() + 1e-20)
+ print(f" t={t}: ‖q‖={q.norm():.6e} → ‖q_new‖={q_new.norm():.6e}, "
+ f"amplification={amplification:.2f}")
+ q = q_new
+ if q.norm().item() > 1.0:
+ print(f" [escaped origin at t={t}]")
+ break
+
+# ── Alternative: What about removing the ‖q‖² term from energy? ─────
+# The standard update q_{t+1} = M·softmax(β·Mᵀq) minimizes
+# E(q) = -1/β·lse(β·Mᵀq) + 1/2·‖q‖²
+# What if we don't want the ‖q‖² penalty? Then the fixed point equation
+# is just q* = M·softmax(β·Mᵀq*), same update but different energy landscape.
+# The issue is: with centering, M̃·uniform = 0 regardless of energy.
+# The ‖q‖² penalty is NOT the problem for centering — the averaging is.
+
+print("\n" + "=" * 80)
+print("DIAGNOSIS: Why centering fails for iteration")
+print("=" * 80)
+
+# For β=50, show the attention distribution at t=0 (before any iteration)
+beta = 50.0
+q0_raw = q_embs[0:1]
+q0_cent = q0_raw - mu.unsqueeze(0)
+logits_cent = beta * (q0_cent @ M_cent) # (1, N)
+alpha_cent = torch.softmax(logits_cent, dim=-1)
+entropy_cent = -(alpha_cent * alpha_cent.log()).sum().item()
+max_alpha_cent = alpha_cent.max().item()
+
+logits_raw = beta * (q0_raw @ M_raw)
+alpha_raw = torch.softmax(logits_raw, dim=-1)
+entropy_raw = -(alpha_raw * alpha_raw.log()).sum().item()
+
+print(f"\nβ={beta}, Q0: '{q_texts[0][:60]}...'")
+print(f"Raw attention: entropy={entropy_raw:.2f}, max={alpha_raw.max():.4f}")
+print(f"Cent attention: entropy={entropy_cent:.2f}, max={max_alpha_cent:.4f}")
+print(f"‖q0_cent‖ = {q0_cent.norm():.4f}")
+print(f"‖q0_raw‖ = {q0_raw.norm():.4f}")
+
+# Show: what's the actual q1 norm vs predicted from Jacobian?
+q1_cent = alpha_cent @ M_cent.T # (1, d)
+predicted_norm = (beta / N * lambda_max_MMT) * q0_cent.norm().item() # rough bound
+print(f"\n‖q1_cent‖ actual = {q1_cent.norm():.6f}")
+print(f"‖q0_cent‖ × Jacobian_spectral_radius ≈ {q0_cent.norm():.4f} × {beta*lambda_max_C:.2f} = {q0_cent.norm().item()*beta*lambda_max_C:.4f}")
+print(f"But the linearization only holds for q→0. q0 is NOT near zero.")
+
+# The real issue: softmax(β·M̃ᵀ·q0) when q0 has ‖q‖=0.5
+# The logits have some spread, but the weighted average of CENTERED vectors
+# inherently cancels out.
+weighted_avg = (alpha_cent @ M_cent.T) # (1, d)
+unweighted_avg = M_cent.mean(dim=1) # (d,)
+print(f"\n‖weighted_avg‖ = {weighted_avg.norm():.6f}")
+print(f"‖unweighted_avg‖ = {unweighted_avg.norm():.2e}")
+
+# How concentrated is the attention?
+top50_vals, top50_idx = alpha_cent.topk(50)
+mass_top50 = top50_vals.sum().item()
+print(f"Mass in top-50 memories: {mass_top50:.4f}")
+print(f"Mass in top-5 memories: {alpha_cent.topk(5)[0].sum().item():.4f}")
+
+# The weighted average of centered vectors is small because:
+# 1. Centered vectors m̃_i have ‖m̃_i‖ ≈ 0.78 (smaller than raw ‖m_i‖=1)
+# 2. Centered vectors point in diverse directions (they have mean removed)
+# 3. Even with non-uniform weights, the cancellation is severe unless
+# attention is extremely peaked on a few memories
+# So the output ‖q1‖ << ‖q0‖, even though β is large
+
+# Key quantification: what fraction of ‖q0‖ is preserved?
+preserve_ratio = q1_cent.norm().item() / q0_cent.norm().item()
+print(f"\n‖q1‖/‖q0‖ = {preserve_ratio:.4f} (fraction of query norm preserved)")
+print("This ratio << 1 means the averaging contracts the query toward 0.")
+print("For centering to work with iteration, this ratio must be > 1.")
+
+print("\n" + "=" * 80)
+print("SOLUTION ANALYSIS")
+print("=" * 80)
+print("""
+The centering fix removes the centroid attractor: M̃·uniform = 0, not μ.
+But the fundamental problem remains: ANY weighted average of centered vectors
+is much shorter than the input query, because centered vectors cancel.
+
+For the origin to be unstable, β must exceed β_critical so that the Jacobian
+amplifies perturbations near zero. But the dynamics from a realistic starting
+point (‖q‖≈0.5) don't behave like the linearization predicts.
+
+The actual contraction ratio ‖q1‖/‖q0‖ is what matters, not the Jacobian
+at origin. This ratio is small because softmax isn't peaked enough.
+
+Possible fixes:
+1. MUCH higher β (β > 500?) to make attention ultra-peaked → less cancellation
+2. Residual connection with centering: q_{t+1} = λ·q_t + (1-λ)·M̃·softmax(...)
+ This explicitly preserves query norm while still benefiting from centering.
+3. Normalize q_{t+1} after each step to prevent norm collapse.
+4. Use centering only for the attention computation, not for the update target:
+ α = softmax(β · M̃ᵀ · q̃) but q_{t+1} = M_raw · α (update in original space)
+""")
+
+# Test option 4: centered attention, raw update
+print("\n" + "=" * 80)
+print("TEST: Centered attention + raw update (hybrid)")
+print("=" * 80)
+
+for beta in [5.0, 20.0, 50.0]:
+ print(f"\n--- β = {beta} ---")
+ for qi in range(min(3, len(q_texts))):
+ q_raw = q_embs[qi:qi+1].clone() # (1, d) raw query
+ print(f" Q{qi}: '{q_texts[qi][:50]}...'")
+
+ for t in range(5):
+ q_cent = q_raw - mu.unsqueeze(0) # center query
+ logits = beta * (q_cent @ M_cent) # attention on centered space
+ alpha = torch.softmax(logits, dim=-1)
+ q_new = alpha @ M_raw.T # update in RAW space
+
+ entropy = -(alpha * alpha.log()).sum().item()
+ delta = (q_new - q_raw).norm().item()
+
+ if t == 0:
+ _, top5 = alpha.topk(5)
+ # FAISS top5
+ logits_f = q_embs[qi:qi+1] @ M_raw
+ _, top5_f = logits_f.topk(5)
+ overlap = len(set(top5[0].cpu().tolist()) & set(top5_f[0].cpu().tolist()))
+
+ print(f" t={t}: ‖q‖={q_raw.norm():.4f} → {q_new.norm():.4f}, "
+ f"H={entropy:.2f}, Δ={delta:.4f}")
+ q_raw = q_new
+
+ # Final top-5
+ q_cent_f = q_raw - mu.unsqueeze(0)
+ logits_f = beta * (q_cent_f @ M_cent)
+ _, top5_final = torch.softmax(logits_f, dim=-1).topk(5)
+ logits_faiss = q_embs[qi:qi+1] @ M_raw
+ _, top5_faiss = logits_faiss.topk(5)
+ overlap = len(set(top5_final[0].cpu().tolist()) & set(top5_faiss[0].cpu().tolist()))
+ print(f" final vs FAISS overlap: {overlap}/5")
+ print(f" FAISS top5: {top5_faiss[0].cpu().tolist()}")
+ print(f" Hybrid top5: {top5_final[0].cpu().tolist()}")
+
+print("\nDone.")