diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-10 20:16:36 +0000 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-10 20:16:36 +0000 |
| commit | 5626080ca4c4219aec4888d6b9406d0d3349fb55 (patch) | |
| tree | 86287d9fd5833e11ccd78566992540f2664fd195 /collaborativeagents/scripts/run_experiments.py | |
| parent | a2036838807428424bbbaff507a6563749a83145 (diff) | |
Add RAG rewrite, 60-session experiment scripts, and analysis tools
- RAG rewrite adapter and vector preference pipeline in personalized_llm
- 60-session experiment queue scripts (reflection, rag, rag_vector, rag_rewrite)
- Vector-preference correlation analysis and visualization scripts
- Local reward model batch processing improvements
- Updated CLAUDE.md with full experiment documentation and notes
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'collaborativeagents/scripts/run_experiments.py')
| -rw-r--r-- | collaborativeagents/scripts/run_experiments.py | 317 |
1 files changed, 298 insertions, 19 deletions
diff --git a/collaborativeagents/scripts/run_experiments.py b/collaborativeagents/scripts/run_experiments.py index e04680c..da3549b 100644 --- a/collaborativeagents/scripts/run_experiments.py +++ b/collaborativeagents/scripts/run_experiments.py @@ -15,6 +15,7 @@ import json import yaml import os import sys +import numpy as np from pathlib import Path from datetime import datetime from typing import List, Dict, Any, Optional @@ -113,7 +114,13 @@ AVAILABLE_METHODS = { "reflection_grpo": "Reflection + GRPO training", "all_memory": "All extracted memories in context (no retrieval)", "rag": "Extractor + RAG (no user vector)", + "rag_dynamic": "Extractor + RAG with dynamic topk (min=3, max=8, ratio=0.5)", + "rag_rewrite": "Extractor + RAG with LLM preference rewrite/merge", + "rag_rewrite_vector": "Extractor + RAG + user vector + LLM preference rewrite", "rag_vector": "Extractor + RAG + user vector (proposed method)", + "rag_vector_fast": "Extractor + RAG + user vector with 10x learning rate", + "rag_vector_consolidate": "Extractor + RAG + user vector with session-level preference consolidation", + "rag_vector_balanced": "Extractor + RAG + user vector with balanced rewards (10x LR + positive signal for good turns)", "rag_bge": "Extractor + RAG with BGE reranker (278M)", "rag_vector_bge": "Extractor + RAG + user vector with BGE reranker (278M)", } @@ -256,6 +263,68 @@ class ExperimentRunner: # Profile will be passed to start_session() when the conversation begins return adapter + def _export_user_vectors(self, method: str, adapters: Dict[int, Any]) -> None: + """ + Export user vectors from all adapters to disk for later analysis. + + Saves both .npz (efficient numpy format) and .json (human-readable). + + Args: + method: Method name for the output directory + adapters: Dict mapping profile_idx to adapter instances + """ + method_dir = self.output_dir / method + + # Collect all user vectors from adapters + all_vectors = {} + for profile_idx, adapter in adapters.items(): + if hasattr(adapter, 'export_all_user_vectors'): + vectors = adapter.export_all_user_vectors() + all_vectors.update(vectors) + + if not all_vectors: + logger.info(f" No user vectors to export for {method}") + return + + # Save as .npz for efficient analysis + npz_path = method_dir / "user_vectors.npz" + user_ids = list(all_vectors.keys()) + k = len(all_vectors[user_ids[0]]["z_long"]) + z_long = np.zeros((len(user_ids), k), dtype=np.float32) + z_short = np.zeros((len(user_ids), k), dtype=np.float32) + reward_ma = np.zeros(len(user_ids), dtype=np.float32) + + for i, uid in enumerate(user_ids): + z_long[i] = all_vectors[uid]["z_long"] + z_short[i] = all_vectors[uid]["z_short"] + reward_ma[i] = all_vectors[uid]["reward_ma"] + + np.savez( + npz_path, + user_ids=np.array(user_ids), + z_long=z_long, + z_short=z_short, + reward_ma=reward_ma, + ) + + # Also save summary stats as JSON + summary = { + "n_users": len(user_ids), + "vector_dim": k, + "z_long_norms": {uid: all_vectors[uid]["z_long_norm"] for uid in user_ids}, + "z_short_norms": {uid: all_vectors[uid]["z_short_norm"] for uid in user_ids}, + "reward_mas": {uid: all_vectors[uid]["reward_ma"] for uid in user_ids}, + "stats": { + "z_long_norm_mean": float(np.mean([all_vectors[uid]["z_long_norm"] for uid in user_ids])), + "z_long_norm_max": float(np.max([all_vectors[uid]["z_long_norm"] for uid in user_ids])), + "z_long_norm_std": float(np.std([all_vectors[uid]["z_long_norm"] for uid in user_ids])), + } + } + with open(method_dir / "user_vectors_summary.json", "w") as f: + json.dump(summary, f, indent=2) + + logger.info(f" Exported {len(user_ids)} user vectors to {npz_path}") + def run_single_session( self, method: str, @@ -297,11 +366,11 @@ class ExperimentRunner: # Structured preferences with condition/action pref_str = "\n".join([ f"- When {p.get('condition', '')}, {p.get('action', '')}" - for p in user_prefs[:10] # Top 10 preferences + for p in user_prefs ]) else: # Simple string preferences - pref_str = "\n".join([f"- {p}" for p in user_prefs[:10]]) + pref_str = "\n".join([f"- {p}" for p in user_prefs]) else: pref_str = str(user_prefs) @@ -619,6 +688,9 @@ class ExperimentRunner: json.dump(results, f, indent=2) logger.info(f" Profile {profile_idx + 1} completed and checkpointed") + # Export user vectors at the end of sequential processing + self._export_user_vectors(method, {0: adapter}) + return results def _run_method_parallel( @@ -690,6 +762,10 @@ class ExperimentRunner: except Exception as e: logger.error(f" Profile {profile_idx} failed: {e}") + # Note: Parallel mode doesn't export user vectors because adapters are + # created/destroyed per profile. Use batch mode for vector export. + logger.info(f" Parallel mode: user vectors not exported (use batch mode)") + def _run_method_batch( self, method: str, @@ -724,7 +800,7 @@ class ExperimentRunner: else: user_client = BatchVLLMClient( vllm_url=self.config.vllm_user_url, - max_tokens=4096, + max_tokens=1024, # User responses typically short, but allow for edge cases temperature=1.0, timeout=None, max_concurrent=100, @@ -799,21 +875,34 @@ class ExperimentRunner: adapters = {} profile_sessions = {} + # Build session problem list ONCE (shared across all profiles for controlled comparison) + # Each dataset contributes exactly n_per_dataset problems (front 10), no repeats + shared_sessions = [] + dataset_names = list(self.datasets.keys()) + n_per_dataset = self.config.n_sessions_per_profile // len(dataset_names) + remainder = self.config.n_sessions_per_profile % len(dataset_names) + + for i, ds_name in enumerate(dataset_names): + ds_obj = self.datasets[ds_name] + items = ds_obj.get_testset() + n_take = n_per_dataset + (1 if i < remainder else 0) + if n_take > len(items): + logger.warning(f" Dataset {ds_name} has only {len(items)} problems, need {n_take}") + for j in range(n_take): + item = items[j % len(items)] + shared_sessions.append({"problem": item.problem, "solution": item.solution, "domain": ds_obj.domain}) + + n_conflict = int(len(shared_sessions) * self.config.conflict_ratio) + shared_session_list = [(s, idx < n_conflict) for idx, s in enumerate(shared_sessions)] + logger.info(f" Built shared session list: {len(shared_sessions)} problems from {len(dataset_names)} datasets ({n_per_dataset} each, same for all profiles)") + for profile_idx in profiles_to_run: profile = self.profiles[profile_idx] adapter = self._create_method_adapter(method, profile, use_shared_models=True) if hasattr(adapter, 'initialize'): adapter.initialize() adapters[profile_idx] = adapter - - sessions = [] - for ds_name, ds_obj in self.datasets.items(): - ds_items = ds_obj.get_testset() - for item in ds_items[:self.config.n_sessions_per_profile]: - sessions.append({"problem": item.problem, "solution": item.solution, "domain": ds_obj.domain}) - sessions = sessions[:self.config.n_sessions_per_profile] - n_conflict = int(len(sessions) * self.config.conflict_ratio) - profile_sessions[profile_idx] = [(s, idx < n_conflict) for idx, s in enumerate(sessions)] + profile_sessions[profile_idx] = shared_session_list n_sessions = self.config.n_sessions_per_profile @@ -860,9 +949,9 @@ class ExperimentRunner: user_prefs = profile.get("preferences", []) if isinstance(user_prefs, list) and user_prefs: if isinstance(user_prefs[0], dict): - pref_str = "\n".join([f"- When {p.get('condition','')}, {p.get('action','')}" for p in user_prefs[:10]]) + pref_str = "\n".join([f"- When {p.get('condition','')}, {p.get('action','')}" for p in user_prefs]) else: - pref_str = "\n".join([f"- {p}" for p in user_prefs[:10]]) + pref_str = "\n".join([f"- {p}" for p in user_prefs]) else: pref_str = str(user_prefs) @@ -916,21 +1005,105 @@ class ExperimentRunner: state["conversation"].append({"role": "user", "content": user_msg}) state["full_log"].append(parsed) - if parsed.get("enforce_preferences", False): + enforce = parsed.get("enforce_preferences", False) + if isinstance(enforce, str): + enforce = enforce.lower() == "true" + if enforce: state["enforcement_count"] += 1 + # Detect disappointment and satisfaction from user message + # Disappointment indicators (not quite right, could be better, etc.) + user_msg_lower = user_msg.lower() + disappointment = any(phrase in user_msg_lower for phrase in [ + "not quite", "not what i", "that's not", "incorrect", + "wrong", "mistake", "error", "confused", "doesn't make sense", + "try again", "not helpful", "not useful" + ]) + # Satisfaction indicators (explicit positive feedback) + satisfaction = parsed.get("should_terminate", False) or any(phrase in user_msg_lower for phrase in [ + "perfect", "exactly", "great", "thanks", "helpful", + "that's right", "correct", "good job", "well done", + "makes sense", "understand now", "got it" + ]) + + # Store parsed feedback for REINFORCE (applied AFTER prepare_prompt sets pending_rl_update) + state["_pending_feedback"] = { + "user_msg": user_msg, + "enforce": bool(enforce), + "disappointment": disappointment and not enforce, # Don't double-count + "satisfaction": satisfaction and not enforce, # Don't count if also enforcing + "draft_answer": bool(parsed.get("draft_answer")), + } + if parsed.get("should_terminate", False) or TERMINATION_SIGNAL in user_msg: to_remove.append(pidx) continue - # Prepare agent prompt for batching (don't call LLM yet) + # Batch preference extraction for PersonalizedLLM adapters + extraction_batch = [] # (pidx, query) + remaining_active = [pidx for pidx in active_list if pidx not in to_remove] + for pidx in remaining_active: + adapter = adapters.get(pidx) + if adapter and hasattr(adapter, '_llm') and hasattr(adapter._llm, 'enable_preference_extraction'): + if adapter._llm.enable_preference_extraction and adapter._llm._extractor is not None: + query = adapter._llm.get_last_user_query(adapter._current_user_id) if hasattr(adapter._llm, 'get_last_user_query') else None + if not query: + state = all_states[pidx] + query = state["conversation"][-1]["content"] if state["conversation"] else "" + if query: + extraction_batch.append((pidx, query)) + + if extraction_batch: + extractor = extraction_batch[0][1] # just need any adapter to get the extractor + adapter0 = adapters[extraction_batch[0][0]] + shared_extractor = adapter0._llm._extractor + if hasattr(shared_extractor, 'batch_extract_preferences'): + queries = [q for _, q in extraction_batch] + batch_results = shared_extractor.batch_extract_preferences(queries) + for (pidx, _), pref_dict in zip(extraction_batch, batch_results): + adapter = adapters[pidx] + adapter._llm.apply_extracted_preferences(adapter._current_user_id, pref_dict) + else: + # Fallback: sequential + for pidx, query in extraction_batch: + adapter = adapters[pidx] + adapter._llm._extractor.extract_turn(adapter._llm._sessions[adapter._current_user_id].session_state.history) + + # Batch scaffolding for reflection adapters before prepare_prompt + scaffolding_batch = [] # (pidx, prompt) + remaining_active = [pidx for pidx in active_list if pidx not in to_remove] + for pidx in remaining_active: + adapter = adapters.get(pidx) + if adapter and hasattr(adapter, 'get_scaffolding_prompt'): + state = all_states[pidx] + # Temporarily add user msg to history for scaffolding + agent_notes = adapter._user_notes.get(adapter._current_user_id, "No notes yet about this user.") + if adapter.with_scaffolding and agent_notes != "No notes yet about this user.": + prompt = adapter.get_scaffolding_prompt( + state["conversation"], agent_notes) + if prompt is not None: + scaffolding_batch.append((pidx, prompt)) + + if scaffolding_batch: + scaff_messages = [[{"role": "user", "content": p}] for _, p in scaffolding_batch] + scaff_responses = agent_client.batch_completion(scaff_messages) + for (pidx, _), resp in zip(scaffolding_batch, scaff_responses): + adapter = adapters[pidx] + adapter._scaffolding_result = resp if resp else None + + # Prepare agent prompts for batching + # NOTE: prepare_prompt calls chat_prepare which sets pending_rl_update + # from the previous turn's data. REINFORCE feedback must be applied + # AFTER this call so that pending_rl_update is available. + for pidx in remaining_active: + state = all_states[pidx] try: adapter = adapters[pidx] + user_msg = state["conversation"][-1]["content"] if hasattr(adapter, 'prepare_prompt'): messages, context = adapter.prepare_prompt(user_msg, state["conversation"][:-1]) agent_prompts_batch.append((pidx, messages, context)) elif hasattr(adapter, 'generate_response'): - # Fallback for adapters without prepare_prompt agent_prompts_batch.append((pidx, None, None)) else: state["conversation"].append({"role": "assistant", "content": "[Error: Adapter not configured]"}) @@ -938,6 +1111,53 @@ class ExperimentRunner: logger.error(f" Agent prepare error p{pidx} t{turn}: {e}") state["conversation"].append({"role": "assistant", "content": "I apologize, I encountered an error. Could you rephrase?"}) + # Apply REINFORCE feedback NOW (after prepare_prompt set pending_rl_update) + for pidx in remaining_active: + state = all_states[pidx] + fb = state.pop("_pending_feedback", None) + if fb: + adapter = adapters.get(pidx) + if adapter and hasattr(adapter, 'process_user_turn'): + adapter.process_user_turn( + user_response=fb["user_msg"], + enforce_preferences=fb["enforce"], + express_disappointment=fb.get("disappointment", False), + express_satisfaction=fb["satisfaction"], + draft_answer_updated=fb["draft_answer"], + ) + + # Also apply feedback for terminated sessions (they skipped prepare_prompt + # but still need the reward signal from their last turn) + for pidx in to_remove: + state = all_states.get(pidx) + if not state: + continue + fb = state.pop("_pending_feedback", None) + if fb: + adapter = adapters.get(pidx) + if adapter and hasattr(adapter, 'process_user_turn'): + # For terminated sessions, we can't call prepare_prompt + # (no next turn), but we still want the reward applied. + # Call chat_prepare with a dummy to set pending_rl_update, + # then apply feedback. + try: + if hasattr(adapter, '_llm') and hasattr(adapter._llm, 'chat_prepare'): + adapter._llm.chat_prepare( + adapter._current_user_id, + fb["user_msg"], + skip_extraction=True, + skip_auto_reward=True, + ) + adapter.process_user_turn( + user_response=fb["user_msg"], + enforce_preferences=fb["enforce"], + express_disappointment=fb.get("disappointment", False), + express_satisfaction=fb["satisfaction"], + draft_answer_updated=fb["draft_answer"], + ) + except Exception: + pass # Best effort for terminated sessions + # Batch vLLM call for all agent prompts if agent_prompts_batch: # Separate prompts that can be batched from fallback @@ -979,6 +1199,25 @@ class ExperimentRunner: active_set -= set(to_remove) + # Batch note-update for reflection adapters before end_session + note_update_batch = [] # (profile_idx, messages) + for profile_idx in profiles_to_run: + if profile_idx not in all_states: + continue + adapter = adapters.get(profile_idx) + if adapter and hasattr(adapter, 'get_note_update_prompt'): + prompt_msgs = adapter.get_note_update_prompt() + if prompt_msgs is not None: + note_update_batch.append((profile_idx, prompt_msgs)) + + if note_update_batch: + note_messages = [msgs for _, msgs in note_update_batch] + note_responses = agent_client.batch_completion(note_messages) + for (profile_idx, _), resp in zip(note_update_batch, note_responses): + if resp: + adapter = adapters[profile_idx] + adapter.apply_note_update_response(resp) + # Save results for this session round for profile_idx in profiles_to_run: if profile_idx not in all_states: @@ -995,10 +1234,20 @@ class ExperimentRunner: task_success = 0 for entry in full_log: if entry.get("should_terminate", False): - draft = entry.get("draft_answer", "") - if draft and "don't know" not in draft.lower() and len(draft) > 20: + draft = str(entry.get("draft_answer", "")) + if draft and "don't know" not in draft.lower(): task_success = 1 + # End session on adapter (applies task completion reward for REINFORCE) + adapter = adapters.get(profile_idx) + if adapter and hasattr(adapter, 'end_session'): + # Skip note update if batch already handled it + skip_notes = hasattr(adapter, 'get_note_update_prompt') + try: + adapter.end_session(task_success=bool(task_success), skip_note_update=skip_notes) + except TypeError: + adapter.end_session(task_success=bool(task_success)) + results.append({ "method": method, "profile_id": self.profiles[profile_idx].get("user_id", f"user_{profile_idx}"), @@ -1023,6 +1272,33 @@ class ExperimentRunner: "adapter_metrics": {}, }) + # Collect adapter metrics (e.g. user_vector_norm for rag_vector) + adapter = adapters.get(profile_idx) + if adapter and hasattr(adapter, 'get_user_vector'): + user_id = self.profiles[profile_idx].get("user_id", f"user_{profile_idx}") + vec = adapter.get_user_vector(user_id) + if vec is not None: + results[-1]["adapter_metrics"] = { + "user_vector_norm": float(np.linalg.norm(vec)), + } + + # Save user vector snapshots every 10 sessions + if (session_idx + 1) % 10 == 0: + vectors_dir = checkpoint_file.parent / "vectors" + vectors_dir.mkdir(parents=True, exist_ok=True) + user_vectors = {} + for profile_idx in profiles_to_run: + adapter = adapters.get(profile_idx) + if adapter and hasattr(adapter, 'get_user_vector'): + user_id = self.profiles[profile_idx].get("user_id", f"user_{profile_idx}") + vec = adapter.get_user_vector(user_id) + if vec is not None: + user_vectors[user_id] = vec + if user_vectors: + snapshot_path = vectors_dir / f"vectors_session_{session_idx+1}.npy" + np.save(snapshot_path, user_vectors) + logger.info(f" Saved {len(user_vectors)} user vectors to {snapshot_path}") + # Checkpoint after each session round with session-level tracking # Only increment for profiles that actually ran in this round (those in all_states) for profile_idx in all_states.keys(): @@ -1043,6 +1319,9 @@ class ExperimentRunner: rate = sessions_done / elapsed * 3600 if elapsed > 0 else 0 logger.info(f" Session round {session_idx+1}/{n_sessions}: {sessions_done} total, {rate:.0f} sessions/hr") + # Export user vectors before cleanup (for RAG methods with user vectors) + self._export_user_vectors(method, adapters) + # Explicitly free adapter models to prevent GPU OOM across methods for pidx, adapter in adapters.items(): if hasattr(adapter, 'cleanup'): |
