summaryrefslogtreecommitdiff
path: root/collaborativeagents/scripts/run_experiments.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-10 20:16:36 +0000
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-10 20:16:36 +0000
commit5626080ca4c4219aec4888d6b9406d0d3349fb55 (patch)
tree86287d9fd5833e11ccd78566992540f2664fd195 /collaborativeagents/scripts/run_experiments.py
parenta2036838807428424bbbaff507a6563749a83145 (diff)
Add RAG rewrite, 60-session experiment scripts, and analysis tools
- RAG rewrite adapter and vector preference pipeline in personalized_llm - 60-session experiment queue scripts (reflection, rag, rag_vector, rag_rewrite) - Vector-preference correlation analysis and visualization scripts - Local reward model batch processing improvements - Updated CLAUDE.md with full experiment documentation and notes Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'collaborativeagents/scripts/run_experiments.py')
-rw-r--r--collaborativeagents/scripts/run_experiments.py317
1 files changed, 298 insertions, 19 deletions
diff --git a/collaborativeagents/scripts/run_experiments.py b/collaborativeagents/scripts/run_experiments.py
index e04680c..da3549b 100644
--- a/collaborativeagents/scripts/run_experiments.py
+++ b/collaborativeagents/scripts/run_experiments.py
@@ -15,6 +15,7 @@ import json
import yaml
import os
import sys
+import numpy as np
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Any, Optional
@@ -113,7 +114,13 @@ AVAILABLE_METHODS = {
"reflection_grpo": "Reflection + GRPO training",
"all_memory": "All extracted memories in context (no retrieval)",
"rag": "Extractor + RAG (no user vector)",
+ "rag_dynamic": "Extractor + RAG with dynamic topk (min=3, max=8, ratio=0.5)",
+ "rag_rewrite": "Extractor + RAG with LLM preference rewrite/merge",
+ "rag_rewrite_vector": "Extractor + RAG + user vector + LLM preference rewrite",
"rag_vector": "Extractor + RAG + user vector (proposed method)",
+ "rag_vector_fast": "Extractor + RAG + user vector with 10x learning rate",
+ "rag_vector_consolidate": "Extractor + RAG + user vector with session-level preference consolidation",
+ "rag_vector_balanced": "Extractor + RAG + user vector with balanced rewards (10x LR + positive signal for good turns)",
"rag_bge": "Extractor + RAG with BGE reranker (278M)",
"rag_vector_bge": "Extractor + RAG + user vector with BGE reranker (278M)",
}
@@ -256,6 +263,68 @@ class ExperimentRunner:
# Profile will be passed to start_session() when the conversation begins
return adapter
+ def _export_user_vectors(self, method: str, adapters: Dict[int, Any]) -> None:
+ """
+ Export user vectors from all adapters to disk for later analysis.
+
+ Saves both .npz (efficient numpy format) and .json (human-readable).
+
+ Args:
+ method: Method name for the output directory
+ adapters: Dict mapping profile_idx to adapter instances
+ """
+ method_dir = self.output_dir / method
+
+ # Collect all user vectors from adapters
+ all_vectors = {}
+ for profile_idx, adapter in adapters.items():
+ if hasattr(adapter, 'export_all_user_vectors'):
+ vectors = adapter.export_all_user_vectors()
+ all_vectors.update(vectors)
+
+ if not all_vectors:
+ logger.info(f" No user vectors to export for {method}")
+ return
+
+ # Save as .npz for efficient analysis
+ npz_path = method_dir / "user_vectors.npz"
+ user_ids = list(all_vectors.keys())
+ k = len(all_vectors[user_ids[0]]["z_long"])
+ z_long = np.zeros((len(user_ids), k), dtype=np.float32)
+ z_short = np.zeros((len(user_ids), k), dtype=np.float32)
+ reward_ma = np.zeros(len(user_ids), dtype=np.float32)
+
+ for i, uid in enumerate(user_ids):
+ z_long[i] = all_vectors[uid]["z_long"]
+ z_short[i] = all_vectors[uid]["z_short"]
+ reward_ma[i] = all_vectors[uid]["reward_ma"]
+
+ np.savez(
+ npz_path,
+ user_ids=np.array(user_ids),
+ z_long=z_long,
+ z_short=z_short,
+ reward_ma=reward_ma,
+ )
+
+ # Also save summary stats as JSON
+ summary = {
+ "n_users": len(user_ids),
+ "vector_dim": k,
+ "z_long_norms": {uid: all_vectors[uid]["z_long_norm"] for uid in user_ids},
+ "z_short_norms": {uid: all_vectors[uid]["z_short_norm"] for uid in user_ids},
+ "reward_mas": {uid: all_vectors[uid]["reward_ma"] for uid in user_ids},
+ "stats": {
+ "z_long_norm_mean": float(np.mean([all_vectors[uid]["z_long_norm"] for uid in user_ids])),
+ "z_long_norm_max": float(np.max([all_vectors[uid]["z_long_norm"] for uid in user_ids])),
+ "z_long_norm_std": float(np.std([all_vectors[uid]["z_long_norm"] for uid in user_ids])),
+ }
+ }
+ with open(method_dir / "user_vectors_summary.json", "w") as f:
+ json.dump(summary, f, indent=2)
+
+ logger.info(f" Exported {len(user_ids)} user vectors to {npz_path}")
+
def run_single_session(
self,
method: str,
@@ -297,11 +366,11 @@ class ExperimentRunner:
# Structured preferences with condition/action
pref_str = "\n".join([
f"- When {p.get('condition', '')}, {p.get('action', '')}"
- for p in user_prefs[:10] # Top 10 preferences
+ for p in user_prefs
])
else:
# Simple string preferences
- pref_str = "\n".join([f"- {p}" for p in user_prefs[:10]])
+ pref_str = "\n".join([f"- {p}" for p in user_prefs])
else:
pref_str = str(user_prefs)
@@ -619,6 +688,9 @@ class ExperimentRunner:
json.dump(results, f, indent=2)
logger.info(f" Profile {profile_idx + 1} completed and checkpointed")
+ # Export user vectors at the end of sequential processing
+ self._export_user_vectors(method, {0: adapter})
+
return results
def _run_method_parallel(
@@ -690,6 +762,10 @@ class ExperimentRunner:
except Exception as e:
logger.error(f" Profile {profile_idx} failed: {e}")
+ # Note: Parallel mode doesn't export user vectors because adapters are
+ # created/destroyed per profile. Use batch mode for vector export.
+ logger.info(f" Parallel mode: user vectors not exported (use batch mode)")
+
def _run_method_batch(
self,
method: str,
@@ -724,7 +800,7 @@ class ExperimentRunner:
else:
user_client = BatchVLLMClient(
vllm_url=self.config.vllm_user_url,
- max_tokens=4096,
+ max_tokens=1024, # User responses typically short, but allow for edge cases
temperature=1.0,
timeout=None,
max_concurrent=100,
@@ -799,21 +875,34 @@ class ExperimentRunner:
adapters = {}
profile_sessions = {}
+ # Build session problem list ONCE (shared across all profiles for controlled comparison)
+ # Each dataset contributes exactly n_per_dataset problems (front 10), no repeats
+ shared_sessions = []
+ dataset_names = list(self.datasets.keys())
+ n_per_dataset = self.config.n_sessions_per_profile // len(dataset_names)
+ remainder = self.config.n_sessions_per_profile % len(dataset_names)
+
+ for i, ds_name in enumerate(dataset_names):
+ ds_obj = self.datasets[ds_name]
+ items = ds_obj.get_testset()
+ n_take = n_per_dataset + (1 if i < remainder else 0)
+ if n_take > len(items):
+ logger.warning(f" Dataset {ds_name} has only {len(items)} problems, need {n_take}")
+ for j in range(n_take):
+ item = items[j % len(items)]
+ shared_sessions.append({"problem": item.problem, "solution": item.solution, "domain": ds_obj.domain})
+
+ n_conflict = int(len(shared_sessions) * self.config.conflict_ratio)
+ shared_session_list = [(s, idx < n_conflict) for idx, s in enumerate(shared_sessions)]
+ logger.info(f" Built shared session list: {len(shared_sessions)} problems from {len(dataset_names)} datasets ({n_per_dataset} each, same for all profiles)")
+
for profile_idx in profiles_to_run:
profile = self.profiles[profile_idx]
adapter = self._create_method_adapter(method, profile, use_shared_models=True)
if hasattr(adapter, 'initialize'):
adapter.initialize()
adapters[profile_idx] = adapter
-
- sessions = []
- for ds_name, ds_obj in self.datasets.items():
- ds_items = ds_obj.get_testset()
- for item in ds_items[:self.config.n_sessions_per_profile]:
- sessions.append({"problem": item.problem, "solution": item.solution, "domain": ds_obj.domain})
- sessions = sessions[:self.config.n_sessions_per_profile]
- n_conflict = int(len(sessions) * self.config.conflict_ratio)
- profile_sessions[profile_idx] = [(s, idx < n_conflict) for idx, s in enumerate(sessions)]
+ profile_sessions[profile_idx] = shared_session_list
n_sessions = self.config.n_sessions_per_profile
@@ -860,9 +949,9 @@ class ExperimentRunner:
user_prefs = profile.get("preferences", [])
if isinstance(user_prefs, list) and user_prefs:
if isinstance(user_prefs[0], dict):
- pref_str = "\n".join([f"- When {p.get('condition','')}, {p.get('action','')}" for p in user_prefs[:10]])
+ pref_str = "\n".join([f"- When {p.get('condition','')}, {p.get('action','')}" for p in user_prefs])
else:
- pref_str = "\n".join([f"- {p}" for p in user_prefs[:10]])
+ pref_str = "\n".join([f"- {p}" for p in user_prefs])
else:
pref_str = str(user_prefs)
@@ -916,21 +1005,105 @@ class ExperimentRunner:
state["conversation"].append({"role": "user", "content": user_msg})
state["full_log"].append(parsed)
- if parsed.get("enforce_preferences", False):
+ enforce = parsed.get("enforce_preferences", False)
+ if isinstance(enforce, str):
+ enforce = enforce.lower() == "true"
+ if enforce:
state["enforcement_count"] += 1
+ # Detect disappointment and satisfaction from user message
+ # Disappointment indicators (not quite right, could be better, etc.)
+ user_msg_lower = user_msg.lower()
+ disappointment = any(phrase in user_msg_lower for phrase in [
+ "not quite", "not what i", "that's not", "incorrect",
+ "wrong", "mistake", "error", "confused", "doesn't make sense",
+ "try again", "not helpful", "not useful"
+ ])
+ # Satisfaction indicators (explicit positive feedback)
+ satisfaction = parsed.get("should_terminate", False) or any(phrase in user_msg_lower for phrase in [
+ "perfect", "exactly", "great", "thanks", "helpful",
+ "that's right", "correct", "good job", "well done",
+ "makes sense", "understand now", "got it"
+ ])
+
+ # Store parsed feedback for REINFORCE (applied AFTER prepare_prompt sets pending_rl_update)
+ state["_pending_feedback"] = {
+ "user_msg": user_msg,
+ "enforce": bool(enforce),
+ "disappointment": disappointment and not enforce, # Don't double-count
+ "satisfaction": satisfaction and not enforce, # Don't count if also enforcing
+ "draft_answer": bool(parsed.get("draft_answer")),
+ }
+
if parsed.get("should_terminate", False) or TERMINATION_SIGNAL in user_msg:
to_remove.append(pidx)
continue
- # Prepare agent prompt for batching (don't call LLM yet)
+ # Batch preference extraction for PersonalizedLLM adapters
+ extraction_batch = [] # (pidx, query)
+ remaining_active = [pidx for pidx in active_list if pidx not in to_remove]
+ for pidx in remaining_active:
+ adapter = adapters.get(pidx)
+ if adapter and hasattr(adapter, '_llm') and hasattr(adapter._llm, 'enable_preference_extraction'):
+ if adapter._llm.enable_preference_extraction and adapter._llm._extractor is not None:
+ query = adapter._llm.get_last_user_query(adapter._current_user_id) if hasattr(adapter._llm, 'get_last_user_query') else None
+ if not query:
+ state = all_states[pidx]
+ query = state["conversation"][-1]["content"] if state["conversation"] else ""
+ if query:
+ extraction_batch.append((pidx, query))
+
+ if extraction_batch:
+ extractor = extraction_batch[0][1] # just need any adapter to get the extractor
+ adapter0 = adapters[extraction_batch[0][0]]
+ shared_extractor = adapter0._llm._extractor
+ if hasattr(shared_extractor, 'batch_extract_preferences'):
+ queries = [q for _, q in extraction_batch]
+ batch_results = shared_extractor.batch_extract_preferences(queries)
+ for (pidx, _), pref_dict in zip(extraction_batch, batch_results):
+ adapter = adapters[pidx]
+ adapter._llm.apply_extracted_preferences(adapter._current_user_id, pref_dict)
+ else:
+ # Fallback: sequential
+ for pidx, query in extraction_batch:
+ adapter = adapters[pidx]
+ adapter._llm._extractor.extract_turn(adapter._llm._sessions[adapter._current_user_id].session_state.history)
+
+ # Batch scaffolding for reflection adapters before prepare_prompt
+ scaffolding_batch = [] # (pidx, prompt)
+ remaining_active = [pidx for pidx in active_list if pidx not in to_remove]
+ for pidx in remaining_active:
+ adapter = adapters.get(pidx)
+ if adapter and hasattr(adapter, 'get_scaffolding_prompt'):
+ state = all_states[pidx]
+ # Temporarily add user msg to history for scaffolding
+ agent_notes = adapter._user_notes.get(adapter._current_user_id, "No notes yet about this user.")
+ if adapter.with_scaffolding and agent_notes != "No notes yet about this user.":
+ prompt = adapter.get_scaffolding_prompt(
+ state["conversation"], agent_notes)
+ if prompt is not None:
+ scaffolding_batch.append((pidx, prompt))
+
+ if scaffolding_batch:
+ scaff_messages = [[{"role": "user", "content": p}] for _, p in scaffolding_batch]
+ scaff_responses = agent_client.batch_completion(scaff_messages)
+ for (pidx, _), resp in zip(scaffolding_batch, scaff_responses):
+ adapter = adapters[pidx]
+ adapter._scaffolding_result = resp if resp else None
+
+ # Prepare agent prompts for batching
+ # NOTE: prepare_prompt calls chat_prepare which sets pending_rl_update
+ # from the previous turn's data. REINFORCE feedback must be applied
+ # AFTER this call so that pending_rl_update is available.
+ for pidx in remaining_active:
+ state = all_states[pidx]
try:
adapter = adapters[pidx]
+ user_msg = state["conversation"][-1]["content"]
if hasattr(adapter, 'prepare_prompt'):
messages, context = adapter.prepare_prompt(user_msg, state["conversation"][:-1])
agent_prompts_batch.append((pidx, messages, context))
elif hasattr(adapter, 'generate_response'):
- # Fallback for adapters without prepare_prompt
agent_prompts_batch.append((pidx, None, None))
else:
state["conversation"].append({"role": "assistant", "content": "[Error: Adapter not configured]"})
@@ -938,6 +1111,53 @@ class ExperimentRunner:
logger.error(f" Agent prepare error p{pidx} t{turn}: {e}")
state["conversation"].append({"role": "assistant", "content": "I apologize, I encountered an error. Could you rephrase?"})
+ # Apply REINFORCE feedback NOW (after prepare_prompt set pending_rl_update)
+ for pidx in remaining_active:
+ state = all_states[pidx]
+ fb = state.pop("_pending_feedback", None)
+ if fb:
+ adapter = adapters.get(pidx)
+ if adapter and hasattr(adapter, 'process_user_turn'):
+ adapter.process_user_turn(
+ user_response=fb["user_msg"],
+ enforce_preferences=fb["enforce"],
+ express_disappointment=fb.get("disappointment", False),
+ express_satisfaction=fb["satisfaction"],
+ draft_answer_updated=fb["draft_answer"],
+ )
+
+ # Also apply feedback for terminated sessions (they skipped prepare_prompt
+ # but still need the reward signal from their last turn)
+ for pidx in to_remove:
+ state = all_states.get(pidx)
+ if not state:
+ continue
+ fb = state.pop("_pending_feedback", None)
+ if fb:
+ adapter = adapters.get(pidx)
+ if adapter and hasattr(adapter, 'process_user_turn'):
+ # For terminated sessions, we can't call prepare_prompt
+ # (no next turn), but we still want the reward applied.
+ # Call chat_prepare with a dummy to set pending_rl_update,
+ # then apply feedback.
+ try:
+ if hasattr(adapter, '_llm') and hasattr(adapter._llm, 'chat_prepare'):
+ adapter._llm.chat_prepare(
+ adapter._current_user_id,
+ fb["user_msg"],
+ skip_extraction=True,
+ skip_auto_reward=True,
+ )
+ adapter.process_user_turn(
+ user_response=fb["user_msg"],
+ enforce_preferences=fb["enforce"],
+ express_disappointment=fb.get("disappointment", False),
+ express_satisfaction=fb["satisfaction"],
+ draft_answer_updated=fb["draft_answer"],
+ )
+ except Exception:
+ pass # Best effort for terminated sessions
+
# Batch vLLM call for all agent prompts
if agent_prompts_batch:
# Separate prompts that can be batched from fallback
@@ -979,6 +1199,25 @@ class ExperimentRunner:
active_set -= set(to_remove)
+ # Batch note-update for reflection adapters before end_session
+ note_update_batch = [] # (profile_idx, messages)
+ for profile_idx in profiles_to_run:
+ if profile_idx not in all_states:
+ continue
+ adapter = adapters.get(profile_idx)
+ if adapter and hasattr(adapter, 'get_note_update_prompt'):
+ prompt_msgs = adapter.get_note_update_prompt()
+ if prompt_msgs is not None:
+ note_update_batch.append((profile_idx, prompt_msgs))
+
+ if note_update_batch:
+ note_messages = [msgs for _, msgs in note_update_batch]
+ note_responses = agent_client.batch_completion(note_messages)
+ for (profile_idx, _), resp in zip(note_update_batch, note_responses):
+ if resp:
+ adapter = adapters[profile_idx]
+ adapter.apply_note_update_response(resp)
+
# Save results for this session round
for profile_idx in profiles_to_run:
if profile_idx not in all_states:
@@ -995,10 +1234,20 @@ class ExperimentRunner:
task_success = 0
for entry in full_log:
if entry.get("should_terminate", False):
- draft = entry.get("draft_answer", "")
- if draft and "don't know" not in draft.lower() and len(draft) > 20:
+ draft = str(entry.get("draft_answer", ""))
+ if draft and "don't know" not in draft.lower():
task_success = 1
+ # End session on adapter (applies task completion reward for REINFORCE)
+ adapter = adapters.get(profile_idx)
+ if adapter and hasattr(adapter, 'end_session'):
+ # Skip note update if batch already handled it
+ skip_notes = hasattr(adapter, 'get_note_update_prompt')
+ try:
+ adapter.end_session(task_success=bool(task_success), skip_note_update=skip_notes)
+ except TypeError:
+ adapter.end_session(task_success=bool(task_success))
+
results.append({
"method": method,
"profile_id": self.profiles[profile_idx].get("user_id", f"user_{profile_idx}"),
@@ -1023,6 +1272,33 @@ class ExperimentRunner:
"adapter_metrics": {},
})
+ # Collect adapter metrics (e.g. user_vector_norm for rag_vector)
+ adapter = adapters.get(profile_idx)
+ if adapter and hasattr(adapter, 'get_user_vector'):
+ user_id = self.profiles[profile_idx].get("user_id", f"user_{profile_idx}")
+ vec = adapter.get_user_vector(user_id)
+ if vec is not None:
+ results[-1]["adapter_metrics"] = {
+ "user_vector_norm": float(np.linalg.norm(vec)),
+ }
+
+ # Save user vector snapshots every 10 sessions
+ if (session_idx + 1) % 10 == 0:
+ vectors_dir = checkpoint_file.parent / "vectors"
+ vectors_dir.mkdir(parents=True, exist_ok=True)
+ user_vectors = {}
+ for profile_idx in profiles_to_run:
+ adapter = adapters.get(profile_idx)
+ if adapter and hasattr(adapter, 'get_user_vector'):
+ user_id = self.profiles[profile_idx].get("user_id", f"user_{profile_idx}")
+ vec = adapter.get_user_vector(user_id)
+ if vec is not None:
+ user_vectors[user_id] = vec
+ if user_vectors:
+ snapshot_path = vectors_dir / f"vectors_session_{session_idx+1}.npy"
+ np.save(snapshot_path, user_vectors)
+ logger.info(f" Saved {len(user_vectors)} user vectors to {snapshot_path}")
+
# Checkpoint after each session round with session-level tracking
# Only increment for profiles that actually ran in this round (those in all_states)
for profile_idx in all_states.keys():
@@ -1043,6 +1319,9 @@ class ExperimentRunner:
rate = sessions_done / elapsed * 3600 if elapsed > 0 else 0
logger.info(f" Session round {session_idx+1}/{n_sessions}: {sessions_done} total, {rate:.0f} sessions/hr")
+ # Export user vectors before cleanup (for RAG methods with user vectors)
+ self._export_user_vectors(method, adapters)
+
# Explicitly free adapter models to prevent GPU OOM across methods
for pidx, adapter in adapters.items():
if hasattr(adapter, 'cleanup'):