summaryrefslogtreecommitdiff
path: root/analysis
diff options
context:
space:
mode:
Diffstat (limited to 'analysis')
-rw-r--r--analysis/aggregate_overlap.py91
-rw-r--r--analysis/balance_diff.py109
-rw-r--r--analysis/cross_model_agreement.py180
-rw-r--r--analysis/kv_overlap.py332
-rw-r--r--analysis/make_figures.py272
-rw-r--r--analysis/normalization_analysis.py189
-rw-r--r--analysis/rescue_analyze.py161
-rw-r--r--analysis/rescue_api.py373
-rw-r--r--analysis/rescue_pooled.py174
-rw-r--r--analysis/rescue_prompts.py267
-rw-r--r--analysis/rescue_runner.py341
-rw-r--r--analysis/sc_success_and_difficulty.py192
-rw-r--r--analysis/self_correction.py202
-rw-r--r--analysis/spotcheck_clean.py181
-rw-r--r--analysis/structural_overlap.py523
-rw-r--r--analysis/topic_problemtype_interaction.py112
-rw-r--r--analysis/unicode_audit.py238
-rw-r--r--analysis/unicode_clean.py729
18 files changed, 4666 insertions, 0 deletions
diff --git a/analysis/aggregate_overlap.py b/analysis/aggregate_overlap.py
new file mode 100644
index 0000000..cd6b53e
--- /dev/null
+++ b/analysis/aggregate_overlap.py
@@ -0,0 +1,91 @@
+"""Aggregate structural_overlap results by variant type and by model.
+
+Produces a clean rebuttal table.
+"""
+from __future__ import annotations
+import json
+import statistics
+from pathlib import Path
+from collections import defaultdict
+
+RESULTS = Path("/home/yurenh2/gap/analysis/structural_overlap_results.json")
+SHORT = {"descriptive_long":"DL","descriptive_long_confusing":"DLC",
+ "descriptive_long_misleading":"DLM","garbled_string":"GS"}
+
+
+def main():
+ cells = json.load(open(RESULTS))
+ print(f"Loaded {len(cells)} cells.\n")
+
+ # Per-variant aggregate
+ per_variant = defaultdict(list)
+ for c in cells:
+ per_variant[c["variant"]].append(c)
+
+ print("=" * 90)
+ print("HEADLINE TABLE: Surface variants — stable vs brittle structural overlap")
+ print("(token Jaccard on canonicalized trajectories, drift cases only)")
+ print("=" * 90)
+ print(f"\n{'Variant':<6} {'#cells':>7} {'#dir+':>6} {'#p<.05':>8} "
+ f"{'med-d':>7} {'mean-d':>7} {'mean-dlt':>9} "
+ f"{'mean-stbl':>10} {'mean-brit':>10} {'mean-noise':>11} "
+ f"{'mean-collapse%':>14}")
+ print("-" * 100)
+ for v, cs in per_variant.items():
+ ds = [c["metrics"]["token_jaccard"]["cohens_d"] for c in cs]
+ ps = [c["metrics"]["token_jaccard"]["p_two_sided"] for c in cs]
+ n_pos = sum(1 for d in ds if d > 0)
+ n_sig = sum(1 for p in ps if p < 0.05)
+ deltas = [c["metrics"]["token_jaccard"]["delta_median"] for c in cs]
+ stbl = [c["metrics"]["token_jaccard"]["stable_median"] for c in cs]
+ brit = [c["metrics"]["token_jaccard"]["brittle_median"] for c in cs]
+ noise = [c["metrics"]["token_jaccard"]["noise_floor_median"] for c in cs
+ if c["metrics"]["token_jaccard"].get("noise_floor_median") is not None]
+ collapse = [c["brittle_collapse_rate"] for c in cs]
+ print(f"{SHORT[v]:<6} {len(cs):>7} {n_pos:>6} {n_sig:>8} "
+ f"{statistics.median(ds):>+7.2f} {statistics.fmean(ds):>+7.2f} "
+ f"{statistics.fmean(deltas):>+9.4f} "
+ f"{statistics.fmean(stbl):>10.3f} {statistics.fmean(brit):>10.3f} "
+ f"{statistics.fmean(noise):>11.3f} "
+ f"{statistics.fmean(collapse)*100:>13.1f}%")
+
+ # Variant-aggregate (across all models, n-weighted)
+ print("\n" + "=" * 90)
+ print("ALL CELLS (18 models × 4 surface variants)")
+ print("=" * 90)
+ all_d = [c["metrics"]["token_jaccard"]["cohens_d"] for c in cells]
+ all_p = [c["metrics"]["token_jaccard"]["p_two_sided"] for c in cells]
+ print(f" cells: {len(cells)}")
+ print(f" direction-positive: {sum(1 for d in all_d if d>0)}/{len(cells)}")
+ print(f" p<0.05: {sum(1 for p in all_p if p<0.05)}/{len(cells)}")
+ print(f" p<0.001: {sum(1 for p in all_p if p<0.001)}/{len(cells)}")
+ print(f" p<1e-6: {sum(1 for p in all_p if p<1e-6)}/{len(cells)}")
+ print(f" Cohen's d median: {statistics.median(all_d):+.3f}")
+ print(f" Cohen's d mean: {statistics.fmean(all_d):+.3f}")
+ print(f" Cohen's d range: [{min(all_d):+.2f}, {max(all_d):+.2f}]")
+
+ # Per-model aggregate (averaged across 4 surface variants)
+ per_model = defaultdict(list)
+ for c in cells:
+ per_model[c["model"]].append(c)
+ print("\n" + "=" * 90)
+ print("PER MODEL (averaged across 4 surface variants)")
+ print("=" * 90)
+ print(f"\n{'Model':<25} {'mean-d':>7} {'mean-stbl':>10} {'mean-brit':>10} "
+ f"{'mean-coll%':>11} {'min-p':>9}")
+ print("-" * 80)
+ rows = []
+ for m, cs in per_model.items():
+ if len(cs) == 0: continue
+ d = statistics.fmean(c["metrics"]["token_jaccard"]["cohens_d"] for c in cs)
+ s = statistics.fmean(c["metrics"]["token_jaccard"]["stable_median"] for c in cs)
+ b = statistics.fmean(c["metrics"]["token_jaccard"]["brittle_median"] for c in cs)
+ col = statistics.fmean(c["brittle_collapse_rate"] for c in cs) * 100
+ mp = min(c["metrics"]["token_jaccard"]["p_two_sided"] for c in cs)
+ rows.append((m, d, s, b, col, mp))
+ for r in sorted(rows, key=lambda r: -r[1]):
+ print(f"{r[0]:<25} {r[1]:>+7.2f} {r[2]:>10.3f} {r[3]:>10.3f} {r[4]:>10.1f}% {r[5]:>9.1e}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/analysis/balance_diff.py b/analysis/balance_diff.py
new file mode 100644
index 0000000..f420d46
--- /dev/null
+++ b/analysis/balance_diff.py
@@ -0,0 +1,109 @@
+"""Compare brace/paren/bracket balance BEFORE vs AFTER cleaning to check
+whether the cleaner introduced any new imbalance."""
+from __future__ import annotations
+import json
+import tarfile
+from pathlib import Path
+from collections import Counter
+
+CURRENT_DIR = Path("/home/yurenh2/gap/putnam-bench-anon/dataset")
+BACKUP_TAR = sorted(Path("/home/yurenh2/gap/analysis/dataset_backups").glob(
+ "putnam-bench-anon_dataset_*.tar.gz"))[-1]
+
+
+def all_text(d: dict) -> str:
+ out = []
+ for k in ("question", "solution"):
+ out.append(d.get(k) or "")
+ for vk, vd in (d.get("variants") or {}).items():
+ if isinstance(vd, dict):
+ for k in ("question", "solution"):
+ out.append(vd.get(k) or "")
+ return "\n".join(out)
+
+
+def balance(text: str):
+ return (
+ text.count("{") - text.count("}"),
+ text.count("(") - text.count(")"),
+ text.count("[") - text.count("]"),
+ )
+
+
+def main():
+ print("Loading backup ...")
+ backup = {}
+ with tarfile.open(BACKUP_TAR, "r:gz") as tar:
+ for member in tar.getmembers():
+ if not member.isfile() or not member.name.endswith(".json"):
+ continue
+ f = tar.extractfile(member)
+ if not f:
+ continue
+ d = json.load(f)
+ backup[d.get("index")] = all_text(d)
+ print(f" loaded {len(backup)} backup problems")
+
+ print("Loading current ...")
+ current = {}
+ for f in sorted(CURRENT_DIR.glob("*.json")):
+ d = json.load(open(f))
+ current[d.get("index")] = all_text(d)
+ print(f" loaded {len(current)} current problems")
+
+ # Per-file balance diff
+ introduced_imbalance = []
+ fixed_imbalance = []
+ same_imbalance = 0
+ same_balanced = 0
+
+ n_brace_changed = 0
+ n_paren_changed = 0
+ n_brack_changed = 0
+
+ for idx in sorted(backup):
+ b_before = balance(backup[idx])
+ b_after = balance(current.get(idx, ""))
+ was_bal = b_before == (0, 0, 0)
+ is_bal = b_after == (0, 0, 0)
+ if b_before != b_after:
+ if was_bal and not is_bal:
+ introduced_imbalance.append((idx, b_before, b_after))
+ elif not was_bal and is_bal:
+ fixed_imbalance.append((idx, b_before, b_after))
+ else:
+ if is_bal:
+ same_balanced += 1
+ else:
+ same_imbalance += 1
+ if b_before[0] != b_after[0]: n_brace_changed += 1
+ if b_before[1] != b_after[1]: n_paren_changed += 1
+ if b_before[2] != b_after[2]: n_brack_changed += 1
+
+ print(f"\n=== Per-file balance change summary ===")
+ print(f" Files with no change in any balance:")
+ print(f" balanced both before and after: {same_balanced}")
+ print(f" imbalanced before and after (same imbalance): {same_imbalance}")
+ print(f" Files where cleaner INTRODUCED new imbalance: "
+ f"{len(introduced_imbalance)}")
+ print(f" Files where cleaner FIXED prior imbalance: {len(fixed_imbalance)}")
+ print()
+ print(f" Files where {{ balance changed: {n_brace_changed}")
+ print(f" Files where ( balance changed: {n_paren_changed}")
+ print(f" Files where [ balance changed: {n_brack_changed}")
+
+ if introduced_imbalance:
+ print(f"\n!!! Cleaner-introduced imbalances ({len(introduced_imbalance)}):")
+ for idx, before, after in introduced_imbalance[:10]:
+ print(f" {idx}: before={before}, after={after}")
+ else:
+ print("\n ✓ No cleaner-introduced imbalances found.")
+
+ if fixed_imbalance:
+ print(f"\n Cleaner-fixed imbalances (top 10):")
+ for idx, before, after in fixed_imbalance[:10]:
+ print(f" {idx}: before={before}, after={after}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/analysis/cross_model_agreement.py b/analysis/cross_model_agreement.py
new file mode 100644
index 0000000..fb9a571
--- /dev/null
+++ b/analysis/cross_model_agreement.py
@@ -0,0 +1,180 @@
+"""Cross-model agreement analysis: which problems are universally hard?
+
+For each (variant, problem) cell, count how many models fail (correct=False).
+Identify "universally hard" problems (failed by ≥80% of models on the variant)
+and "universally easy" (correct by ≥80% on the variant). Then check whether
+the universally hard *flip set* is dominated by certain topics, problem types,
+or years.
+
+Outputs:
+- Per-variant histogram of failure counts
+- "Universal flip" cases: original correct by ≥80% of models, variant wrong by ≥80%
+- These are the cleanest signals of variant-induced fragility because they
+ rule out problem-specific quirks
+"""
+from __future__ import annotations
+import json
+import sys
+from pathlib import Path
+from collections import defaultdict, Counter
+
+THIS_DIR = Path(__file__).resolve().parent
+sys.path.insert(0, str(THIS_DIR))
+from structural_overlap import find_variant_file, load_problems, RESULTS_DIR, SURFACE_VARIANTS
+
+DATASET_DIR = Path("/home/yurenh2/gap/putnam-bench-anon/dataset")
+
+
+def load_metadata():
+ """Load problem-level metadata: type, tag, difficulty, year."""
+ out = {}
+ for f in sorted(DATASET_DIR.glob("*.json")):
+ d = json.load(open(f))
+ idx = d.get("index")
+ out[idx] = {
+ "type": d.get("type"),
+ "tag": d.get("tag"),
+ "difficulty": d.get("difficulty"),
+ "problem_type": d.get("problem_type"),
+ "year": int(idx.split("-")[0]) if idx and "-" in idx else None,
+ }
+ return out
+
+
+def main():
+ base = RESULTS_DIR
+ models = sorted([d.name for d in base.iterdir() if d.is_dir()])
+ print(f"Loading {len(models)} models ...")
+ metadata = load_metadata()
+
+ # correct_table[(variant, idx)][model] = bool
+ correct_table = defaultdict(dict)
+ for m in models:
+ mdir = base / m
+ for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]:
+ vp = find_variant_file(mdir, v)
+ if not vp:
+ continue
+ for p in load_problems(vp):
+ idx = p.get("index")
+ correct = p.get("correct")
+ if idx is None or correct is None:
+ continue
+ correct_table[(v, idx)][m] = correct
+
+ print(f"Loaded {len(correct_table)} (variant, problem) cells.\n")
+
+ # Per-variant histogram of correct counts (out of N models)
+ print("=== HISTOGRAM OF CORRECT-COUNT ACROSS MODELS ===")
+ print("(How many models get each problem right per variant)\n")
+ print(f"{'Variant':<24} {'mean correct/N':>16} {'median':>9} {'#unanimous-fail':>17} {'#unanimous-pass':>17}")
+ print("-" * 90)
+ for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]:
+ cells = [d for (vv, idx), d in correct_table.items() if vv == v]
+ if not cells:
+ continue
+ counts = [sum(1 for vv in cell.values() if vv) / len(cell) for cell in cells]
+ unanimous_fail = sum(1 for cell in cells if not any(cell.values()) and len(cell) >= 3)
+ unanimous_pass = sum(1 for cell in cells if all(cell.values()) and len(cell) >= 3)
+ import statistics
+ print(f"{v:<24} {statistics.fmean(counts)*100:>14.1f}% {statistics.median(counts)*100:>7.1f}% "
+ f"{unanimous_fail:>17} {unanimous_pass:>17}")
+
+ # Universal flip cases: original correct by ≥80% of models, variant wrong by ≥80%
+ print(f"\n\n=== UNIVERSAL FLIP CASES (orig ≥80% correct, variant ≥80% wrong) ===\n")
+ print("These are the cleanest signals of variant-induced fragility.\n")
+ print(f"{'Variant':<24} {'# universal flips':>20}")
+ print("-" * 50)
+ universal_flips = defaultdict(list)
+ for v in SURFACE_VARIANTS + ["kernel_variant"]:
+ for idx in {ii for (vv, ii) in correct_table if vv == "original"}:
+ orig_cell = correct_table.get(("original", idx), {})
+ var_cell = correct_table.get((v, idx), {})
+ common = set(orig_cell) & set(var_cell)
+ if len(common) < 5: continue
+ orig_rate = sum(1 for m in common if orig_cell[m]) / len(common)
+ var_rate = sum(1 for m in common if var_cell[m]) / len(common)
+ if orig_rate >= 0.80 and var_rate <= 0.20:
+ universal_flips[v].append((idx, orig_rate, var_rate))
+ print(f"{v:<24} {len(universal_flips[v]):>20}")
+
+ # Topic / problem_type / difficulty / year breakdown for universal flips
+ print(f"\n\n=== TOPIC BREAKDOWN OF UNIVERSAL FLIPS ===\n")
+ for v in SURFACE_VARIANTS + ["kernel_variant"]:
+ if not universal_flips[v]: continue
+ print(f"--- {v} ({len(universal_flips[v])} universal flips) ---")
+ topics = Counter()
+ ptypes = Counter()
+ difficulties = Counter()
+ years = Counter()
+ for idx, _, _ in universal_flips[v]:
+ md = metadata.get(idx, {})
+ tag = md.get("tag")
+ # tag may be a list (multi-tag) or a string
+ if isinstance(tag, list):
+ for t in tag: topics[t] += 1
+ elif tag:
+ topics[tag] += 1
+ else:
+ topics["?"] += 1
+ ptypes[md.get("problem_type") or "?"] += 1
+ diff = md.get("difficulty")
+ if isinstance(diff, list): diff = diff[0] if diff else "?"
+ difficulties[diff or "?"] += 1
+ year = md.get("year")
+ if year:
+ # Bin years by decade
+ decade = (year // 10) * 10
+ years[f"{decade}s"] += 1
+ print(f" topics: {dict(topics.most_common(8))}")
+ print(f" problem_type:{dict(ptypes)}")
+ print(f" difficulties:{dict(difficulties.most_common(6))}")
+ print(f" decades: {dict(sorted(years.items()))}")
+ print()
+
+ # Save universal flips for later analysis
+ out = {v: [{"index": idx, "orig_rate": o, "var_rate": vr}
+ for (idx, o, vr) in flips]
+ for v, flips in universal_flips.items()}
+ json.dump(out, open(THIS_DIR / "universal_flips.json", "w"), indent=2)
+ print(f"\nSaved -> analysis/universal_flips.json")
+
+ # Topic-stratified analysis: failure rate per topic per variant
+ print(f"\n\n=== ACCURACY BY TOPIC × VARIANT (mean across models) ===\n")
+ by_topic_variant = defaultdict(lambda: defaultdict(list))
+ for (v, idx), cell in correct_table.items():
+ md = metadata.get(idx, {})
+ tag = md.get("tag")
+ if not tag or len(cell) < 5: continue
+ # If multiple tags, attribute the same rate to each — keeps it simple
+ topics_for_problem = tag if isinstance(tag, list) else [tag]
+ rate = sum(1 for vv in cell.values() if vv) / len(cell)
+ for t in topics_for_problem:
+ by_topic_variant[t][v].append(rate)
+
+ topics_to_show = ["ALG", "ANA", "NT", "COMB", "GEO"]
+ print(f"{'Topic':<8}", end="")
+ for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]:
+ short = {"original":"orig","descriptive_long":"DL","descriptive_long_confusing":"DLC",
+ "descriptive_long_misleading":"DLM","garbled_string":"GS","kernel_variant":"KV"}[v]
+ print(f" {short:>5}", end="")
+ print(" Δ_orig→KV")
+ print("-" * 70)
+ for t in topics_to_show:
+ if t not in by_topic_variant: continue
+ row = by_topic_variant[t]
+ if "original" not in row: continue
+ orig_mean = statistics.fmean(row["original"]) * 100
+ print(f"{t:<8}", end="")
+ for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]:
+ if v in row:
+ m = statistics.fmean(row[v]) * 100
+ print(f" {m:>4.1f}%", end="")
+ else:
+ print(f" {'-':>5}", end="")
+ kv_mean = statistics.fmean(row.get("kernel_variant", [0])) * 100
+ print(f" {kv_mean - orig_mean:+5.1f}pp")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/analysis/kv_overlap.py b/analysis/kv_overlap.py
new file mode 100644
index 0000000..137e61f
--- /dev/null
+++ b/analysis/kv_overlap.py
@@ -0,0 +1,332 @@
+"""Kernel-variant structural-overlap analysis (label-free).
+
+Unlike surface variants, kernel variants change the math, so we cannot use the
+model's own original-correct trajectory as a reference. Instead we use the
+dataset's canonical kernel-variant solution as the reference.
+
+Hypothesis: stable (correct on KV) trajectories have higher structural overlap
+with the canonical KV solution than brittle (wrong on KV) trajectories.
+
+For comparability we also recompute the surface analyses using the same
+'overlap with canonical solution' metric, so we can compare apples-to-apples
+the magnitude of stable-vs-brittle gap between surface and kernel.
+"""
+from __future__ import annotations
+import json
+import os
+import statistics
+from pathlib import Path
+from collections import defaultdict
+from typing import Optional
+
+# Reuse helpers from the sibling module
+import sys
+sys.path.insert(0, str(Path(__file__).parent))
+from structural_overlap import (
+ DATASET_DIR, RESULTS_DIR,
+ load_problems, find_variant_file,
+ canonicalize_text, normalize_whitespace,
+ tokens, bigrams, jaccard, extract_math_blocks,
+ metric_token_jaccard, metric_bigram_jaccard,
+ metric_directional_coverage, metric_equation_jaccard,
+ mann_whitney_u, bootstrap_ci_cohens_d,
+ is_collapse, COLLAPSE_MIN_CHARS, COLLAPSE_RATIO,
+ SURFACE_VARIANTS,
+)
+
+
+def load_dataset_variant_solutions() -> dict:
+ """Returns: {problem_index: {variant_name: canonical_solution_text}}.
+
+ Includes 'original' (from top-level field) plus all 5 variants.
+ """
+ out = {}
+ for f in sorted(DATASET_DIR.glob("*.json")):
+ d = json.load(open(f))
+ idx = d.get("index")
+ cell = {"original": d.get("solution") or "",
+ "_problem_type": d.get("problem_type")}
+ for v, vd in d.get("variants", {}).items():
+ if isinstance(vd, dict):
+ cell[v] = vd.get("solution") or ""
+ out[idx] = cell
+ return out
+
+
+def load_dataset_maps() -> dict:
+ """Mirrors structural_overlap.load_dataset_maps but localized for safety."""
+ out = {}
+ for f in sorted(DATASET_DIR.glob("*.json")):
+ d = json.load(open(f))
+ idx = d.get("index")
+ variants = d.get("variants", {})
+ cell = {}
+ for v in SURFACE_VARIANTS:
+ vd = variants.get(v, {})
+ mp_str = vd.get("map")
+ if isinstance(mp_str, str):
+ try:
+ mp = eval(mp_str, {"__builtins__": {}}, {})
+ if isinstance(mp, dict):
+ cell[v] = {str(k): str(v) for k, v in mp.items()}
+ except Exception:
+ pass
+ elif isinstance(mp_str, dict):
+ cell[v] = {str(k): str(v) for k, v in mp_str.items()}
+ out[idx] = cell
+ return out
+
+
+# ---------- Cell analyzer ----------
+
+def analyze_kv_cell(model_name: str, model_dir: Path,
+ canonical_solutions: dict) -> Optional[dict]:
+ """Compare model's KV trajectory to dataset canonical KV solution.
+
+ No canonicalization (no rename map for KV — variables match by construction).
+ """
+ orig_path = find_variant_file(model_dir, "original")
+ var_path = find_variant_file(model_dir, "kernel_variant")
+ if not orig_path or not var_path:
+ return None
+ orig_by = {p["index"]: p for p in load_problems(orig_path)}
+ var_by = {p["index"]: p for p in load_problems(var_path)}
+
+ pairs_stable_drift = []
+ pairs_brittle_drift = []
+ n_brittle_collapse = 0
+ n_stable_collapse = 0
+
+ for idx in set(orig_by) & set(var_by):
+ po, pv = orig_by[idx], var_by[idx]
+ if po.get("correct") is not True:
+ continue # Restrict to "model already gets the original"
+ var_correct = pv.get("correct")
+ if var_correct is None:
+ continue
+ var_text = (pv.get("solve") or {}).get("solution") or ""
+ if not var_text:
+ continue
+ canon_kv = canonical_solutions.get(idx, {}).get("kernel_variant", "")
+ if not canon_kv or len(canon_kv) < 200:
+ continue
+ # Collapse rule: variant text < 200 chars OR < 25% of canonical solution
+ collapse = (len(var_text) < COLLAPSE_MIN_CHARS or
+ len(var_text) < COLLAPSE_RATIO * len(canon_kv))
+ sample = {"index": idx, "var_text": var_text, "canon": canon_kv}
+ if var_correct is True:
+ if collapse:
+ n_stable_collapse += 1
+ else:
+ pairs_stable_drift.append(sample)
+ else:
+ if collapse:
+ n_brittle_collapse += 1
+ else:
+ pairs_brittle_drift.append(sample)
+
+ if not pairs_stable_drift or not pairs_brittle_drift:
+ return None
+
+ metrics = {
+ "token_jaccard": metric_token_jaccard,
+ "bigram_jaccard": metric_bigram_jaccard,
+ "equation_jaccard": metric_equation_jaccard,
+ "directional_coverage": metric_directional_coverage,
+ }
+
+ out = {
+ "model": model_name,
+ "variant": "kernel_variant",
+ "n_stable_drift": len(pairs_stable_drift),
+ "n_brittle_drift": len(pairs_brittle_drift),
+ "n_brittle_collapse": n_brittle_collapse,
+ "n_stable_collapse": n_stable_collapse,
+ "brittle_collapse_rate": n_brittle_collapse /
+ max(1, n_brittle_collapse + len(pairs_brittle_drift)),
+ "metrics": {},
+ }
+ for mname, mfn in metrics.items():
+ s_vals = [mfn(p["var_text"], p["canon"]) for p in pairs_stable_drift]
+ b_vals = [mfn(p["var_text"], p["canon"]) for p in pairs_brittle_drift]
+ U, p = mann_whitney_u(s_vals, b_vals)
+ sm, bm = statistics.fmean(s_vals), statistics.fmean(b_vals)
+ ssd = statistics.pstdev(s_vals) if len(s_vals) > 1 else 0
+ bsd = statistics.pstdev(b_vals) if len(b_vals) > 1 else 0
+ pooled = (((len(s_vals)-1)*ssd**2 + (len(b_vals)-1)*bsd**2)
+ / max(1, len(s_vals)+len(b_vals)-2)) ** 0.5
+ d = (sm - bm) / pooled if pooled > 0 else 0.0
+ out["metrics"][mname] = {
+ "stable_median": statistics.median(s_vals),
+ "stable_mean": sm,
+ "brittle_median": statistics.median(b_vals),
+ "brittle_mean": bm,
+ "delta_median": statistics.median(s_vals) - statistics.median(b_vals),
+ "cohens_d": d,
+ "U": U,
+ "p_two_sided": p,
+ }
+ # Headline bootstrap
+ s_vals = [metric_token_jaccard(p["var_text"], p["canon"]) for p in pairs_stable_drift]
+ b_vals = [metric_token_jaccard(p["var_text"], p["canon"]) for p in pairs_brittle_drift]
+ d_lo, d_hi = bootstrap_ci_cohens_d(s_vals, b_vals, n_iter=400)
+ out["metrics"]["token_jaccard"]["cohens_d_ci"] = [d_lo, d_hi]
+ return out
+
+
+# ---------- Surface re-analysis with canonical reference ----------
+
+def analyze_surface_cell_against_canonical(model_name: str, variant: str,
+ model_dir: Path,
+ canonical_solutions: dict) -> Optional[dict]:
+ """Compare model variant trajectory to dataset canonical variant solution.
+
+ For comparability with KV. No rename canonicalization needed since both
+ sides use the same variant naming.
+ """
+ var_path = find_variant_file(model_dir, variant)
+ orig_path = find_variant_file(model_dir, "original")
+ if not var_path or not orig_path:
+ return None
+ var_by = {p["index"]: p for p in load_problems(var_path)}
+ orig_by = {p["index"]: p for p in load_problems(orig_path)}
+
+ pairs_stable, pairs_brittle = [], []
+ n_brittle_collapse = 0
+ for idx in set(var_by):
+ if idx not in orig_by:
+ continue
+ if orig_by[idx].get("correct") is not True:
+ continue # restrict to model-knows-original
+ pv = var_by[idx]
+ var_correct = pv.get("correct")
+ if var_correct is None:
+ continue
+ var_text = (pv.get("solve") or {}).get("solution") or ""
+ if not var_text:
+ continue
+ canon_var = canonical_solutions.get(idx, {}).get(variant, "")
+ if not canon_var or len(canon_var) < 200:
+ continue
+ if (len(var_text) < COLLAPSE_MIN_CHARS or
+ len(var_text) < COLLAPSE_RATIO * len(canon_var)):
+ if var_correct is False:
+ n_brittle_collapse += 1
+ continue
+ sample = {"index": idx, "var_text": var_text, "canon": canon_var}
+ if var_correct is True:
+ pairs_stable.append(sample)
+ else:
+ pairs_brittle.append(sample)
+
+ if not pairs_stable or not pairs_brittle:
+ return None
+
+ metrics = {
+ "token_jaccard": metric_token_jaccard,
+ "bigram_jaccard": metric_bigram_jaccard,
+ "equation_jaccard": metric_equation_jaccard,
+ "directional_coverage": metric_directional_coverage,
+ }
+ out = {
+ "model": model_name,
+ "variant": variant,
+ "n_stable_drift": len(pairs_stable),
+ "n_brittle_drift": len(pairs_brittle),
+ "n_brittle_collapse": n_brittle_collapse,
+ "brittle_collapse_rate": n_brittle_collapse /
+ max(1, n_brittle_collapse + len(pairs_brittle)),
+ "metrics": {},
+ }
+ for mname, mfn in metrics.items():
+ s_vals = [mfn(p["var_text"], p["canon"]) for p in pairs_stable]
+ b_vals = [mfn(p["var_text"], p["canon"]) for p in pairs_brittle]
+ U, p = mann_whitney_u(s_vals, b_vals)
+ sm, bm = statistics.fmean(s_vals), statistics.fmean(b_vals)
+ ssd = statistics.pstdev(s_vals) if len(s_vals) > 1 else 0
+ bsd = statistics.pstdev(b_vals) if len(b_vals) > 1 else 0
+ pooled = (((len(s_vals)-1)*ssd**2 + (len(b_vals)-1)*bsd**2)
+ / max(1, len(s_vals)+len(b_vals)-2)) ** 0.5
+ d = (sm - bm) / pooled if pooled > 0 else 0.0
+ out["metrics"][mname] = {
+ "stable_median": statistics.median(s_vals),
+ "stable_mean": sm,
+ "brittle_median": statistics.median(b_vals),
+ "brittle_mean": bm,
+ "delta_median": statistics.median(s_vals) - statistics.median(b_vals),
+ "cohens_d": d,
+ "U": U,
+ "p_two_sided": p,
+ }
+ return out
+
+
+def main():
+ print("Loading canonical solutions ...")
+ canon = load_dataset_variant_solutions()
+ print(f" loaded {len(canon)} problems")
+
+ all_models = sorted([d.name for d in RESULTS_DIR.iterdir() if d.is_dir()])
+
+ kv_results = []
+ surface_results = []
+
+ print(f"\n{'KERNEL VARIANT — variant trajectory vs canonical KV solution':<70}")
+ print(f"{'Cell':<32} {'nSd':>4} {'nBd':>4} {'col%':>5} "
+ f"{'sMed':>6} {'bMed':>6} {'d':>6} {'p':>9}")
+ print("-" * 90)
+ for m in all_models:
+ mdir = RESULTS_DIR / m
+ if not mdir.exists():
+ continue
+ res = analyze_kv_cell(m, mdir, canon)
+ if res is None:
+ continue
+ kv_results.append(res)
+ md = res["metrics"]["token_jaccard"]
+ print(f"{m+' / KV':<32} {res['n_stable_drift']:>4} {res['n_brittle_drift']:>4} "
+ f"{res['brittle_collapse_rate']*100:>4.0f}% "
+ f"{md['stable_median']:>6.3f} {md['brittle_median']:>6.3f} "
+ f"{md['cohens_d']:>+6.2f} {md['p_two_sided']:>9.1e}")
+
+ print(f"\n{'SURFACE VARIANT — variant trajectory vs canonical variant solution':<70}")
+ print(f"{'Cell':<46} {'nSd':>4} {'nBd':>4} {'col%':>5} "
+ f"{'sMed':>6} {'bMed':>6} {'d':>6} {'p':>9}")
+ print("-" * 95)
+ for m in all_models:
+ mdir = RESULTS_DIR / m
+ if not mdir.exists():
+ continue
+ for v in SURFACE_VARIANTS:
+ res = analyze_surface_cell_against_canonical(m, v, mdir, canon)
+ if res is None:
+ continue
+ surface_results.append(res)
+ md = res["metrics"]["token_jaccard"]
+ print(f"{m+' / '+v:<46} {res['n_stable_drift']:>4} {res['n_brittle_drift']:>4} "
+ f"{res['brittle_collapse_rate']*100:>4.0f}% "
+ f"{md['stable_median']:>6.3f} {md['brittle_median']:>6.3f} "
+ f"{md['cohens_d']:>+6.2f} {md['p_two_sided']:>9.1e}")
+
+ # Save
+ json.dump(kv_results, open("/home/yurenh2/gap/analysis/kv_overlap_results.json", "w"), indent=2)
+ json.dump(surface_results, open("/home/yurenh2/gap/analysis/surface_canonical_results.json", "w"), indent=2)
+
+ # Aggregate compare
+ print("\n" + "=" * 80)
+ print("AGGREGATE: surface (vs canonical) vs kernel (vs canonical)")
+ print("=" * 80)
+ for tag, results in [("surface", surface_results), ("kernel", kv_results)]:
+ ds = [c["metrics"]["token_jaccard"]["cohens_d"] for c in results]
+ ps = [c["metrics"]["token_jaccard"]["p_two_sided"] for c in results]
+ col = [c["brittle_collapse_rate"] for c in results]
+ if not ds:
+ continue
+ print(f"{tag:<8} cells={len(ds):>3} d_pos={sum(1 for d in ds if d>0):>3}/{len(ds):<3} "
+ f"p<.05={sum(1 for p in ps if p<0.05):>3}/{len(ps):<3} "
+ f"d_med={statistics.median(ds):+.2f} d_mean={statistics.fmean(ds):+.2f} "
+ f"collapse_mean={statistics.fmean(col)*100:.1f}%")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/analysis/make_figures.py b/analysis/make_figures.py
new file mode 100644
index 0000000..4ff598d
--- /dev/null
+++ b/analysis/make_figures.py
@@ -0,0 +1,272 @@
+"""Three rebuttal figures.
+
+Fig1 — Structural Cohen's d heatmap
+ 18 models × 5 variants (4 surface + KV).
+ Surface cells use the self-anchor metric (model's own original under
+ inverse rename). KV uses the canonical-anchor metric.
+
+Fig2 — Rescue rebound rates by variant + condition
+ Pooled across 4 models. Bar plot with Wilson 95 % CI.
+ Three bars per variant: null / canonical_T2 / own_T2 (KV: only 2).
+
+Fig3 — own_T2 vs canonical_T2 per (model, variant)
+ Scatter plot of own_T2 rebound rate vs canonical_T2 rebound rate per
+ cell, with the y=x line. Points above the diagonal: own outperforms
+ canonical (rare); below: canonical outperforms own (typical).
+"""
+from __future__ import annotations
+import json
+import math
+import statistics
+from pathlib import Path
+from collections import defaultdict
+
+import matplotlib
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+import numpy as np
+
+ROOT = Path("/home/yurenh2/gap/analysis")
+FIG_DIR = ROOT / "figures"
+FIG_DIR.mkdir(parents=True, exist_ok=True)
+
+VARIANT_LABELS = {
+ "descriptive_long": "DL",
+ "descriptive_long_confusing": "DLC",
+ "descriptive_long_misleading": "DLM",
+ "garbled_string": "GS",
+ "kernel_variant": "KV",
+}
+VARIANT_ORDER_SURF = ["descriptive_long", "descriptive_long_confusing",
+ "descriptive_long_misleading", "garbled_string"]
+VARIANT_ORDER_ALL = VARIANT_ORDER_SURF + ["kernel_variant"]
+
+# ----------------------------------------------------------------------
+# Fig 1 — Structural Cohen's d heatmap
+# ----------------------------------------------------------------------
+
+def fig1_structural_d_heatmap():
+ """Heatmap of Cohen's d for the stable-vs-brittle structural metric.
+
+ Surface cells: self-anchor (token Jaccard between model's variant
+ trajectory and its own original-correct trajectory after canonicalization).
+ Source file: structural_overlap_results.json.
+
+ KV cells: canonical-anchor (token Jaccard between model's KV trajectory and
+ the dataset's canonical KV solution).
+ Source file: kv_overlap_results.json.
+ """
+ surf = json.load(open(ROOT / "structural_overlap_results.json"))
+ kv = json.load(open(ROOT / "kv_overlap_results.json"))
+
+ # Build matrix: rows = models (sorted by mean d), cols = variants (DL, DLC, DLM, GS, KV)
+ by_cell = {}
+ for c in surf:
+ by_cell[(c["model"], c["variant"])] = c["metrics"]["token_jaccard"]["cohens_d"]
+ for c in kv:
+ by_cell[(c["model"], "kernel_variant")] = c["metrics"]["token_jaccard"]["cohens_d"]
+
+ models = sorted({k[0] for k in by_cell})
+ # Sort by mean d across surface variants only (so KV doesn't bias the order)
+ def mean_surface_d(m):
+ ds = [by_cell.get((m, v)) for v in VARIANT_ORDER_SURF
+ if by_cell.get((m, v)) is not None]
+ return statistics.fmean(ds) if ds else 0.0
+ models.sort(key=mean_surface_d, reverse=True)
+
+ M = np.full((len(models), len(VARIANT_ORDER_ALL)), np.nan)
+ for i, m in enumerate(models):
+ for j, v in enumerate(VARIANT_ORDER_ALL):
+ d = by_cell.get((m, v))
+ if d is not None:
+ M[i, j] = d
+
+ fig, ax = plt.subplots(figsize=(7, 9))
+ vmin = 0.0
+ vmax = 1.4
+ cmap = plt.cm.viridis
+ im = ax.imshow(M, cmap=cmap, vmin=vmin, vmax=vmax, aspect="auto")
+ ax.set_xticks(range(len(VARIANT_ORDER_ALL)))
+ ax.set_xticklabels([VARIANT_LABELS[v] for v in VARIANT_ORDER_ALL])
+ ax.set_yticks(range(len(models)))
+ ax.set_yticklabels(models, fontsize=9)
+ # Annotate values
+ for i in range(len(models)):
+ for j in range(len(VARIANT_ORDER_ALL)):
+ v = M[i, j]
+ if not math.isnan(v):
+ color = "white" if v < 0.7 else "black"
+ ax.text(j, i, f"{v:+.2f}", ha="center", va="center",
+ fontsize=8, color=color)
+ # Vertical line separating surface from KV
+ ax.axvline(x=3.5, color="white", linewidth=2)
+ cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
+ cbar.set_label("Cohen's d (stable − brittle)\non canonicalized token Jaccard",
+ fontsize=9)
+ ax.set_title("Structural overlap effect size: stable vs brittle\n"
+ "(surface = self-anchor; KV = canonical-anchor)",
+ fontsize=11)
+ ax.set_xlabel("Variant family", fontsize=10)
+ plt.tight_layout()
+ out = FIG_DIR / "fig1_structural_d_heatmap.png"
+ plt.savefig(out, dpi=200, bbox_inches="tight")
+ plt.close()
+ print(f"Saved {out}")
+
+
+# ----------------------------------------------------------------------
+# Fig 2 — Rescue rebound rates with Wilson CI
+# ----------------------------------------------------------------------
+
+def wilson_ci(k: int, n: int, z: float = 1.96):
+ if n == 0:
+ return (0.0, 0.0, 0.0)
+ p = k / n
+ denom = 1 + z * z / n
+ center = (p + z * z / (2 * n)) / denom
+ half = z * math.sqrt(p * (1 - p) / n + z * z / (4 * n * n)) / denom
+ return (p, max(0.0, center - half), min(1.0, center + half))
+
+
+def fig2_rescue_rates():
+ rows = [json.loads(l) for l in open(ROOT / "rescue_results/rescue_30.jsonl")]
+
+ counts = defaultdict(lambda: {"k": 0, "n": 0})
+ for r in rows:
+ counts[(r["variant"], r["condition"])]["n"] += 1
+ if r.get("grade") == "CORRECT":
+ counts[(r["variant"], r["condition"])]["k"] += 1
+
+ conds_full = ["null", "canonical_T2", "own_T2"]
+ cond_color = {"null": "#888888", "canonical_T2": "#1f77b4", "own_T2": "#d62728"}
+ cond_label = {"null": "null (generic scaffold)",
+ "canonical_T2": "canonical_T2 (item-specific, expert prose)",
+ "own_T2": "own_T2 (item-specific, model's own work, renamed)"}
+
+ fig, ax = plt.subplots(figsize=(8, 5))
+ n_var = len(VARIANT_ORDER_ALL)
+ width = 0.27
+ x = np.arange(n_var)
+ for ci, cond in enumerate(conds_full):
+ ks, lows, highs, ps = [], [], [], []
+ for v in VARIANT_ORDER_ALL:
+ d = counts.get((v, cond))
+ if d is None:
+ ks.append(0); lows.append(0); highs.append(0); ps.append(0)
+ continue
+ p, lo, hi = wilson_ci(d["k"], d["n"])
+ ps.append(p * 100)
+ lows.append((p - lo) * 100)
+ highs.append((hi - p) * 100)
+ ks.append(d["k"])
+ offset = (ci - 1) * width
+ ax.bar(x + offset, ps, width=width, color=cond_color[cond], label=cond_label[cond],
+ yerr=[lows, highs], capsize=3, error_kw={"elinewidth": 1, "ecolor": "#444444"})
+ # Annotate counts above each bar
+ for xi, p, k in zip(x + offset, ps, ks):
+ if k > 0:
+ ax.text(xi, p + 0.5, f"{p:.0f}%", ha="center", va="bottom", fontsize=8)
+
+ ax.set_xticks(x)
+ ax.set_xticklabels([VARIANT_LABELS[v] for v in VARIANT_ORDER_ALL], fontsize=10)
+ ax.set_ylabel("Rebound rate (%) on flip cases", fontsize=10)
+ ax.set_title("Repairability rescue: rebound rate by variant and prefix condition\n"
+ "(pooled across 4 models, n ≈ 100–120 per cell, 95% Wilson CI)",
+ fontsize=11)
+ ax.set_ylim(0, 60)
+ ax.legend(loc="upper right", fontsize=8, framealpha=0.95)
+ ax.grid(axis="y", linestyle="--", alpha=0.4)
+ ax.set_axisbelow(True)
+ plt.tight_layout()
+ out = FIG_DIR / "fig2_rescue_rebound.png"
+ plt.savefig(out, dpi=200, bbox_inches="tight")
+ plt.close()
+ print(f"Saved {out}")
+
+
+# ----------------------------------------------------------------------
+# Fig 3 — own_T2 vs canonical_T2 scatter
+# ----------------------------------------------------------------------
+
+def fig3_own_vs_canonical_scatter():
+ rows = [json.loads(l) for l in open(ROOT / "rescue_results/rescue_30.jsonl")]
+
+ counts = defaultdict(lambda: {"k": 0, "n": 0})
+ for r in rows:
+ counts[(r["model"], r["variant"], r["condition"])]["n"] += 1
+ if r.get("grade") == "CORRECT":
+ counts[(r["model"], r["variant"], r["condition"])]["k"] += 1
+
+ fig, ax = plt.subplots(figsize=(7, 7))
+
+ models_in_data = sorted({k[0] for k in counts})
+ model_color = {
+ "claude-sonnet-4": "#ff7f0e",
+ "gemini-2.5-flash": "#2ca02c",
+ "gpt-4.1-mini": "#1f77b4",
+ "gpt-4o-mini": "#d62728",
+ }
+ var_marker = {
+ "descriptive_long": "o",
+ "descriptive_long_confusing": "s",
+ "descriptive_long_misleading": "^",
+ "garbled_string": "D",
+ }
+
+ # Diagonal
+ ax.plot([0, 0.7], [0, 0.7], "k--", lw=1, alpha=0.5)
+ ax.text(0.62, 0.66, "y = x", fontsize=8, alpha=0.6)
+
+ for m in models_in_data:
+ for v in VARIANT_ORDER_SURF:
+ own = counts.get((m, v, "own_T2"))
+ can = counts.get((m, v, "canonical_T2"))
+ if own is None or can is None or own["n"] == 0 or can["n"] == 0:
+ continue
+ x = can["k"] / can["n"]
+ y = own["k"] / own["n"]
+ ax.scatter(x, y, s=110, c=model_color.get(m, "gray"),
+ marker=var_marker[v], alpha=0.85,
+ edgecolors="black", linewidths=0.6)
+
+ # Build legend
+ from matplotlib.lines import Line2D
+ model_handles = [Line2D([], [], marker="o", linestyle="", markersize=9,
+ markerfacecolor=c, markeredgecolor="black",
+ markeredgewidth=0.6, label=m)
+ for m, c in model_color.items() if m in models_in_data]
+ variant_handles = [Line2D([], [], marker=mk, linestyle="", markersize=9,
+ markerfacecolor="lightgray", markeredgecolor="black",
+ markeredgewidth=0.6, label=VARIANT_LABELS[v])
+ for v, mk in var_marker.items()]
+ leg1 = ax.legend(handles=model_handles, loc="upper left", title="Model",
+ fontsize=8, title_fontsize=9, framealpha=0.95)
+ ax.add_artist(leg1)
+ ax.legend(handles=variant_handles, loc="lower right", title="Variant",
+ fontsize=8, title_fontsize=9, framealpha=0.95)
+
+ ax.set_xlim(0, 0.7)
+ ax.set_ylim(0, 0.7)
+ ax.set_xlabel("canonical_T2 rebound rate", fontsize=10)
+ ax.set_ylabel("own_T2 rebound rate", fontsize=10)
+ ax.set_title("Per-cell rescue rates: model's own prefix vs canonical prefix\n"
+ "(below diagonal = canonical wins; gpt-4o-mini is the only family above)",
+ fontsize=11)
+ ax.grid(linestyle="--", alpha=0.4)
+ ax.set_axisbelow(True)
+ plt.tight_layout()
+ out = FIG_DIR / "fig3_own_vs_canonical_scatter.png"
+ plt.savefig(out, dpi=200, bbox_inches="tight")
+ plt.close()
+ print(f"Saved {out}")
+
+
+def main():
+ fig1_structural_d_heatmap()
+ fig2_rescue_rates()
+ fig3_own_vs_canonical_scatter()
+ print("\nAll figures written to:", FIG_DIR)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/analysis/normalization_analysis.py b/analysis/normalization_analysis.py
new file mode 100644
index 0000000..8fb4f48
--- /dev/null
+++ b/analysis/normalization_analysis.py
@@ -0,0 +1,189 @@
+"""Quantify spontaneous variant->canonical name normalization in own_T2 outputs.
+
+For each own_T2 case, check whether the model's student_solution preserves the
+variant variable names from its prefix or normalizes them back to the canonical
+names from the dataset's rename map.
+
+For each variant variable name in the rename map:
+- count its occurrences in the prefix (as injected)
+- count its occurrences in the model's student_solution
+- count occurrences of the corresponding CANONICAL name in the student_solution
+
+If the model preserves variant naming: variant_name count in solution should be
+proportionally similar to the prefix count.
+If the model normalizes back: canonical_name count in solution should rise while
+variant_name count drops.
+"""
+from __future__ import annotations
+import json
+import re
+import sys
+from pathlib import Path
+from collections import defaultdict
+import statistics
+
+THIS_DIR = Path(__file__).resolve().parent
+sys.path.insert(0, str(THIS_DIR))
+from rescue_runner import load_dataset_full, find_flip_cases, build_case_prompts
+
+PILOT_PATH = Path("/home/yurenh2/gap/analysis/rescue_results/rescue_5.jsonl")
+
+
+def count_word(text: str, word: str) -> int:
+ """Whole-word count of `word` in `text`."""
+ if not text or not word:
+ return 0
+ pat = r"(?<![A-Za-z0-9_])" + re.escape(word) + r"(?![A-Za-z0-9_])"
+ return len(re.findall(pat, text))
+
+
+def analyze_one(row: dict, ds_cell: dict) -> dict:
+ """For one own_T2 row, compute name preservation stats."""
+ variant = row["variant"]
+ var_info = ds_cell["variants"].get(variant, {})
+ rmap = var_info.get("map") or {}
+ if not rmap:
+ return {}
+ student = row.get("student_solution") or ""
+ if not student:
+ return {}
+
+ # Build the prefix that the model was given
+ case = {
+ "index": row["index"],
+ "problem_type": row["problem_type"],
+ "orig_solution": "",
+ "orig_final_answer": "",
+ }
+ # We need the original solution from results_new to reconstruct the prefix.
+ # Use find_flip_cases to recover it cleanly.
+ cases = find_flip_cases(row["model"], variant, 100)
+ matched = next((c for c in cases if c["index"] == row["index"]), None)
+ if matched is None:
+ return {}
+ prompts = build_case_prompts(matched, variant, ds_cell)
+ own_prompt = prompts.get("own_T2", "")
+ if "PARTIAL WORK" not in own_prompt:
+ return {}
+ # Extract just the partial work text
+ section = own_prompt.split("PARTIAL WORK")[1].split("Provide a complete")[0]
+ section = section.split("(to copy verbatim")[1] if "(to copy verbatim" in section else section
+ section = section.split("):", 1)[1] if "):" in section else section
+ prefix = section.strip()
+
+ # For each variant variable, count occurrences in prefix and in student
+ per_var = {}
+ for canon_name, var_name in rmap.items():
+ if not var_name:
+ continue
+ prefix_v = count_word(prefix, var_name)
+ student_v = count_word(student, var_name)
+ student_c = count_word(student, canon_name)
+ # Only meaningful if the variant name actually appeared in the prefix
+ if prefix_v == 0:
+ continue
+ per_var[var_name] = {
+ "canon_name": canon_name,
+ "prefix_count_variant": prefix_v,
+ "student_count_variant": student_v,
+ "student_count_canonical": student_c,
+ # Preservation ratio: how much of the variant naming survived
+ # capped to 1.0 (model may use the variable many more times in
+ # its continuation, which inflates the count)
+ "preservation_ratio": min(1.0, student_v / max(1, prefix_v)),
+ "normalization_ratio": min(1.0, student_c / max(1, prefix_v)),
+ }
+ if not per_var:
+ return {}
+ # Aggregate per case: median preservation
+ pres_vals = [v["preservation_ratio"] for v in per_var.values()]
+ norm_vals = [v["normalization_ratio"] for v in per_var.values()]
+ return {
+ "model": row["model"],
+ "variant": variant,
+ "index": row["index"],
+ "grade": row.get("grade"),
+ "n_vars_in_prefix": len(per_var),
+ "median_preservation": statistics.median(pres_vals),
+ "median_normalization": statistics.median(norm_vals),
+ "mean_preservation": statistics.fmean(pres_vals),
+ "mean_normalization": statistics.fmean(norm_vals),
+ "per_var": per_var,
+ }
+
+
+def main():
+ print("Loading dataset ...")
+ ds = load_dataset_full()
+ print(f"Loaded {len(ds)} problems")
+ print(f"\nLoading pilot rows from {PILOT_PATH} ...")
+ rows = [json.loads(l) for l in open(PILOT_PATH)]
+ own_rows = [r for r in rows if r["condition"] == "own_T2"]
+ print(f" total rows: {len(rows)}, own_T2 rows: {len(own_rows)}")
+
+ analyses = []
+ skipped = 0
+ for r in own_rows:
+ ds_cell = ds.get(r["index"])
+ if ds_cell is None:
+ skipped += 1
+ continue
+ a = analyze_one(r, ds_cell)
+ if a:
+ analyses.append(a)
+ else:
+ skipped += 1
+ print(f" analyzed: {len(analyses)}, skipped: {skipped}")
+
+ # Aggregate by variant
+ print("\n=== SPONTANEOUS NORMALIZATION (own_T2 condition only) ===\n")
+ print("Per case: median across variant variables of preservation ratio")
+ print("(higher = more variant naming preserved; lower = normalized back to canonical)")
+ print()
+ print(f"{'Variant':<32} {'n':>4} {'median_pres':>12} {'mean_pres':>10} "
+ f"{'median_norm':>12} {'mean_norm':>10}")
+ print("-" * 90)
+ by_variant = defaultdict(list)
+ for a in analyses:
+ by_variant[a["variant"]].append(a)
+ for v in sorted(by_variant):
+ cs = by_variant[v]
+ mp_vals = [c["median_preservation"] for c in cs]
+ mn_vals = [c["median_normalization"] for c in cs]
+ print(f"{v:<32} {len(cs):>4} "
+ f"{statistics.median(mp_vals):>12.3f} {statistics.fmean(mp_vals):>10.3f} "
+ f"{statistics.median(mn_vals):>12.3f} {statistics.fmean(mn_vals):>10.3f}")
+
+ # Aggregate by model
+ print(f"\n{'Model':<22} {'n':>4} {'median_pres':>12} {'mean_pres':>10} "
+ f"{'median_norm':>12} {'mean_norm':>10}")
+ print("-" * 80)
+ by_model = defaultdict(list)
+ for a in analyses:
+ by_model[a["model"]].append(a)
+ for m in sorted(by_model):
+ cs = by_model[m]
+ mp_vals = [c["median_preservation"] for c in cs]
+ mn_vals = [c["median_normalization"] for c in cs]
+ print(f"{m:<22} {len(cs):>4} "
+ f"{statistics.median(mp_vals):>12.3f} {statistics.fmean(mp_vals):>10.3f} "
+ f"{statistics.median(mn_vals):>12.3f} {statistics.fmean(mn_vals):>10.3f}")
+
+ # Effect of normalization on rebound: do cases that normalized more often FAIL?
+ print("\n=== RELATION TO REBOUND ===")
+ pass_pres = [a["median_preservation"] for a in analyses if a["grade"] == "CORRECT"]
+ fail_pres = [a["median_preservation"] for a in analyses if a["grade"] == "INCORRECT"]
+ print(f" median_preservation among rebound CORRECT (n={len(pass_pres)}): "
+ f"median={statistics.median(pass_pres):.3f} mean={statistics.fmean(pass_pres):.3f}")
+ print(f" median_preservation among rebound INCORRECT (n={len(fail_pres)}): "
+ f"median={statistics.median(fail_pres):.3f} mean={statistics.fmean(fail_pres):.3f}")
+
+ # Save detailed results
+ out = Path("/home/yurenh2/gap/analysis/normalization_results.json")
+ json.dump([{k: v for k, v in a.items() if k != "per_var"} for a in analyses],
+ open(out, "w"), indent=2)
+ print(f"\nSaved -> {out}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/analysis/rescue_analyze.py b/analysis/rescue_analyze.py
new file mode 100644
index 0000000..5fe97b6
--- /dev/null
+++ b/analysis/rescue_analyze.py
@@ -0,0 +1,161 @@
+"""Analyze full rescue results: per-cell rebound rates, Wilson CIs, McNemar."""
+from __future__ import annotations
+import json
+import math
+import statistics
+from collections import defaultdict
+from pathlib import Path
+
+PATH = Path("/home/yurenh2/gap/analysis/rescue_results/rescue_30.jsonl")
+
+
+def wilson_ci(k: int, n: int, z: float = 1.96) -> tuple:
+ if n == 0:
+ return (0.0, 0.0, 0.0)
+ p = k / n
+ denom = 1 + z * z / n
+ center = (p + z * z / (2 * n)) / denom
+ half = z * math.sqrt(p * (1 - p) / n + z * z / (4 * n * n)) / denom
+ return (p, max(0.0, center - half), min(1.0, center + half))
+
+
+def mcnemar_p(b: int, c: int) -> float:
+ """McNemar exact-ish p (binomial two-sided). b = treat A correct, B wrong;
+ c = treat A wrong, B correct. Returns p value testing b == c."""
+ n = b + c
+ if n == 0:
+ return 1.0
+ # Two-sided binomial test on min(b,c) ~ Bin(n, 0.5)
+ k = min(b, c)
+ # cumulative
+ cum = 0.0
+ for i in range(k + 1):
+ cum += math.comb(n, i) * (0.5 ** n)
+ p = min(1.0, 2 * cum)
+ return p
+
+
+def main():
+ rows = [json.loads(l) for l in open(PATH)]
+ print(f"Loaded {len(rows)} rows")
+
+ # Quick sanity
+ from collections import Counter
+ print("Solve status:", Counter(r.get("solve_status") for r in rows))
+ print("Grade status:", Counter(r.get("grade_status") for r in rows))
+
+ # Per-cell counts
+ counts = defaultdict(lambda: {"total": 0, "correct": 0})
+ for r in rows:
+ if r.get("grade_status") != "success" and r.get("grade") not in ("CORRECT", "INCORRECT"):
+ # Treat solve failures / parse failures as INCORRECT (conservative)
+ pass
+ key = (r["model"], r["variant"], r["condition"])
+ counts[key]["total"] += 1
+ if r.get("grade") == "CORRECT":
+ counts[key]["correct"] += 1
+
+ # Aggregated by (variant, condition)
+ by_var_cond = defaultdict(lambda: {"total": 0, "correct": 0})
+ for (m, v, c), d in counts.items():
+ by_var_cond[(v, c)]["total"] += d["total"]
+ by_var_cond[(v, c)]["correct"] += d["correct"]
+
+ print("\n" + "=" * 90)
+ print("REBOUND RATE BY (VARIANT, CONDITION) [aggregated across 4 models]")
+ print("=" * 90)
+ print(f"{'Variant':<32} {'Condition':<14} {'k/n':>10} {'rate':>7} {'95% Wilson CI':>20}")
+ print("-" * 90)
+ variants_order = ["descriptive_long", "descriptive_long_confusing",
+ "descriptive_long_misleading", "garbled_string", "kernel_variant"]
+ conds_order = ["null", "canonical_T2", "own_T2"]
+ for v in variants_order:
+ for c in conds_order:
+ d = by_var_cond.get((v, c))
+ if not d:
+ continue
+ p, lo, hi = wilson_ci(d["correct"], d["total"])
+ print(f"{v:<32} {c:<14} {d['correct']:>4}/{d['total']:>4} "
+ f"{p*100:>5.1f}% [{lo*100:>5.1f}%, {hi*100:>5.1f}%]")
+ print()
+
+ # Per-model aggregated by (variant, condition)
+ print("\n" + "=" * 90)
+ print("REBOUND RATE PER (MODEL, VARIANT, CONDITION)")
+ print("=" * 90)
+ models_order = sorted({k[0] for k in counts})
+ print(f"{'Model':<22} {'Variant':<32} {'cond':<14} {'k/n':>10} {'rate':>7}")
+ for m in models_order:
+ for v in variants_order:
+ for c in conds_order:
+ d = counts.get((m, v, c))
+ if not d:
+ continue
+ p, lo, hi = wilson_ci(d["correct"], d["total"])
+ print(f" {m:<20} {v:<32} {c:<14} {d['correct']:>3}/{d['total']:>3} "
+ f"{p*100:>5.1f}%")
+ print()
+
+ # Paired McNemar test: same case, different conditions
+ # Pair canonical_T2 vs null, and own_T2 vs null
+ print("\n" + "=" * 90)
+ print("PAIRED MCNEMAR TESTS")
+ print("=" * 90)
+ case_grades = defaultdict(dict) # (model, variant, index) -> {cond: grade}
+ for r in rows:
+ case_grades[(r["model"], r["variant"], r["index"])][r["condition"]] = r.get("grade")
+
+ print("\ncanonical_T2 vs null:")
+ print(f" {'cell':<46} {'b (can-only)':>12} {'c (null-only)':>13} "
+ f"{'both-CORR':>10} {'both-INC':>10} {'McNemar p':>11}")
+ for m in models_order:
+ for v in variants_order:
+ b = c = both_corr = both_inc = 0
+ for k, grds in case_grades.items():
+ if k[0] != m or k[1] != v: continue
+ ca = grds.get("canonical_T2"); nu = grds.get("null")
+ if ca is None or nu is None: continue
+ if ca == "CORRECT" and nu == "INCORRECT": b += 1
+ elif ca == "INCORRECT" and nu == "CORRECT": c += 1
+ elif ca == "CORRECT" and nu == "CORRECT": both_corr += 1
+ elif ca == "INCORRECT" and nu == "INCORRECT": both_inc += 1
+ p = mcnemar_p(b, c)
+ print(f" {m+'/'+v:<46} {b:>12} {c:>13} {both_corr:>10} {both_inc:>10} {p:>11.3f}")
+
+ print("\nown_T2 vs null:")
+ print(f" {'cell':<46} {'b (own-only)':>12} {'c (null-only)':>13} "
+ f"{'both-CORR':>10} {'both-INC':>10} {'McNemar p':>11}")
+ for m in models_order:
+ for v in [vv for vv in variants_order if vv != "kernel_variant"]:
+ b = c = both_corr = both_inc = 0
+ for k, grds in case_grades.items():
+ if k[0] != m or k[1] != v: continue
+ ow = grds.get("own_T2"); nu = grds.get("null")
+ if ow is None or nu is None: continue
+ if ow == "CORRECT" and nu == "INCORRECT": b += 1
+ elif ow == "INCORRECT" and nu == "CORRECT": c += 1
+ elif ow == "CORRECT" and nu == "CORRECT": both_corr += 1
+ elif ow == "INCORRECT" and nu == "INCORRECT": both_inc += 1
+ p = mcnemar_p(b, c)
+ print(f" {m+'/'+v:<46} {b:>12} {c:>13} {both_corr:>10} {both_inc:>10} {p:>11.3f}")
+
+ print("\nown_T2 vs canonical_T2:")
+ print(f" {'cell':<46} {'b (own-only)':>12} {'c (can-only)':>13} "
+ f"{'both-CORR':>10} {'both-INC':>10} {'McNemar p':>11}")
+ for m in models_order:
+ for v in [vv for vv in variants_order if vv != "kernel_variant"]:
+ b = c = both_corr = both_inc = 0
+ for k, grds in case_grades.items():
+ if k[0] != m or k[1] != v: continue
+ ow = grds.get("own_T2"); ca = grds.get("canonical_T2")
+ if ow is None or ca is None: continue
+ if ow == "CORRECT" and ca == "INCORRECT": b += 1
+ elif ow == "INCORRECT" and ca == "CORRECT": c += 1
+ elif ow == "CORRECT" and ca == "CORRECT": both_corr += 1
+ elif ow == "INCORRECT" and ca == "INCORRECT": both_inc += 1
+ p = mcnemar_p(b, c)
+ print(f" {m+'/'+v:<46} {b:>12} {c:>13} {both_corr:>10} {both_inc:>10} {p:>11.3f}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/analysis/rescue_api.py b/analysis/rescue_api.py
new file mode 100644
index 0000000..4641655
--- /dev/null
+++ b/analysis/rescue_api.py
@@ -0,0 +1,373 @@
+"""Async API caller for rescue experiment.
+
+Supports OpenAI, Anthropic, Google. All callers return a unified dict:
+ {"status": "success"|"failed", "content": str, "error": str|None}
+
+Concurrency is controlled per-provider via asyncio.Semaphore so we don't
+saturate rate limits in any one provider.
+"""
+from __future__ import annotations
+import asyncio
+import json
+import os
+import random
+from typing import Optional
+
+# ---------- Provider constants ----------
+
+# Solver model -> provider mapping
+SOLVER_PROVIDERS = {
+ "gpt-4.1-mini": "openai",
+ "gpt-4o-mini": "openai",
+ "claude-sonnet-4": "anthropic",
+ "gemini-2.5-flash": "google",
+}
+
+# API model strings (the canonical IDs to send)
+API_MODEL_NAMES = {
+ "gpt-4.1-mini": "gpt-4.1-mini",
+ "gpt-4o-mini": "gpt-4o-mini",
+ "claude-sonnet-4": "claude-sonnet-4-20250514",
+ "gemini-2.5-flash": "gemini-2.5-flash",
+}
+
+GRADER_MODEL = "gpt-4o"
+GRADER_PROVIDER = "openai"
+
+PER_PROVIDER_CONCURRENCY = {
+ "openai": 500,
+ "anthropic": 25, # 90k tok/min cap; 25 in flight keeps us comfortably under
+ "google": 300,
+}
+
+DEFAULT_RETRIES = 6
+DEFAULT_BASE_TIMEOUT = 300.0
+RATE_LIMIT_BACKOFF_SECONDS = 60.0 # min sleep on rate limit hits
+
+
+# ---------- Solver / grader prompts (consistent with paper) ----------
+
+SOLVER_SYSTEM_PROMPT = """You are an expert mathematician solving competition-level problems.
+Provide detailed, step-by-step solutions with clear mathematical reasoning.
+
+Requirements:
+- Show all your work and intermediate steps
+- Justify each major step of your reasoning
+- Use proper mathematical notation
+- Be thorough but concise
+- State your final answer clearly
+
+Solve the problem completely and rigorously."""
+
+PROOF_GRADER_SYSTEM_PROMPT = """You are an extremely strict mathematical grader evaluating competition-level PROOF problems.
+
+GRADING STANDARDS (BE VERY STRICT):
+- Mathematical rigor: Every step must be mathematically sound and justified
+- Logical flow: The reasoning must be clear, complete, and logically connected
+- Correctness: All calculations, algebraic manipulations, and conclusions must be correct
+- Completeness: The solution must address all parts of the problem fully
+- Precision: Mathematical statements must be precise and unambiguous
+
+FAILING CRITERIA (Mark as INCORRECT if ANY of these apply):
+- Any unjustified logical leap or gap in reasoning
+- Any computational error, no matter how small
+- Missing steps in critical parts of the argument
+- Imprecise or ambiguous mathematical statements
+- Incorrect final answer, even if approach is partially correct
+- Circular reasoning or logical fallacies
+- Misuse of mathematical theorems or definitions
+
+BE EXTREMELY STRICT. Competition mathematics proofs require perfect precision."""
+
+CALCULATION_GRADER_SYSTEM_PROMPT = """You are a mathematical grader evaluating competition-level CALCULATION problems.
+
+GRADING STANDARDS FOR CALCULATION PROBLEMS:
+- Primary focus: Is the final answer correct?
+- Secondary focus: Is the overall approach reasonable and mathematically sound?
+- Computation: Allow minor computational slips if the method is correct and final answer is right
+
+GRADING CRITERIA:
+- CORRECT: Final answer is correct AND approach is fundamentally sound
+- INCORRECT: Final answer is wrong OR approach is fundamentally flawed
+
+For calculation problems, the final numerical answer is the most important criterion.
+Minor intermediate errors are acceptable if they don't affect the final result."""
+
+PROOF_GRADER_USER_TEMPLATE = """Grade this PROOF solution with extreme strictness.
+
+PROBLEM:
+{problem_statement}
+
+STUDENT SOLUTION:
+{solution}
+
+CORRECT REFERENCE SOLUTION:
+{reference_solution}
+
+Evaluate with maximum strictness. Every logical step must be perfect. Return JSON with:
+{{"grade": "CORRECT" or "INCORRECT",
+ "detailed_feedback": "specific detailed analysis of what is right/wrong",
+ "major_issues": "list of significant mathematical errors or gaps",
+ "final_answer_correct": true or false,
+ "reasoning_rigor_score": 0-10 integer (10=perfect rigor, 0=severely flawed),
+ "overall_assessment": "comprehensive evaluation summary"}}"""
+
+CALCULATION_GRADER_USER_TEMPLATE = """Grade this CALCULATION solution with focus on final answer correctness.
+
+PROBLEM:
+{problem_statement}
+
+STUDENT SOLUTION:
+{solution}
+
+CORRECT REFERENCE SOLUTION:
+{reference_solution}
+
+Focus primarily on whether the final answer is correct. Return JSON with:
+{{"grade": "CORRECT" or "INCORRECT",
+ "detailed_feedback": "specific detailed analysis of what is right/wrong",
+ "major_issues": "list of significant mathematical errors or gaps",
+ "final_answer_correct": true or false,
+ "reasoning_rigor_score": 0-10 integer (10=perfect rigor, 0=severely flawed),
+ "overall_assessment": "comprehensive evaluation summary"}}"""
+
+
+# ---------- Lazy client builders ----------
+
+_openai_client = None
+_anthropic_client = None
+_google_client = None
+
+def _get_openai_client():
+ global _openai_client
+ if _openai_client is None:
+ from openai import AsyncOpenAI
+ import httpx
+ limits = httpx.Limits(max_connections=2000, max_keepalive_connections=1000)
+ timeout = httpx.Timeout(timeout=DEFAULT_BASE_TIMEOUT, connect=30.0,
+ read=DEFAULT_BASE_TIMEOUT, write=30.0)
+ _openai_client = AsyncOpenAI(http_client=httpx.AsyncClient(limits=limits, timeout=timeout))
+ return _openai_client
+
+
+def _get_anthropic_client():
+ global _anthropic_client
+ if _anthropic_client is None:
+ from anthropic import AsyncAnthropic
+ _anthropic_client = AsyncAnthropic()
+ return _anthropic_client
+
+
+def _get_google_client():
+ global _google_client
+ if _google_client is None:
+ from google import genai
+ _google_client = genai.Client(api_key=os.environ["GOOGLE_API_KEY"])
+ return _google_client
+
+
+# ---------- Per-provider call functions ----------
+
+async def _call_openai(model: str, system: str, user: str,
+ temperature: float, max_tokens: int = 16000) -> dict:
+ client = _get_openai_client()
+ api_params = {
+ "model": model,
+ "messages": [
+ {"role": "system", "content": system},
+ {"role": "user", "content": user},
+ ],
+ "max_tokens": max_tokens,
+ }
+ # o-series models force temperature=1 and don't accept max_tokens
+ if any(p in model.lower() for p in ["o1", "o3", "o4"]):
+ api_params.pop("max_tokens", None)
+ api_params["temperature"] = 1.0
+ else:
+ api_params["temperature"] = temperature
+ api_params["response_format"] = {"type": "json_object"}
+ resp = await client.chat.completions.create(**api_params)
+ content = resp.choices[0].message.content or ""
+ return {"status": "success", "content": content, "error": None}
+
+
+async def _call_anthropic(model: str, system: str, user: str,
+ temperature: float, max_tokens: int = 16000) -> dict:
+ client = _get_anthropic_client()
+ resp = await client.messages.create(
+ model=model,
+ system=system,
+ messages=[{"role": "user", "content": user}],
+ temperature=temperature,
+ max_tokens=max_tokens,
+ )
+ content = ""
+ if resp.content:
+ for block in resp.content:
+ if hasattr(block, "text"):
+ content += block.text
+ return {"status": "success", "content": content, "error": None}
+
+
+async def _call_google(model: str, system: str, user: str,
+ temperature: float, max_tokens: int = 16000) -> dict:
+ client = _get_google_client()
+ from google.genai.types import GenerateContentConfig
+ config = GenerateContentConfig(
+ system_instruction=system,
+ temperature=temperature,
+ max_output_tokens=max_tokens,
+ response_mime_type="application/json",
+ )
+ resp = await client.aio.models.generate_content(
+ model=model, contents=user, config=config,
+ )
+ content = resp.text or ""
+ return {"status": "success", "content": content, "error": None}
+
+
+# ---------- Unified caller with retries and per-provider semaphore ----------
+
+_provider_sems: dict = {}
+
+def _sem_for(provider: str) -> asyncio.Semaphore:
+ if provider not in _provider_sems:
+ _provider_sems[provider] = asyncio.Semaphore(PER_PROVIDER_CONCURRENCY[provider])
+ return _provider_sems[provider]
+
+
+async def call_model(model_short: str, system: str, user: str,
+ temperature: float = 0.0, max_tokens: int = 16000,
+ retries: int = DEFAULT_RETRIES) -> dict:
+ """Call any supported model by short alias. Includes retries."""
+ if model_short == GRADER_MODEL:
+ provider = GRADER_PROVIDER
+ api_model = GRADER_MODEL
+ else:
+ provider = SOLVER_PROVIDERS[model_short]
+ api_model = API_MODEL_NAMES[model_short]
+ sem = _sem_for(provider)
+
+ async with sem:
+ last_err = None
+ for attempt in range(retries):
+ try:
+ if provider == "openai":
+ return await _call_openai(api_model, system, user, temperature, max_tokens)
+ elif provider == "anthropic":
+ return await _call_anthropic(api_model, system, user, temperature, max_tokens)
+ elif provider == "google":
+ return await _call_google(api_model, system, user, temperature, max_tokens)
+ else:
+ return {"status": "failed", "content": "",
+ "error": f"unknown provider {provider}"}
+ except Exception as e:
+ last_err = e
+ err_str = str(e).lower()
+ # Longer backoff for rate-limit-style errors so the per-minute
+ # window has time to refill.
+ if "rate_limit" in err_str or "429" in err_str or "rate limit" in err_str:
+ await asyncio.sleep(RATE_LIMIT_BACKOFF_SECONDS + random.random() * 10)
+ else:
+ await asyncio.sleep(min(2 ** attempt + random.random(), 30))
+ return {"status": "failed", "content": "",
+ "error": f"{type(last_err).__name__}: {str(last_err)[:300]}"}
+
+
+# ---------- High-level helpers ----------
+
+async def solve(model_short: str, problem_user_msg: str) -> dict:
+ """Run the solver. The user message already contains problem + any prefix."""
+ return await call_model(model_short, SOLVER_SYSTEM_PROMPT, problem_user_msg, temperature=0.0)
+
+
+async def grade(problem_type: str, problem_statement: str,
+ solution: str, reference_solution: str) -> dict:
+ """Run the grader (gpt-4o)."""
+ if problem_type == "proof":
+ sys = PROOF_GRADER_SYSTEM_PROMPT
+ tmpl = PROOF_GRADER_USER_TEMPLATE
+ else:
+ sys = CALCULATION_GRADER_SYSTEM_PROMPT
+ tmpl = CALCULATION_GRADER_USER_TEMPLATE
+ user = tmpl.format(problem_statement=problem_statement,
+ solution=solution,
+ reference_solution=reference_solution)
+ return await call_model(GRADER_MODEL, sys, user, temperature=0.0)
+
+
+def parse_solution(content: str) -> dict:
+ """Parse JSON {solution, final_answer} from model output, with tolerance."""
+ if not content:
+ return {"solution": "", "final_answer": "", "_parse_error": "empty"}
+ try:
+ d = json.loads(content)
+ return {"solution": d.get("solution", ""),
+ "final_answer": d.get("final_answer", ""),
+ "_parse_error": None}
+ except Exception:
+ # Try to extract a JSON object substring
+ import re
+ m = re.search(r"\{.*\}", content, re.DOTALL)
+ if m:
+ try:
+ d = json.loads(m.group(0))
+ return {"solution": d.get("solution", ""),
+ "final_answer": d.get("final_answer", ""),
+ "_parse_error": None}
+ except Exception as e:
+ return {"solution": content, "final_answer": "",
+ "_parse_error": f"json parse: {e}"}
+ return {"solution": content, "final_answer": "",
+ "_parse_error": "no JSON object found"}
+
+
+def parse_grade(content: str) -> dict:
+ """Parse JSON grade output."""
+ if not content:
+ return {"grade": "INCORRECT", "_parse_error": "empty"}
+ try:
+ d = json.loads(content)
+ # Normalize grade
+ g = (d.get("grade") or "").strip().upper()
+ return {
+ "grade": g if g in ("CORRECT", "INCORRECT") else "INCORRECT",
+ "final_answer_correct": d.get("final_answer_correct"),
+ "detailed_feedback": d.get("detailed_feedback", ""),
+ "_parse_error": None,
+ }
+ except Exception:
+ import re
+ m = re.search(r"\{.*\}", content, re.DOTALL)
+ if m:
+ try:
+ d = json.loads(m.group(0))
+ g = (d.get("grade") or "").strip().upper()
+ return {
+ "grade": g if g in ("CORRECT", "INCORRECT") else "INCORRECT",
+ "final_answer_correct": d.get("final_answer_correct"),
+ "detailed_feedback": d.get("detailed_feedback", ""),
+ "_parse_error": None,
+ }
+ except Exception as e:
+ return {"grade": "INCORRECT", "_parse_error": f"json parse: {e}"}
+ return {"grade": "INCORRECT", "_parse_error": "no JSON object found"}
+
+
+# ---------- Standalone health check ----------
+
+async def _health_check():
+ print("Running health checks ...")
+ msg = ('Reply with JSON {"status": "ok"} only.')
+ for short in ["gpt-4o-mini", "claude-sonnet-4", "gemini-2.5-flash"]:
+ r = await call_model(short, "You are a test. Reply only the requested JSON.",
+ msg, temperature=0.0, max_tokens=200, retries=2)
+ print(f" {short}: {r['status']} - {r['content'][:200]!r} err={r['error']}")
+ # Grader
+ r = await call_model(GRADER_MODEL, "You are a test.", msg, temperature=0.0,
+ max_tokens=200, retries=2)
+ print(f" {GRADER_MODEL} (grader): {r['status']} - {r['content'][:200]!r} err={r['error']}")
+
+
+if __name__ == "__main__":
+ asyncio.run(_health_check())
diff --git a/analysis/rescue_pooled.py b/analysis/rescue_pooled.py
new file mode 100644
index 0000000..cc9f782
--- /dev/null
+++ b/analysis/rescue_pooled.py
@@ -0,0 +1,174 @@
+"""Pooled rescue analysis for the rebuttal headline.
+
+Reports:
+1. Per-variant pooled rebound rates with Wilson 95% CI for each condition
+2. Pooled McNemar (paired) tests across all 4 models per variant
+3. Pooled McNemar across all 5 surface variants for each model
+4. Headline single-cell numbers
+"""
+from __future__ import annotations
+import json
+import math
+import statistics
+from collections import defaultdict
+from pathlib import Path
+
+PATH = Path("/home/yurenh2/gap/analysis/rescue_results/rescue_30.jsonl")
+OUT_PATH = Path("/home/yurenh2/gap/analysis/rescue_pooled_summary.json")
+
+
+def wilson_ci(k: int, n: int, z: float = 1.96):
+ if n == 0:
+ return (0.0, 0.0, 0.0)
+ p = k / n
+ denom = 1 + z * z / n
+ center = (p + z * z / (2 * n)) / denom
+ half = z * math.sqrt(p * (1 - p) / n + z * z / (4 * n * n)) / denom
+ return (p, max(0.0, center - half), min(1.0, center + half))
+
+
+def mcnemar_p(b: int, c: int) -> float:
+ n = b + c
+ if n == 0:
+ return 1.0
+ k = min(b, c)
+ cum = sum(math.comb(n, i) * (0.5 ** n) for i in range(k + 1))
+ return min(1.0, 2 * cum)
+
+
+def main():
+ rows = [json.loads(l) for l in open(PATH)]
+ print(f"Loaded {len(rows)} rows\n")
+
+ # case_grades[(model, variant, index)] = {cond: grade}
+ case_grades = defaultdict(dict)
+ for r in rows:
+ case_grades[(r["model"], r["variant"], r["index"])][r["condition"]] = r.get("grade")
+
+ variants_order = ["descriptive_long", "descriptive_long_confusing",
+ "descriptive_long_misleading", "garbled_string", "kernel_variant"]
+ short = {"descriptive_long":"DL","descriptive_long_confusing":"DLC",
+ "descriptive_long_misleading":"DLM","garbled_string":"GS","kernel_variant":"KV"}
+
+ summary = {}
+
+ print("=" * 92)
+ print("HEADLINE: Rescue rebound by variant (pooled across 4 models)")
+ print("=" * 92)
+ print(f"{'Variant':<6} {'Condition':<14} {'k/n':>10} {'rate':>7} "
+ f"{'95% Wilson CI':>20} {'Δ vs null':>11}")
+ print("-" * 80)
+ var_summary = {}
+ for v in variants_order:
+ # Pool counts across models
+ cell_counts = defaultdict(lambda: {"k": 0, "n": 0})
+ for k, grds in case_grades.items():
+ if k[1] != v: continue
+ for cond in ("null", "canonical_T2", "own_T2"):
+ if cond in grds:
+ cell_counts[cond]["n"] += 1
+ if grds[cond] == "CORRECT":
+ cell_counts[cond]["k"] += 1
+ # Wilson CIs
+ per_cond = {}
+ null_p = cell_counts["null"]["k"] / max(1, cell_counts["null"]["n"])
+ for cond in ("null", "canonical_T2", "own_T2"):
+ if cond not in cell_counts: continue
+ c = cell_counts[cond]
+ if c["n"] == 0: continue
+ p, lo, hi = wilson_ci(c["k"], c["n"])
+ delta = (p - null_p) * 100 if cond != "null" else 0.0
+ per_cond[cond] = {"k": c["k"], "n": c["n"], "p": p, "ci": [lo, hi], "delta_pp": delta}
+ print(f"{short[v]:<6} {cond:<14} {c['k']:>4}/{c['n']:>4} "
+ f"{p*100:>5.1f}% [{lo*100:>5.1f}%, {hi*100:>5.1f}%] "
+ f"{'+' if delta > 0 else ('' if delta == 0 else '-')}{abs(delta):>5.1f} pp")
+ # Pooled McNemar (own vs null, can vs null, own vs can)
+ mc = {}
+ for a, b in [("canonical_T2", "null"), ("own_T2", "null"),
+ ("own_T2", "canonical_T2")]:
+ b_count = c_count = 0
+ for k, grds in case_grades.items():
+ if k[1] != v: continue
+ ga = grds.get(a); gb = grds.get(b)
+ if ga is None or gb is None: continue
+ if ga == "CORRECT" and gb == "INCORRECT": b_count += 1
+ elif ga == "INCORRECT" and gb == "CORRECT": c_count += 1
+ p = mcnemar_p(b_count, c_count)
+ mc[f"{a}_vs_{b}"] = {"b": b_count, "c": c_count, "p": p}
+ var_summary[v] = {"per_cond": per_cond, "mcnemar": mc}
+ print()
+
+ summary["per_variant"] = var_summary
+
+ # Pooled McNemar across all surface variants for canonical vs null and own vs null
+ print("\n" + "=" * 92)
+ print("POOLED McNEMAR (across all 4 surface variants × 4 models)")
+ print("=" * 92)
+ surface_vs = ["descriptive_long", "descriptive_long_confusing",
+ "descriptive_long_misleading", "garbled_string"]
+ for a, b in [("canonical_T2", "null"), ("own_T2", "null"),
+ ("own_T2", "canonical_T2")]:
+ b_count = c_count = 0
+ for k, grds in case_grades.items():
+ if k[1] not in surface_vs: continue
+ ga = grds.get(a); gb = grds.get(b)
+ if ga is None or gb is None: continue
+ if ga == "CORRECT" and gb == "INCORRECT": b_count += 1
+ elif ga == "INCORRECT" and gb == "CORRECT": c_count += 1
+ p = mcnemar_p(b_count, c_count)
+ n = b_count + c_count
+ odds_ratio = b_count / max(1, c_count)
+ print(f" {a:<14} > {b:<14} b={b_count:>4}, c={c_count:>4} "
+ f"OR={odds_ratio:>4.2f} McNemar p={p:.2e} (n_discordant={n})")
+ # KV separately
+ print()
+ for a, b in [("canonical_T2", "null")]:
+ b_count = c_count = 0
+ for k, grds in case_grades.items():
+ if k[1] != "kernel_variant": continue
+ ga = grds.get(a); gb = grds.get(b)
+ if ga is None or gb is None: continue
+ if ga == "CORRECT" and gb == "INCORRECT": b_count += 1
+ elif ga == "INCORRECT" and gb == "CORRECT": c_count += 1
+ p = mcnemar_p(b_count, c_count)
+ odds_ratio = b_count / max(1, c_count)
+ print(f" KV: {a:<14} > {b:<14} b={b_count:>4}, c={c_count:>4} "
+ f"OR={odds_ratio:>4.2f} McNemar p={p:.2e}")
+
+ # Per model summary
+ print("\n" + "=" * 92)
+ print("PER MODEL (averaged across 4 surface variants)")
+ print("=" * 92)
+ print(f"{'Model':<22} {'null':>10} {'canonical_T2':>14} {'own_T2':>10} "
+ f"{'can-null':>10} {'own-null':>10}")
+ per_model = {}
+ for model in sorted({k[0] for k in case_grades}):
+ cnts = defaultdict(lambda: {"k": 0, "n": 0})
+ for k, grds in case_grades.items():
+ if k[0] != model: continue
+ if k[1] not in surface_vs: continue
+ for cond in ("null", "canonical_T2", "own_T2"):
+ if cond in grds:
+ cnts[cond]["n"] += 1
+ if grds[cond] == "CORRECT":
+ cnts[cond]["k"] += 1
+ nul_p = cnts["null"]["k"] / max(1, cnts["null"]["n"])
+ can_p = cnts["canonical_T2"]["k"] / max(1, cnts["canonical_T2"]["n"])
+ own_p = cnts["own_T2"]["k"] / max(1, cnts["own_T2"]["n"])
+ per_model[model] = {
+ "null": {"k": cnts["null"]["k"], "n": cnts["null"]["n"], "p": nul_p},
+ "canonical_T2": {"k": cnts["canonical_T2"]["k"], "n": cnts["canonical_T2"]["n"], "p": can_p},
+ "own_T2": {"k": cnts["own_T2"]["k"], "n": cnts["own_T2"]["n"], "p": own_p},
+ "can_minus_null_pp": (can_p - nul_p) * 100,
+ "own_minus_null_pp": (own_p - nul_p) * 100,
+ }
+ print(f" {model:<20} {nul_p*100:>9.1f}% {can_p*100:>13.1f}% {own_p*100:>9.1f}% "
+ f"{(can_p-nul_p)*100:>+9.1f}pp {(own_p-nul_p)*100:>+9.1f}pp")
+ summary["per_model"] = per_model
+
+ json.dump(summary, open(OUT_PATH, "w"), indent=2)
+ print(f"\nSaved -> {OUT_PATH}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/analysis/rescue_prompts.py b/analysis/rescue_prompts.py
new file mode 100644
index 0000000..8e8f65c
--- /dev/null
+++ b/analysis/rescue_prompts.py
@@ -0,0 +1,267 @@
+"""Rescue-experiment prompt construction.
+
+For each (model, variant, flip-case) we build prompts under three conditions:
+- own_T2: model's own original-correct trajectory truncated at first
+ formal equation (with leakage filter), variables auto-renamed
+ to variant names via the dataset's rename map
+- canonical_T2: the dataset's canonical variant solution truncated at first
+ formal equation (no rename needed; already in variant naming)
+- null: generic content-free scaffold
+
+Truncation rule (event-boundary):
+ 1. Find the FIRST display-math block ($$...$$, \\[...\\], \\begin{equation/align/...})
+ 2. If none, fall back to the first line containing a substantive math relation
+ (>=, <=, =, <, >, ≡, ∈) that is not merely a definition (e.g., 'let x:=...')
+ 3. The T2 prefix INCLUDES that first formal relation
+ 4. Apply leakage filter BEFORE returning: stop at the earliest of:
+ - any line containing \\boxed
+ - any line containing 'therefore', 'hence', 'we conclude', 'the answer',
+ 'we obtain', 'thus', 'it suffices', 'we have proved', 'as a result'
+ - any line containing the dataset's recorded final_answer string
+"""
+from __future__ import annotations
+import re
+from typing import Optional, Dict
+
+
+# ---------- Display-math detection ----------
+
+# Order matters: try richest patterns first
+_DISPLAY_MATH_PATTERNS = [
+ re.compile(r"\$\$.+?\$\$", re.DOTALL),
+ re.compile(r"\\\[.+?\\\]", re.DOTALL),
+ re.compile(r"\\begin\{equation\*?\}.+?\\end\{equation\*?\}", re.DOTALL),
+ re.compile(r"\\begin\{align\*?\}.+?\\end\{align\*?\}", re.DOTALL),
+ re.compile(r"\\begin\{gather\*?\}.+?\\end\{gather\*?\}", re.DOTALL),
+ re.compile(r"\\begin\{eqnarray\*?\}.+?\\end\{eqnarray\*?\}", re.DOTALL),
+]
+
+
+def _first_display_math_end(text: str) -> Optional[int]:
+ """Return the end position of the first display-math block, or None."""
+ earliest = None
+ for pat in _DISPLAY_MATH_PATTERNS:
+ m = pat.search(text)
+ if m:
+ if earliest is None or m.end() < earliest:
+ earliest = m.end()
+ return earliest
+
+
+# Inline relation fallback: first line with a "real" relation
+_INLINE_REL_RE = re.compile(
+ r"[A-Za-z\)\]\}\d_]\s*(?:=|<|>|\\le[q]?|\\ge[q]?|\\equiv|\\in)\s*[A-Za-z\(\[\{\d\\\-]"
+)
+# Definition exclusion: lines that are 'let x = ...' or 'denote ...' are setup,
+# not actual derivations. We allow them in the prefix but don't stop on them.
+_DEFINITION_RE = re.compile(
+ r"^\s*(?:let|denote|define|set|put|call|consider|introduce|let us)\b",
+ re.IGNORECASE
+)
+
+
+def _first_inline_relation_line_end(text: str) -> Optional[int]:
+ """Find the end of the first line containing a non-definition math relation.
+
+ Returns absolute character offset (one past the newline)."""
+ pos = 0
+ while pos < len(text):
+ nl = text.find("\n", pos)
+ line_end = nl if nl != -1 else len(text)
+ line = text[pos:line_end]
+ if _INLINE_REL_RE.search(line) and not _DEFINITION_RE.search(line):
+ return line_end + 1 if nl != -1 else line_end
+ pos = line_end + 1
+ if nl == -1:
+ break
+ return None
+
+
+# ---------- Leakage detection ----------
+
+LEAKAGE_PATTERNS = [
+ re.compile(r"\\boxed\b", re.IGNORECASE),
+ re.compile(r"\btherefore\b", re.IGNORECASE),
+ re.compile(r"\bhence\b", re.IGNORECASE),
+ re.compile(r"\bwe conclude\b", re.IGNORECASE),
+ re.compile(r"\bthe answer\b", re.IGNORECASE),
+ re.compile(r"\bwe obtain\b", re.IGNORECASE),
+ re.compile(r"\bthus\b", re.IGNORECASE),
+ re.compile(r"\bit suffices\b", re.IGNORECASE),
+ re.compile(r"\bwe have proved\b", re.IGNORECASE),
+ re.compile(r"\bwe have shown\b", re.IGNORECASE),
+ re.compile(r"\bas a result\b", re.IGNORECASE),
+ re.compile(r"\bin conclusion\b", re.IGNORECASE),
+ re.compile(r"\bthe final answer\b", re.IGNORECASE),
+ re.compile(r"\bso the answer\b", re.IGNORECASE),
+]
+
+
+def _first_leakage_pos(text: str, final_answer: Optional[str] = None) -> Optional[int]:
+ """Return the starting char position of the earliest leakage marker."""
+ earliest = None
+ for pat in LEAKAGE_PATTERNS:
+ m = pat.search(text)
+ if m:
+ if earliest is None or m.start() < earliest:
+ earliest = m.start()
+ if final_answer:
+ # Final-answer leakage: only check if the answer string is non-trivial
+ fa = final_answer.strip()
+ if 8 <= len(fa) <= 200:
+ idx = text.find(fa)
+ if idx != -1:
+ if earliest is None or idx < earliest:
+ earliest = idx
+ return earliest
+
+
+# ---------- T2 truncation ----------
+
+MIN_PREFIX_CHARS = 50
+MAX_PREFIX_CHARS = 2400 # roughly 600 tokens
+
+
+def truncate_T2(text: str, final_answer: Optional[str] = None) -> Optional[str]:
+ """Return the T2 (after-first-equation) prefix, or None if not detectable.
+
+ T2 = up to and including the first formal equation, then capped by leakage
+ filter and MAX_PREFIX_CHARS.
+ """
+ if not text:
+ return None
+ end = _first_display_math_end(text)
+ if end is None:
+ end = _first_inline_relation_line_end(text)
+ if end is None:
+ return None
+ prefix = text[:end]
+ # Apply leakage filter BEFORE the equation if a leakage marker appears earlier
+ leak = _first_leakage_pos(prefix, final_answer)
+ if leak is not None and leak < end:
+ prefix = text[:leak].rstrip()
+ # Cap length
+ if len(prefix) > MAX_PREFIX_CHARS:
+ prefix = prefix[:MAX_PREFIX_CHARS]
+ # Trim at last newline to avoid cutting mid-sentence
+ last_nl = prefix.rfind("\n")
+ if last_nl > MIN_PREFIX_CHARS:
+ prefix = prefix[:last_nl]
+ if len(prefix) < MIN_PREFIX_CHARS:
+ return None
+ return prefix.rstrip()
+
+
+# ---------- Variable rename for own prefix ----------
+
+def rename_own_prefix(prefix: str, rename_map: Dict[str, str]) -> str:
+ """Apply orig->variant rename mapping to the model's own prefix.
+
+ Sort longest-first to avoid prefix collisions (e.g., 'al' eating 'almondtree').
+ Use word-boundary regex. Pass replacement via lambda to avoid escape-sequence
+ interpretation when the variant name starts with '\\x', '\\g', etc.
+ """
+ if not prefix or not rename_map:
+ return prefix
+ items = sorted(rename_map.items(), key=lambda kv: -len(kv[0]))
+ out = prefix
+ for src, dst in items:
+ if not src:
+ continue
+ pat = r"(?<![A-Za-z0-9_])" + re.escape(src) + r"(?![A-Za-z0-9_])"
+ # Use a lambda so dst is treated literally (no \1, \x, etc. escapes).
+ out = re.sub(pat, lambda _m, _dst=dst: _dst, out)
+ return out
+
+
+# ---------- Null scaffold ----------
+
+NULL_SCAFFOLD = (
+ "Let us proceed carefully. We will first identify the relevant variables "
+ "and their roles, then state the governing relations of the problem, and "
+ "finally develop the argument step by step."
+)
+
+
+# ---------- Prompt builders ----------
+
+# We tell the model to PRODUCE the complete solution that begins with the
+# provided prefix verbatim. This means the grader will see one continuous
+# solution that starts with the injected setup. The instruction to begin
+# verbatim avoids the model paraphrasing the prefix and removing the very
+# representational anchor we are testing.
+
+RESCUE_USER_TEMPLATE = """Please solve the following mathematical problem.
+
+PROBLEM:
+{problem_statement}
+
+You must structure your solution as a continuation of the partial work below.
+Begin your solution with the partial work copied verbatim, then continue
+seamlessly to a complete answer.
+
+PARTIAL WORK (to copy verbatim at the start of your solution):
+{prefix}
+
+Provide a complete, rigorous solution. Return your response in JSON format:
+{{"solution": "your complete solution starting with the partial work above and continuing to the end",
+ "final_answer": "your final answer in clear, concise form"}}"""
+
+
+NULL_USER_TEMPLATE = """Please solve the following mathematical problem.
+
+PROBLEM:
+{problem_statement}
+
+{scaffold}
+
+Provide a complete, rigorous solution. Return your response in JSON format:
+{{"solution": "your complete step-by-step solution",
+ "final_answer": "your final answer in clear, concise form"}}"""
+
+
+def build_rescue_prompt(problem_statement: str, prefix: str) -> str:
+ return RESCUE_USER_TEMPLATE.format(
+ problem_statement=problem_statement, prefix=prefix)
+
+
+def build_null_prompt(problem_statement: str) -> str:
+ return NULL_USER_TEMPLATE.format(
+ problem_statement=problem_statement, scaffold=NULL_SCAFFOLD)
+
+
+# ---------- Smoke test ----------
+
+if __name__ == "__main__":
+ # Quick smoke test on a real flip case
+ import json
+ import sys
+ sys.path.insert(0, "/home/yurenh2/gap/analysis")
+ from structural_overlap import find_variant_file, load_problems
+
+ # Pick gpt-4.1-mini original on a known problem
+ op = find_variant_file(
+ __import__("pathlib").Path("/home/yurenh2/gap/results_new/gpt-4.1-mini"),
+ "original")
+ probs = {p["index"]: p for p in load_problems(op)}
+ sample = next(p for idx, p in probs.items()
+ if p.get("correct") is True and (p.get("solve") or {}).get("solution"))
+ text = sample["solve"]["solution"]
+ fa = sample["solve"].get("final_answer")
+ print(f"Sample index: {sample['index']}, type: {sample['problem_type']}")
+ print(f"Original solution length: {len(text)} chars")
+ print(f"Recorded final_answer: {fa[:200] if fa else None!r}")
+ pre = truncate_T2(text, fa)
+ print(f"\n--- T2 PREFIX ({len(pre or '')} chars) ---")
+ print(pre)
+ print("--- END ---")
+
+ # Test rename: load 1987-B-2 dataset to get a sample map
+ ds = json.load(open("/home/yurenh2/gap/putnam-bench-anon/dataset/1987-B-2.json"))
+ rmap_raw = ds["variants"]["garbled_string"]["map"]
+ rmap = (eval(rmap_raw, {"__builtins__": {}}, {})
+ if isinstance(rmap_raw, str) else rmap_raw)
+ print(f"\nRename map: {rmap}")
+ test_text = "Let n be a positive integer and let f be a continuous function. Then $f(n) = 0$."
+ print(f"\nOriginal: {test_text}")
+ print(f"Renamed: {rename_own_prefix(test_text, rmap)}")
diff --git a/analysis/rescue_runner.py b/analysis/rescue_runner.py
new file mode 100644
index 0000000..9c9f226
--- /dev/null
+++ b/analysis/rescue_runner.py
@@ -0,0 +1,341 @@
+"""End-to-end rescue experiment runner.
+
+For each (model, variant, flip-case):
+ - Build 3 prompts: own_T2, canonical_T2, null (KV: only canonical_T2 + null)
+ - Solve with the same model the case originally failed under
+ - Grade with gpt-4o using the variant problem + canonical variant solution as reference
+ - Save per-case results immediately to a jsonl checkpoint (resumable)
+
+Usage:
+ python rescue_runner.py --pilot # 5 cases per cell (smoke test)
+ python rescue_runner.py # 30 cases per cell (full run)
+"""
+from __future__ import annotations
+import argparse
+import asyncio
+import json
+import os
+import random
+import sys
+import time
+from pathlib import Path
+from typing import Optional
+
+# Local imports
+THIS_DIR = Path(__file__).resolve().parent
+sys.path.insert(0, str(THIS_DIR))
+from rescue_prompts import (
+ truncate_T2, rename_own_prefix,
+ build_rescue_prompt, build_null_prompt, NULL_SCAFFOLD,
+)
+from rescue_api import (
+ SOLVER_PROVIDERS, solve, grade, parse_solution, parse_grade,
+)
+from structural_overlap import (
+ DATASET_DIR, RESULTS_DIR, find_variant_file, load_problems, SURFACE_VARIANTS,
+)
+
+
+# Short model name -> directory name in results_new
+MODEL_RESULTS_DIRS = {
+ "gpt-4.1-mini": "gpt-4.1-mini",
+ "gpt-4o-mini": "gpt-4o-mini",
+ "claude-sonnet-4": "claude-sonnet-4",
+ "gemini-2.5-flash": "gemini_2.5_flash", # historical underscore naming
+}
+SELECTED_MODELS = ["gpt-4.1-mini", "gpt-4o-mini", "claude-sonnet-4", "gemini-2.5-flash"]
+ALL_VARIANTS = SURFACE_VARIANTS + ["kernel_variant"]
+SURFACE_CONDITIONS = ["own_T2", "canonical_T2", "null"]
+KV_CONDITIONS = ["canonical_T2", "null"]
+
+
+# ---------- Dataset loading ----------
+
+def load_dataset_full() -> dict:
+ """Returns: {idx: {original: {...}, variants: {v: {map, question, solution}}}}.
+
+ The dataset stores top-level question/solution and variant-keyed question/solution/map.
+ """
+ out = {}
+ for f in sorted(DATASET_DIR.glob("*.json")):
+ d = json.load(open(f))
+ idx = d.get("index")
+ cell = {
+ "problem_type": d.get("problem_type"),
+ "original_question": d.get("question") or "",
+ "original_solution": d.get("solution") or "",
+ "variants": {},
+ }
+ for v, vd in d.get("variants", {}).items():
+ if isinstance(vd, dict):
+ rmap = vd.get("map")
+ if isinstance(rmap, str):
+ try:
+ rmap = eval(rmap, {"__builtins__": {}}, {})
+ except Exception:
+ rmap = None
+ cell["variants"][v] = {
+ "question": vd.get("question") or "",
+ "solution": vd.get("solution") or "",
+ "map": rmap if isinstance(rmap, dict) else None,
+ }
+ out[idx] = cell
+ return out
+
+
+# ---------- Flip case selection ----------
+
+def find_flip_cases(model: str, variant: str, max_cases: int,
+ seed: int = 42) -> list:
+ """Identify (orig_correct, var_wrong) flip cases for the cell.
+
+ Returns list of dicts with: index, problem_type, model_orig_solution,
+ final_answer (recorded), variant_problem_statement (from results).
+ """
+ mdir = RESULTS_DIR / MODEL_RESULTS_DIRS.get(model, model)
+ op = find_variant_file(mdir, "original")
+ vp = find_variant_file(mdir, variant)
+ if not op or not vp:
+ return []
+ orig_by = {p["index"]: p for p in load_problems(op)}
+ var_by = {p["index"]: p for p in load_problems(vp)}
+ cases = []
+ for idx in sorted(set(orig_by) & set(var_by)):
+ po, pv = orig_by[idx], var_by[idx]
+ if po.get("correct") is not True or pv.get("correct") is not False:
+ continue
+ orig_text = (po.get("solve") or {}).get("solution") or ""
+ if not orig_text:
+ continue
+ # Skip cases where we couldn't extract a T2 prefix from the original
+ fa = (po.get("solve") or {}).get("final_answer") or ""
+ if truncate_T2(orig_text, fa) is None:
+ continue
+ cases.append({
+ "index": idx,
+ "problem_type": po.get("problem_type"),
+ "orig_solution": orig_text,
+ "orig_final_answer": fa,
+ })
+ rng = random.Random(seed)
+ rng.shuffle(cases)
+ return cases[:max_cases]
+
+
+# ---------- Prompt construction per case ----------
+
+def build_case_prompts(case: dict, variant: str, ds_cell: dict) -> dict:
+ """Returns: {condition_name: user_message_string}."""
+ var_info = ds_cell["variants"].get(variant, {})
+ var_question = var_info.get("question", "")
+ if not var_question:
+ return {}
+ prompts = {}
+ is_kv = (variant == "kernel_variant")
+
+ # canonical_T2: dataset's canonical variant solution truncated
+ canon_sol = var_info.get("solution", "")
+ if canon_sol:
+ canon_pre = truncate_T2(canon_sol, None)
+ if canon_pre:
+ prompts["canonical_T2"] = build_rescue_prompt(var_question, canon_pre)
+
+ # own_T2: only for surface variants — model's own original-correct prefix renamed
+ if not is_kv:
+ rmap = var_info.get("map") or {}
+ own_pre = truncate_T2(case["orig_solution"], case.get("orig_final_answer"))
+ if own_pre and rmap:
+ renamed = rename_own_prefix(own_pre, rmap)
+ prompts["own_T2"] = build_rescue_prompt(var_question, renamed)
+
+ # null: always available
+ prompts["null"] = build_null_prompt(var_question)
+ return prompts
+
+
+# ---------- Per-condition runner ----------
+
+async def run_one_condition(model: str, condition: str, user_msg: str,
+ case: dict, variant: str, ds_cell: dict) -> dict:
+ """Solve + grade a single condition for a single case. Returns a result dict."""
+ var_info = ds_cell["variants"].get(variant, {})
+ var_question = var_info.get("question", "")
+ canon_sol = var_info.get("solution", "")
+ problem_type = case["problem_type"]
+ t0 = time.time()
+ solve_resp = await solve(model, user_msg)
+ solve_dt = time.time() - t0
+ if solve_resp["status"] != "success":
+ return {
+ "model": model, "variant": variant, "condition": condition,
+ "index": case["index"], "problem_type": problem_type,
+ "solve_status": "failed",
+ "solve_error": solve_resp["error"],
+ "solve_seconds": solve_dt,
+ "grade": None,
+ }
+ parsed = parse_solution(solve_resp["content"])
+ if not parsed["solution"]:
+ return {
+ "model": model, "variant": variant, "condition": condition,
+ "index": case["index"], "problem_type": problem_type,
+ "solve_status": "parse_failed",
+ "solve_error": parsed.get("_parse_error"),
+ "solve_seconds": solve_dt,
+ "raw_solve_content": solve_resp["content"][:500],
+ "grade": None,
+ }
+ student_solution = parsed["solution"]
+ t1 = time.time()
+ grade_resp = await grade(problem_type, var_question, student_solution, canon_sol)
+ grade_dt = time.time() - t1
+ if grade_resp["status"] != "success":
+ return {
+ "model": model, "variant": variant, "condition": condition,
+ "index": case["index"], "problem_type": problem_type,
+ "solve_status": "success",
+ "solve_seconds": solve_dt,
+ "grade_seconds": grade_dt,
+ "grade_status": "failed",
+ "grade_error": grade_resp["error"],
+ "student_solution_len": len(student_solution),
+ "student_final_answer": parsed["final_answer"],
+ "grade": None,
+ }
+ parsed_grade = parse_grade(grade_resp["content"])
+ return {
+ "model": model, "variant": variant, "condition": condition,
+ "index": case["index"], "problem_type": problem_type,
+ "solve_status": "success",
+ "solve_seconds": solve_dt,
+ "grade_seconds": grade_dt,
+ "grade_status": "success",
+ "student_solution_len": len(student_solution),
+ "student_solution": student_solution, # full text for downstream analysis
+ "student_final_answer": parsed["final_answer"][:500],
+ "grade": parsed_grade["grade"],
+ "final_answer_correct": parsed_grade.get("final_answer_correct"),
+ "grade_feedback": (parsed_grade.get("detailed_feedback") or "")[:1000],
+ }
+
+
+# ---------- Main run ----------
+
+OUT_DIR = Path("/home/yurenh2/gap/analysis/rescue_results")
+OUT_DIR.mkdir(parents=True, exist_ok=True)
+
+
+def load_existing_keys(path: Path) -> set:
+ """Read jsonl checkpoint and return set of (cell_key, condition, index)."""
+ keys = set()
+ if not path.exists():
+ return keys
+ with open(path) as f:
+ for line in f:
+ try:
+ d = json.loads(line)
+ keys.add((d["model"], d["variant"], d["condition"], d["index"]))
+ except Exception:
+ pass
+ return keys
+
+
+async def run_all(num_cases_per_cell: int, dry_run: bool = False, models=None,
+ variants=None):
+ print(f"Loading dataset ...", flush=True)
+ ds = load_dataset_full()
+ print(f" loaded {len(ds)} problems", flush=True)
+
+ out_path = OUT_DIR / f"rescue_{num_cases_per_cell}.jsonl"
+ existing = load_existing_keys(out_path)
+ print(f"Output: {out_path} (existing rows: {len(existing)})")
+
+ models = models or SELECTED_MODELS
+ variants = variants or ALL_VARIANTS
+
+ # Build the full task list
+ tasks_to_run = []
+ cell_summary = {}
+ for model in models:
+ for variant in variants:
+ cases = find_flip_cases(model, variant, num_cases_per_cell)
+ cell_key = f"{model}/{variant}"
+ cell_summary[cell_key] = {"flip_cases_found": len(cases),
+ "added_tasks": 0}
+ for case in cases:
+ ds_cell = ds.get(case["index"])
+ if ds_cell is None:
+ continue
+ prompts = build_case_prompts(case, variant, ds_cell)
+ for cond, user_msg in prompts.items():
+ key = (model, variant, cond, case["index"])
+ if key in existing:
+ continue
+ tasks_to_run.append((model, variant, cond, case, ds_cell, user_msg))
+ cell_summary[cell_key]["added_tasks"] += 1
+
+ print(f"\nCell-level plan ({num_cases_per_cell} flip cases each):")
+ for k, v in sorted(cell_summary.items()):
+ print(f" {k:<46} found={v['flip_cases_found']:>3} new_tasks={v['added_tasks']:>4}")
+ total = len(tasks_to_run)
+ print(f"\nTotal new tasks: {total}")
+ if dry_run:
+ return
+
+ if not tasks_to_run:
+ print("Nothing to do.")
+ return
+
+ # Execute concurrently. Use a writer task to drain results into the jsonl.
+ fout = open(out_path, "a")
+ write_lock = asyncio.Lock()
+ completed = 0
+ failed = 0
+ started_at = time.time()
+
+ async def run_and_write(model, variant, cond, case, ds_cell, user_msg):
+ nonlocal completed, failed
+ try:
+ res = await run_one_condition(model, cond, user_msg, case, variant, ds_cell)
+ except Exception as e:
+ res = {
+ "model": model, "variant": variant, "condition": cond,
+ "index": case["index"], "problem_type": case.get("problem_type"),
+ "solve_status": "exception",
+ "solve_error": f"{type(e).__name__}: {str(e)[:300]}",
+ "grade": None,
+ }
+ failed += 1
+ async with write_lock:
+ fout.write(json.dumps(res) + "\n")
+ fout.flush()
+ completed += 1
+ if completed % 25 == 0 or completed == total:
+ elapsed = time.time() - started_at
+ rate = completed / elapsed if elapsed > 0 else 0
+ eta = (total - completed) / rate if rate > 0 else 0
+ print(f" [{completed:>4}/{total}] elapsed={elapsed:>5.0f}s "
+ f"rate={rate:>4.1f}/s eta={eta:>5.0f}s "
+ f"failed_so_far={failed}", flush=True)
+
+ awaitables = [run_and_write(*t) for t in tasks_to_run]
+ await asyncio.gather(*awaitables)
+ fout.close()
+ print(f"\nDone. {completed}/{total} written. Failed: {failed}.")
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--pilot", action="store_true", help="run only 5 cases per cell")
+ ap.add_argument("--cases", type=int, default=30, help="cases per cell (full run)")
+ ap.add_argument("--dry-run", action="store_true", help="print plan, don't call APIs")
+ ap.add_argument("--models", nargs="+", default=None)
+ ap.add_argument("--variants", nargs="+", default=None)
+ args = ap.parse_args()
+ n = 5 if args.pilot else args.cases
+ asyncio.run(run_all(n, dry_run=args.dry_run,
+ models=args.models, variants=args.variants))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/analysis/sc_success_and_difficulty.py b/analysis/sc_success_and_difficulty.py
new file mode 100644
index 0000000..a8b44db
--- /dev/null
+++ b/analysis/sc_success_and_difficulty.py
@@ -0,0 +1,192 @@
+"""Two follow-up analyses (zero API):
+1. Per-model self-correction success rate: P(correct | SC) vs P(correct | no SC)
+2. Difficulty-stratified surface vs kernel dichotomy
+"""
+from __future__ import annotations
+import json
+import sys
+import statistics
+from pathlib import Path
+from collections import defaultdict
+
+THIS_DIR = Path(__file__).resolve().parent
+sys.path.insert(0, str(THIS_DIR))
+from structural_overlap import find_variant_file, load_problems, RESULTS_DIR, SURFACE_VARIANTS
+from self_correction import has_self_correction
+
+
+# ----------------- 1. SC success rate per model -----------------
+
+def sc_success_rate():
+ base = RESULTS_DIR
+ models = sorted([d.name for d in base.iterdir() if d.is_dir()])
+
+ print("=" * 80)
+ print("PER-MODEL SELF-CORRECTION SUCCESS RATE")
+ print("(does an SC attempt improve probability of being correct?)")
+ print("=" * 80)
+ print()
+
+ rows = []
+ for m in models:
+ mdir = base / m
+ # Aggregate over all variants
+ n_sc_correct = 0
+ n_sc_total = 0
+ n_nosc_correct = 0
+ n_nosc_total = 0
+ for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]:
+ vp = find_variant_file(mdir, v)
+ if not vp: continue
+ for p in load_problems(vp):
+ text = (p.get("solve") or {}).get("solution") or ""
+ if not text: continue
+ correct = p.get("correct")
+ if correct is None: continue
+ if has_self_correction(text):
+ n_sc_total += 1
+ if correct: n_sc_correct += 1
+ else:
+ n_nosc_total += 1
+ if correct: n_nosc_correct += 1
+ if n_sc_total < 5 or n_nosc_total < 5:
+ continue
+ p_sc = n_sc_correct / n_sc_total
+ p_nosc = n_nosc_correct / n_nosc_total
+ delta = p_sc - p_nosc
+ # Wilson 95% CI on each rate
+ rows.append({
+ "model": m,
+ "sc_n": n_sc_total, "sc_correct": n_sc_correct, "p_sc": p_sc,
+ "nosc_n": n_nosc_total, "nosc_correct": n_nosc_correct, "p_nosc": p_nosc,
+ "delta": delta,
+ })
+
+ rows.sort(key=lambda r: -r["sc_n"])
+ print(f"{'Model':<22} {'#SC trials':>11} {'P(corr|SC)':>12} {'P(corr|noSC)':>13} {'Δ':>9}")
+ print("-" * 75)
+ for r in rows:
+ print(f"{r['model']:<22} {r['sc_n']:>11} "
+ f"{r['p_sc']*100:>10.1f}% {r['p_nosc']*100:>11.1f}% "
+ f"{r['delta']*100:>+7.1f}pp")
+
+ json.dump(rows, open(THIS_DIR / "sc_success_per_model.json", "w"), indent=2)
+ return rows
+
+
+# ----------------- 2. Difficulty stratified dichotomy -----------------
+
+DATASET_DIR = Path("/home/yurenh2/gap/putnam-bench-anon/dataset")
+
+def load_difficulty_metadata():
+ """Per-problem difficulty assignment using year/section/index heuristic.
+
+ Per the paper's existing exposition, we derive Easy/Medium/Hard from the
+ problem index (1-2 = Easy, 3-4 = Medium, 5-6 = Hard, 7-8 = extra-hard tail)
+ because the dataset's `difficulty` field is heterogeneous.
+ """
+ out = {}
+ for f in sorted(DATASET_DIR.glob("*.json")):
+ d = json.load(open(f))
+ idx = d.get("index")
+ if not idx: continue
+ # Extract problem number from "YEAR-PART-NUM"
+ parts = idx.split("-")
+ if len(parts) != 3: continue
+ try:
+ num = int(parts[2])
+ except ValueError:
+ continue
+ if num <= 2: bucket = "Easy"
+ elif num <= 4: bucket = "Medium"
+ elif num <= 6: bucket = "Hard"
+ else: bucket = "ExtraHard"
+ out[idx] = bucket
+ return out
+
+
+def difficulty_stratified_dichotomy():
+ print("\n\n" + "=" * 80)
+ print("DIFFICULTY-STRATIFIED ACCURACY (mean across 18 models)")
+ print("Easy/Medium/Hard buckets defined by problem index 1-2/3-4/5-6")
+ print("=" * 80)
+ print()
+
+ diff = load_difficulty_metadata()
+ base = RESULTS_DIR
+ models = sorted([d.name for d in base.iterdir() if d.is_dir()])
+
+ # buckets[(model, variant, difficulty)] = (n, n_correct)
+ cells = defaultdict(lambda: [0, 0])
+ for m in models:
+ mdir = base / m
+ for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]:
+ vp = find_variant_file(mdir, v)
+ if not vp: continue
+ for p in load_problems(vp):
+ idx = p.get("index")
+ correct = p.get("correct")
+ if idx is None or correct is None: continue
+ bucket = diff.get(idx, "Unknown")
+ cells[(m, v, bucket)][0] += 1
+ if correct: cells[(m, v, bucket)][1] += 1
+
+ # Aggregate per (variant, difficulty) by averaging per-model rates
+ print(f"{'Variant':<24} {'Easy':>8} {'Medium':>8} {'Hard':>8} {'XHard':>8}")
+ print("-" * 60)
+ for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]:
+ row = {}
+ for bucket in ["Easy", "Medium", "Hard", "ExtraHard"]:
+ rates = []
+ for m in models:
+ n, c = cells.get((m, v, bucket), [0, 0])
+ if n >= 5:
+ rates.append(c / n)
+ row[bucket] = statistics.fmean(rates) * 100 if rates else None
+ print(f"{v:<24} "
+ f"{row['Easy']:>7.1f}% " if row['Easy'] is not None else f"{v:<24} {'-':>8}",
+ end="")
+ for bucket in ["Medium", "Hard", "ExtraHard"]:
+ print(f"{row[bucket]:>7.1f}% " if row[bucket] is not None else f"{'-':>8}", end="")
+ print()
+
+ # Compute Δ_orig→KV per difficulty bucket
+ print(f"\n--- Δ original → KV per difficulty bucket ---")
+ for bucket in ["Easy", "Medium", "Hard", "ExtraHard"]:
+ orig_rates = []
+ kv_rates = []
+ for m in models:
+ no, co = cells.get((m, "original", bucket), [0, 0])
+ nk, ck = cells.get((m, "kernel_variant", bucket), [0, 0])
+ if no >= 5 and nk >= 5:
+ orig_rates.append(co / no)
+ kv_rates.append(ck / nk)
+ if orig_rates:
+ mo = statistics.fmean(orig_rates) * 100
+ mk = statistics.fmean(kv_rates) * 100
+ print(f" {bucket:<10} orig={mo:5.1f}% kv={mk:5.1f}% Δ={mk-mo:+.1f}pp")
+
+ # Compute Δ_orig→GS per difficulty bucket
+ print(f"\n--- Δ original → GS (surface, hardest renamer) per difficulty bucket ---")
+ for bucket in ["Easy", "Medium", "Hard", "ExtraHard"]:
+ orig_rates = []
+ gs_rates = []
+ for m in models:
+ no, co = cells.get((m, "original", bucket), [0, 0])
+ ng, cg = cells.get((m, "garbled_string", bucket), [0, 0])
+ if no >= 5 and ng >= 5:
+ orig_rates.append(co / no)
+ gs_rates.append(cg / ng)
+ if orig_rates:
+ mo = statistics.fmean(orig_rates) * 100
+ mg = statistics.fmean(gs_rates) * 100
+ print(f" {bucket:<10} orig={mo:5.1f}% GS={mg:5.1f}% Δ={mg-mo:+.1f}pp")
+
+
+def main():
+ sc_success_rate()
+ difficulty_stratified_dichotomy()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/analysis/self_correction.py b/analysis/self_correction.py
new file mode 100644
index 0000000..5769647
--- /dev/null
+++ b/analysis/self_correction.py
@@ -0,0 +1,202 @@
+"""Self-correction / metacognition probe.
+
+Scan model trajectories for self-correction markers and compute:
+1. Attempt rate (trajectory contains a self-correction marker) per (model, variant, group)
+2. Whether self-correction attempt rate differs between stable / brittle-drift / rescued cases
+3. Conditional success: among trajectories with a self-correction attempt, what fraction is graded CORRECT?
+
+Self-correction markers (case-insensitive, word-boundary):
+- "wait" (e.g., "Wait, let me reconsider")
+- "actually" (e.g., "Actually, I think...")
+- "let me reconsider"
+- "let me redo"
+- "let me try again"
+- "I made a mistake"
+- "this is wrong"
+- "on second thought"
+- "correction:"
+- "scratch that"
+- "I was wrong"
+- "let me start over"
+
+Uses three data sources:
+A. The original 18-model results in /home/yurenh2/gap/results_new/ (stable + brittle drift + collapse)
+B. The rescue trajectories in analysis/rescue_results/rescue_30.jsonl (3 conditions × 4 models × 5 variants)
+"""
+from __future__ import annotations
+import json
+import re
+import os
+import sys
+import statistics
+from pathlib import Path
+from collections import defaultdict, Counter
+
+THIS_DIR = Path(__file__).resolve().parent
+sys.path.insert(0, str(THIS_DIR))
+from structural_overlap import find_variant_file, load_problems, RESULTS_DIR, SURFACE_VARIANTS
+
+SC_PATTERNS = [
+ re.compile(r"\bwait\b[,.]?\s+(let|actually|that|i)", re.IGNORECASE),
+ re.compile(r"\bactually[,.]\s", re.IGNORECASE),
+ re.compile(r"\blet\s+me\s+reconsider", re.IGNORECASE),
+ re.compile(r"\blet\s+me\s+redo", re.IGNORECASE),
+ re.compile(r"\blet\s+me\s+try\s+(this\s+)?again", re.IGNORECASE),
+ re.compile(r"\bi\s+made\s+a\s+mistake", re.IGNORECASE),
+ re.compile(r"\bthis\s+is\s+(wrong|incorrect)", re.IGNORECASE),
+ re.compile(r"\bon\s+second\s+thought", re.IGNORECASE),
+ re.compile(r"\bcorrection[:\s]", re.IGNORECASE),
+ re.compile(r"\bscratch\s+that", re.IGNORECASE),
+ re.compile(r"\bi\s+was\s+wrong", re.IGNORECASE),
+ re.compile(r"\blet\s+me\s+start\s+over", re.IGNORECASE),
+ re.compile(r"\bhmm[,.]\s+(actually|wait|that)", re.IGNORECASE),
+ re.compile(r"\bi\s+need\s+to\s+(redo|reconsider)", re.IGNORECASE),
+ re.compile(r"\boh\s+wait", re.IGNORECASE),
+]
+
+
+def has_self_correction(text: str) -> bool:
+ if not text:
+ return False
+ for pat in SC_PATTERNS:
+ if pat.search(text):
+ return True
+ return False
+
+
+def count_sc_markers(text: str) -> int:
+ if not text:
+ return 0
+ return sum(len(pat.findall(text)) for pat in SC_PATTERNS)
+
+
+# ---------- Source A: 18-model original results ----------
+
+def analyze_18_models():
+ """Self-correction rates in original solver runs across all 18 models."""
+ base = RESULTS_DIR
+ models = sorted([d.name for d in base.iterdir() if d.is_dir()])
+ print(f"\n=== SELF-CORRECTION IN 18-MODEL ORIGINAL RUNS ===\n")
+ print(f"Markers used: {len(SC_PATTERNS)} regex patterns")
+ print(f"Definition: trajectory contains at least one match.\n")
+
+ rows = []
+ for m in models:
+ mdir = base / m
+ for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]:
+ vp = find_variant_file(mdir, v)
+ if not vp:
+ continue
+ problems = load_problems(vp)
+ n_total = 0
+ n_sc = 0
+ n_correct_sc = 0
+ n_correct_total = 0
+ n_wrong_sc = 0
+ n_wrong_total = 0
+ for p in problems:
+ text = (p.get("solve") or {}).get("solution") or ""
+ if not text:
+ continue
+ correct = p.get("correct")
+ if correct is None:
+ continue
+ n_total += 1
+ sc = has_self_correction(text)
+ if sc: n_sc += 1
+ if correct is True:
+ n_correct_total += 1
+ if sc: n_correct_sc += 1
+ else:
+ n_wrong_total += 1
+ if sc: n_wrong_sc += 1
+ if n_total > 0:
+ rows.append({
+ "model": m, "variant": v, "n": n_total,
+ "sc_rate": n_sc / n_total,
+ "n_correct": n_correct_total,
+ "n_correct_sc_rate": n_correct_sc / max(1, n_correct_total),
+ "n_wrong": n_wrong_total,
+ "n_wrong_sc_rate": n_wrong_sc / max(1, n_wrong_total),
+ })
+
+ # Print compact table: per (variant) average across models
+ print(f"{'Variant':<24} {'mean SC%':>10} {'SC%|correct':>14} {'SC%|wrong':>12} {'asym (wrong-correct)':>22}")
+ print("-" * 90)
+ by_var = defaultdict(list)
+ for r in rows:
+ by_var[r["variant"]].append(r)
+ for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]:
+ rs = by_var.get(v, [])
+ if not rs:
+ continue
+ m_sc = statistics.fmean(r["sc_rate"] for r in rs) * 100
+ m_sc_c = statistics.fmean(r["n_correct_sc_rate"] for r in rs) * 100
+ m_sc_w = statistics.fmean(r["n_wrong_sc_rate"] for r in rs) * 100
+ asym = m_sc_w - m_sc_c
+ print(f"{v:<24} {m_sc:>9.1f}% {m_sc_c:>13.1f}% {m_sc_w:>11.1f}% {asym:>+21.1f}pp")
+
+ # Per-model leader board
+ print(f"\n{'Model':<22} {'mean SC% (all variants)':>26}")
+ print("-" * 50)
+ by_model = defaultdict(list)
+ for r in rows:
+ by_model[r["model"]].append(r["sc_rate"])
+ model_avgs = sorted([(m, statistics.fmean(vs) * 100) for m, vs in by_model.items()],
+ key=lambda t: -t[1])
+ for m, avg in model_avgs:
+ print(f"{m:<22} {avg:>25.1f}%")
+
+ return rows
+
+
+# ---------- Source B: rescue trajectories ----------
+
+def analyze_rescue():
+ path = THIS_DIR / "rescue_results/rescue_30.jsonl"
+ rows = [json.loads(l) for l in open(path)]
+ print(f"\n\n=== SELF-CORRECTION IN 1{{,}}529 RESCUE TRAJECTORIES ===\n")
+
+ # Group by (model, variant, condition, grade)
+ counts = defaultdict(lambda: {"n": 0, "sc": 0})
+ for r in rows:
+ text = r.get("student_solution") or ""
+ if not text:
+ continue
+ key = (r["model"], r["variant"], r["condition"], r.get("grade"))
+ counts[key]["n"] += 1
+ if has_self_correction(text):
+ counts[key]["sc"] += 1
+
+ # Aggregate per (variant, condition, grade)
+ by_vcg = defaultdict(lambda: {"n": 0, "sc": 0})
+ for k, d in counts.items():
+ m, v, c, g = k
+ by_vcg[(v, c, g)]["n"] += d["n"]
+ by_vcg[(v, c, g)]["sc"] += d["sc"]
+
+ print(f"{'Variant':<24} {'Condition':<14} {'CORRECT-SC%':>14} {'INCORRECT-SC%':>16}")
+ print("-" * 80)
+ for v in ["descriptive_long","descriptive_long_confusing","descriptive_long_misleading","garbled_string","kernel_variant"]:
+ for c in ["null", "canonical_T2", "own_T2"]:
+ cor = by_vcg.get((v, c, "CORRECT"), {"n": 0, "sc": 0})
+ inc = by_vcg.get((v, c, "INCORRECT"), {"n": 0, "sc": 0})
+ if cor["n"] == 0 and inc["n"] == 0:
+ continue
+ sc_c = cor["sc"] / max(1, cor["n"]) * 100 if cor["n"] else 0
+ sc_i = inc["sc"] / max(1, inc["n"]) * 100 if inc["n"] else 0
+ print(f"{v:<24} {c:<14} {sc_c:>11.1f}% (n={cor['n']:>3}) {sc_i:>13.1f}% (n={inc['n']:>3})")
+ print()
+
+ return counts
+
+
+def main():
+ rows_18 = analyze_18_models()
+ json.dump(rows_18, open(THIS_DIR / "self_correction_18models.json", "w"), indent=2)
+ counts_rescue = analyze_rescue()
+ print("\nSaved -> analysis/self_correction_18models.json")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/analysis/spotcheck_clean.py b/analysis/spotcheck_clean.py
new file mode 100644
index 0000000..52ddc43
--- /dev/null
+++ b/analysis/spotcheck_clean.py
@@ -0,0 +1,181 @@
+"""Spot-check Unicode cleaning by side-by-side comparison.
+
+For a stratified sample of problems, load:
+ - the ORIGINAL kernel_variant.solution from the backup tarball
+ - the CLEANED kernel_variant.solution from the current dataset
+and print them side-by-side so the user can verify that the cleaner
+preserved meaning.
+
+Sampling strategy:
+ - 5 most complex (by original Unicode count) — stress test
+ - 3 medium complexity — typical case
+ - 2 surface-variant samples — to confirm rename + LaTeX preserved
+"""
+from __future__ import annotations
+import json
+import sys
+import tarfile
+from pathlib import Path
+
+CURRENT_DIR = Path("/home/yurenh2/gap/putnam-bench-anon/dataset")
+BACKUP_TAR = sorted(Path("/home/yurenh2/gap/analysis/dataset_backups").glob(
+ "putnam-bench-anon_dataset_*.tar.gz"))[-1]
+
+
+def count_unicode(text: str) -> int:
+ return sum(1 for c in (text or "") if ord(c) > 127)
+
+
+def load_backup_problems():
+ """Yield (idx, problem_dict) from the backup tarball."""
+ with tarfile.open(BACKUP_TAR, "r:gz") as tar:
+ for member in tar.getmembers():
+ if not member.isfile() or not member.name.endswith(".json"):
+ continue
+ f = tar.extractfile(member)
+ if not f:
+ continue
+ try:
+ d = json.load(f)
+ yield d.get("index"), d
+ except Exception:
+ continue
+
+
+def main():
+ print(f"Backup tar: {BACKUP_TAR}")
+ print("Building Unicode-count index over 1051 problems ...")
+
+ # Index originals by Unicode count in kernel_variant.solution
+ by_uni_count = [] # (unicode_count, idx, solution_len)
+ backup_data = {}
+ for idx, d in load_backup_problems():
+ if not idx:
+ continue
+ backup_data[idx] = d
+ kv_sol = (d.get("variants") or {}).get("kernel_variant", {}).get("solution", "")
+ uc = count_unicode(kv_sol)
+ by_uni_count.append((uc, idx, len(kv_sol)))
+
+ by_uni_count.sort(reverse=True)
+ print(f" loaded {len(backup_data)} problems from backup")
+
+ # Pick samples
+ samples = []
+ samples.extend([(idx, "TOP COMPLEXITY") for _, idx, _ in by_uni_count[:5]])
+ mid = len(by_uni_count) // 2
+ samples.extend([(idx, "MEDIUM COMPLEXITY")
+ for _, idx, _ in by_uni_count[mid:mid + 3]])
+ # Bottom = least Unicode but still non-zero
+ nonzero = [t for t in by_uni_count if t[0] > 0]
+ samples.extend([(idx, "LOW COMPLEXITY")
+ for _, idx, _ in nonzero[-2:]])
+
+ print(f"\nSelected {len(samples)} samples:\n")
+ for idx, label in samples:
+ print(f" {label:<20} {idx}")
+
+ print("\n" + "=" * 80)
+ print("SIDE-BY-SIDE SPOT-CHECK")
+ print("=" * 80)
+
+ for case_idx, (idx, label) in enumerate(samples, 1):
+ print(f"\n{'#' * 80}")
+ print(f"# CASE {case_idx}/{len(samples)}: {idx} ({label})")
+ print(f"{'#' * 80}")
+
+ backup_problem = backup_data.get(idx)
+ current_path = CURRENT_DIR / f"{idx}.json"
+ if not backup_problem or not current_path.exists():
+ print(f" ! missing data for {idx}")
+ continue
+ current_problem = json.load(open(current_path))
+
+ # Compare kernel_variant.solution by default. For LOW COMPLEXITY cases
+ # we also show the original `solution` field if it differs.
+ for field_path in [("variants", "kernel_variant", "solution")]:
+ orig_text = backup_problem
+ curr_text = current_problem
+ for key in field_path:
+ orig_text = (orig_text or {}).get(key) if isinstance(orig_text, dict) else None
+ curr_text = (curr_text or {}).get(key) if isinstance(curr_text, dict) else None
+ if not orig_text and not curr_text:
+ continue
+ orig_text = orig_text or ""
+ curr_text = curr_text or ""
+ field_label = ".".join(field_path)
+ uni_before = count_unicode(orig_text)
+ uni_after = count_unicode(curr_text)
+ len_before = len(orig_text)
+ len_after = len(curr_text)
+ print(f"\n--- field: {field_label} ---")
+ print(f" before: {len_before} chars, {uni_before} non-ASCII")
+ print(f" after: {len_after} chars, {uni_after} non-ASCII "
+ f"(Δ len {len_after - len_before:+d})")
+ print(f"\n >>> ORIGINAL (first 600 chars) <<<")
+ print(" " + orig_text[:600].replace("\n", "\n "))
+ print(f"\n >>> CLEANED (first 600 chars) <<<")
+ print(" " + curr_text[:600].replace("\n", "\n "))
+
+ if uni_after > 0:
+ print(f" !!! WARNING: cleaned output still has {uni_after} non-ASCII chars")
+
+ # Sanity: are LaTeX braces balanced in the cleaned text?
+ n_open = curr_text.count("{")
+ n_close = curr_text.count("}")
+ n_lparen = curr_text.count("(")
+ n_rparen = curr_text.count(")")
+ n_lbrack = curr_text.count("[")
+ n_rbrack = curr_text.count("]")
+ print(f" brace balance: {{ {n_open} | }} {n_close} "
+ f"( {n_lparen} | ) {n_rparen} "
+ f"[ {n_lbrack} | ] {n_rbrack}")
+
+ # Final aggregate balance check across the entire cleaned dataset
+ print("\n" + "=" * 80)
+ print("AGGREGATE BRACE BALANCE CHECK (entire cleaned dataset)")
+ print("=" * 80)
+ total_diff_brace = 0
+ total_diff_paren = 0
+ total_diff_brack = 0
+ files_with_brace_imbalance = 0
+ files_with_paren_imbalance = 0
+ files_with_brack_imbalance = 0
+ for f in sorted(CURRENT_DIR.glob("*.json")):
+ d = json.load(open(f))
+ # Concatenate all text fields
+ bag = []
+ for k in ("question", "solution"):
+ bag.append(d.get(k) or "")
+ for vk, vd in (d.get("variants") or {}).items():
+ if isinstance(vd, dict):
+ for k in ("question", "solution"):
+ bag.append(vd.get(k) or "")
+ all_text = "\n".join(bag)
+ diff_brace = all_text.count("{") - all_text.count("}")
+ diff_paren = all_text.count("(") - all_text.count(")")
+ diff_brack = all_text.count("[") - all_text.count("]")
+ if diff_brace != 0:
+ files_with_brace_imbalance += 1
+ total_diff_brace += abs(diff_brace)
+ if diff_paren != 0:
+ files_with_paren_imbalance += 1
+ total_diff_paren += abs(diff_paren)
+ if diff_brack != 0:
+ files_with_brack_imbalance += 1
+ total_diff_brack += abs(diff_brack)
+
+ print(f" files with unbalanced {{...}}: {files_with_brace_imbalance}/1051"
+ f" (total |Δ| = {total_diff_brace})")
+ print(f" files with unbalanced (...): {files_with_paren_imbalance}/1051"
+ f" (total |Δ| = {total_diff_paren})")
+ print(f" files with unbalanced [...]: {files_with_brack_imbalance}/1051"
+ f" (total |Δ| = {total_diff_brack})")
+ print()
+ print(" (Imbalance is not necessarily a bug — math text often legitimately")
+ print(" contains unbalanced delimiters in display formulas; this is just")
+ print(" an order-of-magnitude check.)")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/analysis/structural_overlap.py b/analysis/structural_overlap.py
new file mode 100644
index 0000000..284c139
--- /dev/null
+++ b/analysis/structural_overlap.py
@@ -0,0 +1,523 @@
+"""Stable-vs-Brittle structural overlap analysis (label-free).
+
+Pipeline:
+1. For each (model, surface_variant) cell, load original and variant trajectories.
+2. Pull the deterministic rename map from /home/yurenh2/gap/putnam-bench-anon/dataset/.
+3. Canonicalize both trajectories: replace variant variables with placeholders
+ (via inverse rename map). Original trajectory: replace canonical variables
+ with the same placeholders. Both texts then live in a shared placeholder space.
+4. Compute multiple non-LLM structural metrics on (orig_canonical, var_canonical):
+ - Token Jaccard
+ - Bigram Jaccard
+ - Equation-set Jaccard (math-block extraction)
+ - Prefix Jaccard (first 30% of each canonical text)
+5. Stratify by group (stable vs brittle) within each (model, variant) cell.
+6. Mann-Whitney U test on each metric for stable vs brittle.
+
+Surface variants only (rename map available). Kernel handled separately.
+"""
+
+from __future__ import annotations
+import json
+import re
+import os
+from pathlib import Path
+from collections import Counter, defaultdict
+from typing import Dict, List, Tuple, Optional
+
+import statistics
+
+DATASET_DIR = Path("/home/yurenh2/gap/putnam-bench-anon/dataset")
+RESULTS_DIR = Path("/home/yurenh2/gap/results_new")
+
+SURFACE_VARIANTS = ["descriptive_long", "descriptive_long_confusing",
+ "descriptive_long_misleading", "garbled_string"]
+
+
+# ---------- I/O helpers ----------
+
+def load_problems(path: Path) -> List[dict]:
+ d = json.load(open(path))
+ return d.get("problems") or d.get("detailed_results") or []
+
+
+def find_variant_file(model_dir: Path, variant: str) -> Optional[Path]:
+ files = sorted(os.listdir(model_dir))
+ cands = [f for f in files
+ if f.endswith(f"_{variant}.json")
+ and "regraded" not in f and "comparison" not in f
+ and not f.endswith(f"_{variant}2.json")]
+ if not cands and variant == "garbled_string":
+ cands = [f for f in files if f.endswith("_gs.json")]
+ return model_dir / cands[0] if cands else None
+
+
+def load_dataset_maps() -> Dict[str, Dict[str, Dict[str, str]]]:
+ """Returns: {problem_index: {variant: {orig_var_name: variant_var_name}}}"""
+ out = {}
+ for f in sorted(DATASET_DIR.glob("*.json")):
+ d = json.load(open(f))
+ idx = d.get("index")
+ variants = d.get("variants", {})
+ cell = {}
+ for v in SURFACE_VARIANTS:
+ vd = variants.get(v, {})
+ mp_str = vd.get("map")
+ if isinstance(mp_str, str):
+ # The map is stored as a Python repr string; eval it safely
+ try:
+ mp = eval(mp_str, {"__builtins__": {}}, {})
+ if isinstance(mp, dict):
+ cell[v] = {str(k): str(v) for k, v in mp.items()}
+ except Exception:
+ pass
+ elif isinstance(mp_str, dict):
+ cell[v] = {str(k): str(v) for k, v in mp_str.items()}
+ out[idx] = cell
+ return out
+
+
+# ---------- Canonicalization ----------
+
+def canonicalize_text(text: str, var_to_placeholder: Dict[str, str]) -> str:
+ """Replace each variable name in text with its canonical placeholder.
+
+ Sort by length desc to avoid prefix collisions (e.g., 'xs' before 'x').
+ Use word-boundary regex for ASCII-identifier-like names; literal replace
+ for non-identifier names (like garbled strings, which are also alpha).
+ """
+ if not text:
+ return ""
+ # Sort longest-first to avoid 'al' eating into 'almondtree'
+ items = sorted(var_to_placeholder.items(), key=lambda kv: -len(kv[0]))
+ out = text
+ for var, ph in items:
+ if not var:
+ continue
+ # Use word-boundary so we only replace whole tokens. Variables in this
+ # dataset are all alphanumeric.
+ pat = r"(?<![A-Za-z0-9_])" + re.escape(var) + r"(?![A-Za-z0-9_])"
+ out = re.sub(pat, ph, out)
+ return out
+
+
+def normalize_whitespace(text: str) -> str:
+ return re.sub(r"\s+", " ", text).strip()
+
+
+# ---------- Tokenization ----------
+
+_TOKEN_RE = re.compile(r"[A-Za-z_][A-Za-z0-9_]*|\d+|[^\sA-Za-z0-9_]")
+
+def tokens(text: str) -> List[str]:
+ return _TOKEN_RE.findall(text or "")
+
+
+def bigrams(toks: List[str]) -> List[str]:
+ return [f"{toks[i]} {toks[i+1]}" for i in range(len(toks) - 1)]
+
+
+# ---------- Math block extraction ----------
+
+_MATH_BLOCKS = [
+ re.compile(r"\$\$(.+?)\$\$", re.DOTALL),
+ re.compile(r"\\\[(.+?)\\\]", re.DOTALL),
+ re.compile(r"\$(.+?)\$", re.DOTALL),
+ re.compile(r"\\begin\{(?:equation|align|gather)\*?\}(.+?)\\end\{(?:equation|align|gather)\*?\}", re.DOTALL),
+]
+
+def extract_math_blocks(text: str, min_len: int = 8) -> List[str]:
+ found = []
+ for pat in _MATH_BLOCKS:
+ found.extend(pat.findall(text or ""))
+ # Lightweight normalization: collapse whitespace, strip
+ out = [normalize_whitespace(b) for b in found if b.strip()]
+ # Filter trivial fragments like '$n$', '$0$', '$x$' that saturate Jaccard
+ return [b for b in out if len(b) >= min_len]
+
+
+# ---------- Metrics ----------
+
+def jaccard(a: set, b: set) -> float:
+ if not a and not b:
+ return 1.0
+ return len(a & b) / max(1, len(a | b))
+
+
+def metric_token_jaccard(a: str, b: str) -> float:
+ return jaccard(set(tokens(a)), set(tokens(b)))
+
+
+def metric_bigram_jaccard(a: str, b: str) -> float:
+ return jaccard(set(bigrams(tokens(a))), set(bigrams(tokens(b))))
+
+
+def metric_prefix_token_jaccard(a: str, b: str, frac: float = 0.3) -> float:
+ """Jaccard over the first frac of tokens from each side."""
+ ta, tb = tokens(a), tokens(b)
+ na, nb = max(1, int(len(ta) * frac)), max(1, int(len(tb) * frac))
+ return jaccard(set(ta[:na]), set(tb[:nb]))
+
+
+def metric_prefix_bigram_jaccard(a: str, b: str, frac: float = 0.3) -> float:
+ ta, tb = tokens(a), tokens(b)
+ na, nb = max(1, int(len(ta) * frac)), max(1, int(len(tb) * frac))
+ return jaccard(set(bigrams(ta[:na])), set(bigrams(tb[:nb])))
+
+
+def metric_equation_jaccard(a: str, b: str) -> float:
+ ea = set(extract_math_blocks(a))
+ eb = set(extract_math_blocks(b))
+ return jaccard(ea, eb)
+
+
+def metric_lcp_tokens(a: str, b: str) -> int:
+ """Length of the longest common prefix of canonicalized token streams.
+
+ Directly tests Codex's thesis 'early loss of structural overlap with the
+ model's own original reasoning under renaming'. Larger LCP -> the model
+ started its variant trajectory the same way it started the original.
+ """
+ ta, tb = tokens(a), tokens(b)
+ n = min(len(ta), len(tb))
+ i = 0
+ while i < n and ta[i] == tb[i]:
+ i += 1
+ return i
+
+
+def metric_lcp_normalized(a: str, b: str) -> float:
+ """LCP length normalized by the shorter trajectory length, in [0, 1]."""
+ ta, tb = tokens(a), tokens(b)
+ n = min(len(ta), len(tb))
+ if n == 0:
+ return 0.0
+ return metric_lcp_tokens(a, b) / n
+
+
+def metric_lcp_first1k(a: str, b: str) -> float:
+ """LCP length capped to first-1000-token comparison, normalized to [0, 1]."""
+ ta, tb = tokens(a), tokens(b)
+ ta, tb = ta[:1000], tb[:1000]
+ n = min(len(ta), len(tb))
+ if n == 0:
+ return 0.0
+ i = 0
+ while i < n and ta[i] == tb[i]:
+ i += 1
+ return i / n
+
+
+def metric_directional_coverage(a: str, b: str) -> float:
+ """|tokens_a ∩ tokens_b| / |tokens_a|. Length-asymmetric.
+
+ Reads as: 'what fraction of the original's vocabulary survives in the variant?'
+ More robust to length differences than symmetric Jaccard.
+ """
+ ta = set(tokens(a))
+ tb = set(tokens(b))
+ if not ta:
+ return 0.0
+ return len(ta & tb) / len(ta)
+
+
+def metric_window_token_jaccard(a: str, b: str, window: int = 600) -> float:
+ """Jaccard restricted to the first `window` tokens on each side."""
+ ta = tokens(a)[:window]
+ tb = tokens(b)[:window]
+ return jaccard(set(ta), set(tb))
+
+
+def metric_window_bigram_jaccard(a: str, b: str, window: int = 600) -> float:
+ ta = tokens(a)[:window]
+ tb = tokens(b)[:window]
+ return jaccard(set(bigrams(ta)), set(bigrams(tb)))
+
+
+# ---------- Stat helpers ----------
+
+def bootstrap_ci_delta_median(xs: List[float], ys: List[float],
+ n_iter: int = 1000, seed: int = 0) -> Tuple[float, float]:
+ """Percentile bootstrap 95% CI on median(xs) - median(ys)."""
+ import random
+ rng = random.Random(seed)
+ if not xs or not ys:
+ return float("nan"), float("nan")
+ ds = []
+ for _ in range(n_iter):
+ rs = [xs[rng.randrange(len(xs))] for _ in range(len(xs))]
+ rb = [ys[rng.randrange(len(ys))] for _ in range(len(ys))]
+ ds.append(statistics.median(rs) - statistics.median(rb))
+ ds.sort()
+ lo = ds[int(0.025 * n_iter)]
+ hi = ds[int(0.975 * n_iter)]
+ return lo, hi
+
+
+def bootstrap_ci_cohens_d(xs: List[float], ys: List[float],
+ n_iter: int = 1000, seed: int = 0) -> Tuple[float, float]:
+ import random
+ rng = random.Random(seed)
+ if len(xs) < 2 or len(ys) < 2:
+ return float("nan"), float("nan")
+ ds = []
+ for _ in range(n_iter):
+ rs = [xs[rng.randrange(len(xs))] for _ in range(len(xs))]
+ rb = [ys[rng.randrange(len(ys))] for _ in range(len(ys))]
+ sm, bm = statistics.fmean(rs), statistics.fmean(rb)
+ ssd = statistics.pstdev(rs)
+ bsd = statistics.pstdev(rb)
+ pooled = (((len(rs)-1)*ssd**2 + (len(rb)-1)*bsd**2)
+ / max(1, len(rs)+len(rb)-2)) ** 0.5
+ if pooled > 0:
+ ds.append((sm - bm) / pooled)
+ if not ds:
+ return float("nan"), float("nan")
+ ds.sort()
+ lo = ds[int(0.025 * len(ds))]
+ hi = ds[int(0.975 * len(ds))]
+ return lo, hi
+
+
+def mann_whitney_u(xs: List[float], ys: List[float]) -> Tuple[float, float]:
+ """Returns (U, normal_approx_p_two_sided). Pure-python, no scipy.
+
+ Used only as a screening signal — for the rebuttal we'll use scipy if
+ available; this is a fallback so we don't add a dependency.
+ """
+ n1, n2 = len(xs), len(ys)
+ if n1 == 0 or n2 == 0:
+ return float("nan"), float("nan")
+ combined = [(v, 0) for v in xs] + [(v, 1) for v in ys]
+ combined.sort(key=lambda t: t[0])
+ # Average ranks for ties
+ ranks = [0.0] * len(combined)
+ i = 0
+ while i < len(combined):
+ j = i
+ while j + 1 < len(combined) and combined[j + 1][0] == combined[i][0]:
+ j += 1
+ avg = (i + j) / 2.0 + 1 # 1-indexed
+ for k in range(i, j + 1):
+ ranks[k] = avg
+ i = j + 1
+ R1 = sum(ranks[k] for k in range(len(combined)) if combined[k][1] == 0)
+ U1 = R1 - n1 * (n1 + 1) / 2.0
+ U2 = n1 * n2 - U1
+ U = min(U1, U2)
+ # Normal approx (no tie correction)
+ mu = n1 * n2 / 2.0
+ sd = (n1 * n2 * (n1 + n2 + 1) / 12.0) ** 0.5
+ if sd == 0:
+ return U, float("nan")
+ z = (U - mu) / sd
+ # Two-sided p via erf approx
+ import math
+ p = math.erfc(abs(z) / math.sqrt(2))
+ return U, p
+
+
+# ---------- Cell analysis ----------
+
+COLLAPSE_MIN_CHARS = 200
+COLLAPSE_RATIO = 0.25 # variant_len < ratio * orig_len => collapse
+
+
+def is_collapse(orig_text: str, var_text: str) -> bool:
+ return (len(var_text) < COLLAPSE_MIN_CHARS
+ or len(var_text) < COLLAPSE_RATIO * max(1, len(orig_text)))
+
+
+def analyze_cell(model_name: str, variant: str, dataset_maps: dict,
+ model_dir: Path) -> Optional[dict]:
+ orig_path = find_variant_file(model_dir, "original")
+ var_path = find_variant_file(model_dir, variant)
+ if not orig_path or not var_path:
+ return None
+
+ orig_by = {p["index"]: p for p in load_problems(orig_path)}
+ var_by = {p["index"]: p for p in load_problems(var_path)}
+
+ common = set(orig_by) & set(var_by)
+ pairs_stable_drift = [] # (orig_canon, var_canon, problem_type) — non-collapse
+ pairs_brittle_drift = [] # non-collapse brittle
+ pairs_brittle_collapse = [] # short variant text
+ n_stable_collapse = 0 # almost always 0 but tracked for completeness
+
+ for idx in common:
+ po, pv = orig_by[idx], var_by[idx]
+ if po.get("correct") is not True:
+ continue
+ var_correct = pv.get("correct")
+ if var_correct is None:
+ continue
+ orig_text = (po.get("solve") or {}).get("solution") or ""
+ var_text = (pv.get("solve") or {}).get("solution") or ""
+ if not orig_text or not var_text:
+ continue
+ rmap = dataset_maps.get(idx, {}).get(variant)
+ if not rmap:
+ continue
+ # Canonicalize
+ canon_to_ph = {k: f"__V{i}__" for i, k in enumerate(rmap.keys())}
+ var_to_ph = {rmap[k]: canon_to_ph[k] for k in rmap}
+ orig_canon = canonicalize_text(orig_text, canon_to_ph)
+ var_canon = canonicalize_text(var_text, var_to_ph)
+ sample = {
+ "index": idx,
+ "problem_type": po.get("problem_type"),
+ "orig_canon": orig_canon,
+ "var_canon": var_canon,
+ "orig_len": len(orig_text),
+ "var_len": len(var_text),
+ }
+ collapse = is_collapse(orig_text, var_text)
+ if var_correct is True:
+ if collapse:
+ n_stable_collapse += 1
+ else:
+ pairs_stable_drift.append(sample)
+ else:
+ if collapse:
+ pairs_brittle_collapse.append(sample)
+ else:
+ pairs_brittle_drift.append(sample)
+
+ if not pairs_stable_drift or not pairs_brittle_drift:
+ return None
+
+ metrics = {
+ "token_jaccard": metric_token_jaccard,
+ "bigram_jaccard": metric_bigram_jaccard,
+ "directional_coverage": metric_directional_coverage,
+ "window_token_jaccard": metric_window_token_jaccard,
+ "window_bigram_jaccard": metric_window_bigram_jaccard,
+ "equation_jaccard": metric_equation_jaccard,
+ }
+ # Headline metric for bootstrap + noise floor (the others stay descriptive)
+ HEADLINE = "token_jaccard"
+
+ # Pre-tokenize once per pair to amortize cost (used by token/bigram/window metrics).
+ for p in pairs_stable_drift + pairs_brittle_drift:
+ p["_otok"] = tokens(p["orig_canon"])
+ p["_vtok"] = tokens(p["var_canon"])
+ p["_oset"] = set(p["_otok"])
+ p["_vset"] = set(p["_vtok"])
+
+ def fast_token_jaccard(p):
+ a, b = p["_oset"], p["_vset"]
+ if not a and not b:
+ return 1.0
+ return len(a & b) / max(1, len(a | b))
+
+ def fast_token_jaccard_pair(pa, pb):
+ a, b = pa["_oset"], pb["_vset"]
+ if not a and not b:
+ return 1.0
+ return len(a & b) / max(1, len(a | b))
+
+ out = {
+ "model": model_name,
+ "variant": variant,
+ "n_stable_drift": len(pairs_stable_drift),
+ "n_brittle_drift": len(pairs_brittle_drift),
+ "n_stable_collapse": n_stable_collapse,
+ "n_brittle_collapse": len(pairs_brittle_collapse),
+ "brittle_collapse_rate": (len(pairs_brittle_collapse)
+ / max(1, len(pairs_brittle_collapse) + len(pairs_brittle_drift))),
+ "metrics": {},
+ }
+ # Compute all descriptive metrics (one pass per pair, no bootstrap)
+ for mname, mfn in metrics.items():
+ s_vals = [mfn(p["orig_canon"], p["var_canon"]) for p in pairs_stable_drift]
+ b_vals = [mfn(p["orig_canon"], p["var_canon"]) for p in pairs_brittle_drift]
+ U, p = mann_whitney_u(s_vals, b_vals)
+ sm, bm = statistics.fmean(s_vals), statistics.fmean(b_vals)
+ ssd = statistics.pstdev(s_vals) if len(s_vals) > 1 else 0
+ bsd = statistics.pstdev(b_vals) if len(b_vals) > 1 else 0
+ pooled = (((len(s_vals)-1)*ssd**2 + (len(b_vals)-1)*bsd**2)
+ / max(1, len(s_vals)+len(b_vals)-2)) ** 0.5
+ d = (sm - bm) / pooled if pooled > 0 else 0.0
+ out["metrics"][mname] = {
+ "stable_median": statistics.median(s_vals),
+ "stable_mean": sm,
+ "brittle_median": statistics.median(b_vals),
+ "brittle_mean": bm,
+ "delta_median": statistics.median(s_vals) - statistics.median(b_vals),
+ "delta_mean": sm - bm,
+ "cohens_d": d,
+ "U": U,
+ "p_two_sided": p,
+ }
+
+ # Bootstrap + noise floor only on headline metric
+ s_vals = [fast_token_jaccard(p) for p in pairs_stable_drift]
+ b_vals = [fast_token_jaccard(p) for p in pairs_brittle_drift]
+ ci_lo, ci_hi = bootstrap_ci_delta_median(s_vals, b_vals, n_iter=400)
+ d_lo, d_hi = bootstrap_ci_cohens_d(s_vals, b_vals, n_iter=400)
+ out["metrics"][HEADLINE]["delta_median_ci"] = [ci_lo, ci_hi]
+ out["metrics"][HEADLINE]["cohens_d_ci"] = [d_lo, d_hi]
+
+ # Random-pairing noise floor for headline: pair stable orig with random other-problem variant
+ import random as _r
+ rng = _r.Random(42)
+ nf_vals = []
+ n = len(pairs_stable_drift)
+ if n >= 2:
+ for _ in range(min(400, n * (n - 1))):
+ i = rng.randrange(n)
+ j = rng.randrange(n)
+ while j == i:
+ j = rng.randrange(n)
+ nf_vals.append(fast_token_jaccard_pair(pairs_stable_drift[i],
+ pairs_stable_drift[j]))
+ out["metrics"][HEADLINE]["noise_floor_median"] = (
+ statistics.median(nf_vals) if nf_vals else None)
+ out["metrics"][HEADLINE]["noise_floor_mean"] = (
+ statistics.fmean(nf_vals) if nf_vals else None)
+ out["metrics"][HEADLINE]["noise_floor_n"] = len(nf_vals)
+ return out
+
+
+def main():
+ print("Loading dataset rename maps ...", flush=True)
+ dataset_maps = load_dataset_maps()
+ print(f" loaded {len(dataset_maps)} problems", flush=True)
+
+ # Multi-cell sweep across all models × 4 surface variants
+ # Run all 18 models — non-LLM, fast.
+ all_models = sorted([d.name for d in RESULTS_DIR.iterdir() if d.is_dir()])
+ print(f"Models: {all_models}")
+ all_results = []
+
+ print(f"\n{'Cell':<46} {'nSd':>4} {'nBd':>4} {'col%':>5} "
+ f"{'sMed':>6} {'bMed':>6} {'nfMed':>6} "
+ f"{'d':>6} {'d95CI':>14} {'p':>9}")
+ print("-" * 122)
+
+ for m in all_models:
+ for v in SURFACE_VARIANTS:
+ mdir = RESULTS_DIR / m
+ if not mdir.exists():
+ continue
+ res = analyze_cell(m, v, dataset_maps, mdir)
+ if res is None:
+ continue
+ all_results.append(res)
+ md = res["metrics"]["token_jaccard"]
+ label = f"{m} / {v}"
+ ci_lo, ci_hi = md["cohens_d_ci"]
+ ci_str = f"[{ci_lo:+.2f}, {ci_hi:+.2f}]"
+ print(f"{label:<46} {res['n_stable_drift']:>4} {res['n_brittle_drift']:>4} "
+ f"{res['brittle_collapse_rate']*100:>4.0f}% "
+ f"{md['stable_median']:>6.3f} {md['brittle_median']:>6.3f} "
+ f"{md['noise_floor_median']:>6.3f} "
+ f"{md['cohens_d']:>+6.2f} {ci_str:>14} {md['p_two_sided']:>9.1e}")
+
+ out_path = Path("/home/yurenh2/gap/analysis/structural_overlap_results.json")
+ json.dump(all_results, open(out_path, "w"), indent=2)
+ print(f"\nSaved -> {out_path} ({len(all_results)} cells)")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/analysis/topic_problemtype_interaction.py b/analysis/topic_problemtype_interaction.py
new file mode 100644
index 0000000..405b33a
--- /dev/null
+++ b/analysis/topic_problemtype_interaction.py
@@ -0,0 +1,112 @@
+"""KV fragility broken down by Topic × Problem-type (proof vs calculation)."""
+from __future__ import annotations
+import json
+import sys
+import statistics
+from pathlib import Path
+from collections import defaultdict
+
+THIS_DIR = Path(__file__).resolve().parent
+sys.path.insert(0, str(THIS_DIR))
+from structural_overlap import find_variant_file, load_problems, RESULTS_DIR, SURFACE_VARIANTS
+
+DATASET_DIR = Path("/home/yurenh2/gap/putnam-bench-anon/dataset")
+
+
+def load_metadata():
+ out = {}
+ for f in sorted(DATASET_DIR.glob("*.json")):
+ d = json.load(open(f))
+ idx = d.get("index")
+ if not idx: continue
+ out[idx] = {
+ "tag": d.get("tag"),
+ "problem_type": d.get("problem_type"),
+ }
+ return out
+
+
+def main():
+ metadata = load_metadata()
+ base = RESULTS_DIR
+ models = sorted([d.name for d in base.iterdir() if d.is_dir()])
+
+ # cells[(topic, ptype, model, variant)] = (n, n_correct)
+ cells = defaultdict(lambda: [0, 0])
+ for m in models:
+ mdir = base / m
+ for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]:
+ vp = find_variant_file(mdir, v)
+ if not vp: continue
+ for p in load_problems(vp):
+ idx = p.get("index")
+ correct = p.get("correct")
+ if idx is None or correct is None: continue
+ md = metadata.get(idx, {})
+ tag = md.get("tag")
+ ptype = md.get("problem_type")
+ if not tag or not ptype: continue
+ tags = tag if isinstance(tag, list) else [tag]
+ for t in tags:
+ if t not in ["ALG", "ANA", "NT", "COMB", "GEO"]: continue
+ cells[(t, ptype, m, v)][0] += 1
+ if correct: cells[(t, ptype, m, v)][1] += 1
+
+ print("=" * 80)
+ print("ACCURACY BY TOPIC × PROBLEM-TYPE × VARIANT (mean across 18 models)")
+ print("=" * 80)
+ print()
+
+ for ptype in ["proof", "calculation"]:
+ print(f"\n--- {ptype.upper()} ---\n")
+ print(f"{'Topic':<6}", end="")
+ for v in ["original", "garbled_string", "kernel_variant"]:
+ short = {"original":"orig","garbled_string":"GS","kernel_variant":"KV"}[v]
+ print(f" {short:>6}", end="")
+ print(f" {'Δ_GS':>7} {'Δ_KV':>7}")
+ print("-" * 50)
+ for t in ["ALG", "ANA", "NT", "COMB", "GEO"]:
+ orig_rates = []
+ gs_rates = []
+ kv_rates = []
+ for m in models:
+ no, co = cells.get((t, ptype, m, "original"), [0, 0])
+ ng, cg = cells.get((t, ptype, m, "garbled_string"), [0, 0])
+ nk, ck = cells.get((t, ptype, m, "kernel_variant"), [0, 0])
+ if no >= 5 and ng >= 5 and nk >= 5:
+ orig_rates.append(co / no)
+ gs_rates.append(cg / ng)
+ kv_rates.append(ck / nk)
+ if not orig_rates: continue
+ mo = statistics.fmean(orig_rates) * 100
+ mg = statistics.fmean(gs_rates) * 100
+ mk = statistics.fmean(kv_rates) * 100
+ print(f"{t:<6} {mo:>5.1f}% {mg:>5.1f}% {mk:>5.1f}% {mg-mo:>+5.1f}pp {mk-mo:>+5.1f}pp")
+
+ print("\n\n=== KEY DIFFERENTIAL: Δ KV by Topic for proof vs calculation ===\n")
+ print(f"{'Topic':<6} {'proof Δ':>10} {'calc Δ':>10} {'(calc - proof)':>16}")
+ print("-" * 50)
+ for t in ["ALG", "ANA", "NT", "COMB", "GEO"]:
+ deltas = {}
+ for ptype in ["proof", "calculation"]:
+ orig_rates = []
+ kv_rates = []
+ for m in models:
+ no, co = cells.get((t, ptype, m, "original"), [0, 0])
+ nk, ck = cells.get((t, ptype, m, "kernel_variant"), [0, 0])
+ if no >= 5 and nk >= 5:
+ orig_rates.append(co / no)
+ kv_rates.append(ck / nk)
+ if orig_rates:
+ deltas[ptype] = (statistics.fmean(kv_rates) - statistics.fmean(orig_rates)) * 100
+ if "proof" in deltas and "calculation" in deltas:
+ diff = deltas["calculation"] - deltas["proof"]
+ print(f"{t:<6} {deltas['proof']:>+9.1f}pp {deltas['calculation']:>+9.1f}pp {diff:>+15.1f}pp")
+ elif "proof" in deltas:
+ print(f"{t:<6} {deltas['proof']:>+9.1f}pp {'-':>10} {'-':>16}")
+ elif "calculation" in deltas:
+ print(f"{t:<6} {'-':>10} {deltas['calculation']:>+9.1f}pp {'-':>16}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/analysis/unicode_audit.py b/analysis/unicode_audit.py
new file mode 100644
index 0000000..afe5679
--- /dev/null
+++ b/analysis/unicode_audit.py
@@ -0,0 +1,238 @@
+"""Unicode audit for PutnamGAP dataset.
+
+Scans all JSON files in the dataset, finds all non-ASCII characters in text
+fields (question, solution across all variants), and reports:
+
+1. How many files contain Unicode
+2. Top Unicode characters by total frequency with suggested LaTeX replacements
+3. Which fields are most affected
+4. Per-file tallies
+5. Samples of lines showing each unusual character in context
+6. A machine-readable JSON report for downstream cleaning
+
+Does NOT modify any file. Read-only audit.
+"""
+from __future__ import annotations
+import json
+import sys
+import unicodedata
+from pathlib import Path
+from collections import defaultdict, Counter
+
+# Both copies of the dataset
+DIRS = [
+ Path("/home/yurenh2/gap/putnam-bench-anon/dataset"),
+ Path("/home/yurenh2/gap/putnamsup/PutnamGAP"),
+]
+
+# Text-bearing fields we care about
+TOP_LEVEL_TEXT_FIELDS = ["question", "solution"]
+VARIANT_TEXT_FIELDS = ["question", "solution"]
+VARIANT_KEYS = [
+ "descriptive_long",
+ "descriptive_long_confusing",
+ "descriptive_long_misleading",
+ "garbled_string",
+ "kernel_variant",
+ "original_kernel_variant",
+]
+
+# Suggested LaTeX replacements for common math Unicode. (Informational — the
+# audit does not apply these.) Each entry is (unicode_char, latex_suggestion).
+SUGGESTED_LATEX = {
+ # Greek lower case
+ "α": r"\alpha", "β": r"\beta", "γ": r"\gamma", "δ": r"\delta",
+ "ε": r"\varepsilon", "ζ": r"\zeta", "η": r"\eta", "θ": r"\theta",
+ "ι": r"\iota", "κ": r"\kappa", "λ": r"\lambda", "μ": r"\mu",
+ "ν": r"\nu", "ξ": r"\xi", "π": r"\pi", "ρ": r"\rho", "σ": r"\sigma",
+ "τ": r"\tau", "υ": r"\upsilon", "φ": r"\varphi", "χ": r"\chi",
+ "ψ": r"\psi", "ω": r"\omega",
+ # Greek upper case
+ "Α": "A", "Β": "B", "Γ": r"\Gamma", "Δ": r"\Delta", "Ε": "E",
+ "Ζ": "Z", "Η": "H", "Θ": r"\Theta", "Λ": r"\Lambda", "Ξ": r"\Xi",
+ "Π": r"\Pi", "Σ": r"\Sigma", "Φ": r"\Phi", "Ψ": r"\Psi",
+ "Ω": r"\Omega",
+ # Math operators & relations
+ "≤": r"\leq", "≥": r"\geq", "≠": r"\neq", "≈": r"\approx",
+ "≡": r"\equiv", "±": r"\pm", "∓": r"\mp", "×": r"\times",
+ "÷": r"\div", "·": r"\cdot", "∙": r"\cdot",
+ "∞": r"\infty", "∂": r"\partial", "∇": r"\nabla", "∆": r"\Delta",
+ "∑": r"\sum", "∏": r"\prod", "∫": r"\int", "√": r"\sqrt{}",
+ "∮": r"\oint", "∴": r"\therefore", "∵": r"\because",
+ "∈": r"\in", "∉": r"\notin", "⊂": r"\subset", "⊆": r"\subseteq",
+ "⊃": r"\supset", "⊇": r"\supseteq", "∪": r"\cup", "∩": r"\cap",
+ "∧": r"\land", "∨": r"\lor", "¬": r"\neg",
+ "→": r"\to", "←": r"\leftarrow", "↔": r"\leftrightarrow",
+ "⇒": r"\Rightarrow", "⇐": r"\Leftarrow", "⇔": r"\Leftrightarrow",
+ "⟨": r"\langle", "⟩": r"\rangle", "⌊": r"\lfloor", "⌋": r"\rfloor",
+ "⌈": r"\lceil", "⌉": r"\rceil",
+ "∅": r"\emptyset", "ℝ": r"\mathbb{R}", "ℂ": r"\mathbb{C}",
+ "ℕ": r"\mathbb{N}", "ℤ": r"\mathbb{Z}", "ℚ": r"\mathbb{Q}",
+ # Subscripts / superscripts (common ones only)
+ "₀": "_0", "₁": "_1", "₂": "_2", "₃": "_3", "₄": "_4", "₅": "_5",
+ "₆": "_6", "₇": "_7", "₈": "_8", "₉": "_9",
+ "⁰": "^0", "¹": "^1", "²": "^2", "³": "^3", "⁴": "^4", "⁵": "^5",
+ "⁶": "^6", "⁷": "^7", "⁸": "^8", "⁹": "^9",
+ "ₐ": "_a", "ᵢ": "_i", "ⱼ": "_j", "ₖ": "_k", "ₙ": "_n",
+ # Fractions
+ "½": r"\frac{1}{2}", "⅓": r"\frac{1}{3}", "⅔": r"\frac{2}{3}",
+ "¼": r"\frac{1}{4}", "¾": r"\frac{3}{4}",
+ # Punctuation / whitespace
+ "—": "---", "–": "--", "…": r"\ldots",
+ "‘": "`", "’": "'", "“": "``", "”": "''",
+ "°": r"^\circ",
+ "\u00A0": " (nbsp)", # non-breaking space
+ "\u2009": " (thin space)",
+ "\u200b": " (zero-width space)",
+ "\u2026": r"\ldots",
+ "\u2212": "-", # Unicode minus vs hyphen
+}
+
+
+def is_non_ascii(ch: str) -> bool:
+ return ord(ch) > 127
+
+
+def extract_text_fields(problem: dict):
+ """Yield (field_path, text) for every text-bearing field in a problem."""
+ idx = problem.get("index", "?")
+ for k in TOP_LEVEL_TEXT_FIELDS:
+ v = problem.get(k)
+ if isinstance(v, str):
+ yield f"{idx}:{k}", v
+ for vk in VARIANT_KEYS:
+ vd = (problem.get("variants") or {}).get(vk)
+ if not isinstance(vd, dict):
+ continue
+ for k in VARIANT_TEXT_FIELDS:
+ v = vd.get(k)
+ if isinstance(v, str):
+ yield f"{idx}:variants.{vk}.{k}", v
+
+
+def audit_dir(dataset_dir: Path, label: str):
+ print(f"\n{'=' * 76}")
+ print(f"Auditing {label}: {dataset_dir}")
+ print(f"{'=' * 76}")
+
+ files = sorted(dataset_dir.glob("*.json"))
+ print(f"Files: {len(files)}")
+
+ char_counter = Counter() # unicode char -> total occurrences
+ field_char_counter = defaultdict(Counter) # field_name -> Counter
+ files_with_unicode = set() # set of problem indices
+ per_field_counts = Counter() # {question, solution, variants.DL.question, ...} -> n files with unicode
+ examples = defaultdict(list) # char -> list of (context, path)
+ total_chars = 0
+ total_unicode = 0
+
+ for f in files:
+ try:
+ d = json.load(open(f))
+ except Exception as e:
+ print(f" ! {f.name}: JSON parse error: {e}")
+ continue
+ file_had_unicode = False
+ for path, text in extract_text_fields(d):
+ if not text:
+ continue
+ total_chars += len(text)
+ nas = [c for c in text if is_non_ascii(c)]
+ if not nas:
+ continue
+ file_had_unicode = True
+ total_unicode += len(nas)
+ # tally
+ for c in nas:
+ char_counter[c] += 1
+ # short field label (strip problem index prefix)
+ short = path.split(":", 1)[1]
+ field_char_counter[short][c] += 1
+ per_field_counts[short] += 1
+ # collect up to 3 examples per char with ±20 char context
+ if len(examples[c]) < 3:
+ idx = text.find(c)
+ start = max(0, idx - 25)
+ end = min(len(text), idx + 25)
+ ctx = text[start:end].replace("\n", " ")
+ examples[c].append((ctx, path))
+ if file_had_unicode:
+ files_with_unicode.add(d.get("index", f.name))
+
+ # Report
+ print(f"\nTotal characters scanned: {total_chars:,}")
+ print(f"Non-ASCII characters: {total_unicode:,} ({total_unicode/total_chars*100:.2f}%)")
+ print(f"Files with any Unicode: {len(files_with_unicode)}/{len(files)} "
+ f"({len(files_with_unicode)/len(files)*100:.1f}%)")
+ print(f"Distinct Unicode code points: {len(char_counter)}")
+
+ print(f"\n--- Top 40 Unicode characters by frequency ---")
+ print(f"{'char':<6} {'hex':<8} {'count':>8} name / suggested LaTeX")
+ print("-" * 76)
+ for c, n in char_counter.most_common(40):
+ name = unicodedata.name(c, "?")
+ hex_val = f"U+{ord(c):04X}"
+ suggestion = SUGGESTED_LATEX.get(c, "")
+ display_c = c if c.isprintable() and ord(c) > 0x20 else repr(c)
+ print(f"{display_c:<6} {hex_val:<8} {n:>8} {name[:45]:<45} {suggestion}")
+
+ # Per-field breakdown
+ print(f"\n--- Unicode per field (top 15 fields with most Unicode) ---")
+ print(f"{'field':<50} {'total unicode':>15}")
+ print("-" * 70)
+ for field, cnt in Counter({f: sum(c.values()) for f, c in field_char_counter.items()}).most_common(15):
+ print(f"{field:<50} {cnt:>15}")
+
+ # Examples for top 10 chars
+ print(f"\n--- Example contexts for top 10 Unicode chars ---")
+ for c, n in char_counter.most_common(10):
+ name = unicodedata.name(c, "?")
+ display_c = c if c.isprintable() and ord(c) > 0x20 else repr(c)
+ print(f"\n {display_c} (U+{ord(c):04X}, {name}, n={n}):")
+ for ctx, path in examples[c][:2]:
+ print(f" [{path}]")
+ print(f" …{ctx}…")
+
+ # Machine-readable summary
+ summary = {
+ "dataset_dir": str(dataset_dir),
+ "n_files": len(files),
+ "n_files_with_unicode": len(files_with_unicode),
+ "pct_files_with_unicode": 100 * len(files_with_unicode) / max(1, len(files)),
+ "total_chars": total_chars,
+ "total_unicode": total_unicode,
+ "distinct_codepoints": len(char_counter),
+ "top_chars": [
+ {"char": c, "codepoint": f"U+{ord(c):04X}",
+ "name": unicodedata.name(c, "?"),
+ "count": n,
+ "suggested_latex": SUGGESTED_LATEX.get(c, ""),
+ "examples": [{"path": path, "context": ctx}
+ for ctx, path in examples[c][:3]]}
+ for c, n in char_counter.most_common(80)
+ ],
+ "per_field_unicode_counts": dict(
+ Counter({f: sum(c.values()) for f, c in field_char_counter.items()})
+ .most_common(30)),
+ "files_with_unicode_indices": sorted(files_with_unicode),
+ }
+ return summary
+
+
+def main():
+ all_summaries = []
+ for d in DIRS:
+ if d.exists():
+ s = audit_dir(d, d.name)
+ s["label"] = d.name
+ all_summaries.append(s)
+ else:
+ print(f" (skipping missing dir {d})")
+
+ out_path = Path("/home/yurenh2/gap/analysis/unicode_audit.json")
+ json.dump(all_summaries, open(out_path, "w"), indent=2, ensure_ascii=False)
+ print(f"\n\nSaved machine-readable summary -> {out_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/analysis/unicode_clean.py b/analysis/unicode_clean.py
new file mode 100644
index 0000000..cea3cbe
--- /dev/null
+++ b/analysis/unicode_clean.py
@@ -0,0 +1,729 @@
+"""Unicode -> LaTeX cleaner for PutnamGAP dataset (v2).
+
+Improvements over v1:
+ - Pre-normalize via NFKD then strip combining diacritics so accented
+ letters collapse to their ASCII base.
+ - Group adjacent subscript/superscript runs into {...}: x_1_0 -> x_{10},
+ x^2^3 -> x^{23}.
+ - Wrap the argument of radical commands: \\sqrt-followed-by-X -> \\sqrt{X}
+ where X is either an identifier/number run or a balanced paren/bracket
+ group or a single \\-command (optionally followed by {...} arguments).
+ - Explicit replacements for symbols that previously fell through:
+ star, blacksquare/QED, fraction slash, dagger, etc.
+ - Deletes lone combining diacritics and decorative box-drawing characters.
+
+Operates IN PLACE on both dataset copies. Backup in a tarball first.
+"""
+from __future__ import annotations
+import json
+import re
+import sys
+import unicodedata
+from pathlib import Path
+from collections import Counter
+
+DIRS = [
+ Path("/home/yurenh2/gap/putnam-bench-anon/dataset"),
+ Path("/home/yurenh2/gap/putnamsup/PutnamGAP"),
+]
+
+TOP_LEVEL_TEXT_FIELDS = ["question", "solution"]
+VARIANT_TEXT_FIELDS = ["question", "solution"]
+VARIANT_KEYS = [
+ "descriptive_long",
+ "descriptive_long_confusing",
+ "descriptive_long_misleading",
+ "garbled_string",
+ "kernel_variant",
+ "original_kernel_variant",
+]
+
+
+# Sentinels placed during char substitution, resolved in a later pass that
+# can look at the following characters to extract the radical argument.
+SENT_SQRT = "\x01SQRT\x01"
+SENT_CBRT = "\x01CBRT\x01"
+SENT_FRT = "\x01FRT\x01"
+
+REPLACEMENTS: dict = {
+ # Whitespace -> normal space
+ "\u00A0": " ", "\u2002": " ", "\u2003": " ", "\u2004": " ",
+ "\u2005": " ", "\u2006": " ", "\u2007": " ", "\u2008": " ",
+ "\u2009": " ", "\u200A": " ", "\u200B": "", "\u200C": "",
+ "\u200D": "", "\u202F": " ", "\u205F": " ", "\u3000": " ",
+ "\uFEFF": "",
+
+ # Dashes / hyphens
+ # NOTE: in this dataset (kernel-variant LLM-generated math text) the
+ # EN DASH is used pervasively as a math minus sign, not a typographic
+ # en-dash, so we map it to a single hyphen-minus rather than the
+ # typographic `--`. The EM DASH stays as `---` (prose convention).
+ "\u2010": "-", "\u2011": "-",
+ "\u2012": "-", # FIGURE DASH
+ "\u2013": "-", # EN DASH (was `--`; common usage here is math minus)
+ "\u2014": "---", # EM DASH (typographic prose break)
+ "\u2015": "---", # HORIZONTAL BAR
+ "\u2212": "-",
+
+ # Quotation marks
+ "\u2018": "`", "\u2019": "'", "\u201A": ",", "\u201B": "`",
+ "\u201C": "``", "\u201D": "''", "\u201E": ",,",
+ "\u00AB": "<<", "\u00BB": ">>",
+
+ # Punctuation / miscellany
+ "\u2022": "*",
+ "\u2023": "*",
+ "\u2027": ".",
+ "\u2026": r"\ldots",
+ "\u00B7": r"\cdot",
+ "\u00B0": r"^\circ",
+ "\u2032": "'", "\u2033": "''", "\u2034": "'''", "\u2035": "`",
+ "\u2605": r"\star",
+ "\u2606": r"\star",
+ "\u25A0": r"\blacksquare",
+ "\u25A1": r"\square",
+ "\u220E": r"\blacksquare",
+ "\u2020": r"\dagger",
+ "\u2021": r"\ddagger",
+ "\u2044": "/",
+
+ # Sub/super digits
+ "\u2070": "^0", "\u00B9": "^1", "\u00B2": "^2", "\u00B3": "^3",
+ "\u2074": "^4", "\u2075": "^5", "\u2076": "^6", "\u2077": "^7",
+ "\u2078": "^8", "\u2079": "^9",
+ "\u207A": "^+", "\u207B": "^-", "\u207C": "^=", "\u207D": "^(", "\u207E": "^)",
+ "\u2080": "_0", "\u2081": "_1", "\u2082": "_2", "\u2083": "_3",
+ "\u2084": "_4", "\u2085": "_5", "\u2086": "_6", "\u2087": "_7",
+ "\u2088": "_8", "\u2089": "_9",
+ "\u208A": "_+", "\u208B": "_-", "\u208C": "_=", "\u208D": "_(", "\u208E": "_)",
+
+ # Latin sub/super letters
+ "\u2090": "_a", "\u2091": "_e", "\u2092": "_o", "\u2093": "_x",
+ "\u2095": "_h", "\u2096": "_k", "\u2097": "_l", "\u2098": "_m",
+ "\u2099": "_n", "\u209A": "_p", "\u209B": "_s", "\u209C": "_t",
+ "\u2C7C": "_j", # LATIN SUBSCRIPT SMALL LETTER J
+ "\u1D30": "^D", "\u1D31": "^E", "\u1D33": "^G", "\u1D34": "^H",
+ "\u1D35": "^I", "\u1D36": "^J", "\u1D37": "^K", "\u1D38": "^L",
+ "\u1D39": "^M", "\u1D3A": "^N", "\u1D3C": "^O", "\u1D3E": "^P",
+ "\u1D3F": "^R", "\u1D40": "^T", "\u1D41": "^U", "\u1D42": "^W",
+ "\u1D43": "^a", "\u1D47": "^b", "\u1D48": "^d", "\u1D49": "^e",
+ "\u1D4D": "^g", "\u1D4F": "^k", "\u1D50": "^m", "\u1D52": "^o",
+ "\u1D56": "^p", "\u1D57": "^t", "\u1D58": "^u", "\u1D5B": "^v",
+ "\u1D62": "_i", "\u1D63": "_r", "\u1D64": "_u", "\u1D65": "_v",
+ "\u2071": "^i", "\u207F": "^n",
+
+ # Greek lower case
+ "\u03B1": r"\alpha", "\u03B2": r"\beta", "\u03B3": r"\gamma",
+ "\u03B4": r"\delta", "\u03B5": r"\varepsilon", "\u03B6": r"\zeta",
+ "\u03B7": r"\eta", "\u03B8": r"\theta", "\u03B9": r"\iota",
+ "\u03BA": r"\kappa", "\u03BB": r"\lambda", "\u03BC": r"\mu",
+ "\u03BD": r"\nu", "\u03BE": r"\xi", "\u03BF": "o",
+ "\u03C0": r"\pi", "\u03C1": r"\rho", "\u03C2": r"\varsigma",
+ "\u03C3": r"\sigma", "\u03C4": r"\tau", "\u03C5": r"\upsilon",
+ "\u03C6": r"\varphi", "\u03C7": r"\chi", "\u03C8": r"\psi",
+ "\u03C9": r"\omega",
+ "\u03D5": r"\phi", "\u03D1": r"\vartheta", "\u03D6": r"\varpi",
+ "\u03F1": r"\varrho", "\u03F5": r"\epsilon",
+ # Greek upper case
+ "\u0391": "A", "\u0392": "B", "\u0393": r"\Gamma",
+ "\u0394": r"\Delta", "\u0395": "E", "\u0396": "Z",
+ "\u0397": "H", "\u0398": r"\Theta", "\u0399": "I",
+ "\u039A": "K", "\u039B": r"\Lambda", "\u039C": "M",
+ "\u039D": "N", "\u039E": r"\Xi", "\u039F": "O",
+ "\u03A0": r"\Pi", "\u03A1": "P", "\u03A3": r"\Sigma",
+ "\u03A4": "T", "\u03A5": r"\Upsilon", "\u03A6": r"\Phi",
+ "\u03A7": "X", "\u03A8": r"\Psi", "\u03A9": r"\Omega",
+
+ # Math operators / relations
+ "\u2200": r"\forall", "\u2203": r"\exists", "\u2204": r"\nexists",
+ "\u2205": r"\emptyset",
+ "\u2208": r"\in", "\u2209": r"\notin", "\u220B": r"\ni",
+ "\u220F": r"\prod", "\u2210": r"\coprod", "\u2211": r"\sum",
+ "\u2213": r"\mp", "\u00B1": r"\pm",
+ "\u2214": r"\dotplus",
+ "\u2217": "*", "\u2218": r"\circ", "\u2219": r"\cdot",
+ "\u221D": r"\propto",
+ "\u221E": r"\infty",
+ "\u2220": r"\angle", "\u2221": r"\measuredangle",
+ "\u2225": r"\parallel", "\u2226": r"\nparallel",
+ "\u2227": r"\land", "\u2228": r"\lor",
+ "\u2229": r"\cap", "\u222A": r"\cup",
+ "\u222B": r"\int", "\u222C": r"\iint", "\u222D": r"\iiint",
+ "\u222E": r"\oint", "\u222F": r"\oiint",
+ "\u2234": r"\therefore", "\u2235": r"\because",
+ "\u2236": ":", "\u2237": "::",
+ "\u223C": r"\sim", "\u2243": r"\simeq", "\u2245": r"\cong",
+ "\u2248": r"\approx", "\u224D": r"\asymp",
+ "\u2250": r"\doteq",
+ "\u2260": r"\neq", "\u2261": r"\equiv", "\u2262": r"\not\equiv",
+ "\u2264": r"\leq", "\u2265": r"\geq",
+ "\u2266": r"\leqq", "\u2267": r"\geqq",
+ "\u226A": r"\ll", "\u226B": r"\gg",
+ "\u2270": r"\not\leq", "\u2271": r"\not\geq",
+ "\u2282": r"\subset", "\u2283": r"\supset",
+ "\u2284": r"\not\subset", "\u2285": r"\not\supset",
+ "\u2286": r"\subseteq", "\u2287": r"\supseteq",
+ "\u2288": r"\not\subseteq", "\u2289": r"\not\supseteq",
+ "\u228A": r"\subsetneq", "\u228B": r"\supsetneq",
+ "\u2295": r"\oplus", "\u2296": r"\ominus",
+ "\u2297": r"\otimes", "\u2298": r"\oslash", "\u2299": r"\odot",
+ "\u22A2": r"\vdash", "\u22A3": r"\dashv",
+ "\u22A4": r"\top", "\u22A5": r"\bot",
+ "\u22A8": r"\models",
+ "\u22C0": r"\bigwedge", "\u22C1": r"\bigvee",
+ "\u22C2": r"\bigcap", "\u22C3": r"\bigcup",
+ "\u22C5": r"\cdot", "\u22C6": r"\star",
+ "\u22EE": r"\vdots", "\u22EF": r"\cdots",
+ "\u22F1": r"\ddots",
+
+ # Arrows
+ "\u2190": r"\leftarrow", "\u2192": r"\to",
+ "\u2191": r"\uparrow", "\u2193": r"\downarrow",
+ "\u2194": r"\leftrightarrow", "\u2195": r"\updownarrow",
+ "\u21A0": r"\twoheadrightarrow",
+ "\u21A6": r"\mapsto",
+ "\u21D0": r"\Leftarrow", "\u21D2": r"\Rightarrow",
+ "\u21D1": r"\Uparrow", "\u21D3": r"\Downarrow",
+ "\u21D4": r"\Leftrightarrow",
+ "\u27F6": r"\longrightarrow", "\u27F5": r"\longleftarrow",
+ "\u27F9": r"\Longrightarrow", "\u27F8": r"\Longleftarrow",
+ "\u27FA": r"\Longleftrightarrow",
+
+ # Delimiters
+ "\u2016": r"\|",
+ "\u2308": r"\lceil", "\u2309": r"\rceil",
+ "\u230A": r"\lfloor", "\u230B": r"\rfloor",
+ "\u27E8": r"\langle", "\u27E9": r"\rangle",
+ "\u27EA": r"\llangle", "\u27EB": r"\rrangle",
+
+ # Blackboard / script letters
+ "\u2102": r"\mathbb{C}", "\u210D": r"\mathbb{H}",
+ "\u2115": r"\mathbb{N}", "\u2119": r"\mathbb{P}",
+ "\u211A": r"\mathbb{Q}", "\u211D": r"\mathbb{R}",
+ "\u2124": r"\mathbb{Z}",
+ "\u2113": r"\ell", "\u210F": r"\hbar",
+ "\u2202": r"\partial", "\u2207": r"\nabla", "\u2118": r"\wp",
+ "\u2133": r"\mathcal{M}", "\u2112": r"\mathcal{L}",
+ "\u211B": r"\mathcal{R}", "\u2110": r"\mathcal{I}",
+ "\u2130": r"\mathcal{E}", "\u2132": "F",
+
+ # Fractions with precomposed forms
+ "\u00BC": r"\frac{1}{4}", "\u00BD": r"\frac{1}{2}", "\u00BE": r"\frac{3}{4}",
+ "\u2153": r"\frac{1}{3}", "\u2154": r"\frac{2}{3}",
+ "\u2155": r"\frac{1}{5}", "\u2156": r"\frac{2}{5}",
+ "\u2157": r"\frac{3}{5}", "\u2158": r"\frac{4}{5}",
+ "\u2159": r"\frac{1}{6}", "\u215A": r"\frac{5}{6}",
+ "\u215B": r"\frac{1}{8}", "\u215C": r"\frac{3}{8}",
+ "\u215D": r"\frac{5}{8}", "\u215E": r"\frac{7}{8}",
+
+ # Multiplication / division
+ "\u00D7": r"\times", "\u00F7": r"\div",
+
+ # Misc
+ "\u00A7": r"\S",
+ "\u00B6": r"\P",
+ "\u00A9": "(c)", "\u00AE": "(R)", "\u2122": "(TM)",
+ "\u00A3": r"\pounds", "\u20AC": "EUR",
+ "\u00B5": r"\mu",
+
+ # Additional math symbols
+ "\u2216": r"\setminus",
+ "\u2223": r"\mid",
+ "\u2224": r"\nmid",
+ "\u2225": r"\parallel", # duplicate of above, safe
+ "\u2226": r"\nparallel",
+ "\u22BB": r"\veebar",
+ "\u22BC": r"\barwedge",
+ "\u2238": r"\dot{-}",
+ "\u22C8": r"\bowtie",
+ "\u22CE": r"\curlyvee",
+ "\u22CF": r"\curlywedge",
+
+ # Perp and triangle family
+ "\u27C2": r"\perp",
+ "\u22A5": r"\bot", # already present but safe
+ "\u25B3": r"\triangle",
+ "\u25B4": r"\blacktriangle",
+ "\u25BD": r"\triangledown",
+ "\u25BE": r"\blacktriangledown",
+ "\u25C1": r"\triangleleft",
+ "\u25C2": r"\blacktriangleleft",
+ "\u25B7": r"\triangleright",
+ "\u25B8": r"\blacktriangleright",
+
+ # Square / box operators
+ "\u2293": r"\sqcap",
+ "\u2294": r"\sqcup",
+ "\u22A1": r"\boxdot",
+ "\u229E": r"\boxplus",
+ "\u229F": r"\boxminus",
+ "\u22A0": r"\boxtimes",
+
+ # Preceq / succeq family
+ "\u227A": r"\prec",
+ "\u227B": r"\succ",
+ "\u227C": r"\preceq",
+ "\u227D": r"\succeq",
+ "\u2280": r"\nprec",
+ "\u2281": r"\nsucc",
+ "\u22E0": r"\npreceq",
+ "\u22E1": r"\nsucceq",
+
+ # Double-square brackets
+ "\u27E6": r"\llbracket",
+ "\u27E7": r"\rrbracket",
+
+ # Card-suit decorative (drop)
+ "\u2660": "", # spade
+ "\u2661": "",
+ "\u2662": "",
+ "\u2663": "", # club
+ "\u2664": "",
+ "\u2665": "", # heart
+ "\u2666": "", # diamond
+
+ # Musical / dingbat decorations (drop)
+ "\u266A": "", # eighth note
+ "\u266B": "", # beamed eighth notes
+ "\u2713": r"\checkmark",
+ "\u2717": r"\times",
+
+ # Curved delimiters / bracket extension pieces -- these are used by the
+ # kernel generator to draw big parentheses/brackets around multi-line
+ # expressions (like matrices). They are purely decorative in plain text
+ # and we drop them.
+ "\u239B": "", "\u239C": "", "\u239D": "", # ( upper/mid/lower
+ "\u239E": "", "\u239F": "", "\u23A0": "", # ) upper/mid/lower
+ "\u23A1": "", "\u23A2": "", "\u23A3": "", # [ upper/mid/lower
+ "\u23A4": "", "\u23A5": "", "\u23A6": "", # ] upper/mid/lower
+ "\u23A7": "", "\u23A8": "", "\u23A9": "", # { upper/middle/lower
+ "\u23AA": "", # { extension
+ "\u23AB": "", "\u23AC": "", "\u23AD": "", # } upper/middle/lower
+ "\u23AE": "", # integral extension
+ "\u23AF": "", # horizontal line extension
+ "\u23B0": "", "\u23B1": "", # upper/lower curly bracket
+ "\u23B2": "", "\u23B3": "", # summation top/bottom
+ "\u23B4": "", "\u23B5": "", # top/bottom square bracket
+ "\u23B6": "", "\u23B7": "", # bottom square bracket w/tick
+ "\u23D0": "", # vertical line extension
+
+ # Combining over/underlines are stripped by the combining-mark regex
+
+ # Additional remaining symbols found after first clean pass
+ "\u00AD": "", # SOFT HYPHEN -> delete
+ "\u2215": "/", # DIVISION SLASH
+ "\u25A2": r"\square", # WHITE SQUARE WITH ROUNDED CORNERS
+ "\u2718": r"\times", # HEAVY BALLOT X
+ "\u3008": r"\langle", # CJK LEFT ANGLE BRACKET
+ "\u3009": r"\rangle", # CJK RIGHT ANGLE BRACKET
+ "\u2254": ":=", # COLON EQUALS
+ "\u2255": "=:", # EQUALS COLON
+ "\u2198": r"\searrow", # SOUTH EAST ARROW
+ "\u2197": r"\nearrow", # NORTH EAST ARROW
+ "\u2199": r"\swarrow",
+ "\u2196": r"\nwarrow",
+ "\u21A9": r"\hookleftarrow",
+ "\u21AA": r"\hookrightarrow",
+ "\u21BC": r"\leftharpoonup",
+ "\u21BD": r"\leftharpoondown",
+ "\u21BE": r"\upharpoonright",
+ "\u21BF": r"\upharpoonleft",
+ "\u21C0": r"\rightharpoonup",
+ "\u21C1": r"\rightharpoondown",
+ "\u21C2": r"\downharpoonright",
+ "\u21C3": r"\downharpoonleft",
+ "\u21CC": r"\rightleftharpoons",
+ "\u21E2": r"\dashrightarrow",
+ "\u21E0": r"\dashleftarrow",
+ "\u2277": r"\gtrless",
+ "\u2276": r"\lessgtr",
+
+ # Private Use Area characters are almost always OCR garbage or
+ # font-specific glyphs; drop them.
+ "\uF8EB": "", "\uF8F6": "",
+ "\uF8FE": "", "\uF8FD": "", "\uF8FC": "", "\uF8FB": "",
+ "\uF8EF": "", "\uF8F0": "", "\uF8F1": "", "\uF8F2": "",
+
+ # A few more rare but meaningful math symbols
+ "\u2322": r"\frown",
+ "\u2323": r"\smile",
+ "\u226D": r"\not\asymp",
+ "\u22A7": r"\models",
+ "\u22B2": r"\vartriangleleft",
+ "\u22B3": r"\vartriangleright",
+ "\u22B4": r"\trianglelefteq",
+ "\u22B5": r"\trianglerighteq",
+
+ # Small-caps letters sometimes emitted by OCR (collapse to plain letter)
+ "\u026A": "I", # LATIN LETTER SMALL CAPITAL I
+ "\u1D00": "A",
+ "\u1D04": "C",
+ "\u1D05": "D",
+ "\u1D07": "E",
+ "\u0262": "G",
+ "\u029C": "H",
+
+ # Remaining math symbols found after pass 2
+ "\u2A01": r"\bigoplus",
+ "\u2A02": r"\bigotimes",
+ "\u2A00": r"\bigodot",
+ "\u2A03": r"\biguplus",
+ "\u2A04": r"\biguplus",
+ "\u2A05": r"\bigsqcap",
+ "\u2A06": r"\bigsqcup",
+ "\u2272": r"\lesssim",
+ "\u2273": r"\gtrsim",
+ "\u226E": r"\not<",
+ "\u226F": r"\not>",
+ "\u27EE": "(", # MATHEMATICAL LEFT FLATTENED PARENTHESIS
+ "\u27EF": ")", # MATHEMATICAL RIGHT FLATTENED PARENTHESIS
+ "\u2610": r"\square", # BALLOT BOX
+ "\u2611": r"\checkmark",
+ "\u2612": r"\times",
+
+ # Root sentinels (wrapped in a later pass)
+ "\u221A": SENT_SQRT,
+ "\u221B": SENT_CBRT,
+ "\u221C": SENT_FRT,
+}
+
+
+_COMBINING_MARK_RE = re.compile(
+ r"[\u0300-\u036F\u1AB0-\u1AFF\u1DC0-\u1DFF\u20D0-\u20FF\uFE20-\uFE2F]")
+_BOX_DRAWING_RE = re.compile(r"[\u2500-\u257F\u2580-\u259F]")
+
+# Characters from scripts that have no place in English/Greek mathematics
+# and are clearly OCR noise when they appear. Drop them wholesale. Latin and
+# Greek are preserved; extended Latin letters with diacritics are still
+# handled by the NFKD fallback.
+_OCR_NOISE_SCRIPTS_RE = re.compile(
+ r"[\u0400-\u04FF" # Cyrillic
+ r"\u0500-\u052F" # Cyrillic Supplement
+ r"\u0530-\u058F" # Armenian
+ r"\u0590-\u05FF" # Hebrew
+ r"\u0600-\u06FF" # Arabic
+ r"\u0700-\u074F" # Syriac
+ r"\u0750-\u077F" # Arabic Supplement
+ r"\u0780-\u07BF" # Thaana
+ r"\u0900-\u097F" # Devanagari
+ r"\u0B80-\u0BFF" # Tamil
+ r"\u0C00-\u0C7F" # Telugu
+ r"\u0C80-\u0CFF" # Kannada
+ r"\u0D00-\u0D7F" # Malayalam
+ r"\u0D80-\u0DFF" # Sinhala
+ r"\u0E00-\u0E7F" # Thai
+ r"\u0E80-\u0EFF" # Lao
+ r"\u0F00-\u0FFF" # Tibetan
+ r"\u1000-\u109F" # Myanmar
+ r"\u10A0-\u10FF" # Georgian
+ r"\u1100-\u11FF" # Hangul Jamo
+ r"\u1400-\u167F" # Unified Canadian Aboriginal Syllabics
+ r"\u1680-\u169F" # Ogham
+ r"\u16A0-\u16FF" # Runic
+ r"\u1700-\u171F" # Tagalog
+ r"\u1780-\u17FF" # Khmer
+ r"\u1800-\u18AF" # Mongolian
+ r"\u1900-\u194F" # Limbu
+ r"\u3040-\u309F" # Hiragana
+ r"\u30A0-\u30FF" # Katakana
+ r"\u3000-\u303F" # CJK Symbols and Punctuation (incl. ideographic full stop)
+ r"\u3100-\u312F" # Bopomofo
+ r"\u3130-\u318F" # Hangul Compatibility Jamo
+ r"\u3190-\u319F" # Kanbun
+ r"\u3400-\u4DBF" # CJK Extension A
+ r"\u4E00-\u9FFF" # CJK Unified Ideographs
+ r"\uA000-\uA48F" # Yi Syllables
+ r"\uAC00-\uD7AF" # Hangul Syllables
+ r"\uE000-\uF8FF" # Private Use Area
+ r"\uFE00-\uFE0F" # Variation Selectors
+ r"\uFE30-\uFE4F" # CJK Compatibility Forms (vertical presentation
+ # brackets that NFKD-decompose to literal { } [ ] etc.,
+ # which would corrupt our brace balance — drop them)
+ r"\uFE50-\uFE6F" # Small Form Variants (compatibility forms)
+ r"\uFFFC\uFFFD" # Object/Replacement Character
+ r"]"
+)
+
+# Emoji and pictographs (outside the BMP, need surrogate handling)
+_EMOJI_RE = re.compile(
+ "["
+ "\U0001F000-\U0001F9FF" # Emoji blocks
+ "\U0001FA00-\U0001FAFF" # Symbols & Pictographs Extended-A
+ "\U0001F1E6-\U0001F1FF" # Regional indicator symbols
+ "\U0001F3FB-\U0001F3FF" # Emoji modifier fitzpatrick
+ "\U00020000-\U0002FA1F" # CJK Extensions B-F
+ "]",
+ flags=re.UNICODE
+)
+
+
+def prestrip(text: str) -> str:
+ """Strip decorative and OCR-noise characters BEFORE char substitution.
+
+ Important: we do NOT run NFKD here because NFKD decomposes subscript /
+ superscript digits (e.g. \u2080 -> '0') before our explicit REPLACEMENTS
+ entries can rewrite them as `_0`. NFKD is applied later only as a
+ fallback for characters that survive the explicit substitution pass
+ (e.g. accented Latin letters).
+ """
+ if not text:
+ return text
+ text = _BOX_DRAWING_RE.sub("", text)
+ # Lone combining marks are orphaned when the base character was something
+ # we otherwise transformed; strip them up front.
+ text = _COMBINING_MARK_RE.sub("", text)
+ # Strip OCR-noise scripts (Cyrillic / Arabic / CJK / etc.) that have no
+ # place in English-Greek mathematical prose.
+ text = _OCR_NOISE_SCRIPTS_RE.sub("", text)
+ # Strip emoji / pictographs (clearly LLM-emitted noise in math text).
+ text = _EMOJI_RE.sub("", text)
+ return text
+
+
+def char_substitute(text: str, unmapped: Counter) -> str:
+ """Apply REPLACEMENTS char-by-char. Any char not in REPLACEMENTS is left
+ in place so that _nfkd_fallback (run next) has a chance to handle it
+ via compatibility decomposition. A trailing space is appended to bare
+ `\\word` LaTeX commands so subsequent letters do not get absorbed into
+ the command name.
+ """
+ out = []
+ for ch in text:
+ if ord(ch) <= 127 or ch == "\x01":
+ out.append(ch)
+ continue
+ if ch in REPLACEMENTS:
+ val = REPLACEMENTS[ch]
+ # Bare `\word` (starts with `\\`, ends in a letter) needs a
+ # trailing space so that `\cdot t` does not become `\cdott`.
+ if (len(val) >= 2 and val[0] == "\\"
+ and val[-1].isalpha()
+ and not val.startswith("\x01")):
+ val = val + " "
+ out.append(val)
+ continue
+ # Unmapped: keep as-is and let _nfkd_fallback try compat decomposition.
+ out.append(ch)
+ return "".join(out)
+
+
+def _merge_sub_sup(text: str) -> str:
+ def _do(prefix, m):
+ # Extract each ^X or _X token and concatenate the X parts.
+ vals = re.findall(r"[\+\-\=\(\)a-zA-Z0-9]", m.group(0))
+ # The regex captures the X char from each ^X or _X; above regex
+ # finds ALL alnum/sign chars in the match. But `^+` etc. we want
+ # to keep as-is. Simplest: split on the prefix.
+ pieces = [p for p in re.split(r"[\^_]", m.group(0)) if p]
+ joined = "".join(pieces)
+ return f"{prefix}{{{joined}}}"
+
+ text = re.sub(
+ r"(?:\^[\+\-\=\(\)a-zA-Z0-9])(?:\^[\+\-\=\(\)a-zA-Z0-9])+",
+ lambda m: _do("^", m), text)
+ text = re.sub(
+ r"(?:_[\+\-\=\(\)a-zA-Z0-9])(?:_[\+\-\=\(\)a-zA-Z0-9])+",
+ lambda m: _do("_", m), text)
+ return text
+
+
+_SENTINEL_RE = re.compile(r"\x01(SQRT|CBRT|FRT)\x01")
+
+
+def _skip_spaces(s: str, i: int) -> int:
+ while i < len(s) and s[i] in " \t":
+ i += 1
+ return i
+
+
+def _read_balanced(s: str, i: int, open_ch: str, close_ch: str):
+ depth = 0
+ j = i
+ while j < len(s):
+ if s[j] == open_ch:
+ depth += 1
+ elif s[j] == close_ch:
+ depth -= 1
+ if depth == 0:
+ return j + 1
+ j += 1
+ return -1
+
+
+def _read_latex_command(s: str, i: int):
+ if i >= len(s) or s[i] != "\\":
+ return -1
+ j = i + 1
+ while j < len(s) and (s[j].isalpha() or s[j] == "@"):
+ j += 1
+ while j < len(s) and s[j] == "{":
+ end = _read_balanced(s, j, "{", "}")
+ if end == -1:
+ return j
+ j = end
+ return j
+
+
+def _wrap_radical_arguments(text: str) -> str:
+ out = []
+ i = 0
+ LATEX_FOR = {"SQRT": r"\sqrt", "CBRT": r"\sqrt[3]", "FRT": r"\sqrt[4]"}
+ while i < len(text):
+ m = _SENTINEL_RE.match(text, i)
+ if not m:
+ out.append(text[i])
+ i += 1
+ continue
+ kind = m.group(1)
+ latex_prefix = LATEX_FOR[kind]
+ j = _skip_spaces(text, m.end())
+ if j >= len(text):
+ out.append(latex_prefix + "{}")
+ i = j
+ continue
+ ch = text[j]
+ if ch == "(":
+ arg_end = _read_balanced(text, j, "(", ")")
+ if arg_end != -1:
+ arg = text[j + 1 : arg_end - 1]
+ out.append(f"{latex_prefix}{{{arg}}}")
+ i = arg_end
+ continue
+ if ch == "[":
+ arg_end = _read_balanced(text, j, "[", "]")
+ if arg_end != -1:
+ arg = text[j + 1 : arg_end - 1]
+ out.append(f"{latex_prefix}{{{arg}}}")
+ i = arg_end
+ continue
+ if ch == "{":
+ arg_end = _read_balanced(text, j, "{", "}")
+ if arg_end != -1:
+ arg = text[j + 1 : arg_end - 1]
+ out.append(f"{latex_prefix}{{{arg}}}")
+ i = arg_end
+ continue
+ if ch == "\\":
+ arg_end = _read_latex_command(text, j)
+ if arg_end != -1:
+ arg = text[j:arg_end]
+ out.append(f"{latex_prefix}{{{arg}}}")
+ i = arg_end
+ continue
+ # Fallback: alnum run (and dots for things like 3.14)
+ k = j
+ while k < len(text) and (text[k].isalnum() or text[k] in "."):
+ k += 1
+ if k > j:
+ arg = text[j:k]
+ out.append(f"{latex_prefix}{{{arg}}}")
+ i = k
+ continue
+ out.append(latex_prefix + "{}")
+ i = m.end()
+ return "".join(out)
+
+
+def _nfkd_fallback(text: str, unmapped: Counter) -> str:
+ """For characters that survived explicit substitution and are still
+ non-ASCII (e.g. precomposed accented Latin letters like \u00E9 / e-acute,
+ or classical Greek letters with breathing marks like \u1F42), run NFKD
+ and drop combining marks, then re-apply REPLACEMENTS (because NFKD can
+ unmask characters that do appear in REPLACEMENTS, e.g. \u1F42 -> \u03B3).
+ Finally, any character that is still non-ASCII is logged and dropped.
+ """
+ has_non_ascii = any(ord(c) > 127 and c != "\x01" for c in text)
+ if not has_non_ascii:
+ return text
+ text = unicodedata.normalize("NFKD", text)
+ text = _COMBINING_MARK_RE.sub("", text)
+ # Second pass of char_substitute now that NFKD has possibly surfaced
+ # characters that were previously embedded in precomposed forms.
+ text = char_substitute(text, unmapped) # unmapped counter accumulates
+ # Final drop of anything still non-ASCII
+ out = []
+ for c in text:
+ if ord(c) <= 127 or c == "\x01":
+ out.append(c)
+ else:
+ unmapped[c] += 1
+ return "".join(out)
+
+
+def clean_text(text: str, unmapped: Counter) -> str:
+ if not text:
+ return text
+ text = prestrip(text)
+ text = char_substitute(text, unmapped)
+ text = _nfkd_fallback(text, unmapped)
+ text = _merge_sub_sup(text)
+ text = _wrap_radical_arguments(text)
+ return text
+
+
+def clean_problem(problem: dict, unmapped: Counter):
+ for k in TOP_LEVEL_TEXT_FIELDS:
+ if isinstance(problem.get(k), str):
+ problem[k] = clean_text(problem[k], unmapped)
+ variants = problem.get("variants") or {}
+ for vk in VARIANT_KEYS:
+ vd = variants.get(vk)
+ if not isinstance(vd, dict):
+ continue
+ for k in VARIANT_TEXT_FIELDS:
+ if isinstance(vd.get(k), str):
+ vd[k] = clean_text(vd[k], unmapped)
+ return problem
+
+
+def process_dir(dataset_dir: Path):
+ print(f"\n=== Cleaning {dataset_dir} ===")
+ files = sorted(dataset_dir.glob("*.json"))
+ unmapped = Counter()
+ n_modified = 0
+ for f in files:
+ try:
+ d = json.load(open(f))
+ except Exception as e:
+ print(f" ! skip {f.name}: {e}")
+ continue
+ before = json.dumps(d, ensure_ascii=False)
+ d = clean_problem(d, unmapped)
+ after = json.dumps(d, ensure_ascii=False)
+ if before != after:
+ n_modified += 1
+ with open(f, "w") as fh:
+ json.dump(d, fh, ensure_ascii=False, indent=2)
+ print(f" files modified: {n_modified}/{len(files)}")
+ if unmapped:
+ print(f" unmapped characters: {sum(unmapped.values())} occurrences, "
+ f"{len(unmapped)} distinct")
+ print(f" top 20 unmapped:")
+ for ch, n in unmapped.most_common(20):
+ name = unicodedata.name(ch, "?")
+ print(f" {ch!r:<10} U+{ord(ch):04X} n={n} ({name})")
+ else:
+ print(f" no unmapped characters")
+ return unmapped
+
+
+def main():
+ all_unmapped = Counter()
+ for d in DIRS:
+ if d.exists():
+ u = process_dir(d)
+ all_unmapped.update(u)
+ print(f"\n=== OVERALL ===")
+ print(f"Total unmapped characters across both dataset copies: {sum(all_unmapped.values())}")
+ print(f"Distinct unmapped: {len(all_unmapped)}")
+ if all_unmapped:
+ out_path = Path("/home/yurenh2/gap/analysis/unmapped_chars.json")
+ json.dump({f"U+{ord(c):04X}": {"char": c, "name": unicodedata.name(c, "?"),
+ "count": n}
+ for c, n in all_unmapped.most_common()},
+ open(out_path, "w"), indent=2, ensure_ascii=False)
+ print(f"Saved unmapped list -> {out_path}")
+
+
+if __name__ == "__main__":
+ main()