Browse Source

[GraphKernel] clean code for graph_kernel_splitter* & add_stitch_atomic_clean_gpu*

pull/15910/head
r1chardf1d0 4 years ago
parent
commit
c4c69bf5e8
4 changed files with 52 additions and 47 deletions
  1. +7
    -7
      mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.cc
  2. +6
    -5
      mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h
  3. +38
    -34
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc
  4. +1
    -1
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.h

+ 7
- 7
mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.cc View File

@@ -75,8 +75,8 @@ void StitchAtomicCleanInsertter::CorrectKernelBuildInfo(const AnfNodePtr &compos
AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
}

CNodePtr StitchAtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph,
const AnfNodePtr &new_parameter) {
CNodePtr StitchAtomicCleanInsertter::CreateInplaceAssignNode(const FuncGraphPtr &sub_graph,
const AnfNodePtr &new_parameter) {
// add inplaceassign
AnfNodePtr out_node = atomic_add_node_; // Use result data itself, and set attr "fake_out" true.
auto inplace_assign_node =
@@ -88,8 +88,8 @@ CNodePtr StitchAtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(con
return inplace_assign_node;
}

void StitchAtomicCleanInsertter::ProcessOriginCNode(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
const AnfNodePtr &new_input, const FuncGraphManagerPtr &mng) {
void StitchAtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input,
const FuncGraphManagerPtr &mng) {
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
auto mng_sub = sub_graph->manager();
if (mng_sub == nullptr) {
@@ -107,7 +107,7 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode(const KernelGraphPtr &main_g
parameter->set_abstract(new_input->abstract());
parameter->set_kernel_info(new_input->kernel_info_ptr());

auto inplace_assign = CreateInplaceAssignNodeAndCorrectReturn(sub_graph, parameter);
auto inplace_assign = CreateInplaceAssignNode(sub_graph, parameter);

// Replace atomic ReduceSum's user with atomic clean output, and add depend op after inplaceassign to avoid
// elimination.
@@ -116,7 +116,7 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode(const KernelGraphPtr &main_g
for (const auto &[user_node, index] : reduce_user_nodes) {
auto user_cnode = user_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(user_cnode);
user_cnode->set_input(index, parameter);
user_cnode->set_input(static_cast<size_t>(index), parameter);
if (!connected) {
std::vector<std::pair<AnfNodePtr, int>> user_user = FindInnerCNodeUsers(stitch_node_, user_cnode);
if (!user_user.empty()) {
@@ -135,7 +135,7 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode(const KernelGraphPtr &main_g
}

std::vector<std::pair<AnfNodePtr, int>> StitchAtomicCleanInsertter::FindInnerCNodeUsers(const AnfNodePtr &inner_node,
const CNodePtr &target) {
const CNodePtr &target) const {
auto node = inner_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(node);
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);


+ 6
- 5
mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h View File

@@ -34,11 +34,12 @@ class StitchAtomicCleanInsertter : public AtomicCleanInsertter {
bool Run(const FuncGraphPtr &func_graph) override;

private:
void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input);
CNodePtr CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter);
std::vector<std::pair<AnfNodePtr, int>> FindInnerCNodeUsers(const AnfNodePtr &inner_node, const CNodePtr &target);
void ProcessOriginCNode(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
const AnfNodePtr &new_input, const FuncGraphManagerPtr &mng);
void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input) override;
CNodePtr CreateInplaceAssignNode(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter);
std::vector<std::pair<AnfNodePtr, int>> FindInnerCNodeUsers(const AnfNodePtr &inner_node,
const CNodePtr &target) const;
void ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input,
const FuncGraphManagerPtr &mng);
bool IsStitchWithAtomic(const AnfNodePtr &anf_node);

AnfNodePtr stitch_node_{nullptr};


+ 38
- 34
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc View File

@@ -34,7 +34,7 @@
namespace mindspore {
namespace opt {
namespace {
void TraverseFuncGraphFromCNode(const CNodePtr &cnode, std::function<void(AnfNodePtr &)> callback) {
void TraverseFuncGraphFromCNode(const CNodePtr &cnode, const std::function<void(AnfNodePtr &)> &callback) {
std::unordered_set<AnfNodePtr> visited;
std::queue<AnfNodePtr> que;
que.push(cnode);
@@ -55,7 +55,7 @@ void TraverseFuncGraphFromCNode(const CNodePtr &cnode, std::function<void(AnfNod
}

// Visited each AnfNode once, use callback to do the job on AnfNode
inline void TraverseFuncGraph(const FuncGraphPtr &root, std::function<void(AnfNodePtr &)> callback) {
inline void TraverseFuncGraph(const FuncGraphPtr &root, const std::function<void(AnfNodePtr &)> &callback) {
TraverseFuncGraphFromCNode(root->get_return(), callback);
}

@@ -68,7 +68,7 @@ class Area {
if (cnode == nullptr) continue;
const auto &inputs = cnode->inputs();
if (std::any_of(inputs.begin(), inputs.end(), [this](const AnfNodePtr &node) { return IsExternalCNode(node); })) {
spy_cnodes_.push_back(node);
spy_cnodes_.emplace_back(node);
}
}
}
@@ -126,15 +126,15 @@ class Area {
size_t i = 0;
for (auto &traitor : traitor_nodes_) {
tuple_node_index->insert(std::make_pair(traitor, i++));
maketuple_inputs.push_back(traitor);
abstracts.push_back(traitor->abstract());
maketuple_inputs.emplace_back(traitor);
abstracts.emplace_back(traitor->abstract());
}
auto maketuple_node = func_graph->NewCNode(maketuple_inputs);
maketuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstracts));
nodes_.insert(maketuple_node);
return_inputs.push_back(maketuple_node);
return_inputs.emplace_back(maketuple_node);
} else {
return_inputs.push_back(traitor_nodes_[0]);
return_inputs.emplace_back(traitor_nodes_[0]);
}
auto return_node = func_graph->NewCNode(return_inputs);
return_node->set_abstract(return_inputs.back()->abstract());
@@ -146,7 +146,7 @@ class Area {

void AddTraitor(const AnfNodePtr &node) {
if (std::find(traitor_nodes_.begin(), traitor_nodes_.end(), node) == traitor_nodes_.end()) {
traitor_nodes_.push_back(node);
traitor_nodes_.emplace_back(node);
}
}

@@ -185,7 +185,7 @@ class AreaGraph {
// The output `main_cnodes` is a topo-sorted cnode list in main graph, holding the new sub_func_graphs.
// The output `cnode_group_id` represents the indices of main_cnodes before topo-sorting.
void SplitGraph(const FuncGraphPtr &main_func_graph, std::vector<CNodePtr> *main_cnodes,
std::vector<size_t> *cnode_group_id, std::function<void(const Area &)> expand_callback) {
std::vector<size_t> *cnode_group_id, const std::function<void(const Area &)> &expand_callback) {
main_cnodes->clear();
main_cnodes->resize(areas_.size(), nullptr);

@@ -228,7 +228,7 @@ class AreaGraph {
if (u == v) continue;
areas_[u].AddTraitor(in_node);
if (std::find(edge_prev_[v].begin(), edge_prev_[v].end(), u) == edge_prev_[v].end()) {
edge_prev_[v].push_back(u);
edge_prev_[v].emplace_back(u);
}
}
}
@@ -252,7 +252,7 @@ class AreaGraph {
while (!que.empty()) {
size_t u = que.front();
que.pop();
topo_order_.push_back(u);
topo_order_.emplace_back(u);
for (size_t i : edge_prev_[u]) {
if (--out_degree[i] == 0) que.push(i);
}
@@ -280,9 +280,9 @@ class AreaGraph {
TraceGuard g_sub(std::make_shared<TraceOpt>(main_cnodes[input_area]->debug_info()));
auto getitem_node = main_func_graph->NewCNode(getitem_inputs);
getitem_node->set_abstract(main_cnodes[input_area]->abstract());
main_cnode_inputs.push_back(getitem_node);
main_cnode_inputs.emplace_back(getitem_node);
} else {
main_cnode_inputs.push_back(main_cnodes[input_area]);
main_cnode_inputs.emplace_back(main_cnodes[input_area]);
}
}
auto new_main_cnode = main_func_graph->NewCNode(main_cnode_inputs);
@@ -293,7 +293,7 @@ class AreaGraph {
void SortCNodes(std::vector<CNodePtr> *main_cnodes) {
std::vector<CNodePtr> main_cnodes_sorted;
std::transform(topo_order_.begin(), topo_order_.end(), std::back_inserter(main_cnodes_sorted),
[main_cnodes](int index) { return main_cnodes->at(index); });
[main_cnodes](size_t index) { return main_cnodes->at(index); });
*main_cnodes = std::move(main_cnodes_sorted);
}

@@ -309,17 +309,20 @@ class AreaGraph {
std::unordered_map<AnfNodePtr, size_t> node_index_in_returned_tuple_;
};

class SplitSchemer {
public:
SplitSchemer() = default;
virtual ~SplitSchemer() = default;
virtual bool Split(const FuncGraphPtr &func_graph) = 0;
virtual bool NeedInline(size_t group_id) const { return false; }
const std::vector<AnfNodePtrList> &split_plan() const { return split_plan_; }

protected:
std::vector<AnfNodePtrList> split_plan_;
};

class Splitter {
public:
class SplitSchemer {
public:
virtual bool Split(const FuncGraphPtr &func_graph) = 0;
virtual bool NeedInline(size_t group_id) const { return false; }
const std::vector<AnfNodePtrList> &split_plan() const { return split_plan_; }

protected:
std::vector<AnfNodePtrList> split_plan_;
};
using SplitSchemerPtr = std::shared_ptr<SplitSchemer>;
using SplitterPtr = std::shared_ptr<Splitter>;

@@ -345,14 +348,14 @@ class Splitter {
return true;
}

static SplitterPtr MakeSplitter(const CNodePtr &main_cnode, SplitSchemerPtr split_schemer) {
static SplitterPtr MakeSplitter(const CNodePtr &main_cnode, const SplitSchemerPtr &split_schemer) {
MS_EXCEPTION_IF_NULL(main_cnode);
MS_EXCEPTION_IF_NULL(main_cnode->func_graph());
MS_EXCEPTION_IF_NULL(split_schemer);
return std::make_shared<Splitter>(main_cnode, split_schemer);
}

Splitter(const CNodePtr &main_cnode, SplitSchemerPtr split_schemer)
Splitter(const CNodePtr &main_cnode, const SplitSchemerPtr &split_schemer)
: main_func_graph_(main_cnode->func_graph()), old_subgraph_cnode_(main_cnode), split_schemer_(split_schemer) {}
~Splitter() = default;

@@ -376,7 +379,7 @@ class Splitter {
void BindFuncGraph() {
for (const auto &cnode : new_subgraph_cnodes_) {
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
auto callback = [&sub_func_graph, this](const AnfNodePtr &node) {
auto callback = [&sub_func_graph](const AnfNodePtr &node) {
if (!node->isa<ValueNode>()) {
node->set_func_graph(sub_func_graph);
}
@@ -425,7 +428,7 @@ class Splitter {
}
}
if (AnfAlgo::IsRealKernel(node)) {
inlined_nodes_.push_back(node);
inlined_nodes_.emplace_back(node);
}
}
}
@@ -454,7 +457,7 @@ class Splitter {
if (i + 1 == new_subgraph_cnodes_.size()) {
replace_map[this->old_subgraph_cnode_] = new_subgraph_cnodes_.back();
}
tmp_subgraph_cnodes.push_back(new_subgraph_cnodes_[i]);
tmp_subgraph_cnodes.emplace_back(new_subgraph_cnodes_[i]);
}
}
new_subgraph_cnodes_ = std::move(tmp_subgraph_cnodes);
@@ -543,8 +546,9 @@ class Splitter {
std::unordered_map<ParameterPtr, AnfNodePtr> param_to_main_graph_node_map_;
};

class CostModelSplitSchemer : public Splitter::SplitSchemer {
class CostModelSplitSchemer : public SplitSchemer {
public:
virtual ~CostModelSplitSchemer() = default;
bool Split(const FuncGraphPtr &func_graph) override {
if (!func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
MS_EXCEPTION(NotSupportError) << "func_graph must be a GraphKernel node.";
@@ -620,7 +624,7 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer {
MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc;
return false;
}
split_plan_.push_back(std::move(res_graph));
split_plan_.emplace_back(std::move(res_graph));
}

// ops to be inlined.
@@ -687,14 +691,14 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer {

if (IsValidKernelNode(output)) {
auto group_id = node_group_[ret_node] = node_group_[output];
split_plan_[group_id].push_back(ret_node);
split_plan_[group_id].emplace_back(ret_node);
return;
}
// assign the make_tuple node to a new group.
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimMakeTuple)) {
auto group_id = split_plan_.size();
split_plan_.push_back({output, ret_node});
need_inline_.push_back(1);
split_plan_.emplace_back(AnfNodePtrList{output, ret_node});
need_inline_.emplace_back(1);
node_group_[ret_node] = node_group_[output] = group_id;
return;
}
@@ -711,7 +715,7 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer {
auto iter = node_group_.find(input);
if (iter != node_group_.end()) {
node_group_[node] = iter->second;
split_plan_[iter->second].push_back(node);
split_plan_[iter->second].emplace_back(node);
found = true;
break;
}


+ 1
- 1
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.h View File

@@ -25,7 +25,7 @@ class GraphKernelSplitter : public Pass {
public:
GraphKernelSplitter() : Pass("graph_kernel_splitter") {}
~GraphKernelSplitter() override = default;
bool Run(const FuncGraphPtr &func_graph);
bool Run(const FuncGraphPtr &func_graph) override;
};
using GraphKernelSplitterPtr = std::shared_ptr<GraphKernelSplitter>;
} // namespace opt


Loading…
Cancel
Save