| @@ -985,319 +985,295 @@ int ElementOptMinimumFp16(float16_t *input0, float16_t *input1, float16_t *outpu | |||
| return NNACL_OK; | |||
| } | |||
| int ElementNotEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { | |||
| int ElementNotEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size) { | |||
| int index = 0; | |||
| #ifdef ENABLE_NEON | |||
| float16x8_t vtrue = vdupq_n_f16(1); | |||
| float16x8_t vfalse = vdupq_n_f16(0); | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin0 = vld1q_f16(input0 + index); | |||
| float16x8_t vin1 = vld1q_f16(input1 + index); | |||
| float16x8_t vout = vbslq_f16(vceqq_f16(vin0, vin1), vfalse, vtrue); | |||
| vst1q_f16(output + index, vout); | |||
| uint8x8_t vout = vmovn_u16(vceqq_f16(vin0, vin1)); | |||
| vst1_u8(output + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < element_size; index++) { | |||
| output[index] = (float16_t)(input0[index] != input1[index]); | |||
| output[index] = input0[index] != input1[index]; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementOptNotEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, | |||
| int ElementOptNotEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, | |||
| ArithmeticParameter *param) { | |||
| #ifdef ENABLE_NEON | |||
| float16x8_t vin0_opt = vdupq_n_f16(input0[0]); | |||
| float16x8_t vin1_opt = vdupq_n_f16(input1[0]); | |||
| float16x8_t vtrue = vdupq_n_f16(1); | |||
| float16x8_t vfalse = vdupq_n_f16(0); | |||
| #endif | |||
| int index = 0; | |||
| if (param->in_elements_num0_ == 1) { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin1 = vld1q_f16(input1 + index); | |||
| float16x8_t vout = vbslq_f16(vceqq_f16(vin0_opt, vin1), vfalse, vtrue); | |||
| vst1q_f16(output + index, vout); | |||
| uint8x8_t vout = vmovn_u16(vceqq_f16(vin0_opt, vin1)); | |||
| vst1_u8(output + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < element_size; index++) { | |||
| output[index] = (float16_t)(input0[0] != input1[index]); | |||
| output[index] = input0[0] != input1[index]; | |||
| } | |||
| } else { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin0 = vld1q_f16(input0 + index); | |||
| float16x8_t vout = vbslq_f16(vceqq_f16(vin0, vin1_opt), vfalse, vtrue); | |||
| vst1q_f16(output + index, vout); | |||
| uint8x8_t vout = vmovn_u16(vceqq_f16(vin0, vin1_opt)); | |||
| vst1_u8(output + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < element_size; index++) { | |||
| output[index] = (float16_t)(input0[index] != input1[0]); | |||
| output[index] = input0[index] != input1[0]; | |||
| } | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { | |||
| int ElementEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size) { | |||
| int index = 0; | |||
| #ifdef ENABLE_NEON | |||
| float16x8_t vtrue = vdupq_n_f16(1); | |||
| float16x8_t vfalse = vdupq_n_f16(0); | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin0 = vld1q_f16(input0 + index); | |||
| float16x8_t vin1 = vld1q_f16(input1 + index); | |||
| float16x8_t vout = vbslq_f16(vceqq_f16(vin0, vin1), vtrue, vfalse); | |||
| vst1q_f16(output + index, vout); | |||
| uint8x8_t vout = vmovn_u16(vceqq_f16(vin0, vin1)); | |||
| vst1_u8(output + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < element_size; index++) { | |||
| output[index] = (float16_t)(input0[index] == input1[index]); | |||
| output[index] = input0[index] == input1[index]; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementOptEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, | |||
| int ElementOptEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, | |||
| ArithmeticParameter *param) { | |||
| #ifdef ENABLE_NEON | |||
| float16x8_t vin0_opt = vdupq_n_f16(input0[0]); | |||
| float16x8_t vin1_opt = vdupq_n_f16(input1[0]); | |||
| float16x8_t vtrue = vdupq_n_f16(1); | |||
| float16x8_t vfalse = vdupq_n_f16(0); | |||
| #endif | |||
| int index = 0; | |||
| if (param->in_elements_num0_ == 1) { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin1 = vld1q_f16(input1 + index); | |||
| float16x8_t vout = vbslq_f16(vceqq_f16(vin0_opt, vin1), vtrue, vfalse); | |||
| vst1q_f16(output + index, vout); | |||
| uint8x8_t vout = vmovn_u16(vceqq_f16(vin0_opt, vin1)); | |||
| vst1_u8(output + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < element_size; index++) { | |||
| output[index] = (float16_t)(input0[0] == input1[index]); | |||
| output[index] = input0[0] == input1[index]; | |||
| } | |||
| } else { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin0 = vld1q_f16(input0 + index); | |||
| float16x8_t vout = vbslq_f16(vceqq_f16(vin0, vin1_opt), vtrue, vfalse); | |||
| vst1q_f16(output + index, vout); | |||
| uint8x8_t vout = vmovn_u16(vceqq_f16(vin0, vin1_opt)); | |||
| vst1_u8(output + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < element_size; index++) { | |||
| output[index] = (float16_t)(input0[index] == input1[0]); | |||
| output[index] = input0[index] == input1[0]; | |||
| } | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementLessFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { | |||
| int ElementLessFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size) { | |||
| int index = 0; | |||
| #ifdef ENABLE_NEON | |||
| float16x8_t vtrue = vdupq_n_f16(1); | |||
| float16x8_t vfalse = vdupq_n_f16(0); | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin0 = vld1q_f16(input0 + index); | |||
| float16x8_t vin1 = vld1q_f16(input1 + index); | |||
| float16x8_t vout = vbslq_f16(vcltq_f16(vin0, vin1), vtrue, vfalse); | |||
| vst1q_f16(output + index, vout); | |||
| uint8x8_t vout = vmovn_u16(vcltq_f16(vin0, vin1)); | |||
| vst1_u8(output + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < element_size; index++) { | |||
| output[index] = (float16_t)(input0[index] < input1[index]); | |||
| output[index] = input0[index] < input1[index]; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementOptLessFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, | |||
| int ElementOptLessFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, | |||
| ArithmeticParameter *param) { | |||
| #ifdef ENABLE_NEON | |||
| float16x8_t vin0_opt = vdupq_n_f16(input0[0]); | |||
| float16x8_t vin1_opt = vdupq_n_f16(input1[0]); | |||
| float16x8_t vtrue = vdupq_n_f16(1); | |||
| float16x8_t vfalse = vdupq_n_f16(0); | |||
| #endif | |||
| int index = 0; | |||
| if (param->in_elements_num0_ == 1) { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin1 = vld1q_f16(input1 + index); | |||
| float16x8_t vout = vbslq_f16(vcltq_f16(vin0_opt, vin1), vtrue, vfalse); | |||
| vst1q_f16(output + index, vout); | |||
| uint8x8_t vout = vmovn_u16(vcltq_f16(vin0_opt, vin1)); | |||
| vst1_u8(output + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < element_size; index++) { | |||
| output[index] = (float16_t)(input0[0] < input1[index]); | |||
| output[index] = input0[0] < input1[index]; | |||
| } | |||
| } else { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin0 = vld1q_f16(input0 + index); | |||
| float16x8_t vout = vbslq_f16(vcltq_f16(vin0, vin1_opt), vtrue, vfalse); | |||
| vst1q_f16(output + index, vout); | |||
| uint8x8_t vout = vmovn_u16(vcltq_f16(vin0, vin1_opt)); | |||
| vst1_u8(output + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < element_size; index++) { | |||
| output[index] = (float16_t)(input0[index] < input1[0]); | |||
| output[index] = input0[index] < input1[0]; | |||
| } | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementLessEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { | |||
| int ElementLessEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size) { | |||
| int index = 0; | |||
| #ifdef ENABLE_NEON | |||
| float16x8_t vtrue = vdupq_n_f16(1); | |||
| float16x8_t vfalse = vdupq_n_f16(0); | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin0 = vld1q_f16(input0 + index); | |||
| float16x8_t vin1 = vld1q_f16(input1 + index); | |||
| float16x8_t vout = vbslq_f16(vcleq_f16(vin0, vin1), vtrue, vfalse); | |||
| vst1q_f16(output + index, vout); | |||
| uint8x8_t vout = vmovn_u16(vcleq_f16(vin0, vin1)); | |||
| vst1_u8(output + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < element_size; index++) { | |||
| output[index] = (float16_t)(input0[index] <= input1[index]); | |||
| output[index] = input0[index] <= input1[index]; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementOptLessEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, | |||
| int ElementOptLessEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, | |||
| ArithmeticParameter *param) { | |||
| #ifdef ENABLE_NEON | |||
| float16x8_t vin0_opt = vdupq_n_f16(input0[0]); | |||
| float16x8_t vin1_opt = vdupq_n_f16(input1[0]); | |||
| float16x8_t vtrue = vdupq_n_f16(1); | |||
| float16x8_t vfalse = vdupq_n_f16(0); | |||
| #endif | |||
| int index = 0; | |||
| if (param->in_elements_num0_ == 1) { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin1 = vld1q_f16(input1 + index); | |||
| float16x8_t vout = vbslq_f16(vcleq_f16(vin0_opt, vin1), vtrue, vfalse); | |||
| vst1q_f16(output + index, vout); | |||
| uint8x8_t vout = vmovn_u16(vcleq_f16(vin0_opt, vin1)); | |||
| vst1_u8(output + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < element_size; index++) { | |||
| output[index] = (float16_t)(input0[0] <= input1[index]); | |||
| output[index] = input0[0] <= input1[index]; | |||
| } | |||
| } else { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin0 = vld1q_f16(input0 + index); | |||
| float16x8_t vout = vbslq_f16(vcleq_f16(vin0, vin1_opt), vtrue, vfalse); | |||
| vst1q_f16(output + index, vout); | |||
| uint8x8_t vout = vmovn_u16(vcleq_f16(vin0, vin1_opt)); | |||
| vst1_u8(output + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < element_size; index++) { | |||
| output[index] = (float16_t)(input0[index] <= input1[0]); | |||
| output[index] = input0[index] <= input1[0]; | |||
| } | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { | |||
| int ElementGreaterFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size) { | |||
| int index = 0; | |||
| #ifdef ENABLE_NEON | |||
| float16x8_t vtrue = vdupq_n_f16(1); | |||
| float16x8_t vfalse = vdupq_n_f16(0); | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin0 = vld1q_f16(input0 + index); | |||
| float16x8_t vin1 = vld1q_f16(input1 + index); | |||
| float16x8_t vout = vbslq_f16(vcgtq_f16(vin0, vin1), vtrue, vfalse); | |||
| vst1q_f16(output + index, vout); | |||
| uint8x8_t vout = vmovn_u16(vcgtq_f16(vin0, vin1)); | |||
| vst1_u8(output + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < element_size; index++) { | |||
| output[index] = (float16_t)(input0[index] > input1[index]); | |||
| output[index] = input0[index] > input1[index]; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementOptGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, | |||
| int ElementOptGreaterFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, | |||
| ArithmeticParameter *param) { | |||
| #ifdef ENABLE_NEON | |||
| float16x8_t vin0_opt = vdupq_n_f16(input0[0]); | |||
| float16x8_t vin1_opt = vdupq_n_f16(input1[0]); | |||
| float16x8_t vtrue = vdupq_n_f16(1); | |||
| float16x8_t vfalse = vdupq_n_f16(0); | |||
| #endif | |||
| int index = 0; | |||
| if (param->in_elements_num0_ == 1) { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin1 = vld1q_f16(input1 + index); | |||
| float16x8_t vout = vbslq_f16(vcgtq_f16(vin0_opt, vin1), vtrue, vfalse); | |||
| vst1q_f16(output + index, vout); | |||
| uint8x8_t vout = vmovn_u16(vcgtq_f16(vin0_opt, vin1)); | |||
| vst1_u8(output + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < element_size; index++) { | |||
| output[index] = (float16_t)(input0[0] > input1[index]); | |||
| output[index] = input0[0] > input1[index]; | |||
| } | |||
| } else { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin0 = vld1q_f16(input0 + index); | |||
| float16x8_t vout = vbslq_f16(vcgtq_f16(vin0, vin1_opt), vtrue, vfalse); | |||
| vst1q_f16(output + index, vout); | |||
| uint8x8_t vout = vmovn_u16(vcgtq_f16(vin0, vin1_opt)); | |||
| vst1_u8(output + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < element_size; index++) { | |||
| output[index] = (float16_t)(input0[index] > input1[0]); | |||
| output[index] = input0[index] > input1[0]; | |||
| } | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { | |||
| int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size) { | |||
| int index = 0; | |||
| #ifdef ENABLE_NEON | |||
| float16x8_t vtrue = vdupq_n_f16(1); | |||
| float16x8_t vfalse = vdupq_n_f16(0); | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin0 = vld1q_f16(input0 + index); | |||
| float16x8_t vin1 = vld1q_f16(input1 + index); | |||
| float16x8_t vout = vbslq_f16(vcgeq_f16(vin0, vin1), vtrue, vfalse); | |||
| vst1q_f16(output + index, vout); | |||
| uint8x8_t vout = vmovn_u16(vcgeq_f16(vin0, vin1)); | |||
| vst1_u8(output + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < element_size; index++) { | |||
| output[index] = (float16_t)(input0[index] >= input1[index]); | |||
| output[index] = input0[index] >= input1[index]; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementOptGreaterEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, | |||
| int ElementOptGreaterEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, | |||
| ArithmeticParameter *param) { | |||
| #ifdef ENABLE_NEON | |||
| float16x8_t vin0_opt = vdupq_n_f16(input0[0]); | |||
| float16x8_t vin1_opt = vdupq_n_f16(input1[0]); | |||
| float16x8_t vtrue = vdupq_n_f16(1); | |||
| float16x8_t vfalse = vdupq_n_f16(0); | |||
| #endif | |||
| int index = 0; | |||
| if (param->in_elements_num0_ == 1) { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin1 = vld1q_f16(input1 + index); | |||
| float16x8_t vout = vbslq_f16(vcgeq_f16(vin0_opt, vin1), vtrue, vfalse); | |||
| vst1q_f16(output + index, vout); | |||
| uint8x8_t vout = vmovn_u16(vcgeq_f16(vin0_opt, vin1)); | |||
| vst1_u8(output + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < element_size; index++) { | |||
| output[index] = (float16_t)(input0[0] >= input1[index]); | |||
| output[index] = input0[0] >= input1[index]; | |||
| } | |||
| } else { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin0 = vld1q_f16(input0 + index); | |||
| float16x8_t vout = vbslq_f16(vcgeq_f16(vin0, vin1_opt), vtrue, vfalse); | |||
| vst1q_f16(output + index, vout); | |||
| uint8x8_t vout = vmovn_u16(vcgeq_f16(vin0, vin1_opt)); | |||
| vst1_u8(output + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < element_size; index++) { | |||
| output[index] = (float16_t)(input0[index] >= input1[0]); | |||
| output[index] = input0[index] >= input1[0]; | |||
| } | |||
| } | |||
| return NNACL_OK; | |||
| @@ -64,17 +64,17 @@ int ElementOptMaximumFp16(float16_t *input0, float16_t *input1, float16_t *outpu | |||
| ArithmeticParameter *param); | |||
| int ElementOptMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, | |||
| ArithmeticParameter *param); | |||
| int ElementOptNotEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, | |||
| int ElementOptNotEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, | |||
| ArithmeticParameter *param); | |||
| int ElementOptEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, | |||
| int ElementOptEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, | |||
| ArithmeticParameter *param); | |||
| int ElementOptLessFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, | |||
| int ElementOptLessFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, | |||
| ArithmeticParameter *param); | |||
| int ElementOptLessEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, | |||
| int ElementOptLessEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, | |||
| ArithmeticParameter *param); | |||
| int ElementOptGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, | |||
| int ElementOptGreaterFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, | |||
| ArithmeticParameter *param); | |||
| int ElementOptGreaterEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, | |||
| int ElementOptGreaterEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, | |||
| ArithmeticParameter *param); | |||
| int ElementMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); | |||
| @@ -104,12 +104,12 @@ int ElementSquaredDifferenceFp16(float16_t *input0, float16_t *input1, float16_t | |||
| int ElementMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); | |||
| int ElementMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); | |||
| int ElementNotEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); | |||
| int ElementEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); | |||
| int ElementLessFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); | |||
| int ElementLessEqual(float16_t *input0, float16_t *input1, float16_t *output, int element_size); | |||
| int ElementGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); | |||
| int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); | |||
| int ElementNotEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size); | |||
| int ElementEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size); | |||
| int ElementLessFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size); | |||
| int ElementLessEqual(float16_t *input0, float16_t *input1, uint8_t *output, int element_size); | |||
| int ElementGreaterFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size); | |||
| int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size); | |||
| void TileDimensionsFp16(float16_t *data0, float16_t *data1, float16_t *tile_data0, float16_t *tile_data1, | |||
| ArithmeticParameter *param); | |||
| @@ -14,6 +14,19 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/fp16/cast_fp16.h" | |||
| void BoolToFloat16(const bool *input, float16_t *output, int number) { | |||
| for (int i = 0; i < number; ++i) { | |||
| output[i] = (float16_t)input[i]; | |||
| } | |||
| } | |||
| void Uint8ToFloat16(const uint8_t *input, float16_t *output, int number) { | |||
| for (int i = 0; i < number; ++i) { | |||
| output[i] = (float16_t)input[i]; | |||
| } | |||
| } | |||
| #ifndef ENABLE_ARM64 | |||
| void Float32ToFloat16(const float *input, float16_t *output, int number) { | |||
| for (int i = 0; i < number; ++i) { | |||
| @@ -22,6 +22,8 @@ | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void BoolToFloat16(const bool *input, float16_t *output, int number); | |||
| void Uint8ToFloat16(const uint8_t *input, float16_t *output, int number); | |||
| void Float32ToFloat16(const float *input, float16_t *output, int number); | |||
| void Float16ToFloat32(const float16_t *input, float *output, int number); | |||
| #ifdef __cplusplus | |||
| @@ -0,0 +1,109 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <string.h> | |||
| #include <math.h> | |||
| #include "nnacl/fp32/arithmetic_compare.h" | |||
| // equal: | |||
| int ElementEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { | |||
| for (int i = 0; i < element_size; i++) { | |||
| output[i] = input0[i] == input1[i]; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { | |||
| for (int i = 0; i < element_size; i++) { | |||
| output[i] = input0[i] == input1[i]; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| // not equal | |||
| int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { | |||
| for (int i = 0; i < element_size; i++) { | |||
| output[i] = input0[i] != input1[i]; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { | |||
| for (int i = 0; i < element_size; i++) { | |||
| output[i] = input0[i] != input1[i]; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| // less | |||
| int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { | |||
| for (int i = 0; i < element_size; i++) { | |||
| output[i] = input0[i] < input1[i]; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { | |||
| for (int i = 0; i < element_size; i++) { | |||
| output[i] = input0[i] < input1[i]; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| // less equal | |||
| int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { | |||
| for (int i = 0; i < element_size; i++) { | |||
| output[i] = input0[i] <= input1[i]; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { | |||
| for (int i = 0; i < element_size; i++) { | |||
| output[i] = input0[i] <= input1[i]; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| // greater | |||
| int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { | |||
| for (int i = 0; i < element_size; i++) { | |||
| output[i] = input0[i] > input1[i]; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { | |||
| for (int i = 0; i < element_size; i++) { | |||
| output[i] = input0[i] > input1[i]; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| // greater equal | |||
| int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { | |||
| for (int i = 0; i < element_size; i++) { | |||
| output[i] = input0[i] >= input1[i]; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { | |||
| for (int i = 0; i < element_size; i++) { | |||
| output[i] = input0[i] >= input1[i]; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_ARITHMETIC_COMPARE_H_ | |||
| #define MINDSPORE_LITE_NNACL_ARITHMETIC_COMPARE_H_ | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/errorcode.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| int ElementEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size); | |||
| int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); | |||
| int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size); | |||
| int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); | |||
| int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size); | |||
| int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); | |||
| int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size); | |||
| int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); | |||
| int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size); | |||
| int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); | |||
| int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size); | |||
| int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_ARITHMETIC_COMPARE_H_ | |||
| @@ -22,101 +22,87 @@ | |||
| #define ACCURACY_DATA 0.00000001 | |||
| int ElementNotEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, | |||
| int ElementNotEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, | |||
| ArithmeticQuantArg *quant_arg) { | |||
| float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; | |||
| float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; | |||
| const float output_inverse_scale = 1.f / quant_arg->out_args_.scale_; | |||
| float out_zp = quant_arg->out_args_.zp_; | |||
| for (int index = 0; index < element_size; ++index) { | |||
| float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; | |||
| float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; | |||
| float minus_inputs = in0_real - in1_real; | |||
| float out_real = (float)true; | |||
| bool out_real = true; | |||
| if (minus_inputs >= -ACCURACY_DATA && minus_inputs <= ACCURACY_DATA) { | |||
| out_real = (float)false; | |||
| out_real = false; | |||
| } | |||
| output[index] = (int8_t)(out_real * output_inverse_scale + out_zp); | |||
| output[index] = (uint8_t)out_real; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg) { | |||
| int ElementEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg) { | |||
| float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; | |||
| float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; | |||
| const float output_inverse_scale = 1.f / quant_arg->out_args_.scale_; | |||
| float out_zp = quant_arg->out_args_.zp_; | |||
| for (int index = 0; index < element_size; ++index) { | |||
| float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; | |||
| float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; | |||
| float minus_inputs = in0_real - in1_real; | |||
| float out_real = (float)false; | |||
| bool out_real = false; | |||
| if (minus_inputs >= -ACCURACY_DATA && minus_inputs <= ACCURACY_DATA) { | |||
| out_real = (float)true; | |||
| out_real = true; | |||
| } | |||
| output[index] = (int8_t)(out_real * output_inverse_scale + out_zp); | |||
| output[index] = (uint8_t)out_real; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementLessInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg) { | |||
| int ElementLessInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg) { | |||
| float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; | |||
| float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; | |||
| const float output_inverse_scale = 1.f / quant_arg->out_args_.scale_; | |||
| float out_zp = quant_arg->out_args_.zp_; | |||
| for (int index = 0; index < element_size; ++index) { | |||
| float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; | |||
| float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; | |||
| float out_real = (float)(in0_real < in1_real); | |||
| output[index] = (int8_t)(out_real * output_inverse_scale + out_zp); | |||
| bool out_real = in0_real < in1_real; | |||
| output[index] = (uint8_t)out_real; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementLessEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, | |||
| int ElementLessEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, | |||
| ArithmeticQuantArg *quant_arg) { | |||
| float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; | |||
| float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; | |||
| const float output_inverse_scale = 1.f / quant_arg->out_args_.scale_; | |||
| float out_zp = quant_arg->out_args_.zp_; | |||
| for (int index = 0; index < element_size; ++index) { | |||
| float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; | |||
| float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; | |||
| float out_real = (float)(in0_real <= in1_real); | |||
| output[index] = (int8_t)(out_real * output_inverse_scale + out_zp); | |||
| bool out_real = in0_real <= in1_real; | |||
| output[index] = (uint8_t)out_real; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementGreaterInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, | |||
| int ElementGreaterInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, | |||
| ArithmeticQuantArg *quant_arg) { | |||
| float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; | |||
| float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; | |||
| const float output_inverse_scale = 1.f / quant_arg->out_args_.scale_; | |||
| float out_zp = quant_arg->out_args_.zp_; | |||
| for (int index = 0; index < element_size; ++index) { | |||
| float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; | |||
| float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; | |||
| float out_real = (float)(in0_real > in1_real); | |||
| output[index] = (int8_t)(out_real * output_inverse_scale + out_zp); | |||
| bool out_real = in0_real > in1_real; | |||
| output[index] = (uint8_t)out_real; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, | |||
| int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, | |||
| ArithmeticQuantArg *quant_arg) { | |||
| float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; | |||
| float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; | |||
| const float output_inverse_scale = 1.f / quant_arg->out_args_.scale_; | |||
| float out_zp = quant_arg->out_args_.zp_; | |||
| for (int index = 0; index < element_size; ++index) { | |||
| float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; | |||
| float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; | |||
| float out_real = (float)(in0_real >= in1_real); | |||
| output[index] = (int8_t)(out_real * output_inverse_scale + out_zp); | |||
| bool out_real = in0_real >= in1_real; | |||
| output[index] = (uint8_t)out_real; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| @@ -22,19 +22,20 @@ | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| int ElementNotEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, | |||
| int ElementNotEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, | |||
| ArithmeticQuantArg *quant_arg); | |||
| int ElementEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg); | |||
| int ElementEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg); | |||
| int ElementLessInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg); | |||
| int ElementLessInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg); | |||
| int ElementLessEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, | |||
| int ElementLessEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, | |||
| ArithmeticQuantArg *quant_arg); | |||
| int ElementGreaterInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg); | |||
| int ElementGreaterInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, | |||
| ArithmeticQuantArg *quant_arg); | |||
| int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, | |||
| int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, | |||
| ArithmeticQuantArg *quant_arg); | |||
| #ifdef __cplusplus | |||
| @@ -0,0 +1,216 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/fp16/arithmetic_compare_fp16.h" | |||
| #include "src/runtime/kernel/arm/fp16/common_fp16.h" | |||
| #include "nnacl/fp16/arithmetic_fp16.h" | |||
| #include "nnacl/fp16/cast_fp16.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "include/errorcode.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_Equal; | |||
| using mindspore::schema::PrimitiveType_Greater; | |||
| using mindspore::schema::PrimitiveType_GreaterEqual; | |||
| using mindspore::schema::PrimitiveType_Less; | |||
| using mindspore::schema::PrimitiveType_LessEqual; | |||
| using mindspore::schema::PrimitiveType_NotEqual; | |||
| namespace mindspore::kernel { | |||
| ARITHMETIC_COMPARE_FUNC_INFO_FP16 arithmetic_cp_fun_table_fp16[] = { | |||
| {PrimitiveType_NotEqual, schema::ActivationType_NO_ACTIVATION, ElementNotEqualFp16, ElementOptNotEqualFp16}, | |||
| {PrimitiveType_Equal, schema::ActivationType_NO_ACTIVATION, ElementEqualFp16, ElementOptEqualFp16}, | |||
| {PrimitiveType_Less, schema::ActivationType_NO_ACTIVATION, ElementLessFp16, ElementOptLessFp16}, | |||
| {PrimitiveType_LessEqual, schema::ActivationType_NO_ACTIVATION, ElementLessEqual, ElementOptLessEqualFp16}, | |||
| {PrimitiveType_Greater, schema::ActivationType_NO_ACTIVATION, ElementGreaterFp16, ElementOptGreaterFp16}, | |||
| {PrimitiveType_GreaterEqual, schema::ActivationType_NO_ACTIVATION, ElementGreaterEqualFp16, | |||
| ElementOptGreaterEqualFp16}}; | |||
| ArithmeticCompareFuncFp16 GetArithmeticCompareFun(int primitive_type, int activation_type) { | |||
| for (size_t i = 0; i < sizeof(arithmetic_cp_fun_table_fp16); i++) { | |||
| if (arithmetic_cp_fun_table_fp16[i].primitive_type_ == primitive_type && | |||
| arithmetic_cp_fun_table_fp16[i].activation_type_ == activation_type) { | |||
| return arithmetic_cp_fun_table_fp16[i].func_; | |||
| } | |||
| } | |||
| return nullptr; | |||
| } | |||
| ArithmeticCompareOptFuncFp16 GetOptimizedArithmeticCompareFun(int primitive_type, int activation_type) { | |||
| for (size_t i = 0; i < sizeof(arithmetic_cp_fun_table_fp16); i++) { | |||
| if (arithmetic_cp_fun_table_fp16[i].primitive_type_ == primitive_type && | |||
| arithmetic_cp_fun_table_fp16[i].activation_type_ == activation_type) { | |||
| return arithmetic_cp_fun_table_fp16[i].opt_func_; | |||
| } | |||
| } | |||
| return nullptr; | |||
| } | |||
| int ArithmeticCompareFP16CPUKernel::Init() { | |||
| if (!InferShapeDone()) { | |||
| return RET_OK; | |||
| } | |||
| return ReSize(); | |||
| } | |||
| int ArithmeticCompareFP16CPUKernel::ReSize() { | |||
| param_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); | |||
| param_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); | |||
| param_->out_elements_num_ = out_tensors_[0]->ElementsNum(); | |||
| if (param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) { | |||
| param_->broadcasting_ = false; | |||
| arithmetic_opt_func_ = GetOptimizedArithmeticCompareFun(param_->op_parameter_.type_, param_->activation_type_); | |||
| } else { | |||
| arithmetic_func_ = GetArithmeticCompareFun(param_->op_parameter_.type_, param_->activation_type_); | |||
| } | |||
| if (arithmetic_opt_func_ == nullptr && arithmetic_func_ == nullptr) { | |||
| MS_LOG(ERROR) << "arithmetic_opt_func_ and arithmetic_func_ function is nullptr!"; | |||
| return RET_ERROR; | |||
| } | |||
| if (param_->broadcasting_) { | |||
| outside_ = 1; | |||
| for (int i = param_->ndim_ - 1; i >= 0; --i) { | |||
| if (param_->in_shape0_[i] != param_->in_shape1_[i]) { | |||
| break_pos_ = i; | |||
| break; | |||
| } | |||
| outside_ *= param_->out_shape_[i]; | |||
| } | |||
| ComputeStrides(param_->in_shape0_, param_->in_strides0_, param_->ndim_); | |||
| ComputeStrides(param_->in_shape1_, param_->in_strides1_, param_->ndim_); | |||
| ComputeStrides(param_->out_shape_, param_->out_strides_, param_->ndim_); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int ArithmeticCompareFP16CPUKernel::BroadcastRun(float16_t *input0, float16_t *input1, uint8_t *output, int dim, | |||
| int out_count, int cur_offset) { | |||
| if (dim > break_pos_) { | |||
| return arithmetic_func_(input0 + cur_offset, input1 + cur_offset, output + cur_offset, out_count); | |||
| } | |||
| for (int i = 0; i < param_->out_shape_[dim]; ++i) { | |||
| int pos0 = param_->in_shape0_[dim] == 1 ? 0 : i; | |||
| int pos1 = param_->in_shape1_[dim] == 1 ? 0 : i; | |||
| int ret = BroadcastRun(input0 + pos0 * param_->in_strides0_[dim], input1 + pos1 * param_->in_strides1_[dim], | |||
| output + i * param_->out_strides_[dim], dim + 1, out_count, cur_offset); | |||
| if (ret != RET_OK) { | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int ArithmeticCompareFP16CPUKernel::DoArithmetic(int task_id) { | |||
| int stride_per_thread = UP_DIV(param_->broadcasting_ ? outside_ : param_->out_elements_num_, context_->thread_num_); | |||
| int cur_offset = stride_per_thread * task_id; | |||
| int cur_count = param_->broadcasting_ ? MSMIN(stride_per_thread, outside_ - cur_offset) | |||
| : MSMIN(stride_per_thread, param_->out_elements_num_ - cur_offset); | |||
| int ret = RET_OK; | |||
| if (param_->broadcasting_) { | |||
| ret = BroadcastRun(input0_fp16_, input1_fp16_, output_fp16_, 0, cur_count, cur_offset); | |||
| } else if (param_->in_elements_num0_ == 1) { | |||
| ret = arithmetic_opt_func_(input0_fp16_, input1_fp16_ + cur_offset, output_fp16_ + cur_offset, cur_count, param_); | |||
| } else if (param_->in_elements_num1_ == 1) { | |||
| ret = arithmetic_opt_func_(input0_fp16_ + cur_offset, input1_fp16_, output_fp16_ + cur_offset, cur_count, param_); | |||
| } else { | |||
| ret = arithmetic_func_(input0_fp16_ + cur_offset, input1_fp16_ + cur_offset, output_fp16_ + cur_offset, cur_count); | |||
| } | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "DoArithmetic failed, ret = " << ret; | |||
| } | |||
| return ret; | |||
| } | |||
| static int ArithmeticsRunFp16(void *cdata, int task_id) { | |||
| auto arithmetic_kernel = reinterpret_cast<ArithmeticCompareFP16CPUKernel *>(cdata); | |||
| auto ret = arithmetic_kernel->DoArithmetic(task_id); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ArithmeticsRunFp16 error task_id[" << task_id << "] ret[" << ret << "]"; | |||
| } | |||
| return ret; | |||
| } | |||
| int ArithmeticCompareFP16CPUKernel::Run() { | |||
| auto output_tensor = out_tensors_.at(0); | |||
| is_input0_fp32_ = in_tensors_.at(0)->data_type() == kNumberTypeFloat32; | |||
| is_input1_fp32_ = in_tensors_.at(1)->data_type() == kNumberTypeFloat32; | |||
| input0_fp16_ = ConvertInputFp32toFp16(in_tensors_.at(0), context_); | |||
| input1_fp16_ = ConvertInputFp32toFp16(in_tensors_.at(1), context_); | |||
| output_fp16_ = reinterpret_cast<uint8_t *>(output_tensor->MutableData()); | |||
| if (input0_fp16_ == nullptr || input1_fp16_ == nullptr || output_fp16_ == nullptr) { | |||
| MS_LOG(ERROR) << "Memory allocation failed"; | |||
| FreeTmpBuffer(); | |||
| return RET_ERROR; | |||
| } | |||
| auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticsRunFp16, this, context_->thread_num_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ArithmeticsRunFp16 run error error_code[" << ret << "]"; | |||
| } | |||
| FreeTmpBuffer(); | |||
| return ret; | |||
| } | |||
| void ArithmeticCompareFP16CPUKernel::FreeTmpBuffer() { | |||
| if (is_input0_fp32_) { | |||
| context_->allocator->Free(input0_fp16_); | |||
| input0_fp16_ = nullptr; | |||
| } | |||
| if (is_input1_fp32_) { | |||
| context_->allocator->Free(input1_fp16_); | |||
| input1_fp16_ = nullptr; | |||
| } | |||
| } | |||
| kernel::LiteKernel *CpuArithmeticCompareFp16KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *parameter, const lite::InnerContext *ctx, | |||
| const kernel::KernelKey &desc, | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| if (parameter == nullptr) { | |||
| MS_LOG(ERROR) << "input parameter is null!"; | |||
| return nullptr; | |||
| } | |||
| auto kernel = new (std::nothrow) ArithmeticCompareFP16CPUKernel(parameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; | |||
| free(parameter); | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ | |||
| << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_)); | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_NotEqual, CpuArithmeticCompareFp16KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Equal, CpuArithmeticCompareFp16KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Less, CpuArithmeticCompareFp16KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LessEqual, CpuArithmeticCompareFp16KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Greater, CpuArithmeticCompareFp16KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_GreaterEqual, CpuArithmeticCompareFp16KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,67 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_COMPARE_FP16_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_COMPARE_FP16_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "nnacl/fp16/arithmetic_fp16.h" | |||
| #include "schema/model_generated.h" | |||
| namespace mindspore::kernel { | |||
| typedef int (*ArithmeticCompareFuncFp16)(float16_t *input0, float16_t *input1, uint8_t *output, int element_size); | |||
| typedef int (*ArithmeticCompareOptFuncFp16)(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, | |||
| ArithmeticParameter *param); | |||
| typedef struct { | |||
| int primitive_type_; | |||
| int activation_type_; | |||
| ArithmeticCompareFuncFp16 func_; | |||
| ArithmeticCompareOptFuncFp16 opt_func_; | |||
| } ARITHMETIC_COMPARE_FUNC_INFO_FP16; | |||
| class ArithmeticCompareFP16CPUKernel : public LiteKernel { | |||
| public: | |||
| ArithmeticCompareFP16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| param_ = reinterpret_cast<ArithmeticParameter *>(parameter); | |||
| } | |||
| ~ArithmeticCompareFP16CPUKernel() = default; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int DoArithmetic(int task_id); | |||
| int BroadcastRun(float16_t *input0, float16_t *input1, uint8_t *output, int dim, int out_count, | |||
| int out_thread_stride); | |||
| private: | |||
| void FreeTmpBuffer(); | |||
| int outside_; | |||
| int break_pos_; | |||
| bool is_input0_fp32_ = false; | |||
| bool is_input1_fp32_ = false; | |||
| float16_t *input0_fp16_ = nullptr; | |||
| float16_t *input1_fp16_ = nullptr; | |||
| uint8_t *output_fp16_ = nullptr; | |||
| ArithmeticParameter *param_ = nullptr; | |||
| ArithmeticCompareFuncFp16 arithmetic_func_ = nullptr; | |||
| ArithmeticCompareOptFuncFp16 arithmetic_opt_func_ = nullptr; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_COMPARE_FP16_H_ | |||
| @@ -68,15 +68,7 @@ ARITHMETIC_FUNC_INFO_FP16 arithmetic_fun_table_fp16[] = { | |||
| {PrimitiveType_SquaredDifference, schema::ActivationType_NO_ACTIVATION, ElementSquaredDifferenceFp16, | |||
| ElementOptSquaredDifferenceFp16}, | |||
| {PrimitiveType_Maximum, schema::ActivationType_NO_ACTIVATION, ElementMaximumFp16, ElementOptMaximumFp16}, | |||
| {PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, ElementMinimumFp16, ElementOptMinimumFp16}, | |||
| {PrimitiveType_NotEqual, schema::ActivationType_NO_ACTIVATION, ElementNotEqualFp16, ElementOptNotEqualFp16}, | |||
| {PrimitiveType_Equal, schema::ActivationType_NO_ACTIVATION, ElementEqualFp16, ElementOptEqualFp16}, | |||
| {PrimitiveType_Less, schema::ActivationType_NO_ACTIVATION, ElementLessFp16, ElementOptLessFp16}, | |||
| {PrimitiveType_LessEqual, schema::ActivationType_NO_ACTIVATION, ElementLessEqual, ElementOptLessEqualFp16}, | |||
| {PrimitiveType_Greater, schema::ActivationType_NO_ACTIVATION, ElementGreaterFp16, ElementOptGreaterFp16}, | |||
| {PrimitiveType_GreaterEqual, schema::ActivationType_NO_ACTIVATION, ElementGreaterEqualFp16, | |||
| ElementOptGreaterEqualFp16}, | |||
| }; | |||
| {PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, ElementMinimumFp16, ElementOptMinimumFp16}}; | |||
| ArithmeticFuncFp16 GetArithmeticFun(int primitive_type, int activation_type) { | |||
| for (size_t i = 0; i < sizeof(arithmetic_fun_table_fp16); i++) { | |||
| @@ -67,6 +67,12 @@ int CastFp16CPUKernel::DoCast(int thread_id) { | |||
| auto offset = thread_id * stride_; | |||
| auto output_data = out_tensors_.at(0)->MutableData(); | |||
| switch (input->data_type()) { | |||
| case kNumberTypeBool: | |||
| BoolToFloat16(reinterpret_cast<bool *>(input->MutableData()) + offset, | |||
| reinterpret_cast<float16_t *>(output_data) + offset, data_num); | |||
| case kNumberTypeUInt8: | |||
| Uint8ToFloat16(reinterpret_cast<uint8_t *>(input->MutableData()) + offset, | |||
| reinterpret_cast<float16_t *>(output_data) + offset, data_num); | |||
| case kNumberTypeFloat32: | |||
| Float32ToFloat16(reinterpret_cast<float *>(input->MutableData()) + offset, | |||
| reinterpret_cast<float16_t *>(output_data) + offset, data_num); | |||
| @@ -339,12 +339,6 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_FloorMod, CpuArithmeticFp32Ke | |||
| REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_FloorDiv, CpuArithmeticFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_FloorMod, CpuArithmeticFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SquaredDifference, CpuArithmeticFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Equal, CpuArithmeticFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_NotEqual, CpuArithmeticFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Less, CpuArithmeticFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LessEqual, CpuArithmeticFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Greater, CpuArithmeticFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_GreaterEqual, CpuArithmeticFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Eltwise, CpuArithmeticFp32KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,127 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/fp32/arithmetic_compare.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "nnacl/fp32/arithmetic_compare.h" | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_Equal; | |||
| using mindspore::schema::PrimitiveType_Greater; | |||
| using mindspore::schema::PrimitiveType_GreaterEqual; | |||
| using mindspore::schema::PrimitiveType_Less; | |||
| using mindspore::schema::PrimitiveType_LessEqual; | |||
| using mindspore::schema::PrimitiveType_NotEqual; | |||
| namespace mindspore::kernel { | |||
| namespace { | |||
| typedef struct { | |||
| int primitive_type_; | |||
| ArithmeticCompareFp32Func func_; | |||
| } TYPE_FUNC_INFO; | |||
| } // namespace | |||
| ArithmeticCompareFp32Func ArithmeticCompareCPUKernel::GetArithmeticCompareFun(int primitive_type) { | |||
| TYPE_FUNC_INFO type_func_table[] = { | |||
| {PrimitiveType_Equal, ElementEqualFp32}, {PrimitiveType_NotEqual, ElementNotEqualFp32}, | |||
| {PrimitiveType_Less, ElementLessFp32}, {PrimitiveType_LessEqual, ElementLessEqualFp32}, | |||
| {PrimitiveType_Greater, ElementGreaterFp32}, {PrimitiveType_GreaterEqual, ElementGreaterEqualFp32}}; | |||
| for (size_t i = 0; i < sizeof(type_func_table); i++) { | |||
| if (type_func_table[i].primitive_type_ == primitive_type) { | |||
| return type_func_table[i].func_; | |||
| } | |||
| } | |||
| return nullptr; | |||
| } | |||
| int ArithmeticCompareCPUKernel::Init() { | |||
| if (!InferShapeDone()) { | |||
| return RET_OK; | |||
| } | |||
| return ReSize(); | |||
| } | |||
| int ArithmeticCompareCPUKernel::ReSize() { return RET_OK; } | |||
| int ArithmeticCompareCPUKernel::DoExecute(int task_id) { | |||
| int elements_num = in_tensors_.at(0)->ElementsNum(); | |||
| int stride = UP_DIV(elements_num, op_parameter_->thread_num_); | |||
| int offset = task_id * stride; | |||
| int count = MSMIN(stride, elements_num - offset); | |||
| if (count <= 0) { | |||
| return RET_OK; | |||
| } | |||
| if (func_ == nullptr) { | |||
| MS_LOG(ERROR) << "Run function is null! "; | |||
| return RET_ERROR; | |||
| } | |||
| // two inputs have the same shape, support broadcast later | |||
| auto *input0_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | |||
| auto *input1_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | |||
| auto *output_ptr = reinterpret_cast<uint8_t *>(out_tensors_.at(0)->MutableData()); | |||
| auto ret = func_(input0_ptr + offset, input1_ptr + offset, output_ptr + offset, count); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Run failed, illegal input! "; | |||
| } | |||
| return ret; | |||
| } | |||
| int ArithmeticCompareRun(void *cdata, int task_id) { | |||
| auto kernel = reinterpret_cast<ArithmeticCompareCPUKernel *>(cdata); | |||
| auto ret = kernel->DoExecute(task_id); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ArithmeticSelfRuns error task_id[" << task_id << "] error_code[" << ret << "]"; | |||
| } | |||
| return ret; | |||
| } | |||
| int ArithmeticCompareCPUKernel::Run() { | |||
| auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticCompareRun, this, op_parameter_->thread_num_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ArithmeticSelfRun error error_code[" << ret << "]"; | |||
| } | |||
| return ret; | |||
| } | |||
| kernel::LiteKernel *CpuArithmeticCompareFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *parameter, const lite::InnerContext *ctx, | |||
| const kernel::KernelKey &desc, | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| auto *kernel = new (std::nothrow) ArithmeticCompareCPUKernel(parameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "new ArithmeticSelfCPUKernel fail!"; | |||
| free(parameter); | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ | |||
| << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_)); | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Equal, CpuArithmeticCompareFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_NotEqual, CpuArithmeticCompareFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Less, CpuArithmeticCompareFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LessEqual, CpuArithmeticCompareFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Greater, CpuArithmeticCompareFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_GreaterEqual, CpuArithmeticCompareFp32KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,46 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_COMPARE_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_COMPARE_H_ | |||
| #include <vector> | |||
| #include "src/runtime/kernel/arm/fp32/arithmetic.h" | |||
| namespace mindspore::kernel { | |||
| typedef int (*ArithmeticCompareFp32Func)(const float *input0, const float *input1, uint8_t *output, int element_size); | |||
| class ArithmeticCompareCPUKernel : public ArithmeticCPUKernel { | |||
| public: | |||
| explicit ArithmeticCompareCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : ArithmeticCPUKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| func_ = GetArithmeticCompareFun(parameter->type_); | |||
| } | |||
| ~ArithmeticCompareCPUKernel() override = default; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| virtual int DoExecute(int task_id); | |||
| private: | |||
| ArithmeticCompareFp32Func GetArithmeticCompareFun(int primitive_type); | |||
| ArithmeticCompareFp32Func func_; | |||
| }; | |||
| int ArithmeticCompareRun(void *cdata, int task_id); | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_COMPARE_H_ | |||
| @@ -211,7 +211,6 @@ kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector<lite::Tensor | |||
| filter_shape = {new_out_channel, kernel_h, kernel_w, new_in_channel}; | |||
| bias_shape = {new_out_channel}; | |||
| auto *origin_weight = reinterpret_cast<float *>(inputs.at(kWeightIndex)->data_c()); | |||
| auto *origin_bias = reinterpret_cast<float *>(inputs.at(kBiasIndex)->data_c()); | |||
| for (int i = 0; i < group; ++i) { | |||
| std::vector<lite::Tensor *> new_inputs; | |||
| @@ -234,6 +233,7 @@ kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector<lite::Tensor | |||
| // if has bias, set new bias | |||
| if (has_bias) { | |||
| auto *origin_bias = reinterpret_cast<float *>(inputs.at(kBiasIndex)->data_c()); | |||
| auto bias_tensor = new (std::nothrow) | |||
| lite::Tensor(inputs.at(kBiasIndex)->data_type(), bias_shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR); | |||
| bias_tensor->MallocData(); | |||
| @@ -104,7 +104,7 @@ int ArithmeticInt8CPUKernel::ReSize() { return RET_OK; } | |||
| int ArithmeticInt8CPUKernel::DoArithmetic(int thread_id) { | |||
| auto input0_data = reinterpret_cast<int8_t *>(in_tensors_[0]->MutableData()); | |||
| auto input1_data1 = reinterpret_cast<int8_t *>(in_tensors_[1]->MutableData()); | |||
| auto output_data = reinterpret_cast<int8_t *>(out_tensors_[0]->MutableData()); | |||
| auto output_data = reinterpret_cast<uint8_t *>(out_tensors_[0]->MutableData()); | |||
| auto element_num = out_tensors_[0]->ElementsNum(); | |||
| auto param = reinterpret_cast<ArithmeticParameter *>(op_parameter_); | |||
| if (param->broadcasting_ && arithmetic_run_ != nullptr) { | |||
| @@ -24,7 +24,7 @@ | |||
| namespace mindspore::kernel { | |||
| class ArithmeticInt8CPUKernel : public LiteKernel { | |||
| typedef int (*ArithmeticRunInt8)(int8_t *input0, int8_t *input1, int8_t *output, int element_size, | |||
| typedef int (*ArithmeticRunInt8)(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, | |||
| ArithmeticQuantArg *quant_arg); | |||
| public: | |||