diff options
Diffstat (limited to 'analysis/cross_model_agreement.py')
| -rw-r--r-- | analysis/cross_model_agreement.py | 180 |
1 files changed, 180 insertions, 0 deletions
diff --git a/analysis/cross_model_agreement.py b/analysis/cross_model_agreement.py new file mode 100644 index 0000000..fb9a571 --- /dev/null +++ b/analysis/cross_model_agreement.py @@ -0,0 +1,180 @@ +"""Cross-model agreement analysis: which problems are universally hard? + +For each (variant, problem) cell, count how many models fail (correct=False). +Identify "universally hard" problems (failed by ≥80% of models on the variant) +and "universally easy" (correct by ≥80% on the variant). Then check whether +the universally hard *flip set* is dominated by certain topics, problem types, +or years. + +Outputs: +- Per-variant histogram of failure counts +- "Universal flip" cases: original correct by ≥80% of models, variant wrong by ≥80% +- These are the cleanest signals of variant-induced fragility because they + rule out problem-specific quirks +""" +from __future__ import annotations +import json +import sys +from pathlib import Path +from collections import defaultdict, Counter + +THIS_DIR = Path(__file__).resolve().parent +sys.path.insert(0, str(THIS_DIR)) +from structural_overlap import find_variant_file, load_problems, RESULTS_DIR, SURFACE_VARIANTS + +DATASET_DIR = Path("/home/yurenh2/gap/putnam-bench-anon/dataset") + + +def load_metadata(): + """Load problem-level metadata: type, tag, difficulty, year.""" + out = {} + for f in sorted(DATASET_DIR.glob("*.json")): + d = json.load(open(f)) + idx = d.get("index") + out[idx] = { + "type": d.get("type"), + "tag": d.get("tag"), + "difficulty": d.get("difficulty"), + "problem_type": d.get("problem_type"), + "year": int(idx.split("-")[0]) if idx and "-" in idx else None, + } + return out + + +def main(): + base = RESULTS_DIR + models = sorted([d.name for d in base.iterdir() if d.is_dir()]) + print(f"Loading {len(models)} models ...") + metadata = load_metadata() + + # correct_table[(variant, idx)][model] = bool + correct_table = defaultdict(dict) + for m in models: + mdir = base / m + for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]: + vp = find_variant_file(mdir, v) + if not vp: + continue + for p in load_problems(vp): + idx = p.get("index") + correct = p.get("correct") + if idx is None or correct is None: + continue + correct_table[(v, idx)][m] = correct + + print(f"Loaded {len(correct_table)} (variant, problem) cells.\n") + + # Per-variant histogram of correct counts (out of N models) + print("=== HISTOGRAM OF CORRECT-COUNT ACROSS MODELS ===") + print("(How many models get each problem right per variant)\n") + print(f"{'Variant':<24} {'mean correct/N':>16} {'median':>9} {'#unanimous-fail':>17} {'#unanimous-pass':>17}") + print("-" * 90) + for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]: + cells = [d for (vv, idx), d in correct_table.items() if vv == v] + if not cells: + continue + counts = [sum(1 for vv in cell.values() if vv) / len(cell) for cell in cells] + unanimous_fail = sum(1 for cell in cells if not any(cell.values()) and len(cell) >= 3) + unanimous_pass = sum(1 for cell in cells if all(cell.values()) and len(cell) >= 3) + import statistics + print(f"{v:<24} {statistics.fmean(counts)*100:>14.1f}% {statistics.median(counts)*100:>7.1f}% " + f"{unanimous_fail:>17} {unanimous_pass:>17}") + + # Universal flip cases: original correct by ≥80% of models, variant wrong by ≥80% + print(f"\n\n=== UNIVERSAL FLIP CASES (orig ≥80% correct, variant ≥80% wrong) ===\n") + print("These are the cleanest signals of variant-induced fragility.\n") + print(f"{'Variant':<24} {'# universal flips':>20}") + print("-" * 50) + universal_flips = defaultdict(list) + for v in SURFACE_VARIANTS + ["kernel_variant"]: + for idx in {ii for (vv, ii) in correct_table if vv == "original"}: + orig_cell = correct_table.get(("original", idx), {}) + var_cell = correct_table.get((v, idx), {}) + common = set(orig_cell) & set(var_cell) + if len(common) < 5: continue + orig_rate = sum(1 for m in common if orig_cell[m]) / len(common) + var_rate = sum(1 for m in common if var_cell[m]) / len(common) + if orig_rate >= 0.80 and var_rate <= 0.20: + universal_flips[v].append((idx, orig_rate, var_rate)) + print(f"{v:<24} {len(universal_flips[v]):>20}") + + # Topic / problem_type / difficulty / year breakdown for universal flips + print(f"\n\n=== TOPIC BREAKDOWN OF UNIVERSAL FLIPS ===\n") + for v in SURFACE_VARIANTS + ["kernel_variant"]: + if not universal_flips[v]: continue + print(f"--- {v} ({len(universal_flips[v])} universal flips) ---") + topics = Counter() + ptypes = Counter() + difficulties = Counter() + years = Counter() + for idx, _, _ in universal_flips[v]: + md = metadata.get(idx, {}) + tag = md.get("tag") + # tag may be a list (multi-tag) or a string + if isinstance(tag, list): + for t in tag: topics[t] += 1 + elif tag: + topics[tag] += 1 + else: + topics["?"] += 1 + ptypes[md.get("problem_type") or "?"] += 1 + diff = md.get("difficulty") + if isinstance(diff, list): diff = diff[0] if diff else "?" + difficulties[diff or "?"] += 1 + year = md.get("year") + if year: + # Bin years by decade + decade = (year // 10) * 10 + years[f"{decade}s"] += 1 + print(f" topics: {dict(topics.most_common(8))}") + print(f" problem_type:{dict(ptypes)}") + print(f" difficulties:{dict(difficulties.most_common(6))}") + print(f" decades: {dict(sorted(years.items()))}") + print() + + # Save universal flips for later analysis + out = {v: [{"index": idx, "orig_rate": o, "var_rate": vr} + for (idx, o, vr) in flips] + for v, flips in universal_flips.items()} + json.dump(out, open(THIS_DIR / "universal_flips.json", "w"), indent=2) + print(f"\nSaved -> analysis/universal_flips.json") + + # Topic-stratified analysis: failure rate per topic per variant + print(f"\n\n=== ACCURACY BY TOPIC × VARIANT (mean across models) ===\n") + by_topic_variant = defaultdict(lambda: defaultdict(list)) + for (v, idx), cell in correct_table.items(): + md = metadata.get(idx, {}) + tag = md.get("tag") + if not tag or len(cell) < 5: continue + # If multiple tags, attribute the same rate to each — keeps it simple + topics_for_problem = tag if isinstance(tag, list) else [tag] + rate = sum(1 for vv in cell.values() if vv) / len(cell) + for t in topics_for_problem: + by_topic_variant[t][v].append(rate) + + topics_to_show = ["ALG", "ANA", "NT", "COMB", "GEO"] + print(f"{'Topic':<8}", end="") + for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]: + short = {"original":"orig","descriptive_long":"DL","descriptive_long_confusing":"DLC", + "descriptive_long_misleading":"DLM","garbled_string":"GS","kernel_variant":"KV"}[v] + print(f" {short:>5}", end="") + print(" Δ_orig→KV") + print("-" * 70) + for t in topics_to_show: + if t not in by_topic_variant: continue + row = by_topic_variant[t] + if "original" not in row: continue + orig_mean = statistics.fmean(row["original"]) * 100 + print(f"{t:<8}", end="") + for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]: + if v in row: + m = statistics.fmean(row[v]) * 100 + print(f" {m:>4.1f}%", end="") + else: + print(f" {'-':>5}", end="") + kv_mean = statistics.fmean(row.get("kernel_variant", [0])) * 100 + print(f" {kv_mean - orig_mean:+5.1f}pp") + + +if __name__ == "__main__": + main() |
