| @@ -90,6 +90,7 @@ KernelWithIndex VisitSplitKernel(const AnfNodePtr &anf_node, size_t index) { | |||||
| bool InputCheck(const AnfNodePtr &node) { | bool InputCheck(const AnfNodePtr &node) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto in_nums = AnfAlgo::GetInputTensorNum(node); | auto in_nums = AnfAlgo::GetInputTensorNum(node); | ||||
| for (size_t i = 0; i < in_nums; i++) { | for (size_t i = 0; i < in_nums; i++) { | ||||
| auto in_node = VisitSplitKernel(AnfAlgo::GetInputNode(cnode, i), 0).first; | auto in_node = VisitSplitKernel(AnfAlgo::GetInputNode(cnode, i), 0).first; | ||||
| @@ -98,7 +99,9 @@ bool InputCheck(const AnfNodePtr &node) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| if (in_node->isa<CNode>()) { | if (in_node->isa<CNode>()) { | ||||
| auto in_node_name = AnfAlgo::GetCNodeName(in_node->cast<CNodePtr>()); | |||||
| auto in_cnode = in_node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(in_cnode); | |||||
| auto in_node_name = AnfAlgo::GetCNodeName(in_cnode); | |||||
| auto trans_input = AnfAlgo::VisitKernel(in_node, 0).first; | auto trans_input = AnfAlgo::VisitKernel(in_node, 0).first; | ||||
| if (in_node_name == kTransDataOpName && (trans_input->isa<Parameter>() || trans_input->isa<ValueNode>())) { | if (in_node_name == kTransDataOpName && (trans_input->isa<Parameter>() || trans_input->isa<ValueNode>())) { | ||||
| MS_LOG(INFO) << "Data->TransData->split, can not optimizer."; | MS_LOG(INFO) << "Data->TransData->split, can not optimizer."; | ||||
| @@ -107,9 +110,9 @@ bool InputCheck(const AnfNodePtr &node) { | |||||
| if (in_node_name == prim::kPrimControlDepend->name() || in_node_name == prim::kPrimDepend->name()) { | if (in_node_name == prim::kPrimControlDepend->name() || in_node_name == prim::kPrimDepend->name()) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| if ((AnfAlgo::HasNodeAttr("non_task", cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, "non_task")) || | |||||
| (AnfAlgo::HasNodeAttr("nop_node", cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, "nop_node"))) { | |||||
| MS_LOG(INFO) << "Input has non_task or nop_node attr, can not optimizer."; | |||||
| if ((AnfAlgo::HasNodeAttr("non_task", in_cnode) && AnfAlgo::GetNodeAttr<bool>(in_node, "non_task")) || | |||||
| opt::IsNopNode(in_cnode)) { | |||||
| MS_LOG(INFO) << "Input is nop node or has non_task attr, can not optimizer."; | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| @@ -140,7 +143,7 @@ bool OutputCheck(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto op_name = AnfAlgo ::GetCNodeName(item); | auto op_name = AnfAlgo ::GetCNodeName(item); | ||||
| if (InvalidOps.find(op_name) != InvalidOps.end() || AnfAlgo::IsCommunicationOp(node)) { | |||||
| if (InvalidOps.find(op_name) != InvalidOps.end() || AnfAlgo::IsCommunicationOp(item)) { | |||||
| MS_LOG(INFO) << "Next node is " << item->fullname_with_scope() << ", not a invalid node, can not optimizer."; | MS_LOG(INFO) << "Next node is " << item->fullname_with_scope() << ", not a invalid node, can not optimizer."; | ||||
| return false; | return false; | ||||
| } | } | ||||