Browse Source

!1241 replace square and reducesum with squaresumall in lars

Merge pull request !1241 from gziyan/replace_op_in_lars
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
6dfe32e2d6
1 changed files with 2 additions and 3 deletions
  1. +2
    -3
      mindspore/nn/optim/lars.py

+ 2
- 3
mindspore/nn/optim/lars.py View File

@@ -30,9 +30,8 @@ lars_opt = C.MultitypeFuncGraph("lars_opt")
def _tensor_run_opt(lars, weight_decay, learning_rate, gradient, weight, decay_flag, lars_flag):
"""Apply lars optimizer to the weight parameter."""
if lars_flag:
op_reduce = P.ReduceSum()
w_square_sum = op_reduce(F.square(weight))
grad_square_sum = op_reduce(F.square(gradient))
op_reduce_sum = P.SquareSumAll()
w_square_sum, grad_square_sum = op_reduce_sum(weight, gradient)
if decay_flag:
grad_t = lars(weight, gradient, w_square_sum, grad_square_sum, weight_decay, learning_rate)
else:


Loading…
Cancel
Save