/** * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tools/optimizer/graph/functionalize_cond.h" #include #include #include #include #include "include/errorcode.h" #include "tools/converter/ops/ops_def.h" namespace mindspore::opt { STATUS FunctionalizeCond::GetSwitchBranchType(const CNodePtr &switch_cnode, BranchType *branch_type) { MS_ASSERT(switch_cnode != nullptr); MS_ASSERT(branch_type != nullptr); auto manager = fg_->manager(); if (manager == nullptr) { MS_LOG(ERROR) << "manager is nullptr"; return RET_ERROR; } auto node_users = manager->node_users()[switch_cnode]; if (node_users.size() != 1) { // only one output of switch is referenced in cond MS_LOG(ERROR) << "switch's node users is not correct"; return RET_ERROR; } auto node_user = node_users.front(); auto tuple_get_item = node_user.first; if (!utils::isa(tuple_get_item) || !CheckPrimitiveType(tuple_get_item, prim::kPrimTupleGetItem)) { MS_LOG(ERROR) << "switch's node user is not TupleGetItem"; return RET_ERROR; } auto tuple_get_item_cnode = utils::cast(tuple_get_item); auto idx = GetTupleGetItemOutIndex(tuple_get_item_cnode); if (idx == 0) { *branch_type = kElseBranch; } else if (idx == 1) { *branch_type = kThenBranch; } else { MS_LOG(ERROR) << "wrong tuple_get_item index"; return RET_ERROR; } return RET_OK; } STATUS FunctionalizeCond::BranchSubGraphAddNodes(const FuncGraphPtr &graph, const AnfNodePtr &root_node, BranchType branch_type) { std::deque q; std::unordered_set vis; q.push_back(root_node); while (!q.empty()) { auto node = q.front(); q.pop_front(); vis.insert(node); if (FunctionalizeControlOpPass::IsSwitch(node)) { auto cnode = utils::cast(node); BranchType this_type; if (GetSwitchBranchType(cnode, &this_type) != RET_OK || this_type != branch_type) { MS_LOG(ERROR) << "switch node in branch " << branch_type << " is not correct"; return RET_ERROR; } continue; } if (utils::isa(node)) { graph->add_parameter(node->cast()); } else { graph->AddNode(node); } node->set_func_graph(graph); if (utils::isa(node)) { auto cnode = utils::cast(node); for (size_t i = 1; i < cnode->inputs().size(); i++) { auto inputi = cnode->input(i); if (vis.find(inputi) == vis.end()) { q.push_back(cnode->input(i)); } } } } return RET_OK; } int FunctionalizeCond::PosInInputNodes(const CNodePtr &node) { auto index = std::find(input_nodes_.begin(), input_nodes_.end(), node); if (index == input_nodes_.end()) { input_nodes_.push_back(node); return input_nodes_.size() - 1; } return index - input_nodes_.begin(); } STATUS FunctionalizeCond::IdentifySubgraphInput(const FuncGraphPtr &graph, std::string graph_name) { std::vector nodes_need_drop{}; for (auto &cnode : graph->GetOrderedCnodes()) { for (auto &input_node : cnode->inputs()) { if (FunctionalizeControlOpPass::IsSwitch(input_node)) { auto switch_node = input_node->cast(); auto switch_input = utils::cast(switch_node->input(1)); auto pos = PosInInputNodes(switch_input); nodes_need_drop.push_back(cnode); pred_nodes_.push_back(switch_node->input(2)); // set parameter auto parameter = graph->add_parameter(); parameter->set_abstract(cnode->abstract()); // hardcode for subgraph input name parameter->set_name(graph_name + "_input_" + std::to_string(pos) + "_parameter"); // replace switch auto manager = fg_->manager(); auto node_users = manager->node_users()[cnode]; for (auto &node_user : node_users) { if (graph->nodes().contains(node_user.first)) { manager->SetEdge(node_user.first, node_user.second, parameter); } } } } } return RET_OK; } FuncGraphPtr FunctionalizeCond::CreateBranchGraph(const AnfNodePtr &node, std::string name, BranchType branch_type) { auto graph = FunctionalizeControlOpPass::NewFuncGraph(name, mindspore::lite::converter::FmkType_TF); if (graph == nullptr) { MS_LOG(ERROR) << "new graph Partial Node return nullptr"; return nullptr; } graph->set_manager(fg_->manager()); auto status = BranchSubGraphAddNodes(graph, node, branch_type); if (status != RET_OK) { return nullptr; } if (!CheckPrimitiveType(node, prim::kPrimSwitch)) { // graph is not empty auto return_prim_ptr = std::make_shared(); if (return_prim_ptr == nullptr) { MS_LOG(ERROR) << "GetReturnPrim return nullptr"; return nullptr; } auto value_node = NewValueNode(return_prim_ptr); std::vector op_inputs{value_node, node}; // If subgraph only has one output tensor auto return_cnode = graph->NewCNode(op_inputs); return_cnode->set_fullname_with_scope(name + "-return"); return_cnode->set_func_graph(graph); graph->set_return(return_cnode); graph->output()->cast()->set_fullname_with_scope(name + "_output_0_cnode"); } return graph; } CNodePtr FunctionalizeCond::CreateNewIf(const FuncGraphPtr &else_branch, const FuncGraphPtr &then_branch) { MS_ASSERT(else_branch != nullptr); MS_ASSERT(then_branch != nullptr); auto if_primc = std::make_shared(); if (if_primc == nullptr) { MS_LOG(ERROR) << "new if_primitive failed"; return nullptr; } auto if_value_node = NewValueNode(if_primc); if (if_value_node == nullptr) { return nullptr; } auto then_value_node = NewValueNode(then_branch); auto else_value_node = NewValueNode(else_branch); std::vector if_op_inputs = {if_value_node, then_value_node, else_value_node, pred_node_}; std::copy(input_nodes_.begin(), input_nodes_.end(), std::back_inserter(if_op_inputs)); return fg_->NewCNode(if_op_inputs); } STATUS FunctionalizeCond::VerifyPredictNode() { if (pred_nodes_.empty()) { return RET_ERROR; } for (size_t i = 1; i < pred_nodes_.size(); ++i) { if (pred_nodes_[i] != pred_nodes_[0]) { return RET_ERROR; } } if (!utils::isa(pred_nodes_[0])) { return RET_ERROR; } pred_node_ = utils::cast(pred_nodes_[0]); return RET_OK; } STATUS FunctionalizeCond::Process() { if (fg_ == nullptr || merge_node_ == nullptr || merge_node_->inputs().size() != 3) { MS_LOG(ERROR) << "fg or merge is not correct"; return RET_ERROR; } auto else_branch_name = merge_node_->fullname_with_scope() + "-partial-if-else"; auto then_branch_name = merge_node_->fullname_with_scope() + "-partial-then-else"; auto else_branch = CreateBranchGraph(merge_node_->input(1), else_branch_name, kElseBranch); if (else_branch == nullptr) { MS_LOG(ERROR) << "create else branch failed"; return RET_ERROR; } auto then_branch = CreateBranchGraph(merge_node_->input(2), then_branch_name, kThenBranch); if (then_branch == nullptr) { MS_LOG(ERROR) << "create then branch failed"; return RET_ERROR; } auto status = IdentifySubgraphInput(else_branch, else_branch_name); if (status != RET_OK) { return status; } status = IdentifySubgraphInput(then_branch, then_branch_name); if (status != RET_OK) { return status; } status = VerifyPredictNode(); if (status != RET_OK) { return status; } auto if_node = CreateNewIf(else_branch, then_branch); if (if_node == nullptr) { MS_LOG(ERROR) << "create if node error"; return RET_ERROR; } if_node->set_abstract(merge_node_->abstract()->Clone()); auto manager = fg_->manager(); auto node_users = manager->node_users()[merge_node_]; for (auto &node_user : node_users) { manager->SetEdge(node_user.first, node_user.second, if_node); } return RET_OK; } } // namespace mindspore::opt