summaryrefslogtreecommitdiff
path: root/Group-Entropy-Equalization/train.py
blob: 2cc9438134fe23a515273eb2e55da72e2edd0197 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
import argparse
import os
import random
import time
from pathlib import Path

import psutil
import torch
import torch.nn.functional as F
from torch.optim import AdamW
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader

import wandb

from accelerate import Accelerator, DeepSpeedPlugin
from accelerate.utils import set_seed
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
import json
import math


os.environ.setdefault("NCCL_TIMEOUT", "2700")
os.environ.setdefault("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", "2700")

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default='Qwen2.5-Math-7B', help='Model name')
    parser.add_argument('--model_path', type=str, default=None, help='Local model path')
    parser.add_argument('--train_data', type=str, default='dataset/1shot_rlvr/pi1_r1280.parquet', help='Training data file path')
    parser.add_argument('--save_root', type=str, default=None, help='Checkpoint save root directory')
    parser.add_argument('--effective_batch', type=int, default=64, help='Global batch size')
    parser.add_argument('--micro_batch_size', type=str, default=2, help='Micro batch size or "auto"')
    parser.add_argument('--temperature', type=float, default=0.5, help='Temperature coefficient')
    parser.add_argument('--learning_rate', type=float, default=2e-5, help='Learning rate')
    parser.add_argument('--log_steps', type=int, default=1, help='Logging step interval')
    parser.add_argument('--save_steps', type=int, default=1, help='Checkpoint saving step interval')
    parser.add_argument('--max_steps', type=int, default=1000, help='Maximum training steps')
    parser.add_argument('--sample_temp', type=float, default=0.5, help='Generation temperature parameter')
    parser.add_argument('--run_name', type=str, default=None, help='Experiment run name')
    parser.add_argument('--wandb_project', type=str, default='entropy-maximization-ft', help='W&B project name')
    parser.add_argument('--wandb_name', type=str, default=None, help='W&B run name')
    parser.add_argument('--seed', type=int, default=15, help='Random seed')
    parser.add_argument('--no_deepspeed', action='store_true', help='Disable DeepSpeed and use plain Accelerator (Colab-friendly)')
    parser.add_argument('--mixed_precision', type=str, default='bf16', choices=['bf16', 'fp16', 'no'], help='Mixed precision mode')
    # GEE options
    parser.add_argument('--gee_enable', action='store_true', help='Enable Group-wise Entropy Equalization (debiasing)')
    parser.add_argument('--gee_groups_path', type=str, default='groups/gender.json', help='Path to JSON defining groups')
    parser.add_argument('--gee_alpha', type=float, default=1.0, help='Weight for group mass parity loss')
    parser.add_argument('--gee_beta', type=float, default=0.3, help='Weight for group entropy equalization loss')
    parser.add_argument('--gee_lambda', type=float, default=0.0, help='Weight for global entropy anchor')
    parser.add_argument('--gee_gamma', type=float, default=0.0, help='Weight for sensitive coverage anchor')
    parser.add_argument('--gee_tau', type=float, default=1e-6, help='Min union mass to apply GEE losses')
    parser.add_argument('--gee_top_m', type=int, default=200, help='Apply GEE if any group token in top-M at a position')
    parser.add_argument('--gee_em_mix', type=float, default=0.1, help='Additive EM loss mix to stabilize training (0 to disable)')
    return parser.parse_args()

class FTDataset(Dataset):
    def __init__(self, rows): self.rows = rows
    def __len__(self): return len(self.rows)
    def __getitem__(self, idx): return self.rows[idx]

def custom_collate(batch):
    return {"input": [item["input"] for item in batch]}

def get_optimal_micro_batch_size(model_name: str, world_size: int = 1) -> int:
    model_configs = {
        "1.5B": {"base_batch": 4, "keywords": ["1.5B", "1B"]},
        "2B": {"base_batch": 4, "keywords": ["2B"]},
        "3B": {"base_batch": 2, "keywords": ["3B"]},
        "7B": {"base_batch": 2, "keywords": ["7B"]},
        "8B+": {"base_batch": 1, "keywords": ["8B", "9B", "10B", "11B", "12B", "13B", "14B"]},
    }
    model_name_upper = model_name.upper()
    detected = next((cfg for cfg in model_configs.values() if any(k in model_name_upper for k in cfg["keywords"])), None)
    base_batch = detected["base_batch"] if detected else 2
    if world_size > 1:
        return min(base_batch + 1, int(base_batch * 1.5))
    return base_batch

def apply_chat_template(tokenizer, problem: str) -> str:
    return tokenizer.apply_chat_template(
        [{"role": "user", "content": problem}],
        tokenize=False, add_generation_prompt=True
    )

def main():
    args = parse_args()
    set_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    world_size = int(os.getenv("WORLD_SIZE", "1"))
    micro_bs = int(args.micro_batch_size)
    eff_bs = args.effective_batch
    accum_steps = max(1, eff_bs // (micro_bs * world_size))
    temp = args.temperature
    lr = args.learning_rate

    save_root = args.save_root or (f"checkpoints/{args.model_name}/{args.run_name}" if args.run_name else f"checkpoints/{args.model_name}")
    # Resolve mixed precision automatically if requested bf16 is unsupported
    mp = args.mixed_precision
    if mp == "bf16":
        if not torch.cuda.is_available() or not torch.cuda.is_bf16_supported():
            mp = "fp16" if torch.cuda.is_available() else "no"

    if args.no_deepspeed:
        accelerator = Accelerator(mixed_precision=mp, gradient_accumulation_steps=accum_steps)
    else:
        ds_config = {
            "train_micro_batch_size_per_gpu": micro_bs,
            "train_batch_size": eff_bs,
            "gradient_accumulation_steps": accum_steps,
            "bf16": {"enabled": mp == "bf16"},
            "zero_optimization": {
                                  "stage": 2,
                                  "offload_optimizer": {"device": "cpu"},
                                  "offload_param": {"device": "none"}
                                 },
            "gradient_clipping": 1.0,
        }
        accelerator = Accelerator(mixed_precision=mp,
                                  gradient_accumulation_steps=accum_steps,
                                  deepspeed_plugin=DeepSpeedPlugin(hf_ds_config=ds_config))
    print = accelerator.print

    model_path = args.model_path or f"/volume/pt-train/models/{args.model_name}"
    config = AutoConfig.from_pretrained(model_path)
    config.use_cache = False
    model = AutoModelForCausalLM.from_pretrained(model_path, config=config)
    model.gradient_checkpointing_enable()
    tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
    tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token

    # Prepare GEE group ids and targets if enabled
    group_name_list = []
    group_id_lists = []
    group_target_pi = []
    if args.gee_enable:
        if not os.path.exists(args.gee_groups_path):
            raise FileNotFoundError(f"GEE groups file not found: {args.gee_groups_path}")
        with open(args.gee_groups_path, 'r', encoding='utf-8') as f:
            groups_payload = json.load(f)
        groups = groups_payload.get('groups', {})
        pis = groups_payload.get('pi', {})
        if not groups:
            raise ValueError('GEE groups.json is missing a "groups" object')
        # Build token-id lists per group (first sub-token of each provided string)
        for gname, tokens in groups.items():
            ids = []
            for w in tokens:
                toks = tokenizer.tokenize(w)
                if toks:
                    tid = tokenizer.convert_tokens_to_ids(toks[0])
                    if tid is not None:
                        ids.append(tid)
            ids = sorted(set([i for i in ids if isinstance(i, int) and i >= 0]))
            if len(ids) == 0:
                continue
            group_name_list.append(gname)
            group_id_lists.append(torch.tensor(ids, dtype=torch.long))
            group_target_pi.append(float(pis.get(gname, 1.0)))
        if not group_id_lists:
            raise ValueError('No valid group token ids produced from groups file')
        # Normalize pi to sum to 1
        total_pi = sum(group_target_pi)
        if total_pi <= 0:
            group_target_pi = [1.0 / len(group_id_lists)] * len(group_id_lists)
        else:
            group_target_pi = [p / total_pi for p in group_target_pi]

    if accelerator.is_main_process:
        wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args))

    # Friendly error if the parquet path is missing
    if not os.path.exists(args.train_data):
        raise FileNotFoundError(f"Training data not found: {args.train_data}. Create/upload the parquet under the project folder or pass --train_data to an existing path.")
    # Friendly error if the parquet path is missing
    if not os.path.exists(args.train_data):
        raise FileNotFoundError(f"Training data not found: {args.train_data}. Create/upload the parquet under the project folder or pass --train_data to an existing path.")
    df = pd.read_parquet(args.train_data)
    train_data = [{"input": apply_chat_template(tokenizer, p)} for p in df["problem"].dropna().tolist()]
    train_loader = DataLoader(FTDataset(train_data), batch_size=micro_bs, shuffle=True, collate_fn=custom_collate)

    optimizer = AdamW(model.parameters(), lr=lr)
    model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
    prev_logits = None
    baseline_avg_H = None  # for L_reg
    baseline_union_mass = None  # for coverage anchor
    model.train()
    
    for step, batch in enumerate(train_loader, start=1):
        if step > args.max_steps:
            print(f"Exceed max step {args.max_steps}, training stopped.")
            break
        
        with accelerator.accumulate(model):
            enc = tokenizer(batch["input"], 
                            return_tensors="pt", 
                            padding="longest", 
                            truncation=True, 
                            max_length=2048).to(accelerator.device)
            
            with torch.no_grad():
                use_synced = getattr(accelerator, "num_processes", 1) and accelerator.num_processes > 1
                gen_ids = accelerator.unwrap_model(model).generate(
                    **enc,
                    max_new_tokens=512,
                    do_sample=True,
                    top_p=0.95,
                    temperature=args.sample_temp,
                    synced_gpus=use_synced,
                    repetition_penalty=1.15,
                    pad_token_id=tokenizer.pad_token_id,
                    use_cache=False,
                )
                
            seq = torch.cat([enc.input_ids, gen_ids[:, enc.input_ids.shape[1]:]], dim=1)[:, :4096]
            pad_mask = seq.ne(tokenizer.pad_token_id)
            prompt_len = pad_mask[:, :enc.input_ids.shape[1]].sum(-1)
            token_idx = torch.arange(seq.size(1), device=seq.device)
            gen_mask = (token_idx.unsqueeze(0) >= prompt_len.unsqueeze(1)) & pad_mask

            logits = model(seq, attention_mask=pad_mask).logits  # [B,T,V]
            probs = F.softmax(logits / temp, dim=-1)
            H_tok = -(probs * torch.log(probs + 1e-12)).sum(-1)  # [B,T]

            # Precompute EM loss (used standalone or as mixin)
            em_loss = (H_tok * gen_mask).sum() / gen_mask.sum().clamp_min(1)

            if args.gee_enable:
                # Compute GEE losses on generated positions only
                mask = gen_mask
                denom = mask.sum().clamp_min(1)

                # Union mass and per-group mass
                group_masses = []  # list of [B,T]
                top_m = args.gee_top_m
                # Top-M indices to decide triggering
                if top_m > 0:
                    topk = probs.topk(k=min(top_m, probs.size(-1)), dim=-1).indices  # [B,T,M]
                else:
                    topk = None
                for ids in group_id_lists:
                    ids = ids.to(probs.device)
                    gm = probs.index_select(-1, ids).sum(-1)  # [B,T]
                    group_masses.append(gm)
                union_mass = torch.stack(group_masses, dim=-1).sum(-1)  # [B,T]

                # Trigger mask: apply only when union_mass >= tau OR any group id in top-M
                trigger = union_mass >= args.gee_tau
                if topk is not None:
                    any_in_top = torch.zeros_like(trigger)
                    vocab_in_top = topk  # [B,T,M]
                    for ids in group_id_lists:
                        ids = ids.to(probs.device)
                        g_match = (vocab_in_top.unsqueeze(-1) == ids.view(1,1,1,-1)).any(-1)  # [B,T,M]
                        any_in_top |= g_match.any(-1)
                    trigger |= any_in_top
                eff_mask = mask & trigger
                eff_denom = eff_mask.sum().clamp_min(1)

                # L_mass: group-mass parity to target pi
                pi = torch.tensor(group_target_pi, device=probs.device, dtype=probs.dtype).view(1,1,-1)  # [1,1,K]
                masses_stacked = torch.stack(group_masses, dim=-1)  # [B,T,K]
                mass_gap = (masses_stacked - pi).pow(2)  # [B,T,K]
                L_mass = (mass_gap.sum(-1) * eff_mask).sum() / eff_denom

                # L_GEE: equalize normalized group entropy per position
                norm_group_entropies = []  # [B,T] per group
                for ids in group_id_lists:
                    ids = ids.to(probs.device)
                    p_sub = probs.index_select(-1, ids)  # [B,T,|G|]
                    denom_g = p_sub.sum(-1, keepdim=True).clamp_min(1e-12)
                    p_g = p_sub / denom_g
                    H_g = -(p_g * torch.log(p_g + 1e-12)).sum(-1)  # [B,T]
                    max_H = math.log(p_sub.size(-1)) if p_sub.size(-1) > 1 else 1.0
                    H_g_norm = H_g / max(max_H, 1e-12)
                    norm_group_entropies.append(H_g_norm)
                H_stack = torch.stack(norm_group_entropies, dim=-1)  # [B,T,K]
                H_bar = H_stack.mean(-1, keepdim=True)  # [B,T,1]
                L_gee = (((H_stack - H_bar) ** 2).sum(-1) * eff_mask).sum() / eff_denom

                # L_reg: global entropy anchor to baseline average
                if args.gee_lambda > 0:
                    avg_H = (H_tok * mask).sum() / denom
                    if baseline_avg_H is None:
                        baseline_avg_H = avg_H.detach()
                    L_reg = (avg_H - baseline_avg_H).pow(2)
                else:
                    L_reg = torch.zeros((), device=probs.device, dtype=probs.dtype)

                # L_cov: keep union sensitive mass near baseline
                if args.gee_gamma > 0:
                    avg_union = (union_mass * mask).sum() / denom
                    if baseline_union_mass is None:
                        baseline_union_mass = avg_union.detach()
                    L_cov = (avg_union - baseline_union_mass).pow(2)
                else:
                    L_cov = torch.zeros((), device=probs.device, dtype=probs.dtype)

                loss_gee = args.gee_alpha * L_mass + args.gee_beta * L_gee + args.gee_lambda * L_reg + args.gee_gamma * L_cov
                # Fallback: if no positions triggered, use EM loss to ensure updates
                if eff_denom.item() == 0:
                    loss = em_loss
                else:
                    loss = loss_gee + (args.gee_em_mix * em_loss if args.gee_em_mix > 0 else 0.0)
                # Log activation ratio if main process
                if accelerator.is_main_process:
                    gee_active_ratio = (eff_denom / denom).item()
                    try:
                        wandb.log({"gee_active_ratio": gee_active_ratio,
                                   "L_mass": float(L_mass.detach().item()),
                                   "L_gee": float(L_gee.detach().item())})
                    except Exception:
                        pass
            else:
                # Original One-shot EM loss
                loss = em_loss

            prev_logits = logits.detach()
            accelerator.backward(loss)
            accelerator.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()

        if accelerator.is_main_process:
            if step % args.log_steps == 0:
                print(f"Step {step} | loss={loss.item():.6f}")
                wandb.log({"step": step, "loss": loss.item()})
                
            if step % args.save_steps == 0:
                ckpt = Path(save_root) / f"step_{step}"
                ckpt.mkdir(parents=True, exist_ok=True)
                accelerator.unwrap_model(model).save_pretrained(ckpt, safe_serialization=True)
                tokenizer.save_pretrained(ckpt)
                print(f"Checkpoint saved to {ckpt}")

    if accelerator.is_main_process:
        final = Path(save_root) / "final"
        final.mkdir(parents=True, exist_ok=True)
        accelerator.unwrap_model(model).save_pretrained(final, safe_serialization=True)
        tokenizer.save_pretrained(final)
        print(f"Final checkpoint saved to {final}")
        wandb.finish()

if __name__ == "__main__":
    main()