1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
|
"""Cross-model agreement analysis: which problems are universally hard?
For each (variant, problem) cell, count how many models fail (correct=False).
Identify "universally hard" problems (failed by ≥80% of models on the variant)
and "universally easy" (correct by ≥80% on the variant). Then check whether
the universally hard *flip set* is dominated by certain topics, problem types,
or years.
Outputs:
- Per-variant histogram of failure counts
- "Universal flip" cases: original correct by ≥80% of models, variant wrong by ≥80%
- These are the cleanest signals of variant-induced fragility because they
rule out problem-specific quirks
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
from collections import defaultdict, Counter
THIS_DIR = Path(__file__).resolve().parent
sys.path.insert(0, str(THIS_DIR))
from structural_overlap import find_variant_file, load_problems, RESULTS_DIR, SURFACE_VARIANTS
DATASET_DIR = Path("/home/yurenh2/gap/putnam-bench-anon/dataset")
def load_metadata():
"""Load problem-level metadata: type, tag, difficulty, year."""
out = {}
for f in sorted(DATASET_DIR.glob("*.json")):
d = json.load(open(f))
idx = d.get("index")
out[idx] = {
"type": d.get("type"),
"tag": d.get("tag"),
"difficulty": d.get("difficulty"),
"problem_type": d.get("problem_type"),
"year": int(idx.split("-")[0]) if idx and "-" in idx else None,
}
return out
def main():
base = RESULTS_DIR
models = sorted([d.name for d in base.iterdir() if d.is_dir()])
print(f"Loading {len(models)} models ...")
metadata = load_metadata()
# correct_table[(variant, idx)][model] = bool
correct_table = defaultdict(dict)
for m in models:
mdir = base / m
for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]:
vp = find_variant_file(mdir, v)
if not vp:
continue
for p in load_problems(vp):
idx = p.get("index")
correct = p.get("correct")
if idx is None or correct is None:
continue
correct_table[(v, idx)][m] = correct
print(f"Loaded {len(correct_table)} (variant, problem) cells.\n")
# Per-variant histogram of correct counts (out of N models)
print("=== HISTOGRAM OF CORRECT-COUNT ACROSS MODELS ===")
print("(How many models get each problem right per variant)\n")
print(f"{'Variant':<24} {'mean correct/N':>16} {'median':>9} {'#unanimous-fail':>17} {'#unanimous-pass':>17}")
print("-" * 90)
for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]:
cells = [d for (vv, idx), d in correct_table.items() if vv == v]
if not cells:
continue
counts = [sum(1 for vv in cell.values() if vv) / len(cell) for cell in cells]
unanimous_fail = sum(1 for cell in cells if not any(cell.values()) and len(cell) >= 3)
unanimous_pass = sum(1 for cell in cells if all(cell.values()) and len(cell) >= 3)
import statistics
print(f"{v:<24} {statistics.fmean(counts)*100:>14.1f}% {statistics.median(counts)*100:>7.1f}% "
f"{unanimous_fail:>17} {unanimous_pass:>17}")
# Universal flip cases: original correct by ≥80% of models, variant wrong by ≥80%
print(f"\n\n=== UNIVERSAL FLIP CASES (orig ≥80% correct, variant ≥80% wrong) ===\n")
print("These are the cleanest signals of variant-induced fragility.\n")
print(f"{'Variant':<24} {'# universal flips':>20}")
print("-" * 50)
universal_flips = defaultdict(list)
for v in SURFACE_VARIANTS + ["kernel_variant"]:
for idx in {ii for (vv, ii) in correct_table if vv == "original"}:
orig_cell = correct_table.get(("original", idx), {})
var_cell = correct_table.get((v, idx), {})
common = set(orig_cell) & set(var_cell)
if len(common) < 5: continue
orig_rate = sum(1 for m in common if orig_cell[m]) / len(common)
var_rate = sum(1 for m in common if var_cell[m]) / len(common)
if orig_rate >= 0.80 and var_rate <= 0.20:
universal_flips[v].append((idx, orig_rate, var_rate))
print(f"{v:<24} {len(universal_flips[v]):>20}")
# Topic / problem_type / difficulty / year breakdown for universal flips
print(f"\n\n=== TOPIC BREAKDOWN OF UNIVERSAL FLIPS ===\n")
for v in SURFACE_VARIANTS + ["kernel_variant"]:
if not universal_flips[v]: continue
print(f"--- {v} ({len(universal_flips[v])} universal flips) ---")
topics = Counter()
ptypes = Counter()
difficulties = Counter()
years = Counter()
for idx, _, _ in universal_flips[v]:
md = metadata.get(idx, {})
tag = md.get("tag")
# tag may be a list (multi-tag) or a string
if isinstance(tag, list):
for t in tag: topics[t] += 1
elif tag:
topics[tag] += 1
else:
topics["?"] += 1
ptypes[md.get("problem_type") or "?"] += 1
diff = md.get("difficulty")
if isinstance(diff, list): diff = diff[0] if diff else "?"
difficulties[diff or "?"] += 1
year = md.get("year")
if year:
# Bin years by decade
decade = (year // 10) * 10
years[f"{decade}s"] += 1
print(f" topics: {dict(topics.most_common(8))}")
print(f" problem_type:{dict(ptypes)}")
print(f" difficulties:{dict(difficulties.most_common(6))}")
print(f" decades: {dict(sorted(years.items()))}")
print()
# Save universal flips for later analysis
out = {v: [{"index": idx, "orig_rate": o, "var_rate": vr}
for (idx, o, vr) in flips]
for v, flips in universal_flips.items()}
json.dump(out, open(THIS_DIR / "universal_flips.json", "w"), indent=2)
print(f"\nSaved -> analysis/universal_flips.json")
# Topic-stratified analysis: failure rate per topic per variant
print(f"\n\n=== ACCURACY BY TOPIC × VARIANT (mean across models) ===\n")
by_topic_variant = defaultdict(lambda: defaultdict(list))
for (v, idx), cell in correct_table.items():
md = metadata.get(idx, {})
tag = md.get("tag")
if not tag or len(cell) < 5: continue
# If multiple tags, attribute the same rate to each — keeps it simple
topics_for_problem = tag if isinstance(tag, list) else [tag]
rate = sum(1 for vv in cell.values() if vv) / len(cell)
for t in topics_for_problem:
by_topic_variant[t][v].append(rate)
topics_to_show = ["ALG", "ANA", "NT", "COMB", "GEO"]
print(f"{'Topic':<8}", end="")
for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]:
short = {"original":"orig","descriptive_long":"DL","descriptive_long_confusing":"DLC",
"descriptive_long_misleading":"DLM","garbled_string":"GS","kernel_variant":"KV"}[v]
print(f" {short:>5}", end="")
print(" Δ_orig→KV")
print("-" * 70)
for t in topics_to_show:
if t not in by_topic_variant: continue
row = by_topic_variant[t]
if "original" not in row: continue
orig_mean = statistics.fmean(row["original"]) * 100
print(f"{t:<8}", end="")
for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]:
if v in row:
m = statistics.fmean(row[v]) * 100
print(f" {m:>4.1f}%", end="")
else:
print(f" {'-':>5}", end="")
kv_mean = statistics.fmean(row.get("kernel_variant", [0])) * 100
print(f" {kv_mean - orig_mean:+5.1f}pp")
if __name__ == "__main__":
main()
|