diff --git a/mindspore/lite/src/executor.cc b/mindspore/lite/src/executor.cc index 39748a8993..4aeb234432 100644 --- a/mindspore/lite/src/executor.cc +++ b/mindspore/lite/src/executor.cc @@ -49,11 +49,12 @@ int Executor::Run(std::vector &in_tensors, std::vector &out_ MS_LOG(ERROR) << "CheckInputs failed"; return ret; } - MS_ASSERT(std::all_of(kernels.begin(), kernels.end(), [](kernel::LiteKernel *kernel) { - return std::all_of(kernel->in_tensors().begin(), kernel->in_tensors().end(), [](Tensor *in_tensor) { - return in_tensor->IsConst() || in_tensor->IsGraphInput() || in_tensor->ref_count() == 0; - }); - })); + // clear ref_count + for (auto *kernel : kernels) { + for (auto *tensor : kernel->in_tensors()) { + tensor->set_ref_count(0); + } + } std::queue kernel_queue; for (auto kernel : kernels) { if (kernel->IsReady(kernel->in_tensors())) { diff --git a/mindspore/lite/src/lite_kernel.cc b/mindspore/lite/src/lite_kernel.cc index 08e756c8f7..98e0a4ce61 100644 --- a/mindspore/lite/src/lite_kernel.cc +++ b/mindspore/lite/src/lite_kernel.cc @@ -115,7 +115,7 @@ int LiteKernel::PreProcess() { for (auto *output : this->out_tensors()) { MS_ASSERT(output != nullptr); - if (output->ElementsNum() >= MAX_MALLOC_SIZE / static_cast(sizeof(int64_t))) { + if (output->ElementsNum() >= lite::MAX_MALLOC_SIZE / static_cast(sizeof(int64_t))) { MS_LOG(ERROR) << "The size of output tensor is too big"; return RET_ERROR; } diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index ff7558a7bf..c589c2dccc 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -251,6 +251,9 @@ void LiteSession::InitGraphInputMap(const lite::Model *model) { } auto tensor_name = in_node->name_ + std::to_string(i); this->input_map_[tensor_name] = in_tensor; + if (!in_tensor->tensor_name().empty()) { + this->input_map_[in_tensor->tensor_name()] = in_tensor; + } } } } diff --git a/mindspore/lite/src/runtime/allocator.cc b/mindspore/lite/src/runtime/allocator.cc index 35ac8a7854..75a033de9e 100644 --- a/mindspore/lite/src/runtime/allocator.cc +++ b/mindspore/lite/src/runtime/allocator.cc @@ -49,6 +49,10 @@ void *DefaultAllocator::Malloc(size_t size) { MS_LOG(ERROR) << "MallocData out of max_size, size: " << size; return nullptr; } + if (this->GetTotalSize() >= MAX_THREAD_POOL_SIZE) { + MS_LOG(ERROR) << "Memory pool is exhausted"; + return nullptr; + } Lock(); auto iter = freeList_.lower_bound(size); if (iter != freeList_.end() && (iter->second->size >= size) && (iter->second->size < (size << shiftFactor_))) { diff --git a/mindspore/lite/src/runtime/allocator.h b/mindspore/lite/src/runtime/allocator.h index da52943ad6..9d324747bd 100644 --- a/mindspore/lite/src/runtime/allocator.h +++ b/mindspore/lite/src/runtime/allocator.h @@ -39,7 +39,6 @@ class Allocator { virtual void Free(void *ptr) = 0; virtual void SetContext(const AllocatorContext &ctx) {} virtual size_t GetTotalSize() { return 0; } - virtual void Clear() {} static std::shared_ptr Create(); virtual void *Prepare(void *ptr) { return ptr; } std::string name; @@ -53,7 +52,7 @@ class DefaultAllocator : public Allocator { void *Malloc(size_t size) override; void Free(void *ptr) override; size_t GetTotalSize() override; - void Clear() override; + void Clear(); private: void Lock(); @@ -72,7 +71,8 @@ class DefaultAllocator : public Allocator { bool lockFlag_ = false; }; -#define MAX_MALLOC_SIZE (2000 * 1024 * 1024) +constexpr int64_t MAX_MALLOC_SIZE = static_cast(2000) * 1024 * 1024; +constexpr int64_t MAX_THREAD_POOL_SIZE = static_cast(3000) * 1024 * 1024; } // namespace mindspore::lite diff --git a/mindspore/lite/src/runtime/opencl/opencl_allocator.h b/mindspore/lite/src/runtime/opencl/opencl_allocator.h index ab73a31073..f339f15b35 100644 --- a/mindspore/lite/src/runtime/opencl/opencl_allocator.h +++ b/mindspore/lite/src/runtime/opencl/opencl_allocator.h @@ -42,7 +42,7 @@ class OpenCLAllocator : public Allocator { void Free(void *ptr) override; size_t GetTotalSize() override; - void Clear() override; + void Clear(); void *GetImage(void *host_ptr); void *GetBuffer(void *host_ptr); void *MapBuffer(void *host_ptr, int flags, void *command_queue = nullptr, bool sync = true);