#!/usr/bin/env python3 """ Pilot Runner v3 - Multi-User Multi-Session with Personas Upgrades from v2: - Persona: Bundles StylePrefs into user types - 5 test personas (A-E) targeting different style combinations - Multi-user × multi-session evaluation - Refined judge: bullets only on list tasks, relaxed empty_answer - Baseline mode support (no-personalization comparison) 5 Test Personas: - A: short + bullets + en (sanity check) - B: short + NO bullets + en (anti-bullet) - C: long + bullets + en (no length constraint) - D: short + bullets + zh (Chinese) - E: long + NO bullets + zh (most "anti-default") """ import sys import os import json from datetime import datetime from dataclasses import dataclass, asdict, field from typing import List, Dict, Any, Optional, Tuple, Set, Literal # Add src to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src")) from personalization.serving import PersonalizedLLM, Feedback, AssistantResponse # ============================================================================= # Style Preferences (True Preferences) # ============================================================================= @dataclass class StylePrefs: """User's TRUE style preferences.""" require_short: bool = False max_chars: int = 300 require_bullets: bool = False lang: str = "en" # "en" or "zh" # ============================================================================= # Persona Definition # ============================================================================= @dataclass class Persona: """ A user persona that bundles style preferences. Each persona represents a distinct user type for testing. """ persona_id: str style_prefs: StylePrefs description: str = "" # Future extensions: # task_preferences: Dict[str, float] # e.g., {"code": 0.3, "rewrite": 0.7} # tone: str = "neutral" # "formal", "casual", etc. # domain: str = "general" # "tech", "daily_life", etc. # ============================================================================= # 5 Test Personas (A-E) # ============================================================================= PERSONA_A = Persona( persona_id="A_short_bullets_en", style_prefs=StylePrefs( require_short=True, max_chars=200, require_bullets=True, lang="en", ), description="Short + bullets + English (sanity check, same as v2)", ) PERSONA_B = Persona( persona_id="B_short_no_bullets_en", style_prefs=StylePrefs( require_short=True, max_chars=200, require_bullets=False, lang="en", ), description="Short + NO bullets + English (anti-bullet test)", ) PERSONA_C = Persona( persona_id="C_long_bullets_en", style_prefs=StylePrefs( require_short=False, max_chars=800, require_bullets=True, lang="en", ), description="Long + bullets + English (no length constraint)", ) PERSONA_D = Persona( persona_id="D_short_bullets_zh", style_prefs=StylePrefs( require_short=True, max_chars=200, require_bullets=True, lang="zh", ), description="Short + bullets + Chinese (language test)", ) PERSONA_E = Persona( persona_id="E_long_no_bullets_zh", style_prefs=StylePrefs( require_short=False, max_chars=800, require_bullets=False, lang="zh", ), description="Long + NO bullets + Chinese (most anti-default)", ) ALL_PERSONAS = [PERSONA_A, PERSONA_B, PERSONA_C, PERSONA_D, PERSONA_E] def get_persona_by_id(persona_id: str) -> Optional[Persona]: """Get persona by ID.""" for p in ALL_PERSONAS: if p.persona_id == persona_id: return p return None # ============================================================================= # Reveal State # ============================================================================= @dataclass class RevealState: """Tracks which preferences have been explicitly revealed.""" short_revealed: bool = False bullets_revealed: bool = False lang_revealed: bool = False def reset(self): self.short_revealed = False self.bullets_revealed = False self.lang_revealed = False def to_dict(self) -> Dict[str, bool]: return { "short": self.short_revealed, "bullets": self.bullets_revealed, "lang": self.lang_revealed, } def __str__(self) -> str: flags = [] if self.short_revealed: flags.append("short") if self.bullets_revealed: flags.append("bullets") if self.lang_revealed: flags.append("lang") return f"RevealState({', '.join(flags) if flags else 'none'})" class RevealStateManager: """Manages reveal state for multiple users.""" def __init__(self): self._states: Dict[str, RevealState] = {} def get_state(self, user_id: str) -> RevealState: if user_id not in self._states: self._states[user_id] = RevealState() return self._states[user_id] def reset_user(self, user_id: str): if user_id in self._states: self._states[user_id].reset() else: self._states[user_id] = RevealState() def reset_session(self, user_id: str): pass # Reveal state persists across sessions # ============================================================================= # Preference Detection # ============================================================================= def detect_revealed_preferences(query: str, prefs: StylePrefs) -> Dict[str, bool]: """ Detect which preferences are mentioned in a query. Also considers the user's true preferences for language detection. """ lower_q = (query or "").lower() revealed = { "short": False, "bullets": False, "lang": False, } # Short/length preference short_patterns = [ "short", "concise", "brief", "under ", "less than", "keep it short", "keep responses", "keep answers", "maximum ", "max ", "characters", "words or less", "200 ", "100 ", "50 ", "300 ", ] for pattern in short_patterns: if pattern in lower_q: revealed["short"] = True break # Bullet preference (both positive and negative) bullet_patterns = [ "bullet", "bullet point", "bullet-point", "bulleted", "list format", "use bullets", "no bullet", "don't use bullet", "without bullet", "numbered list", "use numbers", ] for pattern in bullet_patterns: if pattern in lower_q: revealed["bullets"] = True break # Language preference lang_patterns_zh = [ "chinese", "中文", "in chinese", "用中文", "speak chinese", "write chinese", "respond in chinese", "please use chinese", "mandarin", "请用中文", ] lang_patterns_en = [ "english", "in english", "use english", "speak english", "write english", "respond in english", "please use english", ] for pattern in lang_patterns_zh + lang_patterns_en: if pattern in lower_q: revealed["lang"] = True break return revealed def update_reveal_state(reveal_state: RevealState, query: str, prefs: StylePrefs) -> Set[str]: """Update reveal state based on query content.""" detected = detect_revealed_preferences(query, prefs) newly_revealed = set() if detected["short"] and not reveal_state.short_revealed: reveal_state.short_revealed = True newly_revealed.add("short") if detected["bullets"] and not reveal_state.bullets_revealed: reveal_state.bullets_revealed = True newly_revealed.add("bullets") if detected["lang"] and not reveal_state.lang_revealed: reveal_state.lang_revealed = True newly_revealed.add("lang") return newly_revealed # ============================================================================= # Refined Style Judge # ============================================================================= @dataclass class JudgeResult: """Output from the judge for one turn.""" sat_t: float sev_t: float prog_t: float violations: List[str] enforced_constraints: List[str] def style_judge_v3( query: str, answer: str, task_type: str, prefs: StylePrefs, reveal_state: RevealState, ) -> JudgeResult: """ Refined style judge with: - Bullets only enforced on list-type tasks - Relaxed empty_answer (only truly empty or single char) - Reveal-aware enforcement """ violations: List[str] = [] enforced: List[str] = [] text = (answer or "").strip() # 0) Empty answer - only truly empty or single non-meaningful char # Relaxed: allow short factual answers like "4", "Paris" if len(text) == 0: violations.append("empty_answer") return JudgeResult( sat_t=0.0, sev_t=1.0, prog_t=0.0, violations=violations, enforced_constraints=["non_empty"], ) # 1) Length - enforce only if BOTH true AND revealed if prefs.require_short and reveal_state.short_revealed: enforced.append("short") if len(text) > prefs.max_chars: violations.append("too_long") # 2) Bullets - enforce ONLY on list-type tasks AND if revealed # task_type "list" = listing tasks (Name X things, What are the N...) # task_type "qa" = factual QA (What is the capital...) # task_type "general" = other general tasks if prefs.require_bullets and reveal_state.bullets_revealed: if task_type == "list": # Only enforce on list tasks enforced.append("bullets") has_bullets = ("- " in text) or ("• " in text) or ("* " in text) or ("\n- " in text) if not has_bullets: violations.append("no_bullets") # 3) Language - enforce only if revealed if reveal_state.lang_revealed: enforced.append("lang") if prefs.lang == "zh": # For Chinese: should have significant non-ASCII content ascii_count = sum(c.isascii() for c in text) ascii_ratio = ascii_count / max(1, len(text)) if ascii_ratio > 0.7: violations.append("wrong_lang") elif prefs.lang == "en": # For English: should be mostly ASCII ascii_count = sum(c.isascii() for c in text) ascii_ratio = ascii_count / max(1, len(text)) if ascii_ratio < 0.5: violations.append("wrong_lang") # 4) Code task: always enforce prog_t = 1.0 if task_type == "code": enforced.append("code_block") has_code = ("```" in text) or ("def " in text) or ("function " in text) if not has_code: violations.append("no_code_block") prog_t = 0.0 # 5) Compute scores if not violations: sat_t = 1.0 sev_t = 0.0 else: sat_t = max(0.0, 1.0 - 0.3 * float(len(violations))) hard_violations = {"empty_answer", "too_long", "wrong_lang"} sev_t = 1.0 if any(v in hard_violations for v in violations) else 0.0 return JudgeResult( sat_t=sat_t, sev_t=sev_t, prog_t=prog_t, violations=violations, enforced_constraints=enforced, ) # ============================================================================= # Feedback Computation # ============================================================================= def compute_feedback_for_turn( turn_id: int, query: str, query_type: str, task_type: str, judge_result: JudgeResult, ) -> Tuple[float, float]: """Convert JudgeResult into (reward, gating).""" reward = judge_result.sat_t lower_q = (query or "").lower() is_pref_turn = ( query_type == "preference" or "i prefer" in lower_q or "my preference" in lower_q or "please use" in lower_q or "please keep" in lower_q or "you didn't follow" in lower_q or "you forgot" in lower_q or "remember that i" in lower_q or "i told you" in lower_q or "i asked for" in lower_q or "中文" in lower_q or "用中文" in lower_q ) gating = 1.0 if is_pref_turn else 0.0 return reward, gating # ============================================================================= # Query Generation per Persona # ============================================================================= def get_session_1_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]: """ Session 1: Reveal preferences. Customize based on persona's true preferences. """ queries = [] prefs = persona.style_prefs # Turn 0: Reveal length preference if prefs.require_short: if prefs.lang == "zh": queries.append({ "query": "我喜欢简短的回答,请保持回复在200字以内。", "type": "preference", "task_type": "general", }) else: queries.append({ "query": "I prefer short, concise answers. Please keep responses under 200 characters.", "type": "preference", "task_type": "general", }) else: # Long preference - don't reveal (let short_revealed stay False) if prefs.lang == "zh": queries.append({ "query": "你好,我想了解一些问题。", "type": "task", "task_type": "general", }) else: queries.append({ "query": "Hello, I have some questions for you.", "type": "task", "task_type": "general", }) # Turn 1: First task if prefs.lang == "zh": queries.append({ "query": "列出三个改善睡眠的建议。", "type": "task", "task_type": "list", }) else: queries.append({ "query": "List three tips for better sleep.", "type": "task", "task_type": "list", }) # Turn 2: Reveal bullet preference if prefs.require_bullets: if prefs.lang == "zh": queries.append({ "query": "我喜欢用项目符号列出要点,请使用bullet points。", "type": "preference", "task_type": "general", }) else: queries.append({ "query": "I prefer bullet points when listing things. Please use bullet points.", "type": "preference", "task_type": "general", }) else: # Don't reveal bullet preference (or reveal anti-bullet) if prefs.lang == "zh": queries.append({ "query": "请不要用项目符号,我更喜欢连续的句子。", "type": "preference", "task_type": "general", }) else: queries.append({ "query": "Please don't use bullet points. I prefer continuous prose.", "type": "preference", "task_type": "general", }) # Turn 3: Reveal language preference (for non-English personas) if prefs.lang == "zh": queries.append({ "query": "请用中文回答我的问题。", "type": "preference", "task_type": "general", }) else: queries.append({ "query": "Please respond in English.", "type": "preference", "task_type": "general", }) # Turn 4-5: Tasks if prefs.lang == "zh": queries.extend([ { "query": "锻炼有什么好处?", "type": "task", "task_type": "list", }, { "query": "列出五种流行的编程语言。", "type": "task", "task_type": "list", }, ]) else: queries.extend([ { "query": "What are the benefits of exercise?", "type": "task", "task_type": "list", }, { "query": "Name five popular programming languages.", "type": "task", "task_type": "list", }, ]) return queries def get_session_2_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]: """ Session 2: NO preference restatement. Tests cross-session retention. """ prefs = persona.style_prefs if prefs.lang == "zh": return [ {"query": "推荐三种健康的早餐。", "type": "task", "task_type": "list"}, {"query": "一年有哪四个季节?", "type": "task", "task_type": "list"}, {"query": "法国的首都是哪里?", "type": "task", "task_type": "qa"}, {"query": "列出三种可再生能源。", "type": "task", "task_type": "list"}, ] else: return [ {"query": "What are three healthy breakfast ideas?", "type": "task", "task_type": "list"}, {"query": "What are the four seasons of the year?", "type": "task", "task_type": "list"}, {"query": "What is the capital of France?", "type": "task", "task_type": "qa"}, {"query": "Name three types of renewable energy.", "type": "task", "task_type": "list"}, ] def get_session_3_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]: """ Session 3: Mix of tasks and one reminder. """ prefs = persona.style_prefs if prefs.lang == "zh": return [ {"query": "列出五种常见的水果。", "type": "task", "task_type": "list"}, {"query": "请记住我喜欢简短的回答。列出三种海洋动物。", "type": "preference", "task_type": "list"}, {"query": "2加2等于多少?", "type": "task", "task_type": "qa"}, ] else: return [ {"query": "Name five common fruits.", "type": "task", "task_type": "list"}, {"query": "Remember that I asked for short answers. List three ocean animals.", "type": "preference", "task_type": "list"}, {"query": "What is 2 + 2?", "type": "task", "task_type": "qa"}, ] # ============================================================================= # Logging # ============================================================================= @dataclass class TurnLog: """Log entry for one turn.""" user_id: str persona_id: str session_id: int turn_id: int query: str query_type: str task_type: str answer: str answer_length: int sat_t: float sev_t: float prog_t: float violations: List[str] enforced_constraints: List[str] reward: float gating: float reveal_state_before: Dict[str, bool] reveal_state_after: Dict[str, bool] newly_revealed: List[str] z_long_norm_before: float z_long_norm_after: float z_short_norm_before: float z_short_norm_after: float prompt_tokens: int completion_tokens: int total_tokens: int num_memories_retrieved: int num_prefs_extracted: int def log_to_jsonl(logs: List[TurnLog], filepath: str): """Save logs to JSONL file.""" os.makedirs(os.path.dirname(filepath), exist_ok=True) with open(filepath, "w") as f: for log in logs: f.write(json.dumps(asdict(log)) + "\n") # ============================================================================= # Session Runner # ============================================================================= def run_session( llm: PersonalizedLLM, user_id: str, persona: Persona, session_id: int, reveal_state: RevealState, queries: List[Dict[str, Any]], all_logs: List[TurnLog], ) -> List[TurnLog]: """Run a single session for a user.""" prefs = persona.style_prefs session_logs: List[TurnLog] = [] print(f"\n{'='*60}") print(f"[{persona.persona_id}] Session {session_id}: {len(queries)} turns") print(f"Reveal state (start): {reveal_state}") print(f"{'='*60}") # Reset session (clears history, z_short; keeps z_long and reveal state) llm.reset_session(user_id) for turn_id, q_info in enumerate(queries): query = q_info["query"] query_type = q_info.get("type", "task") task_type = q_info.get("task_type", "general") print(f"\n--- S{session_id}/T{turn_id} [{query_type}] ---") print(f"[Q] {query[:60]}{'...' if len(query) > 60 else ''}") # Capture reveal state BEFORE reveal_before = reveal_state.to_dict() # Update reveal state newly_revealed = update_reveal_state(reveal_state, query, prefs) if newly_revealed: print(f"[Reveal] Newly: {newly_revealed}") reveal_after = reveal_state.to_dict() # Get user state before state_before = llm.get_user_state_summary(user_id) z_long_before = state_before["z_long_norm"] z_short_before = state_before["z_short_norm"] # Apply feedback for previous turn if turn_id > 0 and session_logs: prev_log = session_logs[-1] feedback = Feedback( user_id=user_id, turn_id=prev_log.turn_id, reward=prev_log.reward, gating=prev_log.gating, meta={"source": "pilot_v3", "session_id": session_id} ) llm.apply_feedback(feedback) # Chat resp: AssistantResponse = llm.chat(user_id, query) answer_display = resp.answer[:80] + "..." if len(resp.answer) > 80 else resp.answer print(f"[A] ({len(resp.answer)}c) {answer_display}") # Judge judge_result = style_judge_v3(query, resp.answer, task_type, prefs, reveal_state) print(f"[J] sat={judge_result.sat_t:.2f}, enforced={judge_result.enforced_constraints}, viol={judge_result.violations}") # Compute feedback reward, gating = compute_feedback_for_turn(turn_id, query, query_type, task_type, judge_result) # Get state after state_after = llm.get_user_state_summary(user_id) z_long_after = state_after["z_long_norm"] z_short_after = state_after["z_short_norm"] # Debug info num_memories = len(resp.debug.selected_memory_ids) if resp.debug else 0 num_prefs = len(resp.debug.extracted_preferences) if resp.debug else 0 # Log log = TurnLog( user_id=user_id, persona_id=persona.persona_id, session_id=session_id, turn_id=turn_id, query=query, query_type=query_type, task_type=task_type, answer=resp.answer, answer_length=len(resp.answer), sat_t=judge_result.sat_t, sev_t=judge_result.sev_t, prog_t=judge_result.prog_t, violations=judge_result.violations, enforced_constraints=judge_result.enforced_constraints, reward=reward, gating=gating, reveal_state_before=reveal_before, reveal_state_after=reveal_after, newly_revealed=list(newly_revealed), z_long_norm_before=z_long_before, z_long_norm_after=z_long_after, z_short_norm_before=z_short_before, z_short_norm_after=z_short_after, prompt_tokens=resp.usage.prompt_tokens, completion_tokens=resp.usage.completion_tokens, total_tokens=resp.usage.total_tokens, num_memories_retrieved=num_memories, num_prefs_extracted=num_prefs, ) session_logs.append(log) all_logs.append(log) # Apply final feedback if session_logs: last_log = session_logs[-1] feedback = Feedback( user_id=user_id, turn_id=last_log.turn_id, reward=last_log.reward, gating=last_log.gating, meta={"source": "pilot_v3", "session_id": session_id, "final": True} ) llm.apply_feedback(feedback) return session_logs # ============================================================================= # Multi-User Multi-Session Runner # ============================================================================= def run_multi_user_pilot( llm: PersonalizedLLM, personas: List[Persona], num_sessions: int = 3, reveal_manager: Optional[RevealStateManager] = None, ) -> List[TurnLog]: """ Run multi-user multi-session pilot. Args: llm: PersonalizedLLM instance personas: List of personas to test num_sessions: Number of sessions per user reveal_manager: Optional existing reveal manager """ if reveal_manager is None: reveal_manager = RevealStateManager() all_logs: List[TurnLog] = [] print(f"\n{'#'*60}") print(f"PILOT v3: MULTI-USER MULTI-SESSION") print(f"Users: {len(personas)}, Sessions per user: {num_sessions}") print(f"{'#'*60}") for persona in personas: user_id = f"user_{persona.persona_id}" prefs = persona.style_prefs print(f"\n{'*'*60}") print(f"USER: {user_id}") print(f"Persona: {persona.description}") print(f"True prefs: short={prefs.require_short}, bullets={prefs.require_bullets}, lang={prefs.lang}") print(f"{'*'*60}") # Reset user completely llm.reset_user(user_id) reveal_manager.reset_user(user_id) reveal_state = reveal_manager.get_state(user_id) # Run sessions for session_id in range(1, num_sessions + 1): if session_id == 1: queries = get_session_1_queries_for_persona(persona) elif session_id == 2: queries = get_session_2_queries_for_persona(persona) else: queries = get_session_3_queries_for_persona(persona) reveal_manager.reset_session(user_id) # No-op, just for clarity run_session(llm, user_id, persona, session_id, reveal_state, queries, all_logs) return all_logs # ============================================================================= # Summary # ============================================================================= def print_summary_v3(logs: List[TurnLog]): """Print summary for pilot v3.""" print(f"\n{'='*60}") print("PILOT v3 SUMMARY - Multi-User Multi-Session") print(f"{'='*60}") if not logs: print("No logs.") return from collections import Counter, defaultdict # Per-persona stats personas = sorted(set(l.persona_id for l in logs)) print(f"\n--- Per-Persona Statistics ---") for pid in personas: p_logs = [l for l in logs if l.persona_id == pid] # Per-session breakdown sessions = sorted(set(l.session_id for l in p_logs)) print(f"\n{pid}:") for sid in sessions: s_logs = [l for l in p_logs if l.session_id == sid] avg_sat = sum(l.sat_t for l in s_logs) / len(s_logs) if s_logs else 0 violations = [v for l in s_logs for v in l.violations] enforced = set(c for l in s_logs for c in l.enforced_constraints) print(f" Session {sid}: {len(s_logs)} turns, avg_sat={avg_sat:.3f}, enforced={enforced}") if violations: print(f" violations: {dict(Counter(violations))}") # Cross-session retention check print(f"\n--- Cross-Session Retention ---") for pid in personas: p_logs = [l for l in logs if l.persona_id == pid] s1_logs = [l for l in p_logs if l.session_id == 1] s2_logs = [l for l in p_logs if l.session_id == 2] if s1_logs and s2_logs: s1_sat = sum(l.sat_t for l in s1_logs) / len(s1_logs) s2_sat = sum(l.sat_t for l in s2_logs) / len(s2_logs) # Check what was enforced in S2 s2_enforced = set(c for l in s2_logs for c in l.enforced_constraints) print(f"{pid}: S1_sat={s1_sat:.3f} → S2_sat={s2_sat:.3f}, S2_enforced={s2_enforced}") # Overall stats total = len(logs) avg_sat = sum(l.sat_t for l in logs) / total total_tokens = sum(l.total_tokens for l in logs) print(f"\n--- Overall ---") print(f"Total turns: {total}") print(f"Overall avg sat_t: {avg_sat:.3f}") print(f"Total tokens: {total_tokens}") # Violations by type all_violations = [v for l in logs for v in l.violations] if all_violations: print(f"\nViolations: {dict(Counter(all_violations))}") def main(): print("=" * 60) print("PILOT RUNNER v3 - Multi-User Multi-Session with Personas") print("=" * 60) print(f"Started at: {datetime.now().isoformat()}") # Select personas personas = ALL_PERSONAS # All 5 personas print(f"\n[Config] Running {len(personas)} personas:") for p in personas: print(f" - {p.persona_id}: {p.description}") # Initialize LLM print("\n[Init] Loading PersonalizedLLM...") llm = PersonalizedLLM( user_store_path="data/users/user_store_pilot_v3.npz", only_own_memories=True, enable_preference_extraction=True, enable_rl_updates=True, ) # Run pilot logs = run_multi_user_pilot(llm, personas, num_sessions=3) # Summary print_summary_v3(logs) # Save logs log_path = f"data/logs/pilot_v3_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl" log_to_jsonl(logs, log_path) print(f"\n[Logs] Saved to: {log_path}") print(f"\nCompleted at: {datetime.now().isoformat()}") print("=" * 60) if __name__ == "__main__": main()