| @@ -165,6 +165,9 @@ void PipelineTransformer::ParameterColoring() { | |||
| parameter->set_stage(graph->stage()); | |||
| } | |||
| } | |||
| if (*parameter_stage.begin() == stage_ && !virtual_param_) { | |||
| virtual_param_ = parameter; | |||
| } | |||
| parameter_color_map[parameter] = parameter_stage; | |||
| } | |||
| } | |||
| @@ -204,7 +207,7 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod | |||
| auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_; | |||
| Attr attr_rank = std::make_pair("dest_rank", MakeValue(dest_rank)); | |||
| OperatorAttrs attrs = {attr_tag, attr_rank}; | |||
| auto send_op = CreatOpInstance(attrs, "Send", "send"); | |||
| auto send_op = CreatOpInstance(attrs, "_Send", "send"); | |||
| auto send_node = NewValueNode(send_op); | |||
| auto prim = GetValueNode<PrimitivePtr>(send_node); | |||
| auto shape_type_pair = GetShapeType(parameter); | |||
| @@ -230,8 +233,8 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode | |||
| Attr attr_shape = std::make_pair("shape", shape_type_pair.first); | |||
| Attr attr_dtype = std::make_pair("dtype", shape_type_pair.second); | |||
| OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype}; | |||
| auto recv_op = CreatOpInstance(attrs, "Receive", "recv"); | |||
| std::vector<AnfNodePtr> recv_input = {NewValueNode(recv_op)}; | |||
| auto recv_op = CreatOpInstance(attrs, "_Receive", "recv"); | |||
| std::vector<AnfNodePtr> recv_input = {NewValueNode(recv_op), virtual_param_}; | |||
| auto recv = graph->NewCNode(recv_input); | |||
| manager_->SetEdge(use_node, index, recv); | |||
| } | |||
| @@ -289,7 +292,7 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) { | |||
| } | |||
| if (node_stage == user_node_stage) { | |||
| if (is_shared && (min_tag != node_stage)) { | |||
| InsertReceive(graph, node, user_node, user_pair.second, min_tag, stage_); | |||
| InsertReceive(graph, node, user_node, user_pair.second, stage_, min_tag); | |||
| } | |||
| continue; | |||
| } | |||
| @@ -436,6 +439,8 @@ void PipelineTransformer::ElimParameter() { | |||
| parameter_list.push_back(parameter); | |||
| } | |||
| } | |||
| auto del_num = parameters.size() - parameter_list.size(); | |||
| root_->set_hyper_param_count(root_->hyper_param_count() - del_num); | |||
| manager_->SetParameters(root_, parameter_list); | |||
| } | |||
| } // namespace parallel | |||
| @@ -65,6 +65,7 @@ class PipelineTransformer { | |||
| int64_t per_stage_rank_num_; | |||
| TypePtr type_ptr_; | |||
| ValueListPtr shape_; | |||
| AnfNodePtr virtual_param_; | |||
| }; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -20,7 +20,7 @@ from .. import operations as P | |||
| from ...common.tensor import RowTensor | |||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | |||
| from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, | |||
| _GetTensorSlice, _MirrorOperator, ReduceOp, Send, Receive, | |||
| _GetTensorSlice, _MirrorOperator, ReduceOp, _Send, _Receive, | |||
| ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap) | |||
| from .grad_base import bprop_getters | |||
| @@ -77,12 +77,12 @@ def get_bprop_all_reduce(self): | |||
| return bprop | |||
| @bprop_getters.register(Send) | |||
| @bprop_getters.register(_Send) | |||
| def get_bprop_send(self): | |||
| """Generate bprop for Send.""" | |||
| shape = self.get_attr_dict()["shape"] | |||
| dtype = self.get_attr_dict()["dtype"] | |||
| send_grad = Receive(self.sr_tag, self.rank, shape, dtype, self.group) | |||
| send_grad = _Receive(self.sr_tag, self.rank, shape, dtype, self.group) | |||
| def bprop(x, out, dout): | |||
| dx = send_grad() | |||
| @@ -90,15 +90,16 @@ def get_bprop_send(self): | |||
| return bprop | |||
| @bprop_getters.register(Receive) | |||
| @bprop_getters.register(_Receive) | |||
| def get_bprop_receive(self): | |||
| """Generate bprop for Receive.""" | |||
| receive_grad = Send(self.tag, self.rank, self.group) | |||
| receive_grad = _Send(self.tag, self.rank, self.group) | |||
| depend = P.Depend() | |||
| cast = P.Cast() | |||
| def bprop(out, dout): | |||
| def bprop(x, out, dout): | |||
| send_out = receive_grad(dout) | |||
| dx = depend(dout, send_out) | |||
| dx = depend(cast(zeros_like(x), F.dtype(x)), send_out) | |||
| return (dx,) | |||
| return bprop | |||
| @@ -36,7 +36,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||
| Unique, GatherD, Identity, RepeatElements) | |||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, | |||
| _MirrorOperator, ReduceOp, _VirtualDataset, | |||
| _VirtualDiv, _GetTensorSlice, Send, Receive, | |||
| _VirtualDiv, _GetTensorSlice, _Send, _Receive, | |||
| _HostAllGather, _HostReduceScatter) | |||
| from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | |||
| TensorSummary, HistogramSummary, Print, Assert) | |||
| @@ -116,7 +116,7 @@ class AllReduce(PrimitiveWithInfer): | |||
| return x_dtype | |||
| class Send(PrimitiveWithInfer): | |||
| class _Send(PrimitiveWithInfer): | |||
| """ | |||
| Send tensors from src_rank to the specified dest_rank. | |||
| @@ -145,7 +145,7 @@ class Send(PrimitiveWithInfer): | |||
| >>> def __init__(self): | |||
| >>> super(Net, self).__init__() | |||
| >>> self.depend = P.Depend() | |||
| >>> self.send = P.Send(st_tag=0, dest_rank=8, group="hccl_world_group") | |||
| >>> self.send = P._Send(st_tag=0, dest_rank=8, group="hccl_world_group") | |||
| >>> | |||
| >>> def construct(self, x): | |||
| >>> out = self.depend(x, self.send(x)) | |||
| @@ -170,7 +170,7 @@ class Send(PrimitiveWithInfer): | |||
| return x_dtype | |||
| class Receive(PrimitiveWithInfer): | |||
| class _Receive(PrimitiveWithInfer): | |||
| """ | |||
| receive tensors from src_rank. | |||
| @@ -201,11 +201,11 @@ class Receive(PrimitiveWithInfer): | |||
| >>> class Net(nn.Cell): | |||
| >>> def __init__(self): | |||
| >>> super(Net, self).__init__() | |||
| >>> self.send = P.Receive(st_tag=0, src_rank=0, shape=[2, 8], dtype=np.float32, | |||
| >>> self.recv = P._Receive(st_tag=0, src_rank=0, shape=[2, 8], dtype=np.float32, | |||
| >>> group="hccl_world_group") | |||
| >>> | |||
| >>> def construct(self, x): | |||
| >>> out = self.depend(x, self.send(x)) | |||
| >>> out = self.depend(x, self.recv(x)) | |||
| >>> return out | |||
| >>> | |||
| >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32)) | |||
| @@ -220,10 +220,10 @@ class Receive(PrimitiveWithInfer): | |||
| self.dtype = dtype | |||
| self.group = group | |||
| def infer_shape(self): | |||
| def infer_shape(self, x_shape=None): | |||
| return self.shape | |||
| def infer_dtype(self): | |||
| def infer_dtype(self, x_dtype=None): | |||
| return self.dtype | |||