summaryrefslogtreecommitdiff
path: root/scripts/pilot_runner_v0.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/pilot_runner_v0.py')
-rw-r--r--scripts/pilot_runner_v0.py362
1 files changed, 362 insertions, 0 deletions
diff --git a/scripts/pilot_runner_v0.py b/scripts/pilot_runner_v0.py
new file mode 100644
index 0000000..8b7773a
--- /dev/null
+++ b/scripts/pilot_runner_v0.py
@@ -0,0 +1,362 @@
+#!/usr/bin/env python3
+"""
+Pilot Runner v0 - Minimal End-to-End Test
+
+Goal: Prove the chat → judge → apply_feedback → next query loop works.
+
+Setup:
+- 1 user × 1 session × 5 turns
+- Fixed queries (no fancy user simulator yet)
+- Rule-based judge: answer non-empty → sat=1, else 0
+- reward = sat, gating = 1 always
+
+What we're checking:
+1. No crashes (KeyError, NoneType, etc.)
+2. User vector norms change after feedback (RL is being called)
+3. resp.usage returns reasonable numbers
+4. Logs are generated correctly
+"""
+
+import sys
+import os
+import json
+from datetime import datetime
+from dataclasses import dataclass, asdict
+from typing import List, Dict, Any, Optional
+
+# Add src to path
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.serving import PersonalizedLLM, Feedback, AssistantResponse
+
+
+# =============================================================================
+# Minimal Judge
+# =============================================================================
+
+@dataclass
+class JudgeResult:
+ """Output from the judge for one turn."""
+ sat_t: float # Satisfaction score [0, 1]
+ sev_t: float # Severity of violations [0, 1]
+ prog_t: float # Task progress [0, 1]
+ violations: List[str] # List of violated constraints
+
+
+def minimal_judge(query: str, answer: str, task_type: str = "general") -> JudgeResult:
+ """
+ Minimal rule-based judge for pilot.
+
+ For now:
+ - sat_t = 1 if answer is non-empty, else 0
+ - sev_t = 0 (no severity tracking yet)
+ - prog_t = 1 if answer looks reasonable, else 0
+ """
+ violations = []
+
+ # Check 1: Answer is non-empty
+ if not answer or len(answer.strip()) < 5:
+ violations.append("empty_answer")
+ return JudgeResult(sat_t=0.0, sev_t=1.0, prog_t=0.0, violations=violations)
+
+ # Check 2: Answer is not too short (at least 20 chars for real content)
+ if len(answer.strip()) < 20:
+ violations.append("too_short")
+
+ # Check 3: For code tasks, look for code markers
+ if task_type == "code":
+ has_code = "```" in answer or "def " in answer or "function" in answer
+ if not has_code:
+ violations.append("no_code_block")
+
+ # Calculate scores
+ sat_t = 1.0 if len(violations) == 0 else max(0.0, 1.0 - 0.3 * len(violations))
+ sev_t = 1.0 if "empty_answer" in violations else 0.0
+ prog_t = 1.0 if "empty_answer" not in violations else 0.0
+
+ return JudgeResult(sat_t=sat_t, sev_t=sev_t, prog_t=prog_t, violations=violations)
+
+
+# =============================================================================
+# Minimal User Simulator (Fixed Queries)
+# =============================================================================
+
+def get_fixed_queries() -> List[Dict[str, Any]]:
+ """
+ Return fixed queries for pilot test.
+ Mix of preference statements and tasks.
+ """
+ return [
+ {
+ "query": "I prefer short, concise answers. Please keep responses under 100 words.",
+ "type": "preference",
+ "task_type": "general",
+ },
+ {
+ "query": "What are three tips for better sleep?",
+ "type": "task",
+ "task_type": "general",
+ },
+ {
+ "query": "I also prefer bullet points when listing things.",
+ "type": "preference",
+ "task_type": "general",
+ },
+ {
+ "query": "What are the main benefits of exercise?",
+ "type": "task",
+ "task_type": "general",
+ },
+ {
+ "query": "Summarize what you know about my preferences.",
+ "type": "task",
+ "task_type": "general",
+ },
+ ]
+
+
+# =============================================================================
+# Logging
+# =============================================================================
+
+@dataclass
+class TurnLog:
+ """Log entry for one turn."""
+ turn_id: int
+ query: str
+ query_type: str
+ answer: str
+ answer_length: int
+ sat_t: float
+ sev_t: float
+ prog_t: float
+ violations: List[str]
+ reward: float
+ gating: float
+ z_long_norm_before: float
+ z_long_norm_after: float
+ z_short_norm_before: float
+ z_short_norm_after: float
+ prompt_tokens: int
+ completion_tokens: int
+ total_tokens: int
+ num_memories_retrieved: int
+ num_prefs_extracted: int
+
+
+def log_to_jsonl(logs: List[TurnLog], filepath: str):
+ """Save logs to JSONL file."""
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
+ with open(filepath, "w") as f:
+ for log in logs:
+ f.write(json.dumps(asdict(log)) + "\n")
+
+
+# =============================================================================
+# Pilot Runner
+# =============================================================================
+
+def run_pilot(
+ llm: PersonalizedLLM,
+ user_id: str = "pilot_user_0",
+ queries: Optional[List[Dict[str, Any]]] = None,
+) -> List[TurnLog]:
+ """
+ Run a single pilot session.
+
+ Returns list of turn logs.
+ """
+ if queries is None:
+ queries = get_fixed_queries()
+
+ logs: List[TurnLog] = []
+
+ print(f"\n{'='*60}")
+ print(f"PILOT SESSION: user_id={user_id}, turns={len(queries)}")
+ print(f"{'='*60}")
+
+ # Reset user for clean start
+ print(f"\n[Pilot] Resetting user: {user_id}")
+ llm.reset_user(user_id)
+
+ # Start session
+ print(f"[Pilot] Starting session")
+ llm.reset_session(user_id)
+
+ # Get initial state
+ state_before = llm.get_user_state_summary(user_id)
+ print(f"[Pilot] Initial state: z_long_norm={state_before['z_long_norm']:.6f}, z_short_norm={state_before['z_short_norm']:.6f}")
+
+ for turn_id, q_info in enumerate(queries):
+ query = q_info["query"]
+ query_type = q_info.get("type", "task")
+ task_type = q_info.get("task_type", "general")
+
+ print(f"\n--- Turn {turn_id} ---")
+ print(f"[Query] ({query_type}) {query[:80]}...")
+
+ # Get state before
+ state_before = llm.get_user_state_summary(user_id)
+ z_long_before = state_before["z_long_norm"]
+ z_short_before = state_before["z_short_norm"]
+
+ # Apply feedback for previous turn (from turn 1 onwards)
+ if turn_id > 0 and len(logs) > 0:
+ prev_log = logs[-1]
+ feedback = Feedback(
+ user_id=user_id,
+ turn_id=turn_id - 1,
+ reward=prev_log.reward,
+ gating=prev_log.gating,
+ meta={"source": "pilot_v0"}
+ )
+ print(f"[Feedback] Applying: reward={feedback.reward:.2f}, gating={feedback.gating:.1f}")
+ llm.apply_feedback(feedback)
+
+ # Chat
+ resp: AssistantResponse = llm.chat(user_id, query)
+
+ print(f"[Answer] {resp.answer[:100]}..." if len(resp.answer) > 100 else f"[Answer] {resp.answer}")
+ print(f"[Usage] prompt={resp.usage.prompt_tokens}, completion={resp.usage.completion_tokens}")
+
+ # Judge
+ judge_result = minimal_judge(query, resp.answer, task_type)
+ print(f"[Judge] sat={judge_result.sat_t:.2f}, prog={judge_result.prog_t:.2f}, violations={judge_result.violations}")
+
+ # Compute reward and gating
+ reward = judge_result.sat_t # Simple: reward = satisfaction
+ gating = 1.0 # Always allow learning for pilot
+
+ # Get state after
+ state_after = llm.get_user_state_summary(user_id)
+ z_long_after = state_after["z_long_norm"]
+ z_short_after = state_after["z_short_norm"]
+
+ # Debug info
+ num_memories = len(resp.debug.selected_memory_ids) if resp.debug else 0
+ num_prefs = len(resp.debug.extracted_preferences) if resp.debug else 0
+
+ print(f"[State] z_long: {z_long_before:.6f} -> {z_long_after:.6f}, z_short: {z_short_before:.6f} -> {z_short_after:.6f}")
+ print(f"[Debug] memories={num_memories}, prefs_extracted={num_prefs}")
+
+ # Log
+ log = TurnLog(
+ turn_id=turn_id,
+ query=query,
+ query_type=query_type,
+ answer=resp.answer,
+ answer_length=len(resp.answer),
+ sat_t=judge_result.sat_t,
+ sev_t=judge_result.sev_t,
+ prog_t=judge_result.prog_t,
+ violations=judge_result.violations,
+ reward=reward,
+ gating=gating,
+ z_long_norm_before=z_long_before,
+ z_long_norm_after=z_long_after,
+ z_short_norm_before=z_short_before,
+ z_short_norm_after=z_short_after,
+ prompt_tokens=resp.usage.prompt_tokens,
+ completion_tokens=resp.usage.completion_tokens,
+ total_tokens=resp.usage.total_tokens,
+ num_memories_retrieved=num_memories,
+ num_prefs_extracted=num_prefs,
+ )
+ logs.append(log)
+
+ # Apply final feedback
+ if len(logs) > 0:
+ last_log = logs[-1]
+ feedback = Feedback(
+ user_id=user_id,
+ turn_id=len(queries) - 1,
+ reward=last_log.reward,
+ gating=last_log.gating,
+ meta={"source": "pilot_v0", "final": True}
+ )
+ print(f"\n[Final Feedback] reward={feedback.reward:.2f}, gating={feedback.gating:.1f}")
+ llm.apply_feedback(feedback)
+
+ return logs
+
+
+def print_summary(logs: List[TurnLog]):
+ """Print summary statistics."""
+ print(f"\n{'='*60}")
+ print("PILOT SUMMARY")
+ print(f"{'='*60}")
+
+ total_turns = len(logs)
+ avg_sat = sum(l.sat_t for l in logs) / total_turns if total_turns > 0 else 0
+ avg_prog = sum(l.prog_t for l in logs) / total_turns if total_turns > 0 else 0
+ total_tokens = sum(l.total_tokens for l in logs)
+ total_prompt = sum(l.prompt_tokens for l in logs)
+ total_completion = sum(l.completion_tokens for l in logs)
+
+ # Check if RL updates happened (vector norms changed)
+ z_long_changes = [abs(l.z_long_norm_after - l.z_long_norm_before) for l in logs]
+ z_short_changes = [abs(l.z_short_norm_after - l.z_short_norm_before) for l in logs]
+ any_z_long_change = any(c > 1e-6 for c in z_long_changes)
+ any_z_short_change = any(c > 1e-6 for c in z_short_changes)
+
+ print(f"Total turns: {total_turns}")
+ print(f"Average satisfaction: {avg_sat:.3f}")
+ print(f"Average progress: {avg_prog:.3f}")
+ print(f"Total tokens: {total_tokens} (prompt: {total_prompt}, completion: {total_completion})")
+ print(f"z_long changed: {any_z_long_change} (max delta: {max(z_long_changes):.6f})")
+ print(f"z_short changed: {any_z_short_change} (max delta: {max(z_short_changes):.6f})")
+
+ # Violations breakdown
+ all_violations = [v for l in logs for v in l.violations]
+ if all_violations:
+ from collections import Counter
+ print(f"Violations: {dict(Counter(all_violations))}")
+ else:
+ print("Violations: None")
+
+ # RL Health Check
+ print(f"\n--- RL Health Check ---")
+ if any_z_long_change or any_z_short_change:
+ print("✓ User vectors ARE being updated by RL")
+ else:
+ print("✗ WARNING: User vectors NOT changing - check apply_feedback")
+
+
+def main():
+ print("=" * 60)
+ print("PILOT RUNNER v0")
+ print("=" * 60)
+ print(f"Started at: {datetime.now().isoformat()}")
+
+ # Initialize LLM
+ print("\n[Init] Loading PersonalizedLLM...")
+ llm = PersonalizedLLM(
+ user_store_path="data/users/user_store_pilot.npz",
+ only_own_memories=True,
+ enable_preference_extraction=True,
+ enable_rl_updates=True,
+ )
+
+ # Run pilot
+ user_id = "pilot_user_0"
+ logs = run_pilot(llm, user_id=user_id)
+
+ # Summary
+ print_summary(logs)
+
+ # Save logs
+ log_path = f"data/logs/pilot_v0_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl"
+ log_to_jsonl(logs, log_path)
+ print(f"\n[Logs] Saved to: {log_path}")
+
+ # Final state
+ final_state = llm.get_user_state_summary(user_id)
+ print(f"\n[Final State] {final_state}")
+
+ print(f"\nCompleted at: {datetime.now().isoformat()}")
+ print("=" * 60)
+
+
+if __name__ == "__main__":
+ main()
+