/** * Copyright 2021-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "common/graph_kernel/parallel_fusion.h" #include #include #include #include #include #include "include/common/utils/context/graph_kernel_flags.h" #include "kernel/kernel.h" #include "common/graph_kernel/graph_kernel_helper.h" #include "kernel/common_utils.h" #include "frontend/operator/ops.h" #include "ir/func_graph_cloner.h" #include "common/graph_kernel/core/update_state_formatter.h" #include "common/graph_kernel/core/graph_builder.h" namespace mindspore::graphkernel { namespace { // Cuda's parameter table can accept maximum 4KB, so the number of parameters should be less than 512. constexpr size_t CUDA_PARA_LIMIT = 512; bool IsOneOf(const AnfNodePtr &node, const std::vector &ops_prim) { return std::any_of(ops_prim.cbegin(), ops_prim.cend(), [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); } void ProcessThroughPassCNode(const std::function &pass_fn, OrderedMap *const node_rels) { std::set latter_to_be_erased; for (const auto &[node, node_rel] : (*node_rels)) { if (!pass_fn(node) || latter_to_be_erased.count(node) != 0) { continue; } auto nexts = node_rel.nexts; std::vector pre_nodes; std::queue node_que; node_que.push(node); // Find until all pre nodes get false from pass_fn, and collect all these predecessor nodes. while (!node_que.empty()) { auto cur_node = node_que.front(); node_que.pop(); if (!pass_fn(cur_node)) { pre_nodes.push_back(cur_node); continue; } latter_to_be_erased.insert(cur_node); auto predecessors = (*node_rels)[cur_node].pres; if (predecessors.empty()) { continue; } for (const auto &pre_node : predecessors) { (*node_rels)[cur_node].pres.erase(pre_node); (*node_rels)[pre_node].nexts.erase(cur_node); node_que.push(pre_node); } } // Modify the relation: delete node <-> next_node, add pre node <-> next_node. for (const auto &next_node : nexts) { (*node_rels)[next_node].pres.erase(node); for (const auto &cur_node : pre_nodes) { (*node_rels)[next_node].pres.insert(cur_node); (*node_rels)[cur_node].nexts.insert(next_node); } } } for (const auto &node : latter_to_be_erased) { node_rels->erase(node); } } void ProcessTailMakeTupleCNode(OrderedMap *const node_rels) { AnfNodePtrList latter_to_be_erased; for (auto &[node, node_rel] : (*node_rels)) { if (!IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { continue; } AnfNodePtrList check_next_list; check_next_list.push_back(node); bool disinterested = false; for (auto &successor : node_rel.nexts) { if (!IsPrimitiveCNode(successor, prim::kPrimTupleGetItem)) { disinterested = true; break; } check_next_list.push_back(successor); } if (disinterested) { continue; } if (!std::all_of(check_next_list.cbegin(), check_next_list.cend(), [&node_rels](const AnfNodePtr &n) -> bool { return (*node_rels)[n].nexts.empty(); })) { continue; } latter_to_be_erased.push_back(node); } // Delete Tail MakeTuple(including its getitem nodes). for (const auto &node : latter_to_be_erased) { for (auto &pre : (*node_rels)[node].pres) { (*node_rels)[pre].nexts.erase(node); } // Tail MakeTuple is just be consumed by nothing or invalid getitem node. for (auto &getitem : (*node_rels)[node].nexts) { node_rels->erase(getitem); } node_rels->erase(node); } } bool IsSingleInputNode(const OrderedMap &node_rels, const AnfNodePtr &node) { if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() == 1) { return true; } return false; } bool IsSingleOutputNode(const OrderedMap &node_rels, const AnfNodePtr &node) { if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() == 1) { return true; } return false; } bool IsMultiInputsNode(const OrderedMap &node_rels, const AnfNodePtr &node) { if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() > 1) { return true; } return false; } bool IsMultiOutputsNode(const OrderedMap &node_rels, const AnfNodePtr &node) { if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() > 1) { return true; } return false; } bool IsNoInputsNode(const OrderedMap &node_rels, const AnfNodePtr &node) { if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() == 0) { return true; } return false; } bool IsNoOutputsNode(const OrderedMap &node_rels, const AnfNodePtr &node) { if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() == 0) { return true; } return false; } void ProcessLocalStructure(OrderedMap *node_rels, std::set *const virtual_noout_nodes, std::set *ignore_noin_nodes) { // 1. Local relation // Graph as following left part, relation D->B and D->E(D is a no input node) // will make B and E to be multiply inputs node. // But for parallel, this local relation can ignore for B and E, which make // them be able to be paralleled. // // ************************************ // * * // * | | * // * A D A D * // * | /| | / \ * // * | C | | C F * // * |/ / | | | * // * B F ====> B x x * // * | / | * // * |/ | * // * E E * // * | | * // * * // ************************************ AnfNodePtrList no_input_nodes; for (const auto &node_rel : *node_rels) { auto &node = node_rel.first; if (IsNoInputsNode(*node_rels, node)) { no_input_nodes.push_back(node); } } std::vector> latter_delete; for (const auto &ninode : no_input_nodes) { AnfNodePtrList cnexts((*node_rels)[ninode].nexts.begin(), (*node_rels)[ninode].nexts.end()); for (const auto &n : cnexts) { AnfNodePtr serial_tail = ninode; AnfNodePtr cur_node = n; while (IsSingleInputNode(*node_rels, cur_node) && IsSingleOutputNode(*node_rels, cur_node)) { serial_tail = cur_node; cur_node = *((*node_rels)[cur_node].nexts.begin()); } latter_delete.emplace_back(serial_tail, cur_node); } } // Delete relation. for (const auto &[serial_tail, cur_node] : latter_delete) { virtual_noout_nodes->insert(serial_tail); ignore_noin_nodes->insert(cur_node); (*node_rels)[serial_tail].nexts.erase(cur_node); (*node_rels)[cur_node].pres.erase(serial_tail); MS_LOG(INFO) << "Process local relation delete relation: " << serial_tail->fullname_with_scope() << " -> " << cur_node->fullname_with_scope(); } } std::tuple GetInterestNodeIds( const OrderedMap &node_rels, const std::set &virtual_noout_nodes, const std::set &ignore_noin_nodes) { AnfNodePtrList multi_inputs_nodes, multi_outputs_nodes, no_input_nodes, no_output_nodes; std::list> func_list = { [&node_rels, &multi_inputs_nodes](const AnfNodePtr &node) { if (IsMultiInputsNode(node_rels, node)) { multi_inputs_nodes.push_back(node); } }, [&node_rels, &multi_outputs_nodes](const AnfNodePtr &node) { if (IsMultiOutputsNode(node_rels, node)) { multi_outputs_nodes.push_back(node); } }, [&node_rels, &no_input_nodes, &ignore_noin_nodes](const AnfNodePtr &node) { if (IsNoInputsNode(node_rels, node) && ignore_noin_nodes.count(node) == 0) { no_input_nodes.push_back(node); } }, [&node_rels, &no_output_nodes, &virtual_noout_nodes](const AnfNodePtr &node) { if (IsNoOutputsNode(node_rels, node) && virtual_noout_nodes.count(node) == 0) { no_output_nodes.push_back(node); } }}; for (const auto &node_rel : node_rels) { for (const auto &func : func_list) { func(node_rel.first); } } return std::make_tuple(multi_inputs_nodes, multi_outputs_nodes, no_input_nodes, no_output_nodes); } bool WhiteOpsFilter(const AnfNodePtr &node) { std::vector whiteable_ops = {}; // Not special for now. return common::AnfAlgo::IsGraphKernel(node) || IsOneOf(node, whiteable_ops); } bool Unfavorable(const AnfNodePtr &node) { // Parallel cannot work with stitching for now. auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); auto input = cnode->input(kAnfPrimitiveIndex); if (!IsValueNode(input)) { return common::AnfAlgo::HasNodeAttr(kAttrStitch, cnode); } auto func_graph = GetValueNode(input); MS_EXCEPTION_IF_NULL(func_graph); AnfNodePtrList sub_nodes; kernel::GetValidKernelNodes(func_graph, &sub_nodes); for (auto sub_node : sub_nodes) { auto sub_cnode = sub_node->cast(); MS_EXCEPTION_IF_NULL(sub_cnode); if (common::AnfAlgo::HasNodeAttr(kAttrStitch, sub_cnode)) { return true; } } return false; } bool Parallelizable(const AnfNodePtr &node) { return WhiteOpsFilter(node) && !Unfavorable(node); } std::vector SearchFromNodes(const AnfNodePtrList &nodes, const std::function &filter_func, const OrderedMap &node_rels, bool is_backward, std::set *const seen) { // Start from multi-inputs node, stop on seen node or multi-inputs or multi-outputs nodes. // For backward search, the other multi-inputs node can be contained in. // For forward search, the other multi-outputs node can be contained in. auto get_contain_node_set = is_backward ? [](const NodeRelation &info) { return info.pres; } : [](const NodeRelation &info) { return info.nexts; }; auto get_exclude_node_set = is_backward ? [](const NodeRelation &info) { return info.nexts; } : [](const NodeRelation &info) { return info.pres; }; std::vector group; for (const auto &node : nodes) { AnfNodePtrList stream; AnfNodePtr n = node; for (auto iter = node_rels.find(n); seen->count(n) == 0 && iter != node_rels.end() && get_exclude_node_set(iter->second).size() <= 1; iter = node_rels.find(n)) { if (filter_func(n)) { stream.push_back(n); seen->insert(n); } if (get_contain_node_set(iter->second).size() != 1) { break; } n = *(get_contain_node_set(iter->second).begin()); } if (stream.size() > 0) { group.push_back(stream); } } if (group.size() == 1) { for (const auto &drop : group[0]) { seen->erase(drop); } group.clear(); } return group; } void SearchStreamFromMultiRelationNode(const AnfNodePtrList &multi_nodes, const OrderedMap &node_rels, bool is_backward, std::vector> *groups, std::set *const seen) { auto get_related_nodes = is_backward ? [](const NodeRelation &info) { return info.pres; } : [](const NodeRelation &info) { return info.nexts; }; for (const auto &node : multi_nodes) { if (auto iter = node_rels.find(node); iter != node_rels.end()) { const auto &pre_nodes = get_related_nodes(iter->second); AnfNodePtrList related_nodes(pre_nodes.begin(), pre_nodes.end()); groups->push_back(SearchFromNodes(related_nodes, Parallelizable, node_rels, is_backward, seen)); } } // Erase empty groups. for (auto iter = groups->begin(); iter != groups->end();) { if (iter->size() == 0) { iter = groups->erase(iter); } else { ++iter; } } } void SearchStreamFromUnidirectionalNode(const AnfNodePtrList &ud_nodes, const OrderedMap &node_rels, bool is_backward, std::vector> *groups, std::set *const seen) { groups->push_back(SearchFromNodes(ud_nodes, Parallelizable, node_rels, is_backward, seen)); // Erase empty groups. for (auto iter = groups->begin(); iter != groups->end();) { if (iter->size() == 0) { iter = groups->erase(iter); } else { ++iter; } } } std::string DumpNode(const AnfNodePtr &node) { auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); std::stringstream buf; buf << (common::AnfAlgo::IsGraphKernel(cnode) ? "[graph]" : "[primitive]") << cnode->fullname_with_scope() << "|" << cnode->ToString(); return buf.str(); } void DumpParallelGroups(const std::vector> &groups, const std::string &title = "") { MS_LOG(INFO) << "[" << title << "]" << "There are " << groups.size() << " parallel groups, their detail is: "; int i = 0; for (const auto group : groups) { std::stringstream buf; buf << "[" << i << " group] " << group.size() << ":\n"; for (const auto nodes : group) { buf << " " << nodes.size() << ": [<"; for (const auto node : nodes) { buf << "(" << DumpNode(node) << ") -> "; } buf << ">]\n"; } i++; MS_LOG(INFO) << buf.str(); } } void DumpParallelFusionDetail(const AnfNodePtrList &source, const AnfNodePtr &target) { std::stringstream buf; buf << "Parallel fusion detail: "; for (const auto &node : source) { buf << "(" << DumpNode(node) << ") + "; } buf << "==>" << "(" << DumpNode(target) << ")"; MS_LOG(INFO) << buf.str(); } inline bool ParameterLimit(const AnfNodePtrList &nodes) { if (nodes.empty()) { MS_LOG(EXCEPTION) << "Nodes is empty, can not check condition."; } bool res = true; switch (AnfAlgo::GetProcessor(nodes[0])) { case kernel::Processor::CUDA: { // The number of inputs and outputs for a valid kernel should be less than cuda's limit. size_t para_count = 0; for (const auto &node : nodes) { para_count += common::AnfAlgo::GetInputTensorNum(node); para_count += common::AnfAlgo::GetOutputTensorNum(node); } res = para_count <= CUDA_PARA_LIMIT; } break; default: break; } return res; } bool ExtraFusionCondition(const AnfNodePtrList &nodes) { return ParameterLimit(nodes); } } // namespace OrderedMap ParallelOpFusion::GenAnalysisGraph(const AnfNodePtrList &nodes) { // Based on anf node input information, build a simple graph for latter analyzation. OrderedMap node_rels; auto get_info = [&node_rels](const AnfNodePtr &node) { if (node_rels.count(node) == 0) { (void)node_rels.emplace(node, NodeRelation()); } return &(node_rels[node]); }; for (const auto &node : nodes) { if (!node->isa()) { continue; } auto prior_node = get_info(node); for (const auto &input : (node->cast())->inputs()) { if (!input->isa()) { continue; } auto behind_node = get_info(input); prior_node->pres.insert(input); behind_node->nexts.insert(node); } } ProcessThroughPassCNode( [](const AnfNodePtr &node) { return IsOneOf(node, {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimSqueeze, prim::kPrimTupleGetItem}); }, &node_rels); ProcessTailMakeTupleCNode(&node_rels); ProcessLocalStructure(&node_rels, &virtual_noout_nodes_, &ignore_noin_nodes_); return node_rels; } std::vector> ParallelOpFusion::SearchParallelGroups( const OrderedMap &node_rels) { // Get interesting nodes: multi-inputs nodes, multi-outputs nodes, no input nodes and no output nodes. auto [mul_ins_nodes, mul_outs_nodes, no_in_nodes, no_out_nodes] = GetInterestNodeIds(node_rels, virtual_noout_nodes_, ignore_noin_nodes_); // Get streams and group them std::set seen; std::vector> groups; SearchStreamFromMultiRelationNode(mul_ins_nodes, node_rels, true, &groups, &seen); SearchStreamFromUnidirectionalNode(no_out_nodes, node_rels, true, &groups, &seen); SearchStreamFromMultiRelationNode(mul_outs_nodes, node_rels, false, &groups, &seen); SearchStreamFromUnidirectionalNode(no_in_nodes, node_rels, false, &groups, &seen); DumpParallelGroups(groups, "Dependency Analyze"); return groups; } std::tuple> ParallelOpFusion::GetAvaliableNodesByOffset( int start, const std::vector &offsets, const std::vector &used, const AnfNodePtrList &nodes, const std::set &excludes) { // Get unused nodes by offset index, the result will contain the node with start index. int node_limit = static_cast(nodes.size()); if (start >= node_limit) { MS_LOG(EXCEPTION) << "Index offset should be less than the limit of given nodes " << node_limit << ", but got " << start; } AnfNodePtrList target_nodes = {nodes[IntToSize(start)]}; std::vector valid_indices; std::vector unused; for (size_t i = IntToSize(start); i < used.size(); ++i) { if (!used[i] && excludes.count(i) == 0) { unused.push_back(i); } } size_t limit = unused.size(); for (auto offset : offsets) { if (offset >= limit) { MS_LOG(EXCEPTION) << "Index offset should be less than the limit of unused nodes " << limit << ", but got " << offset; } if (SizeToInt(unused[offset]) >= node_limit) { MS_LOG(EXCEPTION) << "Index offset should be less than the limit of nodes " << node_limit << ", but got " << unused[offset]; } valid_indices.push_back(unused[offset]); target_nodes.push_back(nodes[unused[offset]]); } return std::make_tuple(target_nodes, valid_indices); } std::tuple, std::vector> ParallelOpFusion::DoSearchInSortedCandidates( size_t origin_size, const AnfNodePtrList &candidates, std::map *origin_indices, std::map *sorted_indices) { auto get_index = [](std::map *indices, const AnfNodePtr &node) -> int { MS_EXCEPTION_IF_NULL(node); if (indices->find(node) == indices->end()) { MS_LOG(EXCEPTION) << "There is no index record for node " << node->ToString(); } return (*indices)[node]; }; std::vector parallel_infos; std::vector origin_candidates_used(origin_size, false); std::vector sorted_candidates_used(candidates.size(), false); for (size_t i = 0; i < candidates.size(); ++i) { if (sorted_candidates_used[i]) { continue; } int max_benefit = 0; ParallelInfo best_parallel_info; size_t unused_num = 0; for (size_t j = i + 1; j < sorted_candidates_used.size(); ++j) { unused_num += sorted_candidates_used[j] ? 0 : 1; } if (unused_num < 1) { break; } unused_num = std::min(unused_num, config_.max_num_for_fuse() - 1); size_t begin = 1, end = unused_num; while (begin <= end) { size_t mid = (begin + end) / 2; std::vector tc(mid); std::iota(tc.begin(), tc.end(), 1); AnfNodePtrList other_candidates; std::tie(other_candidates, std::ignore) = GetAvaliableNodesByOffset(SizeToInt(i), tc, sorted_candidates_used, candidates, std::set()); if (ExtraFusionCondition(other_candidates)) { int benefit; std::tie(std::ignore, benefit, std::ignore) = cost_model_ptr_->CalFuseInfo(other_candidates); if (benefit > 0) { begin = mid + 1; continue; } } end = mid - 1; } if (begin > 1) { std::vector tc(begin - 1); std::iota(tc.begin(), tc.end(), 1); AnfNodePtrList other_candidates; std::tie(other_candidates, std::ignore) = GetAvaliableNodesByOffset(SizeToInt(i), tc, sorted_candidates_used, candidates, std::set()); auto [dim_infos, benefit, fusion_info] = cost_model_ptr_->CalFuseInfo(other_candidates); if (benefit <= 0) { MS_LOG(EXCEPTION) << "Internal error in candidate search! benefit should be greater than 0, but got " << benefit; } max_benefit = benefit; best_parallel_info = ParallelInfo(other_candidates, dim_infos, fusion_info); i += begin - 1; } if (max_benefit > 0) { parallel_infos.push_back(best_parallel_info); for (const auto &node : best_parallel_info.nodes()) { sorted_candidates_used[IntToSize(get_index(sorted_indices, node))] = true; origin_candidates_used[IntToSize(get_index(origin_indices, node))] = true; } } } // Current nodes is not suitable to fuse, so pop first node to try other fusion possibility. if (parallel_infos.size() == 0) { origin_candidates_used[IntToSize(get_index(origin_indices, candidates[parallel_infos.size()]))] = true; } return std::make_tuple(origin_candidates_used, parallel_infos); } std::tuple, std::vector> ParallelOpFusion::SearchFuseNodesInCandidates( const AnfNodePtrList &cs) { std::map origin_indices; std::vector indices; for (size_t i = 0; i < cs.size(); ++i) { if (cs[i]) { (void)origin_indices.emplace(cs[i], i); indices.push_back(i); } } // A calculated heavy node can cover more lighter nodes' cost, so sort them first. std::map cal_amounts; for (auto id : indices) { cal_amounts[id] = cost_model_ptr_->GetNodeCalAmount(cs[id]); } std::sort(indices.begin(), indices.end(), [&cal_amounts](size_t a, size_t b) { return cal_amounts[a] > cal_amounts[b]; }); AnfNodePtrList candidates; for (size_t i = 0; i < indices.size(); ++i) { candidates.push_back(cs[indices[i]]); } std::map sorted_indices; for (size_t i = 0; i < candidates.size(); ++i) { (void)sorted_indices.emplace(candidates[i], i); } return DoSearchInSortedCandidates(cs.size(), candidates, &origin_indices, &sorted_indices); } void ParallelOpFusion::SearchFuseNodesInParallelGroup(const std::vector &group, std::vector *parallel_infos) { std::vector tails; std::vector ended; for (const auto &node_list : group) { tails.push_back(node_list.begin()); ended.push_back(node_list.end()); } auto get_candidates = [&tails, &ended]() { AnfNodePtrList candidates; for (size_t id = 0; id < tails.size(); ++id) { candidates.push_back(tails[id] != ended[id] ? *tails[id] : AnfNodePtr()); } return candidates; }; auto update_tails = [&tails](const std::vector &used) { if (used.size() != tails.size()) { MS_LOG(EXCEPTION) << "Judged nodes size is different from left ones size: " << used.size() << " vs " << tails.size(); } for (size_t id = 0; id < used.size(); ++id) { if (used[id]) { ++tails[id]; } } }; auto valid_candidate_num = [](const AnfNodePtrList &cs) { return std::count_if(cs.begin(), cs.end(), [](const AnfNodePtr &n) { return n != nullptr; }); }; auto candidates = get_candidates(); while (valid_candidate_num(candidates) > 1) { auto [used, fnds] = SearchFuseNodesInCandidates(candidates); std::transform(fnds.cbegin(), fnds.cend(), std::back_insert_iterator(*parallel_infos), [](const ParallelInfo &pi) { return pi; }); update_tails(used); candidates = get_candidates(); } } std::vector ParallelOpFusion::SearchFusableParallelCNodes( const std::vector> &groups) { // Find core-fusable groups with cost model. std::vector parallel_infos; for (const auto &group : groups) { SearchFuseNodesInParallelGroup(group, ¶llel_infos); } return parallel_infos; } void ParallelOpFusion::SetFusedParallelOpAttrToReturnNode(const ParallelInfo ¶llel_info) { AnfNodePtr attach_node; // Dim info should be attach to each segment's output. for (size_t i = 0; i < parallel_info.GetSize(); ++i) { const auto &fuse_nodes = parallel_info.nodes(); std::vector info = {i, std::dynamic_pointer_cast(parallel_info.dims()[i])->dim_info()}; if (!common::AnfAlgo::IsGraphKernel(fuse_nodes[i])) { attach_node = fuse_nodes[i]; SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue>(info), fuse_nodes[i]); } else { auto node_g = GetValueNode((fuse_nodes[i]->cast())->input(0)); auto out_node = node_g->output(); if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) { auto inputs = out_node->cast()->inputs(); for (size_t j = 1; j < inputs.size(); ++j) { SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue>(info), inputs[j]); } attach_node = inputs[1]; } else { attach_node = out_node; SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue>(info), out_node); } } } // Fusion info is ok to attach to one of the segments. SetFusionInfoAttrToNode(attach_node, parallel_info); } void ParallelOpFusion::SetFusionInfoAttrToNode(const AnfNodePtr &node, const ParallelInfo ¶llel_info) { auto fusion_type = parallel_info.fusion_info()->FusionType(); common::AnfAlgo::SetNodeAttr(kAttrParallelFusionType, MakeValue(fusion_type), node); if (parallel_info.fusion_info()->ExistTypeInfo()) { if (auto pipeline_fusion = std::dynamic_pointer_cast(parallel_info.fusion_info())) { common::AnfAlgo::SetNodeAttr(kAttrParallelTypeInfo, MakeValue>>(pipeline_fusion->PipelineIds()), node); } } } bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector ¶llel_infos, const std::shared_ptr &kernel_graph) { bool changed = false; for (size_t i = 0; i < parallel_infos.size(); ++i) { const auto &fuse_nodes = parallel_infos[i].nodes(); if (fuse_nodes.size() <= 1) { continue; } changed = true; SetFusedParallelOpAttrToReturnNode(parallel_infos[i]); auto sg_node = ReplaceNodesWithGraphKernelNode(fuse_nodes, kernel_graph, "parallel"); common::AnfAlgo::SetNodeAttr(kAttrCompositeType, MakeValue("parallel_fusion"), sg_node); DumpParallelFusionDetail(fuse_nodes, sg_node); } return changed; } std::set CollectCapturedNodes(const std::vector &infos) { std::set captured; (void)std::for_each(infos.cbegin(), infos.cend(), [&captured](const ParallelInfo &info) { captured.insert(info.nodes().begin(), info.nodes().end()); }); return captured; } std::vector> GetParallelGroupsByBfs(const OrderedMap &node_rels, const std::set &exclude) { std::vector> groups; // BFS std::queue node_que; std::unordered_map outdegrees; for (const auto &[node, ref] : node_rels) { outdegrees[node] = SizeToInt(ref.nexts.size()); if (outdegrees[node] == 0) { node_que.push(node); } } int total_node_num = SizeToInt(node_rels.size()); while (!node_que.empty()) { std::vector group; int node_size = SizeToInt(node_que.size()); while (node_size--) { auto node = node_que.front(); node_que.pop(); if (exclude.count(node) == 0 && Parallelizable(node)) { (void)group.emplace_back(AnfNodePtrList({node})); } --total_node_num; auto iter = node_rels.find(node); if (iter == node_rels.end()) { MS_LOG(EXCEPTION) << "Internal error in node relationship!"; } for (const auto &pre : iter->second.pres) { if (--outdegrees[pre] == 0) { node_que.push(pre); } } } if (!group.empty()) { groups.push_back(group); } } if (total_node_num > 0) { MS_LOG(EXCEPTION) << "There is circle in analyze graph!"; } DumpParallelGroups(groups, "BFS"); return groups; } bool ParallelOpFusion::Run(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); parallel_level_ = GraphKernelFlags::GetInstance().parallel_ops_level; (void)std::make_shared()->Run(graph); auto kernel_graph = graph->cast>(); MS_EXCEPTION_IF_NULL(kernel_graph); cost_model_ptr_ = ParellelCostModelWarehouse::Instance().GetParallelCostModel(target_); MS_EXCEPTION_IF_NULL(cost_model_ptr_); auto nodes = TopoSort(kernel_graph->get_return()); std::reverse(nodes.begin(), nodes.end()); auto node_rels = GenAnalysisGraph(nodes); auto groups = SearchParallelGroups(node_rels); auto parallel_infos = SearchFusableParallelCNodes(groups); // Search in BFS for left nodes. if (parallel_level_ > 0) { auto exclued_nodes = CollectCapturedNodes(parallel_infos); auto groups_bfs = GetParallelGroupsByBfs(node_rels, exclued_nodes); auto bfs_parallel_infos = SearchFusableParallelCNodes(groups_bfs); (void)parallel_infos.insert(parallel_infos.end(), bfs_parallel_infos.begin(), bfs_parallel_infos.end()); } // Create core-fuse subgraph and change origin graph. bool changed = CreateParallelOpSubGraphs(parallel_infos, kernel_graph); (void)std::make_shared()->Run(graph); return changed; } } // namespace mindspore::graphkernel