#!/usr/bin/env python3 """ SFT Training for Session-Level Reflection using TRL. This trains the model to generate reflections given conversations, which serves as initialization for GRPO training. """ import os import json import argparse from pathlib import Path os.environ["WANDB_PROJECT"] = "collaborative-agent-reflection-sft" import torch from transformers import AutoModelForCausalLM, AutoTokenizer from trl import SFTTrainer, SFTConfig from datasets import Dataset def load_sft_data(data_path: str): """Load SFT training data.""" with open(data_path) as f: data = json.load(f) # Convert to format expected by SFTTrainer processed = [] for item in data: messages = item["messages"] processed.append({"messages": messages}) return Dataset.from_list(processed) def main(): parser = argparse.ArgumentParser(description="Train SFT for session-level reflection") parser.add_argument("--model-path", type=str, default="/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/models/llama-3.1-8b-instruct", help="Path to base model") parser.add_argument("--data-path", type=str, required=True, help="Path to SFT training data JSON") parser.add_argument("--output-dir", type=str, default="collaborativeagents/training/outputs/sft_reflection", help="Output directory for checkpoints") parser.add_argument("--num-epochs", type=int, default=4, help="Number of training epochs") parser.add_argument("--learning-rate", type=float, default=1e-6, help="Learning rate") parser.add_argument("--batch-size", type=int, default=1, help="Per-device batch size") parser.add_argument("--gradient-accumulation", type=int, default=64, help="Gradient accumulation steps") args = parser.parse_args() print(f"Loading model from {args.model_path}...") tokenizer = AutoTokenizer.from_pretrained(args.model_path) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # For distributed training (torchrun), don't use device_map # The trainer handles device placement model = AutoModelForCausalLM.from_pretrained( args.model_path, torch_dtype=torch.bfloat16, attn_implementation="sdpa", # Use PyTorch native scaled dot product attention ) print(f"Loading training data from {args.data_path}...") dataset = load_sft_data(args.data_path) print(f"Loaded {len(dataset)} examples") # Training arguments (from paper Table 4) # Using SFTConfig with assistant_only_loss=True (replaces DataCollatorForCompletionOnlyLM) training_args = SFTConfig( output_dir=args.output_dir, num_train_epochs=args.num_epochs, per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.gradient_accumulation, learning_rate=args.learning_rate, lr_scheduler_type="cosine", warmup_ratio=0.1, bf16=True, logging_steps=10, save_steps=100, save_total_limit=3, report_to="wandb", max_grad_norm=1.0, weight_decay=0.01, max_seq_length=4096, # Train only on assistant responses (completion-only loss) # Note: assistant_only_loss needs the tokenizer to identify assistant turns remove_unused_columns=False, ) # Format function for chat template def formatting_func(example): return tokenizer.apply_chat_template( example["messages"], tokenize=False, add_generation_prompt=False ) trainer = SFTTrainer( model=model, args=training_args, train_dataset=dataset, processing_class=tokenizer, formatting_func=formatting_func, ) print("Starting SFT training...") trainer.train() # Save final model final_path = Path(args.output_dir) / "final" trainer.save_model(str(final_path)) tokenizer.save_pretrained(str(final_path)) print(f"Saved final model to {final_path}") if __name__ == "__main__": main()