From 81ce714b2dc9b4d47cba43c4f1ed846b95c41c4d Mon Sep 17 00:00:00 2001 From: Ziyan Date: Tue, 19 May 2020 09:39:40 +0800 Subject: [PATCH] replace square and reducesum with squaresumall in lars --- mindspore/nn/optim/lars.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mindspore/nn/optim/lars.py b/mindspore/nn/optim/lars.py index 73451f3bf5..e3ab616ddd 100755 --- a/mindspore/nn/optim/lars.py +++ b/mindspore/nn/optim/lars.py @@ -30,9 +30,8 @@ lars_opt = C.MultitypeFuncGraph("lars_opt") def _tensor_run_opt(lars, weight_decay, learning_rate, gradient, weight, decay_flag, lars_flag): """Apply lars optimizer to the weight parameter.""" if lars_flag: - op_reduce = P.ReduceSum() - w_square_sum = op_reduce(F.square(weight)) - grad_square_sum = op_reduce(F.square(gradient)) + op_reduce_sum = P.SquareSumAll() + w_square_sum, grad_square_sum = op_reduce_sum(weight, gradient) if decay_flag: grad_t = lars(weight, gradient, w_square_sum, grad_square_sum, weight_decay, learning_rate) else: