summaryrefslogtreecommitdiff
path: root/scripts/train.py
blob: 63fb8a6b6976f369556fec5dc3ade3642bb421c8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""Entry point for DAGFormer training.

Usage:
    # Single GPU:
    python scripts/train.py --config configs/sanity_check.yaml

    # Multi-GPU (DDP):
    torchrun --nproc_per_node=4 scripts/train.py --config configs/phase1_full.yaml
"""

from __future__ import annotations

import argparse
import os

import torch
import torch.distributed as dist

from src.training.trainer import TrainConfig, Trainer


def main():
    parser = argparse.ArgumentParser(description="Train DAGFormer")
    parser.add_argument("--config", type=str, required=True, help="Path to YAML config file")
    args = parser.parse_args()

    config = TrainConfig.from_yaml(args.config)

    # DDP setup
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))

    if world_size > 1:
        dist.init_process_group(backend="nccl")
        torch.cuda.set_device(local_rank)

    trainer = Trainer(config, local_rank=local_rank, world_size=world_size)
    trainer.train()

    if world_size > 1:
        dist.destroy_process_group()


if __name__ == "__main__":
    main()