diff options
Diffstat (limited to 'scripts/diagnose_centering.py')
| -rw-r--r-- | scripts/diagnose_centering.py | 301 |
1 files changed, 301 insertions, 0 deletions
diff --git a/scripts/diagnose_centering.py b/scripts/diagnose_centering.py new file mode 100644 index 0000000..5b9b4ee --- /dev/null +++ b/scripts/diagnose_centering.py @@ -0,0 +1,301 @@ +"""Diagnose centering dynamics: trace q step by step, verify β_critical theory.""" + +import sys +import torch +import torch.nn.functional as F +import numpy as np + +sys.path.insert(0, "/home/yurenh2/HAG") + +from hag.memory_bank import MemoryBank +from hag.config import MemoryBankConfig + +# ── Load memory bank ───────────────────────────────────────────────── +device = "cuda:0" # CUDA_VISIBLE_DEVICES remaps +mb = MemoryBank(MemoryBankConfig(embedding_dim=768, normalize=True, center=False)) +mb.load("/home/yurenh2/HAG/data/processed/hotpotqa_memory_bank.pt", device=device) +M_raw = mb.embeddings # (d, N), L2-normalized, NOT centered +d, N = M_raw.shape +print(f"Memory bank: d={d}, N={N}") + +# ── Center manually ────────────────────────────────────────────────── +mu = M_raw.mean(dim=1) # (d,) +M_cent = M_raw - mu.unsqueeze(1) # (d, N) centered +print(f"‖μ‖ = {mu.norm():.4f}") +print(f"‖M̃·1/N‖ = {(M_cent.mean(dim=1)).norm():.2e} (should be ~0)") + +# ── Column norms ───────────────────────────────────────────────────── +col_norms_raw = M_raw.norm(dim=0) +col_norms_cent = M_cent.norm(dim=0) +print(f"\nRaw column norms: mean={col_norms_raw.mean():.4f}, std={col_norms_raw.std():.4f}") +print(f"Centered column norms: mean={col_norms_cent.mean():.4f}, std={col_norms_cent.std():.4f}") + +# ── SVD and β_critical ────────────────────────────────────────────── +# M̃M̃ᵀ/N is the sample covariance. β_crit = N / λ_max(M̃M̃ᵀ) = 1/λ_max(C) +# where C = M̃M̃ᵀ/N +# But let's compute via SVD of M̃ +print("\nComputing SVD of centered memory...") +U, S, Vh = torch.linalg.svd(M_cent, full_matrices=False) # S shape: (min(d,N),) +print(f"Top 10 singular values: {S[:10].cpu().tolist()}") +lambda_max_MMT = S[0].item() ** 2 # largest eigenvalue of M̃M̃ᵀ +lambda_max_C = lambda_max_MMT / N # largest eigenvalue of M̃M̃ᵀ/N + +print(f"\nλ_max(M̃M̃ᵀ) = {lambda_max_MMT:.2f}") +print(f"λ_max(C=M̃M̃ᵀ/N) = {lambda_max_C:.4f}") + +# Jacobian at origin: DT = β/N · M̃M̃ᵀ +# Spectral radius = β/N · λ_max(M̃M̃ᵀ) = β · λ_max(C) +# For instability: β · λ_max(C) > 1 → β > 1/λ_max(C) +beta_crit = 1.0 / lambda_max_C +print(f"\nβ_critical = 1/λ_max(C) = {beta_crit:.4f}") +print("For β > β_crit, origin is UNSTABLE (Jacobian spectral radius > 1)") + +for beta in [0.5, 1.0, 5.0, 10.0, 20.0, 50.0, 100.0]: + rho = beta * lambda_max_C + print(f" β={beta:6.1f}: Jacobian spectral radius = {rho:.2f} {'UNSTABLE' if rho > 1 else 'stable'}") + +# ── Load some real queries ─────────────────────────────────────────── +import json +questions_path = "/home/yurenh2/HAG/data/processed/hotpotqa_questions.jsonl" +with open(questions_path) as f: + questions = [json.loads(line) for line in f][:5] + +from transformers import AutoTokenizer, AutoModel +print("\nLoading encoder...") +tokenizer = AutoTokenizer.from_pretrained("facebook/contriever-msmarco") +model = AutoModel.from_pretrained("facebook/contriever-msmarco").to(device) +model.eval() + +def encode(texts): + inputs = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device) + with torch.no_grad(): + outputs = model(**inputs) + # Contriever uses mean pooling + mask = inputs["attention_mask"].unsqueeze(-1).float() + emb = (outputs.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1) + return F.normalize(emb, dim=-1) + +q_texts = [q["question"] for q in questions] +q_embs = encode(q_texts) # (5, d) + +# ── Trace dynamics step by step (centered) ─────────────────────────── +print("\n" + "=" * 80) +print("CENTERED DYNAMICS TRACE (5 queries)") +print("=" * 80) + +for beta in [5.0, 20.0, 50.0, 100.0]: + print(f"\n--- β = {beta} ---") + + for qi in range(min(3, len(q_texts))): + q0_raw = q_embs[qi:qi+1] # (1, d) + q0 = q0_raw - mu.unsqueeze(0) # center the query + + q = q0.clone() + print(f"\n Q{qi}: '{q_texts[qi][:60]}...'") + print(f" t=0: ‖q‖={q.norm():.4f}") + + for t in range(8): + logits = beta * (q @ M_cent) # (1, N) + alpha = torch.softmax(logits, dim=-1) # (1, N) + entropy = -(alpha * alpha.log()).sum().item() + max_alpha = alpha.max().item() + q_new = alpha @ M_cent.T # (1, d) + + # How close is q_new to the dominant eigenvector? + cos_v1 = abs((q_new / (q_new.norm() + 1e-12)) @ U[:, 0]).item() + + # FAISS-equivalent: initial attention from raw query on raw memory + if t == 0: + logits_raw = beta * (q0_raw @ M_raw) + alpha_raw = torch.softmax(logits_raw, dim=-1) + _, top5_raw = alpha_raw.topk(5) + + # Centered attention top-5 + _, top5_cent = alpha.topk(5) + overlap = len(set(top5_raw[0].cpu().tolist()) & set(top5_cent[0].cpu().tolist())) + print(f" initial overlap(raw top5, centered top5) = {overlap}/5") + + delta = (q_new - q).norm().item() + q = q_new + print(f" t={t+1}: ‖q‖={q.norm():.6f}, H(α)={entropy:.2f}/{np.log(N):.2f}, " + f"max(α)={max_alpha:.2e}, Δ={delta:.2e}, cos(q,v1)={cos_v1:.4f}") + + if delta < 1e-8: + print(f" [converged]") + break + + # Final top-5 from centered attention + logits_final = beta * (q @ M_cent) + alpha_final = torch.softmax(logits_final, dim=-1) + _, top5_final = alpha_final.topk(5) + + # Compare to "iter=0" (just softmax on centered, no iteration) + logits_iter0 = beta * (q0 @ M_cent) + alpha_iter0 = torch.softmax(logits_iter0, dim=-1) + _, top5_iter0 = alpha_iter0.topk(5) + + # And raw (FAISS-like) + logits_faiss = beta * (q0_raw @ M_raw) + alpha_faiss = torch.softmax(logits_faiss, dim=-1) + _, top5_faiss = alpha_faiss.topk(5) + + print(f" Top-5 FAISS: {top5_faiss[0].cpu().tolist()}") + print(f" Top-5 cent t=0: {top5_iter0[0].cpu().tolist()}") + print(f" Top-5 cent final: {top5_final[0].cpu().tolist()}") + +# ── KEY TEST: Does the Jacobian actually amplify near origin? ──────── +print("\n" + "=" * 80) +print("JACOBIAN TEST: Start near origin, see if dynamics amplify") +print("=" * 80) + +for beta in [5.0, 50.0, 100.0]: + # Start with a tiny perturbation in direction of top eigenvector + eps = 1e-4 + q_tiny = eps * U[:, 0].unsqueeze(0) # (1, d), tiny perturbation along v1 + + print(f"\n--- β = {beta}, q0 = {eps}·v1 ---") + q = q_tiny.clone() + for t in range(10): + logits = beta * (q @ M_cent) + alpha = torch.softmax(logits, dim=-1) + q_new = alpha @ M_cent.T + amplification = q_new.norm().item() / (q.norm().item() + 1e-20) + print(f" t={t}: ‖q‖={q.norm():.6e} → ‖q_new‖={q_new.norm():.6e}, " + f"amplification={amplification:.2f}") + q = q_new + if q.norm().item() > 1.0: + print(f" [escaped origin at t={t}]") + break + +# ── Alternative: What about removing the ‖q‖² term from energy? ───── +# The standard update q_{t+1} = M·softmax(β·Mᵀq) minimizes +# E(q) = -1/β·lse(β·Mᵀq) + 1/2·‖q‖² +# What if we don't want the ‖q‖² penalty? Then the fixed point equation +# is just q* = M·softmax(β·Mᵀq*), same update but different energy landscape. +# The issue is: with centering, M̃·uniform = 0 regardless of energy. +# The ‖q‖² penalty is NOT the problem for centering — the averaging is. + +print("\n" + "=" * 80) +print("DIAGNOSIS: Why centering fails for iteration") +print("=" * 80) + +# For β=50, show the attention distribution at t=0 (before any iteration) +beta = 50.0 +q0_raw = q_embs[0:1] +q0_cent = q0_raw - mu.unsqueeze(0) +logits_cent = beta * (q0_cent @ M_cent) # (1, N) +alpha_cent = torch.softmax(logits_cent, dim=-1) +entropy_cent = -(alpha_cent * alpha_cent.log()).sum().item() +max_alpha_cent = alpha_cent.max().item() + +logits_raw = beta * (q0_raw @ M_raw) +alpha_raw = torch.softmax(logits_raw, dim=-1) +entropy_raw = -(alpha_raw * alpha_raw.log()).sum().item() + +print(f"\nβ={beta}, Q0: '{q_texts[0][:60]}...'") +print(f"Raw attention: entropy={entropy_raw:.2f}, max={alpha_raw.max():.4f}") +print(f"Cent attention: entropy={entropy_cent:.2f}, max={max_alpha_cent:.4f}") +print(f"‖q0_cent‖ = {q0_cent.norm():.4f}") +print(f"‖q0_raw‖ = {q0_raw.norm():.4f}") + +# Show: what's the actual q1 norm vs predicted from Jacobian? +q1_cent = alpha_cent @ M_cent.T # (1, d) +predicted_norm = (beta / N * lambda_max_MMT) * q0_cent.norm().item() # rough bound +print(f"\n‖q1_cent‖ actual = {q1_cent.norm():.6f}") +print(f"‖q0_cent‖ × Jacobian_spectral_radius ≈ {q0_cent.norm():.4f} × {beta*lambda_max_C:.2f} = {q0_cent.norm().item()*beta*lambda_max_C:.4f}") +print(f"But the linearization only holds for q→0. q0 is NOT near zero.") + +# The real issue: softmax(β·M̃ᵀ·q0) when q0 has ‖q‖=0.5 +# The logits have some spread, but the weighted average of CENTERED vectors +# inherently cancels out. +weighted_avg = (alpha_cent @ M_cent.T) # (1, d) +unweighted_avg = M_cent.mean(dim=1) # (d,) +print(f"\n‖weighted_avg‖ = {weighted_avg.norm():.6f}") +print(f"‖unweighted_avg‖ = {unweighted_avg.norm():.2e}") + +# How concentrated is the attention? +top50_vals, top50_idx = alpha_cent.topk(50) +mass_top50 = top50_vals.sum().item() +print(f"Mass in top-50 memories: {mass_top50:.4f}") +print(f"Mass in top-5 memories: {alpha_cent.topk(5)[0].sum().item():.4f}") + +# The weighted average of centered vectors is small because: +# 1. Centered vectors m̃_i have ‖m̃_i‖ ≈ 0.78 (smaller than raw ‖m_i‖=1) +# 2. Centered vectors point in diverse directions (they have mean removed) +# 3. Even with non-uniform weights, the cancellation is severe unless +# attention is extremely peaked on a few memories +# So the output ‖q1‖ << ‖q0‖, even though β is large + +# Key quantification: what fraction of ‖q0‖ is preserved? +preserve_ratio = q1_cent.norm().item() / q0_cent.norm().item() +print(f"\n‖q1‖/‖q0‖ = {preserve_ratio:.4f} (fraction of query norm preserved)") +print("This ratio << 1 means the averaging contracts the query toward 0.") +print("For centering to work with iteration, this ratio must be > 1.") + +print("\n" + "=" * 80) +print("SOLUTION ANALYSIS") +print("=" * 80) +print(""" +The centering fix removes the centroid attractor: M̃·uniform = 0, not μ. +But the fundamental problem remains: ANY weighted average of centered vectors +is much shorter than the input query, because centered vectors cancel. + +For the origin to be unstable, β must exceed β_critical so that the Jacobian +amplifies perturbations near zero. But the dynamics from a realistic starting +point (‖q‖≈0.5) don't behave like the linearization predicts. + +The actual contraction ratio ‖q1‖/‖q0‖ is what matters, not the Jacobian +at origin. This ratio is small because softmax isn't peaked enough. + +Possible fixes: +1. MUCH higher β (β > 500?) to make attention ultra-peaked → less cancellation +2. Residual connection with centering: q_{t+1} = λ·q_t + (1-λ)·M̃·softmax(...) + This explicitly preserves query norm while still benefiting from centering. +3. Normalize q_{t+1} after each step to prevent norm collapse. +4. Use centering only for the attention computation, not for the update target: + α = softmax(β · M̃ᵀ · q̃) but q_{t+1} = M_raw · α (update in original space) +""") + +# Test option 4: centered attention, raw update +print("\n" + "=" * 80) +print("TEST: Centered attention + raw update (hybrid)") +print("=" * 80) + +for beta in [5.0, 20.0, 50.0]: + print(f"\n--- β = {beta} ---") + for qi in range(min(3, len(q_texts))): + q_raw = q_embs[qi:qi+1].clone() # (1, d) raw query + print(f" Q{qi}: '{q_texts[qi][:50]}...'") + + for t in range(5): + q_cent = q_raw - mu.unsqueeze(0) # center query + logits = beta * (q_cent @ M_cent) # attention on centered space + alpha = torch.softmax(logits, dim=-1) + q_new = alpha @ M_raw.T # update in RAW space + + entropy = -(alpha * alpha.log()).sum().item() + delta = (q_new - q_raw).norm().item() + + if t == 0: + _, top5 = alpha.topk(5) + # FAISS top5 + logits_f = q_embs[qi:qi+1] @ M_raw + _, top5_f = logits_f.topk(5) + overlap = len(set(top5[0].cpu().tolist()) & set(top5_f[0].cpu().tolist())) + + print(f" t={t}: ‖q‖={q_raw.norm():.4f} → {q_new.norm():.4f}, " + f"H={entropy:.2f}, Δ={delta:.4f}") + q_raw = q_new + + # Final top-5 + q_cent_f = q_raw - mu.unsqueeze(0) + logits_f = beta * (q_cent_f @ M_cent) + _, top5_final = torch.softmax(logits_f, dim=-1).topk(5) + logits_faiss = q_embs[qi:qi+1] @ M_raw + _, top5_faiss = logits_faiss.topk(5) + overlap = len(set(top5_final[0].cpu().tolist()) & set(top5_faiss[0].cpu().tolist())) + print(f" final vs FAISS overlap: {overlap}/5") + print(f" FAISS top5: {top5_faiss[0].cpu().tolist()}") + print(f" Hybrid top5: {top5_final[0].cpu().tolist()}") + +print("\nDone.") |
