diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-27 09:57:37 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-27 09:57:37 -0600 |
| commit | dc801c07cf38b0c495686463e6ca6f871a64440e (patch) | |
| tree | 599f03114775921dbc472403c701f4a3a8ea188a /src/personalization/evaluation | |
| parent | e43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (diff) | |
Add collaborativeagents module and update gitignore
- Add collaborativeagents subproject with adapters, agents, and evaluation modules
- Update .gitignore to exclude large binary files (.whl, .tar), wandb logs, and results
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'src/personalization/evaluation')
16 files changed, 2759 insertions, 0 deletions
diff --git a/src/personalization/evaluation/baselines/__init__.py b/src/personalization/evaluation/baselines/__init__.py new file mode 100644 index 0000000..b6a5761 --- /dev/null +++ b/src/personalization/evaluation/baselines/__init__.py @@ -0,0 +1,7 @@ +from .base import BaselineAgent, AgentResponse +from .no_memory import NoMemoryAgent +from .rag_memory import RAGMemoryAgent + +__all__ = ["BaselineAgent", "AgentResponse", "NoMemoryAgent", "RAGMemoryAgent"] + + diff --git a/src/personalization/evaluation/baselines/base.py b/src/personalization/evaluation/baselines/base.py new file mode 100644 index 0000000..a3051bd --- /dev/null +++ b/src/personalization/evaluation/baselines/base.py @@ -0,0 +1,83 @@ +""" +Base class for all baseline agents. + +All agents must implement: +- respond(): Generate a response to user query +- end_session(): Called when a session ends (for memory updates) +- reset_user(): Reset all state for a user +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import List, Dict, Any, Optional + + +@dataclass +class AgentResponse: + """Response from an agent.""" + answer: str + debug_info: Dict[str, Any] = field(default_factory=dict) + + +class BaselineAgent(ABC): + """Abstract base class for all baseline agents.""" + + def __init__(self, model_name: str, **kwargs): + """ + Args: + model_name: Name/path of the LLM to use + **kwargs: Additional configuration + """ + self.model_name = model_name + self.config = kwargs + + @abstractmethod + def respond( + self, + user_id: str, + query: str, + conversation_history: List[Dict[str, str]], + **kwargs + ) -> AgentResponse: + """ + Generate a response to the user's query. + + Args: + user_id: Unique identifier for the user + query: Current user message + conversation_history: List of previous messages [{"role": "user/assistant", "content": "..."}] + **kwargs: Additional context (e.g., task info) + + Returns: + AgentResponse with answer and debug info + """ + pass + + @abstractmethod + def end_session(self, user_id: str, conversation: List[Dict[str, str]]): + """ + Called when a session (one task) ends. + Use this to update memory, notes, etc. + + Args: + user_id: User identifier + conversation: Complete conversation from this session + """ + pass + + @abstractmethod + def reset_user(self, user_id: str): + """ + Completely reset all state for a user. + Called at the start of a new experiment. + + Args: + user_id: User identifier + """ + pass + + def get_name(self) -> str: + """Get a descriptive name for this agent.""" + return self.__class__.__name__ + + diff --git a/src/personalization/evaluation/baselines/no_memory.py b/src/personalization/evaluation/baselines/no_memory.py new file mode 100644 index 0000000..bf4a7cf --- /dev/null +++ b/src/personalization/evaluation/baselines/no_memory.py @@ -0,0 +1,143 @@ +""" +No Memory Baseline (T1) + +A simple agent that has no memory of previous sessions. +Only sees the current conversation history within a session. +""" + +from typing import List, Dict, Any, Optional +import os + +from .base import BaselineAgent, AgentResponse + + +# System prompt for the agent +AGENT_SYSTEM_PROMPT = """You are a helpful AI assistant helping users solve problems. + +Guidelines: +- If the user's request is unclear, ask for clarification +- Provide clear, well-structured answers +- Adapt to user feedback and preferences expressed in the conversation +- Be helpful and do your best to solve the user's problem + +Your output should be a direct response to the user.""" + + +class NoMemoryAgent(BaselineAgent): + """ + T1: Base model with no memory. + + This agent: + - Has no memory across sessions + - Only uses current conversation context + - Represents the baseline "no personalization" case + """ + + def __init__( + self, + model_name: str = "llama-8b", + api_base: Optional[str] = None, + api_key: Optional[str] = None, + max_new_tokens: int = 512, + temperature: float = 0.7, + **kwargs + ): + super().__init__(model_name, **kwargs) + + self.api_base = api_base or os.getenv("OPENAI_API_BASE", "http://localhost:8003/v1") + self.api_key = api_key or os.getenv("OPENAI_API_KEY", "EMPTY") + self.max_new_tokens = max_new_tokens + self.temperature = temperature + + # Initialize client + self._init_client() + + def _init_client(self): + """Initialize the LLM client.""" + try: + import openai + self.client = openai.OpenAI( + base_url=self.api_base, + api_key=self.api_key, + ) + except Exception as e: + print(f"Warning: Could not initialize OpenAI client: {e}") + self.client = None + + def _build_messages( + self, + conversation_history: List[Dict[str, str]], + query: str, + ) -> List[Dict[str, str]]: + """Build messages for the LLM.""" + messages = [{"role": "system", "content": AGENT_SYSTEM_PROMPT}] + + # Add conversation history + for msg in conversation_history: + messages.append({ + "role": msg["role"], + "content": msg["content"], + }) + + # Add current query if not already in history + if not conversation_history or conversation_history[-1]["content"] != query: + messages.append({"role": "user", "content": query}) + + return messages + + def respond( + self, + user_id: str, + query: str, + conversation_history: List[Dict[str, str]], + **kwargs + ) -> AgentResponse: + """Generate response using only current conversation context.""" + + messages = self._build_messages(conversation_history, query) + + if self.client is None: + # Fallback for testing without LLM + return AgentResponse( + answer=f"[NoMemoryAgent] Response to: {query[:50]}...", + debug_info={"mode": "fallback", "num_messages": len(messages)}, + ) + + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=messages, + max_tokens=self.max_new_tokens, + temperature=self.temperature, + ) + + answer = response.choices[0].message.content + + return AgentResponse( + answer=answer, + debug_info={ + "num_messages": len(messages), + "prompt_tokens": response.usage.prompt_tokens if response.usage else 0, + "completion_tokens": response.usage.completion_tokens if response.usage else 0, + }, + ) + + except Exception as e: + print(f"Error calling LLM: {e}") + return AgentResponse( + answer=f"I apologize, but I encountered an error. Let me try again: {query[:100]}", + debug_info={"error": str(e)}, + ) + + def end_session(self, user_id: str, conversation: List[Dict[str, str]]): + """No-op for no-memory agent.""" + pass + + def reset_user(self, user_id: str): + """No-op for no-memory agent.""" + pass + + def get_name(self) -> str: + return f"NoMemory({self.model_name})" + + diff --git a/src/personalization/evaluation/baselines/rag_memory.py b/src/personalization/evaluation/baselines/rag_memory.py new file mode 100644 index 0000000..2b391c3 --- /dev/null +++ b/src/personalization/evaluation/baselines/rag_memory.py @@ -0,0 +1,204 @@ +""" +RAG Memory Baseline (Y3/Y4) + +Wraps the PersonalizedLLM for use in the evaluation framework. +Y3: Extractor + RAG (mode="nopersonal") +Y4: Extractor + RAG + User Vector (mode="full") +""" + +from typing import List, Dict, Any, Optional +import os +import sys + +from .base import BaselineAgent, AgentResponse + +# Add src to path for imports +_src_path = os.path.join(os.path.dirname(__file__), "../../../..") +if _src_path not in sys.path: + sys.path.insert(0, _src_path) + + +class RAGMemoryAgent(BaselineAgent): + """ + Y3/Y4: RAG-based memory with optional user vector. + + This agent: + - Extracts preferences from conversations using the extractor + - Stores preferences as memory cards + - Retrieves relevant memories using RAG for each query + - (Y4 only) Uses user vector to personalize retrieval + """ + + def __init__( + self, + model_name: str = "llama-8b", + mode: str = "nopersonal", # "nopersonal" for Y3, "full" for Y4 + memory_cards_path: str = None, + memory_embeddings_path: str = None, + enable_preference_extraction: bool = True, + enable_rl_updates: bool = False, + only_own_memories: bool = True, + **kwargs + ): + """ + Args: + model_name: LLM model to use + mode: "nopersonal" (Y3) or "full" (Y4) + memory_cards_path: Path to memory cards file + memory_embeddings_path: Path to embeddings file + enable_preference_extraction: Whether to extract preferences + enable_rl_updates: Whether to update user vectors (Y4 only) + only_own_memories: Only retrieve user's own memories + """ + super().__init__(model_name, **kwargs) + + self.mode = mode + self.enable_rl_updates = enable_rl_updates and (mode == "full") + + # Default paths + base_dir = os.path.join(os.path.dirname(__file__), "../../../../..") + self.memory_cards_path = memory_cards_path or os.path.join( + base_dir, "data/eval/memory_cards.jsonl" + ) + self.memory_embeddings_path = memory_embeddings_path or os.path.join( + base_dir, "data/eval/memory_embeddings.npy" + ) + + self.enable_preference_extraction = enable_preference_extraction + self.only_own_memories = only_own_memories + + # Lazy initialization + self._llm = None + self._initialized = False + + def _ensure_initialized(self): + """Lazy initialization of PersonalizedLLM.""" + if self._initialized: + return + + try: + from personalization.serving.personalized_llm import PersonalizedLLM + + self._llm = PersonalizedLLM( + mode=self.mode, + enable_preference_extraction=self.enable_preference_extraction, + enable_rl_updates=self.enable_rl_updates, + only_own_memories=self.only_own_memories, + memory_cards_path=self.memory_cards_path, + memory_embeddings_path=self.memory_embeddings_path, + eval_mode=True, # Deterministic selection + ) + self._initialized = True + + except Exception as e: + print(f"Warning: Could not initialize PersonalizedLLM: {e}") + print("Falling back to simple response mode.") + self._llm = None + self._initialized = True + + def respond( + self, + user_id: str, + query: str, + conversation_history: List[Dict[str, str]], + **kwargs + ) -> AgentResponse: + """Generate response using RAG memory.""" + + self._ensure_initialized() + + if self._llm is None: + # Fallback mode + return AgentResponse( + answer=f"[RAGMemoryAgent-{self.mode}] Response to: {query[:50]}...", + debug_info={"mode": "fallback"}, + ) + + try: + # Use PersonalizedLLM's chat interface + response = self._llm.chat(user_id, query) + + debug_info = { + "mode": self.mode, + "num_memories_retrieved": len(response.debug.selected_memory_ids) if response.debug else 0, + "selected_memories": response.debug.selected_memory_notes if response.debug else [], + "extracted_preferences": response.debug.extracted_preferences if response.debug else [], + } + + if response.debug and response.debug.extra: + debug_info.update(response.debug.extra) + + return AgentResponse( + answer=response.answer, + debug_info=debug_info, + ) + + except Exception as e: + print(f"Error in RAGMemoryAgent.respond: {e}") + return AgentResponse( + answer=f"I apologize for the error. Regarding: {query[:100]}", + debug_info={"error": str(e)}, + ) + + def end_session(self, user_id: str, conversation: List[Dict[str, str]]): + """ + Called at end of session. + PersonalizedLLM already extracts preferences during chat(), + so we just reset the session state. + """ + self._ensure_initialized() + + if self._llm is not None: + self._llm.reset_session(user_id) + + def reset_user(self, user_id: str): + """Reset all state for a user.""" + self._ensure_initialized() + + if self._llm is not None: + self._llm.reset_user(user_id) + + def apply_feedback(self, user_id: str, reward: float, gating: float = 1.0): + """ + Apply feedback for user vector updates (Y4 only). + + Args: + user_id: User identifier + reward: Reward signal (e.g., from preference satisfaction) + gating: Gating signal (1.0 = use this feedback, 0.0 = skip) + """ + if not self.enable_rl_updates or self._llm is None: + return + + try: + from personalization.serving.personalized_llm import Feedback + + feedback = Feedback( + user_id=user_id, + turn_id=0, # Not used in current implementation + reward=reward, + gating=gating, + ) + self._llm.apply_feedback(feedback) + + except Exception as e: + print(f"Error applying feedback: {e}") + + def get_user_state(self, user_id: str) -> Dict[str, Any]: + """Get user state summary (for Y4 analysis).""" + self._ensure_initialized() + + if self._llm is not None: + return self._llm.get_user_state_summary(user_id) + return {} + + def persist(self): + """Save all state to disk.""" + if self._llm is not None: + self._llm.persist() + + def get_name(self) -> str: + mode_name = "RAG" if self.mode == "nopersonal" else "RAG+UV" + return f"{mode_name}({self.model_name})" + + diff --git a/src/personalization/evaluation/demo/__init__.py b/src/personalization/evaluation/demo/__init__.py new file mode 100644 index 0000000..7d50041 --- /dev/null +++ b/src/personalization/evaluation/demo/__init__.py @@ -0,0 +1,3 @@ +# Demo scripts for evaluation + + diff --git a/src/personalization/evaluation/demo/run_demo.py b/src/personalization/evaluation/demo/run_demo.py new file mode 100644 index 0000000..805d046 --- /dev/null +++ b/src/personalization/evaluation/demo/run_demo.py @@ -0,0 +1,273 @@ +#!/usr/bin/env python3 +""" +Demo Runner Script + +A minimal demo to verify the evaluation pipeline works: +- Generates preference bank (5 topics × 5 prefs = 25 total) +- Creates 2 user profiles (10 prefs each) +- Runs 3 tasks per user +- Compares T1 (NoMemory) vs Y3 (RAG) agents + +Usage: + # With LLM servers running: + python run_demo.py + + # Dry run (no LLM, uses fallback responses): + python run_demo.py --dry-run + + # Specify output directory: + python run_demo.py --output-dir /path/to/output +""" + +import argparse +import os +import sys + +# Add src to path +_src_path = os.path.join(os.path.dirname(__file__), "../../../..") +if _src_path not in sys.path: + sys.path.insert(0, _src_path) + + +def run_preference_bank_demo(): + """Generate and display a demo preference bank.""" + print("\n" + "=" * 60) + print("STEP 1: Generate Preference Bank") + print("=" * 60) + + from personalization.evaluation.preference_bank.generator import generate_demo_bank + + output_dir = "data/eval/demo" + os.makedirs(output_dir, exist_ok=True) + + bank_path = os.path.join(output_dir, "preference_bank.json") + bank = generate_demo_bank(output_path=bank_path, use_llm=False) + + print(f"\nGenerated preference bank with {bank.stats()['total_preferences']} preferences") + print(f"Topics: {list(bank.topics.keys())}") + + # Show sample preferences + print("\nSample preferences:") + for topic_name, topic in list(bank.topics.items())[:2]: + print(f"\n {topic_name}:") + for pref in topic.preferences[:2]: + print(f" - When {pref.condition}: {pref.action}") + + return bank + + +def run_profile_demo(bank): + """Generate demo user profiles.""" + print("\n" + "=" * 60) + print("STEP 2: Generate User Profiles") + print("=" * 60) + + from personalization.evaluation.profiles.generator import generate_demo_profiles + + output_dir = "data/eval/demo" + profiles_path = os.path.join(output_dir, "user_profiles.json") + + profiles = generate_demo_profiles( + bank=bank, + num_users=2, + prefs_per_user=10, + output_path=profiles_path, + seed=42, + ) + + print(f"\nGenerated {len(profiles)} user profiles") + + for profile in profiles: + print(f"\n {profile.user_id}:") + print(f" Persona: {profile.persona}") + print(f" Primary topics: {profile.primary_topics}") + print(f" Num preferences: {len(profile.preferences)}") + + return profiles + + +def run_agent_demo(dry_run: bool = True): + """Test agent response generation.""" + print("\n" + "=" * 60) + print("STEP 3: Test Agent Responses") + print("=" * 60) + + from personalization.evaluation.baselines.no_memory import NoMemoryAgent + + # Create agent (will use fallback if no LLM available) + agent = NoMemoryAgent( + model_name="llama-8b", + api_base="http://localhost:8003/v1" if not dry_run else None, + ) + + # Test response + test_query = "What is 2 + 2?" + response = agent.respond( + user_id="test_user", + query=test_query, + conversation_history=[], + ) + + print(f"\nQuery: {test_query}") + print(f"Response: {response.answer[:200]}...") + print(f"Debug: {response.debug_info}") + + return agent + + +def run_user_simulator_demo(profiles, dry_run: bool = True): + """Test user simulator.""" + print("\n" + "=" * 60) + print("STEP 4: Test User Simulator") + print("=" * 60) + + from personalization.evaluation.user_simulator.simulator import UserSimulator + from personalization.evaluation.pipeline.evaluator import Task + + # Create simulator + simulator = UserSimulator( + model_name="Llama-3.3-70B-Instruct", + api_base="http://localhost:8004/v1" if not dry_run else None, + ) + + # Setup with first profile + profile = profiles[0] + task = Task( + task_id="test_001", + dataset="test", + problem="What is the derivative of x^2?", + solution="2x", + task_description="Solve this calculus problem:", + ) + + simulator.setup( + profile=profile, + task_description=task.task_description, + problem=task.problem, + solution=task.solution, + ) + + # Simulate first turn + conversation = [ + {"role": "assistant", "content": "How can I help you?"} + ] + + response = simulator.respond(conversation) + + print(f"\nUser profile: {profile.user_id}") + print(f"Task: {task.problem}") + print(f"\nUser response: {response.response[:200]}...") + print(f"Enforcement needed: {response.enforcement_needed}") + print(f"Draft answer: {response.draft_answer}") + + return simulator + + +def run_full_demo(dry_run: bool = True, output_dir: str = "data/eval/demo"): + """Run complete demo experiment.""" + print("\n" + "=" * 60) + print("STEP 5: Run Full Demo Experiment") + print("=" * 60) + + if dry_run: + print("\n[DRY RUN MODE] Using fallback responses, no LLM calls\n") + + from personalization.evaluation.pipeline.runner import ExperimentRunner, ExperimentConfig + + config = ExperimentConfig( + name="demo_experiment", + output_dir=output_dir, + num_users=2, + prefs_per_user=10, + tasks_per_user=2, # Just 2 tasks for quick demo + max_turns=10, # Short conversations + run_no_memory=True, + run_rag=False, # Skip RAG for initial demo (needs more setup) + run_rag_uv=False, + agent_api_base="http://localhost:8003/v1" if not dry_run else "http://localhost:9999/v1", + user_sim_api_base="http://localhost:8004/v1" if not dry_run else "http://localhost:9999/v1", + ) + + runner = ExperimentRunner(config) + runner.setup() + metrics = runner.run() + + return metrics + + +def main(): + parser = argparse.ArgumentParser(description="Run evaluation demo") + parser.add_argument( + "--dry-run", + action="store_true", + help="Run without LLM (uses fallback responses)", + ) + parser.add_argument( + "--output-dir", + type=str, + default="data/eval/demo", + help="Output directory for results", + ) + parser.add_argument( + "--step", + type=str, + choices=["bank", "profiles", "agent", "simulator", "full", "all"], + default="all", + help="Which step to run", + ) + + args = parser.parse_args() + + print("\n" + "=" * 60) + print("PERSONALIZATION EVALUATION DEMO") + print("=" * 60) + print(f"Mode: {'DRY RUN (no LLM)' if args.dry_run else 'LIVE (requires LLM servers)'}") + print(f"Output: {args.output_dir}") + print("=" * 60) + + os.makedirs(args.output_dir, exist_ok=True) + + if args.step in ["bank", "all"]: + bank = run_preference_bank_demo() + else: + # Load existing bank + from personalization.evaluation.preference_bank.schemas import PreferenceBank + bank_path = os.path.join(args.output_dir, "preference_bank.json") + if os.path.exists(bank_path): + bank = PreferenceBank.load(bank_path) + else: + bank = run_preference_bank_demo() + + if args.step in ["profiles", "all"]: + profiles = run_profile_demo(bank) + else: + from personalization.evaluation.profiles.generator import UserProfileGenerator + profiles_path = os.path.join(args.output_dir, "user_profiles.json") + if os.path.exists(profiles_path): + profiles = UserProfileGenerator.load_profiles(profiles_path) + else: + profiles = run_profile_demo(bank) + + if args.step in ["agent", "all"]: + run_agent_demo(dry_run=args.dry_run) + + if args.step in ["simulator", "all"]: + run_user_simulator_demo(profiles, dry_run=args.dry_run) + + if args.step in ["full", "all"]: + run_full_demo(dry_run=args.dry_run, output_dir=args.output_dir) + + print("\n" + "=" * 60) + print("DEMO COMPLETE!") + print("=" * 60) + print(f"\nResults saved to: {args.output_dir}/") + print("\nNext steps:") + print(" 1. Start LLM servers (vLLM/SGLang)") + print(" 2. Run without --dry-run flag") + print(" 3. Enable RAG baseline for full comparison") + + +if __name__ == "__main__": + main() + + diff --git a/src/personalization/evaluation/pipeline/__init__.py b/src/personalization/evaluation/pipeline/__init__.py new file mode 100644 index 0000000..183d0c5 --- /dev/null +++ b/src/personalization/evaluation/pipeline/__init__.py @@ -0,0 +1,6 @@ +from .evaluator import Evaluator, SessionResult, EvaluationMetrics +from .runner import ExperimentRunner + +__all__ = ["Evaluator", "SessionResult", "EvaluationMetrics", "ExperimentRunner"] + + diff --git a/src/personalization/evaluation/pipeline/evaluator.py b/src/personalization/evaluation/pipeline/evaluator.py new file mode 100644 index 0000000..7304400 --- /dev/null +++ b/src/personalization/evaluation/pipeline/evaluator.py @@ -0,0 +1,353 @@ +""" +Evaluation Pipeline + +Runs evaluation sessions between user simulator and agents. +Computes metrics: Task Success (TS), User Effort (UE), Efficiency (Eff). +""" + +import json +import os +from dataclasses import dataclass, field, asdict +from typing import List, Dict, Any, Optional +from datetime import datetime + +from ..profiles.generator import UserProfile +from ..preference_bank.schemas import PreferenceBank +from ..baselines.base import BaselineAgent +from ..user_simulator.simulator import UserSimulator, UserSimulatorResponse + + +@dataclass +class Task: + """A problem/task for evaluation.""" + task_id: str + dataset: str + problem: str + solution: str + task_description: str = "Work with the assistant to solve this problem:" + + +@dataclass +class SessionResult: + """Result of a single evaluation session.""" + user_id: str + task_id: str + dataset: str + agent_name: str + + # Metrics + task_success: bool # TS: Was the task solved correctly? + user_effort: int # UE: Number of preference enforcements + efficiency: int # Eff: Total number of messages + + # Details + conversation: List[Dict[str, str]] + preference_violations: List[Dict[str, Any]] + final_draft_answer: str + + # Debug + debug_info: Dict[str, Any] = field(default_factory=dict) + timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +@dataclass +class EvaluationMetrics: + """Aggregated evaluation metrics.""" + agent_name: str + num_sessions: int + + # Average metrics + avg_task_success: float # Average TS + avg_user_effort: float # Average UE + avg_efficiency: float # Average Eff + + # Breakdowns + task_success_by_dataset: Dict[str, float] = field(default_factory=dict) + user_effort_by_dataset: Dict[str, float] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +class JudgeModel: + """ + LLM judge for evaluating task success. + Uses the same approach as collaborativeagents. + """ + + def __init__( + self, + model_name: str = "Llama-3.3-70B-Instruct", + api_base: Optional[str] = None, + api_key: Optional[str] = None, + ): + self.model_name = model_name + self.api_base = api_base or os.getenv("JUDGE_API_BASE", "http://localhost:8004/v1") + self.api_key = api_key or os.getenv("JUDGE_API_KEY", "EMPTY") + + self._init_client() + + def _init_client(self): + try: + import openai + self.client = openai.OpenAI( + base_url=self.api_base, + api_key=self.api_key, + ) + except Exception as e: + print(f"Warning: Could not initialize judge client: {e}") + self.client = None + + def evaluate_answer( + self, + problem: str, + correct_answer: str, + user_draft_answer: str, + ) -> bool: + """ + Evaluate if the user's draft answer is correct. + + Returns: + True if answer is correct, False otherwise + """ + prompt = f"""You are an expert evaluator. Determine if the user's answer is correct. + +# Problem +{problem} + +# Correct Answer +{correct_answer} + +# User's Answer +{user_draft_answer} + +# Instructions +Determine if the user's answer is accurate and consistent with the correct answer. +Minor formatting differences are acceptable. +The core answer/solution must match. + +Output JSON: +{{ + "reasoning": "Brief explanation", + "is_correct": true or false +}} + +Output only valid JSON.""" + + if self.client is None: + # Fallback - simple string matching + return correct_answer.lower().strip() in user_draft_answer.lower() + + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=[{"role": "user", "content": prompt}], + temperature=0.0, + max_tokens=256, + ) + + text = response.choices[0].message.content.strip() + + # Parse JSON + if "```" in text: + text = text.split("```")[1] + if text.startswith("json"): + text = text[4:] + + data = json.loads(text) + return data.get("is_correct", False) + + except Exception as e: + print(f"Error in judge evaluation: {e}") + # Fallback + return correct_answer.lower().strip() in user_draft_answer.lower() + + +class Evaluator: + """ + Main evaluator that runs sessions and computes metrics. + """ + + def __init__( + self, + user_simulator: Optional[UserSimulator] = None, + judge: Optional[JudgeModel] = None, + ): + self.user_sim = user_simulator or UserSimulator() + self.judge = judge or JudgeModel() + + def run_session( + self, + agent: BaselineAgent, + user_profile: UserProfile, + task: Task, + max_turns: int = 30, + ) -> SessionResult: + """ + Run a single evaluation session. + + Args: + agent: The agent being evaluated + user_profile: User with preferences + task: Task to solve + max_turns: Maximum conversation turns + + Returns: + SessionResult with metrics and conversation + """ + # Setup user simulator + self.user_sim.setup( + profile=user_profile, + task_description=task.task_description, + problem=task.problem, + solution=task.solution, + ) + + conversation: List[Dict[str, str]] = [] + preference_violations: List[Dict[str, Any]] = [] + user_effort = 0 + final_draft_answer = "I don't know" + + # Agent opens the conversation + conversation.append({ + "role": "assistant", + "content": "How can I help you today?" + }) + + for turn in range(max_turns): + # User responds + user_response = self.user_sim.respond(conversation) + + conversation.append({ + "role": "user", + "content": user_response.response, + }) + + # Track preference violations and enforcement + violations_this_turn = [ + { + "turn": turn, + "preference_id": check.preference_id, + "topic": check.topic, + "violation_detail": check.violation_detail, + } + for check in user_response.preference_checks + if check.relevant and check.satisfied == False + ] + + if violations_this_turn: + preference_violations.extend(violations_this_turn) + + if user_response.enforcement_needed: + user_effort += 1 + + final_draft_answer = user_response.draft_answer + + # Check termination + if user_response.should_terminate or "TERMINATE" in user_response.response: + break + + # Agent responds + agent_response = agent.respond( + user_id=user_profile.user_id, + query=user_response.response, + conversation_history=conversation, + ) + + conversation.append({ + "role": "assistant", + "content": agent_response.answer, + }) + + # End session for agent (update memory, etc.) + agent.end_session(user_profile.user_id, conversation) + + # Evaluate task success + task_success = self.judge.evaluate_answer( + problem=task.problem, + correct_answer=task.solution, + user_draft_answer=final_draft_answer, + ) + + return SessionResult( + user_id=user_profile.user_id, + task_id=task.task_id, + dataset=task.dataset, + agent_name=agent.get_name(), + task_success=task_success, + user_effort=user_effort, + efficiency=len(conversation), + conversation=conversation, + preference_violations=preference_violations, + final_draft_answer=final_draft_answer, + debug_info={ + "num_turns": len(conversation) // 2, + "num_violations": len(preference_violations), + }, + ) + + def aggregate_metrics( + self, + results: List[SessionResult], + agent_name: str, + ) -> EvaluationMetrics: + """ + Aggregate metrics from multiple sessions. + """ + if not results: + return EvaluationMetrics( + agent_name=agent_name, + num_sessions=0, + avg_task_success=0.0, + avg_user_effort=0.0, + avg_efficiency=0.0, + ) + + # Overall averages + avg_ts = sum(r.task_success for r in results) / len(results) + avg_ue = sum(r.user_effort for r in results) / len(results) + avg_eff = sum(r.efficiency for r in results) / len(results) + + # By dataset + datasets = set(r.dataset for r in results) + ts_by_ds = {} + ue_by_ds = {} + + for ds in datasets: + ds_results = [r for r in results if r.dataset == ds] + if ds_results: + ts_by_ds[ds] = sum(r.task_success for r in ds_results) / len(ds_results) + ue_by_ds[ds] = sum(r.user_effort for r in ds_results) / len(ds_results) + + return EvaluationMetrics( + agent_name=agent_name, + num_sessions=len(results), + avg_task_success=avg_ts, + avg_user_effort=avg_ue, + avg_efficiency=avg_eff, + task_success_by_dataset=ts_by_ds, + user_effort_by_dataset=ue_by_ds, + ) + + def save_results(self, results: List[SessionResult], path: str): + """Save results to JSONL file.""" + with open(path, "w", encoding="utf-8") as f: + for result in results: + f.write(json.dumps(result.to_dict(), ensure_ascii=False) + "\n") + + @staticmethod + def load_results(path: str) -> List[SessionResult]: + """Load results from JSONL file.""" + results = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + data = json.loads(line) + # Reconstruct SessionResult + results.append(SessionResult(**data)) + return results + + diff --git a/src/personalization/evaluation/pipeline/runner.py b/src/personalization/evaluation/pipeline/runner.py new file mode 100644 index 0000000..9971c7b --- /dev/null +++ b/src/personalization/evaluation/pipeline/runner.py @@ -0,0 +1,333 @@ +""" +Experiment Runner + +Orchestrates the full evaluation experiment: +1. Generate/load preference bank and user profiles +2. Load datasets +3. Run sessions for all users × tasks × agents +4. Aggregate and report metrics +""" + +import json +import os +from dataclasses import dataclass +from typing import List, Dict, Any, Optional +from datetime import datetime +from tqdm import tqdm + +from ..preference_bank.schemas import PreferenceBank +from ..preference_bank.generator import generate_demo_bank +from ..profiles.generator import UserProfile, UserProfileGenerator, generate_demo_profiles +from ..baselines.base import BaselineAgent +from ..baselines.no_memory import NoMemoryAgent +from ..baselines.rag_memory import RAGMemoryAgent +from ..user_simulator.simulator import UserSimulator +from .evaluator import Evaluator, Task, SessionResult, EvaluationMetrics + + +# Demo dataset: Simple math problems +DEMO_TASKS = [ + Task( + task_id="math_001", + dataset="math-demo", + problem="What is the derivative of f(x) = x^3 + 2x^2 - 5x + 3?", + solution="f'(x) = 3x^2 + 4x - 5", + task_description="Work with the assistant to solve this calculus problem:", + ), + Task( + task_id="math_002", + dataset="math-demo", + problem="Solve for x: 2x + 5 = 3x - 7", + solution="x = 12", + task_description="Work with the assistant to solve this algebra problem:", + ), + Task( + task_id="math_003", + dataset="math-demo", + problem="Find the area of a circle with radius 5.", + solution="A = 25π ≈ 78.54 square units", + task_description="Work with the assistant to solve this geometry problem:", + ), + Task( + task_id="code_001", + dataset="code-demo", + problem="Write a Python function that checks if a string is a palindrome.", + solution="def is_palindrome(s): return s == s[::-1]", + task_description="Work with the assistant to write this Python function:", + ), + Task( + task_id="code_002", + dataset="code-demo", + problem="Write a function to find the nth Fibonacci number.", + solution="def fib(n): return n if n <= 1 else fib(n-1) + fib(n-2)", + task_description="Work with the assistant to implement this algorithm:", + ), +] + + +@dataclass +class ExperimentConfig: + """Configuration for an experiment run.""" + name: str + output_dir: str + + # Scale + num_users: int = 2 + prefs_per_user: int = 10 + tasks_per_user: int = 3 + max_turns: int = 25 + + # Baselines to run + run_no_memory: bool = True + run_rag: bool = True + run_rag_uv: bool = False # User vector mode + + # Model configs + agent_model: str = "llama-8b" + user_sim_model: str = "Llama-3.3-70B-Instruct" + judge_model: str = "Llama-3.3-70B-Instruct" + + # API endpoints + agent_api_base: str = "http://localhost:8003/v1" + user_sim_api_base: str = "http://localhost:8004/v1" + + seed: int = 42 + + +class ExperimentRunner: + """ + Runs a complete evaluation experiment. + """ + + def __init__(self, config: ExperimentConfig): + self.config = config + + # Create output directory + os.makedirs(config.output_dir, exist_ok=True) + + # Will be initialized lazily + self._bank: Optional[PreferenceBank] = None + self._profiles: Optional[List[UserProfile]] = None + self._tasks: Optional[List[Task]] = None + self._evaluator: Optional[Evaluator] = None + + def setup(self): + """Initialize all components.""" + print("=" * 60) + print(f"Setting up experiment: {self.config.name}") + print("=" * 60) + + # 1. Generate/load preference bank + bank_path = os.path.join(self.config.output_dir, "preference_bank.json") + if os.path.exists(bank_path): + print(f"Loading existing preference bank from {bank_path}") + self._bank = PreferenceBank.load(bank_path) + else: + print("Generating new preference bank...") + self._bank = generate_demo_bank(output_path=bank_path, use_llm=False) + + print(f" Bank stats: {self._bank.stats()}") + + # 2. Generate/load user profiles + profiles_path = os.path.join(self.config.output_dir, "user_profiles.json") + if os.path.exists(profiles_path): + print(f"Loading existing profiles from {profiles_path}") + self._profiles = UserProfileGenerator.load_profiles(profiles_path) + else: + print(f"Generating {self.config.num_users} user profiles...") + self._profiles = generate_demo_profiles( + bank=self._bank, + num_users=self.config.num_users, + prefs_per_user=self.config.prefs_per_user, + output_path=profiles_path, + seed=self.config.seed, + ) + + # 3. Load tasks + self._tasks = DEMO_TASKS[:self.config.tasks_per_user * 2] # Use demo tasks + print(f" Loaded {len(self._tasks)} tasks") + + # 4. Initialize evaluator + user_sim = UserSimulator( + model_name=self.config.user_sim_model, + api_base=self.config.user_sim_api_base, + ) + self._evaluator = Evaluator(user_simulator=user_sim) + + print("Setup complete!\n") + + def _create_agents(self) -> Dict[str, BaselineAgent]: + """Create agent instances based on config.""" + agents = {} + + if self.config.run_no_memory: + agents["T1_NoMemory"] = NoMemoryAgent( + model_name=self.config.agent_model, + api_base=self.config.agent_api_base, + ) + + if self.config.run_rag: + # Create directories for RAG memory + memory_dir = os.path.join(self.config.output_dir, "rag_memory") + os.makedirs(memory_dir, exist_ok=True) + + agents["Y3_RAG"] = RAGMemoryAgent( + model_name=self.config.agent_model, + mode="nopersonal", + memory_cards_path=os.path.join(memory_dir, "memory_cards.jsonl"), + memory_embeddings_path=os.path.join(memory_dir, "embeddings.npy"), + ) + + if self.config.run_rag_uv: + memory_dir = os.path.join(self.config.output_dir, "rag_uv_memory") + os.makedirs(memory_dir, exist_ok=True) + + agents["Y4_RAG_UV"] = RAGMemoryAgent( + model_name=self.config.agent_model, + mode="full", + memory_cards_path=os.path.join(memory_dir, "memory_cards.jsonl"), + memory_embeddings_path=os.path.join(memory_dir, "embeddings.npy"), + enable_rl_updates=True, + ) + + return agents + + def run(self) -> Dict[str, EvaluationMetrics]: + """ + Run the full experiment. + + Returns: + Dict mapping agent name to aggregated metrics + """ + if self._evaluator is None: + self.setup() + + agents = self._create_agents() + all_results: Dict[str, List[SessionResult]] = {name: [] for name in agents} + + print("=" * 60) + print("Running experiment") + print("=" * 60) + + # Run for each agent + for agent_name, agent in agents.items(): + print(f"\n>>> Agent: {agent_name}") + + # Run for each user + for profile in tqdm(self._profiles, desc=f"Users ({agent_name})"): + # Reset user state + agent.reset_user(profile.user_id) + + # Get tasks for this user + # In demo, just cycle through available tasks + user_tasks = self._tasks[:self.config.tasks_per_user] + + # Run sessions + for task in user_tasks: + result = self._evaluator.run_session( + agent=agent, + user_profile=profile, + task=task, + max_turns=self.config.max_turns, + ) + + all_results[agent_name].append(result) + + # Print progress + status = "✓" if result.task_success else "✗" + print(f" {profile.user_id} | {task.task_id} | " + f"TS={status} | UE={result.user_effort} | Eff={result.efficiency}") + + # Save raw results + for agent_name, results in all_results.items(): + results_path = os.path.join( + self.config.output_dir, + f"results_{agent_name}.jsonl" + ) + self._evaluator.save_results(results, results_path) + + # Aggregate metrics + metrics = {} + for agent_name, results in all_results.items(): + metrics[agent_name] = self._evaluator.aggregate_metrics(results, agent_name) + + # Save and print summary + self._save_summary(metrics) + self._print_summary(metrics) + + return metrics + + def _save_summary(self, metrics: Dict[str, EvaluationMetrics]): + """Save experiment summary.""" + summary = { + "experiment_name": self.config.name, + "timestamp": datetime.now().isoformat(), + "config": { + "num_users": self.config.num_users, + "prefs_per_user": self.config.prefs_per_user, + "tasks_per_user": self.config.tasks_per_user, + "max_turns": self.config.max_turns, + }, + "metrics": {name: m.to_dict() for name, m in metrics.items()}, + } + + summary_path = os.path.join(self.config.output_dir, "summary.json") + with open(summary_path, "w", encoding="utf-8") as f: + json.dump(summary, f, indent=2, ensure_ascii=False) + + print(f"\nSummary saved to {summary_path}") + + def _print_summary(self, metrics: Dict[str, EvaluationMetrics]): + """Print experiment summary.""" + print("\n" + "=" * 60) + print("EXPERIMENT SUMMARY") + print("=" * 60) + + # Header + print(f"\n{'Agent':<20} {'TS ↑':>10} {'UE ↓':>10} {'Eff ↓':>10} {'Sessions':>10}") + print("-" * 60) + + for agent_name, m in metrics.items(): + print(f"{agent_name:<20} {m.avg_task_success:>10.2%} " + f"{m.avg_user_effort:>10.2f} {m.avg_efficiency:>10.1f} " + f"{m.num_sessions:>10}") + + print("\n" + "=" * 60) + + +def run_demo_experiment(output_dir: str = "data/eval/demo_experiment"): + """ + Run a minimal demo experiment. + + This is a quick sanity check with: + - 2 users + - 10 preferences per user + - 3 tasks per user + - T1 (NoMemory) vs Y3 (RAG) comparison + """ + config = ExperimentConfig( + name="demo_experiment", + output_dir=output_dir, + num_users=2, + prefs_per_user=10, + tasks_per_user=3, + max_turns=15, + run_no_memory=True, + run_rag=True, + run_rag_uv=False, + ) + + runner = ExperimentRunner(config) + runner.setup() + metrics = runner.run() + + return metrics + + +if __name__ == "__main__": + import sys + + output_dir = sys.argv[1] if len(sys.argv) > 1 else "data/eval/demo_experiment" + run_demo_experiment(output_dir) + + diff --git a/src/personalization/evaluation/preference_bank/__init__.py b/src/personalization/evaluation/preference_bank/__init__.py new file mode 100644 index 0000000..33f0ed2 --- /dev/null +++ b/src/personalization/evaluation/preference_bank/__init__.py @@ -0,0 +1,6 @@ +from .schemas import PreferenceItem, PreferenceTopic, PreferenceBank +from .generator import PreferenceBankGenerator + +__all__ = ["PreferenceItem", "PreferenceTopic", "PreferenceBank", "PreferenceBankGenerator"] + + diff --git a/src/personalization/evaluation/preference_bank/generator.py b/src/personalization/evaluation/preference_bank/generator.py new file mode 100644 index 0000000..e256b86 --- /dev/null +++ b/src/personalization/evaluation/preference_bank/generator.py @@ -0,0 +1,530 @@ +""" +Preference Bank Generator + +Uses LLM to automatically generate diverse user preferences for each topic. +""" + +import json +import os +from typing import List, Dict, Any, Optional +from dataclasses import dataclass + +from .schemas import PreferenceItem, PreferenceTopic, PreferenceBank + + +# Topic definitions for the demo (5 topics) +DEMO_TOPICS = { + "math_formatting": { + "description": "How mathematical content should be formatted (LaTeX, plain text, markdown)", + "related_datasets": ["math-hard", "math-500", "gpqa"], + "generation_hints": [ + "LaTeX formatting for equations", + "Plain text vs mathematical notation", + "Inline vs block equations", + "Step-by-step calculation display", + "Variable naming conventions", + ], + }, + "coding_style": { + "description": "Preferences for code formatting, language choice, and documentation", + "related_datasets": ["humaneval", "bigcodebench"], + "generation_hints": [ + "Programming language preference (Python, JavaScript, etc.)", + "Type hints and annotations", + "Docstrings and comments", + "Code structure and organization", + "Naming conventions", + ], + }, + "response_structure": { + "description": "How responses should be organized (bullets, numbered lists, prose)", + "related_datasets": ["all"], + "generation_hints": [ + "Bullet points vs numbered lists vs prose", + "Headers and sections", + "TL;DR summaries", + "Outline before detailed explanation", + "Logical flow and transitions", + ], + }, + "explanation_depth": { + "description": "Level of detail and thoroughness in explanations", + "related_datasets": ["all"], + "generation_hints": [ + "Concise vs comprehensive", + "Examples and analogies", + "Background context", + "Assumptions stated explicitly", + "Multiple approaches/alternatives", + ], + }, + "interaction_style": { + "description": "How the agent should interact (questions, confirmations, suggestions)", + "related_datasets": ["all"], + "generation_hints": [ + "Asking clarifying questions", + "Step-by-step vs holistic answers", + "Proactive suggestions", + "Confidence levels in answers", + "Politeness and tone", + ], + }, +} + + +# LLM prompt template for generating preferences +GENERATION_PROMPT = '''You are helping design a user preference benchmark. Generate {num_prefs} diverse user preferences for the topic: "{topic_name}" + +Topic Description: {topic_description} + +Hints for preference types: +{hints} + +For each preference, provide a JSON object with: +1. "condition": When this preference applies (e.g., "when solving math problems", "when explaining code") +2. "action": What the user prefers (be specific and enforceable) +3. "conflict_group": If this preference conflicts with others in the list, give them the same group name (e.g., "notation_style"). Use null if no conflict. +4. "enforce_description": How a user would detect violation and enforce this preference +5. "example_violation": A concrete example of an agent response that violates this +6. "example_compliance": A concrete example that follows this preference + +Requirements: +- Make preferences SPECIFIC and ENFORCEABLE (not vague like "be helpful") +- Include 2-3 pairs of CONFLICTING preferences (same conflict_group) - this is important for testing RAG +- Vary specificity: some broad ("always use Python"), some narrow ("use f-strings for string formatting in Python") +- Preferences should be realistic things users actually care about + +Output as a JSON array of objects. Only output the JSON array, no other text. +''' + + +class PreferenceBankGenerator: + """Generates a preference bank using LLM.""" + + def __init__( + self, + llm_client: Any = None, + model_name: str = "gpt-4o-mini", # Default to a capable but fast model + ): + """ + Args: + llm_client: OpenAI-compatible client. If None, will create one. + model_name: Model to use for generation. + """ + self.model_name = model_name + + if llm_client is None: + try: + import openai + self.client = openai.OpenAI() + except Exception as e: + print(f"Warning: Could not initialize OpenAI client: {e}") + self.client = None + else: + self.client = llm_client + + def generate_preferences_for_topic( + self, + topic_name: str, + topic_description: str, + hints: List[str], + num_prefs: int = 5, + ) -> List[PreferenceItem]: + """Generate preferences for a single topic using LLM.""" + + if self.client is None: + print(f"No LLM client available, using fallback for topic: {topic_name}") + return self._generate_fallback_preferences(topic_name, num_prefs) + + hints_text = "\n".join(f"- {h}" for h in hints) + + prompt = GENERATION_PROMPT.format( + num_prefs=num_prefs, + topic_name=topic_name, + topic_description=topic_description, + hints=hints_text, + ) + + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=[{"role": "user", "content": prompt}], + temperature=0.8, + max_tokens=4000, + ) + + content = response.choices[0].message.content.strip() + + # Parse JSON + # Handle potential markdown code blocks + if content.startswith("```"): + content = content.split("```")[1] + if content.startswith("json"): + content = content[4:] + + prefs_data = json.loads(content) + + # Convert to PreferenceItem objects + preferences = [] + for i, pref_dict in enumerate(prefs_data): + pref_id = f"{topic_name[:4]}_{i+1:03d}" + pref = PreferenceItem( + id=pref_id, + topic=topic_name, + condition=pref_dict.get("condition", ""), + action=pref_dict.get("action", ""), + conflict_group=pref_dict.get("conflict_group"), + enforce_description=pref_dict.get("enforce_description", ""), + example_violation=pref_dict.get("example_violation", ""), + example_compliance=pref_dict.get("example_compliance", ""), + ) + preferences.append(pref) + + return preferences + + except Exception as e: + print(f"Error generating preferences for {topic_name}: {e}") + return self._generate_fallback_preferences(topic_name, num_prefs) + + def _generate_fallback_preferences( + self, + topic_name: str, + num_prefs: int = 5, + ) -> List[PreferenceItem]: + """Generate hardcoded fallback preferences when LLM is not available.""" + + fallbacks = { + "math_formatting": [ + PreferenceItem( + id="math_001", topic="math_formatting", + condition="solving math problems", + action="use LaTeX for all formulas and equations", + conflict_group="math_notation", + enforce_description="Check if mathematical expressions use LaTeX syntax like $x^2$ or $$\\int$$", + example_violation="The answer is x squared plus 2x plus 1", + example_compliance="The answer is $x^2 + 2x + 1$", + ), + PreferenceItem( + id="math_002", topic="math_formatting", + condition="explaining mathematical concepts", + action="use plain text only, avoid any mathematical notation", + conflict_group="math_notation", + enforce_description="Check if response contains any LaTeX or special math symbols", + example_violation="We need to find $\\frac{d}{dx}(x^2)$", + example_compliance="We need to find the derivative of x squared", + ), + PreferenceItem( + id="math_003", topic="math_formatting", + condition="showing multi-step calculations", + action="display each step on a separate line with clear labels", + conflict_group=None, + enforce_description="Check if steps are on separate lines with labels like 'Step 1:'", + example_violation="First we add 2+3=5, then multiply by 4 to get 20", + example_compliance="Step 1: Add 2 + 3 = 5\nStep 2: Multiply by 4: 5 × 4 = 20", + ), + PreferenceItem( + id="math_004", topic="math_formatting", + condition="presenting final answers", + action="clearly box or highlight the final answer", + conflict_group=None, + enforce_description="Check if final answer is visually distinguished", + example_violation="So x equals 5.", + example_compliance="**Final Answer: x = 5**", + ), + PreferenceItem( + id="math_005", topic="math_formatting", + condition="solving problems with multiple variables", + action="use single-letter variables (x, y, z) rather than descriptive names", + conflict_group="var_naming", + enforce_description="Check if variables are single letters", + example_violation="Let price = 100 and quantity = 5", + example_compliance="Let p = 100 and q = 5", + ), + ], + "coding_style": [ + PreferenceItem( + id="code_001", topic="coding_style", + condition="providing code examples", + action="always use Python", + conflict_group="language", + enforce_description="Check if code is written in Python", + example_violation="```javascript\nfunction add(a, b) { return a + b; }\n```", + example_compliance="```python\ndef add(a, b):\n return a + b\n```", + ), + PreferenceItem( + id="code_002", topic="coding_style", + condition="providing code examples", + action="always use JavaScript or TypeScript", + conflict_group="language", + enforce_description="Check if code is written in JavaScript/TypeScript", + example_violation="```python\ndef add(a, b): return a + b\n```", + example_compliance="```javascript\nconst add = (a, b) => a + b;\n```", + ), + PreferenceItem( + id="code_003", topic="coding_style", + condition="writing Python functions", + action="always include type hints for parameters and return values", + conflict_group=None, + enforce_description="Check if function has type hints", + example_violation="def add(a, b):\n return a + b", + example_compliance="def add(a: int, b: int) -> int:\n return a + b", + ), + PreferenceItem( + id="code_004", topic="coding_style", + condition="writing functions", + action="include a docstring explaining the function", + conflict_group=None, + enforce_description="Check if function has a docstring", + example_violation="def add(a, b):\n return a + b", + example_compliance='def add(a, b):\n """Add two numbers and return the result."""\n return a + b', + ), + PreferenceItem( + id="code_005", topic="coding_style", + condition="writing code", + action="minimize comments, code should be self-documenting", + conflict_group="comment_style", + enforce_description="Check if there are excessive inline comments", + example_violation="x = x + 1 # increment x by 1", + example_compliance="x += 1", + ), + ], + "response_structure": [ + PreferenceItem( + id="struct_001", topic="response_structure", + condition="providing multi-point answers", + action="use bullet points with '-' or '*'", + conflict_group="list_style", + enforce_description="Check if response uses bullet points", + example_violation="First, do X. Second, do Y. Third, do Z.", + example_compliance="- First, do X\n- Second, do Y\n- Third, do Z", + ), + PreferenceItem( + id="struct_002", topic="response_structure", + condition="providing step-by-step instructions", + action="use numbered lists", + conflict_group="list_style", + enforce_description="Check if response uses numbered lists", + example_violation="First do X, then do Y, finally do Z.", + example_compliance="1. Do X\n2. Do Y\n3. Do Z", + ), + PreferenceItem( + id="struct_003", topic="response_structure", + condition="writing explanations", + action="use flowing prose paragraphs, avoid lists", + conflict_group="list_style", + enforce_description="Check if response uses prose instead of lists", + example_violation="Key points:\n- Point 1\n- Point 2", + example_compliance="The key insight here is that Point 1 connects to Point 2 through...", + ), + PreferenceItem( + id="struct_004", topic="response_structure", + condition="providing long explanations", + action="include a TL;DR summary at the end", + conflict_group=None, + enforce_description="Check if response ends with TL;DR", + example_violation="... and that's how it works.", + example_compliance="... and that's how it works.\n\n**TL;DR:** X does Y by Z.", + ), + PreferenceItem( + id="struct_005", topic="response_structure", + condition="explaining complex topics", + action="start with an outline of what will be covered", + conflict_group=None, + enforce_description="Check if response starts with an outline", + example_violation="Let me explain recursion. First, understand that...", + example_compliance="I'll cover: 1) What is recursion, 2) How it works, 3) Examples.\n\n**1) What is recursion**...", + ), + ], + "explanation_depth": [ + PreferenceItem( + id="depth_001", topic="explanation_depth", + condition="answering questions", + action="be concise, no more than 3 sentences", + conflict_group="length", + enforce_description="Count sentences, should be 3 or fewer", + example_violation="Let me explain in detail. First... Second... Third... Fourth... Fifth...", + example_compliance="The answer is X. This works because of Y. Here's how to apply it: Z.", + ), + PreferenceItem( + id="depth_002", topic="explanation_depth", + condition="explaining concepts", + action="provide comprehensive, detailed explanations", + conflict_group="length", + enforce_description="Check if explanation is thorough with multiple aspects covered", + example_violation="It's X. Done.", + example_compliance="Let me explain X in detail. The concept originates from... It works by... Common applications include... Here's an example...", + ), + PreferenceItem( + id="depth_003", topic="explanation_depth", + condition="explaining anything", + action="always include at least one concrete example", + conflict_group=None, + enforce_description="Check if at least one example is provided", + example_violation="A binary tree is a data structure where each node has at most two children.", + example_compliance="A binary tree is a data structure where each node has at most two children. For example, in [5, 3, 7], 5 is the root, 3 is left child, 7 is right child.", + ), + PreferenceItem( + id="depth_004", topic="explanation_depth", + condition="explaining technical concepts", + action="use analogies from everyday life", + conflict_group=None, + enforce_description="Check if explanation includes an everyday analogy", + example_violation="A stack is a LIFO data structure.", + example_compliance="A stack is like a stack of plates - you can only take the top one (LIFO).", + ), + PreferenceItem( + id="depth_005", topic="explanation_depth", + condition="solving problems", + action="state assumptions explicitly before solving", + conflict_group=None, + enforce_description="Check if assumptions are stated upfront", + example_violation="The answer is 42.", + example_compliance="Assuming n is positive and integer, the answer is 42.", + ), + ], + "interaction_style": [ + PreferenceItem( + id="inter_001", topic="interaction_style", + condition="receiving unclear requests", + action="ask clarifying questions before attempting to answer", + conflict_group="clarification", + enforce_description="Check if agent asks questions when request is ambiguous", + example_violation="Here's a solution assuming you meant X...", + example_compliance="Before I help, could you clarify: do you mean X or Y?", + ), + PreferenceItem( + id="inter_002", topic="interaction_style", + condition="receiving requests", + action="make reasonable assumptions and proceed without asking", + conflict_group="clarification", + enforce_description="Check if agent proceeds with reasonable assumptions", + example_violation="What exactly do you mean by 'large'? What size range?", + example_compliance="Assuming you mean 'large' as over 1000 items, here's the solution...", + ), + PreferenceItem( + id="inter_003", topic="interaction_style", + condition="solving multi-step problems", + action="present one step at a time and ask for confirmation before proceeding", + conflict_group="pacing", + enforce_description="Check if agent pauses after each step", + example_violation="Step 1: X. Step 2: Y. Step 3: Z. Done!", + example_compliance="Step 1: X. Does this make sense? Should I continue to Step 2?", + ), + PreferenceItem( + id="inter_004", topic="interaction_style", + condition="solving problems", + action="provide the complete solution at once without pausing", + conflict_group="pacing", + enforce_description="Check if agent gives complete solution without asking to continue", + example_violation="First, let me do step 1... Should I continue?", + example_compliance="Here's the complete solution: Step 1: X, Step 2: Y, Step 3: Z.", + ), + PreferenceItem( + id="inter_005", topic="interaction_style", + condition="providing answers", + action="include a confidence level (e.g., 'I'm 90% confident')", + conflict_group=None, + enforce_description="Check if response includes confidence level", + example_violation="The answer is 42.", + example_compliance="I'm about 95% confident the answer is 42.", + ), + ], + } + + if topic_name in fallbacks: + return fallbacks[topic_name][:num_prefs] + else: + # Generic fallback + return [ + PreferenceItem( + id=f"{topic_name[:4]}_{i+1:03d}", + topic=topic_name, + condition=f"interacting about {topic_name}", + action=f"preference {i+1} for {topic_name}", + conflict_group=None, + enforce_description=f"Check preference {i+1}", + example_violation=f"Violation example {i+1}", + example_compliance=f"Compliance example {i+1}", + ) + for i in range(num_prefs) + ] + + def generate_bank( + self, + topics: Dict[str, Dict] = None, + prefs_per_topic: int = 5, + ) -> PreferenceBank: + """Generate a complete preference bank.""" + + if topics is None: + topics = DEMO_TOPICS + + bank = PreferenceBank() + + for topic_name, topic_config in topics.items(): + print(f"Generating preferences for topic: {topic_name}...") + + preferences = self.generate_preferences_for_topic( + topic_name=topic_name, + topic_description=topic_config["description"], + hints=topic_config.get("generation_hints", []), + num_prefs=prefs_per_topic, + ) + + topic = PreferenceTopic( + name=topic_name, + description=topic_config["description"], + related_datasets=topic_config["related_datasets"], + preferences=preferences, + ) + + bank.add_topic(topic) + print(f" Generated {len(preferences)} preferences") + + return bank + + +def generate_demo_bank( + output_path: str = None, + use_llm: bool = False, + prefs_per_topic: int = 5, +) -> PreferenceBank: + """ + Generate a demo preference bank. + + Args: + output_path: If provided, save bank to this path + use_llm: If True, use LLM to generate. If False, use hardcoded fallbacks. + prefs_per_topic: Number of preferences per topic + + Returns: + Generated PreferenceBank + """ + if use_llm: + generator = PreferenceBankGenerator() + else: + generator = PreferenceBankGenerator(llm_client=None) # Use fallbacks + + bank = generator.generate_bank( + topics=DEMO_TOPICS, + prefs_per_topic=prefs_per_topic, + ) + + if output_path: + bank.save(output_path) + print(f"Saved bank to {output_path}") + + print(f"\nBank Statistics: {bank.stats()}") + + return bank + + +if __name__ == "__main__": + # Generate demo bank with fallback preferences + import os + script_dir = os.path.dirname(os.path.abspath(__file__)) + output_path = os.path.join(script_dir, "bank_demo.json") + + bank = generate_demo_bank(output_path=output_path, use_llm=False) + + diff --git a/src/personalization/evaluation/preference_bank/schemas.py b/src/personalization/evaluation/preference_bank/schemas.py new file mode 100644 index 0000000..f219487 --- /dev/null +++ b/src/personalization/evaluation/preference_bank/schemas.py @@ -0,0 +1,147 @@ +""" +Preference Bank Schemas + +Defines the data structures for user preferences, organized by topic. +Each preference has a condition (when it applies), action (what the user wants), +and optional conflict group (preferences in the same group are mutually exclusive). +""" + +from dataclasses import dataclass, field +from typing import Optional, List, Dict, Any +import json + + +@dataclass +class PreferenceItem: + """A single user preference.""" + id: str # Unique ID, e.g., "math_fmt_001" + topic: str # Topic name, e.g., "math_formatting" + condition: str # When this preference applies + action: str # What the user prefers + conflict_group: Optional[str] # If set, only one pref from this group can be selected + enforce_description: str # Description for user simulator on how to enforce + example_violation: str # Example of agent response that violates this + example_compliance: str # Example that follows this preference + + def to_dict(self) -> Dict[str, Any]: + return { + "id": self.id, + "topic": self.topic, + "condition": self.condition, + "action": self.action, + "conflict_group": self.conflict_group, + "enforce_description": self.enforce_description, + "example_violation": self.example_violation, + "example_compliance": self.example_compliance, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PreferenceItem": + return cls(**data) + + def format_for_user(self) -> str: + """Format for user simulator prompt.""" + return f"When {self.condition}: {self.action}" + + def format_for_enforcement(self) -> str: + """Format with enforcement details.""" + return f"[{self.id}] When {self.condition}: {self.action}\n Enforce if: {self.enforce_description}" + + +@dataclass +class PreferenceTopic: + """A topic containing multiple related preferences.""" + name: str # Topic name, e.g., "math_formatting" + description: str # Description of this topic + related_datasets: List[str] # Datasets where this topic is relevant + preferences: List[PreferenceItem] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return { + "name": self.name, + "description": self.description, + "related_datasets": self.related_datasets, + "preferences": [p.to_dict() for p in self.preferences], + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PreferenceTopic": + prefs = [PreferenceItem.from_dict(p) for p in data.get("preferences", [])] + return cls( + name=data["name"], + description=data["description"], + related_datasets=data["related_datasets"], + preferences=prefs, + ) + + +@dataclass +class PreferenceBank: + """ + A bank of preferences organized by topic. + Used to generate user profiles by sampling preferences. + """ + topics: Dict[str, PreferenceTopic] = field(default_factory=dict) + version: str = "1.0" + + def add_topic(self, topic: PreferenceTopic): + self.topics[topic.name] = topic + + def get_all_preferences(self) -> List[PreferenceItem]: + """Get all preferences across all topics.""" + all_prefs = [] + for topic in self.topics.values(): + all_prefs.extend(topic.preferences) + return all_prefs + + def get_preferences_for_dataset(self, dataset: str) -> List[PreferenceItem]: + """Get preferences relevant to a specific dataset.""" + relevant = [] + for topic in self.topics.values(): + if dataset in topic.related_datasets or "all" in topic.related_datasets: + relevant.extend(topic.preferences) + return relevant + + def to_dict(self) -> Dict[str, Any]: + return { + "version": self.version, + "topics": {name: topic.to_dict() for name, topic in self.topics.items()}, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PreferenceBank": + bank = cls(version=data.get("version", "1.0")) + for name, topic_data in data.get("topics", {}).items(): + bank.topics[name] = PreferenceTopic.from_dict(topic_data) + return bank + + def save(self, path: str): + """Save bank to JSON file.""" + with open(path, "w", encoding="utf-8") as f: + json.dump(self.to_dict(), f, indent=2, ensure_ascii=False) + + @classmethod + def load(cls, path: str) -> "PreferenceBank": + """Load bank from JSON file.""" + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + return cls.from_dict(data) + + def stats(self) -> Dict[str, Any]: + """Get statistics about the bank.""" + total_prefs = 0 + conflict_groups = set() + for topic in self.topics.values(): + total_prefs += len(topic.preferences) + for pref in topic.preferences: + if pref.conflict_group: + conflict_groups.add(pref.conflict_group) + + return { + "num_topics": len(self.topics), + "total_preferences": total_prefs, + "num_conflict_groups": len(conflict_groups), + "prefs_per_topic": {name: len(t.preferences) for name, t in self.topics.items()}, + } + + diff --git a/src/personalization/evaluation/profiles/__init__.py b/src/personalization/evaluation/profiles/__init__.py new file mode 100644 index 0000000..8532af9 --- /dev/null +++ b/src/personalization/evaluation/profiles/__init__.py @@ -0,0 +1,5 @@ +from .generator import UserProfile, UserProfileGenerator + +__all__ = ["UserProfile", "UserProfileGenerator"] + + diff --git a/src/personalization/evaluation/profiles/generator.py b/src/personalization/evaluation/profiles/generator.py new file mode 100644 index 0000000..da847a0 --- /dev/null +++ b/src/personalization/evaluation/profiles/generator.py @@ -0,0 +1,351 @@ +""" +User Profile Generator + +Generates user profiles by sampling preferences from the preference bank. +Ensures no conflicting preferences within same conflict_group, but allows +cross-topic scenario conflicts (which is desired for testing RAG). +""" + +import json +import random +from collections import defaultdict +from dataclasses import dataclass, field +from typing import List, Dict, Set, Optional, Any + +from ..preference_bank.schemas import PreferenceItem, PreferenceBank + + +@dataclass +class UserProfile: + """A simulated user with specific preferences.""" + user_id: str + persona: str # Background description + preferences: List[PreferenceItem] # Selected preferences + primary_topics: List[str] # Topics this user cares most about + preference_by_topic: Dict[str, List[PreferenceItem]] = field(default_factory=dict) + + def __post_init__(self): + # Build topic index if not provided + if not self.preference_by_topic: + self.preference_by_topic = defaultdict(list) + for pref in self.preferences: + self.preference_by_topic[pref.topic].append(pref) + self.preference_by_topic = dict(self.preference_by_topic) + + def get_preferences_for_topic(self, topic: str) -> List[PreferenceItem]: + """Get preferences for a specific topic.""" + return self.preference_by_topic.get(topic, []) + + def get_preferences_for_dataset(self, dataset: str, bank: PreferenceBank) -> List[PreferenceItem]: + """Get preferences relevant to a specific dataset.""" + relevant_topics = set() + for topic_name, topic in bank.topics.items(): + if dataset in topic.related_datasets or "all" in topic.related_datasets: + relevant_topics.add(topic_name) + + relevant_prefs = [] + for pref in self.preferences: + if pref.topic in relevant_topics: + relevant_prefs.append(pref) + return relevant_prefs + + def format_preferences_grouped(self) -> str: + """Format preferences grouped by topic for prompts.""" + lines = [] + for topic, prefs in self.preference_by_topic.items(): + topic_title = topic.replace("_", " ").title() + lines.append(f"\n## {topic_title}") + for pref in prefs: + lines.append(f" [{pref.id}] When {pref.condition}: {pref.action}") + lines.append(f" Enforce if: {pref.enforce_description}") + return "\n".join(lines) + + def format_preferences_flat(self) -> str: + """Format preferences as a flat list.""" + lines = [] + for i, pref in enumerate(self.preferences, 1): + lines.append(f"{i}. When {pref.condition}: {pref.action}") + return "\n".join(lines) + + def to_dict(self) -> Dict[str, Any]: + return { + "user_id": self.user_id, + "persona": self.persona, + "preferences": [p.to_dict() for p in self.preferences], + "primary_topics": self.primary_topics, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "UserProfile": + prefs = [PreferenceItem.from_dict(p) for p in data.get("preferences", [])] + return cls( + user_id=data["user_id"], + persona=data["persona"], + preferences=prefs, + primary_topics=data.get("primary_topics", []), + ) + + def stats(self) -> Dict[str, Any]: + """Get statistics about this profile.""" + conflict_groups = set() + for pref in self.preferences: + if pref.conflict_group: + conflict_groups.add(pref.conflict_group) + + return { + "user_id": self.user_id, + "num_preferences": len(self.preferences), + "num_topics": len(self.preference_by_topic), + "prefs_per_topic": {t: len(ps) for t, ps in self.preference_by_topic.items()}, + "num_conflict_groups_used": len(conflict_groups), + } + + +# Persona templates for different user types +PERSONA_TEMPLATES = [ + "A {field} professional who values {trait} and prefers {style} communication.", + "A graduate student in {field} who appreciates {trait} and likes responses that are {style}.", + "An experienced {field} practitioner who prioritizes {trait} and expects {style} explanations.", + "A beginner learning {field} who needs {trait} and responds well to {style} guidance.", + "A {field} enthusiast who cares about {trait} and prefers {style} interactions.", +] + +FIELDS = [ + "software engineering", "data science", "mathematics", "physics", + "medical research", "financial analysis", "machine learning", + "web development", "systems programming", "algorithm design", +] + +TRAITS = [ + "clarity", "precision", "efficiency", "thoroughness", "simplicity", + "formality", "practicality", "theoretical depth", "hands-on examples", +] + +STYLES = [ + "concise", "detailed", "step-by-step", "example-driven", "formal", + "conversational", "structured", "visual", "analytical", +] + + +class UserProfileGenerator: + """Generates user profiles by sampling from preference bank.""" + + def __init__( + self, + preference_bank: PreferenceBank, + target_num_prefs: int = 15, # For demo, use smaller number + seed: Optional[int] = None, + ): + self.bank = preference_bank + self.target_num = target_num_prefs + + if seed is not None: + random.seed(seed) + + def generate_profile( + self, + user_id: str, + primary_topics: List[str] = None, + persona: str = None, + ) -> UserProfile: + """ + Generate a user profile by sampling preferences. + + Args: + user_id: Unique identifier for this user + primary_topics: Topics this user cares most about (get more prefs from these) + persona: Optional persona description. If None, will be generated. + + Returns: + UserProfile with sampled preferences + """ + selected: List[PreferenceItem] = [] + used_conflict_groups: Set[str] = set() + + # If no primary topics specified, randomly select 1-2 + if primary_topics is None: + all_topics = list(self.bank.topics.keys()) + num_primary = random.randint(1, min(2, len(all_topics))) + primary_topics = random.sample(all_topics, num_primary) + + # Compute quotas for each topic + topic_quotas = self._compute_quotas(primary_topics) + + # Sample from each topic + for topic_name, quota in topic_quotas.items(): + if topic_name not in self.bank.topics: + continue + + topic = self.bank.topics[topic_name] + + # Filter out preferences with already-used conflict groups + available = [ + p for p in topic.preferences + if p.conflict_group is None or p.conflict_group not in used_conflict_groups + ] + + # Sample up to quota + to_select = min(quota, len(available)) + if to_select > 0: + sampled = random.sample(available, to_select) + + for pref in sampled: + selected.append(pref) + if pref.conflict_group: + used_conflict_groups.add(pref.conflict_group) + + # Generate persona if not provided + if persona is None: + persona = self._generate_persona(primary_topics) + + return UserProfile( + user_id=user_id, + persona=persona, + preferences=selected, + primary_topics=primary_topics, + ) + + def _compute_quotas(self, primary_topics: List[str]) -> Dict[str, int]: + """Compute how many preferences to sample from each topic.""" + quotas = {} + all_topics = list(self.bank.topics.keys()) + + # Base quota for all topics + base_quota = max(1, self.target_num // len(all_topics)) + + for topic_name in all_topics: + if topic_name in primary_topics: + # Primary topics get more preferences + quotas[topic_name] = base_quota + random.randint(1, 3) + else: + quotas[topic_name] = max(1, base_quota - random.randint(0, 1)) + + # Adjust to match target + total = sum(quotas.values()) + if total < self.target_num: + # Add more to primary topics + for topic in primary_topics: + if topic in quotas: + quotas[topic] += (self.target_num - total) // len(primary_topics) + + return quotas + + def _generate_persona(self, primary_topics: List[str]) -> str: + """Generate a persona description based on primary topics.""" + template = random.choice(PERSONA_TEMPLATES) + + # Map topics to fields + topic_to_field = { + "math_formatting": ["mathematics", "physics", "data science"], + "coding_style": ["software engineering", "web development", "systems programming"], + "response_structure": ["technical writing", "documentation", "education"], + "explanation_depth": ["research", "teaching", "consulting"], + "interaction_style": ["customer support", "mentoring", "collaboration"], + } + + # Pick a field related to primary topics + possible_fields = [] + for topic in primary_topics: + possible_fields.extend(topic_to_field.get(topic, FIELDS[:3])) + + if not possible_fields: + possible_fields = FIELDS + + field = random.choice(possible_fields) + trait = random.choice(TRAITS) + style = random.choice(STYLES) + + return template.format(field=field, trait=trait, style=style) + + def generate_profiles( + self, + num_users: int, + id_prefix: str = "user", + ) -> List[UserProfile]: + """Generate multiple user profiles.""" + profiles = [] + + for i in range(num_users): + user_id = f"{id_prefix}_{i:03d}" + profile = self.generate_profile(user_id) + profiles.append(profile) + + return profiles + + def save_profiles(self, profiles: List[UserProfile], path: str): + """Save profiles to JSON file.""" + data = [p.to_dict() for p in profiles] + with open(path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, ensure_ascii=False) + + @staticmethod + def load_profiles(path: str) -> List[UserProfile]: + """Load profiles from JSON file.""" + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + return [UserProfile.from_dict(d) for d in data] + + +def generate_demo_profiles( + bank: PreferenceBank, + num_users: int = 2, + prefs_per_user: int = 10, + output_path: str = None, + seed: int = 42, +) -> List[UserProfile]: + """ + Generate demo user profiles. + + Args: + bank: Preference bank to sample from + num_users: Number of users to generate + prefs_per_user: Target preferences per user + output_path: If provided, save profiles to this path + seed: Random seed for reproducibility + + Returns: + List of UserProfile objects + """ + generator = UserProfileGenerator( + preference_bank=bank, + target_num_prefs=prefs_per_user, + seed=seed, + ) + + profiles = generator.generate_profiles(num_users, id_prefix="demo_user") + + if output_path: + generator.save_profiles(profiles, output_path) + print(f"Saved {len(profiles)} profiles to {output_path}") + + # Print stats + for profile in profiles: + print(f"\n{profile.user_id}: {profile.stats()}") + print(f" Persona: {profile.persona}") + + return profiles + + +if __name__ == "__main__": + import os + from ..preference_bank.generator import generate_demo_bank + + # Generate bank first + script_dir = os.path.dirname(os.path.abspath(__file__)) + bank_path = os.path.join(script_dir, "..", "preference_bank", "bank_demo.json") + + if os.path.exists(bank_path): + bank = PreferenceBank.load(bank_path) + else: + bank = generate_demo_bank() + + # Generate profiles + profiles_path = os.path.join(script_dir, "profiles_demo.json") + profiles = generate_demo_profiles( + bank=bank, + num_users=2, + prefs_per_user=10, + output_path=profiles_path, + ) + + diff --git a/src/personalization/evaluation/user_simulator/__init__.py b/src/personalization/evaluation/user_simulator/__init__.py new file mode 100644 index 0000000..f7799d0 --- /dev/null +++ b/src/personalization/evaluation/user_simulator/__init__.py @@ -0,0 +1,5 @@ +from .simulator import UserSimulator, UserSimulatorResponse + +__all__ = ["UserSimulator", "UserSimulatorResponse"] + + diff --git a/src/personalization/evaluation/user_simulator/simulator.py b/src/personalization/evaluation/user_simulator/simulator.py new file mode 100644 index 0000000..5f5f701 --- /dev/null +++ b/src/personalization/evaluation/user_simulator/simulator.py @@ -0,0 +1,310 @@ +""" +User Simulator + +Simulates a user with specific preferences who: +1. Presents problems to the agent +2. Checks if agent responses satisfy their preferences +3. Enforces preferences when violated +4. Tracks draft answer and decides when to terminate +""" + +import json +import os +from dataclasses import dataclass, field +from typing import List, Dict, Any, Optional + +from ..profiles.generator import UserProfile +from ..preference_bank.schemas import PreferenceItem + + +# User simulator system prompt template +USER_SYSTEM_PROMPT = """You are simulating a user who is collaborating with an AI assistant to solve a problem. You have specific preferences about how the assistant should respond. + +# Problem to Solve +{task_description} +{problem} +Note: The assistant cannot see this problem description directly. You need to communicate with them. + +# Your Persona +{persona} + +# Your Preferences (Grouped by Topic) +{preferences_grouped} + +# Preference Enforcement Rules +- For each assistant response, check which of YOUR preferences are RELEVANT to the current context +- A preference is relevant if the assistant's response touches on that topic/condition +- If a relevant preference is VIOLATED, you MUST enforce it before proceeding +- Do NOT update your draft answer or proceed until violated preferences are fixed +- Only check preferences that apply to the current response (e.g., coding preferences for code responses) + +# Draft Answer Management +- Maintain a working draft answer to the problem +- Start with "I don't know" +- Update it based on helpful information from the assistant +- Do NOT update if you're enforcing preferences + +# Conversation Guidelines +- Be somewhat vague initially, let the assistant ask clarifying questions +- Respond naturally like a real user +- Do not copy the problem description directly + +# Termination +Terminate when: +- Your draft answer seems correct and complete +- The assistant cannot help further + +When ready to terminate, include "TERMINATE" in your response. + +# Output Format (JSON) +{{ + "preference_checks": [ + {{ + "preference_id": str, + "topic": str, + "relevant": bool, + "satisfied": bool or null, + "violation_detail": str + }} + ], + "any_violation": bool, + "enforcement_needed": bool, + "reasoning": str, + "draft_answer": str, + "should_terminate": bool, + "response": str +}} + +IMPORTANT: Only include preferences that are RELEVANT to the current assistant response in preference_checks. +Output valid JSON only, no other text.""" + + +@dataclass +class PreferenceCheck: + """Result of checking one preference.""" + preference_id: str + topic: str + relevant: bool + satisfied: Optional[bool] # None if not relevant + violation_detail: str = "" + + +@dataclass +class UserSimulatorResponse: + """Response from the user simulator.""" + response: str # Text response to agent + preference_checks: List[PreferenceCheck] # Checked preferences + any_violation: bool # Any preference violated? + enforcement_needed: bool # Need to enforce? + draft_answer: str # Current draft answer + should_terminate: bool # Should end conversation? + reasoning: str # Internal reasoning + raw_output: Dict[str, Any] = field(default_factory=dict) + + +class UserSimulator: + """ + Simulates a user with preferences interacting with an agent. + """ + + def __init__( + self, + model_name: str = "Llama-3.3-70B-Instruct", + api_base: Optional[str] = None, + api_key: Optional[str] = None, + temperature: float = 0.8, + max_tokens: int = 2048, + ): + self.model_name = model_name + self.api_base = api_base or os.getenv("USER_SIM_API_BASE", "http://localhost:8004/v1") + self.api_key = api_key or os.getenv("USER_SIM_API_KEY", "EMPTY") + self.temperature = temperature + self.max_tokens = max_tokens + + # Current session state + self._profile: Optional[UserProfile] = None + self._task_description: str = "" + self._problem: str = "" + self._solution: str = "" + + self._init_client() + + def _init_client(self): + """Initialize OpenAI client.""" + try: + import openai + self.client = openai.OpenAI( + base_url=self.api_base, + api_key=self.api_key, + ) + except Exception as e: + print(f"Warning: Could not initialize OpenAI client for user simulator: {e}") + self.client = None + + def setup( + self, + profile: UserProfile, + task_description: str, + problem: str, + solution: str = "", + ): + """ + Set up the simulator for a new task. + + Args: + profile: User profile with preferences + task_description: Description of the task type + problem: The specific problem to solve + solution: Ground truth solution (for evaluation) + """ + self._profile = profile + self._task_description = task_description + self._problem = problem + self._solution = solution + + def _build_system_prompt(self) -> str: + """Build the system prompt with user profile and task.""" + if self._profile is None: + raise ValueError("User profile not set. Call setup() first.") + + return USER_SYSTEM_PROMPT.format( + task_description=self._task_description, + problem=self._problem, + persona=self._profile.persona, + preferences_grouped=self._profile.format_preferences_grouped(), + ) + + def _parse_response(self, raw_text: str) -> UserSimulatorResponse: + """Parse LLM output into structured response.""" + try: + # Try to extract JSON from response + text = raw_text.strip() + + # Handle markdown code blocks + if "```json" in text: + text = text.split("```json")[1].split("```")[0] + elif "```" in text: + text = text.split("```")[1].split("```")[0] + + data = json.loads(text) + + # Parse preference checks + pref_checks = [] + for check in data.get("preference_checks", []): + pref_checks.append(PreferenceCheck( + preference_id=check.get("preference_id", ""), + topic=check.get("topic", ""), + relevant=check.get("relevant", False), + satisfied=check.get("satisfied"), + violation_detail=check.get("violation_detail", ""), + )) + + return UserSimulatorResponse( + response=data.get("response", ""), + preference_checks=pref_checks, + any_violation=data.get("any_violation", False), + enforcement_needed=data.get("enforcement_needed", False), + draft_answer=data.get("draft_answer", "I don't know"), + should_terminate=data.get("should_terminate", False), + reasoning=data.get("reasoning", ""), + raw_output=data, + ) + + except Exception as e: + print(f"Error parsing user simulator response: {e}") + print(f"Raw text: {raw_text[:500]}...") + + # Return a basic response + return UserSimulatorResponse( + response=raw_text if len(raw_text) < 500 else "Could you please continue?", + preference_checks=[], + any_violation=False, + enforcement_needed=False, + draft_answer="I don't know", + should_terminate=False, + reasoning="Parse error", + raw_output={"error": str(e), "raw": raw_text}, + ) + + def respond( + self, + conversation_history: List[Dict[str, str]], + ) -> UserSimulatorResponse: + """ + Generate user response based on conversation. + + Args: + conversation_history: List of {"role": "user/assistant", "content": "..."} + + Returns: + UserSimulatorResponse with user's reply and preference status + """ + if self._profile is None: + raise ValueError("User profile not set. Call setup() first.") + + system_prompt = self._build_system_prompt() + + # Build messages - reverse roles (user simulator sees itself as user) + messages = [{"role": "system", "content": system_prompt}] + + for msg in conversation_history: + # Flip roles: agent's messages become user input to simulator + if msg["role"] == "assistant": + messages.append({"role": "user", "content": msg["content"]}) + else: + messages.append({"role": "assistant", "content": msg["content"]}) + + if self.client is None: + # Fallback for testing + return self._fallback_response(conversation_history) + + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=messages, + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + + raw_text = response.choices[0].message.content + return self._parse_response(raw_text) + + except Exception as e: + print(f"Error calling user simulator LLM: {e}") + return self._fallback_response(conversation_history) + + def _fallback_response( + self, + conversation_history: List[Dict[str, str]], + ) -> UserSimulatorResponse: + """Generate a simple fallback response for testing.""" + num_turns = len([m for m in conversation_history if m["role"] == "assistant"]) + + if num_turns == 0: + # First turn - present the problem + response = f"Hi, I need help with this: {self._problem[:200]}..." + elif num_turns < 3: + response = "Thanks, that helps. Can you explain more?" + else: + response = "Got it, I think I understand now. TERMINATE" + + return UserSimulatorResponse( + response=response, + preference_checks=[], + any_violation=False, + enforcement_needed=False, + draft_answer="Draft answer from fallback", + should_terminate="TERMINATE" in response, + reasoning="Fallback mode", + raw_output={}, + ) + + def get_solution(self) -> str: + """Get the ground truth solution.""" + return self._solution + + def get_profile(self) -> Optional[UserProfile]: + """Get the current user profile.""" + return self._profile + + |
