summaryrefslogtreecommitdiff
path: root/collaborativeagents/agents/local_user_agent.py
diff options
context:
space:
mode:
Diffstat (limited to 'collaborativeagents/agents/local_user_agent.py')
-rw-r--r--collaborativeagents/agents/local_user_agent.py334
1 files changed, 334 insertions, 0 deletions
diff --git a/collaborativeagents/agents/local_user_agent.py b/collaborativeagents/agents/local_user_agent.py
new file mode 100644
index 0000000..eae311e
--- /dev/null
+++ b/collaborativeagents/agents/local_user_agent.py
@@ -0,0 +1,334 @@
+"""
+Local User Agent - Uses local transformers model for user simulation.
+
+This replaces the litellm-based UserAgent with a local transformers implementation
+for running experiments without requiring an API server.
+"""
+
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from typing import List, Dict, Any, Optional
+from copy import deepcopy
+from json_repair import repair_json
+
+# Default model paths
+DEFAULT_MODEL_PATH_8B = "/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/models/llama-3.1-8b-instruct"
+DEFAULT_MODEL_PATH_70B = "hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4"
+
+# Use 70B by default for better user simulation
+DEFAULT_MODEL_PATH = DEFAULT_MODEL_PATH_70B
+
+# Termination signal from CollaborativeAgents
+TERMINATION_SIGNAL = "TERMINATE"
+
+# User system prompt with preferences (simplified version)
+USER_SYSTEM_PROMPT_WITH_PREFERENCES = """You are a user simulator collaborating with an agent to solve a problem. You will be provided with a problem description, and you must get the agent to help you solve it. You will also be provided with user preferences, which you must actively enforce throughout the conversation.
+
+# Problem Description
+{problem}
+Note: the agent cannot see this problem description.
+
+# User Persona
+{user_persona}
+
+# User Preferences
+{user_preferences}
+These preferences are NON-NEGOTIABLE that define how you prefer the agent to behave. They must be strictly enforced:
+- **Enforce immediately**: Every agent response must satisfy your preferences before you can proceed.
+- **Never proceed without compliance**: Do NOT move forward until the agent follows your preferences.
+
+# Draft Answer Management
+- Maintain a draft answer throughout the conversation. Start with "I don't know".
+- Update your draft based on agent responses that follow your preferences.
+
+# Conversation Termination
+When ready to terminate (draft answer is good OR agent cannot help), respond with "{termination_signal}".
+
+# Output Format:
+Respond with a JSON object:
+{{
+ "reasoning": "Brief reasoning about the agent's response and your preferences",
+ "draft_answer": "Your current working draft answer",
+ "should_terminate": true/false,
+ "response": "Your response to the agent"
+}}
+"""
+
+USER_SYSTEM_PROMPT_WITHOUT_PREFERENCES = """You are a user simulator collaborating with an agent to solve a problem.
+
+# Problem Description
+{problem}
+
+# User Persona
+{user_persona}
+
+# Conversation Termination
+When ready to terminate, respond with "{termination_signal}".
+
+# Output Format:
+{{
+ "reasoning": "Brief reasoning",
+ "draft_answer": "Your current working draft answer",
+ "should_terminate": true/false,
+ "response": "Your response to the agent"
+}}
+"""
+
+
+class LocalUserAgent:
+ """
+ Local User Agent using transformers for user simulation.
+
+ Simulates a user who:
+ - Presents problems to the agent
+ - Enforces preferences throughout the conversation
+ - Decides when to terminate the conversation
+ """
+
+ def __init__(
+ self,
+ user_task_description: str,
+ problem: str,
+ user_persona: str = None,
+ user_preferences: str = None,
+ model_path: str = DEFAULT_MODEL_PATH,
+ num_retries: int = 3,
+ # For compatibility with original UserAgent interface
+ model_name: str = None,
+ api_base: str = None,
+ api_key: str = None,
+ ):
+ self.problem = problem
+ self.user_persona = user_persona or "A helpful user seeking assistance."
+ self.user_preferences = user_preferences
+ self.num_retries = num_retries
+ self.model_path = model_path
+
+ # Build system prompt
+ if user_preferences:
+ self.system_prompt = USER_SYSTEM_PROMPT_WITH_PREFERENCES.format(
+ problem=problem,
+ user_persona=self.user_persona,
+ user_preferences=user_preferences,
+ termination_signal=TERMINATION_SIGNAL
+ )
+ else:
+ self.system_prompt = USER_SYSTEM_PROMPT_WITHOUT_PREFERENCES.format(
+ problem=problem,
+ user_persona=self.user_persona,
+ termination_signal=TERMINATION_SIGNAL
+ )
+
+ # Model components (loaded lazily)
+ self._model = None
+ self._tokenizer = None
+ self._initialized = False
+
+ def _ensure_initialized(self):
+ """Lazy initialization of model."""
+ if self._initialized:
+ return
+
+ import os
+ # Use project HF cache
+ cache_dir = os.environ.get("HF_HOME", "/projects/bfqt/users/yurenh2/hf_cache/huggingface")
+
+ print(f"[LocalUserAgent] Loading model from {self.model_path}...")
+ self._tokenizer = AutoTokenizer.from_pretrained(
+ self.model_path,
+ cache_dir=cache_dir,
+ trust_remote_code=True
+ )
+
+ # Check if this is an AWQ model
+ is_awq = "awq" in self.model_path.lower()
+
+ if is_awq:
+ # AWQ models use float16 and auto device map
+ self._model = AutoModelForCausalLM.from_pretrained(
+ self.model_path,
+ torch_dtype=torch.float16,
+ device_map="auto",
+ cache_dir=cache_dir,
+ trust_remote_code=True,
+ )
+ else:
+ # Standard model loading
+ self._model = AutoModelForCausalLM.from_pretrained(
+ self.model_path,
+ torch_dtype=torch.bfloat16,
+ device_map="auto",
+ cache_dir=cache_dir,
+ )
+
+ if self._tokenizer.pad_token_id is None:
+ self._tokenizer.pad_token = self._tokenizer.eos_token
+
+ self._initialized = True
+ print(f"[LocalUserAgent] Initialized (AWQ={is_awq})")
+
+ def _generate(self, messages: List[Dict[str, str]], max_new_tokens: int = 512) -> str:
+ """Generate response using local model."""
+ self._ensure_initialized()
+
+ # Apply chat template
+ prompt = self._tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True
+ )
+
+ inputs = self._tokenizer(
+ prompt,
+ return_tensors="pt",
+ truncation=True,
+ max_length=8192
+ ).to(self._model.device)
+
+ with torch.no_grad():
+ outputs = self._model.generate(
+ **inputs,
+ max_new_tokens=max_new_tokens,
+ do_sample=True,
+ temperature=0.7,
+ top_p=0.9,
+ eos_token_id=self._tokenizer.eos_token_id,
+ pad_token_id=self._tokenizer.pad_token_id,
+ )
+
+ # Extract only the generated part
+ input_len = inputs["input_ids"].shape[1]
+ gen_ids = outputs[0][input_len:]
+ response = self._tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
+
+ return response
+
+ def get_system_prompt(self) -> str:
+ """Get the system prompt."""
+ return self.system_prompt
+
+ def reverse_roles(self, conversation: List[Dict[str, str]]) -> List[Dict[str, str]]:
+ """Reverse roles for user perspective (agent becomes user, user becomes assistant)."""
+ conversation = deepcopy(conversation)
+ return [
+ {"role": "user" if msg["role"] == "assistant" else "assistant", "content": msg["content"]}
+ for msg in conversation
+ ]
+
+ def generate_user_response(self, conversation: List[Dict[str, str]]) -> Optional[Dict[str, Any]]:
+ """
+ Generate user response given the conversation history.
+
+ Args:
+ conversation: List of {"role": "user"|"assistant", "content": str}
+ From user perspective: agent messages are "assistant"
+
+ Returns:
+ Dict with keys: reasoning, draft_answer, should_terminate, response
+ Or None if failed
+ """
+ for attempt in range(self.num_retries):
+ try:
+ # Build messages: system prompt + reversed conversation
+ messages = [{"role": "system", "content": self.system_prompt}]
+ messages.extend(self.reverse_roles(conversation))
+
+ # Generate response
+ response_text = self._generate(messages)
+
+ # Try to parse as JSON
+ try:
+ parsed = repair_json(response_text, return_objects=True)
+
+ # Check for required keys
+ required_keys = ["reasoning", "draft_answer", "should_terminate", "response"]
+ missing = [k for k in required_keys if k not in parsed]
+
+ if missing:
+ print(f"[LocalUserAgent] Missing keys: {missing}, attempt {attempt+1}")
+ continue
+
+ return parsed
+
+ except Exception as e:
+ # If JSON parsing fails, try to extract response directly
+ print(f"[LocalUserAgent] JSON parse failed: {e}, attempt {attempt+1}")
+
+ # Fallback: return the raw text as the response
+ if TERMINATION_SIGNAL in response_text:
+ return {
+ "reasoning": "Ending conversation",
+ "draft_answer": "",
+ "should_terminate": True,
+ "response": TERMINATION_SIGNAL
+ }
+ else:
+ return {
+ "reasoning": "",
+ "draft_answer": "",
+ "should_terminate": False,
+ "response": response_text
+ }
+
+ except Exception as e:
+ print(f"[LocalUserAgent] Error: {e}, attempt {attempt+1}")
+ continue
+
+ print(f"[LocalUserAgent] Failed after {self.num_retries} attempts")
+ return None
+
+
+# Singleton model for efficiency (shared across multiple LocalUserAgent instances)
+_shared_model = None
+_shared_tokenizer = None
+
+
+class SharedLocalUserAgent(LocalUserAgent):
+ """
+ LocalUserAgent that shares model across instances to save memory.
+ """
+
+ def _ensure_initialized(self):
+ """Use shared model instead of loading a new one."""
+ global _shared_model, _shared_tokenizer
+
+ if self._initialized:
+ return
+
+ if _shared_model is None:
+ import os
+ cache_dir = os.environ.get("HF_HOME", "/projects/bfqt/users/yurenh2/hf_cache/huggingface")
+
+ print(f"[SharedLocalUserAgent] Loading shared model from {self.model_path}...")
+ _shared_tokenizer = AutoTokenizer.from_pretrained(
+ self.model_path,
+ cache_dir=cache_dir,
+ trust_remote_code=True
+ )
+
+ # Check if this is an AWQ model
+ is_awq = "awq" in self.model_path.lower()
+
+ if is_awq:
+ _shared_model = AutoModelForCausalLM.from_pretrained(
+ self.model_path,
+ torch_dtype=torch.float16,
+ device_map="auto",
+ cache_dir=cache_dir,
+ trust_remote_code=True,
+ )
+ else:
+ _shared_model = AutoModelForCausalLM.from_pretrained(
+ self.model_path,
+ torch_dtype=torch.bfloat16,
+ device_map="auto",
+ cache_dir=cache_dir,
+ )
+
+ if _shared_tokenizer.pad_token_id is None:
+ _shared_tokenizer.pad_token = _shared_tokenizer.eos_token
+ print(f"[SharedLocalUserAgent] Shared model loaded (AWQ={is_awq})")
+
+ self._model = _shared_model
+ self._tokenizer = _shared_tokenizer
+ self._initialized = True