summaryrefslogtreecommitdiff
path: root/scripts/full_labeling.py
blob: 1c5281944d8be10c58d63638aff8a66c2c0d1f40 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import json
import os
import asyncio
import aiofiles
from typing import List, Dict, Any
from openai import AsyncOpenAI
from tqdm.asyncio import tqdm_asyncio

# --- Configuration ---
INPUT_FILE = "data/raw_datasets/combined_raw_queries.jsonl"
OUTPUT_FILE = "data/raw_datasets/labeled_full_dataset.jsonl"
CHECKPOINT_FILE = "data/raw_datasets/labeling_checkpoint.txt"
MODEL_NAME = "gpt-5.1"  # Or "gpt-4o"
MAX_CONCURRENCY = 500   # Adjust based on rate limits
SAVE_INTERVAL = 1000     # Save batch to disk every N items

# --- Load System Prompt ---
with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f:
    SYSTEM_PROMPT = f.read()

async def label_query(client: AsyncOpenAI, sem: asyncio.Semaphore, item: Dict[str, Any]) -> Dict[str, Any]:
    query = item["query"]
    async with sem:
        try:
            # We use a short timeout/retry strategy implicitly via library, 
            # but for bulk processing, just skipping errors is often better than stalling.
            response = await client.chat.completions.create(
                model=MODEL_NAME,
                messages=[
                    {"role": "system", "content": SYSTEM_PROMPT},
                    {"role": "user", "content": query}
                ],
                temperature=0.0,
                response_format={"type": "json_object"}
            )
            result_text = response.choices[0].message.content
            
            try:
                parsed = json.loads(result_text)
                prefs = parsed.get("preferences", [])
                has_pref = len(prefs) > 0
            except:
                parsed = {"error": "json_parse_fail", "raw": result_text}
                has_pref = False
                
            return {
                "original_query": query,
                "source": item.get("source"),
                "extracted_json": parsed,
                "has_preference": has_pref
            }
        except Exception as e:
            return {
                "original_query": query,
                "source": item.get("source"),
                "error": str(e),
                "has_preference": False
            }

async def main():
    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        print("Error: OPENAI_API_KEY not set.")
        return

    # 1. Determine start position (Resume logic)
    processed_count = 0
    if os.path.exists(OUTPUT_FILE):
        # Quick line count to see how many we've done
        # (This assumes we append strictly)
        with open(OUTPUT_FILE, "r", encoding="utf-8") as f:
            for _ in f:
                processed_count += 1
    
    print(f"Resuming from index {processed_count}...")

    # 2. Load Data (skip already processed)
    # Since reading 400k lines is fast, we just read all and slice
    all_items = []
    with open(INPUT_FILE, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                all_items.append(json.loads(line))
    
    total_items = len(all_items)
    remaining_items = all_items[processed_count:]
    
    if not remaining_items:
        print("All items processed!")
        return

    print(f"Total: {total_items}, Remaining: {len(remaining_items)}")

    # 3. Setup Client
    client = AsyncOpenAI(api_key=api_key)
    sem = asyncio.Semaphore(MAX_CONCURRENCY)

    # 4. Batch Processing
    # We process in chunks to allow periodic saving and memory management
    batch_size = SAVE_INTERVAL
    
    # Open file in append mode
    async with aiofiles.open(OUTPUT_FILE, "a", encoding="utf-8") as f_out:
        
        for i in range(0, len(remaining_items), batch_size):
            batch = remaining_items[i : i + batch_size]
            tasks = [label_query(client, sem, item) for item in batch]
            
            # Run batch
            results = await tqdm_asyncio.gather(*tasks, desc=f"Batch {i//batch_size}", leave=False)
            
            # Write batch
            lines = [json.dumps(res, ensure_ascii=False) + "\n" for res in results]
            await f_out.writelines(lines)
            await f_out.flush() # Ensure written to disk
            
            # Optional: Print stats every now and then
            pos_in_batch = sum(1 for r in results if r.get("has_preference"))
            # print(f"Batch saved. Positive in this batch: {pos_in_batch}/{len(batch)}")

    print(f"Done! Saved to {OUTPUT_FILE}")

if __name__ == "__main__":
    asyncio.run(main())