|
|
@@ -103,6 +103,7 @@ bool MemReuseUtil::InitDynamicWorkspaceKernelRef() { |
|
|
bool MemReuseUtil::InitDynamicKernelRef(const KernelGraph *graph) { |
|
|
bool MemReuseUtil::InitDynamicKernelRef(const KernelGraph *graph) { |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
graph_ = graph; |
|
|
graph_ = graph; |
|
|
|
|
|
is_all_nop_node_ = opt::IsAllNopNode(graph); |
|
|
if (!InitDynamicOutputKernelRef()) { |
|
|
if (!InitDynamicOutputKernelRef()) { |
|
|
MS_LOG(INFO) << "InitDynamicOutputKernelRef fail"; |
|
|
MS_LOG(INFO) << "InitDynamicOutputKernelRef fail"; |
|
|
return false; |
|
|
return false; |
|
|
@@ -223,7 +224,6 @@ KernelRefCountPtr MemReuseUtil::GetRef(const AnfNodePtr &node, int output_idx) { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t input_idx) { |
|
|
KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t input_idx) { |
|
|
auto is_all_nop_node = opt::IsAllNopNode(graph_); |
|
|
|
|
|
if (input_idx >= AnfAlgo::GetInputTensorNum(kernel)) { |
|
|
if (input_idx >= AnfAlgo::GetInputTensorNum(kernel)) { |
|
|
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " |
|
|
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " |
|
|
<< AnfAlgo::GetInputTensorNum(kernel); |
|
|
<< AnfAlgo::GetInputTensorNum(kernel); |
|
|
@@ -231,7 +231,7 @@ KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t |
|
|
auto input_node = kernel->input(input_idx + 1); |
|
|
auto input_node = kernel->input(input_idx + 1); |
|
|
// Graph may be all nop nodes and not remove nop node, so this can not skip nop node. |
|
|
// Graph may be all nop nodes and not remove nop node, so this can not skip nop node. |
|
|
session::KernelWithIndex kernel_input; |
|
|
session::KernelWithIndex kernel_input; |
|
|
if (is_all_nop_node) { |
|
|
|
|
|
|
|
|
if (is_all_nop_node_) { |
|
|
// The graph does not remove the nop node. |
|
|
// The graph does not remove the nop node. |
|
|
kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); |
|
|
kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); |
|
|
} else { |
|
|
} else { |
|
|
@@ -265,7 +265,6 @@ void MemReuseUtil::SetKernelDefMap() { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void MemReuseUtil::SetKernelDefInputs() { |
|
|
void MemReuseUtil::SetKernelDefInputs() { |
|
|
auto is_all_nop_node = opt::IsAllNopNode(graph_); |
|
|
|
|
|
for (const auto &kernel : graph_->execution_order()) { |
|
|
for (const auto &kernel : graph_->execution_order()) { |
|
|
MS_EXCEPTION_IF_NULL(kernel); |
|
|
MS_EXCEPTION_IF_NULL(kernel); |
|
|
auto key = kernel.get(); |
|
|
auto key = kernel.get(); |
|
|
@@ -282,7 +281,7 @@ void MemReuseUtil::SetKernelDefInputs() { |
|
|
auto input_node = AnfAlgo::GetInputNode(kernel, i); |
|
|
auto input_node = AnfAlgo::GetInputNode(kernel, i); |
|
|
// Graph may be all nop nodes and not remove nop node, so this can not skip nop node. |
|
|
// Graph may be all nop nodes and not remove nop node, so this can not skip nop node. |
|
|
session::KernelWithIndex input; |
|
|
session::KernelWithIndex input; |
|
|
if (is_all_nop_node) { |
|
|
|
|
|
|
|
|
if (is_all_nop_node_) { |
|
|
// The graph does not remove the nop node. |
|
|
// The graph does not remove the nop node. |
|
|
input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); |
|
|
input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); |
|
|
} else { |
|
|
} else { |
|
|
@@ -349,11 +348,10 @@ void MemReuseUtil::SetSummaryNodesRefCount() { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void MemReuseUtil::SetGraphOutputRefCount() { |
|
|
void MemReuseUtil::SetGraphOutputRefCount() { |
|
|
auto is_all_nop_node = opt::IsAllNopNode(graph_); |
|
|
|
|
|
auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem}); |
|
|
auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem}); |
|
|
for (const auto &node : nodes) { |
|
|
for (const auto &node : nodes) { |
|
|
session::KernelWithIndex kernel_input; |
|
|
session::KernelWithIndex kernel_input; |
|
|
if (is_all_nop_node) { |
|
|
|
|
|
|
|
|
if (is_all_nop_node_) { |
|
|
// The graph does not remove the nop node. |
|
|
// The graph does not remove the nop node. |
|
|
kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0, false); |
|
|
kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0, false); |
|
|
} else { |
|
|
} else { |
|
|
|