diff options
| author | zhang <zch921005@126.com> | 2022-09-01 22:39:16 +0800 |
|---|---|---|
| committer | zhang <zch921005@126.com> | 2022-09-01 22:39:16 +0800 |
| commit | a1930ed563d5a3905c9504dbcff2fb00653233da (patch) | |
| tree | 33ac93fab92f157de6ad1a91515ec9fcc9985e43 /learn_torch/basics/add_norm.py | |
| parent | 257cbbd2270e9a8756798e256a0bbd29c0fd83db (diff) | |
residual connection
Diffstat (limited to 'learn_torch/basics/add_norm.py')
| -rw-r--r-- | learn_torch/basics/add_norm.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/learn_torch/basics/add_norm.py b/learn_torch/basics/add_norm.py new file mode 100644 index 0000000..6d733f7 --- /dev/null +++ b/learn_torch/basics/add_norm.py @@ -0,0 +1,17 @@ + +import torch +from transformers.models.bert import BertModel, BertTokenizer + + +if __name__ == '__main__': + model_name = 'bert-base-uncased' + tokenizer = BertTokenizer.from_pretrained(model_name) + model = BertModel.from_pretrained(model_name, output_hidden_states=True) + + test_sent = 'this is a test sentence' + + model_input = tokenizer(test_sent, return_tensors='pt') + model.eval() + with torch.no_grad(): + output = model(**model_input) + |
