/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H #define MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H #include #include #include #include "session/kernel_graph.h" #include "utils/base_ref.h" #include "utils/contract.h" namespace mindspore { namespace session { class AscendControlParser { public: static void LinkGraph(NotNull kg); static void InsertDependToGraph(NotNull kg, NotNull attch_node); static void InsertControlDependToGraph(NotNull kg, NotNull first_node, NotNull second_node); private: static NotNull ProcessKernelGraph(NotNull kg, const CNodePtr &last_node, const CNodePtr &last_label, const VectorRef &args, NotNull *> memo); static void RecurseCall(NotNull kg, NotNull cur_node, const CNodePtr &next_node, NotNull *> memo); static void RecurseSwitch(NotNull kg, NotNull cur_node, NotNull *> memo); static void RecurseSwitchLayer(NotNull kg, NotNull cur_node, NotNull *> memo); static std::vector GetCNodes(const std::vector &in); static void LinkParentGraph(NotNull kg, const CNodePtr &from_graph_call_node, const CNodePtr &last_label, const VectorRef &args); static void SetSubGraphInput(NotNull kg, NotNull from_graph_call_node, const VectorRef &args); static std::tuple ParsePartial(NotNull node); static void InsertAssignToGraph(NotNull kg, NotNull from, NotNull to); static size_t SetChildGraphInput(NotNull kg, NotNull node, size_t input_index); static constexpr size_t kCNodePrim = 0; static constexpr size_t kCNodeCallArg = 1; static constexpr size_t kCNodeSwitchCond = 1; static constexpr size_t kCNodeSwitchTrue = 2; static constexpr size_t kCNodeSwitchFalse = 3; static constexpr size_t kCNodeSwitchLength = 4; static constexpr size_t kCNodePartialLength = 2; static constexpr size_t kCNodePartialFunc = 1; static constexpr size_t kCNodeSwitchLayerCond = 1; static constexpr size_t kCNodeSwitchLayerBranch = 2; static constexpr size_t kCNodeSwitchLayerLength = 3; }; } // namespace session } // namespace mindspore #endif // MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H