diff options
Diffstat (limited to 'collaborativeagents/training/train_sft.py')
| -rw-r--r-- | collaborativeagents/training/train_sft.py | 124 |
1 files changed, 124 insertions, 0 deletions
diff --git a/collaborativeagents/training/train_sft.py b/collaborativeagents/training/train_sft.py new file mode 100644 index 0000000..cff146a --- /dev/null +++ b/collaborativeagents/training/train_sft.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python3 +""" +SFT Training for Session-Level Reflection using TRL. + +This trains the model to generate reflections given conversations, +which serves as initialization for GRPO training. +""" + +import os +import json +import argparse +from pathlib import Path + +os.environ["WANDB_PROJECT"] = "collaborative-agent-reflection-sft" + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from trl import SFTTrainer, SFTConfig +from datasets import Dataset + + +def load_sft_data(data_path: str): + """Load SFT training data.""" + with open(data_path) as f: + data = json.load(f) + + # Convert to format expected by SFTTrainer + processed = [] + for item in data: + messages = item["messages"] + processed.append({"messages": messages}) + + return Dataset.from_list(processed) + + +def main(): + parser = argparse.ArgumentParser(description="Train SFT for session-level reflection") + parser.add_argument("--model-path", type=str, + default="/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/models/llama-3.1-8b-instruct", + help="Path to base model") + parser.add_argument("--data-path", type=str, required=True, + help="Path to SFT training data JSON") + parser.add_argument("--output-dir", type=str, + default="collaborativeagents/training/outputs/sft_reflection", + help="Output directory for checkpoints") + parser.add_argument("--num-epochs", type=int, default=4, + help="Number of training epochs") + parser.add_argument("--learning-rate", type=float, default=1e-6, + help="Learning rate") + parser.add_argument("--batch-size", type=int, default=1, + help="Per-device batch size") + parser.add_argument("--gradient-accumulation", type=int, default=64, + help="Gradient accumulation steps") + args = parser.parse_args() + + print(f"Loading model from {args.model_path}...") + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # For distributed training (torchrun), don't use device_map + # The trainer handles device placement + model = AutoModelForCausalLM.from_pretrained( + args.model_path, + torch_dtype=torch.bfloat16, + attn_implementation="sdpa", # Use PyTorch native scaled dot product attention + ) + + print(f"Loading training data from {args.data_path}...") + dataset = load_sft_data(args.data_path) + print(f"Loaded {len(dataset)} examples") + + # Training arguments (from paper Table 4) + # Using SFTConfig with assistant_only_loss=True (replaces DataCollatorForCompletionOnlyLM) + training_args = SFTConfig( + output_dir=args.output_dir, + num_train_epochs=args.num_epochs, + per_device_train_batch_size=args.batch_size, + gradient_accumulation_steps=args.gradient_accumulation, + learning_rate=args.learning_rate, + lr_scheduler_type="cosine", + warmup_ratio=0.1, + bf16=True, + logging_steps=10, + save_steps=100, + save_total_limit=3, + report_to="wandb", + max_grad_norm=1.0, + weight_decay=0.01, + max_seq_length=4096, + # Train only on assistant responses (completion-only loss) + # Note: assistant_only_loss needs the tokenizer to identify assistant turns + remove_unused_columns=False, + ) + + # Format function for chat template + def formatting_func(example): + return tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + add_generation_prompt=False + ) + + trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=dataset, + processing_class=tokenizer, + formatting_func=formatting_func, + ) + + print("Starting SFT training...") + trainer.train() + + # Save final model + final_path = Path(args.output_dir) / "final" + trainer.save_model(str(final_path)) + tokenizer.save_pretrained(str(final_path)) + print(f"Saved final model to {final_path}") + + +if __name__ == "__main__": + main() |
