| @@ -78,7 +78,7 @@ It contains of parameters of BERT model and options for training, which is set i | |||||
| ### Options: | ### Options: | ||||
| ``` | ``` | ||||
| Pre-Training: | Pre-Training: | ||||
| bert_network version of BERT model: base | large, default is base | |||||
| bert_network version of BERT model: base | nezha, default is base | |||||
| loss_scale_value initial value of loss scale: N, default is 2^32 | loss_scale_value initial value of loss scale: N, default is 2^32 | ||||
| scale_factor factor used to update loss scale: N, default is 2 | scale_factor factor used to update loss scale: N, default is 2 | ||||
| scale_window steps for once updatation of loss scale: N, default is 1000 | scale_window steps for once updatation of loss scale: N, default is 1000 | ||||
| @@ -26,30 +26,36 @@ cfg = edict({ | |||||
| 'optimizer': 'Lamb', | 'optimizer': 'Lamb', | ||||
| 'AdamWeightDecayDynamicLR': edict({ | 'AdamWeightDecayDynamicLR': edict({ | ||||
| 'learning_rate': 3e-5, | 'learning_rate': 3e-5, | ||||
| 'end_learning_rate': 0.0, | |||||
| 'end_learning_rate': 1e-7, | |||||
| 'power': 5.0, | 'power': 5.0, | ||||
| 'weight_decay': 1e-5, | 'weight_decay': 1e-5, | ||||
| 'eps': 1e-6, | 'eps': 1e-6, | ||||
| }), | }), | ||||
| 'Lamb': edict({ | 'Lamb': edict({ | ||||
| 'start_learning_rate': 3e-5, | 'start_learning_rate': 3e-5, | ||||
| 'end_learning_rate': 0.0, | |||||
| 'end_learning_rate': 1e-7, | |||||
| 'power': 10.0, | 'power': 10.0, | ||||
| 'warmup_steps': 10000, | 'warmup_steps': 10000, | ||||
| 'weight_decay': 0.01, | 'weight_decay': 0.01, | ||||
| 'eps': 1e-6, | 'eps': 1e-6, | ||||
| 'decay_filter': lambda x: False, | |||||
| }), | }), | ||||
| 'Momentum': edict({ | 'Momentum': edict({ | ||||
| 'learning_rate': 2e-5, | 'learning_rate': 2e-5, | ||||
| 'momentum': 0.9, | 'momentum': 0.9, | ||||
| }), | }), | ||||
| }) | }) | ||||
| ''' | |||||
| Including two kinds of network: \ | |||||
| base: Goole BERT-base(the base version of BERT model). | |||||
| large: BERT-NEZHA(a Chinese pretrained language model developed by Huawei, which introduced a improvement of \ | |||||
| Functional Relative Posetional Encoding as an effective positional encoding scheme). | |||||
| ''' | |||||
| if cfg.bert_network == 'base': | if cfg.bert_network == 'base': | ||||
| bert_net_cfg = BertConfig( | bert_net_cfg = BertConfig( | ||||
| batch_size=16, | |||||
| batch_size=32, | |||||
| seq_length=128, | seq_length=128, | ||||
| vocab_size=21136, | |||||
| vocab_size=21128, | |||||
| hidden_size=768, | hidden_size=768, | ||||
| num_hidden_layers=12, | num_hidden_layers=12, | ||||
| num_attention_heads=12, | num_attention_heads=12, | ||||
| @@ -66,13 +72,13 @@ if cfg.bert_network == 'base': | |||||
| dtype=mstype.float32, | dtype=mstype.float32, | ||||
| compute_type=mstype.float16, | compute_type=mstype.float16, | ||||
| ) | ) | ||||
| else: | |||||
| if cfg.bert_network == 'nezha': | |||||
| bert_net_cfg = BertConfig( | bert_net_cfg = BertConfig( | ||||
| batch_size=16, | |||||
| batch_size=32, | |||||
| seq_length=128, | seq_length=128, | ||||
| vocab_size=21136, | |||||
| vocab_size=21128, | |||||
| hidden_size=1024, | hidden_size=1024, | ||||
| num_hidden_layers=12, | |||||
| num_hidden_layers=24, | |||||
| num_attention_heads=16, | num_attention_heads=16, | ||||
| intermediate_size=4096, | intermediate_size=4096, | ||||
| hidden_act="gelu", | hidden_act="gelu", | ||||
| @@ -31,7 +31,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e | |||||
| files = os.listdir(data_dir) | files = os.listdir(data_dir) | ||||
| data_files = [] | data_files = [] | ||||
| for file_name in files: | for file_name in files: | ||||
| data_files.append(data_dir+file_name) | |||||
| data_files.append(os.path.join(data_dir, file_name)) | |||||
| ds = de.TFRecordDataset(data_files, schema_dir, | ds = de.TFRecordDataset(data_files, schema_dir, | ||||
| columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", | columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", | ||||
| "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"], | "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"], | ||||
| @@ -16,17 +16,15 @@ | |||||
| echo "==============================================================================================================" | echo "==============================================================================================================" | ||||
| echo "Please run the scipt as: " | echo "Please run the scipt as: " | ||||
| echo "sh run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH MINDSPORE_PATH" | |||||
| echo "for example: sh run_distribute_pretrain.sh 8 40 /path/zh-wiki/ /path/Schema.json /path/hccl.json /path/mindspore" | |||||
| echo "sh run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH" | |||||
| echo "for example: sh run_distribute_pretrain.sh 8 40 /path/zh-wiki/ /path/Schema.json /path/hccl.json" | |||||
| echo "It is better to use absolute path." | echo "It is better to use absolute path." | ||||
| echo "==============================================================================================================" | echo "==============================================================================================================" | ||||
| EPOCH_SIZE=$2 | EPOCH_SIZE=$2 | ||||
| DATA_DIR=$3 | DATA_DIR=$3 | ||||
| SCHEMA_DIR=$4 | SCHEMA_DIR=$4 | ||||
| MINDSPORE_PATH=$6 | |||||
| export PYTHONPATH=$MINDSPORE_PATH/build/package:$PYTHONPATH | |||||
| export MINDSPORE_HCCL_CONFIG_PATH=$5 | export MINDSPORE_HCCL_CONFIG_PATH=$5 | ||||
| export RANK_SIZE=$1 | export RANK_SIZE=$1 | ||||
| @@ -16,16 +16,14 @@ | |||||
| echo "==============================================================================================================" | echo "==============================================================================================================" | ||||
| echo "Please run the scipt as: " | echo "Please run the scipt as: " | ||||
| echo "sh run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_PATH" | |||||
| echo "for example: sh run_standalone_pretrain.sh 0 40 /path/zh-wiki/ /path/Schema.json /path/mindspore" | |||||
| echo "sh run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR" | |||||
| echo "for example: sh run_standalone_pretrain.sh 0 40 /path/zh-wiki/ /path/Schema.json" | |||||
| echo "==============================================================================================================" | echo "==============================================================================================================" | ||||
| DEVICE_ID=$1 | DEVICE_ID=$1 | ||||
| EPOCH_SIZE=$2 | EPOCH_SIZE=$2 | ||||
| DATA_DIR=$3 | DATA_DIR=$3 | ||||
| SCHEMA_DIR=$4 | SCHEMA_DIR=$4 | ||||
| MINDSPORE_PATH=$5 | |||||
| export PYTHONPATH=$MINDSPORE_PATH/build/package:$PYTHONPATH | |||||
| python run_pretrain.py \ | python run_pretrain.py \ | ||||
| --distribute="false" \ | --distribute="false" \ | ||||
| @@ -135,9 +135,10 @@ class ModelCallback(Callback): | |||||
| def step_end(self, run_context): | def step_end(self, run_context): | ||||
| cb_params = run_context.original_args() | cb_params = run_context.original_args() | ||||
| self.loss_list.append(cb_params.net_outputs[0]) | |||||
| self.loss_list.append(cb_params.net_outputs[0].asnumpy()[0]) | |||||
| self.overflow_list.append(cb_params.net_outputs[1]) | self.overflow_list.append(cb_params.net_outputs[1]) | ||||
| self.lossscale_list.append(cb_params.net_outputs[2]) | self.lossscale_list.append(cb_params.net_outputs[2]) | ||||
| print("epoch: {}, outputs are: {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs))) | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_arm_ascend_training | @pytest.mark.platform_arm_ascend_training | ||||
| @@ -192,7 +193,11 @@ def test_bert_tdt(): | |||||
| if count == scale_window: | if count == scale_window: | ||||
| count = 0 | count = 0 | ||||
| assert callback.lossscale_list[i] == callback.lossscale_list[i - 1] * Tensor(2.0, mstype.float32) | assert callback.lossscale_list[i] == callback.lossscale_list[i - 1] * Tensor(2.0, mstype.float32) | ||||
| # assertion occurs while the loss value is wrong | |||||
| loss_value = np.array(callback.loss_list) | |||||
| expect_value = [12.1918125, 11.966035, 11.972114, 11.982671, 11.976399, 12.616986, 12.180658, 12.850562, 12.415608, 12.640145] | |||||
| print("loss value: {}".format(loss_value)) | |||||
| assert np.allclose(loss_value, expect_value, 0.00001, 0.00001) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_bert_tdt() | test_bert_tdt() | ||||
| @@ -1,190 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """train bert network without lossscale""" | |||||
| import os | |||||
| import pytest | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| import mindspore.common.dtype as mstype | |||||
| import mindspore.dataset.engine.datasets as de | |||||
| import mindspore.dataset.transforms.c_transforms as C | |||||
| from mindspore import Tensor | |||||
| from mindspore.train.model import Model | |||||
| from mindspore.train.callback import Callback | |||||
| from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell | |||||
| from mindspore.nn.optim import Momentum | |||||
| from mindspore import log as logger | |||||
| _current_dir = os.path.dirname(os.path.realpath(__file__)) | |||||
| DATA_DIR = ["/home/workspace/mindspore_dataset/bert/example/examples.tfrecord"] | |||||
| SCHEMA_DIR = "/home/workspace/mindspore_dataset/bert/example/datasetSchema.json" | |||||
| def get_config(version='base', batch_size=1): | |||||
| """get config""" | |||||
| if version == 'base': | |||||
| bert_config = BertConfig( | |||||
| batch_size=batch_size, | |||||
| seq_length=128, | |||||
| vocab_size=21136, | |||||
| hidden_size=768, | |||||
| num_hidden_layers=2, | |||||
| num_attention_heads=12, | |||||
| intermediate_size=3072, | |||||
| hidden_act="gelu", | |||||
| hidden_dropout_prob=0.1, | |||||
| attention_probs_dropout_prob=0.1, | |||||
| max_position_embeddings=512, | |||||
| type_vocab_size=2, | |||||
| initializer_range=0.02, | |||||
| use_relative_positions=True, | |||||
| input_mask_from_dataset=True, | |||||
| token_type_ids_from_dataset=True, | |||||
| dtype=mstype.float32, | |||||
| compute_type=mstype.float32) | |||||
| elif version == 'large': | |||||
| bert_config = BertConfig( | |||||
| batch_size=batch_size, | |||||
| seq_length=128, | |||||
| vocab_size=21136, | |||||
| hidden_size=1024, | |||||
| num_hidden_layers=2, | |||||
| num_attention_heads=16, | |||||
| intermediate_size=4096, | |||||
| hidden_act="gelu", | |||||
| hidden_dropout_prob=0.0, | |||||
| attention_probs_dropout_prob=0.0, | |||||
| max_position_embeddings=512, | |||||
| type_vocab_size=2, | |||||
| initializer_range=0.02, | |||||
| use_relative_positions=True, | |||||
| input_mask_from_dataset=True, | |||||
| token_type_ids_from_dataset=True, | |||||
| dtype=mstype.float32, | |||||
| compute_type=mstype.float16) | |||||
| elif version == 'large_mixed': | |||||
| bert_config = BertConfig( | |||||
| batch_size=batch_size, | |||||
| seq_length=128, | |||||
| vocab_size=21136, | |||||
| hidden_size=1024, | |||||
| num_hidden_layers=24, | |||||
| num_attention_heads=16, | |||||
| intermediate_size=4096, | |||||
| hidden_act="gelu", | |||||
| hidden_dropout_prob=0.0, | |||||
| attention_probs_dropout_prob=0.0, | |||||
| max_position_embeddings=512, | |||||
| type_vocab_size=2, | |||||
| initializer_range=0.02, | |||||
| use_relative_positions=True, | |||||
| input_mask_from_dataset=True, | |||||
| token_type_ids_from_dataset=True, | |||||
| dtype=mstype.float32, | |||||
| compute_type=mstype.float32) | |||||
| else: | |||||
| bert_config = BertConfig(batch_size=batch_size) | |||||
| return bert_config | |||||
| def me_de_train_dataset(): | |||||
| """test me de train dataset""" | |||||
| # apply repeat operations | |||||
| repeat_count = 1 | |||||
| ds = de.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids", | |||||
| "next_sentence_labels", "masked_lm_positions", | |||||
| "masked_lm_ids", "masked_lm_weights"], shuffle=False) | |||||
| type_cast_op = C.TypeCast(mstype.int32) | |||||
| ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="segment_ids", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="input_mask", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="input_ids", operations=type_cast_op) | |||||
| # apply batch operations | |||||
| batch_size = 16 | |||||
| ds = ds.batch(batch_size, drop_remainder=True) | |||||
| ds = ds.repeat(repeat_count) | |||||
| return ds | |||||
| def weight_variable(shape): | |||||
| """weight variable""" | |||||
| np.random.seed(1) | |||||
| ones = np.random.uniform(-0.1, 0.1, size=shape).astype(np.float32) | |||||
| return Tensor(ones) | |||||
| class ModelCallback(Callback): | |||||
| def __init__(self): | |||||
| super(ModelCallback, self).__init__() | |||||
| self.loss_list = [] | |||||
| def step_end(self, run_context): | |||||
| cb_params = run_context.original_args() | |||||
| self.loss_list.append(cb_params.net_outputs.asnumpy()[0]) | |||||
| logger.info("epoch: {}, outputs are {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs))) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_bert_tdt(): | |||||
| """test bert tdt""" | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False) | |||||
| context.set_context(enable_task_sink=True) | |||||
| context.set_context(enable_loop_sink=True) | |||||
| context.set_context(enable_mem_reuse=True) | |||||
| parallel_callback = ModelCallback() | |||||
| ds = me_de_train_dataset() | |||||
| version = os.getenv('VERSION', 'large') | |||||
| batch_size = int(os.getenv('BATCH_SIZE', '16')) | |||||
| config = get_config(version=version, batch_size=batch_size) | |||||
| netwithloss = BertNetworkWithLoss(config, True) | |||||
| optimizer = Momentum(netwithloss.trainable_params(), learning_rate=2e-5, momentum=0.9) | |||||
| netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer) | |||||
| netwithgrads.set_train(True) | |||||
| model = Model(netwithgrads) | |||||
| params = netwithloss.trainable_params() | |||||
| for param in params: | |||||
| value = param.default_input | |||||
| name = param.name | |||||
| if isinstance(value, Tensor): | |||||
| if name.split('.')[-1] in ['weight']: | |||||
| if name.split('.')[-3] in ['cls2']: | |||||
| logger.info("***************** BERT param name is 1 {}".format(name)) | |||||
| param.default_input = weight_variable(value.asnumpy().shape) | |||||
| else: | |||||
| logger.info("***************** BERT param name is 2 {}".format(name)) | |||||
| tempshape = value.asnumpy().shape | |||||
| shape = (tempshape[1], tempshape[0]) | |||||
| weight_value = weight_variable(shape).asnumpy() | |||||
| param.default_input = Tensor(np.transpose(weight_value, [1, 0])) | |||||
| else: | |||||
| logger.info("***************** BERT param name is 3 {}".format(name)) | |||||
| param.default_input = weight_variable(value.asnumpy().shape) | |||||
| model.train(ds.get_repeat_count(), ds, callbacks=parallel_callback, dataset_sink_mode=False) | |||||
| loss_value = np.array(parallel_callback.loss_list) | |||||
| expect_out = [12.19179, 11.965041, 11.969687, 11.97815, 11.969171, 12.603289, 12.165594, | |||||
| 12.824818, 12.38842, 12.604046] | |||||
| logger.info("expected loss value output: {}".format(expect_out)) | |||||
| assert np.allclose(loss_value, expect_out, 0.00001, 0.00001) | |||||
| if __name__ == '__main__': | |||||
| test_bert_tdt() | |||||