from typing import List, Optional import torch from transformers import AutoModelForCausalLM, AutoTokenizer from personalization.models.llm.base import ChatModel from personalization.types import ChatTurn class LlamaChatModel(ChatModel): def __init__( self, model_path: str, device: str = "cuda", dtype: str = "bfloat16", # Keep type hint as str for legacy, but handle torch.dtype max_context_length: int = 8192, ): self.tokenizer = AutoTokenizer.from_pretrained(model_path) # Handle dtype if it's already a torch.dtype object if isinstance(dtype, str): torch_dtype = getattr(torch, dtype) else: torch_dtype = dtype # Handle specific device assignment (e.g., "cuda:0", "cuda:1") if device and device.startswith("cuda:"): # Load to CPU first, then move to specific GPU self.model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch_dtype, device_map=None, low_cpu_mem_usage=True, ) self.model = self.model.to(device) else: # Use accelerate's device mapping self.model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch_dtype, device_map=device, ) self.max_context_length = max_context_length if self.tokenizer.pad_token_id is None: self.tokenizer.pad_token = self.tokenizer.eos_token def _build_prompt( self, history: List[ChatTurn], memory_notes: List[str], ) -> str: memory_block = "" if memory_notes: bullet = "\n".join(f"- {n}" for n in memory_notes) memory_block = ( "Here are the user's preferences and memories:\n" f"{bullet}\n\n" ) # Build prompt manually or use chat template if available. # Llama 3 use specific tags. # <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n ... <|eot_id|> # But we can try to use tokenizer.apply_chat_template if it exists. if hasattr(self.tokenizer, "apply_chat_template"): messages = [{"role": "system", "content": "You are a helpful assistant.\n" + memory_block}] for turn in history: messages.append({"role": turn.role, "content": turn.text}) return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # Fallback manual construction (simplified Llama 2/3 style or generic) # This is risky for Llama 3 specifically which needs exact tokens. # Let's assume apply_chat_template works for Llama-3-Instruct models. # If fallback needed: history_lines = [] for turn in history[-8:]: role_tag = "user" if turn.role == "user" else "assistant" # Generic format history_lines.append(f"{role_tag}: {turn.text}") prompt = ( "System: You are a helpful assistant.\n" + memory_block + "\n".join(history_lines) + "\nassistant:" ) return prompt def answer( self, history: List[ChatTurn], memory_notes: List[str], max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9, top_k: Optional[int] = None, ) -> str: prompt = self._build_prompt(history, memory_notes) inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=self.max_context_length).to(self.model.device) gen_kwargs = { "max_new_tokens": max_new_tokens, "do_sample": temperature > 0, "temperature": temperature, "top_p": top_p, } if top_k is not None: gen_kwargs["top_k"] = top_k with torch.no_grad(): outputs = self.model.generate( **inputs, eos_token_id=self.tokenizer.eos_token_id, **gen_kwargs, ) full = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # naive stripping # With chat template, 'full' usually contains the whole conversation. # We need to extract just the new part. # But 'prompt' string might not match decoded output exactly due to special tokens skipping. # Better: slice output ids. input_len = inputs["input_ids"].shape[1] gen_ids = outputs[0][input_len:] answer_text = self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip() return answer_text