#!/usr/bin/env python3 """ Pilot Runner v2 - Cross-Session Preference Reveal Mechanism Upgrade from v1: - RevealState: Tracks which preferences have been explicitly revealed by the user - pref_true[k] vs pref_revealed_global[k] distinction - Style constraints only enforced AFTER user reveals them - Reveal state persists across sessions, resets on reset_user() Key concepts: - pref_true[k]: User's true preference (from StylePrefs) - pref_revealed_global[k]: Whether preference k has been revealed at least once Enforcement rule: - A style constraint is enforced only when BOTH pref_true[k] AND pref_revealed_global[k] Session semantics: - reset_user(): Clears ALL state including reveal flags - reset_session(): Keeps reveal flags (cross-session memory) """ import sys import os import json from datetime import datetime from dataclasses import dataclass, asdict, field from typing import List, Dict, Any, Optional, Tuple, Set # Add src to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src")) from personalization.serving import PersonalizedLLM, Feedback, AssistantResponse # ============================================================================= # Style Preferences (True Preferences) # ============================================================================= @dataclass class StylePrefs: """ User's TRUE style preferences. These are the ground truth preferences that the user actually has, but they may not have revealed all of them to the system yet. """ require_short: bool = False max_chars: int = 300 require_bullets: bool = False lang: str = "en" # "en" or "zh" # ============================================================================= # Reveal State (What has been explicitly revealed) # ============================================================================= @dataclass class RevealState: """ Tracks which preferences have been explicitly revealed by the user. This persists across sessions for the same user but resets on reset_user(). A preference is revealed when the user explicitly mentions it in a query. """ short_revealed: bool = False # "short", "concise", "brief", length constraints bullets_revealed: bool = False # "bullet", "bullet points", "list format" lang_revealed: bool = False # Language preference mentioned def reset(self): """Reset all reveal flags (called on reset_user).""" self.short_revealed = False self.bullets_revealed = False self.lang_revealed = False def to_dict(self) -> Dict[str, bool]: return { "short": self.short_revealed, "bullets": self.bullets_revealed, "lang": self.lang_revealed, } def __str__(self) -> str: flags = [] if self.short_revealed: flags.append("short") if self.bullets_revealed: flags.append("bullets") if self.lang_revealed: flags.append("lang") return f"RevealState({', '.join(flags) if flags else 'none'})" class RevealStateManager: """ Manages reveal state for multiple users. Persists across sessions, resets on reset_user(). """ def __init__(self): self._states: Dict[str, RevealState] = {} def get_state(self, user_id: str) -> RevealState: """Get or create reveal state for a user.""" if user_id not in self._states: self._states[user_id] = RevealState() return self._states[user_id] def reset_user(self, user_id: str): """Reset reveal state for a user (called on reset_user).""" if user_id in self._states: self._states[user_id].reset() else: self._states[user_id] = RevealState() def reset_session(self, user_id: str): """ Called on reset_session - does NOT reset reveal state. Reveal state persists across sessions. """ # Intentionally do nothing - reveal state persists pass # ============================================================================= # Preference Detection from Queries # ============================================================================= def detect_revealed_preferences(query: str) -> Dict[str, bool]: """ Detect which preferences are mentioned in a query. Returns a dict with keys: "short", "bullets", "lang" Each value is True if that preference was mentioned. """ lower_q = (query or "").lower() revealed = { "short": False, "bullets": False, "lang": False, } # Short/length preference detection short_patterns = [ "short", "concise", "brief", "under ", "less than", "keep it short", "keep responses", "keep answers", "maximum ", "max ", "characters", "words or less", "200 ", "100 ", "50 ", "300 ", # Common char limits ] for pattern in short_patterns: if pattern in lower_q: revealed["short"] = True break # Bullet preference detection bullet_patterns = [ "bullet", "bullet point", "bullet-point", "bulleted", "list format", "use bullets", "use bullet", "with bullets", "in bullets", "- format", "• ", "numbered list", ] for pattern in bullet_patterns: if pattern in lower_q: revealed["bullets"] = True break # Language preference detection lang_patterns_zh = [ "chinese", "中文", "in chinese", "用中文", "speak chinese", "write chinese", "respond in chinese", "please use chinese", "mandarin", ] lang_patterns_en = [ "english", "in english", "use english", "speak english", "write english", "respond in english", "please use english", ] for pattern in lang_patterns_zh + lang_patterns_en: if pattern in lower_q: revealed["lang"] = True break return revealed def update_reveal_state(reveal_state: RevealState, query: str) -> Set[str]: """ Update reveal state based on query content. Returns set of newly revealed preferences. """ detected = detect_revealed_preferences(query) newly_revealed = set() if detected["short"] and not reveal_state.short_revealed: reveal_state.short_revealed = True newly_revealed.add("short") if detected["bullets"] and not reveal_state.bullets_revealed: reveal_state.bullets_revealed = True newly_revealed.add("bullets") if detected["lang"] and not reveal_state.lang_revealed: reveal_state.lang_revealed = True newly_revealed.add("lang") return newly_revealed # ============================================================================= # Style-Aware Judge with Reveal State # ============================================================================= @dataclass class JudgeResult: """Output from the judge for one turn.""" sat_t: float # Satisfaction score [0, 1] sev_t: float # Severity of violations [0, 1] prog_t: float # Task progress [0, 1] violations: List[str] # List of violated constraints enforced_constraints: List[str] # Which constraints were actually enforced def style_judge_with_reveal( query: str, answer: str, task_type: str, prefs: StylePrefs, reveal_state: RevealState, ) -> JudgeResult: """ Style-aware judge that ONLY enforces revealed preferences. A constraint is enforced only when: - pref_true[k] is True (user has this preference) - pref_revealed_global[k] is True (user has revealed this preference) Args: query: User's query answer: Assistant's answer task_type: Type of task ("general", "list", "code") prefs: User's TRUE preferences (StylePrefs) reveal_state: Which preferences have been revealed Returns: JudgeResult with sat_t, sev_t, prog_t, violations, and enforced_constraints """ violations: List[str] = [] enforced: List[str] = [] text = (answer or "").strip() # 0) Empty answer - always a violation regardless of reveal state if not text or len(text) < 5: violations.append("empty_answer") return JudgeResult( sat_t=0.0, sev_t=1.0, prog_t=0.0, violations=violations, enforced_constraints=["non_empty"], ) # 1) Length preference - enforce only if BOTH true AND revealed if prefs.require_short and reveal_state.short_revealed: enforced.append("short") if len(text) > prefs.max_chars: violations.append("too_long") # 2) Bullet preference - enforce only if BOTH true AND revealed # Also only for list-type tasks if prefs.require_bullets and reveal_state.bullets_revealed: if task_type in ("general", "list"): enforced.append("bullets") has_bullets = ("- " in text) or ("• " in text) or ("* " in text) or ("\n- " in text) if not has_bullets: violations.append("no_bullets") # 3) Language preference - enforce only if BOTH true AND revealed if reveal_state.lang_revealed: enforced.append("lang") if prefs.lang == "zh": ascii_count = sum(c.isascii() for c in text) ascii_ratio = ascii_count / max(1, len(text)) if ascii_ratio > 0.7: violations.append("wrong_lang") elif prefs.lang == "en": ascii_count = sum(c.isascii() for c in text) ascii_ratio = ascii_count / max(1, len(text)) if ascii_ratio < 0.5: violations.append("wrong_lang") # 4) Code task: always enforce code markers (not a user preference) prog_t = 1.0 if task_type == "code": enforced.append("code_block") has_code = ("```" in text) or ("def " in text) or ("function " in text) if not has_code: violations.append("no_code_block") prog_t = 0.0 # 5) Compute sat_t and sev_t from violations if not violations: sat_t = 1.0 sev_t = 0.0 else: sat_t = max(0.0, 1.0 - 0.3 * float(len(violations))) hard_violations = {"empty_answer", "too_long", "wrong_lang"} sev_t = 1.0 if any(v in hard_violations for v in violations) else 0.0 return JudgeResult( sat_t=sat_t, sev_t=sev_t, prog_t=prog_t, violations=violations, enforced_constraints=enforced, ) # ============================================================================= # Feedback Computation (reward + gating) # ============================================================================= def compute_feedback_for_turn( turn_id: int, query: str, query_type: str, task_type: str, judge_result: JudgeResult, ) -> Tuple[float, float]: """ Convert JudgeResult into (reward, gating). Same as v1 - reward = sat_t, gating = 1 for preference turns. """ reward = judge_result.sat_t lower_q = (query or "").lower() is_pref_turn = ( query_type == "preference" or "i prefer" in lower_q or "my preference" in lower_q or "please use" in lower_q or "please keep" in lower_q or "you didn't follow" in lower_q or "you forgot" in lower_q or "remember that i" in lower_q or "i told you" in lower_q or "i asked for" in lower_q ) gating = 1.0 if is_pref_turn else 0.0 return reward, gating # ============================================================================= # Multi-Session Queries for Pilot v2 # ============================================================================= def get_session_1_queries() -> List[Dict[str, Any]]: """ Session 1: User reveals preferences and does some tasks. """ return [ { "query": "I prefer short, concise answers. Please keep responses under 200 characters.", "type": "preference", "task_type": "general", }, { "query": "What are three tips for better sleep?", "type": "task", "task_type": "list", }, { "query": "I also prefer bullet points when listing things.", "type": "preference", "task_type": "general", }, { "query": "What are the main benefits of exercise?", "type": "task", "task_type": "list", }, { "query": "Name five programming languages.", "type": "task", "task_type": "list", }, ] def get_session_2_queries() -> List[Dict[str, Any]]: """ Session 2: User does NOT restate preferences. Tests cross-session preference retention. """ return [ { "query": "What are three healthy breakfast ideas?", "type": "task", "task_type": "list", }, { "query": "List four seasons of the year.", "type": "task", "task_type": "list", }, { "query": "What is the capital of France?", "type": "task", "task_type": "general", }, { "query": "Name three types of renewable energy.", "type": "task", "task_type": "list", }, ] def get_session_3_queries() -> List[Dict[str, Any]]: """ Session 3: Mix of tasks and one complaint/reminder. """ return [ { "query": "What are five common fruits?", "type": "task", "task_type": "list", }, { "query": "Remember that I asked for short bullet points. List three ocean animals.", "type": "preference", "task_type": "list", }, { "query": "What is 2 + 2?", "type": "task", "task_type": "general", }, ] # ============================================================================= # Logging (Extended for v2) # ============================================================================= @dataclass class TurnLog: """Log entry for one turn (extended for v2).""" session_id: int turn_id: int query: str query_type: str task_type: str answer: str answer_length: int sat_t: float sev_t: float prog_t: float violations: List[str] enforced_constraints: List[str] reward: float gating: float reveal_state_before: Dict[str, bool] reveal_state_after: Dict[str, bool] newly_revealed: List[str] z_long_norm_before: float z_long_norm_after: float z_short_norm_before: float z_short_norm_after: float prompt_tokens: int completion_tokens: int total_tokens: int num_memories_retrieved: int num_prefs_extracted: int def log_to_jsonl(logs: List[TurnLog], filepath: str): """Save logs to JSONL file.""" os.makedirs(os.path.dirname(filepath), exist_ok=True) with open(filepath, "w") as f: for log in logs: f.write(json.dumps(asdict(log)) + "\n") # ============================================================================= # Pilot Runner v2 (Multi-Session with Reveal State) # ============================================================================= def run_session( llm: PersonalizedLLM, user_id: str, session_id: int, prefs: StylePrefs, reveal_state: RevealState, queries: List[Dict[str, Any]], ) -> List[TurnLog]: """ Run a single session with reveal-aware judging. """ logs: List[TurnLog] = [] print(f"\n{'='*60}") print(f"SESSION {session_id}: user_id={user_id}, turns={len(queries)}") print(f"Reveal state (start): {reveal_state}") print(f"{'='*60}") # Reset session (clears history, z_short; keeps z_long and reveal state) llm.reset_session(user_id) state_before = llm.get_user_state_summary(user_id) print(f"[Session] z_long={state_before['z_long_norm']:.6f}, z_short={state_before['z_short_norm']:.6f}") for turn_id, q_info in enumerate(queries): query = q_info["query"] query_type = q_info.get("type", "task") task_type = q_info.get("task_type", "general") print(f"\n{'─'*60}") print(f"Session {session_id} / Turn {turn_id} [{query_type}]") print(f"{'─'*60}") print(f"[Query] {query}") # Capture reveal state BEFORE this turn reveal_before = reveal_state.to_dict() # Update reveal state based on query content newly_revealed = update_reveal_state(reveal_state, query) if newly_revealed: print(f"[Reveal] Newly revealed: {newly_revealed}") print(f"[Reveal] State: {reveal_state}") # Capture reveal state AFTER update reveal_after = reveal_state.to_dict() # Get user state before state_before = llm.get_user_state_summary(user_id) z_long_before = state_before["z_long_norm"] z_short_before = state_before["z_short_norm"] # Apply feedback for previous turn (from turn 1 onwards in this session) if turn_id > 0 and len(logs) > 0: # Find the last log from THIS session session_logs = [l for l in logs if l.session_id == session_id] if session_logs: prev_log = session_logs[-1] feedback = Feedback( user_id=user_id, turn_id=prev_log.turn_id, reward=prev_log.reward, gating=prev_log.gating, meta={ "sat_t": prev_log.sat_t, "violations": prev_log.violations, "source": "pilot_v2", "session_id": session_id, } ) print(f"[Feedback] turn={prev_log.turn_id}, reward={feedback.reward:.2f}, gating={feedback.gating:.1f}") llm.apply_feedback(feedback) # Chat resp: AssistantResponse = llm.chat(user_id, query) answer_display = resp.answer[:150] + "..." if len(resp.answer) > 150 else resp.answer print(f"[Answer] ({len(resp.answer)} chars) {answer_display}") print(f"[Usage] prompt={resp.usage.prompt_tokens}, completion={resp.usage.completion_tokens}") # Judge with reveal-aware logic judge_result = style_judge_with_reveal(query, resp.answer, task_type, prefs, reveal_state) print(f"[Judge] sat={judge_result.sat_t:.2f}, enforced={judge_result.enforced_constraints}") if judge_result.violations: print(f"[Judge] violations={judge_result.violations}") # Compute feedback reward, gating = compute_feedback_for_turn( turn_id=turn_id, query=query, query_type=query_type, task_type=task_type, judge_result=judge_result, ) print(f"[Feedback] reward={reward:.2f}, gating={gating:.1f}") # Get state after state_after = llm.get_user_state_summary(user_id) z_long_after = state_after["z_long_norm"] z_short_after = state_after["z_short_norm"] z_long_delta = z_long_after - z_long_before z_short_delta = z_short_after - z_short_before print(f"[State] z_long: {z_long_before:.6f} → {z_long_after:.6f} (Δ={z_long_delta:+.6f})") print(f"[State] z_short: {z_short_before:.6f} → {z_short_after:.6f} (Δ={z_short_delta:+.6f})") # Debug info num_memories = len(resp.debug.selected_memory_ids) if resp.debug else 0 num_prefs = len(resp.debug.extracted_preferences) if resp.debug else 0 print(f"[Debug] memories={num_memories}, prefs_extracted={num_prefs}") # Log log = TurnLog( session_id=session_id, turn_id=turn_id, query=query, query_type=query_type, task_type=task_type, answer=resp.answer, answer_length=len(resp.answer), sat_t=judge_result.sat_t, sev_t=judge_result.sev_t, prog_t=judge_result.prog_t, violations=judge_result.violations, enforced_constraints=judge_result.enforced_constraints, reward=reward, gating=gating, reveal_state_before=reveal_before, reveal_state_after=reveal_after, newly_revealed=list(newly_revealed), z_long_norm_before=z_long_before, z_long_norm_after=z_long_after, z_short_norm_before=z_short_before, z_short_norm_after=z_short_after, prompt_tokens=resp.usage.prompt_tokens, completion_tokens=resp.usage.completion_tokens, total_tokens=resp.usage.total_tokens, num_memories_retrieved=num_memories, num_prefs_extracted=num_prefs, ) logs.append(log) # Apply final feedback for this session session_logs = [l for l in logs if l.session_id == session_id] if session_logs: last_log = session_logs[-1] feedback = Feedback( user_id=user_id, turn_id=last_log.turn_id, reward=last_log.reward, gating=last_log.gating, meta={"source": "pilot_v2", "session_id": session_id, "final": True} ) print(f"\n[Final Feedback] turn={last_log.turn_id}, reward={feedback.reward:.2f}, gating={feedback.gating:.1f}") llm.apply_feedback(feedback) print(f"\n[Session {session_id} End] Reveal state: {reveal_state}") return logs def run_pilot_v2( llm: PersonalizedLLM, user_id: str = "pilot_user_v2", prefs: Optional[StylePrefs] = None, ) -> List[TurnLog]: """ Run multi-session pilot with reveal state tracking. Session 1: User reveals preferences Session 2: User does NOT restate preferences (tests cross-session retention) Session 3: Mix of tasks and reminders """ if prefs is None: prefs = StylePrefs( require_short=True, max_chars=200, require_bullets=True, lang="en", ) # Initialize reveal state manager reveal_manager = RevealStateManager() print(f"\n{'#'*60}") print(f"PILOT v2: CROSS-SESSION PREFERENCE REVEAL TEST") print(f"User: {user_id}") print(f"True prefs: short={prefs.require_short}, bullets={prefs.require_bullets}, lang={prefs.lang}") print(f"{'#'*60}") # Reset user completely (clears all state including reveal) print(f"\n[Pilot] Resetting user: {user_id}") llm.reset_user(user_id) reveal_manager.reset_user(user_id) all_logs: List[TurnLog] = [] reveal_state = reveal_manager.get_state(user_id) # Session 1: Reveal preferences session_1_queries = get_session_1_queries() logs_s1 = run_session(llm, user_id, 1, prefs, reveal_state, session_1_queries) all_logs.extend(logs_s1) # Session 2: NO preference restatement (test cross-session retention) # Note: reveal_state persists, but reset_session clears history reveal_manager.reset_session(user_id) # Does nothing to reveal state session_2_queries = get_session_2_queries() logs_s2 = run_session(llm, user_id, 2, prefs, reveal_state, session_2_queries) all_logs.extend(logs_s2) # Session 3: Reminder and more tasks reveal_manager.reset_session(user_id) session_3_queries = get_session_3_queries() logs_s3 = run_session(llm, user_id, 3, prefs, reveal_state, session_3_queries) all_logs.extend(logs_s3) return all_logs def print_summary_v2(logs: List[TurnLog], prefs: StylePrefs): """Print summary for pilot v2.""" print(f"\n{'='*60}") print("PILOT v2 SUMMARY - Cross-Session Reveal") print(f"{'='*60}") if not logs: print("No logs to summarize.") return # Per-session stats sessions = sorted(set(l.session_id for l in logs)) print(f"\n--- Per-Session Statistics ---") for sid in sessions: session_logs = [l for l in logs if l.session_id == sid] avg_sat = sum(l.sat_t for l in session_logs) / len(session_logs) violations = [v for l in session_logs for v in l.violations] # What was revealed at session end if session_logs: final_reveal = session_logs[-1].reveal_state_after else: final_reveal = {} print(f"\nSession {sid}: {len(session_logs)} turns") print(f" Avg sat_t: {avg_sat:.3f}") print(f" Violations: {len(violations)} ({violations if violations else 'none'})") print(f" Reveal state at end: {final_reveal}") # Overall stats total = len(logs) avg_sat = sum(l.sat_t for l in logs) / total total_tokens = sum(l.total_tokens for l in logs) print(f"\n--- Overall Statistics ---") print(f"Total turns: {total}") print(f"Overall avg sat_t: {avg_sat:.3f}") print(f"Total tokens: {total_tokens}") # Violations by type print(f"\n--- Violations Breakdown ---") from collections import Counter all_violations = [v for l in logs for v in l.violations] if all_violations: for v, count in Counter(all_violations).most_common(): print(f" {v}: {count}") else: print(" No violations") # Enforcement tracking print(f"\n--- Constraint Enforcement ---") for constraint in ["short", "bullets", "lang"]: enforced_count = sum(1 for l in logs if constraint in l.enforced_constraints) print(f" {constraint}: enforced in {enforced_count}/{total} turns") # Cross-session reveal verification print(f"\n--- Cross-Session Reveal Verification ---") # Session 1: Should have some reveals s1_logs = [l for l in logs if l.session_id == 1] s1_reveals = set() for l in s1_logs: s1_reveals.update(l.newly_revealed) print(f"Session 1 revealed: {s1_reveals if s1_reveals else 'none'}") # Session 2: Should NOT have new reveals (no preference queries) s2_logs = [l for l in logs if l.session_id == 2] s2_reveals = set() for l in s2_logs: s2_reveals.update(l.newly_revealed) print(f"Session 2 revealed: {s2_reveals if s2_reveals else 'none (expected)'}") # But Session 2 should still ENFORCE the constraints revealed in Session 1 if s2_logs: s2_enforced = set() for l in s2_logs: s2_enforced.update(l.enforced_constraints) print(f"Session 2 enforced: {s2_enforced}") if s1_reveals and s1_reveals.issubset(s2_enforced): print("✓ Cross-session retention VERIFIED: Session 1 reveals enforced in Session 2") else: print("✗ Cross-session retention issue: some reveals not enforced") # Turn-by-turn table print(f"\n--- Turn-by-Turn Summary ---") print(f"{'S':>2} {'T':>2} {'Type':>10} {'Len':>5} {'sat':>5} {'enforced':<20} {'violations'}") print("-" * 70) for l in logs: enforced_str = ",".join(l.enforced_constraints) if l.enforced_constraints else "-" viol_str = ",".join(l.violations) if l.violations else "-" print(f"{l.session_id:>2} {l.turn_id:>2} {l.query_type:>10} {l.answer_length:>5} {l.sat_t:>5.2f} {enforced_str:<20} {viol_str}") def main(): print("=" * 60) print("PILOT RUNNER v2 - Cross-Session Preference Reveal") print("=" * 60) print(f"Started at: {datetime.now().isoformat()}") # Define user's TRUE preferences prefs = StylePrefs( require_short=True, max_chars=200, require_bullets=True, lang="en", ) print(f"\n[Config] True preferences: {prefs}") print("[Config] Note: Constraints only enforced AFTER user reveals them") # Initialize LLM print("\n[Init] Loading PersonalizedLLM...") llm = PersonalizedLLM( user_store_path="data/users/user_store_pilot_v2.npz", only_own_memories=True, enable_preference_extraction=True, enable_rl_updates=True, ) # Run pilot user_id = "pilot_user_v2" logs = run_pilot_v2(llm, user_id=user_id, prefs=prefs) # Summary print_summary_v2(logs, prefs) # Save logs log_path = f"data/logs/pilot_v2_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl" log_to_jsonl(logs, log_path) print(f"\n[Logs] Saved to: {log_path}") # Final state final_state = llm.get_user_state_summary(user_id) print(f"\n[Final State] {final_state}") print(f"\nCompleted at: {datetime.now().isoformat()}") print("=" * 60) if __name__ == "__main__": main()