summaryrefslogtreecommitdiff
path: root/scripts/retrieve_synthesis.py
blob: cbc457384263555d4c464f513928845f4c6a9489 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import json
import os
from openai import OpenAI
from typing import Dict, Any

# --- Configuration ---
BATCH_IDS_FILE = "data/raw_datasets/submitted_synthesis_batch_ids.json"
SEED_FILE = "data/raw_datasets/positive_seeds.jsonl"
# Where to save the new synthesized records
OUTPUT_FILE = "data/raw_datasets/synthesized_positives.jsonl" 

def load_seeds() -> Dict[str, Dict[str, Any]]:
    print("Loading seeds map...")
    mapping = {}
    with open(SEED_FILE, "r", encoding="utf-8") as f:
        # We need to map custom_id back to the seed to get the GROUND TRUTH preferences.
        # But wait, in submit_synthesis_batch.py, we created custom_id as "syn_{original_id}".
        # And we need to find the original seed by that ID.
        # Problem: positive_seeds.jsonl contains the FULL record including 'extracted_json'.
        # We can iterate and build a map: original_custom_id -> record
        for idx, line in enumerate(f):
            if line.strip():
                item = json.loads(line)
                # If item has custom_id, use it. If not, we used "seed_{i}" in submission.
                # Let's hope positive_seeds.jsonl has custom_id (it should if it came from retrieve script).
                cid = item.get("custom_id")
                if not cid:
                    # Fallback if custom_id missing (e.g. from some older process)
                    # We generated "seed_{i}" in submit script.
                    cid = f"seed_{idx}"
                
                mapping[cid] = item
    return mapping

def retrieve_synthesis():
    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        print("Error: OPENAI_API_KEY not set.")
        return
    client = OpenAI(api_key=api_key)

    if not os.path.exists(BATCH_IDS_FILE):
        print(f"Error: {BATCH_IDS_FILE} not found.")
        return

    with open(BATCH_IDS_FILE, "r") as f:
        batch_ids = json.load(f)
    
    seed_map = load_seeds()
    count_rewrites = 0
    count_source_seeds = 0

    print(f"Processing Synthesis Batches -> {OUTPUT_FILE}...")

    with open(OUTPUT_FILE, "w", encoding="utf-8") as f_out:
        for b_id in batch_ids:
            print(f"\nProcessing Batch {b_id}...")
            try:
                batch = client.batches.retrieve(b_id)
                if batch.output_file_id:
                    print(f"  Downloading output {batch.output_file_id}...")
                    content = client.files.content(batch.output_file_id).text
                    
                    for line in content.splitlines():
                        if not line.strip(): continue
                        res = json.loads(line)
                        syn_id = res["custom_id"] # e.g. "syn_req_123"
                        
                        # Derive original seed ID: remove "syn_" prefix
                        if syn_id.startswith("syn_"):
                            orig_id = syn_id[4:]
                        else:
                            orig_id = syn_id
                        
                        if res["response"]["status_code"] == 200:
                            try:
                                body = res["response"]["body"]
                                llm_content = body["choices"][0]["message"]["content"]
                                parsed_json = json.loads(llm_content)
                                
                                rewrites = parsed_json.get("rewrites", [])
                                if not rewrites:
                                    continue
                                
                                # Find original preference to inherit
                                seed = seed_map.get(orig_id)
                                if seed:
                                    prefs = seed.get("extracted_json")
                                    # Create new records
                                    for rw in rewrites:
                                        new_record = {
                                            "original_query": rw,
                                            "source": "synthesis_gpt4o",
                                            "parent_id": orig_id,
                                            "extracted_json": prefs, # INHERIT PREFERENCE
                                            "has_preference": True
                                        }
                                        f_out.write(json.dumps(new_record, ensure_ascii=False) + "\n")
                                        count_rewrites += 1
                                    count_source_seeds += 1
                                else:
                                    # print(f"Warning: Seed {orig_id} not found in map")
                                    pass
                            except Exception as e:
                                print(f"Parse error {syn_id}: {e}")
            except Exception as e:
                print(f"Error checking batch {b_id}: {e}")

    print("\n" + "="*50)
    print("SYNTHESIS RETRIEVAL COMPLETE")
    print(f"Processed Source Seeds: {count_source_seeds}")
    print(f"Generated New Samples:  {count_rewrites}")
    print(f"Saved to: {OUTPUT_FILE}")
    print("="*50)

if __name__ == "__main__":
    retrieve_synthesis()