Merge pull request !839 from wangnan39/add_parameter_verification_for_rmsproptags/v0.3.0-alpha
| @@ -145,9 +145,12 @@ class Adam(Optimizer): | |||
| When the learning_rate is float or learning_rate is a Tensor | |||
| but the dims of the Tensor is 0, use fixed learning rate. | |||
| Other cases are not supported. Default: 1e-3. | |||
| beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0). | |||
| beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). | |||
| eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. | |||
| beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0). Default: | |||
| 0.9. | |||
| beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). Default: | |||
| 0.999. | |||
| eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default: | |||
| 1e-8. | |||
| use_locking (bool): Whether to enable a lock to protect updating variable tensors. | |||
| If True, updating of the var, m, and v tensors will be protected by a lock. | |||
| If False, the result is unpredictable. Default: False. | |||
| @@ -155,8 +158,8 @@ class Adam(Optimizer): | |||
| If True, updates the gradients using NAG. | |||
| If False, updates the gradients without using NAG. Default: False. | |||
| weight_decay (float): Weight decay (L2 penalty). Default: 0.0. | |||
| loss_scale (float): A floating point value for the loss scale. Default: 1.0. | |||
| Should be equal to or greater than 1. | |||
| loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. Default: | |||
| 1.0. | |||
| decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: | |||
| lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name. | |||
| @@ -46,8 +46,8 @@ class Optimizer(Cell): | |||
| learning_rate (float): A floating point value for the learning rate. Should be greater than 0. | |||
| parameters (list): A list of parameter, which will be updated. The element in `parameters` | |||
| should be class mindspore.Parameter. | |||
| weight_decay (float): A floating point value for the weight decay. If the type of `weight_decay` | |||
| input is int, it will be convertd to float. Default: 0.0. | |||
| weight_decay (float): A floating point value for the weight decay. It should be equal to or greater than 0. | |||
| If the type of `weight_decay` input is int, it will be convertd to float. Default: 0.0. | |||
| loss_scale (float): A floating point value for the loss scale. It should be greater than 0. If the | |||
| type of `loss_scale` input is int, it will be convertd to float. Default: 1.0. | |||
| decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: lambda | |||
| @@ -87,21 +87,15 @@ class Optimizer(Cell): | |||
| if isinstance(weight_decay, int): | |||
| weight_decay = float(weight_decay) | |||
| validator.check_float_legal_value('weight_decay', weight_decay, None) | |||
| validator.check_value_type("weight_decay", weight_decay, [float], None) | |||
| validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, None) | |||
| if isinstance(loss_scale, int): | |||
| loss_scale = float(loss_scale) | |||
| validator.check_value_type("loss_scale", loss_scale, [float], None) | |||
| validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_NEITHER, None) | |||
| validator.check_float_legal_value('loss_scale', loss_scale, None) | |||
| if loss_scale <= 0.0: | |||
| raise ValueError("Loss scale should be greater than 0, but got {}".format(loss_scale)) | |||
| self.loss_scale = loss_scale | |||
| if weight_decay < 0.0: | |||
| raise ValueError("Weight decay should be equal or greater than 0, but got {}".format(weight_decay)) | |||
| self.learning_rate = Parameter(learning_rate, name="learning_rate") | |||
| self.parameters = ParameterTuple(parameters) | |||
| self.reciprocal_scale = 1.0 / loss_scale | |||
| @@ -15,6 +15,7 @@ | |||
| """rmsprop""" | |||
| from mindspore.ops import functional as F, composite as C, operations as P | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| from .optimizer import Optimizer | |||
| rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | |||
| @@ -91,14 +92,16 @@ class RMSProp(Optimizer): | |||
| take the i-th value as the learning rate. | |||
| When the learning_rate is float or learning_rate is a Tensor | |||
| but the dims of the Tensor is 0, use fixed learning rate. | |||
| Other cases are not supported. | |||
| decay (float): Decay rate. | |||
| momentum (float): Hyperparameter of type float, means momentum for the moving average. | |||
| epsilon (float): Term added to the denominator to improve numerical stability. Should be greater than 0. | |||
| Other cases are not supported. Default: 0.1. | |||
| decay (float): Decay rate. Should be equal to or greater than 0. Default: 0.9. | |||
| momentum (float): Hyperparameter of type float, means momentum for the moving average. Should be equal to or | |||
| greater than 0.Default: 0.0. | |||
| epsilon (float): Term added to the denominator to improve numerical stability. Should be greater than | |||
| 0. Default: 1e-10. | |||
| use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False. | |||
| centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False | |||
| loss_scale (float): A floating point value for the loss scale. Default: 1.0. | |||
| weight_decay (float): Weight decay (L2 penalty). Default: 0.0. | |||
| centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False. | |||
| loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0. | |||
| weight_decay (float): Weight decay (L2 penalty). Should be equal to or greater than 0. Default: 0.0. | |||
| decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: | |||
| lambda x: 'beta' not in x.name and 'gamma' not in x.name. | |||
| @@ -118,17 +121,15 @@ class RMSProp(Optimizer): | |||
| use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0, | |||
| decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): | |||
| super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter) | |||
| if isinstance(momentum, float) and momentum < 0.0: | |||
| raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) | |||
| if decay < 0.0: | |||
| raise ValueError("decay should be at least 0.0, but got dampening {}".format(decay)) | |||
| self.decay = decay | |||
| self.epsilon = epsilon | |||
| validator.check_value_type("decay", decay, [float], self.cls_name) | |||
| validator.check_number_range("decay", decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||
| validator.check_value_type("momentum", momentum, [float], self.cls_name) | |||
| validator.check_number_range("momentum", momentum, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||
| validator.check_value_type("epsilon", epsilon, [float], self.cls_name) | |||
| validator.check_number_range("epsilon", epsilon, 0.0, float("inf"), Rel.INC_NEITHER, self.cls_name) | |||
| validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) | |||
| validator.check_value_type("centered", centered, [bool], self.cls_name) | |||
| self.centered = centered | |||
| if centered: | |||
| self.opt = P.ApplyCenteredRMSProp(use_locking) | |||
| @@ -137,11 +138,10 @@ class RMSProp(Optimizer): | |||
| self.opt = P.ApplyRMSProp(use_locking) | |||
| self.momentum = momentum | |||
| self.ms = self.parameters.clone(prefix="mean_square", init='zeros') | |||
| self.moment = self.parameters.clone(prefix="moment", init='zeros') | |||
| self.hyper_map = C.HyperMap() | |||
| self.epsilon = epsilon | |||
| self.decay = decay | |||
| def construct(self, gradients): | |||
| @@ -49,12 +49,12 @@ class SGD(Optimizer): | |||
| When the learning_rate is float or learning_rate is a Tensor | |||
| but the dims of the Tensor is 0, use fixed learning rate. | |||
| Other cases are not supported. Default: 0.1. | |||
| momentum (float): A floating point value the momentum. Default: 0. | |||
| dampening (float): A floating point value of dampening for momentum. Default: 0. | |||
| weight_decay (float): Weight decay (L2 penalty). Default: 0. | |||
| momentum (float): A floating point value the momentum. Default: 0.0. | |||
| dampening (float): A floating point value of dampening for momentum. Default: 0.0. | |||
| weight_decay (float): Weight decay (L2 penalty). Default: 0.0. | |||
| nesterov (bool): Enables the Nesterov momentum. Default: False. | |||
| loss_scale (float): A floating point value for the loss scale, which should be larger | |||
| than 0.0. Default: 1.0. | |||
| than 0.0. Default: 1.0. | |||
| Inputs: | |||
| - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | |||
| @@ -0,0 +1,62 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test adam """ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore.common.api import _executor | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| from mindspore.ops import operations as P | |||
| from mindspore.nn.optim import RMSProp | |||
| class Net(nn.Cell): | |||
| """ Net definition """ | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.weight = Parameter(Tensor(np.ones([64, 10]).astype(np.float32)), name="weight") | |||
| self.bias = Parameter(Tensor(np.ones([10]).astype((np.float32))), name="bias") | |||
| self.matmul = P.MatMul() | |||
| self.biasAdd = P.BiasAdd() | |||
| def construct(self, x): | |||
| x = self.biasAdd(self.matmul(x, self.weight), self.bias) | |||
| return x | |||
| def test_rmsprop_compile(): | |||
| """ test_adamw_compile """ | |||
| inputs = Tensor(np.ones([1, 64]).astype(np.float32)) | |||
| label = Tensor(np.zeros([1, 10]).astype(np.float32)) | |||
| net = Net() | |||
| net.set_train() | |||
| loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| optimizer = RMSProp(net.trainable_params(), learning_rate=0.1) | |||
| net_with_loss = WithLossCell(net, loss) | |||
| train_network = TrainOneStepCell(net_with_loss, optimizer) | |||
| _executor.compile(train_network, inputs, label) | |||
| def test_rmsprop_e(): | |||
| net = Net() | |||
| with pytest.raises(ValueError): | |||
| RMSProp(net.get_parameters(), momentum=-0.1, learning_rate=0.1) | |||
| with pytest.raises(TypeError): | |||
| RMSProp(net.get_parameters(), momentum=1, learning_rate=0.1) | |||