summaryrefslogtreecommitdiff
path: root/analysis/cross_model_agreement.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@illinois.edu>2026-04-08 22:06:05 -0500
committerYuren Hao <yurenh2@illinois.edu>2026-04-08 22:06:05 -0500
commit05704d0eb2fa59fe727652465b07db40bcb06c38 (patch)
tree8904aca836cf552fd1a5ae8c2174e9f91e70bbbc /analysis/cross_model_agreement.py
Initial release: GAP framework
- Full pipeline: variant generation, multi-judge verification, evaluation - Loaders for OpenAI / Anthropic / Google / xAI / OpenRouter / vLLM - Framework-level mechanism analyses: paired structural overlap, repairability rescue, self-correction probe, cross-model agreement, topic x problem-type interaction - Unicode -> bare-LaTeX cleaner + audit + spot-check - Mirrors https://huggingface.co/datasets/blackhao0426/PutnamGAP
Diffstat (limited to 'analysis/cross_model_agreement.py')
-rw-r--r--analysis/cross_model_agreement.py180
1 files changed, 180 insertions, 0 deletions
diff --git a/analysis/cross_model_agreement.py b/analysis/cross_model_agreement.py
new file mode 100644
index 0000000..fb9a571
--- /dev/null
+++ b/analysis/cross_model_agreement.py
@@ -0,0 +1,180 @@
+"""Cross-model agreement analysis: which problems are universally hard?
+
+For each (variant, problem) cell, count how many models fail (correct=False).
+Identify "universally hard" problems (failed by ≥80% of models on the variant)
+and "universally easy" (correct by ≥80% on the variant). Then check whether
+the universally hard *flip set* is dominated by certain topics, problem types,
+or years.
+
+Outputs:
+- Per-variant histogram of failure counts
+- "Universal flip" cases: original correct by ≥80% of models, variant wrong by ≥80%
+- These are the cleanest signals of variant-induced fragility because they
+ rule out problem-specific quirks
+"""
+from __future__ import annotations
+import json
+import sys
+from pathlib import Path
+from collections import defaultdict, Counter
+
+THIS_DIR = Path(__file__).resolve().parent
+sys.path.insert(0, str(THIS_DIR))
+from structural_overlap import find_variant_file, load_problems, RESULTS_DIR, SURFACE_VARIANTS
+
+DATASET_DIR = Path("/home/yurenh2/gap/putnam-bench-anon/dataset")
+
+
+def load_metadata():
+ """Load problem-level metadata: type, tag, difficulty, year."""
+ out = {}
+ for f in sorted(DATASET_DIR.glob("*.json")):
+ d = json.load(open(f))
+ idx = d.get("index")
+ out[idx] = {
+ "type": d.get("type"),
+ "tag": d.get("tag"),
+ "difficulty": d.get("difficulty"),
+ "problem_type": d.get("problem_type"),
+ "year": int(idx.split("-")[0]) if idx and "-" in idx else None,
+ }
+ return out
+
+
+def main():
+ base = RESULTS_DIR
+ models = sorted([d.name for d in base.iterdir() if d.is_dir()])
+ print(f"Loading {len(models)} models ...")
+ metadata = load_metadata()
+
+ # correct_table[(variant, idx)][model] = bool
+ correct_table = defaultdict(dict)
+ for m in models:
+ mdir = base / m
+ for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]:
+ vp = find_variant_file(mdir, v)
+ if not vp:
+ continue
+ for p in load_problems(vp):
+ idx = p.get("index")
+ correct = p.get("correct")
+ if idx is None or correct is None:
+ continue
+ correct_table[(v, idx)][m] = correct
+
+ print(f"Loaded {len(correct_table)} (variant, problem) cells.\n")
+
+ # Per-variant histogram of correct counts (out of N models)
+ print("=== HISTOGRAM OF CORRECT-COUNT ACROSS MODELS ===")
+ print("(How many models get each problem right per variant)\n")
+ print(f"{'Variant':<24} {'mean correct/N':>16} {'median':>9} {'#unanimous-fail':>17} {'#unanimous-pass':>17}")
+ print("-" * 90)
+ for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]:
+ cells = [d for (vv, idx), d in correct_table.items() if vv == v]
+ if not cells:
+ continue
+ counts = [sum(1 for vv in cell.values() if vv) / len(cell) for cell in cells]
+ unanimous_fail = sum(1 for cell in cells if not any(cell.values()) and len(cell) >= 3)
+ unanimous_pass = sum(1 for cell in cells if all(cell.values()) and len(cell) >= 3)
+ import statistics
+ print(f"{v:<24} {statistics.fmean(counts)*100:>14.1f}% {statistics.median(counts)*100:>7.1f}% "
+ f"{unanimous_fail:>17} {unanimous_pass:>17}")
+
+ # Universal flip cases: original correct by ≥80% of models, variant wrong by ≥80%
+ print(f"\n\n=== UNIVERSAL FLIP CASES (orig ≥80% correct, variant ≥80% wrong) ===\n")
+ print("These are the cleanest signals of variant-induced fragility.\n")
+ print(f"{'Variant':<24} {'# universal flips':>20}")
+ print("-" * 50)
+ universal_flips = defaultdict(list)
+ for v in SURFACE_VARIANTS + ["kernel_variant"]:
+ for idx in {ii for (vv, ii) in correct_table if vv == "original"}:
+ orig_cell = correct_table.get(("original", idx), {})
+ var_cell = correct_table.get((v, idx), {})
+ common = set(orig_cell) & set(var_cell)
+ if len(common) < 5: continue
+ orig_rate = sum(1 for m in common if orig_cell[m]) / len(common)
+ var_rate = sum(1 for m in common if var_cell[m]) / len(common)
+ if orig_rate >= 0.80 and var_rate <= 0.20:
+ universal_flips[v].append((idx, orig_rate, var_rate))
+ print(f"{v:<24} {len(universal_flips[v]):>20}")
+
+ # Topic / problem_type / difficulty / year breakdown for universal flips
+ print(f"\n\n=== TOPIC BREAKDOWN OF UNIVERSAL FLIPS ===\n")
+ for v in SURFACE_VARIANTS + ["kernel_variant"]:
+ if not universal_flips[v]: continue
+ print(f"--- {v} ({len(universal_flips[v])} universal flips) ---")
+ topics = Counter()
+ ptypes = Counter()
+ difficulties = Counter()
+ years = Counter()
+ for idx, _, _ in universal_flips[v]:
+ md = metadata.get(idx, {})
+ tag = md.get("tag")
+ # tag may be a list (multi-tag) or a string
+ if isinstance(tag, list):
+ for t in tag: topics[t] += 1
+ elif tag:
+ topics[tag] += 1
+ else:
+ topics["?"] += 1
+ ptypes[md.get("problem_type") or "?"] += 1
+ diff = md.get("difficulty")
+ if isinstance(diff, list): diff = diff[0] if diff else "?"
+ difficulties[diff or "?"] += 1
+ year = md.get("year")
+ if year:
+ # Bin years by decade
+ decade = (year // 10) * 10
+ years[f"{decade}s"] += 1
+ print(f" topics: {dict(topics.most_common(8))}")
+ print(f" problem_type:{dict(ptypes)}")
+ print(f" difficulties:{dict(difficulties.most_common(6))}")
+ print(f" decades: {dict(sorted(years.items()))}")
+ print()
+
+ # Save universal flips for later analysis
+ out = {v: [{"index": idx, "orig_rate": o, "var_rate": vr}
+ for (idx, o, vr) in flips]
+ for v, flips in universal_flips.items()}
+ json.dump(out, open(THIS_DIR / "universal_flips.json", "w"), indent=2)
+ print(f"\nSaved -> analysis/universal_flips.json")
+
+ # Topic-stratified analysis: failure rate per topic per variant
+ print(f"\n\n=== ACCURACY BY TOPIC × VARIANT (mean across models) ===\n")
+ by_topic_variant = defaultdict(lambda: defaultdict(list))
+ for (v, idx), cell in correct_table.items():
+ md = metadata.get(idx, {})
+ tag = md.get("tag")
+ if not tag or len(cell) < 5: continue
+ # If multiple tags, attribute the same rate to each — keeps it simple
+ topics_for_problem = tag if isinstance(tag, list) else [tag]
+ rate = sum(1 for vv in cell.values() if vv) / len(cell)
+ for t in topics_for_problem:
+ by_topic_variant[t][v].append(rate)
+
+ topics_to_show = ["ALG", "ANA", "NT", "COMB", "GEO"]
+ print(f"{'Topic':<8}", end="")
+ for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]:
+ short = {"original":"orig","descriptive_long":"DL","descriptive_long_confusing":"DLC",
+ "descriptive_long_misleading":"DLM","garbled_string":"GS","kernel_variant":"KV"}[v]
+ print(f" {short:>5}", end="")
+ print(" Δ_orig→KV")
+ print("-" * 70)
+ for t in topics_to_show:
+ if t not in by_topic_variant: continue
+ row = by_topic_variant[t]
+ if "original" not in row: continue
+ orig_mean = statistics.fmean(row["original"]) * 100
+ print(f"{t:<8}", end="")
+ for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]:
+ if v in row:
+ m = statistics.fmean(row[v]) * 100
+ print(f" {m:>4.1f}%", end="")
+ else:
+ print(f" {'-':>5}", end="")
+ kv_mean = statistics.fmean(row.get("kernel_variant", [0])) * 100
+ print(f" {kv_mean - orig_mean:+5.1f}pp")
+
+
+if __name__ == "__main__":
+ main()