diff options
| author | zhang <zch921005@126.com> | 2022-09-01 22:39:16 +0800 |
|---|---|---|
| committer | zhang <zch921005@126.com> | 2022-09-01 22:39:16 +0800 |
| commit | a1930ed563d5a3905c9504dbcff2fb00653233da (patch) | |
| tree | 33ac93fab92f157de6ad1a91515ec9fcc9985e43 /learn_torch/basics | |
| parent | 257cbbd2270e9a8756798e256a0bbd29c0fd83db (diff) | |
residual connection
Diffstat (limited to 'learn_torch/basics')
| -rw-r--r-- | learn_torch/basics/add_norm.py | 17 | ||||
| -rw-r--r-- | learn_torch/basics/mha.py | 16 |
2 files changed, 33 insertions, 0 deletions
diff --git a/learn_torch/basics/add_norm.py b/learn_torch/basics/add_norm.py new file mode 100644 index 0000000..6d733f7 --- /dev/null +++ b/learn_torch/basics/add_norm.py @@ -0,0 +1,17 @@ + +import torch +from transformers.models.bert import BertModel, BertTokenizer + + +if __name__ == '__main__': + model_name = 'bert-base-uncased' + tokenizer = BertTokenizer.from_pretrained(model_name) + model = BertModel.from_pretrained(model_name, output_hidden_states=True) + + test_sent = 'this is a test sentence' + + model_input = tokenizer(test_sent, return_tensors='pt') + model.eval() + with torch.no_grad(): + output = model(**model_input) + diff --git a/learn_torch/basics/mha.py b/learn_torch/basics/mha.py new file mode 100644 index 0000000..d9d392d --- /dev/null +++ b/learn_torch/basics/mha.py @@ -0,0 +1,16 @@ +import torch +from torch import nn + +if __name__ == '__main__': + + mha = nn.MultiheadAttention(embed_dim=768, num_heads=12, kdim=10, vdim=20) + + query = torch.randn(10, 1, 768) + key = torch.randn(5, 1, 10) + value = torch.randn(5, 1, 20) + + attn_output, attn_output_weights = mha(query, key, value) + print(attn_output.shape) + print(attn_output_weights.shape) + + |
