|
|
|
@@ -773,5 +773,30 @@ void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfN |
|
|
|
// Set attr secondly. |
|
|
|
AnfAlgo::SetNodeAttr(key, value, node); |
|
|
|
} |
|
|
|
|
|
|
|
bool IsKeepBasicNode(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
|
|
|
|
static std::vector<std::string> contagious_attrs = {"inplace_group", "inplace_algo", "inplace_output_index", |
|
|
|
"aggregate", "aggregate_input_indexx"}; |
|
|
|
static std::vector<std::function<bool(const AnfNodePtr &node)>> attrs_with_value = { |
|
|
|
[](const AnfNodePtr &n) -> bool { return AnfAlgo::GetBooleanAttr(n, "skip"); }}; |
|
|
|
|
|
|
|
// If node contain attribute in contagious_attrs, it have to keep basic no matter what the value is. |
|
|
|
// If node contain attribute in attrs_with_value, it only have to keep basic when the check result is true. |
|
|
|
if (std::any_of(contagious_attrs.cbegin(), contagious_attrs.cend(), |
|
|
|
[&cnode](const std::string &attr_name) -> bool { return AnfAlgo::HasNodeAttr(attr_name, cnode); }) || |
|
|
|
std::any_of(attrs_with_value.cbegin(), attrs_with_value.cend(), |
|
|
|
[&cnode](std::function<bool(const AnfNodePtr &node)> func) -> bool { return func(cnode); })) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
return false; |
|
|
|
} |
|
|
|
} // namespace opt |
|
|
|
} // namespace mindspore |