Browse Source

!2116 Change multitypefungraph to internal interface

Merge pull request !2116 from ghzl/change-multitypefungraph-to-internal-interface
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
e4af8b5be2
10 changed files with 64 additions and 64 deletions
  1. +10
    -10
      mindspore/nn/optim/adam.py
  2. +8
    -8
      mindspore/nn/optim/ftrl.py
  3. +4
    -4
      mindspore/nn/optim/lamb.py
  4. +5
    -5
      mindspore/nn/optim/lars.py
  5. +7
    -7
      mindspore/nn/optim/lazyadam.py
  6. +4
    -4
      mindspore/nn/optim/momentum.py
  7. +8
    -8
      mindspore/nn/optim/optimizer.py
  8. +3
    -3
      mindspore/nn/optim/proximal_ada_grad.py
  9. +11
    -11
      mindspore/nn/optim/rmsprop.py
  10. +4
    -4
      mindspore/nn/optim/sgd.py

+ 10
- 10
mindspore/nn/optim/adam.py View File

@@ -26,10 +26,10 @@ from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from .optimizer import Optimizer from .optimizer import Optimizer


adam_opt = C.MultitypeFuncGraph("adam_opt")
_adam_opt = C.MultitypeFuncGraph("adam_opt")




@adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag): def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag):
""" """
Update parameters. Update parameters.
@@ -101,8 +101,8 @@ def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, po
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name) validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)




@adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tuple",
"Tensor", "Tensor", "Tensor")
@_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tuple",
"Tensor", "Tensor", "Tensor")
def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
moment1, moment2): moment1, moment2):
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse.""" """Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
@@ -112,8 +112,8 @@ def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2
return success return success




@adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor")
@_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor")
def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
moment1, moment2): moment1, moment2):
"""Apply adam optimizer to the weight parameter using Tensor.""" """Apply adam optimizer to the weight parameter using Tensor."""
@@ -261,11 +261,11 @@ class Adam(Optimizer):
beta2_power = self.beta2_power * self.beta2 beta2_power = self.beta2_power * self.beta2
self.beta2_power = beta2_power self.beta2_power = beta2_power
if self.is_group_lr: if self.is_group_lr:
success = self.map_(F.partial(adam_opt, self.opt, self.sparse_opt, beta1_power, beta2_power,
success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, beta1_power, beta2_power,
self.beta1, self.beta2, self.eps), self.beta1, self.beta2, self.eps),
lr, gradients, params, moment1, moment2) lr, gradients, params, moment1, moment2)
else: else:
success = self.map_(F.partial(adam_opt, self.opt, self.sparse_opt, beta1_power, beta2_power,
success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, beta1_power, beta2_power,
self.beta1, self.beta2, self.eps, lr), self.beta1, self.beta2, self.eps, lr),
gradients, params, moment1, moment2) gradients, params, moment1, moment2)
return success return success
@@ -328,7 +328,7 @@ class AdamWeightDecay(Optimizer):


def construct(self, gradients): def construct(self, gradients):
lr = self.get_lr() lr = self.get_lr()
updated_velocity = self.hyper_map(F.partial(adam_opt, self.beta1, self.beta2, self.eps, lr,
updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor), self.weight_decay_tensor),
self.params, self.moments1, self.moments2, gradients, self.decay_flag) self.params, self.moments1, self.moments2, gradients, self.decay_flag)


@@ -424,7 +424,7 @@ class AdamWeightDecayDynamicLR(Optimizer):
warmup_lr = self.start_learning_rate * warmup_percent warmup_lr = self.start_learning_rate * warmup_percent
is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32) is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32)
lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
updated_velocity = self.hyper_map(F.partial(adam_opt, self.beta1, self.beta2, self.eps, lr,
updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor), self.weight_decay_tensor),
self.params, self.moments1, self.moments2, gradients, self.decay_flag) self.params, self.moments1, self.moments2, gradients, self.decay_flag)




+ 8
- 8
mindspore/nn/optim/ftrl.py View File

@@ -18,13 +18,13 @@ from mindspore.common import Tensor
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from .optimizer import Optimizer, apply_decay, grad_scale
from .optimizer import Optimizer, _apply_decay, _grad_scale


ftrl_opt = C.MultitypeFuncGraph("ftrl_opt")
_ftrl_opt = C.MultitypeFuncGraph("ftrl_opt")




@ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tuple", "Tensor",
"Tensor")
@_ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tuple", "Tensor",
"Tensor")
def _tensor_run_opt_with_sparse(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment): def _tensor_run_opt_with_sparse(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment):
"""Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse.""" """Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse."""
success = True success = True
@@ -32,8 +32,8 @@ def _tensor_run_opt_with_sparse(opt, spars_opt, learning_rate, l1, l2, lr_power,
return success return success




@ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor",
"Tensor")
@_ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor",
"Tensor")
def _tensor_run_opt(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment): def _tensor_run_opt(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment):
"""Apply ftrl optimizer to the weight parameter.""" """Apply ftrl optimizer to the weight parameter."""
success = True success = True
@@ -124,9 +124,9 @@ class FTRL(Optimizer):
linear = self.linear linear = self.linear
lr = self.learning_rate lr = self.learning_rate
if self.weight_decay > 0.0: if self.weight_decay > 0.0:
grads = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_tf, params, grads)
grads = self.hyper_map(F.partial(_apply_decay, self.weight_decay), self.decay_tf, params, grads)


grads = self.scale_grad(grads) grads = self.scale_grad(grads)
success = self.map_(F.partial(ftrl_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2, self.lr_power),
success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2, self.lr_power),
linear, grads, params, moments) linear, grads, params, moments)
return success return success

+ 4
- 4
mindspore/nn/optim/lamb.py View File

@@ -28,10 +28,10 @@ from .. import layer


num_one = Tensor(np.ones([1]), mstype.float32) num_one = Tensor(np.ones([1]), mstype.float32)


lamb_opt = C.MultitypeFuncGraph("lamb_opt")
_lamb_opt = C.MultitypeFuncGraph("lamb_opt")


@lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool")
@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool")
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, param, m, v, def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, param, m, v,
gradient, decay_flag): gradient, decay_flag):
""" """
@@ -227,7 +227,7 @@ class Lamb(Optimizer):
warmup_lr = self.start_learning_rate * warmup_percent warmup_lr = self.start_learning_rate * warmup_percent
is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32) is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32)
lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
updated_velocity = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps, lr,
updated_velocity = self.hyper_map(F.partial(_lamb_opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor, self.global_step), self.weight_decay_tensor, self.global_step),
self.params, self.moments1, self.moments2, gradients, self.decay_flag) self.params, self.moments1, self.moments2, gradients, self.decay_flag)




+ 5
- 5
mindspore/nn/optim/lars.py View File

@@ -22,12 +22,12 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from .optimizer import grad_scale, Optimizer
from .optimizer import _grad_scale, Optimizer


lars_opt = C.MultitypeFuncGraph("lars_opt")
_lars_opt = C.MultitypeFuncGraph("lars_opt")




@lars_opt.register("Function", "Number", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
@_lars_opt.register("Function", "Number", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
def _tensor_run_opt(lars, weight_decay, learning_rate, gradient, weight, decay_flag, lars_flag): def _tensor_run_opt(lars, weight_decay, learning_rate, gradient, weight, decay_flag, lars_flag):
"""Apply lars optimizer to the weight parameter.""" """Apply lars optimizer to the weight parameter."""
if lars_flag: if lars_flag:
@@ -119,9 +119,9 @@ class LARS(Optimizer):
else: else:
lr = self.learning_rate lr = self.learning_rate
if self.reciprocal_scale != 1.0: if self.reciprocal_scale != 1.0:
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)
gradients = self.hyper_map(F.partial(_grad_scale, self.reciprocal_scale), gradients)


grad_t = self.hyper_map(F.partial(lars_opt, self.lars, self.weight_decay, lr),
grad_t = self.hyper_map(F.partial(_lars_opt, self.lars, self.weight_decay, lr),
gradients, params, self.decay_flag, self.lars_flag) gradients, params, self.decay_flag, self.lars_flag)
success = self.opt(grad_t) success = self.opt(grad_t)




+ 7
- 7
mindspore/nn/optim/lazyadam.py View File

@@ -24,11 +24,11 @@ from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from .optimizer import Optimizer from .optimizer import Optimizer


lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt")
_lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt")




@lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tuple",
"Tensor", "Tensor", "Tensor")
@_lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tuple",
"Tensor", "Tensor", "Tensor")
def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
moment1, moment2): moment1, moment2):
"""Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse.""" """Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse."""
@@ -38,8 +38,8 @@ def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2
return success return success




@lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor")
@_lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor")
def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
moment1, moment2): moment1, moment2):
"""Apply adam optimizer to the weight parameter using Tensor.""" """Apply adam optimizer to the weight parameter using Tensor."""
@@ -189,11 +189,11 @@ class LazyAdam(Optimizer):
self.beta2_power = self.beta2_power * self.beta2 self.beta2_power = self.beta2_power * self.beta2


if self.is_group_lr: if self.is_group_lr:
success = self.map_(F.partial(lazy_adam_opt, self.opt, self.sparse_opt, self.beta1_power,
success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self.beta1_power,
self.beta2_power, self.beta1, self.beta2, self.eps), self.beta2_power, self.beta1, self.beta2, self.eps),
lr, gradients, self.parameters, self.moment1, self.moment2) lr, gradients, self.parameters, self.moment1, self.moment2)
else: else:
success = self.map_(F.partial(lazy_adam_opt, self.opt, self.sparse_opt, self.beta1_power,
success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self.beta1_power,
self.beta2_power, self.beta1, self.beta2, self.eps, lr), self.beta2_power, self.beta1, self.beta2, self.eps, lr),
gradients, self.parameters, self.moment1, self.moment2) gradients, self.parameters, self.moment1, self.moment2)
return success return success

+ 4
- 4
mindspore/nn/optim/momentum.py View File

@@ -21,10 +21,10 @@ from mindspore._checkparam import check_bool
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from .optimizer import Optimizer from .optimizer import Optimizer


momentum_opt = C.MultitypeFuncGraph("momentum_opt")
_momentum_opt = C.MultitypeFuncGraph("momentum_opt")




@momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment): def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment):
"""Apply momentum optimizer to the weight parameter using Tensor.""" """Apply momentum optimizer to the weight parameter using Tensor."""
success = True success = True
@@ -129,7 +129,7 @@ class Momentum(Optimizer):
gradients = self.scale_grad(gradients) gradients = self.scale_grad(gradients)
lr = self.get_lr() lr = self.get_lr()
if self.is_group_lr: if self.is_group_lr:
success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum), lr, gradients, params, moments)
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments)
else: else:
success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
return success return success

+ 8
- 8
mindspore/nn/optim/optimizer.py View File

@@ -171,11 +171,11 @@ class Optimizer(Cell):
params = self.parameters params = self.parameters
if self.is_group: if self.is_group:
if self.exec_weight_decay: if self.exec_weight_decay:
gradients = self.hyper_map(F.partial(apply_decay), self.weight_decay, self.decay_flags,
gradients = self.hyper_map(F.partial(_apply_decay), self.weight_decay, self.decay_flags,
params, gradients) params, gradients)
else: else:
if self.weight_decay > 0: if self.weight_decay > 0:
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags,
gradients = self.hyper_map(F.partial(_apply_decay, self.weight_decay), self.decay_flags,
params, gradients) params, gradients)


return gradients return gradients
@@ -196,7 +196,7 @@ class Optimizer(Cell):


""" """
if self.reciprocal_scale != 1.0: if self.reciprocal_scale != 1.0:
gradients = self.map_(F.partial(grad_scale, self.reciprocal_scale), gradients)
gradients = self.map_(F.partial(_grad_scale, self.reciprocal_scale), gradients)


return gradients return gradients


@@ -390,10 +390,10 @@ class Optimizer(Cell):


op_add = P.AddN() op_add = P.AddN()


apply_decay = C.MultitypeFuncGraph("apply_decay")
_apply_decay = C.MultitypeFuncGraph("apply_decay")




@apply_decay.register("Number", "Bool", "Tensor", "Tensor")
@_apply_decay.register("Number", "Bool", "Tensor", "Tensor")
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay.""" """Get grad with weight_decay."""
if if_apply: if if_apply:
@@ -401,10 +401,10 @@ def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
return gradient return gradient




grad_scale = C.MultitypeFuncGraph("grad_scale")
_grad_scale = C.MultitypeFuncGraph("grad_scale")




@grad_scale.register("Number", "Tensor")
@_grad_scale.register("Number", "Tensor")
def tensor_grad_scale(scale, grad): def tensor_grad_scale(scale, grad):
"""Get grad with scale.""" """Get grad with scale."""
if scale == 1.0: if scale == 1.0:
@@ -412,7 +412,7 @@ def tensor_grad_scale(scale, grad):
return grad * scale return grad * scale




@grad_scale.register("Number", "Tuple")
@_grad_scale.register("Number", "Tuple")
def tensor_grad_scale_with_sparse(scale, grad): def tensor_grad_scale_with_sparse(scale, grad):
"""Get grad with scale.""" """Get grad with scale."""
if scale == 1.0: if scale == 1.0:


+ 3
- 3
mindspore/nn/optim/proximal_ada_grad.py View File

@@ -20,10 +20,10 @@ from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from .optimizer import Optimizer from .optimizer import Optimizer
proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt")
_proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt")
@proximal_ada_grad_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
@_proximal_ada_grad_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt(opt, learning_rate, l1, l2, gradient, weight, accum): def _tensor_run_opt(opt, learning_rate, l1, l2, gradient, weight, accum):
"""Apply proximal_ada_grad optimizer to the weight parameter.""" """Apply proximal_ada_grad optimizer to the weight parameter."""
success = True success = True
@@ -94,6 +94,6 @@ class ProximalAdagrad(Optimizer):
grads = self.decay_weight(grads) grads = self.decay_weight(grads)
grads = self.scale_grad(grads) grads = self.scale_grad(grads)
lr = self.learning_rate lr = self.learning_rate
success = self.hyper_map(F.partial(proximal_ada_grad_opt, self.opt, lr, self.l1, self.l2),
success = self.hyper_map(F.partial(_proximal_ada_grad_opt, self.opt, lr, self.l1, self.l2),
grads, params, accum) grads, params, accum)
return success return success

+ 11
- 11
mindspore/nn/optim/rmsprop.py View File

@@ -18,21 +18,21 @@ from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from .optimizer import Optimizer from .optimizer import Optimizer


rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
_centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")




@rmsprop_opt.register("Function", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _rmsprop_opt(opt, decay, epsilon, momentum, learning_rate, weight, ms, mom, grad):
@_rmsprop_opt.register("Function", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _rmsprop_opt_(opt, decay, epsilon, momentum, learning_rate, weight, ms, mom, grad):
"""Apply rmsprop optimizer to the weight parameter using dynamic learning rate.""" """Apply rmsprop optimizer to the weight parameter using dynamic learning rate."""
success = True success = True
success = F.depend(success, opt(weight, ms, mom, learning_rate, grad, decay, momentum, epsilon)) success = F.depend(success, opt(weight, ms, mom, learning_rate, grad, decay, momentum, epsilon))
return success return success




@centered_rmsprop_opt.register("Function", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor")
def _centered_rmsprop_opt(opt, decay, epsilon, momentum, learning_rate, weight, mg, ms, mom, grad):
@_centered_rmsprop_opt.register("Function", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor")
def _centered_rmsprop_opt_(opt, decay, epsilon, momentum, learning_rate, weight, mg, ms, mom, grad):
"""Apply centered rmsprop optimizer to the weight parameter using dynamic learning rate.""" """Apply centered rmsprop optimizer to the weight parameter using dynamic learning rate."""
success = True success = True
success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon)) success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon))
@@ -187,17 +187,17 @@ class RMSProp(Optimizer):
lr = self.get_lr() lr = self.get_lr()
if self.centered: if self.centered:
if self.is_group_lr: if self.is_group_lr:
success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, self.decay, self.epsilon,
success = self.hyper_map(F.partial(_centered_rmsprop_opt, self.opt, self.decay, self.epsilon,
self.momentum), lr, params, self.mg, self.ms, self.moment, gradients) self.momentum), lr, params, self.mg, self.ms, self.moment, gradients)
else: else:
success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, self.decay, self.epsilon,
success = self.hyper_map(F.partial(_centered_rmsprop_opt, self.opt, self.decay, self.epsilon,
self.momentum, lr), params, self.mg, self.ms, self.moment, gradients) self.momentum, lr), params, self.mg, self.ms, self.moment, gradients)


else: else:
if self.is_group_lr: if self.is_group_lr:
success = self.hyper_map(F.partial(rmsprop_opt, self.opt, self.decay, self.epsilon,
success = self.hyper_map(F.partial(_rmsprop_opt, self.opt, self.decay, self.epsilon,
self.momentum), lr, params, self.ms, self.moment, gradients) self.momentum), lr, params, self.ms, self.moment, gradients)
else: else:
success = self.hyper_map(F.partial(rmsprop_opt, self.opt, self.decay, self.epsilon,
success = self.hyper_map(F.partial(_rmsprop_opt, self.opt, self.decay, self.epsilon,
self.momentum, lr), params, self.ms, self.moment, gradients) self.momentum, lr), params, self.ms, self.moment, gradients)
return success return success

+ 4
- 4
mindspore/nn/optim/sgd.py View File

@@ -20,10 +20,10 @@ import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from .optimizer import Optimizer from .optimizer import Optimizer


sgd_opt = C.MultitypeFuncGraph("sgd_opt")
_sgd_opt = C.MultitypeFuncGraph("sgd_opt")




@sgd_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
@_sgd_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, accum, stat): def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, accum, stat):
"""Apply sgd optimizer to the weight parameter using Tensor.""" """Apply sgd optimizer to the weight parameter using Tensor."""
success = True success = True
@@ -154,7 +154,7 @@ class SGD(Optimizer):
gradients = self.scale_grad(gradients) gradients = self.scale_grad(gradients)
lr = self.get_lr() lr = self.get_lr()
if self.is_group_lr: if self.is_group_lr:
success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat)
success = self.hyper_map(F.partial(_sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat)
else: else:
success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum, lr), gradients, params, accum, stat)
success = self.hyper_map(F.partial(_sgd_opt, self.opt, self.momentum, lr), gradients, params, accum, stat)
return success return success

Loading…
Cancel
Save