Browse Source

update the description of some operations.

tags/v1.1.0
wangshuide2020 5 years ago
parent
commit
fd74df15c8
10 changed files with 70 additions and 41 deletions
  1. +4
    -4
      mindspore/nn/layer/math.py
  2. +6
    -6
      mindspore/nn/learning_rate_schedule.py
  3. +1
    -1
      mindspore/nn/metrics/__init__.py
  4. +1
    -1
      mindspore/nn/metrics/recall.py
  5. +1
    -1
      mindspore/nn/optim/ada_grad.py
  6. +1
    -1
      mindspore/nn/optim/ftrl.py
  7. +2
    -2
      mindspore/nn/optim/lazyadam.py
  8. +1
    -1
      mindspore/ops/operations/array_ops.py
  9. +12
    -8
      mindspore/ops/operations/math_ops.py
  10. +41
    -16
      mindspore/ops/operations/nn_ops.py

+ 4
- 4
mindspore/nn/layer/math.py View File

@@ -284,7 +284,7 @@ class LGamma(Cell):


class DiGamma(Cell): class DiGamma(Cell):
r""" r"""
Calculate Digamma using Lanczos' approximation refering to "A Precision Approximationof the Gamma Function".
Calculates Digamma using Lanczos' approximation refering to "A Precision Approximationof the Gamma Function".
The algorithm is: The algorithm is:


.. math:: .. math::
@@ -549,7 +549,7 @@ def _IgammacContinuedFraction(ax, x, a, enabled):


class IGamma(Cell): class IGamma(Cell):
r""" r"""
Calculate lower regularized incomplete Gamma function.
Calculates lower regularized incomplete Gamma function.
The lower regularized incomplete Gamma function is defined as: The lower regularized incomplete Gamma function is defined as:


.. math:: .. math::
@@ -950,7 +950,7 @@ class Moments(Cell):


class MatInverse(Cell): class MatInverse(Cell):
""" """
Calculate the inverse of Positive-Definite Hermitian matrix using Cholesky decomposition.
Calculates the inverse of Positive-Definite Hermitian matrix using Cholesky decomposition.


Supported Platforms: Supported Platforms:
``GPU`` ``GPU``
@@ -987,7 +987,7 @@ class MatInverse(Cell):


class MatDet(Cell): class MatDet(Cell):
""" """
Calculate the determinant of Positive-Definite Hermitian matrix using Cholesky decomposition.
Calculates the determinant of Positive-Definite Hermitian matrix using Cholesky decomposition.


Supported Platforms: Supported Platforms:
``GPU`` ``GPU``


+ 6
- 6
mindspore/nn/learning_rate_schedule.py View File

@@ -53,7 +53,7 @@ def _check_inputs(learning_rate, decay_rate, decay_steps, is_stair, cls_name):


class ExponentialDecayLR(LearningRateSchedule): class ExponentialDecayLR(LearningRateSchedule):
r""" r"""
Calculate learning rate base on exponential decay function.
Calculates learning rate base on exponential decay function.


For the i-th step, the formula of computing decayed_learning_rate[i] is: For the i-th step, the formula of computing decayed_learning_rate[i] is:


@@ -111,7 +111,7 @@ class ExponentialDecayLR(LearningRateSchedule):


class NaturalExpDecayLR(LearningRateSchedule): class NaturalExpDecayLR(LearningRateSchedule):
r""" r"""
Calculate learning rate base on natural exponential decay function.
Calculates learning rate base on natural exponential decay function.


For the i-th step, the formula of computing decayed_learning_rate[i] is: For the i-th step, the formula of computing decayed_learning_rate[i] is:


@@ -170,7 +170,7 @@ class NaturalExpDecayLR(LearningRateSchedule):


class InverseDecayLR(LearningRateSchedule): class InverseDecayLR(LearningRateSchedule):
r""" r"""
Calculate learning rate base on inverse-time decay function.
Calculates learning rate base on inverse-time decay function.


For the i-th step, the formula of computing decayed_learning_rate[i] is: For the i-th step, the formula of computing decayed_learning_rate[i] is:


@@ -227,7 +227,7 @@ class InverseDecayLR(LearningRateSchedule):


class CosineDecayLR(LearningRateSchedule): class CosineDecayLR(LearningRateSchedule):
r""" r"""
Calculate learning rate base on cosine decay function.
Calculates learning rate base on cosine decay function.


For the i-th step, the formula of computing decayed_learning_rate[i] is: For the i-th step, the formula of computing decayed_learning_rate[i] is:


@@ -283,7 +283,7 @@ class CosineDecayLR(LearningRateSchedule):


class PolynomialDecayLR(LearningRateSchedule): class PolynomialDecayLR(LearningRateSchedule):
r""" r"""
Calculate learning rate base on polynomial decay function.
Calculates learning rate base on polynomial decay function.


For the i-th step, the formula of computing decayed_learning_rate[i] is: For the i-th step, the formula of computing decayed_learning_rate[i] is:


@@ -362,7 +362,7 @@ class PolynomialDecayLR(LearningRateSchedule):


class WarmUpLR(LearningRateSchedule): class WarmUpLR(LearningRateSchedule):
r""" r"""
Get learning rate warming up.
Gets learning rate warming up.


For the i-th step, the formula of computing warmup_learning_rate[i] is: For the i-th step, the formula of computing warmup_learning_rate[i] is:




+ 1
- 1
mindspore/nn/metrics/__init__.py View File

@@ -59,7 +59,7 @@ __factory__ = {


def names(): def names():
""" """
Get the names of the metric methods.
Gets the names of the metric methods.


Returns: Returns:
List, the name list of metric methods. List, the name list of metric methods.


+ 1
- 1
mindspore/nn/metrics/recall.py View File

@@ -23,7 +23,7 @@ from ._evaluation import EvaluationBase


class Recall(EvaluationBase): class Recall(EvaluationBase):
r""" r"""
Calculate recall for classification and multilabel data.
Calculates recall for classification and multilabel data.


The recall class creates two local variables, :math:`\text{true_positive}` and :math:`\text{false_negative}`, The recall class creates two local variables, :math:`\text{true_positive}` and :math:`\text{false_negative}`,
that are used to compute the recall. This value is ultimately returned as the recall, an idempotent operation that are used to compute the recall. This value is ultimately returned as the recall, an idempotent operation


+ 1
- 1
mindspore/nn/optim/ada_grad.py View File

@@ -37,7 +37,7 @@ def _check_param_value(accum, update_slots, prim_name=None):


class Adagrad(Optimizer): class Adagrad(Optimizer):
""" """
Implement the Adagrad algorithm with ApplyAdagrad Operator.
Implements the Adagrad algorithm with ApplyAdagrad Operator.


Adagrad is an online Learning and Stochastic Optimization. Adagrad is an online Learning and Stochastic Optimization.
Refer to paper `Efficient Learning using Forward-Backward Splitting Refer to paper `Efficient Learning using Forward-Backward Splitting


+ 1
- 1
mindspore/nn/optim/ftrl.py View File

@@ -74,7 +74,7 @@ def _check_param(initial_accum, lr_power, l1, l2, use_locking, prim_name=None):


class FTRL(Optimizer): class FTRL(Optimizer):
""" """
Implement the FTRL algorithm with ApplyFtrl Operator.
Implements the FTRL algorithm with ApplyFtrl Operator.


FTRL is an online convex optimization algorithm that adaptively chooses its regularization function FTRL is an online convex optimization algorithm that adaptively chooses its regularization function
based on the loss functions. Refer to paper `Adaptive Bound Optimization for Online Convex Optimization based on the loss functions. Refer to paper `Adaptive Bound Optimization for Online Convex Optimization


+ 2
- 2
mindspore/nn/optim/lazyadam.py View File

@@ -104,9 +104,9 @@ def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):


class LazyAdam(Optimizer): class LazyAdam(Optimizer):
r""" r"""
Updates gradients by Adaptive Moment Estimation (Adam) algorithm.
This optimizer will apply a lazy adam algorithm when gradient is sparse.


The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
The original adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.


The updating formulas are as follows, The updating formulas are as follows,




+ 1
- 1
mindspore/ops/operations/array_ops.py View File

@@ -733,7 +733,7 @@ class Unique(Primitive):
- **x** (Tensor) - The input tensor. - **x** (Tensor) - The input tensor.


Outputs: Outputs:
Tuple, containing Tensor objects `(y, idx)., `y` is a tensor with the
Tuple, containing Tensor objects `(y, idx), `y` is a tensor with the
same type as `x`, and contains the unique elements in `x`, sorted in same type as `x`, and contains the unique elements in `x`, sorted in
ascending order. `idx` is a tensor containing indices of elements in ascending order. `idx` is a tensor containing indices of elements in
the input corresponding to the output tensor. the input corresponding to the output tensor.


+ 12
- 8
mindspore/ops/operations/math_ops.py View File

@@ -2961,7 +2961,9 @@ class IsNan(PrimitiveWithInfer):
Examples: Examples:
>>> is_nan = ops.IsNan() >>> is_nan = ops.IsNan()
>>> input_x = Tensor(np.array([np.log(-1), 1, np.log(0)]), mindspore.float32) >>> input_x = Tensor(np.array([np.log(-1), 1, np.log(0)]), mindspore.float32)
>>> result = is_nan(input_x)
>>> output = is_nan(input_x)
>>> print(output)
[True False False]
""" """


@prim_attr_register @prim_attr_register
@@ -2992,7 +2994,9 @@ class IsInf(PrimitiveWithInfer):
Examples: Examples:
>>> is_inf = ops.IsInf() >>> is_inf = ops.IsInf()
>>> input_x = Tensor(np.array([np.log(-1), 1, np.log(0)]), mindspore.float32) >>> input_x = Tensor(np.array([np.log(-1), 1, np.log(0)]), mindspore.float32)
>>> result = is_inf(input_x)
>>> output = is_inf(input_x)
>>> print(output)
[False False True]
""" """


@prim_attr_register @prim_attr_register
@@ -3132,9 +3136,9 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
>>> alloc_status = ops.NPUAllocFloatStatus() >>> alloc_status = ops.NPUAllocFloatStatus()
>>> get_status = ops.NPUGetFloatStatus() >>> get_status = ops.NPUGetFloatStatus()
>>> init = alloc_status() >>> init = alloc_status()
>>> output = get_status(init)
>>> print(output)
[0. 0. 0. 0. 0. 0. 0. 0.]
>>> get_status(init)
>>> print(init)
[1. 1. 1. 1. 1. 1. 1. 1.]
""" """


@prim_attr_register @prim_attr_register
@@ -3179,9 +3183,9 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
>>> clear_status = ops.NPUClearFloatStatus() >>> clear_status = ops.NPUClearFloatStatus()
>>> init = alloc_status() >>> init = alloc_status()
>>> flag = get_status(init) >>> flag = get_status(init)
>>> output = clear_status(init)
>>> print(output)
[0. 0. 0. 0. 0. 0. 0. 0.]
>>> clear_status(init)
>>> print(init)
[1. 1. 1. 1. 1. 1. 1. 1.]
""" """


@prim_attr_register @prim_attr_register


+ 41
- 16
mindspore/ops/operations/nn_ops.py View File

@@ -1512,7 +1512,14 @@ class MaxPool(_Pool):
Examples: Examples:
>>> input_tensor = Tensor(np.arange(1 * 3 * 3 * 4).reshape((1, 3, 3, 4)), mindspore.float32) >>> input_tensor = Tensor(np.arange(1 * 3 * 3 * 4).reshape((1, 3, 3, 4)), mindspore.float32)
>>> maxpool_op = ops.MaxPool(padding="VALID", ksize=2, strides=1) >>> maxpool_op = ops.MaxPool(padding="VALID", ksize=2, strides=1)
>>> output_tensor = maxpool_op(input_tensor)
>>> output = maxpool_op(input_tensor)
>>> print(output)
[[[[ 5. 6. 7.]
[ 9. 10. 11.]]
[[17. 18. 19.]
[21. 22. 23.]]
[[29. 30. 31.]
[33. 34. 35.]]]]
""" """


@prim_attr_register @prim_attr_register
@@ -1568,6 +1575,13 @@ class MaxPoolWithArgmax(_Pool):
>>> input_tensor = Tensor(np.arange(1 * 3 * 3 * 4).reshape((1, 3, 3, 4)), mindspore.float32) >>> input_tensor = Tensor(np.arange(1 * 3 * 3 * 4).reshape((1, 3, 3, 4)), mindspore.float32)
>>> maxpool_arg_op = ops.MaxPoolWithArgmax(padding="VALID", ksize=2, strides=1) >>> maxpool_arg_op = ops.MaxPoolWithArgmax(padding="VALID", ksize=2, strides=1)
>>> output_tensor, argmax = maxpool_arg_op(input_tensor) >>> output_tensor, argmax = maxpool_arg_op(input_tensor)
>>> print(output_tensor)
[[[[ 5. 6. 7.]
[ 9. 10. 11.]]
[[17. 18. 19.]
[21. 22. 23.]]
[[29. 30. 31.]
[33. 34. 35.]]]]
""" """


def __init__(self, ksize=1, strides=1, padding="valid", data_format="NCHW"): def __init__(self, ksize=1, strides=1, padding="valid", data_format="NCHW"):
@@ -2315,7 +2329,9 @@ class SGD(PrimitiveWithCheck):
>>> accum = Tensor(np.array([0.1, 0.3, -0.2, -0.1]), mindspore.float32) >>> accum = Tensor(np.array([0.1, 0.3, -0.2, -0.1]), mindspore.float32)
>>> momentum = Tensor(0.1, mindspore.float32) >>> momentum = Tensor(0.1, mindspore.float32)
>>> stat = Tensor(np.array([1.5, -0.3, 0.2, -0.7]), mindspore.float32) >>> stat = Tensor(np.array([1.5, -0.3, 0.2, -0.7]), mindspore.float32)
>>> result = sgd(parameters, gradient, learning_rate, accum, momentum, stat)
>>> output = sgd(parameters, gradient, learning_rate, accum, momentum, stat)
>>> print(output[0])
[ 1.9899 -0.4903 1.6952001 3.9801 ]
""" """


@prim_attr_register @prim_attr_register
@@ -2931,7 +2947,7 @@ class FastGelu(PrimitiveWithInfer):
FastGelu is defined as follows: FastGelu is defined as follows:


.. math:: .. math::
\text{output} = \frac {x} {1 + \exp(-1.702 * \left| x \right|)} * \exp(0.851 * (x - \left| x \right|))`,
\text{output} = \frac {x} {1 + \exp(-1.702 * \left| x \right|)} * \exp(0.851 * (x - \left| x \right|)),


where :math:`x` is the element of the input. where :math:`x` is the element of the input.


@@ -3466,8 +3482,8 @@ class ROIAlign(PrimitiveWithInfer):
points. The details of (RoI) Align operator are described in `Mask R-CNN <https://arxiv.org/abs/1703.06870>`_. points. The details of (RoI) Align operator are described in `Mask R-CNN <https://arxiv.org/abs/1703.06870>`_.


Args: Args:
pooled_height (int): The output features' height.
pooled_width (int): The output features' width.
pooled_height (int): The output features height.
pooled_width (int): The output features width.
spatial_scale (float): A scaling factor that maps the raw image coordinates to the input spatial_scale (float): A scaling factor that maps the raw image coordinates to the input
feature map coordinates. Suppose the height of a RoI is `ori_h` in the raw image and `fea_h` in the feature map coordinates. Suppose the height of a RoI is `ori_h` in the raw image and `fea_h` in the
input feature map, the `spatial_scale` must be `fea_h / ori_h`. input feature map, the `spatial_scale` must be `fea_h / ori_h`.
@@ -4046,7 +4062,7 @@ class FusedSparseFtrl(PrimitiveWithInfer):
- **linear** (Tensor) - A Tensor with shape (1,). - **linear** (Tensor) - A Tensor with shape (1,).


Supported Platforms: Supported Platforms:
``CPU``
``Ascend`` ``CPU``


Examples: Examples:
>>> import mindspore >>> import mindspore
@@ -4072,6 +4088,9 @@ class FusedSparseFtrl(PrimitiveWithInfer):
>>> indices = Tensor(np.array([0, 1]).astype(np.int32)) >>> indices = Tensor(np.array([0, 1]).astype(np.int32))
>>> output = net(grad, indices) >>> output = net(grad, indices)
>>> print(output) >>> print(output)
(Tensor(shape=[1], dtype=Float32, value= [0.00000000e+00]),
Tensor(shape=[1], dtype=Float32, value= [0.00000000e+00]),
Tensor(shape=[1], dtype=Float32, value= [0.00000000e+00]))
""" """
__mindspore_signature__ = ( __mindspore_signature__ = (
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
@@ -4155,6 +4174,7 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer):
Examples: Examples:
>>> import numpy as np >>> import numpy as np
>>> import mindspore.nn as nn >>> import mindspore.nn as nn
>>> import mindspore.common.dtype as mstype
>>> from mindspore import Tensor, Parameter >>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as ops >>> from mindspore.ops import operations as ops
>>> class Net(nn.Cell): >>> class Net(nn.Cell):
@@ -4176,6 +4196,8 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer):
>>> indices = Tensor(np.array([0, 1]).astype(np.int32)) >>> indices = Tensor(np.array([0, 1]).astype(np.int32))
>>> output = net(grad, indices) >>> output = net(grad, indices)
>>> print(output) >>> print(output)
(Tensor(shape=[1], dtype=Float32, value= [0.00000000e+00]),
Tensor(shape=[1], dtype=Float32, value= [0.00000000e+00]))
""" """
__mindspore_signature__ = ( __mindspore_signature__ = (
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
@@ -4225,9 +4247,9 @@ class KLDivLoss(PrimitiveWithInfer):


.. math:: .. math::
\ell(x, y) = \begin{cases} \ell(x, y) = \begin{cases}
L, & \text{if reduction} = \text{`none';}\\
\operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
\operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
L, & \text{if reduction} = \text{'none';}\\
\operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
\operatorname{sum}(L), & \text{if reduction} = \text{'sum'.}
\end{cases} \end{cases}


Args: Args:
@@ -4302,9 +4324,9 @@ class BinaryCrossEntropy(PrimitiveWithInfer):


.. math:: .. math::
\ell(x, y) = \begin{cases} \ell(x, y) = \begin{cases}
L, & \text{if reduction} = \text{`none';}\\
\operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
\operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
L, & \text{if reduction} = \text{'none';}\\
\operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
\operatorname{sum}(L), & \text{if reduction} = \text{'sum'.}
\end{cases} \end{cases}


Args: Args:
@@ -6102,10 +6124,13 @@ class CTCLoss(PrimitiveWithInfer):
>>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32) >>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32)
>>> ctc_loss = ops.CTCLoss() >>> ctc_loss = ops.CTCLoss()
>>> loss, gradient = ctc_loss(inputs, labels_indices, labels_values, sequence_length) >>> loss, gradient = ctc_loss(inputs, labels_indices, labels_values, sequence_length)
>>> print(loss.shape)
(2,)
>>> print(gradient.shape)
(2, 2, 3)
>>> print(loss)
[ 0.69121575 0.5381993 ]
>>> print(gradient)
[[[ 0.25831494 0.3623634 -0.62067937 ]
[ 0.25187883 0.2921483 -0.5440271 ]]
[[ 0.43522435 0.24408469 0.07787037 ]
[ 0.29642645 0.4232373 0.06138104 ]]]
""" """


@prim_attr_register @prim_attr_register


Loading…
Cancel
Save