/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_CONTROL_PARSER_H #define MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_CONTROL_PARSER_H #include #include #include #include #include #include #include #include "backend/session/kernel_graph.h" #include "base/base_ref.h" #include "utils/contract.h" #include "utils/union_find_set.h" namespace mindspore { namespace session { class AscendControlParser { public: static void LinkGraph(NotNull kg); static void InsertDependToGraph(NotNull kg, NotNull attch_node); static void InsertControlDependToGraph(NotNull kg, NotNull first_node, NotNull second_node); static void ExecutorValidate(NotNull root_graph); static void InsertMultipleAssignToGraph(NotNull from_graph, const AnfNodePtr &jump_node, NotNull from, NotNull to); private: class ReferenceCounter; static void EraseParameter(NotNull root_graph, const std::set &graph_list); static void EraseAssign(std::shared_ptr parameter_count, const std::set &all_nodes, const std::map ¶_to_written_node, NotNull root_graph, const std::set &graph_list); static void EraseLabel(NotNull root_graph); static void ChildGraphDataAssign(NotNull kg, const NotNull> *> link_list, const NotNull *> memo); static NotNull GetStartLabel(NotNull kg, const CNodePtr &last_node, const CNodePtr &last_label); static NotNull ProcessKernelGraph(NotNull kg, const CNodePtr &last_node, const CNodePtr &last_label, const NotNull *> memo); static void RecurseCall(NotNull kg, NotNull cur_node, const CNodePtr &next_node, const NotNull *> memo); static void RecurseSwitch(NotNull kg, NotNull cur_node, const CNodePtr &next_node, const NotNull *> memo); static void RecurseSwitchLayer(NotNull kg, NotNull cur_node, const CNodePtr &next_node, const NotNull *> memo); static void LinkParentGraph(NotNull kg, const CNodePtr &from_graph_call_node, const CNodePtr &last_label); static AnfNodePtr InsertAssignToGraph(NotNull kg, NotNull from, NotNull to); static std::vector>> ParseCallSwitchNode( NotNull call_node); static std::tuple> ParsePartial(NotNull node); static void AttachChildGraphToReturnNode(NotNull graph, const NotNull *> memo); // root graph order static bool CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cnode, KernelGraphPtr *cur_child_graph); static std::vector RecurseGraph(NotNull graph, const NotNull *> memo); static void AttachOriginalInputsToGraph(NotNull graph, const std::vector orig_inputs); }; class AscendControlParser::ReferenceCounter { public: explicit ReferenceCounter(std::function func) : predicate_(func), count_() {} ~ReferenceCounter() = default; void AddReadCount(const AnfNodePtr &key, int64_t num); void AddWriteCount(const AnfNodePtr &key, int64_t num); void EraseElem(const AnfNodePtr &key); bool HasValidElem() const; std::tuple GetOneValidElem() const; private: std::function predicate_; std::map> count_; }; } // namespace session } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_CONTROL_PARSER_H