Browse Source

exclude special node when expand or basic fusion

tags/v1.2.0-rc1
tronzhang 4 years ago
parent
commit
be2b9978be
4 changed files with 29 additions and 2 deletions
  1. +1
    -1
      mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc
  2. +2
    -1
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc
  3. +25
    -0
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc
  4. +1
    -0
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h

+ 1
- 1
mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc View File

@@ -140,7 +140,7 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr


for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
auto node = (*iter)->cast<CNodePtr>(); auto node = (*iter)->cast<CNodePtr>();
if (node == nullptr || fused_ops->count(node)) {
if (node == nullptr || IsKeepBasicNode(node) || fused_ops->count(node)) {
continue; continue;
} }
bool is_fusible_op = IsFusibleOp(node); bool is_fusible_op = IsFusibleOp(node);


+ 2
- 1
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc View File

@@ -147,7 +147,8 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(mng); MS_EXCEPTION_IF_NULL(mng);
for (const auto &n : todos) { for (const auto &n : todos) {
auto node = n->cast<CNodePtr>(); auto node = n->cast<CNodePtr>();
if (node == nullptr || !AnfAlgo::IsRealKernel(node) || AnfAlgo::IsGraphKernel(node) || !CanExpand(node)) {
if (node == nullptr || IsKeepBasicNode(node) || !AnfAlgo::IsRealKernel(node) || AnfAlgo::IsGraphKernel(node) ||
!CanExpand(node)) {
continue; continue;
} }




+ 25
- 0
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc View File

@@ -773,5 +773,30 @@ void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfN
// Set attr secondly. // Set attr secondly.
AnfAlgo::SetNodeAttr(key, value, node); AnfAlgo::SetNodeAttr(key, value, node);
} }

bool IsKeepBasicNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);

static std::vector<std::string> contagious_attrs = {"inplace_group", "inplace_algo", "inplace_output_index",
"aggregate", "aggregate_input_indexx"};
static std::vector<std::function<bool(const AnfNodePtr &node)>> attrs_with_value = {
[](const AnfNodePtr &n) -> bool { return AnfAlgo::GetBooleanAttr(n, "skip"); }};

// If node contain attribute in contagious_attrs, it have to keep basic no matter what the value is.
// If node contain attribute in attrs_with_value, it only have to keep basic when the check result is true.
if (std::any_of(contagious_attrs.cbegin(), contagious_attrs.cend(),
[&cnode](const std::string &attr_name) -> bool { return AnfAlgo::HasNodeAttr(attr_name, cnode); }) ||
std::any_of(attrs_with_value.cbegin(), attrs_with_value.cend(),
[&cnode](std::function<bool(const AnfNodePtr &node)> func) -> bool { return func(cnode); })) {
return true;
}

return false;
}
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

+ 1
- 0
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h View File

@@ -89,6 +89,7 @@ std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node);


CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info); CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info);
void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node); void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node);
bool IsKeepBasicNode(const AnfNodePtr &node);


template <typename T> template <typename T>
ValueNodePtr CreateScalarTensorValueNode(const DataInfo &info, T value, size_t data_length) { ValueNodePtr CreateScalarTensorValueNode(const DataInfo &info, T value, size_t data_length) {


Loading…
Cancel
Save