|
|
|
@@ -189,7 +189,8 @@ int ReduceInt8CPUKernel::MallocTmpBuffer() { |
|
|
|
size *= input_shape[j]; |
|
|
|
} |
|
|
|
} |
|
|
|
int32_t *buffer = reinterpret_cast<int32_t *>(malloc(size * sizeof(int32_t))); |
|
|
|
MS_ASSERT(context_->allocator != nullptr); |
|
|
|
int32_t *buffer = reinterpret_cast<int32_t *>(context_->allocator->Malloc(size * sizeof(int32_t))); |
|
|
|
if (buffer == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Malloc data failed."; |
|
|
|
return RET_ERROR; |
|
|
|
@@ -199,7 +200,7 @@ int ReduceInt8CPUKernel::MallocTmpBuffer() { |
|
|
|
} |
|
|
|
|
|
|
|
auto input = in_tensors_.at(0); |
|
|
|
begin_src_data_ = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t) * input->ElementsNum())); |
|
|
|
begin_src_data_ = reinterpret_cast<int32_t *>(context_->allocator->Malloc(sizeof(int32_t) * input->ElementsNum())); |
|
|
|
if (begin_src_data_ == nullptr) { |
|
|
|
return RET_NULL_PTR; |
|
|
|
} |
|
|
|
@@ -210,6 +211,32 @@ int ReduceInt8CPUKernel::MallocTmpBuffer() { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
void ReduceInt8CPUKernel::FreeTmpBuffer() { |
|
|
|
for (auto buffer : data_buffers_) { |
|
|
|
if (buffer != nullptr) { |
|
|
|
MS_ASSERT(context_->allocator != nullptr); |
|
|
|
context_->allocator->Free(buffer); |
|
|
|
buffer = nullptr; |
|
|
|
} |
|
|
|
} |
|
|
|
data_buffers_.clear(); |
|
|
|
|
|
|
|
if (begin_src_data_ != nullptr) { |
|
|
|
MS_ASSERT(context_->allocator != nullptr); |
|
|
|
context_->allocator->Free(begin_src_data_); |
|
|
|
begin_src_data_ = nullptr; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
int ReduceInt8CPUKernel::ReSize() { |
|
|
|
FreeTmpBuffer(); |
|
|
|
auto ret = MallocTmpBuffer(); |
|
|
|
if (ret != RET_OK) { |
|
|
|
FreeTmpBuffer(); |
|
|
|
} |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
int ReduceInt8Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { |
|
|
|
auto reduce = reinterpret_cast<ReduceInt8CPUKernel *>(cdata); |
|
|
|
auto error_code = reduce->CallReduceUnit(task_id); |
|
|
|
@@ -261,6 +288,7 @@ int ReduceInt8CPUKernel::Run() { |
|
|
|
axis_size_ = tmp_shape_[axis]; |
|
|
|
auto error_code = LiteBackendParallelLaunch(ReduceInt8Impl, this, context_->thread_num_); |
|
|
|
if (error_code != RET_OK) { |
|
|
|
FreeTmpBuffer(); |
|
|
|
MS_LOG(ERROR) << "Reduce run error, error_code[" << error_code << "]"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
@@ -298,14 +326,11 @@ int ReduceInt8CPUKernel::Run() { |
|
|
|
auto error_code = LiteBackendParallelLaunch(ReduceInt8Impl, this, context_->thread_num_); |
|
|
|
if (error_code != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Reduce run error, error_code[" << error_code << "]"; |
|
|
|
FreeTmpBuffer(); |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
if (begin_src_data_ != nullptr) { |
|
|
|
free(begin_src_data_); |
|
|
|
begin_src_data_ = nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
FreeTmpBuffer(); |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
|