From 23898c735f68a6888905bec9a59d643f0052a42e Mon Sep 17 00:00:00 2001 From: caifubi Date: Sat, 5 Dec 2020 14:29:33 +0800 Subject: [PATCH] Optimize performance of PyNative grad reduce --- mindspore/nn/wrap/grad_reducer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mindspore/nn/wrap/grad_reducer.py b/mindspore/nn/wrap/grad_reducer.py index 9f1e8a16a7..bdfb660db5 100644 --- a/mindspore/nn/wrap/grad_reducer.py +++ b/mindspore/nn/wrap/grad_reducer.py @@ -65,9 +65,7 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra grad = allreduce(grad) if mean: degree = F.scalar_cast(degree, F.dtype(grad)) - cast_op = P.Cast() - mul_op = P.Mul() - grad = mul_op(grad, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(grad))) + grad = F.tensor_mul(grad, F.cast(F.scalar_to_array(1.0 / degree), F.dtype(grad))) return grad return grad