diff options
Diffstat (limited to 'learn_torch/basics/mha.py')
| -rw-r--r-- | learn_torch/basics/mha.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/learn_torch/basics/mha.py b/learn_torch/basics/mha.py new file mode 100644 index 0000000..d9d392d --- /dev/null +++ b/learn_torch/basics/mha.py @@ -0,0 +1,16 @@ +import torch +from torch import nn + +if __name__ == '__main__': + + mha = nn.MultiheadAttention(embed_dim=768, num_heads=12, kdim=10, vdim=20) + + query = torch.randn(10, 1, 768) + key = torch.randn(5, 1, 10) + value = torch.randn(5, 1, 20) + + attn_output, attn_output_weights = mha(query, key, value) + print(attn_output.shape) + print(attn_output_weights.shape) + + |
