|
|
|
@@ -401,6 +401,15 @@ bool BestFitMemReuse::IsReusableStream(uint32_t curr_stream_id, uint32_t target_ |
|
|
|
return curr_parallel_set.find(target_stream_id) == curr_parallel_set.end(); |
|
|
|
} |
|
|
|
|
|
|
|
bool BestFitMemReuse::IsRelease(const std::string &kernel_name) { |
|
|
|
// unable_used_node include the node type that output tensor cannot be released, |
|
|
|
// even if its refcount is equal to zero. |
|
|
|
std::unordered_set<std::string> unable_used_node = {prim::kPrimBatchNorm->name(), prim::kPrimBatchNormGrad->name(), |
|
|
|
prim::kPrimFusedBatchNorm->name(), |
|
|
|
prim::kPrimFusedBatchNormGrad->name()}; |
|
|
|
return unable_used_node.find(kernel_name) == unable_used_node.end(); |
|
|
|
} |
|
|
|
|
|
|
|
void BestFitMemReuse::CheckTensorIndex(int tensor_index) const { |
|
|
|
if (tensor_index < 0) { |
|
|
|
MS_LOG(EXCEPTION) << "warning, please check tensor info."; |
|
|
|
@@ -437,6 +446,9 @@ void BestFitMemReuse::Reuse(const MemReuseUtil *mem_reuse_util_ptr) { |
|
|
|
// update node input tensor refcount, and membuf list status |
|
|
|
UpdateNodeInputAndMembuf(op_def_ptr.get()); |
|
|
|
// check node output tensor which refcount is equal to zero |
|
|
|
if (IsRelease(op_def_ptr->kernel_name())) { |
|
|
|
ReleaseNodeUnusedOutput(op_def_ptr.get()); |
|
|
|
} |
|
|
|
#ifdef MEM_REUSE_DEBUG |
|
|
|
MemReuseChecker::GetInstance().SetMembuInfos(op_def_ptr.get(), membuf_ptr_list_); |
|
|
|
++op_num; |
|
|
|
|