|
|
|
@@ -30,9 +30,8 @@ lars_opt = C.MultitypeFuncGraph("lars_opt") |
|
|
|
def _tensor_run_opt(lars, weight_decay, learning_rate, gradient, weight, decay_flag, lars_flag): |
|
|
|
"""Apply lars optimizer to the weight parameter.""" |
|
|
|
if lars_flag: |
|
|
|
op_reduce = P.ReduceSum() |
|
|
|
w_square_sum = op_reduce(F.square(weight)) |
|
|
|
grad_square_sum = op_reduce(F.square(gradient)) |
|
|
|
op_reduce_sum = P.SquareSumAll() |
|
|
|
w_square_sum, grad_square_sum = op_reduce_sum(weight, gradient) |
|
|
|
if decay_flag: |
|
|
|
grad_t = lars(weight, gradient, w_square_sum, grad_square_sum, weight_decay, learning_rate) |
|
|
|
else: |
|
|
|
|