| @@ -19,6 +19,7 @@ file(GLOB GRAPH_PASS | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/select_pass.cc | ${CMAKE_CURRENT_SOURCE_DIR}/select_pass.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/subgraph_node_pass.cc | ${CMAKE_CURRENT_SOURCE_DIR}/subgraph_node_pass.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/subgraph_tensor_pass.cc | ${CMAKE_CURRENT_SOURCE_DIR}/subgraph_tensor_pass.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/nested_loop_expand_pass.cc | |||||
| ) | ) | ||||
| set_property(SOURCE ${GRAPH_PASS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) | set_property(SOURCE ${GRAPH_PASS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) | ||||
| add_library(graph_pass_mid OBJECT ${GRAPH_PASS}) | add_library(graph_pass_mid OBJECT ${GRAPH_PASS}) | ||||
| @@ -0,0 +1,98 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include "tools/converter/legacy_optimizer/graph/nested_loop_expand_pass.h" | |||||
| #include "src/common/log_adapter.h" | |||||
| #include "src/common/utils.h" | |||||
| #include "tools/common/graph_util.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "schema/inner/model_generated.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| bool NestedLoopExpandPass::IsNestedPartial(const std::unique_ptr<CNodeT> &node) { | |||||
| if (node->primitive->value.type != PrimitiveType_Partial) { | |||||
| return false; | |||||
| } | |||||
| auto subgraph_idx = ((schema::PartialT *)(node->primitive->value.value))->subGraphIndex; | |||||
| auto &this_subgraph = graph_->subGraph.at(subgraph_idx); | |||||
| for (auto &node_idx : this_subgraph->nodeIndices) { | |||||
| auto &cnode = graph_->nodes.at(node_idx); | |||||
| if (cnode->primitive->value.type == PrimitiveType_Partial) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void NestedLoopExpandPass::ReplacePartialNodeWithSubgraph(const std::unique_ptr<SubGraphT> &main_graph) { | |||||
| bool is_changed = false; | |||||
| for (auto &node_idx : main_graph->nodeIndices) { | |||||
| auto &node = graph_->nodes.at(node_idx); | |||||
| if (!IsNestedPartial(node)) { | |||||
| continue; | |||||
| } | |||||
| is_changed = true; | |||||
| auto subgraph_idx = ((schema::PartialT *)(node->primitive->value.value))->subGraphIndex; | |||||
| auto &this_subgraph = graph_->subGraph.at(subgraph_idx); | |||||
| subgraph_to_drop_.push_back(subgraph_idx); | |||||
| auto partial_pos = std::find(main_graph->nodeIndices.begin(), main_graph->nodeIndices.end(), node_idx); | |||||
| std::vector<uint32_t> tmp; | |||||
| tmp.assign(main_graph->nodeIndices.begin(), partial_pos); | |||||
| tmp.insert(tmp.end(), this_subgraph->nodeIndices.begin(), this_subgraph->nodeIndices.end()); | |||||
| tmp.insert(tmp.end(), partial_pos + 1, main_graph->nodeIndices.end()); | |||||
| main_graph->nodeIndices.assign(tmp.begin(), tmp.end()); | |||||
| } | |||||
| if (is_changed) { | |||||
| ReplacePartialNodeWithSubgraph(main_graph); | |||||
| } | |||||
| } | |||||
| STATUS NestedLoopExpandPass::Run(schema::MetaGraphT *graph) { | |||||
| graph_ = graph; | |||||
| auto &main_graph = graph_->subGraph[0]; | |||||
| ReplacePartialNodeWithSubgraph(main_graph); | |||||
| for (auto idx : subgraph_to_drop_) { | |||||
| graph_->subGraph.at(idx) = nullptr; | |||||
| } | |||||
| for (auto it = graph_->subGraph.begin(); it != graph_->subGraph.end();) { | |||||
| if ((*it) == nullptr) { | |||||
| it = graph_->subGraph.erase(it); | |||||
| } else { | |||||
| it++; | |||||
| } | |||||
| } | |||||
| for (auto &node : graph_->nodes) { | |||||
| if (node->primitive->value.type == PrimitiveType_Partial) { | |||||
| ((schema::PartialT *)(node->primitive->value.value))->subGraphIndex -= subgraph_to_drop_.size(); | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,47 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_NESTED_LOOP_EXPAND_PASS_H | |||||
| #define MINDSPORE_LITE_NESTED_LOOP_EXPAND_PASS_H | |||||
| #include <vector> | |||||
| #include <utility> | |||||
| #include <set> | |||||
| #include <memory> | |||||
| #include "tools/converter/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class NestedLoopExpandPass : public GraphPass { | |||||
| public: | |||||
| NestedLoopExpandPass() = default; | |||||
| ~NestedLoopExpandPass() override = default; | |||||
| STATUS Run(schema::MetaGraphT *graph) override; | |||||
| private: | |||||
| bool IsNestedPartial(const std::unique_ptr<CNodeT> &node); | |||||
| void ReplacePartialNodeWithSubgraph(const std::unique_ptr<SubGraphT> &main_graph); | |||||
| schema::MetaGraphT *graph_ = nullptr; | |||||
| std::vector<int> subgraph_to_drop_{}; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_PREDICT_ISOLATED_NODE_REMOVE_PASS_H | |||||
| @@ -35,7 +35,7 @@ STATUS TensorNamePass::Run(schema::MetaGraphT *graph) { | |||||
| auto tensor_id = node->inputIndex.at(i); | auto tensor_id = node->inputIndex.at(i); | ||||
| auto &tensor = graph->allTensors.at(tensor_id); | auto &tensor = graph->allTensors.at(tensor_id); | ||||
| if (tensor->name.empty()) { | if (tensor->name.empty()) { | ||||
| MS_LOG(WARNING) << "input tensor (id = " << tensor_id << ") name is null"; | |||||
| MS_LOG(DEBUG) << "input tensor (id = " << tensor_id << ") name is null"; | |||||
| tensor->name = node->name + "/input-" + std::to_string(i); | tensor->name = node->name + "/input-" + std::to_string(i); | ||||
| } | } | ||||
| } | } | ||||
| @@ -57,27 +57,27 @@ bool UpdateConv2DParamPass::Run(const FuncGraphPtr &func_graph) { | |||||
| auto conv2d_cnode = node->cast<CNodePtr>(); | auto conv2d_cnode = node->cast<CNodePtr>(); | ||||
| auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv2d_cnode->input(0)); | auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv2d_cnode->input(0)); | ||||
| if (primitive_c == nullptr) { | if (primitive_c == nullptr) { | ||||
| MS_LOG(ERROR) << "Conv2D node has no primitiveC."; | |||||
| MS_LOG(DEBUG) << "Conv2D node has no primitiveC."; | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto primT = primitive_c->primitiveT(); | auto primT = primitive_c->primitiveT(); | ||||
| if (primT == nullptr) { | if (primT == nullptr) { | ||||
| MS_LOG(ERROR) << "Conv2D node has no primitiveT."; | |||||
| MS_LOG(DEBUG) << "Conv2D node has no primitiveT."; | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto conv2d_primt = primT->value.AsConv2D(); | auto conv2d_primt = primT->value.AsConv2D(); | ||||
| auto weight_node = conv2d_cnode->input(lite::kAnfPopulaterInputNumTwo); | auto weight_node = conv2d_cnode->input(lite::kAnfPopulaterInputNumTwo); | ||||
| if (weight_node == nullptr) { | if (weight_node == nullptr) { | ||||
| MS_LOG(ERROR) << "Conv2D weight node is nullptr."; | |||||
| MS_LOG(DEBUG) << "Conv2D weight node is nullptr."; | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (!weight_node->isa<Parameter>()) { | if (!weight_node->isa<Parameter>()) { | ||||
| MS_LOG(ERROR) << "Conv2D weight node is not parameter."; | |||||
| MS_LOG(DEBUG) << "Conv2D weight node is not parameter."; | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto weight_param = weight_node->cast<ParameterPtr>(); | auto weight_param = weight_node->cast<ParameterPtr>(); | ||||
| if (!weight_param->has_default()) { | if (!weight_param->has_default()) { | ||||
| MS_LOG(ERROR) << "Conv2D weight node is not parameter."; | |||||
| MS_LOG(DEBUG) << "Conv2D weight node is not parameter."; | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto default_param = weight_param->default_param(); | auto default_param = weight_param->default_param(); | ||||
| @@ -44,29 +44,11 @@ ValueNodePtr WhilePass::GetSwitchAnfPrim() { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto partial_prim = std::make_shared<lite::Partial>(switch_primitiveT); | |||||
| auto partial_prim = std::make_shared<lite::Switch>(switch_primitiveT); | |||||
| ValueNodePtr partial_anf_prim = NewValueNode(partial_prim); | ValueNodePtr partial_anf_prim = NewValueNode(partial_prim); | ||||
| return partial_anf_prim; | return partial_anf_prim; | ||||
| } | } | ||||
| void WhilePass::ReplaceInput(const std::vector<AnfNodePtr> &node_list, AnfNodePtr new_input_cnode, | |||||
| std::string para_name) { | |||||
| for (auto &node : node_list) { | |||||
| if (utils::isa<CNodePtr>(node)) { | |||||
| auto cnode = utils::cast<CNodePtr>(node); | |||||
| for (size_t k = 0; k < cnode->inputs().size(); k++) { | |||||
| if (!utils::isa<ParameterPtr>(cnode->input(k))) { | |||||
| continue; | |||||
| } | |||||
| auto para_input = utils::cast<ParameterPtr>(cnode->input(k)); | |||||
| if (para_input->name() == para_name) { | |||||
| cnode->set_input(k, new_input_cnode); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| bool WhilePass::Run(const FuncGraphPtr &graph) { | bool WhilePass::Run(const FuncGraphPtr &graph) { | ||||
| auto node_list = TopoSort(graph->get_return()); | auto node_list = TopoSort(graph->get_return()); | ||||
| static int count = 0; | static int count = 0; | ||||
| @@ -87,34 +69,23 @@ bool WhilePass::Run(const FuncGraphPtr &graph) { | |||||
| // the order is fixed. | // the order is fixed. | ||||
| auto cond_vnode = while_cnode->input(kWhileCondIndex); | auto cond_vnode = while_cnode->input(kWhileCondIndex); | ||||
| auto body_vnode = while_cnode->input(kWhileBodyIndex); | auto body_vnode = while_cnode->input(kWhileBodyIndex); | ||||
| // body_vnode->cast<ValueNodePtr>()->set_value() | |||||
| auto cond_fg = GetValueNode<std::shared_ptr<FuncGraph>>(cond_vnode); | auto cond_fg = GetValueNode<std::shared_ptr<FuncGraph>>(cond_vnode); | ||||
| auto body_fg = GetValueNode<std::shared_ptr<FuncGraph>>(body_vnode); | auto body_fg = GetValueNode<std::shared_ptr<FuncGraph>>(body_vnode); | ||||
| if (cond_fg == nullptr || body_fg == nullptr) { | if (cond_fg == nullptr || body_fg == nullptr) { | ||||
| MS_LOG(ERROR) << "Get value as func_graph failed."; | MS_LOG(ERROR) << "Get value as func_graph failed."; | ||||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_FAILED); | lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_FAILED); | ||||
| return false; | return false; | ||||
| } | } | ||||
| // create cond partial cnode | |||||
| std::vector<AnfNodePtr> cond_partial_op_inputs{cond_vnode}; | std::vector<AnfNodePtr> cond_partial_op_inputs{cond_vnode}; | ||||
| // create body partial cnode | |||||
| std::vector<AnfNodePtr> body_partial_op_inputs{body_vnode}; | std::vector<AnfNodePtr> body_partial_op_inputs{body_vnode}; | ||||
| // add while op input to cond_cnode and body_cnode | |||||
| cond_partial_op_inputs.insert(cond_partial_op_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize, | cond_partial_op_inputs.insert(cond_partial_op_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize, | ||||
| while_cnode->inputs().end()); | while_cnode->inputs().end()); | ||||
| body_partial_op_inputs.insert(body_partial_op_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize, | body_partial_op_inputs.insert(body_partial_op_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize, | ||||
| while_cnode->inputs().end()); | while_cnode->inputs().end()); | ||||
| static int idx = 0; | static int idx = 0; | ||||
| auto cond_partial_node = graph->NewCNode(cond_partial_op_inputs); | auto cond_partial_node = graph->NewCNode(cond_partial_op_inputs); | ||||
| cond_partial_node->set_fullname_with_scope("Partial-while-cond-" + std::to_string(idx)); | cond_partial_node->set_fullname_with_scope("Partial-while-cond-" + std::to_string(idx)); | ||||
| cond_partial_node->set_abstract(cond_fg->output()->abstract()); | cond_partial_node->set_abstract(cond_fg->output()->abstract()); | ||||
| auto body_partial_node = graph->NewCNode(body_partial_op_inputs); | auto body_partial_node = graph->NewCNode(body_partial_op_inputs); | ||||
| body_partial_node->set_fullname_with_scope("Partial-while-body-" + std::to_string(idx)); | body_partial_node->set_fullname_with_scope("Partial-while-body-" + std::to_string(idx)); | ||||
| idx++; | idx++; | ||||
| @@ -166,7 +137,6 @@ bool WhilePass::Run(const FuncGraphPtr &graph) { | |||||
| } | } | ||||
| abstract_list.push_back(cnode->abstract()); | abstract_list.push_back(cnode->abstract()); | ||||
| } | } | ||||
| switch_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | switch_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | ||||
| // create cond partial cnode | // create cond partial cnode | ||||
| @@ -176,7 +146,6 @@ bool WhilePass::Run(const FuncGraphPtr &graph) { | |||||
| manager->SetEdge(node_user.first, node_user.second, switch_cnode); | manager->SetEdge(node_user.first, node_user.second, switch_cnode); | ||||
| } | } | ||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| } // namespace mindspore::opt | } // namespace mindspore::opt | ||||
| @@ -32,7 +32,6 @@ class WhilePass : public Pass { | |||||
| bool Run(const FuncGraphPtr &graph) override; | bool Run(const FuncGraphPtr &graph) override; | ||||
| private: | private: | ||||
| void ReplaceInput(const std::vector<AnfNodePtr> &node_list, AnfNodePtr new_input_cnode, std::string para_name); | |||||
| ValueNodePtr GetSwitchAnfPrim(); | ValueNodePtr GetSwitchAnfPrim(); | ||||
| const size_t kWhileMinInputSize = 3; | const size_t kWhileMinInputSize = 3; | ||||