Browse Source

!746 reducescatter backforward operator

Merge pull request !746 from lirongzhen1/bp_reducescatter
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 6 years ago
parent
commit
5b3327d103
3 changed files with 52 additions and 3 deletions
  1. +0
    -1
      mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc
  2. +19
    -1
      mindspore/ops/_grad/grad_comm_ops.py
  3. +33
    -1
      tests/ut/python/communication/test_comm.py

+ 0
- 1
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc View File

@@ -29,7 +29,6 @@

namespace mindspore {
namespace parallel {

// Get the target node's weight for sorting.
double GetWeights(const Graph::NodeType &node) {
const OperatorRec &op = node.apply;


+ 19
- 1
mindspore/ops/_grad/grad_comm_ops.py View File

@@ -67,11 +67,29 @@ def get_bprop_broad_cast(self):
@bprop_getters.register(AllGather)
def get_bprop_all_gather(self):
"""Generate bprop for AllGather"""
reduce_scatter_grad = ReduceScatter(ReduceOp.SUM, self.group)
all_gather_grad = ReduceScatter(ReduceOp.SUM, self.group)
if self.instance_name:
instance_name = "grad" + self.instance_name
all_gather_grad.set_prim_instance_name(instance_name)

def bprop(x, out, dout):
dx = all_gather_grad(dout)
return (dx,)

return bprop


@bprop_getters.register(ReduceScatter)
def get_bprop_reduce_scatter(self):
"""Generate bprop for ReduceScatter"""
reduce_scatter_grad = AllGather(self.group)
if self.instance_name:
instance_name = "grad" + self.instance_name
reduce_scatter_grad.set_prim_instance_name(instance_name)

if self.op != ReduceOp.SUM:
raise RuntimeError("The reducescatter bprop only support ReduceOp.SUM until now.")

def bprop(x, out, dout):
dx = reduce_scatter_grad(dout)
return (dx,)


+ 33
- 1
tests/ut/python/communication/test_comm.py View File

@@ -14,7 +14,7 @@

""" test Communicate """
import numpy as np
from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp
from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp, ReduceScatter
from mindspore.ops.operations.comm_ops import Broadcast
from mindspore.communication.management import HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, GlobalComm, init
from mindspore.communication._comm_helper import Backend
@@ -78,6 +78,19 @@ class AllGatherNet(nn.Cell):
x = self.allgather(x)
return self.relu(x)

class ReduceScatterNet(nn.Cell):
"""ReduceScatterNet definition"""
def __init__(self, input_channel, out_channel, op):
super(ReduceScatterNet, self).__init__()
self.dense = Dense(input_channel, out_channel)
self.reducescatter = ReduceScatter(op)
self.relu = ReLU()

def construct(self, x):
x = self.dense(x)
x = self.reducescatter(x)
return self.relu(x)

class AlltoAllNet(nn.Cell):
"""AlltoAllNet definition"""
def __init__(self, input_channel, out_channel):
@@ -126,6 +139,25 @@ def test_allgather():
network = TrainOneStepCell(network, optimizer)
_executor.compile(network, input_tensor, label_tensor)

def run_reducescatter(op):
"""run_reducescatter"""
context.set_context(mode=context.GRAPH_MODE)
input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32))
label_tensor = Tensor(np.array([[1.2], [2.2]], dtype=np.float32))
network = ReduceScatterNet(2, 1, op)
loss_fn = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()),
learning_rate=0.1,
momentum=0.9)
network = WithLossCell(network, loss_fn)
network = TrainOneStepCell(network, optimizer)
_executor.compile(network, input_tensor, label_tensor)

def test_reducescatter():
"""test_reducescatter"""
context.set_context(mode=context.GRAPH_MODE)
run_reducescatter(ReduceOp.SUM)

def test_broadcast():
"""test_broadcast"""
context.set_context(mode=context.GRAPH_MODE)


Loading…
Cancel
Save