diff options
Diffstat (limited to 'collaborativeagents/adapters/reflection_adapter.py')
| -rw-r--r-- | collaborativeagents/adapters/reflection_adapter.py | 186 |
1 files changed, 104 insertions, 82 deletions
diff --git a/collaborativeagents/adapters/reflection_adapter.py b/collaborativeagents/adapters/reflection_adapter.py index d535be2..451c694 100644 --- a/collaborativeagents/adapters/reflection_adapter.py +++ b/collaborativeagents/adapters/reflection_adapter.py @@ -111,6 +111,18 @@ class ReflectionAdapter: return result["content"] + def _prepend_notes_to_conversation(self, conversation, notes_text): + """Prepend notes to the first message of a conversation copy.""" + if conversation: + conversation = list(conversation) + conversation[0] = dict(conversation[0]) + conversation[0]["content"] = ( + f"Remember, you have been taking notes throughout past conversations " + f"about user preferences. Use whatever is relevant in these notes to " + f"guide your response:\n{notes_text}\n\n" + conversation[0]["content"] + ) + return conversation + def _add_scaffolding_to_conversation( self, conversation: List[Dict[str, str]], @@ -118,73 +130,73 @@ class ReflectionAdapter: ) -> List[Dict[str, str]]: """ Add scaffolding (memory notes) to conversation. - - This is the EXACT logic from CollaboratorAgent.add_scaffolding_to_conversation(). - - If with_proper_scaffolding=True: Use LLM to extract relevant notes - If with_proper_scaffolding=False: Prepend all notes to first message + Sequential version - for non-batch use only. """ if not self.with_proper_scaffolding: - # Simple mode: prepend all notes - if conversation: - conversation = list(conversation) # Copy to avoid mutation - conversation[0] = dict(conversation[0]) # Copy first message - conversation[0]["content"] = ( - f"Remember, you have been taking notes throughout past conversations " - f"about user preferences. Use whatever is relevant in these notes to " - f"guide your response:\n{agent_notes}\n\n" + conversation[0]["content"] - ) - return conversation + return self._prepend_notes_to_conversation(conversation, agent_notes) else: - # Proper scaffolding: use LLM to extract relevant notes - conversation_str = get_conversation_string(conversation) - formatted_prompt = proper_scaffolding_prompt.format( - conversation_history=conversation_str, - complete_agent_notes=agent_notes - ) - - # Call LLM to extract relevant notes (with retries) - for attempt in range(3): - try: - messages = [{"role": "user", "content": formatted_prompt}] - response = self._generate(messages, max_new_tokens=512) - - parsed = repair_json(response, return_objects=True) - missing_keys = [k for k in ["reasoning", "relevant_notes"] if k not in parsed] - - if missing_keys: - print(f"[ReflectionAdapter] Scaffolding missing keys: {missing_keys}") - continue - - scaffolded_notes = parsed["relevant_notes"] - - # Prepend extracted notes to first message - if conversation: - conversation = list(conversation) - conversation[0] = dict(conversation[0]) - conversation[0]["content"] = ( - f"Remember, you have been taking notes throughout past conversations " - f"about user preferences. Use these notes to guide your response:\n" - f"{scaffolded_notes}\n\n" + conversation[0]["content"] - ) - - return conversation + prompt = self.get_scaffolding_prompt(conversation, agent_notes) + if prompt is None: + return self._prepend_notes_to_conversation(conversation, agent_notes) + response = self._generate([{"role": "user", "content": prompt}], max_new_tokens=512) + return self.apply_scaffolding_response(conversation, agent_notes, response) + + def get_scaffolding_prompt(self, conversation, agent_notes): + """Build the scaffolding prompt for batch processing. Returns None if no scaffolding needed.""" + if not self.with_proper_scaffolding: + return None + conversation_str = get_conversation_string(conversation) + return proper_scaffolding_prompt.format( + conversation_history=conversation_str, + complete_agent_notes=agent_notes + ) - except Exception as e: - print(f"[ReflectionAdapter] Scaffolding attempt {attempt+1} failed: {e}") - continue + def apply_scaffolding_response(self, conversation, agent_notes, response): + """Apply a scaffolding LLM response to the conversation.""" + try: + parsed = repair_json(response, return_objects=True) + if "relevant_notes" in parsed: + notes_text = parsed["relevant_notes"] + if conversation: + conversation = list(conversation) + conversation[0] = dict(conversation[0]) + conversation[0]["content"] = ( + f"Remember, you have been taking notes throughout past conversations " + f"about user preferences. Use these notes to guide your response:\n" + f"{notes_text}\n\n" + conversation[0]["content"] + ) + return conversation + except Exception: + pass + # Fallback: use all notes + return self._prepend_notes_to_conversation(conversation, agent_notes) + + def get_note_update_prompt(self, user_id=None): + """Build the note-update prompt for batch processing. Returns (messages, user_id) or None.""" + uid = user_id or self._current_user_id + if not uid or not self._conversation_history: + return None + current_notes = self._user_notes.get(uid, "No notes yet.") + conversation_str = get_conversation_string(self._conversation_history) + prompt = update_agent_notes_prompt.format( + agent_notes=current_notes, + conversation_str=conversation_str + ) + return [{"role": "user", "content": prompt}] - # Fallback: use all notes if retrieval fails - print("[ReflectionAdapter] Scaffolding failed, using full notes") - if conversation: - conversation = list(conversation) - conversation[0] = dict(conversation[0]) - conversation[0]["content"] = ( - f"Remember, you have been taking notes throughout past conversations " - f"about user preferences. Use whatever is relevant in these notes to " - f"guide your response:\n{agent_notes}\n\n" + conversation[0]["content"] - ) - return conversation + def apply_note_update_response(self, response, user_id=None): + """Apply a note-update LLM response.""" + uid = user_id or self._current_user_id + if not uid: + return + try: + # 8B model: use raw response directly + updated_notes = response.strip() + if updated_notes: + old_len = len(self._user_notes.get(uid, "")) + self._user_notes[uid] = updated_notes + except Exception: + pass def start_session(self, user_id: str, user_profile: dict = None): """Start a new session for a user.""" @@ -266,11 +278,27 @@ class ReflectionAdapter: # Get current notes for this user agent_notes = self._user_notes.get(self._current_user_id, "No notes yet about this user.") - # Build conversation with scaffolding (may involve LLM call for proper_scaffolding) + # Store for batch scaffolding (prepare_prompt may be called after batch scaffolding) + self._pending_agent_notes = agent_notes + self._pending_scaffolded = False + + # Build conversation with scaffolding if self.with_scaffolding and agent_notes != "No notes yet about this user.": - conversation_with_notes = self._add_scaffolding_to_conversation( - self._conversation_history, agent_notes - ) + if hasattr(self, '_scaffolding_result') and self._scaffolding_result is not None: + # Use pre-computed batch scaffolding result + conversation_with_notes = self.apply_scaffolding_response( + list(self._conversation_history), agent_notes, self._scaffolding_result) + self._scaffolding_result = None + self._pending_scaffolded = True + elif not self.with_proper_scaffolding: + conversation_with_notes = self._prepend_notes_to_conversation( + self._conversation_history, agent_notes) + self._pending_scaffolded = True + else: + # Sequential fallback - should not happen in batch mode + conversation_with_notes = self._add_scaffolding_to_conversation( + self._conversation_history, agent_notes) + self._pending_scaffolded = True else: conversation_with_notes = self._conversation_history @@ -357,30 +385,24 @@ class ReflectionAdapter: return None # All retries failed, keep old notes - def end_session(self, task_success: bool = False) -> Dict[str, Any]: + def end_session(self, task_success: bool = False, skip_note_update: bool = False) -> Dict[str, Any]: """ End session and update agent notes via reflection. - Uses the ORIGINAL update_agent_notes logic from CollaborativeAgents. + Args: + task_success: Whether the task was completed successfully + skip_note_update: If True, skip note update (already done via batch) """ if not self._current_user_id: return {} - # Get current notes - current_notes = self._user_notes.get(self._current_user_id, "No notes yet.") - - # Update notes via session-level reflection - if len(self._conversation_history) > 0: - result = self._update_agent_notes(current_notes, self._conversation_history) - + # Update notes via session-level reflection (skip if batch already did it) + if not skip_note_update and len(self._conversation_history) > 0: + result = self._update_agent_notes( + self._user_notes.get(self._current_user_id, "No notes yet."), + self._conversation_history) if result is not None and "agent_notes" in result: - updated_notes = result["agent_notes"] - self._user_notes[self._current_user_id] = updated_notes - print(f"[ReflectionAdapter] Updated notes for {self._current_user_id} " - f"({len(current_notes)} -> {len(updated_notes)} chars)") - else: - print(f"[ReflectionAdapter] Keeping old notes for {self._current_user_id} " - f"(update failed)") + self._user_notes[self._current_user_id] = result["agent_notes"] return { "turns": len(self._conversation_history), |
