summaryrefslogtreecommitdiff
path: root/collaborativeagents/adapters/reflection_adapter.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-10 20:16:36 +0000
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-10 20:16:36 +0000
commit5626080ca4c4219aec4888d6b9406d0d3349fb55 (patch)
tree86287d9fd5833e11ccd78566992540f2664fd195 /collaborativeagents/adapters/reflection_adapter.py
parenta2036838807428424bbbaff507a6563749a83145 (diff)
Add RAG rewrite, 60-session experiment scripts, and analysis tools
- RAG rewrite adapter and vector preference pipeline in personalized_llm - 60-session experiment queue scripts (reflection, rag, rag_vector, rag_rewrite) - Vector-preference correlation analysis and visualization scripts - Local reward model batch processing improvements - Updated CLAUDE.md with full experiment documentation and notes Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'collaborativeagents/adapters/reflection_adapter.py')
-rw-r--r--collaborativeagents/adapters/reflection_adapter.py186
1 files changed, 104 insertions, 82 deletions
diff --git a/collaborativeagents/adapters/reflection_adapter.py b/collaborativeagents/adapters/reflection_adapter.py
index d535be2..451c694 100644
--- a/collaborativeagents/adapters/reflection_adapter.py
+++ b/collaborativeagents/adapters/reflection_adapter.py
@@ -111,6 +111,18 @@ class ReflectionAdapter:
return result["content"]
+ def _prepend_notes_to_conversation(self, conversation, notes_text):
+ """Prepend notes to the first message of a conversation copy."""
+ if conversation:
+ conversation = list(conversation)
+ conversation[0] = dict(conversation[0])
+ conversation[0]["content"] = (
+ f"Remember, you have been taking notes throughout past conversations "
+ f"about user preferences. Use whatever is relevant in these notes to "
+ f"guide your response:\n{notes_text}\n\n" + conversation[0]["content"]
+ )
+ return conversation
+
def _add_scaffolding_to_conversation(
self,
conversation: List[Dict[str, str]],
@@ -118,73 +130,73 @@ class ReflectionAdapter:
) -> List[Dict[str, str]]:
"""
Add scaffolding (memory notes) to conversation.
-
- This is the EXACT logic from CollaboratorAgent.add_scaffolding_to_conversation().
-
- If with_proper_scaffolding=True: Use LLM to extract relevant notes
- If with_proper_scaffolding=False: Prepend all notes to first message
+ Sequential version - for non-batch use only.
"""
if not self.with_proper_scaffolding:
- # Simple mode: prepend all notes
- if conversation:
- conversation = list(conversation) # Copy to avoid mutation
- conversation[0] = dict(conversation[0]) # Copy first message
- conversation[0]["content"] = (
- f"Remember, you have been taking notes throughout past conversations "
- f"about user preferences. Use whatever is relevant in these notes to "
- f"guide your response:\n{agent_notes}\n\n" + conversation[0]["content"]
- )
- return conversation
+ return self._prepend_notes_to_conversation(conversation, agent_notes)
else:
- # Proper scaffolding: use LLM to extract relevant notes
- conversation_str = get_conversation_string(conversation)
- formatted_prompt = proper_scaffolding_prompt.format(
- conversation_history=conversation_str,
- complete_agent_notes=agent_notes
- )
-
- # Call LLM to extract relevant notes (with retries)
- for attempt in range(3):
- try:
- messages = [{"role": "user", "content": formatted_prompt}]
- response = self._generate(messages, max_new_tokens=512)
-
- parsed = repair_json(response, return_objects=True)
- missing_keys = [k for k in ["reasoning", "relevant_notes"] if k not in parsed]
-
- if missing_keys:
- print(f"[ReflectionAdapter] Scaffolding missing keys: {missing_keys}")
- continue
-
- scaffolded_notes = parsed["relevant_notes"]
-
- # Prepend extracted notes to first message
- if conversation:
- conversation = list(conversation)
- conversation[0] = dict(conversation[0])
- conversation[0]["content"] = (
- f"Remember, you have been taking notes throughout past conversations "
- f"about user preferences. Use these notes to guide your response:\n"
- f"{scaffolded_notes}\n\n" + conversation[0]["content"]
- )
-
- return conversation
+ prompt = self.get_scaffolding_prompt(conversation, agent_notes)
+ if prompt is None:
+ return self._prepend_notes_to_conversation(conversation, agent_notes)
+ response = self._generate([{"role": "user", "content": prompt}], max_new_tokens=512)
+ return self.apply_scaffolding_response(conversation, agent_notes, response)
+
+ def get_scaffolding_prompt(self, conversation, agent_notes):
+ """Build the scaffolding prompt for batch processing. Returns None if no scaffolding needed."""
+ if not self.with_proper_scaffolding:
+ return None
+ conversation_str = get_conversation_string(conversation)
+ return proper_scaffolding_prompt.format(
+ conversation_history=conversation_str,
+ complete_agent_notes=agent_notes
+ )
- except Exception as e:
- print(f"[ReflectionAdapter] Scaffolding attempt {attempt+1} failed: {e}")
- continue
+ def apply_scaffolding_response(self, conversation, agent_notes, response):
+ """Apply a scaffolding LLM response to the conversation."""
+ try:
+ parsed = repair_json(response, return_objects=True)
+ if "relevant_notes" in parsed:
+ notes_text = parsed["relevant_notes"]
+ if conversation:
+ conversation = list(conversation)
+ conversation[0] = dict(conversation[0])
+ conversation[0]["content"] = (
+ f"Remember, you have been taking notes throughout past conversations "
+ f"about user preferences. Use these notes to guide your response:\n"
+ f"{notes_text}\n\n" + conversation[0]["content"]
+ )
+ return conversation
+ except Exception:
+ pass
+ # Fallback: use all notes
+ return self._prepend_notes_to_conversation(conversation, agent_notes)
+
+ def get_note_update_prompt(self, user_id=None):
+ """Build the note-update prompt for batch processing. Returns (messages, user_id) or None."""
+ uid = user_id or self._current_user_id
+ if not uid or not self._conversation_history:
+ return None
+ current_notes = self._user_notes.get(uid, "No notes yet.")
+ conversation_str = get_conversation_string(self._conversation_history)
+ prompt = update_agent_notes_prompt.format(
+ agent_notes=current_notes,
+ conversation_str=conversation_str
+ )
+ return [{"role": "user", "content": prompt}]
- # Fallback: use all notes if retrieval fails
- print("[ReflectionAdapter] Scaffolding failed, using full notes")
- if conversation:
- conversation = list(conversation)
- conversation[0] = dict(conversation[0])
- conversation[0]["content"] = (
- f"Remember, you have been taking notes throughout past conversations "
- f"about user preferences. Use whatever is relevant in these notes to "
- f"guide your response:\n{agent_notes}\n\n" + conversation[0]["content"]
- )
- return conversation
+ def apply_note_update_response(self, response, user_id=None):
+ """Apply a note-update LLM response."""
+ uid = user_id or self._current_user_id
+ if not uid:
+ return
+ try:
+ # 8B model: use raw response directly
+ updated_notes = response.strip()
+ if updated_notes:
+ old_len = len(self._user_notes.get(uid, ""))
+ self._user_notes[uid] = updated_notes
+ except Exception:
+ pass
def start_session(self, user_id: str, user_profile: dict = None):
"""Start a new session for a user."""
@@ -266,11 +278,27 @@ class ReflectionAdapter:
# Get current notes for this user
agent_notes = self._user_notes.get(self._current_user_id, "No notes yet about this user.")
- # Build conversation with scaffolding (may involve LLM call for proper_scaffolding)
+ # Store for batch scaffolding (prepare_prompt may be called after batch scaffolding)
+ self._pending_agent_notes = agent_notes
+ self._pending_scaffolded = False
+
+ # Build conversation with scaffolding
if self.with_scaffolding and agent_notes != "No notes yet about this user.":
- conversation_with_notes = self._add_scaffolding_to_conversation(
- self._conversation_history, agent_notes
- )
+ if hasattr(self, '_scaffolding_result') and self._scaffolding_result is not None:
+ # Use pre-computed batch scaffolding result
+ conversation_with_notes = self.apply_scaffolding_response(
+ list(self._conversation_history), agent_notes, self._scaffolding_result)
+ self._scaffolding_result = None
+ self._pending_scaffolded = True
+ elif not self.with_proper_scaffolding:
+ conversation_with_notes = self._prepend_notes_to_conversation(
+ self._conversation_history, agent_notes)
+ self._pending_scaffolded = True
+ else:
+ # Sequential fallback - should not happen in batch mode
+ conversation_with_notes = self._add_scaffolding_to_conversation(
+ self._conversation_history, agent_notes)
+ self._pending_scaffolded = True
else:
conversation_with_notes = self._conversation_history
@@ -357,30 +385,24 @@ class ReflectionAdapter:
return None # All retries failed, keep old notes
- def end_session(self, task_success: bool = False) -> Dict[str, Any]:
+ def end_session(self, task_success: bool = False, skip_note_update: bool = False) -> Dict[str, Any]:
"""
End session and update agent notes via reflection.
- Uses the ORIGINAL update_agent_notes logic from CollaborativeAgents.
+ Args:
+ task_success: Whether the task was completed successfully
+ skip_note_update: If True, skip note update (already done via batch)
"""
if not self._current_user_id:
return {}
- # Get current notes
- current_notes = self._user_notes.get(self._current_user_id, "No notes yet.")
-
- # Update notes via session-level reflection
- if len(self._conversation_history) > 0:
- result = self._update_agent_notes(current_notes, self._conversation_history)
-
+ # Update notes via session-level reflection (skip if batch already did it)
+ if not skip_note_update and len(self._conversation_history) > 0:
+ result = self._update_agent_notes(
+ self._user_notes.get(self._current_user_id, "No notes yet."),
+ self._conversation_history)
if result is not None and "agent_notes" in result:
- updated_notes = result["agent_notes"]
- self._user_notes[self._current_user_id] = updated_notes
- print(f"[ReflectionAdapter] Updated notes for {self._current_user_id} "
- f"({len(current_notes)} -> {len(updated_notes)} chars)")
- else:
- print(f"[ReflectionAdapter] Keeping old notes for {self._current_user_id} "
- f"(update failed)")
+ self._user_notes[self._current_user_id] = result["agent_notes"]
return {
"turns": len(self._conversation_history),