import json import os from openai import OpenAI from typing import Dict, Any # --- Configuration --- BATCH_IDS_FILE = "data/raw_datasets/submitted_synthesis_batch_ids.json" SEED_FILE = "data/raw_datasets/positive_seeds.jsonl" # Where to save the new synthesized records OUTPUT_FILE = "data/raw_datasets/synthesized_positives.jsonl" def load_seeds() -> Dict[str, Dict[str, Any]]: print("Loading seeds map...") mapping = {} with open(SEED_FILE, "r", encoding="utf-8") as f: # We need to map custom_id back to the seed to get the GROUND TRUTH preferences. # But wait, in submit_synthesis_batch.py, we created custom_id as "syn_{original_id}". # And we need to find the original seed by that ID. # Problem: positive_seeds.jsonl contains the FULL record including 'extracted_json'. # We can iterate and build a map: original_custom_id -> record for idx, line in enumerate(f): if line.strip(): item = json.loads(line) # If item has custom_id, use it. If not, we used "seed_{i}" in submission. # Let's hope positive_seeds.jsonl has custom_id (it should if it came from retrieve script). cid = item.get("custom_id") if not cid: # Fallback if custom_id missing (e.g. from some older process) # We generated "seed_{i}" in submit script. cid = f"seed_{idx}" mapping[cid] = item return mapping def retrieve_synthesis(): api_key = os.getenv("OPENAI_API_KEY") if not api_key: print("Error: OPENAI_API_KEY not set.") return client = OpenAI(api_key=api_key) if not os.path.exists(BATCH_IDS_FILE): print(f"Error: {BATCH_IDS_FILE} not found.") return with open(BATCH_IDS_FILE, "r") as f: batch_ids = json.load(f) seed_map = load_seeds() count_rewrites = 0 count_source_seeds = 0 print(f"Processing Synthesis Batches -> {OUTPUT_FILE}...") with open(OUTPUT_FILE, "w", encoding="utf-8") as f_out: for b_id in batch_ids: print(f"\nProcessing Batch {b_id}...") try: batch = client.batches.retrieve(b_id) if batch.output_file_id: print(f" Downloading output {batch.output_file_id}...") content = client.files.content(batch.output_file_id).text for line in content.splitlines(): if not line.strip(): continue res = json.loads(line) syn_id = res["custom_id"] # e.g. "syn_req_123" # Derive original seed ID: remove "syn_" prefix if syn_id.startswith("syn_"): orig_id = syn_id[4:] else: orig_id = syn_id if res["response"]["status_code"] == 200: try: body = res["response"]["body"] llm_content = body["choices"][0]["message"]["content"] parsed_json = json.loads(llm_content) rewrites = parsed_json.get("rewrites", []) if not rewrites: continue # Find original preference to inherit seed = seed_map.get(orig_id) if seed: prefs = seed.get("extracted_json") # Create new records for rw in rewrites: new_record = { "original_query": rw, "source": "synthesis_gpt4o", "parent_id": orig_id, "extracted_json": prefs, # INHERIT PREFERENCE "has_preference": True } f_out.write(json.dumps(new_record, ensure_ascii=False) + "\n") count_rewrites += 1 count_source_seeds += 1 else: # print(f"Warning: Seed {orig_id} not found in map") pass except Exception as e: print(f"Parse error {syn_id}: {e}") except Exception as e: print(f"Error checking batch {b_id}: {e}") print("\n" + "="*50) print("SYNTHESIS RETRIEVAL COMPLETE") print(f"Processed Source Seeds: {count_source_seeds}") print(f"Generated New Samples: {count_rewrites}") print(f"Saved to: {OUTPUT_FILE}") print("="*50) if __name__ == "__main__": retrieve_synthesis()