summaryrefslogtreecommitdiff
path: root/collaborativeagents/scripts/run_experiments.py
diff options
context:
space:
mode:
Diffstat (limited to 'collaborativeagents/scripts/run_experiments.py')
-rw-r--r--collaborativeagents/scripts/run_experiments.py1328
1 files changed, 1328 insertions, 0 deletions
diff --git a/collaborativeagents/scripts/run_experiments.py b/collaborativeagents/scripts/run_experiments.py
new file mode 100644
index 0000000..0ba0ba0
--- /dev/null
+++ b/collaborativeagents/scripts/run_experiments.py
@@ -0,0 +1,1328 @@
+#!/usr/bin/env python3
+"""
+Main experiment orchestrator for personalization benchmark.
+
+This script runs all baselines and the proposed methods with PROPER multi-turn
+conversation simulation, user preference enforcement, and LLM-based evaluation.
+
+Usage:
+ python run_experiments.py --config config.yaml
+ python run_experiments.py --methods vanilla,rag,rag_vector --datasets gpqa,aime
+"""
+
+import argparse
+import json
+import yaml
+import os
+import sys
+from pathlib import Path
+from datetime import datetime
+from typing import List, Dict, Any, Optional
+from dataclasses import dataclass, asdict
+import logging
+import re
+from concurrent.futures import ThreadPoolExecutor, as_completed
+import threading
+import time
+
+# Add paths
+sys.path.insert(0, str(Path(__file__).parent.parent))
+sys.path.insert(0, str(Path(__file__).parent.parent.parent))
+
+from datasets_extended import get_dataset, get_all_datasets, get_challenging_datasets
+from evaluation.llm_judge import LLMJudge, BatchEvaluator, ConversationMetrics
+from conflict_scenario_generator import ConflictScenarioGenerator
+from adapters.personalized_llm_adapter import PersonalizedLLMAdapter, create_baseline_adapter
+from agents.local_user_agent import LocalUserAgent, SharedLocalUserAgent, TERMINATION_SIGNAL
+from agents.vllm_user_agent import VLLMUserAgent, VLLMAgentClient
+from agents.openai_user_agent import OpenAIUserAgent
+from agents.batch_vllm_agent import BatchConversationGenerator, BatchVLLMClient
+
+
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class ExperimentConfig:
+ """Configuration for an experiment run."""
+ # Methods to compare
+ methods: List[str]
+
+ # Datasets to use
+ datasets: List[str]
+
+ # User profiles
+ n_profiles: int = 200
+ profile_path: Optional[str] = None
+
+ # Profile range (for splitting jobs)
+ start_profile: int = 0 # Inclusive, 0-indexed
+ end_profile: Optional[int] = None # Exclusive, None means all
+
+ # Session settings
+ n_sessions_per_profile: int = 30
+ max_turns_per_session: int = 15 # Increased for harder tasks
+
+ # Model settings
+ user_model: str = "meta-llama/Llama-3.3-70B-Instruct"
+ agent_model: str = "meta-llama/Llama-3.1-8B-Instruct"
+ judge_model: str = "meta-llama/Llama-3.3-70B-Instruct"
+
+ # Output settings
+ output_dir: str = "results"
+ save_conversations: bool = True
+
+ # Conflict testing
+ conflict_ratio: float = 0.3 # proportion of queries that trigger conflicts
+
+ # Compute settings
+ batch_size: int = 4
+ n_gpus: int = 4
+
+ # vLLM settings (for high-performance inference)
+ use_vllm: bool = False
+ vllm_user_url: str = "http://localhost:8004/v1" # 70B user simulator
+ vllm_agent_url: str = "http://localhost:8003/v1" # 8B agent
+
+ # OpenAI user simulator (alternative to vLLM user agent)
+ use_openai_user: bool = False
+ openai_user_model: str = "gpt-5" # Model name for OpenAI user agent
+
+ # Reward mode: "keyword" (implicit user signals) or "llm" (GPT-5-nano judge)
+ # This is a global option applied to ALL methods that use RL updates
+ reward_mode: str = "keyword"
+
+ # Parallel/Batch processing
+ parallel_profiles: int = 50 # Number of profiles to process in parallel
+ use_batch_processing: bool = True # Use turn-synchronous batch processing for vanilla/all_memory
+ batch_size_conversations: int = 50 # Number of conversations to batch together
+
+ # Continue from existing experiment (for extending sessions)
+ continue_from: Optional[str] = None # Path to existing output directory to continue from
+
+
+# Available methods
+AVAILABLE_METHODS = {
+ "vanilla": "No memory, no personalization",
+ "contextual": "Full history in context, summarize when overflow",
+ "reflection": "CollaborativeAgents' agent_notes approach",
+ "reflection_grpo": "Reflection + GRPO training",
+ "all_memory": "All extracted memories in context (no retrieval)",
+ "rag": "Extractor + RAG (no user vector)",
+ "rag_vector": "Extractor + RAG + user vector (proposed method)",
+ "rag_bge": "Extractor + RAG with BGE reranker (278M)",
+ "rag_vector_bge": "Extractor + RAG + user vector with BGE reranker (278M)",
+}
+
+
+class ExperimentRunner:
+ """Main experiment runner."""
+
+ def __init__(self, config: ExperimentConfig):
+ self.config = config
+
+ # Use existing directory if continuing, otherwise create new timestamped one
+ if config.continue_from:
+ self.output_dir = Path(config.continue_from)
+ if not self.output_dir.exists():
+ raise ValueError(f"Continue-from directory does not exist: {config.continue_from}")
+ logger.info(f"Continuing from existing experiment: {self.output_dir}")
+ else:
+ self.output_dir = Path(config.output_dir) / datetime.now().strftime("%Y%m%d_%H%M%S")
+ self.output_dir.mkdir(parents=True, exist_ok=True)
+
+ # Save/update config
+ with open(self.output_dir / "config.yaml", "w") as f:
+ yaml.dump(asdict(config), f)
+
+ # Initialize components
+ self.judge = LLMJudge(model_name=config.judge_model)
+ self.batch_evaluator = BatchEvaluator(self.judge)
+ self.conflict_generator = ConflictScenarioGenerator()
+
+ # Load datasets
+ self.datasets = {}
+ for ds_name in config.datasets:
+ try:
+ self.datasets[ds_name] = get_dataset(ds_name)
+ logger.info(f"Loaded dataset: {ds_name}")
+ except Exception as e:
+ logger.warning(f"Failed to load dataset {ds_name}: {e}")
+
+ # Load or generate profiles
+ self.profiles = self._load_profiles()
+
+ def _load_profiles(self) -> List[Dict]:
+ """Load user profiles from file or generate."""
+ logger.info(f"Profile path configured: {self.config.profile_path}")
+
+ if self.config.profile_path:
+ profile_path = Path(self.config.profile_path)
+ if profile_path.exists():
+ profiles = []
+ with open(profile_path) as f:
+ for line in f:
+ line = line.strip()
+ if line:
+ profiles.append(json.loads(line))
+ logger.info(f"Loaded {len(profiles)} profiles from {self.config.profile_path}")
+ return profiles[:self.config.n_profiles]
+ else:
+ logger.warning(f"Profile path does not exist: {self.config.profile_path}")
+
+ # Generate simple placeholder profiles if no file provided
+ logger.info(f"Generating {self.config.n_profiles} placeholder profiles...")
+ profiles = []
+ for i in range(self.config.n_profiles):
+ profiles.append({
+ "id": i,
+ "persona": f"User {i+1} is a curious individual seeking help with problem solving.",
+ "preferences": [
+ "Provide clear, step-by-step explanations",
+ "Use simple language when possible",
+ "Give examples to illustrate concepts",
+ "Be concise but thorough",
+ "Acknowledge when something is uncertain"
+ ]
+ })
+
+ # Save generated profiles
+ profile_path = self.output_dir / "generated_profiles.json"
+ with open(profile_path, "w") as f:
+ json.dump(profiles, f, indent=2)
+
+ logger.info(f"Generated and saved {len(profiles)} placeholder profiles")
+ return profiles
+
+ def _create_method_adapter(self, method: str, profile: Dict, use_shared_models: bool = False) -> Any:
+ """Create adapter for a specific method.
+
+ Args:
+ method: One of the baseline method names
+ profile: User profile dict (used later in start_session, not constructor)
+ use_shared_models: If True, share embedding/reranker models across parallel
+ workers. ESSENTIAL for parallel profile processing to avoid OOM.
+
+ Returns:
+ Configured adapter instance
+ """
+ # Auto-detect available GPUs and set device assignment accordingly
+ # Layout with local 70B user (4 GPUs):
+ # GPU 0-1: 70B user simulator (TP=2)
+ # GPU 2: 8B agent vLLM server
+ # GPU 3: Embedding + Reranker + Extractor
+ # Layout with OpenAI user (2 GPUs):
+ # GPU 0: 8B agent vLLM server
+ # GPU 1: Embedding + Reranker + Extractor
+ device_assignment = None
+ try:
+ import torch
+ n_gpus = torch.cuda.device_count()
+ if n_gpus >= 4:
+ # 4 GPU layout: 70B user on 0-1, agent on 2, adapters on 3
+ device_assignment = {
+ "embed": "cuda:3",
+ "reranker": "cuda:3",
+ "extractor": "cuda:3",
+ }
+ elif n_gpus >= 2:
+ # 2 GPU layout: agent on 0, adapters on 1
+ device_assignment = {
+ "embed": "cuda:1",
+ "reranker": "cuda:1",
+ "extractor": "cuda:1",
+ }
+ elif n_gpus == 1:
+ device_assignment = {
+ "embed": "cuda:0",
+ "reranker": "cuda:0",
+ "extractor": "cuda:0",
+ }
+ except ImportError:
+ pass
+
+ adapter = create_baseline_adapter(
+ method,
+ device_assignment=device_assignment,
+ use_vllm=self.config.use_vllm,
+ use_shared_models=use_shared_models,
+ reward_mode=self.config.reward_mode,
+ )
+ # Profile will be passed to start_session() when the conversation begins
+ return adapter
+
+ def run_single_session(
+ self,
+ method: str,
+ profile: Dict,
+ problem: Dict,
+ is_conflict_query: bool = False,
+ adapter: Any = None,
+ user_agent: Any = None
+ ) -> Dict:
+ """Run a single session with PROPER multi-turn conversation and user simulation.
+
+ This implements:
+ 1. User simulator that role-plays with preferences
+ 2. Multi-turn conversation (up to max_turns)
+ 3. Preference enforcement by simulated user
+ 4. Proper metrics extraction from conversation
+ """
+ # Use provided adapter (reused across sessions) or create new one
+ agent_adapter = adapter if adapter else self._create_method_adapter(method, profile)
+
+ # Prepare conflict scenario if needed
+ conflict_scenario = None
+ original_problem = problem.get("problem", problem.get("question", ""))
+ if is_conflict_query:
+ conflict_scenario = self.conflict_generator.generate_for_profile(
+ profile["preferences"],
+ problem.get("domain", "general")
+ )
+ if conflict_scenario:
+ problem = dict(problem)
+ problem["problem"] = conflict_scenario["query"]
+
+ query = problem.get("problem", problem.get("question", ""))
+
+ # Extract user preferences as formatted string
+ user_prefs = profile.get("preferences", [])
+ if isinstance(user_prefs, list) and len(user_prefs) > 0:
+ if isinstance(user_prefs[0], dict):
+ # Structured preferences with condition/action
+ pref_str = "\n".join([
+ f"- When {p.get('condition', '')}, {p.get('action', '')}"
+ for p in user_prefs[:10] # Top 10 preferences
+ ])
+ else:
+ # Simple string preferences
+ pref_str = "\n".join([f"- {p}" for p in user_prefs[:10]])
+ else:
+ pref_str = str(user_prefs)
+
+ user_persona = profile.get("persona", "A user seeking help with problem solving.")
+
+ # Create user agent for this session (or reuse provided one)
+ if user_agent is None:
+ if self.config.use_openai_user:
+ user_agent = OpenAIUserAgent(
+ user_task_description="Help the user solve their problem.",
+ problem=query,
+ user_persona=user_persona,
+ user_preferences=pref_str,
+ model=self.config.openai_user_model,
+ )
+ elif self.config.use_vllm:
+ user_agent = VLLMUserAgent(
+ user_task_description="Help the user solve their problem.",
+ problem=query,
+ user_persona=user_persona,
+ user_preferences=pref_str,
+ vllm_url=self.config.vllm_user_url,
+ )
+ else:
+ user_agent = SharedLocalUserAgent(
+ user_task_description="Help the user solve their problem.",
+ problem=query,
+ user_persona=user_persona,
+ user_preferences=pref_str,
+ )
+
+ # Initialize conversation
+ turns = []
+ full_user_log = [] # Detailed user agent outputs
+
+ # Metrics tracking
+ enforcement_count = 0
+ disappointment_count = 0
+ user_token_count = 0
+ agent_token_count = 0
+ preference_compliance_scores = []
+
+ try:
+ # Initialize adapter for this user
+ if hasattr(agent_adapter, 'initialize'):
+ agent_adapter.initialize()
+ if hasattr(agent_adapter, 'start_session'):
+ agent_adapter.start_session(
+ user_id=profile.get("user_id", "test_user"),
+ user_profile={"preferences": user_prefs, "persona": user_persona}
+ )
+
+ # Start with agent greeting
+ conversation = [{"role": "assistant", "content": "How can I help you today?"}]
+
+ # Multi-turn conversation loop
+ for turn_num in range(self.config.max_turns_per_session):
+ # === User Turn ===
+ user_response = user_agent.generate_user_response(conversation)
+
+ if user_response is None:
+ logger.warning(f"User agent failed to respond at turn {turn_num}")
+ break
+
+ user_message = str(user_response.get("response", ""))
+ user_token_count += len(user_message.split())
+
+ # Add to conversation
+ conversation.append({"role": "user", "content": user_message})
+ turns.append({"role": "user", "content": user_message})
+ full_user_log.append(user_response)
+
+ # Check for termination
+ if user_response.get("should_terminate", False) or TERMINATION_SIGNAL in user_message:
+ break
+
+ # Detect preference enforcement (user correcting agent)
+ enforcement_keywords = ["please", "i asked", "i said", "i prefer", "can you", "could you", "instead"]
+ if any(kw in user_message.lower() for kw in enforcement_keywords):
+ enforcement_count += 1
+
+ # === Agent Turn ===
+ if hasattr(agent_adapter, 'generate_response'):
+ response = agent_adapter.generate_response(user_message, conversation[:-1])
+ agent_content = response.get("response", str(response)) if isinstance(response, dict) else str(response)
+ elif callable(agent_adapter):
+ agent_content = agent_adapter(conversation)
+ else:
+ agent_content = "[Error: Adapter not properly configured]"
+
+ agent_token_count += len(agent_content.split())
+
+ # Add to conversation
+ conversation.append({"role": "assistant", "content": agent_content})
+ turns.append({"role": "assistant", "content": agent_content})
+
+ # Estimate preference compliance for this turn (heuristic based on user satisfaction)
+ # If user doesn't enforce in next turn, assume compliance
+ # This is a simplified heuristic - LLM judge would be more accurate
+ compliance_score = 0.8 if enforcement_count == 0 else max(0.2, 1.0 - 0.2 * enforcement_count)
+ preference_compliance_scores.append(compliance_score)
+
+ # End session
+ if hasattr(agent_adapter, 'end_session'):
+ adapter_metrics = agent_adapter.end_session(task_success=True)
+ else:
+ adapter_metrics = {}
+
+ except Exception as e:
+ import traceback
+ logger.error(f"Error in session: {e}")
+ logger.error(f"Full traceback:\n{traceback.format_exc()}")
+ turns.append({"role": "assistant", "content": f"[Error: {e}]"})
+
+ # Compute metrics
+ total_turns = len(turns)
+ total_token_count = user_token_count + agent_token_count
+
+ # Check if user reached a satisfactory answer (from last user response)
+ task_success = False
+ if full_user_log:
+ last_user = full_user_log[-1]
+ if last_user.get("should_terminate", False):
+ draft = last_user.get("draft_answer", "")
+ # Consider success if draft answer is not empty/"I don't know"
+ task_success = bool(draft) and draft.lower() != "i don't know"
+
+ # Compute average compliance
+ avg_compliance = sum(preference_compliance_scores) / len(preference_compliance_scores) if preference_compliance_scores else 0.5
+
+ # Conflict resolution (if this was a conflict test)
+ conflict_accuracy = 0.0
+ if is_conflict_query and conflict_scenario:
+ # Check if the correct preference was applied
+ expected_pref = conflict_scenario.get("expected_preference", "")
+ # Simple heuristic: check if expected preference keywords appear in agent responses
+ agent_texts = " ".join([t["content"] for t in turns if t["role"] == "assistant"])
+ if expected_pref and any(kw in agent_texts.lower() for kw in expected_pref.lower().split()[:3]):
+ conflict_accuracy = 1.0
+
+ # Over-personalization detection (heuristic: if agent mentions preferences not in profile)
+ over_personalization = 0.0
+
+ metrics = ConversationMetrics(
+ task_success=task_success,
+ turns_to_success=total_turns if task_success else -1,
+ total_turns=total_turns,
+ user_token_count=user_token_count,
+ enforcement_count=enforcement_count,
+ disappointment_count=disappointment_count,
+ total_token_count=total_token_count,
+ agent_token_count=agent_token_count,
+ preference_compliance_scores=preference_compliance_scores,
+ conflict_resolution_accuracy=conflict_accuracy,
+ over_personalization_rate=over_personalization,
+ )
+
+ return {
+ "method": method,
+ "profile_id": profile.get("user_id", "unknown"),
+ "problem_id": problem.get("problem_id", str(hash(query))[:8]),
+ "problem": original_problem,
+ "ground_truth_solution": problem.get("solution", problem.get("answer", "")),
+ "is_conflict_test": is_conflict_query,
+ "conflict_scenario": conflict_scenario,
+ "conversation": {"turns": turns} if self.config.save_conversations else None,
+ "full_user_log": full_user_log if self.config.save_conversations else None,
+ "metrics": asdict(metrics),
+ "adapter_metrics": adapter_metrics if 'adapter_metrics' in dir() else {},
+ }
+
+ def _run_profile_sessions(
+ self,
+ method: str,
+ profile_idx: int,
+ profile: Dict,
+ adapter: Any = None
+ ) -> List[Dict]:
+ """Run all sessions for a single profile. Thread-safe for parallel execution."""
+ profile_results = []
+
+ # Create vLLM-based agent client if using vLLM (for methods that need it)
+ vllm_agent = None
+ if self.config.use_vllm and method == "vanilla":
+ vllm_agent = VLLMAgentClient(
+ vllm_url=self.config.vllm_agent_url,
+ system_prompt="You are a helpful AI assistant for problem-solving tasks."
+ )
+
+ # Run sessions across datasets
+ session_idx = 0
+ for ds_name, dataset in self.datasets.items():
+ samples = dataset.get_testset()
+
+ for sample in samples:
+ if session_idx >= self.config.n_sessions_per_profile:
+ break
+
+ # Decide if this is a conflict query
+ is_conflict = (session_idx % int(1 / self.config.conflict_ratio)) == 0
+
+ problem = {
+ "problem": sample.problem,
+ "solution": sample.solution,
+ "problem_id": sample.problem_id,
+ "domain": sample.domain,
+ }
+
+ try:
+ result = self.run_single_session(
+ method=method,
+ profile=profile,
+ problem=problem,
+ is_conflict_query=is_conflict,
+ adapter=vllm_agent if vllm_agent else adapter
+ )
+ profile_results.append(result)
+ except Exception as e:
+ logger.error(f"Error in session for profile {profile_idx}: {e}")
+
+ session_idx += 1
+
+ return profile_results
+
+ def run_method(self, method: str) -> List[Dict]:
+ """Run all sessions for a single method with checkpointing and parallel processing."""
+ logger.info(f"Running method: {method}")
+
+ # Setup method directory and checkpoint
+ method_dir = self.output_dir / method
+ method_dir.mkdir(exist_ok=True)
+ checkpoint_file = method_dir / "checkpoint.json"
+ results_file = method_dir / "results.json"
+
+ # Load existing results and checkpoint
+ results = []
+ completed_profiles = set()
+ sessions_per_profile = {} # Track session count per profile for continue functionality
+ if checkpoint_file.exists():
+ with open(checkpoint_file, "r") as f:
+ checkpoint = json.load(f)
+ completed_profiles = set(checkpoint.get("completed_profiles", []))
+ sessions_per_profile = checkpoint.get("sessions_per_profile", {})
+ logger.info(f" Resuming from checkpoint: {len(completed_profiles)} profiles completed")
+ if sessions_per_profile:
+ total_sessions = sum(sessions_per_profile.values())
+ logger.info(f" Session-level tracking: {total_sessions} sessions across {len(sessions_per_profile)} profiles")
+ if results_file.exists():
+ with open(results_file, "r") as f:
+ results = json.load(f)
+
+ # Determine profile range
+ start_idx = self.config.start_profile
+ end_idx = self.config.end_profile if self.config.end_profile else len(self.profiles)
+
+ # Build list of profiles that need more sessions
+ profiles_to_run = []
+ for idx in range(start_idx, min(end_idx, len(self.profiles))):
+ existing_sessions = sessions_per_profile.get(str(idx), 0)
+ if existing_sessions < self.config.n_sessions_per_profile:
+ profiles_to_run.append(idx)
+
+ # Log what we're running
+ if sessions_per_profile:
+ total_existing = sum(sessions_per_profile.get(str(idx), 0) for idx in profiles_to_run)
+ total_needed = len(profiles_to_run) * self.config.n_sessions_per_profile
+ logger.info(f" Running profiles {start_idx} to {end_idx-1}: {len(profiles_to_run)} profiles need sessions")
+ logger.info(f" Sessions: {total_existing} existing, {total_needed - total_existing} remaining")
+ else:
+ logger.info(f" Running profiles {start_idx} to {end_idx-1} ({len(profiles_to_run)} remaining)")
+
+ # When using batch processing with vLLM or OpenAI user: use turn-synchronous batch mode
+ # This batches both user and agent calls for maximum throughput
+ if self.config.use_batch_processing and self.config.use_vllm:
+ user_type = "OpenAI" if self.config.use_openai_user else "local vLLM"
+ logger.info(f" Using BATCH processing ({user_type} user) for {method}")
+ return self._run_method_batch(
+ method, profiles_to_run, results, completed_profiles,
+ sessions_per_profile, checkpoint_file, results_file
+ )
+
+ # Decide on parallelization for sequential methods
+ n_parallel = self.config.parallel_profiles if (self.config.use_vllm or self.config.use_openai_user) else 1
+
+ if n_parallel > 1:
+ logger.info(f" Using parallel processing with {n_parallel} workers")
+ self._run_method_parallel(
+ method, profiles_to_run, results, completed_profiles,
+ sessions_per_profile, checkpoint_file, results_file
+ )
+ else:
+ # Sequential execution (original behavior)
+ # Create ONE adapter per method and reuse it (avoids GPU OOM from repeated model loading)
+ adapter = self._create_method_adapter(method, None)
+ adapter.initialize()
+
+ for profile_idx in profiles_to_run:
+ profile = self.profiles[profile_idx]
+ logger.info(f" Profile {profile_idx + 1}/{len(self.profiles)}")
+
+ profile_results = self._run_profile_sessions(method, profile_idx, profile, adapter)
+
+ # Add profile results to overall results
+ results.extend(profile_results)
+ completed_profiles.add(profile_idx)
+ sessions_per_profile[str(profile_idx)] = self.config.n_sessions_per_profile
+
+ # Save checkpoint and results after each profile
+ with open(checkpoint_file, "w") as f:
+ json.dump({
+ "completed_profiles": sorted(list(completed_profiles)),
+ "sessions_per_profile": sessions_per_profile
+ }, f)
+ with open(results_file, "w") as f:
+ json.dump(results, f, indent=2)
+ logger.info(f" Profile {profile_idx + 1} completed and checkpointed")
+
+ return results
+
+ def _run_method_parallel(
+ self,
+ method: str,
+ profiles_to_run: List[int],
+ results: List[Dict],
+ completed_profiles: set,
+ sessions_per_profile: Dict[str, int],
+ checkpoint_file: Path,
+ results_file: Path
+ ):
+ """Run profiles in parallel using ThreadPoolExecutor.
+
+ Uses shared model singletons for embedding/reranker to avoid OOM
+ when multiple workers try to load their own copies.
+ """
+ n_parallel = self.config.parallel_profiles
+ results_lock = threading.Lock()
+ start_time = time.time()
+ profiles_completed = 0
+
+ def process_profile(profile_idx: int) -> tuple:
+ """Process a single profile and return (profile_idx, results)."""
+ profile = self.profiles[profile_idx]
+ # Create adapter with shared models to avoid OOM from duplicate model loading
+ adapter = self._create_method_adapter(method, profile, use_shared_models=True)
+ profile_results = self._run_profile_sessions(method, profile_idx, profile, adapter)
+ return profile_idx, profile_results
+
+ with ThreadPoolExecutor(max_workers=n_parallel) as executor:
+ # Submit all profile jobs
+ future_to_profile = {
+ executor.submit(process_profile, idx): idx
+ for idx in profiles_to_run
+ }
+
+ # Process completed profiles
+ for future in as_completed(future_to_profile):
+ profile_idx = future_to_profile[future]
+ try:
+ idx, profile_results = future.result()
+
+ with results_lock:
+ results.extend(profile_results)
+ completed_profiles.add(idx)
+ sessions_per_profile[str(idx)] = self.config.n_sessions_per_profile
+ profiles_completed += 1
+
+ # Save checkpoint with session-level tracking
+ with open(checkpoint_file, "w") as f:
+ json.dump({
+ "completed_profiles": sorted(list(completed_profiles)),
+ "sessions_per_profile": sessions_per_profile
+ }, f)
+ with open(results_file, "w") as f:
+ json.dump(results, f, indent=2)
+
+ # Log progress with throughput estimate
+ elapsed = time.time() - start_time
+ profiles_per_hour = profiles_completed / elapsed * 3600 if elapsed > 0 else 0
+ sessions_per_hour = len(results) / elapsed * 3600 if elapsed > 0 else 0
+ logger.info(
+ f" Profile {idx + 1} completed "
+ f"({profiles_completed}/{len(profiles_to_run)}) - "
+ f"{profiles_per_hour:.1f} profiles/hr, {sessions_per_hour:.1f} sessions/hr"
+ )
+
+ except Exception as e:
+ logger.error(f" Profile {profile_idx} failed: {e}")
+
+ def _run_method_batch(
+ self,
+ method: str,
+ profiles_to_run: List[int],
+ results: List[Dict],
+ completed_profiles: set,
+ sessions_per_profile: Dict[str, int],
+ checkpoint_file: Path,
+ results_file: Path
+ ) -> List[Dict]:
+ """
+ Turn-synchronous batch processing for ALL methods.
+
+ At each turn, user calls are batched concurrently via AsyncOpenAI,
+ then agent responses go through personalization adapters.
+ Sessions within a profile run sequentially (for stateful memory).
+ """
+ from agents.batch_vllm_agent import BatchOpenAIClient, BatchVLLMClient, TERMINATION_SIGNAL
+ from json_repair import repair_json
+
+ start_time = time.time()
+
+ # Create user client (OpenAI API or local vLLM)
+ if self.config.use_openai_user:
+ user_client = BatchOpenAIClient(
+ model=self.config.openai_user_model,
+ max_tokens=4096,
+ max_concurrent=32,
+ api_key=os.environ.get("OPENAI_API_KEY"),
+ )
+ logger.info(f" Using OpenAI user simulator: {self.config.openai_user_model}")
+ else:
+ user_client = BatchVLLMClient(
+ vllm_url=self.config.vllm_user_url,
+ max_tokens=4096,
+ temperature=1.0,
+ timeout=None,
+ max_concurrent=100,
+ json_mode=True, # User simulator needs JSON output
+ )
+ logger.info(f" Using local vLLM user simulator: {self.config.vllm_user_url}")
+
+ # Create async agent client for batched vLLM calls
+ agent_client = BatchVLLMClient(
+ vllm_url=self.config.vllm_agent_url,
+ max_tokens=2048,
+ temperature=0.7,
+ timeout=None, # Infinite timeout for long generations
+ max_concurrent=100,
+ )
+
+ USER_PROMPT_TEMPLATE = (
+ "You are a user simulator collaborating with an agent to solve a problem. "
+ "You will be provided with a problem description, and you must get the agent to help you solve it. "
+ "You will also be provided with user preferences, which you must follow and actively enforce throughout the conversation.\n\n"
+ "# Problem Description\n{problem}\nNote: the agent cannot see this problem description.\n\n"
+ "# User Persona\n{user_persona}\n\n"
+ "# User Preferences\n{user_preferences}\n"
+ "These preferences are NON-NEGOTIABLE that define how you prefer the agent to behave. They must be strictly enforced:\n"
+ " - **Answer clarifying questions**: The agent may ask clarifying questions before attempting an answer. "
+ "Answer such questions, and do not enforce preferences about answer format or content while the agent is clarifying.\n"
+ " - **Enforce immediately**: Every agent response must satisfy your preferences before you can proceed. "
+ "Explicitly ask the agent to adjust their response until it complies.\n"
+ " - **Never proceed without compliance**: Do NOT update your draft answer, do NOT consider terminating, "
+ "and do NOT move forward until the agent follows your preferences.\n\n"
+ "# Draft Answer Management\n"
+ "- **Maintain a working draft**: Start with \"I don't know\". Update your draft answer based on what you learn from agent responses.\n"
+ "- **Don't update when enforcing preferences**: If the agent response does not follow your preferences, "
+ "do NOT update your draft answer, regardless of whether the agent provides helpful information.\n\n"
+ "# Conversation Termination\n"
+ "Before generating your response, determine if you should terminate:\n"
+ " - Do you feel like your draft answer is a good answer to the problem?\n"
+ " - Do you feel like the agent cannot help further?\n"
+ "If the agent response does not follow your preferences, you must NOT terminate - instead, enforce the preferences.\n"
+ "When ready to terminate, respond with \"TERMINATE\".\n\n"
+ "# Output Format (respond in JSON):\n"
+ "{{\n"
+ " \"preferences_check\": \"For EACH relevant preference, evaluate: is it satisfied?\",\n"
+ " \"enforce_preferences\": true/false,\n"
+ " \"reasoning\": \"Brief reasoning (2-3 sentences). Does agent follow preferences? If no, enforce. If yes, update draft.\",\n"
+ " \"draft_answer\": \"Your current working draft answer\",\n"
+ " \"should_terminate\": true/false,\n"
+ " \"response\": \"Your response to the agent\"\n"
+ "}}"
+ )
+
+ def parse_user_response(content):
+ if not content:
+ return None
+ try:
+ parsed = repair_json(content, return_objects=True)
+ if isinstance(parsed, dict) and "response" in parsed:
+ return parsed
+ except:
+ pass
+ if TERMINATION_SIGNAL in (content or ""):
+ return {"reasoning": "", "draft_answer": "", "should_terminate": True, "response": TERMINATION_SIGNAL}
+ return {"reasoning": "", "draft_answer": "", "should_terminate": False, "response": content or ""}
+
+ def reverse_roles(conversation):
+ return [
+ {"role": "user" if m["role"] == "assistant" else "assistant", "content": m["content"]}
+ for m in conversation
+ ]
+
+ # Create per-profile adapters
+ adapters = {}
+ profile_sessions = {}
+
+ for profile_idx in profiles_to_run:
+ profile = self.profiles[profile_idx]
+ adapter = self._create_method_adapter(method, profile, use_shared_models=True)
+ if hasattr(adapter, 'initialize'):
+ adapter.initialize()
+ adapters[profile_idx] = adapter
+
+ sessions = []
+ for ds_name, ds_obj in self.datasets.items():
+ ds_items = ds_obj.get_testset()
+ for item in ds_items[:self.config.n_sessions_per_profile]:
+ sessions.append({"problem": item.problem, "solution": item.solution, "domain": ds_obj.domain})
+ sessions = sessions[:self.config.n_sessions_per_profile]
+ n_conflict = int(len(sessions) * self.config.conflict_ratio)
+ profile_sessions[profile_idx] = [(s, idx < n_conflict) for idx, s in enumerate(sessions)]
+
+ n_sessions = self.config.n_sessions_per_profile
+
+ # Calculate sessions to run per profile (accounting for existing sessions)
+ sessions_to_run_per_profile = {}
+ for profile_idx in profiles_to_run:
+ existing = sessions_per_profile.get(str(profile_idx), 0)
+ remaining = n_sessions - existing
+ if remaining > 0:
+ sessions_to_run_per_profile[profile_idx] = (existing, remaining) # (start_session, count)
+
+ if sessions_to_run_per_profile:
+ total_remaining = sum(v[1] for v in sessions_to_run_per_profile.values())
+ logger.info(f" Batch: {len(sessions_to_run_per_profile)} profiles, {total_remaining} sessions remaining")
+ else:
+ logger.info(f" Batch: All sessions already completed")
+ return results
+
+ # Process sessions in rounds
+ for session_idx in range(n_sessions):
+ # Initialize all conversations for this round
+ all_states = {} # profile_idx -> state dict
+ active_set = set()
+
+ for profile_idx in profiles_to_run:
+ # Skip if this profile doesn't need this session
+ if profile_idx not in sessions_to_run_per_profile:
+ continue
+ start_session, _ = sessions_to_run_per_profile[profile_idx]
+ if session_idx < start_session:
+ continue # Already completed this session
+ if session_idx >= len(profile_sessions[profile_idx]):
+ continue
+ problem_dict, is_conflict = profile_sessions[profile_idx][session_idx]
+ profile = self.profiles[profile_idx]
+ query = problem_dict["problem"]
+
+ if is_conflict:
+ cs = self.conflict_generator.generate_for_profile(
+ profile.get("preferences", []), problem_dict.get("domain", "general"))
+ if cs:
+ query = cs["query"]
+
+ user_prefs = profile.get("preferences", [])
+ if isinstance(user_prefs, list) and user_prefs:
+ if isinstance(user_prefs[0], dict):
+ pref_str = "\n".join([f"- When {p.get('condition','')}, {p.get('action','')}" for p in user_prefs[:10]])
+ else:
+ pref_str = "\n".join([f"- {p}" for p in user_prefs[:10]])
+ else:
+ pref_str = str(user_prefs)
+
+ user_persona = profile.get("persona", "A user seeking help with problem solving.")
+ adapter = adapters[profile_idx]
+ if hasattr(adapter, 'start_session'):
+ adapter.start_session(
+ user_id=profile.get("user_id", f"user_{profile_idx}"),
+ user_profile={"preferences": user_prefs, "persona": user_persona}
+ )
+
+ all_states[profile_idx] = {
+ "conversation": [{"role": "assistant", "content": "How can I help you today?"}],
+ "full_log": [],
+ "system_prompt": USER_PROMPT_TEMPLATE.format(
+ problem=query, user_persona=user_persona, user_preferences=pref_str),
+ "problem_dict": problem_dict,
+ "is_conflict": is_conflict,
+ "enforcement_count": 0,
+ }
+ active_set.add(profile_idx)
+
+ # Turn-synchronous loop
+ for turn in range(self.config.max_turns_per_session):
+ if not active_set:
+ break
+
+ # Batch user calls
+ active_list = sorted(active_set)
+ user_msgs_batch = []
+ for pidx in active_list:
+ state = all_states[pidx]
+ msgs = [{"role": "system", "content": state["system_prompt"]}]
+ msgs.extend(reverse_roles(state["conversation"]))
+ user_msgs_batch.append(msgs)
+
+ user_responses = user_client.batch_completion(user_msgs_batch)
+
+ # Process user responses and prepare agent prompts for batching
+ to_remove = []
+ agent_prompts_batch = [] # List of (pidx, messages, context)
+ for i, pidx in enumerate(active_list):
+ state = all_states[pidx]
+ parsed = parse_user_response(user_responses[i])
+
+ if parsed is None:
+ to_remove.append(pidx)
+ continue
+
+ user_msg = str(parsed.get("response", ""))
+ state["conversation"].append({"role": "user", "content": user_msg})
+ state["full_log"].append(parsed)
+
+ if parsed.get("enforce_preferences", False):
+ state["enforcement_count"] += 1
+
+ if parsed.get("should_terminate", False) or TERMINATION_SIGNAL in user_msg:
+ to_remove.append(pidx)
+ continue
+
+ # Prepare agent prompt for batching (don't call LLM yet)
+ try:
+ adapter = adapters[pidx]
+ if hasattr(adapter, 'prepare_prompt'):
+ messages, context = adapter.prepare_prompt(user_msg, state["conversation"][:-1])
+ agent_prompts_batch.append((pidx, messages, context))
+ elif hasattr(adapter, 'generate_response'):
+ # Fallback for adapters without prepare_prompt
+ agent_prompts_batch.append((pidx, None, None))
+ else:
+ state["conversation"].append({"role": "assistant", "content": "[Error: Adapter not configured]"})
+ except Exception as e:
+ logger.error(f" Agent prepare error p{pidx} t{turn}: {e}")
+ state["conversation"].append({"role": "assistant", "content": "I apologize, I encountered an error. Could you rephrase?"})
+
+ # Batch vLLM call for all agent prompts
+ if agent_prompts_batch:
+ # Separate prompts that can be batched from fallback
+ batchable = [(pidx, msgs, ctx) for pidx, msgs, ctx in agent_prompts_batch if msgs is not None]
+ fallback = [(pidx, msgs, ctx) for pidx, msgs, ctx in agent_prompts_batch if msgs is None]
+
+ # Batch call for batchable prompts
+ if batchable:
+ batch_messages = [msgs for _, msgs, _ in batchable]
+ batch_responses = agent_client.batch_completion(batch_messages)
+
+ # Process batched responses
+ for (pidx, _, context), response in zip(batchable, batch_responses):
+ try:
+ adapter = adapters[pidx]
+ state = all_states[pidx]
+ if response is not None:
+ result = adapter.process_response(response, context)
+ agent_content = result.get("response", str(result)) if isinstance(result, dict) else str(result)
+ else:
+ agent_content = "I apologize, I encountered an error. Could you rephrase?"
+ state["conversation"].append({"role": "assistant", "content": agent_content})
+ except Exception as e:
+ logger.error(f" Agent process error p{pidx} t{turn}: {e}")
+ all_states[pidx]["conversation"].append({"role": "assistant", "content": "I apologize, I encountered an error. Could you rephrase?"})
+
+ # Handle fallback (adapters without prepare_prompt - sequential calls)
+ for pidx, _, _ in fallback:
+ try:
+ adapter = adapters[pidx]
+ state = all_states[pidx]
+ user_msg = state["conversation"][-1]["content"]
+ resp = adapter.generate_response(user_msg, state["conversation"][:-1])
+ agent_content = resp.get("response", str(resp)) if isinstance(resp, dict) else str(resp)
+ state["conversation"].append({"role": "assistant", "content": agent_content})
+ except Exception as e:
+ logger.error(f" Agent fallback error p{pidx} t{turn}: {e}")
+ all_states[pidx]["conversation"].append({"role": "assistant", "content": "I apologize, I encountered an error. Could you rephrase?"})
+
+ active_set -= set(to_remove)
+
+ # Save results for this session round
+ for profile_idx in profiles_to_run:
+ if profile_idx not in all_states:
+ continue
+ state = all_states[profile_idx]
+ problem_dict = state["problem_dict"]
+ conversation = state["conversation"]
+ full_log = state["full_log"]
+
+ user_tokens = sum(len(m["content"].split()) for m in conversation if m["role"] == "user")
+ agent_tokens = sum(len(m["content"].split()) for m in conversation if m["role"] == "assistant")
+
+ enforcement_count = state["enforcement_count"]
+ task_success = 0
+ for entry in full_log:
+ if entry.get("should_terminate", False):
+ draft = entry.get("draft_answer", "")
+ if draft and "don't know" not in draft.lower() and len(draft) > 20:
+ task_success = 1
+
+ results.append({
+ "method": method,
+ "profile_id": self.profiles[profile_idx].get("user_id", f"user_{profile_idx}"),
+ "problem_id": f"session_{session_idx}",
+ "problem": problem_dict.get("problem", ""),
+ "ground_truth_solution": problem_dict.get("solution", ""),
+ "is_conflict_test": state["is_conflict"],
+ "conversation": {"turns": conversation},
+ "full_user_log": full_log,
+ "metrics": {
+ "task_success": bool(task_success),
+ "total_turns": len(conversation),
+ "user_token_count": user_tokens,
+ "agent_token_count": agent_tokens,
+ "total_token_count": user_tokens + agent_tokens,
+ "enforcement_count": enforcement_count,
+ "disappointment_count": 0,
+ "preference_compliance_scores": [],
+ "conflict_resolution_accuracy": 0,
+ "over_personalization_rate": 0,
+ },
+ "adapter_metrics": {},
+ })
+
+ # Checkpoint after each session round with session-level tracking
+ # Only increment for profiles that actually ran in this round (those in all_states)
+ for profile_idx in all_states.keys():
+ sessions_per_profile[str(profile_idx)] = sessions_per_profile.get(str(profile_idx), 0) + 1
+ if sessions_per_profile[str(profile_idx)] >= self.config.n_sessions_per_profile:
+ completed_profiles.add(profile_idx)
+
+ with open(checkpoint_file, "w") as f:
+ json.dump({
+ "completed_profiles": sorted(list(completed_profiles)),
+ "sessions_per_profile": sessions_per_profile
+ }, f)
+ with open(results_file, "w") as f:
+ json.dump(results, f, indent=2)
+
+ elapsed = time.time() - start_time
+ sessions_done = len(results)
+ rate = sessions_done / elapsed * 3600 if elapsed > 0 else 0
+ logger.info(f" Session round {session_idx+1}/{n_sessions}: {sessions_done} total, {rate:.0f} sessions/hr")
+
+ # Explicitly free adapter models to prevent GPU OOM across methods
+ for pidx, adapter in adapters.items():
+ if hasattr(adapter, 'cleanup'):
+ adapter.cleanup()
+ del adapters
+
+ return results
+
+ def run_all(self) -> Dict[str, Any]:
+ """Run all methods and generate comparative analysis."""
+ all_results = {}
+
+ for method in self.config.methods:
+ if method not in AVAILABLE_METHODS:
+ logger.warning(f"Unknown method: {method}, skipping")
+ continue
+
+ results = self.run_method(method)
+ all_results[method] = results
+
+ # Free GPU memory between methods to prevent OOM on later adapters
+ try:
+ from personalization.serving.personalized_llm import clear_shared_models
+ clear_shared_models()
+ except ImportError:
+ pass
+ try:
+ import gc
+ import torch
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ logger.info(f" GPU memory freed after {method}: {torch.cuda.memory_allocated()/1e9:.1f}GB allocated")
+ except ImportError:
+ pass
+
+ # Comparative analysis
+ analysis = self._analyze_results(all_results)
+
+ # Save analysis
+ with open(self.output_dir / "analysis.json", "w") as f:
+ json.dump(analysis, f, indent=2)
+
+ # Generate report
+ self._generate_report(analysis)
+
+ return analysis
+
+ def _analyze_results(self, all_results: Dict[str, List[Dict]]) -> Dict:
+ """Analyze results across all methods."""
+ analysis = {
+ "per_method": {},
+ "comparison": {},
+ }
+
+ for method, results in all_results.items():
+ n = len(results)
+ if n == 0:
+ continue
+
+ # Aggregate metrics
+ task_success = sum(r["metrics"]["task_success"] for r in results) / n
+ avg_user_tokens = sum(r["metrics"]["user_token_count"] for r in results) / n
+ avg_total_tokens = sum(r["metrics"]["total_token_count"] for r in results) / n
+ avg_enforcement = sum(r["metrics"]["enforcement_count"] for r in results) / n
+ avg_turns = sum(r["metrics"]["total_turns"] for r in results) / n
+
+ # Compliance and conflict metrics
+ compliance_scores = [
+ sum(r["metrics"]["preference_compliance_scores"]) / len(r["metrics"]["preference_compliance_scores"])
+ if r["metrics"]["preference_compliance_scores"] else 0.5
+ for r in results
+ ]
+ avg_compliance = sum(compliance_scores) / len(compliance_scores)
+
+ conflict_results = [r for r in results if r["is_conflict_test"]]
+ conflict_accuracy = sum(
+ r["metrics"]["conflict_resolution_accuracy"] for r in conflict_results
+ ) / len(conflict_results) if conflict_results else 0
+
+ over_personalization = sum(
+ r["metrics"]["over_personalization_rate"] for r in results
+ ) / n
+
+ analysis["per_method"][method] = {
+ "n_sessions": n,
+ "task_success_rate": task_success,
+ "avg_user_tokens": avg_user_tokens,
+ "avg_total_tokens": avg_total_tokens,
+ "avg_enforcement_count": avg_enforcement,
+ "avg_turns": avg_turns,
+ "avg_preference_compliance": avg_compliance,
+ "conflict_resolution_accuracy": conflict_accuracy,
+ "over_personalization_rate": over_personalization,
+ }
+
+ # Comparison
+ metrics_to_compare = [
+ ("task_success_rate", True), # higher is better
+ ("avg_user_tokens", False), # lower is better
+ ("avg_total_tokens", False), # lower is better
+ ("avg_enforcement_count", False), # lower is better
+ ("avg_preference_compliance", True), # higher is better
+ ("conflict_resolution_accuracy", True), # higher is better
+ ("over_personalization_rate", False), # lower is better
+ ]
+
+ for metric, higher_better in metrics_to_compare:
+ values = {m: analysis["per_method"][m][metric] for m in analysis["per_method"]}
+ if not values:
+ logger.warning(f"No values for metric {metric}, skipping comparison")
+ continue
+ if higher_better:
+ best = max(values, key=values.get)
+ else:
+ best = min(values, key=values.get)
+
+ analysis["comparison"][metric] = {
+ "values": values,
+ "best_method": best,
+ "best_value": values[best],
+ }
+
+ return analysis
+
+ def _generate_report(self, analysis: Dict) -> None:
+ """Generate a human-readable report."""
+ report_lines = [
+ "# Personalization Experiment Report",
+ f"\nGenerated: {datetime.now().isoformat()}",
+ f"\nConfig: {self.config.n_profiles} profiles, {self.config.n_sessions_per_profile} sessions each",
+ "\n## Method Comparison\n",
+ ]
+
+ # Create comparison table
+ metrics_display = [
+ ("Task Success", "task_success_rate", "{:.1%}"),
+ ("User Effort (tokens)", "avg_user_tokens", "{:.0f}"),
+ ("Total Tokens", "avg_total_tokens", "{:.0f}"),
+ ("Enforcement Count", "avg_enforcement_count", "{:.2f}"),
+ ("Preference Compliance", "avg_preference_compliance", "{:.1%}"),
+ ("Conflict Resolution", "conflict_resolution_accuracy", "{:.1%}"),
+ ("Over-personalization", "over_personalization_rate", "{:.1%}"),
+ ]
+
+ methods = list(analysis["per_method"].keys())
+
+ # Header
+ header = "| Metric |" + "|".join(f" {m} " for m in methods) + "| Best |"
+ separator = "|" + "|".join(["-" * (len(m) + 2) for m in ["Metric"] + methods + ["Best"]]) + "|"
+
+ report_lines.extend([header, separator])
+
+ for display_name, metric_key, fmt in metrics_display:
+ row = f"| {display_name} |"
+ for m in methods:
+ val = analysis["per_method"].get(m, {}).get(metric_key, 0)
+ row += f" {fmt.format(val)} |"
+
+ if metric_key in analysis.get("comparison", {}):
+ best = analysis["comparison"][metric_key]["best_method"]
+ else:
+ best = "N/A"
+ row += f" {best} |"
+ report_lines.append(row)
+
+ # Key findings
+ report_lines.extend([
+ "\n## Key Findings\n",
+ ])
+
+ # Find advantages of proposed methods
+ rag_vector = analysis["per_method"].get("rag_vector", {})
+ rag = analysis["per_method"].get("rag", {})
+ contextual = analysis["per_method"].get("contextual", {})
+ all_memory = analysis["per_method"].get("all_memory", {})
+
+ if rag_vector and contextual:
+ token_reduction = (contextual.get("avg_total_tokens", 0) - rag_vector.get("avg_total_tokens", 0)) / contextual.get("avg_total_tokens", 1) * 100
+ report_lines.append(f"- **Token Efficiency**: RAG+Vector uses {token_reduction:.1f}% fewer tokens than contextual memory")
+
+ if rag_vector and all_memory:
+ conflict_improvement = rag_vector.get("conflict_resolution_accuracy", 0) - all_memory.get("conflict_resolution_accuracy", 0)
+ report_lines.append(f"- **Conflict Resolution**: RAG+Vector improves by {conflict_improvement:.1%} over all-memory baseline")
+
+ if rag_vector:
+ report_lines.append(f"- **Over-personalization**: RAG+Vector rate: {rag_vector.get('over_personalization_rate', 0):.1%}")
+
+ # Save report
+ report_path = self.output_dir / "report.md"
+ with open(report_path, "w") as f:
+ f.write("\n".join(report_lines))
+
+ logger.info(f"Report saved to {report_path}")
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Run personalization experiments")
+ parser.add_argument("--config", type=str, help="Path to config YAML file")
+ parser.add_argument("--methods", type=str, default="vanilla,contextual,rag,rag_vector",
+ help="Comma-separated list of methods to compare")
+ parser.add_argument("--datasets", type=str, default="math-hard,math-500,bigcodebench",
+ help="Comma-separated list of datasets")
+ parser.add_argument("--n-profiles", type=int, default=200, help="Number of user profiles")
+ parser.add_argument("--n-sessions", type=int, default=30, help="Sessions per profile")
+ parser.add_argument("--max-turns", type=int, default=15, help="Max turns per session")
+ parser.add_argument("--output-dir", type=str, default="results", help="Output directory")
+ parser.add_argument("--profile-path", type=str, help="Path to pre-generated profiles")
+ parser.add_argument("--start-profile", type=int, default=0,
+ help="Start profile index (inclusive, 0-indexed)")
+ parser.add_argument("--end-profile", type=int, default=None,
+ help="End profile index (exclusive). If not set, runs all profiles from start")
+
+ # vLLM and parallel processing options
+ parser.add_argument("--use-vllm", action="store_true",
+ help="Use vLLM servers for inference (much faster)")
+ parser.add_argument("--vllm-user-url", type=str, default="http://localhost:8004/v1",
+ help="vLLM server URL for user simulator (70B)")
+ parser.add_argument("--vllm-agent-url", type=str, default="http://localhost:8003/v1",
+ help="vLLM server URL for agent (8B)")
+ # OpenAI user agent options
+ parser.add_argument("--use-openai-user", action="store_true",
+ help="Use OpenAI API (GPT-5) for user simulation instead of vLLM")
+ parser.add_argument("--openai-user-model", type=str, default="gpt-5",
+ help="OpenAI model name for user simulator (default: gpt-5)")
+ parser.add_argument("--reward-mode", type=str, default="keyword", choices=["keyword", "llm"],
+ help="Reward mode for RL updates: 'keyword' (user signals) or 'llm' (GPT-5-nano judge)")
+
+ parser.add_argument("--parallel-profiles", type=int, default=50,
+ help="Number of profiles to process in parallel (requires --use-vllm)")
+ parser.add_argument("--use-batch-processing", action="store_true", default=True,
+ help="Use turn-synchronous batch processing for vanilla/all_memory")
+ parser.add_argument("--no-batch-processing", action="store_false", dest="use_batch_processing",
+ help="Disable batch processing")
+ parser.add_argument("--batch-size", type=int, default=50,
+ help="Number of conversations to batch together")
+ parser.add_argument("--continue-from", type=str, default=None,
+ help="Path to existing output directory to continue from (for extending sessions)")
+
+ args = parser.parse_args()
+
+ # Load or create config
+ if args.config and Path(args.config).exists():
+ with open(args.config) as f:
+ config_dict = yaml.safe_load(f)
+ config = ExperimentConfig(**config_dict)
+ else:
+ config = ExperimentConfig(
+ methods=args.methods.split(","),
+ datasets=args.datasets.split(","),
+ n_profiles=args.n_profiles,
+ n_sessions_per_profile=args.n_sessions,
+ max_turns_per_session=args.max_turns,
+ output_dir=args.output_dir,
+ profile_path=args.profile_path,
+ start_profile=args.start_profile,
+ end_profile=args.end_profile,
+ use_vllm=args.use_vllm,
+ vllm_user_url=args.vllm_user_url,
+ vllm_agent_url=args.vllm_agent_url,
+ use_openai_user=args.use_openai_user,
+ openai_user_model=args.openai_user_model,
+ reward_mode=args.reward_mode,
+ parallel_profiles=args.parallel_profiles,
+ use_batch_processing=args.use_batch_processing,
+ batch_size_conversations=args.batch_size,
+ continue_from=args.continue_from,
+ )
+
+ # Run experiments
+ runner = ExperimentRunner(config)
+ analysis = runner.run_all()
+
+ print("\n" + "=" * 60)
+ print("EXPERIMENT COMPLETE")
+ print("=" * 60)
+ print(f"\nResults saved to: {runner.output_dir}")
+ if analysis.get("comparison"):
+ print("\nBest methods per metric:")
+ for metric, data in analysis["comparison"].items():
+ print(f" {metric}: {data['best_method']} ({data['best_value']:.3f})")
+ else:
+ print("\nNo comparison data available (sessions may have failed)")
+
+
+if __name__ == "__main__":
+ main()