| @@ -18,7 +18,8 @@ from mindspore.common.initializer import initializer | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore._checkparam import ParamValidator as validator | |||
| import mindspore.common.dtype as mstype | |||
| from .optimizer import Optimizer, grad_scale | |||
| from mindspore.common import Tensor | |||
| from .optimizer import Optimizer, grad_scale, apply_decay | |||
| rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | |||
| centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | |||
| @@ -118,6 +119,9 @@ class RMSProp(Optimizer): | |||
| use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False. | |||
| centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False | |||
| loss_scale (float): A floating point value for the loss scale. Default: 1.0. | |||
| weight_decay (float): Weight decay (L2 penalty). Default: 0.0. | |||
| decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: | |||
| lambda x: 'beta' not in x.name and 'gamma' not in x.name. | |||
| Inputs: | |||
| - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | |||
| @@ -132,7 +136,8 @@ class RMSProp(Optimizer): | |||
| >>> model = Model(net, loss, opt) | |||
| """ | |||
| def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10, | |||
| use_locking=False, centered=False, loss_scale=1.0): | |||
| use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0, | |||
| decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): | |||
| super(RMSProp, self).__init__(learning_rate, params) | |||
| if isinstance(momentum, float) and momentum < 0.0: | |||
| @@ -159,6 +164,7 @@ class RMSProp(Optimizer): | |||
| self.assignadd = P.AssignAdd() | |||
| self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step") | |||
| self.axis = 0 | |||
| self.one = Tensor(1, mstype.int32) | |||
| self.momentum = momentum | |||
| @@ -167,10 +173,14 @@ class RMSProp(Optimizer): | |||
| self.hyper_map = C.HyperMap() | |||
| self.decay = decay | |||
| self.decay_tf = tuple(decay_filter(x) for x in self.parameters) | |||
| self.reciprocal_scale = 1.0 / loss_scale | |||
| self.weight_decay = weight_decay * loss_scale | |||
| def construct(self, gradients): | |||
| params = self.parameters | |||
| if self.weight_decay > 0: | |||
| gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_tf, params, gradients) | |||
| if self.reciprocal_scale != 1.0: | |||
| gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients) | |||
| if self.dynamic_lr: | |||