summaryrefslogtreecommitdiff
path: root/analysis/kv_overlap.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@illinois.edu>2026-04-08 22:06:05 -0500
committerYuren Hao <yurenh2@illinois.edu>2026-04-08 22:06:05 -0500
commit05704d0eb2fa59fe727652465b07db40bcb06c38 (patch)
tree8904aca836cf552fd1a5ae8c2174e9f91e70bbbc /analysis/kv_overlap.py
Initial release: GAP framework
- Full pipeline: variant generation, multi-judge verification, evaluation - Loaders for OpenAI / Anthropic / Google / xAI / OpenRouter / vLLM - Framework-level mechanism analyses: paired structural overlap, repairability rescue, self-correction probe, cross-model agreement, topic x problem-type interaction - Unicode -> bare-LaTeX cleaner + audit + spot-check - Mirrors https://huggingface.co/datasets/blackhao0426/PutnamGAP
Diffstat (limited to 'analysis/kv_overlap.py')
-rw-r--r--analysis/kv_overlap.py332
1 files changed, 332 insertions, 0 deletions
diff --git a/analysis/kv_overlap.py b/analysis/kv_overlap.py
new file mode 100644
index 0000000..137e61f
--- /dev/null
+++ b/analysis/kv_overlap.py
@@ -0,0 +1,332 @@
+"""Kernel-variant structural-overlap analysis (label-free).
+
+Unlike surface variants, kernel variants change the math, so we cannot use the
+model's own original-correct trajectory as a reference. Instead we use the
+dataset's canonical kernel-variant solution as the reference.
+
+Hypothesis: stable (correct on KV) trajectories have higher structural overlap
+with the canonical KV solution than brittle (wrong on KV) trajectories.
+
+For comparability we also recompute the surface analyses using the same
+'overlap with canonical solution' metric, so we can compare apples-to-apples
+the magnitude of stable-vs-brittle gap between surface and kernel.
+"""
+from __future__ import annotations
+import json
+import os
+import statistics
+from pathlib import Path
+from collections import defaultdict
+from typing import Optional
+
+# Reuse helpers from the sibling module
+import sys
+sys.path.insert(0, str(Path(__file__).parent))
+from structural_overlap import (
+ DATASET_DIR, RESULTS_DIR,
+ load_problems, find_variant_file,
+ canonicalize_text, normalize_whitespace,
+ tokens, bigrams, jaccard, extract_math_blocks,
+ metric_token_jaccard, metric_bigram_jaccard,
+ metric_directional_coverage, metric_equation_jaccard,
+ mann_whitney_u, bootstrap_ci_cohens_d,
+ is_collapse, COLLAPSE_MIN_CHARS, COLLAPSE_RATIO,
+ SURFACE_VARIANTS,
+)
+
+
+def load_dataset_variant_solutions() -> dict:
+ """Returns: {problem_index: {variant_name: canonical_solution_text}}.
+
+ Includes 'original' (from top-level field) plus all 5 variants.
+ """
+ out = {}
+ for f in sorted(DATASET_DIR.glob("*.json")):
+ d = json.load(open(f))
+ idx = d.get("index")
+ cell = {"original": d.get("solution") or "",
+ "_problem_type": d.get("problem_type")}
+ for v, vd in d.get("variants", {}).items():
+ if isinstance(vd, dict):
+ cell[v] = vd.get("solution") or ""
+ out[idx] = cell
+ return out
+
+
+def load_dataset_maps() -> dict:
+ """Mirrors structural_overlap.load_dataset_maps but localized for safety."""
+ out = {}
+ for f in sorted(DATASET_DIR.glob("*.json")):
+ d = json.load(open(f))
+ idx = d.get("index")
+ variants = d.get("variants", {})
+ cell = {}
+ for v in SURFACE_VARIANTS:
+ vd = variants.get(v, {})
+ mp_str = vd.get("map")
+ if isinstance(mp_str, str):
+ try:
+ mp = eval(mp_str, {"__builtins__": {}}, {})
+ if isinstance(mp, dict):
+ cell[v] = {str(k): str(v) for k, v in mp.items()}
+ except Exception:
+ pass
+ elif isinstance(mp_str, dict):
+ cell[v] = {str(k): str(v) for k, v in mp_str.items()}
+ out[idx] = cell
+ return out
+
+
+# ---------- Cell analyzer ----------
+
+def analyze_kv_cell(model_name: str, model_dir: Path,
+ canonical_solutions: dict) -> Optional[dict]:
+ """Compare model's KV trajectory to dataset canonical KV solution.
+
+ No canonicalization (no rename map for KV — variables match by construction).
+ """
+ orig_path = find_variant_file(model_dir, "original")
+ var_path = find_variant_file(model_dir, "kernel_variant")
+ if not orig_path or not var_path:
+ return None
+ orig_by = {p["index"]: p for p in load_problems(orig_path)}
+ var_by = {p["index"]: p for p in load_problems(var_path)}
+
+ pairs_stable_drift = []
+ pairs_brittle_drift = []
+ n_brittle_collapse = 0
+ n_stable_collapse = 0
+
+ for idx in set(orig_by) & set(var_by):
+ po, pv = orig_by[idx], var_by[idx]
+ if po.get("correct") is not True:
+ continue # Restrict to "model already gets the original"
+ var_correct = pv.get("correct")
+ if var_correct is None:
+ continue
+ var_text = (pv.get("solve") or {}).get("solution") or ""
+ if not var_text:
+ continue
+ canon_kv = canonical_solutions.get(idx, {}).get("kernel_variant", "")
+ if not canon_kv or len(canon_kv) < 200:
+ continue
+ # Collapse rule: variant text < 200 chars OR < 25% of canonical solution
+ collapse = (len(var_text) < COLLAPSE_MIN_CHARS or
+ len(var_text) < COLLAPSE_RATIO * len(canon_kv))
+ sample = {"index": idx, "var_text": var_text, "canon": canon_kv}
+ if var_correct is True:
+ if collapse:
+ n_stable_collapse += 1
+ else:
+ pairs_stable_drift.append(sample)
+ else:
+ if collapse:
+ n_brittle_collapse += 1
+ else:
+ pairs_brittle_drift.append(sample)
+
+ if not pairs_stable_drift or not pairs_brittle_drift:
+ return None
+
+ metrics = {
+ "token_jaccard": metric_token_jaccard,
+ "bigram_jaccard": metric_bigram_jaccard,
+ "equation_jaccard": metric_equation_jaccard,
+ "directional_coverage": metric_directional_coverage,
+ }
+
+ out = {
+ "model": model_name,
+ "variant": "kernel_variant",
+ "n_stable_drift": len(pairs_stable_drift),
+ "n_brittle_drift": len(pairs_brittle_drift),
+ "n_brittle_collapse": n_brittle_collapse,
+ "n_stable_collapse": n_stable_collapse,
+ "brittle_collapse_rate": n_brittle_collapse /
+ max(1, n_brittle_collapse + len(pairs_brittle_drift)),
+ "metrics": {},
+ }
+ for mname, mfn in metrics.items():
+ s_vals = [mfn(p["var_text"], p["canon"]) for p in pairs_stable_drift]
+ b_vals = [mfn(p["var_text"], p["canon"]) for p in pairs_brittle_drift]
+ U, p = mann_whitney_u(s_vals, b_vals)
+ sm, bm = statistics.fmean(s_vals), statistics.fmean(b_vals)
+ ssd = statistics.pstdev(s_vals) if len(s_vals) > 1 else 0
+ bsd = statistics.pstdev(b_vals) if len(b_vals) > 1 else 0
+ pooled = (((len(s_vals)-1)*ssd**2 + (len(b_vals)-1)*bsd**2)
+ / max(1, len(s_vals)+len(b_vals)-2)) ** 0.5
+ d = (sm - bm) / pooled if pooled > 0 else 0.0
+ out["metrics"][mname] = {
+ "stable_median": statistics.median(s_vals),
+ "stable_mean": sm,
+ "brittle_median": statistics.median(b_vals),
+ "brittle_mean": bm,
+ "delta_median": statistics.median(s_vals) - statistics.median(b_vals),
+ "cohens_d": d,
+ "U": U,
+ "p_two_sided": p,
+ }
+ # Headline bootstrap
+ s_vals = [metric_token_jaccard(p["var_text"], p["canon"]) for p in pairs_stable_drift]
+ b_vals = [metric_token_jaccard(p["var_text"], p["canon"]) for p in pairs_brittle_drift]
+ d_lo, d_hi = bootstrap_ci_cohens_d(s_vals, b_vals, n_iter=400)
+ out["metrics"]["token_jaccard"]["cohens_d_ci"] = [d_lo, d_hi]
+ return out
+
+
+# ---------- Surface re-analysis with canonical reference ----------
+
+def analyze_surface_cell_against_canonical(model_name: str, variant: str,
+ model_dir: Path,
+ canonical_solutions: dict) -> Optional[dict]:
+ """Compare model variant trajectory to dataset canonical variant solution.
+
+ For comparability with KV. No rename canonicalization needed since both
+ sides use the same variant naming.
+ """
+ var_path = find_variant_file(model_dir, variant)
+ orig_path = find_variant_file(model_dir, "original")
+ if not var_path or not orig_path:
+ return None
+ var_by = {p["index"]: p for p in load_problems(var_path)}
+ orig_by = {p["index"]: p for p in load_problems(orig_path)}
+
+ pairs_stable, pairs_brittle = [], []
+ n_brittle_collapse = 0
+ for idx in set(var_by):
+ if idx not in orig_by:
+ continue
+ if orig_by[idx].get("correct") is not True:
+ continue # restrict to model-knows-original
+ pv = var_by[idx]
+ var_correct = pv.get("correct")
+ if var_correct is None:
+ continue
+ var_text = (pv.get("solve") or {}).get("solution") or ""
+ if not var_text:
+ continue
+ canon_var = canonical_solutions.get(idx, {}).get(variant, "")
+ if not canon_var or len(canon_var) < 200:
+ continue
+ if (len(var_text) < COLLAPSE_MIN_CHARS or
+ len(var_text) < COLLAPSE_RATIO * len(canon_var)):
+ if var_correct is False:
+ n_brittle_collapse += 1
+ continue
+ sample = {"index": idx, "var_text": var_text, "canon": canon_var}
+ if var_correct is True:
+ pairs_stable.append(sample)
+ else:
+ pairs_brittle.append(sample)
+
+ if not pairs_stable or not pairs_brittle:
+ return None
+
+ metrics = {
+ "token_jaccard": metric_token_jaccard,
+ "bigram_jaccard": metric_bigram_jaccard,
+ "equation_jaccard": metric_equation_jaccard,
+ "directional_coverage": metric_directional_coverage,
+ }
+ out = {
+ "model": model_name,
+ "variant": variant,
+ "n_stable_drift": len(pairs_stable),
+ "n_brittle_drift": len(pairs_brittle),
+ "n_brittle_collapse": n_brittle_collapse,
+ "brittle_collapse_rate": n_brittle_collapse /
+ max(1, n_brittle_collapse + len(pairs_brittle)),
+ "metrics": {},
+ }
+ for mname, mfn in metrics.items():
+ s_vals = [mfn(p["var_text"], p["canon"]) for p in pairs_stable]
+ b_vals = [mfn(p["var_text"], p["canon"]) for p in pairs_brittle]
+ U, p = mann_whitney_u(s_vals, b_vals)
+ sm, bm = statistics.fmean(s_vals), statistics.fmean(b_vals)
+ ssd = statistics.pstdev(s_vals) if len(s_vals) > 1 else 0
+ bsd = statistics.pstdev(b_vals) if len(b_vals) > 1 else 0
+ pooled = (((len(s_vals)-1)*ssd**2 + (len(b_vals)-1)*bsd**2)
+ / max(1, len(s_vals)+len(b_vals)-2)) ** 0.5
+ d = (sm - bm) / pooled if pooled > 0 else 0.0
+ out["metrics"][mname] = {
+ "stable_median": statistics.median(s_vals),
+ "stable_mean": sm,
+ "brittle_median": statistics.median(b_vals),
+ "brittle_mean": bm,
+ "delta_median": statistics.median(s_vals) - statistics.median(b_vals),
+ "cohens_d": d,
+ "U": U,
+ "p_two_sided": p,
+ }
+ return out
+
+
+def main():
+ print("Loading canonical solutions ...")
+ canon = load_dataset_variant_solutions()
+ print(f" loaded {len(canon)} problems")
+
+ all_models = sorted([d.name for d in RESULTS_DIR.iterdir() if d.is_dir()])
+
+ kv_results = []
+ surface_results = []
+
+ print(f"\n{'KERNEL VARIANT — variant trajectory vs canonical KV solution':<70}")
+ print(f"{'Cell':<32} {'nSd':>4} {'nBd':>4} {'col%':>5} "
+ f"{'sMed':>6} {'bMed':>6} {'d':>6} {'p':>9}")
+ print("-" * 90)
+ for m in all_models:
+ mdir = RESULTS_DIR / m
+ if not mdir.exists():
+ continue
+ res = analyze_kv_cell(m, mdir, canon)
+ if res is None:
+ continue
+ kv_results.append(res)
+ md = res["metrics"]["token_jaccard"]
+ print(f"{m+' / KV':<32} {res['n_stable_drift']:>4} {res['n_brittle_drift']:>4} "
+ f"{res['brittle_collapse_rate']*100:>4.0f}% "
+ f"{md['stable_median']:>6.3f} {md['brittle_median']:>6.3f} "
+ f"{md['cohens_d']:>+6.2f} {md['p_two_sided']:>9.1e}")
+
+ print(f"\n{'SURFACE VARIANT — variant trajectory vs canonical variant solution':<70}")
+ print(f"{'Cell':<46} {'nSd':>4} {'nBd':>4} {'col%':>5} "
+ f"{'sMed':>6} {'bMed':>6} {'d':>6} {'p':>9}")
+ print("-" * 95)
+ for m in all_models:
+ mdir = RESULTS_DIR / m
+ if not mdir.exists():
+ continue
+ for v in SURFACE_VARIANTS:
+ res = analyze_surface_cell_against_canonical(m, v, mdir, canon)
+ if res is None:
+ continue
+ surface_results.append(res)
+ md = res["metrics"]["token_jaccard"]
+ print(f"{m+' / '+v:<46} {res['n_stable_drift']:>4} {res['n_brittle_drift']:>4} "
+ f"{res['brittle_collapse_rate']*100:>4.0f}% "
+ f"{md['stable_median']:>6.3f} {md['brittle_median']:>6.3f} "
+ f"{md['cohens_d']:>+6.2f} {md['p_two_sided']:>9.1e}")
+
+ # Save
+ json.dump(kv_results, open("/home/yurenh2/gap/analysis/kv_overlap_results.json", "w"), indent=2)
+ json.dump(surface_results, open("/home/yurenh2/gap/analysis/surface_canonical_results.json", "w"), indent=2)
+
+ # Aggregate compare
+ print("\n" + "=" * 80)
+ print("AGGREGATE: surface (vs canonical) vs kernel (vs canonical)")
+ print("=" * 80)
+ for tag, results in [("surface", surface_results), ("kernel", kv_results)]:
+ ds = [c["metrics"]["token_jaccard"]["cohens_d"] for c in results]
+ ps = [c["metrics"]["token_jaccard"]["p_two_sided"] for c in results]
+ col = [c["brittle_collapse_rate"] for c in results]
+ if not ds:
+ continue
+ print(f"{tag:<8} cells={len(ds):>3} d_pos={sum(1 for d in ds if d>0):>3}/{len(ds):<3} "
+ f"p<.05={sum(1 for p in ps if p<0.05):>3}/{len(ps):<3} "
+ f"d_med={statistics.median(ds):+.2f} d_mean={statistics.fmean(ds):+.2f} "
+ f"collapse_mean={statistics.fmean(col)*100:.1f}%")
+
+
+if __name__ == "__main__":
+ main()