summaryrefslogtreecommitdiff
path: root/scripts/train.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-09 11:00:39 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-09 11:00:39 -0600
commit13ddc8dc583d8b1355909970cb8c27f85b7d3c8b (patch)
tree073534138604c1c49021ca7e334322262129f6ac /scripts/train.py
Initial implementation: DAGFormer Phase 1
- olmo_graph.py: Modified OLMo2-1B forward with per-head routing via 256x256 adjacency matrix A - Proportional attribution for post-norm decomposition - All 6 GPU sanity checks pass (baseline diff = 0.000001) - predictor.py: Qwen3-Embedding encoder + MLP decoder + Gumbel-Sigmoid + cascading gate - pipeline.py: End-to-end glue (predictor -> A -> OLMo -> NLL) - trainer.py: Full training loop with DDP, gradient accumulation, eval, checkpointing - dolma.py: Streaming Dolma v1.7 with sequence packing - 43/43 unit tests pass Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'scripts/train.py')
-rw-r--r--scripts/train.py45
1 files changed, 45 insertions, 0 deletions
diff --git a/scripts/train.py b/scripts/train.py
new file mode 100644
index 0000000..63fb8a6
--- /dev/null
+++ b/scripts/train.py
@@ -0,0 +1,45 @@
+"""Entry point for DAGFormer training.
+
+Usage:
+ # Single GPU:
+ python scripts/train.py --config configs/sanity_check.yaml
+
+ # Multi-GPU (DDP):
+ torchrun --nproc_per_node=4 scripts/train.py --config configs/phase1_full.yaml
+"""
+
+from __future__ import annotations
+
+import argparse
+import os
+
+import torch
+import torch.distributed as dist
+
+from src.training.trainer import TrainConfig, Trainer
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Train DAGFormer")
+ parser.add_argument("--config", type=str, required=True, help="Path to YAML config file")
+ args = parser.parse_args()
+
+ config = TrainConfig.from_yaml(args.config)
+
+ # DDP setup
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
+
+ if world_size > 1:
+ dist.init_process_group(backend="nccl")
+ torch.cuda.set_device(local_rank)
+
+ trainer = Trainer(config, local_rank=local_rank, world_size=world_size)
+ trainer.train()
+
+ if world_size > 1:
+ dist.destroy_process_group()
+
+
+if __name__ == "__main__":
+ main()