diff options
Diffstat (limited to 'scripts/pilot_study.py')
| -rw-r--r-- | scripts/pilot_study.py | 109 |
1 files changed, 109 insertions, 0 deletions
diff --git a/scripts/pilot_study.py b/scripts/pilot_study.py new file mode 100644 index 0000000..9754c42 --- /dev/null +++ b/scripts/pilot_study.py @@ -0,0 +1,109 @@ +import json +import os +import random +import asyncio +from typing import List, Dict, Any +from openai import AsyncOpenAI +from tqdm.asyncio import tqdm_asyncio + +# --- Configuration --- +INPUT_FILE = "data/raw_datasets/combined_raw_queries.jsonl" +OUTPUT_FILE = "data/raw_datasets/pilot_study_1000.jsonl" +SAMPLE_SIZE = 1000 +MODEL_NAME = "gpt-5.1" # Or your specific model ID +MAX_CONCURRENCY = 100 # Adjust based on your rate limits + +# --- Load System Prompt --- +with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f: + # Extract the system prompt part (before the examples to save tokens, + # or keep full if you want few-shot behavior). + # Based on the file content you wrote earlier, let's use the whole thing + # as the system instruction to ensure high quality. + SYSTEM_PROMPT = f.read() + +# --- Async Worker --- +async def label_query(client: AsyncOpenAI, sem: asyncio.Semaphore, item: Dict[str, Any]) -> Dict[str, Any]: + query = item["query"] + async with sem: + try: + response = await client.chat.completions.create( + model=MODEL_NAME, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": query} + ], + temperature=0.0, # Deterministic for extraction + response_format={"type": "json_object"} # Enforce JSON + ) + result_text = response.choices[0].message.content + + # Parse to ensure validity + try: + parsed = json.loads(result_text) + prefs = parsed.get("preferences", []) + has_pref = len(prefs) > 0 + except: + parsed = {"error": "json_parse_fail", "raw": result_text} + has_pref = False + + return { + "original_query": query, + "source": item.get("source"), + "extracted_json": parsed, + "has_preference": has_pref + } + except Exception as e: + return { + "original_query": query, + "source": item.get("source"), + "error": str(e), + "has_preference": False + } + +async def main(): + # 1. Load and Sample + print(f"Loading data from {INPUT_FILE}...") + all_lines = [] + with open(INPUT_FILE, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + all_lines.append(json.loads(line)) + + if len(all_lines) > SAMPLE_SIZE: + sampled_data = random.sample(all_lines, SAMPLE_SIZE) + else: + sampled_data = all_lines + print(f"Sampled {len(sampled_data)} items.") + + # 2. Setup OpenAI Client + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("Error: OPENAI_API_KEY environment variable not set.") + return + + client = AsyncOpenAI(api_key=api_key) + sem = asyncio.Semaphore(MAX_CONCURRENCY) + + # 3. Run Labeling + tasks = [label_query(client, sem, item) for item in sampled_data] + results = await tqdm_asyncio.gather(*tasks, desc="Labeling") + + # 4. Statistics & Save + pos_count = sum(1 for r in results if r.get("has_preference")) + total = len(results) + ratio = (pos_count / total) * 100 if total > 0 else 0 + + print(f"\n--- Results ---") + print(f"Total processed: {total}") + print(f"Positive (has preferences): {pos_count}") + print(f"Negative (empty): {total - pos_count}") + print(f"Positive Ratio: {ratio:.2f}%") + + with open(OUTPUT_FILE, "w", encoding="utf-8") as f: + for res in results: + f.write(json.dumps(res, ensure_ascii=False) + "\n") + print(f"Saved detailed results to {OUTPUT_FILE}") + +if __name__ == "__main__": + asyncio.run(main()) + |
