/** * Copyright 2021-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_ #include #include #include #include #include #include #include #include "base/base.h" #include "backend/common/session/anf_runtime_algorithm.h" #include "include/common/utils/anfalgo.h" #include "backend/common/optimizer/optimizer.h" #include "common/graph_kernel/parallel_cost_model.h" #include "backend/common/session/kernel_graph.h" #include "utils/ms_context.h" namespace mindspore::graphkernel { class ParallelInfo { public: ParallelInfo() = default; ParallelInfo(const AnfNodePtrList &nodes, const std::vector &dims, const FusionInfoPtr &fusion_info) : nodes_(nodes), dims_(dims), fusion_info_(fusion_info) {} ParallelInfo(const ParallelInfo &obj) { nodes_ = obj.nodes_; dims_ = obj.dims_; fusion_info_ = obj.fusion_info_; } ~ParallelInfo() = default; size_t GetSize() const { if (nodes_.size() != dims_.size()) { MS_LOG(EXCEPTION) << "Internal error in parallel info! nodes' size is different from dims' size: " << nodes_.size() << " vs " << dims_.size(); } return nodes_.size(); } const AnfNodePtrList &nodes() const { return nodes_; } const std::vector &dims() const { return dims_; } const FusionInfoPtr &fusion_info() const { return fusion_info_; } private: AnfNodePtrList nodes_; std::vector dims_; FusionInfoPtr fusion_info_; }; class ParallelConfig { public: ParallelConfig() = default; explicit ParallelConfig(size_t max_n) : max_num_for_fuse_(max_n) {} explicit ParallelConfig(const ParallelConfig &obj) { max_num_for_fuse_ = obj.max_num_for_fuse_; } ~ParallelConfig() = default; size_t max_num_for_fuse() const { return max_num_for_fuse_; } private: size_t max_num_for_fuse_{10}; // Too many nodes to fuse together may produce bad result. }; struct NodeRelation { public: NodeRelation() {} ~NodeRelation() = default; OrderedSet pres; OrderedSet nexts; }; class ParallelOpFusion : public opt::Pass { public: ParallelOpFusion(const std::string &target, const ParallelConfig &config) : Pass("parallel_fusion"), target_(target), config_(config) {} ~ParallelOpFusion() override = default; bool Run(const FuncGraphPtr &graph) override; private: std::tuple> GetAvaliableNodesByOffset(int start, const std::vector &offsets, const std::vector &used, const AnfNodePtrList &nodes, const std::set &excludes); std::tuple, std::vector> DoSearchInSortedCandidates( size_t origin_size, const AnfNodePtrList &candidates, std::map *origin_indices, std::map *sorted_indices); std::tuple, std::vector> SearchFuseNodesInCandidates(const AnfNodePtrList &cs); void SearchFuseNodesInParallelGroup(const std::vector &group, std::vector *parallel_infos); std::vector SearchFusableParallelCNodes(const std::vector> &groups); void SetFusionInfoAttrToNode(const AnfNodePtr &node, const ParallelInfo ¶llel_info); void SetFusedParallelOpAttrToReturnNode(const ParallelInfo ¶llel_info); bool CreateParallelOpSubGraphs(const std::vector ¶llel_infos, const std::shared_ptr &kernel_graph); OrderedMap GenAnalysisGraph(const AnfNodePtrList &nodes); std::vector> SearchParallelGroups(const OrderedMap &node_rels); std::string target_; ParallelConfig config_; ParallelCostModelPtr cost_model_ptr_; std::set virtual_noout_nodes_; std::set ignore_noin_nodes_; unsigned int parallel_level_{0}; }; using ParallelOpFusionPtr = std::shared_ptr; } // namespace mindspore::graphkernel #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_