From cdde7fcc7baf86e3c35ddbe95154deffbd9fd918 Mon Sep 17 00:00:00 2001 From: zhujingxuan Date: Fri, 30 Apr 2021 09:38:07 +0800 Subject: [PATCH] add single conv split --- mindspore/lite/tools/converter/CMakeLists.txt | 4 + .../lite/tools/optimizer/common/gllo_utils.cc | 7 + .../lite/tools/optimizer/common/gllo_utils.h | 2 + .../fisson/eliminate_concat_split.cc | 153 +++++++++++++ .../optimizer/fisson/eliminate_concat_split.h | 34 +++ .../fusion/conv_activation_fusion.cc | 1 - .../fusion/pooling_activation_fusion.cc | 1 - .../optimizer/parallel/dynamic_creator.cc | 50 ++++ .../optimizer/parallel/dynamic_creator.h | 80 +++++++ .../tools/optimizer/parallel/operator_info.cc | 213 ++++++++++++++++++ .../tools/optimizer/parallel/operator_info.h | 101 +++++++++ .../tools/optimizer/parallel/parallel_pass.cc | 86 +++++++ .../tools/optimizer/parallel/parallel_pass.h | 55 +++++ .../tools/optimizer/parallel/split_strategy.h | 6 - 14 files changed, 785 insertions(+), 8 deletions(-) create mode 100644 mindspore/lite/tools/optimizer/fisson/eliminate_concat_split.cc create mode 100644 mindspore/lite/tools/optimizer/fisson/eliminate_concat_split.h create mode 100644 mindspore/lite/tools/optimizer/parallel/dynamic_creator.cc create mode 100644 mindspore/lite/tools/optimizer/parallel/dynamic_creator.h create mode 100644 mindspore/lite/tools/optimizer/parallel/operator_info.cc create mode 100644 mindspore/lite/tools/optimizer/parallel/operator_info.h create mode 100644 mindspore/lite/tools/optimizer/parallel/parallel_pass.cc create mode 100644 mindspore/lite/tools/optimizer/parallel/parallel_pass.h diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 54e47dd367..00652fffed 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -59,9 +59,13 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/fusion/tf_gelu_fusion.cc ../optimizer/fusion/onnx_gelu_fusion.cc ../optimizer/fusion/squeeze_fusion.cc + ../optimizer/fisson/eliminate_concat_split.cc ../optimizer/fisson/fisson_util.cc ../optimizer/fisson/iter_node_outputs.cc ../optimizer/fisson/node_out_shapes.cc + ../optimizer/parallel/dynamic_creator.cc + ../optimizer/parallel/operator_info.cc + ../optimizer/parallel/parallel_pass.cc ../optimizer/graph/conv1d_inout_adjust_pass.cc ../optimizer/graph/weight_format_transform_pass.cc ../optimizer/graph/weight_format_hardcode_pass.cc diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index e80e89ef2e..a12f302163 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -645,6 +645,13 @@ bool IsSqueezeNode(const BaseRef &n) { return false; } +bool IsConcatNode(const BaseRef &n) { + if (utils::isa(n)) { + return CheckPrimitiveType(utils::cast(n), prim::kPrimConcat); + } + return false; +} + bool CheckIsAllInputsParam(const AnfNodePtr &node) { if (node == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.h b/mindspore/lite/tools/optimizer/common/gllo_utils.h index 3bb51fb022..f12e01ab86 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.h +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.h @@ -94,6 +94,8 @@ size_t GetOutputTensorNum(const AnfNodePtr &node); bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node); +bool IsConcatNode(const BaseRef &n); + size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item); tensor::TensorPtr GetTensorInfo(const AnfNodePtr &node); diff --git a/mindspore/lite/tools/optimizer/fisson/eliminate_concat_split.cc b/mindspore/lite/tools/optimizer/fisson/eliminate_concat_split.cc new file mode 100644 index 0000000000..ff527ed29e --- /dev/null +++ b/mindspore/lite/tools/optimizer/fisson/eliminate_concat_split.cc @@ -0,0 +1,153 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/optimizer/fisson/eliminate_concat_split.h" +#include "schema/inner/model_generated.h" +#include "utils/utils.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "mindspore/core/ops/split_with_overlap.h" +#include "mindspore/core/ops/concat.h" +#include "mindspore/core/base/core_ops.h" + +namespace mindspore { + +namespace opt { + +const BaseRef EliminateConcatSplit::DefinePattern() const { + auto concat_var = std::make_shared(IsConcatNode); + auto split_prim = std::make_shared(); + + return VectorRef({split_prim, concat_var}); +} + +CNodePtr GetRealPrevCNode(const AnfNodePtr &node) { + if (node == nullptr || !node->isa()) { + return nullptr; + } + auto cnode = node->cast(); + + if (IsRealCNodeKernel(cnode)) { + return cnode; + } + + auto input0 = cnode->input(0); + + if (IsPrimitive(input0, prim::kPrimMakeTuple)) { + auto temp_node = cnode->input(1); + if (temp_node == nullptr) { + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + return nullptr; + } + return GetRealPrevCNode(temp_node); + } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { + return GetRealPrevCNode(cnode->input(1)); + } else { + return nullptr; + } +} + +void ConcatSplitEliminate(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + auto pre_cnode = GetRealPrevCNode(cnode->input(1)); + CheckIfCNodeIsNull(pre_cnode); + + if (!CheckPrimitiveType(pre_cnode, prim::kPrimConcat)) { + return; + } + + auto finder = g_graph_nodes_output.find(pre_cnode->fullname_with_scope()); + if (finder == g_graph_nodes_output.end()) { + return; + } + if (finder->second.size() > 1) return; + + size_t pre_inputs_size = pre_cnode->inputs().size(); + int pre_inputs_node_size = pre_inputs_size - 1; + auto pre_prim = GetValueNode>(pre_cnode->input(kAnfPrimitiveIndex)); + auto prim = GetValueNode>(cnode->input(kAnfPrimitiveIndex)); + + if (prim->get_number_split() != pre_inputs_node_size) { + return; + } + + // check axis NHWC + // only support axis "N" now, other axes will support when having "InferShape" + if (pre_prim->get_axis() != 0) { + return; + } + + // get inputs node + auto it = g_graph_nodes_output.find(cnode->fullname_with_scope()); + if (it == g_graph_nodes_output.end()) { + return; + } + int out_num = it->second.size(); + if (out_num != prim->get_number_split()) { + return; + } + + std::vector inputs_node; + for (int i = 0; i < out_num; i++) { + auto tmp = it->second[i]; + auto tmp_cnode = tmp->cast(); + if (CheckIfCNodeIsNull(tmp_cnode) != lite::RET_OK) { + return; + } + if (!CheckPrimitiveType(tmp_cnode, prim::kPrimTupleGetItem)) { + return; + } + auto tmp_it = g_graph_nodes_output.find(tmp_cnode->fullname_with_scope()); + if (tmp_it == g_graph_nodes_output.end()) { + return; + } + if (tmp_it->second.size() != 1) return; + + auto next = tmp_it->second[0]; + auto next_cnode = next->cast(); + + inputs_node.push_back(next_cnode); + } + // replace inputs + auto manager = func_graph->manager(); + if (manager == nullptr) { + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + return; + } + for (size_t i = 1; i < pre_inputs_size; i++) { + (void)manager->Replace((inputs_node[i - 1])->input(1), pre_cnode->input(i)); + } +} + +const AnfNodePtr EliminateConcatSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_LOG(DEBUG) << "Enter EliminateConcatSplit pass process"; + if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK) { + return nullptr; + } + if (CheckIfAnfNodeIsNull(node) != lite::RET_OK) { + return nullptr; + } + auto split_cnode = node->cast(); + if (CheckIfCNodeIsNull(split_cnode) != lite::RET_OK) { + return nullptr; + } + ConcatSplitEliminate(func_graph, split_cnode); + + return node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/fisson/eliminate_concat_split.h b/mindspore/lite/tools/optimizer/fisson/eliminate_concat_split.h new file mode 100644 index 0000000000..74e8bdcff1 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fisson/eliminate_concat_split.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ELIMINATE_CONCAT_SPLIT_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ELIMINATE_CONCAT_SPLIT_H_ + +#include "backend/optimizer/common/optimizer.h" +#include "tools/optimizer/fisson/fisson_util.h" + +namespace mindspore { +namespace opt { +class EliminateConcatSplit : public PatternProcessPass { + public: + explicit EliminateConcatSplit(bool multigraph = true) : PatternProcessPass("eliminate_concat_split", multigraph) {} + ~EliminateConcatSplit() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ELIMINATE_CONCAT_SPLIT_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc index fda4f685b5..901fb01408 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc @@ -62,7 +62,6 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c return nullptr; } auto conv_node = pre_node->cast(); - MS_ASSERT(primitive_c); if (CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion) || CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion)) { auto prim = GetValueNode(conv_node->input(0)); diff --git a/mindspore/lite/tools/optimizer/fusion/pooling_activation_fusion.cc b/mindspore/lite/tools/optimizer/fusion/pooling_activation_fusion.cc index e7872c89f9..a97d40ab8b 100644 --- a/mindspore/lite/tools/optimizer/fusion/pooling_activation_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/pooling_activation_fusion.cc @@ -63,7 +63,6 @@ const AnfNodePtr PoolingActivationFusion::Process(const FuncGraphPtr &func_graph } auto pooling_node = pre_node->cast(); auto primitive_c = GetValueNode>(pooling_node->input(0)); - MS_ASSERT(primitive_c); MS_ASSERT(utils::isa>(primitive_c)); auto primc = utils::cast>(primitive_c); diff --git a/mindspore/lite/tools/optimizer/parallel/dynamic_creator.cc b/mindspore/lite/tools/optimizer/parallel/dynamic_creator.cc new file mode 100644 index 0000000000..eb741a1e07 --- /dev/null +++ b/mindspore/lite/tools/optimizer/parallel/dynamic_creator.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/optimizer/parallel/dynamic_creator.h" + +namespace mindspore { +namespace opt { +// operator register + +std::string GetDisOpName(const std::string &prim_name) { + std::string op_name = prim_name; + if (!prim_name.empty() && (prim_name[0] == '_')) { + op_name = prim_name.substr(1); + } + return op_name + "Info"; +} + +// create the OperatorInfo instance +OperatorInfoPtr OperatorInstance(const std::string &type_name, const std::string &orig_name, + const SplitStrategy &strategy) { + if (type_name.length() == 0) { + MS_LOG(EXCEPTION) << "Length of name is zero!"; + } + std::string distribute_opname = GetDisOpName(type_name); + OperatorInfoPtr operator_ = (OperatorInfoPtr)DynCreator::Instance().Create(distribute_opname, strategy); + if (operator_ == nullptr) { + MS_LOG(INFO) << "Create " << type_name << " failed"; + return nullptr; + } + std::string origin_name = operator_->name(); + operator_->set_name(orig_name); + MS_LOG(INFO) << "Successfully created operator " << origin_name; + return operator_; +} + +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/parallel/dynamic_creator.h b/mindspore/lite/tools/optimizer/parallel/dynamic_creator.h new file mode 100644 index 0000000000..fbda223cfa --- /dev/null +++ b/mindspore/lite/tools/optimizer/parallel/dynamic_creator.h @@ -0,0 +1,80 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_PARALLEL_DYNAMIC_CREATOR_H_ +#define MINDSPORE_LITE_SRC_PASS_PARALLEL_DYNAMIC_CREATOR_H_ + +#include +#include +#include +#include + +#include "tools/optimizer/parallel/operator_info.h" + +namespace mindspore { +namespace opt { +#define REGISTER(className) \ + OperatorInfoPtr objectCreator##className(std::string name, SplitStrategy strategy) { \ + return std::make_shared(name, strategy); \ + } \ + RegisterAction className##Register(#className, (CreatFn)objectCreator##className); + +typedef OperatorInfoPtr (*CreatFn)(const std::string &name, const SplitStrategy &strategy); + +class DynCreator { + public: + ~DynCreator() = default; + + // create static singleton dyn_creator instance + static DynCreator &Instance() { + static DynCreator fac = DynCreator(); + return fac; + } + // register + void Register(std::string name, CreatFn func) { (void)Function_map_.insert(std::make_pair(name, func)); } + // creator + OperatorInfoPtr Create(const std::string &name, const SplitStrategy &strategy) { + auto iter = Function_map_.find(name); + if (iter == Function_map_.end()) { + MS_LOG(INFO) << name << " is not register yet"; + return nullptr; + } + return iter->second(name, strategy); + } + + private: + DynCreator() = default; + std::map Function_map_; +}; + +class RegisterAction { + public: + RegisterAction(const std::string &name, CreatFn creatfn) : name_(name) { + DynCreator::Instance().Register(name, creatfn); + } + ~RegisterAction() = default; + + private: + std::string name_; +}; + +OperatorInfoPtr OperatorInstance(const std::string &type_name, const std::string &orig_name, + const SplitStrategy &strategy); + +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_PASS_PARALLEL_DYNAMIC_CREATOR_H_ diff --git a/mindspore/lite/tools/optimizer/parallel/operator_info.cc b/mindspore/lite/tools/optimizer/parallel/operator_info.cc new file mode 100644 index 0000000000..11140e1024 --- /dev/null +++ b/mindspore/lite/tools/optimizer/parallel/operator_info.cc @@ -0,0 +1,213 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/optimizer/parallel/operator_info.h" +#include +#include "tools/converter/ops/ops_def.h" +#include "tools/optimizer/parallel/split_strategy.h" +#include "mindspore/core/ops/concat.h" +#include "mindspore/core/ops/addn.h" +#include "mindspore/core/ops/split.h" +#include "include/lite_types.h" +#include "mindspore/ccsrc/utils/utils.h" +#include "base/core_ops.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace opt { +bool is_any_none(const std::vector &split) { + return std::any_of(split.begin(), split.end(), [](int64_t v) { return v == static_cast(NoSplit); }); +} + +bool is_any_not_none(const std::vector &split) { + return std::any_of(split.begin(), split.end(), [](int64_t v) { return v != static_cast(NoSplit); }); +} + +lite::STATUS OperatorInfo::SetCNodeBackend() { + for (size_t i = 0; i < strategy_.dev_num; ++i) { + lite::DeviceType dt_type; + std::string type = strategy_.dev_types[i]; + auto cnode = parallel_output_nodes_[i]->cast()->input(1)->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (type == "GPU") { + dt_type = lite::DeviceType::DT_GPU; + } else if (type == "CPU") { + dt_type = lite::DeviceType::DT_CPU; + } else if (type == "NPU") { + dt_type = lite::DeviceType::DT_NPU; + } else { + MS_LOG(ERROR) << "SetCnodeBackend: unknown device type."; + return lite::RET_ERROR; + } + cnode->AddAttr(mindspore::ops::kDeviceType, MakeValue(static_cast(dt_type))); + } + return lite::RET_OK; +} + +lite::STATUS OperatorInfo::CheckStrategyValue() { + auto strategy_size = strategy_.strategys.size(); + + for (size_t index = 0; index < strategy_size; ++index) { + auto strategy = strategy_.strategys[index]; + for (const auto &s : strategy) { + if (s.size() != IntToSize(strategy_.dev_num)) { + MS_LOG(ERROR) << "Strategy split number:" << s.size() + << " is not equal to device number: " << strategy_.dev_num; + return lite::RET_ERROR; + } + if (is_any_not_none(s) && is_any_none(s)) { + MS_LOG(ERROR) << "Strategy split number must be all zero or all non-zero: " << s; + return lite::RET_ERROR; + } + } + } + return lite::RET_OK; +} + +lite::STATUS OperatorInfo::CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t output_num, + std::vector *outputs) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(outputs); + AbstractBasePtrList ptr_list; + auto cnode = node->cast(); + if (cnode == nullptr) { + MS_LOG(ERROR) << name_ << " : Failed to get CNode."; + return lite::RET_ERROR; + } + + for (size_t i = 0; i < output_num; ++i) { + auto idx = NewValueNode(SizeToInt(i)); + MS_ASSERT(idx); + auto index = std::make_shared(SizeToInt(i)); + auto abstract_scalar = std::make_shared(index); + idx->set_abstract(abstract_scalar); + auto tuple_getitem = func_graph_->NewCNode({NewValueNode(std::make_shared()), node, idx}); + if (tuple_getitem == nullptr) { + MS_LOG(ERROR) << name_ << " : Failed to create output nodes."; + return lite::RET_ERROR; + } + tuple_getitem->set_fullname_with_scope(cnode->fullname_with_scope() + "_TupleGetItem" + std::to_string(i)); + outputs->push_back(tuple_getitem); + ptr_list.push_back(abstract_scalar); + } + node->set_abstract(std::make_shared(ptr_list)); + return lite::RET_OK; +} + +AnfNodePtr OperatorInfo::CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index, + std::vector *split_outputs, size_t split_dim, + size_t split_num, const std::vector &splits, bool trans_format) { + MS_EXCEPTION_IF_NULL(orig_node); + + auto split_prim = std::make_shared(); + split_prim->set_output_num(split_num); + split_prim->set_size_splits(splits); + split_prim->set_axis(split_dim); + auto value_node = NewValueNode(split_prim); + std::vector split_inputs = {value_node}; + split_inputs.push_back(orig_node->input(input_index + 1)); + auto split_cnode = func_graph_->NewCNode(split_inputs); + if (split_cnode == nullptr) { + MS_LOG(ERROR) << name_ << " : Failed to create split node."; + return nullptr; + } + split_cnode->set_fullname_with_scope("Split_" + name_); + CreateMultipleOutputsOfAnfNode(split_cnode, split_num, split_outputs); + + return split_cnode; +} + +AnfNodePtr OperatorInfo::CreateConcateNode(const CNodePtr &orig_node, const std::vector &input_nodes, + int32_t concat_dim, size_t input_nodes_num, bool trans_format) { + MS_EXCEPTION_IF_NULL(orig_node); + + if (input_nodes.size() != input_nodes_num) { + MS_LOG(ERROR) << name_ << " : Input nodes size of concat is not equal to input nodes number."; + return nullptr; + } + auto concat_prim = std::make_shared(); + concat_prim->set_axis(concat_dim); + auto value_node = NewValueNode(concat_prim); + std::vector concat_inputs = {value_node}; + (void)std::transform(input_nodes.begin(), input_nodes.end(), std::back_inserter(concat_inputs), + [](const AnfNodePtr &p) { return p->cast()->input(1); }); + auto concat_cnode = func_graph_->NewCNode(concat_inputs); + if (concat_cnode == nullptr) { + MS_LOG(ERROR) << name_ << " : Failed to create concat node."; + return nullptr; + } + concat_cnode->set_fullname_with_scope("Concat_" + name_); + concat_cnode->set_scope(orig_node->scope()); + + return concat_cnode; +} + +AnfNodePtr OperatorInfo::CreateReduceNode(const CNodePtr &orig_node, const std::vector &input_nodes, + int32_t reduce_dim, size_t input_nodes_num, bool trans_format) { + MS_EXCEPTION_IF_NULL(orig_node); + + if (input_nodes.size() != input_nodes_num) { + MS_LOG(ERROR) << name_ << " : Input nodes size of reduce is not equal to input nodes number."; + return nullptr; + } + // addup inputs element-wise + auto addn_prim = std::make_shared(); + auto value_node = NewValueNode(addn_prim); + std::vector addn_inputs = {value_node}; + (void)std::transform(input_nodes.begin(), input_nodes.end(), std::back_inserter(addn_inputs), + [](const AnfNodePtr &p) { return p->cast()->input(1); }); + auto addn_cnode = func_graph_->NewCNode(addn_inputs); + if (addn_cnode == nullptr) { + MS_LOG(ERROR) << name_ << " : Failed to create concat node."; + return nullptr; + } + addn_cnode->set_fullname_with_scope("AddN_" + name_); + addn_cnode->set_scope(orig_node->scope()); + + return addn_cnode; +} + +lite::STATUS OperatorInfo::Init() { + if (GetAttrs() != lite::RET_OK) { + MS_LOG(ERROR) << name_ << ": Parse attrs failed."; + return lite::RET_ERROR; + } + if (CheckStrategyValue() != lite::RET_OK) { + MS_LOG(ERROR) << name_ << ": Invalid strategy values."; + return lite::RET_ERROR; + } + if (CheckStrategy(strategy_) != lite::RET_OK) { + MS_LOG(ERROR) << name_ << ": Check strategys failed."; + return lite::RET_ERROR; + } + if (InferParallelCNodes() != lite::RET_OK) { + MS_LOG(ERROR) << name_ << ": InferReplaceGraph failed."; + return lite::RET_ERROR; + } + if (SetCNodeBackend() != lite::RET_OK) { + MS_LOG(ERROR) << name_ << ": SetCnodeBackend failed."; + return lite::RET_ERROR; + } + if (InferReplaceOp() != lite::RET_OK) { + MS_LOG(ERROR) << name_ << ": InferForwardOps failed."; + return lite::RET_ERROR; + } + + return lite::RET_OK; +} + +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/parallel/operator_info.h b/mindspore/lite/tools/optimizer/parallel/operator_info.h new file mode 100644 index 0000000000..3f90d4f99f --- /dev/null +++ b/mindspore/lite/tools/optimizer/parallel/operator_info.h @@ -0,0 +1,101 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_PARALLEL_OPERATOR_INFO_H_ +#define MINDSPORE_LITE_SRC_PASS_PARALLEL_OPERATOR_INFO_H_ + +#include +#include +#include +#include +#include + +#include "tools/optimizer/parallel/split_strategy.h" +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "schema/inner/model_generated.h" +#include "include/context.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace opt { +/** + * Do following steps to make a operator support parallel: + * + * 1.Add the schema::PrimitiveType_XXX to ParallelPass::PARALLEL_LIST; + * 2.Add a pair of type and string name to ParallelPass::type_string; + * 3.Implement a class XXXInfo whose parent is OperatorInfo; + * 3.1.Override CheckStrategy(), InferParallelCNodes() and InferReplaceOp() + * 4.include header file of XXXInfo in ops_info_head_files.h + * 5.REGISTER XXXInfo in dynamic_creator.cc + */ +using schema::ReduceMode; + +class OperatorInfo; +using OperatorInfoPtr = std::shared_ptr; + +class OperatorInfo { + public: + OperatorInfo(std::string name, SplitStrategy strategy) + : name_(std::move(name)), + strategy_(std::move(strategy)), + replace_op_(nullptr), + func_graph_(nullptr), + cnode_(nullptr) {} + virtual ~OperatorInfo() = default; + const std::string &name() const { return name_; } + void set_name(const std::string &name) { name_ = name; } + void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; } + void set_cnode(const CNodePtr &cnode) { cnode_ = cnode; } + void setFmk(const int32_t FmkType) { FmkType_ = FmkType; } + AnfNodePtr replace_op() { return replace_op_; } + lite::STATUS Init(); + + protected: + lite::STATUS CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t output_num, + std::vector *outputs); + AnfNodePtr CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index, std::vector *split_outputs, + size_t split_dim, size_t split_num, const std::vector &splits, + bool trans_format); + AnfNodePtr CreateConcateNode(const CNodePtr &orig_node, const std::vector &input_nodes, + int32_t concat_dim, size_t input_nodes_num, bool trans_format); + AnfNodePtr CreateReduceNode(const CNodePtr &orig_node, const std::vector &input_nodes, int32_t reduce_dim, + size_t input_nodes_num, bool trans_format); + virtual lite::STATUS GetAttrs() = 0; + virtual lite::STATUS InferReplaceOp() = 0; + virtual lite::STATUS InferParallelCNodes() = 0; + virtual lite::STATUS CheckStrategy(const SplitStrategy &strategy) = 0; + + std::string name_; + SplitStrategy strategy_; + AnfNodePtr replace_op_; + std::vector parallel_output_nodes_; + FuncGraphPtr func_graph_; + CNodePtr cnode_; + int32_t FmkType_{}; + + private: + lite::STATUS SetCNodeBackend(); + lite::STATUS CheckStrategyValue(); +}; + +bool is_any_none(const std::vector &split); +bool is_any_not_none(const std::vector &split); + +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_PASS_PARALLEL_OPERATOR_INFO_H_ diff --git a/mindspore/lite/tools/optimizer/parallel/parallel_pass.cc b/mindspore/lite/tools/optimizer/parallel/parallel_pass.cc new file mode 100644 index 0000000000..4d9dabb913 --- /dev/null +++ b/mindspore/lite/tools/optimizer/parallel/parallel_pass.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/optimizer/parallel/parallel_pass.h" +#include "include/errorcode.h" +#include "ir/tensor.h" + +namespace mindspore { +namespace opt { +bool ParallelPass::IsParallelCareNode(const AnfNodePtr &node) { + return std::any_of(PARALLEL_LIST.begin(), PARALLEL_LIST.end(), [this, &node](auto &prim) { + if (CheckPrimitiveType(node, prim)) { + type_name_ = PrimToString(prim); + return true; + } else { + return false; + } + }); +} + +std::string ParallelPass::PrimToString(const PrimitivePtr &prim) { + if (type_string.find(prim->name()) == type_string.end()) { + MS_LOG(EXCEPTION) << "String of the type not registered"; + } + return type_string.at(prim->name()); +} + +AnfNodePtr ParallelPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) { + return nullptr; + } + if (!utils::isa(node)) { + return nullptr; + } + auto cnode = node->cast(); + if (CheckIfCNodeIsNull(cnode) != lite::RET_OK) { + return nullptr; + } + if (!IsParallelCareNode(node)) { + return nullptr; + } + std::string cnode_name = cnode->fullname_with_scope(); + std::string name = cnode_name; + std::string orig_name = cnode_name; + // find operator name first, then operator type name. + if (split_strategys_.find(name) == split_strategys_.end()) { + name = type_name_; + } + if (cnode_name.find(PARALLEL_NAME_SUFFIX) != std::string::npos) { + MS_LOG(DEBUG) << " : Skip splited cnode " << cnode_name; + return nullptr; + } + MS_LOG(DEBUG) << " : Reached a parallel care node: " << cnode_name; + if (split_strategys_.find(name) == split_strategys_.end()) { + MS_LOG(DEBUG) << name << " : No split strategy for the current CNode."; + return nullptr; + } + cnode->set_fullname_with_scope(cnode_name + PARALLEL_NAME_SUFFIX); + OperatorInfoPtr operator_ = OperatorInstance(type_name_, orig_name, split_strategys_[name]); + if (operator_ == nullptr) { + MS_LOG(EXCEPTION) << "Failure: Create " << name << " OperatorInstance failed"; + } + operator_->set_cnode(cnode); + operator_->set_func_graph(func_graph); + operator_->setFmk(FmkType_); + if (operator_->Init() == RET_ERROR) { + MS_LOG(EXCEPTION) << "Failure: operator " << name << " init failed"; + } + return operator_->replace_op(); +} + +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/parallel/parallel_pass.h b/mindspore/lite/tools/optimizer/parallel/parallel_pass.h new file mode 100644 index 0000000000..d8af45ad76 --- /dev/null +++ b/mindspore/lite/tools/optimizer/parallel/parallel_pass.h @@ -0,0 +1,55 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include "ir/anf.h" +#include "tools/optimizer/parallel/dynamic_creator.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "mindspore/ccsrc/backend/optimizer/common/node_pass.h" + +#ifndef MINDSPORE_LITE_SRC_PASS_PARALLEL_PARALLEL_PASS_H_ +#define MINDSPORE_LITE_SRC_PASS_PARALLEL_PARALLEL_PASS_H_ + +namespace mindspore { +namespace opt { +class ParallelPass : public opt::NodePass { + public: + explicit ParallelPass(const std::unordered_map strategys, const int32_t FmkType) + : NodePass("parallel_pass"), split_strategys_(strategys), FmkType_(FmkType) {} + ~ParallelPass() override = default; + AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override; + + private: + const std::set PARALLEL_LIST = {prim::kPrimConv2DFusion}; + const std::unordered_map type_string = {{prim::kPrimConv2DFusion->name(), "Conv2D"}}; + + bool IsParallelCareNode(const AnfNodePtr &node); + std::string PrimToString(const PrimitivePtr &prim); + + std::string type_name_; + std::unordered_map split_strategys_; + int32_t FmkType_; +}; + +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_PASS_PARALLEL_PARALLEL_PASS_H_ diff --git a/mindspore/lite/tools/optimizer/parallel/split_strategy.h b/mindspore/lite/tools/optimizer/parallel/split_strategy.h index aa541f1b95..7b69cc3197 100644 --- a/mindspore/lite/tools/optimizer/parallel/split_strategy.h +++ b/mindspore/lite/tools/optimizer/parallel/split_strategy.h @@ -38,12 +38,6 @@ const std::vector kSplitDevTypes = {"CPU", "GPU"}; using Strategys = std::vector>>; -enum Status { - SUCCESS = 0, - FAILED, - INVALID_ARGUMENT, -}; - enum SplitMode { SplitN = 0, SplitH = 1,