|
|
|
@@ -27,7 +27,12 @@ void CPUSimpleMemPlan::MemPlan(const session::KernelGraph *graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel); |
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel); |
|
|
|
for (size_t i = 0; i < input_num; ++i) { |
|
|
|
auto address = AnfAlgo::GetPrevNodeOutputAddr(kernel, i); |
|
|
|
auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(kernel, i); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_with_index.first); |
|
|
|
if (kernel_with_index.first->isa<Parameter>()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto address = AnfAlgo::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, true); |
|
|
|
MS_EXCEPTION_IF_NULL(address); |
|
|
|
if (address->ptr_ == nullptr) { |
|
|
|
total_mem_size += address->size_; |
|
|
|
@@ -73,7 +78,12 @@ void CPUSimpleMemPlan::MemAssign(const session::KernelGraph *graph, uint8_t *bas |
|
|
|
MS_EXCEPTION_IF_NULL(kernel); |
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel); |
|
|
|
for (size_t i = 0; i < input_num; ++i) { |
|
|
|
auto address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); |
|
|
|
auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(kernel, i); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_with_index.first); |
|
|
|
if (kernel_with_index.first->isa<Parameter>()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto address = AnfAlgo::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, true); |
|
|
|
MS_EXCEPTION_IF_NULL(address); |
|
|
|
if (address->ptr_ == nullptr) { |
|
|
|
address->ptr_ = mem_ptr; |
|
|
|
|