|
|
|
@@ -148,6 +148,9 @@ Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy; |
|
|
|
int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy; |
|
|
|
|
|
|
|
if (pad_mode_ == 0) { // 'pad' mode |
|
|
|
MS_LOG(ERROR) << name_ << ": The 'pad' mode do not support to split H or W"; |
|
|
|
return FAILED; |
|
|
|
@@ -160,8 +163,6 @@ Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) { |
|
|
|
} |
|
|
|
|
|
|
|
if (kernel_size_[0] <= stride_[2] || kernel_size_[1] <= stride_[3]) { |
|
|
|
int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy; |
|
|
|
int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy; |
|
|
|
if (h_slice_shape % stride_[2] != 0 || w_slice_shape % stride_[3] != 0) { |
|
|
|
MS_LOG(ERROR) << name_ |
|
|
|
<< ": The 'same' mode do not support to split H or W when kernel_size <= stride but slice shape " |
|
|
|
@@ -177,24 +178,18 @@ Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (kernel_size_[0] <= stride_[2]) { |
|
|
|
int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy; |
|
|
|
if (h_slice_shape % stride_[2] != 0) { |
|
|
|
MS_LOG(ERROR) << name_ |
|
|
|
<< ": The 'valid' mode do not support to split H when kernel_size <= stride but slice shape is " |
|
|
|
"not divisible by stride "; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
if (kernel_size_[0] <= stride_[2] && h_slice_shape % stride_[2] != 0) { |
|
|
|
MS_LOG(ERROR) << name_ |
|
|
|
<< ": The 'valid' mode do not support to split H when kernel_size <= stride but slice shape is " |
|
|
|
"not divisible by stride "; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (kernel_size_[1] <= stride_[3]) { |
|
|
|
int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy; |
|
|
|
if (w_slice_shape % stride_[3] != 0) { |
|
|
|
MS_LOG(ERROR) << name_ |
|
|
|
<< ": The 'valid' mode do not support to split W when kernel_size <= stride but slice shape is " |
|
|
|
"not divisible by stride "; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
if (kernel_size_[1] <= stride_[3] && w_slice_shape % stride_[3] != 0) { |
|
|
|
MS_LOG(ERROR) << name_ |
|
|
|
<< ": The 'valid' mode do not support to split W when kernel_size <= stride but slice shape is " |
|
|
|
"not divisible by stride "; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -234,6 +229,7 @@ Status Conv2DInfo::CheckStrategyBase(const StrategyPtr &strategy) { |
|
|
|
new_out_channel_ = out_channel_ / weight_strategy[0]; |
|
|
|
} else { |
|
|
|
out_channel_shard_ = false; |
|
|
|
new_out_channel_ = out_channel_; |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
@@ -527,7 +523,19 @@ void Conv2DInfo::InferOverlapShapes() { |
|
|
|
right_recv_shape[3] = overlap_right_size_; |
|
|
|
recv_shapes_.push_back(right_recv_shape); |
|
|
|
} |
|
|
|
MS_LOG(INFO) << name_ << ": the recv shapes is " << recv_shapes_; |
|
|
|
|
|
|
|
if (left_need_send_) { |
|
|
|
Shape left_send_shape = input_slice_shape_; |
|
|
|
left_send_shape[3] = left_rank_overlap_right_size_; |
|
|
|
send_shapes_.push_back(left_send_shape); |
|
|
|
} |
|
|
|
|
|
|
|
if (right_need_send_) { |
|
|
|
Shape right_send_shape = input_slice_shape_; |
|
|
|
right_send_shape[3] = right_rank_overlap_left_size_; |
|
|
|
send_shapes_.push_back(right_send_shape); |
|
|
|
} |
|
|
|
MS_LOG(INFO) << name_ << ": the recv shapes is " << recv_shapes_ << ", the send shapes is " << send_shapes_; |
|
|
|
} |
|
|
|
|
|
|
|
void Conv2DInfo::InferStridedSliceAttrs() { |
|
|
|
@@ -536,9 +544,6 @@ void Conv2DInfo::InferStridedSliceAttrs() { |
|
|
|
left_strided_slice_end_ = input_slice_shape_; |
|
|
|
left_strided_slice_end_[3] = left_rank_overlap_right_size_; |
|
|
|
left_strided_slice_strides_ = {1, 1, 1, 1}; |
|
|
|
Shape left_send_shape = input_slice_shape_; |
|
|
|
left_send_shape[3] = left_rank_overlap_right_size_; |
|
|
|
send_shapes_.push_back(left_send_shape); |
|
|
|
MS_LOG(INFO) << name_ << ": The left strided slice begin is " << left_strided_slice_begin_ << ", end is " |
|
|
|
<< left_strided_slice_end_; |
|
|
|
} |
|
|
|
@@ -548,9 +553,6 @@ void Conv2DInfo::InferStridedSliceAttrs() { |
|
|
|
right_strided_slice_begin_[3] = input_slice_shape_[3] - right_rank_overlap_left_size_; |
|
|
|
right_strided_slice_end_ = input_slice_shape_; |
|
|
|
right_strided_slice_strides_ = {1, 1, 1, 1}; |
|
|
|
Shape right_send_shape = input_slice_shape_; |
|
|
|
right_send_shape[3] = right_rank_overlap_left_size_; |
|
|
|
send_shapes_.push_back(right_send_shape); |
|
|
|
MS_LOG(INFO) << name_ << ": The right strided slice begin is " << right_strided_slice_begin_ << ", end is " |
|
|
|
<< right_strided_slice_end_; |
|
|
|
} |
|
|
|
@@ -566,7 +568,7 @@ void Conv2DInfo::InferNewOperatorAttrs() { |
|
|
|
InferStridedSliceAttrs(); |
|
|
|
} |
|
|
|
|
|
|
|
OperatorAttrs Conv2DInfo::CreatNeighborExchangeAttrs(const CNodePtr &cnode) { |
|
|
|
OperatorAttrs Conv2DInfo::CreateNeighborExchangeAttrs(const CNodePtr &cnode) { |
|
|
|
auto type = cnode->Type(); |
|
|
|
MS_EXCEPTION_IF_NULL(type); |
|
|
|
auto tensor_type = type->cast<mindspore::TensorTypePtr>(); |
|
|
|
@@ -582,7 +584,7 @@ OperatorAttrs Conv2DInfo::CreatNeighborExchangeAttrs(const CNodePtr &cnode) { |
|
|
|
return attrs; |
|
|
|
} |
|
|
|
|
|
|
|
OperatorAttrs Conv2DInfo::CreatConv2DAttrs() { |
|
|
|
OperatorAttrs Conv2DInfo::CreateConv2DAttrs() { |
|
|
|
Attr out_channel = {OUT_CHANNEL, MakeValue(new_out_channel_)}; |
|
|
|
Attr kernel_size = {KERNEL_SIZE, MakeValue(kernel_size_)}; |
|
|
|
Attr mode = {MODE, MakeValue(mode_)}; |
|
|
|
@@ -592,65 +594,130 @@ OperatorAttrs Conv2DInfo::CreatConv2DAttrs() { |
|
|
|
Attr dilation = {DILATION, MakeValue(dilation_)}; |
|
|
|
Attr group = {GROUP, MakeValue(group_)}; |
|
|
|
Attr data_format = {DATA_FORMAT, MakeValue(format_)}; |
|
|
|
OperatorAttrs attrs = {out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group, data_format}; |
|
|
|
|
|
|
|
OperatorAttrs attrs; |
|
|
|
if (name_.find(CONV2D_INFO) != std::string::npos) { |
|
|
|
attrs = {out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group, data_format}; |
|
|
|
} else { // Conv2DTranspose |
|
|
|
attrs = {out_channel, kernel_size, pad_mode, pad, pad, mode, stride, dilation, group, data_format}; |
|
|
|
} |
|
|
|
|
|
|
|
return attrs; |
|
|
|
} |
|
|
|
|
|
|
|
std::string Conv2DInfo::ReplaceNodeName() { |
|
|
|
if (name_.find(CONV2D_INFO) != std::string::npos) { |
|
|
|
return CONV2D; |
|
|
|
} |
|
|
|
|
|
|
|
if (name_.find(CONV2D_BACK_PROP_INPUT_INFO) != std::string::npos) { |
|
|
|
return CONV2D_BACK_PROP_INPUT; |
|
|
|
} |
|
|
|
|
|
|
|
if (name_.find(CONV2D_TRANSPOSE_INFO) != std::string::npos) { |
|
|
|
return CONV2D_TRANSPOSE; |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "Invalid name: " << name_; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr Conv2DInfo::GenerateConv2DNode(const AnfNodePtr &new_input, const CNodePtr &cnode) { |
|
|
|
auto conv2d_attrs = CreateConv2DAttrs(); |
|
|
|
auto node_name = ReplaceNodeName(); |
|
|
|
|
|
|
|
// conv2d |
|
|
|
if (name_.find(CONV2D_INFO) != std::string::npos) { |
|
|
|
if (cnode->size() < 3) { |
|
|
|
MS_LOG(EXCEPTION) << name_ << ": The size of cnode is invalid: " << cnode->size(); |
|
|
|
} |
|
|
|
return gen_g_.PushBack({gen_g_.NewOpInst(node_name, conv2d_attrs), new_input, cnode->input(2)}); |
|
|
|
} |
|
|
|
|
|
|
|
// conv2dtranspose |
|
|
|
if (cnode->size() < 4) { |
|
|
|
MS_LOG(EXCEPTION) << name_ << ": The size of cnode is invalid: " << cnode->size(); |
|
|
|
} |
|
|
|
return gen_g_.PushBack({gen_g_.NewOpInst(node_name, conv2d_attrs), new_input, cnode->input(2), cnode->input(3)}); |
|
|
|
} |
|
|
|
|
|
|
|
Status Conv2DInfo::ComputeReplaceGraph(const CNodePtr &cnode) { |
|
|
|
auto graph = cnode->func_graph(); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
GenerateGraph gen_g = GenerateGraph(attrs_); |
|
|
|
if (gen_g.Init(cnode) != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << "GenerateGraph Init failed"; |
|
|
|
return FAILED; |
|
|
|
|
|
|
|
if (gen_g_.Init(cnode) != SUCCESS) { |
|
|
|
MS_LOG(EXCEPTION) << "GenerateGraph Init failed"; |
|
|
|
} |
|
|
|
|
|
|
|
if (!left_need_send_ && !right_need_send_) { |
|
|
|
MS_LOG(EXCEPTION) << name_ << ": Now do not support left no need to send and right no need to send"; |
|
|
|
} |
|
|
|
|
|
|
|
if (!left_need_recv_ && !right_need_recv_) { |
|
|
|
MS_LOG(EXCEPTION) << name_ << ": Now do not support left no need to recv and right no need to recv"; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes; |
|
|
|
std::vector<AnfNodePtr> make_tuple_a_inputs = {NewValueNode(prim::kPrimMakeTuple)}; |
|
|
|
if (left_need_send_) { |
|
|
|
auto slice_left_begin = CreatTuple(left_strided_slice_begin_); |
|
|
|
auto slice_left_end = CreatTuple(left_strided_slice_end_); |
|
|
|
auto slice_left_strided = CreatTuple(left_strided_slice_strides_); |
|
|
|
auto slice_left = gen_g.PushBack( |
|
|
|
{gen_g.NewOpInst(STRIDED_SLICE), cnode->input(1), slice_left_begin, slice_left_end, slice_left_strided}); |
|
|
|
auto slice_left_begin = CreateTuple(left_strided_slice_begin_); |
|
|
|
auto slice_left_end = CreateTuple(left_strided_slice_end_); |
|
|
|
auto slice_left_strided = CreateTuple(left_strided_slice_strides_); |
|
|
|
auto slice_left = gen_g_.PushBack({gen_g_.NewOpInst(STRIDED_SLICE), gen_g_.virtual_input_node(), slice_left_begin, |
|
|
|
slice_left_end, slice_left_strided}); |
|
|
|
make_tuple_a_inputs.push_back(slice_left); |
|
|
|
input_nodes.push_back(std::make_pair(slice_left, 1)); |
|
|
|
} |
|
|
|
if (right_need_send_) { |
|
|
|
auto slice_right_begin = CreatTuple(right_strided_slice_begin_); |
|
|
|
auto slice_right_end = CreatTuple(right_strided_slice_end_); |
|
|
|
auto slice_right_strided = CreatTuple(right_strided_slice_strides_); |
|
|
|
auto slice_right = gen_g.PushBack( |
|
|
|
{gen_g.NewOpInst(STRIDED_SLICE), cnode->input(1), slice_right_begin, slice_right_end, slice_right_strided}); |
|
|
|
auto slice_right_begin = CreateTuple(right_strided_slice_begin_); |
|
|
|
auto slice_right_end = CreateTuple(right_strided_slice_end_); |
|
|
|
auto slice_right_strided = CreateTuple(right_strided_slice_strides_); |
|
|
|
auto slice_right = gen_g_.PushBack({gen_g_.NewOpInst(STRIDED_SLICE), gen_g_.virtual_input_node(), slice_right_begin, |
|
|
|
slice_right_end, slice_right_strided}); |
|
|
|
make_tuple_a_inputs.push_back(slice_right); |
|
|
|
input_nodes.push_back(std::make_pair(slice_right, 1)); |
|
|
|
} |
|
|
|
|
|
|
|
auto make_tuple_a = graph->NewCNode(make_tuple_a_inputs); |
|
|
|
auto alltoall_attrs = CreatNeighborExchangeAttrs(cnode); |
|
|
|
auto alltoall_v = gen_g.PushBack({gen_g.NewOpInst(NEIGHBOREXCHANGE, alltoall_attrs), make_tuple_a}); |
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; |
|
|
|
auto alltoall_attrs = CreateNeighborExchangeAttrs(cnode); |
|
|
|
auto alltoall_v = gen_g_.PushBack({gen_g_.NewOpInst(NEIGHBOREXCHANGE, alltoall_attrs), make_tuple_a}); |
|
|
|
|
|
|
|
AnfNodePtr conv2d; |
|
|
|
Attr concat_axis = {AXIS, MakeValue(-1)}; |
|
|
|
OperatorAttrs concat_attrs = {concat_axis}; |
|
|
|
|
|
|
|
if (left_need_recv_) { |
|
|
|
std::vector<AnfNodePtr> tuple_getitem_l_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v, |
|
|
|
CreatInt64Imm(0)}; |
|
|
|
auto tuple_getitem_l = graph->NewCNode(tuple_getitem_l_inputs); |
|
|
|
std::vector<AnfNodePtr> make_tuple_l_inputs = {NewValueNode(prim::kPrimMakeTuple), cnode->input(1), |
|
|
|
tuple_getitem_l}; |
|
|
|
std::vector<AnfNodePtr> make_tuple_l_inputs = {NewValueNode(prim::kPrimMakeTuple), tuple_getitem_l, |
|
|
|
cnode->input(1)}; |
|
|
|
auto make_tuple_l = graph->NewCNode(make_tuple_l_inputs); |
|
|
|
auto concat_l = gen_g.PushBack({gen_g.NewOpInst(CONCAT), make_tuple_l}); |
|
|
|
make_tuple_inputs.push_back(concat_l); |
|
|
|
} |
|
|
|
if (right_need_recv_) { |
|
|
|
std::vector<AnfNodePtr> tuple_getitem_r_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v, |
|
|
|
CreatInt64Imm(0)}; |
|
|
|
auto tuple_getitem_r = graph->NewCNode(tuple_getitem_r_inputs); |
|
|
|
make_tuple_inputs.push_back(tuple_getitem_r); |
|
|
|
} else { |
|
|
|
make_tuple_inputs.push_back(cnode->input(1)); |
|
|
|
auto concat_l = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_l}); |
|
|
|
|
|
|
|
if (right_need_recv_) { |
|
|
|
std::vector<AnfNodePtr> tuple_getitem_r_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v, |
|
|
|
CreatInt64Imm(1)}; |
|
|
|
auto tuple_getitem_r = graph->NewCNode(tuple_getitem_r_inputs); |
|
|
|
std::vector<AnfNodePtr> make_tuple_r_inputs = {NewValueNode(prim::kPrimMakeTuple), concat_l, tuple_getitem_r}; |
|
|
|
auto make_tuple_r = graph->NewCNode(make_tuple_r_inputs); |
|
|
|
auto concat_r = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_r}); |
|
|
|
conv2d = GenerateConv2DNode(concat_r, cnode); |
|
|
|
} else { |
|
|
|
conv2d = GenerateConv2DNode(concat_l, cnode); |
|
|
|
} |
|
|
|
} else { // left no need recv, and right need recv |
|
|
|
std::vector<AnfNodePtr> tuple_getitem_r_inputs_1 = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v, |
|
|
|
CreatInt64Imm(0)}; |
|
|
|
auto tuple_getitem_r_1 = graph->NewCNode(tuple_getitem_r_inputs_1); |
|
|
|
std::vector<AnfNodePtr> make_tuple_r_inputs_1 = {NewValueNode(prim::kPrimMakeTuple), gen_g_.virtual_input_node(), |
|
|
|
tuple_getitem_r_1}; |
|
|
|
auto make_tuple_r_1 = graph->NewCNode(make_tuple_r_inputs_1); |
|
|
|
input_nodes.push_back(std::make_pair(make_tuple_r_1, 1)); |
|
|
|
|
|
|
|
auto concat_r_1 = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_r_1}); |
|
|
|
conv2d = GenerateConv2DNode(concat_r_1, cnode); |
|
|
|
} |
|
|
|
auto make_tuple = graph->NewCNode(make_tuple_inputs); |
|
|
|
Attr concat_axis = {AXIS, MakeValue(-1)}; |
|
|
|
OperatorAttrs concat_attrs = {concat_axis}; |
|
|
|
std::vector<AnfNodePtr> concat_inputs = {gen_g.NewOpInst(CONCAT, concat_attrs), make_tuple}; |
|
|
|
auto concat = graph->NewCNode(concat_inputs); |
|
|
|
auto conv2d_attrs = CreatConv2DAttrs(); |
|
|
|
auto conv2d = gen_g.PushBack({gen_g.NewOpInst(CONV2D, conv2d_attrs), concat, cnode->input(2)}); |
|
|
|
|
|
|
|
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>( |
|
|
|
std::make_pair(input_nodes, conv2d)); |
|
|
|
return SUCCESS; |
|
|
|
|