summaryrefslogtreecommitdiff
path: root/fine_tune/bert/tutorials/03_bert_input_embedding.py
blob: 95da9ef33b17c4e5d25b22d9173e5fe0884d1258 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

from transformers import BertTokenizer, BertModel
from transformers.models.bert import BertModel
import torch
from torch import nn


model_name = 'bert-base-uncased'

tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name, output_hidden_states=True)

test_sent = 'this is a test sentence'

model_input = tokenizer(test_sent, return_tensors='pt')


model.eval()
with torch.no_grad():
    output = model(**model_input)