summaryrefslogtreecommitdiff
path: root/collaborativeagents/training/train_sft.py
blob: cff146ac0ce7e6937a30fed837ae073e0a75fcb9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/env python3
"""
SFT Training for Session-Level Reflection using TRL.

This trains the model to generate reflections given conversations,
which serves as initialization for GRPO training.
"""

import os
import json
import argparse
from pathlib import Path

os.environ["WANDB_PROJECT"] = "collaborative-agent-reflection-sft"

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
from datasets import Dataset


def load_sft_data(data_path: str):
    """Load SFT training data."""
    with open(data_path) as f:
        data = json.load(f)

    # Convert to format expected by SFTTrainer
    processed = []
    for item in data:
        messages = item["messages"]
        processed.append({"messages": messages})

    return Dataset.from_list(processed)


def main():
    parser = argparse.ArgumentParser(description="Train SFT for session-level reflection")
    parser.add_argument("--model-path", type=str,
                       default="/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/models/llama-3.1-8b-instruct",
                       help="Path to base model")
    parser.add_argument("--data-path", type=str, required=True,
                       help="Path to SFT training data JSON")
    parser.add_argument("--output-dir", type=str,
                       default="collaborativeagents/training/outputs/sft_reflection",
                       help="Output directory for checkpoints")
    parser.add_argument("--num-epochs", type=int, default=4,
                       help="Number of training epochs")
    parser.add_argument("--learning-rate", type=float, default=1e-6,
                       help="Learning rate")
    parser.add_argument("--batch-size", type=int, default=1,
                       help="Per-device batch size")
    parser.add_argument("--gradient-accumulation", type=int, default=64,
                       help="Gradient accumulation steps")
    args = parser.parse_args()

    print(f"Loading model from {args.model_path}...")
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # For distributed training (torchrun), don't use device_map
    # The trainer handles device placement
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        torch_dtype=torch.bfloat16,
        attn_implementation="sdpa",  # Use PyTorch native scaled dot product attention
    )

    print(f"Loading training data from {args.data_path}...")
    dataset = load_sft_data(args.data_path)
    print(f"Loaded {len(dataset)} examples")

    # Training arguments (from paper Table 4)
    # Using SFTConfig with assistant_only_loss=True (replaces DataCollatorForCompletionOnlyLM)
    training_args = SFTConfig(
        output_dir=args.output_dir,
        num_train_epochs=args.num_epochs,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=args.gradient_accumulation,
        learning_rate=args.learning_rate,
        lr_scheduler_type="cosine",
        warmup_ratio=0.1,
        bf16=True,
        logging_steps=10,
        save_steps=100,
        save_total_limit=3,
        report_to="wandb",
        max_grad_norm=1.0,
        weight_decay=0.01,
        max_seq_length=4096,
        # Train only on assistant responses (completion-only loss)
        # Note: assistant_only_loss needs the tokenizer to identify assistant turns
        remove_unused_columns=False,
    )

    # Format function for chat template
    def formatting_func(example):
        return tokenizer.apply_chat_template(
            example["messages"],
            tokenize=False,
            add_generation_prompt=False
        )

    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        processing_class=tokenizer,
        formatting_func=formatting_func,
    )

    print("Starting SFT training...")
    trainer.train()

    # Save final model
    final_path = Path(args.output_dir) / "final"
    trainer.save_model(str(final_path))
    tokenizer.save_pretrained(str(final_path))
    print(f"Saved final model to {final_path}")


if __name__ == "__main__":
    main()