summaryrefslogtreecommitdiff
path: root/analysis/cross_model_agreement.py
blob: fb9a57198f71816cf1f13028d30f9fd0f96e4adc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""Cross-model agreement analysis: which problems are universally hard?

For each (variant, problem) cell, count how many models fail (correct=False).
Identify "universally hard" problems (failed by ≥80% of models on the variant)
and "universally easy" (correct by ≥80% on the variant). Then check whether
the universally hard *flip set* is dominated by certain topics, problem types,
or years.

Outputs:
- Per-variant histogram of failure counts
- "Universal flip" cases: original correct by ≥80% of models, variant wrong by ≥80%
- These are the cleanest signals of variant-induced fragility because they
  rule out problem-specific quirks
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
from collections import defaultdict, Counter

THIS_DIR = Path(__file__).resolve().parent
sys.path.insert(0, str(THIS_DIR))
from structural_overlap import find_variant_file, load_problems, RESULTS_DIR, SURFACE_VARIANTS

DATASET_DIR = Path("/home/yurenh2/gap/putnam-bench-anon/dataset")


def load_metadata():
    """Load problem-level metadata: type, tag, difficulty, year."""
    out = {}
    for f in sorted(DATASET_DIR.glob("*.json")):
        d = json.load(open(f))
        idx = d.get("index")
        out[idx] = {
            "type": d.get("type"),
            "tag": d.get("tag"),
            "difficulty": d.get("difficulty"),
            "problem_type": d.get("problem_type"),
            "year": int(idx.split("-")[0]) if idx and "-" in idx else None,
        }
    return out


def main():
    base = RESULTS_DIR
    models = sorted([d.name for d in base.iterdir() if d.is_dir()])
    print(f"Loading {len(models)} models ...")
    metadata = load_metadata()

    # correct_table[(variant, idx)][model] = bool
    correct_table = defaultdict(dict)
    for m in models:
        mdir = base / m
        for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]:
            vp = find_variant_file(mdir, v)
            if not vp:
                continue
            for p in load_problems(vp):
                idx = p.get("index")
                correct = p.get("correct")
                if idx is None or correct is None:
                    continue
                correct_table[(v, idx)][m] = correct

    print(f"Loaded {len(correct_table)} (variant, problem) cells.\n")

    # Per-variant histogram of correct counts (out of N models)
    print("=== HISTOGRAM OF CORRECT-COUNT ACROSS MODELS ===")
    print("(How many models get each problem right per variant)\n")
    print(f"{'Variant':<24} {'mean correct/N':>16} {'median':>9} {'#unanimous-fail':>17} {'#unanimous-pass':>17}")
    print("-" * 90)
    for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]:
        cells = [d for (vv, idx), d in correct_table.items() if vv == v]
        if not cells:
            continue
        counts = [sum(1 for vv in cell.values() if vv) / len(cell) for cell in cells]
        unanimous_fail = sum(1 for cell in cells if not any(cell.values()) and len(cell) >= 3)
        unanimous_pass = sum(1 for cell in cells if all(cell.values()) and len(cell) >= 3)
        import statistics
        print(f"{v:<24} {statistics.fmean(counts)*100:>14.1f}%  {statistics.median(counts)*100:>7.1f}% "
              f"{unanimous_fail:>17} {unanimous_pass:>17}")

    # Universal flip cases: original correct by ≥80% of models, variant wrong by ≥80%
    print(f"\n\n=== UNIVERSAL FLIP CASES (orig ≥80% correct, variant ≥80% wrong) ===\n")
    print("These are the cleanest signals of variant-induced fragility.\n")
    print(f"{'Variant':<24} {'# universal flips':>20}")
    print("-" * 50)
    universal_flips = defaultdict(list)
    for v in SURFACE_VARIANTS + ["kernel_variant"]:
        for idx in {ii for (vv, ii) in correct_table if vv == "original"}:
            orig_cell = correct_table.get(("original", idx), {})
            var_cell = correct_table.get((v, idx), {})
            common = set(orig_cell) & set(var_cell)
            if len(common) < 5: continue
            orig_rate = sum(1 for m in common if orig_cell[m]) / len(common)
            var_rate = sum(1 for m in common if var_cell[m]) / len(common)
            if orig_rate >= 0.80 and var_rate <= 0.20:
                universal_flips[v].append((idx, orig_rate, var_rate))
        print(f"{v:<24} {len(universal_flips[v]):>20}")

    # Topic / problem_type / difficulty / year breakdown for universal flips
    print(f"\n\n=== TOPIC BREAKDOWN OF UNIVERSAL FLIPS ===\n")
    for v in SURFACE_VARIANTS + ["kernel_variant"]:
        if not universal_flips[v]: continue
        print(f"--- {v} ({len(universal_flips[v])} universal flips) ---")
        topics = Counter()
        ptypes = Counter()
        difficulties = Counter()
        years = Counter()
        for idx, _, _ in universal_flips[v]:
            md = metadata.get(idx, {})
            tag = md.get("tag")
            # tag may be a list (multi-tag) or a string
            if isinstance(tag, list):
                for t in tag: topics[t] += 1
            elif tag:
                topics[tag] += 1
            else:
                topics["?"] += 1
            ptypes[md.get("problem_type") or "?"] += 1
            diff = md.get("difficulty")
            if isinstance(diff, list): diff = diff[0] if diff else "?"
            difficulties[diff or "?"] += 1
            year = md.get("year")
            if year:
                # Bin years by decade
                decade = (year // 10) * 10
                years[f"{decade}s"] += 1
        print(f"  topics:      {dict(topics.most_common(8))}")
        print(f"  problem_type:{dict(ptypes)}")
        print(f"  difficulties:{dict(difficulties.most_common(6))}")
        print(f"  decades:     {dict(sorted(years.items()))}")
        print()

    # Save universal flips for later analysis
    out = {v: [{"index": idx, "orig_rate": o, "var_rate": vr}
               for (idx, o, vr) in flips]
           for v, flips in universal_flips.items()}
    json.dump(out, open(THIS_DIR / "universal_flips.json", "w"), indent=2)
    print(f"\nSaved -> analysis/universal_flips.json")

    # Topic-stratified analysis: failure rate per topic per variant
    print(f"\n\n=== ACCURACY BY TOPIC × VARIANT (mean across models) ===\n")
    by_topic_variant = defaultdict(lambda: defaultdict(list))
    for (v, idx), cell in correct_table.items():
        md = metadata.get(idx, {})
        tag = md.get("tag")
        if not tag or len(cell) < 5: continue
        # If multiple tags, attribute the same rate to each — keeps it simple
        topics_for_problem = tag if isinstance(tag, list) else [tag]
        rate = sum(1 for vv in cell.values() if vv) / len(cell)
        for t in topics_for_problem:
            by_topic_variant[t][v].append(rate)

    topics_to_show = ["ALG", "ANA", "NT", "COMB", "GEO"]
    print(f"{'Topic':<8}", end="")
    for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]:
        short = {"original":"orig","descriptive_long":"DL","descriptive_long_confusing":"DLC",
                 "descriptive_long_misleading":"DLM","garbled_string":"GS","kernel_variant":"KV"}[v]
        print(f"  {short:>5}", end="")
    print("  Δ_orig→KV")
    print("-" * 70)
    for t in topics_to_show:
        if t not in by_topic_variant: continue
        row = by_topic_variant[t]
        if "original" not in row: continue
        orig_mean = statistics.fmean(row["original"]) * 100
        print(f"{t:<8}", end="")
        for v in ["original"] + SURFACE_VARIANTS + ["kernel_variant"]:
            if v in row:
                m = statistics.fmean(row[v]) * 100
                print(f"  {m:>4.1f}%", end="")
            else:
                print(f"  {'-':>5}", end="")
        kv_mean = statistics.fmean(row.get("kernel_variant", [0])) * 100
        print(f"   {kv_mean - orig_mean:+5.1f}pp")


if __name__ == "__main__":
    main()