summaryrefslogtreecommitdiff
path: root/scripts/pilot_runner_v4.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
commite43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch)
tree6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /scripts/pilot_runner_v4.py
Initial commit (clean history)HEADmain
Diffstat (limited to 'scripts/pilot_runner_v4.py')
-rw-r--r--scripts/pilot_runner_v4.py1230
1 files changed, 1230 insertions, 0 deletions
diff --git a/scripts/pilot_runner_v4.py b/scripts/pilot_runner_v4.py
new file mode 100644
index 0000000..b3e2058
--- /dev/null
+++ b/scripts/pilot_runner_v4.py
@@ -0,0 +1,1230 @@
+#!/usr/bin/env python3
+"""
+Pilot Runner v4 - Critical Fixes for Baseline Comparison
+
+Fixes from v3:
+1. Chinese short reveal detection (简短/字以内/不超过 etc.)
+2. Symmetric bullets constraint (has_bullets violation for require_bullets=False)
+3. Better wrong_lang with CJK ratio + math exemption
+4. Persona-conditional query templates (no self-contradiction)
+5. Violation-triggered complaint mechanism (for online RL signal)
+
+This version is ready for proper baseline comparison.
+"""
+
+import sys
+import os
+import re
+import json
+from datetime import datetime
+from dataclasses import dataclass, asdict, field
+from typing import List, Dict, Any, Optional, Tuple, Set, Literal
+
+# Add src to path
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.serving import PersonalizedLLM, Feedback, AssistantResponse
+
+
+# =============================================================================
+# Style Preferences (True Preferences)
+# =============================================================================
+
+@dataclass
+class StylePrefs:
+ """User's TRUE style preferences."""
+ require_short: bool = False
+ max_chars: int = 300
+ require_bullets: bool = False # True = want bullets, False = don't want bullets
+ lang: str = "en" # "en" or "zh"
+
+
+# =============================================================================
+# Persona Definition
+# =============================================================================
+
+@dataclass
+class Persona:
+ """A user persona that bundles style preferences."""
+ persona_id: str
+ style_prefs: StylePrefs
+ description: str = ""
+
+
+# 5 Test Personas
+PERSONA_A = Persona(
+ persona_id="A_short_bullets_en",
+ style_prefs=StylePrefs(require_short=True, max_chars=200, require_bullets=True, lang="en"),
+ description="Short + bullets + English",
+)
+
+PERSONA_B = Persona(
+ persona_id="B_short_no_bullets_en",
+ style_prefs=StylePrefs(require_short=True, max_chars=200, require_bullets=False, lang="en"),
+ description="Short + NO bullets + English (anti-bullet)",
+)
+
+PERSONA_C = Persona(
+ persona_id="C_long_bullets_en",
+ style_prefs=StylePrefs(require_short=False, max_chars=800, require_bullets=True, lang="en"),
+ description="Long + bullets + English",
+)
+
+PERSONA_D = Persona(
+ persona_id="D_short_bullets_zh",
+ style_prefs=StylePrefs(require_short=True, max_chars=200, require_bullets=True, lang="zh"),
+ description="Short + bullets + Chinese",
+)
+
+PERSONA_E = Persona(
+ persona_id="E_long_no_bullets_zh",
+ style_prefs=StylePrefs(require_short=False, max_chars=800, require_bullets=False, lang="zh"),
+ description="Long + NO bullets + Chinese (most anti-default)",
+)
+
+# Extreme short persona for case study - LLM default is much longer
+PERSONA_F = Persona(
+ persona_id="F_extreme_short_en",
+ style_prefs=StylePrefs(require_short=True, max_chars=100, require_bullets=True, lang="en"),
+ description="EXTREME short (100 chars) + bullets + English",
+)
+
+ALL_PERSONAS = [PERSONA_A, PERSONA_B, PERSONA_C, PERSONA_D, PERSONA_E, PERSONA_F]
+
+
+# =============================================================================
+# Reveal State
+# =============================================================================
+
+@dataclass
+class RevealState:
+ """Tracks which preferences have been explicitly revealed."""
+ short_revealed: bool = False
+ bullets_revealed: bool = False
+ lang_revealed: bool = False
+
+ def reset(self):
+ self.short_revealed = False
+ self.bullets_revealed = False
+ self.lang_revealed = False
+
+ def to_dict(self) -> Dict[str, bool]:
+ return {"short": self.short_revealed, "bullets": self.bullets_revealed, "lang": self.lang_revealed}
+
+ def __str__(self) -> str:
+ flags = [k for k, v in self.to_dict().items() if v]
+ return f"RevealState({', '.join(flags) if flags else 'none'})"
+
+
+class RevealStateManager:
+ """Manages reveal state for multiple users."""
+ def __init__(self):
+ self._states: Dict[str, RevealState] = {}
+
+ def get_state(self, user_id: str) -> RevealState:
+ if user_id not in self._states:
+ self._states[user_id] = RevealState()
+ return self._states[user_id]
+
+ def reset_user(self, user_id: str):
+ self._states[user_id] = RevealState()
+
+ def reset_session(self, user_id: str):
+ pass # Reveal state persists
+
+
+# =============================================================================
+# FIX 1: Improved Preference Detection (with Chinese support)
+# =============================================================================
+
+def detect_revealed_preferences(query: str, prefs: StylePrefs) -> Dict[str, bool]:
+ """
+ Detect which preferences are mentioned in a query.
+ FIX: Added Chinese keywords for short detection.
+ """
+ lower_q = (query or "").lower()
+ original_q = query or ""
+
+ revealed = {"short": False, "bullets": False, "lang": False}
+
+ # Short/length preference - English patterns
+ short_patterns_en = [
+ "short", "concise", "brief", "under ", "less than",
+ "keep it short", "keep responses", "keep answers",
+ "maximum ", "max ", "characters", "words or less",
+ ]
+
+ # FIX: Chinese patterns for short preference
+ short_patterns_zh = [
+ "简短", "精简", "尽量短", "不要太长", "字以内", "不超过",
+ "少于", "控制在", "简洁", "简明",
+ ]
+
+ # Regex patterns for number-based length constraints
+ short_regex_patterns = [
+ r"(\d+)\s*字以内", # "200字以内"
+ r"不超过\s*(\d+)\s*字", # "不超过200字"
+ r"under\s*(\d+)", # "under 200"
+ r"less\s*than\s*(\d+)", # "less than 200"
+ ]
+
+ for pattern in short_patterns_en:
+ if pattern in lower_q:
+ revealed["short"] = True
+ break
+
+ if not revealed["short"]:
+ for pattern in short_patterns_zh:
+ if pattern in original_q:
+ revealed["short"] = True
+ break
+
+ if not revealed["short"]:
+ for regex in short_regex_patterns:
+ if re.search(regex, original_q, re.IGNORECASE):
+ revealed["short"] = True
+ break
+
+ # Bullet preference - both positive and negative
+ bullet_patterns_positive = [
+ "bullet", "bullet point", "bullet-point", "bulleted",
+ "list format", "use bullets", "use bullet",
+ "项目符号", "要点", "用bullet",
+ ]
+ bullet_patterns_negative = [
+ "no bullet", "don't use bullet", "without bullet",
+ "不要bullet", "不要项目符号", "不用bullet",
+ "continuous prose", "paragraph form", "flowing text",
+ "连续句子", "段落形式",
+ ]
+
+ for pattern in bullet_patterns_positive + bullet_patterns_negative:
+ if pattern in lower_q or pattern in original_q:
+ revealed["bullets"] = True
+ break
+
+ # Language preference
+ lang_patterns_zh = [
+ "chinese", "中文", "in chinese", "用中文",
+ "speak chinese", "respond in chinese", "请用中文",
+ ]
+ lang_patterns_en = [
+ "english", "in english", "use english",
+ "speak english", "respond in english",
+ ]
+
+ for pattern in lang_patterns_zh + lang_patterns_en:
+ if pattern in lower_q or pattern in original_q:
+ revealed["lang"] = True
+ break
+
+ return revealed
+
+
+def update_reveal_state(reveal_state: RevealState, query: str, prefs: StylePrefs) -> Set[str]:
+ """Update reveal state based on query content."""
+ detected = detect_revealed_preferences(query, prefs)
+ newly_revealed = set()
+
+ if detected["short"] and not reveal_state.short_revealed:
+ reveal_state.short_revealed = True
+ newly_revealed.add("short")
+
+ if detected["bullets"] and not reveal_state.bullets_revealed:
+ reveal_state.bullets_revealed = True
+ newly_revealed.add("bullets")
+
+ if detected["lang"] and not reveal_state.lang_revealed:
+ reveal_state.lang_revealed = True
+ newly_revealed.add("lang")
+
+ return newly_revealed
+
+
+# =============================================================================
+# FIX 3: Better Language Detection
+# =============================================================================
+
+def is_math_or_symbol_only(text: str) -> bool:
+ """Check if text is purely math/symbols (language neutral)."""
+ # Pattern: only digits, operators, whitespace, punctuation
+ math_pattern = r'^[\d+\-*/=().,%\s\n\r]+$'
+ return bool(re.match(math_pattern, text.strip()))
+
+
+def count_cjk_chars(text: str) -> int:
+ """Count CJK (Chinese/Japanese/Korean) characters."""
+ # CJK Unified Ideographs range
+ cjk_pattern = re.compile(r'[\u4e00-\u9fff\u3400-\u4dbf]')
+ return len(cjk_pattern.findall(text))
+
+
+def count_latin_letters(text: str) -> int:
+ """Count Latin letters (a-z, A-Z)."""
+ return sum(1 for c in text if c.isalpha() and c.isascii())
+
+
+def check_language_violation(text: str, target_lang: str) -> bool:
+ """
+ FIX: Better language violation check using CJK ratio.
+ Returns True if there's a violation.
+ """
+ text = text.strip()
+
+ # Exempt pure math/symbols
+ if is_math_or_symbol_only(text):
+ return False
+
+ cjk_count = count_cjk_chars(text)
+ latin_count = count_latin_letters(text)
+ total = cjk_count + latin_count
+
+ if total == 0:
+ return False # No meaningful text to judge
+
+ if target_lang == "zh":
+ # For Chinese: want high CJK ratio
+ cjk_ratio = cjk_count / (total + 1e-9)
+ # Allow some English proper nouns - only flag if very low CJK
+ return cjk_ratio < 0.2 # Less than 20% CJK = wrong language
+
+ elif target_lang == "en":
+ # For English: want high Latin ratio
+ latin_ratio = latin_count / (total + 1e-9)
+ return latin_ratio < 0.5 # Less than 50% Latin = wrong language
+
+ return False
+
+
+# =============================================================================
+# FIX 2: Symmetric Bullets Constraint + FIX 3: Language
+# =============================================================================
+
+@dataclass
+class JudgeResult:
+ """Output from the judge for one turn."""
+ sat_t: float
+ sev_t: float
+ prog_t: float
+ violations: List[str]
+ enforced_constraints: List[str]
+
+
+def has_bullet_markers(text: str) -> bool:
+ """Check if text contains bullet point markers."""
+ return bool(re.search(r'(^|\n)\s*[-•*]\s', text))
+
+
+def style_judge_v4(
+ query: str,
+ answer: str,
+ task_type: str,
+ prefs: StylePrefs,
+ reveal_state: RevealState,
+) -> JudgeResult:
+ """
+ Style judge v4 with:
+ - FIX 2: Symmetric bullets (has_bullets violation for require_bullets=False)
+ - FIX 3: Better wrong_lang with CJK ratio + math exemption
+ """
+ violations: List[str] = []
+ enforced: List[str] = []
+ text = (answer or "").strip()
+
+ # 0) Empty answer
+ if len(text) == 0:
+ violations.append("empty_answer")
+ return JudgeResult(sat_t=0.0, sev_t=1.0, prog_t=0.0,
+ violations=violations, enforced_constraints=["non_empty"])
+
+ # 1) Length - enforce only if true AND revealed
+ if prefs.require_short and reveal_state.short_revealed:
+ enforced.append("short")
+ if len(text) > prefs.max_chars:
+ violations.append("too_long")
+
+ # 2) FIX 2: Symmetric bullets constraint (only for list tasks)
+ if reveal_state.bullets_revealed and task_type == "list":
+ has_bullets = has_bullet_markers(text)
+
+ if prefs.require_bullets:
+ # Want bullets but don't have them
+ enforced.append("require_bullets")
+ if not has_bullets:
+ violations.append("no_bullets")
+ else:
+ # Don't want bullets but have them
+ enforced.append("no_bullets_pref")
+ if has_bullets:
+ violations.append("has_bullets")
+
+ # 3) FIX 3: Language with CJK ratio + math exemption
+ if reveal_state.lang_revealed:
+ enforced.append("lang")
+ if check_language_violation(text, prefs.lang):
+ violations.append("wrong_lang")
+
+ # 4) Code task
+ prog_t = 1.0
+ if task_type == "code":
+ enforced.append("code_block")
+ has_code = ("```" in text) or ("def " in text) or ("function " in text)
+ if not has_code:
+ violations.append("no_code_block")
+ prog_t = 0.0
+
+ # 5) Compute scores
+ if not violations:
+ sat_t = 1.0
+ sev_t = 0.0
+ else:
+ sat_t = max(0.0, 1.0 - 0.3 * float(len(violations)))
+ hard_violations = {"empty_answer", "too_long", "wrong_lang"}
+ sev_t = 1.0 if any(v in hard_violations for v in violations) else 0.0
+
+ return JudgeResult(sat_t=sat_t, sev_t=sev_t, prog_t=prog_t,
+ violations=violations, enforced_constraints=enforced)
+
+
+# =============================================================================
+# Feedback Computation
+# =============================================================================
+
+def compute_feedback_for_turn(
+ query: str,
+ query_type: str,
+ judge_result: JudgeResult,
+) -> Tuple[float, float]:
+ """Convert JudgeResult into (reward, gating)."""
+ reward = judge_result.sat_t
+
+ lower_q = (query or "").lower()
+ original_q = query or ""
+
+ is_pref_turn = (
+ query_type == "preference"
+ or "i prefer" in lower_q
+ or "my preference" in lower_q
+ or "please use" in lower_q
+ or "please keep" in lower_q
+ or "you didn't follow" in lower_q
+ or "you forgot" in lower_q
+ or "remember that i" in lower_q
+ or "i told you" in lower_q
+ or "i asked for" in lower_q
+ or "that was too" in lower_q
+ or "too long" in lower_q
+ or "请用中文" in original_q
+ or "不要" in original_q
+ or "简短" in original_q
+ )
+
+ gating = 1.0 if is_pref_turn else 0.0
+ return reward, gating
+
+
+# =============================================================================
+# FIX 5: Violation-Triggered Complaint Generation
+# =============================================================================
+
+def generate_complaint_query(violations: List[str], prefs: StylePrefs) -> Optional[Dict[str, Any]]:
+ """
+ Generate a complaint query based on violations.
+ Returns None if no complaint needed.
+ """
+ if not violations:
+ return None
+
+ # Priority: address most severe violation first
+ complaint = None
+
+ if "too_long" in violations:
+ if prefs.lang == "zh":
+ complaint = {
+ "query": f"回答太长了。请保持回复在{prefs.max_chars}字以内。",
+ "type": "preference",
+ "task_type": "general",
+ }
+ else:
+ complaint = {
+ "query": f"That was too long. Please keep responses under {prefs.max_chars} characters.",
+ "type": "preference",
+ "task_type": "general",
+ }
+
+ elif "wrong_lang" in violations:
+ if prefs.lang == "zh":
+ complaint = {
+ "query": "请用中文回答。",
+ "type": "preference",
+ "task_type": "general",
+ }
+ else:
+ complaint = {
+ "query": "Please respond in English.",
+ "type": "preference",
+ "task_type": "general",
+ }
+
+ elif "no_bullets" in violations:
+ if prefs.lang == "zh":
+ complaint = {
+ "query": "请在列出内容时使用项目符号(bullet points)。",
+ "type": "preference",
+ "task_type": "general",
+ }
+ else:
+ complaint = {
+ "query": "Please use bullet points when listing things.",
+ "type": "preference",
+ "task_type": "general",
+ }
+
+ elif "has_bullets" in violations:
+ if prefs.lang == "zh":
+ complaint = {
+ "query": "请不要使用项目符号,用连续的句子来表达。",
+ "type": "preference",
+ "task_type": "general",
+ }
+ else:
+ complaint = {
+ "query": "Please don't use bullet points. Use continuous prose instead.",
+ "type": "preference",
+ "task_type": "general",
+ }
+
+ return complaint
+
+
+# =============================================================================
+# FIX 4: Persona-Conditional Query Templates
+# =============================================================================
+
+def get_session_1_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]:
+ """
+ Session 1: Reveal preferences (persona-conditional).
+ FIX: Only reveal preferences that match the persona's true prefs.
+ """
+ queries = []
+ prefs = persona.style_prefs
+
+ # Turn 0: Reveal length preference (only if require_short=True)
+ if prefs.require_short:
+ if prefs.lang == "zh":
+ queries.append({
+ "query": f"我喜欢简短的回答,请保持回复在{prefs.max_chars}字以内。",
+ "type": "preference",
+ "task_type": "general",
+ })
+ else:
+ queries.append({
+ "query": f"I prefer short, concise answers. Please keep responses under {prefs.max_chars} characters.",
+ "type": "preference",
+ "task_type": "general",
+ })
+ else:
+ # Don't reveal short preference for long-preferring personas
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "你好,我有一些问题想问你。",
+ "type": "task",
+ "task_type": "general",
+ })
+ else:
+ queries.append({
+ "query": "Hello, I have some questions for you.",
+ "type": "task",
+ "task_type": "general",
+ })
+
+ # Turn 1: First task
+ if prefs.lang == "zh":
+ queries.append({"query": "列出三个改善睡眠的建议。", "type": "task", "task_type": "list"})
+ else:
+ queries.append({"query": "List three tips for better sleep.", "type": "task", "task_type": "list"})
+
+ # Turn 2: Reveal bullet preference (conditional on require_bullets)
+ if prefs.require_bullets:
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "我喜欢用项目符号列出要点,请使用bullet points。",
+ "type": "preference",
+ "task_type": "general",
+ })
+ else:
+ queries.append({
+ "query": "I prefer bullet points when listing things. Please use bullet points.",
+ "type": "preference",
+ "task_type": "general",
+ })
+ else:
+ # Explicitly say NO bullets
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "请不要用项目符号,我更喜欢连续的句子来表达。",
+ "type": "preference",
+ "task_type": "general",
+ })
+ else:
+ queries.append({
+ "query": "Please don't use bullet points. I prefer continuous prose.",
+ "type": "preference",
+ "task_type": "general",
+ })
+
+ # Turn 3: Reveal language preference
+ if prefs.lang == "zh":
+ queries.append({"query": "请用中文回答我的问题。", "type": "preference", "task_type": "general"})
+ else:
+ queries.append({"query": "Please respond in English.", "type": "preference", "task_type": "general"})
+
+ # Turn 4-5: Tasks
+ if prefs.lang == "zh":
+ queries.extend([
+ {"query": "锻炼有什么好处?", "type": "task", "task_type": "list"},
+ {"query": "列出五种流行的编程语言。", "type": "task", "task_type": "list"},
+ ])
+ else:
+ queries.extend([
+ {"query": "What are the benefits of exercise?", "type": "task", "task_type": "list"},
+ {"query": "Name five popular programming languages.", "type": "task", "task_type": "list"},
+ ])
+
+ return queries
+
+
+def get_session_2_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]:
+ """Session 2: NO preference restatement."""
+ prefs = persona.style_prefs
+
+ if prefs.lang == "zh":
+ return [
+ {"query": "推荐三种健康的早餐。", "type": "task", "task_type": "list"},
+ {"query": "一年有哪四个季节?", "type": "task", "task_type": "list"},
+ {"query": "法国的首都是哪里?", "type": "task", "task_type": "qa"},
+ {"query": "列出三种可再生能源。", "type": "task", "task_type": "list"},
+ ]
+ else:
+ return [
+ {"query": "What are three healthy breakfast ideas?", "type": "task", "task_type": "list"},
+ {"query": "What are the four seasons of the year?", "type": "task", "task_type": "list"},
+ {"query": "What is the capital of France?", "type": "task", "task_type": "qa"},
+ {"query": "Name three types of renewable energy.", "type": "task", "task_type": "list"},
+ ]
+
+
+def get_session_3_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]:
+ """
+ Session 3: Tasks with ONE persona-conditional reminder.
+ FIX: Reminder matches persona's actual preferences.
+ """
+ prefs = persona.style_prefs
+ queries = []
+
+ # First task
+ if prefs.lang == "zh":
+ queries.append({"query": "列出五种常见的水果。", "type": "task", "task_type": "list"})
+ else:
+ queries.append({"query": "Name five common fruits.", "type": "task", "task_type": "list"})
+
+ # Persona-conditional reminder
+ if prefs.require_short and prefs.require_bullets:
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "记住我喜欢简短的回答和项目符号。列出三种海洋动物。",
+ "type": "preference", "task_type": "list"
+ })
+ else:
+ queries.append({
+ "query": "Remember I prefer short answers with bullet points. List three ocean animals.",
+ "type": "preference", "task_type": "list"
+ })
+ elif prefs.require_short and not prefs.require_bullets:
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "记住我喜欢简短的回答,不要用项目符号。列出三种海洋动物。",
+ "type": "preference", "task_type": "list"
+ })
+ else:
+ queries.append({
+ "query": "Remember I prefer short answers without bullet points. List three ocean animals.",
+ "type": "preference", "task_type": "list"
+ })
+ elif not prefs.require_short and prefs.require_bullets:
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "记住我喜欢用项目符号列出要点。列出三种海洋动物。",
+ "type": "preference", "task_type": "list"
+ })
+ else:
+ queries.append({
+ "query": "Remember I prefer bullet points. List three ocean animals.",
+ "type": "preference", "task_type": "list"
+ })
+ else: # not short and not bullets
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "记住我不喜欢用项目符号,喜欢连续的句子。列出三种海洋动物。",
+ "type": "preference", "task_type": "list"
+ })
+ else:
+ queries.append({
+ "query": "Remember I prefer continuous prose without bullet points. List three ocean animals.",
+ "type": "preference", "task_type": "list"
+ })
+
+ # Final task
+ if prefs.lang == "zh":
+ queries.append({"query": "2加2等于多少?", "type": "task", "task_type": "qa"})
+ else:
+ queries.append({"query": "What is 2 + 2?", "type": "task", "task_type": "qa"})
+
+ return queries
+
+
+def get_pure_task_queries_for_persona(persona: Persona, session_idx: int) -> List[Dict[str, Any]]:
+ """
+ Pure task sessions (S4+): NO preference reminders at all.
+ Used for testing long-term retention without any in-context hints.
+ Different task sets per session to avoid repetition.
+ """
+ prefs = persona.style_prefs
+
+ # Task pools for variety
+ zh_task_pools = [
+ # Pool 1
+ [
+ {"query": "列出三种热带水果。", "type": "task", "task_type": "list"},
+ {"query": "列出三种常见的编程语言。", "type": "task", "task_type": "list"},
+ {"query": "什么是光合作用?", "type": "task", "task_type": "qa"},
+ {"query": "太阳系有几颗行星?", "type": "task", "task_type": "qa"},
+ ],
+ # Pool 2
+ [
+ {"query": "列出三种室内植物。", "type": "task", "task_type": "list"},
+ {"query": "列出三种运动项目。", "type": "task", "task_type": "list"},
+ {"query": "什么是人工智能?", "type": "task", "task_type": "qa"},
+ {"query": "地球的自转周期是多少?", "type": "task", "task_type": "qa"},
+ ],
+ # Pool 3
+ [
+ {"query": "列出三种乐器。", "type": "task", "task_type": "list"},
+ {"query": "列出三种社交媒体平台。", "type": "task", "task_type": "list"},
+ {"query": "什么是区块链?", "type": "task", "task_type": "qa"},
+ {"query": "月球绕地球一周需要多长时间?", "type": "task", "task_type": "qa"},
+ ],
+ # Pool 4
+ [
+ {"query": "列出三种鸟类。", "type": "task", "task_type": "list"},
+ {"query": "列出三种数据库系统。", "type": "task", "task_type": "list"},
+ {"query": "什么是机器学习?", "type": "task", "task_type": "qa"},
+ {"query": "水的沸点是多少?", "type": "task", "task_type": "qa"},
+ ],
+ ]
+
+ en_task_pools = [
+ # Pool 1
+ [
+ {"query": "List three tropical fruits.", "type": "task", "task_type": "list"},
+ {"query": "List three popular programming languages.", "type": "task", "task_type": "list"},
+ {"query": "What is photosynthesis?", "type": "task", "task_type": "qa"},
+ {"query": "How many planets are in our solar system?", "type": "task", "task_type": "qa"},
+ ],
+ # Pool 2
+ [
+ {"query": "List three indoor plants.", "type": "task", "task_type": "list"},
+ {"query": "List three types of sports.", "type": "task", "task_type": "list"},
+ {"query": "What is artificial intelligence?", "type": "task", "task_type": "qa"},
+ {"query": "How long is a day on Earth?", "type": "task", "task_type": "qa"},
+ ],
+ # Pool 3
+ [
+ {"query": "List three musical instruments.", "type": "task", "task_type": "list"},
+ {"query": "List three social media platforms.", "type": "task", "task_type": "list"},
+ {"query": "What is blockchain?", "type": "task", "task_type": "qa"},
+ {"query": "How long does it take the Moon to orbit Earth?", "type": "task", "task_type": "qa"},
+ ],
+ # Pool 4
+ [
+ {"query": "List three types of birds.", "type": "task", "task_type": "list"},
+ {"query": "List three database systems.", "type": "task", "task_type": "list"},
+ {"query": "What is machine learning?", "type": "task", "task_type": "qa"},
+ {"query": "What is the boiling point of water?", "type": "task", "task_type": "qa"},
+ ],
+ ]
+
+ pools = zh_task_pools if prefs.lang == "zh" else en_task_pools
+ # Rotate through pools based on session index
+ pool_idx = (session_idx - 4) % len(pools)
+ return pools[pool_idx]
+
+
+def get_queries_for_session(persona: Persona, session_id: int) -> List[Dict[str, Any]]:
+ """
+ Get queries for a specific session.
+ S1: Preference reveal
+ S2: Pure task (no reminder)
+ S3: Tasks with ONE reminder
+ S4+: Pure task (testing long-term retention)
+ """
+ if session_id == 1:
+ return get_session_1_queries_for_persona(persona)
+ elif session_id == 2:
+ return get_session_2_queries_for_persona(persona)
+ elif session_id == 3:
+ return get_session_3_queries_for_persona(persona)
+ else:
+ return get_pure_task_queries_for_persona(persona, session_id)
+
+
+# =============================================================================
+# Logging
+# =============================================================================
+
+@dataclass
+class TurnLog:
+ """Log entry for one turn."""
+ user_id: str
+ persona_id: str
+ session_id: int
+ turn_id: int
+ query: str
+ query_type: str
+ task_type: str
+ answer: str
+ answer_length: int
+ sat_t: float
+ sev_t: float
+ prog_t: float
+ violations: List[str]
+ enforced_constraints: List[str]
+ reward: float
+ gating: float
+ is_complaint: bool
+ reveal_state_before: Dict[str, bool]
+ reveal_state_after: Dict[str, bool]
+ newly_revealed: List[str]
+ z_long_norm_before: float
+ z_long_norm_after: float
+ z_short_norm_before: float
+ z_short_norm_after: float
+ prompt_tokens: int
+ completion_tokens: int
+ total_tokens: int
+ # Memory retrieval details
+ num_memories_retrieved: int
+ num_prefs_extracted: int
+ selected_memory_ids: List[str]
+ selected_memory_notes: List[str]
+ selected_memory_scores: List[float]
+ num_candidates: int
+ num_total_memories: int
+ # Mode indicators
+ mode: str # "full" or "nopersonal"
+ eval_mode: bool # True = greedy, False = sample
+
+
+def log_to_jsonl(logs: List[TurnLog], filepath: str):
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
+ with open(filepath, "w") as f:
+ for log in logs:
+ f.write(json.dumps(asdict(log)) + "\n")
+
+
+# =============================================================================
+# Session Runner with Complaint Injection (FIX 5)
+# =============================================================================
+
+def run_session_v4(
+ llm: PersonalizedLLM,
+ user_id: str,
+ persona: Persona,
+ session_id: int,
+ reveal_state: RevealState,
+ base_queries: List[Dict[str, Any]],
+ all_logs: List[TurnLog],
+ enable_complaints: bool = True,
+) -> List[TurnLog]:
+ """
+ Run session with violation-triggered complaint injection.
+ """
+ prefs = persona.style_prefs
+ session_logs: List[TurnLog] = []
+
+ print(f"\n{'='*60}")
+ print(f"[{persona.persona_id}] Session {session_id}: base queries={len(base_queries)}")
+ print(f"Reveal state (start): {reveal_state}")
+ print(f"{'='*60}")
+
+ llm.reset_session(user_id)
+
+ # Build dynamic query queue
+ query_queue = list(base_queries)
+ turn_id = 0
+
+ while query_queue:
+ q_info = query_queue.pop(0)
+ query = q_info["query"]
+ query_type = q_info.get("type", "task")
+ task_type = q_info.get("task_type", "general")
+ is_complaint = q_info.get("is_complaint", False)
+
+ print(f"\n--- S{session_id}/T{turn_id} [{query_type}]{' [COMPLAINT]' if is_complaint else ''} ---")
+ print(f"[Q] {query[:60]}{'...' if len(query) > 60 else ''}")
+
+ reveal_before = reveal_state.to_dict()
+ newly_revealed = update_reveal_state(reveal_state, query, prefs)
+ if newly_revealed:
+ print(f"[Reveal] Newly: {newly_revealed}")
+ reveal_after = reveal_state.to_dict()
+
+ state_before = llm.get_user_state_summary(user_id)
+ z_long_before = state_before["z_long_norm"]
+ z_short_before = state_before["z_short_norm"]
+
+ # Apply feedback for previous turn
+ if turn_id > 0 and session_logs:
+ prev_log = session_logs[-1]
+ feedback = Feedback(
+ user_id=user_id,
+ turn_id=prev_log.turn_id,
+ reward=prev_log.reward,
+ gating=prev_log.gating,
+ meta={"source": "pilot_v4", "session_id": session_id}
+ )
+ llm.apply_feedback(feedback)
+
+ # Chat
+ resp: AssistantResponse = llm.chat(user_id, query)
+
+ answer_display = resp.answer[:80] + "..." if len(resp.answer) > 80 else resp.answer
+ print(f"[A] ({len(resp.answer)}c) {answer_display}")
+
+ # Judge
+ judge_result = style_judge_v4(query, resp.answer, task_type, prefs, reveal_state)
+ print(f"[J] sat={judge_result.sat_t:.2f}, enforced={judge_result.enforced_constraints}, viol={judge_result.violations}")
+
+ # Compute feedback
+ reward, gating = compute_feedback_for_turn(query, query_type, judge_result)
+
+ state_after = llm.get_user_state_summary(user_id)
+ z_long_after = state_after["z_long_norm"]
+ z_short_after = state_after["z_short_norm"]
+
+ # Extract memory info from debug
+ if resp.debug:
+ num_memories = len(resp.debug.selected_memory_ids)
+ num_prefs = len(resp.debug.extracted_preferences)
+ selected_memory_ids = resp.debug.selected_memory_ids
+ selected_memory_notes = resp.debug.selected_memory_notes
+ selected_memory_scores = resp.debug.selected_memory_scores
+ num_candidates = resp.debug.extra.get("num_candidates", 0)
+ num_total_memories = resp.debug.extra.get("num_total_memories", 0)
+ else:
+ num_memories = 0
+ num_prefs = 0
+ selected_memory_ids = []
+ selected_memory_notes = []
+ selected_memory_scores = []
+ num_candidates = 0
+ num_total_memories = 0
+
+ # Log
+ log = TurnLog(
+ user_id=user_id,
+ persona_id=persona.persona_id,
+ session_id=session_id,
+ turn_id=turn_id,
+ query=query,
+ query_type=query_type,
+ task_type=task_type,
+ answer=resp.answer,
+ answer_length=len(resp.answer),
+ sat_t=judge_result.sat_t,
+ sev_t=judge_result.sev_t,
+ prog_t=judge_result.prog_t,
+ violations=judge_result.violations,
+ enforced_constraints=judge_result.enforced_constraints,
+ reward=reward,
+ gating=gating,
+ is_complaint=is_complaint,
+ reveal_state_before=reveal_before,
+ reveal_state_after=reveal_after,
+ newly_revealed=list(newly_revealed),
+ z_long_norm_before=z_long_before,
+ z_long_norm_after=z_long_after,
+ z_short_norm_before=z_short_before,
+ z_short_norm_after=z_short_after,
+ prompt_tokens=resp.usage.prompt_tokens,
+ completion_tokens=resp.usage.completion_tokens,
+ total_tokens=resp.usage.total_tokens,
+ num_memories_retrieved=num_memories,
+ num_prefs_extracted=num_prefs,
+ selected_memory_ids=selected_memory_ids,
+ selected_memory_notes=selected_memory_notes,
+ selected_memory_scores=selected_memory_scores,
+ num_candidates=num_candidates,
+ num_total_memories=num_total_memories,
+ mode=llm.mode,
+ eval_mode=llm.eval_mode,
+ )
+ session_logs.append(log)
+ all_logs.append(log)
+
+ # FIX 5: Inject complaint if there were violations and this wasn't already a complaint
+ if enable_complaints and judge_result.violations and not is_complaint:
+ complaint = generate_complaint_query(judge_result.violations, prefs)
+ if complaint:
+ complaint["is_complaint"] = True
+ query_queue.insert(0, complaint) # Insert at front
+ print(f"[Complaint Injected] Will complain about: {judge_result.violations}")
+
+ turn_id += 1
+
+ # Final feedback
+ if session_logs:
+ last_log = session_logs[-1]
+ feedback = Feedback(
+ user_id=user_id,
+ turn_id=last_log.turn_id,
+ reward=last_log.reward,
+ gating=last_log.gating,
+ meta={"source": "pilot_v4", "session_id": session_id, "final": True}
+ )
+ llm.apply_feedback(feedback)
+
+ print(f"\n[Session {session_id} End] Reveal: {reveal_state}, Turns: {turn_id}")
+ return session_logs
+
+
+# =============================================================================
+# Multi-User Multi-Session Runner
+# =============================================================================
+
+def run_multi_user_pilot_v4(
+ llm: PersonalizedLLM,
+ personas: List[Persona],
+ num_sessions: int = 3,
+ enable_complaints: bool = True,
+) -> List[TurnLog]:
+ """Run multi-user multi-session pilot v4."""
+ reveal_manager = RevealStateManager()
+ all_logs: List[TurnLog] = []
+
+ print(f"\n{'#'*60}")
+ print(f"PILOT v4: MULTI-USER MULTI-SESSION (Fixed)")
+ print(f"Users: {len(personas)}, Sessions: {num_sessions}, Complaints: {enable_complaints}")
+ print(f"{'#'*60}")
+
+ for persona in personas:
+ user_id = f"user_{persona.persona_id}"
+ prefs = persona.style_prefs
+
+ print(f"\n{'*'*60}")
+ print(f"USER: {user_id}")
+ print(f"Persona: {persona.description}")
+ print(f"True prefs: short={prefs.require_short}, bullets={prefs.require_bullets}, lang={prefs.lang}")
+ print(f"{'*'*60}")
+
+ llm.reset_user(user_id)
+ reveal_manager.reset_user(user_id)
+ reveal_state = reveal_manager.get_state(user_id)
+
+ for session_id in range(1, num_sessions + 1):
+ queries = get_queries_for_session(persona, session_id)
+
+ reveal_manager.reset_session(user_id)
+ run_session_v4(llm, user_id, persona, session_id, reveal_state, queries, all_logs, enable_complaints)
+
+ return all_logs
+
+
+# =============================================================================
+# Summary
+# =============================================================================
+
+def print_summary_v4(logs: List[TurnLog]):
+ """Print summary for pilot v4."""
+ print(f"\n{'='*60}")
+ print("PILOT v4 SUMMARY")
+ print(f"{'='*60}")
+
+ if not logs:
+ print("No logs.")
+ return
+
+ from collections import Counter
+
+ personas = sorted(set(l.persona_id for l in logs))
+
+ print(f"\n--- Per-Persona Statistics ---")
+ for pid in personas:
+ p_logs = [l for l in logs if l.persona_id == pid]
+ sessions = sorted(set(l.session_id for l in p_logs))
+
+ print(f"\n{pid}:")
+ for sid in sessions:
+ s_logs = [l for l in p_logs if l.session_id == sid]
+ avg_sat = sum(l.sat_t for l in s_logs) / len(s_logs) if s_logs else 0
+ violations = [v for l in s_logs for v in l.violations]
+ enforced = set(c for l in s_logs for c in l.enforced_constraints)
+ complaints = sum(1 for l in s_logs if l.is_complaint)
+
+ print(f" S{sid}: {len(s_logs)} turns, avg_sat={avg_sat:.3f}, complaints={complaints}")
+ print(f" enforced={enforced}")
+ if violations:
+ print(f" violations: {dict(Counter(violations))}")
+
+ # Cross-session retention
+ print(f"\n--- Cross-Session Retention (S2 without preferences) ---")
+ for pid in personas:
+ p_logs = [l for l in logs if l.persona_id == pid]
+ s1_logs = [l for l in p_logs if l.session_id == 1]
+ s2_logs = [l for l in p_logs if l.session_id == 2]
+
+ if s1_logs and s2_logs:
+ s1_sat = sum(l.sat_t for l in s1_logs) / len(s1_logs)
+ s2_sat = sum(l.sat_t for l in s2_logs) / len(s2_logs)
+ s2_enforced = set(c for l in s2_logs for c in l.enforced_constraints)
+ print(f"{pid}: S1={s1_sat:.3f} → S2={s2_sat:.3f}, enforced={s2_enforced}")
+
+ # Violation rates
+ print(f"\n--- Violation Rates by Type ---")
+ all_violations = [v for l in logs for v in l.violations]
+ total_turns = len(logs)
+ if all_violations:
+ for v, count in Counter(all_violations).most_common():
+ rate = count / total_turns * 100
+ print(f" {v}: {count} ({rate:.1f}%)")
+ else:
+ print(" No violations")
+
+ # Complaint effectiveness
+ print(f"\n--- Complaint Effectiveness ---")
+ complaint_logs = [l for l in logs if l.is_complaint]
+ if complaint_logs:
+ print(f"Total complaints: {len(complaint_logs)}")
+ avg_sat_complaint = sum(l.sat_t for l in complaint_logs) / len(complaint_logs)
+ print(f"Avg sat on complaint turns: {avg_sat_complaint:.3f}")
+ else:
+ print("No complaints generated")
+
+ # Overall
+ total = len(logs)
+ avg_sat = sum(l.sat_t for l in logs) / total
+ total_tokens = sum(l.total_tokens for l in logs)
+ print(f"\n--- Overall ---")
+ print(f"Total turns: {total}, Avg sat: {avg_sat:.3f}, Total tokens: {total_tokens}")
+
+
+def main():
+ import argparse
+ parser = argparse.ArgumentParser(description="Pilot Runner v4 - Full vs Vanilla Comparison")
+ parser.add_argument("--mode", type=str,
+ choices=["full", "full-greedy", "full-sample", "nopersonal", "vanilla", "compare", "all"],
+ default="compare",
+ help="Mode: 'full-greedy' (personalized, deterministic), "
+ "'full-sample' (personalized, stochastic), "
+ "'nopersonal' (retrieval baseline without z_u), "
+ "'vanilla' (pure LLM, no memory), "
+ "'compare' (full-greedy vs vanilla), "
+ "'all' (run all modes)")
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
+ parser.add_argument("--sessions", type=int, default=3, help="Number of sessions per user")
+ parser.add_argument("--no-complaints", action="store_true", help="Disable complaint injection")
+ args = parser.parse_args()
+
+ # Set seeds for reproducibility
+ import random
+ import numpy as np
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+
+ print("=" * 60)
+ print("PILOT RUNNER v4 - Full vs Vanilla Comparison")
+ print("=" * 60)
+ print(f"Started at: {datetime.now().isoformat()}")
+ print(f"Mode: {args.mode}, Seed: {args.seed}, Sessions: {args.sessions}")
+
+ personas = ALL_PERSONAS
+ print(f"\n[Config] {len(personas)} personas:")
+ for p in personas:
+ print(f" - {p.persona_id}: {p.description}")
+
+ enable_complaints = not args.no_complaints
+
+ # Map mode argument to actual run configurations
+ # Each config: (mode_name, llm_mode, eval_mode)
+ # llm_mode: "full", "nopersonal", or "vanilla"
+ # eval_mode: True = greedy/deterministic, False = stochastic sampling
+ if args.mode == "all":
+ run_configs = [
+ ("full-greedy", "full", True),
+ ("full-sample", "full", False),
+ ("nopersonal", "nopersonal", True),
+ ("vanilla", "vanilla", True),
+ ]
+ elif args.mode == "compare":
+ # Main comparison: Full (with memory) vs Vanilla (no memory)
+ run_configs = [
+ ("full-greedy", "full", True),
+ ("vanilla", "vanilla", True),
+ ]
+ elif args.mode == "full" or args.mode == "full-greedy":
+ run_configs = [("full-greedy", "full", True)]
+ elif args.mode == "full-sample":
+ run_configs = [("full-sample", "full", False)]
+ elif args.mode == "vanilla":
+ run_configs = [("vanilla", "vanilla", True)]
+ elif args.mode == "nopersonal":
+ run_configs = [("nopersonal", "nopersonal", True)]
+ else:
+ run_configs = [(args.mode, args.mode, True)]
+
+ for run_name, llm_mode, eval_mode in run_configs:
+ print(f"\n{'#'*60}")
+ print(f"RUNNING: {run_name.upper()}")
+ print(f" llm_mode={llm_mode}, eval_mode={eval_mode} ({'greedy' if eval_mode else 'sample'})")
+ print(f"{'#'*60}")
+
+ # Reset seeds before each run for exact reproducibility
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+
+ print(f"\n[Init] Loading PersonalizedLLM...")
+ llm = PersonalizedLLM(
+ user_store_path=f"data/users/user_store_pilot_v4_{run_name}.npz",
+ only_own_memories=True,
+ enable_preference_extraction=True,
+ enable_rl_updates=(llm_mode == "full"), # Disable RL for nopersonal
+ mode=llm_mode,
+ eval_mode=eval_mode,
+ device_assignment={
+ "embed": "cuda:0",
+ "reranker": "cuda:1",
+ "chat": "cuda:2",
+ "extractor": "cuda:3",
+ },
+ )
+
+ logs = run_multi_user_pilot_v4(llm, personas, num_sessions=args.sessions, enable_complaints=enable_complaints)
+
+ print_summary_v4(logs)
+
+ log_path = f"data/logs/pilot_v4_{run_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl"
+ log_to_jsonl(logs, log_path)
+ print(f"\n[Logs] Saved to: {log_path}")
+
+ # Save user vectors for similarity analysis
+ if llm_mode == "full":
+ llm.persist()
+ print(f"[Persist] User vectors saved to: {llm._user_store.path}")
+
+ print(f"\nCompleted at: {datetime.now().isoformat()}")
+ print("=" * 60)
+
+
+if __name__ == "__main__":
+ main()
+