summaryrefslogtreecommitdiff
path: root/scripts/analyze_user_similarity.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/analyze_user_similarity.py')
-rw-r--r--scripts/analyze_user_similarity.py445
1 files changed, 445 insertions, 0 deletions
diff --git a/scripts/analyze_user_similarity.py b/scripts/analyze_user_similarity.py
new file mode 100644
index 0000000..538a89a
--- /dev/null
+++ b/scripts/analyze_user_similarity.py
@@ -0,0 +1,445 @@
+#!/usr/bin/env python3
+"""
+User Vector Similarity Analysis
+
+This script analyzes the similarity between user vectors (z_u) learned by the
+online personalization system. It computes:
+1. Cosine similarity matrix between all user vectors
+2. Ground truth similarity based on preference overlap
+3. Correlation between learned and expected similarities
+
+Usage:
+ python scripts/analyze_user_similarity.py \
+ --user-store data/users/user_store_pilot_v4_full-greedy.npz
+"""
+
+import argparse
+import numpy as np
+from typing import Dict, List, Tuple
+from dataclasses import dataclass
+
+
+# =============================================================================
+# Persona Definitions (must match pilot_runner_v4.py)
+# =============================================================================
+
+@dataclass
+class StylePrefs:
+ """User's TRUE style preferences."""
+ require_short: bool = False
+ max_chars: int = 300
+ require_bullets: bool = False
+ lang: str = "en"
+
+
+# Ground truth personas
+PERSONAS = {
+ "user_A_short_bullets_en": StylePrefs(require_short=True, max_chars=200, require_bullets=True, lang="en"),
+ "user_B_short_no_bullets_en": StylePrefs(require_short=True, max_chars=200, require_bullets=False, lang="en"),
+ "user_C_long_bullets_en": StylePrefs(require_short=False, max_chars=800, require_bullets=True, lang="en"),
+ "user_D_short_bullets_zh": StylePrefs(require_short=True, max_chars=200, require_bullets=True, lang="zh"),
+ "user_E_long_no_bullets_zh": StylePrefs(require_short=False, max_chars=800, require_bullets=False, lang="zh"),
+ "user_F_extreme_short_en": StylePrefs(require_short=True, max_chars=100, require_bullets=True, lang="en"),
+}
+
+
+# =============================================================================
+# User Vector Loading
+# =============================================================================
+
+def load_user_vectors(user_store_path: str) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
+ """
+ Load user vectors from saved user store.
+
+ Returns:
+ {user_id: (z_long, z_short)}
+ """
+ data = np.load(user_store_path, allow_pickle=True)
+
+ user_vectors = {}
+
+ # UserTensorStore saves in format: {uid}_long, {uid}_short, {uid}_meta
+ # First, find all unique user IDs
+ user_ids = set()
+ for key in data.files:
+ if key.endswith("_long"):
+ uid = key[:-5] # Remove "_long"
+ user_ids.add(uid)
+
+ # Load vectors for each user
+ for uid in user_ids:
+ long_key = f"{uid}_long"
+ short_key = f"{uid}_short"
+
+ if long_key in data.files and short_key in data.files:
+ z_long = data[long_key]
+ z_short = data[short_key]
+ user_vectors[uid] = (z_long, z_short)
+
+ return user_vectors
+
+
+def load_user_vectors_from_internal(user_store_path: str) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
+ """
+ Alternative loader that understands the internal format.
+ """
+ data = np.load(user_store_path, allow_pickle=True)
+
+ print(f"[Debug] Available keys in npz: {list(data.files)}")
+
+ user_vectors = {}
+
+ # Try to find user vectors in various formats
+ for key in data.files:
+ print(f" {key}: shape={data[key].shape if hasattr(data[key], 'shape') else 'N/A'}")
+
+ # Format 1: Separate arrays per user
+ seen_users = set()
+ for key in data.files:
+ if "_z_long" in key or key.startswith("z_long_"):
+ # Extract user_id
+ if key.startswith("z_long_"):
+ user_id = key[7:] # Remove "z_long_"
+ else:
+ user_id = key.split("_z_long")[0]
+ seen_users.add(user_id)
+
+ for user_id in seen_users:
+ # Try different key formats
+ z_long_keys = [f"z_long_{user_id}", f"{user_id}_z_long"]
+ z_short_keys = [f"z_short_{user_id}", f"{user_id}_z_short"]
+
+ z_long = None
+ z_short = None
+
+ for k in z_long_keys:
+ if k in data.files:
+ z_long = data[k]
+ break
+
+ for k in z_short_keys:
+ if k in data.files:
+ z_short = data[k]
+ break
+
+ if z_long is not None and z_short is not None:
+ user_vectors[user_id] = (z_long, z_short)
+
+ return user_vectors
+
+
+# =============================================================================
+# Similarity Computation
+# =============================================================================
+
+def cosine_similarity(v1: np.ndarray, v2: np.ndarray) -> float:
+ """Compute cosine similarity between two vectors."""
+ norm1 = np.linalg.norm(v1)
+ norm2 = np.linalg.norm(v2)
+
+ if norm1 < 1e-10 or norm2 < 1e-10:
+ return 0.0
+
+ return float(np.dot(v1, v2) / (norm1 * norm2))
+
+
+def compute_learned_similarity_matrix(
+ user_vectors: Dict[str, Tuple[np.ndarray, np.ndarray]],
+ user_order: List[str]
+) -> np.ndarray:
+ """
+ Compute similarity matrix from learned user vectors.
+
+ Uses concatenated [z_long, z_short] as the user representation.
+ """
+ n = len(user_order)
+ sim_matrix = np.zeros((n, n))
+
+ for i, u1 in enumerate(user_order):
+ for j, u2 in enumerate(user_order):
+ if u1 in user_vectors and u2 in user_vectors:
+ z1 = np.concatenate(user_vectors[u1])
+ z2 = np.concatenate(user_vectors[u2])
+ sim_matrix[i, j] = cosine_similarity(z1, z2)
+ elif i == j:
+ sim_matrix[i, j] = 1.0
+
+ return sim_matrix
+
+
+def compute_ground_truth_similarity(
+ personas: Dict[str, StylePrefs],
+ user_order: List[str]
+) -> np.ndarray:
+ """
+ Compute ground truth similarity based on preference overlap.
+
+ Uses Jaccard-like similarity:
+ - short: +1 if both require_short or both don't
+ - bullets: +1 if both require_bullets match
+ - lang: +1 if both lang match
+
+ Then normalize to [0, 1].
+ """
+ n = len(user_order)
+ sim_matrix = np.zeros((n, n))
+
+ for i, u1 in enumerate(user_order):
+ for j, u2 in enumerate(user_order):
+ if u1 not in personas or u2 not in personas:
+ sim_matrix[i, j] = 0.0 if i != j else 1.0
+ continue
+
+ p1 = personas[u1]
+ p2 = personas[u2]
+
+ # Count matching dimensions
+ matches = 0
+ total = 3 # short, bullets, lang
+
+ if p1.require_short == p2.require_short:
+ matches += 1
+ if p1.require_bullets == p2.require_bullets:
+ matches += 1
+ if p1.lang == p2.lang:
+ matches += 1
+
+ sim_matrix[i, j] = matches / total
+
+ return sim_matrix
+
+
+def compute_correlation(learned: np.ndarray, ground_truth: np.ndarray) -> Tuple[float, float]:
+ """
+ Compute Pearson and Spearman correlation between learned and ground truth similarity.
+ Only uses upper triangle (excluding diagonal) to avoid bias.
+ """
+ n = learned.shape[0]
+
+ # Extract upper triangle (excluding diagonal)
+ learned_flat = []
+ gt_flat = []
+
+ for i in range(n):
+ for j in range(i + 1, n):
+ learned_flat.append(learned[i, j])
+ gt_flat.append(ground_truth[i, j])
+
+ learned_flat = np.array(learned_flat)
+ gt_flat = np.array(gt_flat)
+
+ # Pearson correlation
+ if np.std(learned_flat) < 1e-10 or np.std(gt_flat) < 1e-10:
+ pearson = 0.0
+ else:
+ pearson = float(np.corrcoef(learned_flat, gt_flat)[0, 1])
+
+ # Spearman correlation (rank-based)
+ from scipy.stats import spearmanr
+ spearman, _ = spearmanr(learned_flat, gt_flat)
+
+ return pearson, float(spearman)
+
+
+# =============================================================================
+# Visualization
+# =============================================================================
+
+def print_similarity_matrix(matrix: np.ndarray, user_order: List[str], title: str):
+ """Print similarity matrix in ASCII format."""
+ print(f"\n{title}")
+ print("=" * 70)
+
+ # Short labels
+ labels = [u.replace("user_", "").replace("_", " ")[:15] for u in user_order]
+
+ # Header
+ print(f"{'':>16}", end="")
+ for label in labels:
+ print(f"{label[:8]:>10}", end="")
+ print()
+
+ # Rows
+ for i, label in enumerate(labels):
+ print(f"{label:>16}", end="")
+ for j in range(len(labels)):
+ print(f"{matrix[i, j]:>10.3f}", end="")
+ print()
+
+ print()
+
+
+def save_visualization(
+ learned: np.ndarray,
+ ground_truth: np.ndarray,
+ user_order: List[str],
+ output_path: str
+):
+ """Save similarity matrices as heatmap visualization."""
+ try:
+ import matplotlib.pyplot as plt
+ import seaborn as sns
+ except ImportError:
+ print("[Warning] matplotlib/seaborn not available, skipping visualization")
+ return
+
+ # Short labels
+ labels = [u.replace("user_", "")[:12] for u in user_order]
+
+ fig, axes = plt.subplots(1, 2, figsize=(14, 6))
+
+ # Learned similarity
+ sns.heatmap(learned, annot=True, fmt=".2f",
+ xticklabels=labels, yticklabels=labels,
+ cmap="RdYlGn", vmin=-1, vmax=1,
+ ax=axes[0])
+ axes[0].set_title("Learned User Vector Similarity\n(cosine similarity)")
+ axes[0].tick_params(axis='x', rotation=45)
+ axes[0].tick_params(axis='y', rotation=0)
+
+ # Ground truth similarity
+ sns.heatmap(ground_truth, annot=True, fmt=".2f",
+ xticklabels=labels, yticklabels=labels,
+ cmap="RdYlGn", vmin=0, vmax=1,
+ ax=axes[1])
+ axes[1].set_title("Ground Truth Preference Overlap\n(Jaccard-like)")
+ axes[1].tick_params(axis='x', rotation=45)
+ axes[1].tick_params(axis='y', rotation=0)
+
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=150, bbox_inches='tight')
+ print(f"[Visualization] Saved to: {output_path}")
+
+
+# =============================================================================
+# Main Analysis
+# =============================================================================
+
+def analyze_user_similarity(user_store_path: str, output_dir: str = "data/analysis"):
+ """Run full user similarity analysis."""
+ import os
+ os.makedirs(output_dir, exist_ok=True)
+
+ print("=" * 70)
+ print("USER VECTOR SIMILARITY ANALYSIS")
+ print("=" * 70)
+ print(f"User store: {user_store_path}")
+
+ # Load user vectors
+ print("\n[1] Loading user vectors...")
+ user_vectors = load_user_vectors(user_store_path)
+
+ if not user_vectors:
+ print("[Warning] No user vectors found with standard format, trying alternative...")
+ user_vectors = load_user_vectors_from_internal(user_store_path)
+
+ if not user_vectors:
+ print("[Error] Could not load user vectors!")
+ return
+
+ print(f" Found {len(user_vectors)} users: {list(user_vectors.keys())}")
+
+ # Print vector norms
+ print("\n[2] User vector norms:")
+ for uid, (z_long, z_short) in user_vectors.items():
+ print(f" {uid}: ||z_long||={np.linalg.norm(z_long):.4f}, ||z_short||={np.linalg.norm(z_short):.4f}")
+
+ # Determine user order (intersection of loaded users and known personas)
+ user_order = [u for u in PERSONAS.keys() if u in user_vectors]
+ print(f"\n[3] Analyzing {len(user_order)} users: {user_order}")
+
+ if len(user_order) < 2:
+ print("[Error] Need at least 2 users for similarity analysis!")
+ return
+
+ # Compute similarity matrices
+ print("\n[4] Computing similarity matrices...")
+ learned_sim = compute_learned_similarity_matrix(user_vectors, user_order)
+ gt_sim = compute_ground_truth_similarity(PERSONAS, user_order)
+
+ # Print matrices
+ print_similarity_matrix(learned_sim, user_order, "LEARNED SIMILARITY (Cosine of z_u)")
+ print_similarity_matrix(gt_sim, user_order, "GROUND TRUTH SIMILARITY (Preference Overlap)")
+
+ # Compute correlation
+ print("\n[5] Correlation Analysis:")
+ print("-" * 50)
+ pearson, spearman = compute_correlation(learned_sim, gt_sim)
+ print(f" Pearson correlation: {pearson:.4f}")
+ print(f" Spearman correlation: {spearman:.4f}")
+
+ # Interpretation
+ print("\n[6] Interpretation:")
+ print("-" * 50)
+ if spearman > 0.7:
+ print(" ✅ STRONG correlation: User vectors encode preference similarity well!")
+ elif spearman > 0.4:
+ print(" ⚠️ MODERATE correlation: User vectors partially capture preferences.")
+ elif spearman > 0:
+ print(" ⚠️ WEAK correlation: User vectors weakly capture preferences.")
+ else:
+ print(" ❌ NO/NEGATIVE correlation: User vectors do not reflect preferences.")
+
+ # Key comparisons
+ print("\n[7] Key Similarity Comparisons:")
+ print("-" * 50)
+
+ def get_sim(u1, u2, matrix, user_order):
+ if u1 in user_order and u2 in user_order:
+ i, j = user_order.index(u1), user_order.index(u2)
+ return matrix[i, j]
+ return None
+
+ comparisons = [
+ ("user_A_short_bullets_en", "user_F_extreme_short_en", ">", "user_A_short_bullets_en", "user_E_long_no_bullets_zh",
+ "A~F (both short+bullets) should be > A~E (opposite)"),
+ ("user_A_short_bullets_en", "user_D_short_bullets_zh", ">", "user_A_short_bullets_en", "user_C_long_bullets_en",
+ "A~D (both short+bullets) should be > A~C (only bullets match)"),
+ ("user_B_short_no_bullets_en", "user_E_long_no_bullets_zh", ">", "user_B_short_no_bullets_en", "user_A_short_bullets_en",
+ "B~E (both no_bullets) should be > B~A (bullets differ)"),
+ ]
+
+ for u1, u2, op, u3, u4, desc in comparisons:
+ sim1 = get_sim(u1, u2, learned_sim, user_order)
+ sim2 = get_sim(u3, u4, learned_sim, user_order)
+
+ if sim1 is not None and sim2 is not None:
+ passed = sim1 > sim2 if op == ">" else sim1 < sim2
+ status = "✅ PASS" if passed else "❌ FAIL"
+ print(f" {status}: sim({u1[:6]},{u2[:6]})={sim1:.3f} {op} sim({u3[:6]},{u4[:6]})={sim2:.3f}")
+ print(f" ({desc})")
+
+ # Save visualization
+ print("\n[8] Saving visualization...")
+ output_path = os.path.join(output_dir, "user_similarity_matrix.png")
+ save_visualization(learned_sim, gt_sim, user_order, output_path)
+
+ # Save numerical results
+ results_path = os.path.join(output_dir, "user_similarity_results.npz")
+ np.savez(results_path,
+ learned_similarity=learned_sim,
+ ground_truth_similarity=gt_sim,
+ user_order=user_order,
+ pearson=pearson,
+ spearman=spearman)
+ print(f"[Results] Saved to: {results_path}")
+
+ print("\n" + "=" * 70)
+ print("ANALYSIS COMPLETE")
+ print("=" * 70)
+
+
+def main():
+ parser = argparse.ArgumentParser(description="User Vector Similarity Analysis")
+ parser.add_argument("--user-store", type=str, required=True,
+ help="Path to user store npz file")
+ parser.add_argument("--output-dir", type=str, default="data/analysis",
+ help="Output directory for results")
+ args = parser.parse_args()
+
+ analyze_user_similarity(args.user_store, args.output_dir)
+
+
+if __name__ == "__main__":
+ main()
+