summaryrefslogtreecommitdiff
path: root/collaborativeagents/scripts/visualize.py
diff options
context:
space:
mode:
Diffstat (limited to 'collaborativeagents/scripts/visualize.py')
-rw-r--r--collaborativeagents/scripts/visualize.py492
1 files changed, 492 insertions, 0 deletions
diff --git a/collaborativeagents/scripts/visualize.py b/collaborativeagents/scripts/visualize.py
new file mode 100644
index 0000000..2cf7369
--- /dev/null
+++ b/collaborativeagents/scripts/visualize.py
@@ -0,0 +1,492 @@
+import json
+from itertools import zip_longest
+import textwrap
+
+def load_data(filepath):
+ """Load user data from a JSONL file"""
+ users = []
+ with open(filepath, 'r') as f:
+ for line in f:
+ if line.strip():
+ users.append(json.loads(line))
+ return users
+
+def format_conversation(conv, file_label):
+ """Format a conversation into a list of lines"""
+ lines = []
+ lines.append(f">>> {file_label} <<<")
+ lines.append("")
+
+ if 'conversation' in conv:
+ for i, msg in enumerate(conv['conversation'], 1):
+ role = msg.get('role', 'unknown').upper()
+ content = msg.get('content', '')
+ lines.append(f"[{i}] {role}:")
+ # Split content into lines and indent
+ for content_line in content.split('\n'):
+ lines.append(f" {content_line}")
+
+ if i-2 < len(conv['full_conversation_log']):
+ full_conversation_log_msg = conv['full_conversation_log'][i-2]
+ if 'enforce_preferences' in full_conversation_log_msg and (full_conversation_log_msg['enforce_preferences'] == True or full_conversation_log_msg['enforce_preferences'] == "True"):
+ lines.append(f"<<<<< Enforced preference >>>>>")
+
+ lines.append("") # Empty line after each message
+
+ # Format evaluation
+ if 'evaluation' in conv:
+ lines.append("[EVALUATION]")
+ eval_data = conv['evaluation']
+
+ if 'final_answer' in eval_data:
+ lines.append(f"• Final Answer: {eval_data['final_answer']}")
+
+ if 'accuracy' in eval_data:
+ acc = eval_data['accuracy']['accuracy']
+ acc_symbol = "✓" if acc == 1 else "✗"
+ lines.append(f"• Accuracy: {acc} {acc_symbol}")
+
+ num_enforced_preferences = len([message for message in conv['full_conversation_log'] if 'enforce_preferences' in message and (message['enforce_preferences'] == True or message['enforce_preferences'] == "True")])
+ lines.append(f"• Number of enforced preferences: {num_enforced_preferences}")
+
+ if 'conversation_length' in eval_data:
+ lines.append(f"• Length: {eval_data['conversation_length']} msgs")
+
+ return lines
+
+def wrap_lines(lines, width):
+ """Wrap lines to specified width"""
+ wrapped = []
+ for line in lines:
+ if len(line) <= width:
+ wrapped.append(line)
+ else:
+ # Wrap the line
+ wrapped_line = textwrap.wrap(line, width=width, break_long_words=True, break_on_hyphens=False)
+ wrapped.extend(wrapped_line)
+ return wrapped
+
+def calculate_aggregate_stats(users_data):
+ """Calculate aggregate statistics across all users"""
+ all_accuracies = []
+ all_lengths = []
+ all_enforced_counts = []
+
+ for user_data in users_data:
+ if 'generated_conversations' in user_data:
+ for conv in user_data['generated_conversations']:
+ # Collect accuracy
+ if 'evaluation' in conv and 'accuracy' in conv['evaluation']:
+ all_accuracies.append(conv['evaluation']['accuracy']['accuracy'])
+
+ # Collect conversation length
+ if 'evaluation' in conv and 'conversation_length' in conv['evaluation']:
+ all_lengths.append(conv['evaluation']['conversation_length'])
+
+ # Collect enforced preferences count
+ if 'full_conversation_log' in conv:
+ count = len([msg for msg in conv['full_conversation_log']
+ if 'enforce_preferences' in msg and (msg['enforce_preferences'] == True or msg['enforce_preferences'] == "True")])
+ all_enforced_counts.append(count)
+
+ avg_accuracy = sum(all_accuracies) / len(all_accuracies) if all_accuracies else 0
+ avg_length = sum(all_lengths) / len(all_lengths) if all_lengths else 0
+ avg_enforced = sum(all_enforced_counts) / len(all_enforced_counts) if all_enforced_counts else 0
+
+ return avg_accuracy, avg_length, avg_enforced
+
+def print_side_by_side(conv1, conv2, label1, label2, col_width=60):
+ """Print two conversations side by side"""
+ lines1 = format_conversation(conv1, label1)
+ lines2 = format_conversation(conv2, label2)
+
+ # Wrap lines to fit column width
+ lines1 = wrap_lines(lines1, col_width)
+ lines2 = wrap_lines(lines2, col_width)
+
+ # Print header
+ print(f"\n{label1:<{col_width}} | {label2}")
+ print(f"{'-'*col_width} | {'-'*col_width}")
+
+ # Print lines side by side
+ for line1, line2 in zip_longest(lines1, lines2, fillvalue=''):
+ # Pad line1 to col_width
+ line1 = line1.ljust(col_width)
+
+ print(f"{line1} | {line2}")
+
+def print_side_by_side_3(conv1, conv2, conv3, label1, label2, label3, col_width=42):
+ """Print three conversations side by side"""
+ lines1 = format_conversation(conv1, label1)
+ lines2 = format_conversation(conv2, label2)
+ lines3 = format_conversation(conv3, label3)
+
+ # Wrap lines to fit column width
+ lines1 = wrap_lines(lines1, col_width)
+ lines2 = wrap_lines(lines2, col_width)
+ lines3 = wrap_lines(lines3, col_width)
+
+ # Print header
+ print(f"\n{label1:<{col_width}} | {label2:<{col_width}} | {label3}")
+ print(f"{'-'*col_width} | {'-'*col_width} | {'-'*col_width}")
+
+ # Print lines side by side
+ for line1, line2, line3 in zip_longest(lines1, lines2, lines3, fillvalue=''):
+ # Pad line1 and line2 to col_width
+ line1 = line1.ljust(col_width)
+ line2 = line2.ljust(col_width)
+
+ print(f"{line1} | {line2} | {line3}")
+
+def format_detailed_full_log(conv, file_label):
+ """Format detailed conversation including all fields from full_conversation_log"""
+ lines = []
+ lines.append(f">>> {file_label} — FULL LOG <<<")
+ lines.append("")
+
+ if 'full_conversation_log' in conv and conv['full_conversation_log']:
+ for j, msg in enumerate(conv['full_conversation_log'], 1):
+ # Alternate roles starting with USER
+ role_label = 'USER' if j % 2 == 1 else 'ASSISTANT'
+ lines.append(f"[{j}] {role_label}:")
+
+ def is_enforced(value):
+ return value is True or value == "True" or value == "true"
+
+ # 1) Response first (as plain text)
+ response_text = msg.get('response')
+ if response_text is not None:
+ for line in str(response_text).split('\n'):
+ lines.append(f"{line}")
+
+ # 1a) Enforcement tag if applicable
+ if 'enforce_preferences' in msg and is_enforced(msg['enforce_preferences']):
+ lines.append("<<<<< Preferences Enforced >>>>>")
+
+ # 2) Ordered keys as bulleted items
+ ordered_keys = [
+ 'preference_1_satisfied',
+ 'preference_2_satisfied',
+ 'preference_3_satisfied',
+ 'enforce_preferences',
+ 'draft_answer',
+ 'reasoning',
+ 'should_terminate',
+ ]
+
+ def append_bullet(key, value):
+ if isinstance(value, (dict, list)):
+ try:
+ pretty_value = json.dumps(value, indent=2, sort_keys=True, ensure_ascii=False)
+ except Exception:
+ pretty_value = str(value)
+ lines.append(f" - {key}:")
+ for ln in pretty_value.split('\n'):
+ lines.append(f" {ln}")
+ else:
+ value_str = str(value) if value is not None else ""
+ value_lines = value_str.split('\n') if value_str else [""]
+ # First line on the bullet
+ lines.append(f" - {key}: {value_lines[0]}")
+ # Continuation lines indented slightly further
+ for cont in value_lines[1:]:
+ lines.append(f" {cont}")
+
+ for key in ordered_keys:
+ if key in msg:
+ append_bullet(key, msg.get(key))
+
+ # 3) Remaining keys grouped under Other fields
+ shown_keys = set(['response'] + ordered_keys)
+ remaining_keys = [k for k in msg.keys() if k not in shown_keys]
+ if remaining_keys:
+ lines.append(" - Other fields:")
+ for k in sorted(remaining_keys):
+ v = msg[k]
+ if isinstance(v, (dict, list)):
+ try:
+ pretty_v = json.dumps(v, indent=2, sort_keys=True, ensure_ascii=False)
+ except Exception:
+ pretty_v = str(v)
+ lines.append(f" {k}:")
+ for ln in pretty_v.split('\n'):
+ lines.append(f" {ln}")
+ else:
+ v_str = str(v)
+ v_lines = v_str.split('\n') if v_str else [""]
+ lines.append(f" {k}: {v_lines[0]}")
+ for cont in v_lines[1:]:
+ lines.append(f" {cont}")
+
+ lines.append("")
+ else:
+ lines.append("[No full_conversation_log available]")
+
+ # Include evaluation details if present
+ if 'evaluation' in conv:
+ lines.append("[EVALUATION — FULL]")
+ try:
+ eval_pretty = json.dumps(conv['evaluation'], indent=2, sort_keys=True, ensure_ascii=False)
+ except Exception:
+ eval_pretty = str(conv['evaluation'])
+ for line in eval_pretty.split('\n'):
+ lines.append(f" {line}")
+
+ return lines
+
+def print_detailed_logs_3(conv1, conv2, conv3, label1, label2, label3, col_width=42):
+ """Print detailed logs for three conversations side by side"""
+ lines1 = wrap_lines(format_detailed_full_log(conv1, label1), col_width)
+ lines2 = wrap_lines(format_detailed_full_log(conv2, label2), col_width)
+ lines3 = wrap_lines(format_detailed_full_log(conv3, label3), col_width)
+
+ print(f"\n{label1:<{col_width}} | {label2:<{col_width}} | {label3}")
+ print(f"{'-'*col_width} | {'-'*col_width} | {'-'*col_width}")
+ for line1, line2, line3 in zip_longest(lines1, lines2, lines3, fillvalue=''):
+ print(f"{line1.ljust(col_width)} | {line2.ljust(col_width)} | {line3}")
+
+def print_user_info(user_data):
+ """Print user profile information"""
+ print(f"\n[USER PROFILE]")
+ if 'i' in user_data:
+ print(f"User ID: {user_data['i']}")
+ if 'persona' in user_data:
+ print(f"Persona: {user_data['persona']}")
+ if 'preferences' in user_data:
+ print(f"Preferences:")
+ for preference in user_data['preferences']:
+ print(f" - {preference}")
+
+
+
+
+
+
+for task in ["bigcodebench"]: # ["math_500", "logiqa", "math_hard", "medqa", "mmlu"]:
+ user_profiles_without_preferences_path = f"/shared/storage-01/users/mehri2/mem/collaborativeagents/scripts/runs/llama70b_temp_1/user_profiles_without_preferences/logiqa_llama70b_user_llama70b_agent_user_profiles_without_preferences_eval_size_20.jsonl"
+ user_profiles_with_preferences_path = f"/shared/storage-01/users/mehri2/mem/collaborativeagents/scripts/runs/llama70b_temp_1/user_profiles_with_preferences/logiqa_llama70b_user_llama70b_agent_user_profiles_with_preferences_eval_size_20.jsonl"
+ agent_with_userpreferences_path = f"/shared/storage-01/users/mehri2/mem/collaborativeagents/scripts/runs/llama70b_temp_1/agent_with_user_preferences/logiqa_llama70b_user_llama70b_agent_agent_with_user_preferences_eval_size_20_v2.jsonl"
+ agnet_with_reflection_path = f"/shared/storage-01/users/mehri2/mem/collaborativeagents/scripts/runs/llama70b_temp_1/agent_with_reflection_v3/logiqa_llama70b_user_llama70b_agent_agent_with_reflection_eval_size_20.jsonl"
+
+
+ file1_path = user_profiles_without_preferences_path
+ file2_path = user_profiles_with_preferences_path
+ file3_path = agent_with_userpreferences_path
+ file4_path = agnet_with_reflection_path
+
+
+ # Load users from all three files
+ data1 = load_data(file1_path)
+ data2 = load_data(file2_path)
+ data3 = load_data(file3_path)
+ data4 = load_data(file4_path)
+
+ id_to_user_data1 = {elem['i']: elem for elem in data1+data2+data3+data4}
+ id_to_user_data2 = {elem['i']: elem for elem in data2}
+ id_to_user_data3 = {elem['i']: elem for elem in data3}
+ id_to_user_data4 = {elem['i']: elem for elem in data4}
+
+
+
+ for id in id_to_user_data1:
+ if id != 23: continue
+ user_avg_acc1, user_avg_len1, user_avg_enf1 = calculate_aggregate_stats([id_to_user_data1[id]])
+ user_avg_acc2, user_avg_len2, user_avg_enf2 = calculate_aggregate_stats([id_to_user_data2[id]])
+ user_avg_acc3, user_avg_len3, user_avg_enf3 = calculate_aggregate_stats([id_to_user_data3[id]])
+ user_avg_acc4, user_avg_len4, user_avg_enf4 = calculate_aggregate_stats([id_to_user_data4[id]])
+
+
+ # print user info
+ print("\n" + "="*125 + "\n")
+ print(f"### Task: {task}\n")
+ print("LOGGING FOR USER ID: ", id)
+ print_user_info(id_to_user_data1[id])
+
+ # Print the average performance for id_to_user_data1[id]
+ # Print the average performance for id_to_user_data2[id]
+ print("\n" + "-"*125)
+ print("COMPARISON FOR THIS USER")
+ print("-"*125)
+
+
+ print("\nUser Without Preferences:")
+ print(f" Average Accuracy: {user_avg_acc1:.2f}")
+ print(f" Average # Messages: {user_avg_len1:.2f}")
+ print(f" Average # Enforced Preferences: {user_avg_enf1:.2f}")
+
+ print("\nUser With Preferences:")
+ print(f" Average Accuracy: {user_avg_acc2:.2f}")
+ print(f" Average # Messages: {user_avg_len2:.2f}")
+ print(f" Average # Enforced Preferences: {user_avg_enf2:.2f}")
+
+ print("\nAgent With User Preferences:")
+ print(f" Average Accuracy: {user_avg_acc3:.2f}")
+ print(f" Average # Messages: {user_avg_len3:.2f}")
+ print(f" Average # Enforced Preferences: {user_avg_enf3:.2f}")
+
+ print("\nAgent With Reflection:")
+ print(f" Average Accuracy: {user_avg_acc4:.2f}")
+ print(f" Average # Messages: {user_avg_len4:.2f}")
+ print(f" Average # Enforced Preferences: {user_avg_enf4:.2f}")
+
+
+ # print conversations
+ problem_to_conversation1 = {conv['sample']['problem']: conv for conv in id_to_user_data1[id]['generated_conversations']}
+ problem_to_conversation2 = {conv['sample']['problem']: conv for conv in id_to_user_data2[id]['generated_conversations']}
+ problem_to_conversation3 = {conv['sample']['problem']: conv for conv in id_to_user_data3[id]['generated_conversations']}
+
+ for problem in problem_to_conversation1:
+ print("\n" + "="*125)
+ print(f"\n[PROBLEM]")
+ print(problem)
+ print(f"\n[SOLUTION]")
+ print(problem_to_conversation1[problem]['sample']['solution'])
+ print("\n" + "="*125)
+
+ print_side_by_side_3(
+ problem_to_conversation1[problem],
+ problem_to_conversation2[problem],
+ problem_to_conversation3.get(problem, {'conversation': [], 'evaluation': {}}),
+ "FILE 1 (WITHOUT PREFERENCES)",
+ "FILE 2 (WITH PREFERENCES)",
+ "FILE 3 (AGENT WITH USER PREFS)",
+ col_width=55
+ )
+
+ # Detailed logs below with all fields
+ print("\n" + "-"*125)
+ print("DETAILED FULL LOGS")
+ print("-"*125)
+ print_detailed_logs_3(
+ problem_to_conversation1[problem],
+ problem_to_conversation2[problem],
+ problem_to_conversation3.get(problem, {'conversation': [], 'evaluation': {}}),
+ "FILE 1 (WITHOUT PREFERENCES)",
+ "FILE 2 (WITH PREFERENCES)",
+ "FILE 3 (AGENT WITH USER PREFS)",
+ col_width=55
+ )
+
+ # break
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+# # ==============================================================================
+# # SEPARATE SECTION: Per-User Statistics Averaged Over All Tasks
+# # ==============================================================================
+
+# print("\n" + "="*125)
+# print("="*125)
+# print("STATISTICS FOR EACH USER, AVERAGED OVER ALL TASKS")
+# print("="*125)
+# print("="*125 + "\n")
+
+# # Dictionary to store all data for each user across all tasks
+# user_to_all_data = {}
+
+# # Collect data for all users across all tasks
+# for task in ["math_500", "logiqa", "math_hard", "medqa", "mmlu"]:
+# user_profiles_without_preferences_path = f"/shared/storage-01/users/mehri2/mem/collaborativeagents/scripts/runs/llama70b/user_profiles_without_preferences/{task}_llama70b_user_llama70b_agent_user_profiles_without_preferences_eval_size_20.jsonl"
+# user_profiles_with_preferences_path = f"/shared/storage-01/users/mehri2/mem/collaborativeagents/scripts/runs/llama70b/user_profiles_with_preferences/{task}_llama70b_user_llama70b_agent_user_profiles_with_preferences_eval_size_20.jsonl"
+# agent_with_userpreferences_path = f"/shared/storage-01/users/mehri2/mem/collaborativeagents/scripts/runs/llama70b/agent_with_user_preferences/{task}_llama70b_user_llama70b_agent_agent_with_user_preferences_eval_size_20_v2.jsonl"
+# agent_with_reflection_path = f"/shared/storage-01/users/mehri2/mem/collaborativeagents/scripts/runs/llama70b/agent_with_reflection/{task}_llama70b_user_llama70b_agent_agent_with_reflection_eval_size_20.jsonl"
+
+# data1 = load_data(user_profiles_without_preferences_path)
+# data2 = load_data(user_profiles_with_preferences_path)
+# data3 = load_data(agent_with_userpreferences_path)
+# data4 = load_data(agent_with_reflection_path)
+
+# # For each user in this task, store their data
+# for user_data in data1:
+# user_id = user_data['i']
+# if user_id not in user_to_all_data:
+# user_to_all_data[user_id] = {
+# 'persona': user_data.get('persona'),
+# 'preferences': user_data.get('preferences'),
+# 'data1': [], # without preferences
+# 'data2': [], # with preferences
+# 'data3': [], # agent with user preferences
+# 'data4': [] # agent with reflection
+# }
+# user_to_all_data[user_id]['data1'].append(user_data)
+
+# for user_data in data2:
+# user_id = user_data['i']
+# if user_id in user_to_all_data:
+# user_to_all_data[user_id]['data2'].append(user_data)
+
+# for user_data in data3:
+# user_id = user_data['i']
+# if user_id in user_to_all_data:
+# user_to_all_data[user_id]['data3'].append(user_data)
+
+# for user_data in data4:
+# user_id = user_data['i']
+# if user_id in user_to_all_data:
+# user_to_all_data[user_id]['data4'].append(user_data)
+
+# # Now print statistics for each user, averaged over all tasks
+# for user_id in sorted(user_to_all_data.keys()):
+# user_info = user_to_all_data[user_id]
+
+# # Calculate aggregate stats across all tasks for this user
+# user_avg_acc1, user_avg_len1, user_avg_enf1 = calculate_aggregate_stats(user_info['data1'])
+# user_avg_acc2, user_avg_len2, user_avg_enf2 = calculate_aggregate_stats(user_info['data2'])
+# user_avg_acc3, user_avg_len3, user_avg_enf3 = calculate_aggregate_stats(user_info['data3'])
+# user_avg_acc4, user_avg_len4, user_avg_enf4 = calculate_aggregate_stats(user_info['data4'])
+
+# print("\n" + "="*125)
+# print(f"USER ID: {user_id}")
+# print("="*125)
+
+# # Print user profile info
+# if user_info['persona']:
+# print(f"Persona: {user_info['persona']}")
+# if user_info['preferences']:
+# print(f"Preferences:")
+# for preference in user_info['preferences']:
+# print(f" - {preference}")
+
+# print("\n" + "-"*125)
+# print("STATISTICS AVERAGED OVER ALL TASKS")
+# print("-"*125)
+
+# print("\nUser Without Preferences:")
+# print(f" Average Accuracy: {user_avg_acc1:.2f}")
+# print(f" Average # Messages: {user_avg_len1:.2f}")
+# print(f" Average # Enforced Preferences: {user_avg_enf1:.2f}")
+
+# print("\nUser With Preferences:")
+# print(f" Average Accuracy: {user_avg_acc2:.2f}")
+# print(f" Average # Messages: {user_avg_len2:.2f}")
+# print(f" Average # Enforced Preferences: {user_avg_enf2:.2f}")
+
+# print("\nAgent With User Preferences:")
+# print(f" Average Accuracy: {user_avg_acc3:.2f}")
+# print(f" Average # Messages: {user_avg_len3:.2f}")
+# print(f" Average # Enforced Preferences: {user_avg_enf3:.2f}")
+
+# print("\nAgent With Reflection:")
+# print(f" Average Accuracy: {user_avg_acc4:.2f}")
+# print(f" Average # Messages: {user_avg_len4:.2f}")
+# print(f" Average # Enforced Preferences: {user_avg_enf4:.2f}")
+
+# print("\n" + "="*125)
+# print("END OF PER-USER STATISTICS")
+# print("="*125 + "\n")
+