From ee2478c05d8d8d0ef9c61d4da15b9c20ea7e1edd Mon Sep 17 00:00:00 2001 From: lichenever Date: Sat, 21 Nov 2020 17:16:03 +0800 Subject: [PATCH] change send_recv to inner --- .../pipeline_transformer/pipeline_transformer.cc | 13 +++++++++---- .../pipeline_transformer/pipeline_transformer.h | 1 + mindspore/ops/_grad/grad_comm_ops.py | 15 ++++++++------- mindspore/ops/operations/__init__.py | 2 +- mindspore/ops/operations/comm_ops.py | 14 +++++++------- 5 files changed, 26 insertions(+), 19 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc index eddb3f1535..9f7d3dc18a 100644 --- a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc +++ b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc @@ -165,6 +165,9 @@ void PipelineTransformer::ParameterColoring() { parameter->set_stage(graph->stage()); } } + if (*parameter_stage.begin() == stage_ && !virtual_param_) { + virtual_param_ = parameter; + } parameter_color_map[parameter] = parameter_stage; } } @@ -204,7 +207,7 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_; Attr attr_rank = std::make_pair("dest_rank", MakeValue(dest_rank)); OperatorAttrs attrs = {attr_tag, attr_rank}; - auto send_op = CreatOpInstance(attrs, "Send", "send"); + auto send_op = CreatOpInstance(attrs, "_Send", "send"); auto send_node = NewValueNode(send_op); auto prim = GetValueNode(send_node); auto shape_type_pair = GetShapeType(parameter); @@ -230,8 +233,8 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode Attr attr_shape = std::make_pair("shape", shape_type_pair.first); Attr attr_dtype = std::make_pair("dtype", shape_type_pair.second); OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype}; - auto recv_op = CreatOpInstance(attrs, "Receive", "recv"); - std::vector recv_input = {NewValueNode(recv_op)}; + auto recv_op = CreatOpInstance(attrs, "_Receive", "recv"); + std::vector recv_input = {NewValueNode(recv_op), virtual_param_}; auto recv = graph->NewCNode(recv_input); manager_->SetEdge(use_node, index, recv); } @@ -289,7 +292,7 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) { } if (node_stage == user_node_stage) { if (is_shared && (min_tag != node_stage)) { - InsertReceive(graph, node, user_node, user_pair.second, min_tag, stage_); + InsertReceive(graph, node, user_node, user_pair.second, stage_, min_tag); } continue; } @@ -436,6 +439,8 @@ void PipelineTransformer::ElimParameter() { parameter_list.push_back(parameter); } } + auto del_num = parameters.size() - parameter_list.size(); + root_->set_hyper_param_count(root_->hyper_param_count() - del_num); manager_->SetParameters(root_, parameter_list); } } // namespace parallel diff --git a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h index 974cb6c007..a778291ca4 100644 --- a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h +++ b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h @@ -65,6 +65,7 @@ class PipelineTransformer { int64_t per_stage_rank_num_; TypePtr type_ptr_; ValueListPtr shape_; + AnfNodePtr virtual_param_; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ops/_grad/grad_comm_ops.py b/mindspore/ops/_grad/grad_comm_ops.py index 0e9768d2a1..20881d010c 100644 --- a/mindspore/ops/_grad/grad_comm_ops.py +++ b/mindspore/ops/_grad/grad_comm_ops.py @@ -20,7 +20,7 @@ from .. import operations as P from ...common.tensor import RowTensor from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, - _GetTensorSlice, _MirrorOperator, ReduceOp, Send, Receive, + _GetTensorSlice, _MirrorOperator, ReduceOp, _Send, _Receive, ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap) from .grad_base import bprop_getters @@ -77,12 +77,12 @@ def get_bprop_all_reduce(self): return bprop -@bprop_getters.register(Send) +@bprop_getters.register(_Send) def get_bprop_send(self): """Generate bprop for Send.""" shape = self.get_attr_dict()["shape"] dtype = self.get_attr_dict()["dtype"] - send_grad = Receive(self.sr_tag, self.rank, shape, dtype, self.group) + send_grad = _Receive(self.sr_tag, self.rank, shape, dtype, self.group) def bprop(x, out, dout): dx = send_grad() @@ -90,15 +90,16 @@ def get_bprop_send(self): return bprop -@bprop_getters.register(Receive) +@bprop_getters.register(_Receive) def get_bprop_receive(self): """Generate bprop for Receive.""" - receive_grad = Send(self.tag, self.rank, self.group) + receive_grad = _Send(self.tag, self.rank, self.group) depend = P.Depend() + cast = P.Cast() - def bprop(out, dout): + def bprop(x, out, dout): send_out = receive_grad(dout) - dx = depend(dout, send_out) + dx = depend(cast(zeros_like(x), F.dtype(x)), send_out) return (dx,) return bprop diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 7d7f1a2777..6af989c88f 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -36,7 +36,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Unique, GatherD, Identity, RepeatElements) from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, _MirrorOperator, ReduceOp, _VirtualDataset, - _VirtualDiv, _GetTensorSlice, Send, Receive, + _VirtualDiv, _GetTensorSlice, _Send, _Receive, _HostAllGather, _HostReduceScatter) from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, TensorSummary, HistogramSummary, Print, Assert) diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index ed8e985601..7101349a8f 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -116,7 +116,7 @@ class AllReduce(PrimitiveWithInfer): return x_dtype -class Send(PrimitiveWithInfer): +class _Send(PrimitiveWithInfer): """ Send tensors from src_rank to the specified dest_rank. @@ -145,7 +145,7 @@ class Send(PrimitiveWithInfer): >>> def __init__(self): >>> super(Net, self).__init__() >>> self.depend = P.Depend() - >>> self.send = P.Send(st_tag=0, dest_rank=8, group="hccl_world_group") + >>> self.send = P._Send(st_tag=0, dest_rank=8, group="hccl_world_group") >>> >>> def construct(self, x): >>> out = self.depend(x, self.send(x)) @@ -170,7 +170,7 @@ class Send(PrimitiveWithInfer): return x_dtype -class Receive(PrimitiveWithInfer): +class _Receive(PrimitiveWithInfer): """ receive tensors from src_rank. @@ -201,11 +201,11 @@ class Receive(PrimitiveWithInfer): >>> class Net(nn.Cell): >>> def __init__(self): >>> super(Net, self).__init__() - >>> self.send = P.Receive(st_tag=0, src_rank=0, shape=[2, 8], dtype=np.float32, + >>> self.recv = P._Receive(st_tag=0, src_rank=0, shape=[2, 8], dtype=np.float32, >>> group="hccl_world_group") >>> >>> def construct(self, x): - >>> out = self.depend(x, self.send(x)) + >>> out = self.depend(x, self.recv(x)) >>> return out >>> >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32)) @@ -220,10 +220,10 @@ class Receive(PrimitiveWithInfer): self.dtype = dtype self.group = group - def infer_shape(self): + def infer_shape(self, x_shape=None): return self.shape - def infer_dtype(self): + def infer_dtype(self, x_dtype=None): return self.dtype