import argparse import os import random import time from pathlib import Path import psutil import torch import torch.nn.functional as F from torch.optim import AdamW import pandas as pd import numpy as np from torch.utils.data import Dataset, DataLoader import wandb from accelerate import Accelerator, DeepSpeedPlugin from accelerate.utils import set_seed from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM import json import math os.environ.setdefault("NCCL_TIMEOUT", "2700") os.environ.setdefault("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", "2700") def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--model_name', type=str, default='Qwen2.5-Math-7B', help='Model name') parser.add_argument('--model_path', type=str, default=None, help='Local model path') parser.add_argument('--train_data', type=str, default='dataset/1shot_rlvr/pi1_r1280.parquet', help='Training data file path') parser.add_argument('--save_root', type=str, default=None, help='Checkpoint save root directory') parser.add_argument('--effective_batch', type=int, default=64, help='Global batch size') parser.add_argument('--micro_batch_size', type=str, default=2, help='Micro batch size or "auto"') parser.add_argument('--temperature', type=float, default=0.5, help='Temperature coefficient') parser.add_argument('--learning_rate', type=float, default=2e-5, help='Learning rate') parser.add_argument('--log_steps', type=int, default=1, help='Logging step interval') parser.add_argument('--save_steps', type=int, default=1, help='Checkpoint saving step interval') parser.add_argument('--max_steps', type=int, default=1000, help='Maximum training steps') parser.add_argument('--sample_temp', type=float, default=0.5, help='Generation temperature parameter') parser.add_argument('--run_name', type=str, default=None, help='Experiment run name') parser.add_argument('--wandb_project', type=str, default='entropy-maximization-ft', help='W&B project name') parser.add_argument('--wandb_name', type=str, default=None, help='W&B run name') parser.add_argument('--seed', type=int, default=15, help='Random seed') parser.add_argument('--no_deepspeed', action='store_true', help='Disable DeepSpeed and use plain Accelerator (Colab-friendly)') parser.add_argument('--mixed_precision', type=str, default='bf16', choices=['bf16', 'fp16', 'no'], help='Mixed precision mode') # GEE options parser.add_argument('--gee_enable', action='store_true', help='Enable Group-wise Entropy Equalization (debiasing)') parser.add_argument('--gee_groups_path', type=str, default='groups/gender.json', help='Path to JSON defining groups') parser.add_argument('--gee_alpha', type=float, default=1.0, help='Weight for group mass parity loss') parser.add_argument('--gee_beta', type=float, default=0.3, help='Weight for group entropy equalization loss') parser.add_argument('--gee_lambda', type=float, default=0.0, help='Weight for global entropy anchor') parser.add_argument('--gee_gamma', type=float, default=0.0, help='Weight for sensitive coverage anchor') parser.add_argument('--gee_tau', type=float, default=1e-6, help='Min union mass to apply GEE losses') parser.add_argument('--gee_top_m', type=int, default=200, help='Apply GEE if any group token in top-M at a position') parser.add_argument('--gee_em_mix', type=float, default=0.1, help='Additive EM loss mix to stabilize training (0 to disable)') return parser.parse_args() class FTDataset(Dataset): def __init__(self, rows): self.rows = rows def __len__(self): return len(self.rows) def __getitem__(self, idx): return self.rows[idx] def custom_collate(batch): return {"input": [item["input"] for item in batch]} def get_optimal_micro_batch_size(model_name: str, world_size: int = 1) -> int: model_configs = { "1.5B": {"base_batch": 4, "keywords": ["1.5B", "1B"]}, "2B": {"base_batch": 4, "keywords": ["2B"]}, "3B": {"base_batch": 2, "keywords": ["3B"]}, "7B": {"base_batch": 2, "keywords": ["7B"]}, "8B+": {"base_batch": 1, "keywords": ["8B", "9B", "10B", "11B", "12B", "13B", "14B"]}, } model_name_upper = model_name.upper() detected = next((cfg for cfg in model_configs.values() if any(k in model_name_upper for k in cfg["keywords"])), None) base_batch = detected["base_batch"] if detected else 2 if world_size > 1: return min(base_batch + 1, int(base_batch * 1.5)) return base_batch def apply_chat_template(tokenizer, problem: str) -> str: return tokenizer.apply_chat_template( [{"role": "user", "content": problem}], tokenize=False, add_generation_prompt=True ) def main(): args = parse_args() set_seed(args.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False world_size = int(os.getenv("WORLD_SIZE", "1")) micro_bs = int(args.micro_batch_size) eff_bs = args.effective_batch accum_steps = max(1, eff_bs // (micro_bs * world_size)) temp = args.temperature lr = args.learning_rate save_root = args.save_root or (f"checkpoints/{args.model_name}/{args.run_name}" if args.run_name else f"checkpoints/{args.model_name}") # Resolve mixed precision automatically if requested bf16 is unsupported mp = args.mixed_precision if mp == "bf16": if not torch.cuda.is_available() or not torch.cuda.is_bf16_supported(): mp = "fp16" if torch.cuda.is_available() else "no" if args.no_deepspeed: accelerator = Accelerator(mixed_precision=mp, gradient_accumulation_steps=accum_steps) else: ds_config = { "train_micro_batch_size_per_gpu": micro_bs, "train_batch_size": eff_bs, "gradient_accumulation_steps": accum_steps, "bf16": {"enabled": mp == "bf16"}, "zero_optimization": { "stage": 2, "offload_optimizer": {"device": "cpu"}, "offload_param": {"device": "none"} }, "gradient_clipping": 1.0, } accelerator = Accelerator(mixed_precision=mp, gradient_accumulation_steps=accum_steps, deepspeed_plugin=DeepSpeedPlugin(hf_ds_config=ds_config)) print = accelerator.print model_path = args.model_path or f"/volume/pt-train/models/{args.model_name}" config = AutoConfig.from_pretrained(model_path) config.use_cache = False model = AutoModelForCausalLM.from_pretrained(model_path, config=config) model.gradient_checkpointing_enable() tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left") tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token # Prepare GEE group ids and targets if enabled group_name_list = [] group_id_lists = [] group_target_pi = [] if args.gee_enable: if not os.path.exists(args.gee_groups_path): raise FileNotFoundError(f"GEE groups file not found: {args.gee_groups_path}") with open(args.gee_groups_path, 'r', encoding='utf-8') as f: groups_payload = json.load(f) groups = groups_payload.get('groups', {}) pis = groups_payload.get('pi', {}) if not groups: raise ValueError('GEE groups.json is missing a "groups" object') # Build token-id lists per group (first sub-token of each provided string) for gname, tokens in groups.items(): ids = [] for w in tokens: toks = tokenizer.tokenize(w) if toks: tid = tokenizer.convert_tokens_to_ids(toks[0]) if tid is not None: ids.append(tid) ids = sorted(set([i for i in ids if isinstance(i, int) and i >= 0])) if len(ids) == 0: continue group_name_list.append(gname) group_id_lists.append(torch.tensor(ids, dtype=torch.long)) group_target_pi.append(float(pis.get(gname, 1.0))) if not group_id_lists: raise ValueError('No valid group token ids produced from groups file') # Normalize pi to sum to 1 total_pi = sum(group_target_pi) if total_pi <= 0: group_target_pi = [1.0 / len(group_id_lists)] * len(group_id_lists) else: group_target_pi = [p / total_pi for p in group_target_pi] if accelerator.is_main_process: wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args)) # Friendly error if the parquet path is missing if not os.path.exists(args.train_data): raise FileNotFoundError(f"Training data not found: {args.train_data}. Create/upload the parquet under the project folder or pass --train_data to an existing path.") # Friendly error if the parquet path is missing if not os.path.exists(args.train_data): raise FileNotFoundError(f"Training data not found: {args.train_data}. Create/upload the parquet under the project folder or pass --train_data to an existing path.") df = pd.read_parquet(args.train_data) train_data = [{"input": apply_chat_template(tokenizer, p)} for p in df["problem"].dropna().tolist()] train_loader = DataLoader(FTDataset(train_data), batch_size=micro_bs, shuffle=True, collate_fn=custom_collate) optimizer = AdamW(model.parameters(), lr=lr) model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader) prev_logits = None baseline_avg_H = None # for L_reg baseline_union_mass = None # for coverage anchor model.train() for step, batch in enumerate(train_loader, start=1): if step > args.max_steps: print(f"Exceed max step {args.max_steps}, training stopped.") break with accelerator.accumulate(model): enc = tokenizer(batch["input"], return_tensors="pt", padding="longest", truncation=True, max_length=2048).to(accelerator.device) with torch.no_grad(): use_synced = getattr(accelerator, "num_processes", 1) and accelerator.num_processes > 1 gen_ids = accelerator.unwrap_model(model).generate( **enc, max_new_tokens=512, do_sample=True, top_p=0.95, temperature=args.sample_temp, synced_gpus=use_synced, repetition_penalty=1.15, pad_token_id=tokenizer.pad_token_id, use_cache=False, ) seq = torch.cat([enc.input_ids, gen_ids[:, enc.input_ids.shape[1]:]], dim=1)[:, :4096] pad_mask = seq.ne(tokenizer.pad_token_id) prompt_len = pad_mask[:, :enc.input_ids.shape[1]].sum(-1) token_idx = torch.arange(seq.size(1), device=seq.device) gen_mask = (token_idx.unsqueeze(0) >= prompt_len.unsqueeze(1)) & pad_mask logits = model(seq, attention_mask=pad_mask).logits # [B,T,V] probs = F.softmax(logits / temp, dim=-1) H_tok = -(probs * torch.log(probs + 1e-12)).sum(-1) # [B,T] # Precompute EM loss (used standalone or as mixin) em_loss = (H_tok * gen_mask).sum() / gen_mask.sum().clamp_min(1) if args.gee_enable: # Compute GEE losses on generated positions only mask = gen_mask denom = mask.sum().clamp_min(1) # Union mass and per-group mass group_masses = [] # list of [B,T] top_m = args.gee_top_m # Top-M indices to decide triggering if top_m > 0: topk = probs.topk(k=min(top_m, probs.size(-1)), dim=-1).indices # [B,T,M] else: topk = None for ids in group_id_lists: ids = ids.to(probs.device) gm = probs.index_select(-1, ids).sum(-1) # [B,T] group_masses.append(gm) union_mass = torch.stack(group_masses, dim=-1).sum(-1) # [B,T] # Trigger mask: apply only when union_mass >= tau OR any group id in top-M trigger = union_mass >= args.gee_tau if topk is not None: any_in_top = torch.zeros_like(trigger) vocab_in_top = topk # [B,T,M] for ids in group_id_lists: ids = ids.to(probs.device) g_match = (vocab_in_top.unsqueeze(-1) == ids.view(1,1,1,-1)).any(-1) # [B,T,M] any_in_top |= g_match.any(-1) trigger |= any_in_top eff_mask = mask & trigger eff_denom = eff_mask.sum().clamp_min(1) # L_mass: group-mass parity to target pi pi = torch.tensor(group_target_pi, device=probs.device, dtype=probs.dtype).view(1,1,-1) # [1,1,K] masses_stacked = torch.stack(group_masses, dim=-1) # [B,T,K] mass_gap = (masses_stacked - pi).pow(2) # [B,T,K] L_mass = (mass_gap.sum(-1) * eff_mask).sum() / eff_denom # L_GEE: equalize normalized group entropy per position norm_group_entropies = [] # [B,T] per group for ids in group_id_lists: ids = ids.to(probs.device) p_sub = probs.index_select(-1, ids) # [B,T,|G|] denom_g = p_sub.sum(-1, keepdim=True).clamp_min(1e-12) p_g = p_sub / denom_g H_g = -(p_g * torch.log(p_g + 1e-12)).sum(-1) # [B,T] max_H = math.log(p_sub.size(-1)) if p_sub.size(-1) > 1 else 1.0 H_g_norm = H_g / max(max_H, 1e-12) norm_group_entropies.append(H_g_norm) H_stack = torch.stack(norm_group_entropies, dim=-1) # [B,T,K] H_bar = H_stack.mean(-1, keepdim=True) # [B,T,1] L_gee = (((H_stack - H_bar) ** 2).sum(-1) * eff_mask).sum() / eff_denom # L_reg: global entropy anchor to baseline average if args.gee_lambda > 0: avg_H = (H_tok * mask).sum() / denom if baseline_avg_H is None: baseline_avg_H = avg_H.detach() L_reg = (avg_H - baseline_avg_H).pow(2) else: L_reg = torch.zeros((), device=probs.device, dtype=probs.dtype) # L_cov: keep union sensitive mass near baseline if args.gee_gamma > 0: avg_union = (union_mass * mask).sum() / denom if baseline_union_mass is None: baseline_union_mass = avg_union.detach() L_cov = (avg_union - baseline_union_mass).pow(2) else: L_cov = torch.zeros((), device=probs.device, dtype=probs.dtype) loss_gee = args.gee_alpha * L_mass + args.gee_beta * L_gee + args.gee_lambda * L_reg + args.gee_gamma * L_cov # Fallback: if no positions triggered, use EM loss to ensure updates if eff_denom.item() == 0: loss = em_loss else: loss = loss_gee + (args.gee_em_mix * em_loss if args.gee_em_mix > 0 else 0.0) # Log activation ratio if main process if accelerator.is_main_process: gee_active_ratio = (eff_denom / denom).item() try: wandb.log({"gee_active_ratio": gee_active_ratio, "L_mass": float(L_mass.detach().item()), "L_gee": float(L_gee.detach().item())}) except Exception: pass else: # Original One-shot EM loss loss = em_loss prev_logits = logits.detach() accelerator.backward(loss) accelerator.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() optimizer.zero_grad() if accelerator.is_main_process: if step % args.log_steps == 0: print(f"Step {step} | loss={loss.item():.6f}") wandb.log({"step": step, "loss": loss.item()}) if step % args.save_steps == 0: ckpt = Path(save_root) / f"step_{step}" ckpt.mkdir(parents=True, exist_ok=True) accelerator.unwrap_model(model).save_pretrained(ckpt, safe_serialization=True) tokenizer.save_pretrained(ckpt) print(f"Checkpoint saved to {ckpt}") if accelerator.is_main_process: final = Path(save_root) / "final" final.mkdir(parents=True, exist_ok=True) accelerator.unwrap_model(model).save_pretrained(final, safe_serialization=True) tokenizer.save_pretrained(final) print(f"Final checkpoint saved to {final}") wandb.finish() if __name__ == "__main__": main()