summaryrefslogtreecommitdiff
path: root/pretrain.py
diff options
context:
space:
mode:
Diffstat (limited to 'pretrain.py')
-rw-r--r--pretrain.py3
1 files changed, 1 insertions, 2 deletions
diff --git a/pretrain.py b/pretrain.py
index b939318..245cb5c 100644
--- a/pretrain.py
+++ b/pretrain.py
@@ -16,7 +16,6 @@ import coolname
import hydra
import pydantic
from omegaconf import DictConfig
-from wandb.util import make_artifact_name_safe
from adam_atan2 import AdamATan2
from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata
@@ -126,7 +125,7 @@ def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata,
model: nn.Module = model_cls(model_cfg)
model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__) # type: ignore
if "DISABLE_COMPILE" not in os.environ:
- model = torch.compile(model, dynamic=False, fullgraph=True) # type: ignore
+ model = torch.compile(model, dynamic=False) # type: ignore
# Broadcast parameters from rank 0
if world_size > 1: