| @@ -165,6 +165,9 @@ void PipelineTransformer::ParameterColoring() { | |||||
| parameter->set_stage(graph->stage()); | parameter->set_stage(graph->stage()); | ||||
| } | } | ||||
| } | } | ||||
| if (*parameter_stage.begin() == stage_ && !virtual_param_) { | |||||
| virtual_param_ = parameter; | |||||
| } | |||||
| parameter_color_map[parameter] = parameter_stage; | parameter_color_map[parameter] = parameter_stage; | ||||
| } | } | ||||
| } | } | ||||
| @@ -204,7 +207,7 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod | |||||
| auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_; | auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_; | ||||
| Attr attr_rank = std::make_pair("dest_rank", MakeValue(dest_rank)); | Attr attr_rank = std::make_pair("dest_rank", MakeValue(dest_rank)); | ||||
| OperatorAttrs attrs = {attr_tag, attr_rank}; | OperatorAttrs attrs = {attr_tag, attr_rank}; | ||||
| auto send_op = CreatOpInstance(attrs, "Send", "send"); | |||||
| auto send_op = CreatOpInstance(attrs, "_Send", "send"); | |||||
| auto send_node = NewValueNode(send_op); | auto send_node = NewValueNode(send_op); | ||||
| auto prim = GetValueNode<PrimitivePtr>(send_node); | auto prim = GetValueNode<PrimitivePtr>(send_node); | ||||
| auto shape_type_pair = GetShapeType(parameter); | auto shape_type_pair = GetShapeType(parameter); | ||||
| @@ -230,8 +233,8 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode | |||||
| Attr attr_shape = std::make_pair("shape", shape_type_pair.first); | Attr attr_shape = std::make_pair("shape", shape_type_pair.first); | ||||
| Attr attr_dtype = std::make_pair("dtype", shape_type_pair.second); | Attr attr_dtype = std::make_pair("dtype", shape_type_pair.second); | ||||
| OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype}; | OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype}; | ||||
| auto recv_op = CreatOpInstance(attrs, "Receive", "recv"); | |||||
| std::vector<AnfNodePtr> recv_input = {NewValueNode(recv_op)}; | |||||
| auto recv_op = CreatOpInstance(attrs, "_Receive", "recv"); | |||||
| std::vector<AnfNodePtr> recv_input = {NewValueNode(recv_op), virtual_param_}; | |||||
| auto recv = graph->NewCNode(recv_input); | auto recv = graph->NewCNode(recv_input); | ||||
| manager_->SetEdge(use_node, index, recv); | manager_->SetEdge(use_node, index, recv); | ||||
| } | } | ||||
| @@ -289,7 +292,7 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) { | |||||
| } | } | ||||
| if (node_stage == user_node_stage) { | if (node_stage == user_node_stage) { | ||||
| if (is_shared && (min_tag != node_stage)) { | if (is_shared && (min_tag != node_stage)) { | ||||
| InsertReceive(graph, node, user_node, user_pair.second, min_tag, stage_); | |||||
| InsertReceive(graph, node, user_node, user_pair.second, stage_, min_tag); | |||||
| } | } | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -436,6 +439,8 @@ void PipelineTransformer::ElimParameter() { | |||||
| parameter_list.push_back(parameter); | parameter_list.push_back(parameter); | ||||
| } | } | ||||
| } | } | ||||
| auto del_num = parameters.size() - parameter_list.size(); | |||||
| root_->set_hyper_param_count(root_->hyper_param_count() - del_num); | |||||
| manager_->SetParameters(root_, parameter_list); | manager_->SetParameters(root_, parameter_list); | ||||
| } | } | ||||
| } // namespace parallel | } // namespace parallel | ||||
| @@ -65,6 +65,7 @@ class PipelineTransformer { | |||||
| int64_t per_stage_rank_num_; | int64_t per_stage_rank_num_; | ||||
| TypePtr type_ptr_; | TypePtr type_ptr_; | ||||
| ValueListPtr shape_; | ValueListPtr shape_; | ||||
| AnfNodePtr virtual_param_; | |||||
| }; | }; | ||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,7 +20,7 @@ from .. import operations as P | |||||
| from ...common.tensor import RowTensor | from ...common.tensor import RowTensor | ||||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | from ..composite.multitype_ops.zeros_like_impl import zeros_like | ||||
| from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, | from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, | ||||
| _GetTensorSlice, _MirrorOperator, ReduceOp, Send, Receive, | |||||
| _GetTensorSlice, _MirrorOperator, ReduceOp, _Send, _Receive, | |||||
| ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap) | ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap) | ||||
| from .grad_base import bprop_getters | from .grad_base import bprop_getters | ||||
| @@ -77,12 +77,12 @@ def get_bprop_all_reduce(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(Send) | |||||
| @bprop_getters.register(_Send) | |||||
| def get_bprop_send(self): | def get_bprop_send(self): | ||||
| """Generate bprop for Send.""" | """Generate bprop for Send.""" | ||||
| shape = self.get_attr_dict()["shape"] | shape = self.get_attr_dict()["shape"] | ||||
| dtype = self.get_attr_dict()["dtype"] | dtype = self.get_attr_dict()["dtype"] | ||||
| send_grad = Receive(self.sr_tag, self.rank, shape, dtype, self.group) | |||||
| send_grad = _Receive(self.sr_tag, self.rank, shape, dtype, self.group) | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| dx = send_grad() | dx = send_grad() | ||||
| @@ -90,15 +90,16 @@ def get_bprop_send(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(Receive) | |||||
| @bprop_getters.register(_Receive) | |||||
| def get_bprop_receive(self): | def get_bprop_receive(self): | ||||
| """Generate bprop for Receive.""" | """Generate bprop for Receive.""" | ||||
| receive_grad = Send(self.tag, self.rank, self.group) | |||||
| receive_grad = _Send(self.tag, self.rank, self.group) | |||||
| depend = P.Depend() | depend = P.Depend() | ||||
| cast = P.Cast() | |||||
| def bprop(out, dout): | |||||
| def bprop(x, out, dout): | |||||
| send_out = receive_grad(dout) | send_out = receive_grad(dout) | ||||
| dx = depend(dout, send_out) | |||||
| dx = depend(cast(zeros_like(x), F.dtype(x)), send_out) | |||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -36,7 +36,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||||
| Unique, GatherD, Identity, RepeatElements) | Unique, GatherD, Identity, RepeatElements) | ||||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, | from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, | ||||
| _MirrorOperator, ReduceOp, _VirtualDataset, | _MirrorOperator, ReduceOp, _VirtualDataset, | ||||
| _VirtualDiv, _GetTensorSlice, Send, Receive, | |||||
| _VirtualDiv, _GetTensorSlice, _Send, _Receive, | |||||
| _HostAllGather, _HostReduceScatter) | _HostAllGather, _HostReduceScatter) | ||||
| from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | ||||
| TensorSummary, HistogramSummary, Print, Assert) | TensorSummary, HistogramSummary, Print, Assert) | ||||
| @@ -116,7 +116,7 @@ class AllReduce(PrimitiveWithInfer): | |||||
| return x_dtype | return x_dtype | ||||
| class Send(PrimitiveWithInfer): | |||||
| class _Send(PrimitiveWithInfer): | |||||
| """ | """ | ||||
| Send tensors from src_rank to the specified dest_rank. | Send tensors from src_rank to the specified dest_rank. | ||||
| @@ -145,7 +145,7 @@ class Send(PrimitiveWithInfer): | |||||
| >>> def __init__(self): | >>> def __init__(self): | ||||
| >>> super(Net, self).__init__() | >>> super(Net, self).__init__() | ||||
| >>> self.depend = P.Depend() | >>> self.depend = P.Depend() | ||||
| >>> self.send = P.Send(st_tag=0, dest_rank=8, group="hccl_world_group") | |||||
| >>> self.send = P._Send(st_tag=0, dest_rank=8, group="hccl_world_group") | |||||
| >>> | >>> | ||||
| >>> def construct(self, x): | >>> def construct(self, x): | ||||
| >>> out = self.depend(x, self.send(x)) | >>> out = self.depend(x, self.send(x)) | ||||
| @@ -170,7 +170,7 @@ class Send(PrimitiveWithInfer): | |||||
| return x_dtype | return x_dtype | ||||
| class Receive(PrimitiveWithInfer): | |||||
| class _Receive(PrimitiveWithInfer): | |||||
| """ | """ | ||||
| receive tensors from src_rank. | receive tensors from src_rank. | ||||
| @@ -201,11 +201,11 @@ class Receive(PrimitiveWithInfer): | |||||
| >>> class Net(nn.Cell): | >>> class Net(nn.Cell): | ||||
| >>> def __init__(self): | >>> def __init__(self): | ||||
| >>> super(Net, self).__init__() | >>> super(Net, self).__init__() | ||||
| >>> self.send = P.Receive(st_tag=0, src_rank=0, shape=[2, 8], dtype=np.float32, | |||||
| >>> self.recv = P._Receive(st_tag=0, src_rank=0, shape=[2, 8], dtype=np.float32, | |||||
| >>> group="hccl_world_group") | >>> group="hccl_world_group") | ||||
| >>> | >>> | ||||
| >>> def construct(self, x): | >>> def construct(self, x): | ||||
| >>> out = self.depend(x, self.send(x)) | |||||
| >>> out = self.depend(x, self.recv(x)) | |||||
| >>> return out | >>> return out | ||||
| >>> | >>> | ||||
| >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32)) | >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32)) | ||||
| @@ -220,10 +220,10 @@ class Receive(PrimitiveWithInfer): | |||||
| self.dtype = dtype | self.dtype = dtype | ||||
| self.group = group | self.group = group | ||||
| def infer_shape(self): | |||||
| def infer_shape(self, x_shape=None): | |||||
| return self.shape | return self.shape | ||||
| def infer_dtype(self): | |||||
| def infer_dtype(self, x_dtype=None): | |||||
| return self.dtype | return self.dtype | ||||