1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
|
import argparse
import os
import random
import time
from pathlib import Path
import psutil
import torch
import torch.nn.functional as F
from torch.optim import AdamW
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import wandb
from accelerate import Accelerator, DeepSpeedPlugin
from accelerate.utils import set_seed
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
import json
import math
os.environ.setdefault("NCCL_TIMEOUT", "2700")
os.environ.setdefault("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", "2700")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default='Qwen2.5-Math-7B', help='Model name')
parser.add_argument('--model_path', type=str, default=None, help='Local model path')
parser.add_argument('--train_data', type=str, default='dataset/1shot_rlvr/pi1_r1280.parquet', help='Training data file path')
parser.add_argument('--save_root', type=str, default=None, help='Checkpoint save root directory')
parser.add_argument('--effective_batch', type=int, default=64, help='Global batch size')
parser.add_argument('--micro_batch_size', type=str, default=2, help='Micro batch size or "auto"')
parser.add_argument('--temperature', type=float, default=0.5, help='Temperature coefficient')
parser.add_argument('--learning_rate', type=float, default=2e-5, help='Learning rate')
parser.add_argument('--log_steps', type=int, default=1, help='Logging step interval')
parser.add_argument('--save_steps', type=int, default=1, help='Checkpoint saving step interval')
parser.add_argument('--max_steps', type=int, default=1000, help='Maximum training steps')
parser.add_argument('--sample_temp', type=float, default=0.5, help='Generation temperature parameter')
parser.add_argument('--run_name', type=str, default=None, help='Experiment run name')
parser.add_argument('--wandb_project', type=str, default='entropy-maximization-ft', help='W&B project name')
parser.add_argument('--wandb_name', type=str, default=None, help='W&B run name')
parser.add_argument('--seed', type=int, default=15, help='Random seed')
parser.add_argument('--no_deepspeed', action='store_true', help='Disable DeepSpeed and use plain Accelerator (Colab-friendly)')
parser.add_argument('--mixed_precision', type=str, default='bf16', choices=['bf16', 'fp16', 'no'], help='Mixed precision mode')
# GEE options
parser.add_argument('--gee_enable', action='store_true', help='Enable Group-wise Entropy Equalization (debiasing)')
parser.add_argument('--gee_groups_path', type=str, default='groups/gender.json', help='Path to JSON defining groups')
parser.add_argument('--gee_alpha', type=float, default=1.0, help='Weight for group mass parity loss')
parser.add_argument('--gee_beta', type=float, default=0.3, help='Weight for group entropy equalization loss')
parser.add_argument('--gee_lambda', type=float, default=0.0, help='Weight for global entropy anchor')
parser.add_argument('--gee_gamma', type=float, default=0.0, help='Weight for sensitive coverage anchor')
parser.add_argument('--gee_tau', type=float, default=1e-6, help='Min union mass to apply GEE losses')
parser.add_argument('--gee_top_m', type=int, default=200, help='Apply GEE if any group token in top-M at a position')
parser.add_argument('--gee_em_mix', type=float, default=0.1, help='Additive EM loss mix to stabilize training (0 to disable)')
return parser.parse_args()
class FTDataset(Dataset):
def __init__(self, rows): self.rows = rows
def __len__(self): return len(self.rows)
def __getitem__(self, idx): return self.rows[idx]
def custom_collate(batch):
return {"input": [item["input"] for item in batch]}
def get_optimal_micro_batch_size(model_name: str, world_size: int = 1) -> int:
model_configs = {
"1.5B": {"base_batch": 4, "keywords": ["1.5B", "1B"]},
"2B": {"base_batch": 4, "keywords": ["2B"]},
"3B": {"base_batch": 2, "keywords": ["3B"]},
"7B": {"base_batch": 2, "keywords": ["7B"]},
"8B+": {"base_batch": 1, "keywords": ["8B", "9B", "10B", "11B", "12B", "13B", "14B"]},
}
model_name_upper = model_name.upper()
detected = next((cfg for cfg in model_configs.values() if any(k in model_name_upper for k in cfg["keywords"])), None)
base_batch = detected["base_batch"] if detected else 2
if world_size > 1:
return min(base_batch + 1, int(base_batch * 1.5))
return base_batch
def apply_chat_template(tokenizer, problem: str) -> str:
return tokenizer.apply_chat_template(
[{"role": "user", "content": problem}],
tokenize=False, add_generation_prompt=True
)
def main():
args = parse_args()
set_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
world_size = int(os.getenv("WORLD_SIZE", "1"))
micro_bs = int(args.micro_batch_size)
eff_bs = args.effective_batch
accum_steps = max(1, eff_bs // (micro_bs * world_size))
temp = args.temperature
lr = args.learning_rate
save_root = args.save_root or (f"checkpoints/{args.model_name}/{args.run_name}" if args.run_name else f"checkpoints/{args.model_name}")
# Resolve mixed precision automatically if requested bf16 is unsupported
mp = args.mixed_precision
if mp == "bf16":
if not torch.cuda.is_available() or not torch.cuda.is_bf16_supported():
mp = "fp16" if torch.cuda.is_available() else "no"
if args.no_deepspeed:
accelerator = Accelerator(mixed_precision=mp, gradient_accumulation_steps=accum_steps)
else:
ds_config = {
"train_micro_batch_size_per_gpu": micro_bs,
"train_batch_size": eff_bs,
"gradient_accumulation_steps": accum_steps,
"bf16": {"enabled": mp == "bf16"},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {"device": "cpu"},
"offload_param": {"device": "none"}
},
"gradient_clipping": 1.0,
}
accelerator = Accelerator(mixed_precision=mp,
gradient_accumulation_steps=accum_steps,
deepspeed_plugin=DeepSpeedPlugin(hf_ds_config=ds_config))
print = accelerator.print
model_path = args.model_path or f"/volume/pt-train/models/{args.model_name}"
config = AutoConfig.from_pretrained(model_path)
config.use_cache = False
model = AutoModelForCausalLM.from_pretrained(model_path, config=config)
model.gradient_checkpointing_enable()
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
# Prepare GEE group ids and targets if enabled
group_name_list = []
group_id_lists = []
group_target_pi = []
if args.gee_enable:
if not os.path.exists(args.gee_groups_path):
raise FileNotFoundError(f"GEE groups file not found: {args.gee_groups_path}")
with open(args.gee_groups_path, 'r', encoding='utf-8') as f:
groups_payload = json.load(f)
groups = groups_payload.get('groups', {})
pis = groups_payload.get('pi', {})
if not groups:
raise ValueError('GEE groups.json is missing a "groups" object')
# Build token-id lists per group (first sub-token of each provided string)
for gname, tokens in groups.items():
ids = []
for w in tokens:
toks = tokenizer.tokenize(w)
if toks:
tid = tokenizer.convert_tokens_to_ids(toks[0])
if tid is not None:
ids.append(tid)
ids = sorted(set([i for i in ids if isinstance(i, int) and i >= 0]))
if len(ids) == 0:
continue
group_name_list.append(gname)
group_id_lists.append(torch.tensor(ids, dtype=torch.long))
group_target_pi.append(float(pis.get(gname, 1.0)))
if not group_id_lists:
raise ValueError('No valid group token ids produced from groups file')
# Normalize pi to sum to 1
total_pi = sum(group_target_pi)
if total_pi <= 0:
group_target_pi = [1.0 / len(group_id_lists)] * len(group_id_lists)
else:
group_target_pi = [p / total_pi for p in group_target_pi]
if accelerator.is_main_process:
wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args))
# Friendly error if the parquet path is missing
if not os.path.exists(args.train_data):
raise FileNotFoundError(f"Training data not found: {args.train_data}. Create/upload the parquet under the project folder or pass --train_data to an existing path.")
# Friendly error if the parquet path is missing
if not os.path.exists(args.train_data):
raise FileNotFoundError(f"Training data not found: {args.train_data}. Create/upload the parquet under the project folder or pass --train_data to an existing path.")
df = pd.read_parquet(args.train_data)
train_data = [{"input": apply_chat_template(tokenizer, p)} for p in df["problem"].dropna().tolist()]
train_loader = DataLoader(FTDataset(train_data), batch_size=micro_bs, shuffle=True, collate_fn=custom_collate)
optimizer = AdamW(model.parameters(), lr=lr)
model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
prev_logits = None
baseline_avg_H = None # for L_reg
baseline_union_mass = None # for coverage anchor
model.train()
for step, batch in enumerate(train_loader, start=1):
if step > args.max_steps:
print(f"Exceed max step {args.max_steps}, training stopped.")
break
with accelerator.accumulate(model):
enc = tokenizer(batch["input"],
return_tensors="pt",
padding="longest",
truncation=True,
max_length=2048).to(accelerator.device)
with torch.no_grad():
use_synced = getattr(accelerator, "num_processes", 1) and accelerator.num_processes > 1
gen_ids = accelerator.unwrap_model(model).generate(
**enc,
max_new_tokens=512,
do_sample=True,
top_p=0.95,
temperature=args.sample_temp,
synced_gpus=use_synced,
repetition_penalty=1.15,
pad_token_id=tokenizer.pad_token_id,
use_cache=False,
)
seq = torch.cat([enc.input_ids, gen_ids[:, enc.input_ids.shape[1]:]], dim=1)[:, :4096]
pad_mask = seq.ne(tokenizer.pad_token_id)
prompt_len = pad_mask[:, :enc.input_ids.shape[1]].sum(-1)
token_idx = torch.arange(seq.size(1), device=seq.device)
gen_mask = (token_idx.unsqueeze(0) >= prompt_len.unsqueeze(1)) & pad_mask
logits = model(seq, attention_mask=pad_mask).logits # [B,T,V]
probs = F.softmax(logits / temp, dim=-1)
H_tok = -(probs * torch.log(probs + 1e-12)).sum(-1) # [B,T]
# Precompute EM loss (used standalone or as mixin)
em_loss = (H_tok * gen_mask).sum() / gen_mask.sum().clamp_min(1)
if args.gee_enable:
# Compute GEE losses on generated positions only
mask = gen_mask
denom = mask.sum().clamp_min(1)
# Union mass and per-group mass
group_masses = [] # list of [B,T]
top_m = args.gee_top_m
# Top-M indices to decide triggering
if top_m > 0:
topk = probs.topk(k=min(top_m, probs.size(-1)), dim=-1).indices # [B,T,M]
else:
topk = None
for ids in group_id_lists:
ids = ids.to(probs.device)
gm = probs.index_select(-1, ids).sum(-1) # [B,T]
group_masses.append(gm)
union_mass = torch.stack(group_masses, dim=-1).sum(-1) # [B,T]
# Trigger mask: apply only when union_mass >= tau OR any group id in top-M
trigger = union_mass >= args.gee_tau
if topk is not None:
any_in_top = torch.zeros_like(trigger)
vocab_in_top = topk # [B,T,M]
for ids in group_id_lists:
ids = ids.to(probs.device)
g_match = (vocab_in_top.unsqueeze(-1) == ids.view(1,1,1,-1)).any(-1) # [B,T,M]
any_in_top |= g_match.any(-1)
trigger |= any_in_top
eff_mask = mask & trigger
eff_denom = eff_mask.sum().clamp_min(1)
# L_mass: group-mass parity to target pi
pi = torch.tensor(group_target_pi, device=probs.device, dtype=probs.dtype).view(1,1,-1) # [1,1,K]
masses_stacked = torch.stack(group_masses, dim=-1) # [B,T,K]
mass_gap = (masses_stacked - pi).pow(2) # [B,T,K]
L_mass = (mass_gap.sum(-1) * eff_mask).sum() / eff_denom
# L_GEE: equalize normalized group entropy per position
norm_group_entropies = [] # [B,T] per group
for ids in group_id_lists:
ids = ids.to(probs.device)
p_sub = probs.index_select(-1, ids) # [B,T,|G|]
denom_g = p_sub.sum(-1, keepdim=True).clamp_min(1e-12)
p_g = p_sub / denom_g
H_g = -(p_g * torch.log(p_g + 1e-12)).sum(-1) # [B,T]
max_H = math.log(p_sub.size(-1)) if p_sub.size(-1) > 1 else 1.0
H_g_norm = H_g / max(max_H, 1e-12)
norm_group_entropies.append(H_g_norm)
H_stack = torch.stack(norm_group_entropies, dim=-1) # [B,T,K]
H_bar = H_stack.mean(-1, keepdim=True) # [B,T,1]
L_gee = (((H_stack - H_bar) ** 2).sum(-1) * eff_mask).sum() / eff_denom
# L_reg: global entropy anchor to baseline average
if args.gee_lambda > 0:
avg_H = (H_tok * mask).sum() / denom
if baseline_avg_H is None:
baseline_avg_H = avg_H.detach()
L_reg = (avg_H - baseline_avg_H).pow(2)
else:
L_reg = torch.zeros((), device=probs.device, dtype=probs.dtype)
# L_cov: keep union sensitive mass near baseline
if args.gee_gamma > 0:
avg_union = (union_mass * mask).sum() / denom
if baseline_union_mass is None:
baseline_union_mass = avg_union.detach()
L_cov = (avg_union - baseline_union_mass).pow(2)
else:
L_cov = torch.zeros((), device=probs.device, dtype=probs.dtype)
loss_gee = args.gee_alpha * L_mass + args.gee_beta * L_gee + args.gee_lambda * L_reg + args.gee_gamma * L_cov
# Fallback: if no positions triggered, use EM loss to ensure updates
if eff_denom.item() == 0:
loss = em_loss
else:
loss = loss_gee + (args.gee_em_mix * em_loss if args.gee_em_mix > 0 else 0.0)
# Log activation ratio if main process
if accelerator.is_main_process:
gee_active_ratio = (eff_denom / denom).item()
try:
wandb.log({"gee_active_ratio": gee_active_ratio,
"L_mass": float(L_mass.detach().item()),
"L_gee": float(L_gee.detach().item())})
except Exception:
pass
else:
# Original One-shot EM loss
loss = em_loss
prev_logits = logits.detach()
accelerator.backward(loss)
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
if accelerator.is_main_process:
if step % args.log_steps == 0:
print(f"Step {step} | loss={loss.item():.6f}")
wandb.log({"step": step, "loss": loss.item()})
if step % args.save_steps == 0:
ckpt = Path(save_root) / f"step_{step}"
ckpt.mkdir(parents=True, exist_ok=True)
accelerator.unwrap_model(model).save_pretrained(ckpt, safe_serialization=True)
tokenizer.save_pretrained(ckpt)
print(f"Checkpoint saved to {ckpt}")
if accelerator.is_main_process:
final = Path(save_root) / "final"
final.mkdir(parents=True, exist_ok=True)
accelerator.unwrap_model(model).save_pretrained(final, safe_serialization=True)
tokenizer.save_pretrained(final)
print(f"Final checkpoint saved to {final}")
wandb.finish()
if __name__ == "__main__":
main()
|