1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
|
"""
Contextual Adapter - Full conversation history in context baseline.
This implements a simple baseline where:
- Full conversation history is passed to the LLM
- No persistent memory across sessions
- Token-based context window truncation to prevent overflow
Now uses vLLM for fast inference instead of local transformers.
"""
import sys
from pathlib import Path
from typing import Optional, List, Dict, Any
# Add parent for utils import
sys.path.insert(0, str(Path(__file__).parent.parent))
from utils.vllm_client import VLLMClient, VLLMConfig
# Default vLLM URL (agent server on port 8003)
DEFAULT_VLLM_URL = "http://localhost:8003/v1"
# Model context limits
MAX_MODEL_LEN = 16384 # vLLM max_model_len setting
MAX_GENERATION_TOKENS = 1024 # Reserved for generation
SYSTEM_PROMPT_BUFFER = 500 # Buffer for system prompt overhead
# Safe limit for conversation context - reduced to force faster forgetting
# This keeps only ~2-3 sessions worth of history visible
MAX_CONTEXT_TOKENS = 4000 # Reduced from ~14860 to make contextual forget faster
# Basic agent system prompt
AGENT_SYSTEM_PROMPT = """You are a collaborative AI agent helping users solve writing, question answering, math, and coding problems.
# Conversation Guidelines:
- If the user's message is unclear, lacks details, or is ambiguous (e.g. length of an essay, format requirements, specific constraints), do not make assumptions. Ask for clarification and ensure you have enough information before providing an answer.
- Your goal is to help the user solve their problem. Do your best to help them."""
def estimate_tokens(text: str) -> int:
"""
Estimate token count for text using character-based heuristic.
Uses ~2.5 characters per token which is conservative for LLaMA tokenizers,
especially with math/code content where tokenization is less efficient.
"""
return int(len(text) / 2.5) + 1
def estimate_messages_tokens(messages: List[Dict[str, str]]) -> int:
"""Estimate total tokens in a list of messages."""
total = 0
for msg in messages:
# Add overhead for role tags and formatting (~4 tokens per message)
total += estimate_tokens(msg.get("content", "")) + 4
return total
class ContextualAdapter:
"""
Contextual baseline - full history in context, no memory.
Uses vLLM for fast inference, passes full conversation history to the model.
"""
def __init__(
self,
model_name: str = None, # Ignored - vLLM auto-discovers model
device_assignment: dict = None, # Ignored - vLLM handles GPU
api_base: str = None, # vLLM server URL
api_key: str = None, # Ignored
max_context_turns: int = 15, # Fallback turn-based truncation (reduced from 50)
max_context_tokens: int = None, # Token-based truncation (primary)
vllm_url: str = None, # vLLM server URL
):
self.vllm_url = vllm_url or api_base or DEFAULT_VLLM_URL
self.max_context_turns = max_context_turns
self.max_context_tokens = max_context_tokens or MAX_CONTEXT_TOKENS
self._current_user_id: Optional[str] = None
self._conversation_history: List[Dict[str, str]] = []
# vLLM client (initialized lazily)
self._client: Optional[VLLMClient] = None
self._initialized = False
def initialize(self):
"""Initialize the adapter (connects to vLLM server)."""
if self._initialized:
return
print(f"[ContextualAdapter] Connecting to vLLM server at {self.vllm_url}...")
# Retry connection with exponential backoff
import time
max_retries = 30
for attempt in range(max_retries):
try:
self._client = VLLMClient(base_url=self.vllm_url)
if self._client.health_check():
break
except Exception as e:
pass
if attempt < max_retries - 1:
wait_time = min(2 ** attempt * 0.5, 10) # 0.5, 1, 2, 4, 8, 10, 10...
time.sleep(wait_time)
else:
raise RuntimeError(f"vLLM server not responding at {self.vllm_url} after {max_retries} retries")
self._initialized = True
print(f"[ContextualAdapter] Connected to vLLM (model: {self._client.config.model})")
def _truncate_to_token_limit(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
"""
Truncate messages to fit within token limit.
Removes oldest messages first while keeping the most recent context.
"""
if not messages:
return messages
total_tokens = estimate_messages_tokens(messages)
# If within limit, return as-is
if total_tokens <= self.max_context_tokens:
return messages
# Truncate from the beginning (oldest messages) until under limit
truncated = list(messages)
while len(truncated) > 1 and estimate_messages_tokens(truncated) > self.max_context_tokens:
truncated.pop(0)
return truncated
def _generate(self, messages: List[Dict[str, str]], max_new_tokens: int = 1024) -> str:
"""Generate response using vLLM server."""
if not self._initialized:
self.initialize()
result = self._client.chat(
messages=messages,
max_tokens=max_new_tokens,
temperature=0.7,
top_p=0.9,
)
return result["content"]
def start_session(self, user_id: str, user_profile: dict = None):
"""Start a new session (conversation history persists across sessions for this baseline)."""
if not self._initialized:
self.initialize()
self._current_user_id = user_id
# NOTE: For contextual baseline, we keep history across sessions
# This is different from vanilla which resets each session
def generate_response(
self,
query: str,
conversation_history: List[Dict[str, str]] = None
) -> Dict[str, Any]:
"""Generate response with full conversation context."""
if not self._initialized:
self.initialize()
# Add current query
self._conversation_history.append({"role": "user", "content": query})
# Token-based truncation (primary) - keeps most recent messages within token limit
context = self._truncate_to_token_limit(self._conversation_history)
# Fallback: also apply turn-based limit if still too many turns
if len(context) > self.max_context_turns * 2:
context = context[-(self.max_context_turns * 2):]
# Build messages with system prompt
messages = [{"role": "system", "content": AGENT_SYSTEM_PROMPT}]
messages.extend(context)
# Generate response
response_text = self._generate(messages)
self._conversation_history.append({"role": "assistant", "content": response_text})
# Track truncation for debugging
original_turns = len(self._conversation_history) // 2
context_turns = len(context) // 2
truncated = original_turns > context_turns
return {
"response": response_text,
"reasoning": "",
"debug": {
"context_turns": context_turns,
"total_turns": original_turns,
"truncated": truncated,
"estimated_context_tokens": estimate_messages_tokens(context),
}
}
def prepare_prompt(
self,
query: str,
conversation_history: List[Dict[str, str]] = None
) -> tuple:
"""
Prepare prompt for batch processing without calling LLM.
Args:
query: Current user query
conversation_history: Previous conversation
Returns:
Tuple of (messages, context) for batch processing
"""
if not self._initialized:
self.initialize()
# Add current query to history
self._conversation_history.append({"role": "user", "content": query})
# Token-based truncation
context = self._truncate_to_token_limit(self._conversation_history)
# Fallback: also apply turn-based limit
if len(context) > self.max_context_turns * 2:
context = context[-(self.max_context_turns * 2):]
# Build messages with system prompt
messages = [{"role": "system", "content": AGENT_SYSTEM_PROMPT}]
messages.extend(context)
# Context for post-processing
ctx = {
"context": context,
"original_history_len": len(self._conversation_history),
}
return messages, ctx
def process_response(
self,
response: str,
context: dict
) -> Dict[str, Any]:
"""
Process LLM response after batch call.
Args:
response: LLM response text
context: Context dict from prepare_prompt()
Returns:
Dict with 'response', 'reasoning', and debug info
"""
# Add response to history
self._conversation_history.append({"role": "assistant", "content": response})
ctx_context = context["context"]
original_turns = context["original_history_len"] // 2
context_turns = len(ctx_context) // 2
truncated = original_turns > context_turns
return {
"response": response,
"reasoning": "",
"debug": {
"context_turns": context_turns,
"total_turns": original_turns,
"truncated": truncated,
"estimated_context_tokens": estimate_messages_tokens(ctx_context),
}
}
def end_session(self, task_success: bool = False) -> Dict[str, Any]:
"""End session (no memory update for contextual baseline)."""
return {
"turns": len(self._conversation_history),
"task_success": task_success,
}
def reset_user(self, user_id: str):
"""Reset conversation history for user."""
self._conversation_history = []
def __call__(
self,
messages: List[Dict[str, str]],
user_profile: dict = None,
**kwargs
) -> str:
"""Callable interface."""
if not messages:
return "How can I help you?"
last_user_msg = None
for msg in reversed(messages):
if msg["role"] == "user":
last_user_msg = msg["content"]
break
if last_user_msg is None:
return "How can I help you?"
result = self.generate_response(last_user_msg, messages)
return result["response"]
|