summaryrefslogtreecommitdiff
path: root/collaborativeagents/scripts/generate_complex_profiles.py
blob: 3838413a55738cd57e6645d077d821c5c3d5b2c0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
"""
Generate complex user profiles with conditional preferences using LLM.

This script generates user profiles with ~40 situation-dependent preferences
designed to stress-test retrieval-based personalization systems.
"""

import json
import random
from pathlib import Path
from typing import Optional
from dataclasses import dataclass, field, asdict
import hashlib

# Will use litellm for generation
try:
    import litellm
except ImportError:
    litellm = None


# ============================================================================
# Schema Definitions
# ============================================================================

@dataclass
class ConditionalPreference:
    """A preference that applies under specific conditions."""
    pref_id: str
    condition: str  # When this preference applies
    action: str     # What the user prefers
    conflict_group: Optional[str] = None  # Which preferences this might conflict with
    priority_context: list = field(default_factory=list)  # Keywords that trigger this pref

    def to_natural_language(self) -> str:
        """Convert to natural language statement."""
        return f"When {self.condition}, {self.action}."

    def to_memory_card_format(self) -> dict:
        """Convert to format compatible with personalization system's MemoryCard."""
        return {
            "condition": self.condition,
            "action": self.action,
            "confidence": 1.0,
            "source": "user_profile",
            "pref_id": self.pref_id,
            "conflict_group": self.conflict_group,
            "priority_context": self.priority_context
        }


@dataclass
class ConflictGroup:
    """Defines a group of preferences that may conflict."""
    group_id: str
    description: str
    resolution_rule: str  # How to programmatically resolve
    member_pref_ids: list = field(default_factory=list)


@dataclass
class UserProfile:
    """A complex user profile with conditional preferences."""
    user_id: str
    persona: str  # High-level description
    preferences: list  # List of ConditionalPreference
    conflict_groups: dict = field(default_factory=dict)  # group_id -> ConflictGroup

    def get_preferences_by_category(self) -> dict:
        """Group preferences by their category (derived from pref_id prefix)."""
        categories = {}
        for pref in self.preferences:
            cat = pref.pref_id.split('_')[0]
            if cat not in categories:
                categories[cat] = []
            categories[cat].append(pref)
        return categories

    def get_conflicting_preferences(self, query: str) -> list:
        """Find preferences that might conflict for a given query."""
        # Simple keyword matching - in practice, use embeddings
        triggered = []
        query_lower = query.lower()
        for pref in self.preferences:
            for keyword in pref.priority_context:
                if keyword.lower() in query_lower:
                    triggered.append(pref)
                    break

        # Group by conflict group
        conflicts = {}
        for pref in triggered:
            if pref.conflict_group:
                if pref.conflict_group not in conflicts:
                    conflicts[pref.conflict_group] = []
                conflicts[pref.conflict_group].append(pref)

        # Return groups with more than one triggered preference
        return {k: v for k, v in conflicts.items() if len(v) > 1}

    def to_dict(self) -> dict:
        return {
            "user_id": self.user_id,
            "persona": self.persona,
            "preferences": [asdict(p) for p in self.preferences],
            "conflict_groups": {k: asdict(v) for k, v in self.conflict_groups.items()},
            "meta": {
                "total_preferences": len(self.preferences),
                "total_conflict_groups": len(self.conflict_groups)
            }
        }


# ============================================================================
# Preference Templates for LLM Generation
# ============================================================================

PREFERENCE_CATEGORIES = {
    "response_format": {
        "description": "How responses should be structured",
        "num_preferences": 4,
        "example_conflicts": ["bullets vs numbered", "answer-first vs build-up"],
        "generation_prompt": """Generate {n} preferences about response formatting.
Include conflicting pairs like:
- When to use bullet points vs numbered lists
- When to give answer first vs build up to it
Each preference must have a specific condition (when it applies) and action (what to do)."""
    },
    "verbosity": {
        "description": "How detailed responses should be",
        "num_preferences": 5,
        "example_conflicts": ["concise vs detailed", "explain why vs just answer"],
        "generation_prompt": """Generate {n} preferences about response verbosity.
Include conflicting pairs like:
- Brief responses vs detailed explanations
- When to explain reasoning vs just give answer
Conditions should include cue phrases like 'quick question', 'briefly', etc."""
    },
    "code_style": {
        "description": "Programming and code preferences",
        "num_preferences": 8,
        "example_conflicts": ["naming conventions by language", "comment styles", "review focus"],
        "generation_prompt": """Generate {n} preferences about code style.
Include:
- Language-specific naming conventions (Python snake_case, JS camelCase, etc.)
- Comment styles for different code lengths
- Code review focus (bugs only vs style too)
- Error handling preferences"""
    },
    "math_style": {
        "description": "Mathematical explanation preferences",
        "num_preferences": 6,
        "example_conflicts": ["step-by-step vs intuition", "formal vs informal"],
        "generation_prompt": """Generate {n} preferences about mathematical explanations.
Include:
- When to show detailed steps vs high-level approach
- Intuition-first vs formula-first for statistics
- How to structure proofs
- Verification requests"""
    },
    "interaction_pattern": {
        "description": "How to interact with user",
        "num_preferences": 6,
        "example_conflicts": ["confirm vs execute", "recommend vs list options"],
        "generation_prompt": """Generate {n} preferences about interaction patterns.
Include:
- When to confirm before acting vs execute directly
- When to recommend vs present options
- How to handle user emotions (frustration, gratitude)"""
    },
    "domain_specific": {
        "description": "Preferences for specific technical domains",
        "num_preferences": 6,
        "example_conflicts": ["example-first vs definition-first"],
        "generation_prompt": """Generate {n} domain-specific preferences for:
- Machine learning explanations
- System design discussions
- API/library usage
- Data structures (include complexity)"""
    },
    "error_correction": {
        "description": "How to handle user mistakes",
        "num_preferences": 4,
        "example_conflicts": ["gentle vs direct correction"],
        "generation_prompt": """Generate {n} preferences about error correction.
Include:
- Minor terminology errors vs fundamental misconceptions
- Code bugs
- Correcting own previous responses"""
    },
    "output_artifacts": {
        "description": "How to present code and commands",
        "num_preferences": 4,
        "example_conflicts": ["single block vs chunked"],
        "generation_prompt": """Generate {n} preferences about output artifacts.
Include:
- Copyable code blocks vs explained chunks
- Command presentation
- Language specification in code fences"""
    }
}


LLM_GENERATION_PROMPT = """You are generating user preferences for a personalization benchmark.

## Task
Generate {num_prefs} conditional preferences for the category: {category_name}
Description: {category_description}

## Requirements
1. Each preference must have:
   - A specific CONDITION (when it applies, including trigger phrases/situations)
   - An ACTION (what the user prefers to happen)
   - A CONFLICT_GROUP (if this preference might conflict with another)
   - PRIORITY_CONTEXT (list of keywords that trigger this preference)

2. Include at least one pair of CONFLICTING preferences that could both be triggered
   by different aspects of the same query. The conflict should be resolvable by
   looking at the specific context.

3. Conditions should be:
   - Specific and observable (not vague like "when appropriate")
   - Include trigger phrases users might say
   - Cover different situations within this category

4. Example conflicts for this category: {example_conflicts}

## Additional Context (if any)
{extra_context}

## Output Format
Return a JSON array of preferences:
```json
[
  {{
    "pref_id": "{category_prefix}_001",
    "condition": "specific situation or trigger phrase",
    "action": "what the user prefers",
    "conflict_group": "group_name or null",
    "priority_context": ["keyword1", "keyword2"]
  }},
  ...
]
```

Generate exactly {num_prefs} preferences."""


PERSONA_GENERATION_PROMPT = """Generate a realistic user persona for a software developer/researcher.

## Requirements
1. The persona should feel like a real person with:
   - A professional background (role, experience level, domain)
   - Communication style tendencies
   - Learning preferences
   - Work context (startup vs enterprise, solo vs team)

2. The persona should naturally motivate the preferences that will be assigned.

3. Keep it to 2-3 sentences.

## Preference Summary
This user will have preferences in these areas:
{preference_summary}

## Examples of good personas:
- "A senior backend engineer at a fintech startup who values efficiency and directness. Prefers practical solutions over theoretical discussions, and likes to understand the 'why' behind recommendations."
- "A PhD student in machine learning who is meticulous about mathematical rigor. Appreciates step-by-step derivations and often cross-references multiple sources before accepting an explanation."
- "A junior developer transitioning from frontend to full-stack. Learns best through examples and appreciates patient, incremental explanations without condescension."

## Output
Return only the persona text (2-3 sentences), no JSON or formatting."""


# ============================================================================
# Conflict Resolution Logic
# ============================================================================

CONFLICT_RESOLUTION_RULES = {
    "format_structure": {
        "signals": {
            "bullets": ["options", "alternatives", "list", "multiple", "comparison", "pros and cons"],
            "numbered": ["steps", "procedure", "how to", "setup", "install", "first", "then", "sequence"]
        },
        "resolution": "sequential_process -> numbered; parallel_items -> bullets"
    },
    "answer_position": {
        "signals": {
            "answer_first": ["what is", "what's", "tell me", "give me", "?"],
            "build_up": ["explain", "why", "how does", "teach", "help me understand"]
        },
        "resolution": "direct_question -> answer_first; learning_intent -> build_up"
    },
    "response_length": {
        "signals": {
            "concise": ["quick", "brief", "short", "tldr", "in a nutshell", "one line"],
            "detailed": ["explain", "elaborate", "in detail", "thoroughly", "complex", "proof"]
        },
        "resolution": "explicit_brevity_cue -> concise (overrides topic complexity)"
    },
    "naming_convention": {
        "signals": {
            "snake_case": ["python", ".py", "def ", "import "],
            "camelCase": ["javascript", "typescript", ".js", ".ts", "const ", "let ", "function "],
            "UPPER_keywords": ["sql", "SELECT", "FROM", "WHERE", "database"]
        },
        "resolution": "determined by programming language detection"
    },
    "autonomy": {
        "signals": {
            "confirm": ["should I", "would you like", "complex", "multiple parts", "project"],
            "execute": ["do this", "make this", "just", "please", "now"]
        },
        "resolution": "ambiguous_task -> confirm; clear_instruction -> execute"
    },
    "code_presentation": {
        "signals": {
            "single_block": ["copy", "paste", "use this", "give me the code", "full code"],
            "chunked": ["teach", "explain", "understand", "walk through", "learn"]
        },
        "resolution": "copy_intent -> single_block; learning_intent -> chunked"
    }
}


def resolve_conflict(conflict_group: str, query: str, candidates: list) -> Optional[str]:
    """
    Programmatically resolve which preference wins in a conflict.

    Args:
        conflict_group: The conflict group ID
        query: The user query
        candidates: List of ConditionalPreference objects in this conflict

    Returns:
        pref_id of the winning preference, or None if cannot resolve
    """
    if conflict_group not in CONFLICT_RESOLUTION_RULES:
        return None

    rules = CONFLICT_RESOLUTION_RULES[conflict_group]
    query_lower = query.lower()

    # Score each candidate based on signal matches
    scores = {}
    for pref in candidates:
        scores[pref.pref_id] = 0

        # Check each signal category
        for signal_category, keywords in rules["signals"].items():
            for keyword in keywords:
                if keyword.lower() in query_lower:
                    # Check if this signal category matches this preference
                    for ctx in pref.priority_context:
                        if ctx.lower() in signal_category.lower() or signal_category.lower() in ctx.lower():
                            scores[pref.pref_id] += 1
                        # Also check if keyword is in priority context
                        if keyword.lower() in ctx.lower():
                            scores[pref.pref_id] += 1

    # Return highest scoring preference
    if scores:
        winner = max(scores, key=scores.get)
        if scores[winner] > 0:
            return winner

    return None


def create_conflict_test_case(conflict_group: str, preferences: list) -> dict:
    """
    Create a test case that triggers a specific conflict.

    Returns a dict with:
    - query: A query that triggers multiple preferences
    - triggered_prefs: List of preference IDs triggered
    - correct_pref: The preference that should win
    - resolution_reason: Why this preference wins
    """
    if conflict_group not in CONFLICT_RESOLUTION_RULES:
        return None

    rules = CONFLICT_RESOLUTION_RULES[conflict_group]

    # Create queries that trigger conflicts
    test_cases = {
        "format_structure": {
            "query": "How do I set up a Python virtual environment? List the main options.",
            "ambiguity": "Both 'set up' (procedure->numbered) and 'list options' (parallel->bullets)",
            "resolution": "Primary intent is setup procedure -> numbered steps"
        },
        "response_length": {
            "query": "Quick question - how does backpropagation work?",
            "ambiguity": "'Quick question' (concise) vs 'how does X work' (complex topic)",
            "resolution": "Explicit brevity cue 'quick question' overrides topic complexity"
        },
        "answer_position": {
            "query": "What is gradient descent and why is it used?",
            "ambiguity": "'What is' (answer first) vs 'why' (build up explanation)",
            "resolution": "Combined question: give brief answer, then explain why"
        },
        "naming_convention": {
            "query": "Write a function to parse JSON in both Python and JavaScript",
            "ambiguity": "Two languages with different conventions",
            "resolution": "Use appropriate convention for each: snake_case for Python, camelCase for JS"
        },
        "autonomy": {
            "query": "Refactor this authentication module to use JWT",
            "ambiguity": "'Refactor' is complex, but instruction is specific",
            "resolution": "Should confirm approach before major refactor"
        },
        "code_presentation": {
            "query": "I want to understand how this sorting algorithm works, give me the code",
            "ambiguity": "'understand' (chunked) vs 'give me the code' (single block)",
            "resolution": "Learning intent detected -> chunked with explanations"
        }
    }

    if conflict_group in test_cases:
        tc = test_cases[conflict_group]
        # Find which preferences are triggered
        triggered = [p for p in preferences if p.conflict_group == conflict_group]
        winner = resolve_conflict(conflict_group, tc["query"], triggered)

        return {
            "conflict_group": conflict_group,
            "query": tc["query"],
            "ambiguity": tc["ambiguity"],
            "triggered_pref_ids": [p.pref_id for p in triggered],
            "correct_pref_id": winner,
            "resolution_reason": tc["resolution"]
        }

    return None


# ============================================================================
# LLM-based Profile Generation
# ============================================================================

def generate_preferences_with_llm(
    category: str,
    model: str = "gpt-4o-mini",
    extra_context: str = ""
) -> list:
    """Generate preferences for a category using LLM."""
    if litellm is None:
        raise ImportError("litellm required for LLM generation")

    cat_info = PREFERENCE_CATEGORIES[category]
    prompt = LLM_GENERATION_PROMPT.format(
        num_prefs=cat_info["num_preferences"],
        category_name=category,
        category_description=cat_info["description"],
        example_conflicts=", ".join(cat_info["example_conflicts"]),
        category_prefix=category[:2],
        extra_context=extra_context or "None"
    )

    response = litellm.completion(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        response_format={"type": "json_object"}
    )

    content = response.choices[0].message.content
    # Extract JSON from response
    try:
        data = json.loads(content)
        if isinstance(data, dict) and "preferences" in data:
            data = data["preferences"]
        return [ConditionalPreference(**p) for p in data]
    except json.JSONDecodeError:
        # Try to extract JSON array from markdown code block
        import re
        match = re.search(r'\[[\s\S]*\]', content)
        if match:
            data = json.loads(match.group())
            return [ConditionalPreference(**p) for p in data]
        raise


def generate_persona_with_llm(
    preferences: list,
    model: str = "gpt-4o-mini"
) -> str:
    """Generate a persona that matches the preferences."""
    if litellm is None:
        raise ImportError("litellm required for LLM generation")

    # Summarize preferences by category
    by_cat = {}
    for p in preferences:
        cat = p.pref_id.split('_')[0]
        if cat not in by_cat:
            by_cat[cat] = []
        by_cat[cat].append(p.action[:50] + "...")

    summary = "\n".join([f"- {cat}: {', '.join(actions[:3])}" for cat, actions in by_cat.items()])

    prompt = PERSONA_GENERATION_PROMPT.format(preference_summary=summary)

    response = litellm.completion(
        model=model,
        messages=[{"role": "user", "content": prompt}]
    )

    return response.choices[0].message.content.strip()


def generate_full_profile(
    user_id: str,
    model: str = "gpt-4o-mini",
    categories: list = None
) -> UserProfile:
    """Generate a complete user profile with all preferences."""
    if categories is None:
        categories = list(PREFERENCE_CATEGORIES.keys())

    all_preferences = []
    for cat in categories:
        prefs = generate_preferences_with_llm(cat, model)
        all_preferences.extend(prefs)

    persona = generate_persona_with_llm(all_preferences, model)

    # Build conflict groups
    conflict_groups = {}
    for pref in all_preferences:
        if pref.conflict_group:
            if pref.conflict_group not in conflict_groups:
                conflict_groups[pref.conflict_group] = ConflictGroup(
                    group_id=pref.conflict_group,
                    description=CONFLICT_RESOLUTION_RULES.get(pref.conflict_group, {}).get("resolution", ""),
                    resolution_rule=CONFLICT_RESOLUTION_RULES.get(pref.conflict_group, {}).get("resolution", ""),
                    member_pref_ids=[]
                )
            conflict_groups[pref.conflict_group].member_pref_ids.append(pref.pref_id)

    return UserProfile(
        user_id=user_id,
        persona=persona,
        preferences=all_preferences,
        conflict_groups=conflict_groups
    )


# ============================================================================
# Dataset Loading and Challenging Question Selection
# ============================================================================

CHALLENGING_DATASETS = {
    # Existing datasets with difficulty filtering
    "math-hard": {
        "source": "lighteval/MATH-Hard",
        "filter": lambda x: x.get("level") in ["Level 4", "Level 5"],
        "encourage_step_by_step": True
    },
    "humaneval-hard": {
        "source": "openai_humaneval",
        "filter": lambda x: len(x.get("prompt", "")) > 200,  # Longer problems
        "encourage_step_by_step": True
    },

    # New challenging datasets to add
    "gpqa": {
        "source": "Idavidrein/gpqa",
        "description": "PhD-level science questions",
        "filter": lambda x: x.get("difficulty") == "hard",
        "encourage_step_by_step": True
    },
    "theoremqa": {
        "source": "wenhu/TheoremQA",
        "description": "Theorem-based math requiring multi-step proofs",
        "filter": None,
        "encourage_step_by_step": True
    },
    "livecodebench": {
        "source": "livecodebench/livecodebench",
        "description": "Recent competitive programming problems",
        "filter": lambda x: x.get("difficulty") in ["medium", "hard"],
        "encourage_step_by_step": True
    },
    "aime": {
        "source": "AI-MO/aimo-progress-prize",
        "description": "American Invitational Mathematics Examination",
        "filter": None,
        "encourage_step_by_step": True
    },
    "scicode": {
        "source": "scicode-bench/SciCode",
        "description": "Scientific computing problems",
        "filter": None,
        "encourage_step_by_step": True
    }
}


STEP_BY_STEP_PROMPT_ADDITIONS = {
    "math": """
When solving this problem:
1. First identify what type of problem this is
2. State the key concepts/theorems needed
3. Work through the solution step by step
4. Verify your answer
Take your time and show your reasoning at each step.""",

    "code": """
When solving this problem:
1. First understand the requirements and edge cases
2. Outline your approach before writing code
3. Implement step by step, explaining your logic
4. Consider time/space complexity
5. Test with example inputs
Show your reasoning throughout.""",

    "reasoning": """
When solving this problem:
1. Carefully read and identify the key information
2. State any assumptions you're making
3. Work through the logic step by step
4. Check for any flaws in your reasoning
5. State your conclusion clearly
Take your time and explain your thought process."""
}


# ============================================================================
# Batch Generation Script
# ============================================================================

def generate_profiles_batch(
    num_profiles: int,
    output_path: Path,
    model: str = "gpt-4o-mini",
    seed: int = 42
) -> list:
    """Generate multiple user profiles."""
    random.seed(seed)
    profiles = []

    for i in range(num_profiles):
        user_id = f"user_{hashlib.md5(f'{seed}_{i}'.encode()).hexdigest()[:8]}"

        # Optionally vary which categories are emphasized
        # Some users might have stronger code preferences, others math, etc.
        category_weights = {cat: random.random() for cat in PREFERENCE_CATEGORIES}

        try:
            profile = generate_full_profile(user_id, model)
            profiles.append(profile)
            print(f"Generated profile {i+1}/{num_profiles}: {user_id}")
        except Exception as e:
            print(f"Error generating profile {i+1}: {e}")
            continue

    # Save profiles
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, 'w') as f:
        for profile in profiles:
            f.write(json.dumps(profile.to_dict()) + '\n')

    print(f"Saved {len(profiles)} profiles to {output_path}")
    return profiles


def generate_conflict_test_suite(profiles: list, output_path: Path):
    """Generate test cases for conflict resolution evaluation."""
    test_cases = []

    for profile in profiles:
        for conflict_group in profile.conflict_groups:
            tc = create_conflict_test_case(
                conflict_group,
                profile.preferences
            )
            if tc:
                tc["user_id"] = profile.user_id
                test_cases.append(tc)

    with open(output_path, 'w') as f:
        json.dump(test_cases, f, indent=2)

    print(f"Generated {len(test_cases)} conflict test cases")
    return test_cases


# ============================================================================
# Main
# ============================================================================

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--num_profiles", type=int, default=10)
    parser.add_argument("--output_dir", type=str, default="collaborativeagents/data/complex_profiles")
    parser.add_argument("--model", type=str, default="gpt-4o-mini")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--generate_conflicts", action="store_true")

    args = parser.parse_args()

    output_dir = Path(args.output_dir)

    # Generate profiles
    profiles = generate_profiles_batch(
        num_profiles=args.num_profiles,
        output_path=output_dir / "profiles.jsonl",
        model=args.model,
        seed=args.seed
    )

    # Generate conflict test cases
    if args.generate_conflicts:
        generate_conflict_test_suite(
            profiles,
            output_path=output_dir / "conflict_tests.json"
        )