summaryrefslogtreecommitdiff
path: root/src/training/checkpointing.py
blob: 9ff02dfe7db53dd20cbfe4a71a89918fa8e93116 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""Checkpoint save/load for predictor + optimizer + schedule state.

Only saves trainable components (predictor MLP, optimizer, schedule state).
Frozen models (OLMo, Qwen) are not checkpointed — they load from HuggingFace.
"""

from __future__ import annotations

import os
from typing import Any, Optional

import torch
import torch.nn as nn
import torch.optim as optim


def save_checkpoint(
    save_dir: str,
    step: int,
    predictor: nn.Module,
    optimizer: optim.Optimizer,
    scheduler: Any,
    best_eval_nll: float,
    extra: Optional[dict] = None,
) -> str:
    """Save training checkpoint.

    Args:
        save_dir: directory to save checkpoint
        step: current global step
        predictor: the structure predictor (only MLP params are saved)
        optimizer: AdamW optimizer
        scheduler: LR scheduler
        best_eval_nll: best eval NLL so far
        extra: any additional state to save

    Returns:
        path: path to saved checkpoint
    """
    os.makedirs(save_dir, exist_ok=True)
    path = os.path.join(save_dir, f"checkpoint_step{step}.pt")

    state = {
        "step": step,
        "predictor_state_dict": predictor.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict() if scheduler is not None else None,
        "best_eval_nll": best_eval_nll,
    }
    if extra:
        state.update(extra)

    torch.save(state, path)
    print(f"Checkpoint saved: {path}")
    return path


def load_checkpoint(
    path: str,
    predictor: nn.Module,
    optimizer: Optional[optim.Optimizer] = None,
    scheduler: Optional[Any] = None,
    device: Optional[torch.device] = None,
) -> dict:
    """Load training checkpoint.

    Args:
        path: path to checkpoint file
        predictor: structure predictor to load weights into
        optimizer: optimizer to restore state (optional — skip for eval)
        scheduler: LR scheduler to restore state (optional)
        device: device to map tensors to

    Returns:
        state dict with step, best_eval_nll, and any extras
    """
    map_location = device if device is not None else "cpu"
    state = torch.load(path, map_location=map_location)

    predictor.load_state_dict(state["predictor_state_dict"])
    print(f"Predictor state loaded from {path}")

    if optimizer is not None and "optimizer_state_dict" in state:
        optimizer.load_state_dict(state["optimizer_state_dict"])

    if scheduler is not None and state.get("scheduler_state_dict") is not None:
        scheduler.load_state_dict(state["scheduler_state_dict"])

    return {
        "step": state["step"],
        "best_eval_nll": state.get("best_eval_nll", float("inf")),
    }