#!/usr/bin/env python3 """ Pilot Runner v0 - Minimal End-to-End Test Goal: Prove the chat → judge → apply_feedback → next query loop works. Setup: - 1 user × 1 session × 5 turns - Fixed queries (no fancy user simulator yet) - Rule-based judge: answer non-empty → sat=1, else 0 - reward = sat, gating = 1 always What we're checking: 1. No crashes (KeyError, NoneType, etc.) 2. User vector norms change after feedback (RL is being called) 3. resp.usage returns reasonable numbers 4. Logs are generated correctly """ import sys import os import json from datetime import datetime from dataclasses import dataclass, asdict from typing import List, Dict, Any, Optional # Add src to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src")) from personalization.serving import PersonalizedLLM, Feedback, AssistantResponse # ============================================================================= # Minimal Judge # ============================================================================= @dataclass class JudgeResult: """Output from the judge for one turn.""" sat_t: float # Satisfaction score [0, 1] sev_t: float # Severity of violations [0, 1] prog_t: float # Task progress [0, 1] violations: List[str] # List of violated constraints def minimal_judge(query: str, answer: str, task_type: str = "general") -> JudgeResult: """ Minimal rule-based judge for pilot. For now: - sat_t = 1 if answer is non-empty, else 0 - sev_t = 0 (no severity tracking yet) - prog_t = 1 if answer looks reasonable, else 0 """ violations = [] # Check 1: Answer is non-empty if not answer or len(answer.strip()) < 5: violations.append("empty_answer") return JudgeResult(sat_t=0.0, sev_t=1.0, prog_t=0.0, violations=violations) # Check 2: Answer is not too short (at least 20 chars for real content) if len(answer.strip()) < 20: violations.append("too_short") # Check 3: For code tasks, look for code markers if task_type == "code": has_code = "```" in answer or "def " in answer or "function" in answer if not has_code: violations.append("no_code_block") # Calculate scores sat_t = 1.0 if len(violations) == 0 else max(0.0, 1.0 - 0.3 * len(violations)) sev_t = 1.0 if "empty_answer" in violations else 0.0 prog_t = 1.0 if "empty_answer" not in violations else 0.0 return JudgeResult(sat_t=sat_t, sev_t=sev_t, prog_t=prog_t, violations=violations) # ============================================================================= # Minimal User Simulator (Fixed Queries) # ============================================================================= def get_fixed_queries() -> List[Dict[str, Any]]: """ Return fixed queries for pilot test. Mix of preference statements and tasks. """ return [ { "query": "I prefer short, concise answers. Please keep responses under 100 words.", "type": "preference", "task_type": "general", }, { "query": "What are three tips for better sleep?", "type": "task", "task_type": "general", }, { "query": "I also prefer bullet points when listing things.", "type": "preference", "task_type": "general", }, { "query": "What are the main benefits of exercise?", "type": "task", "task_type": "general", }, { "query": "Summarize what you know about my preferences.", "type": "task", "task_type": "general", }, ] # ============================================================================= # Logging # ============================================================================= @dataclass class TurnLog: """Log entry for one turn.""" turn_id: int query: str query_type: str answer: str answer_length: int sat_t: float sev_t: float prog_t: float violations: List[str] reward: float gating: float z_long_norm_before: float z_long_norm_after: float z_short_norm_before: float z_short_norm_after: float prompt_tokens: int completion_tokens: int total_tokens: int num_memories_retrieved: int num_prefs_extracted: int def log_to_jsonl(logs: List[TurnLog], filepath: str): """Save logs to JSONL file.""" os.makedirs(os.path.dirname(filepath), exist_ok=True) with open(filepath, "w") as f: for log in logs: f.write(json.dumps(asdict(log)) + "\n") # ============================================================================= # Pilot Runner # ============================================================================= def run_pilot( llm: PersonalizedLLM, user_id: str = "pilot_user_0", queries: Optional[List[Dict[str, Any]]] = None, ) -> List[TurnLog]: """ Run a single pilot session. Returns list of turn logs. """ if queries is None: queries = get_fixed_queries() logs: List[TurnLog] = [] print(f"\n{'='*60}") print(f"PILOT SESSION: user_id={user_id}, turns={len(queries)}") print(f"{'='*60}") # Reset user for clean start print(f"\n[Pilot] Resetting user: {user_id}") llm.reset_user(user_id) # Start session print(f"[Pilot] Starting session") llm.reset_session(user_id) # Get initial state state_before = llm.get_user_state_summary(user_id) print(f"[Pilot] Initial state: z_long_norm={state_before['z_long_norm']:.6f}, z_short_norm={state_before['z_short_norm']:.6f}") for turn_id, q_info in enumerate(queries): query = q_info["query"] query_type = q_info.get("type", "task") task_type = q_info.get("task_type", "general") print(f"\n--- Turn {turn_id} ---") print(f"[Query] ({query_type}) {query[:80]}...") # Get state before state_before = llm.get_user_state_summary(user_id) z_long_before = state_before["z_long_norm"] z_short_before = state_before["z_short_norm"] # Apply feedback for previous turn (from turn 1 onwards) if turn_id > 0 and len(logs) > 0: prev_log = logs[-1] feedback = Feedback( user_id=user_id, turn_id=turn_id - 1, reward=prev_log.reward, gating=prev_log.gating, meta={"source": "pilot_v0"} ) print(f"[Feedback] Applying: reward={feedback.reward:.2f}, gating={feedback.gating:.1f}") llm.apply_feedback(feedback) # Chat resp: AssistantResponse = llm.chat(user_id, query) print(f"[Answer] {resp.answer[:100]}..." if len(resp.answer) > 100 else f"[Answer] {resp.answer}") print(f"[Usage] prompt={resp.usage.prompt_tokens}, completion={resp.usage.completion_tokens}") # Judge judge_result = minimal_judge(query, resp.answer, task_type) print(f"[Judge] sat={judge_result.sat_t:.2f}, prog={judge_result.prog_t:.2f}, violations={judge_result.violations}") # Compute reward and gating reward = judge_result.sat_t # Simple: reward = satisfaction gating = 1.0 # Always allow learning for pilot # Get state after state_after = llm.get_user_state_summary(user_id) z_long_after = state_after["z_long_norm"] z_short_after = state_after["z_short_norm"] # Debug info num_memories = len(resp.debug.selected_memory_ids) if resp.debug else 0 num_prefs = len(resp.debug.extracted_preferences) if resp.debug else 0 print(f"[State] z_long: {z_long_before:.6f} -> {z_long_after:.6f}, z_short: {z_short_before:.6f} -> {z_short_after:.6f}") print(f"[Debug] memories={num_memories}, prefs_extracted={num_prefs}") # Log log = TurnLog( turn_id=turn_id, query=query, query_type=query_type, answer=resp.answer, answer_length=len(resp.answer), sat_t=judge_result.sat_t, sev_t=judge_result.sev_t, prog_t=judge_result.prog_t, violations=judge_result.violations, reward=reward, gating=gating, z_long_norm_before=z_long_before, z_long_norm_after=z_long_after, z_short_norm_before=z_short_before, z_short_norm_after=z_short_after, prompt_tokens=resp.usage.prompt_tokens, completion_tokens=resp.usage.completion_tokens, total_tokens=resp.usage.total_tokens, num_memories_retrieved=num_memories, num_prefs_extracted=num_prefs, ) logs.append(log) # Apply final feedback if len(logs) > 0: last_log = logs[-1] feedback = Feedback( user_id=user_id, turn_id=len(queries) - 1, reward=last_log.reward, gating=last_log.gating, meta={"source": "pilot_v0", "final": True} ) print(f"\n[Final Feedback] reward={feedback.reward:.2f}, gating={feedback.gating:.1f}") llm.apply_feedback(feedback) return logs def print_summary(logs: List[TurnLog]): """Print summary statistics.""" print(f"\n{'='*60}") print("PILOT SUMMARY") print(f"{'='*60}") total_turns = len(logs) avg_sat = sum(l.sat_t for l in logs) / total_turns if total_turns > 0 else 0 avg_prog = sum(l.prog_t for l in logs) / total_turns if total_turns > 0 else 0 total_tokens = sum(l.total_tokens for l in logs) total_prompt = sum(l.prompt_tokens for l in logs) total_completion = sum(l.completion_tokens for l in logs) # Check if RL updates happened (vector norms changed) z_long_changes = [abs(l.z_long_norm_after - l.z_long_norm_before) for l in logs] z_short_changes = [abs(l.z_short_norm_after - l.z_short_norm_before) for l in logs] any_z_long_change = any(c > 1e-6 for c in z_long_changes) any_z_short_change = any(c > 1e-6 for c in z_short_changes) print(f"Total turns: {total_turns}") print(f"Average satisfaction: {avg_sat:.3f}") print(f"Average progress: {avg_prog:.3f}") print(f"Total tokens: {total_tokens} (prompt: {total_prompt}, completion: {total_completion})") print(f"z_long changed: {any_z_long_change} (max delta: {max(z_long_changes):.6f})") print(f"z_short changed: {any_z_short_change} (max delta: {max(z_short_changes):.6f})") # Violations breakdown all_violations = [v for l in logs for v in l.violations] if all_violations: from collections import Counter print(f"Violations: {dict(Counter(all_violations))}") else: print("Violations: None") # RL Health Check print(f"\n--- RL Health Check ---") if any_z_long_change or any_z_short_change: print("✓ User vectors ARE being updated by RL") else: print("✗ WARNING: User vectors NOT changing - check apply_feedback") def main(): print("=" * 60) print("PILOT RUNNER v0") print("=" * 60) print(f"Started at: {datetime.now().isoformat()}") # Initialize LLM print("\n[Init] Loading PersonalizedLLM...") llm = PersonalizedLLM( user_store_path="data/users/user_store_pilot.npz", only_own_memories=True, enable_preference_extraction=True, enable_rl_updates=True, ) # Run pilot user_id = "pilot_user_0" logs = run_pilot(llm, user_id=user_id) # Summary print_summary(logs) # Save logs log_path = f"data/logs/pilot_v0_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl" log_to_jsonl(logs, log_path) print(f"\n[Logs] Saved to: {log_path}") # Final state final_state = llm.get_user_state_summary(user_id) print(f"\n[Final State] {final_state}") print(f"\nCompleted at: {datetime.now().isoformat()}") print("=" * 60) if __name__ == "__main__": main()