From 9f484de76dde29c22237c8acad21fb27263b79a4 Mon Sep 17 00:00:00 2001 From: zhang Date: Mon, 27 Jun 2022 22:53:39 +0800 Subject: bert model --- fine_tune/bert_parameters.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 fine_tune/bert_parameters.py (limited to 'fine_tune/bert_parameters.py') diff --git a/fine_tune/bert_parameters.py b/fine_tune/bert_parameters.py new file mode 100644 index 0000000..98e2b15 --- /dev/null +++ b/fine_tune/bert_parameters.py @@ -0,0 +1,42 @@ + +from transformers import BertModel, BertForSequenceClassification +from collections import defaultdict +import matplotlib.pyplot as plt + + +model_name = 'bert-base-uncased' + +model = BertModel.from_pretrained(model_name) +cls_model = BertForSequenceClassification.from_pretrained(model_name) + +total_params = 0 +total_learnable_params = 0 +total_embedding_params = 0 +total_encoder_params = 0 +total_pooler_params = 0 + +params_dict = defaultdict(float) + +for name, para in model.named_parameters(): + print(name, para.shape, para.numel()) + if para.requires_grad: + total_learnable_params += para.numel() + total_params += para.numel() + if 'embedding' in name: + params_dict['embedding'] += para.numel() + total_embedding_params += para.numel() + if 'encoder' in name: + layer_index = name.split('.')[2] + params_dict['encoder({})'.format(layer_index)] += para.numel() + total_encoder_params += para.numel() + if 'pooler' in name: + params_dict['pooler'] += para.numel() + total_pooler_params += para.numel() + + +print(total_params) +print(total_learnable_params) +print(params_dict) +print(total_embedding_params) +print(total_encoder_params) +print(total_pooler_params) \ No newline at end of file -- cgit v1.2.3