From dc801c07cf38b0c495686463e6ca6f871a64440e Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Tue, 27 Jan 2026 09:57:37 -0600 Subject: Add collaborativeagents module and update gitignore - Add collaborativeagents subproject with adapters, agents, and evaluation modules - Update .gitignore to exclude large binary files (.whl, .tar), wandb logs, and results Co-Authored-By: Claude Opus 4.5 --- collaborativeagents/scripts/visualize.py | 492 +++++++++++++++++++++++++++++++ 1 file changed, 492 insertions(+) create mode 100644 collaborativeagents/scripts/visualize.py (limited to 'collaborativeagents/scripts/visualize.py') diff --git a/collaborativeagents/scripts/visualize.py b/collaborativeagents/scripts/visualize.py new file mode 100644 index 0000000..2cf7369 --- /dev/null +++ b/collaborativeagents/scripts/visualize.py @@ -0,0 +1,492 @@ +import json +from itertools import zip_longest +import textwrap + +def load_data(filepath): + """Load user data from a JSONL file""" + users = [] + with open(filepath, 'r') as f: + for line in f: + if line.strip(): + users.append(json.loads(line)) + return users + +def format_conversation(conv, file_label): + """Format a conversation into a list of lines""" + lines = [] + lines.append(f">>> {file_label} <<<") + lines.append("") + + if 'conversation' in conv: + for i, msg in enumerate(conv['conversation'], 1): + role = msg.get('role', 'unknown').upper() + content = msg.get('content', '') + lines.append(f"[{i}] {role}:") + # Split content into lines and indent + for content_line in content.split('\n'): + lines.append(f" {content_line}") + + if i-2 < len(conv['full_conversation_log']): + full_conversation_log_msg = conv['full_conversation_log'][i-2] + if 'enforce_preferences' in full_conversation_log_msg and (full_conversation_log_msg['enforce_preferences'] == True or full_conversation_log_msg['enforce_preferences'] == "True"): + lines.append(f"<<<<< Enforced preference >>>>>") + + lines.append("") # Empty line after each message + + # Format evaluation + if 'evaluation' in conv: + lines.append("[EVALUATION]") + eval_data = conv['evaluation'] + + if 'final_answer' in eval_data: + lines.append(f"• Final Answer: {eval_data['final_answer']}") + + if 'accuracy' in eval_data: + acc = eval_data['accuracy']['accuracy'] + acc_symbol = "✓" if acc == 1 else "✗" + lines.append(f"• Accuracy: {acc} {acc_symbol}") + + num_enforced_preferences = len([message for message in conv['full_conversation_log'] if 'enforce_preferences' in message and (message['enforce_preferences'] == True or message['enforce_preferences'] == "True")]) + lines.append(f"• Number of enforced preferences: {num_enforced_preferences}") + + if 'conversation_length' in eval_data: + lines.append(f"• Length: {eval_data['conversation_length']} msgs") + + return lines + +def wrap_lines(lines, width): + """Wrap lines to specified width""" + wrapped = [] + for line in lines: + if len(line) <= width: + wrapped.append(line) + else: + # Wrap the line + wrapped_line = textwrap.wrap(line, width=width, break_long_words=True, break_on_hyphens=False) + wrapped.extend(wrapped_line) + return wrapped + +def calculate_aggregate_stats(users_data): + """Calculate aggregate statistics across all users""" + all_accuracies = [] + all_lengths = [] + all_enforced_counts = [] + + for user_data in users_data: + if 'generated_conversations' in user_data: + for conv in user_data['generated_conversations']: + # Collect accuracy + if 'evaluation' in conv and 'accuracy' in conv['evaluation']: + all_accuracies.append(conv['evaluation']['accuracy']['accuracy']) + + # Collect conversation length + if 'evaluation' in conv and 'conversation_length' in conv['evaluation']: + all_lengths.append(conv['evaluation']['conversation_length']) + + # Collect enforced preferences count + if 'full_conversation_log' in conv: + count = len([msg for msg in conv['full_conversation_log'] + if 'enforce_preferences' in msg and (msg['enforce_preferences'] == True or msg['enforce_preferences'] == "True")]) + all_enforced_counts.append(count) + + avg_accuracy = sum(all_accuracies) / len(all_accuracies) if all_accuracies else 0 + avg_length = sum(all_lengths) / len(all_lengths) if all_lengths else 0 + avg_enforced = sum(all_enforced_counts) / len(all_enforced_counts) if all_enforced_counts else 0 + + return avg_accuracy, avg_length, avg_enforced + +def print_side_by_side(conv1, conv2, label1, label2, col_width=60): + """Print two conversations side by side""" + lines1 = format_conversation(conv1, label1) + lines2 = format_conversation(conv2, label2) + + # Wrap lines to fit column width + lines1 = wrap_lines(lines1, col_width) + lines2 = wrap_lines(lines2, col_width) + + # Print header + print(f"\n{label1:<{col_width}} | {label2}") + print(f"{'-'*col_width} | {'-'*col_width}") + + # Print lines side by side + for line1, line2 in zip_longest(lines1, lines2, fillvalue=''): + # Pad line1 to col_width + line1 = line1.ljust(col_width) + + print(f"{line1} | {line2}") + +def print_side_by_side_3(conv1, conv2, conv3, label1, label2, label3, col_width=42): + """Print three conversations side by side""" + lines1 = format_conversation(conv1, label1) + lines2 = format_conversation(conv2, label2) + lines3 = format_conversation(conv3, label3) + + # Wrap lines to fit column width + lines1 = wrap_lines(lines1, col_width) + lines2 = wrap_lines(lines2, col_width) + lines3 = wrap_lines(lines3, col_width) + + # Print header + print(f"\n{label1:<{col_width}} | {label2:<{col_width}} | {label3}") + print(f"{'-'*col_width} | {'-'*col_width} | {'-'*col_width}") + + # Print lines side by side + for line1, line2, line3 in zip_longest(lines1, lines2, lines3, fillvalue=''): + # Pad line1 and line2 to col_width + line1 = line1.ljust(col_width) + line2 = line2.ljust(col_width) + + print(f"{line1} | {line2} | {line3}") + +def format_detailed_full_log(conv, file_label): + """Format detailed conversation including all fields from full_conversation_log""" + lines = [] + lines.append(f">>> {file_label} — FULL LOG <<<") + lines.append("") + + if 'full_conversation_log' in conv and conv['full_conversation_log']: + for j, msg in enumerate(conv['full_conversation_log'], 1): + # Alternate roles starting with USER + role_label = 'USER' if j % 2 == 1 else 'ASSISTANT' + lines.append(f"[{j}] {role_label}:") + + def is_enforced(value): + return value is True or value == "True" or value == "true" + + # 1) Response first (as plain text) + response_text = msg.get('response') + if response_text is not None: + for line in str(response_text).split('\n'): + lines.append(f"{line}") + + # 1a) Enforcement tag if applicable + if 'enforce_preferences' in msg and is_enforced(msg['enforce_preferences']): + lines.append("<<<<< Preferences Enforced >>>>>") + + # 2) Ordered keys as bulleted items + ordered_keys = [ + 'preference_1_satisfied', + 'preference_2_satisfied', + 'preference_3_satisfied', + 'enforce_preferences', + 'draft_answer', + 'reasoning', + 'should_terminate', + ] + + def append_bullet(key, value): + if isinstance(value, (dict, list)): + try: + pretty_value = json.dumps(value, indent=2, sort_keys=True, ensure_ascii=False) + except Exception: + pretty_value = str(value) + lines.append(f" - {key}:") + for ln in pretty_value.split('\n'): + lines.append(f" {ln}") + else: + value_str = str(value) if value is not None else "" + value_lines = value_str.split('\n') if value_str else [""] + # First line on the bullet + lines.append(f" - {key}: {value_lines[0]}") + # Continuation lines indented slightly further + for cont in value_lines[1:]: + lines.append(f" {cont}") + + for key in ordered_keys: + if key in msg: + append_bullet(key, msg.get(key)) + + # 3) Remaining keys grouped under Other fields + shown_keys = set(['response'] + ordered_keys) + remaining_keys = [k for k in msg.keys() if k not in shown_keys] + if remaining_keys: + lines.append(" - Other fields:") + for k in sorted(remaining_keys): + v = msg[k] + if isinstance(v, (dict, list)): + try: + pretty_v = json.dumps(v, indent=2, sort_keys=True, ensure_ascii=False) + except Exception: + pretty_v = str(v) + lines.append(f" {k}:") + for ln in pretty_v.split('\n'): + lines.append(f" {ln}") + else: + v_str = str(v) + v_lines = v_str.split('\n') if v_str else [""] + lines.append(f" {k}: {v_lines[0]}") + for cont in v_lines[1:]: + lines.append(f" {cont}") + + lines.append("") + else: + lines.append("[No full_conversation_log available]") + + # Include evaluation details if present + if 'evaluation' in conv: + lines.append("[EVALUATION — FULL]") + try: + eval_pretty = json.dumps(conv['evaluation'], indent=2, sort_keys=True, ensure_ascii=False) + except Exception: + eval_pretty = str(conv['evaluation']) + for line in eval_pretty.split('\n'): + lines.append(f" {line}") + + return lines + +def print_detailed_logs_3(conv1, conv2, conv3, label1, label2, label3, col_width=42): + """Print detailed logs for three conversations side by side""" + lines1 = wrap_lines(format_detailed_full_log(conv1, label1), col_width) + lines2 = wrap_lines(format_detailed_full_log(conv2, label2), col_width) + lines3 = wrap_lines(format_detailed_full_log(conv3, label3), col_width) + + print(f"\n{label1:<{col_width}} | {label2:<{col_width}} | {label3}") + print(f"{'-'*col_width} | {'-'*col_width} | {'-'*col_width}") + for line1, line2, line3 in zip_longest(lines1, lines2, lines3, fillvalue=''): + print(f"{line1.ljust(col_width)} | {line2.ljust(col_width)} | {line3}") + +def print_user_info(user_data): + """Print user profile information""" + print(f"\n[USER PROFILE]") + if 'i' in user_data: + print(f"User ID: {user_data['i']}") + if 'persona' in user_data: + print(f"Persona: {user_data['persona']}") + if 'preferences' in user_data: + print(f"Preferences:") + for preference in user_data['preferences']: + print(f" - {preference}") + + + + + + +for task in ["bigcodebench"]: # ["math_500", "logiqa", "math_hard", "medqa", "mmlu"]: + user_profiles_without_preferences_path = f"/shared/storage-01/users/mehri2/mem/collaborativeagents/scripts/runs/llama70b_temp_1/user_profiles_without_preferences/logiqa_llama70b_user_llama70b_agent_user_profiles_without_preferences_eval_size_20.jsonl" + user_profiles_with_preferences_path = f"/shared/storage-01/users/mehri2/mem/collaborativeagents/scripts/runs/llama70b_temp_1/user_profiles_with_preferences/logiqa_llama70b_user_llama70b_agent_user_profiles_with_preferences_eval_size_20.jsonl" + agent_with_userpreferences_path = f"/shared/storage-01/users/mehri2/mem/collaborativeagents/scripts/runs/llama70b_temp_1/agent_with_user_preferences/logiqa_llama70b_user_llama70b_agent_agent_with_user_preferences_eval_size_20_v2.jsonl" + agnet_with_reflection_path = f"/shared/storage-01/users/mehri2/mem/collaborativeagents/scripts/runs/llama70b_temp_1/agent_with_reflection_v3/logiqa_llama70b_user_llama70b_agent_agent_with_reflection_eval_size_20.jsonl" + + + file1_path = user_profiles_without_preferences_path + file2_path = user_profiles_with_preferences_path + file3_path = agent_with_userpreferences_path + file4_path = agnet_with_reflection_path + + + # Load users from all three files + data1 = load_data(file1_path) + data2 = load_data(file2_path) + data3 = load_data(file3_path) + data4 = load_data(file4_path) + + id_to_user_data1 = {elem['i']: elem for elem in data1+data2+data3+data4} + id_to_user_data2 = {elem['i']: elem for elem in data2} + id_to_user_data3 = {elem['i']: elem for elem in data3} + id_to_user_data4 = {elem['i']: elem for elem in data4} + + + + for id in id_to_user_data1: + if id != 23: continue + user_avg_acc1, user_avg_len1, user_avg_enf1 = calculate_aggregate_stats([id_to_user_data1[id]]) + user_avg_acc2, user_avg_len2, user_avg_enf2 = calculate_aggregate_stats([id_to_user_data2[id]]) + user_avg_acc3, user_avg_len3, user_avg_enf3 = calculate_aggregate_stats([id_to_user_data3[id]]) + user_avg_acc4, user_avg_len4, user_avg_enf4 = calculate_aggregate_stats([id_to_user_data4[id]]) + + + # print user info + print("\n" + "="*125 + "\n") + print(f"### Task: {task}\n") + print("LOGGING FOR USER ID: ", id) + print_user_info(id_to_user_data1[id]) + + # Print the average performance for id_to_user_data1[id] + # Print the average performance for id_to_user_data2[id] + print("\n" + "-"*125) + print("COMPARISON FOR THIS USER") + print("-"*125) + + + print("\nUser Without Preferences:") + print(f" Average Accuracy: {user_avg_acc1:.2f}") + print(f" Average # Messages: {user_avg_len1:.2f}") + print(f" Average # Enforced Preferences: {user_avg_enf1:.2f}") + + print("\nUser With Preferences:") + print(f" Average Accuracy: {user_avg_acc2:.2f}") + print(f" Average # Messages: {user_avg_len2:.2f}") + print(f" Average # Enforced Preferences: {user_avg_enf2:.2f}") + + print("\nAgent With User Preferences:") + print(f" Average Accuracy: {user_avg_acc3:.2f}") + print(f" Average # Messages: {user_avg_len3:.2f}") + print(f" Average # Enforced Preferences: {user_avg_enf3:.2f}") + + print("\nAgent With Reflection:") + print(f" Average Accuracy: {user_avg_acc4:.2f}") + print(f" Average # Messages: {user_avg_len4:.2f}") + print(f" Average # Enforced Preferences: {user_avg_enf4:.2f}") + + + # print conversations + problem_to_conversation1 = {conv['sample']['problem']: conv for conv in id_to_user_data1[id]['generated_conversations']} + problem_to_conversation2 = {conv['sample']['problem']: conv for conv in id_to_user_data2[id]['generated_conversations']} + problem_to_conversation3 = {conv['sample']['problem']: conv for conv in id_to_user_data3[id]['generated_conversations']} + + for problem in problem_to_conversation1: + print("\n" + "="*125) + print(f"\n[PROBLEM]") + print(problem) + print(f"\n[SOLUTION]") + print(problem_to_conversation1[problem]['sample']['solution']) + print("\n" + "="*125) + + print_side_by_side_3( + problem_to_conversation1[problem], + problem_to_conversation2[problem], + problem_to_conversation3.get(problem, {'conversation': [], 'evaluation': {}}), + "FILE 1 (WITHOUT PREFERENCES)", + "FILE 2 (WITH PREFERENCES)", + "FILE 3 (AGENT WITH USER PREFS)", + col_width=55 + ) + + # Detailed logs below with all fields + print("\n" + "-"*125) + print("DETAILED FULL LOGS") + print("-"*125) + print_detailed_logs_3( + problem_to_conversation1[problem], + problem_to_conversation2[problem], + problem_to_conversation3.get(problem, {'conversation': [], 'evaluation': {}}), + "FILE 1 (WITHOUT PREFERENCES)", + "FILE 2 (WITH PREFERENCES)", + "FILE 3 (AGENT WITH USER PREFS)", + col_width=55 + ) + + # break + + + + + + + + + + + + + + + + + + +# # ============================================================================== +# # SEPARATE SECTION: Per-User Statistics Averaged Over All Tasks +# # ============================================================================== + +# print("\n" + "="*125) +# print("="*125) +# print("STATISTICS FOR EACH USER, AVERAGED OVER ALL TASKS") +# print("="*125) +# print("="*125 + "\n") + +# # Dictionary to store all data for each user across all tasks +# user_to_all_data = {} + +# # Collect data for all users across all tasks +# for task in ["math_500", "logiqa", "math_hard", "medqa", "mmlu"]: +# user_profiles_without_preferences_path = f"/shared/storage-01/users/mehri2/mem/collaborativeagents/scripts/runs/llama70b/user_profiles_without_preferences/{task}_llama70b_user_llama70b_agent_user_profiles_without_preferences_eval_size_20.jsonl" +# user_profiles_with_preferences_path = f"/shared/storage-01/users/mehri2/mem/collaborativeagents/scripts/runs/llama70b/user_profiles_with_preferences/{task}_llama70b_user_llama70b_agent_user_profiles_with_preferences_eval_size_20.jsonl" +# agent_with_userpreferences_path = f"/shared/storage-01/users/mehri2/mem/collaborativeagents/scripts/runs/llama70b/agent_with_user_preferences/{task}_llama70b_user_llama70b_agent_agent_with_user_preferences_eval_size_20_v2.jsonl" +# agent_with_reflection_path = f"/shared/storage-01/users/mehri2/mem/collaborativeagents/scripts/runs/llama70b/agent_with_reflection/{task}_llama70b_user_llama70b_agent_agent_with_reflection_eval_size_20.jsonl" + +# data1 = load_data(user_profiles_without_preferences_path) +# data2 = load_data(user_profiles_with_preferences_path) +# data3 = load_data(agent_with_userpreferences_path) +# data4 = load_data(agent_with_reflection_path) + +# # For each user in this task, store their data +# for user_data in data1: +# user_id = user_data['i'] +# if user_id not in user_to_all_data: +# user_to_all_data[user_id] = { +# 'persona': user_data.get('persona'), +# 'preferences': user_data.get('preferences'), +# 'data1': [], # without preferences +# 'data2': [], # with preferences +# 'data3': [], # agent with user preferences +# 'data4': [] # agent with reflection +# } +# user_to_all_data[user_id]['data1'].append(user_data) + +# for user_data in data2: +# user_id = user_data['i'] +# if user_id in user_to_all_data: +# user_to_all_data[user_id]['data2'].append(user_data) + +# for user_data in data3: +# user_id = user_data['i'] +# if user_id in user_to_all_data: +# user_to_all_data[user_id]['data3'].append(user_data) + +# for user_data in data4: +# user_id = user_data['i'] +# if user_id in user_to_all_data: +# user_to_all_data[user_id]['data4'].append(user_data) + +# # Now print statistics for each user, averaged over all tasks +# for user_id in sorted(user_to_all_data.keys()): +# user_info = user_to_all_data[user_id] + +# # Calculate aggregate stats across all tasks for this user +# user_avg_acc1, user_avg_len1, user_avg_enf1 = calculate_aggregate_stats(user_info['data1']) +# user_avg_acc2, user_avg_len2, user_avg_enf2 = calculate_aggregate_stats(user_info['data2']) +# user_avg_acc3, user_avg_len3, user_avg_enf3 = calculate_aggregate_stats(user_info['data3']) +# user_avg_acc4, user_avg_len4, user_avg_enf4 = calculate_aggregate_stats(user_info['data4']) + +# print("\n" + "="*125) +# print(f"USER ID: {user_id}") +# print("="*125) + +# # Print user profile info +# if user_info['persona']: +# print(f"Persona: {user_info['persona']}") +# if user_info['preferences']: +# print(f"Preferences:") +# for preference in user_info['preferences']: +# print(f" - {preference}") + +# print("\n" + "-"*125) +# print("STATISTICS AVERAGED OVER ALL TASKS") +# print("-"*125) + +# print("\nUser Without Preferences:") +# print(f" Average Accuracy: {user_avg_acc1:.2f}") +# print(f" Average # Messages: {user_avg_len1:.2f}") +# print(f" Average # Enforced Preferences: {user_avg_enf1:.2f}") + +# print("\nUser With Preferences:") +# print(f" Average Accuracy: {user_avg_acc2:.2f}") +# print(f" Average # Messages: {user_avg_len2:.2f}") +# print(f" Average # Enforced Preferences: {user_avg_enf2:.2f}") + +# print("\nAgent With User Preferences:") +# print(f" Average Accuracy: {user_avg_acc3:.2f}") +# print(f" Average # Messages: {user_avg_len3:.2f}") +# print(f" Average # Enforced Preferences: {user_avg_enf3:.2f}") + +# print("\nAgent With Reflection:") +# print(f" Average Accuracy: {user_avg_acc4:.2f}") +# print(f" Average # Messages: {user_avg_len4:.2f}") +# print(f" Average # Enforced Preferences: {user_avg_enf4:.2f}") + +# print("\n" + "="*125) +# print("END OF PER-USER STATISTICS") +# print("="*125 + "\n") + -- cgit v1.2.3