summaryrefslogtreecommitdiff
path: root/scripts/split_train_test.py
blob: ccb4cb1455d0cb495a5e253717d384ddaf906d50 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import json
import os
import random

INPUT_FILE = "data/finetune/preference_extractor_450k.jsonl"
TRAIN_FILE = "data/finetune/train_llama_factory.json"
TEST_FILE = "data/finetune/test_llama_factory.json"
TEST_SIZE = 1000

SYSTEM_INSTRUCTION = (
    "Extract user preferences from the query into JSON format based on the PreferenceList schema. "
    "If no preferences are found, return {\"preferences\": []}."
)

def split_and_convert():
    if not os.path.exists(INPUT_FILE):
        print(f"Error: {INPUT_FILE} not found.")
        return

    print(f"Reading {INPUT_FILE}...")
    all_data = []
    
    with open(INPUT_FILE, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                item = json.loads(line)
                # Convert to LLaMA-Factory format immediately
                record = {
                    "instruction": SYSTEM_INSTRUCTION,
                    "input": item["input"],
                    "output": item["output"]
                }
                all_data.append(record)
    
    print(f"Total records: {len(all_data)}")
    
    # Shuffle
    random.seed(42) # Fixed seed for reproducibility
    random.shuffle(all_data)
    
    # Split
    test_data = all_data[:TEST_SIZE]
    train_data = all_data[TEST_SIZE:]
    
    print(f"Train size: {len(train_data)}")
    print(f"Test size:  {len(test_data)}")
    
    # Save Train
    print(f"Saving train set to {TRAIN_FILE}...")
    with open(TRAIN_FILE, "w", encoding="utf-8") as f:
        json.dump(train_data, f, indent=2, ensure_ascii=False)
        
    # Save Test
    print(f"Saving test set to {TEST_FILE}...")
    with open(TEST_FILE, "w", encoding="utf-8") as f:
        json.dump(test_data, f, indent=2, ensure_ascii=False)

    print("Done!")
    
    # Update dataset_info advice
    print("\nUpdate dataset_info.json with:")
    info = {
        "preference_extractor_train": {
            "file_name": "train_llama_factory.json",
            "columns": {"prompt": "instruction", "query": "input", "response": "output"}
        },
        "preference_extractor_test": {
            "file_name": "test_llama_factory.json",
            "columns": {"prompt": "instruction", "query": "input", "response": "output"}
        }
    }
    print(json.dumps(info, indent=2))

if __name__ == "__main__":
    split_and_convert()