""" User Profile Generator Generates user profiles by sampling preferences from the preference bank. Ensures no conflicting preferences within same conflict_group, but allows cross-topic scenario conflicts (which is desired for testing RAG). """ import json import random from collections import defaultdict from dataclasses import dataclass, field from typing import List, Dict, Set, Optional, Any from ..preference_bank.schemas import PreferenceItem, PreferenceBank @dataclass class UserProfile: """A simulated user with specific preferences.""" user_id: str persona: str # Background description preferences: List[PreferenceItem] # Selected preferences primary_topics: List[str] # Topics this user cares most about preference_by_topic: Dict[str, List[PreferenceItem]] = field(default_factory=dict) def __post_init__(self): # Build topic index if not provided if not self.preference_by_topic: self.preference_by_topic = defaultdict(list) for pref in self.preferences: self.preference_by_topic[pref.topic].append(pref) self.preference_by_topic = dict(self.preference_by_topic) def get_preferences_for_topic(self, topic: str) -> List[PreferenceItem]: """Get preferences for a specific topic.""" return self.preference_by_topic.get(topic, []) def get_preferences_for_dataset(self, dataset: str, bank: PreferenceBank) -> List[PreferenceItem]: """Get preferences relevant to a specific dataset.""" relevant_topics = set() for topic_name, topic in bank.topics.items(): if dataset in topic.related_datasets or "all" in topic.related_datasets: relevant_topics.add(topic_name) relevant_prefs = [] for pref in self.preferences: if pref.topic in relevant_topics: relevant_prefs.append(pref) return relevant_prefs def format_preferences_grouped(self) -> str: """Format preferences grouped by topic for prompts.""" lines = [] for topic, prefs in self.preference_by_topic.items(): topic_title = topic.replace("_", " ").title() lines.append(f"\n## {topic_title}") for pref in prefs: lines.append(f" [{pref.id}] When {pref.condition}: {pref.action}") lines.append(f" Enforce if: {pref.enforce_description}") return "\n".join(lines) def format_preferences_flat(self) -> str: """Format preferences as a flat list.""" lines = [] for i, pref in enumerate(self.preferences, 1): lines.append(f"{i}. When {pref.condition}: {pref.action}") return "\n".join(lines) def to_dict(self) -> Dict[str, Any]: return { "user_id": self.user_id, "persona": self.persona, "preferences": [p.to_dict() for p in self.preferences], "primary_topics": self.primary_topics, } @classmethod def from_dict(cls, data: Dict[str, Any]) -> "UserProfile": prefs = [PreferenceItem.from_dict(p) for p in data.get("preferences", [])] return cls( user_id=data["user_id"], persona=data["persona"], preferences=prefs, primary_topics=data.get("primary_topics", []), ) def stats(self) -> Dict[str, Any]: """Get statistics about this profile.""" conflict_groups = set() for pref in self.preferences: if pref.conflict_group: conflict_groups.add(pref.conflict_group) return { "user_id": self.user_id, "num_preferences": len(self.preferences), "num_topics": len(self.preference_by_topic), "prefs_per_topic": {t: len(ps) for t, ps in self.preference_by_topic.items()}, "num_conflict_groups_used": len(conflict_groups), } # Persona templates for different user types PERSONA_TEMPLATES = [ "A {field} professional who values {trait} and prefers {style} communication.", "A graduate student in {field} who appreciates {trait} and likes responses that are {style}.", "An experienced {field} practitioner who prioritizes {trait} and expects {style} explanations.", "A beginner learning {field} who needs {trait} and responds well to {style} guidance.", "A {field} enthusiast who cares about {trait} and prefers {style} interactions.", ] FIELDS = [ "software engineering", "data science", "mathematics", "physics", "medical research", "financial analysis", "machine learning", "web development", "systems programming", "algorithm design", ] TRAITS = [ "clarity", "precision", "efficiency", "thoroughness", "simplicity", "formality", "practicality", "theoretical depth", "hands-on examples", ] STYLES = [ "concise", "detailed", "step-by-step", "example-driven", "formal", "conversational", "structured", "visual", "analytical", ] class UserProfileGenerator: """Generates user profiles by sampling from preference bank.""" def __init__( self, preference_bank: PreferenceBank, target_num_prefs: int = 15, # For demo, use smaller number seed: Optional[int] = None, ): self.bank = preference_bank self.target_num = target_num_prefs if seed is not None: random.seed(seed) def generate_profile( self, user_id: str, primary_topics: List[str] = None, persona: str = None, ) -> UserProfile: """ Generate a user profile by sampling preferences. Args: user_id: Unique identifier for this user primary_topics: Topics this user cares most about (get more prefs from these) persona: Optional persona description. If None, will be generated. Returns: UserProfile with sampled preferences """ selected: List[PreferenceItem] = [] used_conflict_groups: Set[str] = set() # If no primary topics specified, randomly select 1-2 if primary_topics is None: all_topics = list(self.bank.topics.keys()) num_primary = random.randint(1, min(2, len(all_topics))) primary_topics = random.sample(all_topics, num_primary) # Compute quotas for each topic topic_quotas = self._compute_quotas(primary_topics) # Sample from each topic for topic_name, quota in topic_quotas.items(): if topic_name not in self.bank.topics: continue topic = self.bank.topics[topic_name] # Filter out preferences with already-used conflict groups available = [ p for p in topic.preferences if p.conflict_group is None or p.conflict_group not in used_conflict_groups ] # Sample up to quota to_select = min(quota, len(available)) if to_select > 0: sampled = random.sample(available, to_select) for pref in sampled: selected.append(pref) if pref.conflict_group: used_conflict_groups.add(pref.conflict_group) # Generate persona if not provided if persona is None: persona = self._generate_persona(primary_topics) return UserProfile( user_id=user_id, persona=persona, preferences=selected, primary_topics=primary_topics, ) def _compute_quotas(self, primary_topics: List[str]) -> Dict[str, int]: """Compute how many preferences to sample from each topic.""" quotas = {} all_topics = list(self.bank.topics.keys()) # Base quota for all topics base_quota = max(1, self.target_num // len(all_topics)) for topic_name in all_topics: if topic_name in primary_topics: # Primary topics get more preferences quotas[topic_name] = base_quota + random.randint(1, 3) else: quotas[topic_name] = max(1, base_quota - random.randint(0, 1)) # Adjust to match target total = sum(quotas.values()) if total < self.target_num: # Add more to primary topics for topic in primary_topics: if topic in quotas: quotas[topic] += (self.target_num - total) // len(primary_topics) return quotas def _generate_persona(self, primary_topics: List[str]) -> str: """Generate a persona description based on primary topics.""" template = random.choice(PERSONA_TEMPLATES) # Map topics to fields topic_to_field = { "math_formatting": ["mathematics", "physics", "data science"], "coding_style": ["software engineering", "web development", "systems programming"], "response_structure": ["technical writing", "documentation", "education"], "explanation_depth": ["research", "teaching", "consulting"], "interaction_style": ["customer support", "mentoring", "collaboration"], } # Pick a field related to primary topics possible_fields = [] for topic in primary_topics: possible_fields.extend(topic_to_field.get(topic, FIELDS[:3])) if not possible_fields: possible_fields = FIELDS field = random.choice(possible_fields) trait = random.choice(TRAITS) style = random.choice(STYLES) return template.format(field=field, trait=trait, style=style) def generate_profiles( self, num_users: int, id_prefix: str = "user", ) -> List[UserProfile]: """Generate multiple user profiles.""" profiles = [] for i in range(num_users): user_id = f"{id_prefix}_{i:03d}" profile = self.generate_profile(user_id) profiles.append(profile) return profiles def save_profiles(self, profiles: List[UserProfile], path: str): """Save profiles to JSON file.""" data = [p.to_dict() for p in profiles] with open(path, "w", encoding="utf-8") as f: json.dump(data, f, indent=2, ensure_ascii=False) @staticmethod def load_profiles(path: str) -> List[UserProfile]: """Load profiles from JSON file.""" with open(path, "r", encoding="utf-8") as f: data = json.load(f) return [UserProfile.from_dict(d) for d in data] def generate_demo_profiles( bank: PreferenceBank, num_users: int = 2, prefs_per_user: int = 10, output_path: str = None, seed: int = 42, ) -> List[UserProfile]: """ Generate demo user profiles. Args: bank: Preference bank to sample from num_users: Number of users to generate prefs_per_user: Target preferences per user output_path: If provided, save profiles to this path seed: Random seed for reproducibility Returns: List of UserProfile objects """ generator = UserProfileGenerator( preference_bank=bank, target_num_prefs=prefs_per_user, seed=seed, ) profiles = generator.generate_profiles(num_users, id_prefix="demo_user") if output_path: generator.save_profiles(profiles, output_path) print(f"Saved {len(profiles)} profiles to {output_path}") # Print stats for profile in profiles: print(f"\n{profile.user_id}: {profile.stats()}") print(f" Persona: {profile.persona}") return profiles if __name__ == "__main__": import os from ..preference_bank.generator import generate_demo_bank # Generate bank first script_dir = os.path.dirname(os.path.abspath(__file__)) bank_path = os.path.join(script_dir, "..", "preference_bank", "bank_demo.json") if os.path.exists(bank_path): bank = PreferenceBank.load(bank_path) else: bank = generate_demo_bank() # Generate profiles profiles_path = os.path.join(script_dir, "profiles_demo.json") profiles = generate_demo_profiles( bank=bank, num_users=2, prefs_per_user=10, output_path=profiles_path, )