| @@ -273,6 +273,7 @@ class ClipByNorm(Cell): | |||
| \text{output}(X) = \frac{\text{clip_norm} * X}{L_2(X)}, | |||
| where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`. | |||
| Args: | |||
| axis (Union[None, int, tuple(int)): Compute the L2-norm along the Specific dimension. | |||
| Default: None, all dimensions to calculate. | |||
| @@ -280,6 +281,7 @@ class ClipByNorm(Cell): | |||
| Inputs: | |||
| - **input** (Tensor) - Tensor of shape N-D. The type must be float32 or float16. | |||
| - **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)`. | |||
| Or a tensor shape can be broadcast to input shape. | |||
| Outputs: | |||
| Tensor, clipped tensor with the same shape as the input, whose type is float32. | |||
| @@ -22,7 +22,7 @@ Pre-defined combination of operators. | |||
| from .base import GradOperation, HyperMap, Map, MultitypeFuncGraph, add_flags, \ | |||
| core, env_get, tail, zip_operation | |||
| from .clip_ops import clip_by_value | |||
| from .clip_ops import clip_by_value, clip_by_global_norm | |||
| from .multitype_ops.add_impl import hyper_add | |||
| from .multitype_ops.ones_like_impl import ones_like | |||
| from .multitype_ops.zeros_like_impl import zeros_like | |||
| @@ -49,4 +49,5 @@ __all__ = [ | |||
| 'poisson', | |||
| 'multinomial', | |||
| 'clip_by_value', | |||
| 'clip_by_global_norm', | |||
| 'count_nonzero'] | |||
| @@ -14,8 +14,15 @@ | |||
| # ============================================================================ | |||
| """Operations for clipping tensors to min/max values.""" | |||
| from .. import operations as P | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore._checkparam import Rel | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore.ops.primitive import constexpr | |||
| def clip_by_value(x, clip_value_min, clip_value_max): | |||
| @@ -41,3 +48,89 @@ def clip_by_value(x, clip_value_min, clip_value_max): | |||
| x_min = min_op(x, clip_value_max) | |||
| x_max = max_op(x_min, clip_value_min) | |||
| return x_max | |||
| get_square_sum = C.MultitypeFuncGraph("get_square_sum") | |||
| @get_square_sum.register("Tensor") | |||
| def _get_square_sum(x): | |||
| norm = P.ReduceSum(False)(F.square(x), ()) | |||
| norm = F.expand_dims(F.cast(norm, mstype.float32), 0) | |||
| return norm | |||
| apply_global_norm = C.MultitypeFuncGraph("apply_global_norm") | |||
| @apply_global_norm.register("Tensor", "Tensor", "Tensor") | |||
| def _apply_global_norm(clip_norm, global_norm, x): | |||
| x = x * clip_norm / global_norm | |||
| return x | |||
| class _ClipByGlobalNorm(Cell): | |||
| r""" | |||
| Clips tensor values by the ratio of the sum of their norms. | |||
| Args: | |||
| clip_norm (Union(float, int)): The clipping ratio. Default: 1.0 | |||
| use_norm (Union(float, None)): The global norm. Default: None | |||
| Inputs: | |||
| - **x** (Union(tuple[Tensor], list[Tensor])) - Input data to clip. | |||
| Outputs: | |||
| Tensor, a clipped Tensor. | |||
| """ | |||
| def __init__(self, clip_norm=1.0, use_norm=None): | |||
| super(_ClipByGlobalNorm, self).__init__() | |||
| # Add interface. This parameter is not used at present | |||
| if use_norm is not None: | |||
| validator.check_number("use_norm", use_norm, 0.0, Rel.GE, self.cls_name) | |||
| validator.check_number("clip_norm", clip_norm, 0.0, Rel.GT, self.cls_name) | |||
| self.clip_norm = Tensor([clip_norm], mstype.float32) | |||
| self.hyper_map = C.HyperMap() | |||
| self.greater_equal = P.GreaterEqual() | |||
| def construct(self, x): | |||
| square_sum = self.hyper_map(get_square_sum, x) | |||
| global_norm = F.sqrt(F.addn(square_sum)) | |||
| cond = self.greater_equal(global_norm, self.clip_norm) | |||
| global_norm = F.select(cond, global_norm, self.clip_norm) | |||
| clip_x = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), x) | |||
| return clip_x | |||
| @constexpr | |||
| def _check_value(clip_norm): | |||
| validator.check_number("clip_norm", clip_norm, 0.0, Rel.GT, "clip_by_global_norm") | |||
| return clip_norm | |||
| def clip_by_global_norm(x, clip_norm=1.0, use_norm=None): | |||
| r""" | |||
| Clips tensor values by the ratio of the sum of their norms. | |||
| Note: | |||
| 'input x' should be a tuple or list of tensors. Otherwise, it will raise an error. | |||
| Args: | |||
| x (Union(tuple[Tensor], list[Tensor])): Input data to clip. | |||
| clip_norm (Union(float, int)): The clipping ratio. Default: 1.0 | |||
| use_norm (None): The global norm. Default: None. Currently only none is supported. | |||
| Returns: | |||
| Tensor, a clipped Tensor. | |||
| Examples: | |||
| >>> x1 = np.array([[2., 3.],[1., 2.]]).astype(np.float32) | |||
| >>> x2 = np.array([[1., 4.],[3., 1.]]).astype(np.float32) | |||
| >>> input_x = (Tensor(x1), Tensor(x2)) | |||
| >>> out = clip_by_global_norm(input_x, 1.0) | |||
| ([[ 2.98142403e-01, 4.47213590e-01], | |||
| [ 1.49071202e-01, 2.98142403e-01]], | |||
| [[ 1.49071202e-01, 5.96284807e-01], | |||
| [ 4.47213590e-01, 1.49071202e-01]]) | |||
| """ | |||
| clip_norm = _check_value(clip_norm) | |||
| out = _ClipByGlobalNorm(clip_norm, use_norm)(x) | |||
| return out | |||
| @@ -28,7 +28,6 @@ from mindspore.context import ParallelMode | |||
| from mindspore.communication.management import get_group_size | |||
| from mindspore import context | |||
| from .bert_model import BertModel | |||
| from .utils import ClipByGlobalNorm | |||
| GRADIENT_CLIP_TYPE = 1 | |||
| GRADIENT_CLIP_VALUE = 1.0 | |||
| @@ -565,7 +564,7 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): | |||
| scaling = scaling_sens * self.degree * self.accumulation_steps | |||
| grads = self.hyper_map(F.partial(grad_scale, scaling), grads) | |||
| if self.enable_global_norm: | |||
| grads = ClipByGlobalNorm()(grads) | |||
| grads = C.clip_by_global_norm(grads, 1.0, None) | |||
| else: | |||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||
| accu_overflow = self.overflow_reducer(accu_overflow) | |||
| @@ -24,62 +24,12 @@ import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import log as logger | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.train.callback import Callback | |||
| from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR | |||
| get_square_sum = C.MultitypeFuncGraph("get_square_sum") | |||
| @get_square_sum.register("Tensor") | |||
| def _get_square_sum(grad): | |||
| norm = P.ReduceSum(False)(F.square(grad), ()) | |||
| norm = F.expand_dims(F.cast(norm, mstype.float32), 0) | |||
| return norm | |||
| apply_global_norm = C.MultitypeFuncGraph("apply_global_norm") | |||
| @apply_global_norm.register("Tensor", "Tensor", "Tensor") | |||
| def _apply_global_norm(clip_norm, global_norm, grad): | |||
| grad = grad * clip_norm / global_norm | |||
| return grad | |||
| class GlobalNorm(nn.Cell): | |||
| """ | |||
| Calculate the global norm value of given tensors | |||
| """ | |||
| def __init__(self): | |||
| super(GlobalNorm, self).__init__() | |||
| self.norm = nn.Norm() | |||
| self.hyper_map = C.HyperMap() | |||
| def construct(self, grads): | |||
| square_sum = self.hyper_map(get_square_sum, grads) | |||
| global_norms = F.sqrt(F.addn(square_sum) / F.scalar_to_array(len(square_sum))) | |||
| return global_norms | |||
| class ClipByGlobalNorm(nn.Cell): | |||
| """ | |||
| Clip grads by global norm | |||
| """ | |||
| def __init__(self, clip_norm=1.0): | |||
| super(ClipByGlobalNorm, self).__init__() | |||
| self.global_norm = GlobalNorm() | |||
| self.clip_norm = Tensor([clip_norm], mstype.float32) | |||
| self.hyper_map = C.HyperMap() | |||
| def construct(self, grads): | |||
| global_norm = self.global_norm(grads) | |||
| cond = P.GreaterEqual()(global_norm, self.clip_norm) | |||
| global_norm = F.select(cond, global_norm, self.clip_norm) | |||
| grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads) | |||
| return grads | |||
| class CrossEntropyCalculation(nn.Cell): | |||
| """ | |||
| Cross Entropy loss | |||
| @@ -240,6 +240,20 @@ class ClipByNorm(nn.Cell): | |||
| return norm | |||
| class ClipByGlobalNorm(nn.Cell): | |||
| """ClipByGlobalNorm net definition""" | |||
| def __init__(self, x, clip_norm=1.0, use_norm=None): | |||
| super(ClipByGlobalNorm, self).__init__() | |||
| self.x = x | |||
| self.clip_norm = clip_norm | |||
| self.use_norm = use_norm | |||
| def construct(self): | |||
| norm = C.clip_by_global_norm(self.x, self.clip_norm, self.use_norm) | |||
| return norm | |||
| class Embedding(nn.Cell): | |||
| """Embedding net definition""" | |||
| @@ -1130,6 +1144,11 @@ test_case_math_ops = [ | |||
| 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32)), | |||
| Tensor(np.array([0.01]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| ('ClipByGlobalNorm', { | |||
| 'block': ClipByGlobalNorm(x=Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32)), | |||
| clip_norm=1.0, use_norm=None), | |||
| 'desc_inputs': [], | |||
| 'skip': ['backward']}), | |||
| ('Embedding_1', { | |||
| 'block': Embedding(vocab_size=10, embedding_size=3), | |||
| 'desc_inputs': [Tensor(np.array([0, 2, 2, 7]).astype(np.int32))], | |||