1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
|
"""Diagnose centering dynamics: trace q step by step, verify β_critical theory."""
import sys
import torch
import torch.nn.functional as F
import numpy as np
sys.path.insert(0, "/home/yurenh2/HAG")
from hag.memory_bank import MemoryBank
from hag.config import MemoryBankConfig
# ── Load memory bank ─────────────────────────────────────────────────
device = "cuda:0" # CUDA_VISIBLE_DEVICES remaps
mb = MemoryBank(MemoryBankConfig(embedding_dim=768, normalize=True, center=False))
mb.load("/home/yurenh2/HAG/data/processed/hotpotqa_memory_bank.pt", device=device)
M_raw = mb.embeddings # (d, N), L2-normalized, NOT centered
d, N = M_raw.shape
print(f"Memory bank: d={d}, N={N}")
# ── Center manually ──────────────────────────────────────────────────
mu = M_raw.mean(dim=1) # (d,)
M_cent = M_raw - mu.unsqueeze(1) # (d, N) centered
print(f"‖μ‖ = {mu.norm():.4f}")
print(f"‖M̃·1/N‖ = {(M_cent.mean(dim=1)).norm():.2e} (should be ~0)")
# ── Column norms ─────────────────────────────────────────────────────
col_norms_raw = M_raw.norm(dim=0)
col_norms_cent = M_cent.norm(dim=0)
print(f"\nRaw column norms: mean={col_norms_raw.mean():.4f}, std={col_norms_raw.std():.4f}")
print(f"Centered column norms: mean={col_norms_cent.mean():.4f}, std={col_norms_cent.std():.4f}")
# ── SVD and β_critical ──────────────────────────────────────────────
# M̃M̃ᵀ/N is the sample covariance. β_crit = N / λ_max(M̃M̃ᵀ) = 1/λ_max(C)
# where C = M̃M̃ᵀ/N
# But let's compute via SVD of M̃
print("\nComputing SVD of centered memory...")
U, S, Vh = torch.linalg.svd(M_cent, full_matrices=False) # S shape: (min(d,N),)
print(f"Top 10 singular values: {S[:10].cpu().tolist()}")
lambda_max_MMT = S[0].item() ** 2 # largest eigenvalue of M̃M̃ᵀ
lambda_max_C = lambda_max_MMT / N # largest eigenvalue of M̃M̃ᵀ/N
print(f"\nλ_max(M̃M̃ᵀ) = {lambda_max_MMT:.2f}")
print(f"λ_max(C=M̃M̃ᵀ/N) = {lambda_max_C:.4f}")
# Jacobian at origin: DT = β/N · M̃M̃ᵀ
# Spectral radius = β/N · λ_max(M̃M̃ᵀ) = β · λ_max(C)
# For instability: β · λ_max(C) > 1 → β > 1/λ_max(C)
beta_crit = 1.0 / lambda_max_C
print(f"\nβ_critical = 1/λ_max(C) = {beta_crit:.4f}")
print("For β > β_crit, origin is UNSTABLE (Jacobian spectral radius > 1)")
for beta in [0.5, 1.0, 5.0, 10.0, 20.0, 50.0, 100.0]:
rho = beta * lambda_max_C
print(f" β={beta:6.1f}: Jacobian spectral radius = {rho:.2f} {'UNSTABLE' if rho > 1 else 'stable'}")
# ── Load some real queries ───────────────────────────────────────────
import json
questions_path = "/home/yurenh2/HAG/data/processed/hotpotqa_questions.jsonl"
with open(questions_path) as f:
questions = [json.loads(line) for line in f][:5]
from transformers import AutoTokenizer, AutoModel
print("\nLoading encoder...")
tokenizer = AutoTokenizer.from_pretrained("facebook/contriever-msmarco")
model = AutoModel.from_pretrained("facebook/contriever-msmarco").to(device)
model.eval()
def encode(texts):
inputs = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
# Contriever uses mean pooling
mask = inputs["attention_mask"].unsqueeze(-1).float()
emb = (outputs.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1)
return F.normalize(emb, dim=-1)
q_texts = [q["question"] for q in questions]
q_embs = encode(q_texts) # (5, d)
# ── Trace dynamics step by step (centered) ───────────────────────────
print("\n" + "=" * 80)
print("CENTERED DYNAMICS TRACE (5 queries)")
print("=" * 80)
for beta in [5.0, 20.0, 50.0, 100.0]:
print(f"\n--- β = {beta} ---")
for qi in range(min(3, len(q_texts))):
q0_raw = q_embs[qi:qi+1] # (1, d)
q0 = q0_raw - mu.unsqueeze(0) # center the query
q = q0.clone()
print(f"\n Q{qi}: '{q_texts[qi][:60]}...'")
print(f" t=0: ‖q‖={q.norm():.4f}")
for t in range(8):
logits = beta * (q @ M_cent) # (1, N)
alpha = torch.softmax(logits, dim=-1) # (1, N)
entropy = -(alpha * alpha.log()).sum().item()
max_alpha = alpha.max().item()
q_new = alpha @ M_cent.T # (1, d)
# How close is q_new to the dominant eigenvector?
cos_v1 = abs((q_new / (q_new.norm() + 1e-12)) @ U[:, 0]).item()
# FAISS-equivalent: initial attention from raw query on raw memory
if t == 0:
logits_raw = beta * (q0_raw @ M_raw)
alpha_raw = torch.softmax(logits_raw, dim=-1)
_, top5_raw = alpha_raw.topk(5)
# Centered attention top-5
_, top5_cent = alpha.topk(5)
overlap = len(set(top5_raw[0].cpu().tolist()) & set(top5_cent[0].cpu().tolist()))
print(f" initial overlap(raw top5, centered top5) = {overlap}/5")
delta = (q_new - q).norm().item()
q = q_new
print(f" t={t+1}: ‖q‖={q.norm():.6f}, H(α)={entropy:.2f}/{np.log(N):.2f}, "
f"max(α)={max_alpha:.2e}, Δ={delta:.2e}, cos(q,v1)={cos_v1:.4f}")
if delta < 1e-8:
print(f" [converged]")
break
# Final top-5 from centered attention
logits_final = beta * (q @ M_cent)
alpha_final = torch.softmax(logits_final, dim=-1)
_, top5_final = alpha_final.topk(5)
# Compare to "iter=0" (just softmax on centered, no iteration)
logits_iter0 = beta * (q0 @ M_cent)
alpha_iter0 = torch.softmax(logits_iter0, dim=-1)
_, top5_iter0 = alpha_iter0.topk(5)
# And raw (FAISS-like)
logits_faiss = beta * (q0_raw @ M_raw)
alpha_faiss = torch.softmax(logits_faiss, dim=-1)
_, top5_faiss = alpha_faiss.topk(5)
print(f" Top-5 FAISS: {top5_faiss[0].cpu().tolist()}")
print(f" Top-5 cent t=0: {top5_iter0[0].cpu().tolist()}")
print(f" Top-5 cent final: {top5_final[0].cpu().tolist()}")
# ── KEY TEST: Does the Jacobian actually amplify near origin? ────────
print("\n" + "=" * 80)
print("JACOBIAN TEST: Start near origin, see if dynamics amplify")
print("=" * 80)
for beta in [5.0, 50.0, 100.0]:
# Start with a tiny perturbation in direction of top eigenvector
eps = 1e-4
q_tiny = eps * U[:, 0].unsqueeze(0) # (1, d), tiny perturbation along v1
print(f"\n--- β = {beta}, q0 = {eps}·v1 ---")
q = q_tiny.clone()
for t in range(10):
logits = beta * (q @ M_cent)
alpha = torch.softmax(logits, dim=-1)
q_new = alpha @ M_cent.T
amplification = q_new.norm().item() / (q.norm().item() + 1e-20)
print(f" t={t}: ‖q‖={q.norm():.6e} → ‖q_new‖={q_new.norm():.6e}, "
f"amplification={amplification:.2f}")
q = q_new
if q.norm().item() > 1.0:
print(f" [escaped origin at t={t}]")
break
# ── Alternative: What about removing the ‖q‖² term from energy? ─────
# The standard update q_{t+1} = M·softmax(β·Mᵀq) minimizes
# E(q) = -1/β·lse(β·Mᵀq) + 1/2·‖q‖²
# What if we don't want the ‖q‖² penalty? Then the fixed point equation
# is just q* = M·softmax(β·Mᵀq*), same update but different energy landscape.
# The issue is: with centering, M̃·uniform = 0 regardless of energy.
# The ‖q‖² penalty is NOT the problem for centering — the averaging is.
print("\n" + "=" * 80)
print("DIAGNOSIS: Why centering fails for iteration")
print("=" * 80)
# For β=50, show the attention distribution at t=0 (before any iteration)
beta = 50.0
q0_raw = q_embs[0:1]
q0_cent = q0_raw - mu.unsqueeze(0)
logits_cent = beta * (q0_cent @ M_cent) # (1, N)
alpha_cent = torch.softmax(logits_cent, dim=-1)
entropy_cent = -(alpha_cent * alpha_cent.log()).sum().item()
max_alpha_cent = alpha_cent.max().item()
logits_raw = beta * (q0_raw @ M_raw)
alpha_raw = torch.softmax(logits_raw, dim=-1)
entropy_raw = -(alpha_raw * alpha_raw.log()).sum().item()
print(f"\nβ={beta}, Q0: '{q_texts[0][:60]}...'")
print(f"Raw attention: entropy={entropy_raw:.2f}, max={alpha_raw.max():.4f}")
print(f"Cent attention: entropy={entropy_cent:.2f}, max={max_alpha_cent:.4f}")
print(f"‖q0_cent‖ = {q0_cent.norm():.4f}")
print(f"‖q0_raw‖ = {q0_raw.norm():.4f}")
# Show: what's the actual q1 norm vs predicted from Jacobian?
q1_cent = alpha_cent @ M_cent.T # (1, d)
predicted_norm = (beta / N * lambda_max_MMT) * q0_cent.norm().item() # rough bound
print(f"\n‖q1_cent‖ actual = {q1_cent.norm():.6f}")
print(f"‖q0_cent‖ × Jacobian_spectral_radius ≈ {q0_cent.norm():.4f} × {beta*lambda_max_C:.2f} = {q0_cent.norm().item()*beta*lambda_max_C:.4f}")
print(f"But the linearization only holds for q→0. q0 is NOT near zero.")
# The real issue: softmax(β·M̃ᵀ·q0) when q0 has ‖q‖=0.5
# The logits have some spread, but the weighted average of CENTERED vectors
# inherently cancels out.
weighted_avg = (alpha_cent @ M_cent.T) # (1, d)
unweighted_avg = M_cent.mean(dim=1) # (d,)
print(f"\n‖weighted_avg‖ = {weighted_avg.norm():.6f}")
print(f"‖unweighted_avg‖ = {unweighted_avg.norm():.2e}")
# How concentrated is the attention?
top50_vals, top50_idx = alpha_cent.topk(50)
mass_top50 = top50_vals.sum().item()
print(f"Mass in top-50 memories: {mass_top50:.4f}")
print(f"Mass in top-5 memories: {alpha_cent.topk(5)[0].sum().item():.4f}")
# The weighted average of centered vectors is small because:
# 1. Centered vectors m̃_i have ‖m̃_i‖ ≈ 0.78 (smaller than raw ‖m_i‖=1)
# 2. Centered vectors point in diverse directions (they have mean removed)
# 3. Even with non-uniform weights, the cancellation is severe unless
# attention is extremely peaked on a few memories
# So the output ‖q1‖ << ‖q0‖, even though β is large
# Key quantification: what fraction of ‖q0‖ is preserved?
preserve_ratio = q1_cent.norm().item() / q0_cent.norm().item()
print(f"\n‖q1‖/‖q0‖ = {preserve_ratio:.4f} (fraction of query norm preserved)")
print("This ratio << 1 means the averaging contracts the query toward 0.")
print("For centering to work with iteration, this ratio must be > 1.")
print("\n" + "=" * 80)
print("SOLUTION ANALYSIS")
print("=" * 80)
print("""
The centering fix removes the centroid attractor: M̃·uniform = 0, not μ.
But the fundamental problem remains: ANY weighted average of centered vectors
is much shorter than the input query, because centered vectors cancel.
For the origin to be unstable, β must exceed β_critical so that the Jacobian
amplifies perturbations near zero. But the dynamics from a realistic starting
point (‖q‖≈0.5) don't behave like the linearization predicts.
The actual contraction ratio ‖q1‖/‖q0‖ is what matters, not the Jacobian
at origin. This ratio is small because softmax isn't peaked enough.
Possible fixes:
1. MUCH higher β (β > 500?) to make attention ultra-peaked → less cancellation
2. Residual connection with centering: q_{t+1} = λ·q_t + (1-λ)·M̃·softmax(...)
This explicitly preserves query norm while still benefiting from centering.
3. Normalize q_{t+1} after each step to prevent norm collapse.
4. Use centering only for the attention computation, not for the update target:
α = softmax(β · M̃ᵀ · q̃) but q_{t+1} = M_raw · α (update in original space)
""")
# Test option 4: centered attention, raw update
print("\n" + "=" * 80)
print("TEST: Centered attention + raw update (hybrid)")
print("=" * 80)
for beta in [5.0, 20.0, 50.0]:
print(f"\n--- β = {beta} ---")
for qi in range(min(3, len(q_texts))):
q_raw = q_embs[qi:qi+1].clone() # (1, d) raw query
print(f" Q{qi}: '{q_texts[qi][:50]}...'")
for t in range(5):
q_cent = q_raw - mu.unsqueeze(0) # center query
logits = beta * (q_cent @ M_cent) # attention on centered space
alpha = torch.softmax(logits, dim=-1)
q_new = alpha @ M_raw.T # update in RAW space
entropy = -(alpha * alpha.log()).sum().item()
delta = (q_new - q_raw).norm().item()
if t == 0:
_, top5 = alpha.topk(5)
# FAISS top5
logits_f = q_embs[qi:qi+1] @ M_raw
_, top5_f = logits_f.topk(5)
overlap = len(set(top5[0].cpu().tolist()) & set(top5_f[0].cpu().tolist()))
print(f" t={t}: ‖q‖={q_raw.norm():.4f} → {q_new.norm():.4f}, "
f"H={entropy:.2f}, Δ={delta:.4f}")
q_raw = q_new
# Final top-5
q_cent_f = q_raw - mu.unsqueeze(0)
logits_f = beta * (q_cent_f @ M_cent)
_, top5_final = torch.softmax(logits_f, dim=-1).topk(5)
logits_faiss = q_embs[qi:qi+1] @ M_raw
_, top5_faiss = logits_faiss.topk(5)
overlap = len(set(top5_final[0].cpu().tolist()) & set(top5_faiss[0].cpu().tolist()))
print(f" final vs FAISS overlap: {overlap}/5")
print(f" FAISS top5: {top5_faiss[0].cpu().tolist()}")
print(f" Hybrid top5: {top5_final[0].cpu().tolist()}")
print("\nDone.")
|