diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-27 09:57:37 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-27 09:57:37 -0600 |
| commit | dc801c07cf38b0c495686463e6ca6f871a64440e (patch) | |
| tree | 599f03114775921dbc472403c701f4a3a8ea188a /collaborativeagents/scripts/extend_profiles.py | |
| parent | e43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (diff) | |
Add collaborativeagents module and update gitignore
- Add collaborativeagents subproject with adapters, agents, and evaluation modules
- Update .gitignore to exclude large binary files (.whl, .tar), wandb logs, and results
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'collaborativeagents/scripts/extend_profiles.py')
| -rw-r--r-- | collaborativeagents/scripts/extend_profiles.py | 195 |
1 files changed, 195 insertions, 0 deletions
diff --git a/collaborativeagents/scripts/extend_profiles.py b/collaborativeagents/scripts/extend_profiles.py new file mode 100644 index 0000000..d780697 --- /dev/null +++ b/collaborativeagents/scripts/extend_profiles.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +""" +Generate additional profiles by remixing preferences from existing profiles. +This creates diverse profile combinations without requiring LLM calls. +""" + +import json +import random +import hashlib +from pathlib import Path +from typing import List, Dict +import argparse + + +def load_profiles(path: Path) -> List[Dict]: + """Load profiles from JSONL file.""" + profiles = [] + with open(path) as f: + for line in f: + profiles.append(json.loads(line.strip())) + return profiles + + +def extract_all_preferences(profiles: List[Dict]) -> Dict[str, List[Dict]]: + """Extract all unique preferences grouped by category (prefix).""" + categories = {} + seen_ids = set() + + for profile in profiles: + for pref in profile.get("preferences", []): + pref_id = pref.get("pref_id", "unknown") + if pref_id in seen_ids: + continue + seen_ids.add(pref_id) + + # Extract category from prefix (e.g., "rf_001" -> "rf") + prefix = pref_id.split("_")[0] if "_" in pref_id else "other" + if prefix not in categories: + categories[prefix] = [] + categories[prefix].append(pref) + + return categories + + +def extract_personas(profiles: List[Dict]) -> List[str]: + """Extract unique personas from profiles.""" + personas = [] + seen = set() + for profile in profiles: + persona = profile.get("persona", "") + if persona and persona not in seen: + personas.append(persona) + seen.add(persona) + return personas + + +def generate_new_profile( + user_id: str, + preference_pool: Dict[str, List[Dict]], + personas: List[str], + target_prefs: int = 43, + rng: random.Random = None +) -> Dict: + """Generate a new profile by sampling from preference pool.""" + if rng is None: + rng = random.Random() + + selected_prefs = [] + + # Sample from each category to maintain diversity + prefs_per_cat = max(1, target_prefs // len(preference_pool)) + + for cat, prefs in preference_pool.items(): + # Sample with some randomness + n_sample = min(len(prefs), prefs_per_cat + rng.randint(-1, 2)) + n_sample = max(1, n_sample) + sampled = rng.sample(prefs, min(n_sample, len(prefs))) + selected_prefs.extend(sampled) + + # Add/remove to hit target + all_prefs = [] + for prefs in preference_pool.values(): + all_prefs.extend(prefs) + + while len(selected_prefs) < target_prefs: + remaining = [p for p in all_prefs if p not in selected_prefs] + if not remaining: + break + selected_prefs.append(rng.choice(remaining)) + + while len(selected_prefs) > target_prefs: + selected_prefs.pop(rng.randint(0, len(selected_prefs) - 1)) + + # Build conflict groups + conflict_groups = {} + for pref in selected_prefs: + cg = pref.get("conflict_group") + if cg: + if cg not in conflict_groups: + conflict_groups[cg] = [] + conflict_groups[cg].append(pref["pref_id"]) + + return { + "user_id": user_id, + "persona": rng.choice(personas), + "preferences": selected_prefs, + "conflict_groups": conflict_groups, + "meta": { + "total_preferences": len(selected_prefs), + "total_conflict_groups": len(conflict_groups), + "generator": "extend_profiles.py" + } + } + + +def main(): + parser = argparse.ArgumentParser( + description="Generate additional profiles by remixing existing ones" + ) + parser.add_argument("--input", type=str, required=True, + help="Path to existing profiles JSONL") + parser.add_argument("--output", type=str, required=True, + help="Path for output profiles JSONL") + parser.add_argument("--num-new", type=int, default=100, + help="Number of new profiles to generate") + parser.add_argument("--seed", type=int, default=142, + help="Random seed (use different from original)") + parser.add_argument("--target-prefs", type=int, default=43, + help="Target number of preferences per profile") + parser.add_argument("--merge", action="store_true", + help="Merge with existing profiles in output") + + args = parser.parse_args() + + input_path = Path(args.input) + output_path = Path(args.output) + + print(f"Loading profiles from: {input_path}") + profiles = load_profiles(input_path) + print(f" Loaded {len(profiles)} profiles") + + # Extract preference pool and personas + pref_pool = extract_all_preferences(profiles) + personas = extract_personas(profiles) + + print(f"\nPreference pool:") + for cat, prefs in pref_pool.items(): + print(f" {cat}: {len(prefs)} preferences") + print(f" Total unique preferences: {sum(len(p) for p in pref_pool.values())}") + print(f" Unique personas: {len(personas)}") + + # Generate new profiles + rng = random.Random(args.seed) + new_profiles = [] + + print(f"\nGenerating {args.num_new} new profiles...") + for i in range(args.num_new): + user_id = f"user_{hashlib.md5(f'{args.seed}_{i}'.encode()).hexdigest()[:8]}" + profile = generate_new_profile( + user_id=user_id, + preference_pool=pref_pool, + personas=personas, + target_prefs=args.target_prefs, + rng=rng + ) + new_profiles.append(profile) + + if (i + 1) % 20 == 0: + print(f" Generated {i + 1}/{args.num_new}") + + # Optionally merge with original + if args.merge: + output_profiles = profiles + new_profiles + print(f"\nMerging: {len(profiles)} original + {len(new_profiles)} new = {len(output_profiles)}") + else: + output_profiles = new_profiles + + # Save + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, 'w') as f: + for profile in output_profiles: + f.write(json.dumps(profile) + '\n') + + print(f"\nSaved {len(output_profiles)} profiles to: {output_path}") + + # Summary stats + pref_counts = [p["meta"]["total_preferences"] for p in output_profiles] + print(f"\nProfile statistics:") + print(f" Min preferences: {min(pref_counts)}") + print(f" Max preferences: {max(pref_counts)}") + print(f" Avg preferences: {sum(pref_counts)/len(pref_counts):.1f}") + + +if __name__ == "__main__": + main() |
