summaryrefslogtreecommitdiff
path: root/collaborativeagents/scripts/extend_profiles.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-27 09:57:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-27 09:57:37 -0600
commitdc801c07cf38b0c495686463e6ca6f871a64440e (patch)
tree599f03114775921dbc472403c701f4a3a8ea188a /collaborativeagents/scripts/extend_profiles.py
parente43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (diff)
Add collaborativeagents module and update gitignore
- Add collaborativeagents subproject with adapters, agents, and evaluation modules - Update .gitignore to exclude large binary files (.whl, .tar), wandb logs, and results Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'collaborativeagents/scripts/extend_profiles.py')
-rw-r--r--collaborativeagents/scripts/extend_profiles.py195
1 files changed, 195 insertions, 0 deletions
diff --git a/collaborativeagents/scripts/extend_profiles.py b/collaborativeagents/scripts/extend_profiles.py
new file mode 100644
index 0000000..d780697
--- /dev/null
+++ b/collaborativeagents/scripts/extend_profiles.py
@@ -0,0 +1,195 @@
+#!/usr/bin/env python3
+"""
+Generate additional profiles by remixing preferences from existing profiles.
+This creates diverse profile combinations without requiring LLM calls.
+"""
+
+import json
+import random
+import hashlib
+from pathlib import Path
+from typing import List, Dict
+import argparse
+
+
+def load_profiles(path: Path) -> List[Dict]:
+ """Load profiles from JSONL file."""
+ profiles = []
+ with open(path) as f:
+ for line in f:
+ profiles.append(json.loads(line.strip()))
+ return profiles
+
+
+def extract_all_preferences(profiles: List[Dict]) -> Dict[str, List[Dict]]:
+ """Extract all unique preferences grouped by category (prefix)."""
+ categories = {}
+ seen_ids = set()
+
+ for profile in profiles:
+ for pref in profile.get("preferences", []):
+ pref_id = pref.get("pref_id", "unknown")
+ if pref_id in seen_ids:
+ continue
+ seen_ids.add(pref_id)
+
+ # Extract category from prefix (e.g., "rf_001" -> "rf")
+ prefix = pref_id.split("_")[0] if "_" in pref_id else "other"
+ if prefix not in categories:
+ categories[prefix] = []
+ categories[prefix].append(pref)
+
+ return categories
+
+
+def extract_personas(profiles: List[Dict]) -> List[str]:
+ """Extract unique personas from profiles."""
+ personas = []
+ seen = set()
+ for profile in profiles:
+ persona = profile.get("persona", "")
+ if persona and persona not in seen:
+ personas.append(persona)
+ seen.add(persona)
+ return personas
+
+
+def generate_new_profile(
+ user_id: str,
+ preference_pool: Dict[str, List[Dict]],
+ personas: List[str],
+ target_prefs: int = 43,
+ rng: random.Random = None
+) -> Dict:
+ """Generate a new profile by sampling from preference pool."""
+ if rng is None:
+ rng = random.Random()
+
+ selected_prefs = []
+
+ # Sample from each category to maintain diversity
+ prefs_per_cat = max(1, target_prefs // len(preference_pool))
+
+ for cat, prefs in preference_pool.items():
+ # Sample with some randomness
+ n_sample = min(len(prefs), prefs_per_cat + rng.randint(-1, 2))
+ n_sample = max(1, n_sample)
+ sampled = rng.sample(prefs, min(n_sample, len(prefs)))
+ selected_prefs.extend(sampled)
+
+ # Add/remove to hit target
+ all_prefs = []
+ for prefs in preference_pool.values():
+ all_prefs.extend(prefs)
+
+ while len(selected_prefs) < target_prefs:
+ remaining = [p for p in all_prefs if p not in selected_prefs]
+ if not remaining:
+ break
+ selected_prefs.append(rng.choice(remaining))
+
+ while len(selected_prefs) > target_prefs:
+ selected_prefs.pop(rng.randint(0, len(selected_prefs) - 1))
+
+ # Build conflict groups
+ conflict_groups = {}
+ for pref in selected_prefs:
+ cg = pref.get("conflict_group")
+ if cg:
+ if cg not in conflict_groups:
+ conflict_groups[cg] = []
+ conflict_groups[cg].append(pref["pref_id"])
+
+ return {
+ "user_id": user_id,
+ "persona": rng.choice(personas),
+ "preferences": selected_prefs,
+ "conflict_groups": conflict_groups,
+ "meta": {
+ "total_preferences": len(selected_prefs),
+ "total_conflict_groups": len(conflict_groups),
+ "generator": "extend_profiles.py"
+ }
+ }
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Generate additional profiles by remixing existing ones"
+ )
+ parser.add_argument("--input", type=str, required=True,
+ help="Path to existing profiles JSONL")
+ parser.add_argument("--output", type=str, required=True,
+ help="Path for output profiles JSONL")
+ parser.add_argument("--num-new", type=int, default=100,
+ help="Number of new profiles to generate")
+ parser.add_argument("--seed", type=int, default=142,
+ help="Random seed (use different from original)")
+ parser.add_argument("--target-prefs", type=int, default=43,
+ help="Target number of preferences per profile")
+ parser.add_argument("--merge", action="store_true",
+ help="Merge with existing profiles in output")
+
+ args = parser.parse_args()
+
+ input_path = Path(args.input)
+ output_path = Path(args.output)
+
+ print(f"Loading profiles from: {input_path}")
+ profiles = load_profiles(input_path)
+ print(f" Loaded {len(profiles)} profiles")
+
+ # Extract preference pool and personas
+ pref_pool = extract_all_preferences(profiles)
+ personas = extract_personas(profiles)
+
+ print(f"\nPreference pool:")
+ for cat, prefs in pref_pool.items():
+ print(f" {cat}: {len(prefs)} preferences")
+ print(f" Total unique preferences: {sum(len(p) for p in pref_pool.values())}")
+ print(f" Unique personas: {len(personas)}")
+
+ # Generate new profiles
+ rng = random.Random(args.seed)
+ new_profiles = []
+
+ print(f"\nGenerating {args.num_new} new profiles...")
+ for i in range(args.num_new):
+ user_id = f"user_{hashlib.md5(f'{args.seed}_{i}'.encode()).hexdigest()[:8]}"
+ profile = generate_new_profile(
+ user_id=user_id,
+ preference_pool=pref_pool,
+ personas=personas,
+ target_prefs=args.target_prefs,
+ rng=rng
+ )
+ new_profiles.append(profile)
+
+ if (i + 1) % 20 == 0:
+ print(f" Generated {i + 1}/{args.num_new}")
+
+ # Optionally merge with original
+ if args.merge:
+ output_profiles = profiles + new_profiles
+ print(f"\nMerging: {len(profiles)} original + {len(new_profiles)} new = {len(output_profiles)}")
+ else:
+ output_profiles = new_profiles
+
+ # Save
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+ with open(output_path, 'w') as f:
+ for profile in output_profiles:
+ f.write(json.dumps(profile) + '\n')
+
+ print(f"\nSaved {len(output_profiles)} profiles to: {output_path}")
+
+ # Summary stats
+ pref_counts = [p["meta"]["total_preferences"] for p in output_profiles]
+ print(f"\nProfile statistics:")
+ print(f" Min preferences: {min(pref_counts)}")
+ print(f" Max preferences: {max(pref_counts)}")
+ print(f" Avg preferences: {sum(pref_counts)/len(pref_counts):.1f}")
+
+
+if __name__ == "__main__":
+ main()