Browse Source

replace square and reducesum with squaresumall in lars

tags/v0.3.0-alpha
Ziyan 6 years ago
parent
commit
81ce714b2d
1 changed files with 2 additions and 3 deletions
  1. +2
    -3
      mindspore/nn/optim/lars.py

+ 2
- 3
mindspore/nn/optim/lars.py View File

@@ -30,9 +30,8 @@ lars_opt = C.MultitypeFuncGraph("lars_opt")
def _tensor_run_opt(lars, weight_decay, learning_rate, gradient, weight, decay_flag, lars_flag):
"""Apply lars optimizer to the weight parameter."""
if lars_flag:
op_reduce = P.ReduceSum()
w_square_sum = op_reduce(F.square(weight))
grad_square_sum = op_reduce(F.square(gradient))
op_reduce_sum = P.SquareSumAll()
w_square_sum, grad_square_sum = op_reduce_sum(weight, gradient)
if decay_flag:
grad_t = lars(weight, gradient, w_square_sum, grad_square_sum, weight_decay, learning_rate)
else:


Loading…
Cancel
Save