/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tools/optimizer/parallel/conv2d_info.h" #include #include #include #include #include #include "mindspore/core/ops/fusion/conv2d_fusion.h" #include "mindspore/core/ops/split_with_overlap.h" #include "tools/optimizer/common/gllo_utils.h" #include "mindspore/ccsrc/utils/utils.h" #include "tools/converter/converter_flags.h" #include "include/errorcode.h" namespace mindspore { namespace opt { // strategy format is NHWC-KHWC constexpr int32_t kAxisN = 0; constexpr int32_t kAxisCIn = 3; constexpr int32_t kAxisCOut = 0; constexpr int32_t kAxisH = 1; constexpr int32_t kAxisW = 2; constexpr auto kIndexH = 0; constexpr auto kIndexW = 1; constexpr auto kPadUp = 0; constexpr auto kPadDown = 1; constexpr auto kPadLeft = 2; constexpr auto kPadRight = 3; int Conv2DInfo::GetAttrs() { return lite::RET_OK; } int Conv2DInfo::CheckStrategy(const SplitStrategy &strategy) { int split_count = 0; Strategys strategys = strategy.strategys; // if split N if (is_any_not_none(strategys[0][kAxisN])) { split_count++; splitMode_ = SplitN; } // if split C_in if (is_any_not_none(strategys[0][kAxisCIn])) { split_count++; splitMode_ = SplitCIN; if (strategys[0][kAxisCIn] != strategys[1][kAxisCIn]) { MS_LOG(ERROR) << "Strategy ERROR, split C_in, input and kernel must use same strategy."; return lite::RET_ERROR; } } // if split C_out if (is_any_not_none(strategys[1][kAxisCOut])) { split_count++; splitMode_ = SplitCOUT; } // if split H if (is_any_not_none(strategys[0][kAxisH])) { split_count++; splitMode_ = SplitH; } if (is_any_not_none(strategys[0][kAxisW])) { MS_LOG(ERROR) << "Strategy ERROR, doesn't support split W."; return lite::RET_ERROR; } if (is_any_not_none(strategys[1][kAxisH])) { MS_LOG(ERROR) << "Strategy ERROR, doesn't support split kernel H."; return lite::RET_ERROR; } if (is_any_not_none(strategys[1][kAxisW])) { MS_LOG(ERROR) << "Strategy ERROR, doesn't support split kernel W."; return lite::RET_ERROR; } if (split_count > 1) { MS_LOG(ERROR) << "Strategy ERROR, only support split one dimension."; return lite::RET_ERROR; } return lite::RET_OK; } AnfNodePtr Conv2DInfo::CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index, std::vector *split_outputs, size_t split_dim, size_t split_num, const std::vector &splits, bool trans_format) { if (orig_node == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return nullptr; } auto conv_prim = GetValueNode>(cnode_->input(kAnfPrimitiveIndex)); // prim of split auto split_prim = std::make_shared(); split_prim->set_split_dim(split_dim); split_prim->set_number_split(split_num); split_prim->set_ratio(splits); split_prim->set_trans_format(trans_format); if (splitMode_ == SplitH) { split_prim->set_extend_top(std::vector(split_num, 0)); auto extend_bottom = conv_prim->get_kernel_size().at(kIndexH) - conv_prim->get_stride().at(kIndexH); auto bottom_vector = std::vector(split_num, extend_bottom); bottom_vector[split_num - 1] = 0; split_prim->set_extend_bottom(bottom_vector); split_prim->set_stride(conv_prim->get_stride().at(kIndexH)); split_prim->set_pad_top(conv_prim->get_pad_list().at(kPadUp)); } else { split_prim->set_extend_top(std::vector(split_num, 0)); split_prim->set_extend_bottom(std::vector(split_num, 0)); split_prim->set_stride(0); split_prim->set_pad_top(0); } std::vector split_inputs = {NewValueNode(split_prim)}; split_inputs.push_back(orig_node->input(input_index + 1)); auto split_cnode = func_graph_->NewCNode(split_inputs); if (split_cnode == nullptr) { MS_LOG(ERROR) << name_ << " : Failed to create split node."; lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return nullptr; } split_cnode->set_fullname_with_scope("Split_" + name_); CreateMultipleOutputsOfAnfNode(split_cnode, split_num, split_outputs); return split_cnode; } int Conv2DInfo::CheckConv2DPrimitiveType() { if (CheckIfFuncGraphIsNull(func_graph_) != lite::RET_OK) { return lite::RET_ERROR; } if (CheckIfAnfNodeIsNull(cnode_) != lite::RET_OK) { return lite::RET_ERROR; } if (!CheckPrimitiveType(cnode_, prim::kPrimConv2D) && !CheckPrimitiveType(cnode_, prim::kPrimConv2DFusion)) { return RET_ERROR; } return RET_OK; } int Conv2DInfo::InferParallelCNodes() { if (!CheckConv2DPrimitiveType()) { return RET_OK; } Strategys strategys = strategy_.strategys; size_t dev_num = strategy_.dev_num; std::vector feature_split_outputs; std::vector kernel_split_outputs; std::vector bias_split_outputs; std::string orig_name = name_; parallel_output_nodes_.clear(); auto conv_prim = GetValueNode>(cnode_->input(kAnfPrimitiveIndex)); // split feature and kernel switch (splitMode_) { case SplitH: { name_ = orig_name + "_input"; auto feature_split_cnode = CreateOutputsOfSplit(cnode_, 0, &feature_split_outputs, kAxisH, dev_num, strategys[0][kAxisH], true); if ((feature_split_cnode == nullptr) || (feature_split_outputs.size() != IntToSize(dev_num))) { MS_LOG(ERROR) << name_ << " : Make split cnode failed."; return lite::RET_ERROR; } } break; case SplitN: { name_ = orig_name + "_input"; auto feature_split_cnode = CreateOutputsOfSplit(cnode_, 0, &feature_split_outputs, kAxisN, dev_num, strategys[0][kAxisN], true); if ((feature_split_cnode == nullptr) || (feature_split_outputs.size() != IntToSize(dev_num))) { MS_LOG(ERROR) << name_ << " : Make split cnode failed."; return lite::RET_ERROR; } } break; case SplitCOUT: { name_ = orig_name + "_kernel"; auto kernel_split_cnode = CreateOutputsOfSplit(cnode_, 1, &kernel_split_outputs, kAxisCOut, dev_num, strategys[1][kAxisCOut], false); if ((kernel_split_cnode == nullptr) || (kernel_split_outputs.size() != IntToSize(dev_num))) { MS_LOG(ERROR) << name_ << " : Make split cnode failed."; return lite::RET_ERROR; } if (cnode_->size() >= 4) { name_ = orig_name + "_bias"; auto bias_split_cnode = CreateOutputsOfSplit(cnode_, 2, &bias_split_outputs, 0, dev_num, strategys[1][kAxisCOut], false); if ((bias_split_cnode == nullptr) || (bias_split_outputs.size() != IntToSize(dev_num))) { MS_LOG(ERROR) << name_ << " : Make split cnode failed."; return lite::RET_ERROR; } } } break; case SplitCIN: { name_ = orig_name + "_input"; auto feature_split_cnode = CreateOutputsOfSplit(cnode_, 0, &feature_split_outputs, kAxisCIn, dev_num, strategys[0][kAxisCIn], true); if ((feature_split_cnode == nullptr) || (feature_split_outputs.size() != IntToSize(dev_num))) { MS_LOG(ERROR) << name_ << " : Make split cnode failed."; return lite::RET_ERROR; } name_ = orig_name + "_kernel"; auto kernel_split_cnode = CreateOutputsOfSplit(cnode_, 1, &kernel_split_outputs, kAxisCIn, dev_num, strategys[1][kAxisCIn], false); if ((kernel_split_cnode == nullptr) || (kernel_split_outputs.size() != IntToSize(dev_num))) { MS_LOG(ERROR) << name_ << " : Make split cnode failed."; return lite::RET_ERROR; } } break; default: MS_LOG(DEBUG) << "No Split mode chosen"; } name_ = orig_name; return ConstructOutputCNodes(conv_prim, feature_split_outputs, kernel_split_outputs, bias_split_outputs); } int Conv2DInfo::ConstructOutputCNodes(const std::shared_ptr &conv_prim, const std::vector &feature_split_outputs, const std::vector &kernel_split_outputs, const std::vector &bias_split_outputs) { Strategys strategys = strategy_.strategys; size_t dev_num = strategy_.dev_num; int cin_strategy_sum = std::accumulate(strategys[0][kAxisCIn].begin(), strategys[0][kAxisCIn].end(), 0); int cout_strategy_sum = std::accumulate(strategys[1][kAxisCOut].begin(), strategys[1][kAxisCOut].end(), 0); std::string conv_cnode_name = cnode_->fullname_with_scope(); // construct parallel Conv2D nodes for (size_t i = 0; i < dev_num; ++i) { std::vector tmp_outputs; bool has_bias = cnode_->size() >= 4; // if split cin, only one parallel operator has bias if ((i != 0) && splitMode_ == SplitCIN) { has_bias = false; } // copy attr auto prim = std::make_shared(); prim->set_pad(conv_prim->get_pad()); prim->set_in_channel(conv_prim->get_in_channel()); prim->set_out_channel(conv_prim->get_out_channel()); prim->set_dilation(conv_prim->get_dilation()); prim->set_format(conv_prim->get_format()); prim->set_group(conv_prim->get_group()); prim->set_kernel_size(conv_prim->get_kernel_size()); prim->set_pad_mode(conv_prim->get_pad_mode()); prim->set_pad_list(conv_prim->get_pad_list()); prim->set_stride(conv_prim->get_stride()); prim->set_activation_type(conv_prim->get_activation_type()); switch (splitMode_) { case SplitH: { if (i != 0) { auto pad = prim->get_pad_list(); pad.at(kPadUp) = 0; prim->set_pad_list(pad); } if (i != (dev_num - 1)) { auto pad = prim->get_pad_list(); pad.at(kPadDown) = 0; prim->set_pad_list(pad); } } break; case SplitCIN: { auto in_channel = prim->get_in_channel(); if (i == 0) { prim->set_in_channel(in_channel * strategys[0][kAxisCIn][0] / cin_strategy_sum); } else { prim->set_in_channel(in_channel - (in_channel * strategys[0][kAxisCIn][0] / cin_strategy_sum)); } } break; case SplitCOUT: { auto out_channel = prim->get_out_channel(); if (i == 0) { prim->set_out_channel(out_channel * strategys[1][kAxisCOut][0] / cout_strategy_sum); } else { prim->set_out_channel(out_channel - (out_channel * strategys[1][kAxisCOut][0] / cout_strategy_sum)); } } break; default: break; } std::vector conv_inputs = {NewValueNode(prim)}; // if split Cout, feature will not be splited if (splitMode_ == SplitCOUT) { conv_inputs.push_back(cnode_->input(1)); } else { conv_inputs.push_back(feature_split_outputs[i]); } // kernel splited only when split Cin and Cout if (splitMode_ == SplitCIN || splitMode_ == SplitCOUT) { conv_inputs.push_back(kernel_split_outputs[i]); } else { conv_inputs.push_back(cnode_->input(2)); } if (has_bias) { if (splitMode_ == SplitCOUT) { conv_inputs.push_back(bias_split_outputs[i]); } else { conv_inputs.push_back(cnode_->input(3)); } } auto conv_cnode = func_graph_->NewCNode(conv_inputs); if (conv_cnode == nullptr) { MS_LOG(ERROR) << name_ << " : Failed to create parallel Conv2D node " << i; return lite::RET_ERROR; } conv_cnode->set_fullname_with_scope(conv_cnode_name + std::to_string(i)); CreateMultipleOutputsOfAnfNode(conv_cnode, 1, &tmp_outputs); parallel_output_nodes_.push_back(tmp_outputs[0]); } return lite::RET_OK; } int Conv2DInfo::InferReplaceOp() { size_t dev_num = strategy_.dev_num; if (splitMode_ == SplitCIN) { MS_LOG(DEBUG) << name_ << " : Split Cin, infer Forward op."; replace_op_ = CreateReduceNode(cnode_, parallel_output_nodes_, kAxisCIn, dev_num, true); } else { int32_t concat_dim; if (splitMode_ == SplitN) { concat_dim = kAxisN; } else if (splitMode_ == SplitCOUT) { // output format is same as feature map concat_dim = kAxisCIn; } else { concat_dim = kAxisH; } replace_op_ = CreateConcateNode(cnode_, parallel_output_nodes_, concat_dim, dev_num, true); } if (replace_op_ == nullptr) { return lite::RET_ERROR; } return lite::RET_OK; } AnfNodePtr DepthwiseConv2DInfo::CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index, std::vector *split_outputs, size_t split_dim, size_t split_num, const std::vector &splits, bool trans_format) { return nullptr; } int DepthwiseConv2DInfo::InferParallelCNodes() { if (CheckIfFuncGraphIsNull(func_graph_) != lite::RET_OK) { return lite::RET_ERROR; } if (CheckIfAnfNodeIsNull(cnode_) != lite::RET_OK) { return lite::RET_ERROR; } Strategys strategys = strategy_.strategys; size_t dev_num = strategy_.dev_num; std::vector feature_split_outputs; std::string orig_name = name_; switch (splitMode_) { case SplitCIN: { MS_LOG(ERROR) << "DepthwiseConv2DInfo doesn't support split Cin."; return lite::RET_ERROR; } break; case SplitCOUT: { MS_LOG(ERROR) << "DepthwiseConv2DInfo doesn't support split Cout."; return lite::RET_ERROR; } break; case SplitH: { MS_LOG(ERROR) << "DepthwiseConv2DInfo doesn't support split H."; return lite::RET_ERROR; } break; default: break; } parallel_output_nodes_.clear(); std::string conv_cnode_name = cnode_->fullname_with_scope(); auto conv_prim = GetValueNode>(cnode_->input(kAnfPrimitiveIndex)); if (!IsDwConvNode(conv_prim)) { return lite::RET_ERROR; } // split feature and kernel if (splitMode_ == SplitN) { name_ = orig_name + "_input"; auto feature_split_cnode = CreateOutputsOfSplit(cnode_, 0, &feature_split_outputs, kAxisN, dev_num, strategys[0][kAxisN], true); if ((feature_split_cnode == nullptr) || (feature_split_outputs.size() != IntToSize(dev_num))) { MS_LOG(ERROR) << name_ << " : Make split cnode failed."; return lite::RET_ERROR; } } name_ = orig_name; // construct parallel Conv2D nodes for (size_t i = 0; i < dev_num; ++i) { std::vector tmp_outputs; // copy attr auto prim = std::make_shared(); prim->AddAttr(ops::kIsDepthWise, MakeValue(true)); prim->set_pad(conv_prim->get_pad()); prim->set_in_channel(conv_prim->get_in_channel()); prim->set_out_channel(conv_prim->get_out_channel()); prim->set_dilation(conv_prim->get_dilation()); prim->set_format(conv_prim->get_format()); prim->set_group(conv_prim->get_group()); prim->set_kernel_size(conv_prim->get_kernel_size()); prim->set_pad_mode(conv_prim->get_pad_mode()); prim->set_pad_list(conv_prim->get_pad_list()); prim->set_stride(conv_prim->get_stride()); prim->set_activation_type(conv_prim->get_activation_type()); std::vector conv_inputs = {NewValueNode(prim)}; conv_inputs.push_back(feature_split_outputs[i]); conv_inputs.push_back(cnode_->input(2)); if (cnode_->size() >= 4) { conv_inputs.push_back(cnode_->input(3)); } auto conv_cnode = func_graph_->NewCNode(conv_inputs); if (conv_cnode == nullptr) { MS_LOG(ERROR) << name_ << " : Failed to create parallel DepthwiseConv2D node " << i; return lite::RET_ERROR; } conv_cnode->set_fullname_with_scope(conv_cnode_name + std::to_string(i)); CreateMultipleOutputsOfAnfNode(conv_cnode, 1, &tmp_outputs); parallel_output_nodes_.push_back(tmp_outputs[0]); } return lite::RET_OK; } int DepthwiseConv2DInfo::InferReplaceOp() { size_t dev_num = strategy_.dev_num; replace_op_ = CreateConcateNode(cnode_, parallel_output_nodes_, kAxisN, dev_num, true); if (replace_op_ == nullptr) { return lite::RET_ERROR; } return lite::RET_OK; } } // namespace opt } // namespace mindspore