summaryrefslogtreecommitdiff
path: root/hugface/basics.py
blob: acba6ffef76459aa96ce39e5fc91e179202f9588 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

import transformers
from transformers import pipeline
import torch.nn.functional as F
import torch

from transformers import AutoTokenizer, AutoModelForSequenceClassification

model_name = 'distilbert-base-uncased-finetuned-sst-2-english'
# model_name = 'bert-base-uncased'

model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# clf = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
#
# test_sentence = 'today is not that bad'
test_sentences = ['today is not that bad', 'today is so bad']
# res = clf(test_sentences)
# print(res)
#


batch = tokenizer(test_sentences, padding='max_length', truncation=True, max_length=512, return_tensors='pt')

with torch.no_grad():
    # print(**batch)
    outputs = model(**batch)
    print(outputs)
    scores = F.softmax(outputs.logits, dim=1)
    labels = torch.argmax(scores, dim=1)
    labels = [model.config.id2label[id] for id in labels.tolist()]
    print(labels)