diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-27 09:57:37 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-27 09:57:37 -0600 |
| commit | dc801c07cf38b0c495686463e6ca6f871a64440e (patch) | |
| tree | 599f03114775921dbc472403c701f4a3a8ea188a /collaborativeagents/adapters/contextual_adapter.py | |
| parent | e43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (diff) | |
Add collaborativeagents module and update gitignore
- Add collaborativeagents subproject with adapters, agents, and evaluation modules
- Update .gitignore to exclude large binary files (.whl, .tar), wandb logs, and results
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'collaborativeagents/adapters/contextual_adapter.py')
| -rw-r--r-- | collaborativeagents/adapters/contextual_adapter.py | 305 |
1 files changed, 305 insertions, 0 deletions
diff --git a/collaborativeagents/adapters/contextual_adapter.py b/collaborativeagents/adapters/contextual_adapter.py new file mode 100644 index 0000000..ef5e92e --- /dev/null +++ b/collaborativeagents/adapters/contextual_adapter.py @@ -0,0 +1,305 @@ +""" +Contextual Adapter - Full conversation history in context baseline. + +This implements a simple baseline where: +- Full conversation history is passed to the LLM +- No persistent memory across sessions +- Token-based context window truncation to prevent overflow + +Now uses vLLM for fast inference instead of local transformers. +""" + +import sys +from pathlib import Path +from typing import Optional, List, Dict, Any + +# Add parent for utils import +sys.path.insert(0, str(Path(__file__).parent.parent)) +from utils.vllm_client import VLLMClient, VLLMConfig + +# Default vLLM URL (agent server on port 8003) +DEFAULT_VLLM_URL = "http://localhost:8003/v1" + +# Model context limits +MAX_MODEL_LEN = 16384 # vLLM max_model_len setting +MAX_GENERATION_TOKENS = 1024 # Reserved for generation +SYSTEM_PROMPT_BUFFER = 500 # Buffer for system prompt overhead +# Safe limit for conversation context - reduced to force faster forgetting +# This keeps only ~2-3 sessions worth of history visible +MAX_CONTEXT_TOKENS = 4000 # Reduced from ~14860 to make contextual forget faster + +# Basic agent system prompt +AGENT_SYSTEM_PROMPT = """You are a collaborative AI agent helping users solve writing, question answering, math, and coding problems. + +# Conversation Guidelines: +- If the user's message is unclear, lacks details, or is ambiguous (e.g. length of an essay, format requirements, specific constraints), do not make assumptions. Ask for clarification and ensure you have enough information before providing an answer. +- Your goal is to help the user solve their problem. Do your best to help them.""" + + +def estimate_tokens(text: str) -> int: + """ + Estimate token count for text using character-based heuristic. + Uses ~2.5 characters per token which is conservative for LLaMA tokenizers, + especially with math/code content where tokenization is less efficient. + """ + return int(len(text) / 2.5) + 1 + + +def estimate_messages_tokens(messages: List[Dict[str, str]]) -> int: + """Estimate total tokens in a list of messages.""" + total = 0 + for msg in messages: + # Add overhead for role tags and formatting (~4 tokens per message) + total += estimate_tokens(msg.get("content", "")) + 4 + return total + + +class ContextualAdapter: + """ + Contextual baseline - full history in context, no memory. + + Uses vLLM for fast inference, passes full conversation history to the model. + """ + + def __init__( + self, + model_name: str = None, # Ignored - vLLM auto-discovers model + device_assignment: dict = None, # Ignored - vLLM handles GPU + api_base: str = None, # vLLM server URL + api_key: str = None, # Ignored + max_context_turns: int = 15, # Fallback turn-based truncation (reduced from 50) + max_context_tokens: int = None, # Token-based truncation (primary) + vllm_url: str = None, # vLLM server URL + ): + self.vllm_url = vllm_url or api_base or DEFAULT_VLLM_URL + self.max_context_turns = max_context_turns + self.max_context_tokens = max_context_tokens or MAX_CONTEXT_TOKENS + + self._current_user_id: Optional[str] = None + self._conversation_history: List[Dict[str, str]] = [] + + # vLLM client (initialized lazily) + self._client: Optional[VLLMClient] = None + self._initialized = False + + def initialize(self): + """Initialize the adapter (connects to vLLM server).""" + if self._initialized: + return + + print(f"[ContextualAdapter] Connecting to vLLM server at {self.vllm_url}...") + + # Retry connection with exponential backoff + import time + max_retries = 30 + for attempt in range(max_retries): + try: + self._client = VLLMClient(base_url=self.vllm_url) + if self._client.health_check(): + break + except Exception as e: + pass + + if attempt < max_retries - 1: + wait_time = min(2 ** attempt * 0.5, 10) # 0.5, 1, 2, 4, 8, 10, 10... + time.sleep(wait_time) + else: + raise RuntimeError(f"vLLM server not responding at {self.vllm_url} after {max_retries} retries") + + self._initialized = True + print(f"[ContextualAdapter] Connected to vLLM (model: {self._client.config.model})") + + def _truncate_to_token_limit(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]: + """ + Truncate messages to fit within token limit. + Removes oldest messages first while keeping the most recent context. + """ + if not messages: + return messages + + total_tokens = estimate_messages_tokens(messages) + + # If within limit, return as-is + if total_tokens <= self.max_context_tokens: + return messages + + # Truncate from the beginning (oldest messages) until under limit + truncated = list(messages) + while len(truncated) > 1 and estimate_messages_tokens(truncated) > self.max_context_tokens: + truncated.pop(0) + + return truncated + + def _generate(self, messages: List[Dict[str, str]], max_new_tokens: int = 1024) -> str: + """Generate response using vLLM server.""" + if not self._initialized: + self.initialize() + + result = self._client.chat( + messages=messages, + max_tokens=max_new_tokens, + temperature=0.7, + top_p=0.9, + ) + + return result["content"] + + def start_session(self, user_id: str, user_profile: dict = None): + """Start a new session (conversation history persists across sessions for this baseline).""" + if not self._initialized: + self.initialize() + + self._current_user_id = user_id + # NOTE: For contextual baseline, we keep history across sessions + # This is different from vanilla which resets each session + + def generate_response( + self, + query: str, + conversation_history: List[Dict[str, str]] = None + ) -> Dict[str, Any]: + """Generate response with full conversation context.""" + if not self._initialized: + self.initialize() + + # Add current query + self._conversation_history.append({"role": "user", "content": query}) + + # Token-based truncation (primary) - keeps most recent messages within token limit + context = self._truncate_to_token_limit(self._conversation_history) + + # Fallback: also apply turn-based limit if still too many turns + if len(context) > self.max_context_turns * 2: + context = context[-(self.max_context_turns * 2):] + + # Build messages with system prompt + messages = [{"role": "system", "content": AGENT_SYSTEM_PROMPT}] + messages.extend(context) + + # Generate response + response_text = self._generate(messages) + + self._conversation_history.append({"role": "assistant", "content": response_text}) + + # Track truncation for debugging + original_turns = len(self._conversation_history) // 2 + context_turns = len(context) // 2 + truncated = original_turns > context_turns + + return { + "response": response_text, + "reasoning": "", + "debug": { + "context_turns": context_turns, + "total_turns": original_turns, + "truncated": truncated, + "estimated_context_tokens": estimate_messages_tokens(context), + } + } + + def prepare_prompt( + self, + query: str, + conversation_history: List[Dict[str, str]] = None + ) -> tuple: + """ + Prepare prompt for batch processing without calling LLM. + + Args: + query: Current user query + conversation_history: Previous conversation + + Returns: + Tuple of (messages, context) for batch processing + """ + if not self._initialized: + self.initialize() + + # Add current query to history + self._conversation_history.append({"role": "user", "content": query}) + + # Token-based truncation + context = self._truncate_to_token_limit(self._conversation_history) + + # Fallback: also apply turn-based limit + if len(context) > self.max_context_turns * 2: + context = context[-(self.max_context_turns * 2):] + + # Build messages with system prompt + messages = [{"role": "system", "content": AGENT_SYSTEM_PROMPT}] + messages.extend(context) + + # Context for post-processing + ctx = { + "context": context, + "original_history_len": len(self._conversation_history), + } + + return messages, ctx + + def process_response( + self, + response: str, + context: dict + ) -> Dict[str, Any]: + """ + Process LLM response after batch call. + + Args: + response: LLM response text + context: Context dict from prepare_prompt() + + Returns: + Dict with 'response', 'reasoning', and debug info + """ + # Add response to history + self._conversation_history.append({"role": "assistant", "content": response}) + + ctx_context = context["context"] + original_turns = context["original_history_len"] // 2 + context_turns = len(ctx_context) // 2 + truncated = original_turns > context_turns + + return { + "response": response, + "reasoning": "", + "debug": { + "context_turns": context_turns, + "total_turns": original_turns, + "truncated": truncated, + "estimated_context_tokens": estimate_messages_tokens(ctx_context), + } + } + + def end_session(self, task_success: bool = False) -> Dict[str, Any]: + """End session (no memory update for contextual baseline).""" + return { + "turns": len(self._conversation_history), + "task_success": task_success, + } + + def reset_user(self, user_id: str): + """Reset conversation history for user.""" + self._conversation_history = [] + + def __call__( + self, + messages: List[Dict[str, str]], + user_profile: dict = None, + **kwargs + ) -> str: + """Callable interface.""" + if not messages: + return "How can I help you?" + + last_user_msg = None + for msg in reversed(messages): + if msg["role"] == "user": + last_user_msg = msg["content"] + break + + if last_user_msg is None: + return "How can I help you?" + + result = self.generate_response(last_user_msg, messages) + return result["response"] |
