summaryrefslogtreecommitdiff
path: root/src/personalization/models/preference_extractor/rule_extractor.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/personalization/models/preference_extractor/rule_extractor.py')
-rw-r--r--src/personalization/models/preference_extractor/rule_extractor.py152
1 files changed, 152 insertions, 0 deletions
diff --git a/src/personalization/models/preference_extractor/rule_extractor.py b/src/personalization/models/preference_extractor/rule_extractor.py
new file mode 100644
index 0000000..0f743d9
--- /dev/null
+++ b/src/personalization/models/preference_extractor/rule_extractor.py
@@ -0,0 +1,152 @@
+from __future__ import annotations
+
+import json
+import re
+import os
+from typing import Any, Dict, List
+
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from personalization.config.registry import choose_dtype, choose_device_map
+from personalization.config.settings import LocalModelsConfig
+from .base import PreferenceExtractor
+from personalization.retrieval.preference_store.schemas import (
+ PreferenceList,
+ preference_list_json_schema,
+ ChatTurn,
+)
+
+# Hardcoded System Prompt to match SFT training
+# This MUST match what was used in training (scripts/split_train_test.py)
+SFT_SYSTEM_PROMPT = (
+ "Extract user preferences from the query into JSON format based on the PreferenceList schema. "
+ "If no preferences are found, return {\"preferences\": []}."
+)
+
+class QwenRuleExtractor(PreferenceExtractor):
+ """
+ Extractor using a Fine-Tuned (SFT) Qwen model.
+ Despite the name 'RuleExtractor' (legacy), this now performs direct End-to-End extraction.
+ """
+ def __init__(self, model_path: str, dtype: torch.dtype, device_map: str = "auto") -> None:
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ model_path, use_fast=True, trust_remote_code=True
+ )
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ dtype=dtype,
+ device_map=device_map,
+ trust_remote_code=True,
+ )
+ if self.tokenizer.pad_token_id is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ @classmethod
+ def from_config(cls, cfg: LocalModelsConfig) -> "QwenRuleExtractor":
+ spec = cfg.preference_extractor
+ dtype = choose_dtype(spec.dtype)
+ device_map = choose_device_map(spec.device_map)
+ return cls(spec.local_path, dtype=dtype, device_map=device_map)
+
+ def build_preference_prompt(self, query: str) -> str:
+ """
+ Construct the prompt string using the tokenizer's chat template.
+ Matches the format seen during SFT training.
+ """
+ messages = [
+ {"role": "system", "content": SFT_SYSTEM_PROMPT},
+ {"role": "user", "content": query}
+ ]
+ prompt = self.tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+ return prompt
+
+ @torch.inference_mode()
+ def extract_preferences(self, query: str) -> Dict[str, Any]:
+ """
+ Directly extract preferences from query using the SFT model.
+ Returns a dict compatible with PreferenceList model (key: 'preferences').
+ """
+ prompt = self.build_preference_prompt(query)
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
+
+ outputs = self.model.generate(
+ **inputs,
+ do_sample=False, # Deterministic greedy decoding
+ max_new_tokens=512, # Allow enough space for JSON
+ pad_token_id=self.tokenizer.pad_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ )
+
+ input_len = inputs["input_ids"].shape[1]
+ gen_ids = outputs[0][input_len:]
+ text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
+
+ if os.getenv("PREF_DEBUG") == "1":
+ print(f"[debug][extractor] Raw output: {text}")
+
+ # Try parsing JSON
+ try:
+ # 1. Direct parse
+ data = json.loads(text)
+
+ # 2. Validate against schema structure
+ validated = PreferenceList.model_validate(data)
+ return validated.model_dump()
+
+ except Exception:
+ # Fallback: Try to find JSON blob if model outputted extra text (rare for SFT but possible)
+ extracted_json = self._extract_json_substring(text)
+ if extracted_json:
+ try:
+ data = json.loads(extracted_json)
+ validated = PreferenceList.model_validate(data)
+ return validated.model_dump()
+ except:
+ pass
+
+ # If all fails, return empty
+ return {"preferences": []}
+
+ def _extract_json_substring(self, text: str) -> str | None:
+ """Helper to find { ... } block in text."""
+ # Find first '{' and last '}'
+ start = text.find('{')
+ end = text.rfind('}')
+ if start != -1 and end != -1 and end > start:
+ return text[start : end + 1]
+ return None
+
+ def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList:
+ """
+ Extract preferences from the LAST user turn in the history.
+ We don't concat history because our SFT model was trained on single-turn extraction.
+ Using context might confuse it unless we trained it that way.
+ """
+ # Find the last user message
+ last_user_msg = None
+ for t in reversed(turns):
+ if t.role == "user":
+ last_user_msg = t.text
+ break
+
+ if not last_user_msg:
+ return PreferenceList(preferences=[])
+
+ result_dict = self.extract_preferences(last_user_msg)
+ return PreferenceList.model_validate(result_dict)
+
+ def extract_session(self, turns: List[ChatTurn]) -> List[PreferenceList]:
+ """
+ Extract preferences from ALL user turns individually.
+ """
+ results = []
+ for turn in turns:
+ if turn.role == "user":
+ res = self.extract_preferences(turn.text)
+ results.append(PreferenceList.model_validate(res))
+ else:
+ results.append(PreferenceList(preferences=[]))
+ return results