summaryrefslogtreecommitdiff
path: root/fine_tune/bert/demo.py
diff options
context:
space:
mode:
authorzhang <zch921005@126.com>2022-06-27 22:53:39 +0800
committerzhang <zch921005@126.com>2022-06-27 22:53:39 +0800
commit9f484de76dde29c22237c8acad21fb27263b79a4 (patch)
treebaa867f1311b40385604dda1a576ac7ba78db429 /fine_tune/bert/demo.py
parent462bd2c944beed1667df30aa7fa626b972a5dfc9 (diff)
bert model
Diffstat (limited to 'fine_tune/bert/demo.py')
-rw-r--r--fine_tune/bert/demo.py81
1 files changed, 81 insertions, 0 deletions
diff --git a/fine_tune/bert/demo.py b/fine_tune/bert/demo.py
new file mode 100644
index 0000000..c64f30f
--- /dev/null
+++ b/fine_tune/bert/demo.py
@@ -0,0 +1,81 @@
+
+import torch
+import re
+from transformers import BertTokenizer
+import pandas as pd
+
+
+if torch.cuda.is_available():
+ device = torch.device("cuda")
+ print(f'There are {torch.cuda.device_count()} GPU(s) available.')
+ print('Device name:', torch.cuda.get_device_name(0))
+
+else:
+ print('No GPU available, using the CPU instead.')
+ device = torch.device("cpu")
+
+
+def text_preprocessing(text):
+ """
+ - Remove entity mentions (eg. '@united')
+ - Correct errors (eg. '&amp;' to '&')
+ @param text (str): a string to be processed.
+ @return text (Str): the processed string.
+ """
+ # Remove '@name'
+ text = re.sub(r'(@.*?)[\s]', ' ', text)
+
+ # Replace '&amp;' with '&'
+ text = re.sub(r'&amp;', '&', text)
+
+ # Remove trailing whitespace
+ text = re.sub(r'\s+', ' ', text).strip()
+
+ return text
+
+
+model_name = 'bert-base-uncased'
+tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
+# Specify `MAX_LEN`
+MAX_LEN = 64
+
+def preprocessing_for_bert(data):
+ """Perform required preprocessing steps for pretrained BERT.
+ @param data (np.array): Array of texts to be processed.
+ @return input_ids (torch.Tensor): Tensor of token ids to be fed to a model.
+ @return attention_masks (torch.Tensor): Tensor of indices specifying which
+ tokens should be attended to by the model.
+ """
+ # Create empty lists to store outputs
+ input_ids = []
+ attention_masks = []
+
+ # For every sentence...
+ for sent in data:
+ # `encode_plus` will:
+ # (1) Tokenize the sentence
+ # (2) Add the `[CLS]` and `[SEP]` token to the start and end
+ # (3) Truncate/Pad sentence to max length
+ # (4) Map tokens to their IDs
+ # (5) Create attention mask
+ # (6) Return a dictionary of outputs
+ encoded_sent = tokenizer.encode_plus(
+ text=text_preprocessing(sent), # Preprocess sentence
+ add_special_tokens=True, # Add `[CLS]` and `[SEP]`
+ max_length=MAX_LEN, # Max length to truncate/pad
+ pad_to_max_length=True, # Pad sentence to max length
+ # return_tensors='pt', # Return PyTorch tensor
+ return_attention_mask=True # Return attention mask
+ )
+
+ # Add the outputs to the lists
+ input_ids.append(encoded_sent.get('input_ids'))
+ attention_masks.append(encoded_sent.get('attention_mask'))
+
+ # Convert lists to tensors
+ input_ids = torch.tensor(input_ids)
+ attention_masks = torch.tensor(attention_masks)
+
+ return input_ids, attention_masks
+
+