diff options
Diffstat (limited to 'analysis/kv_overlap.py')
| -rw-r--r-- | analysis/kv_overlap.py | 332 |
1 files changed, 332 insertions, 0 deletions
diff --git a/analysis/kv_overlap.py b/analysis/kv_overlap.py new file mode 100644 index 0000000..137e61f --- /dev/null +++ b/analysis/kv_overlap.py @@ -0,0 +1,332 @@ +"""Kernel-variant structural-overlap analysis (label-free). + +Unlike surface variants, kernel variants change the math, so we cannot use the +model's own original-correct trajectory as a reference. Instead we use the +dataset's canonical kernel-variant solution as the reference. + +Hypothesis: stable (correct on KV) trajectories have higher structural overlap +with the canonical KV solution than brittle (wrong on KV) trajectories. + +For comparability we also recompute the surface analyses using the same +'overlap with canonical solution' metric, so we can compare apples-to-apples +the magnitude of stable-vs-brittle gap between surface and kernel. +""" +from __future__ import annotations +import json +import os +import statistics +from pathlib import Path +from collections import defaultdict +from typing import Optional + +# Reuse helpers from the sibling module +import sys +sys.path.insert(0, str(Path(__file__).parent)) +from structural_overlap import ( + DATASET_DIR, RESULTS_DIR, + load_problems, find_variant_file, + canonicalize_text, normalize_whitespace, + tokens, bigrams, jaccard, extract_math_blocks, + metric_token_jaccard, metric_bigram_jaccard, + metric_directional_coverage, metric_equation_jaccard, + mann_whitney_u, bootstrap_ci_cohens_d, + is_collapse, COLLAPSE_MIN_CHARS, COLLAPSE_RATIO, + SURFACE_VARIANTS, +) + + +def load_dataset_variant_solutions() -> dict: + """Returns: {problem_index: {variant_name: canonical_solution_text}}. + + Includes 'original' (from top-level field) plus all 5 variants. + """ + out = {} + for f in sorted(DATASET_DIR.glob("*.json")): + d = json.load(open(f)) + idx = d.get("index") + cell = {"original": d.get("solution") or "", + "_problem_type": d.get("problem_type")} + for v, vd in d.get("variants", {}).items(): + if isinstance(vd, dict): + cell[v] = vd.get("solution") or "" + out[idx] = cell + return out + + +def load_dataset_maps() -> dict: + """Mirrors structural_overlap.load_dataset_maps but localized for safety.""" + out = {} + for f in sorted(DATASET_DIR.glob("*.json")): + d = json.load(open(f)) + idx = d.get("index") + variants = d.get("variants", {}) + cell = {} + for v in SURFACE_VARIANTS: + vd = variants.get(v, {}) + mp_str = vd.get("map") + if isinstance(mp_str, str): + try: + mp = eval(mp_str, {"__builtins__": {}}, {}) + if isinstance(mp, dict): + cell[v] = {str(k): str(v) for k, v in mp.items()} + except Exception: + pass + elif isinstance(mp_str, dict): + cell[v] = {str(k): str(v) for k, v in mp_str.items()} + out[idx] = cell + return out + + +# ---------- Cell analyzer ---------- + +def analyze_kv_cell(model_name: str, model_dir: Path, + canonical_solutions: dict) -> Optional[dict]: + """Compare model's KV trajectory to dataset canonical KV solution. + + No canonicalization (no rename map for KV — variables match by construction). + """ + orig_path = find_variant_file(model_dir, "original") + var_path = find_variant_file(model_dir, "kernel_variant") + if not orig_path or not var_path: + return None + orig_by = {p["index"]: p for p in load_problems(orig_path)} + var_by = {p["index"]: p for p in load_problems(var_path)} + + pairs_stable_drift = [] + pairs_brittle_drift = [] + n_brittle_collapse = 0 + n_stable_collapse = 0 + + for idx in set(orig_by) & set(var_by): + po, pv = orig_by[idx], var_by[idx] + if po.get("correct") is not True: + continue # Restrict to "model already gets the original" + var_correct = pv.get("correct") + if var_correct is None: + continue + var_text = (pv.get("solve") or {}).get("solution") or "" + if not var_text: + continue + canon_kv = canonical_solutions.get(idx, {}).get("kernel_variant", "") + if not canon_kv or len(canon_kv) < 200: + continue + # Collapse rule: variant text < 200 chars OR < 25% of canonical solution + collapse = (len(var_text) < COLLAPSE_MIN_CHARS or + len(var_text) < COLLAPSE_RATIO * len(canon_kv)) + sample = {"index": idx, "var_text": var_text, "canon": canon_kv} + if var_correct is True: + if collapse: + n_stable_collapse += 1 + else: + pairs_stable_drift.append(sample) + else: + if collapse: + n_brittle_collapse += 1 + else: + pairs_brittle_drift.append(sample) + + if not pairs_stable_drift or not pairs_brittle_drift: + return None + + metrics = { + "token_jaccard": metric_token_jaccard, + "bigram_jaccard": metric_bigram_jaccard, + "equation_jaccard": metric_equation_jaccard, + "directional_coverage": metric_directional_coverage, + } + + out = { + "model": model_name, + "variant": "kernel_variant", + "n_stable_drift": len(pairs_stable_drift), + "n_brittle_drift": len(pairs_brittle_drift), + "n_brittle_collapse": n_brittle_collapse, + "n_stable_collapse": n_stable_collapse, + "brittle_collapse_rate": n_brittle_collapse / + max(1, n_brittle_collapse + len(pairs_brittle_drift)), + "metrics": {}, + } + for mname, mfn in metrics.items(): + s_vals = [mfn(p["var_text"], p["canon"]) for p in pairs_stable_drift] + b_vals = [mfn(p["var_text"], p["canon"]) for p in pairs_brittle_drift] + U, p = mann_whitney_u(s_vals, b_vals) + sm, bm = statistics.fmean(s_vals), statistics.fmean(b_vals) + ssd = statistics.pstdev(s_vals) if len(s_vals) > 1 else 0 + bsd = statistics.pstdev(b_vals) if len(b_vals) > 1 else 0 + pooled = (((len(s_vals)-1)*ssd**2 + (len(b_vals)-1)*bsd**2) + / max(1, len(s_vals)+len(b_vals)-2)) ** 0.5 + d = (sm - bm) / pooled if pooled > 0 else 0.0 + out["metrics"][mname] = { + "stable_median": statistics.median(s_vals), + "stable_mean": sm, + "brittle_median": statistics.median(b_vals), + "brittle_mean": bm, + "delta_median": statistics.median(s_vals) - statistics.median(b_vals), + "cohens_d": d, + "U": U, + "p_two_sided": p, + } + # Headline bootstrap + s_vals = [metric_token_jaccard(p["var_text"], p["canon"]) for p in pairs_stable_drift] + b_vals = [metric_token_jaccard(p["var_text"], p["canon"]) for p in pairs_brittle_drift] + d_lo, d_hi = bootstrap_ci_cohens_d(s_vals, b_vals, n_iter=400) + out["metrics"]["token_jaccard"]["cohens_d_ci"] = [d_lo, d_hi] + return out + + +# ---------- Surface re-analysis with canonical reference ---------- + +def analyze_surface_cell_against_canonical(model_name: str, variant: str, + model_dir: Path, + canonical_solutions: dict) -> Optional[dict]: + """Compare model variant trajectory to dataset canonical variant solution. + + For comparability with KV. No rename canonicalization needed since both + sides use the same variant naming. + """ + var_path = find_variant_file(model_dir, variant) + orig_path = find_variant_file(model_dir, "original") + if not var_path or not orig_path: + return None + var_by = {p["index"]: p for p in load_problems(var_path)} + orig_by = {p["index"]: p for p in load_problems(orig_path)} + + pairs_stable, pairs_brittle = [], [] + n_brittle_collapse = 0 + for idx in set(var_by): + if idx not in orig_by: + continue + if orig_by[idx].get("correct") is not True: + continue # restrict to model-knows-original + pv = var_by[idx] + var_correct = pv.get("correct") + if var_correct is None: + continue + var_text = (pv.get("solve") or {}).get("solution") or "" + if not var_text: + continue + canon_var = canonical_solutions.get(idx, {}).get(variant, "") + if not canon_var or len(canon_var) < 200: + continue + if (len(var_text) < COLLAPSE_MIN_CHARS or + len(var_text) < COLLAPSE_RATIO * len(canon_var)): + if var_correct is False: + n_brittle_collapse += 1 + continue + sample = {"index": idx, "var_text": var_text, "canon": canon_var} + if var_correct is True: + pairs_stable.append(sample) + else: + pairs_brittle.append(sample) + + if not pairs_stable or not pairs_brittle: + return None + + metrics = { + "token_jaccard": metric_token_jaccard, + "bigram_jaccard": metric_bigram_jaccard, + "equation_jaccard": metric_equation_jaccard, + "directional_coverage": metric_directional_coverage, + } + out = { + "model": model_name, + "variant": variant, + "n_stable_drift": len(pairs_stable), + "n_brittle_drift": len(pairs_brittle), + "n_brittle_collapse": n_brittle_collapse, + "brittle_collapse_rate": n_brittle_collapse / + max(1, n_brittle_collapse + len(pairs_brittle)), + "metrics": {}, + } + for mname, mfn in metrics.items(): + s_vals = [mfn(p["var_text"], p["canon"]) for p in pairs_stable] + b_vals = [mfn(p["var_text"], p["canon"]) for p in pairs_brittle] + U, p = mann_whitney_u(s_vals, b_vals) + sm, bm = statistics.fmean(s_vals), statistics.fmean(b_vals) + ssd = statistics.pstdev(s_vals) if len(s_vals) > 1 else 0 + bsd = statistics.pstdev(b_vals) if len(b_vals) > 1 else 0 + pooled = (((len(s_vals)-1)*ssd**2 + (len(b_vals)-1)*bsd**2) + / max(1, len(s_vals)+len(b_vals)-2)) ** 0.5 + d = (sm - bm) / pooled if pooled > 0 else 0.0 + out["metrics"][mname] = { + "stable_median": statistics.median(s_vals), + "stable_mean": sm, + "brittle_median": statistics.median(b_vals), + "brittle_mean": bm, + "delta_median": statistics.median(s_vals) - statistics.median(b_vals), + "cohens_d": d, + "U": U, + "p_two_sided": p, + } + return out + + +def main(): + print("Loading canonical solutions ...") + canon = load_dataset_variant_solutions() + print(f" loaded {len(canon)} problems") + + all_models = sorted([d.name for d in RESULTS_DIR.iterdir() if d.is_dir()]) + + kv_results = [] + surface_results = [] + + print(f"\n{'KERNEL VARIANT — variant trajectory vs canonical KV solution':<70}") + print(f"{'Cell':<32} {'nSd':>4} {'nBd':>4} {'col%':>5} " + f"{'sMed':>6} {'bMed':>6} {'d':>6} {'p':>9}") + print("-" * 90) + for m in all_models: + mdir = RESULTS_DIR / m + if not mdir.exists(): + continue + res = analyze_kv_cell(m, mdir, canon) + if res is None: + continue + kv_results.append(res) + md = res["metrics"]["token_jaccard"] + print(f"{m+' / KV':<32} {res['n_stable_drift']:>4} {res['n_brittle_drift']:>4} " + f"{res['brittle_collapse_rate']*100:>4.0f}% " + f"{md['stable_median']:>6.3f} {md['brittle_median']:>6.3f} " + f"{md['cohens_d']:>+6.2f} {md['p_two_sided']:>9.1e}") + + print(f"\n{'SURFACE VARIANT — variant trajectory vs canonical variant solution':<70}") + print(f"{'Cell':<46} {'nSd':>4} {'nBd':>4} {'col%':>5} " + f"{'sMed':>6} {'bMed':>6} {'d':>6} {'p':>9}") + print("-" * 95) + for m in all_models: + mdir = RESULTS_DIR / m + if not mdir.exists(): + continue + for v in SURFACE_VARIANTS: + res = analyze_surface_cell_against_canonical(m, v, mdir, canon) + if res is None: + continue + surface_results.append(res) + md = res["metrics"]["token_jaccard"] + print(f"{m+' / '+v:<46} {res['n_stable_drift']:>4} {res['n_brittle_drift']:>4} " + f"{res['brittle_collapse_rate']*100:>4.0f}% " + f"{md['stable_median']:>6.3f} {md['brittle_median']:>6.3f} " + f"{md['cohens_d']:>+6.2f} {md['p_two_sided']:>9.1e}") + + # Save + json.dump(kv_results, open("/home/yurenh2/gap/analysis/kv_overlap_results.json", "w"), indent=2) + json.dump(surface_results, open("/home/yurenh2/gap/analysis/surface_canonical_results.json", "w"), indent=2) + + # Aggregate compare + print("\n" + "=" * 80) + print("AGGREGATE: surface (vs canonical) vs kernel (vs canonical)") + print("=" * 80) + for tag, results in [("surface", surface_results), ("kernel", kv_results)]: + ds = [c["metrics"]["token_jaccard"]["cohens_d"] for c in results] + ps = [c["metrics"]["token_jaccard"]["p_two_sided"] for c in results] + col = [c["brittle_collapse_rate"] for c in results] + if not ds: + continue + print(f"{tag:<8} cells={len(ds):>3} d_pos={sum(1 for d in ds if d>0):>3}/{len(ds):<3} " + f"p<.05={sum(1 for p in ps if p<0.05):>3}/{len(ps):<3} " + f"d_med={statistics.median(ds):+.2f} d_mean={statistics.fmean(ds):+.2f} " + f"collapse_mean={statistics.fmean(col)*100:.1f}%") + + +if __name__ == "__main__": + main() |
