summaryrefslogtreecommitdiff
path: root/explore_winobias.py
diff options
context:
space:
mode:
authorblackhao <13851610112@163.com>2025-07-04 18:50:03 -0700
committerblackhao <13851610112@163.com>2025-07-04 18:50:03 -0700
commit24c5437993a0199074e027785d11ae3647576123 (patch)
tree53ad67e379a3cd0ddd1ff1b32e62860e637ac26f /explore_winobias.py
parentf35d5c8b7380b906a6819ef436a9d808041731fb (diff)
add winobis
Diffstat (limited to 'explore_winobias.py')
-rw-r--r--explore_winobias.py243
1 files changed, 243 insertions, 0 deletions
diff --git a/explore_winobias.py b/explore_winobias.py
new file mode 100644
index 0000000..291cc54
--- /dev/null
+++ b/explore_winobias.py
@@ -0,0 +1,243 @@
+#!/usr/bin/env python3
+"""
+WinoBias数据集探索脚本
+用于分析数据集是否符合我们的mask-based gender debiasing需求
+"""
+
+from datasets import load_dataset
+import pandas as pd
+import json
+import re
+from collections import Counter
+
+def main():
+ print("🔍 正在下载WinoBias数据集...")
+
+ # WinoBias有4个配置
+ configs = ['type1_anti', 'type1_pro', 'type2_anti', 'type2_pro']
+
+ all_examples = []
+ all_configs_data = {}
+
+ for config in configs:
+ try:
+ print(f"\n📥 正在加载配置: {config}")
+ dataset = load_dataset("wino_bias", config)
+ all_configs_data[config] = dataset
+
+ print(f"✅ {config} 下载成功!")
+
+ # 查看数据集结构
+ print(f"\n📊 {config} 数据集信息:")
+ print(f"- 数据集分片: {list(dataset.keys())}")
+
+ for split_name, split_data in dataset.items():
+ print(f"- {split_name}: {len(split_data)} 条数据")
+ print(f"- 字段: {split_data.column_names}")
+
+ # 收集所有文本
+ if len(split_data) > 0:
+ examples = split_data[:min(10, len(split_data))]
+
+ # 查看数据结构并转换tokens为文本
+ print(f"\n📝 {config}-{split_name} 前3个例子:")
+ for i in range(min(3, len(examples))):
+ print(f"\n例子 {i+1}:")
+
+ # 将tokens转换为文本
+ if 'tokens' in examples:
+ text = ' '.join(examples['tokens'][i])
+ print(f" 文本: {text}")
+ all_examples.append(text)
+
+ # 显示coreference信息
+ if 'coreference_clusters' in examples:
+ print(f" 指代关系: {examples['coreference_clusters'][i]}")
+
+ except Exception as e:
+ print(f"❌ 加载 {config} 失败: {e}")
+
+ # 如果有数据,进行分析
+ if all_examples:
+ analyze_dataset(all_examples, all_configs_data)
+ else:
+ print("❌ 未找到任何文本数据")
+
+def analyze_dataset(all_examples, configs_data):
+ """分析数据集是否适合mask任务"""
+
+ print(f"\n🎯 整体适用性分析:")
+ print(f"总共 {len(all_examples)} 个例子")
+
+ # 查找包含职业的句子
+ occupation_keywords = [
+ 'teacher', 'doctor', 'nurse', 'engineer', 'manager', 'secretary',
+ 'lawyer', 'developer', 'designer', 'cook', 'mechanic', 'driver',
+ 'farmer', 'clerk', 'CEO', 'assistant', 'guard', 'baker', 'analyst',
+ 'salesperson', 'receptionist', 'auditor', 'carpenter', 'mover',
+ 'accountant', 'janitor', 'librarian'
+ ]
+
+ gender_pronouns = ['he', 'she', 'his', 'her', 'him']
+
+ suitable_for_masking = []
+ mask_examples = []
+
+ print(f"\n🔍 分析前20个例子:")
+ for i, example in enumerate(all_examples[:20]):
+ print(f"{i+1}. {example}")
+
+ text_lower = example.lower()
+
+ # 检查是否包含职业词汇
+ has_occupation = any(occ in text_lower for occ in occupation_keywords)
+
+ # 检查是否包含性别代词
+ has_gender_pronoun = any(pronoun in text_lower for pronoun in gender_pronouns)
+
+ print(f" 职业词汇: {has_occupation}, 性别代词: {has_gender_pronoun}")
+
+ # 检查句子结构是否适合转换为mask
+ if has_occupation and has_gender_pronoun:
+ suitable_for_masking.append(example)
+
+ # 创建mask版本
+ masked_version = create_mask_template(example)
+ if masked_version:
+ mask_examples.append({
+ 'original': example,
+ 'masked': masked_version,
+ 'gender_words': extract_gender_words(example)
+ })
+
+ print(f"\n✅ 适合mask任务的例子数量: {len(suitable_for_masking)}")
+ print(f"✅ 成功转换为mask格式的例子: {len(mask_examples)}")
+
+ if mask_examples:
+ print(f"\n📋 Mask转换示例 (前10个):")
+ for i, example in enumerate(mask_examples[:10]):
+ print(f"\n{i+1}. 原句: {example['original']}")
+ print(f" Mask: {example['masked']}")
+ print(f" 性别词: {example['gender_words']}")
+
+ # 分析不同配置的特点
+ print(f"\n🔍 各配置数据特点:")
+ for config_name, dataset in configs_data.items():
+ print(f"\n{config_name}:")
+ if dataset and len(dataset) > 0:
+ # 处理所有分片的数据
+ total_size = 0
+ occupation_total = 0
+ gender_total = 0
+
+ for split_name, split_data in dataset.items():
+ total_size += len(split_data)
+
+ # 分析前10个例子
+ examples = split_data[:min(10, len(split_data))]
+ if 'tokens' in split_data.column_names:
+ for tokens in examples['tokens']:
+ text = ' '.join(tokens).lower()
+ if any(occ in text for occ in occupation_keywords):
+ occupation_total += 1
+ if any(pronoun in text for pronoun in gender_pronouns):
+ gender_total += 1
+
+ print(f" - 总数据量: {total_size}")
+ print(f" - 包含职业词汇的例子(前10个/分片): {occupation_total}")
+ print(f" - 包含性别代词的例子(前10个/分片): {gender_total}")
+
+ # 保存分析结果
+ save_analysis_results(len(all_examples), len(suitable_for_masking), mask_examples, configs_data)
+
+def create_mask_template(text):
+ """将句子转换为mask格式"""
+ patterns = [
+ (r'\bhe\b', '[MASK]'),
+ (r'\bshe\b', '[MASK]'),
+ (r'\bhis\b', '[MASK]'),
+ (r'\bher\b', '[MASK]'),
+ (r'\bhim\b', '[MASK]')
+ ]
+
+ masked_text = text
+ replacements_made = 0
+
+ for pattern, replacement in patterns:
+ if re.search(pattern, masked_text, flags=re.IGNORECASE):
+ masked_text = re.sub(pattern, replacement, masked_text, flags=re.IGNORECASE, count=1)
+ replacements_made += 1
+ break # 只替换第一个找到的
+
+ return masked_text if replacements_made > 0 else None
+
+def extract_gender_words(text):
+ """提取文本中的性别词汇"""
+ gender_words = []
+ patterns = [r'\bhe\b', r'\bshe\b', r'\bhis\b', r'\bher\b', r'\bhim\b']
+
+ for pattern in patterns:
+ matches = re.findall(pattern, text, flags=re.IGNORECASE)
+ gender_words.extend(matches)
+
+ return list(set([word.lower() for word in gender_words]))
+
+def save_analysis_results(total_examples, suitable_count, mask_examples, configs_data):
+ """保存分析结果"""
+
+ # 提取所有文本样例
+ all_text_samples = []
+
+ for config_name, dataset in configs_data.items():
+ for split_name, split_data in dataset.items():
+ examples = split_data[:5] # 每个分片取5个例子
+ if 'tokens' in split_data.column_names:
+ for tokens in examples['tokens']:
+ text = ' '.join(tokens)
+ all_text_samples.append({
+ 'config': config_name,
+ 'split': split_name,
+ 'text': text
+ })
+
+ results = {
+ 'total_examples': total_examples,
+ 'suitable_for_masking': suitable_count,
+ 'mask_conversion_success_rate': len(mask_examples) / max(suitable_count, 1),
+ 'sample_mask_examples': mask_examples[:15],
+ 'configs_info': {},
+ 'all_text_samples': all_text_samples[:20] # 保存前20个文本样例
+ }
+
+ # 保存各配置信息
+ for config_name, dataset in configs_data.items():
+ if dataset:
+ config_info = {'splits': {}}
+ for split_name, split_data in dataset.items():
+ config_info['splits'][split_name] = {
+ 'size': len(split_data),
+ 'columns': split_data.column_names
+ }
+ results['configs_info'][config_name] = config_info
+
+ with open('winobias_analysis.json', 'w', encoding='utf-8') as f:
+ json.dump(results, f, indent=2, ensure_ascii=False, default=str)
+
+ print(f"\n💾 详细分析结果已保存到 'winobias_analysis.json'")
+
+ # 输出总结
+ print(f"\n📊 总结:")
+ print(f"- 总例子数: {total_examples}")
+ print(f"- 适合mask任务: {suitable_count} ({suitable_count/max(total_examples,1)*100:.1f}%)")
+ print(f"- 成功转换率: {len(mask_examples)/max(suitable_count,1)*100:.1f}%")
+
+ if len(mask_examples) > 10:
+ print(f"✅ WinoBias数据集非常适合我们的需求!")
+ print(f" - 包含大量职业+性别代词的句子")
+ print(f" - 可以轻松转换为[MASK]格式")
+ print(f" - 有pro/anti版本用于对比性别偏见")
+ else:
+ print(f"⚠️ 可用数据较少,可能需要额外处理或寻找其他数据集")
+
+if __name__ == "__main__":
+ main() \ No newline at end of file