summaryrefslogtreecommitdiff
path: root/collaborativeagents/scripts/run_baseline_comparison.py
diff options
context:
space:
mode:
Diffstat (limited to 'collaborativeagents/scripts/run_baseline_comparison.py')
-rw-r--r--collaborativeagents/scripts/run_baseline_comparison.py608
1 files changed, 608 insertions, 0 deletions
diff --git a/collaborativeagents/scripts/run_baseline_comparison.py b/collaborativeagents/scripts/run_baseline_comparison.py
new file mode 100644
index 0000000..0bdbcb5
--- /dev/null
+++ b/collaborativeagents/scripts/run_baseline_comparison.py
@@ -0,0 +1,608 @@
+"""
+Run baseline comparison experiments for personalization methods.
+
+Baselines:
+1. Vanilla - No memory
+2. Contextual Memory - Full history in context (summarize if exceeds limit)
+3. Reflection Memory - CollaborativeAgents' agent_notes approach
+4. Reflection + GRPO - Trained version of reflection
+5. All Memory Cards in Context - Extract all, no retrieval
+6. Extractor + RAG - Retrieval without user vector
+7. Extractor + RAG + User Vector - Full personalization
+
+Metrics:
+- Task Accuracy
+- User Effort (user token count)
+- Total Efficiency (all tokens)
+- Conflict Resolution Accuracy (new)
+- User Vector Similarity to Ground Truth (new)
+"""
+
+import json
+import time
+from pathlib import Path
+from dataclasses import dataclass, field, asdict
+from typing import Optional, Callable
+from abc import ABC, abstractmethod
+import numpy as np
+
+# ============================================================================
+# Metrics
+# ============================================================================
+
+@dataclass
+class ConversationMetrics:
+ """Metrics for a single conversation."""
+ task_accuracy: float # 0 or 1 for correct answer
+ user_tokens: int # Total tokens from user messages
+ assistant_tokens: int # Total tokens from assistant messages
+ total_tokens: int # All tokens
+ num_turns: int # Number of conversation turns
+ num_preference_enforcements: int # How many times user enforced preferences
+ conflict_resolution_correct: Optional[bool] = None # If conflict test, was it resolved correctly?
+ latency_seconds: float = 0.0
+
+ @property
+ def user_effort(self) -> int:
+ """User effort = user tokens (lower is better)."""
+ return self.user_tokens
+
+ @property
+ def efficiency(self) -> float:
+ """Efficiency = accuracy / total_tokens * 1000 (higher is better)."""
+ if self.total_tokens == 0:
+ return 0.0
+ return self.task_accuracy / self.total_tokens * 1000
+
+
+@dataclass
+class ExperimentResults:
+ """Aggregated results for an experiment."""
+ baseline_name: str
+ num_conversations: int
+ metrics: dict = field(default_factory=dict)
+
+ def add_conversation(self, conv_metrics: ConversationMetrics):
+ for key in ['task_accuracy', 'user_tokens', 'assistant_tokens',
+ 'total_tokens', 'num_turns', 'num_preference_enforcements']:
+ if key not in self.metrics:
+ self.metrics[key] = []
+ self.metrics[key].append(getattr(conv_metrics, key))
+
+ if conv_metrics.conflict_resolution_correct is not None:
+ if 'conflict_resolution_correct' not in self.metrics:
+ self.metrics['conflict_resolution_correct'] = []
+ self.metrics['conflict_resolution_correct'].append(
+ 1.0 if conv_metrics.conflict_resolution_correct else 0.0
+ )
+
+ def summary(self) -> dict:
+ """Compute summary statistics."""
+ summary = {"baseline": self.baseline_name, "n": self.num_conversations}
+ for key, values in self.metrics.items():
+ if values:
+ summary[f"{key}_mean"] = np.mean(values)
+ summary[f"{key}_std"] = np.std(values)
+ return summary
+
+
+# ============================================================================
+# Baseline Implementations (Abstract)
+# ============================================================================
+
+class BaselineMethod(ABC):
+ """Abstract base class for all baseline methods."""
+
+ def __init__(self, name: str, config: dict = None):
+ self.name = name
+ self.config = config or {}
+
+ @abstractmethod
+ def initialize_session(self, user_id: str, user_profile: dict):
+ """Initialize a new session for a user."""
+ pass
+
+ @abstractmethod
+ def generate_response(self, query: str, conversation_history: list) -> str:
+ """Generate a response given query and history."""
+ pass
+
+ @abstractmethod
+ def update_memory(self, conversation: list, feedback: dict = None):
+ """Update memory after a conversation or turn."""
+ pass
+
+ @abstractmethod
+ def get_context_for_prompt(self) -> str:
+ """Get the memory/context to include in prompts."""
+ pass
+
+ def count_tokens(self, text: str) -> int:
+ """Estimate token count (simple approximation)."""
+ return len(text.split()) * 1.3 # Rough estimate
+
+
+class VanillaBaseline(BaselineMethod):
+ """No memory - fresh context each time."""
+
+ def __init__(self):
+ super().__init__("vanilla")
+
+ def initialize_session(self, user_id: str, user_profile: dict):
+ self.user_id = user_id
+ # No memory initialization needed
+
+ def generate_response(self, query: str, conversation_history: list) -> str:
+ # Would call LLM here
+ pass
+
+ def update_memory(self, conversation: list, feedback: dict = None):
+ # No memory to update
+ pass
+
+ def get_context_for_prompt(self) -> str:
+ return "" # No additional context
+
+
+class ContextualMemoryBaseline(BaselineMethod):
+ """
+ Full conversation history in context.
+ Summarize when exceeds context limit.
+ """
+
+ def __init__(self, max_context_tokens: int = 32000):
+ super().__init__("contextual_memory")
+ self.max_context_tokens = max_context_tokens
+ self.full_history = []
+ self.summarized_history = ""
+
+ def initialize_session(self, user_id: str, user_profile: dict):
+ self.user_id = user_id
+ # Keep accumulated history across sessions
+
+ def generate_response(self, query: str, conversation_history: list) -> str:
+ pass
+
+ def update_memory(self, conversation: list, feedback: dict = None):
+ self.full_history.extend(conversation)
+
+ # Check if we need to summarize
+ total_tokens = sum(self.count_tokens(msg['content']) for msg in self.full_history)
+ if total_tokens > self.max_context_tokens:
+ self._summarize_old_history()
+
+ def _summarize_old_history(self):
+ """Summarize older parts of history to fit context."""
+ # Keep recent conversations, summarize older ones
+ # This is where information loss happens!
+ keep_recent = 10 # Keep last 10 turns verbatim
+ to_summarize = self.full_history[:-keep_recent]
+ recent = self.full_history[-keep_recent:]
+
+ # Would call LLM to summarize here
+ # self.summarized_history = summarize_with_llm(to_summarize)
+ self.full_history = recent
+
+ def get_context_for_prompt(self) -> str:
+ context = ""
+ if self.summarized_history:
+ context += f"Previous conversation summary:\n{self.summarized_history}\n\n"
+ context += "Recent conversation:\n"
+ for msg in self.full_history[-20:]: # Last 20 messages
+ context += f"{msg['role']}: {msg['content']}\n"
+ return context
+
+
+class ReflectionMemoryBaseline(BaselineMethod):
+ """
+ CollaborativeAgents' approach: maintain agent_notes that are
+ updated after each conversation via reflection.
+ """
+
+ def __init__(self):
+ super().__init__("reflection_memory")
+ self.agent_notes = {}
+
+ def initialize_session(self, user_id: str, user_profile: dict):
+ self.user_id = user_id
+ if user_id not in self.agent_notes:
+ self.agent_notes[user_id] = ""
+
+ def generate_response(self, query: str, conversation_history: list) -> str:
+ pass
+
+ def update_memory(self, conversation: list, feedback: dict = None):
+ # After conversation, reflect and update notes
+ # This is their update_agent_notes_prompt approach
+ pass
+
+ def get_context_for_prompt(self) -> str:
+ return f"Notes about this user:\n{self.agent_notes.get(self.user_id, '')}"
+
+
+class AllMemoryCardsBaseline(BaselineMethod):
+ """
+ Extract preferences into memory cards, but put ALL in context.
+ No retrieval - just dump everything.
+ """
+
+ def __init__(self, max_cards_in_context: int = 100):
+ super().__init__("all_memory_cards")
+ self.max_cards = max_cards_in_context
+ self.memory_cards = {} # user_id -> list of cards
+
+ def initialize_session(self, user_id: str, user_profile: dict):
+ self.user_id = user_id
+ if user_id not in self.memory_cards:
+ self.memory_cards[user_id] = []
+
+ def generate_response(self, query: str, conversation_history: list) -> str:
+ pass
+
+ def update_memory(self, conversation: list, feedback: dict = None):
+ # Extract preferences from conversation and add to cards
+ # Would use preference_extractor here
+ pass
+
+ def get_context_for_prompt(self) -> str:
+ cards = self.memory_cards.get(self.user_id, [])
+ if not cards:
+ return ""
+
+ # Just dump all cards - this is the weakness!
+ context = "User preferences (all known):\n"
+ for i, card in enumerate(cards[:self.max_cards]):
+ context += f"{i+1}. When {card['condition']}: {card['action']}\n"
+ return context
+
+
+class ExtractorRAGBaseline(BaselineMethod):
+ """
+ Extract preferences + RAG retrieval.
+ No user vector - just relevance-based retrieval.
+ """
+
+ def __init__(self, top_k: int = 5):
+ super().__init__("extractor_rag")
+ self.top_k = top_k
+ self.memory_store = None # Would be vector store
+
+ def initialize_session(self, user_id: str, user_profile: dict):
+ self.user_id = user_id
+
+ def generate_response(self, query: str, conversation_history: list) -> str:
+ pass
+
+ def update_memory(self, conversation: list, feedback: dict = None):
+ # Extract and store in vector DB
+ pass
+
+ def get_context_for_prompt(self) -> str:
+ # Would retrieve relevant memories here
+ return "Retrieved preferences:\n..."
+
+
+class ExtractorRAGUserVectorBaseline(BaselineMethod):
+ """
+ Full method: Extract + RAG + User Vector for personalized retrieval.
+ """
+
+ def __init__(self, top_k: int = 5):
+ super().__init__("extractor_rag_user_vector")
+ self.top_k = top_k
+ # Would integrate with your PersonalizedLLM
+
+ def initialize_session(self, user_id: str, user_profile: dict):
+ self.user_id = user_id
+
+ def generate_response(self, query: str, conversation_history: list) -> str:
+ pass
+
+ def update_memory(self, conversation: list, feedback: dict = None):
+ # Extract, store, and update user vector via REINFORCE
+ pass
+
+ def get_context_for_prompt(self) -> str:
+ # Would use policy-based retrieval here
+ return "Retrieved preferences (personalized):\n..."
+
+
+# ============================================================================
+# Experiment Runner
+# ============================================================================
+
+@dataclass
+class ExperimentConfig:
+ """Configuration for an experiment run."""
+ baselines: list # List of baseline names to run
+ dataset: str # Dataset to use
+ num_sessions: int = 10 # Sessions per user
+ num_users: int = 20 # Number of user profiles
+ max_turns_per_session: int = 15
+ profile_path: str = "collaborativeagents/data/complex_profiles/profiles.jsonl"
+ output_dir: str = "collaborativeagents/results"
+ include_conflict_tests: bool = True
+ seed: int = 42
+
+
+class ExperimentRunner:
+ """Runs baseline comparison experiments."""
+
+ BASELINE_CLASSES = {
+ "vanilla": VanillaBaseline,
+ "contextual_memory": ContextualMemoryBaseline,
+ "reflection_memory": ReflectionMemoryBaseline,
+ "all_memory_cards": AllMemoryCardsBaseline,
+ "extractor_rag": ExtractorRAGBaseline,
+ "extractor_rag_user_vector": ExtractorRAGUserVectorBaseline,
+ }
+
+ def __init__(self, config: ExperimentConfig):
+ self.config = config
+ self.results = {}
+
+ def load_profiles(self) -> list:
+ """Load user profiles."""
+ profiles = []
+ with open(self.config.profile_path) as f:
+ for line in f:
+ profiles.append(json.loads(line))
+ return profiles[:self.config.num_users]
+
+ def load_dataset(self) -> list:
+ """Load evaluation dataset."""
+ # Would load from collaborativeagents datasets
+ pass
+
+ def run_single_conversation(
+ self,
+ baseline: BaselineMethod,
+ user_profile: dict,
+ problem: dict,
+ session_num: int
+ ) -> ConversationMetrics:
+ """Run a single conversation and collect metrics."""
+ baseline.initialize_session(user_profile['user_id'], user_profile)
+
+ conversation = []
+ user_tokens = 0
+ assistant_tokens = 0
+ num_enforcements = 0
+
+ # Simulate conversation
+ # In practice, would use UserAgent and actual LLM calls
+
+ start_time = time.time()
+
+ # ... conversation loop ...
+
+ latency = time.time() - start_time
+
+ return ConversationMetrics(
+ task_accuracy=0.0, # Would evaluate
+ user_tokens=user_tokens,
+ assistant_tokens=assistant_tokens,
+ total_tokens=user_tokens + assistant_tokens,
+ num_turns=len(conversation) // 2,
+ num_preference_enforcements=num_enforcements,
+ latency_seconds=latency
+ )
+
+ def run_conflict_test(
+ self,
+ baseline: BaselineMethod,
+ user_profile: dict,
+ conflict_test: dict
+ ) -> bool:
+ """Test if baseline correctly resolves a preference conflict."""
+ baseline.initialize_session(user_profile['user_id'], user_profile)
+
+ # Generate response to conflicting query
+ query = conflict_test['query']
+ response = baseline.generate_response(query, [])
+
+ # Check if correct preference was applied
+ correct_pref_id = conflict_test['correct_pref_id']
+ # Would analyze response to check which preference was followed
+
+ return False # Placeholder
+
+ def run_experiment(self):
+ """Run full experiment across all baselines."""
+ profiles = self.load_profiles()
+ dataset = self.load_dataset()
+
+ for baseline_name in self.config.baselines:
+ print(f"\n{'='*60}")
+ print(f"Running baseline: {baseline_name}")
+ print(f"{'='*60}")
+
+ baseline_class = self.BASELINE_CLASSES[baseline_name]
+ baseline = baseline_class()
+
+ results = ExperimentResults(
+ baseline_name=baseline_name,
+ num_conversations=0
+ )
+
+ for user_profile in profiles:
+ user_id = user_profile['user_id']
+ print(f"\nUser: {user_id}")
+
+ # Run multiple sessions
+ for session in range(self.config.num_sessions):
+ # Select problems for this session
+ session_problems = dataset[session * 3:(session + 1) * 3]
+
+ for problem in session_problems:
+ metrics = self.run_single_conversation(
+ baseline, user_profile, problem, session
+ )
+ results.add_conversation(metrics)
+ results.num_conversations += 1
+
+ # Run conflict tests
+ if self.config.include_conflict_tests:
+ for conflict_test in user_profile.get('conflict_tests', []):
+ correct = self.run_conflict_test(
+ baseline, user_profile, conflict_test
+ )
+ # Would add to results
+
+ self.results[baseline_name] = results
+
+ return self.results
+
+ def compute_user_vector_similarity(
+ self,
+ learned_vector: np.ndarray,
+ ground_truth_profile: dict
+ ) -> float:
+ """
+ Compute similarity between learned user vector and ground truth.
+
+ Ground truth is derived from the preference profile:
+ - One-hot encode preference categories
+ - Weight by how often each preference was triggered
+ """
+ # Create ground truth vector from profile
+ # This is a key metric for your method!
+ pass
+
+ def save_results(self):
+ """Save experiment results."""
+ output_dir = Path(self.config.output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ # Summary table
+ summary = []
+ for name, results in self.results.items():
+ summary.append(results.summary())
+
+ with open(output_dir / "summary.json", 'w') as f:
+ json.dump(summary, f, indent=2)
+
+ # Detailed results
+ for name, results in self.results.items():
+ with open(output_dir / f"{name}_detailed.json", 'w') as f:
+ json.dump(asdict(results), f, indent=2)
+
+ print(f"\nResults saved to {output_dir}")
+
+ def print_comparison_table(self):
+ """Print a comparison table of all baselines."""
+ print("\n" + "=" * 80)
+ print("BASELINE COMPARISON RESULTS")
+ print("=" * 80)
+
+ headers = ["Baseline", "Accuracy", "User Effort", "Total Tokens", "Conflict Acc"]
+ row_format = "{:<30} {:>10} {:>12} {:>14} {:>12}"
+
+ print(row_format.format(*headers))
+ print("-" * 80)
+
+ for name, results in self.results.items():
+ summary = results.summary()
+ print(row_format.format(
+ name,
+ f"{summary.get('task_accuracy_mean', 0):.3f}",
+ f"{summary.get('user_tokens_mean', 0):.0f}",
+ f"{summary.get('total_tokens_mean', 0):.0f}",
+ f"{summary.get('conflict_resolution_correct_mean', 0):.3f}"
+ ))
+
+
+# ============================================================================
+# Analysis Functions
+# ============================================================================
+
+def analyze_context_overflow(results: dict) -> dict:
+ """
+ Analyze how methods degrade as context grows.
+
+ Returns degradation curves for each method.
+ """
+ analysis = {}
+
+ for baseline_name, baseline_results in results.items():
+ # Group by session number
+ by_session = {}
+ # Would analyze accuracy degradation over sessions
+ analysis[baseline_name] = by_session
+
+ return analysis
+
+
+def analyze_conflict_resolution(results: dict, conflict_tests: list) -> dict:
+ """
+ Analyze conflict resolution accuracy by conflict type.
+ """
+ analysis = {}
+
+ for conflict_type in set(t['conflict_group'] for t in conflict_tests):
+ type_tests = [t for t in conflict_tests if t['conflict_group'] == conflict_type]
+
+ for baseline_name in results:
+ if baseline_name not in analysis:
+ analysis[baseline_name] = {}
+ # Would compute accuracy per conflict type
+ analysis[baseline_name][conflict_type] = 0.0
+
+ return analysis
+
+
+def analyze_user_vector_quality(
+ learned_vectors: dict,
+ ground_truth_profiles: list
+) -> dict:
+ """
+ Analyze how well user vectors capture user identity.
+
+ Tests:
+ 1. Same user across sessions -> high similarity
+ 2. Different users -> low similarity
+ 3. Users with similar preferences -> moderate similarity
+ """
+ analysis = {
+ "intra_user_similarity": [], # Same user, different sessions
+ "inter_user_similarity": [], # Different users
+ "preference_cluster_quality": 0.0 # How well vectors cluster by preference
+ }
+
+ # Would compute similarities
+ return analysis
+
+
+# ============================================================================
+# Main
+# ============================================================================
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--baselines", nargs="+", default=[
+ "vanilla", "contextual_memory", "reflection_memory",
+ "all_memory_cards", "extractor_rag", "extractor_rag_user_vector"
+ ])
+ parser.add_argument("--dataset", default="math-500")
+ parser.add_argument("--num_sessions", type=int, default=10)
+ parser.add_argument("--num_users", type=int, default=20)
+ parser.add_argument("--output_dir", default="collaborativeagents/results")
+ parser.add_argument("--seed", type=int, default=42)
+
+ args = parser.parse_args()
+
+ config = ExperimentConfig(
+ baselines=args.baselines,
+ dataset=args.dataset,
+ num_sessions=args.num_sessions,
+ num_users=args.num_users,
+ output_dir=args.output_dir,
+ seed=args.seed
+ )
+
+ runner = ExperimentRunner(config)
+ results = runner.run_experiment()
+ runner.print_comparison_table()
+ runner.save_results()