From 9bdd1091a8045d74823d1f5de94b041805e42c3f Mon Sep 17 00:00:00 2001 From: Yuren Hao Date: Wed, 10 Sep 2025 14:58:32 -0500 Subject: fuck nvidia drivers --- src/__init__.py | 0 src/losses.py | 248 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 248 insertions(+) create mode 100644 src/__init__.py create mode 100644 src/losses.py (limited to 'src') diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/losses.py b/src/losses.py new file mode 100644 index 0000000..f8f491e --- /dev/null +++ b/src/losses.py @@ -0,0 +1,248 @@ +# -*- coding: utf-8 -*- +""" +Losses for: +- EM (entropy minimization) +- Group-EM (entropy-difference between female/male token groups) +- JSD counterfactual invariance (x vs swap(x)) with optional Top-K +Guards: +- Mass parity: (piF - piM)^2 +- Stability: KL( p_theta || p_base ) +Gating: +- Top-K trigger on {F ∪ M} at each step (boundary-safe happens at text level during data build/eval) +Note: +- All losses are averaged over steps where gate==1 AND (optionally) generation mask==1. +""" +from typing import Dict, List, Optional, Tuple +import torch +import torch.nn.functional as F + +def map_words_to_token_ids(tok, words: List[str]) -> List[int]: + ids = set() + for w in words: + for form in (w, " " + w): + enc = tok(form, add_special_tokens=False, return_tensors=None) + toks = enc["input_ids"] + if len(toks) == 1: + ids.add(int(toks[0])) + elif len(toks) > 1: + ids.add(int(toks[0])) # first-piece fallback + return sorted(ids) + +def probs_from_logits(logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor: + if temperature <= 0: + # avoid div-by-zero; at T=0 use softmax on raw logits (equivalent to no scaling) + return F.softmax(logits, dim=-1) + return F.softmax(logits / temperature, dim=-1) + +def topk_gate(logits: torch.Tensor, fem_ids: List[int], male_ids: List[int], k: int = 20) -> torch.Tensor: + """ + logits: [B,T,V] + Return gate mask [B,T] == 1 if top-k at step contains any F∪M id. + """ + B,T,V = logits.shape + topk = torch.topk(logits, k=min(k, V), dim=-1).indices # [B,T,k] + ids = torch.tensor(list(set(fem_ids) | set(male_ids)), device=logits.device, dtype=torch.long) + if ids.numel() == 0: + return torch.zeros(B,T, dtype=torch.float32, device=logits.device) + # Compare with broadcasting + match = (topk.unsqueeze(-1) == ids.view(1,1,1,-1)).any(dim=-1) # [B,T,k] -> [B,T] + return match.float() + +def group_masses(probs: torch.Tensor, fem_ids: List[int], male_ids: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: + """ + probs: [B,T,V] + Returns piF, piM of shape [B,T] + """ + if len(fem_ids) == 0 and len(male_ids) == 0: + return torch.zeros_like(probs[...,0]), torch.zeros_like(probs[...,0]) + idxF = torch.tensor(fem_ids, device=probs.device, dtype=torch.long) if len(fem_ids)>0 else None + idxM = torch.tensor(male_ids, device=probs.device, dtype=torch.long) if len(male_ids)>0 else None + piF = probs[..., idxF].sum(dim=-1) if idxF is not None else torch.zeros_like(probs[...,0]) + piM = probs[..., idxM].sum(dim=-1) if idxM is not None else torch.zeros_like(probs[...,0]) + return piF, piM + +def normalized_entropy(sub_probs: torch.Tensor) -> torch.Tensor: + """ + sub_probs: [*, K] + Return normalized entropy in [0,1]: H(p)/log(K) + """ + eps = 1e-12 + K = sub_probs.size(-1) + H = -(sub_probs.clamp_min(eps) * sub_probs.clamp_min(eps).log()).sum(dim=-1) + denom = torch.log(torch.tensor(float(K), device=sub_probs.device)) + return H / (denom + eps) + +def group_entropies(probs: torch.Tensor, fem_ids: List[int], male_ids: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: + """ + probs: [B,T,V] -> qF [B,T,|F|], qM [B,T,|M|] -> HF, HM in [0,1] + """ + eps = 1e-12 + idxF = torch.tensor(fem_ids, device=probs.device, dtype=torch.long) if len(fem_ids)>0 else None + idxM = torch.tensor(male_ids, device=probs.device, dtype=torch.long) if len(male_ids)>0 else None + + if idxF is None: + HF = torch.zeros(probs.shape[:2], device=probs.device) + else: + pF = probs[..., idxF] # [B,T,|F|] + piF = pF.sum(dim=-1, keepdim=True) + eps + qF = pF / piF + HF = normalized_entropy(qF) + + if idxM is None: + HM = torch.zeros(probs.shape[:2], device=probs.device) + else: + pM = probs[..., idxM] + piM = pM.sum(dim=-1, keepdim=True) + eps + qM = pM / piM + HM = normalized_entropy(qM) + + return HF, HM + +def reduce_steps(x: torch.Tensor, step_mask: torch.Tensor) -> torch.Tensor: + """ + x: [B,T], step_mask: [B,T] in {0,1} + Return mean over steps where mask==1 (avoid div by 0). + """ + w = step_mask + s = (x * w).sum() + d = w.sum().clamp_min(1.0) + return s / d + +# ---------------- EM ---------------- +def loss_em(logits: torch.Tensor, gen_mask: torch.Tensor) -> Tuple[torch.Tensor, Dict]: + """ + Entropy minimization over generation steps (no gating). + logits: [B,T,V]; gen_mask: [B,T] 1 for generation steps (non-prompt) + """ + probs = probs_from_logits(logits) # [B,T,V] + eps = 1e-12 + Ht = -(probs.clamp_min(eps) * probs.clamp_min(eps).log()).sum(dim=-1) # [B,T] + L = reduce_steps(Ht, gen_mask) + return L, {"H_mean": float(reduce_steps(Ht, gen_mask).item())} + +# ------------- Group-EM ------------- +def loss_group_em( + logits: torch.Tensor, + gen_mask: torch.Tensor, + fem_ids: List[int], + male_ids: List[int], + gate_mask: Optional[torch.Tensor] = None, + lambda_mass: float = 0.0, + beta_kl: float = 0.0, + ref_probs: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, Dict]: + """ + Group-EM loss with optional guards. + - core: (H_F - H_M)^2 + - mass guard: (piF - piM)^2 + - stability: KL( p || pref ) + """ + probs = probs_from_logits(logits) # [B,T,V] + HF, HM = group_entropies(probs, fem_ids, male_ids) # [B,T], [B,T] + core = (HF - HM) ** 2 # [B,T] + + piF, piM = group_masses(probs, fem_ids, male_ids) + Lmass = (piF - piM) ** 2 # [B,T] + + if gate_mask is None: + step_mask = gen_mask + else: + step_mask = (gen_mask * gate_mask).float() + + L_core = reduce_steps(core, step_mask) + L_mass = reduce_steps(Lmass, step_mask) + + L_kl = torch.tensor(0.0, device=logits.device) + if beta_kl > 0.0 and ref_probs is not None: + eps = 1e-12 + p = probs.clamp_min(eps) + q = ref_probs.clamp_min(eps) + KL = (p * (p.log() - q.log())).sum(dim=-1) # [B,T] + L_kl = reduce_steps(KL, step_mask) + + loss = L_core + lambda_mass * L_mass + beta_kl * L_kl + extras = { + "L_core": float(L_core.item()), + "L_mass": float(L_mass.item()), + "L_kl": float(L_kl.item()) if isinstance(L_kl, torch.Tensor) else float(L_kl), + "piF_mean": float(reduce_steps(piF, step_mask).item()), + "piM_mean": float(reduce_steps(piM, step_mask).item()), + "HF_mean": float(reduce_steps(HF, step_mask).item()), + "HM_mean": float(reduce_steps(HM, step_mask).item()), + } + return loss, extras + +# --------------- JSD --------------- +def _jsd_full(p: torch.Tensor, q: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + m = 0.5 * (p + q) + return 0.5 * (p * (p.clamp_min(eps).log() - m.clamp_min(eps).log())).sum(dim=-1) + \ + 0.5 * (q * (q.clamp_min(eps).log() - m.clamp_min(eps).log())).sum(dim=-1) + +def _jsd_topk(p: torch.Tensor, q: torch.Tensor, K: int) -> torch.Tensor: + V = p.size(-1) + K = min(K, V) + idx_p = torch.topk(p, k=K, dim=-1).indices + idx_q = torch.topk(q, k=K, dim=-1).indices + idx = torch.cat([idx_p, idx_q], dim=-1).unique(dim=-1) # union + pK = p.gather(-1, idx); qK = q.gather(-1, idx) + mK = 0.5 * (pK + qK) + eps = 1e-12 + return 0.5 * (pK * (pK.clamp_min(eps).log() - mK.clamp_min(eps).log())).sum(dim=-1) + \ + 0.5 * (qK * (qK.clamp_min(eps).log() - mK.clamp_min(eps).log())).sum(dim=-1) + +def loss_jsd( + logits_f: torch.Tensor, # [B,T,V] + logits_c: torch.Tensor, # [B,T,V] + gen_mask: torch.Tensor, # [B,T] + fem_ids: List[int], + male_ids: List[int], + gate_mask_f: Optional[torch.Tensor] = None, + gate_mask_c: Optional[torch.Tensor] = None, + lambda_mass: float = 0.0, + beta_kl: float = 0.0, + ref_probs_f: Optional[torch.Tensor] = None, + topk_jsd: int = 0 +) -> Tuple[torch.Tensor, Dict]: + """ + JSD(p||q) averaged over steps with gating (factual and counterfactual separately gated). + Also includes mass parity on both branches and optional stability to base on factual branch. + """ + p = probs_from_logits(logits_f) # [B,T,V] + q = probs_from_logits(logits_c) # [B,T,V] + + if topk_jsd and topk_jsd > 0: + J = _jsd_topk(p, q, K=topk_jsd) # [B,T] + else: + J = _jsd_full(p, q) # [B,T] + + # step mask: require gate on factual (and optionally also on counterfactual) + if gate_mask_f is None: + step_mask = gen_mask + else: + step_mask = (gen_mask * gate_mask_f).float() + + L_jsd = reduce_steps(J, step_mask) + + # mass parity on each branch + piF_f, piM_f = group_masses(p, fem_ids, male_ids) + piF_c, piM_c = group_masses(q, fem_ids, male_ids) + L_mass = reduce_steps((piF_f - piM_f)**2, step_mask) + reduce_steps((piF_c - piM_c)**2, step_mask) + + # stability to base (factual branch) + L_kl = torch.tensor(0.0, device=logits_f.device) + if beta_kl > 0.0 and ref_probs_f is not None: + eps = 1e-12 + p0 = ref_probs_f.clamp_min(eps) + L_kl = reduce_steps((p.clamp_min(eps) * (p.clamp_min(eps).log() - p0.log())).sum(dim=-1), step_mask) + + loss = L_jsd + lambda_mass * L_mass + beta_kl * L_kl + extras = { + "L_jsd": float(L_jsd.item()), + "L_mass": float(L_mass.item()), + "L_kl": float(L_kl.item()) if isinstance(L_kl, torch.Tensor) else float(L_kl), + "piF_f": float(reduce_steps(piF_f, step_mask).item()), + "piM_f": float(reduce_steps(piM_f, step_mask).item()), + "piF_c": float(reduce_steps(piF_c, step_mask).item()), + "piM_c": float(reduce_steps(piM_c, step_mask).item()), + } + return loss, extras -- cgit v1.2.3