From 3f0dbfc357622717ffa37ae06dbf5043d8616f89 Mon Sep 17 00:00:00 2001 From: fuzhiye Date: Mon, 26 Oct 2020 15:55:58 +0800 Subject: [PATCH] fix arithmetic op bug --- mindspore/lite/nnacl/fp16/arithmetic_fp16.c | 156 ++++++------- mindspore/lite/nnacl/fp16/arithmetic_fp16.h | 24 +- mindspore/lite/nnacl/fp16/cast_fp16.c | 13 ++ mindspore/lite/nnacl/fp16/cast_fp16.h | 2 + .../lite/nnacl/fp32/arithmetic_compare.c | 109 +++++++++ .../lite/nnacl/fp32/arithmetic_compare.h | 50 ++++ mindspore/lite/nnacl/int8/arithmetic_int8.c | 54 ++--- mindspore/lite/nnacl/int8/arithmetic_int8.h | 13 +- .../arm/fp16/arithmetic_compare_fp16.cc | 216 ++++++++++++++++++ .../kernel/arm/fp16/arithmetic_compare_fp16.h | 67 ++++++ .../kernel/arm/fp16/arithmetic_fp16.cc | 10 +- .../src/runtime/kernel/arm/fp16/cast_fp16.cc | 6 + .../src/runtime/kernel/arm/fp32/arithmetic.cc | 6 - .../kernel/arm/fp32/arithmetic_compare.cc | 127 ++++++++++ .../kernel/arm/fp32/arithmetic_compare.h | 46 ++++ .../runtime/kernel/arm/fp32/convolution.cc | 2 +- .../kernel/arm/int8/arithmetic_int8.cc | 2 +- .../runtime/kernel/arm/int8/arithmetic_int8.h | 2 +- 18 files changed, 745 insertions(+), 160 deletions(-) create mode 100644 mindspore/lite/nnacl/fp32/arithmetic_compare.c create mode 100644 mindspore/lite/nnacl/fp32/arithmetic_compare.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_compare_fp16.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_compare_fp16.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare.h diff --git a/mindspore/lite/nnacl/fp16/arithmetic_fp16.c b/mindspore/lite/nnacl/fp16/arithmetic_fp16.c index d493358437..af1a778ad0 100644 --- a/mindspore/lite/nnacl/fp16/arithmetic_fp16.c +++ b/mindspore/lite/nnacl/fp16/arithmetic_fp16.c @@ -985,319 +985,295 @@ int ElementOptMinimumFp16(float16_t *input0, float16_t *input1, float16_t *outpu return NNACL_OK; } -int ElementNotEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { +int ElementNotEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON - float16x8_t vtrue = vdupq_n_f16(1); - float16x8_t vfalse = vdupq_n_f16(0); for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin0 = vld1q_f16(input0 + index); float16x8_t vin1 = vld1q_f16(input1 + index); - float16x8_t vout = vbslq_f16(vceqq_f16(vin0, vin1), vfalse, vtrue); - vst1q_f16(output + index, vout); + uint8x8_t vout = vmovn_u16(vceqq_f16(vin0, vin1)); + vst1_u8(output + index, vout); } #endif for (; index < element_size; index++) { - output[index] = (float16_t)(input0[index] != input1[index]); + output[index] = input0[index] != input1[index]; } return NNACL_OK; } -int ElementOptNotEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptNotEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]); - float16x8_t vtrue = vdupq_n_f16(1); - float16x8_t vfalse = vdupq_n_f16(0); #endif int index = 0; if (param->in_elements_num0_ == 1) { #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin1 = vld1q_f16(input1 + index); - float16x8_t vout = vbslq_f16(vceqq_f16(vin0_opt, vin1), vfalse, vtrue); - vst1q_f16(output + index, vout); + uint8x8_t vout = vmovn_u16(vceqq_f16(vin0_opt, vin1)); + vst1_u8(output + index, vout); } #endif for (; index < element_size; index++) { - output[index] = (float16_t)(input0[0] != input1[index]); + output[index] = input0[0] != input1[index]; } } else { #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin0 = vld1q_f16(input0 + index); - float16x8_t vout = vbslq_f16(vceqq_f16(vin0, vin1_opt), vfalse, vtrue); - vst1q_f16(output + index, vout); + uint8x8_t vout = vmovn_u16(vceqq_f16(vin0, vin1_opt)); + vst1_u8(output + index, vout); } #endif for (; index < element_size; index++) { - output[index] = (float16_t)(input0[index] != input1[0]); + output[index] = input0[index] != input1[0]; } } return NNACL_OK; } -int ElementEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { +int ElementEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON - float16x8_t vtrue = vdupq_n_f16(1); - float16x8_t vfalse = vdupq_n_f16(0); for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin0 = vld1q_f16(input0 + index); float16x8_t vin1 = vld1q_f16(input1 + index); - float16x8_t vout = vbslq_f16(vceqq_f16(vin0, vin1), vtrue, vfalse); - vst1q_f16(output + index, vout); + uint8x8_t vout = vmovn_u16(vceqq_f16(vin0, vin1)); + vst1_u8(output + index, vout); } #endif for (; index < element_size; index++) { - output[index] = (float16_t)(input0[index] == input1[index]); + output[index] = input0[index] == input1[index]; } return NNACL_OK; } -int ElementOptEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]); - float16x8_t vtrue = vdupq_n_f16(1); - float16x8_t vfalse = vdupq_n_f16(0); #endif int index = 0; if (param->in_elements_num0_ == 1) { #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin1 = vld1q_f16(input1 + index); - float16x8_t vout = vbslq_f16(vceqq_f16(vin0_opt, vin1), vtrue, vfalse); - vst1q_f16(output + index, vout); + uint8x8_t vout = vmovn_u16(vceqq_f16(vin0_opt, vin1)); + vst1_u8(output + index, vout); } #endif for (; index < element_size; index++) { - output[index] = (float16_t)(input0[0] == input1[index]); + output[index] = input0[0] == input1[index]; } } else { #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin0 = vld1q_f16(input0 + index); - float16x8_t vout = vbslq_f16(vceqq_f16(vin0, vin1_opt), vtrue, vfalse); - vst1q_f16(output + index, vout); + uint8x8_t vout = vmovn_u16(vceqq_f16(vin0, vin1_opt)); + vst1_u8(output + index, vout); } #endif for (; index < element_size; index++) { - output[index] = (float16_t)(input0[index] == input1[0]); + output[index] = input0[index] == input1[0]; } } return NNACL_OK; } -int ElementLessFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { +int ElementLessFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON - float16x8_t vtrue = vdupq_n_f16(1); - float16x8_t vfalse = vdupq_n_f16(0); for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin0 = vld1q_f16(input0 + index); float16x8_t vin1 = vld1q_f16(input1 + index); - float16x8_t vout = vbslq_f16(vcltq_f16(vin0, vin1), vtrue, vfalse); - vst1q_f16(output + index, vout); + uint8x8_t vout = vmovn_u16(vcltq_f16(vin0, vin1)); + vst1_u8(output + index, vout); } #endif for (; index < element_size; index++) { - output[index] = (float16_t)(input0[index] < input1[index]); + output[index] = input0[index] < input1[index]; } return NNACL_OK; } -int ElementOptLessFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptLessFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]); - float16x8_t vtrue = vdupq_n_f16(1); - float16x8_t vfalse = vdupq_n_f16(0); #endif int index = 0; if (param->in_elements_num0_ == 1) { #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin1 = vld1q_f16(input1 + index); - float16x8_t vout = vbslq_f16(vcltq_f16(vin0_opt, vin1), vtrue, vfalse); - vst1q_f16(output + index, vout); + uint8x8_t vout = vmovn_u16(vcltq_f16(vin0_opt, vin1)); + vst1_u8(output + index, vout); } #endif for (; index < element_size; index++) { - output[index] = (float16_t)(input0[0] < input1[index]); + output[index] = input0[0] < input1[index]; } } else { #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin0 = vld1q_f16(input0 + index); - float16x8_t vout = vbslq_f16(vcltq_f16(vin0, vin1_opt), vtrue, vfalse); - vst1q_f16(output + index, vout); + uint8x8_t vout = vmovn_u16(vcltq_f16(vin0, vin1_opt)); + vst1_u8(output + index, vout); } #endif for (; index < element_size; index++) { - output[index] = (float16_t)(input0[index] < input1[0]); + output[index] = input0[index] < input1[0]; } } return NNACL_OK; } -int ElementLessEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { +int ElementLessEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON - float16x8_t vtrue = vdupq_n_f16(1); - float16x8_t vfalse = vdupq_n_f16(0); for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin0 = vld1q_f16(input0 + index); float16x8_t vin1 = vld1q_f16(input1 + index); - float16x8_t vout = vbslq_f16(vcleq_f16(vin0, vin1), vtrue, vfalse); - vst1q_f16(output + index, vout); + uint8x8_t vout = vmovn_u16(vcleq_f16(vin0, vin1)); + vst1_u8(output + index, vout); } #endif for (; index < element_size; index++) { - output[index] = (float16_t)(input0[index] <= input1[index]); + output[index] = input0[index] <= input1[index]; } return NNACL_OK; } -int ElementOptLessEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptLessEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]); - float16x8_t vtrue = vdupq_n_f16(1); - float16x8_t vfalse = vdupq_n_f16(0); #endif int index = 0; if (param->in_elements_num0_ == 1) { #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin1 = vld1q_f16(input1 + index); - float16x8_t vout = vbslq_f16(vcleq_f16(vin0_opt, vin1), vtrue, vfalse); - vst1q_f16(output + index, vout); + uint8x8_t vout = vmovn_u16(vcleq_f16(vin0_opt, vin1)); + vst1_u8(output + index, vout); } #endif for (; index < element_size; index++) { - output[index] = (float16_t)(input0[0] <= input1[index]); + output[index] = input0[0] <= input1[index]; } } else { #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin0 = vld1q_f16(input0 + index); - float16x8_t vout = vbslq_f16(vcleq_f16(vin0, vin1_opt), vtrue, vfalse); - vst1q_f16(output + index, vout); + uint8x8_t vout = vmovn_u16(vcleq_f16(vin0, vin1_opt)); + vst1_u8(output + index, vout); } #endif for (; index < element_size; index++) { - output[index] = (float16_t)(input0[index] <= input1[0]); + output[index] = input0[index] <= input1[0]; } } return NNACL_OK; } -int ElementGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { +int ElementGreaterFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON - float16x8_t vtrue = vdupq_n_f16(1); - float16x8_t vfalse = vdupq_n_f16(0); for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin0 = vld1q_f16(input0 + index); float16x8_t vin1 = vld1q_f16(input1 + index); - float16x8_t vout = vbslq_f16(vcgtq_f16(vin0, vin1), vtrue, vfalse); - vst1q_f16(output + index, vout); + uint8x8_t vout = vmovn_u16(vcgtq_f16(vin0, vin1)); + vst1_u8(output + index, vout); } #endif for (; index < element_size; index++) { - output[index] = (float16_t)(input0[index] > input1[index]); + output[index] = input0[index] > input1[index]; } return NNACL_OK; } -int ElementOptGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptGreaterFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]); - float16x8_t vtrue = vdupq_n_f16(1); - float16x8_t vfalse = vdupq_n_f16(0); #endif int index = 0; if (param->in_elements_num0_ == 1) { #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin1 = vld1q_f16(input1 + index); - float16x8_t vout = vbslq_f16(vcgtq_f16(vin0_opt, vin1), vtrue, vfalse); - vst1q_f16(output + index, vout); + uint8x8_t vout = vmovn_u16(vcgtq_f16(vin0_opt, vin1)); + vst1_u8(output + index, vout); } #endif for (; index < element_size; index++) { - output[index] = (float16_t)(input0[0] > input1[index]); + output[index] = input0[0] > input1[index]; } } else { #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin0 = vld1q_f16(input0 + index); - float16x8_t vout = vbslq_f16(vcgtq_f16(vin0, vin1_opt), vtrue, vfalse); - vst1q_f16(output + index, vout); + uint8x8_t vout = vmovn_u16(vcgtq_f16(vin0, vin1_opt)); + vst1_u8(output + index, vout); } #endif for (; index < element_size; index++) { - output[index] = (float16_t)(input0[index] > input1[0]); + output[index] = input0[index] > input1[0]; } } return NNACL_OK; } -int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { +int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON - float16x8_t vtrue = vdupq_n_f16(1); - float16x8_t vfalse = vdupq_n_f16(0); for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin0 = vld1q_f16(input0 + index); float16x8_t vin1 = vld1q_f16(input1 + index); - float16x8_t vout = vbslq_f16(vcgeq_f16(vin0, vin1), vtrue, vfalse); - vst1q_f16(output + index, vout); + uint8x8_t vout = vmovn_u16(vcgeq_f16(vin0, vin1)); + vst1_u8(output + index, vout); } #endif for (; index < element_size; index++) { - output[index] = (float16_t)(input0[index] >= input1[index]); + output[index] = input0[index] >= input1[index]; } return NNACL_OK; } -int ElementOptGreaterEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptGreaterEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]); - float16x8_t vtrue = vdupq_n_f16(1); - float16x8_t vfalse = vdupq_n_f16(0); #endif int index = 0; if (param->in_elements_num0_ == 1) { #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin1 = vld1q_f16(input1 + index); - float16x8_t vout = vbslq_f16(vcgeq_f16(vin0_opt, vin1), vtrue, vfalse); - vst1q_f16(output + index, vout); + uint8x8_t vout = vmovn_u16(vcgeq_f16(vin0_opt, vin1)); + vst1_u8(output + index, vout); } #endif for (; index < element_size; index++) { - output[index] = (float16_t)(input0[0] >= input1[index]); + output[index] = input0[0] >= input1[index]; } } else { #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin0 = vld1q_f16(input0 + index); - float16x8_t vout = vbslq_f16(vcgeq_f16(vin0, vin1_opt), vtrue, vfalse); - vst1q_f16(output + index, vout); + uint8x8_t vout = vmovn_u16(vcgeq_f16(vin0, vin1_opt)); + vst1_u8(output + index, vout); } #endif for (; index < element_size; index++) { - output[index] = (float16_t)(input0[index] >= input1[0]); + output[index] = input0[index] >= input1[0]; } } return NNACL_OK; diff --git a/mindspore/lite/nnacl/fp16/arithmetic_fp16.h b/mindspore/lite/nnacl/fp16/arithmetic_fp16.h index 17e712a7f8..f27b9d25b5 100644 --- a/mindspore/lite/nnacl/fp16/arithmetic_fp16.h +++ b/mindspore/lite/nnacl/fp16/arithmetic_fp16.h @@ -64,17 +64,17 @@ int ElementOptMaximumFp16(float16_t *input0, float16_t *input1, float16_t *outpu ArithmeticParameter *param); int ElementOptMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param); -int ElementOptNotEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptNotEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, ArithmeticParameter *param); -int ElementOptEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, ArithmeticParameter *param); -int ElementOptLessFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptLessFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, ArithmeticParameter *param); -int ElementOptLessEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptLessEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, ArithmeticParameter *param); -int ElementOptGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptGreaterFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, ArithmeticParameter *param); -int ElementOptGreaterEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptGreaterEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, ArithmeticParameter *param); int ElementMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); @@ -104,12 +104,12 @@ int ElementSquaredDifferenceFp16(float16_t *input0, float16_t *input1, float16_t int ElementMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); int ElementMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); -int ElementNotEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); -int ElementEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); -int ElementLessFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); -int ElementLessEqual(float16_t *input0, float16_t *input1, float16_t *output, int element_size); -int ElementGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); -int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); +int ElementNotEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size); +int ElementEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size); +int ElementLessFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size); +int ElementLessEqual(float16_t *input0, float16_t *input1, uint8_t *output, int element_size); +int ElementGreaterFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size); +int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size); void TileDimensionsFp16(float16_t *data0, float16_t *data1, float16_t *tile_data0, float16_t *tile_data1, ArithmeticParameter *param); diff --git a/mindspore/lite/nnacl/fp16/cast_fp16.c b/mindspore/lite/nnacl/fp16/cast_fp16.c index ee870324f8..235a7b3be2 100644 --- a/mindspore/lite/nnacl/fp16/cast_fp16.c +++ b/mindspore/lite/nnacl/fp16/cast_fp16.c @@ -14,6 +14,19 @@ * limitations under the License. */ #include "nnacl/fp16/cast_fp16.h" + +void BoolToFloat16(const bool *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +void Uint8ToFloat16(const uint8_t *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + #ifndef ENABLE_ARM64 void Float32ToFloat16(const float *input, float16_t *output, int number) { for (int i = 0; i < number; ++i) { diff --git a/mindspore/lite/nnacl/fp16/cast_fp16.h b/mindspore/lite/nnacl/fp16/cast_fp16.h index d3571503db..be942074a8 100644 --- a/mindspore/lite/nnacl/fp16/cast_fp16.h +++ b/mindspore/lite/nnacl/fp16/cast_fp16.h @@ -22,6 +22,8 @@ #ifdef __cplusplus extern "C" { #endif +void BoolToFloat16(const bool *input, float16_t *output, int number); +void Uint8ToFloat16(const uint8_t *input, float16_t *output, int number); void Float32ToFloat16(const float *input, float16_t *output, int number); void Float16ToFloat32(const float16_t *input, float *output, int number); #ifdef __cplusplus diff --git a/mindspore/lite/nnacl/fp32/arithmetic_compare.c b/mindspore/lite/nnacl/fp32/arithmetic_compare.c new file mode 100644 index 0000000000..f8ed586986 --- /dev/null +++ b/mindspore/lite/nnacl/fp32/arithmetic_compare.c @@ -0,0 +1,109 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "nnacl/fp32/arithmetic_compare.h" + +// equal: +int ElementEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = input0[i] == input1[i]; + } + return NNACL_OK; +} + +int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = input0[i] == input1[i]; + } + return NNACL_OK; +} + +// not equal +int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = input0[i] != input1[i]; + } + return NNACL_OK; +} + +int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = input0[i] != input1[i]; + } + return NNACL_OK; +} + +// less +int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = input0[i] < input1[i]; + } + return NNACL_OK; +} + +int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = input0[i] < input1[i]; + } + return NNACL_OK; +} + +// less equal +int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = input0[i] <= input1[i]; + } + return NNACL_OK; +} + +int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = input0[i] <= input1[i]; + } + return NNACL_OK; +} + +// greater +int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = input0[i] > input1[i]; + } + return NNACL_OK; +} + +int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = input0[i] > input1[i]; + } + return NNACL_OK; +} + +// greater equal +int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = input0[i] >= input1[i]; + } + return NNACL_OK; +} + +int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = input0[i] >= input1[i]; + } + return NNACL_OK; +} diff --git a/mindspore/lite/nnacl/fp32/arithmetic_compare.h b/mindspore/lite/nnacl/fp32/arithmetic_compare.h new file mode 100644 index 0000000000..5f0ed2d58b --- /dev/null +++ b/mindspore/lite/nnacl/fp32/arithmetic_compare.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_NNACL_ARITHMETIC_COMPARE_H_ +#define MINDSPORE_LITE_NNACL_ARITHMETIC_COMPARE_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ElementEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size); +int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); + +int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size); +int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); + +int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size); +int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); + +int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size); +int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); + +int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size); +int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); + +int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size); +int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_NNACL_ARITHMETIC_COMPARE_H_ diff --git a/mindspore/lite/nnacl/int8/arithmetic_int8.c b/mindspore/lite/nnacl/int8/arithmetic_int8.c index 9fddb2adbc..3685d61b84 100644 --- a/mindspore/lite/nnacl/int8/arithmetic_int8.c +++ b/mindspore/lite/nnacl/int8/arithmetic_int8.c @@ -22,101 +22,87 @@ #define ACCURACY_DATA 0.00000001 -int ElementNotEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, +int ElementNotEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg) { float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; - const float output_inverse_scale = 1.f / quant_arg->out_args_.scale_; - float out_zp = quant_arg->out_args_.zp_; for (int index = 0; index < element_size; ++index) { float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; float minus_inputs = in0_real - in1_real; - float out_real = (float)true; + bool out_real = true; if (minus_inputs >= -ACCURACY_DATA && minus_inputs <= ACCURACY_DATA) { - out_real = (float)false; + out_real = false; } - output[index] = (int8_t)(out_real * output_inverse_scale + out_zp); + output[index] = (uint8_t)out_real; } return NNACL_OK; } -int ElementEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg) { +int ElementEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg) { float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; - const float output_inverse_scale = 1.f / quant_arg->out_args_.scale_; - float out_zp = quant_arg->out_args_.zp_; for (int index = 0; index < element_size; ++index) { float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; float minus_inputs = in0_real - in1_real; - float out_real = (float)false; + bool out_real = false; if (minus_inputs >= -ACCURACY_DATA && minus_inputs <= ACCURACY_DATA) { - out_real = (float)true; + out_real = true; } - output[index] = (int8_t)(out_real * output_inverse_scale + out_zp); + output[index] = (uint8_t)out_real; } return NNACL_OK; } -int ElementLessInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg) { +int ElementLessInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg) { float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; - const float output_inverse_scale = 1.f / quant_arg->out_args_.scale_; - float out_zp = quant_arg->out_args_.zp_; for (int index = 0; index < element_size; ++index) { float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; - float out_real = (float)(in0_real < in1_real); - output[index] = (int8_t)(out_real * output_inverse_scale + out_zp); + bool out_real = in0_real < in1_real; + output[index] = (uint8_t)out_real; } return NNACL_OK; } -int ElementLessEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, +int ElementLessEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg) { float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; - const float output_inverse_scale = 1.f / quant_arg->out_args_.scale_; - float out_zp = quant_arg->out_args_.zp_; - for (int index = 0; index < element_size; ++index) { float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; - float out_real = (float)(in0_real <= in1_real); - output[index] = (int8_t)(out_real * output_inverse_scale + out_zp); + bool out_real = in0_real <= in1_real; + output[index] = (uint8_t)out_real; } return NNACL_OK; } -int ElementGreaterInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, +int ElementGreaterInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg) { float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; - const float output_inverse_scale = 1.f / quant_arg->out_args_.scale_; - float out_zp = quant_arg->out_args_.zp_; - for (int index = 0; index < element_size; ++index) { float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; - float out_real = (float)(in0_real > in1_real); - output[index] = (int8_t)(out_real * output_inverse_scale + out_zp); + bool out_real = in0_real > in1_real; + output[index] = (uint8_t)out_real; } return NNACL_OK; } -int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, +int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg) { float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; - const float output_inverse_scale = 1.f / quant_arg->out_args_.scale_; - float out_zp = quant_arg->out_args_.zp_; for (int index = 0; index < element_size; ++index) { float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; - float out_real = (float)(in0_real >= in1_real); - output[index] = (int8_t)(out_real * output_inverse_scale + out_zp); + bool out_real = in0_real >= in1_real; + output[index] = (uint8_t)out_real; } return NNACL_OK; } diff --git a/mindspore/lite/nnacl/int8/arithmetic_int8.h b/mindspore/lite/nnacl/int8/arithmetic_int8.h index b442092200..3c1cf6e5b6 100644 --- a/mindspore/lite/nnacl/int8/arithmetic_int8.h +++ b/mindspore/lite/nnacl/int8/arithmetic_int8.h @@ -22,19 +22,20 @@ #ifdef __cplusplus extern "C" { #endif -int ElementNotEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, +int ElementNotEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg); -int ElementEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg); +int ElementEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg); -int ElementLessInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg); +int ElementLessInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg); -int ElementLessEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, +int ElementLessEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg); -int ElementGreaterInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg); +int ElementGreaterInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, + ArithmeticQuantArg *quant_arg); -int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, +int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg); #ifdef __cplusplus diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_compare_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_compare_fp16.cc new file mode 100644 index 0000000000..820876f1bb --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_compare_fp16.cc @@ -0,0 +1,216 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp16/arithmetic_compare_fp16.h" +#include "src/runtime/kernel/arm/fp16/common_fp16.h" +#include "nnacl/fp16/arithmetic_fp16.h" +#include "nnacl/fp16/cast_fp16.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Equal; +using mindspore::schema::PrimitiveType_Greater; +using mindspore::schema::PrimitiveType_GreaterEqual; +using mindspore::schema::PrimitiveType_Less; +using mindspore::schema::PrimitiveType_LessEqual; +using mindspore::schema::PrimitiveType_NotEqual; + +namespace mindspore::kernel { +ARITHMETIC_COMPARE_FUNC_INFO_FP16 arithmetic_cp_fun_table_fp16[] = { + {PrimitiveType_NotEqual, schema::ActivationType_NO_ACTIVATION, ElementNotEqualFp16, ElementOptNotEqualFp16}, + {PrimitiveType_Equal, schema::ActivationType_NO_ACTIVATION, ElementEqualFp16, ElementOptEqualFp16}, + {PrimitiveType_Less, schema::ActivationType_NO_ACTIVATION, ElementLessFp16, ElementOptLessFp16}, + {PrimitiveType_LessEqual, schema::ActivationType_NO_ACTIVATION, ElementLessEqual, ElementOptLessEqualFp16}, + {PrimitiveType_Greater, schema::ActivationType_NO_ACTIVATION, ElementGreaterFp16, ElementOptGreaterFp16}, + {PrimitiveType_GreaterEqual, schema::ActivationType_NO_ACTIVATION, ElementGreaterEqualFp16, + ElementOptGreaterEqualFp16}}; + +ArithmeticCompareFuncFp16 GetArithmeticCompareFun(int primitive_type, int activation_type) { + for (size_t i = 0; i < sizeof(arithmetic_cp_fun_table_fp16); i++) { + if (arithmetic_cp_fun_table_fp16[i].primitive_type_ == primitive_type && + arithmetic_cp_fun_table_fp16[i].activation_type_ == activation_type) { + return arithmetic_cp_fun_table_fp16[i].func_; + } + } + return nullptr; +} + +ArithmeticCompareOptFuncFp16 GetOptimizedArithmeticCompareFun(int primitive_type, int activation_type) { + for (size_t i = 0; i < sizeof(arithmetic_cp_fun_table_fp16); i++) { + if (arithmetic_cp_fun_table_fp16[i].primitive_type_ == primitive_type && + arithmetic_cp_fun_table_fp16[i].activation_type_ == activation_type) { + return arithmetic_cp_fun_table_fp16[i].opt_func_; + } + } + return nullptr; +} + +int ArithmeticCompareFP16CPUKernel::Init() { + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int ArithmeticCompareFP16CPUKernel::ReSize() { + param_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); + param_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); + param_->out_elements_num_ = out_tensors_[0]->ElementsNum(); + + if (param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) { + param_->broadcasting_ = false; + arithmetic_opt_func_ = GetOptimizedArithmeticCompareFun(param_->op_parameter_.type_, param_->activation_type_); + } else { + arithmetic_func_ = GetArithmeticCompareFun(param_->op_parameter_.type_, param_->activation_type_); + } + if (arithmetic_opt_func_ == nullptr && arithmetic_func_ == nullptr) { + MS_LOG(ERROR) << "arithmetic_opt_func_ and arithmetic_func_ function is nullptr!"; + return RET_ERROR; + } + if (param_->broadcasting_) { + outside_ = 1; + for (int i = param_->ndim_ - 1; i >= 0; --i) { + if (param_->in_shape0_[i] != param_->in_shape1_[i]) { + break_pos_ = i; + break; + } + outside_ *= param_->out_shape_[i]; + } + ComputeStrides(param_->in_shape0_, param_->in_strides0_, param_->ndim_); + ComputeStrides(param_->in_shape1_, param_->in_strides1_, param_->ndim_); + ComputeStrides(param_->out_shape_, param_->out_strides_, param_->ndim_); + } + return RET_OK; +} + +int ArithmeticCompareFP16CPUKernel::BroadcastRun(float16_t *input0, float16_t *input1, uint8_t *output, int dim, + int out_count, int cur_offset) { + if (dim > break_pos_) { + return arithmetic_func_(input0 + cur_offset, input1 + cur_offset, output + cur_offset, out_count); + } + for (int i = 0; i < param_->out_shape_[dim]; ++i) { + int pos0 = param_->in_shape0_[dim] == 1 ? 0 : i; + int pos1 = param_->in_shape1_[dim] == 1 ? 0 : i; + int ret = BroadcastRun(input0 + pos0 * param_->in_strides0_[dim], input1 + pos1 * param_->in_strides1_[dim], + output + i * param_->out_strides_[dim], dim + 1, out_count, cur_offset); + if (ret != RET_OK) { + return RET_ERROR; + } + } + return RET_OK; +} + +int ArithmeticCompareFP16CPUKernel::DoArithmetic(int task_id) { + int stride_per_thread = UP_DIV(param_->broadcasting_ ? outside_ : param_->out_elements_num_, context_->thread_num_); + int cur_offset = stride_per_thread * task_id; + int cur_count = param_->broadcasting_ ? MSMIN(stride_per_thread, outside_ - cur_offset) + : MSMIN(stride_per_thread, param_->out_elements_num_ - cur_offset); + + int ret = RET_OK; + if (param_->broadcasting_) { + ret = BroadcastRun(input0_fp16_, input1_fp16_, output_fp16_, 0, cur_count, cur_offset); + } else if (param_->in_elements_num0_ == 1) { + ret = arithmetic_opt_func_(input0_fp16_, input1_fp16_ + cur_offset, output_fp16_ + cur_offset, cur_count, param_); + } else if (param_->in_elements_num1_ == 1) { + ret = arithmetic_opt_func_(input0_fp16_ + cur_offset, input1_fp16_, output_fp16_ + cur_offset, cur_count, param_); + } else { + ret = arithmetic_func_(input0_fp16_ + cur_offset, input1_fp16_ + cur_offset, output_fp16_ + cur_offset, cur_count); + } + if (ret != RET_OK) { + MS_LOG(ERROR) << "DoArithmetic failed, ret = " << ret; + } + return ret; +} + +static int ArithmeticsRunFp16(void *cdata, int task_id) { + auto arithmetic_kernel = reinterpret_cast(cdata); + auto ret = arithmetic_kernel->DoArithmetic(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ArithmeticsRunFp16 error task_id[" << task_id << "] ret[" << ret << "]"; + } + return ret; +} + +int ArithmeticCompareFP16CPUKernel::Run() { + auto output_tensor = out_tensors_.at(0); + is_input0_fp32_ = in_tensors_.at(0)->data_type() == kNumberTypeFloat32; + is_input1_fp32_ = in_tensors_.at(1)->data_type() == kNumberTypeFloat32; + + input0_fp16_ = ConvertInputFp32toFp16(in_tensors_.at(0), context_); + input1_fp16_ = ConvertInputFp32toFp16(in_tensors_.at(1), context_); + output_fp16_ = reinterpret_cast(output_tensor->MutableData()); + if (input0_fp16_ == nullptr || input1_fp16_ == nullptr || output_fp16_ == nullptr) { + MS_LOG(ERROR) << "Memory allocation failed"; + FreeTmpBuffer(); + return RET_ERROR; + } + auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticsRunFp16, this, context_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ArithmeticsRunFp16 run error error_code[" << ret << "]"; + } + FreeTmpBuffer(); + return ret; +} + +void ArithmeticCompareFP16CPUKernel::FreeTmpBuffer() { + if (is_input0_fp32_) { + context_->allocator->Free(input0_fp16_); + input0_fp16_ = nullptr; + } + if (is_input1_fp32_) { + context_->allocator->Free(input1_fp16_); + input1_fp16_ = nullptr; + } +} + +kernel::LiteKernel *CpuArithmeticCompareFp16KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *parameter, const lite::InnerContext *ctx, + const kernel::KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { + if (parameter == nullptr) { + MS_LOG(ERROR) << "input parameter is null!"; + return nullptr; + } + auto kernel = new (std::nothrow) ArithmeticCompareFP16CPUKernel(parameter, inputs, outputs, ctx, primitive); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; + free(parameter); + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_NotEqual, CpuArithmeticCompareFp16KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Equal, CpuArithmeticCompareFp16KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Less, CpuArithmeticCompareFp16KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LessEqual, CpuArithmeticCompareFp16KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Greater, CpuArithmeticCompareFp16KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_GreaterEqual, CpuArithmeticCompareFp16KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_compare_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_compare_fp16.h new file mode 100644 index 0000000000..168678f079 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_compare_fp16.h @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_COMPARE_FP16_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_COMPARE_FP16_H_ + +#include +#include "src/lite_kernel.h" +#include "nnacl/fp16/arithmetic_fp16.h" +#include "schema/model_generated.h" + +namespace mindspore::kernel { +typedef int (*ArithmeticCompareFuncFp16)(float16_t *input0, float16_t *input1, uint8_t *output, int element_size); +typedef int (*ArithmeticCompareOptFuncFp16)(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, + ArithmeticParameter *param); +typedef struct { + int primitive_type_; + int activation_type_; + ArithmeticCompareFuncFp16 func_; + ArithmeticCompareOptFuncFp16 opt_func_; +} ARITHMETIC_COMPARE_FUNC_INFO_FP16; + +class ArithmeticCompareFP16CPUKernel : public LiteKernel { + public: + ArithmeticCompareFP16CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + param_ = reinterpret_cast(parameter); + } + ~ArithmeticCompareFP16CPUKernel() = default; + + int Init() override; + int ReSize() override; + int Run() override; + int DoArithmetic(int task_id); + int BroadcastRun(float16_t *input0, float16_t *input1, uint8_t *output, int dim, int out_count, + int out_thread_stride); + + private: + void FreeTmpBuffer(); + int outside_; + int break_pos_; + bool is_input0_fp32_ = false; + bool is_input1_fp32_ = false; + float16_t *input0_fp16_ = nullptr; + float16_t *input1_fp16_ = nullptr; + uint8_t *output_fp16_ = nullptr; + ArithmeticParameter *param_ = nullptr; + ArithmeticCompareFuncFp16 arithmetic_func_ = nullptr; + ArithmeticCompareOptFuncFp16 arithmetic_opt_func_ = nullptr; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_COMPARE_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc index 0ff67245dd..b6edbc2838 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc @@ -68,15 +68,7 @@ ARITHMETIC_FUNC_INFO_FP16 arithmetic_fun_table_fp16[] = { {PrimitiveType_SquaredDifference, schema::ActivationType_NO_ACTIVATION, ElementSquaredDifferenceFp16, ElementOptSquaredDifferenceFp16}, {PrimitiveType_Maximum, schema::ActivationType_NO_ACTIVATION, ElementMaximumFp16, ElementOptMaximumFp16}, - {PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, ElementMinimumFp16, ElementOptMinimumFp16}, - {PrimitiveType_NotEqual, schema::ActivationType_NO_ACTIVATION, ElementNotEqualFp16, ElementOptNotEqualFp16}, - {PrimitiveType_Equal, schema::ActivationType_NO_ACTIVATION, ElementEqualFp16, ElementOptEqualFp16}, - {PrimitiveType_Less, schema::ActivationType_NO_ACTIVATION, ElementLessFp16, ElementOptLessFp16}, - {PrimitiveType_LessEqual, schema::ActivationType_NO_ACTIVATION, ElementLessEqual, ElementOptLessEqualFp16}, - {PrimitiveType_Greater, schema::ActivationType_NO_ACTIVATION, ElementGreaterFp16, ElementOptGreaterFp16}, - {PrimitiveType_GreaterEqual, schema::ActivationType_NO_ACTIVATION, ElementGreaterEqualFp16, - ElementOptGreaterEqualFp16}, -}; + {PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, ElementMinimumFp16, ElementOptMinimumFp16}}; ArithmeticFuncFp16 GetArithmeticFun(int primitive_type, int activation_type) { for (size_t i = 0; i < sizeof(arithmetic_fun_table_fp16); i++) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/cast_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/cast_fp16.cc index 52c4e6dfe6..3811b3e14c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/cast_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/cast_fp16.cc @@ -67,6 +67,12 @@ int CastFp16CPUKernel::DoCast(int thread_id) { auto offset = thread_id * stride_; auto output_data = out_tensors_.at(0)->MutableData(); switch (input->data_type()) { + case kNumberTypeBool: + BoolToFloat16(reinterpret_cast(input->MutableData()) + offset, + reinterpret_cast(output_data) + offset, data_num); + case kNumberTypeUInt8: + Uint8ToFloat16(reinterpret_cast(input->MutableData()) + offset, + reinterpret_cast(output_data) + offset, data_num); case kNumberTypeFloat32: Float32ToFloat16(reinterpret_cast(input->MutableData()) + offset, reinterpret_cast(output_data) + offset, data_num); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc index 410fdcfa08..3eeace0f20 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc @@ -339,12 +339,6 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_FloorMod, CpuArithmeticFp32Ke REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_FloorDiv, CpuArithmeticFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_FloorMod, CpuArithmeticFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SquaredDifference, CpuArithmeticFp32KernelCreator) -REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Equal, CpuArithmeticFp32KernelCreator) -REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_NotEqual, CpuArithmeticFp32KernelCreator) -REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Less, CpuArithmeticFp32KernelCreator) -REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LessEqual, CpuArithmeticFp32KernelCreator) -REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Greater, CpuArithmeticFp32KernelCreator) -REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_GreaterEqual, CpuArithmeticFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Eltwise, CpuArithmeticFp32KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare.cc new file mode 100644 index 0000000000..dea58b171e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare.cc @@ -0,0 +1,127 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/arithmetic_compare.h" +#include "src/kernel_registry.h" +#include "nnacl/fp32/arithmetic_compare.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Equal; +using mindspore::schema::PrimitiveType_Greater; +using mindspore::schema::PrimitiveType_GreaterEqual; +using mindspore::schema::PrimitiveType_Less; +using mindspore::schema::PrimitiveType_LessEqual; +using mindspore::schema::PrimitiveType_NotEqual; + +namespace mindspore::kernel { +namespace { +typedef struct { + int primitive_type_; + ArithmeticCompareFp32Func func_; +} TYPE_FUNC_INFO; +} // namespace + +ArithmeticCompareFp32Func ArithmeticCompareCPUKernel::GetArithmeticCompareFun(int primitive_type) { + TYPE_FUNC_INFO type_func_table[] = { + {PrimitiveType_Equal, ElementEqualFp32}, {PrimitiveType_NotEqual, ElementNotEqualFp32}, + {PrimitiveType_Less, ElementLessFp32}, {PrimitiveType_LessEqual, ElementLessEqualFp32}, + {PrimitiveType_Greater, ElementGreaterFp32}, {PrimitiveType_GreaterEqual, ElementGreaterEqualFp32}}; + for (size_t i = 0; i < sizeof(type_func_table); i++) { + if (type_func_table[i].primitive_type_ == primitive_type) { + return type_func_table[i].func_; + } + } + return nullptr; +} + +int ArithmeticCompareCPUKernel::Init() { + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int ArithmeticCompareCPUKernel::ReSize() { return RET_OK; } + +int ArithmeticCompareCPUKernel::DoExecute(int task_id) { + int elements_num = in_tensors_.at(0)->ElementsNum(); + int stride = UP_DIV(elements_num, op_parameter_->thread_num_); + int offset = task_id * stride; + int count = MSMIN(stride, elements_num - offset); + if (count <= 0) { + return RET_OK; + } + if (func_ == nullptr) { + MS_LOG(ERROR) << "Run function is null! "; + return RET_ERROR; + } + // two inputs have the same shape, support broadcast later + auto *input0_ptr = reinterpret_cast(in_tensors_.at(0)->MutableData()); + auto *input1_ptr = reinterpret_cast(in_tensors_.at(1)->MutableData()); + auto *output_ptr = reinterpret_cast(out_tensors_.at(0)->MutableData()); + auto ret = func_(input0_ptr + offset, input1_ptr + offset, output_ptr + offset, count); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Run failed, illegal input! "; + } + return ret; +} + +int ArithmeticCompareRun(void *cdata, int task_id) { + auto kernel = reinterpret_cast(cdata); + auto ret = kernel->DoExecute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ArithmeticSelfRuns error task_id[" << task_id << "] error_code[" << ret << "]"; + } + return ret; +} + +int ArithmeticCompareCPUKernel::Run() { + auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticCompareRun, this, op_parameter_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ArithmeticSelfRun error error_code[" << ret << "]"; + } + return ret; +} + +kernel::LiteKernel *CpuArithmeticCompareFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *parameter, const lite::InnerContext *ctx, + const kernel::KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { + auto *kernel = new (std::nothrow) ArithmeticCompareCPUKernel(parameter, inputs, outputs, ctx, primitive); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new ArithmeticSelfCPUKernel fail!"; + free(parameter); + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Equal, CpuArithmeticCompareFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_NotEqual, CpuArithmeticCompareFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Less, CpuArithmeticCompareFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LessEqual, CpuArithmeticCompareFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Greater, CpuArithmeticCompareFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_GreaterEqual, CpuArithmeticCompareFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare.h new file mode 100644 index 0000000000..2d6c39dcf7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_COMPARE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_COMPARE_H_ + +#include +#include "src/runtime/kernel/arm/fp32/arithmetic.h" + +namespace mindspore::kernel { +typedef int (*ArithmeticCompareFp32Func)(const float *input0, const float *input1, uint8_t *output, int element_size); +class ArithmeticCompareCPUKernel : public ArithmeticCPUKernel { + public: + explicit ArithmeticCompareCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : ArithmeticCPUKernel(parameter, inputs, outputs, ctx, primitive) { + func_ = GetArithmeticCompareFun(parameter->type_); + } + ~ArithmeticCompareCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + virtual int DoExecute(int task_id); + + private: + ArithmeticCompareFp32Func GetArithmeticCompareFun(int primitive_type); + ArithmeticCompareFp32Func func_; +}; +int ArithmeticCompareRun(void *cdata, int task_id); +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_COMPARE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc index 0acaa7ec90..711e755470 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc @@ -211,7 +211,6 @@ kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector(inputs.at(kWeightIndex)->data_c()); - auto *origin_bias = reinterpret_cast(inputs.at(kBiasIndex)->data_c()); for (int i = 0; i < group; ++i) { std::vector new_inputs; @@ -234,6 +233,7 @@ kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector(inputs.at(kBiasIndex)->data_c()); auto bias_tensor = new (std::nothrow) lite::Tensor(inputs.at(kBiasIndex)->data_type(), bias_shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR); bias_tensor->MallocData(); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.cc index 8ea73859ed..c2690ab170 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.cc @@ -104,7 +104,7 @@ int ArithmeticInt8CPUKernel::ReSize() { return RET_OK; } int ArithmeticInt8CPUKernel::DoArithmetic(int thread_id) { auto input0_data = reinterpret_cast(in_tensors_[0]->MutableData()); auto input1_data1 = reinterpret_cast(in_tensors_[1]->MutableData()); - auto output_data = reinterpret_cast(out_tensors_[0]->MutableData()); + auto output_data = reinterpret_cast(out_tensors_[0]->MutableData()); auto element_num = out_tensors_[0]->ElementsNum(); auto param = reinterpret_cast(op_parameter_); if (param->broadcasting_ && arithmetic_run_ != nullptr) { diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.h index 5dec016406..ceb082b79e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.h @@ -24,7 +24,7 @@ namespace mindspore::kernel { class ArithmeticInt8CPUKernel : public LiteKernel { - typedef int (*ArithmeticRunInt8)(int8_t *input0, int8_t *input1, int8_t *output, int element_size, + typedef int (*ArithmeticRunInt8)(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg); public: