"""Diagnose centering dynamics: trace q step by step, verify β_critical theory.""" import sys import torch import torch.nn.functional as F import numpy as np sys.path.insert(0, "/home/yurenh2/HAG") from hag.memory_bank import MemoryBank from hag.config import MemoryBankConfig # ── Load memory bank ───────────────────────────────────────────────── device = "cuda:0" # CUDA_VISIBLE_DEVICES remaps mb = MemoryBank(MemoryBankConfig(embedding_dim=768, normalize=True, center=False)) mb.load("/home/yurenh2/HAG/data/processed/hotpotqa_memory_bank.pt", device=device) M_raw = mb.embeddings # (d, N), L2-normalized, NOT centered d, N = M_raw.shape print(f"Memory bank: d={d}, N={N}") # ── Center manually ────────────────────────────────────────────────── mu = M_raw.mean(dim=1) # (d,) M_cent = M_raw - mu.unsqueeze(1) # (d, N) centered print(f"‖μ‖ = {mu.norm():.4f}") print(f"‖M̃·1/N‖ = {(M_cent.mean(dim=1)).norm():.2e} (should be ~0)") # ── Column norms ───────────────────────────────────────────────────── col_norms_raw = M_raw.norm(dim=0) col_norms_cent = M_cent.norm(dim=0) print(f"\nRaw column norms: mean={col_norms_raw.mean():.4f}, std={col_norms_raw.std():.4f}") print(f"Centered column norms: mean={col_norms_cent.mean():.4f}, std={col_norms_cent.std():.4f}") # ── SVD and β_critical ────────────────────────────────────────────── # M̃M̃ᵀ/N is the sample covariance. β_crit = N / λ_max(M̃M̃ᵀ) = 1/λ_max(C) # where C = M̃M̃ᵀ/N # But let's compute via SVD of M̃ print("\nComputing SVD of centered memory...") U, S, Vh = torch.linalg.svd(M_cent, full_matrices=False) # S shape: (min(d,N),) print(f"Top 10 singular values: {S[:10].cpu().tolist()}") lambda_max_MMT = S[0].item() ** 2 # largest eigenvalue of M̃M̃ᵀ lambda_max_C = lambda_max_MMT / N # largest eigenvalue of M̃M̃ᵀ/N print(f"\nλ_max(M̃M̃ᵀ) = {lambda_max_MMT:.2f}") print(f"λ_max(C=M̃M̃ᵀ/N) = {lambda_max_C:.4f}") # Jacobian at origin: DT = β/N · M̃M̃ᵀ # Spectral radius = β/N · λ_max(M̃M̃ᵀ) = β · λ_max(C) # For instability: β · λ_max(C) > 1 → β > 1/λ_max(C) beta_crit = 1.0 / lambda_max_C print(f"\nβ_critical = 1/λ_max(C) = {beta_crit:.4f}") print("For β > β_crit, origin is UNSTABLE (Jacobian spectral radius > 1)") for beta in [0.5, 1.0, 5.0, 10.0, 20.0, 50.0, 100.0]: rho = beta * lambda_max_C print(f" β={beta:6.1f}: Jacobian spectral radius = {rho:.2f} {'UNSTABLE' if rho > 1 else 'stable'}") # ── Load some real queries ─────────────────────────────────────────── import json questions_path = "/home/yurenh2/HAG/data/processed/hotpotqa_questions.jsonl" with open(questions_path) as f: questions = [json.loads(line) for line in f][:5] from transformers import AutoTokenizer, AutoModel print("\nLoading encoder...") tokenizer = AutoTokenizer.from_pretrained("facebook/contriever-msmarco") model = AutoModel.from_pretrained("facebook/contriever-msmarco").to(device) model.eval() def encode(texts): inputs = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) # Contriever uses mean pooling mask = inputs["attention_mask"].unsqueeze(-1).float() emb = (outputs.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1) return F.normalize(emb, dim=-1) q_texts = [q["question"] for q in questions] q_embs = encode(q_texts) # (5, d) # ── Trace dynamics step by step (centered) ─────────────────────────── print("\n" + "=" * 80) print("CENTERED DYNAMICS TRACE (5 queries)") print("=" * 80) for beta in [5.0, 20.0, 50.0, 100.0]: print(f"\n--- β = {beta} ---") for qi in range(min(3, len(q_texts))): q0_raw = q_embs[qi:qi+1] # (1, d) q0 = q0_raw - mu.unsqueeze(0) # center the query q = q0.clone() print(f"\n Q{qi}: '{q_texts[qi][:60]}...'") print(f" t=0: ‖q‖={q.norm():.4f}") for t in range(8): logits = beta * (q @ M_cent) # (1, N) alpha = torch.softmax(logits, dim=-1) # (1, N) entropy = -(alpha * alpha.log()).sum().item() max_alpha = alpha.max().item() q_new = alpha @ M_cent.T # (1, d) # How close is q_new to the dominant eigenvector? cos_v1 = abs((q_new / (q_new.norm() + 1e-12)) @ U[:, 0]).item() # FAISS-equivalent: initial attention from raw query on raw memory if t == 0: logits_raw = beta * (q0_raw @ M_raw) alpha_raw = torch.softmax(logits_raw, dim=-1) _, top5_raw = alpha_raw.topk(5) # Centered attention top-5 _, top5_cent = alpha.topk(5) overlap = len(set(top5_raw[0].cpu().tolist()) & set(top5_cent[0].cpu().tolist())) print(f" initial overlap(raw top5, centered top5) = {overlap}/5") delta = (q_new - q).norm().item() q = q_new print(f" t={t+1}: ‖q‖={q.norm():.6f}, H(α)={entropy:.2f}/{np.log(N):.2f}, " f"max(α)={max_alpha:.2e}, Δ={delta:.2e}, cos(q,v1)={cos_v1:.4f}") if delta < 1e-8: print(f" [converged]") break # Final top-5 from centered attention logits_final = beta * (q @ M_cent) alpha_final = torch.softmax(logits_final, dim=-1) _, top5_final = alpha_final.topk(5) # Compare to "iter=0" (just softmax on centered, no iteration) logits_iter0 = beta * (q0 @ M_cent) alpha_iter0 = torch.softmax(logits_iter0, dim=-1) _, top5_iter0 = alpha_iter0.topk(5) # And raw (FAISS-like) logits_faiss = beta * (q0_raw @ M_raw) alpha_faiss = torch.softmax(logits_faiss, dim=-1) _, top5_faiss = alpha_faiss.topk(5) print(f" Top-5 FAISS: {top5_faiss[0].cpu().tolist()}") print(f" Top-5 cent t=0: {top5_iter0[0].cpu().tolist()}") print(f" Top-5 cent final: {top5_final[0].cpu().tolist()}") # ── KEY TEST: Does the Jacobian actually amplify near origin? ──────── print("\n" + "=" * 80) print("JACOBIAN TEST: Start near origin, see if dynamics amplify") print("=" * 80) for beta in [5.0, 50.0, 100.0]: # Start with a tiny perturbation in direction of top eigenvector eps = 1e-4 q_tiny = eps * U[:, 0].unsqueeze(0) # (1, d), tiny perturbation along v1 print(f"\n--- β = {beta}, q0 = {eps}·v1 ---") q = q_tiny.clone() for t in range(10): logits = beta * (q @ M_cent) alpha = torch.softmax(logits, dim=-1) q_new = alpha @ M_cent.T amplification = q_new.norm().item() / (q.norm().item() + 1e-20) print(f" t={t}: ‖q‖={q.norm():.6e} → ‖q_new‖={q_new.norm():.6e}, " f"amplification={amplification:.2f}") q = q_new if q.norm().item() > 1.0: print(f" [escaped origin at t={t}]") break # ── Alternative: What about removing the ‖q‖² term from energy? ───── # The standard update q_{t+1} = M·softmax(β·Mᵀq) minimizes # E(q) = -1/β·lse(β·Mᵀq) + 1/2·‖q‖² # What if we don't want the ‖q‖² penalty? Then the fixed point equation # is just q* = M·softmax(β·Mᵀq*), same update but different energy landscape. # The issue is: with centering, M̃·uniform = 0 regardless of energy. # The ‖q‖² penalty is NOT the problem for centering — the averaging is. print("\n" + "=" * 80) print("DIAGNOSIS: Why centering fails for iteration") print("=" * 80) # For β=50, show the attention distribution at t=0 (before any iteration) beta = 50.0 q0_raw = q_embs[0:1] q0_cent = q0_raw - mu.unsqueeze(0) logits_cent = beta * (q0_cent @ M_cent) # (1, N) alpha_cent = torch.softmax(logits_cent, dim=-1) entropy_cent = -(alpha_cent * alpha_cent.log()).sum().item() max_alpha_cent = alpha_cent.max().item() logits_raw = beta * (q0_raw @ M_raw) alpha_raw = torch.softmax(logits_raw, dim=-1) entropy_raw = -(alpha_raw * alpha_raw.log()).sum().item() print(f"\nβ={beta}, Q0: '{q_texts[0][:60]}...'") print(f"Raw attention: entropy={entropy_raw:.2f}, max={alpha_raw.max():.4f}") print(f"Cent attention: entropy={entropy_cent:.2f}, max={max_alpha_cent:.4f}") print(f"‖q0_cent‖ = {q0_cent.norm():.4f}") print(f"‖q0_raw‖ = {q0_raw.norm():.4f}") # Show: what's the actual q1 norm vs predicted from Jacobian? q1_cent = alpha_cent @ M_cent.T # (1, d) predicted_norm = (beta / N * lambda_max_MMT) * q0_cent.norm().item() # rough bound print(f"\n‖q1_cent‖ actual = {q1_cent.norm():.6f}") print(f"‖q0_cent‖ × Jacobian_spectral_radius ≈ {q0_cent.norm():.4f} × {beta*lambda_max_C:.2f} = {q0_cent.norm().item()*beta*lambda_max_C:.4f}") print(f"But the linearization only holds for q→0. q0 is NOT near zero.") # The real issue: softmax(β·M̃ᵀ·q0) when q0 has ‖q‖=0.5 # The logits have some spread, but the weighted average of CENTERED vectors # inherently cancels out. weighted_avg = (alpha_cent @ M_cent.T) # (1, d) unweighted_avg = M_cent.mean(dim=1) # (d,) print(f"\n‖weighted_avg‖ = {weighted_avg.norm():.6f}") print(f"‖unweighted_avg‖ = {unweighted_avg.norm():.2e}") # How concentrated is the attention? top50_vals, top50_idx = alpha_cent.topk(50) mass_top50 = top50_vals.sum().item() print(f"Mass in top-50 memories: {mass_top50:.4f}") print(f"Mass in top-5 memories: {alpha_cent.topk(5)[0].sum().item():.4f}") # The weighted average of centered vectors is small because: # 1. Centered vectors m̃_i have ‖m̃_i‖ ≈ 0.78 (smaller than raw ‖m_i‖=1) # 2. Centered vectors point in diverse directions (they have mean removed) # 3. Even with non-uniform weights, the cancellation is severe unless # attention is extremely peaked on a few memories # So the output ‖q1‖ << ‖q0‖, even though β is large # Key quantification: what fraction of ‖q0‖ is preserved? preserve_ratio = q1_cent.norm().item() / q0_cent.norm().item() print(f"\n‖q1‖/‖q0‖ = {preserve_ratio:.4f} (fraction of query norm preserved)") print("This ratio << 1 means the averaging contracts the query toward 0.") print("For centering to work with iteration, this ratio must be > 1.") print("\n" + "=" * 80) print("SOLUTION ANALYSIS") print("=" * 80) print(""" The centering fix removes the centroid attractor: M̃·uniform = 0, not μ. But the fundamental problem remains: ANY weighted average of centered vectors is much shorter than the input query, because centered vectors cancel. For the origin to be unstable, β must exceed β_critical so that the Jacobian amplifies perturbations near zero. But the dynamics from a realistic starting point (‖q‖≈0.5) don't behave like the linearization predicts. The actual contraction ratio ‖q1‖/‖q0‖ is what matters, not the Jacobian at origin. This ratio is small because softmax isn't peaked enough. Possible fixes: 1. MUCH higher β (β > 500?) to make attention ultra-peaked → less cancellation 2. Residual connection with centering: q_{t+1} = λ·q_t + (1-λ)·M̃·softmax(...) This explicitly preserves query norm while still benefiting from centering. 3. Normalize q_{t+1} after each step to prevent norm collapse. 4. Use centering only for the attention computation, not for the update target: α = softmax(β · M̃ᵀ · q̃) but q_{t+1} = M_raw · α (update in original space) """) # Test option 4: centered attention, raw update print("\n" + "=" * 80) print("TEST: Centered attention + raw update (hybrid)") print("=" * 80) for beta in [5.0, 20.0, 50.0]: print(f"\n--- β = {beta} ---") for qi in range(min(3, len(q_texts))): q_raw = q_embs[qi:qi+1].clone() # (1, d) raw query print(f" Q{qi}: '{q_texts[qi][:50]}...'") for t in range(5): q_cent = q_raw - mu.unsqueeze(0) # center query logits = beta * (q_cent @ M_cent) # attention on centered space alpha = torch.softmax(logits, dim=-1) q_new = alpha @ M_raw.T # update in RAW space entropy = -(alpha * alpha.log()).sum().item() delta = (q_new - q_raw).norm().item() if t == 0: _, top5 = alpha.topk(5) # FAISS top5 logits_f = q_embs[qi:qi+1] @ M_raw _, top5_f = logits_f.topk(5) overlap = len(set(top5[0].cpu().tolist()) & set(top5_f[0].cpu().tolist())) print(f" t={t}: ‖q‖={q_raw.norm():.4f} → {q_new.norm():.4f}, " f"H={entropy:.2f}, Δ={delta:.4f}") q_raw = q_new # Final top-5 q_cent_f = q_raw - mu.unsqueeze(0) logits_f = beta * (q_cent_f @ M_cent) _, top5_final = torch.softmax(logits_f, dim=-1).topk(5) logits_faiss = q_embs[qi:qi+1] @ M_raw _, top5_faiss = logits_faiss.topk(5) overlap = len(set(top5_final[0].cpu().tolist()) & set(top5_faiss[0].cpu().tolist())) print(f" final vs FAISS overlap: {overlap}/5") print(f" FAISS top5: {top5_faiss[0].cpu().tolist()}") print(f" Hybrid top5: {top5_final[0].cpu().tolist()}") print("\nDone.")