diff options
| author | zhang <zch921005@126.com> | 2022-06-19 10:53:08 +0800 |
|---|---|---|
| committer | zhang <zch921005@126.com> | 2022-06-19 10:53:08 +0800 |
| commit | 63998272c86c3177b857c71d03d583f69e3af47c (patch) | |
| tree | 6776a1cf370bb35d8957c39807ca983fb60ec9ed /hugface/basics.py | |
| parent | ed026519d959ecc60a895f379c228de5df77ffb0 (diff) | |
huggingface tokenizer
Diffstat (limited to 'hugface/basics.py')
| -rw-r--r-- | hugface/basics.py | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/hugface/basics.py b/hugface/basics.py new file mode 100644 index 0000000..acba6ff --- /dev/null +++ b/hugface/basics.py @@ -0,0 +1,33 @@ + +import transformers +from transformers import pipeline +import torch.nn.functional as F +import torch + +from transformers import AutoTokenizer, AutoModelForSequenceClassification + +model_name = 'distilbert-base-uncased-finetuned-sst-2-english' +# model_name = 'bert-base-uncased' + +model = AutoModelForSequenceClassification.from_pretrained(model_name) +tokenizer = AutoTokenizer.from_pretrained(model_name) + +# clf = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer) +# +# test_sentence = 'today is not that bad' +test_sentences = ['today is not that bad', 'today is so bad'] +# res = clf(test_sentences) +# print(res) +# + + +batch = tokenizer(test_sentences, padding='max_length', truncation=True, max_length=512, return_tensors='pt') + +with torch.no_grad(): + # print(**batch) + outputs = model(**batch) + print(outputs) + scores = F.softmax(outputs.logits, dim=1) + labels = torch.argmax(scores, dim=1) + labels = [model.config.id2label[id] for id in labels.tolist()] + print(labels) |
