Merge pull request !4799 from shibeiji/bert_accu_stepstags/v1.0.0
| @@ -63,6 +63,7 @@ check_bprop = P.CheckBprop() | |||
| equal = P.Equal() | |||
| not_equal = P.NotEqual() | |||
| assign_sub = P.AssignSub() | |||
| assign_add = P.AssignAdd() | |||
| assign = P.Assign() | |||
| square = P.Square() | |||
| sqrt = P.Sqrt() | |||
| @@ -123,6 +123,7 @@ usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_n | |||
| [--enable_save_ckpt ENABLE_SAVE_CKPT] [--device_target DEVICE_TARGET] | |||
| [--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE] | |||
| [--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] | |||
| [--accumulation_steps N] | |||
| [--save_checkpoint_path SAVE_CHECKPOINT_PATH] | |||
| [--load_checkpoint_path LOAD_CHECKPOINT_PATH] | |||
| [--save_checkpoint_steps N] [--save_checkpoint_num N] | |||
| @@ -139,6 +140,7 @@ options: | |||
| --do_shuffle enable shuffle: "true" | "false", default is "true" | |||
| --enable_data_sink enable data sink: "true" | "false", default is "true" | |||
| --data_sink_steps set data sink steps: N, default is 1 | |||
| --accumulation_steps accumulate gradients N times before weight update: N, default is 1 | |||
| --save_checkpoint_path path to save checkpoint files: PATH, default is "" | |||
| --load_checkpoint_path path to load checkpoint files: PATH, default is "" | |||
| --save_checkpoint_steps steps for saving checkpoint files: N, default is 1000 | |||
| @@ -30,7 +30,8 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMoni | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay | |||
| from mindspore import log as logger | |||
| from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell | |||
| from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \ | |||
| BertTrainAccumulateStepsWithLossScaleCell | |||
| from src.dataset import create_bert_dataset | |||
| from src.config import cfg, bert_net_cfg | |||
| from src.utils import LossCallBack, BertLearningRate | |||
| @@ -51,6 +52,8 @@ def run_pretrain(): | |||
| parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") | |||
| parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") | |||
| parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.") | |||
| parser.add_argument("--accumulation_steps", type=int, default="1", | |||
| help="Accumulating gradients N times before weight update, default is 1.") | |||
| parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path") | |||
| parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path") | |||
| parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, " | |||
| @@ -98,6 +101,16 @@ def run_pretrain(): | |||
| logger.warning('Gpu only support fp32 temporarily, run with fp32.') | |||
| bert_net_cfg.compute_type = mstype.float32 | |||
| if args_opt.accumulation_steps > 1: | |||
| logger.info("accumulation steps: {}".format(args_opt.accumulation_steps)) | |||
| logger.info("global batch size: {}".format(bert_net_cfg.batch_size * args_opt.accumulation_steps)) | |||
| if args_opt.enable_data_sink == "true": | |||
| args_opt.data_sink_steps *= args_opt.accumulation_steps | |||
| logger.info("data sink steps: {}".format(args_opt.data_sink_steps)) | |||
| if args_opt.enable_save_ckpt == "true": | |||
| args_opt.save_checkpoint_steps *= args_opt.accumulation_steps | |||
| logger.info("save checkpoint steps: {}".format(args_opt.save_checkpoint_steps)) | |||
| ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir) | |||
| net_with_loss = BertNetworkWithLoss(bert_net_cfg, True) | |||
| @@ -157,8 +170,15 @@ def run_pretrain(): | |||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, | |||
| scale_factor=cfg.scale_factor, | |||
| scale_window=cfg.scale_window) | |||
| net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer, | |||
| scale_update_cell=update_cell) | |||
| if args_opt.accumulation_steps <= 1: | |||
| net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer, | |||
| scale_update_cell=update_cell) | |||
| else: | |||
| accumulation_steps = args_opt.accumulation_steps | |||
| net_with_grads = BertTrainAccumulateStepsWithLossScaleCell(net_with_loss, optimizer=optimizer, | |||
| scale_update_cell=update_cell, | |||
| accumulation_steps=accumulation_steps) | |||
| else: | |||
| net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer) | |||
| @@ -6,6 +6,7 @@ enable_lossscale=true | |||
| do_shuffle=true | |||
| enable_data_sink=true | |||
| data_sink_steps=100 | |||
| accumulation_steps=1 | |||
| save_checkpoint_path=./checkpoint/ | |||
| save_checkpoint_steps=10000 | |||
| save_checkpoint_num=1 | |||
| save_checkpoint_num=1 | |||
| @@ -39,6 +39,7 @@ python ${PROJECT_DIR}/../run_pretrain.py \ | |||
| --do_shuffle="true" \ | |||
| --enable_data_sink="true" \ | |||
| --data_sink_steps=1 \ | |||
| --accumulation_steps=1 \ | |||
| --load_checkpoint_path="" \ | |||
| --save_checkpoint_steps=10000 \ | |||
| --save_checkpoint_num=1 \ | |||
| @@ -15,7 +15,8 @@ | |||
| """Bert Init.""" | |||
| from .bert_for_pre_training import BertNetworkWithLoss, BertPreTraining, \ | |||
| BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \ | |||
| BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell | |||
| BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \ | |||
| BertTrainAccumulateStepsWithLossScaleCell | |||
| from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \ | |||
| BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \ | |||
| EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \ | |||
| @@ -23,7 +24,8 @@ from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \ | |||
| __all__ = [ | |||
| "BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss", | |||
| "GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell", "BertTrainOneStepWithLossScaleCell", | |||
| "GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell", | |||
| "BertTrainOneStepWithLossScaleCell", "BertTrainAccumulateStepsWithLossScaleCell", | |||
| "BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput", | |||
| "BertSelfAttention", "BertTransformer", "EmbeddingLookup", | |||
| "EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", | |||
| @@ -438,3 +438,164 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): | |||
| succ = self.optimizer(grads) | |||
| ret = (loss, cond, scaling_sens) | |||
| return F.depend(ret, succ) | |||
| cast = P.Cast() | |||
| update_accu_grads = C.MultitypeFuncGraph("update_accu_grads") | |||
| @update_accu_grads.register("Tensor", "Tensor") | |||
| def _update_accu_grads(accu_grad, grad): | |||
| succ = True | |||
| return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32))) | |||
| zeroslike = P.ZerosLike() | |||
| reset_accu_grads = C.MultitypeFuncGraph("reset_accu_grads") | |||
| @reset_accu_grads.register("Tensor") | |||
| def _reset_accu_grads(accu_grad): | |||
| succ = True | |||
| return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad))) | |||
| class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): | |||
| """ | |||
| Encapsulation class of bert network training. | |||
| Append an optimizer to the training network after that the construct | |||
| function can be called to create the backward graph. To mimic higher batch size, gradients are | |||
| accumulated N times before weight update. | |||
| Args: | |||
| network (Cell): The training network. Note that loss function should have been added. | |||
| optimizer (Optimizer): Optimizer for updating the weights. | |||
| scale_update_cell (Cell): Cell to do the loss scale. Default: None. | |||
| accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size = | |||
| batch_size * accumulation_steps. Default: 1. | |||
| """ | |||
| def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1): | |||
| super(BertTrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.weights = optimizer.parameters | |||
| self.optimizer = optimizer | |||
| self.accumulation_steps = accumulation_steps | |||
| self.one = Tensor(np.array([1]).astype(np.int32)) | |||
| self.zero = Tensor(np.array([0]).astype(np.int32)) | |||
| self.local_step = Parameter(initializer(0, [1], mstype.int32), name="local_step") | |||
| self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros') | |||
| self.accu_overflow = Parameter(initializer(0, [1], mstype.int32), name="accu_overflow") | |||
| self.loss = Parameter(initializer(0, [1], mstype.float32), name="accu_loss") | |||
| self.grad = C.GradOperation('grad', | |||
| get_by_list=True, | |||
| sens_param=True) | |||
| self.reducer_flag = False | |||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||
| self.reducer_flag = True | |||
| self.grad_reducer = F.identity | |||
| self.degree = 1 | |||
| if self.reducer_flag: | |||
| self.degree = get_group_size() | |||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) | |||
| self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) | |||
| self.overflow_reducer = F.identity | |||
| if self.is_distributed: | |||
| self.overflow_reducer = P.AllReduce() | |||
| self.cast = P.Cast() | |||
| self.alloc_status = P.NPUAllocFloatStatus() | |||
| self.get_status = P.NPUGetFloatStatus() | |||
| self.clear_before_grad = P.NPUClearFloatStatus() | |||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | |||
| self.base = Tensor(1, mstype.float32) | |||
| self.less_equal = P.LessEqual() | |||
| self.logical_or = P.LogicalOr() | |||
| self.not_equal = P.NotEqual() | |||
| self.select = P.Select() | |||
| self.reshape = P.Reshape() | |||
| self.hyper_map = C.HyperMap() | |||
| self.loss_scale = None | |||
| self.loss_scaling_manager = scale_update_cell | |||
| if scale_update_cell: | |||
| self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), | |||
| name="loss_scale") | |||
| @C.add_flags(has_effect=True) | |||
| def construct(self, | |||
| input_ids, | |||
| input_mask, | |||
| token_type_id, | |||
| next_sentence_labels, | |||
| masked_lm_positions, | |||
| masked_lm_ids, | |||
| masked_lm_weights, | |||
| sens=None): | |||
| """Defines the computation performed.""" | |||
| weights = self.weights | |||
| loss = self.network(input_ids, | |||
| input_mask, | |||
| token_type_id, | |||
| next_sentence_labels, | |||
| masked_lm_positions, | |||
| masked_lm_ids, | |||
| masked_lm_weights) | |||
| if sens is None: | |||
| scaling_sens = self.loss_scale | |||
| else: | |||
| scaling_sens = sens | |||
| # update accumulation parameters | |||
| is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) | |||
| self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one) | |||
| self.loss = self.select(is_accu_step, self.loss + loss, loss) | |||
| mean_loss = self.loss / self.local_step | |||
| is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) | |||
| # alloc status and clear should be right before gradoperation | |||
| init = self.alloc_status() | |||
| self.clear_before_grad(init) | |||
| grads = self.grad(self.network, weights)(input_ids, | |||
| input_mask, | |||
| token_type_id, | |||
| next_sentence_labels, | |||
| masked_lm_positions, | |||
| masked_lm_ids, | |||
| masked_lm_weights, | |||
| self.cast(scaling_sens, | |||
| mstype.float32)) | |||
| accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads) | |||
| mean_loss = F.depend(mean_loss, accu_succ) | |||
| self.get_status(init) | |||
| flag_sum = self.reduce_sum(init, (0,)) | |||
| overflow = self.less_equal(self.base, flag_sum) | |||
| overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow) | |||
| accu_overflow = self.select(overflow, self.one, self.zero) | |||
| self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero) | |||
| if is_accu_step: | |||
| succ = False | |||
| else: | |||
| # apply grad reducer on grads | |||
| grads = self.grad_reducer(self.accu_grads) | |||
| scaling = scaling_sens * self.degree * self.accumulation_steps | |||
| grads = self.hyper_map(F.partial(grad_scale, scaling), grads) | |||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||
| accu_overflow = self.overflow_reducer(accu_overflow) | |||
| F.control_depend(grads, accu_overflow) | |||
| overflow = self.less_equal(self.base, accu_overflow) | |||
| accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads) | |||
| overflow = F.depend(overflow, accu_succ) | |||
| overflow = self.reshape(overflow, (())) | |||
| if sens is None: | |||
| overflow = self.loss_scaling_manager(self.loss_scale, overflow) | |||
| if overflow: | |||
| succ = False | |||
| else: | |||
| succ = self.optimizer(grads) | |||
| ret = (mean_loss, overflow, scaling_sens) | |||
| return F.depend(ret, succ) | |||
| @@ -50,7 +50,7 @@ cfg = edict({ | |||
| ''' | |||
| Including two kinds of network: \ | |||
| base: Goole BERT-base(the base version of BERT model). | |||
| base: Google BERT-base(the base version of BERT model). | |||
| large: BERT-NEZHA(a Chinese pretrained language model developed by Huawei, which introduced a improvement of \ | |||
| Functional Relative Posetional Encoding as an effective positional encoding scheme). | |||
| ''' | |||