summaryrefslogtreecommitdiff
path: root/src/personalization/models/preference_extractor
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-27 15:43:42 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-27 15:43:42 -0600
commitf918fc90b8d71d1287590b016d926268be573de0 (patch)
treed9009c8612c8e7f866c31d22fb979892a5b55eeb /src/personalization/models/preference_extractor
parent680513b7771a29f27cbbb3ffb009a69a913de6f9 (diff)
Add model wrapper modules (embedding, reranker, llm, preference_extractor)
Add Python wrappers for: - Qwen3/Nemotron embedding models - BGE/Qwen3 rerankers - vLLM/Llama/Qwen LLM backends - GPT-4o/LLM-based preference extractors Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'src/personalization/models/preference_extractor')
-rw-r--r--src/personalization/models/preference_extractor/__init__.py5
-rw-r--r--src/personalization/models/preference_extractor/base.py17
-rw-r--r--src/personalization/models/preference_extractor/gpt4o_extractor.py97
-rw-r--r--src/personalization/models/preference_extractor/llm_extractor.py153
-rw-r--r--src/personalization/models/preference_extractor/rule_extractor.py152
5 files changed, 424 insertions, 0 deletions
diff --git a/src/personalization/models/preference_extractor/__init__.py b/src/personalization/models/preference_extractor/__init__.py
new file mode 100644
index 0000000..65e2595
--- /dev/null
+++ b/src/personalization/models/preference_extractor/__init__.py
@@ -0,0 +1,5 @@
+from .rule_extractor import QwenRuleExtractor
+from .gpt4o_extractor import GPT4OExtractor
+from .base import PreferenceExtractor
+
+__all__ = ["QwenRuleExtractor", "GPT4OExtractor", "PreferenceExtractor"]
diff --git a/src/personalization/models/preference_extractor/base.py b/src/personalization/models/preference_extractor/base.py
new file mode 100644
index 0000000..850292f
--- /dev/null
+++ b/src/personalization/models/preference_extractor/base.py
@@ -0,0 +1,17 @@
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List
+from personalization.retrieval.preference_store.schemas import ChatTurn, PreferenceList
+
+class PreferenceExtractorBase(ABC):
+ @abstractmethod
+ def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList:
+ """
+ Extract preferences from a window of chat turns (history + current query).
+ """
+ raise NotImplementedError
+
+# Alias for backward compatibility if needed,
+# though specific extractors should inherit from PreferenceExtractorBase now.
+PreferenceExtractor = PreferenceExtractorBase
diff --git a/src/personalization/models/preference_extractor/gpt4o_extractor.py b/src/personalization/models/preference_extractor/gpt4o_extractor.py
new file mode 100644
index 0000000..212bb13
--- /dev/null
+++ b/src/personalization/models/preference_extractor/gpt4o_extractor.py
@@ -0,0 +1,97 @@
+from __future__ import annotations
+
+import json
+import os
+from typing import Any, Dict, List
+
+from openai import OpenAI
+from personalization.config.settings import LocalModelsConfig
+from personalization.models.preference_extractor.base import PreferenceExtractorBase as PreferenceExtractor
+from personalization.retrieval.preference_store.schemas import (
+ ChatTurn,
+ PreferenceList,
+ preference_list_json_schema,
+)
+
+
+class GPT4OExtractor(PreferenceExtractor):
+ def __init__(self, api_key: str, model: str = "gpt-4o") -> None:
+ self.client = OpenAI(api_key=api_key)
+ self.model = model
+
+ # Load system prompt template
+ template_path = "fine_tuning_prompt_template.txt"
+ if os.path.exists(template_path):
+ with open(template_path, "r", encoding="utf-8") as f:
+ self.system_prompt = f.read()
+ else:
+ # Fallback simple prompt if file missing
+ self.system_prompt = (
+ "You are a preference extraction assistant. "
+ "Extract user preferences from the query into a JSON object."
+ )
+
+ @classmethod
+ def from_config(cls, cfg: LocalModelsConfig) -> "GPT4OExtractor":
+ # We rely on env var for API key, config for other potential settings if needed
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ raise ValueError("OPENAI_API_KEY environment variable not set")
+ return cls(api_key=api_key)
+
+ def build_preference_prompt(self, query: str) -> str:
+ # GPT4OExtractor uses the system prompt loaded in __init__
+ return self.system_prompt
+
+ def extract_preferences(self, query: str) -> Dict[str, Any]:
+ # Reuse logic but return raw dict
+ try:
+ response = self.client.chat.completions.create(
+ model=self.model,
+ messages=[
+ {"role": "system", "content": self.system_prompt},
+ {"role": "user", "content": query},
+ ],
+ response_format={"type": "json_object"},
+ temperature=0.0,
+ )
+ content = response.choices[0].message.content
+ if content:
+ return json.loads(content)
+ except Exception as e:
+ print(f"Error calling GPT-4o: {e}")
+ return {"preferences": []}
+
+ def extract_turn(self, turn: ChatTurn) -> PreferenceList:
+ if turn.role != "user":
+ return PreferenceList(preferences=[])
+
+ try:
+ response = self.client.chat.completions.create(
+ model=self.model,
+ messages=[
+ {"role": "system", "content": self.system_prompt},
+ {"role": "user", "content": turn.text},
+ ],
+ response_format={"type": "json_object"},
+ temperature=0.0,
+ )
+
+ content = response.choices[0].message.content
+ if not content:
+ return PreferenceList(preferences=[])
+
+ data = json.loads(content)
+ # The prompt might return {"preferences": [...]}, validate it
+ return PreferenceList.model_validate(data)
+
+ except Exception as e:
+ print(f"Error calling GPT-4o: {e}")
+ return PreferenceList(preferences=[])
+
+ def extract_session(self, turns: List[ChatTurn]) -> List[PreferenceList]:
+ results = []
+ for turn in turns:
+ results.append(self.extract_turn(turn))
+ return results
+
diff --git a/src/personalization/models/preference_extractor/llm_extractor.py b/src/personalization/models/preference_extractor/llm_extractor.py
new file mode 100644
index 0000000..8f7a6cb
--- /dev/null
+++ b/src/personalization/models/preference_extractor/llm_extractor.py
@@ -0,0 +1,153 @@
+from typing import List, Dict, Any
+import torch
+import json
+import os
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from personalization.models.preference_extractor.base import PreferenceExtractorBase
+from personalization.retrieval.preference_store.schemas import ChatTurn, PreferenceList
+from personalization.config.settings import LocalModelsConfig
+from personalization.config.registry import choose_dtype, choose_device_map
+
+class PreferenceExtractorLLM(PreferenceExtractorBase):
+ def __init__(
+ self,
+ model_path: str,
+ prompt_template_path: str = "fine_tuning_prompt_template.txt",
+ device_map: str = "auto",
+ dtype: torch.dtype = torch.bfloat16,
+ max_new_tokens: int = 512,
+ ) -> None:
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ torch_dtype=dtype,
+ device_map=device_map,
+ trust_remote_code=True,
+ )
+ self.max_new_tokens = max_new_tokens
+
+ if os.path.exists(prompt_template_path):
+ with open(prompt_template_path, "r", encoding="utf-8") as f:
+ self.prompt_template = f.read()
+ else:
+ print(f"Warning: Prompt template not found at {prompt_template_path}. Using fallback.")
+ self.prompt_template = "Extract user preferences from the following conversation."
+
+ @classmethod
+ def from_config(cls, cfg: LocalModelsConfig, name: str = "qwen3_0_6b_sft") -> "PreferenceExtractorLLM":
+ # We need to access the specific extractor config by name
+ # Assuming cfg has a way to access extra configs or we update LocalModelsConfig to support multiple extractors
+ # For now, let's look for it in the 'preference_extractor' dict if it was a Dict, but it is a ModelSpec.
+ # We need to update LocalModelsConfig to support a dictionary of extractors or a specific one.
+ # Based on user design doc:
+ # preference_extractor:
+ # qwen3_0_6b_sft: ...
+
+ # We might need to manually parse the raw config or update settings.py
+ # Let's assume settings.py will be updated to hold a map or specific fields.
+ # For now, if we use the existing ModelSpec for preference_extractor in cfg, we assume it points to this model.
+
+ # BUT the design doc says "preference_extractor" in local_models.yaml will have "qwen3_0_6b_sft" key.
+ # The current settings.py defines preference_extractor as a single ModelSpec.
+ # We will need to update settings.py first to support multiple extractors or a dict.
+ # I will proceed implementing this class assuming arguments are passed, and update settings/registry later.
+
+ # This from_config might change depending on how settings.py is refactored.
+ # For now I will implement it assuming a direct ModelSpec is passed, or we handle it in registry.
+ pass
+ return None
+
+ def _build_prompt(self, turns: List[ChatTurn]) -> str:
+ # Construct messages list for chat template
+ messages = [{"role": "system", "content": self.prompt_template}]
+
+ # Window size 6
+ window = turns[-6:]
+
+ # Add conversation history
+ # We need to format the conversation as input context.
+ # Since the task is to extract preferences from the *whole* context (or latest turn?),
+ # usually we provide the conversation and ask for extraction.
+ # But LLaMA-Factory SFT usually expects:
+ # System: <template>
+ # User: <input>
+ # Assistant: <output>
+
+ # We should pack the conversation history into the User message?
+ # Or if we trained with multi-turn chat format?
+ # Assuming "Input" column in dataset was the conversation history.
+
+ history_texts = []
+ for t in window:
+ role = "User" if t.role == "user" else "Assistant"
+ history_texts.append(f"{role}: {t.text}")
+
+ conversation_text = "\n".join(history_texts)
+
+ # Construct the User input
+ # We append a trigger instruction if it wasn't part of the training input implicitly.
+ # But based on your template, the User Input Example was just the query "I am a Python developer..."
+ # So likely we should just feed the conversation text as the user message.
+
+ messages.append({"role": "user", "content": conversation_text})
+
+ # Apply chat template
+ prompt = self.tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True
+ )
+
+ return prompt
+
+ def _generate(self, prompt: str) -> str:
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
+ with torch.no_grad():
+ outputs = self.model.generate(
+ **inputs,
+ max_new_tokens=self.max_new_tokens,
+ do_sample=False,
+ temperature=0.0,
+ eos_token_id=self.tokenizer.eos_token_id,
+ pad_token_id=self.tokenizer.pad_token_id,
+ )
+ full_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
+ return full_text[len(prompt):]
+
+ def _parse_preferences(self, raw_output: str) -> PreferenceList:
+ start = raw_output.find("{")
+ end = raw_output.rfind("}")
+
+ if start == -1 or end == -1 or end <= start:
+ return PreferenceList(preferences=[])
+
+ json_str = raw_output[start:end+1]
+ try:
+ data = json.loads(json_str)
+ return PreferenceList.model_validate(data)
+ except Exception:
+ return PreferenceList(preferences=[])
+
+ def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList:
+ prompt = self._build_prompt(turns)
+ raw_output = self._generate(prompt)
+ return self._parse_preferences(raw_output)
+
+ # Legacy support
+ def build_preference_prompt(self, query: str) -> str:
+ # Wrap query in a dummy turn
+ turn = ChatTurn(
+ user_id="dummy", session_id="dummy", turn_id=0,
+ role="user", text=query
+ )
+ return self._build_prompt([turn])
+
+ def extract_preferences(self, query: str) -> Dict[str, Any]:
+ turn = ChatTurn(
+ user_id="dummy", session_id="dummy", turn_id=0,
+ role="user", text=query
+ )
+ prefs = self.extract_turn([turn])
+ return prefs.model_dump()
+
diff --git a/src/personalization/models/preference_extractor/rule_extractor.py b/src/personalization/models/preference_extractor/rule_extractor.py
new file mode 100644
index 0000000..0f743d9
--- /dev/null
+++ b/src/personalization/models/preference_extractor/rule_extractor.py
@@ -0,0 +1,152 @@
+from __future__ import annotations
+
+import json
+import re
+import os
+from typing import Any, Dict, List
+
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from personalization.config.registry import choose_dtype, choose_device_map
+from personalization.config.settings import LocalModelsConfig
+from .base import PreferenceExtractor
+from personalization.retrieval.preference_store.schemas import (
+ PreferenceList,
+ preference_list_json_schema,
+ ChatTurn,
+)
+
+# Hardcoded System Prompt to match SFT training
+# This MUST match what was used in training (scripts/split_train_test.py)
+SFT_SYSTEM_PROMPT = (
+ "Extract user preferences from the query into JSON format based on the PreferenceList schema. "
+ "If no preferences are found, return {\"preferences\": []}."
+)
+
+class QwenRuleExtractor(PreferenceExtractor):
+ """
+ Extractor using a Fine-Tuned (SFT) Qwen model.
+ Despite the name 'RuleExtractor' (legacy), this now performs direct End-to-End extraction.
+ """
+ def __init__(self, model_path: str, dtype: torch.dtype, device_map: str = "auto") -> None:
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ model_path, use_fast=True, trust_remote_code=True
+ )
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ dtype=dtype,
+ device_map=device_map,
+ trust_remote_code=True,
+ )
+ if self.tokenizer.pad_token_id is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ @classmethod
+ def from_config(cls, cfg: LocalModelsConfig) -> "QwenRuleExtractor":
+ spec = cfg.preference_extractor
+ dtype = choose_dtype(spec.dtype)
+ device_map = choose_device_map(spec.device_map)
+ return cls(spec.local_path, dtype=dtype, device_map=device_map)
+
+ def build_preference_prompt(self, query: str) -> str:
+ """
+ Construct the prompt string using the tokenizer's chat template.
+ Matches the format seen during SFT training.
+ """
+ messages = [
+ {"role": "system", "content": SFT_SYSTEM_PROMPT},
+ {"role": "user", "content": query}
+ ]
+ prompt = self.tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+ return prompt
+
+ @torch.inference_mode()
+ def extract_preferences(self, query: str) -> Dict[str, Any]:
+ """
+ Directly extract preferences from query using the SFT model.
+ Returns a dict compatible with PreferenceList model (key: 'preferences').
+ """
+ prompt = self.build_preference_prompt(query)
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
+
+ outputs = self.model.generate(
+ **inputs,
+ do_sample=False, # Deterministic greedy decoding
+ max_new_tokens=512, # Allow enough space for JSON
+ pad_token_id=self.tokenizer.pad_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ )
+
+ input_len = inputs["input_ids"].shape[1]
+ gen_ids = outputs[0][input_len:]
+ text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
+
+ if os.getenv("PREF_DEBUG") == "1":
+ print(f"[debug][extractor] Raw output: {text}")
+
+ # Try parsing JSON
+ try:
+ # 1. Direct parse
+ data = json.loads(text)
+
+ # 2. Validate against schema structure
+ validated = PreferenceList.model_validate(data)
+ return validated.model_dump()
+
+ except Exception:
+ # Fallback: Try to find JSON blob if model outputted extra text (rare for SFT but possible)
+ extracted_json = self._extract_json_substring(text)
+ if extracted_json:
+ try:
+ data = json.loads(extracted_json)
+ validated = PreferenceList.model_validate(data)
+ return validated.model_dump()
+ except:
+ pass
+
+ # If all fails, return empty
+ return {"preferences": []}
+
+ def _extract_json_substring(self, text: str) -> str | None:
+ """Helper to find { ... } block in text."""
+ # Find first '{' and last '}'
+ start = text.find('{')
+ end = text.rfind('}')
+ if start != -1 and end != -1 and end > start:
+ return text[start : end + 1]
+ return None
+
+ def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList:
+ """
+ Extract preferences from the LAST user turn in the history.
+ We don't concat history because our SFT model was trained on single-turn extraction.
+ Using context might confuse it unless we trained it that way.
+ """
+ # Find the last user message
+ last_user_msg = None
+ for t in reversed(turns):
+ if t.role == "user":
+ last_user_msg = t.text
+ break
+
+ if not last_user_msg:
+ return PreferenceList(preferences=[])
+
+ result_dict = self.extract_preferences(last_user_msg)
+ return PreferenceList.model_validate(result_dict)
+
+ def extract_session(self, turns: List[ChatTurn]) -> List[PreferenceList]:
+ """
+ Extract preferences from ALL user turns individually.
+ """
+ results = []
+ for turn in turns:
+ if turn.role == "user":
+ res = self.extract_preferences(turn.text)
+ results.append(PreferenceList.model_validate(res))
+ else:
+ results.append(PreferenceList(preferences=[]))
+ return results