diff options
Diffstat (limited to 'scripts/retrieve_oasst1.py')
| -rw-r--r-- | scripts/retrieve_oasst1.py | 96 |
1 files changed, 96 insertions, 0 deletions
diff --git a/scripts/retrieve_oasst1.py b/scripts/retrieve_oasst1.py new file mode 100644 index 0000000..436d329 --- /dev/null +++ b/scripts/retrieve_oasst1.py @@ -0,0 +1,96 @@ +import json +import os +from openai import OpenAI +from typing import Dict, Any + +# --- Configuration --- +BATCH_IDS_FILE = "data/raw_datasets/submitted_oasst1_batch_ids.json" +METADATA_FILE = "data/raw_datasets/oasst1_metadata_map.jsonl" +# Store independently for Memory/User Modeling initialization +OUTPUT_FILE = "data/corpora/oasst1_labeled.jsonl" + +def load_metadata() -> Dict[str, Dict[str, Any]]: + print("Loading OASST1 metadata map...") + mapping = {} + if os.path.exists(METADATA_FILE): + with open(METADATA_FILE, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + item = json.loads(line) + mapping[item["custom_id"]] = item + return mapping + +def retrieve_oasst1(): + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("Error: OPENAI_API_KEY not set.") + return + client = OpenAI(api_key=api_key) + + if not os.path.exists(BATCH_IDS_FILE): + print(f"Error: {BATCH_IDS_FILE} not found.") + return + + with open(BATCH_IDS_FILE, "r") as f: + batch_ids = json.load(f) + + meta_map = load_metadata() + count_success = 0 + count_fail = 0 + + print(f"Appending OASST1 results to {OUTPUT_FILE}...") + + with open(OUTPUT_FILE, "a", encoding="utf-8") as f_out: + for b_id in batch_ids: + print(f"\nProcessing Batch {b_id}...") + try: + batch = client.batches.retrieve(b_id) + if batch.output_file_id: + print(f" Downloading output {batch.output_file_id}...") + content = client.files.content(batch.output_file_id).text + + for line in content.splitlines(): + if not line.strip(): continue + res = json.loads(line) + custom_id = res["custom_id"] + + if res["response"]["status_code"] == 200: + try: + body = res["response"]["body"] + llm_content = body["choices"][0]["message"]["content"] + parsed_json = json.loads(llm_content) + + meta = meta_map.get(custom_id) + if meta: + record = { + "custom_id": custom_id, + "original_query": meta["original_query"], + "source": "oasst1", + "user_id": meta.get("user_id"), + "session_id": meta.get("session_id"), + "extracted_json": parsed_json, + "has_preference": len(parsed_json.get("preferences", [])) > 0 + } + f_out.write(json.dumps(record, ensure_ascii=False) + "\n") + count_success += 1 + else: + # Fallback if metadata missing (unlikely) + print(f"Warning: Metadata missing for {custom_id}") + except Exception as e: + print(f"Parse error {custom_id}: {e}") + count_fail += 1 + else: + count_fail += 1 + except Exception as e: + print(f"Error checking batch {b_id}: {e}") + + print("\n" + "="*50) + print("OASST1 RETRIEVAL COMPLETE") + print(f"Successfully processed: {count_success}") + print(f"Failed/Parse Error: {count_fail}") + print(f"Full dataset updated at: {OUTPUT_FILE}") + print("="*50) + +if __name__ == "__main__": + retrieve_oasst1() + |
