Browse Source

!8675 support allreduce prod

From: @yao_yf
Reviewed-by: @stsuteng,@kisnwang
Signed-off-by: @stsuteng
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
8980bc3de7
3 changed files with 10 additions and 4 deletions
  1. +9
    -2
      mindspore/ops/_grad/grad_comm_ops.py
  2. +0
    -2
      mindspore/ops/operations/comm_ops.py
  3. +1
    -0
      tests/ut/python/communication/test_comm.py

+ 9
- 2
mindspore/ops/_grad/grad_comm_ops.py View File

@@ -37,11 +37,18 @@ def get_bprop_all_reduce(self):
equal = P.Equal()
cast = P.Cast()
mul = P.Mul()
div = P.RealDiv()
dtype = P.DType()

if self.op == ReduceOp.PROD:
raise RuntimeError("The bprop of ReduceOp.PROD is not supported yet.")
if self.op == ReduceOp.SUM:

def bprop(x, out, dout):
dy1 = mul(dout, out)
dy2 = all_reduce_grad(dy1)
dx = div(dy2, x)
return (dx,)

elif self.op == ReduceOp.SUM:

def bprop(x, out, dout):
if F.issubclass_(F.typeof(dout), mstype.tensor):


+ 0
- 2
mindspore/ops/operations/comm_ops.py View File

@@ -95,8 +95,6 @@ class AllReduce(PrimitiveWithInfer):
def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
if not isinstance(op, type(ReduceOp.SUM)):
raise TypeError("The operation of AllReduce should be str.")
if op == ReduceOp.PROD:
raise RuntimeError("The operation of AllReduce 'prod' is not supported yet.")
if not isinstance(_get_group(group), str):
raise TypeError("The group of AllReduce should be str.")
self.op = op


+ 1
- 0
tests/ut/python/communication/test_comm.py View File

@@ -159,6 +159,7 @@ def test_allreduce():
run_allreduce(ReduceOp.SUM)
run_allreduce(ReduceOp.MAX)
run_allreduce(ReduceOp.MIN)
run_allreduce(ReduceOp.PROD)


def test_allgather():


Loading…
Cancel
Save