summaryrefslogtreecommitdiff
path: root/scripts/pilot_runner_v2.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/pilot_runner_v2.py')
-rw-r--r--scripts/pilot_runner_v2.py852
1 files changed, 852 insertions, 0 deletions
diff --git a/scripts/pilot_runner_v2.py b/scripts/pilot_runner_v2.py
new file mode 100644
index 0000000..d3c2aa8
--- /dev/null
+++ b/scripts/pilot_runner_v2.py
@@ -0,0 +1,852 @@
+#!/usr/bin/env python3
+"""
+Pilot Runner v2 - Cross-Session Preference Reveal Mechanism
+
+Upgrade from v1:
+- RevealState: Tracks which preferences have been explicitly revealed by the user
+- pref_true[k] vs pref_revealed_global[k] distinction
+- Style constraints only enforced AFTER user reveals them
+- Reveal state persists across sessions, resets on reset_user()
+
+Key concepts:
+- pref_true[k]: User's true preference (from StylePrefs)
+- pref_revealed_global[k]: Whether preference k has been revealed at least once
+
+Enforcement rule:
+- A style constraint is enforced only when BOTH pref_true[k] AND pref_revealed_global[k]
+
+Session semantics:
+- reset_user(): Clears ALL state including reveal flags
+- reset_session(): Keeps reveal flags (cross-session memory)
+"""
+
+import sys
+import os
+import json
+from datetime import datetime
+from dataclasses import dataclass, asdict, field
+from typing import List, Dict, Any, Optional, Tuple, Set
+
+# Add src to path
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.serving import PersonalizedLLM, Feedback, AssistantResponse
+
+
+# =============================================================================
+# Style Preferences (True Preferences)
+# =============================================================================
+
+@dataclass
+class StylePrefs:
+ """
+ User's TRUE style preferences.
+ These are the ground truth preferences that the user actually has,
+ but they may not have revealed all of them to the system yet.
+ """
+ require_short: bool = False
+ max_chars: int = 300
+ require_bullets: bool = False
+ lang: str = "en" # "en" or "zh"
+
+
+# =============================================================================
+# Reveal State (What has been explicitly revealed)
+# =============================================================================
+
+@dataclass
+class RevealState:
+ """
+ Tracks which preferences have been explicitly revealed by the user.
+
+ This persists across sessions for the same user but resets on reset_user().
+ A preference is revealed when the user explicitly mentions it in a query.
+ """
+ short_revealed: bool = False # "short", "concise", "brief", length constraints
+ bullets_revealed: bool = False # "bullet", "bullet points", "list format"
+ lang_revealed: bool = False # Language preference mentioned
+
+ def reset(self):
+ """Reset all reveal flags (called on reset_user)."""
+ self.short_revealed = False
+ self.bullets_revealed = False
+ self.lang_revealed = False
+
+ def to_dict(self) -> Dict[str, bool]:
+ return {
+ "short": self.short_revealed,
+ "bullets": self.bullets_revealed,
+ "lang": self.lang_revealed,
+ }
+
+ def __str__(self) -> str:
+ flags = []
+ if self.short_revealed:
+ flags.append("short")
+ if self.bullets_revealed:
+ flags.append("bullets")
+ if self.lang_revealed:
+ flags.append("lang")
+ return f"RevealState({', '.join(flags) if flags else 'none'})"
+
+
+class RevealStateManager:
+ """
+ Manages reveal state for multiple users.
+ Persists across sessions, resets on reset_user().
+ """
+
+ def __init__(self):
+ self._states: Dict[str, RevealState] = {}
+
+ def get_state(self, user_id: str) -> RevealState:
+ """Get or create reveal state for a user."""
+ if user_id not in self._states:
+ self._states[user_id] = RevealState()
+ return self._states[user_id]
+
+ def reset_user(self, user_id: str):
+ """Reset reveal state for a user (called on reset_user)."""
+ if user_id in self._states:
+ self._states[user_id].reset()
+ else:
+ self._states[user_id] = RevealState()
+
+ def reset_session(self, user_id: str):
+ """
+ Called on reset_session - does NOT reset reveal state.
+ Reveal state persists across sessions.
+ """
+ # Intentionally do nothing - reveal state persists
+ pass
+
+
+# =============================================================================
+# Preference Detection from Queries
+# =============================================================================
+
+def detect_revealed_preferences(query: str) -> Dict[str, bool]:
+ """
+ Detect which preferences are mentioned in a query.
+
+ Returns a dict with keys: "short", "bullets", "lang"
+ Each value is True if that preference was mentioned.
+ """
+ lower_q = (query or "").lower()
+
+ revealed = {
+ "short": False,
+ "bullets": False,
+ "lang": False,
+ }
+
+ # Short/length preference detection
+ short_patterns = [
+ "short", "concise", "brief", "under ", "less than",
+ "keep it short", "keep responses", "keep answers",
+ "maximum ", "max ", "characters", "words or less",
+ "200 ", "100 ", "50 ", "300 ", # Common char limits
+ ]
+ for pattern in short_patterns:
+ if pattern in lower_q:
+ revealed["short"] = True
+ break
+
+ # Bullet preference detection
+ bullet_patterns = [
+ "bullet", "bullet point", "bullet-point",
+ "bulleted", "list format", "use bullets",
+ "use bullet", "with bullets", "in bullets",
+ "- format", "• ", "numbered list",
+ ]
+ for pattern in bullet_patterns:
+ if pattern in lower_q:
+ revealed["bullets"] = True
+ break
+
+ # Language preference detection
+ lang_patterns_zh = [
+ "chinese", "中文", "in chinese", "用中文",
+ "speak chinese", "write chinese", "respond in chinese",
+ "please use chinese", "mandarin",
+ ]
+ lang_patterns_en = [
+ "english", "in english", "use english",
+ "speak english", "write english", "respond in english",
+ "please use english",
+ ]
+
+ for pattern in lang_patterns_zh + lang_patterns_en:
+ if pattern in lower_q:
+ revealed["lang"] = True
+ break
+
+ return revealed
+
+
+def update_reveal_state(reveal_state: RevealState, query: str) -> Set[str]:
+ """
+ Update reveal state based on query content.
+ Returns set of newly revealed preferences.
+ """
+ detected = detect_revealed_preferences(query)
+ newly_revealed = set()
+
+ if detected["short"] and not reveal_state.short_revealed:
+ reveal_state.short_revealed = True
+ newly_revealed.add("short")
+
+ if detected["bullets"] and not reveal_state.bullets_revealed:
+ reveal_state.bullets_revealed = True
+ newly_revealed.add("bullets")
+
+ if detected["lang"] and not reveal_state.lang_revealed:
+ reveal_state.lang_revealed = True
+ newly_revealed.add("lang")
+
+ return newly_revealed
+
+
+# =============================================================================
+# Style-Aware Judge with Reveal State
+# =============================================================================
+
+@dataclass
+class JudgeResult:
+ """Output from the judge for one turn."""
+ sat_t: float # Satisfaction score [0, 1]
+ sev_t: float # Severity of violations [0, 1]
+ prog_t: float # Task progress [0, 1]
+ violations: List[str] # List of violated constraints
+ enforced_constraints: List[str] # Which constraints were actually enforced
+
+
+def style_judge_with_reveal(
+ query: str,
+ answer: str,
+ task_type: str,
+ prefs: StylePrefs,
+ reveal_state: RevealState,
+) -> JudgeResult:
+ """
+ Style-aware judge that ONLY enforces revealed preferences.
+
+ A constraint is enforced only when:
+ - pref_true[k] is True (user has this preference)
+ - pref_revealed_global[k] is True (user has revealed this preference)
+
+ Args:
+ query: User's query
+ answer: Assistant's answer
+ task_type: Type of task ("general", "list", "code")
+ prefs: User's TRUE preferences (StylePrefs)
+ reveal_state: Which preferences have been revealed
+
+ Returns:
+ JudgeResult with sat_t, sev_t, prog_t, violations, and enforced_constraints
+ """
+ violations: List[str] = []
+ enforced: List[str] = []
+ text = (answer or "").strip()
+
+ # 0) Empty answer - always a violation regardless of reveal state
+ if not text or len(text) < 5:
+ violations.append("empty_answer")
+ return JudgeResult(
+ sat_t=0.0,
+ sev_t=1.0,
+ prog_t=0.0,
+ violations=violations,
+ enforced_constraints=["non_empty"],
+ )
+
+ # 1) Length preference - enforce only if BOTH true AND revealed
+ if prefs.require_short and reveal_state.short_revealed:
+ enforced.append("short")
+ if len(text) > prefs.max_chars:
+ violations.append("too_long")
+
+ # 2) Bullet preference - enforce only if BOTH true AND revealed
+ # Also only for list-type tasks
+ if prefs.require_bullets and reveal_state.bullets_revealed:
+ if task_type in ("general", "list"):
+ enforced.append("bullets")
+ has_bullets = ("- " in text) or ("• " in text) or ("* " in text) or ("\n- " in text)
+ if not has_bullets:
+ violations.append("no_bullets")
+
+ # 3) Language preference - enforce only if BOTH true AND revealed
+ if reveal_state.lang_revealed:
+ enforced.append("lang")
+ if prefs.lang == "zh":
+ ascii_count = sum(c.isascii() for c in text)
+ ascii_ratio = ascii_count / max(1, len(text))
+ if ascii_ratio > 0.7:
+ violations.append("wrong_lang")
+ elif prefs.lang == "en":
+ ascii_count = sum(c.isascii() for c in text)
+ ascii_ratio = ascii_count / max(1, len(text))
+ if ascii_ratio < 0.5:
+ violations.append("wrong_lang")
+
+ # 4) Code task: always enforce code markers (not a user preference)
+ prog_t = 1.0
+ if task_type == "code":
+ enforced.append("code_block")
+ has_code = ("```" in text) or ("def " in text) or ("function " in text)
+ if not has_code:
+ violations.append("no_code_block")
+ prog_t = 0.0
+
+ # 5) Compute sat_t and sev_t from violations
+ if not violations:
+ sat_t = 1.0
+ sev_t = 0.0
+ else:
+ sat_t = max(0.0, 1.0 - 0.3 * float(len(violations)))
+ hard_violations = {"empty_answer", "too_long", "wrong_lang"}
+ sev_t = 1.0 if any(v in hard_violations for v in violations) else 0.0
+
+ return JudgeResult(
+ sat_t=sat_t,
+ sev_t=sev_t,
+ prog_t=prog_t,
+ violations=violations,
+ enforced_constraints=enforced,
+ )
+
+
+# =============================================================================
+# Feedback Computation (reward + gating)
+# =============================================================================
+
+def compute_feedback_for_turn(
+ turn_id: int,
+ query: str,
+ query_type: str,
+ task_type: str,
+ judge_result: JudgeResult,
+) -> Tuple[float, float]:
+ """
+ Convert JudgeResult into (reward, gating).
+ Same as v1 - reward = sat_t, gating = 1 for preference turns.
+ """
+ reward = judge_result.sat_t
+
+ lower_q = (query or "").lower()
+
+ is_pref_turn = (
+ query_type == "preference"
+ or "i prefer" in lower_q
+ or "my preference" in lower_q
+ or "please use" in lower_q
+ or "please keep" in lower_q
+ or "you didn't follow" in lower_q
+ or "you forgot" in lower_q
+ or "remember that i" in lower_q
+ or "i told you" in lower_q
+ or "i asked for" in lower_q
+ )
+
+ gating = 1.0 if is_pref_turn else 0.0
+ return reward, gating
+
+
+# =============================================================================
+# Multi-Session Queries for Pilot v2
+# =============================================================================
+
+def get_session_1_queries() -> List[Dict[str, Any]]:
+ """
+ Session 1: User reveals preferences and does some tasks.
+ """
+ return [
+ {
+ "query": "I prefer short, concise answers. Please keep responses under 200 characters.",
+ "type": "preference",
+ "task_type": "general",
+ },
+ {
+ "query": "What are three tips for better sleep?",
+ "type": "task",
+ "task_type": "list",
+ },
+ {
+ "query": "I also prefer bullet points when listing things.",
+ "type": "preference",
+ "task_type": "general",
+ },
+ {
+ "query": "What are the main benefits of exercise?",
+ "type": "task",
+ "task_type": "list",
+ },
+ {
+ "query": "Name five programming languages.",
+ "type": "task",
+ "task_type": "list",
+ },
+ ]
+
+
+def get_session_2_queries() -> List[Dict[str, Any]]:
+ """
+ Session 2: User does NOT restate preferences.
+ Tests cross-session preference retention.
+ """
+ return [
+ {
+ "query": "What are three healthy breakfast ideas?",
+ "type": "task",
+ "task_type": "list",
+ },
+ {
+ "query": "List four seasons of the year.",
+ "type": "task",
+ "task_type": "list",
+ },
+ {
+ "query": "What is the capital of France?",
+ "type": "task",
+ "task_type": "general",
+ },
+ {
+ "query": "Name three types of renewable energy.",
+ "type": "task",
+ "task_type": "list",
+ },
+ ]
+
+
+def get_session_3_queries() -> List[Dict[str, Any]]:
+ """
+ Session 3: Mix of tasks and one complaint/reminder.
+ """
+ return [
+ {
+ "query": "What are five common fruits?",
+ "type": "task",
+ "task_type": "list",
+ },
+ {
+ "query": "Remember that I asked for short bullet points. List three ocean animals.",
+ "type": "preference",
+ "task_type": "list",
+ },
+ {
+ "query": "What is 2 + 2?",
+ "type": "task",
+ "task_type": "general",
+ },
+ ]
+
+
+# =============================================================================
+# Logging (Extended for v2)
+# =============================================================================
+
+@dataclass
+class TurnLog:
+ """Log entry for one turn (extended for v2)."""
+ session_id: int
+ turn_id: int
+ query: str
+ query_type: str
+ task_type: str
+ answer: str
+ answer_length: int
+ sat_t: float
+ sev_t: float
+ prog_t: float
+ violations: List[str]
+ enforced_constraints: List[str]
+ reward: float
+ gating: float
+ reveal_state_before: Dict[str, bool]
+ reveal_state_after: Dict[str, bool]
+ newly_revealed: List[str]
+ z_long_norm_before: float
+ z_long_norm_after: float
+ z_short_norm_before: float
+ z_short_norm_after: float
+ prompt_tokens: int
+ completion_tokens: int
+ total_tokens: int
+ num_memories_retrieved: int
+ num_prefs_extracted: int
+
+
+def log_to_jsonl(logs: List[TurnLog], filepath: str):
+ """Save logs to JSONL file."""
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
+ with open(filepath, "w") as f:
+ for log in logs:
+ f.write(json.dumps(asdict(log)) + "\n")
+
+
+# =============================================================================
+# Pilot Runner v2 (Multi-Session with Reveal State)
+# =============================================================================
+
+def run_session(
+ llm: PersonalizedLLM,
+ user_id: str,
+ session_id: int,
+ prefs: StylePrefs,
+ reveal_state: RevealState,
+ queries: List[Dict[str, Any]],
+) -> List[TurnLog]:
+ """
+ Run a single session with reveal-aware judging.
+ """
+ logs: List[TurnLog] = []
+
+ print(f"\n{'='*60}")
+ print(f"SESSION {session_id}: user_id={user_id}, turns={len(queries)}")
+ print(f"Reveal state (start): {reveal_state}")
+ print(f"{'='*60}")
+
+ # Reset session (clears history, z_short; keeps z_long and reveal state)
+ llm.reset_session(user_id)
+
+ state_before = llm.get_user_state_summary(user_id)
+ print(f"[Session] z_long={state_before['z_long_norm']:.6f}, z_short={state_before['z_short_norm']:.6f}")
+
+ for turn_id, q_info in enumerate(queries):
+ query = q_info["query"]
+ query_type = q_info.get("type", "task")
+ task_type = q_info.get("task_type", "general")
+
+ print(f"\n{'─'*60}")
+ print(f"Session {session_id} / Turn {turn_id} [{query_type}]")
+ print(f"{'─'*60}")
+ print(f"[Query] {query}")
+
+ # Capture reveal state BEFORE this turn
+ reveal_before = reveal_state.to_dict()
+
+ # Update reveal state based on query content
+ newly_revealed = update_reveal_state(reveal_state, query)
+ if newly_revealed:
+ print(f"[Reveal] Newly revealed: {newly_revealed}")
+ print(f"[Reveal] State: {reveal_state}")
+
+ # Capture reveal state AFTER update
+ reveal_after = reveal_state.to_dict()
+
+ # Get user state before
+ state_before = llm.get_user_state_summary(user_id)
+ z_long_before = state_before["z_long_norm"]
+ z_short_before = state_before["z_short_norm"]
+
+ # Apply feedback for previous turn (from turn 1 onwards in this session)
+ if turn_id > 0 and len(logs) > 0:
+ # Find the last log from THIS session
+ session_logs = [l for l in logs if l.session_id == session_id]
+ if session_logs:
+ prev_log = session_logs[-1]
+ feedback = Feedback(
+ user_id=user_id,
+ turn_id=prev_log.turn_id,
+ reward=prev_log.reward,
+ gating=prev_log.gating,
+ meta={
+ "sat_t": prev_log.sat_t,
+ "violations": prev_log.violations,
+ "source": "pilot_v2",
+ "session_id": session_id,
+ }
+ )
+ print(f"[Feedback] turn={prev_log.turn_id}, reward={feedback.reward:.2f}, gating={feedback.gating:.1f}")
+ llm.apply_feedback(feedback)
+
+ # Chat
+ resp: AssistantResponse = llm.chat(user_id, query)
+
+ answer_display = resp.answer[:150] + "..." if len(resp.answer) > 150 else resp.answer
+ print(f"[Answer] ({len(resp.answer)} chars) {answer_display}")
+ print(f"[Usage] prompt={resp.usage.prompt_tokens}, completion={resp.usage.completion_tokens}")
+
+ # Judge with reveal-aware logic
+ judge_result = style_judge_with_reveal(query, resp.answer, task_type, prefs, reveal_state)
+ print(f"[Judge] sat={judge_result.sat_t:.2f}, enforced={judge_result.enforced_constraints}")
+ if judge_result.violations:
+ print(f"[Judge] violations={judge_result.violations}")
+
+ # Compute feedback
+ reward, gating = compute_feedback_for_turn(
+ turn_id=turn_id,
+ query=query,
+ query_type=query_type,
+ task_type=task_type,
+ judge_result=judge_result,
+ )
+ print(f"[Feedback] reward={reward:.2f}, gating={gating:.1f}")
+
+ # Get state after
+ state_after = llm.get_user_state_summary(user_id)
+ z_long_after = state_after["z_long_norm"]
+ z_short_after = state_after["z_short_norm"]
+
+ z_long_delta = z_long_after - z_long_before
+ z_short_delta = z_short_after - z_short_before
+ print(f"[State] z_long: {z_long_before:.6f} → {z_long_after:.6f} (Δ={z_long_delta:+.6f})")
+ print(f"[State] z_short: {z_short_before:.6f} → {z_short_after:.6f} (Δ={z_short_delta:+.6f})")
+
+ # Debug info
+ num_memories = len(resp.debug.selected_memory_ids) if resp.debug else 0
+ num_prefs = len(resp.debug.extracted_preferences) if resp.debug else 0
+ print(f"[Debug] memories={num_memories}, prefs_extracted={num_prefs}")
+
+ # Log
+ log = TurnLog(
+ session_id=session_id,
+ turn_id=turn_id,
+ query=query,
+ query_type=query_type,
+ task_type=task_type,
+ answer=resp.answer,
+ answer_length=len(resp.answer),
+ sat_t=judge_result.sat_t,
+ sev_t=judge_result.sev_t,
+ prog_t=judge_result.prog_t,
+ violations=judge_result.violations,
+ enforced_constraints=judge_result.enforced_constraints,
+ reward=reward,
+ gating=gating,
+ reveal_state_before=reveal_before,
+ reveal_state_after=reveal_after,
+ newly_revealed=list(newly_revealed),
+ z_long_norm_before=z_long_before,
+ z_long_norm_after=z_long_after,
+ z_short_norm_before=z_short_before,
+ z_short_norm_after=z_short_after,
+ prompt_tokens=resp.usage.prompt_tokens,
+ completion_tokens=resp.usage.completion_tokens,
+ total_tokens=resp.usage.total_tokens,
+ num_memories_retrieved=num_memories,
+ num_prefs_extracted=num_prefs,
+ )
+ logs.append(log)
+
+ # Apply final feedback for this session
+ session_logs = [l for l in logs if l.session_id == session_id]
+ if session_logs:
+ last_log = session_logs[-1]
+ feedback = Feedback(
+ user_id=user_id,
+ turn_id=last_log.turn_id,
+ reward=last_log.reward,
+ gating=last_log.gating,
+ meta={"source": "pilot_v2", "session_id": session_id, "final": True}
+ )
+ print(f"\n[Final Feedback] turn={last_log.turn_id}, reward={feedback.reward:.2f}, gating={feedback.gating:.1f}")
+ llm.apply_feedback(feedback)
+
+ print(f"\n[Session {session_id} End] Reveal state: {reveal_state}")
+
+ return logs
+
+
+def run_pilot_v2(
+ llm: PersonalizedLLM,
+ user_id: str = "pilot_user_v2",
+ prefs: Optional[StylePrefs] = None,
+) -> List[TurnLog]:
+ """
+ Run multi-session pilot with reveal state tracking.
+
+ Session 1: User reveals preferences
+ Session 2: User does NOT restate preferences (tests cross-session retention)
+ Session 3: Mix of tasks and reminders
+ """
+ if prefs is None:
+ prefs = StylePrefs(
+ require_short=True,
+ max_chars=200,
+ require_bullets=True,
+ lang="en",
+ )
+
+ # Initialize reveal state manager
+ reveal_manager = RevealStateManager()
+
+ print(f"\n{'#'*60}")
+ print(f"PILOT v2: CROSS-SESSION PREFERENCE REVEAL TEST")
+ print(f"User: {user_id}")
+ print(f"True prefs: short={prefs.require_short}, bullets={prefs.require_bullets}, lang={prefs.lang}")
+ print(f"{'#'*60}")
+
+ # Reset user completely (clears all state including reveal)
+ print(f"\n[Pilot] Resetting user: {user_id}")
+ llm.reset_user(user_id)
+ reveal_manager.reset_user(user_id)
+
+ all_logs: List[TurnLog] = []
+ reveal_state = reveal_manager.get_state(user_id)
+
+ # Session 1: Reveal preferences
+ session_1_queries = get_session_1_queries()
+ logs_s1 = run_session(llm, user_id, 1, prefs, reveal_state, session_1_queries)
+ all_logs.extend(logs_s1)
+
+ # Session 2: NO preference restatement (test cross-session retention)
+ # Note: reveal_state persists, but reset_session clears history
+ reveal_manager.reset_session(user_id) # Does nothing to reveal state
+ session_2_queries = get_session_2_queries()
+ logs_s2 = run_session(llm, user_id, 2, prefs, reveal_state, session_2_queries)
+ all_logs.extend(logs_s2)
+
+ # Session 3: Reminder and more tasks
+ reveal_manager.reset_session(user_id)
+ session_3_queries = get_session_3_queries()
+ logs_s3 = run_session(llm, user_id, 3, prefs, reveal_state, session_3_queries)
+ all_logs.extend(logs_s3)
+
+ return all_logs
+
+
+def print_summary_v2(logs: List[TurnLog], prefs: StylePrefs):
+ """Print summary for pilot v2."""
+ print(f"\n{'='*60}")
+ print("PILOT v2 SUMMARY - Cross-Session Reveal")
+ print(f"{'='*60}")
+
+ if not logs:
+ print("No logs to summarize.")
+ return
+
+ # Per-session stats
+ sessions = sorted(set(l.session_id for l in logs))
+
+ print(f"\n--- Per-Session Statistics ---")
+ for sid in sessions:
+ session_logs = [l for l in logs if l.session_id == sid]
+ avg_sat = sum(l.sat_t for l in session_logs) / len(session_logs)
+ violations = [v for l in session_logs for v in l.violations]
+
+ # What was revealed at session end
+ if session_logs:
+ final_reveal = session_logs[-1].reveal_state_after
+ else:
+ final_reveal = {}
+
+ print(f"\nSession {sid}: {len(session_logs)} turns")
+ print(f" Avg sat_t: {avg_sat:.3f}")
+ print(f" Violations: {len(violations)} ({violations if violations else 'none'})")
+ print(f" Reveal state at end: {final_reveal}")
+
+ # Overall stats
+ total = len(logs)
+ avg_sat = sum(l.sat_t for l in logs) / total
+ total_tokens = sum(l.total_tokens for l in logs)
+
+ print(f"\n--- Overall Statistics ---")
+ print(f"Total turns: {total}")
+ print(f"Overall avg sat_t: {avg_sat:.3f}")
+ print(f"Total tokens: {total_tokens}")
+
+ # Violations by type
+ print(f"\n--- Violations Breakdown ---")
+ from collections import Counter
+ all_violations = [v for l in logs for v in l.violations]
+ if all_violations:
+ for v, count in Counter(all_violations).most_common():
+ print(f" {v}: {count}")
+ else:
+ print(" No violations")
+
+ # Enforcement tracking
+ print(f"\n--- Constraint Enforcement ---")
+ for constraint in ["short", "bullets", "lang"]:
+ enforced_count = sum(1 for l in logs if constraint in l.enforced_constraints)
+ print(f" {constraint}: enforced in {enforced_count}/{total} turns")
+
+ # Cross-session reveal verification
+ print(f"\n--- Cross-Session Reveal Verification ---")
+
+ # Session 1: Should have some reveals
+ s1_logs = [l for l in logs if l.session_id == 1]
+ s1_reveals = set()
+ for l in s1_logs:
+ s1_reveals.update(l.newly_revealed)
+ print(f"Session 1 revealed: {s1_reveals if s1_reveals else 'none'}")
+
+ # Session 2: Should NOT have new reveals (no preference queries)
+ s2_logs = [l for l in logs if l.session_id == 2]
+ s2_reveals = set()
+ for l in s2_logs:
+ s2_reveals.update(l.newly_revealed)
+ print(f"Session 2 revealed: {s2_reveals if s2_reveals else 'none (expected)'}")
+
+ # But Session 2 should still ENFORCE the constraints revealed in Session 1
+ if s2_logs:
+ s2_enforced = set()
+ for l in s2_logs:
+ s2_enforced.update(l.enforced_constraints)
+ print(f"Session 2 enforced: {s2_enforced}")
+
+ if s1_reveals and s1_reveals.issubset(s2_enforced):
+ print("✓ Cross-session retention VERIFIED: Session 1 reveals enforced in Session 2")
+ else:
+ print("✗ Cross-session retention issue: some reveals not enforced")
+
+ # Turn-by-turn table
+ print(f"\n--- Turn-by-Turn Summary ---")
+ print(f"{'S':>2} {'T':>2} {'Type':>10} {'Len':>5} {'sat':>5} {'enforced':<20} {'violations'}")
+ print("-" * 70)
+ for l in logs:
+ enforced_str = ",".join(l.enforced_constraints) if l.enforced_constraints else "-"
+ viol_str = ",".join(l.violations) if l.violations else "-"
+ print(f"{l.session_id:>2} {l.turn_id:>2} {l.query_type:>10} {l.answer_length:>5} {l.sat_t:>5.2f} {enforced_str:<20} {viol_str}")
+
+
+def main():
+ print("=" * 60)
+ print("PILOT RUNNER v2 - Cross-Session Preference Reveal")
+ print("=" * 60)
+ print(f"Started at: {datetime.now().isoformat()}")
+
+ # Define user's TRUE preferences
+ prefs = StylePrefs(
+ require_short=True,
+ max_chars=200,
+ require_bullets=True,
+ lang="en",
+ )
+ print(f"\n[Config] True preferences: {prefs}")
+ print("[Config] Note: Constraints only enforced AFTER user reveals them")
+
+ # Initialize LLM
+ print("\n[Init] Loading PersonalizedLLM...")
+ llm = PersonalizedLLM(
+ user_store_path="data/users/user_store_pilot_v2.npz",
+ only_own_memories=True,
+ enable_preference_extraction=True,
+ enable_rl_updates=True,
+ )
+
+ # Run pilot
+ user_id = "pilot_user_v2"
+ logs = run_pilot_v2(llm, user_id=user_id, prefs=prefs)
+
+ # Summary
+ print_summary_v2(logs, prefs)
+
+ # Save logs
+ log_path = f"data/logs/pilot_v2_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl"
+ log_to_jsonl(logs, log_path)
+ print(f"\n[Logs] Saved to: {log_path}")
+
+ # Final state
+ final_state = llm.get_user_state_summary(user_id)
+ print(f"\n[Final State] {final_state}")
+
+ print(f"\nCompleted at: {datetime.now().isoformat()}")
+ print("=" * 60)
+
+
+if __name__ == "__main__":
+ main()
+
+