1. move functions from graph_kernel_helper.cc to graph_builder.cc: the EliminateMakeTuple, implemented with SpreadTuples. the ConvertNonscalarTensorToParameter, remove checking the equal tensor. the IsTupleOutput (original IsMakeTupleOut), use recursion. the CreateNewFuseCNode, remove the "output" argument; call SetNewKernelInfo in it. the ReplaceNewFuseCNode, the BuildSingleGraphFromNodes (original MixedNodesTransToGraph) the ReplaceNodesWithGraphKernelNode (original FuseNodesToSubGraph) 2. create graph_kernel_utils.cc. the ExtractGraphKernelName and SpreadTuples was moved to the file. 3. add SetNewKernelInfo to the callback functions.tags/v1.6.0
| @@ -17,7 +17,9 @@ | |||
| #include "backend/optimizer/graph_kernel/adapter/callback_impl.h" | |||
| #include <algorithm> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include <memory> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| @@ -76,4 +78,62 @@ std::string CallbackImpl::GetOutputFormat(const AnfNodePtr &node, size_t i) { | |||
| std::string CallbackImpl::GetProcessor(const AnfNodePtr &node) { return kernel::GetProcessorStr(node); } | |||
| std::string CallbackImpl::GetProcessorFromContext() { return kernel::GetStrProcessorFromContext(); } | |||
| void CallbackImpl::SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) { | |||
| std::vector<std::string> graph_input_format; | |||
| std::vector<TypeId> graph_input_type; | |||
| std::vector<std::string> graph_output_format; | |||
| std::vector<TypeId> graph_output_type; | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto fg = GetCNodeFuncGraph(node); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| auto &inputs = cnode->inputs(); | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| auto kernel_with_index = AnfUtils::VisitKernel(inputs[i], 0); | |||
| if (kernel_with_index.first->isa<ValueNode>()) { | |||
| auto tensor = GetValueNode<tensor::TensorPtr>(kernel_with_index.first); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| (void)graph_input_format.emplace_back(kOpFormat_DEFAULT); | |||
| (void)graph_input_type.emplace_back(tensor->data_type()); | |||
| } else { | |||
| auto input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); | |||
| (void)graph_input_format.emplace_back(std::move(input_format)); | |||
| auto input_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second); | |||
| (void)graph_input_type.emplace_back(input_type); | |||
| } | |||
| fg->parameters()[i - 1]->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder para_info_builder; | |||
| para_info_builder.SetOutputsFormat({graph_input_format.back()}); | |||
| para_info_builder.SetOutputsDeviceType({graph_input_type.back()}); | |||
| para_info_builder.SetKernelType(KernelType::AKG_KERNEL); | |||
| para_info_builder.SetProcessor(kernel::GetProcessorFromContext()); | |||
| AnfAlgo::SetSelectKernelBuildInfo(para_info_builder.Build(), fg->parameters()[i - 1].get()); | |||
| } | |||
| AnfNodePtrList outputs; | |||
| if (IsPrimitiveCNode(fg->output(), prim::kPrimMakeTuple)) { | |||
| auto fg_output = fg->output()->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(fg_output); | |||
| outputs.assign(fg_output->inputs().begin() + 1, fg_output->inputs().end()); | |||
| } else { | |||
| outputs.push_back(fg->output()); | |||
| } | |||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||
| auto kernel_with_index = AnfAlgo::VisitKernel(outputs[i], 0); | |||
| auto output_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); | |||
| auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second); | |||
| graph_output_format.push_back(output_format); | |||
| graph_output_type.push_back(output_type); | |||
| } | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder; | |||
| graph_info_builder.SetInputsFormat(graph_input_format); | |||
| graph_info_builder.SetInputsDeviceType(graph_input_type); | |||
| graph_info_builder.SetOutputsFormat(graph_output_format); | |||
| graph_info_builder.SetOutputsDeviceType(graph_output_type); | |||
| graph_info_builder.SetProcessor(kernel::GetProcessorFromContext()); | |||
| graph_info_builder.SetKernelType(KernelType::AKG_KERNEL); | |||
| graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE); | |||
| auto graph_selected_info = graph_info_builder.Build(); | |||
| AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, node.get()); | |||
| } | |||
| } // namespace mindspore::graphkernel | |||
| @@ -34,6 +34,7 @@ class CallbackImpl : public Callback { | |||
| std::string GetOutputFormat(const AnfNodePtr &node, size_t i) override; | |||
| std::string GetProcessor(const AnfNodePtr &node) override; | |||
| std::string GetProcessorFromContext() override; | |||
| void SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) override; | |||
| }; | |||
| } // namespace mindspore::graphkernel | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADAPTER_CALLBACK_IMPL_H_ | |||
| @@ -29,6 +29,7 @@ | |||
| #include "utils/log_adapter.h" | |||
| #include "backend/kernel_compiler/kernel.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||
| #include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/session/kernel_graph.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| @@ -360,7 +361,7 @@ void AtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node, | |||
| CorrectKernelBuildInfo(composite_node, new_input); | |||
| auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); | |||
| auto new_graph_name = ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add"); | |||
| auto new_graph_name = GkUtils::ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add"); | |||
| sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name)); | |||
| MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name; | |||
| } | |||
| @@ -435,7 +436,7 @@ CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphP | |||
| auto broadcast_to_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph)}); | |||
| broadcast_to_composite_node->set_abstract(broadcast_to_node_inner->abstract()); | |||
| SetNewKernelInfo(broadcast_to_composite_node, new_sub_graph, {}, {broadcast_to_node_inner}); | |||
| auto graph_attr = ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "atomic_clean"); | |||
| auto graph_attr = GkUtils::ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "atomic_clean"); | |||
| new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr)); | |||
| new_sub_graph->set_attr("composite_type", MakeValue("atomic_clean")); | |||
| @@ -25,6 +25,7 @@ | |||
| #include "backend/kernel_compiler/kernel.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||
| #include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/session/kernel_graph.h" | |||
| @@ -82,7 +83,7 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_ | |||
| } | |||
| auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); | |||
| auto new_graph_name = ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add"); | |||
| auto new_graph_name = GkUtils::ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add"); | |||
| sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name)); | |||
| MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name; | |||
| } | |||
| @@ -25,6 +25,7 @@ | |||
| #include <unordered_map> | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||
| #include "backend/optimizer/graph_kernel/core/graph_builder.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "ir/anf.h" | |||
| #include "utils/context/graph_kernel_flags.h" | |||
| @@ -645,14 +646,12 @@ bool ArithmeticSimplify::Run(const FuncGraphPtr &func_graph) { | |||
| } | |||
| if (!change_anf_graph) continue; | |||
| ReorganizeEmptyGraph(lg); | |||
| AnfNodePtrList outputs; | |||
| auto new_funcgraph = LiteGraph2AnfGraph(lg, &outputs); | |||
| auto new_funcgraph = LiteGraph2AnfGraph(lg); | |||
| new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end()); | |||
| EliminateRedundantParameters(new_funcgraph, &inputs); | |||
| auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs, outputs); | |||
| SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs); | |||
| auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs); | |||
| mng->Replace(node, new_node); | |||
| mng->AddFuncGraph(new_funcgraph); | |||
| do_simplify = true; | |||
| @@ -25,29 +25,27 @@ | |||
| #include "base/core_ops.h" | |||
| #include "ir/func_graph.h" | |||
| #include "utils/utils.h" | |||
| #include "utils/anf_utils.h" | |||
| #include "backend/optimizer/graph_kernel/core/graph_kernel_callback.h" | |||
| #include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h" | |||
| namespace mindspore::graphkernel { | |||
| AnfNodePtrList GetOutput(const AnfNodePtrList &nodes, const NodeUsersMap &users, | |||
| const std::unordered_set<AnfNodePtr> &seen) { | |||
| namespace { | |||
| // find outputs of nodes | |||
| AnfNodePtrList FindOutputs(const AnfNodePtrList &nodes, const AnfNodePtrToAnfNodePtrMap &eqv) { | |||
| AnfNodePtrList output; | |||
| if (users.size() == 0) { | |||
| return output; | |||
| } | |||
| auto mng = nodes[0]->func_graph()->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| auto &users = mng->node_users(); | |||
| for (auto &node : nodes) { | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| // only CNode can be an output. | |||
| if (!node->isa<CNode>()) continue; | |||
| auto iter = users.find(node); | |||
| if (iter == users.end()) { | |||
| continue; | |||
| } | |||
| if (iter == users.end()) continue; | |||
| auto &node_users = iter->second; | |||
| const bool has_outer_user = std::any_of(std::begin(node_users), std::end(node_users), | |||
| [&seen](const std::pair<AnfNodePtr, int64_t> &u) -> bool { | |||
| const bool is_outer_user = (seen.find(u.first) == seen.end()); | |||
| return is_outer_user; | |||
| }); | |||
| if (has_outer_user) { | |||
| // if any user of the `node` is not in the nodes list, the `node` is an output. | |||
| if (std::any_of(std::begin(node_users), std::end(node_users), | |||
| [&eqv](const std::pair<AnfNodePtr, int> &u) { return eqv.find(u.first) == eqv.end(); })) { | |||
| output.emplace_back(node); | |||
| } | |||
| } | |||
| @@ -56,12 +54,11 @@ AnfNodePtrList GetOutput(const AnfNodePtrList &nodes, const NodeUsersMap &users, | |||
| AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *const inputs_ptr, | |||
| AnfNodePtrToAnfNodePtrMap *eqv_ptr) { | |||
| auto &input_list = *inputs_ptr; | |||
| auto &eqv = *eqv_ptr; | |||
| if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) { | |||
| eqv[node] = node; | |||
| } else if (eqv.find(node) == eqv.end()) { | |||
| input_list.push_back(node); | |||
| inputs_ptr->push_back(node); | |||
| eqv[node] = fg->add_parameter(); | |||
| eqv[node]->set_abstract(node->abstract()); | |||
| eqv[node]->set_kernel_info(node->kernel_info_ptr()); | |||
| @@ -69,43 +66,150 @@ AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNo | |||
| return eqv[node]; | |||
| } | |||
| std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildGraphFromNodes(const AnfNodePtrList &node_list) { | |||
| bool InlineInnerFuncGraph(const FuncGraphPtr &fg) { | |||
| auto mng = fg->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| bool changed = false; | |||
| auto cnodes = fg->GetOrderedCnodes(); | |||
| for (const auto &n : cnodes) { | |||
| auto graph_kernel_g = GetCNodeFuncGraph(n); | |||
| if (graph_kernel_g == nullptr) continue; | |||
| AnfNodePtrList inp(n->inputs().begin() + 1, n->inputs().end()); | |||
| auto out = InlineClone(graph_kernel_g, fg, inp, n->input(0)->scope()); | |||
| mng->Replace(n, out); | |||
| changed = true; | |||
| } | |||
| return changed; | |||
| } | |||
| void EliminateMakeTuple(const FuncGraphPtr &fg) { | |||
| if (!IsPrimitiveCNode(fg->output(), prim::kPrimMakeTuple)) { | |||
| return; | |||
| } | |||
| auto out_cnode = fg->output()->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(out_cnode); | |||
| AnfNodePtrList new_args = GkUtils::SpreadTuples(out_cnode->inputs()); | |||
| if (new_args.size() != out_cnode->size()) { | |||
| auto new_out = fg->NewCNode(new_args); | |||
| auto mng = fg->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| mng->Replace(out_cnode, new_out); | |||
| } | |||
| AbstractBasePtrList abs_list; | |||
| std::transform(new_args.begin() + 1, new_args.end(), std::back_inserter(abs_list), | |||
| [](const AnfNodePtr &node) { return node->abstract(); }); | |||
| fg->output()->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list)); | |||
| } | |||
| bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr) { | |||
| auto cnodes = fg->GetOrderedCnodes(); | |||
| AnfNodePtrList value_nodes; | |||
| for (const auto &cnode : cnodes) { | |||
| auto &inputs = cnode->inputs(); | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| const auto &tnode = inputs[i]; | |||
| auto tensor = GetValueNode<tensor::TensorPtr>(tnode); | |||
| if (tensor == nullptr || tensor->DataSize() == 1) { | |||
| continue; | |||
| } | |||
| value_nodes.push_back(tnode); | |||
| } | |||
| } | |||
| if (value_nodes.empty()) return false; | |||
| auto mng = fg->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| for (auto &vnode : value_nodes) { | |||
| auto parameter = fg->add_parameter(); | |||
| parameter->set_abstract(vnode->abstract()); | |||
| parameter->set_kernel_info(vnode->kernel_info_ptr()); | |||
| mng->Replace(vnode, parameter); | |||
| inputs_ptr->push_back(vnode); | |||
| } | |||
| return true; | |||
| } | |||
| bool IsTupleOutput(const AnfNodePtr &out, AnfNodePtrList *real_outs) { | |||
| if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) { | |||
| auto &inputs = out->cast<CNodePtr>()->inputs(); | |||
| real_outs->assign(inputs.begin() + 1, inputs.end()); | |||
| return true; | |||
| } | |||
| if (auto fg = GetCNodeFuncGraph(out); fg != nullptr) { | |||
| return IsTupleOutput(fg->output(), real_outs); | |||
| } | |||
| return false; | |||
| } | |||
| void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_fuse_cnode, | |||
| const AnfNodePtrList &outputs) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto mng = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| // single out | |||
| if (outputs.size() == 1) { | |||
| mng->Replace(outputs[0], new_fuse_cnode); | |||
| return; | |||
| } | |||
| size_t offset = 0; | |||
| for (size_t out_idx = 0; out_idx < outputs.size(); out_idx++) { | |||
| AnfNodePtrList real_outs; | |||
| // the output is a single tensor | |||
| if (!IsTupleOutput(outputs[out_idx], &real_outs)) { | |||
| auto gt_idx = MakeValue(SizeToLong(out_idx + offset)); | |||
| AnfNodePtrList gt_inputs{NewValueNode(prim::kPrimTupleGetItem), new_fuse_cnode, NewValueNode(gt_idx)}; | |||
| gt_inputs.back()->set_abstract(gt_idx->ToAbstract()); | |||
| auto new_out = func_graph->NewCNode(gt_inputs); | |||
| new_out->set_abstract(outputs[out_idx]->abstract()); | |||
| mng->Replace(outputs[out_idx], new_out); | |||
| continue; | |||
| } | |||
| // the out is make tuple , modify the get_item node's value | |||
| auto users = mng->node_users()[outputs[out_idx]]; // use a copy, the original user map is changed in for-loop. | |||
| for (auto &user : users) { | |||
| auto getitem_node = user.first; | |||
| if (!getitem_node->isa<CNode>() || !IsPrimitiveCNode(getitem_node, prim::kPrimTupleGetItem)) { | |||
| continue; | |||
| } | |||
| auto value_ptr = GetValueNode(getitem_node->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem)); | |||
| MS_EXCEPTION_IF_NULL(value_ptr); | |||
| auto old_gt_idx = GetValue<int64_t>(value_ptr); | |||
| auto gt_idx = MakeValue(SizeToLong(out_idx + offset) + old_gt_idx); | |||
| AnfNodePtrList gt_inputs{NewValueNode(prim::kPrimTupleGetItem), new_fuse_cnode, NewValueNode(gt_idx)}; | |||
| gt_inputs.back()->set_abstract(gt_idx->ToAbstract()); | |||
| auto new_getitem_node = func_graph->NewCNode(gt_inputs); | |||
| new_getitem_node->set_abstract(getitem_node->abstract()); | |||
| mng->Replace(getitem_node, new_getitem_node); | |||
| } | |||
| offset += real_outs.size() - 1; | |||
| } | |||
| } | |||
| } // namespace | |||
| std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildGraphFromNodes(const AnfNodePtrList &nodes) { | |||
| FuncGraphPtr fg = nullptr; | |||
| { | |||
| // limit the lifetime of guard. | |||
| TraceGuard guard( | |||
| std::make_shared<TraceSegmentTransform>(node_list[0]->cast<CNodePtr>()->func_graph()->debug_info())); | |||
| TraceGuard guard(std::make_shared<TraceSegmentTransform>(nodes[0]->cast<CNodePtr>()->func_graph()->debug_info())); | |||
| fg = std::make_shared<FuncGraph>(); | |||
| } | |||
| AnfNodePtrList input_list; | |||
| AnfNodePtrToAnfNodePtrMap eqv; | |||
| // Merge CNodes into a AnfGraph that represents a linear instruction segment | |||
| for (auto node : node_list) { | |||
| auto &input_nodes = node->cast<CNodePtr>()->inputs(); | |||
| auto fn = input_nodes[0]; | |||
| std::vector<AnfNodePtr> new_args{fn}; | |||
| if (IsPrimitive(fn, prim::kPrimDepend) && input_nodes.size() >= kDependInputSize && | |||
| eqv.find(input_nodes[kDependAttachNodeIndex]) == eqv.end()) { | |||
| new_args.emplace_back(RefSubGraphNode(fg, input_nodes[kRealInputIndexInDepend], &input_list, &eqv)); | |||
| const size_t value_start_index = 2; | |||
| for (size_t i = value_start_index; i < input_nodes.size(); ++i) { | |||
| new_args.emplace_back(NewValueNode(MakeValue(0))); | |||
| } | |||
| } else { | |||
| (void)std::transform( | |||
| std::begin(input_nodes) + 1, std::end(input_nodes), std::back_inserter(new_args), | |||
| [&fg, &input_list, &eqv](const AnfNodePtr &node) { return RefSubGraphNode(fg, node, &input_list, &eqv); }); | |||
| } | |||
| for (auto &node : nodes) { | |||
| auto &node_inputs = node->cast<CNodePtr>()->inputs(); | |||
| std::vector<AnfNodePtr> new_args{node_inputs[0]}; | |||
| (void)std::transform( | |||
| std::begin(node_inputs) + 1, std::end(node_inputs), std::back_inserter(new_args), | |||
| [&fg, &input_list, &eqv](const AnfNodePtr &node) { return RefSubGraphNode(fg, node, &input_list, &eqv); }); | |||
| TraceGuard tg(std::make_shared<TraceSegmentTransform>(node->debug_info())); | |||
| eqv[node] = fg->NewCNode(new_args); | |||
| eqv[node]->set_abstract(node->abstract()); | |||
| eqv[node]->set_kernel_info(node->kernel_info_ptr()); | |||
| } | |||
| std::unordered_set<AnfNodePtr> eqv_keys; | |||
| (void)std::transform(std::begin(eqv), std::end(eqv), std::inserter(eqv_keys, eqv_keys.end()), | |||
| [](const std::pair<AnfNodePtr, AnfNodePtr> &elem) -> AnfNodePtr { return elem.first; }); | |||
| auto mgr = node_list[0]->func_graph()->manager(); | |||
| auto outputs = GetOutput(node_list, mgr->node_users(), eqv_keys); | |||
| auto outputs = FindOutputs(nodes, eqv); | |||
| AnfNodePtr fg_output; | |||
| if (outputs.size() > 1) { | |||
| std::vector<AnfNodePtr> output_args; | |||
| @@ -120,4 +224,52 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildGraphFromNodes(con | |||
| fg->set_output(fg_output); | |||
| return std::make_tuple(fg, input_list, outputs); | |||
| } | |||
| // Transform nodes(including basic and composite node) to a new graph, and collect their inputs and outputs. | |||
| std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildSingleGraphFromNodes(const AnfNodePtrList &nodes) { | |||
| FuncGraphPtr fg; | |||
| AnfNodePtrList inputs; | |||
| AnfNodePtrList outputs; | |||
| std::tie(fg, inputs, outputs) = BuildGraphFromNodes(nodes); | |||
| FuncGraphManagerPtr mng = fg->manager(); | |||
| if (mng == nullptr) { | |||
| mng = Manage(fg, false); | |||
| fg->set_manager(mng); | |||
| } | |||
| InlineInnerFuncGraph(fg); | |||
| // eliminate tuple of tuple, and set Abstract for output MakeTuple | |||
| EliminateMakeTuple(fg); | |||
| ConvertNonscalarTensorToParameter(fg, &inputs); | |||
| return std::make_tuple(fg, inputs, outputs); | |||
| } | |||
| AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &main_fg, const FuncGraphPtr &sub_fg, const AnfNodePtrList &inputs) { | |||
| std::vector<AnfNodePtr> fn_inputs{NewValueNode(sub_fg)}; | |||
| fn_inputs.insert(fn_inputs.end(), inputs.begin(), inputs.end()); | |||
| auto fuse_cnode = main_fg->NewCNode(fn_inputs); | |||
| fuse_cnode->set_abstract(sub_fg->output()->abstract()); | |||
| Callback::Instance()->SetGraphKernelNodeKernelInfo(fuse_cnode); | |||
| return fuse_cnode; | |||
| } | |||
| AnfNodePtr ReplaceNodesWithGraphKernelNode(const AnfNodePtrList &nodes, const FuncGraphPtr &main_graph, | |||
| const std::string &postfix) { | |||
| auto mng = main_graph->manager(); | |||
| if (mng == nullptr) { | |||
| mng = Manage(main_graph, true); | |||
| main_graph->set_manager(mng); | |||
| } | |||
| FuncGraphPtr fg; | |||
| AnfNodePtrList inputs; | |||
| AnfNodePtrList outputs; | |||
| std::tie(fg, inputs, outputs) = BuildSingleGraphFromNodes(nodes); | |||
| auto fuse_new_node = CreateNewFuseCNode(main_graph, fg, inputs); | |||
| ReplaceNewFuseCNode(main_graph, fuse_new_node, outputs); | |||
| auto fuse_op_name = GkUtils::ExtractGraphKernelName(nodes, "", postfix); | |||
| fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name)); | |||
| return fuse_new_node; | |||
| } | |||
| } // namespace mindspore::graphkernel | |||
| @@ -18,11 +18,16 @@ | |||
| #include <unordered_map> | |||
| #include <tuple> | |||
| #include <string> | |||
| #include "ir/anf.h" | |||
| namespace mindspore::graphkernel { | |||
| using AnfNodePtrToAnfNodePtrMap = std::unordered_map<AnfNodePtr, AnfNodePtr>; | |||
| std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildGraphFromNodes(const AnfNodePtrList &lst); | |||
| std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildGraphFromNodes(const AnfNodePtrList &nodes); | |||
| std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildSingleGraphFromNodes(const AnfNodePtrList &nodes); | |||
| AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &main_fg, const FuncGraphPtr &sub_fg, const AnfNodePtrList &inputs); | |||
| AnfNodePtr ReplaceNodesWithGraphKernelNode(const AnfNodePtrList &nodes, const FuncGraphPtr &main_graph, | |||
| const std::string &postfix = ""); | |||
| } // namespace mindspore::graphkernel | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_BUILDER_H_ | |||
| @@ -120,6 +120,13 @@ class Callback { | |||
| */ | |||
| virtual std::string GetProcessorFromContext() = 0; | |||
| /** | |||
| * @brief Set KernelInfo for a GraphKernel node, the info is extract from its inputs/outputs. | |||
| * | |||
| * @param[in] node the GraphKernel CNode. | |||
| */ | |||
| virtual void SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) = 0; | |||
| private: | |||
| friend class CallbackImplRegister; | |||
| static void RegImpl(Callback *cb) { instance_.reset(cb); } | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h" | |||
| #include <sstream> | |||
| #include "base/core_ops.h" | |||
| #include "utils/anf_utils.h" | |||
| namespace mindspore::graphkernel { | |||
| std::string GkUtils::ExtractGraphKernelName(const AnfNodePtrList &nodes, const std::string &prefix, | |||
| const std::string &postfix) { | |||
| std::stringstream name; | |||
| if (!prefix.empty()) { | |||
| name << prefix << "_"; | |||
| } | |||
| for (const auto &node : nodes) { | |||
| if (AnfUtils::IsGraphKernel(node)) { | |||
| auto fg_flag_val = GetCNodeFuncGraph(node)->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); | |||
| name << GetValue<std::string>(fg_flag_val) << "_"; | |||
| } else if (node->isa<CNode>() && AnfUtils::IsRealKernel(node)) { | |||
| name << GetCNodePrimitive(node)->name() << "_"; | |||
| } | |||
| } | |||
| if (!postfix.empty()) { | |||
| name << postfix; | |||
| } | |||
| return name.str(); | |||
| } | |||
| AnfNodePtrList GkUtils::SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index) { | |||
| AnfNodePtrList result; | |||
| for (size_t i = begin_index; i < nodes.size(); i++) { | |||
| if (IsPrimitiveCNode(nodes[i], prim::kPrimMakeTuple)) { | |||
| auto mt = nodes[i]->cast<CNodePtr>(); | |||
| // recursively spread all inner tuples. | |||
| auto mt_inputs = SpreadTuples(mt->inputs(), 1); | |||
| result.insert(result.end(), mt_inputs.begin(), mt_inputs.end()); | |||
| } else { | |||
| result.push_back(nodes[i]); | |||
| } | |||
| } | |||
| return result; | |||
| } | |||
| } // namespace mindspore::graphkernel | |||
| @@ -0,0 +1,52 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_UTILS_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_UTILS_H_ | |||
| #include <string> | |||
| #include "ir/anf.h" | |||
| #include "ir/func_graph.h" | |||
| namespace mindspore::graphkernel { | |||
| class GkUtils { | |||
| public: | |||
| /** | |||
| * @brief Extract kernel name from nodes, only the real kernel CNode is processed. | |||
| * @param[in] nodes The node list | |||
| * @param[in] prefix The prefix of result name | |||
| * @param[in] postfix The postfix of result name | |||
| * @return The string concatenated by the names of all cnodes | |||
| */ | |||
| static std::string ExtractGraphKernelName(const AnfNodePtrList &nodes, const std::string &prefix = "", | |||
| const std::string &postfix = ""); | |||
| /** | |||
| * @brief Spread the MakeTuple in node list | |||
| * @param[in] nodes | |||
| * @param[in] begin_index | |||
| * @example | |||
| * input | |||
| * nodes: [ a, b, MakeTuple[i, j], c, d, MakeTuple[x, MakeTuple[y, z]] ] | |||
| * begin_index: 1 | |||
| * output | |||
| * [b, i, j, c, d, x, y, z] | |||
| * @return std::vector<AnfNodePtr> | |||
| */ | |||
| static AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index = 0); | |||
| }; | |||
| } // namespace mindspore::graphkernel | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_UTILS_H_ | |||
| @@ -29,6 +29,7 @@ | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||
| #include "backend/optimizer/graph_kernel/update_state_formatter.h" | |||
| #include "backend/optimizer/graph_kernel/core/graph_builder.h" | |||
| namespace mindspore::graphkernel { | |||
| namespace { | |||
| @@ -257,10 +258,7 @@ AnfNodePtr EliminateHangingOutput::ReplaceMakeTuple(const AnfNodePtr &node, cons | |||
| auto old_cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(old_cnode); | |||
| AnfNodePtrList inputs(old_cnode->inputs().begin() + 1, old_cnode->inputs().end()); | |||
| AnfNodePtrList outputs; | |||
| kernel::GetFuncGraphOutputNodes(func_graph, &outputs); | |||
| auto graph_kernel_node = CreateNewFuseCNode(node->func_graph(), func_graph, inputs, outputs); | |||
| SetNewKernelInfo(graph_kernel_node, func_graph, inputs, outputs); | |||
| auto graph_kernel_node = CreateNewFuseCNode(node->func_graph(), func_graph, inputs); | |||
| return graph_kernel_node; | |||
| } | |||
| @@ -33,6 +33,7 @@ | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||
| #include "backend/optimizer/pass/getitem_tuple.h" | |||
| #include "backend/optimizer/graph_kernel/core/graph_builder.h" | |||
| namespace mindspore::graphkernel { | |||
| std::vector<PrimitivePtr> GraphKernelCluster::GetClusterableOpList() { | |||
| @@ -404,10 +405,9 @@ bool GraphKernelCluster::Process(const FuncGraphPtr &func_graph) { | |||
| void GraphKernelCluster::CreateFuncGraph(const FuncGraphPtr &func_graph, const std::vector<size_t> &nodes_id) { | |||
| AnfNodePtrList old_nodes; | |||
| AnfNodePtr new_node; | |||
| (void)std::transform(nodes_id.begin(), nodes_id.end(), std::back_inserter(old_nodes), | |||
| [this](size_t id) { return this->nodes_[id]; }); | |||
| std::tie(new_node, std::ignore) = FuseNodesToSubGraph(old_nodes, func_graph, "fusion"); | |||
| auto new_node = ReplaceNodesWithGraphKernelNode(old_nodes, func_graph, "fusion"); | |||
| std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<opt::GetitemTuple>(); | |||
| (void)eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(new_node)); | |||
| if (GraphKernelFlags::GetInstance().dump_as_text) { | |||
| @@ -37,6 +37,7 @@ | |||
| #include "pybind_api/ir/primitive_py.h" | |||
| #include "runtime/device/kernel_info.h" | |||
| #include "backend/optimizer/graph_kernel/expanders/expander_factory.h" | |||
| #include "backend/optimizer/graph_kernel/core/graph_builder.h" | |||
| namespace mindspore::graphkernel { | |||
| namespace { | |||
| @@ -164,12 +165,9 @@ AnfNodePtr PyExpander::CreateExpandGraphKernel(const FuncGraphPtr &new_func_grap | |||
| auto func_graph = old_node->func_graph(); | |||
| std::vector<AnfNodePtr> inputs(old_node->inputs().begin() + 1, old_node->inputs().end()); | |||
| AnfNodePtrList kernel_nodes; | |||
| AnfNodePtrList outputs; | |||
| EliminateRedundantParameters(new_func_graph, &inputs); | |||
| kernel::GetValidKernelNodes(new_func_graph, &kernel_nodes); | |||
| kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs); | |||
| auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs); | |||
| SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs); | |||
| auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs); | |||
| MS_LOG(DEBUG) << "Expand node: " << old_node->fullname_with_scope() | |||
| << " with: " << graph_kernel_node->fullname_with_scope(); | |||
| return graph_kernel_node; | |||
| @@ -67,37 +67,6 @@ bool IsMakeTupleOut(const AnfNodePtr &out, AnfNodePtrList *real_outs) { | |||
| return false; | |||
| } | |||
| AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) { | |||
| AnfNodePtrList outs; | |||
| auto out_node = fg->output(); | |||
| if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) { | |||
| std::vector<AnfNodePtr> output_args; | |||
| auto out_cnode = out_node->cast<CNodePtr>(); | |||
| for (auto out : out_cnode->inputs()) { | |||
| if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) { | |||
| auto inputs = out->cast<CNodePtr>()->inputs(); | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| output_args.push_back(inputs[i]); | |||
| } | |||
| } else { | |||
| output_args.push_back(out); | |||
| } | |||
| } | |||
| if (output_args.size() != out_cnode->inputs().size()) { | |||
| auto new_out = fg->NewCNode(output_args); | |||
| mng->Replace(out_node, new_out); | |||
| } | |||
| for (size_t i = 1; i < output_args.size(); ++i) { | |||
| outs.push_back(output_args[i]); | |||
| } | |||
| return outs; | |||
| } | |||
| outs.push_back(out_node); | |||
| return outs; | |||
| } | |||
| bool GenJson(const AnfNodePtrList &op_nodes, const std::pair<AnfNodePtrList, AnfNodePtrList> &in_and_out, | |||
| const DumpOption &dump_option, nlohmann::json *op_desc, | |||
| std::map<std::string, AnfNodePtr> *address_node_map = nullptr) { | |||
| @@ -128,100 +97,6 @@ AbstractBasePtr GetOutputAbstract(const AnfNodePtr &node, size_t output_idx) { | |||
| return out_spec; | |||
| } | |||
| bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr) { | |||
| MS_EXCEPTION_IF_NULL(inputs_ptr); | |||
| auto nodes = TopoSort(fg->get_return()); | |||
| std::vector<std::pair<tensor::TensorPtr, AnfNodePtrList>> v_replace; | |||
| for (const auto &node : nodes) { | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| const auto &tnode = inputs[i]; | |||
| auto tensor = GetValueNode<tensor::TensorPtr>(tnode); | |||
| if (tensor == nullptr || tensor->DataSize() == 1) { | |||
| continue; | |||
| } | |||
| auto tensor_iter = std::find_if( | |||
| v_replace.begin(), v_replace.end(), | |||
| [&tensor](const std::pair<tensor::TensorPtr, AnfNodePtrList> &vl) { return vl.first->ValueEqual(*tensor); }); | |||
| if (tensor_iter == v_replace.end()) { | |||
| (void)v_replace.emplace_back(tensor, AnfNodePtrList{tnode}); | |||
| } else { | |||
| tensor_iter->second.push_back(tnode); | |||
| } | |||
| } | |||
| } | |||
| if (v_replace.empty()) { | |||
| return false; | |||
| } | |||
| auto mng = fg->manager(); | |||
| if (mng == nullptr) { | |||
| mng = Manage(fg, false); | |||
| fg->set_manager(mng); | |||
| } | |||
| auto &inputs = *inputs_ptr; | |||
| for (auto iter : v_replace) { | |||
| auto value_nodes = iter.second; | |||
| if (value_nodes.empty()) { | |||
| MS_LOG(EXCEPTION) << "Invalid value in map!"; | |||
| } | |||
| auto vnode = value_nodes[0]; | |||
| auto parameter = fg->add_parameter(); | |||
| parameter->set_abstract(vnode->abstract()); | |||
| parameter->set_kernel_info(vnode->kernel_info_ptr()); | |||
| for (const auto &value_node : value_nodes) { | |||
| mng->Replace(value_node, parameter); | |||
| } | |||
| inputs.push_back(vnode); | |||
| } | |||
| return true; | |||
| } | |||
| // Transform nodes(including basic and composite node) to a new graph, and collect their inputs and outputs. | |||
| std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> MixedNodesTransToGraph(const AnfNodePtrList &fuse_nodes, | |||
| AnfNodePtrList *src_outputs) { | |||
| FuncGraphPtr fg; | |||
| AnfNodePtrList inputs; | |||
| AnfNodePtrList outputs; | |||
| AnfNodePtrList *soutputs = (src_outputs != nullptr) ? src_outputs : &outputs; | |||
| std::tie(fg, inputs, *soutputs) = BuildGraphFromNodes(fuse_nodes); | |||
| FuncGraphManagerPtr mng = fg->manager(); | |||
| if (mng == nullptr) { | |||
| mng = Manage(fg, false); | |||
| fg->set_manager(mng); | |||
| } | |||
| // Inline origin graphkernel | |||
| auto cnodes = fg->GetOrderedCnodes(); | |||
| for (const auto &n : cnodes) { | |||
| if (!AnfAlgo::IsGraphKernel(n)) { | |||
| continue; | |||
| } | |||
| auto graph_kernel_g = GetValueNode<FuncGraphPtr>(n->input(0)); | |||
| AnfNodePtrList ins; | |||
| ins.insert(ins.end(), n->inputs().begin() + 1, n->inputs().end()); | |||
| auto out = InlineClone(graph_kernel_g, fg, ins, n->input(0)->scope()); | |||
| mng->Replace(n, out); | |||
| } | |||
| EliminateMakeTuple(fg, mng); | |||
| ConvertNonscalarTensorToParameter(fg, &inputs); | |||
| outputs.clear(); | |||
| kernel::GetFuncGraphOutputNodes(fg, &outputs); | |||
| return std::make_tuple(fg, inputs, outputs); | |||
| } | |||
| // Rebuild as node inputs or outputs have changed, processor comes from node itself | |||
| kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::string> &inputs_format, | |||
| const std::vector<TypeId> &inputs_type, | |||
| @@ -254,6 +129,7 @@ kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::str | |||
| return graph_info_builder.Build(); | |||
| } | |||
| // Deprecated. use Callback->SetGraphKernelNodeKernelInfo. | |||
| void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, | |||
| const AnfNodePtrList &outputs) { | |||
| std::vector<std::string> graph_input_format; | |||
| @@ -309,127 +185,6 @@ void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const | |||
| AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, new_node.get()); | |||
| } | |||
| AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &func_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, | |||
| const AnfNodePtrList &outputs) { | |||
| auto func_node = NewValueNode(fg); | |||
| std::vector<AnfNodePtr> fn_inputs; | |||
| fn_inputs.push_back(func_node); | |||
| fn_inputs.insert(fn_inputs.end(), inputs.begin(), inputs.end()); | |||
| auto fuse_cnode = func_graph->NewCNode(fn_inputs); | |||
| // Set output abstract | |||
| if (outputs.size() > 1) { | |||
| std::vector<AbstractBasePtr> out_specs; | |||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||
| out_specs.push_back(outputs[i]->abstract()); | |||
| } | |||
| auto out_spec = std::make_shared<abstract::AbstractTuple>(out_specs); | |||
| fuse_cnode->set_abstract(out_spec); | |||
| } else { | |||
| fuse_cnode->set_abstract(outputs[0]->abstract()); | |||
| } | |||
| // Set parameter abstract. | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0); | |||
| auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second); | |||
| fg->parameters()[i]->set_abstract(input_abs); | |||
| } | |||
| return fuse_cnode; | |||
| } | |||
| void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_fuse_cnode, | |||
| const AnfNodePtrList &outputs) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto mng = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| // single out | |||
| if (outputs.size() == 1) { | |||
| mng->Replace(outputs[0], new_fuse_cnode); | |||
| return; | |||
| } | |||
| std::vector<AnfNodePtr> fn_inputs; | |||
| size_t offset = 0; | |||
| for (size_t out_idx = 0; out_idx < outputs.size(); out_idx++) { | |||
| AnfNodePtrList real_outs; | |||
| // not make tuple out, replace | |||
| if (!IsMakeTupleOut(outputs[out_idx], &real_outs)) { | |||
| fn_inputs.clear(); | |||
| fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem)); | |||
| fn_inputs.push_back(new_fuse_cnode); | |||
| fn_inputs.push_back(NewValueNode(MakeValue(SizeToLong(out_idx + offset)))); | |||
| auto new_out = func_graph->NewCNode(fn_inputs); | |||
| new_out->set_abstract(outputs[out_idx]->abstract()); | |||
| mng->Replace(outputs[out_idx], new_out); | |||
| continue; | |||
| } | |||
| // the out is make tuple , modify the get_item node's value | |||
| auto users = mng->node_users()[outputs[out_idx]]; | |||
| for (auto &user : users) { | |||
| auto use_node = user.first; | |||
| if (!use_node->isa<CNode>() || !IsPrimitiveCNode(use_node, prim::kPrimTupleGetItem)) { | |||
| continue; | |||
| } | |||
| auto get_item_cnode = use_node->cast<CNodePtr>(); | |||
| auto value_input = get_item_cnode->input(kInputNodeOutputIndexInTupleGetItem); | |||
| MS_EXCEPTION_IF_NULL(value_input); | |||
| auto value_node = value_input->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| auto item_idx = GetValue<int64_t>(value_node->value()); | |||
| int64_t new_item_idx = SizeToLong(out_idx + offset) + item_idx; | |||
| fn_inputs.clear(); | |||
| fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem)); | |||
| fn_inputs.push_back(new_fuse_cnode); | |||
| fn_inputs.push_back(NewValueNode(new_item_idx)); | |||
| auto new_out = func_graph->NewCNode(fn_inputs); | |||
| new_out->set_abstract(get_item_cnode->abstract()); | |||
| mng->Replace(get_item_cnode, new_out); | |||
| } | |||
| offset += real_outs.size() - 1; | |||
| } | |||
| } | |||
| std::tuple<AnfNodePtr, AnfNodePtrList> FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes, | |||
| const FuncGraphPtr &kernel_graph, | |||
| const std::string &postfix) { | |||
| auto mng = kernel_graph->manager(); | |||
| if (mng == nullptr) { | |||
| mng = Manage(kernel_graph, true); | |||
| kernel_graph->set_manager(mng); | |||
| } | |||
| FuncGraphPtr fg; | |||
| AnfNodePtrList inputs; | |||
| AnfNodePtrList src_outputs; | |||
| AnfNodePtrList outputs; | |||
| std::tie(fg, inputs, outputs) = MixedNodesTransToGraph(fuse_nodes, &src_outputs); | |||
| auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs); | |||
| SetNewKernelInfo(fuse_new_node, fg, inputs, outputs); | |||
| // Handle get-item probleam. | |||
| ReplaceNewFuseCNode(kernel_graph, fuse_new_node, src_outputs); | |||
| // set graphKernel attr | |||
| std::string fuse_op_name = ""; | |||
| for (auto &fuse_node : fuse_nodes) { | |||
| if (IsPrimitiveCNode(fuse_node)) { | |||
| fuse_op_name += AnfAlgo::GetCNodePrimitive(fuse_node)->name() + "_"; | |||
| } else if (AnfAlgo::IsGraphKernel(fuse_node)) { | |||
| auto fuse_cnode = fuse_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(fuse_cnode); | |||
| auto graph_kernel_fg = GetValueNode<FuncGraphPtr>(fuse_cnode->input(kAnfPrimitiveIndex)); | |||
| auto fg_flag_val = graph_kernel_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); | |||
| auto fuse_fg_name = GetValue<std::string>(fg_flag_val); | |||
| fuse_op_name += fuse_fg_name + "_"; | |||
| } | |||
| } | |||
| fuse_op_name += postfix; | |||
| fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name)); | |||
| return std::make_tuple(fuse_new_node, src_outputs); | |||
| } | |||
| bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc, | |||
| std::map<std::string, AnfNodePtr> *address_node_map) { | |||
| MS_EXCEPTION_IF_NULL(op_desc); | |||
| @@ -466,15 +221,14 @@ bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, n | |||
| } | |||
| FuncGraphPtr fg; | |||
| AnfNodePtrList op_nodes, inputs, outputs; | |||
| if (nodes.size() == 1 && AnfAlgo::IsGraphKernel(nodes[0])) { | |||
| fg = AnfAlgo::GetCNodeFuncGraphPtr(nodes[0]); | |||
| } else { | |||
| std::tie(fg, inputs, outputs) = MixedNodesTransToGraph(nodes); | |||
| inputs.clear(); | |||
| outputs.clear(); | |||
| std::tie(fg, std::ignore, std::ignore) = BuildSingleGraphFromNodes(nodes); | |||
| } | |||
| AnfNodePtrList op_nodes, inputs, outputs; | |||
| kernel::GetValidKernelNodes(fg, &op_nodes, &inputs, &outputs); | |||
| auto mng = fg->manager(); | |||
| @@ -524,22 +278,6 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc) { | |||
| return fg; | |||
| } | |||
| std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix, const string &postfix) { | |||
| std::stringstream name; | |||
| if (prefix != "") { | |||
| name << prefix << "_"; | |||
| } | |||
| for (const auto &node : cnodes) { | |||
| if (node->isa<CNode>() && AnfUtils::IsRealKernel(node)) { | |||
| name << AnfAlgo::GetCNodeName(node) << "_"; | |||
| } | |||
| } | |||
| if (postfix != "") { | |||
| name << postfix; | |||
| } | |||
| return name.str(); | |||
| } | |||
| void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| @@ -53,9 +53,6 @@ struct DataInfo { | |||
| TypePtr type{nullptr}; | |||
| }; | |||
| bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr); | |||
| std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> MixedNodesTransToGraph(const AnfNodePtrList &fuse_nodes, | |||
| AnfNodePtrList *src_outputs = nullptr); | |||
| void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, | |||
| const AnfNodePtrList &outputs); | |||
| kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::string> &inputs_format, | |||
| @@ -66,19 +63,11 @@ kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::str | |||
| const std::vector<TypeId> &inputs_type, | |||
| const std::vector<std::string> &output_formats, | |||
| const std::vector<TypeId> &output_types); | |||
| AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, | |||
| const AnfNodePtrList &outputs); | |||
| void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode, | |||
| const AnfNodePtrList &outputs); | |||
| std::tuple<AnfNodePtr, AnfNodePtrList> FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes, | |||
| const FuncGraphPtr &kernel_graph, | |||
| const std::string &postfix = ""); | |||
| bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc); | |||
| bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc, | |||
| std::map<std::string, AnfNodePtr> *address_node_map); | |||
| bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, const DumpOption &dump_option, nlohmann::json *op_desc); | |||
| FuncGraphPtr JsonDescToAnf(const std::string &json_desc); | |||
| std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix = "", const string &postfix = ""); | |||
| void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); | |||
| std::string GetFormat(const AnfNodePtr &node); | |||
| @@ -27,6 +27,7 @@ | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||
| #include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| #include "utils/context/graph_kernel_flags.h" | |||
| @@ -632,7 +633,7 @@ class Splitter { | |||
| graph_manager->AddFuncGraph(sub_func_graph); | |||
| // set GraphKernel attr | |||
| auto attr = ExtractGraphKernelName(TopoSort(sub_func_graph->get_return()), "", "split"); | |||
| auto attr = GkUtils::ExtractGraphKernelName(TopoSort(sub_func_graph->get_return()), "", "split"); | |||
| sub_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(attr)); | |||
| // set kernel info | |||
| @@ -28,6 +28,7 @@ | |||
| #include "frontend/operator/ops.h" | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "backend/optimizer/graph_kernel/update_state_formatter.h" | |||
| #include "backend/optimizer/graph_kernel/core/graph_builder.h" | |||
| namespace mindspore::graphkernel { | |||
| namespace { | |||
| @@ -746,8 +747,7 @@ bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector<ParallelInfo> | |||
| } | |||
| changed = true; | |||
| SetFusedParallelOpAttrToReturnNode(parallel_infos[i]); | |||
| AnfNodePtr sg_node; | |||
| std::tie(sg_node, std::ignore) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "parallel"); | |||
| auto sg_node = ReplaceNodesWithGraphKernelNode(fuse_nodes, kernel_graph, "parallel"); | |||
| AnfAlgo::SetNodeAttr(kAttrCompositeType, MakeValue("parallel_fusion"), sg_node); | |||
| DumpParallelFusionDetail(fuse_nodes, sg_node); | |||
| } | |||
| @@ -31,6 +31,7 @@ | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||
| #include "backend/optimizer/graph_kernel/model/lite_graph.h" | |||
| #include "backend/optimizer/graph_kernel/model/op_register.h" | |||
| #include "backend/optimizer/graph_kernel/core/graph_builder.h" | |||
| namespace mindspore::graphkernel { | |||
| namespace { | |||
| @@ -438,13 +439,11 @@ bool TransformOpOptimizer::Run(const FuncGraphPtr &kernel_graph) { | |||
| auto litegraph = AnfGraph2LiteGraph(sub_func_graph); | |||
| if (Process(litegraph)) { | |||
| changed = true; | |||
| AnfNodePtrList outputs; | |||
| auto new_funcgraph = LiteGraph2AnfGraph(litegraph, &outputs); | |||
| auto new_funcgraph = LiteGraph2AnfGraph(litegraph); | |||
| new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end()); | |||
| auto new_node = CreateNewFuseCNode(kernel_graph, new_funcgraph, inputs, outputs); | |||
| SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs); | |||
| auto new_node = CreateNewFuseCNode(kernel_graph, new_funcgraph, inputs); | |||
| (void)mng->Replace(node, new_node); | |||
| mng->AddFuncGraph(new_funcgraph); | |||
| } | |||
| @@ -35,6 +35,7 @@ | |||
| #include "backend/kernel_compiler/kernel.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||
| #include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h" | |||
| namespace mindspore::graphkernel { | |||
| class TsaChecker : public AtomicAddChecker { | |||
| @@ -133,7 +134,7 @@ AnfNodePtr TsaAtomicAddToFirstTensor::ProcessTsaFirstNode(const KernelGraphPtr & | |||
| auto new_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph), tsa_first_input}); | |||
| new_composite_node->set_abstract(identity_node->abstract()); | |||
| SetNewKernelInfo(new_composite_node, new_sub_graph, {tsa_first_input}, {identity_node}); | |||
| auto graph_attr = ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "tsa_identity"); | |||
| auto graph_attr = GkUtils::ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "tsa_identity"); | |||
| new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr)); | |||
| new_sub_graph->set_attr("composite_type", MakeValue("tsa_identity")); | |||
| @@ -198,7 +199,8 @@ void TsaAtomicAddToFirstTensor::ProcessOriginCNode(const AnfNodePtr &composite_n | |||
| CorrectKernelBuildInfo(composite_node, outter_node); | |||
| auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); | |||
| auto new_graph_name = ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "tensor_scatter_add_modified"); | |||
| auto new_graph_name = | |||
| GkUtils::ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "tensor_scatter_add_modified"); | |||
| sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name)); | |||
| MS_LOG(INFO) << "Convert " << old_graph_name << " to tensor scatter add graph " << new_graph_name; | |||
| } | |||
| @@ -23,6 +23,7 @@ | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||
| #include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h" | |||
| #include "backend/optimizer/graph_kernel/eliminate_redundant_output.h" | |||
| namespace mindspore::graphkernel { | |||
| @@ -34,21 +35,6 @@ AnfNodePtrList GetUpdateStateList(const FuncGraphPtr &func_graph) { | |||
| return result; | |||
| } | |||
| AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index) { | |||
| AnfNodePtrList result; | |||
| for (size_t i = begin_index; i < nodes.size(); i++) { | |||
| if (IsPrimitiveCNode(nodes[i], prim::kPrimMakeTuple)) { | |||
| auto mt = nodes[i]->cast<CNodePtr>(); | |||
| // recursively spread all inner tuples. | |||
| auto mt_inputs = SpreadTuples(mt->inputs(), 1); | |||
| result.insert(result.end(), mt_inputs.begin(), mt_inputs.end()); | |||
| } else { | |||
| result.push_back(nodes[i]); | |||
| } | |||
| } | |||
| return result; | |||
| } | |||
| AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdateState(const AnfNodePtrList &nodes, | |||
| const FuncGraphPtr &func_graph) { | |||
| AnfNodePtrList result; | |||
| @@ -85,7 +71,7 @@ bool SpreadUpdateState::Run(const FuncGraphPtr &func_graph) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->size() <= kUpdateStateRealInput) continue; | |||
| auto inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput); | |||
| auto inputs = GkUtils::SpreadTuples(cnode->inputs(), kUpdateStateRealInput); | |||
| // extend inputs of UpdateState if which have multiple outputs | |||
| inputs = ExtendInputsOfUpdateState(inputs, func_graph); | |||
| if (inputs.size() + kUpdateStateRealInput != cnode->size() || inputs[0] != cnode->input(kUpdateStateRealInput)) { | |||
| @@ -110,7 +96,7 @@ bool ShrinkUpdateState::Run(const FuncGraphPtr &func_graph) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->size() <= kUpdateStateRealInput + 1) continue; | |||
| AnfNodePtrList mt_inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput); | |||
| AnfNodePtrList mt_inputs = GkUtils::SpreadTuples(cnode->inputs(), kUpdateStateRealInput); | |||
| AbstractBasePtrList abs_list; | |||
| std::transform(mt_inputs.begin(), mt_inputs.end(), std::back_inserter(abs_list), | |||
| [](const AnfNodePtr &inp) { return inp->abstract(); }); | |||
| @@ -61,20 +61,6 @@ class ShrinkUpdateState : public opt::Pass { | |||
| bool Run(const FuncGraphPtr &func_graph) override; | |||
| }; | |||
| /** | |||
| * @brief Spread the MakeTuple in node list | |||
| * @param nodes | |||
| * @param begin_index | |||
| * @example | |||
| * input | |||
| * nodes: [ a, b, MakeTuple[i, j], c, d, MakeTuple[x, MakeTuple[y, z]] ] | |||
| * begin_index: 1 | |||
| * output | |||
| * [b, i, j, c, d, x, y, z] | |||
| * @return std::vector<AnfNodePtr> | |||
| */ | |||
| AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index = 0); | |||
| /** | |||
| * @brief Extend the getitem for UpdateState | |||
| * @example | |||