diff options
Diffstat (limited to 'scripts/pilot_runner_v2.py')
| -rw-r--r-- | scripts/pilot_runner_v2.py | 852 |
1 files changed, 852 insertions, 0 deletions
diff --git a/scripts/pilot_runner_v2.py b/scripts/pilot_runner_v2.py new file mode 100644 index 0000000..d3c2aa8 --- /dev/null +++ b/scripts/pilot_runner_v2.py @@ -0,0 +1,852 @@ +#!/usr/bin/env python3 +""" +Pilot Runner v2 - Cross-Session Preference Reveal Mechanism + +Upgrade from v1: +- RevealState: Tracks which preferences have been explicitly revealed by the user +- pref_true[k] vs pref_revealed_global[k] distinction +- Style constraints only enforced AFTER user reveals them +- Reveal state persists across sessions, resets on reset_user() + +Key concepts: +- pref_true[k]: User's true preference (from StylePrefs) +- pref_revealed_global[k]: Whether preference k has been revealed at least once + +Enforcement rule: +- A style constraint is enforced only when BOTH pref_true[k] AND pref_revealed_global[k] + +Session semantics: +- reset_user(): Clears ALL state including reveal flags +- reset_session(): Keeps reveal flags (cross-session memory) +""" + +import sys +import os +import json +from datetime import datetime +from dataclasses import dataclass, asdict, field +from typing import List, Dict, Any, Optional, Tuple, Set + +# Add src to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.serving import PersonalizedLLM, Feedback, AssistantResponse + + +# ============================================================================= +# Style Preferences (True Preferences) +# ============================================================================= + +@dataclass +class StylePrefs: + """ + User's TRUE style preferences. + These are the ground truth preferences that the user actually has, + but they may not have revealed all of them to the system yet. + """ + require_short: bool = False + max_chars: int = 300 + require_bullets: bool = False + lang: str = "en" # "en" or "zh" + + +# ============================================================================= +# Reveal State (What has been explicitly revealed) +# ============================================================================= + +@dataclass +class RevealState: + """ + Tracks which preferences have been explicitly revealed by the user. + + This persists across sessions for the same user but resets on reset_user(). + A preference is revealed when the user explicitly mentions it in a query. + """ + short_revealed: bool = False # "short", "concise", "brief", length constraints + bullets_revealed: bool = False # "bullet", "bullet points", "list format" + lang_revealed: bool = False # Language preference mentioned + + def reset(self): + """Reset all reveal flags (called on reset_user).""" + self.short_revealed = False + self.bullets_revealed = False + self.lang_revealed = False + + def to_dict(self) -> Dict[str, bool]: + return { + "short": self.short_revealed, + "bullets": self.bullets_revealed, + "lang": self.lang_revealed, + } + + def __str__(self) -> str: + flags = [] + if self.short_revealed: + flags.append("short") + if self.bullets_revealed: + flags.append("bullets") + if self.lang_revealed: + flags.append("lang") + return f"RevealState({', '.join(flags) if flags else 'none'})" + + +class RevealStateManager: + """ + Manages reveal state for multiple users. + Persists across sessions, resets on reset_user(). + """ + + def __init__(self): + self._states: Dict[str, RevealState] = {} + + def get_state(self, user_id: str) -> RevealState: + """Get or create reveal state for a user.""" + if user_id not in self._states: + self._states[user_id] = RevealState() + return self._states[user_id] + + def reset_user(self, user_id: str): + """Reset reveal state for a user (called on reset_user).""" + if user_id in self._states: + self._states[user_id].reset() + else: + self._states[user_id] = RevealState() + + def reset_session(self, user_id: str): + """ + Called on reset_session - does NOT reset reveal state. + Reveal state persists across sessions. + """ + # Intentionally do nothing - reveal state persists + pass + + +# ============================================================================= +# Preference Detection from Queries +# ============================================================================= + +def detect_revealed_preferences(query: str) -> Dict[str, bool]: + """ + Detect which preferences are mentioned in a query. + + Returns a dict with keys: "short", "bullets", "lang" + Each value is True if that preference was mentioned. + """ + lower_q = (query or "").lower() + + revealed = { + "short": False, + "bullets": False, + "lang": False, + } + + # Short/length preference detection + short_patterns = [ + "short", "concise", "brief", "under ", "less than", + "keep it short", "keep responses", "keep answers", + "maximum ", "max ", "characters", "words or less", + "200 ", "100 ", "50 ", "300 ", # Common char limits + ] + for pattern in short_patterns: + if pattern in lower_q: + revealed["short"] = True + break + + # Bullet preference detection + bullet_patterns = [ + "bullet", "bullet point", "bullet-point", + "bulleted", "list format", "use bullets", + "use bullet", "with bullets", "in bullets", + "- format", "• ", "numbered list", + ] + for pattern in bullet_patterns: + if pattern in lower_q: + revealed["bullets"] = True + break + + # Language preference detection + lang_patterns_zh = [ + "chinese", "中文", "in chinese", "用中文", + "speak chinese", "write chinese", "respond in chinese", + "please use chinese", "mandarin", + ] + lang_patterns_en = [ + "english", "in english", "use english", + "speak english", "write english", "respond in english", + "please use english", + ] + + for pattern in lang_patterns_zh + lang_patterns_en: + if pattern in lower_q: + revealed["lang"] = True + break + + return revealed + + +def update_reveal_state(reveal_state: RevealState, query: str) -> Set[str]: + """ + Update reveal state based on query content. + Returns set of newly revealed preferences. + """ + detected = detect_revealed_preferences(query) + newly_revealed = set() + + if detected["short"] and not reveal_state.short_revealed: + reveal_state.short_revealed = True + newly_revealed.add("short") + + if detected["bullets"] and not reveal_state.bullets_revealed: + reveal_state.bullets_revealed = True + newly_revealed.add("bullets") + + if detected["lang"] and not reveal_state.lang_revealed: + reveal_state.lang_revealed = True + newly_revealed.add("lang") + + return newly_revealed + + +# ============================================================================= +# Style-Aware Judge with Reveal State +# ============================================================================= + +@dataclass +class JudgeResult: + """Output from the judge for one turn.""" + sat_t: float # Satisfaction score [0, 1] + sev_t: float # Severity of violations [0, 1] + prog_t: float # Task progress [0, 1] + violations: List[str] # List of violated constraints + enforced_constraints: List[str] # Which constraints were actually enforced + + +def style_judge_with_reveal( + query: str, + answer: str, + task_type: str, + prefs: StylePrefs, + reveal_state: RevealState, +) -> JudgeResult: + """ + Style-aware judge that ONLY enforces revealed preferences. + + A constraint is enforced only when: + - pref_true[k] is True (user has this preference) + - pref_revealed_global[k] is True (user has revealed this preference) + + Args: + query: User's query + answer: Assistant's answer + task_type: Type of task ("general", "list", "code") + prefs: User's TRUE preferences (StylePrefs) + reveal_state: Which preferences have been revealed + + Returns: + JudgeResult with sat_t, sev_t, prog_t, violations, and enforced_constraints + """ + violations: List[str] = [] + enforced: List[str] = [] + text = (answer or "").strip() + + # 0) Empty answer - always a violation regardless of reveal state + if not text or len(text) < 5: + violations.append("empty_answer") + return JudgeResult( + sat_t=0.0, + sev_t=1.0, + prog_t=0.0, + violations=violations, + enforced_constraints=["non_empty"], + ) + + # 1) Length preference - enforce only if BOTH true AND revealed + if prefs.require_short and reveal_state.short_revealed: + enforced.append("short") + if len(text) > prefs.max_chars: + violations.append("too_long") + + # 2) Bullet preference - enforce only if BOTH true AND revealed + # Also only for list-type tasks + if prefs.require_bullets and reveal_state.bullets_revealed: + if task_type in ("general", "list"): + enforced.append("bullets") + has_bullets = ("- " in text) or ("• " in text) or ("* " in text) or ("\n- " in text) + if not has_bullets: + violations.append("no_bullets") + + # 3) Language preference - enforce only if BOTH true AND revealed + if reveal_state.lang_revealed: + enforced.append("lang") + if prefs.lang == "zh": + ascii_count = sum(c.isascii() for c in text) + ascii_ratio = ascii_count / max(1, len(text)) + if ascii_ratio > 0.7: + violations.append("wrong_lang") + elif prefs.lang == "en": + ascii_count = sum(c.isascii() for c in text) + ascii_ratio = ascii_count / max(1, len(text)) + if ascii_ratio < 0.5: + violations.append("wrong_lang") + + # 4) Code task: always enforce code markers (not a user preference) + prog_t = 1.0 + if task_type == "code": + enforced.append("code_block") + has_code = ("```" in text) or ("def " in text) or ("function " in text) + if not has_code: + violations.append("no_code_block") + prog_t = 0.0 + + # 5) Compute sat_t and sev_t from violations + if not violations: + sat_t = 1.0 + sev_t = 0.0 + else: + sat_t = max(0.0, 1.0 - 0.3 * float(len(violations))) + hard_violations = {"empty_answer", "too_long", "wrong_lang"} + sev_t = 1.0 if any(v in hard_violations for v in violations) else 0.0 + + return JudgeResult( + sat_t=sat_t, + sev_t=sev_t, + prog_t=prog_t, + violations=violations, + enforced_constraints=enforced, + ) + + +# ============================================================================= +# Feedback Computation (reward + gating) +# ============================================================================= + +def compute_feedback_for_turn( + turn_id: int, + query: str, + query_type: str, + task_type: str, + judge_result: JudgeResult, +) -> Tuple[float, float]: + """ + Convert JudgeResult into (reward, gating). + Same as v1 - reward = sat_t, gating = 1 for preference turns. + """ + reward = judge_result.sat_t + + lower_q = (query or "").lower() + + is_pref_turn = ( + query_type == "preference" + or "i prefer" in lower_q + or "my preference" in lower_q + or "please use" in lower_q + or "please keep" in lower_q + or "you didn't follow" in lower_q + or "you forgot" in lower_q + or "remember that i" in lower_q + or "i told you" in lower_q + or "i asked for" in lower_q + ) + + gating = 1.0 if is_pref_turn else 0.0 + return reward, gating + + +# ============================================================================= +# Multi-Session Queries for Pilot v2 +# ============================================================================= + +def get_session_1_queries() -> List[Dict[str, Any]]: + """ + Session 1: User reveals preferences and does some tasks. + """ + return [ + { + "query": "I prefer short, concise answers. Please keep responses under 200 characters.", + "type": "preference", + "task_type": "general", + }, + { + "query": "What are three tips for better sleep?", + "type": "task", + "task_type": "list", + }, + { + "query": "I also prefer bullet points when listing things.", + "type": "preference", + "task_type": "general", + }, + { + "query": "What are the main benefits of exercise?", + "type": "task", + "task_type": "list", + }, + { + "query": "Name five programming languages.", + "type": "task", + "task_type": "list", + }, + ] + + +def get_session_2_queries() -> List[Dict[str, Any]]: + """ + Session 2: User does NOT restate preferences. + Tests cross-session preference retention. + """ + return [ + { + "query": "What are three healthy breakfast ideas?", + "type": "task", + "task_type": "list", + }, + { + "query": "List four seasons of the year.", + "type": "task", + "task_type": "list", + }, + { + "query": "What is the capital of France?", + "type": "task", + "task_type": "general", + }, + { + "query": "Name three types of renewable energy.", + "type": "task", + "task_type": "list", + }, + ] + + +def get_session_3_queries() -> List[Dict[str, Any]]: + """ + Session 3: Mix of tasks and one complaint/reminder. + """ + return [ + { + "query": "What are five common fruits?", + "type": "task", + "task_type": "list", + }, + { + "query": "Remember that I asked for short bullet points. List three ocean animals.", + "type": "preference", + "task_type": "list", + }, + { + "query": "What is 2 + 2?", + "type": "task", + "task_type": "general", + }, + ] + + +# ============================================================================= +# Logging (Extended for v2) +# ============================================================================= + +@dataclass +class TurnLog: + """Log entry for one turn (extended for v2).""" + session_id: int + turn_id: int + query: str + query_type: str + task_type: str + answer: str + answer_length: int + sat_t: float + sev_t: float + prog_t: float + violations: List[str] + enforced_constraints: List[str] + reward: float + gating: float + reveal_state_before: Dict[str, bool] + reveal_state_after: Dict[str, bool] + newly_revealed: List[str] + z_long_norm_before: float + z_long_norm_after: float + z_short_norm_before: float + z_short_norm_after: float + prompt_tokens: int + completion_tokens: int + total_tokens: int + num_memories_retrieved: int + num_prefs_extracted: int + + +def log_to_jsonl(logs: List[TurnLog], filepath: str): + """Save logs to JSONL file.""" + os.makedirs(os.path.dirname(filepath), exist_ok=True) + with open(filepath, "w") as f: + for log in logs: + f.write(json.dumps(asdict(log)) + "\n") + + +# ============================================================================= +# Pilot Runner v2 (Multi-Session with Reveal State) +# ============================================================================= + +def run_session( + llm: PersonalizedLLM, + user_id: str, + session_id: int, + prefs: StylePrefs, + reveal_state: RevealState, + queries: List[Dict[str, Any]], +) -> List[TurnLog]: + """ + Run a single session with reveal-aware judging. + """ + logs: List[TurnLog] = [] + + print(f"\n{'='*60}") + print(f"SESSION {session_id}: user_id={user_id}, turns={len(queries)}") + print(f"Reveal state (start): {reveal_state}") + print(f"{'='*60}") + + # Reset session (clears history, z_short; keeps z_long and reveal state) + llm.reset_session(user_id) + + state_before = llm.get_user_state_summary(user_id) + print(f"[Session] z_long={state_before['z_long_norm']:.6f}, z_short={state_before['z_short_norm']:.6f}") + + for turn_id, q_info in enumerate(queries): + query = q_info["query"] + query_type = q_info.get("type", "task") + task_type = q_info.get("task_type", "general") + + print(f"\n{'─'*60}") + print(f"Session {session_id} / Turn {turn_id} [{query_type}]") + print(f"{'─'*60}") + print(f"[Query] {query}") + + # Capture reveal state BEFORE this turn + reveal_before = reveal_state.to_dict() + + # Update reveal state based on query content + newly_revealed = update_reveal_state(reveal_state, query) + if newly_revealed: + print(f"[Reveal] Newly revealed: {newly_revealed}") + print(f"[Reveal] State: {reveal_state}") + + # Capture reveal state AFTER update + reveal_after = reveal_state.to_dict() + + # Get user state before + state_before = llm.get_user_state_summary(user_id) + z_long_before = state_before["z_long_norm"] + z_short_before = state_before["z_short_norm"] + + # Apply feedback for previous turn (from turn 1 onwards in this session) + if turn_id > 0 and len(logs) > 0: + # Find the last log from THIS session + session_logs = [l for l in logs if l.session_id == session_id] + if session_logs: + prev_log = session_logs[-1] + feedback = Feedback( + user_id=user_id, + turn_id=prev_log.turn_id, + reward=prev_log.reward, + gating=prev_log.gating, + meta={ + "sat_t": prev_log.sat_t, + "violations": prev_log.violations, + "source": "pilot_v2", + "session_id": session_id, + } + ) + print(f"[Feedback] turn={prev_log.turn_id}, reward={feedback.reward:.2f}, gating={feedback.gating:.1f}") + llm.apply_feedback(feedback) + + # Chat + resp: AssistantResponse = llm.chat(user_id, query) + + answer_display = resp.answer[:150] + "..." if len(resp.answer) > 150 else resp.answer + print(f"[Answer] ({len(resp.answer)} chars) {answer_display}") + print(f"[Usage] prompt={resp.usage.prompt_tokens}, completion={resp.usage.completion_tokens}") + + # Judge with reveal-aware logic + judge_result = style_judge_with_reveal(query, resp.answer, task_type, prefs, reveal_state) + print(f"[Judge] sat={judge_result.sat_t:.2f}, enforced={judge_result.enforced_constraints}") + if judge_result.violations: + print(f"[Judge] violations={judge_result.violations}") + + # Compute feedback + reward, gating = compute_feedback_for_turn( + turn_id=turn_id, + query=query, + query_type=query_type, + task_type=task_type, + judge_result=judge_result, + ) + print(f"[Feedback] reward={reward:.2f}, gating={gating:.1f}") + + # Get state after + state_after = llm.get_user_state_summary(user_id) + z_long_after = state_after["z_long_norm"] + z_short_after = state_after["z_short_norm"] + + z_long_delta = z_long_after - z_long_before + z_short_delta = z_short_after - z_short_before + print(f"[State] z_long: {z_long_before:.6f} → {z_long_after:.6f} (Δ={z_long_delta:+.6f})") + print(f"[State] z_short: {z_short_before:.6f} → {z_short_after:.6f} (Δ={z_short_delta:+.6f})") + + # Debug info + num_memories = len(resp.debug.selected_memory_ids) if resp.debug else 0 + num_prefs = len(resp.debug.extracted_preferences) if resp.debug else 0 + print(f"[Debug] memories={num_memories}, prefs_extracted={num_prefs}") + + # Log + log = TurnLog( + session_id=session_id, + turn_id=turn_id, + query=query, + query_type=query_type, + task_type=task_type, + answer=resp.answer, + answer_length=len(resp.answer), + sat_t=judge_result.sat_t, + sev_t=judge_result.sev_t, + prog_t=judge_result.prog_t, + violations=judge_result.violations, + enforced_constraints=judge_result.enforced_constraints, + reward=reward, + gating=gating, + reveal_state_before=reveal_before, + reveal_state_after=reveal_after, + newly_revealed=list(newly_revealed), + z_long_norm_before=z_long_before, + z_long_norm_after=z_long_after, + z_short_norm_before=z_short_before, + z_short_norm_after=z_short_after, + prompt_tokens=resp.usage.prompt_tokens, + completion_tokens=resp.usage.completion_tokens, + total_tokens=resp.usage.total_tokens, + num_memories_retrieved=num_memories, + num_prefs_extracted=num_prefs, + ) + logs.append(log) + + # Apply final feedback for this session + session_logs = [l for l in logs if l.session_id == session_id] + if session_logs: + last_log = session_logs[-1] + feedback = Feedback( + user_id=user_id, + turn_id=last_log.turn_id, + reward=last_log.reward, + gating=last_log.gating, + meta={"source": "pilot_v2", "session_id": session_id, "final": True} + ) + print(f"\n[Final Feedback] turn={last_log.turn_id}, reward={feedback.reward:.2f}, gating={feedback.gating:.1f}") + llm.apply_feedback(feedback) + + print(f"\n[Session {session_id} End] Reveal state: {reveal_state}") + + return logs + + +def run_pilot_v2( + llm: PersonalizedLLM, + user_id: str = "pilot_user_v2", + prefs: Optional[StylePrefs] = None, +) -> List[TurnLog]: + """ + Run multi-session pilot with reveal state tracking. + + Session 1: User reveals preferences + Session 2: User does NOT restate preferences (tests cross-session retention) + Session 3: Mix of tasks and reminders + """ + if prefs is None: + prefs = StylePrefs( + require_short=True, + max_chars=200, + require_bullets=True, + lang="en", + ) + + # Initialize reveal state manager + reveal_manager = RevealStateManager() + + print(f"\n{'#'*60}") + print(f"PILOT v2: CROSS-SESSION PREFERENCE REVEAL TEST") + print(f"User: {user_id}") + print(f"True prefs: short={prefs.require_short}, bullets={prefs.require_bullets}, lang={prefs.lang}") + print(f"{'#'*60}") + + # Reset user completely (clears all state including reveal) + print(f"\n[Pilot] Resetting user: {user_id}") + llm.reset_user(user_id) + reveal_manager.reset_user(user_id) + + all_logs: List[TurnLog] = [] + reveal_state = reveal_manager.get_state(user_id) + + # Session 1: Reveal preferences + session_1_queries = get_session_1_queries() + logs_s1 = run_session(llm, user_id, 1, prefs, reveal_state, session_1_queries) + all_logs.extend(logs_s1) + + # Session 2: NO preference restatement (test cross-session retention) + # Note: reveal_state persists, but reset_session clears history + reveal_manager.reset_session(user_id) # Does nothing to reveal state + session_2_queries = get_session_2_queries() + logs_s2 = run_session(llm, user_id, 2, prefs, reveal_state, session_2_queries) + all_logs.extend(logs_s2) + + # Session 3: Reminder and more tasks + reveal_manager.reset_session(user_id) + session_3_queries = get_session_3_queries() + logs_s3 = run_session(llm, user_id, 3, prefs, reveal_state, session_3_queries) + all_logs.extend(logs_s3) + + return all_logs + + +def print_summary_v2(logs: List[TurnLog], prefs: StylePrefs): + """Print summary for pilot v2.""" + print(f"\n{'='*60}") + print("PILOT v2 SUMMARY - Cross-Session Reveal") + print(f"{'='*60}") + + if not logs: + print("No logs to summarize.") + return + + # Per-session stats + sessions = sorted(set(l.session_id for l in logs)) + + print(f"\n--- Per-Session Statistics ---") + for sid in sessions: + session_logs = [l for l in logs if l.session_id == sid] + avg_sat = sum(l.sat_t for l in session_logs) / len(session_logs) + violations = [v for l in session_logs for v in l.violations] + + # What was revealed at session end + if session_logs: + final_reveal = session_logs[-1].reveal_state_after + else: + final_reveal = {} + + print(f"\nSession {sid}: {len(session_logs)} turns") + print(f" Avg sat_t: {avg_sat:.3f}") + print(f" Violations: {len(violations)} ({violations if violations else 'none'})") + print(f" Reveal state at end: {final_reveal}") + + # Overall stats + total = len(logs) + avg_sat = sum(l.sat_t for l in logs) / total + total_tokens = sum(l.total_tokens for l in logs) + + print(f"\n--- Overall Statistics ---") + print(f"Total turns: {total}") + print(f"Overall avg sat_t: {avg_sat:.3f}") + print(f"Total tokens: {total_tokens}") + + # Violations by type + print(f"\n--- Violations Breakdown ---") + from collections import Counter + all_violations = [v for l in logs for v in l.violations] + if all_violations: + for v, count in Counter(all_violations).most_common(): + print(f" {v}: {count}") + else: + print(" No violations") + + # Enforcement tracking + print(f"\n--- Constraint Enforcement ---") + for constraint in ["short", "bullets", "lang"]: + enforced_count = sum(1 for l in logs if constraint in l.enforced_constraints) + print(f" {constraint}: enforced in {enforced_count}/{total} turns") + + # Cross-session reveal verification + print(f"\n--- Cross-Session Reveal Verification ---") + + # Session 1: Should have some reveals + s1_logs = [l for l in logs if l.session_id == 1] + s1_reveals = set() + for l in s1_logs: + s1_reveals.update(l.newly_revealed) + print(f"Session 1 revealed: {s1_reveals if s1_reveals else 'none'}") + + # Session 2: Should NOT have new reveals (no preference queries) + s2_logs = [l for l in logs if l.session_id == 2] + s2_reveals = set() + for l in s2_logs: + s2_reveals.update(l.newly_revealed) + print(f"Session 2 revealed: {s2_reveals if s2_reveals else 'none (expected)'}") + + # But Session 2 should still ENFORCE the constraints revealed in Session 1 + if s2_logs: + s2_enforced = set() + for l in s2_logs: + s2_enforced.update(l.enforced_constraints) + print(f"Session 2 enforced: {s2_enforced}") + + if s1_reveals and s1_reveals.issubset(s2_enforced): + print("✓ Cross-session retention VERIFIED: Session 1 reveals enforced in Session 2") + else: + print("✗ Cross-session retention issue: some reveals not enforced") + + # Turn-by-turn table + print(f"\n--- Turn-by-Turn Summary ---") + print(f"{'S':>2} {'T':>2} {'Type':>10} {'Len':>5} {'sat':>5} {'enforced':<20} {'violations'}") + print("-" * 70) + for l in logs: + enforced_str = ",".join(l.enforced_constraints) if l.enforced_constraints else "-" + viol_str = ",".join(l.violations) if l.violations else "-" + print(f"{l.session_id:>2} {l.turn_id:>2} {l.query_type:>10} {l.answer_length:>5} {l.sat_t:>5.2f} {enforced_str:<20} {viol_str}") + + +def main(): + print("=" * 60) + print("PILOT RUNNER v2 - Cross-Session Preference Reveal") + print("=" * 60) + print(f"Started at: {datetime.now().isoformat()}") + + # Define user's TRUE preferences + prefs = StylePrefs( + require_short=True, + max_chars=200, + require_bullets=True, + lang="en", + ) + print(f"\n[Config] True preferences: {prefs}") + print("[Config] Note: Constraints only enforced AFTER user reveals them") + + # Initialize LLM + print("\n[Init] Loading PersonalizedLLM...") + llm = PersonalizedLLM( + user_store_path="data/users/user_store_pilot_v2.npz", + only_own_memories=True, + enable_preference_extraction=True, + enable_rl_updates=True, + ) + + # Run pilot + user_id = "pilot_user_v2" + logs = run_pilot_v2(llm, user_id=user_id, prefs=prefs) + + # Summary + print_summary_v2(logs, prefs) + + # Save logs + log_path = f"data/logs/pilot_v2_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl" + log_to_jsonl(logs, log_path) + print(f"\n[Logs] Saved to: {log_path}") + + # Final state + final_state = llm.get_user_state_summary(user_id) + print(f"\n[Final State] {final_state}") + + print(f"\nCompleted at: {datetime.now().isoformat()}") + print("=" * 60) + + +if __name__ == "__main__": + main() + + |
