diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/adapter/callback_impl.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/adapter/callback_impl.cc index 364c6468ab..59a1b58380 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/adapter/callback_impl.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/adapter/callback_impl.cc @@ -17,7 +17,9 @@ #include "backend/optimizer/graph_kernel/adapter/callback_impl.h" #include -#include +#include +#include +#include #include "backend/session/anf_runtime_algorithm.h" #include "backend/kernel_compiler/common_utils.h" @@ -76,4 +78,62 @@ std::string CallbackImpl::GetOutputFormat(const AnfNodePtr &node, size_t i) { std::string CallbackImpl::GetProcessor(const AnfNodePtr &node) { return kernel::GetProcessorStr(node); } std::string CallbackImpl::GetProcessorFromContext() { return kernel::GetStrProcessorFromContext(); } + +void CallbackImpl::SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) { + std::vector graph_input_format; + std::vector graph_input_type; + std::vector graph_output_format; + std::vector graph_output_type; + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto fg = GetCNodeFuncGraph(node); + MS_EXCEPTION_IF_NULL(fg); + auto &inputs = cnode->inputs(); + for (size_t i = 1; i < inputs.size(); ++i) { + auto kernel_with_index = AnfUtils::VisitKernel(inputs[i], 0); + if (kernel_with_index.first->isa()) { + auto tensor = GetValueNode(kernel_with_index.first); + MS_EXCEPTION_IF_NULL(tensor); + (void)graph_input_format.emplace_back(kOpFormat_DEFAULT); + (void)graph_input_type.emplace_back(tensor->data_type()); + } else { + auto input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); + (void)graph_input_format.emplace_back(std::move(input_format)); + auto input_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second); + (void)graph_input_type.emplace_back(input_type); + } + fg->parameters()[i - 1]->set_kernel_info(std::make_shared()); + kernel::KernelBuildInfo::KernelBuildInfoBuilder para_info_builder; + para_info_builder.SetOutputsFormat({graph_input_format.back()}); + para_info_builder.SetOutputsDeviceType({graph_input_type.back()}); + para_info_builder.SetKernelType(KernelType::AKG_KERNEL); + para_info_builder.SetProcessor(kernel::GetProcessorFromContext()); + AnfAlgo::SetSelectKernelBuildInfo(para_info_builder.Build(), fg->parameters()[i - 1].get()); + } + AnfNodePtrList outputs; + if (IsPrimitiveCNode(fg->output(), prim::kPrimMakeTuple)) { + auto fg_output = fg->output()->cast(); + MS_EXCEPTION_IF_NULL(fg_output); + outputs.assign(fg_output->inputs().begin() + 1, fg_output->inputs().end()); + } else { + outputs.push_back(fg->output()); + } + for (size_t i = 0; i < outputs.size(); ++i) { + auto kernel_with_index = AnfAlgo::VisitKernel(outputs[i], 0); + auto output_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); + auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second); + graph_output_format.push_back(output_format); + graph_output_type.push_back(output_type); + } + kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder; + graph_info_builder.SetInputsFormat(graph_input_format); + graph_info_builder.SetInputsDeviceType(graph_input_type); + graph_info_builder.SetOutputsFormat(graph_output_format); + graph_info_builder.SetOutputsDeviceType(graph_output_type); + graph_info_builder.SetProcessor(kernel::GetProcessorFromContext()); + graph_info_builder.SetKernelType(KernelType::AKG_KERNEL); + graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE); + auto graph_selected_info = graph_info_builder.Build(); + AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, node.get()); +} } // namespace mindspore::graphkernel diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/adapter/callback_impl.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/adapter/callback_impl.h index a08d35d1eb..726903fcc4 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/adapter/callback_impl.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/adapter/callback_impl.h @@ -34,6 +34,7 @@ class CallbackImpl : public Callback { std::string GetOutputFormat(const AnfNodePtr &node, size_t i) override; std::string GetProcessor(const AnfNodePtr &node) override; std::string GetProcessorFromContext() override; + void SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) override; }; } // namespace mindspore::graphkernel #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADAPTER_CALLBACK_IMPL_H_ diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean.cc index d2e2fbbcaf..8d88d1f0d6 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean.cc @@ -29,6 +29,7 @@ #include "utils/log_adapter.h" #include "backend/kernel_compiler/kernel.h" #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" +#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h" #include "backend/session/anf_runtime_algorithm.h" #include "backend/session/kernel_graph.h" #include "debug/anf_ir_dump.h" @@ -360,7 +361,7 @@ void AtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node, CorrectKernelBuildInfo(composite_node, new_input); auto old_graph_name = GetValue(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); - auto new_graph_name = ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add"); + auto new_graph_name = GkUtils::ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add"); sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name)); MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name; } @@ -435,7 +436,7 @@ CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphP auto broadcast_to_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph)}); broadcast_to_composite_node->set_abstract(broadcast_to_node_inner->abstract()); SetNewKernelInfo(broadcast_to_composite_node, new_sub_graph, {}, {broadcast_to_node_inner}); - auto graph_attr = ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "atomic_clean"); + auto graph_attr = GkUtils::ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "atomic_clean"); new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr)); new_sub_graph->set_attr("composite_type", MakeValue("atomic_clean")); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.cc index 70adb6aaa1..5373965f7b 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.cc @@ -25,6 +25,7 @@ #include "backend/kernel_compiler/kernel.h" #include "backend/kernel_compiler/common_utils.h" #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" +#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h" #include "backend/session/anf_runtime_algorithm.h" #include "backend/session/kernel_graph.h" @@ -82,7 +83,7 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_ } auto old_graph_name = GetValue(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); - auto new_graph_name = ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add"); + auto new_graph_name = GkUtils::ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add"); sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name)); MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name; } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc index fc00f525c9..6ad5d4ea71 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc @@ -25,6 +25,7 @@ #include #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" +#include "backend/optimizer/graph_kernel/core/graph_builder.h" #include "backend/session/anf_runtime_algorithm.h" #include "ir/anf.h" #include "utils/context/graph_kernel_flags.h" @@ -645,14 +646,12 @@ bool ArithmeticSimplify::Run(const FuncGraphPtr &func_graph) { } if (!change_anf_graph) continue; ReorganizeEmptyGraph(lg); - AnfNodePtrList outputs; - auto new_funcgraph = LiteGraph2AnfGraph(lg, &outputs); + auto new_funcgraph = LiteGraph2AnfGraph(lg); new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); auto cnode = node->cast(); AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end()); EliminateRedundantParameters(new_funcgraph, &inputs); - auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs, outputs); - SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs); + auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs); mng->Replace(node, new_node); mng->AddFuncGraph(new_funcgraph); do_simplify = true; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/core/graph_builder.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/core/graph_builder.cc index 79f8dcd459..2da028f930 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/core/graph_builder.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/core/graph_builder.cc @@ -25,29 +25,27 @@ #include "base/core_ops.h" #include "ir/func_graph.h" #include "utils/utils.h" +#include "utils/anf_utils.h" +#include "backend/optimizer/graph_kernel/core/graph_kernel_callback.h" +#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h" namespace mindspore::graphkernel { -AnfNodePtrList GetOutput(const AnfNodePtrList &nodes, const NodeUsersMap &users, - const std::unordered_set &seen) { +namespace { +// find outputs of nodes +AnfNodePtrList FindOutputs(const AnfNodePtrList &nodes, const AnfNodePtrToAnfNodePtrMap &eqv) { AnfNodePtrList output; - if (users.size() == 0) { - return output; - } + auto mng = nodes[0]->func_graph()->manager(); + MS_EXCEPTION_IF_NULL(mng); + auto &users = mng->node_users(); for (auto &node : nodes) { - if (!node->isa()) { - continue; - } + // only CNode can be an output. + if (!node->isa()) continue; auto iter = users.find(node); - if (iter == users.end()) { - continue; - } + if (iter == users.end()) continue; auto &node_users = iter->second; - const bool has_outer_user = std::any_of(std::begin(node_users), std::end(node_users), - [&seen](const std::pair &u) -> bool { - const bool is_outer_user = (seen.find(u.first) == seen.end()); - return is_outer_user; - }); - if (has_outer_user) { + // if any user of the `node` is not in the nodes list, the `node` is an output. + if (std::any_of(std::begin(node_users), std::end(node_users), + [&eqv](const std::pair &u) { return eqv.find(u.first) == eqv.end(); })) { output.emplace_back(node); } } @@ -56,12 +54,11 @@ AnfNodePtrList GetOutput(const AnfNodePtrList &nodes, const NodeUsersMap &users, AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *const inputs_ptr, AnfNodePtrToAnfNodePtrMap *eqv_ptr) { - auto &input_list = *inputs_ptr; auto &eqv = *eqv_ptr; if (node->isa() && !IsValueNode(node)) { eqv[node] = node; } else if (eqv.find(node) == eqv.end()) { - input_list.push_back(node); + inputs_ptr->push_back(node); eqv[node] = fg->add_parameter(); eqv[node]->set_abstract(node->abstract()); eqv[node]->set_kernel_info(node->kernel_info_ptr()); @@ -69,43 +66,150 @@ AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNo return eqv[node]; } -std::tuple BuildGraphFromNodes(const AnfNodePtrList &node_list) { +bool InlineInnerFuncGraph(const FuncGraphPtr &fg) { + auto mng = fg->manager(); + MS_EXCEPTION_IF_NULL(mng); + bool changed = false; + auto cnodes = fg->GetOrderedCnodes(); + for (const auto &n : cnodes) { + auto graph_kernel_g = GetCNodeFuncGraph(n); + if (graph_kernel_g == nullptr) continue; + AnfNodePtrList inp(n->inputs().begin() + 1, n->inputs().end()); + auto out = InlineClone(graph_kernel_g, fg, inp, n->input(0)->scope()); + mng->Replace(n, out); + changed = true; + } + return changed; +} + +void EliminateMakeTuple(const FuncGraphPtr &fg) { + if (!IsPrimitiveCNode(fg->output(), prim::kPrimMakeTuple)) { + return; + } + auto out_cnode = fg->output()->cast(); + MS_EXCEPTION_IF_NULL(out_cnode); + AnfNodePtrList new_args = GkUtils::SpreadTuples(out_cnode->inputs()); + if (new_args.size() != out_cnode->size()) { + auto new_out = fg->NewCNode(new_args); + auto mng = fg->manager(); + MS_EXCEPTION_IF_NULL(mng); + mng->Replace(out_cnode, new_out); + } + AbstractBasePtrList abs_list; + std::transform(new_args.begin() + 1, new_args.end(), std::back_inserter(abs_list), + [](const AnfNodePtr &node) { return node->abstract(); }); + fg->output()->set_abstract(std::make_shared(abs_list)); +} + +bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr) { + auto cnodes = fg->GetOrderedCnodes(); + AnfNodePtrList value_nodes; + for (const auto &cnode : cnodes) { + auto &inputs = cnode->inputs(); + for (size_t i = 1; i < inputs.size(); ++i) { + const auto &tnode = inputs[i]; + auto tensor = GetValueNode(tnode); + if (tensor == nullptr || tensor->DataSize() == 1) { + continue; + } + value_nodes.push_back(tnode); + } + } + if (value_nodes.empty()) return false; + auto mng = fg->manager(); + MS_EXCEPTION_IF_NULL(mng); + for (auto &vnode : value_nodes) { + auto parameter = fg->add_parameter(); + parameter->set_abstract(vnode->abstract()); + parameter->set_kernel_info(vnode->kernel_info_ptr()); + mng->Replace(vnode, parameter); + inputs_ptr->push_back(vnode); + } + return true; +} + +bool IsTupleOutput(const AnfNodePtr &out, AnfNodePtrList *real_outs) { + if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) { + auto &inputs = out->cast()->inputs(); + real_outs->assign(inputs.begin() + 1, inputs.end()); + return true; + } + if (auto fg = GetCNodeFuncGraph(out); fg != nullptr) { + return IsTupleOutput(fg->output(), real_outs); + } + return false; +} + +void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_fuse_cnode, + const AnfNodePtrList &outputs) { + MS_EXCEPTION_IF_NULL(func_graph); + auto mng = func_graph->manager(); + MS_EXCEPTION_IF_NULL(mng); + // single out + if (outputs.size() == 1) { + mng->Replace(outputs[0], new_fuse_cnode); + return; + } + + size_t offset = 0; + for (size_t out_idx = 0; out_idx < outputs.size(); out_idx++) { + AnfNodePtrList real_outs; + // the output is a single tensor + if (!IsTupleOutput(outputs[out_idx], &real_outs)) { + auto gt_idx = MakeValue(SizeToLong(out_idx + offset)); + AnfNodePtrList gt_inputs{NewValueNode(prim::kPrimTupleGetItem), new_fuse_cnode, NewValueNode(gt_idx)}; + gt_inputs.back()->set_abstract(gt_idx->ToAbstract()); + auto new_out = func_graph->NewCNode(gt_inputs); + new_out->set_abstract(outputs[out_idx]->abstract()); + mng->Replace(outputs[out_idx], new_out); + continue; + } + + // the out is make tuple , modify the get_item node's value + auto users = mng->node_users()[outputs[out_idx]]; // use a copy, the original user map is changed in for-loop. + for (auto &user : users) { + auto getitem_node = user.first; + if (!getitem_node->isa() || !IsPrimitiveCNode(getitem_node, prim::kPrimTupleGetItem)) { + continue; + } + auto value_ptr = GetValueNode(getitem_node->cast()->input(kInputNodeOutputIndexInTupleGetItem)); + MS_EXCEPTION_IF_NULL(value_ptr); + auto old_gt_idx = GetValue(value_ptr); + auto gt_idx = MakeValue(SizeToLong(out_idx + offset) + old_gt_idx); + AnfNodePtrList gt_inputs{NewValueNode(prim::kPrimTupleGetItem), new_fuse_cnode, NewValueNode(gt_idx)}; + gt_inputs.back()->set_abstract(gt_idx->ToAbstract()); + auto new_getitem_node = func_graph->NewCNode(gt_inputs); + new_getitem_node->set_abstract(getitem_node->abstract()); + mng->Replace(getitem_node, new_getitem_node); + } + + offset += real_outs.size() - 1; + } +} +} // namespace + +std::tuple BuildGraphFromNodes(const AnfNodePtrList &nodes) { FuncGraphPtr fg = nullptr; { // limit the lifetime of guard. - TraceGuard guard( - std::make_shared(node_list[0]->cast()->func_graph()->debug_info())); + TraceGuard guard(std::make_shared(nodes[0]->cast()->func_graph()->debug_info())); fg = std::make_shared(); } AnfNodePtrList input_list; AnfNodePtrToAnfNodePtrMap eqv; // Merge CNodes into a AnfGraph that represents a linear instruction segment - for (auto node : node_list) { - auto &input_nodes = node->cast()->inputs(); - auto fn = input_nodes[0]; - std::vector new_args{fn}; - if (IsPrimitive(fn, prim::kPrimDepend) && input_nodes.size() >= kDependInputSize && - eqv.find(input_nodes[kDependAttachNodeIndex]) == eqv.end()) { - new_args.emplace_back(RefSubGraphNode(fg, input_nodes[kRealInputIndexInDepend], &input_list, &eqv)); - const size_t value_start_index = 2; - for (size_t i = value_start_index; i < input_nodes.size(); ++i) { - new_args.emplace_back(NewValueNode(MakeValue(0))); - } - } else { - (void)std::transform( - std::begin(input_nodes) + 1, std::end(input_nodes), std::back_inserter(new_args), - [&fg, &input_list, &eqv](const AnfNodePtr &node) { return RefSubGraphNode(fg, node, &input_list, &eqv); }); - } + for (auto &node : nodes) { + auto &node_inputs = node->cast()->inputs(); + std::vector new_args{node_inputs[0]}; + (void)std::transform( + std::begin(node_inputs) + 1, std::end(node_inputs), std::back_inserter(new_args), + [&fg, &input_list, &eqv](const AnfNodePtr &node) { return RefSubGraphNode(fg, node, &input_list, &eqv); }); TraceGuard tg(std::make_shared(node->debug_info())); eqv[node] = fg->NewCNode(new_args); eqv[node]->set_abstract(node->abstract()); eqv[node]->set_kernel_info(node->kernel_info_ptr()); } - std::unordered_set eqv_keys; - (void)std::transform(std::begin(eqv), std::end(eqv), std::inserter(eqv_keys, eqv_keys.end()), - [](const std::pair &elem) -> AnfNodePtr { return elem.first; }); - auto mgr = node_list[0]->func_graph()->manager(); - auto outputs = GetOutput(node_list, mgr->node_users(), eqv_keys); + auto outputs = FindOutputs(nodes, eqv); AnfNodePtr fg_output; if (outputs.size() > 1) { std::vector output_args; @@ -120,4 +224,52 @@ std::tuple BuildGraphFromNodes(con fg->set_output(fg_output); return std::make_tuple(fg, input_list, outputs); } + +// Transform nodes(including basic and composite node) to a new graph, and collect their inputs and outputs. +std::tuple BuildSingleGraphFromNodes(const AnfNodePtrList &nodes) { + FuncGraphPtr fg; + AnfNodePtrList inputs; + AnfNodePtrList outputs; + std::tie(fg, inputs, outputs) = BuildGraphFromNodes(nodes); + + FuncGraphManagerPtr mng = fg->manager(); + if (mng == nullptr) { + mng = Manage(fg, false); + fg->set_manager(mng); + } + + InlineInnerFuncGraph(fg); + // eliminate tuple of tuple, and set Abstract for output MakeTuple + EliminateMakeTuple(fg); + ConvertNonscalarTensorToParameter(fg, &inputs); + + return std::make_tuple(fg, inputs, outputs); +} + +AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &main_fg, const FuncGraphPtr &sub_fg, const AnfNodePtrList &inputs) { + std::vector fn_inputs{NewValueNode(sub_fg)}; + fn_inputs.insert(fn_inputs.end(), inputs.begin(), inputs.end()); + auto fuse_cnode = main_fg->NewCNode(fn_inputs); + fuse_cnode->set_abstract(sub_fg->output()->abstract()); + Callback::Instance()->SetGraphKernelNodeKernelInfo(fuse_cnode); + return fuse_cnode; +} + +AnfNodePtr ReplaceNodesWithGraphKernelNode(const AnfNodePtrList &nodes, const FuncGraphPtr &main_graph, + const std::string &postfix) { + auto mng = main_graph->manager(); + if (mng == nullptr) { + mng = Manage(main_graph, true); + main_graph->set_manager(mng); + } + FuncGraphPtr fg; + AnfNodePtrList inputs; + AnfNodePtrList outputs; + std::tie(fg, inputs, outputs) = BuildSingleGraphFromNodes(nodes); + auto fuse_new_node = CreateNewFuseCNode(main_graph, fg, inputs); + ReplaceNewFuseCNode(main_graph, fuse_new_node, outputs); + auto fuse_op_name = GkUtils::ExtractGraphKernelName(nodes, "", postfix); + fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name)); + return fuse_new_node; +} } // namespace mindspore::graphkernel diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/core/graph_builder.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/core/graph_builder.h index 5293dc6fcb..197a3f1af8 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/core/graph_builder.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/core/graph_builder.h @@ -18,11 +18,16 @@ #include #include +#include #include "ir/anf.h" namespace mindspore::graphkernel { using AnfNodePtrToAnfNodePtrMap = std::unordered_map; -std::tuple BuildGraphFromNodes(const AnfNodePtrList &lst); +std::tuple BuildGraphFromNodes(const AnfNodePtrList &nodes); +std::tuple BuildSingleGraphFromNodes(const AnfNodePtrList &nodes); +AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &main_fg, const FuncGraphPtr &sub_fg, const AnfNodePtrList &inputs); +AnfNodePtr ReplaceNodesWithGraphKernelNode(const AnfNodePtrList &nodes, const FuncGraphPtr &main_graph, + const std::string &postfix = ""); } // namespace mindspore::graphkernel #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_BUILDER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/core/graph_kernel_callback.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/core/graph_kernel_callback.h index 035a8ff532..dd107520f7 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/core/graph_kernel_callback.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/core/graph_kernel_callback.h @@ -120,6 +120,13 @@ class Callback { */ virtual std::string GetProcessorFromContext() = 0; + /** + * @brief Set KernelInfo for a GraphKernel node, the info is extract from its inputs/outputs. + * + * @param[in] node the GraphKernel CNode. + */ + virtual void SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) = 0; + private: friend class CallbackImplRegister; static void RegImpl(Callback *cb) { instance_.reset(cb); } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/core/graph_kernel_utils.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/core/graph_kernel_utils.cc new file mode 100644 index 0000000000..f073e978fa --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/core/graph_kernel_utils.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h" +#include +#include "base/core_ops.h" +#include "utils/anf_utils.h" + +namespace mindspore::graphkernel { +std::string GkUtils::ExtractGraphKernelName(const AnfNodePtrList &nodes, const std::string &prefix, + const std::string &postfix) { + std::stringstream name; + if (!prefix.empty()) { + name << prefix << "_"; + } + for (const auto &node : nodes) { + if (AnfUtils::IsGraphKernel(node)) { + auto fg_flag_val = GetCNodeFuncGraph(node)->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); + name << GetValue(fg_flag_val) << "_"; + } else if (node->isa() && AnfUtils::IsRealKernel(node)) { + name << GetCNodePrimitive(node)->name() << "_"; + } + } + if (!postfix.empty()) { + name << postfix; + } + return name.str(); +} + +AnfNodePtrList GkUtils::SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index) { + AnfNodePtrList result; + for (size_t i = begin_index; i < nodes.size(); i++) { + if (IsPrimitiveCNode(nodes[i], prim::kPrimMakeTuple)) { + auto mt = nodes[i]->cast(); + // recursively spread all inner tuples. + auto mt_inputs = SpreadTuples(mt->inputs(), 1); + result.insert(result.end(), mt_inputs.begin(), mt_inputs.end()); + } else { + result.push_back(nodes[i]); + } + } + return result; +} +} // namespace mindspore::graphkernel diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/core/graph_kernel_utils.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/core/graph_kernel_utils.h new file mode 100644 index 0000000000..63f67913ec --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/core/graph_kernel_utils.h @@ -0,0 +1,52 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_UTILS_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_UTILS_H_ + +#include +#include "ir/anf.h" +#include "ir/func_graph.h" + +namespace mindspore::graphkernel { +class GkUtils { + public: + /** + * @brief Extract kernel name from nodes, only the real kernel CNode is processed. + * @param[in] nodes The node list + * @param[in] prefix The prefix of result name + * @param[in] postfix The postfix of result name + * @return The string concatenated by the names of all cnodes + */ + static std::string ExtractGraphKernelName(const AnfNodePtrList &nodes, const std::string &prefix = "", + const std::string &postfix = ""); + + /** + * @brief Spread the MakeTuple in node list + * @param[in] nodes + * @param[in] begin_index + * @example + * input + * nodes: [ a, b, MakeTuple[i, j], c, d, MakeTuple[x, MakeTuple[y, z]] ] + * begin_index: 1 + * output + * [b, i, j, c, d, x, y, z] + * @return std::vector + */ + static AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index = 0); +}; +} // namespace mindspore::graphkernel +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_UTILS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/eliminate_redundant_output.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/eliminate_redundant_output.cc index f95aca3414..490e8ada16 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/eliminate_redundant_output.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/eliminate_redundant_output.cc @@ -29,6 +29,7 @@ #include "backend/kernel_compiler/common_utils.h" #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" #include "backend/optimizer/graph_kernel/update_state_formatter.h" +#include "backend/optimizer/graph_kernel/core/graph_builder.h" namespace mindspore::graphkernel { namespace { @@ -257,10 +258,7 @@ AnfNodePtr EliminateHangingOutput::ReplaceMakeTuple(const AnfNodePtr &node, cons auto old_cnode = node->cast(); MS_EXCEPTION_IF_NULL(old_cnode); AnfNodePtrList inputs(old_cnode->inputs().begin() + 1, old_cnode->inputs().end()); - AnfNodePtrList outputs; - kernel::GetFuncGraphOutputNodes(func_graph, &outputs); - auto graph_kernel_node = CreateNewFuseCNode(node->func_graph(), func_graph, inputs, outputs); - SetNewKernelInfo(graph_kernel_node, func_graph, inputs, outputs); + auto graph_kernel_node = CreateNewFuseCNode(node->func_graph(), func_graph, inputs); return graph_kernel_node; } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc index 45c0b0c53b..f13eb12aa0 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc @@ -33,6 +33,7 @@ #include "backend/session/anf_runtime_algorithm.h" #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" #include "backend/optimizer/pass/getitem_tuple.h" +#include "backend/optimizer/graph_kernel/core/graph_builder.h" namespace mindspore::graphkernel { std::vector GraphKernelCluster::GetClusterableOpList() { @@ -404,10 +405,9 @@ bool GraphKernelCluster::Process(const FuncGraphPtr &func_graph) { void GraphKernelCluster::CreateFuncGraph(const FuncGraphPtr &func_graph, const std::vector &nodes_id) { AnfNodePtrList old_nodes; - AnfNodePtr new_node; (void)std::transform(nodes_id.begin(), nodes_id.end(), std::back_inserter(old_nodes), [this](size_t id) { return this->nodes_[id]; }); - std::tie(new_node, std::ignore) = FuseNodesToSubGraph(old_nodes, func_graph, "fusion"); + auto new_node = ReplaceNodesWithGraphKernelNode(old_nodes, func_graph, "fusion"); std::shared_ptr eliminate_getitem_pass = std::make_shared(); (void)eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(new_node)); if (GraphKernelFlags::GetInstance().dump_as_text) { diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc index 89150567b6..976864b11f 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc @@ -37,6 +37,7 @@ #include "pybind_api/ir/primitive_py.h" #include "runtime/device/kernel_info.h" #include "backend/optimizer/graph_kernel/expanders/expander_factory.h" +#include "backend/optimizer/graph_kernel/core/graph_builder.h" namespace mindspore::graphkernel { namespace { @@ -164,12 +165,9 @@ AnfNodePtr PyExpander::CreateExpandGraphKernel(const FuncGraphPtr &new_func_grap auto func_graph = old_node->func_graph(); std::vector inputs(old_node->inputs().begin() + 1, old_node->inputs().end()); AnfNodePtrList kernel_nodes; - AnfNodePtrList outputs; EliminateRedundantParameters(new_func_graph, &inputs); kernel::GetValidKernelNodes(new_func_graph, &kernel_nodes); - kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs); - auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs); - SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs); + auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs); MS_LOG(DEBUG) << "Expand node: " << old_node->fullname_with_scope() << " with: " << graph_kernel_node->fullname_with_scope(); return graph_kernel_node; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index 73df895a41..b6deca374e 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -67,37 +67,6 @@ bool IsMakeTupleOut(const AnfNodePtr &out, AnfNodePtrList *real_outs) { return false; } -AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) { - AnfNodePtrList outs; - auto out_node = fg->output(); - if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) { - std::vector output_args; - auto out_cnode = out_node->cast(); - for (auto out : out_cnode->inputs()) { - if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) { - auto inputs = out->cast()->inputs(); - for (size_t i = 1; i < inputs.size(); ++i) { - output_args.push_back(inputs[i]); - } - } else { - output_args.push_back(out); - } - } - if (output_args.size() != out_cnode->inputs().size()) { - auto new_out = fg->NewCNode(output_args); - mng->Replace(out_node, new_out); - } - - for (size_t i = 1; i < output_args.size(); ++i) { - outs.push_back(output_args[i]); - } - return outs; - } - - outs.push_back(out_node); - return outs; -} - bool GenJson(const AnfNodePtrList &op_nodes, const std::pair &in_and_out, const DumpOption &dump_option, nlohmann::json *op_desc, std::map *address_node_map = nullptr) { @@ -128,100 +97,6 @@ AbstractBasePtr GetOutputAbstract(const AnfNodePtr &node, size_t output_idx) { return out_spec; } -bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr) { - MS_EXCEPTION_IF_NULL(inputs_ptr); - auto nodes = TopoSort(fg->get_return()); - - std::vector> v_replace; - for (const auto &node : nodes) { - if (!node->isa()) { - continue; - } - auto &inputs = node->cast()->inputs(); - for (size_t i = 1; i < inputs.size(); ++i) { - const auto &tnode = inputs[i]; - auto tensor = GetValueNode(tnode); - if (tensor == nullptr || tensor->DataSize() == 1) { - continue; - } - auto tensor_iter = std::find_if( - v_replace.begin(), v_replace.end(), - [&tensor](const std::pair &vl) { return vl.first->ValueEqual(*tensor); }); - if (tensor_iter == v_replace.end()) { - (void)v_replace.emplace_back(tensor, AnfNodePtrList{tnode}); - } else { - tensor_iter->second.push_back(tnode); - } - } - } - - if (v_replace.empty()) { - return false; - } - - auto mng = fg->manager(); - if (mng == nullptr) { - mng = Manage(fg, false); - fg->set_manager(mng); - } - - auto &inputs = *inputs_ptr; - for (auto iter : v_replace) { - auto value_nodes = iter.second; - if (value_nodes.empty()) { - MS_LOG(EXCEPTION) << "Invalid value in map!"; - } - - auto vnode = value_nodes[0]; - auto parameter = fg->add_parameter(); - parameter->set_abstract(vnode->abstract()); - parameter->set_kernel_info(vnode->kernel_info_ptr()); - for (const auto &value_node : value_nodes) { - mng->Replace(value_node, parameter); - } - - inputs.push_back(vnode); - } - - return true; -} - -// Transform nodes(including basic and composite node) to a new graph, and collect their inputs and outputs. -std::tuple MixedNodesTransToGraph(const AnfNodePtrList &fuse_nodes, - AnfNodePtrList *src_outputs) { - FuncGraphPtr fg; - AnfNodePtrList inputs; - AnfNodePtrList outputs; - AnfNodePtrList *soutputs = (src_outputs != nullptr) ? src_outputs : &outputs; - std::tie(fg, inputs, *soutputs) = BuildGraphFromNodes(fuse_nodes); - - FuncGraphManagerPtr mng = fg->manager(); - if (mng == nullptr) { - mng = Manage(fg, false); - fg->set_manager(mng); - } - - // Inline origin graphkernel - auto cnodes = fg->GetOrderedCnodes(); - for (const auto &n : cnodes) { - if (!AnfAlgo::IsGraphKernel(n)) { - continue; - } - auto graph_kernel_g = GetValueNode(n->input(0)); - AnfNodePtrList ins; - ins.insert(ins.end(), n->inputs().begin() + 1, n->inputs().end()); - auto out = InlineClone(graph_kernel_g, fg, ins, n->input(0)->scope()); - mng->Replace(n, out); - } - - EliminateMakeTuple(fg, mng); - ConvertNonscalarTensorToParameter(fg, &inputs); - - outputs.clear(); - kernel::GetFuncGraphOutputNodes(fg, &outputs); - return std::make_tuple(fg, inputs, outputs); -} - // Rebuild as node inputs or outputs have changed, processor comes from node itself kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector &inputs_format, const std::vector &inputs_type, @@ -254,6 +129,7 @@ kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vectorSetGraphKernelNodeKernelInfo. void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, const AnfNodePtrList &outputs) { std::vector graph_input_format; @@ -309,127 +185,6 @@ void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, new_node.get()); } -AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &func_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, - const AnfNodePtrList &outputs) { - auto func_node = NewValueNode(fg); - std::vector fn_inputs; - fn_inputs.push_back(func_node); - fn_inputs.insert(fn_inputs.end(), inputs.begin(), inputs.end()); - auto fuse_cnode = func_graph->NewCNode(fn_inputs); - // Set output abstract - if (outputs.size() > 1) { - std::vector out_specs; - for (size_t i = 0; i < outputs.size(); ++i) { - out_specs.push_back(outputs[i]->abstract()); - } - auto out_spec = std::make_shared(out_specs); - fuse_cnode->set_abstract(out_spec); - } else { - fuse_cnode->set_abstract(outputs[0]->abstract()); - } - // Set parameter abstract. - for (size_t i = 0; i < inputs.size(); ++i) { - auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0); - auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second); - fg->parameters()[i]->set_abstract(input_abs); - } - return fuse_cnode; -} - -void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_fuse_cnode, - const AnfNodePtrList &outputs) { - MS_EXCEPTION_IF_NULL(func_graph); - auto mng = func_graph->manager(); - MS_EXCEPTION_IF_NULL(mng); - // single out - if (outputs.size() == 1) { - mng->Replace(outputs[0], new_fuse_cnode); - return; - } - - std::vector fn_inputs; - size_t offset = 0; - for (size_t out_idx = 0; out_idx < outputs.size(); out_idx++) { - AnfNodePtrList real_outs; - // not make tuple out, replace - if (!IsMakeTupleOut(outputs[out_idx], &real_outs)) { - fn_inputs.clear(); - fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem)); - fn_inputs.push_back(new_fuse_cnode); - fn_inputs.push_back(NewValueNode(MakeValue(SizeToLong(out_idx + offset)))); - auto new_out = func_graph->NewCNode(fn_inputs); - new_out->set_abstract(outputs[out_idx]->abstract()); - mng->Replace(outputs[out_idx], new_out); - continue; - } - - // the out is make tuple , modify the get_item node's value - auto users = mng->node_users()[outputs[out_idx]]; - for (auto &user : users) { - auto use_node = user.first; - if (!use_node->isa() || !IsPrimitiveCNode(use_node, prim::kPrimTupleGetItem)) { - continue; - } - auto get_item_cnode = use_node->cast(); - auto value_input = get_item_cnode->input(kInputNodeOutputIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(value_input); - auto value_node = value_input->cast(); - MS_EXCEPTION_IF_NULL(value_node); - auto item_idx = GetValue(value_node->value()); - int64_t new_item_idx = SizeToLong(out_idx + offset) + item_idx; - fn_inputs.clear(); - fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem)); - fn_inputs.push_back(new_fuse_cnode); - fn_inputs.push_back(NewValueNode(new_item_idx)); - auto new_out = func_graph->NewCNode(fn_inputs); - new_out->set_abstract(get_item_cnode->abstract()); - mng->Replace(get_item_cnode, new_out); - } - - offset += real_outs.size() - 1; - } -} - -std::tuple FuseNodesToSubGraph(const std::vector &fuse_nodes, - const FuncGraphPtr &kernel_graph, - const std::string &postfix) { - auto mng = kernel_graph->manager(); - if (mng == nullptr) { - mng = Manage(kernel_graph, true); - kernel_graph->set_manager(mng); - } - - FuncGraphPtr fg; - AnfNodePtrList inputs; - AnfNodePtrList src_outputs; - AnfNodePtrList outputs; - - std::tie(fg, inputs, outputs) = MixedNodesTransToGraph(fuse_nodes, &src_outputs); - auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs); - SetNewKernelInfo(fuse_new_node, fg, inputs, outputs); - // Handle get-item probleam. - ReplaceNewFuseCNode(kernel_graph, fuse_new_node, src_outputs); - - // set graphKernel attr - std::string fuse_op_name = ""; - for (auto &fuse_node : fuse_nodes) { - if (IsPrimitiveCNode(fuse_node)) { - fuse_op_name += AnfAlgo::GetCNodePrimitive(fuse_node)->name() + "_"; - } else if (AnfAlgo::IsGraphKernel(fuse_node)) { - auto fuse_cnode = fuse_node->cast(); - MS_EXCEPTION_IF_NULL(fuse_cnode); - auto graph_kernel_fg = GetValueNode(fuse_cnode->input(kAnfPrimitiveIndex)); - auto fg_flag_val = graph_kernel_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); - auto fuse_fg_name = GetValue(fg_flag_val); - fuse_op_name += fuse_fg_name + "_"; - } - } - fuse_op_name += postfix; - fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name)); - - return std::make_tuple(fuse_new_node, src_outputs); -} - bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc, std::map *address_node_map) { MS_EXCEPTION_IF_NULL(op_desc); @@ -466,15 +221,14 @@ bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, n } FuncGraphPtr fg; - AnfNodePtrList op_nodes, inputs, outputs; + if (nodes.size() == 1 && AnfAlgo::IsGraphKernel(nodes[0])) { fg = AnfAlgo::GetCNodeFuncGraphPtr(nodes[0]); } else { - std::tie(fg, inputs, outputs) = MixedNodesTransToGraph(nodes); - inputs.clear(); - outputs.clear(); + std::tie(fg, std::ignore, std::ignore) = BuildSingleGraphFromNodes(nodes); } + AnfNodePtrList op_nodes, inputs, outputs; kernel::GetValidKernelNodes(fg, &op_nodes, &inputs, &outputs); auto mng = fg->manager(); @@ -524,22 +278,6 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc) { return fg; } -std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix, const string &postfix) { - std::stringstream name; - if (prefix != "") { - name << prefix << "_"; - } - for (const auto &node : cnodes) { - if (node->isa() && AnfUtils::IsRealKernel(node)) { - name << AnfAlgo::GetCNodeName(node) << "_"; - } - } - if (postfix != "") { - name << postfix; - } - return name.str(); -} - void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) { auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h index 993e995a25..18a5e731c6 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h @@ -53,9 +53,6 @@ struct DataInfo { TypePtr type{nullptr}; }; -bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr); -std::tuple MixedNodesTransToGraph(const AnfNodePtrList &fuse_nodes, - AnfNodePtrList *src_outputs = nullptr); void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, const AnfNodePtrList &outputs); kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector &inputs_format, @@ -66,19 +63,11 @@ kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector &inputs_type, const std::vector &output_formats, const std::vector &output_types); -AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, - const AnfNodePtrList &outputs); -void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode, - const AnfNodePtrList &outputs); -std::tuple FuseNodesToSubGraph(const std::vector &fuse_nodes, - const FuncGraphPtr &kernel_graph, - const std::string &postfix = ""); bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc); bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc, std::map *address_node_map); bool AnfToJsonDesc(const std::vector &graphs, const DumpOption &dump_option, nlohmann::json *op_desc); FuncGraphPtr JsonDescToAnf(const std::string &json_desc); -std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix = "", const string &postfix = ""); void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); std::string GetFormat(const AnfNodePtr &node); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc index e66a386710..38d2b86cc1 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc @@ -27,6 +27,7 @@ #include "backend/session/anf_runtime_algorithm.h" #include "backend/kernel_compiler/common_utils.h" #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" +#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h" #include "debug/anf_ir_dump.h" #include "utils/context/graph_kernel_flags.h" @@ -632,7 +633,7 @@ class Splitter { graph_manager->AddFuncGraph(sub_func_graph); // set GraphKernel attr - auto attr = ExtractGraphKernelName(TopoSort(sub_func_graph->get_return()), "", "split"); + auto attr = GkUtils::ExtractGraphKernelName(TopoSort(sub_func_graph->get_return()), "", "split"); sub_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(attr)); // set kernel info diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc index e10476f2aa..dd9917f006 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc @@ -28,6 +28,7 @@ #include "frontend/operator/ops.h" #include "ir/func_graph_cloner.h" #include "backend/optimizer/graph_kernel/update_state_formatter.h" +#include "backend/optimizer/graph_kernel/core/graph_builder.h" namespace mindspore::graphkernel { namespace { @@ -746,8 +747,7 @@ bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector } changed = true; SetFusedParallelOpAttrToReturnNode(parallel_infos[i]); - AnfNodePtr sg_node; - std::tie(sg_node, std::ignore) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "parallel"); + auto sg_node = ReplaceNodesWithGraphKernelNode(fuse_nodes, kernel_graph, "parallel"); AnfAlgo::SetNodeAttr(kAttrCompositeType, MakeValue("parallel_fusion"), sg_node); DumpParallelFusionDetail(fuse_nodes, sg_node); } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/transform_op_optimizer.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/transform_op_optimizer.cc index 376d98e9cb..3f37d232c2 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/transform_op_optimizer.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/transform_op_optimizer.cc @@ -31,6 +31,7 @@ #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" #include "backend/optimizer/graph_kernel/model/lite_graph.h" #include "backend/optimizer/graph_kernel/model/op_register.h" +#include "backend/optimizer/graph_kernel/core/graph_builder.h" namespace mindspore::graphkernel { namespace { @@ -438,13 +439,11 @@ bool TransformOpOptimizer::Run(const FuncGraphPtr &kernel_graph) { auto litegraph = AnfGraph2LiteGraph(sub_func_graph); if (Process(litegraph)) { changed = true; - AnfNodePtrList outputs; - auto new_funcgraph = LiteGraph2AnfGraph(litegraph, &outputs); + auto new_funcgraph = LiteGraph2AnfGraph(litegraph); new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); auto cnode = node->cast(); AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end()); - auto new_node = CreateNewFuseCNode(kernel_graph, new_funcgraph, inputs, outputs); - SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs); + auto new_node = CreateNewFuseCNode(kernel_graph, new_funcgraph, inputs); (void)mng->Replace(node, new_node); mng->AddFuncGraph(new_funcgraph); } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/tsa_atomic_add_to_first_tensor.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/tsa_atomic_add_to_first_tensor.cc index e618d5b860..8f61e548a9 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/tsa_atomic_add_to_first_tensor.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/tsa_atomic_add_to_first_tensor.cc @@ -35,6 +35,7 @@ #include "backend/kernel_compiler/kernel.h" #include "backend/kernel_compiler/common_utils.h" #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" +#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h" namespace mindspore::graphkernel { class TsaChecker : public AtomicAddChecker { @@ -133,7 +134,7 @@ AnfNodePtr TsaAtomicAddToFirstTensor::ProcessTsaFirstNode(const KernelGraphPtr & auto new_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph), tsa_first_input}); new_composite_node->set_abstract(identity_node->abstract()); SetNewKernelInfo(new_composite_node, new_sub_graph, {tsa_first_input}, {identity_node}); - auto graph_attr = ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "tsa_identity"); + auto graph_attr = GkUtils::ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "tsa_identity"); new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr)); new_sub_graph->set_attr("composite_type", MakeValue("tsa_identity")); @@ -198,7 +199,8 @@ void TsaAtomicAddToFirstTensor::ProcessOriginCNode(const AnfNodePtr &composite_n CorrectKernelBuildInfo(composite_node, outter_node); auto old_graph_name = GetValue(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); - auto new_graph_name = ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "tensor_scatter_add_modified"); + auto new_graph_name = + GkUtils::ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "tensor_scatter_add_modified"); sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name)); MS_LOG(INFO) << "Convert " << old_graph_name << " to tensor scatter add graph " << new_graph_name; } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.cc index ecdd2c0338..aa0528e868 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.cc @@ -23,6 +23,7 @@ #include "backend/session/anf_runtime_algorithm.h" #include "backend/kernel_compiler/common_utils.h" #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" +#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h" #include "backend/optimizer/graph_kernel/eliminate_redundant_output.h" namespace mindspore::graphkernel { @@ -34,21 +35,6 @@ AnfNodePtrList GetUpdateStateList(const FuncGraphPtr &func_graph) { return result; } -AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index) { - AnfNodePtrList result; - for (size_t i = begin_index; i < nodes.size(); i++) { - if (IsPrimitiveCNode(nodes[i], prim::kPrimMakeTuple)) { - auto mt = nodes[i]->cast(); - // recursively spread all inner tuples. - auto mt_inputs = SpreadTuples(mt->inputs(), 1); - result.insert(result.end(), mt_inputs.begin(), mt_inputs.end()); - } else { - result.push_back(nodes[i]); - } - } - return result; -} - AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdateState(const AnfNodePtrList &nodes, const FuncGraphPtr &func_graph) { AnfNodePtrList result; @@ -85,7 +71,7 @@ bool SpreadUpdateState::Run(const FuncGraphPtr &func_graph) { auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (cnode->size() <= kUpdateStateRealInput) continue; - auto inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput); + auto inputs = GkUtils::SpreadTuples(cnode->inputs(), kUpdateStateRealInput); // extend inputs of UpdateState if which have multiple outputs inputs = ExtendInputsOfUpdateState(inputs, func_graph); if (inputs.size() + kUpdateStateRealInput != cnode->size() || inputs[0] != cnode->input(kUpdateStateRealInput)) { @@ -110,7 +96,7 @@ bool ShrinkUpdateState::Run(const FuncGraphPtr &func_graph) { auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (cnode->size() <= kUpdateStateRealInput + 1) continue; - AnfNodePtrList mt_inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput); + AnfNodePtrList mt_inputs = GkUtils::SpreadTuples(cnode->inputs(), kUpdateStateRealInput); AbstractBasePtrList abs_list; std::transform(mt_inputs.begin(), mt_inputs.end(), std::back_inserter(abs_list), [](const AnfNodePtr &inp) { return inp->abstract(); }); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.h index f0ccf0f54a..5cb4976e9c 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.h @@ -61,20 +61,6 @@ class ShrinkUpdateState : public opt::Pass { bool Run(const FuncGraphPtr &func_graph) override; }; -/** - * @brief Spread the MakeTuple in node list - * @param nodes - * @param begin_index - * @example - * input - * nodes: [ a, b, MakeTuple[i, j], c, d, MakeTuple[x, MakeTuple[y, z]] ] - * begin_index: 1 - * output - * [b, i, j, c, d, x, y, z] - * @return std::vector - */ -AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index = 0); - /** * @brief Extend the getitem for UpdateState * @example