""" Reflection Adapter - vLLM-based implementation using original CollaborativeAgents prompts. This implements the "reflection" baseline from the MULTISESSIONCOLLAB paper: - After each session, agent reflects on interaction to update memory (agent_notes) - Memory is provided to agent at start of subsequent sessions - Uses session-level reflection + persistent memory - Uses LLM-based retrieval (proper_scaffolding) to prevent context overflow Now uses vLLM for fast inference instead of local transformers. Uses EXACT prompts from the original CollaborativeAgents paper for fairness. """ import sys from pathlib import Path from typing import Optional, List, Dict, Any from json_repair import repair_json # Add parent for utils import sys.path.insert(0, str(Path(__file__).parent.parent)) from utils.vllm_client import VLLMClient, VLLMConfig # Import ORIGINAL prompts from CollaborativeAgents for fair reproduction sys.path.insert(0, str(Path(__file__).parent.parent / "collaborativeagents")) from collaborativeagents.prompts import ( reflective_agent_system_prompt_no_json, update_agent_notes_prompt, proper_scaffolding_prompt, ) from collaborativeagents.utils import get_conversation_string # Default vLLM URL (agent server on port 8003) DEFAULT_VLLM_URL = "http://localhost:8003/v1" class ReflectionAdapter: """ Adapter for the Reflection baseline from MULTISESSIONCOLLAB. Uses vLLM for fast inference with: - agent_notes: Persistent memory updated via session-level reflection - Memory retrieval at each turn via proper_scaffolding (when notes are long) Uses ORIGINAL CollaborativeAgents prompts for fair benchmark reproduction. """ def __init__( self, model_name: str = None, # Ignored - vLLM auto-discovers model device_assignment: dict = None, # Ignored - vLLM handles GPU api_base: str = None, # vLLM server URL api_key: str = None, # Ignored with_scaffolding: bool = True, with_proper_scaffolding: bool = True, # Enable LLM-based retrieval by default vllm_url: str = None, # vLLM server URL max_new_tokens: int = 2048, # Match original CollaborativeAgents setting ): self.vllm_url = vllm_url or api_base or DEFAULT_VLLM_URL self.with_scaffolding = with_scaffolding self.with_proper_scaffolding = with_proper_scaffolding self.max_new_tokens = max_new_tokens # Per-user memory storage self._user_notes: Dict[str, str] = {} self._current_user_id: Optional[str] = None self._conversation_history: List[Dict[str, str]] = [] # vLLM client (initialized lazily) self._client: Optional[VLLMClient] = None self._initialized = False def initialize(self): """Initialize the adapter (connects to vLLM server).""" if self._initialized: return print(f"[ReflectionAdapter] Connecting to vLLM server at {self.vllm_url}...") print(f"[ReflectionAdapter] Using proper_scaffolding={self.with_proper_scaffolding}") # Retry connection with exponential backoff import time max_retries = 30 for attempt in range(max_retries): try: self._client = VLLMClient(base_url=self.vllm_url) if self._client.health_check(): break except Exception as e: pass if attempt < max_retries - 1: wait_time = min(2 ** attempt * 0.5, 10) # 0.5, 1, 2, 4, 8, 10, 10... time.sleep(wait_time) else: raise RuntimeError(f"vLLM server not responding at {self.vllm_url} after {max_retries} retries") self._initialized = True print(f"[ReflectionAdapter] Connected to vLLM (model: {self._client.config.model})") def _generate(self, messages: List[Dict[str, str]], max_new_tokens: int = 1024) -> str: """Generate response using vLLM server.""" if not self._initialized: self.initialize() result = self._client.chat( messages=messages, max_tokens=max_new_tokens, temperature=1.0, # Match original CollaborativeAgents setting top_p=0.9, ) return result["content"] def _add_scaffolding_to_conversation( self, conversation: List[Dict[str, str]], agent_notes: str ) -> List[Dict[str, str]]: """ Add scaffolding (memory notes) to conversation. This is the EXACT logic from CollaboratorAgent.add_scaffolding_to_conversation(). If with_proper_scaffolding=True: Use LLM to extract relevant notes If with_proper_scaffolding=False: Prepend all notes to first message """ if not self.with_proper_scaffolding: # Simple mode: prepend all notes if conversation: conversation = list(conversation) # Copy to avoid mutation conversation[0] = dict(conversation[0]) # Copy first message conversation[0]["content"] = ( f"Remember, you have been taking notes throughout past conversations " f"about user preferences. Use whatever is relevant in these notes to " f"guide your response:\n{agent_notes}\n\n" + conversation[0]["content"] ) return conversation else: # Proper scaffolding: use LLM to extract relevant notes conversation_str = get_conversation_string(conversation) formatted_prompt = proper_scaffolding_prompt.format( conversation_history=conversation_str, complete_agent_notes=agent_notes ) # Call LLM to extract relevant notes (with retries) for attempt in range(3): try: messages = [{"role": "user", "content": formatted_prompt}] response = self._generate(messages, max_new_tokens=512) parsed = repair_json(response, return_objects=True) missing_keys = [k for k in ["reasoning", "relevant_notes"] if k not in parsed] if missing_keys: print(f"[ReflectionAdapter] Scaffolding missing keys: {missing_keys}") continue scaffolded_notes = parsed["relevant_notes"] # Prepend extracted notes to first message if conversation: conversation = list(conversation) conversation[0] = dict(conversation[0]) conversation[0]["content"] = ( f"Remember, you have been taking notes throughout past conversations " f"about user preferences. Use these notes to guide your response:\n" f"{scaffolded_notes}\n\n" + conversation[0]["content"] ) return conversation except Exception as e: print(f"[ReflectionAdapter] Scaffolding attempt {attempt+1} failed: {e}") continue # Fallback: use all notes if retrieval fails print("[ReflectionAdapter] Scaffolding failed, using full notes") if conversation: conversation = list(conversation) conversation[0] = dict(conversation[0]) conversation[0]["content"] = ( f"Remember, you have been taking notes throughout past conversations " f"about user preferences. Use whatever is relevant in these notes to " f"guide your response:\n{agent_notes}\n\n" + conversation[0]["content"] ) return conversation def start_session(self, user_id: str, user_profile: dict = None): """Start a new session for a user.""" if not self._initialized: self.initialize() self._current_user_id = user_id self._conversation_history = [] def generate_response( self, query: str, conversation_history: List[Dict[str, str]] = None ) -> Dict[str, Any]: """Generate a response using the reflection agent.""" if not self._initialized: self.initialize() # Add user query to history self._conversation_history.append({"role": "user", "content": query}) # Get current notes for this user agent_notes = self._user_notes.get(self._current_user_id, "No notes yet about this user.") # Build conversation with scaffolding (uses proper_scaffolding if enabled) if self.with_scaffolding and agent_notes != "No notes yet about this user.": conversation_with_notes = self._add_scaffolding_to_conversation( self._conversation_history, agent_notes ) else: conversation_with_notes = self._conversation_history # Build system prompt using ORIGINAL CollaborativeAgents prompt # Note: For no_json mode, we don't include agent_notes in system prompt # because they're added via scaffolding to the conversation system_prompt = reflective_agent_system_prompt_no_json.format( agent_notes=agent_notes if not self.with_scaffolding else "See notes in conversation." ) # Build messages for generation messages = [{"role": "system", "content": system_prompt}] messages.extend(conversation_with_notes) # Generate response response_text = self._generate(messages, max_new_tokens=self.max_new_tokens) self._conversation_history.append({"role": "assistant", "content": response_text}) return { "response": response_text, "reasoning": "", "debug": {"agent_notes": agent_notes, "proper_scaffolding": self.with_proper_scaffolding} } def prepare_prompt( self, query: str, conversation_history: List[Dict[str, str]] = None ) -> tuple: """ Prepare prompt for batch processing without calling main LLM. Note: This may still call LLM for scaffolding (memory retrieval), but the main generation is deferred for batching. Args: query: Current user query conversation_history: Previous conversation Returns: Tuple of (messages, context) for batch processing """ if not self._initialized: self.initialize() # Add user query to history self._conversation_history.append({"role": "user", "content": query}) # Get current notes for this user agent_notes = self._user_notes.get(self._current_user_id, "No notes yet about this user.") # Build conversation with scaffolding (may involve LLM call for proper_scaffolding) if self.with_scaffolding and agent_notes != "No notes yet about this user.": conversation_with_notes = self._add_scaffolding_to_conversation( self._conversation_history, agent_notes ) else: conversation_with_notes = self._conversation_history # Build system prompt using ORIGINAL CollaborativeAgents prompt system_prompt = reflective_agent_system_prompt_no_json.format( agent_notes=agent_notes if not self.with_scaffolding else "See notes in conversation." ) # Build messages for generation messages = [{"role": "system", "content": system_prompt}] messages.extend(conversation_with_notes) # Context for post-processing ctx = { "agent_notes": agent_notes, } return messages, ctx def process_response( self, response: str, context: dict ) -> Dict[str, Any]: """ Process LLM response after batch call. Args: response: LLM response text context: Context dict from prepare_prompt() Returns: Dict with 'response', 'reasoning', and debug info """ self._conversation_history.append({"role": "assistant", "content": response}) return { "response": response, "reasoning": "", "debug": { "agent_notes": context["agent_notes"], "proper_scaffolding": self.with_proper_scaffolding } } def _update_agent_notes(self, agent_notes: str, conversation: List[Dict[str, str]]) -> Optional[Dict]: """ Update agent notes using ORIGINAL CollaborativeAgents logic. For 8B models (no_json): Use raw response directly For JSON models: Check for required keys, retry up to num_retries times Returns None if all retries fail (keeps old notes) """ conversation_str = get_conversation_string(conversation) formatted_prompt = update_agent_notes_prompt.format( agent_notes=agent_notes, conversation_str=conversation_str ) num_retries = 10 no_json = True # 8B model for attempt in range(num_retries): try: messages = [{"role": "user", "content": formatted_prompt}] response = self._generate(messages, max_new_tokens=512) # For 8B models (no_json=True): use raw response directly if no_json: return {"agent_notes": response} # For JSON models: parse and check keys processed_response = repair_json(response, return_objects=True) missing_keys = [k for k in ["user_preferences_reasoning", "agent_notes"] if k not in processed_response] if missing_keys: print(f"[ReflectionAdapter] Missing keys: {missing_keys}, attempt {attempt + 1}") continue return processed_response except Exception as e: print(f"[ReflectionAdapter] Failed to update agent notes: {e}") return None # All retries failed, keep old notes def end_session(self, task_success: bool = False) -> Dict[str, Any]: """ End session and update agent notes via reflection. Uses the ORIGINAL update_agent_notes logic from CollaborativeAgents. """ if not self._current_user_id: return {} # Get current notes current_notes = self._user_notes.get(self._current_user_id, "No notes yet.") # Update notes via session-level reflection if len(self._conversation_history) > 0: result = self._update_agent_notes(current_notes, self._conversation_history) if result is not None and "agent_notes" in result: updated_notes = result["agent_notes"] self._user_notes[self._current_user_id] = updated_notes print(f"[ReflectionAdapter] Updated notes for {self._current_user_id} " f"({len(current_notes)} -> {len(updated_notes)} chars)") else: print(f"[ReflectionAdapter] Keeping old notes for {self._current_user_id} " f"(update failed)") return { "turns": len(self._conversation_history), "task_success": task_success, "notes_updated": True, } def reset_user(self, user_id: str): """Reset all memory for a user.""" if user_id in self._user_notes: del self._user_notes[user_id] def __call__( self, messages: List[Dict[str, str]], user_profile: dict = None, **kwargs ) -> str: """Callable interface for ConversationGenerator compatibility.""" if not messages: return "How can I help you?" last_user_msg = None for msg in reversed(messages): if msg["role"] == "user": last_user_msg = msg["content"] break if last_user_msg is None: return "How can I help you?" result = self.generate_response(last_user_msg, messages) return result["response"]