diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
| commit | e43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch) | |
| tree | 6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /scripts/retrieve_synthesis.py | |
Diffstat (limited to 'scripts/retrieve_synthesis.py')
| -rw-r--r-- | scripts/retrieve_synthesis.py | 118 |
1 files changed, 118 insertions, 0 deletions
diff --git a/scripts/retrieve_synthesis.py b/scripts/retrieve_synthesis.py new file mode 100644 index 0000000..cbc4573 --- /dev/null +++ b/scripts/retrieve_synthesis.py @@ -0,0 +1,118 @@ +import json +import os +from openai import OpenAI +from typing import Dict, Any + +# --- Configuration --- +BATCH_IDS_FILE = "data/raw_datasets/submitted_synthesis_batch_ids.json" +SEED_FILE = "data/raw_datasets/positive_seeds.jsonl" +# Where to save the new synthesized records +OUTPUT_FILE = "data/raw_datasets/synthesized_positives.jsonl" + +def load_seeds() -> Dict[str, Dict[str, Any]]: + print("Loading seeds map...") + mapping = {} + with open(SEED_FILE, "r", encoding="utf-8") as f: + # We need to map custom_id back to the seed to get the GROUND TRUTH preferences. + # But wait, in submit_synthesis_batch.py, we created custom_id as "syn_{original_id}". + # And we need to find the original seed by that ID. + # Problem: positive_seeds.jsonl contains the FULL record including 'extracted_json'. + # We can iterate and build a map: original_custom_id -> record + for idx, line in enumerate(f): + if line.strip(): + item = json.loads(line) + # If item has custom_id, use it. If not, we used "seed_{i}" in submission. + # Let's hope positive_seeds.jsonl has custom_id (it should if it came from retrieve script). + cid = item.get("custom_id") + if not cid: + # Fallback if custom_id missing (e.g. from some older process) + # We generated "seed_{i}" in submit script. + cid = f"seed_{idx}" + + mapping[cid] = item + return mapping + +def retrieve_synthesis(): + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("Error: OPENAI_API_KEY not set.") + return + client = OpenAI(api_key=api_key) + + if not os.path.exists(BATCH_IDS_FILE): + print(f"Error: {BATCH_IDS_FILE} not found.") + return + + with open(BATCH_IDS_FILE, "r") as f: + batch_ids = json.load(f) + + seed_map = load_seeds() + count_rewrites = 0 + count_source_seeds = 0 + + print(f"Processing Synthesis Batches -> {OUTPUT_FILE}...") + + with open(OUTPUT_FILE, "w", encoding="utf-8") as f_out: + for b_id in batch_ids: + print(f"\nProcessing Batch {b_id}...") + try: + batch = client.batches.retrieve(b_id) + if batch.output_file_id: + print(f" Downloading output {batch.output_file_id}...") + content = client.files.content(batch.output_file_id).text + + for line in content.splitlines(): + if not line.strip(): continue + res = json.loads(line) + syn_id = res["custom_id"] # e.g. "syn_req_123" + + # Derive original seed ID: remove "syn_" prefix + if syn_id.startswith("syn_"): + orig_id = syn_id[4:] + else: + orig_id = syn_id + + if res["response"]["status_code"] == 200: + try: + body = res["response"]["body"] + llm_content = body["choices"][0]["message"]["content"] + parsed_json = json.loads(llm_content) + + rewrites = parsed_json.get("rewrites", []) + if not rewrites: + continue + + # Find original preference to inherit + seed = seed_map.get(orig_id) + if seed: + prefs = seed.get("extracted_json") + # Create new records + for rw in rewrites: + new_record = { + "original_query": rw, + "source": "synthesis_gpt4o", + "parent_id": orig_id, + "extracted_json": prefs, # INHERIT PREFERENCE + "has_preference": True + } + f_out.write(json.dumps(new_record, ensure_ascii=False) + "\n") + count_rewrites += 1 + count_source_seeds += 1 + else: + # print(f"Warning: Seed {orig_id} not found in map") + pass + except Exception as e: + print(f"Parse error {syn_id}: {e}") + except Exception as e: + print(f"Error checking batch {b_id}: {e}") + + print("\n" + "="*50) + print("SYNTHESIS RETRIEVAL COMPLETE") + print(f"Processed Source Seeds: {count_source_seeds}") + print(f"Generated New Samples: {count_rewrites}") + print(f"Saved to: {OUTPUT_FILE}") + print("="*50) + +if __name__ == "__main__": + retrieve_synthesis() + |
