Browse Source

!14198 [GraphKernel] Refactor GraphKernelExpander (4th submission)

From: @dayschan
Reviewed-by: @gaoxiong1,@dylangeng
Signed-off-by: @dylangeng
pull/14198/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
203106c133
6 changed files with 107 additions and 83 deletions
  1. +49
    -57
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc
  2. +20
    -8
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.h
  3. +22
    -0
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc
  4. +2
    -4
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h
  5. +10
    -8
      mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.cc
  6. +4
    -6
      mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.h

+ 49
- 57
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc View File

@@ -17,7 +17,7 @@
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h"

#include <string>
#include <unordered_set>
#include <set>
#include <utility>
#include <vector>

@@ -37,8 +37,8 @@
namespace mindspore {
namespace opt {
namespace {
std::unordered_set<PrimitivePtr> GetExpandOps() {
std::unordered_set<PrimitivePtr> expand_ops = {
std::vector<PrimitivePtr> GetExpandOps() {
std::vector<PrimitivePtr> expand_ops = {
prim::kPrimSquare,
prim::kPrimGeLUGrad,
#if ENABLE_D
@@ -55,7 +55,7 @@ std::unordered_set<PrimitivePtr> GetExpandOps() {
prim::kPrimReduceMean,
prim::kPrimMaximumGrad,
prim::kPrimMinimumGrad,
prim::kPrimGkDropout,
prim::kPrimDropout,
prim::kPrimDropoutGrad,
prim::kPrimSoftmax,
prim::kPrimLayerNorm,
@@ -68,41 +68,20 @@ std::unordered_set<PrimitivePtr> GetExpandOps() {
prim::kPrimAssignAdd,
#endif
};
auto new_prim = [](const std::string &name) { return std::make_shared<Primitive>(name); };
auto &flags = context::GraphKernelFlags::GetInstance();
auto &enable_ops_only = flags.enable_expand_ops_only;
if (!enable_ops_only.empty()) {
expand_ops.clear();
std::transform(enable_ops_only.begin(), enable_ops_only.end(), std::inserter(expand_ops, expand_ops.end()),
new_prim);
} else {
auto &enable_ops = flags.enable_expand_ops;
auto &disable_ops = flags.disable_expand_ops;
if (!enable_ops.empty()) {
std::transform(enable_ops.begin(), enable_ops.end(), std::inserter(expand_ops, expand_ops.end()), new_prim);
}
if (!disable_ops.empty()) {
for (auto iter = expand_ops.begin(); iter != expand_ops.end();) {
if (std::find(disable_ops.begin(), disable_ops.end(), (*iter)->name()) != disable_ops.end()) {
expand_ops.erase(iter++);
} else {
++iter;
}
}
}
}
const auto &flags = context::GraphKernelFlags::GetInstance();
OpListFilter(&expand_ops, flags.enable_expand_ops_only, flags.enable_expand_ops, flags.disable_expand_ops);
return expand_ops;
}
} // namespace

bool GraphKernelExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) {
bool DefaultExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) {
DumpOption dump_option;
dump_option.extract_opinfo_from_anfnode = true;
kernel::AkgKernelJsonGenerator json_generator(dump_option);
return json_generator.CollectJson(node, kernel_json);
}

FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) {
FuncGraphPtr DefaultExpander::CreateExpandFuncGraph(const CNodePtr &node) {
nlohmann::json kernel_json;
if (!ExpandJsonInfo(node, &kernel_json)) {
MS_LOG(ERROR) << "Expand json info to: " << node->DebugString(2) << " failed, ori_json:\n" << kernel_json.dump();
@@ -121,7 +100,6 @@ FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) {
}
std::string kernel_desc_str = py::cast<std::string>(ret);
if (kernel_desc_str.empty()) {
MS_LOG(INFO) << "Jump expand node: " << node->fullname_with_scope();
return nullptr;
}
// decode json to func_graph.
@@ -129,10 +107,10 @@ FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) {
return JsonDescToAnf(kernel_desc_str, ori_inputs);
}

void GraphKernelExpander::EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) {
void DefaultExpander::EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) {
const auto &ori_parameter = func_graph->parameters();
auto todos = TopoSort(func_graph->get_return());
std::unordered_set<AnfNodePtr> used_param;
std::set<AnfNodePtr> used_param;
for (auto node : todos) {
if (node->isa<Parameter>()) {
used_param.insert(node);
@@ -152,21 +130,39 @@ void GraphKernelExpander::EliminateRedundantParameters(const FuncGraphPtr &func_
*inputs = std::move(new_inputs);
}

AnfNodePtr GraphKernelExpander::CreateExpandGraphKernel(const FuncGraphPtr &func_graph,
const FuncGraphPtr &new_func_graph, const CNodePtr &node) {
std::vector<AnfNodePtr> inputs(node->inputs().begin() + 1, node->inputs().end());
AnfNodePtr DefaultExpander::CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node) {
auto func_graph = old_node->func_graph();
std::vector<AnfNodePtr> inputs(old_node->inputs().begin() + 1, old_node->inputs().end());
AnfNodePtrList kernel_nodes;
AnfNodePtrList outputs;
EliminateRedundantParameters(new_func_graph, &inputs);
kernel::GetValidKernelNodes(new_func_graph, &kernel_nodes);
kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs);
auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs);
SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs, AnfAlgo::GetProcessor(node));
MS_LOG(DEBUG) << "Expand node: " << node->fullname_with_scope()
SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs, AnfAlgo::GetProcessor(old_node));
MS_LOG(DEBUG) << "Expand node: " << old_node->fullname_with_scope()
<< " with: " << graph_kernel_node->fullname_with_scope();
return graph_kernel_node;
}

AnfNodePtr DefaultExpander::Run(const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto new_func_graph = CreateExpandFuncGraph(cnode);
if (new_func_graph == nullptr) {
return nullptr;
}
new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(AnfAlgo::GetCNodeName(cnode)));
auto graph_kernel_node = CreateExpandGraphKernel(new_func_graph, cnode);
if (AnfAlgo::GetOutputTensorNum(node) != AnfAlgo::GetOutputTensorNum(graph_kernel_node)) {
MS_LOG(ERROR) << "The output num of composite node (" << AnfAlgo::GetOutputTensorNum(graph_kernel_node)
<< ") does not match the original basic node (" << AnfAlgo::GetOutputTensorNum(node) << ")."
<< node->fullname_with_scope();
return nullptr;
}
return graph_kernel_node;
}

bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
bool changed = false;
auto todos = TopoSort(func_graph->get_return());
@@ -181,35 +177,31 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
}

MS_LOG(INFO) << "Expanding node: " << node->fullname_with_scope();
auto new_func_graph = CreateExpandFuncGraph(node);
if (new_func_graph == nullptr) {
auto new_node = GetExpander(node)->Run(node);
if (new_node == nullptr) {
MS_LOG(INFO) << "Skipped node: " << node->fullname_with_scope();
continue;
}
mng->AddFuncGraph(new_func_graph);

auto graph_kernel_node = CreateExpandGraphKernel(func_graph, new_func_graph, node);
new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(AnfAlgo::GetCNodeName(node)));

// replace origin node.
(void)mng->Replace(node, graph_kernel_node);
(void)mng->Replace(node, new_node);
changed = true;
}
return changed;
}

bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) {
expand_ops_ = GetExpandOps();
MS_EXCEPTION_IF_NULL(func_graph);
if (expand_ops_.count(prim::kPrimGkDropout) > 0) {
std::shared_ptr<Pass> pass = std::make_shared<opt::SubstituteDropout>();
pass->Run(func_graph);
ExpanderPtr GraphKernelExpander::GetExpander(const AnfNodePtr &node) {
std::vector<std::pair<PrimitivePtr, ExpanderPtr>> expanders = {
{prim::kPrimDropout, std::make_shared<DropoutExpander>()},
};
for (auto &e : expanders) {
if (IsPrimitiveCNode(node, e.first)) {
return e.second;
}
}
return std::make_shared<DefaultExpander>();
}

auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) {
expand_ops_ = GetExpandOps();
return DoExpand(func_graph);
}
} // namespace opt


+ 20
- 8
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.h View File

@@ -16,13 +16,30 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_
#include <memory>
#include <unordered_set>
#include <vector>
#include <nlohmann/json.hpp>
#include "backend/optimizer/common/pass.h"
#include "ir/func_graph.h"

namespace mindspore {
namespace opt {
class Expander {
public:
virtual AnfNodePtr Run(const AnfNodePtr &node) = 0;
};
using ExpanderPtr = std::shared_ptr<Expander>;

class DefaultExpander : public Expander {
public:
AnfNodePtr Run(const AnfNodePtr &node) override;

protected:
bool ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json);
void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs);
AnfNodePtr CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node);
FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node);
};

class GraphKernelExpander : public Pass {
public:
GraphKernelExpander() : Pass("graph_kernel_expander") {}
@@ -30,21 +47,16 @@ class GraphKernelExpander : public Pass {
bool Run(const FuncGraphPtr &func_graph);

private:
FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node);
ExpanderPtr GetExpander(const AnfNodePtr &node);
bool DoExpand(const FuncGraphPtr &func_graph);
void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs);
AnfNodePtr CreateExpandGraphKernel(const FuncGraphPtr &func_graph, const FuncGraphPtr &new_func_graph,
const CNodePtr &node);
bool CanExpand(const CNodePtr &node) {
return std::any_of(expand_ops_.begin(), expand_ops_.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
}
bool ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json);

private:
std::unordered_set<PrimitivePtr> expand_ops_;
std::vector<PrimitivePtr> expand_ops_;
};
using GraphKernelExpanderPtr = std::shared_ptr<GraphKernelExpander>;
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_

+ 22
- 0
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc View File

@@ -32,6 +32,7 @@
#include "ir/func_graph.h"
#include "pipeline/jit/parse/python_adapter.h"
#include "pipeline/jit/action.h"
#include "utils/context/graph_kernel_flags.h"
#include "vm/segment_runner.h"
#if ENABLE_GPU
#include "runtime/device/gpu/kernel_info_setter.h"
@@ -587,6 +588,8 @@ std::vector<PrimitivePtr> GetFusibleOpList() {
#else
std::vector<PrimitivePtr> fusible_basic_ops;
#endif
const auto &flags = context::GraphKernelFlags::GetInstance();
OpListFilter(&fusible_basic_ops, flags.enable_cluster_ops_only, flags.enable_cluster_ops, flags.disable_cluster_ops);
return fusible_basic_ops;
}

@@ -901,5 +904,24 @@ bool IsKeepBasicNode(const AnfNodePtr &node) {

return false;
}

void OpListFilter(std::vector<PrimitivePtr> *ops, const std::vector<std::string> &enable_ops_only,
const std::vector<std::string> &enable_ops, const std::vector<std::string> &disable_ops) {
auto new_prim = [](const std::string &name) { return std::make_shared<Primitive>(name); };
if (!enable_ops_only.empty()) {
ops->clear();
std::transform(enable_ops_only.begin(), enable_ops_only.end(), std::back_inserter(*ops), new_prim);
} else {
if (!enable_ops.empty()) {
std::transform(enable_ops.begin(), enable_ops.end(), std::back_inserter(*ops), new_prim);
}
if (!disable_ops.empty()) {
auto iter = std::remove_if(ops->begin(), ops->end(), [&disable_ops](const PrimitivePtr &p) {
return std::find(disable_ops.begin(), disable_ops.end(), p->name()) != disable_ops.end();
});
ops->erase(iter, ops->end());
}
}
}
} // namespace opt
} // namespace mindspore

+ 2
- 4
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h View File

@@ -33,9 +33,6 @@
#include <nlohmann/json.hpp>

namespace mindspore {
namespace prim {
inline const PrimitivePtr kPrimGkDropout = std::make_shared<Primitive>("GkDropout");
} // namespace prim
namespace opt {
using kernel::DumpOption;

@@ -91,7 +88,8 @@ std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node);
CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info);
void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node);
bool IsKeepBasicNode(const AnfNodePtr &node);

void OpListFilter(std::vector<PrimitivePtr> *ops, const std::vector<std::string> &enable_ops_only,
const std::vector<std::string> &enable_ops, const std::vector<std::string> &disable_ops);
template <typename T>
ValueNodePtr CreateScalarTensorValueNode(const DataInfo &info, T value, size_t data_length) {
// Create tensor value.


+ 10
- 8
mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.cc View File

@@ -30,13 +30,11 @@
#include "runtime/device/kernel_info.h"

namespace mindspore {
namespace prim {
inline const PrimitivePtr kPrimGkDropout = std::make_shared<Primitive>("GkDropout");
} // namespace prim
namespace opt {
unsigned int SubstituteDropout::seed_ = time(NULL);

const BaseRef SubstituteDropout::DefinePattern() const {
VarPtr Xs = std::make_shared<Var>();
return VectorRef({prim::kPrimDropout, Xs});
}
unsigned int DropoutExpander::seed_ = time(NULL);

void SetNewKernelInfo(const CNodePtr &kernel_node) {
std::vector<std::string> inputs_format;
@@ -66,8 +64,7 @@ void SetNewKernelInfo(const CNodePtr &kernel_node) {
AnfAlgo::SetSelectKernelBuildInfo(cnode_selected_info, kernel_node.get());
}

const AnfNodePtr SubstituteDropout::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
AnfNodePtr DropoutExpander::PreProcess(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
@@ -116,5 +113,10 @@ const AnfNodePtr SubstituteDropout::Process(const FuncGraphPtr &func_graph, cons
SetNewKernelInfo(new_node);
return new_node;
}

AnfNodePtr DropoutExpander::Run(const AnfNodePtr &node) {
auto gkdropout_node = PreProcess(node->func_graph(), node);
return DefaultExpander::Run(gkdropout_node);
}
} // namespace opt
} // namespace mindspore

+ 4
- 6
mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.h View File

@@ -16,18 +16,16 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SUBSTITUTE_DROPOUT_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SUBSTITUTE_DROPOUT_H_

#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h"

namespace mindspore {
namespace opt {
class SubstituteDropout : public PatternProcessPass {
class DropoutExpander : public DefaultExpander {
public:
explicit SubstituteDropout(bool multigraph = true) : PatternProcessPass("substitute_dropout", multigraph) {}
~SubstituteDropout() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
AnfNodePtr Run(const AnfNodePtr &node) override;

private:
AnfNodePtr PreProcess(const FuncGraphPtr &, const AnfNodePtr &);
static unsigned int seed_;
};
} // namespace opt


Loading…
Cancel
Save