summaryrefslogtreecommitdiff
path: root/scripts/retrieve_oasst1.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
commite43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch)
tree6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /scripts/retrieve_oasst1.py
Initial commit (clean history)HEADmain
Diffstat (limited to 'scripts/retrieve_oasst1.py')
-rw-r--r--scripts/retrieve_oasst1.py96
1 files changed, 96 insertions, 0 deletions
diff --git a/scripts/retrieve_oasst1.py b/scripts/retrieve_oasst1.py
new file mode 100644
index 0000000..436d329
--- /dev/null
+++ b/scripts/retrieve_oasst1.py
@@ -0,0 +1,96 @@
+import json
+import os
+from openai import OpenAI
+from typing import Dict, Any
+
+# --- Configuration ---
+BATCH_IDS_FILE = "data/raw_datasets/submitted_oasst1_batch_ids.json"
+METADATA_FILE = "data/raw_datasets/oasst1_metadata_map.jsonl"
+# Store independently for Memory/User Modeling initialization
+OUTPUT_FILE = "data/corpora/oasst1_labeled.jsonl"
+
+def load_metadata() -> Dict[str, Dict[str, Any]]:
+ print("Loading OASST1 metadata map...")
+ mapping = {}
+ if os.path.exists(METADATA_FILE):
+ with open(METADATA_FILE, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ item = json.loads(line)
+ mapping[item["custom_id"]] = item
+ return mapping
+
+def retrieve_oasst1():
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ print("Error: OPENAI_API_KEY not set.")
+ return
+ client = OpenAI(api_key=api_key)
+
+ if not os.path.exists(BATCH_IDS_FILE):
+ print(f"Error: {BATCH_IDS_FILE} not found.")
+ return
+
+ with open(BATCH_IDS_FILE, "r") as f:
+ batch_ids = json.load(f)
+
+ meta_map = load_metadata()
+ count_success = 0
+ count_fail = 0
+
+ print(f"Appending OASST1 results to {OUTPUT_FILE}...")
+
+ with open(OUTPUT_FILE, "a", encoding="utf-8") as f_out:
+ for b_id in batch_ids:
+ print(f"\nProcessing Batch {b_id}...")
+ try:
+ batch = client.batches.retrieve(b_id)
+ if batch.output_file_id:
+ print(f" Downloading output {batch.output_file_id}...")
+ content = client.files.content(batch.output_file_id).text
+
+ for line in content.splitlines():
+ if not line.strip(): continue
+ res = json.loads(line)
+ custom_id = res["custom_id"]
+
+ if res["response"]["status_code"] == 200:
+ try:
+ body = res["response"]["body"]
+ llm_content = body["choices"][0]["message"]["content"]
+ parsed_json = json.loads(llm_content)
+
+ meta = meta_map.get(custom_id)
+ if meta:
+ record = {
+ "custom_id": custom_id,
+ "original_query": meta["original_query"],
+ "source": "oasst1",
+ "user_id": meta.get("user_id"),
+ "session_id": meta.get("session_id"),
+ "extracted_json": parsed_json,
+ "has_preference": len(parsed_json.get("preferences", [])) > 0
+ }
+ f_out.write(json.dumps(record, ensure_ascii=False) + "\n")
+ count_success += 1
+ else:
+ # Fallback if metadata missing (unlikely)
+ print(f"Warning: Metadata missing for {custom_id}")
+ except Exception as e:
+ print(f"Parse error {custom_id}: {e}")
+ count_fail += 1
+ else:
+ count_fail += 1
+ except Exception as e:
+ print(f"Error checking batch {b_id}: {e}")
+
+ print("\n" + "="*50)
+ print("OASST1 RETRIEVAL COMPLETE")
+ print(f"Successfully processed: {count_success}")
+ print(f"Failed/Parse Error: {count_fail}")
+ print(f"Full dataset updated at: {OUTPUT_FILE}")
+ print("="*50)
+
+if __name__ == "__main__":
+ retrieve_oasst1()
+