From 58f02e611cf17c5dbf6b4af480e93ca490d70c0c Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Thu, 24 Sep 2020 15:27:05 +0800 Subject: [PATCH] mul int32 --- mindspore/lite/nnacl/fp32/arithmetic.c | 242 ++++++++++++++ mindspore/lite/nnacl/fp32/arithmetic.h | 6 + .../src/runtime/kernel/arm/fp32/arithmetic.cc | 109 ++++-- .../src/runtime/kernel/arm/fp32/arithmetic.h | 13 +- .../kernel/arm/fp32/arithmetic_fp32_tests.cc | 310 ++++++++++++++++++ 5 files changed, 652 insertions(+), 28 deletions(-) diff --git a/mindspore/lite/nnacl/fp32/arithmetic.c b/mindspore/lite/nnacl/fp32/arithmetic.c index 65a733897c..00594f8de3 100644 --- a/mindspore/lite/nnacl/fp32/arithmetic.c +++ b/mindspore/lite/nnacl/fp32/arithmetic.c @@ -169,6 +169,156 @@ int ElementOptMulRelu6(float *input0, float *input1, float *output, int element_ return NNACL_OK; } +int ElementOptMulInt(int *input0, int *input1, int *output, int element_size, ArithmeticParameter *param) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + int in0_opt = input0[0]; + int in1_opt = input1[0]; +#ifdef ENABLE_NEON + int32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]}; + int32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]}; +#endif + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + int32x4_t vin0 = vin0_opt; + int32x4_t vin1 = vld1q_s32(input1); + int32x4_t vout = vmulq_s32(vin0, vin1); + vst1q_s32(output, vout); +#else + for (int i = 0; i < C4NUM; ++i) { + output[i] = in0_opt * input1[i]; + } +#endif + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = in0_opt * input1[index]; + } + } else { + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + int32x4_t vin0 = vld1q_s32(input0); + int32x4_t vin1 = vin1_opt; + int32x4_t vout = vmulq_s32(vin0, vin1); + vst1q_s32(output, vout); +#else + for (int i = 0; i < C4NUM; ++i) { + output[i] = input0[i] * in1_opt; + } +#endif + input0 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = input0[index] * in1_opt; + } + } + + return NNACL_OK; +} +int ElementOptMulReluInt(int *input0, int *input1, int *output, int element_size, ArithmeticParameter *param) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + int in0_opt = input0[0]; + int in1_opt = input1[0]; +#ifdef ENABLE_NEON + int32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]}; + int32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]}; + int32x4_t zeros = {0, 0, 0, 0}; +#endif + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + int32x4_t vin0 = vin0_opt; + int32x4_t vin1 = vld1q_s32(input1); + int32x4_t vout = vmaxq_s32(vmulq_s32(vin0, vin1), zeros); + vst1q_s32(output, vout); +#else + for (int i = 0; i < C4NUM; ++i) { + output[i] = MSMAX(in0_opt * input1[i], 0); + } +#endif + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMAX(in0_opt * input1[index], 0); + } + } else { + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + int32x4_t vin0 = vld1q_s32(input0); + int32x4_t vin1 = vin1_opt; + int32x4_t vout = vmaxq_s32(vmulq_s32(vin0, vin1), zeros); + vst1q_s32(output, vout); +#else + for (int i = 0; i < C4NUM; ++i) { + output[i] = MSMAX(input0[i] * in1_opt, 0); + } +#endif + input0 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMAX(input0[index] * in1_opt, 0); + } + } + + return NNACL_OK; +} +int ElementOptMulRelu6Int(int *input0, int *input1, int *output, int element_size, ArithmeticParameter *param) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + int in0_opt = input0[0]; + int in1_opt = input1[0]; +#ifdef ENABLE_NEON + int32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]}; + int32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]}; + int32x4_t zeros = {0, 0, 0, 0}; + int32x4_t bounds = {6, 6, 6, 6}; +#endif + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + int32x4_t vin0 = vin0_opt; + int32x4_t vin1 = vld1q_s32(input1); + int32x4_t vout = vminq_s32(vmaxq_s32(vmulq_s32(vin0, vin1), zeros), bounds); + vst1q_s32(output, vout); +#else + for (int i = 0; i < C4NUM; ++i) { + output[i] = MSMIN(MSMAX(in0_opt * input1[i], 0), 6); + } +#endif + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMIN(MSMAX(in0_opt * input1[index], 0), 6); + } + } else { + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + int32x4_t vin0 = vld1q_s32(input0); + int32x4_t vin1 = vin1_opt; + int32x4_t vout = vminq_s32(vmaxq_s32(vmulq_s32(vin0, vin1), zeros), bounds); + vst1q_s32(output, vout); +#else + for (int i = 0; i < C4NUM; ++i) { + output[i] = MSMIN(MSMAX(input0[i] * in1_opt, 0), 6); + } +#endif + input0 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMIN(MSMAX(input0[index] * in1_opt, 0), 6); + } + } + + return NNACL_OK; +} int ElementOptSub(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) { int block_mod = element_size % C4NUM; @@ -608,6 +758,98 @@ int ElementMulRelu6(float *input0, float *input1, float *output, int element_siz return NNACL_OK; } +int ElementMulInt(int *input0, int *input1, int *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + int32x4_t vin0 = vld1q_s32(input0); + int32x4_t vin1 = vld1q_s32(input1); + int32x4_t vout = vmulq_s32(vin0, vin1); + vst1q_s32(output, vout); +#else + output[0] = input0[0] * input1[0]; + output[1] = input0[1] * input1[1]; + output[2] = input0[2] * input1[2]; + output[3] = input0[3] * input1[3]; +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = input0[index] * input1[index]; + } + + return NNACL_OK; +} +int ElementMulReluInt(int *input0, int *input1, int *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + +#ifdef ENABLE_NEON + int32x4_t zeros = {0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + int32x4_t vin0 = vld1q_s32(input0); + int32x4_t vin1 = vld1q_s32(input1); + int32x4_t vout = vmulq_s32(vin0, vin1); + vout = vbslq_s32(vcgtq_s32(vout, zeros), vout, zeros); + vst1q_s32(output, vout); +#else + float res = input0[0] * input1[0]; + output[0] = res > 0 ? res : 0; + res = input0[1] * input1[1]; + output[1] = res > 0 ? res : 0; + res = input0[2] * input1[2]; + output[2] = res > 0 ? res : 0; + res = input0[3] * input1[3]; + output[3] = res > 0 ? res : 0; +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + float res = input0[index] * input1[index]; + output[index] = res > 0 ? res : 0; + } + + return NNACL_OK; +} +int ElementMulRelu6Int(int *input0, int *input1, int *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + +#ifdef ENABLE_NEON + int32x4_t zeros = {0, 0, 0, 0}; + int32x4_t bounds = {6, 6, 6, 6}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + int32x4_t vin0 = vld1q_s32(input0); + int32x4_t vin1 = vld1q_s32(input1); + int32x4_t vout = vminq_s32(vmaxq_s32(vmulq_s32(vin0, vin1), zeros), bounds); + vst1q_s32(output, vout); +#else + output[0] = MSMIN(MSMAX(input0[0] * input1[0], 0), 6); + output[1] = MSMIN(MSMAX(input0[1] * input1[1], 0), 6); + output[2] = MSMIN(MSMAX(input0[2] * input1[2], 0), 6); + output[3] = MSMIN(MSMAX(input0[3] * input1[3], 0), 6); +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMIN(MSMAX(input0[index] * input1[index], 0), 6); + } + + return NNACL_OK; +} + int BroadcastMul(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, ArithmeticParameter *param) { TileDimensions(input0, input1, tile_input0, tile_input1, param); diff --git a/mindspore/lite/nnacl/fp32/arithmetic.h b/mindspore/lite/nnacl/fp32/arithmetic.h index 22c5d36c02..c9556a8642 100644 --- a/mindspore/lite/nnacl/fp32/arithmetic.h +++ b/mindspore/lite/nnacl/fp32/arithmetic.h @@ -35,12 +35,18 @@ int ElementOptSubRelu6(float *input0, float *input1, float *output, int element_ int ElementOptMul(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); int ElementOptMulRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); int ElementOptMulRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); +int ElementOptMulInt(int *input0, int *input1, int *output, int element_size, ArithmeticParameter *param); +int ElementOptMulReluInt(int *input0, int *input1, int *output, int element_size, ArithmeticParameter *param); +int ElementOptMulRelu6Int(int *input0, int *input1, int *output, int element_size, ArithmeticParameter *param); int ElementOptDiv(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); int ElementOptDivRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); int ElementOptDivRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); int ElementMul(float *input0, float *input1, float *output, int element_size); int ElementMulRelu(float *input0, float *input1, float *output, int element_size); int ElementMulRelu6(float *input0, float *input1, float *output, int element_size); +int ElementMulInt(int *input0, int *input1, int *output, int element_size); +int ElementMulReluInt(int *input0, int *input1, int *output, int element_size); +int ElementMulRelu6Int(int *input0, int *input1, int *output, int element_size); int BroadcastMul(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, ArithmeticParameter *param); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc index 95d87e747d..8c41ad946b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc @@ -36,6 +36,11 @@ int ArithmeticCPUKernel::Init() { if (!InferShapeDone()) { return RET_OK; } + if (in_tensors_[0]->data_type() == kNumberTypeFloat32) { + data_type_ = kDataTypeFloat; + } else { + data_type_ = kDataTypeInt; + } return ReSize(); } @@ -51,14 +56,17 @@ int ArithmeticCPUKernel::ReSize() { case schema::ActivationType_RELU: arithmeticParameter_->broadcasting_ = false; arithmetic_opt_run_ = ElementOptMulRelu; + arithmetic_opt_run_int_ = ElementOptMulReluInt; break; case schema::ActivationType_RELU6: arithmeticParameter_->broadcasting_ = false; arithmetic_opt_run_ = ElementOptMulRelu6; + arithmetic_opt_run_int_ = ElementOptMulRelu6Int; break; default: arithmeticParameter_->broadcasting_ = false; arithmetic_opt_run_ = ElementOptMul; + arithmetic_opt_run_int_ = ElementOptMulInt; break; } break; @@ -113,23 +121,40 @@ int ArithmeticCPUKernel::ReSize() { default: break; } + } else { + arithmetic_opt_run_ = nullptr; + arithmetic_opt_run_int_ = nullptr; } return RET_OK; } -int ArithmeticCPUKernel::BroadcastRun(float *input0, float *input1, float *output, int dim, int out_count, +int ArithmeticCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride) { if (dim > break_pos_) { - return arithmetic_run_(input0 + out_thread_stride, input1 + out_thread_stride, output + out_thread_stride, - out_count); + if (data_type_ == kDataTypeInt) { + return arithmetic_run_int_(reinterpret_cast(input0) + out_thread_stride, + reinterpret_cast(input1) + out_thread_stride, + reinterpret_cast(output) + out_thread_stride, out_count); + } + return arithmetic_run_(reinterpret_cast(input0) + out_thread_stride, + reinterpret_cast(input1) + out_thread_stride, + reinterpret_cast(output) + out_thread_stride, out_count); } for (int i = 0; i < arithmeticParameter_->out_shape_[dim]; ++i) { int pos0_ = arithmeticParameter_->in_shape0_[dim] == 1 ? 0 : i; int pos1_ = arithmeticParameter_->in_shape1_[dim] == 1 ? 0 : i; - int error_code = - BroadcastRun(input0 + pos0_ * arithmeticParameter_->in_strides0_[dim], - input1 + pos1_ * arithmeticParameter_->in_strides1_[dim], - output + i * arithmeticParameter_->out_strides_[dim], dim + 1, out_count, out_thread_stride); + int error_code; + if (data_type_ == kDataTypeInt) { + error_code = BroadcastRun(reinterpret_cast(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim], + reinterpret_cast(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim], + reinterpret_cast(output) + i * arithmeticParameter_->out_strides_[dim], dim + 1, + out_count, out_thread_stride); + } else { + error_code = BroadcastRun(reinterpret_cast(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim], + reinterpret_cast(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim], + reinterpret_cast(output) + i * arithmeticParameter_->out_strides_[dim], + dim + 1, out_count, out_thread_stride); + } if (error_code != RET_OK) { return error_code; } @@ -138,9 +163,6 @@ int ArithmeticCPUKernel::BroadcastRun(float *input0, float *input1, float *outpu } int ArithmeticCPUKernel::DoArithmetic(int task_id) { - auto input0_data = reinterpret_cast(in_tensors_[0]->MutableData()); - auto input1_data1 = reinterpret_cast(in_tensors_[1]->MutableData()); - auto output_data = reinterpret_cast(out_tensors_[0]->MutableData()); auto element_num = out_tensors_[0]->ElementsNum(); MS_ASSERT(thread_count_ != 0); @@ -152,26 +174,62 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) { return RET_ERROR; } - int error_code = RET_OK; - if (arithmeticParameter_->broadcasting_) { + int error_code; + if (arithmeticParameter_->broadcasting_) { // need broadcast stride = UP_DIV(outside_, thread_count_); - out_count_ = MSMIN(stride, outside_ - stride * task_id); - out_thread_stride_ = stride * task_id; - error_code = BroadcastRun(input0_data, input1_data1, output_data, 0, out_count_, out_thread_stride_); - } else if (arithmetic_opt_run_ != nullptr) { + int out_count = MSMIN(stride, outside_ - stride * task_id); + int out_thread_stride = stride * task_id; + if (data_type_ == kDataTypeFloat) { + error_code = + BroadcastRun(reinterpret_cast(in_tensors_[0]->MutableData()), + reinterpret_cast(in_tensors_[1]->MutableData()), + reinterpret_cast(out_tensors_[0]->MutableData()), 0, out_count, out_thread_stride); + } else { + error_code = BroadcastRun( + reinterpret_cast(in_tensors_[0]->MutableData()), reinterpret_cast(in_tensors_[1]->MutableData()), + reinterpret_cast(out_tensors_[0]->MutableData()), 0, out_count, out_thread_stride); + } + + } else if (arithmetic_opt_run_ != nullptr) { // no broadcast, one of input is scalar if (arithmeticParameter_->in_elements_num0_ == 1) { - error_code = arithmetic_opt_run_(input0_data, input1_data1 + stride * task_id, output_data + stride * task_id, - count, arithmeticParameter_); + if (data_type_ == kDataTypeFloat) { + error_code = arithmetic_opt_run_(reinterpret_cast(in_tensors_[0]->MutableData()), + reinterpret_cast(in_tensors_[1]->MutableData()) + stride * task_id, + reinterpret_cast(out_tensors_[0]->MutableData()) + stride * task_id, + count, arithmeticParameter_); + } else { + error_code = arithmetic_opt_run_int_(reinterpret_cast(in_tensors_[0]->MutableData()), + reinterpret_cast(in_tensors_[1]->MutableData()) + stride * task_id, + reinterpret_cast(out_tensors_[0]->MutableData()) + stride * task_id, + count, arithmeticParameter_); + } } else if (arithmeticParameter_->in_elements_num1_ == 1) { - error_code = arithmetic_opt_run_(input0_data + stride * task_id, input1_data1, output_data + stride * task_id, - count, arithmeticParameter_); + if (data_type_ == kDataTypeFloat) { + error_code = arithmetic_opt_run_(reinterpret_cast(in_tensors_[0]->MutableData()) + stride * task_id, + reinterpret_cast(in_tensors_[1]->MutableData()), + reinterpret_cast(out_tensors_[0]->MutableData()) + stride * task_id, + count, arithmeticParameter_); + } else { + error_code = arithmetic_opt_run_int_(reinterpret_cast(in_tensors_[0]->MutableData()) + stride * task_id, + reinterpret_cast(in_tensors_[1]->MutableData()), + reinterpret_cast(out_tensors_[0]->MutableData()) + stride * task_id, + count, arithmeticParameter_); + } } else { - error_code = arithmetic_opt_run_(input0_data + stride * task_id, input1_data1 + stride * task_id, - output_data + stride * task_id, count, arithmeticParameter_); + MS_LOG(ERROR) << "Arithmetic opt run: at least one of inputs is scalar"; + return RET_ERROR; + } + } else { // no broadcast, neither is scalar, two same shape + if (data_type_ == kDataTypeFloat) { + error_code = arithmetic_run_(reinterpret_cast(in_tensors_[0]->MutableData()) + stride * task_id, + reinterpret_cast(in_tensors_[1]->MutableData()) + stride * task_id, + reinterpret_cast(out_tensors_[0]->MutableData()) + stride * task_id, count); + } else { + error_code = + arithmetic_run_int_(reinterpret_cast(in_tensors_[0]->MutableData()) + stride * task_id, + reinterpret_cast(in_tensors_[1]->MutableData()) + stride * task_id, + reinterpret_cast(out_tensors_[0]->MutableData()) + stride * task_id, count); } - } else { - error_code = arithmetic_run_(input0_data + stride * task_id, input1_data1 + stride * task_id, - output_data + stride * task_id, count); } if (error_code != RET_OK) { return RET_ERROR; @@ -239,6 +297,7 @@ kernel::LiteKernel *CpuArithmeticFp32KernelCreator(const std::vector &inputs, @@ -57,12 +60,15 @@ class ArithmeticCPUKernel : public LiteKernel { switch (arithmeticParameter_->activation_type_) { case schema::ActivationType_RELU: arithmetic_run_ = ElementMulRelu; + arithmetic_run_int_ = ElementMulReluInt; break; case schema::ActivationType_RELU6: arithmetic_run_ = ElementMulRelu6; + arithmetic_run_int_ = ElementMulRelu6Int; break; default: arithmetic_run_ = ElementMul; + arithmetic_run_int_ = ElementMulInt; break; } break; @@ -158,15 +164,16 @@ class ArithmeticCPUKernel : public LiteKernel { int DoArithmetic(int task_id); private: - int BroadcastRun(float *input0, float *input1, float *output, int dim, int out_count, int out_thread_stride); + int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride); int break_pos_; int outside_; - int out_thread_stride_; - int out_count_; int thread_count_; ArithmeticParameter *arithmeticParameter_; ArithmeticRun arithmetic_run_ = nullptr; ArithmeticOptRun arithmetic_opt_run_ = nullptr; + ArithmeticIntRun arithmetic_run_int_ = nullptr; + ArithmeticOptIntRun arithmetic_opt_run_int_ = nullptr; + LiteDataType data_type_; }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/arithmetic_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/arithmetic_fp32_tests.cc index 1acdb7d60f..b83fde169c 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/arithmetic_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/arithmetic_fp32_tests.cc @@ -28,8 +28,66 @@ namespace mindspore { class TestArithmeticTestFp32 : public mindspore::CommonTest { public: TestArithmeticTestFp32() {} + void PrepareInt(const std::vector &input0_shape, const std::vector &input1_shape, bool broadcast, + const std::vector &output_shape, int *input0_data, int *input1_data, int *output_data, int type, + int act_type, const int thread_num); + void TearDown() override; + + public: + float err_tol = 1e-5; + lite::Tensor in_tensor_0_; + lite::Tensor in_tensor_1_; + lite::Tensor out_tensor_; + std::vector inputs_{&in_tensor_0_, &in_tensor_1_}; + std::vector outputs_{&out_tensor_}; + ArithmeticParameter param_; + kernel::KernelKey desc_ = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt, schema::PrimitiveType_Eltwise}; + lite::InnerContext ctx_ = lite::InnerContext(); + kernel::KernelCreator creator_ = nullptr; + kernel::LiteKernel *kernel_ = nullptr; }; +void TestArithmeticTestFp32::PrepareInt(const std::vector &input0_shape, const std::vector &input1_shape, + bool broadcast, const std::vector &output_shape, int *input0_data, + int *input1_data, int *output_data, int type, int act_type, + const int thread_num) { + param_.broadcasting_ = true; + param_.op_parameter_.type_ = type; + param_.ndim_ = input0_shape.size(); + param_.activation_type_ = act_type; + param_.broadcasting_ = broadcast; + for (size_t i = 0; i < input0_shape.size(); ++i) { + param_.in_shape0_[i] = input0_shape[i]; + } + for (size_t i = 0; i < input1_shape.size(); ++i) { + param_.in_shape1_[i] = input1_shape[i]; + } + for (size_t i = 0; i < output_shape.size(); ++i) { + param_.out_shape_[i] = output_shape[i]; + } + + in_tensor_0_.set_data_type(kNumberTypeInt); + in_tensor_0_.SetData(input0_data); + in_tensor_0_.set_shape(input0_shape); + in_tensor_1_.SetData(input1_data); + in_tensor_1_.set_shape(input1_shape); + out_tensor_.SetData(output_data); + out_tensor_.set_shape(output_shape); + + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc_); + ASSERT_NE(creator, nullptr); + ctx_.thread_num_ = thread_num; + ASSERT_EQ(lite::RET_OK, ctx_.Init()); + kernel_ = creator(inputs_, outputs_, reinterpret_cast(¶m_), &ctx_, desc_, nullptr); + ASSERT_NE(kernel_, nullptr); +} + +void TestArithmeticTestFp32::TearDown() { + in_tensor_0_.SetData(nullptr); + in_tensor_1_.SetData(nullptr); + out_tensor_.SetData(nullptr); +} + TEST_F(TestArithmeticTestFp32, AddTest) { auto add_param = new ArithmeticParameter(); add_param->ndim_ = 4; @@ -677,6 +735,258 @@ TEST_F(TestArithmeticTestFp32, MulRelu6Fp32) { output0_tensor.SetData(nullptr); } +TEST_F(TestArithmeticTestFp32, MulInt0) { + std::vector input0_shape{1, 2, 2, 3}; + std::vector input1_shape{1, 1, 1, 3}; + bool broadcast = true; + std::vector output_shape{1, 2, 2, 3}; + int in0_data[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + int in1_data[3] = {3, 2, 1}; + int out_data[12] = {0}; + schema::PrimitiveType type = schema::PrimitiveType_Mul; + int act_type = schema::ActivationType_NO_ACTIVATION; + int thread_num = 2; + desc_.type = type; + PrepareInt(input0_shape, input1_shape, broadcast, output_shape, in0_data, in1_data, out_data, type, act_type, + thread_num); + kernel_->Run(); + + int correct_data[12] = {0, 2, 2, 9, 8, 5, 18, 14, 8, 27, 20, 11}; + + CompareOutputData(out_data, correct_data, 12, err_tol); +} + +TEST_F(TestArithmeticTestFp32, MulInt1) { + std::vector input0_shape{1, 2, 2, 3}; + std::vector input1_shape{1}; + bool broadcast = true; + std::vector output_shape{1, 2, 2, 3}; + int in0_data[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + int in1_data[1] = {2}; + int out_data[12] = {0}; + schema::PrimitiveType type = schema::PrimitiveType_Mul; + int act_type = schema::ActivationType_NO_ACTIVATION; + int thread_num = 2; + desc_.type = type; + PrepareInt(input0_shape, input1_shape, broadcast, output_shape, in0_data, in1_data, out_data, type, act_type, + thread_num); + kernel_->Run(); + + int correct_data[12] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22}; + + CompareOutputData(out_data, correct_data, 12, err_tol); +} + +TEST_F(TestArithmeticTestFp32, MulInt2) { + std::vector input0_shape{1}; + std::vector input1_shape{1, 2, 2, 3}; + bool broadcast = true; + std::vector output_shape{1, 2, 2, 3}; + int in0_data[1] = {2}; + int in1_data[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + int out_data[12] = {0}; + schema::PrimitiveType type = schema::PrimitiveType_Mul; + int act_type = schema::ActivationType_NO_ACTIVATION; + int thread_num = 2; + desc_.type = type; + PrepareInt(input0_shape, input1_shape, broadcast, output_shape, in0_data, in1_data, out_data, type, act_type, + thread_num); + kernel_->Run(); + + int correct_data[12] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22}; + + CompareOutputData(out_data, correct_data, 12, err_tol); +} + +TEST_F(TestArithmeticTestFp32, MulInt3) { + std::vector input0_shape{1, 2, 2, 3}; + std::vector input1_shape{1, 2, 2, 3}; + bool broadcast = false; + std::vector output_shape{1, 2, 2, 3}; + int in0_data[12] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; + int in1_data[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + int out_data[12] = {0}; + schema::PrimitiveType type = schema::PrimitiveType_Mul; + int act_type = schema::ActivationType_NO_ACTIVATION; + int thread_num = 2; + desc_.type = type; + PrepareInt(input0_shape, input1_shape, broadcast, output_shape, in0_data, in1_data, out_data, type, act_type, + thread_num); + kernel_->Run(); + + int correct_data[12] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22}; + + CompareOutputData(out_data, correct_data, 12, err_tol); +} + +TEST_F(TestArithmeticTestFp32, MulReluInt0) { + std::vector input0_shape{1, 2, 2, 3}; + std::vector input1_shape{1, 1, 1, 3}; + bool broadcast = true; + std::vector output_shape{1, 2, 2, 3}; + int in0_data[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + int in1_data[3] = {-1, 1, 1}; + int out_data[12] = {0}; + schema::PrimitiveType type = schema::PrimitiveType_Mul; + int act_type = schema::ActivationType_RELU; + int thread_num = 2; + desc_.type = type; + PrepareInt(input0_shape, input1_shape, broadcast, output_shape, in0_data, in1_data, out_data, type, act_type, + thread_num); + kernel_->Run(); + + int correct_data[12] = {0, 1, 2, 0, 4, 5, 0, 7, 8, 0, 10, 11}; + + CompareOutputData(out_data, correct_data, 12, err_tol); +} + +TEST_F(TestArithmeticTestFp32, MulReluInt1) { + std::vector input0_shape{1, 2, 2, 3}; + std::vector input1_shape{1}; + bool broadcast = true; + std::vector output_shape{1, 2, 2, 3}; + int in0_data[12] = {0, -1, -2, -3, -4, -5, 6, 7, 8, 9, 10, 11}; + int in1_data[1] = {1}; + int out_data[12] = {0}; + schema::PrimitiveType type = schema::PrimitiveType_Mul; + int act_type = schema::ActivationType_RELU; + int thread_num = 2; + desc_.type = type; + PrepareInt(input0_shape, input1_shape, broadcast, output_shape, in0_data, in1_data, out_data, type, act_type, + thread_num); + kernel_->Run(); + + int correct_data[12] = {0, 0, 0, 0, 0, 0, 6, 7, 8, 9, 10, 11}; + + CompareOutputData(out_data, correct_data, 12, err_tol); +} + +TEST_F(TestArithmeticTestFp32, MulReluInt2) { + std::vector input0_shape{1}; + std::vector input1_shape{1, 2, 2, 3}; + bool broadcast = true; + std::vector output_shape{1, 2, 2, 3}; + int in0_data[1] = {1}; + int in1_data[12] = {0, -1, -2, -3, -4, -5, 6, 7, 8, 9, 10, 11}; + int out_data[12] = {0}; + schema::PrimitiveType type = schema::PrimitiveType_Mul; + int act_type = schema::ActivationType_RELU; + int thread_num = 2; + desc_.type = type; + PrepareInt(input0_shape, input1_shape, broadcast, output_shape, in0_data, in1_data, out_data, type, act_type, + thread_num); + kernel_->Run(); + + int correct_data[12] = {0, 0, 0, 0, 0, 0, 6, 7, 8, 9, 10, 11}; + + CompareOutputData(out_data, correct_data, 12, err_tol); +} + +TEST_F(TestArithmeticTestFp32, MulReluInt3) { + std::vector input0_shape{1, 2, 2, 3}; + std::vector input1_shape{1, 2, 2, 3}; + bool broadcast = false; + std::vector output_shape{1, 2, 2, 3}; + int in0_data[12] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + int in1_data[12] = {0, -1, -2, -3, -4, -5, 6, 7, 8, 9, 10, 11}; + int out_data[12] = {0}; + schema::PrimitiveType type = schema::PrimitiveType_Mul; + int act_type = schema::ActivationType_RELU; + int thread_num = 2; + desc_.type = type; + PrepareInt(input0_shape, input1_shape, broadcast, output_shape, in0_data, in1_data, out_data, type, act_type, + thread_num); + kernel_->Run(); + + int correct_data[12] = {0, 0, 0, 0, 0, 0, 6, 7, 8, 9, 10, 11}; + + CompareOutputData(out_data, correct_data, 12, err_tol); +} + +TEST_F(TestArithmeticTestFp32, MulRelu6Int0) { + std::vector input0_shape{1, 2, 2, 3}; + std::vector input1_shape{1, 1, 1, 3}; + bool broadcast = true; + std::vector output_shape{1, 2, 2, 3}; + int in0_data[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + int in1_data[3] = {-1, 1, 1}; + int out_data[12] = {0}; + schema::PrimitiveType type = schema::PrimitiveType_Mul; + int act_type = schema::ActivationType_RELU6; + int thread_num = 2; + desc_.type = type; + PrepareInt(input0_shape, input1_shape, broadcast, output_shape, in0_data, in1_data, out_data, type, act_type, + thread_num); + kernel_->Run(); + + int correct_data[12] = {0, 1, 2, 0, 4, 5, 0, 6, 6, 0, 6, 6}; + + CompareOutputData(out_data, correct_data, 12, err_tol); +} + +TEST_F(TestArithmeticTestFp32, MulRelu6Int1) { + std::vector input0_shape{1, 2, 2, 3}; + std::vector input1_shape{1}; + bool broadcast = true; + std::vector output_shape{1, 2, 2, 3}; + int in0_data[12] = {0, -1, -2, -3, -4, -5, 6, 7, 8, 9, 10, 11}; + int in1_data[1] = {1}; + int out_data[12] = {0}; + schema::PrimitiveType type = schema::PrimitiveType_Mul; + int act_type = schema::ActivationType_RELU6; + int thread_num = 2; + desc_.type = type; + PrepareInt(input0_shape, input1_shape, broadcast, output_shape, in0_data, in1_data, out_data, type, act_type, + thread_num); + kernel_->Run(); + + int correct_data[12] = {0, 0, 0, 0, 0, 0, 6, 6, 6, 6, 6, 6}; + + CompareOutputData(out_data, correct_data, 12, err_tol); +} + +TEST_F(TestArithmeticTestFp32, MulRelu6Int2) { + std::vector input0_shape{1}; + std::vector input1_shape{1, 2, 2, 3}; + bool broadcast = true; + std::vector output_shape{1, 2, 2, 3}; + int in0_data[1] = {1}; + int in1_data[12] = {0, -1, -2, -3, -4, -5, 6, 7, 8, 9, 10, 11}; + int out_data[12] = {0}; + schema::PrimitiveType type = schema::PrimitiveType_Mul; + int act_type = schema::ActivationType_RELU6; + int thread_num = 2; + desc_.type = type; + PrepareInt(input0_shape, input1_shape, broadcast, output_shape, in0_data, in1_data, out_data, type, act_type, + thread_num); + kernel_->Run(); + + int correct_data[12] = {0, 0, 0, 0, 0, 0, 6, 6, 6, 6, 6, 6}; + + CompareOutputData(out_data, correct_data, 12, err_tol); +} + +TEST_F(TestArithmeticTestFp32, MulRelu6Int3) { + std::vector input0_shape{1, 2, 2, 3}; + std::vector input1_shape{1, 2, 2, 3}; + bool broadcast = false; + std::vector output_shape{1, 2, 2, 3}; + int in0_data[12] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + int in1_data[12] = {0, -1, -2, -3, -4, -5, 6, 7, 8, 9, 10, 11}; + int out_data[12] = {0}; + schema::PrimitiveType type = schema::PrimitiveType_Mul; + int act_type = schema::ActivationType_RELU6; + int thread_num = 2; + desc_.type = type; + PrepareInt(input0_shape, input1_shape, broadcast, output_shape, in0_data, in1_data, out_data, type, act_type, + thread_num); + kernel_->Run(); + + int correct_data[12] = {0, 0, 0, 0, 0, 0, 6, 6, 6, 6, 6, 6}; + + CompareOutputData(out_data, correct_data, 12, err_tol); +} + TEST_F(TestArithmeticTestFp32, AddReluFp32) { std::vector inputs_tensor; std::vector outputs_tensor;