Browse Source

[GraphKernel] clean code for graph_kernel_splitter* & add_stitch_atomic_clean_gpu*

pull/15910/head
r1chardf1d0 4 years ago
parent
commit
c4c69bf5e8
4 changed files with 52 additions and 47 deletions
  1. +7
    -7
      mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.cc
  2. +6
    -5
      mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h
  3. +38
    -34
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc
  4. +1
    -1
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.h

+ 7
- 7
mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.cc View File

@@ -75,8 +75,8 @@ void StitchAtomicCleanInsertter::CorrectKernelBuildInfo(const AnfNodePtr &compos
AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get()); AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
} }


CNodePtr StitchAtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph,
const AnfNodePtr &new_parameter) {
CNodePtr StitchAtomicCleanInsertter::CreateInplaceAssignNode(const FuncGraphPtr &sub_graph,
const AnfNodePtr &new_parameter) {
// add inplaceassign // add inplaceassign
AnfNodePtr out_node = atomic_add_node_; // Use result data itself, and set attr "fake_out" true. AnfNodePtr out_node = atomic_add_node_; // Use result data itself, and set attr "fake_out" true.
auto inplace_assign_node = auto inplace_assign_node =
@@ -88,8 +88,8 @@ CNodePtr StitchAtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(con
return inplace_assign_node; return inplace_assign_node;
} }


void StitchAtomicCleanInsertter::ProcessOriginCNode(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
const AnfNodePtr &new_input, const FuncGraphManagerPtr &mng) {
void StitchAtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input,
const FuncGraphManagerPtr &mng) {
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(composite_node); auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
auto mng_sub = sub_graph->manager(); auto mng_sub = sub_graph->manager();
if (mng_sub == nullptr) { if (mng_sub == nullptr) {
@@ -107,7 +107,7 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode(const KernelGraphPtr &main_g
parameter->set_abstract(new_input->abstract()); parameter->set_abstract(new_input->abstract());
parameter->set_kernel_info(new_input->kernel_info_ptr()); parameter->set_kernel_info(new_input->kernel_info_ptr());


auto inplace_assign = CreateInplaceAssignNodeAndCorrectReturn(sub_graph, parameter);
auto inplace_assign = CreateInplaceAssignNode(sub_graph, parameter);


// Replace atomic ReduceSum's user with atomic clean output, and add depend op after inplaceassign to avoid // Replace atomic ReduceSum's user with atomic clean output, and add depend op after inplaceassign to avoid
// elimination. // elimination.
@@ -116,7 +116,7 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode(const KernelGraphPtr &main_g
for (const auto &[user_node, index] : reduce_user_nodes) { for (const auto &[user_node, index] : reduce_user_nodes) {
auto user_cnode = user_node->cast<CNodePtr>(); auto user_cnode = user_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(user_cnode); MS_EXCEPTION_IF_NULL(user_cnode);
user_cnode->set_input(index, parameter);
user_cnode->set_input(static_cast<size_t>(index), parameter);
if (!connected) { if (!connected) {
std::vector<std::pair<AnfNodePtr, int>> user_user = FindInnerCNodeUsers(stitch_node_, user_cnode); std::vector<std::pair<AnfNodePtr, int>> user_user = FindInnerCNodeUsers(stitch_node_, user_cnode);
if (!user_user.empty()) { if (!user_user.empty()) {
@@ -135,7 +135,7 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode(const KernelGraphPtr &main_g
} }


std::vector<std::pair<AnfNodePtr, int>> StitchAtomicCleanInsertter::FindInnerCNodeUsers(const AnfNodePtr &inner_node, std::vector<std::pair<AnfNodePtr, int>> StitchAtomicCleanInsertter::FindInnerCNodeUsers(const AnfNodePtr &inner_node,
const CNodePtr &target) {
const CNodePtr &target) const {
auto node = inner_node->cast<CNodePtr>(); auto node = inner_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);


+ 6
- 5
mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h View File

@@ -34,11 +34,12 @@ class StitchAtomicCleanInsertter : public AtomicCleanInsertter {
bool Run(const FuncGraphPtr &func_graph) override; bool Run(const FuncGraphPtr &func_graph) override;


private: private:
void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input);
CNodePtr CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter);
std::vector<std::pair<AnfNodePtr, int>> FindInnerCNodeUsers(const AnfNodePtr &inner_node, const CNodePtr &target);
void ProcessOriginCNode(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
const AnfNodePtr &new_input, const FuncGraphManagerPtr &mng);
void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input) override;
CNodePtr CreateInplaceAssignNode(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter);
std::vector<std::pair<AnfNodePtr, int>> FindInnerCNodeUsers(const AnfNodePtr &inner_node,
const CNodePtr &target) const;
void ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input,
const FuncGraphManagerPtr &mng);
bool IsStitchWithAtomic(const AnfNodePtr &anf_node); bool IsStitchWithAtomic(const AnfNodePtr &anf_node);


AnfNodePtr stitch_node_{nullptr}; AnfNodePtr stitch_node_{nullptr};


+ 38
- 34
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc View File

@@ -34,7 +34,7 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace { namespace {
void TraverseFuncGraphFromCNode(const CNodePtr &cnode, std::function<void(AnfNodePtr &)> callback) {
void TraverseFuncGraphFromCNode(const CNodePtr &cnode, const std::function<void(AnfNodePtr &)> &callback) {
std::unordered_set<AnfNodePtr> visited; std::unordered_set<AnfNodePtr> visited;
std::queue<AnfNodePtr> que; std::queue<AnfNodePtr> que;
que.push(cnode); que.push(cnode);
@@ -55,7 +55,7 @@ void TraverseFuncGraphFromCNode(const CNodePtr &cnode, std::function<void(AnfNod
} }


// Visited each AnfNode once, use callback to do the job on AnfNode // Visited each AnfNode once, use callback to do the job on AnfNode
inline void TraverseFuncGraph(const FuncGraphPtr &root, std::function<void(AnfNodePtr &)> callback) {
inline void TraverseFuncGraph(const FuncGraphPtr &root, const std::function<void(AnfNodePtr &)> &callback) {
TraverseFuncGraphFromCNode(root->get_return(), callback); TraverseFuncGraphFromCNode(root->get_return(), callback);
} }


@@ -68,7 +68,7 @@ class Area {
if (cnode == nullptr) continue; if (cnode == nullptr) continue;
const auto &inputs = cnode->inputs(); const auto &inputs = cnode->inputs();
if (std::any_of(inputs.begin(), inputs.end(), [this](const AnfNodePtr &node) { return IsExternalCNode(node); })) { if (std::any_of(inputs.begin(), inputs.end(), [this](const AnfNodePtr &node) { return IsExternalCNode(node); })) {
spy_cnodes_.push_back(node);
spy_cnodes_.emplace_back(node);
} }
} }
} }
@@ -126,15 +126,15 @@ class Area {
size_t i = 0; size_t i = 0;
for (auto &traitor : traitor_nodes_) { for (auto &traitor : traitor_nodes_) {
tuple_node_index->insert(std::make_pair(traitor, i++)); tuple_node_index->insert(std::make_pair(traitor, i++));
maketuple_inputs.push_back(traitor);
abstracts.push_back(traitor->abstract());
maketuple_inputs.emplace_back(traitor);
abstracts.emplace_back(traitor->abstract());
} }
auto maketuple_node = func_graph->NewCNode(maketuple_inputs); auto maketuple_node = func_graph->NewCNode(maketuple_inputs);
maketuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstracts)); maketuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstracts));
nodes_.insert(maketuple_node); nodes_.insert(maketuple_node);
return_inputs.push_back(maketuple_node);
return_inputs.emplace_back(maketuple_node);
} else { } else {
return_inputs.push_back(traitor_nodes_[0]);
return_inputs.emplace_back(traitor_nodes_[0]);
} }
auto return_node = func_graph->NewCNode(return_inputs); auto return_node = func_graph->NewCNode(return_inputs);
return_node->set_abstract(return_inputs.back()->abstract()); return_node->set_abstract(return_inputs.back()->abstract());
@@ -146,7 +146,7 @@ class Area {


void AddTraitor(const AnfNodePtr &node) { void AddTraitor(const AnfNodePtr &node) {
if (std::find(traitor_nodes_.begin(), traitor_nodes_.end(), node) == traitor_nodes_.end()) { if (std::find(traitor_nodes_.begin(), traitor_nodes_.end(), node) == traitor_nodes_.end()) {
traitor_nodes_.push_back(node);
traitor_nodes_.emplace_back(node);
} }
} }


@@ -185,7 +185,7 @@ class AreaGraph {
// The output `main_cnodes` is a topo-sorted cnode list in main graph, holding the new sub_func_graphs. // The output `main_cnodes` is a topo-sorted cnode list in main graph, holding the new sub_func_graphs.
// The output `cnode_group_id` represents the indices of main_cnodes before topo-sorting. // The output `cnode_group_id` represents the indices of main_cnodes before topo-sorting.
void SplitGraph(const FuncGraphPtr &main_func_graph, std::vector<CNodePtr> *main_cnodes, void SplitGraph(const FuncGraphPtr &main_func_graph, std::vector<CNodePtr> *main_cnodes,
std::vector<size_t> *cnode_group_id, std::function<void(const Area &)> expand_callback) {
std::vector<size_t> *cnode_group_id, const std::function<void(const Area &)> &expand_callback) {
main_cnodes->clear(); main_cnodes->clear();
main_cnodes->resize(areas_.size(), nullptr); main_cnodes->resize(areas_.size(), nullptr);


@@ -228,7 +228,7 @@ class AreaGraph {
if (u == v) continue; if (u == v) continue;
areas_[u].AddTraitor(in_node); areas_[u].AddTraitor(in_node);
if (std::find(edge_prev_[v].begin(), edge_prev_[v].end(), u) == edge_prev_[v].end()) { if (std::find(edge_prev_[v].begin(), edge_prev_[v].end(), u) == edge_prev_[v].end()) {
edge_prev_[v].push_back(u);
edge_prev_[v].emplace_back(u);
} }
} }
} }
@@ -252,7 +252,7 @@ class AreaGraph {
while (!que.empty()) { while (!que.empty()) {
size_t u = que.front(); size_t u = que.front();
que.pop(); que.pop();
topo_order_.push_back(u);
topo_order_.emplace_back(u);
for (size_t i : edge_prev_[u]) { for (size_t i : edge_prev_[u]) {
if (--out_degree[i] == 0) que.push(i); if (--out_degree[i] == 0) que.push(i);
} }
@@ -280,9 +280,9 @@ class AreaGraph {
TraceGuard g_sub(std::make_shared<TraceOpt>(main_cnodes[input_area]->debug_info())); TraceGuard g_sub(std::make_shared<TraceOpt>(main_cnodes[input_area]->debug_info()));
auto getitem_node = main_func_graph->NewCNode(getitem_inputs); auto getitem_node = main_func_graph->NewCNode(getitem_inputs);
getitem_node->set_abstract(main_cnodes[input_area]->abstract()); getitem_node->set_abstract(main_cnodes[input_area]->abstract());
main_cnode_inputs.push_back(getitem_node);
main_cnode_inputs.emplace_back(getitem_node);
} else { } else {
main_cnode_inputs.push_back(main_cnodes[input_area]);
main_cnode_inputs.emplace_back(main_cnodes[input_area]);
} }
} }
auto new_main_cnode = main_func_graph->NewCNode(main_cnode_inputs); auto new_main_cnode = main_func_graph->NewCNode(main_cnode_inputs);
@@ -293,7 +293,7 @@ class AreaGraph {
void SortCNodes(std::vector<CNodePtr> *main_cnodes) { void SortCNodes(std::vector<CNodePtr> *main_cnodes) {
std::vector<CNodePtr> main_cnodes_sorted; std::vector<CNodePtr> main_cnodes_sorted;
std::transform(topo_order_.begin(), topo_order_.end(), std::back_inserter(main_cnodes_sorted), std::transform(topo_order_.begin(), topo_order_.end(), std::back_inserter(main_cnodes_sorted),
[main_cnodes](int index) { return main_cnodes->at(index); });
[main_cnodes](size_t index) { return main_cnodes->at(index); });
*main_cnodes = std::move(main_cnodes_sorted); *main_cnodes = std::move(main_cnodes_sorted);
} }


@@ -309,17 +309,20 @@ class AreaGraph {
std::unordered_map<AnfNodePtr, size_t> node_index_in_returned_tuple_; std::unordered_map<AnfNodePtr, size_t> node_index_in_returned_tuple_;
}; };


class SplitSchemer {
public:
SplitSchemer() = default;
virtual ~SplitSchemer() = default;
virtual bool Split(const FuncGraphPtr &func_graph) = 0;
virtual bool NeedInline(size_t group_id) const { return false; }
const std::vector<AnfNodePtrList> &split_plan() const { return split_plan_; }

protected:
std::vector<AnfNodePtrList> split_plan_;
};

class Splitter { class Splitter {
public: public:
class SplitSchemer {
public:
virtual bool Split(const FuncGraphPtr &func_graph) = 0;
virtual bool NeedInline(size_t group_id) const { return false; }
const std::vector<AnfNodePtrList> &split_plan() const { return split_plan_; }

protected:
std::vector<AnfNodePtrList> split_plan_;
};
using SplitSchemerPtr = std::shared_ptr<SplitSchemer>; using SplitSchemerPtr = std::shared_ptr<SplitSchemer>;
using SplitterPtr = std::shared_ptr<Splitter>; using SplitterPtr = std::shared_ptr<Splitter>;


@@ -345,14 +348,14 @@ class Splitter {
return true; return true;
} }


static SplitterPtr MakeSplitter(const CNodePtr &main_cnode, SplitSchemerPtr split_schemer) {
static SplitterPtr MakeSplitter(const CNodePtr &main_cnode, const SplitSchemerPtr &split_schemer) {
MS_EXCEPTION_IF_NULL(main_cnode); MS_EXCEPTION_IF_NULL(main_cnode);
MS_EXCEPTION_IF_NULL(main_cnode->func_graph()); MS_EXCEPTION_IF_NULL(main_cnode->func_graph());
MS_EXCEPTION_IF_NULL(split_schemer); MS_EXCEPTION_IF_NULL(split_schemer);
return std::make_shared<Splitter>(main_cnode, split_schemer); return std::make_shared<Splitter>(main_cnode, split_schemer);
} }


Splitter(const CNodePtr &main_cnode, SplitSchemerPtr split_schemer)
Splitter(const CNodePtr &main_cnode, const SplitSchemerPtr &split_schemer)
: main_func_graph_(main_cnode->func_graph()), old_subgraph_cnode_(main_cnode), split_schemer_(split_schemer) {} : main_func_graph_(main_cnode->func_graph()), old_subgraph_cnode_(main_cnode), split_schemer_(split_schemer) {}
~Splitter() = default; ~Splitter() = default;


@@ -376,7 +379,7 @@ class Splitter {
void BindFuncGraph() { void BindFuncGraph() {
for (const auto &cnode : new_subgraph_cnodes_) { for (const auto &cnode : new_subgraph_cnodes_) {
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode); auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
auto callback = [&sub_func_graph, this](const AnfNodePtr &node) {
auto callback = [&sub_func_graph](const AnfNodePtr &node) {
if (!node->isa<ValueNode>()) { if (!node->isa<ValueNode>()) {
node->set_func_graph(sub_func_graph); node->set_func_graph(sub_func_graph);
} }
@@ -425,7 +428,7 @@ class Splitter {
} }
} }
if (AnfAlgo::IsRealKernel(node)) { if (AnfAlgo::IsRealKernel(node)) {
inlined_nodes_.push_back(node);
inlined_nodes_.emplace_back(node);
} }
} }
} }
@@ -454,7 +457,7 @@ class Splitter {
if (i + 1 == new_subgraph_cnodes_.size()) { if (i + 1 == new_subgraph_cnodes_.size()) {
replace_map[this->old_subgraph_cnode_] = new_subgraph_cnodes_.back(); replace_map[this->old_subgraph_cnode_] = new_subgraph_cnodes_.back();
} }
tmp_subgraph_cnodes.push_back(new_subgraph_cnodes_[i]);
tmp_subgraph_cnodes.emplace_back(new_subgraph_cnodes_[i]);
} }
} }
new_subgraph_cnodes_ = std::move(tmp_subgraph_cnodes); new_subgraph_cnodes_ = std::move(tmp_subgraph_cnodes);
@@ -543,8 +546,9 @@ class Splitter {
std::unordered_map<ParameterPtr, AnfNodePtr> param_to_main_graph_node_map_; std::unordered_map<ParameterPtr, AnfNodePtr> param_to_main_graph_node_map_;
}; };


class CostModelSplitSchemer : public Splitter::SplitSchemer {
class CostModelSplitSchemer : public SplitSchemer {
public: public:
virtual ~CostModelSplitSchemer() = default;
bool Split(const FuncGraphPtr &func_graph) override { bool Split(const FuncGraphPtr &func_graph) override {
if (!func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { if (!func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
MS_EXCEPTION(NotSupportError) << "func_graph must be a GraphKernel node."; MS_EXCEPTION(NotSupportError) << "func_graph must be a GraphKernel node.";
@@ -620,7 +624,7 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer {
MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc; MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc;
return false; return false;
} }
split_plan_.push_back(std::move(res_graph));
split_plan_.emplace_back(std::move(res_graph));
} }


// ops to be inlined. // ops to be inlined.
@@ -687,14 +691,14 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer {


if (IsValidKernelNode(output)) { if (IsValidKernelNode(output)) {
auto group_id = node_group_[ret_node] = node_group_[output]; auto group_id = node_group_[ret_node] = node_group_[output];
split_plan_[group_id].push_back(ret_node);
split_plan_[group_id].emplace_back(ret_node);
return; return;
} }
// assign the make_tuple node to a new group. // assign the make_tuple node to a new group.
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimMakeTuple)) { if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimMakeTuple)) {
auto group_id = split_plan_.size(); auto group_id = split_plan_.size();
split_plan_.push_back({output, ret_node});
need_inline_.push_back(1);
split_plan_.emplace_back(AnfNodePtrList{output, ret_node});
need_inline_.emplace_back(1);
node_group_[ret_node] = node_group_[output] = group_id; node_group_[ret_node] = node_group_[output] = group_id;
return; return;
} }
@@ -711,7 +715,7 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer {
auto iter = node_group_.find(input); auto iter = node_group_.find(input);
if (iter != node_group_.end()) { if (iter != node_group_.end()) {
node_group_[node] = iter->second; node_group_[node] = iter->second;
split_plan_[iter->second].push_back(node);
split_plan_[iter->second].emplace_back(node);
found = true; found = true;
break; break;
} }


+ 1
- 1
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.h View File

@@ -25,7 +25,7 @@ class GraphKernelSplitter : public Pass {
public: public:
GraphKernelSplitter() : Pass("graph_kernel_splitter") {} GraphKernelSplitter() : Pass("graph_kernel_splitter") {}
~GraphKernelSplitter() override = default; ~GraphKernelSplitter() override = default;
bool Run(const FuncGraphPtr &func_graph);
bool Run(const FuncGraphPtr &func_graph) override;
}; };
using GraphKernelSplitterPtr = std::shared_ptr<GraphKernelSplitter>; using GraphKernelSplitterPtr = std::shared_ptr<GraphKernelSplitter>;
} // namespace opt } // namespace opt


Loading…
Cancel
Save