diff options
Diffstat (limited to 'src/personalization/models/llm/qwen_instruct.py')
| -rw-r--r-- | src/personalization/models/llm/qwen_instruct.py | 164 |
1 files changed, 164 insertions, 0 deletions
diff --git a/src/personalization/models/llm/qwen_instruct.py b/src/personalization/models/llm/qwen_instruct.py new file mode 100644 index 0000000..cf2047d --- /dev/null +++ b/src/personalization/models/llm/qwen_instruct.py @@ -0,0 +1,164 @@ +from typing import List, Optional, Dict, Any +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch + +from personalization.models.llm.base import ChatModel +from personalization.types import ChatTurn +from personalization.config.settings import LocalModelsConfig +from personalization.config.registry import choose_dtype, choose_device_map + +class QwenInstruct(ChatModel): + def __init__( + self, + model_path: str, + device: str = "cuda", + dtype: torch.dtype = torch.bfloat16, + max_context_length: int = 4096, + ): + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, + use_fast=True, + trust_remote_code=True, + ) + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=dtype, # dtype is already torch.dtype, no getattr needed + device_map=device, + trust_remote_code=True, + ) + self.max_context_length = max_context_length + if self.tokenizer.pad_token_id is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + # Legacy helper for manual generation without template + @torch.inference_mode() + def generate( + self, + prompt: str, + max_new_tokens: int = 256, + temperature: float = 0.7, + top_p: float = 0.9, + stop: Optional[List[str]] = None, + top_k: Optional[int] = None, + ) -> str: + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) + + gen_kwargs = { + "max_new_tokens": max_new_tokens, + "do_sample": temperature > 0, + "temperature": temperature, + "top_p": top_p, + "pad_token_id": self.tokenizer.pad_token_id, + "eos_token_id": self.tokenizer.eos_token_id, + } + if top_k is not None: + gen_kwargs["top_k"] = top_k + + outputs = self.model.generate( + **inputs, + **gen_kwargs + ) + # Return only the newly generated portion, not the echoed prompt + input_len = inputs["input_ids"].shape[1] + gen_ids = outputs[0][input_len:] + text = self.tokenizer.decode(gen_ids, skip_special_tokens=True) + if stop: + for s in stop: + if s in text: + text = text.split(s)[0] + break + return text + + def _build_prompt( + self, + history: List[ChatTurn], + memory_notes: List[str], + ) -> str: + """ + Construct prompt using ChatML-like structure via apply_chat_template if available, + or manual construction. Qwen usually supports apply_chat_template. + We will map ChatTurn to messages list. + """ + memory_block = "" + if memory_notes: + bullet = "\n".join(f"- {n}" for n in memory_notes) + memory_block = ( + "Here are the user's preferences and memories:\n" + f"{bullet}\n\n" + ) + + messages = [{"role": "system", "content": "You are a helpful assistant.\n" + memory_block}] + + for turn in history: + messages.append({"role": turn.role, "content": turn.text}) + + return self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + def answer( + self, + history: List[ChatTurn], + memory_notes: List[str], + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.9, + top_k: Optional[int] = None, + ) -> str: + # Compatibility check: if history is dict list (legacy), convert to ChatTurn + # This allows old code to work if not fully updated, though we should update callers. + # But ChatTurn is required by Protocol. We assume callers are updated. + if history and isinstance(history[0], dict): + # Auto-convert for safety during migration + history = [ChatTurn( + user_id="unknown", session_id="unknown", turn_id=i, + role=h["role"], text=h["content"] + ) for i, h in enumerate(history)] + + prompt = self._build_prompt(history, memory_notes) + inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, + max_length=self.max_context_length).to(self.model.device) + + gen_kwargs = { + "max_new_tokens": max_new_tokens, + "do_sample": temperature > 0, + "temperature": temperature, + "top_p": top_p, + "pad_token_id": self.tokenizer.pad_token_id, + "eos_token_id": self.tokenizer.eos_token_id, + } + if top_k is not None: + gen_kwargs["top_k"] = top_k + + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + **gen_kwargs, + ) + + full = self.tokenizer.decode(outputs[0], skip_special_tokens=True) + # remove prompt part manually since we didn't use self.generate helper here to keep full control + # input_ids length is inputs['input_ids'].shape[1] + input_len = inputs["input_ids"].shape[1] + gen_ids = outputs[0][input_len:] + answer_text = self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip() + + return answer_text + + # Factory method for legacy config loading + @classmethod + def from_config(cls, cfg: LocalModelsConfig) -> "QwenInstruct": + spec = cfg.llm + dtype = choose_dtype(spec.dtype) + device_map = choose_device_map(spec.device_map) + # device_map usually handled by transformers if passed as device_map argument + # Here we pass it as 'device' arg to constructor if it is a string like "cuda:0" + # If it is "auto", constructor might need adjustment or we trust transformers. + # Our constructor takes 'device' string. + device = spec.device_map if isinstance(spec.device_map, str) else "cuda" + + return cls( + model_path=spec.local_path, + device=device, # Pass string + dtype=spec.dtype # Pass string name, constructor converts + ) |
