diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-27 09:57:37 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-27 09:57:37 -0600 |
| commit | dc801c07cf38b0c495686463e6ca6f871a64440e (patch) | |
| tree | 599f03114775921dbc472403c701f4a3a8ea188a /src/personalization/evaluation/profiles/generator.py | |
| parent | e43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (diff) | |
Add collaborativeagents module and update gitignore
- Add collaborativeagents subproject with adapters, agents, and evaluation modules
- Update .gitignore to exclude large binary files (.whl, .tar), wandb logs, and results
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'src/personalization/evaluation/profiles/generator.py')
| -rw-r--r-- | src/personalization/evaluation/profiles/generator.py | 351 |
1 files changed, 351 insertions, 0 deletions
diff --git a/src/personalization/evaluation/profiles/generator.py b/src/personalization/evaluation/profiles/generator.py new file mode 100644 index 0000000..da847a0 --- /dev/null +++ b/src/personalization/evaluation/profiles/generator.py @@ -0,0 +1,351 @@ +""" +User Profile Generator + +Generates user profiles by sampling preferences from the preference bank. +Ensures no conflicting preferences within same conflict_group, but allows +cross-topic scenario conflicts (which is desired for testing RAG). +""" + +import json +import random +from collections import defaultdict +from dataclasses import dataclass, field +from typing import List, Dict, Set, Optional, Any + +from ..preference_bank.schemas import PreferenceItem, PreferenceBank + + +@dataclass +class UserProfile: + """A simulated user with specific preferences.""" + user_id: str + persona: str # Background description + preferences: List[PreferenceItem] # Selected preferences + primary_topics: List[str] # Topics this user cares most about + preference_by_topic: Dict[str, List[PreferenceItem]] = field(default_factory=dict) + + def __post_init__(self): + # Build topic index if not provided + if not self.preference_by_topic: + self.preference_by_topic = defaultdict(list) + for pref in self.preferences: + self.preference_by_topic[pref.topic].append(pref) + self.preference_by_topic = dict(self.preference_by_topic) + + def get_preferences_for_topic(self, topic: str) -> List[PreferenceItem]: + """Get preferences for a specific topic.""" + return self.preference_by_topic.get(topic, []) + + def get_preferences_for_dataset(self, dataset: str, bank: PreferenceBank) -> List[PreferenceItem]: + """Get preferences relevant to a specific dataset.""" + relevant_topics = set() + for topic_name, topic in bank.topics.items(): + if dataset in topic.related_datasets or "all" in topic.related_datasets: + relevant_topics.add(topic_name) + + relevant_prefs = [] + for pref in self.preferences: + if pref.topic in relevant_topics: + relevant_prefs.append(pref) + return relevant_prefs + + def format_preferences_grouped(self) -> str: + """Format preferences grouped by topic for prompts.""" + lines = [] + for topic, prefs in self.preference_by_topic.items(): + topic_title = topic.replace("_", " ").title() + lines.append(f"\n## {topic_title}") + for pref in prefs: + lines.append(f" [{pref.id}] When {pref.condition}: {pref.action}") + lines.append(f" Enforce if: {pref.enforce_description}") + return "\n".join(lines) + + def format_preferences_flat(self) -> str: + """Format preferences as a flat list.""" + lines = [] + for i, pref in enumerate(self.preferences, 1): + lines.append(f"{i}. When {pref.condition}: {pref.action}") + return "\n".join(lines) + + def to_dict(self) -> Dict[str, Any]: + return { + "user_id": self.user_id, + "persona": self.persona, + "preferences": [p.to_dict() for p in self.preferences], + "primary_topics": self.primary_topics, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "UserProfile": + prefs = [PreferenceItem.from_dict(p) for p in data.get("preferences", [])] + return cls( + user_id=data["user_id"], + persona=data["persona"], + preferences=prefs, + primary_topics=data.get("primary_topics", []), + ) + + def stats(self) -> Dict[str, Any]: + """Get statistics about this profile.""" + conflict_groups = set() + for pref in self.preferences: + if pref.conflict_group: + conflict_groups.add(pref.conflict_group) + + return { + "user_id": self.user_id, + "num_preferences": len(self.preferences), + "num_topics": len(self.preference_by_topic), + "prefs_per_topic": {t: len(ps) for t, ps in self.preference_by_topic.items()}, + "num_conflict_groups_used": len(conflict_groups), + } + + +# Persona templates for different user types +PERSONA_TEMPLATES = [ + "A {field} professional who values {trait} and prefers {style} communication.", + "A graduate student in {field} who appreciates {trait} and likes responses that are {style}.", + "An experienced {field} practitioner who prioritizes {trait} and expects {style} explanations.", + "A beginner learning {field} who needs {trait} and responds well to {style} guidance.", + "A {field} enthusiast who cares about {trait} and prefers {style} interactions.", +] + +FIELDS = [ + "software engineering", "data science", "mathematics", "physics", + "medical research", "financial analysis", "machine learning", + "web development", "systems programming", "algorithm design", +] + +TRAITS = [ + "clarity", "precision", "efficiency", "thoroughness", "simplicity", + "formality", "practicality", "theoretical depth", "hands-on examples", +] + +STYLES = [ + "concise", "detailed", "step-by-step", "example-driven", "formal", + "conversational", "structured", "visual", "analytical", +] + + +class UserProfileGenerator: + """Generates user profiles by sampling from preference bank.""" + + def __init__( + self, + preference_bank: PreferenceBank, + target_num_prefs: int = 15, # For demo, use smaller number + seed: Optional[int] = None, + ): + self.bank = preference_bank + self.target_num = target_num_prefs + + if seed is not None: + random.seed(seed) + + def generate_profile( + self, + user_id: str, + primary_topics: List[str] = None, + persona: str = None, + ) -> UserProfile: + """ + Generate a user profile by sampling preferences. + + Args: + user_id: Unique identifier for this user + primary_topics: Topics this user cares most about (get more prefs from these) + persona: Optional persona description. If None, will be generated. + + Returns: + UserProfile with sampled preferences + """ + selected: List[PreferenceItem] = [] + used_conflict_groups: Set[str] = set() + + # If no primary topics specified, randomly select 1-2 + if primary_topics is None: + all_topics = list(self.bank.topics.keys()) + num_primary = random.randint(1, min(2, len(all_topics))) + primary_topics = random.sample(all_topics, num_primary) + + # Compute quotas for each topic + topic_quotas = self._compute_quotas(primary_topics) + + # Sample from each topic + for topic_name, quota in topic_quotas.items(): + if topic_name not in self.bank.topics: + continue + + topic = self.bank.topics[topic_name] + + # Filter out preferences with already-used conflict groups + available = [ + p for p in topic.preferences + if p.conflict_group is None or p.conflict_group not in used_conflict_groups + ] + + # Sample up to quota + to_select = min(quota, len(available)) + if to_select > 0: + sampled = random.sample(available, to_select) + + for pref in sampled: + selected.append(pref) + if pref.conflict_group: + used_conflict_groups.add(pref.conflict_group) + + # Generate persona if not provided + if persona is None: + persona = self._generate_persona(primary_topics) + + return UserProfile( + user_id=user_id, + persona=persona, + preferences=selected, + primary_topics=primary_topics, + ) + + def _compute_quotas(self, primary_topics: List[str]) -> Dict[str, int]: + """Compute how many preferences to sample from each topic.""" + quotas = {} + all_topics = list(self.bank.topics.keys()) + + # Base quota for all topics + base_quota = max(1, self.target_num // len(all_topics)) + + for topic_name in all_topics: + if topic_name in primary_topics: + # Primary topics get more preferences + quotas[topic_name] = base_quota + random.randint(1, 3) + else: + quotas[topic_name] = max(1, base_quota - random.randint(0, 1)) + + # Adjust to match target + total = sum(quotas.values()) + if total < self.target_num: + # Add more to primary topics + for topic in primary_topics: + if topic in quotas: + quotas[topic] += (self.target_num - total) // len(primary_topics) + + return quotas + + def _generate_persona(self, primary_topics: List[str]) -> str: + """Generate a persona description based on primary topics.""" + template = random.choice(PERSONA_TEMPLATES) + + # Map topics to fields + topic_to_field = { + "math_formatting": ["mathematics", "physics", "data science"], + "coding_style": ["software engineering", "web development", "systems programming"], + "response_structure": ["technical writing", "documentation", "education"], + "explanation_depth": ["research", "teaching", "consulting"], + "interaction_style": ["customer support", "mentoring", "collaboration"], + } + + # Pick a field related to primary topics + possible_fields = [] + for topic in primary_topics: + possible_fields.extend(topic_to_field.get(topic, FIELDS[:3])) + + if not possible_fields: + possible_fields = FIELDS + + field = random.choice(possible_fields) + trait = random.choice(TRAITS) + style = random.choice(STYLES) + + return template.format(field=field, trait=trait, style=style) + + def generate_profiles( + self, + num_users: int, + id_prefix: str = "user", + ) -> List[UserProfile]: + """Generate multiple user profiles.""" + profiles = [] + + for i in range(num_users): + user_id = f"{id_prefix}_{i:03d}" + profile = self.generate_profile(user_id) + profiles.append(profile) + + return profiles + + def save_profiles(self, profiles: List[UserProfile], path: str): + """Save profiles to JSON file.""" + data = [p.to_dict() for p in profiles] + with open(path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, ensure_ascii=False) + + @staticmethod + def load_profiles(path: str) -> List[UserProfile]: + """Load profiles from JSON file.""" + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + return [UserProfile.from_dict(d) for d in data] + + +def generate_demo_profiles( + bank: PreferenceBank, + num_users: int = 2, + prefs_per_user: int = 10, + output_path: str = None, + seed: int = 42, +) -> List[UserProfile]: + """ + Generate demo user profiles. + + Args: + bank: Preference bank to sample from + num_users: Number of users to generate + prefs_per_user: Target preferences per user + output_path: If provided, save profiles to this path + seed: Random seed for reproducibility + + Returns: + List of UserProfile objects + """ + generator = UserProfileGenerator( + preference_bank=bank, + target_num_prefs=prefs_per_user, + seed=seed, + ) + + profiles = generator.generate_profiles(num_users, id_prefix="demo_user") + + if output_path: + generator.save_profiles(profiles, output_path) + print(f"Saved {len(profiles)} profiles to {output_path}") + + # Print stats + for profile in profiles: + print(f"\n{profile.user_id}: {profile.stats()}") + print(f" Persona: {profile.persona}") + + return profiles + + +if __name__ == "__main__": + import os + from ..preference_bank.generator import generate_demo_bank + + # Generate bank first + script_dir = os.path.dirname(os.path.abspath(__file__)) + bank_path = os.path.join(script_dir, "..", "preference_bank", "bank_demo.json") + + if os.path.exists(bank_path): + bank = PreferenceBank.load(bank_path) + else: + bank = generate_demo_bank() + + # Generate profiles + profiles_path = os.path.join(script_dir, "profiles_demo.json") + profiles = generate_demo_profiles( + bank=bank, + num_users=2, + prefs_per_user=10, + output_path=profiles_path, + ) + + |
