summaryrefslogtreecommitdiff
path: root/scripts/retrieve_synthesis.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
commite43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch)
tree6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /scripts/retrieve_synthesis.py
Initial commit (clean history)HEADmain
Diffstat (limited to 'scripts/retrieve_synthesis.py')
-rw-r--r--scripts/retrieve_synthesis.py118
1 files changed, 118 insertions, 0 deletions
diff --git a/scripts/retrieve_synthesis.py b/scripts/retrieve_synthesis.py
new file mode 100644
index 0000000..cbc4573
--- /dev/null
+++ b/scripts/retrieve_synthesis.py
@@ -0,0 +1,118 @@
+import json
+import os
+from openai import OpenAI
+from typing import Dict, Any
+
+# --- Configuration ---
+BATCH_IDS_FILE = "data/raw_datasets/submitted_synthesis_batch_ids.json"
+SEED_FILE = "data/raw_datasets/positive_seeds.jsonl"
+# Where to save the new synthesized records
+OUTPUT_FILE = "data/raw_datasets/synthesized_positives.jsonl"
+
+def load_seeds() -> Dict[str, Dict[str, Any]]:
+ print("Loading seeds map...")
+ mapping = {}
+ with open(SEED_FILE, "r", encoding="utf-8") as f:
+ # We need to map custom_id back to the seed to get the GROUND TRUTH preferences.
+ # But wait, in submit_synthesis_batch.py, we created custom_id as "syn_{original_id}".
+ # And we need to find the original seed by that ID.
+ # Problem: positive_seeds.jsonl contains the FULL record including 'extracted_json'.
+ # We can iterate and build a map: original_custom_id -> record
+ for idx, line in enumerate(f):
+ if line.strip():
+ item = json.loads(line)
+ # If item has custom_id, use it. If not, we used "seed_{i}" in submission.
+ # Let's hope positive_seeds.jsonl has custom_id (it should if it came from retrieve script).
+ cid = item.get("custom_id")
+ if not cid:
+ # Fallback if custom_id missing (e.g. from some older process)
+ # We generated "seed_{i}" in submit script.
+ cid = f"seed_{idx}"
+
+ mapping[cid] = item
+ return mapping
+
+def retrieve_synthesis():
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ print("Error: OPENAI_API_KEY not set.")
+ return
+ client = OpenAI(api_key=api_key)
+
+ if not os.path.exists(BATCH_IDS_FILE):
+ print(f"Error: {BATCH_IDS_FILE} not found.")
+ return
+
+ with open(BATCH_IDS_FILE, "r") as f:
+ batch_ids = json.load(f)
+
+ seed_map = load_seeds()
+ count_rewrites = 0
+ count_source_seeds = 0
+
+ print(f"Processing Synthesis Batches -> {OUTPUT_FILE}...")
+
+ with open(OUTPUT_FILE, "w", encoding="utf-8") as f_out:
+ for b_id in batch_ids:
+ print(f"\nProcessing Batch {b_id}...")
+ try:
+ batch = client.batches.retrieve(b_id)
+ if batch.output_file_id:
+ print(f" Downloading output {batch.output_file_id}...")
+ content = client.files.content(batch.output_file_id).text
+
+ for line in content.splitlines():
+ if not line.strip(): continue
+ res = json.loads(line)
+ syn_id = res["custom_id"] # e.g. "syn_req_123"
+
+ # Derive original seed ID: remove "syn_" prefix
+ if syn_id.startswith("syn_"):
+ orig_id = syn_id[4:]
+ else:
+ orig_id = syn_id
+
+ if res["response"]["status_code"] == 200:
+ try:
+ body = res["response"]["body"]
+ llm_content = body["choices"][0]["message"]["content"]
+ parsed_json = json.loads(llm_content)
+
+ rewrites = parsed_json.get("rewrites", [])
+ if not rewrites:
+ continue
+
+ # Find original preference to inherit
+ seed = seed_map.get(orig_id)
+ if seed:
+ prefs = seed.get("extracted_json")
+ # Create new records
+ for rw in rewrites:
+ new_record = {
+ "original_query": rw,
+ "source": "synthesis_gpt4o",
+ "parent_id": orig_id,
+ "extracted_json": prefs, # INHERIT PREFERENCE
+ "has_preference": True
+ }
+ f_out.write(json.dumps(new_record, ensure_ascii=False) + "\n")
+ count_rewrites += 1
+ count_source_seeds += 1
+ else:
+ # print(f"Warning: Seed {orig_id} not found in map")
+ pass
+ except Exception as e:
+ print(f"Parse error {syn_id}: {e}")
+ except Exception as e:
+ print(f"Error checking batch {b_id}: {e}")
+
+ print("\n" + "="*50)
+ print("SYNTHESIS RETRIEVAL COMPLETE")
+ print(f"Processed Source Seeds: {count_source_seeds}")
+ print(f"Generated New Samples: {count_rewrites}")
+ print(f"Saved to: {OUTPUT_FILE}")
+ print("="*50)
+
+if __name__ == "__main__":
+ retrieve_synthesis()
+