summaryrefslogtreecommitdiff
path: root/collaborativeagents/scripts/run.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-27 09:57:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-27 09:57:37 -0600
commitdc801c07cf38b0c495686463e6ca6f871a64440e (patch)
tree599f03114775921dbc472403c701f4a3a8ea188a /collaborativeagents/scripts/run.py
parente43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (diff)
Add collaborativeagents module and update gitignore
- Add collaborativeagents subproject with adapters, agents, and evaluation modules - Update .gitignore to exclude large binary files (.whl, .tar), wandb logs, and results Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'collaborativeagents/scripts/run.py')
-rw-r--r--collaborativeagents/scripts/run.py504
1 files changed, 504 insertions, 0 deletions
diff --git a/collaborativeagents/scripts/run.py b/collaborativeagents/scripts/run.py
new file mode 100644
index 0000000..f6ed79e
--- /dev/null
+++ b/collaborativeagents/scripts/run.py
@@ -0,0 +1,504 @@
+import argparse
+import json
+import os
+from tqdm import tqdm
+from concurrent.futures import ThreadPoolExecutor, as_completed
+
+from collaborativeagents.conversation_generator import ConversationGenerator
+from collaborativeagents.conversation_evaluator import ConversationEvaluator
+from collaborativeagents.datasets import datasets_info
+from collaborativeagents.agents import CollaboratorAgent,UserAgent
+from collaborativeagents.prompts import agent_system_prompt_no_user
+
+# import litellm
+# litellm._turn_on_debug()
+
+
+def load_dataset(dataset_name, eval_size, training=False):
+ if dataset_name not in datasets_info:
+ raise ValueError(f"Dataset '{dataset_name}' not found. Available datasets: {list(datasets_info.keys())}")
+
+ dataset_class,user_task_description = datasets_info[dataset_name]['class'],datasets_info[dataset_name]['task_description']
+ dataset_instance = dataset_class(eval_size=eval_size, training=training)
+ dataset = dataset_instance.get_dataset()
+ print(f"Loaded {len(dataset)} samples from {dataset_name}")
+
+ return dataset,user_task_description
+
+def load_user_profiles(training=False):
+ if training:
+ with open("/shared/storage-01/users/mehri2/mem/collaborativeagents/collaborativeagents/user_profiles/training_user_profiles.json", 'r') as f:
+ user_profiles = json.load(f)
+ else:
+ with open("/shared/storage-01/users/mehri2/mem/collaborativeagents/collaborativeagents/user_profiles/user_profiles.json", 'r') as f:
+ user_profiles = json.load(f)
+ return user_profiles
+
+
+def run_no_user(
+ dataset_name="math-hard",
+ eval_size=20,
+ batch_size=50,
+ collaborator_model_name="gpt-4.1-mini",
+ collaborator_api_base=None,
+ collaborator_api_key=None,
+ judge_model_name="gpt-4.1-mini",
+ judge_api_base=None,
+ judge_api_key=None,
+ output_file=None
+ ):
+ if os.path.exists(output_file):
+ with open(output_file, 'r') as f:
+ evaluation_results = []
+ for line in f:
+ if line.strip() == "":
+ continue
+ evaluation_result = json.loads(line)
+ evaluation_results.append(evaluation_result)
+
+ print(f"\n\n\nAll conversations generation and evaluation complete!")
+ print(f" # Total conversations: {len(evaluation_results)}")
+ print("\nEvaluation Results:")
+ print(f" # Overall average accuracy: {evaluation_results[0]['average_accuracy']}")
+ print(f" # Overall average conversation length (# messages): {evaluation_results[0]['average_conversation_length']}")
+ return
+
+ dataset,_ = load_dataset(dataset_name, eval_size)
+
+ collaborator_agent = CollaboratorAgent(
+ model_name=collaborator_model_name,
+ api_base=collaborator_api_base,
+ api_key=collaborator_api_key,
+ )
+ conversationEvaluator = ConversationEvaluator(
+ dataset_name=dataset_name,
+ model_name=judge_model_name,
+ api_base=judge_api_base,
+ api_key=judge_api_key
+ )
+
+ # Generate and evaluate conversations
+ print(f"\n\n\nGenerating answers for {len(dataset)} {dataset_name} samples\n")
+ generated_conversations = []
+
+ total_batches = (len(dataset) + batch_size - 1) // batch_size
+ with tqdm(total=total_batches, desc="Generating conversations") as progress_bar:
+ for i in range(0, len(dataset), batch_size):
+ batch_samples = dataset[i:i+batch_size]
+ # Prepare conversations for the collaborator
+ batch_conversations = [[{"role": "user", "content": s['problem']} ] for s in batch_samples]
+
+ # Batched collaborator responses
+ collab_responses = collaborator_agent.generate_collaborator_responses_batch(batch_conversations)
+
+ # Assemble results
+ for sample, conv, collab_response in zip(batch_samples, batch_conversations, collab_responses):
+ if collab_response is None:
+ # Skip failed items; they will be counted downstream if needed
+ continue
+ conv.append({"role": "assistant", "content": str(collab_response["response"])})
+
+ # Add draft_answer key for evaluator compatibility
+ collab_response["draft_answer"] = collab_response["response"]
+ full_conversation_log = [collab_response]
+
+ res = {
+ "sample": sample,
+ "conversation": conv,
+ "full_conversation_log": full_conversation_log
+ }
+ generated_conversations.append(res)
+
+ progress_bar.update(1)
+
+ evaluation_results = conversationEvaluator.evaluate_conversations(generated_conversations)
+
+ with open(output_file, 'a') as f:
+ f.write(json.dumps(evaluation_results) + "\n")
+ f.flush()
+
+ print(f"\n\n\nAll conversations generation and evaluation complete!")
+ print(f" # Total conversations: {len(generated_conversations)}")
+ print("\nEvaluation Results:")
+ print(f" # Overall average accuracy: {evaluation_results['average_accuracy']}")
+ print(f" # Overall average conversation length (# messages): {evaluation_results['average_conversation_length']}")
+
+def run_user_no_profile(
+ dataset_name="math-hard",
+ eval_size=20,
+ max_turns=10,
+ batch_size=100,
+ user_model_name="gpt-4.1-mini",
+ user_api_base=None,
+ user_api_key=None,
+ collaborator_model_name="gpt-4.1-mini",
+ collaborator_api_base=None,
+ collaborator_api_key=None,
+ judge_model_name="gpt-4.1-mini",
+ judge_api_base=None,
+ judge_api_key=None,
+ output_file=None
+ ):
+ if os.path.exists(output_file):
+ with open(output_file, 'r') as f:
+ evaluation_results = []
+ for line in f:
+ if line.strip() == "":
+ continue
+ evaluation_result = json.loads(line)
+ evaluation_results.append(evaluation_result)
+
+ print(f"\n\n\nAll conversations generation and evaluation complete!")
+ print(f" # Total conversations: {len(evaluation_results)}")
+ print("\nEvaluation Results:")
+ print(f" # Overall average accuracy: {evaluation_results[0]['average_accuracy']}")
+ print(f" # Overall average conversation length (# messages): {evaluation_results[0]['average_conversation_length']}")
+ return
+
+ dataset,user_task_description = load_dataset(dataset_name, eval_size)
+
+ # Generate conversations
+ generated_conversations = []
+
+ print(f"\n\n\nStarting generation conversations for user no preferences\n")
+
+ conversationGenerator = ConversationGenerator(
+ user_task_description=user_task_description,
+ user_persona=None,
+ user_preferences=None,
+ max_turns=max_turns,
+ agent_with_user_preferences=False,
+ batch_size=batch_size,
+ user_model_name=user_model_name,
+ user_api_base=user_api_base,
+ user_api_key=user_api_key,
+ collaborator_model_name=collaborator_model_name,
+ collaborator_api_base=collaborator_api_base,
+ collaborator_api_key=collaborator_api_key
+ )
+ generated_conversations = conversationGenerator.generate_conversations_parallel(dataset)
+
+ conversationEvaluator = ConversationEvaluator(
+ dataset_name=dataset_name,
+ model_name=judge_model_name,
+ api_base=judge_api_base,
+ api_key=judge_api_key
+ )
+ evaluation_results = conversationEvaluator.evaluate_conversations(generated_conversations)
+
+ with open(output_file, 'a') as f:
+ f.write(json.dumps(evaluation_results) + "\n")
+ f.flush()
+
+ print(f"\n\n\nAll conversations generation and evaluation complete!")
+ print(f" # Total conversations: {len(generated_conversations)}")
+ print("\nEvaluation Results:")
+ print(f" # Overall average accuracy: {evaluation_results['average_accuracy']}")
+ print(f" # Overall average conversation length (# messages): {evaluation_results['average_conversation_length']}")
+
+def run_user_profiles(
+ dataset_name="math-hard",
+ training=False,
+ user_profiles=None,
+ user_with_preferences=False,
+ agent_with_user_preferences=False,
+ agent_with_reflection=False,
+ with_scaffolding=False,
+ with_proper_scaffolding=False,
+ eval_size=20,
+ max_turns=10,
+ batch_size=100,
+ user_model_name="gpt-4.1-mini",
+ user_api_base=None,
+ user_api_key=None,
+ collaborator_model_name="gpt-4.1-mini",
+ collaborator_api_base=None,
+ collaborator_api_key=None,
+ judge_model_name="gpt-4.1-mini",
+ judge_api_base=None,
+ judge_api_key=None,
+ output_file=None
+ ):
+ dataset,user_task_description = load_dataset(dataset_name, eval_size, training=training)
+
+ generated_user_sessions = []
+ if os.path.exists(output_file):
+ with open(output_file, 'r') as f:
+ seen_users = set()
+ for line in f:
+ if line.strip() == "":
+ continue
+ curr_result = json.loads(line)
+ seen_users.add(curr_result["i"])
+ generated_user_sessions.append(curr_result)
+ user_profiles = [user_profile_elem for user_profile_elem in user_profiles if user_profile_elem["i"] not in seen_users]
+
+ def generate_and_evaluate_single_user_profile(user_profile_elem):
+ user_profile_i = user_profile_elem["i"]
+ user_persona = user_profile_elem["persona"]
+ if user_with_preferences:
+ user_preferences = "\n".join([f"{i+1}. {pref}" for i, pref in enumerate(user_profile_elem["preferences"])])
+ else:
+ user_preferences = None
+
+ # Generate conversations
+ if agent_with_reflection:
+ print(f"Starting generation conversation sessions for User {user_profile_i}")
+ conversationGenerator = ConversationGenerator(
+ user_task_description=user_task_description,
+ user_persona=user_persona,
+ user_preferences=user_preferences,
+ agent_with_user_preferences=agent_with_user_preferences,
+ max_turns=max_turns,
+ with_scaffolding=with_scaffolding,
+ with_proper_scaffolding=with_proper_scaffolding,
+ batch_size=batch_size,
+ user_model_name=user_model_name,
+ user_api_base=user_api_base,
+ user_api_key=user_api_key,
+ collaborator_model_name=collaborator_model_name,
+ collaborator_api_base=collaborator_api_base,
+ collaborator_api_key=collaborator_api_key
+ )
+ generated_conversations = conversationGenerator.generate_conversations_with_reflective_agent(dataset, training=training)
+ print(f"Finished generation conversation sessions for User {user_profile_i}")
+ print(f" # succeeded user conversation sessions: {len(generated_conversations)}")
+ print(f" # failed user conversation sessions: {len(dataset) - len(generated_conversations)}")
+ else:
+ print(f"Starting generation conversation sessions for User {user_profile_i}")
+ conversationGenerator = ConversationGenerator(
+ user_task_description=user_task_description,
+ user_persona=user_persona,
+ user_preferences=user_preferences,
+ agent_with_user_preferences=agent_with_user_preferences,
+ max_turns=max_turns,
+ batch_size=batch_size,
+ user_model_name=user_model_name,
+ user_api_base=user_api_base,
+ user_api_key=user_api_key,
+ collaborator_model_name=collaborator_model_name,
+ collaborator_api_base=collaborator_api_base,
+ collaborator_api_key=collaborator_api_key
+ )
+ generated_conversations = conversationGenerator.generate_conversations_parallel(dataset)
+ print(f"Finished generation conversation sessions for User {user_profile_i}")
+ print(f" # succeeded user conversation sessions: {len(generated_conversations)}")
+ print(f" # failed user conversation sessions: {len(dataset) - len(generated_conversations)}")
+
+ # Evaluate conversations
+ conversationEvaluator = ConversationEvaluator(
+ dataset_name=dataset_name,
+ model_name=judge_model_name,
+ api_base=judge_api_base,
+ api_key=judge_api_key
+ )
+ evaluation_results = conversationEvaluator.evaluate_conversations(generated_conversations)
+ user_profile_elem["generated_conversations"] = generated_conversations
+ user_profile_elem["evaluation"] = evaluation_results
+
+ return user_profile_elem
+
+
+ with open(output_file, 'a') as f:
+ with tqdm(total=len(user_profiles), desc="Processing user profiles") as progress_bar:
+ for i in range(0, len(user_profiles), batch_size):
+ batch = user_profiles[i:i+batch_size]
+
+ with ThreadPoolExecutor(max_workers=min(batch_size, len(batch))) as executor:
+ futures_to_profile = {
+ executor.submit(generate_and_evaluate_single_user_profile, user_profile_elem): user_profile_elem
+ for user_profile_elem in batch
+ }
+
+ for future in as_completed(futures_to_profile):
+ curr_result = future.result()
+ generated_user_sessions.append(curr_result)
+
+ f.write(json.dumps(curr_result) + "\n")
+ f.flush()
+
+ progress_bar.update(1)
+
+ # Aggregate evaluation results from all user sessions
+ avg_accuracy = sum([user_session['evaluation']['average_accuracy'] for user_session in generated_user_sessions]) / len(generated_user_sessions)
+ avg_length = sum([user_session['evaluation']['average_conversation_length'] for user_session in generated_user_sessions]) / len(generated_user_sessions)
+
+ num_enforced_preferences_per_conversation = []
+ for generated_user_session in generated_user_sessions:
+ for generated_conversation in generated_user_session['generated_conversations']:
+ curr_num_enforced_preferences = 0
+ for message in generated_conversation['full_conversation_log']:
+ if 'enforce_preferences' in message:
+ if message["enforce_preferences"] == True or message["enforce_preferences"] == "True":
+ curr_num_enforced_preferences += 1
+ num_enforced_preferences_per_conversation.append(curr_num_enforced_preferences)
+
+ print(f"\n\n\nAll user profiles generation and evaluation complete!")
+ print(f" # Total user profiles processed: {len(generated_user_sessions)}")
+ print(f" # Total conversations: {sum([len(user_session['generated_conversations']) for user_session in generated_user_sessions])}")
+ print("\nEvaluation Results:")
+ print(f" # Overall average accuracy: {avg_accuracy}")
+ print(f" # Overall average conversation length (# messages): {avg_length}")
+ print(f" # Overall average number of enforced preferences: {sum(num_enforced_preferences_per_conversation) / len(num_enforced_preferences_per_conversation)}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--experiment_type", type=str)
+ parser.add_argument("--dataset", type=str)
+ parser.add_argument("--eval_size", type=int)
+ parser.add_argument("--output_file", type=str)
+ parser.add_argument("--max_turns", type=int)
+ parser.add_argument("--batch_size", type=int)
+ parser.add_argument("--user_model_name", type=str)
+ parser.add_argument("--user_api_base", type=str)
+ parser.add_argument("--user_api_key", type=str)
+ parser.add_argument("--collaborator_model_name", type=str)
+ parser.add_argument("--collaborator_api_base", type=str)
+ parser.add_argument("--collaborator_api_key", type=str)
+ parser.add_argument("--judge_model_name", type=str)
+ parser.add_argument("--judge_api_base", type=str)
+ parser.add_argument("--judge_api_key", type=str)
+ args = parser.parse_args()
+
+ if args.experiment_type == "no_user":
+ run_no_user(
+ dataset_name=args.dataset,
+ eval_size=args.eval_size,
+ batch_size=args.batch_size,
+ collaborator_model_name=args.collaborator_model_name, collaborator_api_base=args.collaborator_api_base, collaborator_api_key=args.collaborator_api_key,
+ judge_model_name=args.judge_model_name, judge_api_base=args.judge_api_base, judge_api_key=args.judge_api_key,
+ output_file=args.output_file
+ )
+ elif args.experiment_type == "user_no_profile":
+ run_user_no_profile(
+ dataset_name=args.dataset,
+ eval_size=args.eval_size,
+ max_turns=args.max_turns,
+ user_model_name=args.user_model_name, user_api_base=args.user_api_base, user_api_key=args.user_api_key,
+ collaborator_model_name=args.collaborator_model_name, collaborator_api_base=args.collaborator_api_base, collaborator_api_key=args.collaborator_api_key,
+ judge_model_name=args.judge_model_name, judge_api_base=args.judge_api_base, judge_api_key=args.judge_api_key,
+ output_file=args.output_file
+ )
+ elif args.experiment_type == "user_profiles_without_preferences":
+ user_profiles = load_user_profiles()
+ run_user_profiles(
+ dataset_name=args.dataset,
+ training=False,
+ user_profiles=user_profiles,
+ user_with_preferences=False,
+ agent_with_reflection=False,
+ eval_size=args.eval_size,
+ max_turns=args.max_turns,
+ batch_size=args.batch_size,
+ user_model_name=args.user_model_name, user_api_base=args.user_api_base, user_api_key=args.user_api_key,
+ collaborator_model_name=args.collaborator_model_name, collaborator_api_base=args.collaborator_api_base, collaborator_api_key=args.collaborator_api_key,
+ judge_model_name=args.judge_model_name, judge_api_base=args.judge_api_base, judge_api_key=args.judge_api_key,
+ output_file=args.output_file
+ )
+ elif args.experiment_type == "user_profiles_with_preferences":
+ user_profiles = load_user_profiles()
+ run_user_profiles(
+ dataset_name=args.dataset,
+ training=False,
+ user_profiles=user_profiles,
+ user_with_preferences=True,
+ agent_with_reflection=False,
+ eval_size=args.eval_size,
+ max_turns=args.max_turns,
+ batch_size=args.batch_size,
+ user_model_name=args.user_model_name, user_api_base=args.user_api_base, user_api_key=args.user_api_key,
+ collaborator_model_name=args.collaborator_model_name, collaborator_api_base=args.collaborator_api_base, collaborator_api_key=args.collaborator_api_key,
+ judge_model_name=args.judge_model_name, judge_api_base=args.judge_api_base, judge_api_key=args.judge_api_key,
+ output_file=args.output_file
+ )
+ elif args.experiment_type == "agent_with_user_preferences":
+ user_profiles = load_user_profiles()
+ run_user_profiles(
+ dataset_name=args.dataset,
+ training=False,
+ user_profiles=user_profiles,
+ user_with_preferences=True,
+ agent_with_user_preferences=True,
+ agent_with_reflection=False,
+ eval_size=args.eval_size,
+ max_turns=args.max_turns,
+ batch_size=args.batch_size,
+ user_model_name=args.user_model_name, user_api_base=args.user_api_base, user_api_key=args.user_api_key,
+ collaborator_model_name=args.collaborator_model_name, collaborator_api_base=args.collaborator_api_base, collaborator_api_key=args.collaborator_api_key,
+ judge_model_name=args.judge_model_name, judge_api_base=args.judge_api_base, judge_api_key=args.judge_api_key,
+ output_file=args.output_file
+ )
+ elif args.experiment_type == "agent_with_reflection":
+ user_profiles = load_user_profiles()
+ run_user_profiles(
+ dataset_name=args.dataset,
+ training=False,
+ user_profiles=user_profiles,
+ user_with_preferences=True,
+ agent_with_user_preferences=True,
+ agent_with_reflection=True,
+ eval_size=args.eval_size,
+ max_turns=args.max_turns,
+ batch_size=args.batch_size,
+ user_model_name=args.user_model_name, user_api_base=args.user_api_base, user_api_key=args.user_api_key,
+ collaborator_model_name=args.collaborator_model_name, collaborator_api_base=args.collaborator_api_base, collaborator_api_key=args.collaborator_api_key,
+ judge_model_name=args.judge_model_name, judge_api_base=args.judge_api_base, judge_api_key=args.judge_api_key,
+ output_file=args.output_file
+ )
+ elif args.experiment_type == "agent_with_reflection_and_scaffolding":
+ user_profiles = load_user_profiles()
+ run_user_profiles(
+ dataset_name=args.dataset,
+ training=False,
+ user_profiles=user_profiles,
+ user_with_preferences=True,
+ agent_with_user_preferences=True,
+ agent_with_reflection=True,
+ with_scaffolding=True,
+ eval_size=args.eval_size,
+ max_turns=args.max_turns,
+ batch_size=args.batch_size,
+ user_model_name=args.user_model_name, user_api_base=args.user_api_base, user_api_key=args.user_api_key,
+ collaborator_model_name=args.collaborator_model_name, collaborator_api_base=args.collaborator_api_base, collaborator_api_key=args.collaborator_api_key,
+ judge_model_name=args.judge_model_name, judge_api_base=args.judge_api_base, judge_api_key=args.judge_api_key,
+ output_file=args.output_file
+ )
+ elif args.experiment_type == "agent_with_reflection_and_proper_scaffolding":
+ user_profiles = load_user_profiles()
+ run_user_profiles(
+ dataset_name=args.dataset,
+ training=False,
+ user_profiles=user_profiles,
+ user_with_preferences=True,
+ agent_with_user_preferences=True,
+ agent_with_reflection=True,
+ with_scaffolding=True,
+ with_proper_scaffolding=True,
+ eval_size=args.eval_size,
+ max_turns=args.max_turns,
+ batch_size=args.batch_size,
+ user_model_name=args.user_model_name, user_api_base=args.user_api_base, user_api_key=args.user_api_key,
+ collaborator_model_name=args.collaborator_model_name, collaborator_api_base=args.collaborator_api_base, collaborator_api_key=args.collaborator_api_key,
+ judge_model_name=args.judge_model_name, judge_api_base=args.judge_api_base, judge_api_key=args.judge_api_key,
+ output_file=args.output_file
+ )
+ elif args.experiment_type == "training_data_with_user_profiles_with_preferences":
+ user_profiles = load_user_profiles(training=True)
+ run_user_profiles(
+ dataset_name=args.dataset,
+ training=True,
+ user_profiles=user_profiles,
+ user_with_preferences=True,
+ agent_with_user_preferences=True,
+ agent_with_reflection=True,
+ eval_size=args.eval_size,
+ max_turns=args.max_turns,
+ batch_size=args.batch_size,
+ user_model_name=args.user_model_name, user_api_base=args.user_api_base, user_api_key=args.user_api_key,
+ collaborator_model_name=args.collaborator_model_name, collaborator_api_base=args.collaborator_api_base, collaborator_api_key=args.collaborator_api_key,
+ judge_model_name=args.judge_model_name, judge_api_base=args.judge_api_base, judge_api_key=args.judge_api_key,
+ output_file=args.output_file
+ )
+ else:
+ raise ValueError(f"Invalid experiment type: {args.experiment_type}") \ No newline at end of file