diff options
| -rw-r--r-- | Group-Entropy-Equalization/train.py | 26 |
1 files changed, 22 insertions, 4 deletions
diff --git a/Group-Entropy-Equalization/train.py b/Group-Entropy-Equalization/train.py index 352e916..2cc9438 100644 --- a/Group-Entropy-Equalization/train.py +++ b/Group-Entropy-Equalization/train.py @@ -51,8 +51,9 @@ def parse_args(): parser.add_argument('--gee_beta', type=float, default=0.3, help='Weight for group entropy equalization loss') parser.add_argument('--gee_lambda', type=float, default=0.0, help='Weight for global entropy anchor') parser.add_argument('--gee_gamma', type=float, default=0.0, help='Weight for sensitive coverage anchor') - parser.add_argument('--gee_tau', type=float, default=1e-3, help='Min union mass to apply GEE losses') - parser.add_argument('--gee_top_m', type=int, default=50, help='Apply GEE if any group token in top-M at a position') + parser.add_argument('--gee_tau', type=float, default=1e-6, help='Min union mass to apply GEE losses') + parser.add_argument('--gee_top_m', type=int, default=200, help='Apply GEE if any group token in top-M at a position') + parser.add_argument('--gee_em_mix', type=float, default=0.1, help='Additive EM loss mix to stabilize training (0 to disable)') return parser.parse_args() class FTDataset(Dataset): @@ -225,6 +226,9 @@ def main(): probs = F.softmax(logits / temp, dim=-1) H_tok = -(probs * torch.log(probs + 1e-12)).sum(-1) # [B,T] + # Precompute EM loss (used standalone or as mixin) + em_loss = (H_tok * gen_mask).sum() / gen_mask.sum().clamp_min(1) + if args.gee_enable: # Compute GEE losses on generated positions only mask = gen_mask @@ -296,10 +300,24 @@ def main(): else: L_cov = torch.zeros((), device=probs.device, dtype=probs.dtype) - loss = args.gee_alpha * L_mass + args.gee_beta * L_gee + args.gee_lambda * L_reg + args.gee_gamma * L_cov + loss_gee = args.gee_alpha * L_mass + args.gee_beta * L_gee + args.gee_lambda * L_reg + args.gee_gamma * L_cov + # Fallback: if no positions triggered, use EM loss to ensure updates + if eff_denom.item() == 0: + loss = em_loss + else: + loss = loss_gee + (args.gee_em_mix * em_loss if args.gee_em_mix > 0 else 0.0) + # Log activation ratio if main process + if accelerator.is_main_process: + gee_active_ratio = (eff_denom / denom).item() + try: + wandb.log({"gee_active_ratio": gee_active_ratio, + "L_mass": float(L_mass.detach().item()), + "L_gee": float(L_gee.detach().item())}) + except Exception: + pass else: # Original One-shot EM loss - loss = (H_tok * gen_mask).sum() / gen_mask.sum().clamp_min(1) + loss = em_loss prev_logits = logits.detach() accelerator.backward(loss) |
