""" Preference Bank Schemas Defines the data structures for user preferences, organized by topic. Each preference has a condition (when it applies), action (what the user wants), and optional conflict group (preferences in the same group are mutually exclusive). """ from dataclasses import dataclass, field from typing import Optional, List, Dict, Any import json @dataclass class PreferenceItem: """A single user preference.""" id: str # Unique ID, e.g., "math_fmt_001" topic: str # Topic name, e.g., "math_formatting" condition: str # When this preference applies action: str # What the user prefers conflict_group: Optional[str] # If set, only one pref from this group can be selected enforce_description: str # Description for user simulator on how to enforce example_violation: str # Example of agent response that violates this example_compliance: str # Example that follows this preference def to_dict(self) -> Dict[str, Any]: return { "id": self.id, "topic": self.topic, "condition": self.condition, "action": self.action, "conflict_group": self.conflict_group, "enforce_description": self.enforce_description, "example_violation": self.example_violation, "example_compliance": self.example_compliance, } @classmethod def from_dict(cls, data: Dict[str, Any]) -> "PreferenceItem": return cls(**data) def format_for_user(self) -> str: """Format for user simulator prompt.""" return f"When {self.condition}: {self.action}" def format_for_enforcement(self) -> str: """Format with enforcement details.""" return f"[{self.id}] When {self.condition}: {self.action}\n Enforce if: {self.enforce_description}" @dataclass class PreferenceTopic: """A topic containing multiple related preferences.""" name: str # Topic name, e.g., "math_formatting" description: str # Description of this topic related_datasets: List[str] # Datasets where this topic is relevant preferences: List[PreferenceItem] = field(default_factory=list) def to_dict(self) -> Dict[str, Any]: return { "name": self.name, "description": self.description, "related_datasets": self.related_datasets, "preferences": [p.to_dict() for p in self.preferences], } @classmethod def from_dict(cls, data: Dict[str, Any]) -> "PreferenceTopic": prefs = [PreferenceItem.from_dict(p) for p in data.get("preferences", [])] return cls( name=data["name"], description=data["description"], related_datasets=data["related_datasets"], preferences=prefs, ) @dataclass class PreferenceBank: """ A bank of preferences organized by topic. Used to generate user profiles by sampling preferences. """ topics: Dict[str, PreferenceTopic] = field(default_factory=dict) version: str = "1.0" def add_topic(self, topic: PreferenceTopic): self.topics[topic.name] = topic def get_all_preferences(self) -> List[PreferenceItem]: """Get all preferences across all topics.""" all_prefs = [] for topic in self.topics.values(): all_prefs.extend(topic.preferences) return all_prefs def get_preferences_for_dataset(self, dataset: str) -> List[PreferenceItem]: """Get preferences relevant to a specific dataset.""" relevant = [] for topic in self.topics.values(): if dataset in topic.related_datasets or "all" in topic.related_datasets: relevant.extend(topic.preferences) return relevant def to_dict(self) -> Dict[str, Any]: return { "version": self.version, "topics": {name: topic.to_dict() for name, topic in self.topics.items()}, } @classmethod def from_dict(cls, data: Dict[str, Any]) -> "PreferenceBank": bank = cls(version=data.get("version", "1.0")) for name, topic_data in data.get("topics", {}).items(): bank.topics[name] = PreferenceTopic.from_dict(topic_data) return bank def save(self, path: str): """Save bank to JSON file.""" with open(path, "w", encoding="utf-8") as f: json.dump(self.to_dict(), f, indent=2, ensure_ascii=False) @classmethod def load(cls, path: str) -> "PreferenceBank": """Load bank from JSON file.""" with open(path, "r", encoding="utf-8") as f: data = json.load(f) return cls.from_dict(data) def stats(self) -> Dict[str, Any]: """Get statistics about the bank.""" total_prefs = 0 conflict_groups = set() for topic in self.topics.values(): total_prefs += len(topic.preferences) for pref in topic.preferences: if pref.conflict_group: conflict_groups.add(pref.conflict_group) return { "num_topics": len(self.topics), "total_preferences": total_prefs, "num_conflict_groups": len(conflict_groups), "prefs_per_topic": {name: len(t.preferences) for name, t in self.topics.items()}, }