From bd6222774edcec1608a6842d0b06a637a4acef59 Mon Sep 17 00:00:00 2001 From: One Date: Wed, 9 Jul 2025 10:13:51 +0800 Subject: Release --- models/common.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 models/common.py (limited to 'models/common.py') diff --git a/models/common.py b/models/common.py new file mode 100644 index 0000000..1a04505 --- /dev/null +++ b/models/common.py @@ -0,0 +1,32 @@ +import math + +import torch +from torch import nn + + +def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0): + # NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor + # This function is a PyTorch version of jax truncated normal init (default init method in flax) + # https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848 + # https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199 + + with torch.no_grad(): + if std == 0: + tensor.zero_() + else: + sqrt2 = math.sqrt(2) + a = math.erf(lower / sqrt2) + b = math.erf(upper / sqrt2) + z = (b - a) / 2 + + c = (2 * math.pi) ** -0.5 + pdf_u = c * math.exp(-0.5 * lower ** 2) + pdf_l = c * math.exp(-0.5 * upper ** 2) + comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2) + + tensor.uniform_(a, b) + tensor.erfinv_() + tensor.mul_(sqrt2 * comp_std) + tensor.clip_(lower * comp_std, upper * comp_std) + + return tensor -- cgit v1.2.3