Browse Source

!5836 reduce int8 move assign data from malloc to run

Merge pull request !5836 from zhaozhenlong/lite/issue/reduce_int8_malloc
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
96eb284f40
13 changed files with 193 additions and 228 deletions
  1. +2
    -2
      mindspore/lite/nnacl/fp16/reduce_fp16.c
  2. +1
    -1
      mindspore/lite/nnacl/fp16/reduce_fp16.h
  3. +17
    -17
      mindspore/lite/nnacl/fp32/reduce.c
  4. +11
    -11
      mindspore/lite/nnacl/fp32/reduce.h
  5. +48
    -1
      mindspore/lite/src/runtime/kernel/arm/base/reduce_base.cc
  6. +6
    -1
      mindspore/lite/src/runtime/kernel/arm/base/reduce_base.h
  7. +7
    -24
      mindspore/lite/src/runtime/kernel/arm/fp16/reduce_fp16.cc
  8. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/fp16/reduce_fp16.h
  9. +31
    -58
      mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc
  10. +3
    -8
      mindspore/lite/src/runtime/kernel/arm/fp32/reduce.h
  11. +46
    -86
      mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.cc
  12. +2
    -0
      mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.h
  13. +18
    -18
      mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/reduce_fp32_tests.cc

+ 2
- 2
mindspore/lite/nnacl/fp16/reduce_fp16.c View File

@@ -19,8 +19,8 @@
#include "nnacl/errorcode.h"

int ReduceMeanFp16(const int outer_size, const int inner_size, const int axis_size, const float16_t *src_data,
const int *src_shape, float16_t *dst_data, const int tid, const int thread_num) {
if (src_data == NULL || src_shape == NULL || dst_data == NULL) {
float16_t *dst_data, const int tid, const int thread_num) {
if (src_data == NULL || dst_data == NULL) {
return NNACL_NULL_PTR;
}
int i, j, k;


+ 1
- 1
mindspore/lite/nnacl/fp16/reduce_fp16.h View File

@@ -26,7 +26,7 @@
extern "C" {
#endif
int ReduceMeanFp16(const int outer_size, const int inner_size, const int axis_size, const float16_t *src_data,
const int *src_shape, float16_t *dst_data, const int tid, const int thread_num);
float16_t *dst_data, const int tid, const int thread_num);

#ifdef __cplusplus
}


+ 17
- 17
mindspore/lite/nnacl/fp32/reduce.c View File

@@ -18,9 +18,9 @@
#include "nnacl/fp32/reduce.h"
#include "nnacl/errorcode.h"

int ReduceMean(const int outer_size, const int inner_size, const int axis_size, const float *src_data,
const int *src_shape, float *dst_data, const int tid, const int thread_num) {
if (src_data == NULL || src_shape == NULL || dst_data == NULL) {
int ReduceMean(const int outer_size, const int inner_size, const int axis_size, const float *src_data, float *dst_data,
const int tid, const int thread_num) {
if (src_data == NULL || dst_data == NULL) {
return NNACL_NULL_PTR;
}
int i, j, k;
@@ -39,9 +39,9 @@ int ReduceMean(const int outer_size, const int inner_size, const int axis_size,
}
return NNACL_OK;
}
int ReduceSum(const int outer_size, const int inner_size, const int axis_size, const float *src_data,
const int *src_shape, float *dst_data, const int tid, const int thread_num) {
if (src_data == NULL || src_shape == NULL || dst_data == NULL) {
int ReduceSum(const int outer_size, const int inner_size, const int axis_size, const float *src_data, float *dst_data,
const int tid, const int thread_num) {
if (src_data == NULL || dst_data == NULL) {
return NNACL_NULL_PTR;
}
int i, j, k;
@@ -60,9 +60,9 @@ int ReduceSum(const int outer_size, const int inner_size, const int axis_size, c
}
return NNACL_OK;
}
int ReduceMax(const int outer_size, const int inner_size, const int axis_size, const float *src_data,
const int *src_shape, float *dst_data, const int tid, const int thread_num) {
if (src_data == NULL || src_shape == NULL || dst_data == NULL) {
int ReduceMax(const int outer_size, const int inner_size, const int axis_size, const float *src_data, float *dst_data,
const int tid, const int thread_num) {
if (src_data == NULL || dst_data == NULL) {
return NNACL_NULL_PTR;
}
int i, j, k;
@@ -81,9 +81,9 @@ int ReduceMax(const int outer_size, const int inner_size, const int axis_size, c
}
return NNACL_OK;
}
int ReduceMin(const int outer_size, const int inner_size, const int axis_size, const float *src_data,
const int *src_shape, float *dst_data, const int tid, const int thread_num) {
if (src_data == NULL || src_shape == NULL || dst_data == NULL) {
int ReduceMin(const int outer_size, const int inner_size, const int axis_size, const float *src_data, float *dst_data,
const int tid, const int thread_num) {
if (src_data == NULL || dst_data == NULL) {
return NNACL_NULL_PTR;
}
int i, j, k;
@@ -102,9 +102,9 @@ int ReduceMin(const int outer_size, const int inner_size, const int axis_size, c
}
return NNACL_OK;
}
int ReduceProd(const int outer_size, const int inner_size, const int axis_size, const float *src_data,
const int *src_shape, float *dst_data, const int tid, const int thread_num) {
if (src_data == NULL || src_shape == NULL || dst_data == NULL) {
int ReduceProd(const int outer_size, const int inner_size, const int axis_size, const float *src_data, float *dst_data,
const int tid, const int thread_num) {
if (src_data == NULL || dst_data == NULL) {
return NNACL_NULL_PTR;
}
int i, j, k;
@@ -124,8 +124,8 @@ int ReduceProd(const int outer_size, const int inner_size, const int axis_size,
return NNACL_OK;
}
int ReduceSumSquare(const int outer_size, const int inner_size, const int axis_size, const float *src_data,
const int *src_shape, float *dst_data, const int tid, const int thread_num) {
if (src_data == NULL || src_shape == NULL || dst_data == NULL) {
float *dst_data, const int tid, const int thread_num) {
if (src_data == NULL || dst_data == NULL) {
return NNACL_NULL_PTR;
}
int i, j, k;


+ 11
- 11
mindspore/lite/nnacl/fp32/reduce.h View File

@@ -22,18 +22,18 @@
#ifdef __cplusplus
extern "C" {
#endif
int ReduceMean(const int outer_size, const int inner_size, const int axis_size, const float *src_data,
const int *src_shape, float *dst_data, const int tid, const int thread_num);
int ReduceSum(const int outer_size, const int inner_size, const int axis_size, const float *src_data,
const int *src_shape, float *dst_data, const int tid, const int thread_num);
int ReduceMax(const int outer_size, const int inner_size, const int axis_size, const float *src_data,
const int *src_shape, float *dst_data, const int tid, const int thread_num);
int ReduceMin(const int outer_size, const int inner_size, const int axis_size, const float *src_data,
const int *src_shape, float *dst_data, const int tid, const int thread_num);
int ReduceProd(const int outer_size, const int inner_size, const int axis_size, const float *src_data,
const int *src_shape, float *dst_data, const int tid, const int thread_num);
int ReduceMean(const int outer_size, const int inner_size, const int axis_size, const float *src_data, float *dst_data,
const int tid, const int thread_num);
int ReduceSum(const int outer_size, const int inner_size, const int axis_size, const float *src_data, float *dst_data,
const int tid, const int thread_num);
int ReduceMax(const int outer_size, const int inner_size, const int axis_size, const float *src_data, float *dst_data,
const int tid, const int thread_num);
int ReduceMin(const int outer_size, const int inner_size, const int axis_size, const float *src_data, float *dst_data,
const int tid, const int thread_num);
int ReduceProd(const int outer_size, const int inner_size, const int axis_size, const float *src_data, float *dst_data,
const int tid, const int thread_num);
int ReduceSumSquare(const int outer_size, const int inner_size, const int axis_size, const float *src_data,
const int *src_shape, float *dst_data, const int tid, const int thread_num);
float *dst_data, const int tid, const int thread_num);
#ifdef __cplusplus
}
#endif


+ 48
- 1
mindspore/lite/src/runtime/kernel/arm/base/reduce_base.cc View File

@@ -120,7 +120,54 @@ int ReduceBaseCPUKernel::Init() {
return RET_OK;
}

int ReduceBaseCPUKernel::ReSize() { return CheckParameters(); }
void ReduceBaseCPUKernel::CalculateInnerOuterSize() {
outer_sizes_.clear();
inner_sizes_.clear();
axis_sizes_.clear();
auto tmp_shape = in_tensors_.at(0)->shape();
for (auto i = 0; i < num_axes_; ++i) {
int axis = axes_[i];
auto outer_size = 1;
for (int j = 0; j < axis; j++) {
outer_size *= tmp_shape[j];
}
outer_sizes_.emplace_back(outer_size);
auto inner_size = 1;
for (int k = axis + 1; k < static_cast<int>(tmp_shape.size()); k++) {
inner_size *= tmp_shape[k];
}
inner_sizes_.emplace_back(inner_size);
axis_sizes_.emplace_back(tmp_shape[axis]);
tmp_shape[axis] = 1;
}
}

void ReduceBaseCPUKernel::CalculateTmpBufferSize() {
buffer_sizes_.clear();
auto input_shape = in_tensors_.at(0)->shape();
for (auto i = 0; i < num_axes_; i++) {
int axis = axes_[i];
size_t size = 1;
for (size_t j = 0; j < input_shape.size(); j++) {
if (axis != static_cast<int>(j)) {
size *= input_shape[j];
}
}
MS_ASSERT(context_->allocator != nullptr);
buffer_sizes_.emplace_back(size);
input_shape[axis] = 1;
}
}

int ReduceBaseCPUKernel::ReSize() {
auto ret = CheckParameters();
if (ret != RET_OK) {
return ret;
}
CalculateTmpBufferSize();
CalculateInnerOuterSize();
return RET_OK;
}

kernel::LiteKernel *CpuReduceFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,


+ 6
- 1
mindspore/lite/src/runtime/kernel/arm/base/reduce_base.h View File

@@ -45,10 +45,15 @@ class ReduceBaseCPUKernel : public LiteKernel {
bool reduce_to_end_;

protected:
void CalculateTmpBufferSize();
void CalculateInnerOuterSize();
std::vector<size_t> buffer_sizes_;
std::vector<int> outer_sizes_;
std::vector<int> inner_sizes_;
std::vector<int> axis_sizes_;
int outer_size_;
int inner_size_;
int axis_size_;
std::vector<int> tmp_shape_;
};
} // namespace mindspore::kernel



+ 7
- 24
mindspore/lite/src/runtime/kernel/arm/fp16/reduce_fp16.cc View File

@@ -60,8 +60,8 @@ int ReduceFp16CPUKernel::Init() {
int ReduceFp16CPUKernel::ReSize() { return ReduceBaseCPUKernel::ReSize(); }

int ReduceFp16CPUKernel::CallReduceUnit(int task_id) {
auto ret = reducer_(outer_size_, inner_size_, axis_size_, fp16_src_data_, tmp_shape_.data(), fp16_dst_data_, task_id,
context_->thread_num_);
auto ret =
reducer_(outer_size_, inner_size_, axis_size_, fp16_src_data_, fp16_dst_data_, task_id, context_->thread_num_);
return ret;
}

@@ -88,7 +88,6 @@ int ReduceFp16CPUKernel::Run() {
return ret;
}

tmp_shape_ = in_tensors_.at(0)->shape();
auto in_tensor = in_tensors_.at(0);
if (in_tensor->data_type() == kNumberTypeFloat32 || in_tensor->data_type() == kNumberTypeFloat) {
auto input_data = reinterpret_cast<float *>(in_tensor->MutableData());
@@ -100,23 +99,15 @@ int ReduceFp16CPUKernel::Run() {
fp16_src_data_ = fp16_input_;
for (int i = 0; i < data_buffers_.size(); ++i) {
fp16_dst_data_ = data_buffers_[i];
int axis = axes_[i];
outer_size_ = 1;
for (int j = 0; j < axis; j++) {
outer_size_ *= tmp_shape_[j];
}
inner_size_ = 1;
for (int k = axis + 1; k < static_cast<int>(tmp_shape_.size()); k++) {
inner_size_ *= tmp_shape_[k];
}
axis_size_ = tmp_shape_[axis];
outer_size_ = outer_sizes_[i];
inner_size_ = inner_sizes_[i];
axis_size_ = axis_sizes_[i];
auto error_code = ParallelLaunch(THREAD_POOL_DEFAULT, ReduceImpl, this, context_->thread_num_);
if (error_code != RET_OK) {
FreeTmpBuffer();
MS_LOG(ERROR) << "Reduce run error, error_code[" << error_code << "]";
return RET_ERROR;
}
tmp_shape_[axis] = 1;
fp16_src_data_ = fp16_dst_data_;
}

@@ -151,22 +142,14 @@ void ReduceFp16CPUKernel::FreeTmpBuffer() {
}

int ReduceFp16CPUKernel::MallocTmpBuffer() {
auto input_shape = in_tensors_.at(0)->shape();
for (auto i = 0; i < num_axes_; i++) {
int axis = axes_[i];
size_t size = 1;
for (auto j = 0; j < input_shape.size(); j++) {
if (static_cast<size_t>(axis) != j) {
size *= input_shape[j];
}
}
data_buffers_.clear();
for (auto size : buffer_sizes_) {
float16_t *buffer = reinterpret_cast<float16_t *>(context_->allocator->Malloc(size * sizeof(float16_t)));
if (buffer == nullptr) {
MS_LOG(ERROR) << "Malloc data failed";
return RET_ERROR;
}
data_buffers_.emplace_back(buffer);
input_shape[axis] = 1;
}

auto in_tensor = in_tensors_.front();


+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/fp16/reduce_fp16.h View File

@@ -27,7 +27,7 @@ using mindspore::schema::ReduceMode;
namespace mindspore::kernel {
class ReduceFp16CPUKernel : public ReduceBaseCPUKernel {
typedef int (*Reducer)(const int outer_size, const int inner_size, const int axis_size, const float16_t *src_data,
const int *src_shape, float16_t *dst_data, const int tid, const int thread_num);
float16_t *dst_data, const int tid, const int thread_num);

public:
ReduceFp16CPUKernel(OpParameter *param, const std::vector<lite::Tensor *> &inputs,


+ 31
- 58
mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc View File

@@ -81,17 +81,10 @@ int ReduceCPUKernel::Init() {
return ReSize();
}

int ReduceCPUKernel::ReSize() {
auto ret = ReduceBaseCPUKernel::ReSize();
if (ret != RET_OK) {
return ret;
}
return MallocTmpBuffer();
}
int ReduceCPUKernel::ReSize() { return ReduceBaseCPUKernel::ReSize(); }

int ReduceCPUKernel::CallReduceUnit(int task_id) {
auto ret = reducer_(outer_size_, inner_size_, axis_size_, src_data_, tmp_shape_.data(), dst_data_, task_id,
context_->thread_num_);
auto ret = reducer_(outer_size_, inner_size_, axis_size_, src_data_, dst_data_, task_id, context_->thread_num_);
return ret;
}

@@ -111,75 +104,55 @@ int ReduceCPUKernel::Run() {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
tmp_shape_ = in_tensors_.at(0)->shape();
auto ret = MallocTmpBuffer();
if (ret != RET_OK) {
FreeTmpBuffer();
return ret;
}

src_data_ = static_cast<float *>(in_tensors_.at(0)->MutableData());
for (size_t i = 0; i < data_buffers_.size(); ++i) {
dst_data_ = data_buffers_[i];
int axis = axes_[i];
outer_size_ = 1;
for (int j = 0; j < axis; j++) {
outer_size_ *= tmp_shape_[j];
for (size_t i = 0; i < static_cast<size_t>(num_axes_); ++i) {
if (i != static_cast<size_t>(num_axes_ - 1)) {
dst_data_ = data_buffers_[i];
} else {
dst_data_ = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
}
inner_size_ = 1;
for (int k = axis + 1; k < static_cast<int>(tmp_shape_.size()); k++) {
inner_size_ *= tmp_shape_[k];
}
axis_size_ = tmp_shape_[axis];
outer_size_ = outer_sizes_[i];
inner_size_ = inner_sizes_[i];
axis_size_ = axis_sizes_[i];
auto error_code = ParallelLaunch(THREAD_POOL_DEFAULT, ReduceImpl, this, context_->thread_num_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Reduce run error, error_code[" << error_code << "]";
FreeTmpBuffer();
return RET_ERROR;
}
tmp_shape_[axis] = 1;
src_data_ = dst_data_;
}

int last_reduce_axis = axes_[num_axes_ - 1];
outer_size_ = 1;
for (int i = 0; i < last_reduce_axis; i++) {
outer_size_ *= tmp_shape_[i];
}
inner_size_ = 1;
for (int i = last_reduce_axis + 1; i < static_cast<int>(tmp_shape_.size()); i++) {
inner_size_ *= tmp_shape_[i];
}
axis_size_ = tmp_shape_[last_reduce_axis];
dst_data_ = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
auto error_code = ParallelLaunch(THREAD_POOL_DEFAULT, ReduceImpl, this, context_->thread_num_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Reduce run error, error_code[" << error_code << "]";
return RET_ERROR;
}

FreeTmpBuffer();
return RET_OK;
}

int ReduceCPUKernel::MallocTmpBuffer() {
for (auto buffer : data_buffers_) {
if (buffer != nullptr) {
free(buffer);
buffer = nullptr;
}
}
data_buffers_.clear();

auto input_shape = in_tensors_.at(0)->shape();
for (auto i = 0; i < num_axes_ - 1; i++) {
int axis = axes_[i];
size_t size = 1;
for (size_t j = 0; j < input_shape.size(); j++) {
if (axis != static_cast<int>(j)) {
size *= input_shape[j];
}
}
float *buffer = reinterpret_cast<float *>(malloc(size * sizeof(float)));
for (auto size : buffer_sizes_) {
float *buffer = reinterpret_cast<float *>(context_->allocator->Malloc(size * sizeof(float)));
if (buffer == nullptr) {
MS_LOG(ERROR) << "Malloc data failed.";
return RET_ERROR;
}
data_buffers_.emplace_back(buffer);
input_shape[axis] = 1;
}
return RET_OK;
}

void ReduceCPUKernel::FreeTmpBuffer() {
for (size_t i = 0; i < data_buffers_.size(); i++) {
float *buffer = data_buffers_[i];
if (buffer != nullptr) {
context_->allocator->Free(buffer);
buffer = nullptr;
}
}
data_buffers_.clear();
}
} // namespace mindspore::kernel

+ 3
- 8
mindspore/lite/src/runtime/kernel/arm/fp32/reduce.h View File

@@ -28,7 +28,7 @@ using mindspore::schema::ReduceMode;
namespace mindspore::kernel {
class ReduceCPUKernel : public ReduceBaseCPUKernel {
typedef int (*Reducer)(const int outer_size, const int inner_size, const int axis_size, const float *src_data,
const int *src_shape, float *dst_data, const int tid, const int thread_num);
float *dst_data, const int tid, const int thread_num);

public:
ReduceCPUKernel(OpParameter *param, const std::vector<lite::Tensor *> &inputs,
@@ -36,13 +36,7 @@ class ReduceCPUKernel : public ReduceBaseCPUKernel {
const mindspore::lite::PrimitiveC *primitive)
: ReduceBaseCPUKernel(param, inputs, outputs, ctx, primitive) {}
~ReduceCPUKernel() {
for (size_t i = 0; i < data_buffers_.size(); i++) {
float *buffer = data_buffers_[i];
if (buffer != nullptr) {
free(buffer);
buffer = nullptr;
}
}
FreeTmpBuffer();
src_data_ = nullptr;
dst_data_ = nullptr;
}
@@ -60,6 +54,7 @@ class ReduceCPUKernel : public ReduceBaseCPUKernel {

private:
int MallocTmpBuffer();
void FreeTmpBuffer();
};
} // namespace mindspore::kernel



+ 46
- 86
mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.cc View File

@@ -39,10 +39,6 @@ int ReduceInt8CPUKernel::Init() {
if (ret != RET_OK) {
return ret;
}
ret = MallocTmpBuffer();
if (ret != RET_OK) {
return ret;
}
ret = CalculateQuantArgs();
if (ret != RET_OK) {
return ret;
@@ -179,23 +175,15 @@ int ReduceInt8CPUKernel::CalculateQuantArgs() {
}

int ReduceInt8CPUKernel::MallocTmpBuffer() {
auto input_shape = in_tensors_.at(0)->shape();
for (auto i = 0; i < num_axes_ - 1; i++) {
int axis = axes_[i];
size_t size = 1;
for (size_t j = 0; j < input_shape.size(); j++) {
if (axis != static_cast<int>(j)) {
size *= input_shape[j];
}
}
MS_ASSERT(context_->allocator != nullptr);
int32_t *buffer = reinterpret_cast<int32_t *>(context_->allocator->Malloc(size * sizeof(int32_t)));
data_buffers_.clear();
MS_ASSERT(static_cast<int>(buffer_sizes_.size()) == num_axes_ - 1);
for (auto buffer_size : buffer_sizes_) {
int32_t *buffer = reinterpret_cast<int32_t *>(context_->allocator->Malloc(buffer_size * sizeof(int32_t)));
if (buffer == nullptr) {
MS_LOG(ERROR) << "Malloc data failed.";
return RET_ERROR;
}
data_buffers_.emplace_back(buffer);
input_shape[axis] = 1;
}

auto input = in_tensors_.at(0);
@@ -203,17 +191,13 @@ int ReduceInt8CPUKernel::MallocTmpBuffer() {
if (begin_src_data_ == nullptr) {
return RET_NULL_PTR;
}
auto input_data = reinterpret_cast<int8_t *>(input->MutableData());
for (auto i = 0; i < input->ElementsNum(); i++) {
begin_src_data_[i] = static_cast<int32_t>(input_data[i]);
}

return RET_OK;
}

void ReduceInt8CPUKernel::FreeTmpBuffer() {
for (auto buffer : data_buffers_) {
if (buffer != nullptr) {
MS_ASSERT(context_->allocator != nullptr);
context_->allocator->Free(buffer);
buffer = nullptr;
}
@@ -221,20 +205,12 @@ void ReduceInt8CPUKernel::FreeTmpBuffer() {
data_buffers_.clear();

if (begin_src_data_ != nullptr) {
MS_ASSERT(context_->allocator != nullptr);
context_->allocator->Free(begin_src_data_);
begin_src_data_ = nullptr;
}
}

int ReduceInt8CPUKernel::ReSize() {
FreeTmpBuffer();
auto ret = MallocTmpBuffer();
if (ret != RET_OK) {
FreeTmpBuffer();
}
return ret;
}
int ReduceInt8CPUKernel::ReSize() { return ReduceBaseCPUKernel::ReSize(); }

int ReduceInt8Impl(void *cdata, int task_id) {
auto reduce = reinterpret_cast<ReduceInt8CPUKernel *>(cdata);
@@ -246,80 +222,65 @@ int ReduceInt8Impl(void *cdata, int task_id) {
return RET_OK;
}

void ReduceInt8CPUKernel::GetQuantArgs(size_t i) {
MS_ASSERT(i < static_cast<size_t>(num_axis_));
if (mode_ == static_cast<int>(schema::ReduceMode_ReduceMean)) {
quant_arg_.mean_multiplier_ = mean_multipliers_[i]->multiplier_;
quant_arg_.mean_left_shift_ = mean_multipliers_[i]->left_shift_;
quant_arg_.mean_right_shift_ = mean_multipliers_[i]->right_shift_;
}

if (mode_ == static_cast<int>(schema::ReduceMode_ReduceProd)) {
quant_arg_.prod_multiplier_ = prod_multipliers_[i]->multiplier_;
quant_arg_.prod_left_shift_ = prod_multipliers_[i]->left_shift_;
quant_arg_.prod_right_shift_ = prod_multipliers_[i]->right_shift_;
}
if (mode_ == static_cast<int>(schema::ReduceMode_ReduceSumSquare)) {
quant_arg_.sum_square_multiplier_ = sum_square_multipliers_[i]->multiplier_;
quant_arg_.sum_square_left_shift_ = sum_square_multipliers_[i]->left_shift_;
quant_arg_.sum_square_right_shift_ = sum_square_multipliers_[i]->right_shift_;
}
}

int ReduceInt8CPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
auto ret = MallocTmpBuffer();
if (ret != RET_OK) {
FreeTmpBuffer();
return ret;
}

is_last_axis_ = false;
tmp_shape_ = in_tensors_.at(0)->shape();
src_data_ = begin_src_data_;

for (size_t i = 0; i < data_buffers_.size(); ++i) {
if (mode_ == static_cast<int>(schema::ReduceMode_ReduceMean)) {
quant_arg_.mean_multiplier_ = mean_multipliers_[i]->multiplier_;
quant_arg_.mean_left_shift_ = mean_multipliers_[i]->left_shift_;
quant_arg_.mean_right_shift_ = mean_multipliers_[i]->right_shift_;
}

if (mode_ == static_cast<int>(schema::ReduceMode_ReduceProd)) {
quant_arg_.prod_multiplier_ = prod_multipliers_[i]->multiplier_;
quant_arg_.prod_left_shift_ = prod_multipliers_[i]->left_shift_;
quant_arg_.prod_right_shift_ = prod_multipliers_[i]->right_shift_;
}
if (mode_ == static_cast<int>(schema::ReduceMode_ReduceSumSquare)) {
quant_arg_.sum_square_multiplier_ = sum_square_multipliers_[i]->multiplier_;
quant_arg_.sum_square_left_shift_ = sum_square_multipliers_[i]->left_shift_;
quant_arg_.sum_square_right_shift_ = sum_square_multipliers_[i]->right_shift_;
}
auto input = in_tensors().at(0);
auto input_data = reinterpret_cast<int8_t *>(input->MutableData());
for (auto i = 0; i < input->ElementsNum(); i++) {
begin_src_data_[i] = static_cast<int32_t>(input_data[i]);
}
src_data_ = begin_src_data_;
for (size_t i = 0; i < data_buffers_.size() - 1; ++i) {
GetQuantArgs(i);
dst_data_ = data_buffers_[i];
int axis = axes_[i];
outer_size_ = 1;
for (int j = 0; j < axis; j++) {
outer_size_ *= tmp_shape_[j];
}
inner_size_ = 1;
for (int k = axis + 1; k < static_cast<int>(tmp_shape_.size()); k++) {
inner_size_ *= tmp_shape_[k];
}
axis_size_ = tmp_shape_[axis];
outer_size_ = outer_sizes_[i];
inner_size_ = inner_sizes_[i];
axis_size_ = axis_sizes_[i];
auto error_code = ParallelLaunch(THREAD_POOL_DEFAULT, ReduceInt8Impl, this, context_->thread_num_);
if (error_code != RET_OK) {
FreeTmpBuffer();
MS_LOG(ERROR) << "Reduce run error, error_code[" << error_code << "]";
return RET_ERROR;
}
tmp_shape_[axis] = 1;
src_data_ = dst_data_;
}

if (mode_ == static_cast<int>(schema::ReduceMode_ReduceMean)) {
quant_arg_.mean_multiplier_ = mean_multipliers_.back()->multiplier_;
quant_arg_.mean_left_shift_ = mean_multipliers_.back()->left_shift_;
quant_arg_.mean_right_shift_ = mean_multipliers_.back()->right_shift_;
}
if (mode_ == static_cast<int>(schema::ReduceMode_ReduceProd)) {
quant_arg_.prod_multiplier_ = prod_multipliers_.back()->multiplier_;
quant_arg_.prod_left_shift_ = prod_multipliers_.back()->left_shift_;
quant_arg_.prod_right_shift_ = prod_multipliers_.back()->right_shift_;
}
if (mode_ == static_cast<int>(schema::ReduceMode_ReduceSumSquare)) {
quant_arg_.sum_square_multiplier_ = sum_square_multipliers_.back()->multiplier_;
quant_arg_.sum_square_left_shift_ = sum_square_multipliers_.back()->left_shift_;
quant_arg_.sum_square_right_shift_ = sum_square_multipliers_.back()->right_shift_;
}
int last_reduce_axis = axes_[num_axes_ - 1];
outer_size_ = 1;
for (int i = 0; i < last_reduce_axis; i++) {
outer_size_ *= tmp_shape_[i];
}
inner_size_ = 1;
for (int i = last_reduce_axis + 1; i < static_cast<int>(tmp_shape_.size()); i++) {
inner_size_ *= tmp_shape_[i];
}
axis_size_ = tmp_shape_[last_reduce_axis];
GetQuantArgs(static_cast<size_t>(num_axes_ - 1));
outer_size_ = outer_sizes_.back();
inner_size_ = inner_sizes_.back();
axis_size_ = axis_sizes_.back();
last_dst_data_ = reinterpret_cast<int8_t *>(out_tensors_.at(0)->MutableData());
is_last_axis_ = true;
auto error_code = ParallelLaunch(THREAD_POOL_DEFAULT, ReduceInt8Impl, this, context_->thread_num_);
@@ -328,7 +289,6 @@ int ReduceInt8CPUKernel::Run() {
FreeTmpBuffer();
return RET_ERROR;
}

FreeTmpBuffer();
return RET_OK;
}


+ 2
- 0
mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.h View File

@@ -68,7 +68,9 @@ class ReduceInt8CPUKernel : public ReduceBaseCPUKernel {
private:
int MallocTmpBuffer();
void FreeTmpBuffer();

int CalculateQuantArgs();
void GetQuantArgs(size_t i);

private:
ReduceParameter *param_ = nullptr;


+ 18
- 18
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/reduce_fp32_tests.cc View File

@@ -46,7 +46,7 @@ TEST_F(TestReduceFp32, Mean) {
int outer_size = 2;
int inner_size = 12;
int axis_size = 4;
(void)ReduceMean(outer_size, inner_size, axis_size, in, input_shape, out, tid, thread_num);
(void)ReduceMean(outer_size, inner_size, axis_size, in, out, tid, thread_num);

int output_size = 24;
CompareOutputData(out, correct, output_size, err_tol);
@@ -72,9 +72,9 @@ TEST_F(TestReduceFp32, Mean2Thread) {
int axis_size = 4;
thread_num = 2;
tid = 0;
(void)ReduceMean(outer_size, inner_size, axis_size, in, input_shape, out, tid, thread_num);
(void)ReduceMean(outer_size, inner_size, axis_size, in, out, tid, thread_num);
tid = 1;
(void)ReduceMean(outer_size, inner_size, axis_size, in, input_shape, out, tid, thread_num);
(void)ReduceMean(outer_size, inner_size, axis_size, in, out, tid, thread_num);

int output_size = 24;
CompareOutputData(out, correct, output_size, err_tol);
@@ -98,7 +98,7 @@ TEST_F(TestReduceFp32, MeanAllAxis) {
float *src = in;
float dst1[48] = {0};
MS_ASSERT(dst != nullptr);
(void)ReduceMean(outer_size, inner_size, axis_size, src, input_shape, dst1, tid, thread_num);
(void)ReduceMean(outer_size, inner_size, axis_size, src, dst1, tid, thread_num);

input_shape[0] = 1; // 1 4 4 3
outer_size = 1;
@@ -106,7 +106,7 @@ TEST_F(TestReduceFp32, MeanAllAxis) {
axis_size = 4;
src = dst1;
float dst2[12] = {0};
(void)ReduceMean(outer_size, inner_size, axis_size, src, input_shape, dst2, tid, thread_num);
(void)ReduceMean(outer_size, inner_size, axis_size, src, dst2, tid, thread_num);

input_shape[1] = 1; // 1 1 4 3
outer_size = 1;
@@ -114,14 +114,14 @@ TEST_F(TestReduceFp32, MeanAllAxis) {
axis_size = 4;
src = dst2;
float dst3[3] = {0};
(void)ReduceMean(outer_size, inner_size, axis_size, src, input_shape, dst3, tid, thread_num);
(void)ReduceMean(outer_size, inner_size, axis_size, src, dst3, tid, thread_num);

input_shape[2] = 1; // 1 1 1 3
outer_size = 1;
inner_size = 1;
axis_size = 3;
src = dst3;
(void)ReduceMean(outer_size, inner_size, axis_size, src, input_shape, out, tid, thread_num);
(void)ReduceMean(outer_size, inner_size, axis_size, src, out, tid, thread_num);

int output_size = 1;
CompareOutputData(out, correct, output_size, err_tol);
@@ -145,7 +145,7 @@ TEST_F(TestReduceFp32, Sum) {
int outer_size = 2;
int inner_size = 12;
int axis_size = 4;
(void)ReduceSum(outer_size, inner_size, axis_size, in, input_shape, out, tid, thread_num);
(void)ReduceSum(outer_size, inner_size, axis_size, in, out, tid, thread_num);

int output_size = 24;
CompareOutputData(out, correct, output_size, err_tol);
@@ -171,9 +171,9 @@ TEST_F(TestReduceFp32, Sum2Thread) {
int axis_size = 4;
thread_num = 2;
tid = 0;
(void)ReduceSum(outer_size, inner_size, axis_size, in, input_shape, out, tid, thread_num);
(void)ReduceSum(outer_size, inner_size, axis_size, in, out, tid, thread_num);
tid = 1;
(void)ReduceSum(outer_size, inner_size, axis_size, in, input_shape, out, tid, thread_num);
(void)ReduceSum(outer_size, inner_size, axis_size, in, out, tid, thread_num);

int output_size = 24;
CompareOutputData(out, correct, output_size, err_tol);
@@ -197,7 +197,7 @@ TEST_F(TestReduceFp32, SumAllAxis) {
float *src = in;
float dst1[48] = {0};
MS_ASSERT(dst != nullptr);
(void)ReduceSum(outer_size, inner_size, axis_size, src, input_shape, dst1, tid, thread_num);
(void)ReduceSum(outer_size, inner_size, axis_size, src, dst1, tid, thread_num);

input_shape[0] = 1; // 1 4 4 3
outer_size = 1;
@@ -205,7 +205,7 @@ TEST_F(TestReduceFp32, SumAllAxis) {
axis_size = 4;
src = dst1;
float dst2[12] = {0};
(void)ReduceSum(outer_size, inner_size, axis_size, src, input_shape, dst2, tid, thread_num);
(void)ReduceSum(outer_size, inner_size, axis_size, src, dst2, tid, thread_num);

input_shape[1] = 1; // 1 1 4 3
outer_size = 1;
@@ -213,14 +213,14 @@ TEST_F(TestReduceFp32, SumAllAxis) {
axis_size = 4;
src = dst2;
float dst3[3] = {0};
(void)ReduceSum(outer_size, inner_size, axis_size, src, input_shape, dst3, tid, thread_num);
(void)ReduceSum(outer_size, inner_size, axis_size, src, dst3, tid, thread_num);

input_shape[2] = 1; // 1 1 1 3
outer_size = 1;
inner_size = 1;
axis_size = 3;
src = dst3;
(void)ReduceSum(outer_size, inner_size, axis_size, src, input_shape, out, tid, thread_num);
(void)ReduceSum(outer_size, inner_size, axis_size, src, out, tid, thread_num);

int output_size = 1;
CompareOutputData(out, correct, output_size, err_tol);
@@ -244,7 +244,7 @@ TEST_F(TestReduceFp32, Max) {
int outer_size = 2;
int inner_size = 12;
int axis_size = 4;
(void)ReduceMax(outer_size, inner_size, axis_size, in, input_shape, out, tid, thread_num);
(void)ReduceMax(outer_size, inner_size, axis_size, in, out, tid, thread_num);

int output_size = 24;
CompareOutputData(out, correct, output_size, err_tol);
@@ -268,7 +268,7 @@ TEST_F(TestReduceFp32, Min) {
int outer_size = 2;
int inner_size = 12;
int axis_size = 4;
(void)ReduceMin(outer_size, inner_size, axis_size, in, input_shape, out, tid, thread_num);
(void)ReduceMin(outer_size, inner_size, axis_size, in, out, tid, thread_num);

int output_size = 24;
CompareOutputData(out, correct, output_size, err_tol);
@@ -293,7 +293,7 @@ TEST_F(TestReduceFp32, Prod) {
int outer_size = 2;
int inner_size = 12;
int axis_size = 4;
(void)ReduceProd(outer_size, inner_size, axis_size, in, input_shape, out, tid, thread_num);
(void)ReduceProd(outer_size, inner_size, axis_size, in, out, tid, thread_num);

int output_size = 24;
CompareOutputData(out, correct, output_size, err_tol);
@@ -318,7 +318,7 @@ TEST_F(TestReduceFp32, SumSquare) {
int outer_size = 2;
int inner_size = 12;
int axis_size = 4;
(void)ReduceSumSquare(outer_size, inner_size, axis_size, in, input_shape, out, tid, thread_num);
(void)ReduceSumSquare(outer_size, inner_size, axis_size, in, out, tid, thread_num);

int output_size = 24;
CompareOutputData(out, correct, output_size, err_tol);


Loading…
Cancel
Save