summaryrefslogtreecommitdiff
path: root/scripts/pilot_runner_v1.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/pilot_runner_v1.py')
-rw-r--r--scripts/pilot_runner_v1.py607
1 files changed, 607 insertions, 0 deletions
diff --git a/scripts/pilot_runner_v1.py b/scripts/pilot_runner_v1.py
new file mode 100644
index 0000000..fbb2876
--- /dev/null
+++ b/scripts/pilot_runner_v1.py
@@ -0,0 +1,607 @@
+#!/usr/bin/env python3
+"""
+Pilot Runner v1 - Style-Aware Judge + Gating Logic
+
+Upgrade from v0:
+- StylePrefs: User style preferences (length, bullets, language)
+- style_judge: Checks style conformance, not just non-empty
+- compute_feedback_for_turn: gating=1 only for preference-related turns
+- Extended queries: ~10 turns with preference/task mix
+
+Goal: Verify that:
+1. sat_t varies based on style violations (not always 1)
+2. gating=1 only on preference turns, 0 on regular tasks
+3. RL updates happen when gating=1 and reward != baseline
+4. Over turns, model may adapt to preferences (sat_t improves)
+"""
+
+import sys
+import os
+import json
+from datetime import datetime
+from dataclasses import dataclass, asdict, field
+from typing import List, Dict, Any, Optional, Tuple
+
+# Add src to path
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.serving import PersonalizedLLM, Feedback, AssistantResponse
+
+
+# =============================================================================
+# Style Preferences
+# =============================================================================
+
+@dataclass
+class StylePrefs:
+ """User's style preferences for the judge to check."""
+ require_short: bool = False
+ max_chars: int = 300
+ require_bullets: bool = False
+ lang: str = "en" # "en" or "zh"
+
+
+# =============================================================================
+# Style-Aware Judge
+# =============================================================================
+
+@dataclass
+class JudgeResult:
+ """Output from the judge for one turn."""
+ sat_t: float # Satisfaction score [0, 1]
+ sev_t: float # Severity of violations [0, 1]
+ prog_t: float # Task progress [0, 1]
+ violations: List[str] # List of violated constraints
+
+
+def style_judge(
+ query: str,
+ answer: str,
+ task_type: str,
+ prefs: StylePrefs,
+) -> JudgeResult:
+ """
+ Style-aware judge that checks:
+ - Empty/too short answer
+ - Length constraint (max_chars)
+ - Bullet point requirement
+ - Language preference
+ - Code block for code tasks
+
+ Returns:
+ JudgeResult with sat_t, sev_t, prog_t, and violations list.
+ """
+ violations: List[str] = []
+ text = (answer or "").strip()
+
+ # 0) Empty answer - immediate fail
+ if not text or len(text) < 5:
+ violations.append("empty_answer")
+ return JudgeResult(
+ sat_t=0.0,
+ sev_t=1.0,
+ prog_t=0.0,
+ violations=violations,
+ )
+
+ # 1) Length preference
+ if prefs.require_short:
+ if len(text) > prefs.max_chars:
+ violations.append("too_long")
+
+ # 2) Bullet preference (only for general/list tasks, not pure preference statements)
+ if prefs.require_bullets and task_type in ("general", "list"):
+ has_bullets = ("- " in text) or ("• " in text) or ("* " in text) or ("\n- " in text)
+ if not has_bullets:
+ violations.append("no_bullets")
+
+ # 3) Language preference (rough heuristic)
+ if prefs.lang == "zh":
+ # For Chinese preference, check if answer has enough non-ASCII chars
+ ascii_count = sum(c.isascii() for c in text)
+ ascii_ratio = ascii_count / max(1, len(text))
+ if ascii_ratio > 0.7: # Too much ASCII = probably not Chinese
+ violations.append("wrong_lang")
+ elif prefs.lang == "en":
+ # For English preference, check if answer is mostly ASCII
+ ascii_count = sum(c.isascii() for c in text)
+ ascii_ratio = ascii_count / max(1, len(text))
+ if ascii_ratio < 0.5: # Too little ASCII = probably not English
+ violations.append("wrong_lang")
+
+ # 4) Code task: must have code markers
+ prog_t = 1.0
+ if task_type == "code":
+ has_code = ("```" in text) or ("def " in text) or ("function " in text)
+ if not has_code:
+ violations.append("no_code_block")
+ prog_t = 0.0
+
+ # 5) Compute sat_t and sev_t from violations
+ if not violations:
+ sat_t = 1.0
+ sev_t = 0.0
+ else:
+ # Each violation costs 0.3, minimum 0
+ sat_t = max(0.0, 1.0 - 0.3 * float(len(violations)))
+ # Hard violations trigger sev_t=1
+ hard_violations = {"empty_answer", "too_long", "wrong_lang"}
+ sev_t = 1.0 if any(v in hard_violations for v in violations) else 0.0
+
+ return JudgeResult(
+ sat_t=sat_t,
+ sev_t=sev_t,
+ prog_t=prog_t,
+ violations=violations,
+ )
+
+
+# =============================================================================
+# Feedback Computation (reward + gating)
+# =============================================================================
+
+def compute_feedback_for_turn(
+ turn_id: int,
+ query: str,
+ query_type: str,
+ task_type: str,
+ judge_result: JudgeResult,
+) -> Tuple[float, float]:
+ """
+ Convert JudgeResult into (reward, gating):
+ - reward = sat_t (style satisfaction)
+ - gating = 1 only if this turn is preference-related (declared or complained)
+
+ Args:
+ turn_id: The turn index
+ query: The user's query text
+ query_type: "preference" or "task" from query metadata
+ task_type: "general", "list", "code", etc.
+ judge_result: The judge's evaluation
+
+ Returns:
+ (reward, gating) tuple
+ """
+ reward = judge_result.sat_t
+
+ # Gating logic: only allow RL update on preference-related turns
+ # 1. Explicit preference declaration (query_type == "preference")
+ # 2. Complaint about not following preference
+ lower_q = (query or "").lower()
+
+ is_pref_turn = (
+ query_type == "preference"
+ or "i prefer" in lower_q
+ or "my preference" in lower_q
+ or "please use" in lower_q
+ or "please keep" in lower_q
+ or "you didn't follow" in lower_q
+ or "you forgot" in lower_q
+ or "remember that i" in lower_q
+ or "i told you" in lower_q
+ or "i asked for" in lower_q
+ )
+
+ if is_pref_turn:
+ gating = 1.0
+ else:
+ gating = 0.0
+
+ return reward, gating
+
+
+# =============================================================================
+# Extended Queries for Pilot v1 (~10 turns)
+# =============================================================================
+
+def get_pilot_v1_queries() -> List[Dict[str, Any]]:
+ """
+ Extended query set for pilot v1.
+ Mix of preference declarations and tasks.
+ Tests: length constraint, bullet points, task completion.
+ """
+ return [
+ # Turn 0: Declare length preference
+ {
+ "query": "I prefer short, concise answers. Please keep responses under 200 characters.",
+ "type": "preference",
+ "task_type": "general",
+ },
+ # Turn 1: Task that should be short
+ {
+ "query": "What are three tips for better sleep?",
+ "type": "task",
+ "task_type": "list",
+ },
+ # Turn 2: Declare bullet preference
+ {
+ "query": "I also prefer bullet points when listing things. Please use bullet points.",
+ "type": "preference",
+ "task_type": "general",
+ },
+ # Turn 3: Task that should use bullets
+ {
+ "query": "What are the main benefits of regular exercise?",
+ "type": "task",
+ "task_type": "list",
+ },
+ # Turn 4: Another task (test if preferences stick)
+ {
+ "query": "Name five popular programming languages.",
+ "type": "task",
+ "task_type": "list",
+ },
+ # Turn 5: Complaint if needed (always include to test gating)
+ {
+ "query": "Remember that I asked for short answers with bullet points. Can you list three healthy breakfast ideas?",
+ "type": "preference",
+ "task_type": "list",
+ },
+ # Turn 6: Regular task
+ {
+ "query": "What is the capital of France?",
+ "type": "task",
+ "task_type": "general",
+ },
+ # Turn 7: Task requiring list
+ {
+ "query": "What are four seasons of the year?",
+ "type": "task",
+ "task_type": "list",
+ },
+ # Turn 8: Another preference reminder
+ {
+ "query": "I prefer concise bullet points. Please list three types of renewable energy.",
+ "type": "preference",
+ "task_type": "list",
+ },
+ # Turn 9: Final task - test memory
+ {
+ "query": "Summarize what you know about my communication preferences.",
+ "type": "task",
+ "task_type": "general",
+ },
+ ]
+
+
+# =============================================================================
+# Logging
+# =============================================================================
+
+@dataclass
+class TurnLog:
+ """Log entry for one turn."""
+ turn_id: int
+ query: str
+ query_type: str
+ task_type: str
+ answer: str
+ answer_length: int
+ sat_t: float
+ sev_t: float
+ prog_t: float
+ violations: List[str]
+ reward: float
+ gating: float
+ z_long_norm_before: float
+ z_long_norm_after: float
+ z_short_norm_before: float
+ z_short_norm_after: float
+ prompt_tokens: int
+ completion_tokens: int
+ total_tokens: int
+ num_memories_retrieved: int
+ num_prefs_extracted: int
+
+
+def log_to_jsonl(logs: List[TurnLog], filepath: str):
+ """Save logs to JSONL file."""
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
+ with open(filepath, "w") as f:
+ for log in logs:
+ f.write(json.dumps(asdict(log)) + "\n")
+
+
+# =============================================================================
+# Pilot Runner v1
+# =============================================================================
+
+def run_pilot_v1(
+ llm: PersonalizedLLM,
+ user_id: str = "pilot_user_v1",
+ prefs: Optional[StylePrefs] = None,
+ queries: Optional[List[Dict[str, Any]]] = None,
+) -> List[TurnLog]:
+ """
+ Run pilot v1 with style-aware judge and gating.
+
+ Args:
+ llm: PersonalizedLLM instance
+ user_id: User identifier
+ prefs: Style preferences for this user
+ queries: Query list (defaults to get_pilot_v1_queries)
+
+ Returns:
+ List of TurnLog entries
+ """
+ if prefs is None:
+ # Default preferences: short + bullets + English
+ prefs = StylePrefs(
+ require_short=True,
+ max_chars=200,
+ require_bullets=True,
+ lang="en",
+ )
+
+ if queries is None:
+ queries = get_pilot_v1_queries()
+
+ logs: List[TurnLog] = []
+
+ print(f"\n{'='*60}")
+ print(f"PILOT v1 SESSION: user_id={user_id}, turns={len(queries)}")
+ print(f"Preferences: short={prefs.require_short}, max_chars={prefs.max_chars}, bullets={prefs.require_bullets}, lang={prefs.lang}")
+ print(f"{'='*60}")
+
+ # Reset user for clean start
+ print(f"\n[Pilot] Resetting user: {user_id}")
+ llm.reset_user(user_id)
+
+ # Start session
+ print(f"[Pilot] Starting session")
+ llm.reset_session(user_id)
+
+ # Get initial state
+ state_before = llm.get_user_state_summary(user_id)
+ print(f"[Pilot] Initial state: z_long={state_before['z_long_norm']:.6f}, z_short={state_before['z_short_norm']:.6f}")
+
+ for turn_id, q_info in enumerate(queries):
+ query = q_info["query"]
+ query_type = q_info.get("type", "task")
+ task_type = q_info.get("task_type", "general")
+
+ print(f"\n{'─'*60}")
+ print(f"Turn {turn_id} [{query_type}]")
+ print(f"{'─'*60}")
+ print(f"[Query] {query}")
+
+ # Get state before
+ state_before = llm.get_user_state_summary(user_id)
+ z_long_before = state_before["z_long_norm"]
+ z_short_before = state_before["z_short_norm"]
+
+ # Apply feedback for previous turn (from turn 1 onwards)
+ if turn_id > 0 and len(logs) > 0:
+ prev_log = logs[-1]
+ prev_query = queries[turn_id - 1]
+
+ # Re-judge the previous answer with current context
+ # (In practice we already have the result, but this shows the flow)
+ feedback = Feedback(
+ user_id=user_id,
+ turn_id=turn_id - 1,
+ reward=prev_log.reward,
+ gating=prev_log.gating,
+ meta={
+ "sat_t": prev_log.sat_t,
+ "sev_t": prev_log.sev_t,
+ "prog_t": prev_log.prog_t,
+ "violations": prev_log.violations,
+ "task_type": prev_log.task_type,
+ "source": "pilot_v1",
+ }
+ )
+ print(f"[Feedback] turn={turn_id-1}, reward={feedback.reward:.2f}, gating={feedback.gating:.1f}")
+ llm.apply_feedback(feedback)
+
+ # Chat
+ resp: AssistantResponse = llm.chat(user_id, query)
+
+ # Truncate answer for display
+ answer_display = resp.answer[:150] + "..." if len(resp.answer) > 150 else resp.answer
+ print(f"[Answer] ({len(resp.answer)} chars) {answer_display}")
+ print(f"[Usage] prompt={resp.usage.prompt_tokens}, completion={resp.usage.completion_tokens}")
+
+ # Judge with style preferences
+ judge_result = style_judge(query, resp.answer, task_type, prefs)
+ print(f"[Judge] sat={judge_result.sat_t:.2f}, sev={judge_result.sev_t:.1f}, prog={judge_result.prog_t:.1f}")
+ if judge_result.violations:
+ print(f"[Judge] violations={judge_result.violations}")
+
+ # Compute feedback for THIS turn (will be applied next turn)
+ reward, gating = compute_feedback_for_turn(
+ turn_id=turn_id,
+ query=query,
+ query_type=query_type,
+ task_type=task_type,
+ judge_result=judge_result,
+ )
+ print(f"[Feedback] reward={reward:.2f}, gating={gating:.1f} (computed for this turn)")
+
+ # Get state after
+ state_after = llm.get_user_state_summary(user_id)
+ z_long_after = state_after["z_long_norm"]
+ z_short_after = state_after["z_short_norm"]
+
+ # Debug info
+ num_memories = len(resp.debug.selected_memory_ids) if resp.debug else 0
+ num_prefs = len(resp.debug.extracted_preferences) if resp.debug else 0
+
+ z_long_delta = z_long_after - z_long_before
+ z_short_delta = z_short_after - z_short_before
+ print(f"[State] z_long: {z_long_before:.6f} → {z_long_after:.6f} (Δ={z_long_delta:+.6f})")
+ print(f"[State] z_short: {z_short_before:.6f} → {z_short_after:.6f} (Δ={z_short_delta:+.6f})")
+ print(f"[Debug] memories={num_memories}, prefs_extracted={num_prefs}")
+
+ # Log
+ log = TurnLog(
+ turn_id=turn_id,
+ query=query,
+ query_type=query_type,
+ task_type=task_type,
+ answer=resp.answer,
+ answer_length=len(resp.answer),
+ sat_t=judge_result.sat_t,
+ sev_t=judge_result.sev_t,
+ prog_t=judge_result.prog_t,
+ violations=judge_result.violations,
+ reward=reward,
+ gating=gating,
+ z_long_norm_before=z_long_before,
+ z_long_norm_after=z_long_after,
+ z_short_norm_before=z_short_before,
+ z_short_norm_after=z_short_after,
+ prompt_tokens=resp.usage.prompt_tokens,
+ completion_tokens=resp.usage.completion_tokens,
+ total_tokens=resp.usage.total_tokens,
+ num_memories_retrieved=num_memories,
+ num_prefs_extracted=num_prefs,
+ )
+ logs.append(log)
+
+ # Apply final feedback for last turn
+ if len(logs) > 0:
+ last_log = logs[-1]
+ feedback = Feedback(
+ user_id=user_id,
+ turn_id=len(queries) - 1,
+ reward=last_log.reward,
+ gating=last_log.gating,
+ meta={"source": "pilot_v1", "final": True}
+ )
+ print(f"\n[Final Feedback] turn={len(queries)-1}, reward={feedback.reward:.2f}, gating={feedback.gating:.1f}")
+ llm.apply_feedback(feedback)
+
+ return logs
+
+
+def print_summary_v1(logs: List[TurnLog], prefs: StylePrefs):
+ """Print summary statistics for pilot v1."""
+ print(f"\n{'='*60}")
+ print("PILOT v1 SUMMARY")
+ print(f"{'='*60}")
+
+ total_turns = len(logs)
+ if total_turns == 0:
+ print("No turns to summarize.")
+ return
+
+ # Basic stats
+ avg_sat = sum(l.sat_t for l in logs) / total_turns
+ avg_prog = sum(l.prog_t for l in logs) / total_turns
+ total_tokens = sum(l.total_tokens for l in logs)
+ total_prompt = sum(l.prompt_tokens for l in logs)
+ total_completion = sum(l.completion_tokens for l in logs)
+
+ # Gating stats
+ gated_turns = [l for l in logs if l.gating > 0]
+ non_gated_turns = [l for l in logs if l.gating == 0]
+
+ print(f"\n--- Turn Statistics ---")
+ print(f"Total turns: {total_turns}")
+ print(f"Gated turns (RL active): {len(gated_turns)}")
+ print(f"Non-gated turns (RL skipped): {len(non_gated_turns)}")
+
+ print(f"\n--- Satisfaction ---")
+ print(f"Average sat_t (all): {avg_sat:.3f}")
+ if gated_turns:
+ avg_sat_gated = sum(l.sat_t for l in gated_turns) / len(gated_turns)
+ print(f"Average sat_t (gated only): {avg_sat_gated:.3f}")
+ print(f"Average prog_t: {avg_prog:.3f}")
+
+ print(f"\n--- Token Usage ---")
+ print(f"Total tokens: {total_tokens}")
+ print(f" Prompt: {total_prompt}")
+ print(f" Completion: {total_completion}")
+ print(f"Avg tokens/turn: {total_tokens / total_turns:.1f}")
+
+ # Violations breakdown
+ print(f"\n--- Violations ---")
+ from collections import Counter
+ all_violations = [v for l in logs for v in l.violations]
+ if all_violations:
+ print(f"Total violations: {len(all_violations)}")
+ for v, count in Counter(all_violations).most_common():
+ print(f" {v}: {count}")
+ else:
+ print("No violations")
+
+ # Answer length analysis
+ print(f"\n--- Answer Lengths (max_chars={prefs.max_chars}) ---")
+ lengths = [l.answer_length for l in logs]
+ over_limit = sum(1 for l in lengths if l > prefs.max_chars)
+ print(f"Min: {min(lengths)}, Max: {max(lengths)}, Avg: {sum(lengths)/len(lengths):.1f}")
+ print(f"Over limit: {over_limit}/{total_turns}")
+
+ # RL Health Check
+ print(f"\n--- RL Health Check ---")
+ z_long_changes = [abs(l.z_long_norm_after - l.z_long_norm_before) for l in logs]
+ z_short_changes = [abs(l.z_short_norm_after - l.z_short_norm_before) for l in logs]
+ any_z_long_change = any(c > 1e-6 for c in z_long_changes)
+ any_z_short_change = any(c > 1e-6 for c in z_short_changes)
+
+ print(f"z_long changed: {any_z_long_change} (max Δ: {max(z_long_changes):.6f})")
+ print(f"z_short changed: {any_z_short_change} (max Δ: {max(z_short_changes):.6f})")
+
+ if any_z_long_change or any_z_short_change:
+ print("✓ User vectors ARE being updated by RL")
+ else:
+ print("✗ WARNING: User vectors NOT changing")
+ print(" Check: gating=1 on some turns? reward != baseline?")
+
+ # Per-turn detail table
+ print(f"\n--- Turn-by-Turn Summary ---")
+ print(f"{'Turn':>4} {'Type':>10} {'Len':>5} {'sat':>5} {'gate':>5} {'violations'}")
+ print("-" * 60)
+ for l in logs:
+ viol_str = ",".join(l.violations) if l.violations else "-"
+ print(f"{l.turn_id:>4} {l.query_type:>10} {l.answer_length:>5} {l.sat_t:>5.2f} {l.gating:>5.1f} {viol_str}")
+
+
+def main():
+ print("=" * 60)
+ print("PILOT RUNNER v1 - Style-Aware Judge + Gating")
+ print("=" * 60)
+ print(f"Started at: {datetime.now().isoformat()}")
+
+ # Define user preferences
+ prefs = StylePrefs(
+ require_short=True,
+ max_chars=200,
+ require_bullets=True,
+ lang="en",
+ )
+ print(f"\n[Config] User preferences: {prefs}")
+
+ # Initialize LLM
+ print("\n[Init] Loading PersonalizedLLM...")
+ llm = PersonalizedLLM(
+ user_store_path="data/users/user_store_pilot_v1.npz",
+ only_own_memories=True,
+ enable_preference_extraction=True,
+ enable_rl_updates=True,
+ )
+
+ # Run pilot
+ user_id = "pilot_user_v1"
+ logs = run_pilot_v1(llm, user_id=user_id, prefs=prefs)
+
+ # Summary
+ print_summary_v1(logs, prefs)
+
+ # Save logs
+ log_path = f"data/logs/pilot_v1_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl"
+ log_to_jsonl(logs, log_path)
+ print(f"\n[Logs] Saved to: {log_path}")
+
+ # Final state
+ final_state = llm.get_user_state_summary(user_id)
+ print(f"\n[Final State] {final_state}")
+
+ print(f"\nCompleted at: {datetime.now().isoformat()}")
+ print("=" * 60)
+
+
+if __name__ == "__main__":
+ main()
+