diff options
Diffstat (limited to 'collaborativeagents/agents/local_user_agent.py')
| -rw-r--r-- | collaborativeagents/agents/local_user_agent.py | 334 |
1 files changed, 334 insertions, 0 deletions
diff --git a/collaborativeagents/agents/local_user_agent.py b/collaborativeagents/agents/local_user_agent.py new file mode 100644 index 0000000..eae311e --- /dev/null +++ b/collaborativeagents/agents/local_user_agent.py @@ -0,0 +1,334 @@ +""" +Local User Agent - Uses local transformers model for user simulation. + +This replaces the litellm-based UserAgent with a local transformers implementation +for running experiments without requiring an API server. +""" + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from typing import List, Dict, Any, Optional +from copy import deepcopy +from json_repair import repair_json + +# Default model paths +DEFAULT_MODEL_PATH_8B = "/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/models/llama-3.1-8b-instruct" +DEFAULT_MODEL_PATH_70B = "hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4" + +# Use 70B by default for better user simulation +DEFAULT_MODEL_PATH = DEFAULT_MODEL_PATH_70B + +# Termination signal from CollaborativeAgents +TERMINATION_SIGNAL = "TERMINATE" + +# User system prompt with preferences (simplified version) +USER_SYSTEM_PROMPT_WITH_PREFERENCES = """You are a user simulator collaborating with an agent to solve a problem. You will be provided with a problem description, and you must get the agent to help you solve it. You will also be provided with user preferences, which you must actively enforce throughout the conversation. + +# Problem Description +{problem} +Note: the agent cannot see this problem description. + +# User Persona +{user_persona} + +# User Preferences +{user_preferences} +These preferences are NON-NEGOTIABLE that define how you prefer the agent to behave. They must be strictly enforced: +- **Enforce immediately**: Every agent response must satisfy your preferences before you can proceed. +- **Never proceed without compliance**: Do NOT move forward until the agent follows your preferences. + +# Draft Answer Management +- Maintain a draft answer throughout the conversation. Start with "I don't know". +- Update your draft based on agent responses that follow your preferences. + +# Conversation Termination +When ready to terminate (draft answer is good OR agent cannot help), respond with "{termination_signal}". + +# Output Format: +Respond with a JSON object: +{{ + "reasoning": "Brief reasoning about the agent's response and your preferences", + "draft_answer": "Your current working draft answer", + "should_terminate": true/false, + "response": "Your response to the agent" +}} +""" + +USER_SYSTEM_PROMPT_WITHOUT_PREFERENCES = """You are a user simulator collaborating with an agent to solve a problem. + +# Problem Description +{problem} + +# User Persona +{user_persona} + +# Conversation Termination +When ready to terminate, respond with "{termination_signal}". + +# Output Format: +{{ + "reasoning": "Brief reasoning", + "draft_answer": "Your current working draft answer", + "should_terminate": true/false, + "response": "Your response to the agent" +}} +""" + + +class LocalUserAgent: + """ + Local User Agent using transformers for user simulation. + + Simulates a user who: + - Presents problems to the agent + - Enforces preferences throughout the conversation + - Decides when to terminate the conversation + """ + + def __init__( + self, + user_task_description: str, + problem: str, + user_persona: str = None, + user_preferences: str = None, + model_path: str = DEFAULT_MODEL_PATH, + num_retries: int = 3, + # For compatibility with original UserAgent interface + model_name: str = None, + api_base: str = None, + api_key: str = None, + ): + self.problem = problem + self.user_persona = user_persona or "A helpful user seeking assistance." + self.user_preferences = user_preferences + self.num_retries = num_retries + self.model_path = model_path + + # Build system prompt + if user_preferences: + self.system_prompt = USER_SYSTEM_PROMPT_WITH_PREFERENCES.format( + problem=problem, + user_persona=self.user_persona, + user_preferences=user_preferences, + termination_signal=TERMINATION_SIGNAL + ) + else: + self.system_prompt = USER_SYSTEM_PROMPT_WITHOUT_PREFERENCES.format( + problem=problem, + user_persona=self.user_persona, + termination_signal=TERMINATION_SIGNAL + ) + + # Model components (loaded lazily) + self._model = None + self._tokenizer = None + self._initialized = False + + def _ensure_initialized(self): + """Lazy initialization of model.""" + if self._initialized: + return + + import os + # Use project HF cache + cache_dir = os.environ.get("HF_HOME", "/projects/bfqt/users/yurenh2/hf_cache/huggingface") + + print(f"[LocalUserAgent] Loading model from {self.model_path}...") + self._tokenizer = AutoTokenizer.from_pretrained( + self.model_path, + cache_dir=cache_dir, + trust_remote_code=True + ) + + # Check if this is an AWQ model + is_awq = "awq" in self.model_path.lower() + + if is_awq: + # AWQ models use float16 and auto device map + self._model = AutoModelForCausalLM.from_pretrained( + self.model_path, + torch_dtype=torch.float16, + device_map="auto", + cache_dir=cache_dir, + trust_remote_code=True, + ) + else: + # Standard model loading + self._model = AutoModelForCausalLM.from_pretrained( + self.model_path, + torch_dtype=torch.bfloat16, + device_map="auto", + cache_dir=cache_dir, + ) + + if self._tokenizer.pad_token_id is None: + self._tokenizer.pad_token = self._tokenizer.eos_token + + self._initialized = True + print(f"[LocalUserAgent] Initialized (AWQ={is_awq})") + + def _generate(self, messages: List[Dict[str, str]], max_new_tokens: int = 512) -> str: + """Generate response using local model.""" + self._ensure_initialized() + + # Apply chat template + prompt = self._tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + + inputs = self._tokenizer( + prompt, + return_tensors="pt", + truncation=True, + max_length=8192 + ).to(self._model.device) + + with torch.no_grad(): + outputs = self._model.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=True, + temperature=0.7, + top_p=0.9, + eos_token_id=self._tokenizer.eos_token_id, + pad_token_id=self._tokenizer.pad_token_id, + ) + + # Extract only the generated part + input_len = inputs["input_ids"].shape[1] + gen_ids = outputs[0][input_len:] + response = self._tokenizer.decode(gen_ids, skip_special_tokens=True).strip() + + return response + + def get_system_prompt(self) -> str: + """Get the system prompt.""" + return self.system_prompt + + def reverse_roles(self, conversation: List[Dict[str, str]]) -> List[Dict[str, str]]: + """Reverse roles for user perspective (agent becomes user, user becomes assistant).""" + conversation = deepcopy(conversation) + return [ + {"role": "user" if msg["role"] == "assistant" else "assistant", "content": msg["content"]} + for msg in conversation + ] + + def generate_user_response(self, conversation: List[Dict[str, str]]) -> Optional[Dict[str, Any]]: + """ + Generate user response given the conversation history. + + Args: + conversation: List of {"role": "user"|"assistant", "content": str} + From user perspective: agent messages are "assistant" + + Returns: + Dict with keys: reasoning, draft_answer, should_terminate, response + Or None if failed + """ + for attempt in range(self.num_retries): + try: + # Build messages: system prompt + reversed conversation + messages = [{"role": "system", "content": self.system_prompt}] + messages.extend(self.reverse_roles(conversation)) + + # Generate response + response_text = self._generate(messages) + + # Try to parse as JSON + try: + parsed = repair_json(response_text, return_objects=True) + + # Check for required keys + required_keys = ["reasoning", "draft_answer", "should_terminate", "response"] + missing = [k for k in required_keys if k not in parsed] + + if missing: + print(f"[LocalUserAgent] Missing keys: {missing}, attempt {attempt+1}") + continue + + return parsed + + except Exception as e: + # If JSON parsing fails, try to extract response directly + print(f"[LocalUserAgent] JSON parse failed: {e}, attempt {attempt+1}") + + # Fallback: return the raw text as the response + if TERMINATION_SIGNAL in response_text: + return { + "reasoning": "Ending conversation", + "draft_answer": "", + "should_terminate": True, + "response": TERMINATION_SIGNAL + } + else: + return { + "reasoning": "", + "draft_answer": "", + "should_terminate": False, + "response": response_text + } + + except Exception as e: + print(f"[LocalUserAgent] Error: {e}, attempt {attempt+1}") + continue + + print(f"[LocalUserAgent] Failed after {self.num_retries} attempts") + return None + + +# Singleton model for efficiency (shared across multiple LocalUserAgent instances) +_shared_model = None +_shared_tokenizer = None + + +class SharedLocalUserAgent(LocalUserAgent): + """ + LocalUserAgent that shares model across instances to save memory. + """ + + def _ensure_initialized(self): + """Use shared model instead of loading a new one.""" + global _shared_model, _shared_tokenizer + + if self._initialized: + return + + if _shared_model is None: + import os + cache_dir = os.environ.get("HF_HOME", "/projects/bfqt/users/yurenh2/hf_cache/huggingface") + + print(f"[SharedLocalUserAgent] Loading shared model from {self.model_path}...") + _shared_tokenizer = AutoTokenizer.from_pretrained( + self.model_path, + cache_dir=cache_dir, + trust_remote_code=True + ) + + # Check if this is an AWQ model + is_awq = "awq" in self.model_path.lower() + + if is_awq: + _shared_model = AutoModelForCausalLM.from_pretrained( + self.model_path, + torch_dtype=torch.float16, + device_map="auto", + cache_dir=cache_dir, + trust_remote_code=True, + ) + else: + _shared_model = AutoModelForCausalLM.from_pretrained( + self.model_path, + torch_dtype=torch.bfloat16, + device_map="auto", + cache_dir=cache_dir, + ) + + if _shared_tokenizer.pad_token_id is None: + _shared_tokenizer.pad_token = _shared_tokenizer.eos_token + print(f"[SharedLocalUserAgent] Shared model loaded (AWQ={is_awq})") + + self._model = _shared_model + self._tokenizer = _shared_tokenizer + self._initialized = True |
