summaryrefslogtreecommitdiff
path: root/scripts/retrieve_batch_results.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
commite43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch)
tree6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /scripts/retrieve_batch_results.py
Initial commit (clean history)HEADmain
Diffstat (limited to 'scripts/retrieve_batch_results.py')
-rw-r--r--scripts/retrieve_batch_results.py151
1 files changed, 151 insertions, 0 deletions
diff --git a/scripts/retrieve_batch_results.py b/scripts/retrieve_batch_results.py
new file mode 100644
index 0000000..aa26e28
--- /dev/null
+++ b/scripts/retrieve_batch_results.py
@@ -0,0 +1,151 @@
+import json
+import os
+import time
+from typing import Dict, Any, List, Set
+from openai import OpenAI
+
+# --- Configuration ---
+BATCH_IDS_FILE = "data/raw_datasets/submitted_batch_ids.json"
+ORIGINAL_INPUT_FILE = "data/raw_datasets/combined_raw_queries.jsonl"
+OUTPUT_LABEL_FILE = "data/raw_datasets/labeled_full_dataset_batch.jsonl"
+RETRY_INPUT_FILE = "data/raw_datasets/retry_requests.jsonl"
+MODEL_NAME = "gpt-5.1" # Need this for reconstruction
+
+# Load System Prompt locally to avoid import errors
+with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f:
+ SYSTEM_PROMPT = f.read()
+
+def load_original_queries() -> Dict[str, Dict[str, Any]]:
+ print("Loading original queries map...")
+ mapping = {}
+ with open(ORIGINAL_INPUT_FILE, "r", encoding="utf-8") as f:
+ for idx, line in enumerate(f):
+ if line.strip():
+ mapping[f"req_{idx}"] = json.loads(line)
+ return mapping
+
+def process_batch_results():
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ print("Error: OPENAI_API_KEY not set.")
+ return
+ client = OpenAI(api_key=api_key)
+
+ if not os.path.exists(BATCH_IDS_FILE):
+ print(f"Error: {BATCH_IDS_FILE} not found.")
+ return
+
+ with open(BATCH_IDS_FILE, "r") as f:
+ batch_ids = json.load(f)
+
+ query_map = load_original_queries()
+ processed_ids: Set[str] = set()
+
+ # We append to existing output file if it exists, or overwrite?
+ # To be safe and avoid duplicates if re-run, let's load existing processed IDs if file exists.
+ if os.path.exists(OUTPUT_LABEL_FILE):
+ print("Scanning existing output file to avoid duplicates...")
+ with open(OUTPUT_LABEL_FILE, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ try:
+ # We don't store custom_id in output, but we can infer or we should have stored it.
+ # Wait, the output format in previous run didn't store custom_id.
+ # But we can't easily dedup without it unless we match content.
+ # BETTER STRATEGY: Just overwrite OUTPUT_LABEL_FILE for this recovery run to be clean.
+ # Or, since we crashed mid-way, maybe overwrite is safer.
+ pass
+ except:
+ pass
+
+ print("Starting fresh download/processing (Overwriting output)...")
+
+ success_count = 0
+ fail_count = 0
+
+ with open(OUTPUT_LABEL_FILE, "w", encoding="utf-8") as f_success:
+ for b_id in batch_ids:
+ print(f"\nProcessing Batch {b_id}...")
+ try:
+ batch = client.batches.retrieve(b_id)
+
+ # 1. Output File (Success)
+ if batch.output_file_id:
+ print(f" Downloading output {batch.output_file_id}...")
+ content = client.files.content(batch.output_file_id).text
+
+ for line in content.splitlines():
+ if not line.strip(): continue
+ res = json.loads(line)
+ custom_id = res["custom_id"]
+
+ if res["response"]["status_code"] == 200:
+ try:
+ body = res["response"]["body"]
+ llm_content = body["choices"][0]["message"]["content"]
+ parsed_json = json.loads(llm_content)
+
+ original_item = query_map.get(custom_id)
+ if original_item:
+ record = {
+ "custom_id": custom_id, # Add this to help debug later
+ "original_query": original_item["query"],
+ "source": original_item.get("source"),
+ "extracted_json": parsed_json,
+ "has_preference": len(parsed_json.get("preferences", [])) > 0
+ }
+ f_success.write(json.dumps(record, ensure_ascii=False) + "\n")
+ processed_ids.add(custom_id)
+ success_count += 1
+ except Exception as e:
+ print(f" Parse Error {custom_id}: {e}")
+ # Parse error -> Fail
+ # If not 200, it's a fail, handled by logic below (since it won't be in processed_ids)
+
+ # 2. Error File (Explicit Failures)
+ # We don't need to explicitly read error file to write retries,
+ # because we will do a global "Missing Check" at the end.
+ # But reading it helps debugging.
+ if batch.error_file_id:
+ print(f" Downloading ERROR {batch.error_file_id}...")
+ # Just print count
+ # content = client.files.content(batch.error_file_id).text
+ # print(f" Found {len(content.splitlines())} errors in error file.")
+
+ except Exception as e:
+ print(f" CRITICAL ERROR processing batch {b_id}: {e}")
+
+ # --- Missing Check & Retry Generation ---
+ print(f"\nVerifying completeness... (Total Queries: {len(query_map)})")
+ print(f"Successful processed: {len(processed_ids)}")
+
+ with open(RETRY_INPUT_FILE, "w", encoding="utf-8") as f_retry:
+ for custom_id, original_item in query_map.items():
+ if custom_id not in processed_ids:
+ fail_count += 1
+
+ # Reconstruct Request
+ request_obj = {
+ "custom_id": custom_id,
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ "model": MODEL_NAME,
+ "messages": [
+ {"role": "system", "content": SYSTEM_PROMPT},
+ {"role": "user", "content": original_item["query"]}
+ ],
+ "temperature": 0.0,
+ "response_format": {"type": "json_object"}
+ }
+ }
+ f_retry.write(json.dumps(request_obj) + "\n")
+
+ print("\n" + "="*50)
+ print(f"Processing Complete.")
+ print(f"Successful: {success_count} (Saved to {OUTPUT_LABEL_FILE})")
+ print(f"To Retry: {fail_count} (Saved to {RETRY_INPUT_FILE})")
+ print("="*50)
+
+if __name__ == "__main__":
+ process_batch_results()