summaryrefslogtreecommitdiff
path: root/pretrain.py
diff options
context:
space:
mode:
Diffstat (limited to 'pretrain.py')
-rw-r--r--pretrain.py454
1 files changed, 454 insertions, 0 deletions
diff --git a/pretrain.py b/pretrain.py
new file mode 100644
index 0000000..b939318
--- /dev/null
+++ b/pretrain.py
@@ -0,0 +1,454 @@
+from typing import Optional, Any, Sequence, List
+from dataclasses import dataclass
+import os
+import math
+import yaml
+import shutil
+
+import torch
+import torch.distributed as dist
+from torch import nn
+from torch.utils.data import DataLoader
+
+import tqdm
+import wandb
+import coolname
+import hydra
+import pydantic
+from omegaconf import DictConfig
+from wandb.util import make_artifact_name_safe
+from adam_atan2 import AdamATan2
+
+from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata
+from utils.functions import load_model_class, get_model_source_path
+from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed
+
+
+class LossConfig(pydantic.BaseModel):
+ model_config = pydantic.ConfigDict(extra='allow')
+
+ name: str
+
+
+class ArchConfig(pydantic.BaseModel):
+ model_config = pydantic.ConfigDict(extra='allow')
+
+ name: str
+ loss: LossConfig
+
+
+class PretrainConfig(pydantic.BaseModel):
+ # Config
+ arch: ArchConfig
+ # Data
+ data_path: str
+
+ # Hyperparams
+ global_batch_size: int
+ epochs: int
+
+ lr: float
+ lr_min_ratio: float
+ lr_warmup_steps: int
+
+ weight_decay: float
+ beta1: float
+ beta2: float
+
+ # Puzzle embedding
+ puzzle_emb_lr: float
+ puzzle_emb_weight_decay: float
+
+ # Names
+ project_name: Optional[str] = None
+ run_name: Optional[str] = None
+ checkpoint_path: Optional[str] = None
+
+ # Extras
+ seed: int = 0
+ checkpoint_every_eval: bool = False
+ eval_interval: Optional[int] = None
+ eval_save_outputs: List[str] = []
+
+
+@dataclass
+class TrainState:
+ model: nn.Module
+ optimizers: Sequence[torch.optim.Optimizer]
+ optimizer_lrs: Sequence[float]
+ carry: Any
+
+ step: int
+ total_steps: int
+
+
+def create_dataloader(config: PretrainConfig, split: str, rank: int, world_size: int, **kwargs):
+ dataset = PuzzleDataset(PuzzleDatasetConfig(
+ seed=config.seed,
+
+ dataset_path=config.data_path,
+
+ rank=rank,
+ num_replicas=world_size,
+
+ **kwargs
+ ), split=split)
+ dataloader = DataLoader(
+ dataset,
+ batch_size=None,
+
+ num_workers=1,
+ prefetch_factor=8,
+
+ pin_memory=True,
+ persistent_workers=True
+ )
+ return dataloader, dataset.metadata
+
+
+def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int):
+ model_cfg = dict(
+ **config.arch.__pydantic_extra__, # type: ignore
+
+ batch_size=config.global_batch_size // world_size,
+
+ vocab_size=train_metadata.vocab_size,
+ seq_len=train_metadata.seq_len,
+ num_puzzle_identifiers=train_metadata.num_puzzle_identifiers,
+ causal=False # Non-autoregressive
+ )
+
+ # Instantiate model with loss head
+ model_cls = load_model_class(config.arch.name)
+ loss_head_cls = load_model_class(config.arch.loss.name)
+
+ with torch.device("cuda"):
+ model: nn.Module = model_cls(model_cfg)
+ model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__) # type: ignore
+ if "DISABLE_COMPILE" not in os.environ:
+ model = torch.compile(model, dynamic=False, fullgraph=True) # type: ignore
+
+ # Broadcast parameters from rank 0
+ if world_size > 1:
+ with torch.no_grad():
+ for param in list(model.parameters()) + list(model.buffers()):
+ dist.broadcast(param, src=0)
+
+ # Optimizers and lr
+ optimizers = [
+ CastedSparseEmbeddingSignSGD_Distributed(
+ model.model.puzzle_emb.buffers(), # type: ignore
+
+ lr=0, # Needs to be set by scheduler
+ weight_decay=config.puzzle_emb_weight_decay,
+
+ world_size=world_size
+ ),
+ AdamATan2(
+ model.parameters(),
+
+ lr=0, # Needs to be set by scheduler
+ weight_decay=config.weight_decay,
+ betas=(config.beta1, config.beta2)
+ )
+ ]
+ optimizer_lrs = [
+ config.puzzle_emb_lr,
+ config.lr
+ ]
+
+ return model, optimizers, optimizer_lrs
+
+
+def cosine_schedule_with_warmup_lr_lambda(
+ current_step: int, *, base_lr: float, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5
+):
+ if current_step < num_warmup_steps:
+ return base_lr * float(current_step) / float(max(1, num_warmup_steps))
+
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+ return base_lr * (min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))))
+
+
+def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int):
+ # Estimated total training steps
+ total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size)
+
+ # Model
+ model, optimizers, optimizer_lrs = create_model(config, train_metadata, world_size=world_size)
+
+ return TrainState(
+ step=0,
+ total_steps=total_steps,
+
+ model=model,
+ optimizers=optimizers,
+ optimizer_lrs=optimizer_lrs,
+ carry=None
+ )
+
+
+def save_train_state(config: PretrainConfig, train_state: TrainState):
+ # FIXME: Only saved model.
+ if config.checkpoint_path is None:
+ return
+
+ os.makedirs(config.checkpoint_path, exist_ok=True)
+ torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}"))
+
+
+def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState):
+ return cosine_schedule_with_warmup_lr_lambda(
+ current_step=train_state.step,
+ base_lr=base_lr,
+ num_warmup_steps=round(config.lr_warmup_steps),
+ num_training_steps=train_state.total_steps,
+ min_ratio=config.lr_min_ratio
+ )
+
+
+def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int):
+ train_state.step += 1
+ if train_state.step > train_state.total_steps: # At most train_total_steps
+ return
+
+ # To device
+ batch = {k: v.cuda() for k, v in batch.items()}
+
+ # Init carry if it is None
+ if train_state.carry is None:
+ with torch.device("cuda"):
+ train_state.carry = train_state.model.initial_carry(batch) # type: ignore
+
+ # Forward
+ train_state.carry, loss, metrics, _, _ = train_state.model(carry=train_state.carry, batch=batch, return_keys=[])
+
+ ((1 / global_batch_size) * loss).backward()
+
+ # Allreduce
+ if world_size > 1:
+ for param in train_state.model.parameters():
+ if param.grad is not None:
+ dist.all_reduce(param.grad)
+
+ # Apply optimizer
+ lr_this_step = None
+ for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs):
+ lr_this_step = compute_lr(base_lr, config, train_state)
+
+ for param_group in optim.param_groups:
+ param_group['lr'] = lr_this_step
+
+ optim.step()
+ optim.zero_grad()
+
+ # Reduce metrics
+ if len(metrics):
+ assert not any(v.requires_grad for v in metrics.values())
+
+ metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order.
+ # Reduce and reconstruct
+ metric_values = torch.stack([metrics[k] for k in metric_keys])
+ if world_size > 1:
+ dist.reduce(metric_values, dst=0)
+
+ if rank == 0:
+ metric_values = metric_values.cpu().numpy()
+ reduced_metrics = {k: metric_values[i] for i, k in enumerate(metric_keys)}
+
+ # Postprocess
+ count = max(reduced_metrics["count"], 1) # Avoid NaNs
+ reduced_metrics = {f"train/{k}": v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items()}
+
+ reduced_metrics["train/lr"] = lr_this_step
+ return reduced_metrics
+
+
+def evaluate(config: PretrainConfig, train_state: TrainState, eval_loader: torch.utils.data.DataLoader, eval_metadata: PuzzleDatasetMetadata, rank: int, world_size: int):
+ with torch.inference_mode():
+ set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)}
+
+ all_preds = {}
+
+ metric_keys = []
+ metric_values = None
+ metric_global_batch_size = [0 for _ in range(len(set_ids))]
+
+ carry = None
+ for set_name, batch, global_batch_size in eval_loader:
+ # To device
+ batch = {k: v.cuda() for k, v in batch.items()}
+ with torch.device("cuda"):
+ carry = train_state.model.initial_carry(batch) # type: ignore
+
+ # Forward
+ while True:
+ carry, _, metrics, preds, all_finish = train_state.model(carry=carry, batch=batch, return_keys=config.eval_save_outputs)
+
+ if all_finish:
+ break
+
+ for collection in (batch, preds):
+ for k, v in collection.items():
+ if k in config.eval_save_outputs:
+ all_preds.setdefault(k, [])
+ all_preds[k].append(v.cpu()) # Move to CPU for saving GPU memory
+
+ del carry, preds, batch, all_finish
+
+ # Aggregate
+ set_id = set_ids[set_name]
+
+ if metric_values is None:
+ metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order.
+ metric_values = torch.zeros((len(set_ids), len(metrics.values())), dtype=torch.float32, device="cuda")
+
+ metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys])
+ metric_global_batch_size[set_id] += global_batch_size
+
+ if len(all_preds) and config.checkpoint_path is not None:
+ all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()}
+
+ os.makedirs(config.checkpoint_path, exist_ok=True)
+ torch.save(all_preds, os.path.join(config.checkpoint_path, f"step_{train_state.step}_all_preds.{rank}"))
+
+ # Logging
+ # Reduce to rank 0
+ if metric_values is not None:
+ if world_size > 1:
+ dist.reduce(metric_values, dst=0)
+
+ if rank == 0:
+ reduced_metrics = metric_values.cpu().numpy()
+ reduced_metrics = {set_name: {metric_name: reduced_metrics[set_id, metric_id] for metric_id, metric_name in enumerate(metric_keys)}
+ for set_id, set_name in enumerate(set_ids)}
+
+ # Postprocess
+ for set_name, metrics in reduced_metrics.items():
+ count = metrics.pop("count")
+ reduced_metrics[set_name] = {k: v / count for k, v in metrics.items()}
+
+ return reduced_metrics
+
+
+def save_code_and_config(config: PretrainConfig):
+ if config.checkpoint_path is None or wandb.run is None:
+ return
+
+ os.makedirs(config.checkpoint_path, exist_ok=True)
+
+ # Copy code
+ code_list = [
+ get_model_source_path(config.arch.name),
+ get_model_source_path(config.arch.loss.name)
+ ]
+ for code_file in code_list:
+ if code_file is not None:
+ code_name = os.path.basename(code_file)
+
+ shutil.copy(code_file, os.path.join(config.checkpoint_path, code_name))
+
+ # Dump config as yaml
+ config_file = os.path.join(config.checkpoint_path, "all_config.yaml")
+ with open(config_file, "wt") as f:
+ yaml.dump(config.model_dump(), f)
+
+ # Log code
+ wandb.run.log_code(config.checkpoint_path)
+
+
+def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> PretrainConfig:
+ objects = [None]
+ if rank == 0:
+ config = PretrainConfig(**hydra_config) # type: ignore
+
+ # Naming
+ if config.project_name is None:
+ config.project_name = f"{os.path.basename(config.data_path).capitalize()} ACT-torch"
+ if config.run_name is None:
+ config.run_name = f"{config.arch.name.split('@')[-1]} {coolname.generate_slug(2)}"
+ if config.checkpoint_path is None:
+ config.checkpoint_path = os.path.join("checkpoints", config.project_name, config.run_name)
+
+ objects = [config]
+
+ if world_size > 1:
+ dist.broadcast_object_list(objects, src=0)
+
+ return objects[0] # type: ignore
+
+
+@hydra.main(config_path="config", config_name="cfg_pretrain", version_base=None)
+def launch(hydra_config: DictConfig):
+ RANK = 0
+ WORLD_SIZE = 1
+
+ # Initialize distributed training if in distributed environment (e.g. torchrun)
+ if "LOCAL_RANK" in os.environ:
+ # Initialize distributed, default device and dtype
+ dist.init_process_group(backend="nccl")
+
+ RANK = dist.get_rank()
+ WORLD_SIZE = dist.get_world_size()
+
+ torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
+
+ # Load sync'ed config
+ config = load_synced_config(hydra_config, rank=RANK, world_size=WORLD_SIZE)
+
+ # Seed RNGs to ensure consistency
+ torch.random.manual_seed(config.seed + RANK)
+
+ # Dataset
+ train_epochs_per_iter = config.eval_interval if config.eval_interval is not None else config.epochs
+ total_iters = config.epochs // train_epochs_per_iter
+
+ assert config.epochs % train_epochs_per_iter == 0, "Eval interval must be a divisor of total epochs."
+
+ train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=train_epochs_per_iter, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
+ eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
+
+ # Train state
+ train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE)
+
+ # Progress bar and logger
+ progress_bar = None
+ if RANK == 0:
+ progress_bar = tqdm.tqdm(total=train_state.total_steps)
+
+ wandb.init(project=config.project_name, name=config.run_name, config=config.model_dump(), settings=wandb.Settings(_disable_stats=True)) # type: ignore
+ wandb.log({"num_params": sum(x.numel() for x in train_state.model.parameters())}, step=0)
+ save_code_and_config(config)
+
+ # Training Loop
+ for _iter_id in range(total_iters):
+ print (f"[Rank {RANK}, World Size {WORLD_SIZE}]: Epoch {_iter_id * train_epochs_per_iter}")
+
+ ############ Train Iter
+ train_state.model.train()
+ for set_name, batch, global_batch_size in train_loader:
+ metrics = train_batch(config, train_state, batch, global_batch_size, rank=RANK, world_size=WORLD_SIZE)
+
+ if RANK == 0 and metrics is not None:
+ wandb.log(metrics, step=train_state.step)
+ progress_bar.update(train_state.step - progress_bar.n) # type: ignore
+
+ ############ Evaluation
+ train_state.model.eval()
+ metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE)
+
+ if RANK == 0 and metrics is not None:
+ wandb.log(metrics, step=train_state.step)
+
+ ############ Checkpointing
+ if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)):
+ save_train_state(config, train_state)
+
+ # finalize
+ if dist.is_initialized():
+ dist.destroy_process_group()
+ wandb.finish()
+
+
+if __name__ == "__main__":
+ launch()