summaryrefslogtreecommitdiff
path: root/Group-Entropy-Equalization/train.py
diff options
context:
space:
mode:
authorblackhao <13851610112@163.com>2025-08-23 15:41:02 -0500
committerblackhao <13851610112@163.com>2025-08-23 15:41:02 -0500
commitcd4c558c080bfe3575ec0e7ad85d66a5b7d47aae (patch)
treee7bc629d47cebba10897146accfebdc3a8da58af /Group-Entropy-Equalization/train.py
parent5539e387c08a54c77b9cb5e92faf7d878bff40ad (diff)
Diffstat (limited to 'Group-Entropy-Equalization/train.py')
-rw-r--r--Group-Entropy-Equalization/train.py26
1 files changed, 22 insertions, 4 deletions
diff --git a/Group-Entropy-Equalization/train.py b/Group-Entropy-Equalization/train.py
index 352e916..2cc9438 100644
--- a/Group-Entropy-Equalization/train.py
+++ b/Group-Entropy-Equalization/train.py
@@ -51,8 +51,9 @@ def parse_args():
parser.add_argument('--gee_beta', type=float, default=0.3, help='Weight for group entropy equalization loss')
parser.add_argument('--gee_lambda', type=float, default=0.0, help='Weight for global entropy anchor')
parser.add_argument('--gee_gamma', type=float, default=0.0, help='Weight for sensitive coverage anchor')
- parser.add_argument('--gee_tau', type=float, default=1e-3, help='Min union mass to apply GEE losses')
- parser.add_argument('--gee_top_m', type=int, default=50, help='Apply GEE if any group token in top-M at a position')
+ parser.add_argument('--gee_tau', type=float, default=1e-6, help='Min union mass to apply GEE losses')
+ parser.add_argument('--gee_top_m', type=int, default=200, help='Apply GEE if any group token in top-M at a position')
+ parser.add_argument('--gee_em_mix', type=float, default=0.1, help='Additive EM loss mix to stabilize training (0 to disable)')
return parser.parse_args()
class FTDataset(Dataset):
@@ -225,6 +226,9 @@ def main():
probs = F.softmax(logits / temp, dim=-1)
H_tok = -(probs * torch.log(probs + 1e-12)).sum(-1) # [B,T]
+ # Precompute EM loss (used standalone or as mixin)
+ em_loss = (H_tok * gen_mask).sum() / gen_mask.sum().clamp_min(1)
+
if args.gee_enable:
# Compute GEE losses on generated positions only
mask = gen_mask
@@ -296,10 +300,24 @@ def main():
else:
L_cov = torch.zeros((), device=probs.device, dtype=probs.dtype)
- loss = args.gee_alpha * L_mass + args.gee_beta * L_gee + args.gee_lambda * L_reg + args.gee_gamma * L_cov
+ loss_gee = args.gee_alpha * L_mass + args.gee_beta * L_gee + args.gee_lambda * L_reg + args.gee_gamma * L_cov
+ # Fallback: if no positions triggered, use EM loss to ensure updates
+ if eff_denom.item() == 0:
+ loss = em_loss
+ else:
+ loss = loss_gee + (args.gee_em_mix * em_loss if args.gee_em_mix > 0 else 0.0)
+ # Log activation ratio if main process
+ if accelerator.is_main_process:
+ gee_active_ratio = (eff_denom / denom).item()
+ try:
+ wandb.log({"gee_active_ratio": gee_active_ratio,
+ "L_mass": float(L_mass.detach().item()),
+ "L_gee": float(L_gee.detach().item())})
+ except Exception:
+ pass
else:
# Original One-shot EM loss
- loss = (H_tok * gen_mask).sum() / gen_mask.sum().clamp_min(1)
+ loss = em_loss
prev_logits = logits.detach()
accelerator.backward(loss)