summaryrefslogtreecommitdiff
path: root/learn_torch/basics/add_norm.py
blob: 6d733f7ad36431ac1e491350a763384b3f36eca7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

import torch
from transformers.models.bert import BertModel, BertTokenizer


if __name__ == '__main__':
    model_name = 'bert-base-uncased'
    tokenizer = BertTokenizer.from_pretrained(model_name)
    model = BertModel.from_pretrained(model_name, output_hidden_states=True)

    test_sent = 'this is a test sentence'

    model_input = tokenizer(test_sent, return_tensors='pt')
    model.eval()
    with torch.no_grad():
        output = model(**model_input)