Merge pull request !5632 from chenhaozhe/add-global-norm-to-berttags/v1.0.0
| @@ -179,12 +179,14 @@ def run_pretrain(): | |||
| if args_opt.accumulation_steps <= 1: | |||
| net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer, | |||
| scale_update_cell=update_cell) | |||
| scale_update_cell=update_cell, | |||
| enable_global_norm=cfg.enable_global_norm) | |||
| else: | |||
| accumulation_steps = args_opt.accumulation_steps | |||
| net_with_grads = BertTrainAccumulateStepsWithLossScaleCell(net_with_loss, optimizer=optimizer, | |||
| scale_update_cell=update_cell, | |||
| accumulation_steps=accumulation_steps) | |||
| accumulation_steps=accumulation_steps, | |||
| enable_global_norm=cfg.enable_global_norm) | |||
| else: | |||
| net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer) | |||
| @@ -29,6 +29,7 @@ from mindspore.communication.management import get_group_size | |||
| from mindspore import context | |||
| from mindspore.ops import _selected_ops | |||
| from .bert_model import BertModel | |||
| from .utils import ClipByGlobalNorm | |||
| GRADIENT_CLIP_TYPE = 1 | |||
| GRADIENT_CLIP_VALUE = 1.0 | |||
| @@ -348,11 +349,12 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): | |||
| optimizer (Optimizer): Optimizer for updating the weights. | |||
| scale_update_cell (Cell): Cell to do the loss scale. Default: None. | |||
| """ | |||
| def __init__(self, network, optimizer, scale_update_cell=None): | |||
| def __init__(self, network, optimizer, scale_update_cell=None, enable_global_norm=False): | |||
| super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.weights = optimizer.parameters | |||
| self.optimizer = optimizer | |||
| self.enable_global_norm = enable_global_norm | |||
| self.grad = C.GradOperation(get_by_list=True, | |||
| sens_param=True) | |||
| self.reducer_flag = False | |||
| @@ -419,7 +421,10 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): | |||
| # apply grad reducer on grads | |||
| grads = self.grad_reducer(grads) | |||
| grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) | |||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||
| if self.enable_global_norm: | |||
| grads = ClipByGlobalNorm()(grads) | |||
| else: | |||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||
| self.get_status(init) | |||
| flag_sum = self.reduce_sum(init, (0,)) | |||
| if self.is_distributed: | |||
| @@ -474,12 +479,13 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): | |||
| accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size = | |||
| batch_size * accumulation_steps. Default: 1. | |||
| """ | |||
| def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1): | |||
| def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False): | |||
| super(BertTrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.weights = optimizer.parameters | |||
| self.optimizer = optimizer | |||
| self.accumulation_steps = accumulation_steps | |||
| self.enable_global_norm = enable_global_norm | |||
| self.one = Tensor(np.array([1]).astype(np.int32)) | |||
| self.zero = Tensor(np.array([0]).astype(np.int32)) | |||
| self.local_step = Parameter(initializer(0, [1], mstype.int32), name="local_step") | |||
| @@ -580,7 +586,10 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): | |||
| grads = self.grad_reducer(self.accu_grads) | |||
| scaling = scaling_sens * self.degree * self.accumulation_steps | |||
| grads = self.hyper_map(F.partial(grad_scale, scaling), grads) | |||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||
| if self.enable_global_norm: | |||
| grads = ClipByGlobalNorm()(grad) | |||
| else: | |||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||
| accu_overflow = self.overflow_reducer(accu_overflow) | |||
| F.control_depend(grads, accu_overflow) | |||
| overflow = self.less_equal(self.base, accu_overflow) | |||
| @@ -24,6 +24,7 @@ cfg = edict({ | |||
| 'scale_factor': 2, | |||
| 'scale_window': 1000, | |||
| 'optimizer': 'Lamb', | |||
| 'enable_global_norm': False, | |||
| 'AdamWeightDecay': edict({ | |||
| 'learning_rate': 3e-5, | |||
| 'end_learning_rate': 0.0, | |||
| @@ -115,6 +116,5 @@ if cfg.bert_network == 'large': | |||
| input_mask_from_dataset=True, | |||
| token_type_ids_from_dataset=True, | |||
| dtype=mstype.float32, | |||
| compute_type=mstype.float16, | |||
| enable_fused_layernorm=True | |||
| compute_type=mstype.float16 | |||
| ) | |||
| @@ -23,12 +23,62 @@ import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import log as logger | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.train.callback import Callback | |||
| from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR | |||
| get_square_sum = C.MultitypeFuncGraph("get_square_sum") | |||
| @get_square_sum.register("Tensor") | |||
| def _get_square_sum(grad): | |||
| norm = P.ReduceSum(False)(F.square(grad), ()) | |||
| norm = F.expand_dims(F.cast(norm, mstype.float32), 0) | |||
| return norm | |||
| apply_global_norm = C.MultitypeFuncGraph("apply_global_norm") | |||
| @apply_global_norm.register("Tensor", "Tensor", "Tensor") | |||
| def _apply_global_norm(clip_norm, global_norm, grad): | |||
| grad = grad * clip_norm / global_norm | |||
| return grad | |||
| class GlobalNorm(nn.Cell): | |||
| """ | |||
| Calculate the global norm value of given tensors | |||
| """ | |||
| def __init__(self): | |||
| super(GlobalNorm, self).__init__() | |||
| self.norm = nn.Norm() | |||
| self.hyper_map = C.HyperMap() | |||
| def construct(self, grads): | |||
| square_sum = self.hyper_map(get_square_sum, grads) | |||
| global_norms = F.sqrt(F.addn(square_sum) / F.scalar_to_array(len(square_sum))) | |||
| return global_norms | |||
| class ClipByGlobalNorm(nn.Cell): | |||
| """ | |||
| Clip grads by global norm | |||
| """ | |||
| def __init__(self, clip_norm=1.0): | |||
| super(ClipByGlobalNorm, self).__init__() | |||
| self.global_norm = GlobalNorm() | |||
| self.clip_norm = Tensor([clip_norm], mstype.float32) | |||
| self.hyper_map = C.HyperMap() | |||
| def construct(self, grads): | |||
| global_norm = self.global_norm(grads) | |||
| cond = P.GreaterEqual()(global_norm, self.clip_norm) | |||
| global_norm = F.select(cond, global_norm, self.clip_norm) | |||
| grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads) | |||
| return grads | |||
| class CrossEntropyCalculation(nn.Cell): | |||
| """ | |||
| Cross Entropy loss | |||