diff options
Diffstat (limited to 'analysis/rescue_runner.py')
| -rw-r--r-- | analysis/rescue_runner.py | 341 |
1 files changed, 341 insertions, 0 deletions
diff --git a/analysis/rescue_runner.py b/analysis/rescue_runner.py new file mode 100644 index 0000000..9c9f226 --- /dev/null +++ b/analysis/rescue_runner.py @@ -0,0 +1,341 @@ +"""End-to-end rescue experiment runner. + +For each (model, variant, flip-case): + - Build 3 prompts: own_T2, canonical_T2, null (KV: only canonical_T2 + null) + - Solve with the same model the case originally failed under + - Grade with gpt-4o using the variant problem + canonical variant solution as reference + - Save per-case results immediately to a jsonl checkpoint (resumable) + +Usage: + python rescue_runner.py --pilot # 5 cases per cell (smoke test) + python rescue_runner.py # 30 cases per cell (full run) +""" +from __future__ import annotations +import argparse +import asyncio +import json +import os +import random +import sys +import time +from pathlib import Path +from typing import Optional + +# Local imports +THIS_DIR = Path(__file__).resolve().parent +sys.path.insert(0, str(THIS_DIR)) +from rescue_prompts import ( + truncate_T2, rename_own_prefix, + build_rescue_prompt, build_null_prompt, NULL_SCAFFOLD, +) +from rescue_api import ( + SOLVER_PROVIDERS, solve, grade, parse_solution, parse_grade, +) +from structural_overlap import ( + DATASET_DIR, RESULTS_DIR, find_variant_file, load_problems, SURFACE_VARIANTS, +) + + +# Short model name -> directory name in results_new +MODEL_RESULTS_DIRS = { + "gpt-4.1-mini": "gpt-4.1-mini", + "gpt-4o-mini": "gpt-4o-mini", + "claude-sonnet-4": "claude-sonnet-4", + "gemini-2.5-flash": "gemini_2.5_flash", # historical underscore naming +} +SELECTED_MODELS = ["gpt-4.1-mini", "gpt-4o-mini", "claude-sonnet-4", "gemini-2.5-flash"] +ALL_VARIANTS = SURFACE_VARIANTS + ["kernel_variant"] +SURFACE_CONDITIONS = ["own_T2", "canonical_T2", "null"] +KV_CONDITIONS = ["canonical_T2", "null"] + + +# ---------- Dataset loading ---------- + +def load_dataset_full() -> dict: + """Returns: {idx: {original: {...}, variants: {v: {map, question, solution}}}}. + + The dataset stores top-level question/solution and variant-keyed question/solution/map. + """ + out = {} + for f in sorted(DATASET_DIR.glob("*.json")): + d = json.load(open(f)) + idx = d.get("index") + cell = { + "problem_type": d.get("problem_type"), + "original_question": d.get("question") or "", + "original_solution": d.get("solution") or "", + "variants": {}, + } + for v, vd in d.get("variants", {}).items(): + if isinstance(vd, dict): + rmap = vd.get("map") + if isinstance(rmap, str): + try: + rmap = eval(rmap, {"__builtins__": {}}, {}) + except Exception: + rmap = None + cell["variants"][v] = { + "question": vd.get("question") or "", + "solution": vd.get("solution") or "", + "map": rmap if isinstance(rmap, dict) else None, + } + out[idx] = cell + return out + + +# ---------- Flip case selection ---------- + +def find_flip_cases(model: str, variant: str, max_cases: int, + seed: int = 42) -> list: + """Identify (orig_correct, var_wrong) flip cases for the cell. + + Returns list of dicts with: index, problem_type, model_orig_solution, + final_answer (recorded), variant_problem_statement (from results). + """ + mdir = RESULTS_DIR / MODEL_RESULTS_DIRS.get(model, model) + op = find_variant_file(mdir, "original") + vp = find_variant_file(mdir, variant) + if not op or not vp: + return [] + orig_by = {p["index"]: p for p in load_problems(op)} + var_by = {p["index"]: p for p in load_problems(vp)} + cases = [] + for idx in sorted(set(orig_by) & set(var_by)): + po, pv = orig_by[idx], var_by[idx] + if po.get("correct") is not True or pv.get("correct") is not False: + continue + orig_text = (po.get("solve") or {}).get("solution") or "" + if not orig_text: + continue + # Skip cases where we couldn't extract a T2 prefix from the original + fa = (po.get("solve") or {}).get("final_answer") or "" + if truncate_T2(orig_text, fa) is None: + continue + cases.append({ + "index": idx, + "problem_type": po.get("problem_type"), + "orig_solution": orig_text, + "orig_final_answer": fa, + }) + rng = random.Random(seed) + rng.shuffle(cases) + return cases[:max_cases] + + +# ---------- Prompt construction per case ---------- + +def build_case_prompts(case: dict, variant: str, ds_cell: dict) -> dict: + """Returns: {condition_name: user_message_string}.""" + var_info = ds_cell["variants"].get(variant, {}) + var_question = var_info.get("question", "") + if not var_question: + return {} + prompts = {} + is_kv = (variant == "kernel_variant") + + # canonical_T2: dataset's canonical variant solution truncated + canon_sol = var_info.get("solution", "") + if canon_sol: + canon_pre = truncate_T2(canon_sol, None) + if canon_pre: + prompts["canonical_T2"] = build_rescue_prompt(var_question, canon_pre) + + # own_T2: only for surface variants — model's own original-correct prefix renamed + if not is_kv: + rmap = var_info.get("map") or {} + own_pre = truncate_T2(case["orig_solution"], case.get("orig_final_answer")) + if own_pre and rmap: + renamed = rename_own_prefix(own_pre, rmap) + prompts["own_T2"] = build_rescue_prompt(var_question, renamed) + + # null: always available + prompts["null"] = build_null_prompt(var_question) + return prompts + + +# ---------- Per-condition runner ---------- + +async def run_one_condition(model: str, condition: str, user_msg: str, + case: dict, variant: str, ds_cell: dict) -> dict: + """Solve + grade a single condition for a single case. Returns a result dict.""" + var_info = ds_cell["variants"].get(variant, {}) + var_question = var_info.get("question", "") + canon_sol = var_info.get("solution", "") + problem_type = case["problem_type"] + t0 = time.time() + solve_resp = await solve(model, user_msg) + solve_dt = time.time() - t0 + if solve_resp["status"] != "success": + return { + "model": model, "variant": variant, "condition": condition, + "index": case["index"], "problem_type": problem_type, + "solve_status": "failed", + "solve_error": solve_resp["error"], + "solve_seconds": solve_dt, + "grade": None, + } + parsed = parse_solution(solve_resp["content"]) + if not parsed["solution"]: + return { + "model": model, "variant": variant, "condition": condition, + "index": case["index"], "problem_type": problem_type, + "solve_status": "parse_failed", + "solve_error": parsed.get("_parse_error"), + "solve_seconds": solve_dt, + "raw_solve_content": solve_resp["content"][:500], + "grade": None, + } + student_solution = parsed["solution"] + t1 = time.time() + grade_resp = await grade(problem_type, var_question, student_solution, canon_sol) + grade_dt = time.time() - t1 + if grade_resp["status"] != "success": + return { + "model": model, "variant": variant, "condition": condition, + "index": case["index"], "problem_type": problem_type, + "solve_status": "success", + "solve_seconds": solve_dt, + "grade_seconds": grade_dt, + "grade_status": "failed", + "grade_error": grade_resp["error"], + "student_solution_len": len(student_solution), + "student_final_answer": parsed["final_answer"], + "grade": None, + } + parsed_grade = parse_grade(grade_resp["content"]) + return { + "model": model, "variant": variant, "condition": condition, + "index": case["index"], "problem_type": problem_type, + "solve_status": "success", + "solve_seconds": solve_dt, + "grade_seconds": grade_dt, + "grade_status": "success", + "student_solution_len": len(student_solution), + "student_solution": student_solution, # full text for downstream analysis + "student_final_answer": parsed["final_answer"][:500], + "grade": parsed_grade["grade"], + "final_answer_correct": parsed_grade.get("final_answer_correct"), + "grade_feedback": (parsed_grade.get("detailed_feedback") or "")[:1000], + } + + +# ---------- Main run ---------- + +OUT_DIR = Path("/home/yurenh2/gap/analysis/rescue_results") +OUT_DIR.mkdir(parents=True, exist_ok=True) + + +def load_existing_keys(path: Path) -> set: + """Read jsonl checkpoint and return set of (cell_key, condition, index).""" + keys = set() + if not path.exists(): + return keys + with open(path) as f: + for line in f: + try: + d = json.loads(line) + keys.add((d["model"], d["variant"], d["condition"], d["index"])) + except Exception: + pass + return keys + + +async def run_all(num_cases_per_cell: int, dry_run: bool = False, models=None, + variants=None): + print(f"Loading dataset ...", flush=True) + ds = load_dataset_full() + print(f" loaded {len(ds)} problems", flush=True) + + out_path = OUT_DIR / f"rescue_{num_cases_per_cell}.jsonl" + existing = load_existing_keys(out_path) + print(f"Output: {out_path} (existing rows: {len(existing)})") + + models = models or SELECTED_MODELS + variants = variants or ALL_VARIANTS + + # Build the full task list + tasks_to_run = [] + cell_summary = {} + for model in models: + for variant in variants: + cases = find_flip_cases(model, variant, num_cases_per_cell) + cell_key = f"{model}/{variant}" + cell_summary[cell_key] = {"flip_cases_found": len(cases), + "added_tasks": 0} + for case in cases: + ds_cell = ds.get(case["index"]) + if ds_cell is None: + continue + prompts = build_case_prompts(case, variant, ds_cell) + for cond, user_msg in prompts.items(): + key = (model, variant, cond, case["index"]) + if key in existing: + continue + tasks_to_run.append((model, variant, cond, case, ds_cell, user_msg)) + cell_summary[cell_key]["added_tasks"] += 1 + + print(f"\nCell-level plan ({num_cases_per_cell} flip cases each):") + for k, v in sorted(cell_summary.items()): + print(f" {k:<46} found={v['flip_cases_found']:>3} new_tasks={v['added_tasks']:>4}") + total = len(tasks_to_run) + print(f"\nTotal new tasks: {total}") + if dry_run: + return + + if not tasks_to_run: + print("Nothing to do.") + return + + # Execute concurrently. Use a writer task to drain results into the jsonl. + fout = open(out_path, "a") + write_lock = asyncio.Lock() + completed = 0 + failed = 0 + started_at = time.time() + + async def run_and_write(model, variant, cond, case, ds_cell, user_msg): + nonlocal completed, failed + try: + res = await run_one_condition(model, cond, user_msg, case, variant, ds_cell) + except Exception as e: + res = { + "model": model, "variant": variant, "condition": cond, + "index": case["index"], "problem_type": case.get("problem_type"), + "solve_status": "exception", + "solve_error": f"{type(e).__name__}: {str(e)[:300]}", + "grade": None, + } + failed += 1 + async with write_lock: + fout.write(json.dumps(res) + "\n") + fout.flush() + completed += 1 + if completed % 25 == 0 or completed == total: + elapsed = time.time() - started_at + rate = completed / elapsed if elapsed > 0 else 0 + eta = (total - completed) / rate if rate > 0 else 0 + print(f" [{completed:>4}/{total}] elapsed={elapsed:>5.0f}s " + f"rate={rate:>4.1f}/s eta={eta:>5.0f}s " + f"failed_so_far={failed}", flush=True) + + awaitables = [run_and_write(*t) for t in tasks_to_run] + await asyncio.gather(*awaitables) + fout.close() + print(f"\nDone. {completed}/{total} written. Failed: {failed}.") + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--pilot", action="store_true", help="run only 5 cases per cell") + ap.add_argument("--cases", type=int, default=30, help="cases per cell (full run)") + ap.add_argument("--dry-run", action="store_true", help="print plan, don't call APIs") + ap.add_argument("--models", nargs="+", default=None) + ap.add_argument("--variants", nargs="+", default=None) + args = ap.parse_args() + n = 5 if args.pilot else args.cases + asyncio.run(run_all(n, dry_run=args.dry_run, + models=args.models, variants=args.variants)) + + +if __name__ == "__main__": + main() |
