Browse Source

add weight decay in RMSProp optimizer

tags/v0.3.0-alpha
zhaoting chang zherui 6 years ago
parent
commit
2a82eb450e
1 changed files with 1 additions and 16 deletions
  1. +1
    -16
      mindspore/nn/optim/rmsprop.py

+ 1
- 16
mindspore/nn/optim/rmsprop.py View File

@@ -18,12 +18,8 @@ from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore._checkparam import ParamValidator as validator from mindspore._checkparam import ParamValidator as validator
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
<<<<<<< HEAD
from mindspore.common import Tensor from mindspore.common import Tensor
from .optimizer import Optimizer, grad_scale, apply_decay from .optimizer import Optimizer, grad_scale, apply_decay
=======
from .optimizer import Optimizer, grad_scale
>>>>>>> add RMSProp optimizer


rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
@@ -123,12 +119,9 @@ class RMSProp(Optimizer):
use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False. use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False.
centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False
loss_scale (float): A floating point value for the loss scale. Default: 1.0. loss_scale (float): A floating point value for the loss scale. Default: 1.0.
<<<<<<< HEAD
weight_decay (float): Weight decay (L2 penalty). Default: 0.0. weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
lambda x: 'beta' not in x.name and 'gamma' not in x.name. lambda x: 'beta' not in x.name and 'gamma' not in x.name.
=======
>>>>>>> add RMSProp optimizer


Inputs: Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
@@ -139,20 +132,12 @@ class RMSProp(Optimizer):
Examples: Examples:
>>> net = Net() >>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
<<<<<<< HEAD
>>> opt = nn.RMSProp(params=net.trainable_params(), learning_rate=lr) >>> opt = nn.RMSProp(params=net.trainable_params(), learning_rate=lr)
>>> model = Model(net, loss, opt) >>> model = Model(net, loss, opt)
""" """
def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10, def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10,
use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0, use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
=======
>>> opt = RMSProp(params=net.trainable_params(), learning_rate=lr)
>>> model = Model(net, loss, opt)
"""
def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10,
use_locking=False, centered=False, loss_scale=1.0):
>>>>>>> add RMSProp optimizer
super(RMSProp, self).__init__(learning_rate, params) super(RMSProp, self).__init__(learning_rate, params)


if isinstance(momentum, float) and momentum < 0.0: if isinstance(momentum, float) and momentum < 0.0:
@@ -209,4 +194,4 @@ class RMSProp(Optimizer):
else: else:
success = self.hyper_map(F.partial(rmsprop_opt, self.opt, lr, self.decay, self.epsilon, success = self.hyper_map(F.partial(rmsprop_opt, self.opt, lr, self.decay, self.epsilon,
self.momentum), params, self.ms, self.moment, gradients) self.momentum), params, self.ms, self.moment, gradients)
return success
return success

Loading…
Cancel
Save