summaryrefslogtreecommitdiff
path: root/src/personalization/models/llm/llama_instruct.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/personalization/models/llm/llama_instruct.py')
-rw-r--r--src/personalization/models/llm/llama_instruct.py129
1 files changed, 129 insertions, 0 deletions
diff --git a/src/personalization/models/llm/llama_instruct.py b/src/personalization/models/llm/llama_instruct.py
new file mode 100644
index 0000000..bdf0dff
--- /dev/null
+++ b/src/personalization/models/llm/llama_instruct.py
@@ -0,0 +1,129 @@
+from typing import List, Optional
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from personalization.models.llm.base import ChatModel
+from personalization.types import ChatTurn
+
+class LlamaChatModel(ChatModel):
+ def __init__(
+ self,
+ model_path: str,
+ device: str = "cuda",
+ dtype: str = "bfloat16", # Keep type hint as str for legacy, but handle torch.dtype
+ max_context_length: int = 8192,
+ ):
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
+
+ # Handle dtype if it's already a torch.dtype object
+ if isinstance(dtype, str):
+ torch_dtype = getattr(torch, dtype)
+ else:
+ torch_dtype = dtype
+
+ # Handle specific device assignment (e.g., "cuda:0", "cuda:1")
+ if device and device.startswith("cuda:"):
+ # Load to CPU first, then move to specific GPU
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ torch_dtype=torch_dtype,
+ device_map=None,
+ low_cpu_mem_usage=True,
+ )
+ self.model = self.model.to(device)
+ else:
+ # Use accelerate's device mapping
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ torch_dtype=torch_dtype,
+ device_map=device,
+ )
+
+ self.max_context_length = max_context_length
+ if self.tokenizer.pad_token_id is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ def _build_prompt(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ ) -> str:
+ memory_block = ""
+ if memory_notes:
+ bullet = "\n".join(f"- {n}" for n in memory_notes)
+ memory_block = (
+ "Here are the user's preferences and memories:\n"
+ f"{bullet}\n\n"
+ )
+
+ # Build prompt manually or use chat template if available.
+ # Llama 3 use specific tags.
+ # <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n ... <|eot_id|>
+ # But we can try to use tokenizer.apply_chat_template if it exists.
+
+ if hasattr(self.tokenizer, "apply_chat_template"):
+ messages = [{"role": "system", "content": "You are a helpful assistant.\n" + memory_block}]
+ for turn in history:
+ messages.append({"role": turn.role, "content": turn.text})
+ return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+
+ # Fallback manual construction (simplified Llama 2/3 style or generic)
+ # This is risky for Llama 3 specifically which needs exact tokens.
+ # Let's assume apply_chat_template works for Llama-3-Instruct models.
+
+ # If fallback needed:
+ history_lines = []
+ for turn in history[-8:]:
+ role_tag = "user" if turn.role == "user" else "assistant"
+ # Generic format
+ history_lines.append(f"{role_tag}: {turn.text}")
+
+ prompt = (
+ "System: You are a helpful assistant.\n"
+ + memory_block
+ + "\n".join(history_lines)
+ + "\nassistant:"
+ )
+ return prompt
+
+ def answer(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ max_new_tokens: int = 512,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ top_k: Optional[int] = None,
+ ) -> str:
+ prompt = self._build_prompt(history, memory_notes)
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True,
+ max_length=self.max_context_length).to(self.model.device)
+
+ gen_kwargs = {
+ "max_new_tokens": max_new_tokens,
+ "do_sample": temperature > 0,
+ "temperature": temperature,
+ "top_p": top_p,
+ }
+ if top_k is not None:
+ gen_kwargs["top_k"] = top_k
+
+ with torch.no_grad():
+ outputs = self.model.generate(
+ **inputs,
+ eos_token_id=self.tokenizer.eos_token_id,
+ **gen_kwargs,
+ )
+ full = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
+ # naive stripping
+ # With chat template, 'full' usually contains the whole conversation.
+ # We need to extract just the new part.
+ # But 'prompt' string might not match decoded output exactly due to special tokens skipping.
+ # Better: slice output ids.
+
+ input_len = inputs["input_ids"].shape[1]
+ gen_ids = outputs[0][input_len:]
+ answer_text = self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
+
+ return answer_text
+