diff options
Diffstat (limited to 'kv_math_redo.py')
| -rw-r--r-- | kv_math_redo.py | 275 |
1 files changed, 275 insertions, 0 deletions
diff --git a/kv_math_redo.py b/kv_math_redo.py new file mode 100644 index 0000000..6ef7b63 --- /dev/null +++ b/kv_math_redo.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python3 +""" +KV redo: Re-run slot discovery with o3 on no_slots problems, then evaluate all accepted. +""" +import json, asyncio, random, re, os +from datasets import load_dataset +from openai import AsyncOpenAI + +client = AsyncOpenAI() +SEM_O3 = asyncio.Semaphore(3) +SEM_EVAL = asyncio.Semaphore(40) +random.seed(42) + +OUTPUT_DIR = '/home/yurenh2/gap/mini_gap_math_results/kv_200' +REDO_FILE = os.path.join(OUTPUT_DIR, 'kv_redo.json') +LOCK = asyncio.Lock() + +# Prompts +SLOT_DISCOVERY_O3 = """You are a world-class mathematician. Given a math problem and its reference solution, find ALL numerical constants, coefficients, parameters, or specific values that could be changed to create a structurally equivalent but numerically different problem. + +Be AGGRESSIVE in finding slots. Even if a value seems "natural" (like 2, 3, etc.), if changing it to another value would still yield a solvable problem with the same solution technique, list it. + +Examples of mutable slots: +- Coefficients in equations (2x+3 → 5x+7) +- Exponents (x^3 → x^5) +- Bounds or limits (sum from 1 to 100 → sum from 1 to 200) +- Specific numbers in word problems +- Dimensions or sizes +- Modular bases + +Return ONLY valid JSON: +{"mutable_slots": [{"value": "...", "role": "...", "constraints": "..."}, ...]} +If truly no slots exist (every constant is mathematically forced), return: {"mutable_slots": []}""" + +BACK_SYNTHESIS = """You are creating a mathematical variant. Given the original problem, solution, and mutable slots: +- Choose NEW values satisfying constraints +- Rewrite the full problem with new values +- Solve it completely step by step +- The solution MUST use the same mathematical technique + +Return ONLY valid JSON: +{"new_problem": "...", "new_solution": "...", "new_answer": "...", "slot_changes": [{"original": "...", "new": "..."}]}""" + +VERIFY = """You are a rigorous mathematical verifier. Check: +1. Is the problem well-defined? +2. Is every solution step correct? +3. Does it reach the stated answer? + +Reply EXACTLY: VERDICT: ACCEPT or VERDICT: REJECT +REASON: [explanation]""" + +REPAIR = """Fix this rejected variant. +Problem: {problem} +Solution: {solution} +Reason: {reason} +Return ONLY JSON: {{"new_problem": "...", "new_solution": "...", "new_answer": "..."}}""" + +def extract_json(text): + if not text: return None + try: return json.loads(text) + except: pass + m = re.search(r'```(?:json)?\s*(\{[\s\S]*?\})\s*```', text) + if m: + try: return json.loads(m.group(1)) + except: pass + m = re.search(r'\{[\s\S]*\}', text) + if m: + try: return json.loads(m.group()) + except: pass + return None + +async def api_call(messages, model="o3", max_tokens=4000): + sem = SEM_O3 if model == "o3" else SEM_EVAL + async with sem: + for attempt in range(5): + try: + kw = {"model": model, "messages": messages} + if model == "o3": kw["max_completion_tokens"] = max_tokens + else: kw["max_tokens"] = max_tokens; kw["temperature"] = 0 + r = await client.chat.completions.create(**kw) + return r.choices[0].message.content + except Exception as e: + w = min(60, (2**attempt)*3) + if attempt < 4: await asyncio.sleep(w) + else: return None + +async def save(result, results_list): + async with LOCK: + results_list.append(result) + with open(REDO_FILE, 'w') as f: + json.dump(results_list, f) + +async def process_one(problem, solution, idx, results_list): + # Stage 1: o3 slot discovery + slot_text = await api_call( + [{"role": "system", "content": SLOT_DISCOVERY_O3}, + {"role": "user", "content": f"Problem:\n{problem}\n\nSolution:\n{solution}"}], + model="o3", max_tokens=2000) + slots = extract_json(slot_text) if slot_text else None + if not slots or not slots.get('mutable_slots'): + await save({'status': 'no_slots', 'idx': idx}, results_list) + print(f"[{idx}] no_slots (o3)") + return + + n = len(slots['mutable_slots']) + + # Stage 2: o3 back-synthesis + synth_text = await api_call( + [{"role": "system", "content": BACK_SYNTHESIS}, + {"role": "user", "content": f"Original:\n{problem}\n\nSolution:\n{solution}\n\nSlots:\n{json.dumps(slots['mutable_slots'])}\n\nCreate variant."}], + model="o3", max_tokens=6000) + synth = extract_json(synth_text) if synth_text else None + if not synth or not synth.get('new_problem'): + await save({'status': 'error', 'idx': idx, 'reason': 'synthesis failed'}, results_list) + print(f"[{idx}] synth_error") + return + + new_p, new_s, new_a = synth['new_problem'], synth['new_solution'], synth.get('new_answer', '') + + # Stage 3: 3 judges + 3 repair rounds + for rr in range(4): + judges = await asyncio.gather(*[api_call( + [{"role": "system", "content": VERIFY}, + {"role": "user", "content": f"Problem:\n{new_p}\n\nSolution:\n{new_s}"}], + model="o3", max_tokens=500) for _ in range(3)]) + accepts = sum(1 for j in judges if j and 'ACCEPT' in j.upper() and 'REJECT' not in j.upper()) + if accepts == 3: + await save({ + 'status': 'accepted', 'idx': idx, + 'original_problem': problem, 'original_solution': solution, + 'kv_problem': new_p, 'kv_solution': new_s, 'kv_answer': new_a, + 'mutable_slots': slots['mutable_slots'], + 'slot_changes': synth.get('slot_changes', []), + 'repair_rounds': rr, 'n_slots': n, + }, results_list) + print(f"[{idx}] ACCEPTED (round {rr}, {n} slots)") + return + if rr < 3: + reasons = [re.search(r'REASON:\s*(.*)', j or '', re.I) for j in judges] + reason_str = '; '.join(m.group(1)[:200] for m in reasons if m)[:500] + fix = await api_call( + [{"role": "system", "content": REPAIR.format(problem=new_p, solution=new_s, reason=reason_str)}, + {"role": "user", "content": "Fix."}], + model="o3", max_tokens=6000) + fd = extract_json(fix) if fix else None + if fd: + new_p = fd.get('new_problem', new_p) + new_s = fd.get('new_solution', new_s) + new_a = fd.get('new_answer', new_a) + + await save({'status': 'rejected', 'idx': idx}, results_list) + print(f"[{idx}] REJECTED") + +def extract_boxed(text): + if not text: return None + matches = [] + i = 0 + while i < len(text): + idx = text.find('\\boxed{', i) + if idx == -1: break + depth = 1; j = idx + 7 + while j < len(text) and depth > 0: + if text[j] == '{': depth += 1 + elif text[j] == '}': depth -= 1 + j += 1 + if depth == 0: matches.append(text[idx+7:j-1].strip()) + i = j + return matches[-1] if matches else None + +async def evaluate_all(all_accepted): + async def solve(problem, model): + async with SEM_EVAL: + r = await client.chat.completions.create( + model=model, temperature=0, max_tokens=2048, + messages=[{"role": "system", "content": "Solve step by step. Final answer in \\boxed{}."}, + {"role": "user", "content": problem}]) + return r.choices[0].message.content + + async def grade(ref, stu): + async with SEM_EVAL: + r = await client.chat.completions.create( + model="gpt-4o", temperature=0, max_tokens=10, + messages=[{"role": "user", "content": f"Are these equivalent? Ref: {ref}\nStudent: {stu}\nCORRECT or INCORRECT."}]) + t = r.choices[0].message.content.upper() + return 'INCORRECT' not in t and 'CORRECT' in t + + results = {} + for model in ['gpt-4o', 'gpt-4o-mini']: + print(f"\nEval {len(all_accepted)} with {model}...") + orig_sols = await asyncio.gather(*[solve(a['original_problem'], model) for a in all_accepted]) + kv_sols = await asyncio.gather(*[solve(a['kv_problem'], model) for a in all_accepted]) + + og, kg = [], [] + for i, a in enumerate(all_accepted): + ro = extract_boxed(a['original_solution']); so = extract_boxed(orig_sols[i]) + rk = a.get('kv_answer') or extract_boxed(a.get('kv_solution','')); sk = extract_boxed(kv_sols[i]) + og.append(await grade(ro or 'N/A', so or 'N/A') if ro and so else False) + kg.append(await grade(rk or 'N/A', sk or 'N/A') if rk and sk else False) + + oa = sum(og)/len(og)*100; ka = sum(kg)/len(kg)*100 + results[model] = {'orig': oa, 'kv': ka, 'delta': ka-oa, 'n': len(all_accepted), + 'orig_c': sum(og), 'kv_c': sum(kg)} + print(f" {model}: orig={oa:.1f}% kv={ka:.1f}% Δ={ka-oa:+.1f}pp (n={len(all_accepted)})") + return results + +async def main(): + # Load all Level 5 problems (same seed as kv_math_200.py) + subsets = ['algebra', 'number_theory', 'precalculus', 'intermediate_algebra', 'counting_and_probability', 'geometry'] + all_l5 = [] + for s in subsets: + ds = load_dataset('EleutherAI/hendrycks_math', s, split='test') + for item in ds: + if item.get('level') == 'Level 5' and len(item.get('solution','')) > 50: + item['subject'] = s; all_l5.append(item) + random.shuffle(all_l5) + selected = all_l5[:200] + print(f"Total pool: {len(selected)} Level 5 problems") + + # Load previous kv_200 results to find no_slots indices + with open(os.path.join(OUTPUT_DIR, 'kv_generation.json')) as f: + prev = json.load(f) + no_slots_indices = [r['original_index'] for r in prev if r['status'] == 'no_slots'] + prev_accepted = [r for r in prev if r['status'] == 'accepted'] + print(f"Previous: {len(prev_accepted)} accepted, {len(no_slots_indices)} no_slots to redo with o3") + + # Also load kv_50 accepted + kv50_file = '/home/yurenh2/gap/mini_gap_math_results/kv_50/kv_final_results.json' + kv50_accepted = [] + if os.path.exists(kv50_file): + with open(kv50_file) as f: + kv50 = json.load(f) + kv50_accepted = kv50.get('accepted_variants', []) + print(f"kv_50 accepted: {len(kv50_accepted)}") + + # Resume redo progress + redo_results = [] + done_indices = set() + if os.path.exists(REDO_FILE): + with open(REDO_FILE) as f: + redo_results = json.load(f) + done_indices = {r['idx'] for r in redo_results} + print(f"Resuming redo: {len(redo_results)} done") + + remaining = [i for i in no_slots_indices if i not in done_indices] + print(f"Remaining to redo: {len(remaining)}") + + # Process in batches of 8 + for batch_start in range(0, len(remaining), 8): + batch = remaining[batch_start:batch_start+8] + tasks = [process_one(selected[i]['problem'], selected[i]['solution'], i, redo_results) for i in batch] + await asyncio.gather(*tasks) + from collections import Counter + st = Counter(r['status'] for r in redo_results) + print(f"--- Redo progress: {len(redo_results)}/{len(no_slots_indices)}, {dict(st)} ---") + + # Combine all accepted + redo_accepted = [r for r in redo_results if r['status'] == 'accepted'] + all_accepted = kv50_accepted + prev_accepted + redo_accepted + print(f"\nTotal accepted: {len(all_accepted)} (kv50={len(kv50_accepted)}, kv200={len(prev_accepted)}, redo={len(redo_accepted)})") + + # Evaluate + if all_accepted: + eval_results = await evaluate_all(all_accepted) + final = { + 'total_accepted': len(all_accepted), + 'sources': {'kv50': len(kv50_accepted), 'kv200': len(prev_accepted), 'redo': len(redo_accepted)}, + 'evaluation': eval_results, + } + with open(os.path.join(OUTPUT_DIR, 'kv_combined_final.json'), 'w') as f: + json.dump(final, f, indent=2) + print(f"\n{'='*60}\nFINAL COMBINED RESULTS ({len(all_accepted)} KV variants)\n{'='*60}") + for m, r in eval_results.items(): + print(f" {m}: orig={r['orig']:.1f}% kv={r['kv']:.1f}% Δ={r['delta']:+.1f}pp (n={r['n']})") + +asyncio.run(main()) |
