summaryrefslogtreecommitdiff
path: root/scripts/split_train_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/split_train_test.py')
-rw-r--r--scripts/split_train_test.py76
1 files changed, 76 insertions, 0 deletions
diff --git a/scripts/split_train_test.py b/scripts/split_train_test.py
new file mode 100644
index 0000000..ccb4cb1
--- /dev/null
+++ b/scripts/split_train_test.py
@@ -0,0 +1,76 @@
+import json
+import os
+import random
+
+INPUT_FILE = "data/finetune/preference_extractor_450k.jsonl"
+TRAIN_FILE = "data/finetune/train_llama_factory.json"
+TEST_FILE = "data/finetune/test_llama_factory.json"
+TEST_SIZE = 1000
+
+SYSTEM_INSTRUCTION = (
+ "Extract user preferences from the query into JSON format based on the PreferenceList schema. "
+ "If no preferences are found, return {\"preferences\": []}."
+)
+
+def split_and_convert():
+ if not os.path.exists(INPUT_FILE):
+ print(f"Error: {INPUT_FILE} not found.")
+ return
+
+ print(f"Reading {INPUT_FILE}...")
+ all_data = []
+
+ with open(INPUT_FILE, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ item = json.loads(line)
+ # Convert to LLaMA-Factory format immediately
+ record = {
+ "instruction": SYSTEM_INSTRUCTION,
+ "input": item["input"],
+ "output": item["output"]
+ }
+ all_data.append(record)
+
+ print(f"Total records: {len(all_data)}")
+
+ # Shuffle
+ random.seed(42) # Fixed seed for reproducibility
+ random.shuffle(all_data)
+
+ # Split
+ test_data = all_data[:TEST_SIZE]
+ train_data = all_data[TEST_SIZE:]
+
+ print(f"Train size: {len(train_data)}")
+ print(f"Test size: {len(test_data)}")
+
+ # Save Train
+ print(f"Saving train set to {TRAIN_FILE}...")
+ with open(TRAIN_FILE, "w", encoding="utf-8") as f:
+ json.dump(train_data, f, indent=2, ensure_ascii=False)
+
+ # Save Test
+ print(f"Saving test set to {TEST_FILE}...")
+ with open(TEST_FILE, "w", encoding="utf-8") as f:
+ json.dump(test_data, f, indent=2, ensure_ascii=False)
+
+ print("Done!")
+
+ # Update dataset_info advice
+ print("\nUpdate dataset_info.json with:")
+ info = {
+ "preference_extractor_train": {
+ "file_name": "train_llama_factory.json",
+ "columns": {"prompt": "instruction", "query": "input", "response": "output"}
+ },
+ "preference_extractor_test": {
+ "file_name": "test_llama_factory.json",
+ "columns": {"prompt": "instruction", "query": "input", "response": "output"}
+ }
+ }
+ print(json.dumps(info, indent=2))
+
+if __name__ == "__main__":
+ split_and_convert()
+