|
|
|
@@ -65,9 +65,7 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra |
|
|
|
grad = allreduce(grad) |
|
|
|
if mean: |
|
|
|
degree = F.scalar_cast(degree, F.dtype(grad)) |
|
|
|
cast_op = P.Cast() |
|
|
|
mul_op = P.Mul() |
|
|
|
grad = mul_op(grad, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(grad))) |
|
|
|
grad = F.tensor_mul(grad, F.cast(F.scalar_to_array(1.0 / degree), F.dtype(grad))) |
|
|
|
return grad |
|
|
|
return grad |
|
|
|
|
|
|
|
|