Browse Source

[bugfix]Momentum does not suport ps cache leading to process aborted

tags/v1.1.0
lizhenyu 5 years ago
parent
commit
a8e15e2dd1
1 changed files with 5 additions and 5 deletions
  1. +5
    -5
      mindspore/nn/optim/momentum.py

+ 5
- 5
mindspore/nn/optim/momentum.py View File

@@ -24,11 +24,11 @@ from .optimizer import Optimizer
_momentum_opt = C.MultitypeFuncGraph("momentum_opt") _momentum_opt = C.MultitypeFuncGraph("momentum_opt")




@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment, ps_parameter):
@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment, ps_parameter, cache_enable):
"""Apply momentum optimizer to the weight parameter using Tensor.""" """Apply momentum optimizer to the weight parameter using Tensor."""
success = True success = True
if ps_parameter:
if ps_parameter and not cache_enable:
op_shape = P.Shape() op_shape = P.Shape()
_ps_pull = P.Pull() _ps_pull = P.Pull()
_ps_push = P.Push("ApplyMomentum", []) _ps_push = P.Push("ApplyMomentum", [])
@@ -146,8 +146,8 @@ class Momentum(Optimizer):
lr = self.get_lr() lr = self.get_lr()
if self.is_group_lr: if self.is_group_lr:
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments, success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments,
self.ps_parameters)
self.ps_parameters, self.cache_enable)
else: else:
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments, success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments,
self.ps_parameters)
self.ps_parameters, self.cache_enable)
return success return success

Loading…
Cancel
Save