| @@ -140,7 +140,7 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr | |||||
| for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { | for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { | ||||
| auto node = (*iter)->cast<CNodePtr>(); | auto node = (*iter)->cast<CNodePtr>(); | ||||
| if (node == nullptr || fused_ops->count(node)) { | |||||
| if (node == nullptr || IsKeepBasicNode(node) || fused_ops->count(node)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| bool is_fusible_op = IsFusibleOp(node); | bool is_fusible_op = IsFusibleOp(node); | ||||
| @@ -147,7 +147,8 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(mng); | MS_EXCEPTION_IF_NULL(mng); | ||||
| for (const auto &n : todos) { | for (const auto &n : todos) { | ||||
| auto node = n->cast<CNodePtr>(); | auto node = n->cast<CNodePtr>(); | ||||
| if (node == nullptr || !AnfAlgo::IsRealKernel(node) || AnfAlgo::IsGraphKernel(node) || !CanExpand(node)) { | |||||
| if (node == nullptr || IsKeepBasicNode(node) || !AnfAlgo::IsRealKernel(node) || AnfAlgo::IsGraphKernel(node) || | |||||
| !CanExpand(node)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -773,5 +773,30 @@ void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfN | |||||
| // Set attr secondly. | // Set attr secondly. | ||||
| AnfAlgo::SetNodeAttr(key, value, node); | AnfAlgo::SetNodeAttr(key, value, node); | ||||
| } | } | ||||
| bool IsKeepBasicNode(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (!node->isa<CNode>()) { | |||||
| return false; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| static std::vector<std::string> contagious_attrs = {"inplace_group", "inplace_algo", "inplace_output_index", | |||||
| "aggregate", "aggregate_input_indexx"}; | |||||
| static std::vector<std::function<bool(const AnfNodePtr &node)>> attrs_with_value = { | |||||
| [](const AnfNodePtr &n) -> bool { return AnfAlgo::GetBooleanAttr(n, "skip"); }}; | |||||
| // If node contain attribute in contagious_attrs, it have to keep basic no matter what the value is. | |||||
| // If node contain attribute in attrs_with_value, it only have to keep basic when the check result is true. | |||||
| if (std::any_of(contagious_attrs.cbegin(), contagious_attrs.cend(), | |||||
| [&cnode](const std::string &attr_name) -> bool { return AnfAlgo::HasNodeAttr(attr_name, cnode); }) || | |||||
| std::any_of(attrs_with_value.cbegin(), attrs_with_value.cend(), | |||||
| [&cnode](std::function<bool(const AnfNodePtr &node)> func) -> bool { return func(cnode); })) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -89,6 +89,7 @@ std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node); | |||||
| CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info); | CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info); | ||||
| void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node); | void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node); | ||||
| bool IsKeepBasicNode(const AnfNodePtr &node); | |||||
| template <typename T> | template <typename T> | ||||
| ValueNodePtr CreateScalarTensorValueNode(const DataInfo &info, T value, size_t data_length) { | ValueNodePtr CreateScalarTensorValueNode(const DataInfo &info, T value, size_t data_length) { | ||||