summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Group-Entropy-Equalization/README.md29
-rw-r--r--Group-Entropy-Equalization/train.py133
2 files changed, 159 insertions, 3 deletions
diff --git a/Group-Entropy-Equalization/README.md b/Group-Entropy-Equalization/README.md
index 33bd020..0ea2010 100644
--- a/Group-Entropy-Equalization/README.md
+++ b/Group-Entropy-Equalization/README.md
@@ -57,6 +57,35 @@ Checkpoints are saved under `checkpoints/<model>/<run_name>/`.
---
+### Group-wise Entropy Equalization (GEE)
+
+GEE balances sensitive groups by:
+- Group mass parity (push group probability mass toward target pi)
+- Group entropy equalization (normalize and equalize per-group entropy)
+- Optional anchors to keep global token-entropy and sensitive-union mass close to baseline
+
+Default groups file: `groups/gender.json`.
+
+Run on Colab (example):
+
+```bash
+!python train.py \
+ --model_name Qwen2.5-1.5B \
+ --model_path Qwen/Qwen2.5-1.5B \
+ --train_data dataset/1shot_rlvr/pi1_r1280.parquet \
+ --effective_batch 4 --micro_batch_size 1 \
+ --temperature 0.5 --learning_rate 2e-5 --sample_temp 0.5 \
+ --max_steps 15 --log_steps 1 --save_steps 5 \
+ --run_name colab_gee15 --wandb_project one-shot-em \
+ --no_deepspeed --mixed_precision no \
+ --gee_enable --gee_groups_path groups/gender.json \
+ --gee_alpha 1.0 --gee_beta 0.3 --gee_lambda 0.0 --gee_gamma 0.0 --gee_tau 1e-3 --gee_top_m 50
+```
+
+You can customize groups and target proportions in the JSON.
+
+---
+
### Reproducing One-shot EM Training (SOTA)
```bash
diff --git a/Group-Entropy-Equalization/train.py b/Group-Entropy-Equalization/train.py
index d1ba4f0..9ca1d4a 100644
--- a/Group-Entropy-Equalization/train.py
+++ b/Group-Entropy-Equalization/train.py
@@ -17,6 +17,7 @@ import wandb
from accelerate import Accelerator, DeepSpeedPlugin
from accelerate.utils import set_seed
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
+import json
os.environ.setdefault("NCCL_TIMEOUT", "2700")
@@ -42,6 +43,15 @@ def parse_args():
parser.add_argument('--seed', type=int, default=15, help='Random seed')
parser.add_argument('--no_deepspeed', action='store_true', help='Disable DeepSpeed and use plain Accelerator (Colab-friendly)')
parser.add_argument('--mixed_precision', type=str, default='bf16', choices=['bf16', 'fp16', 'no'], help='Mixed precision mode')
+ # GEE options
+ parser.add_argument('--gee_enable', action='store_true', help='Enable Group-wise Entropy Equalization (debiasing)')
+ parser.add_argument('--gee_groups_path', type=str, default='groups/gender.json', help='Path to JSON defining groups')
+ parser.add_argument('--gee_alpha', type=float, default=1.0, help='Weight for group mass parity loss')
+ parser.add_argument('--gee_beta', type=float, default=0.3, help='Weight for group entropy equalization loss')
+ parser.add_argument('--gee_lambda', type=float, default=0.0, help='Weight for global entropy anchor')
+ parser.add_argument('--gee_gamma', type=float, default=0.0, help='Weight for sensitive coverage anchor')
+ parser.add_argument('--gee_tau', type=float, default=1e-3, help='Min union mass to apply GEE losses')
+ parser.add_argument('--gee_top_m', type=int, default=50, help='Apply GEE if any group token in top-M at a position')
return parser.parse_args()
class FTDataset(Dataset):
@@ -121,12 +131,52 @@ def main():
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
+ # Prepare GEE group ids and targets if enabled
+ group_name_list = []
+ group_id_lists = []
+ group_target_pi = []
+ if args.gee_enable:
+ if not os.path.exists(args.gee_groups_path):
+ raise FileNotFoundError(f"GEE groups file not found: {args.gee_groups_path}")
+ with open(args.gee_groups_path, 'r', encoding='utf-8') as f:
+ groups_payload = json.load(f)
+ groups = groups_payload.get('groups', {})
+ pis = groups_payload.get('pi', {})
+ if not groups:
+ raise ValueError('GEE groups.json is missing a "groups" object')
+ # Build token-id lists per group (first sub-token of each provided string)
+ for gname, tokens in groups.items():
+ ids = []
+ for w in tokens:
+ toks = tokenizer.tokenize(w)
+ if toks:
+ tid = tokenizer.convert_tokens_to_ids(toks[0])
+ if tid is not None:
+ ids.append(tid)
+ ids = sorted(set([i for i in ids if isinstance(i, int) and i >= 0]))
+ if len(ids) == 0:
+ continue
+ group_name_list.append(gname)
+ group_id_lists.append(torch.tensor(ids, dtype=torch.long))
+ group_target_pi.append(float(pis.get(gname, 1.0)))
+ if not group_id_lists:
+ raise ValueError('No valid group token ids produced from groups file')
+ # Normalize pi to sum to 1
+ total_pi = sum(group_target_pi)
+ if total_pi <= 0:
+ group_target_pi = [1.0 / len(group_id_lists)] * len(group_id_lists)
+ else:
+ group_target_pi = [p / total_pi for p in group_target_pi]
+
if accelerator.is_main_process:
wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args))
# Friendly error if the parquet path is missing
if not os.path.exists(args.train_data):
raise FileNotFoundError(f"Training data not found: {args.train_data}. Create/upload the parquet under the project folder or pass --train_data to an existing path.")
+ # Friendly error if the parquet path is missing
+ if not os.path.exists(args.train_data):
+ raise FileNotFoundError(f"Training data not found: {args.train_data}. Create/upload the parquet under the project folder or pass --train_data to an existing path.")
df = pd.read_parquet(args.train_data)
train_data = [{"input": apply_chat_template(tokenizer, p)} for p in df["problem"].dropna().tolist()]
train_loader = DataLoader(FTDataset(train_data), batch_size=micro_bs, shuffle=True, collate_fn=custom_collate)
@@ -134,6 +184,8 @@ def main():
optimizer = AdamW(model.parameters(), lr=lr)
model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
prev_logits = None
+ baseline_avg_H = None # for L_reg
+ baseline_union_mass = None # for coverage anchor
model.train()
for step, batch in enumerate(train_loader, start=1):
@@ -168,10 +220,85 @@ def main():
token_idx = torch.arange(seq.size(1), device=seq.device)
gen_mask = (token_idx.unsqueeze(0) >= prompt_len.unsqueeze(1)) & pad_mask
- logits = model(seq, attention_mask=pad_mask).logits
+ logits = model(seq, attention_mask=pad_mask).logits # [B,T,V]
probs = F.softmax(logits / temp, dim=-1)
- H_tok = -(probs * torch.log(probs + 1e-12)).sum(-1)
- loss = (H_tok * gen_mask).sum() / gen_mask.sum().clamp_min(1)
+ H_tok = -(probs * torch.log(probs + 1e-12)).sum(-1) # [B,T]
+
+ if args.gee_enable:
+ # Compute GEE losses on generated positions only
+ mask = gen_mask
+ denom = mask.sum().clamp_min(1)
+
+ # Union mass and per-group mass
+ group_masses = [] # list of [B,T]
+ top_m = args.gee_top_m
+ # Top-M indices to decide triggering
+ if top_m > 0:
+ topk = probs.topk(k=min(top_m, probs.size(-1)), dim=-1).indices # [B,T,M]
+ else:
+ topk = None
+ for ids in group_id_lists:
+ ids = ids.to(probs.device)
+ gm = probs.index_select(-1, ids).sum(-1) # [B,T]
+ group_masses.append(gm)
+ union_mass = torch.stack(group_masses, dim=-1).sum(-1) # [B,T]
+
+ # Trigger mask: apply only when union_mass >= tau OR any group id in top-M
+ trigger = union_mass >= args.gee_tau
+ if topk is not None:
+ any_in_top = torch.zeros_like(trigger)
+ vocab_in_top = topk # [B,T,M]
+ for ids in group_id_lists:
+ ids = ids.to(probs.device)
+ g_match = (vocab_in_top.unsqueeze(-1) == ids.view(1,1,1,-1)).any(-1) # [B,T,M]
+ any_in_top |= g_match.any(-1)
+ trigger |= any_in_top
+ eff_mask = mask & trigger
+ eff_denom = eff_mask.sum().clamp_min(1)
+
+ # L_mass: group-mass parity to target pi
+ pi = torch.tensor(group_target_pi, device=probs.device, dtype=probs.dtype).view(1,1,-1) # [1,1,K]
+ masses_stacked = torch.stack(group_masses, dim=-1) # [B,T,K]
+ mass_gap = (masses_stacked - pi).pow(2) # [B,T,K]
+ L_mass = (mass_gap.sum(-1) * eff_mask).sum() / eff_denom
+
+ # L_GEE: equalize normalized group entropy per position
+ norm_group_entropies = [] # [B,T] per group
+ for ids in group_id_lists:
+ ids = ids.to(probs.device)
+ p_sub = probs.index_select(-1, ids) # [B,T,|G|]
+ denom_g = p_sub.sum(-1, keepdim=True).clamp_min(1e-12)
+ p_g = p_sub / denom_g
+ H_g = -(p_g * torch.log(p_g + 1e-12)).sum(-1) # [B,T]
+ max_H = math.log(p_sub.size(-1)) if p_sub.size(-1) > 1 else 1.0
+ H_g_norm = H_g / max(max_H, 1e-12)
+ norm_group_entropies.append(H_g_norm)
+ H_stack = torch.stack(norm_group_entropies, dim=-1) # [B,T,K]
+ H_bar = H_stack.mean(-1, keepdim=True) # [B,T,1]
+ L_gee = (((H_stack - H_bar) ** 2).sum(-1) * eff_mask).sum() / eff_denom
+
+ # L_reg: global entropy anchor to baseline average
+ if args.gee_lambda > 0:
+ avg_H = (H_tok * mask).sum() / denom
+ if baseline_avg_H is None:
+ baseline_avg_H = avg_H.detach()
+ L_reg = (avg_H - baseline_avg_H).pow(2)
+ else:
+ L_reg = torch.zeros((), device=probs.device, dtype=probs.dtype)
+
+ # L_cov: keep union sensitive mass near baseline
+ if args.gee_gamma > 0:
+ avg_union = (union_mass * mask).sum() / denom
+ if baseline_union_mass is None:
+ baseline_union_mass = avg_union.detach()
+ L_cov = (avg_union - baseline_union_mass).pow(2)
+ else:
+ L_cov = torch.zeros((), device=probs.device, dtype=probs.dtype)
+
+ loss = args.gee_alpha * L_mass + args.gee_beta * L_gee + args.gee_lambda * L_reg + args.gee_gamma * L_cov
+ else:
+ # Original One-shot EM loss
+ loss = (H_tok * gen_mask).sum() / gen_mask.sum().clamp_min(1)
prev_logits = logits.detach()
accelerator.backward(loss)