summaryrefslogtreecommitdiff
path: root/scripts/analyze_user_similarity.py
blob: 538a89ab569d8d5befa7ba5bbd653cff6bd8d021 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
#!/usr/bin/env python3
"""
User Vector Similarity Analysis

This script analyzes the similarity between user vectors (z_u) learned by the
online personalization system. It computes:
1. Cosine similarity matrix between all user vectors
2. Ground truth similarity based on preference overlap
3. Correlation between learned and expected similarities

Usage:
    python scripts/analyze_user_similarity.py \
        --user-store data/users/user_store_pilot_v4_full-greedy.npz
"""

import argparse
import numpy as np
from typing import Dict, List, Tuple
from dataclasses import dataclass


# =============================================================================
# Persona Definitions (must match pilot_runner_v4.py)
# =============================================================================

@dataclass
class StylePrefs:
    """User's TRUE style preferences."""
    require_short: bool = False
    max_chars: int = 300
    require_bullets: bool = False
    lang: str = "en"


# Ground truth personas
PERSONAS = {
    "user_A_short_bullets_en": StylePrefs(require_short=True, max_chars=200, require_bullets=True, lang="en"),
    "user_B_short_no_bullets_en": StylePrefs(require_short=True, max_chars=200, require_bullets=False, lang="en"),
    "user_C_long_bullets_en": StylePrefs(require_short=False, max_chars=800, require_bullets=True, lang="en"),
    "user_D_short_bullets_zh": StylePrefs(require_short=True, max_chars=200, require_bullets=True, lang="zh"),
    "user_E_long_no_bullets_zh": StylePrefs(require_short=False, max_chars=800, require_bullets=False, lang="zh"),
    "user_F_extreme_short_en": StylePrefs(require_short=True, max_chars=100, require_bullets=True, lang="en"),
}


# =============================================================================
# User Vector Loading
# =============================================================================

def load_user_vectors(user_store_path: str) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
    """
    Load user vectors from saved user store.
    
    Returns:
        {user_id: (z_long, z_short)}
    """
    data = np.load(user_store_path, allow_pickle=True)
    
    user_vectors = {}
    
    # UserTensorStore saves in format: {uid}_long, {uid}_short, {uid}_meta
    # First, find all unique user IDs
    user_ids = set()
    for key in data.files:
        if key.endswith("_long"):
            uid = key[:-5]  # Remove "_long"
            user_ids.add(uid)
    
    # Load vectors for each user
    for uid in user_ids:
        long_key = f"{uid}_long"
        short_key = f"{uid}_short"
        
        if long_key in data.files and short_key in data.files:
            z_long = data[long_key]
            z_short = data[short_key]
            user_vectors[uid] = (z_long, z_short)
    
    return user_vectors


def load_user_vectors_from_internal(user_store_path: str) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
    """
    Alternative loader that understands the internal format.
    """
    data = np.load(user_store_path, allow_pickle=True)
    
    print(f"[Debug] Available keys in npz: {list(data.files)}")
    
    user_vectors = {}
    
    # Try to find user vectors in various formats
    for key in data.files:
        print(f"  {key}: shape={data[key].shape if hasattr(data[key], 'shape') else 'N/A'}")
    
    # Format 1: Separate arrays per user
    seen_users = set()
    for key in data.files:
        if "_z_long" in key or key.startswith("z_long_"):
            # Extract user_id
            if key.startswith("z_long_"):
                user_id = key[7:]  # Remove "z_long_"
            else:
                user_id = key.split("_z_long")[0]
            seen_users.add(user_id)
    
    for user_id in seen_users:
        # Try different key formats
        z_long_keys = [f"z_long_{user_id}", f"{user_id}_z_long"]
        z_short_keys = [f"z_short_{user_id}", f"{user_id}_z_short"]
        
        z_long = None
        z_short = None
        
        for k in z_long_keys:
            if k in data.files:
                z_long = data[k]
                break
        
        for k in z_short_keys:
            if k in data.files:
                z_short = data[k]
                break
        
        if z_long is not None and z_short is not None:
            user_vectors[user_id] = (z_long, z_short)
    
    return user_vectors


# =============================================================================
# Similarity Computation
# =============================================================================

def cosine_similarity(v1: np.ndarray, v2: np.ndarray) -> float:
    """Compute cosine similarity between two vectors."""
    norm1 = np.linalg.norm(v1)
    norm2 = np.linalg.norm(v2)
    
    if norm1 < 1e-10 or norm2 < 1e-10:
        return 0.0
    
    return float(np.dot(v1, v2) / (norm1 * norm2))


def compute_learned_similarity_matrix(
    user_vectors: Dict[str, Tuple[np.ndarray, np.ndarray]],
    user_order: List[str]
) -> np.ndarray:
    """
    Compute similarity matrix from learned user vectors.
    
    Uses concatenated [z_long, z_short] as the user representation.
    """
    n = len(user_order)
    sim_matrix = np.zeros((n, n))
    
    for i, u1 in enumerate(user_order):
        for j, u2 in enumerate(user_order):
            if u1 in user_vectors and u2 in user_vectors:
                z1 = np.concatenate(user_vectors[u1])
                z2 = np.concatenate(user_vectors[u2])
                sim_matrix[i, j] = cosine_similarity(z1, z2)
            elif i == j:
                sim_matrix[i, j] = 1.0
    
    return sim_matrix


def compute_ground_truth_similarity(
    personas: Dict[str, StylePrefs],
    user_order: List[str]
) -> np.ndarray:
    """
    Compute ground truth similarity based on preference overlap.
    
    Uses Jaccard-like similarity:
    - short: +1 if both require_short or both don't
    - bullets: +1 if both require_bullets match
    - lang: +1 if both lang match
    
    Then normalize to [0, 1].
    """
    n = len(user_order)
    sim_matrix = np.zeros((n, n))
    
    for i, u1 in enumerate(user_order):
        for j, u2 in enumerate(user_order):
            if u1 not in personas or u2 not in personas:
                sim_matrix[i, j] = 0.0 if i != j else 1.0
                continue
            
            p1 = personas[u1]
            p2 = personas[u2]
            
            # Count matching dimensions
            matches = 0
            total = 3  # short, bullets, lang
            
            if p1.require_short == p2.require_short:
                matches += 1
            if p1.require_bullets == p2.require_bullets:
                matches += 1
            if p1.lang == p2.lang:
                matches += 1
            
            sim_matrix[i, j] = matches / total
    
    return sim_matrix


def compute_correlation(learned: np.ndarray, ground_truth: np.ndarray) -> Tuple[float, float]:
    """
    Compute Pearson and Spearman correlation between learned and ground truth similarity.
    Only uses upper triangle (excluding diagonal) to avoid bias.
    """
    n = learned.shape[0]
    
    # Extract upper triangle (excluding diagonal)
    learned_flat = []
    gt_flat = []
    
    for i in range(n):
        for j in range(i + 1, n):
            learned_flat.append(learned[i, j])
            gt_flat.append(ground_truth[i, j])
    
    learned_flat = np.array(learned_flat)
    gt_flat = np.array(gt_flat)
    
    # Pearson correlation
    if np.std(learned_flat) < 1e-10 or np.std(gt_flat) < 1e-10:
        pearson = 0.0
    else:
        pearson = float(np.corrcoef(learned_flat, gt_flat)[0, 1])
    
    # Spearman correlation (rank-based)
    from scipy.stats import spearmanr
    spearman, _ = spearmanr(learned_flat, gt_flat)
    
    return pearson, float(spearman)


# =============================================================================
# Visualization
# =============================================================================

def print_similarity_matrix(matrix: np.ndarray, user_order: List[str], title: str):
    """Print similarity matrix in ASCII format."""
    print(f"\n{title}")
    print("=" * 70)
    
    # Short labels
    labels = [u.replace("user_", "").replace("_", " ")[:15] for u in user_order]
    
    # Header
    print(f"{'':>16}", end="")
    for label in labels:
        print(f"{label[:8]:>10}", end="")
    print()
    
    # Rows
    for i, label in enumerate(labels):
        print(f"{label:>16}", end="")
        for j in range(len(labels)):
            print(f"{matrix[i, j]:>10.3f}", end="")
        print()
    
    print()


def save_visualization(
    learned: np.ndarray,
    ground_truth: np.ndarray,
    user_order: List[str],
    output_path: str
):
    """Save similarity matrices as heatmap visualization."""
    try:
        import matplotlib.pyplot as plt
        import seaborn as sns
    except ImportError:
        print("[Warning] matplotlib/seaborn not available, skipping visualization")
        return
    
    # Short labels
    labels = [u.replace("user_", "")[:12] for u in user_order]
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    # Learned similarity
    sns.heatmap(learned, annot=True, fmt=".2f", 
                xticklabels=labels, yticklabels=labels,
                cmap="RdYlGn", vmin=-1, vmax=1,
                ax=axes[0])
    axes[0].set_title("Learned User Vector Similarity\n(cosine similarity)")
    axes[0].tick_params(axis='x', rotation=45)
    axes[0].tick_params(axis='y', rotation=0)
    
    # Ground truth similarity
    sns.heatmap(ground_truth, annot=True, fmt=".2f",
                xticklabels=labels, yticklabels=labels,
                cmap="RdYlGn", vmin=0, vmax=1,
                ax=axes[1])
    axes[1].set_title("Ground Truth Preference Overlap\n(Jaccard-like)")
    axes[1].tick_params(axis='x', rotation=45)
    axes[1].tick_params(axis='y', rotation=0)
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    print(f"[Visualization] Saved to: {output_path}")


# =============================================================================
# Main Analysis
# =============================================================================

def analyze_user_similarity(user_store_path: str, output_dir: str = "data/analysis"):
    """Run full user similarity analysis."""
    import os
    os.makedirs(output_dir, exist_ok=True)
    
    print("=" * 70)
    print("USER VECTOR SIMILARITY ANALYSIS")
    print("=" * 70)
    print(f"User store: {user_store_path}")
    
    # Load user vectors
    print("\n[1] Loading user vectors...")
    user_vectors = load_user_vectors(user_store_path)
    
    if not user_vectors:
        print("[Warning] No user vectors found with standard format, trying alternative...")
        user_vectors = load_user_vectors_from_internal(user_store_path)
    
    if not user_vectors:
        print("[Error] Could not load user vectors!")
        return
    
    print(f"  Found {len(user_vectors)} users: {list(user_vectors.keys())}")
    
    # Print vector norms
    print("\n[2] User vector norms:")
    for uid, (z_long, z_short) in user_vectors.items():
        print(f"  {uid}: ||z_long||={np.linalg.norm(z_long):.4f}, ||z_short||={np.linalg.norm(z_short):.4f}")
    
    # Determine user order (intersection of loaded users and known personas)
    user_order = [u for u in PERSONAS.keys() if u in user_vectors]
    print(f"\n[3] Analyzing {len(user_order)} users: {user_order}")
    
    if len(user_order) < 2:
        print("[Error] Need at least 2 users for similarity analysis!")
        return
    
    # Compute similarity matrices
    print("\n[4] Computing similarity matrices...")
    learned_sim = compute_learned_similarity_matrix(user_vectors, user_order)
    gt_sim = compute_ground_truth_similarity(PERSONAS, user_order)
    
    # Print matrices
    print_similarity_matrix(learned_sim, user_order, "LEARNED SIMILARITY (Cosine of z_u)")
    print_similarity_matrix(gt_sim, user_order, "GROUND TRUTH SIMILARITY (Preference Overlap)")
    
    # Compute correlation
    print("\n[5] Correlation Analysis:")
    print("-" * 50)
    pearson, spearman = compute_correlation(learned_sim, gt_sim)
    print(f"  Pearson correlation:  {pearson:.4f}")
    print(f"  Spearman correlation: {spearman:.4f}")
    
    # Interpretation
    print("\n[6] Interpretation:")
    print("-" * 50)
    if spearman > 0.7:
        print("  ✅ STRONG correlation: User vectors encode preference similarity well!")
    elif spearman > 0.4:
        print("  ⚠️  MODERATE correlation: User vectors partially capture preferences.")
    elif spearman > 0:
        print("  ⚠️  WEAK correlation: User vectors weakly capture preferences.")
    else:
        print("  ❌ NO/NEGATIVE correlation: User vectors do not reflect preferences.")
    
    # Key comparisons
    print("\n[7] Key Similarity Comparisons:")
    print("-" * 50)
    
    def get_sim(u1, u2, matrix, user_order):
        if u1 in user_order and u2 in user_order:
            i, j = user_order.index(u1), user_order.index(u2)
            return matrix[i, j]
        return None
    
    comparisons = [
        ("user_A_short_bullets_en", "user_F_extreme_short_en", ">", "user_A_short_bullets_en", "user_E_long_no_bullets_zh",
         "A~F (both short+bullets) should be > A~E (opposite)"),
        ("user_A_short_bullets_en", "user_D_short_bullets_zh", ">", "user_A_short_bullets_en", "user_C_long_bullets_en",
         "A~D (both short+bullets) should be > A~C (only bullets match)"),
        ("user_B_short_no_bullets_en", "user_E_long_no_bullets_zh", ">", "user_B_short_no_bullets_en", "user_A_short_bullets_en",
         "B~E (both no_bullets) should be > B~A (bullets differ)"),
    ]
    
    for u1, u2, op, u3, u4, desc in comparisons:
        sim1 = get_sim(u1, u2, learned_sim, user_order)
        sim2 = get_sim(u3, u4, learned_sim, user_order)
        
        if sim1 is not None and sim2 is not None:
            passed = sim1 > sim2 if op == ">" else sim1 < sim2
            status = "✅ PASS" if passed else "❌ FAIL"
            print(f"  {status}: sim({u1[:6]},{u2[:6]})={sim1:.3f} {op} sim({u3[:6]},{u4[:6]})={sim2:.3f}")
            print(f"          ({desc})")
    
    # Save visualization
    print("\n[8] Saving visualization...")
    output_path = os.path.join(output_dir, "user_similarity_matrix.png")
    save_visualization(learned_sim, gt_sim, user_order, output_path)
    
    # Save numerical results
    results_path = os.path.join(output_dir, "user_similarity_results.npz")
    np.savez(results_path,
             learned_similarity=learned_sim,
             ground_truth_similarity=gt_sim,
             user_order=user_order,
             pearson=pearson,
             spearman=spearman)
    print(f"[Results] Saved to: {results_path}")
    
    print("\n" + "=" * 70)
    print("ANALYSIS COMPLETE")
    print("=" * 70)


def main():
    parser = argparse.ArgumentParser(description="User Vector Similarity Analysis")
    parser.add_argument("--user-store", type=str, required=True,
                        help="Path to user store npz file")
    parser.add_argument("--output-dir", type=str, default="data/analysis",
                        help="Output directory for results")
    args = parser.parse_args()
    
    analyze_user_similarity(args.user_store, args.output_dir)


if __name__ == "__main__":
    main()