summaryrefslogtreecommitdiff
path: root/scripts/convert_to_llama_factory.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
commite43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch)
tree6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /scripts/convert_to_llama_factory.py
Initial commit (clean history)HEADmain
Diffstat (limited to 'scripts/convert_to_llama_factory.py')
-rw-r--r--scripts/convert_to_llama_factory.py62
1 files changed, 62 insertions, 0 deletions
diff --git a/scripts/convert_to_llama_factory.py b/scripts/convert_to_llama_factory.py
new file mode 100644
index 0000000..d8b7565
--- /dev/null
+++ b/scripts/convert_to_llama_factory.py
@@ -0,0 +1,62 @@
+import json
+import os
+
+INPUT_FILE = "data/finetune/preference_extractor_450k.jsonl"
+OUTPUT_FILE = "data/finetune/train_llama_factory.json"
+
+# We embed the system prompt as "instruction" so the model learns to respond to this specific instruction.
+# Or, if you plan to put this system prompt in the system slot of the chat template,
+# you can leave instruction empty or simplified.
+# Given 0.5B model, explicit instruction in the prompt is often helpful.
+SYSTEM_INSTRUCTION = (
+ "Extract user preferences from the query into JSON format based on the PreferenceList schema. "
+ "If no preferences are found, return {\"preferences\": []}."
+)
+
+def convert():
+ if not os.path.exists(INPUT_FILE):
+ print(f"Error: {INPUT_FILE} not found. Run scripts/assemble_dataset.py first.")
+ return
+
+ print(f"Reading {INPUT_FILE}...")
+ dataset = []
+
+ with open(INPUT_FILE, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ item = json.loads(line)
+
+ # Alpaca format
+ record = {
+ "instruction": SYSTEM_INSTRUCTION,
+ "input": item["input"],
+ "output": item["output"]
+ }
+ dataset.append(record)
+
+ print(f"Converted {len(dataset)} items.")
+
+ # Save as JSON list (LLaMA-Factory standard)
+ print(f"Saving to {OUTPUT_FILE}...")
+ with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
+ json.dump(dataset, f, indent=2, ensure_ascii=False)
+
+ print("Done!")
+
+ print("\nNext steps for LLaMA-Factory:")
+ print("1. Copy data/finetune/train_llama_factory.json to your LLaMA-Factory data/ folder.")
+ print("2. Add entry to dataset_info.json:")
+ print(json.dumps({
+ "preference_extractor_v1": {
+ "file_name": "train_llama_factory.json",
+ "columns": {
+ "prompt": "instruction",
+ "query": "input",
+ "response": "output"
+ }
+ }
+ }, indent=2))
+
+if __name__ == "__main__":
+ convert()
+