summaryrefslogtreecommitdiff
path: root/scripts/full_labeling.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/full_labeling.py')
-rw-r--r--scripts/full_labeling.py125
1 files changed, 125 insertions, 0 deletions
diff --git a/scripts/full_labeling.py b/scripts/full_labeling.py
new file mode 100644
index 0000000..1c52819
--- /dev/null
+++ b/scripts/full_labeling.py
@@ -0,0 +1,125 @@
+import json
+import os
+import asyncio
+import aiofiles
+from typing import List, Dict, Any
+from openai import AsyncOpenAI
+from tqdm.asyncio import tqdm_asyncio
+
+# --- Configuration ---
+INPUT_FILE = "data/raw_datasets/combined_raw_queries.jsonl"
+OUTPUT_FILE = "data/raw_datasets/labeled_full_dataset.jsonl"
+CHECKPOINT_FILE = "data/raw_datasets/labeling_checkpoint.txt"
+MODEL_NAME = "gpt-5.1" # Or "gpt-4o"
+MAX_CONCURRENCY = 500 # Adjust based on rate limits
+SAVE_INTERVAL = 1000 # Save batch to disk every N items
+
+# --- Load System Prompt ---
+with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f:
+ SYSTEM_PROMPT = f.read()
+
+async def label_query(client: AsyncOpenAI, sem: asyncio.Semaphore, item: Dict[str, Any]) -> Dict[str, Any]:
+ query = item["query"]
+ async with sem:
+ try:
+ # We use a short timeout/retry strategy implicitly via library,
+ # but for bulk processing, just skipping errors is often better than stalling.
+ response = await client.chat.completions.create(
+ model=MODEL_NAME,
+ messages=[
+ {"role": "system", "content": SYSTEM_PROMPT},
+ {"role": "user", "content": query}
+ ],
+ temperature=0.0,
+ response_format={"type": "json_object"}
+ )
+ result_text = response.choices[0].message.content
+
+ try:
+ parsed = json.loads(result_text)
+ prefs = parsed.get("preferences", [])
+ has_pref = len(prefs) > 0
+ except:
+ parsed = {"error": "json_parse_fail", "raw": result_text}
+ has_pref = False
+
+ return {
+ "original_query": query,
+ "source": item.get("source"),
+ "extracted_json": parsed,
+ "has_preference": has_pref
+ }
+ except Exception as e:
+ return {
+ "original_query": query,
+ "source": item.get("source"),
+ "error": str(e),
+ "has_preference": False
+ }
+
+async def main():
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ print("Error: OPENAI_API_KEY not set.")
+ return
+
+ # 1. Determine start position (Resume logic)
+ processed_count = 0
+ if os.path.exists(OUTPUT_FILE):
+ # Quick line count to see how many we've done
+ # (This assumes we append strictly)
+ with open(OUTPUT_FILE, "r", encoding="utf-8") as f:
+ for _ in f:
+ processed_count += 1
+
+ print(f"Resuming from index {processed_count}...")
+
+ # 2. Load Data (skip already processed)
+ # Since reading 400k lines is fast, we just read all and slice
+ all_items = []
+ with open(INPUT_FILE, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ all_items.append(json.loads(line))
+
+ total_items = len(all_items)
+ remaining_items = all_items[processed_count:]
+
+ if not remaining_items:
+ print("All items processed!")
+ return
+
+ print(f"Total: {total_items}, Remaining: {len(remaining_items)}")
+
+ # 3. Setup Client
+ client = AsyncOpenAI(api_key=api_key)
+ sem = asyncio.Semaphore(MAX_CONCURRENCY)
+
+ # 4. Batch Processing
+ # We process in chunks to allow periodic saving and memory management
+ batch_size = SAVE_INTERVAL
+
+ # Open file in append mode
+ async with aiofiles.open(OUTPUT_FILE, "a", encoding="utf-8") as f_out:
+
+ for i in range(0, len(remaining_items), batch_size):
+ batch = remaining_items[i : i + batch_size]
+ tasks = [label_query(client, sem, item) for item in batch]
+
+ # Run batch
+ results = await tqdm_asyncio.gather(*tasks, desc=f"Batch {i//batch_size}", leave=False)
+
+ # Write batch
+ lines = [json.dumps(res, ensure_ascii=False) + "\n" for res in results]
+ await f_out.writelines(lines)
+ await f_out.flush() # Ensure written to disk
+
+ # Optional: Print stats every now and then
+ pos_in_batch = sum(1 for r in results if r.get("has_preference"))
+ # print(f"Batch saved. Positive in this batch: {pos_in_batch}/{len(batch)}")
+
+ print(f"Done! Saved to {OUTPUT_FILE}")
+
+if __name__ == "__main__":
+ asyncio.run(main())
+