diff options
| author | blackhao <13851610112@163.com> | 2025-08-23 14:17:47 -0500 |
|---|---|---|
| committer | blackhao <13851610112@163.com> | 2025-08-23 14:17:47 -0500 |
| commit | 8b2cf4b226de17227aa95abb637d1410bc2e57e3 (patch) | |
| tree | 6301c78a7c98b590f4a14487e4547f561389075d /Group-Entropy-Equalization/train.py | |
| parent | f21f7dd85365b10505bbd1cfa28f6a8648ba1b7e (diff) | |
feat(gee): add GEE objective and flags; add groups/gender.json; docs for Colab and GEE
Diffstat (limited to 'Group-Entropy-Equalization/train.py')
| -rw-r--r-- | Group-Entropy-Equalization/train.py | 133 |
1 files changed, 130 insertions, 3 deletions
diff --git a/Group-Entropy-Equalization/train.py b/Group-Entropy-Equalization/train.py index d1ba4f0..9ca1d4a 100644 --- a/Group-Entropy-Equalization/train.py +++ b/Group-Entropy-Equalization/train.py @@ -17,6 +17,7 @@ import wandb from accelerate import Accelerator, DeepSpeedPlugin from accelerate.utils import set_seed from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM +import json os.environ.setdefault("NCCL_TIMEOUT", "2700") @@ -42,6 +43,15 @@ def parse_args(): parser.add_argument('--seed', type=int, default=15, help='Random seed') parser.add_argument('--no_deepspeed', action='store_true', help='Disable DeepSpeed and use plain Accelerator (Colab-friendly)') parser.add_argument('--mixed_precision', type=str, default='bf16', choices=['bf16', 'fp16', 'no'], help='Mixed precision mode') + # GEE options + parser.add_argument('--gee_enable', action='store_true', help='Enable Group-wise Entropy Equalization (debiasing)') + parser.add_argument('--gee_groups_path', type=str, default='groups/gender.json', help='Path to JSON defining groups') + parser.add_argument('--gee_alpha', type=float, default=1.0, help='Weight for group mass parity loss') + parser.add_argument('--gee_beta', type=float, default=0.3, help='Weight for group entropy equalization loss') + parser.add_argument('--gee_lambda', type=float, default=0.0, help='Weight for global entropy anchor') + parser.add_argument('--gee_gamma', type=float, default=0.0, help='Weight for sensitive coverage anchor') + parser.add_argument('--gee_tau', type=float, default=1e-3, help='Min union mass to apply GEE losses') + parser.add_argument('--gee_top_m', type=int, default=50, help='Apply GEE if any group token in top-M at a position') return parser.parse_args() class FTDataset(Dataset): @@ -121,12 +131,52 @@ def main(): tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left") tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token + # Prepare GEE group ids and targets if enabled + group_name_list = [] + group_id_lists = [] + group_target_pi = [] + if args.gee_enable: + if not os.path.exists(args.gee_groups_path): + raise FileNotFoundError(f"GEE groups file not found: {args.gee_groups_path}") + with open(args.gee_groups_path, 'r', encoding='utf-8') as f: + groups_payload = json.load(f) + groups = groups_payload.get('groups', {}) + pis = groups_payload.get('pi', {}) + if not groups: + raise ValueError('GEE groups.json is missing a "groups" object') + # Build token-id lists per group (first sub-token of each provided string) + for gname, tokens in groups.items(): + ids = [] + for w in tokens: + toks = tokenizer.tokenize(w) + if toks: + tid = tokenizer.convert_tokens_to_ids(toks[0]) + if tid is not None: + ids.append(tid) + ids = sorted(set([i for i in ids if isinstance(i, int) and i >= 0])) + if len(ids) == 0: + continue + group_name_list.append(gname) + group_id_lists.append(torch.tensor(ids, dtype=torch.long)) + group_target_pi.append(float(pis.get(gname, 1.0))) + if not group_id_lists: + raise ValueError('No valid group token ids produced from groups file') + # Normalize pi to sum to 1 + total_pi = sum(group_target_pi) + if total_pi <= 0: + group_target_pi = [1.0 / len(group_id_lists)] * len(group_id_lists) + else: + group_target_pi = [p / total_pi for p in group_target_pi] + if accelerator.is_main_process: wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args)) # Friendly error if the parquet path is missing if not os.path.exists(args.train_data): raise FileNotFoundError(f"Training data not found: {args.train_data}. Create/upload the parquet under the project folder or pass --train_data to an existing path.") + # Friendly error if the parquet path is missing + if not os.path.exists(args.train_data): + raise FileNotFoundError(f"Training data not found: {args.train_data}. Create/upload the parquet under the project folder or pass --train_data to an existing path.") df = pd.read_parquet(args.train_data) train_data = [{"input": apply_chat_template(tokenizer, p)} for p in df["problem"].dropna().tolist()] train_loader = DataLoader(FTDataset(train_data), batch_size=micro_bs, shuffle=True, collate_fn=custom_collate) @@ -134,6 +184,8 @@ def main(): optimizer = AdamW(model.parameters(), lr=lr) model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader) prev_logits = None + baseline_avg_H = None # for L_reg + baseline_union_mass = None # for coverage anchor model.train() for step, batch in enumerate(train_loader, start=1): @@ -168,10 +220,85 @@ def main(): token_idx = torch.arange(seq.size(1), device=seq.device) gen_mask = (token_idx.unsqueeze(0) >= prompt_len.unsqueeze(1)) & pad_mask - logits = model(seq, attention_mask=pad_mask).logits + logits = model(seq, attention_mask=pad_mask).logits # [B,T,V] probs = F.softmax(logits / temp, dim=-1) - H_tok = -(probs * torch.log(probs + 1e-12)).sum(-1) - loss = (H_tok * gen_mask).sum() / gen_mask.sum().clamp_min(1) + H_tok = -(probs * torch.log(probs + 1e-12)).sum(-1) # [B,T] + + if args.gee_enable: + # Compute GEE losses on generated positions only + mask = gen_mask + denom = mask.sum().clamp_min(1) + + # Union mass and per-group mass + group_masses = [] # list of [B,T] + top_m = args.gee_top_m + # Top-M indices to decide triggering + if top_m > 0: + topk = probs.topk(k=min(top_m, probs.size(-1)), dim=-1).indices # [B,T,M] + else: + topk = None + for ids in group_id_lists: + ids = ids.to(probs.device) + gm = probs.index_select(-1, ids).sum(-1) # [B,T] + group_masses.append(gm) + union_mass = torch.stack(group_masses, dim=-1).sum(-1) # [B,T] + + # Trigger mask: apply only when union_mass >= tau OR any group id in top-M + trigger = union_mass >= args.gee_tau + if topk is not None: + any_in_top = torch.zeros_like(trigger) + vocab_in_top = topk # [B,T,M] + for ids in group_id_lists: + ids = ids.to(probs.device) + g_match = (vocab_in_top.unsqueeze(-1) == ids.view(1,1,1,-1)).any(-1) # [B,T,M] + any_in_top |= g_match.any(-1) + trigger |= any_in_top + eff_mask = mask & trigger + eff_denom = eff_mask.sum().clamp_min(1) + + # L_mass: group-mass parity to target pi + pi = torch.tensor(group_target_pi, device=probs.device, dtype=probs.dtype).view(1,1,-1) # [1,1,K] + masses_stacked = torch.stack(group_masses, dim=-1) # [B,T,K] + mass_gap = (masses_stacked - pi).pow(2) # [B,T,K] + L_mass = (mass_gap.sum(-1) * eff_mask).sum() / eff_denom + + # L_GEE: equalize normalized group entropy per position + norm_group_entropies = [] # [B,T] per group + for ids in group_id_lists: + ids = ids.to(probs.device) + p_sub = probs.index_select(-1, ids) # [B,T,|G|] + denom_g = p_sub.sum(-1, keepdim=True).clamp_min(1e-12) + p_g = p_sub / denom_g + H_g = -(p_g * torch.log(p_g + 1e-12)).sum(-1) # [B,T] + max_H = math.log(p_sub.size(-1)) if p_sub.size(-1) > 1 else 1.0 + H_g_norm = H_g / max(max_H, 1e-12) + norm_group_entropies.append(H_g_norm) + H_stack = torch.stack(norm_group_entropies, dim=-1) # [B,T,K] + H_bar = H_stack.mean(-1, keepdim=True) # [B,T,1] + L_gee = (((H_stack - H_bar) ** 2).sum(-1) * eff_mask).sum() / eff_denom + + # L_reg: global entropy anchor to baseline average + if args.gee_lambda > 0: + avg_H = (H_tok * mask).sum() / denom + if baseline_avg_H is None: + baseline_avg_H = avg_H.detach() + L_reg = (avg_H - baseline_avg_H).pow(2) + else: + L_reg = torch.zeros((), device=probs.device, dtype=probs.dtype) + + # L_cov: keep union sensitive mass near baseline + if args.gee_gamma > 0: + avg_union = (union_mass * mask).sum() / denom + if baseline_union_mass is None: + baseline_union_mass = avg_union.detach() + L_cov = (avg_union - baseline_union_mass).pow(2) + else: + L_cov = torch.zeros((), device=probs.device, dtype=probs.dtype) + + loss = args.gee_alpha * L_mass + args.gee_beta * L_gee + args.gee_lambda * L_reg + args.gee_gamma * L_cov + else: + # Original One-shot EM loss + loss = (H_tok * gen_mask).sum() / gen_mask.sum().clamp_min(1) prev_logits = logits.detach() accelerator.backward(loss) |
