summaryrefslogtreecommitdiff
path: root/scripts/retrieve_oasst1.py
blob: 436d329239800c108856b43d37b21bf87643c4f3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import json
import os
from openai import OpenAI
from typing import Dict, Any

# --- Configuration ---
BATCH_IDS_FILE = "data/raw_datasets/submitted_oasst1_batch_ids.json"
METADATA_FILE = "data/raw_datasets/oasst1_metadata_map.jsonl"
# Store independently for Memory/User Modeling initialization
OUTPUT_FILE = "data/corpora/oasst1_labeled.jsonl" 

def load_metadata() -> Dict[str, Dict[str, Any]]:
    print("Loading OASST1 metadata map...")
    mapping = {}
    if os.path.exists(METADATA_FILE):
        with open(METADATA_FILE, "r", encoding="utf-8") as f:
            for line in f:
                if line.strip():
                    item = json.loads(line)
                    mapping[item["custom_id"]] = item
    return mapping

def retrieve_oasst1():
    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        print("Error: OPENAI_API_KEY not set.")
        return
    client = OpenAI(api_key=api_key)

    if not os.path.exists(BATCH_IDS_FILE):
        print(f"Error: {BATCH_IDS_FILE} not found.")
        return

    with open(BATCH_IDS_FILE, "r") as f:
        batch_ids = json.load(f)
    
    meta_map = load_metadata()
    count_success = 0
    count_fail = 0

    print(f"Appending OASST1 results to {OUTPUT_FILE}...")

    with open(OUTPUT_FILE, "a", encoding="utf-8") as f_out:
        for b_id in batch_ids:
            print(f"\nProcessing Batch {b_id}...")
            try:
                batch = client.batches.retrieve(b_id)
                if batch.output_file_id:
                    print(f"  Downloading output {batch.output_file_id}...")
                    content = client.files.content(batch.output_file_id).text
                    
                    for line in content.splitlines():
                        if not line.strip(): continue
                        res = json.loads(line)
                        custom_id = res["custom_id"]
                        
                        if res["response"]["status_code"] == 200:
                            try:
                                body = res["response"]["body"]
                                llm_content = body["choices"][0]["message"]["content"]
                                parsed_json = json.loads(llm_content)
                                
                                meta = meta_map.get(custom_id)
                                if meta:
                                    record = {
                                        "custom_id": custom_id,
                                        "original_query": meta["original_query"],
                                        "source": "oasst1",
                                        "user_id": meta.get("user_id"),
                                        "session_id": meta.get("session_id"),
                                        "extracted_json": parsed_json,
                                        "has_preference": len(parsed_json.get("preferences", [])) > 0
                                    }
                                    f_out.write(json.dumps(record, ensure_ascii=False) + "\n")
                                    count_success += 1
                                else:
                                    # Fallback if metadata missing (unlikely)
                                    print(f"Warning: Metadata missing for {custom_id}")
                            except Exception as e:
                                print(f"Parse error {custom_id}: {e}")
                                count_fail += 1
                        else:
                            count_fail += 1
            except Exception as e:
                print(f"Error checking batch {b_id}: {e}")

    print("\n" + "="*50)
    print("OASST1 RETRIEVAL COMPLETE")
    print(f"Successfully processed: {count_success}")
    print(f"Failed/Parse Error:     {count_fail}")
    print(f"Full dataset updated at: {OUTPUT_FILE}")
    print("="*50)

if __name__ == "__main__":
    retrieve_oasst1()