#!/usr/bin/env python3 """ Pilot Runner v1 - Style-Aware Judge + Gating Logic Upgrade from v0: - StylePrefs: User style preferences (length, bullets, language) - style_judge: Checks style conformance, not just non-empty - compute_feedback_for_turn: gating=1 only for preference-related turns - Extended queries: ~10 turns with preference/task mix Goal: Verify that: 1. sat_t varies based on style violations (not always 1) 2. gating=1 only on preference turns, 0 on regular tasks 3. RL updates happen when gating=1 and reward != baseline 4. Over turns, model may adapt to preferences (sat_t improves) """ import sys import os import json from datetime import datetime from dataclasses import dataclass, asdict, field from typing import List, Dict, Any, Optional, Tuple # Add src to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src")) from personalization.serving import PersonalizedLLM, Feedback, AssistantResponse # ============================================================================= # Style Preferences # ============================================================================= @dataclass class StylePrefs: """User's style preferences for the judge to check.""" require_short: bool = False max_chars: int = 300 require_bullets: bool = False lang: str = "en" # "en" or "zh" # ============================================================================= # Style-Aware Judge # ============================================================================= @dataclass class JudgeResult: """Output from the judge for one turn.""" sat_t: float # Satisfaction score [0, 1] sev_t: float # Severity of violations [0, 1] prog_t: float # Task progress [0, 1] violations: List[str] # List of violated constraints def style_judge( query: str, answer: str, task_type: str, prefs: StylePrefs, ) -> JudgeResult: """ Style-aware judge that checks: - Empty/too short answer - Length constraint (max_chars) - Bullet point requirement - Language preference - Code block for code tasks Returns: JudgeResult with sat_t, sev_t, prog_t, and violations list. """ violations: List[str] = [] text = (answer or "").strip() # 0) Empty answer - immediate fail if not text or len(text) < 5: violations.append("empty_answer") return JudgeResult( sat_t=0.0, sev_t=1.0, prog_t=0.0, violations=violations, ) # 1) Length preference if prefs.require_short: if len(text) > prefs.max_chars: violations.append("too_long") # 2) Bullet preference (only for general/list tasks, not pure preference statements) if prefs.require_bullets and task_type in ("general", "list"): has_bullets = ("- " in text) or ("• " in text) or ("* " in text) or ("\n- " in text) if not has_bullets: violations.append("no_bullets") # 3) Language preference (rough heuristic) if prefs.lang == "zh": # For Chinese preference, check if answer has enough non-ASCII chars ascii_count = sum(c.isascii() for c in text) ascii_ratio = ascii_count / max(1, len(text)) if ascii_ratio > 0.7: # Too much ASCII = probably not Chinese violations.append("wrong_lang") elif prefs.lang == "en": # For English preference, check if answer is mostly ASCII ascii_count = sum(c.isascii() for c in text) ascii_ratio = ascii_count / max(1, len(text)) if ascii_ratio < 0.5: # Too little ASCII = probably not English violations.append("wrong_lang") # 4) Code task: must have code markers prog_t = 1.0 if task_type == "code": has_code = ("```" in text) or ("def " in text) or ("function " in text) if not has_code: violations.append("no_code_block") prog_t = 0.0 # 5) Compute sat_t and sev_t from violations if not violations: sat_t = 1.0 sev_t = 0.0 else: # Each violation costs 0.3, minimum 0 sat_t = max(0.0, 1.0 - 0.3 * float(len(violations))) # Hard violations trigger sev_t=1 hard_violations = {"empty_answer", "too_long", "wrong_lang"} sev_t = 1.0 if any(v in hard_violations for v in violations) else 0.0 return JudgeResult( sat_t=sat_t, sev_t=sev_t, prog_t=prog_t, violations=violations, ) # ============================================================================= # Feedback Computation (reward + gating) # ============================================================================= def compute_feedback_for_turn( turn_id: int, query: str, query_type: str, task_type: str, judge_result: JudgeResult, ) -> Tuple[float, float]: """ Convert JudgeResult into (reward, gating): - reward = sat_t (style satisfaction) - gating = 1 only if this turn is preference-related (declared or complained) Args: turn_id: The turn index query: The user's query text query_type: "preference" or "task" from query metadata task_type: "general", "list", "code", etc. judge_result: The judge's evaluation Returns: (reward, gating) tuple """ reward = judge_result.sat_t # Gating logic: only allow RL update on preference-related turns # 1. Explicit preference declaration (query_type == "preference") # 2. Complaint about not following preference lower_q = (query or "").lower() is_pref_turn = ( query_type == "preference" or "i prefer" in lower_q or "my preference" in lower_q or "please use" in lower_q or "please keep" in lower_q or "you didn't follow" in lower_q or "you forgot" in lower_q or "remember that i" in lower_q or "i told you" in lower_q or "i asked for" in lower_q ) if is_pref_turn: gating = 1.0 else: gating = 0.0 return reward, gating # ============================================================================= # Extended Queries for Pilot v1 (~10 turns) # ============================================================================= def get_pilot_v1_queries() -> List[Dict[str, Any]]: """ Extended query set for pilot v1. Mix of preference declarations and tasks. Tests: length constraint, bullet points, task completion. """ return [ # Turn 0: Declare length preference { "query": "I prefer short, concise answers. Please keep responses under 200 characters.", "type": "preference", "task_type": "general", }, # Turn 1: Task that should be short { "query": "What are three tips for better sleep?", "type": "task", "task_type": "list", }, # Turn 2: Declare bullet preference { "query": "I also prefer bullet points when listing things. Please use bullet points.", "type": "preference", "task_type": "general", }, # Turn 3: Task that should use bullets { "query": "What are the main benefits of regular exercise?", "type": "task", "task_type": "list", }, # Turn 4: Another task (test if preferences stick) { "query": "Name five popular programming languages.", "type": "task", "task_type": "list", }, # Turn 5: Complaint if needed (always include to test gating) { "query": "Remember that I asked for short answers with bullet points. Can you list three healthy breakfast ideas?", "type": "preference", "task_type": "list", }, # Turn 6: Regular task { "query": "What is the capital of France?", "type": "task", "task_type": "general", }, # Turn 7: Task requiring list { "query": "What are four seasons of the year?", "type": "task", "task_type": "list", }, # Turn 8: Another preference reminder { "query": "I prefer concise bullet points. Please list three types of renewable energy.", "type": "preference", "task_type": "list", }, # Turn 9: Final task - test memory { "query": "Summarize what you know about my communication preferences.", "type": "task", "task_type": "general", }, ] # ============================================================================= # Logging # ============================================================================= @dataclass class TurnLog: """Log entry for one turn.""" turn_id: int query: str query_type: str task_type: str answer: str answer_length: int sat_t: float sev_t: float prog_t: float violations: List[str] reward: float gating: float z_long_norm_before: float z_long_norm_after: float z_short_norm_before: float z_short_norm_after: float prompt_tokens: int completion_tokens: int total_tokens: int num_memories_retrieved: int num_prefs_extracted: int def log_to_jsonl(logs: List[TurnLog], filepath: str): """Save logs to JSONL file.""" os.makedirs(os.path.dirname(filepath), exist_ok=True) with open(filepath, "w") as f: for log in logs: f.write(json.dumps(asdict(log)) + "\n") # ============================================================================= # Pilot Runner v1 # ============================================================================= def run_pilot_v1( llm: PersonalizedLLM, user_id: str = "pilot_user_v1", prefs: Optional[StylePrefs] = None, queries: Optional[List[Dict[str, Any]]] = None, ) -> List[TurnLog]: """ Run pilot v1 with style-aware judge and gating. Args: llm: PersonalizedLLM instance user_id: User identifier prefs: Style preferences for this user queries: Query list (defaults to get_pilot_v1_queries) Returns: List of TurnLog entries """ if prefs is None: # Default preferences: short + bullets + English prefs = StylePrefs( require_short=True, max_chars=200, require_bullets=True, lang="en", ) if queries is None: queries = get_pilot_v1_queries() logs: List[TurnLog] = [] print(f"\n{'='*60}") print(f"PILOT v1 SESSION: user_id={user_id}, turns={len(queries)}") print(f"Preferences: short={prefs.require_short}, max_chars={prefs.max_chars}, bullets={prefs.require_bullets}, lang={prefs.lang}") print(f"{'='*60}") # Reset user for clean start print(f"\n[Pilot] Resetting user: {user_id}") llm.reset_user(user_id) # Start session print(f"[Pilot] Starting session") llm.reset_session(user_id) # Get initial state state_before = llm.get_user_state_summary(user_id) print(f"[Pilot] Initial state: z_long={state_before['z_long_norm']:.6f}, z_short={state_before['z_short_norm']:.6f}") for turn_id, q_info in enumerate(queries): query = q_info["query"] query_type = q_info.get("type", "task") task_type = q_info.get("task_type", "general") print(f"\n{'─'*60}") print(f"Turn {turn_id} [{query_type}]") print(f"{'─'*60}") print(f"[Query] {query}") # Get state before state_before = llm.get_user_state_summary(user_id) z_long_before = state_before["z_long_norm"] z_short_before = state_before["z_short_norm"] # Apply feedback for previous turn (from turn 1 onwards) if turn_id > 0 and len(logs) > 0: prev_log = logs[-1] prev_query = queries[turn_id - 1] # Re-judge the previous answer with current context # (In practice we already have the result, but this shows the flow) feedback = Feedback( user_id=user_id, turn_id=turn_id - 1, reward=prev_log.reward, gating=prev_log.gating, meta={ "sat_t": prev_log.sat_t, "sev_t": prev_log.sev_t, "prog_t": prev_log.prog_t, "violations": prev_log.violations, "task_type": prev_log.task_type, "source": "pilot_v1", } ) print(f"[Feedback] turn={turn_id-1}, reward={feedback.reward:.2f}, gating={feedback.gating:.1f}") llm.apply_feedback(feedback) # Chat resp: AssistantResponse = llm.chat(user_id, query) # Truncate answer for display answer_display = resp.answer[:150] + "..." if len(resp.answer) > 150 else resp.answer print(f"[Answer] ({len(resp.answer)} chars) {answer_display}") print(f"[Usage] prompt={resp.usage.prompt_tokens}, completion={resp.usage.completion_tokens}") # Judge with style preferences judge_result = style_judge(query, resp.answer, task_type, prefs) print(f"[Judge] sat={judge_result.sat_t:.2f}, sev={judge_result.sev_t:.1f}, prog={judge_result.prog_t:.1f}") if judge_result.violations: print(f"[Judge] violations={judge_result.violations}") # Compute feedback for THIS turn (will be applied next turn) reward, gating = compute_feedback_for_turn( turn_id=turn_id, query=query, query_type=query_type, task_type=task_type, judge_result=judge_result, ) print(f"[Feedback] reward={reward:.2f}, gating={gating:.1f} (computed for this turn)") # Get state after state_after = llm.get_user_state_summary(user_id) z_long_after = state_after["z_long_norm"] z_short_after = state_after["z_short_norm"] # Debug info num_memories = len(resp.debug.selected_memory_ids) if resp.debug else 0 num_prefs = len(resp.debug.extracted_preferences) if resp.debug else 0 z_long_delta = z_long_after - z_long_before z_short_delta = z_short_after - z_short_before print(f"[State] z_long: {z_long_before:.6f} → {z_long_after:.6f} (Δ={z_long_delta:+.6f})") print(f"[State] z_short: {z_short_before:.6f} → {z_short_after:.6f} (Δ={z_short_delta:+.6f})") print(f"[Debug] memories={num_memories}, prefs_extracted={num_prefs}") # Log log = TurnLog( turn_id=turn_id, query=query, query_type=query_type, task_type=task_type, answer=resp.answer, answer_length=len(resp.answer), sat_t=judge_result.sat_t, sev_t=judge_result.sev_t, prog_t=judge_result.prog_t, violations=judge_result.violations, reward=reward, gating=gating, z_long_norm_before=z_long_before, z_long_norm_after=z_long_after, z_short_norm_before=z_short_before, z_short_norm_after=z_short_after, prompt_tokens=resp.usage.prompt_tokens, completion_tokens=resp.usage.completion_tokens, total_tokens=resp.usage.total_tokens, num_memories_retrieved=num_memories, num_prefs_extracted=num_prefs, ) logs.append(log) # Apply final feedback for last turn if len(logs) > 0: last_log = logs[-1] feedback = Feedback( user_id=user_id, turn_id=len(queries) - 1, reward=last_log.reward, gating=last_log.gating, meta={"source": "pilot_v1", "final": True} ) print(f"\n[Final Feedback] turn={len(queries)-1}, reward={feedback.reward:.2f}, gating={feedback.gating:.1f}") llm.apply_feedback(feedback) return logs def print_summary_v1(logs: List[TurnLog], prefs: StylePrefs): """Print summary statistics for pilot v1.""" print(f"\n{'='*60}") print("PILOT v1 SUMMARY") print(f"{'='*60}") total_turns = len(logs) if total_turns == 0: print("No turns to summarize.") return # Basic stats avg_sat = sum(l.sat_t for l in logs) / total_turns avg_prog = sum(l.prog_t for l in logs) / total_turns total_tokens = sum(l.total_tokens for l in logs) total_prompt = sum(l.prompt_tokens for l in logs) total_completion = sum(l.completion_tokens for l in logs) # Gating stats gated_turns = [l for l in logs if l.gating > 0] non_gated_turns = [l for l in logs if l.gating == 0] print(f"\n--- Turn Statistics ---") print(f"Total turns: {total_turns}") print(f"Gated turns (RL active): {len(gated_turns)}") print(f"Non-gated turns (RL skipped): {len(non_gated_turns)}") print(f"\n--- Satisfaction ---") print(f"Average sat_t (all): {avg_sat:.3f}") if gated_turns: avg_sat_gated = sum(l.sat_t for l in gated_turns) / len(gated_turns) print(f"Average sat_t (gated only): {avg_sat_gated:.3f}") print(f"Average prog_t: {avg_prog:.3f}") print(f"\n--- Token Usage ---") print(f"Total tokens: {total_tokens}") print(f" Prompt: {total_prompt}") print(f" Completion: {total_completion}") print(f"Avg tokens/turn: {total_tokens / total_turns:.1f}") # Violations breakdown print(f"\n--- Violations ---") from collections import Counter all_violations = [v for l in logs for v in l.violations] if all_violations: print(f"Total violations: {len(all_violations)}") for v, count in Counter(all_violations).most_common(): print(f" {v}: {count}") else: print("No violations") # Answer length analysis print(f"\n--- Answer Lengths (max_chars={prefs.max_chars}) ---") lengths = [l.answer_length for l in logs] over_limit = sum(1 for l in lengths if l > prefs.max_chars) print(f"Min: {min(lengths)}, Max: {max(lengths)}, Avg: {sum(lengths)/len(lengths):.1f}") print(f"Over limit: {over_limit}/{total_turns}") # RL Health Check print(f"\n--- RL Health Check ---") z_long_changes = [abs(l.z_long_norm_after - l.z_long_norm_before) for l in logs] z_short_changes = [abs(l.z_short_norm_after - l.z_short_norm_before) for l in logs] any_z_long_change = any(c > 1e-6 for c in z_long_changes) any_z_short_change = any(c > 1e-6 for c in z_short_changes) print(f"z_long changed: {any_z_long_change} (max Δ: {max(z_long_changes):.6f})") print(f"z_short changed: {any_z_short_change} (max Δ: {max(z_short_changes):.6f})") if any_z_long_change or any_z_short_change: print("✓ User vectors ARE being updated by RL") else: print("✗ WARNING: User vectors NOT changing") print(" Check: gating=1 on some turns? reward != baseline?") # Per-turn detail table print(f"\n--- Turn-by-Turn Summary ---") print(f"{'Turn':>4} {'Type':>10} {'Len':>5} {'sat':>5} {'gate':>5} {'violations'}") print("-" * 60) for l in logs: viol_str = ",".join(l.violations) if l.violations else "-" print(f"{l.turn_id:>4} {l.query_type:>10} {l.answer_length:>5} {l.sat_t:>5.2f} {l.gating:>5.1f} {viol_str}") def main(): print("=" * 60) print("PILOT RUNNER v1 - Style-Aware Judge + Gating") print("=" * 60) print(f"Started at: {datetime.now().isoformat()}") # Define user preferences prefs = StylePrefs( require_short=True, max_chars=200, require_bullets=True, lang="en", ) print(f"\n[Config] User preferences: {prefs}") # Initialize LLM print("\n[Init] Loading PersonalizedLLM...") llm = PersonalizedLLM( user_store_path="data/users/user_store_pilot_v1.npz", only_own_memories=True, enable_preference_extraction=True, enable_rl_updates=True, ) # Run pilot user_id = "pilot_user_v1" logs = run_pilot_v1(llm, user_id=user_id, prefs=prefs) # Summary print_summary_v1(logs, prefs) # Save logs log_path = f"data/logs/pilot_v1_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl" log_to_jsonl(logs, log_path) print(f"\n[Logs] Saved to: {log_path}") # Final state final_state = llm.get_user_state_summary(user_id) print(f"\n[Final State] {final_state}") print(f"\nCompleted at: {datetime.now().isoformat()}") print("=" * 60) if __name__ == "__main__": main()