summaryrefslogtreecommitdiff
path: root/collaborativeagents/training/train_sft.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-27 09:57:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-27 09:57:37 -0600
commitdc801c07cf38b0c495686463e6ca6f871a64440e (patch)
tree599f03114775921dbc472403c701f4a3a8ea188a /collaborativeagents/training/train_sft.py
parente43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (diff)
Add collaborativeagents module and update gitignore
- Add collaborativeagents subproject with adapters, agents, and evaluation modules - Update .gitignore to exclude large binary files (.whl, .tar), wandb logs, and results Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'collaborativeagents/training/train_sft.py')
-rw-r--r--collaborativeagents/training/train_sft.py124
1 files changed, 124 insertions, 0 deletions
diff --git a/collaborativeagents/training/train_sft.py b/collaborativeagents/training/train_sft.py
new file mode 100644
index 0000000..cff146a
--- /dev/null
+++ b/collaborativeagents/training/train_sft.py
@@ -0,0 +1,124 @@
+#!/usr/bin/env python3
+"""
+SFT Training for Session-Level Reflection using TRL.
+
+This trains the model to generate reflections given conversations,
+which serves as initialization for GRPO training.
+"""
+
+import os
+import json
+import argparse
+from pathlib import Path
+
+os.environ["WANDB_PROJECT"] = "collaborative-agent-reflection-sft"
+
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from trl import SFTTrainer, SFTConfig
+from datasets import Dataset
+
+
+def load_sft_data(data_path: str):
+ """Load SFT training data."""
+ with open(data_path) as f:
+ data = json.load(f)
+
+ # Convert to format expected by SFTTrainer
+ processed = []
+ for item in data:
+ messages = item["messages"]
+ processed.append({"messages": messages})
+
+ return Dataset.from_list(processed)
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Train SFT for session-level reflection")
+ parser.add_argument("--model-path", type=str,
+ default="/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/models/llama-3.1-8b-instruct",
+ help="Path to base model")
+ parser.add_argument("--data-path", type=str, required=True,
+ help="Path to SFT training data JSON")
+ parser.add_argument("--output-dir", type=str,
+ default="collaborativeagents/training/outputs/sft_reflection",
+ help="Output directory for checkpoints")
+ parser.add_argument("--num-epochs", type=int, default=4,
+ help="Number of training epochs")
+ parser.add_argument("--learning-rate", type=float, default=1e-6,
+ help="Learning rate")
+ parser.add_argument("--batch-size", type=int, default=1,
+ help="Per-device batch size")
+ parser.add_argument("--gradient-accumulation", type=int, default=64,
+ help="Gradient accumulation steps")
+ args = parser.parse_args()
+
+ print(f"Loading model from {args.model_path}...")
+ tokenizer = AutoTokenizer.from_pretrained(args.model_path)
+
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+
+ # For distributed training (torchrun), don't use device_map
+ # The trainer handles device placement
+ model = AutoModelForCausalLM.from_pretrained(
+ args.model_path,
+ torch_dtype=torch.bfloat16,
+ attn_implementation="sdpa", # Use PyTorch native scaled dot product attention
+ )
+
+ print(f"Loading training data from {args.data_path}...")
+ dataset = load_sft_data(args.data_path)
+ print(f"Loaded {len(dataset)} examples")
+
+ # Training arguments (from paper Table 4)
+ # Using SFTConfig with assistant_only_loss=True (replaces DataCollatorForCompletionOnlyLM)
+ training_args = SFTConfig(
+ output_dir=args.output_dir,
+ num_train_epochs=args.num_epochs,
+ per_device_train_batch_size=args.batch_size,
+ gradient_accumulation_steps=args.gradient_accumulation,
+ learning_rate=args.learning_rate,
+ lr_scheduler_type="cosine",
+ warmup_ratio=0.1,
+ bf16=True,
+ logging_steps=10,
+ save_steps=100,
+ save_total_limit=3,
+ report_to="wandb",
+ max_grad_norm=1.0,
+ weight_decay=0.01,
+ max_seq_length=4096,
+ # Train only on assistant responses (completion-only loss)
+ # Note: assistant_only_loss needs the tokenizer to identify assistant turns
+ remove_unused_columns=False,
+ )
+
+ # Format function for chat template
+ def formatting_func(example):
+ return tokenizer.apply_chat_template(
+ example["messages"],
+ tokenize=False,
+ add_generation_prompt=False
+ )
+
+ trainer = SFTTrainer(
+ model=model,
+ args=training_args,
+ train_dataset=dataset,
+ processing_class=tokenizer,
+ formatting_func=formatting_func,
+ )
+
+ print("Starting SFT training...")
+ trainer.train()
+
+ # Save final model
+ final_path = Path(args.output_dir) / "final"
+ trainer.save_model(str(final_path))
+ tokenizer.save_pretrained(str(final_path))
+ print(f"Saved final model to {final_path}")
+
+
+if __name__ == "__main__":
+ main()