summaryrefslogtreecommitdiff
path: root/scripts/make_train_sets.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@timan108.cs.illinois.edu>2025-09-10 14:58:32 -0500
committerYuren Hao <yurenh2@timan108.cs.illinois.edu>2025-09-10 14:58:32 -0500
commit9bdd1091a8045d74823d1f5de94b041805e42c3f (patch)
treeb63843e7bce35b95bc3bd87aaee19059fdf2982c /scripts/make_train_sets.py
parent9d5f2379ac25b4b58e2600544f61172dbb15b67a (diff)
fuck nvidia driversHEADmain
Diffstat (limited to 'scripts/make_train_sets.py')
-rw-r--r--scripts/make_train_sets.py200
1 files changed, 200 insertions, 0 deletions
diff --git a/scripts/make_train_sets.py b/scripts/make_train_sets.py
new file mode 100644
index 0000000..218ac81
--- /dev/null
+++ b/scripts/make_train_sets.py
@@ -0,0 +1,200 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Generate training sets for:
+- EM / Group-EM: 1, 5, 20 prompts (English only)
+- JSD: 100 factual/counterfactual pairs (x, x_swap) with gender swap
+
+Data sources:
+ - assets/triggers/occupations_en.txt
+ - assets/groups/en_definitional_pairs.json (female<->male definitional pairs; supports multiple schemas)
+ - assets/groups/en_female.txt / en_male.txt (for pronouns; we mainly use 'she'/'he')
+
+Avoids overlap with eval CTF examples in data/bias/ctf/ctf_en.jsonl.
+
+Outputs:
+ data/train/em_group/train_en_size{1,5,20}.jsonl
+ data/train/jsd/train_pairs_en_size100.jsonl
+"""
+import json, random, pathlib, re, sys
+from typing import List, Tuple, Dict, Any
+
+ROOT = pathlib.Path(__file__).resolve().parents[1]
+ASSETS = ROOT / "assets"
+DATA = ROOT / "data"
+OUT_EM = DATA / "train" / "em_group"
+OUT_JSD = DATA / "train" / "jsd"
+EVAL_CTF = DATA / "bias" / "ctf" / "ctf_en.jsonl"
+
+random.seed(2025)
+
+def read_lines(p: pathlib.Path) -> List[str]:
+ return [x.strip() for x in p.read_text(encoding="utf-8").splitlines() if x.strip()]
+
+def read_jsonl(p: pathlib.Path) -> List[dict]:
+ if not p.exists(): return []
+ return [json.loads(x) for x in p.read_text(encoding="utf-8").splitlines() if x.strip()]
+
+def _normalize_pair(a: Any, b: Any) -> Tuple[str,str]:
+ return (str(a).strip().lower(), str(b).strip().lower())
+
+def load_pairs_json_any(p: pathlib.Path) -> List[Tuple[str,str]]:
+ """
+ Robustly load definitional gender pairs from various schemas seen in the wild:
+ - [ ["woman","man"], ["girl","boy"], ... ]
+ - { "definitional": [ ["woman","man"], ... ] }
+ - [ {"f":"woman","m":"man"}, ... ] or keys {"female":...,"male":...}
+ - { "pairs":[... same as above ...] }
+ Returns a list of (female, male) lower-cased tuples.
+ """
+ if not p.exists():
+ return []
+ data = json.loads(p.read_text(encoding="utf-8"))
+ pairs: List[Tuple[str,str]] = []
+
+ def add_from_list(lst: List[Any]):
+ for item in lst:
+ if isinstance(item, list) and len(item) == 2:
+ a, b = item
+ pairs.append(_normalize_pair(a, b))
+ elif isinstance(item, dict):
+ # common key patterns
+ if "f" in item and "m" in item:
+ pairs.append(_normalize_pair(item["f"], item["m"]))
+ elif "female" in item and "male" in item:
+ pairs.append(_normalize_pair(item["female"], item["male"]))
+ elif "a" in item and "b" in item:
+ pairs.append(_normalize_pair(item["a"], item["b"]))
+
+ if isinstance(data, list):
+ add_from_list(data)
+ elif isinstance(data, dict):
+ for k in ("definitional", "definitional_pairs", "pairs"):
+ if isinstance(data.get(k), list):
+ add_from_list(data[k])
+
+ # de-dup & sanity
+ out = []
+ seen = set()
+ for a,b in pairs:
+ if not a or not b or a == b:
+ continue
+ key = (a,b)
+ if key not in seen:
+ seen.add(key); out.append(key)
+ return out
+
+def minimal_swap(text: str, pairs: List[Tuple[str,str]]) -> str:
+ """
+ Swap gendered tokens with word-boundary regex, preferring longer strings first.
+ """
+ if not pairs:
+ return text # no-op; caller will fallback to template-based swap
+ mapping: Dict[str,str] = {}
+ for a,b in pairs:
+ mapping[a] = b
+ mapping[b] = a
+ # sort keys by length desc to avoid partial overshadow
+ keys = sorted(mapping.keys(), key=len, reverse=True)
+ out = text
+ for k in keys:
+ # boundary-safe, case-insensitive; replace with mapping[k] in lowercase
+ pat = re.compile(rf"(?<!\w){re.escape(k)}(?!\w)", flags=re.IGNORECASE)
+ out = pat.sub(mapping[k], out)
+ return out
+
+def sample_templates(occupations: List[str], n: int) -> List[Tuple[str,str,str]]:
+ """
+ Return list of tuples: (occupation, she_template, he_template)
+ """
+ SHE_FORMS = [
+ "The {occ} said that she would review the design.",
+ "The {occ} explained that she needed more time.",
+ "The {occ} promised that she will lead the next meeting.",
+ "The {occ} reported that she completed the task.",
+ "The {occ} noted that she prefers the alternative plan.",
+ ]
+ HE_FORMS = [
+ t.replace(" she ", " he ").replace("She ", "He ").replace(" she", " he")
+ for t in SHE_FORMS
+ ]
+ random.shuffle(occupations)
+ out = []
+ i = 0
+ while len(out) < n and i < 10*n:
+ occ = occupations[i % len(occupations)]
+ idx = random.randrange(len(SHE_FORMS))
+ s_she = SHE_FORMS[idx].format(occ=occ)
+ s_he = HE_FORMS[idx].format(occ=occ)
+ out.append((occ, s_she, s_he))
+ i += 1
+ return out[:n]
+
+def main():
+ OUT_EM.mkdir(parents=True, exist_ok=True)
+ OUT_JSD.mkdir(parents=True, exist_ok=True)
+
+ # sources
+ occs = read_lines(ASSETS / "triggers" / "occupations_en.txt")
+ pairs_json = ASSETS / "groups" / "en_definitional_pairs.json"
+ pairs = load_pairs_json_any(pairs_json)
+ eval_ctf = read_jsonl(EVAL_CTF)
+ eval_x = set([r.get("x","").strip() for r in eval_ctf] + [r.get("x_swap","").strip() for r in eval_ctf])
+
+ # ---- EM / Group-EM prompts (she-variant prompts; labels不需要)
+ for size in [1,5,20]:
+ triples = sample_templates(occs, size*3) # oversample for filtering
+ rows = []
+ for occ, s_she, s_he in triples:
+ if s_she in eval_x or s_he in eval_x:
+ continue
+ rows.append({"id": f"em_{len(rows):06d}", "lang":"en", "occupation": occ, "prompt": s_she})
+ if len(rows) >= size: break
+ outp = OUT_EM / f"train_en_size{size}.jsonl"
+ outp.write_text("\n".join(json.dumps(r) for r in rows) + ("\n" if rows else ""), encoding="utf-8")
+ print("Wrote", outp, "N=", len(rows))
+
+ # ---- JSD pairs (x, x_swap)
+ size_jsd = 100
+ triples = sample_templates(occs, size_jsd*4) # oversample more to be safe
+ pairs_out = []
+ for occ, s_she, s_he in triples:
+ x = s_she
+ if x in eval_x:
+ continue
+
+ # try definitional-pair swap first
+ x_swap = minimal_swap(x, pairs)
+ # fallback: if no change, use our explicit he-template
+ if x.strip().lower() == x_swap.strip().lower():
+ x_swap = s_he
+
+ if x_swap in eval_x:
+ continue
+ if x.strip().lower() == x_swap.strip().lower(): # still identical? extremely unlikely
+ continue
+
+ pairs_out.append({
+ "id": f"jsd_{len(pairs_out):06d}",
+ "lang":"en",
+ "occupation": occ,
+ "prompt": x,
+ "prompt_swap": x_swap
+ })
+ if len(pairs_out) >= size_jsd: break
+
+ outp2 = OUT_JSD / "train_pairs_en_size100.jsonl"
+ outp2.write_text("\n".join(json.dumps(r) for r in pairs_out) + ("\n" if pairs_out else ""), encoding="utf-8")
+ print("Wrote", outp2, "N=", len(pairs_out))
+
+ # quick diagnostics
+ if len(pairs_out) == 0:
+ print("[WARN] JSD pairs = 0. Diagnostics:")
+ print(" - occupations:", len(occs))
+ print(" - definitional pairs loaded:", len(pairs))
+ print(" - eval_x size:", len(eval_x))
+ print(" - Check assets/groups/en_definitional_pairs.json schema.")
+ sys.exit(2)
+
+if __name__ == "__main__":
+ main()