1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
|
import json
import os
from openai import OpenAI
from typing import Dict, Any
# --- Configuration ---
BATCH_IDS_FILE = "data/raw_datasets/submitted_oasst1_batch_ids.json"
METADATA_FILE = "data/raw_datasets/oasst1_metadata_map.jsonl"
# Store independently for Memory/User Modeling initialization
OUTPUT_FILE = "data/corpora/oasst1_labeled.jsonl"
def load_metadata() -> Dict[str, Dict[str, Any]]:
print("Loading OASST1 metadata map...")
mapping = {}
if os.path.exists(METADATA_FILE):
with open(METADATA_FILE, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
item = json.loads(line)
mapping[item["custom_id"]] = item
return mapping
def retrieve_oasst1():
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
print("Error: OPENAI_API_KEY not set.")
return
client = OpenAI(api_key=api_key)
if not os.path.exists(BATCH_IDS_FILE):
print(f"Error: {BATCH_IDS_FILE} not found.")
return
with open(BATCH_IDS_FILE, "r") as f:
batch_ids = json.load(f)
meta_map = load_metadata()
count_success = 0
count_fail = 0
print(f"Appending OASST1 results to {OUTPUT_FILE}...")
with open(OUTPUT_FILE, "a", encoding="utf-8") as f_out:
for b_id in batch_ids:
print(f"\nProcessing Batch {b_id}...")
try:
batch = client.batches.retrieve(b_id)
if batch.output_file_id:
print(f" Downloading output {batch.output_file_id}...")
content = client.files.content(batch.output_file_id).text
for line in content.splitlines():
if not line.strip(): continue
res = json.loads(line)
custom_id = res["custom_id"]
if res["response"]["status_code"] == 200:
try:
body = res["response"]["body"]
llm_content = body["choices"][0]["message"]["content"]
parsed_json = json.loads(llm_content)
meta = meta_map.get(custom_id)
if meta:
record = {
"custom_id": custom_id,
"original_query": meta["original_query"],
"source": "oasst1",
"user_id": meta.get("user_id"),
"session_id": meta.get("session_id"),
"extracted_json": parsed_json,
"has_preference": len(parsed_json.get("preferences", [])) > 0
}
f_out.write(json.dumps(record, ensure_ascii=False) + "\n")
count_success += 1
else:
# Fallback if metadata missing (unlikely)
print(f"Warning: Metadata missing for {custom_id}")
except Exception as e:
print(f"Parse error {custom_id}: {e}")
count_fail += 1
else:
count_fail += 1
except Exception as e:
print(f"Error checking batch {b_id}: {e}")
print("\n" + "="*50)
print("OASST1 RETRIEVAL COMPLETE")
print(f"Successfully processed: {count_success}")
print(f"Failed/Parse Error: {count_fail}")
print(f"Full dataset updated at: {OUTPUT_FILE}")
print("="*50)
if __name__ == "__main__":
retrieve_oasst1()
|