|
|
|
@@ -34,7 +34,7 @@ |
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
namespace { |
|
|
|
void TraverseFuncGraphFromCNode(const CNodePtr &cnode, std::function<void(AnfNodePtr &)> callback) { |
|
|
|
void TraverseFuncGraphFromCNode(const CNodePtr &cnode, const std::function<void(AnfNodePtr &)> &callback) { |
|
|
|
std::unordered_set<AnfNodePtr> visited; |
|
|
|
std::queue<AnfNodePtr> que; |
|
|
|
que.push(cnode); |
|
|
|
@@ -55,7 +55,7 @@ void TraverseFuncGraphFromCNode(const CNodePtr &cnode, std::function<void(AnfNod |
|
|
|
} |
|
|
|
|
|
|
|
// Visited each AnfNode once, use callback to do the job on AnfNode |
|
|
|
inline void TraverseFuncGraph(const FuncGraphPtr &root, std::function<void(AnfNodePtr &)> callback) { |
|
|
|
inline void TraverseFuncGraph(const FuncGraphPtr &root, const std::function<void(AnfNodePtr &)> &callback) { |
|
|
|
TraverseFuncGraphFromCNode(root->get_return(), callback); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -68,7 +68,7 @@ class Area { |
|
|
|
if (cnode == nullptr) continue; |
|
|
|
const auto &inputs = cnode->inputs(); |
|
|
|
if (std::any_of(inputs.begin(), inputs.end(), [this](const AnfNodePtr &node) { return IsExternalCNode(node); })) { |
|
|
|
spy_cnodes_.push_back(node); |
|
|
|
spy_cnodes_.emplace_back(node); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -126,15 +126,15 @@ class Area { |
|
|
|
size_t i = 0; |
|
|
|
for (auto &traitor : traitor_nodes_) { |
|
|
|
tuple_node_index->insert(std::make_pair(traitor, i++)); |
|
|
|
maketuple_inputs.push_back(traitor); |
|
|
|
abstracts.push_back(traitor->abstract()); |
|
|
|
maketuple_inputs.emplace_back(traitor); |
|
|
|
abstracts.emplace_back(traitor->abstract()); |
|
|
|
} |
|
|
|
auto maketuple_node = func_graph->NewCNode(maketuple_inputs); |
|
|
|
maketuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstracts)); |
|
|
|
nodes_.insert(maketuple_node); |
|
|
|
return_inputs.push_back(maketuple_node); |
|
|
|
return_inputs.emplace_back(maketuple_node); |
|
|
|
} else { |
|
|
|
return_inputs.push_back(traitor_nodes_[0]); |
|
|
|
return_inputs.emplace_back(traitor_nodes_[0]); |
|
|
|
} |
|
|
|
auto return_node = func_graph->NewCNode(return_inputs); |
|
|
|
return_node->set_abstract(return_inputs.back()->abstract()); |
|
|
|
@@ -146,7 +146,7 @@ class Area { |
|
|
|
|
|
|
|
void AddTraitor(const AnfNodePtr &node) { |
|
|
|
if (std::find(traitor_nodes_.begin(), traitor_nodes_.end(), node) == traitor_nodes_.end()) { |
|
|
|
traitor_nodes_.push_back(node); |
|
|
|
traitor_nodes_.emplace_back(node); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -185,7 +185,7 @@ class AreaGraph { |
|
|
|
// The output `main_cnodes` is a topo-sorted cnode list in main graph, holding the new sub_func_graphs. |
|
|
|
// The output `cnode_group_id` represents the indices of main_cnodes before topo-sorting. |
|
|
|
void SplitGraph(const FuncGraphPtr &main_func_graph, std::vector<CNodePtr> *main_cnodes, |
|
|
|
std::vector<size_t> *cnode_group_id, std::function<void(const Area &)> expand_callback) { |
|
|
|
std::vector<size_t> *cnode_group_id, const std::function<void(const Area &)> &expand_callback) { |
|
|
|
main_cnodes->clear(); |
|
|
|
main_cnodes->resize(areas_.size(), nullptr); |
|
|
|
|
|
|
|
@@ -228,7 +228,7 @@ class AreaGraph { |
|
|
|
if (u == v) continue; |
|
|
|
areas_[u].AddTraitor(in_node); |
|
|
|
if (std::find(edge_prev_[v].begin(), edge_prev_[v].end(), u) == edge_prev_[v].end()) { |
|
|
|
edge_prev_[v].push_back(u); |
|
|
|
edge_prev_[v].emplace_back(u); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -252,7 +252,7 @@ class AreaGraph { |
|
|
|
while (!que.empty()) { |
|
|
|
size_t u = que.front(); |
|
|
|
que.pop(); |
|
|
|
topo_order_.push_back(u); |
|
|
|
topo_order_.emplace_back(u); |
|
|
|
for (size_t i : edge_prev_[u]) { |
|
|
|
if (--out_degree[i] == 0) que.push(i); |
|
|
|
} |
|
|
|
@@ -280,9 +280,9 @@ class AreaGraph { |
|
|
|
TraceGuard g_sub(std::make_shared<TraceOpt>(main_cnodes[input_area]->debug_info())); |
|
|
|
auto getitem_node = main_func_graph->NewCNode(getitem_inputs); |
|
|
|
getitem_node->set_abstract(main_cnodes[input_area]->abstract()); |
|
|
|
main_cnode_inputs.push_back(getitem_node); |
|
|
|
main_cnode_inputs.emplace_back(getitem_node); |
|
|
|
} else { |
|
|
|
main_cnode_inputs.push_back(main_cnodes[input_area]); |
|
|
|
main_cnode_inputs.emplace_back(main_cnodes[input_area]); |
|
|
|
} |
|
|
|
} |
|
|
|
auto new_main_cnode = main_func_graph->NewCNode(main_cnode_inputs); |
|
|
|
@@ -293,7 +293,7 @@ class AreaGraph { |
|
|
|
void SortCNodes(std::vector<CNodePtr> *main_cnodes) { |
|
|
|
std::vector<CNodePtr> main_cnodes_sorted; |
|
|
|
std::transform(topo_order_.begin(), topo_order_.end(), std::back_inserter(main_cnodes_sorted), |
|
|
|
[main_cnodes](int index) { return main_cnodes->at(index); }); |
|
|
|
[main_cnodes](size_t index) { return main_cnodes->at(index); }); |
|
|
|
*main_cnodes = std::move(main_cnodes_sorted); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -309,17 +309,20 @@ class AreaGraph { |
|
|
|
std::unordered_map<AnfNodePtr, size_t> node_index_in_returned_tuple_; |
|
|
|
}; |
|
|
|
|
|
|
|
class SplitSchemer { |
|
|
|
public: |
|
|
|
SplitSchemer() = default; |
|
|
|
virtual ~SplitSchemer() = default; |
|
|
|
virtual bool Split(const FuncGraphPtr &func_graph) = 0; |
|
|
|
virtual bool NeedInline(size_t group_id) const { return false; } |
|
|
|
const std::vector<AnfNodePtrList> &split_plan() const { return split_plan_; } |
|
|
|
|
|
|
|
protected: |
|
|
|
std::vector<AnfNodePtrList> split_plan_; |
|
|
|
}; |
|
|
|
|
|
|
|
class Splitter { |
|
|
|
public: |
|
|
|
class SplitSchemer { |
|
|
|
public: |
|
|
|
virtual bool Split(const FuncGraphPtr &func_graph) = 0; |
|
|
|
virtual bool NeedInline(size_t group_id) const { return false; } |
|
|
|
const std::vector<AnfNodePtrList> &split_plan() const { return split_plan_; } |
|
|
|
|
|
|
|
protected: |
|
|
|
std::vector<AnfNodePtrList> split_plan_; |
|
|
|
}; |
|
|
|
using SplitSchemerPtr = std::shared_ptr<SplitSchemer>; |
|
|
|
using SplitterPtr = std::shared_ptr<Splitter>; |
|
|
|
|
|
|
|
@@ -345,14 +348,14 @@ class Splitter { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
static SplitterPtr MakeSplitter(const CNodePtr &main_cnode, SplitSchemerPtr split_schemer) { |
|
|
|
static SplitterPtr MakeSplitter(const CNodePtr &main_cnode, const SplitSchemerPtr &split_schemer) { |
|
|
|
MS_EXCEPTION_IF_NULL(main_cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(main_cnode->func_graph()); |
|
|
|
MS_EXCEPTION_IF_NULL(split_schemer); |
|
|
|
return std::make_shared<Splitter>(main_cnode, split_schemer); |
|
|
|
} |
|
|
|
|
|
|
|
Splitter(const CNodePtr &main_cnode, SplitSchemerPtr split_schemer) |
|
|
|
Splitter(const CNodePtr &main_cnode, const SplitSchemerPtr &split_schemer) |
|
|
|
: main_func_graph_(main_cnode->func_graph()), old_subgraph_cnode_(main_cnode), split_schemer_(split_schemer) {} |
|
|
|
~Splitter() = default; |
|
|
|
|
|
|
|
@@ -376,7 +379,7 @@ class Splitter { |
|
|
|
void BindFuncGraph() { |
|
|
|
for (const auto &cnode : new_subgraph_cnodes_) { |
|
|
|
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode); |
|
|
|
auto callback = [&sub_func_graph, this](const AnfNodePtr &node) { |
|
|
|
auto callback = [&sub_func_graph](const AnfNodePtr &node) { |
|
|
|
if (!node->isa<ValueNode>()) { |
|
|
|
node->set_func_graph(sub_func_graph); |
|
|
|
} |
|
|
|
@@ -425,7 +428,7 @@ class Splitter { |
|
|
|
} |
|
|
|
} |
|
|
|
if (AnfAlgo::IsRealKernel(node)) { |
|
|
|
inlined_nodes_.push_back(node); |
|
|
|
inlined_nodes_.emplace_back(node); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -454,7 +457,7 @@ class Splitter { |
|
|
|
if (i + 1 == new_subgraph_cnodes_.size()) { |
|
|
|
replace_map[this->old_subgraph_cnode_] = new_subgraph_cnodes_.back(); |
|
|
|
} |
|
|
|
tmp_subgraph_cnodes.push_back(new_subgraph_cnodes_[i]); |
|
|
|
tmp_subgraph_cnodes.emplace_back(new_subgraph_cnodes_[i]); |
|
|
|
} |
|
|
|
} |
|
|
|
new_subgraph_cnodes_ = std::move(tmp_subgraph_cnodes); |
|
|
|
@@ -543,8 +546,9 @@ class Splitter { |
|
|
|
std::unordered_map<ParameterPtr, AnfNodePtr> param_to_main_graph_node_map_; |
|
|
|
}; |
|
|
|
|
|
|
|
class CostModelSplitSchemer : public Splitter::SplitSchemer { |
|
|
|
class CostModelSplitSchemer : public SplitSchemer { |
|
|
|
public: |
|
|
|
virtual ~CostModelSplitSchemer() = default; |
|
|
|
bool Split(const FuncGraphPtr &func_graph) override { |
|
|
|
if (!func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { |
|
|
|
MS_EXCEPTION(NotSupportError) << "func_graph must be a GraphKernel node."; |
|
|
|
@@ -620,7 +624,7 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer { |
|
|
|
MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc; |
|
|
|
return false; |
|
|
|
} |
|
|
|
split_plan_.push_back(std::move(res_graph)); |
|
|
|
split_plan_.emplace_back(std::move(res_graph)); |
|
|
|
} |
|
|
|
|
|
|
|
// ops to be inlined. |
|
|
|
@@ -687,14 +691,14 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer { |
|
|
|
|
|
|
|
if (IsValidKernelNode(output)) { |
|
|
|
auto group_id = node_group_[ret_node] = node_group_[output]; |
|
|
|
split_plan_[group_id].push_back(ret_node); |
|
|
|
split_plan_[group_id].emplace_back(ret_node); |
|
|
|
return; |
|
|
|
} |
|
|
|
// assign the make_tuple node to a new group. |
|
|
|
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimMakeTuple)) { |
|
|
|
auto group_id = split_plan_.size(); |
|
|
|
split_plan_.push_back({output, ret_node}); |
|
|
|
need_inline_.push_back(1); |
|
|
|
split_plan_.emplace_back(AnfNodePtrList{output, ret_node}); |
|
|
|
need_inline_.emplace_back(1); |
|
|
|
node_group_[ret_node] = node_group_[output] = group_id; |
|
|
|
return; |
|
|
|
} |
|
|
|
@@ -711,7 +715,7 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer { |
|
|
|
auto iter = node_group_.find(input); |
|
|
|
if (iter != node_group_.end()) { |
|
|
|
node_group_[node] = iter->second; |
|
|
|
split_plan_[iter->second].push_back(node); |
|
|
|
split_plan_[iter->second].emplace_back(node); |
|
|
|
found = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
|