summaryrefslogtreecommitdiff
path: root/scripts/analyze_full_vs_nopersonal.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
commite43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch)
tree6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /scripts/analyze_full_vs_nopersonal.py
Initial commit (clean history)HEADmain
Diffstat (limited to 'scripts/analyze_full_vs_nopersonal.py')
-rw-r--r--scripts/analyze_full_vs_nopersonal.py361
1 files changed, 361 insertions, 0 deletions
diff --git a/scripts/analyze_full_vs_nopersonal.py b/scripts/analyze_full_vs_nopersonal.py
new file mode 100644
index 0000000..a10f3ef
--- /dev/null
+++ b/scripts/analyze_full_vs_nopersonal.py
@@ -0,0 +1,361 @@
+#!/usr/bin/env python3
+"""
+Analyze Full vs NoPersonal Baseline Comparison.
+
+This script loads logs from pilot_runner_v4 runs (both full and nopersonal modes)
+and produces comparison metrics for:
+1. Session 2 retention (base task avg satisfaction)
+2. Violation rates by type
+3. Preference memory recall@k
+
+Usage:
+ python scripts/analyze_full_vs_nopersonal.py \
+ --full data/logs/pilot_v4_full_TIMESTAMP.jsonl \
+ --nopersonal data/logs/pilot_v4_nopersonal_TIMESTAMP.jsonl
+"""
+
+import json
+import argparse
+import re
+from dataclasses import dataclass
+from typing import List, Dict, Any, Optional, Set
+from collections import defaultdict
+
+
+@dataclass
+class TurnLog:
+ """Parsed log entry."""
+ user_id: str
+ persona_id: str
+ session_id: int
+ turn_id: int
+ query: str
+ query_type: str
+ task_type: str
+ answer: str
+ answer_length: int
+ sat_t: float
+ sev_t: float
+ prog_t: float
+ violations: List[str]
+ enforced_constraints: List[str]
+ reward: float
+ gating: float
+ is_complaint: bool
+ reveal_state_before: Dict[str, bool]
+ reveal_state_after: Dict[str, bool]
+ newly_revealed: List[str]
+ z_long_norm_before: float
+ z_long_norm_after: float
+ z_short_norm_before: float
+ z_short_norm_after: float
+ prompt_tokens: int
+ completion_tokens: int
+ total_tokens: int
+ num_memories_retrieved: int
+ num_prefs_extracted: int
+ selected_memory_ids: List[str]
+ selected_memory_notes: List[str]
+ selected_memory_scores: List[float]
+ num_candidates: int
+ num_total_memories: int
+ mode: str
+ eval_mode: bool # True = greedy, False = sample
+
+
+def load_logs(filepath: str) -> List[TurnLog]:
+ """Load logs from JSONL file."""
+ logs = []
+ with open(filepath, "r") as f:
+ for line in f:
+ if line.strip():
+ data = json.loads(line)
+ # Handle missing fields with defaults
+ log = TurnLog(
+ user_id=data.get("user_id", ""),
+ persona_id=data.get("persona_id", ""),
+ session_id=data.get("session_id", 0),
+ turn_id=data.get("turn_id", 0),
+ query=data.get("query", ""),
+ query_type=data.get("query_type", ""),
+ task_type=data.get("task_type", ""),
+ answer=data.get("answer", ""),
+ answer_length=data.get("answer_length", 0),
+ sat_t=data.get("sat_t", 0.0),
+ sev_t=data.get("sev_t", 0.0),
+ prog_t=data.get("prog_t", 0.0),
+ violations=data.get("violations", []),
+ enforced_constraints=data.get("enforced_constraints", []),
+ reward=data.get("reward", 0.0),
+ gating=data.get("gating", 0.0),
+ is_complaint=data.get("is_complaint", False),
+ reveal_state_before=data.get("reveal_state_before", {}),
+ reveal_state_after=data.get("reveal_state_after", {}),
+ newly_revealed=data.get("newly_revealed", []),
+ z_long_norm_before=data.get("z_long_norm_before", 0.0),
+ z_long_norm_after=data.get("z_long_norm_after", 0.0),
+ z_short_norm_before=data.get("z_short_norm_before", 0.0),
+ z_short_norm_after=data.get("z_short_norm_after", 0.0),
+ prompt_tokens=data.get("prompt_tokens", 0),
+ completion_tokens=data.get("completion_tokens", 0),
+ total_tokens=data.get("total_tokens", 0),
+ num_memories_retrieved=data.get("num_memories_retrieved", 0),
+ num_prefs_extracted=data.get("num_prefs_extracted", 0),
+ selected_memory_ids=data.get("selected_memory_ids", []),
+ selected_memory_notes=data.get("selected_memory_notes", []),
+ selected_memory_scores=data.get("selected_memory_scores", []),
+ num_candidates=data.get("num_candidates", 0),
+ num_total_memories=data.get("num_total_memories", 0),
+ mode=data.get("mode", "unknown"),
+ eval_mode=data.get("eval_mode", True),
+ )
+ logs.append(log)
+ return logs
+
+
+def is_base_task_turn(log: TurnLog) -> bool:
+ """Check if this is a base task turn (not complaint, not preference)."""
+ if log.is_complaint:
+ return False
+ if log.query_type == "preference":
+ return False
+ if log.query_type in ("task", "task_list"):
+ return True
+ return False
+
+
+def compute_session2_base_avg_sat(logs: List[TurnLog]) -> Dict[str, float]:
+ """
+ Compute average satisfaction for Session 2 base tasks.
+ Returns: {user_id: avg_sat}
+ """
+ user_sat = defaultdict(list)
+
+ for log in logs:
+ if log.session_id == 2 and is_base_task_turn(log):
+ user_sat[log.user_id].append(log.sat_t)
+
+ result = {}
+ for user_id, sats in user_sat.items():
+ if sats:
+ result[user_id] = sum(sats) / len(sats)
+
+ return result
+
+
+def compute_overall_session2_avg_sat(logs: List[TurnLog]) -> float:
+ """Compute overall average satisfaction for Session 2 base tasks."""
+ sats = []
+ for log in logs:
+ if log.session_id == 2 and is_base_task_turn(log):
+ sats.append(log.sat_t)
+ return sum(sats) / len(sats) if sats else 0.0
+
+
+def compute_violation_rates(logs: List[TurnLog], session_filter: Optional[int] = None) -> Dict[str, float]:
+ """
+ Compute violation rates by type.
+ Returns: {violation_type: rate}
+ """
+ violation_counts = defaultdict(int)
+ total_base_tasks = 0
+
+ for log in logs:
+ if session_filter is not None and log.session_id != session_filter:
+ continue
+ if not is_base_task_turn(log):
+ continue
+
+ total_base_tasks += 1
+ for v in log.violations:
+ violation_counts[v] += 1
+
+ if total_base_tasks == 0:
+ return {}
+
+ return {v: count / total_base_tasks for v, count in violation_counts.items()}
+
+
+def is_pref_memory(note_text: str, dim: str) -> bool:
+ """
+ Check if a memory note relates to a preference dimension.
+ dim: "short", "bullets", or "lang"
+ """
+ text_lower = note_text.lower()
+
+ if dim == "short":
+ keywords = [
+ "short", "concise", "brief", "200", "characters", "less",
+ "简短", "精简", "字以内", "不超过", "简洁"
+ ]
+ return any(kw in text_lower for kw in keywords)
+
+ elif dim == "bullets":
+ keywords = [
+ "bullet", "bullets", "list", "point", "points",
+ "要点", "列表", "项目符号"
+ ]
+ # Also check for "no bullet" / "don't use bullet"
+ no_bullet = any(x in text_lower for x in ["no bullet", "don't use bullet", "without bullet", "不要要点", "不使用列表"])
+ if no_bullet:
+ return True # It's still about bullets preference
+ return any(kw in text_lower for kw in keywords)
+
+ elif dim == "lang":
+ # Check for language preferences
+ zh_keywords = ["chinese", "中文", "用中文", "请用中文"]
+ en_keywords = ["english", "英文", "in english"]
+ return any(kw in text_lower for kw in zh_keywords + en_keywords)
+
+ return False
+
+
+def compute_pref_recall_at_k(logs: List[TurnLog], dim: str, session_filter: Optional[int] = None) -> float:
+ """
+ Compute preference memory recall@k for a given dimension.
+ Returns: fraction of base task turns where a relevant pref memory was retrieved.
+ """
+ hits = 0
+ total = 0
+
+ for log in logs:
+ if session_filter is not None and log.session_id != session_filter:
+ continue
+ if not is_base_task_turn(log):
+ continue
+
+ total += 1
+ # Check if any selected memory note matches the dimension
+ for note in log.selected_memory_notes:
+ if is_pref_memory(note, dim):
+ hits += 1
+ break
+
+ return hits / total if total > 0 else 0.0
+
+
+def print_comparison_table(full_logs: List[TurnLog], nopersonal_logs: List[TurnLog]):
+ """Print a comparison table of Full vs NoPersonal metrics."""
+
+ # Detect mode from logs
+ full_mode = full_logs[0].mode if full_logs else "unknown"
+ full_eval = "greedy" if (full_logs and full_logs[0].eval_mode) else "sample"
+ np_mode = nopersonal_logs[0].mode if nopersonal_logs else "unknown"
+ np_eval = "greedy" if (nopersonal_logs and nopersonal_logs[0].eval_mode) else "sample"
+
+ print("\n" + "=" * 70)
+ print("FULL vs NOPERSONAL COMPARISON")
+ print(f"Full: mode={full_mode}, selection={full_eval}")
+ print(f"NoPersonal: mode={np_mode}, selection={np_eval}")
+ print("=" * 70)
+
+ # 1. Session 2 Base Task Average Satisfaction
+ print("\n### 1. Session 2 Base Task Average Satisfaction")
+ print("-" * 50)
+
+ full_s2_sat = compute_overall_session2_avg_sat(full_logs)
+ nopersonal_s2_sat = compute_overall_session2_avg_sat(nopersonal_logs)
+ delta = full_s2_sat - nopersonal_s2_sat
+
+ print(f"{'Metric':<30} {'Full':<12} {'NoPersonal':<12} {'Delta':<12}")
+ print("-" * 50)
+ print(f"{'avg_sat_S2_base':<30} {full_s2_sat:<12.4f} {nopersonal_s2_sat:<12.4f} {delta:<+12.4f}")
+
+ # Per-user breakdown
+ full_user_sat = compute_session2_base_avg_sat(full_logs)
+ nopersonal_user_sat = compute_session2_base_avg_sat(nopersonal_logs)
+
+ print("\nPer-user Session 2 avg_sat:")
+ print(f"{'User':<20} {'Full':<12} {'NoPersonal':<12} {'Delta':<12}")
+ print("-" * 50)
+ all_users = set(full_user_sat.keys()) | set(nopersonal_user_sat.keys())
+ for user_id in sorted(all_users):
+ f_sat = full_user_sat.get(user_id, 0.0)
+ n_sat = nopersonal_user_sat.get(user_id, 0.0)
+ d = f_sat - n_sat
+ print(f"{user_id:<20} {f_sat:<12.4f} {n_sat:<12.4f} {d:<+12.4f}")
+
+ # 2. Violation Rates
+ print("\n### 2. Session 2 Violation Rates")
+ print("-" * 50)
+
+ full_viol = compute_violation_rates(full_logs, session_filter=2)
+ nopersonal_viol = compute_violation_rates(nopersonal_logs, session_filter=2)
+
+ all_viols = set(full_viol.keys()) | set(nopersonal_viol.keys())
+ key_viols = ["too_long", "no_bullets", "has_bullets", "wrong_lang", "empty_answer"]
+
+ print(f"{'Violation Type':<20} {'Full':<12} {'NoPersonal':<12} {'Delta':<12}")
+ print("-" * 50)
+ for v in key_viols:
+ if v in all_viols:
+ f_rate = full_viol.get(v, 0.0)
+ n_rate = nopersonal_viol.get(v, 0.0)
+ d = f_rate - n_rate
+ print(f"{v:<20} {f_rate:<12.4f} {n_rate:<12.4f} {d:<+12.4f}")
+
+ # Other violations
+ other_viols = [v for v in all_viols if v not in key_viols]
+ for v in sorted(other_viols):
+ f_rate = full_viol.get(v, 0.0)
+ n_rate = nopersonal_viol.get(v, 0.0)
+ d = f_rate - n_rate
+ print(f"{v:<20} {f_rate:<12.4f} {n_rate:<12.4f} {d:<+12.4f}")
+
+ # 3. Preference Memory Recall@k
+ print("\n### 3. Session 2 Preference Memory Recall@k")
+ print("-" * 50)
+
+ dims = ["short", "bullets", "lang"]
+ print(f"{'Dimension':<20} {'Full':<12} {'NoPersonal':<12} {'Delta':<12}")
+ print("-" * 50)
+ for dim in dims:
+ f_recall = compute_pref_recall_at_k(full_logs, dim, session_filter=2)
+ n_recall = compute_pref_recall_at_k(nopersonal_logs, dim, session_filter=2)
+ d = f_recall - n_recall
+ print(f"{dim:<20} {f_recall:<12.4f} {n_recall:<12.4f} {d:<+12.4f}")
+
+ # 4. Summary Statistics
+ print("\n### 4. Summary Statistics")
+ print("-" * 50)
+
+ def count_base_tasks(logs, session=None):
+ return sum(1 for l in logs if (session is None or l.session_id == session) and is_base_task_turn(l))
+
+ def count_complaints(logs, session=None):
+ return sum(1 for l in logs if (session is None or l.session_id == session) and l.is_complaint)
+
+ print(f"{'Statistic':<30} {'Full':<12} {'NoPersonal':<12}")
+ print("-" * 50)
+ print(f"{'Total turns':<30} {len(full_logs):<12} {len(nopersonal_logs):<12}")
+ print(f"{'S2 base task turns':<30} {count_base_tasks(full_logs, 2):<12} {count_base_tasks(nopersonal_logs, 2):<12}")
+ print(f"{'S2 complaint turns':<30} {count_complaints(full_logs, 2):<12} {count_complaints(nopersonal_logs, 2):<12}")
+
+ # Token usage
+ full_tokens = sum(l.total_tokens for l in full_logs)
+ nopersonal_tokens = sum(l.total_tokens for l in nopersonal_logs)
+ print(f"{'Total tokens':<30} {full_tokens:<12} {nopersonal_tokens:<12}")
+
+ print("\n" + "=" * 70)
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Analyze Full vs NoPersonal Comparison")
+ parser.add_argument("--full", type=str, required=True, help="Path to Full mode log file")
+ parser.add_argument("--nopersonal", type=str, required=True, help="Path to NoPersonal mode log file")
+ args = parser.parse_args()
+
+ print(f"Loading Full logs from: {args.full}")
+ full_logs = load_logs(args.full)
+ print(f" Loaded {len(full_logs)} turns")
+
+ print(f"Loading NoPersonal logs from: {args.nopersonal}")
+ nopersonal_logs = load_logs(args.nopersonal)
+ print(f" Loaded {len(nopersonal_logs)} turns")
+
+ print_comparison_table(full_logs, nopersonal_logs)
+
+
+if __name__ == "__main__":
+ main()
+