| @@ -231,6 +231,8 @@ endif() | |||
| if(NOT PLATFORM_ARM32 AND NOT PLATFORM_ARM64) | |||
| if("${X86_64_SIMD}" STREQUAL "sse") | |||
| add_compile_definitions(ENABLE_SSE) | |||
| set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1") | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.1") | |||
| endif() | |||
| if("${X86_64_SIMD}" STREQUAL "avx") | |||
| add_compile_definitions(ENABLE_SSE) | |||
| @@ -20,10 +20,17 @@ | |||
| int Fp32Relu(const float *src, int length, float *dst) { | |||
| int i = 0; | |||
| #ifdef ENABLE_ARM | |||
| float32x4_t zero_4 = vdupq_n_f32(0.0f); | |||
| #if defined(ENABLE_AVX) | |||
| MS_FLOAT32X8 zero_8 = MS_MOV256_F32(0.0f); | |||
| for (; i < length - 8; i += 8) { | |||
| MS_ST256_F32(dst + i, MS_MAX256_F32(MS_LD256_F32(src + i), zero_8)); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_SSE) || defined(ENABLE_ARM) | |||
| MS_FLOAT32X4 zero = MS_MOVQ_F32(0.0f); | |||
| for (; i < length - 4; i += 4) { | |||
| vst1q_f32(dst + i, vmaxq_f32(vld1q_f32(src + i), zero_4)); | |||
| MS_STQ_F32(dst + i, MS_MAXQ_F32(MS_LDQ_F32(src + i), zero)); | |||
| } | |||
| #endif | |||
| for (; i < length; ++i) { | |||
| @@ -34,13 +41,24 @@ int Fp32Relu(const float *src, int length, float *dst) { | |||
| int Fp32Relu6(const float *src, int length, float *dst) { | |||
| int i = 0; | |||
| #ifdef ENABLE_ARM | |||
| float32x4_t zero_4 = vdupq_n_f32(0.0f); | |||
| float32x4_t six_4 = vdupq_n_f32(6.0f); | |||
| #if defined(ENABLE_AVX) | |||
| MS_FLOAT32X8 zero_8 = MS_MOV256_F32(0.0f); | |||
| MS_FLOAT32X8 six_8 = MS_MOV256_F32(6.0f); | |||
| for (; i < length - 8; i += 8) { | |||
| MS_FLOAT32X8 dst_tmp = MS_MAX256_F32(MS_LD256_F32(src + i), zero_8); | |||
| dst_tmp = MS_MIN256_F32(dst_tmp, six_8); | |||
| MS_ST256_F32(dst + i, dst_tmp); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| MS_FLOAT32X4 zero = MS_MOVQ_F32(0.0f); | |||
| MS_FLOAT32X4 six = MS_MOVQ_F32(6.0f); | |||
| for (; i < length - 4; i += 4) { | |||
| float32x4_t dst_4 = vmaxq_f32(vld1q_f32(src + i), zero_4); | |||
| dst_4 = vminq_f32(dst_4, six_4); | |||
| vst1q_f32(dst + i, dst_4); | |||
| MS_FLOAT32X4 dst_tmp = MS_MAXQ_F32(MS_LDQ_F32(src + i), zero); | |||
| dst_tmp = MS_MINQ_F32(dst_tmp, six); | |||
| MS_STQ_F32(dst + i, dst_tmp); | |||
| } | |||
| #endif | |||
| for (; i < length; ++i) { | |||
| @@ -55,14 +73,21 @@ int Fp32Relu6(const float *src, int length, float *dst) { | |||
| int LRelu(const float *src, int length, float *dst, float alpha) { | |||
| int i = 0; | |||
| #ifdef ENABLE_ARM64 | |||
| float32x4_t alpha_4 = vdupq_n_f32(alpha); | |||
| #if defined(ENABLE_AVX) | |||
| for (; i < length - 8; i += 8) { | |||
| MS_FLOAT32X8 src_tmp = MS_LD256_F32(src + i); | |||
| MS_FLOAT32X8 mul_tmp = MS_MUL256_N_F32(src_tmp, alpha); | |||
| MS_FLOAT32X8 mask = MS_CMP256_PS(src_tmp, MS_MOV256_F32(0.0f), 30); | |||
| MS_ST256_F32(dst + i, MS_BLEND256_PS(mul_tmp, src_tmp, mask)); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| for (; i < length - 4; i += 4) { | |||
| float32x4_t src_4 = vld1q_f32(src + i); | |||
| float32x4_t mul_4 = vmulq_f32(src_4, alpha_4); | |||
| uint32x4_t flag = vclezq_f32(src_4); | |||
| float32x4_t dst_4 = vbslq_f32(flag, mul_4, src_4); | |||
| vst1q_f32(dst + i, dst_4); | |||
| MS_FLOAT32X4 src_tmp = MS_LDQ_F32(src + i); | |||
| MS_FLOAT32X4 mul_tmp = MS_MULQ_N_F32(src_tmp, alpha); | |||
| MS_FLOAT32X4 mask = MS_CMPGTQ_PS(src_tmp, MS_MOVQ_F32(0.0f)); | |||
| MS_STQ_F32(dst + i, MS_BLENDQ_PS(mul_tmp, src_tmp, mask)); | |||
| } | |||
| #endif | |||
| for (; i < length; ++i) { | |||
| @@ -73,11 +98,18 @@ int LRelu(const float *src, int length, float *dst, float alpha) { | |||
| int Sigmoid(const float *src, int length, float *dst) { | |||
| int i = 0; | |||
| #ifdef ENABLE_ARM64 | |||
| int count = (length / C4NUM) * C4NUM; | |||
| for (; i < count; i += C4NUM) { | |||
| simd_exp(vnegq_f32(vld1q_f32(src + i)), dst + i); | |||
| vst1q_f32(dst + i, vdivq_f32(vdupq_n_f32(1.0f), vaddq_f32(vdupq_n_f32(1.0f), vld1q_f32(dst + i)))); | |||
| #if defined(ENABLE_AVX) | |||
| for (; i < length - 8; i += 8) { | |||
| simd_exp_avx(-(MS_LD256_F32(src + i)), dst + i); | |||
| MS_ST256_F32(dst + i, | |||
| MS_DIV256_F32(MS_MOV256_F32(1.0f), MS_ADD256_F32(MS_MOV256_F32(1.0f), MS_LD256_F32(dst + i)))); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_ARM64) || defined(ENABLE_SSE) | |||
| for (; i < length - 4; i += 4) { | |||
| simd_exp(-(MS_LDQ_F32(src + i)), dst + i); | |||
| MS_STQ_F32(dst + i, MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_ADDQ_F32(MS_MOVQ_F32(1.0f), MS_LDQ_F32(dst + i)))); | |||
| } | |||
| #endif | |||
| for (; i < length; ++i) { | |||
| @@ -102,26 +134,40 @@ float TanhOpt(float src) { | |||
| int Tanh(const float *src, int length, float *dst) { | |||
| int i = 0; | |||
| #ifdef ENABLE_ARM64 | |||
| static float32x4_t paramv[] = {{378.0f, 378.0f, 378.0f, 378.0f}, | |||
| {17325.0f, 17325.0f, 17325.0f, 17325.0f}, | |||
| {135135.0f, 135135.0f, 135135.0f, 135135.0f}, | |||
| {28.0f, 28.0f, 28.0f, 28.0f}, | |||
| {3150.0f, 3150.0f, 3150.0f, 3150.0f}, | |||
| {62370.0f, 62370.0f, 62370.0f, 62370.0f}}; | |||
| float32x4_t neg_one = {-1.0f, -1.0f, -1.0f, -1.0f}; | |||
| float32x4_t pos_one = {1.0f, 1.0f, 1.0f, 1.0f}; | |||
| int count = (length / C4NUM) * C4NUM; | |||
| for (; i < count; i += C4NUM) { | |||
| float32x4_t input = vld1q_f32(src + i); | |||
| float32x4_t square = vmulq_f32(input, input); | |||
| float32x4_t a = vmulq_f32( | |||
| vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(square, paramv[0]), square), paramv[1]), square), paramv[2]), | |||
| input); | |||
| float32x4_t b = vaddq_f32( | |||
| vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(paramv[3], square), paramv[4]), square), paramv[5]), square), | |||
| paramv[2]); | |||
| vst1q_f32(dst + i, vminq_f32(vmaxq_f32(vdivq_f32(a, b), neg_one), pos_one)); | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) || defined(ENABLE_AVX) | |||
| const int cnt = 6; | |||
| float data[] = {378.0f, 17325.0f, 135135.0f, 28.0f, 3150.0f, 62370.0f}; | |||
| #endif | |||
| #if defined(ENABLE_AVX) | |||
| MS_FLOAT32X8 neg_one_8 = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; | |||
| MS_FLOAT32X8 pos_one_8 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; | |||
| MS_FLOAT32X8 param256[6]; | |||
| for (int j = 0; j < cnt; ++j) { | |||
| param256[j] = MS_MOV256_F32(data[j]); | |||
| } | |||
| for (; i < length - 8; i += 8) { | |||
| MS_FLOAT32X8 input = MS_LD256_F32(src + i); | |||
| MS_FLOAT32X8 square = input * input; | |||
| MS_FLOAT32X8 a = (((square + param256[0]) * square + param256[1]) * square + param256[2]) * input; | |||
| MS_FLOAT32X8 b = ((param256[3] * square + param256[4]) * square + param256[5]) * square + param256[2]; | |||
| MS_ST256_F32(dst + i, MS_MIN256_F32(MS_MAX256_F32(a / b, neg_one_8), pos_one_8)); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| MS_FLOAT32X4 param[6]; | |||
| MS_FLOAT32X4 neg_one = {-1.0f, -1.0f, -1.0f, -1.0f}; | |||
| MS_FLOAT32X4 pos_one = {1.0f, 1.0f, 1.0f, 1.0f}; | |||
| for (int j = 0; j < cnt; ++j) { | |||
| param[j] = MS_MOVQ_F32(data[j]); | |||
| } | |||
| for (; i < length - 4; i += 4) { | |||
| MS_FLOAT32X4 input = MS_LDQ_F32(src + i); | |||
| MS_FLOAT32X4 square = input * input; | |||
| MS_FLOAT32X4 a = (((square + param[0]) * square + param[1]) * square + param[2]) * input; | |||
| MS_FLOAT32X4 b = ((param[3] * square + param[4]) * square + param[5]) * square + param[2]; | |||
| MS_STQ_F32(dst + i, MS_MINQ_F32(MS_MAXQ_F32(a / b, neg_one), pos_one)); | |||
| } | |||
| #endif | |||
| for (; i < length; ++i) { | |||
| @@ -142,12 +188,21 @@ int Swish(const float *src, int length, float *dst) { | |||
| return NNACL_ERR; | |||
| } | |||
| int index = 0; | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= length - C4NUM; index += C4NUM) { | |||
| float32x4_t src_value = vld1q_f32(src + index); | |||
| float32x4_t sigmoid_value = vld1q_f32(dst + index); | |||
| float32x4_t result = vmulq_f32(src_value, sigmoid_value); | |||
| vst1q_f32(dst + index, result); | |||
| #if defined(ENABLE_AVX) | |||
| for (; index <= length - 8; index += 8) { | |||
| MS_FLOAT32X8 src_value = MS_LD256_F32(src + index); | |||
| MS_FLOAT32X8 sigmoid_value = MS_LD256_F32(dst + index); | |||
| MS_FLOAT32X8 result = MS_MUL256_F32(src_value, sigmoid_value); | |||
| MS_ST256_F32(dst + index, result); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| for (; index <= length - 4; index += 4) { | |||
| MS_FLOAT32X4 src_value = MS_LDQ_F32(src + index); | |||
| MS_FLOAT32X4 sigmoid_value = MS_LDQ_F32(dst + index); | |||
| MS_FLOAT32X4 result = MS_MULQ_F32(src_value, sigmoid_value); | |||
| MS_STQ_F32(dst + index, result); | |||
| } | |||
| #endif | |||
| for (; index < length; index++) { | |||
| @@ -18,28 +18,46 @@ | |||
| #include "nnacl/fp32/arithmetic_fp32.h" | |||
| int ElementOptAdd(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) { | |||
| #ifdef ENABLE_NEON | |||
| float32x4_t vin0_opt = vdupq_n_f32(in0[0]); | |||
| float32x4_t vin1_opt = vdupq_n_f32(in1[0]); | |||
| #ifdef ENABLE_AVX | |||
| MS_FLOAT32X8 vin0_opt_8 = MS_MOV256_F32(in0[0]); | |||
| MS_FLOAT32X8 vin1_opt_8 = MS_MOV256_F32(in1[0]); | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| MS_FLOAT32X4 vin0_opt = MS_MOVQ_F32(in0[0]); | |||
| MS_FLOAT32X4 vin1_opt = MS_MOVQ_F32(in1[0]); | |||
| #endif | |||
| int index = 0; | |||
| if (param->in_elements_num0_ == 1) { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| float32x4_t vin1 = vld1q_f32(in1 + index); | |||
| float32x4_t vout = vaddq_f32(vin0_opt, vin1); | |||
| vst1q_f32(out + index, vout); | |||
| #ifdef ENABLE_AVX | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); | |||
| MS_FLOAT32X8 vout = MS_ADD256_F32(vin0_opt_8, vin1); | |||
| MS_ST256_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); | |||
| MS_FLOAT32X4 vout = MS_ADDQ_F32(vin0_opt, vin1); | |||
| MS_STQ_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| out[index] = in0[0] + in1[index]; | |||
| } | |||
| } else { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| float32x4_t vin0 = vld1q_f32(in0 + index); | |||
| float32x4_t vout = vaddq_f32(vin0, vin1_opt); | |||
| vst1q_f32(out + index, vout); | |||
| #ifdef ENABLE_AVX | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); | |||
| MS_FLOAT32X8 vout = MS_ADD256_F32(vin0, vin1_opt_8); | |||
| MS_ST256_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); | |||
| MS_FLOAT32X4 vout = MS_ADDQ_F32(vin0, vin1_opt); | |||
| MS_STQ_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| @@ -50,28 +68,46 @@ int ElementOptAdd(const float *in0, const float *in1, float *out, int size, cons | |||
| } | |||
| int ElementOptAddInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param) { | |||
| #ifdef ENABLE_NEON | |||
| int32x4_t vin0_opt = vdupq_n_s32(in0[0]); | |||
| int32x4_t vin1_opt = vdupq_n_s32(in1[0]); | |||
| #ifdef ENABLE_AVX | |||
| MS_INT32X8 vin0_opt_8 = MS_MOV256_EPI32(in0[0]); | |||
| MS_INT32X8 vin1_opt_8 = MS_MOV256_EPI32(in1[0]); | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| MS_INT32X4 vin0_opt = MS_MOVQ_EPI32(in0[0]); | |||
| MS_INT32X4 vin1_opt = MS_MOVQ_EPI32(in1[0]); | |||
| #endif | |||
| int index = 0; | |||
| if (param->in_elements_num0_ == 1) { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| int32x4_t vin1 = vld1q_s32(in1 + index); | |||
| int32x4_t vout = vaddq_s32(vin0_opt, vin1); | |||
| vst1q_s32(out + index, vout); | |||
| #ifdef ENABLE_AVX | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_INT32X8 vin1 = MS_LD256_EPI32(in1 + index); | |||
| MS_INT32X8 vout = MS_ADD256_EPI32(vin0_opt_8, vin1); | |||
| MS_ST256_EPI32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_INT32X4 vin1 = MS_LDQ_EPI32(in1 + index); | |||
| MS_INT32X4 vout = MS_ADDQ_EPI32(vin0_opt, vin1); | |||
| MS_STQ_EPI32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| out[index] = in0[0] + in1[index]; | |||
| } | |||
| } else { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| int32x4_t vin0 = vld1q_s32(in0 + index); | |||
| int32x4_t vout = vaddq_s32(vin0, vin1_opt); | |||
| vst1q_s32(out + index, vout); | |||
| #ifdef ENABLE_AVX | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_INT32X8 vin0 = MS_LD256_EPI32(in0 + index); | |||
| MS_INT32X8 vout = MS_ADD256_EPI32(vin0, vin1_opt_8); | |||
| MS_ST256_EPI32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_INT32X4 vin0 = MS_LDQ_EPI32(in0 + index); | |||
| MS_INT32X4 vout = MS_ADDQ_EPI32(vin0, vin1_opt); | |||
| MS_STQ_EPI32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| @@ -82,29 +118,48 @@ int ElementOptAddInt(const int *in0, const int *in1, int *out, int size, const A | |||
| } | |||
| int ElementOptAddRelu(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) { | |||
| #ifdef ENABLE_NEON | |||
| float32x4_t vin0_opt = vdupq_n_f32(in0[0]); | |||
| float32x4_t vin1_opt = vdupq_n_f32(in1[0]); | |||
| float32x4_t zeros = vdupq_n_f32(0.0f); | |||
| #ifdef ENABLE_AVX | |||
| MS_FLOAT32X8 vin0_opt_8 = MS_MOV256_F32(in0[0]); | |||
| MS_FLOAT32X8 vin1_opt_8 = MS_MOV256_F32(in1[0]); | |||
| MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f); | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| MS_FLOAT32X4 vin0_opt = MS_MOVQ_F32(in0[0]); | |||
| MS_FLOAT32X4 vin1_opt = MS_MOVQ_F32(in1[0]); | |||
| MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f); | |||
| #endif | |||
| int index = 0; | |||
| if (param->in_elements_num0_ == 1) { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| float32x4_t vin1 = vld1q_f32(in1 + index); | |||
| float32x4_t vout = vmaxq_f32(vaddq_f32(vin0_opt, vin1), zeros); | |||
| vst1q_f32(out + index, vout); | |||
| #ifdef ENABLE_AVX | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); | |||
| MS_FLOAT32X8 vout = MS_MAX256_F32(MS_ADD256_F32(vin0_opt_8, vin1), zeros_8); | |||
| MS_ST256_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); | |||
| MS_FLOAT32X4 vout = MS_MAXQ_F32(MS_ADDQ_F32(vin0_opt, vin1), zeros); | |||
| MS_STQ_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| out[index] = MSMAX(in0[0] + in1[index], 0); | |||
| } | |||
| } else { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| float32x4_t vin0 = vld1q_f32(in0 + index); | |||
| float32x4_t vout = vmaxq_f32(vaddq_f32(vin0, vin1_opt), zeros); | |||
| vst1q_f32(out + index, vout); | |||
| #ifdef ENABLE_AVX | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); | |||
| MS_FLOAT32X8 vout = MS_MAX256_F32(MS_ADD256_F32(vin0, vin1_opt_8), zeros_8); | |||
| MS_ST256_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); | |||
| MS_FLOAT32X4 vout = MS_MAXQ_F32(MS_ADDQ_F32(vin0, vin1_opt), zeros); | |||
| MS_STQ_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| @@ -115,30 +170,50 @@ int ElementOptAddRelu(const float *in0, const float *in1, float *out, int size, | |||
| } | |||
| int ElementOptAddRelu6(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) { | |||
| #ifdef ENABLE_NEON | |||
| float32x4_t vin0_opt = vdupq_n_f32(in0[0]); | |||
| float32x4_t vin1_opt = vdupq_n_f32(in1[0]); | |||
| float32x4_t zeros = vdupq_n_f32(0.0f); | |||
| float32x4_t bounds = vdupq_n_f32(6.0f); | |||
| #ifdef ENABLE_AVX | |||
| MS_FLOAT32X8 vin0_opt_8 = MS_MOV256_F32(in0[0]); | |||
| MS_FLOAT32X8 vin1_opt_8 = MS_MOV256_F32(in1[0]); | |||
| MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f); | |||
| MS_FLOAT32X8 bounds_8 = MS_MOV256_F32(6.0f); | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| MS_FLOAT32X4 vin0_opt = MS_MOVQ_F32(in0[0]); | |||
| MS_FLOAT32X4 vin1_opt = MS_MOVQ_F32(in1[0]); | |||
| MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f); | |||
| MS_FLOAT32X4 bounds = MS_MOVQ_F32(6.0f); | |||
| #endif | |||
| int index = 0; | |||
| if (param->in_elements_num0_ == 1) { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| float32x4_t vin1 = vld1q_f32(in1 + index); | |||
| float32x4_t vout = vminq_f32(vmaxq_f32(vaddq_f32(vin0_opt, vin1), zeros), bounds); | |||
| vst1q_f32(out + index, vout); | |||
| #ifdef ENABLE_AVX | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); | |||
| MS_FLOAT32X8 vout = MS_MIN256_F32(MS_MAX256_F32(MS_ADD256_F32(vin0_opt_8, vin1), zeros_8), bounds_8); | |||
| MS_ST256_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); | |||
| MS_FLOAT32X4 vout = MS_MINQ_F32(MS_MAXQ_F32(MS_ADDQ_F32(vin0_opt, vin1), zeros), bounds); | |||
| MS_STQ_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| out[index] = MSMIN(MSMAX(in0[0] + in1[index], 0), 6); | |||
| } | |||
| } else { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| float32x4_t vin0 = vld1q_f32(in0 + index); | |||
| float32x4_t vout = vminq_f32(vmaxq_f32(vaddq_f32(vin0, vin1_opt), zeros), bounds); | |||
| vst1q_f32(out + index, vout); | |||
| #ifdef ENABLE_AVX | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); | |||
| MS_FLOAT32X8 vout = MS_MIN256_F32(MS_MAX256_F32(MS_ADD256_F32(vin0, vin1_opt_8), zeros_8), bounds_8); | |||
| MS_ST256_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); | |||
| MS_FLOAT32X4 vout = MS_MINQ_F32(MS_MAXQ_F32(MS_ADDQ_F32(vin0, vin1_opt), zeros), bounds); | |||
| MS_STQ_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| @@ -157,12 +232,20 @@ int BroadcastAdd(const float *in0, const float *in1, float *tile_in0, float *til | |||
| int ElementAdd(const float *in0, const float *in1, float *out, int size) { | |||
| int index = 0; | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| float32x4_t vin0 = vld1q_f32(in0 + index); | |||
| float32x4_t vin1 = vld1q_f32(in1 + index); | |||
| float32x4_t vout = vaddq_f32(vin0, vin1); | |||
| vst1q_f32(out + index, vout); | |||
| #ifdef ENABLE_AVX | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); | |||
| MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); | |||
| MS_FLOAT32X8 vout = MS_ADD256_F32(vin0, vin1); | |||
| MS_ST256_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); | |||
| MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); | |||
| MS_FLOAT32X4 vout = MS_ADDQ_F32(vin0, vin1); | |||
| MS_STQ_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| @@ -173,14 +256,24 @@ int ElementAdd(const float *in0, const float *in1, float *out, int size) { | |||
| int ElementAddRelu(const float *in0, const float *in1, float *out, int size) { | |||
| int index = 0; | |||
| #ifdef ENABLE_NEON | |||
| float32x4_t zeros = vdupq_n_f32(0.0f); | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| float32x4_t vin0 = vld1q_f32(in0 + index); | |||
| float32x4_t vin1 = vld1q_f32(in1 + index); | |||
| float32x4_t vout = vaddq_f32(vin0, vin1); | |||
| vout = vbslq_f32(vcgtq_f32(vout, zeros), vout, zeros); | |||
| vst1q_f32(out + index, vout); | |||
| #ifdef ENABLE_AVX | |||
| MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f); | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); | |||
| MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); | |||
| MS_FLOAT32X8 vout = MS_ADD256_F32(vin0, vin1); | |||
| vout = MS_BLEND256_PS(zeros_8, vout, MS_CMP256_PS(vout, zeros_8, 30)); | |||
| MS_ST256_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f); | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); | |||
| MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); | |||
| MS_FLOAT32X4 vout = MS_ADDQ_F32(vin0, vin1); | |||
| vout = MS_BLENDQ_PS(zeros, vout, MS_CMPGTQ_PS(vout, zeros)); | |||
| MS_STQ_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| @@ -192,14 +285,24 @@ int ElementAddRelu(const float *in0, const float *in1, float *out, int size) { | |||
| int ElementAddRelu6(const float *in0, const float *in1, float *out, int size) { | |||
| int index = 0; | |||
| #ifdef ENABLE_NEON | |||
| float32x4_t zeros = vdupq_n_f32(0.0f); | |||
| float32x4_t bounds = vdupq_n_f32(6.0f); | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| float32x4_t vin0 = vld1q_f32(in0 + index); | |||
| float32x4_t vin1 = vld1q_f32(in1 + index); | |||
| float32x4_t vout = vminq_f32(vmaxq_f32(vaddq_f32(vin0, vin1), zeros), bounds); | |||
| vst1q_f32(out + index, vout); | |||
| #ifdef ENABLE_AVX | |||
| MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f); | |||
| MS_FLOAT32X8 bounds_8 = MS_MOV256_F32(6.0f); | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); | |||
| MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); | |||
| MS_FLOAT32X8 vout = MS_MIN256_F32(MS_MAX256_F32(MS_ADD256_F32(vin0, vin1), zeros_8), bounds_8); | |||
| MS_ST256_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f); | |||
| MS_FLOAT32X4 bounds = MS_MOVQ_F32(6.0f); | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); | |||
| MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); | |||
| MS_FLOAT32X4 vout = MS_MINQ_F32(MS_MAXQ_F32(MS_ADDQ_F32(vin0, vin1), zeros), bounds); | |||
| MS_STQ_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| @@ -210,12 +313,20 @@ int ElementAddRelu6(const float *in0, const float *in1, float *out, int size) { | |||
| int ElementAddInt(const int *in0, const int *in1, int *out, int size) { | |||
| int index = 0; | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| int32x4_t vin0 = vld1q_s32(in0 + index); | |||
| int32x4_t vin1 = vld1q_s32(in1 + index); | |||
| int32x4_t vout = vaddq_s32(vin0, vin1); | |||
| vst1q_s32(out + index, vout); | |||
| #ifdef ENABLE_AVX | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_INT32X8 vin0 = MS_LD256_EPI32(in0 + index); | |||
| MS_INT32X8 vin1 = MS_LD256_EPI32(in1 + index); | |||
| MS_INT32X8 vout = MS_ADD256_EPI32(vin0, vin1); | |||
| MS_ST256_EPI32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_INT32X4 vin0 = MS_LDQ_EPI32(in0 + index); | |||
| MS_INT32X4 vin1 = MS_LDQ_EPI32(in1 + index); | |||
| MS_INT32X4 vout = MS_ADDQ_EPI32(vin0, vin1); | |||
| MS_STQ_EPI32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| @@ -38,26 +38,49 @@ extern "C" { | |||
| int Exp(const float *input_data, float *output_data, const ExpParameter *parameter, int task_id); | |||
| void ExpFp32(const float *src, float *dst, int num); | |||
| #ifdef ENABLE_ARM64 | |||
| static inline void simd_exp(float32x4_t input4, float *dst) { | |||
| static float32x4_t maxv = {88.0f, 88.0f, 88.0f, 88.0f}; | |||
| static float32x4_t minv = {-88.0f, -88.0f, -88.0f, -88.0f}; | |||
| static float32x4_t paramv[] = {{0.693147f, 0.693147f, 0.693147f, 0.693147f}, | |||
| #if defined(ENABLE_ARM64) || defined(ENABLE_SSE) | |||
| static inline void simd_exp(MS_FLOAT32X4 input, float *dst) { | |||
| static MS_FLOAT32X4 maxv = {88.0f, 88.0f, 88.0f, 88.0f}; | |||
| static MS_FLOAT32X4 minv = {-88.0f, -88.0f, -88.0f, -88.0f}; | |||
| static MS_FLOAT32X4 param[] = {{0.693147f, 0.693147f, 0.693147f, 0.693147f}, | |||
| {1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120}, | |||
| {1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24}, | |||
| {1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6}, | |||
| {0.5f, 0.5f, 0.5f, 0.5f}, | |||
| {1.0f, 1.0f, 1.0f, 1.0f}}; | |||
| input4 = vmaxq_f32(minv, vminq_f32(maxv, input4)); | |||
| int32x4_t integer4 = vcvtq_s32_f32(vdivq_f32(input4, paramv[0])); | |||
| float32x4_t decimal4 = vsubq_f32(input4, vmulq_f32(vcvtq_f32_s32(integer4), paramv[0])); | |||
| int32x4_t int_exp4 = vshlq_s32(vaddq_s32(integer4, vdupq_n_s32(127)), vdupq_n_s32(23)); | |||
| vst1q_f32(dst, vld1q_f32((float32_t *)(&int_exp4))); | |||
| float32x4_t decimal_exp4 = vaddq_f32(paramv[2], vmulq_f32(decimal4, paramv[1])); | |||
| decimal_exp4 = vmulq_f32(decimal4, vaddq_f32(paramv[3], vmulq_f32(decimal4, decimal_exp4))); | |||
| decimal_exp4 = vaddq_f32(paramv[5], vmulq_f32(decimal4, vaddq_f32(paramv[4], decimal_exp4))); | |||
| decimal_exp4 = vaddq_f32(paramv[5], vmulq_f32(decimal4, decimal_exp4)); | |||
| vst1q_f32(dst, vmulq_f32(vld1q_f32(dst), decimal_exp4)); | |||
| input = MS_MAXQ_F32(minv, MS_MINQ_F32(input, maxv)); | |||
| MS_INT32X4 integer = MS_CVTQPS_EPI32(input / param[0]); | |||
| MS_FLOAT32X4 decimal = input - MS_CVTQEPI32_PS(integer) * param[0]; | |||
| MS_INT32X4 int_exp = MS_SLLIQ_EPI32(MS_ADDQ_EPI32(integer, MS_MOVQ_EPI32(127)), 23); | |||
| memcpy(dst, &int_exp, sizeof(int32_t) * 4); | |||
| MS_FLOAT32X4 decimal_exp = | |||
| param[5] + | |||
| decimal * (param[5] + decimal * (param[4] + decimal * (param[3] + decimal * (param[2] + decimal * param[1])))); | |||
| MS_STQ_F32(dst, decimal_exp * MS_LDQ_F32(dst)); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_AVX) | |||
| static inline void simd_exp_avx(MS_FLOAT32X8 input, float *dst) { | |||
| static MS_FLOAT32X8 maxv = {88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f}; | |||
| static MS_FLOAT32X8 minv = {-88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f}; | |||
| static MS_FLOAT32X8 param[] = { | |||
| {0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f}, | |||
| {1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120}, | |||
| {1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24}, | |||
| {1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6}, | |||
| {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}, | |||
| {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}}; | |||
| input = MS_MAX256_F32(minv, MS_MIN256_F32(input, maxv)); | |||
| MS_INT32X8 integer = MS_CVT256PS_EPI32(input / param[0]); | |||
| MS_FLOAT32X8 decimal = input - MS_CVT256EPI32_PS(integer) * param[0]; | |||
| MS_INT32X8 int_exp = MS_SLLI256_EPI32(MS_ADD256_EPI32(integer, MS_MOV256_EPI32(127)), 23); | |||
| memcpy(dst, &int_exp, sizeof(int32_t) * 8); | |||
| MS_FLOAT32X8 decimal_exp = | |||
| param[5] + | |||
| decimal * (param[5] + decimal * (param[4] + decimal * (param[3] + decimal * (param[2] + decimal * param[1])))); | |||
| MS_ST256_F32(dst, decimal_exp * MS_LD256_F32(dst)); | |||
| } | |||
| #endif | |||
| @@ -24,12 +24,20 @@ int BroadcastMul(const float *in0, const float *in1, float *tile_in0, float *til | |||
| int ElementMul(const float *in0, const float *in1, float *out, int size) { | |||
| int index = 0; | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| float32x4_t vin0 = vld1q_f32(in0 + index); | |||
| float32x4_t vin1 = vld1q_f32(in1 + index); | |||
| float32x4_t vout = vmulq_f32(vin0, vin1); | |||
| vst1q_f32(out + index, vout); | |||
| #if defined(ENABLE_AVX) | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); | |||
| MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); | |||
| MS_FLOAT32X8 vout = MS_MUL256_F32(vin0, vin1); | |||
| MS_ST256_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); | |||
| MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); | |||
| MS_FLOAT32X4 vout = MS_MULQ_F32(vin0, vin1); | |||
| MS_STQ_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| @@ -40,14 +48,24 @@ int ElementMul(const float *in0, const float *in1, float *out, int size) { | |||
| int ElementMulRelu(const float *in0, const float *in1, float *out, int size) { | |||
| int index = 0; | |||
| #ifdef ENABLE_NEON | |||
| float32x4_t zeros = vdupq_n_f32(0.0f); | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| float32x4_t vin0 = vld1q_f32(in0 + index); | |||
| float32x4_t vin1 = vld1q_f32(in1 + index); | |||
| float32x4_t vout = vmulq_f32(vin0, vin1); | |||
| vout = vbslq_f32(vcgtq_f32(vout, zeros), vout, zeros); | |||
| vst1q_f32(out + index, vout); | |||
| #if defined(ENABLE_AVX) | |||
| MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f); | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); | |||
| MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); | |||
| MS_FLOAT32X8 vout = MS_MUL256_F32(vin0, vin1); | |||
| vout = MS_BLEND256_PS(zeros_8, vout, MS_CMP256_PS(vout, zeros_8, 30)); | |||
| MS_ST256_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f); | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); | |||
| MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); | |||
| MS_FLOAT32X4 vout = MS_MULQ_F32(vin0, vin1); | |||
| vout = MS_BLENDQ_PS(zeros, vout, MS_CMPGTQ_PS(vout, zeros)); | |||
| MS_STQ_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| @@ -59,14 +77,24 @@ int ElementMulRelu(const float *in0, const float *in1, float *out, int size) { | |||
| int ElementMulRelu6(const float *in0, const float *in1, float *out, int size) { | |||
| int index = 0; | |||
| #ifdef ENABLE_NEON | |||
| float32x4_t zeros = vdupq_n_f32(0.0f); | |||
| float32x4_t bounds = vdupq_n_f32(6.0f); | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| float32x4_t vin0 = vld1q_f32(in0 + index); | |||
| float32x4_t vin1 = vld1q_f32(in1 + index); | |||
| float32x4_t vout = vminq_f32(vmaxq_f32(vmulq_f32(vin0, vin1), zeros), bounds); | |||
| vst1q_f32(out + index, vout); | |||
| #if defined(ENABLE_AVX) | |||
| MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f); | |||
| MS_FLOAT32X8 bounds_8 = MS_MOV256_F32(6.0f); | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); | |||
| MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); | |||
| MS_FLOAT32X8 vout = MS_MIN256_F32(MS_MAX256_F32(MS_MUL256_F32(vin0, vin1), zeros_8), bounds_8); | |||
| MS_ST256_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f); | |||
| MS_FLOAT32X4 bounds = MS_MOVQ_F32(6.0f); | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); | |||
| MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); | |||
| MS_FLOAT32X4 vout = MS_MINQ_F32(MS_MAXQ_F32(MS_MULQ_F32(vin0, vin1), zeros), bounds); | |||
| MS_STQ_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| @@ -77,12 +105,20 @@ int ElementMulRelu6(const float *in0, const float *in1, float *out, int size) { | |||
| int ElementMulInt(const int *in0, const int *in1, int *out, int size) { | |||
| int index = 0; | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| int32x4_t vin0 = vld1q_s32(in0 + index); | |||
| int32x4_t vin1 = vld1q_s32(in1 + index); | |||
| int32x4_t vout = vmulq_s32(vin0, vin1); | |||
| vst1q_s32(out + index, vout); | |||
| #if defined(ENABLE_AVX) | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_INT32X8 vin0 = MS_LD256_EPI32(in0 + index); | |||
| MS_INT32X8 vin1 = MS_LD256_EPI32(in1 + index); | |||
| MS_INT32X8 vout = MS_MUL256_EPI32(vin0, vin1); | |||
| MS_ST256_EPI32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_INT32X4 vin0 = MS_LDQ_EPI32(in0 + index); | |||
| MS_INT32X4 vin1 = MS_LDQ_EPI32(in1 + index); | |||
| MS_INT32X4 vout = MS_MULQ_EPI32(vin0, vin1); | |||
| MS_STQ_EPI32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| @@ -93,18 +129,28 @@ int ElementMulInt(const int *in0, const int *in1, int *out, int size) { | |||
| int ElementMulReluInt(const int *in0, const int *in1, int *out, int size) { | |||
| int index = 0; | |||
| #ifdef ENABLE_NEON | |||
| int32x4_t zeros = vdupq_n_s32(0); | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| int32x4_t vin0 = vld1q_s32(in0 + index); | |||
| int32x4_t vin1 = vld1q_s32(in1 + index); | |||
| int32x4_t vout = vmulq_s32(vin0, vin1); | |||
| vout = vbslq_s32(vcgtq_s32(vout, zeros), vout, zeros); | |||
| vst1q_s32(out + index, vout); | |||
| #if defined(ENABLE_AVX) | |||
| MS_INT32X8 zeros_8 = MS_MOV256_EPI32(0); | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_INT32X8 vin0 = MS_LD256_EPI32(in0 + index); | |||
| MS_INT32X8 vin1 = MS_LD256_EPI32(in1 + index); | |||
| MS_INT32X8 vout = MS_MUL256_EPI32(vin0, vin1); | |||
| vout = MS_BLEND256_EPI32(zeros_8, vout, MS_CMPGT256_EPI32(vout, zeros_8)); | |||
| MS_ST256_EPI32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| MS_INT32X4 zeros = MS_MOVQ_EPI32(0); | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_INT32X4 vin0 = MS_LDQ_EPI32(in0 + index); | |||
| MS_INT32X4 vin1 = MS_LDQ_EPI32(in1 + index); | |||
| MS_INT32X4 vout = MS_MULQ_EPI32(vin0, vin1); | |||
| vout = MS_BLENDQ_EPI32(zeros, vout, MS_CMPGTQ_EPI32(vout, zeros)); | |||
| MS_STQ_EPI32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| float res = in0[index] * in1[index]; | |||
| int res = in0[index] * in1[index]; | |||
| out[index] = res > 0 ? res : 0; | |||
| } | |||
| return NNACL_OK; | |||
| @@ -112,14 +158,24 @@ int ElementMulReluInt(const int *in0, const int *in1, int *out, int size) { | |||
| int ElementMulRelu6Int(const int *in0, const int *in1, int *out, int size) { | |||
| int index = 0; | |||
| #ifdef ENABLE_NEON | |||
| int32x4_t zeros = vdupq_n_s32(0); | |||
| int32x4_t bounds = vdupq_n_s32(6); | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| int32x4_t vin0 = vld1q_s32(in0 + index); | |||
| int32x4_t vin1 = vld1q_s32(in1 + index); | |||
| int32x4_t vout = vminq_s32(vmaxq_s32(vmulq_s32(vin0, vin1), zeros), bounds); | |||
| vst1q_s32(out + index, vout); | |||
| #if defined(ENABLE_AVX) | |||
| MS_INT32X8 zeros_8 = MS_MOV256_EPI32(0); | |||
| MS_INT32X8 bounds_8 = MS_MOV256_EPI32(6); | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_INT32X8 vin0 = MS_LD256_EPI32(in0 + index); | |||
| MS_INT32X8 vin1 = MS_LD256_EPI32(in1 + index); | |||
| MS_INT32X8 vout = MS_MIN256_EPI32(MS_MAX256_EPI32(MS_MUL256_EPI32(vin0, vin1), zeros_8), bounds_8); | |||
| MS_ST256_EPI32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| MS_INT32X4 zeros = MS_MOVQ_EPI32(0); | |||
| MS_INT32X4 bounds = MS_MOVQ_EPI32(6); | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_INT32X4 vin0 = MS_LDQ_EPI32(in0 + index); | |||
| MS_INT32X4 vin1 = MS_LDQ_EPI32(in1 + index); | |||
| MS_INT32X4 vout = MS_MINQ_EPI32(MS_MAXQ_EPI32(MS_MULQ_EPI32(vin0, vin1), zeros), bounds); | |||
| MS_STQ_EPI32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| @@ -129,28 +185,42 @@ int ElementMulRelu6Int(const int *in0, const int *in1, int *out, int size) { | |||
| } | |||
| int ElementOptMul(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) { | |||
| #ifdef ENABLE_NEON | |||
| float32x4_t vin0_opt = vdupq_n_f32(in0[0]); | |||
| float32x4_t vin1_opt = vdupq_n_f32(in1[0]); | |||
| #endif | |||
| int index = 0; | |||
| if (param->in_elements_num0_ == 1) { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| float32x4_t vin1 = vld1q_f32(in1 + index); | |||
| float32x4_t vout = vmulq_f32(vin0_opt, vin1); | |||
| vst1q_f32(out + index, vout); | |||
| #if defined(ENABLE_AVX) | |||
| MS_FLOAT32X8 vin0_opt_8 = MS_MOV256_F32(in0[0]); | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); | |||
| MS_FLOAT32X8 vout = MS_MUL256_F32(vin0_opt_8, vin1); | |||
| MS_ST256_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| MS_FLOAT32X4 vin0_opt = MS_MOVQ_F32(in0[0]); | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); | |||
| MS_FLOAT32X4 vout = MS_MULQ_F32(vin0_opt, vin1); | |||
| MS_STQ_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| out[index] = in0[0] * in1[index]; | |||
| } | |||
| } else { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| float32x4_t vin0 = vld1q_f32(in0 + index); | |||
| float32x4_t vout = vmulq_f32(vin0, vin1_opt); | |||
| vst1q_f32(out + index, vout); | |||
| #if defined(ENABLE_AVX) | |||
| MS_FLOAT32X8 vin1_opt_8 = MS_MOV256_F32(in1[0]); | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); | |||
| MS_FLOAT32X8 vout = MS_MUL256_F32(vin0, vin1_opt_8); | |||
| MS_ST256_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| MS_FLOAT32X4 vin1_opt = MS_MOVQ_F32(in1[0]); | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); | |||
| MS_FLOAT32X4 vout = MS_MULQ_F32(vin0, vin1_opt); | |||
| MS_STQ_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| @@ -161,29 +231,46 @@ int ElementOptMul(const float *in0, const float *in1, float *out, int size, cons | |||
| } | |||
| int ElementOptMulRelu(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) { | |||
| #ifdef ENABLE_NEON | |||
| float32x4_t vin0_opt = vdupq_n_f32(in0[0]); | |||
| float32x4_t vin1_opt = vdupq_n_f32(in1[0]); | |||
| float32x4_t zeros = vdupq_n_f32(0.0f); | |||
| #endif | |||
| int index = 0; | |||
| if (param->in_elements_num0_ == 1) { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| float32x4_t vin1 = vld1q_f32(in1 + index); | |||
| float32x4_t vout = vmaxq_f32(vmulq_f32(vin0_opt, vin1), zeros); | |||
| vst1q_f32(out + index, vout); | |||
| #if defined(ENABLE_AVX) | |||
| MS_FLOAT32X8 vin0_opt_8 = MS_MOV256_F32(in0[0]); | |||
| MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f); | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); | |||
| MS_FLOAT32X8 vout = MS_MAX256_F32(MS_MUL256_F32(vin0_opt_8, vin1), zeros_8); | |||
| MS_ST256_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| MS_FLOAT32X4 vin0_opt = MS_MOVQ_F32(in0[0]); | |||
| MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f); | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); | |||
| MS_FLOAT32X4 vout = MS_MAXQ_F32(MS_MULQ_F32(vin0_opt, vin1), zeros); | |||
| MS_STQ_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| out[index] = MSMAX(in0[0] * in1[index], 0); | |||
| } | |||
| } else { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| float32x4_t vin0 = vld1q_f32(in0 + index); | |||
| float32x4_t vout = vmaxq_f32(vmulq_f32(vin0, vin1_opt), zeros); | |||
| vst1q_f32(out + index, vout); | |||
| #if defined(ENABLE_AVX) | |||
| MS_FLOAT32X8 vin1_opt_8 = MS_MOV256_F32(in1[0]); | |||
| MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f); | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); | |||
| MS_FLOAT32X8 vout = MS_MAX256_F32(MS_MUL256_F32(vin0, vin1_opt_8), zeros_8); | |||
| MS_ST256_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| MS_FLOAT32X4 vin1_opt = MS_MOVQ_F32(in1[0]); | |||
| MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f); | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); | |||
| MS_FLOAT32X4 vout = MS_MAXQ_F32(MS_MULQ_F32(vin0, vin1_opt), zeros); | |||
| MS_STQ_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| @@ -194,30 +281,50 @@ int ElementOptMulRelu(const float *in0, const float *in1, float *out, int size, | |||
| } | |||
| int ElementOptMulRelu6(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) { | |||
| #ifdef ENABLE_NEON | |||
| float32x4_t vin0_opt = vdupq_n_f32(in0[0]); | |||
| float32x4_t vin1_opt = vdupq_n_f32(in1[0]); | |||
| float32x4_t zeros = vdupq_n_f32(0.0f); | |||
| float32x4_t bounds = vdupq_n_f32(6.0f); | |||
| #endif | |||
| int index = 0; | |||
| if (param->in_elements_num0_ == 1) { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| float32x4_t vin1 = vld1q_f32(in1 + index); | |||
| float32x4_t vout = vminq_f32(vmaxq_f32(vmulq_f32(vin0_opt, vin1), zeros), bounds); | |||
| vst1q_f32(out + index, vout); | |||
| #if defined(ENABLE_AVX) | |||
| MS_FLOAT32X8 vin0_opt_8 = MS_MOV256_F32(in0[0]); | |||
| MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f); | |||
| MS_FLOAT32X8 bounds_8 = MS_MOV256_F32(6.0f); | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); | |||
| MS_FLOAT32X8 vout = MS_MIN256_F32(MS_MAX256_F32(MS_MUL256_F32(vin0_opt_8, vin1), zeros_8), bounds_8); | |||
| MS_ST256_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| MS_FLOAT32X4 vin0_opt = MS_MOVQ_F32(in0[0]); | |||
| MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f); | |||
| MS_FLOAT32X4 bounds = MS_MOVQ_F32(6.0f); | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); | |||
| MS_FLOAT32X4 vout = MS_MINQ_F32(MS_MAXQ_F32(MS_MULQ_F32(vin0_opt, vin1), zeros), bounds); | |||
| MS_STQ_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| out[index] = MSMIN(MSMAX(in0[0] * in1[index], 0), 6); | |||
| } | |||
| } else { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| float32x4_t vin0 = vld1q_f32(in0 + index); | |||
| float32x4_t vout = vminq_f32(vmaxq_f32(vmulq_f32(vin0, vin1_opt), zeros), bounds); | |||
| vst1q_f32(out + index, vout); | |||
| #if defined(ENABLE_AVX) | |||
| MS_FLOAT32X8 vin1_opt_8 = MS_MOV256_F32(in1[0]); | |||
| MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f); | |||
| MS_FLOAT32X8 bounds_8 = MS_MOV256_F32(6.0f); | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); | |||
| MS_FLOAT32X8 vout = MS_MIN256_F32(MS_MAX256_F32(MS_MUL256_F32(vin0, vin1_opt_8), zeros_8), bounds_8); | |||
| MS_ST256_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| MS_FLOAT32X4 vin1_opt = MS_MOVQ_F32(in1[0]); | |||
| MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f); | |||
| MS_FLOAT32X4 bounds = MS_MOVQ_F32(6.0f); | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); | |||
| MS_FLOAT32X4 vout = MS_MINQ_F32(MS_MAXQ_F32(MS_MULQ_F32(vin0, vin1_opt), zeros), bounds); | |||
| MS_STQ_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| @@ -228,28 +335,42 @@ int ElementOptMulRelu6(const float *in0, const float *in1, float *out, int size, | |||
| } | |||
| int ElementOptMulInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param) { | |||
| #ifdef ENABLE_NEON | |||
| int32x4_t vin0_opt = vdupq_n_s32(in0[0]); | |||
| int32x4_t vin1_opt = vdupq_n_s32(in1[0]); | |||
| #endif | |||
| int index = 0; | |||
| if (param->in_elements_num0_ == 1) { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| int32x4_t vin1 = vld1q_s32(in1 + index); | |||
| int32x4_t vout = vmulq_s32(vin0_opt, vin1); | |||
| vst1q_s32(out + index, vout); | |||
| #if defined(ENABLE_AVX) | |||
| MS_INT32X8 vin0_opt_8 = MS_MOV256_EPI32(in0[0]); | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_INT32X8 vin1 = MS_LD256_EPI32(in1 + index); | |||
| MS_INT32X8 vout = MS_MUL256_EPI32(vin0_opt_8, vin1); | |||
| MS_ST256_EPI32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| MS_INT32X4 vin0_opt = MS_MOVQ_EPI32(in0[0]); | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_INT32X4 vin1 = MS_LDQ_EPI32(in1 + index); | |||
| MS_INT32X4 vout = MS_MULQ_EPI32(vin0_opt, vin1); | |||
| MS_STQ_EPI32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| out[index] = in0[0] * in1[index]; | |||
| } | |||
| } else { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| int32x4_t vin0 = vld1q_s32(in0 + index); | |||
| int32x4_t vout = vmulq_s32(vin0, vin1_opt); | |||
| vst1q_s32(out + index, vout); | |||
| #if defined(ENABLE_AVX) | |||
| MS_INT32X8 vin1_opt_8 = MS_MOV256_EPI32(in1[0]); | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_INT32X8 vin0 = MS_LD256_EPI32(in0 + index); | |||
| MS_INT32X8 vout = MS_MUL256_EPI32(vin0, vin1_opt_8); | |||
| MS_ST256_EPI32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| MS_INT32X4 vin1_opt = MS_MOVQ_EPI32(in1[0]); | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_INT32X4 vin0 = MS_LDQ_EPI32(in0 + index); | |||
| MS_INT32X4 vout = MS_MULQ_EPI32(vin0, vin1_opt); | |||
| MS_STQ_EPI32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| @@ -260,29 +381,46 @@ int ElementOptMulInt(const int *in0, const int *in1, int *out, int size, const A | |||
| } | |||
| int ElementOptMulReluInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param) { | |||
| #ifdef ENABLE_NEON | |||
| int32x4_t vin0_opt = vdupq_n_s32(in0[0]); | |||
| int32x4_t vin1_opt = vdupq_n_s32(in1[0]); | |||
| int32x4_t zeros = vdupq_n_s32(0); | |||
| #endif | |||
| int index = 0; | |||
| if (param->in_elements_num0_ == 1) { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| int32x4_t vin1 = vld1q_s32(in1 + index); | |||
| int32x4_t vout = vmaxq_s32(vmulq_s32(vin0_opt, vin1), zeros); | |||
| vst1q_s32(out + index, vout); | |||
| #if defined(ENABLE_AVX) | |||
| MS_INT32X8 vin0_opt_8 = MS_MOV256_EPI32(in0[0]); | |||
| MS_INT32X8 zeros_8 = MS_MOV256_EPI32(0); | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_INT32X8 vin1 = MS_LD256_EPI32(in1 + index); | |||
| MS_INT32X8 vout = MS_MAX256_EPI32(MS_MUL256_EPI32(vin0_opt_8, vin1), zeros_8); | |||
| MS_ST256_EPI32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| MS_INT32X4 vin0_opt = MS_MOVQ_EPI32(in0[0]); | |||
| MS_INT32X4 zeros = MS_MOVQ_EPI32(0); | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_INT32X4 vin1 = MS_LDQ_EPI32(in1 + index); | |||
| MS_INT32X4 vout = MS_MAXQ_EPI32(MS_MULQ_EPI32(vin0_opt, vin1), zeros); | |||
| MS_STQ_EPI32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| out[index] = MSMAX(in0[0] * in1[index], 0); | |||
| } | |||
| } else { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| int32x4_t vin0 = vld1q_s32(in0 + index); | |||
| int32x4_t vout = vmaxq_s32(vmulq_s32(vin0, vin1_opt), zeros); | |||
| vst1q_s32(out + index, vout); | |||
| #if defined(ENABLE_AVX) | |||
| MS_INT32X8 vin1_opt_8 = MS_MOV256_EPI32(in1[0]); | |||
| MS_INT32X8 zeros_8 = MS_MOV256_EPI32(0); | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_INT32X8 vin0 = MS_LD256_EPI32(in0 + index); | |||
| MS_INT32X8 vout = MS_MAX256_EPI32(MS_MUL256_EPI32(vin0, vin1_opt_8), zeros_8); | |||
| MS_ST256_EPI32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| MS_INT32X4 vin1_opt = MS_MOVQ_EPI32(in1[0]); | |||
| MS_INT32X4 zeros = MS_MOVQ_EPI32(0); | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_INT32X4 vin0 = MS_LDQ_EPI32(in0 + index); | |||
| MS_INT32X4 vout = MS_MAXQ_EPI32(MS_MULQ_EPI32(vin0, vin1_opt), zeros); | |||
| MS_STQ_EPI32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| @@ -293,30 +431,50 @@ int ElementOptMulReluInt(const int *in0, const int *in1, int *out, int size, con | |||
| } | |||
| int ElementOptMulRelu6Int(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param) { | |||
| #ifdef ENABLE_NEON | |||
| int32x4_t vin0_opt = vdupq_n_s32(in0[0]); | |||
| int32x4_t vin1_opt = vdupq_n_s32(in1[0]); | |||
| int32x4_t zeros = vdupq_n_s32(0); | |||
| int32x4_t bounds = vdupq_n_s32(6); | |||
| #endif | |||
| int index = 0; | |||
| if (param->in_elements_num0_ == 1) { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| int32x4_t vin1 = vld1q_s32(in1 + index); | |||
| int32x4_t vout = vminq_s32(vmaxq_s32(vmulq_s32(vin0_opt, vin1), zeros), bounds); | |||
| vst1q_s32(out + index, vout); | |||
| #if defined(ENABLE_AVX) | |||
| MS_INT32X8 vin0_opt_8 = MS_MOV256_EPI32(in0[0]); | |||
| MS_INT32X8 zeros_8 = MS_MOV256_EPI32(0); | |||
| MS_INT32X8 bounds_8 = MS_MOV256_EPI32(6); | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_INT32X8 vin1 = MS_LD256_EPI32(in1 + index); | |||
| MS_INT32X8 vout = MS_MIN256_EPI32(MS_MAX256_EPI32(MS_MUL256_EPI32(vin0_opt_8, vin1), zeros_8), bounds_8); | |||
| MS_ST256_EPI32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| MS_INT32X4 vin0_opt = MS_MOVQ_EPI32(in0[0]); | |||
| MS_INT32X4 zeros = MS_MOVQ_EPI32(0); | |||
| MS_INT32X4 bounds = MS_MOVQ_EPI32(6); | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_INT32X4 vin1 = MS_LDQ_EPI32(in1 + index); | |||
| MS_INT32X4 vout = MS_MINQ_EPI32(MS_MAXQ_EPI32(MS_MULQ_EPI32(vin0_opt, vin1), zeros), bounds); | |||
| MS_STQ_EPI32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| out[index] = MSMIN(MSMAX(in0[0] * in1[index], 0), 6); | |||
| } | |||
| } else { | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= size - 4; index += C4NUM) { | |||
| int32x4_t vin0 = vld1q_s32(in0 + index); | |||
| int32x4_t vout = vminq_s32(vmaxq_s32(vmulq_s32(vin0, vin1_opt), zeros), bounds); | |||
| vst1q_s32(out + index, vout); | |||
| #if defined(ENABLE_AVX) | |||
| MS_INT32X8 vin1_opt_8 = MS_MOV256_EPI32(in1[0]); | |||
| MS_INT32X8 zeros_8 = MS_MOV256_EPI32(0); | |||
| MS_INT32X8 bounds_8 = MS_MOV256_EPI32(6); | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_INT32X8 vin0 = MS_LD256_EPI32(in0 + index); | |||
| MS_INT32X8 vout = MS_MIN256_EPI32(MS_MAX256_EPI32(MS_MUL256_EPI32(vin0, vin1_opt_8), zeros_8), bounds_8); | |||
| MS_ST256_EPI32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| MS_INT32X4 vin1_opt = MS_MOVQ_EPI32(in1[0]); | |||
| MS_INT32X4 zeros = MS_MOVQ_EPI32(0); | |||
| MS_INT32X4 bounds = MS_MOVQ_EPI32(6); | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_INT32X4 vin0 = MS_LDQ_EPI32(in0 + index); | |||
| MS_INT32X4 vout = MS_MINQ_EPI32(MS_MAXQ_EPI32(MS_MULQ_EPI32(vin0, vin1_opt), zeros), bounds); | |||
| MS_STQ_EPI32(out + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| @@ -107,28 +107,105 @@ typedef enum CalFixedMultiplierMode { | |||
| #ifdef ENABLE_ARM | |||
| #define MS_FLOAT32X4 float32x4_t | |||
| #define MS_INT32X4 int32x4_t | |||
| #define MS_LDQ_F32 vld1q_f32 | |||
| #define MS_LDQ_EPI32 vld1q_s32 | |||
| #define MS_ADDQ_F32 vaddq_f32 | |||
| #define MS_ADDQ_EPI32 vaddq_s32 | |||
| #define MS_MOVQ_F32 vmovq_n_f32 | |||
| #define MS_MOVQ_EPI32 vmovq_n_s32 | |||
| #define MS_DUPQ_F32 vdupq_n_f32 // It is recommended to replace with MS_MOVQ_F32. | |||
| #define MS_SUBQ_F32 vsubq_f32 | |||
| #define MS_MLAQ_F32(src1, src2, src3) vmlaq_f32(src1, src2, src3) | |||
| #define MS_STQ_F32 vst1q_f32 | |||
| #define MS_STQ_EPI32 vst1q_s32 | |||
| #define MS_MAXQ_F32 vmaxq_f32 | |||
| #define MS_MAXQ_EPI32 vmaxq_s32 | |||
| #define MS_MINQ_F32 vminq_f32 | |||
| #define MS_MULQ_F32(src1, src2) vmulq_n_f32(src1, src2) | |||
| #elif defined(ENABLE_SSE) | |||
| #define MS_MINQ_EPI32 vminq_s32 | |||
| #define MS_MULQ_F32(src1, src2) vmulq_f32(src1, src2) | |||
| #define MS_MULQ_EPI32(src1, src2) vmulq_s32(src1, src2) | |||
| #ifdef ENABLE_ARM64 | |||
| #define MS_DIVQ_F32(src1, src2) vdivq_f32(src1, src2) | |||
| #else | |||
| #define MS_DIVQ_F32(src1, src2) vmulq_f32(src1, vrecpeq_f32(src2)) | |||
| #endif | |||
| #define MS_MULQ_N_F32(src1, src2) vmulq_n_f32(src1, src2) | |||
| #define MS_MULQ_N_EPI32(src1, src2) vmulq_n_s32(src1, src2) | |||
| #define MS_DIVQ_N_F32(src1, src2) vdivq_n_f32(src1, src2) | |||
| #define MS_SLLIQ_EPI32(src1, src2) vshlq_s32(src1, vmovq_n_s32(src2)) | |||
| #define MS_CVTQPS_EPI32(src) vcvtq_s32_f32(src) | |||
| #define MS_CVTQEPI32_PS(src) vcvtq_f32_s32(src) | |||
| #define MS_CMPGTQ_PS(src1, src2) vcgtq_f32(src1, src2) | |||
| #define MS_CMPGTQ_EPI32(src1, src2) vcgtq_s32(src1, src2) | |||
| // Note: Compared with X86, the vbslq_f32 parameters are the opposite with _mm_blendv_ps | |||
| #define MS_BLENDQ_PS(src1, src2, src3) vbslq_f32(src3, src2, src1) | |||
| #define MS_BLENDQ_EPI32(src1, src2, src3) vbslq_s32(src3, src2, src1) | |||
| #elif defined(ENABLE_AVX) | |||
| #define MS_FLOAT32X8 __m256 | |||
| #define MS_INT32X8 __m256i | |||
| #define MS_LD256_F32 _mm256_loadu_ps | |||
| #define MS_LD256_EPI32(src) _mm256_loadu_si256((__m256i const *)(src)) | |||
| #define MS_ADD256_F32 _mm256_add_ps | |||
| #define MS_ADD256_EPI32 _mm256_add_epi32 | |||
| #define MS_MOV256_F32 _mm256_set1_ps | |||
| #define MS_MOV256_EPI32 _mm256_set1_epi32 | |||
| #define MS_DUP256_F32 _mm256_load_ps1 // It is recommended to replace with MS_MOV256_F32. | |||
| #define MS_MLA256_F32(src1, src2, src3) _mm256_add_ps(src1, _mm256_mul_ps(src2, src3)) | |||
| #define MS_ST256_F32 _mm256_storeu_ps | |||
| #define MS_ST256_EPI32(src1, src2) _mm256_storeu_si256((__m256i *)(src1), src2) | |||
| #define MS_SUB256_F32 _mm256_sub_ps | |||
| #define MS_MAX256_F32 _mm256_max_ps | |||
| #define MS_MAX256_EPI32 _mm256_max_epi32 | |||
| #define MS_MIN256_F32 _mm256_min_ps | |||
| #define MS_MIN256_EPI32 _mm256_min_epi32 | |||
| #define MS_MUL256_F32(src1, src2) _mm256_mul_ps(src1, src2) | |||
| #define MS_MUL256_EPI32(src1, src2) _mm256_mul_epi32(src1, src2) | |||
| #define MS_DIV256_F32(src1, src2) _mm256_div_ps(src1, src2) | |||
| #define MS_MUL256_N_F32(src1, src2) _mm256_mul_ps(src1, _mm256_set1_ps(src2)) | |||
| #define MS_MUL256_N_EPI32(src1, src2) _mm256_mul_epi32(src1, _mm256_set1_epi32(src2)) | |||
| #define MS_DIV256_N_F32(src1, src2) _mm256_div_ps(src1, _mm256_set1_ps(src2)) | |||
| #define MS_SLLI256_EPI32(src1, src2) _mm256_slli_epi32(src1, src2) | |||
| #define MS_CVT256PS_EPI32(src) _mm256_cvttps_epi32(src) | |||
| #define MS_CVT256EPI32_PS(src) _mm256_cvtepi32_ps(src) // truncate float to int | |||
| #define MS_CMP256_PS(src1, src2, src3) _mm256_cmp_ps(src1, src2, src3) | |||
| #define MS_CMPGT256_EPI32(src1, src2) _mm256_cmpgt_epi32(src1, src2) | |||
| #define MS_BLEND256_PS(src1, src2, src3) _mm256_blendv_ps(src1, src2, src3) | |||
| #define MS_BLEND256_EPI32(src1, src2, src3) _mm256_blendv_epi8(src1, src2, src3) | |||
| #endif | |||
| #if defined(ENABLE_SSE) | |||
| #define MS_FLOAT32X4 __m128 | |||
| #define MS_INT32X4 __m128i | |||
| #define MS_LDQ_F32 _mm_loadu_ps | |||
| #define MS_LDQ_EPI32(src) _mm_loadu_si128((__m128i const *)(src)) | |||
| #define MS_ADDQ_F32 _mm_add_ps | |||
| #define MS_MOVQ_F32 _mm_set_ps1 | |||
| #define MS_ADDQ_EPI32 _mm_add_epi32 | |||
| #define MS_MOVQ_F32 _mm_set1_ps | |||
| #define MS_MOVQ_EPI32 _mm_set1_epi32 | |||
| #define MS_DUPQ_F32 _mm_load_ps1 // It is recommended to replace with MS_MOVQ_F32. | |||
| #define MS_MLAQ_F32(src1, src2, src3) _mm_add_ps(src1, _mm_mul_ps(src2, src3)) | |||
| #define MS_STQ_F32 _mm_storeu_ps | |||
| #define MS_STQ_EPI32(src1, src2) _mm_storeu_si128((__m128i *)(src1), src2) | |||
| #define MS_SUBQ_F32 _mm_sub_ps | |||
| #define MS_MAXQ_F32 _mm_max_ps | |||
| #define MS_MAXQ_EPI32 _mm_max_epi32 | |||
| #define MS_MINQ_F32 _mm_min_ps | |||
| #define MS_MULQ_F32(src1, src2) _mm_mul_ps(src1, _mm_set_ps1(src2)) | |||
| #define MS_MINQ_EPI32 _mm_min_epi32 | |||
| #define MS_MULQ_F32(src1, src2) _mm_mul_ps(src1, src2) | |||
| #define MS_MULQ_EPI32(src1, src2) _mm_mul_epi32(src1, src2) | |||
| #define MS_DIVQ_F32(src1, src2) _mm_div_ps(src1, src2) | |||
| #define MS_MULQ_N_F32(src1, src2) _mm_mul_ps(src1, _mm_set1_ps(src2)) | |||
| #define MS_MULQ_N_EPI32(src1, src2) _mm_mul_epi32(src1, _mm_set1_epi32(src2)) | |||
| #define MS_DIVQ_N_F32(src1, src2) _mm_div_ps(src1, _mm_set1_ps(src2)) | |||
| #define MS_SLLIQ_EPI32(src1, src2) _mm_slli_epi32(src1, src2) | |||
| #define MS_CVTQPS_EPI32(src) _mm_cvttps_epi32(src) // truncate float to int | |||
| #define MS_CVTQEPI32_PS(src) _mm_cvtepi32_ps(src) | |||
| #define MS_CMPGTQ_PS(src1, src2) _mm_cmpgt_ps(src1, src2) | |||
| #define MS_CMPGTQ_EPI32(src1, src2) _mm_cmpgt_epi32(src1, src2) | |||
| #define MS_BLENDQ_PS(src1, src2, src3) _mm_blendv_ps(src1, src2, src3) | |||
| #define MS_BLENDQ_EPI32(src1, src2, src3) _mm_blendv_epi8(src1, src2, src3) | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_OP_BASE_H_ | |||
| @@ -179,7 +179,7 @@ int TensorListGetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect | |||
| MS_LOG(ERROR) << "element_shape_ is not fullyDefined!"; | |||
| return RET_ERROR; | |||
| } | |||
| output->set_data_type(input0->data_type()); | |||
| output->set_data_type(input0->tensors_data_type()); | |||
| output->set_shape(element_shape_); | |||
| } | |||
| output->set_format(input0->GetTensor(index_)->format()); | |||