import json import os import asyncio import aiofiles from typing import List, Dict, Any from openai import AsyncOpenAI from tqdm.asyncio import tqdm_asyncio # --- Configuration --- INPUT_FILE = "data/raw_datasets/combined_raw_queries.jsonl" OUTPUT_FILE = "data/raw_datasets/labeled_full_dataset.jsonl" CHECKPOINT_FILE = "data/raw_datasets/labeling_checkpoint.txt" MODEL_NAME = "gpt-5.1" # Or "gpt-4o" MAX_CONCURRENCY = 500 # Adjust based on rate limits SAVE_INTERVAL = 1000 # Save batch to disk every N items # --- Load System Prompt --- with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f: SYSTEM_PROMPT = f.read() async def label_query(client: AsyncOpenAI, sem: asyncio.Semaphore, item: Dict[str, Any]) -> Dict[str, Any]: query = item["query"] async with sem: try: # We use a short timeout/retry strategy implicitly via library, # but for bulk processing, just skipping errors is often better than stalling. response = await client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": query} ], temperature=0.0, response_format={"type": "json_object"} ) result_text = response.choices[0].message.content try: parsed = json.loads(result_text) prefs = parsed.get("preferences", []) has_pref = len(prefs) > 0 except: parsed = {"error": "json_parse_fail", "raw": result_text} has_pref = False return { "original_query": query, "source": item.get("source"), "extracted_json": parsed, "has_preference": has_pref } except Exception as e: return { "original_query": query, "source": item.get("source"), "error": str(e), "has_preference": False } async def main(): api_key = os.getenv("OPENAI_API_KEY") if not api_key: print("Error: OPENAI_API_KEY not set.") return # 1. Determine start position (Resume logic) processed_count = 0 if os.path.exists(OUTPUT_FILE): # Quick line count to see how many we've done # (This assumes we append strictly) with open(OUTPUT_FILE, "r", encoding="utf-8") as f: for _ in f: processed_count += 1 print(f"Resuming from index {processed_count}...") # 2. Load Data (skip already processed) # Since reading 400k lines is fast, we just read all and slice all_items = [] with open(INPUT_FILE, "r", encoding="utf-8") as f: for line in f: if line.strip(): all_items.append(json.loads(line)) total_items = len(all_items) remaining_items = all_items[processed_count:] if not remaining_items: print("All items processed!") return print(f"Total: {total_items}, Remaining: {len(remaining_items)}") # 3. Setup Client client = AsyncOpenAI(api_key=api_key) sem = asyncio.Semaphore(MAX_CONCURRENCY) # 4. Batch Processing # We process in chunks to allow periodic saving and memory management batch_size = SAVE_INTERVAL # Open file in append mode async with aiofiles.open(OUTPUT_FILE, "a", encoding="utf-8") as f_out: for i in range(0, len(remaining_items), batch_size): batch = remaining_items[i : i + batch_size] tasks = [label_query(client, sem, item) for item in batch] # Run batch results = await tqdm_asyncio.gather(*tasks, desc=f"Batch {i//batch_size}", leave=False) # Write batch lines = [json.dumps(res, ensure_ascii=False) + "\n" for res in results] await f_out.writelines(lines) await f_out.flush() # Ensure written to disk # Optional: Print stats every now and then pos_in_batch = sum(1 for r in results if r.get("has_preference")) # print(f"Batch saved. Positive in this batch: {pos_in_batch}/{len(batch)}") print(f"Done! Saved to {OUTPUT_FILE}") if __name__ == "__main__": asyncio.run(main())