diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc index 2b3d08bacb..d438333967 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc @@ -17,7 +17,7 @@ #include "backend/optimizer/graph_kernel/graph_kernel_expander.h" #include -#include +#include #include #include @@ -37,8 +37,8 @@ namespace mindspore { namespace opt { namespace { -std::unordered_set GetExpandOps() { - std::unordered_set expand_ops = { +std::vector GetExpandOps() { + std::vector expand_ops = { prim::kPrimSquare, prim::kPrimGeLUGrad, #if ENABLE_D @@ -55,7 +55,7 @@ std::unordered_set GetExpandOps() { prim::kPrimReduceMean, prim::kPrimMaximumGrad, prim::kPrimMinimumGrad, - prim::kPrimGkDropout, + prim::kPrimDropout, prim::kPrimDropoutGrad, prim::kPrimSoftmax, prim::kPrimLayerNorm, @@ -68,41 +68,20 @@ std::unordered_set GetExpandOps() { prim::kPrimAssignAdd, #endif }; - auto new_prim = [](const std::string &name) { return std::make_shared(name); }; - auto &flags = context::GraphKernelFlags::GetInstance(); - auto &enable_ops_only = flags.enable_expand_ops_only; - if (!enable_ops_only.empty()) { - expand_ops.clear(); - std::transform(enable_ops_only.begin(), enable_ops_only.end(), std::inserter(expand_ops, expand_ops.end()), - new_prim); - } else { - auto &enable_ops = flags.enable_expand_ops; - auto &disable_ops = flags.disable_expand_ops; - if (!enable_ops.empty()) { - std::transform(enable_ops.begin(), enable_ops.end(), std::inserter(expand_ops, expand_ops.end()), new_prim); - } - if (!disable_ops.empty()) { - for (auto iter = expand_ops.begin(); iter != expand_ops.end();) { - if (std::find(disable_ops.begin(), disable_ops.end(), (*iter)->name()) != disable_ops.end()) { - expand_ops.erase(iter++); - } else { - ++iter; - } - } - } - } + const auto &flags = context::GraphKernelFlags::GetInstance(); + OpListFilter(&expand_ops, flags.enable_expand_ops_only, flags.enable_expand_ops, flags.disable_expand_ops); return expand_ops; } } // namespace -bool GraphKernelExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) { +bool DefaultExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) { DumpOption dump_option; dump_option.extract_opinfo_from_anfnode = true; kernel::AkgKernelJsonGenerator json_generator(dump_option); return json_generator.CollectJson(node, kernel_json); } -FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) { +FuncGraphPtr DefaultExpander::CreateExpandFuncGraph(const CNodePtr &node) { nlohmann::json kernel_json; if (!ExpandJsonInfo(node, &kernel_json)) { MS_LOG(ERROR) << "Expand json info to: " << node->DebugString(2) << " failed, ori_json:\n" << kernel_json.dump(); @@ -121,7 +100,6 @@ FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) { } std::string kernel_desc_str = py::cast(ret); if (kernel_desc_str.empty()) { - MS_LOG(INFO) << "Jump expand node: " << node->fullname_with_scope(); return nullptr; } // decode json to func_graph. @@ -129,10 +107,10 @@ FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) { return JsonDescToAnf(kernel_desc_str, ori_inputs); } -void GraphKernelExpander::EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) { +void DefaultExpander::EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) { const auto &ori_parameter = func_graph->parameters(); auto todos = TopoSort(func_graph->get_return()); - std::unordered_set used_param; + std::set used_param; for (auto node : todos) { if (node->isa()) { used_param.insert(node); @@ -152,21 +130,39 @@ void GraphKernelExpander::EliminateRedundantParameters(const FuncGraphPtr &func_ *inputs = std::move(new_inputs); } -AnfNodePtr GraphKernelExpander::CreateExpandGraphKernel(const FuncGraphPtr &func_graph, - const FuncGraphPtr &new_func_graph, const CNodePtr &node) { - std::vector inputs(node->inputs().begin() + 1, node->inputs().end()); +AnfNodePtr DefaultExpander::CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node) { + auto func_graph = old_node->func_graph(); + std::vector inputs(old_node->inputs().begin() + 1, old_node->inputs().end()); AnfNodePtrList kernel_nodes; AnfNodePtrList outputs; EliminateRedundantParameters(new_func_graph, &inputs); kernel::GetValidKernelNodes(new_func_graph, &kernel_nodes); kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs); auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs); - SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs, AnfAlgo::GetProcessor(node)); - MS_LOG(DEBUG) << "Expand node: " << node->fullname_with_scope() + SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs, AnfAlgo::GetProcessor(old_node)); + MS_LOG(DEBUG) << "Expand node: " << old_node->fullname_with_scope() << " with: " << graph_kernel_node->fullname_with_scope(); return graph_kernel_node; } +AnfNodePtr DefaultExpander::Run(const AnfNodePtr &node) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto new_func_graph = CreateExpandFuncGraph(cnode); + if (new_func_graph == nullptr) { + return nullptr; + } + new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(AnfAlgo::GetCNodeName(cnode))); + auto graph_kernel_node = CreateExpandGraphKernel(new_func_graph, cnode); + if (AnfAlgo::GetOutputTensorNum(node) != AnfAlgo::GetOutputTensorNum(graph_kernel_node)) { + MS_LOG(ERROR) << "The output num of composite node (" << AnfAlgo::GetOutputTensorNum(graph_kernel_node) + << ") does not match the original basic node (" << AnfAlgo::GetOutputTensorNum(node) << ")." + << node->fullname_with_scope(); + return nullptr; + } + return graph_kernel_node; +} + bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) { bool changed = false; auto todos = TopoSort(func_graph->get_return()); @@ -181,35 +177,31 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) { } MS_LOG(INFO) << "Expanding node: " << node->fullname_with_scope(); - auto new_func_graph = CreateExpandFuncGraph(node); - if (new_func_graph == nullptr) { + auto new_node = GetExpander(node)->Run(node); + if (new_node == nullptr) { + MS_LOG(INFO) << "Skipped node: " << node->fullname_with_scope(); continue; } - mng->AddFuncGraph(new_func_graph); - - auto graph_kernel_node = CreateExpandGraphKernel(func_graph, new_func_graph, node); - new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(AnfAlgo::GetCNodeName(node))); - - // replace origin node. - (void)mng->Replace(node, graph_kernel_node); + (void)mng->Replace(node, new_node); changed = true; } return changed; } -bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) { - expand_ops_ = GetExpandOps(); - MS_EXCEPTION_IF_NULL(func_graph); - if (expand_ops_.count(prim::kPrimGkDropout) > 0) { - std::shared_ptr pass = std::make_shared(); - pass->Run(func_graph); +ExpanderPtr GraphKernelExpander::GetExpander(const AnfNodePtr &node) { + std::vector> expanders = { + {prim::kPrimDropout, std::make_shared()}, + }; + for (auto &e : expanders) { + if (IsPrimitiveCNode(node, e.first)) { + return e.second; + } } + return std::make_shared(); +} - auto mng = func_graph->manager(); - if (mng == nullptr) { - mng = Manage(func_graph, true); - func_graph->set_manager(mng); - } +bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) { + expand_ops_ = GetExpandOps(); return DoExpand(func_graph); } } // namespace opt diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.h index ab153fafdb..e8055bea4e 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.h @@ -16,13 +16,30 @@ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_ #include -#include +#include #include #include "backend/optimizer/common/pass.h" #include "ir/func_graph.h" namespace mindspore { namespace opt { +class Expander { + public: + virtual AnfNodePtr Run(const AnfNodePtr &node) = 0; +}; +using ExpanderPtr = std::shared_ptr; + +class DefaultExpander : public Expander { + public: + AnfNodePtr Run(const AnfNodePtr &node) override; + + protected: + bool ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json); + void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs); + AnfNodePtr CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node); + FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node); +}; + class GraphKernelExpander : public Pass { public: GraphKernelExpander() : Pass("graph_kernel_expander") {} @@ -30,21 +47,16 @@ class GraphKernelExpander : public Pass { bool Run(const FuncGraphPtr &func_graph); private: - FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node); + ExpanderPtr GetExpander(const AnfNodePtr &node); bool DoExpand(const FuncGraphPtr &func_graph); - void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs); - AnfNodePtr CreateExpandGraphKernel(const FuncGraphPtr &func_graph, const FuncGraphPtr &new_func_graph, - const CNodePtr &node); bool CanExpand(const CNodePtr &node) { return std::any_of(expand_ops_.begin(), expand_ops_.end(), [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); } - bool ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json); private: - std::unordered_set expand_ops_; + std::vector expand_ops_; }; -using GraphKernelExpanderPtr = std::shared_ptr; } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index 3de37c8798..d018f09f17 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -32,6 +32,7 @@ #include "ir/func_graph.h" #include "pipeline/jit/parse/python_adapter.h" #include "pipeline/jit/action.h" +#include "utils/context/graph_kernel_flags.h" #include "vm/segment_runner.h" #if ENABLE_GPU #include "runtime/device/gpu/kernel_info_setter.h" @@ -587,6 +588,8 @@ std::vector GetFusibleOpList() { #else std::vector fusible_basic_ops; #endif + const auto &flags = context::GraphKernelFlags::GetInstance(); + OpListFilter(&fusible_basic_ops, flags.enable_cluster_ops_only, flags.enable_cluster_ops, flags.disable_cluster_ops); return fusible_basic_ops; } @@ -901,5 +904,24 @@ bool IsKeepBasicNode(const AnfNodePtr &node) { return false; } + +void OpListFilter(std::vector *ops, const std::vector &enable_ops_only, + const std::vector &enable_ops, const std::vector &disable_ops) { + auto new_prim = [](const std::string &name) { return std::make_shared(name); }; + if (!enable_ops_only.empty()) { + ops->clear(); + std::transform(enable_ops_only.begin(), enable_ops_only.end(), std::back_inserter(*ops), new_prim); + } else { + if (!enable_ops.empty()) { + std::transform(enable_ops.begin(), enable_ops.end(), std::back_inserter(*ops), new_prim); + } + if (!disable_ops.empty()) { + auto iter = std::remove_if(ops->begin(), ops->end(), [&disable_ops](const PrimitivePtr &p) { + return std::find(disable_ops.begin(), disable_ops.end(), p->name()) != disable_ops.end(); + }); + ops->erase(iter, ops->end()); + } + } +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h index 0e4ead2a4d..9b7810e71e 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h @@ -33,9 +33,6 @@ #include namespace mindspore { -namespace prim { -inline const PrimitivePtr kPrimGkDropout = std::make_shared("GkDropout"); -} // namespace prim namespace opt { using kernel::DumpOption; @@ -91,7 +88,8 @@ std::vector GetReduceAxis(const AnfNodePtr &node); CNodePtr CreateCNode(const std::vector &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info); void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node); bool IsKeepBasicNode(const AnfNodePtr &node); - +void OpListFilter(std::vector *ops, const std::vector &enable_ops_only, + const std::vector &enable_ops, const std::vector &disable_ops); template ValueNodePtr CreateScalarTensorValueNode(const DataInfo &info, T value, size_t data_length) { // Create tensor value. diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.cc index d5d8d31bfd..bfdd2c93e7 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.cc @@ -30,13 +30,11 @@ #include "runtime/device/kernel_info.h" namespace mindspore { +namespace prim { +inline const PrimitivePtr kPrimGkDropout = std::make_shared("GkDropout"); +} // namespace prim namespace opt { -unsigned int SubstituteDropout::seed_ = time(NULL); - -const BaseRef SubstituteDropout::DefinePattern() const { - VarPtr Xs = std::make_shared(); - return VectorRef({prim::kPrimDropout, Xs}); -} +unsigned int DropoutExpander::seed_ = time(NULL); void SetNewKernelInfo(const CNodePtr &kernel_node) { std::vector inputs_format; @@ -66,8 +64,7 @@ void SetNewKernelInfo(const CNodePtr &kernel_node) { AnfAlgo::SetSelectKernelBuildInfo(cnode_selected_info, kernel_node.get()); } -const AnfNodePtr SubstituteDropout::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { +AnfNodePtr DropoutExpander::PreProcess(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); CNodePtr cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); @@ -116,5 +113,10 @@ const AnfNodePtr SubstituteDropout::Process(const FuncGraphPtr &func_graph, cons SetNewKernelInfo(new_node); return new_node; } + +AnfNodePtr DropoutExpander::Run(const AnfNodePtr &node) { + auto gkdropout_node = PreProcess(node->func_graph(), node); + return DefaultExpander::Run(gkdropout_node); +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.h index 33804e77a4..240c291687 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.h @@ -16,18 +16,16 @@ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SUBSTITUTE_DROPOUT_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SUBSTITUTE_DROPOUT_H_ -#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/graph_kernel/graph_kernel_expander.h" namespace mindspore { namespace opt { -class SubstituteDropout : public PatternProcessPass { +class DropoutExpander : public DefaultExpander { public: - explicit SubstituteDropout(bool multigraph = true) : PatternProcessPass("substitute_dropout", multigraph) {} - ~SubstituteDropout() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + AnfNodePtr Run(const AnfNodePtr &node) override; private: + AnfNodePtr PreProcess(const FuncGraphPtr &, const AnfNodePtr &); static unsigned int seed_; }; } // namespace opt