From e43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Wed, 17 Dec 2025 04:29:37 -0600 Subject: Initial commit (clean history) --- scripts/split_train_test.py | 76 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 scripts/split_train_test.py (limited to 'scripts/split_train_test.py') diff --git a/scripts/split_train_test.py b/scripts/split_train_test.py new file mode 100644 index 0000000..ccb4cb1 --- /dev/null +++ b/scripts/split_train_test.py @@ -0,0 +1,76 @@ +import json +import os +import random + +INPUT_FILE = "data/finetune/preference_extractor_450k.jsonl" +TRAIN_FILE = "data/finetune/train_llama_factory.json" +TEST_FILE = "data/finetune/test_llama_factory.json" +TEST_SIZE = 1000 + +SYSTEM_INSTRUCTION = ( + "Extract user preferences from the query into JSON format based on the PreferenceList schema. " + "If no preferences are found, return {\"preferences\": []}." +) + +def split_and_convert(): + if not os.path.exists(INPUT_FILE): + print(f"Error: {INPUT_FILE} not found.") + return + + print(f"Reading {INPUT_FILE}...") + all_data = [] + + with open(INPUT_FILE, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + item = json.loads(line) + # Convert to LLaMA-Factory format immediately + record = { + "instruction": SYSTEM_INSTRUCTION, + "input": item["input"], + "output": item["output"] + } + all_data.append(record) + + print(f"Total records: {len(all_data)}") + + # Shuffle + random.seed(42) # Fixed seed for reproducibility + random.shuffle(all_data) + + # Split + test_data = all_data[:TEST_SIZE] + train_data = all_data[TEST_SIZE:] + + print(f"Train size: {len(train_data)}") + print(f"Test size: {len(test_data)}") + + # Save Train + print(f"Saving train set to {TRAIN_FILE}...") + with open(TRAIN_FILE, "w", encoding="utf-8") as f: + json.dump(train_data, f, indent=2, ensure_ascii=False) + + # Save Test + print(f"Saving test set to {TEST_FILE}...") + with open(TEST_FILE, "w", encoding="utf-8") as f: + json.dump(test_data, f, indent=2, ensure_ascii=False) + + print("Done!") + + # Update dataset_info advice + print("\nUpdate dataset_info.json with:") + info = { + "preference_extractor_train": { + "file_name": "train_llama_factory.json", + "columns": {"prompt": "instruction", "query": "input", "response": "output"} + }, + "preference_extractor_test": { + "file_name": "test_llama_factory.json", + "columns": {"prompt": "instruction", "query": "input", "response": "output"} + } + } + print(json.dumps(info, indent=2)) + +if __name__ == "__main__": + split_and_convert() + -- cgit v1.2.3