summaryrefslogtreecommitdiff
path: root/src/personalization/models/llm/qwen_instruct.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/personalization/models/llm/qwen_instruct.py')
-rw-r--r--src/personalization/models/llm/qwen_instruct.py164
1 files changed, 164 insertions, 0 deletions
diff --git a/src/personalization/models/llm/qwen_instruct.py b/src/personalization/models/llm/qwen_instruct.py
new file mode 100644
index 0000000..cf2047d
--- /dev/null
+++ b/src/personalization/models/llm/qwen_instruct.py
@@ -0,0 +1,164 @@
+from typing import List, Optional, Dict, Any
+from transformers import AutoModelForCausalLM, AutoTokenizer
+import torch
+
+from personalization.models.llm.base import ChatModel
+from personalization.types import ChatTurn
+from personalization.config.settings import LocalModelsConfig
+from personalization.config.registry import choose_dtype, choose_device_map
+
+class QwenInstruct(ChatModel):
+ def __init__(
+ self,
+ model_path: str,
+ device: str = "cuda",
+ dtype: torch.dtype = torch.bfloat16,
+ max_context_length: int = 4096,
+ ):
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ model_path,
+ use_fast=True,
+ trust_remote_code=True,
+ )
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ torch_dtype=dtype, # dtype is already torch.dtype, no getattr needed
+ device_map=device,
+ trust_remote_code=True,
+ )
+ self.max_context_length = max_context_length
+ if self.tokenizer.pad_token_id is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ # Legacy helper for manual generation without template
+ @torch.inference_mode()
+ def generate(
+ self,
+ prompt: str,
+ max_new_tokens: int = 256,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ stop: Optional[List[str]] = None,
+ top_k: Optional[int] = None,
+ ) -> str:
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
+
+ gen_kwargs = {
+ "max_new_tokens": max_new_tokens,
+ "do_sample": temperature > 0,
+ "temperature": temperature,
+ "top_p": top_p,
+ "pad_token_id": self.tokenizer.pad_token_id,
+ "eos_token_id": self.tokenizer.eos_token_id,
+ }
+ if top_k is not None:
+ gen_kwargs["top_k"] = top_k
+
+ outputs = self.model.generate(
+ **inputs,
+ **gen_kwargs
+ )
+ # Return only the newly generated portion, not the echoed prompt
+ input_len = inputs["input_ids"].shape[1]
+ gen_ids = outputs[0][input_len:]
+ text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
+ if stop:
+ for s in stop:
+ if s in text:
+ text = text.split(s)[0]
+ break
+ return text
+
+ def _build_prompt(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ ) -> str:
+ """
+ Construct prompt using ChatML-like structure via apply_chat_template if available,
+ or manual construction. Qwen usually supports apply_chat_template.
+ We will map ChatTurn to messages list.
+ """
+ memory_block = ""
+ if memory_notes:
+ bullet = "\n".join(f"- {n}" for n in memory_notes)
+ memory_block = (
+ "Here are the user's preferences and memories:\n"
+ f"{bullet}\n\n"
+ )
+
+ messages = [{"role": "system", "content": "You are a helpful assistant.\n" + memory_block}]
+
+ for turn in history:
+ messages.append({"role": turn.role, "content": turn.text})
+
+ return self.tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+
+ def answer(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ max_new_tokens: int = 512,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ top_k: Optional[int] = None,
+ ) -> str:
+ # Compatibility check: if history is dict list (legacy), convert to ChatTurn
+ # This allows old code to work if not fully updated, though we should update callers.
+ # But ChatTurn is required by Protocol. We assume callers are updated.
+ if history and isinstance(history[0], dict):
+ # Auto-convert for safety during migration
+ history = [ChatTurn(
+ user_id="unknown", session_id="unknown", turn_id=i,
+ role=h["role"], text=h["content"]
+ ) for i, h in enumerate(history)]
+
+ prompt = self._build_prompt(history, memory_notes)
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True,
+ max_length=self.max_context_length).to(self.model.device)
+
+ gen_kwargs = {
+ "max_new_tokens": max_new_tokens,
+ "do_sample": temperature > 0,
+ "temperature": temperature,
+ "top_p": top_p,
+ "pad_token_id": self.tokenizer.pad_token_id,
+ "eos_token_id": self.tokenizer.eos_token_id,
+ }
+ if top_k is not None:
+ gen_kwargs["top_k"] = top_k
+
+ with torch.no_grad():
+ outputs = self.model.generate(
+ **inputs,
+ **gen_kwargs,
+ )
+
+ full = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
+ # remove prompt part manually since we didn't use self.generate helper here to keep full control
+ # input_ids length is inputs['input_ids'].shape[1]
+ input_len = inputs["input_ids"].shape[1]
+ gen_ids = outputs[0][input_len:]
+ answer_text = self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
+
+ return answer_text
+
+ # Factory method for legacy config loading
+ @classmethod
+ def from_config(cls, cfg: LocalModelsConfig) -> "QwenInstruct":
+ spec = cfg.llm
+ dtype = choose_dtype(spec.dtype)
+ device_map = choose_device_map(spec.device_map)
+ # device_map usually handled by transformers if passed as device_map argument
+ # Here we pass it as 'device' arg to constructor if it is a string like "cuda:0"
+ # If it is "auto", constructor might need adjustment or we trust transformers.
+ # Our constructor takes 'device' string.
+ device = spec.device_map if isinstance(spec.device_map, str) else "cuda"
+
+ return cls(
+ model_path=spec.local_path,
+ device=device, # Pass string
+ dtype=spec.dtype # Pass string name, constructor converts
+ )