diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-27 09:57:37 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-27 09:57:37 -0600 |
| commit | dc801c07cf38b0c495686463e6ca6f871a64440e (patch) | |
| tree | 599f03114775921dbc472403c701f4a3a8ea188a /collaborativeagents/evaluation/llm_judge.py | |
| parent | e43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (diff) | |
Add collaborativeagents module and update gitignore
- Add collaborativeagents subproject with adapters, agents, and evaluation modules
- Update .gitignore to exclude large binary files (.whl, .tar), wandb logs, and results
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'collaborativeagents/evaluation/llm_judge.py')
| -rw-r--r-- | collaborativeagents/evaluation/llm_judge.py | 748 |
1 files changed, 748 insertions, 0 deletions
diff --git a/collaborativeagents/evaluation/llm_judge.py b/collaborativeagents/evaluation/llm_judge.py new file mode 100644 index 0000000..e6b90e8 --- /dev/null +++ b/collaborativeagents/evaluation/llm_judge.py @@ -0,0 +1,748 @@ +""" +LLM-as-Judge evaluation framework for personalization experiments. + +Evaluates: +1. Task accuracy (did the agent solve the problem?) +2. Preference compliance (did the agent follow user preferences?) +3. Conflict resolution (did the agent pick the right preference in conflicts?) +4. User effort (how much did the user have to correct the agent?) +""" + +import json +import re +from typing import List, Dict, Any, Optional, Tuple +from dataclasses import dataclass, field +from enum import Enum +import sys +import os + +# Add parent to path for imports +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +class PreferenceViolationType(Enum): + """Types of preference violations.""" + WRONG_FORMAT = "wrong_format" # bullets vs numbered, etc. + WRONG_VERBOSITY = "wrong_verbosity" # too long/short + WRONG_STYLE = "wrong_style" # code style, explanation style + WRONG_TONE = "wrong_tone" # too casual/formal + CONFLICT_WRONG = "conflict_wrong" # picked wrong preference in conflict + OVER_PERSONALIZATION = "over_personalization" # applied irrelevant preference + + +@dataclass +class JudgmentResult: + """Result of LLM judge evaluation.""" + # Task success + task_correct: bool + task_confidence: float # 0-1 + + # Preference compliance + preferences_followed: List[str] + preferences_violated: List[Tuple[str, PreferenceViolationType]] + compliance_score: float # 0-1 + + # Conflict resolution (if applicable) + conflict_present: bool = False + conflict_resolved_correctly: bool = False + expected_preference: Optional[str] = None + applied_preference: Optional[str] = None + + # Over-personalization detection + over_personalized: bool = False + irrelevant_preferences_applied: List[str] = field(default_factory=list) + + # Raw judge outputs + raw_judgments: Dict[str, str] = field(default_factory=dict) + + +@dataclass +class ConversationMetrics: + """Metrics for a full conversation.""" + # Task metrics + task_success: bool + turns_to_success: int + total_turns: int + + # User effort + user_token_count: int + enforcement_count: int # times user had to enforce preferences + disappointment_count: int # times user expressed disappointment + + # Efficiency + total_token_count: int + agent_token_count: int + + # Preference metrics + preference_compliance_scores: List[float] # per-turn + conflict_resolution_accuracy: float # across all conflicts + over_personalization_rate: float + + +class LLMJudge: + """LLM-based judge for evaluating personalization quality.""" + + def __init__( + self, + model_name: str = "meta-llama/Llama-3.3-70B-Instruct", + temperature: float = 0.0, + max_tokens: int = 1024, + ): + self.model_name = model_name + self.temperature = temperature + self.max_tokens = max_tokens + self._client = None + + def _get_client(self): + """Lazy initialization of LLM client.""" + if self._client is None: + # Try vLLM first, then fall back to HuggingFace + try: + from vllm import LLM, SamplingParams + self._client_type = "vllm" + self._client = LLM(model=self.model_name) + self._sampling_params = SamplingParams( + temperature=self.temperature, + max_tokens=self.max_tokens + ) + except ImportError: + try: + from transformers import AutoTokenizer, AutoModelForCausalLM + import torch + self._client_type = "hf" + self._tokenizer = AutoTokenizer.from_pretrained(self.model_name) + self._client = AutoModelForCausalLM.from_pretrained( + self.model_name, + torch_dtype=torch.bfloat16, + device_map="auto" + ) + except Exception as e: + print(f"Warning: Could not load LLM judge model: {e}") + self._client_type = "mock" + self._client = "mock" + + return self._client + + def _generate(self, prompt: str) -> str: + """Generate response from judge LLM.""" + client = self._get_client() + + if self._client_type == "vllm": + outputs = client.generate([prompt], self._sampling_params) + return outputs[0].outputs[0].text + + elif self._client_type == "hf": + inputs = self._tokenizer(prompt, return_tensors="pt").to(client.device) + outputs = client.generate( + **inputs, + max_new_tokens=self.max_tokens, + temperature=self.temperature if self.temperature > 0 else None, + do_sample=self.temperature > 0 + ) + return self._tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) + + else: + # Mock for testing + return '{"task_correct": true, "confidence": 0.8}' + + # ========================================================================= + # Task Correctness Evaluation + # ========================================================================= + + def judge_task_correctness( + self, + problem: str, + ground_truth: str, + agent_final_answer: str, + domain: str + ) -> Tuple[bool, float, str]: + """Judge if the agent's answer is correct. + + Returns: + (is_correct, confidence, reasoning) + """ + prompt = self._build_task_correctness_prompt( + problem, ground_truth, agent_final_answer, domain + ) + + response = self._generate(prompt) + return self._parse_task_correctness_response(response) + + def _build_task_correctness_prompt( + self, + problem: str, + ground_truth: str, + agent_answer: str, + domain: str + ) -> str: + """Build prompt for task correctness evaluation.""" + return f"""You are an expert evaluator. Determine if the agent's answer is correct. + +DOMAIN: {domain} + +PROBLEM: +{problem} + +GROUND TRUTH ANSWER: +{ground_truth} + +AGENT'S ANSWER: +{agent_answer} + +EVALUATION CRITERIA: +- For math: Check if the final numerical answer matches (accounting for equivalent forms) +- For code: Check if the logic is correct and would produce correct output +- For multiple choice: Check if the selected option matches +- For reasoning: Check if the conclusion is correct and reasoning is sound + +Respond in JSON format: +{{ + "is_correct": true/false, + "confidence": 0.0-1.0, + "reasoning": "Brief explanation" +}}""" + + def _parse_task_correctness_response(self, response: str) -> Tuple[bool, float, str]: + """Parse the task correctness judgment.""" + try: + # Extract JSON from response + json_match = re.search(r'\{[^}]+\}', response, re.DOTALL) + if json_match: + data = json.loads(json_match.group()) + return ( + data.get("is_correct", False), + data.get("confidence", 0.5), + data.get("reasoning", "") + ) + except Exception: + pass + + # Default parsing + is_correct = "true" in response.lower() or "correct" in response.lower() + return is_correct, 0.5, response + + # ========================================================================= + # Preference Compliance Evaluation + # ========================================================================= + + def judge_preference_compliance( + self, + user_preferences: List[Dict[str, str]], + query: str, + agent_response: str, + context_signals: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """Judge if the agent followed applicable preferences. + + Args: + user_preferences: List of {condition, preference} dicts + query: The user's query + agent_response: The agent's response + context_signals: Optional context cues (rushed, frustrated, etc.) + + Returns: + Dictionary with compliance details + """ + prompt = self._build_preference_compliance_prompt( + user_preferences, query, agent_response, context_signals + ) + + response = self._generate(prompt) + return self._parse_preference_compliance_response(response, user_preferences) + + def _build_preference_compliance_prompt( + self, + preferences: List[Dict[str, str]], + query: str, + response: str, + context: Optional[Dict[str, Any]] + ) -> str: + """Build prompt for preference compliance evaluation.""" + prefs_str = "\n".join([ + f"{i+1}. When: {p.get('condition', 'always')} -> {p['preference']}" + for i, p in enumerate(preferences) + ]) + + context_str = "" + if context: + context_str = f"\nCONTEXT SIGNALS: {json.dumps(context)}" + + return f"""You are evaluating if an AI assistant followed user preferences. + +USER'S KNOWN PREFERENCES: +{prefs_str} + +USER QUERY: +{query} +{context_str} + +AGENT RESPONSE: +{response} + +TASK: +1. Identify which preferences APPLY to this query (based on conditions) +2. For each applicable preference, judge if it was FOLLOWED or VIOLATED +3. Check for OVER-PERSONALIZATION (applying preferences that don't apply here) + +Respond in JSON format: +{{ + "applicable_preferences": [1, 3, 5], // preference numbers that apply + "followed": [1, 5], // preferences that were correctly followed + "violated": [3], // preferences that were violated + "violation_types": {{"3": "wrong_format"}}, // type of each violation + "over_personalization": [], // preferences applied but shouldn't have been + "compliance_score": 0.67, // followed / applicable + "explanation": "Brief explanation of judgment" +}}""" + + def _parse_preference_compliance_response( + self, + response: str, + preferences: List[Dict[str, str]] + ) -> Dict[str, Any]: + """Parse preference compliance judgment.""" + try: + json_match = re.search(r'\{[^}]*\}', response, re.DOTALL) + if json_match: + return json.loads(json_match.group()) + except Exception: + pass + + # Default + return { + "applicable_preferences": [], + "followed": [], + "violated": [], + "violation_types": {}, + "over_personalization": [], + "compliance_score": 0.5, + "explanation": response + } + + # ========================================================================= + # Conflict Resolution Evaluation + # ========================================================================= + + def judge_conflict_resolution( + self, + conflicting_preferences: List[Dict[str, str]], + expected_preference_idx: int, + query: str, + agent_response: str + ) -> Dict[str, Any]: + """Judge if the agent correctly resolved a preference conflict. + + Args: + conflicting_preferences: List of preferences that conflict + expected_preference_idx: Index of the preference that SHOULD be applied + query: The user's query + agent_response: The agent's response + + Returns: + Dictionary with conflict resolution details + """ + prompt = self._build_conflict_resolution_prompt( + conflicting_preferences, expected_preference_idx, query, agent_response + ) + + response = self._generate(prompt) + return self._parse_conflict_resolution_response(response) + + def _build_conflict_resolution_prompt( + self, + preferences: List[Dict[str, str]], + expected_idx: int, + query: str, + response: str + ) -> str: + """Build prompt for conflict resolution evaluation.""" + prefs_str = "\n".join([ + f"{i+1}. When: {p.get('condition', 'always')} -> {p['preference']}" + for i, p in enumerate(preferences) + ]) + + return f"""You are evaluating how an AI handled CONFLICTING user preferences. + +CONFLICTING PREFERENCES: +{prefs_str} + +These preferences cannot all be followed simultaneously. Based on the query context, +preference #{expected_idx + 1} is the most appropriate one to follow. + +USER QUERY: +{query} + +AGENT RESPONSE: +{response} + +TASK: +1. Determine which preference the agent actually followed +2. Judge if this was the correct choice given the context +3. Note if the agent tried to follow multiple conflicting preferences (confused behavior) + +Respond in JSON format: +{{ + "preference_applied": 1, // which preference number was actually followed (0 if unclear) + "correct_resolution": true/false, // did they pick the right one? + "confused_behavior": false, // did they try to follow multiple conflicting ones? + "explanation": "Brief explanation" +}}""" + + def _parse_conflict_resolution_response(self, response: str) -> Dict[str, Any]: + """Parse conflict resolution judgment.""" + try: + json_match = re.search(r'\{[^}]*\}', response, re.DOTALL) + if json_match: + return json.loads(json_match.group()) + except Exception: + pass + + return { + "preference_applied": 0, + "correct_resolution": False, + "confused_behavior": True, + "explanation": response + } + + # ========================================================================= + # Full Conversation Evaluation + # ========================================================================= + + def evaluate_conversation( + self, + conversation: List[Dict[str, str]], + user_preferences: List[Dict[str, str]], + problem: str, + ground_truth: str, + domain: str, + conflict_scenarios: Optional[List[Dict]] = None + ) -> ConversationMetrics: + """Evaluate an entire conversation. + + Args: + conversation: List of {role, content} turns + user_preferences: User's conditional preferences + problem: The original problem + ground_truth: Expected answer + domain: Problem domain + conflict_scenarios: Optional list of known conflict points + + Returns: + ConversationMetrics with all evaluation results + """ + # Count tokens + user_tokens = sum( + len(turn["content"].split()) + for turn in conversation if turn["role"] == "user" + ) + agent_tokens = sum( + len(turn["content"].split()) + for turn in conversation if turn["role"] == "assistant" + ) + total_tokens = user_tokens + agent_tokens + + # Count user effort indicators + enforcement_count = 0 + disappointment_count = 0 + enforcement_phrases = [ + "please use", "I asked for", "remember that I", + "as I mentioned", "like I said", "I prefer" + ] + disappointment_phrases = [ + "not what I wanted", "that's not", "I said", + "too long", "too short", "wrong format" + ] + + for turn in conversation: + if turn["role"] == "user": + content_lower = turn["content"].lower() + if any(phrase in content_lower for phrase in enforcement_phrases): + enforcement_count += 1 + if any(phrase in content_lower for phrase in disappointment_phrases): + disappointment_count += 1 + + # Find the final answer + final_answer = "" + for turn in reversed(conversation): + if turn["role"] == "assistant": + final_answer = turn["content"] + break + + # Judge task correctness + task_correct, task_conf, _ = self.judge_task_correctness( + problem, ground_truth, final_answer, domain + ) + + # Judge preference compliance for each agent turn + compliance_scores = [] + for i, turn in enumerate(conversation): + if turn["role"] == "assistant" and i > 0: + # Get the query that preceded this response + query = conversation[i-1]["content"] if conversation[i-1]["role"] == "user" else "" + compliance = self.judge_preference_compliance( + user_preferences, query, turn["content"] + ) + compliance_scores.append(compliance.get("compliance_score", 0.5)) + + # Judge conflict resolution if applicable + conflict_accuracy = 1.0 + if conflict_scenarios: + correct_resolutions = 0 + for scenario in conflict_scenarios: + # Find the relevant turn + for i, turn in enumerate(conversation): + if turn["role"] == "assistant" and scenario.get("turn_idx") == i: + result = self.judge_conflict_resolution( + scenario["preferences"], + scenario["expected_idx"], + conversation[i-1]["content"], + turn["content"] + ) + if result.get("correct_resolution"): + correct_resolutions += 1 + break + + conflict_accuracy = correct_resolutions / len(conflict_scenarios) if conflict_scenarios else 1.0 + + # Calculate over-personalization rate + over_personalization_count = 0 + for i, turn in enumerate(conversation): + if turn["role"] == "assistant" and i > 0: + query = conversation[i-1]["content"] if conversation[i-1]["role"] == "user" else "" + compliance = self.judge_preference_compliance( + user_preferences, query, turn["content"] + ) + if compliance.get("over_personalization"): + over_personalization_count += 1 + + agent_turns = sum(1 for t in conversation if t["role"] == "assistant") + over_personalization_rate = over_personalization_count / agent_turns if agent_turns > 0 else 0 + + return ConversationMetrics( + task_success=task_correct, + turns_to_success=len(conversation) if task_correct else -1, + total_turns=len(conversation), + user_token_count=user_tokens, + enforcement_count=enforcement_count, + disappointment_count=disappointment_count, + total_token_count=total_tokens, + agent_token_count=agent_tokens, + preference_compliance_scores=compliance_scores, + conflict_resolution_accuracy=conflict_accuracy, + over_personalization_rate=over_personalization_rate + ) + + +class BatchEvaluator: + """Batch evaluation across multiple conversations and methods.""" + + def __init__(self, judge: LLMJudge): + self.judge = judge + + def evaluate_method( + self, + method_name: str, + conversations: List[Dict], + user_profiles: List[Dict], + problems: List[Dict] + ) -> Dict[str, Any]: + """Evaluate a method across all conversations. + + Returns aggregated metrics. + """ + all_metrics = [] + + for conv, profile, problem in zip(conversations, user_profiles, problems): + metrics = self.judge.evaluate_conversation( + conversation=conv["turns"], + user_preferences=profile["preferences"], + problem=problem["problem"], + ground_truth=problem["solution"], + domain=problem["domain"], + conflict_scenarios=conv.get("conflict_scenarios") + ) + all_metrics.append(metrics) + + # Aggregate + n = len(all_metrics) + return { + "method": method_name, + "n_conversations": n, + "task_success_rate": sum(m.task_success for m in all_metrics) / n, + "avg_turns": sum(m.total_turns for m in all_metrics) / n, + "avg_user_tokens": sum(m.user_token_count for m in all_metrics) / n, + "avg_total_tokens": sum(m.total_token_count for m in all_metrics) / n, + "avg_enforcement_count": sum(m.enforcement_count for m in all_metrics) / n, + "avg_disappointment_count": sum(m.disappointment_count for m in all_metrics) / n, + "avg_compliance_score": sum( + sum(m.preference_compliance_scores) / len(m.preference_compliance_scores) + if m.preference_compliance_scores else 0.5 + for m in all_metrics + ) / n, + "conflict_resolution_accuracy": sum(m.conflict_resolution_accuracy for m in all_metrics) / n, + "over_personalization_rate": sum(m.over_personalization_rate for m in all_metrics) / n, + } + + def compare_methods( + self, + results_by_method: Dict[str, List[Dict]], + user_profiles: List[Dict], + problems: List[Dict] + ) -> Dict[str, Dict]: + """Compare multiple methods. + + Args: + results_by_method: {method_name: list of conversation results} + + Returns: + Comparative analysis + """ + method_metrics = {} + + for method_name, conversations in results_by_method.items(): + method_metrics[method_name] = self.evaluate_method( + method_name, conversations, user_profiles, problems + ) + + # Add comparative analysis + metrics_to_compare = [ + "task_success_rate", "avg_user_tokens", "avg_total_tokens", + "avg_compliance_score", "conflict_resolution_accuracy", + "over_personalization_rate" + ] + + comparison = {} + for metric in metrics_to_compare: + values = {m: method_metrics[m][metric] for m in method_metrics} + best_method = max(values, key=values.get) if "rate" in metric or "score" in metric or "accuracy" in metric else min(values, key=values.get) + comparison[metric] = { + "values": values, + "best": best_method, + "best_value": values[best_method] + } + + return { + "per_method": method_metrics, + "comparison": comparison + } + + +# ============================================================================= +# User Effort Analysis +# ============================================================================= + +def analyze_user_effort(conversation: List[Dict[str, str]]) -> Dict[str, Any]: + """Detailed analysis of user effort in a conversation. + + Returns: + Dictionary with effort metrics and categorization + """ + effort_categories = { + "preference_enforcement": [], # explicit preference reminders + "clarification_requests": [], # asking to clarify + "corrections": [], # correcting mistakes + "rephrasing": [], # saying same thing differently + "frustration": [], # expressing frustration + } + + enforcement_patterns = [ + r"(please |can you |could you )?(use|give me|format|write)", + r"(I |we )(prefer|want|need|asked)", + r"(like I |as I )(said|mentioned|asked)", + r"remember (that |to )", + ] + + clarification_patterns = [ + r"what do you mean", + r"(can you |could you )?(explain|clarify)", + r"I don't understand", + r"(not sure|unclear)", + ] + + correction_patterns = [ + r"(that's |this is )(not |wrong|incorrect)", + r"(no|actually),? (I |the |it )", + r"you (missed|forgot|ignored)", + ] + + frustration_patterns = [ + r"(ugh|sigh|argh)", + r"(frustrat|annoy|confus)", + r"why (can't|won't|don't) you", + r"this is (hard|difficult|impossible)", + ] + + for turn in conversation: + if turn["role"] != "user": + continue + + content = turn["content"].lower() + + for pattern in enforcement_patterns: + if re.search(pattern, content): + effort_categories["preference_enforcement"].append(turn["content"]) + break + + for pattern in clarification_patterns: + if re.search(pattern, content): + effort_categories["clarification_requests"].append(turn["content"]) + break + + for pattern in correction_patterns: + if re.search(pattern, content): + effort_categories["corrections"].append(turn["content"]) + break + + for pattern in frustration_patterns: + if re.search(pattern, content): + effort_categories["frustration"].append(turn["content"]) + break + + # Calculate effort score (weighted) + weights = { + "preference_enforcement": 2.0, + "clarification_requests": 1.0, + "corrections": 3.0, + "rephrasing": 1.5, + "frustration": 2.5, + } + + effort_score = sum( + len(items) * weights[cat] + for cat, items in effort_categories.items() + ) + + return { + "categories": effort_categories, + "counts": {cat: len(items) for cat, items in effort_categories.items()}, + "total_effort_instances": sum(len(items) for items in effort_categories.values()), + "weighted_effort_score": effort_score, + } + + +if __name__ == "__main__": + # Test the judge + judge = LLMJudge() + + # Test task correctness + print("Testing task correctness judgment...") + correct, conf, reason = judge.judge_task_correctness( + problem="What is 2 + 2?", + ground_truth="4", + agent_final_answer="The answer is 4.", + domain="math" + ) + print(f"Correct: {correct}, Confidence: {conf}") + + # Test preference compliance + print("\nTesting preference compliance...") + preferences = [ + {"condition": "always", "preference": "Use bullet points for lists"}, + {"condition": "when explaining", "preference": "Include examples"}, + ] + compliance = judge.judge_preference_compliance( + user_preferences=preferences, + query="Explain how to make coffee", + agent_response="Here's how to make coffee:\n- Boil water\n- Add grounds\n- Pour and enjoy" + ) + print(f"Compliance: {compliance}") |
