summaryrefslogtreecommitdiff
path: root/src/personalization/evaluation/profiles
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-27 09:57:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-27 09:57:37 -0600
commitdc801c07cf38b0c495686463e6ca6f871a64440e (patch)
tree599f03114775921dbc472403c701f4a3a8ea188a /src/personalization/evaluation/profiles
parente43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (diff)
Add collaborativeagents module and update gitignore
- Add collaborativeagents subproject with adapters, agents, and evaluation modules - Update .gitignore to exclude large binary files (.whl, .tar), wandb logs, and results Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'src/personalization/evaluation/profiles')
-rw-r--r--src/personalization/evaluation/profiles/__init__.py5
-rw-r--r--src/personalization/evaluation/profiles/generator.py351
2 files changed, 356 insertions, 0 deletions
diff --git a/src/personalization/evaluation/profiles/__init__.py b/src/personalization/evaluation/profiles/__init__.py
new file mode 100644
index 0000000..8532af9
--- /dev/null
+++ b/src/personalization/evaluation/profiles/__init__.py
@@ -0,0 +1,5 @@
+from .generator import UserProfile, UserProfileGenerator
+
+__all__ = ["UserProfile", "UserProfileGenerator"]
+
+
diff --git a/src/personalization/evaluation/profiles/generator.py b/src/personalization/evaluation/profiles/generator.py
new file mode 100644
index 0000000..da847a0
--- /dev/null
+++ b/src/personalization/evaluation/profiles/generator.py
@@ -0,0 +1,351 @@
+"""
+User Profile Generator
+
+Generates user profiles by sampling preferences from the preference bank.
+Ensures no conflicting preferences within same conflict_group, but allows
+cross-topic scenario conflicts (which is desired for testing RAG).
+"""
+
+import json
+import random
+from collections import defaultdict
+from dataclasses import dataclass, field
+from typing import List, Dict, Set, Optional, Any
+
+from ..preference_bank.schemas import PreferenceItem, PreferenceBank
+
+
+@dataclass
+class UserProfile:
+ """A simulated user with specific preferences."""
+ user_id: str
+ persona: str # Background description
+ preferences: List[PreferenceItem] # Selected preferences
+ primary_topics: List[str] # Topics this user cares most about
+ preference_by_topic: Dict[str, List[PreferenceItem]] = field(default_factory=dict)
+
+ def __post_init__(self):
+ # Build topic index if not provided
+ if not self.preference_by_topic:
+ self.preference_by_topic = defaultdict(list)
+ for pref in self.preferences:
+ self.preference_by_topic[pref.topic].append(pref)
+ self.preference_by_topic = dict(self.preference_by_topic)
+
+ def get_preferences_for_topic(self, topic: str) -> List[PreferenceItem]:
+ """Get preferences for a specific topic."""
+ return self.preference_by_topic.get(topic, [])
+
+ def get_preferences_for_dataset(self, dataset: str, bank: PreferenceBank) -> List[PreferenceItem]:
+ """Get preferences relevant to a specific dataset."""
+ relevant_topics = set()
+ for topic_name, topic in bank.topics.items():
+ if dataset in topic.related_datasets or "all" in topic.related_datasets:
+ relevant_topics.add(topic_name)
+
+ relevant_prefs = []
+ for pref in self.preferences:
+ if pref.topic in relevant_topics:
+ relevant_prefs.append(pref)
+ return relevant_prefs
+
+ def format_preferences_grouped(self) -> str:
+ """Format preferences grouped by topic for prompts."""
+ lines = []
+ for topic, prefs in self.preference_by_topic.items():
+ topic_title = topic.replace("_", " ").title()
+ lines.append(f"\n## {topic_title}")
+ for pref in prefs:
+ lines.append(f" [{pref.id}] When {pref.condition}: {pref.action}")
+ lines.append(f" Enforce if: {pref.enforce_description}")
+ return "\n".join(lines)
+
+ def format_preferences_flat(self) -> str:
+ """Format preferences as a flat list."""
+ lines = []
+ for i, pref in enumerate(self.preferences, 1):
+ lines.append(f"{i}. When {pref.condition}: {pref.action}")
+ return "\n".join(lines)
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "user_id": self.user_id,
+ "persona": self.persona,
+ "preferences": [p.to_dict() for p in self.preferences],
+ "primary_topics": self.primary_topics,
+ }
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> "UserProfile":
+ prefs = [PreferenceItem.from_dict(p) for p in data.get("preferences", [])]
+ return cls(
+ user_id=data["user_id"],
+ persona=data["persona"],
+ preferences=prefs,
+ primary_topics=data.get("primary_topics", []),
+ )
+
+ def stats(self) -> Dict[str, Any]:
+ """Get statistics about this profile."""
+ conflict_groups = set()
+ for pref in self.preferences:
+ if pref.conflict_group:
+ conflict_groups.add(pref.conflict_group)
+
+ return {
+ "user_id": self.user_id,
+ "num_preferences": len(self.preferences),
+ "num_topics": len(self.preference_by_topic),
+ "prefs_per_topic": {t: len(ps) for t, ps in self.preference_by_topic.items()},
+ "num_conflict_groups_used": len(conflict_groups),
+ }
+
+
+# Persona templates for different user types
+PERSONA_TEMPLATES = [
+ "A {field} professional who values {trait} and prefers {style} communication.",
+ "A graduate student in {field} who appreciates {trait} and likes responses that are {style}.",
+ "An experienced {field} practitioner who prioritizes {trait} and expects {style} explanations.",
+ "A beginner learning {field} who needs {trait} and responds well to {style} guidance.",
+ "A {field} enthusiast who cares about {trait} and prefers {style} interactions.",
+]
+
+FIELDS = [
+ "software engineering", "data science", "mathematics", "physics",
+ "medical research", "financial analysis", "machine learning",
+ "web development", "systems programming", "algorithm design",
+]
+
+TRAITS = [
+ "clarity", "precision", "efficiency", "thoroughness", "simplicity",
+ "formality", "practicality", "theoretical depth", "hands-on examples",
+]
+
+STYLES = [
+ "concise", "detailed", "step-by-step", "example-driven", "formal",
+ "conversational", "structured", "visual", "analytical",
+]
+
+
+class UserProfileGenerator:
+ """Generates user profiles by sampling from preference bank."""
+
+ def __init__(
+ self,
+ preference_bank: PreferenceBank,
+ target_num_prefs: int = 15, # For demo, use smaller number
+ seed: Optional[int] = None,
+ ):
+ self.bank = preference_bank
+ self.target_num = target_num_prefs
+
+ if seed is not None:
+ random.seed(seed)
+
+ def generate_profile(
+ self,
+ user_id: str,
+ primary_topics: List[str] = None,
+ persona: str = None,
+ ) -> UserProfile:
+ """
+ Generate a user profile by sampling preferences.
+
+ Args:
+ user_id: Unique identifier for this user
+ primary_topics: Topics this user cares most about (get more prefs from these)
+ persona: Optional persona description. If None, will be generated.
+
+ Returns:
+ UserProfile with sampled preferences
+ """
+ selected: List[PreferenceItem] = []
+ used_conflict_groups: Set[str] = set()
+
+ # If no primary topics specified, randomly select 1-2
+ if primary_topics is None:
+ all_topics = list(self.bank.topics.keys())
+ num_primary = random.randint(1, min(2, len(all_topics)))
+ primary_topics = random.sample(all_topics, num_primary)
+
+ # Compute quotas for each topic
+ topic_quotas = self._compute_quotas(primary_topics)
+
+ # Sample from each topic
+ for topic_name, quota in topic_quotas.items():
+ if topic_name not in self.bank.topics:
+ continue
+
+ topic = self.bank.topics[topic_name]
+
+ # Filter out preferences with already-used conflict groups
+ available = [
+ p for p in topic.preferences
+ if p.conflict_group is None or p.conflict_group not in used_conflict_groups
+ ]
+
+ # Sample up to quota
+ to_select = min(quota, len(available))
+ if to_select > 0:
+ sampled = random.sample(available, to_select)
+
+ for pref in sampled:
+ selected.append(pref)
+ if pref.conflict_group:
+ used_conflict_groups.add(pref.conflict_group)
+
+ # Generate persona if not provided
+ if persona is None:
+ persona = self._generate_persona(primary_topics)
+
+ return UserProfile(
+ user_id=user_id,
+ persona=persona,
+ preferences=selected,
+ primary_topics=primary_topics,
+ )
+
+ def _compute_quotas(self, primary_topics: List[str]) -> Dict[str, int]:
+ """Compute how many preferences to sample from each topic."""
+ quotas = {}
+ all_topics = list(self.bank.topics.keys())
+
+ # Base quota for all topics
+ base_quota = max(1, self.target_num // len(all_topics))
+
+ for topic_name in all_topics:
+ if topic_name in primary_topics:
+ # Primary topics get more preferences
+ quotas[topic_name] = base_quota + random.randint(1, 3)
+ else:
+ quotas[topic_name] = max(1, base_quota - random.randint(0, 1))
+
+ # Adjust to match target
+ total = sum(quotas.values())
+ if total < self.target_num:
+ # Add more to primary topics
+ for topic in primary_topics:
+ if topic in quotas:
+ quotas[topic] += (self.target_num - total) // len(primary_topics)
+
+ return quotas
+
+ def _generate_persona(self, primary_topics: List[str]) -> str:
+ """Generate a persona description based on primary topics."""
+ template = random.choice(PERSONA_TEMPLATES)
+
+ # Map topics to fields
+ topic_to_field = {
+ "math_formatting": ["mathematics", "physics", "data science"],
+ "coding_style": ["software engineering", "web development", "systems programming"],
+ "response_structure": ["technical writing", "documentation", "education"],
+ "explanation_depth": ["research", "teaching", "consulting"],
+ "interaction_style": ["customer support", "mentoring", "collaboration"],
+ }
+
+ # Pick a field related to primary topics
+ possible_fields = []
+ for topic in primary_topics:
+ possible_fields.extend(topic_to_field.get(topic, FIELDS[:3]))
+
+ if not possible_fields:
+ possible_fields = FIELDS
+
+ field = random.choice(possible_fields)
+ trait = random.choice(TRAITS)
+ style = random.choice(STYLES)
+
+ return template.format(field=field, trait=trait, style=style)
+
+ def generate_profiles(
+ self,
+ num_users: int,
+ id_prefix: str = "user",
+ ) -> List[UserProfile]:
+ """Generate multiple user profiles."""
+ profiles = []
+
+ for i in range(num_users):
+ user_id = f"{id_prefix}_{i:03d}"
+ profile = self.generate_profile(user_id)
+ profiles.append(profile)
+
+ return profiles
+
+ def save_profiles(self, profiles: List[UserProfile], path: str):
+ """Save profiles to JSON file."""
+ data = [p.to_dict() for p in profiles]
+ with open(path, "w", encoding="utf-8") as f:
+ json.dump(data, f, indent=2, ensure_ascii=False)
+
+ @staticmethod
+ def load_profiles(path: str) -> List[UserProfile]:
+ """Load profiles from JSON file."""
+ with open(path, "r", encoding="utf-8") as f:
+ data = json.load(f)
+ return [UserProfile.from_dict(d) for d in data]
+
+
+def generate_demo_profiles(
+ bank: PreferenceBank,
+ num_users: int = 2,
+ prefs_per_user: int = 10,
+ output_path: str = None,
+ seed: int = 42,
+) -> List[UserProfile]:
+ """
+ Generate demo user profiles.
+
+ Args:
+ bank: Preference bank to sample from
+ num_users: Number of users to generate
+ prefs_per_user: Target preferences per user
+ output_path: If provided, save profiles to this path
+ seed: Random seed for reproducibility
+
+ Returns:
+ List of UserProfile objects
+ """
+ generator = UserProfileGenerator(
+ preference_bank=bank,
+ target_num_prefs=prefs_per_user,
+ seed=seed,
+ )
+
+ profiles = generator.generate_profiles(num_users, id_prefix="demo_user")
+
+ if output_path:
+ generator.save_profiles(profiles, output_path)
+ print(f"Saved {len(profiles)} profiles to {output_path}")
+
+ # Print stats
+ for profile in profiles:
+ print(f"\n{profile.user_id}: {profile.stats()}")
+ print(f" Persona: {profile.persona}")
+
+ return profiles
+
+
+if __name__ == "__main__":
+ import os
+ from ..preference_bank.generator import generate_demo_bank
+
+ # Generate bank first
+ script_dir = os.path.dirname(os.path.abspath(__file__))
+ bank_path = os.path.join(script_dir, "..", "preference_bank", "bank_demo.json")
+
+ if os.path.exists(bank_path):
+ bank = PreferenceBank.load(bank_path)
+ else:
+ bank = generate_demo_bank()
+
+ # Generate profiles
+ profiles_path = os.path.join(script_dir, "profiles_demo.json")
+ profiles = generate_demo_profiles(
+ bank=bank,
+ num_users=2,
+ prefs_per_user=10,
+ output_path=profiles_path,
+ )
+
+