From a8e15e2dd1bab8afe48bff1dfe8fd23467e3ca4c Mon Sep 17 00:00:00 2001 From: lizhenyu Date: Tue, 15 Dec 2020 18:50:07 +0800 Subject: [PATCH] [bugfix]Momentum does not suport ps cache leading to process aborted --- mindspore/nn/optim/momentum.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mindspore/nn/optim/momentum.py b/mindspore/nn/optim/momentum.py index 2e74f03264..0b73d5c0c8 100755 --- a/mindspore/nn/optim/momentum.py +++ b/mindspore/nn/optim/momentum.py @@ -24,11 +24,11 @@ from .optimizer import Optimizer _momentum_opt = C.MultitypeFuncGraph("momentum_opt") -@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") -def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment, ps_parameter): +@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") +def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment, ps_parameter, cache_enable): """Apply momentum optimizer to the weight parameter using Tensor.""" success = True - if ps_parameter: + if ps_parameter and not cache_enable: op_shape = P.Shape() _ps_pull = P.Pull() _ps_push = P.Push("ApplyMomentum", []) @@ -146,8 +146,8 @@ class Momentum(Optimizer): lr = self.get_lr() if self.is_group_lr: success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments, - self.ps_parameters) + self.ps_parameters, self.cache_enable) else: success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments, - self.ps_parameters) + self.ps_parameters, self.cache_enable) return success