From: @zhujingxuan Reviewed-by: @wangchengyuan,@hangangqiang Signed-off-by: @wangchengyuanpull/15933/MERGE
| @@ -59,9 +59,13 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ../optimizer/fusion/tf_gelu_fusion.cc | |||
| ../optimizer/fusion/onnx_gelu_fusion.cc | |||
| ../optimizer/fusion/squeeze_fusion.cc | |||
| ../optimizer/fisson/eliminate_concat_split.cc | |||
| ../optimizer/fisson/fisson_util.cc | |||
| ../optimizer/fisson/iter_node_outputs.cc | |||
| ../optimizer/fisson/node_out_shapes.cc | |||
| ../optimizer/parallel/dynamic_creator.cc | |||
| ../optimizer/parallel/operator_info.cc | |||
| ../optimizer/parallel/parallel_pass.cc | |||
| ../optimizer/graph/conv1d_inout_adjust_pass.cc | |||
| ../optimizer/graph/weight_format_transform_pass.cc | |||
| ../optimizer/graph/weight_format_hardcode_pass.cc | |||
| @@ -645,6 +645,13 @@ bool IsSqueezeNode(const BaseRef &n) { | |||
| return false; | |||
| } | |||
| bool IsConcatNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimConcat); | |||
| } | |||
| return false; | |||
| } | |||
| bool CheckIsAllInputsParam(const AnfNodePtr &node) { | |||
| if (node == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| @@ -94,6 +94,8 @@ size_t GetOutputTensorNum(const AnfNodePtr &node); | |||
| bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node); | |||
| bool IsConcatNode(const BaseRef &n); | |||
| size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item); | |||
| tensor::TensorPtr GetTensorInfo(const AnfNodePtr &node); | |||
| @@ -0,0 +1,153 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "tools/optimizer/fisson/eliminate_concat_split.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "utils/utils.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "mindspore/core/ops/split_with_overlap.h" | |||
| #include "mindspore/core/ops/concat.h" | |||
| #include "mindspore/core/base/core_ops.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const BaseRef EliminateConcatSplit::DefinePattern() const { | |||
| auto concat_var = std::make_shared<CondVar>(IsConcatNode); | |||
| auto split_prim = std::make_shared<ops::SplitWithOverlap>(); | |||
| return VectorRef({split_prim, concat_var}); | |||
| } | |||
| CNodePtr GetRealPrevCNode(const AnfNodePtr &node) { | |||
| if (node == nullptr || !node->isa<CNode>()) { | |||
| return nullptr; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (IsRealCNodeKernel(cnode)) { | |||
| return cnode; | |||
| } | |||
| auto input0 = cnode->input(0); | |||
| if (IsPrimitive(input0, prim::kPrimMakeTuple)) { | |||
| auto temp_node = cnode->input(1); | |||
| if (temp_node == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return nullptr; | |||
| } | |||
| return GetRealPrevCNode(temp_node); | |||
| } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { | |||
| return GetRealPrevCNode(cnode->input(1)); | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| } | |||
| void ConcatSplitEliminate(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| auto pre_cnode = GetRealPrevCNode(cnode->input(1)); | |||
| CheckIfCNodeIsNull(pre_cnode); | |||
| if (!CheckPrimitiveType(pre_cnode, prim::kPrimConcat)) { | |||
| return; | |||
| } | |||
| auto finder = g_graph_nodes_output.find(pre_cnode->fullname_with_scope()); | |||
| if (finder == g_graph_nodes_output.end()) { | |||
| return; | |||
| } | |||
| if (finder->second.size() > 1) return; | |||
| size_t pre_inputs_size = pre_cnode->inputs().size(); | |||
| int pre_inputs_node_size = pre_inputs_size - 1; | |||
| auto pre_prim = GetValueNode<std::shared_ptr<ops::Concat>>(pre_cnode->input(kAnfPrimitiveIndex)); | |||
| auto prim = GetValueNode<std::shared_ptr<ops::SplitWithOverlap>>(cnode->input(kAnfPrimitiveIndex)); | |||
| if (prim->get_number_split() != pre_inputs_node_size) { | |||
| return; | |||
| } | |||
| // check axis NHWC | |||
| // only support axis "N" now, other axes will support when having "InferShape" | |||
| if (pre_prim->get_axis() != 0) { | |||
| return; | |||
| } | |||
| // get inputs node | |||
| auto it = g_graph_nodes_output.find(cnode->fullname_with_scope()); | |||
| if (it == g_graph_nodes_output.end()) { | |||
| return; | |||
| } | |||
| int out_num = it->second.size(); | |||
| if (out_num != prim->get_number_split()) { | |||
| return; | |||
| } | |||
| std::vector<CNodePtr> inputs_node; | |||
| for (int i = 0; i < out_num; i++) { | |||
| auto tmp = it->second[i]; | |||
| auto tmp_cnode = tmp->cast<CNodePtr>(); | |||
| if (CheckIfCNodeIsNull(tmp_cnode) != lite::RET_OK) { | |||
| return; | |||
| } | |||
| if (!CheckPrimitiveType(tmp_cnode, prim::kPrimTupleGetItem)) { | |||
| return; | |||
| } | |||
| auto tmp_it = g_graph_nodes_output.find(tmp_cnode->fullname_with_scope()); | |||
| if (tmp_it == g_graph_nodes_output.end()) { | |||
| return; | |||
| } | |||
| if (tmp_it->second.size() != 1) return; | |||
| auto next = tmp_it->second[0]; | |||
| auto next_cnode = next->cast<CNodePtr>(); | |||
| inputs_node.push_back(next_cnode); | |||
| } | |||
| // replace inputs | |||
| auto manager = func_graph->manager(); | |||
| if (manager == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return; | |||
| } | |||
| for (size_t i = 1; i < pre_inputs_size; i++) { | |||
| (void)manager->Replace((inputs_node[i - 1])->input(1), pre_cnode->input(i)); | |||
| } | |||
| } | |||
| const AnfNodePtr EliminateConcatSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_LOG(DEBUG) << "Enter EliminateConcatSplit pass process"; | |||
| if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK) { | |||
| return nullptr; | |||
| } | |||
| if (CheckIfAnfNodeIsNull(node) != lite::RET_OK) { | |||
| return nullptr; | |||
| } | |||
| auto split_cnode = node->cast<CNodePtr>(); | |||
| if (CheckIfCNodeIsNull(split_cnode) != lite::RET_OK) { | |||
| return nullptr; | |||
| } | |||
| ConcatSplitEliminate(func_graph, split_cnode); | |||
| return node; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ELIMINATE_CONCAT_SPLIT_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ELIMINATE_CONCAT_SPLIT_H_ | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "tools/optimizer/fisson/fisson_util.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class EliminateConcatSplit : public PatternProcessPass { | |||
| public: | |||
| explicit EliminateConcatSplit(bool multigraph = true) : PatternProcessPass("eliminate_concat_split", multigraph) {} | |||
| ~EliminateConcatSplit() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ELIMINATE_CONCAT_SPLIT_H_ | |||
| @@ -62,7 +62,6 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c | |||
| return nullptr; | |||
| } | |||
| auto conv_node = pre_node->cast<CNodePtr>(); | |||
| MS_ASSERT(primitive_c); | |||
| if (CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion) || | |||
| CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion)) { | |||
| auto prim = GetValueNode<PrimitivePtr>(conv_node->input(0)); | |||
| @@ -63,7 +63,6 @@ const AnfNodePtr PoolingActivationFusion::Process(const FuncGraphPtr &func_graph | |||
| } | |||
| auto pooling_node = pre_node->cast<CNodePtr>(); | |||
| auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(pooling_node->input(0)); | |||
| MS_ASSERT(primitive_c); | |||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Pooling>>(primitive_c)); | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::Pooling>>(primitive_c); | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/parallel/dynamic_creator.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| // operator register | |||
| std::string GetDisOpName(const std::string &prim_name) { | |||
| std::string op_name = prim_name; | |||
| if (!prim_name.empty() && (prim_name[0] == '_')) { | |||
| op_name = prim_name.substr(1); | |||
| } | |||
| return op_name + "Info"; | |||
| } | |||
| // create the OperatorInfo instance | |||
| OperatorInfoPtr OperatorInstance(const std::string &type_name, const std::string &orig_name, | |||
| const SplitStrategy &strategy) { | |||
| if (type_name.length() == 0) { | |||
| MS_LOG(EXCEPTION) << "Length of name is zero!"; | |||
| } | |||
| std::string distribute_opname = GetDisOpName(type_name); | |||
| OperatorInfoPtr operator_ = (OperatorInfoPtr)DynCreator::Instance().Create(distribute_opname, strategy); | |||
| if (operator_ == nullptr) { | |||
| MS_LOG(INFO) << "Create " << type_name << " failed"; | |||
| return nullptr; | |||
| } | |||
| std::string origin_name = operator_->name(); | |||
| operator_->set_name(orig_name); | |||
| MS_LOG(INFO) << "Successfully created operator " << origin_name; | |||
| return operator_; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,80 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_PASS_PARALLEL_DYNAMIC_CREATOR_H_ | |||
| #define MINDSPORE_LITE_SRC_PASS_PARALLEL_DYNAMIC_CREATOR_H_ | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "tools/optimizer/parallel/operator_info.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| #define REGISTER(className) \ | |||
| OperatorInfoPtr objectCreator##className(std::string name, SplitStrategy strategy) { \ | |||
| return std::make_shared<className>(name, strategy); \ | |||
| } \ | |||
| RegisterAction className##Register(#className, (CreatFn)objectCreator##className); | |||
| typedef OperatorInfoPtr (*CreatFn)(const std::string &name, const SplitStrategy &strategy); | |||
| class DynCreator { | |||
| public: | |||
| ~DynCreator() = default; | |||
| // create static singleton dyn_creator instance | |||
| static DynCreator &Instance() { | |||
| static DynCreator fac = DynCreator(); | |||
| return fac; | |||
| } | |||
| // register | |||
| void Register(std::string name, CreatFn func) { (void)Function_map_.insert(std::make_pair(name, func)); } | |||
| // creator | |||
| OperatorInfoPtr Create(const std::string &name, const SplitStrategy &strategy) { | |||
| auto iter = Function_map_.find(name); | |||
| if (iter == Function_map_.end()) { | |||
| MS_LOG(INFO) << name << " is not register yet"; | |||
| return nullptr; | |||
| } | |||
| return iter->second(name, strategy); | |||
| } | |||
| private: | |||
| DynCreator() = default; | |||
| std::map<std::string, CreatFn> Function_map_; | |||
| }; | |||
| class RegisterAction { | |||
| public: | |||
| RegisterAction(const std::string &name, CreatFn creatfn) : name_(name) { | |||
| DynCreator::Instance().Register(name, creatfn); | |||
| } | |||
| ~RegisterAction() = default; | |||
| private: | |||
| std::string name_; | |||
| }; | |||
| OperatorInfoPtr OperatorInstance(const std::string &type_name, const std::string &orig_name, | |||
| const SplitStrategy &strategy); | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_SRC_PASS_PARALLEL_DYNAMIC_CREATOR_H_ | |||
| @@ -0,0 +1,213 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/parallel/operator_info.h" | |||
| #include <algorithm> | |||
| #include "tools/converter/ops/ops_def.h" | |||
| #include "tools/optimizer/parallel/split_strategy.h" | |||
| #include "mindspore/core/ops/concat.h" | |||
| #include "mindspore/core/ops/addn.h" | |||
| #include "mindspore/core/ops/split.h" | |||
| #include "include/lite_types.h" | |||
| #include "mindspore/ccsrc/utils/utils.h" | |||
| #include "base/core_ops.h" | |||
| #include "include/errorcode.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| bool is_any_none(const std::vector<int64_t> &split) { | |||
| return std::any_of(split.begin(), split.end(), [](int64_t v) { return v == static_cast<int64_t>(NoSplit); }); | |||
| } | |||
| bool is_any_not_none(const std::vector<int64_t> &split) { | |||
| return std::any_of(split.begin(), split.end(), [](int64_t v) { return v != static_cast<int64_t>(NoSplit); }); | |||
| } | |||
| lite::STATUS OperatorInfo::SetCNodeBackend() { | |||
| for (size_t i = 0; i < strategy_.dev_num; ++i) { | |||
| lite::DeviceType dt_type; | |||
| std::string type = strategy_.dev_types[i]; | |||
| auto cnode = parallel_output_nodes_[i]->cast<CNodePtr>()->input(1)->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (type == "GPU") { | |||
| dt_type = lite::DeviceType::DT_GPU; | |||
| } else if (type == "CPU") { | |||
| dt_type = lite::DeviceType::DT_CPU; | |||
| } else if (type == "NPU") { | |||
| dt_type = lite::DeviceType::DT_NPU; | |||
| } else { | |||
| MS_LOG(ERROR) << "SetCnodeBackend: unknown device type."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| cnode->AddAttr(mindspore::ops::kDeviceType, MakeValue(static_cast<int>(dt_type))); | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| lite::STATUS OperatorInfo::CheckStrategyValue() { | |||
| auto strategy_size = strategy_.strategys.size(); | |||
| for (size_t index = 0; index < strategy_size; ++index) { | |||
| auto strategy = strategy_.strategys[index]; | |||
| for (const auto &s : strategy) { | |||
| if (s.size() != IntToSize(strategy_.dev_num)) { | |||
| MS_LOG(ERROR) << "Strategy split number:" << s.size() | |||
| << " is not equal to device number: " << strategy_.dev_num; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (is_any_not_none(s) && is_any_none(s)) { | |||
| MS_LOG(ERROR) << "Strategy split number must be all zero or all non-zero: " << s; | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| lite::STATUS OperatorInfo::CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t output_num, | |||
| std::vector<AnfNodePtr> *outputs) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| AbstractBasePtrList ptr_list; | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode == nullptr) { | |||
| MS_LOG(ERROR) << name_ << " : Failed to get CNode."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| for (size_t i = 0; i < output_num; ++i) { | |||
| auto idx = NewValueNode(SizeToInt(i)); | |||
| MS_ASSERT(idx); | |||
| auto index = std::make_shared<Int32Imm>(SizeToInt(i)); | |||
| auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(index); | |||
| idx->set_abstract(abstract_scalar); | |||
| auto tuple_getitem = func_graph_->NewCNode({NewValueNode(std::make_shared<lite::TupleGetItem>()), node, idx}); | |||
| if (tuple_getitem == nullptr) { | |||
| MS_LOG(ERROR) << name_ << " : Failed to create output nodes."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| tuple_getitem->set_fullname_with_scope(cnode->fullname_with_scope() + "_TupleGetItem" + std::to_string(i)); | |||
| outputs->push_back(tuple_getitem); | |||
| ptr_list.push_back(abstract_scalar); | |||
| } | |||
| node->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list)); | |||
| return lite::RET_OK; | |||
| } | |||
| AnfNodePtr OperatorInfo::CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index, | |||
| std::vector<AnfNodePtr> *split_outputs, size_t split_dim, | |||
| size_t split_num, const std::vector<int64_t> &splits, bool trans_format) { | |||
| MS_EXCEPTION_IF_NULL(orig_node); | |||
| auto split_prim = std::make_shared<ops::Split>(); | |||
| split_prim->set_output_num(split_num); | |||
| split_prim->set_size_splits(splits); | |||
| split_prim->set_axis(split_dim); | |||
| auto value_node = NewValueNode(split_prim); | |||
| std::vector<AnfNodePtr> split_inputs = {value_node}; | |||
| split_inputs.push_back(orig_node->input(input_index + 1)); | |||
| auto split_cnode = func_graph_->NewCNode(split_inputs); | |||
| if (split_cnode == nullptr) { | |||
| MS_LOG(ERROR) << name_ << " : Failed to create split node."; | |||
| return nullptr; | |||
| } | |||
| split_cnode->set_fullname_with_scope("Split_" + name_); | |||
| CreateMultipleOutputsOfAnfNode(split_cnode, split_num, split_outputs); | |||
| return split_cnode; | |||
| } | |||
| AnfNodePtr OperatorInfo::CreateConcateNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes, | |||
| int32_t concat_dim, size_t input_nodes_num, bool trans_format) { | |||
| MS_EXCEPTION_IF_NULL(orig_node); | |||
| if (input_nodes.size() != input_nodes_num) { | |||
| MS_LOG(ERROR) << name_ << " : Input nodes size of concat is not equal to input nodes number."; | |||
| return nullptr; | |||
| } | |||
| auto concat_prim = std::make_shared<ops::Concat>(); | |||
| concat_prim->set_axis(concat_dim); | |||
| auto value_node = NewValueNode(concat_prim); | |||
| std::vector<AnfNodePtr> concat_inputs = {value_node}; | |||
| (void)std::transform(input_nodes.begin(), input_nodes.end(), std::back_inserter(concat_inputs), | |||
| [](const AnfNodePtr &p) { return p->cast<CNodePtr>()->input(1); }); | |||
| auto concat_cnode = func_graph_->NewCNode(concat_inputs); | |||
| if (concat_cnode == nullptr) { | |||
| MS_LOG(ERROR) << name_ << " : Failed to create concat node."; | |||
| return nullptr; | |||
| } | |||
| concat_cnode->set_fullname_with_scope("Concat_" + name_); | |||
| concat_cnode->set_scope(orig_node->scope()); | |||
| return concat_cnode; | |||
| } | |||
| AnfNodePtr OperatorInfo::CreateReduceNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes, | |||
| int32_t reduce_dim, size_t input_nodes_num, bool trans_format) { | |||
| MS_EXCEPTION_IF_NULL(orig_node); | |||
| if (input_nodes.size() != input_nodes_num) { | |||
| MS_LOG(ERROR) << name_ << " : Input nodes size of reduce is not equal to input nodes number."; | |||
| return nullptr; | |||
| } | |||
| // addup inputs element-wise | |||
| auto addn_prim = std::make_shared<ops::AddN>(); | |||
| auto value_node = NewValueNode(addn_prim); | |||
| std::vector<AnfNodePtr> addn_inputs = {value_node}; | |||
| (void)std::transform(input_nodes.begin(), input_nodes.end(), std::back_inserter(addn_inputs), | |||
| [](const AnfNodePtr &p) { return p->cast<CNodePtr>()->input(1); }); | |||
| auto addn_cnode = func_graph_->NewCNode(addn_inputs); | |||
| if (addn_cnode == nullptr) { | |||
| MS_LOG(ERROR) << name_ << " : Failed to create concat node."; | |||
| return nullptr; | |||
| } | |||
| addn_cnode->set_fullname_with_scope("AddN_" + name_); | |||
| addn_cnode->set_scope(orig_node->scope()); | |||
| return addn_cnode; | |||
| } | |||
| lite::STATUS OperatorInfo::Init() { | |||
| if (GetAttrs() != lite::RET_OK) { | |||
| MS_LOG(ERROR) << name_ << ": Parse attrs failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (CheckStrategyValue() != lite::RET_OK) { | |||
| MS_LOG(ERROR) << name_ << ": Invalid strategy values."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (CheckStrategy(strategy_) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << name_ << ": Check strategys failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (InferParallelCNodes() != lite::RET_OK) { | |||
| MS_LOG(ERROR) << name_ << ": InferReplaceGraph failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (SetCNodeBackend() != lite::RET_OK) { | |||
| MS_LOG(ERROR) << name_ << ": SetCnodeBackend failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (InferReplaceOp() != lite::RET_OK) { | |||
| MS_LOG(ERROR) << name_ << ": InferForwardOps failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,101 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_PASS_PARALLEL_OPERATOR_INFO_H_ | |||
| #define MINDSPORE_LITE_SRC_PASS_PARALLEL_OPERATOR_INFO_H_ | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include "tools/optimizer/parallel/split_strategy.h" | |||
| #include "ir/anf.h" | |||
| #include "ir/func_graph.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "include/context.h" | |||
| #include "include/errorcode.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| /** | |||
| * Do following steps to make a operator support parallel: | |||
| * | |||
| * 1.Add the schema::PrimitiveType_XXX to ParallelPass::PARALLEL_LIST; | |||
| * 2.Add a pair of type and string name to ParallelPass::type_string; | |||
| * 3.Implement a class XXXInfo whose parent is OperatorInfo; | |||
| * 3.1.Override CheckStrategy(), InferParallelCNodes() and InferReplaceOp() | |||
| * 4.include header file of XXXInfo in ops_info_head_files.h | |||
| * 5.REGISTER XXXInfo in dynamic_creator.cc | |||
| */ | |||
| using schema::ReduceMode; | |||
| class OperatorInfo; | |||
| using OperatorInfoPtr = std::shared_ptr<OperatorInfo>; | |||
| class OperatorInfo { | |||
| public: | |||
| OperatorInfo(std::string name, SplitStrategy strategy) | |||
| : name_(std::move(name)), | |||
| strategy_(std::move(strategy)), | |||
| replace_op_(nullptr), | |||
| func_graph_(nullptr), | |||
| cnode_(nullptr) {} | |||
| virtual ~OperatorInfo() = default; | |||
| const std::string &name() const { return name_; } | |||
| void set_name(const std::string &name) { name_ = name; } | |||
| void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; } | |||
| void set_cnode(const CNodePtr &cnode) { cnode_ = cnode; } | |||
| void setFmk(const int32_t FmkType) { FmkType_ = FmkType; } | |||
| AnfNodePtr replace_op() { return replace_op_; } | |||
| lite::STATUS Init(); | |||
| protected: | |||
| lite::STATUS CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t output_num, | |||
| std::vector<AnfNodePtr> *outputs); | |||
| AnfNodePtr CreateOutputsOfSplit(const CNodePtr &orig_node, size_t input_index, std::vector<AnfNodePtr> *split_outputs, | |||
| size_t split_dim, size_t split_num, const std::vector<int64_t> &splits, | |||
| bool trans_format); | |||
| AnfNodePtr CreateConcateNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes, | |||
| int32_t concat_dim, size_t input_nodes_num, bool trans_format); | |||
| AnfNodePtr CreateReduceNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes, int32_t reduce_dim, | |||
| size_t input_nodes_num, bool trans_format); | |||
| virtual lite::STATUS GetAttrs() = 0; | |||
| virtual lite::STATUS InferReplaceOp() = 0; | |||
| virtual lite::STATUS InferParallelCNodes() = 0; | |||
| virtual lite::STATUS CheckStrategy(const SplitStrategy &strategy) = 0; | |||
| std::string name_; | |||
| SplitStrategy strategy_; | |||
| AnfNodePtr replace_op_; | |||
| std::vector<AnfNodePtr> parallel_output_nodes_; | |||
| FuncGraphPtr func_graph_; | |||
| CNodePtr cnode_; | |||
| int32_t FmkType_{}; | |||
| private: | |||
| lite::STATUS SetCNodeBackend(); | |||
| lite::STATUS CheckStrategyValue(); | |||
| }; | |||
| bool is_any_none(const std::vector<int64_t> &split); | |||
| bool is_any_not_none(const std::vector<int64_t> &split); | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_SRC_PASS_PARALLEL_OPERATOR_INFO_H_ | |||
| @@ -0,0 +1,86 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/parallel/parallel_pass.h" | |||
| #include "include/errorcode.h" | |||
| #include "ir/tensor.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| bool ParallelPass::IsParallelCareNode(const AnfNodePtr &node) { | |||
| return std::any_of(PARALLEL_LIST.begin(), PARALLEL_LIST.end(), [this, &node](auto &prim) { | |||
| if (CheckPrimitiveType(node, prim)) { | |||
| type_name_ = PrimToString(prim); | |||
| return true; | |||
| } else { | |||
| return false; | |||
| } | |||
| }); | |||
| } | |||
| std::string ParallelPass::PrimToString(const PrimitivePtr &prim) { | |||
| if (type_string.find(prim->name()) == type_string.end()) { | |||
| MS_LOG(EXCEPTION) << "String of the type not registered"; | |||
| } | |||
| return type_string.at(prim->name()); | |||
| } | |||
| AnfNodePtr ParallelPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||
| if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) { | |||
| return nullptr; | |||
| } | |||
| if (!utils::isa<CNode>(node)) { | |||
| return nullptr; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (CheckIfCNodeIsNull(cnode) != lite::RET_OK) { | |||
| return nullptr; | |||
| } | |||
| if (!IsParallelCareNode(node)) { | |||
| return nullptr; | |||
| } | |||
| std::string cnode_name = cnode->fullname_with_scope(); | |||
| std::string name = cnode_name; | |||
| std::string orig_name = cnode_name; | |||
| // find operator name first, then operator type name. | |||
| if (split_strategys_.find(name) == split_strategys_.end()) { | |||
| name = type_name_; | |||
| } | |||
| if (cnode_name.find(PARALLEL_NAME_SUFFIX) != std::string::npos) { | |||
| MS_LOG(DEBUG) << " : Skip splited cnode " << cnode_name; | |||
| return nullptr; | |||
| } | |||
| MS_LOG(DEBUG) << " : Reached a parallel care node: " << cnode_name; | |||
| if (split_strategys_.find(name) == split_strategys_.end()) { | |||
| MS_LOG(DEBUG) << name << " : No split strategy for the current CNode."; | |||
| return nullptr; | |||
| } | |||
| cnode->set_fullname_with_scope(cnode_name + PARALLEL_NAME_SUFFIX); | |||
| OperatorInfoPtr operator_ = OperatorInstance(type_name_, orig_name, split_strategys_[name]); | |||
| if (operator_ == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failure: Create " << name << " OperatorInstance failed"; | |||
| } | |||
| operator_->set_cnode(cnode); | |||
| operator_->set_func_graph(func_graph); | |||
| operator_->setFmk(FmkType_); | |||
| if (operator_->Init() == RET_ERROR) { | |||
| MS_LOG(EXCEPTION) << "Failure: operator " << name << " init failed"; | |||
| } | |||
| return operator_->replace_op(); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,55 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <set> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "ir/anf.h" | |||
| #include "tools/optimizer/parallel/dynamic_creator.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "mindspore/ccsrc/backend/optimizer/common/node_pass.h" | |||
| #ifndef MINDSPORE_LITE_SRC_PASS_PARALLEL_PARALLEL_PASS_H_ | |||
| #define MINDSPORE_LITE_SRC_PASS_PARALLEL_PARALLEL_PASS_H_ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class ParallelPass : public opt::NodePass { | |||
| public: | |||
| explicit ParallelPass(const std::unordered_map<std::string, SplitStrategy> strategys, const int32_t FmkType) | |||
| : NodePass("parallel_pass"), split_strategys_(strategys), FmkType_(FmkType) {} | |||
| ~ParallelPass() override = default; | |||
| AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override; | |||
| private: | |||
| const std::set<PrimitivePtr> PARALLEL_LIST = {prim::kPrimConv2DFusion}; | |||
| const std::unordered_map<std::string, std::string> type_string = {{prim::kPrimConv2DFusion->name(), "Conv2D"}}; | |||
| bool IsParallelCareNode(const AnfNodePtr &node); | |||
| std::string PrimToString(const PrimitivePtr &prim); | |||
| std::string type_name_; | |||
| std::unordered_map<std::string, SplitStrategy> split_strategys_; | |||
| int32_t FmkType_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_SRC_PASS_PARALLEL_PARALLEL_PASS_H_ | |||
| @@ -38,12 +38,6 @@ const std::vector<std::string> kSplitDevTypes = {"CPU", "GPU"}; | |||
| using Strategys = std::vector<std::vector<std::vector<int64_t>>>; | |||
| enum Status { | |||
| SUCCESS = 0, | |||
| FAILED, | |||
| INVALID_ARGUMENT, | |||
| }; | |||
| enum SplitMode { | |||
| SplitN = 0, | |||
| SplitH = 1, | |||