diff --git a/model_zoo/official/nlp/bert/README.md b/model_zoo/official/nlp/bert/README.md index a5bc379141..1fb895933f 100644 --- a/model_zoo/official/nlp/bert/README.md +++ b/model_zoo/official/nlp/bert/README.md @@ -239,30 +239,32 @@ usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_n [--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE] [--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] [--accumulation_steps N] + [--allreduce_post_accumulation ALLREDUCE_POST_ACCUMULATION] [--save_checkpoint_path SAVE_CHECKPOINT_PATH] [--load_checkpoint_path LOAD_CHECKPOINT_PATH] [--save_checkpoint_steps N] [--save_checkpoint_num N] [--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR] [train_steps N] options: - --device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend" - --distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false" - --epoch_size epoch size: N, default is 1 - --device_num number of used devices: N, default is 1 - --device_id device id: N, default is 0 - --enable_save_ckpt enable save checkpoint: "true" | "false", default is "true" - --enable_lossscale enable lossscale: "true" | "false", default is "true" - --do_shuffle enable shuffle: "true" | "false", default is "true" - --enable_data_sink enable data sink: "true" | "false", default is "true" - --data_sink_steps set data sink steps: N, default is 1 - --accumulation_steps accumulate gradients N times before weight update: N, default is 1 - --save_checkpoint_path path to save checkpoint files: PATH, default is "" - --load_checkpoint_path path to load checkpoint files: PATH, default is "" - --save_checkpoint_steps steps for saving checkpoint files: N, default is 1000 - --save_checkpoint_num number for saving checkpoint files: N, default is 1 - --train_steps Training Steps: N, default is -1 - --data_dir path to dataset directory: PATH, default is "" - --schema_dir path to schema.json file, PATH, default is "" + --device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend" + --distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false" + --epoch_size epoch size: N, default is 1 + --device_num number of used devices: N, default is 1 + --device_id device id: N, default is 0 + --enable_save_ckpt enable save checkpoint: "true" | "false", default is "true" + --enable_lossscale enable lossscale: "true" | "false", default is "true" + --do_shuffle enable shuffle: "true" | "false", default is "true" + --enable_data_sink enable data sink: "true" | "false", default is "true" + --data_sink_steps set data sink steps: N, default is 1 + --accumulation_steps accumulate gradients N times before weight update: N, default is 1 + --allreduce_post_accumulation allreduce after accumulation of N steps or after each step: "true" | "false", default is "true" + --save_checkpoint_path path to save checkpoint files: PATH, default is "" + --load_checkpoint_path path to load checkpoint files: PATH, default is "" + --save_checkpoint_steps steps for saving checkpoint files: N, default is 1000 + --save_checkpoint_num number for saving checkpoint files: N, default is 1 + --train_steps Training Steps: N, default is -1 + --data_dir path to dataset directory: PATH, default is "" + --schema_dir path to schema.json file, PATH, default is "" ``` ### Fine-Tuning and Evaluation diff --git a/model_zoo/official/nlp/bert/run_pretrain.py b/model_zoo/official/nlp/bert/run_pretrain.py index 52c14a2819..ba43c9e7f6 100644 --- a/model_zoo/official/nlp/bert/run_pretrain.py +++ b/model_zoo/official/nlp/bert/run_pretrain.py @@ -32,7 +32,9 @@ from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay from mindspore import log as logger from mindspore.common import set_seed from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \ - BertTrainAccumulateStepsWithLossScaleCell, BertTrainOneStepWithLossScaleCellForAdam, \ + BertTrainAccumulationAllReduceEachWithLossScaleCell, \ + BertTrainAccumulationAllReducePostWithLossScaleCell, \ + BertTrainOneStepWithLossScaleCellForAdam, \ AdamWeightDecayForBert from src.dataset import create_bert_dataset from src.config import cfg, bert_net_cfg @@ -122,6 +124,8 @@ def run_pretrain(): parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.") parser.add_argument("--accumulation_steps", type=int, default="1", help="Accumulating gradients N times before weight update, default is 1.") + parser.add_argument("--allreduce_post_accumulation", type=str, default="true", choices=["true", "false"], + help="Whether to allreduce after accumulation of N steps or after each step, default is true.") parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path") parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path") parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, " @@ -207,8 +211,9 @@ def run_pretrain(): update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, scale_factor=cfg.scale_factor, scale_window=cfg.scale_window) - - if args_opt.accumulation_steps <= 1: + accumulation_steps = args_opt.accumulation_steps + enable_global_norm = cfg.enable_global_norm + if accumulation_steps <= 1: if cfg.optimizer == 'AdamWeightDecay': net_with_grads = BertTrainOneStepWithLossScaleCellForAdam(net_with_loss, optimizer=optimizer, scale_update_cell=update_cell) @@ -216,11 +221,13 @@ def run_pretrain(): net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer, scale_update_cell=update_cell) else: - accumulation_steps = args_opt.accumulation_steps - net_with_grads = BertTrainAccumulateStepsWithLossScaleCell(net_with_loss, optimizer=optimizer, - scale_update_cell=update_cell, - accumulation_steps=accumulation_steps, - enable_global_norm=cfg.enable_global_norm) + allreduce_post = args_opt.distribute == "false" or args_opt.allreduce_post_accumulation == "true" + net_with_accumulation = (BertTrainAccumulationAllReducePostWithLossScaleCell if allreduce_post else + BertTrainAccumulationAllReduceEachWithLossScaleCell) + net_with_grads = net_with_accumulation(net_with_loss, optimizer=optimizer, + scale_update_cell=update_cell, + accumulation_steps=accumulation_steps, + enable_global_norm=enable_global_norm) else: net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer) diff --git a/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/hyper_parameter_config.ini b/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/hyper_parameter_config.ini index a7c7030492..5489c9ca75 100644 --- a/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/hyper_parameter_config.ini +++ b/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/hyper_parameter_config.ini @@ -7,6 +7,7 @@ do_shuffle=true enable_data_sink=true data_sink_steps=100 accumulation_steps=1 +allreduce_post_accumulation=true save_checkpoint_path=./ save_checkpoint_steps=10000 save_checkpoint_num=1 diff --git a/model_zoo/official/nlp/bert/src/__init__.py b/model_zoo/official/nlp/bert/src/__init__.py index aa5003a2b2..da365e41b5 100644 --- a/model_zoo/official/nlp/bert/src/__init__.py +++ b/model_zoo/official/nlp/bert/src/__init__.py @@ -16,7 +16,8 @@ from .bert_for_pre_training import BertNetworkWithLoss, BertPreTraining, \ BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \ BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \ - BertTrainAccumulateStepsWithLossScaleCell + BertTrainAccumulationAllReduceEachWithLossScaleCell, \ + BertTrainAccumulationAllReducePostWithLossScaleCell from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \ BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \ EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \ @@ -25,7 +26,8 @@ from .adam import AdamWeightDecayForBert __all__ = [ "BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss", "GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell", - "BertTrainOneStepWithLossScaleCell", "BertTrainAccumulateStepsWithLossScaleCell", + "BertTrainOneStepWithLossScaleCell", "BertTrainAccumulationAllReduceEachWithLossScaleCell", + "BertTrainAccumulationAllReducePostWithLossScaleCell", "BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput", "BertSelfAttention", "BertTransformer", "EmbeddingLookup", "EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", "AdamWeightDecayForBert", diff --git a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py index 2ab077d4a4..cbe47aa1b1 100644 --- a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py +++ b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py @@ -556,11 +556,24 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell): return F.depend(ret, succ) cast = P.Cast() -update_accu_grads = C.MultitypeFuncGraph("update_accu_grads") +add_grads = C.MultitypeFuncGraph("add_grads") + +@add_grads.register("Tensor", "Tensor") +def _add_grads(accu_grad, grad): + return accu_grad + cast(grad, mstype.float32) + +update_accu_grads = C.MultitypeFuncGraph("update_accu_grads") @update_accu_grads.register("Tensor", "Tensor") def _update_accu_grads(accu_grad, grad): + succ = True + return F.depend(succ, F.assign(accu_grad, cast(grad, mstype.float32))) + +accumulate_accu_grads = C.MultitypeFuncGraph("accumulate_accu_grads") + +@accumulate_accu_grads.register("Tensor", "Tensor") +def _accumulate_accu_grads(accu_grad, grad): succ = True return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32))) @@ -575,13 +588,17 @@ def _reset_accu_grads(accu_grad): return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad))) -class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): +class BertTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell): """ Encapsulation class of bert network training. Append an optimizer to the training network after that the construct - function can be called to create the backward graph. To mimic higher batch size, gradients are - accumulated N times before weight update. + function can be called to create the backward graph. + + To mimic higher batch size, gradients are accumulated N times before weight update. + + For distribution mode, allreduce will only be implemented in the weight updated step, + i.e. the sub-step after gradients accumulated N times. Args: network (Cell): The training network. Note that loss function should have been added. @@ -591,7 +608,7 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): batch_size * accumulation_steps. Default: 1. """ def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False): - super(BertTrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False) + super(BertTrainAccumulationAllReducePostWithLossScaleCell, self).__init__(auto_prefix=False) self.network = network self.network.set_grad() self.weights = optimizer.parameters @@ -680,7 +697,7 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): self.cast(scaling_sens, mstype.float32)) - accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads) + accu_succ = self.hyper_map(accumulate_accu_grads, self.accu_grads, grads) mean_loss = F.depend(mean_loss, accu_succ) self.get_status(init) @@ -716,3 +733,151 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): ret = (mean_loss, overflow, scaling_sens) return F.depend(ret, succ) + + +class BertTrainAccumulationAllReduceEachWithLossScaleCell(nn.Cell): + """ + Encapsulation class of bert network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + To mimic higher batch size, gradients are accumulated N times before weight update. + + For distribution mode, allreduce will be implemented after each sub-step and the trailing time + will be overided by backend optimization pass. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + scale_update_cell (Cell): Cell to do the loss scale. Default: None. + accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size = + batch_size * accumulation_steps. Default: 1. + """ + def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False): + super(BertTrainAccumulationAllReduceEachWithLossScaleCell, self).__init__(auto_prefix=False) + self.network = network + self.network.set_grad() + self.weights = optimizer.parameters + self.optimizer = optimizer + self.accumulation_steps = accumulation_steps + self.enable_global_norm = enable_global_norm + self.one = Tensor(np.array([1]).astype(np.int32)) + self.zero = Tensor(np.array([0]).astype(np.int32)) + self.local_step = Parameter(initializer(0, [1], mstype.int32)) + self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros') + self.accu_overflow = Parameter(initializer(0, [1], mstype.int32)) + self.accu_loss = Parameter(initializer(0, [1], mstype.float32)) + + self.grad = C.GradOperation(get_by_list=True, sens_param=True) + self.reducer_flag = False + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = F.identity + self.degree = 1 + if self.reducer_flag: + self.degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.overflow_reducer = F.identity + if self.is_distributed: + self.overflow_reducer = P.AllReduce() + self.cast = P.Cast() + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_before_grad = P.NPUClearFloatStatus() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.base = Tensor(1, mstype.float32) + self.less_equal = P.LessEqual() + self.logical_or = P.LogicalOr() + self.not_equal = P.NotEqual() + self.select = P.Select() + self.reshape = P.Reshape() + self.hyper_map = C.HyperMap() + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) + + @C.add_flags(has_effect=True) + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + sens=None): + """Defines the computation performed.""" + weights = self.weights + loss = self.network(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + + # update accumulation parameters + is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) + self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one) + self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss) + mean_loss = self.accu_loss / self.local_step + is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) + + # alloc status and clear should be right before gradoperation + init = self.alloc_status() + self.clear_before_grad(init) + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + self.cast(scaling_sens, + mstype.float32)) + + + accu_grads = self.hyper_map(add_grads, self.accu_grads, grads) + scaling = scaling_sens * self.degree * self.accumulation_steps + grads = self.hyper_map(F.partial(grad_scale, scaling), accu_grads) + grads = self.grad_reducer(grads) + + self.get_status(init) + flag_sum = self.reduce_sum(init, (0,)) + flag_reduce = self.overflow_reducer(flag_sum) + overflow = self.less_equal(self.base, flag_reduce) + overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow) + accu_overflow = self.select(overflow, self.one, self.zero) + self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero) + overflow = self.reshape(overflow, (())) + + if is_accu_step: + succ = False + accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, accu_grads) + succ = F.depend(succ, accu_succ) + else: + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, overflow) + if overflow: + succ = False + else: + if self.enable_global_norm: + grads = C.clip_by_global_norm(grads, 1.0, None) + else: + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + + succ = self.optimizer(grads) + + accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads) + succ = F.depend(succ, accu_succ) + + ret = (mean_loss, overflow, scaling_sens) + return F.depend(ret, succ)