From 13ddc8dc583d8b1355909970cb8c27f85b7d3c8b Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 9 Feb 2026 11:00:39 -0600 Subject: Initial implementation: DAGFormer Phase 1 - olmo_graph.py: Modified OLMo2-1B forward with per-head routing via 256x256 adjacency matrix A - Proportional attribution for post-norm decomposition - All 6 GPU sanity checks pass (baseline diff = 0.000001) - predictor.py: Qwen3-Embedding encoder + MLP decoder + Gumbel-Sigmoid + cascading gate - pipeline.py: End-to-end glue (predictor -> A -> OLMo -> NLL) - trainer.py: Full training loop with DDP, gradient accumulation, eval, checkpointing - dolma.py: Streaming Dolma v1.7 with sequence packing - 43/43 unit tests pass Co-Authored-By: Claude Opus 4.6 --- scripts/train.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 scripts/train.py (limited to 'scripts/train.py') diff --git a/scripts/train.py b/scripts/train.py new file mode 100644 index 0000000..63fb8a6 --- /dev/null +++ b/scripts/train.py @@ -0,0 +1,45 @@ +"""Entry point for DAGFormer training. + +Usage: + # Single GPU: + python scripts/train.py --config configs/sanity_check.yaml + + # Multi-GPU (DDP): + torchrun --nproc_per_node=4 scripts/train.py --config configs/phase1_full.yaml +""" + +from __future__ import annotations + +import argparse +import os + +import torch +import torch.distributed as dist + +from src.training.trainer import TrainConfig, Trainer + + +def main(): + parser = argparse.ArgumentParser(description="Train DAGFormer") + parser.add_argument("--config", type=str, required=True, help="Path to YAML config file") + args = parser.parse_args() + + config = TrainConfig.from_yaml(args.config) + + # DDP setup + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + + if world_size > 1: + dist.init_process_group(backend="nccl") + torch.cuda.set_device(local_rank) + + trainer = Trainer(config, local_rank=local_rank, world_size=world_size) + trainer.train() + + if world_size > 1: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() -- cgit v1.2.3