from typing import List, Dict, Any import torch import json import os from transformers import AutoModelForCausalLM, AutoTokenizer from personalization.models.preference_extractor.base import PreferenceExtractorBase from personalization.retrieval.preference_store.schemas import ChatTurn, PreferenceList from personalization.config.settings import LocalModelsConfig from personalization.config.registry import choose_dtype, choose_device_map class PreferenceExtractorLLM(PreferenceExtractorBase): def __init__( self, model_path: str, prompt_template_path: str = "fine_tuning_prompt_template.txt", device_map: str = "auto", dtype: torch.dtype = torch.bfloat16, max_new_tokens: int = 512, ) -> None: self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) self.model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=dtype, device_map=device_map, trust_remote_code=True, ) self.max_new_tokens = max_new_tokens if os.path.exists(prompt_template_path): with open(prompt_template_path, "r", encoding="utf-8") as f: self.prompt_template = f.read() else: print(f"Warning: Prompt template not found at {prompt_template_path}. Using fallback.") self.prompt_template = "Extract user preferences from the following conversation." @classmethod def from_config(cls, cfg: LocalModelsConfig, name: str = "qwen3_0_6b_sft") -> "PreferenceExtractorLLM": # We need to access the specific extractor config by name # Assuming cfg has a way to access extra configs or we update LocalModelsConfig to support multiple extractors # For now, let's look for it in the 'preference_extractor' dict if it was a Dict, but it is a ModelSpec. # We need to update LocalModelsConfig to support a dictionary of extractors or a specific one. # Based on user design doc: # preference_extractor: # qwen3_0_6b_sft: ... # We might need to manually parse the raw config or update settings.py # Let's assume settings.py will be updated to hold a map or specific fields. # For now, if we use the existing ModelSpec for preference_extractor in cfg, we assume it points to this model. # BUT the design doc says "preference_extractor" in local_models.yaml will have "qwen3_0_6b_sft" key. # The current settings.py defines preference_extractor as a single ModelSpec. # We will need to update settings.py first to support multiple extractors or a dict. # I will proceed implementing this class assuming arguments are passed, and update settings/registry later. # This from_config might change depending on how settings.py is refactored. # For now I will implement it assuming a direct ModelSpec is passed, or we handle it in registry. pass return None def _build_prompt(self, turns: List[ChatTurn]) -> str: # Construct messages list for chat template messages = [{"role": "system", "content": self.prompt_template}] # Window size 6 window = turns[-6:] # Add conversation history # We need to format the conversation as input context. # Since the task is to extract preferences from the *whole* context (or latest turn?), # usually we provide the conversation and ask for extraction. # But LLaMA-Factory SFT usually expects: # System: