Browse Source

add annotations for some api

pull/14983/head
wangnan39@huawei.com 4 years ago
parent
commit
7ad4fa160c
12 changed files with 76 additions and 18 deletions
  1. +11
    -2
      mindspore/nn/cell.py
  2. +3
    -0
      mindspore/nn/layer/activation.py
  3. +5
    -1
      mindspore/nn/optim/ada_grad.py
  4. +10
    -2
      mindspore/nn/optim/adam.py
  5. +5
    -1
      mindspore/nn/optim/ftrl.py
  6. +5
    -2
      mindspore/nn/optim/lazyadam.py
  7. +5
    -1
      mindspore/nn/optim/momentum.py
  8. +5
    -1
      mindspore/nn/optim/optimizer.py
  9. +5
    -1
      mindspore/nn/optim/proximal_ada_grad.py
  10. +5
    -1
      mindspore/nn/optim/rmsprop.py
  11. +5
    -2
      mindspore/nn/optim/sgd.py
  12. +12
    -4
      mindspore/train/loss_scale_manager.py

+ 11
- 2
mindspore/nn/cell.py View File

@@ -1061,7 +1061,9 @@ class Cell(Cell_):

def set_grad(self, requires_grad=True):
"""
Sets the cell flag for gradient.
Sets the cell flag for gradient. In pynative mode, this parameter specifies whether the network require
gradients. If True, the backward network needed to compute the gradients will be generated when the forward
network is executed.

Args:
requires_grad (bool): Specifies if the net need to grad, if it is
@@ -1075,7 +1077,8 @@ class Cell(Cell_):
Sets the cell to training mode.

The cell itself and all children cells will be set to training mode. Layers that have different constructions
for training and predicting , such as `BatchNorm`, will distinguish between the branches by this attribute.
for training and predicting, such as `BatchNorm`, will distinguish between the branches by this attribute. If
set to True, the training branch will be executed, otherwise another branch.

Args:
mode (bool): Specifies whether the model is training. Default: True.
@@ -1214,6 +1217,9 @@ class GraphKernel(Cell):
auto_prefix (bool): Recursively generate namespaces. Default: True.
flags (dict) : Set graph flags. Default: None.

Supported Platforms:
``Ascend`` ``GPU``

Examples:
>>> class Relu(nn.GraphKernel):
... def __init__(self):
@@ -1243,6 +1249,9 @@ class GraphCell(Cell):
Args:
graph (object): A compiled graph loaded from MindIR.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> import numpy as np
>>> import mindspore.nn as nn


+ 3
- 0
mindspore/nn/layer/activation.py View File

@@ -746,6 +746,9 @@ def get_activation(name):
Returns:
Function, the activation function.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> sigmoid = nn.get_activation('sigmoid')
"""


+ 5
- 1
mindspore/nn/optim/ada_grad.py View File

@@ -95,7 +95,11 @@ class Adagrad(Optimizer):
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
Default: 0.001.
update_slots (bool): If true, update accumulation. Default: True.
loss_scale (float): Value for the loss scale. It must be greater than 0.0. Default: 1.0.
loss_scale (float): Value for the loss scale. It must be greater than 0.0. In general, use the default value.
Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
`FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
`FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details.
Default: 1.0.
weight_decay (Union[float, int]): Weight decay value to multiply weight, must be zero or positive value.
Default: 0.0.



+ 10
- 2
mindspore/nn/optim/adam.py View File

@@ -264,7 +264,11 @@ class Adam(Optimizer):
If true, update the gradients using NAG.
If false, update the gradients without using NAG. Default: False.
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0.
loss_scale (float): A floating point value for the loss scale. Should be greater than 0. In general, use the
default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
`FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
`FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details.
Default: 1.0.

Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
@@ -561,7 +565,11 @@ class AdamOffload(Optimizer):
If true, update the gradients using NAG.
If false, update the gradients without using NAG. Default: False.
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0.
loss_scale (float): A floating point value for the loss scale. Should be greater than 0. In general, use the
default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
`FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
`FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details.
Default: 1.0.

Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.


+ 5
- 1
mindspore/nn/optim/ftrl.py View File

@@ -144,7 +144,11 @@ class FTRL(Optimizer):
l1 (float): l1 regularization strength, must be greater than or equal to zero. Default: 0.0.
l2 (float): l2 regularization strength, must be greater than or equal to zero. Default: 0.0.
use_locking (bool): If true, use locks for updating operation. Default: False.
loss_scale (float): Value for the loss scale. It must be equal to or greater than 1.0. Default: 1.0.
loss_scale (float): Value for the loss scale. It must be greater than 0.0. In general, use the default value.
Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
`FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
`FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details.
Default: 1.0.
weight_decay (Union[float, int]): Weight decay value to multiply weight, must be zero or positive value.
Default: 0.0.



+ 5
- 2
mindspore/nn/optim/lazyadam.py View File

@@ -183,8 +183,11 @@ class LazyAdam(Optimizer):
If true, update the gradients using NAG.
If false, update the gradients without using NAG. Default: False.
weight_decay (Union[float, int]): Weight decay (L2 penalty). Default: 0.0.
loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. Default:
1.0.
loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. In general,
use the default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update`
in `FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
`FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details.
Default: 1.0.

Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.


+ 5
- 1
mindspore/nn/optim/momentum.py View File

@@ -101,7 +101,11 @@ class Momentum(Optimizer):
momentum (float): Hyperparameter of type float, means momentum for the moving average.
It must be at least 0.0.
weight_decay (int, float): Weight decay (L2 penalty). It must be equal to or greater than 0.0. Default: 0.0.
loss_scale (float): A floating point value for the loss scale. It must be greater than 0.0. Default: 1.0.
loss_scale (float): A floating point value for the loss scale. It must be greater than 0.0. In general, use the
default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
`FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
`FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details.
Default: 1.0.
use_nesterov (bool): Enable Nesterov momentum. Default: False.

Inputs:


+ 5
- 1
mindspore/nn/optim/optimizer.py View File

@@ -88,7 +88,11 @@ class Optimizer(Cell):
It must be equal to or greater than 0.
If the type of `weight_decay` input is int, it will be converted to float. Default: 0.0.
loss_scale (float): A floating point value for the loss scale. It must be greater than 0. If the
type of `loss_scale` input is int, it will be converted to float. Default: 1.0.
type of `loss_scale` input is int, it will be converted to float. In general, use the default value. Only
when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
`FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
`FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details.
Default: 1.0.

Raises:
TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.


+ 5
- 1
mindspore/nn/optim/proximal_ada_grad.py View File

@@ -105,7 +105,11 @@ class ProximalAdagrad(Optimizer):
l1 (float): l1 regularization strength, must be greater than or equal to zero. Default: 0.0.
l2 (float): l2 regularization strength, must be greater than or equal to zero. Default: 0.0.
use_locking (bool): If true, use locks for updating operation. Default: False.
loss_scale (float): Value for the loss scale. It must be greater than 0.0. Default: 1.0.
loss_scale (float): Value for the loss scale. It must be greater than 0.0. In general, use the default value.
Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
`FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
`FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details.
Default: 1.0.
weight_decay (Union[float, int]): Weight decay value to multiply weight, must be zero or positive value.
Default: 0.0.


+ 5
- 1
mindspore/nn/optim/rmsprop.py View File

@@ -129,7 +129,11 @@ class RMSProp(Optimizer):
use_locking (bool): Whether to enable a lock to protect the variable and accumlation tensors from being
updated. Default: False.
centered (bool): If true, gradients are normalized by the estimated variance of the gradient. Default: False.
loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0.
loss_scale (float): A floating point value for the loss scale. Should be greater than 0. In general, use the
default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
`FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
`FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details.
Default: 1.0.
weight_decay (Union[float, int]): Weight decay (L2 penalty). Should be equal to or greater than 0. Default: 0.0.

Inputs:


+ 5
- 2
mindspore/nn/optim/sgd.py View File

@@ -94,8 +94,11 @@ class SGD(Optimizer):
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
nesterov (bool): Enables the Nesterov momentum. If use nesterov, momentum must be positive,
and dampening must equal to 0.0. Default: False.
loss_scale (float): A floating point value for the loss scale, which must be larger
than 0.0. Default: 1.0.
loss_scale (float): A floating point value for the loss scale, which must be larger than 0.0. In general, use
the default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
`FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
`FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details.
Default: 1.0.

Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.


+ 12
- 4
mindspore/train/loss_scale_manager.py View File

@@ -39,17 +39,25 @@ class FixedLossScaleManager(LossScaleManager):
Fixed loss-scale manager.

Args:
loss_scale (float): Loss scale. Default: 128.0.
drop_overflow_update (bool): whether to execute optimizer if there is an overflow. Default: True.
loss_scale (float): Loss scale. Note that if `drop_overflow_update` is set to False, the value of `loss_scale`
in optimizer that you used need to be set to the same value as here. Default: 128.0.
drop_overflow_update (bool): Whether to execute optimizer if there is an overflow. If True, the optimizer will
not executed when overflow occurs. Default: True.

Examples:
>>> from mindspore import Model, nn
>>> from mindspore.train.loss_scale_manager import FixedLossScaleManager
>>> from mindspore import Model, nn, FixedLossScaleManager
>>>
>>> net = Net()
>>> #1) Drop the parameter update if there is an overflow
>>> loss_scale_manager = FixedLossScaleManager()
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim)
>>>
>>> #2) Execute parameter update even if overflow occurs
>>> loss_scale = 1024
>>> loss_scale_manager = FixedLossScaleManager(loss_scale, False)
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9, loss_scale=loss_scale)
>>> model = Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim)
"""
def __init__(self, loss_scale=128.0, drop_overflow_update=True):
if loss_scale < 1:


Loading…
Cancel
Save