diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
| commit | e43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch) | |
| tree | 6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /scripts/analyze_full_vs_nopersonal.py | |
Diffstat (limited to 'scripts/analyze_full_vs_nopersonal.py')
| -rw-r--r-- | scripts/analyze_full_vs_nopersonal.py | 361 |
1 files changed, 361 insertions, 0 deletions
diff --git a/scripts/analyze_full_vs_nopersonal.py b/scripts/analyze_full_vs_nopersonal.py new file mode 100644 index 0000000..a10f3ef --- /dev/null +++ b/scripts/analyze_full_vs_nopersonal.py @@ -0,0 +1,361 @@ +#!/usr/bin/env python3 +""" +Analyze Full vs NoPersonal Baseline Comparison. + +This script loads logs from pilot_runner_v4 runs (both full and nopersonal modes) +and produces comparison metrics for: +1. Session 2 retention (base task avg satisfaction) +2. Violation rates by type +3. Preference memory recall@k + +Usage: + python scripts/analyze_full_vs_nopersonal.py \ + --full data/logs/pilot_v4_full_TIMESTAMP.jsonl \ + --nopersonal data/logs/pilot_v4_nopersonal_TIMESTAMP.jsonl +""" + +import json +import argparse +import re +from dataclasses import dataclass +from typing import List, Dict, Any, Optional, Set +from collections import defaultdict + + +@dataclass +class TurnLog: + """Parsed log entry.""" + user_id: str + persona_id: str + session_id: int + turn_id: int + query: str + query_type: str + task_type: str + answer: str + answer_length: int + sat_t: float + sev_t: float + prog_t: float + violations: List[str] + enforced_constraints: List[str] + reward: float + gating: float + is_complaint: bool + reveal_state_before: Dict[str, bool] + reveal_state_after: Dict[str, bool] + newly_revealed: List[str] + z_long_norm_before: float + z_long_norm_after: float + z_short_norm_before: float + z_short_norm_after: float + prompt_tokens: int + completion_tokens: int + total_tokens: int + num_memories_retrieved: int + num_prefs_extracted: int + selected_memory_ids: List[str] + selected_memory_notes: List[str] + selected_memory_scores: List[float] + num_candidates: int + num_total_memories: int + mode: str + eval_mode: bool # True = greedy, False = sample + + +def load_logs(filepath: str) -> List[TurnLog]: + """Load logs from JSONL file.""" + logs = [] + with open(filepath, "r") as f: + for line in f: + if line.strip(): + data = json.loads(line) + # Handle missing fields with defaults + log = TurnLog( + user_id=data.get("user_id", ""), + persona_id=data.get("persona_id", ""), + session_id=data.get("session_id", 0), + turn_id=data.get("turn_id", 0), + query=data.get("query", ""), + query_type=data.get("query_type", ""), + task_type=data.get("task_type", ""), + answer=data.get("answer", ""), + answer_length=data.get("answer_length", 0), + sat_t=data.get("sat_t", 0.0), + sev_t=data.get("sev_t", 0.0), + prog_t=data.get("prog_t", 0.0), + violations=data.get("violations", []), + enforced_constraints=data.get("enforced_constraints", []), + reward=data.get("reward", 0.0), + gating=data.get("gating", 0.0), + is_complaint=data.get("is_complaint", False), + reveal_state_before=data.get("reveal_state_before", {}), + reveal_state_after=data.get("reveal_state_after", {}), + newly_revealed=data.get("newly_revealed", []), + z_long_norm_before=data.get("z_long_norm_before", 0.0), + z_long_norm_after=data.get("z_long_norm_after", 0.0), + z_short_norm_before=data.get("z_short_norm_before", 0.0), + z_short_norm_after=data.get("z_short_norm_after", 0.0), + prompt_tokens=data.get("prompt_tokens", 0), + completion_tokens=data.get("completion_tokens", 0), + total_tokens=data.get("total_tokens", 0), + num_memories_retrieved=data.get("num_memories_retrieved", 0), + num_prefs_extracted=data.get("num_prefs_extracted", 0), + selected_memory_ids=data.get("selected_memory_ids", []), + selected_memory_notes=data.get("selected_memory_notes", []), + selected_memory_scores=data.get("selected_memory_scores", []), + num_candidates=data.get("num_candidates", 0), + num_total_memories=data.get("num_total_memories", 0), + mode=data.get("mode", "unknown"), + eval_mode=data.get("eval_mode", True), + ) + logs.append(log) + return logs + + +def is_base_task_turn(log: TurnLog) -> bool: + """Check if this is a base task turn (not complaint, not preference).""" + if log.is_complaint: + return False + if log.query_type == "preference": + return False + if log.query_type in ("task", "task_list"): + return True + return False + + +def compute_session2_base_avg_sat(logs: List[TurnLog]) -> Dict[str, float]: + """ + Compute average satisfaction for Session 2 base tasks. + Returns: {user_id: avg_sat} + """ + user_sat = defaultdict(list) + + for log in logs: + if log.session_id == 2 and is_base_task_turn(log): + user_sat[log.user_id].append(log.sat_t) + + result = {} + for user_id, sats in user_sat.items(): + if sats: + result[user_id] = sum(sats) / len(sats) + + return result + + +def compute_overall_session2_avg_sat(logs: List[TurnLog]) -> float: + """Compute overall average satisfaction for Session 2 base tasks.""" + sats = [] + for log in logs: + if log.session_id == 2 and is_base_task_turn(log): + sats.append(log.sat_t) + return sum(sats) / len(sats) if sats else 0.0 + + +def compute_violation_rates(logs: List[TurnLog], session_filter: Optional[int] = None) -> Dict[str, float]: + """ + Compute violation rates by type. + Returns: {violation_type: rate} + """ + violation_counts = defaultdict(int) + total_base_tasks = 0 + + for log in logs: + if session_filter is not None and log.session_id != session_filter: + continue + if not is_base_task_turn(log): + continue + + total_base_tasks += 1 + for v in log.violations: + violation_counts[v] += 1 + + if total_base_tasks == 0: + return {} + + return {v: count / total_base_tasks for v, count in violation_counts.items()} + + +def is_pref_memory(note_text: str, dim: str) -> bool: + """ + Check if a memory note relates to a preference dimension. + dim: "short", "bullets", or "lang" + """ + text_lower = note_text.lower() + + if dim == "short": + keywords = [ + "short", "concise", "brief", "200", "characters", "less", + "简短", "精简", "字以内", "不超过", "简洁" + ] + return any(kw in text_lower for kw in keywords) + + elif dim == "bullets": + keywords = [ + "bullet", "bullets", "list", "point", "points", + "要点", "列表", "项目符号" + ] + # Also check for "no bullet" / "don't use bullet" + no_bullet = any(x in text_lower for x in ["no bullet", "don't use bullet", "without bullet", "不要要点", "不使用列表"]) + if no_bullet: + return True # It's still about bullets preference + return any(kw in text_lower for kw in keywords) + + elif dim == "lang": + # Check for language preferences + zh_keywords = ["chinese", "中文", "用中文", "请用中文"] + en_keywords = ["english", "英文", "in english"] + return any(kw in text_lower for kw in zh_keywords + en_keywords) + + return False + + +def compute_pref_recall_at_k(logs: List[TurnLog], dim: str, session_filter: Optional[int] = None) -> float: + """ + Compute preference memory recall@k for a given dimension. + Returns: fraction of base task turns where a relevant pref memory was retrieved. + """ + hits = 0 + total = 0 + + for log in logs: + if session_filter is not None and log.session_id != session_filter: + continue + if not is_base_task_turn(log): + continue + + total += 1 + # Check if any selected memory note matches the dimension + for note in log.selected_memory_notes: + if is_pref_memory(note, dim): + hits += 1 + break + + return hits / total if total > 0 else 0.0 + + +def print_comparison_table(full_logs: List[TurnLog], nopersonal_logs: List[TurnLog]): + """Print a comparison table of Full vs NoPersonal metrics.""" + + # Detect mode from logs + full_mode = full_logs[0].mode if full_logs else "unknown" + full_eval = "greedy" if (full_logs and full_logs[0].eval_mode) else "sample" + np_mode = nopersonal_logs[0].mode if nopersonal_logs else "unknown" + np_eval = "greedy" if (nopersonal_logs and nopersonal_logs[0].eval_mode) else "sample" + + print("\n" + "=" * 70) + print("FULL vs NOPERSONAL COMPARISON") + print(f"Full: mode={full_mode}, selection={full_eval}") + print(f"NoPersonal: mode={np_mode}, selection={np_eval}") + print("=" * 70) + + # 1. Session 2 Base Task Average Satisfaction + print("\n### 1. Session 2 Base Task Average Satisfaction") + print("-" * 50) + + full_s2_sat = compute_overall_session2_avg_sat(full_logs) + nopersonal_s2_sat = compute_overall_session2_avg_sat(nopersonal_logs) + delta = full_s2_sat - nopersonal_s2_sat + + print(f"{'Metric':<30} {'Full':<12} {'NoPersonal':<12} {'Delta':<12}") + print("-" * 50) + print(f"{'avg_sat_S2_base':<30} {full_s2_sat:<12.4f} {nopersonal_s2_sat:<12.4f} {delta:<+12.4f}") + + # Per-user breakdown + full_user_sat = compute_session2_base_avg_sat(full_logs) + nopersonal_user_sat = compute_session2_base_avg_sat(nopersonal_logs) + + print("\nPer-user Session 2 avg_sat:") + print(f"{'User':<20} {'Full':<12} {'NoPersonal':<12} {'Delta':<12}") + print("-" * 50) + all_users = set(full_user_sat.keys()) | set(nopersonal_user_sat.keys()) + for user_id in sorted(all_users): + f_sat = full_user_sat.get(user_id, 0.0) + n_sat = nopersonal_user_sat.get(user_id, 0.0) + d = f_sat - n_sat + print(f"{user_id:<20} {f_sat:<12.4f} {n_sat:<12.4f} {d:<+12.4f}") + + # 2. Violation Rates + print("\n### 2. Session 2 Violation Rates") + print("-" * 50) + + full_viol = compute_violation_rates(full_logs, session_filter=2) + nopersonal_viol = compute_violation_rates(nopersonal_logs, session_filter=2) + + all_viols = set(full_viol.keys()) | set(nopersonal_viol.keys()) + key_viols = ["too_long", "no_bullets", "has_bullets", "wrong_lang", "empty_answer"] + + print(f"{'Violation Type':<20} {'Full':<12} {'NoPersonal':<12} {'Delta':<12}") + print("-" * 50) + for v in key_viols: + if v in all_viols: + f_rate = full_viol.get(v, 0.0) + n_rate = nopersonal_viol.get(v, 0.0) + d = f_rate - n_rate + print(f"{v:<20} {f_rate:<12.4f} {n_rate:<12.4f} {d:<+12.4f}") + + # Other violations + other_viols = [v for v in all_viols if v not in key_viols] + for v in sorted(other_viols): + f_rate = full_viol.get(v, 0.0) + n_rate = nopersonal_viol.get(v, 0.0) + d = f_rate - n_rate + print(f"{v:<20} {f_rate:<12.4f} {n_rate:<12.4f} {d:<+12.4f}") + + # 3. Preference Memory Recall@k + print("\n### 3. Session 2 Preference Memory Recall@k") + print("-" * 50) + + dims = ["short", "bullets", "lang"] + print(f"{'Dimension':<20} {'Full':<12} {'NoPersonal':<12} {'Delta':<12}") + print("-" * 50) + for dim in dims: + f_recall = compute_pref_recall_at_k(full_logs, dim, session_filter=2) + n_recall = compute_pref_recall_at_k(nopersonal_logs, dim, session_filter=2) + d = f_recall - n_recall + print(f"{dim:<20} {f_recall:<12.4f} {n_recall:<12.4f} {d:<+12.4f}") + + # 4. Summary Statistics + print("\n### 4. Summary Statistics") + print("-" * 50) + + def count_base_tasks(logs, session=None): + return sum(1 for l in logs if (session is None or l.session_id == session) and is_base_task_turn(l)) + + def count_complaints(logs, session=None): + return sum(1 for l in logs if (session is None or l.session_id == session) and l.is_complaint) + + print(f"{'Statistic':<30} {'Full':<12} {'NoPersonal':<12}") + print("-" * 50) + print(f"{'Total turns':<30} {len(full_logs):<12} {len(nopersonal_logs):<12}") + print(f"{'S2 base task turns':<30} {count_base_tasks(full_logs, 2):<12} {count_base_tasks(nopersonal_logs, 2):<12}") + print(f"{'S2 complaint turns':<30} {count_complaints(full_logs, 2):<12} {count_complaints(nopersonal_logs, 2):<12}") + + # Token usage + full_tokens = sum(l.total_tokens for l in full_logs) + nopersonal_tokens = sum(l.total_tokens for l in nopersonal_logs) + print(f"{'Total tokens':<30} {full_tokens:<12} {nopersonal_tokens:<12}") + + print("\n" + "=" * 70) + + +def main(): + parser = argparse.ArgumentParser(description="Analyze Full vs NoPersonal Comparison") + parser.add_argument("--full", type=str, required=True, help="Path to Full mode log file") + parser.add_argument("--nopersonal", type=str, required=True, help="Path to NoPersonal mode log file") + args = parser.parse_args() + + print(f"Loading Full logs from: {args.full}") + full_logs = load_logs(args.full) + print(f" Loaded {len(full_logs)} turns") + + print(f"Loading NoPersonal logs from: {args.nopersonal}") + nopersonal_logs = load_logs(args.nopersonal) + print(f" Loaded {len(nopersonal_logs)} turns") + + print_comparison_table(full_logs, nopersonal_logs) + + +if __name__ == "__main__": + main() + |
