| @@ -273,6 +273,7 @@ class ClipByNorm(Cell): | |||||
| \text{output}(X) = \frac{\text{clip_norm} * X}{L_2(X)}, | \text{output}(X) = \frac{\text{clip_norm} * X}{L_2(X)}, | ||||
| where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`. | where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`. | ||||
| Args: | Args: | ||||
| axis (Union[None, int, tuple(int)): Compute the L2-norm along the Specific dimension. | axis (Union[None, int, tuple(int)): Compute the L2-norm along the Specific dimension. | ||||
| Default: None, all dimensions to calculate. | Default: None, all dimensions to calculate. | ||||
| @@ -280,6 +281,7 @@ class ClipByNorm(Cell): | |||||
| Inputs: | Inputs: | ||||
| - **input** (Tensor) - Tensor of shape N-D. The type must be float32 or float16. | - **input** (Tensor) - Tensor of shape N-D. The type must be float32 or float16. | ||||
| - **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)`. | - **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)`. | ||||
| Or a tensor shape can be broadcast to input shape. | |||||
| Outputs: | Outputs: | ||||
| Tensor, clipped tensor with the same shape as the input, whose type is float32. | Tensor, clipped tensor with the same shape as the input, whose type is float32. | ||||
| @@ -22,7 +22,7 @@ Pre-defined combination of operators. | |||||
| from .base import GradOperation, HyperMap, Map, MultitypeFuncGraph, add_flags, \ | from .base import GradOperation, HyperMap, Map, MultitypeFuncGraph, add_flags, \ | ||||
| core, env_get, tail, zip_operation | core, env_get, tail, zip_operation | ||||
| from .clip_ops import clip_by_value | |||||
| from .clip_ops import clip_by_value, clip_by_global_norm | |||||
| from .multitype_ops.add_impl import hyper_add | from .multitype_ops.add_impl import hyper_add | ||||
| from .multitype_ops.ones_like_impl import ones_like | from .multitype_ops.ones_like_impl import ones_like | ||||
| from .multitype_ops.zeros_like_impl import zeros_like | from .multitype_ops.zeros_like_impl import zeros_like | ||||
| @@ -49,4 +49,5 @@ __all__ = [ | |||||
| 'poisson', | 'poisson', | ||||
| 'multinomial', | 'multinomial', | ||||
| 'clip_by_value', | 'clip_by_value', | ||||
| 'clip_by_global_norm', | |||||
| 'count_nonzero'] | 'count_nonzero'] | ||||
| @@ -14,8 +14,15 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Operations for clipping tensors to min/max values.""" | """Operations for clipping tensors to min/max values.""" | ||||
| from .. import operations as P | |||||
| from mindspore.nn.cell import Cell | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore._checkparam import Rel | |||||
| from mindspore._checkparam import Validator as validator | |||||
| from mindspore.ops.primitive import constexpr | |||||
| def clip_by_value(x, clip_value_min, clip_value_max): | def clip_by_value(x, clip_value_min, clip_value_max): | ||||
| @@ -41,3 +48,89 @@ def clip_by_value(x, clip_value_min, clip_value_max): | |||||
| x_min = min_op(x, clip_value_max) | x_min = min_op(x, clip_value_max) | ||||
| x_max = max_op(x_min, clip_value_min) | x_max = max_op(x_min, clip_value_min) | ||||
| return x_max | return x_max | ||||
| get_square_sum = C.MultitypeFuncGraph("get_square_sum") | |||||
| @get_square_sum.register("Tensor") | |||||
| def _get_square_sum(x): | |||||
| norm = P.ReduceSum(False)(F.square(x), ()) | |||||
| norm = F.expand_dims(F.cast(norm, mstype.float32), 0) | |||||
| return norm | |||||
| apply_global_norm = C.MultitypeFuncGraph("apply_global_norm") | |||||
| @apply_global_norm.register("Tensor", "Tensor", "Tensor") | |||||
| def _apply_global_norm(clip_norm, global_norm, x): | |||||
| x = x * clip_norm / global_norm | |||||
| return x | |||||
| class _ClipByGlobalNorm(Cell): | |||||
| r""" | |||||
| Clips tensor values by the ratio of the sum of their norms. | |||||
| Args: | |||||
| clip_norm (Union(float, int)): The clipping ratio. Default: 1.0 | |||||
| use_norm (Union(float, None)): The global norm. Default: None | |||||
| Inputs: | |||||
| - **x** (Union(tuple[Tensor], list[Tensor])) - Input data to clip. | |||||
| Outputs: | |||||
| Tensor, a clipped Tensor. | |||||
| """ | |||||
| def __init__(self, clip_norm=1.0, use_norm=None): | |||||
| super(_ClipByGlobalNorm, self).__init__() | |||||
| # Add interface. This parameter is not used at present | |||||
| if use_norm is not None: | |||||
| validator.check_number("use_norm", use_norm, 0.0, Rel.GE, self.cls_name) | |||||
| validator.check_number("clip_norm", clip_norm, 0.0, Rel.GT, self.cls_name) | |||||
| self.clip_norm = Tensor([clip_norm], mstype.float32) | |||||
| self.hyper_map = C.HyperMap() | |||||
| self.greater_equal = P.GreaterEqual() | |||||
| def construct(self, x): | |||||
| square_sum = self.hyper_map(get_square_sum, x) | |||||
| global_norm = F.sqrt(F.addn(square_sum)) | |||||
| cond = self.greater_equal(global_norm, self.clip_norm) | |||||
| global_norm = F.select(cond, global_norm, self.clip_norm) | |||||
| clip_x = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), x) | |||||
| return clip_x | |||||
| @constexpr | |||||
| def _check_value(clip_norm): | |||||
| validator.check_number("clip_norm", clip_norm, 0.0, Rel.GT, "clip_by_global_norm") | |||||
| return clip_norm | |||||
| def clip_by_global_norm(x, clip_norm=1.0, use_norm=None): | |||||
| r""" | |||||
| Clips tensor values by the ratio of the sum of their norms. | |||||
| Note: | |||||
| 'input x' should be a tuple or list of tensors. Otherwise, it will raise an error. | |||||
| Args: | |||||
| x (Union(tuple[Tensor], list[Tensor])): Input data to clip. | |||||
| clip_norm (Union(float, int)): The clipping ratio. Default: 1.0 | |||||
| use_norm (None): The global norm. Default: None. Currently only none is supported. | |||||
| Returns: | |||||
| Tensor, a clipped Tensor. | |||||
| Examples: | |||||
| >>> x1 = np.array([[2., 3.],[1., 2.]]).astype(np.float32) | |||||
| >>> x2 = np.array([[1., 4.],[3., 1.]]).astype(np.float32) | |||||
| >>> input_x = (Tensor(x1), Tensor(x2)) | |||||
| >>> out = clip_by_global_norm(input_x, 1.0) | |||||
| ([[ 2.98142403e-01, 4.47213590e-01], | |||||
| [ 1.49071202e-01, 2.98142403e-01]], | |||||
| [[ 1.49071202e-01, 5.96284807e-01], | |||||
| [ 4.47213590e-01, 1.49071202e-01]]) | |||||
| """ | |||||
| clip_norm = _check_value(clip_norm) | |||||
| out = _ClipByGlobalNorm(clip_norm, use_norm)(x) | |||||
| return out | |||||
| @@ -28,7 +28,6 @@ from mindspore.context import ParallelMode | |||||
| from mindspore.communication.management import get_group_size | from mindspore.communication.management import get_group_size | ||||
| from mindspore import context | from mindspore import context | ||||
| from .bert_model import BertModel | from .bert_model import BertModel | ||||
| from .utils import ClipByGlobalNorm | |||||
| GRADIENT_CLIP_TYPE = 1 | GRADIENT_CLIP_TYPE = 1 | ||||
| GRADIENT_CLIP_VALUE = 1.0 | GRADIENT_CLIP_VALUE = 1.0 | ||||
| @@ -565,7 +564,7 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): | |||||
| scaling = scaling_sens * self.degree * self.accumulation_steps | scaling = scaling_sens * self.degree * self.accumulation_steps | ||||
| grads = self.hyper_map(F.partial(grad_scale, scaling), grads) | grads = self.hyper_map(F.partial(grad_scale, scaling), grads) | ||||
| if self.enable_global_norm: | if self.enable_global_norm: | ||||
| grads = ClipByGlobalNorm()(grads) | |||||
| grads = C.clip_by_global_norm(grads, 1.0, None) | |||||
| else: | else: | ||||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | ||||
| accu_overflow = self.overflow_reducer(accu_overflow) | accu_overflow = self.overflow_reducer(accu_overflow) | ||||
| @@ -24,62 +24,12 @@ import numpy as np | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import functional as F | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from mindspore.train.callback import Callback | from mindspore.train.callback import Callback | ||||
| from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR | from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR | ||||
| get_square_sum = C.MultitypeFuncGraph("get_square_sum") | |||||
| @get_square_sum.register("Tensor") | |||||
| def _get_square_sum(grad): | |||||
| norm = P.ReduceSum(False)(F.square(grad), ()) | |||||
| norm = F.expand_dims(F.cast(norm, mstype.float32), 0) | |||||
| return norm | |||||
| apply_global_norm = C.MultitypeFuncGraph("apply_global_norm") | |||||
| @apply_global_norm.register("Tensor", "Tensor", "Tensor") | |||||
| def _apply_global_norm(clip_norm, global_norm, grad): | |||||
| grad = grad * clip_norm / global_norm | |||||
| return grad | |||||
| class GlobalNorm(nn.Cell): | |||||
| """ | |||||
| Calculate the global norm value of given tensors | |||||
| """ | |||||
| def __init__(self): | |||||
| super(GlobalNorm, self).__init__() | |||||
| self.norm = nn.Norm() | |||||
| self.hyper_map = C.HyperMap() | |||||
| def construct(self, grads): | |||||
| square_sum = self.hyper_map(get_square_sum, grads) | |||||
| global_norms = F.sqrt(F.addn(square_sum) / F.scalar_to_array(len(square_sum))) | |||||
| return global_norms | |||||
| class ClipByGlobalNorm(nn.Cell): | |||||
| """ | |||||
| Clip grads by global norm | |||||
| """ | |||||
| def __init__(self, clip_norm=1.0): | |||||
| super(ClipByGlobalNorm, self).__init__() | |||||
| self.global_norm = GlobalNorm() | |||||
| self.clip_norm = Tensor([clip_norm], mstype.float32) | |||||
| self.hyper_map = C.HyperMap() | |||||
| def construct(self, grads): | |||||
| global_norm = self.global_norm(grads) | |||||
| cond = P.GreaterEqual()(global_norm, self.clip_norm) | |||||
| global_norm = F.select(cond, global_norm, self.clip_norm) | |||||
| grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads) | |||||
| return grads | |||||
| class CrossEntropyCalculation(nn.Cell): | class CrossEntropyCalculation(nn.Cell): | ||||
| """ | """ | ||||
| Cross Entropy loss | Cross Entropy loss | ||||
| @@ -240,6 +240,20 @@ class ClipByNorm(nn.Cell): | |||||
| return norm | return norm | ||||
| class ClipByGlobalNorm(nn.Cell): | |||||
| """ClipByGlobalNorm net definition""" | |||||
| def __init__(self, x, clip_norm=1.0, use_norm=None): | |||||
| super(ClipByGlobalNorm, self).__init__() | |||||
| self.x = x | |||||
| self.clip_norm = clip_norm | |||||
| self.use_norm = use_norm | |||||
| def construct(self): | |||||
| norm = C.clip_by_global_norm(self.x, self.clip_norm, self.use_norm) | |||||
| return norm | |||||
| class Embedding(nn.Cell): | class Embedding(nn.Cell): | ||||
| """Embedding net definition""" | """Embedding net definition""" | ||||
| @@ -1130,6 +1144,11 @@ test_case_math_ops = [ | |||||
| 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32)), | 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32)), | ||||
| Tensor(np.array([0.01]).astype(np.float32))], | Tensor(np.array([0.01]).astype(np.float32))], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| ('ClipByGlobalNorm', { | |||||
| 'block': ClipByGlobalNorm(x=Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32)), | |||||
| clip_norm=1.0, use_norm=None), | |||||
| 'desc_inputs': [], | |||||
| 'skip': ['backward']}), | |||||
| ('Embedding_1', { | ('Embedding_1', { | ||||
| 'block': Embedding(vocab_size=10, embedding_size=3), | 'block': Embedding(vocab_size=10, embedding_size=3), | ||||
| 'desc_inputs': [Tensor(np.array([0, 2, 2, 7]).astype(np.int32))], | 'desc_inputs': [Tensor(np.array([0, 2, 2, 7]).astype(np.int32))], | ||||