| @@ -24,11 +24,11 @@ from .optimizer import Optimizer | |||||
| _momentum_opt = C.MultitypeFuncGraph("momentum_opt") | _momentum_opt = C.MultitypeFuncGraph("momentum_opt") | ||||
| @_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") | |||||
| def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment, ps_parameter): | |||||
| @_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") | |||||
| def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment, ps_parameter, cache_enable): | |||||
| """Apply momentum optimizer to the weight parameter using Tensor.""" | """Apply momentum optimizer to the weight parameter using Tensor.""" | ||||
| success = True | success = True | ||||
| if ps_parameter: | |||||
| if ps_parameter and not cache_enable: | |||||
| op_shape = P.Shape() | op_shape = P.Shape() | ||||
| _ps_pull = P.Pull() | _ps_pull = P.Pull() | ||||
| _ps_push = P.Push("ApplyMomentum", []) | _ps_push = P.Push("ApplyMomentum", []) | ||||
| @@ -146,8 +146,8 @@ class Momentum(Optimizer): | |||||
| lr = self.get_lr() | lr = self.get_lr() | ||||
| if self.is_group_lr: | if self.is_group_lr: | ||||
| success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments, | success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments, | ||||
| self.ps_parameters) | |||||
| self.ps_parameters, self.cache_enable) | |||||
| else: | else: | ||||
| success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments, | success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments, | ||||
| self.ps_parameters) | |||||
| self.ps_parameters, self.cache_enable) | |||||
| return success | return success | ||||