| @@ -18,12 +18,8 @@ from mindspore.common.initializer import initializer | |||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from mindspore._checkparam import ParamValidator as validator | from mindspore._checkparam import ParamValidator as validator | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| <<<<<<< HEAD | |||||
| from mindspore.common import Tensor | from mindspore.common import Tensor | ||||
| from .optimizer import Optimizer, grad_scale, apply_decay | from .optimizer import Optimizer, grad_scale, apply_decay | ||||
| ======= | |||||
| from .optimizer import Optimizer, grad_scale | |||||
| >>>>>>> add RMSProp optimizer | |||||
| rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | ||||
| centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | ||||
| @@ -123,12 +119,9 @@ class RMSProp(Optimizer): | |||||
| use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False. | use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False. | ||||
| centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False | centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False | ||||
| loss_scale (float): A floating point value for the loss scale. Default: 1.0. | loss_scale (float): A floating point value for the loss scale. Default: 1.0. | ||||
| <<<<<<< HEAD | |||||
| weight_decay (float): Weight decay (L2 penalty). Default: 0.0. | weight_decay (float): Weight decay (L2 penalty). Default: 0.0. | ||||
| decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: | decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: | ||||
| lambda x: 'beta' not in x.name and 'gamma' not in x.name. | lambda x: 'beta' not in x.name and 'gamma' not in x.name. | ||||
| ======= | |||||
| >>>>>>> add RMSProp optimizer | |||||
| Inputs: | Inputs: | ||||
| - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | ||||
| @@ -139,20 +132,12 @@ class RMSProp(Optimizer): | |||||
| Examples: | Examples: | ||||
| >>> net = Net() | >>> net = Net() | ||||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | >>> loss = nn.SoftmaxCrossEntropyWithLogits() | ||||
| <<<<<<< HEAD | |||||
| >>> opt = nn.RMSProp(params=net.trainable_params(), learning_rate=lr) | >>> opt = nn.RMSProp(params=net.trainable_params(), learning_rate=lr) | ||||
| >>> model = Model(net, loss, opt) | >>> model = Model(net, loss, opt) | ||||
| """ | """ | ||||
| def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10, | def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10, | ||||
| use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0, | use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0, | ||||
| decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): | decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): | ||||
| ======= | |||||
| >>> opt = RMSProp(params=net.trainable_params(), learning_rate=lr) | |||||
| >>> model = Model(net, loss, opt) | |||||
| """ | |||||
| def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10, | |||||
| use_locking=False, centered=False, loss_scale=1.0): | |||||
| >>>>>>> add RMSProp optimizer | |||||
| super(RMSProp, self).__init__(learning_rate, params) | super(RMSProp, self).__init__(learning_rate, params) | ||||
| if isinstance(momentum, float) and momentum < 0.0: | if isinstance(momentum, float) and momentum < 0.0: | ||||
| @@ -209,4 +194,4 @@ class RMSProp(Optimizer): | |||||
| else: | else: | ||||
| success = self.hyper_map(F.partial(rmsprop_opt, self.opt, lr, self.decay, self.epsilon, | success = self.hyper_map(F.partial(rmsprop_opt, self.opt, lr, self.decay, self.epsilon, | ||||
| self.momentum), params, self.ms, self.moment, gradients) | self.momentum), params, self.ms, self.moment, gradients) | ||||
| return success | |||||
| return success | |||||