summaryrefslogtreecommitdiff
path: root/scripts/assemble_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/assemble_dataset.py')
-rw-r--r--scripts/assemble_dataset.py85
1 files changed, 85 insertions, 0 deletions
diff --git a/scripts/assemble_dataset.py b/scripts/assemble_dataset.py
new file mode 100644
index 0000000..024f91f
--- /dev/null
+++ b/scripts/assemble_dataset.py
@@ -0,0 +1,85 @@
+import json
+import os
+import random
+
+# Source Files
+FILE_ORIGINAL = "data/raw_datasets/labeled_full_dataset_batch.jsonl"
+FILE_SYNTHESIS = "data/raw_datasets/synthesized_positives.jsonl"
+
+# Output Files
+OUTPUT_RAW = "data/finetune/preference_extractor_450k.jsonl"
+
+def assemble_dataset():
+ os.makedirs(os.path.dirname(OUTPUT_RAW), exist_ok=True)
+
+ print("Assembling final dataset...")
+
+ records = []
+
+ # 1. Load Original (Pos + Neg)
+ print(f"Loading {FILE_ORIGINAL}...")
+ if os.path.exists(FILE_ORIGINAL):
+ with open(FILE_ORIGINAL, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ item = json.loads(line)
+ # Standardize format: {"input": ..., "output": ...}
+ # Input is user query. Output is extracted JSON string.
+
+ query = item.get("original_query", "")
+ output_json = item.get("extracted_json", {"preferences": []})
+
+ # Ensure output is a string of JSON, minimal whitespace to save tokens
+ output_str = json.dumps(output_json, ensure_ascii=False)
+
+ records.append({
+ "input": query,
+ "output": output_str,
+ "source": item.get("source", "original")
+ })
+ else:
+ print(f"Warning: {FILE_ORIGINAL} missing!")
+
+ print(f"Loaded {len(records)} from original.")
+
+ # 2. Load Synthesis (Pos)
+ print(f"Loading {FILE_SYNTHESIS}...")
+ syn_count = 0
+ if os.path.exists(FILE_SYNTHESIS):
+ with open(FILE_SYNTHESIS, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ item = json.loads(line)
+
+ query = item.get("original_query", "")
+ output_json = item.get("extracted_json", {"preferences": []})
+ output_str = json.dumps(output_json, ensure_ascii=False)
+
+ records.append({
+ "input": query,
+ "output": output_str,
+ "source": "synthesis"
+ })
+ syn_count += 1
+ else:
+ print(f"Warning: {FILE_SYNTHESIS} missing!")
+
+ print(f"Loaded {syn_count} from synthesis.")
+
+ # 3. Shuffle
+ print("Shuffling...")
+ random.shuffle(records)
+
+ # 4. Save
+ print(f"Saving {len(records)} records to {OUTPUT_RAW}...")
+ with open(OUTPUT_RAW, "w", encoding="utf-8") as f:
+ for r in records:
+ f.write(json.dumps(r, ensure_ascii=False) + "\n")
+
+ print("Done!")
+ print("\nTo upload to Hugging Face, run:")
+ print("huggingface-cli upload <repo_id> data/finetune/preference_extractor_450k.jsonl --repo-type dataset")
+
+if __name__ == "__main__":
+ assemble_dataset()
+