summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
authorOne <imone@tuta.io>2025-07-21 18:40:40 +0800
committerOne <imone@tuta.io>2025-07-21 18:40:40 +0800
commit171e2fcde636bcb7e6c0073a9983ed5252f04753 (patch)
treed7844d28ad5f289c25a046e58ec9d20216cfba44 /models
parentbd6222774edcec1608a6842d0b06a637a4acef59 (diff)
Update
Diffstat (limited to 'models')
-rw-r--r--models/layers.py26
1 files changed, 17 insertions, 9 deletions
diff --git a/models/layers.py b/models/layers.py
index 4f7dee4..008a172 100644
--- a/models/layers.py
+++ b/models/layers.py
@@ -4,6 +4,12 @@ import torch
from torch import nn
import torch.nn.functional as F
+try:
+ from flash_attn_interface import flash_attn_func # type: ignore[import]
+except ImportError:
+ # Fallback to FlashAttention 2
+ from flash_attn import flash_attn_func # type: ignore[import]
+
from models.common import trunc_normal_init_
@@ -22,14 +28,14 @@ def rotate_half(x: torch.Tensor):
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
- # q, k: [bs, num_heads, seq_len, head_dim]
+ # q, k: [bs, seq_len, num_heads, head_dim]
# cos, sin: [seq_len, head_dim]
orig_dtype = q.dtype
q = q.to(cos.dtype)
k = k.to(cos.dtype)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
+ q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
+ k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))
return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
@@ -110,10 +116,10 @@ class Attention(nn.Module):
qkv = self.qkv_proj(hidden_states)
# Split head
- qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim).transpose(-2, -3)
- query = qkv[:, :self.num_heads]
- key = qkv[:, self.num_heads: self.num_heads + self.num_key_value_heads]
- value = qkv[:, self.num_heads + self.num_key_value_heads:]
+ qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
+ query = qkv[:, :, :self.num_heads]
+ key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads]
+ value = qkv[:, :, self.num_heads + self.num_key_value_heads:]
# RoPE
if cos_sin is not None:
@@ -121,10 +127,12 @@ class Attention(nn.Module):
query, key = apply_rotary_pos_emb(query, key, cos, sin)
# flash attn
- attn_output = F.scaled_dot_product_attention(query=query, key=key, value=value, is_causal=self.causal)
+ attn_output = flash_attn_func(q=query, k=key, v=value, causal=self.causal)
+ if isinstance(attn_output, tuple): # fa2 and fa3 compatibility
+ attn_output = attn_output[0]
# attn_output: [batch_size, num_heads, seq_len, head_dim]
- attn_output = attn_output.transpose(-2, -3).view(batch_size, seq_len, self.output_size) # type: ignore
+ attn_output = attn_output.view(batch_size, seq_len, self.output_size) # type: ignore
return self.o_proj(attn_output)