/** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_UTILS_GRAPH_UTILS_H_ #define MINDSPORE_CCSRC_UTILS_GRAPH_UTILS_H_ #include #include #include #include #include #include #include #include #include "ir/anf.h" #include "ir/primitive.h" #include "ir/scalar.h" #include "ir/meta_tensor.h" #include "debug/label.h" namespace mindspore { enum IncludeType { FOLLOW, NOFOLLOW, EXCLUDE }; using IncludeFunc = std::function; using SuccFunc = std::function(AnfNodePtr)>; using SearchFunc = std::function(const AnfNodePtr &, const IncludeFunc &)>; std::vector SuccDeeper(const AnfNodePtr &node); std::vector SuccDeeperSimple(const AnfNodePtr &node); std::vector SuccIncoming(const AnfNodePtr &node); std::vector SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &node); IncludeType AlwaysInclude(const AnfNodePtr &node); IncludeType IncludeBelongGraph(const FuncGraphPtr &fg, const AnfNodePtr &node); std::vector DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); std::vector DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); std::vector DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming, const IncludeFunc &include = AlwaysInclude); class FuncGraphIndex { public: explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch, const IncludeFunc &include = AlwaysInclude); FuncGraphIndex(const FuncGraphIndex &) = delete; FuncGraphIndex &operator=(const FuncGraphIndex &) = delete; virtual ~FuncGraphIndex() {} std::set GetFuncGraphs(const std::string &key); std::set GetNodes(const std::string &key); FuncGraphPtr GetFirstFuncGraph(const std::string &key); AnfNodePtr GetFirstNode(const std::string &key); private: void Acquire(const FuncGraphPtr &key); void Acquire(const AnfNodePtr &key); std::map> index_func_graph_; std::map> index_node_; }; // Isomorphism struct PairHasher { template std::size_t operator()(const std::pair &p) const { auto h1 = std::hash{}(p.first); auto h2 = std::hash{}(p.second); return h1 ^ h2; } }; enum EquivState { kNotEquiv = 0, kEquiv = 1, kPending = 2 }; using FuncGraphPairMapEquiv = std::unordered_map, EquivState, PairHasher>; using NodeMapEquiv = std::unordered_map; bool Isomorphic(FuncGraphPtr g1, FuncGraphPtr g2, FuncGraphPairMapEquiv *equiv_func_graph, NodeMapEquiv *equiv_node); tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar); } // namespace mindspore #endif // MINDSPORE_CCSRC_UTILS_GRAPH_UTILS_H_