/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 *conv_activation_fusion.h * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_LITE_SRC_PASS_FUNCTIONALIZE_WHILE_H_ #define MINDSPORE_LITE_SRC_PASS_FUNCTIONALIZE_WHILE_H_ #include #include #include #include #include "backend/optimizer/common/pass.h" #include "tools/converter/converter_flags.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/graph/functionalize_control_op_pass.h" using mindspore::lite::converter::FmkType; namespace mindspore::opt { constexpr const int POS_INVALID = -1; class FunctionalizeWhile { public: FunctionalizeWhile(std::vector node_cluster, const CNodePtr &loop_cond_node, FuncGraphPtr fg) : node_cluster_(node_cluster), loop_cond_node_(loop_cond_node), fg_(fg) {} // while STATUS BuildWhileNode(); STATUS IdentifyWhileNodeInput(); STATUS IdentifyWhileNodeExternalInput(); STATUS IdentifyWhileNodeOutput(); STATUS UpdateExitNodeUser(); STATUS NewWhileNode(); STATUS InsertFuncGraphToWhileInput(); bool WhileNodeExternalInputIsContain(const AnfNodePtr &node); // cond subgraph STATUS BuildCondGraph(); STATUS CondSubgraphAddNodes(); STATUS IdentifyCondSubgraphInput(); STATUS IdentifyCondSubgraphOutput(); // body subgraph STATUS BuildBodyGraph(); STATUS BodySubgraphAddNodes(); STATUS IdentifyBodySubgraphInput(); STATUS IdentifyBodySubgraphOutput(); CNodePtr BlongToWhichSwitch(const CNodePtr &node); CNodePtr BlongToWhichMerge(const CNodePtr &node); CNodePtr BlongToWhichEnter(const CNodePtr &node); CNodePtr BlongToWhichExternalEnter(const CNodePtr &node); int PosInInputEnterNodes(const CNodePtr &node); STATUS DropUselessNodesInMainGraph(); STATUS Process(); private: std::vector node_cluster_{}; const CNodePtr loop_cond_node_; FuncGraphPtr fg_; FuncGraphPtr cond_sub_func_graph_ = nullptr; FuncGraphPtr body_sub_func_graph_ = nullptr; CNodePtr while_node_ = nullptr; std::string cond_subgraph_name_{}; std::string body_subgraph_name_{}; // while std::vector input_enter_nodes_{}; std::vector external_input_enter_nodes_{}; std::vector output_exit_nodes_{}; // pair (next iteration node, next iteration node input) std::map body_subgraph_output_map_{}; // pair (switch node, switch output in body graph) std::map body_subgraph_input_map_{}; // pair (switch node, switch output in body graph) std::map cond_subgraph_input_map_{}; }; } // namespace mindspore::opt #endif // MINDSPORE_LITE_SRC_PASS_FUNCTIONALIZE_WHILE_PASS_H_