From dc801c07cf38b0c495686463e6ca6f871a64440e Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Tue, 27 Jan 2026 09:57:37 -0600 Subject: Add collaborativeagents module and update gitignore - Add collaborativeagents subproject with adapters, agents, and evaluation modules - Update .gitignore to exclude large binary files (.whl, .tar), wandb logs, and results Co-Authored-By: Claude Opus 4.5 --- .../evaluation/preference_bank/schemas.py | 147 +++++++++++++++++++++ 1 file changed, 147 insertions(+) create mode 100644 src/personalization/evaluation/preference_bank/schemas.py (limited to 'src/personalization/evaluation/preference_bank/schemas.py') diff --git a/src/personalization/evaluation/preference_bank/schemas.py b/src/personalization/evaluation/preference_bank/schemas.py new file mode 100644 index 0000000..f219487 --- /dev/null +++ b/src/personalization/evaluation/preference_bank/schemas.py @@ -0,0 +1,147 @@ +""" +Preference Bank Schemas + +Defines the data structures for user preferences, organized by topic. +Each preference has a condition (when it applies), action (what the user wants), +and optional conflict group (preferences in the same group are mutually exclusive). +""" + +from dataclasses import dataclass, field +from typing import Optional, List, Dict, Any +import json + + +@dataclass +class PreferenceItem: + """A single user preference.""" + id: str # Unique ID, e.g., "math_fmt_001" + topic: str # Topic name, e.g., "math_formatting" + condition: str # When this preference applies + action: str # What the user prefers + conflict_group: Optional[str] # If set, only one pref from this group can be selected + enforce_description: str # Description for user simulator on how to enforce + example_violation: str # Example of agent response that violates this + example_compliance: str # Example that follows this preference + + def to_dict(self) -> Dict[str, Any]: + return { + "id": self.id, + "topic": self.topic, + "condition": self.condition, + "action": self.action, + "conflict_group": self.conflict_group, + "enforce_description": self.enforce_description, + "example_violation": self.example_violation, + "example_compliance": self.example_compliance, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PreferenceItem": + return cls(**data) + + def format_for_user(self) -> str: + """Format for user simulator prompt.""" + return f"When {self.condition}: {self.action}" + + def format_for_enforcement(self) -> str: + """Format with enforcement details.""" + return f"[{self.id}] When {self.condition}: {self.action}\n Enforce if: {self.enforce_description}" + + +@dataclass +class PreferenceTopic: + """A topic containing multiple related preferences.""" + name: str # Topic name, e.g., "math_formatting" + description: str # Description of this topic + related_datasets: List[str] # Datasets where this topic is relevant + preferences: List[PreferenceItem] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return { + "name": self.name, + "description": self.description, + "related_datasets": self.related_datasets, + "preferences": [p.to_dict() for p in self.preferences], + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PreferenceTopic": + prefs = [PreferenceItem.from_dict(p) for p in data.get("preferences", [])] + return cls( + name=data["name"], + description=data["description"], + related_datasets=data["related_datasets"], + preferences=prefs, + ) + + +@dataclass +class PreferenceBank: + """ + A bank of preferences organized by topic. + Used to generate user profiles by sampling preferences. + """ + topics: Dict[str, PreferenceTopic] = field(default_factory=dict) + version: str = "1.0" + + def add_topic(self, topic: PreferenceTopic): + self.topics[topic.name] = topic + + def get_all_preferences(self) -> List[PreferenceItem]: + """Get all preferences across all topics.""" + all_prefs = [] + for topic in self.topics.values(): + all_prefs.extend(topic.preferences) + return all_prefs + + def get_preferences_for_dataset(self, dataset: str) -> List[PreferenceItem]: + """Get preferences relevant to a specific dataset.""" + relevant = [] + for topic in self.topics.values(): + if dataset in topic.related_datasets or "all" in topic.related_datasets: + relevant.extend(topic.preferences) + return relevant + + def to_dict(self) -> Dict[str, Any]: + return { + "version": self.version, + "topics": {name: topic.to_dict() for name, topic in self.topics.items()}, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PreferenceBank": + bank = cls(version=data.get("version", "1.0")) + for name, topic_data in data.get("topics", {}).items(): + bank.topics[name] = PreferenceTopic.from_dict(topic_data) + return bank + + def save(self, path: str): + """Save bank to JSON file.""" + with open(path, "w", encoding="utf-8") as f: + json.dump(self.to_dict(), f, indent=2, ensure_ascii=False) + + @classmethod + def load(cls, path: str) -> "PreferenceBank": + """Load bank from JSON file.""" + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + return cls.from_dict(data) + + def stats(self) -> Dict[str, Any]: + """Get statistics about the bank.""" + total_prefs = 0 + conflict_groups = set() + for topic in self.topics.values(): + total_prefs += len(topic.preferences) + for pref in topic.preferences: + if pref.conflict_group: + conflict_groups.add(pref.conflict_group) + + return { + "num_topics": len(self.topics), + "total_preferences": total_prefs, + "num_conflict_groups": len(conflict_groups), + "prefs_per_topic": {name: len(t.preferences) for name, t in self.topics.items()}, + } + + -- cgit v1.2.3