summaryrefslogtreecommitdiff
path: root/collaborativeagents/adapters/reflection_adapter.py
diff options
context:
space:
mode:
Diffstat (limited to 'collaborativeagents/adapters/reflection_adapter.py')
-rw-r--r--collaborativeagents/adapters/reflection_adapter.py416
1 files changed, 416 insertions, 0 deletions
diff --git a/collaborativeagents/adapters/reflection_adapter.py b/collaborativeagents/adapters/reflection_adapter.py
new file mode 100644
index 0000000..d535be2
--- /dev/null
+++ b/collaborativeagents/adapters/reflection_adapter.py
@@ -0,0 +1,416 @@
+"""
+Reflection Adapter - vLLM-based implementation using original CollaborativeAgents prompts.
+
+This implements the "reflection" baseline from the MULTISESSIONCOLLAB paper:
+- After each session, agent reflects on interaction to update memory (agent_notes)
+- Memory is provided to agent at start of subsequent sessions
+- Uses session-level reflection + persistent memory
+- Uses LLM-based retrieval (proper_scaffolding) to prevent context overflow
+
+Now uses vLLM for fast inference instead of local transformers.
+Uses EXACT prompts from the original CollaborativeAgents paper for fairness.
+"""
+
+import sys
+from pathlib import Path
+from typing import Optional, List, Dict, Any
+from json_repair import repair_json
+
+# Add parent for utils import
+sys.path.insert(0, str(Path(__file__).parent.parent))
+from utils.vllm_client import VLLMClient, VLLMConfig
+
+# Import ORIGINAL prompts from CollaborativeAgents for fair reproduction
+sys.path.insert(0, str(Path(__file__).parent.parent / "collaborativeagents"))
+from collaborativeagents.prompts import (
+ reflective_agent_system_prompt_no_json,
+ update_agent_notes_prompt,
+ proper_scaffolding_prompt,
+)
+from collaborativeagents.utils import get_conversation_string
+
+# Default vLLM URL (agent server on port 8003)
+DEFAULT_VLLM_URL = "http://localhost:8003/v1"
+
+
+class ReflectionAdapter:
+ """
+ Adapter for the Reflection baseline from MULTISESSIONCOLLAB.
+
+ Uses vLLM for fast inference with:
+ - agent_notes: Persistent memory updated via session-level reflection
+ - Memory retrieval at each turn via proper_scaffolding (when notes are long)
+
+ Uses ORIGINAL CollaborativeAgents prompts for fair benchmark reproduction.
+ """
+
+ def __init__(
+ self,
+ model_name: str = None, # Ignored - vLLM auto-discovers model
+ device_assignment: dict = None, # Ignored - vLLM handles GPU
+ api_base: str = None, # vLLM server URL
+ api_key: str = None, # Ignored
+ with_scaffolding: bool = True,
+ with_proper_scaffolding: bool = True, # Enable LLM-based retrieval by default
+ vllm_url: str = None, # vLLM server URL
+ max_new_tokens: int = 2048, # Match original CollaborativeAgents setting
+ ):
+ self.vllm_url = vllm_url or api_base or DEFAULT_VLLM_URL
+ self.with_scaffolding = with_scaffolding
+ self.with_proper_scaffolding = with_proper_scaffolding
+ self.max_new_tokens = max_new_tokens
+
+ # Per-user memory storage
+ self._user_notes: Dict[str, str] = {}
+ self._current_user_id: Optional[str] = None
+ self._conversation_history: List[Dict[str, str]] = []
+
+ # vLLM client (initialized lazily)
+ self._client: Optional[VLLMClient] = None
+ self._initialized = False
+
+ def initialize(self):
+ """Initialize the adapter (connects to vLLM server)."""
+ if self._initialized:
+ return
+
+ print(f"[ReflectionAdapter] Connecting to vLLM server at {self.vllm_url}...")
+ print(f"[ReflectionAdapter] Using proper_scaffolding={self.with_proper_scaffolding}")
+
+ # Retry connection with exponential backoff
+ import time
+ max_retries = 30
+ for attempt in range(max_retries):
+ try:
+ self._client = VLLMClient(base_url=self.vllm_url)
+ if self._client.health_check():
+ break
+ except Exception as e:
+ pass
+
+ if attempt < max_retries - 1:
+ wait_time = min(2 ** attempt * 0.5, 10) # 0.5, 1, 2, 4, 8, 10, 10...
+ time.sleep(wait_time)
+ else:
+ raise RuntimeError(f"vLLM server not responding at {self.vllm_url} after {max_retries} retries")
+
+ self._initialized = True
+ print(f"[ReflectionAdapter] Connected to vLLM (model: {self._client.config.model})")
+
+ def _generate(self, messages: List[Dict[str, str]], max_new_tokens: int = 1024) -> str:
+ """Generate response using vLLM server."""
+ if not self._initialized:
+ self.initialize()
+
+ result = self._client.chat(
+ messages=messages,
+ max_tokens=max_new_tokens,
+ temperature=1.0, # Match original CollaborativeAgents setting
+ top_p=0.9,
+ )
+
+ return result["content"]
+
+ def _add_scaffolding_to_conversation(
+ self,
+ conversation: List[Dict[str, str]],
+ agent_notes: str
+ ) -> List[Dict[str, str]]:
+ """
+ Add scaffolding (memory notes) to conversation.
+
+ This is the EXACT logic from CollaboratorAgent.add_scaffolding_to_conversation().
+
+ If with_proper_scaffolding=True: Use LLM to extract relevant notes
+ If with_proper_scaffolding=False: Prepend all notes to first message
+ """
+ if not self.with_proper_scaffolding:
+ # Simple mode: prepend all notes
+ if conversation:
+ conversation = list(conversation) # Copy to avoid mutation
+ conversation[0] = dict(conversation[0]) # Copy first message
+ conversation[0]["content"] = (
+ f"Remember, you have been taking notes throughout past conversations "
+ f"about user preferences. Use whatever is relevant in these notes to "
+ f"guide your response:\n{agent_notes}\n\n" + conversation[0]["content"]
+ )
+ return conversation
+ else:
+ # Proper scaffolding: use LLM to extract relevant notes
+ conversation_str = get_conversation_string(conversation)
+ formatted_prompt = proper_scaffolding_prompt.format(
+ conversation_history=conversation_str,
+ complete_agent_notes=agent_notes
+ )
+
+ # Call LLM to extract relevant notes (with retries)
+ for attempt in range(3):
+ try:
+ messages = [{"role": "user", "content": formatted_prompt}]
+ response = self._generate(messages, max_new_tokens=512)
+
+ parsed = repair_json(response, return_objects=True)
+ missing_keys = [k for k in ["reasoning", "relevant_notes"] if k not in parsed]
+
+ if missing_keys:
+ print(f"[ReflectionAdapter] Scaffolding missing keys: {missing_keys}")
+ continue
+
+ scaffolded_notes = parsed["relevant_notes"]
+
+ # Prepend extracted notes to first message
+ if conversation:
+ conversation = list(conversation)
+ conversation[0] = dict(conversation[0])
+ conversation[0]["content"] = (
+ f"Remember, you have been taking notes throughout past conversations "
+ f"about user preferences. Use these notes to guide your response:\n"
+ f"{scaffolded_notes}\n\n" + conversation[0]["content"]
+ )
+
+ return conversation
+
+ except Exception as e:
+ print(f"[ReflectionAdapter] Scaffolding attempt {attempt+1} failed: {e}")
+ continue
+
+ # Fallback: use all notes if retrieval fails
+ print("[ReflectionAdapter] Scaffolding failed, using full notes")
+ if conversation:
+ conversation = list(conversation)
+ conversation[0] = dict(conversation[0])
+ conversation[0]["content"] = (
+ f"Remember, you have been taking notes throughout past conversations "
+ f"about user preferences. Use whatever is relevant in these notes to "
+ f"guide your response:\n{agent_notes}\n\n" + conversation[0]["content"]
+ )
+ return conversation
+
+ def start_session(self, user_id: str, user_profile: dict = None):
+ """Start a new session for a user."""
+ if not self._initialized:
+ self.initialize()
+
+ self._current_user_id = user_id
+ self._conversation_history = []
+
+ def generate_response(
+ self,
+ query: str,
+ conversation_history: List[Dict[str, str]] = None
+ ) -> Dict[str, Any]:
+ """Generate a response using the reflection agent."""
+ if not self._initialized:
+ self.initialize()
+
+ # Add user query to history
+ self._conversation_history.append({"role": "user", "content": query})
+
+ # Get current notes for this user
+ agent_notes = self._user_notes.get(self._current_user_id, "No notes yet about this user.")
+
+ # Build conversation with scaffolding (uses proper_scaffolding if enabled)
+ if self.with_scaffolding and agent_notes != "No notes yet about this user.":
+ conversation_with_notes = self._add_scaffolding_to_conversation(
+ self._conversation_history, agent_notes
+ )
+ else:
+ conversation_with_notes = self._conversation_history
+
+ # Build system prompt using ORIGINAL CollaborativeAgents prompt
+ # Note: For no_json mode, we don't include agent_notes in system prompt
+ # because they're added via scaffolding to the conversation
+ system_prompt = reflective_agent_system_prompt_no_json.format(
+ agent_notes=agent_notes if not self.with_scaffolding else "See notes in conversation."
+ )
+
+ # Build messages for generation
+ messages = [{"role": "system", "content": system_prompt}]
+ messages.extend(conversation_with_notes)
+
+ # Generate response
+ response_text = self._generate(messages, max_new_tokens=self.max_new_tokens)
+
+ self._conversation_history.append({"role": "assistant", "content": response_text})
+
+ return {
+ "response": response_text,
+ "reasoning": "",
+ "debug": {"agent_notes": agent_notes, "proper_scaffolding": self.with_proper_scaffolding}
+ }
+
+ def prepare_prompt(
+ self,
+ query: str,
+ conversation_history: List[Dict[str, str]] = None
+ ) -> tuple:
+ """
+ Prepare prompt for batch processing without calling main LLM.
+
+ Note: This may still call LLM for scaffolding (memory retrieval),
+ but the main generation is deferred for batching.
+
+ Args:
+ query: Current user query
+ conversation_history: Previous conversation
+
+ Returns:
+ Tuple of (messages, context) for batch processing
+ """
+ if not self._initialized:
+ self.initialize()
+
+ # Add user query to history
+ self._conversation_history.append({"role": "user", "content": query})
+
+ # Get current notes for this user
+ agent_notes = self._user_notes.get(self._current_user_id, "No notes yet about this user.")
+
+ # Build conversation with scaffolding (may involve LLM call for proper_scaffolding)
+ if self.with_scaffolding and agent_notes != "No notes yet about this user.":
+ conversation_with_notes = self._add_scaffolding_to_conversation(
+ self._conversation_history, agent_notes
+ )
+ else:
+ conversation_with_notes = self._conversation_history
+
+ # Build system prompt using ORIGINAL CollaborativeAgents prompt
+ system_prompt = reflective_agent_system_prompt_no_json.format(
+ agent_notes=agent_notes if not self.with_scaffolding else "See notes in conversation."
+ )
+
+ # Build messages for generation
+ messages = [{"role": "system", "content": system_prompt}]
+ messages.extend(conversation_with_notes)
+
+ # Context for post-processing
+ ctx = {
+ "agent_notes": agent_notes,
+ }
+
+ return messages, ctx
+
+ def process_response(
+ self,
+ response: str,
+ context: dict
+ ) -> Dict[str, Any]:
+ """
+ Process LLM response after batch call.
+
+ Args:
+ response: LLM response text
+ context: Context dict from prepare_prompt()
+
+ Returns:
+ Dict with 'response', 'reasoning', and debug info
+ """
+ self._conversation_history.append({"role": "assistant", "content": response})
+
+ return {
+ "response": response,
+ "reasoning": "",
+ "debug": {
+ "agent_notes": context["agent_notes"],
+ "proper_scaffolding": self.with_proper_scaffolding
+ }
+ }
+
+ def _update_agent_notes(self, agent_notes: str, conversation: List[Dict[str, str]]) -> Optional[Dict]:
+ """
+ Update agent notes using ORIGINAL CollaborativeAgents logic.
+
+ For 8B models (no_json): Use raw response directly
+ For JSON models: Check for required keys, retry up to num_retries times
+ Returns None if all retries fail (keeps old notes)
+ """
+ conversation_str = get_conversation_string(conversation)
+ formatted_prompt = update_agent_notes_prompt.format(
+ agent_notes=agent_notes,
+ conversation_str=conversation_str
+ )
+
+ num_retries = 10
+ no_json = True # 8B model
+
+ for attempt in range(num_retries):
+ try:
+ messages = [{"role": "user", "content": formatted_prompt}]
+ response = self._generate(messages, max_new_tokens=512)
+
+ # For 8B models (no_json=True): use raw response directly
+ if no_json:
+ return {"agent_notes": response}
+
+ # For JSON models: parse and check keys
+ processed_response = repair_json(response, return_objects=True)
+ missing_keys = [k for k in ["user_preferences_reasoning", "agent_notes"] if k not in processed_response]
+
+ if missing_keys:
+ print(f"[ReflectionAdapter] Missing keys: {missing_keys}, attempt {attempt + 1}")
+ continue
+
+ return processed_response
+
+ except Exception as e:
+ print(f"[ReflectionAdapter] Failed to update agent notes: {e}")
+
+ return None # All retries failed, keep old notes
+
+ def end_session(self, task_success: bool = False) -> Dict[str, Any]:
+ """
+ End session and update agent notes via reflection.
+
+ Uses the ORIGINAL update_agent_notes logic from CollaborativeAgents.
+ """
+ if not self._current_user_id:
+ return {}
+
+ # Get current notes
+ current_notes = self._user_notes.get(self._current_user_id, "No notes yet.")
+
+ # Update notes via session-level reflection
+ if len(self._conversation_history) > 0:
+ result = self._update_agent_notes(current_notes, self._conversation_history)
+
+ if result is not None and "agent_notes" in result:
+ updated_notes = result["agent_notes"]
+ self._user_notes[self._current_user_id] = updated_notes
+ print(f"[ReflectionAdapter] Updated notes for {self._current_user_id} "
+ f"({len(current_notes)} -> {len(updated_notes)} chars)")
+ else:
+ print(f"[ReflectionAdapter] Keeping old notes for {self._current_user_id} "
+ f"(update failed)")
+
+ return {
+ "turns": len(self._conversation_history),
+ "task_success": task_success,
+ "notes_updated": True,
+ }
+
+ def reset_user(self, user_id: str):
+ """Reset all memory for a user."""
+ if user_id in self._user_notes:
+ del self._user_notes[user_id]
+
+ def __call__(
+ self,
+ messages: List[Dict[str, str]],
+ user_profile: dict = None,
+ **kwargs
+ ) -> str:
+ """Callable interface for ConversationGenerator compatibility."""
+ if not messages:
+ return "How can I help you?"
+
+ last_user_msg = None
+ for msg in reversed(messages):
+ if msg["role"] == "user":
+ last_user_msg = msg["content"]
+ break
+
+ if last_user_msg is None:
+ return "How can I help you?"
+
+ result = self.generate_response(last_user_msg, messages)
+ return result["response"]