summaryrefslogtreecommitdiff
path: root/models/common.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/common.py')
-rw-r--r--models/common.py32
1 files changed, 32 insertions, 0 deletions
diff --git a/models/common.py b/models/common.py
new file mode 100644
index 0000000..1a04505
--- /dev/null
+++ b/models/common.py
@@ -0,0 +1,32 @@
+import math
+
+import torch
+from torch import nn
+
+
+def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):
+ # NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor
+ # This function is a PyTorch version of jax truncated normal init (default init method in flax)
+ # https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848
+ # https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199
+
+ with torch.no_grad():
+ if std == 0:
+ tensor.zero_()
+ else:
+ sqrt2 = math.sqrt(2)
+ a = math.erf(lower / sqrt2)
+ b = math.erf(upper / sqrt2)
+ z = (b - a) / 2
+
+ c = (2 * math.pi) ** -0.5
+ pdf_u = c * math.exp(-0.5 * lower ** 2)
+ pdf_l = c * math.exp(-0.5 * upper ** 2)
+ comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2)
+
+ tensor.uniform_(a, b)
+ tensor.erfinv_()
+ tensor.mul_(sqrt2 * comp_std)
+ tensor.clip_(lower * comp_std, upper * comp_std)
+
+ return tensor