diff options
Diffstat (limited to 'collaborativeagents/adapters/reflection_adapter.py')
| -rw-r--r-- | collaborativeagents/adapters/reflection_adapter.py | 416 |
1 files changed, 416 insertions, 0 deletions
diff --git a/collaborativeagents/adapters/reflection_adapter.py b/collaborativeagents/adapters/reflection_adapter.py new file mode 100644 index 0000000..d535be2 --- /dev/null +++ b/collaborativeagents/adapters/reflection_adapter.py @@ -0,0 +1,416 @@ +""" +Reflection Adapter - vLLM-based implementation using original CollaborativeAgents prompts. + +This implements the "reflection" baseline from the MULTISESSIONCOLLAB paper: +- After each session, agent reflects on interaction to update memory (agent_notes) +- Memory is provided to agent at start of subsequent sessions +- Uses session-level reflection + persistent memory +- Uses LLM-based retrieval (proper_scaffolding) to prevent context overflow + +Now uses vLLM for fast inference instead of local transformers. +Uses EXACT prompts from the original CollaborativeAgents paper for fairness. +""" + +import sys +from pathlib import Path +from typing import Optional, List, Dict, Any +from json_repair import repair_json + +# Add parent for utils import +sys.path.insert(0, str(Path(__file__).parent.parent)) +from utils.vllm_client import VLLMClient, VLLMConfig + +# Import ORIGINAL prompts from CollaborativeAgents for fair reproduction +sys.path.insert(0, str(Path(__file__).parent.parent / "collaborativeagents")) +from collaborativeagents.prompts import ( + reflective_agent_system_prompt_no_json, + update_agent_notes_prompt, + proper_scaffolding_prompt, +) +from collaborativeagents.utils import get_conversation_string + +# Default vLLM URL (agent server on port 8003) +DEFAULT_VLLM_URL = "http://localhost:8003/v1" + + +class ReflectionAdapter: + """ + Adapter for the Reflection baseline from MULTISESSIONCOLLAB. + + Uses vLLM for fast inference with: + - agent_notes: Persistent memory updated via session-level reflection + - Memory retrieval at each turn via proper_scaffolding (when notes are long) + + Uses ORIGINAL CollaborativeAgents prompts for fair benchmark reproduction. + """ + + def __init__( + self, + model_name: str = None, # Ignored - vLLM auto-discovers model + device_assignment: dict = None, # Ignored - vLLM handles GPU + api_base: str = None, # vLLM server URL + api_key: str = None, # Ignored + with_scaffolding: bool = True, + with_proper_scaffolding: bool = True, # Enable LLM-based retrieval by default + vllm_url: str = None, # vLLM server URL + max_new_tokens: int = 2048, # Match original CollaborativeAgents setting + ): + self.vllm_url = vllm_url or api_base or DEFAULT_VLLM_URL + self.with_scaffolding = with_scaffolding + self.with_proper_scaffolding = with_proper_scaffolding + self.max_new_tokens = max_new_tokens + + # Per-user memory storage + self._user_notes: Dict[str, str] = {} + self._current_user_id: Optional[str] = None + self._conversation_history: List[Dict[str, str]] = [] + + # vLLM client (initialized lazily) + self._client: Optional[VLLMClient] = None + self._initialized = False + + def initialize(self): + """Initialize the adapter (connects to vLLM server).""" + if self._initialized: + return + + print(f"[ReflectionAdapter] Connecting to vLLM server at {self.vllm_url}...") + print(f"[ReflectionAdapter] Using proper_scaffolding={self.with_proper_scaffolding}") + + # Retry connection with exponential backoff + import time + max_retries = 30 + for attempt in range(max_retries): + try: + self._client = VLLMClient(base_url=self.vllm_url) + if self._client.health_check(): + break + except Exception as e: + pass + + if attempt < max_retries - 1: + wait_time = min(2 ** attempt * 0.5, 10) # 0.5, 1, 2, 4, 8, 10, 10... + time.sleep(wait_time) + else: + raise RuntimeError(f"vLLM server not responding at {self.vllm_url} after {max_retries} retries") + + self._initialized = True + print(f"[ReflectionAdapter] Connected to vLLM (model: {self._client.config.model})") + + def _generate(self, messages: List[Dict[str, str]], max_new_tokens: int = 1024) -> str: + """Generate response using vLLM server.""" + if not self._initialized: + self.initialize() + + result = self._client.chat( + messages=messages, + max_tokens=max_new_tokens, + temperature=1.0, # Match original CollaborativeAgents setting + top_p=0.9, + ) + + return result["content"] + + def _add_scaffolding_to_conversation( + self, + conversation: List[Dict[str, str]], + agent_notes: str + ) -> List[Dict[str, str]]: + """ + Add scaffolding (memory notes) to conversation. + + This is the EXACT logic from CollaboratorAgent.add_scaffolding_to_conversation(). + + If with_proper_scaffolding=True: Use LLM to extract relevant notes + If with_proper_scaffolding=False: Prepend all notes to first message + """ + if not self.with_proper_scaffolding: + # Simple mode: prepend all notes + if conversation: + conversation = list(conversation) # Copy to avoid mutation + conversation[0] = dict(conversation[0]) # Copy first message + conversation[0]["content"] = ( + f"Remember, you have been taking notes throughout past conversations " + f"about user preferences. Use whatever is relevant in these notes to " + f"guide your response:\n{agent_notes}\n\n" + conversation[0]["content"] + ) + return conversation + else: + # Proper scaffolding: use LLM to extract relevant notes + conversation_str = get_conversation_string(conversation) + formatted_prompt = proper_scaffolding_prompt.format( + conversation_history=conversation_str, + complete_agent_notes=agent_notes + ) + + # Call LLM to extract relevant notes (with retries) + for attempt in range(3): + try: + messages = [{"role": "user", "content": formatted_prompt}] + response = self._generate(messages, max_new_tokens=512) + + parsed = repair_json(response, return_objects=True) + missing_keys = [k for k in ["reasoning", "relevant_notes"] if k not in parsed] + + if missing_keys: + print(f"[ReflectionAdapter] Scaffolding missing keys: {missing_keys}") + continue + + scaffolded_notes = parsed["relevant_notes"] + + # Prepend extracted notes to first message + if conversation: + conversation = list(conversation) + conversation[0] = dict(conversation[0]) + conversation[0]["content"] = ( + f"Remember, you have been taking notes throughout past conversations " + f"about user preferences. Use these notes to guide your response:\n" + f"{scaffolded_notes}\n\n" + conversation[0]["content"] + ) + + return conversation + + except Exception as e: + print(f"[ReflectionAdapter] Scaffolding attempt {attempt+1} failed: {e}") + continue + + # Fallback: use all notes if retrieval fails + print("[ReflectionAdapter] Scaffolding failed, using full notes") + if conversation: + conversation = list(conversation) + conversation[0] = dict(conversation[0]) + conversation[0]["content"] = ( + f"Remember, you have been taking notes throughout past conversations " + f"about user preferences. Use whatever is relevant in these notes to " + f"guide your response:\n{agent_notes}\n\n" + conversation[0]["content"] + ) + return conversation + + def start_session(self, user_id: str, user_profile: dict = None): + """Start a new session for a user.""" + if not self._initialized: + self.initialize() + + self._current_user_id = user_id + self._conversation_history = [] + + def generate_response( + self, + query: str, + conversation_history: List[Dict[str, str]] = None + ) -> Dict[str, Any]: + """Generate a response using the reflection agent.""" + if not self._initialized: + self.initialize() + + # Add user query to history + self._conversation_history.append({"role": "user", "content": query}) + + # Get current notes for this user + agent_notes = self._user_notes.get(self._current_user_id, "No notes yet about this user.") + + # Build conversation with scaffolding (uses proper_scaffolding if enabled) + if self.with_scaffolding and agent_notes != "No notes yet about this user.": + conversation_with_notes = self._add_scaffolding_to_conversation( + self._conversation_history, agent_notes + ) + else: + conversation_with_notes = self._conversation_history + + # Build system prompt using ORIGINAL CollaborativeAgents prompt + # Note: For no_json mode, we don't include agent_notes in system prompt + # because they're added via scaffolding to the conversation + system_prompt = reflective_agent_system_prompt_no_json.format( + agent_notes=agent_notes if not self.with_scaffolding else "See notes in conversation." + ) + + # Build messages for generation + messages = [{"role": "system", "content": system_prompt}] + messages.extend(conversation_with_notes) + + # Generate response + response_text = self._generate(messages, max_new_tokens=self.max_new_tokens) + + self._conversation_history.append({"role": "assistant", "content": response_text}) + + return { + "response": response_text, + "reasoning": "", + "debug": {"agent_notes": agent_notes, "proper_scaffolding": self.with_proper_scaffolding} + } + + def prepare_prompt( + self, + query: str, + conversation_history: List[Dict[str, str]] = None + ) -> tuple: + """ + Prepare prompt for batch processing without calling main LLM. + + Note: This may still call LLM for scaffolding (memory retrieval), + but the main generation is deferred for batching. + + Args: + query: Current user query + conversation_history: Previous conversation + + Returns: + Tuple of (messages, context) for batch processing + """ + if not self._initialized: + self.initialize() + + # Add user query to history + self._conversation_history.append({"role": "user", "content": query}) + + # Get current notes for this user + agent_notes = self._user_notes.get(self._current_user_id, "No notes yet about this user.") + + # Build conversation with scaffolding (may involve LLM call for proper_scaffolding) + if self.with_scaffolding and agent_notes != "No notes yet about this user.": + conversation_with_notes = self._add_scaffolding_to_conversation( + self._conversation_history, agent_notes + ) + else: + conversation_with_notes = self._conversation_history + + # Build system prompt using ORIGINAL CollaborativeAgents prompt + system_prompt = reflective_agent_system_prompt_no_json.format( + agent_notes=agent_notes if not self.with_scaffolding else "See notes in conversation." + ) + + # Build messages for generation + messages = [{"role": "system", "content": system_prompt}] + messages.extend(conversation_with_notes) + + # Context for post-processing + ctx = { + "agent_notes": agent_notes, + } + + return messages, ctx + + def process_response( + self, + response: str, + context: dict + ) -> Dict[str, Any]: + """ + Process LLM response after batch call. + + Args: + response: LLM response text + context: Context dict from prepare_prompt() + + Returns: + Dict with 'response', 'reasoning', and debug info + """ + self._conversation_history.append({"role": "assistant", "content": response}) + + return { + "response": response, + "reasoning": "", + "debug": { + "agent_notes": context["agent_notes"], + "proper_scaffolding": self.with_proper_scaffolding + } + } + + def _update_agent_notes(self, agent_notes: str, conversation: List[Dict[str, str]]) -> Optional[Dict]: + """ + Update agent notes using ORIGINAL CollaborativeAgents logic. + + For 8B models (no_json): Use raw response directly + For JSON models: Check for required keys, retry up to num_retries times + Returns None if all retries fail (keeps old notes) + """ + conversation_str = get_conversation_string(conversation) + formatted_prompt = update_agent_notes_prompt.format( + agent_notes=agent_notes, + conversation_str=conversation_str + ) + + num_retries = 10 + no_json = True # 8B model + + for attempt in range(num_retries): + try: + messages = [{"role": "user", "content": formatted_prompt}] + response = self._generate(messages, max_new_tokens=512) + + # For 8B models (no_json=True): use raw response directly + if no_json: + return {"agent_notes": response} + + # For JSON models: parse and check keys + processed_response = repair_json(response, return_objects=True) + missing_keys = [k for k in ["user_preferences_reasoning", "agent_notes"] if k not in processed_response] + + if missing_keys: + print(f"[ReflectionAdapter] Missing keys: {missing_keys}, attempt {attempt + 1}") + continue + + return processed_response + + except Exception as e: + print(f"[ReflectionAdapter] Failed to update agent notes: {e}") + + return None # All retries failed, keep old notes + + def end_session(self, task_success: bool = False) -> Dict[str, Any]: + """ + End session and update agent notes via reflection. + + Uses the ORIGINAL update_agent_notes logic from CollaborativeAgents. + """ + if not self._current_user_id: + return {} + + # Get current notes + current_notes = self._user_notes.get(self._current_user_id, "No notes yet.") + + # Update notes via session-level reflection + if len(self._conversation_history) > 0: + result = self._update_agent_notes(current_notes, self._conversation_history) + + if result is not None and "agent_notes" in result: + updated_notes = result["agent_notes"] + self._user_notes[self._current_user_id] = updated_notes + print(f"[ReflectionAdapter] Updated notes for {self._current_user_id} " + f"({len(current_notes)} -> {len(updated_notes)} chars)") + else: + print(f"[ReflectionAdapter] Keeping old notes for {self._current_user_id} " + f"(update failed)") + + return { + "turns": len(self._conversation_history), + "task_success": task_success, + "notes_updated": True, + } + + def reset_user(self, user_id: str): + """Reset all memory for a user.""" + if user_id in self._user_notes: + del self._user_notes[user_id] + + def __call__( + self, + messages: List[Dict[str, str]], + user_profile: dict = None, + **kwargs + ) -> str: + """Callable interface for ConversationGenerator compatibility.""" + if not messages: + return "How can I help you?" + + last_user_msg = None + for msg in reversed(messages): + if msg["role"] == "user": + last_user_msg = msg["content"] + break + + if last_user_msg is None: + return "How can I help you?" + + result = self.generate_response(last_user_msg, messages) + return result["response"] |
