summaryrefslogtreecommitdiff
path: root/src/personalization/evaluation/preference_bank/schemas.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/personalization/evaluation/preference_bank/schemas.py')
-rw-r--r--src/personalization/evaluation/preference_bank/schemas.py147
1 files changed, 147 insertions, 0 deletions
diff --git a/src/personalization/evaluation/preference_bank/schemas.py b/src/personalization/evaluation/preference_bank/schemas.py
new file mode 100644
index 0000000..f219487
--- /dev/null
+++ b/src/personalization/evaluation/preference_bank/schemas.py
@@ -0,0 +1,147 @@
+"""
+Preference Bank Schemas
+
+Defines the data structures for user preferences, organized by topic.
+Each preference has a condition (when it applies), action (what the user wants),
+and optional conflict group (preferences in the same group are mutually exclusive).
+"""
+
+from dataclasses import dataclass, field
+from typing import Optional, List, Dict, Any
+import json
+
+
+@dataclass
+class PreferenceItem:
+ """A single user preference."""
+ id: str # Unique ID, e.g., "math_fmt_001"
+ topic: str # Topic name, e.g., "math_formatting"
+ condition: str # When this preference applies
+ action: str # What the user prefers
+ conflict_group: Optional[str] # If set, only one pref from this group can be selected
+ enforce_description: str # Description for user simulator on how to enforce
+ example_violation: str # Example of agent response that violates this
+ example_compliance: str # Example that follows this preference
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "id": self.id,
+ "topic": self.topic,
+ "condition": self.condition,
+ "action": self.action,
+ "conflict_group": self.conflict_group,
+ "enforce_description": self.enforce_description,
+ "example_violation": self.example_violation,
+ "example_compliance": self.example_compliance,
+ }
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> "PreferenceItem":
+ return cls(**data)
+
+ def format_for_user(self) -> str:
+ """Format for user simulator prompt."""
+ return f"When {self.condition}: {self.action}"
+
+ def format_for_enforcement(self) -> str:
+ """Format with enforcement details."""
+ return f"[{self.id}] When {self.condition}: {self.action}\n Enforce if: {self.enforce_description}"
+
+
+@dataclass
+class PreferenceTopic:
+ """A topic containing multiple related preferences."""
+ name: str # Topic name, e.g., "math_formatting"
+ description: str # Description of this topic
+ related_datasets: List[str] # Datasets where this topic is relevant
+ preferences: List[PreferenceItem] = field(default_factory=list)
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "name": self.name,
+ "description": self.description,
+ "related_datasets": self.related_datasets,
+ "preferences": [p.to_dict() for p in self.preferences],
+ }
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> "PreferenceTopic":
+ prefs = [PreferenceItem.from_dict(p) for p in data.get("preferences", [])]
+ return cls(
+ name=data["name"],
+ description=data["description"],
+ related_datasets=data["related_datasets"],
+ preferences=prefs,
+ )
+
+
+@dataclass
+class PreferenceBank:
+ """
+ A bank of preferences organized by topic.
+ Used to generate user profiles by sampling preferences.
+ """
+ topics: Dict[str, PreferenceTopic] = field(default_factory=dict)
+ version: str = "1.0"
+
+ def add_topic(self, topic: PreferenceTopic):
+ self.topics[topic.name] = topic
+
+ def get_all_preferences(self) -> List[PreferenceItem]:
+ """Get all preferences across all topics."""
+ all_prefs = []
+ for topic in self.topics.values():
+ all_prefs.extend(topic.preferences)
+ return all_prefs
+
+ def get_preferences_for_dataset(self, dataset: str) -> List[PreferenceItem]:
+ """Get preferences relevant to a specific dataset."""
+ relevant = []
+ for topic in self.topics.values():
+ if dataset in topic.related_datasets or "all" in topic.related_datasets:
+ relevant.extend(topic.preferences)
+ return relevant
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "version": self.version,
+ "topics": {name: topic.to_dict() for name, topic in self.topics.items()},
+ }
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> "PreferenceBank":
+ bank = cls(version=data.get("version", "1.0"))
+ for name, topic_data in data.get("topics", {}).items():
+ bank.topics[name] = PreferenceTopic.from_dict(topic_data)
+ return bank
+
+ def save(self, path: str):
+ """Save bank to JSON file."""
+ with open(path, "w", encoding="utf-8") as f:
+ json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
+
+ @classmethod
+ def load(cls, path: str) -> "PreferenceBank":
+ """Load bank from JSON file."""
+ with open(path, "r", encoding="utf-8") as f:
+ data = json.load(f)
+ return cls.from_dict(data)
+
+ def stats(self) -> Dict[str, Any]:
+ """Get statistics about the bank."""
+ total_prefs = 0
+ conflict_groups = set()
+ for topic in self.topics.values():
+ total_prefs += len(topic.preferences)
+ for pref in topic.preferences:
+ if pref.conflict_group:
+ conflict_groups.add(pref.conflict_group)
+
+ return {
+ "num_topics": len(self.topics),
+ "total_preferences": total_prefs,
+ "num_conflict_groups": len(conflict_groups),
+ "prefs_per_topic": {name: len(t.preferences) for name, t in self.topics.items()},
+ }
+
+