diff options
Diffstat (limited to 'scripts/full_labeling.py')
| -rw-r--r-- | scripts/full_labeling.py | 125 |
1 files changed, 125 insertions, 0 deletions
diff --git a/scripts/full_labeling.py b/scripts/full_labeling.py new file mode 100644 index 0000000..1c52819 --- /dev/null +++ b/scripts/full_labeling.py @@ -0,0 +1,125 @@ +import json +import os +import asyncio +import aiofiles +from typing import List, Dict, Any +from openai import AsyncOpenAI +from tqdm.asyncio import tqdm_asyncio + +# --- Configuration --- +INPUT_FILE = "data/raw_datasets/combined_raw_queries.jsonl" +OUTPUT_FILE = "data/raw_datasets/labeled_full_dataset.jsonl" +CHECKPOINT_FILE = "data/raw_datasets/labeling_checkpoint.txt" +MODEL_NAME = "gpt-5.1" # Or "gpt-4o" +MAX_CONCURRENCY = 500 # Adjust based on rate limits +SAVE_INTERVAL = 1000 # Save batch to disk every N items + +# --- Load System Prompt --- +with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f: + SYSTEM_PROMPT = f.read() + +async def label_query(client: AsyncOpenAI, sem: asyncio.Semaphore, item: Dict[str, Any]) -> Dict[str, Any]: + query = item["query"] + async with sem: + try: + # We use a short timeout/retry strategy implicitly via library, + # but for bulk processing, just skipping errors is often better than stalling. + response = await client.chat.completions.create( + model=MODEL_NAME, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": query} + ], + temperature=0.0, + response_format={"type": "json_object"} + ) + result_text = response.choices[0].message.content + + try: + parsed = json.loads(result_text) + prefs = parsed.get("preferences", []) + has_pref = len(prefs) > 0 + except: + parsed = {"error": "json_parse_fail", "raw": result_text} + has_pref = False + + return { + "original_query": query, + "source": item.get("source"), + "extracted_json": parsed, + "has_preference": has_pref + } + except Exception as e: + return { + "original_query": query, + "source": item.get("source"), + "error": str(e), + "has_preference": False + } + +async def main(): + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("Error: OPENAI_API_KEY not set.") + return + + # 1. Determine start position (Resume logic) + processed_count = 0 + if os.path.exists(OUTPUT_FILE): + # Quick line count to see how many we've done + # (This assumes we append strictly) + with open(OUTPUT_FILE, "r", encoding="utf-8") as f: + for _ in f: + processed_count += 1 + + print(f"Resuming from index {processed_count}...") + + # 2. Load Data (skip already processed) + # Since reading 400k lines is fast, we just read all and slice + all_items = [] + with open(INPUT_FILE, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + all_items.append(json.loads(line)) + + total_items = len(all_items) + remaining_items = all_items[processed_count:] + + if not remaining_items: + print("All items processed!") + return + + print(f"Total: {total_items}, Remaining: {len(remaining_items)}") + + # 3. Setup Client + client = AsyncOpenAI(api_key=api_key) + sem = asyncio.Semaphore(MAX_CONCURRENCY) + + # 4. Batch Processing + # We process in chunks to allow periodic saving and memory management + batch_size = SAVE_INTERVAL + + # Open file in append mode + async with aiofiles.open(OUTPUT_FILE, "a", encoding="utf-8") as f_out: + + for i in range(0, len(remaining_items), batch_size): + batch = remaining_items[i : i + batch_size] + tasks = [label_query(client, sem, item) for item in batch] + + # Run batch + results = await tqdm_asyncio.gather(*tasks, desc=f"Batch {i//batch_size}", leave=False) + + # Write batch + lines = [json.dumps(res, ensure_ascii=False) + "\n" for res in results] + await f_out.writelines(lines) + await f_out.flush() # Ensure written to disk + + # Optional: Print stats every now and then + pos_in_batch = sum(1 for r in results if r.get("has_preference")) + # print(f"Batch saved. Positive in this batch: {pos_in_batch}/{len(batch)}") + + print(f"Done! Saved to {OUTPUT_FILE}") + +if __name__ == "__main__": + asyncio.run(main()) + |
