diff options
Diffstat (limited to 'src/personalization/evaluation/preference_bank')
3 files changed, 683 insertions, 0 deletions
diff --git a/src/personalization/evaluation/preference_bank/__init__.py b/src/personalization/evaluation/preference_bank/__init__.py new file mode 100644 index 0000000..33f0ed2 --- /dev/null +++ b/src/personalization/evaluation/preference_bank/__init__.py @@ -0,0 +1,6 @@ +from .schemas import PreferenceItem, PreferenceTopic, PreferenceBank +from .generator import PreferenceBankGenerator + +__all__ = ["PreferenceItem", "PreferenceTopic", "PreferenceBank", "PreferenceBankGenerator"] + + diff --git a/src/personalization/evaluation/preference_bank/generator.py b/src/personalization/evaluation/preference_bank/generator.py new file mode 100644 index 0000000..e256b86 --- /dev/null +++ b/src/personalization/evaluation/preference_bank/generator.py @@ -0,0 +1,530 @@ +""" +Preference Bank Generator + +Uses LLM to automatically generate diverse user preferences for each topic. +""" + +import json +import os +from typing import List, Dict, Any, Optional +from dataclasses import dataclass + +from .schemas import PreferenceItem, PreferenceTopic, PreferenceBank + + +# Topic definitions for the demo (5 topics) +DEMO_TOPICS = { + "math_formatting": { + "description": "How mathematical content should be formatted (LaTeX, plain text, markdown)", + "related_datasets": ["math-hard", "math-500", "gpqa"], + "generation_hints": [ + "LaTeX formatting for equations", + "Plain text vs mathematical notation", + "Inline vs block equations", + "Step-by-step calculation display", + "Variable naming conventions", + ], + }, + "coding_style": { + "description": "Preferences for code formatting, language choice, and documentation", + "related_datasets": ["humaneval", "bigcodebench"], + "generation_hints": [ + "Programming language preference (Python, JavaScript, etc.)", + "Type hints and annotations", + "Docstrings and comments", + "Code structure and organization", + "Naming conventions", + ], + }, + "response_structure": { + "description": "How responses should be organized (bullets, numbered lists, prose)", + "related_datasets": ["all"], + "generation_hints": [ + "Bullet points vs numbered lists vs prose", + "Headers and sections", + "TL;DR summaries", + "Outline before detailed explanation", + "Logical flow and transitions", + ], + }, + "explanation_depth": { + "description": "Level of detail and thoroughness in explanations", + "related_datasets": ["all"], + "generation_hints": [ + "Concise vs comprehensive", + "Examples and analogies", + "Background context", + "Assumptions stated explicitly", + "Multiple approaches/alternatives", + ], + }, + "interaction_style": { + "description": "How the agent should interact (questions, confirmations, suggestions)", + "related_datasets": ["all"], + "generation_hints": [ + "Asking clarifying questions", + "Step-by-step vs holistic answers", + "Proactive suggestions", + "Confidence levels in answers", + "Politeness and tone", + ], + }, +} + + +# LLM prompt template for generating preferences +GENERATION_PROMPT = '''You are helping design a user preference benchmark. Generate {num_prefs} diverse user preferences for the topic: "{topic_name}" + +Topic Description: {topic_description} + +Hints for preference types: +{hints} + +For each preference, provide a JSON object with: +1. "condition": When this preference applies (e.g., "when solving math problems", "when explaining code") +2. "action": What the user prefers (be specific and enforceable) +3. "conflict_group": If this preference conflicts with others in the list, give them the same group name (e.g., "notation_style"). Use null if no conflict. +4. "enforce_description": How a user would detect violation and enforce this preference +5. "example_violation": A concrete example of an agent response that violates this +6. "example_compliance": A concrete example that follows this preference + +Requirements: +- Make preferences SPECIFIC and ENFORCEABLE (not vague like "be helpful") +- Include 2-3 pairs of CONFLICTING preferences (same conflict_group) - this is important for testing RAG +- Vary specificity: some broad ("always use Python"), some narrow ("use f-strings for string formatting in Python") +- Preferences should be realistic things users actually care about + +Output as a JSON array of objects. Only output the JSON array, no other text. +''' + + +class PreferenceBankGenerator: + """Generates a preference bank using LLM.""" + + def __init__( + self, + llm_client: Any = None, + model_name: str = "gpt-4o-mini", # Default to a capable but fast model + ): + """ + Args: + llm_client: OpenAI-compatible client. If None, will create one. + model_name: Model to use for generation. + """ + self.model_name = model_name + + if llm_client is None: + try: + import openai + self.client = openai.OpenAI() + except Exception as e: + print(f"Warning: Could not initialize OpenAI client: {e}") + self.client = None + else: + self.client = llm_client + + def generate_preferences_for_topic( + self, + topic_name: str, + topic_description: str, + hints: List[str], + num_prefs: int = 5, + ) -> List[PreferenceItem]: + """Generate preferences for a single topic using LLM.""" + + if self.client is None: + print(f"No LLM client available, using fallback for topic: {topic_name}") + return self._generate_fallback_preferences(topic_name, num_prefs) + + hints_text = "\n".join(f"- {h}" for h in hints) + + prompt = GENERATION_PROMPT.format( + num_prefs=num_prefs, + topic_name=topic_name, + topic_description=topic_description, + hints=hints_text, + ) + + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=[{"role": "user", "content": prompt}], + temperature=0.8, + max_tokens=4000, + ) + + content = response.choices[0].message.content.strip() + + # Parse JSON + # Handle potential markdown code blocks + if content.startswith("```"): + content = content.split("```")[1] + if content.startswith("json"): + content = content[4:] + + prefs_data = json.loads(content) + + # Convert to PreferenceItem objects + preferences = [] + for i, pref_dict in enumerate(prefs_data): + pref_id = f"{topic_name[:4]}_{i+1:03d}" + pref = PreferenceItem( + id=pref_id, + topic=topic_name, + condition=pref_dict.get("condition", ""), + action=pref_dict.get("action", ""), + conflict_group=pref_dict.get("conflict_group"), + enforce_description=pref_dict.get("enforce_description", ""), + example_violation=pref_dict.get("example_violation", ""), + example_compliance=pref_dict.get("example_compliance", ""), + ) + preferences.append(pref) + + return preferences + + except Exception as e: + print(f"Error generating preferences for {topic_name}: {e}") + return self._generate_fallback_preferences(topic_name, num_prefs) + + def _generate_fallback_preferences( + self, + topic_name: str, + num_prefs: int = 5, + ) -> List[PreferenceItem]: + """Generate hardcoded fallback preferences when LLM is not available.""" + + fallbacks = { + "math_formatting": [ + PreferenceItem( + id="math_001", topic="math_formatting", + condition="solving math problems", + action="use LaTeX for all formulas and equations", + conflict_group="math_notation", + enforce_description="Check if mathematical expressions use LaTeX syntax like $x^2$ or $$\\int$$", + example_violation="The answer is x squared plus 2x plus 1", + example_compliance="The answer is $x^2 + 2x + 1$", + ), + PreferenceItem( + id="math_002", topic="math_formatting", + condition="explaining mathematical concepts", + action="use plain text only, avoid any mathematical notation", + conflict_group="math_notation", + enforce_description="Check if response contains any LaTeX or special math symbols", + example_violation="We need to find $\\frac{d}{dx}(x^2)$", + example_compliance="We need to find the derivative of x squared", + ), + PreferenceItem( + id="math_003", topic="math_formatting", + condition="showing multi-step calculations", + action="display each step on a separate line with clear labels", + conflict_group=None, + enforce_description="Check if steps are on separate lines with labels like 'Step 1:'", + example_violation="First we add 2+3=5, then multiply by 4 to get 20", + example_compliance="Step 1: Add 2 + 3 = 5\nStep 2: Multiply by 4: 5 × 4 = 20", + ), + PreferenceItem( + id="math_004", topic="math_formatting", + condition="presenting final answers", + action="clearly box or highlight the final answer", + conflict_group=None, + enforce_description="Check if final answer is visually distinguished", + example_violation="So x equals 5.", + example_compliance="**Final Answer: x = 5**", + ), + PreferenceItem( + id="math_005", topic="math_formatting", + condition="solving problems with multiple variables", + action="use single-letter variables (x, y, z) rather than descriptive names", + conflict_group="var_naming", + enforce_description="Check if variables are single letters", + example_violation="Let price = 100 and quantity = 5", + example_compliance="Let p = 100 and q = 5", + ), + ], + "coding_style": [ + PreferenceItem( + id="code_001", topic="coding_style", + condition="providing code examples", + action="always use Python", + conflict_group="language", + enforce_description="Check if code is written in Python", + example_violation="```javascript\nfunction add(a, b) { return a + b; }\n```", + example_compliance="```python\ndef add(a, b):\n return a + b\n```", + ), + PreferenceItem( + id="code_002", topic="coding_style", + condition="providing code examples", + action="always use JavaScript or TypeScript", + conflict_group="language", + enforce_description="Check if code is written in JavaScript/TypeScript", + example_violation="```python\ndef add(a, b): return a + b\n```", + example_compliance="```javascript\nconst add = (a, b) => a + b;\n```", + ), + PreferenceItem( + id="code_003", topic="coding_style", + condition="writing Python functions", + action="always include type hints for parameters and return values", + conflict_group=None, + enforce_description="Check if function has type hints", + example_violation="def add(a, b):\n return a + b", + example_compliance="def add(a: int, b: int) -> int:\n return a + b", + ), + PreferenceItem( + id="code_004", topic="coding_style", + condition="writing functions", + action="include a docstring explaining the function", + conflict_group=None, + enforce_description="Check if function has a docstring", + example_violation="def add(a, b):\n return a + b", + example_compliance='def add(a, b):\n """Add two numbers and return the result."""\n return a + b', + ), + PreferenceItem( + id="code_005", topic="coding_style", + condition="writing code", + action="minimize comments, code should be self-documenting", + conflict_group="comment_style", + enforce_description="Check if there are excessive inline comments", + example_violation="x = x + 1 # increment x by 1", + example_compliance="x += 1", + ), + ], + "response_structure": [ + PreferenceItem( + id="struct_001", topic="response_structure", + condition="providing multi-point answers", + action="use bullet points with '-' or '*'", + conflict_group="list_style", + enforce_description="Check if response uses bullet points", + example_violation="First, do X. Second, do Y. Third, do Z.", + example_compliance="- First, do X\n- Second, do Y\n- Third, do Z", + ), + PreferenceItem( + id="struct_002", topic="response_structure", + condition="providing step-by-step instructions", + action="use numbered lists", + conflict_group="list_style", + enforce_description="Check if response uses numbered lists", + example_violation="First do X, then do Y, finally do Z.", + example_compliance="1. Do X\n2. Do Y\n3. Do Z", + ), + PreferenceItem( + id="struct_003", topic="response_structure", + condition="writing explanations", + action="use flowing prose paragraphs, avoid lists", + conflict_group="list_style", + enforce_description="Check if response uses prose instead of lists", + example_violation="Key points:\n- Point 1\n- Point 2", + example_compliance="The key insight here is that Point 1 connects to Point 2 through...", + ), + PreferenceItem( + id="struct_004", topic="response_structure", + condition="providing long explanations", + action="include a TL;DR summary at the end", + conflict_group=None, + enforce_description="Check if response ends with TL;DR", + example_violation="... and that's how it works.", + example_compliance="... and that's how it works.\n\n**TL;DR:** X does Y by Z.", + ), + PreferenceItem( + id="struct_005", topic="response_structure", + condition="explaining complex topics", + action="start with an outline of what will be covered", + conflict_group=None, + enforce_description="Check if response starts with an outline", + example_violation="Let me explain recursion. First, understand that...", + example_compliance="I'll cover: 1) What is recursion, 2) How it works, 3) Examples.\n\n**1) What is recursion**...", + ), + ], + "explanation_depth": [ + PreferenceItem( + id="depth_001", topic="explanation_depth", + condition="answering questions", + action="be concise, no more than 3 sentences", + conflict_group="length", + enforce_description="Count sentences, should be 3 or fewer", + example_violation="Let me explain in detail. First... Second... Third... Fourth... Fifth...", + example_compliance="The answer is X. This works because of Y. Here's how to apply it: Z.", + ), + PreferenceItem( + id="depth_002", topic="explanation_depth", + condition="explaining concepts", + action="provide comprehensive, detailed explanations", + conflict_group="length", + enforce_description="Check if explanation is thorough with multiple aspects covered", + example_violation="It's X. Done.", + example_compliance="Let me explain X in detail. The concept originates from... It works by... Common applications include... Here's an example...", + ), + PreferenceItem( + id="depth_003", topic="explanation_depth", + condition="explaining anything", + action="always include at least one concrete example", + conflict_group=None, + enforce_description="Check if at least one example is provided", + example_violation="A binary tree is a data structure where each node has at most two children.", + example_compliance="A binary tree is a data structure where each node has at most two children. For example, in [5, 3, 7], 5 is the root, 3 is left child, 7 is right child.", + ), + PreferenceItem( + id="depth_004", topic="explanation_depth", + condition="explaining technical concepts", + action="use analogies from everyday life", + conflict_group=None, + enforce_description="Check if explanation includes an everyday analogy", + example_violation="A stack is a LIFO data structure.", + example_compliance="A stack is like a stack of plates - you can only take the top one (LIFO).", + ), + PreferenceItem( + id="depth_005", topic="explanation_depth", + condition="solving problems", + action="state assumptions explicitly before solving", + conflict_group=None, + enforce_description="Check if assumptions are stated upfront", + example_violation="The answer is 42.", + example_compliance="Assuming n is positive and integer, the answer is 42.", + ), + ], + "interaction_style": [ + PreferenceItem( + id="inter_001", topic="interaction_style", + condition="receiving unclear requests", + action="ask clarifying questions before attempting to answer", + conflict_group="clarification", + enforce_description="Check if agent asks questions when request is ambiguous", + example_violation="Here's a solution assuming you meant X...", + example_compliance="Before I help, could you clarify: do you mean X or Y?", + ), + PreferenceItem( + id="inter_002", topic="interaction_style", + condition="receiving requests", + action="make reasonable assumptions and proceed without asking", + conflict_group="clarification", + enforce_description="Check if agent proceeds with reasonable assumptions", + example_violation="What exactly do you mean by 'large'? What size range?", + example_compliance="Assuming you mean 'large' as over 1000 items, here's the solution...", + ), + PreferenceItem( + id="inter_003", topic="interaction_style", + condition="solving multi-step problems", + action="present one step at a time and ask for confirmation before proceeding", + conflict_group="pacing", + enforce_description="Check if agent pauses after each step", + example_violation="Step 1: X. Step 2: Y. Step 3: Z. Done!", + example_compliance="Step 1: X. Does this make sense? Should I continue to Step 2?", + ), + PreferenceItem( + id="inter_004", topic="interaction_style", + condition="solving problems", + action="provide the complete solution at once without pausing", + conflict_group="pacing", + enforce_description="Check if agent gives complete solution without asking to continue", + example_violation="First, let me do step 1... Should I continue?", + example_compliance="Here's the complete solution: Step 1: X, Step 2: Y, Step 3: Z.", + ), + PreferenceItem( + id="inter_005", topic="interaction_style", + condition="providing answers", + action="include a confidence level (e.g., 'I'm 90% confident')", + conflict_group=None, + enforce_description="Check if response includes confidence level", + example_violation="The answer is 42.", + example_compliance="I'm about 95% confident the answer is 42.", + ), + ], + } + + if topic_name in fallbacks: + return fallbacks[topic_name][:num_prefs] + else: + # Generic fallback + return [ + PreferenceItem( + id=f"{topic_name[:4]}_{i+1:03d}", + topic=topic_name, + condition=f"interacting about {topic_name}", + action=f"preference {i+1} for {topic_name}", + conflict_group=None, + enforce_description=f"Check preference {i+1}", + example_violation=f"Violation example {i+1}", + example_compliance=f"Compliance example {i+1}", + ) + for i in range(num_prefs) + ] + + def generate_bank( + self, + topics: Dict[str, Dict] = None, + prefs_per_topic: int = 5, + ) -> PreferenceBank: + """Generate a complete preference bank.""" + + if topics is None: + topics = DEMO_TOPICS + + bank = PreferenceBank() + + for topic_name, topic_config in topics.items(): + print(f"Generating preferences for topic: {topic_name}...") + + preferences = self.generate_preferences_for_topic( + topic_name=topic_name, + topic_description=topic_config["description"], + hints=topic_config.get("generation_hints", []), + num_prefs=prefs_per_topic, + ) + + topic = PreferenceTopic( + name=topic_name, + description=topic_config["description"], + related_datasets=topic_config["related_datasets"], + preferences=preferences, + ) + + bank.add_topic(topic) + print(f" Generated {len(preferences)} preferences") + + return bank + + +def generate_demo_bank( + output_path: str = None, + use_llm: bool = False, + prefs_per_topic: int = 5, +) -> PreferenceBank: + """ + Generate a demo preference bank. + + Args: + output_path: If provided, save bank to this path + use_llm: If True, use LLM to generate. If False, use hardcoded fallbacks. + prefs_per_topic: Number of preferences per topic + + Returns: + Generated PreferenceBank + """ + if use_llm: + generator = PreferenceBankGenerator() + else: + generator = PreferenceBankGenerator(llm_client=None) # Use fallbacks + + bank = generator.generate_bank( + topics=DEMO_TOPICS, + prefs_per_topic=prefs_per_topic, + ) + + if output_path: + bank.save(output_path) + print(f"Saved bank to {output_path}") + + print(f"\nBank Statistics: {bank.stats()}") + + return bank + + +if __name__ == "__main__": + # Generate demo bank with fallback preferences + import os + script_dir = os.path.dirname(os.path.abspath(__file__)) + output_path = os.path.join(script_dir, "bank_demo.json") + + bank = generate_demo_bank(output_path=output_path, use_llm=False) + + diff --git a/src/personalization/evaluation/preference_bank/schemas.py b/src/personalization/evaluation/preference_bank/schemas.py new file mode 100644 index 0000000..f219487 --- /dev/null +++ b/src/personalization/evaluation/preference_bank/schemas.py @@ -0,0 +1,147 @@ +""" +Preference Bank Schemas + +Defines the data structures for user preferences, organized by topic. +Each preference has a condition (when it applies), action (what the user wants), +and optional conflict group (preferences in the same group are mutually exclusive). +""" + +from dataclasses import dataclass, field +from typing import Optional, List, Dict, Any +import json + + +@dataclass +class PreferenceItem: + """A single user preference.""" + id: str # Unique ID, e.g., "math_fmt_001" + topic: str # Topic name, e.g., "math_formatting" + condition: str # When this preference applies + action: str # What the user prefers + conflict_group: Optional[str] # If set, only one pref from this group can be selected + enforce_description: str # Description for user simulator on how to enforce + example_violation: str # Example of agent response that violates this + example_compliance: str # Example that follows this preference + + def to_dict(self) -> Dict[str, Any]: + return { + "id": self.id, + "topic": self.topic, + "condition": self.condition, + "action": self.action, + "conflict_group": self.conflict_group, + "enforce_description": self.enforce_description, + "example_violation": self.example_violation, + "example_compliance": self.example_compliance, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PreferenceItem": + return cls(**data) + + def format_for_user(self) -> str: + """Format for user simulator prompt.""" + return f"When {self.condition}: {self.action}" + + def format_for_enforcement(self) -> str: + """Format with enforcement details.""" + return f"[{self.id}] When {self.condition}: {self.action}\n Enforce if: {self.enforce_description}" + + +@dataclass +class PreferenceTopic: + """A topic containing multiple related preferences.""" + name: str # Topic name, e.g., "math_formatting" + description: str # Description of this topic + related_datasets: List[str] # Datasets where this topic is relevant + preferences: List[PreferenceItem] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return { + "name": self.name, + "description": self.description, + "related_datasets": self.related_datasets, + "preferences": [p.to_dict() for p in self.preferences], + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PreferenceTopic": + prefs = [PreferenceItem.from_dict(p) for p in data.get("preferences", [])] + return cls( + name=data["name"], + description=data["description"], + related_datasets=data["related_datasets"], + preferences=prefs, + ) + + +@dataclass +class PreferenceBank: + """ + A bank of preferences organized by topic. + Used to generate user profiles by sampling preferences. + """ + topics: Dict[str, PreferenceTopic] = field(default_factory=dict) + version: str = "1.0" + + def add_topic(self, topic: PreferenceTopic): + self.topics[topic.name] = topic + + def get_all_preferences(self) -> List[PreferenceItem]: + """Get all preferences across all topics.""" + all_prefs = [] + for topic in self.topics.values(): + all_prefs.extend(topic.preferences) + return all_prefs + + def get_preferences_for_dataset(self, dataset: str) -> List[PreferenceItem]: + """Get preferences relevant to a specific dataset.""" + relevant = [] + for topic in self.topics.values(): + if dataset in topic.related_datasets or "all" in topic.related_datasets: + relevant.extend(topic.preferences) + return relevant + + def to_dict(self) -> Dict[str, Any]: + return { + "version": self.version, + "topics": {name: topic.to_dict() for name, topic in self.topics.items()}, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PreferenceBank": + bank = cls(version=data.get("version", "1.0")) + for name, topic_data in data.get("topics", {}).items(): + bank.topics[name] = PreferenceTopic.from_dict(topic_data) + return bank + + def save(self, path: str): + """Save bank to JSON file.""" + with open(path, "w", encoding="utf-8") as f: + json.dump(self.to_dict(), f, indent=2, ensure_ascii=False) + + @classmethod + def load(cls, path: str) -> "PreferenceBank": + """Load bank from JSON file.""" + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + return cls.from_dict(data) + + def stats(self) -> Dict[str, Any]: + """Get statistics about the bank.""" + total_prefs = 0 + conflict_groups = set() + for topic in self.topics.values(): + total_prefs += len(topic.preferences) + for pref in topic.preferences: + if pref.conflict_group: + conflict_groups.add(pref.conflict_group) + + return { + "num_topics": len(self.topics), + "total_preferences": total_prefs, + "num_conflict_groups": len(conflict_groups), + "prefs_per_topic": {name: len(t.preferences) for name, t in self.topics.items()}, + } + + |
