Browse Source

change send_recv to inner

tags/v1.1.0
lichenever 5 years ago
parent
commit
ee2478c05d
5 changed files with 26 additions and 19 deletions
  1. +9
    -4
      mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc
  2. +1
    -0
      mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h
  3. +8
    -7
      mindspore/ops/_grad/grad_comm_ops.py
  4. +1
    -1
      mindspore/ops/operations/__init__.py
  5. +7
    -7
      mindspore/ops/operations/comm_ops.py

+ 9
- 4
mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc View File

@@ -165,6 +165,9 @@ void PipelineTransformer::ParameterColoring() {
parameter->set_stage(graph->stage()); parameter->set_stage(graph->stage());
} }
} }
if (*parameter_stage.begin() == stage_ && !virtual_param_) {
virtual_param_ = parameter;
}
parameter_color_map[parameter] = parameter_stage; parameter_color_map[parameter] = parameter_stage;
} }
} }
@@ -204,7 +207,7 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_; auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_;
Attr attr_rank = std::make_pair("dest_rank", MakeValue(dest_rank)); Attr attr_rank = std::make_pair("dest_rank", MakeValue(dest_rank));
OperatorAttrs attrs = {attr_tag, attr_rank}; OperatorAttrs attrs = {attr_tag, attr_rank};
auto send_op = CreatOpInstance(attrs, "Send", "send");
auto send_op = CreatOpInstance(attrs, "_Send", "send");
auto send_node = NewValueNode(send_op); auto send_node = NewValueNode(send_op);
auto prim = GetValueNode<PrimitivePtr>(send_node); auto prim = GetValueNode<PrimitivePtr>(send_node);
auto shape_type_pair = GetShapeType(parameter); auto shape_type_pair = GetShapeType(parameter);
@@ -230,8 +233,8 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode
Attr attr_shape = std::make_pair("shape", shape_type_pair.first); Attr attr_shape = std::make_pair("shape", shape_type_pair.first);
Attr attr_dtype = std::make_pair("dtype", shape_type_pair.second); Attr attr_dtype = std::make_pair("dtype", shape_type_pair.second);
OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype}; OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype};
auto recv_op = CreatOpInstance(attrs, "Receive", "recv");
std::vector<AnfNodePtr> recv_input = {NewValueNode(recv_op)};
auto recv_op = CreatOpInstance(attrs, "_Receive", "recv");
std::vector<AnfNodePtr> recv_input = {NewValueNode(recv_op), virtual_param_};
auto recv = graph->NewCNode(recv_input); auto recv = graph->NewCNode(recv_input);
manager_->SetEdge(use_node, index, recv); manager_->SetEdge(use_node, index, recv);
} }
@@ -289,7 +292,7 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
} }
if (node_stage == user_node_stage) { if (node_stage == user_node_stage) {
if (is_shared && (min_tag != node_stage)) { if (is_shared && (min_tag != node_stage)) {
InsertReceive(graph, node, user_node, user_pair.second, min_tag, stage_);
InsertReceive(graph, node, user_node, user_pair.second, stage_, min_tag);
} }
continue; continue;
} }
@@ -436,6 +439,8 @@ void PipelineTransformer::ElimParameter() {
parameter_list.push_back(parameter); parameter_list.push_back(parameter);
} }
} }
auto del_num = parameters.size() - parameter_list.size();
root_->set_hyper_param_count(root_->hyper_param_count() - del_num);
manager_->SetParameters(root_, parameter_list); manager_->SetParameters(root_, parameter_list);
} }
} // namespace parallel } // namespace parallel


+ 1
- 0
mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h View File

@@ -65,6 +65,7 @@ class PipelineTransformer {
int64_t per_stage_rank_num_; int64_t per_stage_rank_num_;
TypePtr type_ptr_; TypePtr type_ptr_;
ValueListPtr shape_; ValueListPtr shape_;
AnfNodePtr virtual_param_;
}; };
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore


+ 8
- 7
mindspore/ops/_grad/grad_comm_ops.py View File

@@ -20,7 +20,7 @@ from .. import operations as P
from ...common.tensor import RowTensor from ...common.tensor import RowTensor
from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
_GetTensorSlice, _MirrorOperator, ReduceOp, Send, Receive,
_GetTensorSlice, _MirrorOperator, ReduceOp, _Send, _Receive,
ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap) ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap)
from .grad_base import bprop_getters from .grad_base import bprop_getters


@@ -77,12 +77,12 @@ def get_bprop_all_reduce(self):
return bprop return bprop




@bprop_getters.register(Send)
@bprop_getters.register(_Send)
def get_bprop_send(self): def get_bprop_send(self):
"""Generate bprop for Send.""" """Generate bprop for Send."""
shape = self.get_attr_dict()["shape"] shape = self.get_attr_dict()["shape"]
dtype = self.get_attr_dict()["dtype"] dtype = self.get_attr_dict()["dtype"]
send_grad = Receive(self.sr_tag, self.rank, shape, dtype, self.group)
send_grad = _Receive(self.sr_tag, self.rank, shape, dtype, self.group)


def bprop(x, out, dout): def bprop(x, out, dout):
dx = send_grad() dx = send_grad()
@@ -90,15 +90,16 @@ def get_bprop_send(self):
return bprop return bprop




@bprop_getters.register(Receive)
@bprop_getters.register(_Receive)
def get_bprop_receive(self): def get_bprop_receive(self):
"""Generate bprop for Receive.""" """Generate bprop for Receive."""
receive_grad = Send(self.tag, self.rank, self.group)
receive_grad = _Send(self.tag, self.rank, self.group)
depend = P.Depend() depend = P.Depend()
cast = P.Cast()


def bprop(out, dout):
def bprop(x, out, dout):
send_out = receive_grad(dout) send_out = receive_grad(dout)
dx = depend(dout, send_out)
dx = depend(cast(zeros_like(x), F.dtype(x)), send_out)
return (dx,) return (dx,)
return bprop return bprop




+ 1
- 1
mindspore/ops/operations/__init__.py View File

@@ -36,7 +36,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Unique, GatherD, Identity, RepeatElements) Unique, GatherD, Identity, RepeatElements)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
_MirrorOperator, ReduceOp, _VirtualDataset, _MirrorOperator, ReduceOp, _VirtualDataset,
_VirtualDiv, _GetTensorSlice, Send, Receive,
_VirtualDiv, _GetTensorSlice, _Send, _Receive,
_HostAllGather, _HostReduceScatter) _HostAllGather, _HostReduceScatter)
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
TensorSummary, HistogramSummary, Print, Assert) TensorSummary, HistogramSummary, Print, Assert)


+ 7
- 7
mindspore/ops/operations/comm_ops.py View File

@@ -116,7 +116,7 @@ class AllReduce(PrimitiveWithInfer):
return x_dtype return x_dtype




class Send(PrimitiveWithInfer):
class _Send(PrimitiveWithInfer):
""" """
Send tensors from src_rank to the specified dest_rank. Send tensors from src_rank to the specified dest_rank.


@@ -145,7 +145,7 @@ class Send(PrimitiveWithInfer):
>>> def __init__(self): >>> def __init__(self):
>>> super(Net, self).__init__() >>> super(Net, self).__init__()
>>> self.depend = P.Depend() >>> self.depend = P.Depend()
>>> self.send = P.Send(st_tag=0, dest_rank=8, group="hccl_world_group")
>>> self.send = P._Send(st_tag=0, dest_rank=8, group="hccl_world_group")
>>> >>>
>>> def construct(self, x): >>> def construct(self, x):
>>> out = self.depend(x, self.send(x)) >>> out = self.depend(x, self.send(x))
@@ -170,7 +170,7 @@ class Send(PrimitiveWithInfer):
return x_dtype return x_dtype




class Receive(PrimitiveWithInfer):
class _Receive(PrimitiveWithInfer):
""" """
receive tensors from src_rank. receive tensors from src_rank.


@@ -201,11 +201,11 @@ class Receive(PrimitiveWithInfer):
>>> class Net(nn.Cell): >>> class Net(nn.Cell):
>>> def __init__(self): >>> def __init__(self):
>>> super(Net, self).__init__() >>> super(Net, self).__init__()
>>> self.send = P.Receive(st_tag=0, src_rank=0, shape=[2, 8], dtype=np.float32,
>>> self.recv = P._Receive(st_tag=0, src_rank=0, shape=[2, 8], dtype=np.float32,
>>> group="hccl_world_group") >>> group="hccl_world_group")
>>> >>>
>>> def construct(self, x): >>> def construct(self, x):
>>> out = self.depend(x, self.send(x))
>>> out = self.depend(x, self.recv(x))
>>> return out >>> return out
>>> >>>
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32)) >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
@@ -220,10 +220,10 @@ class Receive(PrimitiveWithInfer):
self.dtype = dtype self.dtype = dtype
self.group = group self.group = group


def infer_shape(self):
def infer_shape(self, x_shape=None):
return self.shape return self.shape


def infer_dtype(self):
def infer_dtype(self, x_dtype=None):
return self.dtype return self.dtype






Loading…
Cancel
Save