summaryrefslogtreecommitdiff
path: root/collaborativeagents/adapters
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-27 09:57:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-27 09:57:37 -0600
commitdc801c07cf38b0c495686463e6ca6f871a64440e (patch)
tree599f03114775921dbc472403c701f4a3a8ea188a /collaborativeagents/adapters
parente43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (diff)
Add collaborativeagents module and update gitignore
- Add collaborativeagents subproject with adapters, agents, and evaluation modules - Update .gitignore to exclude large binary files (.whl, .tar), wandb logs, and results Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'collaborativeagents/adapters')
-rw-r--r--collaborativeagents/adapters/__init__.py15
-rw-r--r--collaborativeagents/adapters/contextual_adapter.py305
-rw-r--r--collaborativeagents/adapters/personalized_llm_adapter.py731
-rw-r--r--collaborativeagents/adapters/reflection_adapter.py416
-rw-r--r--collaborativeagents/adapters/reflection_grpo_adapter.py321
5 files changed, 1788 insertions, 0 deletions
diff --git a/collaborativeagents/adapters/__init__.py b/collaborativeagents/adapters/__init__.py
new file mode 100644
index 0000000..b5cfda6
--- /dev/null
+++ b/collaborativeagents/adapters/__init__.py
@@ -0,0 +1,15 @@
+"""Adapters for integrating PersonalizedLLM with benchmark frameworks."""
+
+from .personalized_llm_adapter import (
+ AdapterConfig,
+ PersonalizedLLMAdapter,
+ PersonalizedCollaborator,
+ create_baseline_adapter,
+)
+
+__all__ = [
+ "AdapterConfig",
+ "PersonalizedLLMAdapter",
+ "PersonalizedCollaborator",
+ "create_baseline_adapter",
+]
diff --git a/collaborativeagents/adapters/contextual_adapter.py b/collaborativeagents/adapters/contextual_adapter.py
new file mode 100644
index 0000000..ef5e92e
--- /dev/null
+++ b/collaborativeagents/adapters/contextual_adapter.py
@@ -0,0 +1,305 @@
+"""
+Contextual Adapter - Full conversation history in context baseline.
+
+This implements a simple baseline where:
+- Full conversation history is passed to the LLM
+- No persistent memory across sessions
+- Token-based context window truncation to prevent overflow
+
+Now uses vLLM for fast inference instead of local transformers.
+"""
+
+import sys
+from pathlib import Path
+from typing import Optional, List, Dict, Any
+
+# Add parent for utils import
+sys.path.insert(0, str(Path(__file__).parent.parent))
+from utils.vllm_client import VLLMClient, VLLMConfig
+
+# Default vLLM URL (agent server on port 8003)
+DEFAULT_VLLM_URL = "http://localhost:8003/v1"
+
+# Model context limits
+MAX_MODEL_LEN = 16384 # vLLM max_model_len setting
+MAX_GENERATION_TOKENS = 1024 # Reserved for generation
+SYSTEM_PROMPT_BUFFER = 500 # Buffer for system prompt overhead
+# Safe limit for conversation context - reduced to force faster forgetting
+# This keeps only ~2-3 sessions worth of history visible
+MAX_CONTEXT_TOKENS = 4000 # Reduced from ~14860 to make contextual forget faster
+
+# Basic agent system prompt
+AGENT_SYSTEM_PROMPT = """You are a collaborative AI agent helping users solve writing, question answering, math, and coding problems.
+
+# Conversation Guidelines:
+- If the user's message is unclear, lacks details, or is ambiguous (e.g. length of an essay, format requirements, specific constraints), do not make assumptions. Ask for clarification and ensure you have enough information before providing an answer.
+- Your goal is to help the user solve their problem. Do your best to help them."""
+
+
+def estimate_tokens(text: str) -> int:
+ """
+ Estimate token count for text using character-based heuristic.
+ Uses ~2.5 characters per token which is conservative for LLaMA tokenizers,
+ especially with math/code content where tokenization is less efficient.
+ """
+ return int(len(text) / 2.5) + 1
+
+
+def estimate_messages_tokens(messages: List[Dict[str, str]]) -> int:
+ """Estimate total tokens in a list of messages."""
+ total = 0
+ for msg in messages:
+ # Add overhead for role tags and formatting (~4 tokens per message)
+ total += estimate_tokens(msg.get("content", "")) + 4
+ return total
+
+
+class ContextualAdapter:
+ """
+ Contextual baseline - full history in context, no memory.
+
+ Uses vLLM for fast inference, passes full conversation history to the model.
+ """
+
+ def __init__(
+ self,
+ model_name: str = None, # Ignored - vLLM auto-discovers model
+ device_assignment: dict = None, # Ignored - vLLM handles GPU
+ api_base: str = None, # vLLM server URL
+ api_key: str = None, # Ignored
+ max_context_turns: int = 15, # Fallback turn-based truncation (reduced from 50)
+ max_context_tokens: int = None, # Token-based truncation (primary)
+ vllm_url: str = None, # vLLM server URL
+ ):
+ self.vllm_url = vllm_url or api_base or DEFAULT_VLLM_URL
+ self.max_context_turns = max_context_turns
+ self.max_context_tokens = max_context_tokens or MAX_CONTEXT_TOKENS
+
+ self._current_user_id: Optional[str] = None
+ self._conversation_history: List[Dict[str, str]] = []
+
+ # vLLM client (initialized lazily)
+ self._client: Optional[VLLMClient] = None
+ self._initialized = False
+
+ def initialize(self):
+ """Initialize the adapter (connects to vLLM server)."""
+ if self._initialized:
+ return
+
+ print(f"[ContextualAdapter] Connecting to vLLM server at {self.vllm_url}...")
+
+ # Retry connection with exponential backoff
+ import time
+ max_retries = 30
+ for attempt in range(max_retries):
+ try:
+ self._client = VLLMClient(base_url=self.vllm_url)
+ if self._client.health_check():
+ break
+ except Exception as e:
+ pass
+
+ if attempt < max_retries - 1:
+ wait_time = min(2 ** attempt * 0.5, 10) # 0.5, 1, 2, 4, 8, 10, 10...
+ time.sleep(wait_time)
+ else:
+ raise RuntimeError(f"vLLM server not responding at {self.vllm_url} after {max_retries} retries")
+
+ self._initialized = True
+ print(f"[ContextualAdapter] Connected to vLLM (model: {self._client.config.model})")
+
+ def _truncate_to_token_limit(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
+ """
+ Truncate messages to fit within token limit.
+ Removes oldest messages first while keeping the most recent context.
+ """
+ if not messages:
+ return messages
+
+ total_tokens = estimate_messages_tokens(messages)
+
+ # If within limit, return as-is
+ if total_tokens <= self.max_context_tokens:
+ return messages
+
+ # Truncate from the beginning (oldest messages) until under limit
+ truncated = list(messages)
+ while len(truncated) > 1 and estimate_messages_tokens(truncated) > self.max_context_tokens:
+ truncated.pop(0)
+
+ return truncated
+
+ def _generate(self, messages: List[Dict[str, str]], max_new_tokens: int = 1024) -> str:
+ """Generate response using vLLM server."""
+ if not self._initialized:
+ self.initialize()
+
+ result = self._client.chat(
+ messages=messages,
+ max_tokens=max_new_tokens,
+ temperature=0.7,
+ top_p=0.9,
+ )
+
+ return result["content"]
+
+ def start_session(self, user_id: str, user_profile: dict = None):
+ """Start a new session (conversation history persists across sessions for this baseline)."""
+ if not self._initialized:
+ self.initialize()
+
+ self._current_user_id = user_id
+ # NOTE: For contextual baseline, we keep history across sessions
+ # This is different from vanilla which resets each session
+
+ def generate_response(
+ self,
+ query: str,
+ conversation_history: List[Dict[str, str]] = None
+ ) -> Dict[str, Any]:
+ """Generate response with full conversation context."""
+ if not self._initialized:
+ self.initialize()
+
+ # Add current query
+ self._conversation_history.append({"role": "user", "content": query})
+
+ # Token-based truncation (primary) - keeps most recent messages within token limit
+ context = self._truncate_to_token_limit(self._conversation_history)
+
+ # Fallback: also apply turn-based limit if still too many turns
+ if len(context) > self.max_context_turns * 2:
+ context = context[-(self.max_context_turns * 2):]
+
+ # Build messages with system prompt
+ messages = [{"role": "system", "content": AGENT_SYSTEM_PROMPT}]
+ messages.extend(context)
+
+ # Generate response
+ response_text = self._generate(messages)
+
+ self._conversation_history.append({"role": "assistant", "content": response_text})
+
+ # Track truncation for debugging
+ original_turns = len(self._conversation_history) // 2
+ context_turns = len(context) // 2
+ truncated = original_turns > context_turns
+
+ return {
+ "response": response_text,
+ "reasoning": "",
+ "debug": {
+ "context_turns": context_turns,
+ "total_turns": original_turns,
+ "truncated": truncated,
+ "estimated_context_tokens": estimate_messages_tokens(context),
+ }
+ }
+
+ def prepare_prompt(
+ self,
+ query: str,
+ conversation_history: List[Dict[str, str]] = None
+ ) -> tuple:
+ """
+ Prepare prompt for batch processing without calling LLM.
+
+ Args:
+ query: Current user query
+ conversation_history: Previous conversation
+
+ Returns:
+ Tuple of (messages, context) for batch processing
+ """
+ if not self._initialized:
+ self.initialize()
+
+ # Add current query to history
+ self._conversation_history.append({"role": "user", "content": query})
+
+ # Token-based truncation
+ context = self._truncate_to_token_limit(self._conversation_history)
+
+ # Fallback: also apply turn-based limit
+ if len(context) > self.max_context_turns * 2:
+ context = context[-(self.max_context_turns * 2):]
+
+ # Build messages with system prompt
+ messages = [{"role": "system", "content": AGENT_SYSTEM_PROMPT}]
+ messages.extend(context)
+
+ # Context for post-processing
+ ctx = {
+ "context": context,
+ "original_history_len": len(self._conversation_history),
+ }
+
+ return messages, ctx
+
+ def process_response(
+ self,
+ response: str,
+ context: dict
+ ) -> Dict[str, Any]:
+ """
+ Process LLM response after batch call.
+
+ Args:
+ response: LLM response text
+ context: Context dict from prepare_prompt()
+
+ Returns:
+ Dict with 'response', 'reasoning', and debug info
+ """
+ # Add response to history
+ self._conversation_history.append({"role": "assistant", "content": response})
+
+ ctx_context = context["context"]
+ original_turns = context["original_history_len"] // 2
+ context_turns = len(ctx_context) // 2
+ truncated = original_turns > context_turns
+
+ return {
+ "response": response,
+ "reasoning": "",
+ "debug": {
+ "context_turns": context_turns,
+ "total_turns": original_turns,
+ "truncated": truncated,
+ "estimated_context_tokens": estimate_messages_tokens(ctx_context),
+ }
+ }
+
+ def end_session(self, task_success: bool = False) -> Dict[str, Any]:
+ """End session (no memory update for contextual baseline)."""
+ return {
+ "turns": len(self._conversation_history),
+ "task_success": task_success,
+ }
+
+ def reset_user(self, user_id: str):
+ """Reset conversation history for user."""
+ self._conversation_history = []
+
+ def __call__(
+ self,
+ messages: List[Dict[str, str]],
+ user_profile: dict = None,
+ **kwargs
+ ) -> str:
+ """Callable interface."""
+ if not messages:
+ return "How can I help you?"
+
+ last_user_msg = None
+ for msg in reversed(messages):
+ if msg["role"] == "user":
+ last_user_msg = msg["content"]
+ break
+
+ if last_user_msg is None:
+ return "How can I help you?"
+
+ result = self.generate_response(last_user_msg, messages)
+ return result["response"]
diff --git a/collaborativeagents/adapters/personalized_llm_adapter.py b/collaborativeagents/adapters/personalized_llm_adapter.py
new file mode 100644
index 0000000..c2d4727
--- /dev/null
+++ b/collaborativeagents/adapters/personalized_llm_adapter.py
@@ -0,0 +1,731 @@
+"""
+Adapter to integrate PersonalizedLLM with CollaborativeAgents benchmark.
+
+This adapter wraps PersonalizedLLM to work as a CollaboratorAgent in the
+MULTISESSIONCOLLAB framework while maintaining all personalization features.
+"""
+
+import sys
+import os
+from pathlib import Path
+from typing import Optional, List, Dict, Any
+from dataclasses import dataclass, field
+import json
+import numpy as np
+
+# Add paths
+_project_root = Path(__file__).parent.parent.parent
+sys.path.insert(0, str(_project_root / "src"))
+
+# Import from your personalization system
+from personalization.serving.personalized_llm import (
+ PersonalizedLLM,
+ AssistantResponse,
+ Feedback,
+ create_personalized_llm
+)
+
+
+@dataclass
+class AdapterConfig:
+ """Configuration for the PersonalizedLLM adapter."""
+ # PersonalizedLLM config
+ mode: str = "full" # "full", "nopersonal", "vanilla"
+ eval_mode: bool = True
+ enable_preference_extraction: bool = True
+ enable_rl_updates: bool = True
+ use_user_vector: bool = True # Whether to use user vector in policy scoring
+
+ # Paths (absolute paths to actual file locations in the repo)
+ # Note: Using empty_store to start fresh - RAG will accumulate memories during evaluation
+ user_store_path: str = "/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/data/users/collab_eval_store.npz"
+ memory_cards_path: str = "/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/data/corpora/empty_store/memory_cards.jsonl"
+ memory_embeddings_path: str = "/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/data/corpora/empty_store/memory_embeddings.npy"
+ item_projection_path: str = "/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/data/corpora/item_projection.npz"
+
+ # Multi-GPU assignment
+ device_assignment: Optional[Dict[str, str]] = None
+
+ # LLM backend selection
+ llm_name: str = "qwen_1_5b" # Use "llama_8b_vllm" for vLLM backend
+
+ # Shared model mode for multi-threaded efficiency
+ use_shared_models: bool = False # If True, share embedding/reranker across parallel workers
+
+ # Reranker selection: "qwen3" (8B) or "bge" (278M)
+ reranker_type: str = "qwen3"
+
+ # Best-of-N sampling: generate N responses and pick best (for RAG methods)
+ best_of_n: int = 1
+
+ # Reward mode: "keyword" (legacy heuristic) or "llm" (GPT-5-nano judge)
+ reward_mode: str = "keyword"
+
+ # Reward mapping for user behavior
+ preference_enforcement_reward: float = -0.8 # Negative reward when user enforces
+ disappointment_expression_reward: float = -0.4 # Milder negative for disappointment
+ positive_feedback_reward: float = 0.5 # When user expresses satisfaction
+ task_completion_reward: float = 1.0 # When task is solved correctly
+
+
+class PersonalizedLLMAdapter:
+ """
+ Adapter that wraps PersonalizedLLM for use in CollaborativeAgents.
+
+ This adapter:
+ 1. Translates CollaborativeAgents conversation format to PersonalizedLLM
+ 2. Converts user simulator signals to reward/gating for REINFORCE
+ 3. Tracks metrics for evaluation
+ 4. Supports all baseline modes
+ """
+
+ def __init__(self, config: AdapterConfig = None):
+ self.config = config or AdapterConfig()
+ self._llm: Optional[PersonalizedLLM] = None
+ self._initialized = False
+
+ # Session tracking
+ self._current_user_id: Optional[str] = None
+ self._turn_counter: int = 0
+ self._session_metrics: Dict[str, Any] = {}
+
+ # Metrics accumulation
+ self._total_enforcements: int = 0
+ self._total_disappointments: int = 0
+ self._total_turns: int = 0
+
+ def initialize(self):
+ """Initialize the PersonalizedLLM instance."""
+ if self._initialized:
+ return
+
+ shared_mode_str = " (shared models)" if self.config.use_shared_models else ""
+ print(f"[Adapter] Initializing PersonalizedLLM with LLM: {self.config.llm_name}{shared_mode_str}...")
+ self._llm = PersonalizedLLM(
+ mode=self.config.mode,
+ eval_mode=self.config.eval_mode,
+ enable_preference_extraction=self.config.enable_preference_extraction,
+ enable_rl_updates=self.config.enable_rl_updates,
+ user_store_path=self.config.user_store_path,
+ memory_cards_path=self.config.memory_cards_path,
+ memory_embeddings_path=self.config.memory_embeddings_path,
+ item_projection_path=self.config.item_projection_path,
+ device_assignment=self.config.device_assignment,
+ llm_name=self.config.llm_name,
+ use_shared_models=self.config.use_shared_models,
+ reranker_type=self.config.reranker_type,
+ best_of_n=self.config.best_of_n,
+ reward_mode=self.config.reward_mode,
+ )
+ self._initialized = True
+ print("[Adapter] Initialization complete.")
+
+ def start_session(self, user_id: str, user_profile: dict = None):
+ """
+ Start a new session for a user.
+
+ Args:
+ user_id: Unique user identifier
+ user_profile: Optional user profile with preferences (for ground truth)
+ """
+ if not self._initialized:
+ self.initialize()
+
+ self._current_user_id = user_id
+ self._turn_counter = 0
+ self._session_metrics = {
+ "user_id": user_id,
+ "enforcements": 0,
+ "disappointments": 0,
+ "turns": 0,
+ "rewards_applied": [],
+ }
+
+ # Reset session (keeps z_long, clears z_short and history)
+ self._llm.reset_session(user_id)
+
+ def generate_response(
+ self,
+ query: str,
+ conversation_history: List[Dict[str, str]] = None
+ ) -> Dict[str, Any]:
+ """
+ Generate a response using PersonalizedLLM.
+
+ Args:
+ query: Current user query
+ conversation_history: Previous conversation (for context, though
+ PersonalizedLLM tracks its own history)
+
+ Returns:
+ Dict with 'response', 'reasoning', and debug info
+ """
+ if not self._initialized:
+ self.initialize()
+
+ # Call PersonalizedLLM
+ result: AssistantResponse = self._llm.chat(self._current_user_id, query)
+
+ self._turn_counter += 1
+ self._session_metrics["turns"] = self._turn_counter
+
+ # Handle None result defensively
+ if result is None:
+ return {"response": "[Error: LLM returned None]", "reasoning": "", "debug": {}}
+
+ # Format response for CollaborativeAgents
+ answer = result.answer if result.answer else "[No answer generated]"
+ debug_info = result.debug if result.debug else None
+ usage_info = result.usage if result.usage else None
+
+ return {
+ "response": answer,
+ "reasoning": f"Retrieved {len(debug_info.selected_memory_notes) if debug_info else 0} memories",
+ "debug": {
+ "selected_memories": debug_info.selected_memory_notes if debug_info else [],
+ "memory_scores": debug_info.selected_memory_scores if debug_info else [],
+ "extracted_preferences": debug_info.extracted_preferences if debug_info else [],
+ "user_vector_norm": debug_info.extra.get("z_long_norm", 0) if debug_info and debug_info.extra else 0,
+ "usage": {
+ "prompt_tokens": usage_info.prompt_tokens if usage_info else 0,
+ "completion_tokens": usage_info.completion_tokens if usage_info else 0,
+ "total_tokens": usage_info.total_tokens if usage_info else 0,
+ } if usage_info else {}
+ }
+ }
+
+ def prepare_prompt(
+ self,
+ query: str,
+ conversation_history: List[Dict[str, str]] = None
+ ) -> tuple:
+ """
+ Prepare prompt for batch processing without calling LLM.
+
+ This method does all preparation (embedding, memory retrieval) and
+ returns messages for batched vLLM call.
+
+ Args:
+ query: Current user query
+ conversation_history: Previous conversation
+
+ Returns:
+ Tuple of (messages, context) where messages is ready for vLLM batch
+ and context is needed for process_response().
+ """
+ if not self._initialized:
+ self.initialize()
+
+ # Use chat_prepare from PersonalizedLLM
+ result = self._llm.chat_prepare(self._current_user_id, query)
+ return result["messages"], result["context"]
+
+ def process_response(
+ self,
+ response: str,
+ context: dict
+ ) -> Dict[str, Any]:
+ """
+ Process LLM response after batch call.
+
+ This method takes the LLM response and context from prepare_prompt(),
+ does post-processing, and returns the formatted result.
+
+ Args:
+ response: LLM response text from batched vLLM call
+ context: Context dict from prepare_prompt()
+
+ Returns:
+ Dict with 'response', 'reasoning', and debug info
+ """
+ # Use chat_complete from PersonalizedLLM
+ result: AssistantResponse = self._llm.chat_complete(response, context)
+
+ self._turn_counter += 1
+ self._session_metrics["turns"] = self._turn_counter
+
+ # Handle None result defensively
+ if result is None:
+ return {"response": "[Error: LLM returned None]", "reasoning": "", "debug": {}}
+
+ # Format response for CollaborativeAgents
+ answer = result.answer if result.answer else "[No answer generated]"
+ debug_info = result.debug if result.debug else None
+ usage_info = result.usage if result.usage else None
+
+ return {
+ "response": answer,
+ "reasoning": f"Retrieved {len(debug_info.selected_memory_notes) if debug_info else 0} memories",
+ "debug": {
+ "selected_memories": debug_info.selected_memory_notes if debug_info else [],
+ "memory_scores": debug_info.selected_memory_scores if debug_info else [],
+ "extracted_preferences": debug_info.extracted_preferences if debug_info else [],
+ "user_vector_norm": debug_info.extra.get("z_long_norm", 0) if debug_info and debug_info.extra else 0,
+ "usage": {
+ "prompt_tokens": usage_info.prompt_tokens if usage_info else 0,
+ "completion_tokens": usage_info.completion_tokens if usage_info else 0,
+ "total_tokens": usage_info.total_tokens if usage_info else 0,
+ } if usage_info else {}
+ }
+ }
+
+ def process_user_turn(
+ self,
+ user_response: str,
+ enforce_preferences: bool = False,
+ express_disappointment: bool = False,
+ express_satisfaction: bool = False,
+ draft_answer_updated: bool = False
+ ):
+ """
+ Process user turn and derive reward signal for REINFORCE.
+
+ Args:
+ user_response: The user's response text
+ enforce_preferences: Whether user explicitly enforced preferences
+ express_disappointment: Whether user expressed disappointment
+ express_satisfaction: Whether user expressed satisfaction
+ draft_answer_updated: Whether user updated their draft answer
+
+ This is called AFTER generate_response and BEFORE the next turn.
+ """
+ # Derive reward from user behavior
+ reward = 0.0
+ gating = 1.0 # Always apply (could be conditional)
+
+ if enforce_preferences:
+ reward = self.config.preference_enforcement_reward
+ self._session_metrics["enforcements"] += 1
+ self._total_enforcements += 1
+
+ elif express_disappointment:
+ reward = self.config.disappointment_expression_reward
+ self._session_metrics["disappointments"] += 1
+ self._total_disappointments += 1
+
+ elif express_satisfaction or draft_answer_updated:
+ reward = self.config.positive_feedback_reward
+
+ # Apply feedback to PersonalizedLLM
+ if self.config.enable_rl_updates and reward != 0.0:
+ feedback = Feedback(
+ user_id=self._current_user_id,
+ turn_id=self._turn_counter - 1,
+ reward=reward,
+ gating=gating,
+ meta={
+ "enforce": enforce_preferences,
+ "disappointment": express_disappointment,
+ "satisfaction": express_satisfaction,
+ }
+ )
+ self._llm.apply_feedback(feedback)
+ self._session_metrics["rewards_applied"].append(reward)
+
+ def end_session(self, task_success: bool = False) -> Dict[str, Any]:
+ """
+ End the current session and return metrics.
+
+ Args:
+ task_success: Whether the task was solved correctly
+
+ Returns:
+ Session metrics dictionary
+ """
+ # Apply final reward for task completion
+ if task_success and self.config.enable_rl_updates:
+ feedback = Feedback(
+ user_id=self._current_user_id,
+ turn_id=self._turn_counter,
+ reward=self.config.task_completion_reward,
+ gating=1.0,
+ meta={"task_success": True}
+ )
+ self._llm.apply_feedback(feedback)
+ self._session_metrics["rewards_applied"].append(
+ self.config.task_completion_reward
+ )
+
+ self._session_metrics["task_success"] = task_success
+ self._total_turns += self._turn_counter
+
+ return self._session_metrics.copy()
+
+ def reset_user(self, user_id: str):
+ """Completely reset a user (new experiment)."""
+ if self._initialized:
+ self._llm.reset_user(user_id)
+
+ def get_user_vector(self, user_id: str) -> Optional[np.ndarray]:
+ """Get the user's z_long vector for analysis."""
+ if not self._initialized:
+ return None
+
+ state = self._llm._user_store.get_state(user_id)
+ return state.z_long.copy()
+
+ def get_user_state_summary(self, user_id: str) -> Dict[str, Any]:
+ """Get summary of user state for analysis."""
+ if not self._initialized:
+ return {}
+
+ return self._llm.get_user_state_summary(user_id)
+
+ def persist(self):
+ """Save all state to disk."""
+ if self._initialized:
+ self._llm.persist()
+
+ # =========================================================================
+ # CollaborativeAgents Interface Methods
+ # =========================================================================
+
+ def __call__(
+ self,
+ messages: List[Dict[str, str]],
+ user_profile: dict = None,
+ **kwargs
+ ) -> str:
+ """
+ Callable interface for CollaborativeAgents ConversationGenerator.
+
+ Args:
+ messages: Conversation history in [{"role": "user/assistant", "content": "..."}]
+ user_profile: Optional user profile
+
+ Returns:
+ Response string
+ """
+ if not messages:
+ return "How can I help you?"
+
+ # Get the last user message
+ last_user_msg = None
+ for msg in reversed(messages):
+ if msg["role"] == "user":
+ last_user_msg = msg["content"]
+ break
+
+ if last_user_msg is None:
+ return "How can I help you?"
+
+ result = self.generate_response(last_user_msg, messages)
+ return result["response"]
+
+
+# =============================================================================
+# Baseline Adapter Factory
+# =============================================================================
+
+def create_baseline_adapter(
+ baseline_name: str,
+ device_assignment: dict = None,
+ use_vllm: bool = False,
+ use_shared_models: bool = False,
+ reward_mode: str = "keyword",
+) -> PersonalizedLLMAdapter:
+ """
+ Create an adapter configured for a specific baseline.
+
+ Args:
+ baseline_name: One of:
+ - "vanilla": No memory or personalization
+ - "contextual": Full history in context (truncate if overflow)
+ - "reflection": CollaborativeAgents' agent_notes approach
+ - "reflection_grpo": Reflection + GRPO training
+ - "all_memory": All extracted memories in context (no retrieval)
+ - "rag": Extractor + RAG (no user vector)
+ - "rag_vector": Full personalization (Extractor + RAG + User Vector)
+ device_assignment: GPU assignment dict
+ use_vllm: If True, use vLLM HTTP API for LLM inference (much faster)
+ reward_mode: Global reward mode ("keyword" or "llm") applied to all methods
+ use_shared_models: If True, share embedding/reranker models across parallel
+ workers. ESSENTIAL for parallel profile processing to avoid OOM.
+
+ Returns:
+ Configured adapter (PersonalizedLLMAdapter or baseline-specific adapter)
+ """
+ # Select LLM backend
+ llm_name = "llama_8b_vllm" if use_vllm else "llama_8b"
+ configs = {
+ # Baseline 1: Vanilla - no memory at all
+ "vanilla": AdapterConfig(
+ mode="vanilla",
+ enable_preference_extraction=False,
+ enable_rl_updates=False,
+ use_user_vector=False,
+ llm_name=llm_name,
+ use_shared_models=use_shared_models,
+ ),
+ # Baseline 2: Contextual - full history in context
+ # This needs a separate adapter (ContextualAdapter)
+ "contextual": None, # Handled separately
+ # Baseline 3: Reflection - agent_notes mechanism
+ # This needs a separate adapter (ReflectionAdapter)
+ "reflection": None, # Handled separately
+ # Baseline 4: Reflection + GRPO
+ # This needs a separate adapter (ReflectionGRPOAdapter)
+ "reflection_grpo": None, # Handled separately
+ # Baseline 5: All memory in context (no retrieval)
+ "all_memory": AdapterConfig(
+ mode="nopersonal", # Uses all memories, no policy selection
+ enable_preference_extraction=True,
+ enable_rl_updates=False,
+ use_user_vector=False,
+ llm_name=llm_name,
+ use_shared_models=use_shared_models,
+ ),
+ # Baseline 6: Extractor + RAG (no user vector)
+ # Use "nopersonal" mode for pure dense+rerank retrieval without user vector influence
+ # Device assignment: GPUs 2,3 for HF models (8B vLLM uses 40% memory, leaving room)
+ "rag": AdapterConfig(
+ mode="nopersonal",
+ enable_preference_extraction=True,
+ enable_rl_updates=False, # No RL updates
+ use_user_vector=False, # No user vector in policy
+ llm_name=llm_name,
+ use_shared_models=use_shared_models,
+ device_assignment={
+ "embed": "cuda:2",
+ "reranker": "cuda:3",
+ "extractor": "cuda:2",
+ },
+ ),
+ # Baseline 7: Full - Extractor + RAG + User Vector (proposed method)
+ # Device assignment: GPUs 2,3 for HF models (8B vLLM uses 40% memory, leaving room)
+ "rag_vector": AdapterConfig(
+ mode="full",
+ enable_preference_extraction=True,
+ enable_rl_updates=True,
+ use_user_vector=True,
+ llm_name=llm_name,
+ use_shared_models=use_shared_models,
+ device_assignment={
+ "embed": "cuda:2",
+ "reranker": "cuda:3",
+ "extractor": "cuda:2",
+ },
+ ),
+ # Baseline 8: RAG with BGE reranker (278M instead of 8B)
+ "rag_bge": AdapterConfig(
+ mode="nopersonal",
+ enable_preference_extraction=True,
+ enable_rl_updates=False,
+ use_user_vector=False,
+ llm_name=llm_name,
+ use_shared_models=use_shared_models,
+ reranker_type="bge",
+ device_assignment={
+ "embed": "cuda:2",
+ "reranker": "cuda:3",
+ "extractor": "cuda:2",
+ },
+ ),
+ # Baseline 9: RAG + Vector with BGE reranker (278M instead of 8B)
+ "rag_vector_bge": AdapterConfig(
+ mode="full",
+ enable_preference_extraction=True,
+ enable_rl_updates=True,
+ use_user_vector=True,
+ llm_name=llm_name,
+ use_shared_models=use_shared_models,
+ reranker_type="bge",
+ device_assignment={
+ "embed": "cuda:2",
+ "reranker": "cuda:3",
+ "extractor": "cuda:2",
+ },
+ ),
+ # Baseline 10: RAG + Vector with best-of-3 sampling
+ "rag_vector_best3": AdapterConfig(
+ mode="full",
+ enable_preference_extraction=True,
+ enable_rl_updates=True,
+ use_user_vector=True,
+ llm_name=llm_name,
+ use_shared_models=use_shared_models,
+ best_of_n=3,
+ device_assignment={
+ "embed": "cuda:2",
+ "reranker": "cuda:3",
+ "extractor": "cuda:2",
+ },
+ ),
+ # Legacy aliases
+ "nopersonal": AdapterConfig(
+ mode="nopersonal",
+ enable_preference_extraction=True,
+ enable_rl_updates=False,
+ use_user_vector=False,
+ llm_name=llm_name,
+ use_shared_models=use_shared_models,
+ ),
+ "full": AdapterConfig(
+ mode="full",
+ enable_preference_extraction=True,
+ enable_rl_updates=True,
+ use_user_vector=True,
+ llm_name=llm_name,
+ use_shared_models=use_shared_models,
+ ),
+ }
+
+ if baseline_name not in configs:
+ raise ValueError(f"Unknown baseline: {baseline_name}. Choose from {list(configs.keys())}")
+
+ config = configs[baseline_name]
+
+ # Handle baselines that need separate adapters
+ if config is None:
+ if baseline_name == "contextual":
+ from .contextual_adapter import ContextualAdapter
+ return ContextualAdapter(device_assignment=device_assignment)
+ elif baseline_name == "reflection":
+ from .reflection_adapter import ReflectionAdapter
+ return ReflectionAdapter(device_assignment=device_assignment)
+ elif baseline_name == "reflection_grpo":
+ from .reflection_grpo_adapter import ReflectionGRPOAdapter
+ return ReflectionGRPOAdapter(device_assignment=device_assignment)
+ else:
+ raise ValueError(f"Baseline {baseline_name} not implemented yet")
+
+ if device_assignment:
+ config.device_assignment = device_assignment
+
+ # Apply global reward_mode to all methods (overrides per-method defaults)
+ config.reward_mode = reward_mode
+
+ return PersonalizedLLMAdapter(config)
+
+
+# =============================================================================
+# Integration with CollaborativeAgents ConversationGenerator
+# =============================================================================
+
+class PersonalizedCollaborator:
+ """
+ Drop-in replacement for CollaboratorAgent that uses PersonalizedLLM.
+
+ Compatible with ConversationGenerator.generate_conversation()
+ """
+
+ def __init__(
+ self,
+ adapter: PersonalizedLLMAdapter,
+ user_id: str,
+ user_profile: dict = None,
+ max_new_tokens: int = 1024
+ ):
+ self.adapter = adapter
+ self.user_id = user_id
+ self.user_profile = user_profile
+ self.max_new_tokens = max_new_tokens
+
+ # Start session
+ self.adapter.start_session(user_id, user_profile)
+
+ def generate(self, messages: List[Dict[str, str]]) -> Dict[str, Any]:
+ """
+ Generate response in CollaborativeAgents format.
+
+ Returns dict with 'reasoning' and 'response' keys.
+ """
+ # Extract last user message
+ last_user_msg = ""
+ for msg in reversed(messages):
+ if msg["role"] == "user":
+ last_user_msg = msg["content"]
+ break
+
+ # Check for preference enforcement in the user message
+ enforce_detected = self._detect_enforcement(last_user_msg)
+ disappointment_detected = self._detect_disappointment(last_user_msg)
+ satisfaction_detected = self._detect_satisfaction(last_user_msg)
+
+ # Process the previous turn's feedback (if any)
+ if len(messages) > 2: # Not the first turn
+ self.adapter.process_user_turn(
+ last_user_msg,
+ enforce_preferences=enforce_detected,
+ express_disappointment=disappointment_detected,
+ express_satisfaction=satisfaction_detected,
+ )
+
+ # Generate response
+ result = self.adapter.generate_response(last_user_msg, messages)
+
+ return {
+ "reasoning": result["reasoning"],
+ "response": result["response"],
+ "debug": result.get("debug", {})
+ }
+
+ def _detect_enforcement(self, text: str) -> bool:
+ """Detect if user is enforcing preferences."""
+ enforcement_phrases = [
+ "please use", "i asked for", "i prefer", "can you",
+ "instead of", "not what i wanted", "i said", "remember that",
+ "you should", "don't", "avoid", "stop"
+ ]
+ text_lower = text.lower()
+ return any(phrase in text_lower for phrase in enforcement_phrases)
+
+ def _detect_disappointment(self, text: str) -> bool:
+ """Detect expressions of disappointment."""
+ disappointment_phrases = [
+ "not quite", "that's not", "hmm", "not really",
+ "i was hoping", "could be better", "not exactly"
+ ]
+ text_lower = text.lower()
+ return any(phrase in text_lower for phrase in disappointment_phrases)
+
+ def _detect_satisfaction(self, text: str) -> bool:
+ """Detect expressions of satisfaction."""
+ satisfaction_phrases = [
+ "thanks", "perfect", "great", "exactly", "that's what i",
+ "helpful", "makes sense", "got it", "understand now"
+ ]
+ text_lower = text.lower()
+ return any(phrase in text_lower for phrase in satisfaction_phrases)
+
+ def end_session(self, task_success: bool) -> Dict[str, Any]:
+ """End session and get metrics."""
+ return self.adapter.end_session(task_success)
+
+
+# =============================================================================
+# Usage Example
+# =============================================================================
+
+if __name__ == "__main__":
+ # Example usage
+ adapter = create_baseline_adapter("full")
+ adapter.initialize()
+
+ # Simulate a session
+ user_id = "test_user_001"
+ adapter.start_session(user_id)
+
+ # First turn
+ response = adapter.generate_response("How do I implement quicksort?")
+ print(f"Response: {response['response'][:200]}...")
+
+ # User provides feedback (simulating enforcement)
+ adapter.process_user_turn(
+ "Can you use bullet points instead?",
+ enforce_preferences=True
+ )
+
+ # Second turn
+ response = adapter.generate_response("Can you use bullet points instead?")
+ print(f"Response: {response['response'][:200]}...")
+
+ # End session
+ metrics = adapter.end_session(task_success=True)
+ print(f"Session metrics: {metrics}")
+
+ # Get user vector for analysis
+ z_long = adapter.get_user_vector(user_id)
+ print(f"User vector norm: {np.linalg.norm(z_long):.4f}")
+
+ adapter.persist()
diff --git a/collaborativeagents/adapters/reflection_adapter.py b/collaborativeagents/adapters/reflection_adapter.py
new file mode 100644
index 0000000..d535be2
--- /dev/null
+++ b/collaborativeagents/adapters/reflection_adapter.py
@@ -0,0 +1,416 @@
+"""
+Reflection Adapter - vLLM-based implementation using original CollaborativeAgents prompts.
+
+This implements the "reflection" baseline from the MULTISESSIONCOLLAB paper:
+- After each session, agent reflects on interaction to update memory (agent_notes)
+- Memory is provided to agent at start of subsequent sessions
+- Uses session-level reflection + persistent memory
+- Uses LLM-based retrieval (proper_scaffolding) to prevent context overflow
+
+Now uses vLLM for fast inference instead of local transformers.
+Uses EXACT prompts from the original CollaborativeAgents paper for fairness.
+"""
+
+import sys
+from pathlib import Path
+from typing import Optional, List, Dict, Any
+from json_repair import repair_json
+
+# Add parent for utils import
+sys.path.insert(0, str(Path(__file__).parent.parent))
+from utils.vllm_client import VLLMClient, VLLMConfig
+
+# Import ORIGINAL prompts from CollaborativeAgents for fair reproduction
+sys.path.insert(0, str(Path(__file__).parent.parent / "collaborativeagents"))
+from collaborativeagents.prompts import (
+ reflective_agent_system_prompt_no_json,
+ update_agent_notes_prompt,
+ proper_scaffolding_prompt,
+)
+from collaborativeagents.utils import get_conversation_string
+
+# Default vLLM URL (agent server on port 8003)
+DEFAULT_VLLM_URL = "http://localhost:8003/v1"
+
+
+class ReflectionAdapter:
+ """
+ Adapter for the Reflection baseline from MULTISESSIONCOLLAB.
+
+ Uses vLLM for fast inference with:
+ - agent_notes: Persistent memory updated via session-level reflection
+ - Memory retrieval at each turn via proper_scaffolding (when notes are long)
+
+ Uses ORIGINAL CollaborativeAgents prompts for fair benchmark reproduction.
+ """
+
+ def __init__(
+ self,
+ model_name: str = None, # Ignored - vLLM auto-discovers model
+ device_assignment: dict = None, # Ignored - vLLM handles GPU
+ api_base: str = None, # vLLM server URL
+ api_key: str = None, # Ignored
+ with_scaffolding: bool = True,
+ with_proper_scaffolding: bool = True, # Enable LLM-based retrieval by default
+ vllm_url: str = None, # vLLM server URL
+ max_new_tokens: int = 2048, # Match original CollaborativeAgents setting
+ ):
+ self.vllm_url = vllm_url or api_base or DEFAULT_VLLM_URL
+ self.with_scaffolding = with_scaffolding
+ self.with_proper_scaffolding = with_proper_scaffolding
+ self.max_new_tokens = max_new_tokens
+
+ # Per-user memory storage
+ self._user_notes: Dict[str, str] = {}
+ self._current_user_id: Optional[str] = None
+ self._conversation_history: List[Dict[str, str]] = []
+
+ # vLLM client (initialized lazily)
+ self._client: Optional[VLLMClient] = None
+ self._initialized = False
+
+ def initialize(self):
+ """Initialize the adapter (connects to vLLM server)."""
+ if self._initialized:
+ return
+
+ print(f"[ReflectionAdapter] Connecting to vLLM server at {self.vllm_url}...")
+ print(f"[ReflectionAdapter] Using proper_scaffolding={self.with_proper_scaffolding}")
+
+ # Retry connection with exponential backoff
+ import time
+ max_retries = 30
+ for attempt in range(max_retries):
+ try:
+ self._client = VLLMClient(base_url=self.vllm_url)
+ if self._client.health_check():
+ break
+ except Exception as e:
+ pass
+
+ if attempt < max_retries - 1:
+ wait_time = min(2 ** attempt * 0.5, 10) # 0.5, 1, 2, 4, 8, 10, 10...
+ time.sleep(wait_time)
+ else:
+ raise RuntimeError(f"vLLM server not responding at {self.vllm_url} after {max_retries} retries")
+
+ self._initialized = True
+ print(f"[ReflectionAdapter] Connected to vLLM (model: {self._client.config.model})")
+
+ def _generate(self, messages: List[Dict[str, str]], max_new_tokens: int = 1024) -> str:
+ """Generate response using vLLM server."""
+ if not self._initialized:
+ self.initialize()
+
+ result = self._client.chat(
+ messages=messages,
+ max_tokens=max_new_tokens,
+ temperature=1.0, # Match original CollaborativeAgents setting
+ top_p=0.9,
+ )
+
+ return result["content"]
+
+ def _add_scaffolding_to_conversation(
+ self,
+ conversation: List[Dict[str, str]],
+ agent_notes: str
+ ) -> List[Dict[str, str]]:
+ """
+ Add scaffolding (memory notes) to conversation.
+
+ This is the EXACT logic from CollaboratorAgent.add_scaffolding_to_conversation().
+
+ If with_proper_scaffolding=True: Use LLM to extract relevant notes
+ If with_proper_scaffolding=False: Prepend all notes to first message
+ """
+ if not self.with_proper_scaffolding:
+ # Simple mode: prepend all notes
+ if conversation:
+ conversation = list(conversation) # Copy to avoid mutation
+ conversation[0] = dict(conversation[0]) # Copy first message
+ conversation[0]["content"] = (
+ f"Remember, you have been taking notes throughout past conversations "
+ f"about user preferences. Use whatever is relevant in these notes to "
+ f"guide your response:\n{agent_notes}\n\n" + conversation[0]["content"]
+ )
+ return conversation
+ else:
+ # Proper scaffolding: use LLM to extract relevant notes
+ conversation_str = get_conversation_string(conversation)
+ formatted_prompt = proper_scaffolding_prompt.format(
+ conversation_history=conversation_str,
+ complete_agent_notes=agent_notes
+ )
+
+ # Call LLM to extract relevant notes (with retries)
+ for attempt in range(3):
+ try:
+ messages = [{"role": "user", "content": formatted_prompt}]
+ response = self._generate(messages, max_new_tokens=512)
+
+ parsed = repair_json(response, return_objects=True)
+ missing_keys = [k for k in ["reasoning", "relevant_notes"] if k not in parsed]
+
+ if missing_keys:
+ print(f"[ReflectionAdapter] Scaffolding missing keys: {missing_keys}")
+ continue
+
+ scaffolded_notes = parsed["relevant_notes"]
+
+ # Prepend extracted notes to first message
+ if conversation:
+ conversation = list(conversation)
+ conversation[0] = dict(conversation[0])
+ conversation[0]["content"] = (
+ f"Remember, you have been taking notes throughout past conversations "
+ f"about user preferences. Use these notes to guide your response:\n"
+ f"{scaffolded_notes}\n\n" + conversation[0]["content"]
+ )
+
+ return conversation
+
+ except Exception as e:
+ print(f"[ReflectionAdapter] Scaffolding attempt {attempt+1} failed: {e}")
+ continue
+
+ # Fallback: use all notes if retrieval fails
+ print("[ReflectionAdapter] Scaffolding failed, using full notes")
+ if conversation:
+ conversation = list(conversation)
+ conversation[0] = dict(conversation[0])
+ conversation[0]["content"] = (
+ f"Remember, you have been taking notes throughout past conversations "
+ f"about user preferences. Use whatever is relevant in these notes to "
+ f"guide your response:\n{agent_notes}\n\n" + conversation[0]["content"]
+ )
+ return conversation
+
+ def start_session(self, user_id: str, user_profile: dict = None):
+ """Start a new session for a user."""
+ if not self._initialized:
+ self.initialize()
+
+ self._current_user_id = user_id
+ self._conversation_history = []
+
+ def generate_response(
+ self,
+ query: str,
+ conversation_history: List[Dict[str, str]] = None
+ ) -> Dict[str, Any]:
+ """Generate a response using the reflection agent."""
+ if not self._initialized:
+ self.initialize()
+
+ # Add user query to history
+ self._conversation_history.append({"role": "user", "content": query})
+
+ # Get current notes for this user
+ agent_notes = self._user_notes.get(self._current_user_id, "No notes yet about this user.")
+
+ # Build conversation with scaffolding (uses proper_scaffolding if enabled)
+ if self.with_scaffolding and agent_notes != "No notes yet about this user.":
+ conversation_with_notes = self._add_scaffolding_to_conversation(
+ self._conversation_history, agent_notes
+ )
+ else:
+ conversation_with_notes = self._conversation_history
+
+ # Build system prompt using ORIGINAL CollaborativeAgents prompt
+ # Note: For no_json mode, we don't include agent_notes in system prompt
+ # because they're added via scaffolding to the conversation
+ system_prompt = reflective_agent_system_prompt_no_json.format(
+ agent_notes=agent_notes if not self.with_scaffolding else "See notes in conversation."
+ )
+
+ # Build messages for generation
+ messages = [{"role": "system", "content": system_prompt}]
+ messages.extend(conversation_with_notes)
+
+ # Generate response
+ response_text = self._generate(messages, max_new_tokens=self.max_new_tokens)
+
+ self._conversation_history.append({"role": "assistant", "content": response_text})
+
+ return {
+ "response": response_text,
+ "reasoning": "",
+ "debug": {"agent_notes": agent_notes, "proper_scaffolding": self.with_proper_scaffolding}
+ }
+
+ def prepare_prompt(
+ self,
+ query: str,
+ conversation_history: List[Dict[str, str]] = None
+ ) -> tuple:
+ """
+ Prepare prompt for batch processing without calling main LLM.
+
+ Note: This may still call LLM for scaffolding (memory retrieval),
+ but the main generation is deferred for batching.
+
+ Args:
+ query: Current user query
+ conversation_history: Previous conversation
+
+ Returns:
+ Tuple of (messages, context) for batch processing
+ """
+ if not self._initialized:
+ self.initialize()
+
+ # Add user query to history
+ self._conversation_history.append({"role": "user", "content": query})
+
+ # Get current notes for this user
+ agent_notes = self._user_notes.get(self._current_user_id, "No notes yet about this user.")
+
+ # Build conversation with scaffolding (may involve LLM call for proper_scaffolding)
+ if self.with_scaffolding and agent_notes != "No notes yet about this user.":
+ conversation_with_notes = self._add_scaffolding_to_conversation(
+ self._conversation_history, agent_notes
+ )
+ else:
+ conversation_with_notes = self._conversation_history
+
+ # Build system prompt using ORIGINAL CollaborativeAgents prompt
+ system_prompt = reflective_agent_system_prompt_no_json.format(
+ agent_notes=agent_notes if not self.with_scaffolding else "See notes in conversation."
+ )
+
+ # Build messages for generation
+ messages = [{"role": "system", "content": system_prompt}]
+ messages.extend(conversation_with_notes)
+
+ # Context for post-processing
+ ctx = {
+ "agent_notes": agent_notes,
+ }
+
+ return messages, ctx
+
+ def process_response(
+ self,
+ response: str,
+ context: dict
+ ) -> Dict[str, Any]:
+ """
+ Process LLM response after batch call.
+
+ Args:
+ response: LLM response text
+ context: Context dict from prepare_prompt()
+
+ Returns:
+ Dict with 'response', 'reasoning', and debug info
+ """
+ self._conversation_history.append({"role": "assistant", "content": response})
+
+ return {
+ "response": response,
+ "reasoning": "",
+ "debug": {
+ "agent_notes": context["agent_notes"],
+ "proper_scaffolding": self.with_proper_scaffolding
+ }
+ }
+
+ def _update_agent_notes(self, agent_notes: str, conversation: List[Dict[str, str]]) -> Optional[Dict]:
+ """
+ Update agent notes using ORIGINAL CollaborativeAgents logic.
+
+ For 8B models (no_json): Use raw response directly
+ For JSON models: Check for required keys, retry up to num_retries times
+ Returns None if all retries fail (keeps old notes)
+ """
+ conversation_str = get_conversation_string(conversation)
+ formatted_prompt = update_agent_notes_prompt.format(
+ agent_notes=agent_notes,
+ conversation_str=conversation_str
+ )
+
+ num_retries = 10
+ no_json = True # 8B model
+
+ for attempt in range(num_retries):
+ try:
+ messages = [{"role": "user", "content": formatted_prompt}]
+ response = self._generate(messages, max_new_tokens=512)
+
+ # For 8B models (no_json=True): use raw response directly
+ if no_json:
+ return {"agent_notes": response}
+
+ # For JSON models: parse and check keys
+ processed_response = repair_json(response, return_objects=True)
+ missing_keys = [k for k in ["user_preferences_reasoning", "agent_notes"] if k not in processed_response]
+
+ if missing_keys:
+ print(f"[ReflectionAdapter] Missing keys: {missing_keys}, attempt {attempt + 1}")
+ continue
+
+ return processed_response
+
+ except Exception as e:
+ print(f"[ReflectionAdapter] Failed to update agent notes: {e}")
+
+ return None # All retries failed, keep old notes
+
+ def end_session(self, task_success: bool = False) -> Dict[str, Any]:
+ """
+ End session and update agent notes via reflection.
+
+ Uses the ORIGINAL update_agent_notes logic from CollaborativeAgents.
+ """
+ if not self._current_user_id:
+ return {}
+
+ # Get current notes
+ current_notes = self._user_notes.get(self._current_user_id, "No notes yet.")
+
+ # Update notes via session-level reflection
+ if len(self._conversation_history) > 0:
+ result = self._update_agent_notes(current_notes, self._conversation_history)
+
+ if result is not None and "agent_notes" in result:
+ updated_notes = result["agent_notes"]
+ self._user_notes[self._current_user_id] = updated_notes
+ print(f"[ReflectionAdapter] Updated notes for {self._current_user_id} "
+ f"({len(current_notes)} -> {len(updated_notes)} chars)")
+ else:
+ print(f"[ReflectionAdapter] Keeping old notes for {self._current_user_id} "
+ f"(update failed)")
+
+ return {
+ "turns": len(self._conversation_history),
+ "task_success": task_success,
+ "notes_updated": True,
+ }
+
+ def reset_user(self, user_id: str):
+ """Reset all memory for a user."""
+ if user_id in self._user_notes:
+ del self._user_notes[user_id]
+
+ def __call__(
+ self,
+ messages: List[Dict[str, str]],
+ user_profile: dict = None,
+ **kwargs
+ ) -> str:
+ """Callable interface for ConversationGenerator compatibility."""
+ if not messages:
+ return "How can I help you?"
+
+ last_user_msg = None
+ for msg in reversed(messages):
+ if msg["role"] == "user":
+ last_user_msg = msg["content"]
+ break
+
+ if last_user_msg is None:
+ return "How can I help you?"
+
+ result = self.generate_response(last_user_msg, messages)
+ return result["response"]
diff --git a/collaborativeagents/adapters/reflection_grpo_adapter.py b/collaborativeagents/adapters/reflection_grpo_adapter.py
new file mode 100644
index 0000000..09c5b26
--- /dev/null
+++ b/collaborativeagents/adapters/reflection_grpo_adapter.py
@@ -0,0 +1,321 @@
+"""
+Reflection + GRPO Adapter - Local transformers-based implementation.
+
+This implements the "Reflection + GRPO" baseline from the MULTISESSIONCOLLAB paper:
+- Uses a GRPO-trained model for session-level reflection
+- The model is trained to generate higher-quality reflections that capture user preferences
+- Training uses rewards from LLM judge evaluating reflection quality
+
+Key difference from vanilla reflection:
+- Uses GRPO-trained model for reflection generation (better preference capture)
+- Produces more actionable and comprehensive agent notes
+"""
+
+import sys
+from pathlib import Path
+from typing import Optional, List, Dict, Any
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from json_repair import repair_json
+
+# Model paths - Use GRPO-trained model if available, fallback to base
+GRPO_MODEL_PATH = "/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents/training/outputs/grpo_reflection/final"
+SFT_MODEL_PATH = "/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents/training/outputs/sft_reflection"
+DEFAULT_MODEL_PATH = "/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/models/llama-3.1-8b-instruct"
+
+def get_best_available_model():
+ """Get the best available model path (GRPO > SFT > base)."""
+ grpo_path = Path(GRPO_MODEL_PATH)
+ sft_path = Path(SFT_MODEL_PATH)
+
+ if grpo_path.exists() and (grpo_path / "config.json").exists():
+ print(f"[ReflectionGRPOAdapter] Using GRPO-trained model: {grpo_path}")
+ return str(grpo_path)
+ elif sft_path.exists() and (sft_path / "config.json").exists():
+ print(f"[ReflectionGRPOAdapter] Using SFT model (GRPO not found): {sft_path}")
+ return str(sft_path)
+ else:
+ print(f"[ReflectionGRPOAdapter] WARNING: No trained model found, using base model")
+ print(f"[ReflectionGRPOAdapter] To train: run collaborativeagents/slurm/run_sft_training.sh")
+ print(f"[ReflectionGRPOAdapter] then collaborativeagents/slurm/run_grpo_training.sh")
+ return DEFAULT_MODEL_PATH
+
+# GRPO-enhanced system prompt with proper scaffolding
+REFLECTIVE_AGENT_SYSTEM_PROMPT = """You are a collaborative AI agent helping users solve writing, question answering, math, and coding problems.
+
+# Notes
+Remember, you have been taking notes throughout past conversations about user preferences. Use these notes to guide your response:
+{agent_notes}
+
+# Conversation Guidelines:
+- If the user's message is unclear, lacks details, or is ambiguous (e.g. length of an essay, format requirements, specific constraints), do not make assumptions. Ask for clarification and ensure you have enough information before providing an answer.
+- Your goal is to help the user solve their problem. Adhere to their preferences and do your best to help them solve their problem."""
+
+UPDATE_AGENT_NOTES_PROMPT = """You are a collaborative AI agent learning to better help a user with problem-solving tasks across multi-session interactions. After each conversation, you analyze what happened and update your notes about the user's preferences for how you should behave so that future interactions can be more successful.
+
+# Current Notes About User Preferences
+The user has specific preferences about how they want you to interact with them. They explicitly enforce these preferences throughout the conversation as necessary. Here are your current notes about the user's preferences from previous conversations:
+{agent_notes}
+
+# Conversation to Analyze
+{conversation_str}
+
+# Notes Updating Task
+Analyze the conversation above to identify the user's preferences and how you can best satisfy them. Your goal is to create actionable notes that help you satisfy these preferences for future conversations. Keep your notes concise and actionable, without adding unnecessary details. Consider:
+- When did the user explicitly ask you to adjust your response? What specifically did they want changed?
+- What specific actions, formats, or approaches satisfy each preference? What should you keep in mind for future conversations?
+As new situations arise, you may refine, combine, or split preferences to better reflect the user's needs. When updating the notes, do not lose any useful information from past interactions.
+Make sure to add information about the user preferences that you are sure about, and do not hallucinate preferences.
+
+Provide your updated notes as a clear, structured response. List each preference with actionable guidance."""
+
+# GRPO-trained reflection prompt - produces higher quality reflections
+UPDATE_AGENT_NOTES_PROMPT_GRPO = """You are a collaborative AI agent learning to better help a user with problem-solving tasks across multi-session interactions. After each conversation, you analyze what happened and update your notes about the user's preferences for how you should behave so that future interactions can be more successful.
+
+# Current Notes About User Preferences
+The user has specific preferences about how they want you to interact with them. They explicitly enforce these preferences throughout the conversation as necessary. Here are your current notes about the user's preferences from previous conversations:
+{agent_notes}
+
+# Conversation to Analyze
+{conversation_str}
+
+# Notes Updating Task
+Analyze the conversation above to identify the user's preferences and how you can best satisfy them. Your goal is to create actionable notes that help you satisfy these preferences for future conversations. Keep your notes concise and actionable, without adding unnecessary details. Consider:
+- When did the user explicitly ask you to adjust your response? What specifically did they want changed?
+- What specific actions, formats, or approaches satisfy each preference? What should you keep in mind for future conversations?
+As new situations arise, you may refine, combine, or split preferences to better reflect the user's needs. When updating the notes, do not lose any useful information from past interactions.
+Make sure to add information about the user preferences that you are sure about, and do not hallucinate preferences.
+
+# Output Format:
+{{
+ "user_preferences_reasoning": str, # Reasoning about the user preferences and how to satisfy them
+ "agent_notes": str, # Updated notes. Provide a description of the user preferences, how to satisfy them, and any additional notes. This will be provided to you in future conversations with this user. Ensure that you provide a structured response that is clear and easy to understand.
+}}
+For each response, output a valid JSON object using the exact format above, do not include any text before or after the JSON object."""
+
+
+class ReflectionGRPOAdapter:
+ """
+ Adapter for the Reflection + GRPO baseline from MULTISESSIONCOLLAB.
+
+ Uses GRPO-trained model for:
+ - Higher quality session-level reflections that better capture user preferences
+ - The model was trained with rewards from LLM judge evaluating reflection quality
+
+ Key difference from vanilla ReflectionAdapter:
+ - Uses GRPO-trained model (if available) for reflection generation
+ - Removes the faulty preprocessing step that was causing issues
+ - Produces more comprehensive and actionable agent notes
+ """
+
+ def __init__(
+ self,
+ model_name: str = None, # Auto-detect best available model
+ device_assignment: dict = None,
+ api_base: str = None, # Ignored, kept for compatibility
+ api_key: str = None, # Ignored, kept for compatibility
+ ):
+ # Auto-detect best model (GRPO > SFT > base)
+ self.model_path = model_name if model_name else get_best_available_model()
+ self.device_assignment = device_assignment
+
+ # Per-user memory storage
+ self._user_notes: Dict[str, str] = {}
+ self._current_user_id: Optional[str] = None
+ self._conversation_history: List[Dict[str, str]] = []
+
+ # Model components (loaded lazily)
+ self._model = None
+ self._tokenizer = None
+ self._initialized = False
+
+ def initialize(self):
+ """Initialize the adapter (loads model)."""
+ if self._initialized:
+ return
+
+ print(f"[ReflectionGRPOAdapter] Loading model from {self.model_path}...")
+ self._tokenizer = AutoTokenizer.from_pretrained(self.model_path)
+ self._model = AutoModelForCausalLM.from_pretrained(
+ self.model_path,
+ torch_dtype=torch.bfloat16,
+ device_map="auto",
+ )
+ if self._tokenizer.pad_token_id is None:
+ self._tokenizer.pad_token = self._tokenizer.eos_token
+
+ self._initialized = True
+ print("[ReflectionGRPOAdapter] Initialized")
+
+ def _generate(self, messages: List[Dict[str, str]], max_new_tokens: int = 1024) -> str:
+ """Generate response using local model."""
+ if not self._initialized:
+ self.initialize()
+
+ # Apply chat template
+ prompt = self._tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True
+ )
+
+ inputs = self._tokenizer(
+ prompt,
+ return_tensors="pt",
+ truncation=True,
+ max_length=8192
+ ).to(self._model.device)
+
+ with torch.no_grad():
+ outputs = self._model.generate(
+ **inputs,
+ max_new_tokens=max_new_tokens,
+ do_sample=True,
+ temperature=0.7,
+ top_p=0.9,
+ eos_token_id=self._tokenizer.eos_token_id,
+ pad_token_id=self._tokenizer.pad_token_id,
+ )
+
+ # Extract only the generated part
+ input_len = inputs["input_ids"].shape[1]
+ gen_ids = outputs[0][input_len:]
+ response = self._tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
+
+ return response
+
+ def start_session(self, user_id: str, user_profile: dict = None):
+ """Start a new session for a user."""
+ if not self._initialized:
+ self.initialize()
+
+ self._current_user_id = user_id
+ self._conversation_history = []
+
+ def generate_response(
+ self,
+ query: str,
+ conversation_history: List[Dict[str, str]] = None
+ ) -> Dict[str, Any]:
+ """
+ Generate a response using the GRPO-trained reflection agent.
+
+ Note: GRPO training improves the REFLECTION quality (in end_session),
+ not the runtime response generation. The improvement comes from better
+ agent_notes that are generated after each session.
+ """
+ if not self._initialized:
+ self.initialize()
+
+ # Add user query to history
+ self._conversation_history.append({"role": "user", "content": query})
+
+ # Get current notes for this user (these are higher quality due to GRPO training)
+ agent_notes = self._user_notes.get(self._current_user_id, "No notes yet about this user.")
+
+ # Build system prompt with notes
+ system_prompt = REFLECTIVE_AGENT_SYSTEM_PROMPT.format(agent_notes=agent_notes)
+
+ # Build messages for generation
+ messages = [{"role": "system", "content": system_prompt}]
+ messages.extend(self._conversation_history)
+
+ # Generate response
+ response_text = self._generate(messages)
+
+ self._conversation_history.append({"role": "assistant", "content": response_text})
+
+ return {
+ "response": response_text,
+ "reasoning": "",
+ "debug": {"agent_notes": agent_notes}
+ }
+
+ def end_session(self, task_success: bool = False) -> Dict[str, Any]:
+ """
+ End session and update agent notes via GRPO-trained reflection.
+
+ This is the KEY DIFFERENCE from vanilla reflection:
+ - The GRPO-trained model generates higher quality reflections
+ - Reflections better capture user preferences without hallucination
+ - Notes are more actionable and comprehensive
+
+ The improvement comes from GRPO training with rewards that evaluate:
+ - Coverage: Does reflection capture all enforced preferences?
+ - Actionability: Are notes useful for future interactions?
+ - Accuracy: No hallucinated preferences?
+ - Clarity: Well-organized and non-redundant?
+ """
+ if not self._current_user_id:
+ return {}
+
+ # Get current notes
+ current_notes = self._user_notes.get(self._current_user_id, "No notes yet.")
+
+ # Update notes via GRPO-trained session-level reflection
+ if len(self._conversation_history) > 0:
+ try:
+ # Build conversation string
+ conv_str = ""
+ for msg in self._conversation_history:
+ role = "User" if msg["role"] == "user" else "Assistant"
+ conv_str += f"{role}: {msg['content']}\n\n"
+
+ # Generate reflection using GRPO-trained model
+ reflection_prompt = UPDATE_AGENT_NOTES_PROMPT_GRPO.format(
+ agent_notes=current_notes,
+ conversation_str=conv_str
+ )
+
+ messages = [{"role": "user", "content": reflection_prompt}]
+ raw_output = self._generate(messages, max_new_tokens=512)
+
+ # Parse JSON output (GRPO-trained model outputs structured JSON)
+ try:
+ parsed = repair_json(raw_output, return_objects=True)
+ if isinstance(parsed, dict) and "agent_notes" in parsed:
+ updated_notes = parsed["agent_notes"]
+ else:
+ updated_notes = raw_output
+ except:
+ updated_notes = raw_output
+
+ if updated_notes:
+ self._user_notes[self._current_user_id] = updated_notes
+ print(f"[ReflectionGRPOAdapter] Updated notes for {self._current_user_id}")
+
+ except Exception as e:
+ print(f"[ReflectionGRPOAdapter] Failed to update notes: {e}")
+
+ return {
+ "turns": len(self._conversation_history),
+ "task_success": task_success,
+ "notes_updated": True,
+ }
+
+ def reset_user(self, user_id: str):
+ """Reset all memory for a user."""
+ if user_id in self._user_notes:
+ del self._user_notes[user_id]
+
+ def __call__(
+ self,
+ messages: List[Dict[str, str]],
+ user_profile: dict = None,
+ **kwargs
+ ) -> str:
+ """Callable interface for ConversationGenerator compatibility."""
+ if not messages:
+ return "How can I help you?"
+
+ last_user_msg = None
+ for msg in reversed(messages):
+ if msg["role"] == "user":
+ last_user_msg = msg["content"]
+ break
+
+ if last_user_msg is None:
+ return "How can I help you?"
+
+ result = self.generate_response(last_user_msg, messages)
+ return result["response"]