summaryrefslogtreecommitdiff
path: root/enhance_gee_processor.py
diff options
context:
space:
mode:
Diffstat (limited to 'enhance_gee_processor.py')
-rwxr-xr-xenhance_gee_processor.py235
1 files changed, 235 insertions, 0 deletions
diff --git a/enhance_gee_processor.py b/enhance_gee_processor.py
new file mode 100755
index 0000000..173b5aa
--- /dev/null
+++ b/enhance_gee_processor.py
@@ -0,0 +1,235 @@
+#!/usr/bin/env python3
+"""
+增强GEE处理器以支持真实数据集
+支持Numina数学推理数据和其他真实数据源
+"""
+import pandas as pd
+import numpy as np
+from pathlib import Path
+import json
+import re
+import sys
+sys.path.append('.')
+
+from dataset.gee_processor import GEEProcessor
+
+class EnhancedGEEProcessor(GEEProcessor):
+ """增强版GEE处理器,支持多种真实数据源"""
+
+ def __init__(self, tokenizer):
+ super().__init__(tokenizer)
+ self.name_patterns = {
+ 'male': ['Tom', 'John', 'Mike', 'Bob', 'David', 'James', 'Robert', 'Michael', 'William', 'Richard'],
+ 'female': ['Sarah', 'Lisa', 'Emma', 'Alice', 'Mary', 'Jennifer', 'Linda', 'Elizabeth', 'Barbara', 'Susan']
+ }
+
+ def process_numina_data(self, file_path: str, target_size: int = 1000) -> list:
+ """处理Numina数学推理数据"""
+ print(f"📊 处理Numina数据: {file_path}")
+
+ # 读取parquet文件
+ df = pd.read_parquet(file_path)
+ print(f"原始数据量: {len(df)}")
+
+ # 随机采样
+ if len(df) > target_size:
+ df = df.sample(n=target_size, random_state=42)
+ print(f"采样后数据量: {len(df)}")
+
+ processed_data = []
+ for idx, row in df.iterrows():
+ # 提取问题和答案
+ problem = row.get('problem', row.get('question', ''))
+ solution = row.get('solution', row.get('answer', ''))
+
+ if problem and solution:
+ # 生成性别平衡的变体
+ male_version = self._genderize_text(problem, 'male')
+ female_version = self._genderize_text(problem, 'female')
+
+ processed_data.extend([
+ {
+ 'input': self.apply_chat_template(male_version),
+ 'output': solution,
+ 'gender': 'male',
+ 'original_id': idx,
+ 'source': 'numina'
+ },
+ {
+ 'input': self.apply_chat_template(female_version),
+ 'output': solution,
+ 'gender': 'female',
+ 'original_id': idx,
+ 'source': 'numina'
+ }
+ ])
+
+ print(f"✅ 处理完成,生成 {len(processed_data)} 个样本")
+ return processed_data
+
+ def process_1shot_rlvr_data(self, file_path: str) -> list:
+ """处理1shot RLVR数据"""
+ print(f"�� 处理1shot RLVR数据: {file_path}")
+
+ df = pd.read_parquet(file_path)
+ print(f"原始数据量: {len(df)}")
+
+ processed_data = []
+ for idx, row in df.iterrows():
+ # 根据实际数据结构调整
+ prompt = row.get('prompt', row.get('input', ''))
+
+ if prompt:
+ # 生成性别变体
+ for gender in ['male', 'female']:
+ genderized_prompt = self._genderize_text(prompt, gender)
+
+ processed_data.append({
+ 'input': self.apply_chat_template(genderized_prompt),
+ 'gender': gender,
+ 'original_id': idx,
+ 'source': '1shot_rlvr'
+ })
+
+ print(f"✅ 处理完成,生成 {len(processed_data)} 个样本")
+ return processed_data
+
+ def _genderize_text(self, text: str, target_gender: str) -> str:
+ """将文本中的性别引用转换为指定性别"""
+
+ # 选择名字
+ names = self.name_patterns[target_gender]
+
+ # 替换通用占位符
+ if '[NAME]' in text or '{name}' in text:
+ name = np.random.choice(names)
+ text = text.replace('[NAME]', name).replace('{name}', name)
+ return text
+
+ # 检测现有性别名字并替换
+ all_male_names = self.name_patterns['male']
+ all_female_names = self.name_patterns['female']
+
+ for male_name in all_male_names:
+ if male_name in text:
+ replacement = np.random.choice(names)
+ text = text.replace(male_name, replacement)
+ break
+
+ for female_name in all_female_names:
+ if female_name in text:
+ replacement = np.random.choice(names)
+ text = text.replace(female_name, replacement)
+ break
+
+ # 如果没有找到名字,随机添加一个
+ if not any(name in text for name in all_male_names + all_female_names):
+ name = np.random.choice(names)
+ # 在合适的地方插入名字
+ if "person" in text.lower():
+ text = text.replace("person", name)
+ elif "student" in text.lower():
+ text = text.replace("student", f"student named {name}")
+ elif "someone" in text.lower():
+ text = text.replace("someone", name)
+ else:
+ # 在句子开头添加
+ text = f"{name} is working on this problem: {text}"
+
+ return text
+
+ def create_balanced_dataset(self, data_sources: list, balance_method: str = 'oversample') -> list:
+ """创建性别平衡的数据集"""
+
+ all_data = []
+ for source_config in data_sources:
+ source_type = source_config['type']
+ file_path = source_config['path']
+
+ if source_type == 'numina':
+ data = self.process_numina_data(file_path, source_config.get('target_size', 1000))
+ elif source_type == '1shot_rlvr':
+ data = self.process_1shot_rlvr_data(file_path)
+ else:
+ print(f"⚠️ 未知数据源类型: {source_type}")
+ continue
+
+ all_data.extend(data)
+
+ # 统计性别分布
+ male_data = [item for item in all_data if item['gender'] == 'male']
+ female_data = [item for item in all_data if item['gender'] == 'female']
+
+ print(f"\n📊 原始数据分布:")
+ print(f" 男性样本: {len(male_data)}")
+ print(f" 女性样本: {len(female_data)}")
+
+ # 平衡处理
+ if balance_method == 'oversample':
+ target_size = max(len(male_data), len(female_data))
+
+ if len(male_data) < target_size:
+ male_data = male_data * (target_size // len(male_data)) + male_data[:target_size % len(male_data)]
+ if len(female_data) < target_size:
+ female_data = female_data * (target_size // len(female_data)) + female_data[:target_size % len(female_data)]
+
+ elif balance_method == 'undersample':
+ target_size = min(len(male_data), len(female_data))
+ male_data = male_data[:target_size]
+ female_data = female_data[:target_size]
+
+ balanced_data = male_data + female_data
+ np.random.shuffle(balanced_data)
+
+ print(f"📊 平衡后数据分布:")
+ male_count = sum(1 for item in balanced_data if item['gender'] == 'male')
+ female_count = sum(1 for item in balanced_data if item['gender'] == 'female')
+ print(f" 男性样本: {male_count}")
+ print(f" 女性样本: {female_count}")
+ print(f" 总样本数: {len(balanced_data)}")
+
+ return balanced_data
+
+def main():
+ """示例用法"""
+ from transformers import AutoTokenizer
+
+ print("🔧 测试增强版GEE处理器...")
+
+ # 初始化
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Math-1.5B-Instruct", trust_remote_code=True)
+ processor = EnhancedGEEProcessor(tokenizer)
+
+ # 配置数据源
+ data_sources = [
+ {
+ 'type': 'numina',
+ 'path': 'dataset/numina/numina_00.parquet',
+ 'target_size': 100 # 测试用小样本
+ }
+ # 可以添加更多数据源
+ ]
+
+ # 处理数据
+ try:
+ balanced_data = processor.create_balanced_dataset(data_sources, balance_method='oversample')
+
+ # 保存结果
+ output_file = 'enhanced_training_data.json'
+ with open(output_file, 'w', encoding='utf-8') as f:
+ json.dump(balanced_data, f, indent=2, ensure_ascii=False)
+
+ print(f"✅ 增强数据已保存: {output_file}")
+
+ # 显示示例
+ print(f"\n📝 数据示例:")
+ for i, sample in enumerate(balanced_data[:4]):
+ print(f" 示例 {i+1} ({sample['gender']}):")
+ print(f" 输入: {sample['input'][:100]}...")
+ print()
+
+ except Exception as e:
+ print(f"❌ 处理失败: {e}")
+
+if __name__ == "__main__":
+ main()