summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
Diffstat (limited to 'scripts')
-rw-r--r--scripts/eval_bias_baseline.py80
-rw-r--r--scripts/summarize_baseline.py39
2 files changed, 97 insertions, 22 deletions
diff --git a/scripts/eval_bias_baseline.py b/scripts/eval_bias_baseline.py
index 227805f..7817b78 100644
--- a/scripts/eval_bias_baseline.py
+++ b/scripts/eval_bias_baseline.py
@@ -7,6 +7,11 @@ Baseline bias evaluation: CTF (x vs swap(x)), CrowS-Pairs (gender), WinoGender.
Outputs:
runs/<DATE>/baseline_eval/bias/{ctf,crows,wino}/metrics.json
runs/<DATE>/baseline_eval/bias/{ctf,crows,wino}/preds.jsonl
+
+NOTE (2025-09-10 fix):
+- Previously first_gender_span() used naive substring search (str.find),
+ causing false positives like "he" in "The". We now use a word-boundary
+ regex (compiled once) with case-insensitive matching to avoid such leaks.
"""
import argparse, json, os, math, re, time, pathlib, statistics
from typing import List, Dict, Tuple, Optional
@@ -64,6 +69,47 @@ def map_words_to_token_ids(tok: AutoTokenizer, words: List[str]) -> List[int]:
ids.add(int(i))
return sorted(ids)
+# --------------------- Boundary-safe gender span finder ---------------------
+def build_gender_regex(fem_words: List[str], male_words: List[str]) -> re.Pattern:
+ """
+ Compile a single regex that matches any gender token with word boundaries:
+ (?<!\w)(?:w1|w2|...)(?!\w)
+ - Case-insensitive
+ - Escapes special chars (so 'mr.' works)
+ - Sorts by length desc to prefer longer matches (e.g., 'herself' over 'her')
+ """
+ # Normalize and merge sets; keep both dotted/undotted variants if present
+ vocab = sorted({w.lower() for w in (fem_words + male_words)}, key=len, reverse=True)
+ pat = r"(?<!\w)(?:" + "|".join(re.escape(w) for w in vocab) + r")(?!\w)"
+ return re.compile(pat, flags=re.IGNORECASE)
+
+def classify_matched_token(matched: str, fem_set: set, male_set: set) -> Optional[str]:
+ """
+ Decide whether the matched surface form belongs to F or M (case-insensitive).
+ We try raw, then rstrip('.') variants for dotted honorifics (mr., ms., mrs.).
+ Returns "F" or "M" (or None if not classifiable, which shouldn't happen if lists are consistent).
+ """
+ s = matched.lower()
+ if s in fem_set: return "F"
+ if s in male_set: return "M"
+ s2 = s.rstrip(".")
+ if s2 in fem_set: return "F"
+ if s2 in male_set: return "M"
+ return None
+
+def first_gender_span(text: str, gender_re: re.Pattern, fem_set: set, male_set: set) -> Optional[Tuple[int, int, str]]:
+ """
+ Find first occurrence (boundary-safe) of any gender word; return (start,end,group), group in {"F","M"}.
+ """
+ m = gender_re.search(text)
+ if not m:
+ return None
+ grp = classify_matched_token(m.group(0), fem_set, male_set)
+ if grp is None:
+ return None
+ start, end = m.span()
+ return start, end, grp
+
# --------------------- Scoring utils ---------------------
@torch.no_grad()
def sequence_logprob(model, tok, text: str, device: torch.device) -> float:
@@ -119,41 +165,31 @@ def mean_ci95(xs: List[float]) -> Tuple[float, float]:
if len(xs) < 2:
return m, 0.0
sd = statistics.pstdev(xs) # pop std (ok for quick CI)
- # Approx CI with normal approx: 1.96 * sd/sqrt(n)
ci = 1.96 * (sd / math.sqrt(len(xs)))
return m, ci
# --------------------- CTF evaluation ---------------------
-def first_gender_span(text: str, fem_words: List[str], male_words: List[str]) -> Optional[Tuple[int, int, str]]:
- """
- Find first occurrence (case-insensitive) of any gender word; return (start,end,group)
- group in {"F","M"}.
- """
- lower = text.lower()
- # prefer pronouns first (single-token in most tokenizers)
- fem_sorted = sorted(fem_words, key=len) # short first
- male_sorted = sorted(male_words, key=len)
- for group, vocab in (("F", fem_sorted), ("M", male_sorted)):
- for w in vocab:
- i = lower.find(w)
- if i != -1:
- return i, i + len(w), group
- return None
-
def eval_ctf(model, tok, ctf_rows: List[Dict], fem_words, male_words,
fem_ids: List[int], male_ids: List[int], topk:int, device, outdir:str):
preds = []
ctf_gaps = []
jsds = []
cover = 0
+
+ # Build boundary-safe regex once
+ gender_re = build_gender_regex(fem_words, male_words)
+ fem_set = set(w.lower() for w in fem_words)
+ male_set = set(w.lower() for w in male_words)
+
for r in ctf_rows:
x = r["x"]; xs = r["x_swap"]
# choose boundary after the first gender token to probe next-token distribution
- b1 = first_gender_span(x, fem_words, male_words)
- b2 = first_gender_span(xs, fem_words, male_words)
+ b1 = first_gender_span(x, gender_re, fem_set, male_set)
+ b2 = first_gender_span(xs, gender_re, fem_set, male_set)
if not b1 or not b2:
preds.append({**r, "skip_reason":"no_gender_span"})
continue
+ # probe right after the matched gender token (index = end)
p1 = next_token_prob(model, tok, x[:b1[1]], device)
p2 = next_token_prob(model, tok, xs[:b2[1]], device)
@@ -167,8 +203,7 @@ def eval_ctf(model, tok, ctf_rows: List[Dict], fem_words, male_words,
gap2 = abs(piF2 - piM2)
# record per-sample (we report gap on factual x; swap gap optional)
- gap = gap1
- ctf_gaps.append(gap)
+ ctf_gaps.append(gap1)
cover += 1
# swap JSD at the probe step
@@ -179,6 +214,8 @@ def eval_ctf(model, tok, ctf_rows: List[Dict], fem_words, male_words,
**r,
"probe_index_factual": b1[1],
"probe_index_swap": b2[1],
+ "matched_factual": x[b1[0]:b1[1]],
+ "matched_swap": xs[b2[0]:b2[1]],
"piF_factual": piF1, "piM_factual": piM1, "gap_factual": gap1,
"piF_swap": piF2, "piM_swap": piM2, "gap_swap": gap2,
"jsd_swap": j
@@ -232,7 +269,6 @@ def eval_wino(model, tok, rows: List[Dict], device, outdir:str):
preds.append({**r, "logp_a": lpa, "logp_b": lpb, "pred": pred, "correct": int(pred==gold)})
acc = correct / max(1,len(rows))
- # quick CI with normal approx
sd = math.sqrt(acc*(1-acc)/max(1,len(rows)))
ci = 1.96 * sd
metrics = {
diff --git a/scripts/summarize_baseline.py b/scripts/summarize_baseline.py
new file mode 100644
index 0000000..efe4fc7
--- /dev/null
+++ b/scripts/summarize_baseline.py
@@ -0,0 +1,39 @@
+#!/usr/bin/env python3
+import json, os, sys, pathlib, datetime as dt
+
+def loadj(p):
+ try:
+ with open(p,'r',encoding='utf-8') as f: return json.load(f)
+ except: return None
+
+def main(root):
+ root = pathlib.Path(root)
+ out = root/"summary.md"
+ bias_ctf = loadj(root/"bias/ctf/metrics.json")
+ bias_crows = loadj(root/"bias/crows/metrics.json")
+ bias_wino = loadj(root/"bias/wino/metrics.json")
+ main_math = loadj(root/"main/math/metrics.json")
+ main_ppl = loadj(root/"main/ppl/metrics.json")
+
+ lines = ["# Baseline Summary",
+ f"- Generated: {dt.datetime.now().isoformat(timespec='seconds')}",
+ "","## Bias"]
+ if bias_ctf:
+ lines.append(f"- **CTF-gap**: {bias_ctf['CTF_gap_mean']:.6f} ± {bias_ctf['CTF_gap_ci95']:.6f} (coverage={bias_ctf['coverage']:.2f})")
+ lines.append(f"- **JSD_swap**: {bias_ctf['JSD_swap_mean']:.6f} ± {bias_ctf['JSD_swap_ci95']:.6f}")
+ if bias_crows:
+ lines.append(f"- **CrowS ΔlogP** (anti−stereo): {bias_crows['delta_logP_mean']:.6f} ± {bias_crows['delta_logP_ci95']:.6f}")
+ if bias_wino:
+ lines.append(f"- **Wino Acc**: {bias_wino['acc']:.3f} ± {bias_wino['acc_ci95']:.3f}")
+ lines += ["","## Main"]
+ if main_math:
+ lines.append(f"- **MATH EM**: {main_math['acc']:.3f} ± {main_math['acc_ci95']:.3f}")
+ if main_ppl:
+ lines.append(f"- **PPL**: {main_ppl['ppl']:.2f}")
+ out.parent.mkdir(parents=True, exist_ok=True)
+ out.write_text("\n".join(lines)+"\n",encoding='utf-8')
+ print("Wrote", out)
+
+if __name__=="__main__":
+ # usage: python scripts/summarize_baseline.py runs/20250910/baseline_eval
+ main(sys.argv[1] if len(sys.argv)>1 else "runs")