diff options
| author | zhang <zch921005@126.com> | 2022-06-27 22:53:39 +0800 |
|---|---|---|
| committer | zhang <zch921005@126.com> | 2022-06-27 22:53:39 +0800 |
| commit | 9f484de76dde29c22237c8acad21fb27263b79a4 (patch) | |
| tree | baa867f1311b40385604dda1a576ac7ba78db429 /fine_tune/bert/demo.py | |
| parent | 462bd2c944beed1667df30aa7fa626b972a5dfc9 (diff) | |
bert model
Diffstat (limited to 'fine_tune/bert/demo.py')
| -rw-r--r-- | fine_tune/bert/demo.py | 81 |
1 files changed, 81 insertions, 0 deletions
diff --git a/fine_tune/bert/demo.py b/fine_tune/bert/demo.py new file mode 100644 index 0000000..c64f30f --- /dev/null +++ b/fine_tune/bert/demo.py @@ -0,0 +1,81 @@ + +import torch +import re +from transformers import BertTokenizer +import pandas as pd + + +if torch.cuda.is_available(): + device = torch.device("cuda") + print(f'There are {torch.cuda.device_count()} GPU(s) available.') + print('Device name:', torch.cuda.get_device_name(0)) + +else: + print('No GPU available, using the CPU instead.') + device = torch.device("cpu") + + +def text_preprocessing(text): + """ + - Remove entity mentions (eg. '@united') + - Correct errors (eg. '&' to '&') + @param text (str): a string to be processed. + @return text (Str): the processed string. + """ + # Remove '@name' + text = re.sub(r'(@.*?)[\s]', ' ', text) + + # Replace '&' with '&' + text = re.sub(r'&', '&', text) + + # Remove trailing whitespace + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +model_name = 'bert-base-uncased' +tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) +# Specify `MAX_LEN` +MAX_LEN = 64 + +def preprocessing_for_bert(data): + """Perform required preprocessing steps for pretrained BERT. + @param data (np.array): Array of texts to be processed. + @return input_ids (torch.Tensor): Tensor of token ids to be fed to a model. + @return attention_masks (torch.Tensor): Tensor of indices specifying which + tokens should be attended to by the model. + """ + # Create empty lists to store outputs + input_ids = [] + attention_masks = [] + + # For every sentence... + for sent in data: + # `encode_plus` will: + # (1) Tokenize the sentence + # (2) Add the `[CLS]` and `[SEP]` token to the start and end + # (3) Truncate/Pad sentence to max length + # (4) Map tokens to their IDs + # (5) Create attention mask + # (6) Return a dictionary of outputs + encoded_sent = tokenizer.encode_plus( + text=text_preprocessing(sent), # Preprocess sentence + add_special_tokens=True, # Add `[CLS]` and `[SEP]` + max_length=MAX_LEN, # Max length to truncate/pad + pad_to_max_length=True, # Pad sentence to max length + # return_tensors='pt', # Return PyTorch tensor + return_attention_mask=True # Return attention mask + ) + + # Add the outputs to the lists + input_ids.append(encoded_sent.get('input_ids')) + attention_masks.append(encoded_sent.get('attention_mask')) + + # Convert lists to tensors + input_ids = torch.tensor(input_ids) + attention_masks = torch.tensor(attention_masks) + + return input_ids, attention_masks + + |
