Browse Source

!9550 Optimize performance of PyNative grad reduce

From: @jojobugfree
Reviewed-by: @kisnwang,@jjfeing
Signed-off-by: @jjfeing
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
e4f1365495
1 changed files with 1 additions and 3 deletions
  1. +1
    -3
      mindspore/nn/wrap/grad_reducer.py

+ 1
- 3
mindspore/nn/wrap/grad_reducer.py View File

@@ -65,9 +65,7 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra
grad = allreduce(grad)
if mean:
degree = F.scalar_cast(degree, F.dtype(grad))
cast_op = P.Cast()
mul_op = P.Mul()
grad = mul_op(grad, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(grad)))
grad = F.tensor_mul(grad, F.cast(F.scalar_to_array(1.0 / degree), F.dtype(grad)))
return grad
return grad



Loading…
Cancel
Save