From e43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Wed, 17 Dec 2025 04:29:37 -0600 Subject: Initial commit (clean history) --- scripts/recover_and_merge.py | 151 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 151 insertions(+) create mode 100644 scripts/recover_and_merge.py (limited to 'scripts/recover_and_merge.py') diff --git a/scripts/recover_and_merge.py b/scripts/recover_and_merge.py new file mode 100644 index 0000000..b0f37f7 --- /dev/null +++ b/scripts/recover_and_merge.py @@ -0,0 +1,151 @@ +import json +import os +from openai import OpenAI +from typing import Dict, Any + +# --- Configuration --- +# 1. Main Batch IDs (The 340k success ones we lost) +MAIN_BATCH_IDS_FILE = "data/raw_datasets/submitted_batch_ids.json" +# 2. OASST1 Batch IDs (New) +OASST1_BATCH_IDS_FILE = "data/raw_datasets/submitted_oasst1_batch_ids.json" +OASST1_METADATA_FILE = "data/raw_datasets/oasst1_metadata_map.jsonl" + +# The file we want to APPEND to (currently has 68k retry items) +OUTPUT_FILE = "data/raw_datasets/labeled_full_dataset_batch.jsonl" + +# Original queries map for main batch reconstruction +ORIGINAL_INPUT_FILE = "data/raw_datasets/combined_raw_queries.jsonl" + +def load_original_queries() -> Dict[str, Dict[str, Any]]: + print("Loading original queries map (Main)...") + mapping = {} + with open(ORIGINAL_INPUT_FILE, "r", encoding="utf-8") as f: + for idx, line in enumerate(f): + if line.strip(): + mapping[f"req_{idx}"] = json.loads(line) + return mapping + +def load_oasst1_metadata() -> Dict[str, Dict[str, Any]]: + print("Loading OASST1 metadata map...") + mapping = {} + if os.path.exists(OASST1_METADATA_FILE): + with open(OASST1_METADATA_FILE, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + item = json.loads(line) + mapping[item["custom_id"]] = item + return mapping + +def recover_and_merge(): + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("Error: OPENAI_API_KEY not set.") + return + client = OpenAI(api_key=api_key) + + # Load Maps + main_query_map = load_original_queries() + oasst1_meta_map = load_oasst1_metadata() + + # We will append to the existing file which holds the RETRY results. + # So we don't lose the 68k we just fixed. + print(f"Appending recovered data to {OUTPUT_FILE}...") + + count_main = 0 + count_oasst1 = 0 + + with open(OUTPUT_FILE, "a", encoding="utf-8") as f_out: + + # --- 1. Recover Main Batches --- + if os.path.exists(MAIN_BATCH_IDS_FILE): + with open(MAIN_BATCH_IDS_FILE, "r") as f: + main_ids = json.load(f) + + print(f"\nRecovering {len(main_ids)} Main Batches...") + for b_id in main_ids: + try: + batch = client.batches.retrieve(b_id) + if batch.output_file_id: + print(f" Downloading {b_id} (Output: {batch.output_file_id})...") + content = client.files.content(batch.output_file_id).text + + for line in content.splitlines(): + if not line.strip(): continue + res = json.loads(line) + custom_id = res["custom_id"] + + if res["response"]["status_code"] == 200: + try: + body = res["response"]["body"] + llm_content = body["choices"][0]["message"]["content"] + parsed_json = json.loads(llm_content) + + original = main_query_map.get(custom_id) + if original: + record = { + "custom_id": custom_id, + "original_query": original["query"], + "source": original.get("source"), + "extracted_json": parsed_json, + "has_preference": len(parsed_json.get("preferences", [])) > 0 + } + f_out.write(json.dumps(record, ensure_ascii=False) + "\n") + count_main += 1 + except: + pass + except Exception as e: + print(f" Error {b_id}: {e}") + + # --- 2. Retrieve OASST1 Batches --- + # User requested to skip OASST1 merge for now. + # if os.path.exists(OASST1_BATCH_IDS_FILE): + # with open(OASST1_BATCH_IDS_FILE, "r") as f: + # oasst_ids = json.load(f) + + # print(f"\nRetrieving {len(oasst_ids)} OASST1 Batches...") + # for b_id in oasst_ids: + # try: + # batch = client.batches.retrieve(b_id) + # if batch.status == "completed" and batch.output_file_id: + # print(f" Downloading {b_id}...") + # content = client.files.content(batch.output_file_id).text + + # for line in content.splitlines(): + # if not line.strip(): continue + # res = json.loads(line) + # custom_id = res["custom_id"] + + # if res["response"]["status_code"] == 200: + # try: + # body = res["response"]["body"] + # llm_content = body["choices"][0]["message"]["content"] + # parsed_json = json.loads(llm_content) + + # meta = oasst1_meta_map.get(custom_id) + # if meta: + # record = { + # "custom_id": custom_id, + # "original_query": meta["original_query"], + # "source": "oasst1", + # "user_id": meta.get("user_id"), # Preserve User ID! + # "session_id": meta.get("session_id"), + # "extracted_json": parsed_json, + # "has_preference": len(parsed_json.get("preferences", [])) > 0 + # } + # f_out.write(json.dumps(record, ensure_ascii=False) + "\n") + # count_oasst1 += 1 + # except: + # pass + # except Exception as e: + # print(f" Error {b_id}: {e}") + + print("\n" + "="*50) + print("RECOVERY & MERGE COMPLETE") + print(f"Recovered Main: {count_main}") + print(f"New OASST1: {count_oasst1}") + print(f"Full dataset updated at: {OUTPUT_FILE}") + print("="*50) + +if __name__ == "__main__": + recover_and_merge() + -- cgit v1.2.3