Browse Source

!14198 [GraphKernel] Refactor GraphKernelExpander (4th submission)

From: @dayschan
Reviewed-by: @gaoxiong1,@dylangeng
Signed-off-by: @dylangeng
pull/14198/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
203106c133
6 changed files with 107 additions and 83 deletions
  1. +49
    -57
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc
  2. +20
    -8
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.h
  3. +22
    -0
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc
  4. +2
    -4
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h
  5. +10
    -8
      mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.cc
  6. +4
    -6
      mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.h

+ 49
- 57
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc View File

@@ -17,7 +17,7 @@
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h" #include "backend/optimizer/graph_kernel/graph_kernel_expander.h"


#include <string> #include <string>
#include <unordered_set>
#include <set>
#include <utility> #include <utility>
#include <vector> #include <vector>


@@ -37,8 +37,8 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace { namespace {
std::unordered_set<PrimitivePtr> GetExpandOps() {
std::unordered_set<PrimitivePtr> expand_ops = {
std::vector<PrimitivePtr> GetExpandOps() {
std::vector<PrimitivePtr> expand_ops = {
prim::kPrimSquare, prim::kPrimSquare,
prim::kPrimGeLUGrad, prim::kPrimGeLUGrad,
#if ENABLE_D #if ENABLE_D
@@ -55,7 +55,7 @@ std::unordered_set<PrimitivePtr> GetExpandOps() {
prim::kPrimReduceMean, prim::kPrimReduceMean,
prim::kPrimMaximumGrad, prim::kPrimMaximumGrad,
prim::kPrimMinimumGrad, prim::kPrimMinimumGrad,
prim::kPrimGkDropout,
prim::kPrimDropout,
prim::kPrimDropoutGrad, prim::kPrimDropoutGrad,
prim::kPrimSoftmax, prim::kPrimSoftmax,
prim::kPrimLayerNorm, prim::kPrimLayerNorm,
@@ -68,41 +68,20 @@ std::unordered_set<PrimitivePtr> GetExpandOps() {
prim::kPrimAssignAdd, prim::kPrimAssignAdd,
#endif #endif
}; };
auto new_prim = [](const std::string &name) { return std::make_shared<Primitive>(name); };
auto &flags = context::GraphKernelFlags::GetInstance();
auto &enable_ops_only = flags.enable_expand_ops_only;
if (!enable_ops_only.empty()) {
expand_ops.clear();
std::transform(enable_ops_only.begin(), enable_ops_only.end(), std::inserter(expand_ops, expand_ops.end()),
new_prim);
} else {
auto &enable_ops = flags.enable_expand_ops;
auto &disable_ops = flags.disable_expand_ops;
if (!enable_ops.empty()) {
std::transform(enable_ops.begin(), enable_ops.end(), std::inserter(expand_ops, expand_ops.end()), new_prim);
}
if (!disable_ops.empty()) {
for (auto iter = expand_ops.begin(); iter != expand_ops.end();) {
if (std::find(disable_ops.begin(), disable_ops.end(), (*iter)->name()) != disable_ops.end()) {
expand_ops.erase(iter++);
} else {
++iter;
}
}
}
}
const auto &flags = context::GraphKernelFlags::GetInstance();
OpListFilter(&expand_ops, flags.enable_expand_ops_only, flags.enable_expand_ops, flags.disable_expand_ops);
return expand_ops; return expand_ops;
} }
} // namespace } // namespace


bool GraphKernelExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) {
bool DefaultExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) {
DumpOption dump_option; DumpOption dump_option;
dump_option.extract_opinfo_from_anfnode = true; dump_option.extract_opinfo_from_anfnode = true;
kernel::AkgKernelJsonGenerator json_generator(dump_option); kernel::AkgKernelJsonGenerator json_generator(dump_option);
return json_generator.CollectJson(node, kernel_json); return json_generator.CollectJson(node, kernel_json);
} }


FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) {
FuncGraphPtr DefaultExpander::CreateExpandFuncGraph(const CNodePtr &node) {
nlohmann::json kernel_json; nlohmann::json kernel_json;
if (!ExpandJsonInfo(node, &kernel_json)) { if (!ExpandJsonInfo(node, &kernel_json)) {
MS_LOG(ERROR) << "Expand json info to: " << node->DebugString(2) << " failed, ori_json:\n" << kernel_json.dump(); MS_LOG(ERROR) << "Expand json info to: " << node->DebugString(2) << " failed, ori_json:\n" << kernel_json.dump();
@@ -121,7 +100,6 @@ FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) {
} }
std::string kernel_desc_str = py::cast<std::string>(ret); std::string kernel_desc_str = py::cast<std::string>(ret);
if (kernel_desc_str.empty()) { if (kernel_desc_str.empty()) {
MS_LOG(INFO) << "Jump expand node: " << node->fullname_with_scope();
return nullptr; return nullptr;
} }
// decode json to func_graph. // decode json to func_graph.
@@ -129,10 +107,10 @@ FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) {
return JsonDescToAnf(kernel_desc_str, ori_inputs); return JsonDescToAnf(kernel_desc_str, ori_inputs);
} }


void GraphKernelExpander::EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) {
void DefaultExpander::EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) {
const auto &ori_parameter = func_graph->parameters(); const auto &ori_parameter = func_graph->parameters();
auto todos = TopoSort(func_graph->get_return()); auto todos = TopoSort(func_graph->get_return());
std::unordered_set<AnfNodePtr> used_param;
std::set<AnfNodePtr> used_param;
for (auto node : todos) { for (auto node : todos) {
if (node->isa<Parameter>()) { if (node->isa<Parameter>()) {
used_param.insert(node); used_param.insert(node);
@@ -152,21 +130,39 @@ void GraphKernelExpander::EliminateRedundantParameters(const FuncGraphPtr &func_
*inputs = std::move(new_inputs); *inputs = std::move(new_inputs);
} }


AnfNodePtr GraphKernelExpander::CreateExpandGraphKernel(const FuncGraphPtr &func_graph,
const FuncGraphPtr &new_func_graph, const CNodePtr &node) {
std::vector<AnfNodePtr> inputs(node->inputs().begin() + 1, node->inputs().end());
AnfNodePtr DefaultExpander::CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node) {
auto func_graph = old_node->func_graph();
std::vector<AnfNodePtr> inputs(old_node->inputs().begin() + 1, old_node->inputs().end());
AnfNodePtrList kernel_nodes; AnfNodePtrList kernel_nodes;
AnfNodePtrList outputs; AnfNodePtrList outputs;
EliminateRedundantParameters(new_func_graph, &inputs); EliminateRedundantParameters(new_func_graph, &inputs);
kernel::GetValidKernelNodes(new_func_graph, &kernel_nodes); kernel::GetValidKernelNodes(new_func_graph, &kernel_nodes);
kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs); kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs);
auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs); auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs);
SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs, AnfAlgo::GetProcessor(node));
MS_LOG(DEBUG) << "Expand node: " << node->fullname_with_scope()
SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs, AnfAlgo::GetProcessor(old_node));
MS_LOG(DEBUG) << "Expand node: " << old_node->fullname_with_scope()
<< " with: " << graph_kernel_node->fullname_with_scope(); << " with: " << graph_kernel_node->fullname_with_scope();
return graph_kernel_node; return graph_kernel_node;
} }


AnfNodePtr DefaultExpander::Run(const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto new_func_graph = CreateExpandFuncGraph(cnode);
if (new_func_graph == nullptr) {
return nullptr;
}
new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(AnfAlgo::GetCNodeName(cnode)));
auto graph_kernel_node = CreateExpandGraphKernel(new_func_graph, cnode);
if (AnfAlgo::GetOutputTensorNum(node) != AnfAlgo::GetOutputTensorNum(graph_kernel_node)) {
MS_LOG(ERROR) << "The output num of composite node (" << AnfAlgo::GetOutputTensorNum(graph_kernel_node)
<< ") does not match the original basic node (" << AnfAlgo::GetOutputTensorNum(node) << ")."
<< node->fullname_with_scope();
return nullptr;
}
return graph_kernel_node;
}

bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) { bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
bool changed = false; bool changed = false;
auto todos = TopoSort(func_graph->get_return()); auto todos = TopoSort(func_graph->get_return());
@@ -181,35 +177,31 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
} }


MS_LOG(INFO) << "Expanding node: " << node->fullname_with_scope(); MS_LOG(INFO) << "Expanding node: " << node->fullname_with_scope();
auto new_func_graph = CreateExpandFuncGraph(node);
if (new_func_graph == nullptr) {
auto new_node = GetExpander(node)->Run(node);
if (new_node == nullptr) {
MS_LOG(INFO) << "Skipped node: " << node->fullname_with_scope();
continue; continue;
} }
mng->AddFuncGraph(new_func_graph);

auto graph_kernel_node = CreateExpandGraphKernel(func_graph, new_func_graph, node);
new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(AnfAlgo::GetCNodeName(node)));

// replace origin node.
(void)mng->Replace(node, graph_kernel_node);
(void)mng->Replace(node, new_node);
changed = true; changed = true;
} }
return changed; return changed;
} }


bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) {
expand_ops_ = GetExpandOps();
MS_EXCEPTION_IF_NULL(func_graph);
if (expand_ops_.count(prim::kPrimGkDropout) > 0) {
std::shared_ptr<Pass> pass = std::make_shared<opt::SubstituteDropout>();
pass->Run(func_graph);
ExpanderPtr GraphKernelExpander::GetExpander(const AnfNodePtr &node) {
std::vector<std::pair<PrimitivePtr, ExpanderPtr>> expanders = {
{prim::kPrimDropout, std::make_shared<DropoutExpander>()},
};
for (auto &e : expanders) {
if (IsPrimitiveCNode(node, e.first)) {
return e.second;
}
} }
return std::make_shared<DefaultExpander>();
}


auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) {
expand_ops_ = GetExpandOps();
return DoExpand(func_graph); return DoExpand(func_graph);
} }
} // namespace opt } // namespace opt


+ 20
- 8
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.h View File

@@ -16,13 +16,30 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_
#include <memory> #include <memory>
#include <unordered_set>
#include <vector>
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
#include "backend/optimizer/common/pass.h" #include "backend/optimizer/common/pass.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"


namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class Expander {
public:
virtual AnfNodePtr Run(const AnfNodePtr &node) = 0;
};
using ExpanderPtr = std::shared_ptr<Expander>;

class DefaultExpander : public Expander {
public:
AnfNodePtr Run(const AnfNodePtr &node) override;

protected:
bool ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json);
void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs);
AnfNodePtr CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node);
FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node);
};

class GraphKernelExpander : public Pass { class GraphKernelExpander : public Pass {
public: public:
GraphKernelExpander() : Pass("graph_kernel_expander") {} GraphKernelExpander() : Pass("graph_kernel_expander") {}
@@ -30,21 +47,16 @@ class GraphKernelExpander : public Pass {
bool Run(const FuncGraphPtr &func_graph); bool Run(const FuncGraphPtr &func_graph);


private: private:
FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node);
ExpanderPtr GetExpander(const AnfNodePtr &node);
bool DoExpand(const FuncGraphPtr &func_graph); bool DoExpand(const FuncGraphPtr &func_graph);
void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs);
AnfNodePtr CreateExpandGraphKernel(const FuncGraphPtr &func_graph, const FuncGraphPtr &new_func_graph,
const CNodePtr &node);
bool CanExpand(const CNodePtr &node) { bool CanExpand(const CNodePtr &node) {
return std::any_of(expand_ops_.begin(), expand_ops_.end(), return std::any_of(expand_ops_.begin(), expand_ops_.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
} }
bool ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json);


private: private:
std::unordered_set<PrimitivePtr> expand_ops_;
std::vector<PrimitivePtr> expand_ops_;
}; };
using GraphKernelExpanderPtr = std::shared_ptr<GraphKernelExpander>;
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_ #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_

+ 22
- 0
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc View File

@@ -32,6 +32,7 @@
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "pipeline/jit/parse/python_adapter.h" #include "pipeline/jit/parse/python_adapter.h"
#include "pipeline/jit/action.h" #include "pipeline/jit/action.h"
#include "utils/context/graph_kernel_flags.h"
#include "vm/segment_runner.h" #include "vm/segment_runner.h"
#if ENABLE_GPU #if ENABLE_GPU
#include "runtime/device/gpu/kernel_info_setter.h" #include "runtime/device/gpu/kernel_info_setter.h"
@@ -587,6 +588,8 @@ std::vector<PrimitivePtr> GetFusibleOpList() {
#else #else
std::vector<PrimitivePtr> fusible_basic_ops; std::vector<PrimitivePtr> fusible_basic_ops;
#endif #endif
const auto &flags = context::GraphKernelFlags::GetInstance();
OpListFilter(&fusible_basic_ops, flags.enable_cluster_ops_only, flags.enable_cluster_ops, flags.disable_cluster_ops);
return fusible_basic_ops; return fusible_basic_ops;
} }


@@ -901,5 +904,24 @@ bool IsKeepBasicNode(const AnfNodePtr &node) {


return false; return false;
} }

void OpListFilter(std::vector<PrimitivePtr> *ops, const std::vector<std::string> &enable_ops_only,
const std::vector<std::string> &enable_ops, const std::vector<std::string> &disable_ops) {
auto new_prim = [](const std::string &name) { return std::make_shared<Primitive>(name); };
if (!enable_ops_only.empty()) {
ops->clear();
std::transform(enable_ops_only.begin(), enable_ops_only.end(), std::back_inserter(*ops), new_prim);
} else {
if (!enable_ops.empty()) {
std::transform(enable_ops.begin(), enable_ops.end(), std::back_inserter(*ops), new_prim);
}
if (!disable_ops.empty()) {
auto iter = std::remove_if(ops->begin(), ops->end(), [&disable_ops](const PrimitivePtr &p) {
return std::find(disable_ops.begin(), disable_ops.end(), p->name()) != disable_ops.end();
});
ops->erase(iter, ops->end());
}
}
}
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

+ 2
- 4
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h View File

@@ -33,9 +33,6 @@
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>


namespace mindspore { namespace mindspore {
namespace prim {
inline const PrimitivePtr kPrimGkDropout = std::make_shared<Primitive>("GkDropout");
} // namespace prim
namespace opt { namespace opt {
using kernel::DumpOption; using kernel::DumpOption;


@@ -91,7 +88,8 @@ std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node);
CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info); CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info);
void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node); void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node);
bool IsKeepBasicNode(const AnfNodePtr &node); bool IsKeepBasicNode(const AnfNodePtr &node);

void OpListFilter(std::vector<PrimitivePtr> *ops, const std::vector<std::string> &enable_ops_only,
const std::vector<std::string> &enable_ops, const std::vector<std::string> &disable_ops);
template <typename T> template <typename T>
ValueNodePtr CreateScalarTensorValueNode(const DataInfo &info, T value, size_t data_length) { ValueNodePtr CreateScalarTensorValueNode(const DataInfo &info, T value, size_t data_length) {
// Create tensor value. // Create tensor value.


+ 10
- 8
mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.cc View File

@@ -30,13 +30,11 @@
#include "runtime/device/kernel_info.h" #include "runtime/device/kernel_info.h"


namespace mindspore { namespace mindspore {
namespace prim {
inline const PrimitivePtr kPrimGkDropout = std::make_shared<Primitive>("GkDropout");
} // namespace prim
namespace opt { namespace opt {
unsigned int SubstituteDropout::seed_ = time(NULL);

const BaseRef SubstituteDropout::DefinePattern() const {
VarPtr Xs = std::make_shared<Var>();
return VectorRef({prim::kPrimDropout, Xs});
}
unsigned int DropoutExpander::seed_ = time(NULL);


void SetNewKernelInfo(const CNodePtr &kernel_node) { void SetNewKernelInfo(const CNodePtr &kernel_node) {
std::vector<std::string> inputs_format; std::vector<std::string> inputs_format;
@@ -66,8 +64,7 @@ void SetNewKernelInfo(const CNodePtr &kernel_node) {
AnfAlgo::SetSelectKernelBuildInfo(cnode_selected_info, kernel_node.get()); AnfAlgo::SetSelectKernelBuildInfo(cnode_selected_info, kernel_node.get());
} }


const AnfNodePtr SubstituteDropout::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
AnfNodePtr DropoutExpander::PreProcess(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
CNodePtr cnode = node->cast<CNodePtr>(); CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
@@ -116,5 +113,10 @@ const AnfNodePtr SubstituteDropout::Process(const FuncGraphPtr &func_graph, cons
SetNewKernelInfo(new_node); SetNewKernelInfo(new_node);
return new_node; return new_node;
} }

AnfNodePtr DropoutExpander::Run(const AnfNodePtr &node) {
auto gkdropout_node = PreProcess(node->func_graph(), node);
return DefaultExpander::Run(gkdropout_node);
}
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

+ 4
- 6
mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.h View File

@@ -16,18 +16,16 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SUBSTITUTE_DROPOUT_H_ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SUBSTITUTE_DROPOUT_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SUBSTITUTE_DROPOUT_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SUBSTITUTE_DROPOUT_H_


#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h"


namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class SubstituteDropout : public PatternProcessPass {
class DropoutExpander : public DefaultExpander {
public: public:
explicit SubstituteDropout(bool multigraph = true) : PatternProcessPass("substitute_dropout", multigraph) {}
~SubstituteDropout() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
AnfNodePtr Run(const AnfNodePtr &node) override;


private: private:
AnfNodePtr PreProcess(const FuncGraphPtr &, const AnfNodePtr &);
static unsigned int seed_; static unsigned int seed_;
}; };
} // namespace opt } // namespace opt


Loading…
Cancel
Save