1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
|
import json
import os
from openai import OpenAI
from typing import Dict, Any
# --- Configuration ---
# 1. Main Batch IDs (The 340k success ones we lost)
MAIN_BATCH_IDS_FILE = "data/raw_datasets/submitted_batch_ids.json"
# 2. OASST1 Batch IDs (New)
OASST1_BATCH_IDS_FILE = "data/raw_datasets/submitted_oasst1_batch_ids.json"
OASST1_METADATA_FILE = "data/raw_datasets/oasst1_metadata_map.jsonl"
# The file we want to APPEND to (currently has 68k retry items)
OUTPUT_FILE = "data/raw_datasets/labeled_full_dataset_batch.jsonl"
# Original queries map for main batch reconstruction
ORIGINAL_INPUT_FILE = "data/raw_datasets/combined_raw_queries.jsonl"
def load_original_queries() -> Dict[str, Dict[str, Any]]:
print("Loading original queries map (Main)...")
mapping = {}
with open(ORIGINAL_INPUT_FILE, "r", encoding="utf-8") as f:
for idx, line in enumerate(f):
if line.strip():
mapping[f"req_{idx}"] = json.loads(line)
return mapping
def load_oasst1_metadata() -> Dict[str, Dict[str, Any]]:
print("Loading OASST1 metadata map...")
mapping = {}
if os.path.exists(OASST1_METADATA_FILE):
with open(OASST1_METADATA_FILE, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
item = json.loads(line)
mapping[item["custom_id"]] = item
return mapping
def recover_and_merge():
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
print("Error: OPENAI_API_KEY not set.")
return
client = OpenAI(api_key=api_key)
# Load Maps
main_query_map = load_original_queries()
oasst1_meta_map = load_oasst1_metadata()
# We will append to the existing file which holds the RETRY results.
# So we don't lose the 68k we just fixed.
print(f"Appending recovered data to {OUTPUT_FILE}...")
count_main = 0
count_oasst1 = 0
with open(OUTPUT_FILE, "a", encoding="utf-8") as f_out:
# --- 1. Recover Main Batches ---
if os.path.exists(MAIN_BATCH_IDS_FILE):
with open(MAIN_BATCH_IDS_FILE, "r") as f:
main_ids = json.load(f)
print(f"\nRecovering {len(main_ids)} Main Batches...")
for b_id in main_ids:
try:
batch = client.batches.retrieve(b_id)
if batch.output_file_id:
print(f" Downloading {b_id} (Output: {batch.output_file_id})...")
content = client.files.content(batch.output_file_id).text
for line in content.splitlines():
if not line.strip(): continue
res = json.loads(line)
custom_id = res["custom_id"]
if res["response"]["status_code"] == 200:
try:
body = res["response"]["body"]
llm_content = body["choices"][0]["message"]["content"]
parsed_json = json.loads(llm_content)
original = main_query_map.get(custom_id)
if original:
record = {
"custom_id": custom_id,
"original_query": original["query"],
"source": original.get("source"),
"extracted_json": parsed_json,
"has_preference": len(parsed_json.get("preferences", [])) > 0
}
f_out.write(json.dumps(record, ensure_ascii=False) + "\n")
count_main += 1
except:
pass
except Exception as e:
print(f" Error {b_id}: {e}")
# --- 2. Retrieve OASST1 Batches ---
# User requested to skip OASST1 merge for now.
# if os.path.exists(OASST1_BATCH_IDS_FILE):
# with open(OASST1_BATCH_IDS_FILE, "r") as f:
# oasst_ids = json.load(f)
# print(f"\nRetrieving {len(oasst_ids)} OASST1 Batches...")
# for b_id in oasst_ids:
# try:
# batch = client.batches.retrieve(b_id)
# if batch.status == "completed" and batch.output_file_id:
# print(f" Downloading {b_id}...")
# content = client.files.content(batch.output_file_id).text
# for line in content.splitlines():
# if not line.strip(): continue
# res = json.loads(line)
# custom_id = res["custom_id"]
# if res["response"]["status_code"] == 200:
# try:
# body = res["response"]["body"]
# llm_content = body["choices"][0]["message"]["content"]
# parsed_json = json.loads(llm_content)
# meta = oasst1_meta_map.get(custom_id)
# if meta:
# record = {
# "custom_id": custom_id,
# "original_query": meta["original_query"],
# "source": "oasst1",
# "user_id": meta.get("user_id"), # Preserve User ID!
# "session_id": meta.get("session_id"),
# "extracted_json": parsed_json,
# "has_preference": len(parsed_json.get("preferences", [])) > 0
# }
# f_out.write(json.dumps(record, ensure_ascii=False) + "\n")
# count_oasst1 += 1
# except:
# pass
# except Exception as e:
# print(f" Error {b_id}: {e}")
print("\n" + "="*50)
print("RECOVERY & MERGE COMPLETE")
print(f"Recovered Main: {count_main}")
print(f"New OASST1: {count_oasst1}")
print(f"Full dataset updated at: {OUTPUT_FILE}")
print("="*50)
if __name__ == "__main__":
recover_and_merge()
|