summaryrefslogtreecommitdiff
path: root/collaborativeagents/adapters/contextual_adapter.py
blob: ef5e92ee4ed4c654da61eb7bb3e431be9db72566 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
"""
Contextual Adapter - Full conversation history in context baseline.

This implements a simple baseline where:
- Full conversation history is passed to the LLM
- No persistent memory across sessions
- Token-based context window truncation to prevent overflow

Now uses vLLM for fast inference instead of local transformers.
"""

import sys
from pathlib import Path
from typing import Optional, List, Dict, Any

# Add parent for utils import
sys.path.insert(0, str(Path(__file__).parent.parent))
from utils.vllm_client import VLLMClient, VLLMConfig

# Default vLLM URL (agent server on port 8003)
DEFAULT_VLLM_URL = "http://localhost:8003/v1"

# Model context limits
MAX_MODEL_LEN = 16384  # vLLM max_model_len setting
MAX_GENERATION_TOKENS = 1024  # Reserved for generation
SYSTEM_PROMPT_BUFFER = 500  # Buffer for system prompt overhead
# Safe limit for conversation context - reduced to force faster forgetting
# This keeps only ~2-3 sessions worth of history visible
MAX_CONTEXT_TOKENS = 4000  # Reduced from ~14860 to make contextual forget faster

# Basic agent system prompt
AGENT_SYSTEM_PROMPT = """You are a collaborative AI agent helping users solve writing, question answering, math, and coding problems.

# Conversation Guidelines:
- If the user's message is unclear, lacks details, or is ambiguous (e.g. length of an essay, format requirements, specific constraints), do not make assumptions. Ask for clarification and ensure you have enough information before providing an answer.
- Your goal is to help the user solve their problem. Do your best to help them."""


def estimate_tokens(text: str) -> int:
    """
    Estimate token count for text using character-based heuristic.
    Uses ~2.5 characters per token which is conservative for LLaMA tokenizers,
    especially with math/code content where tokenization is less efficient.
    """
    return int(len(text) / 2.5) + 1


def estimate_messages_tokens(messages: List[Dict[str, str]]) -> int:
    """Estimate total tokens in a list of messages."""
    total = 0
    for msg in messages:
        # Add overhead for role tags and formatting (~4 tokens per message)
        total += estimate_tokens(msg.get("content", "")) + 4
    return total


class ContextualAdapter:
    """
    Contextual baseline - full history in context, no memory.

    Uses vLLM for fast inference, passes full conversation history to the model.
    """

    def __init__(
        self,
        model_name: str = None,  # Ignored - vLLM auto-discovers model
        device_assignment: dict = None,  # Ignored - vLLM handles GPU
        api_base: str = None,  # vLLM server URL
        api_key: str = None,   # Ignored
        max_context_turns: int = 15,  # Fallback turn-based truncation (reduced from 50)
        max_context_tokens: int = None,  # Token-based truncation (primary)
        vllm_url: str = None,  # vLLM server URL
    ):
        self.vllm_url = vllm_url or api_base or DEFAULT_VLLM_URL
        self.max_context_turns = max_context_turns
        self.max_context_tokens = max_context_tokens or MAX_CONTEXT_TOKENS

        self._current_user_id: Optional[str] = None
        self._conversation_history: List[Dict[str, str]] = []

        # vLLM client (initialized lazily)
        self._client: Optional[VLLMClient] = None
        self._initialized = False

    def initialize(self):
        """Initialize the adapter (connects to vLLM server)."""
        if self._initialized:
            return

        print(f"[ContextualAdapter] Connecting to vLLM server at {self.vllm_url}...")

        # Retry connection with exponential backoff
        import time
        max_retries = 30
        for attempt in range(max_retries):
            try:
                self._client = VLLMClient(base_url=self.vllm_url)
                if self._client.health_check():
                    break
            except Exception as e:
                pass

            if attempt < max_retries - 1:
                wait_time = min(2 ** attempt * 0.5, 10)  # 0.5, 1, 2, 4, 8, 10, 10...
                time.sleep(wait_time)
        else:
            raise RuntimeError(f"vLLM server not responding at {self.vllm_url} after {max_retries} retries")

        self._initialized = True
        print(f"[ContextualAdapter] Connected to vLLM (model: {self._client.config.model})")

    def _truncate_to_token_limit(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
        """
        Truncate messages to fit within token limit.
        Removes oldest messages first while keeping the most recent context.
        """
        if not messages:
            return messages

        total_tokens = estimate_messages_tokens(messages)

        # If within limit, return as-is
        if total_tokens <= self.max_context_tokens:
            return messages

        # Truncate from the beginning (oldest messages) until under limit
        truncated = list(messages)
        while len(truncated) > 1 and estimate_messages_tokens(truncated) > self.max_context_tokens:
            truncated.pop(0)

        return truncated

    def _generate(self, messages: List[Dict[str, str]], max_new_tokens: int = 1024) -> str:
        """Generate response using vLLM server."""
        if not self._initialized:
            self.initialize()

        result = self._client.chat(
            messages=messages,
            max_tokens=max_new_tokens,
            temperature=0.7,
            top_p=0.9,
        )

        return result["content"]

    def start_session(self, user_id: str, user_profile: dict = None):
        """Start a new session (conversation history persists across sessions for this baseline)."""
        if not self._initialized:
            self.initialize()

        self._current_user_id = user_id
        # NOTE: For contextual baseline, we keep history across sessions
        # This is different from vanilla which resets each session

    def generate_response(
        self,
        query: str,
        conversation_history: List[Dict[str, str]] = None
    ) -> Dict[str, Any]:
        """Generate response with full conversation context."""
        if not self._initialized:
            self.initialize()

        # Add current query
        self._conversation_history.append({"role": "user", "content": query})

        # Token-based truncation (primary) - keeps most recent messages within token limit
        context = self._truncate_to_token_limit(self._conversation_history)

        # Fallback: also apply turn-based limit if still too many turns
        if len(context) > self.max_context_turns * 2:
            context = context[-(self.max_context_turns * 2):]

        # Build messages with system prompt
        messages = [{"role": "system", "content": AGENT_SYSTEM_PROMPT}]
        messages.extend(context)

        # Generate response
        response_text = self._generate(messages)

        self._conversation_history.append({"role": "assistant", "content": response_text})

        # Track truncation for debugging
        original_turns = len(self._conversation_history) // 2
        context_turns = len(context) // 2
        truncated = original_turns > context_turns

        return {
            "response": response_text,
            "reasoning": "",
            "debug": {
                "context_turns": context_turns,
                "total_turns": original_turns,
                "truncated": truncated,
                "estimated_context_tokens": estimate_messages_tokens(context),
            }
        }

    def prepare_prompt(
        self,
        query: str,
        conversation_history: List[Dict[str, str]] = None
    ) -> tuple:
        """
        Prepare prompt for batch processing without calling LLM.

        Args:
            query: Current user query
            conversation_history: Previous conversation

        Returns:
            Tuple of (messages, context) for batch processing
        """
        if not self._initialized:
            self.initialize()

        # Add current query to history
        self._conversation_history.append({"role": "user", "content": query})

        # Token-based truncation
        context = self._truncate_to_token_limit(self._conversation_history)

        # Fallback: also apply turn-based limit
        if len(context) > self.max_context_turns * 2:
            context = context[-(self.max_context_turns * 2):]

        # Build messages with system prompt
        messages = [{"role": "system", "content": AGENT_SYSTEM_PROMPT}]
        messages.extend(context)

        # Context for post-processing
        ctx = {
            "context": context,
            "original_history_len": len(self._conversation_history),
        }

        return messages, ctx

    def process_response(
        self,
        response: str,
        context: dict
    ) -> Dict[str, Any]:
        """
        Process LLM response after batch call.

        Args:
            response: LLM response text
            context: Context dict from prepare_prompt()

        Returns:
            Dict with 'response', 'reasoning', and debug info
        """
        # Add response to history
        self._conversation_history.append({"role": "assistant", "content": response})

        ctx_context = context["context"]
        original_turns = context["original_history_len"] // 2
        context_turns = len(ctx_context) // 2
        truncated = original_turns > context_turns

        return {
            "response": response,
            "reasoning": "",
            "debug": {
                "context_turns": context_turns,
                "total_turns": original_turns,
                "truncated": truncated,
                "estimated_context_tokens": estimate_messages_tokens(ctx_context),
            }
        }

    def end_session(self, task_success: bool = False) -> Dict[str, Any]:
        """End session (no memory update for contextual baseline)."""
        return {
            "turns": len(self._conversation_history),
            "task_success": task_success,
        }

    def reset_user(self, user_id: str):
        """Reset conversation history for user."""
        self._conversation_history = []

    def __call__(
        self,
        messages: List[Dict[str, str]],
        user_profile: dict = None,
        **kwargs
    ) -> str:
        """Callable interface."""
        if not messages:
            return "How can I help you?"

        last_user_msg = None
        for msg in reversed(messages):
            if msg["role"] == "user":
                last_user_msg = msg["content"]
                break

        if last_user_msg is None:
            return "How can I help you?"

        result = self.generate_response(last_user_msg, messages)
        return result["response"]