diff options
Diffstat (limited to 'enhance_gee_processor.py')
| -rwxr-xr-x | enhance_gee_processor.py | 235 |
1 files changed, 235 insertions, 0 deletions
diff --git a/enhance_gee_processor.py b/enhance_gee_processor.py new file mode 100755 index 0000000..173b5aa --- /dev/null +++ b/enhance_gee_processor.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python3 +""" +增强GEE处理器以支持真实数据集 +支持Numina数学推理数据和其他真实数据源 +""" +import pandas as pd +import numpy as np +from pathlib import Path +import json +import re +import sys +sys.path.append('.') + +from dataset.gee_processor import GEEProcessor + +class EnhancedGEEProcessor(GEEProcessor): + """增强版GEE处理器,支持多种真实数据源""" + + def __init__(self, tokenizer): + super().__init__(tokenizer) + self.name_patterns = { + 'male': ['Tom', 'John', 'Mike', 'Bob', 'David', 'James', 'Robert', 'Michael', 'William', 'Richard'], + 'female': ['Sarah', 'Lisa', 'Emma', 'Alice', 'Mary', 'Jennifer', 'Linda', 'Elizabeth', 'Barbara', 'Susan'] + } + + def process_numina_data(self, file_path: str, target_size: int = 1000) -> list: + """处理Numina数学推理数据""" + print(f"📊 处理Numina数据: {file_path}") + + # 读取parquet文件 + df = pd.read_parquet(file_path) + print(f"原始数据量: {len(df)}") + + # 随机采样 + if len(df) > target_size: + df = df.sample(n=target_size, random_state=42) + print(f"采样后数据量: {len(df)}") + + processed_data = [] + for idx, row in df.iterrows(): + # 提取问题和答案 + problem = row.get('problem', row.get('question', '')) + solution = row.get('solution', row.get('answer', '')) + + if problem and solution: + # 生成性别平衡的变体 + male_version = self._genderize_text(problem, 'male') + female_version = self._genderize_text(problem, 'female') + + processed_data.extend([ + { + 'input': self.apply_chat_template(male_version), + 'output': solution, + 'gender': 'male', + 'original_id': idx, + 'source': 'numina' + }, + { + 'input': self.apply_chat_template(female_version), + 'output': solution, + 'gender': 'female', + 'original_id': idx, + 'source': 'numina' + } + ]) + + print(f"✅ 处理完成,生成 {len(processed_data)} 个样本") + return processed_data + + def process_1shot_rlvr_data(self, file_path: str) -> list: + """处理1shot RLVR数据""" + print(f"�� 处理1shot RLVR数据: {file_path}") + + df = pd.read_parquet(file_path) + print(f"原始数据量: {len(df)}") + + processed_data = [] + for idx, row in df.iterrows(): + # 根据实际数据结构调整 + prompt = row.get('prompt', row.get('input', '')) + + if prompt: + # 生成性别变体 + for gender in ['male', 'female']: + genderized_prompt = self._genderize_text(prompt, gender) + + processed_data.append({ + 'input': self.apply_chat_template(genderized_prompt), + 'gender': gender, + 'original_id': idx, + 'source': '1shot_rlvr' + }) + + print(f"✅ 处理完成,生成 {len(processed_data)} 个样本") + return processed_data + + def _genderize_text(self, text: str, target_gender: str) -> str: + """将文本中的性别引用转换为指定性别""" + + # 选择名字 + names = self.name_patterns[target_gender] + + # 替换通用占位符 + if '[NAME]' in text or '{name}' in text: + name = np.random.choice(names) + text = text.replace('[NAME]', name).replace('{name}', name) + return text + + # 检测现有性别名字并替换 + all_male_names = self.name_patterns['male'] + all_female_names = self.name_patterns['female'] + + for male_name in all_male_names: + if male_name in text: + replacement = np.random.choice(names) + text = text.replace(male_name, replacement) + break + + for female_name in all_female_names: + if female_name in text: + replacement = np.random.choice(names) + text = text.replace(female_name, replacement) + break + + # 如果没有找到名字,随机添加一个 + if not any(name in text for name in all_male_names + all_female_names): + name = np.random.choice(names) + # 在合适的地方插入名字 + if "person" in text.lower(): + text = text.replace("person", name) + elif "student" in text.lower(): + text = text.replace("student", f"student named {name}") + elif "someone" in text.lower(): + text = text.replace("someone", name) + else: + # 在句子开头添加 + text = f"{name} is working on this problem: {text}" + + return text + + def create_balanced_dataset(self, data_sources: list, balance_method: str = 'oversample') -> list: + """创建性别平衡的数据集""" + + all_data = [] + for source_config in data_sources: + source_type = source_config['type'] + file_path = source_config['path'] + + if source_type == 'numina': + data = self.process_numina_data(file_path, source_config.get('target_size', 1000)) + elif source_type == '1shot_rlvr': + data = self.process_1shot_rlvr_data(file_path) + else: + print(f"⚠️ 未知数据源类型: {source_type}") + continue + + all_data.extend(data) + + # 统计性别分布 + male_data = [item for item in all_data if item['gender'] == 'male'] + female_data = [item for item in all_data if item['gender'] == 'female'] + + print(f"\n📊 原始数据分布:") + print(f" 男性样本: {len(male_data)}") + print(f" 女性样本: {len(female_data)}") + + # 平衡处理 + if balance_method == 'oversample': + target_size = max(len(male_data), len(female_data)) + + if len(male_data) < target_size: + male_data = male_data * (target_size // len(male_data)) + male_data[:target_size % len(male_data)] + if len(female_data) < target_size: + female_data = female_data * (target_size // len(female_data)) + female_data[:target_size % len(female_data)] + + elif balance_method == 'undersample': + target_size = min(len(male_data), len(female_data)) + male_data = male_data[:target_size] + female_data = female_data[:target_size] + + balanced_data = male_data + female_data + np.random.shuffle(balanced_data) + + print(f"📊 平衡后数据分布:") + male_count = sum(1 for item in balanced_data if item['gender'] == 'male') + female_count = sum(1 for item in balanced_data if item['gender'] == 'female') + print(f" 男性样本: {male_count}") + print(f" 女性样本: {female_count}") + print(f" 总样本数: {len(balanced_data)}") + + return balanced_data + +def main(): + """示例用法""" + from transformers import AutoTokenizer + + print("🔧 测试增强版GEE处理器...") + + # 初始化 + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Math-1.5B-Instruct", trust_remote_code=True) + processor = EnhancedGEEProcessor(tokenizer) + + # 配置数据源 + data_sources = [ + { + 'type': 'numina', + 'path': 'dataset/numina/numina_00.parquet', + 'target_size': 100 # 测试用小样本 + } + # 可以添加更多数据源 + ] + + # 处理数据 + try: + balanced_data = processor.create_balanced_dataset(data_sources, balance_method='oversample') + + # 保存结果 + output_file = 'enhanced_training_data.json' + with open(output_file, 'w', encoding='utf-8') as f: + json.dump(balanced_data, f, indent=2, ensure_ascii=False) + + print(f"✅ 增强数据已保存: {output_file}") + + # 显示示例 + print(f"\n📝 数据示例:") + for i, sample in enumerate(balanced_data[:4]): + print(f" 示例 {i+1} ({sample['gender']}):") + print(f" 输入: {sample['input'][:100]}...") + print() + + except Exception as e: + print(f"❌ 处理失败: {e}") + +if __name__ == "__main__": + main() |
