diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
| commit | e43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch) | |
| tree | 6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /scripts/pilot_runner_v1.py | |
Diffstat (limited to 'scripts/pilot_runner_v1.py')
| -rw-r--r-- | scripts/pilot_runner_v1.py | 607 |
1 files changed, 607 insertions, 0 deletions
diff --git a/scripts/pilot_runner_v1.py b/scripts/pilot_runner_v1.py new file mode 100644 index 0000000..fbb2876 --- /dev/null +++ b/scripts/pilot_runner_v1.py @@ -0,0 +1,607 @@ +#!/usr/bin/env python3 +""" +Pilot Runner v1 - Style-Aware Judge + Gating Logic + +Upgrade from v0: +- StylePrefs: User style preferences (length, bullets, language) +- style_judge: Checks style conformance, not just non-empty +- compute_feedback_for_turn: gating=1 only for preference-related turns +- Extended queries: ~10 turns with preference/task mix + +Goal: Verify that: +1. sat_t varies based on style violations (not always 1) +2. gating=1 only on preference turns, 0 on regular tasks +3. RL updates happen when gating=1 and reward != baseline +4. Over turns, model may adapt to preferences (sat_t improves) +""" + +import sys +import os +import json +from datetime import datetime +from dataclasses import dataclass, asdict, field +from typing import List, Dict, Any, Optional, Tuple + +# Add src to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.serving import PersonalizedLLM, Feedback, AssistantResponse + + +# ============================================================================= +# Style Preferences +# ============================================================================= + +@dataclass +class StylePrefs: + """User's style preferences for the judge to check.""" + require_short: bool = False + max_chars: int = 300 + require_bullets: bool = False + lang: str = "en" # "en" or "zh" + + +# ============================================================================= +# Style-Aware Judge +# ============================================================================= + +@dataclass +class JudgeResult: + """Output from the judge for one turn.""" + sat_t: float # Satisfaction score [0, 1] + sev_t: float # Severity of violations [0, 1] + prog_t: float # Task progress [0, 1] + violations: List[str] # List of violated constraints + + +def style_judge( + query: str, + answer: str, + task_type: str, + prefs: StylePrefs, +) -> JudgeResult: + """ + Style-aware judge that checks: + - Empty/too short answer + - Length constraint (max_chars) + - Bullet point requirement + - Language preference + - Code block for code tasks + + Returns: + JudgeResult with sat_t, sev_t, prog_t, and violations list. + """ + violations: List[str] = [] + text = (answer or "").strip() + + # 0) Empty answer - immediate fail + if not text or len(text) < 5: + violations.append("empty_answer") + return JudgeResult( + sat_t=0.0, + sev_t=1.0, + prog_t=0.0, + violations=violations, + ) + + # 1) Length preference + if prefs.require_short: + if len(text) > prefs.max_chars: + violations.append("too_long") + + # 2) Bullet preference (only for general/list tasks, not pure preference statements) + if prefs.require_bullets and task_type in ("general", "list"): + has_bullets = ("- " in text) or ("• " in text) or ("* " in text) or ("\n- " in text) + if not has_bullets: + violations.append("no_bullets") + + # 3) Language preference (rough heuristic) + if prefs.lang == "zh": + # For Chinese preference, check if answer has enough non-ASCII chars + ascii_count = sum(c.isascii() for c in text) + ascii_ratio = ascii_count / max(1, len(text)) + if ascii_ratio > 0.7: # Too much ASCII = probably not Chinese + violations.append("wrong_lang") + elif prefs.lang == "en": + # For English preference, check if answer is mostly ASCII + ascii_count = sum(c.isascii() for c in text) + ascii_ratio = ascii_count / max(1, len(text)) + if ascii_ratio < 0.5: # Too little ASCII = probably not English + violations.append("wrong_lang") + + # 4) Code task: must have code markers + prog_t = 1.0 + if task_type == "code": + has_code = ("```" in text) or ("def " in text) or ("function " in text) + if not has_code: + violations.append("no_code_block") + prog_t = 0.0 + + # 5) Compute sat_t and sev_t from violations + if not violations: + sat_t = 1.0 + sev_t = 0.0 + else: + # Each violation costs 0.3, minimum 0 + sat_t = max(0.0, 1.0 - 0.3 * float(len(violations))) + # Hard violations trigger sev_t=1 + hard_violations = {"empty_answer", "too_long", "wrong_lang"} + sev_t = 1.0 if any(v in hard_violations for v in violations) else 0.0 + + return JudgeResult( + sat_t=sat_t, + sev_t=sev_t, + prog_t=prog_t, + violations=violations, + ) + + +# ============================================================================= +# Feedback Computation (reward + gating) +# ============================================================================= + +def compute_feedback_for_turn( + turn_id: int, + query: str, + query_type: str, + task_type: str, + judge_result: JudgeResult, +) -> Tuple[float, float]: + """ + Convert JudgeResult into (reward, gating): + - reward = sat_t (style satisfaction) + - gating = 1 only if this turn is preference-related (declared or complained) + + Args: + turn_id: The turn index + query: The user's query text + query_type: "preference" or "task" from query metadata + task_type: "general", "list", "code", etc. + judge_result: The judge's evaluation + + Returns: + (reward, gating) tuple + """ + reward = judge_result.sat_t + + # Gating logic: only allow RL update on preference-related turns + # 1. Explicit preference declaration (query_type == "preference") + # 2. Complaint about not following preference + lower_q = (query or "").lower() + + is_pref_turn = ( + query_type == "preference" + or "i prefer" in lower_q + or "my preference" in lower_q + or "please use" in lower_q + or "please keep" in lower_q + or "you didn't follow" in lower_q + or "you forgot" in lower_q + or "remember that i" in lower_q + or "i told you" in lower_q + or "i asked for" in lower_q + ) + + if is_pref_turn: + gating = 1.0 + else: + gating = 0.0 + + return reward, gating + + +# ============================================================================= +# Extended Queries for Pilot v1 (~10 turns) +# ============================================================================= + +def get_pilot_v1_queries() -> List[Dict[str, Any]]: + """ + Extended query set for pilot v1. + Mix of preference declarations and tasks. + Tests: length constraint, bullet points, task completion. + """ + return [ + # Turn 0: Declare length preference + { + "query": "I prefer short, concise answers. Please keep responses under 200 characters.", + "type": "preference", + "task_type": "general", + }, + # Turn 1: Task that should be short + { + "query": "What are three tips for better sleep?", + "type": "task", + "task_type": "list", + }, + # Turn 2: Declare bullet preference + { + "query": "I also prefer bullet points when listing things. Please use bullet points.", + "type": "preference", + "task_type": "general", + }, + # Turn 3: Task that should use bullets + { + "query": "What are the main benefits of regular exercise?", + "type": "task", + "task_type": "list", + }, + # Turn 4: Another task (test if preferences stick) + { + "query": "Name five popular programming languages.", + "type": "task", + "task_type": "list", + }, + # Turn 5: Complaint if needed (always include to test gating) + { + "query": "Remember that I asked for short answers with bullet points. Can you list three healthy breakfast ideas?", + "type": "preference", + "task_type": "list", + }, + # Turn 6: Regular task + { + "query": "What is the capital of France?", + "type": "task", + "task_type": "general", + }, + # Turn 7: Task requiring list + { + "query": "What are four seasons of the year?", + "type": "task", + "task_type": "list", + }, + # Turn 8: Another preference reminder + { + "query": "I prefer concise bullet points. Please list three types of renewable energy.", + "type": "preference", + "task_type": "list", + }, + # Turn 9: Final task - test memory + { + "query": "Summarize what you know about my communication preferences.", + "type": "task", + "task_type": "general", + }, + ] + + +# ============================================================================= +# Logging +# ============================================================================= + +@dataclass +class TurnLog: + """Log entry for one turn.""" + turn_id: int + query: str + query_type: str + task_type: str + answer: str + answer_length: int + sat_t: float + sev_t: float + prog_t: float + violations: List[str] + reward: float + gating: float + z_long_norm_before: float + z_long_norm_after: float + z_short_norm_before: float + z_short_norm_after: float + prompt_tokens: int + completion_tokens: int + total_tokens: int + num_memories_retrieved: int + num_prefs_extracted: int + + +def log_to_jsonl(logs: List[TurnLog], filepath: str): + """Save logs to JSONL file.""" + os.makedirs(os.path.dirname(filepath), exist_ok=True) + with open(filepath, "w") as f: + for log in logs: + f.write(json.dumps(asdict(log)) + "\n") + + +# ============================================================================= +# Pilot Runner v1 +# ============================================================================= + +def run_pilot_v1( + llm: PersonalizedLLM, + user_id: str = "pilot_user_v1", + prefs: Optional[StylePrefs] = None, + queries: Optional[List[Dict[str, Any]]] = None, +) -> List[TurnLog]: + """ + Run pilot v1 with style-aware judge and gating. + + Args: + llm: PersonalizedLLM instance + user_id: User identifier + prefs: Style preferences for this user + queries: Query list (defaults to get_pilot_v1_queries) + + Returns: + List of TurnLog entries + """ + if prefs is None: + # Default preferences: short + bullets + English + prefs = StylePrefs( + require_short=True, + max_chars=200, + require_bullets=True, + lang="en", + ) + + if queries is None: + queries = get_pilot_v1_queries() + + logs: List[TurnLog] = [] + + print(f"\n{'='*60}") + print(f"PILOT v1 SESSION: user_id={user_id}, turns={len(queries)}") + print(f"Preferences: short={prefs.require_short}, max_chars={prefs.max_chars}, bullets={prefs.require_bullets}, lang={prefs.lang}") + print(f"{'='*60}") + + # Reset user for clean start + print(f"\n[Pilot] Resetting user: {user_id}") + llm.reset_user(user_id) + + # Start session + print(f"[Pilot] Starting session") + llm.reset_session(user_id) + + # Get initial state + state_before = llm.get_user_state_summary(user_id) + print(f"[Pilot] Initial state: z_long={state_before['z_long_norm']:.6f}, z_short={state_before['z_short_norm']:.6f}") + + for turn_id, q_info in enumerate(queries): + query = q_info["query"] + query_type = q_info.get("type", "task") + task_type = q_info.get("task_type", "general") + + print(f"\n{'─'*60}") + print(f"Turn {turn_id} [{query_type}]") + print(f"{'─'*60}") + print(f"[Query] {query}") + + # Get state before + state_before = llm.get_user_state_summary(user_id) + z_long_before = state_before["z_long_norm"] + z_short_before = state_before["z_short_norm"] + + # Apply feedback for previous turn (from turn 1 onwards) + if turn_id > 0 and len(logs) > 0: + prev_log = logs[-1] + prev_query = queries[turn_id - 1] + + # Re-judge the previous answer with current context + # (In practice we already have the result, but this shows the flow) + feedback = Feedback( + user_id=user_id, + turn_id=turn_id - 1, + reward=prev_log.reward, + gating=prev_log.gating, + meta={ + "sat_t": prev_log.sat_t, + "sev_t": prev_log.sev_t, + "prog_t": prev_log.prog_t, + "violations": prev_log.violations, + "task_type": prev_log.task_type, + "source": "pilot_v1", + } + ) + print(f"[Feedback] turn={turn_id-1}, reward={feedback.reward:.2f}, gating={feedback.gating:.1f}") + llm.apply_feedback(feedback) + + # Chat + resp: AssistantResponse = llm.chat(user_id, query) + + # Truncate answer for display + answer_display = resp.answer[:150] + "..." if len(resp.answer) > 150 else resp.answer + print(f"[Answer] ({len(resp.answer)} chars) {answer_display}") + print(f"[Usage] prompt={resp.usage.prompt_tokens}, completion={resp.usage.completion_tokens}") + + # Judge with style preferences + judge_result = style_judge(query, resp.answer, task_type, prefs) + print(f"[Judge] sat={judge_result.sat_t:.2f}, sev={judge_result.sev_t:.1f}, prog={judge_result.prog_t:.1f}") + if judge_result.violations: + print(f"[Judge] violations={judge_result.violations}") + + # Compute feedback for THIS turn (will be applied next turn) + reward, gating = compute_feedback_for_turn( + turn_id=turn_id, + query=query, + query_type=query_type, + task_type=task_type, + judge_result=judge_result, + ) + print(f"[Feedback] reward={reward:.2f}, gating={gating:.1f} (computed for this turn)") + + # Get state after + state_after = llm.get_user_state_summary(user_id) + z_long_after = state_after["z_long_norm"] + z_short_after = state_after["z_short_norm"] + + # Debug info + num_memories = len(resp.debug.selected_memory_ids) if resp.debug else 0 + num_prefs = len(resp.debug.extracted_preferences) if resp.debug else 0 + + z_long_delta = z_long_after - z_long_before + z_short_delta = z_short_after - z_short_before + print(f"[State] z_long: {z_long_before:.6f} → {z_long_after:.6f} (Δ={z_long_delta:+.6f})") + print(f"[State] z_short: {z_short_before:.6f} → {z_short_after:.6f} (Δ={z_short_delta:+.6f})") + print(f"[Debug] memories={num_memories}, prefs_extracted={num_prefs}") + + # Log + log = TurnLog( + turn_id=turn_id, + query=query, + query_type=query_type, + task_type=task_type, + answer=resp.answer, + answer_length=len(resp.answer), + sat_t=judge_result.sat_t, + sev_t=judge_result.sev_t, + prog_t=judge_result.prog_t, + violations=judge_result.violations, + reward=reward, + gating=gating, + z_long_norm_before=z_long_before, + z_long_norm_after=z_long_after, + z_short_norm_before=z_short_before, + z_short_norm_after=z_short_after, + prompt_tokens=resp.usage.prompt_tokens, + completion_tokens=resp.usage.completion_tokens, + total_tokens=resp.usage.total_tokens, + num_memories_retrieved=num_memories, + num_prefs_extracted=num_prefs, + ) + logs.append(log) + + # Apply final feedback for last turn + if len(logs) > 0: + last_log = logs[-1] + feedback = Feedback( + user_id=user_id, + turn_id=len(queries) - 1, + reward=last_log.reward, + gating=last_log.gating, + meta={"source": "pilot_v1", "final": True} + ) + print(f"\n[Final Feedback] turn={len(queries)-1}, reward={feedback.reward:.2f}, gating={feedback.gating:.1f}") + llm.apply_feedback(feedback) + + return logs + + +def print_summary_v1(logs: List[TurnLog], prefs: StylePrefs): + """Print summary statistics for pilot v1.""" + print(f"\n{'='*60}") + print("PILOT v1 SUMMARY") + print(f"{'='*60}") + + total_turns = len(logs) + if total_turns == 0: + print("No turns to summarize.") + return + + # Basic stats + avg_sat = sum(l.sat_t for l in logs) / total_turns + avg_prog = sum(l.prog_t for l in logs) / total_turns + total_tokens = sum(l.total_tokens for l in logs) + total_prompt = sum(l.prompt_tokens for l in logs) + total_completion = sum(l.completion_tokens for l in logs) + + # Gating stats + gated_turns = [l for l in logs if l.gating > 0] + non_gated_turns = [l for l in logs if l.gating == 0] + + print(f"\n--- Turn Statistics ---") + print(f"Total turns: {total_turns}") + print(f"Gated turns (RL active): {len(gated_turns)}") + print(f"Non-gated turns (RL skipped): {len(non_gated_turns)}") + + print(f"\n--- Satisfaction ---") + print(f"Average sat_t (all): {avg_sat:.3f}") + if gated_turns: + avg_sat_gated = sum(l.sat_t for l in gated_turns) / len(gated_turns) + print(f"Average sat_t (gated only): {avg_sat_gated:.3f}") + print(f"Average prog_t: {avg_prog:.3f}") + + print(f"\n--- Token Usage ---") + print(f"Total tokens: {total_tokens}") + print(f" Prompt: {total_prompt}") + print(f" Completion: {total_completion}") + print(f"Avg tokens/turn: {total_tokens / total_turns:.1f}") + + # Violations breakdown + print(f"\n--- Violations ---") + from collections import Counter + all_violations = [v for l in logs for v in l.violations] + if all_violations: + print(f"Total violations: {len(all_violations)}") + for v, count in Counter(all_violations).most_common(): + print(f" {v}: {count}") + else: + print("No violations") + + # Answer length analysis + print(f"\n--- Answer Lengths (max_chars={prefs.max_chars}) ---") + lengths = [l.answer_length for l in logs] + over_limit = sum(1 for l in lengths if l > prefs.max_chars) + print(f"Min: {min(lengths)}, Max: {max(lengths)}, Avg: {sum(lengths)/len(lengths):.1f}") + print(f"Over limit: {over_limit}/{total_turns}") + + # RL Health Check + print(f"\n--- RL Health Check ---") + z_long_changes = [abs(l.z_long_norm_after - l.z_long_norm_before) for l in logs] + z_short_changes = [abs(l.z_short_norm_after - l.z_short_norm_before) for l in logs] + any_z_long_change = any(c > 1e-6 for c in z_long_changes) + any_z_short_change = any(c > 1e-6 for c in z_short_changes) + + print(f"z_long changed: {any_z_long_change} (max Δ: {max(z_long_changes):.6f})") + print(f"z_short changed: {any_z_short_change} (max Δ: {max(z_short_changes):.6f})") + + if any_z_long_change or any_z_short_change: + print("✓ User vectors ARE being updated by RL") + else: + print("✗ WARNING: User vectors NOT changing") + print(" Check: gating=1 on some turns? reward != baseline?") + + # Per-turn detail table + print(f"\n--- Turn-by-Turn Summary ---") + print(f"{'Turn':>4} {'Type':>10} {'Len':>5} {'sat':>5} {'gate':>5} {'violations'}") + print("-" * 60) + for l in logs: + viol_str = ",".join(l.violations) if l.violations else "-" + print(f"{l.turn_id:>4} {l.query_type:>10} {l.answer_length:>5} {l.sat_t:>5.2f} {l.gating:>5.1f} {viol_str}") + + +def main(): + print("=" * 60) + print("PILOT RUNNER v1 - Style-Aware Judge + Gating") + print("=" * 60) + print(f"Started at: {datetime.now().isoformat()}") + + # Define user preferences + prefs = StylePrefs( + require_short=True, + max_chars=200, + require_bullets=True, + lang="en", + ) + print(f"\n[Config] User preferences: {prefs}") + + # Initialize LLM + print("\n[Init] Loading PersonalizedLLM...") + llm = PersonalizedLLM( + user_store_path="data/users/user_store_pilot_v1.npz", + only_own_memories=True, + enable_preference_extraction=True, + enable_rl_updates=True, + ) + + # Run pilot + user_id = "pilot_user_v1" + logs = run_pilot_v1(llm, user_id=user_id, prefs=prefs) + + # Summary + print_summary_v1(logs, prefs) + + # Save logs + log_path = f"data/logs/pilot_v1_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl" + log_to_jsonl(logs, log_path) + print(f"\n[Logs] Saved to: {log_path}") + + # Final state + final_state = llm.get_user_state_summary(user_id) + print(f"\n[Final State] {final_state}") + + print(f"\nCompleted at: {datetime.now().isoformat()}") + print("=" * 60) + + +if __name__ == "__main__": + main() + |
