summaryrefslogtreecommitdiff
path: root/scripts/pilot_study.py
blob: 9754c42d360d7133ba52abe2670aa7660d697b32 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import json
import os
import random
import asyncio
from typing import List, Dict, Any
from openai import AsyncOpenAI
from tqdm.asyncio import tqdm_asyncio

# --- Configuration ---
INPUT_FILE = "data/raw_datasets/combined_raw_queries.jsonl"
OUTPUT_FILE = "data/raw_datasets/pilot_study_1000.jsonl"
SAMPLE_SIZE = 1000
MODEL_NAME = "gpt-5.1"  # Or your specific model ID
MAX_CONCURRENCY = 100    # Adjust based on your rate limits

# --- Load System Prompt ---
with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f:
    # Extract the system prompt part (before the examples to save tokens, 
    # or keep full if you want few-shot behavior).
    # Based on the file content you wrote earlier, let's use the whole thing 
    # as the system instruction to ensure high quality.
    SYSTEM_PROMPT = f.read()

# --- Async Worker ---
async def label_query(client: AsyncOpenAI, sem: asyncio.Semaphore, item: Dict[str, Any]) -> Dict[str, Any]:
    query = item["query"]
    async with sem:
        try:
            response = await client.chat.completions.create(
                model=MODEL_NAME,
                messages=[
                    {"role": "system", "content": SYSTEM_PROMPT},
                    {"role": "user", "content": query}
                ],
                temperature=0.0, # Deterministic for extraction
                response_format={"type": "json_object"} # Enforce JSON
            )
            result_text = response.choices[0].message.content
            
            # Parse to ensure validity
            try:
                parsed = json.loads(result_text)
                prefs = parsed.get("preferences", [])
                has_pref = len(prefs) > 0
            except:
                parsed = {"error": "json_parse_fail", "raw": result_text}
                has_pref = False
                
            return {
                "original_query": query,
                "source": item.get("source"),
                "extracted_json": parsed,
                "has_preference": has_pref
            }
        except Exception as e:
            return {
                "original_query": query,
                "source": item.get("source"),
                "error": str(e),
                "has_preference": False
            }

async def main():
    # 1. Load and Sample
    print(f"Loading data from {INPUT_FILE}...")
    all_lines = []
    with open(INPUT_FILE, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                all_lines.append(json.loads(line))
    
    if len(all_lines) > SAMPLE_SIZE:
        sampled_data = random.sample(all_lines, SAMPLE_SIZE)
    else:
        sampled_data = all_lines
    print(f"Sampled {len(sampled_data)} items.")

    # 2. Setup OpenAI Client
    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        print("Error: OPENAI_API_KEY environment variable not set.")
        return
    
    client = AsyncOpenAI(api_key=api_key)
    sem = asyncio.Semaphore(MAX_CONCURRENCY)

    # 3. Run Labeling
    tasks = [label_query(client, sem, item) for item in sampled_data]
    results = await tqdm_asyncio.gather(*tasks, desc="Labeling")

    # 4. Statistics & Save
    pos_count = sum(1 for r in results if r.get("has_preference"))
    total = len(results)
    ratio = (pos_count / total) * 100 if total > 0 else 0

    print(f"\n--- Results ---")
    print(f"Total processed: {total}")
    print(f"Positive (has preferences): {pos_count}")
    print(f"Negative (empty): {total - pos_count}")
    print(f"Positive Ratio: {ratio:.2f}%")
    
    with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
        for res in results:
            f.write(json.dumps(res, ensure_ascii=False) + "\n")
    print(f"Saved detailed results to {OUTPUT_FILE}")

if __name__ == "__main__":
    asyncio.run(main())