| @@ -223,13 +223,21 @@ KernelRefCountPtr MemReuseUtil::GetRef(const AnfNodePtr &node, int output_idx) { | |||||
| } | } | ||||
| KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t input_idx) { | KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t input_idx) { | ||||
| auto is_all_nop_node = opt::IsAllNopNode(graph_); | |||||
| if (input_idx >= AnfAlgo::GetInputTensorNum(kernel)) { | if (input_idx >= AnfAlgo::GetInputTensorNum(kernel)) { | ||||
| MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " | MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " | ||||
| << AnfAlgo::GetInputTensorNum(kernel); | << AnfAlgo::GetInputTensorNum(kernel); | ||||
| } | } | ||||
| auto input_node = kernel->input(input_idx + 1); | auto input_node = kernel->input(input_idx + 1); | ||||
| // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. | // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. | ||||
| auto kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); | |||||
| session::KernelWithIndex kernel_input; | |||||
| if (is_all_nop_node) { | |||||
| // The graph does not remove the nop node. | |||||
| kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); | |||||
| } else { | |||||
| // The graph removes the nop node. | |||||
| kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true); | |||||
| } | |||||
| if (IsPrimitive(kernel_input.first, prim::kPrimMakeTuple)) { | if (IsPrimitive(kernel_input.first, prim::kPrimMakeTuple)) { | ||||
| MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << input_idx << " is MakeTuple"; | MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << input_idx << " is MakeTuple"; | ||||
| } | } | ||||
| @@ -257,6 +265,7 @@ void MemReuseUtil::SetKernelDefMap() { | |||||
| } | } | ||||
| void MemReuseUtil::SetKernelDefInputs() { | void MemReuseUtil::SetKernelDefInputs() { | ||||
| auto is_all_nop_node = opt::IsAllNopNode(graph_); | |||||
| for (const auto &kernel : graph_->execution_order()) { | for (const auto &kernel : graph_->execution_order()) { | ||||
| MS_EXCEPTION_IF_NULL(kernel); | MS_EXCEPTION_IF_NULL(kernel); | ||||
| auto key = kernel.get(); | auto key = kernel.get(); | ||||
| @@ -272,7 +281,14 @@ void MemReuseUtil::SetKernelDefInputs() { | |||||
| // set the inputs of this kernel_def | // set the inputs of this kernel_def | ||||
| auto input_node = AnfAlgo::GetInputNode(kernel, i); | auto input_node = AnfAlgo::GetInputNode(kernel, i); | ||||
| // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. | // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. | ||||
| auto input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); | |||||
| session::KernelWithIndex input; | |||||
| if (is_all_nop_node) { | |||||
| // The graph does not remove the nop node. | |||||
| input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); | |||||
| } else { | |||||
| // The graph removes the nop node. | |||||
| input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true); | |||||
| } | |||||
| if (IsPrimitive(input.first, prim::kPrimMakeTuple)) { | if (IsPrimitive(input.first, prim::kPrimMakeTuple)) { | ||||
| MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << i << " is MakeTuple"; | MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << i << " is MakeTuple"; | ||||
| } | } | ||||