diff options
Diffstat (limited to 'src/personalization/models/preference_extractor')
5 files changed, 545 insertions, 0 deletions
diff --git a/src/personalization/models/preference_extractor/__init__.py b/src/personalization/models/preference_extractor/__init__.py new file mode 100644 index 0000000..65e2595 --- /dev/null +++ b/src/personalization/models/preference_extractor/__init__.py @@ -0,0 +1,5 @@ +from .rule_extractor import QwenRuleExtractor +from .gpt4o_extractor import GPT4OExtractor +from .base import PreferenceExtractor + +__all__ = ["QwenRuleExtractor", "GPT4OExtractor", "PreferenceExtractor"] diff --git a/src/personalization/models/preference_extractor/base.py b/src/personalization/models/preference_extractor/base.py new file mode 100644 index 0000000..850292f --- /dev/null +++ b/src/personalization/models/preference_extractor/base.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, List +from personalization.retrieval.preference_store.schemas import ChatTurn, PreferenceList + +class PreferenceExtractorBase(ABC): + @abstractmethod + def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList: + """ + Extract preferences from a window of chat turns (history + current query). + """ + raise NotImplementedError + +# Alias for backward compatibility if needed, +# though specific extractors should inherit from PreferenceExtractorBase now. +PreferenceExtractor = PreferenceExtractorBase diff --git a/src/personalization/models/preference_extractor/gpt4o_extractor.py b/src/personalization/models/preference_extractor/gpt4o_extractor.py new file mode 100644 index 0000000..0f70522 --- /dev/null +++ b/src/personalization/models/preference_extractor/gpt4o_extractor.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +import json +import os +from typing import Any, Dict, List + +from openai import OpenAI +from personalization.config.settings import LocalModelsConfig +from personalization.models.preference_extractor.base import PreferenceExtractorBase as PreferenceExtractor +from personalization.retrieval.preference_store.schemas import ( + ChatTurn, + PreferenceList, + preference_list_json_schema, +) + + +class GPT4OExtractor(PreferenceExtractor): + def __init__(self, api_key: str, model: str = "gpt-4o") -> None: + self.client = OpenAI(api_key=api_key) + self.model = model + + # Load system prompt template + template_path = "fine_tuning_prompt_template.txt" + if os.path.exists(template_path): + with open(template_path, "r", encoding="utf-8") as f: + self.system_prompt = f.read() + else: + # Structured prompt that enforces the PreferenceList schema + self.system_prompt = ( + "You are a preference extraction assistant. " + "Given a user message, extract any user preferences as condition-action rules.\n\n" + "Return a JSON object with exactly this structure:\n" + '{"preferences": [{"condition": "<when this applies>", "action": "<what to do>", "confidence": <0.0-1.0>}]}\n\n' + "Examples of preferences:\n" + '- {"condition": "general", "action": "respond in Chinese", "confidence": 0.9}\n' + '- {"condition": "when writing code", "action": "use Python with type hints", "confidence": 0.8}\n' + '- {"condition": "when explaining math", "action": "show step-by-step derivation", "confidence": 0.7}\n\n' + "If no preferences are found, return {\"preferences\": []}.\n" + "IMPORTANT: The output MUST be a JSON object with a \"preferences\" key containing a list." + ) + + @classmethod + def from_config(cls, cfg: LocalModelsConfig) -> "GPT4OExtractor": + # We rely on env var for API key, config for other potential settings if needed + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + raise ValueError("OPENAI_API_KEY environment variable not set") + return cls(api_key=api_key) + + def build_preference_prompt(self, query: str) -> str: + # GPT4OExtractor uses the system prompt loaded in __init__ + return self.system_prompt + + def _call_kwargs(self, messages): + """Build kwargs for chat completion, skipping temperature for models that don't support it.""" + kwargs = { + "model": self.model, + "messages": messages, + "response_format": {"type": "json_object"}, + } + # GPT-5 series doesn't support temperature=0 + if not self.model.startswith("gpt-5"): + kwargs["temperature"] = 0.0 + return kwargs + + def extract_preferences(self, query: str) -> Dict[str, Any]: + # Reuse logic but return raw dict + try: + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": query}, + ] + response = self.client.chat.completions.create(**self._call_kwargs(messages)) + content = response.choices[0].message.content + if content: + return json.loads(content) + except Exception as e: + print(f"Error calling GPT-4o: {e}") + return {"preferences": []} + + def extract_turn(self, turns) -> PreferenceList: + # Accept both a single ChatTurn and a list of ChatTurns (history) + if isinstance(turns, list): + # Find the last user message in history + last_user_msg = None + for t in reversed(turns): + if hasattr(t, 'role') and t.role == "user": + last_user_msg = t.text + break + if not last_user_msg: + return PreferenceList(preferences=[]) + else: + # Single ChatTurn + if turns.role != "user": + return PreferenceList(preferences=[]) + last_user_msg = turns.text + + try: + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": last_user_msg}, + ] + response = self.client.chat.completions.create(**self._call_kwargs(messages)) + + content = response.choices[0].message.content + if not content: + return PreferenceList(preferences=[]) + + data = json.loads(content) + return self._parse_to_preference_list(data) + + except Exception as e: + print(f"Error calling GPT-4o: {e}") + return PreferenceList(preferences=[]) + + @staticmethod + def _parse_to_preference_list(data: dict) -> PreferenceList: + """Robustly convert GPT output to PreferenceList, handling non-standard formats.""" + # Best case: already matches schema + if "preferences" in data and isinstance(data["preferences"], list): + prefs = [] + for item in data["preferences"]: + if isinstance(item, dict) and "condition" in item and "action" in item: + prefs.append({ + "condition": str(item["condition"])[:128], + "action": str(item["action"])[:256], + "confidence": float(item.get("confidence", 0.7)), + }) + return PreferenceList.model_validate({"preferences": prefs}) + + # GPT returned a flat dict of preferences - convert to condition/action pairs + prefs = [] + for key, value in data.items(): + if isinstance(value, str) and len(value) > 2: + prefs.append({ + "condition": str(key)[:128] if len(str(key)) > 1 else "general", + "action": str(value)[:256], + "confidence": 0.7, + }) + elif isinstance(value, dict): + # Nested dict: try to extract meaningful pairs + for sub_key, sub_val in value.items(): + if isinstance(sub_val, str) and len(sub_val) > 2: + prefs.append({ + "condition": str(sub_key)[:128], + "action": str(sub_val)[:256], + "confidence": 0.7, + }) + elif isinstance(value, list): + for item in value: + if isinstance(item, str) and len(item) > 2: + prefs.append({ + "condition": str(key)[:128], + "action": str(item)[:256], + "confidence": 0.7, + }) + + return PreferenceList.model_validate({"preferences": prefs[:20]}) + + def extract_session(self, turns: List[ChatTurn]) -> List[PreferenceList]: + results = [] + for turn in turns: + results.append(self.extract_turn(turn)) + return results + diff --git a/src/personalization/models/preference_extractor/llm_extractor.py b/src/personalization/models/preference_extractor/llm_extractor.py new file mode 100644 index 0000000..8f7a6cb --- /dev/null +++ b/src/personalization/models/preference_extractor/llm_extractor.py @@ -0,0 +1,153 @@ +from typing import List, Dict, Any +import torch +import json +import os +from transformers import AutoModelForCausalLM, AutoTokenizer + +from personalization.models.preference_extractor.base import PreferenceExtractorBase +from personalization.retrieval.preference_store.schemas import ChatTurn, PreferenceList +from personalization.config.settings import LocalModelsConfig +from personalization.config.registry import choose_dtype, choose_device_map + +class PreferenceExtractorLLM(PreferenceExtractorBase): + def __init__( + self, + model_path: str, + prompt_template_path: str = "fine_tuning_prompt_template.txt", + device_map: str = "auto", + dtype: torch.dtype = torch.bfloat16, + max_new_tokens: int = 512, + ) -> None: + self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=device_map, + trust_remote_code=True, + ) + self.max_new_tokens = max_new_tokens + + if os.path.exists(prompt_template_path): + with open(prompt_template_path, "r", encoding="utf-8") as f: + self.prompt_template = f.read() + else: + print(f"Warning: Prompt template not found at {prompt_template_path}. Using fallback.") + self.prompt_template = "Extract user preferences from the following conversation." + + @classmethod + def from_config(cls, cfg: LocalModelsConfig, name: str = "qwen3_0_6b_sft") -> "PreferenceExtractorLLM": + # We need to access the specific extractor config by name + # Assuming cfg has a way to access extra configs or we update LocalModelsConfig to support multiple extractors + # For now, let's look for it in the 'preference_extractor' dict if it was a Dict, but it is a ModelSpec. + # We need to update LocalModelsConfig to support a dictionary of extractors or a specific one. + # Based on user design doc: + # preference_extractor: + # qwen3_0_6b_sft: ... + + # We might need to manually parse the raw config or update settings.py + # Let's assume settings.py will be updated to hold a map or specific fields. + # For now, if we use the existing ModelSpec for preference_extractor in cfg, we assume it points to this model. + + # BUT the design doc says "preference_extractor" in local_models.yaml will have "qwen3_0_6b_sft" key. + # The current settings.py defines preference_extractor as a single ModelSpec. + # We will need to update settings.py first to support multiple extractors or a dict. + # I will proceed implementing this class assuming arguments are passed, and update settings/registry later. + + # This from_config might change depending on how settings.py is refactored. + # For now I will implement it assuming a direct ModelSpec is passed, or we handle it in registry. + pass + return None + + def _build_prompt(self, turns: List[ChatTurn]) -> str: + # Construct messages list for chat template + messages = [{"role": "system", "content": self.prompt_template}] + + # Window size 6 + window = turns[-6:] + + # Add conversation history + # We need to format the conversation as input context. + # Since the task is to extract preferences from the *whole* context (or latest turn?), + # usually we provide the conversation and ask for extraction. + # But LLaMA-Factory SFT usually expects: + # System: <template> + # User: <input> + # Assistant: <output> + + # We should pack the conversation history into the User message? + # Or if we trained with multi-turn chat format? + # Assuming "Input" column in dataset was the conversation history. + + history_texts = [] + for t in window: + role = "User" if t.role == "user" else "Assistant" + history_texts.append(f"{role}: {t.text}") + + conversation_text = "\n".join(history_texts) + + # Construct the User input + # We append a trigger instruction if it wasn't part of the training input implicitly. + # But based on your template, the User Input Example was just the query "I am a Python developer..." + # So likely we should just feed the conversation text as the user message. + + messages.append({"role": "user", "content": conversation_text}) + + # Apply chat template + prompt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + + return prompt + + def _generate(self, prompt: str) -> str: + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_new_tokens=self.max_new_tokens, + do_sample=False, + temperature=0.0, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + ) + full_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) + return full_text[len(prompt):] + + def _parse_preferences(self, raw_output: str) -> PreferenceList: + start = raw_output.find("{") + end = raw_output.rfind("}") + + if start == -1 or end == -1 or end <= start: + return PreferenceList(preferences=[]) + + json_str = raw_output[start:end+1] + try: + data = json.loads(json_str) + return PreferenceList.model_validate(data) + except Exception: + return PreferenceList(preferences=[]) + + def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList: + prompt = self._build_prompt(turns) + raw_output = self._generate(prompt) + return self._parse_preferences(raw_output) + + # Legacy support + def build_preference_prompt(self, query: str) -> str: + # Wrap query in a dummy turn + turn = ChatTurn( + user_id="dummy", session_id="dummy", turn_id=0, + role="user", text=query + ) + return self._build_prompt([turn]) + + def extract_preferences(self, query: str) -> Dict[str, Any]: + turn = ChatTurn( + user_id="dummy", session_id="dummy", turn_id=0, + role="user", text=query + ) + prefs = self.extract_turn([turn]) + return prefs.model_dump() + diff --git a/src/personalization/models/preference_extractor/rule_extractor.py b/src/personalization/models/preference_extractor/rule_extractor.py new file mode 100644 index 0000000..42f43ed --- /dev/null +++ b/src/personalization/models/preference_extractor/rule_extractor.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +import json +import re +import os +from typing import Any, Dict, List + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from personalization.config.registry import choose_dtype, choose_device_map +from personalization.config.settings import LocalModelsConfig +from .base import PreferenceExtractor +from personalization.retrieval.preference_store.schemas import ( + PreferenceList, + preference_list_json_schema, + ChatTurn, +) + +# Hardcoded System Prompt to match SFT training +# This MUST match what was used in training (scripts/split_train_test.py) +SFT_SYSTEM_PROMPT = ( + "Extract user preferences from the query into JSON format based on the PreferenceList schema. " + "If no preferences are found, return {\"preferences\": []}." +) + +class QwenRuleExtractor(PreferenceExtractor): + """ + Extractor using a Fine-Tuned (SFT) Qwen model. + Despite the name 'RuleExtractor' (legacy), this now performs direct End-to-End extraction. + """ + def __init__(self, model_path: str, dtype: torch.dtype, device_map: str = "auto") -> None: + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=True, trust_remote_code=True + ) + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + dtype=dtype, + device_map=device_map, + trust_remote_code=True, + ) + if self.tokenizer.pad_token_id is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + @classmethod + def from_config(cls, cfg: LocalModelsConfig) -> "QwenRuleExtractor": + spec = cfg.preference_extractor + dtype = choose_dtype(spec.dtype) + device_map = choose_device_map(spec.device_map) + return cls(spec.local_path, dtype=dtype, device_map=device_map) + + def build_preference_prompt(self, query: str) -> str: + """ + Construct the prompt string using the tokenizer's chat template. + Matches the format seen during SFT training. + """ + messages = [ + {"role": "system", "content": SFT_SYSTEM_PROMPT}, + {"role": "user", "content": query} + ] + prompt = self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + return prompt + + @torch.inference_mode() + def extract_preferences(self, query: str) -> Dict[str, Any]: + """ + Directly extract preferences from query using the SFT model. + Returns a dict compatible with PreferenceList model (key: 'preferences'). + """ + prompt = self.build_preference_prompt(query) + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) + + outputs = self.model.generate( + **inputs, + do_sample=False, # Deterministic greedy decoding + max_new_tokens=512, # Allow enough space for JSON + pad_token_id=self.tokenizer.pad_token_id, + eos_token_id=self.tokenizer.eos_token_id, + ) + + input_len = inputs["input_ids"].shape[1] + gen_ids = outputs[0][input_len:] + text = self.tokenizer.decode(gen_ids, skip_special_tokens=True) + + if os.getenv("PREF_DEBUG") == "1": + print(f"[debug][extractor] Raw output: {text}") + + # Try parsing JSON + try: + # 1. Direct parse + data = json.loads(text) + + # 2. Validate against schema structure + validated = PreferenceList.model_validate(data) + return validated.model_dump() + + except Exception: + # Fallback: Try to find JSON blob if model outputted extra text (rare for SFT but possible) + extracted_json = self._extract_json_substring(text) + if extracted_json: + try: + data = json.loads(extracted_json) + validated = PreferenceList.model_validate(data) + return validated.model_dump() + except: + pass + + # If all fails, return empty + return {"preferences": []} + + def _extract_json_substring(self, text: str) -> str | None: + """Helper to find { ... } block in text.""" + # Find first '{' and last '}' + start = text.find('{') + end = text.rfind('}') + if start != -1 and end != -1 and end > start: + return text[start : end + 1] + return None + + @torch.inference_mode() + def batch_extract_preferences(self, queries: List[str], batch_size: int = 64) -> List[Dict[str, Any]]: + """ + Batch extract preferences from multiple queries using left-padded batching. + """ + if not queries: + return [] + + # Save and set padding side for decoder-only batched generation + orig_padding_side = self.tokenizer.padding_side + self.tokenizer.padding_side = "left" + + all_results = [] + prompts = [self.build_preference_prompt(q) for q in queries] + + for start in range(0, len(prompts), batch_size): + batch_prompts = prompts[start:start + batch_size] + inputs = self.tokenizer( + batch_prompts, return_tensors="pt", padding=True, truncation=True + ).to(self.model.device) + + outputs = self.model.generate( + **inputs, + do_sample=False, + max_new_tokens=512, + pad_token_id=self.tokenizer.pad_token_id, + eos_token_id=self.tokenizer.eos_token_id, + ) + + for i in range(len(batch_prompts)): + input_len = (inputs["attention_mask"][i] == 1).sum().item() + gen_ids = outputs[i][input_len:] + text = self.tokenizer.decode(gen_ids, skip_special_tokens=True) + + try: + data = json.loads(text) + validated = PreferenceList.model_validate(data) + all_results.append(validated.model_dump()) + except Exception: + extracted_json = self._extract_json_substring(text) + if extracted_json: + try: + data = json.loads(extracted_json) + validated = PreferenceList.model_validate(data) + all_results.append(validated.model_dump()) + continue + except Exception: + pass + all_results.append({"preferences": []}) + + self.tokenizer.padding_side = orig_padding_side + return all_results + + def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList: + """ + Extract preferences from the LAST user turn in the history. + We don't concat history because our SFT model was trained on single-turn extraction. + Using context might confuse it unless we trained it that way. + """ + # Find the last user message + last_user_msg = None + for t in reversed(turns): + if t.role == "user": + last_user_msg = t.text + break + + if not last_user_msg: + return PreferenceList(preferences=[]) + + result_dict = self.extract_preferences(last_user_msg) + return PreferenceList.model_validate(result_dict) + + def extract_session(self, turns: List[ChatTurn]) -> List[PreferenceList]: + """ + Extract preferences from ALL user turns individually. + """ + results = [] + for turn in turns: + if turn.role == "user": + res = self.extract_preferences(turn.text) + results.append(PreferenceList.model_validate(res)) + else: + results.append(PreferenceList(preferences=[])) + return results |
