Browse Source

!12394 add raises description for Adam, Lamb, Momentum, etc. operators

From: @wangshuide2020
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
fd8d767bbe
22 changed files with 354 additions and 28 deletions
  1. +12
    -0
      mindspore/nn/layer/combined.py
  2. +1
    -1
      mindspore/nn/layer/conv.py
  3. +31
    -1
      mindspore/nn/layer/image.py
  4. +21
    -6
      mindspore/nn/layer/math.py
  5. +28
    -1
      mindspore/nn/layer/pooling.py
  6. +47
    -2
      mindspore/nn/layer/quant.py
  7. +54
    -0
      mindspore/nn/learning_rate_schedule.py
  8. +11
    -1
      mindspore/nn/optim/ada_grad.py
  9. +29
    -0
      mindspore/nn/optim/adam.py
  10. +11
    -1
      mindspore/nn/optim/ftrl.py
  11. +9
    -0
      mindspore/nn/optim/lamb.py
  12. +11
    -1
      mindspore/nn/optim/lazyadam.py
  13. +8
    -3
      mindspore/nn/optim/momentum.py
  14. +9
    -3
      mindspore/nn/optim/optimizer.py
  15. +10
    -1
      mindspore/nn/optim/proximal_ada_grad.py
  16. +10
    -1
      mindspore/nn/optim/rmsprop.py
  17. +14
    -1
      mindspore/nn/wrap/cell_wrapper.py
  18. +8
    -0
      mindspore/nn/wrap/loss_scale.py
  19. +1
    -1
      mindspore/ops/operations/array_ops.py
  20. +1
    -1
      mindspore/ops/operations/comm_ops.py
  21. +8
    -1
      mindspore/ops/operations/math_ops.py
  22. +20
    -2
      mindspore/ops/operations/nn_ops.py

+ 12
- 0
mindspore/nn/layer/combined.py View File

@@ -76,6 +76,12 @@ class Conv2dBnAct(Cell):
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.

Raises:
TypeError: If `in_channels`, `out_channels`, `stride`, `padding` or `dilation` is not an int.
TypeError: If `has_bias` is not a bool.
ValueError: If `in_channels` or `out_channels` `stride`, `padding` or `dilation` is less than 1.
ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.

Supported Platforms:
``Ascend`` ``GPU``

@@ -170,6 +176,12 @@ class DenseBnAct(Cell):
Outputs:
Tensor of shape :math:`(N, out\_channels)`.

Raises:
TypeError: If `in_channels` or `out_channels` is not an int.
TypeError: If `has_bias`, `has_bn` or `after_fake` is not a bool.
TypeError: If `momentum` or `eps` is not a float.
ValueError: If `momentum` is not in range [0, 1.0].

Supported Platforms:
``Ascend`` ``GPU``



+ 1
- 1
mindspore/nn/layer/conv.py View File

@@ -372,7 +372,7 @@ class Conv1d(_Conv):
ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.

Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> net = nn.Conv1d(120, 240, 4, has_bias=False, weight_init='normal')


+ 31
- 1
mindspore/nn/layer/image.py View File

@@ -29,6 +29,7 @@ from ..cell import Cell

__all__ = ['ImageGradients', 'SSIM', 'MSSSIM', 'PSNR', 'CentralCrop']


class ImageGradients(Cell):
r"""
Returns two tensors, the first is along the height dimension and the second is along the width dimension.
@@ -50,6 +51,9 @@ class ImageGradients(Cell):
- **dy** (Tensor) - vertical image gradients, the same type and shape as input.
- **dx** (Tensor) - horizontal image gradients, the same type and shape as input.

Raises:
ValueError: If length of shape of `images` is not equal to 4.

Supported Platforms:
``Ascend`` ``GPU``

@@ -89,7 +93,7 @@ class ImageGradients(Cell):

def _convert_img_dtype_to_float32(img, max_val):
"""convert img dtype to float32"""
# Ususally max_val is 1.0 or 255, we will do the scaling if max_val > 1.
# Usually max_val is 1.0 or 255, we will do the scaling if max_val > 1.
# We will scale img pixel value if max_val > 1. and just cast otherwise.
ret = F.cast(img, mstype.float32)
max_val = F.scalar_cast(max_val, mstype.float32)
@@ -214,6 +218,13 @@ class SSIM(Cell):
Outputs:
Tensor, has the same dtype as img1. It is a 1-D tensor with shape N, where N is the batch num of img1.

Raises:
TypeError: If `max_val` is neither int nor float.
TypeError: If `k1`, `k2` or `filter_sigma` is not a float.
TypeError: If `filter_size` is not an int.
ValueError: If `max_val` or `filter_sigma` is less than or equal to 0.
ValueError: If `filter_size` is less than 0.

Supported Platforms:
``Ascend`` ``GPU``

@@ -296,6 +307,15 @@ class MSSSIM(Cell):
Outputs:
Tensor, the value is in range [0, 1]. It is a 1-D tensor with shape N, where N is the batch num of img1.

Raises:
TypeError: If `max_val` is neither int nor float.
TypeError: If `power_factors` is neither tuple nor list.
TypeError: If `k1`, `k2` or `filter_sigma` is not a float.
TypeError: If `filter_size` is not an int.
ValueError: If `max_val` or `filter_sigma` is less than or equal to 0.
ValueError: If `filter_size` is less than 0.
ValueError: If length of shape of `img1` or `img2` is not equal to 4.

Supported Platforms:
``Ascend``

@@ -391,6 +411,11 @@ class PSNR(Cell):
Outputs:
Tensor, with dtype mindspore.float32. It is a 1-D tensor with shape N, where N is the batch num of img1.

Raises:
TypeError: If `max_val` is neither int nor float.
ValueError: If `max_val` is less than or equal to 0.
ValueError: If length of shape of `img1` or `img2` is not equal to 4.

Supported Platforms:
``Ascend`` ``GPU``

@@ -451,6 +476,7 @@ def _get_bbox(rank, shape, central_fraction):

return bbox_begin, bbox_size


class CentralCrop(Cell):
"""
Crop the centeral region of the images with the central_fraction.
@@ -464,6 +490,10 @@ class CentralCrop(Cell):
Outputs:
Tensor, 3-D or 4-D float tensor, according to the input.

Raises:
TypeError: If `central_fraction` is not a float.
ValueError: If `central_fraction` is not in range (0, 1.0].

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``



+ 21
- 6
mindspore/nn/layer/math.py View File

@@ -75,6 +75,11 @@ class ReduceLogSumExp(Cell):
- If axis is tuple(int), set as (2, 3), and keep_dims is False,
the shape of output is :math:`(x_1, x_4, ..., x_R)`.

Raises:
TypeError: If `axis` is not one of int, list, tuple.
TypeError: If `keep_dims` is not bool.
TypeError: If dtype of `x` is neither float16 nor float32.

Supported Platforms:
``Ascend`` ``GPU``

@@ -205,7 +210,7 @@ class LGamma(Cell):
Tensor, has the same shape and dtype as the `x`.

Raises:
TypeError: If dtype of input x is not float16 nor float32.
TypeError: If dtype of `x` is neither float16 nor float32.

Supported Platforms:
``Ascend`` ``GPU``
@@ -323,7 +328,7 @@ class DiGamma(Cell):
Tensor, has the same shape and dtype as the `x`.

Raises:
TypeError: If dtype of input x is not float16 nor float32.
TypeError: If dtype of `x` is neither float16 nor float32.

Supported Platforms:
``Ascend`` ``GPU``
@@ -679,8 +684,8 @@ class LBeta(Cell):
Tensor, has the same dtype as `x` and `y`.

Raises:
TypeError: If dtype of input x and a is not float16 nor float32,
or if x has different dtype with a.
TypeError: If dtype of `x` or `y` is neither float16 nor float32,
or if `x` has different dtype with `y`.

Supported Platforms:
``Ascend`` ``GPU``
@@ -857,6 +862,11 @@ class MatMul(Cell):
Outputs:
Tensor, the shape of the output tensor depends on the dimension of input tensors.

Raises:
TypeError: If `transpose_x1` or `transpose_x2` is not a bool.
ValueError: If the column of matrix dimensions of `input_x1` is not equal to
the row of matrix dimensions of `input_x2`.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -933,6 +943,11 @@ class Moments(Cell):
- **mean** (Tensor) - The mean of input x, with the same date type as input x.
- **variance** (Tensor) - The variance of input x, with the same date type as input x.

Raises:
TypeError: If `axis` is not one of int, tuple, None.
TypeError: If `keep_dims` is neither bool nor None.
TypeError: If dtype of `input_x` is neither float16 nor float32.

Supported Platforms:
``Ascend`` ``GPU``

@@ -993,7 +1008,7 @@ class MatInverse(Cell):
Tensor, has the same dtype as the `a`.

Raises:
TypeError: If dtype of input x is not float16 nor float32.
TypeError: If dtype of `a` is neither float16 nor float32.

Supported Platforms:
``GPU``
@@ -1033,7 +1048,7 @@ class MatDet(Cell):
Tensor, has the same dtype as the `a`.

Raises:
TypeError: If dtype of input x is not float16 nor float32.
TypeError: If dtype of `a` is neither float16 nor float32.

Supported Platforms:
``GPU``


+ 28
- 1
mindspore/nn/layer/pooling.py View File

@@ -104,6 +104,13 @@ class MaxPool2d(_PoolNd):
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.

Raises:
TypeError: If `kernel_size` or `strides` is neither int nor tuple.
ValueError: If `pad_mode` is neither 'valid' nor 'same' with not case sensitive.
ValueError: If `data_format` is neither 'NCHW' nor 'NHWC'.
ValueError: If `kernel_size` or `strides` is less than 1.
ValueError: If length of shape of `input` is not equal to 4.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -164,8 +171,15 @@ class MaxPool1d(_PoolNd):
Outputs:
Tensor of shape :math:`(N, C, L_{out}))`.

Raises:
TypeError: If `kernel_size` or `strides` is not an int.
ValueError: If `pad_mode` is neither 'valid' nor 'same' with not case sensitive.
ValueError: If `data_format` is neither 'NCHW' nor 'NHWC'.
ValueError: If `kernel_size` or `strides` is less than 1.
ValueError: If length of shape of `input` is not equal to 4.

Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> max_pool = nn.MaxPool1d(kernel_size=3, stride=1)
@@ -246,6 +260,13 @@ class AvgPool2d(_PoolNd):
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.

Raises:
TypeError: If `kernel_size` or `strides` is neither int nor tuple.
ValueError: If `pad_mode` is neither 'valid' nor 'same' with not case sensitive.
ValueError: If `data_format` is neither 'NCHW' nor 'NHWC'.
ValueError: If `kernel_size` or `strides` is less than 1.
ValueError: If length of shape of `input` is not equal to 4.

Supported Platforms:
``Ascend`` ``GPU``

@@ -311,6 +332,12 @@ class AvgPool1d(_PoolNd):
Outputs:
Tensor of shape :math:`(N, C_{out}, L_{out})`.

Raises:
TypeError: If `kernel_size` or `stride` is not an int.
ValueError: If `pad_mode` is neither 'same' nor 'valid' with not case sensitive.
ValueError: If `kernel_size` or `strides` is less than 1.
ValueError: If length of shape of `input` is not equal to 3.

Supported Platforms:
``Ascend`` ``GPU``



+ 47
- 2
mindspore/nn/layer/quant.py View File

@@ -286,6 +286,15 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver):
Outputs:
Tensor, with the same type and shape as the `input`.

Raises:
TypeError: If `min_init` or `max_init` is neither int nor float.
TypeError: If `quant_delay` is not an int.
TypeError: If `min_init` is not less than `max_init`.
TypeError: If `quant_delay` is not greater than or equal to 0.

Supported Platforms:
``Ascend`` ``GPU``

Examples:
>>> fake_quant = nn.FakeQuantWithMinMaxObserver()
>>> input = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
@@ -432,10 +441,19 @@ class Conv2dBnFoldQuantOneConv(Cell):
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.

Raises:
TypeError: If `in_channels`, `out_channels`, `stride`, `padding` or `dilation` is not an int.
TypeError: If `has_bias` is not a bool.
ValueError: If `in_channels` or `out_channels` `stride`, `padding` or `dilation` is less than 1.
ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.

Supported Platforms:
``Ascend`` ``GPU``

Examples:
>>> qconfig = compression.quant.create_quant_config()
>>> conv2d_bnfold = nn.Conv2dBnFoldQuant(1, 6, kernel_size=(2, 2), stride=(1, 1), pad_mode="valid",
... quant_config=qconfig)
>>> conv2d_bnfold = nn.Conv2dBnFoldQuantOneConv(1, 6, kernel_size=(2, 2), stride=(1, 1), pad_mode="valid",
... quant_config=qconfig)
>>> input = Tensor(np.random.randint(-2, 2, (2, 1, 3, 3)), mindspore.float32)
>>> result = conv2d_bnfold(input)
>>> output = result.shape
@@ -634,6 +652,12 @@ class Conv2dBnFoldQuant(Cell):
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.

Raises:
TypeError: If `in_channels`, `out_channels`, `stride`, `padding` or `dilation` is not an int.
TypeError: If `has_bias` is not a bool.
ValueError: If `in_channels` or `out_channels` `stride`, `padding` or `dilation` is less than 1.
ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.

Supported Platforms:
``Ascend`` ``GPU``

@@ -926,6 +950,12 @@ class Conv2dQuant(Cell):
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.

Raises:
TypeError: If `in_channels`, `out_channels`, `stride`, `padding` or `dilation` is not an int.
TypeError: If `has_bias` is not a bool.
ValueError: If `in_channels` or `out_channels` `stride`, `padding` or `dilation` is less than 1.
ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.

Supported Platforms:
``Ascend`` ``GPU``

@@ -1033,6 +1063,11 @@ class DenseQuant(Cell):
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.

Raises:
TypeError: If `in_channels`, `out_channels` is not an int.
TypeError: If `has_bias` is not a bool.
ValueError: If `in_channels` or `out_channels` is less than 1.

Supported Platforms:
``Ascend`` ``GPU``

@@ -1145,6 +1180,10 @@ class ActQuant(_QuantActivation):
Outputs:
Tensor, with the same type and shape as the `input`.

Raises:
TypeError: If `activation` is not an instance of Cell.
TypeError: If `fake_before` is not a bool.

Supported Platforms:
``Ascend`` ``GPU``

@@ -1212,6 +1251,9 @@ class TensorAddQuant(Cell):
Outputs:
Tensor, with the same type and shape as the `input_x1`.

Raises:
TypeError: If `ema_decay` is not a float.

Supported Platforms:
``Ascend`` ``GPU``

@@ -1265,6 +1307,9 @@ class MulQuant(Cell):
Outputs:
Tensor, with the same type and shape as the `input_x1`.

Raises:
TypeError: If `ema_decay` is not a float.

Supported Platforms:
``Ascend`` ``GPU``



+ 54
- 0
mindspore/nn/learning_rate_schedule.py View File

@@ -82,6 +82,15 @@ class ExponentialDecayLR(LearningRateSchedule):
Outputs:
Tensor. The learning rate value for the current step.

Raises:
TypeError: If `learning_rate` or `decay_rate` is not a float.
TypeError: If `decay_steps` is not an int or `is_stair` is not a bool.
ValueError: If `decay_steps` is less than 1.
ValueError: If `learning_rate` or `decay_rate` is less than or equal to 0.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> learning_rate = 0.1
>>> decay_rate = 0.9
@@ -140,6 +149,15 @@ class NaturalExpDecayLR(LearningRateSchedule):
Outputs:
Tensor. The learning rate value for the current step.

Raises:
TypeError: If `learning_rate` or `decay_rate` is not a float.
TypeError: If `decay_steps` is not an int or `is_stair` is not a bool.
ValueError: If `decay_steps` is less than 1.
ValueError: If `learning_rate` or `decay_rate` is less than or equal to 0.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> learning_rate = 0.1
>>> decay_rate = 0.9
@@ -199,6 +217,15 @@ class InverseDecayLR(LearningRateSchedule):
Outputs:
Tensor. The learning rate value for the current step.

Raises:
TypeError: If `learning_rate` or `decay_rate` is not a float.
TypeError: If `decay_steps` is not an int or `is_stair` is not a bool.
ValueError: If `decay_steps` is less than 1.
ValueError: If `learning_rate` or `decay_rate` is less than or equal to 0.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> learning_rate = 0.1
>>> decay_rate = 0.9
@@ -247,6 +274,15 @@ class CosineDecayLR(LearningRateSchedule):
Outputs:
Tensor. The learning rate value for the current step.

Raises:
TypeError: If `min_lr` or `max_lr` is not a float.
TypeError: If `decay_steps` is not an int.
ValueError: If `min_lr` is less than 0 or `decay_steps` is less than 1.
ValueError: If `max_lr` is less than or equal to 0.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> min_lr = 0.01
>>> max_lr = 0.1
@@ -314,6 +350,15 @@ class PolynomialDecayLR(LearningRateSchedule):
Outputs:
Tensor. The learning rate value for the current step.

Raises:
TypeError: If `learning_rate`, `end_learning_rate` or `power` is not a float.
TypeError: If `decay_steps` is not an int or `update_decay_steps` is not a bool.
ValueError: If `end_learning_rate` is less than 0 or `decay_steps` is less than 1.
ValueError: If `learning_rate` or `power` is less than or equal to 0.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> learning_rate = 0.1
>>> end_learning_rate = 0.01
@@ -384,6 +429,15 @@ class WarmUpLR(LearningRateSchedule):
Outputs:
Tensor. The learning rate value for the current step.

Raises:
TypeError: If `learning_rate` is not a float.
TypeError: If `warmup_steps` is not an int.
ValueError: If `warmup_steps` is less than 1.
ValueError: If `learning_rate` is less than or equal to 0.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> learning_rate = 0.1
>>> warmup_steps = 2


+ 11
- 1
mindspore/nn/optim/ada_grad.py View File

@@ -81,7 +81,8 @@ class Adagrad(Optimizer):
Default: 0.001.
update_slots (bool): If true, update accumulation. Default: True.
loss_scale (float): Value for the loss scale. It must be greater than 0.0. Default: 1.0.
weight_decay (float): Weight decay value to multiply weight, must be zero or positive value. Default: 0.0.
weight_decay (Union[float, int]): Weight decay value to multiply weight, must be zero or positive value.
Default: 0.0.

Inputs:
- **grads** (tuple[Tensor]) - The gradients of `params` in the optimizer, the shape is the same as the `params`
@@ -90,6 +91,15 @@ class Adagrad(Optimizer):
Outputs:
Tensor[bool], the value is True.

Raises:
TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
TypeError: If element of `parameters` is neither Parameter nor dict.
TypeError: If `accum` or `loss_scale` is not a float.
TypeError: If `update_slots` is not a bool.
TypeError: If `weight_decay` is neither float nor int.
ValueError: If `loss_scale` is less than or equal to 0.
ValueError: If `accum` or `weight_decay` is less than 0.

Supported Platforms:
``Ascend`` ``CPU`` ``GPU``



+ 29
- 0
mindspore/nn/optim/adam.py View File

@@ -267,6 +267,16 @@ class Adam(Optimizer):
Outputs:
Tensor[bool], the value is True.

Raises:
TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
TypeError: If element of `parameters` is neither Parameter nor dict.
TypeError: If `beta1`, `beta2`, `eps` or `loss_scale` is not a float.
TypeError: If `weight_decay` is neither float nor int.
TypeError: If `use_locking` or `use_nesterov` is not a bool.
ValueError: If `loss_scale` or `eps` is less than or equal to 0.
ValueError: If `beta1`, `beta2` is not in range (0.0, 1.0).
ValueError: If `weight_decay` is less than 0.

Supported Platforms:
``Ascend`` ``GPU``

@@ -414,6 +424,15 @@ class AdamWeightDecay(Optimizer):
Outputs:
tuple[bool], all elements are True.

Raises:
TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
TypeError: If element of `parameters` is neither Parameter nor dict.
TypeError: If `beta1`, `beta2` or `eps` is not a float.
TypeError: If `weight_decay` is neither float nor int.
ValueError: If `eps` is less than or equal to 0.
ValueError: If `beta1`, `beta2` is not in range (0.0, 1.0).
ValueError: If `weight_decay` is less than 0.

Supported Platforms:
``Ascend`` ``GPU``

@@ -545,6 +564,16 @@ class AdamOffload(Optimizer):
Outputs:
Tensor[bool], the value is True.

Raises:
TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
TypeError: If element of `parameters` is neither Parameter nor dict.
TypeError: If `beta1`, `beta2`, `eps` or `loss_scale` is not a float.
TypeError: If `weight_decay` is neither float nor int.
TypeError: If `use_locking` or `use_nesterov` is not a bool.
ValueError: If `loss_scale` or `eps` is less than or equal to 0.
ValueError: If `beta1`, `beta2` is not in range (0.0, 1.0).
ValueError: If `weight_decay` is less than 0.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``



+ 11
- 1
mindspore/nn/optim/ftrl.py View File

@@ -140,7 +140,8 @@ class FTRL(Optimizer):
l2 (float): l2 regularization strength, must be greater than or equal to zero. Default: 0.0.
use_locking (bool): If true, use locks for updating operation. Default: False.
loss_scale (float): Value for the loss scale. It must be equal to or greater than 1.0. Default: 1.0.
weight_decay (float): Weight decay value to multiply weight, must be zero or positive value. Default: 0.0.
weight_decay (Union[float, int]): Weight decay value to multiply weight, must be zero or positive value.
Default: 0.0.

Inputs:
- **grads** (tuple[Tensor]) - The gradients of `params` in the optimizer, the shape is the same as the `params`
@@ -149,6 +150,15 @@ class FTRL(Optimizer):
Outputs:
tuple[Parameter], the updated parameters, the shape is the same as `params`.

Raises:
TypeError: If `initial_accum`, `learning_rate`, `lr_power`, `l1`, `l2` or `loss_scale` is not a float.
TypeError: If element of `parameters` is neither Parameter nor dict.
TypeError: If `weight_decay` is neither float nor int.
TypeError: If `use_nesterov` is not a bool.
ValueError: If `lr_power` is greater than 0.
ValueError: If `loss_scale` is less than or equal to 0.
ValueError: If `initial_accum`, `l1` or `l2` is less than 0.

Supported Platforms:
``Ascend`` ``GPU``



+ 9
- 0
mindspore/nn/optim/lamb.py View File

@@ -224,6 +224,15 @@ class Lamb(Optimizer):
Outputs:
tuple[bool], all elements are True.

Raises:
TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
TypeError: If element of `parameters` is neither Parameter nor dict.
TypeError: If `beta1`, `beta2` or `eps` is not a float.
TypeError: If `weight_decay` is neither float nor int.
ValueError: If `eps` is less than or equal to 0.
ValueError: If `beta1`, `beta2` is not in range (0.0, 1.0).
ValueError: If `weight_decay` is less than 0.

Supported Platforms:
``Ascend`` ``GPU``



+ 11
- 1
mindspore/nn/optim/lazyadam.py View File

@@ -177,7 +177,7 @@ class LazyAdam(Optimizer):
use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
If true, update the gradients using NAG.
If false, update the gradients without using NAG. Default: False.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
weight_decay (Union[float, int]): Weight decay (L2 penalty). Default: 0.0.
loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. Default:
1.0.

@@ -187,6 +187,16 @@ class LazyAdam(Optimizer):
Outputs:
Tensor[bool], the value is True.

Raises:
TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
TypeError: If element of `parameters` is neither Parameter nor dict.
TypeError: If `beta1`, `beta2`, `eps` or `loss_scale` is not a float.
TypeError: If `weight_decay` is neither float nor int.
TypeError: If `use_locking` or `use_nesterov` is not a bool.
ValueError: If `loss_scale` or `eps` is less than or equal to 0.
ValueError: If `beta1`, `beta2` is not in range (0.0, 1.0).
ValueError: If `weight_decay` is less than 0.

Supported Platforms:
``Ascend`` ``GPU``



+ 8
- 3
mindspore/nn/optim/momentum.py View File

@@ -96,7 +96,7 @@ class Momentum(Optimizer):
momentum (float): Hyperparameter of type float, means momentum for the moving average.
It must be at least 0.0.
weight_decay (int, float): Weight decay (L2 penalty). It must be equal to or greater than 0.0. Default: 0.0.
loss_scale (int, float): A floating point value for the loss scale. It must be greater than 0.0. Default: 1.0.
loss_scale (float): A floating point value for the loss scale. It must be greater than 0.0. Default: 1.0.
use_nesterov (bool): Enable Nesterov momentum. Default: False.

Inputs:
@@ -106,8 +106,13 @@ class Momentum(Optimizer):
tuple[bool], all elements are True.

Raises:
ValueError: If the momentum is less than 0.0.
TypeError: If the momentum is not a float or use_nesterov is not a bool.
TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
TypeError: If element of `parameters` is neither Parameter nor dict.
TypeError: If `loss_scale` or `momentum` is not a float.
TypeError: If `weight_decay` is neither float nor int.
TypeError: If `use_nesterov` is not a bool.
ValueError: If `loss_scale` is less than or equal to 0.
ValueError: If `weight_decay` or `momentum` is less than 0.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``


+ 9
- 3
mindspore/nn/optim/optimizer.py View File

@@ -83,14 +83,20 @@ class Optimizer(Cell):
- grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used.
If not, the `grad_centralization` is False by default.

weight_decay (float): A floating point value for the weight decay. It must be equal to or greater than 0.
weight_decay (Union[float, int]): An int or a floating point value for the weight decay.
It must be equal to or greater than 0.
If the type of `weight_decay` input is int, it will be converted to float. Default: 0.0.
loss_scale (float): A floating point value for the loss scale. It must be greater than 0. If the
type of `loss_scale` input is int, it will be converted to float. Default: 1.0.

Raises:
ValueError: If the learning_rate is a Tensor, but the dimension of tensor is greater than 1.
TypeError: If the learning_rate is not any of the three types: float, Tensor, nor Iterable.
TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
TypeError: If element of `parameters` is neither Parameter nor dict.
TypeError: If `loss_scale` is not a float.
TypeError: If `weight_decay` is neither float nor int.
ValueError: If `loss_scale` is less than or equal to 0.
ValueError: If `weight_decay` is less than 0.
ValueError: If `learning_rate` is a Tensor, but the dimension of tensor is greater than 1.

Supported Platforms:
``Ascend`` ``GPU``


+ 10
- 1
mindspore/nn/optim/proximal_ada_grad.py View File

@@ -101,7 +101,8 @@ class ProximalAdagrad(Optimizer):
l2 (float): l2 regularization strength, must be greater than or equal to zero. Default: 0.0.
use_locking (bool): If true, use locks for updating operation. Default: False.
loss_scale (float): Value for the loss scale. It must be greater than 0.0. Default: 1.0.
weight_decay (float): Weight decay value to multiply weight, must be zero or positive value. Default: 0.0.
weight_decay (Union[float, int]): Weight decay value to multiply weight, must be zero or positive value.
Default: 0.0.
Inputs:
- **grads** (tuple[Tensor]) - The gradients of `params` in the optimizer, the shape is the same as the `params`
@@ -110,6 +111,14 @@ class ProximalAdagrad(Optimizer):
Outputs:
Tensor[bool], the value is True.
Raises:
TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
TypeError: If element of `parameters` is neither Parameter nor dict.
TypeError: If `accum`, `l1`, `l2` or `loss_scale` is not a float.
TypeError: If `weight_decay` is neither float nor int.
ValueError: If `loss_scale` is less than or equal to 0.
ValueError: If `accum`, `l1`, `l2` or `weight_decay` is less than 0.
Supported Platforms:
``Ascend``


+ 10
- 1
mindspore/nn/optim/rmsprop.py View File

@@ -125,7 +125,7 @@ class RMSProp(Optimizer):
updated. Default: False.
centered (bool): If true, gradients are normalized by the estimated variance of the gradient. Default: False.
loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0.
weight_decay (float): Weight decay (L2 penalty). Should be equal to or greater than 0. Default: 0.0.
weight_decay (Union[float, int]): Weight decay (L2 penalty). Should be equal to or greater than 0. Default: 0.0.

Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
@@ -133,6 +133,15 @@ class RMSProp(Optimizer):
Outputs:
Tensor[bool], the value is True.

Raises:
TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
TypeError: If `decay`, `momentum`, `epsilon` or `loss_scale` is not a float.
TypeError: If element of `parameters` is neither Parameter nor dict.
TypeError: If `weight_decay` is neither float nor int.
TypeError: If `use_locking` or `centered` is not a bool.
ValueError: If `epsilon` is less than or equal to 0.
ValueError: If `decay` or `momentum` is less than 0.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``



+ 14
- 1
mindspore/nn/wrap/cell_wrapper.py View File

@@ -18,6 +18,7 @@ from types import FunctionType, MethodType
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
_get_parallel_mode)
from mindspore.context import ParallelMode
from mindspore._checkparam import Validator as validator
from ...common import dtype as mstype
from ...common.parameter import Parameter, ParameterTuple
from ...ops import composite as C
@@ -80,6 +81,9 @@ class WithLossCell(Cell):
Outputs:
Tensor, a scalar tensor with shape :math:`()`.

Raises:
TypeError: If dtype of `data` or `label` is neither float16 nor float32.

Supported Platforms:
``Ascend`` ``GPU``

@@ -139,6 +143,9 @@ class WithGradCell(Cell):
Outputs:
list, a list of Tensors with identical shapes as trainable weights.

Raises:
TypeError: If `sens` is not one of None, Tensor, Scalar or Tuple.

Supported Platforms:
``Ascend`` ``GPU``

@@ -294,6 +301,9 @@ class TrainOneStepCell(Cell):
Outputs:
Tensor, a scalar Tensor with shape :math:`()`.

Raises:
TypeError: If `sens` is not a number.

Supported Platforms:
``Ascend`` ``GPU``

@@ -468,6 +478,9 @@ class WithEvalCell(Cell):
Tuple, containing a scalar loss Tensor, a network output Tensor of shape :math:`(N, \ldots)`
and a label Tensor of shape :math:`(N, \ldots)`.

Raises:
TypeError: If `add_cast_fp32` is not a bool.

Supported Platforms:
``Ascend`` ``GPU``

@@ -482,7 +495,7 @@ class WithEvalCell(Cell):
super(WithEvalCell, self).__init__(auto_prefix=False)
self._network = network
self._loss_fn = loss_fn
self.add_cast_fp32 = add_cast_fp32
self.add_cast_fp32 = validator.check_value_type("add_cast_fp32", add_cast_fp32, [bool], self.cls_name)

def construct(self, data, label):
outputs = self._network(data)


+ 8
- 0
mindspore/nn/wrap/loss_scale.py View File

@@ -50,6 +50,7 @@ def _tensor_grad_overflow(grad):
def _tensor_grad_overflow_row_tensor(grad):
return grad_overflow(grad.values)


class DynamicLossScaleUpdateCell(Cell):
r"""
Dynamic Loss scale update cell.
@@ -73,6 +74,9 @@ class DynamicLossScaleUpdateCell(Cell):
Outputs:
Tensor, a scalar Tensor with shape :math:`()`.

Raises:
TypeError: If dtype of `inputs` or `label` is neither float16 nor float32.

Supported Platforms:
``Ascend`` ``GPU``

@@ -227,6 +231,10 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
- **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool.
- **loss scaling value** (Tensor) - Tensor with shape :math:`()`

Raises:
TypeError: If `scale_sense` is neither Cell nor Tensor.
ValueError: If shape of `scale_sense` is neither (1,) nor ().

Supported Platforms:
``Ascend`` ``GPU``



+ 1
- 1
mindspore/ops/operations/array_ops.py View File

@@ -821,7 +821,7 @@ class Gather(PrimitiveWithCheck):
TypeError: If `axis` is not an int.

Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32)


+ 1
- 1
mindspore/ops/operations/comm_ops.py View File

@@ -428,7 +428,7 @@ class Broadcast(PrimitiveWithInfer):
TypeError: If root_rank is not a integer or group is not a string.

Supported Platforms:
``Ascend``, ``GPU``
``Ascend`` ``GPU``

Examples:
>>> # This example should be run with multiple processes.


+ 8
- 1
mindspore/ops/operations/math_ops.py View File

@@ -747,6 +747,12 @@ class MatMul(PrimitiveWithCheck):
Outputs:
Tensor, the shape of the output tensor is :math:`(N, M)`.

Raises:
TypeError: If `transpose_a` or `transpose_b` is not a bool.
ValueError: If the column of matrix dimensions of `input_x` is not equal to
the row of matrix dimensions of `input_y`.
ValueError: If length of shape of `input_x` or `input_y` is not equal to 2.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -827,7 +833,8 @@ class BatchMatMul(MatMul):

Raises:
TypeError: If `transpose_a` or `transpose_b` is not a bool.
ValueError: If length of shape of `input_x` is less than 3 or not equal to length of shape of `input_y`.
ValueError: If length of shape of `input_x` is not equal to length of shape of `input_y` or
length of shape of `input_x` is less than 3.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``


+ 20
- 2
mindspore/ops/operations/nn_ops.py View File

@@ -1722,7 +1722,7 @@ class MaxPool(_Pool):
represent height and width of movement respectively. Default: 1.
pad_mode (str): The optional value for pad mode, is "same" or "valid", not case sensitive.
Default: "valid".
format (str) : The optional value for data format, is 'NHWC' or 'NCHW'.
data_format (str) : The optional value for data format, is 'NHWC' or 'NCHW'.
Default: 'NCHW'.

- same: Adopts the way of completion. The height and width of the output will be the same as
@@ -1739,6 +1739,13 @@ class MaxPool(_Pool):
Outputs:
Tensor, with shape :math:`(N, C_{out}, H_{out}, W_{out})`.

Raises:
TypeError: If `kernel_size` or `strides` is neither int nor tuple.
ValueError: If `pad_mode` is neither 'valid' nor 'same' with not case sensitive.
ValueError: If `data_format` is neither 'NCHW' nor 'NHWC'.
ValueError: If `kernel_size` or `strides` is less than 1.
ValueError: If length of shape of `input` is not equal to 4.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -1973,6 +1980,13 @@ class AvgPool(_Pool):
Outputs:
Tensor, with shape :math:`(N, C_{out}, H_{out}, W_{out})`.

Raises:
TypeError: If `kernel_size` or `strides` is neither int nor tuple.
ValueError: If `pad_mode` is neither 'valid' nor 'same' with not case sensitive.
ValueError: If `data_format` is neither 'NCHW' nor 'NHWC'.
ValueError: If `kernel_size` or `strides` is less than 1.
ValueError: If length of shape of `input` is not equal to 4.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -3096,7 +3110,7 @@ class L2Normalize(PrimitiveWithInfer):
Tensor, with the same type and shape as the input.

Supported Platforms:
``Ascend``
``Ascend`` ``GPU``

Examples:
>>> l2_normalize = ops.L2Normalize()
@@ -4237,6 +4251,10 @@ class Adam(PrimitiveWithInfer):
- **m** (Tensor) - The same shape and data type as `m`.
- **v** (Tensor) - The same shape and data type as `v`.

Raises:
TypeError: If `use_locking` or `use_nesterov` is not a bool.
ValueError: If shape of `var`, `m` and `v` is not the same.

Supported Platforms:
``Ascend`` ``GPU``



Loading…
Cancel
Save