summaryrefslogtreecommitdiff
path: root/train_gee.py
blob: 2e62fee4058a47a1950e1d036ea7d7e2b046db3d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import argparse
import os
import random
import time
from pathlib import Path

import psutil
import torch
import torch.nn.functional as F
from torch.optim import AdamW
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader

import wandb

from accelerate import Accelerator, DeepSpeedPlugin
from accelerate.utils import set_seed
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM

# 导入自定义模块
import sys
sys.path.append('.')
from dataset.gee_processor import GEEProcessor
from losses.gee_loss import GEELoss, gender_to_label

os.environ.setdefault("NCCL_TIMEOUT", "2700")
os.environ.setdefault("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", "2700")

class GEEDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

def custom_collate(batch):
    return {
        "input": [item["input"] for item in batch],
        "gender": [item["gender"] for item in batch]
    }

def parse_args():
    parser = argparse.ArgumentParser()
    # GEE相关参数
    parser.add_argument('--lambda_weight', type=float, default=3.0, help='GEE lambda weight')
    parser.add_argument('--use_l1', action='store_true', help='Use L1 loss instead of L2')
    parser.add_argument('--auto_anneal', action='store_true', help='Use automatic annealing')
    parser.add_argument('--bias_eval_steps', type=int, default=10, help='Bias evaluation frequency')
    parser.add_argument('--balance_dataset', action='store_true', default=True, help='Balance dataset by gender')
    parser.add_argument('--target_size', type=int, default=None, help='Target dataset size for balancing')
    
    # 保留原有参数
    parser.add_argument('--model_name', type=str, default='Qwen2.5-Math-7B', help='Model name')
    parser.add_argument('--model_path', type=str, default=None, help='Local model path')
    parser.add_argument('--train_data', type=str, default='dataset/1shot_rlvr/pi1_r1280.parquet', help='Training data file path')
    parser.add_argument('--save_root', type=str, default=None, help='Checkpoint save root directory')
    parser.add_argument('--effective_batch', type=int, default=64, help='Global batch size')
    parser.add_argument('--micro_batch_size', type=int, default=2, help='Micro batch size')
    parser.add_argument('--temperature', type=float, default=0.5, help='Temperature coefficient')
    parser.add_argument('--learning_rate', type=float, default=2e-5, help='Learning rate')
    parser.add_argument('--log_steps', type=int, default=1, help='Logging step interval')
    parser.add_argument('--save_steps', type=int, default=1, help='Checkpoint saving step interval')
    parser.add_argument('--max_steps', type=int, default=50, help='Maximum training steps')
    parser.add_argument('--sample_temp', type=float, default=0.5, help='Generation temperature parameter')
    parser.add_argument('--run_name', type=str, default='one_shot_gee', help='Experiment run name')
    parser.add_argument('--wandb_project', type=str, default='one-shot-gee', help='W&B project name')
    parser.add_argument('--wandb_name', type=str, default=None, help='W&B run name')
    parser.add_argument('--seed', type=int, default=15, help='Random seed')
    parser.add_argument('--use_test_data', action='store_true', help='Use synthetic test data instead of real data')
    return parser.parse_args()

def apply_chat_template(tokenizer, problem: str) -> str:
    return tokenizer.apply_chat_template(
        [{"role": "user", "content": problem}],
        tokenize=False, add_generation_prompt=True
    )

def main():
    args = parse_args()
    set_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    world_size = int(os.getenv("WORLD_SIZE", "1"))
    micro_bs = args.micro_batch_size
    eff_bs = args.effective_batch
    accum_steps = max(1, eff_bs // (micro_bs * world_size))
    temp = args.temperature
    lr = args.learning_rate

    save_root = args.save_root or (f"checkpoints/{args.model_name}/{args.run_name}" if args.run_name else f"checkpoints/{args.model_name}")
    ds_config = {
        "train_micro_batch_size_per_gpu": micro_bs,
        "train_batch_size": eff_bs,
        "gradient_accumulation_steps": accum_steps,
        "bf16": {"enabled": True},
        "zero_optimization": {
                              "stage": 2, 
                              "offload_optimizer": {"device": "cpu"}, 
                              "offload_param": {"device": "none"}
                             },
        "gradient_clipping": 1.0,
    }
    accelerator = Accelerator(mixed_precision="bf16", 
                              gradient_accumulation_steps=accum_steps, 
                              deepspeed_plugin=DeepSpeedPlugin(hf_ds_config=ds_config))
    print = accelerator.print

    model_path = args.model_path or f"/volume/pt-train/models/{args.model_name}"
    config = AutoConfig.from_pretrained(model_path)
    config.use_cache = False
    model = AutoModelForCausalLM.from_pretrained(model_path, config=config)
    model.gradient_checkpointing_enable()
    tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
    tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token

    # 初始化GEE处理器和损失函数
    gee_processor = GEEProcessor(tokenizer)
    gee_loss_fn = GEELoss(lambda_weight=args.lambda_weight, use_l1=args.use_l1)

    if accelerator.is_main_process:
        wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args))

    # 准备数据
    if args.use_test_data:
        print("使用合成测试数据...")
        train_data = gee_processor.create_test_data(num_samples=200)
    else:
        print("使用真实数据...")
        train_data = gee_processor.prepare_gee_data(
            args.train_data, 
            balance=args.balance_dataset, 
            target_size=args.target_size
        )
    
    train_loader = DataLoader(
        GEEDataset(train_data), 
        batch_size=micro_bs, 
        shuffle=True, 
        collate_fn=custom_collate
    )

    optimizer = AdamW(model.parameters(), lr=lr)
    model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
    
    initial_entropy_gap = None
    model.train()
    
    for step, batch in enumerate(train_loader, start=1):
        if step > args.max_steps:
            print(f"Exceed max step {args.max_steps}, training stopped.")
            break
        
        with accelerator.accumulate(model):
            # 准备输入
            inputs = tokenizer(
                batch["input"], 
                return_tensors="pt", 
                padding="longest", 
                truncation=True, 
                max_length=2048
            ).to(accelerator.device)
            
            # 生成回答
            with torch.no_grad():
                gen_ids = accelerator.unwrap_model(model).generate(
                    **inputs,
                    max_new_tokens=512,
                    do_sample=True,
                    top_p=0.95,
                    temperature=args.sample_temp,
                    synced_gpus=True,
                    repetition_penalty=1.15,
                    pad_token_id=tokenizer.pad_token_id,
                    use_cache=False
                )
            
            # 准备完整序列
            seq = torch.cat([inputs.input_ids, gen_ids[:, inputs.input_ids.shape[1]:]], dim=1)[:, :4096]
            pad_mask = seq.ne(tokenizer.pad_token_id)
            prompt_lengths = pad_mask[:, :inputs.input_ids.shape[1]].sum(-1)
            
            # 计算logits和熵
            logits = model(seq, attention_mask=pad_mask).logits
            H_tok = gee_loss_fn.compute_token_entropy(logits, pad_mask)
            H_i = gee_loss_fn.compute_sample_entropy(H_tok, prompt_lengths)
            
            # 准备性别标签
            gender_labels = torch.tensor([
                gender_to_label(g) for g in batch["gender"]
            ], device=accelerator.device)
            
            # 计算GEE损失
            loss, metrics = gee_loss_fn.compute_gee_loss(H_i, gender_labels)
            
            # 自动退火(可选)
            if args.auto_anneal and initial_entropy_gap is None:
                initial_entropy_gap = metrics['entropy_gap']
            
            if args.auto_anneal and initial_entropy_gap > 0:
                current_gap = metrics['entropy_gap']
                anneal_factor = current_gap / initial_entropy_gap
                new_lambda = args.lambda_weight * anneal_factor
                gee_loss_fn.update_lambda(new_lambda)
                metrics['lambda_weight'] = new_lambda
            
            # 反向传播
            accelerator.backward(loss)
            accelerator.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()
        
        # 日志记录
        if accelerator.is_main_process:
            if step % args.log_steps == 0:
                print(f"Step {step} | loss={loss.item():.6f} | "
                      f"entropy_gap={metrics['entropy_gap']:.6f} | "
                      f"H_male={metrics['H_male']:.6f} | "
                      f"H_female={metrics['H_female']:.6f}")
                wandb.log({"step": step, **metrics})
                
            if step % args.save_steps == 0:
                ckpt = Path(save_root) / f"step_{step}"
                ckpt.mkdir(parents=True, exist_ok=True)
                accelerator.unwrap_model(model).save_pretrained(ckpt, safe_serialization=True)
                tokenizer.save_pretrained(ckpt)
                print(f"Checkpoint saved to {ckpt}")

    if accelerator.is_main_process:
        final = Path(save_root) / "final"
        final.mkdir(parents=True, exist_ok=True)
        accelerator.unwrap_model(model).save_pretrained(final, safe_serialization=True)
        tokenizer.save_pretrained(final)
        print(f"Final checkpoint saved to {final}")
        wandb.finish()

if __name__ == "__main__":
    main()