diff options
Diffstat (limited to 'scripts/eval_bias_baseline.py')
| -rw-r--r-- | scripts/eval_bias_baseline.py | 80 |
1 files changed, 58 insertions, 22 deletions
diff --git a/scripts/eval_bias_baseline.py b/scripts/eval_bias_baseline.py index 227805f..7817b78 100644 --- a/scripts/eval_bias_baseline.py +++ b/scripts/eval_bias_baseline.py @@ -7,6 +7,11 @@ Baseline bias evaluation: CTF (x vs swap(x)), CrowS-Pairs (gender), WinoGender. Outputs: runs/<DATE>/baseline_eval/bias/{ctf,crows,wino}/metrics.json runs/<DATE>/baseline_eval/bias/{ctf,crows,wino}/preds.jsonl + +NOTE (2025-09-10 fix): +- Previously first_gender_span() used naive substring search (str.find), + causing false positives like "he" in "The". We now use a word-boundary + regex (compiled once) with case-insensitive matching to avoid such leaks. """ import argparse, json, os, math, re, time, pathlib, statistics from typing import List, Dict, Tuple, Optional @@ -64,6 +69,47 @@ def map_words_to_token_ids(tok: AutoTokenizer, words: List[str]) -> List[int]: ids.add(int(i)) return sorted(ids) +# --------------------- Boundary-safe gender span finder --------------------- +def build_gender_regex(fem_words: List[str], male_words: List[str]) -> re.Pattern: + """ + Compile a single regex that matches any gender token with word boundaries: + (?<!\w)(?:w1|w2|...)(?!\w) + - Case-insensitive + - Escapes special chars (so 'mr.' works) + - Sorts by length desc to prefer longer matches (e.g., 'herself' over 'her') + """ + # Normalize and merge sets; keep both dotted/undotted variants if present + vocab = sorted({w.lower() for w in (fem_words + male_words)}, key=len, reverse=True) + pat = r"(?<!\w)(?:" + "|".join(re.escape(w) for w in vocab) + r")(?!\w)" + return re.compile(pat, flags=re.IGNORECASE) + +def classify_matched_token(matched: str, fem_set: set, male_set: set) -> Optional[str]: + """ + Decide whether the matched surface form belongs to F or M (case-insensitive). + We try raw, then rstrip('.') variants for dotted honorifics (mr., ms., mrs.). + Returns "F" or "M" (or None if not classifiable, which shouldn't happen if lists are consistent). + """ + s = matched.lower() + if s in fem_set: return "F" + if s in male_set: return "M" + s2 = s.rstrip(".") + if s2 in fem_set: return "F" + if s2 in male_set: return "M" + return None + +def first_gender_span(text: str, gender_re: re.Pattern, fem_set: set, male_set: set) -> Optional[Tuple[int, int, str]]: + """ + Find first occurrence (boundary-safe) of any gender word; return (start,end,group), group in {"F","M"}. + """ + m = gender_re.search(text) + if not m: + return None + grp = classify_matched_token(m.group(0), fem_set, male_set) + if grp is None: + return None + start, end = m.span() + return start, end, grp + # --------------------- Scoring utils --------------------- @torch.no_grad() def sequence_logprob(model, tok, text: str, device: torch.device) -> float: @@ -119,41 +165,31 @@ def mean_ci95(xs: List[float]) -> Tuple[float, float]: if len(xs) < 2: return m, 0.0 sd = statistics.pstdev(xs) # pop std (ok for quick CI) - # Approx CI with normal approx: 1.96 * sd/sqrt(n) ci = 1.96 * (sd / math.sqrt(len(xs))) return m, ci # --------------------- CTF evaluation --------------------- -def first_gender_span(text: str, fem_words: List[str], male_words: List[str]) -> Optional[Tuple[int, int, str]]: - """ - Find first occurrence (case-insensitive) of any gender word; return (start,end,group) - group in {"F","M"}. - """ - lower = text.lower() - # prefer pronouns first (single-token in most tokenizers) - fem_sorted = sorted(fem_words, key=len) # short first - male_sorted = sorted(male_words, key=len) - for group, vocab in (("F", fem_sorted), ("M", male_sorted)): - for w in vocab: - i = lower.find(w) - if i != -1: - return i, i + len(w), group - return None - def eval_ctf(model, tok, ctf_rows: List[Dict], fem_words, male_words, fem_ids: List[int], male_ids: List[int], topk:int, device, outdir:str): preds = [] ctf_gaps = [] jsds = [] cover = 0 + + # Build boundary-safe regex once + gender_re = build_gender_regex(fem_words, male_words) + fem_set = set(w.lower() for w in fem_words) + male_set = set(w.lower() for w in male_words) + for r in ctf_rows: x = r["x"]; xs = r["x_swap"] # choose boundary after the first gender token to probe next-token distribution - b1 = first_gender_span(x, fem_words, male_words) - b2 = first_gender_span(xs, fem_words, male_words) + b1 = first_gender_span(x, gender_re, fem_set, male_set) + b2 = first_gender_span(xs, gender_re, fem_set, male_set) if not b1 or not b2: preds.append({**r, "skip_reason":"no_gender_span"}) continue + # probe right after the matched gender token (index = end) p1 = next_token_prob(model, tok, x[:b1[1]], device) p2 = next_token_prob(model, tok, xs[:b2[1]], device) @@ -167,8 +203,7 @@ def eval_ctf(model, tok, ctf_rows: List[Dict], fem_words, male_words, gap2 = abs(piF2 - piM2) # record per-sample (we report gap on factual x; swap gap optional) - gap = gap1 - ctf_gaps.append(gap) + ctf_gaps.append(gap1) cover += 1 # swap JSD at the probe step @@ -179,6 +214,8 @@ def eval_ctf(model, tok, ctf_rows: List[Dict], fem_words, male_words, **r, "probe_index_factual": b1[1], "probe_index_swap": b2[1], + "matched_factual": x[b1[0]:b1[1]], + "matched_swap": xs[b2[0]:b2[1]], "piF_factual": piF1, "piM_factual": piM1, "gap_factual": gap1, "piF_swap": piF2, "piM_swap": piM2, "gap_swap": gap2, "jsd_swap": j @@ -232,7 +269,6 @@ def eval_wino(model, tok, rows: List[Dict], device, outdir:str): preds.append({**r, "logp_a": lpa, "logp_b": lpb, "pred": pred, "correct": int(pred==gold)}) acc = correct / max(1,len(rows)) - # quick CI with normal approx sd = math.sqrt(acc*(1-acc)/max(1,len(rows))) ci = 1.96 * sd metrics = { |
