from typing import List, Optional, Dict, Any from transformers import AutoModelForCausalLM, AutoTokenizer import torch from personalization.models.llm.base import ChatModel from personalization.types import ChatTurn from personalization.config.settings import LocalModelsConfig from personalization.config.registry import choose_dtype, choose_device_map class QwenInstruct(ChatModel): def __init__( self, model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, max_context_length: int = 4096, ): self.tokenizer = AutoTokenizer.from_pretrained( model_path, use_fast=True, trust_remote_code=True, ) self.model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=dtype, # dtype is already torch.dtype, no getattr needed device_map=device, trust_remote_code=True, ) self.max_context_length = max_context_length if self.tokenizer.pad_token_id is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Legacy helper for manual generation without template @torch.inference_mode() def generate( self, prompt: str, max_new_tokens: int = 256, temperature: float = 0.7, top_p: float = 0.9, stop: Optional[List[str]] = None, top_k: Optional[int] = None, ) -> str: inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) gen_kwargs = { "max_new_tokens": max_new_tokens, "do_sample": temperature > 0, "temperature": temperature, "top_p": top_p, "pad_token_id": self.tokenizer.pad_token_id, "eos_token_id": self.tokenizer.eos_token_id, } if top_k is not None: gen_kwargs["top_k"] = top_k outputs = self.model.generate( **inputs, **gen_kwargs ) # Return only the newly generated portion, not the echoed prompt input_len = inputs["input_ids"].shape[1] gen_ids = outputs[0][input_len:] text = self.tokenizer.decode(gen_ids, skip_special_tokens=True) if stop: for s in stop: if s in text: text = text.split(s)[0] break return text def _build_prompt( self, history: List[ChatTurn], memory_notes: List[str], ) -> str: """ Construct prompt using ChatML-like structure via apply_chat_template if available, or manual construction. Qwen usually supports apply_chat_template. We will map ChatTurn to messages list. """ memory_block = "" if memory_notes: bullet = "\n".join(f"- {n}" for n in memory_notes) memory_block = ( "Here are the user's preferences and memories:\n" f"{bullet}\n\n" ) messages = [{"role": "system", "content": "You are a helpful assistant.\n" + memory_block}] for turn in history: messages.append({"role": turn.role, "content": turn.text}) return self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) def answer( self, history: List[ChatTurn], memory_notes: List[str], max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9, top_k: Optional[int] = None, ) -> str: # Compatibility check: if history is dict list (legacy), convert to ChatTurn # This allows old code to work if not fully updated, though we should update callers. # But ChatTurn is required by Protocol. We assume callers are updated. if history and isinstance(history[0], dict): # Auto-convert for safety during migration history = [ChatTurn( user_id="unknown", session_id="unknown", turn_id=i, role=h["role"], text=h["content"] ) for i, h in enumerate(history)] prompt = self._build_prompt(history, memory_notes) inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=self.max_context_length).to(self.model.device) gen_kwargs = { "max_new_tokens": max_new_tokens, "do_sample": temperature > 0, "temperature": temperature, "top_p": top_p, "pad_token_id": self.tokenizer.pad_token_id, "eos_token_id": self.tokenizer.eos_token_id, } if top_k is not None: gen_kwargs["top_k"] = top_k with torch.no_grad(): outputs = self.model.generate( **inputs, **gen_kwargs, ) full = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # remove prompt part manually since we didn't use self.generate helper here to keep full control # input_ids length is inputs['input_ids'].shape[1] input_len = inputs["input_ids"].shape[1] gen_ids = outputs[0][input_len:] answer_text = self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip() return answer_text # Factory method for legacy config loading @classmethod def from_config(cls, cfg: LocalModelsConfig) -> "QwenInstruct": spec = cfg.llm dtype = choose_dtype(spec.dtype) device_map = choose_device_map(spec.device_map) # device_map usually handled by transformers if passed as device_map argument # Here we pass it as 'device' arg to constructor if it is a string like "cuda:0" # If it is "auto", constructor might need adjustment or we trust transformers. # Our constructor takes 'device' string. device = spec.device_map if isinstance(spec.device_map, str) else "cuda" return cls( model_path=spec.local_path, device=device, # Pass string dtype=spec.dtype # Pass string name, constructor converts )