diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
| commit | e43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch) | |
| tree | 6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /scripts/analyze_user_similarity.py | |
Diffstat (limited to 'scripts/analyze_user_similarity.py')
| -rw-r--r-- | scripts/analyze_user_similarity.py | 445 |
1 files changed, 445 insertions, 0 deletions
diff --git a/scripts/analyze_user_similarity.py b/scripts/analyze_user_similarity.py new file mode 100644 index 0000000..538a89a --- /dev/null +++ b/scripts/analyze_user_similarity.py @@ -0,0 +1,445 @@ +#!/usr/bin/env python3 +""" +User Vector Similarity Analysis + +This script analyzes the similarity between user vectors (z_u) learned by the +online personalization system. It computes: +1. Cosine similarity matrix between all user vectors +2. Ground truth similarity based on preference overlap +3. Correlation between learned and expected similarities + +Usage: + python scripts/analyze_user_similarity.py \ + --user-store data/users/user_store_pilot_v4_full-greedy.npz +""" + +import argparse +import numpy as np +from typing import Dict, List, Tuple +from dataclasses import dataclass + + +# ============================================================================= +# Persona Definitions (must match pilot_runner_v4.py) +# ============================================================================= + +@dataclass +class StylePrefs: + """User's TRUE style preferences.""" + require_short: bool = False + max_chars: int = 300 + require_bullets: bool = False + lang: str = "en" + + +# Ground truth personas +PERSONAS = { + "user_A_short_bullets_en": StylePrefs(require_short=True, max_chars=200, require_bullets=True, lang="en"), + "user_B_short_no_bullets_en": StylePrefs(require_short=True, max_chars=200, require_bullets=False, lang="en"), + "user_C_long_bullets_en": StylePrefs(require_short=False, max_chars=800, require_bullets=True, lang="en"), + "user_D_short_bullets_zh": StylePrefs(require_short=True, max_chars=200, require_bullets=True, lang="zh"), + "user_E_long_no_bullets_zh": StylePrefs(require_short=False, max_chars=800, require_bullets=False, lang="zh"), + "user_F_extreme_short_en": StylePrefs(require_short=True, max_chars=100, require_bullets=True, lang="en"), +} + + +# ============================================================================= +# User Vector Loading +# ============================================================================= + +def load_user_vectors(user_store_path: str) -> Dict[str, Tuple[np.ndarray, np.ndarray]]: + """ + Load user vectors from saved user store. + + Returns: + {user_id: (z_long, z_short)} + """ + data = np.load(user_store_path, allow_pickle=True) + + user_vectors = {} + + # UserTensorStore saves in format: {uid}_long, {uid}_short, {uid}_meta + # First, find all unique user IDs + user_ids = set() + for key in data.files: + if key.endswith("_long"): + uid = key[:-5] # Remove "_long" + user_ids.add(uid) + + # Load vectors for each user + for uid in user_ids: + long_key = f"{uid}_long" + short_key = f"{uid}_short" + + if long_key in data.files and short_key in data.files: + z_long = data[long_key] + z_short = data[short_key] + user_vectors[uid] = (z_long, z_short) + + return user_vectors + + +def load_user_vectors_from_internal(user_store_path: str) -> Dict[str, Tuple[np.ndarray, np.ndarray]]: + """ + Alternative loader that understands the internal format. + """ + data = np.load(user_store_path, allow_pickle=True) + + print(f"[Debug] Available keys in npz: {list(data.files)}") + + user_vectors = {} + + # Try to find user vectors in various formats + for key in data.files: + print(f" {key}: shape={data[key].shape if hasattr(data[key], 'shape') else 'N/A'}") + + # Format 1: Separate arrays per user + seen_users = set() + for key in data.files: + if "_z_long" in key or key.startswith("z_long_"): + # Extract user_id + if key.startswith("z_long_"): + user_id = key[7:] # Remove "z_long_" + else: + user_id = key.split("_z_long")[0] + seen_users.add(user_id) + + for user_id in seen_users: + # Try different key formats + z_long_keys = [f"z_long_{user_id}", f"{user_id}_z_long"] + z_short_keys = [f"z_short_{user_id}", f"{user_id}_z_short"] + + z_long = None + z_short = None + + for k in z_long_keys: + if k in data.files: + z_long = data[k] + break + + for k in z_short_keys: + if k in data.files: + z_short = data[k] + break + + if z_long is not None and z_short is not None: + user_vectors[user_id] = (z_long, z_short) + + return user_vectors + + +# ============================================================================= +# Similarity Computation +# ============================================================================= + +def cosine_similarity(v1: np.ndarray, v2: np.ndarray) -> float: + """Compute cosine similarity between two vectors.""" + norm1 = np.linalg.norm(v1) + norm2 = np.linalg.norm(v2) + + if norm1 < 1e-10 or norm2 < 1e-10: + return 0.0 + + return float(np.dot(v1, v2) / (norm1 * norm2)) + + +def compute_learned_similarity_matrix( + user_vectors: Dict[str, Tuple[np.ndarray, np.ndarray]], + user_order: List[str] +) -> np.ndarray: + """ + Compute similarity matrix from learned user vectors. + + Uses concatenated [z_long, z_short] as the user representation. + """ + n = len(user_order) + sim_matrix = np.zeros((n, n)) + + for i, u1 in enumerate(user_order): + for j, u2 in enumerate(user_order): + if u1 in user_vectors and u2 in user_vectors: + z1 = np.concatenate(user_vectors[u1]) + z2 = np.concatenate(user_vectors[u2]) + sim_matrix[i, j] = cosine_similarity(z1, z2) + elif i == j: + sim_matrix[i, j] = 1.0 + + return sim_matrix + + +def compute_ground_truth_similarity( + personas: Dict[str, StylePrefs], + user_order: List[str] +) -> np.ndarray: + """ + Compute ground truth similarity based on preference overlap. + + Uses Jaccard-like similarity: + - short: +1 if both require_short or both don't + - bullets: +1 if both require_bullets match + - lang: +1 if both lang match + + Then normalize to [0, 1]. + """ + n = len(user_order) + sim_matrix = np.zeros((n, n)) + + for i, u1 in enumerate(user_order): + for j, u2 in enumerate(user_order): + if u1 not in personas or u2 not in personas: + sim_matrix[i, j] = 0.0 if i != j else 1.0 + continue + + p1 = personas[u1] + p2 = personas[u2] + + # Count matching dimensions + matches = 0 + total = 3 # short, bullets, lang + + if p1.require_short == p2.require_short: + matches += 1 + if p1.require_bullets == p2.require_bullets: + matches += 1 + if p1.lang == p2.lang: + matches += 1 + + sim_matrix[i, j] = matches / total + + return sim_matrix + + +def compute_correlation(learned: np.ndarray, ground_truth: np.ndarray) -> Tuple[float, float]: + """ + Compute Pearson and Spearman correlation between learned and ground truth similarity. + Only uses upper triangle (excluding diagonal) to avoid bias. + """ + n = learned.shape[0] + + # Extract upper triangle (excluding diagonal) + learned_flat = [] + gt_flat = [] + + for i in range(n): + for j in range(i + 1, n): + learned_flat.append(learned[i, j]) + gt_flat.append(ground_truth[i, j]) + + learned_flat = np.array(learned_flat) + gt_flat = np.array(gt_flat) + + # Pearson correlation + if np.std(learned_flat) < 1e-10 or np.std(gt_flat) < 1e-10: + pearson = 0.0 + else: + pearson = float(np.corrcoef(learned_flat, gt_flat)[0, 1]) + + # Spearman correlation (rank-based) + from scipy.stats import spearmanr + spearman, _ = spearmanr(learned_flat, gt_flat) + + return pearson, float(spearman) + + +# ============================================================================= +# Visualization +# ============================================================================= + +def print_similarity_matrix(matrix: np.ndarray, user_order: List[str], title: str): + """Print similarity matrix in ASCII format.""" + print(f"\n{title}") + print("=" * 70) + + # Short labels + labels = [u.replace("user_", "").replace("_", " ")[:15] for u in user_order] + + # Header + print(f"{'':>16}", end="") + for label in labels: + print(f"{label[:8]:>10}", end="") + print() + + # Rows + for i, label in enumerate(labels): + print(f"{label:>16}", end="") + for j in range(len(labels)): + print(f"{matrix[i, j]:>10.3f}", end="") + print() + + print() + + +def save_visualization( + learned: np.ndarray, + ground_truth: np.ndarray, + user_order: List[str], + output_path: str +): + """Save similarity matrices as heatmap visualization.""" + try: + import matplotlib.pyplot as plt + import seaborn as sns + except ImportError: + print("[Warning] matplotlib/seaborn not available, skipping visualization") + return + + # Short labels + labels = [u.replace("user_", "")[:12] for u in user_order] + + fig, axes = plt.subplots(1, 2, figsize=(14, 6)) + + # Learned similarity + sns.heatmap(learned, annot=True, fmt=".2f", + xticklabels=labels, yticklabels=labels, + cmap="RdYlGn", vmin=-1, vmax=1, + ax=axes[0]) + axes[0].set_title("Learned User Vector Similarity\n(cosine similarity)") + axes[0].tick_params(axis='x', rotation=45) + axes[0].tick_params(axis='y', rotation=0) + + # Ground truth similarity + sns.heatmap(ground_truth, annot=True, fmt=".2f", + xticklabels=labels, yticklabels=labels, + cmap="RdYlGn", vmin=0, vmax=1, + ax=axes[1]) + axes[1].set_title("Ground Truth Preference Overlap\n(Jaccard-like)") + axes[1].tick_params(axis='x', rotation=45) + axes[1].tick_params(axis='y', rotation=0) + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches='tight') + print(f"[Visualization] Saved to: {output_path}") + + +# ============================================================================= +# Main Analysis +# ============================================================================= + +def analyze_user_similarity(user_store_path: str, output_dir: str = "data/analysis"): + """Run full user similarity analysis.""" + import os + os.makedirs(output_dir, exist_ok=True) + + print("=" * 70) + print("USER VECTOR SIMILARITY ANALYSIS") + print("=" * 70) + print(f"User store: {user_store_path}") + + # Load user vectors + print("\n[1] Loading user vectors...") + user_vectors = load_user_vectors(user_store_path) + + if not user_vectors: + print("[Warning] No user vectors found with standard format, trying alternative...") + user_vectors = load_user_vectors_from_internal(user_store_path) + + if not user_vectors: + print("[Error] Could not load user vectors!") + return + + print(f" Found {len(user_vectors)} users: {list(user_vectors.keys())}") + + # Print vector norms + print("\n[2] User vector norms:") + for uid, (z_long, z_short) in user_vectors.items(): + print(f" {uid}: ||z_long||={np.linalg.norm(z_long):.4f}, ||z_short||={np.linalg.norm(z_short):.4f}") + + # Determine user order (intersection of loaded users and known personas) + user_order = [u for u in PERSONAS.keys() if u in user_vectors] + print(f"\n[3] Analyzing {len(user_order)} users: {user_order}") + + if len(user_order) < 2: + print("[Error] Need at least 2 users for similarity analysis!") + return + + # Compute similarity matrices + print("\n[4] Computing similarity matrices...") + learned_sim = compute_learned_similarity_matrix(user_vectors, user_order) + gt_sim = compute_ground_truth_similarity(PERSONAS, user_order) + + # Print matrices + print_similarity_matrix(learned_sim, user_order, "LEARNED SIMILARITY (Cosine of z_u)") + print_similarity_matrix(gt_sim, user_order, "GROUND TRUTH SIMILARITY (Preference Overlap)") + + # Compute correlation + print("\n[5] Correlation Analysis:") + print("-" * 50) + pearson, spearman = compute_correlation(learned_sim, gt_sim) + print(f" Pearson correlation: {pearson:.4f}") + print(f" Spearman correlation: {spearman:.4f}") + + # Interpretation + print("\n[6] Interpretation:") + print("-" * 50) + if spearman > 0.7: + print(" ✅ STRONG correlation: User vectors encode preference similarity well!") + elif spearman > 0.4: + print(" ⚠️ MODERATE correlation: User vectors partially capture preferences.") + elif spearman > 0: + print(" ⚠️ WEAK correlation: User vectors weakly capture preferences.") + else: + print(" ❌ NO/NEGATIVE correlation: User vectors do not reflect preferences.") + + # Key comparisons + print("\n[7] Key Similarity Comparisons:") + print("-" * 50) + + def get_sim(u1, u2, matrix, user_order): + if u1 in user_order and u2 in user_order: + i, j = user_order.index(u1), user_order.index(u2) + return matrix[i, j] + return None + + comparisons = [ + ("user_A_short_bullets_en", "user_F_extreme_short_en", ">", "user_A_short_bullets_en", "user_E_long_no_bullets_zh", + "A~F (both short+bullets) should be > A~E (opposite)"), + ("user_A_short_bullets_en", "user_D_short_bullets_zh", ">", "user_A_short_bullets_en", "user_C_long_bullets_en", + "A~D (both short+bullets) should be > A~C (only bullets match)"), + ("user_B_short_no_bullets_en", "user_E_long_no_bullets_zh", ">", "user_B_short_no_bullets_en", "user_A_short_bullets_en", + "B~E (both no_bullets) should be > B~A (bullets differ)"), + ] + + for u1, u2, op, u3, u4, desc in comparisons: + sim1 = get_sim(u1, u2, learned_sim, user_order) + sim2 = get_sim(u3, u4, learned_sim, user_order) + + if sim1 is not None and sim2 is not None: + passed = sim1 > sim2 if op == ">" else sim1 < sim2 + status = "✅ PASS" if passed else "❌ FAIL" + print(f" {status}: sim({u1[:6]},{u2[:6]})={sim1:.3f} {op} sim({u3[:6]},{u4[:6]})={sim2:.3f}") + print(f" ({desc})") + + # Save visualization + print("\n[8] Saving visualization...") + output_path = os.path.join(output_dir, "user_similarity_matrix.png") + save_visualization(learned_sim, gt_sim, user_order, output_path) + + # Save numerical results + results_path = os.path.join(output_dir, "user_similarity_results.npz") + np.savez(results_path, + learned_similarity=learned_sim, + ground_truth_similarity=gt_sim, + user_order=user_order, + pearson=pearson, + spearman=spearman) + print(f"[Results] Saved to: {results_path}") + + print("\n" + "=" * 70) + print("ANALYSIS COMPLETE") + print("=" * 70) + + +def main(): + parser = argparse.ArgumentParser(description="User Vector Similarity Analysis") + parser.add_argument("--user-store", type=str, required=True, + help="Path to user store npz file") + parser.add_argument("--output-dir", type=str, default="data/analysis", + help="Output directory for results") + args = parser.parse_args() + + analyze_user_similarity(args.user_store, args.output_dir) + + +if __name__ == "__main__": + main() + |
