diff options
| author | One <imone@tuta.io> | 2025-07-21 18:40:40 +0800 |
|---|---|---|
| committer | One <imone@tuta.io> | 2025-07-21 18:40:40 +0800 |
| commit | 171e2fcde636bcb7e6c0073a9983ed5252f04753 (patch) | |
| tree | d7844d28ad5f289c25a046e58ec9d20216cfba44 /models | |
| parent | bd6222774edcec1608a6842d0b06a637a4acef59 (diff) | |
Update
Diffstat (limited to 'models')
| -rw-r--r-- | models/layers.py | 26 |
1 files changed, 17 insertions, 9 deletions
diff --git a/models/layers.py b/models/layers.py index 4f7dee4..008a172 100644 --- a/models/layers.py +++ b/models/layers.py @@ -4,6 +4,12 @@ import torch from torch import nn import torch.nn.functional as F +try: + from flash_attn_interface import flash_attn_func # type: ignore[import] +except ImportError: + # Fallback to FlashAttention 2 + from flash_attn import flash_attn_func # type: ignore[import] + from models.common import trunc_normal_init_ @@ -22,14 +28,14 @@ def rotate_half(x: torch.Tensor): def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): - # q, k: [bs, num_heads, seq_len, head_dim] + # q, k: [bs, seq_len, num_heads, head_dim] # cos, sin: [seq_len, head_dim] orig_dtype = q.dtype q = q.to(cos.dtype) k = k.to(cos.dtype) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2)) + k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2)) return q_embed.to(orig_dtype), k_embed.to(orig_dtype) @@ -110,10 +116,10 @@ class Attention(nn.Module): qkv = self.qkv_proj(hidden_states) # Split head - qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim).transpose(-2, -3) - query = qkv[:, :self.num_heads] - key = qkv[:, self.num_heads: self.num_heads + self.num_key_value_heads] - value = qkv[:, self.num_heads + self.num_key_value_heads:] + qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) + query = qkv[:, :, :self.num_heads] + key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads] + value = qkv[:, :, self.num_heads + self.num_key_value_heads:] # RoPE if cos_sin is not None: @@ -121,10 +127,12 @@ class Attention(nn.Module): query, key = apply_rotary_pos_emb(query, key, cos, sin) # flash attn - attn_output = F.scaled_dot_product_attention(query=query, key=key, value=value, is_causal=self.causal) + attn_output = flash_attn_func(q=query, k=key, v=value, causal=self.causal) + if isinstance(attn_output, tuple): # fa2 and fa3 compatibility + attn_output = attn_output[0] # attn_output: [batch_size, num_heads, seq_len, head_dim] - attn_output = attn_output.transpose(-2, -3).view(batch_size, seq_len, self.output_size) # type: ignore + attn_output = attn_output.view(batch_size, seq_len, self.output_size) # type: ignore return self.o_proj(attn_output) |
