summaryrefslogtreecommitdiff
path: root/collaborativeagents/scripts/extend_profiles.py
blob: d7806976cad79d9361c7e98c50f5087fcf784a84 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
#!/usr/bin/env python3
"""
Generate additional profiles by remixing preferences from existing profiles.
This creates diverse profile combinations without requiring LLM calls.
"""

import json
import random
import hashlib
from pathlib import Path
from typing import List, Dict
import argparse


def load_profiles(path: Path) -> List[Dict]:
    """Load profiles from JSONL file."""
    profiles = []
    with open(path) as f:
        for line in f:
            profiles.append(json.loads(line.strip()))
    return profiles


def extract_all_preferences(profiles: List[Dict]) -> Dict[str, List[Dict]]:
    """Extract all unique preferences grouped by category (prefix)."""
    categories = {}
    seen_ids = set()

    for profile in profiles:
        for pref in profile.get("preferences", []):
            pref_id = pref.get("pref_id", "unknown")
            if pref_id in seen_ids:
                continue
            seen_ids.add(pref_id)

            # Extract category from prefix (e.g., "rf_001" -> "rf")
            prefix = pref_id.split("_")[0] if "_" in pref_id else "other"
            if prefix not in categories:
                categories[prefix] = []
            categories[prefix].append(pref)

    return categories


def extract_personas(profiles: List[Dict]) -> List[str]:
    """Extract unique personas from profiles."""
    personas = []
    seen = set()
    for profile in profiles:
        persona = profile.get("persona", "")
        if persona and persona not in seen:
            personas.append(persona)
            seen.add(persona)
    return personas


def generate_new_profile(
    user_id: str,
    preference_pool: Dict[str, List[Dict]],
    personas: List[str],
    target_prefs: int = 43,
    rng: random.Random = None
) -> Dict:
    """Generate a new profile by sampling from preference pool."""
    if rng is None:
        rng = random.Random()

    selected_prefs = []

    # Sample from each category to maintain diversity
    prefs_per_cat = max(1, target_prefs // len(preference_pool))

    for cat, prefs in preference_pool.items():
        # Sample with some randomness
        n_sample = min(len(prefs), prefs_per_cat + rng.randint(-1, 2))
        n_sample = max(1, n_sample)
        sampled = rng.sample(prefs, min(n_sample, len(prefs)))
        selected_prefs.extend(sampled)

    # Add/remove to hit target
    all_prefs = []
    for prefs in preference_pool.values():
        all_prefs.extend(prefs)

    while len(selected_prefs) < target_prefs:
        remaining = [p for p in all_prefs if p not in selected_prefs]
        if not remaining:
            break
        selected_prefs.append(rng.choice(remaining))

    while len(selected_prefs) > target_prefs:
        selected_prefs.pop(rng.randint(0, len(selected_prefs) - 1))

    # Build conflict groups
    conflict_groups = {}
    for pref in selected_prefs:
        cg = pref.get("conflict_group")
        if cg:
            if cg not in conflict_groups:
                conflict_groups[cg] = []
            conflict_groups[cg].append(pref["pref_id"])

    return {
        "user_id": user_id,
        "persona": rng.choice(personas),
        "preferences": selected_prefs,
        "conflict_groups": conflict_groups,
        "meta": {
            "total_preferences": len(selected_prefs),
            "total_conflict_groups": len(conflict_groups),
            "generator": "extend_profiles.py"
        }
    }


def main():
    parser = argparse.ArgumentParser(
        description="Generate additional profiles by remixing existing ones"
    )
    parser.add_argument("--input", type=str, required=True,
                        help="Path to existing profiles JSONL")
    parser.add_argument("--output", type=str, required=True,
                        help="Path for output profiles JSONL")
    parser.add_argument("--num-new", type=int, default=100,
                        help="Number of new profiles to generate")
    parser.add_argument("--seed", type=int, default=142,
                        help="Random seed (use different from original)")
    parser.add_argument("--target-prefs", type=int, default=43,
                        help="Target number of preferences per profile")
    parser.add_argument("--merge", action="store_true",
                        help="Merge with existing profiles in output")

    args = parser.parse_args()

    input_path = Path(args.input)
    output_path = Path(args.output)

    print(f"Loading profiles from: {input_path}")
    profiles = load_profiles(input_path)
    print(f"  Loaded {len(profiles)} profiles")

    # Extract preference pool and personas
    pref_pool = extract_all_preferences(profiles)
    personas = extract_personas(profiles)

    print(f"\nPreference pool:")
    for cat, prefs in pref_pool.items():
        print(f"  {cat}: {len(prefs)} preferences")
    print(f"  Total unique preferences: {sum(len(p) for p in pref_pool.values())}")
    print(f"  Unique personas: {len(personas)}")

    # Generate new profiles
    rng = random.Random(args.seed)
    new_profiles = []

    print(f"\nGenerating {args.num_new} new profiles...")
    for i in range(args.num_new):
        user_id = f"user_{hashlib.md5(f'{args.seed}_{i}'.encode()).hexdigest()[:8]}"
        profile = generate_new_profile(
            user_id=user_id,
            preference_pool=pref_pool,
            personas=personas,
            target_prefs=args.target_prefs,
            rng=rng
        )
        new_profiles.append(profile)

        if (i + 1) % 20 == 0:
            print(f"  Generated {i + 1}/{args.num_new}")

    # Optionally merge with original
    if args.merge:
        output_profiles = profiles + new_profiles
        print(f"\nMerging: {len(profiles)} original + {len(new_profiles)} new = {len(output_profiles)}")
    else:
        output_profiles = new_profiles

    # Save
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, 'w') as f:
        for profile in output_profiles:
            f.write(json.dumps(profile) + '\n')

    print(f"\nSaved {len(output_profiles)} profiles to: {output_path}")

    # Summary stats
    pref_counts = [p["meta"]["total_preferences"] for p in output_profiles]
    print(f"\nProfile statistics:")
    print(f"  Min preferences: {min(pref_counts)}")
    print(f"  Max preferences: {max(pref_counts)}")
    print(f"  Avg preferences: {sum(pref_counts)/len(pref_counts):.1f}")


if __name__ == "__main__":
    main()