1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
|
from typing import List, Optional
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from personalization.models.llm.base import ChatModel
from personalization.types import ChatTurn
class LlamaChatModel(ChatModel):
def __init__(
self,
model_path: str,
device: str = "cuda",
dtype: str = "bfloat16", # Keep type hint as str for legacy, but handle torch.dtype
max_context_length: int = 8192,
):
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
# Handle dtype if it's already a torch.dtype object
if isinstance(dtype, str):
torch_dtype = getattr(torch, dtype)
else:
torch_dtype = dtype
# Handle specific device assignment (e.g., "cuda:0", "cuda:1")
if device and device.startswith("cuda:"):
# Load to CPU first, then move to specific GPU
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
device_map=None,
low_cpu_mem_usage=True,
)
self.model = self.model.to(device)
else:
# Use accelerate's device mapping
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
device_map=device,
)
self.max_context_length = max_context_length
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def _build_prompt(
self,
history: List[ChatTurn],
memory_notes: List[str],
) -> str:
memory_block = ""
if memory_notes:
bullet = "\n".join(f"- {n}" for n in memory_notes)
memory_block = (
"Here are the user's preferences and memories:\n"
f"{bullet}\n\n"
)
# Build prompt manually or use chat template if available.
# Llama 3 use specific tags.
# <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n ... <|eot_id|>
# But we can try to use tokenizer.apply_chat_template if it exists.
if hasattr(self.tokenizer, "apply_chat_template"):
messages = [{"role": "system", "content": "You are a helpful assistant.\n" + memory_block}]
for turn in history:
messages.append({"role": turn.role, "content": turn.text})
return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Fallback manual construction (simplified Llama 2/3 style or generic)
# This is risky for Llama 3 specifically which needs exact tokens.
# Let's assume apply_chat_template works for Llama-3-Instruct models.
# If fallback needed:
history_lines = []
for turn in history[-8:]:
role_tag = "user" if turn.role == "user" else "assistant"
# Generic format
history_lines.append(f"{role_tag}: {turn.text}")
prompt = (
"System: You are a helpful assistant.\n"
+ memory_block
+ "\n".join(history_lines)
+ "\nassistant:"
)
return prompt
def answer(
self,
history: List[ChatTurn],
memory_notes: List[str],
max_new_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
top_k: Optional[int] = None,
) -> str:
prompt = self._build_prompt(history, memory_notes)
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True,
max_length=self.max_context_length).to(self.model.device)
gen_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": temperature > 0,
"temperature": temperature,
"top_p": top_p,
}
if top_k is not None:
gen_kwargs["top_k"] = top_k
with torch.no_grad():
outputs = self.model.generate(
**inputs,
eos_token_id=self.tokenizer.eos_token_id,
**gen_kwargs,
)
full = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# naive stripping
# With chat template, 'full' usually contains the whole conversation.
# We need to extract just the new part.
# But 'prompt' string might not match decoded output exactly due to special tokens skipping.
# Better: slice output ids.
input_len = inputs["input_ids"].shape[1]
gen_ids = outputs[0][input_len:]
answer_text = self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
return answer_text
|