import json import os from openai import OpenAI from typing import Dict, Any # --- Configuration --- BATCH_IDS_FILE = "data/raw_datasets/submitted_oasst1_batch_ids.json" METADATA_FILE = "data/raw_datasets/oasst1_metadata_map.jsonl" # Store independently for Memory/User Modeling initialization OUTPUT_FILE = "data/corpora/oasst1_labeled.jsonl" def load_metadata() -> Dict[str, Dict[str, Any]]: print("Loading OASST1 metadata map...") mapping = {} if os.path.exists(METADATA_FILE): with open(METADATA_FILE, "r", encoding="utf-8") as f: for line in f: if line.strip(): item = json.loads(line) mapping[item["custom_id"]] = item return mapping def retrieve_oasst1(): api_key = os.getenv("OPENAI_API_KEY") if not api_key: print("Error: OPENAI_API_KEY not set.") return client = OpenAI(api_key=api_key) if not os.path.exists(BATCH_IDS_FILE): print(f"Error: {BATCH_IDS_FILE} not found.") return with open(BATCH_IDS_FILE, "r") as f: batch_ids = json.load(f) meta_map = load_metadata() count_success = 0 count_fail = 0 print(f"Appending OASST1 results to {OUTPUT_FILE}...") with open(OUTPUT_FILE, "a", encoding="utf-8") as f_out: for b_id in batch_ids: print(f"\nProcessing Batch {b_id}...") try: batch = client.batches.retrieve(b_id) if batch.output_file_id: print(f" Downloading output {batch.output_file_id}...") content = client.files.content(batch.output_file_id).text for line in content.splitlines(): if not line.strip(): continue res = json.loads(line) custom_id = res["custom_id"] if res["response"]["status_code"] == 200: try: body = res["response"]["body"] llm_content = body["choices"][0]["message"]["content"] parsed_json = json.loads(llm_content) meta = meta_map.get(custom_id) if meta: record = { "custom_id": custom_id, "original_query": meta["original_query"], "source": "oasst1", "user_id": meta.get("user_id"), "session_id": meta.get("session_id"), "extracted_json": parsed_json, "has_preference": len(parsed_json.get("preferences", [])) > 0 } f_out.write(json.dumps(record, ensure_ascii=False) + "\n") count_success += 1 else: # Fallback if metadata missing (unlikely) print(f"Warning: Metadata missing for {custom_id}") except Exception as e: print(f"Parse error {custom_id}: {e}") count_fail += 1 else: count_fail += 1 except Exception as e: print(f"Error checking batch {b_id}: {e}") print("\n" + "="*50) print("OASST1 RETRIEVAL COMPLETE") print(f"Successfully processed: {count_success}") print(f"Failed/Parse Error: {count_fail}") print(f"Full dataset updated at: {OUTPUT_FILE}") print("="*50) if __name__ == "__main__": retrieve_oasst1()