1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
|
import json
import os
from openai import OpenAI
from typing import Dict, Any
# --- Configuration ---
BATCH_IDS_FILE = "data/raw_datasets/submitted_synthesis_batch_ids.json"
SEED_FILE = "data/raw_datasets/positive_seeds.jsonl"
# Where to save the new synthesized records
OUTPUT_FILE = "data/raw_datasets/synthesized_positives.jsonl"
def load_seeds() -> Dict[str, Dict[str, Any]]:
print("Loading seeds map...")
mapping = {}
with open(SEED_FILE, "r", encoding="utf-8") as f:
# We need to map custom_id back to the seed to get the GROUND TRUTH preferences.
# But wait, in submit_synthesis_batch.py, we created custom_id as "syn_{original_id}".
# And we need to find the original seed by that ID.
# Problem: positive_seeds.jsonl contains the FULL record including 'extracted_json'.
# We can iterate and build a map: original_custom_id -> record
for idx, line in enumerate(f):
if line.strip():
item = json.loads(line)
# If item has custom_id, use it. If not, we used "seed_{i}" in submission.
# Let's hope positive_seeds.jsonl has custom_id (it should if it came from retrieve script).
cid = item.get("custom_id")
if not cid:
# Fallback if custom_id missing (e.g. from some older process)
# We generated "seed_{i}" in submit script.
cid = f"seed_{idx}"
mapping[cid] = item
return mapping
def retrieve_synthesis():
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
print("Error: OPENAI_API_KEY not set.")
return
client = OpenAI(api_key=api_key)
if not os.path.exists(BATCH_IDS_FILE):
print(f"Error: {BATCH_IDS_FILE} not found.")
return
with open(BATCH_IDS_FILE, "r") as f:
batch_ids = json.load(f)
seed_map = load_seeds()
count_rewrites = 0
count_source_seeds = 0
print(f"Processing Synthesis Batches -> {OUTPUT_FILE}...")
with open(OUTPUT_FILE, "w", encoding="utf-8") as f_out:
for b_id in batch_ids:
print(f"\nProcessing Batch {b_id}...")
try:
batch = client.batches.retrieve(b_id)
if batch.output_file_id:
print(f" Downloading output {batch.output_file_id}...")
content = client.files.content(batch.output_file_id).text
for line in content.splitlines():
if not line.strip(): continue
res = json.loads(line)
syn_id = res["custom_id"] # e.g. "syn_req_123"
# Derive original seed ID: remove "syn_" prefix
if syn_id.startswith("syn_"):
orig_id = syn_id[4:]
else:
orig_id = syn_id
if res["response"]["status_code"] == 200:
try:
body = res["response"]["body"]
llm_content = body["choices"][0]["message"]["content"]
parsed_json = json.loads(llm_content)
rewrites = parsed_json.get("rewrites", [])
if not rewrites:
continue
# Find original preference to inherit
seed = seed_map.get(orig_id)
if seed:
prefs = seed.get("extracted_json")
# Create new records
for rw in rewrites:
new_record = {
"original_query": rw,
"source": "synthesis_gpt4o",
"parent_id": orig_id,
"extracted_json": prefs, # INHERIT PREFERENCE
"has_preference": True
}
f_out.write(json.dumps(new_record, ensure_ascii=False) + "\n")
count_rewrites += 1
count_source_seeds += 1
else:
# print(f"Warning: Seed {orig_id} not found in map")
pass
except Exception as e:
print(f"Parse error {syn_id}: {e}")
except Exception as e:
print(f"Error checking batch {b_id}: {e}")
print("\n" + "="*50)
print("SYNTHESIS RETRIEVAL COMPLETE")
print(f"Processed Source Seeds: {count_source_seeds}")
print(f"Generated New Samples: {count_rewrites}")
print(f"Saved to: {OUTPUT_FILE}")
print("="*50)
if __name__ == "__main__":
retrieve_synthesis()
|