""" LLM-as-Judge evaluation framework for personalization experiments. Evaluates: 1. Task accuracy (did the agent solve the problem?) 2. Preference compliance (did the agent follow user preferences?) 3. Conflict resolution (did the agent pick the right preference in conflicts?) 4. User effort (how much did the user have to correct the agent?) """ import json import re from typing import List, Dict, Any, Optional, Tuple from dataclasses import dataclass, field from enum import Enum import sys import os # Add parent to path for imports sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) class PreferenceViolationType(Enum): """Types of preference violations.""" WRONG_FORMAT = "wrong_format" # bullets vs numbered, etc. WRONG_VERBOSITY = "wrong_verbosity" # too long/short WRONG_STYLE = "wrong_style" # code style, explanation style WRONG_TONE = "wrong_tone" # too casual/formal CONFLICT_WRONG = "conflict_wrong" # picked wrong preference in conflict OVER_PERSONALIZATION = "over_personalization" # applied irrelevant preference @dataclass class JudgmentResult: """Result of LLM judge evaluation.""" # Task success task_correct: bool task_confidence: float # 0-1 # Preference compliance preferences_followed: List[str] preferences_violated: List[Tuple[str, PreferenceViolationType]] compliance_score: float # 0-1 # Conflict resolution (if applicable) conflict_present: bool = False conflict_resolved_correctly: bool = False expected_preference: Optional[str] = None applied_preference: Optional[str] = None # Over-personalization detection over_personalized: bool = False irrelevant_preferences_applied: List[str] = field(default_factory=list) # Raw judge outputs raw_judgments: Dict[str, str] = field(default_factory=dict) @dataclass class ConversationMetrics: """Metrics for a full conversation.""" # Task metrics task_success: bool turns_to_success: int total_turns: int # User effort user_token_count: int enforcement_count: int # times user had to enforce preferences disappointment_count: int # times user expressed disappointment # Efficiency total_token_count: int agent_token_count: int # Preference metrics preference_compliance_scores: List[float] # per-turn conflict_resolution_accuracy: float # across all conflicts over_personalization_rate: float class LLMJudge: """LLM-based judge for evaluating personalization quality.""" def __init__( self, model_name: str = "meta-llama/Llama-3.3-70B-Instruct", temperature: float = 0.0, max_tokens: int = 1024, ): self.model_name = model_name self.temperature = temperature self.max_tokens = max_tokens self._client = None def _get_client(self): """Lazy initialization of LLM client.""" if self._client is None: # Try vLLM first, then fall back to HuggingFace try: from vllm import LLM, SamplingParams self._client_type = "vllm" self._client = LLM(model=self.model_name) self._sampling_params = SamplingParams( temperature=self.temperature, max_tokens=self.max_tokens ) except ImportError: try: from transformers import AutoTokenizer, AutoModelForCausalLM import torch self._client_type = "hf" self._tokenizer = AutoTokenizer.from_pretrained(self.model_name) self._client = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.bfloat16, device_map="auto" ) except Exception as e: print(f"Warning: Could not load LLM judge model: {e}") self._client_type = "mock" self._client = "mock" return self._client def _generate(self, prompt: str) -> str: """Generate response from judge LLM.""" client = self._get_client() if self._client_type == "vllm": outputs = client.generate([prompt], self._sampling_params) return outputs[0].outputs[0].text elif self._client_type == "hf": inputs = self._tokenizer(prompt, return_tensors="pt").to(client.device) outputs = client.generate( **inputs, max_new_tokens=self.max_tokens, temperature=self.temperature if self.temperature > 0 else None, do_sample=self.temperature > 0 ) return self._tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) else: # Mock for testing return '{"task_correct": true, "confidence": 0.8}' # ========================================================================= # Task Correctness Evaluation # ========================================================================= def judge_task_correctness( self, problem: str, ground_truth: str, agent_final_answer: str, domain: str ) -> Tuple[bool, float, str]: """Judge if the agent's answer is correct. Returns: (is_correct, confidence, reasoning) """ prompt = self._build_task_correctness_prompt( problem, ground_truth, agent_final_answer, domain ) response = self._generate(prompt) return self._parse_task_correctness_response(response) def _build_task_correctness_prompt( self, problem: str, ground_truth: str, agent_answer: str, domain: str ) -> str: """Build prompt for task correctness evaluation.""" return f"""You are an expert evaluator. Determine if the agent's answer is correct. DOMAIN: {domain} PROBLEM: {problem} GROUND TRUTH ANSWER: {ground_truth} AGENT'S ANSWER: {agent_answer} EVALUATION CRITERIA: - For math: Check if the final numerical answer matches (accounting for equivalent forms) - For code: Check if the logic is correct and would produce correct output - For multiple choice: Check if the selected option matches - For reasoning: Check if the conclusion is correct and reasoning is sound Respond in JSON format: {{ "is_correct": true/false, "confidence": 0.0-1.0, "reasoning": "Brief explanation" }}""" def _parse_task_correctness_response(self, response: str) -> Tuple[bool, float, str]: """Parse the task correctness judgment.""" try: # Extract JSON from response json_match = re.search(r'\{[^}]+\}', response, re.DOTALL) if json_match: data = json.loads(json_match.group()) return ( data.get("is_correct", False), data.get("confidence", 0.5), data.get("reasoning", "") ) except Exception: pass # Default parsing is_correct = "true" in response.lower() or "correct" in response.lower() return is_correct, 0.5, response # ========================================================================= # Preference Compliance Evaluation # ========================================================================= def judge_preference_compliance( self, user_preferences: List[Dict[str, str]], query: str, agent_response: str, context_signals: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """Judge if the agent followed applicable preferences. Args: user_preferences: List of {condition, preference} dicts query: The user's query agent_response: The agent's response context_signals: Optional context cues (rushed, frustrated, etc.) Returns: Dictionary with compliance details """ prompt = self._build_preference_compliance_prompt( user_preferences, query, agent_response, context_signals ) response = self._generate(prompt) return self._parse_preference_compliance_response(response, user_preferences) def _build_preference_compliance_prompt( self, preferences: List[Dict[str, str]], query: str, response: str, context: Optional[Dict[str, Any]] ) -> str: """Build prompt for preference compliance evaluation.""" prefs_str = "\n".join([ f"{i+1}. When: {p.get('condition', 'always')} -> {p['preference']}" for i, p in enumerate(preferences) ]) context_str = "" if context: context_str = f"\nCONTEXT SIGNALS: {json.dumps(context)}" return f"""You are evaluating if an AI assistant followed user preferences. USER'S KNOWN PREFERENCES: {prefs_str} USER QUERY: {query} {context_str} AGENT RESPONSE: {response} TASK: 1. Identify which preferences APPLY to this query (based on conditions) 2. For each applicable preference, judge if it was FOLLOWED or VIOLATED 3. Check for OVER-PERSONALIZATION (applying preferences that don't apply here) Respond in JSON format: {{ "applicable_preferences": [1, 3, 5], // preference numbers that apply "followed": [1, 5], // preferences that were correctly followed "violated": [3], // preferences that were violated "violation_types": {{"3": "wrong_format"}}, // type of each violation "over_personalization": [], // preferences applied but shouldn't have been "compliance_score": 0.67, // followed / applicable "explanation": "Brief explanation of judgment" }}""" def _parse_preference_compliance_response( self, response: str, preferences: List[Dict[str, str]] ) -> Dict[str, Any]: """Parse preference compliance judgment.""" try: json_match = re.search(r'\{[^}]*\}', response, re.DOTALL) if json_match: return json.loads(json_match.group()) except Exception: pass # Default return { "applicable_preferences": [], "followed": [], "violated": [], "violation_types": {}, "over_personalization": [], "compliance_score": 0.5, "explanation": response } # ========================================================================= # Conflict Resolution Evaluation # ========================================================================= def judge_conflict_resolution( self, conflicting_preferences: List[Dict[str, str]], expected_preference_idx: int, query: str, agent_response: str ) -> Dict[str, Any]: """Judge if the agent correctly resolved a preference conflict. Args: conflicting_preferences: List of preferences that conflict expected_preference_idx: Index of the preference that SHOULD be applied query: The user's query agent_response: The agent's response Returns: Dictionary with conflict resolution details """ prompt = self._build_conflict_resolution_prompt( conflicting_preferences, expected_preference_idx, query, agent_response ) response = self._generate(prompt) return self._parse_conflict_resolution_response(response) def _build_conflict_resolution_prompt( self, preferences: List[Dict[str, str]], expected_idx: int, query: str, response: str ) -> str: """Build prompt for conflict resolution evaluation.""" prefs_str = "\n".join([ f"{i+1}. When: {p.get('condition', 'always')} -> {p['preference']}" for i, p in enumerate(preferences) ]) return f"""You are evaluating how an AI handled CONFLICTING user preferences. CONFLICTING PREFERENCES: {prefs_str} These preferences cannot all be followed simultaneously. Based on the query context, preference #{expected_idx + 1} is the most appropriate one to follow. USER QUERY: {query} AGENT RESPONSE: {response} TASK: 1. Determine which preference the agent actually followed 2. Judge if this was the correct choice given the context 3. Note if the agent tried to follow multiple conflicting preferences (confused behavior) Respond in JSON format: {{ "preference_applied": 1, // which preference number was actually followed (0 if unclear) "correct_resolution": true/false, // did they pick the right one? "confused_behavior": false, // did they try to follow multiple conflicting ones? "explanation": "Brief explanation" }}""" def _parse_conflict_resolution_response(self, response: str) -> Dict[str, Any]: """Parse conflict resolution judgment.""" try: json_match = re.search(r'\{[^}]*\}', response, re.DOTALL) if json_match: return json.loads(json_match.group()) except Exception: pass return { "preference_applied": 0, "correct_resolution": False, "confused_behavior": True, "explanation": response } # ========================================================================= # Full Conversation Evaluation # ========================================================================= def evaluate_conversation( self, conversation: List[Dict[str, str]], user_preferences: List[Dict[str, str]], problem: str, ground_truth: str, domain: str, conflict_scenarios: Optional[List[Dict]] = None ) -> ConversationMetrics: """Evaluate an entire conversation. Args: conversation: List of {role, content} turns user_preferences: User's conditional preferences problem: The original problem ground_truth: Expected answer domain: Problem domain conflict_scenarios: Optional list of known conflict points Returns: ConversationMetrics with all evaluation results """ # Count tokens user_tokens = sum( len(turn["content"].split()) for turn in conversation if turn["role"] == "user" ) agent_tokens = sum( len(turn["content"].split()) for turn in conversation if turn["role"] == "assistant" ) total_tokens = user_tokens + agent_tokens # Count user effort indicators enforcement_count = 0 disappointment_count = 0 enforcement_phrases = [ "please use", "I asked for", "remember that I", "as I mentioned", "like I said", "I prefer" ] disappointment_phrases = [ "not what I wanted", "that's not", "I said", "too long", "too short", "wrong format" ] for turn in conversation: if turn["role"] == "user": content_lower = turn["content"].lower() if any(phrase in content_lower for phrase in enforcement_phrases): enforcement_count += 1 if any(phrase in content_lower for phrase in disappointment_phrases): disappointment_count += 1 # Find the final answer final_answer = "" for turn in reversed(conversation): if turn["role"] == "assistant": final_answer = turn["content"] break # Judge task correctness task_correct, task_conf, _ = self.judge_task_correctness( problem, ground_truth, final_answer, domain ) # Judge preference compliance for each agent turn compliance_scores = [] for i, turn in enumerate(conversation): if turn["role"] == "assistant" and i > 0: # Get the query that preceded this response query = conversation[i-1]["content"] if conversation[i-1]["role"] == "user" else "" compliance = self.judge_preference_compliance( user_preferences, query, turn["content"] ) compliance_scores.append(compliance.get("compliance_score", 0.5)) # Judge conflict resolution if applicable conflict_accuracy = 1.0 if conflict_scenarios: correct_resolutions = 0 for scenario in conflict_scenarios: # Find the relevant turn for i, turn in enumerate(conversation): if turn["role"] == "assistant" and scenario.get("turn_idx") == i: result = self.judge_conflict_resolution( scenario["preferences"], scenario["expected_idx"], conversation[i-1]["content"], turn["content"] ) if result.get("correct_resolution"): correct_resolutions += 1 break conflict_accuracy = correct_resolutions / len(conflict_scenarios) if conflict_scenarios else 1.0 # Calculate over-personalization rate over_personalization_count = 0 for i, turn in enumerate(conversation): if turn["role"] == "assistant" and i > 0: query = conversation[i-1]["content"] if conversation[i-1]["role"] == "user" else "" compliance = self.judge_preference_compliance( user_preferences, query, turn["content"] ) if compliance.get("over_personalization"): over_personalization_count += 1 agent_turns = sum(1 for t in conversation if t["role"] == "assistant") over_personalization_rate = over_personalization_count / agent_turns if agent_turns > 0 else 0 return ConversationMetrics( task_success=task_correct, turns_to_success=len(conversation) if task_correct else -1, total_turns=len(conversation), user_token_count=user_tokens, enforcement_count=enforcement_count, disappointment_count=disappointment_count, total_token_count=total_tokens, agent_token_count=agent_tokens, preference_compliance_scores=compliance_scores, conflict_resolution_accuracy=conflict_accuracy, over_personalization_rate=over_personalization_rate ) class BatchEvaluator: """Batch evaluation across multiple conversations and methods.""" def __init__(self, judge: LLMJudge): self.judge = judge def evaluate_method( self, method_name: str, conversations: List[Dict], user_profiles: List[Dict], problems: List[Dict] ) -> Dict[str, Any]: """Evaluate a method across all conversations. Returns aggregated metrics. """ all_metrics = [] for conv, profile, problem in zip(conversations, user_profiles, problems): metrics = self.judge.evaluate_conversation( conversation=conv["turns"], user_preferences=profile["preferences"], problem=problem["problem"], ground_truth=problem["solution"], domain=problem["domain"], conflict_scenarios=conv.get("conflict_scenarios") ) all_metrics.append(metrics) # Aggregate n = len(all_metrics) return { "method": method_name, "n_conversations": n, "task_success_rate": sum(m.task_success for m in all_metrics) / n, "avg_turns": sum(m.total_turns for m in all_metrics) / n, "avg_user_tokens": sum(m.user_token_count for m in all_metrics) / n, "avg_total_tokens": sum(m.total_token_count for m in all_metrics) / n, "avg_enforcement_count": sum(m.enforcement_count for m in all_metrics) / n, "avg_disappointment_count": sum(m.disappointment_count for m in all_metrics) / n, "avg_compliance_score": sum( sum(m.preference_compliance_scores) / len(m.preference_compliance_scores) if m.preference_compliance_scores else 0.5 for m in all_metrics ) / n, "conflict_resolution_accuracy": sum(m.conflict_resolution_accuracy for m in all_metrics) / n, "over_personalization_rate": sum(m.over_personalization_rate for m in all_metrics) / n, } def compare_methods( self, results_by_method: Dict[str, List[Dict]], user_profiles: List[Dict], problems: List[Dict] ) -> Dict[str, Dict]: """Compare multiple methods. Args: results_by_method: {method_name: list of conversation results} Returns: Comparative analysis """ method_metrics = {} for method_name, conversations in results_by_method.items(): method_metrics[method_name] = self.evaluate_method( method_name, conversations, user_profiles, problems ) # Add comparative analysis metrics_to_compare = [ "task_success_rate", "avg_user_tokens", "avg_total_tokens", "avg_compliance_score", "conflict_resolution_accuracy", "over_personalization_rate" ] comparison = {} for metric in metrics_to_compare: values = {m: method_metrics[m][metric] for m in method_metrics} best_method = max(values, key=values.get) if "rate" in metric or "score" in metric or "accuracy" in metric else min(values, key=values.get) comparison[metric] = { "values": values, "best": best_method, "best_value": values[best_method] } return { "per_method": method_metrics, "comparison": comparison } # ============================================================================= # User Effort Analysis # ============================================================================= def analyze_user_effort(conversation: List[Dict[str, str]]) -> Dict[str, Any]: """Detailed analysis of user effort in a conversation. Returns: Dictionary with effort metrics and categorization """ effort_categories = { "preference_enforcement": [], # explicit preference reminders "clarification_requests": [], # asking to clarify "corrections": [], # correcting mistakes "rephrasing": [], # saying same thing differently "frustration": [], # expressing frustration } enforcement_patterns = [ r"(please |can you |could you )?(use|give me|format|write)", r"(I |we )(prefer|want|need|asked)", r"(like I |as I )(said|mentioned|asked)", r"remember (that |to )", ] clarification_patterns = [ r"what do you mean", r"(can you |could you )?(explain|clarify)", r"I don't understand", r"(not sure|unclear)", ] correction_patterns = [ r"(that's |this is )(not |wrong|incorrect)", r"(no|actually),? (I |the |it )", r"you (missed|forgot|ignored)", ] frustration_patterns = [ r"(ugh|sigh|argh)", r"(frustrat|annoy|confus)", r"why (can't|won't|don't) you", r"this is (hard|difficult|impossible)", ] for turn in conversation: if turn["role"] != "user": continue content = turn["content"].lower() for pattern in enforcement_patterns: if re.search(pattern, content): effort_categories["preference_enforcement"].append(turn["content"]) break for pattern in clarification_patterns: if re.search(pattern, content): effort_categories["clarification_requests"].append(turn["content"]) break for pattern in correction_patterns: if re.search(pattern, content): effort_categories["corrections"].append(turn["content"]) break for pattern in frustration_patterns: if re.search(pattern, content): effort_categories["frustration"].append(turn["content"]) break # Calculate effort score (weighted) weights = { "preference_enforcement": 2.0, "clarification_requests": 1.0, "corrections": 3.0, "rephrasing": 1.5, "frustration": 2.5, } effort_score = sum( len(items) * weights[cat] for cat, items in effort_categories.items() ) return { "categories": effort_categories, "counts": {cat: len(items) for cat, items in effort_categories.items()}, "total_effort_instances": sum(len(items) for items in effort_categories.values()), "weighted_effort_score": effort_score, } if __name__ == "__main__": # Test the judge judge = LLMJudge() # Test task correctness print("Testing task correctness judgment...") correct, conf, reason = judge.judge_task_correctness( problem="What is 2 + 2?", ground_truth="4", agent_final_answer="The answer is 4.", domain="math" ) print(f"Correct: {correct}, Confidence: {conf}") # Test preference compliance print("\nTesting preference compliance...") preferences = [ {"condition": "always", "preference": "Use bullet points for lists"}, {"condition": "when explaining", "preference": "Include examples"}, ] compliance = judge.judge_preference_compliance( user_preferences=preferences, query="Explain how to make coffee", agent_response="Here's how to make coffee:\n- Boil water\n- Add grounds\n- Pour and enjoy" ) print(f"Compliance: {compliance}")