Merge pull request !839 from wangnan39/add_parameter_verification_for_rmsproptags/v0.3.0-alpha
| @@ -145,9 +145,12 @@ class Adam(Optimizer): | |||||
| When the learning_rate is float or learning_rate is a Tensor | When the learning_rate is float or learning_rate is a Tensor | ||||
| but the dims of the Tensor is 0, use fixed learning rate. | but the dims of the Tensor is 0, use fixed learning rate. | ||||
| Other cases are not supported. Default: 1e-3. | Other cases are not supported. Default: 1e-3. | ||||
| beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0). | |||||
| beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). | |||||
| eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. | |||||
| beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0). Default: | |||||
| 0.9. | |||||
| beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). Default: | |||||
| 0.999. | |||||
| eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default: | |||||
| 1e-8. | |||||
| use_locking (bool): Whether to enable a lock to protect updating variable tensors. | use_locking (bool): Whether to enable a lock to protect updating variable tensors. | ||||
| If True, updating of the var, m, and v tensors will be protected by a lock. | If True, updating of the var, m, and v tensors will be protected by a lock. | ||||
| If False, the result is unpredictable. Default: False. | If False, the result is unpredictable. Default: False. | ||||
| @@ -155,8 +158,8 @@ class Adam(Optimizer): | |||||
| If True, updates the gradients using NAG. | If True, updates the gradients using NAG. | ||||
| If False, updates the gradients without using NAG. Default: False. | If False, updates the gradients without using NAG. Default: False. | ||||
| weight_decay (float): Weight decay (L2 penalty). Default: 0.0. | weight_decay (float): Weight decay (L2 penalty). Default: 0.0. | ||||
| loss_scale (float): A floating point value for the loss scale. Default: 1.0. | |||||
| Should be equal to or greater than 1. | |||||
| loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. Default: | |||||
| 1.0. | |||||
| decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: | decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: | ||||
| lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name. | lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name. | ||||
| @@ -46,8 +46,8 @@ class Optimizer(Cell): | |||||
| learning_rate (float): A floating point value for the learning rate. Should be greater than 0. | learning_rate (float): A floating point value for the learning rate. Should be greater than 0. | ||||
| parameters (list): A list of parameter, which will be updated. The element in `parameters` | parameters (list): A list of parameter, which will be updated. The element in `parameters` | ||||
| should be class mindspore.Parameter. | should be class mindspore.Parameter. | ||||
| weight_decay (float): A floating point value for the weight decay. If the type of `weight_decay` | |||||
| input is int, it will be convertd to float. Default: 0.0. | |||||
| weight_decay (float): A floating point value for the weight decay. It should be equal to or greater than 0. | |||||
| If the type of `weight_decay` input is int, it will be convertd to float. Default: 0.0. | |||||
| loss_scale (float): A floating point value for the loss scale. It should be greater than 0. If the | loss_scale (float): A floating point value for the loss scale. It should be greater than 0. If the | ||||
| type of `loss_scale` input is int, it will be convertd to float. Default: 1.0. | type of `loss_scale` input is int, it will be convertd to float. Default: 1.0. | ||||
| decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: lambda | decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: lambda | ||||
| @@ -87,21 +87,15 @@ class Optimizer(Cell): | |||||
| if isinstance(weight_decay, int): | if isinstance(weight_decay, int): | ||||
| weight_decay = float(weight_decay) | weight_decay = float(weight_decay) | ||||
| validator.check_float_legal_value('weight_decay', weight_decay, None) | |||||
| validator.check_value_type("weight_decay", weight_decay, [float], None) | |||||
| validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, None) | |||||
| if isinstance(loss_scale, int): | if isinstance(loss_scale, int): | ||||
| loss_scale = float(loss_scale) | loss_scale = float(loss_scale) | ||||
| validator.check_value_type("loss_scale", loss_scale, [float], None) | |||||
| validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_NEITHER, None) | |||||
| validator.check_float_legal_value('loss_scale', loss_scale, None) | |||||
| if loss_scale <= 0.0: | |||||
| raise ValueError("Loss scale should be greater than 0, but got {}".format(loss_scale)) | |||||
| self.loss_scale = loss_scale | self.loss_scale = loss_scale | ||||
| if weight_decay < 0.0: | |||||
| raise ValueError("Weight decay should be equal or greater than 0, but got {}".format(weight_decay)) | |||||
| self.learning_rate = Parameter(learning_rate, name="learning_rate") | self.learning_rate = Parameter(learning_rate, name="learning_rate") | ||||
| self.parameters = ParameterTuple(parameters) | self.parameters = ParameterTuple(parameters) | ||||
| self.reciprocal_scale = 1.0 / loss_scale | self.reciprocal_scale = 1.0 / loss_scale | ||||
| @@ -15,6 +15,7 @@ | |||||
| """rmsprop""" | """rmsprop""" | ||||
| from mindspore.ops import functional as F, composite as C, operations as P | from mindspore.ops import functional as F, composite as C, operations as P | ||||
| from mindspore._checkparam import Validator as validator | from mindspore._checkparam import Validator as validator | ||||
| from mindspore._checkparam import Rel | |||||
| from .optimizer import Optimizer | from .optimizer import Optimizer | ||||
| rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | ||||
| @@ -91,14 +92,16 @@ class RMSProp(Optimizer): | |||||
| take the i-th value as the learning rate. | take the i-th value as the learning rate. | ||||
| When the learning_rate is float or learning_rate is a Tensor | When the learning_rate is float or learning_rate is a Tensor | ||||
| but the dims of the Tensor is 0, use fixed learning rate. | but the dims of the Tensor is 0, use fixed learning rate. | ||||
| Other cases are not supported. | |||||
| decay (float): Decay rate. | |||||
| momentum (float): Hyperparameter of type float, means momentum for the moving average. | |||||
| epsilon (float): Term added to the denominator to improve numerical stability. Should be greater than 0. | |||||
| Other cases are not supported. Default: 0.1. | |||||
| decay (float): Decay rate. Should be equal to or greater than 0. Default: 0.9. | |||||
| momentum (float): Hyperparameter of type float, means momentum for the moving average. Should be equal to or | |||||
| greater than 0.Default: 0.0. | |||||
| epsilon (float): Term added to the denominator to improve numerical stability. Should be greater than | |||||
| 0. Default: 1e-10. | |||||
| use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False. | use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False. | ||||
| centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False | |||||
| loss_scale (float): A floating point value for the loss scale. Default: 1.0. | |||||
| weight_decay (float): Weight decay (L2 penalty). Default: 0.0. | |||||
| centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False. | |||||
| loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0. | |||||
| weight_decay (float): Weight decay (L2 penalty). Should be equal to or greater than 0. Default: 0.0. | |||||
| decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: | decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: | ||||
| lambda x: 'beta' not in x.name and 'gamma' not in x.name. | lambda x: 'beta' not in x.name and 'gamma' not in x.name. | ||||
| @@ -118,17 +121,15 @@ class RMSProp(Optimizer): | |||||
| use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0, | use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0, | ||||
| decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): | decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): | ||||
| super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter) | super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter) | ||||
| if isinstance(momentum, float) and momentum < 0.0: | |||||
| raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) | |||||
| if decay < 0.0: | |||||
| raise ValueError("decay should be at least 0.0, but got dampening {}".format(decay)) | |||||
| self.decay = decay | |||||
| self.epsilon = epsilon | |||||
| validator.check_value_type("decay", decay, [float], self.cls_name) | |||||
| validator.check_number_range("decay", decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||||
| validator.check_value_type("momentum", momentum, [float], self.cls_name) | |||||
| validator.check_number_range("momentum", momentum, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||||
| validator.check_value_type("epsilon", epsilon, [float], self.cls_name) | |||||
| validator.check_number_range("epsilon", epsilon, 0.0, float("inf"), Rel.INC_NEITHER, self.cls_name) | |||||
| validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) | validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) | ||||
| validator.check_value_type("centered", centered, [bool], self.cls_name) | validator.check_value_type("centered", centered, [bool], self.cls_name) | ||||
| self.centered = centered | self.centered = centered | ||||
| if centered: | if centered: | ||||
| self.opt = P.ApplyCenteredRMSProp(use_locking) | self.opt = P.ApplyCenteredRMSProp(use_locking) | ||||
| @@ -137,11 +138,10 @@ class RMSProp(Optimizer): | |||||
| self.opt = P.ApplyRMSProp(use_locking) | self.opt = P.ApplyRMSProp(use_locking) | ||||
| self.momentum = momentum | self.momentum = momentum | ||||
| self.ms = self.parameters.clone(prefix="mean_square", init='zeros') | self.ms = self.parameters.clone(prefix="mean_square", init='zeros') | ||||
| self.moment = self.parameters.clone(prefix="moment", init='zeros') | self.moment = self.parameters.clone(prefix="moment", init='zeros') | ||||
| self.hyper_map = C.HyperMap() | self.hyper_map = C.HyperMap() | ||||
| self.epsilon = epsilon | |||||
| self.decay = decay | self.decay = decay | ||||
| def construct(self, gradients): | def construct(self, gradients): | ||||
| @@ -49,12 +49,12 @@ class SGD(Optimizer): | |||||
| When the learning_rate is float or learning_rate is a Tensor | When the learning_rate is float or learning_rate is a Tensor | ||||
| but the dims of the Tensor is 0, use fixed learning rate. | but the dims of the Tensor is 0, use fixed learning rate. | ||||
| Other cases are not supported. Default: 0.1. | Other cases are not supported. Default: 0.1. | ||||
| momentum (float): A floating point value the momentum. Default: 0. | |||||
| dampening (float): A floating point value of dampening for momentum. Default: 0. | |||||
| weight_decay (float): Weight decay (L2 penalty). Default: 0. | |||||
| momentum (float): A floating point value the momentum. Default: 0.0. | |||||
| dampening (float): A floating point value of dampening for momentum. Default: 0.0. | |||||
| weight_decay (float): Weight decay (L2 penalty). Default: 0.0. | |||||
| nesterov (bool): Enables the Nesterov momentum. Default: False. | nesterov (bool): Enables the Nesterov momentum. Default: False. | ||||
| loss_scale (float): A floating point value for the loss scale, which should be larger | loss_scale (float): A floating point value for the loss scale, which should be larger | ||||
| than 0.0. Default: 1.0. | |||||
| than 0.0. Default: 1.0. | |||||
| Inputs: | Inputs: | ||||
| - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | ||||
| @@ -0,0 +1,62 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ test adam """ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.nn as nn | |||||
| from mindspore.common.api import _executor | |||||
| from mindspore import Tensor, Parameter | |||||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.nn.optim import RMSProp | |||||
| class Net(nn.Cell): | |||||
| """ Net definition """ | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.weight = Parameter(Tensor(np.ones([64, 10]).astype(np.float32)), name="weight") | |||||
| self.bias = Parameter(Tensor(np.ones([10]).astype((np.float32))), name="bias") | |||||
| self.matmul = P.MatMul() | |||||
| self.biasAdd = P.BiasAdd() | |||||
| def construct(self, x): | |||||
| x = self.biasAdd(self.matmul(x, self.weight), self.bias) | |||||
| return x | |||||
| def test_rmsprop_compile(): | |||||
| """ test_adamw_compile """ | |||||
| inputs = Tensor(np.ones([1, 64]).astype(np.float32)) | |||||
| label = Tensor(np.zeros([1, 10]).astype(np.float32)) | |||||
| net = Net() | |||||
| net.set_train() | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits() | |||||
| optimizer = RMSProp(net.trainable_params(), learning_rate=0.1) | |||||
| net_with_loss = WithLossCell(net, loss) | |||||
| train_network = TrainOneStepCell(net_with_loss, optimizer) | |||||
| _executor.compile(train_network, inputs, label) | |||||
| def test_rmsprop_e(): | |||||
| net = Net() | |||||
| with pytest.raises(ValueError): | |||||
| RMSProp(net.get_parameters(), momentum=-0.1, learning_rate=0.1) | |||||
| with pytest.raises(TypeError): | |||||
| RMSProp(net.get_parameters(), momentum=1, learning_rate=0.1) | |||||