summaryrefslogtreecommitdiff
path: root/scripts/pilot_runner_v3.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
commite43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch)
tree6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /scripts/pilot_runner_v3.py
Initial commit (clean history)HEADmain
Diffstat (limited to 'scripts/pilot_runner_v3.py')
-rw-r--r--scripts/pilot_runner_v3.py924
1 files changed, 924 insertions, 0 deletions
diff --git a/scripts/pilot_runner_v3.py b/scripts/pilot_runner_v3.py
new file mode 100644
index 0000000..d232d10
--- /dev/null
+++ b/scripts/pilot_runner_v3.py
@@ -0,0 +1,924 @@
+#!/usr/bin/env python3
+"""
+Pilot Runner v3 - Multi-User Multi-Session with Personas
+
+Upgrades from v2:
+- Persona: Bundles StylePrefs into user types
+- 5 test personas (A-E) targeting different style combinations
+- Multi-user × multi-session evaluation
+- Refined judge: bullets only on list tasks, relaxed empty_answer
+- Baseline mode support (no-personalization comparison)
+
+5 Test Personas:
+- A: short + bullets + en (sanity check)
+- B: short + NO bullets + en (anti-bullet)
+- C: long + bullets + en (no length constraint)
+- D: short + bullets + zh (Chinese)
+- E: long + NO bullets + zh (most "anti-default")
+"""
+
+import sys
+import os
+import json
+from datetime import datetime
+from dataclasses import dataclass, asdict, field
+from typing import List, Dict, Any, Optional, Tuple, Set, Literal
+
+# Add src to path
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.serving import PersonalizedLLM, Feedback, AssistantResponse
+
+
+# =============================================================================
+# Style Preferences (True Preferences)
+# =============================================================================
+
+@dataclass
+class StylePrefs:
+ """User's TRUE style preferences."""
+ require_short: bool = False
+ max_chars: int = 300
+ require_bullets: bool = False
+ lang: str = "en" # "en" or "zh"
+
+
+# =============================================================================
+# Persona Definition
+# =============================================================================
+
+@dataclass
+class Persona:
+ """
+ A user persona that bundles style preferences.
+ Each persona represents a distinct user type for testing.
+ """
+ persona_id: str
+ style_prefs: StylePrefs
+ description: str = ""
+ # Future extensions:
+ # task_preferences: Dict[str, float] # e.g., {"code": 0.3, "rewrite": 0.7}
+ # tone: str = "neutral" # "formal", "casual", etc.
+ # domain: str = "general" # "tech", "daily_life", etc.
+
+
+# =============================================================================
+# 5 Test Personas (A-E)
+# =============================================================================
+
+PERSONA_A = Persona(
+ persona_id="A_short_bullets_en",
+ style_prefs=StylePrefs(
+ require_short=True,
+ max_chars=200,
+ require_bullets=True,
+ lang="en",
+ ),
+ description="Short + bullets + English (sanity check, same as v2)",
+)
+
+PERSONA_B = Persona(
+ persona_id="B_short_no_bullets_en",
+ style_prefs=StylePrefs(
+ require_short=True,
+ max_chars=200,
+ require_bullets=False,
+ lang="en",
+ ),
+ description="Short + NO bullets + English (anti-bullet test)",
+)
+
+PERSONA_C = Persona(
+ persona_id="C_long_bullets_en",
+ style_prefs=StylePrefs(
+ require_short=False,
+ max_chars=800,
+ require_bullets=True,
+ lang="en",
+ ),
+ description="Long + bullets + English (no length constraint)",
+)
+
+PERSONA_D = Persona(
+ persona_id="D_short_bullets_zh",
+ style_prefs=StylePrefs(
+ require_short=True,
+ max_chars=200,
+ require_bullets=True,
+ lang="zh",
+ ),
+ description="Short + bullets + Chinese (language test)",
+)
+
+PERSONA_E = Persona(
+ persona_id="E_long_no_bullets_zh",
+ style_prefs=StylePrefs(
+ require_short=False,
+ max_chars=800,
+ require_bullets=False,
+ lang="zh",
+ ),
+ description="Long + NO bullets + Chinese (most anti-default)",
+)
+
+ALL_PERSONAS = [PERSONA_A, PERSONA_B, PERSONA_C, PERSONA_D, PERSONA_E]
+
+
+def get_persona_by_id(persona_id: str) -> Optional[Persona]:
+ """Get persona by ID."""
+ for p in ALL_PERSONAS:
+ if p.persona_id == persona_id:
+ return p
+ return None
+
+
+# =============================================================================
+# Reveal State
+# =============================================================================
+
+@dataclass
+class RevealState:
+ """Tracks which preferences have been explicitly revealed."""
+ short_revealed: bool = False
+ bullets_revealed: bool = False
+ lang_revealed: bool = False
+
+ def reset(self):
+ self.short_revealed = False
+ self.bullets_revealed = False
+ self.lang_revealed = False
+
+ def to_dict(self) -> Dict[str, bool]:
+ return {
+ "short": self.short_revealed,
+ "bullets": self.bullets_revealed,
+ "lang": self.lang_revealed,
+ }
+
+ def __str__(self) -> str:
+ flags = []
+ if self.short_revealed:
+ flags.append("short")
+ if self.bullets_revealed:
+ flags.append("bullets")
+ if self.lang_revealed:
+ flags.append("lang")
+ return f"RevealState({', '.join(flags) if flags else 'none'})"
+
+
+class RevealStateManager:
+ """Manages reveal state for multiple users."""
+
+ def __init__(self):
+ self._states: Dict[str, RevealState] = {}
+
+ def get_state(self, user_id: str) -> RevealState:
+ if user_id not in self._states:
+ self._states[user_id] = RevealState()
+ return self._states[user_id]
+
+ def reset_user(self, user_id: str):
+ if user_id in self._states:
+ self._states[user_id].reset()
+ else:
+ self._states[user_id] = RevealState()
+
+ def reset_session(self, user_id: str):
+ pass # Reveal state persists across sessions
+
+
+# =============================================================================
+# Preference Detection
+# =============================================================================
+
+def detect_revealed_preferences(query: str, prefs: StylePrefs) -> Dict[str, bool]:
+ """
+ Detect which preferences are mentioned in a query.
+ Also considers the user's true preferences for language detection.
+ """
+ lower_q = (query or "").lower()
+
+ revealed = {
+ "short": False,
+ "bullets": False,
+ "lang": False,
+ }
+
+ # Short/length preference
+ short_patterns = [
+ "short", "concise", "brief", "under ", "less than",
+ "keep it short", "keep responses", "keep answers",
+ "maximum ", "max ", "characters", "words or less",
+ "200 ", "100 ", "50 ", "300 ",
+ ]
+ for pattern in short_patterns:
+ if pattern in lower_q:
+ revealed["short"] = True
+ break
+
+ # Bullet preference (both positive and negative)
+ bullet_patterns = [
+ "bullet", "bullet point", "bullet-point",
+ "bulleted", "list format", "use bullets",
+ "no bullet", "don't use bullet", "without bullet",
+ "numbered list", "use numbers",
+ ]
+ for pattern in bullet_patterns:
+ if pattern in lower_q:
+ revealed["bullets"] = True
+ break
+
+ # Language preference
+ lang_patterns_zh = [
+ "chinese", "中文", "in chinese", "用中文",
+ "speak chinese", "write chinese", "respond in chinese",
+ "please use chinese", "mandarin", "请用中文",
+ ]
+ lang_patterns_en = [
+ "english", "in english", "use english",
+ "speak english", "write english", "respond in english",
+ "please use english",
+ ]
+
+ for pattern in lang_patterns_zh + lang_patterns_en:
+ if pattern in lower_q:
+ revealed["lang"] = True
+ break
+
+ return revealed
+
+
+def update_reveal_state(reveal_state: RevealState, query: str, prefs: StylePrefs) -> Set[str]:
+ """Update reveal state based on query content."""
+ detected = detect_revealed_preferences(query, prefs)
+ newly_revealed = set()
+
+ if detected["short"] and not reveal_state.short_revealed:
+ reveal_state.short_revealed = True
+ newly_revealed.add("short")
+
+ if detected["bullets"] and not reveal_state.bullets_revealed:
+ reveal_state.bullets_revealed = True
+ newly_revealed.add("bullets")
+
+ if detected["lang"] and not reveal_state.lang_revealed:
+ reveal_state.lang_revealed = True
+ newly_revealed.add("lang")
+
+ return newly_revealed
+
+
+# =============================================================================
+# Refined Style Judge
+# =============================================================================
+
+@dataclass
+class JudgeResult:
+ """Output from the judge for one turn."""
+ sat_t: float
+ sev_t: float
+ prog_t: float
+ violations: List[str]
+ enforced_constraints: List[str]
+
+
+def style_judge_v3(
+ query: str,
+ answer: str,
+ task_type: str,
+ prefs: StylePrefs,
+ reveal_state: RevealState,
+) -> JudgeResult:
+ """
+ Refined style judge with:
+ - Bullets only enforced on list-type tasks
+ - Relaxed empty_answer (only truly empty or single char)
+ - Reveal-aware enforcement
+ """
+ violations: List[str] = []
+ enforced: List[str] = []
+ text = (answer or "").strip()
+
+ # 0) Empty answer - only truly empty or single non-meaningful char
+ # Relaxed: allow short factual answers like "4", "Paris"
+ if len(text) == 0:
+ violations.append("empty_answer")
+ return JudgeResult(
+ sat_t=0.0,
+ sev_t=1.0,
+ prog_t=0.0,
+ violations=violations,
+ enforced_constraints=["non_empty"],
+ )
+
+ # 1) Length - enforce only if BOTH true AND revealed
+ if prefs.require_short and reveal_state.short_revealed:
+ enforced.append("short")
+ if len(text) > prefs.max_chars:
+ violations.append("too_long")
+
+ # 2) Bullets - enforce ONLY on list-type tasks AND if revealed
+ # task_type "list" = listing tasks (Name X things, What are the N...)
+ # task_type "qa" = factual QA (What is the capital...)
+ # task_type "general" = other general tasks
+ if prefs.require_bullets and reveal_state.bullets_revealed:
+ if task_type == "list": # Only enforce on list tasks
+ enforced.append("bullets")
+ has_bullets = ("- " in text) or ("• " in text) or ("* " in text) or ("\n- " in text)
+ if not has_bullets:
+ violations.append("no_bullets")
+
+ # 3) Language - enforce only if revealed
+ if reveal_state.lang_revealed:
+ enforced.append("lang")
+ if prefs.lang == "zh":
+ # For Chinese: should have significant non-ASCII content
+ ascii_count = sum(c.isascii() for c in text)
+ ascii_ratio = ascii_count / max(1, len(text))
+ if ascii_ratio > 0.7:
+ violations.append("wrong_lang")
+ elif prefs.lang == "en":
+ # For English: should be mostly ASCII
+ ascii_count = sum(c.isascii() for c in text)
+ ascii_ratio = ascii_count / max(1, len(text))
+ if ascii_ratio < 0.5:
+ violations.append("wrong_lang")
+
+ # 4) Code task: always enforce
+ prog_t = 1.0
+ if task_type == "code":
+ enforced.append("code_block")
+ has_code = ("```" in text) or ("def " in text) or ("function " in text)
+ if not has_code:
+ violations.append("no_code_block")
+ prog_t = 0.0
+
+ # 5) Compute scores
+ if not violations:
+ sat_t = 1.0
+ sev_t = 0.0
+ else:
+ sat_t = max(0.0, 1.0 - 0.3 * float(len(violations)))
+ hard_violations = {"empty_answer", "too_long", "wrong_lang"}
+ sev_t = 1.0 if any(v in hard_violations for v in violations) else 0.0
+
+ return JudgeResult(
+ sat_t=sat_t,
+ sev_t=sev_t,
+ prog_t=prog_t,
+ violations=violations,
+ enforced_constraints=enforced,
+ )
+
+
+# =============================================================================
+# Feedback Computation
+# =============================================================================
+
+def compute_feedback_for_turn(
+ turn_id: int,
+ query: str,
+ query_type: str,
+ task_type: str,
+ judge_result: JudgeResult,
+) -> Tuple[float, float]:
+ """Convert JudgeResult into (reward, gating)."""
+ reward = judge_result.sat_t
+
+ lower_q = (query or "").lower()
+
+ is_pref_turn = (
+ query_type == "preference"
+ or "i prefer" in lower_q
+ or "my preference" in lower_q
+ or "please use" in lower_q
+ or "please keep" in lower_q
+ or "you didn't follow" in lower_q
+ or "you forgot" in lower_q
+ or "remember that i" in lower_q
+ or "i told you" in lower_q
+ or "i asked for" in lower_q
+ or "中文" in lower_q
+ or "用中文" in lower_q
+ )
+
+ gating = 1.0 if is_pref_turn else 0.0
+ return reward, gating
+
+
+# =============================================================================
+# Query Generation per Persona
+# =============================================================================
+
+def get_session_1_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]:
+ """
+ Session 1: Reveal preferences.
+ Customize based on persona's true preferences.
+ """
+ queries = []
+ prefs = persona.style_prefs
+
+ # Turn 0: Reveal length preference
+ if prefs.require_short:
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "我喜欢简短的回答,请保持回复在200字以内。",
+ "type": "preference",
+ "task_type": "general",
+ })
+ else:
+ queries.append({
+ "query": "I prefer short, concise answers. Please keep responses under 200 characters.",
+ "type": "preference",
+ "task_type": "general",
+ })
+ else:
+ # Long preference - don't reveal (let short_revealed stay False)
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "你好,我想了解一些问题。",
+ "type": "task",
+ "task_type": "general",
+ })
+ else:
+ queries.append({
+ "query": "Hello, I have some questions for you.",
+ "type": "task",
+ "task_type": "general",
+ })
+
+ # Turn 1: First task
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "列出三个改善睡眠的建议。",
+ "type": "task",
+ "task_type": "list",
+ })
+ else:
+ queries.append({
+ "query": "List three tips for better sleep.",
+ "type": "task",
+ "task_type": "list",
+ })
+
+ # Turn 2: Reveal bullet preference
+ if prefs.require_bullets:
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "我喜欢用项目符号列出要点,请使用bullet points。",
+ "type": "preference",
+ "task_type": "general",
+ })
+ else:
+ queries.append({
+ "query": "I prefer bullet points when listing things. Please use bullet points.",
+ "type": "preference",
+ "task_type": "general",
+ })
+ else:
+ # Don't reveal bullet preference (or reveal anti-bullet)
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "请不要用项目符号,我更喜欢连续的句子。",
+ "type": "preference",
+ "task_type": "general",
+ })
+ else:
+ queries.append({
+ "query": "Please don't use bullet points. I prefer continuous prose.",
+ "type": "preference",
+ "task_type": "general",
+ })
+
+ # Turn 3: Reveal language preference (for non-English personas)
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "请用中文回答我的问题。",
+ "type": "preference",
+ "task_type": "general",
+ })
+ else:
+ queries.append({
+ "query": "Please respond in English.",
+ "type": "preference",
+ "task_type": "general",
+ })
+
+ # Turn 4-5: Tasks
+ if prefs.lang == "zh":
+ queries.extend([
+ {
+ "query": "锻炼有什么好处?",
+ "type": "task",
+ "task_type": "list",
+ },
+ {
+ "query": "列出五种流行的编程语言。",
+ "type": "task",
+ "task_type": "list",
+ },
+ ])
+ else:
+ queries.extend([
+ {
+ "query": "What are the benefits of exercise?",
+ "type": "task",
+ "task_type": "list",
+ },
+ {
+ "query": "Name five popular programming languages.",
+ "type": "task",
+ "task_type": "list",
+ },
+ ])
+
+ return queries
+
+
+def get_session_2_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]:
+ """
+ Session 2: NO preference restatement.
+ Tests cross-session retention.
+ """
+ prefs = persona.style_prefs
+
+ if prefs.lang == "zh":
+ return [
+ {"query": "推荐三种健康的早餐。", "type": "task", "task_type": "list"},
+ {"query": "一年有哪四个季节?", "type": "task", "task_type": "list"},
+ {"query": "法国的首都是哪里?", "type": "task", "task_type": "qa"},
+ {"query": "列出三种可再生能源。", "type": "task", "task_type": "list"},
+ ]
+ else:
+ return [
+ {"query": "What are three healthy breakfast ideas?", "type": "task", "task_type": "list"},
+ {"query": "What are the four seasons of the year?", "type": "task", "task_type": "list"},
+ {"query": "What is the capital of France?", "type": "task", "task_type": "qa"},
+ {"query": "Name three types of renewable energy.", "type": "task", "task_type": "list"},
+ ]
+
+
+def get_session_3_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]:
+ """
+ Session 3: Mix of tasks and one reminder.
+ """
+ prefs = persona.style_prefs
+
+ if prefs.lang == "zh":
+ return [
+ {"query": "列出五种常见的水果。", "type": "task", "task_type": "list"},
+ {"query": "请记住我喜欢简短的回答。列出三种海洋动物。", "type": "preference", "task_type": "list"},
+ {"query": "2加2等于多少?", "type": "task", "task_type": "qa"},
+ ]
+ else:
+ return [
+ {"query": "Name five common fruits.", "type": "task", "task_type": "list"},
+ {"query": "Remember that I asked for short answers. List three ocean animals.", "type": "preference", "task_type": "list"},
+ {"query": "What is 2 + 2?", "type": "task", "task_type": "qa"},
+ ]
+
+
+# =============================================================================
+# Logging
+# =============================================================================
+
+@dataclass
+class TurnLog:
+ """Log entry for one turn."""
+ user_id: str
+ persona_id: str
+ session_id: int
+ turn_id: int
+ query: str
+ query_type: str
+ task_type: str
+ answer: str
+ answer_length: int
+ sat_t: float
+ sev_t: float
+ prog_t: float
+ violations: List[str]
+ enforced_constraints: List[str]
+ reward: float
+ gating: float
+ reveal_state_before: Dict[str, bool]
+ reveal_state_after: Dict[str, bool]
+ newly_revealed: List[str]
+ z_long_norm_before: float
+ z_long_norm_after: float
+ z_short_norm_before: float
+ z_short_norm_after: float
+ prompt_tokens: int
+ completion_tokens: int
+ total_tokens: int
+ num_memories_retrieved: int
+ num_prefs_extracted: int
+
+
+def log_to_jsonl(logs: List[TurnLog], filepath: str):
+ """Save logs to JSONL file."""
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
+ with open(filepath, "w") as f:
+ for log in logs:
+ f.write(json.dumps(asdict(log)) + "\n")
+
+
+# =============================================================================
+# Session Runner
+# =============================================================================
+
+def run_session(
+ llm: PersonalizedLLM,
+ user_id: str,
+ persona: Persona,
+ session_id: int,
+ reveal_state: RevealState,
+ queries: List[Dict[str, Any]],
+ all_logs: List[TurnLog],
+) -> List[TurnLog]:
+ """Run a single session for a user."""
+
+ prefs = persona.style_prefs
+ session_logs: List[TurnLog] = []
+
+ print(f"\n{'='*60}")
+ print(f"[{persona.persona_id}] Session {session_id}: {len(queries)} turns")
+ print(f"Reveal state (start): {reveal_state}")
+ print(f"{'='*60}")
+
+ # Reset session (clears history, z_short; keeps z_long and reveal state)
+ llm.reset_session(user_id)
+
+ for turn_id, q_info in enumerate(queries):
+ query = q_info["query"]
+ query_type = q_info.get("type", "task")
+ task_type = q_info.get("task_type", "general")
+
+ print(f"\n--- S{session_id}/T{turn_id} [{query_type}] ---")
+ print(f"[Q] {query[:60]}{'...' if len(query) > 60 else ''}")
+
+ # Capture reveal state BEFORE
+ reveal_before = reveal_state.to_dict()
+
+ # Update reveal state
+ newly_revealed = update_reveal_state(reveal_state, query, prefs)
+ if newly_revealed:
+ print(f"[Reveal] Newly: {newly_revealed}")
+
+ reveal_after = reveal_state.to_dict()
+
+ # Get user state before
+ state_before = llm.get_user_state_summary(user_id)
+ z_long_before = state_before["z_long_norm"]
+ z_short_before = state_before["z_short_norm"]
+
+ # Apply feedback for previous turn
+ if turn_id > 0 and session_logs:
+ prev_log = session_logs[-1]
+ feedback = Feedback(
+ user_id=user_id,
+ turn_id=prev_log.turn_id,
+ reward=prev_log.reward,
+ gating=prev_log.gating,
+ meta={"source": "pilot_v3", "session_id": session_id}
+ )
+ llm.apply_feedback(feedback)
+
+ # Chat
+ resp: AssistantResponse = llm.chat(user_id, query)
+
+ answer_display = resp.answer[:80] + "..." if len(resp.answer) > 80 else resp.answer
+ print(f"[A] ({len(resp.answer)}c) {answer_display}")
+
+ # Judge
+ judge_result = style_judge_v3(query, resp.answer, task_type, prefs, reveal_state)
+ print(f"[J] sat={judge_result.sat_t:.2f}, enforced={judge_result.enforced_constraints}, viol={judge_result.violations}")
+
+ # Compute feedback
+ reward, gating = compute_feedback_for_turn(turn_id, query, query_type, task_type, judge_result)
+
+ # Get state after
+ state_after = llm.get_user_state_summary(user_id)
+ z_long_after = state_after["z_long_norm"]
+ z_short_after = state_after["z_short_norm"]
+
+ # Debug info
+ num_memories = len(resp.debug.selected_memory_ids) if resp.debug else 0
+ num_prefs = len(resp.debug.extracted_preferences) if resp.debug else 0
+
+ # Log
+ log = TurnLog(
+ user_id=user_id,
+ persona_id=persona.persona_id,
+ session_id=session_id,
+ turn_id=turn_id,
+ query=query,
+ query_type=query_type,
+ task_type=task_type,
+ answer=resp.answer,
+ answer_length=len(resp.answer),
+ sat_t=judge_result.sat_t,
+ sev_t=judge_result.sev_t,
+ prog_t=judge_result.prog_t,
+ violations=judge_result.violations,
+ enforced_constraints=judge_result.enforced_constraints,
+ reward=reward,
+ gating=gating,
+ reveal_state_before=reveal_before,
+ reveal_state_after=reveal_after,
+ newly_revealed=list(newly_revealed),
+ z_long_norm_before=z_long_before,
+ z_long_norm_after=z_long_after,
+ z_short_norm_before=z_short_before,
+ z_short_norm_after=z_short_after,
+ prompt_tokens=resp.usage.prompt_tokens,
+ completion_tokens=resp.usage.completion_tokens,
+ total_tokens=resp.usage.total_tokens,
+ num_memories_retrieved=num_memories,
+ num_prefs_extracted=num_prefs,
+ )
+ session_logs.append(log)
+ all_logs.append(log)
+
+ # Apply final feedback
+ if session_logs:
+ last_log = session_logs[-1]
+ feedback = Feedback(
+ user_id=user_id,
+ turn_id=last_log.turn_id,
+ reward=last_log.reward,
+ gating=last_log.gating,
+ meta={"source": "pilot_v3", "session_id": session_id, "final": True}
+ )
+ llm.apply_feedback(feedback)
+
+ return session_logs
+
+
+# =============================================================================
+# Multi-User Multi-Session Runner
+# =============================================================================
+
+def run_multi_user_pilot(
+ llm: PersonalizedLLM,
+ personas: List[Persona],
+ num_sessions: int = 3,
+ reveal_manager: Optional[RevealStateManager] = None,
+) -> List[TurnLog]:
+ """
+ Run multi-user multi-session pilot.
+
+ Args:
+ llm: PersonalizedLLM instance
+ personas: List of personas to test
+ num_sessions: Number of sessions per user
+ reveal_manager: Optional existing reveal manager
+ """
+ if reveal_manager is None:
+ reveal_manager = RevealStateManager()
+
+ all_logs: List[TurnLog] = []
+
+ print(f"\n{'#'*60}")
+ print(f"PILOT v3: MULTI-USER MULTI-SESSION")
+ print(f"Users: {len(personas)}, Sessions per user: {num_sessions}")
+ print(f"{'#'*60}")
+
+ for persona in personas:
+ user_id = f"user_{persona.persona_id}"
+ prefs = persona.style_prefs
+
+ print(f"\n{'*'*60}")
+ print(f"USER: {user_id}")
+ print(f"Persona: {persona.description}")
+ print(f"True prefs: short={prefs.require_short}, bullets={prefs.require_bullets}, lang={prefs.lang}")
+ print(f"{'*'*60}")
+
+ # Reset user completely
+ llm.reset_user(user_id)
+ reveal_manager.reset_user(user_id)
+ reveal_state = reveal_manager.get_state(user_id)
+
+ # Run sessions
+ for session_id in range(1, num_sessions + 1):
+ if session_id == 1:
+ queries = get_session_1_queries_for_persona(persona)
+ elif session_id == 2:
+ queries = get_session_2_queries_for_persona(persona)
+ else:
+ queries = get_session_3_queries_for_persona(persona)
+
+ reveal_manager.reset_session(user_id) # No-op, just for clarity
+ run_session(llm, user_id, persona, session_id, reveal_state, queries, all_logs)
+
+ return all_logs
+
+
+# =============================================================================
+# Summary
+# =============================================================================
+
+def print_summary_v3(logs: List[TurnLog]):
+ """Print summary for pilot v3."""
+ print(f"\n{'='*60}")
+ print("PILOT v3 SUMMARY - Multi-User Multi-Session")
+ print(f"{'='*60}")
+
+ if not logs:
+ print("No logs.")
+ return
+
+ from collections import Counter, defaultdict
+
+ # Per-persona stats
+ personas = sorted(set(l.persona_id for l in logs))
+
+ print(f"\n--- Per-Persona Statistics ---")
+ for pid in personas:
+ p_logs = [l for l in logs if l.persona_id == pid]
+
+ # Per-session breakdown
+ sessions = sorted(set(l.session_id for l in p_logs))
+
+ print(f"\n{pid}:")
+ for sid in sessions:
+ s_logs = [l for l in p_logs if l.session_id == sid]
+ avg_sat = sum(l.sat_t for l in s_logs) / len(s_logs) if s_logs else 0
+ violations = [v for l in s_logs for v in l.violations]
+ enforced = set(c for l in s_logs for c in l.enforced_constraints)
+
+ print(f" Session {sid}: {len(s_logs)} turns, avg_sat={avg_sat:.3f}, enforced={enforced}")
+ if violations:
+ print(f" violations: {dict(Counter(violations))}")
+
+ # Cross-session retention check
+ print(f"\n--- Cross-Session Retention ---")
+ for pid in personas:
+ p_logs = [l for l in logs if l.persona_id == pid]
+ s1_logs = [l for l in p_logs if l.session_id == 1]
+ s2_logs = [l for l in p_logs if l.session_id == 2]
+
+ if s1_logs and s2_logs:
+ s1_sat = sum(l.sat_t for l in s1_logs) / len(s1_logs)
+ s2_sat = sum(l.sat_t for l in s2_logs) / len(s2_logs)
+
+ # Check what was enforced in S2
+ s2_enforced = set(c for l in s2_logs for c in l.enforced_constraints)
+
+ print(f"{pid}: S1_sat={s1_sat:.3f} → S2_sat={s2_sat:.3f}, S2_enforced={s2_enforced}")
+
+ # Overall stats
+ total = len(logs)
+ avg_sat = sum(l.sat_t for l in logs) / total
+ total_tokens = sum(l.total_tokens for l in logs)
+
+ print(f"\n--- Overall ---")
+ print(f"Total turns: {total}")
+ print(f"Overall avg sat_t: {avg_sat:.3f}")
+ print(f"Total tokens: {total_tokens}")
+
+ # Violations by type
+ all_violations = [v for l in logs for v in l.violations]
+ if all_violations:
+ print(f"\nViolations: {dict(Counter(all_violations))}")
+
+
+def main():
+ print("=" * 60)
+ print("PILOT RUNNER v3 - Multi-User Multi-Session with Personas")
+ print("=" * 60)
+ print(f"Started at: {datetime.now().isoformat()}")
+
+ # Select personas
+ personas = ALL_PERSONAS # All 5 personas
+ print(f"\n[Config] Running {len(personas)} personas:")
+ for p in personas:
+ print(f" - {p.persona_id}: {p.description}")
+
+ # Initialize LLM
+ print("\n[Init] Loading PersonalizedLLM...")
+ llm = PersonalizedLLM(
+ user_store_path="data/users/user_store_pilot_v3.npz",
+ only_own_memories=True,
+ enable_preference_extraction=True,
+ enable_rl_updates=True,
+ )
+
+ # Run pilot
+ logs = run_multi_user_pilot(llm, personas, num_sessions=3)
+
+ # Summary
+ print_summary_v3(logs)
+
+ # Save logs
+ log_path = f"data/logs/pilot_v3_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl"
+ log_to_jsonl(logs, log_path)
+ print(f"\n[Logs] Saved to: {log_path}")
+
+ print(f"\nCompleted at: {datetime.now().isoformat()}")
+ print("=" * 60)
+
+
+if __name__ == "__main__":
+ main()
+