#!/usr/bin/env python3 """ Analyze Full vs NoPersonal Baseline Comparison. This script loads logs from pilot_runner_v4 runs (both full and nopersonal modes) and produces comparison metrics for: 1. Session 2 retention (base task avg satisfaction) 2. Violation rates by type 3. Preference memory recall@k Usage: python scripts/analyze_full_vs_nopersonal.py \ --full data/logs/pilot_v4_full_TIMESTAMP.jsonl \ --nopersonal data/logs/pilot_v4_nopersonal_TIMESTAMP.jsonl """ import json import argparse import re from dataclasses import dataclass from typing import List, Dict, Any, Optional, Set from collections import defaultdict @dataclass class TurnLog: """Parsed log entry.""" user_id: str persona_id: str session_id: int turn_id: int query: str query_type: str task_type: str answer: str answer_length: int sat_t: float sev_t: float prog_t: float violations: List[str] enforced_constraints: List[str] reward: float gating: float is_complaint: bool reveal_state_before: Dict[str, bool] reveal_state_after: Dict[str, bool] newly_revealed: List[str] z_long_norm_before: float z_long_norm_after: float z_short_norm_before: float z_short_norm_after: float prompt_tokens: int completion_tokens: int total_tokens: int num_memories_retrieved: int num_prefs_extracted: int selected_memory_ids: List[str] selected_memory_notes: List[str] selected_memory_scores: List[float] num_candidates: int num_total_memories: int mode: str eval_mode: bool # True = greedy, False = sample def load_logs(filepath: str) -> List[TurnLog]: """Load logs from JSONL file.""" logs = [] with open(filepath, "r") as f: for line in f: if line.strip(): data = json.loads(line) # Handle missing fields with defaults log = TurnLog( user_id=data.get("user_id", ""), persona_id=data.get("persona_id", ""), session_id=data.get("session_id", 0), turn_id=data.get("turn_id", 0), query=data.get("query", ""), query_type=data.get("query_type", ""), task_type=data.get("task_type", ""), answer=data.get("answer", ""), answer_length=data.get("answer_length", 0), sat_t=data.get("sat_t", 0.0), sev_t=data.get("sev_t", 0.0), prog_t=data.get("prog_t", 0.0), violations=data.get("violations", []), enforced_constraints=data.get("enforced_constraints", []), reward=data.get("reward", 0.0), gating=data.get("gating", 0.0), is_complaint=data.get("is_complaint", False), reveal_state_before=data.get("reveal_state_before", {}), reveal_state_after=data.get("reveal_state_after", {}), newly_revealed=data.get("newly_revealed", []), z_long_norm_before=data.get("z_long_norm_before", 0.0), z_long_norm_after=data.get("z_long_norm_after", 0.0), z_short_norm_before=data.get("z_short_norm_before", 0.0), z_short_norm_after=data.get("z_short_norm_after", 0.0), prompt_tokens=data.get("prompt_tokens", 0), completion_tokens=data.get("completion_tokens", 0), total_tokens=data.get("total_tokens", 0), num_memories_retrieved=data.get("num_memories_retrieved", 0), num_prefs_extracted=data.get("num_prefs_extracted", 0), selected_memory_ids=data.get("selected_memory_ids", []), selected_memory_notes=data.get("selected_memory_notes", []), selected_memory_scores=data.get("selected_memory_scores", []), num_candidates=data.get("num_candidates", 0), num_total_memories=data.get("num_total_memories", 0), mode=data.get("mode", "unknown"), eval_mode=data.get("eval_mode", True), ) logs.append(log) return logs def is_base_task_turn(log: TurnLog) -> bool: """Check if this is a base task turn (not complaint, not preference).""" if log.is_complaint: return False if log.query_type == "preference": return False if log.query_type in ("task", "task_list"): return True return False def compute_session2_base_avg_sat(logs: List[TurnLog]) -> Dict[str, float]: """ Compute average satisfaction for Session 2 base tasks. Returns: {user_id: avg_sat} """ user_sat = defaultdict(list) for log in logs: if log.session_id == 2 and is_base_task_turn(log): user_sat[log.user_id].append(log.sat_t) result = {} for user_id, sats in user_sat.items(): if sats: result[user_id] = sum(sats) / len(sats) return result def compute_overall_session2_avg_sat(logs: List[TurnLog]) -> float: """Compute overall average satisfaction for Session 2 base tasks.""" sats = [] for log in logs: if log.session_id == 2 and is_base_task_turn(log): sats.append(log.sat_t) return sum(sats) / len(sats) if sats else 0.0 def compute_violation_rates(logs: List[TurnLog], session_filter: Optional[int] = None) -> Dict[str, float]: """ Compute violation rates by type. Returns: {violation_type: rate} """ violation_counts = defaultdict(int) total_base_tasks = 0 for log in logs: if session_filter is not None and log.session_id != session_filter: continue if not is_base_task_turn(log): continue total_base_tasks += 1 for v in log.violations: violation_counts[v] += 1 if total_base_tasks == 0: return {} return {v: count / total_base_tasks for v, count in violation_counts.items()} def is_pref_memory(note_text: str, dim: str) -> bool: """ Check if a memory note relates to a preference dimension. dim: "short", "bullets", or "lang" """ text_lower = note_text.lower() if dim == "short": keywords = [ "short", "concise", "brief", "200", "characters", "less", "简短", "精简", "字以内", "不超过", "简洁" ] return any(kw in text_lower for kw in keywords) elif dim == "bullets": keywords = [ "bullet", "bullets", "list", "point", "points", "要点", "列表", "项目符号" ] # Also check for "no bullet" / "don't use bullet" no_bullet = any(x in text_lower for x in ["no bullet", "don't use bullet", "without bullet", "不要要点", "不使用列表"]) if no_bullet: return True # It's still about bullets preference return any(kw in text_lower for kw in keywords) elif dim == "lang": # Check for language preferences zh_keywords = ["chinese", "中文", "用中文", "请用中文"] en_keywords = ["english", "英文", "in english"] return any(kw in text_lower for kw in zh_keywords + en_keywords) return False def compute_pref_recall_at_k(logs: List[TurnLog], dim: str, session_filter: Optional[int] = None) -> float: """ Compute preference memory recall@k for a given dimension. Returns: fraction of base task turns where a relevant pref memory was retrieved. """ hits = 0 total = 0 for log in logs: if session_filter is not None and log.session_id != session_filter: continue if not is_base_task_turn(log): continue total += 1 # Check if any selected memory note matches the dimension for note in log.selected_memory_notes: if is_pref_memory(note, dim): hits += 1 break return hits / total if total > 0 else 0.0 def print_comparison_table(full_logs: List[TurnLog], nopersonal_logs: List[TurnLog]): """Print a comparison table of Full vs NoPersonal metrics.""" # Detect mode from logs full_mode = full_logs[0].mode if full_logs else "unknown" full_eval = "greedy" if (full_logs and full_logs[0].eval_mode) else "sample" np_mode = nopersonal_logs[0].mode if nopersonal_logs else "unknown" np_eval = "greedy" if (nopersonal_logs and nopersonal_logs[0].eval_mode) else "sample" print("\n" + "=" * 70) print("FULL vs NOPERSONAL COMPARISON") print(f"Full: mode={full_mode}, selection={full_eval}") print(f"NoPersonal: mode={np_mode}, selection={np_eval}") print("=" * 70) # 1. Session 2 Base Task Average Satisfaction print("\n### 1. Session 2 Base Task Average Satisfaction") print("-" * 50) full_s2_sat = compute_overall_session2_avg_sat(full_logs) nopersonal_s2_sat = compute_overall_session2_avg_sat(nopersonal_logs) delta = full_s2_sat - nopersonal_s2_sat print(f"{'Metric':<30} {'Full':<12} {'NoPersonal':<12} {'Delta':<12}") print("-" * 50) print(f"{'avg_sat_S2_base':<30} {full_s2_sat:<12.4f} {nopersonal_s2_sat:<12.4f} {delta:<+12.4f}") # Per-user breakdown full_user_sat = compute_session2_base_avg_sat(full_logs) nopersonal_user_sat = compute_session2_base_avg_sat(nopersonal_logs) print("\nPer-user Session 2 avg_sat:") print(f"{'User':<20} {'Full':<12} {'NoPersonal':<12} {'Delta':<12}") print("-" * 50) all_users = set(full_user_sat.keys()) | set(nopersonal_user_sat.keys()) for user_id in sorted(all_users): f_sat = full_user_sat.get(user_id, 0.0) n_sat = nopersonal_user_sat.get(user_id, 0.0) d = f_sat - n_sat print(f"{user_id:<20} {f_sat:<12.4f} {n_sat:<12.4f} {d:<+12.4f}") # 2. Violation Rates print("\n### 2. Session 2 Violation Rates") print("-" * 50) full_viol = compute_violation_rates(full_logs, session_filter=2) nopersonal_viol = compute_violation_rates(nopersonal_logs, session_filter=2) all_viols = set(full_viol.keys()) | set(nopersonal_viol.keys()) key_viols = ["too_long", "no_bullets", "has_bullets", "wrong_lang", "empty_answer"] print(f"{'Violation Type':<20} {'Full':<12} {'NoPersonal':<12} {'Delta':<12}") print("-" * 50) for v in key_viols: if v in all_viols: f_rate = full_viol.get(v, 0.0) n_rate = nopersonal_viol.get(v, 0.0) d = f_rate - n_rate print(f"{v:<20} {f_rate:<12.4f} {n_rate:<12.4f} {d:<+12.4f}") # Other violations other_viols = [v for v in all_viols if v not in key_viols] for v in sorted(other_viols): f_rate = full_viol.get(v, 0.0) n_rate = nopersonal_viol.get(v, 0.0) d = f_rate - n_rate print(f"{v:<20} {f_rate:<12.4f} {n_rate:<12.4f} {d:<+12.4f}") # 3. Preference Memory Recall@k print("\n### 3. Session 2 Preference Memory Recall@k") print("-" * 50) dims = ["short", "bullets", "lang"] print(f"{'Dimension':<20} {'Full':<12} {'NoPersonal':<12} {'Delta':<12}") print("-" * 50) for dim in dims: f_recall = compute_pref_recall_at_k(full_logs, dim, session_filter=2) n_recall = compute_pref_recall_at_k(nopersonal_logs, dim, session_filter=2) d = f_recall - n_recall print(f"{dim:<20} {f_recall:<12.4f} {n_recall:<12.4f} {d:<+12.4f}") # 4. Summary Statistics print("\n### 4. Summary Statistics") print("-" * 50) def count_base_tasks(logs, session=None): return sum(1 for l in logs if (session is None or l.session_id == session) and is_base_task_turn(l)) def count_complaints(logs, session=None): return sum(1 for l in logs if (session is None or l.session_id == session) and l.is_complaint) print(f"{'Statistic':<30} {'Full':<12} {'NoPersonal':<12}") print("-" * 50) print(f"{'Total turns':<30} {len(full_logs):<12} {len(nopersonal_logs):<12}") print(f"{'S2 base task turns':<30} {count_base_tasks(full_logs, 2):<12} {count_base_tasks(nopersonal_logs, 2):<12}") print(f"{'S2 complaint turns':<30} {count_complaints(full_logs, 2):<12} {count_complaints(nopersonal_logs, 2):<12}") # Token usage full_tokens = sum(l.total_tokens for l in full_logs) nopersonal_tokens = sum(l.total_tokens for l in nopersonal_logs) print(f"{'Total tokens':<30} {full_tokens:<12} {nopersonal_tokens:<12}") print("\n" + "=" * 70) def main(): parser = argparse.ArgumentParser(description="Analyze Full vs NoPersonal Comparison") parser.add_argument("--full", type=str, required=True, help="Path to Full mode log file") parser.add_argument("--nopersonal", type=str, required=True, help="Path to NoPersonal mode log file") args = parser.parse_args() print(f"Loading Full logs from: {args.full}") full_logs = load_logs(args.full) print(f" Loaded {len(full_logs)} turns") print(f"Loading NoPersonal logs from: {args.nopersonal}") nopersonal_logs = load_logs(args.nopersonal) print(f" Loaded {len(nopersonal_logs)} turns") print_comparison_table(full_logs, nopersonal_logs) if __name__ == "__main__": main()