summaryrefslogtreecommitdiff
path: root/collaborativeagents/scripts/visualize_user_vectors.py
diff options
context:
space:
mode:
Diffstat (limited to 'collaborativeagents/scripts/visualize_user_vectors.py')
-rw-r--r--collaborativeagents/scripts/visualize_user_vectors.py407
1 files changed, 407 insertions, 0 deletions
diff --git a/collaborativeagents/scripts/visualize_user_vectors.py b/collaborativeagents/scripts/visualize_user_vectors.py
new file mode 100644
index 0000000..203cb68
--- /dev/null
+++ b/collaborativeagents/scripts/visualize_user_vectors.py
@@ -0,0 +1,407 @@
+#!/usr/bin/env python3
+"""
+User Vector Visualization Script
+
+Visualizes learned user vectors using t-SNE and PCA for dimensionality reduction.
+Supports multiple coloring schemes to analyze user clusters.
+
+Usage:
+ python visualize_user_vectors.py --results-dir ../results/fullrun_3methods
+ python visualize_user_vectors.py --vectors-file user_vectors.npy --profiles-file profiles.json
+"""
+
+import argparse
+import json
+import numpy as np
+import matplotlib.pyplot as plt
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+from sklearn.manifold import TSNE
+from sklearn.decomposition import PCA
+from sklearn.preprocessing import StandardScaler
+import warnings
+warnings.filterwarnings('ignore')
+
+
+def load_user_vectors(results_dir: Path) -> Tuple[np.ndarray, List[int]]:
+ """Load user vectors from experiment results."""
+ vectors = []
+ user_ids = []
+
+ # Try to find user vectors in different locations
+ possible_paths = [
+ results_dir / "user_vectors.npy",
+ results_dir / "rag_vector" / "user_vectors.npy",
+ results_dir / "checkpoints" / "user_vectors.npy",
+ ]
+
+ for path in possible_paths:
+ if path.exists():
+ data = np.load(path, allow_pickle=True)
+ if isinstance(data, np.ndarray):
+ if data.dtype == object:
+ # Dictionary format
+ data = data.item()
+ for uid, vec in data.items():
+ user_ids.append(int(uid))
+ vectors.append(vec)
+ else:
+ # Direct array format
+ vectors = data
+ user_ids = list(range(len(data)))
+ print(f"Loaded {len(vectors)} user vectors from {path}")
+ return np.array(vectors), user_ids
+
+ # Try to extract from results.json
+ results_files = list(results_dir.glob("**/results.json"))
+ for rf in results_files:
+ try:
+ with open(rf) as f:
+ data = json.load(f)
+ # Extract user vectors if stored in results
+ if isinstance(data, dict) and "user_vectors" in data:
+ for uid, vec in data["user_vectors"].items():
+ user_ids.append(int(uid))
+ vectors.append(np.array(vec))
+ print(f"Loaded {len(vectors)} user vectors from {rf}")
+ return np.array(vectors), user_ids
+ except:
+ continue
+
+ raise FileNotFoundError(f"No user vectors found in {results_dir}")
+
+
+def load_profiles(profiles_path: Path) -> List[Dict]:
+ """Load user profiles for labeling."""
+ if profiles_path.suffix == '.jsonl':
+ profiles = []
+ with open(profiles_path) as f:
+ for line in f:
+ profiles.append(json.loads(line))
+ return profiles
+ else:
+ with open(profiles_path) as f:
+ return json.load(f)
+
+
+def extract_profile_features(profiles: List[Dict]) -> Dict[str, List]:
+ """Extract features from profiles for coloring."""
+ features = {
+ "categories": [],
+ "n_preferences": [],
+ "persona_length": [],
+ }
+
+ for p in profiles:
+ # Extract categories if available
+ cats = p.get("categories", [])
+ features["categories"].append(cats[0] if cats else "unknown")
+
+ # Number of preferences
+ prefs = p.get("preferences", [])
+ features["n_preferences"].append(len(prefs))
+
+ # Persona length
+ persona = p.get("persona", "")
+ features["persona_length"].append(len(persona))
+
+ return features
+
+
+def apply_tsne(vectors: np.ndarray, perplexity: int = 30, max_iter: int = 1000) -> np.ndarray:
+ """Apply t-SNE dimensionality reduction."""
+ # Standardize vectors
+ scaler = StandardScaler()
+ vectors_scaled = scaler.fit_transform(vectors)
+
+ # Adjust perplexity if needed
+ n_samples = len(vectors)
+ perplexity = min(perplexity, n_samples - 1)
+
+ tsne = TSNE(
+ n_components=2,
+ perplexity=perplexity,
+ max_iter=max_iter,
+ random_state=42,
+ init='pca',
+ learning_rate='auto'
+ )
+ return tsne.fit_transform(vectors_scaled)
+
+
+def apply_pca(vectors: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """Apply PCA dimensionality reduction. Returns (2D projection, explained variance)."""
+ scaler = StandardScaler()
+ vectors_scaled = scaler.fit_transform(vectors)
+
+ pca = PCA(n_components=min(10, vectors.shape[1]))
+ transformed = pca.fit_transform(vectors_scaled)
+
+ return transformed[:, :2], pca.explained_variance_ratio_
+
+
+def plot_comparison(
+ vectors: np.ndarray,
+ user_ids: List[int],
+ profiles: Optional[List[Dict]] = None,
+ output_path: Optional[Path] = None,
+ title_prefix: str = ""
+):
+ """Create side-by-side t-SNE and PCA plots."""
+
+ # Apply dimensionality reduction
+ print("Applying t-SNE...")
+ tsne_2d = apply_tsne(vectors)
+
+ print("Applying PCA...")
+ pca_2d, pca_variance = apply_pca(vectors)
+
+ # Prepare coloring
+ if profiles and len(profiles) >= len(user_ids):
+ features = extract_profile_features(profiles)
+ color_by = features["n_preferences"]
+ color_label = "Number of Preferences"
+ else:
+ color_by = user_ids
+ color_label = "User ID"
+
+ # Create figure
+ fig, axes = plt.subplots(1, 2, figsize=(16, 7))
+
+ # t-SNE plot
+ ax1 = axes[0]
+ scatter1 = ax1.scatter(
+ tsne_2d[:, 0], tsne_2d[:, 1],
+ c=color_by, cmap='viridis', alpha=0.7, s=50
+ )
+ ax1.set_xlabel('t-SNE Dimension 1')
+ ax1.set_ylabel('t-SNE Dimension 2')
+ ax1.set_title(f'{title_prefix}t-SNE Visualization\n({len(vectors)} users)')
+ plt.colorbar(scatter1, ax=ax1, label=color_label)
+
+ # PCA plot
+ ax2 = axes[1]
+ scatter2 = ax2.scatter(
+ pca_2d[:, 0], pca_2d[:, 1],
+ c=color_by, cmap='viridis', alpha=0.7, s=50
+ )
+ ax2.set_xlabel(f'PC1 ({pca_variance[0]*100:.1f}% variance)')
+ ax2.set_ylabel(f'PC2 ({pca_variance[1]*100:.1f}% variance)')
+ ax2.set_title(f'{title_prefix}PCA Visualization\n(Top 2 components: {(pca_variance[0]+pca_variance[1])*100:.1f}% variance)')
+ plt.colorbar(scatter2, ax=ax2, label=color_label)
+
+ plt.tight_layout()
+
+ if output_path:
+ plt.savefig(output_path, dpi=150, bbox_inches='tight')
+ print(f"Saved comparison plot to {output_path}")
+
+ plt.show()
+
+ return tsne_2d, pca_2d, pca_variance
+
+
+def plot_by_category(
+ vectors: np.ndarray,
+ user_ids: List[int],
+ profiles: List[Dict],
+ output_path: Optional[Path] = None
+):
+ """Create plots colored by preference category."""
+
+ features = extract_profile_features(profiles)
+ categories = features["categories"]
+ unique_cats = list(set(categories))
+ cat_to_idx = {c: i for i, c in enumerate(unique_cats)}
+ cat_colors = [cat_to_idx[c] for c in categories[:len(user_ids)]]
+
+ # Apply reductions
+ tsne_2d = apply_tsne(vectors)
+ pca_2d, pca_variance = apply_pca(vectors)
+
+ fig, axes = plt.subplots(1, 2, figsize=(16, 7))
+
+ # t-SNE by category
+ ax1 = axes[0]
+ scatter1 = ax1.scatter(
+ tsne_2d[:, 0], tsne_2d[:, 1],
+ c=cat_colors, cmap='tab10', alpha=0.7, s=50
+ )
+ ax1.set_xlabel('t-SNE Dimension 1')
+ ax1.set_ylabel('t-SNE Dimension 2')
+ ax1.set_title('t-SNE by Preference Category')
+
+ # PCA by category
+ ax2 = axes[1]
+ scatter2 = ax2.scatter(
+ pca_2d[:, 0], pca_2d[:, 1],
+ c=cat_colors, cmap='tab10', alpha=0.7, s=50
+ )
+ ax2.set_xlabel(f'PC1 ({pca_variance[0]*100:.1f}%)')
+ ax2.set_ylabel(f'PC2 ({pca_variance[1]*100:.1f}%)')
+ ax2.set_title('PCA by Preference Category')
+
+ # Add legend
+ handles = [plt.scatter([], [], c=[cat_to_idx[c]], cmap='tab10', label=c)
+ for c in unique_cats[:10]] # Limit to 10 categories
+ fig.legend(handles, unique_cats[:10], loc='center right', title='Category')
+
+ plt.tight_layout()
+ plt.subplots_adjust(right=0.85)
+
+ if output_path:
+ plt.savefig(output_path, dpi=150, bbox_inches='tight')
+ print(f"Saved category plot to {output_path}")
+
+ plt.show()
+
+
+def plot_pca_variance(vectors: np.ndarray, output_path: Optional[Path] = None):
+ """Plot PCA explained variance to understand dimensionality."""
+ scaler = StandardScaler()
+ vectors_scaled = scaler.fit_transform(vectors)
+
+ n_components = min(50, vectors.shape[1], vectors.shape[0])
+ pca = PCA(n_components=n_components)
+ pca.fit(vectors_scaled)
+
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
+
+ # Individual variance
+ ax1 = axes[0]
+ ax1.bar(range(1, n_components + 1), pca.explained_variance_ratio_ * 100)
+ ax1.set_xlabel('Principal Component')
+ ax1.set_ylabel('Explained Variance (%)')
+ ax1.set_title('PCA Explained Variance by Component')
+ ax1.set_xlim(0, n_components + 1)
+
+ # Cumulative variance
+ ax2 = axes[1]
+ cumvar = np.cumsum(pca.explained_variance_ratio_) * 100
+ ax2.plot(range(1, n_components + 1), cumvar, 'b-o', markersize=4)
+ ax2.axhline(y=90, color='r', linestyle='--', label='90% variance')
+ ax2.axhline(y=95, color='g', linestyle='--', label='95% variance')
+ ax2.set_xlabel('Number of Components')
+ ax2.set_ylabel('Cumulative Explained Variance (%)')
+ ax2.set_title('PCA Cumulative Explained Variance')
+ ax2.legend()
+ ax2.set_xlim(0, n_components + 1)
+ ax2.set_ylim(0, 105)
+
+ # Find components needed for 90% and 95% variance
+ n_90 = np.argmax(cumvar >= 90) + 1
+ n_95 = np.argmax(cumvar >= 95) + 1
+ print(f"Components for 90% variance: {n_90}")
+ print(f"Components for 95% variance: {n_95}")
+
+ plt.tight_layout()
+
+ if output_path:
+ plt.savefig(output_path, dpi=150, bbox_inches='tight')
+ print(f"Saved variance plot to {output_path}")
+
+ plt.show()
+
+ return pca.explained_variance_ratio_
+
+
+def generate_synthetic_vectors(n_users: int = 200, dim: int = 64) -> np.ndarray:
+ """Generate synthetic user vectors for testing visualization."""
+ np.random.seed(42)
+
+ # Create 5 clusters of users
+ n_clusters = 5
+ cluster_size = n_users // n_clusters
+ vectors = []
+
+ for i in range(n_clusters):
+ # Each cluster has a different center
+ center = np.random.randn(dim) * 2
+ # Users in cluster are variations around center
+ cluster_vectors = center + np.random.randn(cluster_size, dim) * 0.5
+ vectors.append(cluster_vectors)
+
+ # Add remaining users
+ remaining = n_users - n_clusters * cluster_size
+ if remaining > 0:
+ vectors.append(np.random.randn(remaining, dim))
+
+ return np.vstack(vectors)
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Visualize user vectors with t-SNE and PCA")
+ parser.add_argument("--results-dir", type=str, help="Path to experiment results directory")
+ parser.add_argument("--vectors-file", type=str, help="Path to user vectors .npy file")
+ parser.add_argument("--profiles-file", type=str, help="Path to user profiles JSON file")
+ parser.add_argument("--output-dir", type=str, default=".", help="Output directory for plots")
+ parser.add_argument("--demo", action="store_true", help="Run demo with synthetic data")
+ args = parser.parse_args()
+
+ output_dir = Path(args.output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ if args.demo:
+ print("Running demo with synthetic user vectors...")
+ vectors = generate_synthetic_vectors(200, 64)
+ user_ids = list(range(200))
+ profiles = None
+ title_prefix = "[Demo] "
+ elif args.vectors_file:
+ vectors = np.load(args.vectors_file)
+ user_ids = list(range(len(vectors)))
+ profiles = None
+ if args.profiles_file:
+ profiles = load_profiles(Path(args.profiles_file))
+ title_prefix = ""
+ elif args.results_dir:
+ results_dir = Path(args.results_dir)
+ vectors, user_ids = load_user_vectors(results_dir)
+
+ # Try to find profiles
+ profiles = None
+ profile_paths = [
+ results_dir / "generated_profiles.json",
+ results_dir.parent / "profiles.json",
+ Path("../data/complex_profiles_v2/profiles_200.jsonl"),
+ ]
+ for pp in profile_paths:
+ if pp.exists():
+ profiles = load_profiles(pp)
+ print(f"Loaded {len(profiles)} profiles from {pp}")
+ break
+ title_prefix = ""
+ else:
+ print("Please provide --results-dir, --vectors-file, or --demo")
+ return
+
+ print(f"\nUser vectors shape: {vectors.shape}")
+ print(f"Number of users: {len(user_ids)}")
+
+ # Generate plots
+ print("\n=== Generating comparison plot ===")
+ plot_comparison(
+ vectors, user_ids, profiles,
+ output_path=output_dir / "user_vectors_comparison.png",
+ title_prefix=title_prefix
+ )
+
+ print("\n=== Generating PCA variance plot ===")
+ plot_pca_variance(
+ vectors,
+ output_path=output_dir / "user_vectors_pca_variance.png"
+ )
+
+ if profiles and len(profiles) >= len(user_ids):
+ print("\n=== Generating category plot ===")
+ plot_by_category(
+ vectors, user_ids, profiles,
+ output_path=output_dir / "user_vectors_by_category.png"
+ )
+
+ print("\n=== Done! ===")
+ print(f"Plots saved to {output_dir}")
+
+
+if __name__ == "__main__":
+ main()