Browse Source

add check and fix bug

tags/v1.2.0-rc1
Jiaqi 5 years ago
parent
commit
cbd717ea23
14 changed files with 124 additions and 59 deletions
  1. +37
    -26
      mindspore/nn/loss/loss.py
  2. +9
    -0
      mindspore/nn/metrics/hausdorff_distance.py
  3. +5
    -2
      mindspore/nn/metrics/mean_surface_distance.py
  4. +5
    -2
      mindspore/nn/metrics/root_mean_square_surface_distance.py
  5. +7
    -2
      mindspore/nn/optim/ada_grad.py
  6. +7
    -2
      mindspore/nn/optim/adam.py
  7. +7
    -2
      mindspore/nn/optim/ftrl.py
  8. +7
    -2
      mindspore/nn/optim/lamb.py
  9. +7
    -2
      mindspore/nn/optim/lazyadam.py
  10. +7
    -2
      mindspore/nn/optim/momentum.py
  11. +4
    -3
      mindspore/nn/optim/optimizer.py
  12. +7
    -2
      mindspore/nn/optim/proximal_ada_grad.py
  13. +7
    -2
      mindspore/nn/optim/rmsprop.py
  14. +8
    -10
      mindspore/nn/optim/sgd.py

+ 37
- 26
mindspore/nn/loss/loss.py View File

@@ -178,11 +178,12 @@ class RMSELoss(_Loss):
RMSELoss creates a standard to measure the root mean square error between :math:`x` and :math:`y`
element-wise, where :math:`x` is the input and :math:`y` is the target.

For simplicity, let :math:`x` and :math:`y` be 1-dimensional Tensor with length :math:`N`,
For simplicity, let :math:`x` and :math:`y` be 1-dimensional Tensor with length :math:`M` and :math:`N`,
the unreduced loss (i.e. with argument reduction set to 'none') of :math:`x` and :math:`y` is given as:

.. math::
loss = \sqrt{\frac{1}{M}\sum_{m=1}^{M}{(x_m-y_m)^2}}
loss = \begin{cases} \sqrt{\frac{1}{M}\sum_{m=1,n=1}^{M,N}{(x_m-y_n)^2}}, & \text {if M > N }
\\\\ \sqrt{\frac{1}{N}\sum_{m=1,n=1}^{M,N}{(x_m-y_n)^2}}, &\text{if M < N } \end{cases}


Inputs:
@@ -249,9 +250,12 @@ class MAELoss(_Loss):
>>> print(output)
0.33333334
"""
def __init__(self, reduction='mean'):
super(MAELoss, self).__init__(reduction)
self.abs = P.Abs()

def construct(self, logits, label):
x = F.absolute(logits - label)
x = self.abs(logits - label)
return self.get_loss(x)


@@ -484,10 +488,10 @@ class MultiClassDiceLoss(_Loss):
Default: 'softmax'. Choose from: ['softmax', 'logsoftmax', 'relu', 'relu6', 'tanh','Sigmoid']

Inputs:
- **y_pred** (Tensor) - Tensor of shape (N, C, ...). y_pred dimension should be greater than 1.
The data type must be float16 or float32.
- **y** (Tensor) - Tensor of shape (N, C, ...). y dimension should be greater than 1.
The data type must be float16 or float32.
- **y_pred** (Tensor) - Tensor of shape (N, C, ...). y_pred dimension should be greater than 1. The data type
must be float16 or float32.
- **y** (Tensor) - Tensor of shape (N, C, ...). y dimension should be greater than 1. The data type must be
float16 or float32.

Outputs:
Tensor, a tensor of shape with the per-example sampled MultiClass Dice Losses.
@@ -1002,21 +1006,25 @@ class BCEWithLogitsLoss(_Loss):

@constexpr
def _check_ndim(predict_nidm, target_ndim):
if predict_nidm < 2 or predict_nidm > 4:
raise ValueError("The dimensions of predict and target should be between 2 and 4, but got"
"predict dim {}.".format(predict_nidm))
if target_ndim < 2 or target_ndim > 4:
raise ValueError("The dimensions of target and target should be between 2 and 4, but got"
"target dim {}.".format(target_ndim))
if predict_nidm != target_ndim:
raise ValueError("The dim of the predicted value and the dim of the target value must be equal, but got"
"predict dim {} and target dim {}.".format(predict_nidm, target_ndim))


@constexpr
def _check_channel_and_shape(target, predict):
if target not in (predict, 1):
raise ValueError("The target must have a channel or the same shape as predict.")


@constexpr
def _check_predict_channel(predict):
def _check_channel_and_shape(predict, target):
if predict == 1:
raise ValueError("Single channel prediction is not supported.")
if target not in (1, predict):
raise ValueError("The target must have a channel or the same shape as predict."
"If it has a channel, it should be the range [0, C-1], where C is the number of classes "
f"inferred from 'predict': C={predict}.")


class FocalLoss(_Loss):
@@ -1027,20 +1035,22 @@ class FocalLoss(_Loss):

Args:
gamma (float): Gamma is used to adjust the steepness of weight curve in focal loss. Default: 2.0.
weight (Union[Tensor, None]): A rescaling weight applied to the loss of each batch element. If None, no weights
are applied. Default: None.
weight (Union[Tensor, None]): A rescaling weight applied to the loss of each batch element. The dimension of
weight should be 1. If None, no weights are applied. Default: None.
reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none".
If "none", do not perform reduction. Default: "mean".

Inputs:
- **predict** (Tensor) - Input logits. Tensor of shape should be BCH[WD]. Where C is the number of classes.
Its value is greater than 1.
- **target** (Tensor) - Tensor of shape should be B1H[WD] or BCH[WD]. If the target shape is B1H[WD], the
expected target of this loss should be the class index within the range of [0, C-1],
where C is the number of classes.
- **predict** (Tensor) - Tensor of shape should be (B, C) or (B, C, H) or (B, C, H, W). Where C is the number
of classes. Its value is greater than 1. If the shape is (B, C, H, W) or (B, C, H), the H or product of H
and W should be the same as target.
- **target** (Tensor) - Tensor of shape should be (B, C) or (B, C, H) or (B, C, H, W). The value of C is 1 or
it needs to be the same as predict's C. If C is not 1, the shape of target should be the same as that of
predict, where C is the number of classes. If the shape is (B, C, H, W) or (B, C, H), the H or product of H
and W should be the same as predict.

Outputs:
Tensor, a tensor of shape with the per-example sampled Focal losses.
Tensor, it's a tensor with the same shape and type as input `predict`.

Raises:
TypeError: If the data type of ``gamma`` is not float..
@@ -1056,9 +1066,9 @@ class FocalLoss(_Loss):
>>> predict = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32)
>>> target = Tensor([[1], [1], [0]], mstype.int32)
>>> focalloss = nn.FocalLoss(weight=Tensor([1, 2]), gamma=2.0, reduction='mean')
>>> output = focalloss(inputs, labels)
>>> output = focalloss(predict, target)
>>> print(output)
0.33365273
1.6610543
"""

def __init__(self, weight=None, gamma=2.0, reduction='mean'):
@@ -1067,6 +1077,8 @@ class FocalLoss(_Loss):
self.gamma = validator.check_value_type("gamma", gamma, [float])
if weight is not None and not isinstance(weight, Tensor):
raise TypeError("The type of weight should be Tensor, but got {}.".format(type(weight)))
if isinstance(weight, Tensor) and weight.ndim != 1:
raise ValueError("The dimension of weight should be 1, but got {}.".format(weight.ndim))
self.weight = weight
self.expand_dims = P.ExpandDims()
self.gather_d = P.GatherD()
@@ -1077,8 +1089,7 @@ class FocalLoss(_Loss):
def construct(self, predict, target):
targets = target
_check_ndim(predict.ndim, targets.ndim)
_check_channel_and_shape(targets.shape[1], predict.shape[1])
_check_predict_channel(predict.shape[1])
_check_channel_and_shape(predict.shape[1], targets.shape[1])

if predict.ndim > 2:
predict = predict.view(predict.shape[0], predict.shape[1], -1)


+ 9
- 0
mindspore/nn/metrics/hausdorff_distance.py View File

@@ -259,17 +259,26 @@ class HausdorffDistance(Metric):
ValueError: If the number of the inputs is not 3.
"""
self._is_update = True

if len(inputs) != 3:
raise ValueError('HausdorffDistance need 3 inputs (y_pred, y, label), but got {}'.format(len(inputs)))

y_pred = self._convert_data(inputs[0])
y = self._convert_data(inputs[1])
label_idx = inputs[2]

if not isinstance(label_idx, (int, float)):
raise TypeError("The data type of label_idx must be int or float, but got {}.".format(type(label_idx)))

if label_idx not in y_pred and label_idx not in y:
raise ValueError("The label_idx should be in y_pred or y, but {} is not.".format(label_idx))

if y_pred.size == 0 or y_pred.shape != y.shape:
raise ValueError("Labelfields should have the same shape, but got {}, {}".format(y_pred.shape, y.shape))

y_pred = (y_pred == label_idx) if y_pred.dtype is not bool else y_pred
y = (y == label_idx) if y.dtype is not bool else y

self.y_pred_edges, self.y_edges = self._get_mask_edges_distance(y_pred, y)

def eval(self):


+ 5
- 2
mindspore/nn/metrics/mean_surface_distance.py View File

@@ -99,8 +99,11 @@ class MeanSurfaceDistance(Metric):
y = self._convert_data(inputs[1])
label_idx = inputs[2]

if not isinstance(label_idx, int):
raise TypeError("The data type of label_idx must be int, but got {}.".format(type(label_idx)))
if not isinstance(label_idx, (int, float)):
raise TypeError("The data type of label_idx must be int or float, but got {}.".format(type(label_idx)))

if label_idx not in y_pred and label_idx not in y:
raise ValueError("The label_idx should be in y_pred or y, but {} is not.".format(label_idx))

if y_pred.size == 0 or y_pred.shape != y.shape:
raise ValueError("y_pred and y should have same shape, but got {}, {}.".format(y_pred.shape, y.shape))


+ 5
- 2
mindspore/nn/metrics/root_mean_square_surface_distance.py View File

@@ -101,8 +101,11 @@ class RootMeanSquareDistance(Metric):
y = self._convert_data(inputs[1])
label_idx = inputs[2]

if not isinstance(label_idx, int):
raise TypeError("The data type of label_idx must be int, but got {}.".format(type(label_idx)))
if not isinstance(label_idx, (int, float)):
raise TypeError("The data type of label_idx must be int or float, but got {}.".format(type(label_idx)))

if label_idx not in y_pred and label_idx not in y:
raise ValueError("The label_idx should be in y_pred or y, but {} is not.".format(label_idx))

if y_pred.size == 0 or y_pred.shape != y.shape:
raise ValueError("y_pred and y should have same shape, but got {}, {}.".format(y_pred.shape, y.shape))


+ 7
- 2
mindspore/nn/optim/ada_grad.py View File

@@ -48,6 +48,10 @@ class Adagrad(Optimizer):
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.

When separating parameter groups, if you want to centralize the gradient, set grad_centralization to True,
but the gradient centralization can only be applied to the parameters of the convolution layer.
If the parameters of the non convolution layer are set to True, an error will be reported.

To improve parameter groups performance, the customized order of parameters can be supported.

Args:
@@ -67,8 +71,9 @@ class Adagrad(Optimizer):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' must be in one of group parameters.

- grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used.
If not, the `grad_centralization` is False by default. This parameter only works on the convolution layer.
- grad_centralization: Optional. The data type of "grad_centralization" is Bool. If "grad_centralization"
is in the keys, the set value will be used. If not, the `grad_centralization` is False by default.
This parameter only works on the convolution layer.

accum (float): The starting value for accumulators, must be zero or positive values. Default: 0.1.
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.


+ 7
- 2
mindspore/nn/optim/adam.py View File

@@ -212,6 +212,10 @@ class Adam(Optimizer):
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.

When separating parameter groups, if you want to centralize the gradient, set grad_centralization to True,
but the gradient centralization can only be applied to the parameters of the convolution layer.
If the parameters of the non convolution layer are set to True, an error will be reported.

To improve parameter groups performance, the customized order of parameters is supported.

The sparse strategy is applied while the SparseGatherV2 operator is used for forward network.
@@ -235,8 +239,9 @@ class Adam(Optimizer):
the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
which in the 'order_params' must be in one of group parameters.

- grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used.
If not, the `grad_centralization` is False by default. This parameter only works on the convolution layer.
- grad_centralization: Optional. The data type of "grad_centralization" is Bool. If "grad_centralization"
is in the keys, the set value will be used. If not, the `grad_centralization` is False by default.
This parameter only works on the convolution layer.

learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then


+ 7
- 2
mindspore/nn/optim/ftrl.py View File

@@ -106,6 +106,10 @@ class FTRL(Optimizer):
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on all of the parameters.

When separating parameter groups, if you want to centralize the gradient, set grad_centralization to True,
but the gradient centralization can only be applied to the parameters of the convolution layer.
If the parameters of the non convolution layer are set to True, an error will be reported.

To improve parameter groups performance, the customized order of parameters can be supported.

The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
@@ -128,8 +132,9 @@ class FTRL(Optimizer):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' must be in one of group parameters.

- grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used.
If not, the `grad_centralization` is False by default. This parameter only works on the convolution layer.
- grad_centralization: Optional. The data type of "grad_centralization" is Bool. If "grad_centralization"
is in the keys, the set value will be used. If not, the `grad_centralization` is False by default.
This parameter only works on the convolution layer.

initial_accum (float): The starting value for accumulators, must be zero or positive values. Default: 0.1.
learning_rate (float): The learning rate value, must be zero or positive, dynamic learning rate is currently


+ 7
- 2
mindspore/nn/optim/lamb.py View File

@@ -181,6 +181,10 @@ class Lamb(Optimizer):
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.

When separating parameter groups, if you want to centralize the gradient, set grad_centralization to True,
but the gradient centralization can only be applied to the parameters of the convolution layer.
If the parameters of the non convolution layer are set to True, an error will be reported.

To improve parameter groups performance, the customized order of parameters can be supported.

Args:
@@ -200,8 +204,9 @@ class Lamb(Optimizer):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' must be in one of group parameters.

- grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used.
If not, the `grad_centralization` is False by default. This parameter only works on the convolution layer.
- grad_centralization: Optional. The data type of "grad_centralization" is Bool. If "grad_centralization"
is in the keys, the set value will be used. If not, the `grad_centralization` is False by default.
This parameter only works on the convolution layer.

learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then


+ 7
- 2
mindspore/nn/optim/lazyadam.py View File

@@ -130,6 +130,10 @@ class LazyAdam(Optimizer):
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.

When separating parameter groups, if you want to centralize the gradient, set grad_centralization to True,
but the gradient centralization can only be applied to the parameters of the convolution layer.
If the parameters of the non convolution layer are set to True, an error will be reported.

To improve parameter groups performance, the customized order of parameters can be supported.

The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
@@ -154,8 +158,9 @@ class LazyAdam(Optimizer):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' must be in one of group parameters.

- grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used.
If not, the `grad_centralization` is False by default. This parameter only works on the convolution layer.
- grad_centralization: Optional. The data type of "grad_centralization" is Bool. If "grad_centralization"
is in the keys, the set value will be used. If not, the `grad_centralization` is False by default.
This parameter only works on the convolution layer.

learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then


+ 7
- 2
mindspore/nn/optim/momentum.py View File

@@ -64,6 +64,10 @@ class Momentum(Optimizer):
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.

When separating parameter groups, if you want to centralize the gradient, set grad_centralization to True,
but the gradient centralization can only be applied to the parameters of the convolution layer.
If the parameters of the non convolution layer are set to True, an error will be reported.

To improve parameter groups performance, the customized order of parameters can be supported.

Args:
@@ -83,8 +87,9 @@ class Momentum(Optimizer):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' must be in one of group parameters.

- grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used.
If not, the `grad_centralization` is False by default. This parameter only works on the convolution layer.
- grad_centralization: Optional. The data type of "grad_centralization" is Bool. If "grad_centralization"
is in the keys, the set value will be used. If not, the `grad_centralization` is False by default.
This parameter only works on the convolution layer.

learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then


+ 4
- 3
mindspore/nn/optim/optimizer.py View File

@@ -52,7 +52,7 @@ class Optimizer(Cell):

When separating parameter groups, if you want to centralize the gradient, set grad_centralization to True,
but the gradient centralization can only be applied to the parameters of the convolution layer.
If the parameters of the non convolution layer are set to True, an error will be reported. Default: False.
If the parameters of the non convolution layer are set to True, an error will be reported.

To improve parameter groups performance, the customized order of parameters can be supported.

@@ -80,8 +80,9 @@ class Optimizer(Cell):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' must be in one of group parameters.

- grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used.
If not, the `grad_centralization` is False by default.
- grad_centralization: Optional. The data type of "grad_centralization" is Bool. If "grad_centralization"
is in the keys, the set value will be used. If not, the `grad_centralization` is False by default.
This parameter only works on the convolution layer.

weight_decay (float): A floating point value for the weight decay. It must be equal to or greater than 0.
If the type of `weight_decay` input is int, it will be converted to float. Default: 0.0.


+ 7
- 2
mindspore/nn/optim/proximal_ada_grad.py View File

@@ -62,6 +62,10 @@ class ProximalAdagrad(Optimizer):
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
When separating parameter groups, if you want to centralize the gradient, set grad_centralization to True,
but the gradient centralization can only be applied to the parameters of the convolution layer.
If the parameters of the non convolution layer are set to True, an error will be reported.
To improve parameter groups performance, the customized order of parameters can be supported.
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
@@ -85,8 +89,9 @@ class ProximalAdagrad(Optimizer):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' must be in one of group parameters.
- grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used.
If not, the `grad_centralization` is False by default. This parameter only works on the convolution layer.
- grad_centralization: Optional. The data type of "grad_centralization" is Bool. If "grad_centralization"
is in the keys, the set value will be used. If not, the `grad_centralization` is False by default.
This parameter only works on the convolution layer.
accum (float): The starting value for accumulators, must be zero or positive values. Default: 0.1.
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.


+ 7
- 2
mindspore/nn/optim/rmsprop.py View File

@@ -86,6 +86,10 @@ class RMSProp(Optimizer):
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.

When separating parameter groups, if you want to centralize the gradient, set grad_centralization to True,
but the gradient centralization can only be applied to the parameters of the convolution layer.
If the parameters of the non convolution layer are set to True, an error will be reported.

To improve parameter groups performance, the customized order of parameters can be supported.

Args:
@@ -105,8 +109,9 @@ class RMSProp(Optimizer):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' must be in one of group parameters.

- grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used.
If not, the `grad_centralization` is False by default. This parameter only works on the convolution layer.
- grad_centralization: Optional. The data type of "grad_centralization" is Bool. If "grad_centralization"
is in the keys, the set value will be used. If not, the `grad_centralization` is False by default.
This parameter only works on the convolution layer.

learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then


+ 8
- 10
mindspore/nn/optim/sgd.py View File

@@ -57,9 +57,9 @@ class SGD(Optimizer):
Here : where p, v and u denote the parameters, accum, and momentum respectively.

Note:
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
When separating parameter groups, if you want to centralize the gradient, set grad_centralization to True,
but the gradient centralization can only be applied to the parameters of the convolution layer.
If the parameters of the non convolution layer are set to True, an error will be reported.

To improve parameter groups performance, the customized order of parameters can be supported.

@@ -73,15 +73,13 @@ class SGD(Optimizer):
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
If not, the `learning_rate` in the API will be used.

- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
will be used. If not, the `weight_decay` in the API will be used.

- order_params: Optional. If "order_params" in the keys, the value must be the order of parameters and
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' must be in one of group parameters.

- grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used.
If not, the `grad_centralization` is False by default. This parameter only works on the convolution layer.
- grad_centralization: Optional. The data type of "grad_centralization" is Bool. If "grad_centralization"
is in the keys, the set value will be used. If not, the `grad_centralization` is False by default.
This parameter only works on the convolution layer.

learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then
@@ -119,11 +117,11 @@ class SGD(Optimizer):
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True},
>>> group_params = [{'params': conv_params,'grad_centralization':True},
... {'params': no_conv_params, 'lr': 0.01},
... {'order_params': net.trainable_params()}]
>>> optim = nn.SGD(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad
>>> # The conv_params's parameters will use default learning rate of 0.1 default weight decay of 0.0 and grad
>>> # centralization of True.
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad
>>> # centralization of False.


Loading…
Cancel
Save