summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/personalization/config/registry.py9
-rw-r--r--src/personalization/evaluation/baselines/__init__.py7
-rw-r--r--src/personalization/evaluation/baselines/base.py83
-rw-r--r--src/personalization/evaluation/baselines/no_memory.py143
-rw-r--r--src/personalization/evaluation/baselines/rag_memory.py204
-rw-r--r--src/personalization/evaluation/demo/__init__.py3
-rw-r--r--src/personalization/evaluation/demo/run_demo.py273
-rw-r--r--src/personalization/evaluation/pipeline/__init__.py6
-rw-r--r--src/personalization/evaluation/pipeline/evaluator.py353
-rw-r--r--src/personalization/evaluation/pipeline/runner.py333
-rw-r--r--src/personalization/evaluation/preference_bank/__init__.py6
-rw-r--r--src/personalization/evaluation/preference_bank/generator.py530
-rw-r--r--src/personalization/evaluation/preference_bank/schemas.py147
-rw-r--r--src/personalization/evaluation/profiles/__init__.py5
-rw-r--r--src/personalization/evaluation/profiles/generator.py351
-rw-r--r--src/personalization/evaluation/user_simulator/__init__.py5
-rw-r--r--src/personalization/evaluation/user_simulator/simulator.py310
-rw-r--r--src/personalization/feedback/handlers.py59
-rw-r--r--src/personalization/feedback/llm_reward.py253
-rw-r--r--src/personalization/retrieval/pipeline.py42
-rw-r--r--src/personalization/serving/personalized_llm.py721
21 files changed, 3732 insertions, 111 deletions
diff --git a/src/personalization/config/registry.py b/src/personalization/config/registry.py
index d825ad3..6048044 100644
--- a/src/personalization/config/registry.py
+++ b/src/personalization/config/registry.py
@@ -49,6 +49,7 @@ def get_chat_model(name: str, device_override: Optional[str] = None):
from personalization.models.llm.base import ChatModel
from personalization.models.llm.qwen_instruct import QwenInstruct
from personalization.models.llm.llama_instruct import LlamaChatModel
+ from personalization.models.llm.vllm_chat import VLLMChatModel
cfg = settings.load_local_models_config()
@@ -81,6 +82,14 @@ def get_chat_model(name: str, device_override: Optional[str] = None):
dtype=choose_dtype(dtype), # Converts string to torch.dtype
max_context_length=max_len
)
+ elif backend == "vllm":
+ # Use vLLM HTTP API for high-throughput inference
+ vllm_url = spec.get("vllm_url", "http://localhost:8003/v1")
+ return VLLMChatModel(
+ vllm_url=vllm_url,
+ model_name=spec.get("model_name"),
+ max_context_length=max_len
+ )
# Fallback to legacy single config
return QwenInstruct.from_config(cfg)
diff --git a/src/personalization/evaluation/baselines/__init__.py b/src/personalization/evaluation/baselines/__init__.py
new file mode 100644
index 0000000..b6a5761
--- /dev/null
+++ b/src/personalization/evaluation/baselines/__init__.py
@@ -0,0 +1,7 @@
+from .base import BaselineAgent, AgentResponse
+from .no_memory import NoMemoryAgent
+from .rag_memory import RAGMemoryAgent
+
+__all__ = ["BaselineAgent", "AgentResponse", "NoMemoryAgent", "RAGMemoryAgent"]
+
+
diff --git a/src/personalization/evaluation/baselines/base.py b/src/personalization/evaluation/baselines/base.py
new file mode 100644
index 0000000..a3051bd
--- /dev/null
+++ b/src/personalization/evaluation/baselines/base.py
@@ -0,0 +1,83 @@
+"""
+Base class for all baseline agents.
+
+All agents must implement:
+- respond(): Generate a response to user query
+- end_session(): Called when a session ends (for memory updates)
+- reset_user(): Reset all state for a user
+"""
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass, field
+from typing import List, Dict, Any, Optional
+
+
+@dataclass
+class AgentResponse:
+ """Response from an agent."""
+ answer: str
+ debug_info: Dict[str, Any] = field(default_factory=dict)
+
+
+class BaselineAgent(ABC):
+ """Abstract base class for all baseline agents."""
+
+ def __init__(self, model_name: str, **kwargs):
+ """
+ Args:
+ model_name: Name/path of the LLM to use
+ **kwargs: Additional configuration
+ """
+ self.model_name = model_name
+ self.config = kwargs
+
+ @abstractmethod
+ def respond(
+ self,
+ user_id: str,
+ query: str,
+ conversation_history: List[Dict[str, str]],
+ **kwargs
+ ) -> AgentResponse:
+ """
+ Generate a response to the user's query.
+
+ Args:
+ user_id: Unique identifier for the user
+ query: Current user message
+ conversation_history: List of previous messages [{"role": "user/assistant", "content": "..."}]
+ **kwargs: Additional context (e.g., task info)
+
+ Returns:
+ AgentResponse with answer and debug info
+ """
+ pass
+
+ @abstractmethod
+ def end_session(self, user_id: str, conversation: List[Dict[str, str]]):
+ """
+ Called when a session (one task) ends.
+ Use this to update memory, notes, etc.
+
+ Args:
+ user_id: User identifier
+ conversation: Complete conversation from this session
+ """
+ pass
+
+ @abstractmethod
+ def reset_user(self, user_id: str):
+ """
+ Completely reset all state for a user.
+ Called at the start of a new experiment.
+
+ Args:
+ user_id: User identifier
+ """
+ pass
+
+ def get_name(self) -> str:
+ """Get a descriptive name for this agent."""
+ return self.__class__.__name__
+
+
diff --git a/src/personalization/evaluation/baselines/no_memory.py b/src/personalization/evaluation/baselines/no_memory.py
new file mode 100644
index 0000000..bf4a7cf
--- /dev/null
+++ b/src/personalization/evaluation/baselines/no_memory.py
@@ -0,0 +1,143 @@
+"""
+No Memory Baseline (T1)
+
+A simple agent that has no memory of previous sessions.
+Only sees the current conversation history within a session.
+"""
+
+from typing import List, Dict, Any, Optional
+import os
+
+from .base import BaselineAgent, AgentResponse
+
+
+# System prompt for the agent
+AGENT_SYSTEM_PROMPT = """You are a helpful AI assistant helping users solve problems.
+
+Guidelines:
+- If the user's request is unclear, ask for clarification
+- Provide clear, well-structured answers
+- Adapt to user feedback and preferences expressed in the conversation
+- Be helpful and do your best to solve the user's problem
+
+Your output should be a direct response to the user."""
+
+
+class NoMemoryAgent(BaselineAgent):
+ """
+ T1: Base model with no memory.
+
+ This agent:
+ - Has no memory across sessions
+ - Only uses current conversation context
+ - Represents the baseline "no personalization" case
+ """
+
+ def __init__(
+ self,
+ model_name: str = "llama-8b",
+ api_base: Optional[str] = None,
+ api_key: Optional[str] = None,
+ max_new_tokens: int = 512,
+ temperature: float = 0.7,
+ **kwargs
+ ):
+ super().__init__(model_name, **kwargs)
+
+ self.api_base = api_base or os.getenv("OPENAI_API_BASE", "http://localhost:8003/v1")
+ self.api_key = api_key or os.getenv("OPENAI_API_KEY", "EMPTY")
+ self.max_new_tokens = max_new_tokens
+ self.temperature = temperature
+
+ # Initialize client
+ self._init_client()
+
+ def _init_client(self):
+ """Initialize the LLM client."""
+ try:
+ import openai
+ self.client = openai.OpenAI(
+ base_url=self.api_base,
+ api_key=self.api_key,
+ )
+ except Exception as e:
+ print(f"Warning: Could not initialize OpenAI client: {e}")
+ self.client = None
+
+ def _build_messages(
+ self,
+ conversation_history: List[Dict[str, str]],
+ query: str,
+ ) -> List[Dict[str, str]]:
+ """Build messages for the LLM."""
+ messages = [{"role": "system", "content": AGENT_SYSTEM_PROMPT}]
+
+ # Add conversation history
+ for msg in conversation_history:
+ messages.append({
+ "role": msg["role"],
+ "content": msg["content"],
+ })
+
+ # Add current query if not already in history
+ if not conversation_history or conversation_history[-1]["content"] != query:
+ messages.append({"role": "user", "content": query})
+
+ return messages
+
+ def respond(
+ self,
+ user_id: str,
+ query: str,
+ conversation_history: List[Dict[str, str]],
+ **kwargs
+ ) -> AgentResponse:
+ """Generate response using only current conversation context."""
+
+ messages = self._build_messages(conversation_history, query)
+
+ if self.client is None:
+ # Fallback for testing without LLM
+ return AgentResponse(
+ answer=f"[NoMemoryAgent] Response to: {query[:50]}...",
+ debug_info={"mode": "fallback", "num_messages": len(messages)},
+ )
+
+ try:
+ response = self.client.chat.completions.create(
+ model=self.model_name,
+ messages=messages,
+ max_tokens=self.max_new_tokens,
+ temperature=self.temperature,
+ )
+
+ answer = response.choices[0].message.content
+
+ return AgentResponse(
+ answer=answer,
+ debug_info={
+ "num_messages": len(messages),
+ "prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
+ "completion_tokens": response.usage.completion_tokens if response.usage else 0,
+ },
+ )
+
+ except Exception as e:
+ print(f"Error calling LLM: {e}")
+ return AgentResponse(
+ answer=f"I apologize, but I encountered an error. Let me try again: {query[:100]}",
+ debug_info={"error": str(e)},
+ )
+
+ def end_session(self, user_id: str, conversation: List[Dict[str, str]]):
+ """No-op for no-memory agent."""
+ pass
+
+ def reset_user(self, user_id: str):
+ """No-op for no-memory agent."""
+ pass
+
+ def get_name(self) -> str:
+ return f"NoMemory({self.model_name})"
+
+
diff --git a/src/personalization/evaluation/baselines/rag_memory.py b/src/personalization/evaluation/baselines/rag_memory.py
new file mode 100644
index 0000000..2b391c3
--- /dev/null
+++ b/src/personalization/evaluation/baselines/rag_memory.py
@@ -0,0 +1,204 @@
+"""
+RAG Memory Baseline (Y3/Y4)
+
+Wraps the PersonalizedLLM for use in the evaluation framework.
+Y3: Extractor + RAG (mode="nopersonal")
+Y4: Extractor + RAG + User Vector (mode="full")
+"""
+
+from typing import List, Dict, Any, Optional
+import os
+import sys
+
+from .base import BaselineAgent, AgentResponse
+
+# Add src to path for imports
+_src_path = os.path.join(os.path.dirname(__file__), "../../../..")
+if _src_path not in sys.path:
+ sys.path.insert(0, _src_path)
+
+
+class RAGMemoryAgent(BaselineAgent):
+ """
+ Y3/Y4: RAG-based memory with optional user vector.
+
+ This agent:
+ - Extracts preferences from conversations using the extractor
+ - Stores preferences as memory cards
+ - Retrieves relevant memories using RAG for each query
+ - (Y4 only) Uses user vector to personalize retrieval
+ """
+
+ def __init__(
+ self,
+ model_name: str = "llama-8b",
+ mode: str = "nopersonal", # "nopersonal" for Y3, "full" for Y4
+ memory_cards_path: str = None,
+ memory_embeddings_path: str = None,
+ enable_preference_extraction: bool = True,
+ enable_rl_updates: bool = False,
+ only_own_memories: bool = True,
+ **kwargs
+ ):
+ """
+ Args:
+ model_name: LLM model to use
+ mode: "nopersonal" (Y3) or "full" (Y4)
+ memory_cards_path: Path to memory cards file
+ memory_embeddings_path: Path to embeddings file
+ enable_preference_extraction: Whether to extract preferences
+ enable_rl_updates: Whether to update user vectors (Y4 only)
+ only_own_memories: Only retrieve user's own memories
+ """
+ super().__init__(model_name, **kwargs)
+
+ self.mode = mode
+ self.enable_rl_updates = enable_rl_updates and (mode == "full")
+
+ # Default paths
+ base_dir = os.path.join(os.path.dirname(__file__), "../../../../..")
+ self.memory_cards_path = memory_cards_path or os.path.join(
+ base_dir, "data/eval/memory_cards.jsonl"
+ )
+ self.memory_embeddings_path = memory_embeddings_path or os.path.join(
+ base_dir, "data/eval/memory_embeddings.npy"
+ )
+
+ self.enable_preference_extraction = enable_preference_extraction
+ self.only_own_memories = only_own_memories
+
+ # Lazy initialization
+ self._llm = None
+ self._initialized = False
+
+ def _ensure_initialized(self):
+ """Lazy initialization of PersonalizedLLM."""
+ if self._initialized:
+ return
+
+ try:
+ from personalization.serving.personalized_llm import PersonalizedLLM
+
+ self._llm = PersonalizedLLM(
+ mode=self.mode,
+ enable_preference_extraction=self.enable_preference_extraction,
+ enable_rl_updates=self.enable_rl_updates,
+ only_own_memories=self.only_own_memories,
+ memory_cards_path=self.memory_cards_path,
+ memory_embeddings_path=self.memory_embeddings_path,
+ eval_mode=True, # Deterministic selection
+ )
+ self._initialized = True
+
+ except Exception as e:
+ print(f"Warning: Could not initialize PersonalizedLLM: {e}")
+ print("Falling back to simple response mode.")
+ self._llm = None
+ self._initialized = True
+
+ def respond(
+ self,
+ user_id: str,
+ query: str,
+ conversation_history: List[Dict[str, str]],
+ **kwargs
+ ) -> AgentResponse:
+ """Generate response using RAG memory."""
+
+ self._ensure_initialized()
+
+ if self._llm is None:
+ # Fallback mode
+ return AgentResponse(
+ answer=f"[RAGMemoryAgent-{self.mode}] Response to: {query[:50]}...",
+ debug_info={"mode": "fallback"},
+ )
+
+ try:
+ # Use PersonalizedLLM's chat interface
+ response = self._llm.chat(user_id, query)
+
+ debug_info = {
+ "mode": self.mode,
+ "num_memories_retrieved": len(response.debug.selected_memory_ids) if response.debug else 0,
+ "selected_memories": response.debug.selected_memory_notes if response.debug else [],
+ "extracted_preferences": response.debug.extracted_preferences if response.debug else [],
+ }
+
+ if response.debug and response.debug.extra:
+ debug_info.update(response.debug.extra)
+
+ return AgentResponse(
+ answer=response.answer,
+ debug_info=debug_info,
+ )
+
+ except Exception as e:
+ print(f"Error in RAGMemoryAgent.respond: {e}")
+ return AgentResponse(
+ answer=f"I apologize for the error. Regarding: {query[:100]}",
+ debug_info={"error": str(e)},
+ )
+
+ def end_session(self, user_id: str, conversation: List[Dict[str, str]]):
+ """
+ Called at end of session.
+ PersonalizedLLM already extracts preferences during chat(),
+ so we just reset the session state.
+ """
+ self._ensure_initialized()
+
+ if self._llm is not None:
+ self._llm.reset_session(user_id)
+
+ def reset_user(self, user_id: str):
+ """Reset all state for a user."""
+ self._ensure_initialized()
+
+ if self._llm is not None:
+ self._llm.reset_user(user_id)
+
+ def apply_feedback(self, user_id: str, reward: float, gating: float = 1.0):
+ """
+ Apply feedback for user vector updates (Y4 only).
+
+ Args:
+ user_id: User identifier
+ reward: Reward signal (e.g., from preference satisfaction)
+ gating: Gating signal (1.0 = use this feedback, 0.0 = skip)
+ """
+ if not self.enable_rl_updates or self._llm is None:
+ return
+
+ try:
+ from personalization.serving.personalized_llm import Feedback
+
+ feedback = Feedback(
+ user_id=user_id,
+ turn_id=0, # Not used in current implementation
+ reward=reward,
+ gating=gating,
+ )
+ self._llm.apply_feedback(feedback)
+
+ except Exception as e:
+ print(f"Error applying feedback: {e}")
+
+ def get_user_state(self, user_id: str) -> Dict[str, Any]:
+ """Get user state summary (for Y4 analysis)."""
+ self._ensure_initialized()
+
+ if self._llm is not None:
+ return self._llm.get_user_state_summary(user_id)
+ return {}
+
+ def persist(self):
+ """Save all state to disk."""
+ if self._llm is not None:
+ self._llm.persist()
+
+ def get_name(self) -> str:
+ mode_name = "RAG" if self.mode == "nopersonal" else "RAG+UV"
+ return f"{mode_name}({self.model_name})"
+
+
diff --git a/src/personalization/evaluation/demo/__init__.py b/src/personalization/evaluation/demo/__init__.py
new file mode 100644
index 0000000..7d50041
--- /dev/null
+++ b/src/personalization/evaluation/demo/__init__.py
@@ -0,0 +1,3 @@
+# Demo scripts for evaluation
+
+
diff --git a/src/personalization/evaluation/demo/run_demo.py b/src/personalization/evaluation/demo/run_demo.py
new file mode 100644
index 0000000..805d046
--- /dev/null
+++ b/src/personalization/evaluation/demo/run_demo.py
@@ -0,0 +1,273 @@
+#!/usr/bin/env python3
+"""
+Demo Runner Script
+
+A minimal demo to verify the evaluation pipeline works:
+- Generates preference bank (5 topics × 5 prefs = 25 total)
+- Creates 2 user profiles (10 prefs each)
+- Runs 3 tasks per user
+- Compares T1 (NoMemory) vs Y3 (RAG) agents
+
+Usage:
+ # With LLM servers running:
+ python run_demo.py
+
+ # Dry run (no LLM, uses fallback responses):
+ python run_demo.py --dry-run
+
+ # Specify output directory:
+ python run_demo.py --output-dir /path/to/output
+"""
+
+import argparse
+import os
+import sys
+
+# Add src to path
+_src_path = os.path.join(os.path.dirname(__file__), "../../../..")
+if _src_path not in sys.path:
+ sys.path.insert(0, _src_path)
+
+
+def run_preference_bank_demo():
+ """Generate and display a demo preference bank."""
+ print("\n" + "=" * 60)
+ print("STEP 1: Generate Preference Bank")
+ print("=" * 60)
+
+ from personalization.evaluation.preference_bank.generator import generate_demo_bank
+
+ output_dir = "data/eval/demo"
+ os.makedirs(output_dir, exist_ok=True)
+
+ bank_path = os.path.join(output_dir, "preference_bank.json")
+ bank = generate_demo_bank(output_path=bank_path, use_llm=False)
+
+ print(f"\nGenerated preference bank with {bank.stats()['total_preferences']} preferences")
+ print(f"Topics: {list(bank.topics.keys())}")
+
+ # Show sample preferences
+ print("\nSample preferences:")
+ for topic_name, topic in list(bank.topics.items())[:2]:
+ print(f"\n {topic_name}:")
+ for pref in topic.preferences[:2]:
+ print(f" - When {pref.condition}: {pref.action}")
+
+ return bank
+
+
+def run_profile_demo(bank):
+ """Generate demo user profiles."""
+ print("\n" + "=" * 60)
+ print("STEP 2: Generate User Profiles")
+ print("=" * 60)
+
+ from personalization.evaluation.profiles.generator import generate_demo_profiles
+
+ output_dir = "data/eval/demo"
+ profiles_path = os.path.join(output_dir, "user_profiles.json")
+
+ profiles = generate_demo_profiles(
+ bank=bank,
+ num_users=2,
+ prefs_per_user=10,
+ output_path=profiles_path,
+ seed=42,
+ )
+
+ print(f"\nGenerated {len(profiles)} user profiles")
+
+ for profile in profiles:
+ print(f"\n {profile.user_id}:")
+ print(f" Persona: {profile.persona}")
+ print(f" Primary topics: {profile.primary_topics}")
+ print(f" Num preferences: {len(profile.preferences)}")
+
+ return profiles
+
+
+def run_agent_demo(dry_run: bool = True):
+ """Test agent response generation."""
+ print("\n" + "=" * 60)
+ print("STEP 3: Test Agent Responses")
+ print("=" * 60)
+
+ from personalization.evaluation.baselines.no_memory import NoMemoryAgent
+
+ # Create agent (will use fallback if no LLM available)
+ agent = NoMemoryAgent(
+ model_name="llama-8b",
+ api_base="http://localhost:8003/v1" if not dry_run else None,
+ )
+
+ # Test response
+ test_query = "What is 2 + 2?"
+ response = agent.respond(
+ user_id="test_user",
+ query=test_query,
+ conversation_history=[],
+ )
+
+ print(f"\nQuery: {test_query}")
+ print(f"Response: {response.answer[:200]}...")
+ print(f"Debug: {response.debug_info}")
+
+ return agent
+
+
+def run_user_simulator_demo(profiles, dry_run: bool = True):
+ """Test user simulator."""
+ print("\n" + "=" * 60)
+ print("STEP 4: Test User Simulator")
+ print("=" * 60)
+
+ from personalization.evaluation.user_simulator.simulator import UserSimulator
+ from personalization.evaluation.pipeline.evaluator import Task
+
+ # Create simulator
+ simulator = UserSimulator(
+ model_name="Llama-3.3-70B-Instruct",
+ api_base="http://localhost:8004/v1" if not dry_run else None,
+ )
+
+ # Setup with first profile
+ profile = profiles[0]
+ task = Task(
+ task_id="test_001",
+ dataset="test",
+ problem="What is the derivative of x^2?",
+ solution="2x",
+ task_description="Solve this calculus problem:",
+ )
+
+ simulator.setup(
+ profile=profile,
+ task_description=task.task_description,
+ problem=task.problem,
+ solution=task.solution,
+ )
+
+ # Simulate first turn
+ conversation = [
+ {"role": "assistant", "content": "How can I help you?"}
+ ]
+
+ response = simulator.respond(conversation)
+
+ print(f"\nUser profile: {profile.user_id}")
+ print(f"Task: {task.problem}")
+ print(f"\nUser response: {response.response[:200]}...")
+ print(f"Enforcement needed: {response.enforcement_needed}")
+ print(f"Draft answer: {response.draft_answer}")
+
+ return simulator
+
+
+def run_full_demo(dry_run: bool = True, output_dir: str = "data/eval/demo"):
+ """Run complete demo experiment."""
+ print("\n" + "=" * 60)
+ print("STEP 5: Run Full Demo Experiment")
+ print("=" * 60)
+
+ if dry_run:
+ print("\n[DRY RUN MODE] Using fallback responses, no LLM calls\n")
+
+ from personalization.evaluation.pipeline.runner import ExperimentRunner, ExperimentConfig
+
+ config = ExperimentConfig(
+ name="demo_experiment",
+ output_dir=output_dir,
+ num_users=2,
+ prefs_per_user=10,
+ tasks_per_user=2, # Just 2 tasks for quick demo
+ max_turns=10, # Short conversations
+ run_no_memory=True,
+ run_rag=False, # Skip RAG for initial demo (needs more setup)
+ run_rag_uv=False,
+ agent_api_base="http://localhost:8003/v1" if not dry_run else "http://localhost:9999/v1",
+ user_sim_api_base="http://localhost:8004/v1" if not dry_run else "http://localhost:9999/v1",
+ )
+
+ runner = ExperimentRunner(config)
+ runner.setup()
+ metrics = runner.run()
+
+ return metrics
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Run evaluation demo")
+ parser.add_argument(
+ "--dry-run",
+ action="store_true",
+ help="Run without LLM (uses fallback responses)",
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default="data/eval/demo",
+ help="Output directory for results",
+ )
+ parser.add_argument(
+ "--step",
+ type=str,
+ choices=["bank", "profiles", "agent", "simulator", "full", "all"],
+ default="all",
+ help="Which step to run",
+ )
+
+ args = parser.parse_args()
+
+ print("\n" + "=" * 60)
+ print("PERSONALIZATION EVALUATION DEMO")
+ print("=" * 60)
+ print(f"Mode: {'DRY RUN (no LLM)' if args.dry_run else 'LIVE (requires LLM servers)'}")
+ print(f"Output: {args.output_dir}")
+ print("=" * 60)
+
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.step in ["bank", "all"]:
+ bank = run_preference_bank_demo()
+ else:
+ # Load existing bank
+ from personalization.evaluation.preference_bank.schemas import PreferenceBank
+ bank_path = os.path.join(args.output_dir, "preference_bank.json")
+ if os.path.exists(bank_path):
+ bank = PreferenceBank.load(bank_path)
+ else:
+ bank = run_preference_bank_demo()
+
+ if args.step in ["profiles", "all"]:
+ profiles = run_profile_demo(bank)
+ else:
+ from personalization.evaluation.profiles.generator import UserProfileGenerator
+ profiles_path = os.path.join(args.output_dir, "user_profiles.json")
+ if os.path.exists(profiles_path):
+ profiles = UserProfileGenerator.load_profiles(profiles_path)
+ else:
+ profiles = run_profile_demo(bank)
+
+ if args.step in ["agent", "all"]:
+ run_agent_demo(dry_run=args.dry_run)
+
+ if args.step in ["simulator", "all"]:
+ run_user_simulator_demo(profiles, dry_run=args.dry_run)
+
+ if args.step in ["full", "all"]:
+ run_full_demo(dry_run=args.dry_run, output_dir=args.output_dir)
+
+ print("\n" + "=" * 60)
+ print("DEMO COMPLETE!")
+ print("=" * 60)
+ print(f"\nResults saved to: {args.output_dir}/")
+ print("\nNext steps:")
+ print(" 1. Start LLM servers (vLLM/SGLang)")
+ print(" 2. Run without --dry-run flag")
+ print(" 3. Enable RAG baseline for full comparison")
+
+
+if __name__ == "__main__":
+ main()
+
+
diff --git a/src/personalization/evaluation/pipeline/__init__.py b/src/personalization/evaluation/pipeline/__init__.py
new file mode 100644
index 0000000..183d0c5
--- /dev/null
+++ b/src/personalization/evaluation/pipeline/__init__.py
@@ -0,0 +1,6 @@
+from .evaluator import Evaluator, SessionResult, EvaluationMetrics
+from .runner import ExperimentRunner
+
+__all__ = ["Evaluator", "SessionResult", "EvaluationMetrics", "ExperimentRunner"]
+
+
diff --git a/src/personalization/evaluation/pipeline/evaluator.py b/src/personalization/evaluation/pipeline/evaluator.py
new file mode 100644
index 0000000..7304400
--- /dev/null
+++ b/src/personalization/evaluation/pipeline/evaluator.py
@@ -0,0 +1,353 @@
+"""
+Evaluation Pipeline
+
+Runs evaluation sessions between user simulator and agents.
+Computes metrics: Task Success (TS), User Effort (UE), Efficiency (Eff).
+"""
+
+import json
+import os
+from dataclasses import dataclass, field, asdict
+from typing import List, Dict, Any, Optional
+from datetime import datetime
+
+from ..profiles.generator import UserProfile
+from ..preference_bank.schemas import PreferenceBank
+from ..baselines.base import BaselineAgent
+from ..user_simulator.simulator import UserSimulator, UserSimulatorResponse
+
+
+@dataclass
+class Task:
+ """A problem/task for evaluation."""
+ task_id: str
+ dataset: str
+ problem: str
+ solution: str
+ task_description: str = "Work with the assistant to solve this problem:"
+
+
+@dataclass
+class SessionResult:
+ """Result of a single evaluation session."""
+ user_id: str
+ task_id: str
+ dataset: str
+ agent_name: str
+
+ # Metrics
+ task_success: bool # TS: Was the task solved correctly?
+ user_effort: int # UE: Number of preference enforcements
+ efficiency: int # Eff: Total number of messages
+
+ # Details
+ conversation: List[Dict[str, str]]
+ preference_violations: List[Dict[str, Any]]
+ final_draft_answer: str
+
+ # Debug
+ debug_info: Dict[str, Any] = field(default_factory=dict)
+ timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
+
+ def to_dict(self) -> Dict[str, Any]:
+ return asdict(self)
+
+
+@dataclass
+class EvaluationMetrics:
+ """Aggregated evaluation metrics."""
+ agent_name: str
+ num_sessions: int
+
+ # Average metrics
+ avg_task_success: float # Average TS
+ avg_user_effort: float # Average UE
+ avg_efficiency: float # Average Eff
+
+ # Breakdowns
+ task_success_by_dataset: Dict[str, float] = field(default_factory=dict)
+ user_effort_by_dataset: Dict[str, float] = field(default_factory=dict)
+
+ def to_dict(self) -> Dict[str, Any]:
+ return asdict(self)
+
+
+class JudgeModel:
+ """
+ LLM judge for evaluating task success.
+ Uses the same approach as collaborativeagents.
+ """
+
+ def __init__(
+ self,
+ model_name: str = "Llama-3.3-70B-Instruct",
+ api_base: Optional[str] = None,
+ api_key: Optional[str] = None,
+ ):
+ self.model_name = model_name
+ self.api_base = api_base or os.getenv("JUDGE_API_BASE", "http://localhost:8004/v1")
+ self.api_key = api_key or os.getenv("JUDGE_API_KEY", "EMPTY")
+
+ self._init_client()
+
+ def _init_client(self):
+ try:
+ import openai
+ self.client = openai.OpenAI(
+ base_url=self.api_base,
+ api_key=self.api_key,
+ )
+ except Exception as e:
+ print(f"Warning: Could not initialize judge client: {e}")
+ self.client = None
+
+ def evaluate_answer(
+ self,
+ problem: str,
+ correct_answer: str,
+ user_draft_answer: str,
+ ) -> bool:
+ """
+ Evaluate if the user's draft answer is correct.
+
+ Returns:
+ True if answer is correct, False otherwise
+ """
+ prompt = f"""You are an expert evaluator. Determine if the user's answer is correct.
+
+# Problem
+{problem}
+
+# Correct Answer
+{correct_answer}
+
+# User's Answer
+{user_draft_answer}
+
+# Instructions
+Determine if the user's answer is accurate and consistent with the correct answer.
+Minor formatting differences are acceptable.
+The core answer/solution must match.
+
+Output JSON:
+{{
+ "reasoning": "Brief explanation",
+ "is_correct": true or false
+}}
+
+Output only valid JSON."""
+
+ if self.client is None:
+ # Fallback - simple string matching
+ return correct_answer.lower().strip() in user_draft_answer.lower()
+
+ try:
+ response = self.client.chat.completions.create(
+ model=self.model_name,
+ messages=[{"role": "user", "content": prompt}],
+ temperature=0.0,
+ max_tokens=256,
+ )
+
+ text = response.choices[0].message.content.strip()
+
+ # Parse JSON
+ if "```" in text:
+ text = text.split("```")[1]
+ if text.startswith("json"):
+ text = text[4:]
+
+ data = json.loads(text)
+ return data.get("is_correct", False)
+
+ except Exception as e:
+ print(f"Error in judge evaluation: {e}")
+ # Fallback
+ return correct_answer.lower().strip() in user_draft_answer.lower()
+
+
+class Evaluator:
+ """
+ Main evaluator that runs sessions and computes metrics.
+ """
+
+ def __init__(
+ self,
+ user_simulator: Optional[UserSimulator] = None,
+ judge: Optional[JudgeModel] = None,
+ ):
+ self.user_sim = user_simulator or UserSimulator()
+ self.judge = judge or JudgeModel()
+
+ def run_session(
+ self,
+ agent: BaselineAgent,
+ user_profile: UserProfile,
+ task: Task,
+ max_turns: int = 30,
+ ) -> SessionResult:
+ """
+ Run a single evaluation session.
+
+ Args:
+ agent: The agent being evaluated
+ user_profile: User with preferences
+ task: Task to solve
+ max_turns: Maximum conversation turns
+
+ Returns:
+ SessionResult with metrics and conversation
+ """
+ # Setup user simulator
+ self.user_sim.setup(
+ profile=user_profile,
+ task_description=task.task_description,
+ problem=task.problem,
+ solution=task.solution,
+ )
+
+ conversation: List[Dict[str, str]] = []
+ preference_violations: List[Dict[str, Any]] = []
+ user_effort = 0
+ final_draft_answer = "I don't know"
+
+ # Agent opens the conversation
+ conversation.append({
+ "role": "assistant",
+ "content": "How can I help you today?"
+ })
+
+ for turn in range(max_turns):
+ # User responds
+ user_response = self.user_sim.respond(conversation)
+
+ conversation.append({
+ "role": "user",
+ "content": user_response.response,
+ })
+
+ # Track preference violations and enforcement
+ violations_this_turn = [
+ {
+ "turn": turn,
+ "preference_id": check.preference_id,
+ "topic": check.topic,
+ "violation_detail": check.violation_detail,
+ }
+ for check in user_response.preference_checks
+ if check.relevant and check.satisfied == False
+ ]
+
+ if violations_this_turn:
+ preference_violations.extend(violations_this_turn)
+
+ if user_response.enforcement_needed:
+ user_effort += 1
+
+ final_draft_answer = user_response.draft_answer
+
+ # Check termination
+ if user_response.should_terminate or "TERMINATE" in user_response.response:
+ break
+
+ # Agent responds
+ agent_response = agent.respond(
+ user_id=user_profile.user_id,
+ query=user_response.response,
+ conversation_history=conversation,
+ )
+
+ conversation.append({
+ "role": "assistant",
+ "content": agent_response.answer,
+ })
+
+ # End session for agent (update memory, etc.)
+ agent.end_session(user_profile.user_id, conversation)
+
+ # Evaluate task success
+ task_success = self.judge.evaluate_answer(
+ problem=task.problem,
+ correct_answer=task.solution,
+ user_draft_answer=final_draft_answer,
+ )
+
+ return SessionResult(
+ user_id=user_profile.user_id,
+ task_id=task.task_id,
+ dataset=task.dataset,
+ agent_name=agent.get_name(),
+ task_success=task_success,
+ user_effort=user_effort,
+ efficiency=len(conversation),
+ conversation=conversation,
+ preference_violations=preference_violations,
+ final_draft_answer=final_draft_answer,
+ debug_info={
+ "num_turns": len(conversation) // 2,
+ "num_violations": len(preference_violations),
+ },
+ )
+
+ def aggregate_metrics(
+ self,
+ results: List[SessionResult],
+ agent_name: str,
+ ) -> EvaluationMetrics:
+ """
+ Aggregate metrics from multiple sessions.
+ """
+ if not results:
+ return EvaluationMetrics(
+ agent_name=agent_name,
+ num_sessions=0,
+ avg_task_success=0.0,
+ avg_user_effort=0.0,
+ avg_efficiency=0.0,
+ )
+
+ # Overall averages
+ avg_ts = sum(r.task_success for r in results) / len(results)
+ avg_ue = sum(r.user_effort for r in results) / len(results)
+ avg_eff = sum(r.efficiency for r in results) / len(results)
+
+ # By dataset
+ datasets = set(r.dataset for r in results)
+ ts_by_ds = {}
+ ue_by_ds = {}
+
+ for ds in datasets:
+ ds_results = [r for r in results if r.dataset == ds]
+ if ds_results:
+ ts_by_ds[ds] = sum(r.task_success for r in ds_results) / len(ds_results)
+ ue_by_ds[ds] = sum(r.user_effort for r in ds_results) / len(ds_results)
+
+ return EvaluationMetrics(
+ agent_name=agent_name,
+ num_sessions=len(results),
+ avg_task_success=avg_ts,
+ avg_user_effort=avg_ue,
+ avg_efficiency=avg_eff,
+ task_success_by_dataset=ts_by_ds,
+ user_effort_by_dataset=ue_by_ds,
+ )
+
+ def save_results(self, results: List[SessionResult], path: str):
+ """Save results to JSONL file."""
+ with open(path, "w", encoding="utf-8") as f:
+ for result in results:
+ f.write(json.dumps(result.to_dict(), ensure_ascii=False) + "\n")
+
+ @staticmethod
+ def load_results(path: str) -> List[SessionResult]:
+ """Load results from JSONL file."""
+ results = []
+ with open(path, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ data = json.loads(line)
+ # Reconstruct SessionResult
+ results.append(SessionResult(**data))
+ return results
+
+
diff --git a/src/personalization/evaluation/pipeline/runner.py b/src/personalization/evaluation/pipeline/runner.py
new file mode 100644
index 0000000..9971c7b
--- /dev/null
+++ b/src/personalization/evaluation/pipeline/runner.py
@@ -0,0 +1,333 @@
+"""
+Experiment Runner
+
+Orchestrates the full evaluation experiment:
+1. Generate/load preference bank and user profiles
+2. Load datasets
+3. Run sessions for all users × tasks × agents
+4. Aggregate and report metrics
+"""
+
+import json
+import os
+from dataclasses import dataclass
+from typing import List, Dict, Any, Optional
+from datetime import datetime
+from tqdm import tqdm
+
+from ..preference_bank.schemas import PreferenceBank
+from ..preference_bank.generator import generate_demo_bank
+from ..profiles.generator import UserProfile, UserProfileGenerator, generate_demo_profiles
+from ..baselines.base import BaselineAgent
+from ..baselines.no_memory import NoMemoryAgent
+from ..baselines.rag_memory import RAGMemoryAgent
+from ..user_simulator.simulator import UserSimulator
+from .evaluator import Evaluator, Task, SessionResult, EvaluationMetrics
+
+
+# Demo dataset: Simple math problems
+DEMO_TASKS = [
+ Task(
+ task_id="math_001",
+ dataset="math-demo",
+ problem="What is the derivative of f(x) = x^3 + 2x^2 - 5x + 3?",
+ solution="f'(x) = 3x^2 + 4x - 5",
+ task_description="Work with the assistant to solve this calculus problem:",
+ ),
+ Task(
+ task_id="math_002",
+ dataset="math-demo",
+ problem="Solve for x: 2x + 5 = 3x - 7",
+ solution="x = 12",
+ task_description="Work with the assistant to solve this algebra problem:",
+ ),
+ Task(
+ task_id="math_003",
+ dataset="math-demo",
+ problem="Find the area of a circle with radius 5.",
+ solution="A = 25π ≈ 78.54 square units",
+ task_description="Work with the assistant to solve this geometry problem:",
+ ),
+ Task(
+ task_id="code_001",
+ dataset="code-demo",
+ problem="Write a Python function that checks if a string is a palindrome.",
+ solution="def is_palindrome(s): return s == s[::-1]",
+ task_description="Work with the assistant to write this Python function:",
+ ),
+ Task(
+ task_id="code_002",
+ dataset="code-demo",
+ problem="Write a function to find the nth Fibonacci number.",
+ solution="def fib(n): return n if n <= 1 else fib(n-1) + fib(n-2)",
+ task_description="Work with the assistant to implement this algorithm:",
+ ),
+]
+
+
+@dataclass
+class ExperimentConfig:
+ """Configuration for an experiment run."""
+ name: str
+ output_dir: str
+
+ # Scale
+ num_users: int = 2
+ prefs_per_user: int = 10
+ tasks_per_user: int = 3
+ max_turns: int = 25
+
+ # Baselines to run
+ run_no_memory: bool = True
+ run_rag: bool = True
+ run_rag_uv: bool = False # User vector mode
+
+ # Model configs
+ agent_model: str = "llama-8b"
+ user_sim_model: str = "Llama-3.3-70B-Instruct"
+ judge_model: str = "Llama-3.3-70B-Instruct"
+
+ # API endpoints
+ agent_api_base: str = "http://localhost:8003/v1"
+ user_sim_api_base: str = "http://localhost:8004/v1"
+
+ seed: int = 42
+
+
+class ExperimentRunner:
+ """
+ Runs a complete evaluation experiment.
+ """
+
+ def __init__(self, config: ExperimentConfig):
+ self.config = config
+
+ # Create output directory
+ os.makedirs(config.output_dir, exist_ok=True)
+
+ # Will be initialized lazily
+ self._bank: Optional[PreferenceBank] = None
+ self._profiles: Optional[List[UserProfile]] = None
+ self._tasks: Optional[List[Task]] = None
+ self._evaluator: Optional[Evaluator] = None
+
+ def setup(self):
+ """Initialize all components."""
+ print("=" * 60)
+ print(f"Setting up experiment: {self.config.name}")
+ print("=" * 60)
+
+ # 1. Generate/load preference bank
+ bank_path = os.path.join(self.config.output_dir, "preference_bank.json")
+ if os.path.exists(bank_path):
+ print(f"Loading existing preference bank from {bank_path}")
+ self._bank = PreferenceBank.load(bank_path)
+ else:
+ print("Generating new preference bank...")
+ self._bank = generate_demo_bank(output_path=bank_path, use_llm=False)
+
+ print(f" Bank stats: {self._bank.stats()}")
+
+ # 2. Generate/load user profiles
+ profiles_path = os.path.join(self.config.output_dir, "user_profiles.json")
+ if os.path.exists(profiles_path):
+ print(f"Loading existing profiles from {profiles_path}")
+ self._profiles = UserProfileGenerator.load_profiles(profiles_path)
+ else:
+ print(f"Generating {self.config.num_users} user profiles...")
+ self._profiles = generate_demo_profiles(
+ bank=self._bank,
+ num_users=self.config.num_users,
+ prefs_per_user=self.config.prefs_per_user,
+ output_path=profiles_path,
+ seed=self.config.seed,
+ )
+
+ # 3. Load tasks
+ self._tasks = DEMO_TASKS[:self.config.tasks_per_user * 2] # Use demo tasks
+ print(f" Loaded {len(self._tasks)} tasks")
+
+ # 4. Initialize evaluator
+ user_sim = UserSimulator(
+ model_name=self.config.user_sim_model,
+ api_base=self.config.user_sim_api_base,
+ )
+ self._evaluator = Evaluator(user_simulator=user_sim)
+
+ print("Setup complete!\n")
+
+ def _create_agents(self) -> Dict[str, BaselineAgent]:
+ """Create agent instances based on config."""
+ agents = {}
+
+ if self.config.run_no_memory:
+ agents["T1_NoMemory"] = NoMemoryAgent(
+ model_name=self.config.agent_model,
+ api_base=self.config.agent_api_base,
+ )
+
+ if self.config.run_rag:
+ # Create directories for RAG memory
+ memory_dir = os.path.join(self.config.output_dir, "rag_memory")
+ os.makedirs(memory_dir, exist_ok=True)
+
+ agents["Y3_RAG"] = RAGMemoryAgent(
+ model_name=self.config.agent_model,
+ mode="nopersonal",
+ memory_cards_path=os.path.join(memory_dir, "memory_cards.jsonl"),
+ memory_embeddings_path=os.path.join(memory_dir, "embeddings.npy"),
+ )
+
+ if self.config.run_rag_uv:
+ memory_dir = os.path.join(self.config.output_dir, "rag_uv_memory")
+ os.makedirs(memory_dir, exist_ok=True)
+
+ agents["Y4_RAG_UV"] = RAGMemoryAgent(
+ model_name=self.config.agent_model,
+ mode="full",
+ memory_cards_path=os.path.join(memory_dir, "memory_cards.jsonl"),
+ memory_embeddings_path=os.path.join(memory_dir, "embeddings.npy"),
+ enable_rl_updates=True,
+ )
+
+ return agents
+
+ def run(self) -> Dict[str, EvaluationMetrics]:
+ """
+ Run the full experiment.
+
+ Returns:
+ Dict mapping agent name to aggregated metrics
+ """
+ if self._evaluator is None:
+ self.setup()
+
+ agents = self._create_agents()
+ all_results: Dict[str, List[SessionResult]] = {name: [] for name in agents}
+
+ print("=" * 60)
+ print("Running experiment")
+ print("=" * 60)
+
+ # Run for each agent
+ for agent_name, agent in agents.items():
+ print(f"\n>>> Agent: {agent_name}")
+
+ # Run for each user
+ for profile in tqdm(self._profiles, desc=f"Users ({agent_name})"):
+ # Reset user state
+ agent.reset_user(profile.user_id)
+
+ # Get tasks for this user
+ # In demo, just cycle through available tasks
+ user_tasks = self._tasks[:self.config.tasks_per_user]
+
+ # Run sessions
+ for task in user_tasks:
+ result = self._evaluator.run_session(
+ agent=agent,
+ user_profile=profile,
+ task=task,
+ max_turns=self.config.max_turns,
+ )
+
+ all_results[agent_name].append(result)
+
+ # Print progress
+ status = "✓" if result.task_success else "✗"
+ print(f" {profile.user_id} | {task.task_id} | "
+ f"TS={status} | UE={result.user_effort} | Eff={result.efficiency}")
+
+ # Save raw results
+ for agent_name, results in all_results.items():
+ results_path = os.path.join(
+ self.config.output_dir,
+ f"results_{agent_name}.jsonl"
+ )
+ self._evaluator.save_results(results, results_path)
+
+ # Aggregate metrics
+ metrics = {}
+ for agent_name, results in all_results.items():
+ metrics[agent_name] = self._evaluator.aggregate_metrics(results, agent_name)
+
+ # Save and print summary
+ self._save_summary(metrics)
+ self._print_summary(metrics)
+
+ return metrics
+
+ def _save_summary(self, metrics: Dict[str, EvaluationMetrics]):
+ """Save experiment summary."""
+ summary = {
+ "experiment_name": self.config.name,
+ "timestamp": datetime.now().isoformat(),
+ "config": {
+ "num_users": self.config.num_users,
+ "prefs_per_user": self.config.prefs_per_user,
+ "tasks_per_user": self.config.tasks_per_user,
+ "max_turns": self.config.max_turns,
+ },
+ "metrics": {name: m.to_dict() for name, m in metrics.items()},
+ }
+
+ summary_path = os.path.join(self.config.output_dir, "summary.json")
+ with open(summary_path, "w", encoding="utf-8") as f:
+ json.dump(summary, f, indent=2, ensure_ascii=False)
+
+ print(f"\nSummary saved to {summary_path}")
+
+ def _print_summary(self, metrics: Dict[str, EvaluationMetrics]):
+ """Print experiment summary."""
+ print("\n" + "=" * 60)
+ print("EXPERIMENT SUMMARY")
+ print("=" * 60)
+
+ # Header
+ print(f"\n{'Agent':<20} {'TS ↑':>10} {'UE ↓':>10} {'Eff ↓':>10} {'Sessions':>10}")
+ print("-" * 60)
+
+ for agent_name, m in metrics.items():
+ print(f"{agent_name:<20} {m.avg_task_success:>10.2%} "
+ f"{m.avg_user_effort:>10.2f} {m.avg_efficiency:>10.1f} "
+ f"{m.num_sessions:>10}")
+
+ print("\n" + "=" * 60)
+
+
+def run_demo_experiment(output_dir: str = "data/eval/demo_experiment"):
+ """
+ Run a minimal demo experiment.
+
+ This is a quick sanity check with:
+ - 2 users
+ - 10 preferences per user
+ - 3 tasks per user
+ - T1 (NoMemory) vs Y3 (RAG) comparison
+ """
+ config = ExperimentConfig(
+ name="demo_experiment",
+ output_dir=output_dir,
+ num_users=2,
+ prefs_per_user=10,
+ tasks_per_user=3,
+ max_turns=15,
+ run_no_memory=True,
+ run_rag=True,
+ run_rag_uv=False,
+ )
+
+ runner = ExperimentRunner(config)
+ runner.setup()
+ metrics = runner.run()
+
+ return metrics
+
+
+if __name__ == "__main__":
+ import sys
+
+ output_dir = sys.argv[1] if len(sys.argv) > 1 else "data/eval/demo_experiment"
+ run_demo_experiment(output_dir)
+
+
diff --git a/src/personalization/evaluation/preference_bank/__init__.py b/src/personalization/evaluation/preference_bank/__init__.py
new file mode 100644
index 0000000..33f0ed2
--- /dev/null
+++ b/src/personalization/evaluation/preference_bank/__init__.py
@@ -0,0 +1,6 @@
+from .schemas import PreferenceItem, PreferenceTopic, PreferenceBank
+from .generator import PreferenceBankGenerator
+
+__all__ = ["PreferenceItem", "PreferenceTopic", "PreferenceBank", "PreferenceBankGenerator"]
+
+
diff --git a/src/personalization/evaluation/preference_bank/generator.py b/src/personalization/evaluation/preference_bank/generator.py
new file mode 100644
index 0000000..e256b86
--- /dev/null
+++ b/src/personalization/evaluation/preference_bank/generator.py
@@ -0,0 +1,530 @@
+"""
+Preference Bank Generator
+
+Uses LLM to automatically generate diverse user preferences for each topic.
+"""
+
+import json
+import os
+from typing import List, Dict, Any, Optional
+from dataclasses import dataclass
+
+from .schemas import PreferenceItem, PreferenceTopic, PreferenceBank
+
+
+# Topic definitions for the demo (5 topics)
+DEMO_TOPICS = {
+ "math_formatting": {
+ "description": "How mathematical content should be formatted (LaTeX, plain text, markdown)",
+ "related_datasets": ["math-hard", "math-500", "gpqa"],
+ "generation_hints": [
+ "LaTeX formatting for equations",
+ "Plain text vs mathematical notation",
+ "Inline vs block equations",
+ "Step-by-step calculation display",
+ "Variable naming conventions",
+ ],
+ },
+ "coding_style": {
+ "description": "Preferences for code formatting, language choice, and documentation",
+ "related_datasets": ["humaneval", "bigcodebench"],
+ "generation_hints": [
+ "Programming language preference (Python, JavaScript, etc.)",
+ "Type hints and annotations",
+ "Docstrings and comments",
+ "Code structure and organization",
+ "Naming conventions",
+ ],
+ },
+ "response_structure": {
+ "description": "How responses should be organized (bullets, numbered lists, prose)",
+ "related_datasets": ["all"],
+ "generation_hints": [
+ "Bullet points vs numbered lists vs prose",
+ "Headers and sections",
+ "TL;DR summaries",
+ "Outline before detailed explanation",
+ "Logical flow and transitions",
+ ],
+ },
+ "explanation_depth": {
+ "description": "Level of detail and thoroughness in explanations",
+ "related_datasets": ["all"],
+ "generation_hints": [
+ "Concise vs comprehensive",
+ "Examples and analogies",
+ "Background context",
+ "Assumptions stated explicitly",
+ "Multiple approaches/alternatives",
+ ],
+ },
+ "interaction_style": {
+ "description": "How the agent should interact (questions, confirmations, suggestions)",
+ "related_datasets": ["all"],
+ "generation_hints": [
+ "Asking clarifying questions",
+ "Step-by-step vs holistic answers",
+ "Proactive suggestions",
+ "Confidence levels in answers",
+ "Politeness and tone",
+ ],
+ },
+}
+
+
+# LLM prompt template for generating preferences
+GENERATION_PROMPT = '''You are helping design a user preference benchmark. Generate {num_prefs} diverse user preferences for the topic: "{topic_name}"
+
+Topic Description: {topic_description}
+
+Hints for preference types:
+{hints}
+
+For each preference, provide a JSON object with:
+1. "condition": When this preference applies (e.g., "when solving math problems", "when explaining code")
+2. "action": What the user prefers (be specific and enforceable)
+3. "conflict_group": If this preference conflicts with others in the list, give them the same group name (e.g., "notation_style"). Use null if no conflict.
+4. "enforce_description": How a user would detect violation and enforce this preference
+5. "example_violation": A concrete example of an agent response that violates this
+6. "example_compliance": A concrete example that follows this preference
+
+Requirements:
+- Make preferences SPECIFIC and ENFORCEABLE (not vague like "be helpful")
+- Include 2-3 pairs of CONFLICTING preferences (same conflict_group) - this is important for testing RAG
+- Vary specificity: some broad ("always use Python"), some narrow ("use f-strings for string formatting in Python")
+- Preferences should be realistic things users actually care about
+
+Output as a JSON array of objects. Only output the JSON array, no other text.
+'''
+
+
+class PreferenceBankGenerator:
+ """Generates a preference bank using LLM."""
+
+ def __init__(
+ self,
+ llm_client: Any = None,
+ model_name: str = "gpt-4o-mini", # Default to a capable but fast model
+ ):
+ """
+ Args:
+ llm_client: OpenAI-compatible client. If None, will create one.
+ model_name: Model to use for generation.
+ """
+ self.model_name = model_name
+
+ if llm_client is None:
+ try:
+ import openai
+ self.client = openai.OpenAI()
+ except Exception as e:
+ print(f"Warning: Could not initialize OpenAI client: {e}")
+ self.client = None
+ else:
+ self.client = llm_client
+
+ def generate_preferences_for_topic(
+ self,
+ topic_name: str,
+ topic_description: str,
+ hints: List[str],
+ num_prefs: int = 5,
+ ) -> List[PreferenceItem]:
+ """Generate preferences for a single topic using LLM."""
+
+ if self.client is None:
+ print(f"No LLM client available, using fallback for topic: {topic_name}")
+ return self._generate_fallback_preferences(topic_name, num_prefs)
+
+ hints_text = "\n".join(f"- {h}" for h in hints)
+
+ prompt = GENERATION_PROMPT.format(
+ num_prefs=num_prefs,
+ topic_name=topic_name,
+ topic_description=topic_description,
+ hints=hints_text,
+ )
+
+ try:
+ response = self.client.chat.completions.create(
+ model=self.model_name,
+ messages=[{"role": "user", "content": prompt}],
+ temperature=0.8,
+ max_tokens=4000,
+ )
+
+ content = response.choices[0].message.content.strip()
+
+ # Parse JSON
+ # Handle potential markdown code blocks
+ if content.startswith("```"):
+ content = content.split("```")[1]
+ if content.startswith("json"):
+ content = content[4:]
+
+ prefs_data = json.loads(content)
+
+ # Convert to PreferenceItem objects
+ preferences = []
+ for i, pref_dict in enumerate(prefs_data):
+ pref_id = f"{topic_name[:4]}_{i+1:03d}"
+ pref = PreferenceItem(
+ id=pref_id,
+ topic=topic_name,
+ condition=pref_dict.get("condition", ""),
+ action=pref_dict.get("action", ""),
+ conflict_group=pref_dict.get("conflict_group"),
+ enforce_description=pref_dict.get("enforce_description", ""),
+ example_violation=pref_dict.get("example_violation", ""),
+ example_compliance=pref_dict.get("example_compliance", ""),
+ )
+ preferences.append(pref)
+
+ return preferences
+
+ except Exception as e:
+ print(f"Error generating preferences for {topic_name}: {e}")
+ return self._generate_fallback_preferences(topic_name, num_prefs)
+
+ def _generate_fallback_preferences(
+ self,
+ topic_name: str,
+ num_prefs: int = 5,
+ ) -> List[PreferenceItem]:
+ """Generate hardcoded fallback preferences when LLM is not available."""
+
+ fallbacks = {
+ "math_formatting": [
+ PreferenceItem(
+ id="math_001", topic="math_formatting",
+ condition="solving math problems",
+ action="use LaTeX for all formulas and equations",
+ conflict_group="math_notation",
+ enforce_description="Check if mathematical expressions use LaTeX syntax like $x^2$ or $$\\int$$",
+ example_violation="The answer is x squared plus 2x plus 1",
+ example_compliance="The answer is $x^2 + 2x + 1$",
+ ),
+ PreferenceItem(
+ id="math_002", topic="math_formatting",
+ condition="explaining mathematical concepts",
+ action="use plain text only, avoid any mathematical notation",
+ conflict_group="math_notation",
+ enforce_description="Check if response contains any LaTeX or special math symbols",
+ example_violation="We need to find $\\frac{d}{dx}(x^2)$",
+ example_compliance="We need to find the derivative of x squared",
+ ),
+ PreferenceItem(
+ id="math_003", topic="math_formatting",
+ condition="showing multi-step calculations",
+ action="display each step on a separate line with clear labels",
+ conflict_group=None,
+ enforce_description="Check if steps are on separate lines with labels like 'Step 1:'",
+ example_violation="First we add 2+3=5, then multiply by 4 to get 20",
+ example_compliance="Step 1: Add 2 + 3 = 5\nStep 2: Multiply by 4: 5 × 4 = 20",
+ ),
+ PreferenceItem(
+ id="math_004", topic="math_formatting",
+ condition="presenting final answers",
+ action="clearly box or highlight the final answer",
+ conflict_group=None,
+ enforce_description="Check if final answer is visually distinguished",
+ example_violation="So x equals 5.",
+ example_compliance="**Final Answer: x = 5**",
+ ),
+ PreferenceItem(
+ id="math_005", topic="math_formatting",
+ condition="solving problems with multiple variables",
+ action="use single-letter variables (x, y, z) rather than descriptive names",
+ conflict_group="var_naming",
+ enforce_description="Check if variables are single letters",
+ example_violation="Let price = 100 and quantity = 5",
+ example_compliance="Let p = 100 and q = 5",
+ ),
+ ],
+ "coding_style": [
+ PreferenceItem(
+ id="code_001", topic="coding_style",
+ condition="providing code examples",
+ action="always use Python",
+ conflict_group="language",
+ enforce_description="Check if code is written in Python",
+ example_violation="```javascript\nfunction add(a, b) { return a + b; }\n```",
+ example_compliance="```python\ndef add(a, b):\n return a + b\n```",
+ ),
+ PreferenceItem(
+ id="code_002", topic="coding_style",
+ condition="providing code examples",
+ action="always use JavaScript or TypeScript",
+ conflict_group="language",
+ enforce_description="Check if code is written in JavaScript/TypeScript",
+ example_violation="```python\ndef add(a, b): return a + b\n```",
+ example_compliance="```javascript\nconst add = (a, b) => a + b;\n```",
+ ),
+ PreferenceItem(
+ id="code_003", topic="coding_style",
+ condition="writing Python functions",
+ action="always include type hints for parameters and return values",
+ conflict_group=None,
+ enforce_description="Check if function has type hints",
+ example_violation="def add(a, b):\n return a + b",
+ example_compliance="def add(a: int, b: int) -> int:\n return a + b",
+ ),
+ PreferenceItem(
+ id="code_004", topic="coding_style",
+ condition="writing functions",
+ action="include a docstring explaining the function",
+ conflict_group=None,
+ enforce_description="Check if function has a docstring",
+ example_violation="def add(a, b):\n return a + b",
+ example_compliance='def add(a, b):\n """Add two numbers and return the result."""\n return a + b',
+ ),
+ PreferenceItem(
+ id="code_005", topic="coding_style",
+ condition="writing code",
+ action="minimize comments, code should be self-documenting",
+ conflict_group="comment_style",
+ enforce_description="Check if there are excessive inline comments",
+ example_violation="x = x + 1 # increment x by 1",
+ example_compliance="x += 1",
+ ),
+ ],
+ "response_structure": [
+ PreferenceItem(
+ id="struct_001", topic="response_structure",
+ condition="providing multi-point answers",
+ action="use bullet points with '-' or '*'",
+ conflict_group="list_style",
+ enforce_description="Check if response uses bullet points",
+ example_violation="First, do X. Second, do Y. Third, do Z.",
+ example_compliance="- First, do X\n- Second, do Y\n- Third, do Z",
+ ),
+ PreferenceItem(
+ id="struct_002", topic="response_structure",
+ condition="providing step-by-step instructions",
+ action="use numbered lists",
+ conflict_group="list_style",
+ enforce_description="Check if response uses numbered lists",
+ example_violation="First do X, then do Y, finally do Z.",
+ example_compliance="1. Do X\n2. Do Y\n3. Do Z",
+ ),
+ PreferenceItem(
+ id="struct_003", topic="response_structure",
+ condition="writing explanations",
+ action="use flowing prose paragraphs, avoid lists",
+ conflict_group="list_style",
+ enforce_description="Check if response uses prose instead of lists",
+ example_violation="Key points:\n- Point 1\n- Point 2",
+ example_compliance="The key insight here is that Point 1 connects to Point 2 through...",
+ ),
+ PreferenceItem(
+ id="struct_004", topic="response_structure",
+ condition="providing long explanations",
+ action="include a TL;DR summary at the end",
+ conflict_group=None,
+ enforce_description="Check if response ends with TL;DR",
+ example_violation="... and that's how it works.",
+ example_compliance="... and that's how it works.\n\n**TL;DR:** X does Y by Z.",
+ ),
+ PreferenceItem(
+ id="struct_005", topic="response_structure",
+ condition="explaining complex topics",
+ action="start with an outline of what will be covered",
+ conflict_group=None,
+ enforce_description="Check if response starts with an outline",
+ example_violation="Let me explain recursion. First, understand that...",
+ example_compliance="I'll cover: 1) What is recursion, 2) How it works, 3) Examples.\n\n**1) What is recursion**...",
+ ),
+ ],
+ "explanation_depth": [
+ PreferenceItem(
+ id="depth_001", topic="explanation_depth",
+ condition="answering questions",
+ action="be concise, no more than 3 sentences",
+ conflict_group="length",
+ enforce_description="Count sentences, should be 3 or fewer",
+ example_violation="Let me explain in detail. First... Second... Third... Fourth... Fifth...",
+ example_compliance="The answer is X. This works because of Y. Here's how to apply it: Z.",
+ ),
+ PreferenceItem(
+ id="depth_002", topic="explanation_depth",
+ condition="explaining concepts",
+ action="provide comprehensive, detailed explanations",
+ conflict_group="length",
+ enforce_description="Check if explanation is thorough with multiple aspects covered",
+ example_violation="It's X. Done.",
+ example_compliance="Let me explain X in detail. The concept originates from... It works by... Common applications include... Here's an example...",
+ ),
+ PreferenceItem(
+ id="depth_003", topic="explanation_depth",
+ condition="explaining anything",
+ action="always include at least one concrete example",
+ conflict_group=None,
+ enforce_description="Check if at least one example is provided",
+ example_violation="A binary tree is a data structure where each node has at most two children.",
+ example_compliance="A binary tree is a data structure where each node has at most two children. For example, in [5, 3, 7], 5 is the root, 3 is left child, 7 is right child.",
+ ),
+ PreferenceItem(
+ id="depth_004", topic="explanation_depth",
+ condition="explaining technical concepts",
+ action="use analogies from everyday life",
+ conflict_group=None,
+ enforce_description="Check if explanation includes an everyday analogy",
+ example_violation="A stack is a LIFO data structure.",
+ example_compliance="A stack is like a stack of plates - you can only take the top one (LIFO).",
+ ),
+ PreferenceItem(
+ id="depth_005", topic="explanation_depth",
+ condition="solving problems",
+ action="state assumptions explicitly before solving",
+ conflict_group=None,
+ enforce_description="Check if assumptions are stated upfront",
+ example_violation="The answer is 42.",
+ example_compliance="Assuming n is positive and integer, the answer is 42.",
+ ),
+ ],
+ "interaction_style": [
+ PreferenceItem(
+ id="inter_001", topic="interaction_style",
+ condition="receiving unclear requests",
+ action="ask clarifying questions before attempting to answer",
+ conflict_group="clarification",
+ enforce_description="Check if agent asks questions when request is ambiguous",
+ example_violation="Here's a solution assuming you meant X...",
+ example_compliance="Before I help, could you clarify: do you mean X or Y?",
+ ),
+ PreferenceItem(
+ id="inter_002", topic="interaction_style",
+ condition="receiving requests",
+ action="make reasonable assumptions and proceed without asking",
+ conflict_group="clarification",
+ enforce_description="Check if agent proceeds with reasonable assumptions",
+ example_violation="What exactly do you mean by 'large'? What size range?",
+ example_compliance="Assuming you mean 'large' as over 1000 items, here's the solution...",
+ ),
+ PreferenceItem(
+ id="inter_003", topic="interaction_style",
+ condition="solving multi-step problems",
+ action="present one step at a time and ask for confirmation before proceeding",
+ conflict_group="pacing",
+ enforce_description="Check if agent pauses after each step",
+ example_violation="Step 1: X. Step 2: Y. Step 3: Z. Done!",
+ example_compliance="Step 1: X. Does this make sense? Should I continue to Step 2?",
+ ),
+ PreferenceItem(
+ id="inter_004", topic="interaction_style",
+ condition="solving problems",
+ action="provide the complete solution at once without pausing",
+ conflict_group="pacing",
+ enforce_description="Check if agent gives complete solution without asking to continue",
+ example_violation="First, let me do step 1... Should I continue?",
+ example_compliance="Here's the complete solution: Step 1: X, Step 2: Y, Step 3: Z.",
+ ),
+ PreferenceItem(
+ id="inter_005", topic="interaction_style",
+ condition="providing answers",
+ action="include a confidence level (e.g., 'I'm 90% confident')",
+ conflict_group=None,
+ enforce_description="Check if response includes confidence level",
+ example_violation="The answer is 42.",
+ example_compliance="I'm about 95% confident the answer is 42.",
+ ),
+ ],
+ }
+
+ if topic_name in fallbacks:
+ return fallbacks[topic_name][:num_prefs]
+ else:
+ # Generic fallback
+ return [
+ PreferenceItem(
+ id=f"{topic_name[:4]}_{i+1:03d}",
+ topic=topic_name,
+ condition=f"interacting about {topic_name}",
+ action=f"preference {i+1} for {topic_name}",
+ conflict_group=None,
+ enforce_description=f"Check preference {i+1}",
+ example_violation=f"Violation example {i+1}",
+ example_compliance=f"Compliance example {i+1}",
+ )
+ for i in range(num_prefs)
+ ]
+
+ def generate_bank(
+ self,
+ topics: Dict[str, Dict] = None,
+ prefs_per_topic: int = 5,
+ ) -> PreferenceBank:
+ """Generate a complete preference bank."""
+
+ if topics is None:
+ topics = DEMO_TOPICS
+
+ bank = PreferenceBank()
+
+ for topic_name, topic_config in topics.items():
+ print(f"Generating preferences for topic: {topic_name}...")
+
+ preferences = self.generate_preferences_for_topic(
+ topic_name=topic_name,
+ topic_description=topic_config["description"],
+ hints=topic_config.get("generation_hints", []),
+ num_prefs=prefs_per_topic,
+ )
+
+ topic = PreferenceTopic(
+ name=topic_name,
+ description=topic_config["description"],
+ related_datasets=topic_config["related_datasets"],
+ preferences=preferences,
+ )
+
+ bank.add_topic(topic)
+ print(f" Generated {len(preferences)} preferences")
+
+ return bank
+
+
+def generate_demo_bank(
+ output_path: str = None,
+ use_llm: bool = False,
+ prefs_per_topic: int = 5,
+) -> PreferenceBank:
+ """
+ Generate a demo preference bank.
+
+ Args:
+ output_path: If provided, save bank to this path
+ use_llm: If True, use LLM to generate. If False, use hardcoded fallbacks.
+ prefs_per_topic: Number of preferences per topic
+
+ Returns:
+ Generated PreferenceBank
+ """
+ if use_llm:
+ generator = PreferenceBankGenerator()
+ else:
+ generator = PreferenceBankGenerator(llm_client=None) # Use fallbacks
+
+ bank = generator.generate_bank(
+ topics=DEMO_TOPICS,
+ prefs_per_topic=prefs_per_topic,
+ )
+
+ if output_path:
+ bank.save(output_path)
+ print(f"Saved bank to {output_path}")
+
+ print(f"\nBank Statistics: {bank.stats()}")
+
+ return bank
+
+
+if __name__ == "__main__":
+ # Generate demo bank with fallback preferences
+ import os
+ script_dir = os.path.dirname(os.path.abspath(__file__))
+ output_path = os.path.join(script_dir, "bank_demo.json")
+
+ bank = generate_demo_bank(output_path=output_path, use_llm=False)
+
+
diff --git a/src/personalization/evaluation/preference_bank/schemas.py b/src/personalization/evaluation/preference_bank/schemas.py
new file mode 100644
index 0000000..f219487
--- /dev/null
+++ b/src/personalization/evaluation/preference_bank/schemas.py
@@ -0,0 +1,147 @@
+"""
+Preference Bank Schemas
+
+Defines the data structures for user preferences, organized by topic.
+Each preference has a condition (when it applies), action (what the user wants),
+and optional conflict group (preferences in the same group are mutually exclusive).
+"""
+
+from dataclasses import dataclass, field
+from typing import Optional, List, Dict, Any
+import json
+
+
+@dataclass
+class PreferenceItem:
+ """A single user preference."""
+ id: str # Unique ID, e.g., "math_fmt_001"
+ topic: str # Topic name, e.g., "math_formatting"
+ condition: str # When this preference applies
+ action: str # What the user prefers
+ conflict_group: Optional[str] # If set, only one pref from this group can be selected
+ enforce_description: str # Description for user simulator on how to enforce
+ example_violation: str # Example of agent response that violates this
+ example_compliance: str # Example that follows this preference
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "id": self.id,
+ "topic": self.topic,
+ "condition": self.condition,
+ "action": self.action,
+ "conflict_group": self.conflict_group,
+ "enforce_description": self.enforce_description,
+ "example_violation": self.example_violation,
+ "example_compliance": self.example_compliance,
+ }
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> "PreferenceItem":
+ return cls(**data)
+
+ def format_for_user(self) -> str:
+ """Format for user simulator prompt."""
+ return f"When {self.condition}: {self.action}"
+
+ def format_for_enforcement(self) -> str:
+ """Format with enforcement details."""
+ return f"[{self.id}] When {self.condition}: {self.action}\n Enforce if: {self.enforce_description}"
+
+
+@dataclass
+class PreferenceTopic:
+ """A topic containing multiple related preferences."""
+ name: str # Topic name, e.g., "math_formatting"
+ description: str # Description of this topic
+ related_datasets: List[str] # Datasets where this topic is relevant
+ preferences: List[PreferenceItem] = field(default_factory=list)
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "name": self.name,
+ "description": self.description,
+ "related_datasets": self.related_datasets,
+ "preferences": [p.to_dict() for p in self.preferences],
+ }
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> "PreferenceTopic":
+ prefs = [PreferenceItem.from_dict(p) for p in data.get("preferences", [])]
+ return cls(
+ name=data["name"],
+ description=data["description"],
+ related_datasets=data["related_datasets"],
+ preferences=prefs,
+ )
+
+
+@dataclass
+class PreferenceBank:
+ """
+ A bank of preferences organized by topic.
+ Used to generate user profiles by sampling preferences.
+ """
+ topics: Dict[str, PreferenceTopic] = field(default_factory=dict)
+ version: str = "1.0"
+
+ def add_topic(self, topic: PreferenceTopic):
+ self.topics[topic.name] = topic
+
+ def get_all_preferences(self) -> List[PreferenceItem]:
+ """Get all preferences across all topics."""
+ all_prefs = []
+ for topic in self.topics.values():
+ all_prefs.extend(topic.preferences)
+ return all_prefs
+
+ def get_preferences_for_dataset(self, dataset: str) -> List[PreferenceItem]:
+ """Get preferences relevant to a specific dataset."""
+ relevant = []
+ for topic in self.topics.values():
+ if dataset in topic.related_datasets or "all" in topic.related_datasets:
+ relevant.extend(topic.preferences)
+ return relevant
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "version": self.version,
+ "topics": {name: topic.to_dict() for name, topic in self.topics.items()},
+ }
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> "PreferenceBank":
+ bank = cls(version=data.get("version", "1.0"))
+ for name, topic_data in data.get("topics", {}).items():
+ bank.topics[name] = PreferenceTopic.from_dict(topic_data)
+ return bank
+
+ def save(self, path: str):
+ """Save bank to JSON file."""
+ with open(path, "w", encoding="utf-8") as f:
+ json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
+
+ @classmethod
+ def load(cls, path: str) -> "PreferenceBank":
+ """Load bank from JSON file."""
+ with open(path, "r", encoding="utf-8") as f:
+ data = json.load(f)
+ return cls.from_dict(data)
+
+ def stats(self) -> Dict[str, Any]:
+ """Get statistics about the bank."""
+ total_prefs = 0
+ conflict_groups = set()
+ for topic in self.topics.values():
+ total_prefs += len(topic.preferences)
+ for pref in topic.preferences:
+ if pref.conflict_group:
+ conflict_groups.add(pref.conflict_group)
+
+ return {
+ "num_topics": len(self.topics),
+ "total_preferences": total_prefs,
+ "num_conflict_groups": len(conflict_groups),
+ "prefs_per_topic": {name: len(t.preferences) for name, t in self.topics.items()},
+ }
+
+
diff --git a/src/personalization/evaluation/profiles/__init__.py b/src/personalization/evaluation/profiles/__init__.py
new file mode 100644
index 0000000..8532af9
--- /dev/null
+++ b/src/personalization/evaluation/profiles/__init__.py
@@ -0,0 +1,5 @@
+from .generator import UserProfile, UserProfileGenerator
+
+__all__ = ["UserProfile", "UserProfileGenerator"]
+
+
diff --git a/src/personalization/evaluation/profiles/generator.py b/src/personalization/evaluation/profiles/generator.py
new file mode 100644
index 0000000..da847a0
--- /dev/null
+++ b/src/personalization/evaluation/profiles/generator.py
@@ -0,0 +1,351 @@
+"""
+User Profile Generator
+
+Generates user profiles by sampling preferences from the preference bank.
+Ensures no conflicting preferences within same conflict_group, but allows
+cross-topic scenario conflicts (which is desired for testing RAG).
+"""
+
+import json
+import random
+from collections import defaultdict
+from dataclasses import dataclass, field
+from typing import List, Dict, Set, Optional, Any
+
+from ..preference_bank.schemas import PreferenceItem, PreferenceBank
+
+
+@dataclass
+class UserProfile:
+ """A simulated user with specific preferences."""
+ user_id: str
+ persona: str # Background description
+ preferences: List[PreferenceItem] # Selected preferences
+ primary_topics: List[str] # Topics this user cares most about
+ preference_by_topic: Dict[str, List[PreferenceItem]] = field(default_factory=dict)
+
+ def __post_init__(self):
+ # Build topic index if not provided
+ if not self.preference_by_topic:
+ self.preference_by_topic = defaultdict(list)
+ for pref in self.preferences:
+ self.preference_by_topic[pref.topic].append(pref)
+ self.preference_by_topic = dict(self.preference_by_topic)
+
+ def get_preferences_for_topic(self, topic: str) -> List[PreferenceItem]:
+ """Get preferences for a specific topic."""
+ return self.preference_by_topic.get(topic, [])
+
+ def get_preferences_for_dataset(self, dataset: str, bank: PreferenceBank) -> List[PreferenceItem]:
+ """Get preferences relevant to a specific dataset."""
+ relevant_topics = set()
+ for topic_name, topic in bank.topics.items():
+ if dataset in topic.related_datasets or "all" in topic.related_datasets:
+ relevant_topics.add(topic_name)
+
+ relevant_prefs = []
+ for pref in self.preferences:
+ if pref.topic in relevant_topics:
+ relevant_prefs.append(pref)
+ return relevant_prefs
+
+ def format_preferences_grouped(self) -> str:
+ """Format preferences grouped by topic for prompts."""
+ lines = []
+ for topic, prefs in self.preference_by_topic.items():
+ topic_title = topic.replace("_", " ").title()
+ lines.append(f"\n## {topic_title}")
+ for pref in prefs:
+ lines.append(f" [{pref.id}] When {pref.condition}: {pref.action}")
+ lines.append(f" Enforce if: {pref.enforce_description}")
+ return "\n".join(lines)
+
+ def format_preferences_flat(self) -> str:
+ """Format preferences as a flat list."""
+ lines = []
+ for i, pref in enumerate(self.preferences, 1):
+ lines.append(f"{i}. When {pref.condition}: {pref.action}")
+ return "\n".join(lines)
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "user_id": self.user_id,
+ "persona": self.persona,
+ "preferences": [p.to_dict() for p in self.preferences],
+ "primary_topics": self.primary_topics,
+ }
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> "UserProfile":
+ prefs = [PreferenceItem.from_dict(p) for p in data.get("preferences", [])]
+ return cls(
+ user_id=data["user_id"],
+ persona=data["persona"],
+ preferences=prefs,
+ primary_topics=data.get("primary_topics", []),
+ )
+
+ def stats(self) -> Dict[str, Any]:
+ """Get statistics about this profile."""
+ conflict_groups = set()
+ for pref in self.preferences:
+ if pref.conflict_group:
+ conflict_groups.add(pref.conflict_group)
+
+ return {
+ "user_id": self.user_id,
+ "num_preferences": len(self.preferences),
+ "num_topics": len(self.preference_by_topic),
+ "prefs_per_topic": {t: len(ps) for t, ps in self.preference_by_topic.items()},
+ "num_conflict_groups_used": len(conflict_groups),
+ }
+
+
+# Persona templates for different user types
+PERSONA_TEMPLATES = [
+ "A {field} professional who values {trait} and prefers {style} communication.",
+ "A graduate student in {field} who appreciates {trait} and likes responses that are {style}.",
+ "An experienced {field} practitioner who prioritizes {trait} and expects {style} explanations.",
+ "A beginner learning {field} who needs {trait} and responds well to {style} guidance.",
+ "A {field} enthusiast who cares about {trait} and prefers {style} interactions.",
+]
+
+FIELDS = [
+ "software engineering", "data science", "mathematics", "physics",
+ "medical research", "financial analysis", "machine learning",
+ "web development", "systems programming", "algorithm design",
+]
+
+TRAITS = [
+ "clarity", "precision", "efficiency", "thoroughness", "simplicity",
+ "formality", "practicality", "theoretical depth", "hands-on examples",
+]
+
+STYLES = [
+ "concise", "detailed", "step-by-step", "example-driven", "formal",
+ "conversational", "structured", "visual", "analytical",
+]
+
+
+class UserProfileGenerator:
+ """Generates user profiles by sampling from preference bank."""
+
+ def __init__(
+ self,
+ preference_bank: PreferenceBank,
+ target_num_prefs: int = 15, # For demo, use smaller number
+ seed: Optional[int] = None,
+ ):
+ self.bank = preference_bank
+ self.target_num = target_num_prefs
+
+ if seed is not None:
+ random.seed(seed)
+
+ def generate_profile(
+ self,
+ user_id: str,
+ primary_topics: List[str] = None,
+ persona: str = None,
+ ) -> UserProfile:
+ """
+ Generate a user profile by sampling preferences.
+
+ Args:
+ user_id: Unique identifier for this user
+ primary_topics: Topics this user cares most about (get more prefs from these)
+ persona: Optional persona description. If None, will be generated.
+
+ Returns:
+ UserProfile with sampled preferences
+ """
+ selected: List[PreferenceItem] = []
+ used_conflict_groups: Set[str] = set()
+
+ # If no primary topics specified, randomly select 1-2
+ if primary_topics is None:
+ all_topics = list(self.bank.topics.keys())
+ num_primary = random.randint(1, min(2, len(all_topics)))
+ primary_topics = random.sample(all_topics, num_primary)
+
+ # Compute quotas for each topic
+ topic_quotas = self._compute_quotas(primary_topics)
+
+ # Sample from each topic
+ for topic_name, quota in topic_quotas.items():
+ if topic_name not in self.bank.topics:
+ continue
+
+ topic = self.bank.topics[topic_name]
+
+ # Filter out preferences with already-used conflict groups
+ available = [
+ p for p in topic.preferences
+ if p.conflict_group is None or p.conflict_group not in used_conflict_groups
+ ]
+
+ # Sample up to quota
+ to_select = min(quota, len(available))
+ if to_select > 0:
+ sampled = random.sample(available, to_select)
+
+ for pref in sampled:
+ selected.append(pref)
+ if pref.conflict_group:
+ used_conflict_groups.add(pref.conflict_group)
+
+ # Generate persona if not provided
+ if persona is None:
+ persona = self._generate_persona(primary_topics)
+
+ return UserProfile(
+ user_id=user_id,
+ persona=persona,
+ preferences=selected,
+ primary_topics=primary_topics,
+ )
+
+ def _compute_quotas(self, primary_topics: List[str]) -> Dict[str, int]:
+ """Compute how many preferences to sample from each topic."""
+ quotas = {}
+ all_topics = list(self.bank.topics.keys())
+
+ # Base quota for all topics
+ base_quota = max(1, self.target_num // len(all_topics))
+
+ for topic_name in all_topics:
+ if topic_name in primary_topics:
+ # Primary topics get more preferences
+ quotas[topic_name] = base_quota + random.randint(1, 3)
+ else:
+ quotas[topic_name] = max(1, base_quota - random.randint(0, 1))
+
+ # Adjust to match target
+ total = sum(quotas.values())
+ if total < self.target_num:
+ # Add more to primary topics
+ for topic in primary_topics:
+ if topic in quotas:
+ quotas[topic] += (self.target_num - total) // len(primary_topics)
+
+ return quotas
+
+ def _generate_persona(self, primary_topics: List[str]) -> str:
+ """Generate a persona description based on primary topics."""
+ template = random.choice(PERSONA_TEMPLATES)
+
+ # Map topics to fields
+ topic_to_field = {
+ "math_formatting": ["mathematics", "physics", "data science"],
+ "coding_style": ["software engineering", "web development", "systems programming"],
+ "response_structure": ["technical writing", "documentation", "education"],
+ "explanation_depth": ["research", "teaching", "consulting"],
+ "interaction_style": ["customer support", "mentoring", "collaboration"],
+ }
+
+ # Pick a field related to primary topics
+ possible_fields = []
+ for topic in primary_topics:
+ possible_fields.extend(topic_to_field.get(topic, FIELDS[:3]))
+
+ if not possible_fields:
+ possible_fields = FIELDS
+
+ field = random.choice(possible_fields)
+ trait = random.choice(TRAITS)
+ style = random.choice(STYLES)
+
+ return template.format(field=field, trait=trait, style=style)
+
+ def generate_profiles(
+ self,
+ num_users: int,
+ id_prefix: str = "user",
+ ) -> List[UserProfile]:
+ """Generate multiple user profiles."""
+ profiles = []
+
+ for i in range(num_users):
+ user_id = f"{id_prefix}_{i:03d}"
+ profile = self.generate_profile(user_id)
+ profiles.append(profile)
+
+ return profiles
+
+ def save_profiles(self, profiles: List[UserProfile], path: str):
+ """Save profiles to JSON file."""
+ data = [p.to_dict() for p in profiles]
+ with open(path, "w", encoding="utf-8") as f:
+ json.dump(data, f, indent=2, ensure_ascii=False)
+
+ @staticmethod
+ def load_profiles(path: str) -> List[UserProfile]:
+ """Load profiles from JSON file."""
+ with open(path, "r", encoding="utf-8") as f:
+ data = json.load(f)
+ return [UserProfile.from_dict(d) for d in data]
+
+
+def generate_demo_profiles(
+ bank: PreferenceBank,
+ num_users: int = 2,
+ prefs_per_user: int = 10,
+ output_path: str = None,
+ seed: int = 42,
+) -> List[UserProfile]:
+ """
+ Generate demo user profiles.
+
+ Args:
+ bank: Preference bank to sample from
+ num_users: Number of users to generate
+ prefs_per_user: Target preferences per user
+ output_path: If provided, save profiles to this path
+ seed: Random seed for reproducibility
+
+ Returns:
+ List of UserProfile objects
+ """
+ generator = UserProfileGenerator(
+ preference_bank=bank,
+ target_num_prefs=prefs_per_user,
+ seed=seed,
+ )
+
+ profiles = generator.generate_profiles(num_users, id_prefix="demo_user")
+
+ if output_path:
+ generator.save_profiles(profiles, output_path)
+ print(f"Saved {len(profiles)} profiles to {output_path}")
+
+ # Print stats
+ for profile in profiles:
+ print(f"\n{profile.user_id}: {profile.stats()}")
+ print(f" Persona: {profile.persona}")
+
+ return profiles
+
+
+if __name__ == "__main__":
+ import os
+ from ..preference_bank.generator import generate_demo_bank
+
+ # Generate bank first
+ script_dir = os.path.dirname(os.path.abspath(__file__))
+ bank_path = os.path.join(script_dir, "..", "preference_bank", "bank_demo.json")
+
+ if os.path.exists(bank_path):
+ bank = PreferenceBank.load(bank_path)
+ else:
+ bank = generate_demo_bank()
+
+ # Generate profiles
+ profiles_path = os.path.join(script_dir, "profiles_demo.json")
+ profiles = generate_demo_profiles(
+ bank=bank,
+ num_users=2,
+ prefs_per_user=10,
+ output_path=profiles_path,
+ )
+
+
diff --git a/src/personalization/evaluation/user_simulator/__init__.py b/src/personalization/evaluation/user_simulator/__init__.py
new file mode 100644
index 0000000..f7799d0
--- /dev/null
+++ b/src/personalization/evaluation/user_simulator/__init__.py
@@ -0,0 +1,5 @@
+from .simulator import UserSimulator, UserSimulatorResponse
+
+__all__ = ["UserSimulator", "UserSimulatorResponse"]
+
+
diff --git a/src/personalization/evaluation/user_simulator/simulator.py b/src/personalization/evaluation/user_simulator/simulator.py
new file mode 100644
index 0000000..5f5f701
--- /dev/null
+++ b/src/personalization/evaluation/user_simulator/simulator.py
@@ -0,0 +1,310 @@
+"""
+User Simulator
+
+Simulates a user with specific preferences who:
+1. Presents problems to the agent
+2. Checks if agent responses satisfy their preferences
+3. Enforces preferences when violated
+4. Tracks draft answer and decides when to terminate
+"""
+
+import json
+import os
+from dataclasses import dataclass, field
+from typing import List, Dict, Any, Optional
+
+from ..profiles.generator import UserProfile
+from ..preference_bank.schemas import PreferenceItem
+
+
+# User simulator system prompt template
+USER_SYSTEM_PROMPT = """You are simulating a user who is collaborating with an AI assistant to solve a problem. You have specific preferences about how the assistant should respond.
+
+# Problem to Solve
+{task_description}
+{problem}
+Note: The assistant cannot see this problem description directly. You need to communicate with them.
+
+# Your Persona
+{persona}
+
+# Your Preferences (Grouped by Topic)
+{preferences_grouped}
+
+# Preference Enforcement Rules
+- For each assistant response, check which of YOUR preferences are RELEVANT to the current context
+- A preference is relevant if the assistant's response touches on that topic/condition
+- If a relevant preference is VIOLATED, you MUST enforce it before proceeding
+- Do NOT update your draft answer or proceed until violated preferences are fixed
+- Only check preferences that apply to the current response (e.g., coding preferences for code responses)
+
+# Draft Answer Management
+- Maintain a working draft answer to the problem
+- Start with "I don't know"
+- Update it based on helpful information from the assistant
+- Do NOT update if you're enforcing preferences
+
+# Conversation Guidelines
+- Be somewhat vague initially, let the assistant ask clarifying questions
+- Respond naturally like a real user
+- Do not copy the problem description directly
+
+# Termination
+Terminate when:
+- Your draft answer seems correct and complete
+- The assistant cannot help further
+
+When ready to terminate, include "TERMINATE" in your response.
+
+# Output Format (JSON)
+{{
+ "preference_checks": [
+ {{
+ "preference_id": str,
+ "topic": str,
+ "relevant": bool,
+ "satisfied": bool or null,
+ "violation_detail": str
+ }}
+ ],
+ "any_violation": bool,
+ "enforcement_needed": bool,
+ "reasoning": str,
+ "draft_answer": str,
+ "should_terminate": bool,
+ "response": str
+}}
+
+IMPORTANT: Only include preferences that are RELEVANT to the current assistant response in preference_checks.
+Output valid JSON only, no other text."""
+
+
+@dataclass
+class PreferenceCheck:
+ """Result of checking one preference."""
+ preference_id: str
+ topic: str
+ relevant: bool
+ satisfied: Optional[bool] # None if not relevant
+ violation_detail: str = ""
+
+
+@dataclass
+class UserSimulatorResponse:
+ """Response from the user simulator."""
+ response: str # Text response to agent
+ preference_checks: List[PreferenceCheck] # Checked preferences
+ any_violation: bool # Any preference violated?
+ enforcement_needed: bool # Need to enforce?
+ draft_answer: str # Current draft answer
+ should_terminate: bool # Should end conversation?
+ reasoning: str # Internal reasoning
+ raw_output: Dict[str, Any] = field(default_factory=dict)
+
+
+class UserSimulator:
+ """
+ Simulates a user with preferences interacting with an agent.
+ """
+
+ def __init__(
+ self,
+ model_name: str = "Llama-3.3-70B-Instruct",
+ api_base: Optional[str] = None,
+ api_key: Optional[str] = None,
+ temperature: float = 0.8,
+ max_tokens: int = 2048,
+ ):
+ self.model_name = model_name
+ self.api_base = api_base or os.getenv("USER_SIM_API_BASE", "http://localhost:8004/v1")
+ self.api_key = api_key or os.getenv("USER_SIM_API_KEY", "EMPTY")
+ self.temperature = temperature
+ self.max_tokens = max_tokens
+
+ # Current session state
+ self._profile: Optional[UserProfile] = None
+ self._task_description: str = ""
+ self._problem: str = ""
+ self._solution: str = ""
+
+ self._init_client()
+
+ def _init_client(self):
+ """Initialize OpenAI client."""
+ try:
+ import openai
+ self.client = openai.OpenAI(
+ base_url=self.api_base,
+ api_key=self.api_key,
+ )
+ except Exception as e:
+ print(f"Warning: Could not initialize OpenAI client for user simulator: {e}")
+ self.client = None
+
+ def setup(
+ self,
+ profile: UserProfile,
+ task_description: str,
+ problem: str,
+ solution: str = "",
+ ):
+ """
+ Set up the simulator for a new task.
+
+ Args:
+ profile: User profile with preferences
+ task_description: Description of the task type
+ problem: The specific problem to solve
+ solution: Ground truth solution (for evaluation)
+ """
+ self._profile = profile
+ self._task_description = task_description
+ self._problem = problem
+ self._solution = solution
+
+ def _build_system_prompt(self) -> str:
+ """Build the system prompt with user profile and task."""
+ if self._profile is None:
+ raise ValueError("User profile not set. Call setup() first.")
+
+ return USER_SYSTEM_PROMPT.format(
+ task_description=self._task_description,
+ problem=self._problem,
+ persona=self._profile.persona,
+ preferences_grouped=self._profile.format_preferences_grouped(),
+ )
+
+ def _parse_response(self, raw_text: str) -> UserSimulatorResponse:
+ """Parse LLM output into structured response."""
+ try:
+ # Try to extract JSON from response
+ text = raw_text.strip()
+
+ # Handle markdown code blocks
+ if "```json" in text:
+ text = text.split("```json")[1].split("```")[0]
+ elif "```" in text:
+ text = text.split("```")[1].split("```")[0]
+
+ data = json.loads(text)
+
+ # Parse preference checks
+ pref_checks = []
+ for check in data.get("preference_checks", []):
+ pref_checks.append(PreferenceCheck(
+ preference_id=check.get("preference_id", ""),
+ topic=check.get("topic", ""),
+ relevant=check.get("relevant", False),
+ satisfied=check.get("satisfied"),
+ violation_detail=check.get("violation_detail", ""),
+ ))
+
+ return UserSimulatorResponse(
+ response=data.get("response", ""),
+ preference_checks=pref_checks,
+ any_violation=data.get("any_violation", False),
+ enforcement_needed=data.get("enforcement_needed", False),
+ draft_answer=data.get("draft_answer", "I don't know"),
+ should_terminate=data.get("should_terminate", False),
+ reasoning=data.get("reasoning", ""),
+ raw_output=data,
+ )
+
+ except Exception as e:
+ print(f"Error parsing user simulator response: {e}")
+ print(f"Raw text: {raw_text[:500]}...")
+
+ # Return a basic response
+ return UserSimulatorResponse(
+ response=raw_text if len(raw_text) < 500 else "Could you please continue?",
+ preference_checks=[],
+ any_violation=False,
+ enforcement_needed=False,
+ draft_answer="I don't know",
+ should_terminate=False,
+ reasoning="Parse error",
+ raw_output={"error": str(e), "raw": raw_text},
+ )
+
+ def respond(
+ self,
+ conversation_history: List[Dict[str, str]],
+ ) -> UserSimulatorResponse:
+ """
+ Generate user response based on conversation.
+
+ Args:
+ conversation_history: List of {"role": "user/assistant", "content": "..."}
+
+ Returns:
+ UserSimulatorResponse with user's reply and preference status
+ """
+ if self._profile is None:
+ raise ValueError("User profile not set. Call setup() first.")
+
+ system_prompt = self._build_system_prompt()
+
+ # Build messages - reverse roles (user simulator sees itself as user)
+ messages = [{"role": "system", "content": system_prompt}]
+
+ for msg in conversation_history:
+ # Flip roles: agent's messages become user input to simulator
+ if msg["role"] == "assistant":
+ messages.append({"role": "user", "content": msg["content"]})
+ else:
+ messages.append({"role": "assistant", "content": msg["content"]})
+
+ if self.client is None:
+ # Fallback for testing
+ return self._fallback_response(conversation_history)
+
+ try:
+ response = self.client.chat.completions.create(
+ model=self.model_name,
+ messages=messages,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ raw_text = response.choices[0].message.content
+ return self._parse_response(raw_text)
+
+ except Exception as e:
+ print(f"Error calling user simulator LLM: {e}")
+ return self._fallback_response(conversation_history)
+
+ def _fallback_response(
+ self,
+ conversation_history: List[Dict[str, str]],
+ ) -> UserSimulatorResponse:
+ """Generate a simple fallback response for testing."""
+ num_turns = len([m for m in conversation_history if m["role"] == "assistant"])
+
+ if num_turns == 0:
+ # First turn - present the problem
+ response = f"Hi, I need help with this: {self._problem[:200]}..."
+ elif num_turns < 3:
+ response = "Thanks, that helps. Can you explain more?"
+ else:
+ response = "Got it, I think I understand now. TERMINATE"
+
+ return UserSimulatorResponse(
+ response=response,
+ preference_checks=[],
+ any_violation=False,
+ enforcement_needed=False,
+ draft_answer="Draft answer from fallback",
+ should_terminate="TERMINATE" in response,
+ reasoning="Fallback mode",
+ raw_output={},
+ )
+
+ def get_solution(self) -> str:
+ """Get the ground truth solution."""
+ return self._solution
+
+ def get_profile(self) -> Optional[UserProfile]:
+ """Get the current user profile."""
+ return self._profile
+
+
diff --git a/src/personalization/feedback/handlers.py b/src/personalization/feedback/handlers.py
index 60a8d17..f0468b6 100644
--- a/src/personalization/feedback/handlers.py
+++ b/src/personalization/feedback/handlers.py
@@ -5,6 +5,10 @@ from personalization.retrieval.preference_store.schemas import MemoryCard
from personalization.feedback.schemas import TurnSample
from personalization.feedback.reward_model import estimate_reward
from personalization.feedback.gating import estimate_retrieval_gating
+from personalization.feedback.llm_reward import (
+ LLMRewardClient, LLMRewardConfig, RewardResult
+)
+
def eval_step(
q_t: str,
@@ -15,23 +19,18 @@ def eval_step(
query_embedding_t1: Optional[np.ndarray] = None,
) -> Tuple[float, float]:
"""
- Unified evaluation interface.
+ Keyword-based evaluation (legacy).
Given (q_t, a_t, q_{t+1}, memories), returns (reward_hat, gating_hat).
"""
-
- # Construct a lightweight TurnSample
- # We might need embeddings for gating. If not provided, gating might return default.
-
- # Ensure memories have embeddings for gating
mem_embs = None
if memories_t and memories_t[0].embedding_e:
try:
mem_embs = np.array([m.embedding_e for m in memories_t])
except:
pass
-
+
sample = TurnSample(
- user_id="", # Not needed for simple eval
+ user_id="",
session_id="",
turn_id=0,
query_t=q_t,
@@ -40,11 +39,49 @@ def eval_step(
memories=memories_t,
query_embedding_t=query_embedding_t,
query_embedding_t1=query_embedding_t1,
- memory_embeddings=mem_embs
+ memory_embeddings=mem_embs,
)
-
+
r_hat = estimate_reward(sample)
g_hat = estimate_retrieval_gating(sample, r_hat)
-
+
return r_hat, g_hat
+
+async def eval_step_llm(
+ q_t: str,
+ answer_t: str,
+ q_t1: str,
+ memories_t: List[MemoryCard],
+ client: LLMRewardClient,
+ query_embedding_t: Optional[np.ndarray] = None,
+ query_embedding_t1: Optional[np.ndarray] = None,
+) -> Tuple[float, float]:
+ """
+ LLM-as-judge evaluation (async).
+ Returns (reward, gating) where gating=0.0 if update should be skipped.
+
+ The gating signal is derived from the judge's confidence and label:
+ - If confidence < tau_c or label == topic_shift: gating = 0.0
+ - Otherwise: gating = confidence (continuous, in [tau_c, 1.0])
+
+ This replaces the old heuristic gating with the judge's own confidence.
+ """
+ sample = TurnSample(
+ user_id="",
+ session_id="",
+ turn_id=0,
+ query_t=q_t,
+ answer_t=answer_t,
+ query_t1=q_t1,
+ memories=memories_t,
+ query_embedding_t=query_embedding_t,
+ query_embedding_t1=query_embedding_t1,
+ )
+
+ result: RewardResult = await client.judge(sample)
+
+ if result.should_update:
+ return result.reward, result.confidence
+ else:
+ return 0.0, 0.0
diff --git a/src/personalization/feedback/llm_reward.py b/src/personalization/feedback/llm_reward.py
new file mode 100644
index 0000000..6adcf98
--- /dev/null
+++ b/src/personalization/feedback/llm_reward.py
@@ -0,0 +1,253 @@
+"""
+LLM-as-Judge reward model using OpenAI GPT-5-nano (async for parallelism).
+
+Replaces keyword-based heuristic reward with structured LLM judgement.
+Judge receives only (q_t, a_t, q_{t+1}) — no oracle preference cards, no history.
+
+Label taxonomy → scalar reward mapping:
+ neg_constraint_restate → -1.0
+ neg_correction → -0.8
+ neg_confusion → -0.6
+ pos_praise → +0.8
+ pos_progress → +0.1
+ neutral → 0.0
+ topic_shift → 0.0 (update skipped)
+
+Confidence gating: if confidence < tau_c, reward is set to 0 and update is skipped.
+"""
+from __future__ import annotations
+
+import asyncio
+import hashlib
+import json
+import os
+from dataclasses import dataclass, field
+from typing import Dict, List, Optional, Tuple
+
+from openai import AsyncOpenAI, RateLimitError, APITimeoutError, APIConnectionError
+
+from personalization.feedback.schemas import TurnSample
+
+
+# --- Label → Reward Mapping ---
+
+REWARD_MAP: Dict[str, float] = {
+ "neg_constraint_restate": -1.0,
+ "neg_correction": -0.8,
+ "neg_confusion": -0.6,
+ "pos_praise": +0.8,
+ "pos_progress": +0.1,
+ "neutral": 0.0,
+ "topic_shift": 0.0,
+}
+
+VALID_LABELS = set(REWARD_MAP.keys())
+
+
+# --- Configuration ---
+
+@dataclass
+class LLMRewardConfig:
+ model: str = "gpt-5-nano"
+ api_key: Optional[str] = None # Falls back to OPENAI_API_KEY env var
+ base_url: Optional[str] = None # For custom endpoints
+ max_concurrent: int = 32 # Semaphore limit for parallel requests
+ max_retries: int = 3
+ retry_base_delay: float = 1.0 # Exponential backoff base (seconds)
+ timeout: float = 60.0 # Per-request timeout (reasoning models are slower)
+ max_completion_tokens: int = 2048 # Must be high — reasoning models use internal tokens
+ confidence_threshold: float = 0.6 # tau_c: skip update if confidence < this
+ enable_cache: bool = True # Cache by hash of (q_t, a_t, q_{t+1})
+
+
+# --- Prompt ---
+
+JUDGE_SYSTEM_PROMPT = """\
+You are a feedback classifier. Given a user query (q_t), the assistant's response (a_t), \
+and the user's next message (q_{t+1}), classify the user's follow-up into exactly one label.
+
+Labels (mutually exclusive):
+- neg_constraint_restate: User reasserts constraints/preferences as correction (e.g., "as I said…", "remember…", "按我说的…").
+- neg_correction: User indicates the content is wrong or the assistant failed to answer.
+- neg_confusion: User indicates confusion or requests re-explanation.
+- pos_praise: Explicit praise or satisfaction with the response.
+- pos_progress: Constructive continuation (examples, extensions, what-if, next steps) without complaint.
+- neutral: Ambiguous or minimal feedback, not clearly positive or negative.
+- topic_shift: User switches to a new unrelated task/topic.
+
+Output a JSON object with fields: label, confidence (0-1), rationale (one short sentence)."""
+
+JUDGE_USER_TEMPLATE = """\
+q_t: {query_t}
+
+a_t: {answer_t}
+
+q_{{t+1}}: {query_t1}"""
+
+
+# --- Result Dataclass ---
+
+@dataclass
+class RewardResult:
+ label: str
+ confidence: float
+ rationale: str
+ reward: float
+ should_update: bool # False if gated by confidence or topic_shift
+
+
+# --- Async Client ---
+
+class LLMRewardClient:
+ """Async OpenAI client for LLM-as-judge reward estimation."""
+
+ def __init__(self, config: Optional[LLMRewardConfig] = None):
+ self.config = config or LLMRewardConfig()
+ self._client = AsyncOpenAI(
+ api_key=self.config.api_key or os.getenv("OPENAI_API_KEY"),
+ base_url=self.config.base_url,
+ timeout=self.config.timeout,
+ )
+ self._semaphore = asyncio.Semaphore(self.config.max_concurrent)
+ self._cache: Dict[str, RewardResult] = {}
+
+ def _cache_key(self, query_t: str, answer_t: str, query_t1: str) -> str:
+ """Deterministic hash of the judge input triple."""
+ content = f"{query_t}\x00{answer_t}\x00{query_t1}"
+ return hashlib.sha256(content.encode("utf-8")).hexdigest()
+
+ async def _call_with_retry(self, messages: List[dict]) -> str:
+ """Single LLM call with exponential backoff retry."""
+ for attempt in range(self.config.max_retries):
+ try:
+ async with self._semaphore:
+ response = await self._client.chat.completions.create(
+ model=self.config.model,
+ messages=messages,
+ max_completion_tokens=self.config.max_completion_tokens,
+ response_format={"type": "json_object"},
+ )
+ content = response.choices[0].message.content
+ if content:
+ return content.strip()
+ # Reasoning model may exhaust tokens on thinking — retry
+ if response.choices[0].finish_reason == "length":
+ continue
+ return ""
+ except (RateLimitError, APITimeoutError, APIConnectionError) as e:
+ if attempt == self.config.max_retries - 1:
+ raise
+ delay = self.config.retry_base_delay * (2 ** attempt)
+ await asyncio.sleep(delay)
+ return ""
+
+ def _build_messages(self, sample: TurnSample) -> List[dict]:
+ """Construct the judge prompt from (q_t, a_t, q_{t+1}) only."""
+ user_content = JUDGE_USER_TEMPLATE.format(
+ query_t=sample.query_t,
+ answer_t=sample.answer_t,
+ query_t1=sample.query_t1,
+ )
+ return [
+ {"role": "system", "content": JUDGE_SYSTEM_PROMPT},
+ {"role": "user", "content": user_content},
+ ]
+
+ def _parse_result(self, raw: str) -> RewardResult:
+ """Parse structured JSON output into RewardResult."""
+ try:
+ parsed = json.loads(raw)
+ label = parsed["label"]
+ confidence = float(parsed["confidence"])
+ rationale = parsed.get("rationale", "")
+
+ if label not in VALID_LABELS:
+ label = "neutral"
+ confidence = 0.0
+
+ reward = REWARD_MAP[label]
+
+ # Confidence gating and topic_shift skip
+ should_update = (
+ confidence >= self.config.confidence_threshold
+ and label != "topic_shift"
+ )
+ if not should_update:
+ reward = 0.0
+
+ return RewardResult(
+ label=label,
+ confidence=confidence,
+ rationale=rationale,
+ reward=reward,
+ should_update=should_update,
+ )
+ except (json.JSONDecodeError, KeyError, TypeError, ValueError):
+ return RewardResult(
+ label="neutral",
+ confidence=0.0,
+ rationale="parse_failure",
+ reward=0.0,
+ should_update=False,
+ )
+
+ async def judge(self, sample: TurnSample) -> RewardResult:
+ """Judge a single turn (async). Returns RewardResult with gating applied."""
+ # Cache lookup
+ if self.config.enable_cache:
+ key = self._cache_key(sample.query_t, sample.answer_t, sample.query_t1)
+ if key in self._cache:
+ return self._cache[key]
+
+ messages = self._build_messages(sample)
+ raw = await self._call_with_retry(messages)
+ result = self._parse_result(raw)
+
+ # Cache store
+ if self.config.enable_cache:
+ self._cache[key] = result
+
+ return result
+
+ async def judge_batch(self, samples: List[TurnSample]) -> List[RewardResult]:
+ """Judge a batch of turns in parallel. Returns list of RewardResult."""
+ tasks = [self.judge(s) for s in samples]
+ return await asyncio.gather(*tasks)
+
+ async def close(self):
+ """Close the underlying HTTP client."""
+ await self._client.close()
+
+
+# --- Synchronous Wrappers ---
+
+def estimate_reward_llm(
+ sample: TurnSample,
+ config: Optional[LLMRewardConfig] = None,
+) -> Tuple[float, bool]:
+ """
+ Synchronous single-sample reward estimation.
+ Returns (reward, should_update).
+ """
+ client = LLMRewardClient(config)
+ try:
+ result = asyncio.run(client.judge(sample))
+ return result.reward, result.should_update
+ finally:
+ asyncio.run(client.close())
+
+
+def estimate_rewards_batch(
+ samples: List[TurnSample],
+ config: Optional[LLMRewardConfig] = None,
+) -> List[Tuple[float, bool]]:
+ """
+ Synchronous batch reward estimation (runs async internally).
+ Returns list of (reward, should_update) tuples.
+ """
+ client = LLMRewardClient(config)
+ try:
+ results = asyncio.run(client.judge_batch(samples))
+ return [(r.reward, r.should_update) for r in results]
+ finally:
+ asyncio.run(client.close())
diff --git a/src/personalization/retrieval/pipeline.py b/src/personalization/retrieval/pipeline.py
index 3d3eeb7..e83940d 100644
--- a/src/personalization/retrieval/pipeline.py
+++ b/src/personalization/retrieval/pipeline.py
@@ -110,10 +110,14 @@ def retrieve_with_policy(
candidates = [memory_cards[i] for i in dense_idx]
candidate_docs = [c.note_text for c in candidates]
-
+
# 2. Rerank base score (P(yes|q,m))
- base_scores = np.array(reranker.score(query, candidate_docs))
-
+ # Skip reranking if we have fewer candidates than topk_rerank (saves GPU memory)
+ if len(candidates) <= topk_rerank:
+ base_scores = np.ones(len(candidates)) # Uniform scores
+ else:
+ base_scores = np.array(reranker.score(query, candidate_docs))
+
# 3. Policy Scoring (Softmax)
user_state: UserState = user_store.get_state(user_id)
candidate_vectors = item_vectors[dense_idx] # [K, k]
@@ -181,29 +185,35 @@ def retrieve_no_policy(
# 1. Dense retrieval
dense_idx = dense_topk_indices(
- query,
- embed_model,
- memory_embeddings,
+ query,
+ embed_model,
+ memory_embeddings,
valid_indices=valid_indices,
topk=topk_dense
)
-
+
if not dense_idx:
return [], np.array([]), np.array([]), [], np.array([])
candidates = [memory_cards[i] for i in dense_idx]
candidate_docs = [c.note_text for c in candidates]
-
+
# 2. Rerank base score (P(yes|q,m))
- base_scores = np.array(reranker.score(query, candidate_docs))
-
- # 3. Deterministic Top-K selection based on rerank scores ONLY (no policy)
- k = min(topk_rerank, len(base_scores))
- top_indices_local = base_scores.argsort()[-k:][::-1]
- chosen_indices = top_indices_local.tolist()
-
+ # Skip reranking if we have fewer candidates than topk_rerank (saves GPU memory)
+ if len(candidates) <= topk_rerank:
+ # Just return all candidates without reranking
+ base_scores = np.ones(len(candidates)) # Uniform scores
+ chosen_indices = list(range(len(candidates)))
+ else:
+ base_scores = np.array(reranker.score(query, candidate_docs))
+
+ # 3. Deterministic Top-K selection based on rerank scores ONLY (no policy)
+ k = min(topk_rerank, len(base_scores))
+ top_indices_local = base_scores.argsort()[-k:][::-1]
+ chosen_indices = top_indices_local.tolist()
+
# Get scores for chosen items (for logging compatibility)
- chosen_scores = base_scores[top_indices_local]
+ chosen_scores = base_scores[chosen_indices]
# Return empty item vectors (not used in NoPersonal mode)
# Return rerank scores as the "probs" field for logging compatibility
diff --git a/src/personalization/serving/personalized_llm.py b/src/personalization/serving/personalized_llm.py
index 2c4d5a8..733ff87 100644
--- a/src/personalization/serving/personalized_llm.py
+++ b/src/personalization/serving/personalized_llm.py
@@ -33,6 +33,7 @@ from personalization.config.settings import load_local_models_config
from personalization.config.registry import get_preference_extractor, get_chat_model
from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
from personalization.models.reranker.qwen3_reranker import Qwen3Reranker
+from personalization.models.reranker.bge_reranker import BGEReranker
from personalization.user_model.tensor_store import UserTensorStore, UserState
from personalization.user_model.session_state import OnlineSessionState
from personalization.user_model.features import ItemProjection
@@ -40,7 +41,8 @@ from personalization.retrieval.preference_store.schemas import (
MemoryCard, ChatTurn, PreferenceList, Preference
)
from personalization.retrieval.pipeline import retrieve_with_policy, retrieve_no_policy
-from personalization.feedback.handlers import eval_step
+from personalization.feedback.handlers import eval_step, eval_step_llm
+from personalization.feedback.llm_reward import LLMRewardClient, LLMRewardConfig
from personalization.user_model.policy.reinforce import reinforce_update_user_state
@@ -113,6 +115,119 @@ class _SessionContext:
# =============================================================================
+# Shared Model Singletons for Multi-threaded Efficiency
+# =============================================================================
+
+_shared_embed_model = None
+_shared_reranker = None
+_shared_extractor = None
+_shared_models_lock = None # Will be initialized on first use
+
+
+def _get_shared_models_lock():
+ """Get or create the threading lock for shared models."""
+ global _shared_models_lock
+ if _shared_models_lock is None:
+ import threading
+ _shared_models_lock = threading.Lock()
+ return _shared_models_lock
+
+
+def get_shared_embedding_model(model_path: str, device_map: str = "auto"):
+ """Get or create shared embedding model (thread-safe singleton)."""
+ global _shared_embed_model
+ import torch
+
+ lock = _get_shared_models_lock()
+ with lock:
+ if _shared_embed_model is None:
+ print(f"[SharedModels] Loading shared embedding model on {device_map}...")
+ _shared_embed_model = Qwen3Embedding8B(
+ model_path=model_path,
+ dtype=torch.bfloat16,
+ device_map=device_map,
+ )
+ print("[SharedModels] Shared embedding model loaded.")
+ return _shared_embed_model
+
+
+def get_shared_reranker(model_path: str, device_map: str = "auto", reranker_type: str = "qwen3"):
+ """Get or create shared reranker model (thread-safe singleton)."""
+ global _shared_reranker
+ import torch
+
+ lock = _get_shared_models_lock()
+ with lock:
+ if _shared_reranker is None:
+ print(f"[SharedModels] Loading shared reranker ({reranker_type}) on {device_map}...")
+ if reranker_type == "bge":
+ _shared_reranker = BGEReranker(
+ model_path=model_path,
+ device_map=device_map,
+ dtype=torch.float16,
+ )
+ else:
+ _shared_reranker = Qwen3Reranker(
+ model_path=model_path,
+ device_map=device_map,
+ dtype=torch.bfloat16,
+ )
+ print("[SharedModels] Shared reranker model loaded.")
+ return _shared_reranker
+
+
+def get_shared_extractor(model_path: str, device_map: str = "auto"):
+ """Get or create shared preference extractor model (thread-safe singleton)."""
+ global _shared_extractor
+ import torch
+ from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor
+
+ lock = _get_shared_models_lock()
+ with lock:
+ if _shared_extractor is None:
+ print(f"[SharedModels] Loading shared preference extractor on {device_map}...")
+ _shared_extractor = QwenRuleExtractor(
+ model_path=model_path,
+ dtype=torch.bfloat16,
+ device_map=device_map,
+ )
+ print("[SharedModels] Shared preference extractor loaded.")
+ return _shared_extractor
+
+
+def clear_shared_models():
+ """Free all shared singleton models to reclaim GPU memory between methods."""
+ global _shared_embed_model, _shared_reranker, _shared_extractor
+ import gc
+
+ lock = _get_shared_models_lock()
+ with lock:
+ freed = []
+ if _shared_embed_model is not None:
+ freed.append("embedding")
+ del _shared_embed_model
+ _shared_embed_model = None
+ if _shared_reranker is not None:
+ freed.append("reranker")
+ del _shared_reranker
+ _shared_reranker = None
+ if _shared_extractor is not None:
+ freed.append("extractor")
+ del _shared_extractor
+ _shared_extractor = None
+
+ if freed:
+ gc.collect()
+ try:
+ import torch
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ except ImportError:
+ pass
+ print(f"[SharedModels] Cleared: {', '.join(freed)}")
+
+
+# =============================================================================
# PersonalizedLLM Class
# =============================================================================
@@ -163,6 +278,12 @@ class PersonalizedLLM:
mode: str = "full", # "full", "nopersonal", or "vanilla"
eval_mode: bool = True, # True = greedy selection, False = stochastic sampling
device_assignment: Optional[Dict[str, str]] = None, # Multi-GPU support
+ llm_name: Optional[str] = None, # Override LLM name (e.g., "llama_8b_vllm" for vLLM)
+ use_shared_models: bool = False, # Use shared singleton models for multi-threaded efficiency
+ reranker_type: str = "qwen3", # "qwen3" (8B) or "bge" (278M)
+ best_of_n: int = 1, # Generate N responses and pick best (for RAG methods)
+ reward_mode: str = "keyword", # "keyword" (legacy heuristic) or "llm" (GPT-5-nano judge)
+ llm_reward_config: Optional["LLMRewardConfig"] = None, # Config for LLM judge
):
"""
Initialize the PersonalizedLLM.
@@ -183,12 +304,25 @@ class PersonalizedLLM:
device_assignment: Optional dict to assign models to specific GPUs.
Example: {"embed": "cuda:0", "reranker": "cuda:1", "chat": "cuda:2", "extractor": "cuda:3"}
If None, uses "auto" for all models.
+ use_shared_models: If True, use shared singleton models for embedding and reranker.
+ This is essential for multi-threaded/parallel profile processing to avoid
+ loading duplicate models. When enabled, the first thread loads the models,
+ and subsequent threads reuse the shared instances.
"""
self.only_own_memories = only_own_memories
+ self.use_shared_models = use_shared_models
self.enable_preference_extraction = enable_preference_extraction
self.enable_rl_updates = enable_rl_updates
self.mode = mode # "full" or "nopersonal"
self.eval_mode = eval_mode # True = greedy, False = sample
+ self.reranker_type = reranker_type # "qwen3" or "bge"
+ self.best_of_n = best_of_n # Generate N responses and pick best
+ self.reward_mode = reward_mode # "keyword" or "llm"
+
+ # Initialize LLM reward client if using LLM judge
+ self._llm_reward_client: Optional[LLMRewardClient] = None
+ if reward_mode == "llm":
+ self._llm_reward_client = LLMRewardClient(llm_reward_config or LLMRewardConfig())
# Multi-GPU device assignment
self._device_assignment = device_assignment or {
@@ -219,6 +353,9 @@ class PersonalizedLLM:
"max_new_tokens": 512,
}
+ # Store llm_name before loading config (needed in _load_config)
+ self._llm_name_override = llm_name
+
# Load config and override RL params if available
self._load_config(config_path)
@@ -249,8 +386,8 @@ class PersonalizedLLM:
if config_path is None:
config_path = "configs/user_model.yaml"
- self._llm_name = "qwen_1_5b" # Default
-
+ self._llm_name = self._llm_name_override or "qwen_1_5b" # Default, can be overridden
+
try:
if os.path.exists(config_path):
with open(config_path, "r") as f:
@@ -260,8 +397,8 @@ class PersonalizedLLM:
for key in self._rl_cfg:
if key in user_cfg:
self._rl_cfg[key] = user_cfg[key]
- # LLM name
- if "llm_name" in user_cfg:
+ # LLM name (only from config if not already set via parameter)
+ if self._llm_name_override is None and "llm_name" in user_cfg:
self._llm_name = user_cfg["llm_name"]
except Exception as e:
print(f"[PersonalizedLLM] Warning: Failed to load config: {e}")
@@ -269,53 +406,110 @@ class PersonalizedLLM:
def _load_models(self):
"""Load all ML models with optional multi-GPU assignment."""
import torch
-
- # Report GPU availability
- num_gpus = torch.cuda.device_count()
- print(f"[PersonalizedLLM] Available GPUs: {num_gpus}")
- for i in range(num_gpus):
- mem = torch.cuda.get_device_properties(i).total_memory / 1e9
- print(f" GPU {i}: {torch.cuda.get_device_name(i)} ({mem:.1f}GB)")
-
+
+ # Report GPU availability (only once, not for shared model instances)
+ if not self.use_shared_models:
+ num_gpus = torch.cuda.device_count()
+ print(f"[PersonalizedLLM] Available GPUs: {num_gpus}")
+ for i in range(num_gpus):
+ mem = torch.cuda.get_device_properties(i).total_memory / 1e9
+ print(f" GPU {i}: {torch.cuda.get_device_name(i)} ({mem:.1f}GB)")
+
embed_device = self._device_assignment.get("embed", "auto")
reranker_device = self._device_assignment.get("reranker", "auto")
chat_device = self._device_assignment.get("chat", "auto")
extractor_device = self._device_assignment.get("extractor", "auto")
-
- # Embedding model
- print(f"[PersonalizedLLM] Loading Embedding model on {embed_device}...")
- self._embed_model = Qwen3Embedding8B(
- model_path=self._cfg.embedding.qwen3.local_path,
- dtype=torch.bfloat16,
- device_map=embed_device,
- )
-
- # Reranker
- print(f"[PersonalizedLLM] Loading Reranker on {reranker_device}...")
- self._reranker = Qwen3Reranker(
- model_path=self._cfg.reranker.qwen3_8b.local_path,
- device_map=reranker_device,
- dtype=torch.bfloat16,
- )
-
+
+ # Embedding model - only load for modes that use RAG retrieval
+ # Vanilla and contextual modes don't need embedding/reranker
+ needs_retrieval = self.mode not in ("vanilla", "contextual")
+
+ if needs_retrieval:
+ if self.use_shared_models:
+ print(f"[PersonalizedLLM] Using shared embedding model...")
+ self._embed_model = get_shared_embedding_model(
+ model_path=self._cfg.embedding.qwen3.local_path,
+ device_map=embed_device,
+ )
+ else:
+ print(f"[PersonalizedLLM] Loading Embedding model on {embed_device}...")
+ self._embed_model = Qwen3Embedding8B(
+ model_path=self._cfg.embedding.qwen3.local_path,
+ dtype=torch.bfloat16,
+ device_map=embed_device,
+ )
+ else:
+ print(f"[PersonalizedLLM] Skipping embedding model (not needed for {self.mode} mode)")
+ self._embed_model = None
+
+ # Reranker - only load for modes that use RAG retrieval
+ # Support both qwen3 (8B) and bge (278M) rerankers
+ if needs_retrieval:
+ if self.reranker_type == "bge":
+ reranker_path = getattr(self._cfg.reranker, "bge_base", None)
+ reranker_path = reranker_path.local_path if reranker_path else "BAAI/bge-reranker-base"
+ else:
+ reranker_path = self._cfg.reranker.qwen3_8b.local_path
+
+ if self.use_shared_models:
+ print(f"[PersonalizedLLM] Using shared reranker model ({self.reranker_type})...")
+ self._reranker = get_shared_reranker(
+ model_path=reranker_path,
+ device_map=reranker_device,
+ reranker_type=self.reranker_type,
+ )
+ else:
+ print(f"[PersonalizedLLM] Loading Reranker ({self.reranker_type}) on {reranker_device}...")
+ if self.reranker_type == "bge":
+ self._reranker = BGEReranker(
+ model_path=reranker_path,
+ device_map=reranker_device,
+ dtype=torch.float16,
+ )
+ else:
+ self._reranker = Qwen3Reranker(
+ model_path=reranker_path,
+ device_map=reranker_device,
+ dtype=torch.bfloat16,
+ )
+ else:
+ print(f"[PersonalizedLLM] Skipping reranker (not needed for {self.mode} mode)")
+ self._reranker = None
+
# Chat model (via registry for backend switching)
print(f"[PersonalizedLLM] Loading ChatModel: {self._llm_name} on {chat_device}...")
# Pass device override if specified (not "auto")
device_for_chat = chat_device if chat_device != "auto" else None
self._chat_model = get_chat_model(self._llm_name, device_override=device_for_chat)
-
- # Preference extractor
+
+ # Preference extractor - use shared singleton if enabled
if self.enable_preference_extraction:
extractor_name = "qwen3_0_6b_sft"
- print(f"[PersonalizedLLM] Loading extractor: {extractor_name} on {extractor_device}...")
- try:
- self._extractor = get_preference_extractor(extractor_name)
- except Exception as e:
- print(f"[PersonalizedLLM] Warning: Failed to load {extractor_name}: {e}. Using rule-based.")
- self._extractor = get_preference_extractor("rule")
+ if self.use_shared_models:
+ print(f"[PersonalizedLLM] Using shared preference extractor...")
+ try:
+ extractor_path = self._cfg.preference_extractor.get("qwen3_0_6b_sft", {}).get("path", None)
+ if extractor_path:
+ self._extractor = get_shared_extractor(
+ model_path=extractor_path,
+ device_map=extractor_device,
+ )
+ else:
+ print(f"[PersonalizedLLM] Extractor path not found, using rule-based.")
+ self._extractor = get_preference_extractor("rule")
+ except Exception as e:
+ print(f"[PersonalizedLLM] Warning: Failed to load shared extractor: {e}. Using rule-based.")
+ self._extractor = get_preference_extractor("rule")
+ else:
+ print(f"[PersonalizedLLM] Loading extractor: {extractor_name} on {extractor_device}...")
+ try:
+ self._extractor = get_preference_extractor(extractor_name)
+ except Exception as e:
+ print(f"[PersonalizedLLM] Warning: Failed to load {extractor_name}: {e}. Using rule-based.")
+ self._extractor = get_preference_extractor("rule")
else:
- print("[PersonalizedLLM] Preference extraction disabled, using rule-based extractor.")
- self._extractor = get_preference_extractor("rule")
+ print("[PersonalizedLLM] Preference extraction disabled, skipping extractor.")
+ self._extractor = None
def _load_memory_store(self):
"""Load memory cards and embeddings."""
@@ -396,33 +590,34 @@ class PersonalizedLLM:
Returns list of preference dicts for debug info.
"""
extracted = []
-
+
if not prefs.preferences or self._projection is None:
return extracted
-
- # Compute embedding for the query
- e_q = self._embed_model.encode([query], return_tensor=False)[0]
- v_q = self._projection.transform_vector(np.array(e_q))
-
+
for pref in prefs.preferences:
note_text = f"When {pref.condition}, {pref.action}."
-
+
# Record for debug
extracted.append({
"condition": pref.condition,
"action": pref.action,
"confidence": pref.confidence,
})
-
+
# Deduplication check
is_duplicate = any(
card.user_id == user_id and card.note_text == note_text
for card in self._memory_cards
)
-
+
if is_duplicate:
continue
-
+
+ # Compute embedding from note_text (NOT query) for proper semantic retrieval
+ # This ensures retrieval query "solve math problem" matches stored "When math problems..."
+ e_note = self._embed_model.encode([note_text], normalize=True, return_tensor=False)[0]
+ v_note = self._projection.transform_vector(np.array(e_note))
+
# Create new memory card
card = MemoryCard(
card_id=str(uuid.uuid4()),
@@ -432,21 +627,61 @@ class PersonalizedLLM:
raw_queries=[query],
preference_list=PreferenceList(preferences=[pref]),
note_text=note_text,
- embedding_e=list(e_q),
+ embedding_e=list(e_note),
kind="pref",
)
-
+
# Add to memory store
self._memory_cards.append(card)
- self._memory_embeddings = np.vstack([self._memory_embeddings, np.array([e_q])])
- self._item_vectors = np.vstack([self._item_vectors, np.array([v_q])])
+ self._memory_embeddings = np.vstack([self._memory_embeddings, np.array([e_note])])
+ self._item_vectors = np.vstack([self._item_vectors, np.array([v_note])])
return extracted
-
+
+ def _score_response(self, response: str) -> float:
+ """
+ Score a response for best-of-N selection.
+
+ Higher score = better response. Scoring heuristics:
+ 1. Length: Longer responses typically have more substance
+ 2. Solution indicators: Contains formulas, steps, answers
+ 3. Proactivity: Doesn't end with just a question
+
+ Returns:
+ Float score (higher is better)
+ """
+ score = 0.0
+ response_lower = response.lower()
+
+ # Length score (normalized, cap at 1000 chars)
+ score += min(len(response), 1000) / 1000 * 3.0
+
+ # Solution indicators (+1 each, max 5)
+ solution_indicators = ['=', 'step', 'answer', 'formula', 'result', 'therefore', 'solution']
+ indicator_count = sum(1 for ind in solution_indicators if ind in response_lower)
+ score += min(indicator_count, 5) * 0.5
+
+ # Structured content (+1 for numbered/bulleted lists)
+ if any(marker in response for marker in ['1.', '2.', '- ', '* ', '##']):
+ score += 1.0
+
+ # Penalty for ending with question (passive behavior)
+ # Check last 100 chars for question marks
+ if '?' in response[-100:]:
+ score -= 1.5
+
+ # Bonus for providing concrete values/numbers
+ import re
+ numbers = re.findall(r'\d+\.?\d*', response)
+ if len(numbers) >= 3:
+ score += 1.0
+
+ return score
+
# =========================================================================
# Public Interface
# =========================================================================
-
+
def chat(self, user_id: str, query: str) -> AssistantResponse:
"""
Main online chat interface.
@@ -465,34 +700,19 @@ class PersonalizedLLM:
ctx = self._get_or_create_session(user_id)
session = ctx.session_state
user_state = self._user_store.get_state(user_id)
-
+
# Record user vector before for debug
z_long_before = user_state.z_long.copy().tolist()
z_short_before = user_state.z_short.copy().tolist()
-
- # Compute query embedding
- e_q_t = np.array(self._embed_model.encode([query], return_tensor=False)[0])
-
- # Store pending RL update info from last turn (for apply_feedback)
- if session.last_query is not None and self.enable_rl_updates:
- ctx.pending_rl_update = {
- "last_query": session.last_query,
- "last_answer": session.last_answer,
- "last_memories": session.last_memories,
- "last_query_embedding": session.last_query_embedding,
- "current_query_embedding": e_q_t,
- "last_candidate_item_vectors": session.last_candidate_item_vectors,
- "last_policy_probs": session.last_policy_probs,
- "last_chosen_indices": session.last_chosen_indices,
- }
-
+
# Add user turn to history
user_turn = self._build_chat_turn(user_id, query, "user", ctx.turn_counter)
session.history.append(user_turn)
-
+
# Vanilla mode: pure LLM without any memory or preference extraction
if self.mode == "vanilla":
- # Skip preference extraction and memory retrieval entirely
+ # Skip embedding, preference extraction, and memory retrieval entirely
+ e_q_t = np.zeros(4096, dtype=np.float32) # Placeholder for vanilla mode
extracted_prefs = []
candidates = []
cand_item_vecs = np.array([])
@@ -502,13 +722,61 @@ class PersonalizedLLM:
memories_t = []
memory_notes = []
else:
+ # Compute query embedding (only needed for non-vanilla modes)
+ # Explicitly normalize for consistent cosine similarity with stored embeddings
+ embed_result = self._embed_model.encode([query], normalize=True, return_tensor=False)
+ if embed_result is None or len(embed_result) == 0:
+ raise RuntimeError(f"Embedding model returned empty result for query: {query[:100]}")
+ e_q_t = np.array(embed_result[0])
+
+ # Store pending RL update info from last turn (for apply_feedback)
+ if session.last_query is not None and self.enable_rl_updates:
+ ctx.pending_rl_update = {
+ "last_query": session.last_query,
+ "last_answer": session.last_answer,
+ "last_memories": session.last_memories,
+ "last_query_embedding": session.last_query_embedding,
+ "current_query_embedding": e_q_t,
+ "last_candidate_item_vectors": session.last_candidate_item_vectors,
+ "last_policy_probs": session.last_policy_probs,
+ "last_chosen_indices": session.last_chosen_indices,
+ }
+
+ # Auto-compute reward via LLM judge if enabled
+ if self.reward_mode == "llm" and self._llm_reward_client is not None:
+ import asyncio
+ try:
+ reward, gating = asyncio.run(eval_step_llm(
+ q_t=session.last_query,
+ answer_t=session.last_answer,
+ q_t1=query,
+ memories_t=session.last_memories or [],
+ client=self._llm_reward_client,
+ ))
+ if gating > 0.0:
+ self.apply_feedback(Feedback(
+ user_id=user_id,
+ turn_id=ctx.turn_counter - 1,
+ reward=reward,
+ gating=gating,
+ ))
+ except Exception as e:
+ # Graceful fallback: skip RL update if judge fails
+ print(f"[LLM-Reward] Judge call failed, skipping update: {e}")
+
# Extract preferences from conversation (if enabled)
+ # extract_turn processes only the last user turn - efficient since called each turn
+ # Preferences accumulate in _memory_cards across turns (dedup prevents duplicates)
extracted_prefs = []
if self.enable_preference_extraction:
prefs = self._extractor.extract_turn(session.history)
+ if prefs.preferences:
+ print(f"[DEBUG] Extracted {len(prefs.preferences)} prefs from history (len={len(session.history)})")
extracted_prefs = self._add_preferences_as_memory(
prefs, query, user_id, ctx.turn_counter
)
+ if extracted_prefs:
+ print(f"[DEBUG] Added {len(extracted_prefs)} to memory. Total cards: {len(self._memory_cards)}")
# Retrieve memories
# In "nopersonal" mode: deterministic retrieval (dense + rerank + topk), no policy/user vector
@@ -551,6 +819,14 @@ class PersonalizedLLM:
# Get selected memories
memories_t = [candidates[int(i)] for i in chosen_indices] if chosen_indices else []
memory_notes = [m.note_text for m in memories_t]
+
+ # Debug: show retrieval info
+ if memories_t:
+ print(f"[DEBUG-RETRIEVAL] User={user_id}, Query={query[:50]}...")
+ print(f"[DEBUG-RETRIEVAL] Candidates={len(candidates)}, Selected={len(memories_t)}")
+ for i, m in enumerate(memories_t[:3]): # Show top 3
+ score = probs[chosen_indices[i]] if i < len(chosen_indices) and chosen_indices[i] < len(probs) else 0
+ print(f"[DEBUG-RETRIEVAL] [{i+1}] score={score:.3f}: {m.note_text[:80]}...")
# Build prompt and count tokens
prompt_tokens = self._count_tokens(query)
@@ -559,13 +835,34 @@ class PersonalizedLLM:
for note in memory_notes:
prompt_tokens += self._count_tokens(note)
- # Generate answer
- answer_t = self._chat_model.answer(
- history=session.history,
- memory_notes=memory_notes,
- max_new_tokens=self._rl_cfg["max_new_tokens"],
- )
-
+ # Generate answer (with best-of-N if enabled)
+ if self.best_of_n > 1:
+ # Generate N responses and pick the best one
+ candidates_responses = []
+ for i in range(self.best_of_n):
+ resp = self._chat_model.answer(
+ history=session.history,
+ memory_notes=memory_notes,
+ max_new_tokens=self._rl_cfg["max_new_tokens"],
+ temperature=0.8, # Slightly higher temp for diversity
+ )
+ score = self._score_response(resp)
+ candidates_responses.append((resp, score))
+
+ # Sort by score (descending) and pick best
+ candidates_responses.sort(key=lambda x: x[1], reverse=True)
+ answer_t = candidates_responses[0][0]
+ best_score = candidates_responses[0][1]
+
+ if len(candidates_responses) > 1:
+ print(f"[BEST-OF-{self.best_of_n}] Scores: {[f'{s:.2f}' for _, s in candidates_responses]}, picked score={best_score:.2f}")
+ else:
+ answer_t = self._chat_model.answer(
+ history=session.history,
+ memory_notes=memory_notes,
+ max_new_tokens=self._rl_cfg["max_new_tokens"],
+ )
+
completion_tokens = self._count_tokens(answer_t)
# Add assistant turn to history
@@ -612,7 +909,263 @@ class PersonalizedLLM:
usage=usage,
debug=debug,
)
-
+
+ def chat_prepare(self, user_id: str, query: str) -> dict:
+ """
+ Prepare for chat without calling the LLM.
+
+ This does all the preparation work (embedding, memory retrieval, etc.)
+ and returns the messages to send to the LLM along with context needed
+ for post-processing.
+
+ Used for batch processing where messages are collected first, then
+ sent in batch to vLLM for concurrent processing.
+
+ Args:
+ user_id: Unique identifier for the user.
+ query: Current user query/message.
+
+ Returns:
+ Dict containing:
+ - messages: List of messages to send to LLM
+ - context: Dict with all state needed for chat_complete()
+ """
+ ctx = self._get_or_create_session(user_id)
+ session = ctx.session_state
+ user_state = self._user_store.get_state(user_id)
+
+ # Record user vector before for debug
+ z_long_before = user_state.z_long.copy().tolist()
+ z_short_before = user_state.z_short.copy().tolist()
+
+ # Add user turn to history
+ user_turn = self._build_chat_turn(user_id, query, "user", ctx.turn_counter)
+ session.history.append(user_turn)
+
+ # Vanilla mode: pure LLM without any memory or preference extraction
+ if self.mode == "vanilla":
+ e_q_t = np.zeros(4096, dtype=np.float32)
+ extracted_prefs = []
+ candidates = []
+ cand_item_vecs = np.array([])
+ base_scores = np.array([])
+ chosen_indices = []
+ probs = np.array([])
+ memories_t = []
+ memory_notes = []
+ else:
+ # Compute query embedding
+ embed_result = self._embed_model.encode([query], normalize=True, return_tensor=False)
+ if embed_result is None or len(embed_result) == 0:
+ raise RuntimeError(f"Embedding model returned empty result for query: {query[:100]}")
+ e_q_t = np.array(embed_result[0])
+
+ # Store pending RL update info from last turn
+ if session.last_query is not None and self.enable_rl_updates:
+ ctx.pending_rl_update = {
+ "last_query": session.last_query,
+ "last_answer": session.last_answer,
+ "last_memories": session.last_memories,
+ "last_query_embedding": session.last_query_embedding,
+ "current_query_embedding": e_q_t,
+ "last_candidate_item_vectors": session.last_candidate_item_vectors,
+ "last_policy_probs": session.last_policy_probs,
+ "last_chosen_indices": session.last_chosen_indices,
+ }
+
+ # Auto-compute reward via LLM judge if enabled
+ if self.reward_mode == "llm" and self._llm_reward_client is not None:
+ import asyncio
+ try:
+ reward, gating = asyncio.run(eval_step_llm(
+ q_t=session.last_query,
+ answer_t=session.last_answer,
+ q_t1=query,
+ memories_t=session.last_memories or [],
+ client=self._llm_reward_client,
+ ))
+ if gating > 0.0:
+ self.apply_feedback(Feedback(
+ user_id=user_id,
+ turn_id=ctx.turn_counter - 1,
+ reward=reward,
+ gating=gating,
+ ))
+ except Exception as e:
+ print(f"[LLM-Reward] Judge call failed, skipping update: {e}")
+
+ # Extract preferences from conversation
+ extracted_prefs = []
+ if self.enable_preference_extraction:
+ prefs = self._extractor.extract_turn(session.history)
+ if prefs.preferences:
+ print(f"[DEBUG] Extracted {len(prefs.preferences)} prefs from history (len={len(session.history)})")
+ extracted_prefs = self._add_preferences_as_memory(
+ prefs, query, user_id, ctx.turn_counter
+ )
+ if extracted_prefs:
+ print(f"[DEBUG] Added {len(extracted_prefs)} to memory. Total cards: {len(self._memory_cards)}")
+
+ # Retrieve memories
+ if self.mode == "nopersonal":
+ candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_no_policy(
+ user_id=user_id,
+ query=query,
+ embed_model=self._embed_model,
+ reranker=self._reranker,
+ memory_cards=self._memory_cards,
+ memory_embeddings=self._memory_embeddings,
+ topk_dense=self._rl_cfg["dense_topk"],
+ topk_rerank=self._rl_cfg["rerank_topk"],
+ only_own_memories=self.only_own_memories,
+ )
+ else:
+ beta_long = self._rl_cfg["beta_long"]
+ beta_short = self._rl_cfg["beta_short"]
+ candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_with_policy(
+ user_id=user_id,
+ query=query,
+ embed_model=self._embed_model,
+ reranker=self._reranker,
+ memory_cards=self._memory_cards,
+ memory_embeddings=self._memory_embeddings,
+ user_store=self._user_store,
+ item_vectors=self._item_vectors,
+ topk_dense=self._rl_cfg["dense_topk"],
+ topk_rerank=self._rl_cfg["rerank_topk"],
+ beta_long=beta_long,
+ beta_short=beta_short,
+ tau=self._rl_cfg["tau"],
+ only_own_memories=self.only_own_memories,
+ sample=not self.eval_mode,
+ )
+
+ memories_t = [candidates[int(i)] for i in chosen_indices] if chosen_indices else []
+ memory_notes = [m.note_text for m in memories_t]
+
+ if memories_t:
+ print(f"[DEBUG-RETRIEVAL] User={user_id}, Query={query[:50]}...")
+ print(f"[DEBUG-RETRIEVAL] Candidates={len(candidates)}, Selected={len(memories_t)}")
+ for i, m in enumerate(memories_t[:3]):
+ score = probs[chosen_indices[i]] if i < len(chosen_indices) and chosen_indices[i] < len(probs) else 0
+ print(f"[DEBUG-RETRIEVAL] [{i+1}] score={score:.3f}: {m.note_text[:80]}...")
+
+ # Build prompt token count
+ prompt_tokens = self._count_tokens(query)
+ for turn in session.history:
+ prompt_tokens += self._count_tokens(turn.text)
+ for note in memory_notes:
+ prompt_tokens += self._count_tokens(note)
+
+ # Build messages for LLM
+ messages = self._chat_model.build_messages(
+ history=session.history,
+ memory_notes=memory_notes,
+ max_new_tokens=self._rl_cfg["max_new_tokens"],
+ )
+
+ # Return messages and context for chat_complete
+ return {
+ "messages": messages,
+ "context": {
+ "user_id": user_id,
+ "query": query,
+ "ctx": ctx,
+ "session": session,
+ "user_state": user_state,
+ "z_long_before": z_long_before,
+ "z_short_before": z_short_before,
+ "e_q_t": e_q_t,
+ "extracted_prefs": extracted_prefs,
+ "candidates": candidates,
+ "cand_item_vecs": cand_item_vecs,
+ "chosen_indices": chosen_indices,
+ "probs": probs,
+ "memories_t": memories_t,
+ "memory_notes": memory_notes,
+ "prompt_tokens": prompt_tokens,
+ }
+ }
+
+ def chat_complete(self, answer_t: str, context: dict) -> AssistantResponse:
+ """
+ Complete chat with LLM response.
+
+ This takes the LLM response and context from chat_prepare(), and
+ does all post-processing (add to history, debug info, etc.).
+
+ Args:
+ answer_t: The LLM response text.
+ context: Context dict from chat_prepare().
+
+ Returns:
+ AssistantResponse containing the answer, usage stats, and debug info.
+ """
+ # Unpack context
+ user_id = context["user_id"]
+ query = context["query"]
+ ctx = context["ctx"]
+ session = context["session"]
+ user_state = context["user_state"]
+ z_long_before = context["z_long_before"]
+ z_short_before = context["z_short_before"]
+ e_q_t = context["e_q_t"]
+ extracted_prefs = context["extracted_prefs"]
+ candidates = context["candidates"]
+ cand_item_vecs = context["cand_item_vecs"]
+ chosen_indices = context["chosen_indices"]
+ probs = context["probs"]
+ memories_t = context["memories_t"]
+ memory_notes = context["memory_notes"]
+ prompt_tokens = context["prompt_tokens"]
+
+ completion_tokens = self._count_tokens(answer_t)
+
+ # Add assistant turn to history
+ assist_turn = self._build_chat_turn(user_id, answer_t, "assistant", ctx.turn_counter)
+ session.history.append(assist_turn)
+
+ # Update session state for next turn
+ session.last_query = query
+ session.last_answer = answer_t
+ session.last_memories = memories_t
+ session.last_query_embedding = e_q_t
+ session.last_candidate_item_vectors = cand_item_vecs
+ session.last_policy_probs = probs
+ session.last_chosen_indices = list(chosen_indices) if len(chosen_indices) > 0 else []
+
+ ctx.turn_counter += 1
+
+ # Build debug info
+ debug = DebugInfo(
+ selected_memory_ids=[m.card_id for m in memories_t],
+ selected_memory_notes=[m.note_text for m in memories_t],
+ selected_memory_scores=[float(probs[i]) if i < len(probs) else 0.0 for i in chosen_indices] if len(chosen_indices) > 0 else [],
+ user_vector_before=z_long_before + z_short_before,
+ user_vector_after=user_state.z_long.tolist() + user_state.z_short.tolist(),
+ extracted_preferences=extracted_prefs,
+ extra={
+ "num_candidates": len(candidates),
+ "num_total_memories": len(self._memory_cards),
+ "z_long_norm": float(np.linalg.norm(user_state.z_long)),
+ "z_short_norm": float(np.linalg.norm(user_state.z_short)),
+ }
+ )
+
+ # Build usage stats
+ usage = UsageStats(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ model=self._llm_name,
+ )
+
+ return AssistantResponse(
+ answer=answer_t,
+ usage=usage,
+ debug=debug,
+ )
+
def reset_session(self, user_id: str) -> None:
"""
Reset session for a user (new chat window).