summaryrefslogtreecommitdiff
path: root/analysis/structural_overlap.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@illinois.edu>2026-04-08 22:06:05 -0500
committerYuren Hao <yurenh2@illinois.edu>2026-04-08 22:06:05 -0500
commit05704d0eb2fa59fe727652465b07db40bcb06c38 (patch)
tree8904aca836cf552fd1a5ae8c2174e9f91e70bbbc /analysis/structural_overlap.py
Initial release: GAP framework
- Full pipeline: variant generation, multi-judge verification, evaluation - Loaders for OpenAI / Anthropic / Google / xAI / OpenRouter / vLLM - Framework-level mechanism analyses: paired structural overlap, repairability rescue, self-correction probe, cross-model agreement, topic x problem-type interaction - Unicode -> bare-LaTeX cleaner + audit + spot-check - Mirrors https://huggingface.co/datasets/blackhao0426/PutnamGAP
Diffstat (limited to 'analysis/structural_overlap.py')
-rw-r--r--analysis/structural_overlap.py523
1 files changed, 523 insertions, 0 deletions
diff --git a/analysis/structural_overlap.py b/analysis/structural_overlap.py
new file mode 100644
index 0000000..284c139
--- /dev/null
+++ b/analysis/structural_overlap.py
@@ -0,0 +1,523 @@
+"""Stable-vs-Brittle structural overlap analysis (label-free).
+
+Pipeline:
+1. For each (model, surface_variant) cell, load original and variant trajectories.
+2. Pull the deterministic rename map from /home/yurenh2/gap/putnam-bench-anon/dataset/.
+3. Canonicalize both trajectories: replace variant variables with placeholders
+ (via inverse rename map). Original trajectory: replace canonical variables
+ with the same placeholders. Both texts then live in a shared placeholder space.
+4. Compute multiple non-LLM structural metrics on (orig_canonical, var_canonical):
+ - Token Jaccard
+ - Bigram Jaccard
+ - Equation-set Jaccard (math-block extraction)
+ - Prefix Jaccard (first 30% of each canonical text)
+5. Stratify by group (stable vs brittle) within each (model, variant) cell.
+6. Mann-Whitney U test on each metric for stable vs brittle.
+
+Surface variants only (rename map available). Kernel handled separately.
+"""
+
+from __future__ import annotations
+import json
+import re
+import os
+from pathlib import Path
+from collections import Counter, defaultdict
+from typing import Dict, List, Tuple, Optional
+
+import statistics
+
+DATASET_DIR = Path("/home/yurenh2/gap/putnam-bench-anon/dataset")
+RESULTS_DIR = Path("/home/yurenh2/gap/results_new")
+
+SURFACE_VARIANTS = ["descriptive_long", "descriptive_long_confusing",
+ "descriptive_long_misleading", "garbled_string"]
+
+
+# ---------- I/O helpers ----------
+
+def load_problems(path: Path) -> List[dict]:
+ d = json.load(open(path))
+ return d.get("problems") or d.get("detailed_results") or []
+
+
+def find_variant_file(model_dir: Path, variant: str) -> Optional[Path]:
+ files = sorted(os.listdir(model_dir))
+ cands = [f for f in files
+ if f.endswith(f"_{variant}.json")
+ and "regraded" not in f and "comparison" not in f
+ and not f.endswith(f"_{variant}2.json")]
+ if not cands and variant == "garbled_string":
+ cands = [f for f in files if f.endswith("_gs.json")]
+ return model_dir / cands[0] if cands else None
+
+
+def load_dataset_maps() -> Dict[str, Dict[str, Dict[str, str]]]:
+ """Returns: {problem_index: {variant: {orig_var_name: variant_var_name}}}"""
+ out = {}
+ for f in sorted(DATASET_DIR.glob("*.json")):
+ d = json.load(open(f))
+ idx = d.get("index")
+ variants = d.get("variants", {})
+ cell = {}
+ for v in SURFACE_VARIANTS:
+ vd = variants.get(v, {})
+ mp_str = vd.get("map")
+ if isinstance(mp_str, str):
+ # The map is stored as a Python repr string; eval it safely
+ try:
+ mp = eval(mp_str, {"__builtins__": {}}, {})
+ if isinstance(mp, dict):
+ cell[v] = {str(k): str(v) for k, v in mp.items()}
+ except Exception:
+ pass
+ elif isinstance(mp_str, dict):
+ cell[v] = {str(k): str(v) for k, v in mp_str.items()}
+ out[idx] = cell
+ return out
+
+
+# ---------- Canonicalization ----------
+
+def canonicalize_text(text: str, var_to_placeholder: Dict[str, str]) -> str:
+ """Replace each variable name in text with its canonical placeholder.
+
+ Sort by length desc to avoid prefix collisions (e.g., 'xs' before 'x').
+ Use word-boundary regex for ASCII-identifier-like names; literal replace
+ for non-identifier names (like garbled strings, which are also alpha).
+ """
+ if not text:
+ return ""
+ # Sort longest-first to avoid 'al' eating into 'almondtree'
+ items = sorted(var_to_placeholder.items(), key=lambda kv: -len(kv[0]))
+ out = text
+ for var, ph in items:
+ if not var:
+ continue
+ # Use word-boundary so we only replace whole tokens. Variables in this
+ # dataset are all alphanumeric.
+ pat = r"(?<![A-Za-z0-9_])" + re.escape(var) + r"(?![A-Za-z0-9_])"
+ out = re.sub(pat, ph, out)
+ return out
+
+
+def normalize_whitespace(text: str) -> str:
+ return re.sub(r"\s+", " ", text).strip()
+
+
+# ---------- Tokenization ----------
+
+_TOKEN_RE = re.compile(r"[A-Za-z_][A-Za-z0-9_]*|\d+|[^\sA-Za-z0-9_]")
+
+def tokens(text: str) -> List[str]:
+ return _TOKEN_RE.findall(text or "")
+
+
+def bigrams(toks: List[str]) -> List[str]:
+ return [f"{toks[i]} {toks[i+1]}" for i in range(len(toks) - 1)]
+
+
+# ---------- Math block extraction ----------
+
+_MATH_BLOCKS = [
+ re.compile(r"\$\$(.+?)\$\$", re.DOTALL),
+ re.compile(r"\\\[(.+?)\\\]", re.DOTALL),
+ re.compile(r"\$(.+?)\$", re.DOTALL),
+ re.compile(r"\\begin\{(?:equation|align|gather)\*?\}(.+?)\\end\{(?:equation|align|gather)\*?\}", re.DOTALL),
+]
+
+def extract_math_blocks(text: str, min_len: int = 8) -> List[str]:
+ found = []
+ for pat in _MATH_BLOCKS:
+ found.extend(pat.findall(text or ""))
+ # Lightweight normalization: collapse whitespace, strip
+ out = [normalize_whitespace(b) for b in found if b.strip()]
+ # Filter trivial fragments like '$n$', '$0$', '$x$' that saturate Jaccard
+ return [b for b in out if len(b) >= min_len]
+
+
+# ---------- Metrics ----------
+
+def jaccard(a: set, b: set) -> float:
+ if not a and not b:
+ return 1.0
+ return len(a & b) / max(1, len(a | b))
+
+
+def metric_token_jaccard(a: str, b: str) -> float:
+ return jaccard(set(tokens(a)), set(tokens(b)))
+
+
+def metric_bigram_jaccard(a: str, b: str) -> float:
+ return jaccard(set(bigrams(tokens(a))), set(bigrams(tokens(b))))
+
+
+def metric_prefix_token_jaccard(a: str, b: str, frac: float = 0.3) -> float:
+ """Jaccard over the first frac of tokens from each side."""
+ ta, tb = tokens(a), tokens(b)
+ na, nb = max(1, int(len(ta) * frac)), max(1, int(len(tb) * frac))
+ return jaccard(set(ta[:na]), set(tb[:nb]))
+
+
+def metric_prefix_bigram_jaccard(a: str, b: str, frac: float = 0.3) -> float:
+ ta, tb = tokens(a), tokens(b)
+ na, nb = max(1, int(len(ta) * frac)), max(1, int(len(tb) * frac))
+ return jaccard(set(bigrams(ta[:na])), set(bigrams(tb[:nb])))
+
+
+def metric_equation_jaccard(a: str, b: str) -> float:
+ ea = set(extract_math_blocks(a))
+ eb = set(extract_math_blocks(b))
+ return jaccard(ea, eb)
+
+
+def metric_lcp_tokens(a: str, b: str) -> int:
+ """Length of the longest common prefix of canonicalized token streams.
+
+ Directly tests Codex's thesis 'early loss of structural overlap with the
+ model's own original reasoning under renaming'. Larger LCP -> the model
+ started its variant trajectory the same way it started the original.
+ """
+ ta, tb = tokens(a), tokens(b)
+ n = min(len(ta), len(tb))
+ i = 0
+ while i < n and ta[i] == tb[i]:
+ i += 1
+ return i
+
+
+def metric_lcp_normalized(a: str, b: str) -> float:
+ """LCP length normalized by the shorter trajectory length, in [0, 1]."""
+ ta, tb = tokens(a), tokens(b)
+ n = min(len(ta), len(tb))
+ if n == 0:
+ return 0.0
+ return metric_lcp_tokens(a, b) / n
+
+
+def metric_lcp_first1k(a: str, b: str) -> float:
+ """LCP length capped to first-1000-token comparison, normalized to [0, 1]."""
+ ta, tb = tokens(a), tokens(b)
+ ta, tb = ta[:1000], tb[:1000]
+ n = min(len(ta), len(tb))
+ if n == 0:
+ return 0.0
+ i = 0
+ while i < n and ta[i] == tb[i]:
+ i += 1
+ return i / n
+
+
+def metric_directional_coverage(a: str, b: str) -> float:
+ """|tokens_a ∩ tokens_b| / |tokens_a|. Length-asymmetric.
+
+ Reads as: 'what fraction of the original's vocabulary survives in the variant?'
+ More robust to length differences than symmetric Jaccard.
+ """
+ ta = set(tokens(a))
+ tb = set(tokens(b))
+ if not ta:
+ return 0.0
+ return len(ta & tb) / len(ta)
+
+
+def metric_window_token_jaccard(a: str, b: str, window: int = 600) -> float:
+ """Jaccard restricted to the first `window` tokens on each side."""
+ ta = tokens(a)[:window]
+ tb = tokens(b)[:window]
+ return jaccard(set(ta), set(tb))
+
+
+def metric_window_bigram_jaccard(a: str, b: str, window: int = 600) -> float:
+ ta = tokens(a)[:window]
+ tb = tokens(b)[:window]
+ return jaccard(set(bigrams(ta)), set(bigrams(tb)))
+
+
+# ---------- Stat helpers ----------
+
+def bootstrap_ci_delta_median(xs: List[float], ys: List[float],
+ n_iter: int = 1000, seed: int = 0) -> Tuple[float, float]:
+ """Percentile bootstrap 95% CI on median(xs) - median(ys)."""
+ import random
+ rng = random.Random(seed)
+ if not xs or not ys:
+ return float("nan"), float("nan")
+ ds = []
+ for _ in range(n_iter):
+ rs = [xs[rng.randrange(len(xs))] for _ in range(len(xs))]
+ rb = [ys[rng.randrange(len(ys))] for _ in range(len(ys))]
+ ds.append(statistics.median(rs) - statistics.median(rb))
+ ds.sort()
+ lo = ds[int(0.025 * n_iter)]
+ hi = ds[int(0.975 * n_iter)]
+ return lo, hi
+
+
+def bootstrap_ci_cohens_d(xs: List[float], ys: List[float],
+ n_iter: int = 1000, seed: int = 0) -> Tuple[float, float]:
+ import random
+ rng = random.Random(seed)
+ if len(xs) < 2 or len(ys) < 2:
+ return float("nan"), float("nan")
+ ds = []
+ for _ in range(n_iter):
+ rs = [xs[rng.randrange(len(xs))] for _ in range(len(xs))]
+ rb = [ys[rng.randrange(len(ys))] for _ in range(len(ys))]
+ sm, bm = statistics.fmean(rs), statistics.fmean(rb)
+ ssd = statistics.pstdev(rs)
+ bsd = statistics.pstdev(rb)
+ pooled = (((len(rs)-1)*ssd**2 + (len(rb)-1)*bsd**2)
+ / max(1, len(rs)+len(rb)-2)) ** 0.5
+ if pooled > 0:
+ ds.append((sm - bm) / pooled)
+ if not ds:
+ return float("nan"), float("nan")
+ ds.sort()
+ lo = ds[int(0.025 * len(ds))]
+ hi = ds[int(0.975 * len(ds))]
+ return lo, hi
+
+
+def mann_whitney_u(xs: List[float], ys: List[float]) -> Tuple[float, float]:
+ """Returns (U, normal_approx_p_two_sided). Pure-python, no scipy.
+
+ Used only as a screening signal — for the rebuttal we'll use scipy if
+ available; this is a fallback so we don't add a dependency.
+ """
+ n1, n2 = len(xs), len(ys)
+ if n1 == 0 or n2 == 0:
+ return float("nan"), float("nan")
+ combined = [(v, 0) for v in xs] + [(v, 1) for v in ys]
+ combined.sort(key=lambda t: t[0])
+ # Average ranks for ties
+ ranks = [0.0] * len(combined)
+ i = 0
+ while i < len(combined):
+ j = i
+ while j + 1 < len(combined) and combined[j + 1][0] == combined[i][0]:
+ j += 1
+ avg = (i + j) / 2.0 + 1 # 1-indexed
+ for k in range(i, j + 1):
+ ranks[k] = avg
+ i = j + 1
+ R1 = sum(ranks[k] for k in range(len(combined)) if combined[k][1] == 0)
+ U1 = R1 - n1 * (n1 + 1) / 2.0
+ U2 = n1 * n2 - U1
+ U = min(U1, U2)
+ # Normal approx (no tie correction)
+ mu = n1 * n2 / 2.0
+ sd = (n1 * n2 * (n1 + n2 + 1) / 12.0) ** 0.5
+ if sd == 0:
+ return U, float("nan")
+ z = (U - mu) / sd
+ # Two-sided p via erf approx
+ import math
+ p = math.erfc(abs(z) / math.sqrt(2))
+ return U, p
+
+
+# ---------- Cell analysis ----------
+
+COLLAPSE_MIN_CHARS = 200
+COLLAPSE_RATIO = 0.25 # variant_len < ratio * orig_len => collapse
+
+
+def is_collapse(orig_text: str, var_text: str) -> bool:
+ return (len(var_text) < COLLAPSE_MIN_CHARS
+ or len(var_text) < COLLAPSE_RATIO * max(1, len(orig_text)))
+
+
+def analyze_cell(model_name: str, variant: str, dataset_maps: dict,
+ model_dir: Path) -> Optional[dict]:
+ orig_path = find_variant_file(model_dir, "original")
+ var_path = find_variant_file(model_dir, variant)
+ if not orig_path or not var_path:
+ return None
+
+ orig_by = {p["index"]: p for p in load_problems(orig_path)}
+ var_by = {p["index"]: p for p in load_problems(var_path)}
+
+ common = set(orig_by) & set(var_by)
+ pairs_stable_drift = [] # (orig_canon, var_canon, problem_type) — non-collapse
+ pairs_brittle_drift = [] # non-collapse brittle
+ pairs_brittle_collapse = [] # short variant text
+ n_stable_collapse = 0 # almost always 0 but tracked for completeness
+
+ for idx in common:
+ po, pv = orig_by[idx], var_by[idx]
+ if po.get("correct") is not True:
+ continue
+ var_correct = pv.get("correct")
+ if var_correct is None:
+ continue
+ orig_text = (po.get("solve") or {}).get("solution") or ""
+ var_text = (pv.get("solve") or {}).get("solution") or ""
+ if not orig_text or not var_text:
+ continue
+ rmap = dataset_maps.get(idx, {}).get(variant)
+ if not rmap:
+ continue
+ # Canonicalize
+ canon_to_ph = {k: f"__V{i}__" for i, k in enumerate(rmap.keys())}
+ var_to_ph = {rmap[k]: canon_to_ph[k] for k in rmap}
+ orig_canon = canonicalize_text(orig_text, canon_to_ph)
+ var_canon = canonicalize_text(var_text, var_to_ph)
+ sample = {
+ "index": idx,
+ "problem_type": po.get("problem_type"),
+ "orig_canon": orig_canon,
+ "var_canon": var_canon,
+ "orig_len": len(orig_text),
+ "var_len": len(var_text),
+ }
+ collapse = is_collapse(orig_text, var_text)
+ if var_correct is True:
+ if collapse:
+ n_stable_collapse += 1
+ else:
+ pairs_stable_drift.append(sample)
+ else:
+ if collapse:
+ pairs_brittle_collapse.append(sample)
+ else:
+ pairs_brittle_drift.append(sample)
+
+ if not pairs_stable_drift or not pairs_brittle_drift:
+ return None
+
+ metrics = {
+ "token_jaccard": metric_token_jaccard,
+ "bigram_jaccard": metric_bigram_jaccard,
+ "directional_coverage": metric_directional_coverage,
+ "window_token_jaccard": metric_window_token_jaccard,
+ "window_bigram_jaccard": metric_window_bigram_jaccard,
+ "equation_jaccard": metric_equation_jaccard,
+ }
+ # Headline metric for bootstrap + noise floor (the others stay descriptive)
+ HEADLINE = "token_jaccard"
+
+ # Pre-tokenize once per pair to amortize cost (used by token/bigram/window metrics).
+ for p in pairs_stable_drift + pairs_brittle_drift:
+ p["_otok"] = tokens(p["orig_canon"])
+ p["_vtok"] = tokens(p["var_canon"])
+ p["_oset"] = set(p["_otok"])
+ p["_vset"] = set(p["_vtok"])
+
+ def fast_token_jaccard(p):
+ a, b = p["_oset"], p["_vset"]
+ if not a and not b:
+ return 1.0
+ return len(a & b) / max(1, len(a | b))
+
+ def fast_token_jaccard_pair(pa, pb):
+ a, b = pa["_oset"], pb["_vset"]
+ if not a and not b:
+ return 1.0
+ return len(a & b) / max(1, len(a | b))
+
+ out = {
+ "model": model_name,
+ "variant": variant,
+ "n_stable_drift": len(pairs_stable_drift),
+ "n_brittle_drift": len(pairs_brittle_drift),
+ "n_stable_collapse": n_stable_collapse,
+ "n_brittle_collapse": len(pairs_brittle_collapse),
+ "brittle_collapse_rate": (len(pairs_brittle_collapse)
+ / max(1, len(pairs_brittle_collapse) + len(pairs_brittle_drift))),
+ "metrics": {},
+ }
+ # Compute all descriptive metrics (one pass per pair, no bootstrap)
+ for mname, mfn in metrics.items():
+ s_vals = [mfn(p["orig_canon"], p["var_canon"]) for p in pairs_stable_drift]
+ b_vals = [mfn(p["orig_canon"], p["var_canon"]) for p in pairs_brittle_drift]
+ U, p = mann_whitney_u(s_vals, b_vals)
+ sm, bm = statistics.fmean(s_vals), statistics.fmean(b_vals)
+ ssd = statistics.pstdev(s_vals) if len(s_vals) > 1 else 0
+ bsd = statistics.pstdev(b_vals) if len(b_vals) > 1 else 0
+ pooled = (((len(s_vals)-1)*ssd**2 + (len(b_vals)-1)*bsd**2)
+ / max(1, len(s_vals)+len(b_vals)-2)) ** 0.5
+ d = (sm - bm) / pooled if pooled > 0 else 0.0
+ out["metrics"][mname] = {
+ "stable_median": statistics.median(s_vals),
+ "stable_mean": sm,
+ "brittle_median": statistics.median(b_vals),
+ "brittle_mean": bm,
+ "delta_median": statistics.median(s_vals) - statistics.median(b_vals),
+ "delta_mean": sm - bm,
+ "cohens_d": d,
+ "U": U,
+ "p_two_sided": p,
+ }
+
+ # Bootstrap + noise floor only on headline metric
+ s_vals = [fast_token_jaccard(p) for p in pairs_stable_drift]
+ b_vals = [fast_token_jaccard(p) for p in pairs_brittle_drift]
+ ci_lo, ci_hi = bootstrap_ci_delta_median(s_vals, b_vals, n_iter=400)
+ d_lo, d_hi = bootstrap_ci_cohens_d(s_vals, b_vals, n_iter=400)
+ out["metrics"][HEADLINE]["delta_median_ci"] = [ci_lo, ci_hi]
+ out["metrics"][HEADLINE]["cohens_d_ci"] = [d_lo, d_hi]
+
+ # Random-pairing noise floor for headline: pair stable orig with random other-problem variant
+ import random as _r
+ rng = _r.Random(42)
+ nf_vals = []
+ n = len(pairs_stable_drift)
+ if n >= 2:
+ for _ in range(min(400, n * (n - 1))):
+ i = rng.randrange(n)
+ j = rng.randrange(n)
+ while j == i:
+ j = rng.randrange(n)
+ nf_vals.append(fast_token_jaccard_pair(pairs_stable_drift[i],
+ pairs_stable_drift[j]))
+ out["metrics"][HEADLINE]["noise_floor_median"] = (
+ statistics.median(nf_vals) if nf_vals else None)
+ out["metrics"][HEADLINE]["noise_floor_mean"] = (
+ statistics.fmean(nf_vals) if nf_vals else None)
+ out["metrics"][HEADLINE]["noise_floor_n"] = len(nf_vals)
+ return out
+
+
+def main():
+ print("Loading dataset rename maps ...", flush=True)
+ dataset_maps = load_dataset_maps()
+ print(f" loaded {len(dataset_maps)} problems", flush=True)
+
+ # Multi-cell sweep across all models × 4 surface variants
+ # Run all 18 models — non-LLM, fast.
+ all_models = sorted([d.name for d in RESULTS_DIR.iterdir() if d.is_dir()])
+ print(f"Models: {all_models}")
+ all_results = []
+
+ print(f"\n{'Cell':<46} {'nSd':>4} {'nBd':>4} {'col%':>5} "
+ f"{'sMed':>6} {'bMed':>6} {'nfMed':>6} "
+ f"{'d':>6} {'d95CI':>14} {'p':>9}")
+ print("-" * 122)
+
+ for m in all_models:
+ for v in SURFACE_VARIANTS:
+ mdir = RESULTS_DIR / m
+ if not mdir.exists():
+ continue
+ res = analyze_cell(m, v, dataset_maps, mdir)
+ if res is None:
+ continue
+ all_results.append(res)
+ md = res["metrics"]["token_jaccard"]
+ label = f"{m} / {v}"
+ ci_lo, ci_hi = md["cohens_d_ci"]
+ ci_str = f"[{ci_lo:+.2f}, {ci_hi:+.2f}]"
+ print(f"{label:<46} {res['n_stable_drift']:>4} {res['n_brittle_drift']:>4} "
+ f"{res['brittle_collapse_rate']*100:>4.0f}% "
+ f"{md['stable_median']:>6.3f} {md['brittle_median']:>6.3f} "
+ f"{md['noise_floor_median']:>6.3f} "
+ f"{md['cohens_d']:>+6.2f} {ci_str:>14} {md['p_two_sided']:>9.1e}")
+
+ out_path = Path("/home/yurenh2/gap/analysis/structural_overlap_results.json")
+ json.dump(all_results, open(out_path, "w"), indent=2)
+ print(f"\nSaved -> {out_path} ({len(all_results)} cells)")
+
+
+if __name__ == "__main__":
+ main()