diff options
| author | Yuren Hao <yurenh2@timan108.cs.illinois.edu> | 2025-09-10 12:09:06 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@timan108.cs.illinois.edu> | 2025-09-10 12:09:06 -0500 |
| commit | 5bfd92f6c28530482a765252a4497cfedacad25a (patch) | |
| tree | c24a9aaa21fdfdff0a91bdeeb02432679904bc8a /scripts | |
| parent | 523b1747ee27b60d06424dcabd47a309cda80536 (diff) | |
smoke tests
Diffstat (limited to 'scripts')
| -rw-r--r-- | scripts/eval_bias_baseline.py | 292 | ||||
| -rw-r--r-- | scripts/eval_main_baseline.py | 138 |
2 files changed, 430 insertions, 0 deletions
diff --git a/scripts/eval_bias_baseline.py b/scripts/eval_bias_baseline.py new file mode 100644 index 0000000..227805f --- /dev/null +++ b/scripts/eval_bias_baseline.py @@ -0,0 +1,292 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Baseline bias evaluation: CTF (x vs swap(x)), CrowS-Pairs (gender), WinoGender. +- No training; pure scoring with Hugging Face Transformers. +- T = 0 decoding policy: we don't sample; we compute log-probs directly. +Outputs: + runs/<DATE>/baseline_eval/bias/{ctf,crows,wino}/metrics.json + runs/<DATE>/baseline_eval/bias/{ctf,crows,wino}/preds.jsonl +""" +import argparse, json, os, math, re, time, pathlib, statistics +from typing import List, Dict, Tuple, Optional + +import torch +import torch.nn.functional as F +from transformers import AutoModelForCausalLM, AutoTokenizer + +# --------------------- IO utils --------------------- +def read_jsonl(path: str) -> List[Dict]: + with open(path, "r", encoding="utf-8") as f: + return [json.loads(line) for line in f if line.strip()] + +def write_json(path: str, obj: Dict): + pathlib.Path(path).parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + json.dump(obj, f, indent=2, ensure_ascii=False) + +def write_jsonl(path: str, rows: List[Dict]): + pathlib.Path(path).parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + for r in rows: + f.write(json.dumps(r, ensure_ascii=False) + "\n") + +def now_ts() -> str: + return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + +# --------------------- Token set mapping --------------------- +def load_word_list(path: str) -> List[str]: + words = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + w = line.strip().lower() + if w: + words.append(w) + return words + +def map_words_to_token_ids(tok: AutoTokenizer, words: List[str]) -> List[int]: + """ + Map words to token ids as single-token variants if possible. + Try with and without leading space; if both single-tokenize, include both. + Fall back: if tokenizes to multiple tokens, include the FIRST token id + (approximation for group-mass aggregation). + """ + ids = set() + for w in words: + cand = [] + for form in (w, " " + w): + enc = tok(form, add_special_tokens=False, return_tensors=None) + if len(enc["input_ids"]) == 1: + cand.append(enc["input_ids"][0]) + else: + cand.append(enc["input_ids"][0]) # first-piece fallback + for i in cand: + ids.add(int(i)) + return sorted(ids) + +# --------------------- Scoring utils --------------------- +@torch.no_grad() +def sequence_logprob(model, tok, text: str, device: torch.device) -> float: + """ Sum log p(y_t | y_<t) over the full sequence (excluding the first token). """ + enc = tok(text, return_tensors="pt") + input_ids = enc.input_ids.to(device) + attn_mask = enc.attention_mask.to(device) + out = model(input_ids=input_ids, attention_mask=attn_mask) + logits = out.logits # [1, T, V] + logprobs = F.log_softmax(logits[:, :-1, :], dim=-1) # exclude last targetless step + tgt = input_ids[:, 1:] # shift + ll = logprobs.gather(-1, tgt.unsqueeze(-1)).squeeze(-1).sum().item() + return float(ll) + +@torch.no_grad() +def conditional_logprob(model, tok, prompt: str, cont: str, device: torch.device) -> float: + """ log p(cont | prompt) by concatenation and subtracting prefix part. """ + e_prompt = tok(prompt, return_tensors="pt", add_special_tokens=False) + e_cont = tok(" " + cont, return_tensors="pt", add_special_tokens=False) + input_ids = torch.cat([e_prompt.input_ids, e_cont.input_ids], dim=-1).to(device) + attn_mask = torch.ones_like(input_ids).to(device) + + out = model(input_ids=input_ids, attention_mask=attn_mask) + logits = out.logits # [1, T, V] + logprobs = F.log_softmax(logits[:, :-1, :], dim=-1) + tgt = input_ids[:, 1:] + ll_all = logprobs.gather(-1, tgt.unsqueeze(-1)).squeeze(-1) + + Lp = e_prompt.input_ids.size(-1) + ll = ll_all[:, Lp:].sum().item() + return float(ll) + +@torch.no_grad() +def next_token_prob(model, tok, prefix: str, device: torch.device) -> torch.Tensor: + """ p_t(. | prefix) distribution for the next token. """ + enc = tok(prefix, return_tensors="pt") + input_ids = enc.input_ids.to(device) + attn_mask = enc.attention_mask.to(device) + out = model(input_ids=input_ids, attention_mask=attn_mask) + logits = out.logits[:, -1, :] # last position + probs = F.softmax(logits, dim=-1).squeeze(0) # [V] + return probs # 1-D tensor + +def jsd(p: torch.Tensor, q: torch.Tensor, eps: float = 1e-12) -> float: + p = p.clamp_min(eps); q = q.clamp_min(eps) + m = 0.5 * (p + q) + return 0.5 * (p * (p.log() - m.log())).sum().item() + 0.5 * (q * (q.log() - m.log())).sum().item() + +def mean_ci95(xs: List[float]) -> Tuple[float, float]: + if not xs: + return 0.0, 0.0 + m = statistics.fmean(xs) + if len(xs) < 2: + return m, 0.0 + sd = statistics.pstdev(xs) # pop std (ok for quick CI) + # Approx CI with normal approx: 1.96 * sd/sqrt(n) + ci = 1.96 * (sd / math.sqrt(len(xs))) + return m, ci + +# --------------------- CTF evaluation --------------------- +def first_gender_span(text: str, fem_words: List[str], male_words: List[str]) -> Optional[Tuple[int, int, str]]: + """ + Find first occurrence (case-insensitive) of any gender word; return (start,end,group) + group in {"F","M"}. + """ + lower = text.lower() + # prefer pronouns first (single-token in most tokenizers) + fem_sorted = sorted(fem_words, key=len) # short first + male_sorted = sorted(male_words, key=len) + for group, vocab in (("F", fem_sorted), ("M", male_sorted)): + for w in vocab: + i = lower.find(w) + if i != -1: + return i, i + len(w), group + return None + +def eval_ctf(model, tok, ctf_rows: List[Dict], fem_words, male_words, + fem_ids: List[int], male_ids: List[int], topk:int, device, outdir:str): + preds = [] + ctf_gaps = [] + jsds = [] + cover = 0 + for r in ctf_rows: + x = r["x"]; xs = r["x_swap"] + # choose boundary after the first gender token to probe next-token distribution + b1 = first_gender_span(x, fem_words, male_words) + b2 = first_gender_span(xs, fem_words, male_words) + if not b1 or not b2: + preds.append({**r, "skip_reason":"no_gender_span"}) + continue + p1 = next_token_prob(model, tok, x[:b1[1]], device) + p2 = next_token_prob(model, tok, xs[:b2[1]], device) + + # group masses + piF1 = float(p1[fem_ids].sum().item()) if fem_ids else 0.0 + piM1 = float(p1[male_ids].sum().item()) if male_ids else 0.0 + gap1 = abs(piF1 - piM1) + + piF2 = float(p2[fem_ids].sum().item()) if fem_ids else 0.0 + piM2 = float(p2[male_ids].sum().item()) if male_ids else 0.0 + gap2 = abs(piF2 - piM2) + + # record per-sample (we report gap on factual x; swap gap optional) + gap = gap1 + ctf_gaps.append(gap) + cover += 1 + + # swap JSD at the probe step + j = jsd(p1, p2) + jsds.append(float(j)) + + preds.append({ + **r, + "probe_index_factual": b1[1], + "probe_index_swap": b2[1], + "piF_factual": piF1, "piM_factual": piM1, "gap_factual": gap1, + "piF_swap": piF2, "piM_swap": piM2, "gap_swap": gap2, + "jsd_swap": j + }) + + m_gap, ci_gap = mean_ci95(ctf_gaps) + m_jsd, ci_jsd = mean_ci95(jsds) + metrics = { + "timestamp": now_ts(), + "count": len(ctf_rows), + "covered": cover, + "coverage": (cover / max(1,len(ctf_rows))), + "CTF_gap_mean": m_gap, "CTF_gap_ci95": ci_gap, + "JSD_swap_mean": m_jsd, "JSD_swap_ci95": ci_jsd, + "topk": topk + } + write_json(os.path.join(outdir, "ctf", "metrics.json"), metrics) + write_jsonl(os.path.join(outdir, "ctf", "preds.jsonl"), preds) + +# --------------------- CrowS-Pairs (gender) --------------------- +def eval_crows(model, tok, rows: List[Dict], device, outdir:str): + deltas = [] + preds = [] + for r in rows: + s_st = r["sentence_stereo"]; s_anti = r["sentence_anti"] + ll_st = sequence_logprob(model, tok, s_st, device) + ll_anti = sequence_logprob(model, tok, s_anti, device) + delta = ll_anti - ll_st + deltas.append(float(delta)) + preds.append({**r, "logp_stereo": ll_st, "logp_anti": ll_anti, "delta": delta}) + m, ci = mean_ci95(deltas) + metrics = { + "timestamp": now_ts(), + "count": len(rows), + "delta_logP_mean": m, "delta_logP_ci95": ci + } + write_json(os.path.join(outdir, "crows", "metrics.json"), metrics) + write_jsonl(os.path.join(outdir, "crows", "preds.jsonl"), preds) + +# --------------------- WinoGender --------------------- +def eval_wino(model, tok, rows: List[Dict], device, outdir:str): + correct = 0 + preds = [] + for r in rows: + prompt = r["premise"] + a = r["option_a"]; b = r["option_b"]; gold = r["gold"].strip().lower() + lpa = conditional_logprob(model, tok, prompt, a, device) + lpb = conditional_logprob(model, tok, prompt, b, device) + pred = "a" if lpa > lpb else "b" + correct += int(pred == gold) + preds.append({**r, "logp_a": lpa, "logp_b": lpb, "pred": pred, "correct": int(pred==gold)}) + + acc = correct / max(1,len(rows)) + # quick CI with normal approx + sd = math.sqrt(acc*(1-acc)/max(1,len(rows))) + ci = 1.96 * sd + metrics = { + "timestamp": now_ts(), + "count": len(rows), + "acc": acc, "acc_ci95": ci + } + write_json(os.path.join(outdir, "wino", "metrics.json"), metrics) + write_jsonl(os.path.join(outdir, "wino", "preds.jsonl"), preds) + +# --------------------- Main --------------------- +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--model", type=str, required=True, help="HF model id, e.g., Qwen/Qwen2.5-7B-Instruct") + ap.add_argument("--ctf", type=str, required=True) + ap.add_argument("--crows", type=str, required=True) + ap.add_argument("--wino", type=str, required=True) + ap.add_argument("--groups_dir", type=str, required=True, help="assets/groups/") + ap.add_argument("--out", type=str, required=True) + ap.add_argument("--top_k", type=int, default=20) + ap.add_argument("--dtype", type=str, default="bfloat16", choices=["float16","bfloat16","float32"]) + args = ap.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = {"float16":torch.float16, "bfloat16":torch.bfloat16, "float32":torch.float32}[args.dtype] + + tok = AutoTokenizer.from_pretrained(args.model, use_fast=True) + model = AutoModelForCausalLM.from_pretrained( + args.model, + torch_dtype=dtype if device.type=="cuda" else torch.float32, + device_map=None + ).to(device) + model.eval() + + fem_words = load_word_list(os.path.join(args.groups_dir, "en_female.txt")) + male_words = load_word_list(os.path.join(args.groups_dir, "en_male.txt")) + fem_ids = map_words_to_token_ids(tok, fem_words) + male_ids = map_words_to_token_ids(tok, male_words) + + outdir = args.out + + # CTF + ctf_rows = read_jsonl(args.ctf) + eval_ctf(model, tok, ctf_rows, fem_words, male_words, fem_ids, male_ids, args.top_k, device, outdir) + + # CrowS + crows_rows = read_jsonl(args.crows) + eval_crows(model, tok, crows_rows, device, outdir) + + # Wino + wino_rows = read_jsonl(args.wino) + eval_wino(model, tok, wino_rows, device, outdir) + + print("[DONE] Bias baseline written to", outdir) + +if __name__ == "__main__": + main() diff --git a/scripts/eval_main_baseline.py b/scripts/eval_main_baseline.py new file mode 100644 index 0000000..0d16f79 --- /dev/null +++ b/scripts/eval_main_baseline.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Baseline main-task eval: MATH (EM, greedy) and PPL (LM perplexity). +Outputs: + runs/<DATE>/baseline_eval/main/{math,ppl}/metrics.json + runs/<DATE>/baseline_eval/main/{math,ppl}/preds.jsonl +""" +import argparse, json, os, math, re, time, pathlib, statistics +from typing import List, Dict + +import torch +import torch.nn.functional as F +from transformers import AutoModelForCausalLM, AutoTokenizer + +def read_jsonl(path: str) -> List[Dict]: + with open(path, "r", encoding="utf-8") as f: + return [json.loads(line) for line in f if line.strip()] + +def write_json(path: str, obj): + pathlib.Path(path).parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + json.dump(obj, f, indent=2, ensure_ascii=False) + +def write_jsonl(path: str, rows: List[Dict]): + pathlib.Path(path).parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + for r in rows: + f.write(json.dumps(r, ensure_ascii=False) + "\n") + +def now_ts() -> str: + import time + return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + +# ---------- PPL ---------- +@torch.no_grad() +def sequence_nll(model, tok, text: str, device: torch.device) -> float: + enc = tok(text, return_tensors="pt") + input_ids = enc.input_ids.to(device) + attn_mask = enc.attention_mask.to(device) + out = model(input_ids=input_ids, attention_mask=attn_mask) + logits = out.logits # [1, T, V] + logprobs = F.log_softmax(logits[:, :-1, :], dim=-1) + tgt = input_ids[:, 1:] + nll = -(logprobs.gather(-1, tgt.unsqueeze(-1)).squeeze(-1)).sum().item() + return float(nll), int(tgt.numel()) + +def eval_ppl(model, tok, rows: List[Dict], device, outdir: str): + nll_sum = 0.0 + tok_count = 0 + preds = [] + for r in rows: + nll, n = sequence_nll(model, tok, r["text"], device) + nll_sum += nll; tok_count += n + preds.append({**r, "nll": nll, "tokens": n}) + ppl = math.exp(nll_sum / max(1, tok_count)) + write_json(os.path.join(outdir, "ppl", "metrics.json"), { + "timestamp": now_ts(), "count": len(rows), "tokens": tok_count, "ppl": ppl + }) + write_jsonl(os.path.join(outdir, "ppl", "preds.jsonl"), preds) + +# ---------- MATH ---------- +def canon_num(s: str) -> str: + # Extract the last integer/decimal (simple heuristic for our 5 examples) + # Remove commas/whitespace; keep leading minus; allow ^digits not needed here + s = s.strip() + # pick last number-like pattern + nums = re.findall(r"-?\d+(?:\.\d+)?", s.replace(",", "")) + return nums[-1] if nums else s.strip().lower() + +@torch.no_grad() +def greedy_generate(model, tok, prompt: str, device, max_new_tokens: int) -> str: + enc = tok(prompt, return_tensors="pt").to(device) + out = model.generate( + **enc, + do_sample=False, temperature=0.0, top_p=1.0, + max_new_tokens=max_new_tokens, + eos_token_id=tok.eos_token_id + ) + text = tok.decode(out[0], skip_special_tokens=True) + # return only the newly generated tail (after prompt) + prompt_text = tok.decode(enc.input_ids[0], skip_special_tokens=True) + if text.startswith(prompt_text): + return text[len(prompt_text):].strip() + return text.strip() + +def eval_math(model, tok, rows: List[Dict], device, outdir: str, max_new_tokens: int): + correct = 0 + preds = [] + for r in rows: + q = r["question"]; gold = r["gold"] + gen = greedy_generate(model, tok, q, device, max_new_tokens=max_new_tokens) + pred = canon_num(gen); gold_c = canon_num(gold) + is_ok = int(pred == gold_c) + correct += is_ok + preds.append({**r, "gen": gen, "pred": pred, "gold_canon": gold_c, "correct": is_ok}) + acc = correct / max(1,len(rows)) + sd = math.sqrt(acc*(1-acc)/max(1,len(rows))) + ci = 1.96 * sd + write_json(os.path.join(outdir, "math", "metrics.json"), { + "timestamp": now_ts(), "count": len(rows), + "acc": acc, "acc_ci95": ci + }) + write_jsonl(os.path.join(outdir, "math", "preds.jsonl"), preds) + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--model", type=str, required=True) + ap.add_argument("--math", type=str, required=True) + ap.add_argument("--ppl", type=str, required=True) + ap.add_argument("--out", type=str, required=True) + ap.add_argument("--dtype", type=str, default="bfloat16", choices=["float16","bfloat16","float32"]) + ap.add_argument("--max_new_tokens_math", type=int, default=512) + args = ap.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = {"float16":torch.float16, "bfloat16":torch.bfloat16, "float32":torch.float32}[args.dtype] + + tok = AutoTokenizer.from_pretrained(args.model, use_fast=True) + model = AutoModelForCausalLM.from_pretrained( + args.model, + torch_dtype=dtype if device.type=="cuda" else torch.float32, + device_map=None + ).to(device) + model.eval() + + # MATH + math_rows = read_jsonl(args.math) + eval_math(model, tok, math_rows, device, args.out, max_new_tokens=args.max_new_tokens_math) + + # PPL + ppl_rows = read_jsonl(args.ppl) + eval_ppl(model, tok, ppl_rows, device, args.out) + + print("[DONE] Main baseline written to", args.out) + +if __name__ == "__main__": + main() |
