Merge pull request !66 from zhaoting/add-RMSProptags/v0.2.0-alpha
| @@ -183,6 +183,8 @@ const char kNameDiagPart[] = "DiagPart"; | |||
| const char kNameSpaceToBatch[] = "SpaceToBatch"; | |||
| const char kNameBatchToSpace[] = "BatchToSpace"; | |||
| const char kNameAtan2[] = "Atan2"; | |||
| const char kNameApplyRMSProp[] = "ApplyRMSProp"; | |||
| const char kNameApplyCenteredRMSProp[] = "ApplyCenteredRMSProp"; | |||
| // -----------------OpAdapter initialization-------------- | |||
| std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_map() { | |||
| @@ -367,7 +369,9 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||
| {string(kNameDiagPart), ADPT_DESC(DiagPart)}, | |||
| {string(kNameSpaceToBatch), ADPT_DESC(SpaceToBatchD)}, | |||
| {string(kNameBatchToSpace), ADPT_DESC(BatchToSpaceD)}, | |||
| {string(kNameAtan2), ADPT_DESC(Atan2)}}; | |||
| {string(kNameAtan2), ADPT_DESC(Atan2)}, | |||
| {string(kNameApplyRMSProp), ADPT_DESC(ApplyRMSPropD)}, | |||
| {string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}}; | |||
| #ifdef ENABLE_GE | |||
| adpt_map[string(kNamePrint)] = ADPT_DESC(Print); | |||
| #endif | |||
| @@ -1202,6 +1202,22 @@ INPUT_MAP(Atan2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; | |||
| ATTR_MAP(Atan2) = EMPTY_ATTR_MAP; | |||
| OUTPUT_MAP(Atan2) = {{0, OUTPUT_DESC(y)}}; | |||
| // ApplyRMSPropD | |||
| INPUT_MAP(ApplyRMSPropD) = { | |||
| {1, INPUT_DESC(var)}, {2, INPUT_DESC(ms)}, {3, INPUT_DESC(mom)}, {4, INPUT_DESC(grad)}, {5, INPUT_DESC(lr)}}; | |||
| INPUT_ATTR_MAP(ApplyRMSPropD) = {{6, ATTR_DESC(rho, AnyTraits<float>())}, | |||
| {7, ATTR_DESC(momentum, AnyTraits<float>())}, | |||
| {8, ATTR_DESC(epsilon, AnyTraits<float>())}}; | |||
| ATTR_MAP(ApplyRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | |||
| OUTPUT_MAP(ApplyRMSPropD) = {{0, OUTPUT_DESC(var)}}; | |||
| // ApplyCenteredRMSProp | |||
| INPUT_MAP(ApplyCenteredRMSProp) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)}, {3, INPUT_DESC(ms)}, | |||
| {4, INPUT_DESC(mom)}, {5, INPUT_DESC(grad)}, {6, INPUT_DESC(lr)}, | |||
| {7, INPUT_DESC(rho)}, {8, INPUT_DESC(momentum)}, {9, INPUT_DESC(epsilon)}}; | |||
| ATTR_MAP(ApplyCenteredRMSProp) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | |||
| OUTPUT_MAP(ApplyCenteredRMSProp) = {{0, OUTPUT_DESC(var)}}; | |||
| #ifdef ENABLE_GE | |||
| INPUT_MAP(Print) = EMPTY_INPUT_MAP; | |||
| @@ -445,6 +445,12 @@ DECLARE_OP_ADAPTER(BatchToSpaceD) | |||
| DECLARE_OP_USE_OUTPUT(BatchToSpaceD) | |||
| DECLARE_OP_ADAPTER(Atan2) | |||
| DECLARE_OP_USE_OUTPUT(Atan2) | |||
| DECLARE_OP_ADAPTER(ApplyRMSPropD) | |||
| DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD) | |||
| DECLARE_OP_USE_OUTPUT(ApplyRMSPropD) | |||
| DECLARE_OP_ADAPTER(ApplyCenteredRMSProp) | |||
| DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSProp) | |||
| #ifdef ENABLE_GE | |||
| DECLARE_OP_ADAPTER(Print) | |||
| DECLARE_OP_USE_DYN_INPUT(Print) | |||
| @@ -25,6 +25,7 @@ from .lamb import Lamb | |||
| from .sgd import SGD | |||
| from .lars import LARS | |||
| from .ftrl import FTRL | |||
| from .rmsprop import RMSProp | |||
| __all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay', | |||
| 'AdamWeightDecayDynamicLR', 'Lamb', 'SGD', 'FTRL'] | |||
| 'AdamWeightDecayDynamicLR', 'Lamb', 'SGD', 'FTRL', 'RMSProp'] | |||
| @@ -0,0 +1,187 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """rmsprop""" | |||
| from mindspore.ops import functional as F, composite as C, operations as P | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore._checkparam import ParamValidator as validator | |||
| import mindspore.common.dtype as mstype | |||
| from .optimizer import Optimizer, grad_scale | |||
| rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | |||
| centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | |||
| @rmsprop_opt.register("Function", "Number", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor") | |||
| def _rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, ms, mom, grad): | |||
| """Apply rmsprop optimizer to the weight parameter.""" | |||
| success = True | |||
| success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon)) | |||
| return success | |||
| @rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor") | |||
| def _rmsprop_opt_dynamic_lr(opt, learning_rate, decay, epsilon, momentum, weight, ms, mom, grad): | |||
| """Apply rmsprop optimizer to the weight parameter using dynamic learning rate.""" | |||
| success = True | |||
| success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon)) | |||
| return success | |||
| @centered_rmsprop_opt.register("Function", "Number", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "Tensor") | |||
| def _centered_rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, mg, ms, mom, grad): | |||
| """Apply centered rmsprop optimizer to the weight parameter.""" | |||
| success = True | |||
| success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon)) | |||
| return success | |||
| @centered_rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "Tensor") | |||
| def _centered_rmsprop_opt_dynamic_lr(opt, learning_rate, decay, epsilon, momentum, weight, mg, ms, mom, grad): | |||
| """Apply centered rmsprop optimizer to the weight parameter using dynamic learning rate.""" | |||
| success = True | |||
| success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon)) | |||
| return success | |||
| class RMSProp(Optimizer): | |||
| """ | |||
| Implements Root Mean Squared Propagation (RMSProp) algorithm. | |||
| Note: | |||
| Update `params` according to the RMSProp algorithm. | |||
| The equation is as follows: | |||
| .. math:: | |||
| s_{t} = \\rho s_{t-1} + (1 - \\rho)(\\nabla Q_{i}(w))^2 | |||
| .. math:: | |||
| m_{t} = \\beta m_{t-1} + \\frac{\\eta} {\\sqrt{s_{t} + \\epsilon}} \\nabla Q_{i}(w) | |||
| .. math:: | |||
| w = w - m_{t} | |||
| The first equation calculates moving average of the squared gradient for | |||
| each weight. Then dividing the gradient by :math:`\\sqrt{ms_{t} + \\epsilon}`. | |||
| if centered is True: | |||
| .. math:: | |||
| g_{t} = \\rho g_{t-1} + (1 - \\rho)\\nabla Q_{i}(w) | |||
| .. math:: | |||
| s_{t} = \\rho s_{t-1} + (1 - \\rho)(\\nabla Q_{i}(w))^2 | |||
| .. math:: | |||
| m_{t} = \\beta m_{t-1} + \\frac{\\eta} {\\sqrt{s_{t} - g_{t}^2 + \\epsilon}} \\nabla Q_{i}(w) | |||
| .. math:: | |||
| w = w - m_{t} | |||
| where, :math:`w` represents `params`, which will be updated. | |||
| :math:`g_{t}` is mean gradients, :math:`g_{t-1}` is the last moment of :math:`g_{t}`. | |||
| :math:`s_{t}` is the mean square gradients, :math:`s_{t-1}` is the last moment of :math:`s_{t}`, | |||
| :math:`m_{t}` is moment, the delta of `w`, :math:`m_{t-1}` is the last moment of :math:`m_{t}`. | |||
| :math:`\\rho` represents `decay`. :math:`\\beta` is the momentum term, represents `momentum`. | |||
| :math:`\\epsilon` is a smoothing term to avoid division by zero, represents `epsilon`. | |||
| :math:`\\eta` is learning rate, represents `learning_rate`. :math:`\\nabla Q_{i}(w)` is gradientse, | |||
| represents `gradients`. | |||
| Args: | |||
| params (list[Parameter]): A list of parameter, which will be updated. The element in `parameters` | |||
| should be class mindspore.Parameter. | |||
| learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is | |||
| Iterable or a Tensor and the dims of the Tensor is 1, | |||
| use dynamic learning rate, then the i-th step will | |||
| take the i-th value as the learning rate. | |||
| When the learning_rate is float or learning_rate is a Tensor | |||
| but the dims of the Tensor is 0, use fixed learning rate. | |||
| Other cases are not supported. | |||
| decay (float): Decay rate. | |||
| momentum (float): Hyperparameter of type float, means momentum for the moving average. | |||
| epsilon (float): Term added to the denominator to improve numerical stability. Should be greater than 0. | |||
| use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False. | |||
| centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False | |||
| loss_scale (float): A floating point value for the loss scale. Default: 1.0. | |||
| Inputs: | |||
| - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | |||
| Outputs: | |||
| Tensor[bool], the value is True. | |||
| Examples: | |||
| >>> net = Net() | |||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| >>> opt = RMSProp(params=net.trainable_params(), learning_rate=lr) | |||
| >>> model = Model(net, loss, opt) | |||
| """ | |||
| def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10, | |||
| use_locking=False, centered=False, loss_scale=1.0): | |||
| super(RMSProp, self).__init__(learning_rate, params) | |||
| if isinstance(momentum, float) and momentum < 0.0: | |||
| raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) | |||
| if decay < 0.0: | |||
| raise ValueError("decay should be at least 0.0, but got dampening {}".format(decay)) | |||
| self.decay = decay | |||
| self.epsilon = epsilon | |||
| validator.check_type("use_locking", use_locking, [bool]) | |||
| validator.check_type("centered", centered, [bool]) | |||
| self.centered = centered | |||
| if centered: | |||
| self.opt = P.ApplyCenteredRMSProp(use_locking) | |||
| self.mg = self.parameters.clone(prefix="mean_grad", init='zeros') | |||
| else: | |||
| self.opt = P.ApplyRMSProp(use_locking) | |||
| self.dynamic_lr = False | |||
| if not isinstance(learning_rate, float): | |||
| self.dynamic_lr = True | |||
| self.gather = P.GatherV2() | |||
| self.assignadd = P.AssignAdd() | |||
| self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step") | |||
| self.axis = 0 | |||
| self.momentum = momentum | |||
| self.ms = self.parameters.clone(prefix="mean_square", init='zeros') | |||
| self.moment = self.parameters.clone(prefix="moment", init='zeros') | |||
| self.hyper_map = C.HyperMap() | |||
| self.decay = decay | |||
| self.reciprocal_scale = 1.0 / loss_scale | |||
| def construct(self, gradients): | |||
| params = self.parameters | |||
| if self.reciprocal_scale != 1.0: | |||
| gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients) | |||
| if self.dynamic_lr: | |||
| lr = self.gather(self.learning_rate, self.global_step, self.axis) | |||
| F.control_depend(lr, self.assignadd(self.global_step, self.one)) | |||
| else: | |||
| lr = self.learning_rate | |||
| if self.centered: | |||
| success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, lr, self.decay, self.epsilon, | |||
| self.momentum), params, self.mg, self.ms, self.moment, gradients) | |||
| else: | |||
| success = self.hyper_map(F.partial(rmsprop_opt, self.opt, lr, self.decay, self.epsilon, | |||
| self.momentum), params, self.ms, self.moment, gradients) | |||
| return success | |||
| @@ -394,8 +394,8 @@ def _split_shape_index(input_shape, axis): | |||
| axis = tuple([axis]) | |||
| reduction_indices = tuple([(i + rank) % rank for i in axis]) | |||
| other_indices = tuple(set(range(rank)) - set(reduction_indices)) | |||
| reduced_num = reduce(lambda x, y: x * y, [input_shape[i] for i in reduction_indices]) | |||
| other_num = reduce(lambda x, y: x * y, [input_shape[i] for i in other_indices]) | |||
| reduced_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in reduction_indices]) | |||
| other_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in other_indices]) | |||
| perm = reduction_indices + other_indices | |||
| return tuple([reduced_num, other_num]), perm | |||
| @@ -65,7 +65,8 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, | |||
| SmoothL1Loss, Softmax, | |||
| SoftmaxCrossEntropyWithLogits, ROIAlign, | |||
| SparseSoftmaxCrossEntropyWithLogits, Tanh, | |||
| TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl) | |||
| TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, | |||
| ApplyRMSProp, ApplyCenteredRMSProp) | |||
| from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey | |||
| @@ -228,6 +229,8 @@ __all__ = [ | |||
| "SpaceToBatch", | |||
| "BatchToSpace", | |||
| "Atan2", | |||
| "ApplyRMSProp", | |||
| "ApplyCenteredRMSProp" | |||
| ] | |||
| __all__.sort() | |||
| @@ -1359,6 +1359,158 @@ class SGD(PrimitiveWithInfer): | |||
| validator.check_typename("stat_dtype", stat_dtype, [mstype.float16, mstype.float32]) | |||
| return parameters_dtype | |||
| class ApplyRMSProp(PrimitiveWithInfer): | |||
| """ | |||
| Optimizer that implements the Root Mean Square prop(RMSProp) algorithm. | |||
| Note: | |||
| Update `var` according to the RMSProp algorithm. | |||
| .. math:: | |||
| s_{t} = \\rho s_{t-1} + (1 - \\rho)(\\nabla Q_{i}(w))^2 | |||
| .. math:: | |||
| m_{t} = \\beta m_{t-1} + \\frac{\\eta} {\\sqrt{s_{t} + \\epsilon}} \\nabla Q_{i}(w) | |||
| .. math:: | |||
| w = w - m_{t} | |||
| where, :math:`w` represents `var`, which will be updated. | |||
| :math:`s_{t}` represents `mean_square`, :math:`s_{t-1}` is the last momentent of :math:`s_{t}`, | |||
| :math:`m_{t}` represents `moment`, :math:`m_{t-1}` is the last momentent of :math:`m_{t}`. | |||
| :math:`\\rho` represents `decay`. :math:`\\beta` is the momentum term, represents `momentum`. | |||
| :math:`\\epsilon` is a smoothing term to avoid division by zero, represents `epsilon`. | |||
| :math:`\\eta` represents `learning_rate`. :math:`\\nabla Q_{i}(w)` represents `grad`. | |||
| Args: | |||
| use_locking (bool): Enable a lock to protect the update of variable tensors. Default: False. | |||
| Inputs: | |||
| - **var** (Tensor) - Weights to be update. | |||
| - **mean_square** (Tensor) - Mean square gradients, must have the same type as `var`. | |||
| - **moment** (Tensor) - Delta of `var`, must have the same type as `var`. | |||
| - **grad** (Tensor) - Gradients, must have the same type as `var`. | |||
| - **learning_rate** (Union[Number, Tensor]) - Learning rate. | |||
| - **decay** (float) - Decay rate. | |||
| - **momentum** (float) - Momentum. | |||
| - **epsilon** (float) - Ridge term. | |||
| Outputs: | |||
| Tensor, parameters to be update. | |||
| Examples: | |||
| >>> net = Net() | |||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| >>> opt = RMSProp(params=net.trainable_params(), learning_rate=learning_rate) | |||
| >>> model = Model(net, loss, opt) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, use_locking=False): | |||
| self.use_locking = validator.check_type("use_locking", use_locking, [bool]) | |||
| def infer_shape(self, var_shape, mean_square_shape, moment_shape, grad_shape, learning_rate_shape, decay_shape, | |||
| momentum_shape, epsilon_shape): | |||
| validator.check_param_equal("var_shape", var_shape, "mean_square_shape", mean_square_shape) | |||
| validator.check_param_equal("var_shape", var_shape, "moment_shape", moment_shape) | |||
| validator.check_param_equal("var_shape", var_shape, "grad_shape", grad_shape) | |||
| return var_shape | |||
| def infer_dtype(self, var_dtype, mean_square_dtype, moment_dtype, grad_dtype, learning_rate_dtype, decay_dtype, | |||
| momentum_dtype, epsilon_dtype): | |||
| validator.check_subclass("var_dtype", var_dtype, mstype.tensor) | |||
| validator.check_subclass("mean_square_dtype", mean_square_dtype, mstype.tensor) | |||
| validator.check_subclass("moment_dtype", moment_dtype, mstype.tensor) | |||
| validator.check_subclass("grad_dtype", moment_dtype, mstype.tensor) | |||
| args = {"var_dtype": var_dtype, "mean_square_dtype": mean_square_dtype, "moment_dtype": moment_dtype, | |||
| "grad_dtype": grad_dtype} | |||
| validator.check_type_same(args, mstype.number_type) | |||
| args = {"learning_rate_dtype": learning_rate_dtype, "decay_dtype": decay_dtype, | |||
| 'momentum_dtype': momentum_dtype, "epsilon_dtype": epsilon_dtype} | |||
| validator.check_type_same(args, [mstype.float16, mstype.float32]) | |||
| return var_dtype | |||
| class ApplyCenteredRMSProp(PrimitiveWithInfer): | |||
| """ | |||
| Optimizer that implements the centered RMSProp algorithm. | |||
| Note: | |||
| Update `var` according to the centered RMSProp algorithm. | |||
| .. math:: | |||
| g_{t} = \\rho g_{t-1} + (1 - \\rho)\\nabla Q_{i}(w) | |||
| .. math:: | |||
| s_{t} = \\rho s_{t-1} + (1 - \\rho)(\\nabla Q_{i}(w))^2 | |||
| .. math:: | |||
| m_{t} = \\beta m_{t-1} + \\frac{\\eta} {\\sqrt{s_{t} - g_{t}^2 + \\epsilon}} \\nabla Q_{i}(w) | |||
| .. math:: | |||
| w = w - m_{t} | |||
| where, :math:`w` represents `var`, which will be updated. | |||
| :math:`g_{t}` represents `mean_gradient`, :math:`g_{t-1}` is the last momentent of :math:`g_{t}`. | |||
| :math:`s_{t}` represents `mean_square`, :math:`s_{t-1}` is the last momentent of :math:`s_{t}`, | |||
| :math:`m_{t}` represents `moment`, :math:`m_{t-1}` is the last momentent of :math:`m_{t}`. | |||
| :math:`\\rho` represents `decay`. :math:`\\beta` is the momentum term, represents `momentum`. | |||
| :math:`\\epsilon` is a smoothing term to avoid division by zero, represents `epsilon`. | |||
| :math:`\\eta` represents `learning_rate`. :math:`\\nabla Q_{i}(w)` represents `grad`. | |||
| Args: | |||
| use_locking (bool): Enable a lock to protect the update of variable tensors. Default: False. | |||
| Inputs: | |||
| - **var** (Tensor) - Weights to be update. | |||
| - **mean_gradient** (Tensor) - Mean gradients, must have the same type as `var`. | |||
| - **mean_square** (Tensor) - Mean square gradients, must have the same type as `var`. | |||
| - **moment** (Tensor) - Delta of `var`, must have the same type as `var`. | |||
| - **grad** (Tensor) - Gradients, must have the same type as `var`. | |||
| - **learning_rate** (Union[Number, Tensor]) - Learning rate. | |||
| - **decay** (float) - Decay rate. | |||
| - **momentum** (float) - Momentum. | |||
| - **epsilon** (float) - Ridge term. | |||
| Outputs: | |||
| Tensor, parameters to be update. | |||
| Examples: | |||
| >>> net = Net() | |||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| >>> opt = RMSProp(params=net.trainable_params(), learning_rate=learning_rate, centered=True) | |||
| >>> model = Model(net, loss, opt) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, use_locking=False): | |||
| self.use_locking = validator.check_type("use_locking", use_locking, [bool]) | |||
| def infer_shape(self, var_shape, mean_gradient_shape, mean_square_shape, moment_shape, grad_shape, | |||
| learning_rate_shape, decay_shape, momentum_shape, epsilon_shape): | |||
| validator.check_param_equal("var_shape", var_shape, "mean_gradient_shape", mean_gradient_shape) | |||
| validator.check_param_equal("var_shape", var_shape, "mean_square_shape", mean_square_shape) | |||
| validator.check_param_equal("var_shape", var_shape, "moment_shape", moment_shape) | |||
| validator.check_param_equal("var_shape", var_shape, "grad_shape", grad_shape) | |||
| return var_shape | |||
| def infer_dtype(self, var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype, grad_dtype, | |||
| learning_rate_dtype, rho_dtype, momentum_dtype, epsilon_dtype): | |||
| validator.check_subclass("var_dtype", var_dtype, mstype.tensor) | |||
| validator.check_subclass("mean_gradient_dtype", mean_gradient_dtype, mstype.tensor) | |||
| validator.check_subclass("mean_square_dtype", mean_square_dtype, mstype.tensor) | |||
| validator.check_subclass("moment_dtype", moment_dtype, mstype.tensor) | |||
| validator.check_subclass("grad_dtype", moment_dtype, mstype.tensor) | |||
| args = {"var_dtype": var_dtype, "mean_gradient_dtype": mean_gradient_dtype, | |||
| "mean_square_dtype": mean_square_dtype, "moment_dtype": moment_dtype, "grad_dtype": grad_dtype} | |||
| validator.check_type_same(args, mstype.number_type) | |||
| args = {"learning_rate_dtype": learning_rate_dtype, "rho_dtype": rho_dtype, 'momentum_dtype': momentum_dtype, | |||
| "epsilon_dtype": epsilon_dtype} | |||
| validator.check_type_same(args, [mstype.float16, mstype.float32]) | |||
| return var_dtype | |||
| class LayerNorm(Primitive): | |||
| r""" | |||
| @@ -223,6 +223,10 @@ class InputOpNet(nn.Cell): | |||
| x = self.op(x1, x2, x3, x4, x5, self.c1) | |||
| return x | |||
| def construct5_c4(self, x1, x2, x3, x4, x5): | |||
| x = self.op(x1, x2, x3, x4, x5, self.c1, self.c2, self.c3, self.c4) | |||
| return x | |||
| def gen_net(op, input_num, training=True, desc_const=(), const_first=False, add_fake_input=False): | |||
| if isinstance(op, nn.Cell): | |||
| return op | |||
| @@ -805,6 +805,18 @@ test_case_nn_ops = [ | |||
| 'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]], | |||
| 'desc_bprop': [3, 3], | |||
| 'skip': ['backward']}), | |||
| ('ApplyRMSProp', { | |||
| 'block': P.ApplyRMSProp(), | |||
| 'desc_const': [0.9, 0.0, 1e-10, 0.001], | |||
| 'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]], | |||
| 'desc_bprop': [3, 3], | |||
| 'skip': ['backward']}), | |||
| ('ApplyCenteredRMSProp', { | |||
| 'block': P.ApplyCenteredRMSProp(), | |||
| 'desc_const': [0.9, 0.0, 1e-10, 0.001], | |||
| 'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]], | |||
| 'desc_bprop': [3, 3], | |||
| 'skip': ['backward']}), | |||
| ] | |||
| test_case_array_ops = [ | |||