|
|
|
@@ -24,11 +24,11 @@ from .optimizer import Optimizer |
|
|
|
_momentum_opt = C.MultitypeFuncGraph("momentum_opt") |
|
|
|
|
|
|
|
|
|
|
|
@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") |
|
|
|
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment, ps_parameter): |
|
|
|
@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") |
|
|
|
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment, ps_parameter, cache_enable): |
|
|
|
"""Apply momentum optimizer to the weight parameter using Tensor.""" |
|
|
|
success = True |
|
|
|
if ps_parameter: |
|
|
|
if ps_parameter and not cache_enable: |
|
|
|
op_shape = P.Shape() |
|
|
|
_ps_pull = P.Pull() |
|
|
|
_ps_push = P.Push("ApplyMomentum", []) |
|
|
|
@@ -146,8 +146,8 @@ class Momentum(Optimizer): |
|
|
|
lr = self.get_lr() |
|
|
|
if self.is_group_lr: |
|
|
|
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments, |
|
|
|
self.ps_parameters) |
|
|
|
self.ps_parameters, self.cache_enable) |
|
|
|
else: |
|
|
|
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments, |
|
|
|
self.ps_parameters) |
|
|
|
self.ps_parameters, self.cache_enable) |
|
|
|
return success |