summaryrefslogtreecommitdiff
path: root/scripts/pilot_study.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/pilot_study.py')
-rw-r--r--scripts/pilot_study.py109
1 files changed, 109 insertions, 0 deletions
diff --git a/scripts/pilot_study.py b/scripts/pilot_study.py
new file mode 100644
index 0000000..9754c42
--- /dev/null
+++ b/scripts/pilot_study.py
@@ -0,0 +1,109 @@
+import json
+import os
+import random
+import asyncio
+from typing import List, Dict, Any
+from openai import AsyncOpenAI
+from tqdm.asyncio import tqdm_asyncio
+
+# --- Configuration ---
+INPUT_FILE = "data/raw_datasets/combined_raw_queries.jsonl"
+OUTPUT_FILE = "data/raw_datasets/pilot_study_1000.jsonl"
+SAMPLE_SIZE = 1000
+MODEL_NAME = "gpt-5.1" # Or your specific model ID
+MAX_CONCURRENCY = 100 # Adjust based on your rate limits
+
+# --- Load System Prompt ---
+with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f:
+ # Extract the system prompt part (before the examples to save tokens,
+ # or keep full if you want few-shot behavior).
+ # Based on the file content you wrote earlier, let's use the whole thing
+ # as the system instruction to ensure high quality.
+ SYSTEM_PROMPT = f.read()
+
+# --- Async Worker ---
+async def label_query(client: AsyncOpenAI, sem: asyncio.Semaphore, item: Dict[str, Any]) -> Dict[str, Any]:
+ query = item["query"]
+ async with sem:
+ try:
+ response = await client.chat.completions.create(
+ model=MODEL_NAME,
+ messages=[
+ {"role": "system", "content": SYSTEM_PROMPT},
+ {"role": "user", "content": query}
+ ],
+ temperature=0.0, # Deterministic for extraction
+ response_format={"type": "json_object"} # Enforce JSON
+ )
+ result_text = response.choices[0].message.content
+
+ # Parse to ensure validity
+ try:
+ parsed = json.loads(result_text)
+ prefs = parsed.get("preferences", [])
+ has_pref = len(prefs) > 0
+ except:
+ parsed = {"error": "json_parse_fail", "raw": result_text}
+ has_pref = False
+
+ return {
+ "original_query": query,
+ "source": item.get("source"),
+ "extracted_json": parsed,
+ "has_preference": has_pref
+ }
+ except Exception as e:
+ return {
+ "original_query": query,
+ "source": item.get("source"),
+ "error": str(e),
+ "has_preference": False
+ }
+
+async def main():
+ # 1. Load and Sample
+ print(f"Loading data from {INPUT_FILE}...")
+ all_lines = []
+ with open(INPUT_FILE, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ all_lines.append(json.loads(line))
+
+ if len(all_lines) > SAMPLE_SIZE:
+ sampled_data = random.sample(all_lines, SAMPLE_SIZE)
+ else:
+ sampled_data = all_lines
+ print(f"Sampled {len(sampled_data)} items.")
+
+ # 2. Setup OpenAI Client
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ print("Error: OPENAI_API_KEY environment variable not set.")
+ return
+
+ client = AsyncOpenAI(api_key=api_key)
+ sem = asyncio.Semaphore(MAX_CONCURRENCY)
+
+ # 3. Run Labeling
+ tasks = [label_query(client, sem, item) for item in sampled_data]
+ results = await tqdm_asyncio.gather(*tasks, desc="Labeling")
+
+ # 4. Statistics & Save
+ pos_count = sum(1 for r in results if r.get("has_preference"))
+ total = len(results)
+ ratio = (pos_count / total) * 100 if total > 0 else 0
+
+ print(f"\n--- Results ---")
+ print(f"Total processed: {total}")
+ print(f"Positive (has preferences): {pos_count}")
+ print(f"Negative (empty): {total - pos_count}")
+ print(f"Positive Ratio: {ratio:.2f}%")
+
+ with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
+ for res in results:
+ f.write(json.dumps(res, ensure_ascii=False) + "\n")
+ print(f"Saved detailed results to {OUTPUT_FILE}")
+
+if __name__ == "__main__":
+ asyncio.run(main())
+