diff options
Diffstat (limited to 'scripts/pilot_runner_v4.py')
| -rw-r--r-- | scripts/pilot_runner_v4.py | 1230 |
1 files changed, 1230 insertions, 0 deletions
diff --git a/scripts/pilot_runner_v4.py b/scripts/pilot_runner_v4.py new file mode 100644 index 0000000..b3e2058 --- /dev/null +++ b/scripts/pilot_runner_v4.py @@ -0,0 +1,1230 @@ +#!/usr/bin/env python3 +""" +Pilot Runner v4 - Critical Fixes for Baseline Comparison + +Fixes from v3: +1. Chinese short reveal detection (简短/字以内/不超过 etc.) +2. Symmetric bullets constraint (has_bullets violation for require_bullets=False) +3. Better wrong_lang with CJK ratio + math exemption +4. Persona-conditional query templates (no self-contradiction) +5. Violation-triggered complaint mechanism (for online RL signal) + +This version is ready for proper baseline comparison. +""" + +import sys +import os +import re +import json +from datetime import datetime +from dataclasses import dataclass, asdict, field +from typing import List, Dict, Any, Optional, Tuple, Set, Literal + +# Add src to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.serving import PersonalizedLLM, Feedback, AssistantResponse + + +# ============================================================================= +# Style Preferences (True Preferences) +# ============================================================================= + +@dataclass +class StylePrefs: + """User's TRUE style preferences.""" + require_short: bool = False + max_chars: int = 300 + require_bullets: bool = False # True = want bullets, False = don't want bullets + lang: str = "en" # "en" or "zh" + + +# ============================================================================= +# Persona Definition +# ============================================================================= + +@dataclass +class Persona: + """A user persona that bundles style preferences.""" + persona_id: str + style_prefs: StylePrefs + description: str = "" + + +# 5 Test Personas +PERSONA_A = Persona( + persona_id="A_short_bullets_en", + style_prefs=StylePrefs(require_short=True, max_chars=200, require_bullets=True, lang="en"), + description="Short + bullets + English", +) + +PERSONA_B = Persona( + persona_id="B_short_no_bullets_en", + style_prefs=StylePrefs(require_short=True, max_chars=200, require_bullets=False, lang="en"), + description="Short + NO bullets + English (anti-bullet)", +) + +PERSONA_C = Persona( + persona_id="C_long_bullets_en", + style_prefs=StylePrefs(require_short=False, max_chars=800, require_bullets=True, lang="en"), + description="Long + bullets + English", +) + +PERSONA_D = Persona( + persona_id="D_short_bullets_zh", + style_prefs=StylePrefs(require_short=True, max_chars=200, require_bullets=True, lang="zh"), + description="Short + bullets + Chinese", +) + +PERSONA_E = Persona( + persona_id="E_long_no_bullets_zh", + style_prefs=StylePrefs(require_short=False, max_chars=800, require_bullets=False, lang="zh"), + description="Long + NO bullets + Chinese (most anti-default)", +) + +# Extreme short persona for case study - LLM default is much longer +PERSONA_F = Persona( + persona_id="F_extreme_short_en", + style_prefs=StylePrefs(require_short=True, max_chars=100, require_bullets=True, lang="en"), + description="EXTREME short (100 chars) + bullets + English", +) + +ALL_PERSONAS = [PERSONA_A, PERSONA_B, PERSONA_C, PERSONA_D, PERSONA_E, PERSONA_F] + + +# ============================================================================= +# Reveal State +# ============================================================================= + +@dataclass +class RevealState: + """Tracks which preferences have been explicitly revealed.""" + short_revealed: bool = False + bullets_revealed: bool = False + lang_revealed: bool = False + + def reset(self): + self.short_revealed = False + self.bullets_revealed = False + self.lang_revealed = False + + def to_dict(self) -> Dict[str, bool]: + return {"short": self.short_revealed, "bullets": self.bullets_revealed, "lang": self.lang_revealed} + + def __str__(self) -> str: + flags = [k for k, v in self.to_dict().items() if v] + return f"RevealState({', '.join(flags) if flags else 'none'})" + + +class RevealStateManager: + """Manages reveal state for multiple users.""" + def __init__(self): + self._states: Dict[str, RevealState] = {} + + def get_state(self, user_id: str) -> RevealState: + if user_id not in self._states: + self._states[user_id] = RevealState() + return self._states[user_id] + + def reset_user(self, user_id: str): + self._states[user_id] = RevealState() + + def reset_session(self, user_id: str): + pass # Reveal state persists + + +# ============================================================================= +# FIX 1: Improved Preference Detection (with Chinese support) +# ============================================================================= + +def detect_revealed_preferences(query: str, prefs: StylePrefs) -> Dict[str, bool]: + """ + Detect which preferences are mentioned in a query. + FIX: Added Chinese keywords for short detection. + """ + lower_q = (query or "").lower() + original_q = query or "" + + revealed = {"short": False, "bullets": False, "lang": False} + + # Short/length preference - English patterns + short_patterns_en = [ + "short", "concise", "brief", "under ", "less than", + "keep it short", "keep responses", "keep answers", + "maximum ", "max ", "characters", "words or less", + ] + + # FIX: Chinese patterns for short preference + short_patterns_zh = [ + "简短", "精简", "尽量短", "不要太长", "字以内", "不超过", + "少于", "控制在", "简洁", "简明", + ] + + # Regex patterns for number-based length constraints + short_regex_patterns = [ + r"(\d+)\s*字以内", # "200字以内" + r"不超过\s*(\d+)\s*字", # "不超过200字" + r"under\s*(\d+)", # "under 200" + r"less\s*than\s*(\d+)", # "less than 200" + ] + + for pattern in short_patterns_en: + if pattern in lower_q: + revealed["short"] = True + break + + if not revealed["short"]: + for pattern in short_patterns_zh: + if pattern in original_q: + revealed["short"] = True + break + + if not revealed["short"]: + for regex in short_regex_patterns: + if re.search(regex, original_q, re.IGNORECASE): + revealed["short"] = True + break + + # Bullet preference - both positive and negative + bullet_patterns_positive = [ + "bullet", "bullet point", "bullet-point", "bulleted", + "list format", "use bullets", "use bullet", + "项目符号", "要点", "用bullet", + ] + bullet_patterns_negative = [ + "no bullet", "don't use bullet", "without bullet", + "不要bullet", "不要项目符号", "不用bullet", + "continuous prose", "paragraph form", "flowing text", + "连续句子", "段落形式", + ] + + for pattern in bullet_patterns_positive + bullet_patterns_negative: + if pattern in lower_q or pattern in original_q: + revealed["bullets"] = True + break + + # Language preference + lang_patterns_zh = [ + "chinese", "中文", "in chinese", "用中文", + "speak chinese", "respond in chinese", "请用中文", + ] + lang_patterns_en = [ + "english", "in english", "use english", + "speak english", "respond in english", + ] + + for pattern in lang_patterns_zh + lang_patterns_en: + if pattern in lower_q or pattern in original_q: + revealed["lang"] = True + break + + return revealed + + +def update_reveal_state(reveal_state: RevealState, query: str, prefs: StylePrefs) -> Set[str]: + """Update reveal state based on query content.""" + detected = detect_revealed_preferences(query, prefs) + newly_revealed = set() + + if detected["short"] and not reveal_state.short_revealed: + reveal_state.short_revealed = True + newly_revealed.add("short") + + if detected["bullets"] and not reveal_state.bullets_revealed: + reveal_state.bullets_revealed = True + newly_revealed.add("bullets") + + if detected["lang"] and not reveal_state.lang_revealed: + reveal_state.lang_revealed = True + newly_revealed.add("lang") + + return newly_revealed + + +# ============================================================================= +# FIX 3: Better Language Detection +# ============================================================================= + +def is_math_or_symbol_only(text: str) -> bool: + """Check if text is purely math/symbols (language neutral).""" + # Pattern: only digits, operators, whitespace, punctuation + math_pattern = r'^[\d+\-*/=().,%\s\n\r]+$' + return bool(re.match(math_pattern, text.strip())) + + +def count_cjk_chars(text: str) -> int: + """Count CJK (Chinese/Japanese/Korean) characters.""" + # CJK Unified Ideographs range + cjk_pattern = re.compile(r'[\u4e00-\u9fff\u3400-\u4dbf]') + return len(cjk_pattern.findall(text)) + + +def count_latin_letters(text: str) -> int: + """Count Latin letters (a-z, A-Z).""" + return sum(1 for c in text if c.isalpha() and c.isascii()) + + +def check_language_violation(text: str, target_lang: str) -> bool: + """ + FIX: Better language violation check using CJK ratio. + Returns True if there's a violation. + """ + text = text.strip() + + # Exempt pure math/symbols + if is_math_or_symbol_only(text): + return False + + cjk_count = count_cjk_chars(text) + latin_count = count_latin_letters(text) + total = cjk_count + latin_count + + if total == 0: + return False # No meaningful text to judge + + if target_lang == "zh": + # For Chinese: want high CJK ratio + cjk_ratio = cjk_count / (total + 1e-9) + # Allow some English proper nouns - only flag if very low CJK + return cjk_ratio < 0.2 # Less than 20% CJK = wrong language + + elif target_lang == "en": + # For English: want high Latin ratio + latin_ratio = latin_count / (total + 1e-9) + return latin_ratio < 0.5 # Less than 50% Latin = wrong language + + return False + + +# ============================================================================= +# FIX 2: Symmetric Bullets Constraint + FIX 3: Language +# ============================================================================= + +@dataclass +class JudgeResult: + """Output from the judge for one turn.""" + sat_t: float + sev_t: float + prog_t: float + violations: List[str] + enforced_constraints: List[str] + + +def has_bullet_markers(text: str) -> bool: + """Check if text contains bullet point markers.""" + return bool(re.search(r'(^|\n)\s*[-•*]\s', text)) + + +def style_judge_v4( + query: str, + answer: str, + task_type: str, + prefs: StylePrefs, + reveal_state: RevealState, +) -> JudgeResult: + """ + Style judge v4 with: + - FIX 2: Symmetric bullets (has_bullets violation for require_bullets=False) + - FIX 3: Better wrong_lang with CJK ratio + math exemption + """ + violations: List[str] = [] + enforced: List[str] = [] + text = (answer or "").strip() + + # 0) Empty answer + if len(text) == 0: + violations.append("empty_answer") + return JudgeResult(sat_t=0.0, sev_t=1.0, prog_t=0.0, + violations=violations, enforced_constraints=["non_empty"]) + + # 1) Length - enforce only if true AND revealed + if prefs.require_short and reveal_state.short_revealed: + enforced.append("short") + if len(text) > prefs.max_chars: + violations.append("too_long") + + # 2) FIX 2: Symmetric bullets constraint (only for list tasks) + if reveal_state.bullets_revealed and task_type == "list": + has_bullets = has_bullet_markers(text) + + if prefs.require_bullets: + # Want bullets but don't have them + enforced.append("require_bullets") + if not has_bullets: + violations.append("no_bullets") + else: + # Don't want bullets but have them + enforced.append("no_bullets_pref") + if has_bullets: + violations.append("has_bullets") + + # 3) FIX 3: Language with CJK ratio + math exemption + if reveal_state.lang_revealed: + enforced.append("lang") + if check_language_violation(text, prefs.lang): + violations.append("wrong_lang") + + # 4) Code task + prog_t = 1.0 + if task_type == "code": + enforced.append("code_block") + has_code = ("```" in text) or ("def " in text) or ("function " in text) + if not has_code: + violations.append("no_code_block") + prog_t = 0.0 + + # 5) Compute scores + if not violations: + sat_t = 1.0 + sev_t = 0.0 + else: + sat_t = max(0.0, 1.0 - 0.3 * float(len(violations))) + hard_violations = {"empty_answer", "too_long", "wrong_lang"} + sev_t = 1.0 if any(v in hard_violations for v in violations) else 0.0 + + return JudgeResult(sat_t=sat_t, sev_t=sev_t, prog_t=prog_t, + violations=violations, enforced_constraints=enforced) + + +# ============================================================================= +# Feedback Computation +# ============================================================================= + +def compute_feedback_for_turn( + query: str, + query_type: str, + judge_result: JudgeResult, +) -> Tuple[float, float]: + """Convert JudgeResult into (reward, gating).""" + reward = judge_result.sat_t + + lower_q = (query or "").lower() + original_q = query or "" + + is_pref_turn = ( + query_type == "preference" + or "i prefer" in lower_q + or "my preference" in lower_q + or "please use" in lower_q + or "please keep" in lower_q + or "you didn't follow" in lower_q + or "you forgot" in lower_q + or "remember that i" in lower_q + or "i told you" in lower_q + or "i asked for" in lower_q + or "that was too" in lower_q + or "too long" in lower_q + or "请用中文" in original_q + or "不要" in original_q + or "简短" in original_q + ) + + gating = 1.0 if is_pref_turn else 0.0 + return reward, gating + + +# ============================================================================= +# FIX 5: Violation-Triggered Complaint Generation +# ============================================================================= + +def generate_complaint_query(violations: List[str], prefs: StylePrefs) -> Optional[Dict[str, Any]]: + """ + Generate a complaint query based on violations. + Returns None if no complaint needed. + """ + if not violations: + return None + + # Priority: address most severe violation first + complaint = None + + if "too_long" in violations: + if prefs.lang == "zh": + complaint = { + "query": f"回答太长了。请保持回复在{prefs.max_chars}字以内。", + "type": "preference", + "task_type": "general", + } + else: + complaint = { + "query": f"That was too long. Please keep responses under {prefs.max_chars} characters.", + "type": "preference", + "task_type": "general", + } + + elif "wrong_lang" in violations: + if prefs.lang == "zh": + complaint = { + "query": "请用中文回答。", + "type": "preference", + "task_type": "general", + } + else: + complaint = { + "query": "Please respond in English.", + "type": "preference", + "task_type": "general", + } + + elif "no_bullets" in violations: + if prefs.lang == "zh": + complaint = { + "query": "请在列出内容时使用项目符号(bullet points)。", + "type": "preference", + "task_type": "general", + } + else: + complaint = { + "query": "Please use bullet points when listing things.", + "type": "preference", + "task_type": "general", + } + + elif "has_bullets" in violations: + if prefs.lang == "zh": + complaint = { + "query": "请不要使用项目符号,用连续的句子来表达。", + "type": "preference", + "task_type": "general", + } + else: + complaint = { + "query": "Please don't use bullet points. Use continuous prose instead.", + "type": "preference", + "task_type": "general", + } + + return complaint + + +# ============================================================================= +# FIX 4: Persona-Conditional Query Templates +# ============================================================================= + +def get_session_1_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]: + """ + Session 1: Reveal preferences (persona-conditional). + FIX: Only reveal preferences that match the persona's true prefs. + """ + queries = [] + prefs = persona.style_prefs + + # Turn 0: Reveal length preference (only if require_short=True) + if prefs.require_short: + if prefs.lang == "zh": + queries.append({ + "query": f"我喜欢简短的回答,请保持回复在{prefs.max_chars}字以内。", + "type": "preference", + "task_type": "general", + }) + else: + queries.append({ + "query": f"I prefer short, concise answers. Please keep responses under {prefs.max_chars} characters.", + "type": "preference", + "task_type": "general", + }) + else: + # Don't reveal short preference for long-preferring personas + if prefs.lang == "zh": + queries.append({ + "query": "你好,我有一些问题想问你。", + "type": "task", + "task_type": "general", + }) + else: + queries.append({ + "query": "Hello, I have some questions for you.", + "type": "task", + "task_type": "general", + }) + + # Turn 1: First task + if prefs.lang == "zh": + queries.append({"query": "列出三个改善睡眠的建议。", "type": "task", "task_type": "list"}) + else: + queries.append({"query": "List three tips for better sleep.", "type": "task", "task_type": "list"}) + + # Turn 2: Reveal bullet preference (conditional on require_bullets) + if prefs.require_bullets: + if prefs.lang == "zh": + queries.append({ + "query": "我喜欢用项目符号列出要点,请使用bullet points。", + "type": "preference", + "task_type": "general", + }) + else: + queries.append({ + "query": "I prefer bullet points when listing things. Please use bullet points.", + "type": "preference", + "task_type": "general", + }) + else: + # Explicitly say NO bullets + if prefs.lang == "zh": + queries.append({ + "query": "请不要用项目符号,我更喜欢连续的句子来表达。", + "type": "preference", + "task_type": "general", + }) + else: + queries.append({ + "query": "Please don't use bullet points. I prefer continuous prose.", + "type": "preference", + "task_type": "general", + }) + + # Turn 3: Reveal language preference + if prefs.lang == "zh": + queries.append({"query": "请用中文回答我的问题。", "type": "preference", "task_type": "general"}) + else: + queries.append({"query": "Please respond in English.", "type": "preference", "task_type": "general"}) + + # Turn 4-5: Tasks + if prefs.lang == "zh": + queries.extend([ + {"query": "锻炼有什么好处?", "type": "task", "task_type": "list"}, + {"query": "列出五种流行的编程语言。", "type": "task", "task_type": "list"}, + ]) + else: + queries.extend([ + {"query": "What are the benefits of exercise?", "type": "task", "task_type": "list"}, + {"query": "Name five popular programming languages.", "type": "task", "task_type": "list"}, + ]) + + return queries + + +def get_session_2_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]: + """Session 2: NO preference restatement.""" + prefs = persona.style_prefs + + if prefs.lang == "zh": + return [ + {"query": "推荐三种健康的早餐。", "type": "task", "task_type": "list"}, + {"query": "一年有哪四个季节?", "type": "task", "task_type": "list"}, + {"query": "法国的首都是哪里?", "type": "task", "task_type": "qa"}, + {"query": "列出三种可再生能源。", "type": "task", "task_type": "list"}, + ] + else: + return [ + {"query": "What are three healthy breakfast ideas?", "type": "task", "task_type": "list"}, + {"query": "What are the four seasons of the year?", "type": "task", "task_type": "list"}, + {"query": "What is the capital of France?", "type": "task", "task_type": "qa"}, + {"query": "Name three types of renewable energy.", "type": "task", "task_type": "list"}, + ] + + +def get_session_3_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]: + """ + Session 3: Tasks with ONE persona-conditional reminder. + FIX: Reminder matches persona's actual preferences. + """ + prefs = persona.style_prefs + queries = [] + + # First task + if prefs.lang == "zh": + queries.append({"query": "列出五种常见的水果。", "type": "task", "task_type": "list"}) + else: + queries.append({"query": "Name five common fruits.", "type": "task", "task_type": "list"}) + + # Persona-conditional reminder + if prefs.require_short and prefs.require_bullets: + if prefs.lang == "zh": + queries.append({ + "query": "记住我喜欢简短的回答和项目符号。列出三种海洋动物。", + "type": "preference", "task_type": "list" + }) + else: + queries.append({ + "query": "Remember I prefer short answers with bullet points. List three ocean animals.", + "type": "preference", "task_type": "list" + }) + elif prefs.require_short and not prefs.require_bullets: + if prefs.lang == "zh": + queries.append({ + "query": "记住我喜欢简短的回答,不要用项目符号。列出三种海洋动物。", + "type": "preference", "task_type": "list" + }) + else: + queries.append({ + "query": "Remember I prefer short answers without bullet points. List three ocean animals.", + "type": "preference", "task_type": "list" + }) + elif not prefs.require_short and prefs.require_bullets: + if prefs.lang == "zh": + queries.append({ + "query": "记住我喜欢用项目符号列出要点。列出三种海洋动物。", + "type": "preference", "task_type": "list" + }) + else: + queries.append({ + "query": "Remember I prefer bullet points. List three ocean animals.", + "type": "preference", "task_type": "list" + }) + else: # not short and not bullets + if prefs.lang == "zh": + queries.append({ + "query": "记住我不喜欢用项目符号,喜欢连续的句子。列出三种海洋动物。", + "type": "preference", "task_type": "list" + }) + else: + queries.append({ + "query": "Remember I prefer continuous prose without bullet points. List three ocean animals.", + "type": "preference", "task_type": "list" + }) + + # Final task + if prefs.lang == "zh": + queries.append({"query": "2加2等于多少?", "type": "task", "task_type": "qa"}) + else: + queries.append({"query": "What is 2 + 2?", "type": "task", "task_type": "qa"}) + + return queries + + +def get_pure_task_queries_for_persona(persona: Persona, session_idx: int) -> List[Dict[str, Any]]: + """ + Pure task sessions (S4+): NO preference reminders at all. + Used for testing long-term retention without any in-context hints. + Different task sets per session to avoid repetition. + """ + prefs = persona.style_prefs + + # Task pools for variety + zh_task_pools = [ + # Pool 1 + [ + {"query": "列出三种热带水果。", "type": "task", "task_type": "list"}, + {"query": "列出三种常见的编程语言。", "type": "task", "task_type": "list"}, + {"query": "什么是光合作用?", "type": "task", "task_type": "qa"}, + {"query": "太阳系有几颗行星?", "type": "task", "task_type": "qa"}, + ], + # Pool 2 + [ + {"query": "列出三种室内植物。", "type": "task", "task_type": "list"}, + {"query": "列出三种运动项目。", "type": "task", "task_type": "list"}, + {"query": "什么是人工智能?", "type": "task", "task_type": "qa"}, + {"query": "地球的自转周期是多少?", "type": "task", "task_type": "qa"}, + ], + # Pool 3 + [ + {"query": "列出三种乐器。", "type": "task", "task_type": "list"}, + {"query": "列出三种社交媒体平台。", "type": "task", "task_type": "list"}, + {"query": "什么是区块链?", "type": "task", "task_type": "qa"}, + {"query": "月球绕地球一周需要多长时间?", "type": "task", "task_type": "qa"}, + ], + # Pool 4 + [ + {"query": "列出三种鸟类。", "type": "task", "task_type": "list"}, + {"query": "列出三种数据库系统。", "type": "task", "task_type": "list"}, + {"query": "什么是机器学习?", "type": "task", "task_type": "qa"}, + {"query": "水的沸点是多少?", "type": "task", "task_type": "qa"}, + ], + ] + + en_task_pools = [ + # Pool 1 + [ + {"query": "List three tropical fruits.", "type": "task", "task_type": "list"}, + {"query": "List three popular programming languages.", "type": "task", "task_type": "list"}, + {"query": "What is photosynthesis?", "type": "task", "task_type": "qa"}, + {"query": "How many planets are in our solar system?", "type": "task", "task_type": "qa"}, + ], + # Pool 2 + [ + {"query": "List three indoor plants.", "type": "task", "task_type": "list"}, + {"query": "List three types of sports.", "type": "task", "task_type": "list"}, + {"query": "What is artificial intelligence?", "type": "task", "task_type": "qa"}, + {"query": "How long is a day on Earth?", "type": "task", "task_type": "qa"}, + ], + # Pool 3 + [ + {"query": "List three musical instruments.", "type": "task", "task_type": "list"}, + {"query": "List three social media platforms.", "type": "task", "task_type": "list"}, + {"query": "What is blockchain?", "type": "task", "task_type": "qa"}, + {"query": "How long does it take the Moon to orbit Earth?", "type": "task", "task_type": "qa"}, + ], + # Pool 4 + [ + {"query": "List three types of birds.", "type": "task", "task_type": "list"}, + {"query": "List three database systems.", "type": "task", "task_type": "list"}, + {"query": "What is machine learning?", "type": "task", "task_type": "qa"}, + {"query": "What is the boiling point of water?", "type": "task", "task_type": "qa"}, + ], + ] + + pools = zh_task_pools if prefs.lang == "zh" else en_task_pools + # Rotate through pools based on session index + pool_idx = (session_idx - 4) % len(pools) + return pools[pool_idx] + + +def get_queries_for_session(persona: Persona, session_id: int) -> List[Dict[str, Any]]: + """ + Get queries for a specific session. + S1: Preference reveal + S2: Pure task (no reminder) + S3: Tasks with ONE reminder + S4+: Pure task (testing long-term retention) + """ + if session_id == 1: + return get_session_1_queries_for_persona(persona) + elif session_id == 2: + return get_session_2_queries_for_persona(persona) + elif session_id == 3: + return get_session_3_queries_for_persona(persona) + else: + return get_pure_task_queries_for_persona(persona, session_id) + + +# ============================================================================= +# Logging +# ============================================================================= + +@dataclass +class TurnLog: + """Log entry for one turn.""" + user_id: str + persona_id: str + session_id: int + turn_id: int + query: str + query_type: str + task_type: str + answer: str + answer_length: int + sat_t: float + sev_t: float + prog_t: float + violations: List[str] + enforced_constraints: List[str] + reward: float + gating: float + is_complaint: bool + reveal_state_before: Dict[str, bool] + reveal_state_after: Dict[str, bool] + newly_revealed: List[str] + z_long_norm_before: float + z_long_norm_after: float + z_short_norm_before: float + z_short_norm_after: float + prompt_tokens: int + completion_tokens: int + total_tokens: int + # Memory retrieval details + num_memories_retrieved: int + num_prefs_extracted: int + selected_memory_ids: List[str] + selected_memory_notes: List[str] + selected_memory_scores: List[float] + num_candidates: int + num_total_memories: int + # Mode indicators + mode: str # "full" or "nopersonal" + eval_mode: bool # True = greedy, False = sample + + +def log_to_jsonl(logs: List[TurnLog], filepath: str): + os.makedirs(os.path.dirname(filepath), exist_ok=True) + with open(filepath, "w") as f: + for log in logs: + f.write(json.dumps(asdict(log)) + "\n") + + +# ============================================================================= +# Session Runner with Complaint Injection (FIX 5) +# ============================================================================= + +def run_session_v4( + llm: PersonalizedLLM, + user_id: str, + persona: Persona, + session_id: int, + reveal_state: RevealState, + base_queries: List[Dict[str, Any]], + all_logs: List[TurnLog], + enable_complaints: bool = True, +) -> List[TurnLog]: + """ + Run session with violation-triggered complaint injection. + """ + prefs = persona.style_prefs + session_logs: List[TurnLog] = [] + + print(f"\n{'='*60}") + print(f"[{persona.persona_id}] Session {session_id}: base queries={len(base_queries)}") + print(f"Reveal state (start): {reveal_state}") + print(f"{'='*60}") + + llm.reset_session(user_id) + + # Build dynamic query queue + query_queue = list(base_queries) + turn_id = 0 + + while query_queue: + q_info = query_queue.pop(0) + query = q_info["query"] + query_type = q_info.get("type", "task") + task_type = q_info.get("task_type", "general") + is_complaint = q_info.get("is_complaint", False) + + print(f"\n--- S{session_id}/T{turn_id} [{query_type}]{' [COMPLAINT]' if is_complaint else ''} ---") + print(f"[Q] {query[:60]}{'...' if len(query) > 60 else ''}") + + reveal_before = reveal_state.to_dict() + newly_revealed = update_reveal_state(reveal_state, query, prefs) + if newly_revealed: + print(f"[Reveal] Newly: {newly_revealed}") + reveal_after = reveal_state.to_dict() + + state_before = llm.get_user_state_summary(user_id) + z_long_before = state_before["z_long_norm"] + z_short_before = state_before["z_short_norm"] + + # Apply feedback for previous turn + if turn_id > 0 and session_logs: + prev_log = session_logs[-1] + feedback = Feedback( + user_id=user_id, + turn_id=prev_log.turn_id, + reward=prev_log.reward, + gating=prev_log.gating, + meta={"source": "pilot_v4", "session_id": session_id} + ) + llm.apply_feedback(feedback) + + # Chat + resp: AssistantResponse = llm.chat(user_id, query) + + answer_display = resp.answer[:80] + "..." if len(resp.answer) > 80 else resp.answer + print(f"[A] ({len(resp.answer)}c) {answer_display}") + + # Judge + judge_result = style_judge_v4(query, resp.answer, task_type, prefs, reveal_state) + print(f"[J] sat={judge_result.sat_t:.2f}, enforced={judge_result.enforced_constraints}, viol={judge_result.violations}") + + # Compute feedback + reward, gating = compute_feedback_for_turn(query, query_type, judge_result) + + state_after = llm.get_user_state_summary(user_id) + z_long_after = state_after["z_long_norm"] + z_short_after = state_after["z_short_norm"] + + # Extract memory info from debug + if resp.debug: + num_memories = len(resp.debug.selected_memory_ids) + num_prefs = len(resp.debug.extracted_preferences) + selected_memory_ids = resp.debug.selected_memory_ids + selected_memory_notes = resp.debug.selected_memory_notes + selected_memory_scores = resp.debug.selected_memory_scores + num_candidates = resp.debug.extra.get("num_candidates", 0) + num_total_memories = resp.debug.extra.get("num_total_memories", 0) + else: + num_memories = 0 + num_prefs = 0 + selected_memory_ids = [] + selected_memory_notes = [] + selected_memory_scores = [] + num_candidates = 0 + num_total_memories = 0 + + # Log + log = TurnLog( + user_id=user_id, + persona_id=persona.persona_id, + session_id=session_id, + turn_id=turn_id, + query=query, + query_type=query_type, + task_type=task_type, + answer=resp.answer, + answer_length=len(resp.answer), + sat_t=judge_result.sat_t, + sev_t=judge_result.sev_t, + prog_t=judge_result.prog_t, + violations=judge_result.violations, + enforced_constraints=judge_result.enforced_constraints, + reward=reward, + gating=gating, + is_complaint=is_complaint, + reveal_state_before=reveal_before, + reveal_state_after=reveal_after, + newly_revealed=list(newly_revealed), + z_long_norm_before=z_long_before, + z_long_norm_after=z_long_after, + z_short_norm_before=z_short_before, + z_short_norm_after=z_short_after, + prompt_tokens=resp.usage.prompt_tokens, + completion_tokens=resp.usage.completion_tokens, + total_tokens=resp.usage.total_tokens, + num_memories_retrieved=num_memories, + num_prefs_extracted=num_prefs, + selected_memory_ids=selected_memory_ids, + selected_memory_notes=selected_memory_notes, + selected_memory_scores=selected_memory_scores, + num_candidates=num_candidates, + num_total_memories=num_total_memories, + mode=llm.mode, + eval_mode=llm.eval_mode, + ) + session_logs.append(log) + all_logs.append(log) + + # FIX 5: Inject complaint if there were violations and this wasn't already a complaint + if enable_complaints and judge_result.violations and not is_complaint: + complaint = generate_complaint_query(judge_result.violations, prefs) + if complaint: + complaint["is_complaint"] = True + query_queue.insert(0, complaint) # Insert at front + print(f"[Complaint Injected] Will complain about: {judge_result.violations}") + + turn_id += 1 + + # Final feedback + if session_logs: + last_log = session_logs[-1] + feedback = Feedback( + user_id=user_id, + turn_id=last_log.turn_id, + reward=last_log.reward, + gating=last_log.gating, + meta={"source": "pilot_v4", "session_id": session_id, "final": True} + ) + llm.apply_feedback(feedback) + + print(f"\n[Session {session_id} End] Reveal: {reveal_state}, Turns: {turn_id}") + return session_logs + + +# ============================================================================= +# Multi-User Multi-Session Runner +# ============================================================================= + +def run_multi_user_pilot_v4( + llm: PersonalizedLLM, + personas: List[Persona], + num_sessions: int = 3, + enable_complaints: bool = True, +) -> List[TurnLog]: + """Run multi-user multi-session pilot v4.""" + reveal_manager = RevealStateManager() + all_logs: List[TurnLog] = [] + + print(f"\n{'#'*60}") + print(f"PILOT v4: MULTI-USER MULTI-SESSION (Fixed)") + print(f"Users: {len(personas)}, Sessions: {num_sessions}, Complaints: {enable_complaints}") + print(f"{'#'*60}") + + for persona in personas: + user_id = f"user_{persona.persona_id}" + prefs = persona.style_prefs + + print(f"\n{'*'*60}") + print(f"USER: {user_id}") + print(f"Persona: {persona.description}") + print(f"True prefs: short={prefs.require_short}, bullets={prefs.require_bullets}, lang={prefs.lang}") + print(f"{'*'*60}") + + llm.reset_user(user_id) + reveal_manager.reset_user(user_id) + reveal_state = reveal_manager.get_state(user_id) + + for session_id in range(1, num_sessions + 1): + queries = get_queries_for_session(persona, session_id) + + reveal_manager.reset_session(user_id) + run_session_v4(llm, user_id, persona, session_id, reveal_state, queries, all_logs, enable_complaints) + + return all_logs + + +# ============================================================================= +# Summary +# ============================================================================= + +def print_summary_v4(logs: List[TurnLog]): + """Print summary for pilot v4.""" + print(f"\n{'='*60}") + print("PILOT v4 SUMMARY") + print(f"{'='*60}") + + if not logs: + print("No logs.") + return + + from collections import Counter + + personas = sorted(set(l.persona_id for l in logs)) + + print(f"\n--- Per-Persona Statistics ---") + for pid in personas: + p_logs = [l for l in logs if l.persona_id == pid] + sessions = sorted(set(l.session_id for l in p_logs)) + + print(f"\n{pid}:") + for sid in sessions: + s_logs = [l for l in p_logs if l.session_id == sid] + avg_sat = sum(l.sat_t for l in s_logs) / len(s_logs) if s_logs else 0 + violations = [v for l in s_logs for v in l.violations] + enforced = set(c for l in s_logs for c in l.enforced_constraints) + complaints = sum(1 for l in s_logs if l.is_complaint) + + print(f" S{sid}: {len(s_logs)} turns, avg_sat={avg_sat:.3f}, complaints={complaints}") + print(f" enforced={enforced}") + if violations: + print(f" violations: {dict(Counter(violations))}") + + # Cross-session retention + print(f"\n--- Cross-Session Retention (S2 without preferences) ---") + for pid in personas: + p_logs = [l for l in logs if l.persona_id == pid] + s1_logs = [l for l in p_logs if l.session_id == 1] + s2_logs = [l for l in p_logs if l.session_id == 2] + + if s1_logs and s2_logs: + s1_sat = sum(l.sat_t for l in s1_logs) / len(s1_logs) + s2_sat = sum(l.sat_t for l in s2_logs) / len(s2_logs) + s2_enforced = set(c for l in s2_logs for c in l.enforced_constraints) + print(f"{pid}: S1={s1_sat:.3f} → S2={s2_sat:.3f}, enforced={s2_enforced}") + + # Violation rates + print(f"\n--- Violation Rates by Type ---") + all_violations = [v for l in logs for v in l.violations] + total_turns = len(logs) + if all_violations: + for v, count in Counter(all_violations).most_common(): + rate = count / total_turns * 100 + print(f" {v}: {count} ({rate:.1f}%)") + else: + print(" No violations") + + # Complaint effectiveness + print(f"\n--- Complaint Effectiveness ---") + complaint_logs = [l for l in logs if l.is_complaint] + if complaint_logs: + print(f"Total complaints: {len(complaint_logs)}") + avg_sat_complaint = sum(l.sat_t for l in complaint_logs) / len(complaint_logs) + print(f"Avg sat on complaint turns: {avg_sat_complaint:.3f}") + else: + print("No complaints generated") + + # Overall + total = len(logs) + avg_sat = sum(l.sat_t for l in logs) / total + total_tokens = sum(l.total_tokens for l in logs) + print(f"\n--- Overall ---") + print(f"Total turns: {total}, Avg sat: {avg_sat:.3f}, Total tokens: {total_tokens}") + + +def main(): + import argparse + parser = argparse.ArgumentParser(description="Pilot Runner v4 - Full vs Vanilla Comparison") + parser.add_argument("--mode", type=str, + choices=["full", "full-greedy", "full-sample", "nopersonal", "vanilla", "compare", "all"], + default="compare", + help="Mode: 'full-greedy' (personalized, deterministic), " + "'full-sample' (personalized, stochastic), " + "'nopersonal' (retrieval baseline without z_u), " + "'vanilla' (pure LLM, no memory), " + "'compare' (full-greedy vs vanilla), " + "'all' (run all modes)") + parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility") + parser.add_argument("--sessions", type=int, default=3, help="Number of sessions per user") + parser.add_argument("--no-complaints", action="store_true", help="Disable complaint injection") + args = parser.parse_args() + + # Set seeds for reproducibility + import random + import numpy as np + random.seed(args.seed) + np.random.seed(args.seed) + + print("=" * 60) + print("PILOT RUNNER v4 - Full vs Vanilla Comparison") + print("=" * 60) + print(f"Started at: {datetime.now().isoformat()}") + print(f"Mode: {args.mode}, Seed: {args.seed}, Sessions: {args.sessions}") + + personas = ALL_PERSONAS + print(f"\n[Config] {len(personas)} personas:") + for p in personas: + print(f" - {p.persona_id}: {p.description}") + + enable_complaints = not args.no_complaints + + # Map mode argument to actual run configurations + # Each config: (mode_name, llm_mode, eval_mode) + # llm_mode: "full", "nopersonal", or "vanilla" + # eval_mode: True = greedy/deterministic, False = stochastic sampling + if args.mode == "all": + run_configs = [ + ("full-greedy", "full", True), + ("full-sample", "full", False), + ("nopersonal", "nopersonal", True), + ("vanilla", "vanilla", True), + ] + elif args.mode == "compare": + # Main comparison: Full (with memory) vs Vanilla (no memory) + run_configs = [ + ("full-greedy", "full", True), + ("vanilla", "vanilla", True), + ] + elif args.mode == "full" or args.mode == "full-greedy": + run_configs = [("full-greedy", "full", True)] + elif args.mode == "full-sample": + run_configs = [("full-sample", "full", False)] + elif args.mode == "vanilla": + run_configs = [("vanilla", "vanilla", True)] + elif args.mode == "nopersonal": + run_configs = [("nopersonal", "nopersonal", True)] + else: + run_configs = [(args.mode, args.mode, True)] + + for run_name, llm_mode, eval_mode in run_configs: + print(f"\n{'#'*60}") + print(f"RUNNING: {run_name.upper()}") + print(f" llm_mode={llm_mode}, eval_mode={eval_mode} ({'greedy' if eval_mode else 'sample'})") + print(f"{'#'*60}") + + # Reset seeds before each run for exact reproducibility + random.seed(args.seed) + np.random.seed(args.seed) + + print(f"\n[Init] Loading PersonalizedLLM...") + llm = PersonalizedLLM( + user_store_path=f"data/users/user_store_pilot_v4_{run_name}.npz", + only_own_memories=True, + enable_preference_extraction=True, + enable_rl_updates=(llm_mode == "full"), # Disable RL for nopersonal + mode=llm_mode, + eval_mode=eval_mode, + device_assignment={ + "embed": "cuda:0", + "reranker": "cuda:1", + "chat": "cuda:2", + "extractor": "cuda:3", + }, + ) + + logs = run_multi_user_pilot_v4(llm, personas, num_sessions=args.sessions, enable_complaints=enable_complaints) + + print_summary_v4(logs) + + log_path = f"data/logs/pilot_v4_{run_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl" + log_to_jsonl(logs, log_path) + print(f"\n[Logs] Saved to: {log_path}") + + # Save user vectors for similarity analysis + if llm_mode == "full": + llm.persist() + print(f"[Persist] User vectors saved to: {llm._user_store.path}") + + print(f"\nCompleted at: {datetime.now().isoformat()}") + print("=" * 60) + + +if __name__ == "__main__": + main() + |
