| @@ -42,7 +42,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e | |||||
| if enable_data_sink == "true": | if enable_data_sink == "true": | ||||
| new_size = data_sink_steps * bert_net_cfg.batch_size | new_size = data_sink_steps * bert_net_cfg.batch_size | ||||
| ds.set_dataset_size(new_size) | ds.set_dataset_size(new_size) | ||||
| repeat_count = int(repeat_count * ori_dataset_size // ds.get_dataset_size()) | |||||
| new_repeat_count = int(repeat_count * ori_dataset_size // ds.get_dataset_size()) | |||||
| type_cast_op = C.TypeCast(mstype.int32) | type_cast_op = C.TypeCast(mstype.int32) | ||||
| ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) | ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) | ||||
| ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) | ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) | ||||
| @@ -55,4 +55,4 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e | |||||
| ds = ds.repeat(repeat_count) | ds = ds.repeat(repeat_count) | ||||
| logger.info("data size: {}".format(ds.get_dataset_size())) | logger.info("data size: {}".format(ds.get_dataset_size())) | ||||
| logger.info("repeatcount: {}".format(ds.get_repeat_count())) | logger.info("repeatcount: {}".format(ds.get_repeat_count())) | ||||
| return ds | |||||
| return ds, new_repeat_count | |||||
| @@ -24,7 +24,7 @@ from mindspore import context | |||||
| from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
| from mindspore.train.parallel_utils import ParallelMode | from mindspore.train.parallel_utils import ParallelMode | ||||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | ||||
| from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig | |||||
| from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig, TimeMonitor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell | from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell | ||||
| from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR | from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR | ||||
| @@ -87,8 +87,9 @@ def run_pretrain(): | |||||
| rank = 0 | rank = 0 | ||||
| device_num = 1 | device_num = 1 | ||||
| ds = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle, args_opt.enable_data_sink, | |||||
| args_opt.data_sink_steps, args_opt.data_dir, args_opt.schema_dir) | |||||
| ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle, | |||||
| args_opt.enable_data_sink, args_opt.data_sink_steps, | |||||
| args_opt.data_dir, args_opt.schema_dir) | |||||
| netwithloss = BertNetworkWithLoss(bert_net_cfg, True) | netwithloss = BertNetworkWithLoss(bert_net_cfg, True) | ||||
| @@ -112,7 +113,7 @@ def run_pretrain(): | |||||
| else: | else: | ||||
| raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]". | raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]". | ||||
| format(cfg.optimizer)) | format(cfg.optimizer)) | ||||
| callback = [LossCallBack()] | |||||
| callback = [TimeMonitor(ds.get_dataset_size()), LossCallBack()] | |||||
| if args_opt.enable_save_ckpt == "true": | if args_opt.enable_save_ckpt == "true": | ||||
| config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps, | config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps, | ||||
| keep_checkpoint_max=args_opt.save_checkpoint_num) | keep_checkpoint_max=args_opt.save_checkpoint_num) | ||||
| @@ -133,6 +134,6 @@ def run_pretrain(): | |||||
| netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer) | netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer) | ||||
| model = Model(netwithgrads) | model = Model(netwithgrads) | ||||
| model.train(ds.get_repeat_count(), ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true")) | |||||
| model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true")) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| run_pretrain() | run_pretrain() | ||||
| @@ -99,6 +99,9 @@ class Dropout(Cell): | |||||
| out, _ = self.dropout(x) | out, _ = self.dropout(x) | ||||
| return out | return out | ||||
| if self.keep_prob == 1: | |||||
| return x | |||||
| shape = self.get_shape(x) | shape = self.get_shape(x) | ||||
| dtype = P.DType()(x) | dtype = P.DType()(x) | ||||
| keep_prob = self.cast(self.keep_prob, dtype) | keep_prob = self.cast(self.keep_prob, dtype) | ||||
| @@ -26,7 +26,7 @@ from mindspore import context | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertNetworkWithLoss, BertTrainOneStepWithLossScaleCell | from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertNetworkWithLoss, BertTrainOneStepWithLossScaleCell | ||||
| from mindspore.nn.optim import Momentum | |||||
| from mindspore.nn.optim import Lamb | |||||
| from mindspore.train.callback import Callback | from mindspore.train.callback import Callback | ||||
| from mindspore.train.loss_scale_manager import DynamicLossScaleManager | from mindspore.train.loss_scale_manager import DynamicLossScaleManager | ||||
| from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
| @@ -73,7 +73,7 @@ def get_config(version='base', batch_size=1): | |||||
| max_position_embeddings=512, | max_position_embeddings=512, | ||||
| type_vocab_size=2, | type_vocab_size=2, | ||||
| initializer_range=0.02, | initializer_range=0.02, | ||||
| use_relative_positions=True, | |||||
| use_relative_positions=False, | |||||
| input_mask_from_dataset=True, | input_mask_from_dataset=True, | ||||
| token_type_ids_from_dataset=True, | token_type_ids_from_dataset=True, | ||||
| dtype=mstype.float32, | dtype=mstype.float32, | ||||
| @@ -138,7 +138,9 @@ def test_bert_tdt(): | |||||
| batch_size = int(os.getenv('BATCH_SIZE', '16')) | batch_size = int(os.getenv('BATCH_SIZE', '16')) | ||||
| config = get_config(version=version, batch_size=batch_size) | config = get_config(version=version, batch_size=batch_size) | ||||
| netwithloss = BertNetworkWithLoss(config, True) | netwithloss = BertNetworkWithLoss(config, True) | ||||
| optimizer = Momentum(netwithloss.trainable_params(), learning_rate=2e-5, momentum=0.9) | |||||
| optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size()*ds.get_repeat_count(), | |||||
| start_learning_rate=5e-5, end_learning_rate=1e-9, | |||||
| power=10.0, warmup_steps=0, weight_decay=0.01) | |||||
| scale_window = 3 | scale_window = 3 | ||||
| scale_manager = DynamicLossScaleManager(2 ** 16, 2, scale_window) | scale_manager = DynamicLossScaleManager(2 ** 16, 2, scale_window) | ||||
| netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, | netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, | ||||
| @@ -169,10 +171,10 @@ def test_bert_tdt(): | |||||
| # assertion occurs while the loss value, overflow state or loss_scale value is wrong | # assertion occurs while the loss value, overflow state or loss_scale value is wrong | ||||
| loss_value = np.array(callback.loss_list) | loss_value = np.array(callback.loss_list) | ||||
| expect_loss_value = [12.191826, 11.966009, 11.972208, 11.98216, 11.973932, 12.611078, 12.17554, 12.840299, | |||||
| 12.403329, 12.621632] | |||||
| expect_loss_value = [12.207201, 11.980862, 11.984737, 11.879344, 11.832838, 12.411388, | |||||
| 12.009449, 12.621273, 12.223175, 12.427313] | |||||
| print("loss value: {}".format(loss_value)) | print("loss value: {}".format(loss_value)) | ||||
| assert np.allclose(loss_value, expect_loss_value, 0.00001, 0.00001) | |||||
| assert np.allclose(loss_value, expect_loss_value, 0, 0.0005) | |||||
| overflow = np.array(callback.overflow_list) | overflow = np.array(callback.overflow_list) | ||||
| expect_overflow = [True, True, False, False, False, True, False, False, False, True] | expect_overflow = [True, True, False, False, False, True, False, False, False, True] | ||||
| @@ -182,7 +184,7 @@ def test_bert_tdt(): | |||||
| loss_scale = np.array(callback.lossscale_list) | loss_scale = np.array(callback.lossscale_list) | ||||
| expect_loss_scale = [32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0] | expect_loss_scale = [32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0] | ||||
| print("loss scale: {}".format(loss_scale)) | print("loss scale: {}".format(loss_scale)) | ||||
| assert np.allclose(loss_scale, expect_loss_scale, 0.00001, 0.00001) | |||||
| assert np.allclose(loss_scale, expect_loss_scale, 0, 0) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||