/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 *conv_activation_fusion.h * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include "tools/optimizer/graph/functionalize_while.h" #include "include/errorcode.h" #include "ops/make_tuple.h" #include "ops/return.h" #include "ops/tuple_get_item.h" #include "tools/converter/ops/while.h" namespace { mindspore::ValueNodePtr GetWhileAnfPrim() { auto while_primc = std::make_shared(); if (while_primc == nullptr) { MS_LOG(ERROR) << "new while_primitive failed"; return nullptr; } while_primc->set_cond_subgraph_index(mindspore::opt::FunctionalizeControlOpPass::GetSubgraphIndex()); while_primc->set_body_subgraph_index(mindspore::opt::FunctionalizeControlOpPass::GetSubgraphIndex()); mindspore::ValueNodePtr partial_anf_prim = NewValueNode(while_primc); return partial_anf_prim; } } // namespace namespace mindspore::opt { using mindspore::lite::RET_NULL_PTR; CNodePtr FunctionalizeWhile::BlongToWhichSwitch(const CNodePtr &node) { return FunctionalizeControlOpPass::BelongToWhichNode(node, FunctionalizeControlOpPass::IsSwitch); } CNodePtr FunctionalizeWhile::BlongToWhichMerge(const CNodePtr &node) { return FunctionalizeControlOpPass::BelongToWhichNode(node, FunctionalizeControlOpPass::IsMerge); } CNodePtr FunctionalizeWhile::BlongToWhichEnter(const CNodePtr &node) { return FunctionalizeControlOpPass::BelongToWhichNode(node, FunctionalizeControlOpPass::IsEnter); } CNodePtr FunctionalizeWhile::BlongToWhichExternalEnter(const CNodePtr &node) { if (node == nullptr) { return nullptr; } if (FunctionalizeControlOpPass::IsEnter(node)) { return node; } CNodePtr aim_node = nullptr; std::deque todo(256); todo.clear(); for (auto &input_node : node->inputs()) { if (FunctionalizeControlOpPass::IsEnter(input_node) && WhileNodeExternalInputIsContain(input_node)) { aim_node = utils::cast(input_node); todo.clear(); break; } todo.push_back(input_node); } while (!todo.empty()) { AnfNodePtr todo_node = todo.front(); todo.pop_front(); if (FunctionalizeControlOpPass::IsEnter(todo_node) && WhileNodeExternalInputIsContain(todo_node)) { aim_node = utils::cast(todo_node); todo.clear(); break; } if (utils::isa(todo_node)) { auto cnode = utils::cast(todo_node); for (size_t i = 0; i < cnode->inputs().size(); i++) { todo.push_back(cnode->input(i)); } } } if (aim_node == nullptr) { MS_LOG(WARNING) << "not found belonging enter node."; return nullptr; } return aim_node; } int FunctionalizeWhile::PosInInputEnterNodes(const CNodePtr &node) { auto index = std::find(input_enter_nodes_.begin(), input_enter_nodes_.end(), node); if (index == input_enter_nodes_.end()) { return POS_INVALID; } return index - input_enter_nodes_.begin(); } STATUS FunctionalizeWhile::NewWhileNode() { ValueNodePtr while_anf_primitive = GetWhileAnfPrim(); if (while_anf_primitive == nullptr) { MS_LOG(ERROR) << "Get while anf primitive failed."; return RET_NULL_PTR; } static int count = 0; std::vector while_op_inputs = {while_anf_primitive}; while_node_ = fg_->NewCNode(while_op_inputs); while_node_->set_fullname_with_scope(loop_cond_node_->fullname_with_scope() + "-while-" + std::to_string(count++)); return RET_OK; } STATUS FunctionalizeWhile::IdentifyWhileNodeInput() { for (auto &node : node_cluster_) { if (FunctionalizeControlOpPass::IsEnter(node)) { auto enter_cnode = node->cast(); input_enter_nodes_.push_back(enter_cnode); while_node_->add_input(enter_cnode->input(1)); } } if (input_enter_nodes_.empty()) { MS_LOG(ERROR) << "not found input of while node."; return RET_ERROR; } return RET_OK; } STATUS FunctionalizeWhile::IdentifyWhileNodeExternalInput() { std::deque todo(128); std::vector merge_nodes{}; todo.clear(); for (size_t i = 1; i < loop_cond_node_->inputs().size(); i++) { todo.push_back(loop_cond_node_->input(i)); } while (!todo.empty()) { AnfNodePtr node = todo.front(); todo.pop_front(); if (FunctionalizeControlOpPass::IsMerge(node)) { merge_nodes.push_back(node->cast()); continue; } if (utils::isa(node)) { auto cnode = utils::cast(node); for (size_t i = 1; i < cnode->inputs().size(); i++) { todo.push_back(cnode->input(i)); } } } for (auto &node : merge_nodes) { external_input_enter_nodes_.push_back(node->input(1)->cast()); } return RET_OK; } bool FunctionalizeWhile::WhileNodeExternalInputIsContain(const AnfNodePtr &node) { auto cnode = node->cast(); return std::find(external_input_enter_nodes_.begin(), external_input_enter_nodes_.end(), cnode) != external_input_enter_nodes_.end(); } STATUS FunctionalizeWhile::IdentifyWhileNodeOutput() { output_exit_nodes_.resize(external_input_enter_nodes_.size()); for (auto &node : node_cluster_) { // exit ->switch->merge->enter if (FunctionalizeControlOpPass::IsExit(node)) { auto exit_node = node->cast(); auto switch_node = BlongToWhichSwitch(exit_node); auto merge_node = BlongToWhichMerge(switch_node); auto enter_node = BlongToWhichExternalEnter(merge_node); int pos = PosInInputEnterNodes(enter_node); if (pos == -1) { MS_LOG(ERROR) << "not find in input enter nodes."; return RET_ERROR; } output_exit_nodes_.at(pos) = exit_node; } } if (output_exit_nodes_.size() == 1) { while_node_->set_abstract(output_exit_nodes_[0]->abstract()); } else { AbstractBasePtrList abstract_list; abstract_list.resize(output_exit_nodes_.size()); std::transform(output_exit_nodes_.begin(), output_exit_nodes_.end(), abstract_list.begin(), [](const CNodePtr &cnode) { return cnode->abstract(); }); while_node_->set_abstract(std::make_shared(abstract_list)); } return RET_OK; } STATUS FunctionalizeWhile::UpdateExitNodeUser() { auto manager = fg_->manager(); if (output_exit_nodes_.size() == 1) { if (!manager->Replace(output_exit_nodes_[0], while_node_)) { MS_LOG(ERROR) << "replace node failed."; return RET_ERROR; } } else { for (auto &node : output_exit_nodes_) { auto node_users = manager->node_users()[node]; for (auto &node_user : node_users) { // new getitem AbstractBasePtrList abstractList; std::vector shape_vector; abstractList.emplace_back(std::make_shared(kFloat32, shape_vector)); auto tuple_get_item_prim_ptr = std::make_shared(); if (tuple_get_item_prim_ptr == nullptr) { MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr"; return RET_NULL_PTR; } auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr); const auto &exit_node = node; auto switch_node = BlongToWhichSwitch(exit_node); auto merge_node = BlongToWhichMerge(switch_node); auto enter_node = BlongToWhichExternalEnter(merge_node); int output_idx = PosInInputEnterNodes(enter_node); auto getItemValue = NewValueNode(MakeValue(output_idx)); std::vector inputs{tuple_get_item_prim, while_node_, getItemValue}; CNodePtr get_item_node = fg_->NewCNode(inputs); std::string output_item_name = while_node_->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); auto abstract = std::make_shared(kFloat32, shape_vector); if (abstract == nullptr) { MS_LOG(ERROR) << "create AbstractTensor failed"; return RET_NULL_PTR; } get_item_node->set_abstract(abstract); get_item_node->set_fullname_with_scope(output_item_name); // set if (fg_->nodes().contains(node_user.first)) { if (!manager->Replace(output_exit_nodes_[0], while_node_)) { MS_LOG(ERROR) << "replace node failed."; return RET_ERROR; } } } } } return RET_OK; } STATUS FunctionalizeWhile::BuildWhileNode() { int ret = NewWhileNode(); if (ret != RET_OK) { MS_LOG(ERROR) << "new while node failed, ret:" << ret; return ret; } ret = IdentifyWhileNodeInput(); if (ret != RET_OK) { MS_LOG(ERROR) << "identify while node input failed, ret:" << ret; return ret; } ret = IdentifyWhileNodeExternalInput(); if (ret != RET_OK) { MS_LOG(ERROR) << "identify while node external input failed, ret:" << ret; return ret; } ret = IdentifyWhileNodeOutput(); if (ret != RET_OK) { MS_LOG(ERROR) << "identify while node output failed, ret:" << ret; return ret; } // update exit node user from exit to while ret = UpdateExitNodeUser(); if (ret != RET_OK) { MS_LOG(ERROR) << "update while node users, ret:" << ret; return ret; } return ret; } // nodes between loop_cond op and merge op be added into cond_func_graph STATUS FunctionalizeWhile::CondSubgraphAddNodes() { std::deque todo(512); todo.clear(); for (size_t i = 1; i < loop_cond_node_->inputs().size(); i++) { todo.push_back(loop_cond_node_->input(i)); } while (!todo.empty()) { AnfNodePtr node = todo.back(); todo.pop_back(); if (FunctionalizeControlOpPass::IsMerge(node)) { continue; } if (utils::isa(node)) { cond_sub_func_graph_->add_parameter(node->cast()); } else { cond_sub_func_graph_->AddNode(node); } node->set_func_graph(cond_sub_func_graph_); if (utils::isa(node)) { auto cnode = utils::cast(node); for (size_t i = 1; i < cnode->inputs().size(); i++) { todo.push_back(cnode->input(i)); } } } return RET_OK; } STATUS FunctionalizeWhile::IdentifyCondSubgraphInput() { std::vector nodes_need_drop{}; for (auto &cnode : cond_sub_func_graph_->GetOrderedCnodes()) { for (auto &input_node : cnode->inputs()) { if (FunctionalizeControlOpPass::IsMerge(input_node)) { auto merge_node = input_node->cast(); auto enter_node = BlongToWhichEnter(merge_node); int pos = PosInInputEnterNodes(enter_node); nodes_need_drop.push_back(cnode); // set parameter auto parameter = cond_sub_func_graph_->add_parameter(); parameter->set_abstract(cnode->abstract()); // hardcode for subgraph input name parameter->set_name(cond_subgraph_name_ + "_input_" + std::to_string(pos) + "_parameter"); // replace merge auto manager = fg_->manager(); auto node_users = manager->node_users()[cnode]; for (auto &node_user : node_users) { if (cond_sub_func_graph_->nodes().contains(node_user.first)) { manager->SetEdge(node_user.first, node_user.second, parameter); } } } } } // drop node from cond_func_graph for (const auto &node : nodes_need_drop) { cond_sub_func_graph_->DropNode(node); } return RET_OK; } STATUS FunctionalizeWhile::IdentifyCondSubgraphOutput() { auto return_prim_ptr = std::make_shared(); if (return_prim_ptr == nullptr) { MS_LOG(ERROR) << "GetReturnPrim return nullptr"; return RET_NULL_PTR; } auto value_node = NewValueNode(return_prim_ptr); if (value_node == nullptr) { MS_LOG(ERROR) << "new value_node failed."; return RET_NULL_PTR; } // cond subgraph output is LoopCond's input std::vector op_inputs{value_node, loop_cond_node_->input(1)}; auto return_cnode = cond_sub_func_graph_->NewCNode(op_inputs); return_cnode->set_fullname_with_scope(cond_subgraph_name_ + "-return"); cond_sub_func_graph_->set_return(return_cnode); // hardcode subgraph outputs name cond_sub_func_graph_->output()->cast()->set_fullname_with_scope(cond_subgraph_name_ + "_output_0_cnode"); return RET_OK; } STATUS FunctionalizeWhile::BuildCondGraph() { cond_subgraph_name_ = FunctionalizeControlOpPass::NodeClusterName(loop_cond_node_) + "_cond"; cond_sub_func_graph_ = FunctionalizeControlOpPass::NewFuncGraph(cond_subgraph_name_, mindspore::lite::converter::FmkType_TF); if (cond_sub_func_graph_ == nullptr) { MS_LOG(ERROR) << "new cond_sub_func_graph_ return nullptr"; return RET_NULL_PTR; } cond_sub_func_graph_->set_manager(fg_->manager()); int ret = CondSubgraphAddNodes(); if (ret != RET_OK) { MS_LOG(ERROR) << "add cond_subgraph node failed, ret:" << ret; return ret; } ret = IdentifyCondSubgraphOutput(); if (ret != RET_OK) { MS_LOG(ERROR) << "identify cond_subgraph output failed, ret:" << ret; return ret; } ret = IdentifyCondSubgraphInput(); if (ret != RET_OK) { MS_LOG(ERROR) << "identify cond_subgraph input failed, ret:" << ret; return ret; } return ret; } // nodes between next_iteration op and switch op will be added into body_func_graph STATUS FunctionalizeWhile::BodySubgraphAddNodes() { std::deque todo(512); todo.clear(); for (auto &node : node_cluster_) { if (FunctionalizeControlOpPass::IsNextIteration(node)) { auto next_iteration_cnode = node->cast(); for (size_t i = 1; i < next_iteration_cnode->inputs().size(); i++) { todo.push_back(next_iteration_cnode->input(i)); } body_subgraph_output_map_[node] = next_iteration_cnode->input(1); } } while (!todo.empty()) { AnfNodePtr node = todo.back(); todo.pop_back(); if (FunctionalizeControlOpPass::IsSwitch(node)) { continue; } if (utils::isa(node)) { body_sub_func_graph_->add_parameter(node->cast()); } else { body_sub_func_graph_->AddNode(node); } node->set_func_graph(body_sub_func_graph_); if (utils::isa(node)) { auto cnode = utils::cast(node); for (size_t i = 1; i < cnode->inputs().size(); i++) { todo.push_back(cnode->input(i)); } } } return RET_OK; } STATUS FunctionalizeWhile::IdentifyBodySubgraphInput() { std::vector nodes_need_drop{}; for (auto &cnode : body_sub_func_graph_->GetOrderedCnodes()) { for (auto &input_node : cnode->inputs()) { if (FunctionalizeControlOpPass::IsSwitch(input_node)) { auto switch_node = input_node->cast(); auto merge_node = BlongToWhichMerge(switch_node); auto enter_node = BlongToWhichEnter(merge_node); int pos = PosInInputEnterNodes(enter_node); if (pos == POS_INVALID) { continue; } nodes_need_drop.push_back(cnode); // set parameter auto parameter = body_sub_func_graph_->add_parameter(); parameter->set_abstract(cnode->abstract()); // hardcode for subgraph input name parameter->set_name(body_subgraph_name_ + "_input_" + std::to_string(pos) + "_parameter"); // replace switch auto manager = fg_->manager(); auto node_users = manager->node_users()[cnode]; for (auto &node_user : node_users) { if (body_sub_func_graph_->nodes().contains(node_user.first)) { manager->SetEdge(node_user.first, node_user.second, parameter); } } } } } // drop node from cond_func_graph for (const auto &node : nodes_need_drop) { body_sub_func_graph_->DropNode(node); } return RET_OK; } STATUS FunctionalizeWhile::IdentifyBodySubgraphOutput() { std::vector tmp_output{}; tmp_output.resize(input_enter_nodes_.size()); for (auto &node_pair : body_subgraph_output_map_) { auto next_iteration_cnode = utils::cast(node_pair.first); auto switch_node = BlongToWhichSwitch(next_iteration_cnode); auto merge_node = BlongToWhichMerge(switch_node); auto enter_node = BlongToWhichEnter(merge_node); int pos = PosInInputEnterNodes(enter_node); if (pos == POS_INVALID) { continue; } tmp_output[pos] = node_pair.second; // hard code. set cnode output name node_pair.second->cast()->set_fullname_with_scope(body_subgraph_name_ + "_output_" + std::to_string(pos) + "_cnode"); } auto return_prim_ptr = std::make_shared(); if (return_prim_ptr == nullptr) { MS_LOG(ERROR) << "GetReturnPrim return nullptr"; return RET_NULL_PTR; } auto value_node = NewValueNode(return_prim_ptr); // cond subgraph output is LoopCond's input std::vector op_inputs{value_node}; auto return_cnode = body_sub_func_graph_->NewCNode(op_inputs); return_cnode->set_fullname_with_scope(body_subgraph_name_ + "-return"); if (tmp_output.size() == 1) { return_cnode->add_input(tmp_output[0]); } else { std::vector make_tuple_inputs = tmp_output; auto make_tuple_prim_ptr = std::make_shared(); if (make_tuple_prim_ptr == nullptr) { MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr"; return RET_NULL_PTR; } auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr); make_tuple_inputs.insert(make_tuple_inputs.begin(), make_tuple_prim); auto make_tuple_cnode = body_sub_func_graph_->NewCNode(make_tuple_inputs); make_tuple_cnode->set_fullname_with_scope(return_cnode->fullname_with_scope() + "tuple"); return_cnode->add_input(make_tuple_cnode); } body_sub_func_graph_->set_return(return_cnode); return RET_OK; } STATUS FunctionalizeWhile::BuildBodyGraph() { body_subgraph_name_ = FunctionalizeControlOpPass::NodeClusterName(loop_cond_node_) + "_body"; body_sub_func_graph_ = FunctionalizeControlOpPass::NewFuncGraph(body_subgraph_name_, mindspore::lite::converter::FmkType_TF); if (body_sub_func_graph_ == nullptr) { MS_LOG(ERROR) << "new body_sub_func_graph_ return nullptr"; return RET_NULL_PTR; } body_sub_func_graph_->set_manager(fg_->manager()); int ret = BodySubgraphAddNodes(); if (ret != RET_OK) { MS_LOG(ERROR) << "add body_subgraph node failed, ret:" << ret; return ret; } ret = IdentifyBodySubgraphOutput(); if (ret != RET_OK) { MS_LOG(ERROR) << "identify body_subgraph output failed, ret:" << ret; return ret; } ret = IdentifyBodySubgraphInput(); if (ret != RET_OK) { MS_LOG(ERROR) << "identify body_subgraph input failed, ret:" << ret; return ret; } return ret; } STATUS FunctionalizeWhile::InsertFuncGraphToWhileInput() { // set while input cond and body vnode auto cond_value_node = NewValueNode(cond_sub_func_graph_); auto body_value_node = NewValueNode(body_sub_func_graph_); auto inputs = while_node_->inputs(); inputs.insert(inputs.begin() + 1, {cond_value_node, body_value_node}); while_node_->set_inputs(inputs); return RET_OK; } STATUS FunctionalizeWhile::DropUselessNodesInMainGraph() { // fg_ drop cluster node for (auto &node : node_cluster_) { fg_->DropNode(node); } return RET_OK; } STATUS FunctionalizeWhile::Process() { int ret = BuildWhileNode(); if (ret != RET_OK) { MS_LOG(ERROR) << "build while node failed, ret:" << ret; return ret; } ret = BuildCondGraph(); if (ret != RET_OK) { MS_LOG(ERROR) << "build condition graph failed, ret:" << ret; return ret; } ret = BuildBodyGraph(); if (ret != RET_OK) { MS_LOG(ERROR) << "build body graph failed, ret:" << ret; return ret; } ret = InsertFuncGraphToWhileInput(); if (ret != RET_OK) { MS_LOG(ERROR) << "insert func_graph to while input failed, ret:" << ret; return ret; } ret = DropUselessNodesInMainGraph(); if (ret != RET_OK) { MS_LOG(ERROR) << "main func_graph drop nodes failed, ret:" << ret; return ret; } return ret; } } // namespace mindspore::opt