summaryrefslogtreecommitdiff
path: root/models/layers.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/layers.py')
-rw-r--r--models/layers.py150
1 files changed, 150 insertions, 0 deletions
diff --git a/models/layers.py b/models/layers.py
new file mode 100644
index 0000000..4f7dee4
--- /dev/null
+++ b/models/layers.py
@@ -0,0 +1,150 @@
+from typing import Tuple
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from models.common import trunc_normal_init_
+
+
+CosSin = Tuple[torch.Tensor, torch.Tensor]
+
+
+def _find_multiple(a, b):
+ return (-(a // -b)) * b
+
+
+def rotate_half(x: torch.Tensor):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
+ # q, k: [bs, num_heads, seq_len, head_dim]
+ # cos, sin: [seq_len, head_dim]
+ orig_dtype = q.dtype
+ q = q.to(cos.dtype)
+ k = k.to(cos.dtype)
+
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+
+ return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
+
+
+class CastedLinear(nn.Module):
+ def __init__(self,
+ in_features: int,
+ out_features: int,
+ bias: bool):
+ super().__init__()
+ # Truncated LeCun normal init
+ self.weight = nn.Parameter(
+ trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
+ )
+ self.bias = None
+ if bias:
+ # Zero init bias
+ self.bias = nn.Parameter(torch.zeros((out_features, )))
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)
+
+
+class CastedEmbedding(nn.Module):
+ def __init__(self,
+ num_embeddings: int,
+ embedding_dim: int,
+ init_std: float,
+ cast_to: torch.dtype):
+ super().__init__()
+ self.cast_to = cast_to
+
+ # Truncated LeCun normal init
+ self.embedding_weight = nn.Parameter(
+ trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
+ )
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return F.embedding(input, self.embedding_weight.to(self.cast_to))
+
+
+class RotaryEmbedding(nn.Module):
+ def __init__(self, dim, max_position_embeddings, base, device=None):
+ super().__init__()
+
+ # RoPE
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
+ t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
+ freqs = torch.outer(t, inv_freq)
+
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
+ self.sin_cached = nn.Buffer(emb.sin(), persistent=False)
+
+ def forward(self):
+ return self.cos_cached, self.sin_cached
+
+
+class Attention(nn.Module):
+ def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
+ super().__init__()
+
+ self.hidden_size = hidden_size
+ self.head_dim = head_dim
+ self.output_size = head_dim * num_heads
+ self.num_heads = num_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.causal = causal
+
+ self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
+ self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)
+
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, seq_len, _ = hidden_states.shape
+
+ # hidden_states: [bs, seq_len, num_heads, head_dim]
+ qkv = self.qkv_proj(hidden_states)
+
+ # Split head
+ qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim).transpose(-2, -3)
+ query = qkv[:, :self.num_heads]
+ key = qkv[:, self.num_heads: self.num_heads + self.num_key_value_heads]
+ value = qkv[:, self.num_heads + self.num_key_value_heads:]
+
+ # RoPE
+ if cos_sin is not None:
+ cos, sin = cos_sin
+ query, key = apply_rotary_pos_emb(query, key, cos, sin)
+
+ # flash attn
+ attn_output = F.scaled_dot_product_attention(query=query, key=key, value=value, is_causal=self.causal)
+
+ # attn_output: [batch_size, num_heads, seq_len, head_dim]
+ attn_output = attn_output.transpose(-2, -3).view(batch_size, seq_len, self.output_size) # type: ignore
+ return self.o_proj(attn_output)
+
+
+class SwiGLU(nn.Module):
+ def __init__(self, hidden_size: int, expansion: float):
+ super().__init__()
+ inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)
+
+ self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
+ self.down_proj = CastedLinear(inter, hidden_size, bias=False)
+
+ def forward(self, x):
+ gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
+ return self.down_proj(F.silu(gate) * up)
+
+
+def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+
+ variance = hidden_states.square().mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
+ return hidden_states.to(input_dtype)