summaryrefslogtreecommitdiff
path: root/src/personalization/models/llm/qwen_instruct.py
blob: cf2047dce68fba08d90f56caf59cc1ae1ac08eb5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from typing import List, Optional, Dict, Any
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

from personalization.models.llm.base import ChatModel
from personalization.types import ChatTurn
from personalization.config.settings import LocalModelsConfig
from personalization.config.registry import choose_dtype, choose_device_map

class QwenInstruct(ChatModel):
    def __init__(
        self,
        model_path: str,
        device: str = "cuda",
        dtype: torch.dtype = torch.bfloat16,
        max_context_length: int = 4096,
    ):
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            use_fast=True,
            trust_remote_code=True,
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=dtype, # dtype is already torch.dtype, no getattr needed
            device_map=device,
            trust_remote_code=True,
        )
        self.max_context_length = max_context_length
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

    # Legacy helper for manual generation without template
    @torch.inference_mode()
    def generate(
        self,
        prompt: str,
        max_new_tokens: int = 256,
        temperature: float = 0.7,
        top_p: float = 0.9,
        stop: Optional[List[str]] = None,
        top_k: Optional[int] = None,
    ) -> str:
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        
        gen_kwargs = {
            "max_new_tokens": max_new_tokens,
            "do_sample": temperature > 0,
            "temperature": temperature,
            "top_p": top_p,
            "pad_token_id": self.tokenizer.pad_token_id,
            "eos_token_id": self.tokenizer.eos_token_id,
        }
        if top_k is not None:
            gen_kwargs["top_k"] = top_k

        outputs = self.model.generate(
            **inputs,
            **gen_kwargs
        )
        # Return only the newly generated portion, not the echoed prompt
        input_len = inputs["input_ids"].shape[1]
        gen_ids = outputs[0][input_len:]
        text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
        if stop:
            for s in stop:
                if s in text:
                    text = text.split(s)[0]
                    break
        return text

    def _build_prompt(
        self,
        history: List[ChatTurn],
        memory_notes: List[str],
    ) -> str:
        """
        Construct prompt using ChatML-like structure via apply_chat_template if available, 
        or manual construction. Qwen usually supports apply_chat_template.
        We will map ChatTurn to messages list.
        """
        memory_block = ""
        if memory_notes:
            bullet = "\n".join(f"- {n}" for n in memory_notes)
            memory_block = (
                "Here are the user's preferences and memories:\n"
                f"{bullet}\n\n"
            )

        messages = [{"role": "system", "content": "You are a helpful assistant.\n" + memory_block}]
        
        for turn in history:
            messages.append({"role": turn.role, "content": turn.text})
            
        return self.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )

    def answer(
        self,
        history: List[ChatTurn],
        memory_notes: List[str],
        max_new_tokens: int = 512,
        temperature: float = 0.7,
        top_p: float = 0.9,
        top_k: Optional[int] = None,
    ) -> str:
        # Compatibility check: if history is dict list (legacy), convert to ChatTurn
        # This allows old code to work if not fully updated, though we should update callers.
        # But ChatTurn is required by Protocol. We assume callers are updated.
        if history and isinstance(history[0], dict):
             # Auto-convert for safety during migration
             history = [ChatTurn(
                 user_id="unknown", session_id="unknown", turn_id=i, 
                 role=h["role"], text=h["content"]
             ) for i, h in enumerate(history)]

        prompt = self._build_prompt(history, memory_notes)
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True,
                                max_length=self.max_context_length).to(self.model.device)

        gen_kwargs = {
            "max_new_tokens": max_new_tokens,
            "do_sample": temperature > 0,
            "temperature": temperature,
            "top_p": top_p,
            "pad_token_id": self.tokenizer.pad_token_id,
            "eos_token_id": self.tokenizer.eos_token_id,
        }
        if top_k is not None:
            gen_kwargs["top_k"] = top_k

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                **gen_kwargs,
            )
        
        full = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        # remove prompt part manually since we didn't use self.generate helper here to keep full control
        # input_ids length is inputs['input_ids'].shape[1]
        input_len = inputs["input_ids"].shape[1]
        gen_ids = outputs[0][input_len:]
        answer_text = self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
        
        return answer_text

    # Factory method for legacy config loading
    @classmethod
    def from_config(cls, cfg: LocalModelsConfig) -> "QwenInstruct":
        spec = cfg.llm
        dtype = choose_dtype(spec.dtype)
        device_map = choose_device_map(spec.device_map)
        # device_map usually handled by transformers if passed as device_map argument
        # Here we pass it as 'device' arg to constructor if it is a string like "cuda:0"
        # If it is "auto", constructor might need adjustment or we trust transformers.
        # Our constructor takes 'device' string.
        device = spec.device_map if isinstance(spec.device_map, str) else "cuda"
        
        return cls(
            model_path=spec.local_path, 
            device=device, # Pass string
            dtype=spec.dtype # Pass string name, constructor converts
        )