From 4851a67bb567511a2d2e9d3faa2d9e3d0dad7f81 Mon Sep 17 00:00:00 2001 From: jzg Date: Wed, 28 Oct 2020 11:16:21 +0800 Subject: [PATCH] add layer of clipbyglobalnorm. --- mindspore/nn/layer/basic.py | 2 + mindspore/ops/composite/__init__.py | 3 +- mindspore/ops/composite/clip_ops.py | 97 ++++++++++++++++++- .../nlp/bert/src/bert_for_pre_training.py | 3 +- model_zoo/official/nlp/bert/src/utils.py | 50 ---------- tests/ut/python/ops/test_ops.py | 19 ++++ 6 files changed, 119 insertions(+), 55 deletions(-) diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index dc17bec715..1195435bd1 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -273,6 +273,7 @@ class ClipByNorm(Cell): \text{output}(X) = \frac{\text{clip_norm} * X}{L_2(X)}, where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`. + Args: axis (Union[None, int, tuple(int)): Compute the L2-norm along the Specific dimension. Default: None, all dimensions to calculate. @@ -280,6 +281,7 @@ class ClipByNorm(Cell): Inputs: - **input** (Tensor) - Tensor of shape N-D. The type must be float32 or float16. - **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)`. + Or a tensor shape can be broadcast to input shape. Outputs: Tensor, clipped tensor with the same shape as the input, whose type is float32. diff --git a/mindspore/ops/composite/__init__.py b/mindspore/ops/composite/__init__.py index 3b97058df4..498f14f660 100644 --- a/mindspore/ops/composite/__init__.py +++ b/mindspore/ops/composite/__init__.py @@ -22,7 +22,7 @@ Pre-defined combination of operators. from .base import GradOperation, HyperMap, Map, MultitypeFuncGraph, add_flags, \ core, env_get, tail, zip_operation -from .clip_ops import clip_by_value +from .clip_ops import clip_by_value, clip_by_global_norm from .multitype_ops.add_impl import hyper_add from .multitype_ops.ones_like_impl import ones_like from .multitype_ops.zeros_like_impl import zeros_like @@ -49,4 +49,5 @@ __all__ = [ 'poisson', 'multinomial', 'clip_by_value', + 'clip_by_global_norm', 'count_nonzero'] diff --git a/mindspore/ops/composite/clip_ops.py b/mindspore/ops/composite/clip_ops.py index a87b1c1678..dbc65e0fc0 100644 --- a/mindspore/ops/composite/clip_ops.py +++ b/mindspore/ops/composite/clip_ops.py @@ -14,8 +14,15 @@ # ============================================================================ """Operations for clipping tensors to min/max values.""" - -from .. import operations as P +from mindspore.nn.cell import Cell +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor +from mindspore.common import dtype as mstype +from mindspore._checkparam import Rel +from mindspore._checkparam import Validator as validator +from mindspore.ops.primitive import constexpr def clip_by_value(x, clip_value_min, clip_value_max): @@ -41,3 +48,89 @@ def clip_by_value(x, clip_value_min, clip_value_max): x_min = min_op(x, clip_value_max) x_max = max_op(x_min, clip_value_min) return x_max + + +get_square_sum = C.MultitypeFuncGraph("get_square_sum") +@get_square_sum.register("Tensor") +def _get_square_sum(x): + norm = P.ReduceSum(False)(F.square(x), ()) + norm = F.expand_dims(F.cast(norm, mstype.float32), 0) + return norm + + +apply_global_norm = C.MultitypeFuncGraph("apply_global_norm") +@apply_global_norm.register("Tensor", "Tensor", "Tensor") +def _apply_global_norm(clip_norm, global_norm, x): + x = x * clip_norm / global_norm + return x + + +class _ClipByGlobalNorm(Cell): + r""" + Clips tensor values by the ratio of the sum of their norms. + + Args: + clip_norm (Union(float, int)): The clipping ratio. Default: 1.0 + use_norm (Union(float, None)): The global norm. Default: None + + Inputs: + - **x** (Union(tuple[Tensor], list[Tensor])) - Input data to clip. + + Outputs: + Tensor, a clipped Tensor. + """ + + def __init__(self, clip_norm=1.0, use_norm=None): + super(_ClipByGlobalNorm, self).__init__() + # Add interface. This parameter is not used at present + if use_norm is not None: + validator.check_number("use_norm", use_norm, 0.0, Rel.GE, self.cls_name) + validator.check_number("clip_norm", clip_norm, 0.0, Rel.GT, self.cls_name) + self.clip_norm = Tensor([clip_norm], mstype.float32) + self.hyper_map = C.HyperMap() + self.greater_equal = P.GreaterEqual() + + def construct(self, x): + square_sum = self.hyper_map(get_square_sum, x) + global_norm = F.sqrt(F.addn(square_sum)) + cond = self.greater_equal(global_norm, self.clip_norm) + global_norm = F.select(cond, global_norm, self.clip_norm) + clip_x = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), x) + return clip_x + + +@constexpr +def _check_value(clip_norm): + validator.check_number("clip_norm", clip_norm, 0.0, Rel.GT, "clip_by_global_norm") + return clip_norm + + +def clip_by_global_norm(x, clip_norm=1.0, use_norm=None): + r""" + Clips tensor values by the ratio of the sum of their norms. + Note: + 'input x' should be a tuple or list of tensors. Otherwise, it will raise an error. + + Args: + x (Union(tuple[Tensor], list[Tensor])): Input data to clip. + clip_norm (Union(float, int)): The clipping ratio. Default: 1.0 + use_norm (None): The global norm. Default: None. Currently only none is supported. + + Returns: + Tensor, a clipped Tensor. + + Examples: + >>> x1 = np.array([[2., 3.],[1., 2.]]).astype(np.float32) + >>> x2 = np.array([[1., 4.],[3., 1.]]).astype(np.float32) + >>> input_x = (Tensor(x1), Tensor(x2)) + >>> out = clip_by_global_norm(input_x, 1.0) + ([[ 2.98142403e-01, 4.47213590e-01], + [ 1.49071202e-01, 2.98142403e-01]], + + [[ 1.49071202e-01, 5.96284807e-01], + [ 4.47213590e-01, 1.49071202e-01]]) + """ + + clip_norm = _check_value(clip_norm) + out = _ClipByGlobalNorm(clip_norm, use_norm)(x) + return out diff --git a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py index a3dd3e10b6..1d3606921d 100644 --- a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py +++ b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py @@ -28,7 +28,6 @@ from mindspore.context import ParallelMode from mindspore.communication.management import get_group_size from mindspore import context from .bert_model import BertModel -from .utils import ClipByGlobalNorm GRADIENT_CLIP_TYPE = 1 GRADIENT_CLIP_VALUE = 1.0 @@ -565,7 +564,7 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): scaling = scaling_sens * self.degree * self.accumulation_steps grads = self.hyper_map(F.partial(grad_scale, scaling), grads) if self.enable_global_norm: - grads = ClipByGlobalNorm()(grads) + grads = C.clip_by_global_norm(grads, 1.0, None) else: grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) accu_overflow = self.overflow_reducer(accu_overflow) diff --git a/model_zoo/official/nlp/bert/src/utils.py b/model_zoo/official/nlp/bert/src/utils.py index 62aaeca1dd..94ce01b17b 100644 --- a/model_zoo/official/nlp/bert/src/utils.py +++ b/model_zoo/official/nlp/bert/src/utils.py @@ -24,62 +24,12 @@ import numpy as np import mindspore.nn as nn from mindspore import log as logger from mindspore.ops import operations as P -from mindspore.ops import functional as F -from mindspore.ops import composite as C from mindspore.common.tensor import Tensor from mindspore.common import dtype as mstype from mindspore.train.callback import Callback from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR -get_square_sum = C.MultitypeFuncGraph("get_square_sum") -@get_square_sum.register("Tensor") -def _get_square_sum(grad): - norm = P.ReduceSum(False)(F.square(grad), ()) - norm = F.expand_dims(F.cast(norm, mstype.float32), 0) - return norm - - -apply_global_norm = C.MultitypeFuncGraph("apply_global_norm") -@apply_global_norm.register("Tensor", "Tensor", "Tensor") -def _apply_global_norm(clip_norm, global_norm, grad): - grad = grad * clip_norm / global_norm - return grad - - -class GlobalNorm(nn.Cell): - """ - Calculate the global norm value of given tensors - """ - def __init__(self): - super(GlobalNorm, self).__init__() - self.norm = nn.Norm() - self.hyper_map = C.HyperMap() - - def construct(self, grads): - square_sum = self.hyper_map(get_square_sum, grads) - global_norms = F.sqrt(F.addn(square_sum) / F.scalar_to_array(len(square_sum))) - return global_norms - - -class ClipByGlobalNorm(nn.Cell): - """ - Clip grads by global norm - """ - def __init__(self, clip_norm=1.0): - super(ClipByGlobalNorm, self).__init__() - self.global_norm = GlobalNorm() - self.clip_norm = Tensor([clip_norm], mstype.float32) - self.hyper_map = C.HyperMap() - - def construct(self, grads): - global_norm = self.global_norm(grads) - cond = P.GreaterEqual()(global_norm, self.clip_norm) - global_norm = F.select(cond, global_norm, self.clip_norm) - grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads) - return grads - - class CrossEntropyCalculation(nn.Cell): """ Cross Entropy loss diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index e4193c1643..946d644b5c 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -240,6 +240,20 @@ class ClipByNorm(nn.Cell): return norm +class ClipByGlobalNorm(nn.Cell): + """ClipByGlobalNorm net definition""" + + def __init__(self, x, clip_norm=1.0, use_norm=None): + super(ClipByGlobalNorm, self).__init__() + self.x = x + self.clip_norm = clip_norm + self.use_norm = use_norm + + def construct(self): + norm = C.clip_by_global_norm(self.x, self.clip_norm, self.use_norm) + return norm + + class Embedding(nn.Cell): """Embedding net definition""" @@ -1130,6 +1144,11 @@ test_case_math_ops = [ 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32)), Tensor(np.array([0.01]).astype(np.float32))], 'skip': ['backward']}), + ('ClipByGlobalNorm', { + 'block': ClipByGlobalNorm(x=Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32)), + clip_norm=1.0, use_norm=None), + 'desc_inputs': [], + 'skip': ['backward']}), ('Embedding_1', { 'block': Embedding(vocab_size=10, embedding_size=3), 'desc_inputs': [Tensor(np.array([0, 2, 2, 7]).astype(np.int32))],