summaryrefslogtreecommitdiff
path: root/src/personalization/models/llm/llama_instruct.py
blob: bdf0dff72d2086a68b927f3ea6e0cc76a792ce4c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from typing import List, Optional
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from personalization.models.llm.base import ChatModel
from personalization.types import ChatTurn

class LlamaChatModel(ChatModel):
    def __init__(
        self,
        model_path: str,
        device: str = "cuda",
        dtype: str = "bfloat16", # Keep type hint as str for legacy, but handle torch.dtype
        max_context_length: int = 8192,
    ):
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        
        # Handle dtype if it's already a torch.dtype object
        if isinstance(dtype, str):
            torch_dtype = getattr(torch, dtype)
        else:
            torch_dtype = dtype
            
        # Handle specific device assignment (e.g., "cuda:0", "cuda:1")
        if device and device.startswith("cuda:"):
            # Load to CPU first, then move to specific GPU
            self.model = AutoModelForCausalLM.from_pretrained(
                model_path,
                torch_dtype=torch_dtype,
                device_map=None,
                low_cpu_mem_usage=True,
            )
            self.model = self.model.to(device)
        else:
            # Use accelerate's device mapping
            self.model = AutoModelForCausalLM.from_pretrained(
                model_path,
                torch_dtype=torch_dtype,
                device_map=device,
            )
        
        self.max_context_length = max_context_length
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

    def _build_prompt(
        self,
        history: List[ChatTurn],
        memory_notes: List[str],
    ) -> str:
        memory_block = ""
        if memory_notes:
            bullet = "\n".join(f"- {n}" for n in memory_notes)
            memory_block = (
                "Here are the user's preferences and memories:\n"
                f"{bullet}\n\n"
            )

        # Build prompt manually or use chat template if available.
        # Llama 3 use specific tags.
        # <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n ... <|eot_id|>
        # But we can try to use tokenizer.apply_chat_template if it exists.
        
        if hasattr(self.tokenizer, "apply_chat_template"):
            messages = [{"role": "system", "content": "You are a helpful assistant.\n" + memory_block}]
            for turn in history:
                messages.append({"role": turn.role, "content": turn.text})
            return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            
        # Fallback manual construction (simplified Llama 2/3 style or generic)
        # This is risky for Llama 3 specifically which needs exact tokens.
        # Let's assume apply_chat_template works for Llama-3-Instruct models.
        
        # If fallback needed:
        history_lines = []
        for turn in history[-8:]:
            role_tag = "user" if turn.role == "user" else "assistant"
            # Generic format
            history_lines.append(f"{role_tag}: {turn.text}")
            
        prompt = (
            "System: You are a helpful assistant.\n"
            + memory_block
            + "\n".join(history_lines)
            + "\nassistant:"
        )
        return prompt

    def answer(
        self,
        history: List[ChatTurn],
        memory_notes: List[str],
        max_new_tokens: int = 512,
        temperature: float = 0.7,
        top_p: float = 0.9,
        top_k: Optional[int] = None,
    ) -> str:
        prompt = self._build_prompt(history, memory_notes)
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True,
                                max_length=self.max_context_length).to(self.model.device)

        gen_kwargs = {
            "max_new_tokens": max_new_tokens,
            "do_sample": temperature > 0,
            "temperature": temperature,
            "top_p": top_p,
        }
        if top_k is not None:
            gen_kwargs["top_k"] = top_k

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                eos_token_id=self.tokenizer.eos_token_id,
                **gen_kwargs,
            )
        full = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        # naive stripping
        # With chat template, 'full' usually contains the whole conversation.
        # We need to extract just the new part.
        # But 'prompt' string might not match decoded output exactly due to special tokens skipping.
        # Better: slice output ids.
        
        input_len = inputs["input_ids"].shape[1]
        gen_ids = outputs[0][input_len:]
        answer_text = self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
        
        return answer_text