diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-10 20:16:36 +0000 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-10 20:16:36 +0000 |
| commit | 5626080ca4c4219aec4888d6b9406d0d3349fb55 (patch) | |
| tree | 86287d9fd5833e11ccd78566992540f2664fd195 /collaborativeagents/scripts | |
| parent | a2036838807428424bbbaff507a6563749a83145 (diff) | |
Add RAG rewrite, 60-session experiment scripts, and analysis tools
- RAG rewrite adapter and vector preference pipeline in personalized_llm
- 60-session experiment queue scripts (reflection, rag, rag_vector, rag_rewrite)
- Vector-preference correlation analysis and visualization scripts
- Local reward model batch processing improvements
- Updated CLAUDE.md with full experiment documentation and notes
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'collaborativeagents/scripts')
| -rwxr-xr-x | collaborativeagents/scripts/analyze_vector_preference.py | 103 | ||||
| -rw-r--r-- | collaborativeagents/scripts/conflict_scenario_generator.py | 13 | ||||
| -rwxr-xr-x | collaborativeagents/scripts/queue_next_run.sh | 44 | ||||
| -rwxr-xr-x | collaborativeagents/scripts/queue_rag_60s.sh | 45 | ||||
| -rwxr-xr-x | collaborativeagents/scripts/queue_rag_rewrite.sh | 45 | ||||
| -rwxr-xr-x | collaborativeagents/scripts/queue_rag_rewrite_vector.sh | 45 | ||||
| -rwxr-xr-x | collaborativeagents/scripts/queue_rag_vector.sh | 45 | ||||
| -rwxr-xr-x | collaborativeagents/scripts/queue_reflection_60s.sh | 45 | ||||
| -rwxr-xr-x | collaborativeagents/scripts/queue_topk5_v2.sh | 44 | ||||
| -rwxr-xr-x | collaborativeagents/scripts/queue_topk5_v3.sh | 52 | ||||
| -rw-r--r-- | collaborativeagents/scripts/run_experiments.py | 317 | ||||
| -rwxr-xr-x | collaborativeagents/scripts/test_new_rewrite.sh | 50 | ||||
| -rw-r--r-- | collaborativeagents/scripts/visualize_user_vectors.py | 407 |
13 files changed, 1231 insertions, 24 deletions
diff --git a/collaborativeagents/scripts/analyze_vector_preference.py b/collaborativeagents/scripts/analyze_vector_preference.py new file mode 100755 index 0000000..7079b26 --- /dev/null +++ b/collaborativeagents/scripts/analyze_vector_preference.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +""" +分析 user vector 与 revealed preference 之间的关联强度 +""" +import json +import numpy as np +from pathlib import Path +import sys + +def load_experiment(exp_dir): + """加载实验结果""" + exp_path = Path(exp_dir) + + # 找到结果目录 + for method in ["rag_vector", "rag_vector_fast", "rag_vector_balanced"]: + for sub in exp_path.iterdir(): + result_dir = sub / method + if result_dir.exists(): + vectors_path = result_dir / "user_vectors.npz" + results_path = result_dir / "results.json" + if vectors_path.exists() and results_path.exists(): + return { + "vectors": np.load(vectors_path, allow_pickle=True), + "results": json.load(open(results_path)), + "method": method + } + return None + +def analyze_vectors(data): + """分析user vectors""" + vectors = data["vectors"] + results = data["results"] + + user_ids = vectors["user_ids"] + z_long = vectors["z_long"] + z_short = vectors["z_short"] + + print(f"=== User Vector 分析 ===") + print(f"用户数: {len(user_ids)}") + print(f"Vector维度: {z_long.shape[1]}") + + # 计算非零vector数量 + z_long_norms = np.linalg.norm(z_long, axis=1) + z_short_norms = np.linalg.norm(z_short, axis=1) + + nonzero_long = np.count_nonzero(z_long_norms) + nonzero_short = np.count_nonzero(z_short_norms) + + print(f"\nz_long 非零用户: {nonzero_long}/{len(user_ids)}") + print(f"z_short 非零用户: {nonzero_short}/{len(user_ids)}") + print(f"z_long norm 均值: {np.mean(z_long_norms):.4f}") + print(f"z_short norm 均值: {np.mean(z_short_norms):.4f}") + + # 按用户分析性能与vector norm的关系 + print(f"\n=== Vector Norm vs 性能 ===") + + user_stats = {} + for s in results: + uid = s.get("profile_id", "") + if uid not in user_stats: + user_stats[uid] = {"success": 0, "total": 0, "enforce": 0} + m = s.get("metrics", {}) + user_stats[uid]["total"] += 1 + user_stats[uid]["success"] += 1 if m.get("task_success", False) else 0 + user_stats[uid]["enforce"] += m.get("enforcement_count", 0) + + # 计算相关性 + success_rates = [] + norms = [] + + for i, uid in enumerate(user_ids): + if uid in user_stats and user_stats[uid]["total"] > 0: + sr = user_stats[uid]["success"] / user_stats[uid]["total"] + success_rates.append(sr) + norms.append(z_long_norms[i]) + + if len(success_rates) > 5: + corr = np.corrcoef(success_rates, norms)[0, 1] + print(f"z_long norm vs 成功率 相关系数: {corr:.4f}") + + return { + "n_users": len(user_ids), + "nonzero_long": nonzero_long, + "nonzero_short": nonzero_short, + "mean_norm_long": float(np.mean(z_long_norms)), + "mean_norm_short": float(np.mean(z_short_norms)), + } + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Usage: python analyze_vector_preference.py <experiment_dir>") + print("Example: python analyze_vector_preference.py collaborativeagents/results/rag_vector_v3") + sys.exit(1) + + exp_dir = sys.argv[1] + data = load_experiment(exp_dir) + + if data is None: + print(f"未找到有效的rag_vector实验结果: {exp_dir}") + sys.exit(1) + + print(f"加载实验: {data['method']}") + analyze_vectors(data) diff --git a/collaborativeagents/scripts/conflict_scenario_generator.py b/collaborativeagents/scripts/conflict_scenario_generator.py index 9d00de8..eaf8ef2 100644 --- a/collaborativeagents/scripts/conflict_scenario_generator.py +++ b/collaborativeagents/scripts/conflict_scenario_generator.py @@ -367,11 +367,14 @@ class ConflictScenarioGenerator: # Find conflict groups in these preferences conflict_groups = {} for pref in preferences: - cg = pref.get('conflict_group') - if cg: - if cg not in conflict_groups: - conflict_groups[cg] = [] - conflict_groups[cg].append(pref) + # Handle both dict preferences (with conflict_group) and string preferences + if isinstance(pref, dict): + cg = pref.get('conflict_group') + if cg: + if cg not in conflict_groups: + conflict_groups[cg] = [] + conflict_groups[cg].append(pref) + # String preferences don't have conflict groups - skip them # Find a conflict group with at least 2 preferences for cg, prefs in conflict_groups.items(): diff --git a/collaborativeagents/scripts/queue_next_run.sh b/collaborativeagents/scripts/queue_next_run.sh new file mode 100755 index 0000000..524f0e2 --- /dev/null +++ b/collaborativeagents/scripts/queue_next_run.sh @@ -0,0 +1,44 @@ +#!/bin/bash +# Wait for current fullrun_4methods to finish, then start real-profile experiment + +LOG="/workspace/personalization-user-model/collaborativeagents/results/fullrun_4methods.log" +SCRIPTS_DIR="/workspace/personalization-user-model/collaborativeagents/scripts" + +echo "[$(date)] Waiting for fullrun_4methods to complete..." + +while true; do + if grep -q "EXPERIMENT COMPLETE" "$LOG" 2>/dev/null; then + echo "[$(date)] fullrun_4methods completed!" + break + fi + # Check if process died without completing + if ! pgrep -f "fullrun_4methods" > /dev/null 2>&1; then + if ! grep -q "EXPERIMENT COMPLETE" "$LOG" 2>/dev/null; then + echo "[$(date)] WARNING: process died before completion. Starting next run anyway." + break + fi + fi + sleep 60 +done + +echo "[$(date)] Starting real-profile experiment: 4 methods x 100 profiles x 20 sessions" + +cd "$SCRIPTS_DIR" +nohup python3 run_experiments.py \ + --methods vanilla,reflection,rag,rag_vector \ + --datasets math-hard \ + --n-profiles 100 \ + --n-sessions 20 \ + --max-turns 8 \ + --use-vllm \ + --vllm-agent-url http://localhost:8003/v1 \ + --vllm-user-url http://localhost:8004/v1 \ + --parallel-profiles 100 \ + --reward-mode llm_local \ + --reward-vllm-url http://localhost:8005/v1 \ + --profile-path ../data/complex_profiles_v2/profiles_200.jsonl \ + --output-dir ../results/realprofile_4methods \ + > ../results/realprofile_4methods.log 2>&1 & + +echo "[$(date)] Experiment launched with PID $!" +echo "Log: /workspace/personalization-user-model/collaborativeagents/results/realprofile_4methods.log" diff --git a/collaborativeagents/scripts/queue_rag_60s.sh b/collaborativeagents/scripts/queue_rag_60s.sh new file mode 100755 index 0000000..87d6679 --- /dev/null +++ b/collaborativeagents/scripts/queue_rag_60s.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# 等待 reflection_60s 完成后启动 rag 60 session 实验 + +echo "等待 reflection_60s 完成..." + +while true; do + CKPT=$(ls collaborativeagents/results/reflection_60s/*/reflection/checkpoint.json 2>/dev/null | head -1) + if [ -n "$CKPT" ]; then + PROGRESS=$(cat "$CKPT" | python3 -c " +import json, sys +data = json.load(sys.stdin) +total = sum(data['sessions_per_profile'].values()) +print(total) +" 2>/dev/null) + + if [ "$PROGRESS" = "3600" ]; then + echo "$(date '+%H:%M:%S') reflection_60s 已完成" + break + fi + echo "$(date '+%H:%M:%S') reflection_60s 进度: $PROGRESS/3600" + else + echo "$(date '+%H:%M:%S') 等待reflection_60s启动..." + fi + sleep 60 +done + +echo "启动 rag_60s..." + +nohup python collaborativeagents/scripts/run_experiments.py \ + --methods rag \ + --datasets math-hard,math-500,bigcodebench \ + --n-profiles 60 \ + --n-sessions 60 \ + --max-turns 10 \ + --use-vllm \ + --vllm-agent-url http://localhost:8003/v1 \ + --vllm-user-url http://localhost:8004/v1 \ + --use-batch-processing \ + --batch-size 4 \ + --parallel-profiles 20 \ + --profile-path collaborativeagents/data/complex_profiles_v2/profiles_200.jsonl \ + --output-dir collaborativeagents/results/rag_60s \ + > collaborativeagents/results/rag_60s.log 2>&1 & + +echo "rag_60s 已启动,PID: $!" diff --git a/collaborativeagents/scripts/queue_rag_rewrite.sh b/collaborativeagents/scripts/queue_rag_rewrite.sh new file mode 100755 index 0000000..a5b5b99 --- /dev/null +++ b/collaborativeagents/scripts/queue_rag_rewrite.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# 等待 rag_vector_60s 完成后启动 rag_rewrite 测试 + +echo "等待 rag_vector_60s 完成..." + +while true; do + CKPT=$(ls collaborativeagents/results/rag_vector_60s/*/rag_vector/checkpoint.json 2>/dev/null | head -1) + if [ -n "$CKPT" ]; then + PROGRESS=$(cat "$CKPT" | python3 -c " +import json, sys +data = json.load(sys.stdin) +total = sum(data['sessions_per_profile'].values()) +print(total) +" 2>/dev/null) + + if [ "$PROGRESS" = "3600" ]; then + echo "$(date '+%H:%M:%S') rag_vector_60s 已完成" + break + fi + echo "$(date '+%H:%M:%S') rag_vector_60s 进度: $PROGRESS/3600" + else + echo "$(date '+%H:%M:%S') 等待rag_vector_60s启动..." + fi + sleep 60 +done + +echo "启动 rag_rewrite_60s..." + +nohup python collaborativeagents/scripts/run_experiments.py \ + --methods rag_rewrite \ + --datasets math-hard,math-500,bigcodebench \ + --n-profiles 60 \ + --n-sessions 60 \ + --max-turns 10 \ + --use-vllm \ + --vllm-agent-url http://localhost:8003/v1 \ + --vllm-user-url http://localhost:8004/v1 \ + --use-batch-processing \ + --batch-size 4 \ + --parallel-profiles 20 \ + --profile-path collaborativeagents/data/complex_profiles_v2/profiles_200.jsonl \ + --output-dir collaborativeagents/results/rag_rewrite_60s \ + > collaborativeagents/results/rag_rewrite_60s.log 2>&1 & + +echo "rag_rewrite_60s 已启动,PID: $!" diff --git a/collaborativeagents/scripts/queue_rag_rewrite_vector.sh b/collaborativeagents/scripts/queue_rag_rewrite_vector.sh new file mode 100755 index 0000000..480e3ac --- /dev/null +++ b/collaborativeagents/scripts/queue_rag_rewrite_vector.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# 等待 rag_rewrite_60s 完成后启动 rag_rewrite_vector 测试 + +echo "等待 rag_rewrite_60s 完成..." + +while true; do + CKPT=$(ls collaborativeagents/results/rag_rewrite_60s/*/rag_rewrite/checkpoint.json 2>/dev/null | head -1) + if [ -n "$CKPT" ]; then + PROGRESS=$(cat "$CKPT" | python3 -c " +import json, sys +data = json.load(sys.stdin) +total = sum(data['sessions_per_profile'].values()) +print(total) +" 2>/dev/null) + + if [ "$PROGRESS" = "3600" ]; then + echo "$(date '+%H:%M:%S') rag_rewrite_60s 已完成" + break + fi + echo "$(date '+%H:%M:%S') rag_rewrite_60s 进度: $PROGRESS/3600" + else + echo "$(date '+%H:%M:%S') 等待rag_rewrite_60s启动..." + fi + sleep 120 +done + +echo "启动 rag_rewrite_vector_60s..." + +nohup python collaborativeagents/scripts/run_experiments.py \ + --methods rag_rewrite_vector \ + --datasets math-hard,math-500,bigcodebench \ + --n-profiles 60 \ + --n-sessions 60 \ + --max-turns 10 \ + --use-vllm \ + --vllm-agent-url http://localhost:8003/v1 \ + --vllm-user-url http://localhost:8004/v1 \ + --use-batch-processing \ + --batch-size 4 \ + --parallel-profiles 20 \ + --profile-path collaborativeagents/data/complex_profiles_v2/profiles_200.jsonl \ + --output-dir collaborativeagents/results/rag_rewrite_vector_60s \ + > collaborativeagents/results/rag_rewrite_vector_60s.log 2>&1 & + +echo "rag_rewrite_vector_60s 已启动,PID: $!" diff --git a/collaborativeagents/scripts/queue_rag_vector.sh b/collaborativeagents/scripts/queue_rag_vector.sh new file mode 100755 index 0000000..2760dc7 --- /dev/null +++ b/collaborativeagents/scripts/queue_rag_vector.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# 等待 topk5_v3 完成后启动 rag_vector 实验 + +echo "等待 topk5_v3 完成..." + +while true; do + CKPT=$(ls collaborativeagents/results/rag_topk5_v3/*/rag/checkpoint.json 2>/dev/null | head -1) + if [ -n "$CKPT" ]; then + PROGRESS=$(cat "$CKPT" | python3 -c " +import json, sys +data = json.load(sys.stdin) +total = sum(data['sessions_per_profile'].values()) +print(total) +" 2>/dev/null) + + if [ "$PROGRESS" = "1800" ]; then + echo "$(date '+%H:%M:%S') topk5_v3 已完成" + break + fi + echo "$(date '+%H:%M:%S') topk5_v3 进度: $PROGRESS/1800" + else + echo "$(date '+%H:%M:%S') 等待topk5_v3启动..." + fi + sleep 60 +done + +echo "启动 rag_vector..." + +nohup python collaborativeagents/scripts/run_experiments.py \ + --methods rag_vector \ + --datasets math-hard,math-500,bigcodebench \ + --n-profiles 60 \ + --n-sessions 30 \ + --max-turns 10 \ + --use-vllm \ + --vllm-agent-url http://localhost:8003/v1 \ + --vllm-user-url http://localhost:8004/v1 \ + --use-batch-processing \ + --batch-size 4 \ + --parallel-profiles 20 \ + --profile-path collaborativeagents/data/complex_profiles_v2/profiles_200.jsonl \ + --output-dir collaborativeagents/results/rag_vector_v3 \ + > collaborativeagents/results/rag_vector_v3.log 2>&1 & + +echo "rag_vector_v3 已启动,PID: $!" diff --git a/collaborativeagents/scripts/queue_reflection_60s.sh b/collaborativeagents/scripts/queue_reflection_60s.sh new file mode 100755 index 0000000..6d15073 --- /dev/null +++ b/collaborativeagents/scripts/queue_reflection_60s.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# 等待 rag_vector_v3 完成后启动 reflection 60 session 实验 + +echo "等待 rag_vector_v3 完成..." + +while true; do + CKPT=$(ls collaborativeagents/results/rag_vector_v3/*/rag_vector/checkpoint.json 2>/dev/null | head -1) + if [ -n "$CKPT" ]; then + PROGRESS=$(cat "$CKPT" | python3 -c " +import json, sys +data = json.load(sys.stdin) +total = sum(data['sessions_per_profile'].values()) +print(total) +" 2>/dev/null) + + if [ "$PROGRESS" = "1800" ]; then + echo "$(date '+%H:%M:%S') rag_vector_v3 已完成" + break + fi + echo "$(date '+%H:%M:%S') rag_vector_v3 进度: $PROGRESS/1800" + else + echo "$(date '+%H:%M:%S') 等待rag_vector_v3启动..." + fi + sleep 60 +done + +echo "启动 reflection_60s..." + +nohup python collaborativeagents/scripts/run_experiments.py \ + --methods reflection \ + --datasets math-hard,math-500,bigcodebench \ + --n-profiles 60 \ + --n-sessions 60 \ + --max-turns 10 \ + --use-vllm \ + --vllm-agent-url http://localhost:8003/v1 \ + --vllm-user-url http://localhost:8004/v1 \ + --use-batch-processing \ + --batch-size 4 \ + --parallel-profiles 20 \ + --profile-path collaborativeagents/data/complex_profiles_v2/profiles_200.jsonl \ + --output-dir collaborativeagents/results/reflection_60s \ + > collaborativeagents/results/reflection_60s.log 2>&1 & + +echo "reflection_60s 已启动,PID: $!" diff --git a/collaborativeagents/scripts/queue_topk5_v2.sh b/collaborativeagents/scripts/queue_topk5_v2.sh new file mode 100755 index 0000000..a116a28 --- /dev/null +++ b/collaborativeagents/scripts/queue_topk5_v2.sh @@ -0,0 +1,44 @@ +#!/bin/bash +# 等待 rag_dynamic 实验完成后启动 topk5 新版本测试 + +echo "等待 rag_dynamic 实验完成..." + +while true; do + # 检查是否有运行中的实验进程 + if pgrep -f "rag_dynamic" > /dev/null; then + # 检查checkpoint进度 + PROGRESS=$(cat collaborativeagents/results/rag_dynamic_test/*/rag_dynamic/checkpoint.json 2>/dev/null | python3 -c " +import json, sys +try: + data = json.load(sys.stdin) + total = sum(data['sessions_per_profile'].values()) + print(f'{total}/1800') +except: + print('0/1800') +" 2>/dev/null) + echo "$(date '+%H:%M:%S') rag_dynamic 进度: $PROGRESS" + sleep 60 + else + echo "rag_dynamic 已完成或未运行,启动 topk5_v2..." + break + fi +done + +# 启动新实验 +nohup python collaborativeagents/scripts/run_experiments.py \ + --methods rag \ + --datasets math-hard,math-500,bigcodebench \ + --n-profiles 60 \ + --n-sessions 30 \ + --max-turns 10 \ + --use-vllm \ + --vllm-agent-url http://localhost:8003/v1 \ + --vllm-user-url http://localhost:8004/v1 \ + --use-batch-processing \ + --batch-size 4 \ + --parallel-profiles 20 \ + --profile-path collaborativeagents/data/complex_profiles_v2/profiles_200.jsonl \ + --output-dir collaborativeagents/results/rag_topk5_v2 \ + > collaborativeagents/results/rag_topk5_v2.log 2>&1 & + +echo "topk5_v2 已启动,PID: $!" diff --git a/collaborativeagents/scripts/queue_topk5_v3.sh b/collaborativeagents/scripts/queue_topk5_v3.sh new file mode 100755 index 0000000..e93b066 --- /dev/null +++ b/collaborativeagents/scripts/queue_topk5_v3.sh @@ -0,0 +1,52 @@ +#!/bin/bash +# 等待 reflection_v2 实验完成后启动 topk5_v3 测试 + +echo "等待 reflection_v2 实验完成..." + +while true; do + # 检查checkpoint进度 + CKPT=$(ls collaborativeagents/results/reflection_v2/*/reflection/checkpoint.json 2>/dev/null | head -1) + if [ -n "$CKPT" ]; then + PROGRESS=$(cat "$CKPT" | python3 -c " +import json, sys +try: + data = json.load(sys.stdin) + total = sum(data['sessions_per_profile'].values()) + print(f'{total}/1800 ({total*100//1800}%)') +except: + print('0/1800') +" 2>/dev/null) + + # 检查是否完成 + DONE=$(echo "$PROGRESS" | grep -c "1800/1800") + if [ "$DONE" -eq 1 ]; then + echo "$(date '+%H:%M:%S') reflection_v2 已完成" + break + fi + echo "$(date '+%H:%M:%S') reflection_v2 进度: $PROGRESS" + else + echo "$(date '+%H:%M:%S') 等待reflection_v2启动..." + fi + sleep 60 +done + +echo "启动 topk5_v3..." + +# 启动新实验 +nohup python collaborativeagents/scripts/run_experiments.py \ + --methods rag \ + --datasets math-hard,math-500,bigcodebench \ + --n-profiles 60 \ + --n-sessions 30 \ + --max-turns 10 \ + --use-vllm \ + --vllm-agent-url http://localhost:8003/v1 \ + --vllm-user-url http://localhost:8004/v1 \ + --use-batch-processing \ + --batch-size 4 \ + --parallel-profiles 20 \ + --profile-path collaborativeagents/data/complex_profiles_v2/profiles_200.jsonl \ + --output-dir collaborativeagents/results/rag_topk5_v3 \ + > collaborativeagents/results/rag_topk5_v3.log 2>&1 & + +echo "topk5_v3 已启动,PID: $!" diff --git a/collaborativeagents/scripts/run_experiments.py b/collaborativeagents/scripts/run_experiments.py index e04680c..da3549b 100644 --- a/collaborativeagents/scripts/run_experiments.py +++ b/collaborativeagents/scripts/run_experiments.py @@ -15,6 +15,7 @@ import json import yaml import os import sys +import numpy as np from pathlib import Path from datetime import datetime from typing import List, Dict, Any, Optional @@ -113,7 +114,13 @@ AVAILABLE_METHODS = { "reflection_grpo": "Reflection + GRPO training", "all_memory": "All extracted memories in context (no retrieval)", "rag": "Extractor + RAG (no user vector)", + "rag_dynamic": "Extractor + RAG with dynamic topk (min=3, max=8, ratio=0.5)", + "rag_rewrite": "Extractor + RAG with LLM preference rewrite/merge", + "rag_rewrite_vector": "Extractor + RAG + user vector + LLM preference rewrite", "rag_vector": "Extractor + RAG + user vector (proposed method)", + "rag_vector_fast": "Extractor + RAG + user vector with 10x learning rate", + "rag_vector_consolidate": "Extractor + RAG + user vector with session-level preference consolidation", + "rag_vector_balanced": "Extractor + RAG + user vector with balanced rewards (10x LR + positive signal for good turns)", "rag_bge": "Extractor + RAG with BGE reranker (278M)", "rag_vector_bge": "Extractor + RAG + user vector with BGE reranker (278M)", } @@ -256,6 +263,68 @@ class ExperimentRunner: # Profile will be passed to start_session() when the conversation begins return adapter + def _export_user_vectors(self, method: str, adapters: Dict[int, Any]) -> None: + """ + Export user vectors from all adapters to disk for later analysis. + + Saves both .npz (efficient numpy format) and .json (human-readable). + + Args: + method: Method name for the output directory + adapters: Dict mapping profile_idx to adapter instances + """ + method_dir = self.output_dir / method + + # Collect all user vectors from adapters + all_vectors = {} + for profile_idx, adapter in adapters.items(): + if hasattr(adapter, 'export_all_user_vectors'): + vectors = adapter.export_all_user_vectors() + all_vectors.update(vectors) + + if not all_vectors: + logger.info(f" No user vectors to export for {method}") + return + + # Save as .npz for efficient analysis + npz_path = method_dir / "user_vectors.npz" + user_ids = list(all_vectors.keys()) + k = len(all_vectors[user_ids[0]]["z_long"]) + z_long = np.zeros((len(user_ids), k), dtype=np.float32) + z_short = np.zeros((len(user_ids), k), dtype=np.float32) + reward_ma = np.zeros(len(user_ids), dtype=np.float32) + + for i, uid in enumerate(user_ids): + z_long[i] = all_vectors[uid]["z_long"] + z_short[i] = all_vectors[uid]["z_short"] + reward_ma[i] = all_vectors[uid]["reward_ma"] + + np.savez( + npz_path, + user_ids=np.array(user_ids), + z_long=z_long, + z_short=z_short, + reward_ma=reward_ma, + ) + + # Also save summary stats as JSON + summary = { + "n_users": len(user_ids), + "vector_dim": k, + "z_long_norms": {uid: all_vectors[uid]["z_long_norm"] for uid in user_ids}, + "z_short_norms": {uid: all_vectors[uid]["z_short_norm"] for uid in user_ids}, + "reward_mas": {uid: all_vectors[uid]["reward_ma"] for uid in user_ids}, + "stats": { + "z_long_norm_mean": float(np.mean([all_vectors[uid]["z_long_norm"] for uid in user_ids])), + "z_long_norm_max": float(np.max([all_vectors[uid]["z_long_norm"] for uid in user_ids])), + "z_long_norm_std": float(np.std([all_vectors[uid]["z_long_norm"] for uid in user_ids])), + } + } + with open(method_dir / "user_vectors_summary.json", "w") as f: + json.dump(summary, f, indent=2) + + logger.info(f" Exported {len(user_ids)} user vectors to {npz_path}") + def run_single_session( self, method: str, @@ -297,11 +366,11 @@ class ExperimentRunner: # Structured preferences with condition/action pref_str = "\n".join([ f"- When {p.get('condition', '')}, {p.get('action', '')}" - for p in user_prefs[:10] # Top 10 preferences + for p in user_prefs ]) else: # Simple string preferences - pref_str = "\n".join([f"- {p}" for p in user_prefs[:10]]) + pref_str = "\n".join([f"- {p}" for p in user_prefs]) else: pref_str = str(user_prefs) @@ -619,6 +688,9 @@ class ExperimentRunner: json.dump(results, f, indent=2) logger.info(f" Profile {profile_idx + 1} completed and checkpointed") + # Export user vectors at the end of sequential processing + self._export_user_vectors(method, {0: adapter}) + return results def _run_method_parallel( @@ -690,6 +762,10 @@ class ExperimentRunner: except Exception as e: logger.error(f" Profile {profile_idx} failed: {e}") + # Note: Parallel mode doesn't export user vectors because adapters are + # created/destroyed per profile. Use batch mode for vector export. + logger.info(f" Parallel mode: user vectors not exported (use batch mode)") + def _run_method_batch( self, method: str, @@ -724,7 +800,7 @@ class ExperimentRunner: else: user_client = BatchVLLMClient( vllm_url=self.config.vllm_user_url, - max_tokens=4096, + max_tokens=1024, # User responses typically short, but allow for edge cases temperature=1.0, timeout=None, max_concurrent=100, @@ -799,21 +875,34 @@ class ExperimentRunner: adapters = {} profile_sessions = {} + # Build session problem list ONCE (shared across all profiles for controlled comparison) + # Each dataset contributes exactly n_per_dataset problems (front 10), no repeats + shared_sessions = [] + dataset_names = list(self.datasets.keys()) + n_per_dataset = self.config.n_sessions_per_profile // len(dataset_names) + remainder = self.config.n_sessions_per_profile % len(dataset_names) + + for i, ds_name in enumerate(dataset_names): + ds_obj = self.datasets[ds_name] + items = ds_obj.get_testset() + n_take = n_per_dataset + (1 if i < remainder else 0) + if n_take > len(items): + logger.warning(f" Dataset {ds_name} has only {len(items)} problems, need {n_take}") + for j in range(n_take): + item = items[j % len(items)] + shared_sessions.append({"problem": item.problem, "solution": item.solution, "domain": ds_obj.domain}) + + n_conflict = int(len(shared_sessions) * self.config.conflict_ratio) + shared_session_list = [(s, idx < n_conflict) for idx, s in enumerate(shared_sessions)] + logger.info(f" Built shared session list: {len(shared_sessions)} problems from {len(dataset_names)} datasets ({n_per_dataset} each, same for all profiles)") + for profile_idx in profiles_to_run: profile = self.profiles[profile_idx] adapter = self._create_method_adapter(method, profile, use_shared_models=True) if hasattr(adapter, 'initialize'): adapter.initialize() adapters[profile_idx] = adapter - - sessions = [] - for ds_name, ds_obj in self.datasets.items(): - ds_items = ds_obj.get_testset() - for item in ds_items[:self.config.n_sessions_per_profile]: - sessions.append({"problem": item.problem, "solution": item.solution, "domain": ds_obj.domain}) - sessions = sessions[:self.config.n_sessions_per_profile] - n_conflict = int(len(sessions) * self.config.conflict_ratio) - profile_sessions[profile_idx] = [(s, idx < n_conflict) for idx, s in enumerate(sessions)] + profile_sessions[profile_idx] = shared_session_list n_sessions = self.config.n_sessions_per_profile @@ -860,9 +949,9 @@ class ExperimentRunner: user_prefs = profile.get("preferences", []) if isinstance(user_prefs, list) and user_prefs: if isinstance(user_prefs[0], dict): - pref_str = "\n".join([f"- When {p.get('condition','')}, {p.get('action','')}" for p in user_prefs[:10]]) + pref_str = "\n".join([f"- When {p.get('condition','')}, {p.get('action','')}" for p in user_prefs]) else: - pref_str = "\n".join([f"- {p}" for p in user_prefs[:10]]) + pref_str = "\n".join([f"- {p}" for p in user_prefs]) else: pref_str = str(user_prefs) @@ -916,21 +1005,105 @@ class ExperimentRunner: state["conversation"].append({"role": "user", "content": user_msg}) state["full_log"].append(parsed) - if parsed.get("enforce_preferences", False): + enforce = parsed.get("enforce_preferences", False) + if isinstance(enforce, str): + enforce = enforce.lower() == "true" + if enforce: state["enforcement_count"] += 1 + # Detect disappointment and satisfaction from user message + # Disappointment indicators (not quite right, could be better, etc.) + user_msg_lower = user_msg.lower() + disappointment = any(phrase in user_msg_lower for phrase in [ + "not quite", "not what i", "that's not", "incorrect", + "wrong", "mistake", "error", "confused", "doesn't make sense", + "try again", "not helpful", "not useful" + ]) + # Satisfaction indicators (explicit positive feedback) + satisfaction = parsed.get("should_terminate", False) or any(phrase in user_msg_lower for phrase in [ + "perfect", "exactly", "great", "thanks", "helpful", + "that's right", "correct", "good job", "well done", + "makes sense", "understand now", "got it" + ]) + + # Store parsed feedback for REINFORCE (applied AFTER prepare_prompt sets pending_rl_update) + state["_pending_feedback"] = { + "user_msg": user_msg, + "enforce": bool(enforce), + "disappointment": disappointment and not enforce, # Don't double-count + "satisfaction": satisfaction and not enforce, # Don't count if also enforcing + "draft_answer": bool(parsed.get("draft_answer")), + } + if parsed.get("should_terminate", False) or TERMINATION_SIGNAL in user_msg: to_remove.append(pidx) continue - # Prepare agent prompt for batching (don't call LLM yet) + # Batch preference extraction for PersonalizedLLM adapters + extraction_batch = [] # (pidx, query) + remaining_active = [pidx for pidx in active_list if pidx not in to_remove] + for pidx in remaining_active: + adapter = adapters.get(pidx) + if adapter and hasattr(adapter, '_llm') and hasattr(adapter._llm, 'enable_preference_extraction'): + if adapter._llm.enable_preference_extraction and adapter._llm._extractor is not None: + query = adapter._llm.get_last_user_query(adapter._current_user_id) if hasattr(adapter._llm, 'get_last_user_query') else None + if not query: + state = all_states[pidx] + query = state["conversation"][-1]["content"] if state["conversation"] else "" + if query: + extraction_batch.append((pidx, query)) + + if extraction_batch: + extractor = extraction_batch[0][1] # just need any adapter to get the extractor + adapter0 = adapters[extraction_batch[0][0]] + shared_extractor = adapter0._llm._extractor + if hasattr(shared_extractor, 'batch_extract_preferences'): + queries = [q for _, q in extraction_batch] + batch_results = shared_extractor.batch_extract_preferences(queries) + for (pidx, _), pref_dict in zip(extraction_batch, batch_results): + adapter = adapters[pidx] + adapter._llm.apply_extracted_preferences(adapter._current_user_id, pref_dict) + else: + # Fallback: sequential + for pidx, query in extraction_batch: + adapter = adapters[pidx] + adapter._llm._extractor.extract_turn(adapter._llm._sessions[adapter._current_user_id].session_state.history) + + # Batch scaffolding for reflection adapters before prepare_prompt + scaffolding_batch = [] # (pidx, prompt) + remaining_active = [pidx for pidx in active_list if pidx not in to_remove] + for pidx in remaining_active: + adapter = adapters.get(pidx) + if adapter and hasattr(adapter, 'get_scaffolding_prompt'): + state = all_states[pidx] + # Temporarily add user msg to history for scaffolding + agent_notes = adapter._user_notes.get(adapter._current_user_id, "No notes yet about this user.") + if adapter.with_scaffolding and agent_notes != "No notes yet about this user.": + prompt = adapter.get_scaffolding_prompt( + state["conversation"], agent_notes) + if prompt is not None: + scaffolding_batch.append((pidx, prompt)) + + if scaffolding_batch: + scaff_messages = [[{"role": "user", "content": p}] for _, p in scaffolding_batch] + scaff_responses = agent_client.batch_completion(scaff_messages) + for (pidx, _), resp in zip(scaffolding_batch, scaff_responses): + adapter = adapters[pidx] + adapter._scaffolding_result = resp if resp else None + + # Prepare agent prompts for batching + # NOTE: prepare_prompt calls chat_prepare which sets pending_rl_update + # from the previous turn's data. REINFORCE feedback must be applied + # AFTER this call so that pending_rl_update is available. + for pidx in remaining_active: + state = all_states[pidx] try: adapter = adapters[pidx] + user_msg = state["conversation"][-1]["content"] if hasattr(adapter, 'prepare_prompt'): messages, context = adapter.prepare_prompt(user_msg, state["conversation"][:-1]) agent_prompts_batch.append((pidx, messages, context)) elif hasattr(adapter, 'generate_response'): - # Fallback for adapters without prepare_prompt agent_prompts_batch.append((pidx, None, None)) else: state["conversation"].append({"role": "assistant", "content": "[Error: Adapter not configured]"}) @@ -938,6 +1111,53 @@ class ExperimentRunner: logger.error(f" Agent prepare error p{pidx} t{turn}: {e}") state["conversation"].append({"role": "assistant", "content": "I apologize, I encountered an error. Could you rephrase?"}) + # Apply REINFORCE feedback NOW (after prepare_prompt set pending_rl_update) + for pidx in remaining_active: + state = all_states[pidx] + fb = state.pop("_pending_feedback", None) + if fb: + adapter = adapters.get(pidx) + if adapter and hasattr(adapter, 'process_user_turn'): + adapter.process_user_turn( + user_response=fb["user_msg"], + enforce_preferences=fb["enforce"], + express_disappointment=fb.get("disappointment", False), + express_satisfaction=fb["satisfaction"], + draft_answer_updated=fb["draft_answer"], + ) + + # Also apply feedback for terminated sessions (they skipped prepare_prompt + # but still need the reward signal from their last turn) + for pidx in to_remove: + state = all_states.get(pidx) + if not state: + continue + fb = state.pop("_pending_feedback", None) + if fb: + adapter = adapters.get(pidx) + if adapter and hasattr(adapter, 'process_user_turn'): + # For terminated sessions, we can't call prepare_prompt + # (no next turn), but we still want the reward applied. + # Call chat_prepare with a dummy to set pending_rl_update, + # then apply feedback. + try: + if hasattr(adapter, '_llm') and hasattr(adapter._llm, 'chat_prepare'): + adapter._llm.chat_prepare( + adapter._current_user_id, + fb["user_msg"], + skip_extraction=True, + skip_auto_reward=True, + ) + adapter.process_user_turn( + user_response=fb["user_msg"], + enforce_preferences=fb["enforce"], + express_disappointment=fb.get("disappointment", False), + express_satisfaction=fb["satisfaction"], + draft_answer_updated=fb["draft_answer"], + ) + except Exception: + pass # Best effort for terminated sessions + # Batch vLLM call for all agent prompts if agent_prompts_batch: # Separate prompts that can be batched from fallback @@ -979,6 +1199,25 @@ class ExperimentRunner: active_set -= set(to_remove) + # Batch note-update for reflection adapters before end_session + note_update_batch = [] # (profile_idx, messages) + for profile_idx in profiles_to_run: + if profile_idx not in all_states: + continue + adapter = adapters.get(profile_idx) + if adapter and hasattr(adapter, 'get_note_update_prompt'): + prompt_msgs = adapter.get_note_update_prompt() + if prompt_msgs is not None: + note_update_batch.append((profile_idx, prompt_msgs)) + + if note_update_batch: + note_messages = [msgs for _, msgs in note_update_batch] + note_responses = agent_client.batch_completion(note_messages) + for (profile_idx, _), resp in zip(note_update_batch, note_responses): + if resp: + adapter = adapters[profile_idx] + adapter.apply_note_update_response(resp) + # Save results for this session round for profile_idx in profiles_to_run: if profile_idx not in all_states: @@ -995,10 +1234,20 @@ class ExperimentRunner: task_success = 0 for entry in full_log: if entry.get("should_terminate", False): - draft = entry.get("draft_answer", "") - if draft and "don't know" not in draft.lower() and len(draft) > 20: + draft = str(entry.get("draft_answer", "")) + if draft and "don't know" not in draft.lower(): task_success = 1 + # End session on adapter (applies task completion reward for REINFORCE) + adapter = adapters.get(profile_idx) + if adapter and hasattr(adapter, 'end_session'): + # Skip note update if batch already handled it + skip_notes = hasattr(adapter, 'get_note_update_prompt') + try: + adapter.end_session(task_success=bool(task_success), skip_note_update=skip_notes) + except TypeError: + adapter.end_session(task_success=bool(task_success)) + results.append({ "method": method, "profile_id": self.profiles[profile_idx].get("user_id", f"user_{profile_idx}"), @@ -1023,6 +1272,33 @@ class ExperimentRunner: "adapter_metrics": {}, }) + # Collect adapter metrics (e.g. user_vector_norm for rag_vector) + adapter = adapters.get(profile_idx) + if adapter and hasattr(adapter, 'get_user_vector'): + user_id = self.profiles[profile_idx].get("user_id", f"user_{profile_idx}") + vec = adapter.get_user_vector(user_id) + if vec is not None: + results[-1]["adapter_metrics"] = { + "user_vector_norm": float(np.linalg.norm(vec)), + } + + # Save user vector snapshots every 10 sessions + if (session_idx + 1) % 10 == 0: + vectors_dir = checkpoint_file.parent / "vectors" + vectors_dir.mkdir(parents=True, exist_ok=True) + user_vectors = {} + for profile_idx in profiles_to_run: + adapter = adapters.get(profile_idx) + if adapter and hasattr(adapter, 'get_user_vector'): + user_id = self.profiles[profile_idx].get("user_id", f"user_{profile_idx}") + vec = adapter.get_user_vector(user_id) + if vec is not None: + user_vectors[user_id] = vec + if user_vectors: + snapshot_path = vectors_dir / f"vectors_session_{session_idx+1}.npy" + np.save(snapshot_path, user_vectors) + logger.info(f" Saved {len(user_vectors)} user vectors to {snapshot_path}") + # Checkpoint after each session round with session-level tracking # Only increment for profiles that actually ran in this round (those in all_states) for profile_idx in all_states.keys(): @@ -1043,6 +1319,9 @@ class ExperimentRunner: rate = sessions_done / elapsed * 3600 if elapsed > 0 else 0 logger.info(f" Session round {session_idx+1}/{n_sessions}: {sessions_done} total, {rate:.0f} sessions/hr") + # Export user vectors before cleanup (for RAG methods with user vectors) + self._export_user_vectors(method, adapters) + # Explicitly free adapter models to prevent GPU OOM across methods for pidx, adapter in adapters.items(): if hasattr(adapter, 'cleanup'): diff --git a/collaborativeagents/scripts/test_new_rewrite.sh b/collaborativeagents/scripts/test_new_rewrite.sh new file mode 100755 index 0000000..1ade8ea --- /dev/null +++ b/collaborativeagents/scripts/test_new_rewrite.sh @@ -0,0 +1,50 @@ +#!/bin/bash +# 小规模测试:验证新的rewrite prompt是否降低E/T +# 10 profiles × 10 sessions = 100 sessions + +echo "$(date '+%H:%M:%S') 启动 rag_rewrite 小规模测试 (新prompt)..." + +cd /workspace/personalization-user-model + +python collaborativeagents/scripts/run_experiments.py \ + --methods rag_rewrite \ + --datasets math-hard,bigcodebench \ + --n-profiles 10 \ + --n-sessions 10 \ + --max-turns 10 \ + --use-vllm \ + --vllm-agent-url http://localhost:8003/v1 \ + --vllm-user-url http://localhost:8004/v1 \ + --use-batch-processing \ + --batch-size 4 \ + --parallel-profiles 10 \ + --profile-path collaborativeagents/data/complex_profiles_v2/profiles_200.jsonl \ + --output-dir collaborativeagents/results/test_new_rewrite_10x10 + +echo "$(date '+%H:%M:%S') 测试完成" + +# 自动分析结果 +python3 << 'ANALYZE' +import json +import numpy as np + +result_path = "collaborativeagents/results/test_new_rewrite_10x10" +import glob +results_file = glob.glob(f"{result_path}/*/rag_rewrite/results.json") + +if results_file: + with open(results_file[0]) as f: + data = json.load(f) + + enforcements = sum(r["metrics"]["enforcement_count"] for r in data) + turns = sum(r["metrics"]["total_turns"] for r in data) + successes = sum(1 for r in data if r["metrics"]["task_success"]) + + print(f"\n=== 新Rewrite Prompt测试结果 ===") + print(f"Sessions: {len(data)}") + print(f"Success Rate: {100*successes/len(data):.1f}%") + print(f"E/T: {enforcements/turns:.4f}") + print(f"(对比旧rewrite E/T: 0.194)") +else: + print("结果文件未找到") +ANALYZE diff --git a/collaborativeagents/scripts/visualize_user_vectors.py b/collaborativeagents/scripts/visualize_user_vectors.py new file mode 100644 index 0000000..203cb68 --- /dev/null +++ b/collaborativeagents/scripts/visualize_user_vectors.py @@ -0,0 +1,407 @@ +#!/usr/bin/env python3 +""" +User Vector Visualization Script + +Visualizes learned user vectors using t-SNE and PCA for dimensionality reduction. +Supports multiple coloring schemes to analyze user clusters. + +Usage: + python visualize_user_vectors.py --results-dir ../results/fullrun_3methods + python visualize_user_vectors.py --vectors-file user_vectors.npy --profiles-file profiles.json +""" + +import argparse +import json +import numpy as np +import matplotlib.pyplot as plt +from pathlib import Path +from typing import Dict, List, Optional, Tuple +from sklearn.manifold import TSNE +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler +import warnings +warnings.filterwarnings('ignore') + + +def load_user_vectors(results_dir: Path) -> Tuple[np.ndarray, List[int]]: + """Load user vectors from experiment results.""" + vectors = [] + user_ids = [] + + # Try to find user vectors in different locations + possible_paths = [ + results_dir / "user_vectors.npy", + results_dir / "rag_vector" / "user_vectors.npy", + results_dir / "checkpoints" / "user_vectors.npy", + ] + + for path in possible_paths: + if path.exists(): + data = np.load(path, allow_pickle=True) + if isinstance(data, np.ndarray): + if data.dtype == object: + # Dictionary format + data = data.item() + for uid, vec in data.items(): + user_ids.append(int(uid)) + vectors.append(vec) + else: + # Direct array format + vectors = data + user_ids = list(range(len(data))) + print(f"Loaded {len(vectors)} user vectors from {path}") + return np.array(vectors), user_ids + + # Try to extract from results.json + results_files = list(results_dir.glob("**/results.json")) + for rf in results_files: + try: + with open(rf) as f: + data = json.load(f) + # Extract user vectors if stored in results + if isinstance(data, dict) and "user_vectors" in data: + for uid, vec in data["user_vectors"].items(): + user_ids.append(int(uid)) + vectors.append(np.array(vec)) + print(f"Loaded {len(vectors)} user vectors from {rf}") + return np.array(vectors), user_ids + except: + continue + + raise FileNotFoundError(f"No user vectors found in {results_dir}") + + +def load_profiles(profiles_path: Path) -> List[Dict]: + """Load user profiles for labeling.""" + if profiles_path.suffix == '.jsonl': + profiles = [] + with open(profiles_path) as f: + for line in f: + profiles.append(json.loads(line)) + return profiles + else: + with open(profiles_path) as f: + return json.load(f) + + +def extract_profile_features(profiles: List[Dict]) -> Dict[str, List]: + """Extract features from profiles for coloring.""" + features = { + "categories": [], + "n_preferences": [], + "persona_length": [], + } + + for p in profiles: + # Extract categories if available + cats = p.get("categories", []) + features["categories"].append(cats[0] if cats else "unknown") + + # Number of preferences + prefs = p.get("preferences", []) + features["n_preferences"].append(len(prefs)) + + # Persona length + persona = p.get("persona", "") + features["persona_length"].append(len(persona)) + + return features + + +def apply_tsne(vectors: np.ndarray, perplexity: int = 30, max_iter: int = 1000) -> np.ndarray: + """Apply t-SNE dimensionality reduction.""" + # Standardize vectors + scaler = StandardScaler() + vectors_scaled = scaler.fit_transform(vectors) + + # Adjust perplexity if needed + n_samples = len(vectors) + perplexity = min(perplexity, n_samples - 1) + + tsne = TSNE( + n_components=2, + perplexity=perplexity, + max_iter=max_iter, + random_state=42, + init='pca', + learning_rate='auto' + ) + return tsne.fit_transform(vectors_scaled) + + +def apply_pca(vectors: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Apply PCA dimensionality reduction. Returns (2D projection, explained variance).""" + scaler = StandardScaler() + vectors_scaled = scaler.fit_transform(vectors) + + pca = PCA(n_components=min(10, vectors.shape[1])) + transformed = pca.fit_transform(vectors_scaled) + + return transformed[:, :2], pca.explained_variance_ratio_ + + +def plot_comparison( + vectors: np.ndarray, + user_ids: List[int], + profiles: Optional[List[Dict]] = None, + output_path: Optional[Path] = None, + title_prefix: str = "" +): + """Create side-by-side t-SNE and PCA plots.""" + + # Apply dimensionality reduction + print("Applying t-SNE...") + tsne_2d = apply_tsne(vectors) + + print("Applying PCA...") + pca_2d, pca_variance = apply_pca(vectors) + + # Prepare coloring + if profiles and len(profiles) >= len(user_ids): + features = extract_profile_features(profiles) + color_by = features["n_preferences"] + color_label = "Number of Preferences" + else: + color_by = user_ids + color_label = "User ID" + + # Create figure + fig, axes = plt.subplots(1, 2, figsize=(16, 7)) + + # t-SNE plot + ax1 = axes[0] + scatter1 = ax1.scatter( + tsne_2d[:, 0], tsne_2d[:, 1], + c=color_by, cmap='viridis', alpha=0.7, s=50 + ) + ax1.set_xlabel('t-SNE Dimension 1') + ax1.set_ylabel('t-SNE Dimension 2') + ax1.set_title(f'{title_prefix}t-SNE Visualization\n({len(vectors)} users)') + plt.colorbar(scatter1, ax=ax1, label=color_label) + + # PCA plot + ax2 = axes[1] + scatter2 = ax2.scatter( + pca_2d[:, 0], pca_2d[:, 1], + c=color_by, cmap='viridis', alpha=0.7, s=50 + ) + ax2.set_xlabel(f'PC1 ({pca_variance[0]*100:.1f}% variance)') + ax2.set_ylabel(f'PC2 ({pca_variance[1]*100:.1f}% variance)') + ax2.set_title(f'{title_prefix}PCA Visualization\n(Top 2 components: {(pca_variance[0]+pca_variance[1])*100:.1f}% variance)') + plt.colorbar(scatter2, ax=ax2, label=color_label) + + plt.tight_layout() + + if output_path: + plt.savefig(output_path, dpi=150, bbox_inches='tight') + print(f"Saved comparison plot to {output_path}") + + plt.show() + + return tsne_2d, pca_2d, pca_variance + + +def plot_by_category( + vectors: np.ndarray, + user_ids: List[int], + profiles: List[Dict], + output_path: Optional[Path] = None +): + """Create plots colored by preference category.""" + + features = extract_profile_features(profiles) + categories = features["categories"] + unique_cats = list(set(categories)) + cat_to_idx = {c: i for i, c in enumerate(unique_cats)} + cat_colors = [cat_to_idx[c] for c in categories[:len(user_ids)]] + + # Apply reductions + tsne_2d = apply_tsne(vectors) + pca_2d, pca_variance = apply_pca(vectors) + + fig, axes = plt.subplots(1, 2, figsize=(16, 7)) + + # t-SNE by category + ax1 = axes[0] + scatter1 = ax1.scatter( + tsne_2d[:, 0], tsne_2d[:, 1], + c=cat_colors, cmap='tab10', alpha=0.7, s=50 + ) + ax1.set_xlabel('t-SNE Dimension 1') + ax1.set_ylabel('t-SNE Dimension 2') + ax1.set_title('t-SNE by Preference Category') + + # PCA by category + ax2 = axes[1] + scatter2 = ax2.scatter( + pca_2d[:, 0], pca_2d[:, 1], + c=cat_colors, cmap='tab10', alpha=0.7, s=50 + ) + ax2.set_xlabel(f'PC1 ({pca_variance[0]*100:.1f}%)') + ax2.set_ylabel(f'PC2 ({pca_variance[1]*100:.1f}%)') + ax2.set_title('PCA by Preference Category') + + # Add legend + handles = [plt.scatter([], [], c=[cat_to_idx[c]], cmap='tab10', label=c) + for c in unique_cats[:10]] # Limit to 10 categories + fig.legend(handles, unique_cats[:10], loc='center right', title='Category') + + plt.tight_layout() + plt.subplots_adjust(right=0.85) + + if output_path: + plt.savefig(output_path, dpi=150, bbox_inches='tight') + print(f"Saved category plot to {output_path}") + + plt.show() + + +def plot_pca_variance(vectors: np.ndarray, output_path: Optional[Path] = None): + """Plot PCA explained variance to understand dimensionality.""" + scaler = StandardScaler() + vectors_scaled = scaler.fit_transform(vectors) + + n_components = min(50, vectors.shape[1], vectors.shape[0]) + pca = PCA(n_components=n_components) + pca.fit(vectors_scaled) + + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + + # Individual variance + ax1 = axes[0] + ax1.bar(range(1, n_components + 1), pca.explained_variance_ratio_ * 100) + ax1.set_xlabel('Principal Component') + ax1.set_ylabel('Explained Variance (%)') + ax1.set_title('PCA Explained Variance by Component') + ax1.set_xlim(0, n_components + 1) + + # Cumulative variance + ax2 = axes[1] + cumvar = np.cumsum(pca.explained_variance_ratio_) * 100 + ax2.plot(range(1, n_components + 1), cumvar, 'b-o', markersize=4) + ax2.axhline(y=90, color='r', linestyle='--', label='90% variance') + ax2.axhline(y=95, color='g', linestyle='--', label='95% variance') + ax2.set_xlabel('Number of Components') + ax2.set_ylabel('Cumulative Explained Variance (%)') + ax2.set_title('PCA Cumulative Explained Variance') + ax2.legend() + ax2.set_xlim(0, n_components + 1) + ax2.set_ylim(0, 105) + + # Find components needed for 90% and 95% variance + n_90 = np.argmax(cumvar >= 90) + 1 + n_95 = np.argmax(cumvar >= 95) + 1 + print(f"Components for 90% variance: {n_90}") + print(f"Components for 95% variance: {n_95}") + + plt.tight_layout() + + if output_path: + plt.savefig(output_path, dpi=150, bbox_inches='tight') + print(f"Saved variance plot to {output_path}") + + plt.show() + + return pca.explained_variance_ratio_ + + +def generate_synthetic_vectors(n_users: int = 200, dim: int = 64) -> np.ndarray: + """Generate synthetic user vectors for testing visualization.""" + np.random.seed(42) + + # Create 5 clusters of users + n_clusters = 5 + cluster_size = n_users // n_clusters + vectors = [] + + for i in range(n_clusters): + # Each cluster has a different center + center = np.random.randn(dim) * 2 + # Users in cluster are variations around center + cluster_vectors = center + np.random.randn(cluster_size, dim) * 0.5 + vectors.append(cluster_vectors) + + # Add remaining users + remaining = n_users - n_clusters * cluster_size + if remaining > 0: + vectors.append(np.random.randn(remaining, dim)) + + return np.vstack(vectors) + + +def main(): + parser = argparse.ArgumentParser(description="Visualize user vectors with t-SNE and PCA") + parser.add_argument("--results-dir", type=str, help="Path to experiment results directory") + parser.add_argument("--vectors-file", type=str, help="Path to user vectors .npy file") + parser.add_argument("--profiles-file", type=str, help="Path to user profiles JSON file") + parser.add_argument("--output-dir", type=str, default=".", help="Output directory for plots") + parser.add_argument("--demo", action="store_true", help="Run demo with synthetic data") + args = parser.parse_args() + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + if args.demo: + print("Running demo with synthetic user vectors...") + vectors = generate_synthetic_vectors(200, 64) + user_ids = list(range(200)) + profiles = None + title_prefix = "[Demo] " + elif args.vectors_file: + vectors = np.load(args.vectors_file) + user_ids = list(range(len(vectors))) + profiles = None + if args.profiles_file: + profiles = load_profiles(Path(args.profiles_file)) + title_prefix = "" + elif args.results_dir: + results_dir = Path(args.results_dir) + vectors, user_ids = load_user_vectors(results_dir) + + # Try to find profiles + profiles = None + profile_paths = [ + results_dir / "generated_profiles.json", + results_dir.parent / "profiles.json", + Path("../data/complex_profiles_v2/profiles_200.jsonl"), + ] + for pp in profile_paths: + if pp.exists(): + profiles = load_profiles(pp) + print(f"Loaded {len(profiles)} profiles from {pp}") + break + title_prefix = "" + else: + print("Please provide --results-dir, --vectors-file, or --demo") + return + + print(f"\nUser vectors shape: {vectors.shape}") + print(f"Number of users: {len(user_ids)}") + + # Generate plots + print("\n=== Generating comparison plot ===") + plot_comparison( + vectors, user_ids, profiles, + output_path=output_dir / "user_vectors_comparison.png", + title_prefix=title_prefix + ) + + print("\n=== Generating PCA variance plot ===") + plot_pca_variance( + vectors, + output_path=output_dir / "user_vectors_pca_variance.png" + ) + + if profiles and len(profiles) >= len(user_ids): + print("\n=== Generating category plot ===") + plot_by_category( + vectors, user_ids, profiles, + output_path=output_dir / "user_vectors_by_category.png" + ) + + print("\n=== Done! ===") + print(f"Plots saved to {output_dir}") + + +if __name__ == "__main__": + main() |
