summaryrefslogtreecommitdiff
path: root/scripts/recover_and_merge.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/recover_and_merge.py')
-rw-r--r--scripts/recover_and_merge.py151
1 files changed, 151 insertions, 0 deletions
diff --git a/scripts/recover_and_merge.py b/scripts/recover_and_merge.py
new file mode 100644
index 0000000..b0f37f7
--- /dev/null
+++ b/scripts/recover_and_merge.py
@@ -0,0 +1,151 @@
+import json
+import os
+from openai import OpenAI
+from typing import Dict, Any
+
+# --- Configuration ---
+# 1. Main Batch IDs (The 340k success ones we lost)
+MAIN_BATCH_IDS_FILE = "data/raw_datasets/submitted_batch_ids.json"
+# 2. OASST1 Batch IDs (New)
+OASST1_BATCH_IDS_FILE = "data/raw_datasets/submitted_oasst1_batch_ids.json"
+OASST1_METADATA_FILE = "data/raw_datasets/oasst1_metadata_map.jsonl"
+
+# The file we want to APPEND to (currently has 68k retry items)
+OUTPUT_FILE = "data/raw_datasets/labeled_full_dataset_batch.jsonl"
+
+# Original queries map for main batch reconstruction
+ORIGINAL_INPUT_FILE = "data/raw_datasets/combined_raw_queries.jsonl"
+
+def load_original_queries() -> Dict[str, Dict[str, Any]]:
+ print("Loading original queries map (Main)...")
+ mapping = {}
+ with open(ORIGINAL_INPUT_FILE, "r", encoding="utf-8") as f:
+ for idx, line in enumerate(f):
+ if line.strip():
+ mapping[f"req_{idx}"] = json.loads(line)
+ return mapping
+
+def load_oasst1_metadata() -> Dict[str, Dict[str, Any]]:
+ print("Loading OASST1 metadata map...")
+ mapping = {}
+ if os.path.exists(OASST1_METADATA_FILE):
+ with open(OASST1_METADATA_FILE, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ item = json.loads(line)
+ mapping[item["custom_id"]] = item
+ return mapping
+
+def recover_and_merge():
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ print("Error: OPENAI_API_KEY not set.")
+ return
+ client = OpenAI(api_key=api_key)
+
+ # Load Maps
+ main_query_map = load_original_queries()
+ oasst1_meta_map = load_oasst1_metadata()
+
+ # We will append to the existing file which holds the RETRY results.
+ # So we don't lose the 68k we just fixed.
+ print(f"Appending recovered data to {OUTPUT_FILE}...")
+
+ count_main = 0
+ count_oasst1 = 0
+
+ with open(OUTPUT_FILE, "a", encoding="utf-8") as f_out:
+
+ # --- 1. Recover Main Batches ---
+ if os.path.exists(MAIN_BATCH_IDS_FILE):
+ with open(MAIN_BATCH_IDS_FILE, "r") as f:
+ main_ids = json.load(f)
+
+ print(f"\nRecovering {len(main_ids)} Main Batches...")
+ for b_id in main_ids:
+ try:
+ batch = client.batches.retrieve(b_id)
+ if batch.output_file_id:
+ print(f" Downloading {b_id} (Output: {batch.output_file_id})...")
+ content = client.files.content(batch.output_file_id).text
+
+ for line in content.splitlines():
+ if not line.strip(): continue
+ res = json.loads(line)
+ custom_id = res["custom_id"]
+
+ if res["response"]["status_code"] == 200:
+ try:
+ body = res["response"]["body"]
+ llm_content = body["choices"][0]["message"]["content"]
+ parsed_json = json.loads(llm_content)
+
+ original = main_query_map.get(custom_id)
+ if original:
+ record = {
+ "custom_id": custom_id,
+ "original_query": original["query"],
+ "source": original.get("source"),
+ "extracted_json": parsed_json,
+ "has_preference": len(parsed_json.get("preferences", [])) > 0
+ }
+ f_out.write(json.dumps(record, ensure_ascii=False) + "\n")
+ count_main += 1
+ except:
+ pass
+ except Exception as e:
+ print(f" Error {b_id}: {e}")
+
+ # --- 2. Retrieve OASST1 Batches ---
+ # User requested to skip OASST1 merge for now.
+ # if os.path.exists(OASST1_BATCH_IDS_FILE):
+ # with open(OASST1_BATCH_IDS_FILE, "r") as f:
+ # oasst_ids = json.load(f)
+
+ # print(f"\nRetrieving {len(oasst_ids)} OASST1 Batches...")
+ # for b_id in oasst_ids:
+ # try:
+ # batch = client.batches.retrieve(b_id)
+ # if batch.status == "completed" and batch.output_file_id:
+ # print(f" Downloading {b_id}...")
+ # content = client.files.content(batch.output_file_id).text
+
+ # for line in content.splitlines():
+ # if not line.strip(): continue
+ # res = json.loads(line)
+ # custom_id = res["custom_id"]
+
+ # if res["response"]["status_code"] == 200:
+ # try:
+ # body = res["response"]["body"]
+ # llm_content = body["choices"][0]["message"]["content"]
+ # parsed_json = json.loads(llm_content)
+
+ # meta = oasst1_meta_map.get(custom_id)
+ # if meta:
+ # record = {
+ # "custom_id": custom_id,
+ # "original_query": meta["original_query"],
+ # "source": "oasst1",
+ # "user_id": meta.get("user_id"), # Preserve User ID!
+ # "session_id": meta.get("session_id"),
+ # "extracted_json": parsed_json,
+ # "has_preference": len(parsed_json.get("preferences", [])) > 0
+ # }
+ # f_out.write(json.dumps(record, ensure_ascii=False) + "\n")
+ # count_oasst1 += 1
+ # except:
+ # pass
+ # except Exception as e:
+ # print(f" Error {b_id}: {e}")
+
+ print("\n" + "="*50)
+ print("RECOVERY & MERGE COMPLETE")
+ print(f"Recovered Main: {count_main}")
+ print(f"New OASST1: {count_oasst1}")
+ print(f"Full dataset updated at: {OUTPUT_FILE}")
+ print("="*50)
+
+if __name__ == "__main__":
+ recover_and_merge()
+