summaryrefslogtreecommitdiff
path: root/src/personalization/models/preference_extractor/llm_extractor.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-27 15:43:42 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-27 15:43:42 -0600
commitf918fc90b8d71d1287590b016d926268be573de0 (patch)
treed9009c8612c8e7f866c31d22fb979892a5b55eeb /src/personalization/models/preference_extractor/llm_extractor.py
parent680513b7771a29f27cbbb3ffb009a69a913de6f9 (diff)
Add model wrapper modules (embedding, reranker, llm, preference_extractor)
Add Python wrappers for: - Qwen3/Nemotron embedding models - BGE/Qwen3 rerankers - vLLM/Llama/Qwen LLM backends - GPT-4o/LLM-based preference extractors Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'src/personalization/models/preference_extractor/llm_extractor.py')
-rw-r--r--src/personalization/models/preference_extractor/llm_extractor.py153
1 files changed, 153 insertions, 0 deletions
diff --git a/src/personalization/models/preference_extractor/llm_extractor.py b/src/personalization/models/preference_extractor/llm_extractor.py
new file mode 100644
index 0000000..8f7a6cb
--- /dev/null
+++ b/src/personalization/models/preference_extractor/llm_extractor.py
@@ -0,0 +1,153 @@
+from typing import List, Dict, Any
+import torch
+import json
+import os
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from personalization.models.preference_extractor.base import PreferenceExtractorBase
+from personalization.retrieval.preference_store.schemas import ChatTurn, PreferenceList
+from personalization.config.settings import LocalModelsConfig
+from personalization.config.registry import choose_dtype, choose_device_map
+
+class PreferenceExtractorLLM(PreferenceExtractorBase):
+ def __init__(
+ self,
+ model_path: str,
+ prompt_template_path: str = "fine_tuning_prompt_template.txt",
+ device_map: str = "auto",
+ dtype: torch.dtype = torch.bfloat16,
+ max_new_tokens: int = 512,
+ ) -> None:
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ torch_dtype=dtype,
+ device_map=device_map,
+ trust_remote_code=True,
+ )
+ self.max_new_tokens = max_new_tokens
+
+ if os.path.exists(prompt_template_path):
+ with open(prompt_template_path, "r", encoding="utf-8") as f:
+ self.prompt_template = f.read()
+ else:
+ print(f"Warning: Prompt template not found at {prompt_template_path}. Using fallback.")
+ self.prompt_template = "Extract user preferences from the following conversation."
+
+ @classmethod
+ def from_config(cls, cfg: LocalModelsConfig, name: str = "qwen3_0_6b_sft") -> "PreferenceExtractorLLM":
+ # We need to access the specific extractor config by name
+ # Assuming cfg has a way to access extra configs or we update LocalModelsConfig to support multiple extractors
+ # For now, let's look for it in the 'preference_extractor' dict if it was a Dict, but it is a ModelSpec.
+ # We need to update LocalModelsConfig to support a dictionary of extractors or a specific one.
+ # Based on user design doc:
+ # preference_extractor:
+ # qwen3_0_6b_sft: ...
+
+ # We might need to manually parse the raw config or update settings.py
+ # Let's assume settings.py will be updated to hold a map or specific fields.
+ # For now, if we use the existing ModelSpec for preference_extractor in cfg, we assume it points to this model.
+
+ # BUT the design doc says "preference_extractor" in local_models.yaml will have "qwen3_0_6b_sft" key.
+ # The current settings.py defines preference_extractor as a single ModelSpec.
+ # We will need to update settings.py first to support multiple extractors or a dict.
+ # I will proceed implementing this class assuming arguments are passed, and update settings/registry later.
+
+ # This from_config might change depending on how settings.py is refactored.
+ # For now I will implement it assuming a direct ModelSpec is passed, or we handle it in registry.
+ pass
+ return None
+
+ def _build_prompt(self, turns: List[ChatTurn]) -> str:
+ # Construct messages list for chat template
+ messages = [{"role": "system", "content": self.prompt_template}]
+
+ # Window size 6
+ window = turns[-6:]
+
+ # Add conversation history
+ # We need to format the conversation as input context.
+ # Since the task is to extract preferences from the *whole* context (or latest turn?),
+ # usually we provide the conversation and ask for extraction.
+ # But LLaMA-Factory SFT usually expects:
+ # System: <template>
+ # User: <input>
+ # Assistant: <output>
+
+ # We should pack the conversation history into the User message?
+ # Or if we trained with multi-turn chat format?
+ # Assuming "Input" column in dataset was the conversation history.
+
+ history_texts = []
+ for t in window:
+ role = "User" if t.role == "user" else "Assistant"
+ history_texts.append(f"{role}: {t.text}")
+
+ conversation_text = "\n".join(history_texts)
+
+ # Construct the User input
+ # We append a trigger instruction if it wasn't part of the training input implicitly.
+ # But based on your template, the User Input Example was just the query "I am a Python developer..."
+ # So likely we should just feed the conversation text as the user message.
+
+ messages.append({"role": "user", "content": conversation_text})
+
+ # Apply chat template
+ prompt = self.tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True
+ )
+
+ return prompt
+
+ def _generate(self, prompt: str) -> str:
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
+ with torch.no_grad():
+ outputs = self.model.generate(
+ **inputs,
+ max_new_tokens=self.max_new_tokens,
+ do_sample=False,
+ temperature=0.0,
+ eos_token_id=self.tokenizer.eos_token_id,
+ pad_token_id=self.tokenizer.pad_token_id,
+ )
+ full_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
+ return full_text[len(prompt):]
+
+ def _parse_preferences(self, raw_output: str) -> PreferenceList:
+ start = raw_output.find("{")
+ end = raw_output.rfind("}")
+
+ if start == -1 or end == -1 or end <= start:
+ return PreferenceList(preferences=[])
+
+ json_str = raw_output[start:end+1]
+ try:
+ data = json.loads(json_str)
+ return PreferenceList.model_validate(data)
+ except Exception:
+ return PreferenceList(preferences=[])
+
+ def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList:
+ prompt = self._build_prompt(turns)
+ raw_output = self._generate(prompt)
+ return self._parse_preferences(raw_output)
+
+ # Legacy support
+ def build_preference_prompt(self, query: str) -> str:
+ # Wrap query in a dummy turn
+ turn = ChatTurn(
+ user_id="dummy", session_id="dummy", turn_id=0,
+ role="user", text=query
+ )
+ return self._build_prompt([turn])
+
+ def extract_preferences(self, query: str) -> Dict[str, Any]:
+ turn = ChatTurn(
+ user_id="dummy", session_id="dummy", turn_id=0,
+ role="user", text=query
+ )
+ prefs = self.extract_turn([turn])
+ return prefs.model_dump()
+