diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-27 15:43:42 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-27 15:43:42 -0600 |
| commit | f918fc90b8d71d1287590b016d926268be573de0 (patch) | |
| tree | d9009c8612c8e7f866c31d22fb979892a5b55eeb /src/personalization/models/llm/llama_instruct.py | |
| parent | 680513b7771a29f27cbbb3ffb009a69a913de6f9 (diff) | |
Add model wrapper modules (embedding, reranker, llm, preference_extractor)
Add Python wrappers for:
- Qwen3/Nemotron embedding models
- BGE/Qwen3 rerankers
- vLLM/Llama/Qwen LLM backends
- GPT-4o/LLM-based preference extractors
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'src/personalization/models/llm/llama_instruct.py')
| -rw-r--r-- | src/personalization/models/llm/llama_instruct.py | 129 |
1 files changed, 129 insertions, 0 deletions
diff --git a/src/personalization/models/llm/llama_instruct.py b/src/personalization/models/llm/llama_instruct.py new file mode 100644 index 0000000..bdf0dff --- /dev/null +++ b/src/personalization/models/llm/llama_instruct.py @@ -0,0 +1,129 @@ +from typing import List, Optional +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from personalization.models.llm.base import ChatModel +from personalization.types import ChatTurn + +class LlamaChatModel(ChatModel): + def __init__( + self, + model_path: str, + device: str = "cuda", + dtype: str = "bfloat16", # Keep type hint as str for legacy, but handle torch.dtype + max_context_length: int = 8192, + ): + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + + # Handle dtype if it's already a torch.dtype object + if isinstance(dtype, str): + torch_dtype = getattr(torch, dtype) + else: + torch_dtype = dtype + + # Handle specific device assignment (e.g., "cuda:0", "cuda:1") + if device and device.startswith("cuda:"): + # Load to CPU first, then move to specific GPU + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch_dtype, + device_map=None, + low_cpu_mem_usage=True, + ) + self.model = self.model.to(device) + else: + # Use accelerate's device mapping + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch_dtype, + device_map=device, + ) + + self.max_context_length = max_context_length + if self.tokenizer.pad_token_id is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + def _build_prompt( + self, + history: List[ChatTurn], + memory_notes: List[str], + ) -> str: + memory_block = "" + if memory_notes: + bullet = "\n".join(f"- {n}" for n in memory_notes) + memory_block = ( + "Here are the user's preferences and memories:\n" + f"{bullet}\n\n" + ) + + # Build prompt manually or use chat template if available. + # Llama 3 use specific tags. + # <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n ... <|eot_id|> + # But we can try to use tokenizer.apply_chat_template if it exists. + + if hasattr(self.tokenizer, "apply_chat_template"): + messages = [{"role": "system", "content": "You are a helpful assistant.\n" + memory_block}] + for turn in history: + messages.append({"role": turn.role, "content": turn.text}) + return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + # Fallback manual construction (simplified Llama 2/3 style or generic) + # This is risky for Llama 3 specifically which needs exact tokens. + # Let's assume apply_chat_template works for Llama-3-Instruct models. + + # If fallback needed: + history_lines = [] + for turn in history[-8:]: + role_tag = "user" if turn.role == "user" else "assistant" + # Generic format + history_lines.append(f"{role_tag}: {turn.text}") + + prompt = ( + "System: You are a helpful assistant.\n" + + memory_block + + "\n".join(history_lines) + + "\nassistant:" + ) + return prompt + + def answer( + self, + history: List[ChatTurn], + memory_notes: List[str], + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.9, + top_k: Optional[int] = None, + ) -> str: + prompt = self._build_prompt(history, memory_notes) + inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, + max_length=self.max_context_length).to(self.model.device) + + gen_kwargs = { + "max_new_tokens": max_new_tokens, + "do_sample": temperature > 0, + "temperature": temperature, + "top_p": top_p, + } + if top_k is not None: + gen_kwargs["top_k"] = top_k + + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + eos_token_id=self.tokenizer.eos_token_id, + **gen_kwargs, + ) + full = self.tokenizer.decode(outputs[0], skip_special_tokens=True) + # naive stripping + # With chat template, 'full' usually contains the whole conversation. + # We need to extract just the new part. + # But 'prompt' string might not match decoded output exactly due to special tokens skipping. + # Better: slice output ids. + + input_len = inputs["input_ids"].shape[1] + gen_ids = outputs[0][input_len:] + answer_text = self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip() + + return answer_text + |
