From 56c4145cc562fe616f40662ff9b8f1d02a926c85 Mon Sep 17 00:00:00 2001 From: dayschan Date: Mon, 22 Mar 2021 16:18:05 +0800 Subject: [PATCH] Refactor GraphKernelExpander (4th submission) Decoupled the process of expanding node from GraphKernelExpander pass, so that it can be rolled back when error occurs. By the way, supported controlling cluster ops by flags. --- .../graph_kernel/graph_kernel_expander.cc | 106 ++++++++---------- .../graph_kernel/graph_kernel_expander.h | 28 +++-- .../graph_kernel/graph_kernel_helper.cc | 22 ++++ .../graph_kernel/graph_kernel_helper.h | 6 +- .../graph_kernel/substitute_dropout.cc | 18 +-- .../graph_kernel/substitute_dropout.h | 10 +- 6 files changed, 107 insertions(+), 83 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc index 2b3d08bacb..d438333967 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc @@ -17,7 +17,7 @@ #include "backend/optimizer/graph_kernel/graph_kernel_expander.h" #include -#include +#include #include #include @@ -37,8 +37,8 @@ namespace mindspore { namespace opt { namespace { -std::unordered_set GetExpandOps() { - std::unordered_set expand_ops = { +std::vector GetExpandOps() { + std::vector expand_ops = { prim::kPrimSquare, prim::kPrimGeLUGrad, #if ENABLE_D @@ -55,7 +55,7 @@ std::unordered_set GetExpandOps() { prim::kPrimReduceMean, prim::kPrimMaximumGrad, prim::kPrimMinimumGrad, - prim::kPrimGkDropout, + prim::kPrimDropout, prim::kPrimDropoutGrad, prim::kPrimSoftmax, prim::kPrimLayerNorm, @@ -68,41 +68,20 @@ std::unordered_set GetExpandOps() { prim::kPrimAssignAdd, #endif }; - auto new_prim = [](const std::string &name) { return std::make_shared(name); }; - auto &flags = context::GraphKernelFlags::GetInstance(); - auto &enable_ops_only = flags.enable_expand_ops_only; - if (!enable_ops_only.empty()) { - expand_ops.clear(); - std::transform(enable_ops_only.begin(), enable_ops_only.end(), std::inserter(expand_ops, expand_ops.end()), - new_prim); - } else { - auto &enable_ops = flags.enable_expand_ops; - auto &disable_ops = flags.disable_expand_ops; - if (!enable_ops.empty()) { - std::transform(enable_ops.begin(), enable_ops.end(), std::inserter(expand_ops, expand_ops.end()), new_prim); - } - if (!disable_ops.empty()) { - for (auto iter = expand_ops.begin(); iter != expand_ops.end();) { - if (std::find(disable_ops.begin(), disable_ops.end(), (*iter)->name()) != disable_ops.end()) { - expand_ops.erase(iter++); - } else { - ++iter; - } - } - } - } + const auto &flags = context::GraphKernelFlags::GetInstance(); + OpListFilter(&expand_ops, flags.enable_expand_ops_only, flags.enable_expand_ops, flags.disable_expand_ops); return expand_ops; } } // namespace -bool GraphKernelExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) { +bool DefaultExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) { DumpOption dump_option; dump_option.extract_opinfo_from_anfnode = true; kernel::AkgKernelJsonGenerator json_generator(dump_option); return json_generator.CollectJson(node, kernel_json); } -FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) { +FuncGraphPtr DefaultExpander::CreateExpandFuncGraph(const CNodePtr &node) { nlohmann::json kernel_json; if (!ExpandJsonInfo(node, &kernel_json)) { MS_LOG(ERROR) << "Expand json info to: " << node->DebugString(2) << " failed, ori_json:\n" << kernel_json.dump(); @@ -121,7 +100,6 @@ FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) { } std::string kernel_desc_str = py::cast(ret); if (kernel_desc_str.empty()) { - MS_LOG(INFO) << "Jump expand node: " << node->fullname_with_scope(); return nullptr; } // decode json to func_graph. @@ -129,10 +107,10 @@ FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) { return JsonDescToAnf(kernel_desc_str, ori_inputs); } -void GraphKernelExpander::EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) { +void DefaultExpander::EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) { const auto &ori_parameter = func_graph->parameters(); auto todos = TopoSort(func_graph->get_return()); - std::unordered_set used_param; + std::set used_param; for (auto node : todos) { if (node->isa()) { used_param.insert(node); @@ -152,21 +130,39 @@ void GraphKernelExpander::EliminateRedundantParameters(const FuncGraphPtr &func_ *inputs = std::move(new_inputs); } -AnfNodePtr GraphKernelExpander::CreateExpandGraphKernel(const FuncGraphPtr &func_graph, - const FuncGraphPtr &new_func_graph, const CNodePtr &node) { - std::vector inputs(node->inputs().begin() + 1, node->inputs().end()); +AnfNodePtr DefaultExpander::CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node) { + auto func_graph = old_node->func_graph(); + std::vector inputs(old_node->inputs().begin() + 1, old_node->inputs().end()); AnfNodePtrList kernel_nodes; AnfNodePtrList outputs; EliminateRedundantParameters(new_func_graph, &inputs); kernel::GetValidKernelNodes(new_func_graph, &kernel_nodes); kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs); auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs); - SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs, AnfAlgo::GetProcessor(node)); - MS_LOG(DEBUG) << "Expand node: " << node->fullname_with_scope() + SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs, AnfAlgo::GetProcessor(old_node)); + MS_LOG(DEBUG) << "Expand node: " << old_node->fullname_with_scope() << " with: " << graph_kernel_node->fullname_with_scope(); return graph_kernel_node; } +AnfNodePtr DefaultExpander::Run(const AnfNodePtr &node) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto new_func_graph = CreateExpandFuncGraph(cnode); + if (new_func_graph == nullptr) { + return nullptr; + } + new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(AnfAlgo::GetCNodeName(cnode))); + auto graph_kernel_node = CreateExpandGraphKernel(new_func_graph, cnode); + if (AnfAlgo::GetOutputTensorNum(node) != AnfAlgo::GetOutputTensorNum(graph_kernel_node)) { + MS_LOG(ERROR) << "The output num of composite node (" << AnfAlgo::GetOutputTensorNum(graph_kernel_node) + << ") does not match the original basic node (" << AnfAlgo::GetOutputTensorNum(node) << ")." + << node->fullname_with_scope(); + return nullptr; + } + return graph_kernel_node; +} + bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) { bool changed = false; auto todos = TopoSort(func_graph->get_return()); @@ -181,35 +177,31 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) { } MS_LOG(INFO) << "Expanding node: " << node->fullname_with_scope(); - auto new_func_graph = CreateExpandFuncGraph(node); - if (new_func_graph == nullptr) { + auto new_node = GetExpander(node)->Run(node); + if (new_node == nullptr) { + MS_LOG(INFO) << "Skipped node: " << node->fullname_with_scope(); continue; } - mng->AddFuncGraph(new_func_graph); - - auto graph_kernel_node = CreateExpandGraphKernel(func_graph, new_func_graph, node); - new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(AnfAlgo::GetCNodeName(node))); - - // replace origin node. - (void)mng->Replace(node, graph_kernel_node); + (void)mng->Replace(node, new_node); changed = true; } return changed; } -bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) { - expand_ops_ = GetExpandOps(); - MS_EXCEPTION_IF_NULL(func_graph); - if (expand_ops_.count(prim::kPrimGkDropout) > 0) { - std::shared_ptr pass = std::make_shared(); - pass->Run(func_graph); +ExpanderPtr GraphKernelExpander::GetExpander(const AnfNodePtr &node) { + std::vector> expanders = { + {prim::kPrimDropout, std::make_shared()}, + }; + for (auto &e : expanders) { + if (IsPrimitiveCNode(node, e.first)) { + return e.second; + } } + return std::make_shared(); +} - auto mng = func_graph->manager(); - if (mng == nullptr) { - mng = Manage(func_graph, true); - func_graph->set_manager(mng); - } +bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) { + expand_ops_ = GetExpandOps(); return DoExpand(func_graph); } } // namespace opt diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.h index ab153fafdb..e8055bea4e 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.h @@ -16,13 +16,30 @@ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_ #include -#include +#include #include #include "backend/optimizer/common/pass.h" #include "ir/func_graph.h" namespace mindspore { namespace opt { +class Expander { + public: + virtual AnfNodePtr Run(const AnfNodePtr &node) = 0; +}; +using ExpanderPtr = std::shared_ptr; + +class DefaultExpander : public Expander { + public: + AnfNodePtr Run(const AnfNodePtr &node) override; + + protected: + bool ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json); + void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs); + AnfNodePtr CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node); + FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node); +}; + class GraphKernelExpander : public Pass { public: GraphKernelExpander() : Pass("graph_kernel_expander") {} @@ -30,21 +47,16 @@ class GraphKernelExpander : public Pass { bool Run(const FuncGraphPtr &func_graph); private: - FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node); + ExpanderPtr GetExpander(const AnfNodePtr &node); bool DoExpand(const FuncGraphPtr &func_graph); - void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs); - AnfNodePtr CreateExpandGraphKernel(const FuncGraphPtr &func_graph, const FuncGraphPtr &new_func_graph, - const CNodePtr &node); bool CanExpand(const CNodePtr &node) { return std::any_of(expand_ops_.begin(), expand_ops_.end(), [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); } - bool ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json); private: - std::unordered_set expand_ops_; + std::vector expand_ops_; }; -using GraphKernelExpanderPtr = std::shared_ptr; } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index df19410f6c..a2827f063c 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -32,6 +32,7 @@ #include "ir/func_graph.h" #include "pipeline/jit/parse/python_adapter.h" #include "pipeline/jit/action.h" +#include "utils/context/graph_kernel_flags.h" #include "vm/segment_runner.h" #if ENABLE_GPU #include "runtime/device/gpu/kernel_info_setter.h" @@ -587,6 +588,8 @@ std::vector GetFusibleOpList() { #else std::vector fusible_basic_ops; #endif + const auto &flags = context::GraphKernelFlags::GetInstance(); + OpListFilter(&fusible_basic_ops, flags.enable_cluster_ops_only, flags.enable_cluster_ops, flags.disable_cluster_ops); return fusible_basic_ops; } @@ -901,5 +904,24 @@ bool IsKeepBasicNode(const AnfNodePtr &node) { return false; } + +void OpListFilter(std::vector *ops, const std::vector &enable_ops_only, + const std::vector &enable_ops, const std::vector &disable_ops) { + auto new_prim = [](const std::string &name) { return std::make_shared(name); }; + if (!enable_ops_only.empty()) { + ops->clear(); + std::transform(enable_ops_only.begin(), enable_ops_only.end(), std::back_inserter(*ops), new_prim); + } else { + if (!enable_ops.empty()) { + std::transform(enable_ops.begin(), enable_ops.end(), std::back_inserter(*ops), new_prim); + } + if (!disable_ops.empty()) { + auto iter = std::remove_if(ops->begin(), ops->end(), [&disable_ops](const PrimitivePtr &p) { + return std::find(disable_ops.begin(), disable_ops.end(), p->name()) != disable_ops.end(); + }); + ops->erase(iter, ops->end()); + } + } +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h index 0e4ead2a4d..9b7810e71e 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h @@ -33,9 +33,6 @@ #include namespace mindspore { -namespace prim { -inline const PrimitivePtr kPrimGkDropout = std::make_shared("GkDropout"); -} // namespace prim namespace opt { using kernel::DumpOption; @@ -91,7 +88,8 @@ std::vector GetReduceAxis(const AnfNodePtr &node); CNodePtr CreateCNode(const std::vector &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info); void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node); bool IsKeepBasicNode(const AnfNodePtr &node); - +void OpListFilter(std::vector *ops, const std::vector &enable_ops_only, + const std::vector &enable_ops, const std::vector &disable_ops); template ValueNodePtr CreateScalarTensorValueNode(const DataInfo &info, T value, size_t data_length) { // Create tensor value. diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.cc index d5d8d31bfd..bfdd2c93e7 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.cc @@ -30,13 +30,11 @@ #include "runtime/device/kernel_info.h" namespace mindspore { +namespace prim { +inline const PrimitivePtr kPrimGkDropout = std::make_shared("GkDropout"); +} // namespace prim namespace opt { -unsigned int SubstituteDropout::seed_ = time(NULL); - -const BaseRef SubstituteDropout::DefinePattern() const { - VarPtr Xs = std::make_shared(); - return VectorRef({prim::kPrimDropout, Xs}); -} +unsigned int DropoutExpander::seed_ = time(NULL); void SetNewKernelInfo(const CNodePtr &kernel_node) { std::vector inputs_format; @@ -66,8 +64,7 @@ void SetNewKernelInfo(const CNodePtr &kernel_node) { AnfAlgo::SetSelectKernelBuildInfo(cnode_selected_info, kernel_node.get()); } -const AnfNodePtr SubstituteDropout::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { +AnfNodePtr DropoutExpander::PreProcess(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); CNodePtr cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); @@ -116,5 +113,10 @@ const AnfNodePtr SubstituteDropout::Process(const FuncGraphPtr &func_graph, cons SetNewKernelInfo(new_node); return new_node; } + +AnfNodePtr DropoutExpander::Run(const AnfNodePtr &node) { + auto gkdropout_node = PreProcess(node->func_graph(), node); + return DefaultExpander::Run(gkdropout_node); +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.h index 33804e77a4..240c291687 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.h @@ -16,18 +16,16 @@ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SUBSTITUTE_DROPOUT_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SUBSTITUTE_DROPOUT_H_ -#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/graph_kernel/graph_kernel_expander.h" namespace mindspore { namespace opt { -class SubstituteDropout : public PatternProcessPass { +class DropoutExpander : public DefaultExpander { public: - explicit SubstituteDropout(bool multigraph = true) : PatternProcessPass("substitute_dropout", multigraph) {} - ~SubstituteDropout() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + AnfNodePtr Run(const AnfNodePtr &node) override; private: + AnfNodePtr PreProcess(const FuncGraphPtr &, const AnfNodePtr &); static unsigned int seed_; }; } // namespace opt