1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
|
from typing import Tuple
import torch
from torch import nn
import torch.nn.functional as F
from models.common import trunc_normal_init_
CosSin = Tuple[torch.Tensor, torch.Tensor]
def _find_multiple(a, b):
return (-(a // -b)) * b
def rotate_half(x: torch.Tensor):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
# q, k: [bs, num_heads, seq_len, head_dim]
# cos, sin: [seq_len, head_dim]
orig_dtype = q.dtype
q = q.to(cos.dtype)
k = k.to(cos.dtype)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
class CastedLinear(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
bias: bool):
super().__init__()
# Truncated LeCun normal init
self.weight = nn.Parameter(
trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
)
self.bias = None
if bias:
# Zero init bias
self.bias = nn.Parameter(torch.zeros((out_features, )))
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)
class CastedEmbedding(nn.Module):
def __init__(self,
num_embeddings: int,
embedding_dim: int,
init_std: float,
cast_to: torch.dtype):
super().__init__()
self.cast_to = cast_to
# Truncated LeCun normal init
self.embedding_weight = nn.Parameter(
trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.embedding(input, self.embedding_weight.to(self.cast_to))
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings, base, device=None):
super().__init__()
# RoPE
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
freqs = torch.outer(t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
self.sin_cached = nn.Buffer(emb.sin(), persistent=False)
def forward(self):
return self.cos_cached, self.sin_cached
class Attention(nn.Module):
def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
super().__init__()
self.hidden_size = hidden_size
self.head_dim = head_dim
self.output_size = head_dim * num_heads
self.num_heads = num_heads
self.num_key_value_heads = num_key_value_heads
self.causal = causal
self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, _ = hidden_states.shape
# hidden_states: [bs, seq_len, num_heads, head_dim]
qkv = self.qkv_proj(hidden_states)
# Split head
qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim).transpose(-2, -3)
query = qkv[:, :self.num_heads]
key = qkv[:, self.num_heads: self.num_heads + self.num_key_value_heads]
value = qkv[:, self.num_heads + self.num_key_value_heads:]
# RoPE
if cos_sin is not None:
cos, sin = cos_sin
query, key = apply_rotary_pos_emb(query, key, cos, sin)
# flash attn
attn_output = F.scaled_dot_product_attention(query=query, key=key, value=value, is_causal=self.causal)
# attn_output: [batch_size, num_heads, seq_len, head_dim]
attn_output = attn_output.transpose(-2, -3).view(batch_size, seq_len, self.output_size) # type: ignore
return self.o_proj(attn_output)
class SwiGLU(nn.Module):
def __init__(self, hidden_size: int, expansion: float):
super().__init__()
inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)
self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
self.down_proj = CastedLinear(inter, hidden_size, bias=False)
def forward(self, x):
gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
return self.down_proj(F.silu(gate) * up)
def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
return hidden_states.to(input_dtype)
|