diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index 055eaae7c6..87c46380f6 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -145,9 +145,12 @@ class Adam(Optimizer): When the learning_rate is float or learning_rate is a Tensor but the dims of the Tensor is 0, use fixed learning rate. Other cases are not supported. Default: 1e-3. - beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0). - beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). - eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. + beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0). Default: + 0.9. + beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). Default: + 0.999. + eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default: + 1e-8. use_locking (bool): Whether to enable a lock to protect updating variable tensors. If True, updating of the var, m, and v tensors will be protected by a lock. If False, the result is unpredictable. Default: False. @@ -155,8 +158,8 @@ class Adam(Optimizer): If True, updates the gradients using NAG. If False, updates the gradients without using NAG. Default: False. weight_decay (float): Weight decay (L2 penalty). Default: 0.0. - loss_scale (float): A floating point value for the loss scale. Default: 1.0. - Should be equal to or greater than 1. + loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. Default: + 1.0. decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name. diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index bab539461e..34abc2b1c2 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -46,8 +46,8 @@ class Optimizer(Cell): learning_rate (float): A floating point value for the learning rate. Should be greater than 0. parameters (list): A list of parameter, which will be updated. The element in `parameters` should be class mindspore.Parameter. - weight_decay (float): A floating point value for the weight decay. If the type of `weight_decay` - input is int, it will be convertd to float. Default: 0.0. + weight_decay (float): A floating point value for the weight decay. It should be equal to or greater than 0. + If the type of `weight_decay` input is int, it will be convertd to float. Default: 0.0. loss_scale (float): A floating point value for the loss scale. It should be greater than 0. If the type of `loss_scale` input is int, it will be convertd to float. Default: 1.0. decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: lambda @@ -87,21 +87,15 @@ class Optimizer(Cell): if isinstance(weight_decay, int): weight_decay = float(weight_decay) - - validator.check_float_legal_value('weight_decay', weight_decay, None) + validator.check_value_type("weight_decay", weight_decay, [float], None) + validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, None) if isinstance(loss_scale, int): loss_scale = float(loss_scale) + validator.check_value_type("loss_scale", loss_scale, [float], None) + validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_NEITHER, None) - validator.check_float_legal_value('loss_scale', loss_scale, None) - - if loss_scale <= 0.0: - raise ValueError("Loss scale should be greater than 0, but got {}".format(loss_scale)) self.loss_scale = loss_scale - - if weight_decay < 0.0: - raise ValueError("Weight decay should be equal or greater than 0, but got {}".format(weight_decay)) - self.learning_rate = Parameter(learning_rate, name="learning_rate") self.parameters = ParameterTuple(parameters) self.reciprocal_scale = 1.0 / loss_scale diff --git a/mindspore/nn/optim/rmsprop.py b/mindspore/nn/optim/rmsprop.py index a8f118b709..b1271587b4 100644 --- a/mindspore/nn/optim/rmsprop.py +++ b/mindspore/nn/optim/rmsprop.py @@ -15,6 +15,7 @@ """rmsprop""" from mindspore.ops import functional as F, composite as C, operations as P from mindspore._checkparam import Validator as validator +from mindspore._checkparam import Rel from .optimizer import Optimizer rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") @@ -91,14 +92,16 @@ class RMSProp(Optimizer): take the i-th value as the learning rate. When the learning_rate is float or learning_rate is a Tensor but the dims of the Tensor is 0, use fixed learning rate. - Other cases are not supported. - decay (float): Decay rate. - momentum (float): Hyperparameter of type float, means momentum for the moving average. - epsilon (float): Term added to the denominator to improve numerical stability. Should be greater than 0. + Other cases are not supported. Default: 0.1. + decay (float): Decay rate. Should be equal to or greater than 0. Default: 0.9. + momentum (float): Hyperparameter of type float, means momentum for the moving average. Should be equal to or + greater than 0.Default: 0.0. + epsilon (float): Term added to the denominator to improve numerical stability. Should be greater than + 0. Default: 1e-10. use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False. - centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False - loss_scale (float): A floating point value for the loss scale. Default: 1.0. - weight_decay (float): Weight decay (L2 penalty). Default: 0.0. + centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False. + loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0. + weight_decay (float): Weight decay (L2 penalty). Should be equal to or greater than 0. Default: 0.0. decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: lambda x: 'beta' not in x.name and 'gamma' not in x.name. @@ -118,17 +121,15 @@ class RMSProp(Optimizer): use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0, decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter) - - if isinstance(momentum, float) and momentum < 0.0: - raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) - - if decay < 0.0: - raise ValueError("decay should be at least 0.0, but got dampening {}".format(decay)) - self.decay = decay - self.epsilon = epsilon - + validator.check_value_type("decay", decay, [float], self.cls_name) + validator.check_number_range("decay", decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) + validator.check_value_type("momentum", momentum, [float], self.cls_name) + validator.check_number_range("momentum", momentum, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) + validator.check_value_type("epsilon", epsilon, [float], self.cls_name) + validator.check_number_range("epsilon", epsilon, 0.0, float("inf"), Rel.INC_NEITHER, self.cls_name) validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) validator.check_value_type("centered", centered, [bool], self.cls_name) + self.centered = centered if centered: self.opt = P.ApplyCenteredRMSProp(use_locking) @@ -137,11 +138,10 @@ class RMSProp(Optimizer): self.opt = P.ApplyRMSProp(use_locking) self.momentum = momentum - self.ms = self.parameters.clone(prefix="mean_square", init='zeros') self.moment = self.parameters.clone(prefix="moment", init='zeros') self.hyper_map = C.HyperMap() - + self.epsilon = epsilon self.decay = decay def construct(self, gradients): diff --git a/mindspore/nn/optim/sgd.py b/mindspore/nn/optim/sgd.py index bf2ed21d50..388fe5db47 100755 --- a/mindspore/nn/optim/sgd.py +++ b/mindspore/nn/optim/sgd.py @@ -49,12 +49,12 @@ class SGD(Optimizer): When the learning_rate is float or learning_rate is a Tensor but the dims of the Tensor is 0, use fixed learning rate. Other cases are not supported. Default: 0.1. - momentum (float): A floating point value the momentum. Default: 0. - dampening (float): A floating point value of dampening for momentum. Default: 0. - weight_decay (float): Weight decay (L2 penalty). Default: 0. + momentum (float): A floating point value the momentum. Default: 0.0. + dampening (float): A floating point value of dampening for momentum. Default: 0.0. + weight_decay (float): Weight decay (L2 penalty). Default: 0.0. nesterov (bool): Enables the Nesterov momentum. Default: False. loss_scale (float): A floating point value for the loss scale, which should be larger - than 0.0. Default: 1.0. + than 0.0. Default: 1.0. Inputs: - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. diff --git a/tests/ut/python/nn/optim/test_rmsprop.py b/tests/ut/python/nn/optim/test_rmsprop.py new file mode 100644 index 0000000000..647f1e8d45 --- /dev/null +++ b/tests/ut/python/nn/optim/test_rmsprop.py @@ -0,0 +1,62 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test adam """ +import numpy as np +import pytest +import mindspore.nn as nn +from mindspore.common.api import _executor +from mindspore import Tensor, Parameter +from mindspore.nn import TrainOneStepCell, WithLossCell +from mindspore.ops import operations as P +from mindspore.nn.optim import RMSProp + + +class Net(nn.Cell): + """ Net definition """ + def __init__(self): + super(Net, self).__init__() + self.weight = Parameter(Tensor(np.ones([64, 10]).astype(np.float32)), name="weight") + self.bias = Parameter(Tensor(np.ones([10]).astype((np.float32))), name="bias") + self.matmul = P.MatMul() + self.biasAdd = P.BiasAdd() + + def construct(self, x): + x = self.biasAdd(self.matmul(x, self.weight), self.bias) + return x + + +def test_rmsprop_compile(): + """ test_adamw_compile """ + inputs = Tensor(np.ones([1, 64]).astype(np.float32)) + label = Tensor(np.zeros([1, 10]).astype(np.float32)) + net = Net() + net.set_train() + + loss = nn.SoftmaxCrossEntropyWithLogits() + optimizer = RMSProp(net.trainable_params(), learning_rate=0.1) + + net_with_loss = WithLossCell(net, loss) + train_network = TrainOneStepCell(net_with_loss, optimizer) + _executor.compile(train_network, inputs, label) + + +def test_rmsprop_e(): + net = Net() + with pytest.raises(ValueError): + RMSProp(net.get_parameters(), momentum=-0.1, learning_rate=0.1) + + with pytest.raises(TypeError): + RMSProp(net.get_parameters(), momentum=1, learning_rate=0.1) +