summaryrefslogtreecommitdiff
path: root/learn_torch/basics/mha.py
blob: d9d392d5db6193d5cd8e6ea21f0fea45bd76b4f8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
from torch import nn

if __name__ == '__main__':

    mha = nn.MultiheadAttention(embed_dim=768, num_heads=12, kdim=10, vdim=20)

    query = torch.randn(10, 1, 768)
    key = torch.randn(5, 1, 10)
    value = torch.randn(5, 1, 20)

    attn_output, attn_output_weights = mha(query, key, value)
    print(attn_output.shape)
    print(attn_output_weights.shape)