summaryrefslogtreecommitdiff
path: root/evaluate.py
diff options
context:
space:
mode:
authorOne <imone@tuta.io>2025-07-09 10:13:51 +0800
committerOne <imone@tuta.io>2025-07-09 10:13:51 +0800
commitbd6222774edcec1608a6842d0b06a637a4acef59 (patch)
tree3b95517044286d82a9166bcce3134bbea099fcfe /evaluate.py
parentcaa00bb77fbfce0cc14a45d5ec1c754394f96c1a (diff)
Release
Diffstat (limited to 'evaluate.py')
-rw-r--r--evaluate.py68
1 files changed, 68 insertions, 0 deletions
diff --git a/evaluate.py b/evaluate.py
new file mode 100644
index 0000000..9bc6ba0
--- /dev/null
+++ b/evaluate.py
@@ -0,0 +1,68 @@
+from typing import List
+import yaml
+import os
+
+import torch
+import torch.distributed as dist
+
+import pydantic
+from omegaconf import OmegaConf
+from pretrain import PretrainConfig, init_train_state, evaluate, create_dataloader
+
+
+class EvalConfig(pydantic.BaseModel):
+ checkpoint: str
+
+ save_outputs: List[str] = ["inputs", "labels", "puzzle_identifiers", "logits", "q_halt_logits", "q_continue_logits"]
+
+
+def launch():
+ eval_cfg = EvalConfig(**OmegaConf.to_container(OmegaConf.from_cli())) # type: ignore
+
+ RANK = 0
+ WORLD_SIZE = 1
+ # Initialize distributed training if in distributed environment (e.g. torchrun)
+ if "LOCAL_RANK" in os.environ:
+ # Initialize distributed, default device and dtype
+ dist.init_process_group(backend="nccl")
+
+ RANK = dist.get_rank()
+ WORLD_SIZE = dist.get_world_size()
+
+ torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
+
+ with open(os.path.join(os.path.dirname(eval_cfg.checkpoint), "all_config.yaml"), "r") as f:
+ config = PretrainConfig(**yaml.safe_load(f))
+
+ config.eval_save_outputs = eval_cfg.save_outputs
+ config.checkpoint_path = os.path.dirname(eval_cfg.checkpoint)
+
+ # Dataloader
+ train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
+ eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, test_set_limit_examples=LIMIT_EXAMPLES, rank=RANK, world_size=WORLD_SIZE)
+
+ # Models
+ train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE)
+ # Try unwrap torch.compile
+ try:
+ train_state.model.load_state_dict(torch.load(eval_cfg.checkpoint, map_location="cuda"), assign=True)
+ except:
+ train_state.model.load_state_dict({k.removeprefix("_orig_mod."): v for k, v in torch.load(eval_cfg.checkpoint, map_location="cuda").items()}, assign=True)
+
+ train_state.step = 0
+ ckpt_filename = os.path.basename(eval_cfg.checkpoint)
+ if ckpt_filename.startswith("step_"):
+ train_state.step = int(ckpt_filename.removeprefix("step_"))
+
+ # Evaluate
+ print ("Starting evaluation")
+
+ train_state.model.eval()
+ metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE)
+
+ if metrics is not None:
+ print (metrics)
+
+
+if __name__ == "__main__":
+ launch()