diff options
Diffstat (limited to 'collaborativeagents/scripts/run_experiments.py')
| -rw-r--r-- | collaborativeagents/scripts/run_experiments.py | 1328 |
1 files changed, 1328 insertions, 0 deletions
diff --git a/collaborativeagents/scripts/run_experiments.py b/collaborativeagents/scripts/run_experiments.py new file mode 100644 index 0000000..0ba0ba0 --- /dev/null +++ b/collaborativeagents/scripts/run_experiments.py @@ -0,0 +1,1328 @@ +#!/usr/bin/env python3 +""" +Main experiment orchestrator for personalization benchmark. + +This script runs all baselines and the proposed methods with PROPER multi-turn +conversation simulation, user preference enforcement, and LLM-based evaluation. + +Usage: + python run_experiments.py --config config.yaml + python run_experiments.py --methods vanilla,rag,rag_vector --datasets gpqa,aime +""" + +import argparse +import json +import yaml +import os +import sys +from pathlib import Path +from datetime import datetime +from typing import List, Dict, Any, Optional +from dataclasses import dataclass, asdict +import logging +import re +from concurrent.futures import ThreadPoolExecutor, as_completed +import threading +import time + +# Add paths +sys.path.insert(0, str(Path(__file__).parent.parent)) +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from datasets_extended import get_dataset, get_all_datasets, get_challenging_datasets +from evaluation.llm_judge import LLMJudge, BatchEvaluator, ConversationMetrics +from conflict_scenario_generator import ConflictScenarioGenerator +from adapters.personalized_llm_adapter import PersonalizedLLMAdapter, create_baseline_adapter +from agents.local_user_agent import LocalUserAgent, SharedLocalUserAgent, TERMINATION_SIGNAL +from agents.vllm_user_agent import VLLMUserAgent, VLLMAgentClient +from agents.openai_user_agent import OpenAIUserAgent +from agents.batch_vllm_agent import BatchConversationGenerator, BatchVLLMClient + + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + + +@dataclass +class ExperimentConfig: + """Configuration for an experiment run.""" + # Methods to compare + methods: List[str] + + # Datasets to use + datasets: List[str] + + # User profiles + n_profiles: int = 200 + profile_path: Optional[str] = None + + # Profile range (for splitting jobs) + start_profile: int = 0 # Inclusive, 0-indexed + end_profile: Optional[int] = None # Exclusive, None means all + + # Session settings + n_sessions_per_profile: int = 30 + max_turns_per_session: int = 15 # Increased for harder tasks + + # Model settings + user_model: str = "meta-llama/Llama-3.3-70B-Instruct" + agent_model: str = "meta-llama/Llama-3.1-8B-Instruct" + judge_model: str = "meta-llama/Llama-3.3-70B-Instruct" + + # Output settings + output_dir: str = "results" + save_conversations: bool = True + + # Conflict testing + conflict_ratio: float = 0.3 # proportion of queries that trigger conflicts + + # Compute settings + batch_size: int = 4 + n_gpus: int = 4 + + # vLLM settings (for high-performance inference) + use_vllm: bool = False + vllm_user_url: str = "http://localhost:8004/v1" # 70B user simulator + vllm_agent_url: str = "http://localhost:8003/v1" # 8B agent + + # OpenAI user simulator (alternative to vLLM user agent) + use_openai_user: bool = False + openai_user_model: str = "gpt-5" # Model name for OpenAI user agent + + # Reward mode: "keyword" (implicit user signals) or "llm" (GPT-5-nano judge) + # This is a global option applied to ALL methods that use RL updates + reward_mode: str = "keyword" + + # Parallel/Batch processing + parallel_profiles: int = 50 # Number of profiles to process in parallel + use_batch_processing: bool = True # Use turn-synchronous batch processing for vanilla/all_memory + batch_size_conversations: int = 50 # Number of conversations to batch together + + # Continue from existing experiment (for extending sessions) + continue_from: Optional[str] = None # Path to existing output directory to continue from + + +# Available methods +AVAILABLE_METHODS = { + "vanilla": "No memory, no personalization", + "contextual": "Full history in context, summarize when overflow", + "reflection": "CollaborativeAgents' agent_notes approach", + "reflection_grpo": "Reflection + GRPO training", + "all_memory": "All extracted memories in context (no retrieval)", + "rag": "Extractor + RAG (no user vector)", + "rag_vector": "Extractor + RAG + user vector (proposed method)", + "rag_bge": "Extractor + RAG with BGE reranker (278M)", + "rag_vector_bge": "Extractor + RAG + user vector with BGE reranker (278M)", +} + + +class ExperimentRunner: + """Main experiment runner.""" + + def __init__(self, config: ExperimentConfig): + self.config = config + + # Use existing directory if continuing, otherwise create new timestamped one + if config.continue_from: + self.output_dir = Path(config.continue_from) + if not self.output_dir.exists(): + raise ValueError(f"Continue-from directory does not exist: {config.continue_from}") + logger.info(f"Continuing from existing experiment: {self.output_dir}") + else: + self.output_dir = Path(config.output_dir) / datetime.now().strftime("%Y%m%d_%H%M%S") + self.output_dir.mkdir(parents=True, exist_ok=True) + + # Save/update config + with open(self.output_dir / "config.yaml", "w") as f: + yaml.dump(asdict(config), f) + + # Initialize components + self.judge = LLMJudge(model_name=config.judge_model) + self.batch_evaluator = BatchEvaluator(self.judge) + self.conflict_generator = ConflictScenarioGenerator() + + # Load datasets + self.datasets = {} + for ds_name in config.datasets: + try: + self.datasets[ds_name] = get_dataset(ds_name) + logger.info(f"Loaded dataset: {ds_name}") + except Exception as e: + logger.warning(f"Failed to load dataset {ds_name}: {e}") + + # Load or generate profiles + self.profiles = self._load_profiles() + + def _load_profiles(self) -> List[Dict]: + """Load user profiles from file or generate.""" + logger.info(f"Profile path configured: {self.config.profile_path}") + + if self.config.profile_path: + profile_path = Path(self.config.profile_path) + if profile_path.exists(): + profiles = [] + with open(profile_path) as f: + for line in f: + line = line.strip() + if line: + profiles.append(json.loads(line)) + logger.info(f"Loaded {len(profiles)} profiles from {self.config.profile_path}") + return profiles[:self.config.n_profiles] + else: + logger.warning(f"Profile path does not exist: {self.config.profile_path}") + + # Generate simple placeholder profiles if no file provided + logger.info(f"Generating {self.config.n_profiles} placeholder profiles...") + profiles = [] + for i in range(self.config.n_profiles): + profiles.append({ + "id": i, + "persona": f"User {i+1} is a curious individual seeking help with problem solving.", + "preferences": [ + "Provide clear, step-by-step explanations", + "Use simple language when possible", + "Give examples to illustrate concepts", + "Be concise but thorough", + "Acknowledge when something is uncertain" + ] + }) + + # Save generated profiles + profile_path = self.output_dir / "generated_profiles.json" + with open(profile_path, "w") as f: + json.dump(profiles, f, indent=2) + + logger.info(f"Generated and saved {len(profiles)} placeholder profiles") + return profiles + + def _create_method_adapter(self, method: str, profile: Dict, use_shared_models: bool = False) -> Any: + """Create adapter for a specific method. + + Args: + method: One of the baseline method names + profile: User profile dict (used later in start_session, not constructor) + use_shared_models: If True, share embedding/reranker models across parallel + workers. ESSENTIAL for parallel profile processing to avoid OOM. + + Returns: + Configured adapter instance + """ + # Auto-detect available GPUs and set device assignment accordingly + # Layout with local 70B user (4 GPUs): + # GPU 0-1: 70B user simulator (TP=2) + # GPU 2: 8B agent vLLM server + # GPU 3: Embedding + Reranker + Extractor + # Layout with OpenAI user (2 GPUs): + # GPU 0: 8B agent vLLM server + # GPU 1: Embedding + Reranker + Extractor + device_assignment = None + try: + import torch + n_gpus = torch.cuda.device_count() + if n_gpus >= 4: + # 4 GPU layout: 70B user on 0-1, agent on 2, adapters on 3 + device_assignment = { + "embed": "cuda:3", + "reranker": "cuda:3", + "extractor": "cuda:3", + } + elif n_gpus >= 2: + # 2 GPU layout: agent on 0, adapters on 1 + device_assignment = { + "embed": "cuda:1", + "reranker": "cuda:1", + "extractor": "cuda:1", + } + elif n_gpus == 1: + device_assignment = { + "embed": "cuda:0", + "reranker": "cuda:0", + "extractor": "cuda:0", + } + except ImportError: + pass + + adapter = create_baseline_adapter( + method, + device_assignment=device_assignment, + use_vllm=self.config.use_vllm, + use_shared_models=use_shared_models, + reward_mode=self.config.reward_mode, + ) + # Profile will be passed to start_session() when the conversation begins + return adapter + + def run_single_session( + self, + method: str, + profile: Dict, + problem: Dict, + is_conflict_query: bool = False, + adapter: Any = None, + user_agent: Any = None + ) -> Dict: + """Run a single session with PROPER multi-turn conversation and user simulation. + + This implements: + 1. User simulator that role-plays with preferences + 2. Multi-turn conversation (up to max_turns) + 3. Preference enforcement by simulated user + 4. Proper metrics extraction from conversation + """ + # Use provided adapter (reused across sessions) or create new one + agent_adapter = adapter if adapter else self._create_method_adapter(method, profile) + + # Prepare conflict scenario if needed + conflict_scenario = None + original_problem = problem.get("problem", problem.get("question", "")) + if is_conflict_query: + conflict_scenario = self.conflict_generator.generate_for_profile( + profile["preferences"], + problem.get("domain", "general") + ) + if conflict_scenario: + problem = dict(problem) + problem["problem"] = conflict_scenario["query"] + + query = problem.get("problem", problem.get("question", "")) + + # Extract user preferences as formatted string + user_prefs = profile.get("preferences", []) + if isinstance(user_prefs, list) and len(user_prefs) > 0: + if isinstance(user_prefs[0], dict): + # Structured preferences with condition/action + pref_str = "\n".join([ + f"- When {p.get('condition', '')}, {p.get('action', '')}" + for p in user_prefs[:10] # Top 10 preferences + ]) + else: + # Simple string preferences + pref_str = "\n".join([f"- {p}" for p in user_prefs[:10]]) + else: + pref_str = str(user_prefs) + + user_persona = profile.get("persona", "A user seeking help with problem solving.") + + # Create user agent for this session (or reuse provided one) + if user_agent is None: + if self.config.use_openai_user: + user_agent = OpenAIUserAgent( + user_task_description="Help the user solve their problem.", + problem=query, + user_persona=user_persona, + user_preferences=pref_str, + model=self.config.openai_user_model, + ) + elif self.config.use_vllm: + user_agent = VLLMUserAgent( + user_task_description="Help the user solve their problem.", + problem=query, + user_persona=user_persona, + user_preferences=pref_str, + vllm_url=self.config.vllm_user_url, + ) + else: + user_agent = SharedLocalUserAgent( + user_task_description="Help the user solve their problem.", + problem=query, + user_persona=user_persona, + user_preferences=pref_str, + ) + + # Initialize conversation + turns = [] + full_user_log = [] # Detailed user agent outputs + + # Metrics tracking + enforcement_count = 0 + disappointment_count = 0 + user_token_count = 0 + agent_token_count = 0 + preference_compliance_scores = [] + + try: + # Initialize adapter for this user + if hasattr(agent_adapter, 'initialize'): + agent_adapter.initialize() + if hasattr(agent_adapter, 'start_session'): + agent_adapter.start_session( + user_id=profile.get("user_id", "test_user"), + user_profile={"preferences": user_prefs, "persona": user_persona} + ) + + # Start with agent greeting + conversation = [{"role": "assistant", "content": "How can I help you today?"}] + + # Multi-turn conversation loop + for turn_num in range(self.config.max_turns_per_session): + # === User Turn === + user_response = user_agent.generate_user_response(conversation) + + if user_response is None: + logger.warning(f"User agent failed to respond at turn {turn_num}") + break + + user_message = str(user_response.get("response", "")) + user_token_count += len(user_message.split()) + + # Add to conversation + conversation.append({"role": "user", "content": user_message}) + turns.append({"role": "user", "content": user_message}) + full_user_log.append(user_response) + + # Check for termination + if user_response.get("should_terminate", False) or TERMINATION_SIGNAL in user_message: + break + + # Detect preference enforcement (user correcting agent) + enforcement_keywords = ["please", "i asked", "i said", "i prefer", "can you", "could you", "instead"] + if any(kw in user_message.lower() for kw in enforcement_keywords): + enforcement_count += 1 + + # === Agent Turn === + if hasattr(agent_adapter, 'generate_response'): + response = agent_adapter.generate_response(user_message, conversation[:-1]) + agent_content = response.get("response", str(response)) if isinstance(response, dict) else str(response) + elif callable(agent_adapter): + agent_content = agent_adapter(conversation) + else: + agent_content = "[Error: Adapter not properly configured]" + + agent_token_count += len(agent_content.split()) + + # Add to conversation + conversation.append({"role": "assistant", "content": agent_content}) + turns.append({"role": "assistant", "content": agent_content}) + + # Estimate preference compliance for this turn (heuristic based on user satisfaction) + # If user doesn't enforce in next turn, assume compliance + # This is a simplified heuristic - LLM judge would be more accurate + compliance_score = 0.8 if enforcement_count == 0 else max(0.2, 1.0 - 0.2 * enforcement_count) + preference_compliance_scores.append(compliance_score) + + # End session + if hasattr(agent_adapter, 'end_session'): + adapter_metrics = agent_adapter.end_session(task_success=True) + else: + adapter_metrics = {} + + except Exception as e: + import traceback + logger.error(f"Error in session: {e}") + logger.error(f"Full traceback:\n{traceback.format_exc()}") + turns.append({"role": "assistant", "content": f"[Error: {e}]"}) + + # Compute metrics + total_turns = len(turns) + total_token_count = user_token_count + agent_token_count + + # Check if user reached a satisfactory answer (from last user response) + task_success = False + if full_user_log: + last_user = full_user_log[-1] + if last_user.get("should_terminate", False): + draft = last_user.get("draft_answer", "") + # Consider success if draft answer is not empty/"I don't know" + task_success = bool(draft) and draft.lower() != "i don't know" + + # Compute average compliance + avg_compliance = sum(preference_compliance_scores) / len(preference_compliance_scores) if preference_compliance_scores else 0.5 + + # Conflict resolution (if this was a conflict test) + conflict_accuracy = 0.0 + if is_conflict_query and conflict_scenario: + # Check if the correct preference was applied + expected_pref = conflict_scenario.get("expected_preference", "") + # Simple heuristic: check if expected preference keywords appear in agent responses + agent_texts = " ".join([t["content"] for t in turns if t["role"] == "assistant"]) + if expected_pref and any(kw in agent_texts.lower() for kw in expected_pref.lower().split()[:3]): + conflict_accuracy = 1.0 + + # Over-personalization detection (heuristic: if agent mentions preferences not in profile) + over_personalization = 0.0 + + metrics = ConversationMetrics( + task_success=task_success, + turns_to_success=total_turns if task_success else -1, + total_turns=total_turns, + user_token_count=user_token_count, + enforcement_count=enforcement_count, + disappointment_count=disappointment_count, + total_token_count=total_token_count, + agent_token_count=agent_token_count, + preference_compliance_scores=preference_compliance_scores, + conflict_resolution_accuracy=conflict_accuracy, + over_personalization_rate=over_personalization, + ) + + return { + "method": method, + "profile_id": profile.get("user_id", "unknown"), + "problem_id": problem.get("problem_id", str(hash(query))[:8]), + "problem": original_problem, + "ground_truth_solution": problem.get("solution", problem.get("answer", "")), + "is_conflict_test": is_conflict_query, + "conflict_scenario": conflict_scenario, + "conversation": {"turns": turns} if self.config.save_conversations else None, + "full_user_log": full_user_log if self.config.save_conversations else None, + "metrics": asdict(metrics), + "adapter_metrics": adapter_metrics if 'adapter_metrics' in dir() else {}, + } + + def _run_profile_sessions( + self, + method: str, + profile_idx: int, + profile: Dict, + adapter: Any = None + ) -> List[Dict]: + """Run all sessions for a single profile. Thread-safe for parallel execution.""" + profile_results = [] + + # Create vLLM-based agent client if using vLLM (for methods that need it) + vllm_agent = None + if self.config.use_vllm and method == "vanilla": + vllm_agent = VLLMAgentClient( + vllm_url=self.config.vllm_agent_url, + system_prompt="You are a helpful AI assistant for problem-solving tasks." + ) + + # Run sessions across datasets + session_idx = 0 + for ds_name, dataset in self.datasets.items(): + samples = dataset.get_testset() + + for sample in samples: + if session_idx >= self.config.n_sessions_per_profile: + break + + # Decide if this is a conflict query + is_conflict = (session_idx % int(1 / self.config.conflict_ratio)) == 0 + + problem = { + "problem": sample.problem, + "solution": sample.solution, + "problem_id": sample.problem_id, + "domain": sample.domain, + } + + try: + result = self.run_single_session( + method=method, + profile=profile, + problem=problem, + is_conflict_query=is_conflict, + adapter=vllm_agent if vllm_agent else adapter + ) + profile_results.append(result) + except Exception as e: + logger.error(f"Error in session for profile {profile_idx}: {e}") + + session_idx += 1 + + return profile_results + + def run_method(self, method: str) -> List[Dict]: + """Run all sessions for a single method with checkpointing and parallel processing.""" + logger.info(f"Running method: {method}") + + # Setup method directory and checkpoint + method_dir = self.output_dir / method + method_dir.mkdir(exist_ok=True) + checkpoint_file = method_dir / "checkpoint.json" + results_file = method_dir / "results.json" + + # Load existing results and checkpoint + results = [] + completed_profiles = set() + sessions_per_profile = {} # Track session count per profile for continue functionality + if checkpoint_file.exists(): + with open(checkpoint_file, "r") as f: + checkpoint = json.load(f) + completed_profiles = set(checkpoint.get("completed_profiles", [])) + sessions_per_profile = checkpoint.get("sessions_per_profile", {}) + logger.info(f" Resuming from checkpoint: {len(completed_profiles)} profiles completed") + if sessions_per_profile: + total_sessions = sum(sessions_per_profile.values()) + logger.info(f" Session-level tracking: {total_sessions} sessions across {len(sessions_per_profile)} profiles") + if results_file.exists(): + with open(results_file, "r") as f: + results = json.load(f) + + # Determine profile range + start_idx = self.config.start_profile + end_idx = self.config.end_profile if self.config.end_profile else len(self.profiles) + + # Build list of profiles that need more sessions + profiles_to_run = [] + for idx in range(start_idx, min(end_idx, len(self.profiles))): + existing_sessions = sessions_per_profile.get(str(idx), 0) + if existing_sessions < self.config.n_sessions_per_profile: + profiles_to_run.append(idx) + + # Log what we're running + if sessions_per_profile: + total_existing = sum(sessions_per_profile.get(str(idx), 0) for idx in profiles_to_run) + total_needed = len(profiles_to_run) * self.config.n_sessions_per_profile + logger.info(f" Running profiles {start_idx} to {end_idx-1}: {len(profiles_to_run)} profiles need sessions") + logger.info(f" Sessions: {total_existing} existing, {total_needed - total_existing} remaining") + else: + logger.info(f" Running profiles {start_idx} to {end_idx-1} ({len(profiles_to_run)} remaining)") + + # When using batch processing with vLLM or OpenAI user: use turn-synchronous batch mode + # This batches both user and agent calls for maximum throughput + if self.config.use_batch_processing and self.config.use_vllm: + user_type = "OpenAI" if self.config.use_openai_user else "local vLLM" + logger.info(f" Using BATCH processing ({user_type} user) for {method}") + return self._run_method_batch( + method, profiles_to_run, results, completed_profiles, + sessions_per_profile, checkpoint_file, results_file + ) + + # Decide on parallelization for sequential methods + n_parallel = self.config.parallel_profiles if (self.config.use_vllm or self.config.use_openai_user) else 1 + + if n_parallel > 1: + logger.info(f" Using parallel processing with {n_parallel} workers") + self._run_method_parallel( + method, profiles_to_run, results, completed_profiles, + sessions_per_profile, checkpoint_file, results_file + ) + else: + # Sequential execution (original behavior) + # Create ONE adapter per method and reuse it (avoids GPU OOM from repeated model loading) + adapter = self._create_method_adapter(method, None) + adapter.initialize() + + for profile_idx in profiles_to_run: + profile = self.profiles[profile_idx] + logger.info(f" Profile {profile_idx + 1}/{len(self.profiles)}") + + profile_results = self._run_profile_sessions(method, profile_idx, profile, adapter) + + # Add profile results to overall results + results.extend(profile_results) + completed_profiles.add(profile_idx) + sessions_per_profile[str(profile_idx)] = self.config.n_sessions_per_profile + + # Save checkpoint and results after each profile + with open(checkpoint_file, "w") as f: + json.dump({ + "completed_profiles": sorted(list(completed_profiles)), + "sessions_per_profile": sessions_per_profile + }, f) + with open(results_file, "w") as f: + json.dump(results, f, indent=2) + logger.info(f" Profile {profile_idx + 1} completed and checkpointed") + + return results + + def _run_method_parallel( + self, + method: str, + profiles_to_run: List[int], + results: List[Dict], + completed_profiles: set, + sessions_per_profile: Dict[str, int], + checkpoint_file: Path, + results_file: Path + ): + """Run profiles in parallel using ThreadPoolExecutor. + + Uses shared model singletons for embedding/reranker to avoid OOM + when multiple workers try to load their own copies. + """ + n_parallel = self.config.parallel_profiles + results_lock = threading.Lock() + start_time = time.time() + profiles_completed = 0 + + def process_profile(profile_idx: int) -> tuple: + """Process a single profile and return (profile_idx, results).""" + profile = self.profiles[profile_idx] + # Create adapter with shared models to avoid OOM from duplicate model loading + adapter = self._create_method_adapter(method, profile, use_shared_models=True) + profile_results = self._run_profile_sessions(method, profile_idx, profile, adapter) + return profile_idx, profile_results + + with ThreadPoolExecutor(max_workers=n_parallel) as executor: + # Submit all profile jobs + future_to_profile = { + executor.submit(process_profile, idx): idx + for idx in profiles_to_run + } + + # Process completed profiles + for future in as_completed(future_to_profile): + profile_idx = future_to_profile[future] + try: + idx, profile_results = future.result() + + with results_lock: + results.extend(profile_results) + completed_profiles.add(idx) + sessions_per_profile[str(idx)] = self.config.n_sessions_per_profile + profiles_completed += 1 + + # Save checkpoint with session-level tracking + with open(checkpoint_file, "w") as f: + json.dump({ + "completed_profiles": sorted(list(completed_profiles)), + "sessions_per_profile": sessions_per_profile + }, f) + with open(results_file, "w") as f: + json.dump(results, f, indent=2) + + # Log progress with throughput estimate + elapsed = time.time() - start_time + profiles_per_hour = profiles_completed / elapsed * 3600 if elapsed > 0 else 0 + sessions_per_hour = len(results) / elapsed * 3600 if elapsed > 0 else 0 + logger.info( + f" Profile {idx + 1} completed " + f"({profiles_completed}/{len(profiles_to_run)}) - " + f"{profiles_per_hour:.1f} profiles/hr, {sessions_per_hour:.1f} sessions/hr" + ) + + except Exception as e: + logger.error(f" Profile {profile_idx} failed: {e}") + + def _run_method_batch( + self, + method: str, + profiles_to_run: List[int], + results: List[Dict], + completed_profiles: set, + sessions_per_profile: Dict[str, int], + checkpoint_file: Path, + results_file: Path + ) -> List[Dict]: + """ + Turn-synchronous batch processing for ALL methods. + + At each turn, user calls are batched concurrently via AsyncOpenAI, + then agent responses go through personalization adapters. + Sessions within a profile run sequentially (for stateful memory). + """ + from agents.batch_vllm_agent import BatchOpenAIClient, BatchVLLMClient, TERMINATION_SIGNAL + from json_repair import repair_json + + start_time = time.time() + + # Create user client (OpenAI API or local vLLM) + if self.config.use_openai_user: + user_client = BatchOpenAIClient( + model=self.config.openai_user_model, + max_tokens=4096, + max_concurrent=32, + api_key=os.environ.get("OPENAI_API_KEY"), + ) + logger.info(f" Using OpenAI user simulator: {self.config.openai_user_model}") + else: + user_client = BatchVLLMClient( + vllm_url=self.config.vllm_user_url, + max_tokens=4096, + temperature=1.0, + timeout=None, + max_concurrent=100, + json_mode=True, # User simulator needs JSON output + ) + logger.info(f" Using local vLLM user simulator: {self.config.vllm_user_url}") + + # Create async agent client for batched vLLM calls + agent_client = BatchVLLMClient( + vllm_url=self.config.vllm_agent_url, + max_tokens=2048, + temperature=0.7, + timeout=None, # Infinite timeout for long generations + max_concurrent=100, + ) + + USER_PROMPT_TEMPLATE = ( + "You are a user simulator collaborating with an agent to solve a problem. " + "You will be provided with a problem description, and you must get the agent to help you solve it. " + "You will also be provided with user preferences, which you must follow and actively enforce throughout the conversation.\n\n" + "# Problem Description\n{problem}\nNote: the agent cannot see this problem description.\n\n" + "# User Persona\n{user_persona}\n\n" + "# User Preferences\n{user_preferences}\n" + "These preferences are NON-NEGOTIABLE that define how you prefer the agent to behave. They must be strictly enforced:\n" + " - **Answer clarifying questions**: The agent may ask clarifying questions before attempting an answer. " + "Answer such questions, and do not enforce preferences about answer format or content while the agent is clarifying.\n" + " - **Enforce immediately**: Every agent response must satisfy your preferences before you can proceed. " + "Explicitly ask the agent to adjust their response until it complies.\n" + " - **Never proceed without compliance**: Do NOT update your draft answer, do NOT consider terminating, " + "and do NOT move forward until the agent follows your preferences.\n\n" + "# Draft Answer Management\n" + "- **Maintain a working draft**: Start with \"I don't know\". Update your draft answer based on what you learn from agent responses.\n" + "- **Don't update when enforcing preferences**: If the agent response does not follow your preferences, " + "do NOT update your draft answer, regardless of whether the agent provides helpful information.\n\n" + "# Conversation Termination\n" + "Before generating your response, determine if you should terminate:\n" + " - Do you feel like your draft answer is a good answer to the problem?\n" + " - Do you feel like the agent cannot help further?\n" + "If the agent response does not follow your preferences, you must NOT terminate - instead, enforce the preferences.\n" + "When ready to terminate, respond with \"TERMINATE\".\n\n" + "# Output Format (respond in JSON):\n" + "{{\n" + " \"preferences_check\": \"For EACH relevant preference, evaluate: is it satisfied?\",\n" + " \"enforce_preferences\": true/false,\n" + " \"reasoning\": \"Brief reasoning (2-3 sentences). Does agent follow preferences? If no, enforce. If yes, update draft.\",\n" + " \"draft_answer\": \"Your current working draft answer\",\n" + " \"should_terminate\": true/false,\n" + " \"response\": \"Your response to the agent\"\n" + "}}" + ) + + def parse_user_response(content): + if not content: + return None + try: + parsed = repair_json(content, return_objects=True) + if isinstance(parsed, dict) and "response" in parsed: + return parsed + except: + pass + if TERMINATION_SIGNAL in (content or ""): + return {"reasoning": "", "draft_answer": "", "should_terminate": True, "response": TERMINATION_SIGNAL} + return {"reasoning": "", "draft_answer": "", "should_terminate": False, "response": content or ""} + + def reverse_roles(conversation): + return [ + {"role": "user" if m["role"] == "assistant" else "assistant", "content": m["content"]} + for m in conversation + ] + + # Create per-profile adapters + adapters = {} + profile_sessions = {} + + for profile_idx in profiles_to_run: + profile = self.profiles[profile_idx] + adapter = self._create_method_adapter(method, profile, use_shared_models=True) + if hasattr(adapter, 'initialize'): + adapter.initialize() + adapters[profile_idx] = adapter + + sessions = [] + for ds_name, ds_obj in self.datasets.items(): + ds_items = ds_obj.get_testset() + for item in ds_items[:self.config.n_sessions_per_profile]: + sessions.append({"problem": item.problem, "solution": item.solution, "domain": ds_obj.domain}) + sessions = sessions[:self.config.n_sessions_per_profile] + n_conflict = int(len(sessions) * self.config.conflict_ratio) + profile_sessions[profile_idx] = [(s, idx < n_conflict) for idx, s in enumerate(sessions)] + + n_sessions = self.config.n_sessions_per_profile + + # Calculate sessions to run per profile (accounting for existing sessions) + sessions_to_run_per_profile = {} + for profile_idx in profiles_to_run: + existing = sessions_per_profile.get(str(profile_idx), 0) + remaining = n_sessions - existing + if remaining > 0: + sessions_to_run_per_profile[profile_idx] = (existing, remaining) # (start_session, count) + + if sessions_to_run_per_profile: + total_remaining = sum(v[1] for v in sessions_to_run_per_profile.values()) + logger.info(f" Batch: {len(sessions_to_run_per_profile)} profiles, {total_remaining} sessions remaining") + else: + logger.info(f" Batch: All sessions already completed") + return results + + # Process sessions in rounds + for session_idx in range(n_sessions): + # Initialize all conversations for this round + all_states = {} # profile_idx -> state dict + active_set = set() + + for profile_idx in profiles_to_run: + # Skip if this profile doesn't need this session + if profile_idx not in sessions_to_run_per_profile: + continue + start_session, _ = sessions_to_run_per_profile[profile_idx] + if session_idx < start_session: + continue # Already completed this session + if session_idx >= len(profile_sessions[profile_idx]): + continue + problem_dict, is_conflict = profile_sessions[profile_idx][session_idx] + profile = self.profiles[profile_idx] + query = problem_dict["problem"] + + if is_conflict: + cs = self.conflict_generator.generate_for_profile( + profile.get("preferences", []), problem_dict.get("domain", "general")) + if cs: + query = cs["query"] + + user_prefs = profile.get("preferences", []) + if isinstance(user_prefs, list) and user_prefs: + if isinstance(user_prefs[0], dict): + pref_str = "\n".join([f"- When {p.get('condition','')}, {p.get('action','')}" for p in user_prefs[:10]]) + else: + pref_str = "\n".join([f"- {p}" for p in user_prefs[:10]]) + else: + pref_str = str(user_prefs) + + user_persona = profile.get("persona", "A user seeking help with problem solving.") + adapter = adapters[profile_idx] + if hasattr(adapter, 'start_session'): + adapter.start_session( + user_id=profile.get("user_id", f"user_{profile_idx}"), + user_profile={"preferences": user_prefs, "persona": user_persona} + ) + + all_states[profile_idx] = { + "conversation": [{"role": "assistant", "content": "How can I help you today?"}], + "full_log": [], + "system_prompt": USER_PROMPT_TEMPLATE.format( + problem=query, user_persona=user_persona, user_preferences=pref_str), + "problem_dict": problem_dict, + "is_conflict": is_conflict, + "enforcement_count": 0, + } + active_set.add(profile_idx) + + # Turn-synchronous loop + for turn in range(self.config.max_turns_per_session): + if not active_set: + break + + # Batch user calls + active_list = sorted(active_set) + user_msgs_batch = [] + for pidx in active_list: + state = all_states[pidx] + msgs = [{"role": "system", "content": state["system_prompt"]}] + msgs.extend(reverse_roles(state["conversation"])) + user_msgs_batch.append(msgs) + + user_responses = user_client.batch_completion(user_msgs_batch) + + # Process user responses and prepare agent prompts for batching + to_remove = [] + agent_prompts_batch = [] # List of (pidx, messages, context) + for i, pidx in enumerate(active_list): + state = all_states[pidx] + parsed = parse_user_response(user_responses[i]) + + if parsed is None: + to_remove.append(pidx) + continue + + user_msg = str(parsed.get("response", "")) + state["conversation"].append({"role": "user", "content": user_msg}) + state["full_log"].append(parsed) + + if parsed.get("enforce_preferences", False): + state["enforcement_count"] += 1 + + if parsed.get("should_terminate", False) or TERMINATION_SIGNAL in user_msg: + to_remove.append(pidx) + continue + + # Prepare agent prompt for batching (don't call LLM yet) + try: + adapter = adapters[pidx] + if hasattr(adapter, 'prepare_prompt'): + messages, context = adapter.prepare_prompt(user_msg, state["conversation"][:-1]) + agent_prompts_batch.append((pidx, messages, context)) + elif hasattr(adapter, 'generate_response'): + # Fallback for adapters without prepare_prompt + agent_prompts_batch.append((pidx, None, None)) + else: + state["conversation"].append({"role": "assistant", "content": "[Error: Adapter not configured]"}) + except Exception as e: + logger.error(f" Agent prepare error p{pidx} t{turn}: {e}") + state["conversation"].append({"role": "assistant", "content": "I apologize, I encountered an error. Could you rephrase?"}) + + # Batch vLLM call for all agent prompts + if agent_prompts_batch: + # Separate prompts that can be batched from fallback + batchable = [(pidx, msgs, ctx) for pidx, msgs, ctx in agent_prompts_batch if msgs is not None] + fallback = [(pidx, msgs, ctx) for pidx, msgs, ctx in agent_prompts_batch if msgs is None] + + # Batch call for batchable prompts + if batchable: + batch_messages = [msgs for _, msgs, _ in batchable] + batch_responses = agent_client.batch_completion(batch_messages) + + # Process batched responses + for (pidx, _, context), response in zip(batchable, batch_responses): + try: + adapter = adapters[pidx] + state = all_states[pidx] + if response is not None: + result = adapter.process_response(response, context) + agent_content = result.get("response", str(result)) if isinstance(result, dict) else str(result) + else: + agent_content = "I apologize, I encountered an error. Could you rephrase?" + state["conversation"].append({"role": "assistant", "content": agent_content}) + except Exception as e: + logger.error(f" Agent process error p{pidx} t{turn}: {e}") + all_states[pidx]["conversation"].append({"role": "assistant", "content": "I apologize, I encountered an error. Could you rephrase?"}) + + # Handle fallback (adapters without prepare_prompt - sequential calls) + for pidx, _, _ in fallback: + try: + adapter = adapters[pidx] + state = all_states[pidx] + user_msg = state["conversation"][-1]["content"] + resp = adapter.generate_response(user_msg, state["conversation"][:-1]) + agent_content = resp.get("response", str(resp)) if isinstance(resp, dict) else str(resp) + state["conversation"].append({"role": "assistant", "content": agent_content}) + except Exception as e: + logger.error(f" Agent fallback error p{pidx} t{turn}: {e}") + all_states[pidx]["conversation"].append({"role": "assistant", "content": "I apologize, I encountered an error. Could you rephrase?"}) + + active_set -= set(to_remove) + + # Save results for this session round + for profile_idx in profiles_to_run: + if profile_idx not in all_states: + continue + state = all_states[profile_idx] + problem_dict = state["problem_dict"] + conversation = state["conversation"] + full_log = state["full_log"] + + user_tokens = sum(len(m["content"].split()) for m in conversation if m["role"] == "user") + agent_tokens = sum(len(m["content"].split()) for m in conversation if m["role"] == "assistant") + + enforcement_count = state["enforcement_count"] + task_success = 0 + for entry in full_log: + if entry.get("should_terminate", False): + draft = entry.get("draft_answer", "") + if draft and "don't know" not in draft.lower() and len(draft) > 20: + task_success = 1 + + results.append({ + "method": method, + "profile_id": self.profiles[profile_idx].get("user_id", f"user_{profile_idx}"), + "problem_id": f"session_{session_idx}", + "problem": problem_dict.get("problem", ""), + "ground_truth_solution": problem_dict.get("solution", ""), + "is_conflict_test": state["is_conflict"], + "conversation": {"turns": conversation}, + "full_user_log": full_log, + "metrics": { + "task_success": bool(task_success), + "total_turns": len(conversation), + "user_token_count": user_tokens, + "agent_token_count": agent_tokens, + "total_token_count": user_tokens + agent_tokens, + "enforcement_count": enforcement_count, + "disappointment_count": 0, + "preference_compliance_scores": [], + "conflict_resolution_accuracy": 0, + "over_personalization_rate": 0, + }, + "adapter_metrics": {}, + }) + + # Checkpoint after each session round with session-level tracking + # Only increment for profiles that actually ran in this round (those in all_states) + for profile_idx in all_states.keys(): + sessions_per_profile[str(profile_idx)] = sessions_per_profile.get(str(profile_idx), 0) + 1 + if sessions_per_profile[str(profile_idx)] >= self.config.n_sessions_per_profile: + completed_profiles.add(profile_idx) + + with open(checkpoint_file, "w") as f: + json.dump({ + "completed_profiles": sorted(list(completed_profiles)), + "sessions_per_profile": sessions_per_profile + }, f) + with open(results_file, "w") as f: + json.dump(results, f, indent=2) + + elapsed = time.time() - start_time + sessions_done = len(results) + rate = sessions_done / elapsed * 3600 if elapsed > 0 else 0 + logger.info(f" Session round {session_idx+1}/{n_sessions}: {sessions_done} total, {rate:.0f} sessions/hr") + + # Explicitly free adapter models to prevent GPU OOM across methods + for pidx, adapter in adapters.items(): + if hasattr(adapter, 'cleanup'): + adapter.cleanup() + del adapters + + return results + + def run_all(self) -> Dict[str, Any]: + """Run all methods and generate comparative analysis.""" + all_results = {} + + for method in self.config.methods: + if method not in AVAILABLE_METHODS: + logger.warning(f"Unknown method: {method}, skipping") + continue + + results = self.run_method(method) + all_results[method] = results + + # Free GPU memory between methods to prevent OOM on later adapters + try: + from personalization.serving.personalized_llm import clear_shared_models + clear_shared_models() + except ImportError: + pass + try: + import gc + import torch + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + logger.info(f" GPU memory freed after {method}: {torch.cuda.memory_allocated()/1e9:.1f}GB allocated") + except ImportError: + pass + + # Comparative analysis + analysis = self._analyze_results(all_results) + + # Save analysis + with open(self.output_dir / "analysis.json", "w") as f: + json.dump(analysis, f, indent=2) + + # Generate report + self._generate_report(analysis) + + return analysis + + def _analyze_results(self, all_results: Dict[str, List[Dict]]) -> Dict: + """Analyze results across all methods.""" + analysis = { + "per_method": {}, + "comparison": {}, + } + + for method, results in all_results.items(): + n = len(results) + if n == 0: + continue + + # Aggregate metrics + task_success = sum(r["metrics"]["task_success"] for r in results) / n + avg_user_tokens = sum(r["metrics"]["user_token_count"] for r in results) / n + avg_total_tokens = sum(r["metrics"]["total_token_count"] for r in results) / n + avg_enforcement = sum(r["metrics"]["enforcement_count"] for r in results) / n + avg_turns = sum(r["metrics"]["total_turns"] for r in results) / n + + # Compliance and conflict metrics + compliance_scores = [ + sum(r["metrics"]["preference_compliance_scores"]) / len(r["metrics"]["preference_compliance_scores"]) + if r["metrics"]["preference_compliance_scores"] else 0.5 + for r in results + ] + avg_compliance = sum(compliance_scores) / len(compliance_scores) + + conflict_results = [r for r in results if r["is_conflict_test"]] + conflict_accuracy = sum( + r["metrics"]["conflict_resolution_accuracy"] for r in conflict_results + ) / len(conflict_results) if conflict_results else 0 + + over_personalization = sum( + r["metrics"]["over_personalization_rate"] for r in results + ) / n + + analysis["per_method"][method] = { + "n_sessions": n, + "task_success_rate": task_success, + "avg_user_tokens": avg_user_tokens, + "avg_total_tokens": avg_total_tokens, + "avg_enforcement_count": avg_enforcement, + "avg_turns": avg_turns, + "avg_preference_compliance": avg_compliance, + "conflict_resolution_accuracy": conflict_accuracy, + "over_personalization_rate": over_personalization, + } + + # Comparison + metrics_to_compare = [ + ("task_success_rate", True), # higher is better + ("avg_user_tokens", False), # lower is better + ("avg_total_tokens", False), # lower is better + ("avg_enforcement_count", False), # lower is better + ("avg_preference_compliance", True), # higher is better + ("conflict_resolution_accuracy", True), # higher is better + ("over_personalization_rate", False), # lower is better + ] + + for metric, higher_better in metrics_to_compare: + values = {m: analysis["per_method"][m][metric] for m in analysis["per_method"]} + if not values: + logger.warning(f"No values for metric {metric}, skipping comparison") + continue + if higher_better: + best = max(values, key=values.get) + else: + best = min(values, key=values.get) + + analysis["comparison"][metric] = { + "values": values, + "best_method": best, + "best_value": values[best], + } + + return analysis + + def _generate_report(self, analysis: Dict) -> None: + """Generate a human-readable report.""" + report_lines = [ + "# Personalization Experiment Report", + f"\nGenerated: {datetime.now().isoformat()}", + f"\nConfig: {self.config.n_profiles} profiles, {self.config.n_sessions_per_profile} sessions each", + "\n## Method Comparison\n", + ] + + # Create comparison table + metrics_display = [ + ("Task Success", "task_success_rate", "{:.1%}"), + ("User Effort (tokens)", "avg_user_tokens", "{:.0f}"), + ("Total Tokens", "avg_total_tokens", "{:.0f}"), + ("Enforcement Count", "avg_enforcement_count", "{:.2f}"), + ("Preference Compliance", "avg_preference_compliance", "{:.1%}"), + ("Conflict Resolution", "conflict_resolution_accuracy", "{:.1%}"), + ("Over-personalization", "over_personalization_rate", "{:.1%}"), + ] + + methods = list(analysis["per_method"].keys()) + + # Header + header = "| Metric |" + "|".join(f" {m} " for m in methods) + "| Best |" + separator = "|" + "|".join(["-" * (len(m) + 2) for m in ["Metric"] + methods + ["Best"]]) + "|" + + report_lines.extend([header, separator]) + + for display_name, metric_key, fmt in metrics_display: + row = f"| {display_name} |" + for m in methods: + val = analysis["per_method"].get(m, {}).get(metric_key, 0) + row += f" {fmt.format(val)} |" + + if metric_key in analysis.get("comparison", {}): + best = analysis["comparison"][metric_key]["best_method"] + else: + best = "N/A" + row += f" {best} |" + report_lines.append(row) + + # Key findings + report_lines.extend([ + "\n## Key Findings\n", + ]) + + # Find advantages of proposed methods + rag_vector = analysis["per_method"].get("rag_vector", {}) + rag = analysis["per_method"].get("rag", {}) + contextual = analysis["per_method"].get("contextual", {}) + all_memory = analysis["per_method"].get("all_memory", {}) + + if rag_vector and contextual: + token_reduction = (contextual.get("avg_total_tokens", 0) - rag_vector.get("avg_total_tokens", 0)) / contextual.get("avg_total_tokens", 1) * 100 + report_lines.append(f"- **Token Efficiency**: RAG+Vector uses {token_reduction:.1f}% fewer tokens than contextual memory") + + if rag_vector and all_memory: + conflict_improvement = rag_vector.get("conflict_resolution_accuracy", 0) - all_memory.get("conflict_resolution_accuracy", 0) + report_lines.append(f"- **Conflict Resolution**: RAG+Vector improves by {conflict_improvement:.1%} over all-memory baseline") + + if rag_vector: + report_lines.append(f"- **Over-personalization**: RAG+Vector rate: {rag_vector.get('over_personalization_rate', 0):.1%}") + + # Save report + report_path = self.output_dir / "report.md" + with open(report_path, "w") as f: + f.write("\n".join(report_lines)) + + logger.info(f"Report saved to {report_path}") + + +def main(): + parser = argparse.ArgumentParser(description="Run personalization experiments") + parser.add_argument("--config", type=str, help="Path to config YAML file") + parser.add_argument("--methods", type=str, default="vanilla,contextual,rag,rag_vector", + help="Comma-separated list of methods to compare") + parser.add_argument("--datasets", type=str, default="math-hard,math-500,bigcodebench", + help="Comma-separated list of datasets") + parser.add_argument("--n-profiles", type=int, default=200, help="Number of user profiles") + parser.add_argument("--n-sessions", type=int, default=30, help="Sessions per profile") + parser.add_argument("--max-turns", type=int, default=15, help="Max turns per session") + parser.add_argument("--output-dir", type=str, default="results", help="Output directory") + parser.add_argument("--profile-path", type=str, help="Path to pre-generated profiles") + parser.add_argument("--start-profile", type=int, default=0, + help="Start profile index (inclusive, 0-indexed)") + parser.add_argument("--end-profile", type=int, default=None, + help="End profile index (exclusive). If not set, runs all profiles from start") + + # vLLM and parallel processing options + parser.add_argument("--use-vllm", action="store_true", + help="Use vLLM servers for inference (much faster)") + parser.add_argument("--vllm-user-url", type=str, default="http://localhost:8004/v1", + help="vLLM server URL for user simulator (70B)") + parser.add_argument("--vllm-agent-url", type=str, default="http://localhost:8003/v1", + help="vLLM server URL for agent (8B)") + # OpenAI user agent options + parser.add_argument("--use-openai-user", action="store_true", + help="Use OpenAI API (GPT-5) for user simulation instead of vLLM") + parser.add_argument("--openai-user-model", type=str, default="gpt-5", + help="OpenAI model name for user simulator (default: gpt-5)") + parser.add_argument("--reward-mode", type=str, default="keyword", choices=["keyword", "llm"], + help="Reward mode for RL updates: 'keyword' (user signals) or 'llm' (GPT-5-nano judge)") + + parser.add_argument("--parallel-profiles", type=int, default=50, + help="Number of profiles to process in parallel (requires --use-vllm)") + parser.add_argument("--use-batch-processing", action="store_true", default=True, + help="Use turn-synchronous batch processing for vanilla/all_memory") + parser.add_argument("--no-batch-processing", action="store_false", dest="use_batch_processing", + help="Disable batch processing") + parser.add_argument("--batch-size", type=int, default=50, + help="Number of conversations to batch together") + parser.add_argument("--continue-from", type=str, default=None, + help="Path to existing output directory to continue from (for extending sessions)") + + args = parser.parse_args() + + # Load or create config + if args.config and Path(args.config).exists(): + with open(args.config) as f: + config_dict = yaml.safe_load(f) + config = ExperimentConfig(**config_dict) + else: + config = ExperimentConfig( + methods=args.methods.split(","), + datasets=args.datasets.split(","), + n_profiles=args.n_profiles, + n_sessions_per_profile=args.n_sessions, + max_turns_per_session=args.max_turns, + output_dir=args.output_dir, + profile_path=args.profile_path, + start_profile=args.start_profile, + end_profile=args.end_profile, + use_vllm=args.use_vllm, + vllm_user_url=args.vllm_user_url, + vllm_agent_url=args.vllm_agent_url, + use_openai_user=args.use_openai_user, + openai_user_model=args.openai_user_model, + reward_mode=args.reward_mode, + parallel_profiles=args.parallel_profiles, + use_batch_processing=args.use_batch_processing, + batch_size_conversations=args.batch_size, + continue_from=args.continue_from, + ) + + # Run experiments + runner = ExperimentRunner(config) + analysis = runner.run_all() + + print("\n" + "=" * 60) + print("EXPERIMENT COMPLETE") + print("=" * 60) + print(f"\nResults saved to: {runner.output_dir}") + if analysis.get("comparison"): + print("\nBest methods per metric:") + for metric, data in analysis["comparison"].items(): + print(f" {metric}: {data['best_method']} ({data['best_value']:.3f})") + else: + print("\nNo comparison data available (sessions may have failed)") + + +if __name__ == "__main__": + main() |
