diff options
Diffstat (limited to 'scripts/pilot_runner_v3.py')
| -rw-r--r-- | scripts/pilot_runner_v3.py | 924 |
1 files changed, 924 insertions, 0 deletions
diff --git a/scripts/pilot_runner_v3.py b/scripts/pilot_runner_v3.py new file mode 100644 index 0000000..d232d10 --- /dev/null +++ b/scripts/pilot_runner_v3.py @@ -0,0 +1,924 @@ +#!/usr/bin/env python3 +""" +Pilot Runner v3 - Multi-User Multi-Session with Personas + +Upgrades from v2: +- Persona: Bundles StylePrefs into user types +- 5 test personas (A-E) targeting different style combinations +- Multi-user × multi-session evaluation +- Refined judge: bullets only on list tasks, relaxed empty_answer +- Baseline mode support (no-personalization comparison) + +5 Test Personas: +- A: short + bullets + en (sanity check) +- B: short + NO bullets + en (anti-bullet) +- C: long + bullets + en (no length constraint) +- D: short + bullets + zh (Chinese) +- E: long + NO bullets + zh (most "anti-default") +""" + +import sys +import os +import json +from datetime import datetime +from dataclasses import dataclass, asdict, field +from typing import List, Dict, Any, Optional, Tuple, Set, Literal + +# Add src to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.serving import PersonalizedLLM, Feedback, AssistantResponse + + +# ============================================================================= +# Style Preferences (True Preferences) +# ============================================================================= + +@dataclass +class StylePrefs: + """User's TRUE style preferences.""" + require_short: bool = False + max_chars: int = 300 + require_bullets: bool = False + lang: str = "en" # "en" or "zh" + + +# ============================================================================= +# Persona Definition +# ============================================================================= + +@dataclass +class Persona: + """ + A user persona that bundles style preferences. + Each persona represents a distinct user type for testing. + """ + persona_id: str + style_prefs: StylePrefs + description: str = "" + # Future extensions: + # task_preferences: Dict[str, float] # e.g., {"code": 0.3, "rewrite": 0.7} + # tone: str = "neutral" # "formal", "casual", etc. + # domain: str = "general" # "tech", "daily_life", etc. + + +# ============================================================================= +# 5 Test Personas (A-E) +# ============================================================================= + +PERSONA_A = Persona( + persona_id="A_short_bullets_en", + style_prefs=StylePrefs( + require_short=True, + max_chars=200, + require_bullets=True, + lang="en", + ), + description="Short + bullets + English (sanity check, same as v2)", +) + +PERSONA_B = Persona( + persona_id="B_short_no_bullets_en", + style_prefs=StylePrefs( + require_short=True, + max_chars=200, + require_bullets=False, + lang="en", + ), + description="Short + NO bullets + English (anti-bullet test)", +) + +PERSONA_C = Persona( + persona_id="C_long_bullets_en", + style_prefs=StylePrefs( + require_short=False, + max_chars=800, + require_bullets=True, + lang="en", + ), + description="Long + bullets + English (no length constraint)", +) + +PERSONA_D = Persona( + persona_id="D_short_bullets_zh", + style_prefs=StylePrefs( + require_short=True, + max_chars=200, + require_bullets=True, + lang="zh", + ), + description="Short + bullets + Chinese (language test)", +) + +PERSONA_E = Persona( + persona_id="E_long_no_bullets_zh", + style_prefs=StylePrefs( + require_short=False, + max_chars=800, + require_bullets=False, + lang="zh", + ), + description="Long + NO bullets + Chinese (most anti-default)", +) + +ALL_PERSONAS = [PERSONA_A, PERSONA_B, PERSONA_C, PERSONA_D, PERSONA_E] + + +def get_persona_by_id(persona_id: str) -> Optional[Persona]: + """Get persona by ID.""" + for p in ALL_PERSONAS: + if p.persona_id == persona_id: + return p + return None + + +# ============================================================================= +# Reveal State +# ============================================================================= + +@dataclass +class RevealState: + """Tracks which preferences have been explicitly revealed.""" + short_revealed: bool = False + bullets_revealed: bool = False + lang_revealed: bool = False + + def reset(self): + self.short_revealed = False + self.bullets_revealed = False + self.lang_revealed = False + + def to_dict(self) -> Dict[str, bool]: + return { + "short": self.short_revealed, + "bullets": self.bullets_revealed, + "lang": self.lang_revealed, + } + + def __str__(self) -> str: + flags = [] + if self.short_revealed: + flags.append("short") + if self.bullets_revealed: + flags.append("bullets") + if self.lang_revealed: + flags.append("lang") + return f"RevealState({', '.join(flags) if flags else 'none'})" + + +class RevealStateManager: + """Manages reveal state for multiple users.""" + + def __init__(self): + self._states: Dict[str, RevealState] = {} + + def get_state(self, user_id: str) -> RevealState: + if user_id not in self._states: + self._states[user_id] = RevealState() + return self._states[user_id] + + def reset_user(self, user_id: str): + if user_id in self._states: + self._states[user_id].reset() + else: + self._states[user_id] = RevealState() + + def reset_session(self, user_id: str): + pass # Reveal state persists across sessions + + +# ============================================================================= +# Preference Detection +# ============================================================================= + +def detect_revealed_preferences(query: str, prefs: StylePrefs) -> Dict[str, bool]: + """ + Detect which preferences are mentioned in a query. + Also considers the user's true preferences for language detection. + """ + lower_q = (query or "").lower() + + revealed = { + "short": False, + "bullets": False, + "lang": False, + } + + # Short/length preference + short_patterns = [ + "short", "concise", "brief", "under ", "less than", + "keep it short", "keep responses", "keep answers", + "maximum ", "max ", "characters", "words or less", + "200 ", "100 ", "50 ", "300 ", + ] + for pattern in short_patterns: + if pattern in lower_q: + revealed["short"] = True + break + + # Bullet preference (both positive and negative) + bullet_patterns = [ + "bullet", "bullet point", "bullet-point", + "bulleted", "list format", "use bullets", + "no bullet", "don't use bullet", "without bullet", + "numbered list", "use numbers", + ] + for pattern in bullet_patterns: + if pattern in lower_q: + revealed["bullets"] = True + break + + # Language preference + lang_patterns_zh = [ + "chinese", "中文", "in chinese", "用中文", + "speak chinese", "write chinese", "respond in chinese", + "please use chinese", "mandarin", "请用中文", + ] + lang_patterns_en = [ + "english", "in english", "use english", + "speak english", "write english", "respond in english", + "please use english", + ] + + for pattern in lang_patterns_zh + lang_patterns_en: + if pattern in lower_q: + revealed["lang"] = True + break + + return revealed + + +def update_reveal_state(reveal_state: RevealState, query: str, prefs: StylePrefs) -> Set[str]: + """Update reveal state based on query content.""" + detected = detect_revealed_preferences(query, prefs) + newly_revealed = set() + + if detected["short"] and not reveal_state.short_revealed: + reveal_state.short_revealed = True + newly_revealed.add("short") + + if detected["bullets"] and not reveal_state.bullets_revealed: + reveal_state.bullets_revealed = True + newly_revealed.add("bullets") + + if detected["lang"] and not reveal_state.lang_revealed: + reveal_state.lang_revealed = True + newly_revealed.add("lang") + + return newly_revealed + + +# ============================================================================= +# Refined Style Judge +# ============================================================================= + +@dataclass +class JudgeResult: + """Output from the judge for one turn.""" + sat_t: float + sev_t: float + prog_t: float + violations: List[str] + enforced_constraints: List[str] + + +def style_judge_v3( + query: str, + answer: str, + task_type: str, + prefs: StylePrefs, + reveal_state: RevealState, +) -> JudgeResult: + """ + Refined style judge with: + - Bullets only enforced on list-type tasks + - Relaxed empty_answer (only truly empty or single char) + - Reveal-aware enforcement + """ + violations: List[str] = [] + enforced: List[str] = [] + text = (answer or "").strip() + + # 0) Empty answer - only truly empty or single non-meaningful char + # Relaxed: allow short factual answers like "4", "Paris" + if len(text) == 0: + violations.append("empty_answer") + return JudgeResult( + sat_t=0.0, + sev_t=1.0, + prog_t=0.0, + violations=violations, + enforced_constraints=["non_empty"], + ) + + # 1) Length - enforce only if BOTH true AND revealed + if prefs.require_short and reveal_state.short_revealed: + enforced.append("short") + if len(text) > prefs.max_chars: + violations.append("too_long") + + # 2) Bullets - enforce ONLY on list-type tasks AND if revealed + # task_type "list" = listing tasks (Name X things, What are the N...) + # task_type "qa" = factual QA (What is the capital...) + # task_type "general" = other general tasks + if prefs.require_bullets and reveal_state.bullets_revealed: + if task_type == "list": # Only enforce on list tasks + enforced.append("bullets") + has_bullets = ("- " in text) or ("• " in text) or ("* " in text) or ("\n- " in text) + if not has_bullets: + violations.append("no_bullets") + + # 3) Language - enforce only if revealed + if reveal_state.lang_revealed: + enforced.append("lang") + if prefs.lang == "zh": + # For Chinese: should have significant non-ASCII content + ascii_count = sum(c.isascii() for c in text) + ascii_ratio = ascii_count / max(1, len(text)) + if ascii_ratio > 0.7: + violations.append("wrong_lang") + elif prefs.lang == "en": + # For English: should be mostly ASCII + ascii_count = sum(c.isascii() for c in text) + ascii_ratio = ascii_count / max(1, len(text)) + if ascii_ratio < 0.5: + violations.append("wrong_lang") + + # 4) Code task: always enforce + prog_t = 1.0 + if task_type == "code": + enforced.append("code_block") + has_code = ("```" in text) or ("def " in text) or ("function " in text) + if not has_code: + violations.append("no_code_block") + prog_t = 0.0 + + # 5) Compute scores + if not violations: + sat_t = 1.0 + sev_t = 0.0 + else: + sat_t = max(0.0, 1.0 - 0.3 * float(len(violations))) + hard_violations = {"empty_answer", "too_long", "wrong_lang"} + sev_t = 1.0 if any(v in hard_violations for v in violations) else 0.0 + + return JudgeResult( + sat_t=sat_t, + sev_t=sev_t, + prog_t=prog_t, + violations=violations, + enforced_constraints=enforced, + ) + + +# ============================================================================= +# Feedback Computation +# ============================================================================= + +def compute_feedback_for_turn( + turn_id: int, + query: str, + query_type: str, + task_type: str, + judge_result: JudgeResult, +) -> Tuple[float, float]: + """Convert JudgeResult into (reward, gating).""" + reward = judge_result.sat_t + + lower_q = (query or "").lower() + + is_pref_turn = ( + query_type == "preference" + or "i prefer" in lower_q + or "my preference" in lower_q + or "please use" in lower_q + or "please keep" in lower_q + or "you didn't follow" in lower_q + or "you forgot" in lower_q + or "remember that i" in lower_q + or "i told you" in lower_q + or "i asked for" in lower_q + or "中文" in lower_q + or "用中文" in lower_q + ) + + gating = 1.0 if is_pref_turn else 0.0 + return reward, gating + + +# ============================================================================= +# Query Generation per Persona +# ============================================================================= + +def get_session_1_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]: + """ + Session 1: Reveal preferences. + Customize based on persona's true preferences. + """ + queries = [] + prefs = persona.style_prefs + + # Turn 0: Reveal length preference + if prefs.require_short: + if prefs.lang == "zh": + queries.append({ + "query": "我喜欢简短的回答,请保持回复在200字以内。", + "type": "preference", + "task_type": "general", + }) + else: + queries.append({ + "query": "I prefer short, concise answers. Please keep responses under 200 characters.", + "type": "preference", + "task_type": "general", + }) + else: + # Long preference - don't reveal (let short_revealed stay False) + if prefs.lang == "zh": + queries.append({ + "query": "你好,我想了解一些问题。", + "type": "task", + "task_type": "general", + }) + else: + queries.append({ + "query": "Hello, I have some questions for you.", + "type": "task", + "task_type": "general", + }) + + # Turn 1: First task + if prefs.lang == "zh": + queries.append({ + "query": "列出三个改善睡眠的建议。", + "type": "task", + "task_type": "list", + }) + else: + queries.append({ + "query": "List three tips for better sleep.", + "type": "task", + "task_type": "list", + }) + + # Turn 2: Reveal bullet preference + if prefs.require_bullets: + if prefs.lang == "zh": + queries.append({ + "query": "我喜欢用项目符号列出要点,请使用bullet points。", + "type": "preference", + "task_type": "general", + }) + else: + queries.append({ + "query": "I prefer bullet points when listing things. Please use bullet points.", + "type": "preference", + "task_type": "general", + }) + else: + # Don't reveal bullet preference (or reveal anti-bullet) + if prefs.lang == "zh": + queries.append({ + "query": "请不要用项目符号,我更喜欢连续的句子。", + "type": "preference", + "task_type": "general", + }) + else: + queries.append({ + "query": "Please don't use bullet points. I prefer continuous prose.", + "type": "preference", + "task_type": "general", + }) + + # Turn 3: Reveal language preference (for non-English personas) + if prefs.lang == "zh": + queries.append({ + "query": "请用中文回答我的问题。", + "type": "preference", + "task_type": "general", + }) + else: + queries.append({ + "query": "Please respond in English.", + "type": "preference", + "task_type": "general", + }) + + # Turn 4-5: Tasks + if prefs.lang == "zh": + queries.extend([ + { + "query": "锻炼有什么好处?", + "type": "task", + "task_type": "list", + }, + { + "query": "列出五种流行的编程语言。", + "type": "task", + "task_type": "list", + }, + ]) + else: + queries.extend([ + { + "query": "What are the benefits of exercise?", + "type": "task", + "task_type": "list", + }, + { + "query": "Name five popular programming languages.", + "type": "task", + "task_type": "list", + }, + ]) + + return queries + + +def get_session_2_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]: + """ + Session 2: NO preference restatement. + Tests cross-session retention. + """ + prefs = persona.style_prefs + + if prefs.lang == "zh": + return [ + {"query": "推荐三种健康的早餐。", "type": "task", "task_type": "list"}, + {"query": "一年有哪四个季节?", "type": "task", "task_type": "list"}, + {"query": "法国的首都是哪里?", "type": "task", "task_type": "qa"}, + {"query": "列出三种可再生能源。", "type": "task", "task_type": "list"}, + ] + else: + return [ + {"query": "What are three healthy breakfast ideas?", "type": "task", "task_type": "list"}, + {"query": "What are the four seasons of the year?", "type": "task", "task_type": "list"}, + {"query": "What is the capital of France?", "type": "task", "task_type": "qa"}, + {"query": "Name three types of renewable energy.", "type": "task", "task_type": "list"}, + ] + + +def get_session_3_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]: + """ + Session 3: Mix of tasks and one reminder. + """ + prefs = persona.style_prefs + + if prefs.lang == "zh": + return [ + {"query": "列出五种常见的水果。", "type": "task", "task_type": "list"}, + {"query": "请记住我喜欢简短的回答。列出三种海洋动物。", "type": "preference", "task_type": "list"}, + {"query": "2加2等于多少?", "type": "task", "task_type": "qa"}, + ] + else: + return [ + {"query": "Name five common fruits.", "type": "task", "task_type": "list"}, + {"query": "Remember that I asked for short answers. List three ocean animals.", "type": "preference", "task_type": "list"}, + {"query": "What is 2 + 2?", "type": "task", "task_type": "qa"}, + ] + + +# ============================================================================= +# Logging +# ============================================================================= + +@dataclass +class TurnLog: + """Log entry for one turn.""" + user_id: str + persona_id: str + session_id: int + turn_id: int + query: str + query_type: str + task_type: str + answer: str + answer_length: int + sat_t: float + sev_t: float + prog_t: float + violations: List[str] + enforced_constraints: List[str] + reward: float + gating: float + reveal_state_before: Dict[str, bool] + reveal_state_after: Dict[str, bool] + newly_revealed: List[str] + z_long_norm_before: float + z_long_norm_after: float + z_short_norm_before: float + z_short_norm_after: float + prompt_tokens: int + completion_tokens: int + total_tokens: int + num_memories_retrieved: int + num_prefs_extracted: int + + +def log_to_jsonl(logs: List[TurnLog], filepath: str): + """Save logs to JSONL file.""" + os.makedirs(os.path.dirname(filepath), exist_ok=True) + with open(filepath, "w") as f: + for log in logs: + f.write(json.dumps(asdict(log)) + "\n") + + +# ============================================================================= +# Session Runner +# ============================================================================= + +def run_session( + llm: PersonalizedLLM, + user_id: str, + persona: Persona, + session_id: int, + reveal_state: RevealState, + queries: List[Dict[str, Any]], + all_logs: List[TurnLog], +) -> List[TurnLog]: + """Run a single session for a user.""" + + prefs = persona.style_prefs + session_logs: List[TurnLog] = [] + + print(f"\n{'='*60}") + print(f"[{persona.persona_id}] Session {session_id}: {len(queries)} turns") + print(f"Reveal state (start): {reveal_state}") + print(f"{'='*60}") + + # Reset session (clears history, z_short; keeps z_long and reveal state) + llm.reset_session(user_id) + + for turn_id, q_info in enumerate(queries): + query = q_info["query"] + query_type = q_info.get("type", "task") + task_type = q_info.get("task_type", "general") + + print(f"\n--- S{session_id}/T{turn_id} [{query_type}] ---") + print(f"[Q] {query[:60]}{'...' if len(query) > 60 else ''}") + + # Capture reveal state BEFORE + reveal_before = reveal_state.to_dict() + + # Update reveal state + newly_revealed = update_reveal_state(reveal_state, query, prefs) + if newly_revealed: + print(f"[Reveal] Newly: {newly_revealed}") + + reveal_after = reveal_state.to_dict() + + # Get user state before + state_before = llm.get_user_state_summary(user_id) + z_long_before = state_before["z_long_norm"] + z_short_before = state_before["z_short_norm"] + + # Apply feedback for previous turn + if turn_id > 0 and session_logs: + prev_log = session_logs[-1] + feedback = Feedback( + user_id=user_id, + turn_id=prev_log.turn_id, + reward=prev_log.reward, + gating=prev_log.gating, + meta={"source": "pilot_v3", "session_id": session_id} + ) + llm.apply_feedback(feedback) + + # Chat + resp: AssistantResponse = llm.chat(user_id, query) + + answer_display = resp.answer[:80] + "..." if len(resp.answer) > 80 else resp.answer + print(f"[A] ({len(resp.answer)}c) {answer_display}") + + # Judge + judge_result = style_judge_v3(query, resp.answer, task_type, prefs, reveal_state) + print(f"[J] sat={judge_result.sat_t:.2f}, enforced={judge_result.enforced_constraints}, viol={judge_result.violations}") + + # Compute feedback + reward, gating = compute_feedback_for_turn(turn_id, query, query_type, task_type, judge_result) + + # Get state after + state_after = llm.get_user_state_summary(user_id) + z_long_after = state_after["z_long_norm"] + z_short_after = state_after["z_short_norm"] + + # Debug info + num_memories = len(resp.debug.selected_memory_ids) if resp.debug else 0 + num_prefs = len(resp.debug.extracted_preferences) if resp.debug else 0 + + # Log + log = TurnLog( + user_id=user_id, + persona_id=persona.persona_id, + session_id=session_id, + turn_id=turn_id, + query=query, + query_type=query_type, + task_type=task_type, + answer=resp.answer, + answer_length=len(resp.answer), + sat_t=judge_result.sat_t, + sev_t=judge_result.sev_t, + prog_t=judge_result.prog_t, + violations=judge_result.violations, + enforced_constraints=judge_result.enforced_constraints, + reward=reward, + gating=gating, + reveal_state_before=reveal_before, + reveal_state_after=reveal_after, + newly_revealed=list(newly_revealed), + z_long_norm_before=z_long_before, + z_long_norm_after=z_long_after, + z_short_norm_before=z_short_before, + z_short_norm_after=z_short_after, + prompt_tokens=resp.usage.prompt_tokens, + completion_tokens=resp.usage.completion_tokens, + total_tokens=resp.usage.total_tokens, + num_memories_retrieved=num_memories, + num_prefs_extracted=num_prefs, + ) + session_logs.append(log) + all_logs.append(log) + + # Apply final feedback + if session_logs: + last_log = session_logs[-1] + feedback = Feedback( + user_id=user_id, + turn_id=last_log.turn_id, + reward=last_log.reward, + gating=last_log.gating, + meta={"source": "pilot_v3", "session_id": session_id, "final": True} + ) + llm.apply_feedback(feedback) + + return session_logs + + +# ============================================================================= +# Multi-User Multi-Session Runner +# ============================================================================= + +def run_multi_user_pilot( + llm: PersonalizedLLM, + personas: List[Persona], + num_sessions: int = 3, + reveal_manager: Optional[RevealStateManager] = None, +) -> List[TurnLog]: + """ + Run multi-user multi-session pilot. + + Args: + llm: PersonalizedLLM instance + personas: List of personas to test + num_sessions: Number of sessions per user + reveal_manager: Optional existing reveal manager + """ + if reveal_manager is None: + reveal_manager = RevealStateManager() + + all_logs: List[TurnLog] = [] + + print(f"\n{'#'*60}") + print(f"PILOT v3: MULTI-USER MULTI-SESSION") + print(f"Users: {len(personas)}, Sessions per user: {num_sessions}") + print(f"{'#'*60}") + + for persona in personas: + user_id = f"user_{persona.persona_id}" + prefs = persona.style_prefs + + print(f"\n{'*'*60}") + print(f"USER: {user_id}") + print(f"Persona: {persona.description}") + print(f"True prefs: short={prefs.require_short}, bullets={prefs.require_bullets}, lang={prefs.lang}") + print(f"{'*'*60}") + + # Reset user completely + llm.reset_user(user_id) + reveal_manager.reset_user(user_id) + reveal_state = reveal_manager.get_state(user_id) + + # Run sessions + for session_id in range(1, num_sessions + 1): + if session_id == 1: + queries = get_session_1_queries_for_persona(persona) + elif session_id == 2: + queries = get_session_2_queries_for_persona(persona) + else: + queries = get_session_3_queries_for_persona(persona) + + reveal_manager.reset_session(user_id) # No-op, just for clarity + run_session(llm, user_id, persona, session_id, reveal_state, queries, all_logs) + + return all_logs + + +# ============================================================================= +# Summary +# ============================================================================= + +def print_summary_v3(logs: List[TurnLog]): + """Print summary for pilot v3.""" + print(f"\n{'='*60}") + print("PILOT v3 SUMMARY - Multi-User Multi-Session") + print(f"{'='*60}") + + if not logs: + print("No logs.") + return + + from collections import Counter, defaultdict + + # Per-persona stats + personas = sorted(set(l.persona_id for l in logs)) + + print(f"\n--- Per-Persona Statistics ---") + for pid in personas: + p_logs = [l for l in logs if l.persona_id == pid] + + # Per-session breakdown + sessions = sorted(set(l.session_id for l in p_logs)) + + print(f"\n{pid}:") + for sid in sessions: + s_logs = [l for l in p_logs if l.session_id == sid] + avg_sat = sum(l.sat_t for l in s_logs) / len(s_logs) if s_logs else 0 + violations = [v for l in s_logs for v in l.violations] + enforced = set(c for l in s_logs for c in l.enforced_constraints) + + print(f" Session {sid}: {len(s_logs)} turns, avg_sat={avg_sat:.3f}, enforced={enforced}") + if violations: + print(f" violations: {dict(Counter(violations))}") + + # Cross-session retention check + print(f"\n--- Cross-Session Retention ---") + for pid in personas: + p_logs = [l for l in logs if l.persona_id == pid] + s1_logs = [l for l in p_logs if l.session_id == 1] + s2_logs = [l for l in p_logs if l.session_id == 2] + + if s1_logs and s2_logs: + s1_sat = sum(l.sat_t for l in s1_logs) / len(s1_logs) + s2_sat = sum(l.sat_t for l in s2_logs) / len(s2_logs) + + # Check what was enforced in S2 + s2_enforced = set(c for l in s2_logs for c in l.enforced_constraints) + + print(f"{pid}: S1_sat={s1_sat:.3f} → S2_sat={s2_sat:.3f}, S2_enforced={s2_enforced}") + + # Overall stats + total = len(logs) + avg_sat = sum(l.sat_t for l in logs) / total + total_tokens = sum(l.total_tokens for l in logs) + + print(f"\n--- Overall ---") + print(f"Total turns: {total}") + print(f"Overall avg sat_t: {avg_sat:.3f}") + print(f"Total tokens: {total_tokens}") + + # Violations by type + all_violations = [v for l in logs for v in l.violations] + if all_violations: + print(f"\nViolations: {dict(Counter(all_violations))}") + + +def main(): + print("=" * 60) + print("PILOT RUNNER v3 - Multi-User Multi-Session with Personas") + print("=" * 60) + print(f"Started at: {datetime.now().isoformat()}") + + # Select personas + personas = ALL_PERSONAS # All 5 personas + print(f"\n[Config] Running {len(personas)} personas:") + for p in personas: + print(f" - {p.persona_id}: {p.description}") + + # Initialize LLM + print("\n[Init] Loading PersonalizedLLM...") + llm = PersonalizedLLM( + user_store_path="data/users/user_store_pilot_v3.npz", + only_own_memories=True, + enable_preference_extraction=True, + enable_rl_updates=True, + ) + + # Run pilot + logs = run_multi_user_pilot(llm, personas, num_sessions=3) + + # Summary + print_summary_v3(logs) + + # Save logs + log_path = f"data/logs/pilot_v3_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl" + log_to_jsonl(logs, log_path) + print(f"\n[Logs] Saved to: {log_path}") + + print(f"\nCompleted at: {datetime.now().isoformat()}") + print("=" * 60) + + +if __name__ == "__main__": + main() + |
