Browse Source

!4738 reduce int8 add resize

Merge pull request !4738 from zhaozhenlong/lite/issue/reduce_int8_resize
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
459b6c58d7
2 changed files with 34 additions and 15 deletions
  1. +32
    -7
      mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.cc
  2. +2
    -8
      mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.h

+ 32
- 7
mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.cc View File

@@ -189,7 +189,8 @@ int ReduceInt8CPUKernel::MallocTmpBuffer() {
size *= input_shape[j];
}
}
int32_t *buffer = reinterpret_cast<int32_t *>(malloc(size * sizeof(int32_t)));
MS_ASSERT(context_->allocator != nullptr);
int32_t *buffer = reinterpret_cast<int32_t *>(context_->allocator->Malloc(size * sizeof(int32_t)));
if (buffer == nullptr) {
MS_LOG(ERROR) << "Malloc data failed.";
return RET_ERROR;
@@ -199,7 +200,7 @@ int ReduceInt8CPUKernel::MallocTmpBuffer() {
}

auto input = in_tensors_.at(0);
begin_src_data_ = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t) * input->ElementsNum()));
begin_src_data_ = reinterpret_cast<int32_t *>(context_->allocator->Malloc(sizeof(int32_t) * input->ElementsNum()));
if (begin_src_data_ == nullptr) {
return RET_NULL_PTR;
}
@@ -210,6 +211,32 @@ int ReduceInt8CPUKernel::MallocTmpBuffer() {
return RET_OK;
}

void ReduceInt8CPUKernel::FreeTmpBuffer() {
for (auto buffer : data_buffers_) {
if (buffer != nullptr) {
MS_ASSERT(context_->allocator != nullptr);
context_->allocator->Free(buffer);
buffer = nullptr;
}
}
data_buffers_.clear();

if (begin_src_data_ != nullptr) {
MS_ASSERT(context_->allocator != nullptr);
context_->allocator->Free(begin_src_data_);
begin_src_data_ = nullptr;
}
}

int ReduceInt8CPUKernel::ReSize() {
FreeTmpBuffer();
auto ret = MallocTmpBuffer();
if (ret != RET_OK) {
FreeTmpBuffer();
}
return ret;
}

int ReduceInt8Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto reduce = reinterpret_cast<ReduceInt8CPUKernel *>(cdata);
auto error_code = reduce->CallReduceUnit(task_id);
@@ -261,6 +288,7 @@ int ReduceInt8CPUKernel::Run() {
axis_size_ = tmp_shape_[axis];
auto error_code = LiteBackendParallelLaunch(ReduceInt8Impl, this, context_->thread_num_);
if (error_code != RET_OK) {
FreeTmpBuffer();
MS_LOG(ERROR) << "Reduce run error, error_code[" << error_code << "]";
return RET_ERROR;
}
@@ -298,14 +326,11 @@ int ReduceInt8CPUKernel::Run() {
auto error_code = LiteBackendParallelLaunch(ReduceInt8Impl, this, context_->thread_num_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Reduce run error, error_code[" << error_code << "]";
FreeTmpBuffer();
return RET_ERROR;
}

if (begin_src_data_ != nullptr) {
free(begin_src_data_);
begin_src_data_ = nullptr;
}

FreeTmpBuffer();
return RET_OK;
}



+ 2
- 8
mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.h View File

@@ -40,13 +40,6 @@ class ReduceInt8CPUKernel : public ReduceBaseCPUKernel {
const mindspore::lite::PrimitiveC *primitive)
: ReduceBaseCPUKernel(param, inputs, outputs, ctx, primitive) {}
~ReduceInt8CPUKernel() {
for (auto i = 0; i < data_buffers_.size(); i++) {
int32_t *buffer = data_buffers_[i];
if (buffer != nullptr) {
free(buffer);
buffer = nullptr;
}
}
for (auto qm : mean_multipliers_) {
delete qm;
qm = nullptr;
@@ -64,7 +57,7 @@ class ReduceInt8CPUKernel : public ReduceBaseCPUKernel {
}

int Init() override;
int ReSize() override { return 0; };
int ReSize() override;
int Run() override;
int CallReduceUnit(int task_id);
int ReduceLastAxis(int task_id);
@@ -74,6 +67,7 @@ class ReduceInt8CPUKernel : public ReduceBaseCPUKernel {

private:
int MallocTmpBuffer();
void FreeTmpBuffer();
int CalculateQuantArgs();

private:


Loading…
Cancel
Save