From 05704d0eb2fa59fe727652465b07db40bcb06c38 Mon Sep 17 00:00:00 2001 From: Yuren Hao Date: Wed, 8 Apr 2026 22:06:05 -0500 Subject: Initial release: GAP framework - Full pipeline: variant generation, multi-judge verification, evaluation - Loaders for OpenAI / Anthropic / Google / xAI / OpenRouter / vLLM - Framework-level mechanism analyses: paired structural overlap, repairability rescue, self-correction probe, cross-model agreement, topic x problem-type interaction - Unicode -> bare-LaTeX cleaner + audit + spot-check - Mirrors https://huggingface.co/datasets/blackhao0426/PutnamGAP --- analysis/aggregate_overlap.py | 91 ++++ analysis/balance_diff.py | 109 +++++ analysis/cross_model_agreement.py | 180 ++++++++ analysis/kv_overlap.py | 332 ++++++++++++++ analysis/make_figures.py | 272 +++++++++++ analysis/normalization_analysis.py | 189 ++++++++ analysis/rescue_analyze.py | 161 +++++++ analysis/rescue_api.py | 373 +++++++++++++++ analysis/rescue_pooled.py | 174 +++++++ analysis/rescue_prompts.py | 267 +++++++++++ analysis/rescue_runner.py | 341 ++++++++++++++ analysis/sc_success_and_difficulty.py | 192 ++++++++ analysis/self_correction.py | 202 +++++++++ analysis/spotcheck_clean.py | 181 ++++++++ analysis/structural_overlap.py | 523 +++++++++++++++++++++ analysis/topic_problemtype_interaction.py | 112 +++++ analysis/unicode_audit.py | 238 ++++++++++ analysis/unicode_clean.py | 729 ++++++++++++++++++++++++++++++ 18 files changed, 4666 insertions(+) create mode 100644 analysis/aggregate_overlap.py create mode 100644 analysis/balance_diff.py create mode 100644 analysis/cross_model_agreement.py create mode 100644 analysis/kv_overlap.py create mode 100644 analysis/make_figures.py create mode 100644 analysis/normalization_analysis.py create mode 100644 analysis/rescue_analyze.py create mode 100644 analysis/rescue_api.py create mode 100644 analysis/rescue_pooled.py create mode 100644 analysis/rescue_prompts.py create mode 100644 analysis/rescue_runner.py create mode 100644 analysis/sc_success_and_difficulty.py create mode 100644 analysis/self_correction.py create mode 100644 analysis/spotcheck_clean.py create mode 100644 analysis/structural_overlap.py create mode 100644 analysis/topic_problemtype_interaction.py create mode 100644 analysis/unicode_audit.py create mode 100644 analysis/unicode_clean.py (limited to 'analysis') diff --git a/analysis/aggregate_overlap.py b/analysis/aggregate_overlap.py new file mode 100644 index 0000000..cd6b53e --- /dev/null +++ b/analysis/aggregate_overlap.py @@ -0,0 +1,91 @@ +"""Aggregate structural_overlap results by variant type and by model. + +Produces a clean rebuttal table. +""" +from __future__ import annotations +import json +import statistics +from pathlib import Path +from collections import defaultdict + +RESULTS = Path("/home/yurenh2/gap/analysis/structural_overlap_results.json") +SHORT = {"descriptive_long":"DL","descriptive_long_confusing":"DLC", + "descriptive_long_misleading":"DLM","garbled_string":"GS"} + + +def main(): + cells = json.load(open(RESULTS)) + print(f"Loaded {len(cells)} cells.\n") + + # Per-variant aggregate + per_variant = defaultdict(list) + for c in cells: + per_variant[c["variant"]].append(c) + + print("=" * 90) + print("HEADLINE TABLE: Surface variants — stable vs brittle structural overlap") + print("(token Jaccard on canonicalized trajectories, drift cases only)") + print("=" * 90) + print(f"\n{'Variant':<6} {'#cells':>7} {'#dir+':>6} {'#p<.05':>8} " + f"{'med-d':>7} {'mean-d':>7} {'mean-dlt':>9} " + f"{'mean-stbl':>10} {'mean-brit':>10} {'mean-noise':>11} " + f"{'mean-collapse%':>14}") + print("-" * 100) + for v, cs in per_variant.items(): + ds = [c["metrics"]["token_jaccard"]["cohens_d"] for c in cs] + ps = [c["metrics"]["token_jaccard"]["p_two_sided"] for c in cs] + n_pos = sum(1 for d in ds if d > 0) + n_sig = sum(1 for p in ps if p < 0.05) + deltas = [c["metrics"]["token_jaccard"]["delta_median"] for c in cs] + stbl = [c["metrics"]["token_jaccard"]["stable_median"] for c in cs] + brit = [c["metrics"]["token_jaccard"]["brittle_median"] for c in cs] + noise = [c["metrics"]["token_jaccard"]["noise_floor_median"] for c in cs + if c["metrics"]["token_jaccard"].get("noise_floor_median") is not None] + collapse = [c["brittle_collapse_rate"] for c in cs] + print(f"{SHORT[v]:<6} {len(cs):>7} {n_pos:>6} {n_sig:>8} " + f"{statistics.median(ds):>+7.2f} {statistics.fmean(ds):>+7.2f} " + f"{statistics.fmean(deltas):>+9.4f} " + f"{statistics.fmean(stbl):>10.3f} {statistics.fmean(brit):>10.3f} " + f"{statistics.fmean(noise):>11.3f} " + f"{statistics.fmean(collapse)*100:>13.1f}%") + + # Variant-aggregate (across all models, n-weighted) + print("\n" + "=" * 90) + print("ALL CELLS (18 models × 4 surface variants)") + print("=" * 90) + all_d = [c["metrics"]["token_jaccard"]["cohens_d"] for c in cells] + all_p = [c["metrics"]["token_jaccard"]["p_two_sided"] for c in cells] + print(f" cells: {len(cells)}") + print(f" direction-positive: {sum(1 for d in all_d if d>0)}/{len(cells)}") + print(f" p<0.05: {sum(1 for p in all_p if p<0.05)}/{len(cells)}") + print(f" p<0.001: {sum(1 for p in all_p if p<0.001)}/{len(cells)}") + print(f" p<1e-6: {sum(1 for p in all_p if p<1e-6)}/{len(cells)}") + print(f" Cohen's d median: {statistics.median(all_d):+.3f}") + print(f" Cohen's d mean: {statistics.fmean(all_d):+.3f}") + print(f" Cohen's d range: [{min(all_d):+.2f}, {max(all_d):+.2f}]") + + # Per-model aggregate (averaged across 4 surface variants) + per_model = defaultdict(list) + for c in cells: + per_model[c["model"]].append(c) + print("\n" + "=" * 90) + print("PER MODEL (averaged across 4 surface variants)") + print("=" * 90) + print(f"\n{'Model':<25} {'mean-d':>7} {'mean-stbl':>10} {'mean-brit':>10} " + f"{'mean-coll%':>11} {'min-p':>9}") + print("-" * 80) + rows = [] + for m, cs in per_model.items(): + if len(cs) == 0: continue + d = statistics.fmean(c["metrics"]["token_jaccard"]["cohens_d"] for c in cs) + s = statistics.fmean(c["metrics"]["token_jaccard"]["stable_median"] for c in cs) + b = statistics.fmean(c["metrics"]["token_jaccard"]["brittle_median"] for c in cs) + col = statistics.fmean(c["brittle_collapse_rate"] for c in cs) * 100 + mp = min(c["metrics"]["token_jaccard"]["p_two_sided"] for c in cs) + rows.append((m, d, s, b, col, mp)) + for r in sorted(rows, key=lambda r: -r[1]): + print(f"{r[0]:<25} {r[1]:>+7.2f} {r[2]:>10.3f} {r[3]:>10.3f} {r[4]:>10.1f}% {r[5]:>9.1e}") + + +if __name__ == "__main__": + main() diff --git a/analysis/balance_diff.py b/analysis/balance_diff.py new file mode 100644 index 0000000..f420d46 --- /dev/null +++ b/analysis/balance_diff.py @@ -0,0 +1,109 @@ +"""Compare brace/paren/bracket balance BEFORE vs AFTER cleaning to check +whether the cleaner introduced any new imbalance.""" +from __future__ import annotations +import json +import tarfile +from pathlib import Path +from collections import Counter + +CURRENT_DIR = Path("/home/yurenh2/gap/putnam-bench-anon/dataset") +BACKUP_TAR = sorted(Path("/home/yurenh2/gap/analysis/dataset_backups").glob( + "putnam-bench-anon_dataset_*.tar.gz"))[-1] + + +def all_text(d: dict) -> str: + out = [] + for k in ("question", "solution"): + out.append(d.get(k) or "") + for vk, vd in (d.get("variants") or {}).items(): + if isinstance(vd, dict): + for k in ("question", "solution"): + out.append(vd.get(k) or "") + return "\n".join(out) + + +def balance(text: str): + return ( + text.count("{") - text.count("}"), + text.count("(") - text.count(")"), + text.count("[") - text.count("]"), + ) + + +def main(): + print("Loading backup ...") + backup = {} + with tarfile.open(BACKUP_TAR, "r:gz") as tar: + for member in tar.getmembers(): + if not member.isfile() or not member.name.endswith(".json"): + continue + f = tar.extractfile(member) + if not f: + continue + d = json.load(f) + backup[d.get("index")] = all_text(d) + print(f" loaded {len(backup)} backup problems") + + print("Loading current ...") + current = {} + for f in sorted(CURRENT_DIR.glob("*.json")): + d = json.load(open(f)) + current[d.get("index")] = all_text(d) + print(f" loaded {len(current)} current problems") + + # Per-file balance diff + introduced_imbalance = [] + fixed_imbalance = [] + same_imbalance = 0 + same_balanced = 0 + + n_brace_changed = 0 + n_paren_changed = 0 + n_brack_changed = 0 + + for idx in sorted(backup): + b_before = balance(backup[idx]) + b_after = balance(current.get(idx, "")) + was_bal = b_before == (0, 0, 0) + is_bal = b_after == (0, 0, 0) + if b_before != b_after: + if was_bal and not is_bal: + introduced_imbalance.append((idx, b_before, b_after)) + elif not was_bal and is_bal: + fixed_imbalance.append((idx, b_before, b_after)) + else: + if is_bal: + same_balanced += 1 + else: + same_imbalance += 1 + if b_before[0] != b_after[0]: n_brace_changed += 1 + if b_before[1] != b_after[1]: n_paren_changed += 1 + if b_before[2] != b_after[2]: n_brack_changed += 1 + + print(f"\n=== Per-file balance change summary ===") + print(f" Files with no change in any balance:") + print(f" balanced both before and after: {same_balanced}") + print(f" imbalanced before and after (same imbalance): {same_imbalance}") + print(f" Files where cleaner INTRODUCED new imbalance: " + f"{len(introduced_imbalance)}") + print(f" Files where cleaner FIXED prior imbalance: {len(fixed_imbalance)}") + print() + print(f" Files where {{ balance changed: {n_brace_changed}") + print(f" Files where ( balance changed: {n_paren_changed}") + print(f" Files where [ balance changed: {n_brack_changed}") + + if introduced_imbalance: + print(f"\n!!! Cleaner-introduced imbalances ({len(introduced_imbalance)}):") + for idx, before, after in introduced_imbalance[:10]: + print(f" {idx}: before={before}, after={after}") + else: + print("\n ✓ No cleaner-introduced imbalances found.") + + if fixed_imbalance: + print(f"\n Cleaner-fixed imbalances (top 10):") + for idx, before, after in fixed_imbalance[:10]: + print(f" {idx}: before={before}, after={after}") + + +if __name__ == "__main__": + main() diff --git a/analysis/cross_model_agreement.py b/analysis/cross_model_agreement.py new file mode 100644 index 0000000..fb9a571 --- /dev/null +++ b/analysis/cross_model_agreement.py @@ -0,0 +1,180 @@ +"""Cross-model agreement analysis: which problems are universally hard? + +For each (variant, problem) cell, count how many models fail (correct=False). +Identify "universally hard" problems (failed by ≥80% of models on the variant) +and "universally easy" (correct by ≥80% on the variant). Then check whether +the universally hard *flip set* is dominated by certain topics, problem types, +or years. + +Outputs: +- Per-variant histogram of failure counts +- "Universal flip" cases: original correct by ≥80% of models, variant wrong by ≥80% +- These are the cleanest signals of variant-induced fragility because they + rule out problem-specific quirks +""" +from __future__ import annotations +import json +import sys +from pathlib import Path +from collections import defaultdict, Counter + +THIS_DIR = Path(__file__).resolve().parent +sys.path.insert(0, str(THIS_DIR)) +from structural_overlap import find_variant_file, load_problems, RESULTS_DIR, SURFACE_VARIANTS + +DATASET_DIR = Path("/home/yurenh2/gap/putnam-bench-anon/dataset") + + +def load_metadata(): + """Load problem-level metadata: type, tag, difficulty, year.""" + out = {} + for f in sorted(DATASET_DIR.glob("*.json")): + d = json.load(open(f)) + idx = d.get("index") + out[idx] = { + "type": d.get("type"), + "tag": d.get("tag"), + "difficulty": d.get("difficulty"), + "problem_type": d.get("problem_type"), + "year": int(idx.split("-")[0]) if idx and "-" in idx else None, + } + return out + + +def main(): + base = RESULTS_DIR + models = sorted([d.name for d in base.iterdir() if d.is_dir()]) + print(f"Loading {len(models)} models ...") + metadata = load_metadata() + + # correct_table[(variant, idx)][model] = bool + correct_table = defaultdict(dict) + for m in models: + mdir = base / m + for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]: + vp = find_variant_file(mdir, v) + if not vp: + continue + for p in load_problems(vp): + idx = p.get("index") + correct = p.get("correct") + if idx is None or correct is None: + continue + correct_table[(v, idx)][m] = correct + + print(f"Loaded {len(correct_table)} (variant, problem) cells.\n") + + # Per-variant histogram of correct counts (out of N models) + print("=== HISTOGRAM OF CORRECT-COUNT ACROSS MODELS ===") + print("(How many models get each problem right per variant)\n") + print(f"{'Variant':<24} {'mean correct/N':>16} {'median':>9} {'#unanimous-fail':>17} {'#unanimous-pass':>17}") + print("-" * 90) + for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]: + cells = [d for (vv, idx), d in correct_table.items() if vv == v] + if not cells: + continue + counts = [sum(1 for vv in cell.values() if vv) / len(cell) for cell in cells] + unanimous_fail = sum(1 for cell in cells if not any(cell.values()) and len(cell) >= 3) + unanimous_pass = sum(1 for cell in cells if all(cell.values()) and len(cell) >= 3) + import statistics + print(f"{v:<24} {statistics.fmean(counts)*100:>14.1f}% {statistics.median(counts)*100:>7.1f}% " + f"{unanimous_fail:>17} {unanimous_pass:>17}") + + # Universal flip cases: original correct by ≥80% of models, variant wrong by ≥80% + print(f"\n\n=== UNIVERSAL FLIP CASES (orig ≥80% correct, variant ≥80% wrong) ===\n") + print("These are the cleanest signals of variant-induced fragility.\n") + print(f"{'Variant':<24} {'# universal flips':>20}") + print("-" * 50) + universal_flips = defaultdict(list) + for v in SURFACE_VARIANTS + ["kernel_variant"]: + for idx in {ii for (vv, ii) in correct_table if vv == "original"}: + orig_cell = correct_table.get(("original", idx), {}) + var_cell = correct_table.get((v, idx), {}) + common = set(orig_cell) & set(var_cell) + if len(common) < 5: continue + orig_rate = sum(1 for m in common if orig_cell[m]) / len(common) + var_rate = sum(1 for m in common if var_cell[m]) / len(common) + if orig_rate >= 0.80 and var_rate <= 0.20: + universal_flips[v].append((idx, orig_rate, var_rate)) + print(f"{v:<24} {len(universal_flips[v]):>20}") + + # Topic / problem_type / difficulty / year breakdown for universal flips + print(f"\n\n=== TOPIC BREAKDOWN OF UNIVERSAL FLIPS ===\n") + for v in SURFACE_VARIANTS + ["kernel_variant"]: + if not universal_flips[v]: continue + print(f"--- {v} ({len(universal_flips[v])} universal flips) ---") + topics = Counter() + ptypes = Counter() + difficulties = Counter() + years = Counter() + for idx, _, _ in universal_flips[v]: + md = metadata.get(idx, {}) + tag = md.get("tag") + # tag may be a list (multi-tag) or a string + if isinstance(tag, list): + for t in tag: topics[t] += 1 + elif tag: + topics[tag] += 1 + else: + topics["?"] += 1 + ptypes[md.get("problem_type") or "?"] += 1 + diff = md.get("difficulty") + if isinstance(diff, list): diff = diff[0] if diff else "?" + difficulties[diff or "?"] += 1 + year = md.get("year") + if year: + # Bin years by decade + decade = (year // 10) * 10 + years[f"{decade}s"] += 1 + print(f" topics: {dict(topics.most_common(8))}") + print(f" problem_type:{dict(ptypes)}") + print(f" difficulties:{dict(difficulties.most_common(6))}") + print(f" decades: {dict(sorted(years.items()))}") + print() + + # Save universal flips for later analysis + out = {v: [{"index": idx, "orig_rate": o, "var_rate": vr} + for (idx, o, vr) in flips] + for v, flips in universal_flips.items()} + json.dump(out, open(THIS_DIR / "universal_flips.json", "w"), indent=2) + print(f"\nSaved -> analysis/universal_flips.json") + + # Topic-stratified analysis: failure rate per topic per variant + print(f"\n\n=== ACCURACY BY TOPIC × VARIANT (mean across models) ===\n") + by_topic_variant = defaultdict(lambda: defaultdict(list)) + for (v, idx), cell in correct_table.items(): + md = metadata.get(idx, {}) + tag = md.get("tag") + if not tag or len(cell) < 5: continue + # If multiple tags, attribute the same rate to each — keeps it simple + topics_for_problem = tag if isinstance(tag, list) else [tag] + rate = sum(1 for vv in cell.values() if vv) / len(cell) + for t in topics_for_problem: + by_topic_variant[t][v].append(rate) + + topics_to_show = ["ALG", "ANA", "NT", "COMB", "GEO"] + print(f"{'Topic':<8}", end="") + for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]: + short = {"original":"orig","descriptive_long":"DL","descriptive_long_confusing":"DLC", + "descriptive_long_misleading":"DLM","garbled_string":"GS","kernel_variant":"KV"}[v] + print(f" {short:>5}", end="") + print(" Δ_orig→KV") + print("-" * 70) + for t in topics_to_show: + if t not in by_topic_variant: continue + row = by_topic_variant[t] + if "original" not in row: continue + orig_mean = statistics.fmean(row["original"]) * 100 + print(f"{t:<8}", end="") + for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]: + if v in row: + m = statistics.fmean(row[v]) * 100 + print(f" {m:>4.1f}%", end="") + else: + print(f" {'-':>5}", end="") + kv_mean = statistics.fmean(row.get("kernel_variant", [0])) * 100 + print(f" {kv_mean - orig_mean:+5.1f}pp") + + +if __name__ == "__main__": + main() diff --git a/analysis/kv_overlap.py b/analysis/kv_overlap.py new file mode 100644 index 0000000..137e61f --- /dev/null +++ b/analysis/kv_overlap.py @@ -0,0 +1,332 @@ +"""Kernel-variant structural-overlap analysis (label-free). + +Unlike surface variants, kernel variants change the math, so we cannot use the +model's own original-correct trajectory as a reference. Instead we use the +dataset's canonical kernel-variant solution as the reference. + +Hypothesis: stable (correct on KV) trajectories have higher structural overlap +with the canonical KV solution than brittle (wrong on KV) trajectories. + +For comparability we also recompute the surface analyses using the same +'overlap with canonical solution' metric, so we can compare apples-to-apples +the magnitude of stable-vs-brittle gap between surface and kernel. +""" +from __future__ import annotations +import json +import os +import statistics +from pathlib import Path +from collections import defaultdict +from typing import Optional + +# Reuse helpers from the sibling module +import sys +sys.path.insert(0, str(Path(__file__).parent)) +from structural_overlap import ( + DATASET_DIR, RESULTS_DIR, + load_problems, find_variant_file, + canonicalize_text, normalize_whitespace, + tokens, bigrams, jaccard, extract_math_blocks, + metric_token_jaccard, metric_bigram_jaccard, + metric_directional_coverage, metric_equation_jaccard, + mann_whitney_u, bootstrap_ci_cohens_d, + is_collapse, COLLAPSE_MIN_CHARS, COLLAPSE_RATIO, + SURFACE_VARIANTS, +) + + +def load_dataset_variant_solutions() -> dict: + """Returns: {problem_index: {variant_name: canonical_solution_text}}. + + Includes 'original' (from top-level field) plus all 5 variants. + """ + out = {} + for f in sorted(DATASET_DIR.glob("*.json")): + d = json.load(open(f)) + idx = d.get("index") + cell = {"original": d.get("solution") or "", + "_problem_type": d.get("problem_type")} + for v, vd in d.get("variants", {}).items(): + if isinstance(vd, dict): + cell[v] = vd.get("solution") or "" + out[idx] = cell + return out + + +def load_dataset_maps() -> dict: + """Mirrors structural_overlap.load_dataset_maps but localized for safety.""" + out = {} + for f in sorted(DATASET_DIR.glob("*.json")): + d = json.load(open(f)) + idx = d.get("index") + variants = d.get("variants", {}) + cell = {} + for v in SURFACE_VARIANTS: + vd = variants.get(v, {}) + mp_str = vd.get("map") + if isinstance(mp_str, str): + try: + mp = eval(mp_str, {"__builtins__": {}}, {}) + if isinstance(mp, dict): + cell[v] = {str(k): str(v) for k, v in mp.items()} + except Exception: + pass + elif isinstance(mp_str, dict): + cell[v] = {str(k): str(v) for k, v in mp_str.items()} + out[idx] = cell + return out + + +# ---------- Cell analyzer ---------- + +def analyze_kv_cell(model_name: str, model_dir: Path, + canonical_solutions: dict) -> Optional[dict]: + """Compare model's KV trajectory to dataset canonical KV solution. + + No canonicalization (no rename map for KV — variables match by construction). + """ + orig_path = find_variant_file(model_dir, "original") + var_path = find_variant_file(model_dir, "kernel_variant") + if not orig_path or not var_path: + return None + orig_by = {p["index"]: p for p in load_problems(orig_path)} + var_by = {p["index"]: p for p in load_problems(var_path)} + + pairs_stable_drift = [] + pairs_brittle_drift = [] + n_brittle_collapse = 0 + n_stable_collapse = 0 + + for idx in set(orig_by) & set(var_by): + po, pv = orig_by[idx], var_by[idx] + if po.get("correct") is not True: + continue # Restrict to "model already gets the original" + var_correct = pv.get("correct") + if var_correct is None: + continue + var_text = (pv.get("solve") or {}).get("solution") or "" + if not var_text: + continue + canon_kv = canonical_solutions.get(idx, {}).get("kernel_variant", "") + if not canon_kv or len(canon_kv) < 200: + continue + # Collapse rule: variant text < 200 chars OR < 25% of canonical solution + collapse = (len(var_text) < COLLAPSE_MIN_CHARS or + len(var_text) < COLLAPSE_RATIO * len(canon_kv)) + sample = {"index": idx, "var_text": var_text, "canon": canon_kv} + if var_correct is True: + if collapse: + n_stable_collapse += 1 + else: + pairs_stable_drift.append(sample) + else: + if collapse: + n_brittle_collapse += 1 + else: + pairs_brittle_drift.append(sample) + + if not pairs_stable_drift or not pairs_brittle_drift: + return None + + metrics = { + "token_jaccard": metric_token_jaccard, + "bigram_jaccard": metric_bigram_jaccard, + "equation_jaccard": metric_equation_jaccard, + "directional_coverage": metric_directional_coverage, + } + + out = { + "model": model_name, + "variant": "kernel_variant", + "n_stable_drift": len(pairs_stable_drift), + "n_brittle_drift": len(pairs_brittle_drift), + "n_brittle_collapse": n_brittle_collapse, + "n_stable_collapse": n_stable_collapse, + "brittle_collapse_rate": n_brittle_collapse / + max(1, n_brittle_collapse + len(pairs_brittle_drift)), + "metrics": {}, + } + for mname, mfn in metrics.items(): + s_vals = [mfn(p["var_text"], p["canon"]) for p in pairs_stable_drift] + b_vals = [mfn(p["var_text"], p["canon"]) for p in pairs_brittle_drift] + U, p = mann_whitney_u(s_vals, b_vals) + sm, bm = statistics.fmean(s_vals), statistics.fmean(b_vals) + ssd = statistics.pstdev(s_vals) if len(s_vals) > 1 else 0 + bsd = statistics.pstdev(b_vals) if len(b_vals) > 1 else 0 + pooled = (((len(s_vals)-1)*ssd**2 + (len(b_vals)-1)*bsd**2) + / max(1, len(s_vals)+len(b_vals)-2)) ** 0.5 + d = (sm - bm) / pooled if pooled > 0 else 0.0 + out["metrics"][mname] = { + "stable_median": statistics.median(s_vals), + "stable_mean": sm, + "brittle_median": statistics.median(b_vals), + "brittle_mean": bm, + "delta_median": statistics.median(s_vals) - statistics.median(b_vals), + "cohens_d": d, + "U": U, + "p_two_sided": p, + } + # Headline bootstrap + s_vals = [metric_token_jaccard(p["var_text"], p["canon"]) for p in pairs_stable_drift] + b_vals = [metric_token_jaccard(p["var_text"], p["canon"]) for p in pairs_brittle_drift] + d_lo, d_hi = bootstrap_ci_cohens_d(s_vals, b_vals, n_iter=400) + out["metrics"]["token_jaccard"]["cohens_d_ci"] = [d_lo, d_hi] + return out + + +# ---------- Surface re-analysis with canonical reference ---------- + +def analyze_surface_cell_against_canonical(model_name: str, variant: str, + model_dir: Path, + canonical_solutions: dict) -> Optional[dict]: + """Compare model variant trajectory to dataset canonical variant solution. + + For comparability with KV. No rename canonicalization needed since both + sides use the same variant naming. + """ + var_path = find_variant_file(model_dir, variant) + orig_path = find_variant_file(model_dir, "original") + if not var_path or not orig_path: + return None + var_by = {p["index"]: p for p in load_problems(var_path)} + orig_by = {p["index"]: p for p in load_problems(orig_path)} + + pairs_stable, pairs_brittle = [], [] + n_brittle_collapse = 0 + for idx in set(var_by): + if idx not in orig_by: + continue + if orig_by[idx].get("correct") is not True: + continue # restrict to model-knows-original + pv = var_by[idx] + var_correct = pv.get("correct") + if var_correct is None: + continue + var_text = (pv.get("solve") or {}).get("solution") or "" + if not var_text: + continue + canon_var = canonical_solutions.get(idx, {}).get(variant, "") + if not canon_var or len(canon_var) < 200: + continue + if (len(var_text) < COLLAPSE_MIN_CHARS or + len(var_text) < COLLAPSE_RATIO * len(canon_var)): + if var_correct is False: + n_brittle_collapse += 1 + continue + sample = {"index": idx, "var_text": var_text, "canon": canon_var} + if var_correct is True: + pairs_stable.append(sample) + else: + pairs_brittle.append(sample) + + if not pairs_stable or not pairs_brittle: + return None + + metrics = { + "token_jaccard": metric_token_jaccard, + "bigram_jaccard": metric_bigram_jaccard, + "equation_jaccard": metric_equation_jaccard, + "directional_coverage": metric_directional_coverage, + } + out = { + "model": model_name, + "variant": variant, + "n_stable_drift": len(pairs_stable), + "n_brittle_drift": len(pairs_brittle), + "n_brittle_collapse": n_brittle_collapse, + "brittle_collapse_rate": n_brittle_collapse / + max(1, n_brittle_collapse + len(pairs_brittle)), + "metrics": {}, + } + for mname, mfn in metrics.items(): + s_vals = [mfn(p["var_text"], p["canon"]) for p in pairs_stable] + b_vals = [mfn(p["var_text"], p["canon"]) for p in pairs_brittle] + U, p = mann_whitney_u(s_vals, b_vals) + sm, bm = statistics.fmean(s_vals), statistics.fmean(b_vals) + ssd = statistics.pstdev(s_vals) if len(s_vals) > 1 else 0 + bsd = statistics.pstdev(b_vals) if len(b_vals) > 1 else 0 + pooled = (((len(s_vals)-1)*ssd**2 + (len(b_vals)-1)*bsd**2) + / max(1, len(s_vals)+len(b_vals)-2)) ** 0.5 + d = (sm - bm) / pooled if pooled > 0 else 0.0 + out["metrics"][mname] = { + "stable_median": statistics.median(s_vals), + "stable_mean": sm, + "brittle_median": statistics.median(b_vals), + "brittle_mean": bm, + "delta_median": statistics.median(s_vals) - statistics.median(b_vals), + "cohens_d": d, + "U": U, + "p_two_sided": p, + } + return out + + +def main(): + print("Loading canonical solutions ...") + canon = load_dataset_variant_solutions() + print(f" loaded {len(canon)} problems") + + all_models = sorted([d.name for d in RESULTS_DIR.iterdir() if d.is_dir()]) + + kv_results = [] + surface_results = [] + + print(f"\n{'KERNEL VARIANT — variant trajectory vs canonical KV solution':<70}") + print(f"{'Cell':<32} {'nSd':>4} {'nBd':>4} {'col%':>5} " + f"{'sMed':>6} {'bMed':>6} {'d':>6} {'p':>9}") + print("-" * 90) + for m in all_models: + mdir = RESULTS_DIR / m + if not mdir.exists(): + continue + res = analyze_kv_cell(m, mdir, canon) + if res is None: + continue + kv_results.append(res) + md = res["metrics"]["token_jaccard"] + print(f"{m+' / KV':<32} {res['n_stable_drift']:>4} {res['n_brittle_drift']:>4} " + f"{res['brittle_collapse_rate']*100:>4.0f}% " + f"{md['stable_median']:>6.3f} {md['brittle_median']:>6.3f} " + f"{md['cohens_d']:>+6.2f} {md['p_two_sided']:>9.1e}") + + print(f"\n{'SURFACE VARIANT — variant trajectory vs canonical variant solution':<70}") + print(f"{'Cell':<46} {'nSd':>4} {'nBd':>4} {'col%':>5} " + f"{'sMed':>6} {'bMed':>6} {'d':>6} {'p':>9}") + print("-" * 95) + for m in all_models: + mdir = RESULTS_DIR / m + if not mdir.exists(): + continue + for v in SURFACE_VARIANTS: + res = analyze_surface_cell_against_canonical(m, v, mdir, canon) + if res is None: + continue + surface_results.append(res) + md = res["metrics"]["token_jaccard"] + print(f"{m+' / '+v:<46} {res['n_stable_drift']:>4} {res['n_brittle_drift']:>4} " + f"{res['brittle_collapse_rate']*100:>4.0f}% " + f"{md['stable_median']:>6.3f} {md['brittle_median']:>6.3f} " + f"{md['cohens_d']:>+6.2f} {md['p_two_sided']:>9.1e}") + + # Save + json.dump(kv_results, open("/home/yurenh2/gap/analysis/kv_overlap_results.json", "w"), indent=2) + json.dump(surface_results, open("/home/yurenh2/gap/analysis/surface_canonical_results.json", "w"), indent=2) + + # Aggregate compare + print("\n" + "=" * 80) + print("AGGREGATE: surface (vs canonical) vs kernel (vs canonical)") + print("=" * 80) + for tag, results in [("surface", surface_results), ("kernel", kv_results)]: + ds = [c["metrics"]["token_jaccard"]["cohens_d"] for c in results] + ps = [c["metrics"]["token_jaccard"]["p_two_sided"] for c in results] + col = [c["brittle_collapse_rate"] for c in results] + if not ds: + continue + print(f"{tag:<8} cells={len(ds):>3} d_pos={sum(1 for d in ds if d>0):>3}/{len(ds):<3} " + f"p<.05={sum(1 for p in ps if p<0.05):>3}/{len(ps):<3} " + f"d_med={statistics.median(ds):+.2f} d_mean={statistics.fmean(ds):+.2f} " + f"collapse_mean={statistics.fmean(col)*100:.1f}%") + + +if __name__ == "__main__": + main() diff --git a/analysis/make_figures.py b/analysis/make_figures.py new file mode 100644 index 0000000..4ff598d --- /dev/null +++ b/analysis/make_figures.py @@ -0,0 +1,272 @@ +"""Three rebuttal figures. + +Fig1 — Structural Cohen's d heatmap + 18 models × 5 variants (4 surface + KV). + Surface cells use the self-anchor metric (model's own original under + inverse rename). KV uses the canonical-anchor metric. + +Fig2 — Rescue rebound rates by variant + condition + Pooled across 4 models. Bar plot with Wilson 95 % CI. + Three bars per variant: null / canonical_T2 / own_T2 (KV: only 2). + +Fig3 — own_T2 vs canonical_T2 per (model, variant) + Scatter plot of own_T2 rebound rate vs canonical_T2 rebound rate per + cell, with the y=x line. Points above the diagonal: own outperforms + canonical (rare); below: canonical outperforms own (typical). +""" +from __future__ import annotations +import json +import math +import statistics +from pathlib import Path +from collections import defaultdict + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np + +ROOT = Path("/home/yurenh2/gap/analysis") +FIG_DIR = ROOT / "figures" +FIG_DIR.mkdir(parents=True, exist_ok=True) + +VARIANT_LABELS = { + "descriptive_long": "DL", + "descriptive_long_confusing": "DLC", + "descriptive_long_misleading": "DLM", + "garbled_string": "GS", + "kernel_variant": "KV", +} +VARIANT_ORDER_SURF = ["descriptive_long", "descriptive_long_confusing", + "descriptive_long_misleading", "garbled_string"] +VARIANT_ORDER_ALL = VARIANT_ORDER_SURF + ["kernel_variant"] + +# ---------------------------------------------------------------------- +# Fig 1 — Structural Cohen's d heatmap +# ---------------------------------------------------------------------- + +def fig1_structural_d_heatmap(): + """Heatmap of Cohen's d for the stable-vs-brittle structural metric. + + Surface cells: self-anchor (token Jaccard between model's variant + trajectory and its own original-correct trajectory after canonicalization). + Source file: structural_overlap_results.json. + + KV cells: canonical-anchor (token Jaccard between model's KV trajectory and + the dataset's canonical KV solution). + Source file: kv_overlap_results.json. + """ + surf = json.load(open(ROOT / "structural_overlap_results.json")) + kv = json.load(open(ROOT / "kv_overlap_results.json")) + + # Build matrix: rows = models (sorted by mean d), cols = variants (DL, DLC, DLM, GS, KV) + by_cell = {} + for c in surf: + by_cell[(c["model"], c["variant"])] = c["metrics"]["token_jaccard"]["cohens_d"] + for c in kv: + by_cell[(c["model"], "kernel_variant")] = c["metrics"]["token_jaccard"]["cohens_d"] + + models = sorted({k[0] for k in by_cell}) + # Sort by mean d across surface variants only (so KV doesn't bias the order) + def mean_surface_d(m): + ds = [by_cell.get((m, v)) for v in VARIANT_ORDER_SURF + if by_cell.get((m, v)) is not None] + return statistics.fmean(ds) if ds else 0.0 + models.sort(key=mean_surface_d, reverse=True) + + M = np.full((len(models), len(VARIANT_ORDER_ALL)), np.nan) + for i, m in enumerate(models): + for j, v in enumerate(VARIANT_ORDER_ALL): + d = by_cell.get((m, v)) + if d is not None: + M[i, j] = d + + fig, ax = plt.subplots(figsize=(7, 9)) + vmin = 0.0 + vmax = 1.4 + cmap = plt.cm.viridis + im = ax.imshow(M, cmap=cmap, vmin=vmin, vmax=vmax, aspect="auto") + ax.set_xticks(range(len(VARIANT_ORDER_ALL))) + ax.set_xticklabels([VARIANT_LABELS[v] for v in VARIANT_ORDER_ALL]) + ax.set_yticks(range(len(models))) + ax.set_yticklabels(models, fontsize=9) + # Annotate values + for i in range(len(models)): + for j in range(len(VARIANT_ORDER_ALL)): + v = M[i, j] + if not math.isnan(v): + color = "white" if v < 0.7 else "black" + ax.text(j, i, f"{v:+.2f}", ha="center", va="center", + fontsize=8, color=color) + # Vertical line separating surface from KV + ax.axvline(x=3.5, color="white", linewidth=2) + cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) + cbar.set_label("Cohen's d (stable − brittle)\non canonicalized token Jaccard", + fontsize=9) + ax.set_title("Structural overlap effect size: stable vs brittle\n" + "(surface = self-anchor; KV = canonical-anchor)", + fontsize=11) + ax.set_xlabel("Variant family", fontsize=10) + plt.tight_layout() + out = FIG_DIR / "fig1_structural_d_heatmap.png" + plt.savefig(out, dpi=200, bbox_inches="tight") + plt.close() + print(f"Saved {out}") + + +# ---------------------------------------------------------------------- +# Fig 2 — Rescue rebound rates with Wilson CI +# ---------------------------------------------------------------------- + +def wilson_ci(k: int, n: int, z: float = 1.96): + if n == 0: + return (0.0, 0.0, 0.0) + p = k / n + denom = 1 + z * z / n + center = (p + z * z / (2 * n)) / denom + half = z * math.sqrt(p * (1 - p) / n + z * z / (4 * n * n)) / denom + return (p, max(0.0, center - half), min(1.0, center + half)) + + +def fig2_rescue_rates(): + rows = [json.loads(l) for l in open(ROOT / "rescue_results/rescue_30.jsonl")] + + counts = defaultdict(lambda: {"k": 0, "n": 0}) + for r in rows: + counts[(r["variant"], r["condition"])]["n"] += 1 + if r.get("grade") == "CORRECT": + counts[(r["variant"], r["condition"])]["k"] += 1 + + conds_full = ["null", "canonical_T2", "own_T2"] + cond_color = {"null": "#888888", "canonical_T2": "#1f77b4", "own_T2": "#d62728"} + cond_label = {"null": "null (generic scaffold)", + "canonical_T2": "canonical_T2 (item-specific, expert prose)", + "own_T2": "own_T2 (item-specific, model's own work, renamed)"} + + fig, ax = plt.subplots(figsize=(8, 5)) + n_var = len(VARIANT_ORDER_ALL) + width = 0.27 + x = np.arange(n_var) + for ci, cond in enumerate(conds_full): + ks, lows, highs, ps = [], [], [], [] + for v in VARIANT_ORDER_ALL: + d = counts.get((v, cond)) + if d is None: + ks.append(0); lows.append(0); highs.append(0); ps.append(0) + continue + p, lo, hi = wilson_ci(d["k"], d["n"]) + ps.append(p * 100) + lows.append((p - lo) * 100) + highs.append((hi - p) * 100) + ks.append(d["k"]) + offset = (ci - 1) * width + ax.bar(x + offset, ps, width=width, color=cond_color[cond], label=cond_label[cond], + yerr=[lows, highs], capsize=3, error_kw={"elinewidth": 1, "ecolor": "#444444"}) + # Annotate counts above each bar + for xi, p, k in zip(x + offset, ps, ks): + if k > 0: + ax.text(xi, p + 0.5, f"{p:.0f}%", ha="center", va="bottom", fontsize=8) + + ax.set_xticks(x) + ax.set_xticklabels([VARIANT_LABELS[v] for v in VARIANT_ORDER_ALL], fontsize=10) + ax.set_ylabel("Rebound rate (%) on flip cases", fontsize=10) + ax.set_title("Repairability rescue: rebound rate by variant and prefix condition\n" + "(pooled across 4 models, n ≈ 100–120 per cell, 95% Wilson CI)", + fontsize=11) + ax.set_ylim(0, 60) + ax.legend(loc="upper right", fontsize=8, framealpha=0.95) + ax.grid(axis="y", linestyle="--", alpha=0.4) + ax.set_axisbelow(True) + plt.tight_layout() + out = FIG_DIR / "fig2_rescue_rebound.png" + plt.savefig(out, dpi=200, bbox_inches="tight") + plt.close() + print(f"Saved {out}") + + +# ---------------------------------------------------------------------- +# Fig 3 — own_T2 vs canonical_T2 scatter +# ---------------------------------------------------------------------- + +def fig3_own_vs_canonical_scatter(): + rows = [json.loads(l) for l in open(ROOT / "rescue_results/rescue_30.jsonl")] + + counts = defaultdict(lambda: {"k": 0, "n": 0}) + for r in rows: + counts[(r["model"], r["variant"], r["condition"])]["n"] += 1 + if r.get("grade") == "CORRECT": + counts[(r["model"], r["variant"], r["condition"])]["k"] += 1 + + fig, ax = plt.subplots(figsize=(7, 7)) + + models_in_data = sorted({k[0] for k in counts}) + model_color = { + "claude-sonnet-4": "#ff7f0e", + "gemini-2.5-flash": "#2ca02c", + "gpt-4.1-mini": "#1f77b4", + "gpt-4o-mini": "#d62728", + } + var_marker = { + "descriptive_long": "o", + "descriptive_long_confusing": "s", + "descriptive_long_misleading": "^", + "garbled_string": "D", + } + + # Diagonal + ax.plot([0, 0.7], [0, 0.7], "k--", lw=1, alpha=0.5) + ax.text(0.62, 0.66, "y = x", fontsize=8, alpha=0.6) + + for m in models_in_data: + for v in VARIANT_ORDER_SURF: + own = counts.get((m, v, "own_T2")) + can = counts.get((m, v, "canonical_T2")) + if own is None or can is None or own["n"] == 0 or can["n"] == 0: + continue + x = can["k"] / can["n"] + y = own["k"] / own["n"] + ax.scatter(x, y, s=110, c=model_color.get(m, "gray"), + marker=var_marker[v], alpha=0.85, + edgecolors="black", linewidths=0.6) + + # Build legend + from matplotlib.lines import Line2D + model_handles = [Line2D([], [], marker="o", linestyle="", markersize=9, + markerfacecolor=c, markeredgecolor="black", + markeredgewidth=0.6, label=m) + for m, c in model_color.items() if m in models_in_data] + variant_handles = [Line2D([], [], marker=mk, linestyle="", markersize=9, + markerfacecolor="lightgray", markeredgecolor="black", + markeredgewidth=0.6, label=VARIANT_LABELS[v]) + for v, mk in var_marker.items()] + leg1 = ax.legend(handles=model_handles, loc="upper left", title="Model", + fontsize=8, title_fontsize=9, framealpha=0.95) + ax.add_artist(leg1) + ax.legend(handles=variant_handles, loc="lower right", title="Variant", + fontsize=8, title_fontsize=9, framealpha=0.95) + + ax.set_xlim(0, 0.7) + ax.set_ylim(0, 0.7) + ax.set_xlabel("canonical_T2 rebound rate", fontsize=10) + ax.set_ylabel("own_T2 rebound rate", fontsize=10) + ax.set_title("Per-cell rescue rates: model's own prefix vs canonical prefix\n" + "(below diagonal = canonical wins; gpt-4o-mini is the only family above)", + fontsize=11) + ax.grid(linestyle="--", alpha=0.4) + ax.set_axisbelow(True) + plt.tight_layout() + out = FIG_DIR / "fig3_own_vs_canonical_scatter.png" + plt.savefig(out, dpi=200, bbox_inches="tight") + plt.close() + print(f"Saved {out}") + + +def main(): + fig1_structural_d_heatmap() + fig2_rescue_rates() + fig3_own_vs_canonical_scatter() + print("\nAll figures written to:", FIG_DIR) + + +if __name__ == "__main__": + main() diff --git a/analysis/normalization_analysis.py b/analysis/normalization_analysis.py new file mode 100644 index 0000000..8fb4f48 --- /dev/null +++ b/analysis/normalization_analysis.py @@ -0,0 +1,189 @@ +"""Quantify spontaneous variant->canonical name normalization in own_T2 outputs. + +For each own_T2 case, check whether the model's student_solution preserves the +variant variable names from its prefix or normalizes them back to the canonical +names from the dataset's rename map. + +For each variant variable name in the rename map: +- count its occurrences in the prefix (as injected) +- count its occurrences in the model's student_solution +- count occurrences of the corresponding CANONICAL name in the student_solution + +If the model preserves variant naming: variant_name count in solution should be +proportionally similar to the prefix count. +If the model normalizes back: canonical_name count in solution should rise while +variant_name count drops. +""" +from __future__ import annotations +import json +import re +import sys +from pathlib import Path +from collections import defaultdict +import statistics + +THIS_DIR = Path(__file__).resolve().parent +sys.path.insert(0, str(THIS_DIR)) +from rescue_runner import load_dataset_full, find_flip_cases, build_case_prompts + +PILOT_PATH = Path("/home/yurenh2/gap/analysis/rescue_results/rescue_5.jsonl") + + +def count_word(text: str, word: str) -> int: + """Whole-word count of `word` in `text`.""" + if not text or not word: + return 0 + pat = r"(? dict: + """For one own_T2 row, compute name preservation stats.""" + variant = row["variant"] + var_info = ds_cell["variants"].get(variant, {}) + rmap = var_info.get("map") or {} + if not rmap: + return {} + student = row.get("student_solution") or "" + if not student: + return {} + + # Build the prefix that the model was given + case = { + "index": row["index"], + "problem_type": row["problem_type"], + "orig_solution": "", + "orig_final_answer": "", + } + # We need the original solution from results_new to reconstruct the prefix. + # Use find_flip_cases to recover it cleanly. + cases = find_flip_cases(row["model"], variant, 100) + matched = next((c for c in cases if c["index"] == row["index"]), None) + if matched is None: + return {} + prompts = build_case_prompts(matched, variant, ds_cell) + own_prompt = prompts.get("own_T2", "") + if "PARTIAL WORK" not in own_prompt: + return {} + # Extract just the partial work text + section = own_prompt.split("PARTIAL WORK")[1].split("Provide a complete")[0] + section = section.split("(to copy verbatim")[1] if "(to copy verbatim" in section else section + section = section.split("):", 1)[1] if "):" in section else section + prefix = section.strip() + + # For each variant variable, count occurrences in prefix and in student + per_var = {} + for canon_name, var_name in rmap.items(): + if not var_name: + continue + prefix_v = count_word(prefix, var_name) + student_v = count_word(student, var_name) + student_c = count_word(student, canon_name) + # Only meaningful if the variant name actually appeared in the prefix + if prefix_v == 0: + continue + per_var[var_name] = { + "canon_name": canon_name, + "prefix_count_variant": prefix_v, + "student_count_variant": student_v, + "student_count_canonical": student_c, + # Preservation ratio: how much of the variant naming survived + # capped to 1.0 (model may use the variable many more times in + # its continuation, which inflates the count) + "preservation_ratio": min(1.0, student_v / max(1, prefix_v)), + "normalization_ratio": min(1.0, student_c / max(1, prefix_v)), + } + if not per_var: + return {} + # Aggregate per case: median preservation + pres_vals = [v["preservation_ratio"] for v in per_var.values()] + norm_vals = [v["normalization_ratio"] for v in per_var.values()] + return { + "model": row["model"], + "variant": variant, + "index": row["index"], + "grade": row.get("grade"), + "n_vars_in_prefix": len(per_var), + "median_preservation": statistics.median(pres_vals), + "median_normalization": statistics.median(norm_vals), + "mean_preservation": statistics.fmean(pres_vals), + "mean_normalization": statistics.fmean(norm_vals), + "per_var": per_var, + } + + +def main(): + print("Loading dataset ...") + ds = load_dataset_full() + print(f"Loaded {len(ds)} problems") + print(f"\nLoading pilot rows from {PILOT_PATH} ...") + rows = [json.loads(l) for l in open(PILOT_PATH)] + own_rows = [r for r in rows if r["condition"] == "own_T2"] + print(f" total rows: {len(rows)}, own_T2 rows: {len(own_rows)}") + + analyses = [] + skipped = 0 + for r in own_rows: + ds_cell = ds.get(r["index"]) + if ds_cell is None: + skipped += 1 + continue + a = analyze_one(r, ds_cell) + if a: + analyses.append(a) + else: + skipped += 1 + print(f" analyzed: {len(analyses)}, skipped: {skipped}") + + # Aggregate by variant + print("\n=== SPONTANEOUS NORMALIZATION (own_T2 condition only) ===\n") + print("Per case: median across variant variables of preservation ratio") + print("(higher = more variant naming preserved; lower = normalized back to canonical)") + print() + print(f"{'Variant':<32} {'n':>4} {'median_pres':>12} {'mean_pres':>10} " + f"{'median_norm':>12} {'mean_norm':>10}") + print("-" * 90) + by_variant = defaultdict(list) + for a in analyses: + by_variant[a["variant"]].append(a) + for v in sorted(by_variant): + cs = by_variant[v] + mp_vals = [c["median_preservation"] for c in cs] + mn_vals = [c["median_normalization"] for c in cs] + print(f"{v:<32} {len(cs):>4} " + f"{statistics.median(mp_vals):>12.3f} {statistics.fmean(mp_vals):>10.3f} " + f"{statistics.median(mn_vals):>12.3f} {statistics.fmean(mn_vals):>10.3f}") + + # Aggregate by model + print(f"\n{'Model':<22} {'n':>4} {'median_pres':>12} {'mean_pres':>10} " + f"{'median_norm':>12} {'mean_norm':>10}") + print("-" * 80) + by_model = defaultdict(list) + for a in analyses: + by_model[a["model"]].append(a) + for m in sorted(by_model): + cs = by_model[m] + mp_vals = [c["median_preservation"] for c in cs] + mn_vals = [c["median_normalization"] for c in cs] + print(f"{m:<22} {len(cs):>4} " + f"{statistics.median(mp_vals):>12.3f} {statistics.fmean(mp_vals):>10.3f} " + f"{statistics.median(mn_vals):>12.3f} {statistics.fmean(mn_vals):>10.3f}") + + # Effect of normalization on rebound: do cases that normalized more often FAIL? + print("\n=== RELATION TO REBOUND ===") + pass_pres = [a["median_preservation"] for a in analyses if a["grade"] == "CORRECT"] + fail_pres = [a["median_preservation"] for a in analyses if a["grade"] == "INCORRECT"] + print(f" median_preservation among rebound CORRECT (n={len(pass_pres)}): " + f"median={statistics.median(pass_pres):.3f} mean={statistics.fmean(pass_pres):.3f}") + print(f" median_preservation among rebound INCORRECT (n={len(fail_pres)}): " + f"median={statistics.median(fail_pres):.3f} mean={statistics.fmean(fail_pres):.3f}") + + # Save detailed results + out = Path("/home/yurenh2/gap/analysis/normalization_results.json") + json.dump([{k: v for k, v in a.items() if k != "per_var"} for a in analyses], + open(out, "w"), indent=2) + print(f"\nSaved -> {out}") + + +if __name__ == "__main__": + main() diff --git a/analysis/rescue_analyze.py b/analysis/rescue_analyze.py new file mode 100644 index 0000000..5fe97b6 --- /dev/null +++ b/analysis/rescue_analyze.py @@ -0,0 +1,161 @@ +"""Analyze full rescue results: per-cell rebound rates, Wilson CIs, McNemar.""" +from __future__ import annotations +import json +import math +import statistics +from collections import defaultdict +from pathlib import Path + +PATH = Path("/home/yurenh2/gap/analysis/rescue_results/rescue_30.jsonl") + + +def wilson_ci(k: int, n: int, z: float = 1.96) -> tuple: + if n == 0: + return (0.0, 0.0, 0.0) + p = k / n + denom = 1 + z * z / n + center = (p + z * z / (2 * n)) / denom + half = z * math.sqrt(p * (1 - p) / n + z * z / (4 * n * n)) / denom + return (p, max(0.0, center - half), min(1.0, center + half)) + + +def mcnemar_p(b: int, c: int) -> float: + """McNemar exact-ish p (binomial two-sided). b = treat A correct, B wrong; + c = treat A wrong, B correct. Returns p value testing b == c.""" + n = b + c + if n == 0: + return 1.0 + # Two-sided binomial test on min(b,c) ~ Bin(n, 0.5) + k = min(b, c) + # cumulative + cum = 0.0 + for i in range(k + 1): + cum += math.comb(n, i) * (0.5 ** n) + p = min(1.0, 2 * cum) + return p + + +def main(): + rows = [json.loads(l) for l in open(PATH)] + print(f"Loaded {len(rows)} rows") + + # Quick sanity + from collections import Counter + print("Solve status:", Counter(r.get("solve_status") for r in rows)) + print("Grade status:", Counter(r.get("grade_status") for r in rows)) + + # Per-cell counts + counts = defaultdict(lambda: {"total": 0, "correct": 0}) + for r in rows: + if r.get("grade_status") != "success" and r.get("grade") not in ("CORRECT", "INCORRECT"): + # Treat solve failures / parse failures as INCORRECT (conservative) + pass + key = (r["model"], r["variant"], r["condition"]) + counts[key]["total"] += 1 + if r.get("grade") == "CORRECT": + counts[key]["correct"] += 1 + + # Aggregated by (variant, condition) + by_var_cond = defaultdict(lambda: {"total": 0, "correct": 0}) + for (m, v, c), d in counts.items(): + by_var_cond[(v, c)]["total"] += d["total"] + by_var_cond[(v, c)]["correct"] += d["correct"] + + print("\n" + "=" * 90) + print("REBOUND RATE BY (VARIANT, CONDITION) [aggregated across 4 models]") + print("=" * 90) + print(f"{'Variant':<32} {'Condition':<14} {'k/n':>10} {'rate':>7} {'95% Wilson CI':>20}") + print("-" * 90) + variants_order = ["descriptive_long", "descriptive_long_confusing", + "descriptive_long_misleading", "garbled_string", "kernel_variant"] + conds_order = ["null", "canonical_T2", "own_T2"] + for v in variants_order: + for c in conds_order: + d = by_var_cond.get((v, c)) + if not d: + continue + p, lo, hi = wilson_ci(d["correct"], d["total"]) + print(f"{v:<32} {c:<14} {d['correct']:>4}/{d['total']:>4} " + f"{p*100:>5.1f}% [{lo*100:>5.1f}%, {hi*100:>5.1f}%]") + print() + + # Per-model aggregated by (variant, condition) + print("\n" + "=" * 90) + print("REBOUND RATE PER (MODEL, VARIANT, CONDITION)") + print("=" * 90) + models_order = sorted({k[0] for k in counts}) + print(f"{'Model':<22} {'Variant':<32} {'cond':<14} {'k/n':>10} {'rate':>7}") + for m in models_order: + for v in variants_order: + for c in conds_order: + d = counts.get((m, v, c)) + if not d: + continue + p, lo, hi = wilson_ci(d["correct"], d["total"]) + print(f" {m:<20} {v:<32} {c:<14} {d['correct']:>3}/{d['total']:>3} " + f"{p*100:>5.1f}%") + print() + + # Paired McNemar test: same case, different conditions + # Pair canonical_T2 vs null, and own_T2 vs null + print("\n" + "=" * 90) + print("PAIRED MCNEMAR TESTS") + print("=" * 90) + case_grades = defaultdict(dict) # (model, variant, index) -> {cond: grade} + for r in rows: + case_grades[(r["model"], r["variant"], r["index"])][r["condition"]] = r.get("grade") + + print("\ncanonical_T2 vs null:") + print(f" {'cell':<46} {'b (can-only)':>12} {'c (null-only)':>13} " + f"{'both-CORR':>10} {'both-INC':>10} {'McNemar p':>11}") + for m in models_order: + for v in variants_order: + b = c = both_corr = both_inc = 0 + for k, grds in case_grades.items(): + if k[0] != m or k[1] != v: continue + ca = grds.get("canonical_T2"); nu = grds.get("null") + if ca is None or nu is None: continue + if ca == "CORRECT" and nu == "INCORRECT": b += 1 + elif ca == "INCORRECT" and nu == "CORRECT": c += 1 + elif ca == "CORRECT" and nu == "CORRECT": both_corr += 1 + elif ca == "INCORRECT" and nu == "INCORRECT": both_inc += 1 + p = mcnemar_p(b, c) + print(f" {m+'/'+v:<46} {b:>12} {c:>13} {both_corr:>10} {both_inc:>10} {p:>11.3f}") + + print("\nown_T2 vs null:") + print(f" {'cell':<46} {'b (own-only)':>12} {'c (null-only)':>13} " + f"{'both-CORR':>10} {'both-INC':>10} {'McNemar p':>11}") + for m in models_order: + for v in [vv for vv in variants_order if vv != "kernel_variant"]: + b = c = both_corr = both_inc = 0 + for k, grds in case_grades.items(): + if k[0] != m or k[1] != v: continue + ow = grds.get("own_T2"); nu = grds.get("null") + if ow is None or nu is None: continue + if ow == "CORRECT" and nu == "INCORRECT": b += 1 + elif ow == "INCORRECT" and nu == "CORRECT": c += 1 + elif ow == "CORRECT" and nu == "CORRECT": both_corr += 1 + elif ow == "INCORRECT" and nu == "INCORRECT": both_inc += 1 + p = mcnemar_p(b, c) + print(f" {m+'/'+v:<46} {b:>12} {c:>13} {both_corr:>10} {both_inc:>10} {p:>11.3f}") + + print("\nown_T2 vs canonical_T2:") + print(f" {'cell':<46} {'b (own-only)':>12} {'c (can-only)':>13} " + f"{'both-CORR':>10} {'both-INC':>10} {'McNemar p':>11}") + for m in models_order: + for v in [vv for vv in variants_order if vv != "kernel_variant"]: + b = c = both_corr = both_inc = 0 + for k, grds in case_grades.items(): + if k[0] != m or k[1] != v: continue + ow = grds.get("own_T2"); ca = grds.get("canonical_T2") + if ow is None or ca is None: continue + if ow == "CORRECT" and ca == "INCORRECT": b += 1 + elif ow == "INCORRECT" and ca == "CORRECT": c += 1 + elif ow == "CORRECT" and ca == "CORRECT": both_corr += 1 + elif ow == "INCORRECT" and ca == "INCORRECT": both_inc += 1 + p = mcnemar_p(b, c) + print(f" {m+'/'+v:<46} {b:>12} {c:>13} {both_corr:>10} {both_inc:>10} {p:>11.3f}") + + +if __name__ == "__main__": + main() diff --git a/analysis/rescue_api.py b/analysis/rescue_api.py new file mode 100644 index 0000000..4641655 --- /dev/null +++ b/analysis/rescue_api.py @@ -0,0 +1,373 @@ +"""Async API caller for rescue experiment. + +Supports OpenAI, Anthropic, Google. All callers return a unified dict: + {"status": "success"|"failed", "content": str, "error": str|None} + +Concurrency is controlled per-provider via asyncio.Semaphore so we don't +saturate rate limits in any one provider. +""" +from __future__ import annotations +import asyncio +import json +import os +import random +from typing import Optional + +# ---------- Provider constants ---------- + +# Solver model -> provider mapping +SOLVER_PROVIDERS = { + "gpt-4.1-mini": "openai", + "gpt-4o-mini": "openai", + "claude-sonnet-4": "anthropic", + "gemini-2.5-flash": "google", +} + +# API model strings (the canonical IDs to send) +API_MODEL_NAMES = { + "gpt-4.1-mini": "gpt-4.1-mini", + "gpt-4o-mini": "gpt-4o-mini", + "claude-sonnet-4": "claude-sonnet-4-20250514", + "gemini-2.5-flash": "gemini-2.5-flash", +} + +GRADER_MODEL = "gpt-4o" +GRADER_PROVIDER = "openai" + +PER_PROVIDER_CONCURRENCY = { + "openai": 500, + "anthropic": 25, # 90k tok/min cap; 25 in flight keeps us comfortably under + "google": 300, +} + +DEFAULT_RETRIES = 6 +DEFAULT_BASE_TIMEOUT = 300.0 +RATE_LIMIT_BACKOFF_SECONDS = 60.0 # min sleep on rate limit hits + + +# ---------- Solver / grader prompts (consistent with paper) ---------- + +SOLVER_SYSTEM_PROMPT = """You are an expert mathematician solving competition-level problems. +Provide detailed, step-by-step solutions with clear mathematical reasoning. + +Requirements: +- Show all your work and intermediate steps +- Justify each major step of your reasoning +- Use proper mathematical notation +- Be thorough but concise +- State your final answer clearly + +Solve the problem completely and rigorously.""" + +PROOF_GRADER_SYSTEM_PROMPT = """You are an extremely strict mathematical grader evaluating competition-level PROOF problems. + +GRADING STANDARDS (BE VERY STRICT): +- Mathematical rigor: Every step must be mathematically sound and justified +- Logical flow: The reasoning must be clear, complete, and logically connected +- Correctness: All calculations, algebraic manipulations, and conclusions must be correct +- Completeness: The solution must address all parts of the problem fully +- Precision: Mathematical statements must be precise and unambiguous + +FAILING CRITERIA (Mark as INCORRECT if ANY of these apply): +- Any unjustified logical leap or gap in reasoning +- Any computational error, no matter how small +- Missing steps in critical parts of the argument +- Imprecise or ambiguous mathematical statements +- Incorrect final answer, even if approach is partially correct +- Circular reasoning or logical fallacies +- Misuse of mathematical theorems or definitions + +BE EXTREMELY STRICT. Competition mathematics proofs require perfect precision.""" + +CALCULATION_GRADER_SYSTEM_PROMPT = """You are a mathematical grader evaluating competition-level CALCULATION problems. + +GRADING STANDARDS FOR CALCULATION PROBLEMS: +- Primary focus: Is the final answer correct? +- Secondary focus: Is the overall approach reasonable and mathematically sound? +- Computation: Allow minor computational slips if the method is correct and final answer is right + +GRADING CRITERIA: +- CORRECT: Final answer is correct AND approach is fundamentally sound +- INCORRECT: Final answer is wrong OR approach is fundamentally flawed + +For calculation problems, the final numerical answer is the most important criterion. +Minor intermediate errors are acceptable if they don't affect the final result.""" + +PROOF_GRADER_USER_TEMPLATE = """Grade this PROOF solution with extreme strictness. + +PROBLEM: +{problem_statement} + +STUDENT SOLUTION: +{solution} + +CORRECT REFERENCE SOLUTION: +{reference_solution} + +Evaluate with maximum strictness. Every logical step must be perfect. Return JSON with: +{{"grade": "CORRECT" or "INCORRECT", + "detailed_feedback": "specific detailed analysis of what is right/wrong", + "major_issues": "list of significant mathematical errors or gaps", + "final_answer_correct": true or false, + "reasoning_rigor_score": 0-10 integer (10=perfect rigor, 0=severely flawed), + "overall_assessment": "comprehensive evaluation summary"}}""" + +CALCULATION_GRADER_USER_TEMPLATE = """Grade this CALCULATION solution with focus on final answer correctness. + +PROBLEM: +{problem_statement} + +STUDENT SOLUTION: +{solution} + +CORRECT REFERENCE SOLUTION: +{reference_solution} + +Focus primarily on whether the final answer is correct. Return JSON with: +{{"grade": "CORRECT" or "INCORRECT", + "detailed_feedback": "specific detailed analysis of what is right/wrong", + "major_issues": "list of significant mathematical errors or gaps", + "final_answer_correct": true or false, + "reasoning_rigor_score": 0-10 integer (10=perfect rigor, 0=severely flawed), + "overall_assessment": "comprehensive evaluation summary"}}""" + + +# ---------- Lazy client builders ---------- + +_openai_client = None +_anthropic_client = None +_google_client = None + +def _get_openai_client(): + global _openai_client + if _openai_client is None: + from openai import AsyncOpenAI + import httpx + limits = httpx.Limits(max_connections=2000, max_keepalive_connections=1000) + timeout = httpx.Timeout(timeout=DEFAULT_BASE_TIMEOUT, connect=30.0, + read=DEFAULT_BASE_TIMEOUT, write=30.0) + _openai_client = AsyncOpenAI(http_client=httpx.AsyncClient(limits=limits, timeout=timeout)) + return _openai_client + + +def _get_anthropic_client(): + global _anthropic_client + if _anthropic_client is None: + from anthropic import AsyncAnthropic + _anthropic_client = AsyncAnthropic() + return _anthropic_client + + +def _get_google_client(): + global _google_client + if _google_client is None: + from google import genai + _google_client = genai.Client(api_key=os.environ["GOOGLE_API_KEY"]) + return _google_client + + +# ---------- Per-provider call functions ---------- + +async def _call_openai(model: str, system: str, user: str, + temperature: float, max_tokens: int = 16000) -> dict: + client = _get_openai_client() + api_params = { + "model": model, + "messages": [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ], + "max_tokens": max_tokens, + } + # o-series models force temperature=1 and don't accept max_tokens + if any(p in model.lower() for p in ["o1", "o3", "o4"]): + api_params.pop("max_tokens", None) + api_params["temperature"] = 1.0 + else: + api_params["temperature"] = temperature + api_params["response_format"] = {"type": "json_object"} + resp = await client.chat.completions.create(**api_params) + content = resp.choices[0].message.content or "" + return {"status": "success", "content": content, "error": None} + + +async def _call_anthropic(model: str, system: str, user: str, + temperature: float, max_tokens: int = 16000) -> dict: + client = _get_anthropic_client() + resp = await client.messages.create( + model=model, + system=system, + messages=[{"role": "user", "content": user}], + temperature=temperature, + max_tokens=max_tokens, + ) + content = "" + if resp.content: + for block in resp.content: + if hasattr(block, "text"): + content += block.text + return {"status": "success", "content": content, "error": None} + + +async def _call_google(model: str, system: str, user: str, + temperature: float, max_tokens: int = 16000) -> dict: + client = _get_google_client() + from google.genai.types import GenerateContentConfig + config = GenerateContentConfig( + system_instruction=system, + temperature=temperature, + max_output_tokens=max_tokens, + response_mime_type="application/json", + ) + resp = await client.aio.models.generate_content( + model=model, contents=user, config=config, + ) + content = resp.text or "" + return {"status": "success", "content": content, "error": None} + + +# ---------- Unified caller with retries and per-provider semaphore ---------- + +_provider_sems: dict = {} + +def _sem_for(provider: str) -> asyncio.Semaphore: + if provider not in _provider_sems: + _provider_sems[provider] = asyncio.Semaphore(PER_PROVIDER_CONCURRENCY[provider]) + return _provider_sems[provider] + + +async def call_model(model_short: str, system: str, user: str, + temperature: float = 0.0, max_tokens: int = 16000, + retries: int = DEFAULT_RETRIES) -> dict: + """Call any supported model by short alias. Includes retries.""" + if model_short == GRADER_MODEL: + provider = GRADER_PROVIDER + api_model = GRADER_MODEL + else: + provider = SOLVER_PROVIDERS[model_short] + api_model = API_MODEL_NAMES[model_short] + sem = _sem_for(provider) + + async with sem: + last_err = None + for attempt in range(retries): + try: + if provider == "openai": + return await _call_openai(api_model, system, user, temperature, max_tokens) + elif provider == "anthropic": + return await _call_anthropic(api_model, system, user, temperature, max_tokens) + elif provider == "google": + return await _call_google(api_model, system, user, temperature, max_tokens) + else: + return {"status": "failed", "content": "", + "error": f"unknown provider {provider}"} + except Exception as e: + last_err = e + err_str = str(e).lower() + # Longer backoff for rate-limit-style errors so the per-minute + # window has time to refill. + if "rate_limit" in err_str or "429" in err_str or "rate limit" in err_str: + await asyncio.sleep(RATE_LIMIT_BACKOFF_SECONDS + random.random() * 10) + else: + await asyncio.sleep(min(2 ** attempt + random.random(), 30)) + return {"status": "failed", "content": "", + "error": f"{type(last_err).__name__}: {str(last_err)[:300]}"} + + +# ---------- High-level helpers ---------- + +async def solve(model_short: str, problem_user_msg: str) -> dict: + """Run the solver. The user message already contains problem + any prefix.""" + return await call_model(model_short, SOLVER_SYSTEM_PROMPT, problem_user_msg, temperature=0.0) + + +async def grade(problem_type: str, problem_statement: str, + solution: str, reference_solution: str) -> dict: + """Run the grader (gpt-4o).""" + if problem_type == "proof": + sys = PROOF_GRADER_SYSTEM_PROMPT + tmpl = PROOF_GRADER_USER_TEMPLATE + else: + sys = CALCULATION_GRADER_SYSTEM_PROMPT + tmpl = CALCULATION_GRADER_USER_TEMPLATE + user = tmpl.format(problem_statement=problem_statement, + solution=solution, + reference_solution=reference_solution) + return await call_model(GRADER_MODEL, sys, user, temperature=0.0) + + +def parse_solution(content: str) -> dict: + """Parse JSON {solution, final_answer} from model output, with tolerance.""" + if not content: + return {"solution": "", "final_answer": "", "_parse_error": "empty"} + try: + d = json.loads(content) + return {"solution": d.get("solution", ""), + "final_answer": d.get("final_answer", ""), + "_parse_error": None} + except Exception: + # Try to extract a JSON object substring + import re + m = re.search(r"\{.*\}", content, re.DOTALL) + if m: + try: + d = json.loads(m.group(0)) + return {"solution": d.get("solution", ""), + "final_answer": d.get("final_answer", ""), + "_parse_error": None} + except Exception as e: + return {"solution": content, "final_answer": "", + "_parse_error": f"json parse: {e}"} + return {"solution": content, "final_answer": "", + "_parse_error": "no JSON object found"} + + +def parse_grade(content: str) -> dict: + """Parse JSON grade output.""" + if not content: + return {"grade": "INCORRECT", "_parse_error": "empty"} + try: + d = json.loads(content) + # Normalize grade + g = (d.get("grade") or "").strip().upper() + return { + "grade": g if g in ("CORRECT", "INCORRECT") else "INCORRECT", + "final_answer_correct": d.get("final_answer_correct"), + "detailed_feedback": d.get("detailed_feedback", ""), + "_parse_error": None, + } + except Exception: + import re + m = re.search(r"\{.*\}", content, re.DOTALL) + if m: + try: + d = json.loads(m.group(0)) + g = (d.get("grade") or "").strip().upper() + return { + "grade": g if g in ("CORRECT", "INCORRECT") else "INCORRECT", + "final_answer_correct": d.get("final_answer_correct"), + "detailed_feedback": d.get("detailed_feedback", ""), + "_parse_error": None, + } + except Exception as e: + return {"grade": "INCORRECT", "_parse_error": f"json parse: {e}"} + return {"grade": "INCORRECT", "_parse_error": "no JSON object found"} + + +# ---------- Standalone health check ---------- + +async def _health_check(): + print("Running health checks ...") + msg = ('Reply with JSON {"status": "ok"} only.') + for short in ["gpt-4o-mini", "claude-sonnet-4", "gemini-2.5-flash"]: + r = await call_model(short, "You are a test. Reply only the requested JSON.", + msg, temperature=0.0, max_tokens=200, retries=2) + print(f" {short}: {r['status']} - {r['content'][:200]!r} err={r['error']}") + # Grader + r = await call_model(GRADER_MODEL, "You are a test.", msg, temperature=0.0, + max_tokens=200, retries=2) + print(f" {GRADER_MODEL} (grader): {r['status']} - {r['content'][:200]!r} err={r['error']}") + + +if __name__ == "__main__": + asyncio.run(_health_check()) diff --git a/analysis/rescue_pooled.py b/analysis/rescue_pooled.py new file mode 100644 index 0000000..cc9f782 --- /dev/null +++ b/analysis/rescue_pooled.py @@ -0,0 +1,174 @@ +"""Pooled rescue analysis for the rebuttal headline. + +Reports: +1. Per-variant pooled rebound rates with Wilson 95% CI for each condition +2. Pooled McNemar (paired) tests across all 4 models per variant +3. Pooled McNemar across all 5 surface variants for each model +4. Headline single-cell numbers +""" +from __future__ import annotations +import json +import math +import statistics +from collections import defaultdict +from pathlib import Path + +PATH = Path("/home/yurenh2/gap/analysis/rescue_results/rescue_30.jsonl") +OUT_PATH = Path("/home/yurenh2/gap/analysis/rescue_pooled_summary.json") + + +def wilson_ci(k: int, n: int, z: float = 1.96): + if n == 0: + return (0.0, 0.0, 0.0) + p = k / n + denom = 1 + z * z / n + center = (p + z * z / (2 * n)) / denom + half = z * math.sqrt(p * (1 - p) / n + z * z / (4 * n * n)) / denom + return (p, max(0.0, center - half), min(1.0, center + half)) + + +def mcnemar_p(b: int, c: int) -> float: + n = b + c + if n == 0: + return 1.0 + k = min(b, c) + cum = sum(math.comb(n, i) * (0.5 ** n) for i in range(k + 1)) + return min(1.0, 2 * cum) + + +def main(): + rows = [json.loads(l) for l in open(PATH)] + print(f"Loaded {len(rows)} rows\n") + + # case_grades[(model, variant, index)] = {cond: grade} + case_grades = defaultdict(dict) + for r in rows: + case_grades[(r["model"], r["variant"], r["index"])][r["condition"]] = r.get("grade") + + variants_order = ["descriptive_long", "descriptive_long_confusing", + "descriptive_long_misleading", "garbled_string", "kernel_variant"] + short = {"descriptive_long":"DL","descriptive_long_confusing":"DLC", + "descriptive_long_misleading":"DLM","garbled_string":"GS","kernel_variant":"KV"} + + summary = {} + + print("=" * 92) + print("HEADLINE: Rescue rebound by variant (pooled across 4 models)") + print("=" * 92) + print(f"{'Variant':<6} {'Condition':<14} {'k/n':>10} {'rate':>7} " + f"{'95% Wilson CI':>20} {'Δ vs null':>11}") + print("-" * 80) + var_summary = {} + for v in variants_order: + # Pool counts across models + cell_counts = defaultdict(lambda: {"k": 0, "n": 0}) + for k, grds in case_grades.items(): + if k[1] != v: continue + for cond in ("null", "canonical_T2", "own_T2"): + if cond in grds: + cell_counts[cond]["n"] += 1 + if grds[cond] == "CORRECT": + cell_counts[cond]["k"] += 1 + # Wilson CIs + per_cond = {} + null_p = cell_counts["null"]["k"] / max(1, cell_counts["null"]["n"]) + for cond in ("null", "canonical_T2", "own_T2"): + if cond not in cell_counts: continue + c = cell_counts[cond] + if c["n"] == 0: continue + p, lo, hi = wilson_ci(c["k"], c["n"]) + delta = (p - null_p) * 100 if cond != "null" else 0.0 + per_cond[cond] = {"k": c["k"], "n": c["n"], "p": p, "ci": [lo, hi], "delta_pp": delta} + print(f"{short[v]:<6} {cond:<14} {c['k']:>4}/{c['n']:>4} " + f"{p*100:>5.1f}% [{lo*100:>5.1f}%, {hi*100:>5.1f}%] " + f"{'+' if delta > 0 else ('' if delta == 0 else '-')}{abs(delta):>5.1f} pp") + # Pooled McNemar (own vs null, can vs null, own vs can) + mc = {} + for a, b in [("canonical_T2", "null"), ("own_T2", "null"), + ("own_T2", "canonical_T2")]: + b_count = c_count = 0 + for k, grds in case_grades.items(): + if k[1] != v: continue + ga = grds.get(a); gb = grds.get(b) + if ga is None or gb is None: continue + if ga == "CORRECT" and gb == "INCORRECT": b_count += 1 + elif ga == "INCORRECT" and gb == "CORRECT": c_count += 1 + p = mcnemar_p(b_count, c_count) + mc[f"{a}_vs_{b}"] = {"b": b_count, "c": c_count, "p": p} + var_summary[v] = {"per_cond": per_cond, "mcnemar": mc} + print() + + summary["per_variant"] = var_summary + + # Pooled McNemar across all surface variants for canonical vs null and own vs null + print("\n" + "=" * 92) + print("POOLED McNEMAR (across all 4 surface variants × 4 models)") + print("=" * 92) + surface_vs = ["descriptive_long", "descriptive_long_confusing", + "descriptive_long_misleading", "garbled_string"] + for a, b in [("canonical_T2", "null"), ("own_T2", "null"), + ("own_T2", "canonical_T2")]: + b_count = c_count = 0 + for k, grds in case_grades.items(): + if k[1] not in surface_vs: continue + ga = grds.get(a); gb = grds.get(b) + if ga is None or gb is None: continue + if ga == "CORRECT" and gb == "INCORRECT": b_count += 1 + elif ga == "INCORRECT" and gb == "CORRECT": c_count += 1 + p = mcnemar_p(b_count, c_count) + n = b_count + c_count + odds_ratio = b_count / max(1, c_count) + print(f" {a:<14} > {b:<14} b={b_count:>4}, c={c_count:>4} " + f"OR={odds_ratio:>4.2f} McNemar p={p:.2e} (n_discordant={n})") + # KV separately + print() + for a, b in [("canonical_T2", "null")]: + b_count = c_count = 0 + for k, grds in case_grades.items(): + if k[1] != "kernel_variant": continue + ga = grds.get(a); gb = grds.get(b) + if ga is None or gb is None: continue + if ga == "CORRECT" and gb == "INCORRECT": b_count += 1 + elif ga == "INCORRECT" and gb == "CORRECT": c_count += 1 + p = mcnemar_p(b_count, c_count) + odds_ratio = b_count / max(1, c_count) + print(f" KV: {a:<14} > {b:<14} b={b_count:>4}, c={c_count:>4} " + f"OR={odds_ratio:>4.2f} McNemar p={p:.2e}") + + # Per model summary + print("\n" + "=" * 92) + print("PER MODEL (averaged across 4 surface variants)") + print("=" * 92) + print(f"{'Model':<22} {'null':>10} {'canonical_T2':>14} {'own_T2':>10} " + f"{'can-null':>10} {'own-null':>10}") + per_model = {} + for model in sorted({k[0] for k in case_grades}): + cnts = defaultdict(lambda: {"k": 0, "n": 0}) + for k, grds in case_grades.items(): + if k[0] != model: continue + if k[1] not in surface_vs: continue + for cond in ("null", "canonical_T2", "own_T2"): + if cond in grds: + cnts[cond]["n"] += 1 + if grds[cond] == "CORRECT": + cnts[cond]["k"] += 1 + nul_p = cnts["null"]["k"] / max(1, cnts["null"]["n"]) + can_p = cnts["canonical_T2"]["k"] / max(1, cnts["canonical_T2"]["n"]) + own_p = cnts["own_T2"]["k"] / max(1, cnts["own_T2"]["n"]) + per_model[model] = { + "null": {"k": cnts["null"]["k"], "n": cnts["null"]["n"], "p": nul_p}, + "canonical_T2": {"k": cnts["canonical_T2"]["k"], "n": cnts["canonical_T2"]["n"], "p": can_p}, + "own_T2": {"k": cnts["own_T2"]["k"], "n": cnts["own_T2"]["n"], "p": own_p}, + "can_minus_null_pp": (can_p - nul_p) * 100, + "own_minus_null_pp": (own_p - nul_p) * 100, + } + print(f" {model:<20} {nul_p*100:>9.1f}% {can_p*100:>13.1f}% {own_p*100:>9.1f}% " + f"{(can_p-nul_p)*100:>+9.1f}pp {(own_p-nul_p)*100:>+9.1f}pp") + summary["per_model"] = per_model + + json.dump(summary, open(OUT_PATH, "w"), indent=2) + print(f"\nSaved -> {OUT_PATH}") + + +if __name__ == "__main__": + main() diff --git a/analysis/rescue_prompts.py b/analysis/rescue_prompts.py new file mode 100644 index 0000000..8e8f65c --- /dev/null +++ b/analysis/rescue_prompts.py @@ -0,0 +1,267 @@ +"""Rescue-experiment prompt construction. + +For each (model, variant, flip-case) we build prompts under three conditions: +- own_T2: model's own original-correct trajectory truncated at first + formal equation (with leakage filter), variables auto-renamed + to variant names via the dataset's rename map +- canonical_T2: the dataset's canonical variant solution truncated at first + formal equation (no rename needed; already in variant naming) +- null: generic content-free scaffold + +Truncation rule (event-boundary): + 1. Find the FIRST display-math block ($$...$$, \\[...\\], \\begin{equation/align/...}) + 2. If none, fall back to the first line containing a substantive math relation + (>=, <=, =, <, >, ≡, ∈) that is not merely a definition (e.g., 'let x:=...') + 3. The T2 prefix INCLUDES that first formal relation + 4. Apply leakage filter BEFORE returning: stop at the earliest of: + - any line containing \\boxed + - any line containing 'therefore', 'hence', 'we conclude', 'the answer', + 'we obtain', 'thus', 'it suffices', 'we have proved', 'as a result' + - any line containing the dataset's recorded final_answer string +""" +from __future__ import annotations +import re +from typing import Optional, Dict + + +# ---------- Display-math detection ---------- + +# Order matters: try richest patterns first +_DISPLAY_MATH_PATTERNS = [ + re.compile(r"\$\$.+?\$\$", re.DOTALL), + re.compile(r"\\\[.+?\\\]", re.DOTALL), + re.compile(r"\\begin\{equation\*?\}.+?\\end\{equation\*?\}", re.DOTALL), + re.compile(r"\\begin\{align\*?\}.+?\\end\{align\*?\}", re.DOTALL), + re.compile(r"\\begin\{gather\*?\}.+?\\end\{gather\*?\}", re.DOTALL), + re.compile(r"\\begin\{eqnarray\*?\}.+?\\end\{eqnarray\*?\}", re.DOTALL), +] + + +def _first_display_math_end(text: str) -> Optional[int]: + """Return the end position of the first display-math block, or None.""" + earliest = None + for pat in _DISPLAY_MATH_PATTERNS: + m = pat.search(text) + if m: + if earliest is None or m.end() < earliest: + earliest = m.end() + return earliest + + +# Inline relation fallback: first line with a "real" relation +_INLINE_REL_RE = re.compile( + r"[A-Za-z\)\]\}\d_]\s*(?:=|<|>|\\le[q]?|\\ge[q]?|\\equiv|\\in)\s*[A-Za-z\(\[\{\d\\\-]" +) +# Definition exclusion: lines that are 'let x = ...' or 'denote ...' are setup, +# not actual derivations. We allow them in the prefix but don't stop on them. +_DEFINITION_RE = re.compile( + r"^\s*(?:let|denote|define|set|put|call|consider|introduce|let us)\b", + re.IGNORECASE +) + + +def _first_inline_relation_line_end(text: str) -> Optional[int]: + """Find the end of the first line containing a non-definition math relation. + + Returns absolute character offset (one past the newline).""" + pos = 0 + while pos < len(text): + nl = text.find("\n", pos) + line_end = nl if nl != -1 else len(text) + line = text[pos:line_end] + if _INLINE_REL_RE.search(line) and not _DEFINITION_RE.search(line): + return line_end + 1 if nl != -1 else line_end + pos = line_end + 1 + if nl == -1: + break + return None + + +# ---------- Leakage detection ---------- + +LEAKAGE_PATTERNS = [ + re.compile(r"\\boxed\b", re.IGNORECASE), + re.compile(r"\btherefore\b", re.IGNORECASE), + re.compile(r"\bhence\b", re.IGNORECASE), + re.compile(r"\bwe conclude\b", re.IGNORECASE), + re.compile(r"\bthe answer\b", re.IGNORECASE), + re.compile(r"\bwe obtain\b", re.IGNORECASE), + re.compile(r"\bthus\b", re.IGNORECASE), + re.compile(r"\bit suffices\b", re.IGNORECASE), + re.compile(r"\bwe have proved\b", re.IGNORECASE), + re.compile(r"\bwe have shown\b", re.IGNORECASE), + re.compile(r"\bas a result\b", re.IGNORECASE), + re.compile(r"\bin conclusion\b", re.IGNORECASE), + re.compile(r"\bthe final answer\b", re.IGNORECASE), + re.compile(r"\bso the answer\b", re.IGNORECASE), +] + + +def _first_leakage_pos(text: str, final_answer: Optional[str] = None) -> Optional[int]: + """Return the starting char position of the earliest leakage marker.""" + earliest = None + for pat in LEAKAGE_PATTERNS: + m = pat.search(text) + if m: + if earliest is None or m.start() < earliest: + earliest = m.start() + if final_answer: + # Final-answer leakage: only check if the answer string is non-trivial + fa = final_answer.strip() + if 8 <= len(fa) <= 200: + idx = text.find(fa) + if idx != -1: + if earliest is None or idx < earliest: + earliest = idx + return earliest + + +# ---------- T2 truncation ---------- + +MIN_PREFIX_CHARS = 50 +MAX_PREFIX_CHARS = 2400 # roughly 600 tokens + + +def truncate_T2(text: str, final_answer: Optional[str] = None) -> Optional[str]: + """Return the T2 (after-first-equation) prefix, or None if not detectable. + + T2 = up to and including the first formal equation, then capped by leakage + filter and MAX_PREFIX_CHARS. + """ + if not text: + return None + end = _first_display_math_end(text) + if end is None: + end = _first_inline_relation_line_end(text) + if end is None: + return None + prefix = text[:end] + # Apply leakage filter BEFORE the equation if a leakage marker appears earlier + leak = _first_leakage_pos(prefix, final_answer) + if leak is not None and leak < end: + prefix = text[:leak].rstrip() + # Cap length + if len(prefix) > MAX_PREFIX_CHARS: + prefix = prefix[:MAX_PREFIX_CHARS] + # Trim at last newline to avoid cutting mid-sentence + last_nl = prefix.rfind("\n") + if last_nl > MIN_PREFIX_CHARS: + prefix = prefix[:last_nl] + if len(prefix) < MIN_PREFIX_CHARS: + return None + return prefix.rstrip() + + +# ---------- Variable rename for own prefix ---------- + +def rename_own_prefix(prefix: str, rename_map: Dict[str, str]) -> str: + """Apply orig->variant rename mapping to the model's own prefix. + + Sort longest-first to avoid prefix collisions (e.g., 'al' eating 'almondtree'). + Use word-boundary regex. Pass replacement via lambda to avoid escape-sequence + interpretation when the variant name starts with '\\x', '\\g', etc. + """ + if not prefix or not rename_map: + return prefix + items = sorted(rename_map.items(), key=lambda kv: -len(kv[0])) + out = prefix + for src, dst in items: + if not src: + continue + pat = r"(? str: + return RESCUE_USER_TEMPLATE.format( + problem_statement=problem_statement, prefix=prefix) + + +def build_null_prompt(problem_statement: str) -> str: + return NULL_USER_TEMPLATE.format( + problem_statement=problem_statement, scaffold=NULL_SCAFFOLD) + + +# ---------- Smoke test ---------- + +if __name__ == "__main__": + # Quick smoke test on a real flip case + import json + import sys + sys.path.insert(0, "/home/yurenh2/gap/analysis") + from structural_overlap import find_variant_file, load_problems + + # Pick gpt-4.1-mini original on a known problem + op = find_variant_file( + __import__("pathlib").Path("/home/yurenh2/gap/results_new/gpt-4.1-mini"), + "original") + probs = {p["index"]: p for p in load_problems(op)} + sample = next(p for idx, p in probs.items() + if p.get("correct") is True and (p.get("solve") or {}).get("solution")) + text = sample["solve"]["solution"] + fa = sample["solve"].get("final_answer") + print(f"Sample index: {sample['index']}, type: {sample['problem_type']}") + print(f"Original solution length: {len(text)} chars") + print(f"Recorded final_answer: {fa[:200] if fa else None!r}") + pre = truncate_T2(text, fa) + print(f"\n--- T2 PREFIX ({len(pre or '')} chars) ---") + print(pre) + print("--- END ---") + + # Test rename: load 1987-B-2 dataset to get a sample map + ds = json.load(open("/home/yurenh2/gap/putnam-bench-anon/dataset/1987-B-2.json")) + rmap_raw = ds["variants"]["garbled_string"]["map"] + rmap = (eval(rmap_raw, {"__builtins__": {}}, {}) + if isinstance(rmap_raw, str) else rmap_raw) + print(f"\nRename map: {rmap}") + test_text = "Let n be a positive integer and let f be a continuous function. Then $f(n) = 0$." + print(f"\nOriginal: {test_text}") + print(f"Renamed: {rename_own_prefix(test_text, rmap)}") diff --git a/analysis/rescue_runner.py b/analysis/rescue_runner.py new file mode 100644 index 0000000..9c9f226 --- /dev/null +++ b/analysis/rescue_runner.py @@ -0,0 +1,341 @@ +"""End-to-end rescue experiment runner. + +For each (model, variant, flip-case): + - Build 3 prompts: own_T2, canonical_T2, null (KV: only canonical_T2 + null) + - Solve with the same model the case originally failed under + - Grade with gpt-4o using the variant problem + canonical variant solution as reference + - Save per-case results immediately to a jsonl checkpoint (resumable) + +Usage: + python rescue_runner.py --pilot # 5 cases per cell (smoke test) + python rescue_runner.py # 30 cases per cell (full run) +""" +from __future__ import annotations +import argparse +import asyncio +import json +import os +import random +import sys +import time +from pathlib import Path +from typing import Optional + +# Local imports +THIS_DIR = Path(__file__).resolve().parent +sys.path.insert(0, str(THIS_DIR)) +from rescue_prompts import ( + truncate_T2, rename_own_prefix, + build_rescue_prompt, build_null_prompt, NULL_SCAFFOLD, +) +from rescue_api import ( + SOLVER_PROVIDERS, solve, grade, parse_solution, parse_grade, +) +from structural_overlap import ( + DATASET_DIR, RESULTS_DIR, find_variant_file, load_problems, SURFACE_VARIANTS, +) + + +# Short model name -> directory name in results_new +MODEL_RESULTS_DIRS = { + "gpt-4.1-mini": "gpt-4.1-mini", + "gpt-4o-mini": "gpt-4o-mini", + "claude-sonnet-4": "claude-sonnet-4", + "gemini-2.5-flash": "gemini_2.5_flash", # historical underscore naming +} +SELECTED_MODELS = ["gpt-4.1-mini", "gpt-4o-mini", "claude-sonnet-4", "gemini-2.5-flash"] +ALL_VARIANTS = SURFACE_VARIANTS + ["kernel_variant"] +SURFACE_CONDITIONS = ["own_T2", "canonical_T2", "null"] +KV_CONDITIONS = ["canonical_T2", "null"] + + +# ---------- Dataset loading ---------- + +def load_dataset_full() -> dict: + """Returns: {idx: {original: {...}, variants: {v: {map, question, solution}}}}. + + The dataset stores top-level question/solution and variant-keyed question/solution/map. + """ + out = {} + for f in sorted(DATASET_DIR.glob("*.json")): + d = json.load(open(f)) + idx = d.get("index") + cell = { + "problem_type": d.get("problem_type"), + "original_question": d.get("question") or "", + "original_solution": d.get("solution") or "", + "variants": {}, + } + for v, vd in d.get("variants", {}).items(): + if isinstance(vd, dict): + rmap = vd.get("map") + if isinstance(rmap, str): + try: + rmap = eval(rmap, {"__builtins__": {}}, {}) + except Exception: + rmap = None + cell["variants"][v] = { + "question": vd.get("question") or "", + "solution": vd.get("solution") or "", + "map": rmap if isinstance(rmap, dict) else None, + } + out[idx] = cell + return out + + +# ---------- Flip case selection ---------- + +def find_flip_cases(model: str, variant: str, max_cases: int, + seed: int = 42) -> list: + """Identify (orig_correct, var_wrong) flip cases for the cell. + + Returns list of dicts with: index, problem_type, model_orig_solution, + final_answer (recorded), variant_problem_statement (from results). + """ + mdir = RESULTS_DIR / MODEL_RESULTS_DIRS.get(model, model) + op = find_variant_file(mdir, "original") + vp = find_variant_file(mdir, variant) + if not op or not vp: + return [] + orig_by = {p["index"]: p for p in load_problems(op)} + var_by = {p["index"]: p for p in load_problems(vp)} + cases = [] + for idx in sorted(set(orig_by) & set(var_by)): + po, pv = orig_by[idx], var_by[idx] + if po.get("correct") is not True or pv.get("correct") is not False: + continue + orig_text = (po.get("solve") or {}).get("solution") or "" + if not orig_text: + continue + # Skip cases where we couldn't extract a T2 prefix from the original + fa = (po.get("solve") or {}).get("final_answer") or "" + if truncate_T2(orig_text, fa) is None: + continue + cases.append({ + "index": idx, + "problem_type": po.get("problem_type"), + "orig_solution": orig_text, + "orig_final_answer": fa, + }) + rng = random.Random(seed) + rng.shuffle(cases) + return cases[:max_cases] + + +# ---------- Prompt construction per case ---------- + +def build_case_prompts(case: dict, variant: str, ds_cell: dict) -> dict: + """Returns: {condition_name: user_message_string}.""" + var_info = ds_cell["variants"].get(variant, {}) + var_question = var_info.get("question", "") + if not var_question: + return {} + prompts = {} + is_kv = (variant == "kernel_variant") + + # canonical_T2: dataset's canonical variant solution truncated + canon_sol = var_info.get("solution", "") + if canon_sol: + canon_pre = truncate_T2(canon_sol, None) + if canon_pre: + prompts["canonical_T2"] = build_rescue_prompt(var_question, canon_pre) + + # own_T2: only for surface variants — model's own original-correct prefix renamed + if not is_kv: + rmap = var_info.get("map") or {} + own_pre = truncate_T2(case["orig_solution"], case.get("orig_final_answer")) + if own_pre and rmap: + renamed = rename_own_prefix(own_pre, rmap) + prompts["own_T2"] = build_rescue_prompt(var_question, renamed) + + # null: always available + prompts["null"] = build_null_prompt(var_question) + return prompts + + +# ---------- Per-condition runner ---------- + +async def run_one_condition(model: str, condition: str, user_msg: str, + case: dict, variant: str, ds_cell: dict) -> dict: + """Solve + grade a single condition for a single case. Returns a result dict.""" + var_info = ds_cell["variants"].get(variant, {}) + var_question = var_info.get("question", "") + canon_sol = var_info.get("solution", "") + problem_type = case["problem_type"] + t0 = time.time() + solve_resp = await solve(model, user_msg) + solve_dt = time.time() - t0 + if solve_resp["status"] != "success": + return { + "model": model, "variant": variant, "condition": condition, + "index": case["index"], "problem_type": problem_type, + "solve_status": "failed", + "solve_error": solve_resp["error"], + "solve_seconds": solve_dt, + "grade": None, + } + parsed = parse_solution(solve_resp["content"]) + if not parsed["solution"]: + return { + "model": model, "variant": variant, "condition": condition, + "index": case["index"], "problem_type": problem_type, + "solve_status": "parse_failed", + "solve_error": parsed.get("_parse_error"), + "solve_seconds": solve_dt, + "raw_solve_content": solve_resp["content"][:500], + "grade": None, + } + student_solution = parsed["solution"] + t1 = time.time() + grade_resp = await grade(problem_type, var_question, student_solution, canon_sol) + grade_dt = time.time() - t1 + if grade_resp["status"] != "success": + return { + "model": model, "variant": variant, "condition": condition, + "index": case["index"], "problem_type": problem_type, + "solve_status": "success", + "solve_seconds": solve_dt, + "grade_seconds": grade_dt, + "grade_status": "failed", + "grade_error": grade_resp["error"], + "student_solution_len": len(student_solution), + "student_final_answer": parsed["final_answer"], + "grade": None, + } + parsed_grade = parse_grade(grade_resp["content"]) + return { + "model": model, "variant": variant, "condition": condition, + "index": case["index"], "problem_type": problem_type, + "solve_status": "success", + "solve_seconds": solve_dt, + "grade_seconds": grade_dt, + "grade_status": "success", + "student_solution_len": len(student_solution), + "student_solution": student_solution, # full text for downstream analysis + "student_final_answer": parsed["final_answer"][:500], + "grade": parsed_grade["grade"], + "final_answer_correct": parsed_grade.get("final_answer_correct"), + "grade_feedback": (parsed_grade.get("detailed_feedback") or "")[:1000], + } + + +# ---------- Main run ---------- + +OUT_DIR = Path("/home/yurenh2/gap/analysis/rescue_results") +OUT_DIR.mkdir(parents=True, exist_ok=True) + + +def load_existing_keys(path: Path) -> set: + """Read jsonl checkpoint and return set of (cell_key, condition, index).""" + keys = set() + if not path.exists(): + return keys + with open(path) as f: + for line in f: + try: + d = json.loads(line) + keys.add((d["model"], d["variant"], d["condition"], d["index"])) + except Exception: + pass + return keys + + +async def run_all(num_cases_per_cell: int, dry_run: bool = False, models=None, + variants=None): + print(f"Loading dataset ...", flush=True) + ds = load_dataset_full() + print(f" loaded {len(ds)} problems", flush=True) + + out_path = OUT_DIR / f"rescue_{num_cases_per_cell}.jsonl" + existing = load_existing_keys(out_path) + print(f"Output: {out_path} (existing rows: {len(existing)})") + + models = models or SELECTED_MODELS + variants = variants or ALL_VARIANTS + + # Build the full task list + tasks_to_run = [] + cell_summary = {} + for model in models: + for variant in variants: + cases = find_flip_cases(model, variant, num_cases_per_cell) + cell_key = f"{model}/{variant}" + cell_summary[cell_key] = {"flip_cases_found": len(cases), + "added_tasks": 0} + for case in cases: + ds_cell = ds.get(case["index"]) + if ds_cell is None: + continue + prompts = build_case_prompts(case, variant, ds_cell) + for cond, user_msg in prompts.items(): + key = (model, variant, cond, case["index"]) + if key in existing: + continue + tasks_to_run.append((model, variant, cond, case, ds_cell, user_msg)) + cell_summary[cell_key]["added_tasks"] += 1 + + print(f"\nCell-level plan ({num_cases_per_cell} flip cases each):") + for k, v in sorted(cell_summary.items()): + print(f" {k:<46} found={v['flip_cases_found']:>3} new_tasks={v['added_tasks']:>4}") + total = len(tasks_to_run) + print(f"\nTotal new tasks: {total}") + if dry_run: + return + + if not tasks_to_run: + print("Nothing to do.") + return + + # Execute concurrently. Use a writer task to drain results into the jsonl. + fout = open(out_path, "a") + write_lock = asyncio.Lock() + completed = 0 + failed = 0 + started_at = time.time() + + async def run_and_write(model, variant, cond, case, ds_cell, user_msg): + nonlocal completed, failed + try: + res = await run_one_condition(model, cond, user_msg, case, variant, ds_cell) + except Exception as e: + res = { + "model": model, "variant": variant, "condition": cond, + "index": case["index"], "problem_type": case.get("problem_type"), + "solve_status": "exception", + "solve_error": f"{type(e).__name__}: {str(e)[:300]}", + "grade": None, + } + failed += 1 + async with write_lock: + fout.write(json.dumps(res) + "\n") + fout.flush() + completed += 1 + if completed % 25 == 0 or completed == total: + elapsed = time.time() - started_at + rate = completed / elapsed if elapsed > 0 else 0 + eta = (total - completed) / rate if rate > 0 else 0 + print(f" [{completed:>4}/{total}] elapsed={elapsed:>5.0f}s " + f"rate={rate:>4.1f}/s eta={eta:>5.0f}s " + f"failed_so_far={failed}", flush=True) + + awaitables = [run_and_write(*t) for t in tasks_to_run] + await asyncio.gather(*awaitables) + fout.close() + print(f"\nDone. {completed}/{total} written. Failed: {failed}.") + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--pilot", action="store_true", help="run only 5 cases per cell") + ap.add_argument("--cases", type=int, default=30, help="cases per cell (full run)") + ap.add_argument("--dry-run", action="store_true", help="print plan, don't call APIs") + ap.add_argument("--models", nargs="+", default=None) + ap.add_argument("--variants", nargs="+", default=None) + args = ap.parse_args() + n = 5 if args.pilot else args.cases + asyncio.run(run_all(n, dry_run=args.dry_run, + models=args.models, variants=args.variants)) + + +if __name__ == "__main__": + main() diff --git a/analysis/sc_success_and_difficulty.py b/analysis/sc_success_and_difficulty.py new file mode 100644 index 0000000..a8b44db --- /dev/null +++ b/analysis/sc_success_and_difficulty.py @@ -0,0 +1,192 @@ +"""Two follow-up analyses (zero API): +1. Per-model self-correction success rate: P(correct | SC) vs P(correct | no SC) +2. Difficulty-stratified surface vs kernel dichotomy +""" +from __future__ import annotations +import json +import sys +import statistics +from pathlib import Path +from collections import defaultdict + +THIS_DIR = Path(__file__).resolve().parent +sys.path.insert(0, str(THIS_DIR)) +from structural_overlap import find_variant_file, load_problems, RESULTS_DIR, SURFACE_VARIANTS +from self_correction import has_self_correction + + +# ----------------- 1. SC success rate per model ----------------- + +def sc_success_rate(): + base = RESULTS_DIR + models = sorted([d.name for d in base.iterdir() if d.is_dir()]) + + print("=" * 80) + print("PER-MODEL SELF-CORRECTION SUCCESS RATE") + print("(does an SC attempt improve probability of being correct?)") + print("=" * 80) + print() + + rows = [] + for m in models: + mdir = base / m + # Aggregate over all variants + n_sc_correct = 0 + n_sc_total = 0 + n_nosc_correct = 0 + n_nosc_total = 0 + for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]: + vp = find_variant_file(mdir, v) + if not vp: continue + for p in load_problems(vp): + text = (p.get("solve") or {}).get("solution") or "" + if not text: continue + correct = p.get("correct") + if correct is None: continue + if has_self_correction(text): + n_sc_total += 1 + if correct: n_sc_correct += 1 + else: + n_nosc_total += 1 + if correct: n_nosc_correct += 1 + if n_sc_total < 5 or n_nosc_total < 5: + continue + p_sc = n_sc_correct / n_sc_total + p_nosc = n_nosc_correct / n_nosc_total + delta = p_sc - p_nosc + # Wilson 95% CI on each rate + rows.append({ + "model": m, + "sc_n": n_sc_total, "sc_correct": n_sc_correct, "p_sc": p_sc, + "nosc_n": n_nosc_total, "nosc_correct": n_nosc_correct, "p_nosc": p_nosc, + "delta": delta, + }) + + rows.sort(key=lambda r: -r["sc_n"]) + print(f"{'Model':<22} {'#SC trials':>11} {'P(corr|SC)':>12} {'P(corr|noSC)':>13} {'Δ':>9}") + print("-" * 75) + for r in rows: + print(f"{r['model']:<22} {r['sc_n']:>11} " + f"{r['p_sc']*100:>10.1f}% {r['p_nosc']*100:>11.1f}% " + f"{r['delta']*100:>+7.1f}pp") + + json.dump(rows, open(THIS_DIR / "sc_success_per_model.json", "w"), indent=2) + return rows + + +# ----------------- 2. Difficulty stratified dichotomy ----------------- + +DATASET_DIR = Path("/home/yurenh2/gap/putnam-bench-anon/dataset") + +def load_difficulty_metadata(): + """Per-problem difficulty assignment using year/section/index heuristic. + + Per the paper's existing exposition, we derive Easy/Medium/Hard from the + problem index (1-2 = Easy, 3-4 = Medium, 5-6 = Hard, 7-8 = extra-hard tail) + because the dataset's `difficulty` field is heterogeneous. + """ + out = {} + for f in sorted(DATASET_DIR.glob("*.json")): + d = json.load(open(f)) + idx = d.get("index") + if not idx: continue + # Extract problem number from "YEAR-PART-NUM" + parts = idx.split("-") + if len(parts) != 3: continue + try: + num = int(parts[2]) + except ValueError: + continue + if num <= 2: bucket = "Easy" + elif num <= 4: bucket = "Medium" + elif num <= 6: bucket = "Hard" + else: bucket = "ExtraHard" + out[idx] = bucket + return out + + +def difficulty_stratified_dichotomy(): + print("\n\n" + "=" * 80) + print("DIFFICULTY-STRATIFIED ACCURACY (mean across 18 models)") + print("Easy/Medium/Hard buckets defined by problem index 1-2/3-4/5-6") + print("=" * 80) + print() + + diff = load_difficulty_metadata() + base = RESULTS_DIR + models = sorted([d.name for d in base.iterdir() if d.is_dir()]) + + # buckets[(model, variant, difficulty)] = (n, n_correct) + cells = defaultdict(lambda: [0, 0]) + for m in models: + mdir = base / m + for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]: + vp = find_variant_file(mdir, v) + if not vp: continue + for p in load_problems(vp): + idx = p.get("index") + correct = p.get("correct") + if idx is None or correct is None: continue + bucket = diff.get(idx, "Unknown") + cells[(m, v, bucket)][0] += 1 + if correct: cells[(m, v, bucket)][1] += 1 + + # Aggregate per (variant, difficulty) by averaging per-model rates + print(f"{'Variant':<24} {'Easy':>8} {'Medium':>8} {'Hard':>8} {'XHard':>8}") + print("-" * 60) + for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]: + row = {} + for bucket in ["Easy", "Medium", "Hard", "ExtraHard"]: + rates = [] + for m in models: + n, c = cells.get((m, v, bucket), [0, 0]) + if n >= 5: + rates.append(c / n) + row[bucket] = statistics.fmean(rates) * 100 if rates else None + print(f"{v:<24} " + f"{row['Easy']:>7.1f}% " if row['Easy'] is not None else f"{v:<24} {'-':>8}", + end="") + for bucket in ["Medium", "Hard", "ExtraHard"]: + print(f"{row[bucket]:>7.1f}% " if row[bucket] is not None else f"{'-':>8}", end="") + print() + + # Compute Δ_orig→KV per difficulty bucket + print(f"\n--- Δ original → KV per difficulty bucket ---") + for bucket in ["Easy", "Medium", "Hard", "ExtraHard"]: + orig_rates = [] + kv_rates = [] + for m in models: + no, co = cells.get((m, "original", bucket), [0, 0]) + nk, ck = cells.get((m, "kernel_variant", bucket), [0, 0]) + if no >= 5 and nk >= 5: + orig_rates.append(co / no) + kv_rates.append(ck / nk) + if orig_rates: + mo = statistics.fmean(orig_rates) * 100 + mk = statistics.fmean(kv_rates) * 100 + print(f" {bucket:<10} orig={mo:5.1f}% kv={mk:5.1f}% Δ={mk-mo:+.1f}pp") + + # Compute Δ_orig→GS per difficulty bucket + print(f"\n--- Δ original → GS (surface, hardest renamer) per difficulty bucket ---") + for bucket in ["Easy", "Medium", "Hard", "ExtraHard"]: + orig_rates = [] + gs_rates = [] + for m in models: + no, co = cells.get((m, "original", bucket), [0, 0]) + ng, cg = cells.get((m, "garbled_string", bucket), [0, 0]) + if no >= 5 and ng >= 5: + orig_rates.append(co / no) + gs_rates.append(cg / ng) + if orig_rates: + mo = statistics.fmean(orig_rates) * 100 + mg = statistics.fmean(gs_rates) * 100 + print(f" {bucket:<10} orig={mo:5.1f}% GS={mg:5.1f}% Δ={mg-mo:+.1f}pp") + + +def main(): + sc_success_rate() + difficulty_stratified_dichotomy() + + +if __name__ == "__main__": + main() diff --git a/analysis/self_correction.py b/analysis/self_correction.py new file mode 100644 index 0000000..5769647 --- /dev/null +++ b/analysis/self_correction.py @@ -0,0 +1,202 @@ +"""Self-correction / metacognition probe. + +Scan model trajectories for self-correction markers and compute: +1. Attempt rate (trajectory contains a self-correction marker) per (model, variant, group) +2. Whether self-correction attempt rate differs between stable / brittle-drift / rescued cases +3. Conditional success: among trajectories with a self-correction attempt, what fraction is graded CORRECT? + +Self-correction markers (case-insensitive, word-boundary): +- "wait" (e.g., "Wait, let me reconsider") +- "actually" (e.g., "Actually, I think...") +- "let me reconsider" +- "let me redo" +- "let me try again" +- "I made a mistake" +- "this is wrong" +- "on second thought" +- "correction:" +- "scratch that" +- "I was wrong" +- "let me start over" + +Uses three data sources: +A. The original 18-model results in /home/yurenh2/gap/results_new/ (stable + brittle drift + collapse) +B. The rescue trajectories in analysis/rescue_results/rescue_30.jsonl (3 conditions × 4 models × 5 variants) +""" +from __future__ import annotations +import json +import re +import os +import sys +import statistics +from pathlib import Path +from collections import defaultdict, Counter + +THIS_DIR = Path(__file__).resolve().parent +sys.path.insert(0, str(THIS_DIR)) +from structural_overlap import find_variant_file, load_problems, RESULTS_DIR, SURFACE_VARIANTS + +SC_PATTERNS = [ + re.compile(r"\bwait\b[,.]?\s+(let|actually|that|i)", re.IGNORECASE), + re.compile(r"\bactually[,.]\s", re.IGNORECASE), + re.compile(r"\blet\s+me\s+reconsider", re.IGNORECASE), + re.compile(r"\blet\s+me\s+redo", re.IGNORECASE), + re.compile(r"\blet\s+me\s+try\s+(this\s+)?again", re.IGNORECASE), + re.compile(r"\bi\s+made\s+a\s+mistake", re.IGNORECASE), + re.compile(r"\bthis\s+is\s+(wrong|incorrect)", re.IGNORECASE), + re.compile(r"\bon\s+second\s+thought", re.IGNORECASE), + re.compile(r"\bcorrection[:\s]", re.IGNORECASE), + re.compile(r"\bscratch\s+that", re.IGNORECASE), + re.compile(r"\bi\s+was\s+wrong", re.IGNORECASE), + re.compile(r"\blet\s+me\s+start\s+over", re.IGNORECASE), + re.compile(r"\bhmm[,.]\s+(actually|wait|that)", re.IGNORECASE), + re.compile(r"\bi\s+need\s+to\s+(redo|reconsider)", re.IGNORECASE), + re.compile(r"\boh\s+wait", re.IGNORECASE), +] + + +def has_self_correction(text: str) -> bool: + if not text: + return False + for pat in SC_PATTERNS: + if pat.search(text): + return True + return False + + +def count_sc_markers(text: str) -> int: + if not text: + return 0 + return sum(len(pat.findall(text)) for pat in SC_PATTERNS) + + +# ---------- Source A: 18-model original results ---------- + +def analyze_18_models(): + """Self-correction rates in original solver runs across all 18 models.""" + base = RESULTS_DIR + models = sorted([d.name for d in base.iterdir() if d.is_dir()]) + print(f"\n=== SELF-CORRECTION IN 18-MODEL ORIGINAL RUNS ===\n") + print(f"Markers used: {len(SC_PATTERNS)} regex patterns") + print(f"Definition: trajectory contains at least one match.\n") + + rows = [] + for m in models: + mdir = base / m + for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]: + vp = find_variant_file(mdir, v) + if not vp: + continue + problems = load_problems(vp) + n_total = 0 + n_sc = 0 + n_correct_sc = 0 + n_correct_total = 0 + n_wrong_sc = 0 + n_wrong_total = 0 + for p in problems: + text = (p.get("solve") or {}).get("solution") or "" + if not text: + continue + correct = p.get("correct") + if correct is None: + continue + n_total += 1 + sc = has_self_correction(text) + if sc: n_sc += 1 + if correct is True: + n_correct_total += 1 + if sc: n_correct_sc += 1 + else: + n_wrong_total += 1 + if sc: n_wrong_sc += 1 + if n_total > 0: + rows.append({ + "model": m, "variant": v, "n": n_total, + "sc_rate": n_sc / n_total, + "n_correct": n_correct_total, + "n_correct_sc_rate": n_correct_sc / max(1, n_correct_total), + "n_wrong": n_wrong_total, + "n_wrong_sc_rate": n_wrong_sc / max(1, n_wrong_total), + }) + + # Print compact table: per (variant) average across models + print(f"{'Variant':<24} {'mean SC%':>10} {'SC%|correct':>14} {'SC%|wrong':>12} {'asym (wrong-correct)':>22}") + print("-" * 90) + by_var = defaultdict(list) + for r in rows: + by_var[r["variant"]].append(r) + for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]: + rs = by_var.get(v, []) + if not rs: + continue + m_sc = statistics.fmean(r["sc_rate"] for r in rs) * 100 + m_sc_c = statistics.fmean(r["n_correct_sc_rate"] for r in rs) * 100 + m_sc_w = statistics.fmean(r["n_wrong_sc_rate"] for r in rs) * 100 + asym = m_sc_w - m_sc_c + print(f"{v:<24} {m_sc:>9.1f}% {m_sc_c:>13.1f}% {m_sc_w:>11.1f}% {asym:>+21.1f}pp") + + # Per-model leader board + print(f"\n{'Model':<22} {'mean SC% (all variants)':>26}") + print("-" * 50) + by_model = defaultdict(list) + for r in rows: + by_model[r["model"]].append(r["sc_rate"]) + model_avgs = sorted([(m, statistics.fmean(vs) * 100) for m, vs in by_model.items()], + key=lambda t: -t[1]) + for m, avg in model_avgs: + print(f"{m:<22} {avg:>25.1f}%") + + return rows + + +# ---------- Source B: rescue trajectories ---------- + +def analyze_rescue(): + path = THIS_DIR / "rescue_results/rescue_30.jsonl" + rows = [json.loads(l) for l in open(path)] + print(f"\n\n=== SELF-CORRECTION IN 1{{,}}529 RESCUE TRAJECTORIES ===\n") + + # Group by (model, variant, condition, grade) + counts = defaultdict(lambda: {"n": 0, "sc": 0}) + for r in rows: + text = r.get("student_solution") or "" + if not text: + continue + key = (r["model"], r["variant"], r["condition"], r.get("grade")) + counts[key]["n"] += 1 + if has_self_correction(text): + counts[key]["sc"] += 1 + + # Aggregate per (variant, condition, grade) + by_vcg = defaultdict(lambda: {"n": 0, "sc": 0}) + for k, d in counts.items(): + m, v, c, g = k + by_vcg[(v, c, g)]["n"] += d["n"] + by_vcg[(v, c, g)]["sc"] += d["sc"] + + print(f"{'Variant':<24} {'Condition':<14} {'CORRECT-SC%':>14} {'INCORRECT-SC%':>16}") + print("-" * 80) + for v in ["descriptive_long","descriptive_long_confusing","descriptive_long_misleading","garbled_string","kernel_variant"]: + for c in ["null", "canonical_T2", "own_T2"]: + cor = by_vcg.get((v, c, "CORRECT"), {"n": 0, "sc": 0}) + inc = by_vcg.get((v, c, "INCORRECT"), {"n": 0, "sc": 0}) + if cor["n"] == 0 and inc["n"] == 0: + continue + sc_c = cor["sc"] / max(1, cor["n"]) * 100 if cor["n"] else 0 + sc_i = inc["sc"] / max(1, inc["n"]) * 100 if inc["n"] else 0 + print(f"{v:<24} {c:<14} {sc_c:>11.1f}% (n={cor['n']:>3}) {sc_i:>13.1f}% (n={inc['n']:>3})") + print() + + return counts + + +def main(): + rows_18 = analyze_18_models() + json.dump(rows_18, open(THIS_DIR / "self_correction_18models.json", "w"), indent=2) + counts_rescue = analyze_rescue() + print("\nSaved -> analysis/self_correction_18models.json") + + +if __name__ == "__main__": + main() diff --git a/analysis/spotcheck_clean.py b/analysis/spotcheck_clean.py new file mode 100644 index 0000000..52ddc43 --- /dev/null +++ b/analysis/spotcheck_clean.py @@ -0,0 +1,181 @@ +"""Spot-check Unicode cleaning by side-by-side comparison. + +For a stratified sample of problems, load: + - the ORIGINAL kernel_variant.solution from the backup tarball + - the CLEANED kernel_variant.solution from the current dataset +and print them side-by-side so the user can verify that the cleaner +preserved meaning. + +Sampling strategy: + - 5 most complex (by original Unicode count) — stress test + - 3 medium complexity — typical case + - 2 surface-variant samples — to confirm rename + LaTeX preserved +""" +from __future__ import annotations +import json +import sys +import tarfile +from pathlib import Path + +CURRENT_DIR = Path("/home/yurenh2/gap/putnam-bench-anon/dataset") +BACKUP_TAR = sorted(Path("/home/yurenh2/gap/analysis/dataset_backups").glob( + "putnam-bench-anon_dataset_*.tar.gz"))[-1] + + +def count_unicode(text: str) -> int: + return sum(1 for c in (text or "") if ord(c) > 127) + + +def load_backup_problems(): + """Yield (idx, problem_dict) from the backup tarball.""" + with tarfile.open(BACKUP_TAR, "r:gz") as tar: + for member in tar.getmembers(): + if not member.isfile() or not member.name.endswith(".json"): + continue + f = tar.extractfile(member) + if not f: + continue + try: + d = json.load(f) + yield d.get("index"), d + except Exception: + continue + + +def main(): + print(f"Backup tar: {BACKUP_TAR}") + print("Building Unicode-count index over 1051 problems ...") + + # Index originals by Unicode count in kernel_variant.solution + by_uni_count = [] # (unicode_count, idx, solution_len) + backup_data = {} + for idx, d in load_backup_problems(): + if not idx: + continue + backup_data[idx] = d + kv_sol = (d.get("variants") or {}).get("kernel_variant", {}).get("solution", "") + uc = count_unicode(kv_sol) + by_uni_count.append((uc, idx, len(kv_sol))) + + by_uni_count.sort(reverse=True) + print(f" loaded {len(backup_data)} problems from backup") + + # Pick samples + samples = [] + samples.extend([(idx, "TOP COMPLEXITY") for _, idx, _ in by_uni_count[:5]]) + mid = len(by_uni_count) // 2 + samples.extend([(idx, "MEDIUM COMPLEXITY") + for _, idx, _ in by_uni_count[mid:mid + 3]]) + # Bottom = least Unicode but still non-zero + nonzero = [t for t in by_uni_count if t[0] > 0] + samples.extend([(idx, "LOW COMPLEXITY") + for _, idx, _ in nonzero[-2:]]) + + print(f"\nSelected {len(samples)} samples:\n") + for idx, label in samples: + print(f" {label:<20} {idx}") + + print("\n" + "=" * 80) + print("SIDE-BY-SIDE SPOT-CHECK") + print("=" * 80) + + for case_idx, (idx, label) in enumerate(samples, 1): + print(f"\n{'#' * 80}") + print(f"# CASE {case_idx}/{len(samples)}: {idx} ({label})") + print(f"{'#' * 80}") + + backup_problem = backup_data.get(idx) + current_path = CURRENT_DIR / f"{idx}.json" + if not backup_problem or not current_path.exists(): + print(f" ! missing data for {idx}") + continue + current_problem = json.load(open(current_path)) + + # Compare kernel_variant.solution by default. For LOW COMPLEXITY cases + # we also show the original `solution` field if it differs. + for field_path in [("variants", "kernel_variant", "solution")]: + orig_text = backup_problem + curr_text = current_problem + for key in field_path: + orig_text = (orig_text or {}).get(key) if isinstance(orig_text, dict) else None + curr_text = (curr_text or {}).get(key) if isinstance(curr_text, dict) else None + if not orig_text and not curr_text: + continue + orig_text = orig_text or "" + curr_text = curr_text or "" + field_label = ".".join(field_path) + uni_before = count_unicode(orig_text) + uni_after = count_unicode(curr_text) + len_before = len(orig_text) + len_after = len(curr_text) + print(f"\n--- field: {field_label} ---") + print(f" before: {len_before} chars, {uni_before} non-ASCII") + print(f" after: {len_after} chars, {uni_after} non-ASCII " + f"(Δ len {len_after - len_before:+d})") + print(f"\n >>> ORIGINAL (first 600 chars) <<<") + print(" " + orig_text[:600].replace("\n", "\n ")) + print(f"\n >>> CLEANED (first 600 chars) <<<") + print(" " + curr_text[:600].replace("\n", "\n ")) + + if uni_after > 0: + print(f" !!! WARNING: cleaned output still has {uni_after} non-ASCII chars") + + # Sanity: are LaTeX braces balanced in the cleaned text? + n_open = curr_text.count("{") + n_close = curr_text.count("}") + n_lparen = curr_text.count("(") + n_rparen = curr_text.count(")") + n_lbrack = curr_text.count("[") + n_rbrack = curr_text.count("]") + print(f" brace balance: {{ {n_open} | }} {n_close} " + f"( {n_lparen} | ) {n_rparen} " + f"[ {n_lbrack} | ] {n_rbrack}") + + # Final aggregate balance check across the entire cleaned dataset + print("\n" + "=" * 80) + print("AGGREGATE BRACE BALANCE CHECK (entire cleaned dataset)") + print("=" * 80) + total_diff_brace = 0 + total_diff_paren = 0 + total_diff_brack = 0 + files_with_brace_imbalance = 0 + files_with_paren_imbalance = 0 + files_with_brack_imbalance = 0 + for f in sorted(CURRENT_DIR.glob("*.json")): + d = json.load(open(f)) + # Concatenate all text fields + bag = [] + for k in ("question", "solution"): + bag.append(d.get(k) or "") + for vk, vd in (d.get("variants") or {}).items(): + if isinstance(vd, dict): + for k in ("question", "solution"): + bag.append(vd.get(k) or "") + all_text = "\n".join(bag) + diff_brace = all_text.count("{") - all_text.count("}") + diff_paren = all_text.count("(") - all_text.count(")") + diff_brack = all_text.count("[") - all_text.count("]") + if diff_brace != 0: + files_with_brace_imbalance += 1 + total_diff_brace += abs(diff_brace) + if diff_paren != 0: + files_with_paren_imbalance += 1 + total_diff_paren += abs(diff_paren) + if diff_brack != 0: + files_with_brack_imbalance += 1 + total_diff_brack += abs(diff_brack) + + print(f" files with unbalanced {{...}}: {files_with_brace_imbalance}/1051" + f" (total |Δ| = {total_diff_brace})") + print(f" files with unbalanced (...): {files_with_paren_imbalance}/1051" + f" (total |Δ| = {total_diff_paren})") + print(f" files with unbalanced [...]: {files_with_brack_imbalance}/1051" + f" (total |Δ| = {total_diff_brack})") + print() + print(" (Imbalance is not necessarily a bug — math text often legitimately") + print(" contains unbalanced delimiters in display formulas; this is just") + print(" an order-of-magnitude check.)") + + +if __name__ == "__main__": + main() diff --git a/analysis/structural_overlap.py b/analysis/structural_overlap.py new file mode 100644 index 0000000..284c139 --- /dev/null +++ b/analysis/structural_overlap.py @@ -0,0 +1,523 @@ +"""Stable-vs-Brittle structural overlap analysis (label-free). + +Pipeline: +1. For each (model, surface_variant) cell, load original and variant trajectories. +2. Pull the deterministic rename map from /home/yurenh2/gap/putnam-bench-anon/dataset/. +3. Canonicalize both trajectories: replace variant variables with placeholders + (via inverse rename map). Original trajectory: replace canonical variables + with the same placeholders. Both texts then live in a shared placeholder space. +4. Compute multiple non-LLM structural metrics on (orig_canonical, var_canonical): + - Token Jaccard + - Bigram Jaccard + - Equation-set Jaccard (math-block extraction) + - Prefix Jaccard (first 30% of each canonical text) +5. Stratify by group (stable vs brittle) within each (model, variant) cell. +6. Mann-Whitney U test on each metric for stable vs brittle. + +Surface variants only (rename map available). Kernel handled separately. +""" + +from __future__ import annotations +import json +import re +import os +from pathlib import Path +from collections import Counter, defaultdict +from typing import Dict, List, Tuple, Optional + +import statistics + +DATASET_DIR = Path("/home/yurenh2/gap/putnam-bench-anon/dataset") +RESULTS_DIR = Path("/home/yurenh2/gap/results_new") + +SURFACE_VARIANTS = ["descriptive_long", "descriptive_long_confusing", + "descriptive_long_misleading", "garbled_string"] + + +# ---------- I/O helpers ---------- + +def load_problems(path: Path) -> List[dict]: + d = json.load(open(path)) + return d.get("problems") or d.get("detailed_results") or [] + + +def find_variant_file(model_dir: Path, variant: str) -> Optional[Path]: + files = sorted(os.listdir(model_dir)) + cands = [f for f in files + if f.endswith(f"_{variant}.json") + and "regraded" not in f and "comparison" not in f + and not f.endswith(f"_{variant}2.json")] + if not cands and variant == "garbled_string": + cands = [f for f in files if f.endswith("_gs.json")] + return model_dir / cands[0] if cands else None + + +def load_dataset_maps() -> Dict[str, Dict[str, Dict[str, str]]]: + """Returns: {problem_index: {variant: {orig_var_name: variant_var_name}}}""" + out = {} + for f in sorted(DATASET_DIR.glob("*.json")): + d = json.load(open(f)) + idx = d.get("index") + variants = d.get("variants", {}) + cell = {} + for v in SURFACE_VARIANTS: + vd = variants.get(v, {}) + mp_str = vd.get("map") + if isinstance(mp_str, str): + # The map is stored as a Python repr string; eval it safely + try: + mp = eval(mp_str, {"__builtins__": {}}, {}) + if isinstance(mp, dict): + cell[v] = {str(k): str(v) for k, v in mp.items()} + except Exception: + pass + elif isinstance(mp_str, dict): + cell[v] = {str(k): str(v) for k, v in mp_str.items()} + out[idx] = cell + return out + + +# ---------- Canonicalization ---------- + +def canonicalize_text(text: str, var_to_placeholder: Dict[str, str]) -> str: + """Replace each variable name in text with its canonical placeholder. + + Sort by length desc to avoid prefix collisions (e.g., 'xs' before 'x'). + Use word-boundary regex for ASCII-identifier-like names; literal replace + for non-identifier names (like garbled strings, which are also alpha). + """ + if not text: + return "" + # Sort longest-first to avoid 'al' eating into 'almondtree' + items = sorted(var_to_placeholder.items(), key=lambda kv: -len(kv[0])) + out = text + for var, ph in items: + if not var: + continue + # Use word-boundary so we only replace whole tokens. Variables in this + # dataset are all alphanumeric. + pat = r"(? str: + return re.sub(r"\s+", " ", text).strip() + + +# ---------- Tokenization ---------- + +_TOKEN_RE = re.compile(r"[A-Za-z_][A-Za-z0-9_]*|\d+|[^\sA-Za-z0-9_]") + +def tokens(text: str) -> List[str]: + return _TOKEN_RE.findall(text or "") + + +def bigrams(toks: List[str]) -> List[str]: + return [f"{toks[i]} {toks[i+1]}" for i in range(len(toks) - 1)] + + +# ---------- Math block extraction ---------- + +_MATH_BLOCKS = [ + re.compile(r"\$\$(.+?)\$\$", re.DOTALL), + re.compile(r"\\\[(.+?)\\\]", re.DOTALL), + re.compile(r"\$(.+?)\$", re.DOTALL), + re.compile(r"\\begin\{(?:equation|align|gather)\*?\}(.+?)\\end\{(?:equation|align|gather)\*?\}", re.DOTALL), +] + +def extract_math_blocks(text: str, min_len: int = 8) -> List[str]: + found = [] + for pat in _MATH_BLOCKS: + found.extend(pat.findall(text or "")) + # Lightweight normalization: collapse whitespace, strip + out = [normalize_whitespace(b) for b in found if b.strip()] + # Filter trivial fragments like '$n$', '$0$', '$x$' that saturate Jaccard + return [b for b in out if len(b) >= min_len] + + +# ---------- Metrics ---------- + +def jaccard(a: set, b: set) -> float: + if not a and not b: + return 1.0 + return len(a & b) / max(1, len(a | b)) + + +def metric_token_jaccard(a: str, b: str) -> float: + return jaccard(set(tokens(a)), set(tokens(b))) + + +def metric_bigram_jaccard(a: str, b: str) -> float: + return jaccard(set(bigrams(tokens(a))), set(bigrams(tokens(b)))) + + +def metric_prefix_token_jaccard(a: str, b: str, frac: float = 0.3) -> float: + """Jaccard over the first frac of tokens from each side.""" + ta, tb = tokens(a), tokens(b) + na, nb = max(1, int(len(ta) * frac)), max(1, int(len(tb) * frac)) + return jaccard(set(ta[:na]), set(tb[:nb])) + + +def metric_prefix_bigram_jaccard(a: str, b: str, frac: float = 0.3) -> float: + ta, tb = tokens(a), tokens(b) + na, nb = max(1, int(len(ta) * frac)), max(1, int(len(tb) * frac)) + return jaccard(set(bigrams(ta[:na])), set(bigrams(tb[:nb]))) + + +def metric_equation_jaccard(a: str, b: str) -> float: + ea = set(extract_math_blocks(a)) + eb = set(extract_math_blocks(b)) + return jaccard(ea, eb) + + +def metric_lcp_tokens(a: str, b: str) -> int: + """Length of the longest common prefix of canonicalized token streams. + + Directly tests Codex's thesis 'early loss of structural overlap with the + model's own original reasoning under renaming'. Larger LCP -> the model + started its variant trajectory the same way it started the original. + """ + ta, tb = tokens(a), tokens(b) + n = min(len(ta), len(tb)) + i = 0 + while i < n and ta[i] == tb[i]: + i += 1 + return i + + +def metric_lcp_normalized(a: str, b: str) -> float: + """LCP length normalized by the shorter trajectory length, in [0, 1].""" + ta, tb = tokens(a), tokens(b) + n = min(len(ta), len(tb)) + if n == 0: + return 0.0 + return metric_lcp_tokens(a, b) / n + + +def metric_lcp_first1k(a: str, b: str) -> float: + """LCP length capped to first-1000-token comparison, normalized to [0, 1].""" + ta, tb = tokens(a), tokens(b) + ta, tb = ta[:1000], tb[:1000] + n = min(len(ta), len(tb)) + if n == 0: + return 0.0 + i = 0 + while i < n and ta[i] == tb[i]: + i += 1 + return i / n + + +def metric_directional_coverage(a: str, b: str) -> float: + """|tokens_a ∩ tokens_b| / |tokens_a|. Length-asymmetric. + + Reads as: 'what fraction of the original's vocabulary survives in the variant?' + More robust to length differences than symmetric Jaccard. + """ + ta = set(tokens(a)) + tb = set(tokens(b)) + if not ta: + return 0.0 + return len(ta & tb) / len(ta) + + +def metric_window_token_jaccard(a: str, b: str, window: int = 600) -> float: + """Jaccard restricted to the first `window` tokens on each side.""" + ta = tokens(a)[:window] + tb = tokens(b)[:window] + return jaccard(set(ta), set(tb)) + + +def metric_window_bigram_jaccard(a: str, b: str, window: int = 600) -> float: + ta = tokens(a)[:window] + tb = tokens(b)[:window] + return jaccard(set(bigrams(ta)), set(bigrams(tb))) + + +# ---------- Stat helpers ---------- + +def bootstrap_ci_delta_median(xs: List[float], ys: List[float], + n_iter: int = 1000, seed: int = 0) -> Tuple[float, float]: + """Percentile bootstrap 95% CI on median(xs) - median(ys).""" + import random + rng = random.Random(seed) + if not xs or not ys: + return float("nan"), float("nan") + ds = [] + for _ in range(n_iter): + rs = [xs[rng.randrange(len(xs))] for _ in range(len(xs))] + rb = [ys[rng.randrange(len(ys))] for _ in range(len(ys))] + ds.append(statistics.median(rs) - statistics.median(rb)) + ds.sort() + lo = ds[int(0.025 * n_iter)] + hi = ds[int(0.975 * n_iter)] + return lo, hi + + +def bootstrap_ci_cohens_d(xs: List[float], ys: List[float], + n_iter: int = 1000, seed: int = 0) -> Tuple[float, float]: + import random + rng = random.Random(seed) + if len(xs) < 2 or len(ys) < 2: + return float("nan"), float("nan") + ds = [] + for _ in range(n_iter): + rs = [xs[rng.randrange(len(xs))] for _ in range(len(xs))] + rb = [ys[rng.randrange(len(ys))] for _ in range(len(ys))] + sm, bm = statistics.fmean(rs), statistics.fmean(rb) + ssd = statistics.pstdev(rs) + bsd = statistics.pstdev(rb) + pooled = (((len(rs)-1)*ssd**2 + (len(rb)-1)*bsd**2) + / max(1, len(rs)+len(rb)-2)) ** 0.5 + if pooled > 0: + ds.append((sm - bm) / pooled) + if not ds: + return float("nan"), float("nan") + ds.sort() + lo = ds[int(0.025 * len(ds))] + hi = ds[int(0.975 * len(ds))] + return lo, hi + + +def mann_whitney_u(xs: List[float], ys: List[float]) -> Tuple[float, float]: + """Returns (U, normal_approx_p_two_sided). Pure-python, no scipy. + + Used only as a screening signal — for the rebuttal we'll use scipy if + available; this is a fallback so we don't add a dependency. + """ + n1, n2 = len(xs), len(ys) + if n1 == 0 or n2 == 0: + return float("nan"), float("nan") + combined = [(v, 0) for v in xs] + [(v, 1) for v in ys] + combined.sort(key=lambda t: t[0]) + # Average ranks for ties + ranks = [0.0] * len(combined) + i = 0 + while i < len(combined): + j = i + while j + 1 < len(combined) and combined[j + 1][0] == combined[i][0]: + j += 1 + avg = (i + j) / 2.0 + 1 # 1-indexed + for k in range(i, j + 1): + ranks[k] = avg + i = j + 1 + R1 = sum(ranks[k] for k in range(len(combined)) if combined[k][1] == 0) + U1 = R1 - n1 * (n1 + 1) / 2.0 + U2 = n1 * n2 - U1 + U = min(U1, U2) + # Normal approx (no tie correction) + mu = n1 * n2 / 2.0 + sd = (n1 * n2 * (n1 + n2 + 1) / 12.0) ** 0.5 + if sd == 0: + return U, float("nan") + z = (U - mu) / sd + # Two-sided p via erf approx + import math + p = math.erfc(abs(z) / math.sqrt(2)) + return U, p + + +# ---------- Cell analysis ---------- + +COLLAPSE_MIN_CHARS = 200 +COLLAPSE_RATIO = 0.25 # variant_len < ratio * orig_len => collapse + + +def is_collapse(orig_text: str, var_text: str) -> bool: + return (len(var_text) < COLLAPSE_MIN_CHARS + or len(var_text) < COLLAPSE_RATIO * max(1, len(orig_text))) + + +def analyze_cell(model_name: str, variant: str, dataset_maps: dict, + model_dir: Path) -> Optional[dict]: + orig_path = find_variant_file(model_dir, "original") + var_path = find_variant_file(model_dir, variant) + if not orig_path or not var_path: + return None + + orig_by = {p["index"]: p for p in load_problems(orig_path)} + var_by = {p["index"]: p for p in load_problems(var_path)} + + common = set(orig_by) & set(var_by) + pairs_stable_drift = [] # (orig_canon, var_canon, problem_type) — non-collapse + pairs_brittle_drift = [] # non-collapse brittle + pairs_brittle_collapse = [] # short variant text + n_stable_collapse = 0 # almost always 0 but tracked for completeness + + for idx in common: + po, pv = orig_by[idx], var_by[idx] + if po.get("correct") is not True: + continue + var_correct = pv.get("correct") + if var_correct is None: + continue + orig_text = (po.get("solve") or {}).get("solution") or "" + var_text = (pv.get("solve") or {}).get("solution") or "" + if not orig_text or not var_text: + continue + rmap = dataset_maps.get(idx, {}).get(variant) + if not rmap: + continue + # Canonicalize + canon_to_ph = {k: f"__V{i}__" for i, k in enumerate(rmap.keys())} + var_to_ph = {rmap[k]: canon_to_ph[k] for k in rmap} + orig_canon = canonicalize_text(orig_text, canon_to_ph) + var_canon = canonicalize_text(var_text, var_to_ph) + sample = { + "index": idx, + "problem_type": po.get("problem_type"), + "orig_canon": orig_canon, + "var_canon": var_canon, + "orig_len": len(orig_text), + "var_len": len(var_text), + } + collapse = is_collapse(orig_text, var_text) + if var_correct is True: + if collapse: + n_stable_collapse += 1 + else: + pairs_stable_drift.append(sample) + else: + if collapse: + pairs_brittle_collapse.append(sample) + else: + pairs_brittle_drift.append(sample) + + if not pairs_stable_drift or not pairs_brittle_drift: + return None + + metrics = { + "token_jaccard": metric_token_jaccard, + "bigram_jaccard": metric_bigram_jaccard, + "directional_coverage": metric_directional_coverage, + "window_token_jaccard": metric_window_token_jaccard, + "window_bigram_jaccard": metric_window_bigram_jaccard, + "equation_jaccard": metric_equation_jaccard, + } + # Headline metric for bootstrap + noise floor (the others stay descriptive) + HEADLINE = "token_jaccard" + + # Pre-tokenize once per pair to amortize cost (used by token/bigram/window metrics). + for p in pairs_stable_drift + pairs_brittle_drift: + p["_otok"] = tokens(p["orig_canon"]) + p["_vtok"] = tokens(p["var_canon"]) + p["_oset"] = set(p["_otok"]) + p["_vset"] = set(p["_vtok"]) + + def fast_token_jaccard(p): + a, b = p["_oset"], p["_vset"] + if not a and not b: + return 1.0 + return len(a & b) / max(1, len(a | b)) + + def fast_token_jaccard_pair(pa, pb): + a, b = pa["_oset"], pb["_vset"] + if not a and not b: + return 1.0 + return len(a & b) / max(1, len(a | b)) + + out = { + "model": model_name, + "variant": variant, + "n_stable_drift": len(pairs_stable_drift), + "n_brittle_drift": len(pairs_brittle_drift), + "n_stable_collapse": n_stable_collapse, + "n_brittle_collapse": len(pairs_brittle_collapse), + "brittle_collapse_rate": (len(pairs_brittle_collapse) + / max(1, len(pairs_brittle_collapse) + len(pairs_brittle_drift))), + "metrics": {}, + } + # Compute all descriptive metrics (one pass per pair, no bootstrap) + for mname, mfn in metrics.items(): + s_vals = [mfn(p["orig_canon"], p["var_canon"]) for p in pairs_stable_drift] + b_vals = [mfn(p["orig_canon"], p["var_canon"]) for p in pairs_brittle_drift] + U, p = mann_whitney_u(s_vals, b_vals) + sm, bm = statistics.fmean(s_vals), statistics.fmean(b_vals) + ssd = statistics.pstdev(s_vals) if len(s_vals) > 1 else 0 + bsd = statistics.pstdev(b_vals) if len(b_vals) > 1 else 0 + pooled = (((len(s_vals)-1)*ssd**2 + (len(b_vals)-1)*bsd**2) + / max(1, len(s_vals)+len(b_vals)-2)) ** 0.5 + d = (sm - bm) / pooled if pooled > 0 else 0.0 + out["metrics"][mname] = { + "stable_median": statistics.median(s_vals), + "stable_mean": sm, + "brittle_median": statistics.median(b_vals), + "brittle_mean": bm, + "delta_median": statistics.median(s_vals) - statistics.median(b_vals), + "delta_mean": sm - bm, + "cohens_d": d, + "U": U, + "p_two_sided": p, + } + + # Bootstrap + noise floor only on headline metric + s_vals = [fast_token_jaccard(p) for p in pairs_stable_drift] + b_vals = [fast_token_jaccard(p) for p in pairs_brittle_drift] + ci_lo, ci_hi = bootstrap_ci_delta_median(s_vals, b_vals, n_iter=400) + d_lo, d_hi = bootstrap_ci_cohens_d(s_vals, b_vals, n_iter=400) + out["metrics"][HEADLINE]["delta_median_ci"] = [ci_lo, ci_hi] + out["metrics"][HEADLINE]["cohens_d_ci"] = [d_lo, d_hi] + + # Random-pairing noise floor for headline: pair stable orig with random other-problem variant + import random as _r + rng = _r.Random(42) + nf_vals = [] + n = len(pairs_stable_drift) + if n >= 2: + for _ in range(min(400, n * (n - 1))): + i = rng.randrange(n) + j = rng.randrange(n) + while j == i: + j = rng.randrange(n) + nf_vals.append(fast_token_jaccard_pair(pairs_stable_drift[i], + pairs_stable_drift[j])) + out["metrics"][HEADLINE]["noise_floor_median"] = ( + statistics.median(nf_vals) if nf_vals else None) + out["metrics"][HEADLINE]["noise_floor_mean"] = ( + statistics.fmean(nf_vals) if nf_vals else None) + out["metrics"][HEADLINE]["noise_floor_n"] = len(nf_vals) + return out + + +def main(): + print("Loading dataset rename maps ...", flush=True) + dataset_maps = load_dataset_maps() + print(f" loaded {len(dataset_maps)} problems", flush=True) + + # Multi-cell sweep across all models × 4 surface variants + # Run all 18 models — non-LLM, fast. + all_models = sorted([d.name for d in RESULTS_DIR.iterdir() if d.is_dir()]) + print(f"Models: {all_models}") + all_results = [] + + print(f"\n{'Cell':<46} {'nSd':>4} {'nBd':>4} {'col%':>5} " + f"{'sMed':>6} {'bMed':>6} {'nfMed':>6} " + f"{'d':>6} {'d95CI':>14} {'p':>9}") + print("-" * 122) + + for m in all_models: + for v in SURFACE_VARIANTS: + mdir = RESULTS_DIR / m + if not mdir.exists(): + continue + res = analyze_cell(m, v, dataset_maps, mdir) + if res is None: + continue + all_results.append(res) + md = res["metrics"]["token_jaccard"] + label = f"{m} / {v}" + ci_lo, ci_hi = md["cohens_d_ci"] + ci_str = f"[{ci_lo:+.2f}, {ci_hi:+.2f}]" + print(f"{label:<46} {res['n_stable_drift']:>4} {res['n_brittle_drift']:>4} " + f"{res['brittle_collapse_rate']*100:>4.0f}% " + f"{md['stable_median']:>6.3f} {md['brittle_median']:>6.3f} " + f"{md['noise_floor_median']:>6.3f} " + f"{md['cohens_d']:>+6.2f} {ci_str:>14} {md['p_two_sided']:>9.1e}") + + out_path = Path("/home/yurenh2/gap/analysis/structural_overlap_results.json") + json.dump(all_results, open(out_path, "w"), indent=2) + print(f"\nSaved -> {out_path} ({len(all_results)} cells)") + + +if __name__ == "__main__": + main() diff --git a/analysis/topic_problemtype_interaction.py b/analysis/topic_problemtype_interaction.py new file mode 100644 index 0000000..405b33a --- /dev/null +++ b/analysis/topic_problemtype_interaction.py @@ -0,0 +1,112 @@ +"""KV fragility broken down by Topic × Problem-type (proof vs calculation).""" +from __future__ import annotations +import json +import sys +import statistics +from pathlib import Path +from collections import defaultdict + +THIS_DIR = Path(__file__).resolve().parent +sys.path.insert(0, str(THIS_DIR)) +from structural_overlap import find_variant_file, load_problems, RESULTS_DIR, SURFACE_VARIANTS + +DATASET_DIR = Path("/home/yurenh2/gap/putnam-bench-anon/dataset") + + +def load_metadata(): + out = {} + for f in sorted(DATASET_DIR.glob("*.json")): + d = json.load(open(f)) + idx = d.get("index") + if not idx: continue + out[idx] = { + "tag": d.get("tag"), + "problem_type": d.get("problem_type"), + } + return out + + +def main(): + metadata = load_metadata() + base = RESULTS_DIR + models = sorted([d.name for d in base.iterdir() if d.is_dir()]) + + # cells[(topic, ptype, model, variant)] = (n, n_correct) + cells = defaultdict(lambda: [0, 0]) + for m in models: + mdir = base / m + for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]: + vp = find_variant_file(mdir, v) + if not vp: continue + for p in load_problems(vp): + idx = p.get("index") + correct = p.get("correct") + if idx is None or correct is None: continue + md = metadata.get(idx, {}) + tag = md.get("tag") + ptype = md.get("problem_type") + if not tag or not ptype: continue + tags = tag if isinstance(tag, list) else [tag] + for t in tags: + if t not in ["ALG", "ANA", "NT", "COMB", "GEO"]: continue + cells[(t, ptype, m, v)][0] += 1 + if correct: cells[(t, ptype, m, v)][1] += 1 + + print("=" * 80) + print("ACCURACY BY TOPIC × PROBLEM-TYPE × VARIANT (mean across 18 models)") + print("=" * 80) + print() + + for ptype in ["proof", "calculation"]: + print(f"\n--- {ptype.upper()} ---\n") + print(f"{'Topic':<6}", end="") + for v in ["original", "garbled_string", "kernel_variant"]: + short = {"original":"orig","garbled_string":"GS","kernel_variant":"KV"}[v] + print(f" {short:>6}", end="") + print(f" {'Δ_GS':>7} {'Δ_KV':>7}") + print("-" * 50) + for t in ["ALG", "ANA", "NT", "COMB", "GEO"]: + orig_rates = [] + gs_rates = [] + kv_rates = [] + for m in models: + no, co = cells.get((t, ptype, m, "original"), [0, 0]) + ng, cg = cells.get((t, ptype, m, "garbled_string"), [0, 0]) + nk, ck = cells.get((t, ptype, m, "kernel_variant"), [0, 0]) + if no >= 5 and ng >= 5 and nk >= 5: + orig_rates.append(co / no) + gs_rates.append(cg / ng) + kv_rates.append(ck / nk) + if not orig_rates: continue + mo = statistics.fmean(orig_rates) * 100 + mg = statistics.fmean(gs_rates) * 100 + mk = statistics.fmean(kv_rates) * 100 + print(f"{t:<6} {mo:>5.1f}% {mg:>5.1f}% {mk:>5.1f}% {mg-mo:>+5.1f}pp {mk-mo:>+5.1f}pp") + + print("\n\n=== KEY DIFFERENTIAL: Δ KV by Topic for proof vs calculation ===\n") + print(f"{'Topic':<6} {'proof Δ':>10} {'calc Δ':>10} {'(calc - proof)':>16}") + print("-" * 50) + for t in ["ALG", "ANA", "NT", "COMB", "GEO"]: + deltas = {} + for ptype in ["proof", "calculation"]: + orig_rates = [] + kv_rates = [] + for m in models: + no, co = cells.get((t, ptype, m, "original"), [0, 0]) + nk, ck = cells.get((t, ptype, m, "kernel_variant"), [0, 0]) + if no >= 5 and nk >= 5: + orig_rates.append(co / no) + kv_rates.append(ck / nk) + if orig_rates: + deltas[ptype] = (statistics.fmean(kv_rates) - statistics.fmean(orig_rates)) * 100 + if "proof" in deltas and "calculation" in deltas: + diff = deltas["calculation"] - deltas["proof"] + print(f"{t:<6} {deltas['proof']:>+9.1f}pp {deltas['calculation']:>+9.1f}pp {diff:>+15.1f}pp") + elif "proof" in deltas: + print(f"{t:<6} {deltas['proof']:>+9.1f}pp {'-':>10} {'-':>16}") + elif "calculation" in deltas: + print(f"{t:<6} {'-':>10} {deltas['calculation']:>+9.1f}pp {'-':>16}") + + +if __name__ == "__main__": + main() diff --git a/analysis/unicode_audit.py b/analysis/unicode_audit.py new file mode 100644 index 0000000..afe5679 --- /dev/null +++ b/analysis/unicode_audit.py @@ -0,0 +1,238 @@ +"""Unicode audit for PutnamGAP dataset. + +Scans all JSON files in the dataset, finds all non-ASCII characters in text +fields (question, solution across all variants), and reports: + +1. How many files contain Unicode +2. Top Unicode characters by total frequency with suggested LaTeX replacements +3. Which fields are most affected +4. Per-file tallies +5. Samples of lines showing each unusual character in context +6. A machine-readable JSON report for downstream cleaning + +Does NOT modify any file. Read-only audit. +""" +from __future__ import annotations +import json +import sys +import unicodedata +from pathlib import Path +from collections import defaultdict, Counter + +# Both copies of the dataset +DIRS = [ + Path("/home/yurenh2/gap/putnam-bench-anon/dataset"), + Path("/home/yurenh2/gap/putnamsup/PutnamGAP"), +] + +# Text-bearing fields we care about +TOP_LEVEL_TEXT_FIELDS = ["question", "solution"] +VARIANT_TEXT_FIELDS = ["question", "solution"] +VARIANT_KEYS = [ + "descriptive_long", + "descriptive_long_confusing", + "descriptive_long_misleading", + "garbled_string", + "kernel_variant", + "original_kernel_variant", +] + +# Suggested LaTeX replacements for common math Unicode. (Informational — the +# audit does not apply these.) Each entry is (unicode_char, latex_suggestion). +SUGGESTED_LATEX = { + # Greek lower case + "α": r"\alpha", "β": r"\beta", "γ": r"\gamma", "δ": r"\delta", + "ε": r"\varepsilon", "ζ": r"\zeta", "η": r"\eta", "θ": r"\theta", + "ι": r"\iota", "κ": r"\kappa", "λ": r"\lambda", "μ": r"\mu", + "ν": r"\nu", "ξ": r"\xi", "π": r"\pi", "ρ": r"\rho", "σ": r"\sigma", + "τ": r"\tau", "υ": r"\upsilon", "φ": r"\varphi", "χ": r"\chi", + "ψ": r"\psi", "ω": r"\omega", + # Greek upper case + "Α": "A", "Β": "B", "Γ": r"\Gamma", "Δ": r"\Delta", "Ε": "E", + "Ζ": "Z", "Η": "H", "Θ": r"\Theta", "Λ": r"\Lambda", "Ξ": r"\Xi", + "Π": r"\Pi", "Σ": r"\Sigma", "Φ": r"\Phi", "Ψ": r"\Psi", + "Ω": r"\Omega", + # Math operators & relations + "≤": r"\leq", "≥": r"\geq", "≠": r"\neq", "≈": r"\approx", + "≡": r"\equiv", "±": r"\pm", "∓": r"\mp", "×": r"\times", + "÷": r"\div", "·": r"\cdot", "∙": r"\cdot", + "∞": r"\infty", "∂": r"\partial", "∇": r"\nabla", "∆": r"\Delta", + "∑": r"\sum", "∏": r"\prod", "∫": r"\int", "√": r"\sqrt{}", + "∮": r"\oint", "∴": r"\therefore", "∵": r"\because", + "∈": r"\in", "∉": r"\notin", "⊂": r"\subset", "⊆": r"\subseteq", + "⊃": r"\supset", "⊇": r"\supseteq", "∪": r"\cup", "∩": r"\cap", + "∧": r"\land", "∨": r"\lor", "¬": r"\neg", + "→": r"\to", "←": r"\leftarrow", "↔": r"\leftrightarrow", + "⇒": r"\Rightarrow", "⇐": r"\Leftarrow", "⇔": r"\Leftrightarrow", + "⟨": r"\langle", "⟩": r"\rangle", "⌊": r"\lfloor", "⌋": r"\rfloor", + "⌈": r"\lceil", "⌉": r"\rceil", + "∅": r"\emptyset", "ℝ": r"\mathbb{R}", "ℂ": r"\mathbb{C}", + "ℕ": r"\mathbb{N}", "ℤ": r"\mathbb{Z}", "ℚ": r"\mathbb{Q}", + # Subscripts / superscripts (common ones only) + "₀": "_0", "₁": "_1", "₂": "_2", "₃": "_3", "₄": "_4", "₅": "_5", + "₆": "_6", "₇": "_7", "₈": "_8", "₉": "_9", + "⁰": "^0", "¹": "^1", "²": "^2", "³": "^3", "⁴": "^4", "⁵": "^5", + "⁶": "^6", "⁷": "^7", "⁸": "^8", "⁹": "^9", + "ₐ": "_a", "ᵢ": "_i", "ⱼ": "_j", "ₖ": "_k", "ₙ": "_n", + # Fractions + "½": r"\frac{1}{2}", "⅓": r"\frac{1}{3}", "⅔": r"\frac{2}{3}", + "¼": r"\frac{1}{4}", "¾": r"\frac{3}{4}", + # Punctuation / whitespace + "—": "---", "–": "--", "…": r"\ldots", + "‘": "`", "’": "'", "“": "``", "”": "''", + "°": r"^\circ", + "\u00A0": " (nbsp)", # non-breaking space + "\u2009": " (thin space)", + "\u200b": " (zero-width space)", + "\u2026": r"\ldots", + "\u2212": "-", # Unicode minus vs hyphen +} + + +def is_non_ascii(ch: str) -> bool: + return ord(ch) > 127 + + +def extract_text_fields(problem: dict): + """Yield (field_path, text) for every text-bearing field in a problem.""" + idx = problem.get("index", "?") + for k in TOP_LEVEL_TEXT_FIELDS: + v = problem.get(k) + if isinstance(v, str): + yield f"{idx}:{k}", v + for vk in VARIANT_KEYS: + vd = (problem.get("variants") or {}).get(vk) + if not isinstance(vd, dict): + continue + for k in VARIANT_TEXT_FIELDS: + v = vd.get(k) + if isinstance(v, str): + yield f"{idx}:variants.{vk}.{k}", v + + +def audit_dir(dataset_dir: Path, label: str): + print(f"\n{'=' * 76}") + print(f"Auditing {label}: {dataset_dir}") + print(f"{'=' * 76}") + + files = sorted(dataset_dir.glob("*.json")) + print(f"Files: {len(files)}") + + char_counter = Counter() # unicode char -> total occurrences + field_char_counter = defaultdict(Counter) # field_name -> Counter + files_with_unicode = set() # set of problem indices + per_field_counts = Counter() # {question, solution, variants.DL.question, ...} -> n files with unicode + examples = defaultdict(list) # char -> list of (context, path) + total_chars = 0 + total_unicode = 0 + + for f in files: + try: + d = json.load(open(f)) + except Exception as e: + print(f" ! {f.name}: JSON parse error: {e}") + continue + file_had_unicode = False + for path, text in extract_text_fields(d): + if not text: + continue + total_chars += len(text) + nas = [c for c in text if is_non_ascii(c)] + if not nas: + continue + file_had_unicode = True + total_unicode += len(nas) + # tally + for c in nas: + char_counter[c] += 1 + # short field label (strip problem index prefix) + short = path.split(":", 1)[1] + field_char_counter[short][c] += 1 + per_field_counts[short] += 1 + # collect up to 3 examples per char with ±20 char context + if len(examples[c]) < 3: + idx = text.find(c) + start = max(0, idx - 25) + end = min(len(text), idx + 25) + ctx = text[start:end].replace("\n", " ") + examples[c].append((ctx, path)) + if file_had_unicode: + files_with_unicode.add(d.get("index", f.name)) + + # Report + print(f"\nTotal characters scanned: {total_chars:,}") + print(f"Non-ASCII characters: {total_unicode:,} ({total_unicode/total_chars*100:.2f}%)") + print(f"Files with any Unicode: {len(files_with_unicode)}/{len(files)} " + f"({len(files_with_unicode)/len(files)*100:.1f}%)") + print(f"Distinct Unicode code points: {len(char_counter)}") + + print(f"\n--- Top 40 Unicode characters by frequency ---") + print(f"{'char':<6} {'hex':<8} {'count':>8} name / suggested LaTeX") + print("-" * 76) + for c, n in char_counter.most_common(40): + name = unicodedata.name(c, "?") + hex_val = f"U+{ord(c):04X}" + suggestion = SUGGESTED_LATEX.get(c, "") + display_c = c if c.isprintable() and ord(c) > 0x20 else repr(c) + print(f"{display_c:<6} {hex_val:<8} {n:>8} {name[:45]:<45} {suggestion}") + + # Per-field breakdown + print(f"\n--- Unicode per field (top 15 fields with most Unicode) ---") + print(f"{'field':<50} {'total unicode':>15}") + print("-" * 70) + for field, cnt in Counter({f: sum(c.values()) for f, c in field_char_counter.items()}).most_common(15): + print(f"{field:<50} {cnt:>15}") + + # Examples for top 10 chars + print(f"\n--- Example contexts for top 10 Unicode chars ---") + for c, n in char_counter.most_common(10): + name = unicodedata.name(c, "?") + display_c = c if c.isprintable() and ord(c) > 0x20 else repr(c) + print(f"\n {display_c} (U+{ord(c):04X}, {name}, n={n}):") + for ctx, path in examples[c][:2]: + print(f" [{path}]") + print(f" …{ctx}…") + + # Machine-readable summary + summary = { + "dataset_dir": str(dataset_dir), + "n_files": len(files), + "n_files_with_unicode": len(files_with_unicode), + "pct_files_with_unicode": 100 * len(files_with_unicode) / max(1, len(files)), + "total_chars": total_chars, + "total_unicode": total_unicode, + "distinct_codepoints": len(char_counter), + "top_chars": [ + {"char": c, "codepoint": f"U+{ord(c):04X}", + "name": unicodedata.name(c, "?"), + "count": n, + "suggested_latex": SUGGESTED_LATEX.get(c, ""), + "examples": [{"path": path, "context": ctx} + for ctx, path in examples[c][:3]]} + for c, n in char_counter.most_common(80) + ], + "per_field_unicode_counts": dict( + Counter({f: sum(c.values()) for f, c in field_char_counter.items()}) + .most_common(30)), + "files_with_unicode_indices": sorted(files_with_unicode), + } + return summary + + +def main(): + all_summaries = [] + for d in DIRS: + if d.exists(): + s = audit_dir(d, d.name) + s["label"] = d.name + all_summaries.append(s) + else: + print(f" (skipping missing dir {d})") + + out_path = Path("/home/yurenh2/gap/analysis/unicode_audit.json") + json.dump(all_summaries, open(out_path, "w"), indent=2, ensure_ascii=False) + print(f"\n\nSaved machine-readable summary -> {out_path}") + + +if __name__ == "__main__": + main() diff --git a/analysis/unicode_clean.py b/analysis/unicode_clean.py new file mode 100644 index 0000000..cea3cbe --- /dev/null +++ b/analysis/unicode_clean.py @@ -0,0 +1,729 @@ +"""Unicode -> LaTeX cleaner for PutnamGAP dataset (v2). + +Improvements over v1: + - Pre-normalize via NFKD then strip combining diacritics so accented + letters collapse to their ASCII base. + - Group adjacent subscript/superscript runs into {...}: x_1_0 -> x_{10}, + x^2^3 -> x^{23}. + - Wrap the argument of radical commands: \\sqrt-followed-by-X -> \\sqrt{X} + where X is either an identifier/number run or a balanced paren/bracket + group or a single \\-command (optionally followed by {...} arguments). + - Explicit replacements for symbols that previously fell through: + star, blacksquare/QED, fraction slash, dagger, etc. + - Deletes lone combining diacritics and decorative box-drawing characters. + +Operates IN PLACE on both dataset copies. Backup in a tarball first. +""" +from __future__ import annotations +import json +import re +import sys +import unicodedata +from pathlib import Path +from collections import Counter + +DIRS = [ + Path("/home/yurenh2/gap/putnam-bench-anon/dataset"), + Path("/home/yurenh2/gap/putnamsup/PutnamGAP"), +] + +TOP_LEVEL_TEXT_FIELDS = ["question", "solution"] +VARIANT_TEXT_FIELDS = ["question", "solution"] +VARIANT_KEYS = [ + "descriptive_long", + "descriptive_long_confusing", + "descriptive_long_misleading", + "garbled_string", + "kernel_variant", + "original_kernel_variant", +] + + +# Sentinels placed during char substitution, resolved in a later pass that +# can look at the following characters to extract the radical argument. +SENT_SQRT = "\x01SQRT\x01" +SENT_CBRT = "\x01CBRT\x01" +SENT_FRT = "\x01FRT\x01" + +REPLACEMENTS: dict = { + # Whitespace -> normal space + "\u00A0": " ", "\u2002": " ", "\u2003": " ", "\u2004": " ", + "\u2005": " ", "\u2006": " ", "\u2007": " ", "\u2008": " ", + "\u2009": " ", "\u200A": " ", "\u200B": "", "\u200C": "", + "\u200D": "", "\u202F": " ", "\u205F": " ", "\u3000": " ", + "\uFEFF": "", + + # Dashes / hyphens + # NOTE: in this dataset (kernel-variant LLM-generated math text) the + # EN DASH is used pervasively as a math minus sign, not a typographic + # en-dash, so we map it to a single hyphen-minus rather than the + # typographic `--`. The EM DASH stays as `---` (prose convention). + "\u2010": "-", "\u2011": "-", + "\u2012": "-", # FIGURE DASH + "\u2013": "-", # EN DASH (was `--`; common usage here is math minus) + "\u2014": "---", # EM DASH (typographic prose break) + "\u2015": "---", # HORIZONTAL BAR + "\u2212": "-", + + # Quotation marks + "\u2018": "`", "\u2019": "'", "\u201A": ",", "\u201B": "`", + "\u201C": "``", "\u201D": "''", "\u201E": ",,", + "\u00AB": "<<", "\u00BB": ">>", + + # Punctuation / miscellany + "\u2022": "*", + "\u2023": "*", + "\u2027": ".", + "\u2026": r"\ldots", + "\u00B7": r"\cdot", + "\u00B0": r"^\circ", + "\u2032": "'", "\u2033": "''", "\u2034": "'''", "\u2035": "`", + "\u2605": r"\star", + "\u2606": r"\star", + "\u25A0": r"\blacksquare", + "\u25A1": r"\square", + "\u220E": r"\blacksquare", + "\u2020": r"\dagger", + "\u2021": r"\ddagger", + "\u2044": "/", + + # Sub/super digits + "\u2070": "^0", "\u00B9": "^1", "\u00B2": "^2", "\u00B3": "^3", + "\u2074": "^4", "\u2075": "^5", "\u2076": "^6", "\u2077": "^7", + "\u2078": "^8", "\u2079": "^9", + "\u207A": "^+", "\u207B": "^-", "\u207C": "^=", "\u207D": "^(", "\u207E": "^)", + "\u2080": "_0", "\u2081": "_1", "\u2082": "_2", "\u2083": "_3", + "\u2084": "_4", "\u2085": "_5", "\u2086": "_6", "\u2087": "_7", + "\u2088": "_8", "\u2089": "_9", + "\u208A": "_+", "\u208B": "_-", "\u208C": "_=", "\u208D": "_(", "\u208E": "_)", + + # Latin sub/super letters + "\u2090": "_a", "\u2091": "_e", "\u2092": "_o", "\u2093": "_x", + "\u2095": "_h", "\u2096": "_k", "\u2097": "_l", "\u2098": "_m", + "\u2099": "_n", "\u209A": "_p", "\u209B": "_s", "\u209C": "_t", + "\u2C7C": "_j", # LATIN SUBSCRIPT SMALL LETTER J + "\u1D30": "^D", "\u1D31": "^E", "\u1D33": "^G", "\u1D34": "^H", + "\u1D35": "^I", "\u1D36": "^J", "\u1D37": "^K", "\u1D38": "^L", + "\u1D39": "^M", "\u1D3A": "^N", "\u1D3C": "^O", "\u1D3E": "^P", + "\u1D3F": "^R", "\u1D40": "^T", "\u1D41": "^U", "\u1D42": "^W", + "\u1D43": "^a", "\u1D47": "^b", "\u1D48": "^d", "\u1D49": "^e", + "\u1D4D": "^g", "\u1D4F": "^k", "\u1D50": "^m", "\u1D52": "^o", + "\u1D56": "^p", "\u1D57": "^t", "\u1D58": "^u", "\u1D5B": "^v", + "\u1D62": "_i", "\u1D63": "_r", "\u1D64": "_u", "\u1D65": "_v", + "\u2071": "^i", "\u207F": "^n", + + # Greek lower case + "\u03B1": r"\alpha", "\u03B2": r"\beta", "\u03B3": r"\gamma", + "\u03B4": r"\delta", "\u03B5": r"\varepsilon", "\u03B6": r"\zeta", + "\u03B7": r"\eta", "\u03B8": r"\theta", "\u03B9": r"\iota", + "\u03BA": r"\kappa", "\u03BB": r"\lambda", "\u03BC": r"\mu", + "\u03BD": r"\nu", "\u03BE": r"\xi", "\u03BF": "o", + "\u03C0": r"\pi", "\u03C1": r"\rho", "\u03C2": r"\varsigma", + "\u03C3": r"\sigma", "\u03C4": r"\tau", "\u03C5": r"\upsilon", + "\u03C6": r"\varphi", "\u03C7": r"\chi", "\u03C8": r"\psi", + "\u03C9": r"\omega", + "\u03D5": r"\phi", "\u03D1": r"\vartheta", "\u03D6": r"\varpi", + "\u03F1": r"\varrho", "\u03F5": r"\epsilon", + # Greek upper case + "\u0391": "A", "\u0392": "B", "\u0393": r"\Gamma", + "\u0394": r"\Delta", "\u0395": "E", "\u0396": "Z", + "\u0397": "H", "\u0398": r"\Theta", "\u0399": "I", + "\u039A": "K", "\u039B": r"\Lambda", "\u039C": "M", + "\u039D": "N", "\u039E": r"\Xi", "\u039F": "O", + "\u03A0": r"\Pi", "\u03A1": "P", "\u03A3": r"\Sigma", + "\u03A4": "T", "\u03A5": r"\Upsilon", "\u03A6": r"\Phi", + "\u03A7": "X", "\u03A8": r"\Psi", "\u03A9": r"\Omega", + + # Math operators / relations + "\u2200": r"\forall", "\u2203": r"\exists", "\u2204": r"\nexists", + "\u2205": r"\emptyset", + "\u2208": r"\in", "\u2209": r"\notin", "\u220B": r"\ni", + "\u220F": r"\prod", "\u2210": r"\coprod", "\u2211": r"\sum", + "\u2213": r"\mp", "\u00B1": r"\pm", + "\u2214": r"\dotplus", + "\u2217": "*", "\u2218": r"\circ", "\u2219": r"\cdot", + "\u221D": r"\propto", + "\u221E": r"\infty", + "\u2220": r"\angle", "\u2221": r"\measuredangle", + "\u2225": r"\parallel", "\u2226": r"\nparallel", + "\u2227": r"\land", "\u2228": r"\lor", + "\u2229": r"\cap", "\u222A": r"\cup", + "\u222B": r"\int", "\u222C": r"\iint", "\u222D": r"\iiint", + "\u222E": r"\oint", "\u222F": r"\oiint", + "\u2234": r"\therefore", "\u2235": r"\because", + "\u2236": ":", "\u2237": "::", + "\u223C": r"\sim", "\u2243": r"\simeq", "\u2245": r"\cong", + "\u2248": r"\approx", "\u224D": r"\asymp", + "\u2250": r"\doteq", + "\u2260": r"\neq", "\u2261": r"\equiv", "\u2262": r"\not\equiv", + "\u2264": r"\leq", "\u2265": r"\geq", + "\u2266": r"\leqq", "\u2267": r"\geqq", + "\u226A": r"\ll", "\u226B": r"\gg", + "\u2270": r"\not\leq", "\u2271": r"\not\geq", + "\u2282": r"\subset", "\u2283": r"\supset", + "\u2284": r"\not\subset", "\u2285": r"\not\supset", + "\u2286": r"\subseteq", "\u2287": r"\supseteq", + "\u2288": r"\not\subseteq", "\u2289": r"\not\supseteq", + "\u228A": r"\subsetneq", "\u228B": r"\supsetneq", + "\u2295": r"\oplus", "\u2296": r"\ominus", + "\u2297": r"\otimes", "\u2298": r"\oslash", "\u2299": r"\odot", + "\u22A2": r"\vdash", "\u22A3": r"\dashv", + "\u22A4": r"\top", "\u22A5": r"\bot", + "\u22A8": r"\models", + "\u22C0": r"\bigwedge", "\u22C1": r"\bigvee", + "\u22C2": r"\bigcap", "\u22C3": r"\bigcup", + "\u22C5": r"\cdot", "\u22C6": r"\star", + "\u22EE": r"\vdots", "\u22EF": r"\cdots", + "\u22F1": r"\ddots", + + # Arrows + "\u2190": r"\leftarrow", "\u2192": r"\to", + "\u2191": r"\uparrow", "\u2193": r"\downarrow", + "\u2194": r"\leftrightarrow", "\u2195": r"\updownarrow", + "\u21A0": r"\twoheadrightarrow", + "\u21A6": r"\mapsto", + "\u21D0": r"\Leftarrow", "\u21D2": r"\Rightarrow", + "\u21D1": r"\Uparrow", "\u21D3": r"\Downarrow", + "\u21D4": r"\Leftrightarrow", + "\u27F6": r"\longrightarrow", "\u27F5": r"\longleftarrow", + "\u27F9": r"\Longrightarrow", "\u27F8": r"\Longleftarrow", + "\u27FA": r"\Longleftrightarrow", + + # Delimiters + "\u2016": r"\|", + "\u2308": r"\lceil", "\u2309": r"\rceil", + "\u230A": r"\lfloor", "\u230B": r"\rfloor", + "\u27E8": r"\langle", "\u27E9": r"\rangle", + "\u27EA": r"\llangle", "\u27EB": r"\rrangle", + + # Blackboard / script letters + "\u2102": r"\mathbb{C}", "\u210D": r"\mathbb{H}", + "\u2115": r"\mathbb{N}", "\u2119": r"\mathbb{P}", + "\u211A": r"\mathbb{Q}", "\u211D": r"\mathbb{R}", + "\u2124": r"\mathbb{Z}", + "\u2113": r"\ell", "\u210F": r"\hbar", + "\u2202": r"\partial", "\u2207": r"\nabla", "\u2118": r"\wp", + "\u2133": r"\mathcal{M}", "\u2112": r"\mathcal{L}", + "\u211B": r"\mathcal{R}", "\u2110": r"\mathcal{I}", + "\u2130": r"\mathcal{E}", "\u2132": "F", + + # Fractions with precomposed forms + "\u00BC": r"\frac{1}{4}", "\u00BD": r"\frac{1}{2}", "\u00BE": r"\frac{3}{4}", + "\u2153": r"\frac{1}{3}", "\u2154": r"\frac{2}{3}", + "\u2155": r"\frac{1}{5}", "\u2156": r"\frac{2}{5}", + "\u2157": r"\frac{3}{5}", "\u2158": r"\frac{4}{5}", + "\u2159": r"\frac{1}{6}", "\u215A": r"\frac{5}{6}", + "\u215B": r"\frac{1}{8}", "\u215C": r"\frac{3}{8}", + "\u215D": r"\frac{5}{8}", "\u215E": r"\frac{7}{8}", + + # Multiplication / division + "\u00D7": r"\times", "\u00F7": r"\div", + + # Misc + "\u00A7": r"\S", + "\u00B6": r"\P", + "\u00A9": "(c)", "\u00AE": "(R)", "\u2122": "(TM)", + "\u00A3": r"\pounds", "\u20AC": "EUR", + "\u00B5": r"\mu", + + # Additional math symbols + "\u2216": r"\setminus", + "\u2223": r"\mid", + "\u2224": r"\nmid", + "\u2225": r"\parallel", # duplicate of above, safe + "\u2226": r"\nparallel", + "\u22BB": r"\veebar", + "\u22BC": r"\barwedge", + "\u2238": r"\dot{-}", + "\u22C8": r"\bowtie", + "\u22CE": r"\curlyvee", + "\u22CF": r"\curlywedge", + + # Perp and triangle family + "\u27C2": r"\perp", + "\u22A5": r"\bot", # already present but safe + "\u25B3": r"\triangle", + "\u25B4": r"\blacktriangle", + "\u25BD": r"\triangledown", + "\u25BE": r"\blacktriangledown", + "\u25C1": r"\triangleleft", + "\u25C2": r"\blacktriangleleft", + "\u25B7": r"\triangleright", + "\u25B8": r"\blacktriangleright", + + # Square / box operators + "\u2293": r"\sqcap", + "\u2294": r"\sqcup", + "\u22A1": r"\boxdot", + "\u229E": r"\boxplus", + "\u229F": r"\boxminus", + "\u22A0": r"\boxtimes", + + # Preceq / succeq family + "\u227A": r"\prec", + "\u227B": r"\succ", + "\u227C": r"\preceq", + "\u227D": r"\succeq", + "\u2280": r"\nprec", + "\u2281": r"\nsucc", + "\u22E0": r"\npreceq", + "\u22E1": r"\nsucceq", + + # Double-square brackets + "\u27E6": r"\llbracket", + "\u27E7": r"\rrbracket", + + # Card-suit decorative (drop) + "\u2660": "", # spade + "\u2661": "", + "\u2662": "", + "\u2663": "", # club + "\u2664": "", + "\u2665": "", # heart + "\u2666": "", # diamond + + # Musical / dingbat decorations (drop) + "\u266A": "", # eighth note + "\u266B": "", # beamed eighth notes + "\u2713": r"\checkmark", + "\u2717": r"\times", + + # Curved delimiters / bracket extension pieces -- these are used by the + # kernel generator to draw big parentheses/brackets around multi-line + # expressions (like matrices). They are purely decorative in plain text + # and we drop them. + "\u239B": "", "\u239C": "", "\u239D": "", # ( upper/mid/lower + "\u239E": "", "\u239F": "", "\u23A0": "", # ) upper/mid/lower + "\u23A1": "", "\u23A2": "", "\u23A3": "", # [ upper/mid/lower + "\u23A4": "", "\u23A5": "", "\u23A6": "", # ] upper/mid/lower + "\u23A7": "", "\u23A8": "", "\u23A9": "", # { upper/middle/lower + "\u23AA": "", # { extension + "\u23AB": "", "\u23AC": "", "\u23AD": "", # } upper/middle/lower + "\u23AE": "", # integral extension + "\u23AF": "", # horizontal line extension + "\u23B0": "", "\u23B1": "", # upper/lower curly bracket + "\u23B2": "", "\u23B3": "", # summation top/bottom + "\u23B4": "", "\u23B5": "", # top/bottom square bracket + "\u23B6": "", "\u23B7": "", # bottom square bracket w/tick + "\u23D0": "", # vertical line extension + + # Combining over/underlines are stripped by the combining-mark regex + + # Additional remaining symbols found after first clean pass + "\u00AD": "", # SOFT HYPHEN -> delete + "\u2215": "/", # DIVISION SLASH + "\u25A2": r"\square", # WHITE SQUARE WITH ROUNDED CORNERS + "\u2718": r"\times", # HEAVY BALLOT X + "\u3008": r"\langle", # CJK LEFT ANGLE BRACKET + "\u3009": r"\rangle", # CJK RIGHT ANGLE BRACKET + "\u2254": ":=", # COLON EQUALS + "\u2255": "=:", # EQUALS COLON + "\u2198": r"\searrow", # SOUTH EAST ARROW + "\u2197": r"\nearrow", # NORTH EAST ARROW + "\u2199": r"\swarrow", + "\u2196": r"\nwarrow", + "\u21A9": r"\hookleftarrow", + "\u21AA": r"\hookrightarrow", + "\u21BC": r"\leftharpoonup", + "\u21BD": r"\leftharpoondown", + "\u21BE": r"\upharpoonright", + "\u21BF": r"\upharpoonleft", + "\u21C0": r"\rightharpoonup", + "\u21C1": r"\rightharpoondown", + "\u21C2": r"\downharpoonright", + "\u21C3": r"\downharpoonleft", + "\u21CC": r"\rightleftharpoons", + "\u21E2": r"\dashrightarrow", + "\u21E0": r"\dashleftarrow", + "\u2277": r"\gtrless", + "\u2276": r"\lessgtr", + + # Private Use Area characters are almost always OCR garbage or + # font-specific glyphs; drop them. + "\uF8EB": "", "\uF8F6": "", + "\uF8FE": "", "\uF8FD": "", "\uF8FC": "", "\uF8FB": "", + "\uF8EF": "", "\uF8F0": "", "\uF8F1": "", "\uF8F2": "", + + # A few more rare but meaningful math symbols + "\u2322": r"\frown", + "\u2323": r"\smile", + "\u226D": r"\not\asymp", + "\u22A7": r"\models", + "\u22B2": r"\vartriangleleft", + "\u22B3": r"\vartriangleright", + "\u22B4": r"\trianglelefteq", + "\u22B5": r"\trianglerighteq", + + # Small-caps letters sometimes emitted by OCR (collapse to plain letter) + "\u026A": "I", # LATIN LETTER SMALL CAPITAL I + "\u1D00": "A", + "\u1D04": "C", + "\u1D05": "D", + "\u1D07": "E", + "\u0262": "G", + "\u029C": "H", + + # Remaining math symbols found after pass 2 + "\u2A01": r"\bigoplus", + "\u2A02": r"\bigotimes", + "\u2A00": r"\bigodot", + "\u2A03": r"\biguplus", + "\u2A04": r"\biguplus", + "\u2A05": r"\bigsqcap", + "\u2A06": r"\bigsqcup", + "\u2272": r"\lesssim", + "\u2273": r"\gtrsim", + "\u226E": r"\not<", + "\u226F": r"\not>", + "\u27EE": "(", # MATHEMATICAL LEFT FLATTENED PARENTHESIS + "\u27EF": ")", # MATHEMATICAL RIGHT FLATTENED PARENTHESIS + "\u2610": r"\square", # BALLOT BOX + "\u2611": r"\checkmark", + "\u2612": r"\times", + + # Root sentinels (wrapped in a later pass) + "\u221A": SENT_SQRT, + "\u221B": SENT_CBRT, + "\u221C": SENT_FRT, +} + + +_COMBINING_MARK_RE = re.compile( + r"[\u0300-\u036F\u1AB0-\u1AFF\u1DC0-\u1DFF\u20D0-\u20FF\uFE20-\uFE2F]") +_BOX_DRAWING_RE = re.compile(r"[\u2500-\u257F\u2580-\u259F]") + +# Characters from scripts that have no place in English/Greek mathematics +# and are clearly OCR noise when they appear. Drop them wholesale. Latin and +# Greek are preserved; extended Latin letters with diacritics are still +# handled by the NFKD fallback. +_OCR_NOISE_SCRIPTS_RE = re.compile( + r"[\u0400-\u04FF" # Cyrillic + r"\u0500-\u052F" # Cyrillic Supplement + r"\u0530-\u058F" # Armenian + r"\u0590-\u05FF" # Hebrew + r"\u0600-\u06FF" # Arabic + r"\u0700-\u074F" # Syriac + r"\u0750-\u077F" # Arabic Supplement + r"\u0780-\u07BF" # Thaana + r"\u0900-\u097F" # Devanagari + r"\u0B80-\u0BFF" # Tamil + r"\u0C00-\u0C7F" # Telugu + r"\u0C80-\u0CFF" # Kannada + r"\u0D00-\u0D7F" # Malayalam + r"\u0D80-\u0DFF" # Sinhala + r"\u0E00-\u0E7F" # Thai + r"\u0E80-\u0EFF" # Lao + r"\u0F00-\u0FFF" # Tibetan + r"\u1000-\u109F" # Myanmar + r"\u10A0-\u10FF" # Georgian + r"\u1100-\u11FF" # Hangul Jamo + r"\u1400-\u167F" # Unified Canadian Aboriginal Syllabics + r"\u1680-\u169F" # Ogham + r"\u16A0-\u16FF" # Runic + r"\u1700-\u171F" # Tagalog + r"\u1780-\u17FF" # Khmer + r"\u1800-\u18AF" # Mongolian + r"\u1900-\u194F" # Limbu + r"\u3040-\u309F" # Hiragana + r"\u30A0-\u30FF" # Katakana + r"\u3000-\u303F" # CJK Symbols and Punctuation (incl. ideographic full stop) + r"\u3100-\u312F" # Bopomofo + r"\u3130-\u318F" # Hangul Compatibility Jamo + r"\u3190-\u319F" # Kanbun + r"\u3400-\u4DBF" # CJK Extension A + r"\u4E00-\u9FFF" # CJK Unified Ideographs + r"\uA000-\uA48F" # Yi Syllables + r"\uAC00-\uD7AF" # Hangul Syllables + r"\uE000-\uF8FF" # Private Use Area + r"\uFE00-\uFE0F" # Variation Selectors + r"\uFE30-\uFE4F" # CJK Compatibility Forms (vertical presentation + # brackets that NFKD-decompose to literal { } [ ] etc., + # which would corrupt our brace balance — drop them) + r"\uFE50-\uFE6F" # Small Form Variants (compatibility forms) + r"\uFFFC\uFFFD" # Object/Replacement Character + r"]" +) + +# Emoji and pictographs (outside the BMP, need surrogate handling) +_EMOJI_RE = re.compile( + "[" + "\U0001F000-\U0001F9FF" # Emoji blocks + "\U0001FA00-\U0001FAFF" # Symbols & Pictographs Extended-A + "\U0001F1E6-\U0001F1FF" # Regional indicator symbols + "\U0001F3FB-\U0001F3FF" # Emoji modifier fitzpatrick + "\U00020000-\U0002FA1F" # CJK Extensions B-F + "]", + flags=re.UNICODE +) + + +def prestrip(text: str) -> str: + """Strip decorative and OCR-noise characters BEFORE char substitution. + + Important: we do NOT run NFKD here because NFKD decomposes subscript / + superscript digits (e.g. \u2080 -> '0') before our explicit REPLACEMENTS + entries can rewrite them as `_0`. NFKD is applied later only as a + fallback for characters that survive the explicit substitution pass + (e.g. accented Latin letters). + """ + if not text: + return text + text = _BOX_DRAWING_RE.sub("", text) + # Lone combining marks are orphaned when the base character was something + # we otherwise transformed; strip them up front. + text = _COMBINING_MARK_RE.sub("", text) + # Strip OCR-noise scripts (Cyrillic / Arabic / CJK / etc.) that have no + # place in English-Greek mathematical prose. + text = _OCR_NOISE_SCRIPTS_RE.sub("", text) + # Strip emoji / pictographs (clearly LLM-emitted noise in math text). + text = _EMOJI_RE.sub("", text) + return text + + +def char_substitute(text: str, unmapped: Counter) -> str: + """Apply REPLACEMENTS char-by-char. Any char not in REPLACEMENTS is left + in place so that _nfkd_fallback (run next) has a chance to handle it + via compatibility decomposition. A trailing space is appended to bare + `\\word` LaTeX commands so subsequent letters do not get absorbed into + the command name. + """ + out = [] + for ch in text: + if ord(ch) <= 127 or ch == "\x01": + out.append(ch) + continue + if ch in REPLACEMENTS: + val = REPLACEMENTS[ch] + # Bare `\word` (starts with `\\`, ends in a letter) needs a + # trailing space so that `\cdot t` does not become `\cdott`. + if (len(val) >= 2 and val[0] == "\\" + and val[-1].isalpha() + and not val.startswith("\x01")): + val = val + " " + out.append(val) + continue + # Unmapped: keep as-is and let _nfkd_fallback try compat decomposition. + out.append(ch) + return "".join(out) + + +def _merge_sub_sup(text: str) -> str: + def _do(prefix, m): + # Extract each ^X or _X token and concatenate the X parts. + vals = re.findall(r"[\+\-\=\(\)a-zA-Z0-9]", m.group(0)) + # The regex captures the X char from each ^X or _X; above regex + # finds ALL alnum/sign chars in the match. But `^+` etc. we want + # to keep as-is. Simplest: split on the prefix. + pieces = [p for p in re.split(r"[\^_]", m.group(0)) if p] + joined = "".join(pieces) + return f"{prefix}{{{joined}}}" + + text = re.sub( + r"(?:\^[\+\-\=\(\)a-zA-Z0-9])(?:\^[\+\-\=\(\)a-zA-Z0-9])+", + lambda m: _do("^", m), text) + text = re.sub( + r"(?:_[\+\-\=\(\)a-zA-Z0-9])(?:_[\+\-\=\(\)a-zA-Z0-9])+", + lambda m: _do("_", m), text) + return text + + +_SENTINEL_RE = re.compile(r"\x01(SQRT|CBRT|FRT)\x01") + + +def _skip_spaces(s: str, i: int) -> int: + while i < len(s) and s[i] in " \t": + i += 1 + return i + + +def _read_balanced(s: str, i: int, open_ch: str, close_ch: str): + depth = 0 + j = i + while j < len(s): + if s[j] == open_ch: + depth += 1 + elif s[j] == close_ch: + depth -= 1 + if depth == 0: + return j + 1 + j += 1 + return -1 + + +def _read_latex_command(s: str, i: int): + if i >= len(s) or s[i] != "\\": + return -1 + j = i + 1 + while j < len(s) and (s[j].isalpha() or s[j] == "@"): + j += 1 + while j < len(s) and s[j] == "{": + end = _read_balanced(s, j, "{", "}") + if end == -1: + return j + j = end + return j + + +def _wrap_radical_arguments(text: str) -> str: + out = [] + i = 0 + LATEX_FOR = {"SQRT": r"\sqrt", "CBRT": r"\sqrt[3]", "FRT": r"\sqrt[4]"} + while i < len(text): + m = _SENTINEL_RE.match(text, i) + if not m: + out.append(text[i]) + i += 1 + continue + kind = m.group(1) + latex_prefix = LATEX_FOR[kind] + j = _skip_spaces(text, m.end()) + if j >= len(text): + out.append(latex_prefix + "{}") + i = j + continue + ch = text[j] + if ch == "(": + arg_end = _read_balanced(text, j, "(", ")") + if arg_end != -1: + arg = text[j + 1 : arg_end - 1] + out.append(f"{latex_prefix}{{{arg}}}") + i = arg_end + continue + if ch == "[": + arg_end = _read_balanced(text, j, "[", "]") + if arg_end != -1: + arg = text[j + 1 : arg_end - 1] + out.append(f"{latex_prefix}{{{arg}}}") + i = arg_end + continue + if ch == "{": + arg_end = _read_balanced(text, j, "{", "}") + if arg_end != -1: + arg = text[j + 1 : arg_end - 1] + out.append(f"{latex_prefix}{{{arg}}}") + i = arg_end + continue + if ch == "\\": + arg_end = _read_latex_command(text, j) + if arg_end != -1: + arg = text[j:arg_end] + out.append(f"{latex_prefix}{{{arg}}}") + i = arg_end + continue + # Fallback: alnum run (and dots for things like 3.14) + k = j + while k < len(text) and (text[k].isalnum() or text[k] in "."): + k += 1 + if k > j: + arg = text[j:k] + out.append(f"{latex_prefix}{{{arg}}}") + i = k + continue + out.append(latex_prefix + "{}") + i = m.end() + return "".join(out) + + +def _nfkd_fallback(text: str, unmapped: Counter) -> str: + """For characters that survived explicit substitution and are still + non-ASCII (e.g. precomposed accented Latin letters like \u00E9 / e-acute, + or classical Greek letters with breathing marks like \u1F42), run NFKD + and drop combining marks, then re-apply REPLACEMENTS (because NFKD can + unmask characters that do appear in REPLACEMENTS, e.g. \u1F42 -> \u03B3). + Finally, any character that is still non-ASCII is logged and dropped. + """ + has_non_ascii = any(ord(c) > 127 and c != "\x01" for c in text) + if not has_non_ascii: + return text + text = unicodedata.normalize("NFKD", text) + text = _COMBINING_MARK_RE.sub("", text) + # Second pass of char_substitute now that NFKD has possibly surfaced + # characters that were previously embedded in precomposed forms. + text = char_substitute(text, unmapped) # unmapped counter accumulates + # Final drop of anything still non-ASCII + out = [] + for c in text: + if ord(c) <= 127 or c == "\x01": + out.append(c) + else: + unmapped[c] += 1 + return "".join(out) + + +def clean_text(text: str, unmapped: Counter) -> str: + if not text: + return text + text = prestrip(text) + text = char_substitute(text, unmapped) + text = _nfkd_fallback(text, unmapped) + text = _merge_sub_sup(text) + text = _wrap_radical_arguments(text) + return text + + +def clean_problem(problem: dict, unmapped: Counter): + for k in TOP_LEVEL_TEXT_FIELDS: + if isinstance(problem.get(k), str): + problem[k] = clean_text(problem[k], unmapped) + variants = problem.get("variants") or {} + for vk in VARIANT_KEYS: + vd = variants.get(vk) + if not isinstance(vd, dict): + continue + for k in VARIANT_TEXT_FIELDS: + if isinstance(vd.get(k), str): + vd[k] = clean_text(vd[k], unmapped) + return problem + + +def process_dir(dataset_dir: Path): + print(f"\n=== Cleaning {dataset_dir} ===") + files = sorted(dataset_dir.glob("*.json")) + unmapped = Counter() + n_modified = 0 + for f in files: + try: + d = json.load(open(f)) + except Exception as e: + print(f" ! skip {f.name}: {e}") + continue + before = json.dumps(d, ensure_ascii=False) + d = clean_problem(d, unmapped) + after = json.dumps(d, ensure_ascii=False) + if before != after: + n_modified += 1 + with open(f, "w") as fh: + json.dump(d, fh, ensure_ascii=False, indent=2) + print(f" files modified: {n_modified}/{len(files)}") + if unmapped: + print(f" unmapped characters: {sum(unmapped.values())} occurrences, " + f"{len(unmapped)} distinct") + print(f" top 20 unmapped:") + for ch, n in unmapped.most_common(20): + name = unicodedata.name(ch, "?") + print(f" {ch!r:<10} U+{ord(ch):04X} n={n} ({name})") + else: + print(f" no unmapped characters") + return unmapped + + +def main(): + all_unmapped = Counter() + for d in DIRS: + if d.exists(): + u = process_dir(d) + all_unmapped.update(u) + print(f"\n=== OVERALL ===") + print(f"Total unmapped characters across both dataset copies: {sum(all_unmapped.values())}") + print(f"Distinct unmapped: {len(all_unmapped)}") + if all_unmapped: + out_path = Path("/home/yurenh2/gap/analysis/unmapped_chars.json") + json.dump({f"U+{ord(c):04X}": {"char": c, "name": unicodedata.name(c, "?"), + "count": n} + for c, n in all_unmapped.most_common()}, + open(out_path, "w"), indent=2, ensure_ascii=False) + print(f"Saved unmapped list -> {out_path}") + + +if __name__ == "__main__": + main() -- cgit v1.2.3