Browse Source

mul int32

tags/v1.1.0
zhaozhenlong 5 years ago
parent
commit
58f02e611c
5 changed files with 652 additions and 28 deletions
  1. +242
    -0
      mindspore/lite/nnacl/fp32/arithmetic.c
  2. +6
    -0
      mindspore/lite/nnacl/fp32/arithmetic.h
  3. +84
    -25
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc
  4. +10
    -3
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h
  5. +310
    -0
      mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/arithmetic_fp32_tests.cc

+ 242
- 0
mindspore/lite/nnacl/fp32/arithmetic.c View File

@@ -169,6 +169,156 @@ int ElementOptMulRelu6(float *input0, float *input1, float *output, int element_

return NNACL_OK;
}
int ElementOptMulInt(int *input0, int *input1, int *output, int element_size, ArithmeticParameter *param) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
int in0_opt = input0[0];
int in1_opt = input1[0];
#ifdef ENABLE_NEON
int32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
int32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
#endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
int32x4_t vin0 = vin0_opt;
int32x4_t vin1 = vld1q_s32(input1);
int32x4_t vout = vmulq_s32(vin0, vin1);
vst1q_s32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = in0_opt * input1[i];
}
#endif
input1 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = in0_opt * input1[index];
}
} else {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
int32x4_t vin0 = vld1q_s32(input0);
int32x4_t vin1 = vin1_opt;
int32x4_t vout = vmulq_s32(vin0, vin1);
vst1q_s32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = input0[i] * in1_opt;
}
#endif
input0 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = input0[index] * in1_opt;
}
}

return NNACL_OK;
}
int ElementOptMulReluInt(int *input0, int *input1, int *output, int element_size, ArithmeticParameter *param) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
int in0_opt = input0[0];
int in1_opt = input1[0];
#ifdef ENABLE_NEON
int32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
int32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
int32x4_t zeros = {0, 0, 0, 0};
#endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
int32x4_t vin0 = vin0_opt;
int32x4_t vin1 = vld1q_s32(input1);
int32x4_t vout = vmaxq_s32(vmulq_s32(vin0, vin1), zeros);
vst1q_s32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMAX(in0_opt * input1[i], 0);
}
#endif
input1 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMAX(in0_opt * input1[index], 0);
}
} else {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
int32x4_t vin0 = vld1q_s32(input0);
int32x4_t vin1 = vin1_opt;
int32x4_t vout = vmaxq_s32(vmulq_s32(vin0, vin1), zeros);
vst1q_s32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMAX(input0[i] * in1_opt, 0);
}
#endif
input0 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMAX(input0[index] * in1_opt, 0);
}
}

return NNACL_OK;
}
int ElementOptMulRelu6Int(int *input0, int *input1, int *output, int element_size, ArithmeticParameter *param) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
int in0_opt = input0[0];
int in1_opt = input1[0];
#ifdef ENABLE_NEON
int32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
int32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
int32x4_t zeros = {0, 0, 0, 0};
int32x4_t bounds = {6, 6, 6, 6};
#endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
int32x4_t vin0 = vin0_opt;
int32x4_t vin1 = vld1q_s32(input1);
int32x4_t vout = vminq_s32(vmaxq_s32(vmulq_s32(vin0, vin1), zeros), bounds);
vst1q_s32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMIN(MSMAX(in0_opt * input1[i], 0), 6);
}
#endif
input1 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(in0_opt * input1[index], 0), 6);
}
} else {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
int32x4_t vin0 = vld1q_s32(input0);
int32x4_t vin1 = vin1_opt;
int32x4_t vout = vminq_s32(vmaxq_s32(vmulq_s32(vin0, vin1), zeros), bounds);
vst1q_s32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMIN(MSMAX(input0[i] * in1_opt, 0), 6);
}
#endif
input0 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(input0[index] * in1_opt, 0), 6);
}
}

return NNACL_OK;
}

int ElementOptSub(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
int block_mod = element_size % C4NUM;
@@ -608,6 +758,98 @@ int ElementMulRelu6(float *input0, float *input1, float *output, int element_siz
return NNACL_OK;
}

int ElementMulInt(int *input0, int *input1, int *output, int element_size) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;

for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
int32x4_t vin0 = vld1q_s32(input0);
int32x4_t vin1 = vld1q_s32(input1);
int32x4_t vout = vmulq_s32(vin0, vin1);
vst1q_s32(output, vout);
#else
output[0] = input0[0] * input1[0];
output[1] = input0[1] * input1[1];
output[2] = input0[2] * input1[2];
output[3] = input0[3] * input1[3];
#endif
input0 += C4NUM;
input1 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = input0[index] * input1[index];
}

return NNACL_OK;
}
int ElementMulReluInt(int *input0, int *input1, int *output, int element_size) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;

#ifdef ENABLE_NEON
int32x4_t zeros = {0, 0, 0, 0};
#endif
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
int32x4_t vin0 = vld1q_s32(input0);
int32x4_t vin1 = vld1q_s32(input1);
int32x4_t vout = vmulq_s32(vin0, vin1);
vout = vbslq_s32(vcgtq_s32(vout, zeros), vout, zeros);
vst1q_s32(output, vout);
#else
float res = input0[0] * input1[0];
output[0] = res > 0 ? res : 0;
res = input0[1] * input1[1];
output[1] = res > 0 ? res : 0;
res = input0[2] * input1[2];
output[2] = res > 0 ? res : 0;
res = input0[3] * input1[3];
output[3] = res > 0 ? res : 0;
#endif
input0 += C4NUM;
input1 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
float res = input0[index] * input1[index];
output[index] = res > 0 ? res : 0;
}

return NNACL_OK;
}
int ElementMulRelu6Int(int *input0, int *input1, int *output, int element_size) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;

#ifdef ENABLE_NEON
int32x4_t zeros = {0, 0, 0, 0};
int32x4_t bounds = {6, 6, 6, 6};
#endif
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
int32x4_t vin0 = vld1q_s32(input0);
int32x4_t vin1 = vld1q_s32(input1);
int32x4_t vout = vminq_s32(vmaxq_s32(vmulq_s32(vin0, vin1), zeros), bounds);
vst1q_s32(output, vout);
#else
output[0] = MSMIN(MSMAX(input0[0] * input1[0], 0), 6);
output[1] = MSMIN(MSMAX(input0[1] * input1[1], 0), 6);
output[2] = MSMIN(MSMAX(input0[2] * input1[2], 0), 6);
output[3] = MSMIN(MSMAX(input0[3] * input1[3], 0), 6);
#endif
input0 += C4NUM;
input1 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(input0[index] * input1[index], 0), 6);
}

return NNACL_OK;
}

int BroadcastMul(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size,
ArithmeticParameter *param) {
TileDimensions(input0, input1, tile_input0, tile_input1, param);


+ 6
- 0
mindspore/lite/nnacl/fp32/arithmetic.h View File

@@ -35,12 +35,18 @@ int ElementOptSubRelu6(float *input0, float *input1, float *output, int element_
int ElementOptMul(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptMulRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptMulRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptMulInt(int *input0, int *input1, int *output, int element_size, ArithmeticParameter *param);
int ElementOptMulReluInt(int *input0, int *input1, int *output, int element_size, ArithmeticParameter *param);
int ElementOptMulRelu6Int(int *input0, int *input1, int *output, int element_size, ArithmeticParameter *param);
int ElementOptDiv(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptDivRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptDivRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementMul(float *input0, float *input1, float *output, int element_size);
int ElementMulRelu(float *input0, float *input1, float *output, int element_size);
int ElementMulRelu6(float *input0, float *input1, float *output, int element_size);
int ElementMulInt(int *input0, int *input1, int *output, int element_size);
int ElementMulReluInt(int *input0, int *input1, int *output, int element_size);
int ElementMulRelu6Int(int *input0, int *input1, int *output, int element_size);
int BroadcastMul(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size,
ArithmeticParameter *param);



+ 84
- 25
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc View File

@@ -36,6 +36,11 @@ int ArithmeticCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
if (in_tensors_[0]->data_type() == kNumberTypeFloat32) {
data_type_ = kDataTypeFloat;
} else {
data_type_ = kDataTypeInt;
}
return ReSize();
}

@@ -51,14 +56,17 @@ int ArithmeticCPUKernel::ReSize() {
case schema::ActivationType_RELU:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptMulRelu;
arithmetic_opt_run_int_ = ElementOptMulReluInt;
break;
case schema::ActivationType_RELU6:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptMulRelu6;
arithmetic_opt_run_int_ = ElementOptMulRelu6Int;
break;
default:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptMul;
arithmetic_opt_run_int_ = ElementOptMulInt;
break;
}
break;
@@ -113,23 +121,40 @@ int ArithmeticCPUKernel::ReSize() {
default:
break;
}
} else {
arithmetic_opt_run_ = nullptr;
arithmetic_opt_run_int_ = nullptr;
}
return RET_OK;
}

int ArithmeticCPUKernel::BroadcastRun(float *input0, float *input1, float *output, int dim, int out_count,
int ArithmeticCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count,
int out_thread_stride) {
if (dim > break_pos_) {
return arithmetic_run_(input0 + out_thread_stride, input1 + out_thread_stride, output + out_thread_stride,
out_count);
if (data_type_ == kDataTypeInt) {
return arithmetic_run_int_(reinterpret_cast<int *>(input0) + out_thread_stride,
reinterpret_cast<int *>(input1) + out_thread_stride,
reinterpret_cast<int *>(output) + out_thread_stride, out_count);
}
return arithmetic_run_(reinterpret_cast<float *>(input0) + out_thread_stride,
reinterpret_cast<float *>(input1) + out_thread_stride,
reinterpret_cast<float *>(output) + out_thread_stride, out_count);
}
for (int i = 0; i < arithmeticParameter_->out_shape_[dim]; ++i) {
int pos0_ = arithmeticParameter_->in_shape0_[dim] == 1 ? 0 : i;
int pos1_ = arithmeticParameter_->in_shape1_[dim] == 1 ? 0 : i;
int error_code =
BroadcastRun(input0 + pos0_ * arithmeticParameter_->in_strides0_[dim],
input1 + pos1_ * arithmeticParameter_->in_strides1_[dim],
output + i * arithmeticParameter_->out_strides_[dim], dim + 1, out_count, out_thread_stride);
int error_code;
if (data_type_ == kDataTypeInt) {
error_code = BroadcastRun(reinterpret_cast<int *>(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim],
reinterpret_cast<int *>(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim],
reinterpret_cast<int *>(output) + i * arithmeticParameter_->out_strides_[dim], dim + 1,
out_count, out_thread_stride);
} else {
error_code = BroadcastRun(reinterpret_cast<float *>(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim],
reinterpret_cast<float *>(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim],
reinterpret_cast<float *>(output) + i * arithmeticParameter_->out_strides_[dim],
dim + 1, out_count, out_thread_stride);
}
if (error_code != RET_OK) {
return error_code;
}
@@ -138,9 +163,6 @@ int ArithmeticCPUKernel::BroadcastRun(float *input0, float *input1, float *outpu
}

int ArithmeticCPUKernel::DoArithmetic(int task_id) {
auto input0_data = reinterpret_cast<float *>(in_tensors_[0]->MutableData());
auto input1_data1 = reinterpret_cast<float *>(in_tensors_[1]->MutableData());
auto output_data = reinterpret_cast<float *>(out_tensors_[0]->MutableData());
auto element_num = out_tensors_[0]->ElementsNum();

MS_ASSERT(thread_count_ != 0);
@@ -152,26 +174,62 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
return RET_ERROR;
}

int error_code = RET_OK;
if (arithmeticParameter_->broadcasting_) {
int error_code;
if (arithmeticParameter_->broadcasting_) { // need broadcast
stride = UP_DIV(outside_, thread_count_);
out_count_ = MSMIN(stride, outside_ - stride * task_id);
out_thread_stride_ = stride * task_id;
error_code = BroadcastRun(input0_data, input1_data1, output_data, 0, out_count_, out_thread_stride_);
} else if (arithmetic_opt_run_ != nullptr) {
int out_count = MSMIN(stride, outside_ - stride * task_id);
int out_thread_stride = stride * task_id;
if (data_type_ == kDataTypeFloat) {
error_code =
BroadcastRun(reinterpret_cast<float *>(in_tensors_[0]->MutableData()),
reinterpret_cast<float *>(in_tensors_[1]->MutableData()),
reinterpret_cast<float *>(out_tensors_[0]->MutableData()), 0, out_count, out_thread_stride);
} else {
error_code = BroadcastRun(
reinterpret_cast<int *>(in_tensors_[0]->MutableData()), reinterpret_cast<int *>(in_tensors_[1]->MutableData()),
reinterpret_cast<int *>(out_tensors_[0]->MutableData()), 0, out_count, out_thread_stride);
}

} else if (arithmetic_opt_run_ != nullptr) { // no broadcast, one of input is scalar
if (arithmeticParameter_->in_elements_num0_ == 1) {
error_code = arithmetic_opt_run_(input0_data, input1_data1 + stride * task_id, output_data + stride * task_id,
count, arithmeticParameter_);
if (data_type_ == kDataTypeFloat) {
error_code = arithmetic_opt_run_(reinterpret_cast<float *>(in_tensors_[0]->MutableData()),
reinterpret_cast<float *>(in_tensors_[1]->MutableData()) + stride * task_id,
reinterpret_cast<float *>(out_tensors_[0]->MutableData()) + stride * task_id,
count, arithmeticParameter_);
} else {
error_code = arithmetic_opt_run_int_(reinterpret_cast<int *>(in_tensors_[0]->MutableData()),
reinterpret_cast<int *>(in_tensors_[1]->MutableData()) + stride * task_id,
reinterpret_cast<int *>(out_tensors_[0]->MutableData()) + stride * task_id,
count, arithmeticParameter_);
}
} else if (arithmeticParameter_->in_elements_num1_ == 1) {
error_code = arithmetic_opt_run_(input0_data + stride * task_id, input1_data1, output_data + stride * task_id,
count, arithmeticParameter_);
if (data_type_ == kDataTypeFloat) {
error_code = arithmetic_opt_run_(reinterpret_cast<float *>(in_tensors_[0]->MutableData()) + stride * task_id,
reinterpret_cast<float *>(in_tensors_[1]->MutableData()),
reinterpret_cast<float *>(out_tensors_[0]->MutableData()) + stride * task_id,
count, arithmeticParameter_);
} else {
error_code = arithmetic_opt_run_int_(reinterpret_cast<int *>(in_tensors_[0]->MutableData()) + stride * task_id,
reinterpret_cast<int *>(in_tensors_[1]->MutableData()),
reinterpret_cast<int *>(out_tensors_[0]->MutableData()) + stride * task_id,
count, arithmeticParameter_);
}
} else {
error_code = arithmetic_opt_run_(input0_data + stride * task_id, input1_data1 + stride * task_id,
output_data + stride * task_id, count, arithmeticParameter_);
MS_LOG(ERROR) << "Arithmetic opt run: at least one of inputs is scalar";
return RET_ERROR;
}
} else { // no broadcast, neither is scalar, two same shape
if (data_type_ == kDataTypeFloat) {
error_code = arithmetic_run_(reinterpret_cast<float *>(in_tensors_[0]->MutableData()) + stride * task_id,
reinterpret_cast<float *>(in_tensors_[1]->MutableData()) + stride * task_id,
reinterpret_cast<float *>(out_tensors_[0]->MutableData()) + stride * task_id, count);
} else {
error_code =
arithmetic_run_int_(reinterpret_cast<int *>(in_tensors_[0]->MutableData()) + stride * task_id,
reinterpret_cast<int *>(in_tensors_[1]->MutableData()) + stride * task_id,
reinterpret_cast<int *>(out_tensors_[0]->MutableData()) + stride * task_id, count);
}
} else {
error_code = arithmetic_run_(input0_data + stride * task_id, input1_data1 + stride * task_id,
output_data + stride * task_id, count);
}
if (error_code != RET_OK) {
return RET_ERROR;
@@ -239,6 +297,7 @@ kernel::LiteKernel *CpuArithmeticFp32KernelCreator(const std::vector<lite::Tenso
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Mul, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_Mul, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Add, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Div, CpuArithmeticFp32KernelCreator)


+ 10
- 3
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h View File

@@ -45,6 +45,9 @@ class ArithmeticCPUKernel : public LiteKernel {
typedef int (*ArithmeticRun)(float *input0, float *input1, float *output, int element_size);
typedef int (*ArithmeticOptRun)(float *input0, float *input1, float *output, int element_size,
ArithmeticParameter *param);
typedef int (*ArithmeticIntRun)(int *input0, int *input1, int *output, int element_size);
typedef int (*ArithmeticOptIntRun)(int *input0, int *input1, int *output, int element_size,
ArithmeticParameter *param);

public:
ArithmeticCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
@@ -57,12 +60,15 @@ class ArithmeticCPUKernel : public LiteKernel {
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementMulRelu;
arithmetic_run_int_ = ElementMulReluInt;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementMulRelu6;
arithmetic_run_int_ = ElementMulRelu6Int;
break;
default:
arithmetic_run_ = ElementMul;
arithmetic_run_int_ = ElementMulInt;
break;
}
break;
@@ -158,15 +164,16 @@ class ArithmeticCPUKernel : public LiteKernel {
int DoArithmetic(int task_id);

private:
int BroadcastRun(float *input0, float *input1, float *output, int dim, int out_count, int out_thread_stride);
int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride);
int break_pos_;
int outside_;
int out_thread_stride_;
int out_count_;
int thread_count_;
ArithmeticParameter *arithmeticParameter_;
ArithmeticRun arithmetic_run_ = nullptr;
ArithmeticOptRun arithmetic_opt_run_ = nullptr;
ArithmeticIntRun arithmetic_run_int_ = nullptr;
ArithmeticOptIntRun arithmetic_opt_run_int_ = nullptr;
LiteDataType data_type_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_

+ 310
- 0
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/arithmetic_fp32_tests.cc View File

@@ -28,8 +28,66 @@ namespace mindspore {
class TestArithmeticTestFp32 : public mindspore::CommonTest {
public:
TestArithmeticTestFp32() {}
void PrepareInt(const std::vector<int> &input0_shape, const std::vector<int> &input1_shape, bool broadcast,
const std::vector<int> &output_shape, int *input0_data, int *input1_data, int *output_data, int type,
int act_type, const int thread_num);
void TearDown() override;

public:
float err_tol = 1e-5;
lite::Tensor in_tensor_0_;
lite::Tensor in_tensor_1_;
lite::Tensor out_tensor_;
std::vector<lite::Tensor *> inputs_{&in_tensor_0_, &in_tensor_1_};
std::vector<lite::Tensor *> outputs_{&out_tensor_};
ArithmeticParameter param_;
kernel::KernelKey desc_ = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt, schema::PrimitiveType_Eltwise};
lite::InnerContext ctx_ = lite::InnerContext();
kernel::KernelCreator creator_ = nullptr;
kernel::LiteKernel *kernel_ = nullptr;
};

void TestArithmeticTestFp32::PrepareInt(const std::vector<int> &input0_shape, const std::vector<int> &input1_shape,
bool broadcast, const std::vector<int> &output_shape, int *input0_data,
int *input1_data, int *output_data, int type, int act_type,
const int thread_num) {
param_.broadcasting_ = true;
param_.op_parameter_.type_ = type;
param_.ndim_ = input0_shape.size();
param_.activation_type_ = act_type;
param_.broadcasting_ = broadcast;
for (size_t i = 0; i < input0_shape.size(); ++i) {
param_.in_shape0_[i] = input0_shape[i];
}
for (size_t i = 0; i < input1_shape.size(); ++i) {
param_.in_shape1_[i] = input1_shape[i];
}
for (size_t i = 0; i < output_shape.size(); ++i) {
param_.out_shape_[i] = output_shape[i];
}

in_tensor_0_.set_data_type(kNumberTypeInt);
in_tensor_0_.SetData(input0_data);
in_tensor_0_.set_shape(input0_shape);
in_tensor_1_.SetData(input1_data);
in_tensor_1_.set_shape(input1_shape);
out_tensor_.SetData(output_data);
out_tensor_.set_shape(output_shape);

auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc_);
ASSERT_NE(creator, nullptr);
ctx_.thread_num_ = thread_num;
ASSERT_EQ(lite::RET_OK, ctx_.Init());
kernel_ = creator(inputs_, outputs_, reinterpret_cast<OpParameter *>(&param_), &ctx_, desc_, nullptr);
ASSERT_NE(kernel_, nullptr);
}

void TestArithmeticTestFp32::TearDown() {
in_tensor_0_.SetData(nullptr);
in_tensor_1_.SetData(nullptr);
out_tensor_.SetData(nullptr);
}

TEST_F(TestArithmeticTestFp32, AddTest) {
auto add_param = new ArithmeticParameter();
add_param->ndim_ = 4;
@@ -677,6 +735,258 @@ TEST_F(TestArithmeticTestFp32, MulRelu6Fp32) {
output0_tensor.SetData(nullptr);
}

TEST_F(TestArithmeticTestFp32, MulInt0) {
std::vector<int> input0_shape{1, 2, 2, 3};
std::vector<int> input1_shape{1, 1, 1, 3};
bool broadcast = true;
std::vector<int> output_shape{1, 2, 2, 3};
int in0_data[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
int in1_data[3] = {3, 2, 1};
int out_data[12] = {0};
schema::PrimitiveType type = schema::PrimitiveType_Mul;
int act_type = schema::ActivationType_NO_ACTIVATION;
int thread_num = 2;
desc_.type = type;
PrepareInt(input0_shape, input1_shape, broadcast, output_shape, in0_data, in1_data, out_data, type, act_type,
thread_num);
kernel_->Run();

int correct_data[12] = {0, 2, 2, 9, 8, 5, 18, 14, 8, 27, 20, 11};

CompareOutputData(out_data, correct_data, 12, err_tol);
}

TEST_F(TestArithmeticTestFp32, MulInt1) {
std::vector<int> input0_shape{1, 2, 2, 3};
std::vector<int> input1_shape{1};
bool broadcast = true;
std::vector<int> output_shape{1, 2, 2, 3};
int in0_data[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
int in1_data[1] = {2};
int out_data[12] = {0};
schema::PrimitiveType type = schema::PrimitiveType_Mul;
int act_type = schema::ActivationType_NO_ACTIVATION;
int thread_num = 2;
desc_.type = type;
PrepareInt(input0_shape, input1_shape, broadcast, output_shape, in0_data, in1_data, out_data, type, act_type,
thread_num);
kernel_->Run();

int correct_data[12] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22};

CompareOutputData(out_data, correct_data, 12, err_tol);
}

TEST_F(TestArithmeticTestFp32, MulInt2) {
std::vector<int> input0_shape{1};
std::vector<int> input1_shape{1, 2, 2, 3};
bool broadcast = true;
std::vector<int> output_shape{1, 2, 2, 3};
int in0_data[1] = {2};
int in1_data[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
int out_data[12] = {0};
schema::PrimitiveType type = schema::PrimitiveType_Mul;
int act_type = schema::ActivationType_NO_ACTIVATION;
int thread_num = 2;
desc_.type = type;
PrepareInt(input0_shape, input1_shape, broadcast, output_shape, in0_data, in1_data, out_data, type, act_type,
thread_num);
kernel_->Run();

int correct_data[12] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22};

CompareOutputData(out_data, correct_data, 12, err_tol);
}

TEST_F(TestArithmeticTestFp32, MulInt3) {
std::vector<int> input0_shape{1, 2, 2, 3};
std::vector<int> input1_shape{1, 2, 2, 3};
bool broadcast = false;
std::vector<int> output_shape{1, 2, 2, 3};
int in0_data[12] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2};
int in1_data[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
int out_data[12] = {0};
schema::PrimitiveType type = schema::PrimitiveType_Mul;
int act_type = schema::ActivationType_NO_ACTIVATION;
int thread_num = 2;
desc_.type = type;
PrepareInt(input0_shape, input1_shape, broadcast, output_shape, in0_data, in1_data, out_data, type, act_type,
thread_num);
kernel_->Run();

int correct_data[12] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22};

CompareOutputData(out_data, correct_data, 12, err_tol);
}

TEST_F(TestArithmeticTestFp32, MulReluInt0) {
std::vector<int> input0_shape{1, 2, 2, 3};
std::vector<int> input1_shape{1, 1, 1, 3};
bool broadcast = true;
std::vector<int> output_shape{1, 2, 2, 3};
int in0_data[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
int in1_data[3] = {-1, 1, 1};
int out_data[12] = {0};
schema::PrimitiveType type = schema::PrimitiveType_Mul;
int act_type = schema::ActivationType_RELU;
int thread_num = 2;
desc_.type = type;
PrepareInt(input0_shape, input1_shape, broadcast, output_shape, in0_data, in1_data, out_data, type, act_type,
thread_num);
kernel_->Run();

int correct_data[12] = {0, 1, 2, 0, 4, 5, 0, 7, 8, 0, 10, 11};

CompareOutputData(out_data, correct_data, 12, err_tol);
}

TEST_F(TestArithmeticTestFp32, MulReluInt1) {
std::vector<int> input0_shape{1, 2, 2, 3};
std::vector<int> input1_shape{1};
bool broadcast = true;
std::vector<int> output_shape{1, 2, 2, 3};
int in0_data[12] = {0, -1, -2, -3, -4, -5, 6, 7, 8, 9, 10, 11};
int in1_data[1] = {1};
int out_data[12] = {0};
schema::PrimitiveType type = schema::PrimitiveType_Mul;
int act_type = schema::ActivationType_RELU;
int thread_num = 2;
desc_.type = type;
PrepareInt(input0_shape, input1_shape, broadcast, output_shape, in0_data, in1_data, out_data, type, act_type,
thread_num);
kernel_->Run();

int correct_data[12] = {0, 0, 0, 0, 0, 0, 6, 7, 8, 9, 10, 11};

CompareOutputData(out_data, correct_data, 12, err_tol);
}

TEST_F(TestArithmeticTestFp32, MulReluInt2) {
std::vector<int> input0_shape{1};
std::vector<int> input1_shape{1, 2, 2, 3};
bool broadcast = true;
std::vector<int> output_shape{1, 2, 2, 3};
int in0_data[1] = {1};
int in1_data[12] = {0, -1, -2, -3, -4, -5, 6, 7, 8, 9, 10, 11};
int out_data[12] = {0};
schema::PrimitiveType type = schema::PrimitiveType_Mul;
int act_type = schema::ActivationType_RELU;
int thread_num = 2;
desc_.type = type;
PrepareInt(input0_shape, input1_shape, broadcast, output_shape, in0_data, in1_data, out_data, type, act_type,
thread_num);
kernel_->Run();

int correct_data[12] = {0, 0, 0, 0, 0, 0, 6, 7, 8, 9, 10, 11};

CompareOutputData(out_data, correct_data, 12, err_tol);
}

TEST_F(TestArithmeticTestFp32, MulReluInt3) {
std::vector<int> input0_shape{1, 2, 2, 3};
std::vector<int> input1_shape{1, 2, 2, 3};
bool broadcast = false;
std::vector<int> output_shape{1, 2, 2, 3};
int in0_data[12] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
int in1_data[12] = {0, -1, -2, -3, -4, -5, 6, 7, 8, 9, 10, 11};
int out_data[12] = {0};
schema::PrimitiveType type = schema::PrimitiveType_Mul;
int act_type = schema::ActivationType_RELU;
int thread_num = 2;
desc_.type = type;
PrepareInt(input0_shape, input1_shape, broadcast, output_shape, in0_data, in1_data, out_data, type, act_type,
thread_num);
kernel_->Run();

int correct_data[12] = {0, 0, 0, 0, 0, 0, 6, 7, 8, 9, 10, 11};

CompareOutputData(out_data, correct_data, 12, err_tol);
}

TEST_F(TestArithmeticTestFp32, MulRelu6Int0) {
std::vector<int> input0_shape{1, 2, 2, 3};
std::vector<int> input1_shape{1, 1, 1, 3};
bool broadcast = true;
std::vector<int> output_shape{1, 2, 2, 3};
int in0_data[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
int in1_data[3] = {-1, 1, 1};
int out_data[12] = {0};
schema::PrimitiveType type = schema::PrimitiveType_Mul;
int act_type = schema::ActivationType_RELU6;
int thread_num = 2;
desc_.type = type;
PrepareInt(input0_shape, input1_shape, broadcast, output_shape, in0_data, in1_data, out_data, type, act_type,
thread_num);
kernel_->Run();

int correct_data[12] = {0, 1, 2, 0, 4, 5, 0, 6, 6, 0, 6, 6};

CompareOutputData(out_data, correct_data, 12, err_tol);
}

TEST_F(TestArithmeticTestFp32, MulRelu6Int1) {
std::vector<int> input0_shape{1, 2, 2, 3};
std::vector<int> input1_shape{1};
bool broadcast = true;
std::vector<int> output_shape{1, 2, 2, 3};
int in0_data[12] = {0, -1, -2, -3, -4, -5, 6, 7, 8, 9, 10, 11};
int in1_data[1] = {1};
int out_data[12] = {0};
schema::PrimitiveType type = schema::PrimitiveType_Mul;
int act_type = schema::ActivationType_RELU6;
int thread_num = 2;
desc_.type = type;
PrepareInt(input0_shape, input1_shape, broadcast, output_shape, in0_data, in1_data, out_data, type, act_type,
thread_num);
kernel_->Run();

int correct_data[12] = {0, 0, 0, 0, 0, 0, 6, 6, 6, 6, 6, 6};

CompareOutputData(out_data, correct_data, 12, err_tol);
}

TEST_F(TestArithmeticTestFp32, MulRelu6Int2) {
std::vector<int> input0_shape{1};
std::vector<int> input1_shape{1, 2, 2, 3};
bool broadcast = true;
std::vector<int> output_shape{1, 2, 2, 3};
int in0_data[1] = {1};
int in1_data[12] = {0, -1, -2, -3, -4, -5, 6, 7, 8, 9, 10, 11};
int out_data[12] = {0};
schema::PrimitiveType type = schema::PrimitiveType_Mul;
int act_type = schema::ActivationType_RELU6;
int thread_num = 2;
desc_.type = type;
PrepareInt(input0_shape, input1_shape, broadcast, output_shape, in0_data, in1_data, out_data, type, act_type,
thread_num);
kernel_->Run();

int correct_data[12] = {0, 0, 0, 0, 0, 0, 6, 6, 6, 6, 6, 6};

CompareOutputData(out_data, correct_data, 12, err_tol);
}

TEST_F(TestArithmeticTestFp32, MulRelu6Int3) {
std::vector<int> input0_shape{1, 2, 2, 3};
std::vector<int> input1_shape{1, 2, 2, 3};
bool broadcast = false;
std::vector<int> output_shape{1, 2, 2, 3};
int in0_data[12] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
int in1_data[12] = {0, -1, -2, -3, -4, -5, 6, 7, 8, 9, 10, 11};
int out_data[12] = {0};
schema::PrimitiveType type = schema::PrimitiveType_Mul;
int act_type = schema::ActivationType_RELU6;
int thread_num = 2;
desc_.type = type;
PrepareInt(input0_shape, input1_shape, broadcast, output_shape, in0_data, in1_data, out_data, type, act_type,
thread_num);
kernel_->Run();

int correct_data[12] = {0, 0, 0, 0, 0, 0, 6, 6, 6, 6, 6, 6};

CompareOutputData(out_data, correct_data, 12, err_tol);
}

TEST_F(TestArithmeticTestFp32, AddReluFp32) {
std::vector<lite::Tensor *> inputs_tensor;
std::vector<lite::Tensor *> outputs_tensor;


Loading…
Cancel
Save