summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
commite43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch)
tree6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5
Initial commit (clean history)HEADmain
-rw-r--r--.env.example0
-rw-r--r--.gitignore35
-rw-r--r--README.md0
-rw-r--r--configs/base.yaml0
-rw-r--r--configs/local_models.yaml50
-rw-r--r--configs/qwen2.5_0.5b_full_sft.yaml34
-rw-r--r--configs/qwen2.5_1.5b_full_sft.yaml33
-rw-r--r--configs/qwen3_0.6b_full_sft.yaml35
-rw-r--r--configs/qwen3_1.7b_full_sft.yaml34
-rw-r--r--configs/reranker.yaml3
-rw-r--r--configs/retrieval.yaml5
-rw-r--r--configs/user_model.yaml14
-rw-r--r--fine_tuning_prompt_template.txt31
-rw-r--r--pyproject.toml38
-rw-r--r--requirements.txt9
-rw-r--r--scripts/analyze_full_vs_nopersonal.py361
-rw-r--r--scripts/analyze_learning_trend.py521
-rw-r--r--scripts/analyze_memory.py61
-rw-r--r--scripts/analyze_memory_coverage.py103
-rw-r--r--scripts/analyze_user_similarity.py445
-rw-r--r--scripts/assemble_dataset.py85
-rw-r--r--scripts/build_item_space.py52
-rw-r--r--scripts/check_batch_status.py60
-rw-r--r--scripts/clean_memory_store.py52
-rw-r--r--scripts/convert_to_llama_factory.py62
-rw-r--r--scripts/day1_demo.py120
-rw-r--r--scripts/day2_demo.py162
-rw-r--r--scripts/day3_demo_feedback.py127
-rw-r--r--scripts/day4_offline_rl_replay.py208
-rw-r--r--scripts/debug_context_file.py14
-rw-r--r--scripts/debug_minimal_day3.py40
-rw-r--r--scripts/debug_personamem_hash.py22
-rw-r--r--scripts/diagnose_oom.py78
-rw-r--r--scripts/download_datasets.py210
-rw-r--r--scripts/download_llama.py16
-rw-r--r--scripts/download_oasst1.py78
-rw-r--r--scripts/download_personamem.py25
-rw-r--r--scripts/eval_embedder_reranker.py0
-rw-r--r--scripts/eval_interface_example.py154
-rw-r--r--scripts/eval_single_ckpt.py145
-rw-r--r--scripts/evaluate_checkpoints.py205
-rw-r--r--scripts/finish_retry_batches.py154
-rw-r--r--scripts/full_labeling.py125
-rw-r--r--scripts/index_corpus.py0
-rw-r--r--scripts/init_user_states.py86
-rw-r--r--scripts/migrate_preferences.py165
-rw-r--r--scripts/online_personalization_demo.py399
-rw-r--r--scripts/personamem_build_user_vectors.py193
-rw-r--r--scripts/personamem_eval_base_vs_ours.py299
-rw-r--r--scripts/pilot_runner_v0.py362
-rw-r--r--scripts/pilot_runner_v1.py607
-rw-r--r--scripts/pilot_runner_v2.py852
-rw-r--r--scripts/pilot_runner_v3.py924
-rw-r--r--scripts/pilot_runner_v4.py1230
-rw-r--r--scripts/pilot_study.py109
-rw-r--r--scripts/process_putnam_batch.py239
-rw-r--r--scripts/pull_models.py76
-rw-r--r--scripts/recompute_embeddings.py65
-rw-r--r--scripts/recover_and_merge.py151
-rw-r--r--scripts/retrieve_batch_results.py151
-rw-r--r--scripts/retrieve_oasst1.py96
-rw-r--r--scripts/retrieve_synthesis.py118
-rw-r--r--scripts/run_putnam_evaluation.py164
-rw-r--r--scripts/run_server.py0
-rw-r--r--scripts/smoke_extractor_llm.py54
-rw-r--r--scripts/smoke_llms.py73
-rw-r--r--scripts/split_train_test.py76
-rw-r--r--scripts/stats_and_extract.py56
-rw-r--r--scripts/submit_batch.py111
-rw-r--r--scripts/submit_oasst1_batch.py120
-rw-r--r--scripts/submit_retry_batch.py88
-rw-r--r--scripts/submit_synthesis_batch.py131
-rw-r--r--scripts/upload_to_hf.py69
-rw-r--r--src/personalization/__init__.py0
-rw-r--r--src/personalization/config/__init__.py0
-rw-r--r--src/personalization/config/registry.py131
-rw-r--r--src/personalization/config/settings.py73
-rw-r--r--src/personalization/data/personamem_loader.py84
-rw-r--r--src/personalization/evaluation/__init__.py0
-rw-r--r--src/personalization/evaluation/compare_pairs.py0
-rw-r--r--src/personalization/evaluation/metrics.py0
-rw-r--r--src/personalization/feedback/__init__.py0
-rw-r--r--src/personalization/feedback/gating.py72
-rw-r--r--src/personalization/feedback/handlers.py50
-rw-r--r--src/personalization/feedback/online_update.py0
-rw-r--r--src/personalization/feedback/reward_model.py64
-rw-r--r--src/personalization/feedback/sampler.py109
-rw-r--r--src/personalization/feedback/schemas.py23
-rw-r--r--src/personalization/retrieval/__init__.py0
-rw-r--r--src/personalization/retrieval/chunking/__init__.py0
-rw-r--r--src/personalization/retrieval/chunking/rules.py0
-rw-r--r--src/personalization/retrieval/pipeline.py250
-rw-r--r--src/personalization/retrieval/preference_store/__init__.py0
-rw-r--r--src/personalization/retrieval/preference_store/base.py0
-rw-r--r--src/personalization/retrieval/preference_store/schemas.py47
-rw-r--r--src/personalization/retrieval/preference_store/vector_kv.py0
-rw-r--r--src/personalization/retrieval/rerank.py0
-rw-r--r--src/personalization/retrieval/store/__init__.py0
-rw-r--r--src/personalization/retrieval/store/base.py0
-rw-r--r--src/personalization/retrieval/store/faiss_store.py0
-rw-r--r--src/personalization/retrieval/store/pgvector_store.py0
-rw-r--r--src/personalization/serving/__init__.py22
-rw-r--r--src/personalization/serving/api/__init__.py0
-rw-r--r--src/personalization/serving/api/main.py0
-rw-r--r--src/personalization/serving/api/routes/__init__.py0
-rw-r--r--src/personalization/serving/api/routes/feedback.py0
-rw-r--r--src/personalization/serving/api/routes/query.py0
-rw-r--r--src/personalization/serving/api/routes/users.py0
-rw-r--r--src/personalization/serving/api/schemas.py0
-rw-r--r--src/personalization/serving/personalized_llm.py837
-rw-r--r--src/personalization/types.py4
-rw-r--r--src/personalization/user_model/__init__.py0
-rw-r--r--src/personalization/user_model/features.py49
-rw-r--r--src/personalization/user_model/policy/__init__.py0
-rw-r--r--src/personalization/user_model/policy/optimizer.py0
-rw-r--r--src/personalization/user_model/policy/reinforce.py104
-rw-r--r--src/personalization/user_model/scoring.py25
-rw-r--r--src/personalization/user_model/session_state.py19
-rw-r--r--src/personalization/user_model/tensor_store.py80
-rw-r--r--src/personalization/utils/__init__.py0
-rw-r--r--src/personalization/utils/ids.py0
-rw-r--r--src/personalization/utils/io.py0
-rw-r--r--src/personalization/utils/logging.py0
-rw-r--r--src/personalization/utils/timing.py0
124 files changed, 13113 insertions, 0 deletions
diff --git a/.env.example b/.env.example
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/.env.example
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..f141132
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,35 @@
+# Python
+__pycache__/
+*.pyc
+*.pyo
+*.pyd
+.Python
+env/
+venv/
+.venv/
+*.egg-info/
+
+# Models (Large model weights)
+models/
+*.safetensors
+*.bin
+*.pt
+*.pth
+
+# IDE / System
+.vscode/
+.idea/
+.DS_Store
+
+# Logs and temporary files
+*.log
+tmp/
+
+# LLaMA-Factory
+LLaMA-Factory/
+# Cursor
+.cursor/
+
+data/
+
+saves/ \ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/README.md
diff --git a/configs/base.yaml b/configs/base.yaml
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/configs/base.yaml
diff --git a/configs/local_models.yaml b/configs/local_models.yaml
new file mode 100644
index 0000000..13c7fcc
--- /dev/null
+++ b/configs/local_models.yaml
@@ -0,0 +1,50 @@
+models:
+ llm:
+ # New Multi-Backend Config
+ qwen_1_5b:
+ backend: qwen
+ path: models/qwen2.5-1.5b-instruct
+ device: auto
+ dtype: bfloat16
+ max_context_length: 4096
+
+ llama_8b:
+ backend: llama
+ path: models/llama-3.1-8b-instruct
+ device: auto
+ dtype: bfloat16
+ max_context_length: 8192
+
+ # Legacy fallback (needed if from_config is called directly without name)
+ hf_id: Qwen/Qwen2.5-1.5B-Instruct
+ local_path: models/qwen2.5-1.5b-instruct
+ dtype: bfloat16
+ device_map: auto
+
+ preference_extractor:
+ # Default/Legacy
+ default:
+ hf_id: Qwen/Qwen2.5-0.5B-Instruct
+ local_path: models/qwen2.5-0.5b-instruct
+ dtype: bfloat16
+ device_map: auto
+ # New SFT Extractor
+ qwen3_0_6b_sft:
+ path: saves/qwen3-0.6b-full-sft-h200/checkpoint-4358
+ prompt_template_path: fine_tuning_prompt_template.txt
+ device: auto
+ dtype: bfloat16
+ max_new_tokens: 512
+ embedding:
+ qwen3:
+ hf_id: Qwen/Qwen3-Embedding-8B
+ local_path: models/qwen3-embedding-8b
+ nemotron:
+ hf_id: nvidia/llama-embed-nemotron-8b
+ local_path: models/llama-embed-nemotron-8b
+ reranker:
+ qwen3_8b:
+ hf_id: Qwen/Qwen3-Reranker-8B
+ local_path: models/rerankers/qwen3-reranker-8b
+ dtype: bfloat16
+ device_map: auto
diff --git a/configs/qwen2.5_0.5b_full_sft.yaml b/configs/qwen2.5_0.5b_full_sft.yaml
new file mode 100644
index 0000000..ca1cca2
--- /dev/null
+++ b/configs/qwen2.5_0.5b_full_sft.yaml
@@ -0,0 +1,34 @@
+### Qwen2.5-0.5B Full SFT Config
+model_name_or_path: Qwen/Qwen2.5-0.5B-Instruct
+stage: sft
+do_train: true
+finetuning_type: full
+freeze_trainable_layers: 0
+
+dataset: preference_extractor_train
+template: qwen
+cutoff_len: 1024
+overwrite_cache: true
+preprocessing_num_workers: 16
+
+output_dir: saves/qwen2.5-0.5b-full-sft
+logging_steps: 10
+save_strategy: steps
+save_steps: 500
+plot_loss: true
+overwrite_output_dir: true
+
+per_device_train_batch_size: 16
+gradient_accumulation_steps: 8
+learning_rate: 2.0e-5
+num_train_epochs: 1.0
+lr_scheduler_type: cosine
+warmup_ratio: 0.05
+bf16: true
+flash_attn: fa2
+
+val_size: 0.01
+per_device_eval_batch_size: 16
+eval_strategy: steps
+eval_steps: 500
+
diff --git a/configs/qwen2.5_1.5b_full_sft.yaml b/configs/qwen2.5_1.5b_full_sft.yaml
new file mode 100644
index 0000000..e91176b
--- /dev/null
+++ b/configs/qwen2.5_1.5b_full_sft.yaml
@@ -0,0 +1,33 @@
+### Qwen2.5-1.5B Full SFT Config
+model_name_or_path: Qwen/Qwen2.5-1.5B-Instruct
+stage: sft
+do_train: true
+finetuning_type: full
+
+dataset: preference_extractor_train
+template: qwen
+cutoff_len: 1024
+overwrite_cache: true
+preprocessing_num_workers: 16
+
+output_dir: saves/qwen2.5-1.5b-full-sft
+logging_steps: 10
+save_strategy: steps
+save_steps: 500
+plot_loss: true
+overwrite_output_dir: true
+
+per_device_train_batch_size: 8
+gradient_accumulation_steps: 16
+learning_rate: 2.0e-5
+num_train_epochs: 1.0
+lr_scheduler_type: cosine
+warmup_ratio: 0.05
+bf16: true
+flash_attn: fa2
+
+val_size: 0.01
+per_device_eval_batch_size: 8
+eval_strategy: steps
+eval_steps: 500
+
diff --git a/configs/qwen3_0.6b_full_sft.yaml b/configs/qwen3_0.6b_full_sft.yaml
new file mode 100644
index 0000000..e41a419
--- /dev/null
+++ b/configs/qwen3_0.6b_full_sft.yaml
@@ -0,0 +1,35 @@
+### Qwen3-0.6B Full SFT Config (H200x4 Optimized)
+model_name_or_path: Qwen/Qwen3-0.6B
+stage: sft
+do_train: true
+finetuning_type: full
+freeze_trainable_layers: 0
+
+dataset: preference_extractor_train
+template: qwen
+cutoff_len: 1024
+overwrite_cache: true
+preprocessing_num_workers: 16
+
+output_dir: saves/qwen3-0.6b-full-sft-h200
+logging_steps: 5
+save_strategy: steps
+save_steps: 200
+plot_loss: true
+overwrite_output_dir: true
+
+# H200x4 Configuration
+# Total Batch Size = 32 * 4 * 1 = 128
+per_device_train_batch_size: 32
+gradient_accumulation_steps: 1
+learning_rate: 2.0e-5
+num_train_epochs: 1.0
+lr_scheduler_type: cosine
+warmup_ratio: 0.05
+bf16: true
+flash_attn: fa2
+
+val_size: 0.01
+per_device_eval_batch_size: 32
+eval_strategy: steps
+eval_steps: 200
diff --git a/configs/qwen3_1.7b_full_sft.yaml b/configs/qwen3_1.7b_full_sft.yaml
new file mode 100644
index 0000000..069c53d
--- /dev/null
+++ b/configs/qwen3_1.7b_full_sft.yaml
@@ -0,0 +1,34 @@
+### Qwen3-1.7B Full SFT Config (H200x4 Optimized)
+model_name_or_path: models/Qwen3-1.7B
+stage: sft
+do_train: true
+finetuning_type: full
+
+dataset: preference_extractor_train
+template: qwen
+cutoff_len: 1024
+overwrite_cache: true
+preprocessing_num_workers: 4
+
+output_dir: saves/qwen3-1.7b-full-sft-h200
+logging_steps: 5
+save_strategy: steps
+save_steps: 200
+plot_loss: true
+overwrite_output_dir: true
+
+# H200x4 Configuration
+# Total Batch Size = 32 * 4 * 1 = 128
+per_device_train_batch_size: 32
+gradient_accumulation_steps: 1
+learning_rate: 2.0e-5
+num_train_epochs: 1.0
+lr_scheduler_type: cosine
+warmup_ratio: 0.05
+bf16: true
+flash_attn: fa2
+
+val_size: 0.01
+per_device_eval_batch_size: 32
+eval_strategy: steps
+eval_steps: 200
diff --git a/configs/reranker.yaml b/configs/reranker.yaml
new file mode 100644
index 0000000..c376fc7
--- /dev/null
+++ b/configs/reranker.yaml
@@ -0,0 +1,3 @@
+reranker:
+ default: qwen3_8b
+
diff --git a/configs/retrieval.yaml b/configs/retrieval.yaml
new file mode 100644
index 0000000..d2e100e
--- /dev/null
+++ b/configs/retrieval.yaml
@@ -0,0 +1,5 @@
+retrieval:
+ dense_topk: 64 # Initial recall count
+ rerank_topk: 8 # Count fed to LLM after rerank
+ pca_dim: 256
+
diff --git a/configs/user_model.yaml b/configs/user_model.yaml
new file mode 100644
index 0000000..7b8e230
--- /dev/null
+++ b/configs/user_model.yaml
@@ -0,0 +1,14 @@
+user_model:
+ item_dim: 256
+ user_dim: 256
+ beta_long: 0.1 # Enable personalization for Day 4
+ beta_short: 0.3
+ tau: 1.0
+ preference_extractor_name: qwen3_0_6b_sft # Switch to new extractor
+ rl:
+ eta_long: 1.0e-3
+ eta_short: 5.0e-3
+ ema_alpha: 0.05
+ short_decay: 0.1
+
+llm_name: llama_8b # Switch backend to Llama 3.1
diff --git a/fine_tuning_prompt_template.txt b/fine_tuning_prompt_template.txt
new file mode 100644
index 0000000..749fdb6
--- /dev/null
+++ b/fine_tuning_prompt_template.txt
@@ -0,0 +1,31 @@
+=== System Prompt ===
+You are a preference extraction assistant. Your task is to analyze the user's query and extract any persistent preferences (such as style, formatting, programming language, or constraints) that should apply to future interactions.
+
+Output strictly valid JSON following this structure:
+{
+ "preferences": [
+ {
+ "condition": "When to apply this preference (e.g., 'code generation', 'general')",
+ "action": "What to do (e.g., 'use Python', 'be concise')",
+ "confidence": 1.0
+ }
+ ]
+}
+
+If the user query contains no persistent preferences (e.g., simple greetings, one-off tasks, factual questions), return:
+{"preferences": []}
+
+=== User Input Example ===
+I am a Python developer, so always give me code examples in Python.
+
+=== Assistant Output Example ===
+{
+ "preferences": [
+ {
+ "condition": "code examples",
+ "action": "use Python",
+ "confidence": 1.0
+ }
+ ]
+}
+
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..9c1e7ac
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,38 @@
+[
+tool.black
+]
+
+[build-system]
+requires = ["hatchling>=1.18"]
+build-backend = "hatchling.build"
+
+[project]
+name = "personalization-user-model"
+version = "0.1.0"
+description = "Personalized memory RAG system with online user modeling"
+readme = "README.md"
+requires-python = ">=3.10"
+license = { text = "Apache-2.0" }
+authors = [
+ { name = "yurenh2" }
+]
+dependencies = [
+ "torch>=2.3.0",
+ "transformers>=4.44.0",
+ "accelerate>=0.33.0",
+ "huggingface_hub>=0.24.0",
+ "pydantic>=2.7.0",
+ "pyyaml>=6.0.0",
+ "safetensors>=0.4.2"
+]
+
+[project.urls]
+homepage = "https://example.com"
+
+[tool.hatch.build.targets.wheel]
+packages = ["src/personalization"]
+
+[tool.hatch.metadata]
+allow-direct-references = true
+
+
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..1e227de
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,9 @@
+torch>=2.3.0
+transformers>=4.44.0
+accelerate>=0.33.0
+huggingface_hub>=0.24.0
+pydantic>=2.7.0
+PyYAML>=6.0.0
+safetensors>=0.4.2
+
+
diff --git a/scripts/analyze_full_vs_nopersonal.py b/scripts/analyze_full_vs_nopersonal.py
new file mode 100644
index 0000000..a10f3ef
--- /dev/null
+++ b/scripts/analyze_full_vs_nopersonal.py
@@ -0,0 +1,361 @@
+#!/usr/bin/env python3
+"""
+Analyze Full vs NoPersonal Baseline Comparison.
+
+This script loads logs from pilot_runner_v4 runs (both full and nopersonal modes)
+and produces comparison metrics for:
+1. Session 2 retention (base task avg satisfaction)
+2. Violation rates by type
+3. Preference memory recall@k
+
+Usage:
+ python scripts/analyze_full_vs_nopersonal.py \
+ --full data/logs/pilot_v4_full_TIMESTAMP.jsonl \
+ --nopersonal data/logs/pilot_v4_nopersonal_TIMESTAMP.jsonl
+"""
+
+import json
+import argparse
+import re
+from dataclasses import dataclass
+from typing import List, Dict, Any, Optional, Set
+from collections import defaultdict
+
+
+@dataclass
+class TurnLog:
+ """Parsed log entry."""
+ user_id: str
+ persona_id: str
+ session_id: int
+ turn_id: int
+ query: str
+ query_type: str
+ task_type: str
+ answer: str
+ answer_length: int
+ sat_t: float
+ sev_t: float
+ prog_t: float
+ violations: List[str]
+ enforced_constraints: List[str]
+ reward: float
+ gating: float
+ is_complaint: bool
+ reveal_state_before: Dict[str, bool]
+ reveal_state_after: Dict[str, bool]
+ newly_revealed: List[str]
+ z_long_norm_before: float
+ z_long_norm_after: float
+ z_short_norm_before: float
+ z_short_norm_after: float
+ prompt_tokens: int
+ completion_tokens: int
+ total_tokens: int
+ num_memories_retrieved: int
+ num_prefs_extracted: int
+ selected_memory_ids: List[str]
+ selected_memory_notes: List[str]
+ selected_memory_scores: List[float]
+ num_candidates: int
+ num_total_memories: int
+ mode: str
+ eval_mode: bool # True = greedy, False = sample
+
+
+def load_logs(filepath: str) -> List[TurnLog]:
+ """Load logs from JSONL file."""
+ logs = []
+ with open(filepath, "r") as f:
+ for line in f:
+ if line.strip():
+ data = json.loads(line)
+ # Handle missing fields with defaults
+ log = TurnLog(
+ user_id=data.get("user_id", ""),
+ persona_id=data.get("persona_id", ""),
+ session_id=data.get("session_id", 0),
+ turn_id=data.get("turn_id", 0),
+ query=data.get("query", ""),
+ query_type=data.get("query_type", ""),
+ task_type=data.get("task_type", ""),
+ answer=data.get("answer", ""),
+ answer_length=data.get("answer_length", 0),
+ sat_t=data.get("sat_t", 0.0),
+ sev_t=data.get("sev_t", 0.0),
+ prog_t=data.get("prog_t", 0.0),
+ violations=data.get("violations", []),
+ enforced_constraints=data.get("enforced_constraints", []),
+ reward=data.get("reward", 0.0),
+ gating=data.get("gating", 0.0),
+ is_complaint=data.get("is_complaint", False),
+ reveal_state_before=data.get("reveal_state_before", {}),
+ reveal_state_after=data.get("reveal_state_after", {}),
+ newly_revealed=data.get("newly_revealed", []),
+ z_long_norm_before=data.get("z_long_norm_before", 0.0),
+ z_long_norm_after=data.get("z_long_norm_after", 0.0),
+ z_short_norm_before=data.get("z_short_norm_before", 0.0),
+ z_short_norm_after=data.get("z_short_norm_after", 0.0),
+ prompt_tokens=data.get("prompt_tokens", 0),
+ completion_tokens=data.get("completion_tokens", 0),
+ total_tokens=data.get("total_tokens", 0),
+ num_memories_retrieved=data.get("num_memories_retrieved", 0),
+ num_prefs_extracted=data.get("num_prefs_extracted", 0),
+ selected_memory_ids=data.get("selected_memory_ids", []),
+ selected_memory_notes=data.get("selected_memory_notes", []),
+ selected_memory_scores=data.get("selected_memory_scores", []),
+ num_candidates=data.get("num_candidates", 0),
+ num_total_memories=data.get("num_total_memories", 0),
+ mode=data.get("mode", "unknown"),
+ eval_mode=data.get("eval_mode", True),
+ )
+ logs.append(log)
+ return logs
+
+
+def is_base_task_turn(log: TurnLog) -> bool:
+ """Check if this is a base task turn (not complaint, not preference)."""
+ if log.is_complaint:
+ return False
+ if log.query_type == "preference":
+ return False
+ if log.query_type in ("task", "task_list"):
+ return True
+ return False
+
+
+def compute_session2_base_avg_sat(logs: List[TurnLog]) -> Dict[str, float]:
+ """
+ Compute average satisfaction for Session 2 base tasks.
+ Returns: {user_id: avg_sat}
+ """
+ user_sat = defaultdict(list)
+
+ for log in logs:
+ if log.session_id == 2 and is_base_task_turn(log):
+ user_sat[log.user_id].append(log.sat_t)
+
+ result = {}
+ for user_id, sats in user_sat.items():
+ if sats:
+ result[user_id] = sum(sats) / len(sats)
+
+ return result
+
+
+def compute_overall_session2_avg_sat(logs: List[TurnLog]) -> float:
+ """Compute overall average satisfaction for Session 2 base tasks."""
+ sats = []
+ for log in logs:
+ if log.session_id == 2 and is_base_task_turn(log):
+ sats.append(log.sat_t)
+ return sum(sats) / len(sats) if sats else 0.0
+
+
+def compute_violation_rates(logs: List[TurnLog], session_filter: Optional[int] = None) -> Dict[str, float]:
+ """
+ Compute violation rates by type.
+ Returns: {violation_type: rate}
+ """
+ violation_counts = defaultdict(int)
+ total_base_tasks = 0
+
+ for log in logs:
+ if session_filter is not None and log.session_id != session_filter:
+ continue
+ if not is_base_task_turn(log):
+ continue
+
+ total_base_tasks += 1
+ for v in log.violations:
+ violation_counts[v] += 1
+
+ if total_base_tasks == 0:
+ return {}
+
+ return {v: count / total_base_tasks for v, count in violation_counts.items()}
+
+
+def is_pref_memory(note_text: str, dim: str) -> bool:
+ """
+ Check if a memory note relates to a preference dimension.
+ dim: "short", "bullets", or "lang"
+ """
+ text_lower = note_text.lower()
+
+ if dim == "short":
+ keywords = [
+ "short", "concise", "brief", "200", "characters", "less",
+ "简短", "精简", "字以内", "不超过", "简洁"
+ ]
+ return any(kw in text_lower for kw in keywords)
+
+ elif dim == "bullets":
+ keywords = [
+ "bullet", "bullets", "list", "point", "points",
+ "要点", "列表", "项目符号"
+ ]
+ # Also check for "no bullet" / "don't use bullet"
+ no_bullet = any(x in text_lower for x in ["no bullet", "don't use bullet", "without bullet", "不要要点", "不使用列表"])
+ if no_bullet:
+ return True # It's still about bullets preference
+ return any(kw in text_lower for kw in keywords)
+
+ elif dim == "lang":
+ # Check for language preferences
+ zh_keywords = ["chinese", "中文", "用中文", "请用中文"]
+ en_keywords = ["english", "英文", "in english"]
+ return any(kw in text_lower for kw in zh_keywords + en_keywords)
+
+ return False
+
+
+def compute_pref_recall_at_k(logs: List[TurnLog], dim: str, session_filter: Optional[int] = None) -> float:
+ """
+ Compute preference memory recall@k for a given dimension.
+ Returns: fraction of base task turns where a relevant pref memory was retrieved.
+ """
+ hits = 0
+ total = 0
+
+ for log in logs:
+ if session_filter is not None and log.session_id != session_filter:
+ continue
+ if not is_base_task_turn(log):
+ continue
+
+ total += 1
+ # Check if any selected memory note matches the dimension
+ for note in log.selected_memory_notes:
+ if is_pref_memory(note, dim):
+ hits += 1
+ break
+
+ return hits / total if total > 0 else 0.0
+
+
+def print_comparison_table(full_logs: List[TurnLog], nopersonal_logs: List[TurnLog]):
+ """Print a comparison table of Full vs NoPersonal metrics."""
+
+ # Detect mode from logs
+ full_mode = full_logs[0].mode if full_logs else "unknown"
+ full_eval = "greedy" if (full_logs and full_logs[0].eval_mode) else "sample"
+ np_mode = nopersonal_logs[0].mode if nopersonal_logs else "unknown"
+ np_eval = "greedy" if (nopersonal_logs and nopersonal_logs[0].eval_mode) else "sample"
+
+ print("\n" + "=" * 70)
+ print("FULL vs NOPERSONAL COMPARISON")
+ print(f"Full: mode={full_mode}, selection={full_eval}")
+ print(f"NoPersonal: mode={np_mode}, selection={np_eval}")
+ print("=" * 70)
+
+ # 1. Session 2 Base Task Average Satisfaction
+ print("\n### 1. Session 2 Base Task Average Satisfaction")
+ print("-" * 50)
+
+ full_s2_sat = compute_overall_session2_avg_sat(full_logs)
+ nopersonal_s2_sat = compute_overall_session2_avg_sat(nopersonal_logs)
+ delta = full_s2_sat - nopersonal_s2_sat
+
+ print(f"{'Metric':<30} {'Full':<12} {'NoPersonal':<12} {'Delta':<12}")
+ print("-" * 50)
+ print(f"{'avg_sat_S2_base':<30} {full_s2_sat:<12.4f} {nopersonal_s2_sat:<12.4f} {delta:<+12.4f}")
+
+ # Per-user breakdown
+ full_user_sat = compute_session2_base_avg_sat(full_logs)
+ nopersonal_user_sat = compute_session2_base_avg_sat(nopersonal_logs)
+
+ print("\nPer-user Session 2 avg_sat:")
+ print(f"{'User':<20} {'Full':<12} {'NoPersonal':<12} {'Delta':<12}")
+ print("-" * 50)
+ all_users = set(full_user_sat.keys()) | set(nopersonal_user_sat.keys())
+ for user_id in sorted(all_users):
+ f_sat = full_user_sat.get(user_id, 0.0)
+ n_sat = nopersonal_user_sat.get(user_id, 0.0)
+ d = f_sat - n_sat
+ print(f"{user_id:<20} {f_sat:<12.4f} {n_sat:<12.4f} {d:<+12.4f}")
+
+ # 2. Violation Rates
+ print("\n### 2. Session 2 Violation Rates")
+ print("-" * 50)
+
+ full_viol = compute_violation_rates(full_logs, session_filter=2)
+ nopersonal_viol = compute_violation_rates(nopersonal_logs, session_filter=2)
+
+ all_viols = set(full_viol.keys()) | set(nopersonal_viol.keys())
+ key_viols = ["too_long", "no_bullets", "has_bullets", "wrong_lang", "empty_answer"]
+
+ print(f"{'Violation Type':<20} {'Full':<12} {'NoPersonal':<12} {'Delta':<12}")
+ print("-" * 50)
+ for v in key_viols:
+ if v in all_viols:
+ f_rate = full_viol.get(v, 0.0)
+ n_rate = nopersonal_viol.get(v, 0.0)
+ d = f_rate - n_rate
+ print(f"{v:<20} {f_rate:<12.4f} {n_rate:<12.4f} {d:<+12.4f}")
+
+ # Other violations
+ other_viols = [v for v in all_viols if v not in key_viols]
+ for v in sorted(other_viols):
+ f_rate = full_viol.get(v, 0.0)
+ n_rate = nopersonal_viol.get(v, 0.0)
+ d = f_rate - n_rate
+ print(f"{v:<20} {f_rate:<12.4f} {n_rate:<12.4f} {d:<+12.4f}")
+
+ # 3. Preference Memory Recall@k
+ print("\n### 3. Session 2 Preference Memory Recall@k")
+ print("-" * 50)
+
+ dims = ["short", "bullets", "lang"]
+ print(f"{'Dimension':<20} {'Full':<12} {'NoPersonal':<12} {'Delta':<12}")
+ print("-" * 50)
+ for dim in dims:
+ f_recall = compute_pref_recall_at_k(full_logs, dim, session_filter=2)
+ n_recall = compute_pref_recall_at_k(nopersonal_logs, dim, session_filter=2)
+ d = f_recall - n_recall
+ print(f"{dim:<20} {f_recall:<12.4f} {n_recall:<12.4f} {d:<+12.4f}")
+
+ # 4. Summary Statistics
+ print("\n### 4. Summary Statistics")
+ print("-" * 50)
+
+ def count_base_tasks(logs, session=None):
+ return sum(1 for l in logs if (session is None or l.session_id == session) and is_base_task_turn(l))
+
+ def count_complaints(logs, session=None):
+ return sum(1 for l in logs if (session is None or l.session_id == session) and l.is_complaint)
+
+ print(f"{'Statistic':<30} {'Full':<12} {'NoPersonal':<12}")
+ print("-" * 50)
+ print(f"{'Total turns':<30} {len(full_logs):<12} {len(nopersonal_logs):<12}")
+ print(f"{'S2 base task turns':<30} {count_base_tasks(full_logs, 2):<12} {count_base_tasks(nopersonal_logs, 2):<12}")
+ print(f"{'S2 complaint turns':<30} {count_complaints(full_logs, 2):<12} {count_complaints(nopersonal_logs, 2):<12}")
+
+ # Token usage
+ full_tokens = sum(l.total_tokens for l in full_logs)
+ nopersonal_tokens = sum(l.total_tokens for l in nopersonal_logs)
+ print(f"{'Total tokens':<30} {full_tokens:<12} {nopersonal_tokens:<12}")
+
+ print("\n" + "=" * 70)
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Analyze Full vs NoPersonal Comparison")
+ parser.add_argument("--full", type=str, required=True, help="Path to Full mode log file")
+ parser.add_argument("--nopersonal", type=str, required=True, help="Path to NoPersonal mode log file")
+ args = parser.parse_args()
+
+ print(f"Loading Full logs from: {args.full}")
+ full_logs = load_logs(args.full)
+ print(f" Loaded {len(full_logs)} turns")
+
+ print(f"Loading NoPersonal logs from: {args.nopersonal}")
+ nopersonal_logs = load_logs(args.nopersonal)
+ print(f" Loaded {len(nopersonal_logs)} turns")
+
+ print_comparison_table(full_logs, nopersonal_logs)
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/scripts/analyze_learning_trend.py b/scripts/analyze_learning_trend.py
new file mode 100644
index 0000000..9ab4699
--- /dev/null
+++ b/scripts/analyze_learning_trend.py
@@ -0,0 +1,521 @@
+#!/usr/bin/env python3
+"""
+Analyze Learning Trend: Correlation and z_u Norm over Sessions
+
+This script shows that:
+1. User vector norms (||z_u||) grow over sessions (learning is happening)
+2. Correlation between learned and ground-truth similarity increases over sessions
+
+Usage:
+ python scripts/analyze_learning_trend.py \
+ --logs data/logs/pilot_v4_full-greedy_*.jsonl
+"""
+
+import argparse
+import json
+import numpy as np
+from typing import Dict, List, Tuple
+from collections import defaultdict
+from dataclasses import dataclass
+import os
+
+
+# =============================================================================
+# Persona Definitions (ground truth)
+# =============================================================================
+
+@dataclass
+class StylePrefs:
+ require_short: bool = False
+ max_chars: int = 300
+ require_bullets: bool = False
+ lang: str = "en"
+
+
+PERSONAS = {
+ "user_A_short_bullets_en": StylePrefs(require_short=True, max_chars=200, require_bullets=True, lang="en"),
+ "user_B_short_no_bullets_en": StylePrefs(require_short=True, max_chars=200, require_bullets=False, lang="en"),
+ "user_C_long_bullets_en": StylePrefs(require_short=False, max_chars=800, require_bullets=True, lang="en"),
+ "user_D_short_bullets_zh": StylePrefs(require_short=True, max_chars=200, require_bullets=True, lang="zh"),
+ "user_E_long_no_bullets_zh": StylePrefs(require_short=False, max_chars=800, require_bullets=False, lang="zh"),
+ "user_F_extreme_short_en": StylePrefs(require_short=True, max_chars=100, require_bullets=True, lang="en"),
+}
+
+
+# =============================================================================
+# Data Loading
+# =============================================================================
+
+def load_logs(filepath: str) -> List[dict]:
+ """Load turn logs from JSONL file."""
+ logs = []
+ with open(filepath, "r") as f:
+ for line in f:
+ if line.strip():
+ logs.append(json.loads(line))
+ return logs
+
+
+def extract_z_norms_by_session(logs: List[dict]) -> Dict[str, Dict[int, Tuple[float, float]]]:
+ """
+ Extract z_long_norm and z_short_norm at the end of each session for each user.
+
+ Returns:
+ {user_id: {session_id: (z_long_norm, z_short_norm)}}
+ """
+ user_session_norms = defaultdict(dict)
+
+ # Group by user and session, take the last turn's z_norm
+ user_session_turns = defaultdict(lambda: defaultdict(list))
+ for log in logs:
+ user_id = log["user_id"]
+ session_id = log["session_id"]
+ user_session_turns[user_id][session_id].append(log)
+
+ for user_id, sessions in user_session_turns.items():
+ for session_id, turns in sessions.items():
+ # Get the last turn of this session
+ last_turn = max(turns, key=lambda x: x["turn_id"])
+ z_long = last_turn.get("z_long_norm_after", 0.0)
+ z_short = last_turn.get("z_short_norm_after", 0.0)
+ user_session_norms[user_id][session_id] = (z_long, z_short)
+
+ return dict(user_session_norms)
+
+
+# =============================================================================
+# Similarity Computation
+# =============================================================================
+
+def cosine_similarity(v1: np.ndarray, v2: np.ndarray) -> float:
+ """Compute cosine similarity."""
+ norm1 = np.linalg.norm(v1)
+ norm2 = np.linalg.norm(v2)
+ if norm1 < 1e-10 or norm2 < 1e-10:
+ return 0.0
+ return float(np.dot(v1, v2) / (norm1 * norm2))
+
+
+def compute_ground_truth_similarity_matrix(user_order: List[str]) -> np.ndarray:
+ """Compute ground truth similarity based on preference overlap."""
+ n = len(user_order)
+ sim_matrix = np.zeros((n, n))
+
+ for i, u1 in enumerate(user_order):
+ for j, u2 in enumerate(user_order):
+ if u1 not in PERSONAS or u2 not in PERSONAS:
+ sim_matrix[i, j] = 0.0 if i != j else 1.0
+ continue
+
+ p1 = PERSONAS[u1]
+ p2 = PERSONAS[u2]
+
+ matches = 0
+ if p1.require_short == p2.require_short:
+ matches += 1
+ if p1.require_bullets == p2.require_bullets:
+ matches += 1
+ if p1.lang == p2.lang:
+ matches += 1
+
+ sim_matrix[i, j] = matches / 3.0
+
+ return sim_matrix
+
+
+def compute_spearman_correlation(learned: np.ndarray, ground_truth: np.ndarray) -> float:
+ """Compute Spearman correlation between similarity matrices."""
+ from scipy.stats import spearmanr
+
+ n = learned.shape[0]
+ learned_flat = []
+ gt_flat = []
+
+ for i in range(n):
+ for j in range(i + 1, n):
+ learned_flat.append(learned[i, j])
+ gt_flat.append(ground_truth[i, j])
+
+ if len(learned_flat) < 2:
+ return 0.0
+
+ # Handle case where all values are the same
+ if np.std(learned_flat) < 1e-10:
+ return 0.0
+
+ corr, _ = spearmanr(learned_flat, gt_flat)
+ return float(corr) if not np.isnan(corr) else 0.0
+
+
+def load_final_z_vectors(user_store_path: str) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
+ """Load final z_u vectors from saved user store."""
+ try:
+ data = np.load(user_store_path, allow_pickle=True)
+ user_vectors = {}
+
+ # UserTensorStore saves in format: {uid}_long, {uid}_short
+ user_ids = set()
+ for key in data.files:
+ if key.endswith("_long"):
+ uid = key[:-5]
+ user_ids.add(uid)
+
+ for uid in user_ids:
+ long_key = f"{uid}_long"
+ short_key = f"{uid}_short"
+ if long_key in data.files and short_key in data.files:
+ user_vectors[uid] = (data[long_key], data[short_key])
+
+ return user_vectors
+ except Exception as e:
+ print(f"[Warning] Could not load user store: {e}")
+ return {}
+
+
+# Global cache for final z vectors
+_FINAL_Z_VECTORS = None
+
+
+def get_z_vectors_at_session(
+ logs: List[dict],
+ user_order: List[str],
+ up_to_session: int,
+ final_z_vectors: Dict[str, Tuple[np.ndarray, np.ndarray]]
+) -> Dict[str, np.ndarray]:
+ """
+ Estimate z_u vectors at a given session checkpoint.
+
+ Method: Use the DIRECTION of the final z_u, scaled by the z_norm at session s.
+ This assumes z_u direction is relatively stable but magnitude grows.
+
+ z_u(s) ≈ (z_final / ||z_final||) * ||z(s)||
+ """
+ user_vectors = {}
+
+ for user_id in user_order:
+ # Get z_norm at the end of this session
+ user_turns = [l for l in logs if l["user_id"] == user_id and l["session_id"] <= up_to_session]
+
+ if not user_turns:
+ user_vectors[user_id] = np.zeros(512) # 256 + 256
+ continue
+
+ # Get the last turn's z_norm at this session
+ last_turn = max(user_turns, key=lambda x: (x["session_id"], x["turn_id"]))
+ z_long_norm_s = last_turn.get("z_long_norm_after", 0.0)
+ z_short_norm_s = last_turn.get("z_short_norm_after", 0.0)
+
+ # Get final z vectors (direction)
+ if user_id in final_z_vectors:
+ z_long_final, z_short_final = final_z_vectors[user_id]
+
+ # Compute unit vectors (direction)
+ z_long_final_norm = np.linalg.norm(z_long_final)
+ z_short_final_norm = np.linalg.norm(z_short_final)
+
+ if z_long_final_norm > 1e-10:
+ z_long_unit = z_long_final / z_long_final_norm
+ else:
+ z_long_unit = np.zeros_like(z_long_final)
+
+ if z_short_final_norm > 1e-10:
+ z_short_unit = z_short_final / z_short_final_norm
+ else:
+ z_short_unit = np.zeros_like(z_short_final)
+
+ # Scale by the norm at this session
+ z_long_s = z_long_unit * z_long_norm_s
+ z_short_s = z_short_unit * z_short_norm_s
+
+ # Concatenate
+ user_vectors[user_id] = np.concatenate([z_long_s, z_short_s])
+ else:
+ user_vectors[user_id] = np.zeros(512)
+
+ return user_vectors
+
+
+def compute_similarity_at_session(
+ logs: List[dict],
+ user_order: List[str],
+ up_to_session: int,
+ final_z_vectors: Dict[str, Tuple[np.ndarray, np.ndarray]] = None
+) -> np.ndarray:
+ """Compute learned similarity matrix at a given session using actual z vectors."""
+ if final_z_vectors:
+ user_vectors = get_z_vectors_at_session(logs, user_order, up_to_session, final_z_vectors)
+ else:
+ # Fallback to old method
+ user_vectors = simulate_z_vectors_at_session_fallback(logs, user_order, up_to_session)
+
+ n = len(user_order)
+ sim_matrix = np.zeros((n, n))
+
+ for i, u1 in enumerate(user_order):
+ for j, u2 in enumerate(user_order):
+ v1 = user_vectors.get(u1, np.zeros(512))
+ v2 = user_vectors.get(u2, np.zeros(512))
+ sim_matrix[i, j] = cosine_similarity(v1, v2)
+
+ return sim_matrix
+
+
+def simulate_z_vectors_at_session_fallback(
+ logs: List[dict],
+ user_order: List[str],
+ up_to_session: int,
+ dim: int = 256
+) -> Dict[str, np.ndarray]:
+ """Fallback: simulate z_u based on violation patterns (less accurate)."""
+ user_vectors = {}
+
+ for user_id in user_order:
+ user_turns = [l for l in logs if l["user_id"] == user_id and l["session_id"] <= up_to_session]
+
+ if not user_turns:
+ user_vectors[user_id] = np.zeros(dim * 2)
+ continue
+
+ last_turn = max(user_turns, key=lambda x: (x["session_id"], x["turn_id"]))
+ z_long_norm = last_turn.get("z_long_norm_after", 0.0)
+ z_short_norm = last_turn.get("z_short_norm_after", 0.0)
+
+ violation_counts = defaultdict(int)
+ for turn in user_turns:
+ for v in turn.get("violations", []):
+ violation_counts[v] += 1
+
+ feature_dim = 10
+ features = np.zeros(feature_dim)
+ features[0] = violation_counts.get("too_long", 0)
+ features[1] = violation_counts.get("no_bullets", 0)
+ features[2] = violation_counts.get("has_bullets", 0)
+ features[3] = violation_counts.get("wrong_lang", 0)
+ features[4] = z_long_norm * 100
+ features[5] = z_short_norm * 100
+
+ norm = np.linalg.norm(features)
+ if norm > 1e-10:
+ features = features / norm
+
+ user_vectors[user_id] = features
+
+ return user_vectors
+
+
+def compute_similarity_at_session(
+ logs: List[dict],
+ user_order: List[str],
+ up_to_session: int
+) -> np.ndarray:
+ """Compute learned similarity matrix at a given session."""
+ user_vectors = simulate_z_vectors_at_session(logs, user_order, up_to_session)
+
+ n = len(user_order)
+ sim_matrix = np.zeros((n, n))
+
+ for i, u1 in enumerate(user_order):
+ for j, u2 in enumerate(user_order):
+ v1 = user_vectors.get(u1, np.zeros(10))
+ v2 = user_vectors.get(u2, np.zeros(10))
+ sim_matrix[i, j] = cosine_similarity(v1, v2)
+
+ return sim_matrix
+
+
+# =============================================================================
+# Main Analysis
+# =============================================================================
+
+def analyze_learning_trend(logs_path: str, output_dir: str = "data/analysis",
+ user_store_path: str = "data/users/user_store_pilot_v4_full-greedy.npz"):
+ """Analyze correlation and z_u norm trends over sessions."""
+ os.makedirs(output_dir, exist_ok=True)
+
+ print("=" * 70)
+ print("LEARNING TREND ANALYSIS")
+ print("=" * 70)
+
+ # Load logs
+ print(f"\n[1] Loading logs from: {logs_path}")
+ logs = load_logs(logs_path)
+ print(f" Loaded {len(logs)} turns")
+
+ # Get user order
+ user_order = [u for u in PERSONAS.keys() if any(l["user_id"] == u for l in logs)]
+ print(f" Users: {user_order}")
+
+ # Get max session
+ max_session = max(l["session_id"] for l in logs)
+ print(f" Sessions: 1 to {max_session}")
+
+ # Extract z_norms by session
+ print("\n[2] Extracting z_u norms by session...")
+ z_norms_by_session = extract_z_norms_by_session(logs)
+
+ # Load final z vectors from user store
+ print(f"\n[2.5] Loading final z vectors from: {user_store_path}")
+ final_z_vectors = load_final_z_vectors(user_store_path)
+ if final_z_vectors:
+ print(f" Loaded final z vectors for {len(final_z_vectors)} users")
+ else:
+ print(" [Warning] No final z vectors found, using fallback method")
+
+ # Compute ground truth similarity (constant)
+ gt_sim = compute_ground_truth_similarity_matrix(user_order)
+
+ # Compute CUMULATIVE correlation and avg z_norm
+ # At session N, we use all data from session 1 to N
+ print("\n[3] Computing CUMULATIVE correlation trend (S1→S1-2→S1-3→...→S1-N)...")
+ sessions = list(range(1, max_session + 1))
+ correlations = []
+ avg_z_norms = []
+
+ for s in sessions:
+ # Compute similarity using z_u at end of session s (cumulative learning)
+ learned_sim = compute_similarity_at_session(logs, user_order, s, final_z_vectors)
+ corr = compute_spearman_correlation(learned_sim, gt_sim)
+ correlations.append(corr)
+
+ # Compute average z_norm at the END of session s (this is already cumulative)
+ z_norms = []
+ for user_id in user_order:
+ if user_id in z_norms_by_session and s in z_norms_by_session[user_id]:
+ zl, zs = z_norms_by_session[user_id][s]
+ z_norms.append(np.sqrt(zl**2 + zs**2)) # Combined norm
+
+ avg_z = np.mean(z_norms) if z_norms else 0.0
+ avg_z_norms.append(avg_z)
+
+ # Print results
+ print("\n[4] Results:")
+ print("-" * 60)
+ print(f"{'Session':<10} {'Correlation':<15} {'Avg ||z_u||':<15}")
+ print("-" * 60)
+ for s, corr, z_norm in zip(sessions, correlations, avg_z_norms):
+ print(f"{s:<10} {corr:<15.4f} {z_norm:<15.6f}")
+
+ # Summary statistics
+ print("\n[5] Trend Summary:")
+ print("-" * 60)
+
+ # Linear regression for correlation trend
+ from scipy.stats import linregress
+ slope_corr, intercept_corr, r_corr, p_corr, _ = linregress(sessions, correlations)
+ print(f" Correlation trend: slope={slope_corr:.4f}, R²={r_corr**2:.4f}, p={p_corr:.4f}")
+
+ # Linear regression for z_norm trend
+ slope_z, intercept_z, r_z, p_z, _ = linregress(sessions, avg_z_norms)
+ print(f" ||z_u|| trend: slope={slope_z:.6f}, R²={r_z**2:.4f}, p={p_z:.4f}")
+
+ # Correlation between the two trends
+ trend_corr, _ = spearmanr(correlations, avg_z_norms) if len(correlations) > 2 else (0, 1)
+ print(f" Correlation between trends: {trend_corr:.4f}")
+
+ # Save data
+ results = {
+ "sessions": np.array(sessions),
+ "correlations": np.array(correlations),
+ "avg_z_norms": np.array(avg_z_norms),
+ "slope_corr": slope_corr,
+ "slope_z": slope_z,
+ "trend_corr": trend_corr,
+ }
+ results_path = os.path.join(output_dir, "learning_trend_results.npz")
+ np.savez(results_path, **results)
+ print(f"\n[Results] Saved to: {results_path}")
+
+ # Plot
+ print("\n[6] Generating plots...")
+ plot_learning_trend(sessions, correlations, avg_z_norms, output_dir)
+
+ print("\n" + "=" * 70)
+ print("ANALYSIS COMPLETE")
+ print("=" * 70)
+
+ return results
+
+
+def plot_learning_trend(sessions, correlations, avg_z_norms, output_dir):
+ """Generate plots for learning trend."""
+ try:
+ import matplotlib.pyplot as plt
+ import matplotlib
+ matplotlib.use('Agg') # Non-interactive backend
+ except ImportError:
+ print("[Warning] matplotlib not available, skipping plots")
+ # Save as text instead
+ with open(os.path.join(output_dir, "learning_trend.txt"), "w") as f:
+ f.write("Session,Correlation,Avg_Z_Norm\n")
+ for s, c, z in zip(sessions, correlations, avg_z_norms):
+ f.write(f"{s},{c:.4f},{z:.6f}\n")
+ print(f"[Data] Saved to: {os.path.join(output_dir, 'learning_trend.txt')}")
+ return
+
+ fig, axes = plt.subplots(1, 2, figsize=(12, 5))
+
+ # Plot 1: Correlation vs Session
+ ax1 = axes[0]
+ ax1.plot(sessions, correlations, 'o-', color='#2ecc71', linewidth=2, markersize=8)
+ ax1.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+
+ # Add trend line
+ from scipy.stats import linregress
+ slope, intercept, _, _, _ = linregress(sessions, correlations)
+ trend_line = [slope * s + intercept for s in sessions]
+ ax1.plot(sessions, trend_line, '--', color='#27ae60', alpha=0.7, label=f'Trend (slope={slope:.3f})')
+
+ ax1.set_xlabel('Sessions (Cumulative: 1→N)', fontsize=12)
+ ax1.set_ylabel('Spearman Correlation', fontsize=12)
+ ax1.set_title('Learned vs Ground-Truth Similarity\nCorrelation with Cumulative Data', fontsize=14)
+ ax1.set_xticks(sessions)
+ ax1.legend()
+ ax1.grid(True, alpha=0.3)
+ ax1.set_ylim(-0.5, 1.0)
+
+ # Plot 2: z_u norm vs Session
+ ax2 = axes[1]
+ ax2.plot(sessions, avg_z_norms, 's-', color='#3498db', linewidth=2, markersize=8)
+
+ # Add trend line
+ slope_z, intercept_z, _, _, _ = linregress(sessions, avg_z_norms)
+ trend_line_z = [slope_z * s + intercept_z for s in sessions]
+ ax2.plot(sessions, trend_line_z, '--', color='#2980b9', alpha=0.7, label=f'Trend (slope={slope_z:.5f})')
+
+ ax2.set_xlabel('Session (End of)', fontsize=12)
+ ax2.set_ylabel('Average ||z_u||', fontsize=12)
+ ax2.set_title('User Vector Norm\n(Cumulative Learning)', fontsize=14)
+ ax2.set_xticks(sessions)
+ ax2.legend()
+ ax2.grid(True, alpha=0.3)
+
+ plt.tight_layout()
+
+ output_path = os.path.join(output_dir, "learning_trend.png")
+ plt.savefig(output_path, dpi=150, bbox_inches='tight')
+ print(f"[Plot] Saved to: {output_path}")
+
+ # Also save as PDF for paper
+ pdf_path = os.path.join(output_dir, "learning_trend.pdf")
+ plt.savefig(pdf_path, bbox_inches='tight')
+ print(f"[Plot] Saved to: {pdf_path}")
+
+
+# Need this import at top level for trend calculation
+from scipy.stats import spearmanr
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Analyze Learning Trend")
+ parser.add_argument("--logs", type=str, required=True, help="Path to log file")
+ parser.add_argument("--user-store", type=str, default="data/users/user_store_pilot_v4_full-greedy.npz",
+ help="Path to user store with final z vectors")
+ parser.add_argument("--output-dir", type=str, default="data/analysis", help="Output directory")
+ args = parser.parse_args()
+
+ analyze_learning_trend(args.logs, args.output_dir, args.user_store)
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/scripts/analyze_memory.py b/scripts/analyze_memory.py
new file mode 100644
index 0000000..4c439cf
--- /dev/null
+++ b/scripts/analyze_memory.py
@@ -0,0 +1,61 @@
+import json
+import os
+from collections import Counter, defaultdict
+
+MEMORY_FILE = "data/corpora/memory_cards.jsonl"
+
+def analyze_memory():
+ if not os.path.exists(MEMORY_FILE):
+ print(f"Error: {MEMORY_FILE} not found.")
+ return
+
+ print(f"Analyzing {MEMORY_FILE}...")
+
+ total_cards = 0
+ user_counts = Counter()
+ content_hashes = defaultdict(int)
+ user_content_hashes = defaultdict(int)
+
+ with open(MEMORY_FILE, "r", encoding="utf-8") as f:
+ for line in f:
+ if not line.strip(): continue
+ try:
+ card = json.loads(line)
+ total_cards += 1
+ uid = card.get("user_id", "unknown")
+ text = card.get("note_text", "").strip()
+
+ user_counts[uid] += 1
+ content_hashes[text] += 1
+ user_content_hashes[(uid, text)] += 1
+ except:
+ pass
+
+ print("\n" + "="*40)
+ print("MEMORY STORE ANALYSIS")
+ print("="*40)
+ print(f"Total Cards: {total_cards}")
+ print(f"Unique Users: {len(user_counts)}")
+ print("-" * 40)
+
+ print("\nTop 10 Users by Card Count:")
+ for uid, count in user_counts.most_common(10):
+ print(f" {uid}: {count}")
+
+ print("\nTop 10 Most Frequent Contents (Global):")
+ sorted_content = sorted(content_hashes.items(), key=lambda x: x[1], reverse=True)[:10]
+ for text, count in sorted_content:
+ display_text = (text[:50] + '...') if len(text) > 50 else text
+ print(f" [{count}] {display_text}")
+
+ print("\nTop 10 Most Frequent (User, Content) Duplicates:")
+ sorted_user_content = sorted(user_content_hashes.items(), key=lambda x: x[1], reverse=True)[:10]
+ for (uid, text), count in sorted_user_content:
+ display_text = (text[:50] + '...') if len(text) > 50 else text
+ print(f" [{count}] User: {uid} | {display_text}")
+
+ print("="*40)
+
+if __name__ == "__main__":
+ analyze_memory()
+
diff --git a/scripts/analyze_memory_coverage.py b/scripts/analyze_memory_coverage.py
new file mode 100644
index 0000000..e7ec498
--- /dev/null
+++ b/scripts/analyze_memory_coverage.py
@@ -0,0 +1,103 @@
+#!/usr/bin/env python3
+"""
+Script to analyze Memory Card coverage statistics.
+"""
+import sys
+import os
+import json
+import numpy as np
+from collections import defaultdict
+
+# Add src to sys.path
+sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.retrieval.preference_store.schemas import MemoryCard
+
+def main():
+ cards_path = "data/personamem/memory_cards.jsonl"
+
+ if not os.path.exists(cards_path):
+ print(f"Error: {cards_path} not found.")
+ return
+
+ print(f"Loading memory cards from {cards_path}...")
+
+ cards_by_user = defaultdict(int)
+ total_cards = 0
+
+ with open(cards_path, "r") as f:
+ for line in f:
+ try:
+ card = json.loads(line)
+ uid = card.get("user_id")
+ if uid:
+ cards_by_user[uid] += 1
+ total_cards += 1
+ except:
+ continue
+
+ # We also need to know the TOTAL number of personas (including those with 0 cards)
+ # We can infer this from the user_vectors file if it exists, or just report on "users with memory"
+ # But better to check contexts file to see denominator
+
+ ctx_path = "data/raw_datasets/personamem/shared_contexts_32k.jsonl"
+ total_personas = 0
+ if os.path.exists(ctx_path):
+ with open(ctx_path, "r") as f:
+ for line in f:
+ data = json.loads(line)
+ total_personas += len(data) # Each line is {hash: [msgs]}? Wait, check format.
+ # personamem_loader says: line is dict {cid: msgs}
+ # So usually 1 per line? Or many?
+ # Let's count keys.
+ else:
+ print("Warning: Context file not found, can't calculate 0-memory users accurately.")
+ total_personas = len(cards_by_user) # Fallback
+
+ users_with_memory = len(cards_by_user)
+ users_without_memory = total_personas - users_with_memory
+
+ counts = list(cards_by_user.values())
+ if users_without_memory > 0:
+ counts.extend([0] * users_without_memory)
+
+ print("\n" + "="*40)
+ print("Memory Coverage Statistics")
+ print("="*40)
+ print(f"Total Personas (Est): {total_personas}")
+ print(f"Total Memory Cards: {total_cards}")
+ print(f"Users with Memory: {users_with_memory} ({users_with_memory/total_personas*100:.2f}%)")
+ print(f"Users w/o Memory: {users_without_memory} ({users_without_memory/total_personas*100:.2f}%)")
+ print("-" * 40)
+
+ if counts:
+ avg_cards = np.mean(counts)
+ median_cards = np.median(counts)
+ max_cards = np.max(counts)
+
+ print(f"Avg Cards/User: {avg_cards:.2f}")
+ print(f"Median Cards/User: {median_cards:.2f}")
+ print(f"Max Cards/User: {max_cards}")
+
+ # Percentiles
+ p25, p75 = np.percentile(counts, [25, 75])
+ print(f"25th Percentile: {p25:.2f}")
+ print(f"75th Percentile: {p75:.2f}")
+
+ print("\nDistribution:")
+
+ # Adjust for exact 0
+ zero_count = counts.count(0)
+
+ print(f" 0 : {zero_count}")
+ # Custom bins for >0
+ non_zero_counts = [c for c in counts if c > 0]
+ if non_zero_counts:
+ hist_nz, edges = np.histogram(non_zero_counts, bins=[1, 5, 10, 20, 50, 1000])
+ for i in range(len(hist_nz)):
+ range_str = f"{int(edges[i])}-{int(edges[i+1]-1)}"
+ print(f" {range_str:<8}: {hist_nz[i]}")
+
+if __name__ == "__main__":
+ main()
+
diff --git a/scripts/analyze_user_similarity.py b/scripts/analyze_user_similarity.py
new file mode 100644
index 0000000..538a89a
--- /dev/null
+++ b/scripts/analyze_user_similarity.py
@@ -0,0 +1,445 @@
+#!/usr/bin/env python3
+"""
+User Vector Similarity Analysis
+
+This script analyzes the similarity between user vectors (z_u) learned by the
+online personalization system. It computes:
+1. Cosine similarity matrix between all user vectors
+2. Ground truth similarity based on preference overlap
+3. Correlation between learned and expected similarities
+
+Usage:
+ python scripts/analyze_user_similarity.py \
+ --user-store data/users/user_store_pilot_v4_full-greedy.npz
+"""
+
+import argparse
+import numpy as np
+from typing import Dict, List, Tuple
+from dataclasses import dataclass
+
+
+# =============================================================================
+# Persona Definitions (must match pilot_runner_v4.py)
+# =============================================================================
+
+@dataclass
+class StylePrefs:
+ """User's TRUE style preferences."""
+ require_short: bool = False
+ max_chars: int = 300
+ require_bullets: bool = False
+ lang: str = "en"
+
+
+# Ground truth personas
+PERSONAS = {
+ "user_A_short_bullets_en": StylePrefs(require_short=True, max_chars=200, require_bullets=True, lang="en"),
+ "user_B_short_no_bullets_en": StylePrefs(require_short=True, max_chars=200, require_bullets=False, lang="en"),
+ "user_C_long_bullets_en": StylePrefs(require_short=False, max_chars=800, require_bullets=True, lang="en"),
+ "user_D_short_bullets_zh": StylePrefs(require_short=True, max_chars=200, require_bullets=True, lang="zh"),
+ "user_E_long_no_bullets_zh": StylePrefs(require_short=False, max_chars=800, require_bullets=False, lang="zh"),
+ "user_F_extreme_short_en": StylePrefs(require_short=True, max_chars=100, require_bullets=True, lang="en"),
+}
+
+
+# =============================================================================
+# User Vector Loading
+# =============================================================================
+
+def load_user_vectors(user_store_path: str) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
+ """
+ Load user vectors from saved user store.
+
+ Returns:
+ {user_id: (z_long, z_short)}
+ """
+ data = np.load(user_store_path, allow_pickle=True)
+
+ user_vectors = {}
+
+ # UserTensorStore saves in format: {uid}_long, {uid}_short, {uid}_meta
+ # First, find all unique user IDs
+ user_ids = set()
+ for key in data.files:
+ if key.endswith("_long"):
+ uid = key[:-5] # Remove "_long"
+ user_ids.add(uid)
+
+ # Load vectors for each user
+ for uid in user_ids:
+ long_key = f"{uid}_long"
+ short_key = f"{uid}_short"
+
+ if long_key in data.files and short_key in data.files:
+ z_long = data[long_key]
+ z_short = data[short_key]
+ user_vectors[uid] = (z_long, z_short)
+
+ return user_vectors
+
+
+def load_user_vectors_from_internal(user_store_path: str) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
+ """
+ Alternative loader that understands the internal format.
+ """
+ data = np.load(user_store_path, allow_pickle=True)
+
+ print(f"[Debug] Available keys in npz: {list(data.files)}")
+
+ user_vectors = {}
+
+ # Try to find user vectors in various formats
+ for key in data.files:
+ print(f" {key}: shape={data[key].shape if hasattr(data[key], 'shape') else 'N/A'}")
+
+ # Format 1: Separate arrays per user
+ seen_users = set()
+ for key in data.files:
+ if "_z_long" in key or key.startswith("z_long_"):
+ # Extract user_id
+ if key.startswith("z_long_"):
+ user_id = key[7:] # Remove "z_long_"
+ else:
+ user_id = key.split("_z_long")[0]
+ seen_users.add(user_id)
+
+ for user_id in seen_users:
+ # Try different key formats
+ z_long_keys = [f"z_long_{user_id}", f"{user_id}_z_long"]
+ z_short_keys = [f"z_short_{user_id}", f"{user_id}_z_short"]
+
+ z_long = None
+ z_short = None
+
+ for k in z_long_keys:
+ if k in data.files:
+ z_long = data[k]
+ break
+
+ for k in z_short_keys:
+ if k in data.files:
+ z_short = data[k]
+ break
+
+ if z_long is not None and z_short is not None:
+ user_vectors[user_id] = (z_long, z_short)
+
+ return user_vectors
+
+
+# =============================================================================
+# Similarity Computation
+# =============================================================================
+
+def cosine_similarity(v1: np.ndarray, v2: np.ndarray) -> float:
+ """Compute cosine similarity between two vectors."""
+ norm1 = np.linalg.norm(v1)
+ norm2 = np.linalg.norm(v2)
+
+ if norm1 < 1e-10 or norm2 < 1e-10:
+ return 0.0
+
+ return float(np.dot(v1, v2) / (norm1 * norm2))
+
+
+def compute_learned_similarity_matrix(
+ user_vectors: Dict[str, Tuple[np.ndarray, np.ndarray]],
+ user_order: List[str]
+) -> np.ndarray:
+ """
+ Compute similarity matrix from learned user vectors.
+
+ Uses concatenated [z_long, z_short] as the user representation.
+ """
+ n = len(user_order)
+ sim_matrix = np.zeros((n, n))
+
+ for i, u1 in enumerate(user_order):
+ for j, u2 in enumerate(user_order):
+ if u1 in user_vectors and u2 in user_vectors:
+ z1 = np.concatenate(user_vectors[u1])
+ z2 = np.concatenate(user_vectors[u2])
+ sim_matrix[i, j] = cosine_similarity(z1, z2)
+ elif i == j:
+ sim_matrix[i, j] = 1.0
+
+ return sim_matrix
+
+
+def compute_ground_truth_similarity(
+ personas: Dict[str, StylePrefs],
+ user_order: List[str]
+) -> np.ndarray:
+ """
+ Compute ground truth similarity based on preference overlap.
+
+ Uses Jaccard-like similarity:
+ - short: +1 if both require_short or both don't
+ - bullets: +1 if both require_bullets match
+ - lang: +1 if both lang match
+
+ Then normalize to [0, 1].
+ """
+ n = len(user_order)
+ sim_matrix = np.zeros((n, n))
+
+ for i, u1 in enumerate(user_order):
+ for j, u2 in enumerate(user_order):
+ if u1 not in personas or u2 not in personas:
+ sim_matrix[i, j] = 0.0 if i != j else 1.0
+ continue
+
+ p1 = personas[u1]
+ p2 = personas[u2]
+
+ # Count matching dimensions
+ matches = 0
+ total = 3 # short, bullets, lang
+
+ if p1.require_short == p2.require_short:
+ matches += 1
+ if p1.require_bullets == p2.require_bullets:
+ matches += 1
+ if p1.lang == p2.lang:
+ matches += 1
+
+ sim_matrix[i, j] = matches / total
+
+ return sim_matrix
+
+
+def compute_correlation(learned: np.ndarray, ground_truth: np.ndarray) -> Tuple[float, float]:
+ """
+ Compute Pearson and Spearman correlation between learned and ground truth similarity.
+ Only uses upper triangle (excluding diagonal) to avoid bias.
+ """
+ n = learned.shape[0]
+
+ # Extract upper triangle (excluding diagonal)
+ learned_flat = []
+ gt_flat = []
+
+ for i in range(n):
+ for j in range(i + 1, n):
+ learned_flat.append(learned[i, j])
+ gt_flat.append(ground_truth[i, j])
+
+ learned_flat = np.array(learned_flat)
+ gt_flat = np.array(gt_flat)
+
+ # Pearson correlation
+ if np.std(learned_flat) < 1e-10 or np.std(gt_flat) < 1e-10:
+ pearson = 0.0
+ else:
+ pearson = float(np.corrcoef(learned_flat, gt_flat)[0, 1])
+
+ # Spearman correlation (rank-based)
+ from scipy.stats import spearmanr
+ spearman, _ = spearmanr(learned_flat, gt_flat)
+
+ return pearson, float(spearman)
+
+
+# =============================================================================
+# Visualization
+# =============================================================================
+
+def print_similarity_matrix(matrix: np.ndarray, user_order: List[str], title: str):
+ """Print similarity matrix in ASCII format."""
+ print(f"\n{title}")
+ print("=" * 70)
+
+ # Short labels
+ labels = [u.replace("user_", "").replace("_", " ")[:15] for u in user_order]
+
+ # Header
+ print(f"{'':>16}", end="")
+ for label in labels:
+ print(f"{label[:8]:>10}", end="")
+ print()
+
+ # Rows
+ for i, label in enumerate(labels):
+ print(f"{label:>16}", end="")
+ for j in range(len(labels)):
+ print(f"{matrix[i, j]:>10.3f}", end="")
+ print()
+
+ print()
+
+
+def save_visualization(
+ learned: np.ndarray,
+ ground_truth: np.ndarray,
+ user_order: List[str],
+ output_path: str
+):
+ """Save similarity matrices as heatmap visualization."""
+ try:
+ import matplotlib.pyplot as plt
+ import seaborn as sns
+ except ImportError:
+ print("[Warning] matplotlib/seaborn not available, skipping visualization")
+ return
+
+ # Short labels
+ labels = [u.replace("user_", "")[:12] for u in user_order]
+
+ fig, axes = plt.subplots(1, 2, figsize=(14, 6))
+
+ # Learned similarity
+ sns.heatmap(learned, annot=True, fmt=".2f",
+ xticklabels=labels, yticklabels=labels,
+ cmap="RdYlGn", vmin=-1, vmax=1,
+ ax=axes[0])
+ axes[0].set_title("Learned User Vector Similarity\n(cosine similarity)")
+ axes[0].tick_params(axis='x', rotation=45)
+ axes[0].tick_params(axis='y', rotation=0)
+
+ # Ground truth similarity
+ sns.heatmap(ground_truth, annot=True, fmt=".2f",
+ xticklabels=labels, yticklabels=labels,
+ cmap="RdYlGn", vmin=0, vmax=1,
+ ax=axes[1])
+ axes[1].set_title("Ground Truth Preference Overlap\n(Jaccard-like)")
+ axes[1].tick_params(axis='x', rotation=45)
+ axes[1].tick_params(axis='y', rotation=0)
+
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=150, bbox_inches='tight')
+ print(f"[Visualization] Saved to: {output_path}")
+
+
+# =============================================================================
+# Main Analysis
+# =============================================================================
+
+def analyze_user_similarity(user_store_path: str, output_dir: str = "data/analysis"):
+ """Run full user similarity analysis."""
+ import os
+ os.makedirs(output_dir, exist_ok=True)
+
+ print("=" * 70)
+ print("USER VECTOR SIMILARITY ANALYSIS")
+ print("=" * 70)
+ print(f"User store: {user_store_path}")
+
+ # Load user vectors
+ print("\n[1] Loading user vectors...")
+ user_vectors = load_user_vectors(user_store_path)
+
+ if not user_vectors:
+ print("[Warning] No user vectors found with standard format, trying alternative...")
+ user_vectors = load_user_vectors_from_internal(user_store_path)
+
+ if not user_vectors:
+ print("[Error] Could not load user vectors!")
+ return
+
+ print(f" Found {len(user_vectors)} users: {list(user_vectors.keys())}")
+
+ # Print vector norms
+ print("\n[2] User vector norms:")
+ for uid, (z_long, z_short) in user_vectors.items():
+ print(f" {uid}: ||z_long||={np.linalg.norm(z_long):.4f}, ||z_short||={np.linalg.norm(z_short):.4f}")
+
+ # Determine user order (intersection of loaded users and known personas)
+ user_order = [u for u in PERSONAS.keys() if u in user_vectors]
+ print(f"\n[3] Analyzing {len(user_order)} users: {user_order}")
+
+ if len(user_order) < 2:
+ print("[Error] Need at least 2 users for similarity analysis!")
+ return
+
+ # Compute similarity matrices
+ print("\n[4] Computing similarity matrices...")
+ learned_sim = compute_learned_similarity_matrix(user_vectors, user_order)
+ gt_sim = compute_ground_truth_similarity(PERSONAS, user_order)
+
+ # Print matrices
+ print_similarity_matrix(learned_sim, user_order, "LEARNED SIMILARITY (Cosine of z_u)")
+ print_similarity_matrix(gt_sim, user_order, "GROUND TRUTH SIMILARITY (Preference Overlap)")
+
+ # Compute correlation
+ print("\n[5] Correlation Analysis:")
+ print("-" * 50)
+ pearson, spearman = compute_correlation(learned_sim, gt_sim)
+ print(f" Pearson correlation: {pearson:.4f}")
+ print(f" Spearman correlation: {spearman:.4f}")
+
+ # Interpretation
+ print("\n[6] Interpretation:")
+ print("-" * 50)
+ if spearman > 0.7:
+ print(" ✅ STRONG correlation: User vectors encode preference similarity well!")
+ elif spearman > 0.4:
+ print(" ⚠️ MODERATE correlation: User vectors partially capture preferences.")
+ elif spearman > 0:
+ print(" ⚠️ WEAK correlation: User vectors weakly capture preferences.")
+ else:
+ print(" ❌ NO/NEGATIVE correlation: User vectors do not reflect preferences.")
+
+ # Key comparisons
+ print("\n[7] Key Similarity Comparisons:")
+ print("-" * 50)
+
+ def get_sim(u1, u2, matrix, user_order):
+ if u1 in user_order and u2 in user_order:
+ i, j = user_order.index(u1), user_order.index(u2)
+ return matrix[i, j]
+ return None
+
+ comparisons = [
+ ("user_A_short_bullets_en", "user_F_extreme_short_en", ">", "user_A_short_bullets_en", "user_E_long_no_bullets_zh",
+ "A~F (both short+bullets) should be > A~E (opposite)"),
+ ("user_A_short_bullets_en", "user_D_short_bullets_zh", ">", "user_A_short_bullets_en", "user_C_long_bullets_en",
+ "A~D (both short+bullets) should be > A~C (only bullets match)"),
+ ("user_B_short_no_bullets_en", "user_E_long_no_bullets_zh", ">", "user_B_short_no_bullets_en", "user_A_short_bullets_en",
+ "B~E (both no_bullets) should be > B~A (bullets differ)"),
+ ]
+
+ for u1, u2, op, u3, u4, desc in comparisons:
+ sim1 = get_sim(u1, u2, learned_sim, user_order)
+ sim2 = get_sim(u3, u4, learned_sim, user_order)
+
+ if sim1 is not None and sim2 is not None:
+ passed = sim1 > sim2 if op == ">" else sim1 < sim2
+ status = "✅ PASS" if passed else "❌ FAIL"
+ print(f" {status}: sim({u1[:6]},{u2[:6]})={sim1:.3f} {op} sim({u3[:6]},{u4[:6]})={sim2:.3f}")
+ print(f" ({desc})")
+
+ # Save visualization
+ print("\n[8] Saving visualization...")
+ output_path = os.path.join(output_dir, "user_similarity_matrix.png")
+ save_visualization(learned_sim, gt_sim, user_order, output_path)
+
+ # Save numerical results
+ results_path = os.path.join(output_dir, "user_similarity_results.npz")
+ np.savez(results_path,
+ learned_similarity=learned_sim,
+ ground_truth_similarity=gt_sim,
+ user_order=user_order,
+ pearson=pearson,
+ spearman=spearman)
+ print(f"[Results] Saved to: {results_path}")
+
+ print("\n" + "=" * 70)
+ print("ANALYSIS COMPLETE")
+ print("=" * 70)
+
+
+def main():
+ parser = argparse.ArgumentParser(description="User Vector Similarity Analysis")
+ parser.add_argument("--user-store", type=str, required=True,
+ help="Path to user store npz file")
+ parser.add_argument("--output-dir", type=str, default="data/analysis",
+ help="Output directory for results")
+ args = parser.parse_args()
+
+ analyze_user_similarity(args.user_store, args.output_dir)
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/scripts/assemble_dataset.py b/scripts/assemble_dataset.py
new file mode 100644
index 0000000..024f91f
--- /dev/null
+++ b/scripts/assemble_dataset.py
@@ -0,0 +1,85 @@
+import json
+import os
+import random
+
+# Source Files
+FILE_ORIGINAL = "data/raw_datasets/labeled_full_dataset_batch.jsonl"
+FILE_SYNTHESIS = "data/raw_datasets/synthesized_positives.jsonl"
+
+# Output Files
+OUTPUT_RAW = "data/finetune/preference_extractor_450k.jsonl"
+
+def assemble_dataset():
+ os.makedirs(os.path.dirname(OUTPUT_RAW), exist_ok=True)
+
+ print("Assembling final dataset...")
+
+ records = []
+
+ # 1. Load Original (Pos + Neg)
+ print(f"Loading {FILE_ORIGINAL}...")
+ if os.path.exists(FILE_ORIGINAL):
+ with open(FILE_ORIGINAL, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ item = json.loads(line)
+ # Standardize format: {"input": ..., "output": ...}
+ # Input is user query. Output is extracted JSON string.
+
+ query = item.get("original_query", "")
+ output_json = item.get("extracted_json", {"preferences": []})
+
+ # Ensure output is a string of JSON, minimal whitespace to save tokens
+ output_str = json.dumps(output_json, ensure_ascii=False)
+
+ records.append({
+ "input": query,
+ "output": output_str,
+ "source": item.get("source", "original")
+ })
+ else:
+ print(f"Warning: {FILE_ORIGINAL} missing!")
+
+ print(f"Loaded {len(records)} from original.")
+
+ # 2. Load Synthesis (Pos)
+ print(f"Loading {FILE_SYNTHESIS}...")
+ syn_count = 0
+ if os.path.exists(FILE_SYNTHESIS):
+ with open(FILE_SYNTHESIS, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ item = json.loads(line)
+
+ query = item.get("original_query", "")
+ output_json = item.get("extracted_json", {"preferences": []})
+ output_str = json.dumps(output_json, ensure_ascii=False)
+
+ records.append({
+ "input": query,
+ "output": output_str,
+ "source": "synthesis"
+ })
+ syn_count += 1
+ else:
+ print(f"Warning: {FILE_SYNTHESIS} missing!")
+
+ print(f"Loaded {syn_count} from synthesis.")
+
+ # 3. Shuffle
+ print("Shuffling...")
+ random.shuffle(records)
+
+ # 4. Save
+ print(f"Saving {len(records)} records to {OUTPUT_RAW}...")
+ with open(OUTPUT_RAW, "w", encoding="utf-8") as f:
+ for r in records:
+ f.write(json.dumps(r, ensure_ascii=False) + "\n")
+
+ print("Done!")
+ print("\nTo upload to Hugging Face, run:")
+ print("huggingface-cli upload <repo_id> data/finetune/preference_extractor_450k.jsonl --repo-type dataset")
+
+if __name__ == "__main__":
+ assemble_dataset()
+
diff --git a/scripts/build_item_space.py b/scripts/build_item_space.py
new file mode 100644
index 0000000..c98238c
--- /dev/null
+++ b/scripts/build_item_space.py
@@ -0,0 +1,52 @@
+#!/usr/bin/env python3
+"""
+Script to build Item Space (PCA Projection) from Memory Embeddings.
+Inputs:
+- data/corpora/memory_embeddings.npy (M x 4096)
+Outputs:
+- data/corpora/item_projection.npz (P, mean, V)
+"""
+
+import sys
+import os
+import numpy as np
+
+# Add src to sys.path
+sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.user_model.features import ItemProjection
+
+def main():
+ emb_path = "data/corpora/memory_embeddings.npy"
+ out_path = "data/corpora/item_projection.npz"
+
+ if not os.path.exists(emb_path):
+ print(f"Error: {emb_path} not found. Run migrate_preferences.py first.")
+ sys.exit(1)
+
+ print(f"Loading embeddings from {emb_path}...")
+ E = np.load(emb_path)
+ print(f"Loaded shape: {E.shape}")
+
+ # Target dimension k=256
+ k = 256
+ print(f"Fitting PCA with k={k}...")
+
+ proj = ItemProjection.from_pca(E, k=k)
+
+ print("Transforming all embeddings to item space...")
+ V = proj.transform_embeddings(E)
+ print(f"Item vectors shape: {V.shape}")
+
+ print(f"Saving projection to {out_path}...")
+ np.savez(
+ out_path,
+ P=proj.P,
+ mean=proj.mean,
+ V=V
+ )
+ print("Done.")
+
+if __name__ == "__main__":
+ main()
+
diff --git a/scripts/check_batch_status.py b/scripts/check_batch_status.py
new file mode 100644
index 0000000..8c3ea5d
--- /dev/null
+++ b/scripts/check_batch_status.py
@@ -0,0 +1,60 @@
+import json
+import os
+import time
+from openai import OpenAI
+
+BATCH_IDS_FILE = "data/putnam_eval/submitted_batch_ids.json"
+
+def check_status():
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ print("Error: OPENAI_API_KEY not set.")
+ return
+ client = OpenAI(api_key=api_key)
+
+ if not os.path.exists(BATCH_IDS_FILE):
+ print(f"Error: {BATCH_IDS_FILE} not found. Did you submit batches?")
+ return
+
+ with open(BATCH_IDS_FILE, "r") as f:
+ batch_ids = json.load(f)
+
+ print(f"Checking status for {len(batch_ids)} batches...\n")
+
+ all_completed = True
+ completed_count = 0
+
+ print(f"{'BATCH ID':<35} | {'STATUS':<15} | {'REQ COUNT':<10} | {'COMPLETED':<10} | {'FAILED':<10}")
+ print("-" * 95)
+
+ for b_id in batch_ids:
+ try:
+ batch = client.batches.retrieve(b_id)
+ status = batch.status
+ counts = batch.request_counts
+ total = counts.total if counts else 0
+ comp = counts.completed if counts else 0
+ fail = counts.failed if counts else 0
+
+ print(f"{b_id:<35} | {status:<15} | {total:<10} | {comp:<10} | {fail:<10}")
+
+ if status != "completed":
+ all_completed = False
+ else:
+ completed_count += 1
+
+ except Exception as e:
+ print(f"{b_id:<35} | {'ERROR':<15} | {str(e)}")
+ all_completed = False
+
+ print("-" * 95)
+ print(f"\nProgress: {completed_count}/{len(batch_ids)} batches completed.")
+
+ if all_completed:
+ print("\nSUCCESS: All batches finished! You can now run scripts/retrieve_results.py")
+ else:
+ print("\nSome batches are still processing. Check again later.")
+
+if __name__ == "__main__":
+ check_status()
+
diff --git a/scripts/clean_memory_store.py b/scripts/clean_memory_store.py
new file mode 100644
index 0000000..3b07a95
--- /dev/null
+++ b/scripts/clean_memory_store.py
@@ -0,0 +1,52 @@
+import json
+import os
+import shutil
+
+MEMORY_FILE = "data/corpora/memory_cards.jsonl"
+BACKUP_FILE = "data/corpora/memory_cards.jsonl.bak"
+
+def clean_memory_store():
+ if not os.path.exists(MEMORY_FILE):
+ print(f"Error: {MEMORY_FILE} not found.")
+ return
+
+ # 1. Backup
+ print(f"Backing up to {BACKUP_FILE}...")
+ shutil.copy2(MEMORY_FILE, BACKUP_FILE)
+
+ unique_keys = set()
+ cleaned_records = []
+ total_read = 0
+
+ # 2. Read and Dedup
+ print("Scanning and deduplicating...")
+ with open(MEMORY_FILE, "r", encoding="utf-8") as f:
+ for line in f:
+ if not line.strip(): continue
+ try:
+ card = json.loads(line)
+ total_read += 1
+
+ uid = card.get("user_id")
+ text = card.get("note_text", "").strip()
+
+ # Key: (User, Content)
+ key = (uid, text)
+
+ if key not in unique_keys:
+ unique_keys.add(key)
+ cleaned_records.append(line.strip())
+ except:
+ pass
+
+ # 3. Write Back
+ print(f"Writing {len(cleaned_records)} records back (Removed {total_read - len(cleaned_records)} duplicates)...")
+ with open(MEMORY_FILE, "w", encoding="utf-8") as f:
+ for line in cleaned_records:
+ f.write(line + "\n")
+
+ print("Done!")
+
+if __name__ == "__main__":
+ clean_memory_store()
+
diff --git a/scripts/convert_to_llama_factory.py b/scripts/convert_to_llama_factory.py
new file mode 100644
index 0000000..d8b7565
--- /dev/null
+++ b/scripts/convert_to_llama_factory.py
@@ -0,0 +1,62 @@
+import json
+import os
+
+INPUT_FILE = "data/finetune/preference_extractor_450k.jsonl"
+OUTPUT_FILE = "data/finetune/train_llama_factory.json"
+
+# We embed the system prompt as "instruction" so the model learns to respond to this specific instruction.
+# Or, if you plan to put this system prompt in the system slot of the chat template,
+# you can leave instruction empty or simplified.
+# Given 0.5B model, explicit instruction in the prompt is often helpful.
+SYSTEM_INSTRUCTION = (
+ "Extract user preferences from the query into JSON format based on the PreferenceList schema. "
+ "If no preferences are found, return {\"preferences\": []}."
+)
+
+def convert():
+ if not os.path.exists(INPUT_FILE):
+ print(f"Error: {INPUT_FILE} not found. Run scripts/assemble_dataset.py first.")
+ return
+
+ print(f"Reading {INPUT_FILE}...")
+ dataset = []
+
+ with open(INPUT_FILE, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ item = json.loads(line)
+
+ # Alpaca format
+ record = {
+ "instruction": SYSTEM_INSTRUCTION,
+ "input": item["input"],
+ "output": item["output"]
+ }
+ dataset.append(record)
+
+ print(f"Converted {len(dataset)} items.")
+
+ # Save as JSON list (LLaMA-Factory standard)
+ print(f"Saving to {OUTPUT_FILE}...")
+ with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
+ json.dump(dataset, f, indent=2, ensure_ascii=False)
+
+ print("Done!")
+
+ print("\nNext steps for LLaMA-Factory:")
+ print("1. Copy data/finetune/train_llama_factory.json to your LLaMA-Factory data/ folder.")
+ print("2. Add entry to dataset_info.json:")
+ print(json.dumps({
+ "preference_extractor_v1": {
+ "file_name": "train_llama_factory.json",
+ "columns": {
+ "prompt": "instruction",
+ "query": "input",
+ "response": "output"
+ }
+ }
+ }, indent=2))
+
+if __name__ == "__main__":
+ convert()
+
diff --git a/scripts/day1_demo.py b/scripts/day1_demo.py
new file mode 100644
index 0000000..b201229
--- /dev/null
+++ b/scripts/day1_demo.py
@@ -0,0 +1,120 @@
+#!/usr/bin/env python3
+"""
+Day 1 Demo: End-to-end Minimal Memory RAG.
+1. Load MemoryCards + Embeddings.
+2. Receive a query.
+3. Retrieve top-k memories.
+4. Generate answer with QwenInstruct.
+"""
+
+import json
+import numpy as np
+import torch
+import sys
+import os
+
+# Add src to sys.path so we can import personalization
+sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))
+
+from typing import List
+
+from personalization.config.settings import load_local_models_config
+from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
+from personalization.models.llm.qwen_instruct import QwenInstruct
+from personalization.retrieval.preference_store.schemas import MemoryCard
+
+def load_memory_store(cards_path: str, embs_path: str):
+ print(f"Loading memory store from {cards_path}...")
+ cards = []
+ with open(cards_path, "r", encoding="utf-8") as f:
+ for line in f:
+ cards.append(MemoryCard.model_validate_json(line))
+
+ embs = np.load(embs_path)
+ return cards, embs
+
+def cosine_similarity(E: np.ndarray, e_q: np.ndarray) -> np.ndarray:
+ # E: [M, d], e_q: [d]
+ # Assumes vectors are normalized
+ return np.dot(E, e_q)
+
+def dense_retrieve(
+ query: str,
+ embedder: Qwen3Embedding8B,
+ cards: List[MemoryCard],
+ E: np.ndarray,
+ topk: int = 3
+) -> List[MemoryCard]:
+
+ # Encode query
+ # encode returns list[list[float]] or tensor
+ e_q_list = embedder.encode([query], normalize=True, return_tensor=False)
+ e_q = np.array(e_q_list[0], dtype=np.float32)
+
+ # Sim
+ sims = cosine_similarity(E, e_q)
+
+ # Top-k
+ # argsort is ascending, so take last k and reverse
+ if len(cards) == 0:
+ return []
+
+ k = min(topk, len(cards))
+ idx = np.argsort(sims)[-k:][::-1]
+
+ results = [cards[i] for i in idx]
+ return results
+
+def main():
+ cards_path = "data/corpora/memory_cards.jsonl"
+ embs_path = "data/corpora/memory_embeddings.npy"
+
+ try:
+ cards, embs = load_memory_store(cards_path, embs_path)
+ print(f"Loaded {len(cards)} memory cards.")
+ except FileNotFoundError:
+ print("Error: Memory store not found. Please run scripts/migrate_preferences.py first.")
+ sys.exit(1)
+
+ cfg = load_local_models_config()
+
+ print("Initializing models...")
+ embedder = Qwen3Embedding8B.from_config(cfg)
+ llm = QwenInstruct.from_config(cfg)
+
+ # Demo Query
+ # Let's try to pick a query that should trigger a retrieval if we have relevant memories.
+ # Since we processed pilot_study, let's assume we might have some "python code" or "formatting" prefs.
+ # If the pilot study didn't yield many prefs, we might just query something generic.
+ query = "Please write a function to calculate fibonacci numbers. Remember my preferences."
+
+ # Or let's allow user input or command line arg
+ if len(sys.argv) > 1:
+ query = sys.argv[1]
+
+ print(f"\nQuery: {query}")
+
+ # Retrieve
+ hits = dense_retrieve(query, embedder, cards, embs, topk=3)
+
+ print(f"\nRetrieved {len(hits)} memories:")
+ notes = []
+ for h in hits:
+ print(f" - [{h.kind}] {h.note_text} (from user: {h.user_id})")
+ notes.append(h.note_text)
+
+ # Generate
+ print("\nGenerating answer...")
+ # Mock history: just the current turn
+ history = [{"role": "user", "content": query}]
+
+ answer = llm.answer(history, notes)
+
+ print("-" * 40)
+ print("Answer:")
+ print(answer)
+ print("-" * 40)
+
+if __name__ == "__main__":
+ main()
+
diff --git a/scripts/day2_demo.py b/scripts/day2_demo.py
new file mode 100644
index 0000000..ca81d99
--- /dev/null
+++ b/scripts/day2_demo.py
@@ -0,0 +1,162 @@
+#!/usr/bin/env python3
+"""
+Day 2 Demo: End-to-end Memory RAG with Reranker and (Shell) Personalization.
+"""
+
+import sys
+import os
+import numpy as np
+import torch
+from typing import List
+
+# Add src to sys.path
+sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.config.settings import load_local_models_config
+from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
+from personalization.models.llm.qwen_instruct import QwenInstruct
+from personalization.models.reranker.qwen3_reranker import Qwen3Reranker
+from personalization.retrieval.preference_store.schemas import MemoryCard
+from personalization.user_model.tensor_store import UserTensorStore
+from personalization.retrieval.pipeline import retrieve_with_rerank
+
+def main():
+ # Paths
+ cards_path = "data/corpora/memory_cards.jsonl"
+ embs_path = "data/corpora/memory_embeddings.npy"
+ item_proj_path = "data/corpora/item_projection.npz"
+ user_store_path = "data/users/user_store.npz"
+
+ # 1. Load Data
+ print("Loading data stores...")
+ if not os.path.exists(cards_path) or not os.path.exists(embs_path):
+ print("Memory data missing. Run migrate_preferences.py")
+ sys.exit(1)
+
+ cards = []
+ with open(cards_path, "r") as f:
+ for line in f:
+ cards.append(MemoryCard.model_validate_json(line))
+
+ memory_embeddings = np.load(embs_path)
+
+ if not os.path.exists(item_proj_path):
+ print("Item projection missing. Run build_item_space.py")
+ sys.exit(1)
+
+ proj_data = np.load(item_proj_path)
+ item_vectors = proj_data["V"]
+
+ # 2. Load Models
+ print("Loading models...")
+ cfg = load_local_models_config()
+
+ embedder = Qwen3Embedding8B.from_config(cfg)
+ reranker = Qwen3Reranker.from_config(cfg)
+ llm = QwenInstruct.from_config(cfg)
+
+ # 3. Load User Store
+ # k = item_vectors.shape[1]
+ k = 256 # Hardcoded per config
+ user_store = UserTensorStore(k=k, path=user_store_path)
+
+ # --- CHECK 1: User Vector Similarity ---
+ print("\n--- CHECK 1: User Vector Similarity ---")
+ # Get 3 users with memories
+ valid_users = [uid for uid, state in user_store._states.items()
+ if np.linalg.norm(state.z_long) > 1e-6] # Only non-zero users
+
+ if len(valid_users) < 3:
+ print(f"Not enough users with memories found (found {len(valid_users)}). Skipping similarity check.")
+ else:
+ # Pick 3 random users
+ import random
+ selected_users = random.sample(valid_users, 3)
+ vectors = [user_store.get_state(uid).z_long for uid in selected_users]
+
+ # Calculate pairwise cosine similarity
+ def cos_sim(a, b):
+ norm_a = np.linalg.norm(a)
+ norm_b = np.linalg.norm(b)
+ if norm_a == 0 or norm_b == 0: return 0.0
+ return np.dot(a, b) / (norm_a * norm_b)
+
+ print(f"Selected Users: {selected_users}")
+ print(f"Sim(0, 1): {cos_sim(vectors[0], vectors[1]):.4f}")
+ print(f"Sim(0, 2): {cos_sim(vectors[0], vectors[2]):.4f}")
+ print(f"Sim(1, 2): {cos_sim(vectors[1], vectors[2]):.4f}")
+
+ # --- CHECK 2: Real User Retrieval ---
+ print("\n--- CHECK 2: Real User Retrieval ---")
+
+ if len(valid_users) > 0:
+ # Pick one user
+ target_user = valid_users[0]
+ # Find a query from this user?
+ # For now, let's use a generic query that might hit some tech preferences,
+ # or ideally find a query from the dataset if we had it loaded.
+ # Let's try a generic coding query since OASST1 has many.
+ query = "How do I write a Python function for fibonacci?"
+
+ print(f"User: {target_user}")
+ print(f"Query: {query}")
+
+ # 5. Retrieve Pipeline
+ print("\nRunning Retrieval Pipeline (GLOBAL search)...")
+ hits_global = retrieve_with_rerank(
+ user_id=target_user,
+ query=query,
+ embed_model=embedder,
+ reranker=reranker,
+ memory_cards=cards,
+ memory_embeddings=memory_embeddings,
+ user_store=user_store,
+ item_vectors=item_vectors,
+ topk_dense=64,
+ topk_rerank=3,
+ beta_long=0.0,
+ beta_short=0.0,
+ only_own_memories=False # Global search
+ )
+
+ print(f"\nTop {len(hits_global)} Memories (Global):")
+ for h in hits_global:
+ print(f" - [{h.kind}] {h.note_text} (User: {h.user_id})")
+
+ print("\nRunning Retrieval Pipeline (OWN memories only)...")
+ hits_own = retrieve_with_rerank(
+ user_id=target_user,
+ query=query,
+ embed_model=embedder,
+ reranker=reranker,
+ memory_cards=cards,
+ memory_embeddings=memory_embeddings,
+ user_store=user_store,
+ item_vectors=item_vectors,
+ topk_dense=64,
+ topk_rerank=3,
+ beta_long=0.0,
+ beta_short=0.0,
+ only_own_memories=True # Own search
+ )
+
+ print(f"\nTop {len(hits_own)} Memories (Own):")
+ notes = []
+ for h in hits_own:
+ print(f" - [{h.kind}] {h.note_text} (User: {h.user_id})")
+ notes.append(h.note_text)
+
+ # 6. Generate Answer (using OWN memories by default for demo)
+ print("\nGenerating Answer (using Own Memories)...")
+ history = [{"role": "user", "content": query}]
+ answer = llm.answer(history, notes)
+
+ print("-" * 40)
+ print(answer)
+ print("-" * 40)
+ else:
+ print("No valid users found for demo.")
+
+if __name__ == "__main__":
+ main()
+
diff --git a/scripts/day3_demo_feedback.py b/scripts/day3_demo_feedback.py
new file mode 100644
index 0000000..4e3d594
--- /dev/null
+++ b/scripts/day3_demo_feedback.py
@@ -0,0 +1,127 @@
+#!/usr/bin/env python3
+"""
+Day 3 Demo: Feedback Loop Simulation (Reward + Gating).
+"""
+
+import sys
+import os
+import json
+import numpy as np
+import random
+
+# Add src to sys.path
+sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.config.settings import load_local_models_config
+from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
+from personalization.models.reranker.qwen3_reranker import Qwen3Reranker
+from personalization.retrieval.preference_store.schemas import MemoryCard, ChatTurn
+from personalization.user_model.tensor_store import UserTensorStore
+from personalization.retrieval.pipeline import retrieve_with_rerank
+from personalization.feedback.handlers import eval_step
+
+def main():
+ # Paths
+ cards_path = "data/corpora/memory_cards.jsonl"
+ embs_path = "data/corpora/memory_embeddings.npy"
+ item_proj_path = "data/corpora/item_projection.npz"
+ user_store_path = "data/users/user_store.npz"
+ oasst_path = "data/raw_datasets/oasst1_queries.jsonl" # Source of turns
+
+ # 1. Load Data
+ print("Loading data stores...")
+ if not os.path.exists(cards_path) or not os.path.exists(embs_path):
+ print("Memory data missing.")
+ sys.exit(1)
+
+ cards = []
+ with open(cards_path, "r") as f:
+ for line in f:
+ cards.append(MemoryCard.model_validate_json(line))
+
+ memory_embeddings = np.load(embs_path)
+
+ proj_data = np.load(item_proj_path)
+ item_vectors = proj_data["V"]
+
+ # 2. Load Models
+ print("Loading models...")
+ cfg = load_local_models_config()
+ embedder = Qwen3Embedding8B.from_config(cfg)
+ reranker = Qwen3Reranker.from_config(cfg)
+
+ user_store = UserTensorStore(k=256, path=user_store_path)
+
+ # 3. Simulate a Session
+ # Since we don't have full sessions in 'oasst1_queries.jsonl' (it's flat queries),
+ # we'll mock a session or try to find one if we had full chat logs.
+ # For demo, let's construct a synthetic scenario.
+
+ print("\n--- Synthetic Session Evaluation ---")
+
+ # Scenario 1: Success
+ # User asks python, system gives good python answer, user asks follow up.
+
+ user_id = "test_user_ok"
+ q_t = "How do I list files in a directory in Python?"
+
+ # Mock retrieval results (relevant)
+ # Ideally we'd run retrieval, but let's assume we found a relevant card
+ # For demo, let's actually run retrieval
+ hits = retrieve_with_rerank(
+ user_id=user_id,
+ query=q_t,
+ embed_model=embedder,
+ reranker=reranker,
+ memory_cards=cards,
+ memory_embeddings=memory_embeddings,
+ user_store=user_store,
+ item_vectors=item_vectors,
+ topk_dense=64,
+ topk_rerank=3
+ )
+
+ a_t = "You can use os.listdir() or pathlib.Path.iterdir(). Here is an example..."
+ q_t1 = "Great, can you show me the pathlib one?"
+
+ print(f"\n[Scenario 1]")
+ print(f"Q_t: {q_t}")
+ print(f"A_t: {a_t}")
+ print(f"Q_t+1: {q_t1}")
+ print(f"Memories: {[m.note_text for m in hits]}")
+
+ # Eval
+ e_q = embedder.encode([q_t], return_tensor=False)[0]
+ e_q1 = embedder.encode([q_t1], return_tensor=False)[0]
+ e_q = np.array(e_q)
+ e_q1 = np.array(e_q1)
+
+ r_hat, g_hat = eval_step(q_t, a_t, q_t1, hits, e_q, e_q1)
+ print(f"-> Reward: {r_hat:.2f}")
+ print(f"-> Gating: {g_hat:.2f}")
+
+ # Scenario 2: Failure (Complaint)
+ q_t = "Explain quantum entanglement."
+ a_t = "Quantum entanglement is a phenomenon where particles..."
+ q_t1 = "No, that's not what I meant. Explain it simply like I'm five."
+
+ # Mock retrieval (irrelevant or empty?)
+ # Let's say we retrieved some python stuff again by mistake
+
+ print(f"\n[Scenario 2]")
+ print(f"Q_t: {q_t}")
+ print(f"A_t: {a_t}")
+ print(f"Q_t+1: {q_t1}")
+ print(f"Memories: {[m.note_text for m in hits]} (Irrelevant)")
+
+ e_q = embedder.encode([q_t], return_tensor=False)[0]
+ e_q1 = embedder.encode([q_t1], return_tensor=False)[0]
+ e_q = np.array(e_q)
+ e_q1 = np.array(e_q1)
+
+ r_hat, g_hat = eval_step(q_t, a_t, q_t1, hits, e_q, e_q1)
+ print(f"-> Reward: {r_hat:.2f}")
+ print(f"-> Gating: {g_hat:.2f}")
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/day4_offline_rl_replay.py b/scripts/day4_offline_rl_replay.py
new file mode 100644
index 0000000..1c5cdd9
--- /dev/null
+++ b/scripts/day4_offline_rl_replay.py
@@ -0,0 +1,208 @@
+#!/usr/bin/env python3
+"""
+Day 4: Offline RL Replay Script.
+Simulates REINFORCE update on user state using OASST1 chat logs.
+"""
+
+import sys
+import os
+import json
+import numpy as np
+import random
+from tqdm import tqdm
+from collections import defaultdict
+
+# Add src to sys.path
+sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.config.settings import load_local_models_config
+from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
+from personalization.models.reranker.qwen3_reranker import Qwen3Reranker
+from personalization.retrieval.preference_store.schemas import MemoryCard
+from personalization.user_model.tensor_store import UserTensorStore
+from personalization.retrieval.pipeline import retrieve_with_policy
+from personalization.feedback.handlers import eval_step
+from personalization.user_model.policy.reinforce import reinforce_update_user_state
+
+def main():
+ # Configuration
+ cfg = load_local_models_config()
+ rl_cfg = {
+ "item_dim": 256,
+ "beta_long": 0.1,
+ "beta_short": 0.3,
+ "tau": 1.0,
+ "eta_long": 1e-3,
+ "eta_short": 5e-3,
+ "ema_alpha": 0.05,
+ "short_decay": 0.1
+ }
+ # Load from yaml if available, here hardcoded for safety/clarity in demo
+
+ # Paths
+ cards_path = "data/corpora/memory_cards.jsonl"
+ embs_path = "data/corpora/memory_embeddings.npy"
+ item_proj_path = "data/corpora/item_projection.npz"
+ user_store_path = "data/users/user_store.npz"
+ chat_path = "data/raw_datasets/oasst1_queries.jsonl"
+
+ # 1. Load Data
+ print("Loading data stores...")
+ if not os.path.exists(cards_path):
+ print("Data missing.")
+ sys.exit(1)
+
+ cards = []
+ with open(cards_path, "r") as f:
+ for line in f:
+ cards.append(MemoryCard.model_validate_json(line))
+
+ memory_embeddings = np.load(embs_path)
+ item_vectors = np.load(item_proj_path)["V"]
+
+ # 2. Load Models
+ print("Loading models...")
+ embedder = Qwen3Embedding8B.from_config(cfg)
+ reranker = Qwen3Reranker.from_config(cfg)
+
+ user_store = UserTensorStore(k=rl_cfg["item_dim"], path=user_store_path)
+
+ # 3. Load Chat Logs and Group by Session
+ print("Loading chat logs...")
+ sessions = defaultdict(list)
+ with open(chat_path, "r") as f:
+ for line in f:
+ row = json.loads(line)
+ sessions[row["session_id"]].append(row)
+
+ # Filter sessions with at least 2 user turns to have q_t and q_t+1
+ valid_sessions = [s for s in sessions.values() if len(s) >= 2]
+ print(f"Found {len(valid_sessions)} valid sessions for replay.")
+
+ # Sample a few users/sessions to replay
+ # For speed, pick 50 sessions
+ sampled_sessions = random.sample(valid_sessions, min(50, len(valid_sessions)))
+
+ total_reward = 0.0
+ total_gating = 0.0
+ steps = 0
+
+ # 4. Replay Loop
+ print("\nStarting RL Replay...")
+
+ for session in tqdm(sampled_sessions, desc="Replay Sessions"):
+ # Sort by turn
+ session.sort(key=lambda x: x.get("turn_id", 0))
+
+ # We need (q_t, a_t, q_t+1)
+ # OASST1 flat queries usually don't contain assistant answers in the same file if we only extracted user queries.
+ # But for this demo, let's assume we can proceed without a_t (or mock it) for retrieval check,
+ # OR we rely on the fact that if q_t+1 exists, there was an answer.
+ # Limitation: We might not have a_t text.
+ # Day 3 eval_step takes a_t. We can pass empty string if unavailable,
+ # as our current reward model mainly looks at q_t+1 keywords.
+
+ user_id = session[0]["user_id"]
+
+ # Pre-check user state
+ state_before = user_store.get_state(user_id)
+ norm_long_before = np.linalg.norm(state_before.z_long)
+ norm_short_before = np.linalg.norm(state_before.z_short)
+
+ num_updates = 0
+ for i in range(len(session) - 1):
+ q_t_row = session[i]
+ q_t1_row = session[i+1]
+
+ q_t = q_t_row["original_query"]
+ q_t1 = q_t1_row["original_query"]
+ a_t = "" # Missing in this file
+
+ # 1. Retrieve with Policy
+ # Note: We use only_own_memories=True to simulate personalized memory bank
+ # But if the user is new/no memory in our store, this returns empty.
+ # Let's try both or fallback? For RL on user vector, we need candidates.
+ # If user has no memory, policy cannot select anything.
+ # Let's use global search for replay to ensure we have candidates to rerank.
+ # In real system, we'd search both.
+
+ candidates, cand_vectors, base_scores, chosen_idx, probs = retrieve_with_policy(
+ user_id=user_id,
+ query=q_t,
+ embed_model=embedder,
+ reranker=reranker,
+ memory_cards=cards,
+ memory_embeddings=memory_embeddings,
+ user_store=user_store,
+ item_vectors=item_vectors,
+ topk_dense=32,
+ topk_rerank=5,
+ beta_long=rl_cfg["beta_long"],
+ beta_short=rl_cfg["beta_short"],
+ tau=rl_cfg["tau"],
+ only_own_memories=False
+ )
+
+ if not candidates:
+ continue
+
+ chosen_memories = [candidates[i] for i in chosen_idx]
+
+ # 2. Eval (Reward & Gating)
+ # Need embeddings
+ e_q = embedder.encode([q_t], return_tensor=False)[0]
+ e_q1 = embedder.encode([q_t1], return_tensor=False)[0]
+
+ r_hat, g_hat = eval_step(
+ q_t, a_t, q_t1,
+ chosen_memories,
+ query_embedding_t=np.array(e_q),
+ query_embedding_t1=np.array(e_q1)
+ )
+
+ total_reward += r_hat
+ total_gating += g_hat
+ steps += 1
+
+ # 3. Update (REINFORCE)
+ state = user_store.get_state(user_id)
+ updated = reinforce_update_user_state(
+ user_state=state,
+ item_vectors=cand_vectors,
+ chosen_indices=chosen_idx,
+ policy_probs=probs,
+ reward_hat=r_hat,
+ gating=g_hat,
+ tau=rl_cfg["tau"],
+ eta_long=rl_cfg["eta_long"],
+ eta_short=rl_cfg["eta_short"],
+ ema_alpha=rl_cfg["ema_alpha"],
+ short_decay=rl_cfg["short_decay"]
+ )
+ if updated:
+ num_updates += 1
+
+ user_store.save_state(state)
+
+ # Post-check
+ state_after = user_store.get_state(user_id)
+ norm_long_after = np.linalg.norm(state_after.z_long)
+ norm_short_after = np.linalg.norm(state_after.z_short)
+
+ # Print change only if updated
+ if num_updates > 0:
+ print(f"User {user_id}: {num_updates} updates")
+ print(f" ||z_long|| {norm_long_before:.8f} -> {norm_long_after:.8f}")
+ print(f" ||z_short|| {norm_short_before:.8f} -> {norm_short_after:.8f}")
+
+ print("\n--- Replay Finished ---")
+ print(f"Total Steps: {steps}")
+ print(f"Avg Reward: {total_reward / max(1, steps):.8f}")
+ print(f"Avg Gating: {total_gating / max(1, steps):.8f}")
+
+ user_store.persist()
+ print(f"User store saved to {user_store_path}")
+
+if __name__ == "__main__":
+ main()
+
diff --git a/scripts/debug_context_file.py b/scripts/debug_context_file.py
new file mode 100644
index 0000000..81ac6b9
--- /dev/null
+++ b/scripts/debug_context_file.py
@@ -0,0 +1,14 @@
+import json
+
+path = "data/raw_datasets/personamem/shared_contexts_32k.jsonl"
+with open(path, 'r') as f:
+ line = f.readline()
+ data = json.loads(line)
+ print(f"Type: {type(data)}")
+ if isinstance(data, dict):
+ print(f"Keys: {list(data.keys())}")
+ # Peek into values
+ for k, v in data.items():
+ print(f"Key '{k}' type: {type(v)}")
+ if isinstance(v, list):
+ print(f" Length: {len(v)}")
diff --git a/scripts/debug_minimal_day3.py b/scripts/debug_minimal_day3.py
new file mode 100644
index 0000000..169cf13
--- /dev/null
+++ b/scripts/debug_minimal_day3.py
@@ -0,0 +1,40 @@
+import sys
+import os
+import torch
+
+# Add src to sys.path
+sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.config.settings import load_local_models_config
+from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
+
+def main():
+ print("--- Minimal Day 3 Debug ---")
+
+ # 1. Load Config
+ print("Loading Local Models Config...")
+ cfg = load_local_models_config()
+ print(f"Config loaded.")
+
+ # Check what we got
+ spec = cfg.embedding.qwen3
+ print(f"Model Path: {spec.local_path}")
+ print(f"Dtype: {spec.dtype}")
+ print(f"Device Map: {spec.device_map}")
+ # Check if trust_remote_code is in spec (it should be if my edit worked)
+ trc = getattr(spec, "trust_remote_code", "UNKNOWN")
+ print(f"Trust Remote Code: {trc}")
+
+ # 2. Init Embedder
+ print("Initializing Embedder via Qwen3Embedding8B.from_config...")
+ try:
+ embedder = Qwen3Embedding8B.from_config(cfg)
+ print("Embedder loaded successfully!")
+ except Exception as e:
+ print(f"Failed to load embedder: {e}")
+ import traceback
+ traceback.print_exc()
+
+if __name__ == "__main__":
+ main()
+
diff --git a/scripts/debug_personamem_hash.py b/scripts/debug_personamem_hash.py
new file mode 100644
index 0000000..7ef442d
--- /dev/null
+++ b/scripts/debug_personamem_hash.py
@@ -0,0 +1,22 @@
+import hashlib
+import json
+
+def get_line_hash(line_str: str) -> str:
+ """Compute SHA256 hash of the line content to match shared_context_id."""
+ return hashlib.sha256(line_str.strip().encode("utf-8")).hexdigest()
+
+def debug_hash():
+ jsonl_path = "data/raw_datasets/personamem/shared_contexts_32k.jsonl"
+ with open(jsonl_path, "r") as f:
+ first_line = f.readline()
+
+ computed_hash = get_line_hash(first_line)
+ target_hash = "e898d03fec683b1cabf29f57287ff66f8a31842543ecef44b56766844c1c1301"
+
+ print(f"Computed: {computed_hash}")
+ print(f"Target: {target_hash}")
+ print(f"Match: {computed_hash == target_hash}")
+
+if __name__ == "__main__":
+ debug_hash()
+
diff --git a/scripts/diagnose_oom.py b/scripts/diagnose_oom.py
new file mode 100644
index 0000000..22de3f9
--- /dev/null
+++ b/scripts/diagnose_oom.py
@@ -0,0 +1,78 @@
+import torch
+from transformers import AutoModel, AutoTokenizer
+import os
+import psutil
+import time
+import sys
+import gc
+
+def log_mem(msg):
+ mem = psutil.Process().memory_info().rss / (1024**3)
+ if torch.cuda.is_available():
+ gpu = torch.cuda.memory_allocated() / (1024**3)
+ gpu_res = torch.cuda.memory_reserved() / (1024**3)
+ print(f"[{msg}] RAM: {mem:.2f}GB | GPU Alloc: {gpu:.2f}GB | GPU Res: {gpu_res:.2f}GB")
+ else:
+ print(f"[{msg}] RAM: {mem:.2f}GB | GPU: N/A")
+ sys.stdout.flush()
+
+def main():
+ print("--- Diagnostic Script ---")
+ log_mem("Start")
+
+ model_path = "models/qwen3-embedding-8b"
+ print(f"Model path: {model_path}")
+
+ # Check config
+ import yaml
+ try:
+ with open("configs/local_models.yaml", "r") as f:
+ cfg = yaml.safe_load(f)
+ print("Config loaded from local_models.yaml:")
+ print(cfg['models']['embedding']['qwen3'])
+ except Exception as e:
+ print(f"Could not load config: {e}")
+
+ # Explicit garbage collection
+ gc.collect()
+ torch.cuda.empty_cache()
+ log_mem("Pre-Load")
+
+ print("Loading Tokenizer...")
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=False)
+ log_mem("Tokenizer Loaded")
+
+ print("Loading Model (trust_remote_code=False)...")
+ try:
+ # Load with low_cpu_mem_usage=True explicit (though auto/cuda usually does it)
+ model = AutoModel.from_pretrained(
+ model_path,
+ device_map="cuda:0",
+ torch_dtype=torch.bfloat16,
+ trust_remote_code=False,
+ low_cpu_mem_usage=True
+ )
+ print("Model loaded successfully.")
+ except Exception as e:
+ print(f"Model load failed: {e}")
+ return
+
+ log_mem("Model Loaded")
+
+ print("Testing forward pass with small input...")
+ input_text = "Hello world"
+ inputs = tokenizer(input_text, return_tensors="pt").to("cuda:0")
+
+ try:
+ with torch.no_grad():
+ outputs = model(**inputs)
+ print("Forward pass success.")
+ print(f"Output shape: {outputs.last_hidden_state.shape}")
+ except Exception as e:
+ print(f"Forward pass failed: {e}")
+
+ log_mem("End")
+
+if __name__ == "__main__":
+ main()
+
diff --git a/scripts/download_datasets.py b/scripts/download_datasets.py
new file mode 100644
index 0000000..f78b15f
--- /dev/null
+++ b/scripts/download_datasets.py
@@ -0,0 +1,210 @@
+import os
+import json
+import random
+from typing import List, Dict, Any
+from datasets import load_dataset
+from tqdm import tqdm
+
+# Configuration
+OUTPUT_DIR = "data/raw_datasets"
+
+# Dataset configurations
+# Format: (huggingface_id, subset, split, text_column_name, approximate_limit)
+SOURCES = [
+ {
+ "id": "lmsys/lmsys-chat-1m",
+ "subset": None,
+ "split": "train",
+ "type": "lmsys",
+ "limit": 200000
+ },
+ {
+ "id": "allenai/WildChat",
+ "subset": None,
+ "split": "train",
+ "type": "wildchat",
+ "limit": 150000
+ },
+ {
+ "id": "anon8231489123/ShareGPT_Vicuna_unfiltered",
+ "subset": None,
+ "split": "train",
+ "data_files": "ShareGPT_V3_unfiltered_cleaned_split.json",
+ "type": "sharegpt",
+ "limit": 50000
+ },
+ {
+ "id": "yahma/alpaca-cleaned",
+ "subset": None,
+ "split": "train",
+ "type": "alpaca",
+ "limit": 52000
+ },
+ {
+ "id": "Open-Orca/SlimOrca",
+ "subset": None,
+ "split": "train",
+ "type": "slimorca",
+ "limit": 100000
+ }
+]
+
+def ensure_english(text: str) -> bool:
+ # A simple heuristic to filter non-English text.
+ # For production, use langdetect or similar libraries.
+ # Here we check if a significant portion of characters are ASCII.
+ try:
+ return text.isascii()
+ except:
+ return False
+
+def process_lmsys(example: Dict[str, Any]) -> str | None:
+ # LMSYS format: conversation is in 'conversation' list of dicts
+ try:
+ conversation = example.get("conversation", [])
+ if not conversation:
+ return None
+ # Get first user message
+ if conversation[0]["role"] == "user":
+ return conversation[0]["content"]
+ except:
+ pass
+ return None
+
+def process_wildchat(example: Dict[str, Any]) -> str | None:
+ # WildChat format: 'conversation' list of dicts or 'prompt' column?
+ # Checking dataset viewer, it usually has 'conversation' with 'content' and 'role'
+ try:
+ conversation = example.get("conversation", [])
+ if not conversation:
+ return None
+ if conversation[0]["role"] == "user":
+ return conversation[0]["content"]
+ except:
+ pass
+ return None
+
+def process_sharegpt(example: Dict[str, Any]) -> str | None:
+ # ShareGPT format: 'conversations' list
+ try:
+ conversations = example.get("conversations", [])
+ if not conversations:
+ return None
+ # Usually human/gpt or user/assistant
+ if conversations[0]["from"] in ["human", "user"]:
+ return conversations[0]["value"]
+ except:
+ pass
+ return None
+
+def process_alpaca(example: Dict[str, Any]) -> str | None:
+ # Alpaca format: 'instruction' and 'input'. We combine them if input exists.
+ try:
+ instruction = example.get("instruction", "").strip()
+ inp = example.get("input", "").strip()
+ if inp:
+ return f"{instruction}\n\nInput: {inp}"
+ return instruction
+ except:
+ pass
+ return None
+
+def process_slimorca(example: Dict[str, Any]) -> str | None:
+ # SlimOrca format: 'conversations' list of dicts (from, value)
+ # Similar to ShareGPT but keys might differ slightly
+ try:
+ conversations = example.get("conversations", [])
+ if not conversations:
+ return None
+ # Usually from: human/user
+ if conversations[0]["from"] in ["human", "user"]:
+ return conversations[0]["value"]
+ except:
+ pass
+ return None
+
+def download_and_process():
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
+
+ all_queries = []
+
+ # Target new sources only (alpaca and slimorca)
+ # You can comment out this filter if you want to re-run everything
+ new_types = ["alpaca", "slimorca"]
+
+ for source in SOURCES:
+ if source["type"] not in new_types:
+ continue
+
+ print(f"Processing {source['id']}...")
+ try:
+ # Load streaming to save disk/memory
+ kwargs = {"streaming": True}
+ if "data_files" in source:
+ kwargs["data_files"] = source["data_files"]
+
+ ds = load_dataset(source["id"], source["subset"], split=source["split"], **kwargs)
+
+ count = 0
+ limit = source["limit"]
+
+ for example in tqdm(ds, desc=f"Reading {source['id']}", total=limit):
+ if count >= limit:
+ break
+
+ query = None
+ if source["type"] == "lmsys":
+ query = process_lmsys(example)
+ elif source["type"] == "wildchat":
+ query = process_wildchat(example)
+ elif source["type"] == "sharegpt":
+ query = process_sharegpt(example)
+ elif source["type"] == "alpaca":
+ query = process_alpaca(example)
+ elif source["type"] == "slimorca":
+ query = process_slimorca(example)
+
+ # Basic cleaning
+ if query and len(query.strip()) > 5 and ensure_english(query):
+ all_queries.append({
+ "source": source["id"],
+ "query": query.strip()
+ })
+ count += 1
+
+ except Exception as e:
+ print(f"Error processing {source['id']}: {e}")
+
+ # Deduplicate based on query content
+ print(f"Total collected new items: {len(all_queries)}")
+
+ # Load existing if available to dedup against
+ output_path = os.path.join(OUTPUT_DIR, "combined_raw_queries.jsonl")
+ existing_data = []
+ if os.path.exists(output_path):
+ print("Loading existing data for deduplication...")
+ with open(output_path, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ existing_data.append(json.loads(line))
+
+ combined = existing_data + all_queries
+ print(f"Total before final deduplication: {len(combined)}")
+
+ unique_queries = {item["query"]: item for item in combined}.values()
+ final_data = list(unique_queries)
+ print(f"Total after final deduplication: {len(final_data)}")
+
+ # Shuffle
+ random.shuffle(final_data)
+
+ # Save
+ print(f"Saving to {output_path}...")
+ with open(output_path, "w", encoding="utf-8") as f:
+ for item in final_data:
+ f.write(json.dumps(item, ensure_ascii=False) + "\n")
+
+ print("Done!")
+
+if __name__ == "__main__":
+ download_and_process()
diff --git a/scripts/download_llama.py b/scripts/download_llama.py
new file mode 100644
index 0000000..47c7243
--- /dev/null
+++ b/scripts/download_llama.py
@@ -0,0 +1,16 @@
+from huggingface_hub import snapshot_download
+import os
+
+model_id = "meta-llama/Llama-3.1-8B-Instruct"
+local_dir = "models/llama-3.1-8b-instruct"
+
+os.makedirs(local_dir, exist_ok=True)
+
+print(f"Downloading {model_id} to {local_dir}...")
+snapshot_download(
+ repo_id=model_id,
+ local_dir=local_dir,
+ # token=os.getenv("HF_TOKEN"), # Assuming token is in env or cached
+)
+print("Download complete.")
+
diff --git a/scripts/download_oasst1.py b/scripts/download_oasst1.py
new file mode 100644
index 0000000..3105f28
--- /dev/null
+++ b/scripts/download_oasst1.py
@@ -0,0 +1,78 @@
+#!/usr/bin/env python3
+"""
+Script to download and process OpenAssistant/oasst1 dataset.
+Converts it into a flat list of user turns (ChatTurn) for our pipeline.
+
+Output format per line (JSONL):
+{
+ "original_query": str,
+ "source": "oasst1",
+ "user_id": str,
+ "session_id": str,
+ "turn_id": int
+}
+"""
+
+import json
+import os
+from datasets import load_dataset
+from tqdm import tqdm
+
+def main():
+ output_path = "data/raw_datasets/oasst1_queries.jsonl"
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+
+ print("Downloading OpenAssistant/oasst1 dataset...")
+ # OASST1 is a tree structure. We need to traverse it to reconstruct conversations.
+ # It has 'message_id', 'parent_id', 'user_id', 'text', 'role'
+ ds = load_dataset("OpenAssistant/oasst1", split="train")
+
+ print(f"Loaded {len(ds)} messages. Reconstructing threads...")
+
+ # Index by message_id
+ id2msg = {}
+ for row in tqdm(ds, desc="Indexing"):
+ id2msg[row["message_id"]] = row
+
+ # Find leaf nodes to trace back threads?
+ # Or just find all user messages and trace back to root to establish session context?
+ # For this task: "从 OASST1 里构造统一的 ChatTurn 序列(带 user_id 和 session_id)"
+ # We want valid user turns.
+ # OASST1 'user_id' is the author ID.
+ # 'message_tree_id' identifies the conversation tree (session).
+
+ # We can iterate all messages. If role=='prompter' (user), we treat it as a turn.
+ # We use 'message_tree_id' as session_id.
+
+ queries = []
+
+ # Iterate all rows
+ for row in tqdm(ds, desc="Processing"):
+ if row["role"] == "prompter":
+ # This is a user turn
+ user_id = row["user_id"] # Author ID
+ session_id = row["message_tree_id"]
+ text = row["text"]
+
+ # Simple metadata
+ queries.append({
+ "original_query": text,
+ "source": "oasst1",
+ "user_id": str(user_id),
+ "session_id": str(session_id),
+ "turn_id": 0 # We don't strictly need precise turn_id for Day 1 pipeline right now unless we sort
+ })
+
+ print(f"Extracted {len(queries)} user queries.")
+
+ # Save
+ print(f"Saving to {output_path}...")
+ with open(output_path, "w", encoding="utf-8") as f:
+ for q in queries:
+ f.write(json.dumps(q, ensure_ascii=False) + "\n")
+
+ print("Done.")
+
+if __name__ == "__main__":
+ main()
+
diff --git a/scripts/download_personamem.py b/scripts/download_personamem.py
new file mode 100644
index 0000000..31b4e0e
--- /dev/null
+++ b/scripts/download_personamem.py
@@ -0,0 +1,25 @@
+from huggingface_hub import hf_hub_download
+import os
+
+repo_id = "bowen-upenn/PersonaMem"
+local_dir = "data/raw_datasets/personamem"
+files_to_download = [
+ "questions_32k.csv",
+ "shared_contexts_32k.jsonl"
+]
+
+os.makedirs(local_dir, exist_ok=True)
+
+print(f"Downloading files from {repo_id} to {local_dir}...")
+
+for filename in files_to_download:
+ print(f"Downloading {filename}...")
+ hf_hub_download(
+ repo_id=repo_id,
+ filename=filename,
+ repo_type="dataset",
+ local_dir=local_dir
+ )
+
+print("Download complete.")
+
diff --git a/scripts/eval_embedder_reranker.py b/scripts/eval_embedder_reranker.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/scripts/eval_embedder_reranker.py
diff --git a/scripts/eval_interface_example.py b/scripts/eval_interface_example.py
new file mode 100644
index 0000000..d5dc6cd
--- /dev/null
+++ b/scripts/eval_interface_example.py
@@ -0,0 +1,154 @@
+#!/usr/bin/env python3
+"""
+Example: Using the PersonalizedLLM Interface for Evaluation.
+
+This script demonstrates the evaluation interface that can be used
+by a user simulator or evaluation framework.
+
+Call sequence per evaluation run:
+1. reset_user(user_id) - Start fresh for this user's "life"
+2. For each session (s=1..S):
+ a. reset_session(user_id) - New chat window
+ b. For each turn (t=1..T):
+ i. [Turn 2+] apply_feedback() for previous turn
+ ii. resp = chat(user_id, query)
+ iii. [Simulator computes reward from response]
+3. persist() - Save state at end
+"""
+
+import sys
+import os
+
+# Add src to sys.path
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.serving import (
+ PersonalizedLLM,
+ AssistantResponse,
+ Feedback,
+)
+
+
+def main():
+ print("=" * 60)
+ print("PersonalizedLLM Evaluation Interface Demo")
+ print("=" * 60)
+
+ # Initialize the system
+ # Note: This will load models, which takes time and GPU memory
+ print("\n[1] Initializing PersonalizedLLM...")
+
+ llm = PersonalizedLLM(
+ user_store_path="data/users/user_store_eval_demo.npz",
+ only_own_memories=True,
+ enable_preference_extraction=True,
+ enable_rl_updates=True,
+ )
+
+ # Define test user
+ user_id = "eval_demo_user"
+
+ # Reset user for clean experiment
+ print(f"\n[2] Resetting user: {user_id}")
+ llm.reset_user(user_id)
+
+ # Check initial state
+ print(f"\n[3] Initial user state:")
+ print(f" {llm.get_user_state_summary(user_id)}")
+
+ # Simulate multiple sessions
+ num_sessions = 2
+ queries_per_session = [
+ # Session 1: Food preferences
+ [
+ "What's a good recipe for dinner tonight?",
+ "I prefer vegetarian food with Asian flavors.",
+ "Can you suggest something spicy?",
+ ],
+ # Session 2: Test personalization retention
+ [
+ "What should I cook for lunch?",
+ "Give me a quick meal idea.",
+ ],
+ ]
+
+ all_responses = []
+
+ for session_idx, session_queries in enumerate(queries_per_session):
+ print(f"\n{'=' * 60}")
+ print(f"SESSION {session_idx + 1}")
+ print("=" * 60)
+
+ # Reset session (new chat window)
+ llm.reset_session(user_id)
+ print(f"[Session {session_idx + 1}] Started new session")
+
+ session_responses = []
+
+ for turn_idx, query in enumerate(session_queries):
+ print(f"\n--- Turn {turn_idx + 1} ---")
+
+ # Apply feedback for previous turn (from turn 2 onwards)
+ if turn_idx > 0:
+ # Simulated feedback - in real eval, this comes from user simulator
+ simulated_reward = 0.7 + 0.1 * (turn_idx % 2) # Varies by turn
+ simulated_gating = 1.0 if turn_idx > 0 else 0.0
+
+ feedback = Feedback(
+ user_id=user_id,
+ turn_id=turn_idx - 1,
+ reward=simulated_reward,
+ gating=simulated_gating,
+ meta={"source": "demo_simulator"}
+ )
+
+ print(f"[Feedback] Applying: reward={simulated_reward:.2f}, gating={simulated_gating:.1f}")
+ llm.apply_feedback(feedback)
+
+ # Main chat call
+ print(f"User: {query}")
+ response: AssistantResponse = llm.chat(user_id, query)
+
+ print(f"Assistant: {response.answer[:200]}..." if len(response.answer) > 200 else f"Assistant: {response.answer}")
+ print(f"[Usage] prompt={response.usage.prompt_tokens}, completion={response.usage.completion_tokens}, model={response.usage.model}")
+
+ if response.debug:
+ print(f"[Debug] memories={len(response.debug.selected_memory_ids)}, z_long_norm={response.debug.extra.get('z_long_norm', 0):.4f}")
+ if response.debug.extracted_preferences:
+ print(f"[Debug] Extracted {len(response.debug.extracted_preferences)} preferences")
+
+ session_responses.append(response)
+
+ all_responses.append(session_responses)
+
+ # Show user state after session
+ print(f"\n[Session {session_idx + 1}] Final state:")
+ print(f" {llm.get_user_state_summary(user_id)}")
+
+ # Summary
+ print(f"\n{'=' * 60}")
+ print("EVALUATION SUMMARY")
+ print("=" * 60)
+
+ total_tokens = sum(
+ r.usage.total_tokens
+ for session in all_responses
+ for r in session
+ )
+ total_turns = sum(len(s) for s in all_responses)
+
+ print(f"Total sessions: {len(all_responses)}")
+ print(f"Total turns: {total_turns}")
+ print(f"Total tokens: {total_tokens}")
+ print(f"Final user state: {llm.get_user_state_summary(user_id)}")
+
+ # Persist (optional, for saving state between runs)
+ # llm.persist()
+ # print("\nState persisted to disk.")
+
+ print("\nDemo complete!")
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/scripts/eval_single_ckpt.py b/scripts/eval_single_ckpt.py
new file mode 100644
index 0000000..8597907
--- /dev/null
+++ b/scripts/eval_single_ckpt.py
@@ -0,0 +1,145 @@
+import json
+import os
+import torch
+import glob
+from tqdm import tqdm
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
+from torch.utils.data import Dataset, DataLoader
+
+# --- Configuration ---
+# You can manually set the checkpoint path here if glob fails or is slow
+# Example: "saves/qwen3-0.6b-full-sft-h200/checkpoint-4358"
+CHECKPOINT_DIR = "saves/qwen3-0.6b-full-sft-h200"
+TEST_FILE = "data/test_llama_factory.json"
+BATCH_SIZE = 128
+USE_FLASH_ATTN = False
+
+# Load System Prompt
+with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f:
+ SYSTEM_PROMPT = f.read()
+
+class EvalDataset(Dataset):
+ def __init__(self, data):
+ self.data = data
+ def __len__(self):
+ return len(self.data)
+ def __getitem__(self, idx):
+ return self.data[idx]
+
+def load_test_data():
+ with open(TEST_FILE, "r", encoding="utf-8") as f:
+ return json.load(f)
+
+def batch_generate(model, tokenizer, batch_data, device="cuda"):
+ prompts = []
+ for item in batch_data:
+ messages = [
+ {"role": "system", "content": SYSTEM_PROMPT},
+ {"role": "user", "content": item["input"]}
+ ]
+ text = tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True
+ )
+ prompts.append(text)
+
+ inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left").to(device)
+
+ with torch.no_grad():
+ generated_ids = model.generate(
+ **inputs,
+ max_new_tokens=256,
+ do_sample=False,
+ pad_token_id=tokenizer.pad_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ )
+
+ input_len = inputs.input_ids.shape[1]
+ gen_tokens = generated_ids[:, input_len:]
+ responses = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
+ return responses
+
+def evaluate_ckpt():
+ # 1. Find the latest checkpoint
+ checkpoints = sorted(glob.glob(os.path.join(CHECKPOINT_DIR, "checkpoint-*")), key=lambda x: int(x.split("-")[-1]))
+ if not checkpoints:
+ print(f"No checkpoints found in {CHECKPOINT_DIR}")
+ return
+
+ latest_ckpt = checkpoints[-1]
+ print(f"\nTarget Checkpoint: {latest_ckpt}")
+
+ device = "cuda"
+ print(f"Loading model (Batch Size: {BATCH_SIZE})...")
+
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(latest_ckpt, trust_remote_code=True, padding_side="left")
+ if tokenizer.pad_token_id is None:
+ tokenizer.pad_token_id = tokenizer.eos_token_id
+
+ kwargs = {"device_map": device, "torch_dtype": torch.bfloat16, "trust_remote_code": True}
+ if USE_FLASH_ATTN:
+ kwargs["attn_implementation"] = "flash_attention_2"
+
+ model = AutoModelForCausalLM.from_pretrained(latest_ckpt, **kwargs)
+ model.eval()
+ except Exception as e:
+ print(f"CRITICAL ERROR loading model: {e}")
+ return
+
+ # 2. Prepare Data
+ test_data = load_test_data()
+ dataset = EvalDataset(test_data)
+ # Reduce num_workers to avoid hang if system is stressed
+ dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
+
+ y_true_has_pref = []
+ y_pred_has_pref = []
+ json_valid_count = 0
+
+ print(f"Evaluating {len(test_data)} samples...")
+
+ # 3. Inference Loop
+ for batch in tqdm(dataloader):
+ inputs = batch["input"]
+ outputs = batch["output"]
+
+ # Ground Truth
+ for gt_str in outputs:
+ try:
+ gt_json = json.loads(gt_str)
+ gt_has = len(gt_json.get("preferences", [])) > 0
+ except:
+ gt_has = False
+ y_true_has_pref.append(gt_has)
+
+ # Prediction
+ batch_items = [{"input": inp} for inp in inputs]
+ responses = batch_generate(model, tokenizer, batch_items, device)
+
+ for pred_str in responses:
+ pred_has = False
+ try:
+ pred_json = json.loads(pred_str)
+ json_valid_count += 1
+ pred_has = len(pred_json.get("preferences", [])) > 0
+ except:
+ pass
+ y_pred_has_pref.append(pred_has)
+
+ # 4. Metrics
+ print("\n" + "="*40)
+ print(f"RESULTS for {latest_ckpt}")
+ print("="*40)
+ print(f"JSON Validity: {json_valid_count / len(test_data):.4f}")
+ print(f"Accuracy: {accuracy_score(y_true_has_pref, y_pred_has_pref):.4f}")
+ print(f"Precision: {precision_score(y_true_has_pref, y_pred_has_pref, zero_division=0):.4f}")
+ print(f"Recall: {recall_score(y_true_has_pref, y_pred_has_pref, zero_division=0):.4f}")
+ print(f"F1 Score: {f1_score(y_true_has_pref, y_pred_has_pref, zero_division=0):.4f}")
+ print("="*40)
+
+if __name__ == "__main__":
+ evaluate_ckpt()
+
diff --git a/scripts/evaluate_checkpoints.py b/scripts/evaluate_checkpoints.py
new file mode 100644
index 0000000..cb5c993
--- /dev/null
+++ b/scripts/evaluate_checkpoints.py
@@ -0,0 +1,205 @@
+import json
+import os
+import glob
+import torch
+import matplotlib.pyplot as plt
+import pandas as pd
+from tqdm import tqdm
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
+from torch.utils.data import Dataset, DataLoader
+
+# --- Configuration ---
+BASE_MODEL_NAME = "Qwen/Qwen3-0.6B" # Or local path models/Qwen3-0.6B
+CHECKPOINT_DIR = "saves/qwen3-0.6b-full-sft-h200"
+TEST_FILE = "data/test_llama_factory.json"
+RESULTS_FILE = "evaluation_results.csv"
+PLOT_FILE = "evaluation_plot.png"
+
+# H200 Optimization
+BATCH_SIZE = 128 # H200 can handle massive batches for 0.6B model
+USE_FLASH_ATTN = False
+
+# Load System Prompt
+with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f:
+ SYSTEM_PROMPT = f.read()
+
+class EvalDataset(Dataset):
+ def __init__(self, data):
+ self.data = data
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ return self.data[idx]
+
+def load_test_data():
+ with open(TEST_FILE, "r", encoding="utf-8") as f:
+ return json.load(f)
+
+def batch_generate(model, tokenizer, batch_data, device="cuda"):
+ prompts = []
+ for item in batch_data:
+ messages = [
+ {"role": "system", "content": SYSTEM_PROMPT},
+ {"role": "user", "content": item["input"]}
+ ]
+ text = tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True
+ )
+ prompts.append(text)
+
+ inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left").to(device)
+
+ with torch.no_grad():
+ generated_ids = model.generate(
+ **inputs,
+ max_new_tokens=256,
+ do_sample=False,
+ pad_token_id=tokenizer.pad_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ )
+
+ # Slice only generated tokens
+ input_len = inputs.input_ids.shape[1]
+ gen_tokens = generated_ids[:, input_len:]
+ responses = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
+ return responses
+
+def evaluate_single_model(model_path, test_data, device="cuda"):
+ print(f"Loading model: {model_path}...")
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side="left")
+ # Ensure pad token is set for batch generation
+ if tokenizer.pad_token_id is None:
+ tokenizer.pad_token_id = tokenizer.eos_token_id
+
+ kwargs = {"device_map": device, "torch_dtype": torch.bfloat16, "trust_remote_code": True}
+ if USE_FLASH_ATTN:
+ kwargs["attn_implementation"] = "flash_attention_2"
+
+ model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
+ except Exception as e:
+ print(f"Failed to load {model_path}: {e}")
+ return None
+
+ model.eval()
+
+ dataset = EvalDataset(test_data)
+ dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4) # Use workers for data loading
+
+ y_true_has_pref = []
+ y_pred_has_pref = []
+ json_valid_count = 0
+
+ print(f"Evaluating on {len(test_data)} samples (Batch Size: {BATCH_SIZE})...")
+
+ for batch in tqdm(dataloader):
+ # batch is a dict of lists because default collate
+ # we need to reconstruct list of dicts or just access lists
+ # DataLoader collates list of dicts into dict of lists: {"input": [...], "output": [...]}
+ inputs = batch["input"]
+ outputs = batch["output"]
+
+ # Ground Truth
+ for gt_str in outputs:
+ try:
+ gt_json = json.loads(gt_str)
+ gt_has = len(gt_json.get("preferences", [])) > 0
+ except:
+ gt_has = False
+ y_true_has_pref.append(gt_has)
+
+ # Prediction
+ # batch_data structure required by batch_generate needs to be list of dicts with "input" key
+ # Reconstruct for helper function
+ batch_items = [{"input": inp} for inp in inputs]
+ responses = batch_generate(model, tokenizer, batch_items, device)
+
+ for pred_str in responses:
+ pred_has = False
+ try:
+ pred_json = json.loads(pred_str)
+ json_valid_count += 1
+ pred_has = len(pred_json.get("preferences", [])) > 0
+ except:
+ pass
+ y_pred_has_pref.append(pred_has)
+
+ # Metrics
+ metrics = {
+ "json_validity": json_valid_count / len(test_data),
+ "accuracy": accuracy_score(y_true_has_pref, y_pred_has_pref),
+ "precision": precision_score(y_true_has_pref, y_pred_has_pref, zero_division=0),
+ "recall": recall_score(y_true_has_pref, y_pred_has_pref, zero_division=0),
+ "f1": f1_score(y_true_has_pref, y_pred_has_pref, zero_division=0)
+ }
+
+ del model
+ del tokenizer
+ torch.cuda.empty_cache()
+
+ return metrics
+
+def main():
+ test_data = load_test_data()
+ results = []
+
+ # 1. Evaluate Base Model
+ print("\n--- Evaluating Base Model ---")
+ base_metrics = evaluate_single_model(BASE_MODEL_NAME, test_data)
+ if base_metrics:
+ base_metrics["step"] = 0
+ base_metrics["model"] = "Base"
+ results.append(base_metrics)
+ print(f"Base: {base_metrics}")
+
+ # 2. Evaluate Checkpoints
+ checkpoints = sorted(glob.glob(os.path.join(CHECKPOINT_DIR, "checkpoint-*")), key=lambda x: int(x.split("-")[-1]))
+ print(f"\nFound {len(checkpoints)} checkpoints.")
+
+ # Filter to only Base + Last Checkpoint (User Request)
+ if checkpoints:
+ checkpoints = [checkpoints[-1]]
+ print(f"Selecting only the last checkpoint: {checkpoints[0]}")
+
+ for ckpt in checkpoints:
+ step = int(ckpt.split("-")[-1])
+ print(f"\n--- Evaluating Checkpoint {step} ---")
+ metrics = evaluate_single_model(ckpt, test_data)
+ if metrics:
+ metrics["step"] = step
+ metrics["model"] = f"Ckpt-{step}"
+ results.append(metrics)
+ print(f"Step {step}: {metrics}")
+
+ # 3. Save & Plot
+ if not results:
+ print("No results generated.")
+ return
+
+ df = pd.DataFrame(results)
+ df = df.sort_values("step")
+ df.to_csv(RESULTS_FILE, index=False)
+ print(f"\nResults saved to {RESULTS_FILE}")
+ print(df)
+
+ plt.figure(figsize=(10, 6))
+ plt.plot(df["step"], df["f1"], marker='o', label="F1 Score")
+ plt.plot(df["step"], df["precision"], marker='s', label="Precision")
+ plt.plot(df["step"], df["recall"], marker='^', label="Recall")
+ plt.plot(df["step"], df["json_validity"], marker='x', linestyle='--', label="JSON Validity")
+
+ plt.title("Preference Extractor Training Progress")
+ plt.xlabel("Training Steps")
+ plt.ylabel("Score")
+ plt.legend()
+ plt.grid(True)
+ plt.savefig(PLOT_FILE)
+ print(f"Plot saved to {PLOT_FILE}")
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/finish_retry_batches.py b/scripts/finish_retry_batches.py
new file mode 100644
index 0000000..f266327
--- /dev/null
+++ b/scripts/finish_retry_batches.py
@@ -0,0 +1,154 @@
+import json
+import os
+import asyncio
+from openai import OpenAI, AsyncOpenAI
+from typing import Dict, Any, Set, List
+
+# --- Configuration ---
+BATCH_IDS_FILE = "data/raw_datasets/submitted_retry_batch_ids.json"
+# The input file for *this specific retry batch* run
+RETRY_INPUT_SOURCE = "data/raw_datasets/retry_requests.jsonl"
+# Where to append the final results
+OUTPUT_LABEL_FILE = "data/raw_datasets/labeled_full_dataset_batch.jsonl"
+MODEL_NAME = "gpt-5.1"
+
+def load_retry_queries() -> Dict[str, Dict[str, Any]]:
+ """
+ Load the requests that were submitted in the retry batch.
+ These are essentially JSON Request objects.
+ """
+ print("Loading retry source requests...")
+ mapping = {}
+ with open(RETRY_INPUT_SOURCE, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ req = json.loads(line)
+ # Structure: {"custom_id": "...", "body": {"messages": [..., {"role": "user", "content": "..."}]}}
+ custom_id = req["custom_id"]
+ # Extract user query back from the request body
+ user_content = ""
+ for m in req["body"]["messages"]:
+ if m["role"] == "user":
+ user_content = m["content"]
+ break
+
+ mapping[custom_id] = {
+ "query": user_content,
+ # We might have lost source info in the retry conversion if not careful,
+ # but for now let's assume we just need the query.
+ # (Ideally we should have propagated source in metadata)
+ }
+ return mapping
+
+async def process_and_finish():
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ print("Error: OPENAI_API_KEY not set.")
+ return
+
+ sync_client = OpenAI(api_key=api_key)
+ async_client = AsyncOpenAI(api_key=api_key)
+
+ if not os.path.exists(BATCH_IDS_FILE):
+ print(f"Error: {BATCH_IDS_FILE} not found.")
+ return
+
+ with open(BATCH_IDS_FILE, "r") as f:
+ batch_ids = json.load(f)
+
+ query_map = load_retry_queries()
+ processed_ids: Set[str] = set()
+
+ print(f"Total requests in retry batch: {len(query_map)}")
+
+ success_count = 0
+
+ # 1. Download results from Batch API (even if expired)
+ print("Downloading batch results...")
+ with open(OUTPUT_LABEL_FILE, "a", encoding="utf-8") as f_out:
+ for b_id in batch_ids:
+ try:
+ batch = sync_client.batches.retrieve(b_id)
+ if batch.output_file_id:
+ content = sync_client.files.content(batch.output_file_id).text
+ for line in content.splitlines():
+ if not line.strip(): continue
+ res = json.loads(line)
+ custom_id = res["custom_id"]
+
+ if res["response"]["status_code"] == 200:
+ try:
+ body = res["response"]["body"]
+ llm_content = body["choices"][0]["message"]["content"]
+ parsed_json = json.loads(llm_content)
+
+ original = query_map.get(custom_id)
+ if original:
+ record = {
+ "custom_id": custom_id,
+ "original_query": original["query"],
+ "source": "retry_recovery", # Lost original source, marking as recovery
+ "extracted_json": parsed_json,
+ "has_preference": len(parsed_json.get("preferences", [])) > 0
+ }
+ f_out.write(json.dumps(record, ensure_ascii=False) + "\n")
+ processed_ids.add(custom_id)
+ success_count += 1
+ except:
+ pass
+ except Exception as e:
+ print(f"Error checking batch {b_id}: {e}")
+
+ # 2. Identify Missing
+ missing_ids = [cid for cid in query_map.keys() if cid not in processed_ids]
+ print(f"\nMissing/Failed items: {len(missing_ids)}")
+
+ # 3. Finish with Direct API
+ if missing_ids:
+ print("Processing missing items via Direct API...")
+
+ # Load System Prompt
+ with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f:
+ sys_prompt = f.read()
+
+ with open(OUTPUT_LABEL_FILE, "a", encoding="utf-8") as f_out:
+ for cid in missing_ids:
+ item = query_map[cid]
+ query = item["query"]
+ print(f" Fixing {cid}...")
+
+ try:
+ resp = await async_client.chat.completions.create(
+ model=MODEL_NAME,
+ messages=[
+ {"role": "system", "content": sys_prompt},
+ {"role": "user", "content": query}
+ ],
+ response_format={"type": "json_object"}
+ )
+
+ content = resp.choices[0].message.content
+ parsed_json = json.loads(content)
+
+ record = {
+ "custom_id": cid,
+ "original_query": query,
+ "source": "retry_direct_fix",
+ "extracted_json": parsed_json,
+ "has_preference": len(parsed_json.get("preferences", [])) > 0
+ }
+ f_out.write(json.dumps(record, ensure_ascii=False) + "\n")
+ success_count += 1
+
+ except Exception as e:
+ print(f" Failed to fix {cid}: {e}")
+
+ print("\n" + "="*50)
+ print("ALL RETRY BATCHES RECOVERED.")
+ print(f"Total processed in this run: {success_count}")
+ print(f"Full dataset updated at: {OUTPUT_LABEL_FILE}")
+ print("="*50)
+
+if __name__ == "__main__":
+ asyncio.run(process_and_finish())
+
diff --git a/scripts/full_labeling.py b/scripts/full_labeling.py
new file mode 100644
index 0000000..1c52819
--- /dev/null
+++ b/scripts/full_labeling.py
@@ -0,0 +1,125 @@
+import json
+import os
+import asyncio
+import aiofiles
+from typing import List, Dict, Any
+from openai import AsyncOpenAI
+from tqdm.asyncio import tqdm_asyncio
+
+# --- Configuration ---
+INPUT_FILE = "data/raw_datasets/combined_raw_queries.jsonl"
+OUTPUT_FILE = "data/raw_datasets/labeled_full_dataset.jsonl"
+CHECKPOINT_FILE = "data/raw_datasets/labeling_checkpoint.txt"
+MODEL_NAME = "gpt-5.1" # Or "gpt-4o"
+MAX_CONCURRENCY = 500 # Adjust based on rate limits
+SAVE_INTERVAL = 1000 # Save batch to disk every N items
+
+# --- Load System Prompt ---
+with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f:
+ SYSTEM_PROMPT = f.read()
+
+async def label_query(client: AsyncOpenAI, sem: asyncio.Semaphore, item: Dict[str, Any]) -> Dict[str, Any]:
+ query = item["query"]
+ async with sem:
+ try:
+ # We use a short timeout/retry strategy implicitly via library,
+ # but for bulk processing, just skipping errors is often better than stalling.
+ response = await client.chat.completions.create(
+ model=MODEL_NAME,
+ messages=[
+ {"role": "system", "content": SYSTEM_PROMPT},
+ {"role": "user", "content": query}
+ ],
+ temperature=0.0,
+ response_format={"type": "json_object"}
+ )
+ result_text = response.choices[0].message.content
+
+ try:
+ parsed = json.loads(result_text)
+ prefs = parsed.get("preferences", [])
+ has_pref = len(prefs) > 0
+ except:
+ parsed = {"error": "json_parse_fail", "raw": result_text}
+ has_pref = False
+
+ return {
+ "original_query": query,
+ "source": item.get("source"),
+ "extracted_json": parsed,
+ "has_preference": has_pref
+ }
+ except Exception as e:
+ return {
+ "original_query": query,
+ "source": item.get("source"),
+ "error": str(e),
+ "has_preference": False
+ }
+
+async def main():
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ print("Error: OPENAI_API_KEY not set.")
+ return
+
+ # 1. Determine start position (Resume logic)
+ processed_count = 0
+ if os.path.exists(OUTPUT_FILE):
+ # Quick line count to see how many we've done
+ # (This assumes we append strictly)
+ with open(OUTPUT_FILE, "r", encoding="utf-8") as f:
+ for _ in f:
+ processed_count += 1
+
+ print(f"Resuming from index {processed_count}...")
+
+ # 2. Load Data (skip already processed)
+ # Since reading 400k lines is fast, we just read all and slice
+ all_items = []
+ with open(INPUT_FILE, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ all_items.append(json.loads(line))
+
+ total_items = len(all_items)
+ remaining_items = all_items[processed_count:]
+
+ if not remaining_items:
+ print("All items processed!")
+ return
+
+ print(f"Total: {total_items}, Remaining: {len(remaining_items)}")
+
+ # 3. Setup Client
+ client = AsyncOpenAI(api_key=api_key)
+ sem = asyncio.Semaphore(MAX_CONCURRENCY)
+
+ # 4. Batch Processing
+ # We process in chunks to allow periodic saving and memory management
+ batch_size = SAVE_INTERVAL
+
+ # Open file in append mode
+ async with aiofiles.open(OUTPUT_FILE, "a", encoding="utf-8") as f_out:
+
+ for i in range(0, len(remaining_items), batch_size):
+ batch = remaining_items[i : i + batch_size]
+ tasks = [label_query(client, sem, item) for item in batch]
+
+ # Run batch
+ results = await tqdm_asyncio.gather(*tasks, desc=f"Batch {i//batch_size}", leave=False)
+
+ # Write batch
+ lines = [json.dumps(res, ensure_ascii=False) + "\n" for res in results]
+ await f_out.writelines(lines)
+ await f_out.flush() # Ensure written to disk
+
+ # Optional: Print stats every now and then
+ pos_in_batch = sum(1 for r in results if r.get("has_preference"))
+ # print(f"Batch saved. Positive in this batch: {pos_in_batch}/{len(batch)}")
+
+ print(f"Done! Saved to {OUTPUT_FILE}")
+
+if __name__ == "__main__":
+ asyncio.run(main())
+
diff --git a/scripts/index_corpus.py b/scripts/index_corpus.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/scripts/index_corpus.py
diff --git a/scripts/init_user_states.py b/scripts/init_user_states.py
new file mode 100644
index 0000000..73c7435
--- /dev/null
+++ b/scripts/init_user_states.py
@@ -0,0 +1,86 @@
+#!/usr/bin/env python3
+"""
+Script to initialize User States (z_long) from Memory Embeddings.
+"""
+
+import sys
+import os
+import numpy as np
+import json
+from collections import defaultdict
+
+# Add src to sys.path
+sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.user_model.tensor_store import UserTensorStore, UserState
+from personalization.retrieval.preference_store.schemas import MemoryCard
+
+def main():
+ cards_path = "data/corpora/memory_cards.jsonl"
+ item_proj_path = "data/corpora/item_projection.npz"
+ user_store_path = "data/users/user_store.npz"
+
+ # Ensure user dir
+ os.makedirs(os.path.dirname(user_store_path), exist_ok=True)
+
+ # 1. Load data
+ print("Loading memory cards...")
+ cards = []
+ if os.path.exists(cards_path):
+ with open(cards_path, "r") as f:
+ for line in f:
+ cards.append(MemoryCard.model_validate_json(line))
+ else:
+ print("No memory cards found. Exiting.")
+ return
+
+ print("Loading item projection V...")
+ if not os.path.exists(item_proj_path):
+ print("Item projection not found. Run build_item_space.py first.")
+ return
+
+ proj_data = np.load(item_proj_path)
+ V = proj_data["V"] # [M, k]
+
+ if len(cards) != V.shape[0]:
+ print(f"Warning: Number of cards ({len(cards)}) != V rows ({V.shape[0]}). Mismatch?")
+ # If mismatch, we might need to be careful. For now assume aligned.
+
+ k = V.shape[1]
+
+ # 2. Group by user
+ user_indices = defaultdict(list)
+ for idx, card in enumerate(cards):
+ user_indices[card.user_id].append(idx)
+
+ # 3. Initialize Store
+ print(f"Initializing UserStore at {user_store_path}...")
+ store = UserTensorStore(k=k, path=user_store_path)
+
+ # 4. Compute z_long and save
+ print(f"Processing {len(user_indices)} users...")
+ for uid, indices in user_indices.items():
+ if not indices:
+ continue
+
+ # Get item vectors for this user
+ # indices is list of int, V is numpy array
+ user_items = V[indices]
+
+ # Mean pooling
+ z_long = np.mean(user_items, axis=0)
+
+ # Get/Create state
+ state = store.get_state(uid)
+ state.z_long = z_long
+ state.z_short = np.zeros(k, dtype=np.float32)
+ state.reward_ma = 0.0
+
+ store.save_state(state)
+
+ store.persist()
+ print("Done. User states initialized.")
+
+if __name__ == "__main__":
+ main()
+
diff --git a/scripts/migrate_preferences.py b/scripts/migrate_preferences.py
new file mode 100644
index 0000000..5d393c9
--- /dev/null
+++ b/scripts/migrate_preferences.py
@@ -0,0 +1,165 @@
+#!/usr/bin/env python3
+"""
+Script to migrate raw queries into MemoryCards by extracting preferences.
+It reads from data/raw_datasets/pilot_study_1000.jsonl and outputs:
+- data/corpora/memory_cards.jsonl
+- data/corpora/memory_embeddings.npy
+"""
+
+import json
+import os
+import sys
+
+# Add src to sys.path so we can import personalization
+sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))
+
+import uuid
+import numpy as np
+import torch
+from pathlib import Path
+from tqdm import tqdm
+from typing import List
+
+from personalization.config.settings import load_local_models_config
+# from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor
+from personalization.models.preference_extractor.gpt4o_extractor import GPT4OExtractor
+from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
+from personalization.retrieval.preference_store.schemas import ChatTurn, MemoryCard, PreferenceList
+
+def ensure_dir(path: str):
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
+
+def main():
+ # 1. Setup paths
+ input_path = "data/corpora/oasst1_labeled.jsonl"
+ # input_path = "data/raw_datasets/oasst1_queries.jsonl"
+ output_cards_path = "data/corpora/memory_cards.jsonl"
+ output_emb_path = "data/corpora/memory_embeddings.npy"
+ ensure_dir(output_cards_path)
+
+ print("Loading models configuration...")
+ cfg = load_local_models_config()
+
+ # 2. Initialize models
+ # print("Initializing Preference Extractor (GPT-4o)...")
+ # extractor = GPT4OExtractor.from_config(cfg)
+
+ print("Initializing Embedding Model...")
+ embedder = Qwen3Embedding8B.from_config(cfg)
+
+ # 3. Process data
+ print(f"Reading from {input_path}...")
+ memory_cards: List[MemoryCard] = []
+
+ # We will process in small batches to manage memory if needed,
+ # but for 1000 items, we can iterate one by one for extraction
+ # and maybe batch for embedding if we want optimization.
+ # Given the complexity, let's just do sequential for simplicity and safety.
+
+ with open(input_path, "r", encoding="utf-8") as f:
+ lines = f.readlines()
+
+ # Synthetic user distribution (round robin for 10 users)
+ users = [f"user_{i}" for i in range(10)]
+
+ print("Extracting preferences...")
+ # Use tqdm for progress
+ for idx, line in enumerate(tqdm(lines)):
+ # if idx >= 100: # LIMIT to 100 items
+ # break
+
+ row = json.loads(line)
+ query = row.get("original_query", "").strip()
+ if not query:
+ continue
+
+ # Use real metadata from dataset
+ user_id = row.get("user_id", f"user_{idx}")
+ session_id = row.get("session_id", f"sess_{idx}")
+ turn_id = row.get("turn_id", 0)
+
+ # Load pre-extracted preferences
+ has_pref = row.get("has_preference", False)
+ extracted_data = row.get("extracted_json", {})
+
+ # Skip if no preference (according to label)
+ if not has_pref:
+ continue
+
+ try:
+ pref_list = PreferenceList.model_validate(extracted_data)
+ except Exception:
+ # Fallback or skip if validation fails
+ continue
+
+ # If we have preferences, create a memory card
+ if pref_list.preferences:
+ # Construct a note text: "condition: action"
+ notes = [f"{p.condition}: {p.action}" for p in pref_list.preferences]
+ note_summary = "; ".join(notes)
+
+ # Create MemoryCard (embedding will be filled later)
+ card = MemoryCard(
+ card_id=str(uuid.uuid4()),
+ user_id=user_id,
+ source_session_id=session_id,
+ source_turn_ids=[turn_id],
+ raw_queries=[query],
+ preference_list=pref_list,
+ note_text=note_summary,
+ embedding_e=[], # To be filled
+ kind="pref"
+ )
+ memory_cards.append(card)
+
+ print(f"Found {len(memory_cards)} memory cards. Generating embeddings...")
+
+ if not memory_cards:
+ print("No preferences found. Exiting.")
+ return
+
+ # 4. Generate Embeddings
+ # We'll embed the `raw_queries` (joined) or `note_text`?
+ # The design doc says: "Qwen3Embedding8B.encode([turn.text])"
+ # So we embed the original query that generated the memory.
+
+ texts_to_embed = [card.raw_queries[0] for card in memory_cards]
+
+ print(f"Embedding {len(texts_to_embed)} memories...")
+ embeddings_list = []
+ chunk_size = 2000 # Process in chunks to avoid OOM
+
+ for i in range(0, len(texts_to_embed), chunk_size):
+ print(f" Embedding chunk {i} to {min(i+chunk_size, len(texts_to_embed))}...")
+ chunk = texts_to_embed[i : i + chunk_size]
+
+ # Batch encode with larger batch_size for A40
+ chunk_emb = embedder.encode(
+ chunk,
+ batch_size=128,
+ normalize=True,
+ return_tensor=False
+ )
+ embeddings_list.extend(chunk_emb)
+
+ # Assign back to cards and prepare matrix
+ emb_matrix = []
+ for card, emb in zip(memory_cards, embeddings_list):
+ card.embedding_e = emb
+ emb_matrix.append(emb)
+
+ # 5. Save
+ print(f"Saving {len(memory_cards)} cards to {output_cards_path}...")
+ with open(output_cards_path, "w", encoding="utf-8") as f:
+ for card in memory_cards:
+ f.write(card.model_dump_json() + "\n")
+
+ print(f"Saving embeddings matrix to {output_emb_path}...")
+ np_emb = np.array(emb_matrix, dtype=np.float32)
+ np.save(output_emb_path, np_emb)
+
+ print("Done!")
+
+if __name__ == "__main__":
+ main()
+
diff --git a/scripts/online_personalization_demo.py b/scripts/online_personalization_demo.py
new file mode 100644
index 0000000..f5b6d68
--- /dev/null
+++ b/scripts/online_personalization_demo.py
@@ -0,0 +1,399 @@
+#!/usr/bin/env python3
+"""
+Online Personalization REPL Demo.
+Interactive CLI for chatting with the Personalized Memory RAG system.
+Includes:
+- Extractor-0.6B for online preference extraction
+- Reranker + Policy Retrieval
+- Online RL updates (REINFORCE)
+"""
+
+import sys
+import os
+import uuid
+import numpy as np
+import torch
+import readline # For better input handling
+import yaml
+
+# Add src to sys.path
+sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.config.settings import load_local_models_config
+from personalization.config.registry import (
+ get_preference_extractor,
+ get_chat_model,
+)
+from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
+from personalization.models.reranker.qwen3_reranker import Qwen3Reranker
+# from personalization.models.llm.qwen_instruct import QwenInstruct # Deprecated direct import
+from personalization.user_model.tensor_store import UserTensorStore
+from personalization.user_model.session_state import OnlineSessionState
+from personalization.retrieval.preference_store.schemas import MemoryCard, ChatTurn, PreferenceList
+from personalization.retrieval.pipeline import retrieve_with_policy
+from personalization.feedback.handlers import eval_step
+from personalization.user_model.policy.reinforce import reinforce_update_user_state
+from personalization.user_model.features import ItemProjection
+
+def load_memory_store():
+ cards_path = "data/corpora/memory_cards.jsonl"
+ embs_path = "data/corpora/memory_embeddings.npy"
+ item_proj_path = "data/corpora/item_projection.npz"
+
+ if not os.path.exists(cards_path) or not os.path.exists(embs_path):
+ print("Memory data missing. Starting with empty memory store is possible but item space requires base data.")
+ # For this demo, we assume base data exists to define PCA space.
+ sys.exit(1)
+
+ print(f"Loading memory cards from {cards_path}...")
+ cards = []
+ with open(cards_path, "r") as f:
+ for line in f:
+ cards.append(MemoryCard.model_validate_json(line))
+
+ memory_embeddings = np.load(embs_path)
+
+ # Load PCA projection
+ proj_data = np.load(item_proj_path)
+ # We need to reconstruct ItemProjection object to transform new memories
+ projection = ItemProjection(P=proj_data["P"], mean=proj_data["mean"])
+ item_vectors = proj_data["V"]
+
+ return cards, memory_embeddings, item_vectors, projection
+
+def build_user_turn(user_id: str, text: str, turn_id: int) -> ChatTurn:
+ return ChatTurn(
+ user_id=user_id,
+ session_id="online_debug_session",
+ turn_id=turn_id,
+ role="user",
+ text=text,
+ meta={"source": "repl"}
+ )
+
+def build_assistant_turn(user_id: str, text: str, turn_id: int) -> ChatTurn:
+ return ChatTurn(
+ user_id=user_id,
+ session_id="online_debug_session",
+ turn_id=turn_id,
+ role="assistant",
+ text=text,
+ meta={"source": "repl"}
+ )
+
+def add_preferences_as_memory_cards(
+ prefs: PreferenceList,
+ query: str,
+ user_id: str,
+ turn_id: int,
+ embed_model: Qwen3Embedding8B,
+ projection: ItemProjection,
+ memory_cards: list,
+ memory_embeddings: np.ndarray,
+ item_vectors: np.ndarray,
+) -> tuple[np.ndarray, np.ndarray]:
+ """
+ Adds extracted preferences as new memory cards.
+ Returns updated memory_embeddings and item_vectors.
+ """
+ if not prefs.preferences:
+ print(" [Extractor] No preferences found in this turn.")
+ return memory_embeddings, item_vectors
+
+ e_m_list = []
+ v_m_list = []
+
+ # Only compute embedding once if we use the query as source for all prefs from this turn
+ # Alternatively, embed the note text.
+ # The current design uses the original query embedding e_m.
+ e_q = embed_model.encode([query], return_tensor=False)[0]
+ v_q = projection.transform_vector(np.array(e_q))
+
+ print(f" [Extractor] Extracted {len(prefs.preferences)} preferences:")
+ for pref in prefs.preferences:
+ note_text = f"When {pref.condition}, {pref.action}."
+ print(f" - {note_text}")
+
+ # Simple deduplication check based on note_text for this user
+ # In a real system, use vector similarity or hash
+ is_duplicate = False
+ for card in memory_cards:
+ if card.user_id == user_id and card.note_text == note_text:
+ is_duplicate = True
+ break
+
+ if is_duplicate:
+ print(" (Duplicate, skipping add)")
+ continue
+
+ card = MemoryCard(
+ card_id=str(uuid.uuid4()),
+ user_id=user_id,
+ source_session_id="online_debug_session",
+ source_turn_ids=[turn_id],
+ raw_queries=[query],
+ preference_list=PreferenceList(preferences=[pref]),
+ note_text=note_text,
+ embedding_e=e_q, # Store list[float]
+ kind="pref",
+ )
+
+ memory_cards.append(card)
+ e_m_list.append(e_q)
+ v_m_list.append(v_q)
+
+ # Update numpy arrays
+ if e_m_list:
+ new_embs = np.array(e_m_list)
+ new_vecs = np.array(v_m_list)
+ memory_embeddings = np.vstack([memory_embeddings, new_embs])
+ item_vectors = np.vstack([item_vectors, new_vecs])
+ print(f" [Debug] Added {len(e_m_list)} new cards. Total cards: {len(memory_cards)}")
+
+ return memory_embeddings, item_vectors
+
+def main():
+ # 1. Load Config & Models
+ print("Loading configuration...")
+ cfg = load_local_models_config()
+
+ # RL Config (Should load from user_model.yaml, hardcoded for safety/demo)
+ rl_cfg = {
+ "item_dim": 256,
+ "beta_long": 0.1,
+ "beta_short": 0.3,
+ "tau": 1.0,
+ "eta_long": 1e-3,
+ "eta_short": 5e-3,
+ "ema_alpha": 0.05,
+ "short_decay": 0.1,
+ "dense_topk": 64,
+ "rerank_topk": 3,
+ "max_new_tokens": 512
+ }
+
+ print("Loading models and stores...")
+ # Using explicit classes for clarity, but registry can also be used
+ embed_model = Qwen3Embedding8B.from_config(cfg)
+ reranker = Qwen3Reranker.from_config(cfg)
+
+ # Use registry for ChatModel (supports switching backends)
+ # Default to "qwen_1_5b" if not specified in user_model.yaml
+ llm_name = "qwen_1_5b"
+
+ # Try loading from config safely
+ try:
+ config_path = os.path.join(os.path.dirname(__file__), "../configs/user_model.yaml")
+ if os.path.exists(config_path):
+ with open(config_path, "r") as f:
+ user_cfg = yaml.safe_load(f)
+ if user_cfg and "llm_name" in user_cfg:
+ llm_name = user_cfg["llm_name"]
+ print(f"Loaded llm_name from config: {llm_name}")
+ else:
+ print(f"Warning: Config file not found at {config_path}")
+ except Exception as e:
+ print(f"Failed to load user_model.yaml: {e}")
+ pass
+
+ print(f"Loading ChatModel: {llm_name}...")
+ chat_model = get_chat_model(llm_name)
+
+ # Use registry for extractor to support switching
+ extractor_name = "qwen3_0_6b_sft" # Default per design doc
+ print(f"Loading extractor: {extractor_name}...")
+ try:
+ extractor = get_preference_extractor(extractor_name)
+ except Exception as e:
+ print(f"Failed to load {extractor_name}: {e}. Fallback to rule.")
+ extractor = get_preference_extractor("rule")
+
+ user_store = UserTensorStore(
+ k=rl_cfg["item_dim"],
+ path="data/users/user_store_online.npz",
+ )
+
+ # Load Memory
+ memory_cards, memory_embeddings, item_vectors, projection = load_memory_store()
+
+ # 2. Init Session
+ user_id = "debug_user"
+ user_state = user_store.get_state(user_id)
+ session_state = OnlineSessionState(user_id=user_id)
+
+ print(f"\n--- Online Personalization REPL (User: {user_id}) ---")
+ print(f"Initial State: ||z_long||={np.linalg.norm(user_state.z_long):.16f}, ||z_short||={np.linalg.norm(user_state.z_short):.16f}")
+ print("Type 'exit' or 'quit' to stop.\n")
+
+ while True:
+ try:
+ q_t = input("User: ").strip()
+ except (EOFError, KeyboardInterrupt):
+ print("\nExiting...")
+ break
+
+ if q_t.lower() in ("exit", "quit"):
+ break
+ if not q_t:
+ continue
+
+ # 3. RL Update (from previous turn)
+ e_q_t = embed_model.encode([q_t], return_tensor=False)[0]
+ e_q_t = np.array(e_q_t)
+
+ if session_state.last_query is not None:
+ r_hat, g_hat = eval_step(
+ q_t=session_state.last_query,
+ answer_t=session_state.last_answer,
+ q_t1=q_t,
+ memories_t=session_state.last_memories,
+ query_embedding_t=session_state.last_query_embedding,
+ query_embedding_t1=e_q_t,
+ )
+
+ print(f" [Feedback] Reward: {r_hat:.2f}, Gating: {g_hat:.2f}")
+
+ if (session_state.last_candidate_item_vectors is not None and
+ session_state.last_policy_probs is not None and
+ len(session_state.last_chosen_indices) > 0):
+
+ # IMPORTANT: Extract the vectors of the chosen items to align with probs
+ # last_candidate_item_vectors: [64, dim]
+ # last_chosen_indices: [3] indices into the 64 candidates
+ # last_policy_probs: [3] probabilities for the chosen items
+
+ # We need the vectors corresponding to the chosen indices
+ # chosen_indices contains indices into candidates list
+ chosen_vectors = session_state.last_candidate_item_vectors[session_state.last_chosen_indices]
+
+ updated = reinforce_update_user_state(
+ user_state=user_state,
+ item_vectors=chosen_vectors, # Corrected: Pass only chosen vectors [3, dim]
+ chosen_indices=np.arange(len(session_state.last_chosen_indices)), # Indices are now 0,1,2 relative to chosen_vectors
+ policy_probs=session_state.last_policy_probs,
+ reward_hat=r_hat,
+ gating=g_hat,
+ tau=rl_cfg["tau"],
+ eta_long=rl_cfg["eta_long"],
+ eta_short=rl_cfg["eta_short"],
+ ema_alpha=rl_cfg["ema_alpha"],
+ short_decay=rl_cfg["short_decay"],
+ )
+ if updated:
+ print(" [RL] User state updated.")
+ user_store.save_state(user_state) # Save immediately for safety
+
+ # 4. Update History
+ user_turn = build_user_turn(user_id, q_t, len(session_state.history))
+ session_state.history.append(user_turn)
+
+ # 5. Extract Preferences -> New Memory
+ # Extract from recent history
+ prefs = extractor.extract_turn(session_state.history)
+ memory_embeddings, item_vectors = add_preferences_as_memory_cards(
+ prefs, q_t, user_id, user_turn.turn_id,
+ embed_model, projection, memory_cards, memory_embeddings, item_vectors
+ )
+
+ # 6. Retrieve + Policy
+ # Use only_own_memories=True to allow strict privacy
+ # Fix unpacking order: pipeline returns (candidates, vecs, scores, indices, probs)
+ candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_with_policy(
+ user_id=user_id,
+ query=q_t,
+ embed_model=embed_model,
+ reranker=reranker,
+ memory_cards=memory_cards,
+ memory_embeddings=memory_embeddings,
+ user_store=user_store,
+ item_vectors=item_vectors,
+ topk_dense=rl_cfg["dense_topk"],
+ topk_rerank=rl_cfg["rerank_topk"],
+ beta_long=rl_cfg["beta_long"],
+ beta_short=rl_cfg["beta_short"],
+ tau=rl_cfg["tau"],
+ only_own_memories=True # User requested strict privacy for demo
+ )
+
+ # Map back to indices in candidates list (0..K-1)
+ print(f"DEBUG: candidates len={len(candidates)}, type={type(candidates)}")
+ print(f"DEBUG: chosen_indices={chosen_indices}, type={type(chosen_indices)}")
+ if len(chosen_indices) > 0:
+ print(f"DEBUG: first idx type={type(chosen_indices[0])}, val={chosen_indices[0]}")
+
+ memories_t = [candidates[int(i)] for i in chosen_indices]
+ if memories_t:
+ print(f" [Retrieval] Found {len(memories_t)} memories:")
+
+ # Display Deduplication: Group by note_text
+ from collections import Counter
+ content_counts = Counter([m.note_text for m in memories_t])
+
+ # Print unique contents with counts
+ for text, count in content_counts.most_common():
+ user_info = f" ({count} users)" if count > 1 else ""
+ print(f" - {text}{user_info}")
+
+ # 7. LLM Answer
+ memory_notes = [m.note_text for m in memories_t]
+ # history should be a list of ChatTurn objects, not dicts
+ # session_state.history is already a list of ChatTurn
+ answer_t = chat_model.answer(
+ history=session_state.history,
+ memory_notes=memory_notes,
+ max_new_tokens=rl_cfg["max_new_tokens"]
+ )
+
+ print(f"Assistant: {answer_t}")
+
+ # 8. Update State for Next Turn
+ assist_turn = build_assistant_turn(user_id, answer_t, len(session_state.history))
+ session_state.history.append(assist_turn)
+
+ session_state.last_query = q_t
+ session_state.last_answer = answer_t
+ session_state.last_memories = memories_t
+ session_state.last_query_embedding = e_q_t
+ session_state.last_candidate_item_vectors = cand_item_vecs
+ session_state.last_policy_probs = probs
+ session_state.last_chosen_indices = chosen_indices
+
+ print(f" [State] ||z_long||={np.linalg.norm(user_state.z_long):.16f}, ||z_short||={np.linalg.norm(user_state.z_short):.16f}")
+ print("-" * 40)
+
+ print("Saving final user state...")
+ user_store.save_state(user_state)
+ user_store.persist()
+
+ # Save updated memories
+ print(f"Saving {len(memory_cards)} memory cards to disk...")
+ # Ideally should be atomic or append-only, but for demo we rewrite
+ # Backup original first? For demo, direct overwrite is fine or save to new file
+ cards_path = "data/corpora/memory_cards.jsonl"
+ embs_path = "data/corpora/memory_embeddings.npy"
+ item_proj_path = "data/corpora/item_projection.npz"
+
+ with open(cards_path, "w", encoding="utf-8") as f:
+ for card in memory_cards:
+ f.write(card.model_dump_json() + "\n")
+
+ np.save(embs_path, memory_embeddings)
+
+ # Update item projection file with new item vectors?
+ # item_projection.npz usually stores the Projection Matrix P and Mean.
+ # The 'V' (item vectors) in it is just a cache.
+ # We should update V in the npz so next load has them.
+ # Load original to keep P and mean
+ proj_data = np.load(item_proj_path)
+ np.savez(
+ item_proj_path,
+ P=proj_data["P"],
+ mean=proj_data["mean"],
+ V=item_vectors
+ )
+ print("Memory store updated.")
+
+if __name__ == "__main__":
+ main()
+
+
diff --git a/scripts/personamem_build_user_vectors.py b/scripts/personamem_build_user_vectors.py
new file mode 100644
index 0000000..4719f9c
--- /dev/null
+++ b/scripts/personamem_build_user_vectors.py
@@ -0,0 +1,193 @@
+#!/usr/bin/env python3
+"""
+Script to build user vectors for PersonaMem 32k dataset.
+1. Load shared contexts.
+2. Convert to ChatTurns.
+3. Extract preferences -> MemoryCards.
+4. Compute embeddings & item vectors.
+5. Aggregate to user vectors.
+"""
+
+import sys
+import os
+import json
+import uuid
+import numpy as np
+from tqdm import tqdm
+from typing import List, Dict
+
+# Add src to sys.path
+sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.config.settings import load_local_models_config
+from personalization.config.registry import get_preference_extractor
+from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
+from personalization.retrieval.preference_store.schemas import MemoryCard, ChatTurn, PreferenceList
+from personalization.data.personamem_loader import load_personamem_contexts_32k, PersonaMemContext
+from personalization.user_model.features import ItemProjection
+
+def context_to_chatturns(context: PersonaMemContext) -> List[ChatTurn]:
+ turns = []
+ for i, msg in enumerate(context.messages):
+ if isinstance(msg, dict):
+ role = msg.get("role", "user")
+ text = msg.get("content", "")
+ elif isinstance(msg, str):
+ # Fallback for string messages, assume user? or skip?
+ # print(f"Warning: msg is string: {msg[:50]}...")
+ role = "user"
+ text = msg
+ else:
+ continue
+
+ turns.append(ChatTurn(
+ user_id=context.shared_context_id, # Use context id as user id equivalent for building vector
+ session_id=context.shared_context_id,
+ turn_id=i,
+ role="user" if role == "user" else "assistant",
+ text=text,
+ timestamp=None,
+ meta={}
+ ))
+ return turns
+
+def sliding_windows(seq: List, window_size: int, step: int):
+ for i in range(0, len(seq), step):
+ yield seq[i : i + window_size]
+
+def find_last_user_text(window: List[ChatTurn]) -> str:
+ for t in reversed(window):
+ if t.role == "user":
+ return t.text
+ return ""
+
+def main():
+ # Paths (adjust as needed, assume downloaded to data/raw_datasets/personamem)
+ ctx_path = "data/raw_datasets/personamem/shared_contexts_32k.jsonl"
+ item_proj_path = "data/corpora/item_projection.npz"
+ output_vec_path = "data/personamem/user_vectors.npz"
+ output_cards_path = "data/personamem/memory_cards.jsonl"
+
+ # Ensure dirs
+ os.makedirs(os.path.dirname(output_vec_path), exist_ok=True)
+
+ if not os.path.exists(ctx_path):
+ print(f"Error: Context file not found at {ctx_path}")
+ return
+
+ if not os.path.exists(item_proj_path):
+ print(f"Error: Item projection not found at {item_proj_path}. Run build_item_space.py first.")
+ return
+
+ # Load Models
+ print("Loading models...")
+ cfg = load_local_models_config()
+ # Explicitly use Qwen3Embedding8B
+ embed_model = Qwen3Embedding8B.from_config(cfg)
+
+ # Use registry for extractor (SFT model)
+ extractor_name = "qwen3_0_6b_sft"
+ try:
+ extractor = get_preference_extractor(extractor_name)
+ except:
+ print(f"Fallback to rule extractor for {extractor_name} not found.")
+ extractor = get_preference_extractor("rule")
+
+ # Load Projection
+ proj_data = np.load(item_proj_path)
+ projection = ItemProjection(P=proj_data["P"], mean=proj_data["mean"])
+
+ # Load Contexts
+ print("Loading contexts...")
+ contexts = load_personamem_contexts_32k(ctx_path)
+ print(f"Loaded {len(contexts)} contexts.")
+
+ all_cards = []
+
+ # Process each context
+ print("Extracting preferences...")
+ # For demo speed, maybe limit? Or full run. Full run might take time.
+ # Assuming batch processing or just loop for now.
+
+ for ctx_id, ctx in tqdm(contexts.items()):
+ turns = context_to_chatturns(ctx)
+ # print(f"Context {ctx_id}: {len(turns)} turns")
+
+ # Sliding window extraction
+ for window in sliding_windows(turns, window_size=6, step=3):
+ # Only extract if window has user turns
+ if not any(t.role == "user" for t in window):
+ continue
+
+ try:
+ prefs = extractor.extract_turn(window)
+ # if prefs.preferences:
+ # print(f" Found {len(prefs.preferences)} preferences")
+ except Exception as e:
+ print(f"Extraction failed: {e}")
+ continue
+
+ if not prefs.preferences:
+ continue
+
+ source_query = find_last_user_text(window)
+ if not source_query:
+ continue
+
+ # Embed
+ e_m = embed_model.encode([source_query], return_tensor=False)[0]
+ e_m_np = np.array(e_m)
+ v_m = projection.transform_vector(e_m_np)
+
+ # Serialize note
+ notes = [f"When {p.condition}, {p.action}." for p in prefs.preferences]
+ note_text = " ".join(notes)
+
+ card = MemoryCard(
+ card_id=str(uuid.uuid4()),
+ user_id=ctx_id, # persona_id/context_id
+ source_session_id=ctx_id,
+ source_turn_ids=[t.turn_id for t in window if t.role == "user"],
+ raw_queries=[source_query],
+ preference_list=prefs,
+ note_text=note_text,
+ embedding_e=e_m,
+ kind="pref"
+ )
+ all_cards.append(card)
+
+ print(f"Extracted {len(all_cards)} memory cards.")
+
+ # Build User Vectors
+ print("Building user vectors...")
+ z_by_user = {}
+
+ # Group cards by user
+ cards_by_user = {}
+ for c in all_cards:
+ if c.user_id not in cards_by_user:
+ cards_by_user[c.user_id] = []
+ cards_by_user[c.user_id].append(c)
+
+ for uid, u_cards in cards_by_user.items():
+ # Stack v_m: [M_u, k]
+ V = np.stack([projection.transform_vector(np.array(c.embedding_e, dtype=np.float32)) for c in u_cards], axis=0)
+ z = np.mean(V, axis=0)
+ z_by_user[uid] = z
+
+ # Save
+ print(f"Saving {len(all_cards)} cards to {output_cards_path}...")
+ with open(output_cards_path, "w", encoding="utf-8") as f:
+ for c in all_cards:
+ f.write(c.model_dump_json() + "\n")
+
+ print(f"Saving user vectors to {output_vec_path}...")
+ user_ids = list(z_by_user.keys())
+ Z = np.array([z_by_user[uid] for uid in user_ids], dtype=np.float32)
+ np.savez(output_vec_path, user_ids=user_ids, Z=Z)
+
+ print("Done.")
+
+if __name__ == "__main__":
+ main()
+
diff --git a/scripts/personamem_eval_base_vs_ours.py b/scripts/personamem_eval_base_vs_ours.py
new file mode 100644
index 0000000..e319765
--- /dev/null
+++ b/scripts/personamem_eval_base_vs_ours.py
@@ -0,0 +1,299 @@
+#!/usr/bin/env python3
+"""
+Evaluation script for PersonaMem task: Base vs Ours (User Vector).
+Metric: Accuracy (Top-1 correct option)
+"""
+
+import sys
+import os
+import numpy as np
+from tqdm import tqdm
+from typing import List
+
+# Add src to sys.path
+sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.config.settings import load_local_models_config
+from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
+from personalization.data.personamem_loader import load_personamem_questions_32k
+from personalization.user_model.features import ItemProjection
+from personalization.retrieval.preference_store.schemas import MemoryCard
+from personalization.user_model.tensor_store import UserState
+
+def cosine_sim_matrix(a: np.ndarray, b: np.ndarray) -> np.ndarray:
+ # a: [d] or [N, d], b: [d] or [M, d]
+ norm_a = np.linalg.norm(a, axis=-1, keepdims=True)
+ norm_b = np.linalg.norm(b, axis=-1, keepdims=True)
+
+ # Ensure correct shapes for broadcasting
+ # If a is 1D [d], dot gives 1D. If a is 2D [N, d], dot gives 2D.
+ # b is typically [M, d] (memories or options)
+
+ dot = np.dot(b, a.T)
+ denom = np.dot(norm_b, norm_a.T) + 1e-8
+
+ # Flatten if needed
+ sim = dot / denom
+ if sim.ndim == 2 and sim.shape[1] == 1:
+ return sim.flatten()
+ return sim
+
+def dense_retrieval(
+ e_q: np.ndarray,
+ memory_embeddings: np.ndarray,
+ topk: int = 5
+) -> np.ndarray:
+ """Returns topk indices of memories most similar to query."""
+ if memory_embeddings.shape[0] == 0:
+ return np.array([], dtype=int)
+
+ sims = cosine_sim_matrix(e_q, memory_embeddings)
+ # Get topk
+ k = min(topk, len(sims))
+ idx = np.argsort(sims)[-k:][::-1]
+ return idx
+
+def policy_retrieval(
+ e_q: np.ndarray,
+ memory_embeddings: np.ndarray,
+ item_vectors: np.ndarray,
+ z_user: np.ndarray,
+ item_proj: ItemProjection,
+ topk_dense: int = 20,
+ topk_final: int = 5,
+ beta: float = 0.2
+) -> np.ndarray:
+ """
+ Simulates retrieve_with_policy:
+ 1. Dense retrieval (topk_dense)
+ 2. Policy Scoring: s = s_base + beta * (z . v_m)
+ 3. Select topk_final
+ """
+ if memory_embeddings.shape[0] == 0:
+ return np.array([], dtype=int)
+
+ # 1. Dense Candidate Generation
+ dense_idx = dense_retrieval(e_q, memory_embeddings, topk=topk_dense)
+ if len(dense_idx) == 0:
+ return np.array([], dtype=int)
+
+ candidates_e = memory_embeddings[dense_idx]
+ candidates_v = item_vectors[dense_idx]
+
+ # 2. Base Scores (Sim(q, m))
+ base_scores = cosine_sim_matrix(e_q, candidates_e)
+
+ # 3. Policy Bonus
+ # z: [k], v: [K, k]
+ bonus = np.dot(candidates_v, z_user)
+
+ total_scores = base_scores + beta * bonus
+
+ # 4. Final Selection
+ k = min(topk_final, len(total_scores))
+ local_top_idx = np.argsort(total_scores)[-k:][::-1]
+
+ # Map back to global indices
+ return dense_idx[local_top_idx]
+
+def score_rag(
+ e_q: np.ndarray,
+ e_opts: np.ndarray,
+ retrieved_embeddings: np.ndarray
+) -> np.ndarray:
+ """
+ Computes score for each option based on query match AND memory match.
+ Score = Sim(Q, Opt) + Mean(Sim(Memories, Opt))
+ """
+ # Base: Query-Option similarity
+ # e_q: [d], e_opts: [4, d] -> [4]
+ s_q_opt = cosine_sim_matrix(e_q, e_opts)
+
+ if len(retrieved_embeddings) == 0:
+ return s_q_opt
+
+ # Memory-Option similarity
+ # retrieved: [K, d], e_opts: [4, d]
+ # We want for each option, the average similarity to retrieved memories.
+ # sim_matrix: [4, K] (options x memories) - check implementation of cosine_sim_matrix
+ # cosine_sim_matrix(a, b) does dot(b, a.T).
+ # Let a=retrieved [K, d], b=e_opts [4, d]. Result: [4, K]
+
+ s_mem_opt = cosine_sim_matrix(retrieved_embeddings, e_opts)
+
+ # Max pooling or Mean pooling over memories
+ # Usually max is better for "if any memory supports this option"
+ s_rag = np.max(s_mem_opt, axis=1) # [4]
+
+ # Combine
+ return s_q_opt + s_rag
+
+def main():
+ # Paths
+ q_path = "data/raw_datasets/personamem/questions_32k.csv"
+ vec_path = "data/personamem/user_vectors.npz"
+ cards_path = "data/personamem/memory_cards.jsonl" # Generated by builder
+ item_proj_path = "data/corpora/item_projection.npz"
+
+ if not os.path.exists(q_path) or not os.path.exists(vec_path) or not os.path.exists(cards_path):
+ print("Data missing. Run personamem_build_user_vectors.py first.")
+ sys.exit(1)
+
+ print("Loading resources...")
+ cfg = load_local_models_config()
+ embed_model = Qwen3Embedding8B.from_config(cfg)
+
+ proj_data = np.load(item_proj_path)
+ item_proj = ItemProjection(P=proj_data["P"], mean=proj_data["mean"])
+
+ # Load User Vectors
+ uv_data = np.load(vec_path, allow_pickle=True)
+ user_ids = uv_data["user_ids"]
+ Z = uv_data["Z"]
+ user_vector_map = {uid: Z[i] for i, uid in enumerate(user_ids)}
+ print(f"Loaded {len(user_vector_map)} user vectors.")
+
+ # Load Memory Cards & Embeddings
+ print("Loading memory store...")
+ cards_by_user = {}
+ embs_by_user = {}
+ vecs_by_user = {} # v vectors
+
+ # We need to load all cards to build per-user indices
+ # This might be slow if file is huge, but 32k dataset usually produces ~100k cards?
+ # Builder output: "Extracted 321 memory cards" (from small sample log).
+ # Let's assume it fits in memory.
+
+ with open(cards_path, "r") as f:
+ for line in f:
+ card = MemoryCard.model_validate_json(line)
+ uid = card.user_id
+ if uid not in cards_by_user:
+ cards_by_user[uid] = []
+ embs_by_user[uid] = []
+
+ cards_by_user[uid].append(card)
+ embs_by_user[uid].append(card.embedding_e)
+
+ # Convert lists to numpy arrays
+ for uid in embs_by_user:
+ E = np.array(embs_by_user[uid], dtype=np.float32)
+ embs_by_user[uid] = E
+ # Compute V on the fly or load if saved (builder didn't save V in separate file, but we can project)
+ vecs_by_user[uid] = item_proj.transform_embeddings(E)
+
+ print(f"Loaded memories for {len(cards_by_user)} users.")
+
+ # Load Questions
+ questions = load_personamem_questions_32k(q_path)
+ print(f"Loaded {len(questions)} questions.")
+
+ correct_base = []
+ correct_ours = []
+
+ # Hyperparams
+ betas = [0.0, 1.0] # Sanity check: 0.0 should match Base RAG (if Ours RAG logic aligns when beta=0, wait, Ours RAG uses Policy Retrieval. Beta=0 in Policy Retrieval means Dense Retrieval order. So Ours RAG (beta=0) == Base RAG? Ideally yes, if topk_dense is large enough to contain base_topk)
+
+ # Actually, Ours RAG pipeline:
+ # 1. Dense (top20) -> 2. Re-rank (Base + beta*Bonus) -> 3. Top5
+ # Base RAG pipeline:
+ # 1. Dense (top5)
+
+ # If beta=0, Ours RAG re-ranking is based on Base Score (Sim(q,m)).
+ # Since Dense Retrieval already sorts by Sim(q,m), re-ranking by Sim(q,m) keeps order.
+ # So if topk_dense >= topk_final, Ours (beta=0) should pick same top-5 as Base RAG.
+
+ for beta in betas:
+ print(f"\nEvaluating with RAG (beta={beta})...")
+
+ correct_base = []
+ correct_ours = []
+
+ case_count = 0
+
+ for q in tqdm(questions):
+ target_id = q.shared_context_id
+
+ # Skip if no vector or no memories
+ if target_id not in user_vector_map or target_id not in embs_by_user:
+ continue
+
+ z_user = user_vector_map[target_id]
+ mem_E = embs_by_user[target_id]
+ mem_V = vecs_by_user[target_id]
+
+ # Embed Query
+ e_q = embed_model.encode([q.user_question_or_message], return_tensor=False)[0]
+ e_q = np.array(e_q, dtype=np.float32)
+
+ # Embed Options
+ if not q.all_options:
+ continue
+ e_opts = embed_model.encode(q.all_options, return_tensor=False)
+ e_opts = np.array(e_opts, dtype=np.float32)
+
+ # --- BASE (No RAG) ---
+ s_base = score_rag(e_q, e_opts, np.array([]))
+
+ # --- OURS (Personalized RAG) ---
+ ours_idx = policy_retrieval(e_q, mem_E, mem_V, z_user, item_proj, topk_dense=20, topk_final=5, beta=beta)
+ ours_mem_E = mem_E[ours_idx]
+ s_ours = score_rag(e_q, e_opts, ours_mem_E)
+
+ pred_base = int(np.argmax(s_base))
+ pred_ours = int(np.argmax(s_ours))
+
+ is_correct = (pred_ours == q.correct_index)
+ base_correct = (pred_base == q.correct_index)
+
+ # Detailed Case Print (Task A)
+ # Print only when Beta > 0 (to avoid duplicate logs) and when they disagree
+ if beta > 0.0 and pred_base != pred_ours and case_count < 5:
+ case_count += 1
+
+ # Reconstruct memory text (need to find card in list)
+ # Optimization: Create a map or just linear search in cards_by_user[target_id]
+ user_cards = cards_by_user[target_id]
+ # ours_idx are indices into mem_E which corresponds to cards_by_user list order
+ retrieved_notes = [user_cards[i].note_text for i in ours_idx]
+
+ print(f"\n" + "="*60)
+ print(f"[CASE ANALYSIS] QID: {q.question_id}")
+ print(f"User Question: {q.user_question_or_message}")
+ print(f"Correct Option ({q.correct_index}): {q.all_options[q.correct_index]}")
+ print("-" * 30)
+ print(f"Base Pred ({pred_base}): {q.all_options[pred_base]} [{'CORRECT' if base_correct else 'WRONG'}]")
+ print(f"Ours Pred ({pred_ours}): {q.all_options[pred_ours]} [{'CORRECT' if is_correct else 'WRONG'}]")
+ print("-" * 30)
+ print(f"Retrieved Memories (Top 3 of {len(retrieved_notes)}):")
+ for note in retrieved_notes[:3]:
+ print(f" - {note}")
+ print("-" * 30)
+ print(f"Scores Base: {s_base}")
+ print(f"Scores Ours: {s_ours}")
+ print("="*60 + "\n")
+
+ if q.correct_index != -1:
+ correct_base.append(1 if pred_base == q.correct_index else 0)
+ correct_ours.append(1 if pred_ours == q.correct_index else 0)
+
+ if not correct_base:
+ print("No valid evaluation samples processed.")
+ continue
+
+ acc_base = np.mean(correct_base)
+ acc_ours = np.mean(correct_ours)
+
+ # Win rate
+ wins = [1 if (c_o == 1 and c_b == 0) else 0 for c_o, c_b in zip(correct_ours, correct_base)]
+ win_rate = np.mean(wins)
+
+ print(f"\n--- Results (Beta={beta}) ---")
+ print(f"Total Samples: {len(correct_base)}")
+ print(f"Accuracy (Base No-RAG): {acc_base:.4f}")
+ print(f"Accuracy (Ours RAG): {acc_ours:.4f}")
+ print(f"Win Rate: {win_rate:.4f}")
+
+if __name__ == "__main__":
+ main()
+
diff --git a/scripts/pilot_runner_v0.py b/scripts/pilot_runner_v0.py
new file mode 100644
index 0000000..8b7773a
--- /dev/null
+++ b/scripts/pilot_runner_v0.py
@@ -0,0 +1,362 @@
+#!/usr/bin/env python3
+"""
+Pilot Runner v0 - Minimal End-to-End Test
+
+Goal: Prove the chat → judge → apply_feedback → next query loop works.
+
+Setup:
+- 1 user × 1 session × 5 turns
+- Fixed queries (no fancy user simulator yet)
+- Rule-based judge: answer non-empty → sat=1, else 0
+- reward = sat, gating = 1 always
+
+What we're checking:
+1. No crashes (KeyError, NoneType, etc.)
+2. User vector norms change after feedback (RL is being called)
+3. resp.usage returns reasonable numbers
+4. Logs are generated correctly
+"""
+
+import sys
+import os
+import json
+from datetime import datetime
+from dataclasses import dataclass, asdict
+from typing import List, Dict, Any, Optional
+
+# Add src to path
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.serving import PersonalizedLLM, Feedback, AssistantResponse
+
+
+# =============================================================================
+# Minimal Judge
+# =============================================================================
+
+@dataclass
+class JudgeResult:
+ """Output from the judge for one turn."""
+ sat_t: float # Satisfaction score [0, 1]
+ sev_t: float # Severity of violations [0, 1]
+ prog_t: float # Task progress [0, 1]
+ violations: List[str] # List of violated constraints
+
+
+def minimal_judge(query: str, answer: str, task_type: str = "general") -> JudgeResult:
+ """
+ Minimal rule-based judge for pilot.
+
+ For now:
+ - sat_t = 1 if answer is non-empty, else 0
+ - sev_t = 0 (no severity tracking yet)
+ - prog_t = 1 if answer looks reasonable, else 0
+ """
+ violations = []
+
+ # Check 1: Answer is non-empty
+ if not answer or len(answer.strip()) < 5:
+ violations.append("empty_answer")
+ return JudgeResult(sat_t=0.0, sev_t=1.0, prog_t=0.0, violations=violations)
+
+ # Check 2: Answer is not too short (at least 20 chars for real content)
+ if len(answer.strip()) < 20:
+ violations.append("too_short")
+
+ # Check 3: For code tasks, look for code markers
+ if task_type == "code":
+ has_code = "```" in answer or "def " in answer or "function" in answer
+ if not has_code:
+ violations.append("no_code_block")
+
+ # Calculate scores
+ sat_t = 1.0 if len(violations) == 0 else max(0.0, 1.0 - 0.3 * len(violations))
+ sev_t = 1.0 if "empty_answer" in violations else 0.0
+ prog_t = 1.0 if "empty_answer" not in violations else 0.0
+
+ return JudgeResult(sat_t=sat_t, sev_t=sev_t, prog_t=prog_t, violations=violations)
+
+
+# =============================================================================
+# Minimal User Simulator (Fixed Queries)
+# =============================================================================
+
+def get_fixed_queries() -> List[Dict[str, Any]]:
+ """
+ Return fixed queries for pilot test.
+ Mix of preference statements and tasks.
+ """
+ return [
+ {
+ "query": "I prefer short, concise answers. Please keep responses under 100 words.",
+ "type": "preference",
+ "task_type": "general",
+ },
+ {
+ "query": "What are three tips for better sleep?",
+ "type": "task",
+ "task_type": "general",
+ },
+ {
+ "query": "I also prefer bullet points when listing things.",
+ "type": "preference",
+ "task_type": "general",
+ },
+ {
+ "query": "What are the main benefits of exercise?",
+ "type": "task",
+ "task_type": "general",
+ },
+ {
+ "query": "Summarize what you know about my preferences.",
+ "type": "task",
+ "task_type": "general",
+ },
+ ]
+
+
+# =============================================================================
+# Logging
+# =============================================================================
+
+@dataclass
+class TurnLog:
+ """Log entry for one turn."""
+ turn_id: int
+ query: str
+ query_type: str
+ answer: str
+ answer_length: int
+ sat_t: float
+ sev_t: float
+ prog_t: float
+ violations: List[str]
+ reward: float
+ gating: float
+ z_long_norm_before: float
+ z_long_norm_after: float
+ z_short_norm_before: float
+ z_short_norm_after: float
+ prompt_tokens: int
+ completion_tokens: int
+ total_tokens: int
+ num_memories_retrieved: int
+ num_prefs_extracted: int
+
+
+def log_to_jsonl(logs: List[TurnLog], filepath: str):
+ """Save logs to JSONL file."""
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
+ with open(filepath, "w") as f:
+ for log in logs:
+ f.write(json.dumps(asdict(log)) + "\n")
+
+
+# =============================================================================
+# Pilot Runner
+# =============================================================================
+
+def run_pilot(
+ llm: PersonalizedLLM,
+ user_id: str = "pilot_user_0",
+ queries: Optional[List[Dict[str, Any]]] = None,
+) -> List[TurnLog]:
+ """
+ Run a single pilot session.
+
+ Returns list of turn logs.
+ """
+ if queries is None:
+ queries = get_fixed_queries()
+
+ logs: List[TurnLog] = []
+
+ print(f"\n{'='*60}")
+ print(f"PILOT SESSION: user_id={user_id}, turns={len(queries)}")
+ print(f"{'='*60}")
+
+ # Reset user for clean start
+ print(f"\n[Pilot] Resetting user: {user_id}")
+ llm.reset_user(user_id)
+
+ # Start session
+ print(f"[Pilot] Starting session")
+ llm.reset_session(user_id)
+
+ # Get initial state
+ state_before = llm.get_user_state_summary(user_id)
+ print(f"[Pilot] Initial state: z_long_norm={state_before['z_long_norm']:.6f}, z_short_norm={state_before['z_short_norm']:.6f}")
+
+ for turn_id, q_info in enumerate(queries):
+ query = q_info["query"]
+ query_type = q_info.get("type", "task")
+ task_type = q_info.get("task_type", "general")
+
+ print(f"\n--- Turn {turn_id} ---")
+ print(f"[Query] ({query_type}) {query[:80]}...")
+
+ # Get state before
+ state_before = llm.get_user_state_summary(user_id)
+ z_long_before = state_before["z_long_norm"]
+ z_short_before = state_before["z_short_norm"]
+
+ # Apply feedback for previous turn (from turn 1 onwards)
+ if turn_id > 0 and len(logs) > 0:
+ prev_log = logs[-1]
+ feedback = Feedback(
+ user_id=user_id,
+ turn_id=turn_id - 1,
+ reward=prev_log.reward,
+ gating=prev_log.gating,
+ meta={"source": "pilot_v0"}
+ )
+ print(f"[Feedback] Applying: reward={feedback.reward:.2f}, gating={feedback.gating:.1f}")
+ llm.apply_feedback(feedback)
+
+ # Chat
+ resp: AssistantResponse = llm.chat(user_id, query)
+
+ print(f"[Answer] {resp.answer[:100]}..." if len(resp.answer) > 100 else f"[Answer] {resp.answer}")
+ print(f"[Usage] prompt={resp.usage.prompt_tokens}, completion={resp.usage.completion_tokens}")
+
+ # Judge
+ judge_result = minimal_judge(query, resp.answer, task_type)
+ print(f"[Judge] sat={judge_result.sat_t:.2f}, prog={judge_result.prog_t:.2f}, violations={judge_result.violations}")
+
+ # Compute reward and gating
+ reward = judge_result.sat_t # Simple: reward = satisfaction
+ gating = 1.0 # Always allow learning for pilot
+
+ # Get state after
+ state_after = llm.get_user_state_summary(user_id)
+ z_long_after = state_after["z_long_norm"]
+ z_short_after = state_after["z_short_norm"]
+
+ # Debug info
+ num_memories = len(resp.debug.selected_memory_ids) if resp.debug else 0
+ num_prefs = len(resp.debug.extracted_preferences) if resp.debug else 0
+
+ print(f"[State] z_long: {z_long_before:.6f} -> {z_long_after:.6f}, z_short: {z_short_before:.6f} -> {z_short_after:.6f}")
+ print(f"[Debug] memories={num_memories}, prefs_extracted={num_prefs}")
+
+ # Log
+ log = TurnLog(
+ turn_id=turn_id,
+ query=query,
+ query_type=query_type,
+ answer=resp.answer,
+ answer_length=len(resp.answer),
+ sat_t=judge_result.sat_t,
+ sev_t=judge_result.sev_t,
+ prog_t=judge_result.prog_t,
+ violations=judge_result.violations,
+ reward=reward,
+ gating=gating,
+ z_long_norm_before=z_long_before,
+ z_long_norm_after=z_long_after,
+ z_short_norm_before=z_short_before,
+ z_short_norm_after=z_short_after,
+ prompt_tokens=resp.usage.prompt_tokens,
+ completion_tokens=resp.usage.completion_tokens,
+ total_tokens=resp.usage.total_tokens,
+ num_memories_retrieved=num_memories,
+ num_prefs_extracted=num_prefs,
+ )
+ logs.append(log)
+
+ # Apply final feedback
+ if len(logs) > 0:
+ last_log = logs[-1]
+ feedback = Feedback(
+ user_id=user_id,
+ turn_id=len(queries) - 1,
+ reward=last_log.reward,
+ gating=last_log.gating,
+ meta={"source": "pilot_v0", "final": True}
+ )
+ print(f"\n[Final Feedback] reward={feedback.reward:.2f}, gating={feedback.gating:.1f}")
+ llm.apply_feedback(feedback)
+
+ return logs
+
+
+def print_summary(logs: List[TurnLog]):
+ """Print summary statistics."""
+ print(f"\n{'='*60}")
+ print("PILOT SUMMARY")
+ print(f"{'='*60}")
+
+ total_turns = len(logs)
+ avg_sat = sum(l.sat_t for l in logs) / total_turns if total_turns > 0 else 0
+ avg_prog = sum(l.prog_t for l in logs) / total_turns if total_turns > 0 else 0
+ total_tokens = sum(l.total_tokens for l in logs)
+ total_prompt = sum(l.prompt_tokens for l in logs)
+ total_completion = sum(l.completion_tokens for l in logs)
+
+ # Check if RL updates happened (vector norms changed)
+ z_long_changes = [abs(l.z_long_norm_after - l.z_long_norm_before) for l in logs]
+ z_short_changes = [abs(l.z_short_norm_after - l.z_short_norm_before) for l in logs]
+ any_z_long_change = any(c > 1e-6 for c in z_long_changes)
+ any_z_short_change = any(c > 1e-6 for c in z_short_changes)
+
+ print(f"Total turns: {total_turns}")
+ print(f"Average satisfaction: {avg_sat:.3f}")
+ print(f"Average progress: {avg_prog:.3f}")
+ print(f"Total tokens: {total_tokens} (prompt: {total_prompt}, completion: {total_completion})")
+ print(f"z_long changed: {any_z_long_change} (max delta: {max(z_long_changes):.6f})")
+ print(f"z_short changed: {any_z_short_change} (max delta: {max(z_short_changes):.6f})")
+
+ # Violations breakdown
+ all_violations = [v for l in logs for v in l.violations]
+ if all_violations:
+ from collections import Counter
+ print(f"Violations: {dict(Counter(all_violations))}")
+ else:
+ print("Violations: None")
+
+ # RL Health Check
+ print(f"\n--- RL Health Check ---")
+ if any_z_long_change or any_z_short_change:
+ print("✓ User vectors ARE being updated by RL")
+ else:
+ print("✗ WARNING: User vectors NOT changing - check apply_feedback")
+
+
+def main():
+ print("=" * 60)
+ print("PILOT RUNNER v0")
+ print("=" * 60)
+ print(f"Started at: {datetime.now().isoformat()}")
+
+ # Initialize LLM
+ print("\n[Init] Loading PersonalizedLLM...")
+ llm = PersonalizedLLM(
+ user_store_path="data/users/user_store_pilot.npz",
+ only_own_memories=True,
+ enable_preference_extraction=True,
+ enable_rl_updates=True,
+ )
+
+ # Run pilot
+ user_id = "pilot_user_0"
+ logs = run_pilot(llm, user_id=user_id)
+
+ # Summary
+ print_summary(logs)
+
+ # Save logs
+ log_path = f"data/logs/pilot_v0_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl"
+ log_to_jsonl(logs, log_path)
+ print(f"\n[Logs] Saved to: {log_path}")
+
+ # Final state
+ final_state = llm.get_user_state_summary(user_id)
+ print(f"\n[Final State] {final_state}")
+
+ print(f"\nCompleted at: {datetime.now().isoformat()}")
+ print("=" * 60)
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/scripts/pilot_runner_v1.py b/scripts/pilot_runner_v1.py
new file mode 100644
index 0000000..fbb2876
--- /dev/null
+++ b/scripts/pilot_runner_v1.py
@@ -0,0 +1,607 @@
+#!/usr/bin/env python3
+"""
+Pilot Runner v1 - Style-Aware Judge + Gating Logic
+
+Upgrade from v0:
+- StylePrefs: User style preferences (length, bullets, language)
+- style_judge: Checks style conformance, not just non-empty
+- compute_feedback_for_turn: gating=1 only for preference-related turns
+- Extended queries: ~10 turns with preference/task mix
+
+Goal: Verify that:
+1. sat_t varies based on style violations (not always 1)
+2. gating=1 only on preference turns, 0 on regular tasks
+3. RL updates happen when gating=1 and reward != baseline
+4. Over turns, model may adapt to preferences (sat_t improves)
+"""
+
+import sys
+import os
+import json
+from datetime import datetime
+from dataclasses import dataclass, asdict, field
+from typing import List, Dict, Any, Optional, Tuple
+
+# Add src to path
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.serving import PersonalizedLLM, Feedback, AssistantResponse
+
+
+# =============================================================================
+# Style Preferences
+# =============================================================================
+
+@dataclass
+class StylePrefs:
+ """User's style preferences for the judge to check."""
+ require_short: bool = False
+ max_chars: int = 300
+ require_bullets: bool = False
+ lang: str = "en" # "en" or "zh"
+
+
+# =============================================================================
+# Style-Aware Judge
+# =============================================================================
+
+@dataclass
+class JudgeResult:
+ """Output from the judge for one turn."""
+ sat_t: float # Satisfaction score [0, 1]
+ sev_t: float # Severity of violations [0, 1]
+ prog_t: float # Task progress [0, 1]
+ violations: List[str] # List of violated constraints
+
+
+def style_judge(
+ query: str,
+ answer: str,
+ task_type: str,
+ prefs: StylePrefs,
+) -> JudgeResult:
+ """
+ Style-aware judge that checks:
+ - Empty/too short answer
+ - Length constraint (max_chars)
+ - Bullet point requirement
+ - Language preference
+ - Code block for code tasks
+
+ Returns:
+ JudgeResult with sat_t, sev_t, prog_t, and violations list.
+ """
+ violations: List[str] = []
+ text = (answer or "").strip()
+
+ # 0) Empty answer - immediate fail
+ if not text or len(text) < 5:
+ violations.append("empty_answer")
+ return JudgeResult(
+ sat_t=0.0,
+ sev_t=1.0,
+ prog_t=0.0,
+ violations=violations,
+ )
+
+ # 1) Length preference
+ if prefs.require_short:
+ if len(text) > prefs.max_chars:
+ violations.append("too_long")
+
+ # 2) Bullet preference (only for general/list tasks, not pure preference statements)
+ if prefs.require_bullets and task_type in ("general", "list"):
+ has_bullets = ("- " in text) or ("• " in text) or ("* " in text) or ("\n- " in text)
+ if not has_bullets:
+ violations.append("no_bullets")
+
+ # 3) Language preference (rough heuristic)
+ if prefs.lang == "zh":
+ # For Chinese preference, check if answer has enough non-ASCII chars
+ ascii_count = sum(c.isascii() for c in text)
+ ascii_ratio = ascii_count / max(1, len(text))
+ if ascii_ratio > 0.7: # Too much ASCII = probably not Chinese
+ violations.append("wrong_lang")
+ elif prefs.lang == "en":
+ # For English preference, check if answer is mostly ASCII
+ ascii_count = sum(c.isascii() for c in text)
+ ascii_ratio = ascii_count / max(1, len(text))
+ if ascii_ratio < 0.5: # Too little ASCII = probably not English
+ violations.append("wrong_lang")
+
+ # 4) Code task: must have code markers
+ prog_t = 1.0
+ if task_type == "code":
+ has_code = ("```" in text) or ("def " in text) or ("function " in text)
+ if not has_code:
+ violations.append("no_code_block")
+ prog_t = 0.0
+
+ # 5) Compute sat_t and sev_t from violations
+ if not violations:
+ sat_t = 1.0
+ sev_t = 0.0
+ else:
+ # Each violation costs 0.3, minimum 0
+ sat_t = max(0.0, 1.0 - 0.3 * float(len(violations)))
+ # Hard violations trigger sev_t=1
+ hard_violations = {"empty_answer", "too_long", "wrong_lang"}
+ sev_t = 1.0 if any(v in hard_violations for v in violations) else 0.0
+
+ return JudgeResult(
+ sat_t=sat_t,
+ sev_t=sev_t,
+ prog_t=prog_t,
+ violations=violations,
+ )
+
+
+# =============================================================================
+# Feedback Computation (reward + gating)
+# =============================================================================
+
+def compute_feedback_for_turn(
+ turn_id: int,
+ query: str,
+ query_type: str,
+ task_type: str,
+ judge_result: JudgeResult,
+) -> Tuple[float, float]:
+ """
+ Convert JudgeResult into (reward, gating):
+ - reward = sat_t (style satisfaction)
+ - gating = 1 only if this turn is preference-related (declared or complained)
+
+ Args:
+ turn_id: The turn index
+ query: The user's query text
+ query_type: "preference" or "task" from query metadata
+ task_type: "general", "list", "code", etc.
+ judge_result: The judge's evaluation
+
+ Returns:
+ (reward, gating) tuple
+ """
+ reward = judge_result.sat_t
+
+ # Gating logic: only allow RL update on preference-related turns
+ # 1. Explicit preference declaration (query_type == "preference")
+ # 2. Complaint about not following preference
+ lower_q = (query or "").lower()
+
+ is_pref_turn = (
+ query_type == "preference"
+ or "i prefer" in lower_q
+ or "my preference" in lower_q
+ or "please use" in lower_q
+ or "please keep" in lower_q
+ or "you didn't follow" in lower_q
+ or "you forgot" in lower_q
+ or "remember that i" in lower_q
+ or "i told you" in lower_q
+ or "i asked for" in lower_q
+ )
+
+ if is_pref_turn:
+ gating = 1.0
+ else:
+ gating = 0.0
+
+ return reward, gating
+
+
+# =============================================================================
+# Extended Queries for Pilot v1 (~10 turns)
+# =============================================================================
+
+def get_pilot_v1_queries() -> List[Dict[str, Any]]:
+ """
+ Extended query set for pilot v1.
+ Mix of preference declarations and tasks.
+ Tests: length constraint, bullet points, task completion.
+ """
+ return [
+ # Turn 0: Declare length preference
+ {
+ "query": "I prefer short, concise answers. Please keep responses under 200 characters.",
+ "type": "preference",
+ "task_type": "general",
+ },
+ # Turn 1: Task that should be short
+ {
+ "query": "What are three tips for better sleep?",
+ "type": "task",
+ "task_type": "list",
+ },
+ # Turn 2: Declare bullet preference
+ {
+ "query": "I also prefer bullet points when listing things. Please use bullet points.",
+ "type": "preference",
+ "task_type": "general",
+ },
+ # Turn 3: Task that should use bullets
+ {
+ "query": "What are the main benefits of regular exercise?",
+ "type": "task",
+ "task_type": "list",
+ },
+ # Turn 4: Another task (test if preferences stick)
+ {
+ "query": "Name five popular programming languages.",
+ "type": "task",
+ "task_type": "list",
+ },
+ # Turn 5: Complaint if needed (always include to test gating)
+ {
+ "query": "Remember that I asked for short answers with bullet points. Can you list three healthy breakfast ideas?",
+ "type": "preference",
+ "task_type": "list",
+ },
+ # Turn 6: Regular task
+ {
+ "query": "What is the capital of France?",
+ "type": "task",
+ "task_type": "general",
+ },
+ # Turn 7: Task requiring list
+ {
+ "query": "What are four seasons of the year?",
+ "type": "task",
+ "task_type": "list",
+ },
+ # Turn 8: Another preference reminder
+ {
+ "query": "I prefer concise bullet points. Please list three types of renewable energy.",
+ "type": "preference",
+ "task_type": "list",
+ },
+ # Turn 9: Final task - test memory
+ {
+ "query": "Summarize what you know about my communication preferences.",
+ "type": "task",
+ "task_type": "general",
+ },
+ ]
+
+
+# =============================================================================
+# Logging
+# =============================================================================
+
+@dataclass
+class TurnLog:
+ """Log entry for one turn."""
+ turn_id: int
+ query: str
+ query_type: str
+ task_type: str
+ answer: str
+ answer_length: int
+ sat_t: float
+ sev_t: float
+ prog_t: float
+ violations: List[str]
+ reward: float
+ gating: float
+ z_long_norm_before: float
+ z_long_norm_after: float
+ z_short_norm_before: float
+ z_short_norm_after: float
+ prompt_tokens: int
+ completion_tokens: int
+ total_tokens: int
+ num_memories_retrieved: int
+ num_prefs_extracted: int
+
+
+def log_to_jsonl(logs: List[TurnLog], filepath: str):
+ """Save logs to JSONL file."""
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
+ with open(filepath, "w") as f:
+ for log in logs:
+ f.write(json.dumps(asdict(log)) + "\n")
+
+
+# =============================================================================
+# Pilot Runner v1
+# =============================================================================
+
+def run_pilot_v1(
+ llm: PersonalizedLLM,
+ user_id: str = "pilot_user_v1",
+ prefs: Optional[StylePrefs] = None,
+ queries: Optional[List[Dict[str, Any]]] = None,
+) -> List[TurnLog]:
+ """
+ Run pilot v1 with style-aware judge and gating.
+
+ Args:
+ llm: PersonalizedLLM instance
+ user_id: User identifier
+ prefs: Style preferences for this user
+ queries: Query list (defaults to get_pilot_v1_queries)
+
+ Returns:
+ List of TurnLog entries
+ """
+ if prefs is None:
+ # Default preferences: short + bullets + English
+ prefs = StylePrefs(
+ require_short=True,
+ max_chars=200,
+ require_bullets=True,
+ lang="en",
+ )
+
+ if queries is None:
+ queries = get_pilot_v1_queries()
+
+ logs: List[TurnLog] = []
+
+ print(f"\n{'='*60}")
+ print(f"PILOT v1 SESSION: user_id={user_id}, turns={len(queries)}")
+ print(f"Preferences: short={prefs.require_short}, max_chars={prefs.max_chars}, bullets={prefs.require_bullets}, lang={prefs.lang}")
+ print(f"{'='*60}")
+
+ # Reset user for clean start
+ print(f"\n[Pilot] Resetting user: {user_id}")
+ llm.reset_user(user_id)
+
+ # Start session
+ print(f"[Pilot] Starting session")
+ llm.reset_session(user_id)
+
+ # Get initial state
+ state_before = llm.get_user_state_summary(user_id)
+ print(f"[Pilot] Initial state: z_long={state_before['z_long_norm']:.6f}, z_short={state_before['z_short_norm']:.6f}")
+
+ for turn_id, q_info in enumerate(queries):
+ query = q_info["query"]
+ query_type = q_info.get("type", "task")
+ task_type = q_info.get("task_type", "general")
+
+ print(f"\n{'─'*60}")
+ print(f"Turn {turn_id} [{query_type}]")
+ print(f"{'─'*60}")
+ print(f"[Query] {query}")
+
+ # Get state before
+ state_before = llm.get_user_state_summary(user_id)
+ z_long_before = state_before["z_long_norm"]
+ z_short_before = state_before["z_short_norm"]
+
+ # Apply feedback for previous turn (from turn 1 onwards)
+ if turn_id > 0 and len(logs) > 0:
+ prev_log = logs[-1]
+ prev_query = queries[turn_id - 1]
+
+ # Re-judge the previous answer with current context
+ # (In practice we already have the result, but this shows the flow)
+ feedback = Feedback(
+ user_id=user_id,
+ turn_id=turn_id - 1,
+ reward=prev_log.reward,
+ gating=prev_log.gating,
+ meta={
+ "sat_t": prev_log.sat_t,
+ "sev_t": prev_log.sev_t,
+ "prog_t": prev_log.prog_t,
+ "violations": prev_log.violations,
+ "task_type": prev_log.task_type,
+ "source": "pilot_v1",
+ }
+ )
+ print(f"[Feedback] turn={turn_id-1}, reward={feedback.reward:.2f}, gating={feedback.gating:.1f}")
+ llm.apply_feedback(feedback)
+
+ # Chat
+ resp: AssistantResponse = llm.chat(user_id, query)
+
+ # Truncate answer for display
+ answer_display = resp.answer[:150] + "..." if len(resp.answer) > 150 else resp.answer
+ print(f"[Answer] ({len(resp.answer)} chars) {answer_display}")
+ print(f"[Usage] prompt={resp.usage.prompt_tokens}, completion={resp.usage.completion_tokens}")
+
+ # Judge with style preferences
+ judge_result = style_judge(query, resp.answer, task_type, prefs)
+ print(f"[Judge] sat={judge_result.sat_t:.2f}, sev={judge_result.sev_t:.1f}, prog={judge_result.prog_t:.1f}")
+ if judge_result.violations:
+ print(f"[Judge] violations={judge_result.violations}")
+
+ # Compute feedback for THIS turn (will be applied next turn)
+ reward, gating = compute_feedback_for_turn(
+ turn_id=turn_id,
+ query=query,
+ query_type=query_type,
+ task_type=task_type,
+ judge_result=judge_result,
+ )
+ print(f"[Feedback] reward={reward:.2f}, gating={gating:.1f} (computed for this turn)")
+
+ # Get state after
+ state_after = llm.get_user_state_summary(user_id)
+ z_long_after = state_after["z_long_norm"]
+ z_short_after = state_after["z_short_norm"]
+
+ # Debug info
+ num_memories = len(resp.debug.selected_memory_ids) if resp.debug else 0
+ num_prefs = len(resp.debug.extracted_preferences) if resp.debug else 0
+
+ z_long_delta = z_long_after - z_long_before
+ z_short_delta = z_short_after - z_short_before
+ print(f"[State] z_long: {z_long_before:.6f} → {z_long_after:.6f} (Δ={z_long_delta:+.6f})")
+ print(f"[State] z_short: {z_short_before:.6f} → {z_short_after:.6f} (Δ={z_short_delta:+.6f})")
+ print(f"[Debug] memories={num_memories}, prefs_extracted={num_prefs}")
+
+ # Log
+ log = TurnLog(
+ turn_id=turn_id,
+ query=query,
+ query_type=query_type,
+ task_type=task_type,
+ answer=resp.answer,
+ answer_length=len(resp.answer),
+ sat_t=judge_result.sat_t,
+ sev_t=judge_result.sev_t,
+ prog_t=judge_result.prog_t,
+ violations=judge_result.violations,
+ reward=reward,
+ gating=gating,
+ z_long_norm_before=z_long_before,
+ z_long_norm_after=z_long_after,
+ z_short_norm_before=z_short_before,
+ z_short_norm_after=z_short_after,
+ prompt_tokens=resp.usage.prompt_tokens,
+ completion_tokens=resp.usage.completion_tokens,
+ total_tokens=resp.usage.total_tokens,
+ num_memories_retrieved=num_memories,
+ num_prefs_extracted=num_prefs,
+ )
+ logs.append(log)
+
+ # Apply final feedback for last turn
+ if len(logs) > 0:
+ last_log = logs[-1]
+ feedback = Feedback(
+ user_id=user_id,
+ turn_id=len(queries) - 1,
+ reward=last_log.reward,
+ gating=last_log.gating,
+ meta={"source": "pilot_v1", "final": True}
+ )
+ print(f"\n[Final Feedback] turn={len(queries)-1}, reward={feedback.reward:.2f}, gating={feedback.gating:.1f}")
+ llm.apply_feedback(feedback)
+
+ return logs
+
+
+def print_summary_v1(logs: List[TurnLog], prefs: StylePrefs):
+ """Print summary statistics for pilot v1."""
+ print(f"\n{'='*60}")
+ print("PILOT v1 SUMMARY")
+ print(f"{'='*60}")
+
+ total_turns = len(logs)
+ if total_turns == 0:
+ print("No turns to summarize.")
+ return
+
+ # Basic stats
+ avg_sat = sum(l.sat_t for l in logs) / total_turns
+ avg_prog = sum(l.prog_t for l in logs) / total_turns
+ total_tokens = sum(l.total_tokens for l in logs)
+ total_prompt = sum(l.prompt_tokens for l in logs)
+ total_completion = sum(l.completion_tokens for l in logs)
+
+ # Gating stats
+ gated_turns = [l for l in logs if l.gating > 0]
+ non_gated_turns = [l for l in logs if l.gating == 0]
+
+ print(f"\n--- Turn Statistics ---")
+ print(f"Total turns: {total_turns}")
+ print(f"Gated turns (RL active): {len(gated_turns)}")
+ print(f"Non-gated turns (RL skipped): {len(non_gated_turns)}")
+
+ print(f"\n--- Satisfaction ---")
+ print(f"Average sat_t (all): {avg_sat:.3f}")
+ if gated_turns:
+ avg_sat_gated = sum(l.sat_t for l in gated_turns) / len(gated_turns)
+ print(f"Average sat_t (gated only): {avg_sat_gated:.3f}")
+ print(f"Average prog_t: {avg_prog:.3f}")
+
+ print(f"\n--- Token Usage ---")
+ print(f"Total tokens: {total_tokens}")
+ print(f" Prompt: {total_prompt}")
+ print(f" Completion: {total_completion}")
+ print(f"Avg tokens/turn: {total_tokens / total_turns:.1f}")
+
+ # Violations breakdown
+ print(f"\n--- Violations ---")
+ from collections import Counter
+ all_violations = [v for l in logs for v in l.violations]
+ if all_violations:
+ print(f"Total violations: {len(all_violations)}")
+ for v, count in Counter(all_violations).most_common():
+ print(f" {v}: {count}")
+ else:
+ print("No violations")
+
+ # Answer length analysis
+ print(f"\n--- Answer Lengths (max_chars={prefs.max_chars}) ---")
+ lengths = [l.answer_length for l in logs]
+ over_limit = sum(1 for l in lengths if l > prefs.max_chars)
+ print(f"Min: {min(lengths)}, Max: {max(lengths)}, Avg: {sum(lengths)/len(lengths):.1f}")
+ print(f"Over limit: {over_limit}/{total_turns}")
+
+ # RL Health Check
+ print(f"\n--- RL Health Check ---")
+ z_long_changes = [abs(l.z_long_norm_after - l.z_long_norm_before) for l in logs]
+ z_short_changes = [abs(l.z_short_norm_after - l.z_short_norm_before) for l in logs]
+ any_z_long_change = any(c > 1e-6 for c in z_long_changes)
+ any_z_short_change = any(c > 1e-6 for c in z_short_changes)
+
+ print(f"z_long changed: {any_z_long_change} (max Δ: {max(z_long_changes):.6f})")
+ print(f"z_short changed: {any_z_short_change} (max Δ: {max(z_short_changes):.6f})")
+
+ if any_z_long_change or any_z_short_change:
+ print("✓ User vectors ARE being updated by RL")
+ else:
+ print("✗ WARNING: User vectors NOT changing")
+ print(" Check: gating=1 on some turns? reward != baseline?")
+
+ # Per-turn detail table
+ print(f"\n--- Turn-by-Turn Summary ---")
+ print(f"{'Turn':>4} {'Type':>10} {'Len':>5} {'sat':>5} {'gate':>5} {'violations'}")
+ print("-" * 60)
+ for l in logs:
+ viol_str = ",".join(l.violations) if l.violations else "-"
+ print(f"{l.turn_id:>4} {l.query_type:>10} {l.answer_length:>5} {l.sat_t:>5.2f} {l.gating:>5.1f} {viol_str}")
+
+
+def main():
+ print("=" * 60)
+ print("PILOT RUNNER v1 - Style-Aware Judge + Gating")
+ print("=" * 60)
+ print(f"Started at: {datetime.now().isoformat()}")
+
+ # Define user preferences
+ prefs = StylePrefs(
+ require_short=True,
+ max_chars=200,
+ require_bullets=True,
+ lang="en",
+ )
+ print(f"\n[Config] User preferences: {prefs}")
+
+ # Initialize LLM
+ print("\n[Init] Loading PersonalizedLLM...")
+ llm = PersonalizedLLM(
+ user_store_path="data/users/user_store_pilot_v1.npz",
+ only_own_memories=True,
+ enable_preference_extraction=True,
+ enable_rl_updates=True,
+ )
+
+ # Run pilot
+ user_id = "pilot_user_v1"
+ logs = run_pilot_v1(llm, user_id=user_id, prefs=prefs)
+
+ # Summary
+ print_summary_v1(logs, prefs)
+
+ # Save logs
+ log_path = f"data/logs/pilot_v1_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl"
+ log_to_jsonl(logs, log_path)
+ print(f"\n[Logs] Saved to: {log_path}")
+
+ # Final state
+ final_state = llm.get_user_state_summary(user_id)
+ print(f"\n[Final State] {final_state}")
+
+ print(f"\nCompleted at: {datetime.now().isoformat()}")
+ print("=" * 60)
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/scripts/pilot_runner_v2.py b/scripts/pilot_runner_v2.py
new file mode 100644
index 0000000..d3c2aa8
--- /dev/null
+++ b/scripts/pilot_runner_v2.py
@@ -0,0 +1,852 @@
+#!/usr/bin/env python3
+"""
+Pilot Runner v2 - Cross-Session Preference Reveal Mechanism
+
+Upgrade from v1:
+- RevealState: Tracks which preferences have been explicitly revealed by the user
+- pref_true[k] vs pref_revealed_global[k] distinction
+- Style constraints only enforced AFTER user reveals them
+- Reveal state persists across sessions, resets on reset_user()
+
+Key concepts:
+- pref_true[k]: User's true preference (from StylePrefs)
+- pref_revealed_global[k]: Whether preference k has been revealed at least once
+
+Enforcement rule:
+- A style constraint is enforced only when BOTH pref_true[k] AND pref_revealed_global[k]
+
+Session semantics:
+- reset_user(): Clears ALL state including reveal flags
+- reset_session(): Keeps reveal flags (cross-session memory)
+"""
+
+import sys
+import os
+import json
+from datetime import datetime
+from dataclasses import dataclass, asdict, field
+from typing import List, Dict, Any, Optional, Tuple, Set
+
+# Add src to path
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.serving import PersonalizedLLM, Feedback, AssistantResponse
+
+
+# =============================================================================
+# Style Preferences (True Preferences)
+# =============================================================================
+
+@dataclass
+class StylePrefs:
+ """
+ User's TRUE style preferences.
+ These are the ground truth preferences that the user actually has,
+ but they may not have revealed all of them to the system yet.
+ """
+ require_short: bool = False
+ max_chars: int = 300
+ require_bullets: bool = False
+ lang: str = "en" # "en" or "zh"
+
+
+# =============================================================================
+# Reveal State (What has been explicitly revealed)
+# =============================================================================
+
+@dataclass
+class RevealState:
+ """
+ Tracks which preferences have been explicitly revealed by the user.
+
+ This persists across sessions for the same user but resets on reset_user().
+ A preference is revealed when the user explicitly mentions it in a query.
+ """
+ short_revealed: bool = False # "short", "concise", "brief", length constraints
+ bullets_revealed: bool = False # "bullet", "bullet points", "list format"
+ lang_revealed: bool = False # Language preference mentioned
+
+ def reset(self):
+ """Reset all reveal flags (called on reset_user)."""
+ self.short_revealed = False
+ self.bullets_revealed = False
+ self.lang_revealed = False
+
+ def to_dict(self) -> Dict[str, bool]:
+ return {
+ "short": self.short_revealed,
+ "bullets": self.bullets_revealed,
+ "lang": self.lang_revealed,
+ }
+
+ def __str__(self) -> str:
+ flags = []
+ if self.short_revealed:
+ flags.append("short")
+ if self.bullets_revealed:
+ flags.append("bullets")
+ if self.lang_revealed:
+ flags.append("lang")
+ return f"RevealState({', '.join(flags) if flags else 'none'})"
+
+
+class RevealStateManager:
+ """
+ Manages reveal state for multiple users.
+ Persists across sessions, resets on reset_user().
+ """
+
+ def __init__(self):
+ self._states: Dict[str, RevealState] = {}
+
+ def get_state(self, user_id: str) -> RevealState:
+ """Get or create reveal state for a user."""
+ if user_id not in self._states:
+ self._states[user_id] = RevealState()
+ return self._states[user_id]
+
+ def reset_user(self, user_id: str):
+ """Reset reveal state for a user (called on reset_user)."""
+ if user_id in self._states:
+ self._states[user_id].reset()
+ else:
+ self._states[user_id] = RevealState()
+
+ def reset_session(self, user_id: str):
+ """
+ Called on reset_session - does NOT reset reveal state.
+ Reveal state persists across sessions.
+ """
+ # Intentionally do nothing - reveal state persists
+ pass
+
+
+# =============================================================================
+# Preference Detection from Queries
+# =============================================================================
+
+def detect_revealed_preferences(query: str) -> Dict[str, bool]:
+ """
+ Detect which preferences are mentioned in a query.
+
+ Returns a dict with keys: "short", "bullets", "lang"
+ Each value is True if that preference was mentioned.
+ """
+ lower_q = (query or "").lower()
+
+ revealed = {
+ "short": False,
+ "bullets": False,
+ "lang": False,
+ }
+
+ # Short/length preference detection
+ short_patterns = [
+ "short", "concise", "brief", "under ", "less than",
+ "keep it short", "keep responses", "keep answers",
+ "maximum ", "max ", "characters", "words or less",
+ "200 ", "100 ", "50 ", "300 ", # Common char limits
+ ]
+ for pattern in short_patterns:
+ if pattern in lower_q:
+ revealed["short"] = True
+ break
+
+ # Bullet preference detection
+ bullet_patterns = [
+ "bullet", "bullet point", "bullet-point",
+ "bulleted", "list format", "use bullets",
+ "use bullet", "with bullets", "in bullets",
+ "- format", "• ", "numbered list",
+ ]
+ for pattern in bullet_patterns:
+ if pattern in lower_q:
+ revealed["bullets"] = True
+ break
+
+ # Language preference detection
+ lang_patterns_zh = [
+ "chinese", "中文", "in chinese", "用中文",
+ "speak chinese", "write chinese", "respond in chinese",
+ "please use chinese", "mandarin",
+ ]
+ lang_patterns_en = [
+ "english", "in english", "use english",
+ "speak english", "write english", "respond in english",
+ "please use english",
+ ]
+
+ for pattern in lang_patterns_zh + lang_patterns_en:
+ if pattern in lower_q:
+ revealed["lang"] = True
+ break
+
+ return revealed
+
+
+def update_reveal_state(reveal_state: RevealState, query: str) -> Set[str]:
+ """
+ Update reveal state based on query content.
+ Returns set of newly revealed preferences.
+ """
+ detected = detect_revealed_preferences(query)
+ newly_revealed = set()
+
+ if detected["short"] and not reveal_state.short_revealed:
+ reveal_state.short_revealed = True
+ newly_revealed.add("short")
+
+ if detected["bullets"] and not reveal_state.bullets_revealed:
+ reveal_state.bullets_revealed = True
+ newly_revealed.add("bullets")
+
+ if detected["lang"] and not reveal_state.lang_revealed:
+ reveal_state.lang_revealed = True
+ newly_revealed.add("lang")
+
+ return newly_revealed
+
+
+# =============================================================================
+# Style-Aware Judge with Reveal State
+# =============================================================================
+
+@dataclass
+class JudgeResult:
+ """Output from the judge for one turn."""
+ sat_t: float # Satisfaction score [0, 1]
+ sev_t: float # Severity of violations [0, 1]
+ prog_t: float # Task progress [0, 1]
+ violations: List[str] # List of violated constraints
+ enforced_constraints: List[str] # Which constraints were actually enforced
+
+
+def style_judge_with_reveal(
+ query: str,
+ answer: str,
+ task_type: str,
+ prefs: StylePrefs,
+ reveal_state: RevealState,
+) -> JudgeResult:
+ """
+ Style-aware judge that ONLY enforces revealed preferences.
+
+ A constraint is enforced only when:
+ - pref_true[k] is True (user has this preference)
+ - pref_revealed_global[k] is True (user has revealed this preference)
+
+ Args:
+ query: User's query
+ answer: Assistant's answer
+ task_type: Type of task ("general", "list", "code")
+ prefs: User's TRUE preferences (StylePrefs)
+ reveal_state: Which preferences have been revealed
+
+ Returns:
+ JudgeResult with sat_t, sev_t, prog_t, violations, and enforced_constraints
+ """
+ violations: List[str] = []
+ enforced: List[str] = []
+ text = (answer or "").strip()
+
+ # 0) Empty answer - always a violation regardless of reveal state
+ if not text or len(text) < 5:
+ violations.append("empty_answer")
+ return JudgeResult(
+ sat_t=0.0,
+ sev_t=1.0,
+ prog_t=0.0,
+ violations=violations,
+ enforced_constraints=["non_empty"],
+ )
+
+ # 1) Length preference - enforce only if BOTH true AND revealed
+ if prefs.require_short and reveal_state.short_revealed:
+ enforced.append("short")
+ if len(text) > prefs.max_chars:
+ violations.append("too_long")
+
+ # 2) Bullet preference - enforce only if BOTH true AND revealed
+ # Also only for list-type tasks
+ if prefs.require_bullets and reveal_state.bullets_revealed:
+ if task_type in ("general", "list"):
+ enforced.append("bullets")
+ has_bullets = ("- " in text) or ("• " in text) or ("* " in text) or ("\n- " in text)
+ if not has_bullets:
+ violations.append("no_bullets")
+
+ # 3) Language preference - enforce only if BOTH true AND revealed
+ if reveal_state.lang_revealed:
+ enforced.append("lang")
+ if prefs.lang == "zh":
+ ascii_count = sum(c.isascii() for c in text)
+ ascii_ratio = ascii_count / max(1, len(text))
+ if ascii_ratio > 0.7:
+ violations.append("wrong_lang")
+ elif prefs.lang == "en":
+ ascii_count = sum(c.isascii() for c in text)
+ ascii_ratio = ascii_count / max(1, len(text))
+ if ascii_ratio < 0.5:
+ violations.append("wrong_lang")
+
+ # 4) Code task: always enforce code markers (not a user preference)
+ prog_t = 1.0
+ if task_type == "code":
+ enforced.append("code_block")
+ has_code = ("```" in text) or ("def " in text) or ("function " in text)
+ if not has_code:
+ violations.append("no_code_block")
+ prog_t = 0.0
+
+ # 5) Compute sat_t and sev_t from violations
+ if not violations:
+ sat_t = 1.0
+ sev_t = 0.0
+ else:
+ sat_t = max(0.0, 1.0 - 0.3 * float(len(violations)))
+ hard_violations = {"empty_answer", "too_long", "wrong_lang"}
+ sev_t = 1.0 if any(v in hard_violations for v in violations) else 0.0
+
+ return JudgeResult(
+ sat_t=sat_t,
+ sev_t=sev_t,
+ prog_t=prog_t,
+ violations=violations,
+ enforced_constraints=enforced,
+ )
+
+
+# =============================================================================
+# Feedback Computation (reward + gating)
+# =============================================================================
+
+def compute_feedback_for_turn(
+ turn_id: int,
+ query: str,
+ query_type: str,
+ task_type: str,
+ judge_result: JudgeResult,
+) -> Tuple[float, float]:
+ """
+ Convert JudgeResult into (reward, gating).
+ Same as v1 - reward = sat_t, gating = 1 for preference turns.
+ """
+ reward = judge_result.sat_t
+
+ lower_q = (query or "").lower()
+
+ is_pref_turn = (
+ query_type == "preference"
+ or "i prefer" in lower_q
+ or "my preference" in lower_q
+ or "please use" in lower_q
+ or "please keep" in lower_q
+ or "you didn't follow" in lower_q
+ or "you forgot" in lower_q
+ or "remember that i" in lower_q
+ or "i told you" in lower_q
+ or "i asked for" in lower_q
+ )
+
+ gating = 1.0 if is_pref_turn else 0.0
+ return reward, gating
+
+
+# =============================================================================
+# Multi-Session Queries for Pilot v2
+# =============================================================================
+
+def get_session_1_queries() -> List[Dict[str, Any]]:
+ """
+ Session 1: User reveals preferences and does some tasks.
+ """
+ return [
+ {
+ "query": "I prefer short, concise answers. Please keep responses under 200 characters.",
+ "type": "preference",
+ "task_type": "general",
+ },
+ {
+ "query": "What are three tips for better sleep?",
+ "type": "task",
+ "task_type": "list",
+ },
+ {
+ "query": "I also prefer bullet points when listing things.",
+ "type": "preference",
+ "task_type": "general",
+ },
+ {
+ "query": "What are the main benefits of exercise?",
+ "type": "task",
+ "task_type": "list",
+ },
+ {
+ "query": "Name five programming languages.",
+ "type": "task",
+ "task_type": "list",
+ },
+ ]
+
+
+def get_session_2_queries() -> List[Dict[str, Any]]:
+ """
+ Session 2: User does NOT restate preferences.
+ Tests cross-session preference retention.
+ """
+ return [
+ {
+ "query": "What are three healthy breakfast ideas?",
+ "type": "task",
+ "task_type": "list",
+ },
+ {
+ "query": "List four seasons of the year.",
+ "type": "task",
+ "task_type": "list",
+ },
+ {
+ "query": "What is the capital of France?",
+ "type": "task",
+ "task_type": "general",
+ },
+ {
+ "query": "Name three types of renewable energy.",
+ "type": "task",
+ "task_type": "list",
+ },
+ ]
+
+
+def get_session_3_queries() -> List[Dict[str, Any]]:
+ """
+ Session 3: Mix of tasks and one complaint/reminder.
+ """
+ return [
+ {
+ "query": "What are five common fruits?",
+ "type": "task",
+ "task_type": "list",
+ },
+ {
+ "query": "Remember that I asked for short bullet points. List three ocean animals.",
+ "type": "preference",
+ "task_type": "list",
+ },
+ {
+ "query": "What is 2 + 2?",
+ "type": "task",
+ "task_type": "general",
+ },
+ ]
+
+
+# =============================================================================
+# Logging (Extended for v2)
+# =============================================================================
+
+@dataclass
+class TurnLog:
+ """Log entry for one turn (extended for v2)."""
+ session_id: int
+ turn_id: int
+ query: str
+ query_type: str
+ task_type: str
+ answer: str
+ answer_length: int
+ sat_t: float
+ sev_t: float
+ prog_t: float
+ violations: List[str]
+ enforced_constraints: List[str]
+ reward: float
+ gating: float
+ reveal_state_before: Dict[str, bool]
+ reveal_state_after: Dict[str, bool]
+ newly_revealed: List[str]
+ z_long_norm_before: float
+ z_long_norm_after: float
+ z_short_norm_before: float
+ z_short_norm_after: float
+ prompt_tokens: int
+ completion_tokens: int
+ total_tokens: int
+ num_memories_retrieved: int
+ num_prefs_extracted: int
+
+
+def log_to_jsonl(logs: List[TurnLog], filepath: str):
+ """Save logs to JSONL file."""
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
+ with open(filepath, "w") as f:
+ for log in logs:
+ f.write(json.dumps(asdict(log)) + "\n")
+
+
+# =============================================================================
+# Pilot Runner v2 (Multi-Session with Reveal State)
+# =============================================================================
+
+def run_session(
+ llm: PersonalizedLLM,
+ user_id: str,
+ session_id: int,
+ prefs: StylePrefs,
+ reveal_state: RevealState,
+ queries: List[Dict[str, Any]],
+) -> List[TurnLog]:
+ """
+ Run a single session with reveal-aware judging.
+ """
+ logs: List[TurnLog] = []
+
+ print(f"\n{'='*60}")
+ print(f"SESSION {session_id}: user_id={user_id}, turns={len(queries)}")
+ print(f"Reveal state (start): {reveal_state}")
+ print(f"{'='*60}")
+
+ # Reset session (clears history, z_short; keeps z_long and reveal state)
+ llm.reset_session(user_id)
+
+ state_before = llm.get_user_state_summary(user_id)
+ print(f"[Session] z_long={state_before['z_long_norm']:.6f}, z_short={state_before['z_short_norm']:.6f}")
+
+ for turn_id, q_info in enumerate(queries):
+ query = q_info["query"]
+ query_type = q_info.get("type", "task")
+ task_type = q_info.get("task_type", "general")
+
+ print(f"\n{'─'*60}")
+ print(f"Session {session_id} / Turn {turn_id} [{query_type}]")
+ print(f"{'─'*60}")
+ print(f"[Query] {query}")
+
+ # Capture reveal state BEFORE this turn
+ reveal_before = reveal_state.to_dict()
+
+ # Update reveal state based on query content
+ newly_revealed = update_reveal_state(reveal_state, query)
+ if newly_revealed:
+ print(f"[Reveal] Newly revealed: {newly_revealed}")
+ print(f"[Reveal] State: {reveal_state}")
+
+ # Capture reveal state AFTER update
+ reveal_after = reveal_state.to_dict()
+
+ # Get user state before
+ state_before = llm.get_user_state_summary(user_id)
+ z_long_before = state_before["z_long_norm"]
+ z_short_before = state_before["z_short_norm"]
+
+ # Apply feedback for previous turn (from turn 1 onwards in this session)
+ if turn_id > 0 and len(logs) > 0:
+ # Find the last log from THIS session
+ session_logs = [l for l in logs if l.session_id == session_id]
+ if session_logs:
+ prev_log = session_logs[-1]
+ feedback = Feedback(
+ user_id=user_id,
+ turn_id=prev_log.turn_id,
+ reward=prev_log.reward,
+ gating=prev_log.gating,
+ meta={
+ "sat_t": prev_log.sat_t,
+ "violations": prev_log.violations,
+ "source": "pilot_v2",
+ "session_id": session_id,
+ }
+ )
+ print(f"[Feedback] turn={prev_log.turn_id}, reward={feedback.reward:.2f}, gating={feedback.gating:.1f}")
+ llm.apply_feedback(feedback)
+
+ # Chat
+ resp: AssistantResponse = llm.chat(user_id, query)
+
+ answer_display = resp.answer[:150] + "..." if len(resp.answer) > 150 else resp.answer
+ print(f"[Answer] ({len(resp.answer)} chars) {answer_display}")
+ print(f"[Usage] prompt={resp.usage.prompt_tokens}, completion={resp.usage.completion_tokens}")
+
+ # Judge with reveal-aware logic
+ judge_result = style_judge_with_reveal(query, resp.answer, task_type, prefs, reveal_state)
+ print(f"[Judge] sat={judge_result.sat_t:.2f}, enforced={judge_result.enforced_constraints}")
+ if judge_result.violations:
+ print(f"[Judge] violations={judge_result.violations}")
+
+ # Compute feedback
+ reward, gating = compute_feedback_for_turn(
+ turn_id=turn_id,
+ query=query,
+ query_type=query_type,
+ task_type=task_type,
+ judge_result=judge_result,
+ )
+ print(f"[Feedback] reward={reward:.2f}, gating={gating:.1f}")
+
+ # Get state after
+ state_after = llm.get_user_state_summary(user_id)
+ z_long_after = state_after["z_long_norm"]
+ z_short_after = state_after["z_short_norm"]
+
+ z_long_delta = z_long_after - z_long_before
+ z_short_delta = z_short_after - z_short_before
+ print(f"[State] z_long: {z_long_before:.6f} → {z_long_after:.6f} (Δ={z_long_delta:+.6f})")
+ print(f"[State] z_short: {z_short_before:.6f} → {z_short_after:.6f} (Δ={z_short_delta:+.6f})")
+
+ # Debug info
+ num_memories = len(resp.debug.selected_memory_ids) if resp.debug else 0
+ num_prefs = len(resp.debug.extracted_preferences) if resp.debug else 0
+ print(f"[Debug] memories={num_memories}, prefs_extracted={num_prefs}")
+
+ # Log
+ log = TurnLog(
+ session_id=session_id,
+ turn_id=turn_id,
+ query=query,
+ query_type=query_type,
+ task_type=task_type,
+ answer=resp.answer,
+ answer_length=len(resp.answer),
+ sat_t=judge_result.sat_t,
+ sev_t=judge_result.sev_t,
+ prog_t=judge_result.prog_t,
+ violations=judge_result.violations,
+ enforced_constraints=judge_result.enforced_constraints,
+ reward=reward,
+ gating=gating,
+ reveal_state_before=reveal_before,
+ reveal_state_after=reveal_after,
+ newly_revealed=list(newly_revealed),
+ z_long_norm_before=z_long_before,
+ z_long_norm_after=z_long_after,
+ z_short_norm_before=z_short_before,
+ z_short_norm_after=z_short_after,
+ prompt_tokens=resp.usage.prompt_tokens,
+ completion_tokens=resp.usage.completion_tokens,
+ total_tokens=resp.usage.total_tokens,
+ num_memories_retrieved=num_memories,
+ num_prefs_extracted=num_prefs,
+ )
+ logs.append(log)
+
+ # Apply final feedback for this session
+ session_logs = [l for l in logs if l.session_id == session_id]
+ if session_logs:
+ last_log = session_logs[-1]
+ feedback = Feedback(
+ user_id=user_id,
+ turn_id=last_log.turn_id,
+ reward=last_log.reward,
+ gating=last_log.gating,
+ meta={"source": "pilot_v2", "session_id": session_id, "final": True}
+ )
+ print(f"\n[Final Feedback] turn={last_log.turn_id}, reward={feedback.reward:.2f}, gating={feedback.gating:.1f}")
+ llm.apply_feedback(feedback)
+
+ print(f"\n[Session {session_id} End] Reveal state: {reveal_state}")
+
+ return logs
+
+
+def run_pilot_v2(
+ llm: PersonalizedLLM,
+ user_id: str = "pilot_user_v2",
+ prefs: Optional[StylePrefs] = None,
+) -> List[TurnLog]:
+ """
+ Run multi-session pilot with reveal state tracking.
+
+ Session 1: User reveals preferences
+ Session 2: User does NOT restate preferences (tests cross-session retention)
+ Session 3: Mix of tasks and reminders
+ """
+ if prefs is None:
+ prefs = StylePrefs(
+ require_short=True,
+ max_chars=200,
+ require_bullets=True,
+ lang="en",
+ )
+
+ # Initialize reveal state manager
+ reveal_manager = RevealStateManager()
+
+ print(f"\n{'#'*60}")
+ print(f"PILOT v2: CROSS-SESSION PREFERENCE REVEAL TEST")
+ print(f"User: {user_id}")
+ print(f"True prefs: short={prefs.require_short}, bullets={prefs.require_bullets}, lang={prefs.lang}")
+ print(f"{'#'*60}")
+
+ # Reset user completely (clears all state including reveal)
+ print(f"\n[Pilot] Resetting user: {user_id}")
+ llm.reset_user(user_id)
+ reveal_manager.reset_user(user_id)
+
+ all_logs: List[TurnLog] = []
+ reveal_state = reveal_manager.get_state(user_id)
+
+ # Session 1: Reveal preferences
+ session_1_queries = get_session_1_queries()
+ logs_s1 = run_session(llm, user_id, 1, prefs, reveal_state, session_1_queries)
+ all_logs.extend(logs_s1)
+
+ # Session 2: NO preference restatement (test cross-session retention)
+ # Note: reveal_state persists, but reset_session clears history
+ reveal_manager.reset_session(user_id) # Does nothing to reveal state
+ session_2_queries = get_session_2_queries()
+ logs_s2 = run_session(llm, user_id, 2, prefs, reveal_state, session_2_queries)
+ all_logs.extend(logs_s2)
+
+ # Session 3: Reminder and more tasks
+ reveal_manager.reset_session(user_id)
+ session_3_queries = get_session_3_queries()
+ logs_s3 = run_session(llm, user_id, 3, prefs, reveal_state, session_3_queries)
+ all_logs.extend(logs_s3)
+
+ return all_logs
+
+
+def print_summary_v2(logs: List[TurnLog], prefs: StylePrefs):
+ """Print summary for pilot v2."""
+ print(f"\n{'='*60}")
+ print("PILOT v2 SUMMARY - Cross-Session Reveal")
+ print(f"{'='*60}")
+
+ if not logs:
+ print("No logs to summarize.")
+ return
+
+ # Per-session stats
+ sessions = sorted(set(l.session_id for l in logs))
+
+ print(f"\n--- Per-Session Statistics ---")
+ for sid in sessions:
+ session_logs = [l for l in logs if l.session_id == sid]
+ avg_sat = sum(l.sat_t for l in session_logs) / len(session_logs)
+ violations = [v for l in session_logs for v in l.violations]
+
+ # What was revealed at session end
+ if session_logs:
+ final_reveal = session_logs[-1].reveal_state_after
+ else:
+ final_reveal = {}
+
+ print(f"\nSession {sid}: {len(session_logs)} turns")
+ print(f" Avg sat_t: {avg_sat:.3f}")
+ print(f" Violations: {len(violations)} ({violations if violations else 'none'})")
+ print(f" Reveal state at end: {final_reveal}")
+
+ # Overall stats
+ total = len(logs)
+ avg_sat = sum(l.sat_t for l in logs) / total
+ total_tokens = sum(l.total_tokens for l in logs)
+
+ print(f"\n--- Overall Statistics ---")
+ print(f"Total turns: {total}")
+ print(f"Overall avg sat_t: {avg_sat:.3f}")
+ print(f"Total tokens: {total_tokens}")
+
+ # Violations by type
+ print(f"\n--- Violations Breakdown ---")
+ from collections import Counter
+ all_violations = [v for l in logs for v in l.violations]
+ if all_violations:
+ for v, count in Counter(all_violations).most_common():
+ print(f" {v}: {count}")
+ else:
+ print(" No violations")
+
+ # Enforcement tracking
+ print(f"\n--- Constraint Enforcement ---")
+ for constraint in ["short", "bullets", "lang"]:
+ enforced_count = sum(1 for l in logs if constraint in l.enforced_constraints)
+ print(f" {constraint}: enforced in {enforced_count}/{total} turns")
+
+ # Cross-session reveal verification
+ print(f"\n--- Cross-Session Reveal Verification ---")
+
+ # Session 1: Should have some reveals
+ s1_logs = [l for l in logs if l.session_id == 1]
+ s1_reveals = set()
+ for l in s1_logs:
+ s1_reveals.update(l.newly_revealed)
+ print(f"Session 1 revealed: {s1_reveals if s1_reveals else 'none'}")
+
+ # Session 2: Should NOT have new reveals (no preference queries)
+ s2_logs = [l for l in logs if l.session_id == 2]
+ s2_reveals = set()
+ for l in s2_logs:
+ s2_reveals.update(l.newly_revealed)
+ print(f"Session 2 revealed: {s2_reveals if s2_reveals else 'none (expected)'}")
+
+ # But Session 2 should still ENFORCE the constraints revealed in Session 1
+ if s2_logs:
+ s2_enforced = set()
+ for l in s2_logs:
+ s2_enforced.update(l.enforced_constraints)
+ print(f"Session 2 enforced: {s2_enforced}")
+
+ if s1_reveals and s1_reveals.issubset(s2_enforced):
+ print("✓ Cross-session retention VERIFIED: Session 1 reveals enforced in Session 2")
+ else:
+ print("✗ Cross-session retention issue: some reveals not enforced")
+
+ # Turn-by-turn table
+ print(f"\n--- Turn-by-Turn Summary ---")
+ print(f"{'S':>2} {'T':>2} {'Type':>10} {'Len':>5} {'sat':>5} {'enforced':<20} {'violations'}")
+ print("-" * 70)
+ for l in logs:
+ enforced_str = ",".join(l.enforced_constraints) if l.enforced_constraints else "-"
+ viol_str = ",".join(l.violations) if l.violations else "-"
+ print(f"{l.session_id:>2} {l.turn_id:>2} {l.query_type:>10} {l.answer_length:>5} {l.sat_t:>5.2f} {enforced_str:<20} {viol_str}")
+
+
+def main():
+ print("=" * 60)
+ print("PILOT RUNNER v2 - Cross-Session Preference Reveal")
+ print("=" * 60)
+ print(f"Started at: {datetime.now().isoformat()}")
+
+ # Define user's TRUE preferences
+ prefs = StylePrefs(
+ require_short=True,
+ max_chars=200,
+ require_bullets=True,
+ lang="en",
+ )
+ print(f"\n[Config] True preferences: {prefs}")
+ print("[Config] Note: Constraints only enforced AFTER user reveals them")
+
+ # Initialize LLM
+ print("\n[Init] Loading PersonalizedLLM...")
+ llm = PersonalizedLLM(
+ user_store_path="data/users/user_store_pilot_v2.npz",
+ only_own_memories=True,
+ enable_preference_extraction=True,
+ enable_rl_updates=True,
+ )
+
+ # Run pilot
+ user_id = "pilot_user_v2"
+ logs = run_pilot_v2(llm, user_id=user_id, prefs=prefs)
+
+ # Summary
+ print_summary_v2(logs, prefs)
+
+ # Save logs
+ log_path = f"data/logs/pilot_v2_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl"
+ log_to_jsonl(logs, log_path)
+ print(f"\n[Logs] Saved to: {log_path}")
+
+ # Final state
+ final_state = llm.get_user_state_summary(user_id)
+ print(f"\n[Final State] {final_state}")
+
+ print(f"\nCompleted at: {datetime.now().isoformat()}")
+ print("=" * 60)
+
+
+if __name__ == "__main__":
+ main()
+
+
diff --git a/scripts/pilot_runner_v3.py b/scripts/pilot_runner_v3.py
new file mode 100644
index 0000000..d232d10
--- /dev/null
+++ b/scripts/pilot_runner_v3.py
@@ -0,0 +1,924 @@
+#!/usr/bin/env python3
+"""
+Pilot Runner v3 - Multi-User Multi-Session with Personas
+
+Upgrades from v2:
+- Persona: Bundles StylePrefs into user types
+- 5 test personas (A-E) targeting different style combinations
+- Multi-user × multi-session evaluation
+- Refined judge: bullets only on list tasks, relaxed empty_answer
+- Baseline mode support (no-personalization comparison)
+
+5 Test Personas:
+- A: short + bullets + en (sanity check)
+- B: short + NO bullets + en (anti-bullet)
+- C: long + bullets + en (no length constraint)
+- D: short + bullets + zh (Chinese)
+- E: long + NO bullets + zh (most "anti-default")
+"""
+
+import sys
+import os
+import json
+from datetime import datetime
+from dataclasses import dataclass, asdict, field
+from typing import List, Dict, Any, Optional, Tuple, Set, Literal
+
+# Add src to path
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.serving import PersonalizedLLM, Feedback, AssistantResponse
+
+
+# =============================================================================
+# Style Preferences (True Preferences)
+# =============================================================================
+
+@dataclass
+class StylePrefs:
+ """User's TRUE style preferences."""
+ require_short: bool = False
+ max_chars: int = 300
+ require_bullets: bool = False
+ lang: str = "en" # "en" or "zh"
+
+
+# =============================================================================
+# Persona Definition
+# =============================================================================
+
+@dataclass
+class Persona:
+ """
+ A user persona that bundles style preferences.
+ Each persona represents a distinct user type for testing.
+ """
+ persona_id: str
+ style_prefs: StylePrefs
+ description: str = ""
+ # Future extensions:
+ # task_preferences: Dict[str, float] # e.g., {"code": 0.3, "rewrite": 0.7}
+ # tone: str = "neutral" # "formal", "casual", etc.
+ # domain: str = "general" # "tech", "daily_life", etc.
+
+
+# =============================================================================
+# 5 Test Personas (A-E)
+# =============================================================================
+
+PERSONA_A = Persona(
+ persona_id="A_short_bullets_en",
+ style_prefs=StylePrefs(
+ require_short=True,
+ max_chars=200,
+ require_bullets=True,
+ lang="en",
+ ),
+ description="Short + bullets + English (sanity check, same as v2)",
+)
+
+PERSONA_B = Persona(
+ persona_id="B_short_no_bullets_en",
+ style_prefs=StylePrefs(
+ require_short=True,
+ max_chars=200,
+ require_bullets=False,
+ lang="en",
+ ),
+ description="Short + NO bullets + English (anti-bullet test)",
+)
+
+PERSONA_C = Persona(
+ persona_id="C_long_bullets_en",
+ style_prefs=StylePrefs(
+ require_short=False,
+ max_chars=800,
+ require_bullets=True,
+ lang="en",
+ ),
+ description="Long + bullets + English (no length constraint)",
+)
+
+PERSONA_D = Persona(
+ persona_id="D_short_bullets_zh",
+ style_prefs=StylePrefs(
+ require_short=True,
+ max_chars=200,
+ require_bullets=True,
+ lang="zh",
+ ),
+ description="Short + bullets + Chinese (language test)",
+)
+
+PERSONA_E = Persona(
+ persona_id="E_long_no_bullets_zh",
+ style_prefs=StylePrefs(
+ require_short=False,
+ max_chars=800,
+ require_bullets=False,
+ lang="zh",
+ ),
+ description="Long + NO bullets + Chinese (most anti-default)",
+)
+
+ALL_PERSONAS = [PERSONA_A, PERSONA_B, PERSONA_C, PERSONA_D, PERSONA_E]
+
+
+def get_persona_by_id(persona_id: str) -> Optional[Persona]:
+ """Get persona by ID."""
+ for p in ALL_PERSONAS:
+ if p.persona_id == persona_id:
+ return p
+ return None
+
+
+# =============================================================================
+# Reveal State
+# =============================================================================
+
+@dataclass
+class RevealState:
+ """Tracks which preferences have been explicitly revealed."""
+ short_revealed: bool = False
+ bullets_revealed: bool = False
+ lang_revealed: bool = False
+
+ def reset(self):
+ self.short_revealed = False
+ self.bullets_revealed = False
+ self.lang_revealed = False
+
+ def to_dict(self) -> Dict[str, bool]:
+ return {
+ "short": self.short_revealed,
+ "bullets": self.bullets_revealed,
+ "lang": self.lang_revealed,
+ }
+
+ def __str__(self) -> str:
+ flags = []
+ if self.short_revealed:
+ flags.append("short")
+ if self.bullets_revealed:
+ flags.append("bullets")
+ if self.lang_revealed:
+ flags.append("lang")
+ return f"RevealState({', '.join(flags) if flags else 'none'})"
+
+
+class RevealStateManager:
+ """Manages reveal state for multiple users."""
+
+ def __init__(self):
+ self._states: Dict[str, RevealState] = {}
+
+ def get_state(self, user_id: str) -> RevealState:
+ if user_id not in self._states:
+ self._states[user_id] = RevealState()
+ return self._states[user_id]
+
+ def reset_user(self, user_id: str):
+ if user_id in self._states:
+ self._states[user_id].reset()
+ else:
+ self._states[user_id] = RevealState()
+
+ def reset_session(self, user_id: str):
+ pass # Reveal state persists across sessions
+
+
+# =============================================================================
+# Preference Detection
+# =============================================================================
+
+def detect_revealed_preferences(query: str, prefs: StylePrefs) -> Dict[str, bool]:
+ """
+ Detect which preferences are mentioned in a query.
+ Also considers the user's true preferences for language detection.
+ """
+ lower_q = (query or "").lower()
+
+ revealed = {
+ "short": False,
+ "bullets": False,
+ "lang": False,
+ }
+
+ # Short/length preference
+ short_patterns = [
+ "short", "concise", "brief", "under ", "less than",
+ "keep it short", "keep responses", "keep answers",
+ "maximum ", "max ", "characters", "words or less",
+ "200 ", "100 ", "50 ", "300 ",
+ ]
+ for pattern in short_patterns:
+ if pattern in lower_q:
+ revealed["short"] = True
+ break
+
+ # Bullet preference (both positive and negative)
+ bullet_patterns = [
+ "bullet", "bullet point", "bullet-point",
+ "bulleted", "list format", "use bullets",
+ "no bullet", "don't use bullet", "without bullet",
+ "numbered list", "use numbers",
+ ]
+ for pattern in bullet_patterns:
+ if pattern in lower_q:
+ revealed["bullets"] = True
+ break
+
+ # Language preference
+ lang_patterns_zh = [
+ "chinese", "中文", "in chinese", "用中文",
+ "speak chinese", "write chinese", "respond in chinese",
+ "please use chinese", "mandarin", "请用中文",
+ ]
+ lang_patterns_en = [
+ "english", "in english", "use english",
+ "speak english", "write english", "respond in english",
+ "please use english",
+ ]
+
+ for pattern in lang_patterns_zh + lang_patterns_en:
+ if pattern in lower_q:
+ revealed["lang"] = True
+ break
+
+ return revealed
+
+
+def update_reveal_state(reveal_state: RevealState, query: str, prefs: StylePrefs) -> Set[str]:
+ """Update reveal state based on query content."""
+ detected = detect_revealed_preferences(query, prefs)
+ newly_revealed = set()
+
+ if detected["short"] and not reveal_state.short_revealed:
+ reveal_state.short_revealed = True
+ newly_revealed.add("short")
+
+ if detected["bullets"] and not reveal_state.bullets_revealed:
+ reveal_state.bullets_revealed = True
+ newly_revealed.add("bullets")
+
+ if detected["lang"] and not reveal_state.lang_revealed:
+ reveal_state.lang_revealed = True
+ newly_revealed.add("lang")
+
+ return newly_revealed
+
+
+# =============================================================================
+# Refined Style Judge
+# =============================================================================
+
+@dataclass
+class JudgeResult:
+ """Output from the judge for one turn."""
+ sat_t: float
+ sev_t: float
+ prog_t: float
+ violations: List[str]
+ enforced_constraints: List[str]
+
+
+def style_judge_v3(
+ query: str,
+ answer: str,
+ task_type: str,
+ prefs: StylePrefs,
+ reveal_state: RevealState,
+) -> JudgeResult:
+ """
+ Refined style judge with:
+ - Bullets only enforced on list-type tasks
+ - Relaxed empty_answer (only truly empty or single char)
+ - Reveal-aware enforcement
+ """
+ violations: List[str] = []
+ enforced: List[str] = []
+ text = (answer or "").strip()
+
+ # 0) Empty answer - only truly empty or single non-meaningful char
+ # Relaxed: allow short factual answers like "4", "Paris"
+ if len(text) == 0:
+ violations.append("empty_answer")
+ return JudgeResult(
+ sat_t=0.0,
+ sev_t=1.0,
+ prog_t=0.0,
+ violations=violations,
+ enforced_constraints=["non_empty"],
+ )
+
+ # 1) Length - enforce only if BOTH true AND revealed
+ if prefs.require_short and reveal_state.short_revealed:
+ enforced.append("short")
+ if len(text) > prefs.max_chars:
+ violations.append("too_long")
+
+ # 2) Bullets - enforce ONLY on list-type tasks AND if revealed
+ # task_type "list" = listing tasks (Name X things, What are the N...)
+ # task_type "qa" = factual QA (What is the capital...)
+ # task_type "general" = other general tasks
+ if prefs.require_bullets and reveal_state.bullets_revealed:
+ if task_type == "list": # Only enforce on list tasks
+ enforced.append("bullets")
+ has_bullets = ("- " in text) or ("• " in text) or ("* " in text) or ("\n- " in text)
+ if not has_bullets:
+ violations.append("no_bullets")
+
+ # 3) Language - enforce only if revealed
+ if reveal_state.lang_revealed:
+ enforced.append("lang")
+ if prefs.lang == "zh":
+ # For Chinese: should have significant non-ASCII content
+ ascii_count = sum(c.isascii() for c in text)
+ ascii_ratio = ascii_count / max(1, len(text))
+ if ascii_ratio > 0.7:
+ violations.append("wrong_lang")
+ elif prefs.lang == "en":
+ # For English: should be mostly ASCII
+ ascii_count = sum(c.isascii() for c in text)
+ ascii_ratio = ascii_count / max(1, len(text))
+ if ascii_ratio < 0.5:
+ violations.append("wrong_lang")
+
+ # 4) Code task: always enforce
+ prog_t = 1.0
+ if task_type == "code":
+ enforced.append("code_block")
+ has_code = ("```" in text) or ("def " in text) or ("function " in text)
+ if not has_code:
+ violations.append("no_code_block")
+ prog_t = 0.0
+
+ # 5) Compute scores
+ if not violations:
+ sat_t = 1.0
+ sev_t = 0.0
+ else:
+ sat_t = max(0.0, 1.0 - 0.3 * float(len(violations)))
+ hard_violations = {"empty_answer", "too_long", "wrong_lang"}
+ sev_t = 1.0 if any(v in hard_violations for v in violations) else 0.0
+
+ return JudgeResult(
+ sat_t=sat_t,
+ sev_t=sev_t,
+ prog_t=prog_t,
+ violations=violations,
+ enforced_constraints=enforced,
+ )
+
+
+# =============================================================================
+# Feedback Computation
+# =============================================================================
+
+def compute_feedback_for_turn(
+ turn_id: int,
+ query: str,
+ query_type: str,
+ task_type: str,
+ judge_result: JudgeResult,
+) -> Tuple[float, float]:
+ """Convert JudgeResult into (reward, gating)."""
+ reward = judge_result.sat_t
+
+ lower_q = (query or "").lower()
+
+ is_pref_turn = (
+ query_type == "preference"
+ or "i prefer" in lower_q
+ or "my preference" in lower_q
+ or "please use" in lower_q
+ or "please keep" in lower_q
+ or "you didn't follow" in lower_q
+ or "you forgot" in lower_q
+ or "remember that i" in lower_q
+ or "i told you" in lower_q
+ or "i asked for" in lower_q
+ or "中文" in lower_q
+ or "用中文" in lower_q
+ )
+
+ gating = 1.0 if is_pref_turn else 0.0
+ return reward, gating
+
+
+# =============================================================================
+# Query Generation per Persona
+# =============================================================================
+
+def get_session_1_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]:
+ """
+ Session 1: Reveal preferences.
+ Customize based on persona's true preferences.
+ """
+ queries = []
+ prefs = persona.style_prefs
+
+ # Turn 0: Reveal length preference
+ if prefs.require_short:
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "我喜欢简短的回答,请保持回复在200字以内。",
+ "type": "preference",
+ "task_type": "general",
+ })
+ else:
+ queries.append({
+ "query": "I prefer short, concise answers. Please keep responses under 200 characters.",
+ "type": "preference",
+ "task_type": "general",
+ })
+ else:
+ # Long preference - don't reveal (let short_revealed stay False)
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "你好,我想了解一些问题。",
+ "type": "task",
+ "task_type": "general",
+ })
+ else:
+ queries.append({
+ "query": "Hello, I have some questions for you.",
+ "type": "task",
+ "task_type": "general",
+ })
+
+ # Turn 1: First task
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "列出三个改善睡眠的建议。",
+ "type": "task",
+ "task_type": "list",
+ })
+ else:
+ queries.append({
+ "query": "List three tips for better sleep.",
+ "type": "task",
+ "task_type": "list",
+ })
+
+ # Turn 2: Reveal bullet preference
+ if prefs.require_bullets:
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "我喜欢用项目符号列出要点,请使用bullet points。",
+ "type": "preference",
+ "task_type": "general",
+ })
+ else:
+ queries.append({
+ "query": "I prefer bullet points when listing things. Please use bullet points.",
+ "type": "preference",
+ "task_type": "general",
+ })
+ else:
+ # Don't reveal bullet preference (or reveal anti-bullet)
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "请不要用项目符号,我更喜欢连续的句子。",
+ "type": "preference",
+ "task_type": "general",
+ })
+ else:
+ queries.append({
+ "query": "Please don't use bullet points. I prefer continuous prose.",
+ "type": "preference",
+ "task_type": "general",
+ })
+
+ # Turn 3: Reveal language preference (for non-English personas)
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "请用中文回答我的问题。",
+ "type": "preference",
+ "task_type": "general",
+ })
+ else:
+ queries.append({
+ "query": "Please respond in English.",
+ "type": "preference",
+ "task_type": "general",
+ })
+
+ # Turn 4-5: Tasks
+ if prefs.lang == "zh":
+ queries.extend([
+ {
+ "query": "锻炼有什么好处?",
+ "type": "task",
+ "task_type": "list",
+ },
+ {
+ "query": "列出五种流行的编程语言。",
+ "type": "task",
+ "task_type": "list",
+ },
+ ])
+ else:
+ queries.extend([
+ {
+ "query": "What are the benefits of exercise?",
+ "type": "task",
+ "task_type": "list",
+ },
+ {
+ "query": "Name five popular programming languages.",
+ "type": "task",
+ "task_type": "list",
+ },
+ ])
+
+ return queries
+
+
+def get_session_2_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]:
+ """
+ Session 2: NO preference restatement.
+ Tests cross-session retention.
+ """
+ prefs = persona.style_prefs
+
+ if prefs.lang == "zh":
+ return [
+ {"query": "推荐三种健康的早餐。", "type": "task", "task_type": "list"},
+ {"query": "一年有哪四个季节?", "type": "task", "task_type": "list"},
+ {"query": "法国的首都是哪里?", "type": "task", "task_type": "qa"},
+ {"query": "列出三种可再生能源。", "type": "task", "task_type": "list"},
+ ]
+ else:
+ return [
+ {"query": "What are three healthy breakfast ideas?", "type": "task", "task_type": "list"},
+ {"query": "What are the four seasons of the year?", "type": "task", "task_type": "list"},
+ {"query": "What is the capital of France?", "type": "task", "task_type": "qa"},
+ {"query": "Name three types of renewable energy.", "type": "task", "task_type": "list"},
+ ]
+
+
+def get_session_3_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]:
+ """
+ Session 3: Mix of tasks and one reminder.
+ """
+ prefs = persona.style_prefs
+
+ if prefs.lang == "zh":
+ return [
+ {"query": "列出五种常见的水果。", "type": "task", "task_type": "list"},
+ {"query": "请记住我喜欢简短的回答。列出三种海洋动物。", "type": "preference", "task_type": "list"},
+ {"query": "2加2等于多少?", "type": "task", "task_type": "qa"},
+ ]
+ else:
+ return [
+ {"query": "Name five common fruits.", "type": "task", "task_type": "list"},
+ {"query": "Remember that I asked for short answers. List three ocean animals.", "type": "preference", "task_type": "list"},
+ {"query": "What is 2 + 2?", "type": "task", "task_type": "qa"},
+ ]
+
+
+# =============================================================================
+# Logging
+# =============================================================================
+
+@dataclass
+class TurnLog:
+ """Log entry for one turn."""
+ user_id: str
+ persona_id: str
+ session_id: int
+ turn_id: int
+ query: str
+ query_type: str
+ task_type: str
+ answer: str
+ answer_length: int
+ sat_t: float
+ sev_t: float
+ prog_t: float
+ violations: List[str]
+ enforced_constraints: List[str]
+ reward: float
+ gating: float
+ reveal_state_before: Dict[str, bool]
+ reveal_state_after: Dict[str, bool]
+ newly_revealed: List[str]
+ z_long_norm_before: float
+ z_long_norm_after: float
+ z_short_norm_before: float
+ z_short_norm_after: float
+ prompt_tokens: int
+ completion_tokens: int
+ total_tokens: int
+ num_memories_retrieved: int
+ num_prefs_extracted: int
+
+
+def log_to_jsonl(logs: List[TurnLog], filepath: str):
+ """Save logs to JSONL file."""
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
+ with open(filepath, "w") as f:
+ for log in logs:
+ f.write(json.dumps(asdict(log)) + "\n")
+
+
+# =============================================================================
+# Session Runner
+# =============================================================================
+
+def run_session(
+ llm: PersonalizedLLM,
+ user_id: str,
+ persona: Persona,
+ session_id: int,
+ reveal_state: RevealState,
+ queries: List[Dict[str, Any]],
+ all_logs: List[TurnLog],
+) -> List[TurnLog]:
+ """Run a single session for a user."""
+
+ prefs = persona.style_prefs
+ session_logs: List[TurnLog] = []
+
+ print(f"\n{'='*60}")
+ print(f"[{persona.persona_id}] Session {session_id}: {len(queries)} turns")
+ print(f"Reveal state (start): {reveal_state}")
+ print(f"{'='*60}")
+
+ # Reset session (clears history, z_short; keeps z_long and reveal state)
+ llm.reset_session(user_id)
+
+ for turn_id, q_info in enumerate(queries):
+ query = q_info["query"]
+ query_type = q_info.get("type", "task")
+ task_type = q_info.get("task_type", "general")
+
+ print(f"\n--- S{session_id}/T{turn_id} [{query_type}] ---")
+ print(f"[Q] {query[:60]}{'...' if len(query) > 60 else ''}")
+
+ # Capture reveal state BEFORE
+ reveal_before = reveal_state.to_dict()
+
+ # Update reveal state
+ newly_revealed = update_reveal_state(reveal_state, query, prefs)
+ if newly_revealed:
+ print(f"[Reveal] Newly: {newly_revealed}")
+
+ reveal_after = reveal_state.to_dict()
+
+ # Get user state before
+ state_before = llm.get_user_state_summary(user_id)
+ z_long_before = state_before["z_long_norm"]
+ z_short_before = state_before["z_short_norm"]
+
+ # Apply feedback for previous turn
+ if turn_id > 0 and session_logs:
+ prev_log = session_logs[-1]
+ feedback = Feedback(
+ user_id=user_id,
+ turn_id=prev_log.turn_id,
+ reward=prev_log.reward,
+ gating=prev_log.gating,
+ meta={"source": "pilot_v3", "session_id": session_id}
+ )
+ llm.apply_feedback(feedback)
+
+ # Chat
+ resp: AssistantResponse = llm.chat(user_id, query)
+
+ answer_display = resp.answer[:80] + "..." if len(resp.answer) > 80 else resp.answer
+ print(f"[A] ({len(resp.answer)}c) {answer_display}")
+
+ # Judge
+ judge_result = style_judge_v3(query, resp.answer, task_type, prefs, reveal_state)
+ print(f"[J] sat={judge_result.sat_t:.2f}, enforced={judge_result.enforced_constraints}, viol={judge_result.violations}")
+
+ # Compute feedback
+ reward, gating = compute_feedback_for_turn(turn_id, query, query_type, task_type, judge_result)
+
+ # Get state after
+ state_after = llm.get_user_state_summary(user_id)
+ z_long_after = state_after["z_long_norm"]
+ z_short_after = state_after["z_short_norm"]
+
+ # Debug info
+ num_memories = len(resp.debug.selected_memory_ids) if resp.debug else 0
+ num_prefs = len(resp.debug.extracted_preferences) if resp.debug else 0
+
+ # Log
+ log = TurnLog(
+ user_id=user_id,
+ persona_id=persona.persona_id,
+ session_id=session_id,
+ turn_id=turn_id,
+ query=query,
+ query_type=query_type,
+ task_type=task_type,
+ answer=resp.answer,
+ answer_length=len(resp.answer),
+ sat_t=judge_result.sat_t,
+ sev_t=judge_result.sev_t,
+ prog_t=judge_result.prog_t,
+ violations=judge_result.violations,
+ enforced_constraints=judge_result.enforced_constraints,
+ reward=reward,
+ gating=gating,
+ reveal_state_before=reveal_before,
+ reveal_state_after=reveal_after,
+ newly_revealed=list(newly_revealed),
+ z_long_norm_before=z_long_before,
+ z_long_norm_after=z_long_after,
+ z_short_norm_before=z_short_before,
+ z_short_norm_after=z_short_after,
+ prompt_tokens=resp.usage.prompt_tokens,
+ completion_tokens=resp.usage.completion_tokens,
+ total_tokens=resp.usage.total_tokens,
+ num_memories_retrieved=num_memories,
+ num_prefs_extracted=num_prefs,
+ )
+ session_logs.append(log)
+ all_logs.append(log)
+
+ # Apply final feedback
+ if session_logs:
+ last_log = session_logs[-1]
+ feedback = Feedback(
+ user_id=user_id,
+ turn_id=last_log.turn_id,
+ reward=last_log.reward,
+ gating=last_log.gating,
+ meta={"source": "pilot_v3", "session_id": session_id, "final": True}
+ )
+ llm.apply_feedback(feedback)
+
+ return session_logs
+
+
+# =============================================================================
+# Multi-User Multi-Session Runner
+# =============================================================================
+
+def run_multi_user_pilot(
+ llm: PersonalizedLLM,
+ personas: List[Persona],
+ num_sessions: int = 3,
+ reveal_manager: Optional[RevealStateManager] = None,
+) -> List[TurnLog]:
+ """
+ Run multi-user multi-session pilot.
+
+ Args:
+ llm: PersonalizedLLM instance
+ personas: List of personas to test
+ num_sessions: Number of sessions per user
+ reveal_manager: Optional existing reveal manager
+ """
+ if reveal_manager is None:
+ reveal_manager = RevealStateManager()
+
+ all_logs: List[TurnLog] = []
+
+ print(f"\n{'#'*60}")
+ print(f"PILOT v3: MULTI-USER MULTI-SESSION")
+ print(f"Users: {len(personas)}, Sessions per user: {num_sessions}")
+ print(f"{'#'*60}")
+
+ for persona in personas:
+ user_id = f"user_{persona.persona_id}"
+ prefs = persona.style_prefs
+
+ print(f"\n{'*'*60}")
+ print(f"USER: {user_id}")
+ print(f"Persona: {persona.description}")
+ print(f"True prefs: short={prefs.require_short}, bullets={prefs.require_bullets}, lang={prefs.lang}")
+ print(f"{'*'*60}")
+
+ # Reset user completely
+ llm.reset_user(user_id)
+ reveal_manager.reset_user(user_id)
+ reveal_state = reveal_manager.get_state(user_id)
+
+ # Run sessions
+ for session_id in range(1, num_sessions + 1):
+ if session_id == 1:
+ queries = get_session_1_queries_for_persona(persona)
+ elif session_id == 2:
+ queries = get_session_2_queries_for_persona(persona)
+ else:
+ queries = get_session_3_queries_for_persona(persona)
+
+ reveal_manager.reset_session(user_id) # No-op, just for clarity
+ run_session(llm, user_id, persona, session_id, reveal_state, queries, all_logs)
+
+ return all_logs
+
+
+# =============================================================================
+# Summary
+# =============================================================================
+
+def print_summary_v3(logs: List[TurnLog]):
+ """Print summary for pilot v3."""
+ print(f"\n{'='*60}")
+ print("PILOT v3 SUMMARY - Multi-User Multi-Session")
+ print(f"{'='*60}")
+
+ if not logs:
+ print("No logs.")
+ return
+
+ from collections import Counter, defaultdict
+
+ # Per-persona stats
+ personas = sorted(set(l.persona_id for l in logs))
+
+ print(f"\n--- Per-Persona Statistics ---")
+ for pid in personas:
+ p_logs = [l for l in logs if l.persona_id == pid]
+
+ # Per-session breakdown
+ sessions = sorted(set(l.session_id for l in p_logs))
+
+ print(f"\n{pid}:")
+ for sid in sessions:
+ s_logs = [l for l in p_logs if l.session_id == sid]
+ avg_sat = sum(l.sat_t for l in s_logs) / len(s_logs) if s_logs else 0
+ violations = [v for l in s_logs for v in l.violations]
+ enforced = set(c for l in s_logs for c in l.enforced_constraints)
+
+ print(f" Session {sid}: {len(s_logs)} turns, avg_sat={avg_sat:.3f}, enforced={enforced}")
+ if violations:
+ print(f" violations: {dict(Counter(violations))}")
+
+ # Cross-session retention check
+ print(f"\n--- Cross-Session Retention ---")
+ for pid in personas:
+ p_logs = [l for l in logs if l.persona_id == pid]
+ s1_logs = [l for l in p_logs if l.session_id == 1]
+ s2_logs = [l for l in p_logs if l.session_id == 2]
+
+ if s1_logs and s2_logs:
+ s1_sat = sum(l.sat_t for l in s1_logs) / len(s1_logs)
+ s2_sat = sum(l.sat_t for l in s2_logs) / len(s2_logs)
+
+ # Check what was enforced in S2
+ s2_enforced = set(c for l in s2_logs for c in l.enforced_constraints)
+
+ print(f"{pid}: S1_sat={s1_sat:.3f} → S2_sat={s2_sat:.3f}, S2_enforced={s2_enforced}")
+
+ # Overall stats
+ total = len(logs)
+ avg_sat = sum(l.sat_t for l in logs) / total
+ total_tokens = sum(l.total_tokens for l in logs)
+
+ print(f"\n--- Overall ---")
+ print(f"Total turns: {total}")
+ print(f"Overall avg sat_t: {avg_sat:.3f}")
+ print(f"Total tokens: {total_tokens}")
+
+ # Violations by type
+ all_violations = [v for l in logs for v in l.violations]
+ if all_violations:
+ print(f"\nViolations: {dict(Counter(all_violations))}")
+
+
+def main():
+ print("=" * 60)
+ print("PILOT RUNNER v3 - Multi-User Multi-Session with Personas")
+ print("=" * 60)
+ print(f"Started at: {datetime.now().isoformat()}")
+
+ # Select personas
+ personas = ALL_PERSONAS # All 5 personas
+ print(f"\n[Config] Running {len(personas)} personas:")
+ for p in personas:
+ print(f" - {p.persona_id}: {p.description}")
+
+ # Initialize LLM
+ print("\n[Init] Loading PersonalizedLLM...")
+ llm = PersonalizedLLM(
+ user_store_path="data/users/user_store_pilot_v3.npz",
+ only_own_memories=True,
+ enable_preference_extraction=True,
+ enable_rl_updates=True,
+ )
+
+ # Run pilot
+ logs = run_multi_user_pilot(llm, personas, num_sessions=3)
+
+ # Summary
+ print_summary_v3(logs)
+
+ # Save logs
+ log_path = f"data/logs/pilot_v3_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl"
+ log_to_jsonl(logs, log_path)
+ print(f"\n[Logs] Saved to: {log_path}")
+
+ print(f"\nCompleted at: {datetime.now().isoformat()}")
+ print("=" * 60)
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/scripts/pilot_runner_v4.py b/scripts/pilot_runner_v4.py
new file mode 100644
index 0000000..b3e2058
--- /dev/null
+++ b/scripts/pilot_runner_v4.py
@@ -0,0 +1,1230 @@
+#!/usr/bin/env python3
+"""
+Pilot Runner v4 - Critical Fixes for Baseline Comparison
+
+Fixes from v3:
+1. Chinese short reveal detection (简短/字以内/不超过 etc.)
+2. Symmetric bullets constraint (has_bullets violation for require_bullets=False)
+3. Better wrong_lang with CJK ratio + math exemption
+4. Persona-conditional query templates (no self-contradiction)
+5. Violation-triggered complaint mechanism (for online RL signal)
+
+This version is ready for proper baseline comparison.
+"""
+
+import sys
+import os
+import re
+import json
+from datetime import datetime
+from dataclasses import dataclass, asdict, field
+from typing import List, Dict, Any, Optional, Tuple, Set, Literal
+
+# Add src to path
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.serving import PersonalizedLLM, Feedback, AssistantResponse
+
+
+# =============================================================================
+# Style Preferences (True Preferences)
+# =============================================================================
+
+@dataclass
+class StylePrefs:
+ """User's TRUE style preferences."""
+ require_short: bool = False
+ max_chars: int = 300
+ require_bullets: bool = False # True = want bullets, False = don't want bullets
+ lang: str = "en" # "en" or "zh"
+
+
+# =============================================================================
+# Persona Definition
+# =============================================================================
+
+@dataclass
+class Persona:
+ """A user persona that bundles style preferences."""
+ persona_id: str
+ style_prefs: StylePrefs
+ description: str = ""
+
+
+# 5 Test Personas
+PERSONA_A = Persona(
+ persona_id="A_short_bullets_en",
+ style_prefs=StylePrefs(require_short=True, max_chars=200, require_bullets=True, lang="en"),
+ description="Short + bullets + English",
+)
+
+PERSONA_B = Persona(
+ persona_id="B_short_no_bullets_en",
+ style_prefs=StylePrefs(require_short=True, max_chars=200, require_bullets=False, lang="en"),
+ description="Short + NO bullets + English (anti-bullet)",
+)
+
+PERSONA_C = Persona(
+ persona_id="C_long_bullets_en",
+ style_prefs=StylePrefs(require_short=False, max_chars=800, require_bullets=True, lang="en"),
+ description="Long + bullets + English",
+)
+
+PERSONA_D = Persona(
+ persona_id="D_short_bullets_zh",
+ style_prefs=StylePrefs(require_short=True, max_chars=200, require_bullets=True, lang="zh"),
+ description="Short + bullets + Chinese",
+)
+
+PERSONA_E = Persona(
+ persona_id="E_long_no_bullets_zh",
+ style_prefs=StylePrefs(require_short=False, max_chars=800, require_bullets=False, lang="zh"),
+ description="Long + NO bullets + Chinese (most anti-default)",
+)
+
+# Extreme short persona for case study - LLM default is much longer
+PERSONA_F = Persona(
+ persona_id="F_extreme_short_en",
+ style_prefs=StylePrefs(require_short=True, max_chars=100, require_bullets=True, lang="en"),
+ description="EXTREME short (100 chars) + bullets + English",
+)
+
+ALL_PERSONAS = [PERSONA_A, PERSONA_B, PERSONA_C, PERSONA_D, PERSONA_E, PERSONA_F]
+
+
+# =============================================================================
+# Reveal State
+# =============================================================================
+
+@dataclass
+class RevealState:
+ """Tracks which preferences have been explicitly revealed."""
+ short_revealed: bool = False
+ bullets_revealed: bool = False
+ lang_revealed: bool = False
+
+ def reset(self):
+ self.short_revealed = False
+ self.bullets_revealed = False
+ self.lang_revealed = False
+
+ def to_dict(self) -> Dict[str, bool]:
+ return {"short": self.short_revealed, "bullets": self.bullets_revealed, "lang": self.lang_revealed}
+
+ def __str__(self) -> str:
+ flags = [k for k, v in self.to_dict().items() if v]
+ return f"RevealState({', '.join(flags) if flags else 'none'})"
+
+
+class RevealStateManager:
+ """Manages reveal state for multiple users."""
+ def __init__(self):
+ self._states: Dict[str, RevealState] = {}
+
+ def get_state(self, user_id: str) -> RevealState:
+ if user_id not in self._states:
+ self._states[user_id] = RevealState()
+ return self._states[user_id]
+
+ def reset_user(self, user_id: str):
+ self._states[user_id] = RevealState()
+
+ def reset_session(self, user_id: str):
+ pass # Reveal state persists
+
+
+# =============================================================================
+# FIX 1: Improved Preference Detection (with Chinese support)
+# =============================================================================
+
+def detect_revealed_preferences(query: str, prefs: StylePrefs) -> Dict[str, bool]:
+ """
+ Detect which preferences are mentioned in a query.
+ FIX: Added Chinese keywords for short detection.
+ """
+ lower_q = (query or "").lower()
+ original_q = query or ""
+
+ revealed = {"short": False, "bullets": False, "lang": False}
+
+ # Short/length preference - English patterns
+ short_patterns_en = [
+ "short", "concise", "brief", "under ", "less than",
+ "keep it short", "keep responses", "keep answers",
+ "maximum ", "max ", "characters", "words or less",
+ ]
+
+ # FIX: Chinese patterns for short preference
+ short_patterns_zh = [
+ "简短", "精简", "尽量短", "不要太长", "字以内", "不超过",
+ "少于", "控制在", "简洁", "简明",
+ ]
+
+ # Regex patterns for number-based length constraints
+ short_regex_patterns = [
+ r"(\d+)\s*字以内", # "200字以内"
+ r"不超过\s*(\d+)\s*字", # "不超过200字"
+ r"under\s*(\d+)", # "under 200"
+ r"less\s*than\s*(\d+)", # "less than 200"
+ ]
+
+ for pattern in short_patterns_en:
+ if pattern in lower_q:
+ revealed["short"] = True
+ break
+
+ if not revealed["short"]:
+ for pattern in short_patterns_zh:
+ if pattern in original_q:
+ revealed["short"] = True
+ break
+
+ if not revealed["short"]:
+ for regex in short_regex_patterns:
+ if re.search(regex, original_q, re.IGNORECASE):
+ revealed["short"] = True
+ break
+
+ # Bullet preference - both positive and negative
+ bullet_patterns_positive = [
+ "bullet", "bullet point", "bullet-point", "bulleted",
+ "list format", "use bullets", "use bullet",
+ "项目符号", "要点", "用bullet",
+ ]
+ bullet_patterns_negative = [
+ "no bullet", "don't use bullet", "without bullet",
+ "不要bullet", "不要项目符号", "不用bullet",
+ "continuous prose", "paragraph form", "flowing text",
+ "连续句子", "段落形式",
+ ]
+
+ for pattern in bullet_patterns_positive + bullet_patterns_negative:
+ if pattern in lower_q or pattern in original_q:
+ revealed["bullets"] = True
+ break
+
+ # Language preference
+ lang_patterns_zh = [
+ "chinese", "中文", "in chinese", "用中文",
+ "speak chinese", "respond in chinese", "请用中文",
+ ]
+ lang_patterns_en = [
+ "english", "in english", "use english",
+ "speak english", "respond in english",
+ ]
+
+ for pattern in lang_patterns_zh + lang_patterns_en:
+ if pattern in lower_q or pattern in original_q:
+ revealed["lang"] = True
+ break
+
+ return revealed
+
+
+def update_reveal_state(reveal_state: RevealState, query: str, prefs: StylePrefs) -> Set[str]:
+ """Update reveal state based on query content."""
+ detected = detect_revealed_preferences(query, prefs)
+ newly_revealed = set()
+
+ if detected["short"] and not reveal_state.short_revealed:
+ reveal_state.short_revealed = True
+ newly_revealed.add("short")
+
+ if detected["bullets"] and not reveal_state.bullets_revealed:
+ reveal_state.bullets_revealed = True
+ newly_revealed.add("bullets")
+
+ if detected["lang"] and not reveal_state.lang_revealed:
+ reveal_state.lang_revealed = True
+ newly_revealed.add("lang")
+
+ return newly_revealed
+
+
+# =============================================================================
+# FIX 3: Better Language Detection
+# =============================================================================
+
+def is_math_or_symbol_only(text: str) -> bool:
+ """Check if text is purely math/symbols (language neutral)."""
+ # Pattern: only digits, operators, whitespace, punctuation
+ math_pattern = r'^[\d+\-*/=().,%\s\n\r]+$'
+ return bool(re.match(math_pattern, text.strip()))
+
+
+def count_cjk_chars(text: str) -> int:
+ """Count CJK (Chinese/Japanese/Korean) characters."""
+ # CJK Unified Ideographs range
+ cjk_pattern = re.compile(r'[\u4e00-\u9fff\u3400-\u4dbf]')
+ return len(cjk_pattern.findall(text))
+
+
+def count_latin_letters(text: str) -> int:
+ """Count Latin letters (a-z, A-Z)."""
+ return sum(1 for c in text if c.isalpha() and c.isascii())
+
+
+def check_language_violation(text: str, target_lang: str) -> bool:
+ """
+ FIX: Better language violation check using CJK ratio.
+ Returns True if there's a violation.
+ """
+ text = text.strip()
+
+ # Exempt pure math/symbols
+ if is_math_or_symbol_only(text):
+ return False
+
+ cjk_count = count_cjk_chars(text)
+ latin_count = count_latin_letters(text)
+ total = cjk_count + latin_count
+
+ if total == 0:
+ return False # No meaningful text to judge
+
+ if target_lang == "zh":
+ # For Chinese: want high CJK ratio
+ cjk_ratio = cjk_count / (total + 1e-9)
+ # Allow some English proper nouns - only flag if very low CJK
+ return cjk_ratio < 0.2 # Less than 20% CJK = wrong language
+
+ elif target_lang == "en":
+ # For English: want high Latin ratio
+ latin_ratio = latin_count / (total + 1e-9)
+ return latin_ratio < 0.5 # Less than 50% Latin = wrong language
+
+ return False
+
+
+# =============================================================================
+# FIX 2: Symmetric Bullets Constraint + FIX 3: Language
+# =============================================================================
+
+@dataclass
+class JudgeResult:
+ """Output from the judge for one turn."""
+ sat_t: float
+ sev_t: float
+ prog_t: float
+ violations: List[str]
+ enforced_constraints: List[str]
+
+
+def has_bullet_markers(text: str) -> bool:
+ """Check if text contains bullet point markers."""
+ return bool(re.search(r'(^|\n)\s*[-•*]\s', text))
+
+
+def style_judge_v4(
+ query: str,
+ answer: str,
+ task_type: str,
+ prefs: StylePrefs,
+ reveal_state: RevealState,
+) -> JudgeResult:
+ """
+ Style judge v4 with:
+ - FIX 2: Symmetric bullets (has_bullets violation for require_bullets=False)
+ - FIX 3: Better wrong_lang with CJK ratio + math exemption
+ """
+ violations: List[str] = []
+ enforced: List[str] = []
+ text = (answer or "").strip()
+
+ # 0) Empty answer
+ if len(text) == 0:
+ violations.append("empty_answer")
+ return JudgeResult(sat_t=0.0, sev_t=1.0, prog_t=0.0,
+ violations=violations, enforced_constraints=["non_empty"])
+
+ # 1) Length - enforce only if true AND revealed
+ if prefs.require_short and reveal_state.short_revealed:
+ enforced.append("short")
+ if len(text) > prefs.max_chars:
+ violations.append("too_long")
+
+ # 2) FIX 2: Symmetric bullets constraint (only for list tasks)
+ if reveal_state.bullets_revealed and task_type == "list":
+ has_bullets = has_bullet_markers(text)
+
+ if prefs.require_bullets:
+ # Want bullets but don't have them
+ enforced.append("require_bullets")
+ if not has_bullets:
+ violations.append("no_bullets")
+ else:
+ # Don't want bullets but have them
+ enforced.append("no_bullets_pref")
+ if has_bullets:
+ violations.append("has_bullets")
+
+ # 3) FIX 3: Language with CJK ratio + math exemption
+ if reveal_state.lang_revealed:
+ enforced.append("lang")
+ if check_language_violation(text, prefs.lang):
+ violations.append("wrong_lang")
+
+ # 4) Code task
+ prog_t = 1.0
+ if task_type == "code":
+ enforced.append("code_block")
+ has_code = ("```" in text) or ("def " in text) or ("function " in text)
+ if not has_code:
+ violations.append("no_code_block")
+ prog_t = 0.0
+
+ # 5) Compute scores
+ if not violations:
+ sat_t = 1.0
+ sev_t = 0.0
+ else:
+ sat_t = max(0.0, 1.0 - 0.3 * float(len(violations)))
+ hard_violations = {"empty_answer", "too_long", "wrong_lang"}
+ sev_t = 1.0 if any(v in hard_violations for v in violations) else 0.0
+
+ return JudgeResult(sat_t=sat_t, sev_t=sev_t, prog_t=prog_t,
+ violations=violations, enforced_constraints=enforced)
+
+
+# =============================================================================
+# Feedback Computation
+# =============================================================================
+
+def compute_feedback_for_turn(
+ query: str,
+ query_type: str,
+ judge_result: JudgeResult,
+) -> Tuple[float, float]:
+ """Convert JudgeResult into (reward, gating)."""
+ reward = judge_result.sat_t
+
+ lower_q = (query or "").lower()
+ original_q = query or ""
+
+ is_pref_turn = (
+ query_type == "preference"
+ or "i prefer" in lower_q
+ or "my preference" in lower_q
+ or "please use" in lower_q
+ or "please keep" in lower_q
+ or "you didn't follow" in lower_q
+ or "you forgot" in lower_q
+ or "remember that i" in lower_q
+ or "i told you" in lower_q
+ or "i asked for" in lower_q
+ or "that was too" in lower_q
+ or "too long" in lower_q
+ or "请用中文" in original_q
+ or "不要" in original_q
+ or "简短" in original_q
+ )
+
+ gating = 1.0 if is_pref_turn else 0.0
+ return reward, gating
+
+
+# =============================================================================
+# FIX 5: Violation-Triggered Complaint Generation
+# =============================================================================
+
+def generate_complaint_query(violations: List[str], prefs: StylePrefs) -> Optional[Dict[str, Any]]:
+ """
+ Generate a complaint query based on violations.
+ Returns None if no complaint needed.
+ """
+ if not violations:
+ return None
+
+ # Priority: address most severe violation first
+ complaint = None
+
+ if "too_long" in violations:
+ if prefs.lang == "zh":
+ complaint = {
+ "query": f"回答太长了。请保持回复在{prefs.max_chars}字以内。",
+ "type": "preference",
+ "task_type": "general",
+ }
+ else:
+ complaint = {
+ "query": f"That was too long. Please keep responses under {prefs.max_chars} characters.",
+ "type": "preference",
+ "task_type": "general",
+ }
+
+ elif "wrong_lang" in violations:
+ if prefs.lang == "zh":
+ complaint = {
+ "query": "请用中文回答。",
+ "type": "preference",
+ "task_type": "general",
+ }
+ else:
+ complaint = {
+ "query": "Please respond in English.",
+ "type": "preference",
+ "task_type": "general",
+ }
+
+ elif "no_bullets" in violations:
+ if prefs.lang == "zh":
+ complaint = {
+ "query": "请在列出内容时使用项目符号(bullet points)。",
+ "type": "preference",
+ "task_type": "general",
+ }
+ else:
+ complaint = {
+ "query": "Please use bullet points when listing things.",
+ "type": "preference",
+ "task_type": "general",
+ }
+
+ elif "has_bullets" in violations:
+ if prefs.lang == "zh":
+ complaint = {
+ "query": "请不要使用项目符号,用连续的句子来表达。",
+ "type": "preference",
+ "task_type": "general",
+ }
+ else:
+ complaint = {
+ "query": "Please don't use bullet points. Use continuous prose instead.",
+ "type": "preference",
+ "task_type": "general",
+ }
+
+ return complaint
+
+
+# =============================================================================
+# FIX 4: Persona-Conditional Query Templates
+# =============================================================================
+
+def get_session_1_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]:
+ """
+ Session 1: Reveal preferences (persona-conditional).
+ FIX: Only reveal preferences that match the persona's true prefs.
+ """
+ queries = []
+ prefs = persona.style_prefs
+
+ # Turn 0: Reveal length preference (only if require_short=True)
+ if prefs.require_short:
+ if prefs.lang == "zh":
+ queries.append({
+ "query": f"我喜欢简短的回答,请保持回复在{prefs.max_chars}字以内。",
+ "type": "preference",
+ "task_type": "general",
+ })
+ else:
+ queries.append({
+ "query": f"I prefer short, concise answers. Please keep responses under {prefs.max_chars} characters.",
+ "type": "preference",
+ "task_type": "general",
+ })
+ else:
+ # Don't reveal short preference for long-preferring personas
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "你好,我有一些问题想问你。",
+ "type": "task",
+ "task_type": "general",
+ })
+ else:
+ queries.append({
+ "query": "Hello, I have some questions for you.",
+ "type": "task",
+ "task_type": "general",
+ })
+
+ # Turn 1: First task
+ if prefs.lang == "zh":
+ queries.append({"query": "列出三个改善睡眠的建议。", "type": "task", "task_type": "list"})
+ else:
+ queries.append({"query": "List three tips for better sleep.", "type": "task", "task_type": "list"})
+
+ # Turn 2: Reveal bullet preference (conditional on require_bullets)
+ if prefs.require_bullets:
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "我喜欢用项目符号列出要点,请使用bullet points。",
+ "type": "preference",
+ "task_type": "general",
+ })
+ else:
+ queries.append({
+ "query": "I prefer bullet points when listing things. Please use bullet points.",
+ "type": "preference",
+ "task_type": "general",
+ })
+ else:
+ # Explicitly say NO bullets
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "请不要用项目符号,我更喜欢连续的句子来表达。",
+ "type": "preference",
+ "task_type": "general",
+ })
+ else:
+ queries.append({
+ "query": "Please don't use bullet points. I prefer continuous prose.",
+ "type": "preference",
+ "task_type": "general",
+ })
+
+ # Turn 3: Reveal language preference
+ if prefs.lang == "zh":
+ queries.append({"query": "请用中文回答我的问题。", "type": "preference", "task_type": "general"})
+ else:
+ queries.append({"query": "Please respond in English.", "type": "preference", "task_type": "general"})
+
+ # Turn 4-5: Tasks
+ if prefs.lang == "zh":
+ queries.extend([
+ {"query": "锻炼有什么好处?", "type": "task", "task_type": "list"},
+ {"query": "列出五种流行的编程语言。", "type": "task", "task_type": "list"},
+ ])
+ else:
+ queries.extend([
+ {"query": "What are the benefits of exercise?", "type": "task", "task_type": "list"},
+ {"query": "Name five popular programming languages.", "type": "task", "task_type": "list"},
+ ])
+
+ return queries
+
+
+def get_session_2_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]:
+ """Session 2: NO preference restatement."""
+ prefs = persona.style_prefs
+
+ if prefs.lang == "zh":
+ return [
+ {"query": "推荐三种健康的早餐。", "type": "task", "task_type": "list"},
+ {"query": "一年有哪四个季节?", "type": "task", "task_type": "list"},
+ {"query": "法国的首都是哪里?", "type": "task", "task_type": "qa"},
+ {"query": "列出三种可再生能源。", "type": "task", "task_type": "list"},
+ ]
+ else:
+ return [
+ {"query": "What are three healthy breakfast ideas?", "type": "task", "task_type": "list"},
+ {"query": "What are the four seasons of the year?", "type": "task", "task_type": "list"},
+ {"query": "What is the capital of France?", "type": "task", "task_type": "qa"},
+ {"query": "Name three types of renewable energy.", "type": "task", "task_type": "list"},
+ ]
+
+
+def get_session_3_queries_for_persona(persona: Persona) -> List[Dict[str, Any]]:
+ """
+ Session 3: Tasks with ONE persona-conditional reminder.
+ FIX: Reminder matches persona's actual preferences.
+ """
+ prefs = persona.style_prefs
+ queries = []
+
+ # First task
+ if prefs.lang == "zh":
+ queries.append({"query": "列出五种常见的水果。", "type": "task", "task_type": "list"})
+ else:
+ queries.append({"query": "Name five common fruits.", "type": "task", "task_type": "list"})
+
+ # Persona-conditional reminder
+ if prefs.require_short and prefs.require_bullets:
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "记住我喜欢简短的回答和项目符号。列出三种海洋动物。",
+ "type": "preference", "task_type": "list"
+ })
+ else:
+ queries.append({
+ "query": "Remember I prefer short answers with bullet points. List three ocean animals.",
+ "type": "preference", "task_type": "list"
+ })
+ elif prefs.require_short and not prefs.require_bullets:
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "记住我喜欢简短的回答,不要用项目符号。列出三种海洋动物。",
+ "type": "preference", "task_type": "list"
+ })
+ else:
+ queries.append({
+ "query": "Remember I prefer short answers without bullet points. List three ocean animals.",
+ "type": "preference", "task_type": "list"
+ })
+ elif not prefs.require_short and prefs.require_bullets:
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "记住我喜欢用项目符号列出要点。列出三种海洋动物。",
+ "type": "preference", "task_type": "list"
+ })
+ else:
+ queries.append({
+ "query": "Remember I prefer bullet points. List three ocean animals.",
+ "type": "preference", "task_type": "list"
+ })
+ else: # not short and not bullets
+ if prefs.lang == "zh":
+ queries.append({
+ "query": "记住我不喜欢用项目符号,喜欢连续的句子。列出三种海洋动物。",
+ "type": "preference", "task_type": "list"
+ })
+ else:
+ queries.append({
+ "query": "Remember I prefer continuous prose without bullet points. List three ocean animals.",
+ "type": "preference", "task_type": "list"
+ })
+
+ # Final task
+ if prefs.lang == "zh":
+ queries.append({"query": "2加2等于多少?", "type": "task", "task_type": "qa"})
+ else:
+ queries.append({"query": "What is 2 + 2?", "type": "task", "task_type": "qa"})
+
+ return queries
+
+
+def get_pure_task_queries_for_persona(persona: Persona, session_idx: int) -> List[Dict[str, Any]]:
+ """
+ Pure task sessions (S4+): NO preference reminders at all.
+ Used for testing long-term retention without any in-context hints.
+ Different task sets per session to avoid repetition.
+ """
+ prefs = persona.style_prefs
+
+ # Task pools for variety
+ zh_task_pools = [
+ # Pool 1
+ [
+ {"query": "列出三种热带水果。", "type": "task", "task_type": "list"},
+ {"query": "列出三种常见的编程语言。", "type": "task", "task_type": "list"},
+ {"query": "什么是光合作用?", "type": "task", "task_type": "qa"},
+ {"query": "太阳系有几颗行星?", "type": "task", "task_type": "qa"},
+ ],
+ # Pool 2
+ [
+ {"query": "列出三种室内植物。", "type": "task", "task_type": "list"},
+ {"query": "列出三种运动项目。", "type": "task", "task_type": "list"},
+ {"query": "什么是人工智能?", "type": "task", "task_type": "qa"},
+ {"query": "地球的自转周期是多少?", "type": "task", "task_type": "qa"},
+ ],
+ # Pool 3
+ [
+ {"query": "列出三种乐器。", "type": "task", "task_type": "list"},
+ {"query": "列出三种社交媒体平台。", "type": "task", "task_type": "list"},
+ {"query": "什么是区块链?", "type": "task", "task_type": "qa"},
+ {"query": "月球绕地球一周需要多长时间?", "type": "task", "task_type": "qa"},
+ ],
+ # Pool 4
+ [
+ {"query": "列出三种鸟类。", "type": "task", "task_type": "list"},
+ {"query": "列出三种数据库系统。", "type": "task", "task_type": "list"},
+ {"query": "什么是机器学习?", "type": "task", "task_type": "qa"},
+ {"query": "水的沸点是多少?", "type": "task", "task_type": "qa"},
+ ],
+ ]
+
+ en_task_pools = [
+ # Pool 1
+ [
+ {"query": "List three tropical fruits.", "type": "task", "task_type": "list"},
+ {"query": "List three popular programming languages.", "type": "task", "task_type": "list"},
+ {"query": "What is photosynthesis?", "type": "task", "task_type": "qa"},
+ {"query": "How many planets are in our solar system?", "type": "task", "task_type": "qa"},
+ ],
+ # Pool 2
+ [
+ {"query": "List three indoor plants.", "type": "task", "task_type": "list"},
+ {"query": "List three types of sports.", "type": "task", "task_type": "list"},
+ {"query": "What is artificial intelligence?", "type": "task", "task_type": "qa"},
+ {"query": "How long is a day on Earth?", "type": "task", "task_type": "qa"},
+ ],
+ # Pool 3
+ [
+ {"query": "List three musical instruments.", "type": "task", "task_type": "list"},
+ {"query": "List three social media platforms.", "type": "task", "task_type": "list"},
+ {"query": "What is blockchain?", "type": "task", "task_type": "qa"},
+ {"query": "How long does it take the Moon to orbit Earth?", "type": "task", "task_type": "qa"},
+ ],
+ # Pool 4
+ [
+ {"query": "List three types of birds.", "type": "task", "task_type": "list"},
+ {"query": "List three database systems.", "type": "task", "task_type": "list"},
+ {"query": "What is machine learning?", "type": "task", "task_type": "qa"},
+ {"query": "What is the boiling point of water?", "type": "task", "task_type": "qa"},
+ ],
+ ]
+
+ pools = zh_task_pools if prefs.lang == "zh" else en_task_pools
+ # Rotate through pools based on session index
+ pool_idx = (session_idx - 4) % len(pools)
+ return pools[pool_idx]
+
+
+def get_queries_for_session(persona: Persona, session_id: int) -> List[Dict[str, Any]]:
+ """
+ Get queries for a specific session.
+ S1: Preference reveal
+ S2: Pure task (no reminder)
+ S3: Tasks with ONE reminder
+ S4+: Pure task (testing long-term retention)
+ """
+ if session_id == 1:
+ return get_session_1_queries_for_persona(persona)
+ elif session_id == 2:
+ return get_session_2_queries_for_persona(persona)
+ elif session_id == 3:
+ return get_session_3_queries_for_persona(persona)
+ else:
+ return get_pure_task_queries_for_persona(persona, session_id)
+
+
+# =============================================================================
+# Logging
+# =============================================================================
+
+@dataclass
+class TurnLog:
+ """Log entry for one turn."""
+ user_id: str
+ persona_id: str
+ session_id: int
+ turn_id: int
+ query: str
+ query_type: str
+ task_type: str
+ answer: str
+ answer_length: int
+ sat_t: float
+ sev_t: float
+ prog_t: float
+ violations: List[str]
+ enforced_constraints: List[str]
+ reward: float
+ gating: float
+ is_complaint: bool
+ reveal_state_before: Dict[str, bool]
+ reveal_state_after: Dict[str, bool]
+ newly_revealed: List[str]
+ z_long_norm_before: float
+ z_long_norm_after: float
+ z_short_norm_before: float
+ z_short_norm_after: float
+ prompt_tokens: int
+ completion_tokens: int
+ total_tokens: int
+ # Memory retrieval details
+ num_memories_retrieved: int
+ num_prefs_extracted: int
+ selected_memory_ids: List[str]
+ selected_memory_notes: List[str]
+ selected_memory_scores: List[float]
+ num_candidates: int
+ num_total_memories: int
+ # Mode indicators
+ mode: str # "full" or "nopersonal"
+ eval_mode: bool # True = greedy, False = sample
+
+
+def log_to_jsonl(logs: List[TurnLog], filepath: str):
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
+ with open(filepath, "w") as f:
+ for log in logs:
+ f.write(json.dumps(asdict(log)) + "\n")
+
+
+# =============================================================================
+# Session Runner with Complaint Injection (FIX 5)
+# =============================================================================
+
+def run_session_v4(
+ llm: PersonalizedLLM,
+ user_id: str,
+ persona: Persona,
+ session_id: int,
+ reveal_state: RevealState,
+ base_queries: List[Dict[str, Any]],
+ all_logs: List[TurnLog],
+ enable_complaints: bool = True,
+) -> List[TurnLog]:
+ """
+ Run session with violation-triggered complaint injection.
+ """
+ prefs = persona.style_prefs
+ session_logs: List[TurnLog] = []
+
+ print(f"\n{'='*60}")
+ print(f"[{persona.persona_id}] Session {session_id}: base queries={len(base_queries)}")
+ print(f"Reveal state (start): {reveal_state}")
+ print(f"{'='*60}")
+
+ llm.reset_session(user_id)
+
+ # Build dynamic query queue
+ query_queue = list(base_queries)
+ turn_id = 0
+
+ while query_queue:
+ q_info = query_queue.pop(0)
+ query = q_info["query"]
+ query_type = q_info.get("type", "task")
+ task_type = q_info.get("task_type", "general")
+ is_complaint = q_info.get("is_complaint", False)
+
+ print(f"\n--- S{session_id}/T{turn_id} [{query_type}]{' [COMPLAINT]' if is_complaint else ''} ---")
+ print(f"[Q] {query[:60]}{'...' if len(query) > 60 else ''}")
+
+ reveal_before = reveal_state.to_dict()
+ newly_revealed = update_reveal_state(reveal_state, query, prefs)
+ if newly_revealed:
+ print(f"[Reveal] Newly: {newly_revealed}")
+ reveal_after = reveal_state.to_dict()
+
+ state_before = llm.get_user_state_summary(user_id)
+ z_long_before = state_before["z_long_norm"]
+ z_short_before = state_before["z_short_norm"]
+
+ # Apply feedback for previous turn
+ if turn_id > 0 and session_logs:
+ prev_log = session_logs[-1]
+ feedback = Feedback(
+ user_id=user_id,
+ turn_id=prev_log.turn_id,
+ reward=prev_log.reward,
+ gating=prev_log.gating,
+ meta={"source": "pilot_v4", "session_id": session_id}
+ )
+ llm.apply_feedback(feedback)
+
+ # Chat
+ resp: AssistantResponse = llm.chat(user_id, query)
+
+ answer_display = resp.answer[:80] + "..." if len(resp.answer) > 80 else resp.answer
+ print(f"[A] ({len(resp.answer)}c) {answer_display}")
+
+ # Judge
+ judge_result = style_judge_v4(query, resp.answer, task_type, prefs, reveal_state)
+ print(f"[J] sat={judge_result.sat_t:.2f}, enforced={judge_result.enforced_constraints}, viol={judge_result.violations}")
+
+ # Compute feedback
+ reward, gating = compute_feedback_for_turn(query, query_type, judge_result)
+
+ state_after = llm.get_user_state_summary(user_id)
+ z_long_after = state_after["z_long_norm"]
+ z_short_after = state_after["z_short_norm"]
+
+ # Extract memory info from debug
+ if resp.debug:
+ num_memories = len(resp.debug.selected_memory_ids)
+ num_prefs = len(resp.debug.extracted_preferences)
+ selected_memory_ids = resp.debug.selected_memory_ids
+ selected_memory_notes = resp.debug.selected_memory_notes
+ selected_memory_scores = resp.debug.selected_memory_scores
+ num_candidates = resp.debug.extra.get("num_candidates", 0)
+ num_total_memories = resp.debug.extra.get("num_total_memories", 0)
+ else:
+ num_memories = 0
+ num_prefs = 0
+ selected_memory_ids = []
+ selected_memory_notes = []
+ selected_memory_scores = []
+ num_candidates = 0
+ num_total_memories = 0
+
+ # Log
+ log = TurnLog(
+ user_id=user_id,
+ persona_id=persona.persona_id,
+ session_id=session_id,
+ turn_id=turn_id,
+ query=query,
+ query_type=query_type,
+ task_type=task_type,
+ answer=resp.answer,
+ answer_length=len(resp.answer),
+ sat_t=judge_result.sat_t,
+ sev_t=judge_result.sev_t,
+ prog_t=judge_result.prog_t,
+ violations=judge_result.violations,
+ enforced_constraints=judge_result.enforced_constraints,
+ reward=reward,
+ gating=gating,
+ is_complaint=is_complaint,
+ reveal_state_before=reveal_before,
+ reveal_state_after=reveal_after,
+ newly_revealed=list(newly_revealed),
+ z_long_norm_before=z_long_before,
+ z_long_norm_after=z_long_after,
+ z_short_norm_before=z_short_before,
+ z_short_norm_after=z_short_after,
+ prompt_tokens=resp.usage.prompt_tokens,
+ completion_tokens=resp.usage.completion_tokens,
+ total_tokens=resp.usage.total_tokens,
+ num_memories_retrieved=num_memories,
+ num_prefs_extracted=num_prefs,
+ selected_memory_ids=selected_memory_ids,
+ selected_memory_notes=selected_memory_notes,
+ selected_memory_scores=selected_memory_scores,
+ num_candidates=num_candidates,
+ num_total_memories=num_total_memories,
+ mode=llm.mode,
+ eval_mode=llm.eval_mode,
+ )
+ session_logs.append(log)
+ all_logs.append(log)
+
+ # FIX 5: Inject complaint if there were violations and this wasn't already a complaint
+ if enable_complaints and judge_result.violations and not is_complaint:
+ complaint = generate_complaint_query(judge_result.violations, prefs)
+ if complaint:
+ complaint["is_complaint"] = True
+ query_queue.insert(0, complaint) # Insert at front
+ print(f"[Complaint Injected] Will complain about: {judge_result.violations}")
+
+ turn_id += 1
+
+ # Final feedback
+ if session_logs:
+ last_log = session_logs[-1]
+ feedback = Feedback(
+ user_id=user_id,
+ turn_id=last_log.turn_id,
+ reward=last_log.reward,
+ gating=last_log.gating,
+ meta={"source": "pilot_v4", "session_id": session_id, "final": True}
+ )
+ llm.apply_feedback(feedback)
+
+ print(f"\n[Session {session_id} End] Reveal: {reveal_state}, Turns: {turn_id}")
+ return session_logs
+
+
+# =============================================================================
+# Multi-User Multi-Session Runner
+# =============================================================================
+
+def run_multi_user_pilot_v4(
+ llm: PersonalizedLLM,
+ personas: List[Persona],
+ num_sessions: int = 3,
+ enable_complaints: bool = True,
+) -> List[TurnLog]:
+ """Run multi-user multi-session pilot v4."""
+ reveal_manager = RevealStateManager()
+ all_logs: List[TurnLog] = []
+
+ print(f"\n{'#'*60}")
+ print(f"PILOT v4: MULTI-USER MULTI-SESSION (Fixed)")
+ print(f"Users: {len(personas)}, Sessions: {num_sessions}, Complaints: {enable_complaints}")
+ print(f"{'#'*60}")
+
+ for persona in personas:
+ user_id = f"user_{persona.persona_id}"
+ prefs = persona.style_prefs
+
+ print(f"\n{'*'*60}")
+ print(f"USER: {user_id}")
+ print(f"Persona: {persona.description}")
+ print(f"True prefs: short={prefs.require_short}, bullets={prefs.require_bullets}, lang={prefs.lang}")
+ print(f"{'*'*60}")
+
+ llm.reset_user(user_id)
+ reveal_manager.reset_user(user_id)
+ reveal_state = reveal_manager.get_state(user_id)
+
+ for session_id in range(1, num_sessions + 1):
+ queries = get_queries_for_session(persona, session_id)
+
+ reveal_manager.reset_session(user_id)
+ run_session_v4(llm, user_id, persona, session_id, reveal_state, queries, all_logs, enable_complaints)
+
+ return all_logs
+
+
+# =============================================================================
+# Summary
+# =============================================================================
+
+def print_summary_v4(logs: List[TurnLog]):
+ """Print summary for pilot v4."""
+ print(f"\n{'='*60}")
+ print("PILOT v4 SUMMARY")
+ print(f"{'='*60}")
+
+ if not logs:
+ print("No logs.")
+ return
+
+ from collections import Counter
+
+ personas = sorted(set(l.persona_id for l in logs))
+
+ print(f"\n--- Per-Persona Statistics ---")
+ for pid in personas:
+ p_logs = [l for l in logs if l.persona_id == pid]
+ sessions = sorted(set(l.session_id for l in p_logs))
+
+ print(f"\n{pid}:")
+ for sid in sessions:
+ s_logs = [l for l in p_logs if l.session_id == sid]
+ avg_sat = sum(l.sat_t for l in s_logs) / len(s_logs) if s_logs else 0
+ violations = [v for l in s_logs for v in l.violations]
+ enforced = set(c for l in s_logs for c in l.enforced_constraints)
+ complaints = sum(1 for l in s_logs if l.is_complaint)
+
+ print(f" S{sid}: {len(s_logs)} turns, avg_sat={avg_sat:.3f}, complaints={complaints}")
+ print(f" enforced={enforced}")
+ if violations:
+ print(f" violations: {dict(Counter(violations))}")
+
+ # Cross-session retention
+ print(f"\n--- Cross-Session Retention (S2 without preferences) ---")
+ for pid in personas:
+ p_logs = [l for l in logs if l.persona_id == pid]
+ s1_logs = [l for l in p_logs if l.session_id == 1]
+ s2_logs = [l for l in p_logs if l.session_id == 2]
+
+ if s1_logs and s2_logs:
+ s1_sat = sum(l.sat_t for l in s1_logs) / len(s1_logs)
+ s2_sat = sum(l.sat_t for l in s2_logs) / len(s2_logs)
+ s2_enforced = set(c for l in s2_logs for c in l.enforced_constraints)
+ print(f"{pid}: S1={s1_sat:.3f} → S2={s2_sat:.3f}, enforced={s2_enforced}")
+
+ # Violation rates
+ print(f"\n--- Violation Rates by Type ---")
+ all_violations = [v for l in logs for v in l.violations]
+ total_turns = len(logs)
+ if all_violations:
+ for v, count in Counter(all_violations).most_common():
+ rate = count / total_turns * 100
+ print(f" {v}: {count} ({rate:.1f}%)")
+ else:
+ print(" No violations")
+
+ # Complaint effectiveness
+ print(f"\n--- Complaint Effectiveness ---")
+ complaint_logs = [l for l in logs if l.is_complaint]
+ if complaint_logs:
+ print(f"Total complaints: {len(complaint_logs)}")
+ avg_sat_complaint = sum(l.sat_t for l in complaint_logs) / len(complaint_logs)
+ print(f"Avg sat on complaint turns: {avg_sat_complaint:.3f}")
+ else:
+ print("No complaints generated")
+
+ # Overall
+ total = len(logs)
+ avg_sat = sum(l.sat_t for l in logs) / total
+ total_tokens = sum(l.total_tokens for l in logs)
+ print(f"\n--- Overall ---")
+ print(f"Total turns: {total}, Avg sat: {avg_sat:.3f}, Total tokens: {total_tokens}")
+
+
+def main():
+ import argparse
+ parser = argparse.ArgumentParser(description="Pilot Runner v4 - Full vs Vanilla Comparison")
+ parser.add_argument("--mode", type=str,
+ choices=["full", "full-greedy", "full-sample", "nopersonal", "vanilla", "compare", "all"],
+ default="compare",
+ help="Mode: 'full-greedy' (personalized, deterministic), "
+ "'full-sample' (personalized, stochastic), "
+ "'nopersonal' (retrieval baseline without z_u), "
+ "'vanilla' (pure LLM, no memory), "
+ "'compare' (full-greedy vs vanilla), "
+ "'all' (run all modes)")
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
+ parser.add_argument("--sessions", type=int, default=3, help="Number of sessions per user")
+ parser.add_argument("--no-complaints", action="store_true", help="Disable complaint injection")
+ args = parser.parse_args()
+
+ # Set seeds for reproducibility
+ import random
+ import numpy as np
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+
+ print("=" * 60)
+ print("PILOT RUNNER v4 - Full vs Vanilla Comparison")
+ print("=" * 60)
+ print(f"Started at: {datetime.now().isoformat()}")
+ print(f"Mode: {args.mode}, Seed: {args.seed}, Sessions: {args.sessions}")
+
+ personas = ALL_PERSONAS
+ print(f"\n[Config] {len(personas)} personas:")
+ for p in personas:
+ print(f" - {p.persona_id}: {p.description}")
+
+ enable_complaints = not args.no_complaints
+
+ # Map mode argument to actual run configurations
+ # Each config: (mode_name, llm_mode, eval_mode)
+ # llm_mode: "full", "nopersonal", or "vanilla"
+ # eval_mode: True = greedy/deterministic, False = stochastic sampling
+ if args.mode == "all":
+ run_configs = [
+ ("full-greedy", "full", True),
+ ("full-sample", "full", False),
+ ("nopersonal", "nopersonal", True),
+ ("vanilla", "vanilla", True),
+ ]
+ elif args.mode == "compare":
+ # Main comparison: Full (with memory) vs Vanilla (no memory)
+ run_configs = [
+ ("full-greedy", "full", True),
+ ("vanilla", "vanilla", True),
+ ]
+ elif args.mode == "full" or args.mode == "full-greedy":
+ run_configs = [("full-greedy", "full", True)]
+ elif args.mode == "full-sample":
+ run_configs = [("full-sample", "full", False)]
+ elif args.mode == "vanilla":
+ run_configs = [("vanilla", "vanilla", True)]
+ elif args.mode == "nopersonal":
+ run_configs = [("nopersonal", "nopersonal", True)]
+ else:
+ run_configs = [(args.mode, args.mode, True)]
+
+ for run_name, llm_mode, eval_mode in run_configs:
+ print(f"\n{'#'*60}")
+ print(f"RUNNING: {run_name.upper()}")
+ print(f" llm_mode={llm_mode}, eval_mode={eval_mode} ({'greedy' if eval_mode else 'sample'})")
+ print(f"{'#'*60}")
+
+ # Reset seeds before each run for exact reproducibility
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+
+ print(f"\n[Init] Loading PersonalizedLLM...")
+ llm = PersonalizedLLM(
+ user_store_path=f"data/users/user_store_pilot_v4_{run_name}.npz",
+ only_own_memories=True,
+ enable_preference_extraction=True,
+ enable_rl_updates=(llm_mode == "full"), # Disable RL for nopersonal
+ mode=llm_mode,
+ eval_mode=eval_mode,
+ device_assignment={
+ "embed": "cuda:0",
+ "reranker": "cuda:1",
+ "chat": "cuda:2",
+ "extractor": "cuda:3",
+ },
+ )
+
+ logs = run_multi_user_pilot_v4(llm, personas, num_sessions=args.sessions, enable_complaints=enable_complaints)
+
+ print_summary_v4(logs)
+
+ log_path = f"data/logs/pilot_v4_{run_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl"
+ log_to_jsonl(logs, log_path)
+ print(f"\n[Logs] Saved to: {log_path}")
+
+ # Save user vectors for similarity analysis
+ if llm_mode == "full":
+ llm.persist()
+ print(f"[Persist] User vectors saved to: {llm._user_store.path}")
+
+ print(f"\nCompleted at: {datetime.now().isoformat()}")
+ print("=" * 60)
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/scripts/pilot_study.py b/scripts/pilot_study.py
new file mode 100644
index 0000000..9754c42
--- /dev/null
+++ b/scripts/pilot_study.py
@@ -0,0 +1,109 @@
+import json
+import os
+import random
+import asyncio
+from typing import List, Dict, Any
+from openai import AsyncOpenAI
+from tqdm.asyncio import tqdm_asyncio
+
+# --- Configuration ---
+INPUT_FILE = "data/raw_datasets/combined_raw_queries.jsonl"
+OUTPUT_FILE = "data/raw_datasets/pilot_study_1000.jsonl"
+SAMPLE_SIZE = 1000
+MODEL_NAME = "gpt-5.1" # Or your specific model ID
+MAX_CONCURRENCY = 100 # Adjust based on your rate limits
+
+# --- Load System Prompt ---
+with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f:
+ # Extract the system prompt part (before the examples to save tokens,
+ # or keep full if you want few-shot behavior).
+ # Based on the file content you wrote earlier, let's use the whole thing
+ # as the system instruction to ensure high quality.
+ SYSTEM_PROMPT = f.read()
+
+# --- Async Worker ---
+async def label_query(client: AsyncOpenAI, sem: asyncio.Semaphore, item: Dict[str, Any]) -> Dict[str, Any]:
+ query = item["query"]
+ async with sem:
+ try:
+ response = await client.chat.completions.create(
+ model=MODEL_NAME,
+ messages=[
+ {"role": "system", "content": SYSTEM_PROMPT},
+ {"role": "user", "content": query}
+ ],
+ temperature=0.0, # Deterministic for extraction
+ response_format={"type": "json_object"} # Enforce JSON
+ )
+ result_text = response.choices[0].message.content
+
+ # Parse to ensure validity
+ try:
+ parsed = json.loads(result_text)
+ prefs = parsed.get("preferences", [])
+ has_pref = len(prefs) > 0
+ except:
+ parsed = {"error": "json_parse_fail", "raw": result_text}
+ has_pref = False
+
+ return {
+ "original_query": query,
+ "source": item.get("source"),
+ "extracted_json": parsed,
+ "has_preference": has_pref
+ }
+ except Exception as e:
+ return {
+ "original_query": query,
+ "source": item.get("source"),
+ "error": str(e),
+ "has_preference": False
+ }
+
+async def main():
+ # 1. Load and Sample
+ print(f"Loading data from {INPUT_FILE}...")
+ all_lines = []
+ with open(INPUT_FILE, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ all_lines.append(json.loads(line))
+
+ if len(all_lines) > SAMPLE_SIZE:
+ sampled_data = random.sample(all_lines, SAMPLE_SIZE)
+ else:
+ sampled_data = all_lines
+ print(f"Sampled {len(sampled_data)} items.")
+
+ # 2. Setup OpenAI Client
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ print("Error: OPENAI_API_KEY environment variable not set.")
+ return
+
+ client = AsyncOpenAI(api_key=api_key)
+ sem = asyncio.Semaphore(MAX_CONCURRENCY)
+
+ # 3. Run Labeling
+ tasks = [label_query(client, sem, item) for item in sampled_data]
+ results = await tqdm_asyncio.gather(*tasks, desc="Labeling")
+
+ # 4. Statistics & Save
+ pos_count = sum(1 for r in results if r.get("has_preference"))
+ total = len(results)
+ ratio = (pos_count / total) * 100 if total > 0 else 0
+
+ print(f"\n--- Results ---")
+ print(f"Total processed: {total}")
+ print(f"Positive (has preferences): {pos_count}")
+ print(f"Negative (empty): {total - pos_count}")
+ print(f"Positive Ratio: {ratio:.2f}%")
+
+ with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
+ for res in results:
+ f.write(json.dumps(res, ensure_ascii=False) + "\n")
+ print(f"Saved detailed results to {OUTPUT_FILE}")
+
+if __name__ == "__main__":
+ asyncio.run(main())
+
diff --git a/scripts/process_putnam_batch.py b/scripts/process_putnam_batch.py
new file mode 100644
index 0000000..27e0465
--- /dev/null
+++ b/scripts/process_putnam_batch.py
@@ -0,0 +1,239 @@
+import json
+import os
+import time
+from typing import List, Dict, Any
+from openai import OpenAI
+from collections import Counter
+
+# Configuration
+API_KEY = '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'
+BATCH_ID_FILE = "data/putnam_eval/submitted_batch_ids.json"
+INPUT_FILE = "data/putnam_eval/putnam_eval_batch.jsonl"
+OUTPUT_FILE = "data/putnam_eval/final_results.json"
+MODEL_NAME = "gpt-5"
+
+def load_input_requests(filepath: str) -> Dict[str, Any]:
+ """Load the original requests to allow retrying."""
+ requests_map = {}
+ print(f"Loading input requests from {filepath}...")
+ with open(filepath, "r", encoding="utf-8") as f:
+ for line in f:
+ if not line.strip():
+ continue
+ item = json.loads(line)
+ requests_map[item["custom_id"]] = item
+ return requests_map
+
+def retrieve_batch_results(client: OpenAI, batch_id: str) -> List[Dict[str, Any]]:
+ """Retrieve and parse batch results."""
+ print(f"Checking status for batch {batch_id}...")
+ batch = client.batches.retrieve(batch_id)
+
+ print(f"Batch Status: {batch.status}")
+ print(f"Output File ID: {batch.output_file_id}")
+ print(f"Error File ID: {batch.error_file_id}")
+
+ results = []
+
+ if batch.output_file_id:
+ print("Downloading output file...")
+ file_response = client.files.content(batch.output_file_id)
+ file_content = file_response.read().decode("utf-8")
+
+ for line in file_content.splitlines():
+ if line.strip():
+ results.append(json.loads(line))
+
+ if batch.error_file_id:
+ print("Downloading error file (if any)...")
+ # Usually contains request-level errors that didn't generate a response object in output
+ try:
+ err_response = client.files.content(batch.error_file_id)
+ err_content = err_response.read().decode("utf-8")
+ for line in err_content.splitlines():
+ if line.strip():
+ results.append(json.loads(line))
+ except Exception as e:
+ print(f"Note: Could not download/parse error file: {e}")
+
+ return results
+
+def process_results_and_find_failures(results: List[Dict[str, Any]], all_request_ids: set) -> tuple[List[Dict[str, Any]], List[str]]:
+ """Separate successful parsable results from failures."""
+ valid_results = []
+ failed_ids = []
+ seen_ids = set()
+
+ for res in results:
+ custom_id = res.get("custom_id")
+ seen_ids.add(custom_id)
+
+ # Check for API level errors
+ if res.get("error"):
+ print(f"Request {custom_id} failed with error: {res['error']}")
+ failed_ids.append(custom_id)
+ continue
+
+ response = res.get("response", {})
+ if response.get("status_code") != 200:
+ print(f"Request {custom_id} failed with status {response.get('status_code')}")
+ failed_ids.append(custom_id)
+ continue
+
+ # Try to parse the content as JSON
+ try:
+ body = response.get("body", {})
+ choices = body.get("choices", [])
+ if not choices:
+ print(f"Request {custom_id} has no choices.")
+ failed_ids.append(custom_id)
+ continue
+
+ content_str = choices[0].get("message", {}).get("content", "")
+ content_json = json.loads(content_str)
+
+ valid_results.append({
+ "custom_id": custom_id,
+ "analysis": content_json
+ })
+ except json.JSONDecodeError:
+ print(f"Request {custom_id} returned invalid JSON content.")
+ failed_ids.append(custom_id)
+ except Exception as e:
+ print(f"Request {custom_id} unexpected processing error: {e}")
+ failed_ids.append(custom_id)
+
+ # Check for completely missing requests
+ missing_ids = all_request_ids - seen_ids
+ if missing_ids:
+ print(f"Found {len(missing_ids)} missing requests that were not in the batch output.")
+ failed_ids.extend(list(missing_ids))
+
+ return valid_results, failed_ids
+
+def retry_failed_requests(client: OpenAI, failed_ids: List[str], input_map: Dict[str, Any]) -> List[Dict[str, Any]]:
+ """Retry specific requests synchronously."""
+ retried_results = []
+ print(f"\nRetrying {len(failed_ids)} failed requests synchronously...")
+
+ for i, custom_id in enumerate(failed_ids):
+ if custom_id not in input_map:
+ print(f"Warning: Original request for {custom_id} not found.")
+ continue
+
+ print(f"Retrying {i+1}/{len(failed_ids)}: {custom_id}")
+ original_req = input_map[custom_id]
+ body = original_req["body"]
+
+ try:
+ response = client.chat.completions.create(
+ model=MODEL_NAME, # Use the model from the script constant, not necessarily the batch one if we want to enforce gpt-5
+ messages=body["messages"],
+ response_format=body.get("response_format"),
+ temperature=body.get("temperature", 1.0) # Default if not set, usually 0 in our templates?
+ )
+
+ content_str = response.choices[0].message.content
+ content_json = json.loads(content_str)
+
+ retried_results.append({
+ "custom_id": custom_id,
+ "analysis": content_json
+ })
+ except Exception as e:
+ print(f"Retry failed for {custom_id}: {e}")
+
+ return retried_results
+
+def print_stats(final_results: List[Dict[str, Any]]):
+ """Calculate and print statistics."""
+ total = len(final_results)
+ if total == 0:
+ print("No results to analyze.")
+ return
+
+ # Categories
+ valid_variant_count = 0
+ correct_solution_count = 0
+ equivalent_count = 0
+ strongly_related_count = 0
+
+ # Validation Consistency
+ both_valid_and_equiv = 0
+
+ print(f"\n--- Statistics (N={total}) ---")
+
+ for item in final_results:
+ analysis = item["analysis"]
+ validity = analysis.get("variant_validity", {})
+ relation = analysis.get("relation_to_original", {})
+
+ is_valid = validity.get("is_problem_valid", False)
+ is_correct = validity.get("is_solution_correct", False)
+ is_equiv = relation.get("is_equivalent", False)
+ is_related = relation.get("is_strongly_related", False)
+
+ if is_valid: valid_variant_count += 1
+ if is_correct: correct_solution_count += 1
+ if is_equiv: equivalent_count += 1
+ if is_related: strongly_related_count += 1
+
+ if is_valid and is_correct and (is_equiv or is_related):
+ both_valid_and_equiv += 1
+
+ print(f"Variant Valid: {valid_variant_count} ({valid_variant_count/total:.1%})")
+ print(f"Solution Correct: {correct_solution_count} ({correct_solution_count/total:.1%})")
+ print(f"Equivalent: {equivalent_count} ({equivalent_count/total:.1%})")
+ print(f"Strongly Related: {strongly_related_count} ({strongly_related_count/total:.1%})")
+ print(f"Valid & Rel/Equiv: {both_valid_and_equiv} ({both_valid_and_equiv/total:.1%})")
+
+def main():
+ if not API_KEY:
+ print("Error: API_KEY not set.")
+ return
+
+ client = OpenAI(api_key=API_KEY)
+
+ # 1. Get Batch ID
+ if not os.path.exists(BATCH_ID_FILE):
+ print(f"Batch ID file not found at {BATCH_ID_FILE}")
+ return
+
+ with open(BATCH_ID_FILE, "r") as f:
+ batch_ids = json.load(f)
+ if not batch_ids:
+ print("No batch IDs found.")
+ return
+ batch_id = batch_ids[-1] # Take the latest one
+ print(f"Processing Batch ID: {batch_id}")
+
+ # 2. Retrieve Results
+ raw_results = retrieve_batch_results(client, batch_id)
+
+ # 3. Load Inputs (to identify missing/failed IDs)
+ input_map = load_input_requests(INPUT_FILE)
+ all_request_ids = set(input_map.keys())
+
+ # 4. Parse and Find Failures
+ valid_results, failed_ids = process_results_and_find_failures(raw_results, all_request_ids)
+ print(f"Successfully parsed: {len(valid_results)}")
+ print(f"Failed/Missing: {len(failed_ids)}")
+
+ # 5. Retry Failures
+ if failed_ids:
+ retry_results = retry_failed_requests(client, failed_ids, input_map)
+ valid_results.extend(retry_results)
+
+ # 6. Save Final Results
+ print(f"Saving {len(valid_results)} results to {OUTPUT_FILE}...")
+ with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
+ json.dump(valid_results, f, indent=2)
+
+ # 7. Stats
+ print_stats(valid_results)
+
+if __name__ == "__main__":
+ main()
+
+
+
diff --git a/scripts/pull_models.py b/scripts/pull_models.py
new file mode 100644
index 0000000..1e4abc3
--- /dev/null
+++ b/scripts/pull_models.py
@@ -0,0 +1,76 @@
+from __future__ import annotations
+
+import argparse
+from pathlib import Path
+import sys
+
+ROOT = Path(__file__).resolve().parents[1]
+sys.path.insert(0, str(ROOT / "src"))
+
+from huggingface_hub import snapshot_download # type: ignore
+
+from personalization.config.settings import load_local_models_config
+
+
+def pull_one(repo_id: str, dest: Path, force: bool = False) -> None:
+ dest.mkdir(parents=True, exist_ok=True)
+ if any(dest.iterdir()) and not force:
+ print(f"[skip] {dest} already populated. Use --force to overwrite.")
+ return
+ print(f"[pull] {repo_id} -> {dest}")
+ snapshot_download(repo_id=repo_id, local_dir=str(dest), local_dir_use_symlinks=False)
+ print(f"[done] {repo_id}")
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--target",
+ choices=[
+ "llm",
+ "preference_extractor",
+ "embed_qwen3",
+ "embed_nemotron",
+ "embedders",
+ "reranker_qwen3",
+ "rerankers",
+ "all",
+ ],
+ default="all",
+ )
+ parser.add_argument("--force", action="store_true")
+ args = parser.parse_args()
+
+ cfg = load_local_models_config()
+ if args.target in ("llm", "all"):
+ pull_one(cfg.llm.hf_id, ROOT / cfg.llm.local_path, force=args.force)
+ if args.target in ("preference_extractor", "all"):
+ pull_one(
+ cfg.preference_extractor.hf_id,
+ ROOT / cfg.preference_extractor.local_path,
+ force=args.force,
+ )
+ if args.target in ("embed_qwen3", "embedders", "all") and cfg.embedding and cfg.embedding.qwen3:
+ pull_one(
+ cfg.embedding.qwen3.hf_id,
+ ROOT / cfg.embedding.qwen3.local_path,
+ force=args.force,
+ )
+ if args.target in ("embed_nemotron", "embedders", "all") and cfg.embedding and cfg.embedding.nemotron:
+ pull_one(
+ cfg.embedding.nemotron.hf_id,
+ ROOT / cfg.embedding.nemotron.local_path,
+ force=args.force,
+ )
+ if args.target in ("reranker_qwen3", "rerankers", "all") and cfg.reranker and cfg.reranker.qwen3_8b:
+ pull_one(
+ cfg.reranker.qwen3_8b.hf_id,
+ ROOT / cfg.reranker.qwen3_8b.local_path,
+ force=args.force,
+ )
+
+
+if __name__ == "__main__":
+ main()
+
+
diff --git a/scripts/recompute_embeddings.py b/scripts/recompute_embeddings.py
new file mode 100644
index 0000000..884cc7b
--- /dev/null
+++ b/scripts/recompute_embeddings.py
@@ -0,0 +1,65 @@
+import json
+import os
+import sys
+import numpy as np
+import torch
+from tqdm import tqdm
+
+# Add src to sys.path
+sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.config.settings import load_local_models_config
+from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
+from personalization.retrieval.preference_store.schemas import MemoryCard
+
+CARDS_FILE = "data/corpora/memory_cards.jsonl"
+EMBEDDINGS_FILE = "data/corpora/memory_embeddings.npy"
+
+def recompute_embeddings():
+ if not os.path.exists(CARDS_FILE):
+ print(f"Error: {CARDS_FILE} not found.")
+ return
+
+ print("Loading configuration and model...")
+ cfg = load_local_models_config()
+ embed_model = Qwen3Embedding8B.from_config(cfg)
+
+ print(f"Reading memory cards from {CARDS_FILE}...")
+ cards = []
+ texts = []
+ with open(CARDS_FILE, "r", encoding="utf-8") as f:
+ for line in f:
+ if not line.strip(): continue
+ card = MemoryCard.model_validate_json(line)
+ cards.append(card)
+ # Embedding source: note_text (preference) or raw_query?
+ # Usually we embed the note_text for retrieval.
+ texts.append(card.note_text)
+
+ print(f"Total cards: {len(cards)}")
+
+ if not cards:
+ print("No cards found.")
+ return
+
+ print("Computing embeddings...")
+ # Batch processing
+ batch_size = 32
+ all_embs = []
+
+ for i in tqdm(range(0, len(texts), batch_size)):
+ batch_texts = texts[i : i + batch_size]
+ # Qwen3Embedding8B.encode returns list of lists (if return_tensor=False)
+ embs = embed_model.encode(batch_texts, return_tensor=False)
+ all_embs.extend(embs)
+
+ emb_array = np.array(all_embs, dtype=np.float32)
+ print(f"Embeddings shape: {emb_array.shape}")
+
+ print(f"Saving to {EMBEDDINGS_FILE}...")
+ np.save(EMBEDDINGS_FILE, emb_array)
+ print("Done!")
+
+if __name__ == "__main__":
+ recompute_embeddings()
+
diff --git a/scripts/recover_and_merge.py b/scripts/recover_and_merge.py
new file mode 100644
index 0000000..b0f37f7
--- /dev/null
+++ b/scripts/recover_and_merge.py
@@ -0,0 +1,151 @@
+import json
+import os
+from openai import OpenAI
+from typing import Dict, Any
+
+# --- Configuration ---
+# 1. Main Batch IDs (The 340k success ones we lost)
+MAIN_BATCH_IDS_FILE = "data/raw_datasets/submitted_batch_ids.json"
+# 2. OASST1 Batch IDs (New)
+OASST1_BATCH_IDS_FILE = "data/raw_datasets/submitted_oasst1_batch_ids.json"
+OASST1_METADATA_FILE = "data/raw_datasets/oasst1_metadata_map.jsonl"
+
+# The file we want to APPEND to (currently has 68k retry items)
+OUTPUT_FILE = "data/raw_datasets/labeled_full_dataset_batch.jsonl"
+
+# Original queries map for main batch reconstruction
+ORIGINAL_INPUT_FILE = "data/raw_datasets/combined_raw_queries.jsonl"
+
+def load_original_queries() -> Dict[str, Dict[str, Any]]:
+ print("Loading original queries map (Main)...")
+ mapping = {}
+ with open(ORIGINAL_INPUT_FILE, "r", encoding="utf-8") as f:
+ for idx, line in enumerate(f):
+ if line.strip():
+ mapping[f"req_{idx}"] = json.loads(line)
+ return mapping
+
+def load_oasst1_metadata() -> Dict[str, Dict[str, Any]]:
+ print("Loading OASST1 metadata map...")
+ mapping = {}
+ if os.path.exists(OASST1_METADATA_FILE):
+ with open(OASST1_METADATA_FILE, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ item = json.loads(line)
+ mapping[item["custom_id"]] = item
+ return mapping
+
+def recover_and_merge():
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ print("Error: OPENAI_API_KEY not set.")
+ return
+ client = OpenAI(api_key=api_key)
+
+ # Load Maps
+ main_query_map = load_original_queries()
+ oasst1_meta_map = load_oasst1_metadata()
+
+ # We will append to the existing file which holds the RETRY results.
+ # So we don't lose the 68k we just fixed.
+ print(f"Appending recovered data to {OUTPUT_FILE}...")
+
+ count_main = 0
+ count_oasst1 = 0
+
+ with open(OUTPUT_FILE, "a", encoding="utf-8") as f_out:
+
+ # --- 1. Recover Main Batches ---
+ if os.path.exists(MAIN_BATCH_IDS_FILE):
+ with open(MAIN_BATCH_IDS_FILE, "r") as f:
+ main_ids = json.load(f)
+
+ print(f"\nRecovering {len(main_ids)} Main Batches...")
+ for b_id in main_ids:
+ try:
+ batch = client.batches.retrieve(b_id)
+ if batch.output_file_id:
+ print(f" Downloading {b_id} (Output: {batch.output_file_id})...")
+ content = client.files.content(batch.output_file_id).text
+
+ for line in content.splitlines():
+ if not line.strip(): continue
+ res = json.loads(line)
+ custom_id = res["custom_id"]
+
+ if res["response"]["status_code"] == 200:
+ try:
+ body = res["response"]["body"]
+ llm_content = body["choices"][0]["message"]["content"]
+ parsed_json = json.loads(llm_content)
+
+ original = main_query_map.get(custom_id)
+ if original:
+ record = {
+ "custom_id": custom_id,
+ "original_query": original["query"],
+ "source": original.get("source"),
+ "extracted_json": parsed_json,
+ "has_preference": len(parsed_json.get("preferences", [])) > 0
+ }
+ f_out.write(json.dumps(record, ensure_ascii=False) + "\n")
+ count_main += 1
+ except:
+ pass
+ except Exception as e:
+ print(f" Error {b_id}: {e}")
+
+ # --- 2. Retrieve OASST1 Batches ---
+ # User requested to skip OASST1 merge for now.
+ # if os.path.exists(OASST1_BATCH_IDS_FILE):
+ # with open(OASST1_BATCH_IDS_FILE, "r") as f:
+ # oasst_ids = json.load(f)
+
+ # print(f"\nRetrieving {len(oasst_ids)} OASST1 Batches...")
+ # for b_id in oasst_ids:
+ # try:
+ # batch = client.batches.retrieve(b_id)
+ # if batch.status == "completed" and batch.output_file_id:
+ # print(f" Downloading {b_id}...")
+ # content = client.files.content(batch.output_file_id).text
+
+ # for line in content.splitlines():
+ # if not line.strip(): continue
+ # res = json.loads(line)
+ # custom_id = res["custom_id"]
+
+ # if res["response"]["status_code"] == 200:
+ # try:
+ # body = res["response"]["body"]
+ # llm_content = body["choices"][0]["message"]["content"]
+ # parsed_json = json.loads(llm_content)
+
+ # meta = oasst1_meta_map.get(custom_id)
+ # if meta:
+ # record = {
+ # "custom_id": custom_id,
+ # "original_query": meta["original_query"],
+ # "source": "oasst1",
+ # "user_id": meta.get("user_id"), # Preserve User ID!
+ # "session_id": meta.get("session_id"),
+ # "extracted_json": parsed_json,
+ # "has_preference": len(parsed_json.get("preferences", [])) > 0
+ # }
+ # f_out.write(json.dumps(record, ensure_ascii=False) + "\n")
+ # count_oasst1 += 1
+ # except:
+ # pass
+ # except Exception as e:
+ # print(f" Error {b_id}: {e}")
+
+ print("\n" + "="*50)
+ print("RECOVERY & MERGE COMPLETE")
+ print(f"Recovered Main: {count_main}")
+ print(f"New OASST1: {count_oasst1}")
+ print(f"Full dataset updated at: {OUTPUT_FILE}")
+ print("="*50)
+
+if __name__ == "__main__":
+ recover_and_merge()
+
diff --git a/scripts/retrieve_batch_results.py b/scripts/retrieve_batch_results.py
new file mode 100644
index 0000000..aa26e28
--- /dev/null
+++ b/scripts/retrieve_batch_results.py
@@ -0,0 +1,151 @@
+import json
+import os
+import time
+from typing import Dict, Any, List, Set
+from openai import OpenAI
+
+# --- Configuration ---
+BATCH_IDS_FILE = "data/raw_datasets/submitted_batch_ids.json"
+ORIGINAL_INPUT_FILE = "data/raw_datasets/combined_raw_queries.jsonl"
+OUTPUT_LABEL_FILE = "data/raw_datasets/labeled_full_dataset_batch.jsonl"
+RETRY_INPUT_FILE = "data/raw_datasets/retry_requests.jsonl"
+MODEL_NAME = "gpt-5.1" # Need this for reconstruction
+
+# Load System Prompt locally to avoid import errors
+with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f:
+ SYSTEM_PROMPT = f.read()
+
+def load_original_queries() -> Dict[str, Dict[str, Any]]:
+ print("Loading original queries map...")
+ mapping = {}
+ with open(ORIGINAL_INPUT_FILE, "r", encoding="utf-8") as f:
+ for idx, line in enumerate(f):
+ if line.strip():
+ mapping[f"req_{idx}"] = json.loads(line)
+ return mapping
+
+def process_batch_results():
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ print("Error: OPENAI_API_KEY not set.")
+ return
+ client = OpenAI(api_key=api_key)
+
+ if not os.path.exists(BATCH_IDS_FILE):
+ print(f"Error: {BATCH_IDS_FILE} not found.")
+ return
+
+ with open(BATCH_IDS_FILE, "r") as f:
+ batch_ids = json.load(f)
+
+ query_map = load_original_queries()
+ processed_ids: Set[str] = set()
+
+ # We append to existing output file if it exists, or overwrite?
+ # To be safe and avoid duplicates if re-run, let's load existing processed IDs if file exists.
+ if os.path.exists(OUTPUT_LABEL_FILE):
+ print("Scanning existing output file to avoid duplicates...")
+ with open(OUTPUT_LABEL_FILE, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ try:
+ # We don't store custom_id in output, but we can infer or we should have stored it.
+ # Wait, the output format in previous run didn't store custom_id.
+ # But we can't easily dedup without it unless we match content.
+ # BETTER STRATEGY: Just overwrite OUTPUT_LABEL_FILE for this recovery run to be clean.
+ # Or, since we crashed mid-way, maybe overwrite is safer.
+ pass
+ except:
+ pass
+
+ print("Starting fresh download/processing (Overwriting output)...")
+
+ success_count = 0
+ fail_count = 0
+
+ with open(OUTPUT_LABEL_FILE, "w", encoding="utf-8") as f_success:
+ for b_id in batch_ids:
+ print(f"\nProcessing Batch {b_id}...")
+ try:
+ batch = client.batches.retrieve(b_id)
+
+ # 1. Output File (Success)
+ if batch.output_file_id:
+ print(f" Downloading output {batch.output_file_id}...")
+ content = client.files.content(batch.output_file_id).text
+
+ for line in content.splitlines():
+ if not line.strip(): continue
+ res = json.loads(line)
+ custom_id = res["custom_id"]
+
+ if res["response"]["status_code"] == 200:
+ try:
+ body = res["response"]["body"]
+ llm_content = body["choices"][0]["message"]["content"]
+ parsed_json = json.loads(llm_content)
+
+ original_item = query_map.get(custom_id)
+ if original_item:
+ record = {
+ "custom_id": custom_id, # Add this to help debug later
+ "original_query": original_item["query"],
+ "source": original_item.get("source"),
+ "extracted_json": parsed_json,
+ "has_preference": len(parsed_json.get("preferences", [])) > 0
+ }
+ f_success.write(json.dumps(record, ensure_ascii=False) + "\n")
+ processed_ids.add(custom_id)
+ success_count += 1
+ except Exception as e:
+ print(f" Parse Error {custom_id}: {e}")
+ # Parse error -> Fail
+ # If not 200, it's a fail, handled by logic below (since it won't be in processed_ids)
+
+ # 2. Error File (Explicit Failures)
+ # We don't need to explicitly read error file to write retries,
+ # because we will do a global "Missing Check" at the end.
+ # But reading it helps debugging.
+ if batch.error_file_id:
+ print(f" Downloading ERROR {batch.error_file_id}...")
+ # Just print count
+ # content = client.files.content(batch.error_file_id).text
+ # print(f" Found {len(content.splitlines())} errors in error file.")
+
+ except Exception as e:
+ print(f" CRITICAL ERROR processing batch {b_id}: {e}")
+
+ # --- Missing Check & Retry Generation ---
+ print(f"\nVerifying completeness... (Total Queries: {len(query_map)})")
+ print(f"Successful processed: {len(processed_ids)}")
+
+ with open(RETRY_INPUT_FILE, "w", encoding="utf-8") as f_retry:
+ for custom_id, original_item in query_map.items():
+ if custom_id not in processed_ids:
+ fail_count += 1
+
+ # Reconstruct Request
+ request_obj = {
+ "custom_id": custom_id,
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ "model": MODEL_NAME,
+ "messages": [
+ {"role": "system", "content": SYSTEM_PROMPT},
+ {"role": "user", "content": original_item["query"]}
+ ],
+ "temperature": 0.0,
+ "response_format": {"type": "json_object"}
+ }
+ }
+ f_retry.write(json.dumps(request_obj) + "\n")
+
+ print("\n" + "="*50)
+ print(f"Processing Complete.")
+ print(f"Successful: {success_count} (Saved to {OUTPUT_LABEL_FILE})")
+ print(f"To Retry: {fail_count} (Saved to {RETRY_INPUT_FILE})")
+ print("="*50)
+
+if __name__ == "__main__":
+ process_batch_results()
diff --git a/scripts/retrieve_oasst1.py b/scripts/retrieve_oasst1.py
new file mode 100644
index 0000000..436d329
--- /dev/null
+++ b/scripts/retrieve_oasst1.py
@@ -0,0 +1,96 @@
+import json
+import os
+from openai import OpenAI
+from typing import Dict, Any
+
+# --- Configuration ---
+BATCH_IDS_FILE = "data/raw_datasets/submitted_oasst1_batch_ids.json"
+METADATA_FILE = "data/raw_datasets/oasst1_metadata_map.jsonl"
+# Store independently for Memory/User Modeling initialization
+OUTPUT_FILE = "data/corpora/oasst1_labeled.jsonl"
+
+def load_metadata() -> Dict[str, Dict[str, Any]]:
+ print("Loading OASST1 metadata map...")
+ mapping = {}
+ if os.path.exists(METADATA_FILE):
+ with open(METADATA_FILE, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ item = json.loads(line)
+ mapping[item["custom_id"]] = item
+ return mapping
+
+def retrieve_oasst1():
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ print("Error: OPENAI_API_KEY not set.")
+ return
+ client = OpenAI(api_key=api_key)
+
+ if not os.path.exists(BATCH_IDS_FILE):
+ print(f"Error: {BATCH_IDS_FILE} not found.")
+ return
+
+ with open(BATCH_IDS_FILE, "r") as f:
+ batch_ids = json.load(f)
+
+ meta_map = load_metadata()
+ count_success = 0
+ count_fail = 0
+
+ print(f"Appending OASST1 results to {OUTPUT_FILE}...")
+
+ with open(OUTPUT_FILE, "a", encoding="utf-8") as f_out:
+ for b_id in batch_ids:
+ print(f"\nProcessing Batch {b_id}...")
+ try:
+ batch = client.batches.retrieve(b_id)
+ if batch.output_file_id:
+ print(f" Downloading output {batch.output_file_id}...")
+ content = client.files.content(batch.output_file_id).text
+
+ for line in content.splitlines():
+ if not line.strip(): continue
+ res = json.loads(line)
+ custom_id = res["custom_id"]
+
+ if res["response"]["status_code"] == 200:
+ try:
+ body = res["response"]["body"]
+ llm_content = body["choices"][0]["message"]["content"]
+ parsed_json = json.loads(llm_content)
+
+ meta = meta_map.get(custom_id)
+ if meta:
+ record = {
+ "custom_id": custom_id,
+ "original_query": meta["original_query"],
+ "source": "oasst1",
+ "user_id": meta.get("user_id"),
+ "session_id": meta.get("session_id"),
+ "extracted_json": parsed_json,
+ "has_preference": len(parsed_json.get("preferences", [])) > 0
+ }
+ f_out.write(json.dumps(record, ensure_ascii=False) + "\n")
+ count_success += 1
+ else:
+ # Fallback if metadata missing (unlikely)
+ print(f"Warning: Metadata missing for {custom_id}")
+ except Exception as e:
+ print(f"Parse error {custom_id}: {e}")
+ count_fail += 1
+ else:
+ count_fail += 1
+ except Exception as e:
+ print(f"Error checking batch {b_id}: {e}")
+
+ print("\n" + "="*50)
+ print("OASST1 RETRIEVAL COMPLETE")
+ print(f"Successfully processed: {count_success}")
+ print(f"Failed/Parse Error: {count_fail}")
+ print(f"Full dataset updated at: {OUTPUT_FILE}")
+ print("="*50)
+
+if __name__ == "__main__":
+ retrieve_oasst1()
+
diff --git a/scripts/retrieve_synthesis.py b/scripts/retrieve_synthesis.py
new file mode 100644
index 0000000..cbc4573
--- /dev/null
+++ b/scripts/retrieve_synthesis.py
@@ -0,0 +1,118 @@
+import json
+import os
+from openai import OpenAI
+from typing import Dict, Any
+
+# --- Configuration ---
+BATCH_IDS_FILE = "data/raw_datasets/submitted_synthesis_batch_ids.json"
+SEED_FILE = "data/raw_datasets/positive_seeds.jsonl"
+# Where to save the new synthesized records
+OUTPUT_FILE = "data/raw_datasets/synthesized_positives.jsonl"
+
+def load_seeds() -> Dict[str, Dict[str, Any]]:
+ print("Loading seeds map...")
+ mapping = {}
+ with open(SEED_FILE, "r", encoding="utf-8") as f:
+ # We need to map custom_id back to the seed to get the GROUND TRUTH preferences.
+ # But wait, in submit_synthesis_batch.py, we created custom_id as "syn_{original_id}".
+ # And we need to find the original seed by that ID.
+ # Problem: positive_seeds.jsonl contains the FULL record including 'extracted_json'.
+ # We can iterate and build a map: original_custom_id -> record
+ for idx, line in enumerate(f):
+ if line.strip():
+ item = json.loads(line)
+ # If item has custom_id, use it. If not, we used "seed_{i}" in submission.
+ # Let's hope positive_seeds.jsonl has custom_id (it should if it came from retrieve script).
+ cid = item.get("custom_id")
+ if not cid:
+ # Fallback if custom_id missing (e.g. from some older process)
+ # We generated "seed_{i}" in submit script.
+ cid = f"seed_{idx}"
+
+ mapping[cid] = item
+ return mapping
+
+def retrieve_synthesis():
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ print("Error: OPENAI_API_KEY not set.")
+ return
+ client = OpenAI(api_key=api_key)
+
+ if not os.path.exists(BATCH_IDS_FILE):
+ print(f"Error: {BATCH_IDS_FILE} not found.")
+ return
+
+ with open(BATCH_IDS_FILE, "r") as f:
+ batch_ids = json.load(f)
+
+ seed_map = load_seeds()
+ count_rewrites = 0
+ count_source_seeds = 0
+
+ print(f"Processing Synthesis Batches -> {OUTPUT_FILE}...")
+
+ with open(OUTPUT_FILE, "w", encoding="utf-8") as f_out:
+ for b_id in batch_ids:
+ print(f"\nProcessing Batch {b_id}...")
+ try:
+ batch = client.batches.retrieve(b_id)
+ if batch.output_file_id:
+ print(f" Downloading output {batch.output_file_id}...")
+ content = client.files.content(batch.output_file_id).text
+
+ for line in content.splitlines():
+ if not line.strip(): continue
+ res = json.loads(line)
+ syn_id = res["custom_id"] # e.g. "syn_req_123"
+
+ # Derive original seed ID: remove "syn_" prefix
+ if syn_id.startswith("syn_"):
+ orig_id = syn_id[4:]
+ else:
+ orig_id = syn_id
+
+ if res["response"]["status_code"] == 200:
+ try:
+ body = res["response"]["body"]
+ llm_content = body["choices"][0]["message"]["content"]
+ parsed_json = json.loads(llm_content)
+
+ rewrites = parsed_json.get("rewrites", [])
+ if not rewrites:
+ continue
+
+ # Find original preference to inherit
+ seed = seed_map.get(orig_id)
+ if seed:
+ prefs = seed.get("extracted_json")
+ # Create new records
+ for rw in rewrites:
+ new_record = {
+ "original_query": rw,
+ "source": "synthesis_gpt4o",
+ "parent_id": orig_id,
+ "extracted_json": prefs, # INHERIT PREFERENCE
+ "has_preference": True
+ }
+ f_out.write(json.dumps(new_record, ensure_ascii=False) + "\n")
+ count_rewrites += 1
+ count_source_seeds += 1
+ else:
+ # print(f"Warning: Seed {orig_id} not found in map")
+ pass
+ except Exception as e:
+ print(f"Parse error {syn_id}: {e}")
+ except Exception as e:
+ print(f"Error checking batch {b_id}: {e}")
+
+ print("\n" + "="*50)
+ print("SYNTHESIS RETRIEVAL COMPLETE")
+ print(f"Processed Source Seeds: {count_source_seeds}")
+ print(f"Generated New Samples: {count_rewrites}")
+ print(f"Saved to: {OUTPUT_FILE}")
+ print("="*50)
+
+if __name__ == "__main__":
+ retrieve_synthesis()
+
diff --git a/scripts/run_putnam_evaluation.py b/scripts/run_putnam_evaluation.py
new file mode 100644
index 0000000..f320eea
--- /dev/null
+++ b/scripts/run_putnam_evaluation.py
@@ -0,0 +1,164 @@
+import json
+import os
+import glob
+import argparse
+from typing import List, Dict, Any
+from openai import OpenAI
+
+# Configuration
+DATA_DIR = "LLaMA-Factory/preprocess/PutnamGAP"
+OUTPUT_DIR = "data/putnam_eval"
+OUTPUT_FILENAME = "putnam_eval_batch.jsonl"
+MODEL_NAME = "gpt-5" # User requested gpt-5
+
+SYSTEM_PROMPT = """You are an expert mathematician and a judge for math competitions. You are given an original math problem (and its solution) and a "kernel variant" of that problem (and its solution).
+
+Your task is to:
+1. Evaluate the correctness of the kernel variant. Is the problem statement mathematically sound and clear? Is the provided solution correct?
+2. Evaluate the relationship between the original problem and the kernel variant. Are they mathematically equivalent? Or is the variant a strong abstraction/generalization/simplification of the original? Do they test the same core concepts?
+
+Output your analysis in the following JSON format:
+{
+ "variant_validity": {
+ "is_problem_valid": boolean,
+ "is_solution_correct": boolean,
+ "comments": "string"
+ },
+ "relation_to_original": {
+ "is_equivalent": boolean,
+ "is_strongly_related": boolean,
+ "relationship_description": "string"
+ }
+}"""
+
+def load_dataset(data_dir: str) -> List[Dict[str, Any]]:
+ files = glob.glob(os.path.join(data_dir, "*.json"))
+ items = []
+ print(f"Scanning {len(files)} files in {data_dir}...")
+ for fpath in files:
+ try:
+ with open(fpath, "r", encoding="utf-8") as f:
+ data = json.load(f)
+
+ # Check for required fields
+ if "variants" not in data or "kernel_variant" not in data["variants"]:
+ continue
+
+ orig_q = data.get("question", "")
+ orig_s = data.get("solution", "")
+ kv = data["variants"]["kernel_variant"]
+ kv_q = kv.get("question", "")
+ kv_s = kv.get("solution", "")
+
+ if not kv_q:
+ continue
+
+ items.append({
+ "id": data.get("index", os.path.basename(fpath)),
+ "original_question": orig_q,
+ "original_solution": orig_s,
+ "kernel_variant_question": kv_q,
+ "kernel_variant_solution": kv_s
+ })
+ except Exception as e:
+ print(f"Error reading {fpath}: {e}")
+ return items
+
+def create_batch_file(items: List[Dict[str, Any]], output_path: str):
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+
+ count = 0
+ with open(output_path, "w", encoding="utf-8") as f:
+ for item in items:
+ user_content = f"""[Original Problem]
+{item['original_question']}
+
+[Original Solution]
+{item['original_solution']}
+
+[Kernel Variant Problem]
+{item['kernel_variant_question']}
+
+[Kernel Variant Solution]
+{item['kernel_variant_solution']}"""
+
+ # Construct request
+ request_obj = {
+ "custom_id": f"req_{item['id']}",
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ "model": MODEL_NAME,
+ "messages": [
+ {"role": "system", "content": SYSTEM_PROMPT},
+ {"role": "user", "content": user_content}
+ ],
+ "response_format": {"type": "json_object"}
+ }
+ }
+ f.write(json.dumps(request_obj) + "\n")
+ count += 1
+
+ print(f"Created batch file at {output_path} with {count} requests.")
+ return count
+
+def submit_batch(file_path: str):
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ print("Error: OPENAI_API_KEY not set. Cannot submit.")
+ return
+
+ client = OpenAI(api_key=api_key)
+
+ print(f"Uploading {file_path} to OpenAI...")
+ with open(file_path, "rb") as f:
+ batch_file_obj = client.files.create(
+ file=f,
+ purpose="batch"
+ )
+ file_id = batch_file_obj.id
+ print(f"Uploaded. File ID: {file_id}")
+
+ print("Submitting Batch Job...")
+ batch_job = client.batches.create(
+ input_file_id=file_id,
+ endpoint="/v1/chat/completions",
+ completion_window="24h",
+ metadata={
+ "description": "PutnamGAP Evaluation"
+ }
+ )
+ print(f"Submitted. Batch ID: {batch_job.id}")
+
+ # Save Batch ID
+ id_file = os.path.join(os.path.dirname(file_path), "submitted_batch_ids.json")
+ existing_ids = []
+ if os.path.exists(id_file):
+ try:
+ with open(id_file, "r") as f:
+ existing_ids = json.load(f)
+ except:
+ pass
+ existing_ids.append(batch_job.id)
+ with open(id_file, "w") as f:
+ json.dump(existing_ids, f, indent=2)
+ print(f"Batch ID saved to {id_file}")
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Prepare and optionally submit PutnamGAP evaluation batch.")
+ parser.add_argument("--submit", action="store_true", help="Submit the batch to OpenAI after generating.")
+ args = parser.parse_args()
+
+ items = load_dataset(DATA_DIR)
+ print(f"Found {len(items)} items with kernel variants.")
+
+ output_path = os.path.join(OUTPUT_DIR, OUTPUT_FILENAME)
+ if items:
+ create_batch_file(items, output_path)
+ if args.submit:
+ submit_batch(output_path)
+ else:
+ print("Use --submit to submit the batch to OpenAI.")
+ else:
+ print("No items found to process.")
+
diff --git a/scripts/run_server.py b/scripts/run_server.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/scripts/run_server.py
diff --git a/scripts/smoke_extractor_llm.py b/scripts/smoke_extractor_llm.py
new file mode 100644
index 0000000..b16d0e2
--- /dev/null
+++ b/scripts/smoke_extractor_llm.py
@@ -0,0 +1,54 @@
+#!/usr/bin/env python3
+"""
+Smoke test for PreferenceExtractorLLM (Qwen3-0.6B).
+Requires 'saves/qwen3-0.6b-full-sft-h200/checkpoint-4358' to be present.
+"""
+
+import sys
+import os
+import json
+
+# Add src to sys.path
+sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.config.registry import get_preference_extractor
+from personalization.retrieval.preference_store.schemas import ChatTurn
+
+def main():
+ print("Initializing Preference Extractor (qwen3_0_6b_sft)...")
+ try:
+ extractor = get_preference_extractor("qwen3_0_6b_sft")
+ except Exception as e:
+ print(f"Failed to load extractor: {e}")
+ print("Please check if the checkpoint exists at saves/qwen3-0.6b-full-sft-h200/checkpoint-4358")
+ print("and local_models.yaml is configured correctly.")
+ sys.exit(1)
+
+ print("Extractor loaded successfully.")
+
+ # Construct dummy conversation
+ turns = [
+ ChatTurn(user_id="u1", session_id="s1", turn_id=0, role="user", text="Hi, I am learning Python. Please always use Python 3.11 in your code examples."),
+ ChatTurn(user_id="u1", session_id="s1", turn_id=1, role="assistant", text="Hello! Python is a great language. How can I help?"),
+ ChatTurn(user_id="u1", session_id="s1", turn_id=2, role="user", text="Please explain lists. And btw, always use snake_case for variables in your code examples."),
+ ]
+
+ print("\n--- Input Turns ---")
+ for t in turns:
+ print(f"[{t.role}]: {t.text}")
+
+ print("\n--- Extracting ---")
+ prefs = extractor.extract_turn(turns)
+
+ print("\n--- Output PreferenceList ---")
+ print(prefs.model_dump_json(indent=2))
+
+ # Validation
+ if prefs.preferences:
+ print("\nSUCCESS: Extracted preferences found.")
+ else:
+ print("\nWARNING: No preferences extracted. (Model might need warming up or prompt adjustment)")
+
+if __name__ == "__main__":
+ main()
+
diff --git a/scripts/smoke_llms.py b/scripts/smoke_llms.py
new file mode 100644
index 0000000..109020a
--- /dev/null
+++ b/scripts/smoke_llms.py
@@ -0,0 +1,73 @@
+from __future__ import annotations
+
+from pathlib import Path
+import sys
+import json
+import os
+
+ROOT = Path(__file__).resolve().parents[1]
+sys.path.insert(0, str(ROOT / "src"))
+
+from personalization.config.settings import load_local_models_config
+from personalization.models.llm.qwen_instruct import QwenInstruct
+from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor
+
+
+def ensure_models_present() -> bool:
+ cfg = load_local_models_config()
+ ok = True
+ for spec in (cfg.llm, cfg.preference_extractor):
+ path = ROOT / spec.local_path
+ if not path.exists() or not any(path.iterdir()):
+ print(f"[missing] {path} is empty. Run: python scripts/pull_models.py --target all")
+ ok = False
+ return ok
+
+
+def main() -> None:
+ if not ensure_models_present():
+ return
+
+ cfg = load_local_models_config()
+ os.environ["PREF_DEBUG"] = "1"
+ llm = QwenInstruct.from_config(cfg)
+ extractor = QwenRuleExtractor.from_config(cfg)
+
+ print("[llm] generating...")
+ out = llm.generate("Say hello in one short sentence.", max_new_tokens=32, temperature=0.2)
+ print(out)
+
+ print("[extractor] extracting...")
+ scenarios = [
+ (
+ "math_latex",
+ "Consider the sequence defined by a_1 = 1 and a_{n+1} = a_n + 1/n for n >= 1. "
+ "(1) Prove that a_n diverges. (2) Derive an asymptotic expression for a_n in terms of the harmonic numbers H_n. "
+ "(3) Compute the limit of (a_n - ln n) as n -> infinity. Please use LaTeX for the output.",
+ ),
+ (
+ "code_python311",
+ "I have a performance bottleneck in my Python code that processes large CSV files. It reads rows, aggregates stats, and writes summaries. "
+ "Explain how to optimize I/O and memory, discuss multiprocessing vs async. When you show code, please use Python 3.11 syntax and include type hints in the snippets.",
+ ),
+ (
+ "data_json_only",
+ "Given a dataset of user events with timestamps, device types, and regions, outline steps to compute DAU, WAU, and retention. "
+ "List pitfalls and how to handle missing data. Return your final answer as JSON only.",
+ ),
+ (
+ "writing_concise_no_emoji",
+ "Explain the difference between supervised and reinforcement learning with practical examples and cautions. "
+ "Keep answers concise and avoid emojis.",
+ ),
+ ]
+ for name, query in scenarios:
+ print(f"\n[scenario] {name}")
+ prefs = extractor.extract_preferences(query)
+ print(json.dumps(prefs, indent=2, ensure_ascii=False))
+
+
+if __name__ == "__main__":
+ main()
+
+
diff --git a/scripts/split_train_test.py b/scripts/split_train_test.py
new file mode 100644
index 0000000..ccb4cb1
--- /dev/null
+++ b/scripts/split_train_test.py
@@ -0,0 +1,76 @@
+import json
+import os
+import random
+
+INPUT_FILE = "data/finetune/preference_extractor_450k.jsonl"
+TRAIN_FILE = "data/finetune/train_llama_factory.json"
+TEST_FILE = "data/finetune/test_llama_factory.json"
+TEST_SIZE = 1000
+
+SYSTEM_INSTRUCTION = (
+ "Extract user preferences from the query into JSON format based on the PreferenceList schema. "
+ "If no preferences are found, return {\"preferences\": []}."
+)
+
+def split_and_convert():
+ if not os.path.exists(INPUT_FILE):
+ print(f"Error: {INPUT_FILE} not found.")
+ return
+
+ print(f"Reading {INPUT_FILE}...")
+ all_data = []
+
+ with open(INPUT_FILE, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ item = json.loads(line)
+ # Convert to LLaMA-Factory format immediately
+ record = {
+ "instruction": SYSTEM_INSTRUCTION,
+ "input": item["input"],
+ "output": item["output"]
+ }
+ all_data.append(record)
+
+ print(f"Total records: {len(all_data)}")
+
+ # Shuffle
+ random.seed(42) # Fixed seed for reproducibility
+ random.shuffle(all_data)
+
+ # Split
+ test_data = all_data[:TEST_SIZE]
+ train_data = all_data[TEST_SIZE:]
+
+ print(f"Train size: {len(train_data)}")
+ print(f"Test size: {len(test_data)}")
+
+ # Save Train
+ print(f"Saving train set to {TRAIN_FILE}...")
+ with open(TRAIN_FILE, "w", encoding="utf-8") as f:
+ json.dump(train_data, f, indent=2, ensure_ascii=False)
+
+ # Save Test
+ print(f"Saving test set to {TEST_FILE}...")
+ with open(TEST_FILE, "w", encoding="utf-8") as f:
+ json.dump(test_data, f, indent=2, ensure_ascii=False)
+
+ print("Done!")
+
+ # Update dataset_info advice
+ print("\nUpdate dataset_info.json with:")
+ info = {
+ "preference_extractor_train": {
+ "file_name": "train_llama_factory.json",
+ "columns": {"prompt": "instruction", "query": "input", "response": "output"}
+ },
+ "preference_extractor_test": {
+ "file_name": "test_llama_factory.json",
+ "columns": {"prompt": "instruction", "query": "input", "response": "output"}
+ }
+ }
+ print(json.dumps(info, indent=2))
+
+if __name__ == "__main__":
+ split_and_convert()
+
diff --git a/scripts/stats_and_extract.py b/scripts/stats_and_extract.py
new file mode 100644
index 0000000..402ade7
--- /dev/null
+++ b/scripts/stats_and_extract.py
@@ -0,0 +1,56 @@
+import json
+import os
+
+INPUT_FILE = "data/raw_datasets/labeled_full_dataset_batch.jsonl"
+OUTPUT_POS_FILE = "data/raw_datasets/positive_seeds.jsonl"
+
+def extract_and_stats():
+ if not os.path.exists(INPUT_FILE):
+ print(f"Error: {INPUT_FILE} not found.")
+ return
+
+ print(f"Scanning {INPUT_FILE}...")
+
+ total = 0
+ pos_count = 0
+ neg_count = 0
+
+ # Optional: Track distribution of preference types/keys if needed
+
+ with open(INPUT_FILE, "r", encoding="utf-8") as f_in, \
+ open(OUTPUT_POS_FILE, "w", encoding="utf-8") as f_out:
+
+ for line in f_in:
+ if not line.strip(): continue
+ try:
+ item = json.loads(line)
+ total += 1
+
+ # Check if positive
+ # Our labeling script ensures 'has_preference' boolean,
+ # but let's double check the actual list to be safe.
+ prefs = item.get("extracted_json", {}).get("preferences", [])
+
+ if prefs and len(prefs) > 0:
+ pos_count += 1
+ f_out.write(line)
+ else:
+ neg_count += 1
+ except:
+ pass # Skip malformed lines
+
+ ratio = (pos_count / total * 100) if total > 0 else 0
+
+ print("\n" + "="*30)
+ print("DATASET STATISTICS")
+ print("="*30)
+ print(f"Total Rows: {total}")
+ print(f"Positive Rows: {pos_count} ({ratio:.2f}%)")
+ print(f"Negative Rows: {neg_count}")
+ print("-" * 30)
+ print(f"Positive seeds saved to: {OUTPUT_POS_FILE}")
+ print("="*30)
+
+if __name__ == "__main__":
+ extract_and_stats()
+
diff --git a/scripts/submit_batch.py b/scripts/submit_batch.py
new file mode 100644
index 0000000..e848dc5
--- /dev/null
+++ b/scripts/submit_batch.py
@@ -0,0 +1,111 @@
+import json
+import os
+import time
+from typing import List
+from openai import OpenAI
+
+# --- Configuration ---
+INPUT_FILE = "data/raw_datasets/combined_raw_queries.jsonl"
+BATCH_DIR = "data/raw_datasets/batch_files"
+MODEL_NAME = "gpt-5.1" # Or "gpt-4o"
+BATCH_SIZE_LIMIT = 49000 # Safe under 50k limit
+
+# --- Load System Prompt ---
+with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f:
+ SYSTEM_PROMPT = f.read()
+
+def prepare_and_submit_batches():
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ print("Error: OPENAI_API_KEY not set.")
+ return
+ client = OpenAI(api_key=api_key)
+
+ os.makedirs(BATCH_DIR, exist_ok=True)
+
+ print(f"Reading from {INPUT_FILE}...")
+
+ # Read all lines first
+ all_lines = []
+ with open(INPUT_FILE, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ all_lines.append(json.loads(line))
+
+ total_items = len(all_lines)
+ print(f"Total items: {total_items}")
+
+ batch_ids = []
+
+ # Split and Process
+ for batch_idx, i in enumerate(range(0, total_items, BATCH_SIZE_LIMIT)):
+ chunk = all_lines[i : i + BATCH_SIZE_LIMIT]
+ chunk_filename = os.path.join(BATCH_DIR, f"batch_input_part_{batch_idx}.jsonl")
+
+ print(f"\n--- Processing Batch {batch_idx} ({len(chunk)} items) ---")
+
+ # 1. Create File
+ with open(chunk_filename, "w", encoding="utf-8") as f_out:
+ for item_idx, item in enumerate(chunk):
+ # Global index to track back later if needed
+ global_idx = i + item_idx
+ query = item["query"]
+
+ # Custom ID: "req_{global_index}"
+ custom_id = f"req_{global_idx}"
+
+ request_obj = {
+ "custom_id": custom_id,
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ "model": MODEL_NAME,
+ "messages": [
+ {"role": "system", "content": SYSTEM_PROMPT},
+ {"role": "user", "content": query}
+ ],
+ "temperature": 0.0,
+ "response_format": {"type": "json_object"}
+ }
+ }
+ f_out.write(json.dumps(request_obj) + "\n")
+
+ print(f"File created: {chunk_filename}")
+
+ # 2. Upload File
+ print("Uploading to OpenAI...")
+ batch_file_obj = client.files.create(
+ file=open(chunk_filename, "rb"),
+ purpose="batch"
+ )
+ file_id = batch_file_obj.id
+ print(f"Uploaded. File ID: {file_id}")
+
+ # 3. Submit Batch
+ print("Submitting Batch Job...")
+ batch_job = client.batches.create(
+ input_file_id=file_id,
+ endpoint="/v1/chat/completions",
+ completion_window="24h",
+ metadata={
+ "description": f"Pers. Extractor Part {batch_idx}",
+ "part_index": str(batch_idx)
+ }
+ )
+ print(f"Submitted. Batch ID: {batch_job.id}")
+ batch_ids.append(batch_job.id)
+
+ # Sleep briefly to be nice to API
+ time.sleep(1)
+
+ # Save all Batch IDs
+ id_file = "data/raw_datasets/submitted_batch_ids.json"
+ with open(id_file, "w") as f:
+ json.dump(batch_ids, f, indent=2)
+
+ print(f"\nALL DONE! Submitted {len(batch_ids)} batches.")
+ print(f"Batch IDs saved to {id_file}")
+ print("Run scripts/check_batch_status.py (you need to write it) to monitor.")
+
+if __name__ == "__main__":
+ prepare_and_submit_batches()
diff --git a/scripts/submit_oasst1_batch.py b/scripts/submit_oasst1_batch.py
new file mode 100644
index 0000000..1a96dd0
--- /dev/null
+++ b/scripts/submit_oasst1_batch.py
@@ -0,0 +1,120 @@
+import json
+import os
+import time
+from openai import OpenAI
+
+# --- Configuration ---
+INPUT_FILE = "data/raw_datasets/oasst1_queries.jsonl"
+BATCH_DIR = "data/raw_datasets/batch_files_oasst1"
+METADATA_FILE = "data/raw_datasets/oasst1_metadata_map.jsonl"
+MODEL_NAME = "gpt-5.1"
+BATCH_SIZE_LIMIT = 49000
+
+# --- Load System Prompt ---
+with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f:
+ SYSTEM_PROMPT = f.read()
+
+def submit_oasst1_batch():
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ print("Error: OPENAI_API_KEY not set.")
+ return
+ client = OpenAI(api_key=api_key)
+
+ os.makedirs(BATCH_DIR, exist_ok=True)
+
+ print(f"Reading from {INPUT_FILE}...")
+
+ all_lines = []
+ with open(INPUT_FILE, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ all_lines.append(json.loads(line))
+
+ total_items = len(all_lines)
+ print(f"Total OASST1 items: {total_items}")
+
+ # 1. Generate Metadata Map first
+ # This ensures we have the mapping even if batch submission fails mid-way
+ print(f"Generating metadata map to {METADATA_FILE}...")
+ with open(METADATA_FILE, "w", encoding="utf-8") as f_meta:
+ for idx, item in enumerate(all_lines):
+ custom_id = f"oasst1_req_{idx}"
+ meta_record = {
+ "custom_id": custom_id,
+ "user_id": item.get("user_id"),
+ "session_id": item.get("session_id"),
+ "turn_id": item.get("turn_id"),
+ "original_query": item.get("original_query") or item.get("query")
+ }
+ f_meta.write(json.dumps(meta_record, ensure_ascii=False) + "\n")
+
+ # Store custom_id back to item list for batch generation
+ item["_temp_custom_id"] = custom_id
+
+ # 2. Split and Submit
+ batch_ids = []
+
+ for batch_idx, i in enumerate(range(0, total_items, BATCH_SIZE_LIMIT)):
+ chunk = all_lines[i : i + BATCH_SIZE_LIMIT]
+ chunk_filename = os.path.join(BATCH_DIR, f"oasst1_batch_part_{batch_idx}.jsonl")
+
+ print(f"\n--- Processing OASST1 Batch {batch_idx} ({len(chunk)} items) ---")
+
+ with open(chunk_filename, "w", encoding="utf-8") as f_out:
+ for item in chunk:
+ custom_id = item["_temp_custom_id"]
+ query = item.get("original_query") or item.get("query")
+
+ request_obj = {
+ "custom_id": custom_id,
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ "model": MODEL_NAME,
+ "messages": [
+ {"role": "system", "content": SYSTEM_PROMPT},
+ {"role": "user", "content": query}
+ ],
+ "temperature": 0.0,
+ "response_format": {"type": "json_object"}
+ }
+ }
+ f_out.write(json.dumps(request_obj) + "\n")
+
+ print(f"File created: {chunk_filename}")
+
+ print("Uploading to OpenAI...")
+ batch_file_obj = client.files.create(
+ file=open(chunk_filename, "rb"),
+ purpose="batch"
+ )
+ file_id = batch_file_obj.id
+ print(f"Uploaded. File ID: {file_id}")
+
+ print("Submitting Batch Job...")
+ batch_job = client.batches.create(
+ input_file_id=file_id,
+ endpoint="/v1/chat/completions",
+ completion_window="24h",
+ metadata={
+ "description": f"Pers. Extractor OASST1 Part {batch_idx}",
+ "dataset": "oasst1"
+ }
+ )
+ print(f"Submitted. Batch ID: {batch_job.id}")
+ batch_ids.append(batch_job.id)
+
+ time.sleep(1)
+
+ id_file = "data/raw_datasets/submitted_oasst1_batch_ids.json"
+ with open(id_file, "w") as f:
+ json.dump(batch_ids, f, indent=2)
+
+ print(f"\nALL DONE! Submitted {len(batch_ids)} OASST1 batches.")
+ print(f"Metadata saved to {METADATA_FILE}")
+ print(f"Batch IDs saved to {id_file}")
+
+if __name__ == "__main__":
+ submit_oasst1_batch()
+
diff --git a/scripts/submit_retry_batch.py b/scripts/submit_retry_batch.py
new file mode 100644
index 0000000..f564c4b
--- /dev/null
+++ b/scripts/submit_retry_batch.py
@@ -0,0 +1,88 @@
+import json
+import os
+import time
+from openai import OpenAI
+
+# --- Configuration ---
+RETRY_INPUT_FILE = "data/raw_datasets/retry_requests.jsonl"
+BATCH_DIR = "data/raw_datasets/batch_files_retry"
+BATCH_SIZE_LIMIT = 10000 # Smaller chunks as requested
+
+def submit_retry_batches():
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ print("Error: OPENAI_API_KEY not set.")
+ return
+ client = OpenAI(api_key=api_key)
+
+ os.makedirs(BATCH_DIR, exist_ok=True)
+
+ if not os.path.exists(RETRY_INPUT_FILE):
+ print(f"Error: {RETRY_INPUT_FILE} not found.")
+ return
+
+ print(f"Reading retry requests from {RETRY_INPUT_FILE}...")
+
+ all_lines = []
+ with open(RETRY_INPUT_FILE, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ all_lines.append(line.strip()) # Keep as string, no need to parse json
+
+ total_items = len(all_lines)
+ print(f"Total retry items: {total_items}")
+
+ batch_ids = []
+
+ # Split and Submit
+ for batch_idx, i in enumerate(range(0, total_items, BATCH_SIZE_LIMIT)):
+ chunk = all_lines[i : i + BATCH_SIZE_LIMIT]
+ chunk_filename = os.path.join(BATCH_DIR, f"retry_batch_part_{batch_idx}.jsonl")
+
+ print(f"\n--- Processing Retry Batch {batch_idx} ({len(chunk)} items) ---")
+
+ # 1. Create File
+ with open(chunk_filename, "w", encoding="utf-8") as f_out:
+ for line in chunk:
+ f_out.write(line + "\n")
+
+ print(f"File created: {chunk_filename}")
+
+ # 2. Upload File
+ print("Uploading to OpenAI...")
+ batch_file_obj = client.files.create(
+ file=open(chunk_filename, "rb"),
+ purpose="batch"
+ )
+ file_id = batch_file_obj.id
+ print(f"Uploaded. File ID: {file_id}")
+
+ # 3. Submit Batch
+ print("Submitting Batch Job...")
+ batch_job = client.batches.create(
+ input_file_id=file_id,
+ endpoint="/v1/chat/completions",
+ completion_window="24h",
+ metadata={
+ "description": f"Pers. Extractor RETRY Part {batch_idx}",
+ "retry": "true"
+ }
+ )
+ print(f"Submitted. Batch ID: {batch_job.id}")
+ batch_ids.append(batch_job.id)
+
+ time.sleep(1)
+
+ # Save Batch IDs (Append to existing or create new separate file?)
+ # Let's create a separate file for retries to avoid confusion.
+ id_file = "data/raw_datasets/submitted_retry_batch_ids.json"
+ with open(id_file, "w") as f:
+ json.dump(batch_ids, f, indent=2)
+
+ print(f"\nALL DONE! Submitted {len(batch_ids)} retry batches.")
+ print(f"Batch IDs saved to {id_file}")
+ print("Use scripts/check_retry_status.py (need to create/modify) to monitor.")
+
+if __name__ == "__main__":
+ submit_retry_batches()
+
diff --git a/scripts/submit_synthesis_batch.py b/scripts/submit_synthesis_batch.py
new file mode 100644
index 0000000..025782d
--- /dev/null
+++ b/scripts/submit_synthesis_batch.py
@@ -0,0 +1,131 @@
+import json
+import os
+from openai import OpenAI
+import time
+
+# --- Configuration ---
+INPUT_SEEDS = "data/raw_datasets/positive_seeds.jsonl"
+BATCH_DIR = "data/raw_datasets/batch_files_synthesis"
+MODEL_NAME = "gpt-5.1" # Or gpt-4o
+BATCH_SIZE_LIMIT = 30000 # 31k total, splitting into 2 files is safe
+
+SYNTHESIS_SYSTEM_PROMPT = """You are a data augmentation assistant.
+Your task is to rewrite a User Query that contains specific preferences into 5 different variations.
+The goal is to train a model to recognize these preferences in various contexts.
+
+Variations required:
+1. Formal/Polite: Use sophisticated language and polite markers.
+2. Casual/Direct: Use slang, abbreviations, or very direct commands.
+3. Implicit/Contextual: Embed the preference naturally within a larger context or story, making it harder to spot.
+4. Distractor-Heavy: Mix the preference with irrelevant information or another task.
+5. Imperative/Short: Extremely concise, almost robotic.
+
+Output strictly a JSON object with a single key "rewrites" containing a list of 5 strings.
+Example: {"rewrites": ["string1", "string2", "string3", "string4", "string5"]}
+"""
+
+def submit_synthesis_batch():
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ print("Error: OPENAI_API_KEY not set.")
+ return
+ client = OpenAI(api_key=api_key)
+
+ os.makedirs(BATCH_DIR, exist_ok=True)
+
+ if not os.path.exists(INPUT_SEEDS):
+ print(f"Error: {INPUT_SEEDS} not found.")
+ return
+
+ print(f"Reading seeds from {INPUT_SEEDS}...")
+
+ seeds = []
+ with open(INPUT_SEEDS, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ seeds.append(json.loads(line))
+
+ total_items = len(seeds)
+ print(f"Total seeds: {total_items}")
+
+ batch_ids = []
+
+ # Split and Submit
+ for batch_idx, i in enumerate(range(0, total_items, BATCH_SIZE_LIMIT)):
+ chunk = seeds[i : i + BATCH_SIZE_LIMIT]
+ chunk_filename = os.path.join(BATCH_DIR, f"synthesis_batch_part_{batch_idx}.jsonl")
+
+ print(f"\n--- Processing Synthesis Batch {batch_idx} ({len(chunk)} items) ---")
+
+ # 1. Create File
+ with open(chunk_filename, "w", encoding="utf-8") as f_out:
+ for item in chunk:
+ # We need to pass both the query and the extracted preference to help the model
+ # understand WHAT to preserve.
+ original_query = item["original_query"]
+ # extracted_json = item["extracted_json"] # Optional, but maybe helpful?
+ # Actually, showing the extracted preference ensures the rewrite keeps the core intent.
+
+ # Use original custom_id or create new one?
+ # Let's create new one: "syn_{original_custom_id}" if available, else "syn_{index}"
+ # Wait, positive_seeds might not have custom_id if it came from the recovered batch.
+ # Let's check keys. The recovered file usually has custom_id.
+ base_id = item.get("custom_id", f"seed_{i}")
+ custom_id = f"syn_{base_id}" # Prefix to distinguish
+
+ user_content = f"Original Query: {original_query}"
+ # Optionally add: f"\nCore Preference: {json.dumps(extracted_json)}"
+
+ request_obj = {
+ "custom_id": custom_id,
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ "model": MODEL_NAME,
+ "messages": [
+ {"role": "system", "content": SYNTHESIS_SYSTEM_PROMPT},
+ {"role": "user", "content": user_content}
+ ],
+ "temperature": 0.7, # Higher temp for diversity
+ "response_format": {"type": "json_object"}
+ }
+ }
+ f_out.write(json.dumps(request_obj) + "\n")
+
+ print(f"File created: {chunk_filename}")
+
+ # 2. Upload
+ print("Uploading to OpenAI...")
+ batch_file_obj = client.files.create(
+ file=open(chunk_filename, "rb"),
+ purpose="batch"
+ )
+ file_id = batch_file_obj.id
+ print(f"Uploaded. File ID: {file_id}")
+
+ # 3. Submit
+ print("Submitting Batch Job...")
+ batch_job = client.batches.create(
+ input_file_id=file_id,
+ endpoint="/v1/chat/completions",
+ completion_window="24h",
+ metadata={
+ "description": f"Pers. Extractor Synthesis Part {batch_idx}",
+ "type": "synthesis"
+ }
+ )
+ print(f"Submitted. Batch ID: {batch_job.id}")
+ batch_ids.append(batch_job.id)
+
+ time.sleep(1)
+
+ id_file = "data/raw_datasets/submitted_synthesis_batch_ids.json"
+ with open(id_file, "w") as f:
+ json.dump(batch_ids, f, indent=2)
+
+ print(f"\nALL DONE! Submitted {len(batch_ids)} synthesis batches.")
+ print(f"Batch IDs saved to {id_file}")
+
+if __name__ == "__main__":
+ submit_synthesis_batch()
+
diff --git a/scripts/upload_to_hf.py b/scripts/upload_to_hf.py
new file mode 100644
index 0000000..3c41011
--- /dev/null
+++ b/scripts/upload_to_hf.py
@@ -0,0 +1,69 @@
+import os
+import argparse
+from huggingface_hub import HfApi, create_repo, upload_folder
+
+def main():
+ parser = argparse.ArgumentParser(description="Upload a checkpoint to Hugging Face Hub")
+ parser.add_argument("--ckpt_path", type=str, required=True, help="Path to the checkpoint directory")
+ parser.add_argument("--repo_id", type=str, required=True, help="Hugging Face repo ID (e.g., username/model-name)")
+ parser.add_argument("--token", type=str, help="Hugging Face token (optional if logged in via CLI)")
+ parser.add_argument("--private", action="store_true", help="Make the repository private")
+ parser.add_argument("--include-optimizer", action="store_true", help="Include optimizer states (optimizer.pt, scheduler.pt, etc.)")
+
+ args = parser.parse_args()
+
+ # Expand path
+ ckpt_path = os.path.abspath(args.ckpt_path)
+ if not os.path.exists(ckpt_path):
+ print(f"Error: Checkpoint path does not exist: {ckpt_path}")
+ return
+
+ print(f"Preparing to upload {ckpt_path} to {args.repo_id}...")
+
+ api = HfApi(token=args.token)
+
+ # Create repo if it doesn't exist
+ try:
+ print(f"Creating repository {args.repo_id} (if not exists)...")
+ create_repo(
+ repo_id=args.repo_id,
+ token=args.token,
+ private=args.private,
+ exist_ok=True,
+ repo_type="model"
+ )
+ except Exception as e:
+ print(f"Error creating repo: {e}")
+ print("Please check your token and permissions.")
+ return
+
+ # patterns to ignore if we don't want optimizer states
+ ignore_patterns = []
+ if not args.include_optimizer:
+ ignore_patterns = [
+ "optimizer.pt",
+ "scheduler.pt",
+ "rng_state_*.pth",
+ "trainer_state.json",
+ "training_args.bin"
+ ]
+ print("Excluding optimizer states from upload.")
+
+ print("Starting upload...")
+ try:
+ api.upload_folder(
+ folder_path=ckpt_path,
+ repo_id=args.repo_id,
+ repo_type="model",
+ ignore_patterns=ignore_patterns,
+ token=args.token
+ )
+ print(f"Successfully uploaded to https://huggingface.co/{args.repo_id}")
+ except Exception as e:
+ print(f"Upload failed: {e}")
+
+if __name__ == "__main__":
+ main()
+
+
+
diff --git a/src/personalization/__init__.py b/src/personalization/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/__init__.py
diff --git a/src/personalization/config/__init__.py b/src/personalization/config/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/config/__init__.py
diff --git a/src/personalization/config/registry.py b/src/personalization/config/registry.py
new file mode 100644
index 0000000..d825ad3
--- /dev/null
+++ b/src/personalization/config/registry.py
@@ -0,0 +1,131 @@
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Any, Dict, Optional
+import torch
+import yaml
+
+from personalization.config import settings
+
+# Avoid circular imports by NOT importing extractors here at top level
+# from personalization.models.preference_extractor.base import PreferenceExtractorBase
+# from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor
+# from personalization.models.preference_extractor.gpt4o_extractor import GPT4OExtractor
+# from personalization.models.preference_extractor.llm_extractor import PreferenceExtractorLLM
+
+_DTYPE_MAP: Dict[str, torch.dtype] = {
+ "bfloat16": torch.bfloat16,
+ "float16": torch.float16,
+ "float32": torch.float32,
+}
+
+def choose_dtype(preferred: Optional[str] = None) -> torch.dtype:
+ if preferred and preferred.lower() in _DTYPE_MAP:
+ dt = _DTYPE_MAP[preferred.lower()]
+ else:
+ dt = torch.bfloat16 if torch.cuda.is_available() else torch.float32
+ if dt is torch.bfloat16 and not torch.cuda.is_available():
+ return torch.float32
+ return dt
+
+def choose_device_map(spec: Optional[str] = "auto") -> Any:
+ return spec or "auto"
+
+def ensure_local_path(path_str: str) -> str:
+ path = Path(path_str)
+ if not path.exists():
+ path.mkdir(parents=True, exist_ok=True)
+ return str(path)
+
+# --- Chat Model Factory ---
+def get_chat_model(name: str, device_override: Optional[str] = None):
+ """
+ Get a chat model by name.
+
+ Args:
+ name: Model name (e.g., "qwen_1_5b", "llama_8b")
+ device_override: Optional device override (e.g., "cuda:2"). If None, uses config default.
+ """
+ from personalization.models.llm.base import ChatModel
+ from personalization.models.llm.qwen_instruct import QwenInstruct
+ from personalization.models.llm.llama_instruct import LlamaChatModel
+
+ cfg = settings.load_local_models_config()
+
+ # Try to load raw config to support multi-backend map
+ with open("configs/local_models.yaml", "r") as f:
+ raw_cfg = yaml.safe_load(f)
+
+ models = raw_cfg.get("models", {}).get("llm", {})
+
+ # If models['llm'] is a dict of configs (new style)
+ if isinstance(models, dict) and "backend" in models.get(name, {}):
+ spec = models[name]
+ backend = spec.get("backend", "qwen")
+ path = spec["path"]
+ device = device_override or spec.get("device", "cuda") # Use override if provided
+ dtype = spec.get("dtype", "bfloat16")
+ max_len = spec.get("max_context_length", 4096)
+
+ if backend == "qwen":
+ return QwenInstruct(
+ model_path=path,
+ device=device,
+ dtype=choose_dtype(dtype), # Converts string to torch.dtype
+ max_context_length=max_len
+ )
+ elif backend == "llama":
+ return LlamaChatModel(
+ model_path=path,
+ device=device,
+ dtype=choose_dtype(dtype), # Converts string to torch.dtype
+ max_context_length=max_len
+ )
+
+ # Fallback to legacy single config
+ return QwenInstruct.from_config(cfg)
+
+def get_preference_extractor(name: Optional[str] = None):
+ # Deferred imports to break circular dependency
+ from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor
+ from personalization.models.preference_extractor.gpt4o_extractor import GPT4OExtractor
+ from personalization.models.preference_extractor.llm_extractor import PreferenceExtractorLLM
+
+ cfg = settings.load_local_models_config()
+ pref_cfg = cfg.preference_extractor
+
+ if name is None:
+ if isinstance(pref_cfg, dict) and "qwen3_0_6b_sft" in pref_cfg:
+ name = "qwen3_0_6b_sft"
+ else:
+ name = "rule"
+
+ if isinstance(pref_cfg, dict) and name in pref_cfg:
+ spec = pref_cfg[name]
+ if name == "qwen3_0_6b_sft":
+ # Use QwenRuleExtractor which we have updated for SFT End-to-End logic
+ return QwenRuleExtractor(
+ model_path=spec["path"],
+ device_map=spec.get("device", "auto"),
+ dtype=choose_dtype(spec.get("dtype", "bfloat16")),
+ )
+ # Add 'default' handling if mapped to rule/gpt
+ if name == "default":
+ pass
+
+ if name == "gpt4o":
+ return GPT4OExtractor.from_config(cfg)
+ elif name == "rule":
+ if isinstance(pref_cfg, dict):
+ if "default" in pref_cfg:
+ # Manually construct to bypass ModelSpec mismatch if needed
+ spec_dict = pref_cfg["default"]
+ return QwenRuleExtractor(
+ model_path=spec_dict["local_path"],
+ dtype=choose_dtype(spec_dict.get("dtype")),
+ device_map=choose_device_map(spec_dict.get("device_map"))
+ )
+ else:
+ return QwenRuleExtractor.from_config(cfg)
+
+ raise ValueError(f"Could not load preference extractor: {name}")
diff --git a/src/personalization/config/settings.py b/src/personalization/config/settings.py
new file mode 100644
index 0000000..1bb1bbe
--- /dev/null
+++ b/src/personalization/config/settings.py
@@ -0,0 +1,73 @@
+from __future__ import annotations
+
+import os
+from pathlib import Path
+from typing import Optional, Any, Dict
+
+import yaml
+from pydantic import BaseModel, Field
+
+
+class ModelSpec(BaseModel):
+ hf_id: str = Field(..., description="Hugging Face repository id")
+ local_path: str = Field(..., description="Local directory for model weights")
+ dtype: Optional[str] = Field(
+ default="bfloat16", description="Preferred torch dtype: bfloat16|float16|float32"
+ )
+ device_map: Optional[str] = Field(default="auto", description="Device map policy")
+
+
+class EmbeddingModelsConfig(BaseModel):
+ qwen3: Optional[ModelSpec] = None
+ nemotron: Optional[ModelSpec] = None
+
+
+class RerankerModelsConfig(BaseModel):
+ qwen3_8b: Optional[ModelSpec] = None
+
+
+class LocalModelsConfig(BaseModel):
+ llm: ModelSpec
+ preference_extractor: Any # Allow flexible dict or ModelSpec for now to support map
+ embedding: Optional[EmbeddingModelsConfig] = None
+ reranker: Optional[RerankerModelsConfig] = None
+
+
+def _resolve_config_path(env_key: str, default_rel: str) -> Path:
+ value = os.getenv(env_key)
+ if value:
+ return Path(value).expanduser().resolve()
+ return (Path.cwd() / default_rel).resolve()
+
+
+def load_local_models_config(path: Optional[str] = None) -> LocalModelsConfig:
+ config_path = Path(path) if path else _resolve_config_path(
+ "LOCAL_MODELS_CONFIG", "configs/local_models.yaml"
+ )
+ with open(config_path, "r", encoding="utf-8") as f:
+ raw = yaml.safe_load(f) or {}
+ models = raw.get("models", {})
+ embedding_cfg = None
+ if "embedding" in models:
+ emb = models["embedding"] or {}
+ # dtype/device_map are not necessary for embedders; ModelSpec still accepts them
+ embedding_cfg = EmbeddingModelsConfig(
+ qwen3=ModelSpec(**emb["qwen3"]) if "qwen3" in emb else None,
+ nemotron=ModelSpec(**emb["nemotron"]) if "nemotron" in emb else None,
+ )
+
+ reranker_cfg = None
+ if "reranker" in models:
+ rer = models["reranker"] or {}
+ reranker_cfg = RerankerModelsConfig(
+ qwen3_8b=ModelSpec(**rer["qwen3_8b"]) if "qwen3_8b" in rer else None
+ )
+
+ return LocalModelsConfig(
+ llm=ModelSpec(**models["llm"]),
+ preference_extractor=models["preference_extractor"], # Pass raw dict/value
+ embedding=embedding_cfg,
+ reranker=reranker_cfg,
+ )
+
+
diff --git a/src/personalization/data/personamem_loader.py b/src/personalization/data/personamem_loader.py
new file mode 100644
index 0000000..3b516ad
--- /dev/null
+++ b/src/personalization/data/personamem_loader.py
@@ -0,0 +1,84 @@
+from __future__ import annotations
+
+import csv
+import json
+from dataclasses import dataclass
+from typing import Dict, List
+
+@dataclass
+class PersonaMemQuestion:
+ persona_id: str
+ question_id: str
+ question_type: str
+ topic: str
+ user_question_or_message: str
+ all_options: List[str] # 4 options
+ correct_index: int # 0..3
+ shared_context_id: str
+ end_index_in_shared_context: int
+
+@dataclass
+class PersonaMemContext:
+ shared_context_id: str
+ messages: List[dict] # raw dicts with "role"/"content" etc
+
+def load_personamem_questions_32k(path_csv: str) -> List[PersonaMemQuestion]:
+ questions = []
+ with open(path_csv, "r", encoding="utf-8") as f:
+ reader = csv.DictReader(f)
+ for row in reader:
+ # Check fields
+ # The official csv usually has: question_id, persona_id, shared_context_id, question, correct_answer, options etc.
+ # Assuming standard PersonaMem format or similar to provided description.
+ # We might need to adjust based on actual file content.
+ # Based on user description:
+ try:
+ options_str = row.get("all_options", "[]") # Assuming json string
+ try:
+ options = json.loads(options_str)
+ except:
+ # Fallback if it's not JSON (e.g. string repr)
+ # For now assume JSON or simple list
+ options = []
+
+ # Handle raw answer format (e.g. "(c)" or "c")
+ raw_ans = row.get("correct_answer", "").strip()
+ # Remove parens if present
+ if raw_ans.startswith("(") and raw_ans.endswith(")"):
+ raw_ans = raw_ans[1:-1]
+
+ # Parse correct index
+ # If correct_answer is 'A','B','C','D' -> 0,1,2,3
+ ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'a': 0, 'b': 1, 'c': 2, 'd': 3}
+ correct_idx = ans_map.get(raw_ans, -1)
+
+ q = PersonaMemQuestion(
+ persona_id=row["persona_id"],
+ question_id=row["question_id"],
+ question_type=row.get("question_type", "unknown"),
+ topic=row.get("topic", "unknown"),
+ user_question_or_message=row.get("user_question_or_message", row.get("question", "")),
+ all_options=options,
+ correct_index=correct_idx,
+ shared_context_id=row["shared_context_id"],
+ end_index_in_shared_context=int(row.get("end_index_in_shared_context", -1))
+ )
+ questions.append(q)
+ except KeyError as e:
+ # print(f"Skipping row due to missing key: {e}")
+ continue
+ return questions
+
+def load_personamem_contexts_32k(path_jsonl: str) -> Dict[str, PersonaMemContext]:
+ contexts = {}
+ with open(path_jsonl, "r", encoding="utf-8") as f:
+ for line in f:
+ data = json.loads(line)
+ # Format: {"hash_id": [messages...]}
+ for cid, msgs in data.items():
+ contexts[cid] = PersonaMemContext(
+ shared_context_id=cid,
+ messages=msgs
+ )
+ return contexts
+
diff --git a/src/personalization/evaluation/__init__.py b/src/personalization/evaluation/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/evaluation/__init__.py
diff --git a/src/personalization/evaluation/compare_pairs.py b/src/personalization/evaluation/compare_pairs.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/evaluation/compare_pairs.py
diff --git a/src/personalization/evaluation/metrics.py b/src/personalization/evaluation/metrics.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/evaluation/metrics.py
diff --git a/src/personalization/feedback/__init__.py b/src/personalization/feedback/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/feedback/__init__.py
diff --git a/src/personalization/feedback/gating.py b/src/personalization/feedback/gating.py
new file mode 100644
index 0000000..d741874
--- /dev/null
+++ b/src/personalization/feedback/gating.py
@@ -0,0 +1,72 @@
+import numpy as np
+from personalization.feedback.schemas import TurnSample
+
+def cosine_sim_batch(matrix: np.ndarray, vector: np.ndarray) -> np.ndarray:
+ # matrix: [N, d], vector: [d]
+ # return: [N]
+ norm_m = np.linalg.norm(matrix, axis=1)
+ norm_v = np.linalg.norm(vector)
+
+ # Avoid div by zero
+ den = norm_m * norm_v
+ den[den == 0] = 1e-9
+
+ return np.dot(matrix, vector) / den
+
+def estimate_retrieval_gating(sample: TurnSample, reward_hat: float) -> float:
+ """
+ Return g_t in [0,1], representing how much the reward is due to retrieval.
+ """
+ e_q = sample.query_embedding_t
+ e_q1 = sample.query_embedding_t1
+
+ if e_q is None or e_q1 is None or not sample.memories:
+ return 0.5 # Neutral
+
+ # We need embeddings of the memories.
+ # In a real pipeline, we might pass them in sample.memory_embeddings.
+ # If missing, we can't compute sim.
+ if sample.memory_embeddings is None:
+ # Try to use embedding_e from memory cards if available
+ # But MemoryCard.embedding_e is List[float]
+ try:
+ mem_embs = np.array([m.embedding_e for m in sample.memories])
+ if mem_embs.shape[1] == 0: # Empty embeddings
+ return 0.5
+ except:
+ return 0.5
+ else:
+ mem_embs = sample.memory_embeddings
+
+ # Compute similarities
+ # shape: [K]
+ sims_q = cosine_sim_batch(mem_embs, e_q)
+ sims_q1 = cosine_sim_batch(mem_embs, e_q1)
+
+ s_q_max = sims_q.max() if len(sims_q) > 0 else 0
+ s_q1_max = sims_q1.max() if len(sims_q1) > 0 else 0
+
+ g = 0.5
+
+ # Heuristics
+
+ # Case A: Retrieval clearly irrelevant + bad reward
+ # q_t / q_{t+1} have low similarity to memories -> likely retrieval failure (or no relevant memories)
+ if reward_hat < -0.5 and s_q_max < 0.2 and s_q1_max < 0.2:
+ g = 0.9 # Blame retrieval (for failing to find anything, or nothing exists)
+
+ # Case B: Retrieval looks good but reward is bad
+ # Memories are relevant to query, but user still unhappy -> LLM didn't use them well?
+ elif reward_hat < -0.5 and s_q_max > 0.5:
+ g = 0.2 # Likely LLM fault
+
+ # Case C: Good reward
+ # If reward is high, we assume both did okay.
+ elif reward_hat > 0.5:
+ if s_q_max > 0.4:
+ g = 0.6 # Retrieval helped
+ else:
+ g = 0.3 # LLM handled it without strong retrieval help
+
+ return float(g)
+
diff --git a/src/personalization/feedback/handlers.py b/src/personalization/feedback/handlers.py
new file mode 100644
index 0000000..60a8d17
--- /dev/null
+++ b/src/personalization/feedback/handlers.py
@@ -0,0 +1,50 @@
+from typing import Tuple, List, Optional
+import numpy as np
+
+from personalization.retrieval.preference_store.schemas import MemoryCard
+from personalization.feedback.schemas import TurnSample
+from personalization.feedback.reward_model import estimate_reward
+from personalization.feedback.gating import estimate_retrieval_gating
+
+def eval_step(
+ q_t: str,
+ answer_t: str,
+ q_t1: str,
+ memories_t: List[MemoryCard],
+ query_embedding_t: Optional[np.ndarray] = None,
+ query_embedding_t1: Optional[np.ndarray] = None,
+) -> Tuple[float, float]:
+ """
+ Unified evaluation interface.
+ Given (q_t, a_t, q_{t+1}, memories), returns (reward_hat, gating_hat).
+ """
+
+ # Construct a lightweight TurnSample
+ # We might need embeddings for gating. If not provided, gating might return default.
+
+ # Ensure memories have embeddings for gating
+ mem_embs = None
+ if memories_t and memories_t[0].embedding_e:
+ try:
+ mem_embs = np.array([m.embedding_e for m in memories_t])
+ except:
+ pass
+
+ sample = TurnSample(
+ user_id="", # Not needed for simple eval
+ session_id="",
+ turn_id=0,
+ query_t=q_t,
+ answer_t=answer_t,
+ query_t1=q_t1,
+ memories=memories_t,
+ query_embedding_t=query_embedding_t,
+ query_embedding_t1=query_embedding_t1,
+ memory_embeddings=mem_embs
+ )
+
+ r_hat = estimate_reward(sample)
+ g_hat = estimate_retrieval_gating(sample, r_hat)
+
+ return r_hat, g_hat
+
diff --git a/src/personalization/feedback/online_update.py b/src/personalization/feedback/online_update.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/feedback/online_update.py
diff --git a/src/personalization/feedback/reward_model.py b/src/personalization/feedback/reward_model.py
new file mode 100644
index 0000000..3584b43
--- /dev/null
+++ b/src/personalization/feedback/reward_model.py
@@ -0,0 +1,64 @@
+import numpy as np
+from personalization.feedback.schemas import TurnSample
+
+def cosine_sim(a: np.ndarray, b: np.ndarray) -> float:
+ norm_a = np.linalg.norm(a)
+ norm_b = np.linalg.norm(b)
+ if norm_a == 0 or norm_b == 0:
+ return 0.0
+ return float(np.dot(a, b) / (norm_a * norm_b))
+
+def estimate_reward(sample: TurnSample) -> float:
+ """
+ Return a scalar reward_hat, indicating if the previous answer was helpful.
+ Range: [-1.0, 1.0] (approx)
+ """
+
+ # 1. Language/Topic Coherence
+ if sample.query_embedding_t is None or sample.query_embedding_t1 is None:
+ topic_sim = 0.5
+ else:
+ topic_sim = cosine_sim(sample.query_embedding_t, sample.query_embedding_t1)
+
+ # 2. Negative Keywords (Complaint/Correction)
+ negative_keywords = [
+ "you didn't", "that's not", "incorrect", "redo", "again", "explain more",
+ "doesn't help", "wrong", "no", "not what i asked",
+ "你没", "不是", "这不是", "重来", "重新", "不对", "错了", "没说清楚"
+ ]
+
+ # 3. Positive Keywords (Follow-up/Elaboration)
+ positive_keywords = [
+ "can you elaborate", "give an example", "continue", "what if", "based on that",
+ "thanks", "good", "great", "cool",
+ "能不能详细一点", "举个例子", "再继续", "那如果", "接下来", "在这个基础上", "谢谢", "不错"
+ ]
+
+ q1_lower = sample.query_t1.lower()
+
+ has_negative = any(kw in q1_lower for kw in negative_keywords)
+ has_positive = any(kw in q1_lower for kw in positive_keywords)
+
+ reward = 0.0
+
+ if has_negative:
+ reward -= 1.0
+
+ if has_positive:
+ # Only reward if topic similarity is decent, otherwise might be "thanks, bye" (end of session)
+ # But "thanks" is good.
+ reward += 0.5
+ if topic_sim > 0.3:
+ reward += 0.5
+
+ if topic_sim < 0.2:
+ # Topic shift -> previous interaction likely finished or failed.
+ # If no explicit positive/negative, assume neutral/slightly decayed.
+ # If user changes topic, it often means the previous task is done (neutral/positive)
+ # OR they gave up (negative). Hard to tell.
+ # Let's dampen the reward towards 0.
+ reward *= 0.5
+
+ # Clip
+ return max(-1.0, min(1.0, reward))
+
diff --git a/src/personalization/feedback/sampler.py b/src/personalization/feedback/sampler.py
new file mode 100644
index 0000000..9e26912
--- /dev/null
+++ b/src/personalization/feedback/sampler.py
@@ -0,0 +1,109 @@
+from typing import Iterable, List, Optional
+import numpy as np
+from tqdm import tqdm
+
+from personalization.retrieval.preference_store.schemas import ChatTurn, MemoryCard
+from personalization.feedback.schemas import TurnSample
+from personalization.retrieval.pipeline import retrieve_with_rerank
+from personalization.models.llm.qwen_instruct import QwenInstruct
+from personalization.models.embedding.base import EmbeddingModel
+from personalization.models.reranker.base import Reranker
+from personalization.user_model.tensor_store import UserTensorStore
+
+def build_turn_samples_from_sessions(
+ sessions: Iterable[List[ChatTurn]],
+ embed_model: EmbeddingModel,
+ llm: QwenInstruct,
+ reranker: Reranker,
+ memory_cards: List[MemoryCard],
+ memory_embeddings: np.ndarray,
+ user_store: UserTensorStore,
+ item_vectors: np.ndarray,
+ max_samples: Optional[int] = None,
+ topk_dense: int = 64,
+ topk_rerank: int = 3,
+) -> List[TurnSample]:
+ samples = []
+
+ for turns in tqdm(sessions, desc="Building TurnSamples"):
+ if max_samples and len(samples) >= max_samples:
+ break
+
+ # Ensure sorted by turn_id
+ sorted_turns = sorted(turns, key=lambda x: x.turn_id)
+
+ # Iterate to find (q_t, a_t, q_{t+1})
+ for i in range(len(sorted_turns)):
+ if max_samples and len(samples) >= max_samples:
+ break
+
+ q_t = sorted_turns[i]
+ if q_t.role != "user":
+ continue
+
+ # Find next user turn
+ # Also try to find assistant response in between
+ a_t_text = ""
+ q_t1 = None
+
+ # Look ahead
+ for j in range(i + 1, len(sorted_turns)):
+ next_turn = sorted_turns[j]
+ if next_turn.role == "assistant" and not a_t_text:
+ a_t_text = next_turn.text
+ elif next_turn.role == "user":
+ q_t1 = next_turn
+ break
+
+ if not q_t1:
+ # End of session or no subsequent user query
+ continue
+
+ # We have q_t, a_t (optional but preferred), q_t1
+ # If a_t is missing, we might skip or use empty string.
+ # For RL, we usually need the answer to evaluate quality.
+ # If dataset doesn't have assistant turns, we might need to generate one?
+ # For now, let's proceed even if a_t is empty, or maybe require it.
+ if not a_t_text:
+ # Try to use LLM to generate if needed, but for offline sampling
+ # from existing chats, we prefer existing answers.
+ # If using OASST1, it should have assistant turns.
+ pass
+
+ # 3. Retrieve memories for q_t
+ memories_t = retrieve_with_rerank(
+ user_id=q_t.user_id,
+ query=q_t.text,
+ embed_model=embed_model,
+ reranker=reranker,
+ memory_cards=memory_cards,
+ memory_embeddings=memory_embeddings,
+ user_store=user_store,
+ item_vectors=item_vectors,
+ topk_dense=topk_dense,
+ topk_rerank=topk_rerank,
+ beta_long=0.0,
+ beta_short=0.0,
+ only_own_memories=True # Assume we want user specific memories
+ )
+
+ # 4. Precompute embeddings
+ # We can do this efficiently later or batch, but here per sample
+ e_q_t = embed_model.encode([q_t.text], return_tensor=False)[0]
+ e_q_t1 = embed_model.encode([q_t1.text], return_tensor=False)[0]
+
+ sample = TurnSample(
+ user_id=q_t.user_id,
+ session_id=q_t.session_id,
+ turn_id=q_t.turn_id,
+ query_t=q_t.text,
+ answer_t=a_t_text,
+ query_t1=q_t1.text,
+ memories=memories_t,
+ query_embedding_t=np.array(e_q_t),
+ query_embedding_t1=np.array(e_q_t1)
+ )
+ samples.append(sample)
+
+ return samples
+
diff --git a/src/personalization/feedback/schemas.py b/src/personalization/feedback/schemas.py
new file mode 100644
index 0000000..b15db80
--- /dev/null
+++ b/src/personalization/feedback/schemas.py
@@ -0,0 +1,23 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import List, Optional, Any
+import numpy as np
+
+from personalization.retrieval.preference_store.schemas import MemoryCard
+
+@dataclass
+class TurnSample:
+ user_id: str
+ session_id: str
+ turn_id: int # index of q_t within the session
+ query_t: str # q_t
+ answer_t: str # a_t
+ query_t1: str # q_{t+1}
+ memories: List[MemoryCard] # A_t
+
+ # Optional pre-computed vectors and features
+ query_embedding_t: Optional[np.ndarray] = None
+ query_embedding_t1: Optional[np.ndarray] = None
+ memory_embeddings: Optional[np.ndarray] = None # corresponding e_m or v_m for memories
+
diff --git a/src/personalization/retrieval/__init__.py b/src/personalization/retrieval/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/retrieval/__init__.py
diff --git a/src/personalization/retrieval/chunking/__init__.py b/src/personalization/retrieval/chunking/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/retrieval/chunking/__init__.py
diff --git a/src/personalization/retrieval/chunking/rules.py b/src/personalization/retrieval/chunking/rules.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/retrieval/chunking/rules.py
diff --git a/src/personalization/retrieval/pipeline.py b/src/personalization/retrieval/pipeline.py
new file mode 100644
index 0000000..3d3eeb7
--- /dev/null
+++ b/src/personalization/retrieval/pipeline.py
@@ -0,0 +1,250 @@
+from typing import List, Tuple
+import numpy as np
+
+from personalization.models.embedding.base import EmbeddingModel
+from personalization.models.reranker.base import Reranker
+from personalization.retrieval.preference_store.schemas import MemoryCard
+from personalization.user_model.tensor_store import UserTensorStore, UserState
+from personalization.user_model.scoring import score_with_user
+from personalization.user_model.policy.reinforce import compute_policy_scores
+
+def cosine_similarity_matrix(E: np.ndarray, e_q: np.ndarray) -> np.ndarray:
+ # E: [M, d], e_q: [d]
+ return np.dot(E, e_q)
+
+def dense_topk_indices(
+ query: str,
+ embed_model: EmbeddingModel,
+ memory_embeddings: np.ndarray,
+ valid_indices: List[int] = None,
+ topk: int = 64
+) -> List[int]:
+ """
+ Return indices of topk memories based on dense embedding similarity.
+ If valid_indices is provided, only search within that subset.
+ """
+ if valid_indices is not None and len(valid_indices) == 0:
+ return []
+
+ e_q_list = embed_model.encode([query], normalize=True, return_tensor=False)
+ e_q = np.array(e_q_list[0], dtype=np.float32)
+
+ # Select subset of embeddings if restricted
+ if valid_indices is not None:
+ # subset_embeddings = memory_embeddings[valid_indices]
+ # But valid_indices might be arbitrary.
+ # Efficient way: only dot product with subset
+ # E_sub: [M_sub, d]
+ E_sub = memory_embeddings[valid_indices]
+ sims_sub = np.dot(E_sub, e_q)
+
+ # Topk within subset
+ k = min(topk, len(sims_sub))
+ if k == 0:
+ return []
+
+ # argsort gives indices relative to E_sub (0..M_sub-1)
+ # We need to map back to original indices
+ idx_sub = np.argsort(sims_sub)[-k:][::-1]
+
+ return [valid_indices[i] for i in idx_sub]
+
+ # Global search
+ sims = np.dot(memory_embeddings, e_q)
+ k = min(topk, len(memory_embeddings))
+ if k == 0:
+ return []
+
+ idx = np.argsort(sims)[-k:][::-1]
+ return idx.tolist()
+
+def retrieve_with_policy(
+ user_id: str,
+ query: str,
+ embed_model: EmbeddingModel,
+ reranker: Reranker,
+ memory_cards: List[MemoryCard],
+ memory_embeddings: np.ndarray, # shape: [M, d]
+ user_store: UserTensorStore,
+ item_vectors: np.ndarray, # shape: [M, k], v_m
+ topk_dense: int = 64,
+ topk_rerank: int = 8,
+ beta_long: float = 0.0,
+ beta_short: float = 0.0,
+ tau: float = 1.0,
+ only_own_memories: bool = False,
+ sample: bool = False,
+) -> Tuple[List[MemoryCard], np.ndarray, np.ndarray, List[int], np.ndarray]:
+ """
+ Returns extended info for policy update:
+ (candidates, candidate_item_vectors, base_scores, chosen_indices, policy_probs)
+
+ Args:
+ sample: If True, use stochastic sampling from policy distribution (for training/exploration).
+ If False, use deterministic top-k by policy scores (for evaluation).
+ """
+ # 0. Filter indices if needed
+ valid_indices = None
+ if only_own_memories:
+ valid_indices = [i for i, card in enumerate(memory_cards) if card.user_id == user_id]
+ if not valid_indices:
+ return [], np.array([]), np.array([]), [], np.array([])
+
+ # 1. Dense retrieval
+ dense_idx = dense_topk_indices(
+ query,
+ embed_model,
+ memory_embeddings,
+ valid_indices=valid_indices,
+ topk=topk_dense
+ )
+ # DEBUG: Check for duplicates or out of bounds
+ if len(dense_idx) > 0:
+ import os
+ if os.getenv("RETRIEVAL_DEBUG") == "1":
+ print(f" [Pipeline] Dense Indices (Top {len(dense_idx)}): {dense_idx[:10]}...")
+ print(f" [Pipeline] Max Index: {max(dense_idx)} | Memory Size: {len(memory_cards)}")
+
+ if not dense_idx:
+ return [], np.array([]), np.array([]), [], np.array([])
+
+ candidates = [memory_cards[i] for i in dense_idx]
+ candidate_docs = [c.note_text for c in candidates]
+
+ # 2. Rerank base score (P(yes|q,m))
+ base_scores = np.array(reranker.score(query, candidate_docs))
+
+ # 3. Policy Scoring (Softmax)
+ user_state: UserState = user_store.get_state(user_id)
+ candidate_vectors = item_vectors[dense_idx] # [K, k]
+
+ policy_out = compute_policy_scores(
+ base_scores=base_scores,
+ user_state=user_state,
+ item_vectors=candidate_vectors,
+ beta_long=beta_long,
+ beta_short=beta_short,
+ tau=tau
+ )
+
+ # 4. Selection: Greedy (eval) or Stochastic (training)
+ k = min(topk_rerank, len(policy_out.scores))
+
+ if sample:
+ # Stochastic sampling from policy distribution (for training/exploration)
+ # Sample k indices without replacement, weighted by policy probs
+ probs = policy_out.probs
+ # Normalize to ensure sum to 1 (handle numerical issues)
+ probs = probs / (probs.sum() + 1e-10)
+ # Sample without replacement
+ chosen_indices = np.random.choice(
+ len(probs), size=k, replace=False, p=probs
+ ).tolist()
+ else:
+ # Deterministic top-k by policy scores (for evaluation)
+ top_indices_local = policy_out.scores.argsort()[-k:][::-1]
+ chosen_indices = top_indices_local.tolist()
+
+ import os
+ if os.getenv("RETRIEVAL_DEBUG") == "1":
+ print(f" [Pipeline] Candidates: {len(candidates)} | Chosen Indices: {chosen_indices} | Sample: {sample}")
+
+ return candidates, candidate_vectors, base_scores, chosen_indices, policy_out.probs
+
+def retrieve_no_policy(
+ user_id: str,
+ query: str,
+ embed_model: EmbeddingModel,
+ reranker: Reranker,
+ memory_cards: List[MemoryCard],
+ memory_embeddings: np.ndarray, # shape: [M, d]
+ topk_dense: int = 64,
+ topk_rerank: int = 8,
+ only_own_memories: bool = False,
+) -> Tuple[List[MemoryCard], np.ndarray, np.ndarray, List[int], np.ndarray]:
+ """
+ Deterministic retrieval baseline (NoPersonal mode):
+ - Dense retrieval -> Rerank -> Top-K (no policy sampling, no user vector influence)
+
+ Returns same structure as retrieve_with_policy for compatibility:
+ (candidates, candidate_item_vectors, base_scores, chosen_indices, rerank_scores_for_chosen)
+
+ Note: candidate_item_vectors is empty array (not used in NoPersonal mode)
+ The last return value is rerank scores instead of policy probs
+ """
+ # 0. Filter indices if needed
+ valid_indices = None
+ if only_own_memories:
+ valid_indices = [i for i, card in enumerate(memory_cards) if card.user_id == user_id]
+ if not valid_indices:
+ return [], np.array([]), np.array([]), [], np.array([])
+
+ # 1. Dense retrieval
+ dense_idx = dense_topk_indices(
+ query,
+ embed_model,
+ memory_embeddings,
+ valid_indices=valid_indices,
+ topk=topk_dense
+ )
+
+ if not dense_idx:
+ return [], np.array([]), np.array([]), [], np.array([])
+
+ candidates = [memory_cards[i] for i in dense_idx]
+ candidate_docs = [c.note_text for c in candidates]
+
+ # 2. Rerank base score (P(yes|q,m))
+ base_scores = np.array(reranker.score(query, candidate_docs))
+
+ # 3. Deterministic Top-K selection based on rerank scores ONLY (no policy)
+ k = min(topk_rerank, len(base_scores))
+ top_indices_local = base_scores.argsort()[-k:][::-1]
+ chosen_indices = top_indices_local.tolist()
+
+ # Get scores for chosen items (for logging compatibility)
+ chosen_scores = base_scores[top_indices_local]
+
+ # Return empty item vectors (not used in NoPersonal mode)
+ # Return rerank scores as the "probs" field for logging compatibility
+ return candidates, np.array([]), base_scores, chosen_indices, chosen_scores
+
+
+def retrieve_with_rerank(
+ user_id: str,
+ query: str,
+ embed_model: EmbeddingModel,
+ reranker: Reranker,
+ memory_cards: List[MemoryCard],
+ memory_embeddings: np.ndarray, # shape: [M, d]
+ user_store: UserTensorStore,
+ item_vectors: np.ndarray, # shape: [M, k], v_m
+ topk_dense: int = 64,
+ topk_rerank: int = 8,
+ beta_long: float = 0.0,
+ beta_short: float = 0.0,
+ only_own_memories: bool = False,
+) -> List[MemoryCard]:
+ """
+ Wrapper around retrieve_with_policy for standard inference.
+ """
+ candidates, _, _, chosen_indices, _ = retrieve_with_policy(
+ user_id=user_id,
+ query=query,
+ embed_model=embed_model,
+ reranker=reranker,
+ memory_cards=memory_cards,
+ memory_embeddings=memory_embeddings,
+ user_store=user_store,
+ item_vectors=item_vectors,
+ topk_dense=topk_dense,
+ topk_rerank=topk_rerank,
+ beta_long=beta_long,
+ beta_short=beta_short,
+ tau=1.0, # Default tau
+ only_own_memories=only_own_memories
+ )
+
+ return [candidates[i] for i in chosen_indices]
+
+
diff --git a/src/personalization/retrieval/preference_store/__init__.py b/src/personalization/retrieval/preference_store/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/retrieval/preference_store/__init__.py
diff --git a/src/personalization/retrieval/preference_store/base.py b/src/personalization/retrieval/preference_store/base.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/retrieval/preference_store/base.py
diff --git a/src/personalization/retrieval/preference_store/schemas.py b/src/personalization/retrieval/preference_store/schemas.py
new file mode 100644
index 0000000..eb82558
--- /dev/null
+++ b/src/personalization/retrieval/preference_store/schemas.py
@@ -0,0 +1,47 @@
+from __future__ import annotations
+
+from typing import List, Literal, Optional, Dict, Any
+
+from pydantic import BaseModel, Field, confloat
+
+
+class Preference(BaseModel):
+ condition: str = Field(
+ ..., min_length=1, max_length=128, description="When the rule applies"
+ )
+ action: str = Field(
+ ..., min_length=1, max_length=256, description="What to do in that case"
+ )
+ confidence: confloat(ge=0.0, le=1.0) = Field(
+ ..., description="Confidence the rule is correct"
+ )
+
+
+class PreferenceList(BaseModel):
+ preferences: List[Preference] = Field(default_factory=list)
+
+
+def preference_list_json_schema() -> dict:
+ return PreferenceList.model_json_schema()
+
+
+class ChatTurn(BaseModel):
+ user_id: str
+ session_id: str
+ turn_id: int
+ role: Literal["user", "assistant"]
+ text: str
+ timestamp: Optional[float] = None
+ meta: Dict[str, Any] = Field(default_factory=dict)
+
+
+class MemoryCard(BaseModel):
+ card_id: str
+ user_id: str
+ source_session_id: str
+ source_turn_ids: List[int]
+ raw_queries: List[str] # The original user utterances
+ preference_list: PreferenceList
+ note_text: str # Summarized "condition: action" text
+ embedding_e: List[float] # The embedding vector
+ kind: Literal["pref", "fact"] = "pref"
diff --git a/src/personalization/retrieval/preference_store/vector_kv.py b/src/personalization/retrieval/preference_store/vector_kv.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/retrieval/preference_store/vector_kv.py
diff --git a/src/personalization/retrieval/rerank.py b/src/personalization/retrieval/rerank.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/retrieval/rerank.py
diff --git a/src/personalization/retrieval/store/__init__.py b/src/personalization/retrieval/store/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/retrieval/store/__init__.py
diff --git a/src/personalization/retrieval/store/base.py b/src/personalization/retrieval/store/base.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/retrieval/store/base.py
diff --git a/src/personalization/retrieval/store/faiss_store.py b/src/personalization/retrieval/store/faiss_store.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/retrieval/store/faiss_store.py
diff --git a/src/personalization/retrieval/store/pgvector_store.py b/src/personalization/retrieval/store/pgvector_store.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/retrieval/store/pgvector_store.py
diff --git a/src/personalization/serving/__init__.py b/src/personalization/serving/__init__.py
new file mode 100644
index 0000000..11adcf8
--- /dev/null
+++ b/src/personalization/serving/__init__.py
@@ -0,0 +1,22 @@
+# Personalization Serving Module
+#
+# This module provides the interface layer for the personalization system.
+
+from personalization.serving.personalized_llm import (
+ PersonalizedLLM,
+ AssistantResponse,
+ UsageStats,
+ DebugInfo,
+ Feedback,
+ create_personalized_llm,
+)
+
+__all__ = [
+ "PersonalizedLLM",
+ "AssistantResponse",
+ "UsageStats",
+ "DebugInfo",
+ "Feedback",
+ "create_personalized_llm",
+]
+
diff --git a/src/personalization/serving/api/__init__.py b/src/personalization/serving/api/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/serving/api/__init__.py
diff --git a/src/personalization/serving/api/main.py b/src/personalization/serving/api/main.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/serving/api/main.py
diff --git a/src/personalization/serving/api/routes/__init__.py b/src/personalization/serving/api/routes/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/serving/api/routes/__init__.py
diff --git a/src/personalization/serving/api/routes/feedback.py b/src/personalization/serving/api/routes/feedback.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/serving/api/routes/feedback.py
diff --git a/src/personalization/serving/api/routes/query.py b/src/personalization/serving/api/routes/query.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/serving/api/routes/query.py
diff --git a/src/personalization/serving/api/routes/users.py b/src/personalization/serving/api/routes/users.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/serving/api/routes/users.py
diff --git a/src/personalization/serving/api/schemas.py b/src/personalization/serving/api/schemas.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/serving/api/schemas.py
diff --git a/src/personalization/serving/personalized_llm.py b/src/personalization/serving/personalized_llm.py
new file mode 100644
index 0000000..2c4d5a8
--- /dev/null
+++ b/src/personalization/serving/personalized_llm.py
@@ -0,0 +1,837 @@
+#!/usr/bin/env python3
+"""
+Personalized LLM Interface for Evaluation.
+
+This module provides the `PersonalizedLLM` class that wraps the entire
+personalization system into a clean interface for evaluation frameworks
+and user simulators.
+
+Interface contract:
+- chat(user_id, query) -> AssistantResponse: Main online interface
+- reset_session(user_id): Clear session history and short-term state
+- reset_user(user_id): Completely reset user (long-term, short-term, memories)
+- apply_feedback(feedback): Apply external feedback for RL updates
+"""
+
+from __future__ import annotations
+
+import os
+import sys
+import uuid
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional
+
+import numpy as np
+import yaml
+
+# Ensure src is in path for standalone usage
+_src_path = os.path.join(os.path.dirname(__file__), "../../..")
+if _src_path not in sys.path:
+ sys.path.insert(0, _src_path)
+
+from personalization.config.settings import load_local_models_config
+from personalization.config.registry import get_preference_extractor, get_chat_model
+from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
+from personalization.models.reranker.qwen3_reranker import Qwen3Reranker
+from personalization.user_model.tensor_store import UserTensorStore, UserState
+from personalization.user_model.session_state import OnlineSessionState
+from personalization.user_model.features import ItemProjection
+from personalization.retrieval.preference_store.schemas import (
+ MemoryCard, ChatTurn, PreferenceList, Preference
+)
+from personalization.retrieval.pipeline import retrieve_with_policy, retrieve_no_policy
+from personalization.feedback.handlers import eval_step
+from personalization.user_model.policy.reinforce import reinforce_update_user_state
+
+
+# =============================================================================
+# Data Classes for Interface
+# =============================================================================
+
+@dataclass
+class UsageStats:
+ """Token usage statistics from a chat completion."""
+ prompt_tokens: int
+ completion_tokens: int
+ total_tokens: int
+ model: str
+
+
+@dataclass
+class DebugInfo:
+ """
+ Debug information for analysis and ablation studies.
+ All fields are optional - fill what you have, leave empty what you don't.
+ """
+ selected_memory_ids: List[str] = field(default_factory=list)
+ selected_memory_notes: List[str] = field(default_factory=list)
+ selected_memory_scores: List[float] = field(default_factory=list)
+ user_vector_before: Optional[List[float]] = None
+ user_vector_after: Optional[List[float]] = None
+ extracted_preferences: List[Dict[str, Any]] = field(default_factory=list)
+ extra: Dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass
+class AssistantResponse:
+ """Response from the personalized LLM chat interface."""
+ answer: str
+ usage: UsageStats
+ debug: Optional[DebugInfo] = None
+
+
+@dataclass
+class Feedback:
+ """
+ Feedback data structure for RL updates from user simulator or judge.
+
+ Attributes:
+ user_id: The user this feedback is for.
+ turn_id: The turn this feedback refers to (from the previous turn).
+ reward: Reward scalar computed by user simulator / judge.
+ gating: Gating flag (1=valid learning signal, 0=skip update).
+ meta: Additional metadata for training/analysis.
+ """
+ user_id: str
+ turn_id: int
+ reward: float
+ gating: float # Can be 0.0 or 1.0, or continuous
+ meta: Dict[str, Any] = field(default_factory=dict)
+
+
+# =============================================================================
+# Internal Session State Extended
+# =============================================================================
+
+@dataclass
+class _SessionContext:
+ """Extended session context for evaluation tracking."""
+ session_state: OnlineSessionState
+ turn_counter: int = 0
+ # Store info needed for apply_feedback
+ pending_rl_update: Optional[Dict[str, Any]] = None
+
+
+# =============================================================================
+# PersonalizedLLM Class
+# =============================================================================
+
+class PersonalizedLLM:
+ """
+ Personalized LLM wrapper for evaluation frameworks.
+
+ This class provides a clean interface that accepts only (user_id, query)
+ for the main chat function, while internally managing:
+ - User state vectors (z_long, z_short)
+ - Session history
+ - Memory retrieval and policy
+ - Preference extraction and storage
+ - RL updates
+
+ Example usage:
+ llm = PersonalizedLLM()
+
+ # Reset user for fresh experiment
+ llm.reset_user("user_123")
+
+ # Start a session
+ llm.reset_session("user_123")
+
+ # Chat
+ response = llm.chat("user_123", "What's a good recipe for dinner?")
+ print(response.answer)
+
+ # Apply feedback from previous turn (from turn 2 onwards)
+ llm.apply_feedback(Feedback(
+ user_id="user_123",
+ turn_id=0,
+ reward=0.8,
+ gating=1.0
+ ))
+ """
+
+ def __init__(
+ self,
+ config_path: Optional[str] = None,
+ user_store_path: str = "data/users/user_store_eval.npz",
+ memory_cards_path: str = "data/corpora/memory_cards.jsonl",
+ memory_embeddings_path: str = "data/corpora/memory_embeddings.npy",
+ item_projection_path: str = "data/corpora/item_projection.npz",
+ only_own_memories: bool = True,
+ enable_preference_extraction: bool = True,
+ enable_rl_updates: bool = True,
+ mode: str = "full", # "full", "nopersonal", or "vanilla"
+ eval_mode: bool = True, # True = greedy selection, False = stochastic sampling
+ device_assignment: Optional[Dict[str, str]] = None, # Multi-GPU support
+ ):
+ """
+ Initialize the PersonalizedLLM.
+
+ Args:
+ config_path: Path to config file. If None, uses default locations.
+ user_store_path: Path to persist user state vectors.
+ memory_cards_path: Path to memory cards JSONL file.
+ memory_embeddings_path: Path to memory embeddings numpy file.
+ item_projection_path: Path to item projection (PCA) file.
+ only_own_memories: If True, only retrieve user's own memories (strict privacy).
+ enable_preference_extraction: If True, extract preferences from user turns.
+ enable_rl_updates: If True, apply RL updates via apply_feedback.
+ mode: "full" for full personalization, "nopersonal" for baseline (no user vector influence),
+ "vanilla" for pure LLM without any memory retrieval or preference extraction.
+ eval_mode: If True, use greedy/deterministic selection (for evaluation).
+ If False, use stochastic sampling (for training/exploration).
+ device_assignment: Optional dict to assign models to specific GPUs.
+ Example: {"embed": "cuda:0", "reranker": "cuda:1", "chat": "cuda:2", "extractor": "cuda:3"}
+ If None, uses "auto" for all models.
+ """
+ self.only_own_memories = only_own_memories
+ self.enable_preference_extraction = enable_preference_extraction
+ self.enable_rl_updates = enable_rl_updates
+ self.mode = mode # "full" or "nopersonal"
+ self.eval_mode = eval_mode # True = greedy, False = sample
+
+ # Multi-GPU device assignment
+ self._device_assignment = device_assignment or {
+ "embed": "auto",
+ "reranker": "auto",
+ "chat": "auto",
+ "extractor": "auto",
+ }
+
+ # Paths
+ self._memory_cards_path = memory_cards_path
+ self._memory_embeddings_path = memory_embeddings_path
+ self._item_projection_path = item_projection_path
+
+ # RL Configuration
+ # Note: beta/eta increased for more significant z_u updates
+ self._rl_cfg = {
+ "item_dim": 256,
+ "beta_long": 2.0, # Increased from 0.1 for stronger personalization
+ "beta_short": 5.0, # Increased from 0.3
+ "tau": 1.0,
+ "eta_long": 0.01, # Increased from 1e-3 for faster learning
+ "eta_short": 0.05, # Increased from 5e-3
+ "ema_alpha": 0.05,
+ "short_decay": 0.1,
+ "dense_topk": 64,
+ "rerank_topk": 3,
+ "max_new_tokens": 512,
+ }
+
+ # Load config and override RL params if available
+ self._load_config(config_path)
+
+ # Load models
+ print("[PersonalizedLLM] Loading models...")
+ self._load_models()
+
+ # Load memory store
+ print("[PersonalizedLLM] Loading memory store...")
+ self._load_memory_store()
+
+ # Initialize user store
+ self._user_store = UserTensorStore(
+ k=self._rl_cfg["item_dim"],
+ path=user_store_path,
+ )
+
+ # Session contexts per user (in-memory)
+ self._sessions: Dict[str, _SessionContext] = {}
+
+ print("[PersonalizedLLM] Initialization complete.")
+
+ def _load_config(self, config_path: Optional[str]):
+ """Load configuration from yaml files."""
+ self._cfg = load_local_models_config()
+
+ # Try to load user_model.yaml for RL params
+ if config_path is None:
+ config_path = "configs/user_model.yaml"
+
+ self._llm_name = "qwen_1_5b" # Default
+
+ try:
+ if os.path.exists(config_path):
+ with open(config_path, "r") as f:
+ user_cfg = yaml.safe_load(f)
+ if user_cfg:
+ # Override RL params if present
+ for key in self._rl_cfg:
+ if key in user_cfg:
+ self._rl_cfg[key] = user_cfg[key]
+ # LLM name
+ if "llm_name" in user_cfg:
+ self._llm_name = user_cfg["llm_name"]
+ except Exception as e:
+ print(f"[PersonalizedLLM] Warning: Failed to load config: {e}")
+
+ def _load_models(self):
+ """Load all ML models with optional multi-GPU assignment."""
+ import torch
+
+ # Report GPU availability
+ num_gpus = torch.cuda.device_count()
+ print(f"[PersonalizedLLM] Available GPUs: {num_gpus}")
+ for i in range(num_gpus):
+ mem = torch.cuda.get_device_properties(i).total_memory / 1e9
+ print(f" GPU {i}: {torch.cuda.get_device_name(i)} ({mem:.1f}GB)")
+
+ embed_device = self._device_assignment.get("embed", "auto")
+ reranker_device = self._device_assignment.get("reranker", "auto")
+ chat_device = self._device_assignment.get("chat", "auto")
+ extractor_device = self._device_assignment.get("extractor", "auto")
+
+ # Embedding model
+ print(f"[PersonalizedLLM] Loading Embedding model on {embed_device}...")
+ self._embed_model = Qwen3Embedding8B(
+ model_path=self._cfg.embedding.qwen3.local_path,
+ dtype=torch.bfloat16,
+ device_map=embed_device,
+ )
+
+ # Reranker
+ print(f"[PersonalizedLLM] Loading Reranker on {reranker_device}...")
+ self._reranker = Qwen3Reranker(
+ model_path=self._cfg.reranker.qwen3_8b.local_path,
+ device_map=reranker_device,
+ dtype=torch.bfloat16,
+ )
+
+ # Chat model (via registry for backend switching)
+ print(f"[PersonalizedLLM] Loading ChatModel: {self._llm_name} on {chat_device}...")
+ # Pass device override if specified (not "auto")
+ device_for_chat = chat_device if chat_device != "auto" else None
+ self._chat_model = get_chat_model(self._llm_name, device_override=device_for_chat)
+
+ # Preference extractor
+ if self.enable_preference_extraction:
+ extractor_name = "qwen3_0_6b_sft"
+ print(f"[PersonalizedLLM] Loading extractor: {extractor_name} on {extractor_device}...")
+ try:
+ self._extractor = get_preference_extractor(extractor_name)
+ except Exception as e:
+ print(f"[PersonalizedLLM] Warning: Failed to load {extractor_name}: {e}. Using rule-based.")
+ self._extractor = get_preference_extractor("rule")
+ else:
+ print("[PersonalizedLLM] Preference extraction disabled, using rule-based extractor.")
+ self._extractor = get_preference_extractor("rule")
+
+ def _load_memory_store(self):
+ """Load memory cards and embeddings."""
+ if not os.path.exists(self._memory_cards_path):
+ print(f"[PersonalizedLLM] Warning: Memory cards not found at {self._memory_cards_path}")
+ self._memory_cards: List[MemoryCard] = []
+ self._memory_embeddings = np.zeros((0, 4096), dtype=np.float32)
+ self._item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32)
+ self._projection = None
+ return
+
+ # Load cards
+ self._memory_cards = []
+ with open(self._memory_cards_path, "r") as f:
+ for line in f:
+ line = line.strip()
+ if line:
+ self._memory_cards.append(MemoryCard.model_validate_json(line))
+
+ # Load embeddings
+ if os.path.exists(self._memory_embeddings_path):
+ self._memory_embeddings = np.load(self._memory_embeddings_path)
+ else:
+ self._memory_embeddings = np.zeros((len(self._memory_cards), 4096), dtype=np.float32)
+
+ # Load projection
+ if os.path.exists(self._item_projection_path):
+ proj_data = np.load(self._item_projection_path)
+ self._projection = ItemProjection(P=proj_data["P"], mean=proj_data["mean"])
+ self._item_vectors = proj_data["V"]
+ else:
+ self._projection = None
+ self._item_vectors = np.zeros((len(self._memory_cards), self._rl_cfg["item_dim"]), dtype=np.float32)
+
+ print(f"[PersonalizedLLM] Loaded {len(self._memory_cards)} memory cards.")
+
+ def _get_or_create_session(self, user_id: str) -> _SessionContext:
+ """Get or create session context for a user."""
+ if user_id not in self._sessions:
+ self._sessions[user_id] = _SessionContext(
+ session_state=OnlineSessionState(user_id=user_id),
+ turn_counter=0,
+ )
+ return self._sessions[user_id]
+
+ def _build_chat_turn(self, user_id: str, text: str, role: str, turn_id: int) -> ChatTurn:
+ """Build a ChatTurn object."""
+ return ChatTurn(
+ user_id=user_id,
+ session_id=f"eval_session_{user_id}",
+ turn_id=turn_id,
+ role=role,
+ text=text,
+ meta={"source": "eval"}
+ )
+
+ def _count_tokens(self, text: str) -> int:
+ """Estimate token count using the tokenizer."""
+ try:
+ # Use the chat model's tokenizer if available
+ if hasattr(self._chat_model, 'tokenizer'):
+ return len(self._chat_model.tokenizer.encode(text))
+ else:
+ # Rough estimate: ~4 chars per token
+ return len(text) // 4
+ except Exception:
+ return len(text) // 4
+
+ def _add_preferences_as_memory(
+ self,
+ prefs: PreferenceList,
+ query: str,
+ user_id: str,
+ turn_id: int,
+ ) -> List[Dict[str, Any]]:
+ """
+ Add extracted preferences as new memory cards.
+ Returns list of preference dicts for debug info.
+ """
+ extracted = []
+
+ if not prefs.preferences or self._projection is None:
+ return extracted
+
+ # Compute embedding for the query
+ e_q = self._embed_model.encode([query], return_tensor=False)[0]
+ v_q = self._projection.transform_vector(np.array(e_q))
+
+ for pref in prefs.preferences:
+ note_text = f"When {pref.condition}, {pref.action}."
+
+ # Record for debug
+ extracted.append({
+ "condition": pref.condition,
+ "action": pref.action,
+ "confidence": pref.confidence,
+ })
+
+ # Deduplication check
+ is_duplicate = any(
+ card.user_id == user_id and card.note_text == note_text
+ for card in self._memory_cards
+ )
+
+ if is_duplicate:
+ continue
+
+ # Create new memory card
+ card = MemoryCard(
+ card_id=str(uuid.uuid4()),
+ user_id=user_id,
+ source_session_id=f"eval_session_{user_id}",
+ source_turn_ids=[turn_id],
+ raw_queries=[query],
+ preference_list=PreferenceList(preferences=[pref]),
+ note_text=note_text,
+ embedding_e=list(e_q),
+ kind="pref",
+ )
+
+ # Add to memory store
+ self._memory_cards.append(card)
+ self._memory_embeddings = np.vstack([self._memory_embeddings, np.array([e_q])])
+ self._item_vectors = np.vstack([self._item_vectors, np.array([v_q])])
+
+ return extracted
+
+ # =========================================================================
+ # Public Interface
+ # =========================================================================
+
+ def chat(self, user_id: str, query: str) -> AssistantResponse:
+ """
+ Main online chat interface.
+
+ Args:
+ user_id: Unique identifier for the user.
+ query: Current user query/message.
+
+ Returns:
+ AssistantResponse containing the answer, usage stats, and debug info.
+
+ Notes:
+ - Internally manages user state, session history, memory retrieval
+ - After this call, you can call apply_feedback() with the turn's feedback
+ """
+ ctx = self._get_or_create_session(user_id)
+ session = ctx.session_state
+ user_state = self._user_store.get_state(user_id)
+
+ # Record user vector before for debug
+ z_long_before = user_state.z_long.copy().tolist()
+ z_short_before = user_state.z_short.copy().tolist()
+
+ # Compute query embedding
+ e_q_t = np.array(self._embed_model.encode([query], return_tensor=False)[0])
+
+ # Store pending RL update info from last turn (for apply_feedback)
+ if session.last_query is not None and self.enable_rl_updates:
+ ctx.pending_rl_update = {
+ "last_query": session.last_query,
+ "last_answer": session.last_answer,
+ "last_memories": session.last_memories,
+ "last_query_embedding": session.last_query_embedding,
+ "current_query_embedding": e_q_t,
+ "last_candidate_item_vectors": session.last_candidate_item_vectors,
+ "last_policy_probs": session.last_policy_probs,
+ "last_chosen_indices": session.last_chosen_indices,
+ }
+
+ # Add user turn to history
+ user_turn = self._build_chat_turn(user_id, query, "user", ctx.turn_counter)
+ session.history.append(user_turn)
+
+ # Vanilla mode: pure LLM without any memory or preference extraction
+ if self.mode == "vanilla":
+ # Skip preference extraction and memory retrieval entirely
+ extracted_prefs = []
+ candidates = []
+ cand_item_vecs = np.array([])
+ base_scores = np.array([])
+ chosen_indices = []
+ probs = np.array([])
+ memories_t = []
+ memory_notes = []
+ else:
+ # Extract preferences from conversation (if enabled)
+ extracted_prefs = []
+ if self.enable_preference_extraction:
+ prefs = self._extractor.extract_turn(session.history)
+ extracted_prefs = self._add_preferences_as_memory(
+ prefs, query, user_id, ctx.turn_counter
+ )
+
+ # Retrieve memories
+ # In "nopersonal" mode: deterministic retrieval (dense + rerank + topk), no policy/user vector
+ # In "full" mode: policy-based retrieval with user vector influence
+ if self.mode == "nopersonal":
+ candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_no_policy(
+ user_id=user_id,
+ query=query,
+ embed_model=self._embed_model,
+ reranker=self._reranker,
+ memory_cards=self._memory_cards,
+ memory_embeddings=self._memory_embeddings,
+ topk_dense=self._rl_cfg["dense_topk"],
+ topk_rerank=self._rl_cfg["rerank_topk"],
+ only_own_memories=self.only_own_memories,
+ )
+ else:
+ beta_long = self._rl_cfg["beta_long"]
+ beta_short = self._rl_cfg["beta_short"]
+ # eval_mode=True -> sample=False (greedy/deterministic)
+ # eval_mode=False -> sample=True (stochastic/exploration)
+ candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_with_policy(
+ user_id=user_id,
+ query=query,
+ embed_model=self._embed_model,
+ reranker=self._reranker,
+ memory_cards=self._memory_cards,
+ memory_embeddings=self._memory_embeddings,
+ user_store=self._user_store,
+ item_vectors=self._item_vectors,
+ topk_dense=self._rl_cfg["dense_topk"],
+ topk_rerank=self._rl_cfg["rerank_topk"],
+ beta_long=beta_long,
+ beta_short=beta_short,
+ tau=self._rl_cfg["tau"],
+ only_own_memories=self.only_own_memories,
+ sample=not self.eval_mode,
+ )
+
+ # Get selected memories
+ memories_t = [candidates[int(i)] for i in chosen_indices] if chosen_indices else []
+ memory_notes = [m.note_text for m in memories_t]
+
+ # Build prompt and count tokens
+ prompt_tokens = self._count_tokens(query)
+ for turn in session.history:
+ prompt_tokens += self._count_tokens(turn.text)
+ for note in memory_notes:
+ prompt_tokens += self._count_tokens(note)
+
+ # Generate answer
+ answer_t = self._chat_model.answer(
+ history=session.history,
+ memory_notes=memory_notes,
+ max_new_tokens=self._rl_cfg["max_new_tokens"],
+ )
+
+ completion_tokens = self._count_tokens(answer_t)
+
+ # Add assistant turn to history
+ assist_turn = self._build_chat_turn(user_id, answer_t, "assistant", ctx.turn_counter)
+ session.history.append(assist_turn)
+
+ # Update session state for next turn
+ session.last_query = query
+ session.last_answer = answer_t
+ session.last_memories = memories_t
+ session.last_query_embedding = e_q_t
+ session.last_candidate_item_vectors = cand_item_vecs
+ session.last_policy_probs = probs
+ session.last_chosen_indices = list(chosen_indices) if len(chosen_indices) > 0 else []
+
+ ctx.turn_counter += 1
+
+ # Build debug info
+ debug = DebugInfo(
+ selected_memory_ids=[m.card_id for m in memories_t],
+ selected_memory_notes=[m.note_text for m in memories_t],
+ selected_memory_scores=[float(probs[i]) if i < len(probs) else 0.0 for i in chosen_indices] if len(chosen_indices) > 0 else [],
+ user_vector_before=z_long_before + z_short_before, # Concatenated for simplicity
+ user_vector_after=user_state.z_long.tolist() + user_state.z_short.tolist(),
+ extracted_preferences=extracted_prefs,
+ extra={
+ "num_candidates": len(candidates),
+ "num_total_memories": len(self._memory_cards),
+ "z_long_norm": float(np.linalg.norm(user_state.z_long)),
+ "z_short_norm": float(np.linalg.norm(user_state.z_short)),
+ }
+ )
+
+ # Build usage stats
+ usage = UsageStats(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ model=self._llm_name,
+ )
+
+ return AssistantResponse(
+ answer=answer_t,
+ usage=usage,
+ debug=debug,
+ )
+
+ def reset_session(self, user_id: str) -> None:
+ """
+ Reset session for a user (new chat window).
+
+ This clears:
+ - Session conversation history
+ - Short-term user vector (z_short)
+ - Pending RL update info
+
+ This preserves:
+ - Long-term user vector (z_long)
+ - User's memory cards
+
+ Args:
+ user_id: The user whose session to reset.
+ """
+ # Clear session context
+ if user_id in self._sessions:
+ del self._sessions[user_id]
+
+ # Create fresh session
+ self._sessions[user_id] = _SessionContext(
+ session_state=OnlineSessionState(user_id=user_id),
+ turn_counter=0,
+ )
+
+ # Reset short-term vector but keep long-term
+ user_state = self._user_store.get_state(user_id)
+ user_state.z_short = np.zeros(self._rl_cfg["item_dim"], dtype=np.float32)
+ self._user_store.save_state(user_state)
+
+ def reset_user(self, user_id: str) -> None:
+ """
+ Completely reset a user (new "life").
+
+ This clears:
+ - Long-term user vector (z_long)
+ - Short-term user vector (z_short)
+ - User's memory cards
+ - Session history
+ - All cached state
+
+ Args:
+ user_id: The user to reset.
+ """
+ # Clear session
+ if user_id in self._sessions:
+ del self._sessions[user_id]
+
+ # Reset user state vectors
+ user_state = self._user_store.get_state(user_id)
+ user_state.z_long = self._user_store.global_init_z.copy()
+ user_state.z_short = np.zeros(self._rl_cfg["item_dim"], dtype=np.float32)
+ user_state.reward_ma = 0.0
+ self._user_store.save_state(user_state)
+
+ # Find indices to KEEP (cards NOT belonging to this user)
+ # Must do this BEFORE modifying _memory_cards
+ keep_indices = [
+ i for i, card in enumerate(self._memory_cards)
+ if card.user_id != user_id
+ ]
+
+ # Filter memory cards
+ self._memory_cards = [self._memory_cards[i] for i in keep_indices]
+
+ # Filter embeddings and item vectors to match
+ if len(keep_indices) > 0 and len(self._memory_embeddings) > 0:
+ self._memory_embeddings = self._memory_embeddings[keep_indices]
+ self._item_vectors = self._item_vectors[keep_indices]
+ else:
+ # No cards left or no embeddings
+ embed_dim = self._memory_embeddings.shape[1] if len(self._memory_embeddings) > 0 else 4096
+ self._memory_embeddings = np.zeros((0, embed_dim), dtype=np.float32)
+ self._item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32)
+
+ def apply_feedback(self, feedback: Feedback) -> None:
+ """
+ Apply feedback from user simulator or judge.
+
+ This performs the REINFORCE update to user vectors based on
+ the reward signal from the previous turn.
+
+ Args:
+ feedback: Feedback object containing reward, gating, and metadata.
+
+ Notes:
+ - Should be called AFTER chat() but BEFORE the next chat() call
+ - Uses the stored context from the previous turn
+ - If enable_rl_updates is False, this is a no-op (logging only)
+ - If mode is "nopersonal", this is a no-op (baseline comparison)
+ """
+ if not self.enable_rl_updates:
+ return
+
+ # In "nopersonal" or "vanilla" mode, skip RL updates entirely (baseline)
+ if self.mode in ("nopersonal", "vanilla"):
+ return
+
+ user_id = feedback.user_id
+ ctx = self._sessions.get(user_id)
+
+ if ctx is None or ctx.pending_rl_update is None:
+ return
+
+ pending = ctx.pending_rl_update
+ user_state = self._user_store.get_state(user_id)
+
+ # Check if we have the necessary data for RL update
+ if (pending.get("last_candidate_item_vectors") is not None and
+ pending.get("last_policy_probs") is not None and
+ pending.get("last_chosen_indices") is not None and
+ len(pending["last_chosen_indices"]) > 0):
+
+ # Extract chosen vectors
+ chosen_indices = pending["last_chosen_indices"]
+ candidate_vectors = pending["last_candidate_item_vectors"]
+
+ if len(candidate_vectors) > 0:
+ # REINFORCE expects:
+ # - item_vectors: ALL candidate vectors [K, k]
+ # - chosen_indices: indices into those candidates
+ # - policy_probs: probabilities over all K candidates [K]
+ updated = reinforce_update_user_state(
+ user_state=user_state,
+ item_vectors=candidate_vectors, # All candidates, not just chosen
+ chosen_indices=chosen_indices, # Original indices into candidates
+ policy_probs=pending["last_policy_probs"],
+ reward_hat=feedback.reward,
+ gating=feedback.gating,
+ tau=self._rl_cfg["tau"],
+ eta_long=self._rl_cfg["eta_long"],
+ eta_short=self._rl_cfg["eta_short"],
+ ema_alpha=self._rl_cfg["ema_alpha"],
+ short_decay=self._rl_cfg["short_decay"],
+ )
+
+ if updated:
+ self._user_store.save_state(user_state)
+
+ # Clear pending update
+ ctx.pending_rl_update = None
+
+ def get_user_state_summary(self, user_id: str) -> Dict[str, Any]:
+ """
+ Get a summary of the user's current state (for debugging/analysis).
+
+ Args:
+ user_id: The user to query.
+
+ Returns:
+ Dictionary with user state information.
+ """
+ user_state = self._user_store.get_state(user_id)
+ ctx = self._sessions.get(user_id)
+
+ user_memory_count = sum(
+ 1 for card in self._memory_cards if card.user_id == user_id
+ )
+
+ return {
+ "user_id": user_id,
+ "z_long_norm": float(np.linalg.norm(user_state.z_long)),
+ "z_short_norm": float(np.linalg.norm(user_state.z_short)),
+ "reward_ma": user_state.reward_ma,
+ "session_history_length": len(ctx.session_state.history) if ctx else 0,
+ "turn_counter": ctx.turn_counter if ctx else 0,
+ "user_memory_count": user_memory_count,
+ "total_memory_count": len(self._memory_cards),
+ }
+
+ def persist(self) -> None:
+ """
+ Persist all state to disk.
+
+ Call this at the end of an evaluation run to save:
+ - User state vectors
+ - Memory cards
+ """
+ # Save user store
+ self._user_store.persist()
+
+ # Save memory cards
+ with open(self._memory_cards_path, "w", encoding="utf-8") as f:
+ for card in self._memory_cards:
+ f.write(card.model_dump_json() + "\n")
+
+ # Save embeddings
+ np.save(self._memory_embeddings_path, self._memory_embeddings)
+
+ # Save item projection with updated vectors
+ if self._projection is not None:
+ np.savez(
+ self._item_projection_path,
+ P=self._projection.P,
+ mean=self._projection.mean,
+ V=self._item_vectors,
+ )
+
+ print("[PersonalizedLLM] State persisted to disk.")
+
+
+# =============================================================================
+# Convenience Factory
+# =============================================================================
+
+def create_personalized_llm(
+ config_path: Optional[str] = None,
+ **kwargs
+) -> PersonalizedLLM:
+ """
+ Factory function to create a PersonalizedLLM instance.
+
+ Args:
+ config_path: Optional path to configuration file.
+ **kwargs: Additional arguments passed to PersonalizedLLM constructor.
+
+ Returns:
+ Configured PersonalizedLLM instance.
+ """
+ return PersonalizedLLM(config_path=config_path, **kwargs)
+
diff --git a/src/personalization/types.py b/src/personalization/types.py
new file mode 100644
index 0000000..a25b560
--- /dev/null
+++ b/src/personalization/types.py
@@ -0,0 +1,4 @@
+from personalization.retrieval.preference_store.schemas import ChatTurn
+
+__all__ = ["ChatTurn"]
+
diff --git a/src/personalization/user_model/__init__.py b/src/personalization/user_model/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/user_model/__init__.py
diff --git a/src/personalization/user_model/features.py b/src/personalization/user_model/features.py
new file mode 100644
index 0000000..a4508b4
--- /dev/null
+++ b/src/personalization/user_model/features.py
@@ -0,0 +1,49 @@
+import numpy as np
+from dataclasses import dataclass
+from sklearn.decomposition import PCA
+
+@dataclass
+class ItemProjection:
+ P: np.ndarray # [k, d]
+ mean: np.ndarray # [d]
+
+ @classmethod
+ def from_pca(cls, embeddings: np.ndarray, k: int) -> "ItemProjection":
+ """
+ embeddings: [M, d]
+ """
+ mean = embeddings.mean(axis=0)
+ centered = embeddings - mean
+
+ # Ensure k is not larger than min(n_samples, n_features)
+ n_samples, n_features = embeddings.shape
+ actual_k = min(k, n_samples, n_features)
+
+ pca = PCA(n_components=actual_k)
+ pca.fit(centered)
+
+ # pca.components_: [k, d]
+ P = pca.components_ # Each row is a principal component vector
+
+ # If we had to reduce k, we might want to pad P or handle it?
+ # For now, let's assume we get what we asked for or less if data is small.
+ # But for the system we want fixed k.
+ # If actual_k < k, we should pad with zeros to match expected dimension.
+ if actual_k < k:
+ padding = np.zeros((k - actual_k, n_features), dtype=P.dtype)
+ P = np.vstack([P, padding])
+
+ return cls(P=P, mean=mean)
+
+ def transform_embeddings(self, E: np.ndarray) -> np.ndarray:
+ """
+ E: [N, d] -> [N, k]
+ """
+ return (E - self.mean) @ self.P.T
+
+ def transform_vector(self, e: np.ndarray) -> np.ndarray:
+ """
+ e: [d] -> [k]
+ """
+ return self.P @ (e - self.mean)
+
diff --git a/src/personalization/user_model/policy/__init__.py b/src/personalization/user_model/policy/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/user_model/policy/__init__.py
diff --git a/src/personalization/user_model/policy/optimizer.py b/src/personalization/user_model/policy/optimizer.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/user_model/policy/optimizer.py
diff --git a/src/personalization/user_model/policy/reinforce.py b/src/personalization/user_model/policy/reinforce.py
new file mode 100644
index 0000000..adfaef7
--- /dev/null
+++ b/src/personalization/user_model/policy/reinforce.py
@@ -0,0 +1,104 @@
+from typing import Sequence, List
+from dataclasses import dataclass
+import numpy as np
+
+from personalization.user_model.tensor_store import UserState
+
+@dataclass
+class PolicyScores:
+ scores: np.ndarray # [K] s(q_t, m; u)
+ probs: np.ndarray # [K] π_z(m|q_t)
+
+def compute_policy_scores(
+ base_scores: np.ndarray, # [K], from reranker
+ user_state: UserState,
+ item_vectors: np.ndarray, # [K, k], v_m for the K candidates
+ beta_long: float,
+ beta_short: float,
+ tau: float,
+) -> PolicyScores:
+ """
+ Compute personalized scores and softmax probabilities.
+ s(q_t, m; u) = s_0(q_t,m) + z_t^{(eff)}.T @ v_m
+ z_t^{(eff)} = beta_long * z_long + beta_short * z_short
+ """
+ if len(item_vectors) == 0:
+ return PolicyScores(scores=np.array([]), probs=np.array([]))
+
+ z_eff = beta_long * user_state.z_long + beta_short * user_state.z_short
+
+ # Calculate personalized term
+ # item_vectors: [K, k]
+ # z_eff: [k]
+ # term: [K]
+ personalization_term = np.dot(item_vectors, z_eff)
+
+ # Total scores
+ scores = base_scores + personalization_term
+
+ # Softmax
+ # Use exp(score/tau)
+ # Subtract max for stability
+ scaled_scores = scores / tau
+ exp_scores = np.exp(scaled_scores - np.max(scaled_scores))
+ probs = exp_scores / np.sum(exp_scores)
+
+ return PolicyScores(scores=scores, probs=probs)
+
+def reinforce_update_user_state(
+ user_state: UserState,
+ item_vectors: np.ndarray, # [K, k] for candidates
+ chosen_indices: Sequence[int], # indices of A_t in 0..K-1
+ policy_probs: np.ndarray, # [K] π_z(m|q_t)
+ reward_hat: float, # \hat r_t
+ gating: float, # g_t
+ tau: float,
+ eta_long: float,
+ eta_short: float,
+ ema_alpha: float,
+ short_decay: float,
+) -> bool:
+ """
+ In-place update user_state.z_long / z_short / reward_ma via REINFORCE.
+ Returns True if update occurred, False otherwise.
+ """
+ if len(chosen_indices) == 0:
+ return False
+
+ # 1. Baseline Advantage
+ advantage = gating * (reward_hat - user_state.reward_ma)
+
+ # Optimization: skip if advantage is negligible
+ if abs(advantage) < 1e-6:
+ return False
+
+ # 2. Chosen Vector Average (v_{chosen,t})
+ chosen_mask = np.zeros(len(item_vectors), dtype=np.float32)
+ for idx in chosen_indices:
+ idx_int = int(idx)
+ if 0 <= idx_int < len(item_vectors):
+ chosen_mask[idx_int] = 1.0
+
+ if chosen_mask.sum() == 0:
+ return False
+
+ chosen_mask /= chosen_mask.sum() # Normalize to average
+ v_chosen = np.dot(chosen_mask, item_vectors) # [k]
+
+ # 3. Expected Vector (\mu_t(z))
+ # policy_probs: [K]
+ # item_vectors: [K, k]
+ v_expect = np.dot(policy_probs, item_vectors) # [k]
+
+ # 4. Gradient Direction
+ grad = (advantage / tau) * (v_chosen - v_expect)
+
+ # 5. Update Vectors
+ user_state.z_long += eta_long * grad
+ user_state.z_short = (1.0 - short_decay) * user_state.z_short + eta_short * grad
+
+ # 6. Update Reward Baseline (EMA)
+ user_state.reward_ma = (1.0 - ema_alpha) * user_state.reward_ma + ema_alpha * reward_hat
+
+ return True
+
diff --git a/src/personalization/user_model/scoring.py b/src/personalization/user_model/scoring.py
new file mode 100644
index 0000000..75ffc84
--- /dev/null
+++ b/src/personalization/user_model/scoring.py
@@ -0,0 +1,25 @@
+import numpy as np
+from .tensor_store import UserState
+
+def score_with_user(
+ base_score: float,
+ user_state: UserState,
+ v_m: np.ndarray, # [k]
+ beta_long: float,
+ beta_short: float,
+) -> float:
+ """
+ Personalized scoring:
+ s = base_score + (beta_long * z_long + beta_short * z_short) . v_m
+ Day2: beta_long = beta_short = 0 -> s == base_score
+ """
+ z_eff = beta_long * user_state.z_long + beta_short * user_state.z_short
+ # dot product
+ # Ensure shapes match
+ if v_m.shape != z_eff.shape:
+ # Just in case of dimension mismatch
+ return float(base_score)
+
+ term = np.dot(z_eff, v_m)
+ return float(base_score + term)
+
diff --git a/src/personalization/user_model/session_state.py b/src/personalization/user_model/session_state.py
new file mode 100644
index 0000000..5cd2243
--- /dev/null
+++ b/src/personalization/user_model/session_state.py
@@ -0,0 +1,19 @@
+from dataclasses import dataclass, field
+from typing import List, Optional
+import numpy as np
+
+from personalization.retrieval.preference_store.schemas import ChatTurn, MemoryCard
+
+@dataclass
+class OnlineSessionState:
+ user_id: str
+ history: List[ChatTurn] = field(default_factory=list)
+ last_query: Optional[str] = None
+ last_answer: Optional[str] = None
+ last_memories: List[MemoryCard] = field(default_factory=list)
+ last_query_embedding: Optional[np.ndarray] = None
+ last_candidate_item_vectors: Optional[np.ndarray] = None # [K, k]
+ last_policy_probs: Optional[np.ndarray] = None # [K]
+ last_chosen_indices: List[int] = field(default_factory=list)
+
+
diff --git a/src/personalization/user_model/tensor_store.py b/src/personalization/user_model/tensor_store.py
new file mode 100644
index 0000000..42dbf4e
--- /dev/null
+++ b/src/personalization/user_model/tensor_store.py
@@ -0,0 +1,80 @@
+import numpy as np
+from dataclasses import dataclass
+from typing import Dict, Optional
+import os
+
+@dataclass
+class UserState:
+ user_id: str
+ z_long: np.ndarray # [k]
+ z_short: np.ndarray # [k]
+ reward_ma: float # baseline for reward, init 0.0
+
+class UserTensorStore:
+ def __init__(self, k: int, path: str):
+ self.k = k
+ self.path = path
+ self._states: Dict[str, UserState] = {}
+ self._load()
+
+ # Calculate global mean for initialization
+ if self._states:
+ z_all = np.stack([st.z_long for st in self._states.values()])
+ self.global_init_z = np.mean(z_all, axis=0)
+ else:
+ self.global_init_z = np.zeros(self.k, dtype=np.float32)
+
+ def _load(self):
+ if os.path.exists(self.path):
+ try:
+ data = np.load(self.path, allow_pickle=True)
+ # Assume saved as dict of user_id -> dict/object
+ # For simplicity, let's say we save a single dict in a .npy or .npz
+ # But np.save/load with pickle is tricky for complex objects.
+ # Let's save as .npz where each key is user_id and value is a structured array or just use z_long for now?
+ # A robust way for prototype:
+ # save multiple arrays: "u1_long", "u1_short", "u1_meta"
+ pass
+ # For Day 2 prototype, we might just re-init from init script or rely on memory if not persisting strictly.
+ # But let's try to load if we can.
+
+ # Let's implement a simple npz schema:
+ # keys: "{uid}_long", "{uid}_short", "{uid}_meta" (meta=[reward_ma])
+ for key in data.files:
+ if key.endswith("_long"):
+ uid = key[:-5]
+ z_long = data[key]
+ z_short = data.get(f"{uid}_short", np.zeros(self.k))
+ meta = data.get(f"{uid}_meta", np.array([0.0]))
+ self._states[uid] = UserState(uid, z_long, z_short, float(meta[0]))
+ except Exception as e:
+ print(f"Warning: Failed to load UserStore from {self.path}: {e}")
+
+ def _save(self):
+ # Save to npz
+ save_dict = {}
+ for uid, state in self._states.items():
+ save_dict[f"{uid}_long"] = state.z_long
+ save_dict[f"{uid}_short"] = state.z_short
+ save_dict[f"{uid}_meta"] = np.array([state.reward_ma])
+ np.savez(self.path, **save_dict)
+
+ def get_state(self, user_id: str) -> UserState:
+ if user_id not in self._states:
+ # Lazy init with global mean for new users
+ state = UserState(
+ user_id=user_id,
+ z_long=self.global_init_z.copy(),
+ z_short=np.zeros(self.k, dtype=np.float32),
+ reward_ma=0.0,
+ )
+ self._states[user_id] = state
+ return self._states[user_id]
+
+ def save_state(self, state: UserState) -> None:
+ self._states[state.user_id] = state
+
+ def persist(self):
+ """Public method to force save to disk."""
+ self._save()
+
diff --git a/src/personalization/utils/__init__.py b/src/personalization/utils/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/utils/__init__.py
diff --git a/src/personalization/utils/ids.py b/src/personalization/utils/ids.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/utils/ids.py
diff --git a/src/personalization/utils/io.py b/src/personalization/utils/io.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/utils/io.py
diff --git a/src/personalization/utils/logging.py b/src/personalization/utils/logging.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/utils/logging.py
diff --git a/src/personalization/utils/timing.py b/src/personalization/utils/timing.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/utils/timing.py