summaryrefslogtreecommitdiff
path: root/scripts/migrate_preferences.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/migrate_preferences.py')
-rw-r--r--scripts/migrate_preferences.py165
1 files changed, 165 insertions, 0 deletions
diff --git a/scripts/migrate_preferences.py b/scripts/migrate_preferences.py
new file mode 100644
index 0000000..5d393c9
--- /dev/null
+++ b/scripts/migrate_preferences.py
@@ -0,0 +1,165 @@
+#!/usr/bin/env python3
+"""
+Script to migrate raw queries into MemoryCards by extracting preferences.
+It reads from data/raw_datasets/pilot_study_1000.jsonl and outputs:
+- data/corpora/memory_cards.jsonl
+- data/corpora/memory_embeddings.npy
+"""
+
+import json
+import os
+import sys
+
+# Add src to sys.path so we can import personalization
+sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))
+
+import uuid
+import numpy as np
+import torch
+from pathlib import Path
+from tqdm import tqdm
+from typing import List
+
+from personalization.config.settings import load_local_models_config
+# from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor
+from personalization.models.preference_extractor.gpt4o_extractor import GPT4OExtractor
+from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
+from personalization.retrieval.preference_store.schemas import ChatTurn, MemoryCard, PreferenceList
+
+def ensure_dir(path: str):
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
+
+def main():
+ # 1. Setup paths
+ input_path = "data/corpora/oasst1_labeled.jsonl"
+ # input_path = "data/raw_datasets/oasst1_queries.jsonl"
+ output_cards_path = "data/corpora/memory_cards.jsonl"
+ output_emb_path = "data/corpora/memory_embeddings.npy"
+ ensure_dir(output_cards_path)
+
+ print("Loading models configuration...")
+ cfg = load_local_models_config()
+
+ # 2. Initialize models
+ # print("Initializing Preference Extractor (GPT-4o)...")
+ # extractor = GPT4OExtractor.from_config(cfg)
+
+ print("Initializing Embedding Model...")
+ embedder = Qwen3Embedding8B.from_config(cfg)
+
+ # 3. Process data
+ print(f"Reading from {input_path}...")
+ memory_cards: List[MemoryCard] = []
+
+ # We will process in small batches to manage memory if needed,
+ # but for 1000 items, we can iterate one by one for extraction
+ # and maybe batch for embedding if we want optimization.
+ # Given the complexity, let's just do sequential for simplicity and safety.
+
+ with open(input_path, "r", encoding="utf-8") as f:
+ lines = f.readlines()
+
+ # Synthetic user distribution (round robin for 10 users)
+ users = [f"user_{i}" for i in range(10)]
+
+ print("Extracting preferences...")
+ # Use tqdm for progress
+ for idx, line in enumerate(tqdm(lines)):
+ # if idx >= 100: # LIMIT to 100 items
+ # break
+
+ row = json.loads(line)
+ query = row.get("original_query", "").strip()
+ if not query:
+ continue
+
+ # Use real metadata from dataset
+ user_id = row.get("user_id", f"user_{idx}")
+ session_id = row.get("session_id", f"sess_{idx}")
+ turn_id = row.get("turn_id", 0)
+
+ # Load pre-extracted preferences
+ has_pref = row.get("has_preference", False)
+ extracted_data = row.get("extracted_json", {})
+
+ # Skip if no preference (according to label)
+ if not has_pref:
+ continue
+
+ try:
+ pref_list = PreferenceList.model_validate(extracted_data)
+ except Exception:
+ # Fallback or skip if validation fails
+ continue
+
+ # If we have preferences, create a memory card
+ if pref_list.preferences:
+ # Construct a note text: "condition: action"
+ notes = [f"{p.condition}: {p.action}" for p in pref_list.preferences]
+ note_summary = "; ".join(notes)
+
+ # Create MemoryCard (embedding will be filled later)
+ card = MemoryCard(
+ card_id=str(uuid.uuid4()),
+ user_id=user_id,
+ source_session_id=session_id,
+ source_turn_ids=[turn_id],
+ raw_queries=[query],
+ preference_list=pref_list,
+ note_text=note_summary,
+ embedding_e=[], # To be filled
+ kind="pref"
+ )
+ memory_cards.append(card)
+
+ print(f"Found {len(memory_cards)} memory cards. Generating embeddings...")
+
+ if not memory_cards:
+ print("No preferences found. Exiting.")
+ return
+
+ # 4. Generate Embeddings
+ # We'll embed the `raw_queries` (joined) or `note_text`?
+ # The design doc says: "Qwen3Embedding8B.encode([turn.text])"
+ # So we embed the original query that generated the memory.
+
+ texts_to_embed = [card.raw_queries[0] for card in memory_cards]
+
+ print(f"Embedding {len(texts_to_embed)} memories...")
+ embeddings_list = []
+ chunk_size = 2000 # Process in chunks to avoid OOM
+
+ for i in range(0, len(texts_to_embed), chunk_size):
+ print(f" Embedding chunk {i} to {min(i+chunk_size, len(texts_to_embed))}...")
+ chunk = texts_to_embed[i : i + chunk_size]
+
+ # Batch encode with larger batch_size for A40
+ chunk_emb = embedder.encode(
+ chunk,
+ batch_size=128,
+ normalize=True,
+ return_tensor=False
+ )
+ embeddings_list.extend(chunk_emb)
+
+ # Assign back to cards and prepare matrix
+ emb_matrix = []
+ for card, emb in zip(memory_cards, embeddings_list):
+ card.embedding_e = emb
+ emb_matrix.append(emb)
+
+ # 5. Save
+ print(f"Saving {len(memory_cards)} cards to {output_cards_path}...")
+ with open(output_cards_path, "w", encoding="utf-8") as f:
+ for card in memory_cards:
+ f.write(card.model_dump_json() + "\n")
+
+ print(f"Saving embeddings matrix to {output_emb_path}...")
+ np_emb = np.array(emb_matrix, dtype=np.float32)
+ np.save(output_emb_path, np_emb)
+
+ print("Done!")
+
+if __name__ == "__main__":
+ main()
+