1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
|
#!/usr/bin/env python3
"""
纯偏见减少训练脚本
目标:只最小化男女间熵差,不进行整体熵最小化
"""
import argparse
import os
import torch
import torch.nn.functional as F
from torch.optim import AdamW
import pandas as pd
import numpy as np
from pathlib import Path
import wandb
from accelerate import Accelerator, DeepSpeedPlugin
from accelerate.utils import set_seed
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
# 导入自定义模块
import sys
sys.path.append('.')
from dataset.gee_processor import GEEProcessor
from losses.debiasing_loss import DebiasingLoss, gender_to_label
from smart_balanced_dataloader import create_smart_balanced_dataloader
os.environ.setdefault("NCCL_TIMEOUT", "2700")
os.environ.setdefault("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", "2700")
def parse_args():
parser = argparse.ArgumentParser()
# Debiasing相关参数
parser.add_argument('--scale_factor', type=float, default=1.0, help='Debiasing loss scale factor')
parser.add_argument('--use_l1', action='store_true', help='Use L1 loss instead of L2')
parser.add_argument('--target_gap', type=float, default=0.01, help='Target entropy gap for early stopping')
# 模型参数
parser.add_argument('--model_name', type=str, default='Qwen2.5-Math-1.5B-Instruct', help='Model name')
parser.add_argument('--model_path', type=str, required=True, help='Model path')
parser.add_argument('--effective_batch', type=int, default=4, help='Global batch size')
parser.add_argument('--micro_batch_size', type=int, default=2, help='Micro batch size (must >=2 for balance)')
parser.add_argument('--learning_rate', type=float, default=1e-5, help='Learning rate (lower for debiasing)')
parser.add_argument('--max_steps', type=int, default=20, help='Maximum training steps')
parser.add_argument('--sample_temp', type=float, default=0.7, help='Generation temperature')
# 运行参数
parser.add_argument('--run_name', type=str, default='pure_debiasing', help='Run name')
parser.add_argument('--wandb_project', type=str, default='debiasing-only', help='W&B project name')
parser.add_argument('--use_test_data', action='store_true', help='Use synthetic test data')
parser.add_argument('--seed', type=int, default=42, help='Random seed')
parser.add_argument('--log_steps', type=int, default=1, help='Logging frequency')
parser.add_argument('--save_steps', type=int, default=10, help='Save frequency')
return parser.parse_args()
def main():
args = parse_args()
set_seed(args.seed)
# 强制检查batch_size
if args.micro_batch_size < 2:
print("❌ 错误: micro_batch_size必须>=2才能保证性别平衡!")
print("请使用: --micro_batch_size 2 或更大")
return
# DeepSpeed配置
ds_config = {
"train_micro_batch_size_per_gpu": args.micro_batch_size,
"train_batch_size": args.effective_batch,
"gradient_accumulation_steps": max(1, args.effective_batch // args.micro_batch_size),
"bf16": {"enabled": True},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {"device": "cpu"},
"offload_param": {"device": "none"}
},
"gradient_clipping": 1.0,
}
accelerator = Accelerator(
mixed_precision="bf16",
gradient_accumulation_steps=max(1, args.effective_batch // args.micro_batch_size),
deepspeed_plugin=DeepSpeedPlugin(hf_ds_config=ds_config)
)
print = accelerator.print
print(f"🎯 开始纯偏见减少训练 - {args.run_name}")
print(f"📊 配置信息:")
print(f" 批次大小: micro={args.micro_batch_size}, effective={args.effective_batch}")
print(f" 缩放因子: {args.scale_factor}")
print(f" 目标熵差: {args.target_gap}")
print(f" 学习率: {args.learning_rate}")
print(f" ⚠️ 注意: 只进行偏见减少,不进行熵最小化")
# 加载模型
model_path = args.model_path
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
config.use_cache = False
model = AutoModelForCausalLM.from_pretrained(model_path, config=config, trust_remote_code=True)
model.gradient_checkpointing_enable()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 初始化纯偏见减少损失函数
gee_processor = GEEProcessor(tokenizer)
debiasing_loss_fn = DebiasingLoss(use_l1=args.use_l1, scale_factor=args.scale_factor)
if accelerator.is_main_process:
wandb.init(project=args.wandb_project, name=args.run_name, config=vars(args))
# 准备数据
if args.use_test_data:
print("📊 使用合成测试数据...")
train_data = gee_processor.create_test_data(num_samples=100)
# 检查数据平衡性
male_count = sum(1 for item in train_data if item['gender'] == 'male')
female_count = sum(1 for item in train_data if item['gender'] == 'female')
print(f"原始数据: male={male_count}, female={female_count}")
# 创建智能平衡的数据加载器
train_loader = create_smart_balanced_dataloader(
train_data,
batch_size=args.micro_batch_size,
num_batches=args.max_steps + 5
)
else:
print("❌ 请使用 --use_test_data 进行测试")
return
optimizer = AdamW(model.parameters(), lr=args.learning_rate)
model, optimizer = accelerator.prepare(model, optimizer)
print(f"🎯 开始训练...")
print(f" 目标: 最小化 |H_female - H_male|")
print(f" 理想结果: entropy_gap → {args.target_gap}")
# 开始训练
model.train()
best_gap = float('inf')
converged_steps = 0
for step, batch in enumerate(train_loader, start=1):
if step > args.max_steps:
print(f"🛑 达到最大步数 {args.max_steps},训练结束")
break
with accelerator.accumulate(model):
try:
# 验证批次平衡性
male_count = sum(1 for g in batch["gender"] if g == 'male')
female_count = sum(1 for g in batch["gender"] if g == 'female')
if male_count == 0 or female_count == 0:
print(f"💥 Step {step}: 批次不平衡!male={male_count}, female={female_count}")
continue
# 准备输入
inputs = tokenizer(
batch["input"],
return_tensors="pt",
padding="longest",
truncation=True,
max_length=1024
).to(accelerator.device)
# 生成回答
with torch.no_grad():
gen_ids = accelerator.unwrap_model(model).generate(
**inputs,
max_new_tokens=128,
do_sample=True,
top_p=0.95,
temperature=args.sample_temp,
synced_gpus=True,
repetition_penalty=1.15,
pad_token_id=tokenizer.pad_token_id,
use_cache=False
)
# 准备完整序列
seq = torch.cat([inputs.input_ids, gen_ids[:, inputs.input_ids.shape[1]:]], dim=1)
pad_mask = seq.ne(tokenizer.pad_token_id)
prompt_lengths = pad_mask[:, :inputs.input_ids.shape[1]].sum(-1)
# 计算logits和熵
logits = model(seq, attention_mask=pad_mask).logits
H_tok = debiasing_loss_fn.compute_token_entropy(logits, pad_mask)
H_i = debiasing_loss_fn.compute_sample_entropy(H_tok, prompt_lengths)
# 准备性别标签
gender_labels = torch.tensor([
gender_to_label(g) for g in batch["gender"]
], device=accelerator.device)
# 计算纯偏见减少损失
loss, metrics = debiasing_loss_fn.compute_debiasing_loss(H_i, gender_labels)
# 反向传播
accelerator.backward(loss)
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
# 检查收敛
current_gap = metrics['entropy_gap']
if current_gap < best_gap:
best_gap = current_gap
converged_steps = 0
else:
converged_steps += 1
# 早停检查
if current_gap <= args.target_gap:
print(f"🎉 达到目标熵差 {args.target_gap}!当前: {current_gap:.6f}")
print(f"在第 {step} 步实现目标")
break
# 日志记录
if accelerator.is_main_process:
if step % args.log_steps == 0:
gap_direction = "📉" if current_gap <= best_gap else "📈"
print(f"{gap_direction} Step {step} | loss={loss.item():.6f} | "
f"gap={current_gap:.6f} | "
f"H_male={metrics['H_male']:.6f} | "
f"H_female={metrics['H_female']:.6f} | "
f"批次[{male_count}M,{female_count}F]")
# 添加进度指标
progress = max(0, min(1, (0.1 - current_gap) / 0.1)) * 100 # 假设从0.1开始
metrics['progress_percent'] = progress
metrics['best_gap'] = best_gap
wandb.log({"step": step, **metrics})
# 保存检查点
if accelerator.is_main_process and step % args.save_steps == 0:
ckpt = Path(f"checkpoints/{args.model_name}/{args.run_name}") / f"step_{step}"
ckpt.mkdir(parents=True, exist_ok=True)
accelerator.unwrap_model(model).save_pretrained(ckpt, safe_serialization=True)
tokenizer.save_pretrained(ckpt)
print(f"💾 检查点已保存: {ckpt}")
except Exception as e:
print(f"❌ 训练步骤 {step} 出错: {e}")
continue
if accelerator.is_main_process:
print(f"\n🎉 纯偏见减少训练完成!")
print(f"📊 最终结果:")
print(f" 最佳熵差: {best_gap:.6f}")
print(f" 目标熵差: {args.target_gap}")
if best_gap <= args.target_gap:
print("✅ 成功达到偏见减少目标!")
elif best_gap <= 0.05:
print("⚠️ 偏见显著减少,接近目标")
else:
print("❌ 偏见减少效果有限,可能需要更多训练步数或调整参数")
# 保存最终模型
final = Path(f"checkpoints/{args.model_name}/{args.run_name}") / "final"
final.mkdir(parents=True, exist_ok=True)
accelerator.unwrap_model(model).save_pretrained(final, safe_serialization=True)
tokenizer.save_pretrained(final)
print(f"💾 最终模型已保存: {final}")
wandb.finish()
if __name__ == "__main__":
main()
|