diff options
| author | blackhao <13851610112@163.com> | 2025-07-04 18:50:03 -0700 |
|---|---|---|
| committer | blackhao <13851610112@163.com> | 2025-07-04 18:50:03 -0700 |
| commit | 24c5437993a0199074e027785d11ae3647576123 (patch) | |
| tree | 53ad67e379a3cd0ddd1ff1b32e62860e637ac26f /explore_winobias.py | |
| parent | f35d5c8b7380b906a6819ef436a9d808041731fb (diff) | |
add winobis
Diffstat (limited to 'explore_winobias.py')
| -rw-r--r-- | explore_winobias.py | 243 |
1 files changed, 243 insertions, 0 deletions
diff --git a/explore_winobias.py b/explore_winobias.py new file mode 100644 index 0000000..291cc54 --- /dev/null +++ b/explore_winobias.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 +""" +WinoBias数据集探索脚本 +用于分析数据集是否符合我们的mask-based gender debiasing需求 +""" + +from datasets import load_dataset +import pandas as pd +import json +import re +from collections import Counter + +def main(): + print("🔍 正在下载WinoBias数据集...") + + # WinoBias有4个配置 + configs = ['type1_anti', 'type1_pro', 'type2_anti', 'type2_pro'] + + all_examples = [] + all_configs_data = {} + + for config in configs: + try: + print(f"\n📥 正在加载配置: {config}") + dataset = load_dataset("wino_bias", config) + all_configs_data[config] = dataset + + print(f"✅ {config} 下载成功!") + + # 查看数据集结构 + print(f"\n📊 {config} 数据集信息:") + print(f"- 数据集分片: {list(dataset.keys())}") + + for split_name, split_data in dataset.items(): + print(f"- {split_name}: {len(split_data)} 条数据") + print(f"- 字段: {split_data.column_names}") + + # 收集所有文本 + if len(split_data) > 0: + examples = split_data[:min(10, len(split_data))] + + # 查看数据结构并转换tokens为文本 + print(f"\n📝 {config}-{split_name} 前3个例子:") + for i in range(min(3, len(examples))): + print(f"\n例子 {i+1}:") + + # 将tokens转换为文本 + if 'tokens' in examples: + text = ' '.join(examples['tokens'][i]) + print(f" 文本: {text}") + all_examples.append(text) + + # 显示coreference信息 + if 'coreference_clusters' in examples: + print(f" 指代关系: {examples['coreference_clusters'][i]}") + + except Exception as e: + print(f"❌ 加载 {config} 失败: {e}") + + # 如果有数据,进行分析 + if all_examples: + analyze_dataset(all_examples, all_configs_data) + else: + print("❌ 未找到任何文本数据") + +def analyze_dataset(all_examples, configs_data): + """分析数据集是否适合mask任务""" + + print(f"\n🎯 整体适用性分析:") + print(f"总共 {len(all_examples)} 个例子") + + # 查找包含职业的句子 + occupation_keywords = [ + 'teacher', 'doctor', 'nurse', 'engineer', 'manager', 'secretary', + 'lawyer', 'developer', 'designer', 'cook', 'mechanic', 'driver', + 'farmer', 'clerk', 'CEO', 'assistant', 'guard', 'baker', 'analyst', + 'salesperson', 'receptionist', 'auditor', 'carpenter', 'mover', + 'accountant', 'janitor', 'librarian' + ] + + gender_pronouns = ['he', 'she', 'his', 'her', 'him'] + + suitable_for_masking = [] + mask_examples = [] + + print(f"\n🔍 分析前20个例子:") + for i, example in enumerate(all_examples[:20]): + print(f"{i+1}. {example}") + + text_lower = example.lower() + + # 检查是否包含职业词汇 + has_occupation = any(occ in text_lower for occ in occupation_keywords) + + # 检查是否包含性别代词 + has_gender_pronoun = any(pronoun in text_lower for pronoun in gender_pronouns) + + print(f" 职业词汇: {has_occupation}, 性别代词: {has_gender_pronoun}") + + # 检查句子结构是否适合转换为mask + if has_occupation and has_gender_pronoun: + suitable_for_masking.append(example) + + # 创建mask版本 + masked_version = create_mask_template(example) + if masked_version: + mask_examples.append({ + 'original': example, + 'masked': masked_version, + 'gender_words': extract_gender_words(example) + }) + + print(f"\n✅ 适合mask任务的例子数量: {len(suitable_for_masking)}") + print(f"✅ 成功转换为mask格式的例子: {len(mask_examples)}") + + if mask_examples: + print(f"\n📋 Mask转换示例 (前10个):") + for i, example in enumerate(mask_examples[:10]): + print(f"\n{i+1}. 原句: {example['original']}") + print(f" Mask: {example['masked']}") + print(f" 性别词: {example['gender_words']}") + + # 分析不同配置的特点 + print(f"\n🔍 各配置数据特点:") + for config_name, dataset in configs_data.items(): + print(f"\n{config_name}:") + if dataset and len(dataset) > 0: + # 处理所有分片的数据 + total_size = 0 + occupation_total = 0 + gender_total = 0 + + for split_name, split_data in dataset.items(): + total_size += len(split_data) + + # 分析前10个例子 + examples = split_data[:min(10, len(split_data))] + if 'tokens' in split_data.column_names: + for tokens in examples['tokens']: + text = ' '.join(tokens).lower() + if any(occ in text for occ in occupation_keywords): + occupation_total += 1 + if any(pronoun in text for pronoun in gender_pronouns): + gender_total += 1 + + print(f" - 总数据量: {total_size}") + print(f" - 包含职业词汇的例子(前10个/分片): {occupation_total}") + print(f" - 包含性别代词的例子(前10个/分片): {gender_total}") + + # 保存分析结果 + save_analysis_results(len(all_examples), len(suitable_for_masking), mask_examples, configs_data) + +def create_mask_template(text): + """将句子转换为mask格式""" + patterns = [ + (r'\bhe\b', '[MASK]'), + (r'\bshe\b', '[MASK]'), + (r'\bhis\b', '[MASK]'), + (r'\bher\b', '[MASK]'), + (r'\bhim\b', '[MASK]') + ] + + masked_text = text + replacements_made = 0 + + for pattern, replacement in patterns: + if re.search(pattern, masked_text, flags=re.IGNORECASE): + masked_text = re.sub(pattern, replacement, masked_text, flags=re.IGNORECASE, count=1) + replacements_made += 1 + break # 只替换第一个找到的 + + return masked_text if replacements_made > 0 else None + +def extract_gender_words(text): + """提取文本中的性别词汇""" + gender_words = [] + patterns = [r'\bhe\b', r'\bshe\b', r'\bhis\b', r'\bher\b', r'\bhim\b'] + + for pattern in patterns: + matches = re.findall(pattern, text, flags=re.IGNORECASE) + gender_words.extend(matches) + + return list(set([word.lower() for word in gender_words])) + +def save_analysis_results(total_examples, suitable_count, mask_examples, configs_data): + """保存分析结果""" + + # 提取所有文本样例 + all_text_samples = [] + + for config_name, dataset in configs_data.items(): + for split_name, split_data in dataset.items(): + examples = split_data[:5] # 每个分片取5个例子 + if 'tokens' in split_data.column_names: + for tokens in examples['tokens']: + text = ' '.join(tokens) + all_text_samples.append({ + 'config': config_name, + 'split': split_name, + 'text': text + }) + + results = { + 'total_examples': total_examples, + 'suitable_for_masking': suitable_count, + 'mask_conversion_success_rate': len(mask_examples) / max(suitable_count, 1), + 'sample_mask_examples': mask_examples[:15], + 'configs_info': {}, + 'all_text_samples': all_text_samples[:20] # 保存前20个文本样例 + } + + # 保存各配置信息 + for config_name, dataset in configs_data.items(): + if dataset: + config_info = {'splits': {}} + for split_name, split_data in dataset.items(): + config_info['splits'][split_name] = { + 'size': len(split_data), + 'columns': split_data.column_names + } + results['configs_info'][config_name] = config_info + + with open('winobias_analysis.json', 'w', encoding='utf-8') as f: + json.dump(results, f, indent=2, ensure_ascii=False, default=str) + + print(f"\n💾 详细分析结果已保存到 'winobias_analysis.json'") + + # 输出总结 + print(f"\n📊 总结:") + print(f"- 总例子数: {total_examples}") + print(f"- 适合mask任务: {suitable_count} ({suitable_count/max(total_examples,1)*100:.1f}%)") + print(f"- 成功转换率: {len(mask_examples)/max(suitable_count,1)*100:.1f}%") + + if len(mask_examples) > 10: + print(f"✅ WinoBias数据集非常适合我们的需求!") + print(f" - 包含大量职业+性别代词的句子") + print(f" - 可以轻松转换为[MASK]格式") + print(f" - 有pro/anti版本用于对比性别偏见") + else: + print(f"⚠️ 可用数据较少,可能需要额外处理或寻找其他数据集") + +if __name__ == "__main__": + main()
\ No newline at end of file |
