From dc801c07cf38b0c495686463e6ca6f871a64440e Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Tue, 27 Jan 2026 09:57:37 -0600 Subject: Add collaborativeagents module and update gitignore - Add collaborativeagents subproject with adapters, agents, and evaluation modules - Update .gitignore to exclude large binary files (.whl, .tar), wandb logs, and results Co-Authored-By: Claude Opus 4.5 --- .../scripts/run_baseline_comparison.py | 608 +++++++++++++++++++++ 1 file changed, 608 insertions(+) create mode 100644 collaborativeagents/scripts/run_baseline_comparison.py (limited to 'collaborativeagents/scripts/run_baseline_comparison.py') diff --git a/collaborativeagents/scripts/run_baseline_comparison.py b/collaborativeagents/scripts/run_baseline_comparison.py new file mode 100644 index 0000000..0bdbcb5 --- /dev/null +++ b/collaborativeagents/scripts/run_baseline_comparison.py @@ -0,0 +1,608 @@ +""" +Run baseline comparison experiments for personalization methods. + +Baselines: +1. Vanilla - No memory +2. Contextual Memory - Full history in context (summarize if exceeds limit) +3. Reflection Memory - CollaborativeAgents' agent_notes approach +4. Reflection + GRPO - Trained version of reflection +5. All Memory Cards in Context - Extract all, no retrieval +6. Extractor + RAG - Retrieval without user vector +7. Extractor + RAG + User Vector - Full personalization + +Metrics: +- Task Accuracy +- User Effort (user token count) +- Total Efficiency (all tokens) +- Conflict Resolution Accuracy (new) +- User Vector Similarity to Ground Truth (new) +""" + +import json +import time +from pathlib import Path +from dataclasses import dataclass, field, asdict +from typing import Optional, Callable +from abc import ABC, abstractmethod +import numpy as np + +# ============================================================================ +# Metrics +# ============================================================================ + +@dataclass +class ConversationMetrics: + """Metrics for a single conversation.""" + task_accuracy: float # 0 or 1 for correct answer + user_tokens: int # Total tokens from user messages + assistant_tokens: int # Total tokens from assistant messages + total_tokens: int # All tokens + num_turns: int # Number of conversation turns + num_preference_enforcements: int # How many times user enforced preferences + conflict_resolution_correct: Optional[bool] = None # If conflict test, was it resolved correctly? + latency_seconds: float = 0.0 + + @property + def user_effort(self) -> int: + """User effort = user tokens (lower is better).""" + return self.user_tokens + + @property + def efficiency(self) -> float: + """Efficiency = accuracy / total_tokens * 1000 (higher is better).""" + if self.total_tokens == 0: + return 0.0 + return self.task_accuracy / self.total_tokens * 1000 + + +@dataclass +class ExperimentResults: + """Aggregated results for an experiment.""" + baseline_name: str + num_conversations: int + metrics: dict = field(default_factory=dict) + + def add_conversation(self, conv_metrics: ConversationMetrics): + for key in ['task_accuracy', 'user_tokens', 'assistant_tokens', + 'total_tokens', 'num_turns', 'num_preference_enforcements']: + if key not in self.metrics: + self.metrics[key] = [] + self.metrics[key].append(getattr(conv_metrics, key)) + + if conv_metrics.conflict_resolution_correct is not None: + if 'conflict_resolution_correct' not in self.metrics: + self.metrics['conflict_resolution_correct'] = [] + self.metrics['conflict_resolution_correct'].append( + 1.0 if conv_metrics.conflict_resolution_correct else 0.0 + ) + + def summary(self) -> dict: + """Compute summary statistics.""" + summary = {"baseline": self.baseline_name, "n": self.num_conversations} + for key, values in self.metrics.items(): + if values: + summary[f"{key}_mean"] = np.mean(values) + summary[f"{key}_std"] = np.std(values) + return summary + + +# ============================================================================ +# Baseline Implementations (Abstract) +# ============================================================================ + +class BaselineMethod(ABC): + """Abstract base class for all baseline methods.""" + + def __init__(self, name: str, config: dict = None): + self.name = name + self.config = config or {} + + @abstractmethod + def initialize_session(self, user_id: str, user_profile: dict): + """Initialize a new session for a user.""" + pass + + @abstractmethod + def generate_response(self, query: str, conversation_history: list) -> str: + """Generate a response given query and history.""" + pass + + @abstractmethod + def update_memory(self, conversation: list, feedback: dict = None): + """Update memory after a conversation or turn.""" + pass + + @abstractmethod + def get_context_for_prompt(self) -> str: + """Get the memory/context to include in prompts.""" + pass + + def count_tokens(self, text: str) -> int: + """Estimate token count (simple approximation).""" + return len(text.split()) * 1.3 # Rough estimate + + +class VanillaBaseline(BaselineMethod): + """No memory - fresh context each time.""" + + def __init__(self): + super().__init__("vanilla") + + def initialize_session(self, user_id: str, user_profile: dict): + self.user_id = user_id + # No memory initialization needed + + def generate_response(self, query: str, conversation_history: list) -> str: + # Would call LLM here + pass + + def update_memory(self, conversation: list, feedback: dict = None): + # No memory to update + pass + + def get_context_for_prompt(self) -> str: + return "" # No additional context + + +class ContextualMemoryBaseline(BaselineMethod): + """ + Full conversation history in context. + Summarize when exceeds context limit. + """ + + def __init__(self, max_context_tokens: int = 32000): + super().__init__("contextual_memory") + self.max_context_tokens = max_context_tokens + self.full_history = [] + self.summarized_history = "" + + def initialize_session(self, user_id: str, user_profile: dict): + self.user_id = user_id + # Keep accumulated history across sessions + + def generate_response(self, query: str, conversation_history: list) -> str: + pass + + def update_memory(self, conversation: list, feedback: dict = None): + self.full_history.extend(conversation) + + # Check if we need to summarize + total_tokens = sum(self.count_tokens(msg['content']) for msg in self.full_history) + if total_tokens > self.max_context_tokens: + self._summarize_old_history() + + def _summarize_old_history(self): + """Summarize older parts of history to fit context.""" + # Keep recent conversations, summarize older ones + # This is where information loss happens! + keep_recent = 10 # Keep last 10 turns verbatim + to_summarize = self.full_history[:-keep_recent] + recent = self.full_history[-keep_recent:] + + # Would call LLM to summarize here + # self.summarized_history = summarize_with_llm(to_summarize) + self.full_history = recent + + def get_context_for_prompt(self) -> str: + context = "" + if self.summarized_history: + context += f"Previous conversation summary:\n{self.summarized_history}\n\n" + context += "Recent conversation:\n" + for msg in self.full_history[-20:]: # Last 20 messages + context += f"{msg['role']}: {msg['content']}\n" + return context + + +class ReflectionMemoryBaseline(BaselineMethod): + """ + CollaborativeAgents' approach: maintain agent_notes that are + updated after each conversation via reflection. + """ + + def __init__(self): + super().__init__("reflection_memory") + self.agent_notes = {} + + def initialize_session(self, user_id: str, user_profile: dict): + self.user_id = user_id + if user_id not in self.agent_notes: + self.agent_notes[user_id] = "" + + def generate_response(self, query: str, conversation_history: list) -> str: + pass + + def update_memory(self, conversation: list, feedback: dict = None): + # After conversation, reflect and update notes + # This is their update_agent_notes_prompt approach + pass + + def get_context_for_prompt(self) -> str: + return f"Notes about this user:\n{self.agent_notes.get(self.user_id, '')}" + + +class AllMemoryCardsBaseline(BaselineMethod): + """ + Extract preferences into memory cards, but put ALL in context. + No retrieval - just dump everything. + """ + + def __init__(self, max_cards_in_context: int = 100): + super().__init__("all_memory_cards") + self.max_cards = max_cards_in_context + self.memory_cards = {} # user_id -> list of cards + + def initialize_session(self, user_id: str, user_profile: dict): + self.user_id = user_id + if user_id not in self.memory_cards: + self.memory_cards[user_id] = [] + + def generate_response(self, query: str, conversation_history: list) -> str: + pass + + def update_memory(self, conversation: list, feedback: dict = None): + # Extract preferences from conversation and add to cards + # Would use preference_extractor here + pass + + def get_context_for_prompt(self) -> str: + cards = self.memory_cards.get(self.user_id, []) + if not cards: + return "" + + # Just dump all cards - this is the weakness! + context = "User preferences (all known):\n" + for i, card in enumerate(cards[:self.max_cards]): + context += f"{i+1}. When {card['condition']}: {card['action']}\n" + return context + + +class ExtractorRAGBaseline(BaselineMethod): + """ + Extract preferences + RAG retrieval. + No user vector - just relevance-based retrieval. + """ + + def __init__(self, top_k: int = 5): + super().__init__("extractor_rag") + self.top_k = top_k + self.memory_store = None # Would be vector store + + def initialize_session(self, user_id: str, user_profile: dict): + self.user_id = user_id + + def generate_response(self, query: str, conversation_history: list) -> str: + pass + + def update_memory(self, conversation: list, feedback: dict = None): + # Extract and store in vector DB + pass + + def get_context_for_prompt(self) -> str: + # Would retrieve relevant memories here + return "Retrieved preferences:\n..." + + +class ExtractorRAGUserVectorBaseline(BaselineMethod): + """ + Full method: Extract + RAG + User Vector for personalized retrieval. + """ + + def __init__(self, top_k: int = 5): + super().__init__("extractor_rag_user_vector") + self.top_k = top_k + # Would integrate with your PersonalizedLLM + + def initialize_session(self, user_id: str, user_profile: dict): + self.user_id = user_id + + def generate_response(self, query: str, conversation_history: list) -> str: + pass + + def update_memory(self, conversation: list, feedback: dict = None): + # Extract, store, and update user vector via REINFORCE + pass + + def get_context_for_prompt(self) -> str: + # Would use policy-based retrieval here + return "Retrieved preferences (personalized):\n..." + + +# ============================================================================ +# Experiment Runner +# ============================================================================ + +@dataclass +class ExperimentConfig: + """Configuration for an experiment run.""" + baselines: list # List of baseline names to run + dataset: str # Dataset to use + num_sessions: int = 10 # Sessions per user + num_users: int = 20 # Number of user profiles + max_turns_per_session: int = 15 + profile_path: str = "collaborativeagents/data/complex_profiles/profiles.jsonl" + output_dir: str = "collaborativeagents/results" + include_conflict_tests: bool = True + seed: int = 42 + + +class ExperimentRunner: + """Runs baseline comparison experiments.""" + + BASELINE_CLASSES = { + "vanilla": VanillaBaseline, + "contextual_memory": ContextualMemoryBaseline, + "reflection_memory": ReflectionMemoryBaseline, + "all_memory_cards": AllMemoryCardsBaseline, + "extractor_rag": ExtractorRAGBaseline, + "extractor_rag_user_vector": ExtractorRAGUserVectorBaseline, + } + + def __init__(self, config: ExperimentConfig): + self.config = config + self.results = {} + + def load_profiles(self) -> list: + """Load user profiles.""" + profiles = [] + with open(self.config.profile_path) as f: + for line in f: + profiles.append(json.loads(line)) + return profiles[:self.config.num_users] + + def load_dataset(self) -> list: + """Load evaluation dataset.""" + # Would load from collaborativeagents datasets + pass + + def run_single_conversation( + self, + baseline: BaselineMethod, + user_profile: dict, + problem: dict, + session_num: int + ) -> ConversationMetrics: + """Run a single conversation and collect metrics.""" + baseline.initialize_session(user_profile['user_id'], user_profile) + + conversation = [] + user_tokens = 0 + assistant_tokens = 0 + num_enforcements = 0 + + # Simulate conversation + # In practice, would use UserAgent and actual LLM calls + + start_time = time.time() + + # ... conversation loop ... + + latency = time.time() - start_time + + return ConversationMetrics( + task_accuracy=0.0, # Would evaluate + user_tokens=user_tokens, + assistant_tokens=assistant_tokens, + total_tokens=user_tokens + assistant_tokens, + num_turns=len(conversation) // 2, + num_preference_enforcements=num_enforcements, + latency_seconds=latency + ) + + def run_conflict_test( + self, + baseline: BaselineMethod, + user_profile: dict, + conflict_test: dict + ) -> bool: + """Test if baseline correctly resolves a preference conflict.""" + baseline.initialize_session(user_profile['user_id'], user_profile) + + # Generate response to conflicting query + query = conflict_test['query'] + response = baseline.generate_response(query, []) + + # Check if correct preference was applied + correct_pref_id = conflict_test['correct_pref_id'] + # Would analyze response to check which preference was followed + + return False # Placeholder + + def run_experiment(self): + """Run full experiment across all baselines.""" + profiles = self.load_profiles() + dataset = self.load_dataset() + + for baseline_name in self.config.baselines: + print(f"\n{'='*60}") + print(f"Running baseline: {baseline_name}") + print(f"{'='*60}") + + baseline_class = self.BASELINE_CLASSES[baseline_name] + baseline = baseline_class() + + results = ExperimentResults( + baseline_name=baseline_name, + num_conversations=0 + ) + + for user_profile in profiles: + user_id = user_profile['user_id'] + print(f"\nUser: {user_id}") + + # Run multiple sessions + for session in range(self.config.num_sessions): + # Select problems for this session + session_problems = dataset[session * 3:(session + 1) * 3] + + for problem in session_problems: + metrics = self.run_single_conversation( + baseline, user_profile, problem, session + ) + results.add_conversation(metrics) + results.num_conversations += 1 + + # Run conflict tests + if self.config.include_conflict_tests: + for conflict_test in user_profile.get('conflict_tests', []): + correct = self.run_conflict_test( + baseline, user_profile, conflict_test + ) + # Would add to results + + self.results[baseline_name] = results + + return self.results + + def compute_user_vector_similarity( + self, + learned_vector: np.ndarray, + ground_truth_profile: dict + ) -> float: + """ + Compute similarity between learned user vector and ground truth. + + Ground truth is derived from the preference profile: + - One-hot encode preference categories + - Weight by how often each preference was triggered + """ + # Create ground truth vector from profile + # This is a key metric for your method! + pass + + def save_results(self): + """Save experiment results.""" + output_dir = Path(self.config.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Summary table + summary = [] + for name, results in self.results.items(): + summary.append(results.summary()) + + with open(output_dir / "summary.json", 'w') as f: + json.dump(summary, f, indent=2) + + # Detailed results + for name, results in self.results.items(): + with open(output_dir / f"{name}_detailed.json", 'w') as f: + json.dump(asdict(results), f, indent=2) + + print(f"\nResults saved to {output_dir}") + + def print_comparison_table(self): + """Print a comparison table of all baselines.""" + print("\n" + "=" * 80) + print("BASELINE COMPARISON RESULTS") + print("=" * 80) + + headers = ["Baseline", "Accuracy", "User Effort", "Total Tokens", "Conflict Acc"] + row_format = "{:<30} {:>10} {:>12} {:>14} {:>12}" + + print(row_format.format(*headers)) + print("-" * 80) + + for name, results in self.results.items(): + summary = results.summary() + print(row_format.format( + name, + f"{summary.get('task_accuracy_mean', 0):.3f}", + f"{summary.get('user_tokens_mean', 0):.0f}", + f"{summary.get('total_tokens_mean', 0):.0f}", + f"{summary.get('conflict_resolution_correct_mean', 0):.3f}" + )) + + +# ============================================================================ +# Analysis Functions +# ============================================================================ + +def analyze_context_overflow(results: dict) -> dict: + """ + Analyze how methods degrade as context grows. + + Returns degradation curves for each method. + """ + analysis = {} + + for baseline_name, baseline_results in results.items(): + # Group by session number + by_session = {} + # Would analyze accuracy degradation over sessions + analysis[baseline_name] = by_session + + return analysis + + +def analyze_conflict_resolution(results: dict, conflict_tests: list) -> dict: + """ + Analyze conflict resolution accuracy by conflict type. + """ + analysis = {} + + for conflict_type in set(t['conflict_group'] for t in conflict_tests): + type_tests = [t for t in conflict_tests if t['conflict_group'] == conflict_type] + + for baseline_name in results: + if baseline_name not in analysis: + analysis[baseline_name] = {} + # Would compute accuracy per conflict type + analysis[baseline_name][conflict_type] = 0.0 + + return analysis + + +def analyze_user_vector_quality( + learned_vectors: dict, + ground_truth_profiles: list +) -> dict: + """ + Analyze how well user vectors capture user identity. + + Tests: + 1. Same user across sessions -> high similarity + 2. Different users -> low similarity + 3. Users with similar preferences -> moderate similarity + """ + analysis = { + "intra_user_similarity": [], # Same user, different sessions + "inter_user_similarity": [], # Different users + "preference_cluster_quality": 0.0 # How well vectors cluster by preference + } + + # Would compute similarities + return analysis + + +# ============================================================================ +# Main +# ============================================================================ + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--baselines", nargs="+", default=[ + "vanilla", "contextual_memory", "reflection_memory", + "all_memory_cards", "extractor_rag", "extractor_rag_user_vector" + ]) + parser.add_argument("--dataset", default="math-500") + parser.add_argument("--num_sessions", type=int, default=10) + parser.add_argument("--num_users", type=int, default=20) + parser.add_argument("--output_dir", default="collaborativeagents/results") + parser.add_argument("--seed", type=int, default=42) + + args = parser.parse_args() + + config = ExperimentConfig( + baselines=args.baselines, + dataset=args.dataset, + num_sessions=args.num_sessions, + num_users=args.num_users, + output_dir=args.output_dir, + seed=args.seed + ) + + runner = ExperimentRunner(config) + results = runner.run_experiment() + runner.print_comparison_table() + runner.save_results() -- cgit v1.2.3