diff options
Diffstat (limited to 'scripts/pilot_runner_v0.py')
| -rw-r--r-- | scripts/pilot_runner_v0.py | 362 |
1 files changed, 362 insertions, 0 deletions
diff --git a/scripts/pilot_runner_v0.py b/scripts/pilot_runner_v0.py new file mode 100644 index 0000000..8b7773a --- /dev/null +++ b/scripts/pilot_runner_v0.py @@ -0,0 +1,362 @@ +#!/usr/bin/env python3 +""" +Pilot Runner v0 - Minimal End-to-End Test + +Goal: Prove the chat → judge → apply_feedback → next query loop works. + +Setup: +- 1 user × 1 session × 5 turns +- Fixed queries (no fancy user simulator yet) +- Rule-based judge: answer non-empty → sat=1, else 0 +- reward = sat, gating = 1 always + +What we're checking: +1. No crashes (KeyError, NoneType, etc.) +2. User vector norms change after feedback (RL is being called) +3. resp.usage returns reasonable numbers +4. Logs are generated correctly +""" + +import sys +import os +import json +from datetime import datetime +from dataclasses import dataclass, asdict +from typing import List, Dict, Any, Optional + +# Add src to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.serving import PersonalizedLLM, Feedback, AssistantResponse + + +# ============================================================================= +# Minimal Judge +# ============================================================================= + +@dataclass +class JudgeResult: + """Output from the judge for one turn.""" + sat_t: float # Satisfaction score [0, 1] + sev_t: float # Severity of violations [0, 1] + prog_t: float # Task progress [0, 1] + violations: List[str] # List of violated constraints + + +def minimal_judge(query: str, answer: str, task_type: str = "general") -> JudgeResult: + """ + Minimal rule-based judge for pilot. + + For now: + - sat_t = 1 if answer is non-empty, else 0 + - sev_t = 0 (no severity tracking yet) + - prog_t = 1 if answer looks reasonable, else 0 + """ + violations = [] + + # Check 1: Answer is non-empty + if not answer or len(answer.strip()) < 5: + violations.append("empty_answer") + return JudgeResult(sat_t=0.0, sev_t=1.0, prog_t=0.0, violations=violations) + + # Check 2: Answer is not too short (at least 20 chars for real content) + if len(answer.strip()) < 20: + violations.append("too_short") + + # Check 3: For code tasks, look for code markers + if task_type == "code": + has_code = "```" in answer or "def " in answer or "function" in answer + if not has_code: + violations.append("no_code_block") + + # Calculate scores + sat_t = 1.0 if len(violations) == 0 else max(0.0, 1.0 - 0.3 * len(violations)) + sev_t = 1.0 if "empty_answer" in violations else 0.0 + prog_t = 1.0 if "empty_answer" not in violations else 0.0 + + return JudgeResult(sat_t=sat_t, sev_t=sev_t, prog_t=prog_t, violations=violations) + + +# ============================================================================= +# Minimal User Simulator (Fixed Queries) +# ============================================================================= + +def get_fixed_queries() -> List[Dict[str, Any]]: + """ + Return fixed queries for pilot test. + Mix of preference statements and tasks. + """ + return [ + { + "query": "I prefer short, concise answers. Please keep responses under 100 words.", + "type": "preference", + "task_type": "general", + }, + { + "query": "What are three tips for better sleep?", + "type": "task", + "task_type": "general", + }, + { + "query": "I also prefer bullet points when listing things.", + "type": "preference", + "task_type": "general", + }, + { + "query": "What are the main benefits of exercise?", + "type": "task", + "task_type": "general", + }, + { + "query": "Summarize what you know about my preferences.", + "type": "task", + "task_type": "general", + }, + ] + + +# ============================================================================= +# Logging +# ============================================================================= + +@dataclass +class TurnLog: + """Log entry for one turn.""" + turn_id: int + query: str + query_type: str + answer: str + answer_length: int + sat_t: float + sev_t: float + prog_t: float + violations: List[str] + reward: float + gating: float + z_long_norm_before: float + z_long_norm_after: float + z_short_norm_before: float + z_short_norm_after: float + prompt_tokens: int + completion_tokens: int + total_tokens: int + num_memories_retrieved: int + num_prefs_extracted: int + + +def log_to_jsonl(logs: List[TurnLog], filepath: str): + """Save logs to JSONL file.""" + os.makedirs(os.path.dirname(filepath), exist_ok=True) + with open(filepath, "w") as f: + for log in logs: + f.write(json.dumps(asdict(log)) + "\n") + + +# ============================================================================= +# Pilot Runner +# ============================================================================= + +def run_pilot( + llm: PersonalizedLLM, + user_id: str = "pilot_user_0", + queries: Optional[List[Dict[str, Any]]] = None, +) -> List[TurnLog]: + """ + Run a single pilot session. + + Returns list of turn logs. + """ + if queries is None: + queries = get_fixed_queries() + + logs: List[TurnLog] = [] + + print(f"\n{'='*60}") + print(f"PILOT SESSION: user_id={user_id}, turns={len(queries)}") + print(f"{'='*60}") + + # Reset user for clean start + print(f"\n[Pilot] Resetting user: {user_id}") + llm.reset_user(user_id) + + # Start session + print(f"[Pilot] Starting session") + llm.reset_session(user_id) + + # Get initial state + state_before = llm.get_user_state_summary(user_id) + print(f"[Pilot] Initial state: z_long_norm={state_before['z_long_norm']:.6f}, z_short_norm={state_before['z_short_norm']:.6f}") + + for turn_id, q_info in enumerate(queries): + query = q_info["query"] + query_type = q_info.get("type", "task") + task_type = q_info.get("task_type", "general") + + print(f"\n--- Turn {turn_id} ---") + print(f"[Query] ({query_type}) {query[:80]}...") + + # Get state before + state_before = llm.get_user_state_summary(user_id) + z_long_before = state_before["z_long_norm"] + z_short_before = state_before["z_short_norm"] + + # Apply feedback for previous turn (from turn 1 onwards) + if turn_id > 0 and len(logs) > 0: + prev_log = logs[-1] + feedback = Feedback( + user_id=user_id, + turn_id=turn_id - 1, + reward=prev_log.reward, + gating=prev_log.gating, + meta={"source": "pilot_v0"} + ) + print(f"[Feedback] Applying: reward={feedback.reward:.2f}, gating={feedback.gating:.1f}") + llm.apply_feedback(feedback) + + # Chat + resp: AssistantResponse = llm.chat(user_id, query) + + print(f"[Answer] {resp.answer[:100]}..." if len(resp.answer) > 100 else f"[Answer] {resp.answer}") + print(f"[Usage] prompt={resp.usage.prompt_tokens}, completion={resp.usage.completion_tokens}") + + # Judge + judge_result = minimal_judge(query, resp.answer, task_type) + print(f"[Judge] sat={judge_result.sat_t:.2f}, prog={judge_result.prog_t:.2f}, violations={judge_result.violations}") + + # Compute reward and gating + reward = judge_result.sat_t # Simple: reward = satisfaction + gating = 1.0 # Always allow learning for pilot + + # Get state after + state_after = llm.get_user_state_summary(user_id) + z_long_after = state_after["z_long_norm"] + z_short_after = state_after["z_short_norm"] + + # Debug info + num_memories = len(resp.debug.selected_memory_ids) if resp.debug else 0 + num_prefs = len(resp.debug.extracted_preferences) if resp.debug else 0 + + print(f"[State] z_long: {z_long_before:.6f} -> {z_long_after:.6f}, z_short: {z_short_before:.6f} -> {z_short_after:.6f}") + print(f"[Debug] memories={num_memories}, prefs_extracted={num_prefs}") + + # Log + log = TurnLog( + turn_id=turn_id, + query=query, + query_type=query_type, + answer=resp.answer, + answer_length=len(resp.answer), + sat_t=judge_result.sat_t, + sev_t=judge_result.sev_t, + prog_t=judge_result.prog_t, + violations=judge_result.violations, + reward=reward, + gating=gating, + z_long_norm_before=z_long_before, + z_long_norm_after=z_long_after, + z_short_norm_before=z_short_before, + z_short_norm_after=z_short_after, + prompt_tokens=resp.usage.prompt_tokens, + completion_tokens=resp.usage.completion_tokens, + total_tokens=resp.usage.total_tokens, + num_memories_retrieved=num_memories, + num_prefs_extracted=num_prefs, + ) + logs.append(log) + + # Apply final feedback + if len(logs) > 0: + last_log = logs[-1] + feedback = Feedback( + user_id=user_id, + turn_id=len(queries) - 1, + reward=last_log.reward, + gating=last_log.gating, + meta={"source": "pilot_v0", "final": True} + ) + print(f"\n[Final Feedback] reward={feedback.reward:.2f}, gating={feedback.gating:.1f}") + llm.apply_feedback(feedback) + + return logs + + +def print_summary(logs: List[TurnLog]): + """Print summary statistics.""" + print(f"\n{'='*60}") + print("PILOT SUMMARY") + print(f"{'='*60}") + + total_turns = len(logs) + avg_sat = sum(l.sat_t for l in logs) / total_turns if total_turns > 0 else 0 + avg_prog = sum(l.prog_t for l in logs) / total_turns if total_turns > 0 else 0 + total_tokens = sum(l.total_tokens for l in logs) + total_prompt = sum(l.prompt_tokens for l in logs) + total_completion = sum(l.completion_tokens for l in logs) + + # Check if RL updates happened (vector norms changed) + z_long_changes = [abs(l.z_long_norm_after - l.z_long_norm_before) for l in logs] + z_short_changes = [abs(l.z_short_norm_after - l.z_short_norm_before) for l in logs] + any_z_long_change = any(c > 1e-6 for c in z_long_changes) + any_z_short_change = any(c > 1e-6 for c in z_short_changes) + + print(f"Total turns: {total_turns}") + print(f"Average satisfaction: {avg_sat:.3f}") + print(f"Average progress: {avg_prog:.3f}") + print(f"Total tokens: {total_tokens} (prompt: {total_prompt}, completion: {total_completion})") + print(f"z_long changed: {any_z_long_change} (max delta: {max(z_long_changes):.6f})") + print(f"z_short changed: {any_z_short_change} (max delta: {max(z_short_changes):.6f})") + + # Violations breakdown + all_violations = [v for l in logs for v in l.violations] + if all_violations: + from collections import Counter + print(f"Violations: {dict(Counter(all_violations))}") + else: + print("Violations: None") + + # RL Health Check + print(f"\n--- RL Health Check ---") + if any_z_long_change or any_z_short_change: + print("✓ User vectors ARE being updated by RL") + else: + print("✗ WARNING: User vectors NOT changing - check apply_feedback") + + +def main(): + print("=" * 60) + print("PILOT RUNNER v0") + print("=" * 60) + print(f"Started at: {datetime.now().isoformat()}") + + # Initialize LLM + print("\n[Init] Loading PersonalizedLLM...") + llm = PersonalizedLLM( + user_store_path="data/users/user_store_pilot.npz", + only_own_memories=True, + enable_preference_extraction=True, + enable_rl_updates=True, + ) + + # Run pilot + user_id = "pilot_user_0" + logs = run_pilot(llm, user_id=user_id) + + # Summary + print_summary(logs) + + # Save logs + log_path = f"data/logs/pilot_v0_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl" + log_to_jsonl(logs, log_path) + print(f"\n[Logs] Saved to: {log_path}") + + # Final state + final_state = llm.get_user_state_summary(user_id) + print(f"\n[Final State] {final_state}") + + print(f"\nCompleted at: {datetime.now().isoformat()}") + print("=" * 60) + + +if __name__ == "__main__": + main() + |
