diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
| commit | e43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch) | |
| tree | 6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 | |
124 files changed, 13113 insertions, 0 deletions
diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/.env.example diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f141132 --- /dev/null +++ b/.gitignore @@ -0,0 +1,35 @@ +# Python +__pycache__/ +*.pyc +*.pyo +*.pyd +.Python +env/ +venv/ +.venv/ +*.egg-info/ + +# Models (Large model weights) +models/ +*.safetensors +*.bin +*.pt +*.pth + +# IDE / System +.vscode/ +.idea/ +.DS_Store + +# Logs and temporary files +*.log +tmp/ + +# LLaMA-Factory +LLaMA-Factory/ +# Cursor +.cursor/ + +data/ + +saves/
\ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/README.md diff --git a/configs/base.yaml b/configs/base.yaml new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/configs/base.yaml diff --git a/configs/local_models.yaml b/configs/local_models.yaml new file mode 100644 index 0000000..13c7fcc --- /dev/null +++ b/configs/local_models.yaml @@ -0,0 +1,50 @@ +models: + llm: + # New Multi-Backend Config + qwen_1_5b: + backend: qwen + path: models/qwen2.5-1.5b-instruct + device: auto + dtype: bfloat16 + max_context_length: 4096 + + llama_8b: + backend: llama + path: models/llama-3.1-8b-instruct + device: auto + dtype: bfloat16 + max_context_length: 8192 + + # Legacy fallback (needed if from_config is called directly without name) + hf_id: Qwen/Qwen2.5-1.5B-Instruct + local_path: models/qwen2.5-1.5b-instruct + dtype: bfloat16 + device_map: auto + + preference_extractor: + # Default/Legacy + default: + hf_id: Qwen/Qwen2.5-0.5B-Instruct + local_path: models/qwen2.5-0.5b-instruct + dtype: bfloat16 + device_map: auto + # New SFT Extractor + qwen3_0_6b_sft: + path: saves/qwen3-0.6b-full-sft-h200/checkpoint-4358 + prompt_template_path: fine_tuning_prompt_template.txt + device: auto + dtype: bfloat16 + max_new_tokens: 512 + embedding: + qwen3: + hf_id: Qwen/Qwen3-Embedding-8B + local_path: models/qwen3-embedding-8b + nemotron: + hf_id: nvidia/llama-embed-nemotron-8b + local_path: models/llama-embed-nemotron-8b + reranker: + qwen3_8b: + hf_id: Qwen/Qwen3-Reranker-8B + local_path: models/rerankers/qwen3-reranker-8b + dtype: bfloat16 + device_map: auto diff --git a/configs/qwen2.5_0.5b_full_sft.yaml b/configs/qwen2.5_0.5b_full_sft.yaml new file mode 100644 index 0000000..ca1cca2 --- /dev/null +++ b/configs/qwen2.5_0.5b_full_sft.yaml @@ -0,0 +1,34 @@ +### Qwen2.5-0.5B Full SFT Config +model_name_or_path: Qwen/Qwen2.5-0.5B-Instruct +stage: sft +do_train: true +finetuning_type: full +freeze_trainable_layers: 0 + +dataset: preference_extractor_train +template: qwen +cutoff_len: 1024 +overwrite_cache: true +preprocessing_num_workers: 16 + +output_dir: saves/qwen2.5-0.5b-full-sft +logging_steps: 10 +save_strategy: steps +save_steps: 500 +plot_loss: true +overwrite_output_dir: true + +per_device_train_batch_size: 16 +gradient_accumulation_steps: 8 +learning_rate: 2.0e-5 +num_train_epochs: 1.0 +lr_scheduler_type: cosine +warmup_ratio: 0.05 +bf16: true +flash_attn: fa2 + +val_size: 0.01 +per_device_eval_batch_size: 16 +eval_strategy: steps +eval_steps: 500 + diff --git a/configs/qwen2.5_1.5b_full_sft.yaml b/configs/qwen2.5_1.5b_full_sft.yaml new file mode 100644 index 0000000..e91176b --- /dev/null +++ b/configs/qwen2.5_1.5b_full_sft.yaml @@ -0,0 +1,33 @@ +### Qwen2.5-1.5B Full SFT Config +model_name_or_path: Qwen/Qwen2.5-1.5B-Instruct +stage: sft +do_train: true +finetuning_type: full + +dataset: preference_extractor_train +template: qwen +cutoff_len: 1024 +overwrite_cache: true +preprocessing_num_workers: 16 + +output_dir: saves/qwen2.5-1.5b-full-sft +logging_steps: 10 +save_strategy: steps +save_steps: 500 +plot_loss: true +overwrite_output_dir: true + +per_device_train_batch_size: 8 +gradient_accumulation_steps: 16 +learning_rate: 2.0e-5 +num_train_epochs: 1.0 +lr_scheduler_type: cosine +warmup_ratio: 0.05 +bf16: true +flash_attn: fa2 + +val_size: 0.01 +per_device_eval_batch_size: 8 +eval_strategy: steps +eval_steps: 500 + diff --git a/configs/qwen3_0.6b_full_sft.yaml b/configs/qwen3_0.6b_full_sft.yaml new file mode 100644 index 0000000..e41a419 --- /dev/null +++ b/configs/qwen3_0.6b_full_sft.yaml @@ -0,0 +1,35 @@ +### Qwen3-0.6B Full SFT Config (H200x4 Optimized) +model_name_or_path: Qwen/Qwen3-0.6B +stage: sft +do_train: true +finetuning_type: full +freeze_trainable_layers: 0 + +dataset: preference_extractor_train +template: qwen +cutoff_len: 1024 +overwrite_cache: true +preprocessing_num_workers: 16 + +output_dir: saves/qwen3-0.6b-full-sft-h200 +logging_steps: 5 +save_strategy: steps +save_steps: 200 +plot_loss: true +overwrite_output_dir: true + +# H200x4 Configuration +# Total Batch Size = 32 * 4 * 1 = 128 +per_device_train_batch_size: 32 +gradient_accumulation_steps: 1 +learning_rate: 2.0e-5 +num_train_epochs: 1.0 +lr_scheduler_type: cosine +warmup_ratio: 0.05 +bf16: true +flash_attn: fa2 + +val_size: 0.01 +per_device_eval_batch_size: 32 +eval_strategy: steps +eval_steps: 200 diff --git a/configs/qwen3_1.7b_full_sft.yaml b/configs/qwen3_1.7b_full_sft.yaml new file mode 100644 index 0000000..069c53d --- /dev/null +++ b/configs/qwen3_1.7b_full_sft.yaml @@ -0,0 +1,34 @@ +### Qwen3-1.7B Full SFT Config (H200x4 Optimized) +model_name_or_path: models/Qwen3-1.7B +stage: sft +do_train: true +finetuning_type: full + +dataset: preference_extractor_train +template: qwen +cutoff_len: 1024 +overwrite_cache: true +preprocessing_num_workers: 4 + +output_dir: saves/qwen3-1.7b-full-sft-h200 +logging_steps: 5 +save_strategy: steps +save_steps: 200 +plot_loss: true +overwrite_output_dir: true + +# H200x4 Configuration +# Total Batch Size = 32 * 4 * 1 = 128 +per_device_train_batch_size: 32 +gradient_accumulation_steps: 1 +learning_rate: 2.0e-5 +num_train_epochs: 1.0 +lr_scheduler_type: cosine +warmup_ratio: 0.05 +bf16: true +flash_attn: fa2 + +val_size: 0.01 +per_device_eval_batch_size: 32 +eval_strategy: steps +eval_steps: 200 diff --git a/configs/reranker.yaml b/configs/reranker.yaml new file mode 100644 index 0000000..c376fc7 --- /dev/null +++ b/configs/reranker.yaml @@ -0,0 +1,3 @@ +reranker: + default: qwen3_8b + diff --git a/configs/retrieval.yaml b/configs/retrieval.yaml new file mode 100644 index 0000000..d2e100e --- /dev/null +++ b/configs/retrieval.yaml @@ -0,0 +1,5 @@ +retrieval: + dense_topk: 64 # Initial recall count + rerank_topk: 8 # Count fed to LLM after rerank + pca_dim: 256 + diff --git a/configs/user_model.yaml b/configs/user_model.yaml new file mode 100644 index 0000000..7b8e230 --- /dev/null +++ b/configs/user_model.yaml @@ -0,0 +1,14 @@ +user_model: + item_dim: 256 + user_dim: 256 + beta_long: 0.1 # Enable personalization for Day 4 + beta_short: 0.3 + tau: 1.0 + preference_extractor_name: qwen3_0_6b_sft # Switch to new extractor + rl: + eta_long: 1.0e-3 + eta_short: 5.0e-3 + ema_alpha: 0.05 + short_decay: 0.1 + +llm_name: llama_8b # Switch backend to Llama 3.1 diff --git a/fine_tuning_prompt_template.txt b/fine_tuning_prompt_template.txt new file mode 100644 index 0000000..749fdb6 --- /dev/null +++ b/fine_tuning_prompt_template.txt @@ -0,0 +1,31 @@ +=== System Prompt === +You are a preference extraction assistant. Your task is to analyze the user's query and extract any persistent preferences (such as style, formatting, programming language, or constraints) that should apply to future interactions. + +Output strictly valid JSON following this structure: +{ + "preferences": [ + { + "condition": "When to apply this preference (e.g., 'code generation', 'general')", + "action": "What to do (e.g., 'use Python', 'be concise')", + "confidence": 1.0 + } + ] +} + +If the user query contains no persistent preferences (e.g., simple greetings, one-off tasks, factual questions), return: +{"preferences": []} + +=== User Input Example === +I am a Python developer, so always give me code examples in Python. + +=== Assistant Output Example === +{ + "preferences": [ + { + "condition": "code examples", + "action": "use Python", + "confidence": 1.0 + } + ] +} + diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..9c1e7ac --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,38 @@ +[ +tool.black +] + +[build-system] +requires = ["hatchling>=1.18"] +build-backend = "hatchling.build" + +[project] +name = "personalization-user-model" +version = "0.1.0" +description = "Personalized memory RAG system with online user modeling" +readme = "README.md" +requires-python = ">=3.10" +license = { text = "Apache-2.0" } +authors = [ + { name = "yurenh2" } +] +dependencies = [ + "torch>=2.3.0", + "transformers>=4.44.0", + "accelerate>=0.33.0", + "huggingface_hub>=0.24.0", + "pydantic>=2.7.0", + "pyyaml>=6.0.0", + "safetensors>=0.4.2" +] + +[project.urls] +homepage = "https://example.com" + +[tool.hatch.build.targets.wheel] +packages = ["src/personalization"] + +[tool.hatch.metadata] +allow-direct-references = true + + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..1e227de --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +torch>=2.3.0 +transformers>=4.44.0 +accelerate>=0.33.0 +huggingface_hub>=0.24.0 +pydantic>=2.7.0 +PyYAML>=6.0.0 +safetensors>=0.4.2 + + diff --git a/scripts/analyze_full_vs_nopersonal.py b/scripts/analyze_full_vs_nopersonal.py new file mode 100644 index 0000000..a10f3ef --- /dev/null +++ b/scripts/analyze_full_vs_nopersonal.py @@ -0,0 +1,361 @@ +#!/usr/bin/env python3 +""" +Analyze Full vs NoPersonal Baseline Comparison. + +This script loads logs from pilot_runner_v4 runs (both full and nopersonal modes) +and produces comparison metrics for: +1. Session 2 retention (base task avg satisfaction) +2. Violation rates by type +3. Preference memory recall@k + +Usage: + python scripts/analyze_full_vs_nopersonal.py \ + --full data/logs/pilot_v4_full_TIMESTAMP.jsonl \ + --nopersonal data/logs/pilot_v4_nopersonal_TIMESTAMP.jsonl +""" + +import json +import argparse +import re +from dataclasses import dataclass +from typing import List, Dict, Any, Optional, Set +from collections import defaultdict + + +@dataclass +class TurnLog: + """Parsed log entry.""" + user_id: str + persona_id: str + session_id: int + turn_id: int + query: str + query_type: str + task_type: str + answer: str + answer_length: int + sat_t: float + sev_t: float + prog_t: float + violations: List[str] + enforced_constraints: List[str] + reward: float + gating: float + is_complaint: bool + reveal_state_before: Dict[str, bool] + reveal_state_after: Dict[str, bool] + newly_revealed: List[str] + z_long_norm_before: float + z_long_norm_after: float + z_short_norm_before: float + z_short_norm_after: float + prompt_tokens: int + completion_tokens: int + total_tokens: int + num_memories_retrieved: int + num_prefs_extracted: int + selected_memory_ids: List[str] + selected_memory_notes: List[str] + selected_memory_scores: List[float] + num_candidates: int + num_total_memories: int + mode: str + eval_mode: bool # True = greedy, False = sample + + +def load_logs(filepath: str) -> List[TurnLog]: + """Load logs from JSONL file.""" + logs = [] + with open(filepath, "r") as f: + for line in f: + if line.strip(): + data = json.loads(line) + # Handle missing fields with defaults + log = TurnLog( + user_id=data.get("user_id", ""), + persona_id=data.get("persona_id", ""), + session_id=data.get("session_id", 0), + turn_id=data.get("turn_id", 0), + query=data.get("query", ""), + query_type=data.get("query_type", ""), + task_type=data.get("task_type", ""), + answer=data.get("answer", ""), + answer_length=data.get("answer_length", 0), + sat_t=data.get("sat_t", 0.0), + sev_t=data.get("sev_t", 0.0), + prog_t=data.get("prog_t", 0.0), + violations=data.get("violations", []), + enforced_constraints=data.get("enforced_constraints", []), + reward=data.get("reward", 0.0), + gating=data.get("gating", 0.0), + is_complaint=data.get("is_complaint", False), + reveal_state_before=data.get("reveal_state_before", {}), + reveal_state_after=data.get("reveal_state_after", {}), + newly_revealed=data.get("newly_revealed", []), + z_long_norm_before=data.get("z_long_norm_before", 0.0), + z_long_norm_after=data.get("z_long_norm_after", 0.0), + z_short_norm_before=data.get("z_short_norm_before", 0.0), + z_short_norm_after=data.get("z_short_norm_after", 0.0), + prompt_tokens=data.get("prompt_tokens", 0), + completion_tokens=data.get("completion_tokens", 0), + total_tokens=data.get("total_tokens", 0), + num_memories_retrieved=data.get("num_memories_retrieved", 0), + num_prefs_extracted=data.get("num_prefs_extracted", 0), + selected_memory_ids=data.get("selected_memory_ids", []), + selected_memory_notes=data.get("selected_memory_notes", []), + selected_memory_scores=data.get("selected_memory_scores", []), + num_candidates=data.get("num_candidates", 0), + num_total_memories=data.get("num_total_memories", 0), + mode=data.get("mode", "unknown"), + eval_mode=data.get("eval_mode", True), + ) + logs.append(log) + return logs + + +def is_base_task_turn(log: TurnLog) -> bool: + """Check if this is a base task turn (not complaint, not preference).""" + if log.is_complaint: + return False + if log.query_type == "preference": + return False + if log.query_type in ("task", "task_list"): + return True + return False + + +def compute_session2_base_avg_sat(logs: List[TurnLog]) -> Dict[str, float]: + """ + Compute average satisfaction for Session 2 base tasks. + Returns: {user_id: avg_sat} + """ + user_sat = defaultdict(list) + + for log in logs: + if log.session_id == 2 and is_base_task_turn(log): + user_sat[log.user_id].append(log.sat_t) + + result = {} + for user_id, sats in user_sat.items(): + if sats: + result[user_id] = sum(sats) / len(sats) + + return result + + +def compute_overall_session2_avg_sat(logs: List[TurnLog]) -> float: + """Compute overall average satisfaction for Session 2 base tasks.""" + sats = [] + for log in logs: + if log.session_id == 2 and is_base_task_turn(log): + sats.append(log.sat_t) + return sum(sats) / len(sats) if sats else 0.0 + + +def compute_violation_rates(logs: List[TurnLog], session_filter: Optional[int] = None) -> Dict[str, float]: + """ + Compute violation rates by type. + Returns: {violation_type: rate} + """ + violation_counts = defaultdict(int) + total_base_tasks = 0 + + for log in logs: + if session_filter is not None and log.session_id != session_filter: + continue + if not is_base_task_turn(log): + continue + + total_base_tasks += 1 + for v in log.violations: + violation_counts[v] += 1 + + if total_base_tasks == 0: + return {} + + return {v: count / total_base_tasks for v, count in violation_counts.items()} + + +def is_pref_memory(note_text: str, dim: str) -> bool: + """ + Check if a memory note relates to a preference dimension. + dim: "short", "bullets", or "lang" + """ + text_lower = note_text.lower() + + if dim == "short": + keywords = [ + "short", "concise", "brief", "200", "characters", "less", + "简短", "精简", "字以内", "不超过", "简洁" + ] + return any(kw in text_lower for kw in keywords) + + elif dim == "bullets": + keywords = [ + "bullet", "bullets", "list", "point", "points", + "要点", "列表", "项目符号" + ] + # Also check for "no bullet" / "don't use bullet" + no_bullet = any(x in text_lower for x in ["no bullet", "don't use bullet", "without bullet", "不要要点", "不使用列表"]) + if no_bullet: + return True # It's still about bullets preference + return any(kw in text_lower for kw in keywords) + + elif dim == "lang": + # Check for language preferences + zh_keywords = ["chinese", "中文", "用中文", "请用中文"] + en_keywords = ["english", "英文", "in english"] + return any(kw in text_lower for kw in zh_keywords + en_keywords) + + return False + + +def compute_pref_recall_at_k(logs: List[TurnLog], dim: str, session_filter: Optional[int] = None) -> float: + """ + Compute preference memory recall@k for a given dimension. + Returns: fraction of base task turns where a relevant pref memory was retrieved. + """ + hits = 0 + total = 0 + + for log in logs: + if session_filter is not None and log.session_id != session_filter: + continue + if not is_base_task_turn(log): + continue + + total += 1 + # Check if any selected memory note matches the dimension + for note in log.selected_memory_notes: + if is_pref_memory(note, dim): + hits += 1 + break + + return hits / total if total > 0 else 0.0 + + +def print_comparison_table(full_logs: List[TurnLog], nopersonal_logs: List[TurnLog]): + """Print a comparison table of Full vs NoPersonal metrics.""" + + # Detect mode from logs + full_mode = full_logs[0].mode if full_logs else "unknown" + full_eval = "greedy" if (full_logs and full_logs[0].eval_mode) else "sample" + np_mode = nopersonal_logs[0].mode if nopersonal_logs else "unknown" + np_eval = "greedy" if (nopersonal_logs and nopersonal_logs[0].eval_mode) else "sample" + + print("\n" + "=" * 70) + print("FULL vs NOPERSONAL COMPARISON") + print(f"Full: mode={full_mode}, selection={full_eval}") + print(f"NoPersonal: mode={np_mode}, selection={np_eval}") + print("=" * 70) + + # 1. Session 2 Base Task Average Satisfaction + print("\n### 1. Session 2 Base Task Average Satisfaction") + print("-" * 50) + + full_s2_sat = compute_overall_session2_avg_sat(full_logs) + nopersonal_s2_sat = compute_overall_session2_avg_sat(nopersonal_logs) + delta = full_s2_sat - nopersonal_s2_sat + + print(f"{'Metric':<30} {'Full':<12} {'NoPersonal':<12} {'Delta':<12}") + print("-" * 50) + print(f"{'avg_sat_S2_base':<30} {full_s2_sat:<12.4f} {nopersonal_s2_sat:<12.4f} {delta:<+12.4f}") + + # Per-user breakdown + full_user_sat = compute_session2_base_avg_sat(full_logs) + nopersonal_user_sat = compute_session2_base_avg_sat(nopersonal_logs) + + print("\nPer-user Session 2 avg_sat:") + print(f"{'User':<20} {'Full':<12} {'NoPersonal':<12} {'Delta':<12}") + print("-" * 50) + all_users = set(full_user_sat.keys()) | set(nopersonal_user_sat.keys()) + for user_id in sorted(all_users): + f_sat = full_user_sat.get(user_id, 0.0) + n_sat = nopersonal_user_sat.get(user_id, 0.0) + d = f_sat - n_sat + print(f"{user_id:<20} {f_sat:<12.4f} {n_sat:<12.4f} {d:<+12.4f}") + + # 2. Violation Rates + print("\n### 2. Session 2 Violation Rates") + print("-" * 50) + + full_viol = compute_violation_rates(full_logs, session_filter=2) + nopersonal_viol = compute_violation_rates(nopersonal_logs, session_filter=2) + + all_viols = set(full_viol.keys()) | set(nopersonal_viol.keys()) + key_viols = ["too_long", "no_bullets", "has_bullets", "wrong_lang", "empty_answer"] + + print(f"{'Violation Type':<20} {'Full':<12} {'NoPersonal':<12} {'Delta':<12}") + print("-" * 50) + for v in key_viols: + if v in all_viols: + f_rate = full_viol.get(v, 0.0) + n_rate = nopersonal_viol.get(v, 0.0) + d = f_rate - n_rate + print(f"{v:<20} {f_rate:<12.4f} {n_rate:<12.4f} {d:<+12.4f}") + + # Other violations + other_viols = [v for v in all_viols if v not in key_viols] + for v in sorted(other_viols): + f_rate = full_viol.get(v, 0.0) + n_rate = nopersonal_viol.get(v, 0.0) + d = f_rate - n_rate + print(f"{v:<20} {f_rate:<12.4f} {n_rate:<12.4f} {d:<+12.4f}") + + # 3. Preference Memory Recall@k + print("\n### 3. Session 2 Preference Memory Recall@k") + print("-" * 50) + + dims = ["short", "bullets", "lang"] + print(f"{'Dimension':<20} {'Full':<12} {'NoPersonal':<12} {'Delta':<12}") + print("-" * 50) + for dim in dims: + f_recall = compute_pref_recall_at_k(full_logs, dim, session_filter=2) + n_recall = compute_pref_recall_at_k(nopersonal_logs, dim, session_filter=2) + d = f_recall - n_recall + print(f"{dim:<20} {f_recall:<12.4f} {n_recall:<12.4f} {d:<+12.4f}") + + # 4. Summary Statistics + print("\n### 4. Summary Statistics") + print("-" * 50) + + def count_base_tasks(logs, session=None): + return sum(1 for l in logs if (session is None or l.session_id == session) and is_base_task_turn(l)) + + def count_complaints(logs, session=None): + return sum(1 for l in logs if (session is None or l.session_id == session) and l.is_complaint) + + print(f"{'Statistic':<30} {'Full':<12} {'NoPersonal':<12}") + print("-" * 50) + print(f"{'Total turns':<30} {len(full_logs):<12} {len(nopersonal_logs):<12}") + print(f"{'S2 base task turns':<30} {count_base_tasks(full_logs, 2):<12} {count_base_tasks(nopersonal_logs, 2):<12}") + print(f"{'S2 complaint turns':<30} {count_complaints(full_logs, 2):<12} {count_complaints(nopersonal_logs, 2):<12}") + + # Token usage + full_tokens = sum(l.total_tokens for l in full_logs) + nopersonal_tokens = sum(l.total_tokens for l in nopersonal_logs) + print(f"{'Total tokens':<30} {full_tokens:<12} {nopersonal_tokens:<12}") + + print("\n" + "=" * 70) + + +def main(): + parser = argparse.ArgumentParser(description="Analyze Full vs NoPersonal Comparison") + parser.add_argument("--full", type=str, required=True, help="Path to Full mode log file") + parser.add_argument("--nopersonal", type=str, required=True, help="Path to NoPersonal mode log file") + args = parser.parse_args() + + print(f"Loading Full logs from: {args.full}") + full_logs = load_logs(args.full) + print(f" Loaded {len(full_logs)} turns") + + print(f"Loading NoPersonal logs from: {args.nopersonal}") + nopersonal_logs = load_logs(args.nopersonal) + print(f" Loaded {len(nopersonal_logs)} turns") + + print_comparison_table(full_logs, nopersonal_logs) + + +if __name__ == "__main__": + main() + diff --git a/scripts/analyze_learning_trend.py b/scripts/analyze_learning_trend.py new file mode 100644 index 0000000..9ab4699 --- /dev/null +++ b/scripts/analyze_learning_trend.py @@ -0,0 +1,521 @@ +#!/usr/bin/env python3 +""" +Analyze Learning Trend: Correlation and z_u Norm over Sessions + +This script shows that: +1. User vector norms (||z_u||) grow over sessions (learning is happening) +2. Correlation between learned and ground-truth similarity increases over sessions + +Usage: + python scripts/analyze_learning_trend.py \ + --logs data/logs/pilot_v4_full-greedy_*.jsonl +""" + +import argparse +import json +import numpy as np +from typing import Dict, List, Tuple +from collections import defaultdict +from dataclasses import dataclass +import os + + +# ============================================================================= +# Persona Definitions (ground truth) +# ============================================================================= + +@dataclass +class StylePrefs: + require_short: bool = False + max_chars: int = 300 + require_bullets: bool = False + lang: str = "en" + + +PERSONAS = { + "user_A_short_bullets_en": StylePrefs(require_short=True, max_chars=200, require_bullets=True, lang="en"), + "user_B_short_no_bullets_en": StylePrefs(require_short=True, max_chars=200, require_bullets=False, lang="en"), + "user_C_long_bullets_en": StylePrefs(require_short=False, max_chars=800, require_bullets=True, lang="en"), + "user_D_short_bullets_zh": StylePrefs(require_short=True, max_chars=200, require_bullets=True, lang="zh"), + "user_E_long_no_bullets_zh": StylePrefs(require_short=False, max_chars=800, require_bullets=False, lang="zh"), + "user_F_extreme_short_en": StylePrefs(require_short=True, max_chars=100, require_bullets=True, lang="en"), +} + + +# ============================================================================= +# Data Loading +# ============================================================================= + +def load_logs(filepath: str) -> List[dict]: + """Load turn logs from JSONL file.""" + logs = [] + with open(filepath, "r") as f: + for line in f: + if line.strip(): + logs.append(json.loads(line)) + return logs + + +def extract_z_norms_by_session(logs: List[dict]) -> Dict[str, Dict[int, Tuple[float, float]]]: + """ + Extract z_long_norm and z_short_norm at the end of each session for each user. + + Returns: + {user_id: {session_id: (z_long_norm, z_short_norm)}} + """ + user_session_norms = defaultdict(dict) + + # Group by user and session, take the last turn's z_norm + user_session_turns = defaultdict(lambda: defaultdict(list)) + for log in logs: + user_id = log["user_id"] + session_id = log["session_id"] + user_session_turns[user_id][session_id].append(log) + + for user_id, sessions in user_session_turns.items(): + for session_id, turns in sessions.items(): + # Get the last turn of this session + last_turn = max(turns, key=lambda x: x["turn_id"]) + z_long = last_turn.get("z_long_norm_after", 0.0) + z_short = last_turn.get("z_short_norm_after", 0.0) + user_session_norms[user_id][session_id] = (z_long, z_short) + + return dict(user_session_norms) + + +# ============================================================================= +# Similarity Computation +# ============================================================================= + +def cosine_similarity(v1: np.ndarray, v2: np.ndarray) -> float: + """Compute cosine similarity.""" + norm1 = np.linalg.norm(v1) + norm2 = np.linalg.norm(v2) + if norm1 < 1e-10 or norm2 < 1e-10: + return 0.0 + return float(np.dot(v1, v2) / (norm1 * norm2)) + + +def compute_ground_truth_similarity_matrix(user_order: List[str]) -> np.ndarray: + """Compute ground truth similarity based on preference overlap.""" + n = len(user_order) + sim_matrix = np.zeros((n, n)) + + for i, u1 in enumerate(user_order): + for j, u2 in enumerate(user_order): + if u1 not in PERSONAS or u2 not in PERSONAS: + sim_matrix[i, j] = 0.0 if i != j else 1.0 + continue + + p1 = PERSONAS[u1] + p2 = PERSONAS[u2] + + matches = 0 + if p1.require_short == p2.require_short: + matches += 1 + if p1.require_bullets == p2.require_bullets: + matches += 1 + if p1.lang == p2.lang: + matches += 1 + + sim_matrix[i, j] = matches / 3.0 + + return sim_matrix + + +def compute_spearman_correlation(learned: np.ndarray, ground_truth: np.ndarray) -> float: + """Compute Spearman correlation between similarity matrices.""" + from scipy.stats import spearmanr + + n = learned.shape[0] + learned_flat = [] + gt_flat = [] + + for i in range(n): + for j in range(i + 1, n): + learned_flat.append(learned[i, j]) + gt_flat.append(ground_truth[i, j]) + + if len(learned_flat) < 2: + return 0.0 + + # Handle case where all values are the same + if np.std(learned_flat) < 1e-10: + return 0.0 + + corr, _ = spearmanr(learned_flat, gt_flat) + return float(corr) if not np.isnan(corr) else 0.0 + + +def load_final_z_vectors(user_store_path: str) -> Dict[str, Tuple[np.ndarray, np.ndarray]]: + """Load final z_u vectors from saved user store.""" + try: + data = np.load(user_store_path, allow_pickle=True) + user_vectors = {} + + # UserTensorStore saves in format: {uid}_long, {uid}_short + user_ids = set() + for key in data.files: + if key.endswith("_long"): + uid = key[:-5] + user_ids.add(uid) + + for uid in user_ids: + long_key = f"{uid}_long" + short_key = f"{uid}_short" + if long_key in data.files and short_key in data.files: + user_vectors[uid] = (data[long_key], data[short_key]) + + return user_vectors + except Exception as e: + print(f"[Warning] Could not load user store: {e}") + return {} + + +# Global cache for final z vectors +_FINAL_Z_VECTORS = None + + +def get_z_vectors_at_session( + logs: List[dict], + user_order: List[str], + up_to_session: int, + final_z_vectors: Dict[str, Tuple[np.ndarray, np.ndarray]] +) -> Dict[str, np.ndarray]: + """ + Estimate z_u vectors at a given session checkpoint. + + Method: Use the DIRECTION of the final z_u, scaled by the z_norm at session s. + This assumes z_u direction is relatively stable but magnitude grows. + + z_u(s) ≈ (z_final / ||z_final||) * ||z(s)|| + """ + user_vectors = {} + + for user_id in user_order: + # Get z_norm at the end of this session + user_turns = [l for l in logs if l["user_id"] == user_id and l["session_id"] <= up_to_session] + + if not user_turns: + user_vectors[user_id] = np.zeros(512) # 256 + 256 + continue + + # Get the last turn's z_norm at this session + last_turn = max(user_turns, key=lambda x: (x["session_id"], x["turn_id"])) + z_long_norm_s = last_turn.get("z_long_norm_after", 0.0) + z_short_norm_s = last_turn.get("z_short_norm_after", 0.0) + + # Get final z vectors (direction) + if user_id in final_z_vectors: + z_long_final, z_short_final = final_z_vectors[user_id] + + # Compute unit vectors (direction) + z_long_final_norm = np.linalg.norm(z_long_final) + z_short_final_norm = np.linalg.norm(z_short_final) + + if z_long_final_norm > 1e-10: + z_long_unit = z_long_final / z_long_final_norm + else: + z_long_unit = np.zeros_like(z_long_final) + + if z_short_final_norm > 1e-10: + z_short_unit = z_short_final / z_short_final_norm + else: + z_short_unit = np.zeros_like(z_short_final) + + # Scale by the norm at this session + z_long_s = z_long_unit * z_long_norm_s + z_short_s = z_short_unit * z_short_norm_s + + # Concatenate + user_vectors[user_id] = np.concatenate([z_long_s, z_short_s]) + else: + user_vectors[user_id] = np.zeros(512) + + return user_vectors + + +def compute_similarity_at_session( + logs: List[dict], + user_order: List[str], + up_to_session: int, + final_z_vectors: Dict[str, Tuple[np.ndarray, np.ndarray]] = None +) -> np.ndarray: + """Compute learned similarity matrix at a given session using actual z vectors.""" + if final_z_vectors: + user_vectors = get_z_vectors_at_session(logs, user_order, up_to_session, final_z_vectors) + else: + # Fallback to old method + user_vectors = simulate_z_vectors_at_session_fallback(logs, user_order, up_to_session) + + n = len(user_order) + sim_matrix = np.zeros((n, n)) + + for i, u1 in enumerate(user_order): + for j, u2 in enumerate(user_order): + v1 = user_vectors.get(u1, np.zeros(512)) + v2 = user_vectors.get(u2, np.zeros(512)) + sim_matrix[i, j] = cosine_similarity(v1, v2) + + return sim_matrix + + +def simulate_z_vectors_at_session_fallback( + logs: List[dict], + user_order: List[str], + up_to_session: int, + dim: int = 256 +) -> Dict[str, np.ndarray]: + """Fallback: simulate z_u based on violation patterns (less accurate).""" + user_vectors = {} + + for user_id in user_order: + user_turns = [l for l in logs if l["user_id"] == user_id and l["session_id"] <= up_to_session] + + if not user_turns: + user_vectors[user_id] = np.zeros(dim * 2) + continue + + last_turn = max(user_turns, key=lambda x: (x["session_id"], x["turn_id"])) + z_long_norm = last_turn.get("z_long_norm_after", 0.0) + z_short_norm = last_turn.get("z_short_norm_after", 0.0) + + violation_counts = defaultdict(int) + for turn in user_turns: + for v in turn.get("violations", []): + violation_counts[v] += 1 + + feature_dim = 10 + features = np.zeros(feature_dim) + features[0] = violation_counts.get("too_long", 0) + features[1] = violation_counts.get("no_bullets", 0) + features[2] = violation_counts.get("has_bullets", 0) + features[3] = violation_counts.get("wrong_lang", 0) + features[4] = z_long_norm * 100 + features[5] = z_short_norm * 100 + + norm = np.linalg.norm(features) + if norm > 1e-10: + features = features / norm + + user_vectors[user_id] = features + + return user_vectors + + +def compute_similarity_at_session( + logs: List[dict], + user_order: List[str], + up_to_session: int +) -> np.ndarray: + """Compute learned similarity matrix at a given session.""" + user_vectors = simulate_z_vectors_at_session(logs, user_order, up_to_session) + + n = len(user_order) + sim_matrix = np.zeros((n, n)) + + for i, u1 in enumerate(user_order): + for j, u2 in enumerate(user_order): + v1 = user_vectors.get(u1, np.zeros(10)) + v2 = user_vectors.get(u2, np.zeros(10)) + sim_matrix[i, j] = cosine_similarity(v1, v2) + + return sim_matrix + + +# ============================================================================= +# Main Analysis +# ============================================================================= + +def analyze_learning_trend(logs_path: str, output_dir: str = "data/analysis", + user_store_path: str = "data/users/user_store_pilot_v4_full-greedy.npz"): + """Analyze correlation and z_u norm trends over sessions.""" + os.makedirs(output_dir, exist_ok=True) + + print("=" * 70) + print("LEARNING TREND ANALYSIS") + print("=" * 70) + + # Load logs + print(f"\n[1] Loading logs from: {logs_path}") + logs = load_logs(logs_path) + print(f" Loaded {len(logs)} turns") + + # Get user order + user_order = [u for u in PERSONAS.keys() if any(l["user_id"] == u for l in logs)] + print(f" Users: {user_order}") + + # Get max session + max_session = max(l["session_id"] for l in logs) + print(f" Sessions: 1 to {max_session}") + + # Extract z_norms by session + print("\n[2] Extracting z_u norms by session...") + z_norms_by_session = extract_z_norms_by_session(logs) + + # Load final z vectors from user store + print(f"\n[2.5] Loading final z vectors from: {user_store_path}") + final_z_vectors = load_final_z_vectors(user_store_path) + if final_z_vectors: + print(f" Loaded final z vectors for {len(final_z_vectors)} users") + else: + print(" [Warning] No final z vectors found, using fallback method") + + # Compute ground truth similarity (constant) + gt_sim = compute_ground_truth_similarity_matrix(user_order) + + # Compute CUMULATIVE correlation and avg z_norm + # At session N, we use all data from session 1 to N + print("\n[3] Computing CUMULATIVE correlation trend (S1→S1-2→S1-3→...→S1-N)...") + sessions = list(range(1, max_session + 1)) + correlations = [] + avg_z_norms = [] + + for s in sessions: + # Compute similarity using z_u at end of session s (cumulative learning) + learned_sim = compute_similarity_at_session(logs, user_order, s, final_z_vectors) + corr = compute_spearman_correlation(learned_sim, gt_sim) + correlations.append(corr) + + # Compute average z_norm at the END of session s (this is already cumulative) + z_norms = [] + for user_id in user_order: + if user_id in z_norms_by_session and s in z_norms_by_session[user_id]: + zl, zs = z_norms_by_session[user_id][s] + z_norms.append(np.sqrt(zl**2 + zs**2)) # Combined norm + + avg_z = np.mean(z_norms) if z_norms else 0.0 + avg_z_norms.append(avg_z) + + # Print results + print("\n[4] Results:") + print("-" * 60) + print(f"{'Session':<10} {'Correlation':<15} {'Avg ||z_u||':<15}") + print("-" * 60) + for s, corr, z_norm in zip(sessions, correlations, avg_z_norms): + print(f"{s:<10} {corr:<15.4f} {z_norm:<15.6f}") + + # Summary statistics + print("\n[5] Trend Summary:") + print("-" * 60) + + # Linear regression for correlation trend + from scipy.stats import linregress + slope_corr, intercept_corr, r_corr, p_corr, _ = linregress(sessions, correlations) + print(f" Correlation trend: slope={slope_corr:.4f}, R²={r_corr**2:.4f}, p={p_corr:.4f}") + + # Linear regression for z_norm trend + slope_z, intercept_z, r_z, p_z, _ = linregress(sessions, avg_z_norms) + print(f" ||z_u|| trend: slope={slope_z:.6f}, R²={r_z**2:.4f}, p={p_z:.4f}") + + # Correlation between the two trends + trend_corr, _ = spearmanr(correlations, avg_z_norms) if len(correlations) > 2 else (0, 1) + print(f" Correlation between trends: {trend_corr:.4f}") + + # Save data + results = { + "sessions": np.array(sessions), + "correlations": np.array(correlations), + "avg_z_norms": np.array(avg_z_norms), + "slope_corr": slope_corr, + "slope_z": slope_z, + "trend_corr": trend_corr, + } + results_path = os.path.join(output_dir, "learning_trend_results.npz") + np.savez(results_path, **results) + print(f"\n[Results] Saved to: {results_path}") + + # Plot + print("\n[6] Generating plots...") + plot_learning_trend(sessions, correlations, avg_z_norms, output_dir) + + print("\n" + "=" * 70) + print("ANALYSIS COMPLETE") + print("=" * 70) + + return results + + +def plot_learning_trend(sessions, correlations, avg_z_norms, output_dir): + """Generate plots for learning trend.""" + try: + import matplotlib.pyplot as plt + import matplotlib + matplotlib.use('Agg') # Non-interactive backend + except ImportError: + print("[Warning] matplotlib not available, skipping plots") + # Save as text instead + with open(os.path.join(output_dir, "learning_trend.txt"), "w") as f: + f.write("Session,Correlation,Avg_Z_Norm\n") + for s, c, z in zip(sessions, correlations, avg_z_norms): + f.write(f"{s},{c:.4f},{z:.6f}\n") + print(f"[Data] Saved to: {os.path.join(output_dir, 'learning_trend.txt')}") + return + + fig, axes = plt.subplots(1, 2, figsize=(12, 5)) + + # Plot 1: Correlation vs Session + ax1 = axes[0] + ax1.plot(sessions, correlations, 'o-', color='#2ecc71', linewidth=2, markersize=8) + ax1.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + + # Add trend line + from scipy.stats import linregress + slope, intercept, _, _, _ = linregress(sessions, correlations) + trend_line = [slope * s + intercept for s in sessions] + ax1.plot(sessions, trend_line, '--', color='#27ae60', alpha=0.7, label=f'Trend (slope={slope:.3f})') + + ax1.set_xlabel('Sessions (Cumulative: 1→N)', fontsize=12) + ax1.set_ylabel('Spearman Correlation', fontsize=12) + ax1.set_title('Learned vs Ground-Truth Similarity\nCorrelation with Cumulative Data', fontsize=14) + ax1.set_xticks(sessions) + ax1.legend() + ax1.grid(True, alpha=0.3) + ax1.set_ylim(-0.5, 1.0) + + # Plot 2: z_u norm vs Session + ax2 = axes[1] + ax2.plot(sessions, avg_z_norms, 's-', color='#3498db', linewidth=2, markersize=8) + + # Add trend line + slope_z, intercept_z, _, _, _ = linregress(sessions, avg_z_norms) + trend_line_z = [slope_z * s + intercept_z for s in sessions] + ax2.plot(sessions, trend_line_z, '--', color='#2980b9', alpha=0.7, label=f'Trend (slope={slope_z:.5f})') + + ax2.set_xlabel('Session (End of)', fontsize=12) + ax2.set_ylabel('Average ||z_u||', fontsize=12) + ax2.set_title('User Vector Norm\n(Cumulative Learning)', fontsize=14) + ax2.set_xticks(sessions) + ax2.legend() + ax2.grid(True, alpha=0.3) + + plt.tight_layout() + + output_path = os.path.join(output_dir, "learning_trend.png") + plt.savefig(output_path, dpi=150, bbox_inches='tight') + print(f"[Plot] Saved to: {output_path}") + + # Also save as PDF for paper + pdf_path = os.path.join(output_dir, "learning_trend.pdf") + plt.savefig(pdf_path, bbox_inches='tight') + print(f"[Plot] Saved to: {pdf_path}") + + +# Need this import at top level for trend calculation +from scipy.stats import spearmanr + + +def main(): + parser = argparse.ArgumentParser(description="Analyze Learning Trend") + parser.add_argument("--logs", type=str, required=True, help="Path to log file") + parser.add_argument("--user-store", type=str, default="data/users/user_store_pilot_v4_full-greedy.npz", + help="Path to user store with final z vectors") + parser.add_argument("--output-dir", type=str, default="data/analysis", help="Output directory") + args = parser.parse_args() + + analyze_learning_trend(args.logs, args.output_dir, args.user_store) + + +if __name__ == "__main__": + main() + diff --git a/scripts/analyze_memory.py b/scripts/analyze_memory.py new file mode 100644 index 0000000..4c439cf --- /dev/null +++ b/scripts/analyze_memory.py @@ -0,0 +1,61 @@ +import json +import os +from collections import Counter, defaultdict + +MEMORY_FILE = "data/corpora/memory_cards.jsonl" + +def analyze_memory(): + if not os.path.exists(MEMORY_FILE): + print(f"Error: {MEMORY_FILE} not found.") + return + + print(f"Analyzing {MEMORY_FILE}...") + + total_cards = 0 + user_counts = Counter() + content_hashes = defaultdict(int) + user_content_hashes = defaultdict(int) + + with open(MEMORY_FILE, "r", encoding="utf-8") as f: + for line in f: + if not line.strip(): continue + try: + card = json.loads(line) + total_cards += 1 + uid = card.get("user_id", "unknown") + text = card.get("note_text", "").strip() + + user_counts[uid] += 1 + content_hashes[text] += 1 + user_content_hashes[(uid, text)] += 1 + except: + pass + + print("\n" + "="*40) + print("MEMORY STORE ANALYSIS") + print("="*40) + print(f"Total Cards: {total_cards}") + print(f"Unique Users: {len(user_counts)}") + print("-" * 40) + + print("\nTop 10 Users by Card Count:") + for uid, count in user_counts.most_common(10): + print(f" {uid}: {count}") + + print("\nTop 10 Most Frequent Contents (Global):") + sorted_content = sorted(content_hashes.items(), key=lambda x: x[1], reverse=True)[:10] + for text, count in sorted_content: + display_text = (text[:50] + '...') if len(text) > 50 else text + print(f" [{count}] {display_text}") + + print("\nTop 10 Most Frequent (User, Content) Duplicates:") + sorted_user_content = sorted(user_content_hashes.items(), key=lambda x: x[1], reverse=True)[:10] + for (uid, text), count in sorted_user_content: + display_text = (text[:50] + '...') if len(text) > 50 else text + print(f" [{count}] User: {uid} | {display_text}") + + print("="*40) + +if __name__ == "__main__": + analyze_memory() + diff --git a/scripts/analyze_memory_coverage.py b/scripts/analyze_memory_coverage.py new file mode 100644 index 0000000..e7ec498 --- /dev/null +++ b/scripts/analyze_memory_coverage.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +""" +Script to analyze Memory Card coverage statistics. +""" +import sys +import os +import json +import numpy as np +from collections import defaultdict + +# Add src to sys.path +sys.path.append(os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.retrieval.preference_store.schemas import MemoryCard + +def main(): + cards_path = "data/personamem/memory_cards.jsonl" + + if not os.path.exists(cards_path): + print(f"Error: {cards_path} not found.") + return + + print(f"Loading memory cards from {cards_path}...") + + cards_by_user = defaultdict(int) + total_cards = 0 + + with open(cards_path, "r") as f: + for line in f: + try: + card = json.loads(line) + uid = card.get("user_id") + if uid: + cards_by_user[uid] += 1 + total_cards += 1 + except: + continue + + # We also need to know the TOTAL number of personas (including those with 0 cards) + # We can infer this from the user_vectors file if it exists, or just report on "users with memory" + # But better to check contexts file to see denominator + + ctx_path = "data/raw_datasets/personamem/shared_contexts_32k.jsonl" + total_personas = 0 + if os.path.exists(ctx_path): + with open(ctx_path, "r") as f: + for line in f: + data = json.loads(line) + total_personas += len(data) # Each line is {hash: [msgs]}? Wait, check format. + # personamem_loader says: line is dict {cid: msgs} + # So usually 1 per line? Or many? + # Let's count keys. + else: + print("Warning: Context file not found, can't calculate 0-memory users accurately.") + total_personas = len(cards_by_user) # Fallback + + users_with_memory = len(cards_by_user) + users_without_memory = total_personas - users_with_memory + + counts = list(cards_by_user.values()) + if users_without_memory > 0: + counts.extend([0] * users_without_memory) + + print("\n" + "="*40) + print("Memory Coverage Statistics") + print("="*40) + print(f"Total Personas (Est): {total_personas}") + print(f"Total Memory Cards: {total_cards}") + print(f"Users with Memory: {users_with_memory} ({users_with_memory/total_personas*100:.2f}%)") + print(f"Users w/o Memory: {users_without_memory} ({users_without_memory/total_personas*100:.2f}%)") + print("-" * 40) + + if counts: + avg_cards = np.mean(counts) + median_cards = np.median(counts) + max_cards = np.max(counts) + + print(f"Avg Cards/User: {avg_cards:.2f}") + print(f"Median Cards/User: {median_cards:.2f}") + print(f"Max Cards/User: {max_cards}") + + # Percentiles + p25, p75 = np.percentile(counts, [25, 75]) + print(f"25th Percentile: {p25:.2f}") + print(f"75th Percentile: {p75:.2f}") + + print("\nDistribution:") + + # Adjust for exact 0 + zero_count = counts.count(0) + + print(f" 0 : {zero_count}") + # Custom bins for >0 + non_zero_counts = [c for c in counts if c > 0] + if non_zero_counts: + hist_nz, edges = np.histogram(non_zero_counts, bins=[1, 5, 10, 20, 50, 1000]) + for i in range(len(hist_nz)): + range_str = f"{int(edges[i])}-{int(edges[i+1]-1)}" + print(f" {range_str:<8}: {hist_nz[i]}") + +if __name__ == "__main__": + main() + diff --git a/scripts/analyze_user_similarity.py b/scripts/analyze_user_similarity.py new file mode 100644 index 0000000..538a89a --- /dev/null +++ b/scripts/analyze_user_similarity.py @@ -0,0 +1,445 @@ +#!/usr/bin/env python3 +""" +User Vector Similarity Analysis + +This script analyzes the similarity between user vectors (z_u) learned by the +online personalization system. It computes: +1. Cosine similarity matrix between all user vectors +2. Ground truth similarity based on preference overlap +3. Correlation between learned and expected similarities + +Usage: + python scripts/analyze_user_similarity.py \ + --user-store data/users/user_store_pilot_v4_full-greedy.npz +""" + +import argparse +import numpy as np +from typing import Dict, List, Tuple +from dataclasses import dataclass + + +# ============================================================================= +# Persona Definitions (must match pilot_runner_v4.py) +# ============================================================================= + +@dataclass +class StylePrefs: + """User's TRUE style preferences.""" + require_short: bool = False + max_chars: int = 300 + require_bullets: bool = False + lang: str = "en" + + +# Ground truth personas +PERSONAS = { + "user_A_short_bullets_en": StylePrefs(require_short=True, max_chars=200, require_bullets=True, lang="en"), + "user_B_short_no_bullets_en": StylePrefs(require_short=True, max_chars=200, require_bullets=False, lang="en"), + "user_C_long_bullets_en": StylePrefs(require_short=False, max_chars=800, require_bullets=True, lang="en"), + "user_D_short_bullets_zh": StylePrefs(require_short=True, max_chars=200, require_bullets=True, lang="zh"), + "user_E_long_no_bullets_zh": StylePrefs(require_short=False, max_chars=800, require_bullets=False, lang="zh"), + "user_F_extreme_short_en": StylePrefs(require_short=True, max_chars=100, require_bullets=True, lang="en"), +} + + +# ============================================================================= +# User Vector Loading +# ============================================================================= + +def load_user_vectors(user_store_path: str) -> Dict[str, Tuple[np.ndarray, np.ndarray]]: + """ + Load user vectors from saved user store. + + Returns: + {user_id: (z_long, z_short)} + """ + data = np.load(user_store_path, allow_pickle=True) + + user_vectors = {} + + # UserTensorStore saves in format: {uid}_long, {uid}_short, {uid}_meta + # First, find all unique user IDs + user_ids = set() + for key in data.files: + if key.endswith("_long"): + uid = key[:-5] # Remove "_long" + user_ids.add(uid) + + # Load vectors for each user + for uid in user_ids: + long_key = f"{uid}_long" + short_key = f"{uid}_short" + + if long_key in data.files and short_key in data.files: + z_long = data[long_key] + z_short = data[short_key] + user_vectors[uid] = (z_long, z_short) + + return user_vectors + + +def load_user_vectors_from_internal(user_store_path: str) -> Dict[str, Tuple[np.ndarray, np.ndarray]]: + """ + Alternative loader that understands the internal format. + """ + data = np.load(user_store_path, allow_pickle=True) + + print(f"[Debug] Available keys in npz: {list(data.files)}") + + user_vectors = {} + + # Try to find user vectors in various formats + for key in data.files: + print(f" {key}: shape={data[key].shape if hasattr(data[key], 'shape') else 'N/A'}") + + # Format 1: Separate arrays per user + seen_users = set() + for key in data.files: + if "_z_long" in key or key.startswith("z_long_"): + # Extract user_id + if key.startswith("z_long_"): + user_id = key[7:] # Remove "z_long_" + else: + user_id = key.split("_z_long")[0] + seen_users.add(user_id) + + for user_id in seen_users: + # Try different key formats + z_long_keys = [f"z_long_{user_id}", f"{user_id}_z_long"] + z_short_keys = [f"z_short_{user_id}", f"{user_id}_z_short"] + + z_long = None + z_short = None + + for k in z_long_keys: + if k in data.files: + z_long = data[k] + break + + for k in z_short_keys: + if k in data.files: + z_short = data[k] + break + + if z_long is not None and z_short is not None: + user_vectors[user_id] = (z_long, z_short) + + return user_vectors + + +# ============================================================================= +# Similarity Computation +# ============================================================================= + +def cosine_similarity(v1: np.ndarray, v2: np.ndarray) -> float: + """Compute cosine similarity between two vectors.""" + norm1 = np.linalg.norm(v1) + norm2 = np.linalg.norm(v2) + + if norm1 < 1e-10 or norm2 < 1e-10: + return 0.0 + + return float(np.dot(v1, v2) / (norm1 * norm2)) + + +def compute_learned_similarity_matrix( + user_vectors: Dict[str, Tuple[np.ndarray, np.ndarray]], + user_order: List[str] +) -> np.ndarray: + """ + Compute similarity matrix from learned user vectors. + + Uses concatenated [z_long, z_short] as the user representation. + """ + n = len(user_order) + sim_matrix = np.zeros((n, n)) + + for i, u1 in enumerate(user_order): + for j, u2 in enumerate(user_order): + if u1 in user_vectors and u2 in user_vectors: + z1 = np.concatenate(user_vectors[u1]) + z2 = np.concatenate(user_vectors[u2]) + sim_matrix[i, j] = cosine_similarity(z1, z2) + elif i == j: + sim_matrix[i, j] = 1.0 + + return sim_matrix + + +def compute_ground_truth_similarity( + personas: Dict[str, StylePrefs], + user_order: List[str] +) -> np.ndarray: + """ + Compute ground truth similarity based on preference overlap. + + Uses Jaccard-like similarity: + - short: +1 if both require_short or both don't + - bullets: +1 if both require_bullets match + - lang: +1 if both lang match + + Then normalize to [0, 1]. + """ + n = len(user_order) + sim_matrix = np.zeros((n, n)) + + for i, u1 in enumerate(user_order): + for j, u2 in enumerate(user_order): + if u1 not in personas or u2 not in personas: + sim_matrix[i, j] = 0.0 if i != j else 1.0 + continue + + p1 = personas[u1] + p2 = personas[u2] + + # Count matching dimensions + matches = 0 + total = 3 # short, bullets, lang + + if p1.require_short == p2.require_short: + matches += 1 + if p1.require_bullets == p2.require_bullets: + matches += 1 + if p1.lang == p2.lang: + matches += 1 + + sim_matrix[i, j] = matches / total + + return sim_matrix + + +def compute_correlation(learned: np.ndarray, ground_truth: np.ndarray) -> Tuple[float, float]: + """ + Compute Pearson and Spearman correlation between learned and ground truth similarity. + Only uses upper triangle (excluding diagonal) to avoid bias. + """ + n = learned.shape[0] + + # Extract upper triangle (excluding diagonal) + learned_flat = [] + gt_flat = [] + + for i in range(n): + for j in range(i + 1, n): + learned_flat.append(learned[i, j]) + gt_flat.append(ground_truth[i, j]) + + learned_flat = np.array(learned_flat) + gt_flat = np.array(gt_flat) + + # Pearson correlation + if np.std(learned_flat) < 1e-10 or np.std(gt_flat) < 1e-10: + pearson = 0.0 + else: + pearson = float(np.corrcoef(learned_flat, gt_flat)[0, 1]) + + # Spearman correlation (rank-based) + from scipy.stats import spearmanr + spearman, _ = spearmanr(learned_flat, gt_flat) + + return pearson, float(spearman) + + +# ============================================================================= +# Visualization +# ============================================================================= + +def print_similarity_matrix(matrix: np.ndarray, user_order: List[str], title: str): + """Print similarity matrix in ASCII format.""" + print(f"\n{title}") + print("=" * 70) + + # Short labels + labels = [u.replace("user_", "").replace("_", " ")[:15] for u in user_order] + + # Header + print(f"{'':>16}", end="") + for label in labels: + print(f"{label[:8]:>10}", end="") + print() + + # Rows + for i, label in enumerate(labels): + print(f"{label:>16}", end="") + for j in range(len(labels)): + print(f"{matrix[i, j]:>10.3f}", end="") + print() + + print() + + +def save_visualization( + learned: np.ndarray, + ground_truth: np.ndarray, + user_order: List[str], + output_path: str +): + """Save similarity matrices as heatmap visualization.""" + try: + import matplotlib.pyplot as plt + import seaborn as sns + except ImportError: + print("[Warning] matplotlib/seaborn not available, skipping visualization") + return + + # Short labels + labels = [u.replace("user_", "")[:12] for u in user_order] + + fig, axes = plt.subplots(1, 2, figsize=(14, 6)) + + # Learned similarity + sns.heatmap(learned, annot=True, fmt=".2f", + xticklabels=labels, yticklabels=labels, + cmap="RdYlGn", vmin=-1, vmax=1, + ax=axes[0]) + axes[0].set_title("Learned User Vector Similarity\n(cosine similarity)") + axes[0].tick_params(axis='x', rotation=45) + axes[0].tick_params(axis='y', rotation=0) + + # Ground truth similarity + sns.heatmap(ground_truth, annot=True, fmt=".2f", + xticklabels=labels, yticklabels=labels, + cmap="RdYlGn", vmin=0, vmax=1, + ax=axes[1]) + axes[1].set_title("Ground Truth Preference Overlap\n(Jaccard-like)") + axes[1].tick_params(axis='x', rotation=45) + axes[1].tick_params(axis='y', rotation=0) + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches='tight') + print(f"[Visualization] Saved to: {output_path}") + + +# ============================================================================= +# Main Analysis +# ============================================================================= + +def analyze_user_similarity(user_store_path: str, output_dir: str = "data/analysis"): + """Run full user similarity analysis.""" + import os + os.makedirs(output_dir, exist_ok=True) + + print("=" * 70) + print("USER VECTOR SIMILARITY ANALYSIS") + print("=" * 70) + print(f"User store: {user_store_path}") + + # Load user vectors + print("\n[1] Loading user vectors...") + user_vectors = load_user_vectors(user_store_path) + + if not user_vectors: + print("[Warning] No user vectors found with standard format, trying alternative...") + user_vectors = load_user_vectors_from_internal(user_store_path) + + if not user_vectors: + print("[Error] Could not load user vectors!") + return + + print(f" Found {len(user_vectors)} users: {list(user_vectors.keys())}") + + # Print vector norms + print("\n[2] User vector norms:") + for uid, (z_long, z_short) in user_vectors.items(): + print(f" {uid}: ||z_long||={np.linalg.norm(z_long):.4f}, ||z_short||={np.linalg.norm(z_short):.4f}") + + # Determine user order (intersection of loaded users and known personas) + user_order = [u for u in PERSONAS.keys() if u in user_vectors] + print(f"\n[3] Analyzing {len(user_order)} users: {user_order}") + + if len(user_order) < 2: + print("[Error] Need at least 2 users for similarity analysis!") + return + + # Compute similarity matrices + print("\n[4] Computing similarity matrices...") + learned_sim = compute_learned_similarity_matrix(user_vectors, user_order) + gt_sim = compute_ground_truth_similarity(PERSONAS, user_order) + + # Print matrices + print_similarity_matrix(learned_sim, user_order, "LEARNED SIMILARITY (Cosine of z_u)") + print_similarity_matrix(gt_sim, user_order, "GROUND TRUTH SIMILARITY (Preference Overlap)") + + # Compute correlation + print("\n[5] Correlation Analysis:") + print("-" * 50) + pearson, spearman = compute_correlation(learned_sim, gt_sim) + print(f" Pearson correlation: {pearson:.4f}") + print(f" Spearman correlation: {spearman:.4f}") + + # Interpretation + print("\n[6] Interpretation:") + print("-" * 50) + if spearman > 0.7: + print(" ✅ STRONG correlation: User vectors encode preference similarity well!") + elif spearman > 0.4: + print(" ⚠️ MODERATE correlation: User vectors partially capture preferences.") + elif spearman > 0: + print(" ⚠️ WEAK correlation: User vectors weakly capture preferences.") + else: + print(" ❌ NO/NEGATIVE correlation: User vectors do not reflect preferences.") + + # Key comparisons + print("\n[7] Key Similarity Comparisons:") + print("-" * 50) + + def get_sim(u1, u2, matrix, user_order): + if u1 in user_order and u2 in user_order: + i, j = user_order.index(u1), user_order.index(u2) + return matrix[i, j] + return None + + comparisons = [ + ("user_A_short_bullets_en", "user_F_extreme_short_en", ">", "user_A_short_bullets_en", "user_E_long_no_bullets_zh", + "A~F (both short+bullets) should be > A~E (opposite)"), + ("user_A_short_bullets_en", "user_D_short_bullets_zh", ">", "user_A_short_bullets_en", "user_C_long_bullets_en", + "A~D (both short+bullets) should be > A~C (only bullets match)"), + ("user_B_short_no_bullets_en", "user_E_long_no_bullets_zh", ">", "user_B_short_no_bullets_en", "user_A_short_bullets_en", + "B~E (both no_bullets) should be > B~A (bullets differ)"), + ] + + for u1, u2, op, u3, u4, desc in comparisons: + sim1 = get_sim(u1, u2, learned_sim, user_order) + sim2 = get_sim(u3, u4, learned_sim, user_order) + + if sim1 is not None and sim2 is not None: + passed = sim1 > sim2 if op == ">" else sim1 < sim2 + status = "✅ PASS" if passed else "❌ FAIL" + print(f" {status}: sim({u1[:6]},{u2[:6]})={sim1:.3f} {op} sim({u3[:6]},{u4[:6]})={sim2:.3f}") + print(f" ({desc})") + + # Save visualization + print("\n[8] Saving visualization...") + output_path = os.path.join(output_dir, "user_similarity_matrix.png") + save_visualization(learned_sim, gt_sim, user_order, output_path) + + # Save numerical results + results_path = os.path.join(output_dir, "user_similarity_results.npz") + np.savez(results_path, + learned_similarity=learned_sim, + ground_truth_similarity=gt_sim, + user_order=user_order, + pearson=pearson, + spearman=spearman) + print(f"[Results] Saved to: {results_path}") + + print("\n" + "=" * 70) + print("ANALYSIS COMPLETE") + print("=" * 70) + + +def main(): + parser = argparse.ArgumentParser(description="User Vector Similarity Analysis") + parser.add_argument("--user-store", type=str, required=True, + help="Path to user store npz file") + parser.add_argument("--output-dir", type=str, default="data/analysis", + help="Output directory for results") + args = parser.parse_args() + + analyze_user_similarity(args.user_store, args.output_dir) + + +if __name__ == "__main__": + main() + diff --git a/scripts/assemble_dataset.py b/scripts/assemble_dataset.py new file mode 100644 index 0000000..024f91f --- /dev/null +++ b/scripts/assemble_dataset.py @@ -0,0 +1,85 @@ +import json +import os +import random + +# Source Files +FILE_ORIGINAL = "data/raw_datasets/labeled_full_dataset_batch.jsonl" +FILE_SYNTHESIS = "data/raw_datasets/synthesized_positives.jsonl" + +# Output Files +OUTPUT_RAW = "data/finetune/preference_extractor_450k.jsonl" + +def assemble_dataset(): + os.makedirs(os.path.dirname(OUTPUT_RAW), exist_ok=True) + + print("Assembling final dataset...") + + records = [] + + # 1. Load Original (Pos + Neg) + print(f"Loading {FILE_ORIGINAL}...") + if os.path.exists(FILE_ORIGINAL): + with open(FILE_ORIGINAL, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + item = json.loads(line) + # Standardize format: {"input": ..., "output": ...} + # Input is user query. Output is extracted JSON string. + + query = item.get("original_query", "") + output_json = item.get("extracted_json", {"preferences": []}) + + # Ensure output is a string of JSON, minimal whitespace to save tokens + output_str = json.dumps(output_json, ensure_ascii=False) + + records.append({ + "input": query, + "output": output_str, + "source": item.get("source", "original") + }) + else: + print(f"Warning: {FILE_ORIGINAL} missing!") + + print(f"Loaded {len(records)} from original.") + + # 2. Load Synthesis (Pos) + print(f"Loading {FILE_SYNTHESIS}...") + syn_count = 0 + if os.path.exists(FILE_SYNTHESIS): + with open(FILE_SYNTHESIS, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + item = json.loads(line) + + query = item.get("original_query", "") + output_json = item.get("extracted_json", {"preferences": []}) + output_str = json.dumps(output_json, ensure_ascii=False) + + records.append({ + "input": query, + "output": output_str, + "source": "synthesis" + }) + syn_count += 1 + else: + print(f"Warning: {FILE_SYNTHESIS} missing!") + + print(f"Loaded {syn_count} from synthesis.") + + # 3. Shuffle + print("Shuffling...") + random.shuffle(records) + + # 4. Save + print(f"Saving {len(records)} records to {OUTPUT_RAW}...") + with open(OUTPUT_RAW, "w", encoding="utf-8") as f: + for r in records: + f.write(json.dumps(r, ensure_ascii=False) + "\n") + + print("Done!") + print("\nTo upload to Hugging Face, run:") + print("huggingface-cli upload <repo_id> data/finetune/preference_extractor_450k.jsonl --repo-type dataset") + +if __name__ == "__main__": + assemble_dataset() + diff --git a/scripts/build_item_space.py b/scripts/build_item_space.py new file mode 100644 index 0000000..c98238c --- /dev/null +++ b/scripts/build_item_space.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +""" +Script to build Item Space (PCA Projection) from Memory Embeddings. +Inputs: +- data/corpora/memory_embeddings.npy (M x 4096) +Outputs: +- data/corpora/item_projection.npz (P, mean, V) +""" + +import sys +import os +import numpy as np + +# Add src to sys.path +sys.path.append(os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.user_model.features import ItemProjection + +def main(): + emb_path = "data/corpora/memory_embeddings.npy" + out_path = "data/corpora/item_projection.npz" + + if not os.path.exists(emb_path): + print(f"Error: {emb_path} not found. Run migrate_preferences.py first.") + sys.exit(1) + + print(f"Loading embeddings from {emb_path}...") + E = np.load(emb_path) + print(f"Loaded shape: {E.shape}") + + # Target dimension k=256 + k = 256 + print(f"Fitting PCA with k={k}...") + + proj = ItemProjection.from_pca(E, k=k) + + print("Transforming all embeddings to item space...") + V = proj.transform_embeddings(E) + print(f"Item vectors shape: {V.shape}") + + print(f"Saving projection to {out_path}...") + np.savez( + out_path, + P=proj.P, + mean=proj.mean, + V=V + ) + print("Done.") + +if __name__ == "__main__": + main() + diff --git a/scripts/check_batch_status.py b/scripts/check_batch_status.py new file mode 100644 index 0000000..8c3ea5d --- /dev/null +++ b/scripts/check_batch_status.py @@ -0,0 +1,60 @@ +import json +import os +import time +from openai import OpenAI + +BATCH_IDS_FILE = "data/putnam_eval/submitted_batch_ids.json" + +def check_status(): + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("Error: OPENAI_API_KEY not set.") + return + client = OpenAI(api_key=api_key) + + if not os.path.exists(BATCH_IDS_FILE): + print(f"Error: {BATCH_IDS_FILE} not found. Did you submit batches?") + return + + with open(BATCH_IDS_FILE, "r") as f: + batch_ids = json.load(f) + + print(f"Checking status for {len(batch_ids)} batches...\n") + + all_completed = True + completed_count = 0 + + print(f"{'BATCH ID':<35} | {'STATUS':<15} | {'REQ COUNT':<10} | {'COMPLETED':<10} | {'FAILED':<10}") + print("-" * 95) + + for b_id in batch_ids: + try: + batch = client.batches.retrieve(b_id) + status = batch.status + counts = batch.request_counts + total = counts.total if counts else 0 + comp = counts.completed if counts else 0 + fail = counts.failed if counts else 0 + + print(f"{b_id:<35} | {status:<15} | {total:<10} | {comp:<10} | {fail:<10}") + + if status != "completed": + all_completed = False + else: + completed_count += 1 + + except Exception as e: + print(f"{b_id:<35} | {'ERROR':<15} | {str(e)}") + all_completed = False + + print("-" * 95) + print(f"\nProgress: {completed_count}/{len(batch_ids)} batches completed.") + + if all_completed: + print("\nSUCCESS: All batches finished! You can now run scripts/retrieve_results.py") + else: + print("\nSome batches are still processing. Check again later.") + +if __name__ == "__main__": + check_status() + diff --git a/scripts/clean_memory_store.py b/scripts/clean_memory_store.py new file mode 100644 index 0000000..3b07a95 --- /dev/null +++ b/scripts/clean_memory_store.py @@ -0,0 +1,52 @@ +import json +import os +import shutil + +MEMORY_FILE = "data/corpora/memory_cards.jsonl" +BACKUP_FILE = "data/corpora/memory_cards.jsonl.bak" + +def clean_memory_store(): + if not os.path.exists(MEMORY_FILE): + print(f"Error: {MEMORY_FILE} not found.") + return + + # 1. Backup + print(f"Backing up to {BACKUP_FILE}...") + shutil.copy2(MEMORY_FILE, BACKUP_FILE) + + unique_keys = set() + cleaned_records = [] + total_read = 0 + + # 2. Read and Dedup + print("Scanning and deduplicating...") + with open(MEMORY_FILE, "r", encoding="utf-8") as f: + for line in f: + if not line.strip(): continue + try: + card = json.loads(line) + total_read += 1 + + uid = card.get("user_id") + text = card.get("note_text", "").strip() + + # Key: (User, Content) + key = (uid, text) + + if key not in unique_keys: + unique_keys.add(key) + cleaned_records.append(line.strip()) + except: + pass + + # 3. Write Back + print(f"Writing {len(cleaned_records)} records back (Removed {total_read - len(cleaned_records)} duplicates)...") + with open(MEMORY_FILE, "w", encoding="utf-8") as f: + for line in cleaned_records: + f.write(line + "\n") + + print("Done!") + +if __name__ == "__main__": + clean_memory_store() + diff --git a/scripts/convert_to_llama_factory.py b/scripts/convert_to_llama_factory.py new file mode 100644 index 0000000..d8b7565 --- /dev/null +++ b/scripts/convert_to_llama_factory.py @@ -0,0 +1,62 @@ +import json +import os + +INPUT_FILE = "data/finetune/preference_extractor_450k.jsonl" +OUTPUT_FILE = "data/finetune/train_llama_factory.json" + +# We embed the system prompt as "instruction" so the model learns to respond to this specific instruction. +# Or, if you plan to put this system prompt in the system slot of the chat template, +# you can leave instruction empty or simplified. +# Given 0.5B model, explicit instruction in the prompt is often helpful. +SYSTEM_INSTRUCTION = ( + "Extract user preferences from the query into JSON format based on the PreferenceList schema. " + "If no preferences are found, return {\"preferences\": []}." +) + +def convert(): + if not os.path.exists(INPUT_FILE): + print(f"Error: {INPUT_FILE} not found. Run scripts/assemble_dataset.py first.") + return + + print(f"Reading {INPUT_FILE}...") + dataset = [] + + with open(INPUT_FILE, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + item = json.loads(line) + + # Alpaca format + record = { + "instruction": SYSTEM_INSTRUCTION, + "input": item["input"], + "output": item["output"] + } + dataset.append(record) + + print(f"Converted {len(dataset)} items.") + + # Save as JSON list (LLaMA-Factory standard) + print(f"Saving to {OUTPUT_FILE}...") + with open(OUTPUT_FILE, "w", encoding="utf-8") as f: + json.dump(dataset, f, indent=2, ensure_ascii=False) + + print("Done!") + + print("\nNext steps for LLaMA-Factory:") + print("1. Copy data/finetune/train_llama_factory.json to your LLaMA-Factory data/ folder.") + print("2. Add entry to dataset_info.json:") + print(json.dumps({ + "preference_extractor_v1": { + "file_name": "train_llama_factory.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output" + } + } + }, indent=2)) + +if __name__ == "__main__": + convert() + diff --git a/scripts/day1_demo.py b/scripts/day1_demo.py new file mode 100644 index 0000000..b201229 --- /dev/null +++ b/scripts/day1_demo.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +""" +Day 1 Demo: End-to-end Minimal Memory RAG. +1. Load MemoryCards + Embeddings. +2. Receive a query. +3. Retrieve top-k memories. +4. Generate answer with QwenInstruct. +""" + +import json +import numpy as np +import torch +import sys +import os + +# Add src to sys.path so we can import personalization +sys.path.append(os.path.join(os.path.dirname(__file__), "../src")) + +from typing import List + +from personalization.config.settings import load_local_models_config +from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B +from personalization.models.llm.qwen_instruct import QwenInstruct +from personalization.retrieval.preference_store.schemas import MemoryCard + +def load_memory_store(cards_path: str, embs_path: str): + print(f"Loading memory store from {cards_path}...") + cards = [] + with open(cards_path, "r", encoding="utf-8") as f: + for line in f: + cards.append(MemoryCard.model_validate_json(line)) + + embs = np.load(embs_path) + return cards, embs + +def cosine_similarity(E: np.ndarray, e_q: np.ndarray) -> np.ndarray: + # E: [M, d], e_q: [d] + # Assumes vectors are normalized + return np.dot(E, e_q) + +def dense_retrieve( + query: str, + embedder: Qwen3Embedding8B, + cards: List[MemoryCard], + E: np.ndarray, + topk: int = 3 +) -> List[MemoryCard]: + + # Encode query + # encode returns list[list[float]] or tensor + e_q_list = embedder.encode([query], normalize=True, return_tensor=False) + e_q = np.array(e_q_list[0], dtype=np.float32) + + # Sim + sims = cosine_similarity(E, e_q) + + # Top-k + # argsort is ascending, so take last k and reverse + if len(cards) == 0: + return [] + + k = min(topk, len(cards)) + idx = np.argsort(sims)[-k:][::-1] + + results = [cards[i] for i in idx] + return results + +def main(): + cards_path = "data/corpora/memory_cards.jsonl" + embs_path = "data/corpora/memory_embeddings.npy" + + try: + cards, embs = load_memory_store(cards_path, embs_path) + print(f"Loaded {len(cards)} memory cards.") + except FileNotFoundError: + print("Error: Memory store not found. Please run scripts/migrate_preferences.py first.") + sys.exit(1) + + cfg = load_local_models_config() + + print("Initializing models...") + embedder = Qwen3Embedding8B.from_config(cfg) + llm = QwenInstruct.from_config(cfg) + + # Demo Query + # Let's try to pick a query that should trigger a retrieval if we have relevant memories. + # Since we processed pilot_study, let's assume we might have some "python code" or "formatting" prefs. + # If the pilot study didn't yield many prefs, we might just query something generic. + query = "Please write a function to calculate fibonacci numbers. Remember my preferences." + + # Or let's allow user input or command line arg + if len(sys.argv) > 1: + query = sys.argv[1] + + print(f"\nQuery: {query}") + + # Retrieve + hits = dense_retrieve(query, embedder, cards, embs, topk=3) + + print(f"\nRetrieved {len(hits)} memories:") + notes = [] + for h in hits: + print(f" - [{h.kind}] {h.note_text} (from user: {h.user_id})") + notes.append(h.note_text) + + # Generate + print("\nGenerating answer...") + # Mock history: just the current turn + history = [{"role": "user", "content": query}] + + answer = llm.answer(history, notes) + + print("-" * 40) + print("Answer:") + print(answer) + print("-" * 40) + +if __name__ == "__main__": + main() + diff --git a/scripts/day2_demo.py b/scripts/day2_demo.py new file mode 100644 index 0000000..ca81d99 --- /dev/null +++ b/scripts/day2_demo.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 +""" +Day 2 Demo: End-to-end Memory RAG with Reranker and (Shell) Personalization. +""" + +import sys +import os +import numpy as np +import torch +from typing import List + +# Add src to sys.path +sys.path.append(os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.config.settings import load_local_models_config +from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B +from personalization.models.llm.qwen_instruct import QwenInstruct +from personalization.models.reranker.qwen3_reranker import Qwen3Reranker +from personalization.retrieval.preference_store.schemas import MemoryCard +from personalization.user_model.tensor_store import UserTensorStore +from personalization.retrieval.pipeline import retrieve_with_rerank + +def main(): + # Paths + cards_path = "data/corpora/memory_cards.jsonl" + embs_path = "data/corpora/memory_embeddings.npy" + item_proj_path = "data/corpora/item_projection.npz" + user_store_path = "data/users/user_store.npz" + + # 1. Load Data + print("Loading data stores...") + if not os.path.exists(cards_path) or not os.path.exists(embs_path): + print("Memory data missing. Run migrate_preferences.py") + sys.exit(1) + + cards = [] + with open(cards_path, "r") as f: + for line in f: + cards.append(MemoryCard.model_validate_json(line)) + + memory_embeddings = np.load(embs_path) + + if not os.path.exists(item_proj_path): + print("Item projection missing. Run build_item_space.py") + sys.exit(1) + + proj_data = np.load(item_proj_path) + item_vectors = proj_data["V"] + + # 2. Load Models + print("Loading models...") + cfg = load_local_models_config() + + embedder = Qwen3Embedding8B.from_config(cfg) + reranker = Qwen3Reranker.from_config(cfg) + llm = QwenInstruct.from_config(cfg) + + # 3. Load User Store + # k = item_vectors.shape[1] + k = 256 # Hardcoded per config + user_store = UserTensorStore(k=k, path=user_store_path) + + # --- CHECK 1: User Vector Similarity --- + print("\n--- CHECK 1: User Vector Similarity ---") + # Get 3 users with memories + valid_users = [uid for uid, state in user_store._states.items() + if np.linalg.norm(state.z_long) > 1e-6] # Only non-zero users + + if len(valid_users) < 3: + print(f"Not enough users with memories found (found {len(valid_users)}). Skipping similarity check.") + else: + # Pick 3 random users + import random + selected_users = random.sample(valid_users, 3) + vectors = [user_store.get_state(uid).z_long for uid in selected_users] + + # Calculate pairwise cosine similarity + def cos_sim(a, b): + norm_a = np.linalg.norm(a) + norm_b = np.linalg.norm(b) + if norm_a == 0 or norm_b == 0: return 0.0 + return np.dot(a, b) / (norm_a * norm_b) + + print(f"Selected Users: {selected_users}") + print(f"Sim(0, 1): {cos_sim(vectors[0], vectors[1]):.4f}") + print(f"Sim(0, 2): {cos_sim(vectors[0], vectors[2]):.4f}") + print(f"Sim(1, 2): {cos_sim(vectors[1], vectors[2]):.4f}") + + # --- CHECK 2: Real User Retrieval --- + print("\n--- CHECK 2: Real User Retrieval ---") + + if len(valid_users) > 0: + # Pick one user + target_user = valid_users[0] + # Find a query from this user? + # For now, let's use a generic query that might hit some tech preferences, + # or ideally find a query from the dataset if we had it loaded. + # Let's try a generic coding query since OASST1 has many. + query = "How do I write a Python function for fibonacci?" + + print(f"User: {target_user}") + print(f"Query: {query}") + + # 5. Retrieve Pipeline + print("\nRunning Retrieval Pipeline (GLOBAL search)...") + hits_global = retrieve_with_rerank( + user_id=target_user, + query=query, + embed_model=embedder, + reranker=reranker, + memory_cards=cards, + memory_embeddings=memory_embeddings, + user_store=user_store, + item_vectors=item_vectors, + topk_dense=64, + topk_rerank=3, + beta_long=0.0, + beta_short=0.0, + only_own_memories=False # Global search + ) + + print(f"\nTop {len(hits_global)} Memories (Global):") + for h in hits_global: + print(f" - [{h.kind}] {h.note_text} (User: {h.user_id})") + + print("\nRunning Retrieval Pipeline (OWN memories only)...") + hits_own = retrieve_with_rerank( + user_id=target_user, + query=query, + embed_model=embedder, + reranker=reranker, + memory_cards=cards, + memory_embeddings=memory_embeddings, + user_store=user_store, + item_vectors=item_vectors, + topk_dense=64, + topk_rerank=3, + beta_long=0.0, + beta_short=0.0, + only_own_memories=True # Own search + ) + + print(f"\nTop {len(hits_own)} Memories (Own):") + notes = [] + for h in hits_own: + print(f" - [{h.kind}] {h.note_text} (User: {h.user_id})") + notes.append(h.note_text) + + # 6. Generate Answer (using OWN memories by default for demo) + print("\nGenerating Answer (using Own Memories)...") + history = [{"role": "user", "content": query}] + answer = llm.answer(history, notes) + + print("-" * 40) + print(answer) + print("-" * 40) + else: + print("No valid users found for demo.") + +if __name__ == "__main__": + main() + diff --git a/scripts/day3_demo_feedback.py b/scripts/day3_demo_feedback.py new file mode 100644 index 0000000..4e3d594 --- /dev/null +++ b/scripts/day3_demo_feedback.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +""" +Day 3 Demo: Feedback Loop Simulation (Reward + Gating). +""" + +import sys +import os +import json +import numpy as np +import random + +# Add src to sys.path +sys.path.append(os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.config.settings import load_local_models_config +from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B +from personalization.models.reranker.qwen3_reranker import Qwen3Reranker +from personalization.retrieval.preference_store.schemas import MemoryCard, ChatTurn +from personalization.user_model.tensor_store import UserTensorStore +from personalization.retrieval.pipeline import retrieve_with_rerank +from personalization.feedback.handlers import eval_step + +def main(): + # Paths + cards_path = "data/corpora/memory_cards.jsonl" + embs_path = "data/corpora/memory_embeddings.npy" + item_proj_path = "data/corpora/item_projection.npz" + user_store_path = "data/users/user_store.npz" + oasst_path = "data/raw_datasets/oasst1_queries.jsonl" # Source of turns + + # 1. Load Data + print("Loading data stores...") + if not os.path.exists(cards_path) or not os.path.exists(embs_path): + print("Memory data missing.") + sys.exit(1) + + cards = [] + with open(cards_path, "r") as f: + for line in f: + cards.append(MemoryCard.model_validate_json(line)) + + memory_embeddings = np.load(embs_path) + + proj_data = np.load(item_proj_path) + item_vectors = proj_data["V"] + + # 2. Load Models + print("Loading models...") + cfg = load_local_models_config() + embedder = Qwen3Embedding8B.from_config(cfg) + reranker = Qwen3Reranker.from_config(cfg) + + user_store = UserTensorStore(k=256, path=user_store_path) + + # 3. Simulate a Session + # Since we don't have full sessions in 'oasst1_queries.jsonl' (it's flat queries), + # we'll mock a session or try to find one if we had full chat logs. + # For demo, let's construct a synthetic scenario. + + print("\n--- Synthetic Session Evaluation ---") + + # Scenario 1: Success + # User asks python, system gives good python answer, user asks follow up. + + user_id = "test_user_ok" + q_t = "How do I list files in a directory in Python?" + + # Mock retrieval results (relevant) + # Ideally we'd run retrieval, but let's assume we found a relevant card + # For demo, let's actually run retrieval + hits = retrieve_with_rerank( + user_id=user_id, + query=q_t, + embed_model=embedder, + reranker=reranker, + memory_cards=cards, + memory_embeddings=memory_embeddings, + user_store=user_store, + item_vectors=item_vectors, + topk_dense=64, + topk_rerank=3 + ) + + a_t = "You can use os.listdir() or pathlib.Path.iterdir(). Here is an example..." + q_t1 = "Great, can you show me the pathlib one?" + + print(f"\n[Scenario 1]") + print(f"Q_t: {q_t}") + print(f"A_t: {a_t}") + print(f"Q_t+1: {q_t1}") + print(f"Memories: {[m.note_text for m in hits]}") + + # Eval + e_q = embedder.encode([q_t], return_tensor=False)[0] + e_q1 = embedder.encode([q_t1], return_tensor=False)[0] + e_q = np.array(e_q) + e_q1 = np.array(e_q1) + + r_hat, g_hat = eval_step(q_t, a_t, q_t1, hits, e_q, e_q1) + print(f"-> Reward: {r_hat:.2f}") + print(f"-> Gating: {g_hat:.2f}") + + # Scenario 2: Failure (Complaint) + q_t = "Explain quantum entanglement." + a_t = "Quantum entanglement is a phenomenon where particles..." + q_t1 = "No, that's not what I meant. Explain it simply like I'm five." + + # Mock retrieval (irrelevant or empty?) + # Let's say we retrieved some python stuff again by mistake + + print(f"\n[Scenario 2]") + print(f"Q_t: {q_t}") + print(f"A_t: {a_t}") + print(f"Q_t+1: {q_t1}") + print(f"Memories: {[m.note_text for m in hits]} (Irrelevant)") + + e_q = embedder.encode([q_t], return_tensor=False)[0] + e_q1 = embedder.encode([q_t1], return_tensor=False)[0] + e_q = np.array(e_q) + e_q1 = np.array(e_q1) + + r_hat, g_hat = eval_step(q_t, a_t, q_t1, hits, e_q, e_q1) + print(f"-> Reward: {r_hat:.2f}") + print(f"-> Gating: {g_hat:.2f}") + +if __name__ == "__main__": + main() diff --git a/scripts/day4_offline_rl_replay.py b/scripts/day4_offline_rl_replay.py new file mode 100644 index 0000000..1c5cdd9 --- /dev/null +++ b/scripts/day4_offline_rl_replay.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python3 +""" +Day 4: Offline RL Replay Script. +Simulates REINFORCE update on user state using OASST1 chat logs. +""" + +import sys +import os +import json +import numpy as np +import random +from tqdm import tqdm +from collections import defaultdict + +# Add src to sys.path +sys.path.append(os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.config.settings import load_local_models_config +from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B +from personalization.models.reranker.qwen3_reranker import Qwen3Reranker +from personalization.retrieval.preference_store.schemas import MemoryCard +from personalization.user_model.tensor_store import UserTensorStore +from personalization.retrieval.pipeline import retrieve_with_policy +from personalization.feedback.handlers import eval_step +from personalization.user_model.policy.reinforce import reinforce_update_user_state + +def main(): + # Configuration + cfg = load_local_models_config() + rl_cfg = { + "item_dim": 256, + "beta_long": 0.1, + "beta_short": 0.3, + "tau": 1.0, + "eta_long": 1e-3, + "eta_short": 5e-3, + "ema_alpha": 0.05, + "short_decay": 0.1 + } + # Load from yaml if available, here hardcoded for safety/clarity in demo + + # Paths + cards_path = "data/corpora/memory_cards.jsonl" + embs_path = "data/corpora/memory_embeddings.npy" + item_proj_path = "data/corpora/item_projection.npz" + user_store_path = "data/users/user_store.npz" + chat_path = "data/raw_datasets/oasst1_queries.jsonl" + + # 1. Load Data + print("Loading data stores...") + if not os.path.exists(cards_path): + print("Data missing.") + sys.exit(1) + + cards = [] + with open(cards_path, "r") as f: + for line in f: + cards.append(MemoryCard.model_validate_json(line)) + + memory_embeddings = np.load(embs_path) + item_vectors = np.load(item_proj_path)["V"] + + # 2. Load Models + print("Loading models...") + embedder = Qwen3Embedding8B.from_config(cfg) + reranker = Qwen3Reranker.from_config(cfg) + + user_store = UserTensorStore(k=rl_cfg["item_dim"], path=user_store_path) + + # 3. Load Chat Logs and Group by Session + print("Loading chat logs...") + sessions = defaultdict(list) + with open(chat_path, "r") as f: + for line in f: + row = json.loads(line) + sessions[row["session_id"]].append(row) + + # Filter sessions with at least 2 user turns to have q_t and q_t+1 + valid_sessions = [s for s in sessions.values() if len(s) >= 2] + print(f"Found {len(valid_sessions)} valid sessions for replay.") + + # Sample a few users/sessions to replay + # For speed, pick 50 sessions + sampled_sessions = random.sample(valid_sessions, min(50, len(valid_sessions))) + + total_reward = 0.0 + total_gating = 0.0 + steps = 0 + + # 4. Replay Loop + print("\nStarting RL Replay...") + + for session in tqdm(sampled_sessions, desc="Replay Sessions"): + # Sort by turn + session.sort(key=lambda x: x.get("turn_id", 0)) + + # We need (q_t, a_t, q_t+1) + # OASST1 flat queries usually don't contain assistant answers in the same file if we only extracted user queries. + # But for this demo, let's assume we can proceed without a_t (or mock it) for retrieval check, + # OR we rely on the fact that if q_t+1 exists, there was an answer. + # Limitation: We might not have a_t text. + # Day 3 eval_step takes a_t. We can pass empty string if unavailable, + # as our current reward model mainly looks at q_t+1 keywords. + + user_id = session[0]["user_id"] + + # Pre-check user state + state_before = user_store.get_state(user_id) + norm_long_before = np.linalg.norm(state_before.z_long) + norm_short_before = np.linalg.norm(state_before.z_short) + + num_updates = 0 + for i in range(len(session) - 1): + q_t_row = session[i] + q_t1_row = session[i+1] + + q_t = q_t_row["original_query"] + q_t1 = q_t1_row["original_query"] + a_t = "" # Missing in this file + + # 1. Retrieve with Policy + # Note: We use only_own_memories=True to simulate personalized memory bank + # But if the user is new/no memory in our store, this returns empty. + # Let's try both or fallback? For RL on user vector, we need candidates. + # If user has no memory, policy cannot select anything. + # Let's use global search for replay to ensure we have candidates to rerank. + # In real system, we'd search both. + + candidates, cand_vectors, base_scores, chosen_idx, probs = retrieve_with_policy( + user_id=user_id, + query=q_t, + embed_model=embedder, + reranker=reranker, + memory_cards=cards, + memory_embeddings=memory_embeddings, + user_store=user_store, + item_vectors=item_vectors, + topk_dense=32, + topk_rerank=5, + beta_long=rl_cfg["beta_long"], + beta_short=rl_cfg["beta_short"], + tau=rl_cfg["tau"], + only_own_memories=False + ) + + if not candidates: + continue + + chosen_memories = [candidates[i] for i in chosen_idx] + + # 2. Eval (Reward & Gating) + # Need embeddings + e_q = embedder.encode([q_t], return_tensor=False)[0] + e_q1 = embedder.encode([q_t1], return_tensor=False)[0] + + r_hat, g_hat = eval_step( + q_t, a_t, q_t1, + chosen_memories, + query_embedding_t=np.array(e_q), + query_embedding_t1=np.array(e_q1) + ) + + total_reward += r_hat + total_gating += g_hat + steps += 1 + + # 3. Update (REINFORCE) + state = user_store.get_state(user_id) + updated = reinforce_update_user_state( + user_state=state, + item_vectors=cand_vectors, + chosen_indices=chosen_idx, + policy_probs=probs, + reward_hat=r_hat, + gating=g_hat, + tau=rl_cfg["tau"], + eta_long=rl_cfg["eta_long"], + eta_short=rl_cfg["eta_short"], + ema_alpha=rl_cfg["ema_alpha"], + short_decay=rl_cfg["short_decay"] + ) + if updated: + num_updates += 1 + + user_store.save_state(state) + + # Post-check + state_after = user_store.get_state(user_id) + norm_long_after = np.linalg.norm(state_after.z_long) + norm_short_after = np.linalg.norm(state_after.z_short) + + # Print change only if updated + if num_updates > 0: + print(f"User {user_id}: {num_updates} updates") + print(f" ||z_long|| {norm_long_before:.8f} -> {norm_long_after:.8f}") + print(f" ||z_short|| {norm_short_before:.8f} -> {norm_short_after:.8f}") + + print("\n--- Replay Finished ---") + print(f"Total Steps: {steps}") + print(f"Avg Reward: {total_reward / max(1, steps):.8f}") + print(f"Avg Gating: {total_gating / max(1, steps):.8f}") + + user_store.persist() + print(f"User store saved to {user_store_path}") + +if __name__ == "__main__": + main() + diff --git a/scripts/debug_context_file.py b/scripts/debug_context_file.py new file mode 100644 index 0000000..81ac6b9 --- /dev/null +++ b/scripts/debug_context_file.py @@ -0,0 +1,14 @@ +import json + +path = "data/raw_datasets/personamem/shared_contexts_32k.jsonl" +with open(path, 'r') as f: + line = f.readline() + data = json.loads(line) + print(f"Type: {type(data)}") + if isinstance(data, dict): + print(f"Keys: {list(data.keys())}") + # Peek into values + for k, v in data.items(): + print(f"Key '{k}' type: {type(v)}") + if isinstance(v, list): + print(f" Length: {len(v)}") diff --git a/scripts/debug_minimal_day3.py b/scripts/debug_minimal_day3.py new file mode 100644 index 0000000..169cf13 --- /dev/null +++ b/scripts/debug_minimal_day3.py @@ -0,0 +1,40 @@ +import sys +import os +import torch + +# Add src to sys.path +sys.path.append(os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.config.settings import load_local_models_config +from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B + +def main(): + print("--- Minimal Day 3 Debug ---") + + # 1. Load Config + print("Loading Local Models Config...") + cfg = load_local_models_config() + print(f"Config loaded.") + + # Check what we got + spec = cfg.embedding.qwen3 + print(f"Model Path: {spec.local_path}") + print(f"Dtype: {spec.dtype}") + print(f"Device Map: {spec.device_map}") + # Check if trust_remote_code is in spec (it should be if my edit worked) + trc = getattr(spec, "trust_remote_code", "UNKNOWN") + print(f"Trust Remote Code: {trc}") + + # 2. Init Embedder + print("Initializing Embedder via Qwen3Embedding8B.from_config...") + try: + embedder = Qwen3Embedding8B.from_config(cfg) + print("Embedder loaded successfully!") + except Exception as e: + print(f"Failed to load embedder: {e}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + main() + diff --git a/scripts/debug_personamem_hash.py b/scripts/debug_personamem_hash.py new file mode 100644 index 0000000..7ef442d --- /dev/null +++ b/scripts/debug_personamem_hash.py @@ -0,0 +1,22 @@ +import hashlib +import json + +def get_line_hash(line_str: str) -> str: + """Compute SHA256 hash of the line content to match shared_context_id.""" + return hashlib.sha256(line_str.strip().encode("utf-8")).hexdigest() + +def debug_hash(): + jsonl_path = "data/raw_datasets/personamem/shared_contexts_32k.jsonl" + with open(jsonl_path, "r") as f: + first_line = f.readline() + + computed_hash = get_line_hash(first_line) + target_hash = "e898d03fec683b1cabf29f57287ff66f8a31842543ecef44b56766844c1c1301" + + print(f"Computed: {computed_hash}") + print(f"Target: {target_hash}") + print(f"Match: {computed_hash == target_hash}") + +if __name__ == "__main__": + debug_hash() + diff --git a/scripts/diagnose_oom.py b/scripts/diagnose_oom.py new file mode 100644 index 0000000..22de3f9 --- /dev/null +++ b/scripts/diagnose_oom.py @@ -0,0 +1,78 @@ +import torch +from transformers import AutoModel, AutoTokenizer +import os +import psutil +import time +import sys +import gc + +def log_mem(msg): + mem = psutil.Process().memory_info().rss / (1024**3) + if torch.cuda.is_available(): + gpu = torch.cuda.memory_allocated() / (1024**3) + gpu_res = torch.cuda.memory_reserved() / (1024**3) + print(f"[{msg}] RAM: {mem:.2f}GB | GPU Alloc: {gpu:.2f}GB | GPU Res: {gpu_res:.2f}GB") + else: + print(f"[{msg}] RAM: {mem:.2f}GB | GPU: N/A") + sys.stdout.flush() + +def main(): + print("--- Diagnostic Script ---") + log_mem("Start") + + model_path = "models/qwen3-embedding-8b" + print(f"Model path: {model_path}") + + # Check config + import yaml + try: + with open("configs/local_models.yaml", "r") as f: + cfg = yaml.safe_load(f) + print("Config loaded from local_models.yaml:") + print(cfg['models']['embedding']['qwen3']) + except Exception as e: + print(f"Could not load config: {e}") + + # Explicit garbage collection + gc.collect() + torch.cuda.empty_cache() + log_mem("Pre-Load") + + print("Loading Tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=False) + log_mem("Tokenizer Loaded") + + print("Loading Model (trust_remote_code=False)...") + try: + # Load with low_cpu_mem_usage=True explicit (though auto/cuda usually does it) + model = AutoModel.from_pretrained( + model_path, + device_map="cuda:0", + torch_dtype=torch.bfloat16, + trust_remote_code=False, + low_cpu_mem_usage=True + ) + print("Model loaded successfully.") + except Exception as e: + print(f"Model load failed: {e}") + return + + log_mem("Model Loaded") + + print("Testing forward pass with small input...") + input_text = "Hello world" + inputs = tokenizer(input_text, return_tensors="pt").to("cuda:0") + + try: + with torch.no_grad(): + outputs = model(**inputs) + print("Forward pass success.") + print(f"Output shape: {outputs.last_hidden_state.shape}") + except Exception as e: + print(f"Forward pass failed: {e}") + + log_mem("End") + +if __name__ == "__main__": + main() + diff --git a/scripts/download_datasets.py b/scripts/download_datasets.py new file mode 100644 index 0000000..f78b15f --- /dev/null +++ b/scripts/download_datasets.py @@ -0,0 +1,210 @@ +import os +import json +import random +from typing import List, Dict, Any +from datasets import load_dataset +from tqdm import tqdm + +# Configuration +OUTPUT_DIR = "data/raw_datasets" + +# Dataset configurations +# Format: (huggingface_id, subset, split, text_column_name, approximate_limit) +SOURCES = [ + { + "id": "lmsys/lmsys-chat-1m", + "subset": None, + "split": "train", + "type": "lmsys", + "limit": 200000 + }, + { + "id": "allenai/WildChat", + "subset": None, + "split": "train", + "type": "wildchat", + "limit": 150000 + }, + { + "id": "anon8231489123/ShareGPT_Vicuna_unfiltered", + "subset": None, + "split": "train", + "data_files": "ShareGPT_V3_unfiltered_cleaned_split.json", + "type": "sharegpt", + "limit": 50000 + }, + { + "id": "yahma/alpaca-cleaned", + "subset": None, + "split": "train", + "type": "alpaca", + "limit": 52000 + }, + { + "id": "Open-Orca/SlimOrca", + "subset": None, + "split": "train", + "type": "slimorca", + "limit": 100000 + } +] + +def ensure_english(text: str) -> bool: + # A simple heuristic to filter non-English text. + # For production, use langdetect or similar libraries. + # Here we check if a significant portion of characters are ASCII. + try: + return text.isascii() + except: + return False + +def process_lmsys(example: Dict[str, Any]) -> str | None: + # LMSYS format: conversation is in 'conversation' list of dicts + try: + conversation = example.get("conversation", []) + if not conversation: + return None + # Get first user message + if conversation[0]["role"] == "user": + return conversation[0]["content"] + except: + pass + return None + +def process_wildchat(example: Dict[str, Any]) -> str | None: + # WildChat format: 'conversation' list of dicts or 'prompt' column? + # Checking dataset viewer, it usually has 'conversation' with 'content' and 'role' + try: + conversation = example.get("conversation", []) + if not conversation: + return None + if conversation[0]["role"] == "user": + return conversation[0]["content"] + except: + pass + return None + +def process_sharegpt(example: Dict[str, Any]) -> str | None: + # ShareGPT format: 'conversations' list + try: + conversations = example.get("conversations", []) + if not conversations: + return None + # Usually human/gpt or user/assistant + if conversations[0]["from"] in ["human", "user"]: + return conversations[0]["value"] + except: + pass + return None + +def process_alpaca(example: Dict[str, Any]) -> str | None: + # Alpaca format: 'instruction' and 'input'. We combine them if input exists. + try: + instruction = example.get("instruction", "").strip() + inp = example.get("input", "").strip() + if inp: + return f"{instruction}\n\nInput: {inp}" + return instruction + except: + pass + return None + +def process_slimorca(example: Dict[str, Any]) -> str | None: + # SlimOrca format: 'conversations' list of dicts (from, value) + # Similar to ShareGPT but keys might differ slightly + try: + conversations = example.get("conversations", []) + if not conversations: + return None + # Usually from: human/user + if conversations[0]["from"] in ["human", "user"]: + return conversations[0]["value"] + except: + pass + return None + +def download_and_process(): + os.makedirs(OUTPUT_DIR, exist_ok=True) + + all_queries = [] + + # Target new sources only (alpaca and slimorca) + # You can comment out this filter if you want to re-run everything + new_types = ["alpaca", "slimorca"] + + for source in SOURCES: + if source["type"] not in new_types: + continue + + print(f"Processing {source['id']}...") + try: + # Load streaming to save disk/memory + kwargs = {"streaming": True} + if "data_files" in source: + kwargs["data_files"] = source["data_files"] + + ds = load_dataset(source["id"], source["subset"], split=source["split"], **kwargs) + + count = 0 + limit = source["limit"] + + for example in tqdm(ds, desc=f"Reading {source['id']}", total=limit): + if count >= limit: + break + + query = None + if source["type"] == "lmsys": + query = process_lmsys(example) + elif source["type"] == "wildchat": + query = process_wildchat(example) + elif source["type"] == "sharegpt": + query = process_sharegpt(example) + elif source["type"] == "alpaca": + query = process_alpaca(example) + elif source["type"] == "slimorca": + query = process_slimorca(example) + + # Basic cleaning + if query and len(query.strip()) > 5 and ensure_english(query): + all_queries.append({ + "source": source["id"], + "query": query.strip() + }) + count += 1 + + except Exception as e: + print(f"Error processing {source['id']}: {e}") + + # Deduplicate based on query content + print(f"Total collected new items: {len(all_queries)}") + + # Load existing if available to dedup against + output_path = os.path.join(OUTPUT_DIR, "combined_raw_queries.jsonl") + existing_data = [] + if os.path.exists(output_path): + print("Loading existing data for deduplication...") + with open(output_path, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + existing_data.append(json.loads(line)) + + combined = existing_data + all_queries + print(f"Total before final deduplication: {len(combined)}") + + unique_queries = {item["query"]: item for item in combined}.values() + final_data = list(unique_queries) + print(f"Total after final deduplication: {len(final_data)}") + + # Shuffle + random.shuffle(final_data) + + # Save + print(f"Saving to {output_path}...") + with open(output_path, "w", encoding="utf-8") as f: + for item in final_data: + f.write(json.dumps(item, ensure_ascii=False) + "\n") + + print("Done!") + +if __name__ == "__main__": + download_and_process() diff --git a/scripts/download_llama.py b/scripts/download_llama.py new file mode 100644 index 0000000..47c7243 --- /dev/null +++ b/scripts/download_llama.py @@ -0,0 +1,16 @@ +from huggingface_hub import snapshot_download +import os + +model_id = "meta-llama/Llama-3.1-8B-Instruct" +local_dir = "models/llama-3.1-8b-instruct" + +os.makedirs(local_dir, exist_ok=True) + +print(f"Downloading {model_id} to {local_dir}...") +snapshot_download( + repo_id=model_id, + local_dir=local_dir, + # token=os.getenv("HF_TOKEN"), # Assuming token is in env or cached +) +print("Download complete.") + diff --git a/scripts/download_oasst1.py b/scripts/download_oasst1.py new file mode 100644 index 0000000..3105f28 --- /dev/null +++ b/scripts/download_oasst1.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +""" +Script to download and process OpenAssistant/oasst1 dataset. +Converts it into a flat list of user turns (ChatTurn) for our pipeline. + +Output format per line (JSONL): +{ + "original_query": str, + "source": "oasst1", + "user_id": str, + "session_id": str, + "turn_id": int +} +""" + +import json +import os +from datasets import load_dataset +from tqdm import tqdm + +def main(): + output_path = "data/raw_datasets/oasst1_queries.jsonl" + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + print("Downloading OpenAssistant/oasst1 dataset...") + # OASST1 is a tree structure. We need to traverse it to reconstruct conversations. + # It has 'message_id', 'parent_id', 'user_id', 'text', 'role' + ds = load_dataset("OpenAssistant/oasst1", split="train") + + print(f"Loaded {len(ds)} messages. Reconstructing threads...") + + # Index by message_id + id2msg = {} + for row in tqdm(ds, desc="Indexing"): + id2msg[row["message_id"]] = row + + # Find leaf nodes to trace back threads? + # Or just find all user messages and trace back to root to establish session context? + # For this task: "从 OASST1 里构造统一的 ChatTurn 序列(带 user_id 和 session_id)" + # We want valid user turns. + # OASST1 'user_id' is the author ID. + # 'message_tree_id' identifies the conversation tree (session). + + # We can iterate all messages. If role=='prompter' (user), we treat it as a turn. + # We use 'message_tree_id' as session_id. + + queries = [] + + # Iterate all rows + for row in tqdm(ds, desc="Processing"): + if row["role"] == "prompter": + # This is a user turn + user_id = row["user_id"] # Author ID + session_id = row["message_tree_id"] + text = row["text"] + + # Simple metadata + queries.append({ + "original_query": text, + "source": "oasst1", + "user_id": str(user_id), + "session_id": str(session_id), + "turn_id": 0 # We don't strictly need precise turn_id for Day 1 pipeline right now unless we sort + }) + + print(f"Extracted {len(queries)} user queries.") + + # Save + print(f"Saving to {output_path}...") + with open(output_path, "w", encoding="utf-8") as f: + for q in queries: + f.write(json.dumps(q, ensure_ascii=False) + "\n") + + print("Done.") + +if __name__ == "__main__": + main() + diff --git a/scripts/download_personamem.py b/scripts/download_personamem.py new file mode 100644 index 0000000..31b4e0e --- /dev/null +++ b/scripts/download_personamem.py @@ -0,0 +1,25 @@ +from huggingface_hub import hf_hub_download +import os + +repo_id = "bowen-upenn/PersonaMem" +local_dir = "data/raw_datasets/personamem" +files_to_download = [ + "questions_32k.csv", + "shared_contexts_32k.jsonl" +] + +os.makedirs(local_dir, exist_ok=True) + +print(f"Downloading files from {repo_id} to {local_dir}...") + +for filename in files_to_download: + print(f"Downloading {filename}...") + hf_hub_download( + repo_id=repo_id, + filename=filename, + repo_type="dataset", + local_dir=local_dir + ) + +print("Download complete.") + diff --git a/scripts/eval_embedder_reranker.py b/scripts/eval_embedder_reranker.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/scripts/eval_embedder_reranker.py diff --git a/scripts/eval_interface_example.py b/scripts/eval_interface_example.py new file mode 100644 index 0000000..d5dc6cd --- /dev/null +++ b/scripts/eval_interface_example.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +""" +Example: Using the PersonalizedLLM Interface for Evaluation. + +This script demonstrates the evaluation interface that can be used +by a user simulator or evaluation framework. + +Call sequence per evaluation run: +1. reset_user(user_id) - Start fresh for this user's "life" +2. For each session (s=1..S): + a. reset_session(user_id) - New chat window + b. For each turn (t=1..T): + i. [Turn 2+] apply_feedback() for previous turn + ii. resp = chat(user_id, query) + iii. [Simulator computes reward from response] +3. persist() - Save state at end +""" + +import sys +import os + +# Add src to sys.path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.serving import ( + PersonalizedLLM, + AssistantResponse, + Feedback, +) + + +def main(): + print("=" * 60) + print("PersonalizedLLM Evaluation Interface Demo") + print("=" * 60) + + # Initialize the system + # Note: This will load models, which takes time and GPU memory + print("\n[1] Initializing PersonalizedLLM...") + + llm = PersonalizedLLM( + user_store_path="data/users/user_store_eval_demo.npz", + only_own_memories=True, + enable_preference_extraction=True, + enable_rl_updates=True, + ) + + # Define test user + user_id = "eval_demo_user" + + # Reset user for clean experiment + print(f"\n[2] Resetting user: {user_id}") + llm.reset_user(user_id) + + # Check initial state + print(f"\n[3] Initial user state:") + print(f" {llm.get_user_state_summary(user_id)}") + + # Simulate multiple sessions + num_sessions = 2 + queries_per_session = [ + # Session 1: Food preferences + [ + "What's a good recipe for dinner tonight?", + "I prefer vegetarian food with Asian flavors.", + "Can you suggest something spicy?", + ], + # Session 2: Test personalization retention + [ + "What should I cook for lunch?", + "Give me a quick meal idea.", + ], + ] + + all_responses = [] + + for session_idx, session_queries in enumerate(queries_per_session): + print(f"\n{'=' * 60}") + print(f"SESSION {session_idx + 1}") + print("=" * 60) + + # Reset session (new chat window) + llm.reset_session(user_id) + print(f"[Session {session_idx + 1}] Started new session") + + session_responses = [] + + for turn_idx, query in enumerate(session_queries): + print(f"\n--- Turn {turn_idx + 1} ---") + + # Apply feedback for previous turn (from turn 2 onwards) + if turn_idx > 0: + # Simulated feedback - in real eval, this comes from user simulator + simulated_reward = 0.7 + 0.1 * (turn_idx % 2) # Varies by turn + simulated_gating = 1.0 if turn_idx > 0 else 0.0 + + feedback = Feedback( + user_id=user_id, + turn_id=turn_idx - 1, + reward=simulated_reward, + gating=simulated_gating, + meta={"source": "demo_simulator"} + ) + + print(f"[Feedback] Applying: reward={simulated_reward:.2f}, gating={simulated_gating:.1f}") + llm.apply_feedback(feedback) + + # Main chat call + print(f"User: {query}") + response: AssistantResponse = llm.chat(user_id, query) + + print(f"Assistant: {response.answer[:200]}..." if len(response.answer) > 200 else f"Assistant: {response.answer}") + print(f"[Usage] prompt={response.usage.prompt_tokens}, completion={response.usage.completion_tokens}, model={response.usage.model}") + + if response.debug: + print(f"[Debug] memories={len(response.debug.selected_memory_ids)}, z_long_norm={response.debug.extra.get('z_long_norm', 0):.4f}") + if response.debug.extracted_preferences: + print(f"[Debug] Extracted {len(response.debug.extracted_preferences)} preferences") + + session_responses.append(response) + + all_responses.append(session_responses) + + # Show user state after session + print(f"\n[Session {session_idx + 1}] Final state:") + print(f" {llm.get_user_state_summary(user_id)}") + + # Summary + print(f"\n{'=' * 60}") + print("EVALUATION SUMMARY") + print("=" * 60) + + total_tokens = sum( + r.usage.total_tokens + for session in all_responses + for r in session + ) + total_turns = sum(len(s) for s in all_responses) + + print(f"Total sessions: {len(all_responses)}") + print(f"Total turns: {total_turns}") + print(f"Total tokens: {total_tokens}") + print(f"Final user state: {llm.get_user_state_summary(user_id)}") + + # Persist (optional, for saving state between runs) + # llm.persist() + # print("\nState persisted to disk.") + + print("\nDemo complete!") + + +if __name__ == "__main__": + main() + diff --git a/scripts/eval_single_ckpt.py b/scripts/eval_single_ckpt.py new file mode 100644 index 0000000..8597907 --- /dev/null +++ b/scripts/eval_single_ckpt.py @@ -0,0 +1,145 @@ +import json +import os +import torch +import glob +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer +from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score +from torch.utils.data import Dataset, DataLoader + +# --- Configuration --- +# You can manually set the checkpoint path here if glob fails or is slow +# Example: "saves/qwen3-0.6b-full-sft-h200/checkpoint-4358" +CHECKPOINT_DIR = "saves/qwen3-0.6b-full-sft-h200" +TEST_FILE = "data/test_llama_factory.json" +BATCH_SIZE = 128 +USE_FLASH_ATTN = False + +# Load System Prompt +with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f: + SYSTEM_PROMPT = f.read() + +class EvalDataset(Dataset): + def __init__(self, data): + self.data = data + def __len__(self): + return len(self.data) + def __getitem__(self, idx): + return self.data[idx] + +def load_test_data(): + with open(TEST_FILE, "r", encoding="utf-8") as f: + return json.load(f) + +def batch_generate(model, tokenizer, batch_data, device="cuda"): + prompts = [] + for item in batch_data: + messages = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": item["input"]} + ] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + prompts.append(text) + + inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left").to(device) + + with torch.no_grad(): + generated_ids = model.generate( + **inputs, + max_new_tokens=256, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + + input_len = inputs.input_ids.shape[1] + gen_tokens = generated_ids[:, input_len:] + responses = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True) + return responses + +def evaluate_ckpt(): + # 1. Find the latest checkpoint + checkpoints = sorted(glob.glob(os.path.join(CHECKPOINT_DIR, "checkpoint-*")), key=lambda x: int(x.split("-")[-1])) + if not checkpoints: + print(f"No checkpoints found in {CHECKPOINT_DIR}") + return + + latest_ckpt = checkpoints[-1] + print(f"\nTarget Checkpoint: {latest_ckpt}") + + device = "cuda" + print(f"Loading model (Batch Size: {BATCH_SIZE})...") + + try: + tokenizer = AutoTokenizer.from_pretrained(latest_ckpt, trust_remote_code=True, padding_side="left") + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + kwargs = {"device_map": device, "torch_dtype": torch.bfloat16, "trust_remote_code": True} + if USE_FLASH_ATTN: + kwargs["attn_implementation"] = "flash_attention_2" + + model = AutoModelForCausalLM.from_pretrained(latest_ckpt, **kwargs) + model.eval() + except Exception as e: + print(f"CRITICAL ERROR loading model: {e}") + return + + # 2. Prepare Data + test_data = load_test_data() + dataset = EvalDataset(test_data) + # Reduce num_workers to avoid hang if system is stressed + dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2) + + y_true_has_pref = [] + y_pred_has_pref = [] + json_valid_count = 0 + + print(f"Evaluating {len(test_data)} samples...") + + # 3. Inference Loop + for batch in tqdm(dataloader): + inputs = batch["input"] + outputs = batch["output"] + + # Ground Truth + for gt_str in outputs: + try: + gt_json = json.loads(gt_str) + gt_has = len(gt_json.get("preferences", [])) > 0 + except: + gt_has = False + y_true_has_pref.append(gt_has) + + # Prediction + batch_items = [{"input": inp} for inp in inputs] + responses = batch_generate(model, tokenizer, batch_items, device) + + for pred_str in responses: + pred_has = False + try: + pred_json = json.loads(pred_str) + json_valid_count += 1 + pred_has = len(pred_json.get("preferences", [])) > 0 + except: + pass + y_pred_has_pref.append(pred_has) + + # 4. Metrics + print("\n" + "="*40) + print(f"RESULTS for {latest_ckpt}") + print("="*40) + print(f"JSON Validity: {json_valid_count / len(test_data):.4f}") + print(f"Accuracy: {accuracy_score(y_true_has_pref, y_pred_has_pref):.4f}") + print(f"Precision: {precision_score(y_true_has_pref, y_pred_has_pref, zero_division=0):.4f}") + print(f"Recall: {recall_score(y_true_has_pref, y_pred_has_pref, zero_division=0):.4f}") + print(f"F1 Score: {f1_score(y_true_has_pref, y_pred_has_pref, zero_division=0):.4f}") + print("="*40) + +if __name__ == "__main__": + evaluate_ckpt() + diff --git a/scripts/evaluate_checkpoints.py b/scripts/evaluate_checkpoints.py new file mode 100644 index 0000000..cb5c993 --- /dev/null +++ b/scripts/evaluate_checkpoints.py @@ -0,0 +1,205 @@ +import json +import os +import glob +import torch +import matplotlib.pyplot as plt +import pandas as pd +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer +from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score +from torch.utils.data import Dataset, DataLoader + +# --- Configuration --- +BASE_MODEL_NAME = "Qwen/Qwen3-0.6B" # Or local path models/Qwen3-0.6B +CHECKPOINT_DIR = "saves/qwen3-0.6b-full-sft-h200" +TEST_FILE = "data/test_llama_factory.json" +RESULTS_FILE = "evaluation_results.csv" +PLOT_FILE = "evaluation_plot.png" + +# H200 Optimization +BATCH_SIZE = 128 # H200 can handle massive batches for 0.6B model +USE_FLASH_ATTN = False + +# Load System Prompt +with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f: + SYSTEM_PROMPT = f.read() + +class EvalDataset(Dataset): + def __init__(self, data): + self.data = data + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + +def load_test_data(): + with open(TEST_FILE, "r", encoding="utf-8") as f: + return json.load(f) + +def batch_generate(model, tokenizer, batch_data, device="cuda"): + prompts = [] + for item in batch_data: + messages = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": item["input"]} + ] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + prompts.append(text) + + inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left").to(device) + + with torch.no_grad(): + generated_ids = model.generate( + **inputs, + max_new_tokens=256, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + + # Slice only generated tokens + input_len = inputs.input_ids.shape[1] + gen_tokens = generated_ids[:, input_len:] + responses = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True) + return responses + +def evaluate_single_model(model_path, test_data, device="cuda"): + print(f"Loading model: {model_path}...") + try: + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side="left") + # Ensure pad token is set for batch generation + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + kwargs = {"device_map": device, "torch_dtype": torch.bfloat16, "trust_remote_code": True} + if USE_FLASH_ATTN: + kwargs["attn_implementation"] = "flash_attention_2" + + model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs) + except Exception as e: + print(f"Failed to load {model_path}: {e}") + return None + + model.eval() + + dataset = EvalDataset(test_data) + dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4) # Use workers for data loading + + y_true_has_pref = [] + y_pred_has_pref = [] + json_valid_count = 0 + + print(f"Evaluating on {len(test_data)} samples (Batch Size: {BATCH_SIZE})...") + + for batch in tqdm(dataloader): + # batch is a dict of lists because default collate + # we need to reconstruct list of dicts or just access lists + # DataLoader collates list of dicts into dict of lists: {"input": [...], "output": [...]} + inputs = batch["input"] + outputs = batch["output"] + + # Ground Truth + for gt_str in outputs: + try: + gt_json = json.loads(gt_str) + gt_has = len(gt_json.get("preferences", [])) > 0 + except: + gt_has = False + y_true_has_pref.append(gt_has) + + # Prediction + # batch_data structure required by batch_generate needs to be list of dicts with "input" key + # Reconstruct for helper function + batch_items = [{"input": inp} for inp in inputs] + responses = batch_generate(model, tokenizer, batch_items, device) + + for pred_str in responses: + pred_has = False + try: + pred_json = json.loads(pred_str) + json_valid_count += 1 + pred_has = len(pred_json.get("preferences", [])) > 0 + except: + pass + y_pred_has_pref.append(pred_has) + + # Metrics + metrics = { + "json_validity": json_valid_count / len(test_data), + "accuracy": accuracy_score(y_true_has_pref, y_pred_has_pref), + "precision": precision_score(y_true_has_pref, y_pred_has_pref, zero_division=0), + "recall": recall_score(y_true_has_pref, y_pred_has_pref, zero_division=0), + "f1": f1_score(y_true_has_pref, y_pred_has_pref, zero_division=0) + } + + del model + del tokenizer + torch.cuda.empty_cache() + + return metrics + +def main(): + test_data = load_test_data() + results = [] + + # 1. Evaluate Base Model + print("\n--- Evaluating Base Model ---") + base_metrics = evaluate_single_model(BASE_MODEL_NAME, test_data) + if base_metrics: + base_metrics["step"] = 0 + base_metrics["model"] = "Base" + results.append(base_metrics) + print(f"Base: {base_metrics}") + + # 2. Evaluate Checkpoints + checkpoints = sorted(glob.glob(os.path.join(CHECKPOINT_DIR, "checkpoint-*")), key=lambda x: int(x.split("-")[-1])) + print(f"\nFound {len(checkpoints)} checkpoints.") + + # Filter to only Base + Last Checkpoint (User Request) + if checkpoints: + checkpoints = [checkpoints[-1]] + print(f"Selecting only the last checkpoint: {checkpoints[0]}") + + for ckpt in checkpoints: + step = int(ckpt.split("-")[-1]) + print(f"\n--- Evaluating Checkpoint {step} ---") + metrics = evaluate_single_model(ckpt, test_data) + if metrics: + metrics["step"] = step + metrics["model"] = f"Ckpt-{step}" + results.append(metrics) + print(f"Step {step}: {metrics}") + + # 3. Save & Plot + if not results: + print("No results generated.") + return + + df = pd.DataFrame(results) + df = df.sort_values("step") + df.to_csv(RESULTS_FILE, index=False) + print(f"\nResults saved to {RESULTS_FILE}") + print(df) + + plt.figure(figsize=(10, 6)) + plt.plot(df["step"], df["f1"], marker='o', label="F1 Score") + plt.plot(df["step"], df["precision"], marker='s', label="Precision") + plt.plot(df["step"], df["recall"], marker='^', label="Recall") + plt.plot(df["step"], df["json_validity"], marker='x', linestyle='--', label="JSON Validity") + + plt.title("Preference Extractor Training Progress") + plt.xlabel("Training Steps") + plt.ylabel("Score") + plt.legend() + plt.grid(True) + plt.savefig(PLOT_FILE) + print(f"Plot saved to {PLOT_FILE}") + +if __name__ == "__main__": + main() diff --git a/scripts/finish_retry_batches.py b/scripts/finish_retry_batches.py new file mode 100644 index 0000000..f266327 --- /dev/null +++ b/scripts/finish_retry_batches.py @@ -0,0 +1,154 @@ +import json +import os +import asyncio +from openai import OpenAI, AsyncOpenAI +from typing import Dict, Any, Set, List + +# --- Configuration --- +BATCH_IDS_FILE = "data/raw_datasets/submitted_retry_batch_ids.json" +# The input file for *this specific retry batch* run +RETRY_INPUT_SOURCE = "data/raw_datasets/retry_requests.jsonl" +# Where to append the final results +OUTPUT_LABEL_FILE = "data/raw_datasets/labeled_full_dataset_batch.jsonl" +MODEL_NAME = "gpt-5.1" + +def load_retry_queries() -> Dict[str, Dict[str, Any]]: + """ + Load the requests that were submitted in the retry batch. + These are essentially JSON Request objects. + """ + print("Loading retry source requests...") + mapping = {} + with open(RETRY_INPUT_SOURCE, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + req = json.loads(line) + # Structure: {"custom_id": "...", "body": {"messages": [..., {"role": "user", "content": "..."}]}} + custom_id = req["custom_id"] + # Extract user query back from the request body + user_content = "" + for m in req["body"]["messages"]: + if m["role"] == "user": + user_content = m["content"] + break + + mapping[custom_id] = { + "query": user_content, + # We might have lost source info in the retry conversion if not careful, + # but for now let's assume we just need the query. + # (Ideally we should have propagated source in metadata) + } + return mapping + +async def process_and_finish(): + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("Error: OPENAI_API_KEY not set.") + return + + sync_client = OpenAI(api_key=api_key) + async_client = AsyncOpenAI(api_key=api_key) + + if not os.path.exists(BATCH_IDS_FILE): + print(f"Error: {BATCH_IDS_FILE} not found.") + return + + with open(BATCH_IDS_FILE, "r") as f: + batch_ids = json.load(f) + + query_map = load_retry_queries() + processed_ids: Set[str] = set() + + print(f"Total requests in retry batch: {len(query_map)}") + + success_count = 0 + + # 1. Download results from Batch API (even if expired) + print("Downloading batch results...") + with open(OUTPUT_LABEL_FILE, "a", encoding="utf-8") as f_out: + for b_id in batch_ids: + try: + batch = sync_client.batches.retrieve(b_id) + if batch.output_file_id: + content = sync_client.files.content(batch.output_file_id).text + for line in content.splitlines(): + if not line.strip(): continue + res = json.loads(line) + custom_id = res["custom_id"] + + if res["response"]["status_code"] == 200: + try: + body = res["response"]["body"] + llm_content = body["choices"][0]["message"]["content"] + parsed_json = json.loads(llm_content) + + original = query_map.get(custom_id) + if original: + record = { + "custom_id": custom_id, + "original_query": original["query"], + "source": "retry_recovery", # Lost original source, marking as recovery + "extracted_json": parsed_json, + "has_preference": len(parsed_json.get("preferences", [])) > 0 + } + f_out.write(json.dumps(record, ensure_ascii=False) + "\n") + processed_ids.add(custom_id) + success_count += 1 + except: + pass + except Exception as e: + print(f"Error checking batch {b_id}: {e}") + + # 2. Identify Missing + missing_ids = [cid for cid in query_map.keys() if cid not in processed_ids] + print(f"\nMissing/Failed items: {len(missing_ids)}") + + # 3. Finish with Direct API + if missing_ids: + print("Processing missing items via Direct API...") + + # Load System Prompt + with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f: + sys_prompt = f.read() + + with open(OUTPUT_LABEL_FILE, "a", encoding="utf-8") as f_out: + for cid in missing_ids: + item = query_map[cid] + query = item["query"] + print(f" Fixing {cid}...") + + try: + resp = await async_client.chat.completions.create( + model=MODEL_NAME, + messages=[ + {"role": "system", "content": sys_prompt}, + {"role": "user", "content": query} + ], + response_format={"type": "json_object"} + ) + + content = resp.choices[0].message.content + parsed_json = json.loads(content) + + record = { + "custom_id": cid, + "original_query": query, + "source": "retry_direct_fix", + "extracted_json": parsed_json, + "has_preference": len(parsed_json.get("preferences", [])) > 0 + } + f_out.write(json.dumps(record, ensure_ascii=False) + "\n") + success_count += 1 + + except Exception as e: + print(f" Failed to fix {cid}: {e}") + + print("\n" + "="*50) + print("ALL RETRY BATCHES RECOVERED.") + print(f"Total processed in this run: {success_count}") + print(f"Full dataset updated at: {OUTPUT_LABEL_FILE}") + print("="*50) + +if __name__ == "__main__": + asyncio.run(process_and_finish()) + diff --git a/scripts/full_labeling.py b/scripts/full_labeling.py new file mode 100644 index 0000000..1c52819 --- /dev/null +++ b/scripts/full_labeling.py @@ -0,0 +1,125 @@ +import json +import os +import asyncio +import aiofiles +from typing import List, Dict, Any +from openai import AsyncOpenAI +from tqdm.asyncio import tqdm_asyncio + +# --- Configuration --- +INPUT_FILE = "data/raw_datasets/combined_raw_queries.jsonl" +OUTPUT_FILE = "data/raw_datasets/labeled_full_dataset.jsonl" +CHECKPOINT_FILE = "data/raw_datasets/labeling_checkpoint.txt" +MODEL_NAME = "gpt-5.1" # Or "gpt-4o" +MAX_CONCURRENCY = 500 # Adjust based on rate limits +SAVE_INTERVAL = 1000 # Save batch to disk every N items + +# --- Load System Prompt --- +with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f: + SYSTEM_PROMPT = f.read() + +async def label_query(client: AsyncOpenAI, sem: asyncio.Semaphore, item: Dict[str, Any]) -> Dict[str, Any]: + query = item["query"] + async with sem: + try: + # We use a short timeout/retry strategy implicitly via library, + # but for bulk processing, just skipping errors is often better than stalling. + response = await client.chat.completions.create( + model=MODEL_NAME, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": query} + ], + temperature=0.0, + response_format={"type": "json_object"} + ) + result_text = response.choices[0].message.content + + try: + parsed = json.loads(result_text) + prefs = parsed.get("preferences", []) + has_pref = len(prefs) > 0 + except: + parsed = {"error": "json_parse_fail", "raw": result_text} + has_pref = False + + return { + "original_query": query, + "source": item.get("source"), + "extracted_json": parsed, + "has_preference": has_pref + } + except Exception as e: + return { + "original_query": query, + "source": item.get("source"), + "error": str(e), + "has_preference": False + } + +async def main(): + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("Error: OPENAI_API_KEY not set.") + return + + # 1. Determine start position (Resume logic) + processed_count = 0 + if os.path.exists(OUTPUT_FILE): + # Quick line count to see how many we've done + # (This assumes we append strictly) + with open(OUTPUT_FILE, "r", encoding="utf-8") as f: + for _ in f: + processed_count += 1 + + print(f"Resuming from index {processed_count}...") + + # 2. Load Data (skip already processed) + # Since reading 400k lines is fast, we just read all and slice + all_items = [] + with open(INPUT_FILE, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + all_items.append(json.loads(line)) + + total_items = len(all_items) + remaining_items = all_items[processed_count:] + + if not remaining_items: + print("All items processed!") + return + + print(f"Total: {total_items}, Remaining: {len(remaining_items)}") + + # 3. Setup Client + client = AsyncOpenAI(api_key=api_key) + sem = asyncio.Semaphore(MAX_CONCURRENCY) + + # 4. Batch Processing + # We process in chunks to allow periodic saving and memory management + batch_size = SAVE_INTERVAL + + # Open file in append mode + async with aiofiles.open(OUTPUT_FILE, "a", encoding="utf-8") as f_out: + + for i in range(0, len(remaining_items), batch_size): + batch = remaining_items[i : i + batch_size] + tasks = [label_query(client, sem, item) for item in batch] + + # Run batch + results = await tqdm_asyncio.gather(*tasks, desc=f"Batch {i//batch_size}", leave=False) + + # Write batch + lines = [json.dumps(res, ensure_ascii=False) + "\n" for res in results] + await f_out.writelines(lines) + await f_out.flush() # Ensure written to disk + + # Optional: Print stats every now and then + pos_in_batch = sum(1 for r in results if r.get("has_preference")) + # print(f"Batch saved. Positive in this batch: {pos_in_batch}/{len(batch)}") + + print(f"Done! Saved to {OUTPUT_FILE}") + +if __name__ == "__main__": + asyncio.run(main()) + diff --git a/scripts/index_corpus.py b/scripts/index_corpus.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/scripts/index_corpus.py diff --git a/scripts/init_user_states.py b/scripts/init_user_states.py new file mode 100644 index 0000000..73c7435 --- /dev/null +++ b/scripts/init_user_states.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +""" +Script to initialize User States (z_long) from Memory Embeddings. +""" + +import sys +import os +import numpy as np +import json +from collections import defaultdict + +# Add src to sys.path +sys.path.append(os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.user_model.tensor_store import UserTensorStore, UserState +from personalization.retrieval.preference_store.schemas import MemoryCard + +def main(): + cards_path = "data/corpora/memory_cards.jsonl" + item_proj_path = "data/corpora/item_projection.npz" + user_store_path = "data/users/user_store.npz" + + # Ensure user dir + os.makedirs(os.path.dirname(user_store_path), exist_ok=True) + + # 1. Load data + print("Loading memory cards...") + cards = [] + if os.path.exists(cards_path): + with open(cards_path, "r") as f: + for line in f: + cards.append(MemoryCard.model_validate_json(line)) + else: + print("No memory cards found. Exiting.") + return + + print("Loading item projection V...") + if not os.path.exists(item_proj_path): + print("Item projection not found. Run build_item_space.py first.") + return + + proj_data = np.load(item_proj_path) + V = proj_data["V"] # [M, k] + + if len(cards) != V.shape[0]: + print(f"Warning: Number of cards ({len(cards)}) != V rows ({V.shape[0]}). Mismatch?") + # If mismatch, we might need to be careful. For now assume aligned. + + k = V.shape[1] + + # 2. Group by user + user_indices = defaultdict(list) + for idx, card in enumerate(cards): + user_indices[card.user_id].append(idx) + + # 3. Initialize Store + print(f"Initializing UserStore at {user_store_path}...") + store = UserTensorStore(k=k, path=user_store_path) + + # 4. Compute z_long and save + print(f"Processing {len(user_indices)} users...") + for uid, indices in user_indices.items(): + if not indices: + continue + + # Get item vectors for this user + # indices is list of int, V is numpy array + user_items = V[indices] + + # Mean pooling + z_long = np.mean(user_items, axis=0) + + # Get/Create state + state = store.get_state(uid) + state.z_long = z_long + state.z_short = np.zeros(k, dtype=np.float32) + state.reward_ma = 0.0 + + store.save_state(state) + + store.persist() + print("Done. User states initialized.") + +if __name__ == "__main__": + main() + diff --git a/scripts/migrate_preferences.py b/scripts/migrate_preferences.py new file mode 100644 index 0000000..5d393c9 --- /dev/null +++ b/scripts/migrate_preferences.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +""" +Script to migrate raw queries into MemoryCards by extracting preferences. +It reads from data/raw_datasets/pilot_study_1000.jsonl and outputs: +- data/corpora/memory_cards.jsonl +- data/corpora/memory_embeddings.npy +""" + +import json +import os +import sys + +# Add src to sys.path so we can import personalization +sys.path.append(os.path.join(os.path.dirname(__file__), "../src")) + +import uuid +import numpy as np +import torch +from pathlib import Path +from tqdm import tqdm +from typing import List + +from personalization.config.settings import load_local_models_config +# from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor +from personalization.models.preference_extractor.gpt4o_extractor import GPT4OExtractor +from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B +from personalization.retrieval.preference_store.schemas import ChatTurn, MemoryCard, PreferenceList + +def ensure_dir(path: str): + Path(path).parent.mkdir(parents=True, exist_ok=True) + +def main(): + # 1. Setup paths + input_path = "data/corpora/oasst1_labeled.jsonl" + # input_path = "data/raw_datasets/oasst1_queries.jsonl" + output_cards_path = "data/corpora/memory_cards.jsonl" + output_emb_path = "data/corpora/memory_embeddings.npy" + ensure_dir(output_cards_path) + + print("Loading models configuration...") + cfg = load_local_models_config() + + # 2. Initialize models + # print("Initializing Preference Extractor (GPT-4o)...") + # extractor = GPT4OExtractor.from_config(cfg) + + print("Initializing Embedding Model...") + embedder = Qwen3Embedding8B.from_config(cfg) + + # 3. Process data + print(f"Reading from {input_path}...") + memory_cards: List[MemoryCard] = [] + + # We will process in small batches to manage memory if needed, + # but for 1000 items, we can iterate one by one for extraction + # and maybe batch for embedding if we want optimization. + # Given the complexity, let's just do sequential for simplicity and safety. + + with open(input_path, "r", encoding="utf-8") as f: + lines = f.readlines() + + # Synthetic user distribution (round robin for 10 users) + users = [f"user_{i}" for i in range(10)] + + print("Extracting preferences...") + # Use tqdm for progress + for idx, line in enumerate(tqdm(lines)): + # if idx >= 100: # LIMIT to 100 items + # break + + row = json.loads(line) + query = row.get("original_query", "").strip() + if not query: + continue + + # Use real metadata from dataset + user_id = row.get("user_id", f"user_{idx}") + session_id = row.get("session_id", f"sess_{idx}") + turn_id = row.get("turn_id", 0) + + # Load pre-extracted preferences + has_pref = row.get("has_preference", False) + extracted_data = row.get("extracted_json", {}) + + # Skip if no preference (according to label) + if not has_pref: + continue + + try: + pref_list = PreferenceList.model_validate(extracted_data) + except Exception: + # Fallback or skip if validation fails + continue + + # If we have preferences, create a memory card + if pref_list.preferences: + # Construct a note text: "condition: action" + notes = [f"{p.condition}: {p.action}" for p in pref_list.preferences] + note_summary = "; ".join(notes) + + # Create MemoryCard (embedding will be filled later) + card = MemoryCard( + card_id=str(uuid.uuid4()), + user_id=user_id, + source_session_id=session_id, + source_turn_ids=[turn_id], + raw_queries=[query], + preference_list=pref_list, + note_text=note_summary, + embedding_e=[], # To be filled + kind="pref" + ) + memory_cards.append(card) + + print(f"Found {len(memory_cards)} memory cards. Generating embeddings...") + + if not memory_cards: + print("No preferences found. Exiting.") + return + + # 4. Generate Embeddings + # We'll embed the `raw_queries` (joined) or `note_text`? + # The design doc says: "Qwen3Embedding8B.encode([turn.text])" + # So we embed the original query that generated the memory. + + texts_to_embed = [card.raw_queries[0] for card in memory_cards] + + print(f"Embedding {len(texts_to_embed)} memories...") + embeddings_list = [] + chunk_size = 2000 # Process in chunks to avoid OOM + + for i in range(0, len(texts_to_embed), chunk_size): + print(f" Embedding chunk {i} to {min(i+chunk_size, len(texts_to_embed))}...") + chunk = texts_to_embed[i : i + chunk_size] + + # Batch encode with larger batch_size for A40 + chunk_emb = embedder.encode( + chunk, + batch_size=128, + normalize=True, + return_tensor=False + ) + embeddings_list.extend(chunk_emb) + + # Assign back to cards and prepare matrix + emb_matrix = [] + for card, emb in zip(memory_cards, embeddings_list): + card.embedding_e = emb + emb_matrix.append(emb) + + # 5. Save + print(f"Saving {len(memory_cards)} cards to {output_cards_path}...") + with open(output_cards_path, "w", encoding="utf-8") as f: + for card in memory_cards: + f.write(card.model_dump_json() + "\n") + + print(f"Saving embeddings matrix to {output_emb_path}...") + np_emb = np.array(emb_matrix, dtype=np.float32) + np.save(output_emb_path, np_emb) + + print("Done!") + +if __name__ == "__main__": + main() + diff --git a/scripts/online_personalization_demo.py b/scripts/online_personalization_demo.py new file mode 100644 index 0000000..f5b6d68 --- /dev/null +++ b/scripts/online_personalization_demo.py @@ -0,0 +1,399 @@ +#!/usr/bin/env python3 +""" +Online Personalization REPL Demo. +Interactive CLI for chatting with the Personalized Memory RAG system. +Includes: +- Extractor-0.6B for online preference extraction +- Reranker + Policy Retrieval +- Online RL updates (REINFORCE) +""" + +import sys +import os +import uuid +import numpy as np +import torch +import readline # For better input handling +import yaml + +# Add src to sys.path +sys.path.append(os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.config.settings import load_local_models_config +from personalization.config.registry import ( + get_preference_extractor, + get_chat_model, +) +from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B +from personalization.models.reranker.qwen3_reranker import Qwen3Reranker +# from personalization.models.llm.qwen_instruct import QwenInstruct # Deprecated direct import +from personalization.user_model.tensor_store import UserTensorStore +from personalization.user_model.session_state import OnlineSessionState +from personalization.retrieval.preference_store.schemas import MemoryCard, ChatTurn, PreferenceList +from personalization.retrieval.pipeline import retrieve_with_policy +from personalization.feedback.handlers import eval_step +from personalization.user_model.policy.reinforce import reinforce_update_user_state +from personalization.user_model.features import ItemProjection + +def load_memory_store(): + cards_path = "data/corpora/memory_cards.jsonl" + embs_path = "data/corpora/memory_embeddings.npy" + item_proj_path = "data/corpora/item_projection.npz" + + if not os.path.exists(cards_path) or not os.path.exists(embs_path): + print("Memory data missing. Starting with empty memory store is possible but item space requires base data.") + # For this demo, we assume base data exists to define PCA space. + sys.exit(1) + + print(f"Loading memory cards from {cards_path}...") + cards = [] + with open(cards_path, "r") as f: + for line in f: + cards.append(MemoryCard.model_validate_json(line)) + + memory_embeddings = np.load(embs_path) + + # Load PCA projection + proj_data = np.load(item_proj_path) + # We need to reconstruct ItemProjection object to transform new memories + projection = ItemProjection(P=proj_data["P"], mean=proj_data["mean"]) + item_vectors = proj_data["V"] + + return cards, memory_embeddings, item_vectors, projection + +def build_user_turn(user_id: str, text: str, turn_id: int) -> ChatTurn: + return ChatTurn( + user_id=user_id, + session_id="online_debug_session", + turn_id=turn_id, + role="user", + text=text, + meta={"source": "repl"} + ) + +def build_assistant_turn(user_id: str, text: str, turn_id: int) -> ChatTurn: + return ChatTurn( + user_id=user_id, + session_id="online_debug_session", + turn_id=turn_id, + role="assistant", + text=text, + meta={"source": "repl"} + ) + +def add_preferences_as_memory_cards( + prefs: PreferenceList, + query: str, + user_id: str, + turn_id: int, + embed_model: Qwen3Embedding8B, + projection: ItemProjection, + memory_cards: list, + memory_embeddings: np.ndarray, + item_vectors: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """ + Adds extracted preferences as new memory cards. + Returns updated memory_embeddings and item_vectors. + """ + if not prefs.preferences: + print(" [Extractor] No preferences found in this turn.") + return memory_embeddings, item_vectors + + e_m_list = [] + v_m_list = [] + + # Only compute embedding once if we use the query as source for all prefs from this turn + # Alternatively, embed the note text. + # The current design uses the original query embedding e_m. + e_q = embed_model.encode([query], return_tensor=False)[0] + v_q = projection.transform_vector(np.array(e_q)) + + print(f" [Extractor] Extracted {len(prefs.preferences)} preferences:") + for pref in prefs.preferences: + note_text = f"When {pref.condition}, {pref.action}." + print(f" - {note_text}") + + # Simple deduplication check based on note_text for this user + # In a real system, use vector similarity or hash + is_duplicate = False + for card in memory_cards: + if card.user_id == user_id and card.note_text == note_text: + is_duplicate = True + break + + if is_duplicate: + print(" (Duplicate, skipping add)") + continue + + card = MemoryCard( + card_id=str(uuid.uuid4()), + user_id=user_id, + source_session_id="online_debug_session", + source_turn_ids=[turn_id], + raw_queries=[query], + preference_list=PreferenceList(preferences=[pref]), + note_text=note_text, + embedding_e=e_q, # Store list[float] + kind="pref", + ) + + memory_cards.append(card) + e_m_list.append(e_q) + v_m_list.append(v_q) + + # Update numpy arrays + if e_m_list: + new_embs = np.array(e_m_list) + new_vecs = np.array(v_m_list) + memory_embeddings = np.vstack([memory_embeddings, new_embs]) + item_vectors = np.vstack([item_vectors, new_vecs]) + print(f" [Debug] Added {len(e_m_list)} new cards. Total cards: {len(memory_cards)}") + + return memory_embeddings, item_vectors + +def main(): + # 1. Load Config & Models + print("Loading configuration...") + cfg = load_local_models_config() + + # RL Config (Should load from user_model.yaml, hardcoded for safety/demo) + rl_cfg = { + "item_dim": 256, + "beta_long": 0.1, + "beta_short": 0.3, + "tau": 1.0, + "eta_long": 1e-3, + "eta_short": 5e-3, + "ema_alpha": 0.05, + "short_decay": 0.1, + "dense_topk": 64, + "rerank_topk": 3, + "max_new_tokens": 512 + } + + print("Loading models and stores...") + # Using explicit classes for clarity, but registry can also be used + embed_model = Qwen3Embedding8B.from_config(cfg) + reranker = Qwen3Reranker.from_config(cfg) + + # Use registry for ChatModel (supports switching backends) + # Default to "qwen_1_5b" if not specified in user_model.yaml + llm_name = "qwen_1_5b" + + # Try loading from config safely + try: + config_path = os.path.join(os.path.dirname(__file__), "../configs/user_model.yaml") + if os.path.exists(config_path): + with open(config_path, "r") as f: + user_cfg = yaml.safe_load(f) + if user_cfg and "llm_name" in user_cfg: + llm_name = user_cfg["llm_name"] + print(f"Loaded llm_name from config: {llm_name}") + else: + print(f"Warning: Config file not found at {config_path}") + except Exception as e: + print(f"Failed to load user_model.yaml: {e}") + pass + + print(f"Loading ChatModel: {llm_name}...") + chat_model = get_chat_model(llm_name) + + # Use registry for extractor to support switching + extractor_name = "qwen3_0_6b_sft" # Default per design doc + print(f"Loading extractor: {extractor_name}...") + try: + extractor = get_preference_extractor(extractor_name) + except Exception as e: + print(f"Failed to load {extractor_name}: {e}. Fallback to rule.") + extractor = get_preference_extractor("rule") + + user_store = UserTensorStore( + k=rl_cfg["item_dim"], + path="data/users/user_store_online.npz", + ) + + # Load Memory + memory_cards, memory_embeddings, item_vectors, projection = load_memory_store() + + # 2. Init Session + user_id = "debug_user" + user_state = user_store.get_state(user_id) + session_state = OnlineSessionState(user_id=user_id) + + print(f"\n--- Online Personalization REPL (User: {user_id}) ---") + print(f"Initial State: ||z_long||={np.linalg.norm(user_state.z_long):.16f}, ||z_short||={np.linalg.norm(user_state.z_short):.16f}") + print("Type 'exit' or 'quit' to stop.\n") + + while True: + try: + q_t = input("User: ").strip() + except (EOFError, KeyboardInterrupt): + print("\nExiting...") + break + + if q_t.lower() in ("exit", "quit"): + break + if not q_t: + continue + + # 3. RL Update (from previous turn) + e_q_t = embed_model.encode([q_t], return_tensor=False)[0] + e_q_t = np.array(e_q_t) + + if session_state.last_query is not None: + r_hat, g_hat = eval_step( + q_t=session_state.last_query, + answer_t=session_state.last_answer, + q_t1=q_t, + memories_t=session_state.last_memories, + query_embedding_t=session_state.last_query_embedding, + query_embedding_t1=e_q_t, + ) + + print(f" [Feedback] Reward: {r_hat:.2f}, Gating: {g_hat:.2f}") + + if (session_state.last_candidate_item_vectors is not None and + session_state.last_policy_probs is not None and + len(session_state.last_chosen_indices) > 0): + + # IMPORTANT: Extract the vectors of the chosen items to align with probs + # last_candidate_item_vectors: [64, dim] + # last_chosen_indices: [3] indices into the 64 candidates + # last_policy_probs: [3] probabilities for the chosen items + + # We need the vectors corresponding to the chosen indices + # chosen_indices contains indices into candidates list + chosen_vectors = session_state.last_candidate_item_vectors[session_state.last_chosen_indices] + + updated = reinforce_update_user_state( + user_state=user_state, + item_vectors=chosen_vectors, # Corrected: Pass only chosen vectors [3, dim] + chosen_indices=np.arange(len(session_state.last_chosen_indices)), # Indices are now 0,1,2 relative to chosen_vectors + policy_probs=session_state.last_policy_probs, + reward_hat=r_hat, + gating=g_hat, + tau=rl_cfg["tau"], + eta_long=rl_cfg["eta_long"], + eta_short=rl_cfg["eta_short"], + ema_alpha=rl_cfg["ema_alpha"], + short_decay=rl_cfg["short_decay"], + ) + if updated: + print(" [RL] User state updated.") + user_store.save_state(user_state) # Save immediately for safety + + # 4. Update History + user_turn = build_user_turn(user_id, q_t, len(session_state.history)) + session_state.history.append(user_turn) + + # 5. Extract Preferences -> New Memory + # Extract from recent history + prefs = extractor.extract_turn(session_state.history) + memory_embeddings, item_vectors = add_preferences_as_memory_cards( + prefs, q_t, user_id, user_turn.turn_id, + embed_model, projection, memory_cards, memory_embeddings, item_vectors + ) + + # 6. Retrieve + Policy + # Use only_own_memories=True to allow strict privacy + # Fix unpacking order: pipeline returns (candidates, vecs, scores, indices, probs) + candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_with_policy( + user_id=user_id, + query=q_t, + embed_model=embed_model, + reranker=reranker, + memory_cards=memory_cards, + memory_embeddings=memory_embeddings, + user_store=user_store, + item_vectors=item_vectors, + topk_dense=rl_cfg["dense_topk"], + topk_rerank=rl_cfg["rerank_topk"], + beta_long=rl_cfg["beta_long"], + beta_short=rl_cfg["beta_short"], + tau=rl_cfg["tau"], + only_own_memories=True # User requested strict privacy for demo + ) + + # Map back to indices in candidates list (0..K-1) + print(f"DEBUG: candidates len={len(candidates)}, type={type(candidates)}") + print(f"DEBUG: chosen_indices={chosen_indices}, type={type(chosen_indices)}") + if len(chosen_indices) > 0: + print(f"DEBUG: first idx type={type(chosen_indices[0])}, val={chosen_indices[0]}") + + memories_t = [candidates[int(i)] for i in chosen_indices] + if memories_t: + print(f" [Retrieval] Found {len(memories_t)} memories:") + + # Display Deduplication: Group by note_text + from collections import Counter + content_counts = Counter([m.note_text for m in memories_t]) + + # Print unique contents with counts + for text, count in content_counts.most_common(): + user_info = f" ({count} users)" if count > 1 else "" + print(f" - {text}{user_info}") + + # 7. LLM Answer + memory_notes = [m.note_text for m in memories_t] + # history should be a list of ChatTurn objects, not dicts + # session_state.history is already a list of ChatTurn + answer_t = chat_model.answer( + history=session_state.history, + memory_notes=memory_notes, + max_new_tokens=rl_cfg["max_new_tokens"] + ) + + print(f"Assistant: {answer_t}") + + # 8. Update State for Next Turn + assist_turn = build_assistant_turn(user_id, answer_t, len(session_state.history)) + session_state.history.append(assist_turn) + + session_state.last_query = q_t + session_state.last_answer = answer_t + session_state.last_memories = memories_t + session_state.last_query_embedding = e_q_t + session_state.last_candidate_item_vectors = cand_item_vecs + session_state.last_policy_probs = probs + session_state.last_chosen_indices = chosen_indices + + print(f" [State] ||z_long||={np.linalg.norm(user_state.z_long):.16f}, ||z_short||={np.linalg.norm(user_state.z_short):.16f}") + print("-" * 40) + + print("Saving final user state...") + user_store.save_state(user_state) + user_store.persist() + + # Save updated memories + print(f"Saving {len(memory_cards)} memory cards to disk...") + # Ideally should be atomic or append-only, but for demo we rewrite + # Backup original first? For demo, direct overwrite is fine or save to new file + cards_path = "data/corpora/memory_cards.jsonl" + embs_path = "data/corpora/memory_embeddings.npy" + item_proj_path = "data/corpora/item_projection.npz" + + with open(cards_path, "w", encoding="utf-8") as f: + for card in memory_cards: + f.write(card.model_dump_json() + "\n") + + np.save(embs_path, memory_embeddings) + + # Update item projection file with new item vectors? + # item_projection.npz usually stores the Projection Matrix P and Mean. + # The 'V' (item vectors) in it is just a cache. + # We should update V in the npz so next load has them. + # Load original to keep P and mean + proj_data = np.load(item_proj_path) + np.savez( + item_proj_path, + P=proj_data["P"], + mean=proj_data["mean"], + V=item_vectors + ) + print("Memory store updated.") + +if __name__ == "__main__": + main() + + diff --git a/scripts/personamem_build_user_vectors.py b/scripts/personamem_build_user_vectors.py new file mode 100644 index 0000000..4719f9c --- /dev/null +++ b/scripts/personamem_build_user_vectors.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +""" +Script to build user vectors for PersonaMem 32k dataset. +1. Load shared contexts. +2. Convert to ChatTurns. +3. Extract preferences -> MemoryCards. +4. Compute embeddings & item vectors. +5. Aggregate to user vectors. +""" + +import sys +import os +import json +import uuid +import numpy as np +from tqdm import tqdm +from typing import List, Dict + +# Add src to sys.path +sys.path.append(os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.config.settings import load_local_models_config +from personalization.config.registry import get_preference_extractor +from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B +from personalization.retrieval.preference_store.schemas import MemoryCard, ChatTurn, PreferenceList +from personalization.data.personamem_loader import load_personamem_contexts_32k, PersonaMemContext +from personalization.user_model.features import ItemProjection + +def context_to_chatturns(context: PersonaMemContext) -> List[ChatTurn]: + turns = [] + for i, msg in enumerate(context.messages): + if isinstance(msg, dict): + role = msg.get("role", "user") + text = msg.get("content", "") + elif isinstance(msg, str): + # Fallback for string messages, assume user? or skip? + # print(f"Warning: msg is string: {msg[:50]}...") + role = "user" + text = msg + else: + continue + + turns.append(ChatTurn( + user_id=context.shared_context_id, # Use context id as user id equivalent for building vector + session_id=context.shared_context_id, + turn_id=i, + role="user" if role == "user" else "assistant", + text=text, + timestamp=None, + meta={} + )) + return turns + +def sliding_windows(seq: List, window_size: int, step: int): + for i in range(0, len(seq), step): + yield seq[i : i + window_size] + +def find_last_user_text(window: List[ChatTurn]) -> str: + for t in reversed(window): + if t.role == "user": + return t.text + return "" + +def main(): + # Paths (adjust as needed, assume downloaded to data/raw_datasets/personamem) + ctx_path = "data/raw_datasets/personamem/shared_contexts_32k.jsonl" + item_proj_path = "data/corpora/item_projection.npz" + output_vec_path = "data/personamem/user_vectors.npz" + output_cards_path = "data/personamem/memory_cards.jsonl" + + # Ensure dirs + os.makedirs(os.path.dirname(output_vec_path), exist_ok=True) + + if not os.path.exists(ctx_path): + print(f"Error: Context file not found at {ctx_path}") + return + + if not os.path.exists(item_proj_path): + print(f"Error: Item projection not found at {item_proj_path}. Run build_item_space.py first.") + return + + # Load Models + print("Loading models...") + cfg = load_local_models_config() + # Explicitly use Qwen3Embedding8B + embed_model = Qwen3Embedding8B.from_config(cfg) + + # Use registry for extractor (SFT model) + extractor_name = "qwen3_0_6b_sft" + try: + extractor = get_preference_extractor(extractor_name) + except: + print(f"Fallback to rule extractor for {extractor_name} not found.") + extractor = get_preference_extractor("rule") + + # Load Projection + proj_data = np.load(item_proj_path) + projection = ItemProjection(P=proj_data["P"], mean=proj_data["mean"]) + + # Load Contexts + print("Loading contexts...") + contexts = load_personamem_contexts_32k(ctx_path) + print(f"Loaded {len(contexts)} contexts.") + + all_cards = [] + + # Process each context + print("Extracting preferences...") + # For demo speed, maybe limit? Or full run. Full run might take time. + # Assuming batch processing or just loop for now. + + for ctx_id, ctx in tqdm(contexts.items()): + turns = context_to_chatturns(ctx) + # print(f"Context {ctx_id}: {len(turns)} turns") + + # Sliding window extraction + for window in sliding_windows(turns, window_size=6, step=3): + # Only extract if window has user turns + if not any(t.role == "user" for t in window): + continue + + try: + prefs = extractor.extract_turn(window) + # if prefs.preferences: + # print(f" Found {len(prefs.preferences)} preferences") + except Exception as e: + print(f"Extraction failed: {e}") + continue + + if not prefs.preferences: + continue + + source_query = find_last_user_text(window) + if not source_query: + continue + + # Embed + e_m = embed_model.encode([source_query], return_tensor=False)[0] + e_m_np = np.array(e_m) + v_m = projection.transform_vector(e_m_np) + + # Serialize note + notes = [f"When {p.condition}, {p.action}." for p in prefs.preferences] + note_text = " ".join(notes) + + card = MemoryCard( + card_id=str(uuid.uuid4()), + user_id=ctx_id, # persona_id/context_id + source_session_id=ctx_id, + source_turn_ids=[t.turn_id for t in window if t.role == "user"], + raw_queries=[source_query], + preference_list=prefs, + note_text=note_text, + embedding_e=e_m, + kind="pref" + ) + all_cards.append(card) + + print(f"Extracted {len(all_cards)} memory cards.") + + # Build User Vectors + print("Building user vectors...") + z_by_user = {} + + # Group cards by user + cards_by_user = {} + for c in all_cards: + if c.user_id not in cards_by_user: + cards_by_user[c.user_id] = [] + cards_by_user[c.user_id].append(c) + + for uid, u_cards in cards_by_user.items(): + # Stack v_m: [M_u, k] + V = np.stack([projection.transform_vector(np.array(c.embedding_e, dtype=np.float32)) for c in u_cards], axis=0) + z = np.mean(V, axis=0) + z_by_user[uid] = z + + # Save + print(f"Saving {len(all_cards)} cards to {output_cards_path}...") + with open(output_cards_path, "w", encoding="utf-8") as f: + for c in all_cards: + f.write(c.model_dump_json() + "\n") + + print(f"Saving user vectors to {output_vec_path}...") + user_ids = list(z_by_user.keys()) + Z = np.array([z_by_user[uid] for uid in user_ids], dtype=np.float32) + np.savez(output_vec_path, user_ids=user_ids, Z=Z) + + print("Done.") + +if __name__ == "__main__": + main() + diff --git a/scripts/personamem_eval_base_vs_ours.py b/scripts/personamem_eval_base_vs_ours.py new file mode 100644 index 0000000..e319765 --- /dev/null +++ b/scripts/personamem_eval_base_vs_ours.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python3 +""" +Evaluation script for PersonaMem task: Base vs Ours (User Vector). +Metric: Accuracy (Top-1 correct option) +""" + +import sys +import os +import numpy as np +from tqdm import tqdm +from typing import List + +# Add src to sys.path +sys.path.append(os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.config.settings import load_local_models_config +from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B +from personalization.data.personamem_loader import load_personamem_questions_32k +from personalization.user_model.features import ItemProjection +from personalization.retrieval.preference_store.schemas import MemoryCard +from personalization.user_model.tensor_store import UserState + +def cosine_sim_matrix(a: np.ndarray, b: np.ndarray) -> np.ndarray: + # a: [d] or [N, d], b: [d] or [M, d] + norm_a = np.linalg.norm(a, axis=-1, keepdims=True) + norm_b = np.linalg.norm(b, axis=-1, keepdims=True) + + # Ensure correct shapes for broadcasting + # If a is 1D [d], dot gives 1D. If a is 2D [N, d], dot gives 2D. + # b is typically [M, d] (memories or options) + + dot = np.dot(b, a.T) + denom = np.dot(norm_b, norm_a.T) + 1e-8 + + # Flatten if needed + sim = dot / denom + if sim.ndim == 2 and sim.shape[1] == 1: + return sim.flatten() + return sim + +def dense_retrieval( + e_q: np.ndarray, + memory_embeddings: np.ndarray, + topk: int = 5 +) -> np.ndarray: + """Returns topk indices of memories most similar to query.""" + if memory_embeddings.shape[0] == 0: + return np.array([], dtype=int) + + sims = cosine_sim_matrix(e_q, memory_embeddings) + # Get topk + k = min(topk, len(sims)) + idx = np.argsort(sims)[-k:][::-1] + return idx + +def policy_retrieval( + e_q: np.ndarray, + memory_embeddings: np.ndarray, + item_vectors: np.ndarray, + z_user: np.ndarray, + item_proj: ItemProjection, + topk_dense: int = 20, + topk_final: int = 5, + beta: float = 0.2 +) -> np.ndarray: + """ + Simulates retrieve_with_policy: + 1. Dense retrieval (topk_dense) + 2. Policy Scoring: s = s_base + beta * (z . v_m) + 3. Select topk_final + """ + if memory_embeddings.shape[0] == 0: + return np.array([], dtype=int) + + # 1. Dense Candidate Generation + dense_idx = dense_retrieval(e_q, memory_embeddings, topk=topk_dense) + if len(dense_idx) == 0: + return np.array([], dtype=int) + + candidates_e = memory_embeddings[dense_idx] + candidates_v = item_vectors[dense_idx] + + # 2. Base Scores (Sim(q, m)) + base_scores = cosine_sim_matrix(e_q, candidates_e) + + # 3. Policy Bonus + # z: [k], v: [K, k] + bonus = np.dot(candidates_v, z_user) + + total_scores = base_scores + beta * bonus + + # 4. Final Selection + k = min(topk_final, len(total_scores)) + local_top_idx = np.argsort(total_scores)[-k:][::-1] + + # Map back to global indices + return dense_idx[local_top_idx] + +def score_rag( + e_q: np.ndarray, + e_opts: np.ndarray, + retrieved_embeddings: np.ndarray +) -> np.ndarray: + """ + Computes score for each option based on query match AND memory match. + Score = Sim(Q, Opt) + Mean(Sim(Memories, Opt)) + """ + # Base: Query-Option similarity + # e_q: [d], e_opts: [4, d] -> [4] + s_q_opt = cosine_sim_matrix(e_q, e_opts) + + if len(retrieved_embeddings) == 0: + return s_q_opt + + # Memory-Option similarity + # retrieved: [K, d], e_opts: [4, d] + # We want for each option, the average similarity to retrieved memories. + # sim_matrix: [4, K] (options x memories) - check implementation of cosine_sim_matrix + # cosine_sim_matrix(a, b) does dot(b, a.T). + # Let a=retrieved [K, d], b=e_opts [4, d]. Result: [4, K] + + s_mem_opt = cosine_sim_matrix(retrieved_embeddings, e_opts) + + # Max pooling or Mean pooling over memories + # Usually max is better for "if any memory supports this option" + s_rag = np.max(s_mem_opt, axis=1) # [4] + + # Combine + return s_q_opt + s_rag + +def main(): + # Paths + q_path = "data/raw_datasets/personamem/questions_32k.csv" + vec_path = "data/personamem/user_vectors.npz" + cards_path = "data/personamem/memory_cards.jsonl" # Generated by builder + item_proj_path = "data/corpora/item_projection.npz" + + if not os.path.exists(q_path) or not os.path.exists(vec_path) or not os.path.exists(cards_path): + print("Data missing. Run personamem_build_user_vectors.py first.") + sys.exit(1) + + print("Loading resources...") + cfg = load_local_models_config() + embed_model = Qwen3Embedding8B.from_config(cfg) + + proj_data = np.load(item_proj_path) + item_proj = ItemProjection(P=proj_data["P"], mean=proj_data["mean"]) + + # Load User Vectors + uv_data = np.load(vec_path, allow_pickle=True) + user_ids = uv_data["user_ids"] + Z = uv_data["Z"] + user_vector_map = {uid: Z[i] for i, uid in enumerate(user_ids)} + print(f"Loaded {len(user_vector_map)} user vectors.") + + # Load Memory Cards & Embeddings + print("Loading memory store...") + cards_by_user = {} + embs_by_user = {} + vecs_by_user = {} # v vectors + + # We need to load all cards to build per-user indices + # This might be slow if file is huge, but 32k dataset usually produces ~100k cards? + # Builder output: "Extracted 321 memory cards" (from small sample log). + # Let's assume it fits in memory. + + with open(cards_path, "r") as f: + for line in f: + card = MemoryCard.model_validate_json(line) + uid = card.user_id + if uid not in cards_by_user: + cards_by_user[uid] = [] + embs_by_user[uid] = [] + + cards_by_user[uid].append(card) + embs_by_user[uid].append(card.embedding_e) + + # Convert lists to numpy arrays + for uid in embs_by_user: + E = np.array(embs_by_user[uid], dtype=np.float32) + embs_by_user[uid] = E + # Compute V on the fly or load if saved (builder didn't save V in separate file, but we can project) + vecs_by_user[uid] = item_proj.transform_embeddings(E) + + print(f"Loaded memories for {len(cards_by_user)} users.") + + # Load Questions + questions = load_personamem_questions_32k(q_path) + print(f"Loaded {len(questions)} questions.") + + correct_base = [] + correct_ours = [] + + # Hyperparams + betas = [0.0, 1.0] # Sanity check: 0.0 should match Base RAG (if Ours RAG logic aligns when beta=0, wait, Ours RAG uses Policy Retrieval. Beta=0 in Policy Retrieval means Dense Retrieval order. So Ours RAG (beta=0) == Base RAG? Ideally yes, if topk_dense is large enough to contain base_topk) + + # Actually, Ours RAG pipeline: + # 1. Dense (top20) -> 2. Re-rank (Base + beta*Bonus) -> 3. Top5 + # Base RAG pipeline: + # 1. Dense (top5) + + # If beta=0, Ours RAG re-ranking is based on Base Score (Sim(q,m)). + # Since Dense Retrieval already sorts by Sim(q,m), re-ranking by Sim(q,m) keeps order. + # So if topk_dense >= topk_final, Ours (beta=0) should pick same top-5 as Base RAG. + + for beta in betas: + print(f"\nEvaluating with RAG (beta={beta})...") + + correct_base = [] + correct_ours = [] + + case_count = 0 + + for q in tqdm(questions): + target_id = q.shared_context_id + + # Skip if no vector or no memories + if target_id not in user_vector_map or target_id not in embs_by_user: + continue + + z_user = user_vector_map[target_id] + mem_E = embs_by_user[target_id] + mem_V = vecs_by_user[target_id] + + # Embed Query + e_q = embed_model.encode([q.user_question_or_message], return_tensor=False)[0] + e_q = np.array(e_q, dtype=np.float32) + + # Embed Options + if not q.all_options: + continue + e_opts = embed_model.encode(q.all_options, return_tensor=False) + e_opts = np.array(e_opts, dtype=np.float32) + + # --- BASE (No RAG) --- + s_base = score_rag(e_q, e_opts, np.array([])) + + # --- OURS (Personalized RAG) --- + ours_idx = policy_retrieval(e_q, mem_E, mem_V, z_user, item_proj, topk_dense=20, topk_final=5, beta=beta) + ours_mem_E = mem_E[ours_idx] + s_ours = score_rag(e_q, e_opts, ours_mem_E) + + pred_base = int(np.argmax(s_base)) + pred_ours = int(np.argmax(s_ours)) + + is_correct = (pred_ours == q.correct_index) + base_correct = (pred_base == q.correct_index) + + # Detailed Case Print (Task A) + # Print only when Beta > 0 (to avoid duplicate logs) and when they disagree + if beta > 0.0 and pred_base != pred_ours and case_count < 5: + case_count += 1 + + # Reconstruct memory text (need to find card in list) + # Optimization: Create a map or just linear search in cards_by_user[target_id] + user_cards = cards_by_user[target_id] + # ours_idx are indices into mem_E which corresponds to cards_by_user list order + retrieved_notes = [user_cards[i].note_text for i in ours_idx] + + print(f"\n" + "="*60) + print(f"[CASE ANALYSIS] QID: {q.question_id}") + print(f"User Question: {q.user_question_or_message}") + print(f"Correct Option ({q.correct_index}): {q.all_options[q.correct_index]}") + print("-" * 30) + print(f"Base Pred ({pred_base}): {q.all_options[pred_base]} [{'CORRECT' if base_correct else 'WRONG'}]") + print(f"Ours Pred ({pred_ours}): {q.all_options[pred_ours]} [{'CORRECT' if is_correct else 'WRONG'}]") + print("-" * 30) + print(f"Retrieved Memories (Top 3 of {len(retrieved_notes)}):") + for note in retrieved_notes[:3]: + print(f" - {note}") + print("-" * 30) + print(f"Scores Base: {s_base}") + print(f"Scores Ours: {s_ours}") + print("="*60 + "\n") + + if q.correct_index != -1: + correct_base.append(1 if pred_base == q.correct_index else 0) + correct_ours.append(1 if pred_ours == q.correct_index else 0) + + if not correct_base: + print("No valid evaluation samples processed.") + continue + + acc_base = np.mean(correct_base) + acc_ours = np.mean(correct_ours) + + # Win rate + wins = [1 if (c_o == 1 and c_b == 0) else 0 for c_o, c_b in zip(correct_ours, correct_base)] + win_rate = np.mean(wins) + + print(f"\n--- Results (Beta={beta}) ---") + print(f"Total Samples: {len(correct_base)}") + print(f"Accuracy (Base No-RAG): {acc_base:.4f}") + print(f"Accuracy (Ours RAG): {acc_ours:.4f}") + print(f"Win Rate: {win_rate:.4f}") + +if __name__ == "__main__": + main() + diff --git a/scripts/pilot_runner_v0.py b/scripts/pilot_runner_v0.py new file mode 100644 index 0000000..8b7773a --- /dev/null +++ b/scripts/pilot_runner_v0.py @@ -0,0 +1,362 @@ +#!/usr/bin/env python3 +""" +Pilot Runner v0 - Minimal End-to-End Test + +Goal: Prove the chat → judge → apply_feedback → next query loop works. + +Setup: +- 1 user × 1 session × 5 turns +- Fixed queries (no fancy user simulator yet) +- Rule-based judge: answer non-empty → sat=1, else 0 +- reward = sat, gating = 1 always + +What we're checking: +1. No crashes (KeyError, NoneType, etc.) +2. User vector norms change after feedback (RL is being called) +3. resp.usage returns reasonable numbers +4. Logs are generated correctly +""" + +import sys +import os +import json +from datetime import datetime +from dataclasses import dataclass, asdict +from typing import List, Dict, Any, Optional + +# Add src to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.serving import PersonalizedLLM, Feedback, AssistantResponse + + +# ============================================================================= +# Minimal Judge +# ============================================================================= + +@dataclass +class JudgeResult: + """Output from the judge for one turn.""" + sat_t: float # Satisfaction score [0, 1] + sev_t: float # Severity of violations [0, 1] + prog_t: float # Task progress [0, 1] + violations: List[str] # List of violated constraints + + +def minimal_judge(query: str, answer: str, task_type: str = "general") -> JudgeResult: + """ + Minimal rule-based judge for pilot. + + For now: + - sat_t = 1 if answer is non-empty, else 0 + - sev_t = 0 (no severity tracking yet) + - prog_t = 1 if answer looks reasonable, else 0 + """ + violations = [] + + # Check 1: Answer is non-empty + if not answer or len(answer.strip()) < 5: + violations.append("empty_answer") + return JudgeResult(sat_t=0.0, sev_t=1.0, prog_t=0.0, violations=violations) + + # Check 2: Answer is not too short (at least 20 chars for real content) + if len(answer.strip()) < 20: + violations.append("too_short") + + # Check 3: For code tasks, look for code markers + if task_type == "code": + has_code = "```" in answer or "def " in answer or "function" in answer + if not has_code: + violations.append("no_code_block") + + # Calculate scores + sat_t = 1.0 if len(violations) == 0 else max(0.0, 1.0 - 0.3 * len(violations)) + sev_t = 1.0 if "empty_answer" in violations else 0.0 + prog_t = 1.0 if "empty_answer" not in violations else 0.0 + + return JudgeResult(sat_t=sat_t, sev_t=sev_t, prog_t=prog_t, violations=violations) + + +# ============================================================================= +# Minimal User Simulator (Fixed Queries) +# ============================================================================= + +def get_fixed_queries() -> List[Dict[str, Any]]: + """ + Return fixed queries for pilot test. + Mix of preference statements and tasks. + """ + return [ + { + "query": "I prefer short, concise answers. Please keep responses under 100 words.", + "type": "preference", + "task_type": "general", + }, + { + "query": "What are three tips for better sleep?", + "type": "task", + "task_type": "general", + }, + { + "query": "I also prefer bullet points when listing things.", + "type": "preference", + "task_type": "general", + }, + { + "query": "What are the main benefits of exercise?", + "type": "task", + "task_type": "general", + }, + { + "query": "Summarize what you know about my preferences.", + "type": "task", + "task_type": "general", + }, + ] + + +# ============================================================================= +# Logging +# ============================================================================= + +@dataclass +class TurnLog: + """Log entry for one turn.""" + turn_id: int + query: str + query_type: str + answer: str + answer_length: int + sat_t: float + sev_t: float + prog_t: float + violations: List[str] + reward: float + gating: float + z_long_norm_before: float + z_long_norm_after: float + z_short_norm_before: float + z_short_norm_after: float + prompt_tokens: int + completion_tokens: int + total_tokens: int + num_memories_retrieved: int + num_prefs_extracted: int + + +def log_to_jsonl(logs: List[TurnLog], filepath: str): + """Save logs to JSONL file.""" + os.makedirs(os.path.dirname(filepath), exist_ok=True) + with open(filepath, "w") as f: + for log in logs: + f.write(json.dumps(asdict(log)) + "\n") + + +# ============================================================================= +# Pilot Runner +# ============================================================================= + +def run_pilot( + llm: PersonalizedLLM, + user_id: str = "pilot_user_0", + queries: Optional[List[Dict[str, Any]]] = None, +) -> List[TurnLog]: + """ + Run a single pilot session. + + Returns list of turn logs. + """ + if queries is None: + queries = get_fixed_queries() + + logs: List[TurnLog] = [] + + print(f"\n{'='*60}") + print(f"PILOT SESSION: user_id={user_id}, turns={len(queries)}") + print(f"{'='*60}") + + # Reset user for clean start + print(f"\n[Pilot] Resetting user: {user_id}") + llm.reset_user(user_id) + + # Start session + print(f"[Pilot] Starting session") + llm.reset_session(user_id) + + # Get initial state + state_before = llm.get_user_state_summary(user_id) + print(f"[Pilot] Initial state: z_long_norm={state_before['z_long_norm']:.6f}, z_short_norm={state_before['z_short_norm']:.6f}") + + for turn_id, q_info in enumerate(queries): + query = q_info["query"] + query_type = q_info.get("type", "task") + task_type = q_info.get("task_type", "general") + + print(f"\n--- Turn {turn_id} ---") + print(f"[Query] ({query_type}) {query[:80]}...") + + # Get state before + state_before = llm.get_user_state_summary(user_id) + z_long_before = state_before["z_long_norm"] + z_short_before = state_before["z_short_norm"] + + # Apply feedback for previous turn (from turn 1 onwards) + if turn_id > 0 and len(logs) > 0: + prev_log = logs[-1] + feedback = Feedback( + user_id=user_id, + turn_id=turn_id - 1, + reward=prev_log.reward, + gating=prev_log.gating, + meta={"source": "pilot_v0"} + ) + print(f"[Feedback] Applying: reward={feedback.reward:.2f}, gating={feedback.gating:.1f}") + llm.apply_feedback(feedback) + + # Chat + resp: AssistantResponse = llm.chat(user_id, query) + + print(f"[Answer] {resp.answer[:100]}..." if len(resp.answer) > 100 else f"[Answer] {resp.answer}") + print(f"[Usage] prompt={resp.usage.prompt_tokens}, completion={resp.usage.completion_tokens}") + + # Judge + judge_result = minimal_judge(query, resp.answer, task_type) + print(f"[Judge] sat={judge_result.sat_t:.2f}, prog={judge_result.prog_t:.2f}, violations={judge_result.violations}") + + # Compute reward and gating + reward = judge_result.sat_t # Simple: reward = satisfaction + gating = 1.0 # Always allow learning for pilot + + # Get state after + state_after = llm.get_user_state_summary(user_id) + z_long_after = state_after["z_long_norm"] + z_short_after = state_after["z_short_norm"] + + # Debug info + num_memories = len(resp.debug.selected_memory_ids) if resp.debug else 0 + num_prefs = len(resp.debug.extracted_preferences) if resp.debug else 0 + + print(f"[State] z_long: {z_long_before:.6f} -> {z_long_after:.6f}, z_short: {z_short_before:.6f} -> {z_short_after:.6f}") + print(f"[Debug] memories={num_memories}, prefs_extracted={num_prefs}") + + # Log + log = TurnLog( + turn_id=turn_id, + query=query, + query_type=query_type, + answer=resp.answer, + answer_length=len(resp.answer), + sat_t=judge_result.sat_t, + sev_t=judge_result.sev_t, + prog_t=judge_result.prog_t, + violations=judge_result.violations, + reward=reward, + gating=gating, + z_long_norm_before=z_long_before, + z_long_norm_after=z_long_after, + z_short_norm_before=z_short_before, + z_short_norm_after=z_short_after, + prompt_tokens=resp.usage.prompt_tokens, + completion_tokens=resp.usage.completion_tokens, + total_tokens=resp.usage.total_tokens, + num_memories_retrieved=num_memories, + num_prefs_extracted=num_prefs, + ) + logs.append(log) + + # Apply final feedback + if len(logs) > 0: + last_log = logs[-1] + feedback = Feedback( + user_id=user_id, + turn_id=len(queries) - 1, + reward=last_log.reward, + gating=last_log.gating, + meta={"source": "pilot_v0", "final": True} + ) + print(f"\n[Final Feedback] reward={feedback.reward:.2f}, gating={feedback.gating:.1f}") + llm.apply_feedback(feedback) + + return logs + + +def print_summary(logs: List[TurnLog]): + """Print summary statistics.""" + print(f"\n{'='*60}") + print("PILOT SUMMARY") + print(f"{'='*60}") + + total_turns = len(logs) + avg_sat = sum(l.sat_t for l in logs) / total_turns if total_turns > 0 else 0 + avg_prog = sum(l.prog_t for l in logs) / total_turns if total_turns > 0 else 0 + total_tokens = sum(l.total_tokens for l in logs) + total_prompt = sum(l.prompt_tokens for l in logs) + total_completion = sum(l.completion_tokens for l in logs) + + # Check if RL updates happened (vector norms changed) + z_long_changes = [abs(l.z_long_norm_after - l.z_long_norm_before) for l in logs] + z_short_changes = [abs(l.z_short_norm_after - l.z_short_norm_before) for l in logs] + any_z_long_change = any(c > 1e-6 for c in z_long_changes) + any_z_short_change = any(c > 1e-6 for c in z_short_changes) + + print(f"Total turns: {total_turns}") + print(f"Average satisfaction: {avg_sat:.3f}") + print(f"Average progress: {avg_prog:.3f}") + print(f"Total tokens: {total_tokens} (prompt: {total_prompt}, completion: {total_completion})") + print(f"z_long changed: {any_z_long_change} (max delta: {max(z_long_changes):.6f})") + print(f"z_short changed: {any_z_short_change} (max delta: {max(z_short_changes):.6f})") + + # Violations breakdown + all_violations = [v for l in logs for v in l.violations] + if all_violations: + from collections import Counter + print(f"Violations: {dict(Counter(all_violations))}") + else: + print("Violations: None") + + # RL Health Check + print(f"\n--- RL Health Check ---") + if any_z_long_change or any_z_short_change: + print("✓ User vectors ARE being updated by RL") + else: + print("✗ WARNING: User vectors NOT changing - check apply_feedback") + + +def main(): + print("=" * 60) + print("PILOT RUNNER v0") + print("=" * 60) + print(f"Started at: {datetime.now().isoformat()}") + + # Initialize LLM + print("\n[Init] Loading PersonalizedLLM...") + llm = PersonalizedLLM( + user_store_path="data/users/user_store_pilot.npz", + only_own_memories=True, + enable_preference_extraction=True, + enable_rl_updates=True, + ) + + # Run pilot + user_id = "pilot_user_0" + logs = run_pilot(llm, user_id=user_id) + + # Summary + print_summary(logs) + + # Save logs + log_path = f"data/logs/pilot_v0_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl" + log_to_jsonl(logs, log_path) + print(f"\n[Logs] Saved to: {log_path}") + + # Final state + final_state = llm.get_user_state_summary(user_id) + print(f"\n[Final State] {final_state}") + + print(f"\nCompleted at: {datetime.now().isoformat()}") + print("=" * 60) + + +if __name__ == "__main__": + main() + diff --git a/scripts/pilot_runner_v1.py b/scripts/pilot_runner_v1.py new file mode 100644 index 0000000..fbb2876 --- /dev/null +++ b/scripts/pilot_runner_v1.py @@ -0,0 +1,607 @@ +#!/usr/bin/env python3 +""" +Pilot Runner v1 - Style-Aware Judge + Gating Logic + +Upgrade from v0: +- StylePrefs: User style preferences (length, bullets, language) +- style_judge: Checks style conformance, not just non-empty +- compute_feedback_for_turn: gating=1 only for preference-related turns +- Extended queries: ~10 turns with preference/task mix + +Goal: Verify that: +1. sat_t varies based on style violations (not always 1) +2. gating=1 only on preference turns, 0 on regular tasks +3. RL updates happen when gating=1 and reward != baseline +4. Over turns, model may adapt to preferences (sat_t improves) +""" + +import sys +import os +import json +from datetime import datetime +from dataclasses import dataclass, asdict, field +from typing import List, Dict, Any, Optional, Tuple + +# Add src to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.serving import PersonalizedLLM, Feedback, AssistantResponse + + +# ============================================================================= +# Style Preferences +# ============================================================================= + +@dataclass +class StylePrefs: + """User's style preferences for the judge to check.""" + require_short: bool = False + max_chars: int = 300 + require_bullets: bool = False + lang: str = "en" # "en" or "zh" + + +# ============================================================================= +# Style-Aware Judge +# ============================================================================= + +@dataclass +class JudgeResult: + """Output from the judge for one turn.""" + sat_t: float # Satisfaction score [0, 1] + sev_t: float # Severity of violations [0, 1] + prog_t: float # Task progress [0, 1] + violations: List[str] # List of violated constraints + + +def style_judge( + query: str, + answer: str, + task_type: str, + prefs: StylePrefs, +) -> JudgeResult: + """ + Style-aware judge that checks: + - Empty/too short answer + - Length constraint (max_chars) + - Bullet point requirement + - Language preference + - Code block for code tasks + + Returns: + JudgeResult with sat_t, sev_t, prog_t, and violations list. + """ + violations: List[str] = [] + text = (answer or "").strip() + + # 0) Empty answer - immediate fail + if not text or len(text) < 5: + violations.append("empty_answer") + return JudgeResult( + sat_t=0.0, + sev_t=1.0, + prog_t=0.0, + violations=violations, + ) + + # 1) Length preference + if prefs.require_short: + if len(text) > prefs.max_chars: + violations.append("too_long") + + # 2) Bullet preference (only for general/list tasks, not pure preference statements) + if prefs.require_bullets and task_type in ("general", "list"): + has_bullets = ("- " in text) or ("• " in text) or ("* " in text) or ("\n- " in text) + if not has_bullets: + violations.append("no_bullets") + + # 3) Language preference (rough heuristic) + if prefs.lang == "zh": + # For Chinese preference, check if answer has enough non-ASCII chars + ascii_count = sum(c.isascii() for c in text) + ascii_ratio = ascii_count / max(1, len(text)) + if ascii_ratio > 0.7: # Too much ASCII = probably not Chinese + violations.append("wrong_lang") + elif prefs.lang == "en": + # For English preference, check if answer is mostly ASCII + ascii_count = sum(c.isascii() for c in text) + ascii_ratio = ascii_count / max(1, len(text)) + if ascii_ratio < 0.5: # Too little ASCII = probably not English + violations.append("wrong_lang") + + # 4) Code task: must have code markers + prog_t = 1.0 + if task_type == "code": + has_code = ("```" in text) or ("def " in text) or ("function " in text) + if not has_code: + violations.append("no_code_block") + prog_t = 0.0 + + # 5) Compute sat_t and sev_t from violations + if not violations: + sat_t = 1.0 + sev_t = 0.0 + else: + # Each violation costs 0.3, minimum 0 + sat_t = max(0.0, 1.0 - 0.3 * float(len(violations))) + # Hard violations trigger sev_t=1 + hard_violations = {"empty_answer", "too_long", "wrong_lang"} + sev_t = 1.0 if any(v in hard_violations for v in violations) else 0.0 + + return JudgeResult( + sat_t=sat_t, + sev_t=sev_t, + prog_t=prog_t, + violations=violations, + ) + + +# ============================================================================= +# Feedback Computation (reward + gating) +# ============================================================================= + +def compute_feedback_for_turn( + turn_id: int, + query: str, + query_type: str, + task_type: str, + judge_result: JudgeResult, +) -> Tuple[float, float]: + """ + Convert JudgeResult into (reward, gating): + - reward = sat_t (style satisfaction) + - gating = 1 only if this turn is preference-related (declared or complained) + + Args: + turn_id: The turn index + query: The user's query text + query_type: "preference" or "task" from query metadata + task_type: "general", "list", "code", etc. + judge_result: The judge's evaluation + + Returns: + (reward, gating) tuple + """ + reward = judge_result.sat_t + + # Gating logic: only allow RL update on preference-related turns + # 1. Explicit preference declaration (query_type == "preference") + # 2. Complaint about not following preference + lower_q = (query or "").lower() + + is_pref_turn = ( + query_type == "preference" + or "i prefer" in lower_q + or "my preference" in lower_q + or "please use" in lower_q + or "please keep" in lower_q + or "you didn't follow" in lower_q + or "you forgot" in lower_q + or "remember that i" in lower_q + or "i told you" in lower_q + or "i asked for" in lower_q + ) + + if is_pref_turn: + gating = 1.0 + else: + gating = 0.0 + + return reward, gating + + +# ============================================================================= +# Extended Queries for Pilot v1 (~10 turns) +# ============================================================================= + +def get_pilot_v1_queries() -> List[Dict[str, Any]]: + """ + Extended query set for pilot v1. + Mix of preference declarations and tasks. + Tests: length constraint, bullet points, task completion. + """ + return [ + # Turn 0: Declare length preference + { + "query": "I prefer short, concise answers. Please keep responses under 200 characters.", + "type": "preference", + "task_type": "general", + }, + # Turn 1: Task that should be short + { + "query": "What are three tips for better sleep?", + "type": "task", + "task_type": "list", + }, + # Turn 2: Declare bullet preference + { + "query": "I also prefer bullet points when listing things. Please use bullet points.", + "type": "preference", + "task_type": "general", + }, + # Turn 3: Task that should use bullets + { + "query": "What are the main benefits of regular exercise?", + "type": "task", + "task_type": "list", + }, + # Turn 4: Another task (test if preferences stick) + { + "query": "Name five popular programming languages.", + "type": "task", + "task_type": "list", + }, + # Turn 5: Complaint if needed (always include to test gating) + { + "query": "Remember that I asked for short answers with bullet points. Can you list three healthy breakfast ideas?", + "type": "preference", + "task_type": "list", + }, + # Turn 6: Regular task + { + "query": "What is the capital of France?", + "type": "task", + "task_type": "general", + }, + # Turn 7: Task requiring list + { + "query": "What are four seasons of the year?", + "type": "task", + "task_type": "list", + }, + # Turn 8: Another preference reminder + { + "query": "I prefer concise bullet points. Please list three types of renewable energy.", + "type": "preference", + "task_type": "list", + }, + # Turn 9: Final task - test memory + { + "query": "Summarize what you know about my communication preferences.", + "type": "task", + "task_type": "general", + }, + ] + + +# ============================================================================= +# Logging +# ============================================================================= + +@dataclass +class TurnLog: + """Log entry for one turn.""" + turn_id: int + query: str + query_type: str + task_type: str + answer: str + answer_length: int + sat_t: float + sev_t: float + prog_t: float + violations: List[str] + reward: float + gating: float + z_long_norm_before: float + z_long_norm_after: float + z_short_norm_before: float + z_short_norm_after: float + prompt_tokens: int + completion_tokens: int + total_tokens: int + num_memories_retrieved: int + num_prefs_extracted: int + + +def log_to_jsonl(logs: List[TurnLog], filepath: str): + """Save logs to JSONL file.""" + os.makedirs(os.path.dirname(filepath), exist_ok=True) + with open(filepath, "w") as f: + for log in logs: + f.write(json.dumps(asdict(log)) + "\n") + + +# ============================================================================= +# Pilot Runner v1 +# ============================================================================= + +def run_pilot_v1( + llm: PersonalizedLLM, + user_id: str = "pilot_user_v1", + prefs: Optional[StylePrefs] = None, + queries: Optional[List[Dict[str, Any]]] = None, +) -> List[TurnLog]: + """ + Run pilot v1 with style-aware judge and gating. + + Args: + llm: PersonalizedLLM instance + user_id: User identifier + prefs: Style preferences for this user + queries: Query list (defaults to get_pilot_v1_queries) + + Returns: + List of TurnLog entries + """ + if prefs is None: + # Default preferences: short + bullets + English + prefs = StylePrefs( + require_short=True, + max_chars=200, + require_bullets=True, + lang="en", + ) + + if queries is None: + queries = get_pilot_v1_queries() + + logs: List[TurnLog] = [] + + print(f"\n{'='*60}") + print(f"PILOT v1 SESSION: user_id={user_id}, turns={len(queries)}") + print(f"Preferences: short={prefs.require_short}, max_chars={prefs.max_chars}, bullets={prefs.require_bullets}, lang={prefs.lang}") + print(f"{'='*60}") + + # Reset user for clean start + print(f"\n[Pilot] Resetting user: {user_id}") + llm.reset_user(user_id) + + # Start session + print(f"[Pilot] Starting session") + llm.reset_session(user_id) + + # Get initial state + state_before = llm.get_user_state_summary(user_id) + print(f"[Pilot] Initial state: z_long={state_before['z_long_norm']:.6f}, z_short={state_before['z_short_norm']:.6f}") + + for turn_id, q_info in enumerate(queries): + query = q_info["query"] + query_type = q_info.get("type", "task") + task_type = q_info.get("task_type", "general") + + print(f"\n{'─'*60}") + print(f"Turn {turn_id} [{query_type}]") + print(f"{'─'*60}") + print(f"[Query] {query}") + + # Get state before + state_before = llm.get_user_state_summary(user_id) + z_long_before = state_before["z_long_norm"] + z_short_before = state_before["z_short_norm"] + + # Apply feedback for previous turn (from turn 1 onwards) + if turn_id > 0 and len(logs) > 0: + prev_log = logs[-1] + prev_query = queries[turn_id - 1] + + # Re-judge the previous answer with current context + # (In practice we already have the result, but this shows the flow) + feedback = Feedback( + user_id=user_id, + turn_id=turn_id - 1, + reward=prev_log.reward, + gating=prev_log.gating, + meta={ + "sat_t": prev_log.sat_t, + "sev_t": prev_log.sev_t, + "prog_t": prev_log.prog_t, + "violations": prev_log.violations, + "task_type": prev_log.task_type, + "source": "pilot_v1", + } + ) + print(f"[Feedback] turn={turn_id-1}, reward={feedback.reward:.2f}, gating={feedback.gating:.1f}") + llm.apply_feedback(feedback) + + # Chat + resp: AssistantResponse = llm.chat(user_id, query) + + # Truncate answer for display + answer_display = resp.answer[:150] + "..." if len(resp.answer) > 150 else resp.answer + print(f"[Answer] ({len(resp.answer)} chars) {answer_display}") + print(f"[Usage] prompt={resp.usage.prompt_tokens}, completion={resp.usage.completion_tokens}") + + # Judge with style preferences + judge_result = style_judge(query, resp.answer, task_type, prefs) + print(f"[Judge] sat={judge_result.sat_t:.2f}, sev={judge_result.sev_t:.1f}, prog={judge_result.prog_t:.1f}") + if judge_result.violations: + print(f"[Judge] violations={judge_result.violations}") + + # Compute feedback for THIS turn (will be applied next turn) + reward, gating = compute_feedback_for_turn( + turn_id=turn_id, + query=query, + query_type=query_type, + task_type=task_type, + judge_result=judge_result, + ) + print(f"[Feedback] reward={reward:.2f}, gating={gating:.1f} (computed for this turn)") + + # Get state after + state_after = llm.get_user_state_summary(user_id) + z_long_after = state_after["z_long_norm"] + z_short_after = state_after["z_short_norm"] + + # Debug info + num_memories = len(resp.debug.selected_memory_ids) if resp.debug else 0 + num_prefs = len(resp.debug.extracted_preferences) if resp.debug else 0 + + z_long_delta = z_long_after - z_long_before + z_short_delta = z_short_after - z_short_before + print(f"[State] z_long: {z_long_before:.6f} → {z_long_after:.6f} (Δ={z_long_delta:+.6f})") + print(f"[State] z_short: {z_short_before:.6f} → {z_short_after:.6f} (Δ={z_short_delta:+.6f})") + print(f"[Debug] memories={num_memories}, prefs_extracted={num_prefs}") + + # Log + log = TurnLog( + turn_id=turn_id, + query=query, + query_type=query_type, + task_type=task_type, + answer=resp.answer, + answer_length=len(resp.answer), + sat_t=judge_result.sat_t, + sev_t=judge_result.sev_t, + prog_t=judge_result.prog_t, + violations=judge_result.violations, + reward=reward, + gating=gating, + z_long_norm_before=z_long_before, + z_long_norm_after=z_long_after, + z_short_norm_before=z_short_before, + z_short_norm_after=z_short_after, + prompt_tokens=resp.usage.prompt_tokens, + completion_tokens=resp.usage.completion_tokens, + total_tokens=resp.usage.total_tokens, + num_memories_retrieved=num_memories, + num_prefs_extracted=num_prefs, + ) + logs.append(log) + + # Apply final feedback for last turn + if len(logs) > 0: + last_log = logs[-1] + feedback = Feedback( + user_id=user_id, + turn_id=len(queries) - 1, + reward=last_log.reward, + gating=last_log.gating, + meta={"source": "pilot_v1", "final": True} + ) + print(f"\n[Final Feedback] turn={len(queries)-1}, reward={feedback.reward:.2f}, gating={feedback.gating:.1f}") + llm.apply_feedback(feedback) + + return logs + + +def print_summary_v1(logs: List[TurnLog], prefs: StylePrefs): + """Print summary statistics for pilot v1.""" + print(f"\n{'='*60}") + print("PILOT v1 SUMMARY") + print(f"{'='*60}") + + total_turns = len(logs) + if total_turns == 0: + print("No turns to summarize.") + return + + # Basic stats + avg_sat = sum(l.sat_t for l in logs) / total_turns + avg_prog = sum(l.prog_t for l in logs) / total_turns + total_tokens = sum(l.total_tokens for l in logs) + total_prompt = sum(l.prompt_tokens for l in logs) + total_completion = sum(l.completion_tokens for l in logs) + + # Gating stats + gated_turns = [l for l in logs if l.gating > 0] + non_gated_turns = [l for l in logs if l.gating == 0] + + print(f"\n--- Turn Statistics ---") + print(f"Total turns: {total_turns}") + print(f"Gated turns (RL active): {len(gated_turns)}") + print(f"Non-gated turns (RL skipped): {len(non_gated_turns)}") + + print(f"\n--- Satisfaction ---") + print(f"Average sat_t (all): {avg_sat:.3f}") + if gated_turns: + avg_sat_gated = sum(l.sat_t for l in gated_turns) / len(gated_turns) + print(f"Average sat_t (gated only): {avg_sat_gated:.3f}") + print(f"Average prog_t: {avg_prog:.3f}") + + print(f"\n--- Token Usage ---") + print(f"Total tokens: {total_tokens}") + print(f" Prompt: {total_prompt}") + print(f" Completion: {total_completion}") + print(f"Avg tokens/turn: {total_tokens / total_turns:.1f}") + + # Violations breakdown + print(f"\n--- Violations ---") + from collections import Counter + all_violations = [v for l in logs for v in l.violations] + if all_violations: + print(f"Total violations: {len(all_violations)}") + for v, count in Counter(all_violations).most_common(): + print(f" {v}: {count}") + else: + print("No violations") + + # Answer length analysis + print(f"\n--- Answer Lengths (max_chars={prefs.max_chars}) ---") + lengths = [l.answer_length for l in logs] + over_limit = sum(1 for l in lengths if l > prefs.max_chars) + print(f"Min: {min(lengths)}, Max: {max(lengths)}, Avg: {sum(lengths)/len(lengths):.1f}") + print(f"Over limit: {over_limit}/{total_turns}") + + # RL Health Check + print(f"\n--- RL Health Check ---") + z_long_changes = [abs(l.z_long_norm_after - l.z_long_norm_before) for l in logs] + z_short_changes = [abs(l.z_short_norm_after - l.z_short_norm_before) for l in logs] + any_z_long_change = any(c > 1e-6 for c in z_long_changes) + any_z_short_change = any(c > 1e-6 for c in z_short_changes) + + print(f"z_long changed: {any_z_long_change} (max Δ: {max(z_long_changes):.6f})") + print(f"z_short changed: {any_z_short_change} (max Δ: {max(z_short_changes):.6f})") + + if any_z_long_change or any_z_short_change: + print("✓ User vectors ARE being updated by RL") + else: + print("✗ WARNING: User vectors NOT changing") + print(" Check: gating=1 on some turns? reward != baseline?") + + # Per-turn detail table + print(f"\n--- Turn-by-Turn Summary ---") + print(f"{'Turn':>4} {'Type':>10} {'Len':>5} {'sat':>5} {'gate':>5} {'violations'}") + print("-" * 60) + for l in logs: + viol_str = ",".join(l.violations) if l.violations else "-" + print(f"{l.turn_id:>4} {l.query_type:>10} {l.answer_length:>5} {l.sat_t:>5.2f} {l.gating:>5.1f} {viol_str}") + + +def main(): + print("=" * 60) + print("PILOT RUNNER v1 - Style-Aware Judge + Gating") + print("=" * 60) + print(f"Started at: {datetime.now().isoformat()}") + + # Define user preferences + prefs = StylePrefs( + require_short=True, + max_chars=200, + require_bullets=True, + lang="en", + ) + print(f"\n[Config] User preferences: {prefs}") + + # Initialize LLM + print("\n[Init] Loading PersonalizedLLM...") + llm = PersonalizedLLM( + user_store_path="data/users/user_store_pilot_v1.npz", + only_own_memories=True, + enable_preference_extraction=True, + enable_rl_updates=True, + ) + + # Run pilot + user_id = "pilot_user_v1" + logs = run_pilot_v1(llm, user_id=user_id, prefs=prefs) + + # Summary + print_summary_v1(logs, prefs) + + # Save logs + log_path = f"data/logs/pilot_v1_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl" + log_to_jsonl(logs, log_path) + print(f"\n[Logs] Saved to: {log_path}") + + # Final state + final_state = llm.get_user_state_summary(user_id) + print(f"\n[Final State] {final_state}") + + print(f"\nCompleted at: {datetime.now().isoformat()}") + print("=" * 60) + + +if __name__ == "__main__": + main() + diff --git a/scripts/pilot_runner_v2.py b/scripts/pilot_runner_v2.py new file mode 100644 index 0000000..d3c2aa8 --- /dev/null +++ b/scripts/pilot_runner_v2.py @@ -0,0 +1,852 @@ +#!/usr/bin/env python3 +""" +Pilot Runner v2 - Cross-Session Preference Reveal Mechanism + +Upgrade from v1: +- RevealState: Tracks which preferences have been explicitly revealed by the user +- pref_true[k] vs pref_revealed_global[k] distinction +- Style constraints only enforced AFTER user reveals them +- Reveal state persists across sessions, resets on reset_user() + +Key concepts: +- pref_true[k]: User's true preference (from StylePrefs) +- pref_revealed_global[k]: Whether preference k has been revealed at least once + +Enforcement rule: +- A style constraint is enforced only when BOTH pref_true[k] AND pref_revealed_global[k] + +Session semantics: +- reset_user(): Clears ALL state including reveal flags +- reset_session(): Keeps reveal flags (cross-session memory) +""" + +import sys +import os +import json +from datetime import datetime +from dataclasses import dataclass, asdict, field +from typing import List, Dict, Any, Optional, Tuple, Set + +# Add src to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.serving import PersonalizedLLM, Feedback, AssistantResponse + + +# ============================================================================= +# Style Preferences (True Preferences) +# ============================================================================= + +@dataclass +class StylePrefs: + """ + User's TRUE style preferences. + These are the ground truth preferences that the user actually has, + but they may not have revealed all of them to the system yet. + """ + require_short: bool = False + max_chars: int = 300 + require_bullets: bool = False + lang: str = "en" # "en" or "zh" + + +# ============================================================================= +# Reveal State (What has been explicitly revealed) +# ============================================================================= + +@dataclass +class RevealState: + """ + Tracks which preferences have been explicitly revealed by the user. + + This persists across sessions for the same user but resets on reset_user(). + A preference is revealed when the user explicitly mentions it in a query. + """ + short_revealed: bool = False # "short", "concise", "brief", length constraints + bullets_revealed: bool = False # "bullet", "bullet points", "list format" + lang_revealed: bool = False # Language preference mentioned + + def reset(self): + """Reset all reveal flags (called on reset_user).""" + self.short_revealed = False + self.bullets_revealed = False + self.lang_revealed = False + + def to_dict(self) -> Dict[str, bool]: + return { + "short": self.short_revealed, + "bullets": self.bullets_revealed, + "lang": self.lang_revealed, + } + + def __str__(self) -> str: + flags = [] + if self.short_revealed: + flags.append("short") + if self.bullets_revealed: + flags.append("bullets") + if self.lang_revealed: + flags.append("lang") + return f"RevealState({', '.join(flags) if flags else 'none'})" + + +class RevealStateManager: + """ + Manages reveal state for multiple users. + Persists across sessions, resets on reset_user(). + """ + + def __init__(self): + self._states: Dict[str, RevealState] = {} + + def get_state(self, user_id: str) -> RevealState: + """Get or create reveal state for a user.""" + if user_id not in self._states: + self._states[user_id] = RevealState() + return self._states[user_id] + + def reset_user(self, user_id: str): + """Reset reveal state for a user (called on reset_user).""" + if user_id in self._states: + self._states[user_id].reset() + else: + self._states[user_id] = RevealState() + + def reset_session(self, user_id: str): + """ + Called on reset_session - does NOT reset reveal state. + Reveal state persists across sessions. + """ + # Intentionally do nothing - reveal state persists + pass + + +# ============================================================================= +# Preference Detection from Queries +# ============================================================================= + +def detect_revealed_preferences(query: str) -> Dict[str, bool]: + """ + Detect which preferences are mentioned in a query. + + Returns a dict with keys: "short", "bullets", "lang" + Each value is True if that preference was mentioned. + """ + lower_q = (query or "").lower() + + revealed = { + "short": False, + "bullets": False, + "lang": False, + } + + # Short/length preference detection + short_patterns = [ + "short", "concise", "brief", "under ", "less than", + "keep it short", "keep responses", "keep answers", + "maximum ", "max ", "characters", "words or less", + "200 ", "100 ", "50 ", "300 ", # Common char limits + ] + for pattern in short_patterns: + if pattern in lower_q: + revealed["short"] = True + break + + # Bullet preference detection + bullet_patterns = [ + "bullet", "bullet point", "bullet-point", + "bulleted", "list format", "use bullets", + "use bullet", "with bullets", "in bullets", + "- format", "• ", "numbered list", + ] + for pattern in bullet_patterns: + if pattern in lower_q: + revealed["bullets"] = True + break + + # Language preference detection + lang_patterns_zh = [ + "chinese", "中文", "in chinese", "用中文", + "speak chinese", "write chinese", "respond in chinese", + "please use chinese", "mandarin", + ] + lang_patterns_en = [ + "english", "in english", "use english", + "speak english", "write english", "respond in english", + "please use english", + ] + + for pattern in lang_patterns_zh + lang_patterns_en: + if pattern in lower_q: + revealed["lang"] = True + break + + return revealed + + +def update_reveal_state(reveal_state: RevealState, query: str) -> Set[str]: + """ + Update reveal state based on query content. + Returns set of newly revealed preferences. + """ + detected = detect_revealed_preferences(query) + newly_revealed = set() + + if detected["short"] and not reveal_state.short_revealed: + reveal_state.short_revealed = True + newly_revealed.add("short") + + if detected["bullets"] and not reveal_state.bullets_revealed: + reveal_state.bullets_revealed = True + newly_revealed.add("bullets") + + if detected["lang"] and not reveal_state.lang_revealed: + reveal_state.lang_revealed = True + newly_revealed.add("lang") + + return newly_revealed + + +# ============================================================================= +# Style-Aware Judge with Reveal State +# ============================================================================= + +@dataclass +class JudgeResult: + """Output from the judge for one turn.""" + sat_t: float # Satisfaction score [0, 1] + sev_t: float # Severity of violations [0, 1] + prog_t: float # Task progress [0, 1] + violations: List[str] # List of violated constraints + enforced_constraints: List[str] # Which constraints were actually enforced + + +def style_judge_with_reveal( + query: str, + answer: str, + task_type: str, + prefs: StylePrefs, + reveal_state: RevealState, +) -> JudgeResult: + """ + Style-aware judge that ONLY enforces revealed preferences. + + A constraint is enforced only when: + - pref_true[k] is True (user has this preference) + - pref_revealed_global[k] is True (user has revealed this preference) + + Args: + query: User's query + answer: Assistant's answer + task_type: Type of task ("general", "list", "code") + prefs: User's TRUE preferences (StylePrefs) + reveal_state: Which preferences have been revealed + + Returns: + JudgeResult with sat_t, sev_t, prog_t, violations, and enforced_constraints + """ + violations: List[str] = [] + enforced: List[str] = [] + text = (answer or "").strip() + + # 0) Empty answer - always a violation regardless of reveal state + if not text or len(text) < 5: + violations.append("empty_answer") + return JudgeResult( + sat_t=0.0, + sev_t=1.0, + prog_t=0.0, + violations=violations, + enforced_constraints=["non_empty"], + ) + + # 1) Length preference - enforce only if BOTH true AND revealed + if prefs.require_short and reveal_state.short_revealed: + enforced.append("short") + if len(text) > prefs.max_chars: + violations.append("too_long") + + # 2) Bullet preference - enforce only if BOTH true AND revealed + # Also only for list-type tasks + if prefs.require_bullets and reveal_state.bullets_revealed: + if task_type in ("general", "list"): + enforced.append("bullets") + has_bullets = ("- " in text) or ("• " in text) or ("* " in text) or ("\n- " in text) + if not has_bullets: + violations.append("no_bullets") + + # 3) Language preference - enforce only if BOTH true AND revealed + if reveal_state.lang_revealed: + enforced.append("lang") + if prefs.lang == "zh": + ascii_count = sum(c.isascii() for c in text) + ascii_ratio = ascii_count / max(1, len(text)) + if ascii_ratio > 0.7: + violations.append("wrong_lang") + elif prefs.lang == "en": + ascii_count = sum(c.isascii() for c in text) + ascii_ratio = ascii_count / max(1, len(text)) + if ascii_ratio < 0.5: + violations.append("wrong_lang") + + # 4) Code task: always enforce code markers (not a user preference) + prog_t = 1.0 + if task_type == "code": + enforced.append("code_block") + has_code = ("```" in text) or ("def " in text) or ("function " in text) + if not has_code: + violations.append("no_code_block") + prog_t = 0.0 + + # 5) Compute sat_t and sev_t from violations + if not violations: + sat_t = 1.0 + sev_t = 0.0 + else: + sat_t = max(0.0, 1.0 - 0.3 * float(len(violations))) + hard_violations = {"empty_answer", "too_long", "wrong_lang"} + sev_t = 1.0 if any(v in hard_violations for v in violations) else 0.0 + + return JudgeResult( + sat_t=sat_t, + sev_t=sev_t, + prog_t=prog_t, + violations=violations, + enforced_constraints=enforced, + ) + + +# ============================================================================= +# Feedback Computation (reward + gating) +# ============================================================================= + +def compute_feedback_for_turn( + turn_id: int, + query: str, + query_type: str, + task_type: str, + judge_result: JudgeResult, +) -> Tuple[float, float]: + """ + Convert JudgeResult into (reward, gating). + Same as v1 - reward = sat_t, gating = 1 for preference turns. + """ + reward = judge_result.sat_t + + lower_q = (query or "").lower() + + is_pref_turn = ( + query_type == "preference" + or "i prefer" in lower_q + or "my preference" in lower_q + or "please use" in lower_q + or "please keep" in lower_q + or "you didn't follow" in lower_q + or "you forgot" in lower_q + or "remember that i" in lower_q + or "i told you" in lower_q + or "i asked for" in lower_q + ) + + gating = 1.0 if is_pref_turn else 0.0 + return reward, gating + + +# ============================================================================= +# Multi-Session Queries for Pilot v2 +# ============================================================================= + +def get_session_1_queries() -> List[Dict[str, Any]]: + """ + Session 1: User reveals preferences and does some tasks. + """ + return [ + { + "query": "I prefer short, concise answers. Please keep responses under 200 characters.", + "type": "preference", + "task_type": "general", + }, + { + "query": "What are three tips for better sleep?", + "type": "task", + "task_type": "list", + }, + { + "query": "I also prefer bullet points when listing things.", + "type": "preference", + "task_type": "general", + }, + { + "query": "What are the main benefits of exercise?", + "type": "task", + "task_type": "list", + }, + { + "query": "Name five programming languages.", + "type": "task", + "task_type": "list", + }, + ] + + +def get_session_2_queries() -> List[Dict[str, Any]]: + """ + Session 2: User does NOT restate preferences. + Tests cross-session preference retention. + """ + return [ + { + "query": "What are three healthy breakfast ideas?", + "type": "task", + "task_type": "list", + }, + { + "query": "List four seasons of the year.", + "type": "task", + "task_type": "list", + }, + { + "query": "What is the capital of France?", + "type": "task", + "task_type": "general", + }, + { + "query": "Name three types of renewable energy.", + "type": "task", + "task_type": "list", + }, + ] + + +def get_session_3_queries() -> List[Dict[str, Any]]: + """ + Session 3: Mix of tasks and one complaint/reminder. + """ + return [ + { + "query": "What are five common fruits?", + "type": "task", + "task_type": "list", + }, + { + "query": "Remember that I asked for short bullet points. List three ocean animals.", + "type": "preference", + "task_type": "list", + }, + { + "query": "What is 2 + 2?", + "type": "task", + "task_type": "general", + }, + ] + + +# ============================================================================= +# Logging (Extended for v2) +# ============================================================================= + +@dataclass +class TurnLog: + """Log entry for one turn (extended for v2).""" + session_id: int + turn_id: int + query: str + query_type: str + task_type: str + answer: str + answer_length: int + sat_t: float + sev_t: float + prog_t: float + violations: List[str] + enforced_constraints: List[str] + reward: float + gating: float + reveal_state_before: Dict[str, bool] + reveal_state_after: Dict[str, bool] + newly_revealed: List[str] + z_long_norm_before: float + z_long_norm_after: float + z_short_norm_before: float + z_short_norm_after: float + prompt_tokens: int + completion_tokens: int + total_tokens: int + num_memories_retrieved: int + num_prefs_extracted: int + + +def log_to_jsonl(logs: List[TurnLog], filepath: str): + """Save logs to JSONL file.""" + os.makedirs(os.path.dirname(filepath), exist_ok=True) + with open(filepath, "w") as f: + for log in logs: + f.write(json.dumps(asdict(log)) + "\n") + + +# ============================================================================= +# Pilot Runner v2 (Multi-Session with Reveal State) +# ============================================================================= + +def run_session( + llm: PersonalizedLLM, + user_id: str, + session_id: int, + prefs: StylePrefs, + reveal_state: RevealState, + queries: List[Dict[str, Any]], +) -> List[TurnLog]: + """ + Run a single session with reveal-aware judging. + """ + logs: List[TurnLog] = [] + + print(f"\n{'='*60}") + print(f"SESSION {session_id}: user_id={user_id}, turns={len(queries)}") + print(f"Reveal state (start): {reveal_state}") + print(f"{'='*60}") + + # Reset session (clears history, z_short; keeps z_long and reveal state) + llm.reset_session(user_id) + + state_before = llm.get_user_state_summary(user_id) + print(f"[Session] z_long={state_before['z_long_norm']:.6f}, z_short={state_before['z_short_norm']:.6f}") + + for turn_id, q_info in enumerate(queries): + query = q_info["query"] + query_type = q_info.get("type", "task") + task_type = q_info.get("task_type", "general") + + print(f"\n{'─'*60}") + print(f"Session {session_id} / Turn {turn_id} [{query_type}]") + print(f"{'─'*60}") + print(f"[Query] {query}") + + # Capture reveal state BEFORE this turn + reveal_before = reveal_state.to_dict() + + # Update reveal state based on query content + newly_revealed = update_reveal_state(reveal_state, query) + if newly_revealed: + print(f"[Reveal] Newly revealed: {newly_revealed}") + print(f"[Reveal] State: {reveal_state}") + + # Capture reveal state AFTER update + reveal_after = reveal_state.to_dict() + + # Get user state before + state_before = llm.get_user_state_summary(user_id) + z_long_before = state_before["z_long_norm"] + z_short_before = state_before["z_short_norm"] + + # Apply feedback for previous turn (from turn 1 onwards in this session) + if turn_id > 0 and len(logs) > 0: + # Find the last log from THIS session + session_logs = [l for l in logs if l.session_id == session_id] + if session_logs: + prev_log = session_logs[-1] + feedback = Feedback( + user_id=user_id, + turn_id=prev_log.turn_id, + reward=prev_log.reward, + gating=prev_log.gating, + meta={ + "sat_t": prev_log.sat_t, + "violations": prev_log.violations, + "source": "pilot_v2", + "session_id": session_id, + } + ) + print(f"[Feedback] turn={prev_log.turn_id}, reward={feedback.reward:.2f}, gating={feedback.gating:.1f}") + llm.apply_feedback(feedback) + + # Chat + resp: AssistantResponse = llm.chat(user_id, query) + + answer_display = resp.answer[:150] + "..." if len(resp.answer) > 150 else resp.answer + print(f"[Answer] ({len(resp.answer)} chars) {answer_display}") + print(f"[Usage] prompt={resp.usage.prompt_tokens}, completion={resp.usage.completion_tokens}") + + # Judge with reveal-aware logic + judge_result = style_judge_with_reveal(query, resp.answer, task_type, prefs, reveal_state) + print(f"[Judge] sat={judge_result.sat_t:.2f}, enforced={judge_result.enforced_constraints}") + if judge_result.violations: + print(f"[Judge] violations={judge_result.violations}") + + # Compute feedback + reward, gating = compute_feedback_for_turn( + turn_id=turn_id, + query=query, + query_type=query_type, + task_type=task_type, + judge_result=judge_result, + ) + print(f"[Feedback] reward={reward:.2f}, gating={gating:.1f}") + + # Get state after + state_after = llm.get_user_state_summary(user_id) + z_long_after = state_after["z_long_norm"] + z_short_after = state_after["z_short_norm"] + + z_long_delta = z_long_after - z_long_before + z_short_delta = z_short_after - z_short_before + print(f"[State] z_long: {z_long_before:.6f} → {z_long_after:.6f} (Δ={z_long_delta:+.6f})") + print(f"[State] z_short: {z_short_before:.6f} → {z_short_after:.6f} (Δ={z_short_delta:+.6f})") + + # Debug info + num_memories = len(resp.debug.selected_memory_ids) if resp.debug else 0 + num_prefs = len(resp.debug.extracted_preferences) if resp.debug else 0 + print(f"[Debug] memories={num_memories}, prefs_extracted={num_prefs}") + + # Log + log = TurnLog( + session_id=session_id, + turn_id=turn_id, + query=query, + query_type=query_type, + task_type=task_type, + answer=resp.answer, + answer_length=len(resp.answer), + sat_t=judge_result.sat_t, + sev_t=judge_result.sev_t, + prog_t=judge_result.prog_t, + violations=judge_result.violations, + enforced_constraints=judge_result.enforced_constraints, + reward=reward, + gating=gating, + reveal_state_before=reveal_before, + reveal_state_after=reveal_after, + newly_revealed=list(newly_revealed), + z_long_norm_before=z_long_before, + z_long_norm_after=z_long_after, + z_short_norm_before=z_short_before, + z_short_norm_after=z_short_after, + prompt_tokens=resp.usage.prompt_tokens, + completion_tokens=resp.usage.completion_tokens, + total_tokens=resp.usage.total_tokens, + num_memories_retrieved=num_memories, + num_prefs_extracted=num_prefs, + ) + logs.append(log) + + # Apply final feedback for this session + session_logs = [l for l in logs if l.session_id == session_id] + if session_logs: + last_log = session_logs[-1] + feedback = Feedback( + user_id=user_id, + turn_id=last_log.turn_id, + reward=last_log.reward, + gating=last_log.gating, + meta={"source": "pilot_v2", "session_id": session_id, "final": True} + ) + print(f"\n[Final Feedback] turn={last_log.turn_id}, reward={feedback.reward:.2f}, gating={feedback.gating:.1f}") + llm.apply_feedback(feedback) + + print(f"\n[Session {session_id} End] Reveal state: {reveal_state}") + + return logs + + +def run_pilot_v2( + llm: PersonalizedLLM, + user_id: str = "pilot_user_v2", + prefs: Optional[StylePrefs] = None, +) -> List[TurnLog]: + """ + Run multi-session pilot with reveal state tracking. + + Session 1: User reveals preferences + Session 2: User does NOT restate preferences (tests cross-session retention) + Session 3: Mix of tasks and reminders + """ + if prefs is None: + prefs = StylePrefs( + require_short=True, + max_chars=200, + require_bullets=True, + lang="en", + ) + + # Initialize reveal state manager + reveal_manager = RevealStateManager() + + print(f"\n{'#'*60}") + print(f"PILOT v2: CROSS-SESSION PREFERENCE REVEAL TEST") + print(f"User: {user_id}") + print(f"True prefs: short={prefs.require_short}, bullets={prefs.require_bullets}, lang={prefs.lang}") + print(f"{'#'*60}") + + # Reset user completely (clears all state including reveal) + print(f"\n[Pilot] Resetting user: {user_id}") + llm.reset_user(user_id) + reveal_manager.reset_user(user_id) + + all_logs: List[TurnLog] = [] + reveal_state = reveal_manager.get_state(user_id) + + # Session 1: Reveal preferences + session_1_queries = get_session_1_queries() + logs_s1 = run_session(llm, user_id, 1, prefs, reveal_state, session_1_queries) + all_logs.extend(logs_s1) + + # Session 2: NO preference restatement (test cross-session retention) + # Note: reveal_state persists, but reset_session clears history + reveal_manager.reset_session(user_id) # Does nothing to reveal state + session_2_queries = get_session_2_queries() + logs_s2 = run_session(llm, user_id, 2, prefs, reveal_state, session_2_queries) + all_logs.extend(logs_s2) + + # Session 3: Reminder and more tasks + reveal_manager.reset_session(user_id) + session_3_queries = get_session_3_queries() + logs_s3 = run_session(llm, user_id, 3, prefs, reveal_state, session_3_queries) + all_logs.extend(logs_s3) + + return all_logs + + +def print_summary_v2(logs: List[TurnLog], prefs: StylePrefs): + """Print summary for pilot v2.""" + print(f"\n{'='*60}") + print("PILOT v2 SUMMARY - Cross-Session Reveal") + print(f"{'='*60}") + + if not logs: + print("No logs to summarize.") + return + + # Per-session stats + sessions = sorted(set(l.session_id for l in logs)) + + print(f"\n--- Per-Session Statistics ---") + for sid in sessions: + session_logs = [l for l in logs if l.session_id == sid] + avg_sat = sum(l.sat_t for l in session_logs) / len(session_logs) + violations = [v for l in session_logs for v in l.violations] + + # What was revealed at session end + if session_logs: + final_reveal = session_logs[-1].reveal_state_after + else: + final_reveal = {} + + print(f"\nSession {sid}: {len(session_logs)} turns") + print(f" Avg sat_t: {avg_sat:.3f}") + print(f" Violations: {len(violations)} ({violations if violations else 'none'})") + print(f" Reveal state at end: {final_reveal}") + + # Overall stats + total = len(logs) + avg_sat = sum(l.sat_t for l in logs) / total + total_tokens = sum(l.total_tokens for l in logs) + + print(f"\n--- Overall Statistics ---") + print(f"Total turns: {total}") + print(f"Overall avg sat_t: {avg_sat:.3f}") + print(f"Total tokens: {total_tokens}") + + # Violations by type + print(f"\n--- Violations Breakdown ---") + from collections import Counter + all_violations = [v for l in logs for v in l.violations] + if all_violations: + for v, count in Counter(all_violations).most_common(): + print(f" {v}: {count}") + else: + print(" No violations") + + # Enforcement tracking + print(f"\n--- Constraint Enforcement ---") + for constraint in ["short", "bullets", "lang"]: + enforced_count = sum(1 for l in logs if constraint in l.enforced_constraints) + print(f" {constraint}: enforced in {enforced_count}/{total} turns") + + # Cross-session reveal verification + print(f"\n--- Cross-Session Reveal Verification ---") + + # Session 1: Should have some reveals + s1_logs = [l for l in logs if l.session_id == 1] + s1_reveals = set() + for l in s1_logs: + s1_reveals.update(l.newly_revealed) + print(f"Session 1 revealed: {s1_reveals if s1_reveals else 'none'}") + + # Session 2: Should NOT have new reveals (no preference queries) + s2_logs = [l for l in logs if l.session_id == 2] + s2_reveals = set() + for l in s2_logs: + s2_reveals.update(l.newly_revealed) + print(f"Session 2 revealed: {s2_reveals if s2_reveals else 'none (expected)'}") + + # But Session 2 should still ENFORCE the constraints revealed in Session 1 + if s2_logs: + s2_enforced = set() + for l in s2_logs: + s2_enforced.update(l.enforced_constraints) + print(f"Session 2 enforced: {s2_enforced}") + + if s1_reveals and s1_reveals.issubset(s2_enforced): + print("✓ Cross-session retention VERIFIED: Session 1 reveals enforced in Session 2") + else: + print("✗ Cross-session retention issue: some reveals not enforced") + + # Turn-by-turn table + print(f"\n--- Turn-by-Turn Summary ---") + print(f"{'S':>2} {'T':>2} {'Type':>10} {'Len':>5} {'sat':>5} {'enforced':<20} {'violations'}") + print("-" * 70) + for l in logs: + enforced_str = ",".join(l.enforced_constraints) if l.enforced_constraints else "-" + viol_str = ",".join(l.violations) if l.violations else "-" + print(f"{l.session_id:>2} {l.turn_id:>2} {l.query_type:>10} {l.answer_length:>5} {l.sat_t:>5.2f} {enforced_str:<20} {viol_str}") + + +def main(): + print("=" * 60) + print("PILOT RUNNER v2 - Cross-Session Preference Reveal") + print("=" * 60) + print(f"Started at: {datetime.now().isoformat()}") + + # Define user's TRUE preferences + prefs = StylePrefs( + require_short=True, + max_chars=200, + require_bullets=True, + lang="en", + ) + print(f"\n[Config] True preferences: {prefs}") + print("[Config] Note: Constraints only enforced AFTER user reveals them") + + # Initialize LLM + print("\n[Init] Loading PersonalizedLLM...") + llm = PersonalizedLLM( + user_store_path="data/users/user_store_pilot_v2.npz", + only_own_memories=True, + enable_preference_extraction=True, + enable_rl_updates=True, + ) + + # Run pilot + user_id = "pilot_user_v2" + logs = run_pilot_v2(llm, user_id=user_id, prefs=prefs) + + # Summary + print_summary_v2(logs, prefs) + + # Save logs + log_path = f"data/logs/pilot_v2_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl" + log_to_jsonl(logs, log_path) + print(f"\n[Logs] Saved to: {log_path}") + + # Final state + final_state = llm.get_user_state_summary(user_id) + print(f"\n[Final State] {final_state}") + + print(f"\nCompleted at: {datetime.now().isoformat()}") + print("=" * 60) + + +if __name__ == "__main__": + main() + + diff --git a/scripts/pilot_runner_v3.py b/scripts/pilot_runner_v3.py new file mode 100644 index 0000000..d232d10 --- /dev/null +++ b/scripts/pilot_runner_v3.py @@ -0,0 +1,924 @@ +#!/usr/bin/env python3 +""" +Pilot Runner v3 - Multi-User Multi-Session with Personas + +Upgrades from v2: +- Persona: Bundles StylePrefs into user types +- 5 test personas (A-E) targeting different style combinations +- Multi-user × multi-session evaluation +- Refined judge: bullets only on list tasks, relaxed empty_answer +- Baseline mode support (no-personalization comparison) + +5 Test Personas: +- A: short + bullets + en (sanity check) +- B: short + NO bullets + en (anti-bullet) +- C: long + bullets + en (no length constraint) +- D: short + bullets + zh (Chinese) +- E: long + NO bullets + zh (most "anti-default") +""" + +import sys +import os +import json +from datetime import datetime +from dataclasses import dataclass, asdict, field +from typing import List, Dict, Any, Optional, Tuple, Set, Literal + +# Add src to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.serving import PersonalizedLLM, Feedback, AssistantResponse + + +# ============================================================================= +# Style Preferences (True Preferences) +# ============================================================================= + +@dataclass +class StylePrefs: + """User's TRUE style preferences.""" + require_short: bool = False + max_chars: int = 300 + require_bullets: bool = False + lang: str = "en" # "en" or "zh" + + +# ============================================================================= +# Persona Definition +# ============================================================================= + +@dataclass +class Persona: + """ + A user persona that bundles style preferences. + Each persona represents a distinct user type for testing. + """ + persona_id: str + style_prefs: StylePrefs + description: str = "" + # Future extensions: + # task_preferences: Dict[str, float] # e.g., {"code": 0.3, "rewrite": 0.7} + # tone: str = "neutral" # "formal", "casual", etc. + # domain: str = "general" # "tech", "daily_life", etc. + + +# ============================================================================= +# 5 Test Personas (A-E) +# ============================================================================= + +PERSONA_A = Persona( + persona_id="A_short_bullets_en", + style_prefs=StylePrefs( + require_short=True, + max_chars=200, + require_bullets=True, + lang="en", + ), + description="Short + bullets + English (sanity check, same as v2)", +) + +PERSONA_B = Persona( + persona_id="B_short_no_bullets_en", + style_prefs=StylePrefs( + require_short=True, + max_chars=200, + require_bullets=False, + lang="en", + ), + description="Short + NO bullets + English (anti-bullet test)", +) + +PERSONA_C = Persona( + persona_id="C_long_bullets_en", + style_prefs=StylePrefs( + require_short=False, + max_chars=800, + require_bullets=True, + lang="en", + ), + description="Long + bullets + English (no length constraint)", +) + +PERSONA_D = Persona( + persona_id="D_short_bullets_zh", + style_prefs=StylePrefs( + require_short=True, + max_chars=200, + require_bullets=True, + lang="zh", + ), + description="Short + bullets + Chinese (language test)", +) + +PERSONA_E = Persona( + persona_id="E_long_no_bullets_zh", + style_prefs=StylePrefs( + require_short=False, + max_chars=800, + require_bullets=False, + lang="zh", + ), + description="Long + NO bullets + Chinese (most anti-default)", +) + +ALL_PERSONAS = [PERSONA_A, PERSONA_B, PERSONA_C, PERSONA_D, PERSONA_E] + + +def get_persona_by_id(persona_id: str) -> Optional[Persona]: + """Get persona by ID.""" + for p in ALL_PERSONAS: + if p.persona_id == persona_id: + return p + return None + + +# ============================================================================= +# Reveal State +# ============================================================================= + +@dataclass +class RevealState: + """Tracks which preferences have been explicitly revealed.""" + short_revealed: bool = False + bullets_revealed: bool = False + lang_revealed: bool = False + + def reset(self): + self.short_revealed = False + self.bullets_revealed = False + self.lang_revealed = False + + def to_dict(self) -> Dict[str, bool]: + return { + "short": self.short_revealed, + "bullets": self.bullets_revealed, + "lang": self.lang_revealed, + } + + def __str__(self) -> str: + flags = [] + if self.short_revealed: + flags.append("short") + if self.bullets_revealed: + flags.append("bullets") + if self.lang_revealed: + flags.append("lang") + return f"RevealState({', '.join(flags) if flags else 'none'})" + + +class RevealStateManager: + """Manages reveal state for multiple users.""" + + def __init__(self): + self._states: Dict[str, RevealState] = {} + + def get_state(self, user_id: str) -> RevealState: + if user_id not in self._states: + self._states[user_id] = RevealState() + return self._states[user_id] + + def reset_user(self, user_id: str): + if user_id in self._states: + self._states[user_id].reset() + else: + self._states[user_id] = RevealState() + + def reset_session(self, user_id: str): + pass # Reveal state persists across sessions + + +# ============================================================================= +# Preference Detection +# ============================================================================= + +def detect_revealed_preferences(query: str, prefs: StylePrefs) -> Dict[str, bool]: + """ + Detect which preferences are mentioned in a query. + Also considers the user's true preferences for language detection. + """ + lower_q = (query or "").lower() + + revealed = { + "short": False, + "bullets": False, + "lang": False, + } + + # Short/length preference + short_patterns = [ + "short", "concise", "brief", "under ", "less than", + "keep it short", "keep responses", "keep answers", + "maximum ", "max ", "characters", "words or less", + "200 ", "100 ", "50 ", "300 ", + ] + for pattern in short_patterns: + if pattern in lower_q: + revealed["short"] = True + break + + # Bullet preference (both positive and negative) + bullet_patterns = [ + "bullet", "bullet point", "bullet-point", + "bulleted", "list format", "use bullets", + "no bullet", "don't use bullet", "without bullet", + "numbered list", "use numbers", + ] + for pattern in bullet_patterns: + if pattern in lower_q: + revealed["bullets"] = True + break + + # Language preference + lang_patterns_zh = [ + "chinese", "中文", "in chinese", "用中文", + "speak chinese", "write chinese", "respond in chinese", + "please use chinese", "mandarin", "请用中文", + ] + lang_patterns_en = [ + "english", "in english", "use english", + "speak english", "write english", "respond in english", + "please use english", + ] + + for pattern in lang_patterns_zh + lang_patterns_en: + if pattern in lower_q: + revealed["lang"] = True + break + + return revealed + + +def update_reveal_state(reveal_state: RevealState, query: str, prefs: StylePrefs) -> Set[str]: + """Update reveal state based on query content.""" + detected = detect_revealed_preferences(query, prefs) + newly_revealed = set() + + if detected["short"] and not reveal_state.short_revealed: + reveal_state.short_revealed = True + newly_revealed.add("short") + + if detected["bullets"] and not reveal_state.bullets_revealed: + reveal_state.bullets_revealed = True + newly_revealed.add("bullets") + + if detected["lang"] and not reveal_state.lang_revealed: + reveal_state.lang_revealed = True + newly_revealed.add("lang") + + return newly_revealed + + +# ============================================================================= +# Refined Style Judge +# ============================================================================= + +@dataclass +class JudgeResult: + """Output from the judge for one turn.""" + sat_t: float + sev_t: float + prog_t: float + violations: List[str] + enforced_constraints: List[str] + + +def style_judge_v3( + query: str, + answer: str, + task_type: str, + prefs: StylePrefs, + reveal_state: RevealState, +) -> JudgeResult: + """ + Refined style judge with: + - Bullets only enforced on list-type tasks + - Relaxed empty_answer (only truly empty or single char) + - Reveal-aware enforcement + """ + violations: List[str] = [] + enforced: List[str] = [] + text = (answer or "").strip() + + # 0) Empty answer - only truly empty or single non-meaningful char + # Relaxed: allow short factual answers like "4", "Paris" + if len(text) == 0: + violations.append("empty_answer") + return JudgeResult( + sat_t=0.0, + sev_t=1.0, + prog_t=0.0, + violations=violations, + enforced_constraints=["non_empty"], + ) + + # 1) Length - enforce only if BOTH true AND revealed + if prefs.require_short and reveal_state.short_revealed: + enforced.append("short") + if len(text) > prefs.max_chars: + violations.append("too_long") + + # 2) Bullets - enforce ONLY on list-type tasks AND if revealed + # task_type "list" = listing tasks (Name X things, What are the N...) + # task_type "qa" = factual QA (What is the capital...) + # task_type "general" = other general tasks + if prefs.require_bullets and reveal_state.bullets_revealed: + if task_type == "list": # Only enforce on list tasks + enforced.append("bullets") + has_bullets = ("- " in text) or ("• " in text) or ("* " in text) or ("\n- " in text) + if not has_bullets: + violations.append("no_bullets") + + # 3) Language - enforce only if revealed + if reveal_state.lang_revealed: + enforced.append("lang") + if prefs.lang == "zh": + # For Chinese: should have significant non-ASCII content + ascii_count = sum(c.isascii() for c in text) + ascii_ratio = ascii_count / max(1, len(text)) + if ascii_ratio > 0.7: + violations.append("wrong_lang") + elif prefs.lang == "en": + # For English: should be mostly ASCII + ascii_count = sum(c.isascii() for c in text) + ascii_ratio = ascii_count / max(1, len(text)) + if ascii_ratio < 0.5: + violations.append("wrong_lang") + + # 4) Code task: always enforce + prog_t = 1.0 + if task_type == "code": + enforced.append("code_block") + has_code = ("```" in text) or ("def " in text) or ("function " in text) + if not has_code: + violations.append("no_code_block") + prog_t = 0.0 + + # 5) Compute scores + if not violations: + sat_t = 1.0 + sev_t = 0.0 + else: + sat_t = max(0.0, 1.0 - 0.3 * float(len(violations))) + hard_violations = {"empty_answer", "too_long", "wrong_lang"} + sev_t = 1.0 if any(v in hard_violations for v in violations) else 0.0 + + return JudgeResult( + sat_t=sat_t, + sev_t=sev_t, + prog_t=prog_t, + violations=violations, + enforced_constraints=enforced, + ) + + +# ============================================================================= +# Feedback Computation +# ============================================================================= + +def compute_feedback_for_turn( + turn_id: int, + query: str, + query_type: str, + task_type: str, + judge_result: JudgeResult, +) -> Tuple[float, float]: + """Convert JudgeResult into (reward, gating).""" + reward = judge_result.sat_t + + lower_q = (query or "").lower() + + is_pref_turn = ( + query_type == "preference" + or "i prefer" in lower_q + or "my preference" in lower_q + or "please use" in lower_q + or "please keep" in lower_q + or "you didn't follow" in lower_q + or "you forgot" in lower_q + or "remember that i" in lower_q + or "i told you" in lower_q + or "i asked for" in lower_q + or "中文" in lower_q + or "用中文" in lower_q + ) + + gating = 1.0 if is_pref_turn else 0.0 + return reward, gating + + +# ============================================================================= +# Query Generation per Persona +# ============================================================================= + +def get_session_1_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]: + """ + Session 1: Reveal preferences. + Customize based on persona's true preferences. + """ + queries = [] + prefs = persona.style_prefs + + # Turn 0: Reveal length preference + if prefs.require_short: + if prefs.lang == "zh": + queries.append({ + "query": "我喜欢简短的回答,请保持回复在200字以内。", + "type": "preference", + "task_type": "general", + }) + else: + queries.append({ + "query": "I prefer short, concise answers. Please keep responses under 200 characters.", + "type": "preference", + "task_type": "general", + }) + else: + # Long preference - don't reveal (let short_revealed stay False) + if prefs.lang == "zh": + queries.append({ + "query": "你好,我想了解一些问题。", + "type": "task", + "task_type": "general", + }) + else: + queries.append({ + "query": "Hello, I have some questions for you.", + "type": "task", + "task_type": "general", + }) + + # Turn 1: First task + if prefs.lang == "zh": + queries.append({ + "query": "列出三个改善睡眠的建议。", + "type": "task", + "task_type": "list", + }) + else: + queries.append({ + "query": "List three tips for better sleep.", + "type": "task", + "task_type": "list", + }) + + # Turn 2: Reveal bullet preference + if prefs.require_bullets: + if prefs.lang == "zh": + queries.append({ + "query": "我喜欢用项目符号列出要点,请使用bullet points。", + "type": "preference", + "task_type": "general", + }) + else: + queries.append({ + "query": "I prefer bullet points when listing things. Please use bullet points.", + "type": "preference", + "task_type": "general", + }) + else: + # Don't reveal bullet preference (or reveal anti-bullet) + if prefs.lang == "zh": + queries.append({ + "query": "请不要用项目符号,我更喜欢连续的句子。", + "type": "preference", + "task_type": "general", + }) + else: + queries.append({ + "query": "Please don't use bullet points. I prefer continuous prose.", + "type": "preference", + "task_type": "general", + }) + + # Turn 3: Reveal language preference (for non-English personas) + if prefs.lang == "zh": + queries.append({ + "query": "请用中文回答我的问题。", + "type": "preference", + "task_type": "general", + }) + else: + queries.append({ + "query": "Please respond in English.", + "type": "preference", + "task_type": "general", + }) + + # Turn 4-5: Tasks + if prefs.lang == "zh": + queries.extend([ + { + "query": "锻炼有什么好处?", + "type": "task", + "task_type": "list", + }, + { + "query": "列出五种流行的编程语言。", + "type": "task", + "task_type": "list", + }, + ]) + else: + queries.extend([ + { + "query": "What are the benefits of exercise?", + "type": "task", + "task_type": "list", + }, + { + "query": "Name five popular programming languages.", + "type": "task", + "task_type": "list", + }, + ]) + + return queries + + +def get_session_2_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]: + """ + Session 2: NO preference restatement. + Tests cross-session retention. + """ + prefs = persona.style_prefs + + if prefs.lang == "zh": + return [ + {"query": "推荐三种健康的早餐。", "type": "task", "task_type": "list"}, + {"query": "一年有哪四个季节?", "type": "task", "task_type": "list"}, + {"query": "法国的首都是哪里?", "type": "task", "task_type": "qa"}, + {"query": "列出三种可再生能源。", "type": "task", "task_type": "list"}, + ] + else: + return [ + {"query": "What are three healthy breakfast ideas?", "type": "task", "task_type": "list"}, + {"query": "What are the four seasons of the year?", "type": "task", "task_type": "list"}, + {"query": "What is the capital of France?", "type": "task", "task_type": "qa"}, + {"query": "Name three types of renewable energy.", "type": "task", "task_type": "list"}, + ] + + +def get_session_3_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]: + """ + Session 3: Mix of tasks and one reminder. + """ + prefs = persona.style_prefs + + if prefs.lang == "zh": + return [ + {"query": "列出五种常见的水果。", "type": "task", "task_type": "list"}, + {"query": "请记住我喜欢简短的回答。列出三种海洋动物。", "type": "preference", "task_type": "list"}, + {"query": "2加2等于多少?", "type": "task", "task_type": "qa"}, + ] + else: + return [ + {"query": "Name five common fruits.", "type": "task", "task_type": "list"}, + {"query": "Remember that I asked for short answers. List three ocean animals.", "type": "preference", "task_type": "list"}, + {"query": "What is 2 + 2?", "type": "task", "task_type": "qa"}, + ] + + +# ============================================================================= +# Logging +# ============================================================================= + +@dataclass +class TurnLog: + """Log entry for one turn.""" + user_id: str + persona_id: str + session_id: int + turn_id: int + query: str + query_type: str + task_type: str + answer: str + answer_length: int + sat_t: float + sev_t: float + prog_t: float + violations: List[str] + enforced_constraints: List[str] + reward: float + gating: float + reveal_state_before: Dict[str, bool] + reveal_state_after: Dict[str, bool] + newly_revealed: List[str] + z_long_norm_before: float + z_long_norm_after: float + z_short_norm_before: float + z_short_norm_after: float + prompt_tokens: int + completion_tokens: int + total_tokens: int + num_memories_retrieved: int + num_prefs_extracted: int + + +def log_to_jsonl(logs: List[TurnLog], filepath: str): + """Save logs to JSONL file.""" + os.makedirs(os.path.dirname(filepath), exist_ok=True) + with open(filepath, "w") as f: + for log in logs: + f.write(json.dumps(asdict(log)) + "\n") + + +# ============================================================================= +# Session Runner +# ============================================================================= + +def run_session( + llm: PersonalizedLLM, + user_id: str, + persona: Persona, + session_id: int, + reveal_state: RevealState, + queries: List[Dict[str, Any]], + all_logs: List[TurnLog], +) -> List[TurnLog]: + """Run a single session for a user.""" + + prefs = persona.style_prefs + session_logs: List[TurnLog] = [] + + print(f"\n{'='*60}") + print(f"[{persona.persona_id}] Session {session_id}: {len(queries)} turns") + print(f"Reveal state (start): {reveal_state}") + print(f"{'='*60}") + + # Reset session (clears history, z_short; keeps z_long and reveal state) + llm.reset_session(user_id) + + for turn_id, q_info in enumerate(queries): + query = q_info["query"] + query_type = q_info.get("type", "task") + task_type = q_info.get("task_type", "general") + + print(f"\n--- S{session_id}/T{turn_id} [{query_type}] ---") + print(f"[Q] {query[:60]}{'...' if len(query) > 60 else ''}") + + # Capture reveal state BEFORE + reveal_before = reveal_state.to_dict() + + # Update reveal state + newly_revealed = update_reveal_state(reveal_state, query, prefs) + if newly_revealed: + print(f"[Reveal] Newly: {newly_revealed}") + + reveal_after = reveal_state.to_dict() + + # Get user state before + state_before = llm.get_user_state_summary(user_id) + z_long_before = state_before["z_long_norm"] + z_short_before = state_before["z_short_norm"] + + # Apply feedback for previous turn + if turn_id > 0 and session_logs: + prev_log = session_logs[-1] + feedback = Feedback( + user_id=user_id, + turn_id=prev_log.turn_id, + reward=prev_log.reward, + gating=prev_log.gating, + meta={"source": "pilot_v3", "session_id": session_id} + ) + llm.apply_feedback(feedback) + + # Chat + resp: AssistantResponse = llm.chat(user_id, query) + + answer_display = resp.answer[:80] + "..." if len(resp.answer) > 80 else resp.answer + print(f"[A] ({len(resp.answer)}c) {answer_display}") + + # Judge + judge_result = style_judge_v3(query, resp.answer, task_type, prefs, reveal_state) + print(f"[J] sat={judge_result.sat_t:.2f}, enforced={judge_result.enforced_constraints}, viol={judge_result.violations}") + + # Compute feedback + reward, gating = compute_feedback_for_turn(turn_id, query, query_type, task_type, judge_result) + + # Get state after + state_after = llm.get_user_state_summary(user_id) + z_long_after = state_after["z_long_norm"] + z_short_after = state_after["z_short_norm"] + + # Debug info + num_memories = len(resp.debug.selected_memory_ids) if resp.debug else 0 + num_prefs = len(resp.debug.extracted_preferences) if resp.debug else 0 + + # Log + log = TurnLog( + user_id=user_id, + persona_id=persona.persona_id, + session_id=session_id, + turn_id=turn_id, + query=query, + query_type=query_type, + task_type=task_type, + answer=resp.answer, + answer_length=len(resp.answer), + sat_t=judge_result.sat_t, + sev_t=judge_result.sev_t, + prog_t=judge_result.prog_t, + violations=judge_result.violations, + enforced_constraints=judge_result.enforced_constraints, + reward=reward, + gating=gating, + reveal_state_before=reveal_before, + reveal_state_after=reveal_after, + newly_revealed=list(newly_revealed), + z_long_norm_before=z_long_before, + z_long_norm_after=z_long_after, + z_short_norm_before=z_short_before, + z_short_norm_after=z_short_after, + prompt_tokens=resp.usage.prompt_tokens, + completion_tokens=resp.usage.completion_tokens, + total_tokens=resp.usage.total_tokens, + num_memories_retrieved=num_memories, + num_prefs_extracted=num_prefs, + ) + session_logs.append(log) + all_logs.append(log) + + # Apply final feedback + if session_logs: + last_log = session_logs[-1] + feedback = Feedback( + user_id=user_id, + turn_id=last_log.turn_id, + reward=last_log.reward, + gating=last_log.gating, + meta={"source": "pilot_v3", "session_id": session_id, "final": True} + ) + llm.apply_feedback(feedback) + + return session_logs + + +# ============================================================================= +# Multi-User Multi-Session Runner +# ============================================================================= + +def run_multi_user_pilot( + llm: PersonalizedLLM, + personas: List[Persona], + num_sessions: int = 3, + reveal_manager: Optional[RevealStateManager] = None, +) -> List[TurnLog]: + """ + Run multi-user multi-session pilot. + + Args: + llm: PersonalizedLLM instance + personas: List of personas to test + num_sessions: Number of sessions per user + reveal_manager: Optional existing reveal manager + """ + if reveal_manager is None: + reveal_manager = RevealStateManager() + + all_logs: List[TurnLog] = [] + + print(f"\n{'#'*60}") + print(f"PILOT v3: MULTI-USER MULTI-SESSION") + print(f"Users: {len(personas)}, Sessions per user: {num_sessions}") + print(f"{'#'*60}") + + for persona in personas: + user_id = f"user_{persona.persona_id}" + prefs = persona.style_prefs + + print(f"\n{'*'*60}") + print(f"USER: {user_id}") + print(f"Persona: {persona.description}") + print(f"True prefs: short={prefs.require_short}, bullets={prefs.require_bullets}, lang={prefs.lang}") + print(f"{'*'*60}") + + # Reset user completely + llm.reset_user(user_id) + reveal_manager.reset_user(user_id) + reveal_state = reveal_manager.get_state(user_id) + + # Run sessions + for session_id in range(1, num_sessions + 1): + if session_id == 1: + queries = get_session_1_queries_for_persona(persona) + elif session_id == 2: + queries = get_session_2_queries_for_persona(persona) + else: + queries = get_session_3_queries_for_persona(persona) + + reveal_manager.reset_session(user_id) # No-op, just for clarity + run_session(llm, user_id, persona, session_id, reveal_state, queries, all_logs) + + return all_logs + + +# ============================================================================= +# Summary +# ============================================================================= + +def print_summary_v3(logs: List[TurnLog]): + """Print summary for pilot v3.""" + print(f"\n{'='*60}") + print("PILOT v3 SUMMARY - Multi-User Multi-Session") + print(f"{'='*60}") + + if not logs: + print("No logs.") + return + + from collections import Counter, defaultdict + + # Per-persona stats + personas = sorted(set(l.persona_id for l in logs)) + + print(f"\n--- Per-Persona Statistics ---") + for pid in personas: + p_logs = [l for l in logs if l.persona_id == pid] + + # Per-session breakdown + sessions = sorted(set(l.session_id for l in p_logs)) + + print(f"\n{pid}:") + for sid in sessions: + s_logs = [l for l in p_logs if l.session_id == sid] + avg_sat = sum(l.sat_t for l in s_logs) / len(s_logs) if s_logs else 0 + violations = [v for l in s_logs for v in l.violations] + enforced = set(c for l in s_logs for c in l.enforced_constraints) + + print(f" Session {sid}: {len(s_logs)} turns, avg_sat={avg_sat:.3f}, enforced={enforced}") + if violations: + print(f" violations: {dict(Counter(violations))}") + + # Cross-session retention check + print(f"\n--- Cross-Session Retention ---") + for pid in personas: + p_logs = [l for l in logs if l.persona_id == pid] + s1_logs = [l for l in p_logs if l.session_id == 1] + s2_logs = [l for l in p_logs if l.session_id == 2] + + if s1_logs and s2_logs: + s1_sat = sum(l.sat_t for l in s1_logs) / len(s1_logs) + s2_sat = sum(l.sat_t for l in s2_logs) / len(s2_logs) + + # Check what was enforced in S2 + s2_enforced = set(c for l in s2_logs for c in l.enforced_constraints) + + print(f"{pid}: S1_sat={s1_sat:.3f} → S2_sat={s2_sat:.3f}, S2_enforced={s2_enforced}") + + # Overall stats + total = len(logs) + avg_sat = sum(l.sat_t for l in logs) / total + total_tokens = sum(l.total_tokens for l in logs) + + print(f"\n--- Overall ---") + print(f"Total turns: {total}") + print(f"Overall avg sat_t: {avg_sat:.3f}") + print(f"Total tokens: {total_tokens}") + + # Violations by type + all_violations = [v for l in logs for v in l.violations] + if all_violations: + print(f"\nViolations: {dict(Counter(all_violations))}") + + +def main(): + print("=" * 60) + print("PILOT RUNNER v3 - Multi-User Multi-Session with Personas") + print("=" * 60) + print(f"Started at: {datetime.now().isoformat()}") + + # Select personas + personas = ALL_PERSONAS # All 5 personas + print(f"\n[Config] Running {len(personas)} personas:") + for p in personas: + print(f" - {p.persona_id}: {p.description}") + + # Initialize LLM + print("\n[Init] Loading PersonalizedLLM...") + llm = PersonalizedLLM( + user_store_path="data/users/user_store_pilot_v3.npz", + only_own_memories=True, + enable_preference_extraction=True, + enable_rl_updates=True, + ) + + # Run pilot + logs = run_multi_user_pilot(llm, personas, num_sessions=3) + + # Summary + print_summary_v3(logs) + + # Save logs + log_path = f"data/logs/pilot_v3_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl" + log_to_jsonl(logs, log_path) + print(f"\n[Logs] Saved to: {log_path}") + + print(f"\nCompleted at: {datetime.now().isoformat()}") + print("=" * 60) + + +if __name__ == "__main__": + main() + diff --git a/scripts/pilot_runner_v4.py b/scripts/pilot_runner_v4.py new file mode 100644 index 0000000..b3e2058 --- /dev/null +++ b/scripts/pilot_runner_v4.py @@ -0,0 +1,1230 @@ +#!/usr/bin/env python3 +""" +Pilot Runner v4 - Critical Fixes for Baseline Comparison + +Fixes from v3: +1. Chinese short reveal detection (简短/字以内/不超过 etc.) +2. Symmetric bullets constraint (has_bullets violation for require_bullets=False) +3. Better wrong_lang with CJK ratio + math exemption +4. Persona-conditional query templates (no self-contradiction) +5. Violation-triggered complaint mechanism (for online RL signal) + +This version is ready for proper baseline comparison. +""" + +import sys +import os +import re +import json +from datetime import datetime +from dataclasses import dataclass, asdict, field +from typing import List, Dict, Any, Optional, Tuple, Set, Literal + +# Add src to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.serving import PersonalizedLLM, Feedback, AssistantResponse + + +# ============================================================================= +# Style Preferences (True Preferences) +# ============================================================================= + +@dataclass +class StylePrefs: + """User's TRUE style preferences.""" + require_short: bool = False + max_chars: int = 300 + require_bullets: bool = False # True = want bullets, False = don't want bullets + lang: str = "en" # "en" or "zh" + + +# ============================================================================= +# Persona Definition +# ============================================================================= + +@dataclass +class Persona: + """A user persona that bundles style preferences.""" + persona_id: str + style_prefs: StylePrefs + description: str = "" + + +# 5 Test Personas +PERSONA_A = Persona( + persona_id="A_short_bullets_en", + style_prefs=StylePrefs(require_short=True, max_chars=200, require_bullets=True, lang="en"), + description="Short + bullets + English", +) + +PERSONA_B = Persona( + persona_id="B_short_no_bullets_en", + style_prefs=StylePrefs(require_short=True, max_chars=200, require_bullets=False, lang="en"), + description="Short + NO bullets + English (anti-bullet)", +) + +PERSONA_C = Persona( + persona_id="C_long_bullets_en", + style_prefs=StylePrefs(require_short=False, max_chars=800, require_bullets=True, lang="en"), + description="Long + bullets + English", +) + +PERSONA_D = Persona( + persona_id="D_short_bullets_zh", + style_prefs=StylePrefs(require_short=True, max_chars=200, require_bullets=True, lang="zh"), + description="Short + bullets + Chinese", +) + +PERSONA_E = Persona( + persona_id="E_long_no_bullets_zh", + style_prefs=StylePrefs(require_short=False, max_chars=800, require_bullets=False, lang="zh"), + description="Long + NO bullets + Chinese (most anti-default)", +) + +# Extreme short persona for case study - LLM default is much longer +PERSONA_F = Persona( + persona_id="F_extreme_short_en", + style_prefs=StylePrefs(require_short=True, max_chars=100, require_bullets=True, lang="en"), + description="EXTREME short (100 chars) + bullets + English", +) + +ALL_PERSONAS = [PERSONA_A, PERSONA_B, PERSONA_C, PERSONA_D, PERSONA_E, PERSONA_F] + + +# ============================================================================= +# Reveal State +# ============================================================================= + +@dataclass +class RevealState: + """Tracks which preferences have been explicitly revealed.""" + short_revealed: bool = False + bullets_revealed: bool = False + lang_revealed: bool = False + + def reset(self): + self.short_revealed = False + self.bullets_revealed = False + self.lang_revealed = False + + def to_dict(self) -> Dict[str, bool]: + return {"short": self.short_revealed, "bullets": self.bullets_revealed, "lang": self.lang_revealed} + + def __str__(self) -> str: + flags = [k for k, v in self.to_dict().items() if v] + return f"RevealState({', '.join(flags) if flags else 'none'})" + + +class RevealStateManager: + """Manages reveal state for multiple users.""" + def __init__(self): + self._states: Dict[str, RevealState] = {} + + def get_state(self, user_id: str) -> RevealState: + if user_id not in self._states: + self._states[user_id] = RevealState() + return self._states[user_id] + + def reset_user(self, user_id: str): + self._states[user_id] = RevealState() + + def reset_session(self, user_id: str): + pass # Reveal state persists + + +# ============================================================================= +# FIX 1: Improved Preference Detection (with Chinese support) +# ============================================================================= + +def detect_revealed_preferences(query: str, prefs: StylePrefs) -> Dict[str, bool]: + """ + Detect which preferences are mentioned in a query. + FIX: Added Chinese keywords for short detection. + """ + lower_q = (query or "").lower() + original_q = query or "" + + revealed = {"short": False, "bullets": False, "lang": False} + + # Short/length preference - English patterns + short_patterns_en = [ + "short", "concise", "brief", "under ", "less than", + "keep it short", "keep responses", "keep answers", + "maximum ", "max ", "characters", "words or less", + ] + + # FIX: Chinese patterns for short preference + short_patterns_zh = [ + "简短", "精简", "尽量短", "不要太长", "字以内", "不超过", + "少于", "控制在", "简洁", "简明", + ] + + # Regex patterns for number-based length constraints + short_regex_patterns = [ + r"(\d+)\s*字以内", # "200字以内" + r"不超过\s*(\d+)\s*字", # "不超过200字" + r"under\s*(\d+)", # "under 200" + r"less\s*than\s*(\d+)", # "less than 200" + ] + + for pattern in short_patterns_en: + if pattern in lower_q: + revealed["short"] = True + break + + if not revealed["short"]: + for pattern in short_patterns_zh: + if pattern in original_q: + revealed["short"] = True + break + + if not revealed["short"]: + for regex in short_regex_patterns: + if re.search(regex, original_q, re.IGNORECASE): + revealed["short"] = True + break + + # Bullet preference - both positive and negative + bullet_patterns_positive = [ + "bullet", "bullet point", "bullet-point", "bulleted", + "list format", "use bullets", "use bullet", + "项目符号", "要点", "用bullet", + ] + bullet_patterns_negative = [ + "no bullet", "don't use bullet", "without bullet", + "不要bullet", "不要项目符号", "不用bullet", + "continuous prose", "paragraph form", "flowing text", + "连续句子", "段落形式", + ] + + for pattern in bullet_patterns_positive + bullet_patterns_negative: + if pattern in lower_q or pattern in original_q: + revealed["bullets"] = True + break + + # Language preference + lang_patterns_zh = [ + "chinese", "中文", "in chinese", "用中文", + "speak chinese", "respond in chinese", "请用中文", + ] + lang_patterns_en = [ + "english", "in english", "use english", + "speak english", "respond in english", + ] + + for pattern in lang_patterns_zh + lang_patterns_en: + if pattern in lower_q or pattern in original_q: + revealed["lang"] = True + break + + return revealed + + +def update_reveal_state(reveal_state: RevealState, query: str, prefs: StylePrefs) -> Set[str]: + """Update reveal state based on query content.""" + detected = detect_revealed_preferences(query, prefs) + newly_revealed = set() + + if detected["short"] and not reveal_state.short_revealed: + reveal_state.short_revealed = True + newly_revealed.add("short") + + if detected["bullets"] and not reveal_state.bullets_revealed: + reveal_state.bullets_revealed = True + newly_revealed.add("bullets") + + if detected["lang"] and not reveal_state.lang_revealed: + reveal_state.lang_revealed = True + newly_revealed.add("lang") + + return newly_revealed + + +# ============================================================================= +# FIX 3: Better Language Detection +# ============================================================================= + +def is_math_or_symbol_only(text: str) -> bool: + """Check if text is purely math/symbols (language neutral).""" + # Pattern: only digits, operators, whitespace, punctuation + math_pattern = r'^[\d+\-*/=().,%\s\n\r]+$' + return bool(re.match(math_pattern, text.strip())) + + +def count_cjk_chars(text: str) -> int: + """Count CJK (Chinese/Japanese/Korean) characters.""" + # CJK Unified Ideographs range + cjk_pattern = re.compile(r'[\u4e00-\u9fff\u3400-\u4dbf]') + return len(cjk_pattern.findall(text)) + + +def count_latin_letters(text: str) -> int: + """Count Latin letters (a-z, A-Z).""" + return sum(1 for c in text if c.isalpha() and c.isascii()) + + +def check_language_violation(text: str, target_lang: str) -> bool: + """ + FIX: Better language violation check using CJK ratio. + Returns True if there's a violation. + """ + text = text.strip() + + # Exempt pure math/symbols + if is_math_or_symbol_only(text): + return False + + cjk_count = count_cjk_chars(text) + latin_count = count_latin_letters(text) + total = cjk_count + latin_count + + if total == 0: + return False # No meaningful text to judge + + if target_lang == "zh": + # For Chinese: want high CJK ratio + cjk_ratio = cjk_count / (total + 1e-9) + # Allow some English proper nouns - only flag if very low CJK + return cjk_ratio < 0.2 # Less than 20% CJK = wrong language + + elif target_lang == "en": + # For English: want high Latin ratio + latin_ratio = latin_count / (total + 1e-9) + return latin_ratio < 0.5 # Less than 50% Latin = wrong language + + return False + + +# ============================================================================= +# FIX 2: Symmetric Bullets Constraint + FIX 3: Language +# ============================================================================= + +@dataclass +class JudgeResult: + """Output from the judge for one turn.""" + sat_t: float + sev_t: float + prog_t: float + violations: List[str] + enforced_constraints: List[str] + + +def has_bullet_markers(text: str) -> bool: + """Check if text contains bullet point markers.""" + return bool(re.search(r'(^|\n)\s*[-•*]\s', text)) + + +def style_judge_v4( + query: str, + answer: str, + task_type: str, + prefs: StylePrefs, + reveal_state: RevealState, +) -> JudgeResult: + """ + Style judge v4 with: + - FIX 2: Symmetric bullets (has_bullets violation for require_bullets=False) + - FIX 3: Better wrong_lang with CJK ratio + math exemption + """ + violations: List[str] = [] + enforced: List[str] = [] + text = (answer or "").strip() + + # 0) Empty answer + if len(text) == 0: + violations.append("empty_answer") + return JudgeResult(sat_t=0.0, sev_t=1.0, prog_t=0.0, + violations=violations, enforced_constraints=["non_empty"]) + + # 1) Length - enforce only if true AND revealed + if prefs.require_short and reveal_state.short_revealed: + enforced.append("short") + if len(text) > prefs.max_chars: + violations.append("too_long") + + # 2) FIX 2: Symmetric bullets constraint (only for list tasks) + if reveal_state.bullets_revealed and task_type == "list": + has_bullets = has_bullet_markers(text) + + if prefs.require_bullets: + # Want bullets but don't have them + enforced.append("require_bullets") + if not has_bullets: + violations.append("no_bullets") + else: + # Don't want bullets but have them + enforced.append("no_bullets_pref") + if has_bullets: + violations.append("has_bullets") + + # 3) FIX 3: Language with CJK ratio + math exemption + if reveal_state.lang_revealed: + enforced.append("lang") + if check_language_violation(text, prefs.lang): + violations.append("wrong_lang") + + # 4) Code task + prog_t = 1.0 + if task_type == "code": + enforced.append("code_block") + has_code = ("```" in text) or ("def " in text) or ("function " in text) + if not has_code: + violations.append("no_code_block") + prog_t = 0.0 + + # 5) Compute scores + if not violations: + sat_t = 1.0 + sev_t = 0.0 + else: + sat_t = max(0.0, 1.0 - 0.3 * float(len(violations))) + hard_violations = {"empty_answer", "too_long", "wrong_lang"} + sev_t = 1.0 if any(v in hard_violations for v in violations) else 0.0 + + return JudgeResult(sat_t=sat_t, sev_t=sev_t, prog_t=prog_t, + violations=violations, enforced_constraints=enforced) + + +# ============================================================================= +# Feedback Computation +# ============================================================================= + +def compute_feedback_for_turn( + query: str, + query_type: str, + judge_result: JudgeResult, +) -> Tuple[float, float]: + """Convert JudgeResult into (reward, gating).""" + reward = judge_result.sat_t + + lower_q = (query or "").lower() + original_q = query or "" + + is_pref_turn = ( + query_type == "preference" + or "i prefer" in lower_q + or "my preference" in lower_q + or "please use" in lower_q + or "please keep" in lower_q + or "you didn't follow" in lower_q + or "you forgot" in lower_q + or "remember that i" in lower_q + or "i told you" in lower_q + or "i asked for" in lower_q + or "that was too" in lower_q + or "too long" in lower_q + or "请用中文" in original_q + or "不要" in original_q + or "简短" in original_q + ) + + gating = 1.0 if is_pref_turn else 0.0 + return reward, gating + + +# ============================================================================= +# FIX 5: Violation-Triggered Complaint Generation +# ============================================================================= + +def generate_complaint_query(violations: List[str], prefs: StylePrefs) -> Optional[Dict[str, Any]]: + """ + Generate a complaint query based on violations. + Returns None if no complaint needed. + """ + if not violations: + return None + + # Priority: address most severe violation first + complaint = None + + if "too_long" in violations: + if prefs.lang == "zh": + complaint = { + "query": f"回答太长了。请保持回复在{prefs.max_chars}字以内。", + "type": "preference", + "task_type": "general", + } + else: + complaint = { + "query": f"That was too long. Please keep responses under {prefs.max_chars} characters.", + "type": "preference", + "task_type": "general", + } + + elif "wrong_lang" in violations: + if prefs.lang == "zh": + complaint = { + "query": "请用中文回答。", + "type": "preference", + "task_type": "general", + } + else: + complaint = { + "query": "Please respond in English.", + "type": "preference", + "task_type": "general", + } + + elif "no_bullets" in violations: + if prefs.lang == "zh": + complaint = { + "query": "请在列出内容时使用项目符号(bullet points)。", + "type": "preference", + "task_type": "general", + } + else: + complaint = { + "query": "Please use bullet points when listing things.", + "type": "preference", + "task_type": "general", + } + + elif "has_bullets" in violations: + if prefs.lang == "zh": + complaint = { + "query": "请不要使用项目符号,用连续的句子来表达。", + "type": "preference", + "task_type": "general", + } + else: + complaint = { + "query": "Please don't use bullet points. Use continuous prose instead.", + "type": "preference", + "task_type": "general", + } + + return complaint + + +# ============================================================================= +# FIX 4: Persona-Conditional Query Templates +# ============================================================================= + +def get_session_1_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]: + """ + Session 1: Reveal preferences (persona-conditional). + FIX: Only reveal preferences that match the persona's true prefs. + """ + queries = [] + prefs = persona.style_prefs + + # Turn 0: Reveal length preference (only if require_short=True) + if prefs.require_short: + if prefs.lang == "zh": + queries.append({ + "query": f"我喜欢简短的回答,请保持回复在{prefs.max_chars}字以内。", + "type": "preference", + "task_type": "general", + }) + else: + queries.append({ + "query": f"I prefer short, concise answers. Please keep responses under {prefs.max_chars} characters.", + "type": "preference", + "task_type": "general", + }) + else: + # Don't reveal short preference for long-preferring personas + if prefs.lang == "zh": + queries.append({ + "query": "你好,我有一些问题想问你。", + "type": "task", + "task_type": "general", + }) + else: + queries.append({ + "query": "Hello, I have some questions for you.", + "type": "task", + "task_type": "general", + }) + + # Turn 1: First task + if prefs.lang == "zh": + queries.append({"query": "列出三个改善睡眠的建议。", "type": "task", "task_type": "list"}) + else: + queries.append({"query": "List three tips for better sleep.", "type": "task", "task_type": "list"}) + + # Turn 2: Reveal bullet preference (conditional on require_bullets) + if prefs.require_bullets: + if prefs.lang == "zh": + queries.append({ + "query": "我喜欢用项目符号列出要点,请使用bullet points。", + "type": "preference", + "task_type": "general", + }) + else: + queries.append({ + "query": "I prefer bullet points when listing things. Please use bullet points.", + "type": "preference", + "task_type": "general", + }) + else: + # Explicitly say NO bullets + if prefs.lang == "zh": + queries.append({ + "query": "请不要用项目符号,我更喜欢连续的句子来表达。", + "type": "preference", + "task_type": "general", + }) + else: + queries.append({ + "query": "Please don't use bullet points. I prefer continuous prose.", + "type": "preference", + "task_type": "general", + }) + + # Turn 3: Reveal language preference + if prefs.lang == "zh": + queries.append({"query": "请用中文回答我的问题。", "type": "preference", "task_type": "general"}) + else: + queries.append({"query": "Please respond in English.", "type": "preference", "task_type": "general"}) + + # Turn 4-5: Tasks + if prefs.lang == "zh": + queries.extend([ + {"query": "锻炼有什么好处?", "type": "task", "task_type": "list"}, + {"query": "列出五种流行的编程语言。", "type": "task", "task_type": "list"}, + ]) + else: + queries.extend([ + {"query": "What are the benefits of exercise?", "type": "task", "task_type": "list"}, + {"query": "Name five popular programming languages.", "type": "task", "task_type": "list"}, + ]) + + return queries + + +def get_session_2_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]: + """Session 2: NO preference restatement.""" + prefs = persona.style_prefs + + if prefs.lang == "zh": + return [ + {"query": "推荐三种健康的早餐。", "type": "task", "task_type": "list"}, + {"query": "一年有哪四个季节?", "type": "task", "task_type": "list"}, + {"query": "法国的首都是哪里?", "type": "task", "task_type": "qa"}, + {"query": "列出三种可再生能源。", "type": "task", "task_type": "list"}, + ] + else: + return [ + {"query": "What are three healthy breakfast ideas?", "type": "task", "task_type": "list"}, + {"query": "What are the four seasons of the year?", "type": "task", "task_type": "list"}, + {"query": "What is the capital of France?", "type": "task", "task_type": "qa"}, + {"query": "Name three types of renewable energy.", "type": "task", "task_type": "list"}, + ] + + +def get_session_3_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]: + """ + Session 3: Tasks with ONE persona-conditional reminder. + FIX: Reminder matches persona's actual preferences. + """ + prefs = persona.style_prefs + queries = [] + + # First task + if prefs.lang == "zh": + queries.append({"query": "列出五种常见的水果。", "type": "task", "task_type": "list"}) + else: + queries.append({"query": "Name five common fruits.", "type": "task", "task_type": "list"}) + + # Persona-conditional reminder + if prefs.require_short and prefs.require_bullets: + if prefs.lang == "zh": + queries.append({ + "query": "记住我喜欢简短的回答和项目符号。列出三种海洋动物。", + "type": "preference", "task_type": "list" + }) + else: + queries.append({ + "query": "Remember I prefer short answers with bullet points. List three ocean animals.", + "type": "preference", "task_type": "list" + }) + elif prefs.require_short and not prefs.require_bullets: + if prefs.lang == "zh": + queries.append({ + "query": "记住我喜欢简短的回答,不要用项目符号。列出三种海洋动物。", + "type": "preference", "task_type": "list" + }) + else: + queries.append({ + "query": "Remember I prefer short answers without bullet points. List three ocean animals.", + "type": "preference", "task_type": "list" + }) + elif not prefs.require_short and prefs.require_bullets: + if prefs.lang == "zh": + queries.append({ + "query": "记住我喜欢用项目符号列出要点。列出三种海洋动物。", + "type": "preference", "task_type": "list" + }) + else: + queries.append({ + "query": "Remember I prefer bullet points. List three ocean animals.", + "type": "preference", "task_type": "list" + }) + else: # not short and not bullets + if prefs.lang == "zh": + queries.append({ + "query": "记住我不喜欢用项目符号,喜欢连续的句子。列出三种海洋动物。", + "type": "preference", "task_type": "list" + }) + else: + queries.append({ + "query": "Remember I prefer continuous prose without bullet points. List three ocean animals.", + "type": "preference", "task_type": "list" + }) + + # Final task + if prefs.lang == "zh": + queries.append({"query": "2加2等于多少?", "type": "task", "task_type": "qa"}) + else: + queries.append({"query": "What is 2 + 2?", "type": "task", "task_type": "qa"}) + + return queries + + +def get_pure_task_queries_for_persona(persona: Persona, session_idx: int) -> List[Dict[str, Any]]: + """ + Pure task sessions (S4+): NO preference reminders at all. + Used for testing long-term retention without any in-context hints. + Different task sets per session to avoid repetition. + """ + prefs = persona.style_prefs + + # Task pools for variety + zh_task_pools = [ + # Pool 1 + [ + {"query": "列出三种热带水果。", "type": "task", "task_type": "list"}, + {"query": "列出三种常见的编程语言。", "type": "task", "task_type": "list"}, + {"query": "什么是光合作用?", "type": "task", "task_type": "qa"}, + {"query": "太阳系有几颗行星?", "type": "task", "task_type": "qa"}, + ], + # Pool 2 + [ + {"query": "列出三种室内植物。", "type": "task", "task_type": "list"}, + {"query": "列出三种运动项目。", "type": "task", "task_type": "list"}, + {"query": "什么是人工智能?", "type": "task", "task_type": "qa"}, + {"query": "地球的自转周期是多少?", "type": "task", "task_type": "qa"}, + ], + # Pool 3 + [ + {"query": "列出三种乐器。", "type": "task", "task_type": "list"}, + {"query": "列出三种社交媒体平台。", "type": "task", "task_type": "list"}, + {"query": "什么是区块链?", "type": "task", "task_type": "qa"}, + {"query": "月球绕地球一周需要多长时间?", "type": "task", "task_type": "qa"}, + ], + # Pool 4 + [ + {"query": "列出三种鸟类。", "type": "task", "task_type": "list"}, + {"query": "列出三种数据库系统。", "type": "task", "task_type": "list"}, + {"query": "什么是机器学习?", "type": "task", "task_type": "qa"}, + {"query": "水的沸点是多少?", "type": "task", "task_type": "qa"}, + ], + ] + + en_task_pools = [ + # Pool 1 + [ + {"query": "List three tropical fruits.", "type": "task", "task_type": "list"}, + {"query": "List three popular programming languages.", "type": "task", "task_type": "list"}, + {"query": "What is photosynthesis?", "type": "task", "task_type": "qa"}, + {"query": "How many planets are in our solar system?", "type": "task", "task_type": "qa"}, + ], + # Pool 2 + [ + {"query": "List three indoor plants.", "type": "task", "task_type": "list"}, + {"query": "List three types of sports.", "type": "task", "task_type": "list"}, + {"query": "What is artificial intelligence?", "type": "task", "task_type": "qa"}, + {"query": "How long is a day on Earth?", "type": "task", "task_type": "qa"}, + ], + # Pool 3 + [ + {"query": "List three musical instruments.", "type": "task", "task_type": "list"}, + {"query": "List three social media platforms.", "type": "task", "task_type": "list"}, + {"query": "What is blockchain?", "type": "task", "task_type": "qa"}, + {"query": "How long does it take the Moon to orbit Earth?", "type": "task", "task_type": "qa"}, + ], + # Pool 4 + [ + {"query": "List three types of birds.", "type": "task", "task_type": "list"}, + {"query": "List three database systems.", "type": "task", "task_type": "list"}, + {"query": "What is machine learning?", "type": "task", "task_type": "qa"}, + {"query": "What is the boiling point of water?", "type": "task", "task_type": "qa"}, + ], + ] + + pools = zh_task_pools if prefs.lang == "zh" else en_task_pools + # Rotate through pools based on session index + pool_idx = (session_idx - 4) % len(pools) + return pools[pool_idx] + + +def get_queries_for_session(persona: Persona, session_id: int) -> List[Dict[str, Any]]: + """ + Get queries for a specific session. + S1: Preference reveal + S2: Pure task (no reminder) + S3: Tasks with ONE reminder + S4+: Pure task (testing long-term retention) + """ + if session_id == 1: + return get_session_1_queries_for_persona(persona) + elif session_id == 2: + return get_session_2_queries_for_persona(persona) + elif session_id == 3: + return get_session_3_queries_for_persona(persona) + else: + return get_pure_task_queries_for_persona(persona, session_id) + + +# ============================================================================= +# Logging +# ============================================================================= + +@dataclass +class TurnLog: + """Log entry for one turn.""" + user_id: str + persona_id: str + session_id: int + turn_id: int + query: str + query_type: str + task_type: str + answer: str + answer_length: int + sat_t: float + sev_t: float + prog_t: float + violations: List[str] + enforced_constraints: List[str] + reward: float + gating: float + is_complaint: bool + reveal_state_before: Dict[str, bool] + reveal_state_after: Dict[str, bool] + newly_revealed: List[str] + z_long_norm_before: float + z_long_norm_after: float + z_short_norm_before: float + z_short_norm_after: float + prompt_tokens: int + completion_tokens: int + total_tokens: int + # Memory retrieval details + num_memories_retrieved: int + num_prefs_extracted: int + selected_memory_ids: List[str] + selected_memory_notes: List[str] + selected_memory_scores: List[float] + num_candidates: int + num_total_memories: int + # Mode indicators + mode: str # "full" or "nopersonal" + eval_mode: bool # True = greedy, False = sample + + +def log_to_jsonl(logs: List[TurnLog], filepath: str): + os.makedirs(os.path.dirname(filepath), exist_ok=True) + with open(filepath, "w") as f: + for log in logs: + f.write(json.dumps(asdict(log)) + "\n") + + +# ============================================================================= +# Session Runner with Complaint Injection (FIX 5) +# ============================================================================= + +def run_session_v4( + llm: PersonalizedLLM, + user_id: str, + persona: Persona, + session_id: int, + reveal_state: RevealState, + base_queries: List[Dict[str, Any]], + all_logs: List[TurnLog], + enable_complaints: bool = True, +) -> List[TurnLog]: + """ + Run session with violation-triggered complaint injection. + """ + prefs = persona.style_prefs + session_logs: List[TurnLog] = [] + + print(f"\n{'='*60}") + print(f"[{persona.persona_id}] Session {session_id}: base queries={len(base_queries)}") + print(f"Reveal state (start): {reveal_state}") + print(f"{'='*60}") + + llm.reset_session(user_id) + + # Build dynamic query queue + query_queue = list(base_queries) + turn_id = 0 + + while query_queue: + q_info = query_queue.pop(0) + query = q_info["query"] + query_type = q_info.get("type", "task") + task_type = q_info.get("task_type", "general") + is_complaint = q_info.get("is_complaint", False) + + print(f"\n--- S{session_id}/T{turn_id} [{query_type}]{' [COMPLAINT]' if is_complaint else ''} ---") + print(f"[Q] {query[:60]}{'...' if len(query) > 60 else ''}") + + reveal_before = reveal_state.to_dict() + newly_revealed = update_reveal_state(reveal_state, query, prefs) + if newly_revealed: + print(f"[Reveal] Newly: {newly_revealed}") + reveal_after = reveal_state.to_dict() + + state_before = llm.get_user_state_summary(user_id) + z_long_before = state_before["z_long_norm"] + z_short_before = state_before["z_short_norm"] + + # Apply feedback for previous turn + if turn_id > 0 and session_logs: + prev_log = session_logs[-1] + feedback = Feedback( + user_id=user_id, + turn_id=prev_log.turn_id, + reward=prev_log.reward, + gating=prev_log.gating, + meta={"source": "pilot_v4", "session_id": session_id} + ) + llm.apply_feedback(feedback) + + # Chat + resp: AssistantResponse = llm.chat(user_id, query) + + answer_display = resp.answer[:80] + "..." if len(resp.answer) > 80 else resp.answer + print(f"[A] ({len(resp.answer)}c) {answer_display}") + + # Judge + judge_result = style_judge_v4(query, resp.answer, task_type, prefs, reveal_state) + print(f"[J] sat={judge_result.sat_t:.2f}, enforced={judge_result.enforced_constraints}, viol={judge_result.violations}") + + # Compute feedback + reward, gating = compute_feedback_for_turn(query, query_type, judge_result) + + state_after = llm.get_user_state_summary(user_id) + z_long_after = state_after["z_long_norm"] + z_short_after = state_after["z_short_norm"] + + # Extract memory info from debug + if resp.debug: + num_memories = len(resp.debug.selected_memory_ids) + num_prefs = len(resp.debug.extracted_preferences) + selected_memory_ids = resp.debug.selected_memory_ids + selected_memory_notes = resp.debug.selected_memory_notes + selected_memory_scores = resp.debug.selected_memory_scores + num_candidates = resp.debug.extra.get("num_candidates", 0) + num_total_memories = resp.debug.extra.get("num_total_memories", 0) + else: + num_memories = 0 + num_prefs = 0 + selected_memory_ids = [] + selected_memory_notes = [] + selected_memory_scores = [] + num_candidates = 0 + num_total_memories = 0 + + # Log + log = TurnLog( + user_id=user_id, + persona_id=persona.persona_id, + session_id=session_id, + turn_id=turn_id, + query=query, + query_type=query_type, + task_type=task_type, + answer=resp.answer, + answer_length=len(resp.answer), + sat_t=judge_result.sat_t, + sev_t=judge_result.sev_t, + prog_t=judge_result.prog_t, + violations=judge_result.violations, + enforced_constraints=judge_result.enforced_constraints, + reward=reward, + gating=gating, + is_complaint=is_complaint, + reveal_state_before=reveal_before, + reveal_state_after=reveal_after, + newly_revealed=list(newly_revealed), + z_long_norm_before=z_long_before, + z_long_norm_after=z_long_after, + z_short_norm_before=z_short_before, + z_short_norm_after=z_short_after, + prompt_tokens=resp.usage.prompt_tokens, + completion_tokens=resp.usage.completion_tokens, + total_tokens=resp.usage.total_tokens, + num_memories_retrieved=num_memories, + num_prefs_extracted=num_prefs, + selected_memory_ids=selected_memory_ids, + selected_memory_notes=selected_memory_notes, + selected_memory_scores=selected_memory_scores, + num_candidates=num_candidates, + num_total_memories=num_total_memories, + mode=llm.mode, + eval_mode=llm.eval_mode, + ) + session_logs.append(log) + all_logs.append(log) + + # FIX 5: Inject complaint if there were violations and this wasn't already a complaint + if enable_complaints and judge_result.violations and not is_complaint: + complaint = generate_complaint_query(judge_result.violations, prefs) + if complaint: + complaint["is_complaint"] = True + query_queue.insert(0, complaint) # Insert at front + print(f"[Complaint Injected] Will complain about: {judge_result.violations}") + + turn_id += 1 + + # Final feedback + if session_logs: + last_log = session_logs[-1] + feedback = Feedback( + user_id=user_id, + turn_id=last_log.turn_id, + reward=last_log.reward, + gating=last_log.gating, + meta={"source": "pilot_v4", "session_id": session_id, "final": True} + ) + llm.apply_feedback(feedback) + + print(f"\n[Session {session_id} End] Reveal: {reveal_state}, Turns: {turn_id}") + return session_logs + + +# ============================================================================= +# Multi-User Multi-Session Runner +# ============================================================================= + +def run_multi_user_pilot_v4( + llm: PersonalizedLLM, + personas: List[Persona], + num_sessions: int = 3, + enable_complaints: bool = True, +) -> List[TurnLog]: + """Run multi-user multi-session pilot v4.""" + reveal_manager = RevealStateManager() + all_logs: List[TurnLog] = [] + + print(f"\n{'#'*60}") + print(f"PILOT v4: MULTI-USER MULTI-SESSION (Fixed)") + print(f"Users: {len(personas)}, Sessions: {num_sessions}, Complaints: {enable_complaints}") + print(f"{'#'*60}") + + for persona in personas: + user_id = f"user_{persona.persona_id}" + prefs = persona.style_prefs + + print(f"\n{'*'*60}") + print(f"USER: {user_id}") + print(f"Persona: {persona.description}") + print(f"True prefs: short={prefs.require_short}, bullets={prefs.require_bullets}, lang={prefs.lang}") + print(f"{'*'*60}") + + llm.reset_user(user_id) + reveal_manager.reset_user(user_id) + reveal_state = reveal_manager.get_state(user_id) + + for session_id in range(1, num_sessions + 1): + queries = get_queries_for_session(persona, session_id) + + reveal_manager.reset_session(user_id) + run_session_v4(llm, user_id, persona, session_id, reveal_state, queries, all_logs, enable_complaints) + + return all_logs + + +# ============================================================================= +# Summary +# ============================================================================= + +def print_summary_v4(logs: List[TurnLog]): + """Print summary for pilot v4.""" + print(f"\n{'='*60}") + print("PILOT v4 SUMMARY") + print(f"{'='*60}") + + if not logs: + print("No logs.") + return + + from collections import Counter + + personas = sorted(set(l.persona_id for l in logs)) + + print(f"\n--- Per-Persona Statistics ---") + for pid in personas: + p_logs = [l for l in logs if l.persona_id == pid] + sessions = sorted(set(l.session_id for l in p_logs)) + + print(f"\n{pid}:") + for sid in sessions: + s_logs = [l for l in p_logs if l.session_id == sid] + avg_sat = sum(l.sat_t for l in s_logs) / len(s_logs) if s_logs else 0 + violations = [v for l in s_logs for v in l.violations] + enforced = set(c for l in s_logs for c in l.enforced_constraints) + complaints = sum(1 for l in s_logs if l.is_complaint) + + print(f" S{sid}: {len(s_logs)} turns, avg_sat={avg_sat:.3f}, complaints={complaints}") + print(f" enforced={enforced}") + if violations: + print(f" violations: {dict(Counter(violations))}") + + # Cross-session retention + print(f"\n--- Cross-Session Retention (S2 without preferences) ---") + for pid in personas: + p_logs = [l for l in logs if l.persona_id == pid] + s1_logs = [l for l in p_logs if l.session_id == 1] + s2_logs = [l for l in p_logs if l.session_id == 2] + + if s1_logs and s2_logs: + s1_sat = sum(l.sat_t for l in s1_logs) / len(s1_logs) + s2_sat = sum(l.sat_t for l in s2_logs) / len(s2_logs) + s2_enforced = set(c for l in s2_logs for c in l.enforced_constraints) + print(f"{pid}: S1={s1_sat:.3f} → S2={s2_sat:.3f}, enforced={s2_enforced}") + + # Violation rates + print(f"\n--- Violation Rates by Type ---") + all_violations = [v for l in logs for v in l.violations] + total_turns = len(logs) + if all_violations: + for v, count in Counter(all_violations).most_common(): + rate = count / total_turns * 100 + print(f" {v}: {count} ({rate:.1f}%)") + else: + print(" No violations") + + # Complaint effectiveness + print(f"\n--- Complaint Effectiveness ---") + complaint_logs = [l for l in logs if l.is_complaint] + if complaint_logs: + print(f"Total complaints: {len(complaint_logs)}") + avg_sat_complaint = sum(l.sat_t for l in complaint_logs) / len(complaint_logs) + print(f"Avg sat on complaint turns: {avg_sat_complaint:.3f}") + else: + print("No complaints generated") + + # Overall + total = len(logs) + avg_sat = sum(l.sat_t for l in logs) / total + total_tokens = sum(l.total_tokens for l in logs) + print(f"\n--- Overall ---") + print(f"Total turns: {total}, Avg sat: {avg_sat:.3f}, Total tokens: {total_tokens}") + + +def main(): + import argparse + parser = argparse.ArgumentParser(description="Pilot Runner v4 - Full vs Vanilla Comparison") + parser.add_argument("--mode", type=str, + choices=["full", "full-greedy", "full-sample", "nopersonal", "vanilla", "compare", "all"], + default="compare", + help="Mode: 'full-greedy' (personalized, deterministic), " + "'full-sample' (personalized, stochastic), " + "'nopersonal' (retrieval baseline without z_u), " + "'vanilla' (pure LLM, no memory), " + "'compare' (full-greedy vs vanilla), " + "'all' (run all modes)") + parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility") + parser.add_argument("--sessions", type=int, default=3, help="Number of sessions per user") + parser.add_argument("--no-complaints", action="store_true", help="Disable complaint injection") + args = parser.parse_args() + + # Set seeds for reproducibility + import random + import numpy as np + random.seed(args.seed) + np.random.seed(args.seed) + + print("=" * 60) + print("PILOT RUNNER v4 - Full vs Vanilla Comparison") + print("=" * 60) + print(f"Started at: {datetime.now().isoformat()}") + print(f"Mode: {args.mode}, Seed: {args.seed}, Sessions: {args.sessions}") + + personas = ALL_PERSONAS + print(f"\n[Config] {len(personas)} personas:") + for p in personas: + print(f" - {p.persona_id}: {p.description}") + + enable_complaints = not args.no_complaints + + # Map mode argument to actual run configurations + # Each config: (mode_name, llm_mode, eval_mode) + # llm_mode: "full", "nopersonal", or "vanilla" + # eval_mode: True = greedy/deterministic, False = stochastic sampling + if args.mode == "all": + run_configs = [ + ("full-greedy", "full", True), + ("full-sample", "full", False), + ("nopersonal", "nopersonal", True), + ("vanilla", "vanilla", True), + ] + elif args.mode == "compare": + # Main comparison: Full (with memory) vs Vanilla (no memory) + run_configs = [ + ("full-greedy", "full", True), + ("vanilla", "vanilla", True), + ] + elif args.mode == "full" or args.mode == "full-greedy": + run_configs = [("full-greedy", "full", True)] + elif args.mode == "full-sample": + run_configs = [("full-sample", "full", False)] + elif args.mode == "vanilla": + run_configs = [("vanilla", "vanilla", True)] + elif args.mode == "nopersonal": + run_configs = [("nopersonal", "nopersonal", True)] + else: + run_configs = [(args.mode, args.mode, True)] + + for run_name, llm_mode, eval_mode in run_configs: + print(f"\n{'#'*60}") + print(f"RUNNING: {run_name.upper()}") + print(f" llm_mode={llm_mode}, eval_mode={eval_mode} ({'greedy' if eval_mode else 'sample'})") + print(f"{'#'*60}") + + # Reset seeds before each run for exact reproducibility + random.seed(args.seed) + np.random.seed(args.seed) + + print(f"\n[Init] Loading PersonalizedLLM...") + llm = PersonalizedLLM( + user_store_path=f"data/users/user_store_pilot_v4_{run_name}.npz", + only_own_memories=True, + enable_preference_extraction=True, + enable_rl_updates=(llm_mode == "full"), # Disable RL for nopersonal + mode=llm_mode, + eval_mode=eval_mode, + device_assignment={ + "embed": "cuda:0", + "reranker": "cuda:1", + "chat": "cuda:2", + "extractor": "cuda:3", + }, + ) + + logs = run_multi_user_pilot_v4(llm, personas, num_sessions=args.sessions, enable_complaints=enable_complaints) + + print_summary_v4(logs) + + log_path = f"data/logs/pilot_v4_{run_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl" + log_to_jsonl(logs, log_path) + print(f"\n[Logs] Saved to: {log_path}") + + # Save user vectors for similarity analysis + if llm_mode == "full": + llm.persist() + print(f"[Persist] User vectors saved to: {llm._user_store.path}") + + print(f"\nCompleted at: {datetime.now().isoformat()}") + print("=" * 60) + + +if __name__ == "__main__": + main() + diff --git a/scripts/pilot_study.py b/scripts/pilot_study.py new file mode 100644 index 0000000..9754c42 --- /dev/null +++ b/scripts/pilot_study.py @@ -0,0 +1,109 @@ +import json +import os +import random +import asyncio +from typing import List, Dict, Any +from openai import AsyncOpenAI +from tqdm.asyncio import tqdm_asyncio + +# --- Configuration --- +INPUT_FILE = "data/raw_datasets/combined_raw_queries.jsonl" +OUTPUT_FILE = "data/raw_datasets/pilot_study_1000.jsonl" +SAMPLE_SIZE = 1000 +MODEL_NAME = "gpt-5.1" # Or your specific model ID +MAX_CONCURRENCY = 100 # Adjust based on your rate limits + +# --- Load System Prompt --- +with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f: + # Extract the system prompt part (before the examples to save tokens, + # or keep full if you want few-shot behavior). + # Based on the file content you wrote earlier, let's use the whole thing + # as the system instruction to ensure high quality. + SYSTEM_PROMPT = f.read() + +# --- Async Worker --- +async def label_query(client: AsyncOpenAI, sem: asyncio.Semaphore, item: Dict[str, Any]) -> Dict[str, Any]: + query = item["query"] + async with sem: + try: + response = await client.chat.completions.create( + model=MODEL_NAME, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": query} + ], + temperature=0.0, # Deterministic for extraction + response_format={"type": "json_object"} # Enforce JSON + ) + result_text = response.choices[0].message.content + + # Parse to ensure validity + try: + parsed = json.loads(result_text) + prefs = parsed.get("preferences", []) + has_pref = len(prefs) > 0 + except: + parsed = {"error": "json_parse_fail", "raw": result_text} + has_pref = False + + return { + "original_query": query, + "source": item.get("source"), + "extracted_json": parsed, + "has_preference": has_pref + } + except Exception as e: + return { + "original_query": query, + "source": item.get("source"), + "error": str(e), + "has_preference": False + } + +async def main(): + # 1. Load and Sample + print(f"Loading data from {INPUT_FILE}...") + all_lines = [] + with open(INPUT_FILE, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + all_lines.append(json.loads(line)) + + if len(all_lines) > SAMPLE_SIZE: + sampled_data = random.sample(all_lines, SAMPLE_SIZE) + else: + sampled_data = all_lines + print(f"Sampled {len(sampled_data)} items.") + + # 2. Setup OpenAI Client + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("Error: OPENAI_API_KEY environment variable not set.") + return + + client = AsyncOpenAI(api_key=api_key) + sem = asyncio.Semaphore(MAX_CONCURRENCY) + + # 3. Run Labeling + tasks = [label_query(client, sem, item) for item in sampled_data] + results = await tqdm_asyncio.gather(*tasks, desc="Labeling") + + # 4. Statistics & Save + pos_count = sum(1 for r in results if r.get("has_preference")) + total = len(results) + ratio = (pos_count / total) * 100 if total > 0 else 0 + + print(f"\n--- Results ---") + print(f"Total processed: {total}") + print(f"Positive (has preferences): {pos_count}") + print(f"Negative (empty): {total - pos_count}") + print(f"Positive Ratio: {ratio:.2f}%") + + with open(OUTPUT_FILE, "w", encoding="utf-8") as f: + for res in results: + f.write(json.dumps(res, ensure_ascii=False) + "\n") + print(f"Saved detailed results to {OUTPUT_FILE}") + +if __name__ == "__main__": + asyncio.run(main()) + diff --git a/scripts/process_putnam_batch.py b/scripts/process_putnam_batch.py new file mode 100644 index 0000000..27e0465 --- /dev/null +++ b/scripts/process_putnam_batch.py @@ -0,0 +1,239 @@ +import json +import os +import time +from typing import List, Dict, Any +from openai import OpenAI +from collections import Counter + +# Configuration +API_KEY = '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' +BATCH_ID_FILE = "data/putnam_eval/submitted_batch_ids.json" +INPUT_FILE = "data/putnam_eval/putnam_eval_batch.jsonl" +OUTPUT_FILE = "data/putnam_eval/final_results.json" +MODEL_NAME = "gpt-5" + +def load_input_requests(filepath: str) -> Dict[str, Any]: + """Load the original requests to allow retrying.""" + requests_map = {} + print(f"Loading input requests from {filepath}...") + with open(filepath, "r", encoding="utf-8") as f: + for line in f: + if not line.strip(): + continue + item = json.loads(line) + requests_map[item["custom_id"]] = item + return requests_map + +def retrieve_batch_results(client: OpenAI, batch_id: str) -> List[Dict[str, Any]]: + """Retrieve and parse batch results.""" + print(f"Checking status for batch {batch_id}...") + batch = client.batches.retrieve(batch_id) + + print(f"Batch Status: {batch.status}") + print(f"Output File ID: {batch.output_file_id}") + print(f"Error File ID: {batch.error_file_id}") + + results = [] + + if batch.output_file_id: + print("Downloading output file...") + file_response = client.files.content(batch.output_file_id) + file_content = file_response.read().decode("utf-8") + + for line in file_content.splitlines(): + if line.strip(): + results.append(json.loads(line)) + + if batch.error_file_id: + print("Downloading error file (if any)...") + # Usually contains request-level errors that didn't generate a response object in output + try: + err_response = client.files.content(batch.error_file_id) + err_content = err_response.read().decode("utf-8") + for line in err_content.splitlines(): + if line.strip(): + results.append(json.loads(line)) + except Exception as e: + print(f"Note: Could not download/parse error file: {e}") + + return results + +def process_results_and_find_failures(results: List[Dict[str, Any]], all_request_ids: set) -> tuple[List[Dict[str, Any]], List[str]]: + """Separate successful parsable results from failures.""" + valid_results = [] + failed_ids = [] + seen_ids = set() + + for res in results: + custom_id = res.get("custom_id") + seen_ids.add(custom_id) + + # Check for API level errors + if res.get("error"): + print(f"Request {custom_id} failed with error: {res['error']}") + failed_ids.append(custom_id) + continue + + response = res.get("response", {}) + if response.get("status_code") != 200: + print(f"Request {custom_id} failed with status {response.get('status_code')}") + failed_ids.append(custom_id) + continue + + # Try to parse the content as JSON + try: + body = response.get("body", {}) + choices = body.get("choices", []) + if not choices: + print(f"Request {custom_id} has no choices.") + failed_ids.append(custom_id) + continue + + content_str = choices[0].get("message", {}).get("content", "") + content_json = json.loads(content_str) + + valid_results.append({ + "custom_id": custom_id, + "analysis": content_json + }) + except json.JSONDecodeError: + print(f"Request {custom_id} returned invalid JSON content.") + failed_ids.append(custom_id) + except Exception as e: + print(f"Request {custom_id} unexpected processing error: {e}") + failed_ids.append(custom_id) + + # Check for completely missing requests + missing_ids = all_request_ids - seen_ids + if missing_ids: + print(f"Found {len(missing_ids)} missing requests that were not in the batch output.") + failed_ids.extend(list(missing_ids)) + + return valid_results, failed_ids + +def retry_failed_requests(client: OpenAI, failed_ids: List[str], input_map: Dict[str, Any]) -> List[Dict[str, Any]]: + """Retry specific requests synchronously.""" + retried_results = [] + print(f"\nRetrying {len(failed_ids)} failed requests synchronously...") + + for i, custom_id in enumerate(failed_ids): + if custom_id not in input_map: + print(f"Warning: Original request for {custom_id} not found.") + continue + + print(f"Retrying {i+1}/{len(failed_ids)}: {custom_id}") + original_req = input_map[custom_id] + body = original_req["body"] + + try: + response = client.chat.completions.create( + model=MODEL_NAME, # Use the model from the script constant, not necessarily the batch one if we want to enforce gpt-5 + messages=body["messages"], + response_format=body.get("response_format"), + temperature=body.get("temperature", 1.0) # Default if not set, usually 0 in our templates? + ) + + content_str = response.choices[0].message.content + content_json = json.loads(content_str) + + retried_results.append({ + "custom_id": custom_id, + "analysis": content_json + }) + except Exception as e: + print(f"Retry failed for {custom_id}: {e}") + + return retried_results + +def print_stats(final_results: List[Dict[str, Any]]): + """Calculate and print statistics.""" + total = len(final_results) + if total == 0: + print("No results to analyze.") + return + + # Categories + valid_variant_count = 0 + correct_solution_count = 0 + equivalent_count = 0 + strongly_related_count = 0 + + # Validation Consistency + both_valid_and_equiv = 0 + + print(f"\n--- Statistics (N={total}) ---") + + for item in final_results: + analysis = item["analysis"] + validity = analysis.get("variant_validity", {}) + relation = analysis.get("relation_to_original", {}) + + is_valid = validity.get("is_problem_valid", False) + is_correct = validity.get("is_solution_correct", False) + is_equiv = relation.get("is_equivalent", False) + is_related = relation.get("is_strongly_related", False) + + if is_valid: valid_variant_count += 1 + if is_correct: correct_solution_count += 1 + if is_equiv: equivalent_count += 1 + if is_related: strongly_related_count += 1 + + if is_valid and is_correct and (is_equiv or is_related): + both_valid_and_equiv += 1 + + print(f"Variant Valid: {valid_variant_count} ({valid_variant_count/total:.1%})") + print(f"Solution Correct: {correct_solution_count} ({correct_solution_count/total:.1%})") + print(f"Equivalent: {equivalent_count} ({equivalent_count/total:.1%})") + print(f"Strongly Related: {strongly_related_count} ({strongly_related_count/total:.1%})") + print(f"Valid & Rel/Equiv: {both_valid_and_equiv} ({both_valid_and_equiv/total:.1%})") + +def main(): + if not API_KEY: + print("Error: API_KEY not set.") + return + + client = OpenAI(api_key=API_KEY) + + # 1. Get Batch ID + if not os.path.exists(BATCH_ID_FILE): + print(f"Batch ID file not found at {BATCH_ID_FILE}") + return + + with open(BATCH_ID_FILE, "r") as f: + batch_ids = json.load(f) + if not batch_ids: + print("No batch IDs found.") + return + batch_id = batch_ids[-1] # Take the latest one + print(f"Processing Batch ID: {batch_id}") + + # 2. Retrieve Results + raw_results = retrieve_batch_results(client, batch_id) + + # 3. Load Inputs (to identify missing/failed IDs) + input_map = load_input_requests(INPUT_FILE) + all_request_ids = set(input_map.keys()) + + # 4. Parse and Find Failures + valid_results, failed_ids = process_results_and_find_failures(raw_results, all_request_ids) + print(f"Successfully parsed: {len(valid_results)}") + print(f"Failed/Missing: {len(failed_ids)}") + + # 5. Retry Failures + if failed_ids: + retry_results = retry_failed_requests(client, failed_ids, input_map) + valid_results.extend(retry_results) + + # 6. Save Final Results + print(f"Saving {len(valid_results)} results to {OUTPUT_FILE}...") + with open(OUTPUT_FILE, "w", encoding="utf-8") as f: + json.dump(valid_results, f, indent=2) + + # 7. Stats + print_stats(valid_results) + +if __name__ == "__main__": + main() + + + diff --git a/scripts/pull_models.py b/scripts/pull_models.py new file mode 100644 index 0000000..1e4abc3 --- /dev/null +++ b/scripts/pull_models.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import argparse +from pathlib import Path +import sys + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT / "src")) + +from huggingface_hub import snapshot_download # type: ignore + +from personalization.config.settings import load_local_models_config + + +def pull_one(repo_id: str, dest: Path, force: bool = False) -> None: + dest.mkdir(parents=True, exist_ok=True) + if any(dest.iterdir()) and not force: + print(f"[skip] {dest} already populated. Use --force to overwrite.") + return + print(f"[pull] {repo_id} -> {dest}") + snapshot_download(repo_id=repo_id, local_dir=str(dest), local_dir_use_symlinks=False) + print(f"[done] {repo_id}") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--target", + choices=[ + "llm", + "preference_extractor", + "embed_qwen3", + "embed_nemotron", + "embedders", + "reranker_qwen3", + "rerankers", + "all", + ], + default="all", + ) + parser.add_argument("--force", action="store_true") + args = parser.parse_args() + + cfg = load_local_models_config() + if args.target in ("llm", "all"): + pull_one(cfg.llm.hf_id, ROOT / cfg.llm.local_path, force=args.force) + if args.target in ("preference_extractor", "all"): + pull_one( + cfg.preference_extractor.hf_id, + ROOT / cfg.preference_extractor.local_path, + force=args.force, + ) + if args.target in ("embed_qwen3", "embedders", "all") and cfg.embedding and cfg.embedding.qwen3: + pull_one( + cfg.embedding.qwen3.hf_id, + ROOT / cfg.embedding.qwen3.local_path, + force=args.force, + ) + if args.target in ("embed_nemotron", "embedders", "all") and cfg.embedding and cfg.embedding.nemotron: + pull_one( + cfg.embedding.nemotron.hf_id, + ROOT / cfg.embedding.nemotron.local_path, + force=args.force, + ) + if args.target in ("reranker_qwen3", "rerankers", "all") and cfg.reranker and cfg.reranker.qwen3_8b: + pull_one( + cfg.reranker.qwen3_8b.hf_id, + ROOT / cfg.reranker.qwen3_8b.local_path, + force=args.force, + ) + + +if __name__ == "__main__": + main() + + diff --git a/scripts/recompute_embeddings.py b/scripts/recompute_embeddings.py new file mode 100644 index 0000000..884cc7b --- /dev/null +++ b/scripts/recompute_embeddings.py @@ -0,0 +1,65 @@ +import json +import os +import sys +import numpy as np +import torch +from tqdm import tqdm + +# Add src to sys.path +sys.path.append(os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.config.settings import load_local_models_config +from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B +from personalization.retrieval.preference_store.schemas import MemoryCard + +CARDS_FILE = "data/corpora/memory_cards.jsonl" +EMBEDDINGS_FILE = "data/corpora/memory_embeddings.npy" + +def recompute_embeddings(): + if not os.path.exists(CARDS_FILE): + print(f"Error: {CARDS_FILE} not found.") + return + + print("Loading configuration and model...") + cfg = load_local_models_config() + embed_model = Qwen3Embedding8B.from_config(cfg) + + print(f"Reading memory cards from {CARDS_FILE}...") + cards = [] + texts = [] + with open(CARDS_FILE, "r", encoding="utf-8") as f: + for line in f: + if not line.strip(): continue + card = MemoryCard.model_validate_json(line) + cards.append(card) + # Embedding source: note_text (preference) or raw_query? + # Usually we embed the note_text for retrieval. + texts.append(card.note_text) + + print(f"Total cards: {len(cards)}") + + if not cards: + print("No cards found.") + return + + print("Computing embeddings...") + # Batch processing + batch_size = 32 + all_embs = [] + + for i in tqdm(range(0, len(texts), batch_size)): + batch_texts = texts[i : i + batch_size] + # Qwen3Embedding8B.encode returns list of lists (if return_tensor=False) + embs = embed_model.encode(batch_texts, return_tensor=False) + all_embs.extend(embs) + + emb_array = np.array(all_embs, dtype=np.float32) + print(f"Embeddings shape: {emb_array.shape}") + + print(f"Saving to {EMBEDDINGS_FILE}...") + np.save(EMBEDDINGS_FILE, emb_array) + print("Done!") + +if __name__ == "__main__": + recompute_embeddings() + diff --git a/scripts/recover_and_merge.py b/scripts/recover_and_merge.py new file mode 100644 index 0000000..b0f37f7 --- /dev/null +++ b/scripts/recover_and_merge.py @@ -0,0 +1,151 @@ +import json +import os +from openai import OpenAI +from typing import Dict, Any + +# --- Configuration --- +# 1. Main Batch IDs (The 340k success ones we lost) +MAIN_BATCH_IDS_FILE = "data/raw_datasets/submitted_batch_ids.json" +# 2. OASST1 Batch IDs (New) +OASST1_BATCH_IDS_FILE = "data/raw_datasets/submitted_oasst1_batch_ids.json" +OASST1_METADATA_FILE = "data/raw_datasets/oasst1_metadata_map.jsonl" + +# The file we want to APPEND to (currently has 68k retry items) +OUTPUT_FILE = "data/raw_datasets/labeled_full_dataset_batch.jsonl" + +# Original queries map for main batch reconstruction +ORIGINAL_INPUT_FILE = "data/raw_datasets/combined_raw_queries.jsonl" + +def load_original_queries() -> Dict[str, Dict[str, Any]]: + print("Loading original queries map (Main)...") + mapping = {} + with open(ORIGINAL_INPUT_FILE, "r", encoding="utf-8") as f: + for idx, line in enumerate(f): + if line.strip(): + mapping[f"req_{idx}"] = json.loads(line) + return mapping + +def load_oasst1_metadata() -> Dict[str, Dict[str, Any]]: + print("Loading OASST1 metadata map...") + mapping = {} + if os.path.exists(OASST1_METADATA_FILE): + with open(OASST1_METADATA_FILE, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + item = json.loads(line) + mapping[item["custom_id"]] = item + return mapping + +def recover_and_merge(): + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("Error: OPENAI_API_KEY not set.") + return + client = OpenAI(api_key=api_key) + + # Load Maps + main_query_map = load_original_queries() + oasst1_meta_map = load_oasst1_metadata() + + # We will append to the existing file which holds the RETRY results. + # So we don't lose the 68k we just fixed. + print(f"Appending recovered data to {OUTPUT_FILE}...") + + count_main = 0 + count_oasst1 = 0 + + with open(OUTPUT_FILE, "a", encoding="utf-8") as f_out: + + # --- 1. Recover Main Batches --- + if os.path.exists(MAIN_BATCH_IDS_FILE): + with open(MAIN_BATCH_IDS_FILE, "r") as f: + main_ids = json.load(f) + + print(f"\nRecovering {len(main_ids)} Main Batches...") + for b_id in main_ids: + try: + batch = client.batches.retrieve(b_id) + if batch.output_file_id: + print(f" Downloading {b_id} (Output: {batch.output_file_id})...") + content = client.files.content(batch.output_file_id).text + + for line in content.splitlines(): + if not line.strip(): continue + res = json.loads(line) + custom_id = res["custom_id"] + + if res["response"]["status_code"] == 200: + try: + body = res["response"]["body"] + llm_content = body["choices"][0]["message"]["content"] + parsed_json = json.loads(llm_content) + + original = main_query_map.get(custom_id) + if original: + record = { + "custom_id": custom_id, + "original_query": original["query"], + "source": original.get("source"), + "extracted_json": parsed_json, + "has_preference": len(parsed_json.get("preferences", [])) > 0 + } + f_out.write(json.dumps(record, ensure_ascii=False) + "\n") + count_main += 1 + except: + pass + except Exception as e: + print(f" Error {b_id}: {e}") + + # --- 2. Retrieve OASST1 Batches --- + # User requested to skip OASST1 merge for now. + # if os.path.exists(OASST1_BATCH_IDS_FILE): + # with open(OASST1_BATCH_IDS_FILE, "r") as f: + # oasst_ids = json.load(f) + + # print(f"\nRetrieving {len(oasst_ids)} OASST1 Batches...") + # for b_id in oasst_ids: + # try: + # batch = client.batches.retrieve(b_id) + # if batch.status == "completed" and batch.output_file_id: + # print(f" Downloading {b_id}...") + # content = client.files.content(batch.output_file_id).text + + # for line in content.splitlines(): + # if not line.strip(): continue + # res = json.loads(line) + # custom_id = res["custom_id"] + + # if res["response"]["status_code"] == 200: + # try: + # body = res["response"]["body"] + # llm_content = body["choices"][0]["message"]["content"] + # parsed_json = json.loads(llm_content) + + # meta = oasst1_meta_map.get(custom_id) + # if meta: + # record = { + # "custom_id": custom_id, + # "original_query": meta["original_query"], + # "source": "oasst1", + # "user_id": meta.get("user_id"), # Preserve User ID! + # "session_id": meta.get("session_id"), + # "extracted_json": parsed_json, + # "has_preference": len(parsed_json.get("preferences", [])) > 0 + # } + # f_out.write(json.dumps(record, ensure_ascii=False) + "\n") + # count_oasst1 += 1 + # except: + # pass + # except Exception as e: + # print(f" Error {b_id}: {e}") + + print("\n" + "="*50) + print("RECOVERY & MERGE COMPLETE") + print(f"Recovered Main: {count_main}") + print(f"New OASST1: {count_oasst1}") + print(f"Full dataset updated at: {OUTPUT_FILE}") + print("="*50) + +if __name__ == "__main__": + recover_and_merge() + diff --git a/scripts/retrieve_batch_results.py b/scripts/retrieve_batch_results.py new file mode 100644 index 0000000..aa26e28 --- /dev/null +++ b/scripts/retrieve_batch_results.py @@ -0,0 +1,151 @@ +import json +import os +import time +from typing import Dict, Any, List, Set +from openai import OpenAI + +# --- Configuration --- +BATCH_IDS_FILE = "data/raw_datasets/submitted_batch_ids.json" +ORIGINAL_INPUT_FILE = "data/raw_datasets/combined_raw_queries.jsonl" +OUTPUT_LABEL_FILE = "data/raw_datasets/labeled_full_dataset_batch.jsonl" +RETRY_INPUT_FILE = "data/raw_datasets/retry_requests.jsonl" +MODEL_NAME = "gpt-5.1" # Need this for reconstruction + +# Load System Prompt locally to avoid import errors +with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f: + SYSTEM_PROMPT = f.read() + +def load_original_queries() -> Dict[str, Dict[str, Any]]: + print("Loading original queries map...") + mapping = {} + with open(ORIGINAL_INPUT_FILE, "r", encoding="utf-8") as f: + for idx, line in enumerate(f): + if line.strip(): + mapping[f"req_{idx}"] = json.loads(line) + return mapping + +def process_batch_results(): + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("Error: OPENAI_API_KEY not set.") + return + client = OpenAI(api_key=api_key) + + if not os.path.exists(BATCH_IDS_FILE): + print(f"Error: {BATCH_IDS_FILE} not found.") + return + + with open(BATCH_IDS_FILE, "r") as f: + batch_ids = json.load(f) + + query_map = load_original_queries() + processed_ids: Set[str] = set() + + # We append to existing output file if it exists, or overwrite? + # To be safe and avoid duplicates if re-run, let's load existing processed IDs if file exists. + if os.path.exists(OUTPUT_LABEL_FILE): + print("Scanning existing output file to avoid duplicates...") + with open(OUTPUT_LABEL_FILE, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + try: + # We don't store custom_id in output, but we can infer or we should have stored it. + # Wait, the output format in previous run didn't store custom_id. + # But we can't easily dedup without it unless we match content. + # BETTER STRATEGY: Just overwrite OUTPUT_LABEL_FILE for this recovery run to be clean. + # Or, since we crashed mid-way, maybe overwrite is safer. + pass + except: + pass + + print("Starting fresh download/processing (Overwriting output)...") + + success_count = 0 + fail_count = 0 + + with open(OUTPUT_LABEL_FILE, "w", encoding="utf-8") as f_success: + for b_id in batch_ids: + print(f"\nProcessing Batch {b_id}...") + try: + batch = client.batches.retrieve(b_id) + + # 1. Output File (Success) + if batch.output_file_id: + print(f" Downloading output {batch.output_file_id}...") + content = client.files.content(batch.output_file_id).text + + for line in content.splitlines(): + if not line.strip(): continue + res = json.loads(line) + custom_id = res["custom_id"] + + if res["response"]["status_code"] == 200: + try: + body = res["response"]["body"] + llm_content = body["choices"][0]["message"]["content"] + parsed_json = json.loads(llm_content) + + original_item = query_map.get(custom_id) + if original_item: + record = { + "custom_id": custom_id, # Add this to help debug later + "original_query": original_item["query"], + "source": original_item.get("source"), + "extracted_json": parsed_json, + "has_preference": len(parsed_json.get("preferences", [])) > 0 + } + f_success.write(json.dumps(record, ensure_ascii=False) + "\n") + processed_ids.add(custom_id) + success_count += 1 + except Exception as e: + print(f" Parse Error {custom_id}: {e}") + # Parse error -> Fail + # If not 200, it's a fail, handled by logic below (since it won't be in processed_ids) + + # 2. Error File (Explicit Failures) + # We don't need to explicitly read error file to write retries, + # because we will do a global "Missing Check" at the end. + # But reading it helps debugging. + if batch.error_file_id: + print(f" Downloading ERROR {batch.error_file_id}...") + # Just print count + # content = client.files.content(batch.error_file_id).text + # print(f" Found {len(content.splitlines())} errors in error file.") + + except Exception as e: + print(f" CRITICAL ERROR processing batch {b_id}: {e}") + + # --- Missing Check & Retry Generation --- + print(f"\nVerifying completeness... (Total Queries: {len(query_map)})") + print(f"Successful processed: {len(processed_ids)}") + + with open(RETRY_INPUT_FILE, "w", encoding="utf-8") as f_retry: + for custom_id, original_item in query_map.items(): + if custom_id not in processed_ids: + fail_count += 1 + + # Reconstruct Request + request_obj = { + "custom_id": custom_id, + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": MODEL_NAME, + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": original_item["query"]} + ], + "temperature": 0.0, + "response_format": {"type": "json_object"} + } + } + f_retry.write(json.dumps(request_obj) + "\n") + + print("\n" + "="*50) + print(f"Processing Complete.") + print(f"Successful: {success_count} (Saved to {OUTPUT_LABEL_FILE})") + print(f"To Retry: {fail_count} (Saved to {RETRY_INPUT_FILE})") + print("="*50) + +if __name__ == "__main__": + process_batch_results() diff --git a/scripts/retrieve_oasst1.py b/scripts/retrieve_oasst1.py new file mode 100644 index 0000000..436d329 --- /dev/null +++ b/scripts/retrieve_oasst1.py @@ -0,0 +1,96 @@ +import json +import os +from openai import OpenAI +from typing import Dict, Any + +# --- Configuration --- +BATCH_IDS_FILE = "data/raw_datasets/submitted_oasst1_batch_ids.json" +METADATA_FILE = "data/raw_datasets/oasst1_metadata_map.jsonl" +# Store independently for Memory/User Modeling initialization +OUTPUT_FILE = "data/corpora/oasst1_labeled.jsonl" + +def load_metadata() -> Dict[str, Dict[str, Any]]: + print("Loading OASST1 metadata map...") + mapping = {} + if os.path.exists(METADATA_FILE): + with open(METADATA_FILE, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + item = json.loads(line) + mapping[item["custom_id"]] = item + return mapping + +def retrieve_oasst1(): + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("Error: OPENAI_API_KEY not set.") + return + client = OpenAI(api_key=api_key) + + if not os.path.exists(BATCH_IDS_FILE): + print(f"Error: {BATCH_IDS_FILE} not found.") + return + + with open(BATCH_IDS_FILE, "r") as f: + batch_ids = json.load(f) + + meta_map = load_metadata() + count_success = 0 + count_fail = 0 + + print(f"Appending OASST1 results to {OUTPUT_FILE}...") + + with open(OUTPUT_FILE, "a", encoding="utf-8") as f_out: + for b_id in batch_ids: + print(f"\nProcessing Batch {b_id}...") + try: + batch = client.batches.retrieve(b_id) + if batch.output_file_id: + print(f" Downloading output {batch.output_file_id}...") + content = client.files.content(batch.output_file_id).text + + for line in content.splitlines(): + if not line.strip(): continue + res = json.loads(line) + custom_id = res["custom_id"] + + if res["response"]["status_code"] == 200: + try: + body = res["response"]["body"] + llm_content = body["choices"][0]["message"]["content"] + parsed_json = json.loads(llm_content) + + meta = meta_map.get(custom_id) + if meta: + record = { + "custom_id": custom_id, + "original_query": meta["original_query"], + "source": "oasst1", + "user_id": meta.get("user_id"), + "session_id": meta.get("session_id"), + "extracted_json": parsed_json, + "has_preference": len(parsed_json.get("preferences", [])) > 0 + } + f_out.write(json.dumps(record, ensure_ascii=False) + "\n") + count_success += 1 + else: + # Fallback if metadata missing (unlikely) + print(f"Warning: Metadata missing for {custom_id}") + except Exception as e: + print(f"Parse error {custom_id}: {e}") + count_fail += 1 + else: + count_fail += 1 + except Exception as e: + print(f"Error checking batch {b_id}: {e}") + + print("\n" + "="*50) + print("OASST1 RETRIEVAL COMPLETE") + print(f"Successfully processed: {count_success}") + print(f"Failed/Parse Error: {count_fail}") + print(f"Full dataset updated at: {OUTPUT_FILE}") + print("="*50) + +if __name__ == "__main__": + retrieve_oasst1() + diff --git a/scripts/retrieve_synthesis.py b/scripts/retrieve_synthesis.py new file mode 100644 index 0000000..cbc4573 --- /dev/null +++ b/scripts/retrieve_synthesis.py @@ -0,0 +1,118 @@ +import json +import os +from openai import OpenAI +from typing import Dict, Any + +# --- Configuration --- +BATCH_IDS_FILE = "data/raw_datasets/submitted_synthesis_batch_ids.json" +SEED_FILE = "data/raw_datasets/positive_seeds.jsonl" +# Where to save the new synthesized records +OUTPUT_FILE = "data/raw_datasets/synthesized_positives.jsonl" + +def load_seeds() -> Dict[str, Dict[str, Any]]: + print("Loading seeds map...") + mapping = {} + with open(SEED_FILE, "r", encoding="utf-8") as f: + # We need to map custom_id back to the seed to get the GROUND TRUTH preferences. + # But wait, in submit_synthesis_batch.py, we created custom_id as "syn_{original_id}". + # And we need to find the original seed by that ID. + # Problem: positive_seeds.jsonl contains the FULL record including 'extracted_json'. + # We can iterate and build a map: original_custom_id -> record + for idx, line in enumerate(f): + if line.strip(): + item = json.loads(line) + # If item has custom_id, use it. If not, we used "seed_{i}" in submission. + # Let's hope positive_seeds.jsonl has custom_id (it should if it came from retrieve script). + cid = item.get("custom_id") + if not cid: + # Fallback if custom_id missing (e.g. from some older process) + # We generated "seed_{i}" in submit script. + cid = f"seed_{idx}" + + mapping[cid] = item + return mapping + +def retrieve_synthesis(): + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("Error: OPENAI_API_KEY not set.") + return + client = OpenAI(api_key=api_key) + + if not os.path.exists(BATCH_IDS_FILE): + print(f"Error: {BATCH_IDS_FILE} not found.") + return + + with open(BATCH_IDS_FILE, "r") as f: + batch_ids = json.load(f) + + seed_map = load_seeds() + count_rewrites = 0 + count_source_seeds = 0 + + print(f"Processing Synthesis Batches -> {OUTPUT_FILE}...") + + with open(OUTPUT_FILE, "w", encoding="utf-8") as f_out: + for b_id in batch_ids: + print(f"\nProcessing Batch {b_id}...") + try: + batch = client.batches.retrieve(b_id) + if batch.output_file_id: + print(f" Downloading output {batch.output_file_id}...") + content = client.files.content(batch.output_file_id).text + + for line in content.splitlines(): + if not line.strip(): continue + res = json.loads(line) + syn_id = res["custom_id"] # e.g. "syn_req_123" + + # Derive original seed ID: remove "syn_" prefix + if syn_id.startswith("syn_"): + orig_id = syn_id[4:] + else: + orig_id = syn_id + + if res["response"]["status_code"] == 200: + try: + body = res["response"]["body"] + llm_content = body["choices"][0]["message"]["content"] + parsed_json = json.loads(llm_content) + + rewrites = parsed_json.get("rewrites", []) + if not rewrites: + continue + + # Find original preference to inherit + seed = seed_map.get(orig_id) + if seed: + prefs = seed.get("extracted_json") + # Create new records + for rw in rewrites: + new_record = { + "original_query": rw, + "source": "synthesis_gpt4o", + "parent_id": orig_id, + "extracted_json": prefs, # INHERIT PREFERENCE + "has_preference": True + } + f_out.write(json.dumps(new_record, ensure_ascii=False) + "\n") + count_rewrites += 1 + count_source_seeds += 1 + else: + # print(f"Warning: Seed {orig_id} not found in map") + pass + except Exception as e: + print(f"Parse error {syn_id}: {e}") + except Exception as e: + print(f"Error checking batch {b_id}: {e}") + + print("\n" + "="*50) + print("SYNTHESIS RETRIEVAL COMPLETE") + print(f"Processed Source Seeds: {count_source_seeds}") + print(f"Generated New Samples: {count_rewrites}") + print(f"Saved to: {OUTPUT_FILE}") + print("="*50) + +if __name__ == "__main__": + retrieve_synthesis() + diff --git a/scripts/run_putnam_evaluation.py b/scripts/run_putnam_evaluation.py new file mode 100644 index 0000000..f320eea --- /dev/null +++ b/scripts/run_putnam_evaluation.py @@ -0,0 +1,164 @@ +import json +import os +import glob +import argparse +from typing import List, Dict, Any +from openai import OpenAI + +# Configuration +DATA_DIR = "LLaMA-Factory/preprocess/PutnamGAP" +OUTPUT_DIR = "data/putnam_eval" +OUTPUT_FILENAME = "putnam_eval_batch.jsonl" +MODEL_NAME = "gpt-5" # User requested gpt-5 + +SYSTEM_PROMPT = """You are an expert mathematician and a judge for math competitions. You are given an original math problem (and its solution) and a "kernel variant" of that problem (and its solution). + +Your task is to: +1. Evaluate the correctness of the kernel variant. Is the problem statement mathematically sound and clear? Is the provided solution correct? +2. Evaluate the relationship between the original problem and the kernel variant. Are they mathematically equivalent? Or is the variant a strong abstraction/generalization/simplification of the original? Do they test the same core concepts? + +Output your analysis in the following JSON format: +{ + "variant_validity": { + "is_problem_valid": boolean, + "is_solution_correct": boolean, + "comments": "string" + }, + "relation_to_original": { + "is_equivalent": boolean, + "is_strongly_related": boolean, + "relationship_description": "string" + } +}""" + +def load_dataset(data_dir: str) -> List[Dict[str, Any]]: + files = glob.glob(os.path.join(data_dir, "*.json")) + items = [] + print(f"Scanning {len(files)} files in {data_dir}...") + for fpath in files: + try: + with open(fpath, "r", encoding="utf-8") as f: + data = json.load(f) + + # Check for required fields + if "variants" not in data or "kernel_variant" not in data["variants"]: + continue + + orig_q = data.get("question", "") + orig_s = data.get("solution", "") + kv = data["variants"]["kernel_variant"] + kv_q = kv.get("question", "") + kv_s = kv.get("solution", "") + + if not kv_q: + continue + + items.append({ + "id": data.get("index", os.path.basename(fpath)), + "original_question": orig_q, + "original_solution": orig_s, + "kernel_variant_question": kv_q, + "kernel_variant_solution": kv_s + }) + except Exception as e: + print(f"Error reading {fpath}: {e}") + return items + +def create_batch_file(items: List[Dict[str, Any]], output_path: str): + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + count = 0 + with open(output_path, "w", encoding="utf-8") as f: + for item in items: + user_content = f"""[Original Problem] +{item['original_question']} + +[Original Solution] +{item['original_solution']} + +[Kernel Variant Problem] +{item['kernel_variant_question']} + +[Kernel Variant Solution] +{item['kernel_variant_solution']}""" + + # Construct request + request_obj = { + "custom_id": f"req_{item['id']}", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": MODEL_NAME, + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_content} + ], + "response_format": {"type": "json_object"} + } + } + f.write(json.dumps(request_obj) + "\n") + count += 1 + + print(f"Created batch file at {output_path} with {count} requests.") + return count + +def submit_batch(file_path: str): + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("Error: OPENAI_API_KEY not set. Cannot submit.") + return + + client = OpenAI(api_key=api_key) + + print(f"Uploading {file_path} to OpenAI...") + with open(file_path, "rb") as f: + batch_file_obj = client.files.create( + file=f, + purpose="batch" + ) + file_id = batch_file_obj.id + print(f"Uploaded. File ID: {file_id}") + + print("Submitting Batch Job...") + batch_job = client.batches.create( + input_file_id=file_id, + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={ + "description": "PutnamGAP Evaluation" + } + ) + print(f"Submitted. Batch ID: {batch_job.id}") + + # Save Batch ID + id_file = os.path.join(os.path.dirname(file_path), "submitted_batch_ids.json") + existing_ids = [] + if os.path.exists(id_file): + try: + with open(id_file, "r") as f: + existing_ids = json.load(f) + except: + pass + existing_ids.append(batch_job.id) + with open(id_file, "w") as f: + json.dump(existing_ids, f, indent=2) + print(f"Batch ID saved to {id_file}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Prepare and optionally submit PutnamGAP evaluation batch.") + parser.add_argument("--submit", action="store_true", help="Submit the batch to OpenAI after generating.") + args = parser.parse_args() + + items = load_dataset(DATA_DIR) + print(f"Found {len(items)} items with kernel variants.") + + output_path = os.path.join(OUTPUT_DIR, OUTPUT_FILENAME) + if items: + create_batch_file(items, output_path) + if args.submit: + submit_batch(output_path) + else: + print("Use --submit to submit the batch to OpenAI.") + else: + print("No items found to process.") + diff --git a/scripts/run_server.py b/scripts/run_server.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/scripts/run_server.py diff --git a/scripts/smoke_extractor_llm.py b/scripts/smoke_extractor_llm.py new file mode 100644 index 0000000..b16d0e2 --- /dev/null +++ b/scripts/smoke_extractor_llm.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +""" +Smoke test for PreferenceExtractorLLM (Qwen3-0.6B). +Requires 'saves/qwen3-0.6b-full-sft-h200/checkpoint-4358' to be present. +""" + +import sys +import os +import json + +# Add src to sys.path +sys.path.append(os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.config.registry import get_preference_extractor +from personalization.retrieval.preference_store.schemas import ChatTurn + +def main(): + print("Initializing Preference Extractor (qwen3_0_6b_sft)...") + try: + extractor = get_preference_extractor("qwen3_0_6b_sft") + except Exception as e: + print(f"Failed to load extractor: {e}") + print("Please check if the checkpoint exists at saves/qwen3-0.6b-full-sft-h200/checkpoint-4358") + print("and local_models.yaml is configured correctly.") + sys.exit(1) + + print("Extractor loaded successfully.") + + # Construct dummy conversation + turns = [ + ChatTurn(user_id="u1", session_id="s1", turn_id=0, role="user", text="Hi, I am learning Python. Please always use Python 3.11 in your code examples."), + ChatTurn(user_id="u1", session_id="s1", turn_id=1, role="assistant", text="Hello! Python is a great language. How can I help?"), + ChatTurn(user_id="u1", session_id="s1", turn_id=2, role="user", text="Please explain lists. And btw, always use snake_case for variables in your code examples."), + ] + + print("\n--- Input Turns ---") + for t in turns: + print(f"[{t.role}]: {t.text}") + + print("\n--- Extracting ---") + prefs = extractor.extract_turn(turns) + + print("\n--- Output PreferenceList ---") + print(prefs.model_dump_json(indent=2)) + + # Validation + if prefs.preferences: + print("\nSUCCESS: Extracted preferences found.") + else: + print("\nWARNING: No preferences extracted. (Model might need warming up or prompt adjustment)") + +if __name__ == "__main__": + main() + diff --git a/scripts/smoke_llms.py b/scripts/smoke_llms.py new file mode 100644 index 0000000..109020a --- /dev/null +++ b/scripts/smoke_llms.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from pathlib import Path +import sys +import json +import os + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT / "src")) + +from personalization.config.settings import load_local_models_config +from personalization.models.llm.qwen_instruct import QwenInstruct +from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor + + +def ensure_models_present() -> bool: + cfg = load_local_models_config() + ok = True + for spec in (cfg.llm, cfg.preference_extractor): + path = ROOT / spec.local_path + if not path.exists() or not any(path.iterdir()): + print(f"[missing] {path} is empty. Run: python scripts/pull_models.py --target all") + ok = False + return ok + + +def main() -> None: + if not ensure_models_present(): + return + + cfg = load_local_models_config() + os.environ["PREF_DEBUG"] = "1" + llm = QwenInstruct.from_config(cfg) + extractor = QwenRuleExtractor.from_config(cfg) + + print("[llm] generating...") + out = llm.generate("Say hello in one short sentence.", max_new_tokens=32, temperature=0.2) + print(out) + + print("[extractor] extracting...") + scenarios = [ + ( + "math_latex", + "Consider the sequence defined by a_1 = 1 and a_{n+1} = a_n + 1/n for n >= 1. " + "(1) Prove that a_n diverges. (2) Derive an asymptotic expression for a_n in terms of the harmonic numbers H_n. " + "(3) Compute the limit of (a_n - ln n) as n -> infinity. Please use LaTeX for the output.", + ), + ( + "code_python311", + "I have a performance bottleneck in my Python code that processes large CSV files. It reads rows, aggregates stats, and writes summaries. " + "Explain how to optimize I/O and memory, discuss multiprocessing vs async. When you show code, please use Python 3.11 syntax and include type hints in the snippets.", + ), + ( + "data_json_only", + "Given a dataset of user events with timestamps, device types, and regions, outline steps to compute DAU, WAU, and retention. " + "List pitfalls and how to handle missing data. Return your final answer as JSON only.", + ), + ( + "writing_concise_no_emoji", + "Explain the difference between supervised and reinforcement learning with practical examples and cautions. " + "Keep answers concise and avoid emojis.", + ), + ] + for name, query in scenarios: + print(f"\n[scenario] {name}") + prefs = extractor.extract_preferences(query) + print(json.dumps(prefs, indent=2, ensure_ascii=False)) + + +if __name__ == "__main__": + main() + + diff --git a/scripts/split_train_test.py b/scripts/split_train_test.py new file mode 100644 index 0000000..ccb4cb1 --- /dev/null +++ b/scripts/split_train_test.py @@ -0,0 +1,76 @@ +import json +import os +import random + +INPUT_FILE = "data/finetune/preference_extractor_450k.jsonl" +TRAIN_FILE = "data/finetune/train_llama_factory.json" +TEST_FILE = "data/finetune/test_llama_factory.json" +TEST_SIZE = 1000 + +SYSTEM_INSTRUCTION = ( + "Extract user preferences from the query into JSON format based on the PreferenceList schema. " + "If no preferences are found, return {\"preferences\": []}." +) + +def split_and_convert(): + if not os.path.exists(INPUT_FILE): + print(f"Error: {INPUT_FILE} not found.") + return + + print(f"Reading {INPUT_FILE}...") + all_data = [] + + with open(INPUT_FILE, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + item = json.loads(line) + # Convert to LLaMA-Factory format immediately + record = { + "instruction": SYSTEM_INSTRUCTION, + "input": item["input"], + "output": item["output"] + } + all_data.append(record) + + print(f"Total records: {len(all_data)}") + + # Shuffle + random.seed(42) # Fixed seed for reproducibility + random.shuffle(all_data) + + # Split + test_data = all_data[:TEST_SIZE] + train_data = all_data[TEST_SIZE:] + + print(f"Train size: {len(train_data)}") + print(f"Test size: {len(test_data)}") + + # Save Train + print(f"Saving train set to {TRAIN_FILE}...") + with open(TRAIN_FILE, "w", encoding="utf-8") as f: + json.dump(train_data, f, indent=2, ensure_ascii=False) + + # Save Test + print(f"Saving test set to {TEST_FILE}...") + with open(TEST_FILE, "w", encoding="utf-8") as f: + json.dump(test_data, f, indent=2, ensure_ascii=False) + + print("Done!") + + # Update dataset_info advice + print("\nUpdate dataset_info.json with:") + info = { + "preference_extractor_train": { + "file_name": "train_llama_factory.json", + "columns": {"prompt": "instruction", "query": "input", "response": "output"} + }, + "preference_extractor_test": { + "file_name": "test_llama_factory.json", + "columns": {"prompt": "instruction", "query": "input", "response": "output"} + } + } + print(json.dumps(info, indent=2)) + +if __name__ == "__main__": + split_and_convert() + diff --git a/scripts/stats_and_extract.py b/scripts/stats_and_extract.py new file mode 100644 index 0000000..402ade7 --- /dev/null +++ b/scripts/stats_and_extract.py @@ -0,0 +1,56 @@ +import json +import os + +INPUT_FILE = "data/raw_datasets/labeled_full_dataset_batch.jsonl" +OUTPUT_POS_FILE = "data/raw_datasets/positive_seeds.jsonl" + +def extract_and_stats(): + if not os.path.exists(INPUT_FILE): + print(f"Error: {INPUT_FILE} not found.") + return + + print(f"Scanning {INPUT_FILE}...") + + total = 0 + pos_count = 0 + neg_count = 0 + + # Optional: Track distribution of preference types/keys if needed + + with open(INPUT_FILE, "r", encoding="utf-8") as f_in, \ + open(OUTPUT_POS_FILE, "w", encoding="utf-8") as f_out: + + for line in f_in: + if not line.strip(): continue + try: + item = json.loads(line) + total += 1 + + # Check if positive + # Our labeling script ensures 'has_preference' boolean, + # but let's double check the actual list to be safe. + prefs = item.get("extracted_json", {}).get("preferences", []) + + if prefs and len(prefs) > 0: + pos_count += 1 + f_out.write(line) + else: + neg_count += 1 + except: + pass # Skip malformed lines + + ratio = (pos_count / total * 100) if total > 0 else 0 + + print("\n" + "="*30) + print("DATASET STATISTICS") + print("="*30) + print(f"Total Rows: {total}") + print(f"Positive Rows: {pos_count} ({ratio:.2f}%)") + print(f"Negative Rows: {neg_count}") + print("-" * 30) + print(f"Positive seeds saved to: {OUTPUT_POS_FILE}") + print("="*30) + +if __name__ == "__main__": + extract_and_stats() + diff --git a/scripts/submit_batch.py b/scripts/submit_batch.py new file mode 100644 index 0000000..e848dc5 --- /dev/null +++ b/scripts/submit_batch.py @@ -0,0 +1,111 @@ +import json +import os +import time +from typing import List +from openai import OpenAI + +# --- Configuration --- +INPUT_FILE = "data/raw_datasets/combined_raw_queries.jsonl" +BATCH_DIR = "data/raw_datasets/batch_files" +MODEL_NAME = "gpt-5.1" # Or "gpt-4o" +BATCH_SIZE_LIMIT = 49000 # Safe under 50k limit + +# --- Load System Prompt --- +with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f: + SYSTEM_PROMPT = f.read() + +def prepare_and_submit_batches(): + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("Error: OPENAI_API_KEY not set.") + return + client = OpenAI(api_key=api_key) + + os.makedirs(BATCH_DIR, exist_ok=True) + + print(f"Reading from {INPUT_FILE}...") + + # Read all lines first + all_lines = [] + with open(INPUT_FILE, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + all_lines.append(json.loads(line)) + + total_items = len(all_lines) + print(f"Total items: {total_items}") + + batch_ids = [] + + # Split and Process + for batch_idx, i in enumerate(range(0, total_items, BATCH_SIZE_LIMIT)): + chunk = all_lines[i : i + BATCH_SIZE_LIMIT] + chunk_filename = os.path.join(BATCH_DIR, f"batch_input_part_{batch_idx}.jsonl") + + print(f"\n--- Processing Batch {batch_idx} ({len(chunk)} items) ---") + + # 1. Create File + with open(chunk_filename, "w", encoding="utf-8") as f_out: + for item_idx, item in enumerate(chunk): + # Global index to track back later if needed + global_idx = i + item_idx + query = item["query"] + + # Custom ID: "req_{global_index}" + custom_id = f"req_{global_idx}" + + request_obj = { + "custom_id": custom_id, + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": MODEL_NAME, + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": query} + ], + "temperature": 0.0, + "response_format": {"type": "json_object"} + } + } + f_out.write(json.dumps(request_obj) + "\n") + + print(f"File created: {chunk_filename}") + + # 2. Upload File + print("Uploading to OpenAI...") + batch_file_obj = client.files.create( + file=open(chunk_filename, "rb"), + purpose="batch" + ) + file_id = batch_file_obj.id + print(f"Uploaded. File ID: {file_id}") + + # 3. Submit Batch + print("Submitting Batch Job...") + batch_job = client.batches.create( + input_file_id=file_id, + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={ + "description": f"Pers. Extractor Part {batch_idx}", + "part_index": str(batch_idx) + } + ) + print(f"Submitted. Batch ID: {batch_job.id}") + batch_ids.append(batch_job.id) + + # Sleep briefly to be nice to API + time.sleep(1) + + # Save all Batch IDs + id_file = "data/raw_datasets/submitted_batch_ids.json" + with open(id_file, "w") as f: + json.dump(batch_ids, f, indent=2) + + print(f"\nALL DONE! Submitted {len(batch_ids)} batches.") + print(f"Batch IDs saved to {id_file}") + print("Run scripts/check_batch_status.py (you need to write it) to monitor.") + +if __name__ == "__main__": + prepare_and_submit_batches() diff --git a/scripts/submit_oasst1_batch.py b/scripts/submit_oasst1_batch.py new file mode 100644 index 0000000..1a96dd0 --- /dev/null +++ b/scripts/submit_oasst1_batch.py @@ -0,0 +1,120 @@ +import json +import os +import time +from openai import OpenAI + +# --- Configuration --- +INPUT_FILE = "data/raw_datasets/oasst1_queries.jsonl" +BATCH_DIR = "data/raw_datasets/batch_files_oasst1" +METADATA_FILE = "data/raw_datasets/oasst1_metadata_map.jsonl" +MODEL_NAME = "gpt-5.1" +BATCH_SIZE_LIMIT = 49000 + +# --- Load System Prompt --- +with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f: + SYSTEM_PROMPT = f.read() + +def submit_oasst1_batch(): + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("Error: OPENAI_API_KEY not set.") + return + client = OpenAI(api_key=api_key) + + os.makedirs(BATCH_DIR, exist_ok=True) + + print(f"Reading from {INPUT_FILE}...") + + all_lines = [] + with open(INPUT_FILE, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + all_lines.append(json.loads(line)) + + total_items = len(all_lines) + print(f"Total OASST1 items: {total_items}") + + # 1. Generate Metadata Map first + # This ensures we have the mapping even if batch submission fails mid-way + print(f"Generating metadata map to {METADATA_FILE}...") + with open(METADATA_FILE, "w", encoding="utf-8") as f_meta: + for idx, item in enumerate(all_lines): + custom_id = f"oasst1_req_{idx}" + meta_record = { + "custom_id": custom_id, + "user_id": item.get("user_id"), + "session_id": item.get("session_id"), + "turn_id": item.get("turn_id"), + "original_query": item.get("original_query") or item.get("query") + } + f_meta.write(json.dumps(meta_record, ensure_ascii=False) + "\n") + + # Store custom_id back to item list for batch generation + item["_temp_custom_id"] = custom_id + + # 2. Split and Submit + batch_ids = [] + + for batch_idx, i in enumerate(range(0, total_items, BATCH_SIZE_LIMIT)): + chunk = all_lines[i : i + BATCH_SIZE_LIMIT] + chunk_filename = os.path.join(BATCH_DIR, f"oasst1_batch_part_{batch_idx}.jsonl") + + print(f"\n--- Processing OASST1 Batch {batch_idx} ({len(chunk)} items) ---") + + with open(chunk_filename, "w", encoding="utf-8") as f_out: + for item in chunk: + custom_id = item["_temp_custom_id"] + query = item.get("original_query") or item.get("query") + + request_obj = { + "custom_id": custom_id, + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": MODEL_NAME, + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": query} + ], + "temperature": 0.0, + "response_format": {"type": "json_object"} + } + } + f_out.write(json.dumps(request_obj) + "\n") + + print(f"File created: {chunk_filename}") + + print("Uploading to OpenAI...") + batch_file_obj = client.files.create( + file=open(chunk_filename, "rb"), + purpose="batch" + ) + file_id = batch_file_obj.id + print(f"Uploaded. File ID: {file_id}") + + print("Submitting Batch Job...") + batch_job = client.batches.create( + input_file_id=file_id, + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={ + "description": f"Pers. Extractor OASST1 Part {batch_idx}", + "dataset": "oasst1" + } + ) + print(f"Submitted. Batch ID: {batch_job.id}") + batch_ids.append(batch_job.id) + + time.sleep(1) + + id_file = "data/raw_datasets/submitted_oasst1_batch_ids.json" + with open(id_file, "w") as f: + json.dump(batch_ids, f, indent=2) + + print(f"\nALL DONE! Submitted {len(batch_ids)} OASST1 batches.") + print(f"Metadata saved to {METADATA_FILE}") + print(f"Batch IDs saved to {id_file}") + +if __name__ == "__main__": + submit_oasst1_batch() + diff --git a/scripts/submit_retry_batch.py b/scripts/submit_retry_batch.py new file mode 100644 index 0000000..f564c4b --- /dev/null +++ b/scripts/submit_retry_batch.py @@ -0,0 +1,88 @@ +import json +import os +import time +from openai import OpenAI + +# --- Configuration --- +RETRY_INPUT_FILE = "data/raw_datasets/retry_requests.jsonl" +BATCH_DIR = "data/raw_datasets/batch_files_retry" +BATCH_SIZE_LIMIT = 10000 # Smaller chunks as requested + +def submit_retry_batches(): + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("Error: OPENAI_API_KEY not set.") + return + client = OpenAI(api_key=api_key) + + os.makedirs(BATCH_DIR, exist_ok=True) + + if not os.path.exists(RETRY_INPUT_FILE): + print(f"Error: {RETRY_INPUT_FILE} not found.") + return + + print(f"Reading retry requests from {RETRY_INPUT_FILE}...") + + all_lines = [] + with open(RETRY_INPUT_FILE, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + all_lines.append(line.strip()) # Keep as string, no need to parse json + + total_items = len(all_lines) + print(f"Total retry items: {total_items}") + + batch_ids = [] + + # Split and Submit + for batch_idx, i in enumerate(range(0, total_items, BATCH_SIZE_LIMIT)): + chunk = all_lines[i : i + BATCH_SIZE_LIMIT] + chunk_filename = os.path.join(BATCH_DIR, f"retry_batch_part_{batch_idx}.jsonl") + + print(f"\n--- Processing Retry Batch {batch_idx} ({len(chunk)} items) ---") + + # 1. Create File + with open(chunk_filename, "w", encoding="utf-8") as f_out: + for line in chunk: + f_out.write(line + "\n") + + print(f"File created: {chunk_filename}") + + # 2. Upload File + print("Uploading to OpenAI...") + batch_file_obj = client.files.create( + file=open(chunk_filename, "rb"), + purpose="batch" + ) + file_id = batch_file_obj.id + print(f"Uploaded. File ID: {file_id}") + + # 3. Submit Batch + print("Submitting Batch Job...") + batch_job = client.batches.create( + input_file_id=file_id, + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={ + "description": f"Pers. Extractor RETRY Part {batch_idx}", + "retry": "true" + } + ) + print(f"Submitted. Batch ID: {batch_job.id}") + batch_ids.append(batch_job.id) + + time.sleep(1) + + # Save Batch IDs (Append to existing or create new separate file?) + # Let's create a separate file for retries to avoid confusion. + id_file = "data/raw_datasets/submitted_retry_batch_ids.json" + with open(id_file, "w") as f: + json.dump(batch_ids, f, indent=2) + + print(f"\nALL DONE! Submitted {len(batch_ids)} retry batches.") + print(f"Batch IDs saved to {id_file}") + print("Use scripts/check_retry_status.py (need to create/modify) to monitor.") + +if __name__ == "__main__": + submit_retry_batches() + diff --git a/scripts/submit_synthesis_batch.py b/scripts/submit_synthesis_batch.py new file mode 100644 index 0000000..025782d --- /dev/null +++ b/scripts/submit_synthesis_batch.py @@ -0,0 +1,131 @@ +import json +import os +from openai import OpenAI +import time + +# --- Configuration --- +INPUT_SEEDS = "data/raw_datasets/positive_seeds.jsonl" +BATCH_DIR = "data/raw_datasets/batch_files_synthesis" +MODEL_NAME = "gpt-5.1" # Or gpt-4o +BATCH_SIZE_LIMIT = 30000 # 31k total, splitting into 2 files is safe + +SYNTHESIS_SYSTEM_PROMPT = """You are a data augmentation assistant. +Your task is to rewrite a User Query that contains specific preferences into 5 different variations. +The goal is to train a model to recognize these preferences in various contexts. + +Variations required: +1. Formal/Polite: Use sophisticated language and polite markers. +2. Casual/Direct: Use slang, abbreviations, or very direct commands. +3. Implicit/Contextual: Embed the preference naturally within a larger context or story, making it harder to spot. +4. Distractor-Heavy: Mix the preference with irrelevant information or another task. +5. Imperative/Short: Extremely concise, almost robotic. + +Output strictly a JSON object with a single key "rewrites" containing a list of 5 strings. +Example: {"rewrites": ["string1", "string2", "string3", "string4", "string5"]} +""" + +def submit_synthesis_batch(): + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("Error: OPENAI_API_KEY not set.") + return + client = OpenAI(api_key=api_key) + + os.makedirs(BATCH_DIR, exist_ok=True) + + if not os.path.exists(INPUT_SEEDS): + print(f"Error: {INPUT_SEEDS} not found.") + return + + print(f"Reading seeds from {INPUT_SEEDS}...") + + seeds = [] + with open(INPUT_SEEDS, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + seeds.append(json.loads(line)) + + total_items = len(seeds) + print(f"Total seeds: {total_items}") + + batch_ids = [] + + # Split and Submit + for batch_idx, i in enumerate(range(0, total_items, BATCH_SIZE_LIMIT)): + chunk = seeds[i : i + BATCH_SIZE_LIMIT] + chunk_filename = os.path.join(BATCH_DIR, f"synthesis_batch_part_{batch_idx}.jsonl") + + print(f"\n--- Processing Synthesis Batch {batch_idx} ({len(chunk)} items) ---") + + # 1. Create File + with open(chunk_filename, "w", encoding="utf-8") as f_out: + for item in chunk: + # We need to pass both the query and the extracted preference to help the model + # understand WHAT to preserve. + original_query = item["original_query"] + # extracted_json = item["extracted_json"] # Optional, but maybe helpful? + # Actually, showing the extracted preference ensures the rewrite keeps the core intent. + + # Use original custom_id or create new one? + # Let's create new one: "syn_{original_custom_id}" if available, else "syn_{index}" + # Wait, positive_seeds might not have custom_id if it came from the recovered batch. + # Let's check keys. The recovered file usually has custom_id. + base_id = item.get("custom_id", f"seed_{i}") + custom_id = f"syn_{base_id}" # Prefix to distinguish + + user_content = f"Original Query: {original_query}" + # Optionally add: f"\nCore Preference: {json.dumps(extracted_json)}" + + request_obj = { + "custom_id": custom_id, + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": MODEL_NAME, + "messages": [ + {"role": "system", "content": SYNTHESIS_SYSTEM_PROMPT}, + {"role": "user", "content": user_content} + ], + "temperature": 0.7, # Higher temp for diversity + "response_format": {"type": "json_object"} + } + } + f_out.write(json.dumps(request_obj) + "\n") + + print(f"File created: {chunk_filename}") + + # 2. Upload + print("Uploading to OpenAI...") + batch_file_obj = client.files.create( + file=open(chunk_filename, "rb"), + purpose="batch" + ) + file_id = batch_file_obj.id + print(f"Uploaded. File ID: {file_id}") + + # 3. Submit + print("Submitting Batch Job...") + batch_job = client.batches.create( + input_file_id=file_id, + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={ + "description": f"Pers. Extractor Synthesis Part {batch_idx}", + "type": "synthesis" + } + ) + print(f"Submitted. Batch ID: {batch_job.id}") + batch_ids.append(batch_job.id) + + time.sleep(1) + + id_file = "data/raw_datasets/submitted_synthesis_batch_ids.json" + with open(id_file, "w") as f: + json.dump(batch_ids, f, indent=2) + + print(f"\nALL DONE! Submitted {len(batch_ids)} synthesis batches.") + print(f"Batch IDs saved to {id_file}") + +if __name__ == "__main__": + submit_synthesis_batch() + diff --git a/scripts/upload_to_hf.py b/scripts/upload_to_hf.py new file mode 100644 index 0000000..3c41011 --- /dev/null +++ b/scripts/upload_to_hf.py @@ -0,0 +1,69 @@ +import os +import argparse +from huggingface_hub import HfApi, create_repo, upload_folder + +def main(): + parser = argparse.ArgumentParser(description="Upload a checkpoint to Hugging Face Hub") + parser.add_argument("--ckpt_path", type=str, required=True, help="Path to the checkpoint directory") + parser.add_argument("--repo_id", type=str, required=True, help="Hugging Face repo ID (e.g., username/model-name)") + parser.add_argument("--token", type=str, help="Hugging Face token (optional if logged in via CLI)") + parser.add_argument("--private", action="store_true", help="Make the repository private") + parser.add_argument("--include-optimizer", action="store_true", help="Include optimizer states (optimizer.pt, scheduler.pt, etc.)") + + args = parser.parse_args() + + # Expand path + ckpt_path = os.path.abspath(args.ckpt_path) + if not os.path.exists(ckpt_path): + print(f"Error: Checkpoint path does not exist: {ckpt_path}") + return + + print(f"Preparing to upload {ckpt_path} to {args.repo_id}...") + + api = HfApi(token=args.token) + + # Create repo if it doesn't exist + try: + print(f"Creating repository {args.repo_id} (if not exists)...") + create_repo( + repo_id=args.repo_id, + token=args.token, + private=args.private, + exist_ok=True, + repo_type="model" + ) + except Exception as e: + print(f"Error creating repo: {e}") + print("Please check your token and permissions.") + return + + # patterns to ignore if we don't want optimizer states + ignore_patterns = [] + if not args.include_optimizer: + ignore_patterns = [ + "optimizer.pt", + "scheduler.pt", + "rng_state_*.pth", + "trainer_state.json", + "training_args.bin" + ] + print("Excluding optimizer states from upload.") + + print("Starting upload...") + try: + api.upload_folder( + folder_path=ckpt_path, + repo_id=args.repo_id, + repo_type="model", + ignore_patterns=ignore_patterns, + token=args.token + ) + print(f"Successfully uploaded to https://huggingface.co/{args.repo_id}") + except Exception as e: + print(f"Upload failed: {e}") + +if __name__ == "__main__": + main() + + + diff --git a/src/personalization/__init__.py b/src/personalization/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/__init__.py diff --git a/src/personalization/config/__init__.py b/src/personalization/config/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/config/__init__.py diff --git a/src/personalization/config/registry.py b/src/personalization/config/registry.py new file mode 100644 index 0000000..d825ad3 --- /dev/null +++ b/src/personalization/config/registry.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any, Dict, Optional +import torch +import yaml + +from personalization.config import settings + +# Avoid circular imports by NOT importing extractors here at top level +# from personalization.models.preference_extractor.base import PreferenceExtractorBase +# from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor +# from personalization.models.preference_extractor.gpt4o_extractor import GPT4OExtractor +# from personalization.models.preference_extractor.llm_extractor import PreferenceExtractorLLM + +_DTYPE_MAP: Dict[str, torch.dtype] = { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, +} + +def choose_dtype(preferred: Optional[str] = None) -> torch.dtype: + if preferred and preferred.lower() in _DTYPE_MAP: + dt = _DTYPE_MAP[preferred.lower()] + else: + dt = torch.bfloat16 if torch.cuda.is_available() else torch.float32 + if dt is torch.bfloat16 and not torch.cuda.is_available(): + return torch.float32 + return dt + +def choose_device_map(spec: Optional[str] = "auto") -> Any: + return spec or "auto" + +def ensure_local_path(path_str: str) -> str: + path = Path(path_str) + if not path.exists(): + path.mkdir(parents=True, exist_ok=True) + return str(path) + +# --- Chat Model Factory --- +def get_chat_model(name: str, device_override: Optional[str] = None): + """ + Get a chat model by name. + + Args: + name: Model name (e.g., "qwen_1_5b", "llama_8b") + device_override: Optional device override (e.g., "cuda:2"). If None, uses config default. + """ + from personalization.models.llm.base import ChatModel + from personalization.models.llm.qwen_instruct import QwenInstruct + from personalization.models.llm.llama_instruct import LlamaChatModel + + cfg = settings.load_local_models_config() + + # Try to load raw config to support multi-backend map + with open("configs/local_models.yaml", "r") as f: + raw_cfg = yaml.safe_load(f) + + models = raw_cfg.get("models", {}).get("llm", {}) + + # If models['llm'] is a dict of configs (new style) + if isinstance(models, dict) and "backend" in models.get(name, {}): + spec = models[name] + backend = spec.get("backend", "qwen") + path = spec["path"] + device = device_override or spec.get("device", "cuda") # Use override if provided + dtype = spec.get("dtype", "bfloat16") + max_len = spec.get("max_context_length", 4096) + + if backend == "qwen": + return QwenInstruct( + model_path=path, + device=device, + dtype=choose_dtype(dtype), # Converts string to torch.dtype + max_context_length=max_len + ) + elif backend == "llama": + return LlamaChatModel( + model_path=path, + device=device, + dtype=choose_dtype(dtype), # Converts string to torch.dtype + max_context_length=max_len + ) + + # Fallback to legacy single config + return QwenInstruct.from_config(cfg) + +def get_preference_extractor(name: Optional[str] = None): + # Deferred imports to break circular dependency + from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor + from personalization.models.preference_extractor.gpt4o_extractor import GPT4OExtractor + from personalization.models.preference_extractor.llm_extractor import PreferenceExtractorLLM + + cfg = settings.load_local_models_config() + pref_cfg = cfg.preference_extractor + + if name is None: + if isinstance(pref_cfg, dict) and "qwen3_0_6b_sft" in pref_cfg: + name = "qwen3_0_6b_sft" + else: + name = "rule" + + if isinstance(pref_cfg, dict) and name in pref_cfg: + spec = pref_cfg[name] + if name == "qwen3_0_6b_sft": + # Use QwenRuleExtractor which we have updated for SFT End-to-End logic + return QwenRuleExtractor( + model_path=spec["path"], + device_map=spec.get("device", "auto"), + dtype=choose_dtype(spec.get("dtype", "bfloat16")), + ) + # Add 'default' handling if mapped to rule/gpt + if name == "default": + pass + + if name == "gpt4o": + return GPT4OExtractor.from_config(cfg) + elif name == "rule": + if isinstance(pref_cfg, dict): + if "default" in pref_cfg: + # Manually construct to bypass ModelSpec mismatch if needed + spec_dict = pref_cfg["default"] + return QwenRuleExtractor( + model_path=spec_dict["local_path"], + dtype=choose_dtype(spec_dict.get("dtype")), + device_map=choose_device_map(spec_dict.get("device_map")) + ) + else: + return QwenRuleExtractor.from_config(cfg) + + raise ValueError(f"Could not load preference extractor: {name}") diff --git a/src/personalization/config/settings.py b/src/personalization/config/settings.py new file mode 100644 index 0000000..1bb1bbe --- /dev/null +++ b/src/personalization/config/settings.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import os +from pathlib import Path +from typing import Optional, Any, Dict + +import yaml +from pydantic import BaseModel, Field + + +class ModelSpec(BaseModel): + hf_id: str = Field(..., description="Hugging Face repository id") + local_path: str = Field(..., description="Local directory for model weights") + dtype: Optional[str] = Field( + default="bfloat16", description="Preferred torch dtype: bfloat16|float16|float32" + ) + device_map: Optional[str] = Field(default="auto", description="Device map policy") + + +class EmbeddingModelsConfig(BaseModel): + qwen3: Optional[ModelSpec] = None + nemotron: Optional[ModelSpec] = None + + +class RerankerModelsConfig(BaseModel): + qwen3_8b: Optional[ModelSpec] = None + + +class LocalModelsConfig(BaseModel): + llm: ModelSpec + preference_extractor: Any # Allow flexible dict or ModelSpec for now to support map + embedding: Optional[EmbeddingModelsConfig] = None + reranker: Optional[RerankerModelsConfig] = None + + +def _resolve_config_path(env_key: str, default_rel: str) -> Path: + value = os.getenv(env_key) + if value: + return Path(value).expanduser().resolve() + return (Path.cwd() / default_rel).resolve() + + +def load_local_models_config(path: Optional[str] = None) -> LocalModelsConfig: + config_path = Path(path) if path else _resolve_config_path( + "LOCAL_MODELS_CONFIG", "configs/local_models.yaml" + ) + with open(config_path, "r", encoding="utf-8") as f: + raw = yaml.safe_load(f) or {} + models = raw.get("models", {}) + embedding_cfg = None + if "embedding" in models: + emb = models["embedding"] or {} + # dtype/device_map are not necessary for embedders; ModelSpec still accepts them + embedding_cfg = EmbeddingModelsConfig( + qwen3=ModelSpec(**emb["qwen3"]) if "qwen3" in emb else None, + nemotron=ModelSpec(**emb["nemotron"]) if "nemotron" in emb else None, + ) + + reranker_cfg = None + if "reranker" in models: + rer = models["reranker"] or {} + reranker_cfg = RerankerModelsConfig( + qwen3_8b=ModelSpec(**rer["qwen3_8b"]) if "qwen3_8b" in rer else None + ) + + return LocalModelsConfig( + llm=ModelSpec(**models["llm"]), + preference_extractor=models["preference_extractor"], # Pass raw dict/value + embedding=embedding_cfg, + reranker=reranker_cfg, + ) + + diff --git a/src/personalization/data/personamem_loader.py b/src/personalization/data/personamem_loader.py new file mode 100644 index 0000000..3b516ad --- /dev/null +++ b/src/personalization/data/personamem_loader.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import csv +import json +from dataclasses import dataclass +from typing import Dict, List + +@dataclass +class PersonaMemQuestion: + persona_id: str + question_id: str + question_type: str + topic: str + user_question_or_message: str + all_options: List[str] # 4 options + correct_index: int # 0..3 + shared_context_id: str + end_index_in_shared_context: int + +@dataclass +class PersonaMemContext: + shared_context_id: str + messages: List[dict] # raw dicts with "role"/"content" etc + +def load_personamem_questions_32k(path_csv: str) -> List[PersonaMemQuestion]: + questions = [] + with open(path_csv, "r", encoding="utf-8") as f: + reader = csv.DictReader(f) + for row in reader: + # Check fields + # The official csv usually has: question_id, persona_id, shared_context_id, question, correct_answer, options etc. + # Assuming standard PersonaMem format or similar to provided description. + # We might need to adjust based on actual file content. + # Based on user description: + try: + options_str = row.get("all_options", "[]") # Assuming json string + try: + options = json.loads(options_str) + except: + # Fallback if it's not JSON (e.g. string repr) + # For now assume JSON or simple list + options = [] + + # Handle raw answer format (e.g. "(c)" or "c") + raw_ans = row.get("correct_answer", "").strip() + # Remove parens if present + if raw_ans.startswith("(") and raw_ans.endswith(")"): + raw_ans = raw_ans[1:-1] + + # Parse correct index + # If correct_answer is 'A','B','C','D' -> 0,1,2,3 + ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'a': 0, 'b': 1, 'c': 2, 'd': 3} + correct_idx = ans_map.get(raw_ans, -1) + + q = PersonaMemQuestion( + persona_id=row["persona_id"], + question_id=row["question_id"], + question_type=row.get("question_type", "unknown"), + topic=row.get("topic", "unknown"), + user_question_or_message=row.get("user_question_or_message", row.get("question", "")), + all_options=options, + correct_index=correct_idx, + shared_context_id=row["shared_context_id"], + end_index_in_shared_context=int(row.get("end_index_in_shared_context", -1)) + ) + questions.append(q) + except KeyError as e: + # print(f"Skipping row due to missing key: {e}") + continue + return questions + +def load_personamem_contexts_32k(path_jsonl: str) -> Dict[str, PersonaMemContext]: + contexts = {} + with open(path_jsonl, "r", encoding="utf-8") as f: + for line in f: + data = json.loads(line) + # Format: {"hash_id": [messages...]} + for cid, msgs in data.items(): + contexts[cid] = PersonaMemContext( + shared_context_id=cid, + messages=msgs + ) + return contexts + diff --git a/src/personalization/evaluation/__init__.py b/src/personalization/evaluation/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/evaluation/__init__.py diff --git a/src/personalization/evaluation/compare_pairs.py b/src/personalization/evaluation/compare_pairs.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/evaluation/compare_pairs.py diff --git a/src/personalization/evaluation/metrics.py b/src/personalization/evaluation/metrics.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/evaluation/metrics.py diff --git a/src/personalization/feedback/__init__.py b/src/personalization/feedback/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/feedback/__init__.py diff --git a/src/personalization/feedback/gating.py b/src/personalization/feedback/gating.py new file mode 100644 index 0000000..d741874 --- /dev/null +++ b/src/personalization/feedback/gating.py @@ -0,0 +1,72 @@ +import numpy as np +from personalization.feedback.schemas import TurnSample + +def cosine_sim_batch(matrix: np.ndarray, vector: np.ndarray) -> np.ndarray: + # matrix: [N, d], vector: [d] + # return: [N] + norm_m = np.linalg.norm(matrix, axis=1) + norm_v = np.linalg.norm(vector) + + # Avoid div by zero + den = norm_m * norm_v + den[den == 0] = 1e-9 + + return np.dot(matrix, vector) / den + +def estimate_retrieval_gating(sample: TurnSample, reward_hat: float) -> float: + """ + Return g_t in [0,1], representing how much the reward is due to retrieval. + """ + e_q = sample.query_embedding_t + e_q1 = sample.query_embedding_t1 + + if e_q is None or e_q1 is None or not sample.memories: + return 0.5 # Neutral + + # We need embeddings of the memories. + # In a real pipeline, we might pass them in sample.memory_embeddings. + # If missing, we can't compute sim. + if sample.memory_embeddings is None: + # Try to use embedding_e from memory cards if available + # But MemoryCard.embedding_e is List[float] + try: + mem_embs = np.array([m.embedding_e for m in sample.memories]) + if mem_embs.shape[1] == 0: # Empty embeddings + return 0.5 + except: + return 0.5 + else: + mem_embs = sample.memory_embeddings + + # Compute similarities + # shape: [K] + sims_q = cosine_sim_batch(mem_embs, e_q) + sims_q1 = cosine_sim_batch(mem_embs, e_q1) + + s_q_max = sims_q.max() if len(sims_q) > 0 else 0 + s_q1_max = sims_q1.max() if len(sims_q1) > 0 else 0 + + g = 0.5 + + # Heuristics + + # Case A: Retrieval clearly irrelevant + bad reward + # q_t / q_{t+1} have low similarity to memories -> likely retrieval failure (or no relevant memories) + if reward_hat < -0.5 and s_q_max < 0.2 and s_q1_max < 0.2: + g = 0.9 # Blame retrieval (for failing to find anything, or nothing exists) + + # Case B: Retrieval looks good but reward is bad + # Memories are relevant to query, but user still unhappy -> LLM didn't use them well? + elif reward_hat < -0.5 and s_q_max > 0.5: + g = 0.2 # Likely LLM fault + + # Case C: Good reward + # If reward is high, we assume both did okay. + elif reward_hat > 0.5: + if s_q_max > 0.4: + g = 0.6 # Retrieval helped + else: + g = 0.3 # LLM handled it without strong retrieval help + + return float(g) + diff --git a/src/personalization/feedback/handlers.py b/src/personalization/feedback/handlers.py new file mode 100644 index 0000000..60a8d17 --- /dev/null +++ b/src/personalization/feedback/handlers.py @@ -0,0 +1,50 @@ +from typing import Tuple, List, Optional +import numpy as np + +from personalization.retrieval.preference_store.schemas import MemoryCard +from personalization.feedback.schemas import TurnSample +from personalization.feedback.reward_model import estimate_reward +from personalization.feedback.gating import estimate_retrieval_gating + +def eval_step( + q_t: str, + answer_t: str, + q_t1: str, + memories_t: List[MemoryCard], + query_embedding_t: Optional[np.ndarray] = None, + query_embedding_t1: Optional[np.ndarray] = None, +) -> Tuple[float, float]: + """ + Unified evaluation interface. + Given (q_t, a_t, q_{t+1}, memories), returns (reward_hat, gating_hat). + """ + + # Construct a lightweight TurnSample + # We might need embeddings for gating. If not provided, gating might return default. + + # Ensure memories have embeddings for gating + mem_embs = None + if memories_t and memories_t[0].embedding_e: + try: + mem_embs = np.array([m.embedding_e for m in memories_t]) + except: + pass + + sample = TurnSample( + user_id="", # Not needed for simple eval + session_id="", + turn_id=0, + query_t=q_t, + answer_t=answer_t, + query_t1=q_t1, + memories=memories_t, + query_embedding_t=query_embedding_t, + query_embedding_t1=query_embedding_t1, + memory_embeddings=mem_embs + ) + + r_hat = estimate_reward(sample) + g_hat = estimate_retrieval_gating(sample, r_hat) + + return r_hat, g_hat + diff --git a/src/personalization/feedback/online_update.py b/src/personalization/feedback/online_update.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/feedback/online_update.py diff --git a/src/personalization/feedback/reward_model.py b/src/personalization/feedback/reward_model.py new file mode 100644 index 0000000..3584b43 --- /dev/null +++ b/src/personalization/feedback/reward_model.py @@ -0,0 +1,64 @@ +import numpy as np +from personalization.feedback.schemas import TurnSample + +def cosine_sim(a: np.ndarray, b: np.ndarray) -> float: + norm_a = np.linalg.norm(a) + norm_b = np.linalg.norm(b) + if norm_a == 0 or norm_b == 0: + return 0.0 + return float(np.dot(a, b) / (norm_a * norm_b)) + +def estimate_reward(sample: TurnSample) -> float: + """ + Return a scalar reward_hat, indicating if the previous answer was helpful. + Range: [-1.0, 1.0] (approx) + """ + + # 1. Language/Topic Coherence + if sample.query_embedding_t is None or sample.query_embedding_t1 is None: + topic_sim = 0.5 + else: + topic_sim = cosine_sim(sample.query_embedding_t, sample.query_embedding_t1) + + # 2. Negative Keywords (Complaint/Correction) + negative_keywords = [ + "you didn't", "that's not", "incorrect", "redo", "again", "explain more", + "doesn't help", "wrong", "no", "not what i asked", + "你没", "不是", "这不是", "重来", "重新", "不对", "错了", "没说清楚" + ] + + # 3. Positive Keywords (Follow-up/Elaboration) + positive_keywords = [ + "can you elaborate", "give an example", "continue", "what if", "based on that", + "thanks", "good", "great", "cool", + "能不能详细一点", "举个例子", "再继续", "那如果", "接下来", "在这个基础上", "谢谢", "不错" + ] + + q1_lower = sample.query_t1.lower() + + has_negative = any(kw in q1_lower for kw in negative_keywords) + has_positive = any(kw in q1_lower for kw in positive_keywords) + + reward = 0.0 + + if has_negative: + reward -= 1.0 + + if has_positive: + # Only reward if topic similarity is decent, otherwise might be "thanks, bye" (end of session) + # But "thanks" is good. + reward += 0.5 + if topic_sim > 0.3: + reward += 0.5 + + if topic_sim < 0.2: + # Topic shift -> previous interaction likely finished or failed. + # If no explicit positive/negative, assume neutral/slightly decayed. + # If user changes topic, it often means the previous task is done (neutral/positive) + # OR they gave up (negative). Hard to tell. + # Let's dampen the reward towards 0. + reward *= 0.5 + + # Clip + return max(-1.0, min(1.0, reward)) + diff --git a/src/personalization/feedback/sampler.py b/src/personalization/feedback/sampler.py new file mode 100644 index 0000000..9e26912 --- /dev/null +++ b/src/personalization/feedback/sampler.py @@ -0,0 +1,109 @@ +from typing import Iterable, List, Optional +import numpy as np +from tqdm import tqdm + +from personalization.retrieval.preference_store.schemas import ChatTurn, MemoryCard +from personalization.feedback.schemas import TurnSample +from personalization.retrieval.pipeline import retrieve_with_rerank +from personalization.models.llm.qwen_instruct import QwenInstruct +from personalization.models.embedding.base import EmbeddingModel +from personalization.models.reranker.base import Reranker +from personalization.user_model.tensor_store import UserTensorStore + +def build_turn_samples_from_sessions( + sessions: Iterable[List[ChatTurn]], + embed_model: EmbeddingModel, + llm: QwenInstruct, + reranker: Reranker, + memory_cards: List[MemoryCard], + memory_embeddings: np.ndarray, + user_store: UserTensorStore, + item_vectors: np.ndarray, + max_samples: Optional[int] = None, + topk_dense: int = 64, + topk_rerank: int = 3, +) -> List[TurnSample]: + samples = [] + + for turns in tqdm(sessions, desc="Building TurnSamples"): + if max_samples and len(samples) >= max_samples: + break + + # Ensure sorted by turn_id + sorted_turns = sorted(turns, key=lambda x: x.turn_id) + + # Iterate to find (q_t, a_t, q_{t+1}) + for i in range(len(sorted_turns)): + if max_samples and len(samples) >= max_samples: + break + + q_t = sorted_turns[i] + if q_t.role != "user": + continue + + # Find next user turn + # Also try to find assistant response in between + a_t_text = "" + q_t1 = None + + # Look ahead + for j in range(i + 1, len(sorted_turns)): + next_turn = sorted_turns[j] + if next_turn.role == "assistant" and not a_t_text: + a_t_text = next_turn.text + elif next_turn.role == "user": + q_t1 = next_turn + break + + if not q_t1: + # End of session or no subsequent user query + continue + + # We have q_t, a_t (optional but preferred), q_t1 + # If a_t is missing, we might skip or use empty string. + # For RL, we usually need the answer to evaluate quality. + # If dataset doesn't have assistant turns, we might need to generate one? + # For now, let's proceed even if a_t is empty, or maybe require it. + if not a_t_text: + # Try to use LLM to generate if needed, but for offline sampling + # from existing chats, we prefer existing answers. + # If using OASST1, it should have assistant turns. + pass + + # 3. Retrieve memories for q_t + memories_t = retrieve_with_rerank( + user_id=q_t.user_id, + query=q_t.text, + embed_model=embed_model, + reranker=reranker, + memory_cards=memory_cards, + memory_embeddings=memory_embeddings, + user_store=user_store, + item_vectors=item_vectors, + topk_dense=topk_dense, + topk_rerank=topk_rerank, + beta_long=0.0, + beta_short=0.0, + only_own_memories=True # Assume we want user specific memories + ) + + # 4. Precompute embeddings + # We can do this efficiently later or batch, but here per sample + e_q_t = embed_model.encode([q_t.text], return_tensor=False)[0] + e_q_t1 = embed_model.encode([q_t1.text], return_tensor=False)[0] + + sample = TurnSample( + user_id=q_t.user_id, + session_id=q_t.session_id, + turn_id=q_t.turn_id, + query_t=q_t.text, + answer_t=a_t_text, + query_t1=q_t1.text, + memories=memories_t, + query_embedding_t=np.array(e_q_t), + query_embedding_t1=np.array(e_q_t1) + ) + samples.append(sample) + + return samples + diff --git a/src/personalization/feedback/schemas.py b/src/personalization/feedback/schemas.py new file mode 100644 index 0000000..b15db80 --- /dev/null +++ b/src/personalization/feedback/schemas.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Optional, Any +import numpy as np + +from personalization.retrieval.preference_store.schemas import MemoryCard + +@dataclass +class TurnSample: + user_id: str + session_id: str + turn_id: int # index of q_t within the session + query_t: str # q_t + answer_t: str # a_t + query_t1: str # q_{t+1} + memories: List[MemoryCard] # A_t + + # Optional pre-computed vectors and features + query_embedding_t: Optional[np.ndarray] = None + query_embedding_t1: Optional[np.ndarray] = None + memory_embeddings: Optional[np.ndarray] = None # corresponding e_m or v_m for memories + diff --git a/src/personalization/retrieval/__init__.py b/src/personalization/retrieval/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/retrieval/__init__.py diff --git a/src/personalization/retrieval/chunking/__init__.py b/src/personalization/retrieval/chunking/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/retrieval/chunking/__init__.py diff --git a/src/personalization/retrieval/chunking/rules.py b/src/personalization/retrieval/chunking/rules.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/retrieval/chunking/rules.py diff --git a/src/personalization/retrieval/pipeline.py b/src/personalization/retrieval/pipeline.py new file mode 100644 index 0000000..3d3eeb7 --- /dev/null +++ b/src/personalization/retrieval/pipeline.py @@ -0,0 +1,250 @@ +from typing import List, Tuple +import numpy as np + +from personalization.models.embedding.base import EmbeddingModel +from personalization.models.reranker.base import Reranker +from personalization.retrieval.preference_store.schemas import MemoryCard +from personalization.user_model.tensor_store import UserTensorStore, UserState +from personalization.user_model.scoring import score_with_user +from personalization.user_model.policy.reinforce import compute_policy_scores + +def cosine_similarity_matrix(E: np.ndarray, e_q: np.ndarray) -> np.ndarray: + # E: [M, d], e_q: [d] + return np.dot(E, e_q) + +def dense_topk_indices( + query: str, + embed_model: EmbeddingModel, + memory_embeddings: np.ndarray, + valid_indices: List[int] = None, + topk: int = 64 +) -> List[int]: + """ + Return indices of topk memories based on dense embedding similarity. + If valid_indices is provided, only search within that subset. + """ + if valid_indices is not None and len(valid_indices) == 0: + return [] + + e_q_list = embed_model.encode([query], normalize=True, return_tensor=False) + e_q = np.array(e_q_list[0], dtype=np.float32) + + # Select subset of embeddings if restricted + if valid_indices is not None: + # subset_embeddings = memory_embeddings[valid_indices] + # But valid_indices might be arbitrary. + # Efficient way: only dot product with subset + # E_sub: [M_sub, d] + E_sub = memory_embeddings[valid_indices] + sims_sub = np.dot(E_sub, e_q) + + # Topk within subset + k = min(topk, len(sims_sub)) + if k == 0: + return [] + + # argsort gives indices relative to E_sub (0..M_sub-1) + # We need to map back to original indices + idx_sub = np.argsort(sims_sub)[-k:][::-1] + + return [valid_indices[i] for i in idx_sub] + + # Global search + sims = np.dot(memory_embeddings, e_q) + k = min(topk, len(memory_embeddings)) + if k == 0: + return [] + + idx = np.argsort(sims)[-k:][::-1] + return idx.tolist() + +def retrieve_with_policy( + user_id: str, + query: str, + embed_model: EmbeddingModel, + reranker: Reranker, + memory_cards: List[MemoryCard], + memory_embeddings: np.ndarray, # shape: [M, d] + user_store: UserTensorStore, + item_vectors: np.ndarray, # shape: [M, k], v_m + topk_dense: int = 64, + topk_rerank: int = 8, + beta_long: float = 0.0, + beta_short: float = 0.0, + tau: float = 1.0, + only_own_memories: bool = False, + sample: bool = False, +) -> Tuple[List[MemoryCard], np.ndarray, np.ndarray, List[int], np.ndarray]: + """ + Returns extended info for policy update: + (candidates, candidate_item_vectors, base_scores, chosen_indices, policy_probs) + + Args: + sample: If True, use stochastic sampling from policy distribution (for training/exploration). + If False, use deterministic top-k by policy scores (for evaluation). + """ + # 0. Filter indices if needed + valid_indices = None + if only_own_memories: + valid_indices = [i for i, card in enumerate(memory_cards) if card.user_id == user_id] + if not valid_indices: + return [], np.array([]), np.array([]), [], np.array([]) + + # 1. Dense retrieval + dense_idx = dense_topk_indices( + query, + embed_model, + memory_embeddings, + valid_indices=valid_indices, + topk=topk_dense + ) + # DEBUG: Check for duplicates or out of bounds + if len(dense_idx) > 0: + import os + if os.getenv("RETRIEVAL_DEBUG") == "1": + print(f" [Pipeline] Dense Indices (Top {len(dense_idx)}): {dense_idx[:10]}...") + print(f" [Pipeline] Max Index: {max(dense_idx)} | Memory Size: {len(memory_cards)}") + + if not dense_idx: + return [], np.array([]), np.array([]), [], np.array([]) + + candidates = [memory_cards[i] for i in dense_idx] + candidate_docs = [c.note_text for c in candidates] + + # 2. Rerank base score (P(yes|q,m)) + base_scores = np.array(reranker.score(query, candidate_docs)) + + # 3. Policy Scoring (Softmax) + user_state: UserState = user_store.get_state(user_id) + candidate_vectors = item_vectors[dense_idx] # [K, k] + + policy_out = compute_policy_scores( + base_scores=base_scores, + user_state=user_state, + item_vectors=candidate_vectors, + beta_long=beta_long, + beta_short=beta_short, + tau=tau + ) + + # 4. Selection: Greedy (eval) or Stochastic (training) + k = min(topk_rerank, len(policy_out.scores)) + + if sample: + # Stochastic sampling from policy distribution (for training/exploration) + # Sample k indices without replacement, weighted by policy probs + probs = policy_out.probs + # Normalize to ensure sum to 1 (handle numerical issues) + probs = probs / (probs.sum() + 1e-10) + # Sample without replacement + chosen_indices = np.random.choice( + len(probs), size=k, replace=False, p=probs + ).tolist() + else: + # Deterministic top-k by policy scores (for evaluation) + top_indices_local = policy_out.scores.argsort()[-k:][::-1] + chosen_indices = top_indices_local.tolist() + + import os + if os.getenv("RETRIEVAL_DEBUG") == "1": + print(f" [Pipeline] Candidates: {len(candidates)} | Chosen Indices: {chosen_indices} | Sample: {sample}") + + return candidates, candidate_vectors, base_scores, chosen_indices, policy_out.probs + +def retrieve_no_policy( + user_id: str, + query: str, + embed_model: EmbeddingModel, + reranker: Reranker, + memory_cards: List[MemoryCard], + memory_embeddings: np.ndarray, # shape: [M, d] + topk_dense: int = 64, + topk_rerank: int = 8, + only_own_memories: bool = False, +) -> Tuple[List[MemoryCard], np.ndarray, np.ndarray, List[int], np.ndarray]: + """ + Deterministic retrieval baseline (NoPersonal mode): + - Dense retrieval -> Rerank -> Top-K (no policy sampling, no user vector influence) + + Returns same structure as retrieve_with_policy for compatibility: + (candidates, candidate_item_vectors, base_scores, chosen_indices, rerank_scores_for_chosen) + + Note: candidate_item_vectors is empty array (not used in NoPersonal mode) + The last return value is rerank scores instead of policy probs + """ + # 0. Filter indices if needed + valid_indices = None + if only_own_memories: + valid_indices = [i for i, card in enumerate(memory_cards) if card.user_id == user_id] + if not valid_indices: + return [], np.array([]), np.array([]), [], np.array([]) + + # 1. Dense retrieval + dense_idx = dense_topk_indices( + query, + embed_model, + memory_embeddings, + valid_indices=valid_indices, + topk=topk_dense + ) + + if not dense_idx: + return [], np.array([]), np.array([]), [], np.array([]) + + candidates = [memory_cards[i] for i in dense_idx] + candidate_docs = [c.note_text for c in candidates] + + # 2. Rerank base score (P(yes|q,m)) + base_scores = np.array(reranker.score(query, candidate_docs)) + + # 3. Deterministic Top-K selection based on rerank scores ONLY (no policy) + k = min(topk_rerank, len(base_scores)) + top_indices_local = base_scores.argsort()[-k:][::-1] + chosen_indices = top_indices_local.tolist() + + # Get scores for chosen items (for logging compatibility) + chosen_scores = base_scores[top_indices_local] + + # Return empty item vectors (not used in NoPersonal mode) + # Return rerank scores as the "probs" field for logging compatibility + return candidates, np.array([]), base_scores, chosen_indices, chosen_scores + + +def retrieve_with_rerank( + user_id: str, + query: str, + embed_model: EmbeddingModel, + reranker: Reranker, + memory_cards: List[MemoryCard], + memory_embeddings: np.ndarray, # shape: [M, d] + user_store: UserTensorStore, + item_vectors: np.ndarray, # shape: [M, k], v_m + topk_dense: int = 64, + topk_rerank: int = 8, + beta_long: float = 0.0, + beta_short: float = 0.0, + only_own_memories: bool = False, +) -> List[MemoryCard]: + """ + Wrapper around retrieve_with_policy for standard inference. + """ + candidates, _, _, chosen_indices, _ = retrieve_with_policy( + user_id=user_id, + query=query, + embed_model=embed_model, + reranker=reranker, + memory_cards=memory_cards, + memory_embeddings=memory_embeddings, + user_store=user_store, + item_vectors=item_vectors, + topk_dense=topk_dense, + topk_rerank=topk_rerank, + beta_long=beta_long, + beta_short=beta_short, + tau=1.0, # Default tau + only_own_memories=only_own_memories + ) + + return [candidates[i] for i in chosen_indices] + + diff --git a/src/personalization/retrieval/preference_store/__init__.py b/src/personalization/retrieval/preference_store/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/retrieval/preference_store/__init__.py diff --git a/src/personalization/retrieval/preference_store/base.py b/src/personalization/retrieval/preference_store/base.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/retrieval/preference_store/base.py diff --git a/src/personalization/retrieval/preference_store/schemas.py b/src/personalization/retrieval/preference_store/schemas.py new file mode 100644 index 0000000..eb82558 --- /dev/null +++ b/src/personalization/retrieval/preference_store/schemas.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from typing import List, Literal, Optional, Dict, Any + +from pydantic import BaseModel, Field, confloat + + +class Preference(BaseModel): + condition: str = Field( + ..., min_length=1, max_length=128, description="When the rule applies" + ) + action: str = Field( + ..., min_length=1, max_length=256, description="What to do in that case" + ) + confidence: confloat(ge=0.0, le=1.0) = Field( + ..., description="Confidence the rule is correct" + ) + + +class PreferenceList(BaseModel): + preferences: List[Preference] = Field(default_factory=list) + + +def preference_list_json_schema() -> dict: + return PreferenceList.model_json_schema() + + +class ChatTurn(BaseModel): + user_id: str + session_id: str + turn_id: int + role: Literal["user", "assistant"] + text: str + timestamp: Optional[float] = None + meta: Dict[str, Any] = Field(default_factory=dict) + + +class MemoryCard(BaseModel): + card_id: str + user_id: str + source_session_id: str + source_turn_ids: List[int] + raw_queries: List[str] # The original user utterances + preference_list: PreferenceList + note_text: str # Summarized "condition: action" text + embedding_e: List[float] # The embedding vector + kind: Literal["pref", "fact"] = "pref" diff --git a/src/personalization/retrieval/preference_store/vector_kv.py b/src/personalization/retrieval/preference_store/vector_kv.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/retrieval/preference_store/vector_kv.py diff --git a/src/personalization/retrieval/rerank.py b/src/personalization/retrieval/rerank.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/retrieval/rerank.py diff --git a/src/personalization/retrieval/store/__init__.py b/src/personalization/retrieval/store/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/retrieval/store/__init__.py diff --git a/src/personalization/retrieval/store/base.py b/src/personalization/retrieval/store/base.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/retrieval/store/base.py diff --git a/src/personalization/retrieval/store/faiss_store.py b/src/personalization/retrieval/store/faiss_store.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/retrieval/store/faiss_store.py diff --git a/src/personalization/retrieval/store/pgvector_store.py b/src/personalization/retrieval/store/pgvector_store.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/retrieval/store/pgvector_store.py diff --git a/src/personalization/serving/__init__.py b/src/personalization/serving/__init__.py new file mode 100644 index 0000000..11adcf8 --- /dev/null +++ b/src/personalization/serving/__init__.py @@ -0,0 +1,22 @@ +# Personalization Serving Module +# +# This module provides the interface layer for the personalization system. + +from personalization.serving.personalized_llm import ( + PersonalizedLLM, + AssistantResponse, + UsageStats, + DebugInfo, + Feedback, + create_personalized_llm, +) + +__all__ = [ + "PersonalizedLLM", + "AssistantResponse", + "UsageStats", + "DebugInfo", + "Feedback", + "create_personalized_llm", +] + diff --git a/src/personalization/serving/api/__init__.py b/src/personalization/serving/api/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/serving/api/__init__.py diff --git a/src/personalization/serving/api/main.py b/src/personalization/serving/api/main.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/serving/api/main.py diff --git a/src/personalization/serving/api/routes/__init__.py b/src/personalization/serving/api/routes/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/serving/api/routes/__init__.py diff --git a/src/personalization/serving/api/routes/feedback.py b/src/personalization/serving/api/routes/feedback.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/serving/api/routes/feedback.py diff --git a/src/personalization/serving/api/routes/query.py b/src/personalization/serving/api/routes/query.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/serving/api/routes/query.py diff --git a/src/personalization/serving/api/routes/users.py b/src/personalization/serving/api/routes/users.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/serving/api/routes/users.py diff --git a/src/personalization/serving/api/schemas.py b/src/personalization/serving/api/schemas.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/serving/api/schemas.py diff --git a/src/personalization/serving/personalized_llm.py b/src/personalization/serving/personalized_llm.py new file mode 100644 index 0000000..2c4d5a8 --- /dev/null +++ b/src/personalization/serving/personalized_llm.py @@ -0,0 +1,837 @@ +#!/usr/bin/env python3 +""" +Personalized LLM Interface for Evaluation. + +This module provides the `PersonalizedLLM` class that wraps the entire +personalization system into a clean interface for evaluation frameworks +and user simulators. + +Interface contract: +- chat(user_id, query) -> AssistantResponse: Main online interface +- reset_session(user_id): Clear session history and short-term state +- reset_user(user_id): Completely reset user (long-term, short-term, memories) +- apply_feedback(feedback): Apply external feedback for RL updates +""" + +from __future__ import annotations + +import os +import sys +import uuid +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +import numpy as np +import yaml + +# Ensure src is in path for standalone usage +_src_path = os.path.join(os.path.dirname(__file__), "../../..") +if _src_path not in sys.path: + sys.path.insert(0, _src_path) + +from personalization.config.settings import load_local_models_config +from personalization.config.registry import get_preference_extractor, get_chat_model +from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B +from personalization.models.reranker.qwen3_reranker import Qwen3Reranker +from personalization.user_model.tensor_store import UserTensorStore, UserState +from personalization.user_model.session_state import OnlineSessionState +from personalization.user_model.features import ItemProjection +from personalization.retrieval.preference_store.schemas import ( + MemoryCard, ChatTurn, PreferenceList, Preference +) +from personalization.retrieval.pipeline import retrieve_with_policy, retrieve_no_policy +from personalization.feedback.handlers import eval_step +from personalization.user_model.policy.reinforce import reinforce_update_user_state + + +# ============================================================================= +# Data Classes for Interface +# ============================================================================= + +@dataclass +class UsageStats: + """Token usage statistics from a chat completion.""" + prompt_tokens: int + completion_tokens: int + total_tokens: int + model: str + + +@dataclass +class DebugInfo: + """ + Debug information for analysis and ablation studies. + All fields are optional - fill what you have, leave empty what you don't. + """ + selected_memory_ids: List[str] = field(default_factory=list) + selected_memory_notes: List[str] = field(default_factory=list) + selected_memory_scores: List[float] = field(default_factory=list) + user_vector_before: Optional[List[float]] = None + user_vector_after: Optional[List[float]] = None + extracted_preferences: List[Dict[str, Any]] = field(default_factory=list) + extra: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class AssistantResponse: + """Response from the personalized LLM chat interface.""" + answer: str + usage: UsageStats + debug: Optional[DebugInfo] = None + + +@dataclass +class Feedback: + """ + Feedback data structure for RL updates from user simulator or judge. + + Attributes: + user_id: The user this feedback is for. + turn_id: The turn this feedback refers to (from the previous turn). + reward: Reward scalar computed by user simulator / judge. + gating: Gating flag (1=valid learning signal, 0=skip update). + meta: Additional metadata for training/analysis. + """ + user_id: str + turn_id: int + reward: float + gating: float # Can be 0.0 or 1.0, or continuous + meta: Dict[str, Any] = field(default_factory=dict) + + +# ============================================================================= +# Internal Session State Extended +# ============================================================================= + +@dataclass +class _SessionContext: + """Extended session context for evaluation tracking.""" + session_state: OnlineSessionState + turn_counter: int = 0 + # Store info needed for apply_feedback + pending_rl_update: Optional[Dict[str, Any]] = None + + +# ============================================================================= +# PersonalizedLLM Class +# ============================================================================= + +class PersonalizedLLM: + """ + Personalized LLM wrapper for evaluation frameworks. + + This class provides a clean interface that accepts only (user_id, query) + for the main chat function, while internally managing: + - User state vectors (z_long, z_short) + - Session history + - Memory retrieval and policy + - Preference extraction and storage + - RL updates + + Example usage: + llm = PersonalizedLLM() + + # Reset user for fresh experiment + llm.reset_user("user_123") + + # Start a session + llm.reset_session("user_123") + + # Chat + response = llm.chat("user_123", "What's a good recipe for dinner?") + print(response.answer) + + # Apply feedback from previous turn (from turn 2 onwards) + llm.apply_feedback(Feedback( + user_id="user_123", + turn_id=0, + reward=0.8, + gating=1.0 + )) + """ + + def __init__( + self, + config_path: Optional[str] = None, + user_store_path: str = "data/users/user_store_eval.npz", + memory_cards_path: str = "data/corpora/memory_cards.jsonl", + memory_embeddings_path: str = "data/corpora/memory_embeddings.npy", + item_projection_path: str = "data/corpora/item_projection.npz", + only_own_memories: bool = True, + enable_preference_extraction: bool = True, + enable_rl_updates: bool = True, + mode: str = "full", # "full", "nopersonal", or "vanilla" + eval_mode: bool = True, # True = greedy selection, False = stochastic sampling + device_assignment: Optional[Dict[str, str]] = None, # Multi-GPU support + ): + """ + Initialize the PersonalizedLLM. + + Args: + config_path: Path to config file. If None, uses default locations. + user_store_path: Path to persist user state vectors. + memory_cards_path: Path to memory cards JSONL file. + memory_embeddings_path: Path to memory embeddings numpy file. + item_projection_path: Path to item projection (PCA) file. + only_own_memories: If True, only retrieve user's own memories (strict privacy). + enable_preference_extraction: If True, extract preferences from user turns. + enable_rl_updates: If True, apply RL updates via apply_feedback. + mode: "full" for full personalization, "nopersonal" for baseline (no user vector influence), + "vanilla" for pure LLM without any memory retrieval or preference extraction. + eval_mode: If True, use greedy/deterministic selection (for evaluation). + If False, use stochastic sampling (for training/exploration). + device_assignment: Optional dict to assign models to specific GPUs. + Example: {"embed": "cuda:0", "reranker": "cuda:1", "chat": "cuda:2", "extractor": "cuda:3"} + If None, uses "auto" for all models. + """ + self.only_own_memories = only_own_memories + self.enable_preference_extraction = enable_preference_extraction + self.enable_rl_updates = enable_rl_updates + self.mode = mode # "full" or "nopersonal" + self.eval_mode = eval_mode # True = greedy, False = sample + + # Multi-GPU device assignment + self._device_assignment = device_assignment or { + "embed": "auto", + "reranker": "auto", + "chat": "auto", + "extractor": "auto", + } + + # Paths + self._memory_cards_path = memory_cards_path + self._memory_embeddings_path = memory_embeddings_path + self._item_projection_path = item_projection_path + + # RL Configuration + # Note: beta/eta increased for more significant z_u updates + self._rl_cfg = { + "item_dim": 256, + "beta_long": 2.0, # Increased from 0.1 for stronger personalization + "beta_short": 5.0, # Increased from 0.3 + "tau": 1.0, + "eta_long": 0.01, # Increased from 1e-3 for faster learning + "eta_short": 0.05, # Increased from 5e-3 + "ema_alpha": 0.05, + "short_decay": 0.1, + "dense_topk": 64, + "rerank_topk": 3, + "max_new_tokens": 512, + } + + # Load config and override RL params if available + self._load_config(config_path) + + # Load models + print("[PersonalizedLLM] Loading models...") + self._load_models() + + # Load memory store + print("[PersonalizedLLM] Loading memory store...") + self._load_memory_store() + + # Initialize user store + self._user_store = UserTensorStore( + k=self._rl_cfg["item_dim"], + path=user_store_path, + ) + + # Session contexts per user (in-memory) + self._sessions: Dict[str, _SessionContext] = {} + + print("[PersonalizedLLM] Initialization complete.") + + def _load_config(self, config_path: Optional[str]): + """Load configuration from yaml files.""" + self._cfg = load_local_models_config() + + # Try to load user_model.yaml for RL params + if config_path is None: + config_path = "configs/user_model.yaml" + + self._llm_name = "qwen_1_5b" # Default + + try: + if os.path.exists(config_path): + with open(config_path, "r") as f: + user_cfg = yaml.safe_load(f) + if user_cfg: + # Override RL params if present + for key in self._rl_cfg: + if key in user_cfg: + self._rl_cfg[key] = user_cfg[key] + # LLM name + if "llm_name" in user_cfg: + self._llm_name = user_cfg["llm_name"] + except Exception as e: + print(f"[PersonalizedLLM] Warning: Failed to load config: {e}") + + def _load_models(self): + """Load all ML models with optional multi-GPU assignment.""" + import torch + + # Report GPU availability + num_gpus = torch.cuda.device_count() + print(f"[PersonalizedLLM] Available GPUs: {num_gpus}") + for i in range(num_gpus): + mem = torch.cuda.get_device_properties(i).total_memory / 1e9 + print(f" GPU {i}: {torch.cuda.get_device_name(i)} ({mem:.1f}GB)") + + embed_device = self._device_assignment.get("embed", "auto") + reranker_device = self._device_assignment.get("reranker", "auto") + chat_device = self._device_assignment.get("chat", "auto") + extractor_device = self._device_assignment.get("extractor", "auto") + + # Embedding model + print(f"[PersonalizedLLM] Loading Embedding model on {embed_device}...") + self._embed_model = Qwen3Embedding8B( + model_path=self._cfg.embedding.qwen3.local_path, + dtype=torch.bfloat16, + device_map=embed_device, + ) + + # Reranker + print(f"[PersonalizedLLM] Loading Reranker on {reranker_device}...") + self._reranker = Qwen3Reranker( + model_path=self._cfg.reranker.qwen3_8b.local_path, + device_map=reranker_device, + dtype=torch.bfloat16, + ) + + # Chat model (via registry for backend switching) + print(f"[PersonalizedLLM] Loading ChatModel: {self._llm_name} on {chat_device}...") + # Pass device override if specified (not "auto") + device_for_chat = chat_device if chat_device != "auto" else None + self._chat_model = get_chat_model(self._llm_name, device_override=device_for_chat) + + # Preference extractor + if self.enable_preference_extraction: + extractor_name = "qwen3_0_6b_sft" + print(f"[PersonalizedLLM] Loading extractor: {extractor_name} on {extractor_device}...") + try: + self._extractor = get_preference_extractor(extractor_name) + except Exception as e: + print(f"[PersonalizedLLM] Warning: Failed to load {extractor_name}: {e}. Using rule-based.") + self._extractor = get_preference_extractor("rule") + else: + print("[PersonalizedLLM] Preference extraction disabled, using rule-based extractor.") + self._extractor = get_preference_extractor("rule") + + def _load_memory_store(self): + """Load memory cards and embeddings.""" + if not os.path.exists(self._memory_cards_path): + print(f"[PersonalizedLLM] Warning: Memory cards not found at {self._memory_cards_path}") + self._memory_cards: List[MemoryCard] = [] + self._memory_embeddings = np.zeros((0, 4096), dtype=np.float32) + self._item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32) + self._projection = None + return + + # Load cards + self._memory_cards = [] + with open(self._memory_cards_path, "r") as f: + for line in f: + line = line.strip() + if line: + self._memory_cards.append(MemoryCard.model_validate_json(line)) + + # Load embeddings + if os.path.exists(self._memory_embeddings_path): + self._memory_embeddings = np.load(self._memory_embeddings_path) + else: + self._memory_embeddings = np.zeros((len(self._memory_cards), 4096), dtype=np.float32) + + # Load projection + if os.path.exists(self._item_projection_path): + proj_data = np.load(self._item_projection_path) + self._projection = ItemProjection(P=proj_data["P"], mean=proj_data["mean"]) + self._item_vectors = proj_data["V"] + else: + self._projection = None + self._item_vectors = np.zeros((len(self._memory_cards), self._rl_cfg["item_dim"]), dtype=np.float32) + + print(f"[PersonalizedLLM] Loaded {len(self._memory_cards)} memory cards.") + + def _get_or_create_session(self, user_id: str) -> _SessionContext: + """Get or create session context for a user.""" + if user_id not in self._sessions: + self._sessions[user_id] = _SessionContext( + session_state=OnlineSessionState(user_id=user_id), + turn_counter=0, + ) + return self._sessions[user_id] + + def _build_chat_turn(self, user_id: str, text: str, role: str, turn_id: int) -> ChatTurn: + """Build a ChatTurn object.""" + return ChatTurn( + user_id=user_id, + session_id=f"eval_session_{user_id}", + turn_id=turn_id, + role=role, + text=text, + meta={"source": "eval"} + ) + + def _count_tokens(self, text: str) -> int: + """Estimate token count using the tokenizer.""" + try: + # Use the chat model's tokenizer if available + if hasattr(self._chat_model, 'tokenizer'): + return len(self._chat_model.tokenizer.encode(text)) + else: + # Rough estimate: ~4 chars per token + return len(text) // 4 + except Exception: + return len(text) // 4 + + def _add_preferences_as_memory( + self, + prefs: PreferenceList, + query: str, + user_id: str, + turn_id: int, + ) -> List[Dict[str, Any]]: + """ + Add extracted preferences as new memory cards. + Returns list of preference dicts for debug info. + """ + extracted = [] + + if not prefs.preferences or self._projection is None: + return extracted + + # Compute embedding for the query + e_q = self._embed_model.encode([query], return_tensor=False)[0] + v_q = self._projection.transform_vector(np.array(e_q)) + + for pref in prefs.preferences: + note_text = f"When {pref.condition}, {pref.action}." + + # Record for debug + extracted.append({ + "condition": pref.condition, + "action": pref.action, + "confidence": pref.confidence, + }) + + # Deduplication check + is_duplicate = any( + card.user_id == user_id and card.note_text == note_text + for card in self._memory_cards + ) + + if is_duplicate: + continue + + # Create new memory card + card = MemoryCard( + card_id=str(uuid.uuid4()), + user_id=user_id, + source_session_id=f"eval_session_{user_id}", + source_turn_ids=[turn_id], + raw_queries=[query], + preference_list=PreferenceList(preferences=[pref]), + note_text=note_text, + embedding_e=list(e_q), + kind="pref", + ) + + # Add to memory store + self._memory_cards.append(card) + self._memory_embeddings = np.vstack([self._memory_embeddings, np.array([e_q])]) + self._item_vectors = np.vstack([self._item_vectors, np.array([v_q])]) + + return extracted + + # ========================================================================= + # Public Interface + # ========================================================================= + + def chat(self, user_id: str, query: str) -> AssistantResponse: + """ + Main online chat interface. + + Args: + user_id: Unique identifier for the user. + query: Current user query/message. + + Returns: + AssistantResponse containing the answer, usage stats, and debug info. + + Notes: + - Internally manages user state, session history, memory retrieval + - After this call, you can call apply_feedback() with the turn's feedback + """ + ctx = self._get_or_create_session(user_id) + session = ctx.session_state + user_state = self._user_store.get_state(user_id) + + # Record user vector before for debug + z_long_before = user_state.z_long.copy().tolist() + z_short_before = user_state.z_short.copy().tolist() + + # Compute query embedding + e_q_t = np.array(self._embed_model.encode([query], return_tensor=False)[0]) + + # Store pending RL update info from last turn (for apply_feedback) + if session.last_query is not None and self.enable_rl_updates: + ctx.pending_rl_update = { + "last_query": session.last_query, + "last_answer": session.last_answer, + "last_memories": session.last_memories, + "last_query_embedding": session.last_query_embedding, + "current_query_embedding": e_q_t, + "last_candidate_item_vectors": session.last_candidate_item_vectors, + "last_policy_probs": session.last_policy_probs, + "last_chosen_indices": session.last_chosen_indices, + } + + # Add user turn to history + user_turn = self._build_chat_turn(user_id, query, "user", ctx.turn_counter) + session.history.append(user_turn) + + # Vanilla mode: pure LLM without any memory or preference extraction + if self.mode == "vanilla": + # Skip preference extraction and memory retrieval entirely + extracted_prefs = [] + candidates = [] + cand_item_vecs = np.array([]) + base_scores = np.array([]) + chosen_indices = [] + probs = np.array([]) + memories_t = [] + memory_notes = [] + else: + # Extract preferences from conversation (if enabled) + extracted_prefs = [] + if self.enable_preference_extraction: + prefs = self._extractor.extract_turn(session.history) + extracted_prefs = self._add_preferences_as_memory( + prefs, query, user_id, ctx.turn_counter + ) + + # Retrieve memories + # In "nopersonal" mode: deterministic retrieval (dense + rerank + topk), no policy/user vector + # In "full" mode: policy-based retrieval with user vector influence + if self.mode == "nopersonal": + candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_no_policy( + user_id=user_id, + query=query, + embed_model=self._embed_model, + reranker=self._reranker, + memory_cards=self._memory_cards, + memory_embeddings=self._memory_embeddings, + topk_dense=self._rl_cfg["dense_topk"], + topk_rerank=self._rl_cfg["rerank_topk"], + only_own_memories=self.only_own_memories, + ) + else: + beta_long = self._rl_cfg["beta_long"] + beta_short = self._rl_cfg["beta_short"] + # eval_mode=True -> sample=False (greedy/deterministic) + # eval_mode=False -> sample=True (stochastic/exploration) + candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_with_policy( + user_id=user_id, + query=query, + embed_model=self._embed_model, + reranker=self._reranker, + memory_cards=self._memory_cards, + memory_embeddings=self._memory_embeddings, + user_store=self._user_store, + item_vectors=self._item_vectors, + topk_dense=self._rl_cfg["dense_topk"], + topk_rerank=self._rl_cfg["rerank_topk"], + beta_long=beta_long, + beta_short=beta_short, + tau=self._rl_cfg["tau"], + only_own_memories=self.only_own_memories, + sample=not self.eval_mode, + ) + + # Get selected memories + memories_t = [candidates[int(i)] for i in chosen_indices] if chosen_indices else [] + memory_notes = [m.note_text for m in memories_t] + + # Build prompt and count tokens + prompt_tokens = self._count_tokens(query) + for turn in session.history: + prompt_tokens += self._count_tokens(turn.text) + for note in memory_notes: + prompt_tokens += self._count_tokens(note) + + # Generate answer + answer_t = self._chat_model.answer( + history=session.history, + memory_notes=memory_notes, + max_new_tokens=self._rl_cfg["max_new_tokens"], + ) + + completion_tokens = self._count_tokens(answer_t) + + # Add assistant turn to history + assist_turn = self._build_chat_turn(user_id, answer_t, "assistant", ctx.turn_counter) + session.history.append(assist_turn) + + # Update session state for next turn + session.last_query = query + session.last_answer = answer_t + session.last_memories = memories_t + session.last_query_embedding = e_q_t + session.last_candidate_item_vectors = cand_item_vecs + session.last_policy_probs = probs + session.last_chosen_indices = list(chosen_indices) if len(chosen_indices) > 0 else [] + + ctx.turn_counter += 1 + + # Build debug info + debug = DebugInfo( + selected_memory_ids=[m.card_id for m in memories_t], + selected_memory_notes=[m.note_text for m in memories_t], + selected_memory_scores=[float(probs[i]) if i < len(probs) else 0.0 for i in chosen_indices] if len(chosen_indices) > 0 else [], + user_vector_before=z_long_before + z_short_before, # Concatenated for simplicity + user_vector_after=user_state.z_long.tolist() + user_state.z_short.tolist(), + extracted_preferences=extracted_prefs, + extra={ + "num_candidates": len(candidates), + "num_total_memories": len(self._memory_cards), + "z_long_norm": float(np.linalg.norm(user_state.z_long)), + "z_short_norm": float(np.linalg.norm(user_state.z_short)), + } + ) + + # Build usage stats + usage = UsageStats( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + model=self._llm_name, + ) + + return AssistantResponse( + answer=answer_t, + usage=usage, + debug=debug, + ) + + def reset_session(self, user_id: str) -> None: + """ + Reset session for a user (new chat window). + + This clears: + - Session conversation history + - Short-term user vector (z_short) + - Pending RL update info + + This preserves: + - Long-term user vector (z_long) + - User's memory cards + + Args: + user_id: The user whose session to reset. + """ + # Clear session context + if user_id in self._sessions: + del self._sessions[user_id] + + # Create fresh session + self._sessions[user_id] = _SessionContext( + session_state=OnlineSessionState(user_id=user_id), + turn_counter=0, + ) + + # Reset short-term vector but keep long-term + user_state = self._user_store.get_state(user_id) + user_state.z_short = np.zeros(self._rl_cfg["item_dim"], dtype=np.float32) + self._user_store.save_state(user_state) + + def reset_user(self, user_id: str) -> None: + """ + Completely reset a user (new "life"). + + This clears: + - Long-term user vector (z_long) + - Short-term user vector (z_short) + - User's memory cards + - Session history + - All cached state + + Args: + user_id: The user to reset. + """ + # Clear session + if user_id in self._sessions: + del self._sessions[user_id] + + # Reset user state vectors + user_state = self._user_store.get_state(user_id) + user_state.z_long = self._user_store.global_init_z.copy() + user_state.z_short = np.zeros(self._rl_cfg["item_dim"], dtype=np.float32) + user_state.reward_ma = 0.0 + self._user_store.save_state(user_state) + + # Find indices to KEEP (cards NOT belonging to this user) + # Must do this BEFORE modifying _memory_cards + keep_indices = [ + i for i, card in enumerate(self._memory_cards) + if card.user_id != user_id + ] + + # Filter memory cards + self._memory_cards = [self._memory_cards[i] for i in keep_indices] + + # Filter embeddings and item vectors to match + if len(keep_indices) > 0 and len(self._memory_embeddings) > 0: + self._memory_embeddings = self._memory_embeddings[keep_indices] + self._item_vectors = self._item_vectors[keep_indices] + else: + # No cards left or no embeddings + embed_dim = self._memory_embeddings.shape[1] if len(self._memory_embeddings) > 0 else 4096 + self._memory_embeddings = np.zeros((0, embed_dim), dtype=np.float32) + self._item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32) + + def apply_feedback(self, feedback: Feedback) -> None: + """ + Apply feedback from user simulator or judge. + + This performs the REINFORCE update to user vectors based on + the reward signal from the previous turn. + + Args: + feedback: Feedback object containing reward, gating, and metadata. + + Notes: + - Should be called AFTER chat() but BEFORE the next chat() call + - Uses the stored context from the previous turn + - If enable_rl_updates is False, this is a no-op (logging only) + - If mode is "nopersonal", this is a no-op (baseline comparison) + """ + if not self.enable_rl_updates: + return + + # In "nopersonal" or "vanilla" mode, skip RL updates entirely (baseline) + if self.mode in ("nopersonal", "vanilla"): + return + + user_id = feedback.user_id + ctx = self._sessions.get(user_id) + + if ctx is None or ctx.pending_rl_update is None: + return + + pending = ctx.pending_rl_update + user_state = self._user_store.get_state(user_id) + + # Check if we have the necessary data for RL update + if (pending.get("last_candidate_item_vectors") is not None and + pending.get("last_policy_probs") is not None and + pending.get("last_chosen_indices") is not None and + len(pending["last_chosen_indices"]) > 0): + + # Extract chosen vectors + chosen_indices = pending["last_chosen_indices"] + candidate_vectors = pending["last_candidate_item_vectors"] + + if len(candidate_vectors) > 0: + # REINFORCE expects: + # - item_vectors: ALL candidate vectors [K, k] + # - chosen_indices: indices into those candidates + # - policy_probs: probabilities over all K candidates [K] + updated = reinforce_update_user_state( + user_state=user_state, + item_vectors=candidate_vectors, # All candidates, not just chosen + chosen_indices=chosen_indices, # Original indices into candidates + policy_probs=pending["last_policy_probs"], + reward_hat=feedback.reward, + gating=feedback.gating, + tau=self._rl_cfg["tau"], + eta_long=self._rl_cfg["eta_long"], + eta_short=self._rl_cfg["eta_short"], + ema_alpha=self._rl_cfg["ema_alpha"], + short_decay=self._rl_cfg["short_decay"], + ) + + if updated: + self._user_store.save_state(user_state) + + # Clear pending update + ctx.pending_rl_update = None + + def get_user_state_summary(self, user_id: str) -> Dict[str, Any]: + """ + Get a summary of the user's current state (for debugging/analysis). + + Args: + user_id: The user to query. + + Returns: + Dictionary with user state information. + """ + user_state = self._user_store.get_state(user_id) + ctx = self._sessions.get(user_id) + + user_memory_count = sum( + 1 for card in self._memory_cards if card.user_id == user_id + ) + + return { + "user_id": user_id, + "z_long_norm": float(np.linalg.norm(user_state.z_long)), + "z_short_norm": float(np.linalg.norm(user_state.z_short)), + "reward_ma": user_state.reward_ma, + "session_history_length": len(ctx.session_state.history) if ctx else 0, + "turn_counter": ctx.turn_counter if ctx else 0, + "user_memory_count": user_memory_count, + "total_memory_count": len(self._memory_cards), + } + + def persist(self) -> None: + """ + Persist all state to disk. + + Call this at the end of an evaluation run to save: + - User state vectors + - Memory cards + """ + # Save user store + self._user_store.persist() + + # Save memory cards + with open(self._memory_cards_path, "w", encoding="utf-8") as f: + for card in self._memory_cards: + f.write(card.model_dump_json() + "\n") + + # Save embeddings + np.save(self._memory_embeddings_path, self._memory_embeddings) + + # Save item projection with updated vectors + if self._projection is not None: + np.savez( + self._item_projection_path, + P=self._projection.P, + mean=self._projection.mean, + V=self._item_vectors, + ) + + print("[PersonalizedLLM] State persisted to disk.") + + +# ============================================================================= +# Convenience Factory +# ============================================================================= + +def create_personalized_llm( + config_path: Optional[str] = None, + **kwargs +) -> PersonalizedLLM: + """ + Factory function to create a PersonalizedLLM instance. + + Args: + config_path: Optional path to configuration file. + **kwargs: Additional arguments passed to PersonalizedLLM constructor. + + Returns: + Configured PersonalizedLLM instance. + """ + return PersonalizedLLM(config_path=config_path, **kwargs) + diff --git a/src/personalization/types.py b/src/personalization/types.py new file mode 100644 index 0000000..a25b560 --- /dev/null +++ b/src/personalization/types.py @@ -0,0 +1,4 @@ +from personalization.retrieval.preference_store.schemas import ChatTurn + +__all__ = ["ChatTurn"] + diff --git a/src/personalization/user_model/__init__.py b/src/personalization/user_model/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/user_model/__init__.py diff --git a/src/personalization/user_model/features.py b/src/personalization/user_model/features.py new file mode 100644 index 0000000..a4508b4 --- /dev/null +++ b/src/personalization/user_model/features.py @@ -0,0 +1,49 @@ +import numpy as np +from dataclasses import dataclass +from sklearn.decomposition import PCA + +@dataclass +class ItemProjection: + P: np.ndarray # [k, d] + mean: np.ndarray # [d] + + @classmethod + def from_pca(cls, embeddings: np.ndarray, k: int) -> "ItemProjection": + """ + embeddings: [M, d] + """ + mean = embeddings.mean(axis=0) + centered = embeddings - mean + + # Ensure k is not larger than min(n_samples, n_features) + n_samples, n_features = embeddings.shape + actual_k = min(k, n_samples, n_features) + + pca = PCA(n_components=actual_k) + pca.fit(centered) + + # pca.components_: [k, d] + P = pca.components_ # Each row is a principal component vector + + # If we had to reduce k, we might want to pad P or handle it? + # For now, let's assume we get what we asked for or less if data is small. + # But for the system we want fixed k. + # If actual_k < k, we should pad with zeros to match expected dimension. + if actual_k < k: + padding = np.zeros((k - actual_k, n_features), dtype=P.dtype) + P = np.vstack([P, padding]) + + return cls(P=P, mean=mean) + + def transform_embeddings(self, E: np.ndarray) -> np.ndarray: + """ + E: [N, d] -> [N, k] + """ + return (E - self.mean) @ self.P.T + + def transform_vector(self, e: np.ndarray) -> np.ndarray: + """ + e: [d] -> [k] + """ + return self.P @ (e - self.mean) + diff --git a/src/personalization/user_model/policy/__init__.py b/src/personalization/user_model/policy/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/user_model/policy/__init__.py diff --git a/src/personalization/user_model/policy/optimizer.py b/src/personalization/user_model/policy/optimizer.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/user_model/policy/optimizer.py diff --git a/src/personalization/user_model/policy/reinforce.py b/src/personalization/user_model/policy/reinforce.py new file mode 100644 index 0000000..adfaef7 --- /dev/null +++ b/src/personalization/user_model/policy/reinforce.py @@ -0,0 +1,104 @@ +from typing import Sequence, List +from dataclasses import dataclass +import numpy as np + +from personalization.user_model.tensor_store import UserState + +@dataclass +class PolicyScores: + scores: np.ndarray # [K] s(q_t, m; u) + probs: np.ndarray # [K] π_z(m|q_t) + +def compute_policy_scores( + base_scores: np.ndarray, # [K], from reranker + user_state: UserState, + item_vectors: np.ndarray, # [K, k], v_m for the K candidates + beta_long: float, + beta_short: float, + tau: float, +) -> PolicyScores: + """ + Compute personalized scores and softmax probabilities. + s(q_t, m; u) = s_0(q_t,m) + z_t^{(eff)}.T @ v_m + z_t^{(eff)} = beta_long * z_long + beta_short * z_short + """ + if len(item_vectors) == 0: + return PolicyScores(scores=np.array([]), probs=np.array([])) + + z_eff = beta_long * user_state.z_long + beta_short * user_state.z_short + + # Calculate personalized term + # item_vectors: [K, k] + # z_eff: [k] + # term: [K] + personalization_term = np.dot(item_vectors, z_eff) + + # Total scores + scores = base_scores + personalization_term + + # Softmax + # Use exp(score/tau) + # Subtract max for stability + scaled_scores = scores / tau + exp_scores = np.exp(scaled_scores - np.max(scaled_scores)) + probs = exp_scores / np.sum(exp_scores) + + return PolicyScores(scores=scores, probs=probs) + +def reinforce_update_user_state( + user_state: UserState, + item_vectors: np.ndarray, # [K, k] for candidates + chosen_indices: Sequence[int], # indices of A_t in 0..K-1 + policy_probs: np.ndarray, # [K] π_z(m|q_t) + reward_hat: float, # \hat r_t + gating: float, # g_t + tau: float, + eta_long: float, + eta_short: float, + ema_alpha: float, + short_decay: float, +) -> bool: + """ + In-place update user_state.z_long / z_short / reward_ma via REINFORCE. + Returns True if update occurred, False otherwise. + """ + if len(chosen_indices) == 0: + return False + + # 1. Baseline Advantage + advantage = gating * (reward_hat - user_state.reward_ma) + + # Optimization: skip if advantage is negligible + if abs(advantage) < 1e-6: + return False + + # 2. Chosen Vector Average (v_{chosen,t}) + chosen_mask = np.zeros(len(item_vectors), dtype=np.float32) + for idx in chosen_indices: + idx_int = int(idx) + if 0 <= idx_int < len(item_vectors): + chosen_mask[idx_int] = 1.0 + + if chosen_mask.sum() == 0: + return False + + chosen_mask /= chosen_mask.sum() # Normalize to average + v_chosen = np.dot(chosen_mask, item_vectors) # [k] + + # 3. Expected Vector (\mu_t(z)) + # policy_probs: [K] + # item_vectors: [K, k] + v_expect = np.dot(policy_probs, item_vectors) # [k] + + # 4. Gradient Direction + grad = (advantage / tau) * (v_chosen - v_expect) + + # 5. Update Vectors + user_state.z_long += eta_long * grad + user_state.z_short = (1.0 - short_decay) * user_state.z_short + eta_short * grad + + # 6. Update Reward Baseline (EMA) + user_state.reward_ma = (1.0 - ema_alpha) * user_state.reward_ma + ema_alpha * reward_hat + + return True + diff --git a/src/personalization/user_model/scoring.py b/src/personalization/user_model/scoring.py new file mode 100644 index 0000000..75ffc84 --- /dev/null +++ b/src/personalization/user_model/scoring.py @@ -0,0 +1,25 @@ +import numpy as np +from .tensor_store import UserState + +def score_with_user( + base_score: float, + user_state: UserState, + v_m: np.ndarray, # [k] + beta_long: float, + beta_short: float, +) -> float: + """ + Personalized scoring: + s = base_score + (beta_long * z_long + beta_short * z_short) . v_m + Day2: beta_long = beta_short = 0 -> s == base_score + """ + z_eff = beta_long * user_state.z_long + beta_short * user_state.z_short + # dot product + # Ensure shapes match + if v_m.shape != z_eff.shape: + # Just in case of dimension mismatch + return float(base_score) + + term = np.dot(z_eff, v_m) + return float(base_score + term) + diff --git a/src/personalization/user_model/session_state.py b/src/personalization/user_model/session_state.py new file mode 100644 index 0000000..5cd2243 --- /dev/null +++ b/src/personalization/user_model/session_state.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass, field +from typing import List, Optional +import numpy as np + +from personalization.retrieval.preference_store.schemas import ChatTurn, MemoryCard + +@dataclass +class OnlineSessionState: + user_id: str + history: List[ChatTurn] = field(default_factory=list) + last_query: Optional[str] = None + last_answer: Optional[str] = None + last_memories: List[MemoryCard] = field(default_factory=list) + last_query_embedding: Optional[np.ndarray] = None + last_candidate_item_vectors: Optional[np.ndarray] = None # [K, k] + last_policy_probs: Optional[np.ndarray] = None # [K] + last_chosen_indices: List[int] = field(default_factory=list) + + diff --git a/src/personalization/user_model/tensor_store.py b/src/personalization/user_model/tensor_store.py new file mode 100644 index 0000000..42dbf4e --- /dev/null +++ b/src/personalization/user_model/tensor_store.py @@ -0,0 +1,80 @@ +import numpy as np +from dataclasses import dataclass +from typing import Dict, Optional +import os + +@dataclass +class UserState: + user_id: str + z_long: np.ndarray # [k] + z_short: np.ndarray # [k] + reward_ma: float # baseline for reward, init 0.0 + +class UserTensorStore: + def __init__(self, k: int, path: str): + self.k = k + self.path = path + self._states: Dict[str, UserState] = {} + self._load() + + # Calculate global mean for initialization + if self._states: + z_all = np.stack([st.z_long for st in self._states.values()]) + self.global_init_z = np.mean(z_all, axis=0) + else: + self.global_init_z = np.zeros(self.k, dtype=np.float32) + + def _load(self): + if os.path.exists(self.path): + try: + data = np.load(self.path, allow_pickle=True) + # Assume saved as dict of user_id -> dict/object + # For simplicity, let's say we save a single dict in a .npy or .npz + # But np.save/load with pickle is tricky for complex objects. + # Let's save as .npz where each key is user_id and value is a structured array or just use z_long for now? + # A robust way for prototype: + # save multiple arrays: "u1_long", "u1_short", "u1_meta" + pass + # For Day 2 prototype, we might just re-init from init script or rely on memory if not persisting strictly. + # But let's try to load if we can. + + # Let's implement a simple npz schema: + # keys: "{uid}_long", "{uid}_short", "{uid}_meta" (meta=[reward_ma]) + for key in data.files: + if key.endswith("_long"): + uid = key[:-5] + z_long = data[key] + z_short = data.get(f"{uid}_short", np.zeros(self.k)) + meta = data.get(f"{uid}_meta", np.array([0.0])) + self._states[uid] = UserState(uid, z_long, z_short, float(meta[0])) + except Exception as e: + print(f"Warning: Failed to load UserStore from {self.path}: {e}") + + def _save(self): + # Save to npz + save_dict = {} + for uid, state in self._states.items(): + save_dict[f"{uid}_long"] = state.z_long + save_dict[f"{uid}_short"] = state.z_short + save_dict[f"{uid}_meta"] = np.array([state.reward_ma]) + np.savez(self.path, **save_dict) + + def get_state(self, user_id: str) -> UserState: + if user_id not in self._states: + # Lazy init with global mean for new users + state = UserState( + user_id=user_id, + z_long=self.global_init_z.copy(), + z_short=np.zeros(self.k, dtype=np.float32), + reward_ma=0.0, + ) + self._states[user_id] = state + return self._states[user_id] + + def save_state(self, state: UserState) -> None: + self._states[state.user_id] = state + + def persist(self): + """Public method to force save to disk.""" + self._save() + diff --git a/src/personalization/utils/__init__.py b/src/personalization/utils/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/utils/__init__.py diff --git a/src/personalization/utils/ids.py b/src/personalization/utils/ids.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/utils/ids.py diff --git a/src/personalization/utils/io.py b/src/personalization/utils/io.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/utils/io.py diff --git a/src/personalization/utils/logging.py b/src/personalization/utils/logging.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/utils/logging.py diff --git a/src/personalization/utils/timing.py b/src/personalization/utils/timing.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/utils/timing.py |
