/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tools/optimizer/fisson/fisson_util.h" #include #include #include "base/core_ops.h" #include "src/common/utils.h" #include "ops/split_with_overlap.h" #include "tools/common/node_util.h" #include "ops/concat.h" #include "tools/optimizer/parallel/spliter.h" #include "tools/optimizer/parallel/split_strategy.h" #include "nnacl/op_base.h" #include "src/common/log_util.h" using mindspore::converter::FmkType; namespace mindspore { namespace opt { std::vector GetSplitPadList(const std::shared_ptr &ori_conv_prim, int64_t input_h, int64_t input_w) { if (ori_conv_prim == nullptr) { MS_LOG(DEBUG) << "input Conv2DFusion is nullptr"; return {}; } if (ori_conv_prim->get_pad_mode() != SAME) { return ori_conv_prim->get_pad_list(); } if (ori_conv_prim->get_stride().size() < kIndexW || ori_conv_prim->get_kernel_size().size() < kIndexW || ori_conv_prim->get_dilation().size() < kIndexW) { MS_LOG(ERROR) << "Index out of range"; return {}; } int64_t output_h = static_cast( std::ceil(static_cast(input_h) / static_cast(ori_conv_prim->get_stride().at(kIndexH)))); int64_t output_w = static_cast( std::ceil(static_cast(input_w) / static_cast(ori_conv_prim->get_stride().at(kIndexW)))); auto kernel_h = ori_conv_prim->get_kernel_size().at(kIndexH); auto dilation_h = ori_conv_prim->get_dilation().at(kIndexH); auto kernel_w = ori_conv_prim->get_kernel_size().at(kIndexW); auto dilation_w = ori_conv_prim->get_dilation().at(kIndexW); if (INT_MUL_OVERFLOW_THRESHOLD((kernel_h - 1), dilation_h, INT64_MAX) || INT_MUL_OVERFLOW_THRESHOLD((kernel_w - 1), dilation_w, INT64_MAX)) { MS_LOG(ERROR) << "int mul overflow"; return {}; } std::vector new_pad_list; int64_t pad_up = 0, pad_down = 0, pad_left = 0, pad_right = 0; int64_t pad_h_all = (output_h - 1) * ori_conv_prim->get_stride().at(kIndexH) + (kernel_h - 1) * dilation_h + 1 - input_h; int64_t pad_w_all = (output_w - 1) * ori_conv_prim->get_stride().at(kIndexW) + (kernel_w - 1) * dilation_w + 1 - input_w; // only check pad_up and pad_down is positive // if compute overflowed, we will get abnormal it in infer_shape if (pad_h_all >= 0) { pad_up = pad_h_all / 2; pad_down = pad_h_all - pad_up; } new_pad_list.push_back(pad_up); new_pad_list.push_back(pad_down); if (pad_w_all >= 0) { pad_left = pad_w_all / 2; pad_right = pad_w_all - pad_left; } new_pad_list.push_back(pad_left); new_pad_list.push_back(pad_right); return new_pad_list; } namespace { bool CalSplitOutputShape(int64_t splited_axis_value, const SplitInfo *split_info, std::vector *split_axis_out_shape, std::vector *split_axis_reduce_out_shape) { MS_ASSERT(split_info != nullptr && split_axis_out_shape != nullptr && split_axis_reduce_out_shape != nullptr); // ori ratio int64_t split_num = split_info->out_num; int64_t split_len = 0; for (int64_t i = 0; i < split_num; i++) { split_len += split_info->size_splits[i]; } if (split_len > splited_axis_value) { return false; } // out-shape after splited int64_t tmp_value = 0; MS_CHECK_TRUE_MSG(split_num > 0, false, "out_num of split_info should be greater than zero"); MS_CHECK_TRUE_MSG(split_len > 0, false, "split_len should be greater than zero"); for (int64_t i = 0; i < split_num - 1; i++) { if (INT_MUL_OVERFLOW_THRESHOLD(split_info->size_splits[i], splited_axis_value, INT64_MAX)) { MS_LOG(ERROR) << "int mul overflow"; return false; } int64_t tmp = UP_DIV(split_info->size_splits[i] * splited_axis_value, split_len); tmp_value += tmp; split_axis_out_shape->push_back(tmp); split_axis_reduce_out_shape->push_back(tmp_value); } split_axis_out_shape->push_back(splited_axis_value - tmp_value); split_axis_reduce_out_shape->push_back(splited_axis_value); return true; } bool CalSplitInShape(const std::vector> &node_in_out_shapes, const SplitInfo *split_info, const std::shared_ptr &ori_conv_prim, size_t index_node, std::vector> *split_axis_inputs_shape, std::vector> *split_axis_reduce_inputs_shape) { MS_ASSERT(split_info != nullptr && ori_conv_prim != nullptr && split_axis_inputs_shape != nullptr && split_axis_reduce_inputs_shape != nullptr); MS_ASSERT(node_in_out_shapes.size() > index_node); auto in_out_shape = node_in_out_shapes.at(index_node); MS_ASSERT(!in_out_shape.empty()); auto in_shape = in_out_shape.front(); if (in_shape.size() < kAxisW) { MS_LOG(DEBUG) << "out of in_shape range"; return false; } int64_t input_h = in_shape.at(kAxisH); int64_t input_w = in_shape.at(kAxisW); auto new_pad_list = GetSplitPadList(ori_conv_prim, input_h, input_w); ori_conv_prim->set_pad_list(new_pad_list); int64_t split_num = split_info->out_num; int64_t tmp = 0; std::vector split_axis_shape; std::vector split_axis_reduce_shape; // iter splited_num for (int64_t index = 0; index < split_num; index++) { // shape auto stride_h = ori_conv_prim->get_stride()[kIndexH]; auto split_axis_dim = (*split_axis_inputs_shape)[index_node][index] - 1; if (INT_MUL_OVERFLOW_THRESHOLD(stride_h, split_axis_dim, INT64_MAX)) { MS_LOG(ERROR) << "int mul overflow"; return false; } if (split_info->axis == CuttingStragedy::CUT_H) { // H if (index == 0) { tmp = stride_h * split_axis_dim - ori_conv_prim->get_pad_list()[kPadUp] + ori_conv_prim->get_kernel_size()[kIndexH]; } else if (index == split_num - 1) { tmp = stride_h * split_axis_dim - ori_conv_prim->get_pad_list()[kPadDown] + ori_conv_prim->get_kernel_size()[kIndexH]; } else { tmp = stride_h * split_axis_dim + ori_conv_prim->get_kernel_size()[kIndexH]; } } split_axis_shape.push_back(tmp); // reduce shape auto split_axis_reduce_dim = (*split_axis_reduce_inputs_shape)[index_node][index] - 1; if (split_info->axis == CuttingStragedy::CUT_H) { // H if (index == split_num - 1) { tmp = stride_h * split_axis_reduce_dim - ori_conv_prim->get_pad_list()[kPadDown] - ori_conv_prim->get_pad_list()[kPadUp] + ori_conv_prim->get_kernel_size()[kIndexH]; } else { tmp = stride_h * split_axis_reduce_dim - ori_conv_prim->get_pad_list()[kPadUp] + ori_conv_prim->get_kernel_size()[kIndexH]; } } split_axis_reduce_shape.push_back(tmp); } split_axis_inputs_shape->push_back(split_axis_shape); split_axis_reduce_inputs_shape->push_back(split_axis_reduce_shape); return true; } } // namespace bool IsConv2D(const AnfNodePtr &node) { return (CheckPrimitiveType(node, prim::kPrimConv2D) || CheckPrimitiveType(node, prim::kPrimConv2DFusion)); } std::shared_ptr CopyConvPrim(const std::shared_ptr &ori_conv_prim) { MS_CHECK_TRUE_MSG(ori_conv_prim != nullptr, nullptr, "input Conv2DFusion is nullptr"); auto new_prim = std::make_shared(); MS_CHECK_TRUE_MSG(new_prim != nullptr, nullptr, "create Conv2DFusion return nullptr"); new_prim->set_pad(ori_conv_prim->get_pad()); new_prim->set_in_channel(ori_conv_prim->get_in_channel()); new_prim->set_out_channel(ori_conv_prim->get_out_channel()); new_prim->set_dilation(ori_conv_prim->get_dilation()); new_prim->set_format(ori_conv_prim->get_format()); new_prim->set_group(ori_conv_prim->get_group()); new_prim->set_kernel_size(ori_conv_prim->get_kernel_size()); if (ori_conv_prim->get_pad_mode() == SAME) { new_prim->set_pad_mode(PAD); } else { new_prim->set_pad_mode(ori_conv_prim->get_pad_mode()); } new_prim->set_stride(ori_conv_prim->get_stride()); new_prim->set_activation_type(ori_conv_prim->get_activation_type()); new_prim->set_pad_list(ori_conv_prim->get_pad_list()); auto is_depth_value = ori_conv_prim->GetAttr(ops::kIsDepthWise); if (is_depth_value != nullptr) { bool is_depth_wise = GetValue(is_depth_value); new_prim->AddAttr(ops::kIsDepthWise, MakeValue(is_depth_wise)); } return new_prim; } bool UpdateSplitInfo(const FuncGraphPtr &func_graph, const std::vector &conv_nodes, SplitInfo *split_info) { MS_CHECK_TRUE_MSG(func_graph != nullptr, false, "input FuncGraphPtr is nullptr"); MS_CHECK_TRUE_MSG(split_info != nullptr, false, "input SplitInfo is nullptr"); if (split_info->axis != CuttingStragedy::CUT_H) { return false; } auto splited_axis = split_info->axis; // need to check if (split_info->fmk_type == FmkType::kFmkTypeCaffe || split_info->fmk_type == FmkType::kFmkTypeOnnx) { // NHWC -> NCHW splited_axis += 1; } size_t node_size = conv_nodes.size(); size_t index_node = 0; std::vector> node_in_out_shapes; while (index_node < node_size) { // [conv3, conv2, conv1] conv1->conv2->conv3 auto out_node_name = conv_nodes[index_node]->fullname_with_scope(); auto output_shapes = Spliter::GetInstance()->graph_node_output_shapes()[out_node_name]; auto input_shapes = Spliter::GetInstance()->graph_node_input_shapes()[out_node_name]; // 0-> in-shape 1->out-shape // only one in and one output MS_ASSERT(!input_shapes.empty() && !output_shapes.empty()); std::vector shape_vec = {input_shapes.front(), output_shapes.front()}; node_in_out_shapes.emplace_back(shape_vec); index_node++; } if (node_in_out_shapes.empty() || node_in_out_shapes.size() < (node_size - 1) || node_in_out_shapes[0].size() <= 1 || node_in_out_shapes[0][1].size() <= static_cast(splited_axis) || node_in_out_shapes[node_size - 1].empty() || node_in_out_shapes[node_size - 1][0].size() <= static_cast(splited_axis)) { MS_LOG(ERROR) << "out of node_in_out_shapes range"; return false; } int64_t splited_axis_value = node_in_out_shapes[0][1][splited_axis]; int64_t final_split_axis_value = node_in_out_shapes[node_size - 1][0][splited_axis]; split_info->ori_split_axis_value = final_split_axis_value; size_t split_num = split_info->size_splits.size(); std::vector split_axis_out_shape; std::vector split_axis_reduce_out_shape; if (!CalSplitOutputShape(splited_axis_value, split_info, &split_axis_out_shape, &split_axis_reduce_out_shape)) { return false; } // infer in-shape after splited std::vector> split_axis_inputs_shape{split_axis_out_shape}; std::vector> split_axis_reduce_inputs_shape{split_axis_reduce_out_shape}; index_node = 0; // iter node while (index_node < node_size) { auto conv_cnode = conv_nodes[index_node]->cast(); MS_ASSERT(conv_cnode != nullptr); auto ori_conv_prim = GetValueNode>(conv_cnode->input(kAnfPrimitiveIndex)); MS_CHECK_TRUE_RET(ori_conv_prim != nullptr, false); if (!CalSplitInShape(node_in_out_shapes, split_info, ori_conv_prim, index_node, &split_axis_inputs_shape, &split_axis_reduce_inputs_shape)) { MS_LOG(ERROR) << "CalSplitInShape failed"; return false; } index_node++; } // update ratio split_info->size_splits.clear(); split_info->extend_top.clear(); split_info->extend_bottom.clear(); int64_t top = 0; int32_t bottom = 0; split_info->size_splits.push_back(split_axis_inputs_shape[node_size][0]); split_info->extend_top.push_back(top); split_info->extend_bottom.push_back(bottom); for (size_t i = 1; i < split_num; i++) { auto begin = split_axis_reduce_inputs_shape[node_size][i] - split_axis_inputs_shape[node_size][i] + 1; top = split_axis_reduce_inputs_shape[node_size][i - 1] - begin + 1; auto value = split_axis_inputs_shape[node_size][i] - top; split_info->size_splits.push_back(value); split_info->extend_top.push_back(top); split_info->extend_bottom.push_back(bottom); } return true; } bool GetMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num, std::vector *outputs) { MS_CHECK_TRUE_MSG(func_graph != nullptr, false, "input FuncGraphPtr is nullptr"); MS_CHECK_TRUE_MSG(node != nullptr, false, "input AnfNodePtr is nullptr"); MS_CHECK_TRUE_MSG(outputs != nullptr, false, "input std::vector is nullptr"); auto cnode = node->cast(); MS_CHECK_TRUE_MSG(cnode != nullptr, false, "create CNode return nullptr"); for (size_t i = 0; i < output_num; i++) { auto index = NewValueNode(SizeToInt(i)); MS_CHECK_TRUE_MSG(index != nullptr, false, "create ValueNode return nullptr"); auto temp = SizeToInt(i); auto imm = std::make_shared(temp); MS_CHECK_TRUE_MSG(imm != nullptr, false, "create Int32Imm return nullptr"); auto abstract_scalar = std::make_shared(imm); MS_CHECK_TRUE_MSG(abstract_scalar != nullptr, false, "create AbstractScalar return nullptr"); index->set_abstract(abstract_scalar); auto tuple_getitem_primitive = NewValueNode(prim::kPrimTupleGetItem); MS_CHECK_TRUE_MSG(tuple_getitem_primitive != nullptr, false, "create PrimTupleGetItem return nullptr"); auto tuple_getitem = func_graph->NewCNode({tuple_getitem_primitive, node, index}); MS_CHECK_TRUE_MSG(tuple_getitem != nullptr, false, "create CNode return nullptr"); tuple_getitem->set_fullname_with_scope(cnode->fullname_with_scope() + "_TupleGetItem_" + std::to_string(i + 1)); outputs->push_back(tuple_getitem); } return true; } AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_cnode, const std::vector &conv_outputs, SplitInfo *split_info, const std::string &node_name) { MS_CHECK_TRUE_MSG(func_graph != nullptr, nullptr, "input FuncGraphPtr is nullptr"); MS_CHECK_TRUE_MSG(conv_cnode != nullptr, nullptr, "input AnfNodePtr is nullptr"); MS_CHECK_TRUE_MSG(split_info != nullptr, nullptr, "input SplitInfo is nullptr"); auto nodes_num = static_cast(conv_outputs.size()); if (nodes_num != split_info->out_num) { MS_LOG(ERROR) << "Conv outputs has wrong input size"; return nullptr; } auto concat_prim = std::make_shared(); MS_CHECK_TRUE_MSG(concat_prim != nullptr, nullptr, "create ops::Concat return nullptr"); concat_prim->set_axis(split_info->axis); // the inputs of concate are from the outputs of conv auto concate_primitive = NewValueNode(concat_prim); MS_CHECK_TRUE_MSG(concate_primitive != nullptr, nullptr, "create concate_primitive return nullptr"); std::vector concate_inputs = {concate_primitive}; for (size_t i = 0; i < static_cast(nodes_num); i++) { concate_inputs.push_back(conv_outputs[i]); } auto concate_cnode = func_graph->NewCNode(concate_inputs); MS_CHECK_TRUE_MSG(concate_cnode != nullptr, nullptr, "create concate_cnode return nullptr"); concate_cnode->set_fullname_with_scope(node_name + "_Concat"); concate_cnode->set_scope(conv_cnode->scope()); std::vector outputs; if (!GetMultipleOutputsOfAnfNode(func_graph, concate_cnode, 1, &outputs)) { MS_LOG(ERROR) << "GetMultipleOutputsOfAnfNode failed"; return nullptr; } return concate_cnode; } bool CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_node, std::vector *split_outputs, SplitInfo *split_info, const std::string &node_name) { MS_CHECK_TRUE_MSG(func_graph != nullptr, false, "input FuncGraphPtr is nullptr"); MS_CHECK_TRUE_MSG(conv_node != nullptr, false, "input conv_node is nullptr"); MS_CHECK_TRUE_MSG(split_outputs != nullptr, false, "input split_outputs is nullptr"); MS_CHECK_TRUE_MSG(split_info != nullptr, false, "input split_info is nullptr"); // attr of split auto split_prim = std::make_shared(); MS_CHECK_TRUE_MSG(split_prim != nullptr, false, "create ops::SplitWithOverlap return nullptr"); split_prim->set_split_dim(split_info->axis); split_prim->set_number_split(split_info->out_num); split_prim->set_ratio(split_info->size_splits); split_prim->set_extend_top(split_info->extend_top); split_prim->set_extend_bottom(split_info->extend_bottom); auto conv_cnode = conv_node->cast(); // the inputs of split is from the inputs of conv auto split_primitive = NewValueNode(split_prim); MS_CHECK_TRUE_MSG(split_primitive != nullptr, false, "create split_primitive return nullptr"); std::vector split_inputs = {split_primitive}; // this conv only has one input, which has been ensured before split_inputs.push_back(conv_cnode->input(1)); auto split_cnode = func_graph->NewCNode(split_inputs); MS_CHECK_TRUE_MSG(split_cnode != nullptr, false, "create split_cnode return nullptr"); split_cnode->set_fullname_with_scope(node_name + "_Split"); if (split_info->out_num < 0) { MS_LOG(ERROR) << "out_num should greater then zero"; return false; } // create outputs op split if (!GetMultipleOutputsOfAnfNode(func_graph, split_cnode, split_info->out_num, split_outputs)) { MS_LOG(ERROR) << "GetMultipleOutputsOfAnfNode failed"; return false; } AbstractBasePtrList ptr_list; for (int64_t i = 0; i < split_info->out_num; i++) { // set date_type same with weight auto type_id = static_cast(kNumberTypeFloat32); auto type_ptr = TypeIdToType(type_id); std::vector shape_vector; auto value_node = std::make_shared(type_ptr, shape_vector); MS_CHECK_TRUE_MSG(value_node != nullptr, false, "create abstract::AbstractTensor return nullptr"); ptr_list.push_back(value_node); } split_cnode->set_abstract(std::make_shared(ptr_list)); return true; } bool UpdateRatioWithPadStride(int64_t *ratio, size_t ratio_len, size_t split_size, int split_dim_size) { MS_CHECK_TRUE_MSG(ratio != nullptr, false, "input ratio is nullptr"); int64_t total_block_count = 0; for (size_t i = 0; i < split_size; i++) { total_block_count += ratio[i]; } if (ratio_len < split_size) { MS_LOG(ERROR) << "out of ratio range"; return false; } if (total_block_count < 0) { MS_LOG(ERROR) << "divide by zero"; return false; } std::vector new_ratio(split_size); int64_t visited_block = 0; for (size_t i = 0; i < split_size - 1; i++) { visited_block += ratio[i]; if (INT_MUL_OVERFLOW_THRESHOLD(split_dim_size, visited_block, INT64_MAX)) { MS_LOG(ERROR) << "int mul overflow"; return false; } int64_t cur_border = UP_DIV(split_dim_size * visited_block, total_block_count); new_ratio[i + 1] = cur_border; } for (size_t i = 0; i < split_size; i++) { ratio[i] = new_ratio[i]; } return true; } } // namespace opt } // namespace mindspore