From 1b4041a8f1fc2a8e438a04a58d84ba3e092ac5e3 Mon Sep 17 00:00:00 2001 From: zhaoting Date: Fri, 3 Apr 2020 11:45:49 +0800 Subject: [PATCH] add weight decay in RMSProp optimizer --- mindspore/nn/optim/rmsprop.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/mindspore/nn/optim/rmsprop.py b/mindspore/nn/optim/rmsprop.py index 3000fdeeee..faaeacfaa8 100644 --- a/mindspore/nn/optim/rmsprop.py +++ b/mindspore/nn/optim/rmsprop.py @@ -18,7 +18,8 @@ from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter from mindspore._checkparam import ParamValidator as validator import mindspore.common.dtype as mstype -from .optimizer import Optimizer, grad_scale +from mindspore.common import Tensor +from .optimizer import Optimizer, grad_scale, apply_decay rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") @@ -118,6 +119,9 @@ class RMSProp(Optimizer): use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False. centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False loss_scale (float): A floating point value for the loss scale. Default: 1.0. + weight_decay (float): Weight decay (L2 penalty). Default: 0.0. + decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: + lambda x: 'beta' not in x.name and 'gamma' not in x.name. Inputs: - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. @@ -132,7 +136,8 @@ class RMSProp(Optimizer): >>> model = Model(net, loss, opt) """ def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10, - use_locking=False, centered=False, loss_scale=1.0): + use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0, + decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): super(RMSProp, self).__init__(learning_rate, params) if isinstance(momentum, float) and momentum < 0.0: @@ -159,6 +164,7 @@ class RMSProp(Optimizer): self.assignadd = P.AssignAdd() self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step") self.axis = 0 + self.one = Tensor(1, mstype.int32) self.momentum = momentum @@ -167,10 +173,14 @@ class RMSProp(Optimizer): self.hyper_map = C.HyperMap() self.decay = decay + self.decay_tf = tuple(decay_filter(x) for x in self.parameters) self.reciprocal_scale = 1.0 / loss_scale + self.weight_decay = weight_decay * loss_scale def construct(self, gradients): params = self.parameters + if self.weight_decay > 0: + gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_tf, params, gradients) if self.reciprocal_scale != 1.0: gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients) if self.dynamic_lr: