""" Contextual Adapter - Full conversation history in context baseline. This implements a simple baseline where: - Full conversation history is passed to the LLM - No persistent memory across sessions - Token-based context window truncation to prevent overflow Now uses vLLM for fast inference instead of local transformers. """ import sys from pathlib import Path from typing import Optional, List, Dict, Any # Add parent for utils import sys.path.insert(0, str(Path(__file__).parent.parent)) from utils.vllm_client import VLLMClient, VLLMConfig # Default vLLM URL (agent server on port 8003) DEFAULT_VLLM_URL = "http://localhost:8003/v1" # Model context limits MAX_MODEL_LEN = 16384 # vLLM max_model_len setting MAX_GENERATION_TOKENS = 1024 # Reserved for generation SYSTEM_PROMPT_BUFFER = 500 # Buffer for system prompt overhead # Safe limit for conversation context - reduced to force faster forgetting # This keeps only ~2-3 sessions worth of history visible MAX_CONTEXT_TOKENS = 4000 # Reduced from ~14860 to make contextual forget faster # Basic agent system prompt AGENT_SYSTEM_PROMPT = """You are a collaborative AI agent helping users solve writing, question answering, math, and coding problems. # Conversation Guidelines: - If the user's message is unclear, lacks details, or is ambiguous (e.g. length of an essay, format requirements, specific constraints), do not make assumptions. Ask for clarification and ensure you have enough information before providing an answer. - Your goal is to help the user solve their problem. Do your best to help them.""" def estimate_tokens(text: str) -> int: """ Estimate token count for text using character-based heuristic. Uses ~2.5 characters per token which is conservative for LLaMA tokenizers, especially with math/code content where tokenization is less efficient. """ return int(len(text) / 2.5) + 1 def estimate_messages_tokens(messages: List[Dict[str, str]]) -> int: """Estimate total tokens in a list of messages.""" total = 0 for msg in messages: # Add overhead for role tags and formatting (~4 tokens per message) total += estimate_tokens(msg.get("content", "")) + 4 return total class ContextualAdapter: """ Contextual baseline - full history in context, no memory. Uses vLLM for fast inference, passes full conversation history to the model. """ def __init__( self, model_name: str = None, # Ignored - vLLM auto-discovers model device_assignment: dict = None, # Ignored - vLLM handles GPU api_base: str = None, # vLLM server URL api_key: str = None, # Ignored max_context_turns: int = 15, # Fallback turn-based truncation (reduced from 50) max_context_tokens: int = None, # Token-based truncation (primary) vllm_url: str = None, # vLLM server URL ): self.vllm_url = vllm_url or api_base or DEFAULT_VLLM_URL self.max_context_turns = max_context_turns self.max_context_tokens = max_context_tokens or MAX_CONTEXT_TOKENS self._current_user_id: Optional[str] = None self._conversation_history: List[Dict[str, str]] = [] # vLLM client (initialized lazily) self._client: Optional[VLLMClient] = None self._initialized = False def initialize(self): """Initialize the adapter (connects to vLLM server).""" if self._initialized: return print(f"[ContextualAdapter] Connecting to vLLM server at {self.vllm_url}...") # Retry connection with exponential backoff import time max_retries = 30 for attempt in range(max_retries): try: self._client = VLLMClient(base_url=self.vllm_url) if self._client.health_check(): break except Exception as e: pass if attempt < max_retries - 1: wait_time = min(2 ** attempt * 0.5, 10) # 0.5, 1, 2, 4, 8, 10, 10... time.sleep(wait_time) else: raise RuntimeError(f"vLLM server not responding at {self.vllm_url} after {max_retries} retries") self._initialized = True print(f"[ContextualAdapter] Connected to vLLM (model: {self._client.config.model})") def _truncate_to_token_limit(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]: """ Truncate messages to fit within token limit. Removes oldest messages first while keeping the most recent context. """ if not messages: return messages total_tokens = estimate_messages_tokens(messages) # If within limit, return as-is if total_tokens <= self.max_context_tokens: return messages # Truncate from the beginning (oldest messages) until under limit truncated = list(messages) while len(truncated) > 1 and estimate_messages_tokens(truncated) > self.max_context_tokens: truncated.pop(0) return truncated def _generate(self, messages: List[Dict[str, str]], max_new_tokens: int = 1024) -> str: """Generate response using vLLM server.""" if not self._initialized: self.initialize() result = self._client.chat( messages=messages, max_tokens=max_new_tokens, temperature=0.7, top_p=0.9, ) return result["content"] def start_session(self, user_id: str, user_profile: dict = None): """Start a new session (conversation history persists across sessions for this baseline).""" if not self._initialized: self.initialize() self._current_user_id = user_id # NOTE: For contextual baseline, we keep history across sessions # This is different from vanilla which resets each session def generate_response( self, query: str, conversation_history: List[Dict[str, str]] = None ) -> Dict[str, Any]: """Generate response with full conversation context.""" if not self._initialized: self.initialize() # Add current query self._conversation_history.append({"role": "user", "content": query}) # Token-based truncation (primary) - keeps most recent messages within token limit context = self._truncate_to_token_limit(self._conversation_history) # Fallback: also apply turn-based limit if still too many turns if len(context) > self.max_context_turns * 2: context = context[-(self.max_context_turns * 2):] # Build messages with system prompt messages = [{"role": "system", "content": AGENT_SYSTEM_PROMPT}] messages.extend(context) # Generate response response_text = self._generate(messages) self._conversation_history.append({"role": "assistant", "content": response_text}) # Track truncation for debugging original_turns = len(self._conversation_history) // 2 context_turns = len(context) // 2 truncated = original_turns > context_turns return { "response": response_text, "reasoning": "", "debug": { "context_turns": context_turns, "total_turns": original_turns, "truncated": truncated, "estimated_context_tokens": estimate_messages_tokens(context), } } def prepare_prompt( self, query: str, conversation_history: List[Dict[str, str]] = None ) -> tuple: """ Prepare prompt for batch processing without calling LLM. Args: query: Current user query conversation_history: Previous conversation Returns: Tuple of (messages, context) for batch processing """ if not self._initialized: self.initialize() # Add current query to history self._conversation_history.append({"role": "user", "content": query}) # Token-based truncation context = self._truncate_to_token_limit(self._conversation_history) # Fallback: also apply turn-based limit if len(context) > self.max_context_turns * 2: context = context[-(self.max_context_turns * 2):] # Build messages with system prompt messages = [{"role": "system", "content": AGENT_SYSTEM_PROMPT}] messages.extend(context) # Context for post-processing ctx = { "context": context, "original_history_len": len(self._conversation_history), } return messages, ctx def process_response( self, response: str, context: dict ) -> Dict[str, Any]: """ Process LLM response after batch call. Args: response: LLM response text context: Context dict from prepare_prompt() Returns: Dict with 'response', 'reasoning', and debug info """ # Add response to history self._conversation_history.append({"role": "assistant", "content": response}) ctx_context = context["context"] original_turns = context["original_history_len"] // 2 context_turns = len(ctx_context) // 2 truncated = original_turns > context_turns return { "response": response, "reasoning": "", "debug": { "context_turns": context_turns, "total_turns": original_turns, "truncated": truncated, "estimated_context_tokens": estimate_messages_tokens(ctx_context), } } def end_session(self, task_success: bool = False) -> Dict[str, Any]: """End session (no memory update for contextual baseline).""" return { "turns": len(self._conversation_history), "task_success": task_success, } def reset_user(self, user_id: str): """Reset conversation history for user.""" self._conversation_history = [] def __call__( self, messages: List[Dict[str, str]], user_profile: dict = None, **kwargs ) -> str: """Callable interface.""" if not messages: return "How can I help you?" last_user_msg = None for msg in reversed(messages): if msg["role"] == "user": last_user_msg = msg["content"] break if last_user_msg is None: return "How can I help you?" result = self.generate_response(last_user_msg, messages) return result["response"]