import json import os from openai import OpenAI from typing import Dict, Any # --- Configuration --- # 1. Main Batch IDs (The 340k success ones we lost) MAIN_BATCH_IDS_FILE = "data/raw_datasets/submitted_batch_ids.json" # 2. OASST1 Batch IDs (New) OASST1_BATCH_IDS_FILE = "data/raw_datasets/submitted_oasst1_batch_ids.json" OASST1_METADATA_FILE = "data/raw_datasets/oasst1_metadata_map.jsonl" # The file we want to APPEND to (currently has 68k retry items) OUTPUT_FILE = "data/raw_datasets/labeled_full_dataset_batch.jsonl" # Original queries map for main batch reconstruction ORIGINAL_INPUT_FILE = "data/raw_datasets/combined_raw_queries.jsonl" def load_original_queries() -> Dict[str, Dict[str, Any]]: print("Loading original queries map (Main)...") mapping = {} with open(ORIGINAL_INPUT_FILE, "r", encoding="utf-8") as f: for idx, line in enumerate(f): if line.strip(): mapping[f"req_{idx}"] = json.loads(line) return mapping def load_oasst1_metadata() -> Dict[str, Dict[str, Any]]: print("Loading OASST1 metadata map...") mapping = {} if os.path.exists(OASST1_METADATA_FILE): with open(OASST1_METADATA_FILE, "r", encoding="utf-8") as f: for line in f: if line.strip(): item = json.loads(line) mapping[item["custom_id"]] = item return mapping def recover_and_merge(): api_key = os.getenv("OPENAI_API_KEY") if not api_key: print("Error: OPENAI_API_KEY not set.") return client = OpenAI(api_key=api_key) # Load Maps main_query_map = load_original_queries() oasst1_meta_map = load_oasst1_metadata() # We will append to the existing file which holds the RETRY results. # So we don't lose the 68k we just fixed. print(f"Appending recovered data to {OUTPUT_FILE}...") count_main = 0 count_oasst1 = 0 with open(OUTPUT_FILE, "a", encoding="utf-8") as f_out: # --- 1. Recover Main Batches --- if os.path.exists(MAIN_BATCH_IDS_FILE): with open(MAIN_BATCH_IDS_FILE, "r") as f: main_ids = json.load(f) print(f"\nRecovering {len(main_ids)} Main Batches...") for b_id in main_ids: try: batch = client.batches.retrieve(b_id) if batch.output_file_id: print(f" Downloading {b_id} (Output: {batch.output_file_id})...") content = client.files.content(batch.output_file_id).text for line in content.splitlines(): if not line.strip(): continue res = json.loads(line) custom_id = res["custom_id"] if res["response"]["status_code"] == 200: try: body = res["response"]["body"] llm_content = body["choices"][0]["message"]["content"] parsed_json = json.loads(llm_content) original = main_query_map.get(custom_id) if original: record = { "custom_id": custom_id, "original_query": original["query"], "source": original.get("source"), "extracted_json": parsed_json, "has_preference": len(parsed_json.get("preferences", [])) > 0 } f_out.write(json.dumps(record, ensure_ascii=False) + "\n") count_main += 1 except: pass except Exception as e: print(f" Error {b_id}: {e}") # --- 2. Retrieve OASST1 Batches --- # User requested to skip OASST1 merge for now. # if os.path.exists(OASST1_BATCH_IDS_FILE): # with open(OASST1_BATCH_IDS_FILE, "r") as f: # oasst_ids = json.load(f) # print(f"\nRetrieving {len(oasst_ids)} OASST1 Batches...") # for b_id in oasst_ids: # try: # batch = client.batches.retrieve(b_id) # if batch.status == "completed" and batch.output_file_id: # print(f" Downloading {b_id}...") # content = client.files.content(batch.output_file_id).text # for line in content.splitlines(): # if not line.strip(): continue # res = json.loads(line) # custom_id = res["custom_id"] # if res["response"]["status_code"] == 200: # try: # body = res["response"]["body"] # llm_content = body["choices"][0]["message"]["content"] # parsed_json = json.loads(llm_content) # meta = oasst1_meta_map.get(custom_id) # if meta: # record = { # "custom_id": custom_id, # "original_query": meta["original_query"], # "source": "oasst1", # "user_id": meta.get("user_id"), # Preserve User ID! # "session_id": meta.get("session_id"), # "extracted_json": parsed_json, # "has_preference": len(parsed_json.get("preferences", [])) > 0 # } # f_out.write(json.dumps(record, ensure_ascii=False) + "\n") # count_oasst1 += 1 # except: # pass # except Exception as e: # print(f" Error {b_id}: {e}") print("\n" + "="*50) print("RECOVERY & MERGE COMPLETE") print(f"Recovered Main: {count_main}") print(f"New OASST1: {count_oasst1}") print(f"Full dataset updated at: {OUTPUT_FILE}") print("="*50) if __name__ == "__main__": recover_and_merge()