summaryrefslogtreecommitdiff
path: root/models/layers.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/layers.py')
-rw-r--r--models/layers.py1
1 files changed, 0 insertions, 1 deletions
diff --git a/models/layers.py b/models/layers.py
index 008a172..0394744 100644
--- a/models/layers.py
+++ b/models/layers.py
@@ -131,7 +131,6 @@ class Attention(nn.Module):
if isinstance(attn_output, tuple): # fa2 and fa3 compatibility
attn_output = attn_output[0]
- # attn_output: [batch_size, num_heads, seq_len, head_dim]
attn_output = attn_output.view(batch_size, seq_len, self.output_size) # type: ignore
return self.o_proj(attn_output)