Browse Source

fp32 optimize

tags/v1.2.0-rc1
lzk 5 years ago
parent
commit
37a30fb0fb
8 changed files with 1034 additions and 565 deletions
  1. +2
    -0
      mindspore/lite/CMakeLists.txt
  2. +102
    -47
      mindspore/lite/nnacl/fp32/activation_fp32.c
  3. +194
    -83
      mindspore/lite/nnacl/fp32/add_fp32.c
  4. +38
    -15
      mindspore/lite/nnacl/fp32/exp_fp32.h
  5. +293
    -135
      mindspore/lite/nnacl/fp32/mul_fp32.c
  6. +323
    -280
      mindspore/lite/nnacl/fp32/winograd_utils.c
  7. +81
    -4
      mindspore/lite/nnacl/op_base.h
  8. +1
    -1
      mindspore/lite/src/ops/tensorlist_getitem.cc

+ 2
- 0
mindspore/lite/CMakeLists.txt View File

@@ -231,6 +231,8 @@ endif()
if(NOT PLATFORM_ARM32 AND NOT PLATFORM_ARM64)
if("${X86_64_SIMD}" STREQUAL "sse")
add_compile_definitions(ENABLE_SSE)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.1")
endif()
if("${X86_64_SIMD}" STREQUAL "avx")
add_compile_definitions(ENABLE_SSE)


+ 102
- 47
mindspore/lite/nnacl/fp32/activation_fp32.c View File

@@ -20,10 +20,17 @@

int Fp32Relu(const float *src, int length, float *dst) {
int i = 0;
#ifdef ENABLE_ARM
float32x4_t zero_4 = vdupq_n_f32(0.0f);
#if defined(ENABLE_AVX)
MS_FLOAT32X8 zero_8 = MS_MOV256_F32(0.0f);
for (; i < length - 8; i += 8) {
MS_ST256_F32(dst + i, MS_MAX256_F32(MS_LD256_F32(src + i), zero_8));
}
#endif

#if defined(ENABLE_SSE) || defined(ENABLE_ARM)
MS_FLOAT32X4 zero = MS_MOVQ_F32(0.0f);
for (; i < length - 4; i += 4) {
vst1q_f32(dst + i, vmaxq_f32(vld1q_f32(src + i), zero_4));
MS_STQ_F32(dst + i, MS_MAXQ_F32(MS_LDQ_F32(src + i), zero));
}
#endif
for (; i < length; ++i) {
@@ -34,13 +41,24 @@ int Fp32Relu(const float *src, int length, float *dst) {

int Fp32Relu6(const float *src, int length, float *dst) {
int i = 0;
#ifdef ENABLE_ARM
float32x4_t zero_4 = vdupq_n_f32(0.0f);
float32x4_t six_4 = vdupq_n_f32(6.0f);

#if defined(ENABLE_AVX)
MS_FLOAT32X8 zero_8 = MS_MOV256_F32(0.0f);
MS_FLOAT32X8 six_8 = MS_MOV256_F32(6.0f);
for (; i < length - 8; i += 8) {
MS_FLOAT32X8 dst_tmp = MS_MAX256_F32(MS_LD256_F32(src + i), zero_8);
dst_tmp = MS_MIN256_F32(dst_tmp, six_8);
MS_ST256_F32(dst + i, dst_tmp);
}
#endif

#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
MS_FLOAT32X4 zero = MS_MOVQ_F32(0.0f);
MS_FLOAT32X4 six = MS_MOVQ_F32(6.0f);
for (; i < length - 4; i += 4) {
float32x4_t dst_4 = vmaxq_f32(vld1q_f32(src + i), zero_4);
dst_4 = vminq_f32(dst_4, six_4);
vst1q_f32(dst + i, dst_4);
MS_FLOAT32X4 dst_tmp = MS_MAXQ_F32(MS_LDQ_F32(src + i), zero);
dst_tmp = MS_MINQ_F32(dst_tmp, six);
MS_STQ_F32(dst + i, dst_tmp);
}
#endif
for (; i < length; ++i) {
@@ -55,14 +73,21 @@ int Fp32Relu6(const float *src, int length, float *dst) {

int LRelu(const float *src, int length, float *dst, float alpha) {
int i = 0;
#ifdef ENABLE_ARM64
float32x4_t alpha_4 = vdupq_n_f32(alpha);
#if defined(ENABLE_AVX)
for (; i < length - 8; i += 8) {
MS_FLOAT32X8 src_tmp = MS_LD256_F32(src + i);
MS_FLOAT32X8 mul_tmp = MS_MUL256_N_F32(src_tmp, alpha);
MS_FLOAT32X8 mask = MS_CMP256_PS(src_tmp, MS_MOV256_F32(0.0f), 30);
MS_ST256_F32(dst + i, MS_BLEND256_PS(mul_tmp, src_tmp, mask));
}
#endif

#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
for (; i < length - 4; i += 4) {
float32x4_t src_4 = vld1q_f32(src + i);
float32x4_t mul_4 = vmulq_f32(src_4, alpha_4);
uint32x4_t flag = vclezq_f32(src_4);
float32x4_t dst_4 = vbslq_f32(flag, mul_4, src_4);
vst1q_f32(dst + i, dst_4);
MS_FLOAT32X4 src_tmp = MS_LDQ_F32(src + i);
MS_FLOAT32X4 mul_tmp = MS_MULQ_N_F32(src_tmp, alpha);
MS_FLOAT32X4 mask = MS_CMPGTQ_PS(src_tmp, MS_MOVQ_F32(0.0f));
MS_STQ_F32(dst + i, MS_BLENDQ_PS(mul_tmp, src_tmp, mask));
}
#endif
for (; i < length; ++i) {
@@ -73,11 +98,18 @@ int LRelu(const float *src, int length, float *dst, float alpha) {

int Sigmoid(const float *src, int length, float *dst) {
int i = 0;
#ifdef ENABLE_ARM64
int count = (length / C4NUM) * C4NUM;
for (; i < count; i += C4NUM) {
simd_exp(vnegq_f32(vld1q_f32(src + i)), dst + i);
vst1q_f32(dst + i, vdivq_f32(vdupq_n_f32(1.0f), vaddq_f32(vdupq_n_f32(1.0f), vld1q_f32(dst + i))));
#if defined(ENABLE_AVX)
for (; i < length - 8; i += 8) {
simd_exp_avx(-(MS_LD256_F32(src + i)), dst + i);
MS_ST256_F32(dst + i,
MS_DIV256_F32(MS_MOV256_F32(1.0f), MS_ADD256_F32(MS_MOV256_F32(1.0f), MS_LD256_F32(dst + i))));
}
#endif

#if defined(ENABLE_ARM64) || defined(ENABLE_SSE)
for (; i < length - 4; i += 4) {
simd_exp(-(MS_LDQ_F32(src + i)), dst + i);
MS_STQ_F32(dst + i, MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_ADDQ_F32(MS_MOVQ_F32(1.0f), MS_LDQ_F32(dst + i))));
}
#endif
for (; i < length; ++i) {
@@ -102,26 +134,40 @@ float TanhOpt(float src) {

int Tanh(const float *src, int length, float *dst) {
int i = 0;
#ifdef ENABLE_ARM64
static float32x4_t paramv[] = {{378.0f, 378.0f, 378.0f, 378.0f},
{17325.0f, 17325.0f, 17325.0f, 17325.0f},
{135135.0f, 135135.0f, 135135.0f, 135135.0f},
{28.0f, 28.0f, 28.0f, 28.0f},
{3150.0f, 3150.0f, 3150.0f, 3150.0f},
{62370.0f, 62370.0f, 62370.0f, 62370.0f}};
float32x4_t neg_one = {-1.0f, -1.0f, -1.0f, -1.0f};
float32x4_t pos_one = {1.0f, 1.0f, 1.0f, 1.0f};
int count = (length / C4NUM) * C4NUM;
for (; i < count; i += C4NUM) {
float32x4_t input = vld1q_f32(src + i);
float32x4_t square = vmulq_f32(input, input);
float32x4_t a = vmulq_f32(
vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(square, paramv[0]), square), paramv[1]), square), paramv[2]),
input);
float32x4_t b = vaddq_f32(
vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(paramv[3], square), paramv[4]), square), paramv[5]), square),
paramv[2]);
vst1q_f32(dst + i, vminq_f32(vmaxq_f32(vdivq_f32(a, b), neg_one), pos_one));
#if defined(ENABLE_ARM) || defined(ENABLE_SSE) || defined(ENABLE_AVX)
const int cnt = 6;
float data[] = {378.0f, 17325.0f, 135135.0f, 28.0f, 3150.0f, 62370.0f};
#endif

#if defined(ENABLE_AVX)
MS_FLOAT32X8 neg_one_8 = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f};
MS_FLOAT32X8 pos_one_8 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
MS_FLOAT32X8 param256[6];
for (int j = 0; j < cnt; ++j) {
param256[j] = MS_MOV256_F32(data[j]);
}
for (; i < length - 8; i += 8) {
MS_FLOAT32X8 input = MS_LD256_F32(src + i);
MS_FLOAT32X8 square = input * input;
MS_FLOAT32X8 a = (((square + param256[0]) * square + param256[1]) * square + param256[2]) * input;
MS_FLOAT32X8 b = ((param256[3] * square + param256[4]) * square + param256[5]) * square + param256[2];
MS_ST256_F32(dst + i, MS_MIN256_F32(MS_MAX256_F32(a / b, neg_one_8), pos_one_8));
}
#endif

#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
MS_FLOAT32X4 param[6];
MS_FLOAT32X4 neg_one = {-1.0f, -1.0f, -1.0f, -1.0f};
MS_FLOAT32X4 pos_one = {1.0f, 1.0f, 1.0f, 1.0f};
for (int j = 0; j < cnt; ++j) {
param[j] = MS_MOVQ_F32(data[j]);
}
for (; i < length - 4; i += 4) {
MS_FLOAT32X4 input = MS_LDQ_F32(src + i);
MS_FLOAT32X4 square = input * input;
MS_FLOAT32X4 a = (((square + param[0]) * square + param[1]) * square + param[2]) * input;
MS_FLOAT32X4 b = ((param[3] * square + param[4]) * square + param[5]) * square + param[2];
MS_STQ_F32(dst + i, MS_MINQ_F32(MS_MAXQ_F32(a / b, neg_one), pos_one));
}
#endif
for (; i < length; ++i) {
@@ -142,12 +188,21 @@ int Swish(const float *src, int length, float *dst) {
return NNACL_ERR;
}
int index = 0;
#ifdef ENABLE_NEON
for (; index <= length - C4NUM; index += C4NUM) {
float32x4_t src_value = vld1q_f32(src + index);
float32x4_t sigmoid_value = vld1q_f32(dst + index);
float32x4_t result = vmulq_f32(src_value, sigmoid_value);
vst1q_f32(dst + index, result);
#if defined(ENABLE_AVX)
for (; index <= length - 8; index += 8) {
MS_FLOAT32X8 src_value = MS_LD256_F32(src + index);
MS_FLOAT32X8 sigmoid_value = MS_LD256_F32(dst + index);
MS_FLOAT32X8 result = MS_MUL256_F32(src_value, sigmoid_value);
MS_ST256_F32(dst + index, result);
}
#endif

#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
for (; index <= length - 4; index += 4) {
MS_FLOAT32X4 src_value = MS_LDQ_F32(src + index);
MS_FLOAT32X4 sigmoid_value = MS_LDQ_F32(dst + index);
MS_FLOAT32X4 result = MS_MULQ_F32(src_value, sigmoid_value);
MS_STQ_F32(dst + index, result);
}
#endif
for (; index < length; index++) {


+ 194
- 83
mindspore/lite/nnacl/fp32/add_fp32.c View File

@@ -18,28 +18,46 @@
#include "nnacl/fp32/arithmetic_fp32.h"

int ElementOptAdd(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float32x4_t vin0_opt = vdupq_n_f32(in0[0]);
float32x4_t vin1_opt = vdupq_n_f32(in1[0]);
#ifdef ENABLE_AVX
MS_FLOAT32X8 vin0_opt_8 = MS_MOV256_F32(in0[0]);
MS_FLOAT32X8 vin1_opt_8 = MS_MOV256_F32(in1[0]);
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_FLOAT32X4 vin0_opt = MS_MOVQ_F32(in0[0]);
MS_FLOAT32X4 vin1_opt = MS_MOVQ_F32(in1[0]);
#endif
int index = 0;
if (param->in_elements_num0_ == 1) {
#ifdef ENABLE_NEON
for (; index <= size - 4; index += C4NUM) {
float32x4_t vin1 = vld1q_f32(in1 + index);
float32x4_t vout = vaddq_f32(vin0_opt, vin1);
vst1q_f32(out + index, vout);
#ifdef ENABLE_AVX
for (; index <= size - C8NUM; index += C8NUM) {
MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index);
MS_FLOAT32X8 vout = MS_ADD256_F32(vin0_opt_8, vin1);
MS_ST256_F32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
for (; index <= size - C4NUM; index += C4NUM) {
MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index);
MS_FLOAT32X4 vout = MS_ADDQ_F32(vin0_opt, vin1);
MS_STQ_F32(out + index, vout);
}
#endif
for (; index < size; index++) {
out[index] = in0[0] + in1[index];
}
} else {
#ifdef ENABLE_NEON
for (; index <= size - 4; index += C4NUM) {
float32x4_t vin0 = vld1q_f32(in0 + index);
float32x4_t vout = vaddq_f32(vin0, vin1_opt);
vst1q_f32(out + index, vout);
#ifdef ENABLE_AVX
for (; index <= size - C8NUM; index += C8NUM) {
MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index);
MS_FLOAT32X8 vout = MS_ADD256_F32(vin0, vin1_opt_8);
MS_ST256_F32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
for (; index <= size - C4NUM; index += C4NUM) {
MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index);
MS_FLOAT32X4 vout = MS_ADDQ_F32(vin0, vin1_opt);
MS_STQ_F32(out + index, vout);
}
#endif
for (; index < size; index++) {
@@ -50,28 +68,46 @@ int ElementOptAdd(const float *in0, const float *in1, float *out, int size, cons
}

int ElementOptAddInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param) {
#ifdef ENABLE_NEON
int32x4_t vin0_opt = vdupq_n_s32(in0[0]);
int32x4_t vin1_opt = vdupq_n_s32(in1[0]);
#ifdef ENABLE_AVX
MS_INT32X8 vin0_opt_8 = MS_MOV256_EPI32(in0[0]);
MS_INT32X8 vin1_opt_8 = MS_MOV256_EPI32(in1[0]);
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_INT32X4 vin0_opt = MS_MOVQ_EPI32(in0[0]);
MS_INT32X4 vin1_opt = MS_MOVQ_EPI32(in1[0]);
#endif
int index = 0;
if (param->in_elements_num0_ == 1) {
#ifdef ENABLE_NEON
for (; index <= size - 4; index += C4NUM) {
int32x4_t vin1 = vld1q_s32(in1 + index);
int32x4_t vout = vaddq_s32(vin0_opt, vin1);
vst1q_s32(out + index, vout);
#ifdef ENABLE_AVX
for (; index <= size - C8NUM; index += C8NUM) {
MS_INT32X8 vin1 = MS_LD256_EPI32(in1 + index);
MS_INT32X8 vout = MS_ADD256_EPI32(vin0_opt_8, vin1);
MS_ST256_EPI32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
for (; index <= size - C4NUM; index += C4NUM) {
MS_INT32X4 vin1 = MS_LDQ_EPI32(in1 + index);
MS_INT32X4 vout = MS_ADDQ_EPI32(vin0_opt, vin1);
MS_STQ_EPI32(out + index, vout);
}
#endif
for (; index < size; index++) {
out[index] = in0[0] + in1[index];
}
} else {
#ifdef ENABLE_NEON
for (; index <= size - 4; index += C4NUM) {
int32x4_t vin0 = vld1q_s32(in0 + index);
int32x4_t vout = vaddq_s32(vin0, vin1_opt);
vst1q_s32(out + index, vout);
#ifdef ENABLE_AVX
for (; index <= size - C8NUM; index += C8NUM) {
MS_INT32X8 vin0 = MS_LD256_EPI32(in0 + index);
MS_INT32X8 vout = MS_ADD256_EPI32(vin0, vin1_opt_8);
MS_ST256_EPI32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
for (; index <= size - C4NUM; index += C4NUM) {
MS_INT32X4 vin0 = MS_LDQ_EPI32(in0 + index);
MS_INT32X4 vout = MS_ADDQ_EPI32(vin0, vin1_opt);
MS_STQ_EPI32(out + index, vout);
}
#endif
for (; index < size; index++) {
@@ -82,29 +118,48 @@ int ElementOptAddInt(const int *in0, const int *in1, int *out, int size, const A
}

int ElementOptAddRelu(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float32x4_t vin0_opt = vdupq_n_f32(in0[0]);
float32x4_t vin1_opt = vdupq_n_f32(in1[0]);
float32x4_t zeros = vdupq_n_f32(0.0f);
#ifdef ENABLE_AVX
MS_FLOAT32X8 vin0_opt_8 = MS_MOV256_F32(in0[0]);
MS_FLOAT32X8 vin1_opt_8 = MS_MOV256_F32(in1[0]);
MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f);
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_FLOAT32X4 vin0_opt = MS_MOVQ_F32(in0[0]);
MS_FLOAT32X4 vin1_opt = MS_MOVQ_F32(in1[0]);
MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f);
#endif
int index = 0;
if (param->in_elements_num0_ == 1) {
#ifdef ENABLE_NEON
for (; index <= size - 4; index += C4NUM) {
float32x4_t vin1 = vld1q_f32(in1 + index);
float32x4_t vout = vmaxq_f32(vaddq_f32(vin0_opt, vin1), zeros);
vst1q_f32(out + index, vout);
#ifdef ENABLE_AVX
for (; index <= size - C8NUM; index += C8NUM) {
MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index);
MS_FLOAT32X8 vout = MS_MAX256_F32(MS_ADD256_F32(vin0_opt_8, vin1), zeros_8);
MS_ST256_F32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
for (; index <= size - C4NUM; index += C4NUM) {
MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index);
MS_FLOAT32X4 vout = MS_MAXQ_F32(MS_ADDQ_F32(vin0_opt, vin1), zeros);
MS_STQ_F32(out + index, vout);
}
#endif
for (; index < size; index++) {
out[index] = MSMAX(in0[0] + in1[index], 0);
}
} else {
#ifdef ENABLE_NEON
for (; index <= size - 4; index += C4NUM) {
float32x4_t vin0 = vld1q_f32(in0 + index);
float32x4_t vout = vmaxq_f32(vaddq_f32(vin0, vin1_opt), zeros);
vst1q_f32(out + index, vout);
#ifdef ENABLE_AVX
for (; index <= size - C8NUM; index += C8NUM) {
MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index);
MS_FLOAT32X8 vout = MS_MAX256_F32(MS_ADD256_F32(vin0, vin1_opt_8), zeros_8);
MS_ST256_F32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
for (; index <= size - C4NUM; index += C4NUM) {
MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index);
MS_FLOAT32X4 vout = MS_MAXQ_F32(MS_ADDQ_F32(vin0, vin1_opt), zeros);
MS_STQ_F32(out + index, vout);
}
#endif
for (; index < size; index++) {
@@ -115,30 +170,50 @@ int ElementOptAddRelu(const float *in0, const float *in1, float *out, int size,
}

int ElementOptAddRelu6(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float32x4_t vin0_opt = vdupq_n_f32(in0[0]);
float32x4_t vin1_opt = vdupq_n_f32(in1[0]);
float32x4_t zeros = vdupq_n_f32(0.0f);
float32x4_t bounds = vdupq_n_f32(6.0f);
#ifdef ENABLE_AVX
MS_FLOAT32X8 vin0_opt_8 = MS_MOV256_F32(in0[0]);
MS_FLOAT32X8 vin1_opt_8 = MS_MOV256_F32(in1[0]);
MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f);
MS_FLOAT32X8 bounds_8 = MS_MOV256_F32(6.0f);
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_FLOAT32X4 vin0_opt = MS_MOVQ_F32(in0[0]);
MS_FLOAT32X4 vin1_opt = MS_MOVQ_F32(in1[0]);
MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f);
MS_FLOAT32X4 bounds = MS_MOVQ_F32(6.0f);
#endif
int index = 0;
if (param->in_elements_num0_ == 1) {
#ifdef ENABLE_NEON
for (; index <= size - 4; index += C4NUM) {
float32x4_t vin1 = vld1q_f32(in1 + index);
float32x4_t vout = vminq_f32(vmaxq_f32(vaddq_f32(vin0_opt, vin1), zeros), bounds);
vst1q_f32(out + index, vout);
#ifdef ENABLE_AVX
for (; index <= size - C8NUM; index += C8NUM) {
MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index);
MS_FLOAT32X8 vout = MS_MIN256_F32(MS_MAX256_F32(MS_ADD256_F32(vin0_opt_8, vin1), zeros_8), bounds_8);
MS_ST256_F32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
for (; index <= size - C4NUM; index += C4NUM) {
MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index);
MS_FLOAT32X4 vout = MS_MINQ_F32(MS_MAXQ_F32(MS_ADDQ_F32(vin0_opt, vin1), zeros), bounds);
MS_STQ_F32(out + index, vout);
}
#endif
for (; index < size; index++) {
out[index] = MSMIN(MSMAX(in0[0] + in1[index], 0), 6);
}
} else {
#ifdef ENABLE_NEON
for (; index <= size - 4; index += C4NUM) {
float32x4_t vin0 = vld1q_f32(in0 + index);
float32x4_t vout = vminq_f32(vmaxq_f32(vaddq_f32(vin0, vin1_opt), zeros), bounds);
vst1q_f32(out + index, vout);
#ifdef ENABLE_AVX
for (; index <= size - C8NUM; index += C8NUM) {
MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index);
MS_FLOAT32X8 vout = MS_MIN256_F32(MS_MAX256_F32(MS_ADD256_F32(vin0, vin1_opt_8), zeros_8), bounds_8);
MS_ST256_F32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
for (; index <= size - C4NUM; index += C4NUM) {
MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index);
MS_FLOAT32X4 vout = MS_MINQ_F32(MS_MAXQ_F32(MS_ADDQ_F32(vin0, vin1_opt), zeros), bounds);
MS_STQ_F32(out + index, vout);
}
#endif
for (; index < size; index++) {
@@ -157,12 +232,20 @@ int BroadcastAdd(const float *in0, const float *in1, float *tile_in0, float *til

int ElementAdd(const float *in0, const float *in1, float *out, int size) {
int index = 0;
#ifdef ENABLE_NEON
for (; index <= size - 4; index += C4NUM) {
float32x4_t vin0 = vld1q_f32(in0 + index);
float32x4_t vin1 = vld1q_f32(in1 + index);
float32x4_t vout = vaddq_f32(vin0, vin1);
vst1q_f32(out + index, vout);
#ifdef ENABLE_AVX
for (; index <= size - C8NUM; index += C8NUM) {
MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index);
MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index);
MS_FLOAT32X8 vout = MS_ADD256_F32(vin0, vin1);
MS_ST256_F32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
for (; index <= size - C4NUM; index += C4NUM) {
MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index);
MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index);
MS_FLOAT32X4 vout = MS_ADDQ_F32(vin0, vin1);
MS_STQ_F32(out + index, vout);
}
#endif
for (; index < size; index++) {
@@ -173,14 +256,24 @@ int ElementAdd(const float *in0, const float *in1, float *out, int size) {

int ElementAddRelu(const float *in0, const float *in1, float *out, int size) {
int index = 0;
#ifdef ENABLE_NEON
float32x4_t zeros = vdupq_n_f32(0.0f);
for (; index <= size - 4; index += C4NUM) {
float32x4_t vin0 = vld1q_f32(in0 + index);
float32x4_t vin1 = vld1q_f32(in1 + index);
float32x4_t vout = vaddq_f32(vin0, vin1);
vout = vbslq_f32(vcgtq_f32(vout, zeros), vout, zeros);
vst1q_f32(out + index, vout);
#ifdef ENABLE_AVX
MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f);
for (; index <= size - C8NUM; index += C8NUM) {
MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index);
MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index);
MS_FLOAT32X8 vout = MS_ADD256_F32(vin0, vin1);
vout = MS_BLEND256_PS(zeros_8, vout, MS_CMP256_PS(vout, zeros_8, 30));
MS_ST256_F32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f);
for (; index <= size - C4NUM; index += C4NUM) {
MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index);
MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index);
MS_FLOAT32X4 vout = MS_ADDQ_F32(vin0, vin1);
vout = MS_BLENDQ_PS(zeros, vout, MS_CMPGTQ_PS(vout, zeros));
MS_STQ_F32(out + index, vout);
}
#endif
for (; index < size; index++) {
@@ -192,14 +285,24 @@ int ElementAddRelu(const float *in0, const float *in1, float *out, int size) {

int ElementAddRelu6(const float *in0, const float *in1, float *out, int size) {
int index = 0;
#ifdef ENABLE_NEON
float32x4_t zeros = vdupq_n_f32(0.0f);
float32x4_t bounds = vdupq_n_f32(6.0f);
for (; index <= size - 4; index += C4NUM) {
float32x4_t vin0 = vld1q_f32(in0 + index);
float32x4_t vin1 = vld1q_f32(in1 + index);
float32x4_t vout = vminq_f32(vmaxq_f32(vaddq_f32(vin0, vin1), zeros), bounds);
vst1q_f32(out + index, vout);
#ifdef ENABLE_AVX
MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f);
MS_FLOAT32X8 bounds_8 = MS_MOV256_F32(6.0f);
for (; index <= size - C8NUM; index += C8NUM) {
MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index);
MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index);
MS_FLOAT32X8 vout = MS_MIN256_F32(MS_MAX256_F32(MS_ADD256_F32(vin0, vin1), zeros_8), bounds_8);
MS_ST256_F32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f);
MS_FLOAT32X4 bounds = MS_MOVQ_F32(6.0f);
for (; index <= size - C4NUM; index += C4NUM) {
MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index);
MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index);
MS_FLOAT32X4 vout = MS_MINQ_F32(MS_MAXQ_F32(MS_ADDQ_F32(vin0, vin1), zeros), bounds);
MS_STQ_F32(out + index, vout);
}
#endif
for (; index < size; index++) {
@@ -210,12 +313,20 @@ int ElementAddRelu6(const float *in0, const float *in1, float *out, int size) {

int ElementAddInt(const int *in0, const int *in1, int *out, int size) {
int index = 0;
#ifdef ENABLE_NEON
for (; index <= size - 4; index += C4NUM) {
int32x4_t vin0 = vld1q_s32(in0 + index);
int32x4_t vin1 = vld1q_s32(in1 + index);
int32x4_t vout = vaddq_s32(vin0, vin1);
vst1q_s32(out + index, vout);
#ifdef ENABLE_AVX
for (; index <= size - C8NUM; index += C8NUM) {
MS_INT32X8 vin0 = MS_LD256_EPI32(in0 + index);
MS_INT32X8 vin1 = MS_LD256_EPI32(in1 + index);
MS_INT32X8 vout = MS_ADD256_EPI32(vin0, vin1);
MS_ST256_EPI32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
for (; index <= size - C4NUM; index += C4NUM) {
MS_INT32X4 vin0 = MS_LDQ_EPI32(in0 + index);
MS_INT32X4 vin1 = MS_LDQ_EPI32(in1 + index);
MS_INT32X4 vout = MS_ADDQ_EPI32(vin0, vin1);
MS_STQ_EPI32(out + index, vout);
}
#endif
for (; index < size; index++) {


+ 38
- 15
mindspore/lite/nnacl/fp32/exp_fp32.h View File

@@ -38,26 +38,49 @@ extern "C" {
int Exp(const float *input_data, float *output_data, const ExpParameter *parameter, int task_id);
void ExpFp32(const float *src, float *dst, int num);

#ifdef ENABLE_ARM64
static inline void simd_exp(float32x4_t input4, float *dst) {
static float32x4_t maxv = {88.0f, 88.0f, 88.0f, 88.0f};
static float32x4_t minv = {-88.0f, -88.0f, -88.0f, -88.0f};
static float32x4_t paramv[] = {{0.693147f, 0.693147f, 0.693147f, 0.693147f},
#if defined(ENABLE_ARM64) || defined(ENABLE_SSE)
static inline void simd_exp(MS_FLOAT32X4 input, float *dst) {
static MS_FLOAT32X4 maxv = {88.0f, 88.0f, 88.0f, 88.0f};
static MS_FLOAT32X4 minv = {-88.0f, -88.0f, -88.0f, -88.0f};
static MS_FLOAT32X4 param[] = {{0.693147f, 0.693147f, 0.693147f, 0.693147f},
{1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120},
{1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24},
{1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6},
{0.5f, 0.5f, 0.5f, 0.5f},
{1.0f, 1.0f, 1.0f, 1.0f}};
input4 = vmaxq_f32(minv, vminq_f32(maxv, input4));
int32x4_t integer4 = vcvtq_s32_f32(vdivq_f32(input4, paramv[0]));
float32x4_t decimal4 = vsubq_f32(input4, vmulq_f32(vcvtq_f32_s32(integer4), paramv[0]));
int32x4_t int_exp4 = vshlq_s32(vaddq_s32(integer4, vdupq_n_s32(127)), vdupq_n_s32(23));
vst1q_f32(dst, vld1q_f32((float32_t *)(&int_exp4)));
float32x4_t decimal_exp4 = vaddq_f32(paramv[2], vmulq_f32(decimal4, paramv[1]));
decimal_exp4 = vmulq_f32(decimal4, vaddq_f32(paramv[3], vmulq_f32(decimal4, decimal_exp4)));
decimal_exp4 = vaddq_f32(paramv[5], vmulq_f32(decimal4, vaddq_f32(paramv[4], decimal_exp4)));
decimal_exp4 = vaddq_f32(paramv[5], vmulq_f32(decimal4, decimal_exp4));
vst1q_f32(dst, vmulq_f32(vld1q_f32(dst), decimal_exp4));

input = MS_MAXQ_F32(minv, MS_MINQ_F32(input, maxv));
MS_INT32X4 integer = MS_CVTQPS_EPI32(input / param[0]);
MS_FLOAT32X4 decimal = input - MS_CVTQEPI32_PS(integer) * param[0];
MS_INT32X4 int_exp = MS_SLLIQ_EPI32(MS_ADDQ_EPI32(integer, MS_MOVQ_EPI32(127)), 23);
memcpy(dst, &int_exp, sizeof(int32_t) * 4);
MS_FLOAT32X4 decimal_exp =
param[5] +
decimal * (param[5] + decimal * (param[4] + decimal * (param[3] + decimal * (param[2] + decimal * param[1]))));
MS_STQ_F32(dst, decimal_exp * MS_LDQ_F32(dst));
}
#endif

#if defined(ENABLE_AVX)
static inline void simd_exp_avx(MS_FLOAT32X8 input, float *dst) {
static MS_FLOAT32X8 maxv = {88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f};
static MS_FLOAT32X8 minv = {-88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f};
static MS_FLOAT32X8 param[] = {
{0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f},
{1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120},
{1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24},
{1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6},
{0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}};
input = MS_MAX256_F32(minv, MS_MIN256_F32(input, maxv));
MS_INT32X8 integer = MS_CVT256PS_EPI32(input / param[0]);
MS_FLOAT32X8 decimal = input - MS_CVT256EPI32_PS(integer) * param[0];
MS_INT32X8 int_exp = MS_SLLI256_EPI32(MS_ADD256_EPI32(integer, MS_MOV256_EPI32(127)), 23);
memcpy(dst, &int_exp, sizeof(int32_t) * 8);
MS_FLOAT32X8 decimal_exp =
param[5] +
decimal * (param[5] + decimal * (param[4] + decimal * (param[3] + decimal * (param[2] + decimal * param[1]))));
MS_ST256_F32(dst, decimal_exp * MS_LD256_F32(dst));
}
#endif



+ 293
- 135
mindspore/lite/nnacl/fp32/mul_fp32.c View File

@@ -24,12 +24,20 @@ int BroadcastMul(const float *in0, const float *in1, float *tile_in0, float *til

int ElementMul(const float *in0, const float *in1, float *out, int size) {
int index = 0;
#ifdef ENABLE_NEON
for (; index <= size - 4; index += C4NUM) {
float32x4_t vin0 = vld1q_f32(in0 + index);
float32x4_t vin1 = vld1q_f32(in1 + index);
float32x4_t vout = vmulq_f32(vin0, vin1);
vst1q_f32(out + index, vout);
#if defined(ENABLE_AVX)
for (; index <= size - C8NUM; index += C8NUM) {
MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index);
MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index);
MS_FLOAT32X8 vout = MS_MUL256_F32(vin0, vin1);
MS_ST256_F32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
for (; index <= size - C4NUM; index += C4NUM) {
MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index);
MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index);
MS_FLOAT32X4 vout = MS_MULQ_F32(vin0, vin1);
MS_STQ_F32(out + index, vout);
}
#endif
for (; index < size; index++) {
@@ -40,14 +48,24 @@ int ElementMul(const float *in0, const float *in1, float *out, int size) {

int ElementMulRelu(const float *in0, const float *in1, float *out, int size) {
int index = 0;
#ifdef ENABLE_NEON
float32x4_t zeros = vdupq_n_f32(0.0f);
for (; index <= size - 4; index += C4NUM) {
float32x4_t vin0 = vld1q_f32(in0 + index);
float32x4_t vin1 = vld1q_f32(in1 + index);
float32x4_t vout = vmulq_f32(vin0, vin1);
vout = vbslq_f32(vcgtq_f32(vout, zeros), vout, zeros);
vst1q_f32(out + index, vout);
#if defined(ENABLE_AVX)
MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f);
for (; index <= size - C8NUM; index += C8NUM) {
MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index);
MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index);
MS_FLOAT32X8 vout = MS_MUL256_F32(vin0, vin1);
vout = MS_BLEND256_PS(zeros_8, vout, MS_CMP256_PS(vout, zeros_8, 30));
MS_ST256_F32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f);
for (; index <= size - C4NUM; index += C4NUM) {
MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index);
MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index);
MS_FLOAT32X4 vout = MS_MULQ_F32(vin0, vin1);
vout = MS_BLENDQ_PS(zeros, vout, MS_CMPGTQ_PS(vout, zeros));
MS_STQ_F32(out + index, vout);
}
#endif
for (; index < size; index++) {
@@ -59,14 +77,24 @@ int ElementMulRelu(const float *in0, const float *in1, float *out, int size) {

int ElementMulRelu6(const float *in0, const float *in1, float *out, int size) {
int index = 0;
#ifdef ENABLE_NEON
float32x4_t zeros = vdupq_n_f32(0.0f);
float32x4_t bounds = vdupq_n_f32(6.0f);
for (; index <= size - 4; index += C4NUM) {
float32x4_t vin0 = vld1q_f32(in0 + index);
float32x4_t vin1 = vld1q_f32(in1 + index);
float32x4_t vout = vminq_f32(vmaxq_f32(vmulq_f32(vin0, vin1), zeros), bounds);
vst1q_f32(out + index, vout);
#if defined(ENABLE_AVX)
MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f);
MS_FLOAT32X8 bounds_8 = MS_MOV256_F32(6.0f);
for (; index <= size - C8NUM; index += C8NUM) {
MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index);
MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index);
MS_FLOAT32X8 vout = MS_MIN256_F32(MS_MAX256_F32(MS_MUL256_F32(vin0, vin1), zeros_8), bounds_8);
MS_ST256_F32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f);
MS_FLOAT32X4 bounds = MS_MOVQ_F32(6.0f);
for (; index <= size - C4NUM; index += C4NUM) {
MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index);
MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index);
MS_FLOAT32X4 vout = MS_MINQ_F32(MS_MAXQ_F32(MS_MULQ_F32(vin0, vin1), zeros), bounds);
MS_STQ_F32(out + index, vout);
}
#endif
for (; index < size; index++) {
@@ -77,12 +105,20 @@ int ElementMulRelu6(const float *in0, const float *in1, float *out, int size) {

int ElementMulInt(const int *in0, const int *in1, int *out, int size) {
int index = 0;
#ifdef ENABLE_NEON
for (; index <= size - 4; index += C4NUM) {
int32x4_t vin0 = vld1q_s32(in0 + index);
int32x4_t vin1 = vld1q_s32(in1 + index);
int32x4_t vout = vmulq_s32(vin0, vin1);
vst1q_s32(out + index, vout);
#if defined(ENABLE_AVX)
for (; index <= size - C8NUM; index += C8NUM) {
MS_INT32X8 vin0 = MS_LD256_EPI32(in0 + index);
MS_INT32X8 vin1 = MS_LD256_EPI32(in1 + index);
MS_INT32X8 vout = MS_MUL256_EPI32(vin0, vin1);
MS_ST256_EPI32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
for (; index <= size - C4NUM; index += C4NUM) {
MS_INT32X4 vin0 = MS_LDQ_EPI32(in0 + index);
MS_INT32X4 vin1 = MS_LDQ_EPI32(in1 + index);
MS_INT32X4 vout = MS_MULQ_EPI32(vin0, vin1);
MS_STQ_EPI32(out + index, vout);
}
#endif
for (; index < size; index++) {
@@ -93,18 +129,28 @@ int ElementMulInt(const int *in0, const int *in1, int *out, int size) {

int ElementMulReluInt(const int *in0, const int *in1, int *out, int size) {
int index = 0;
#ifdef ENABLE_NEON
int32x4_t zeros = vdupq_n_s32(0);
for (; index <= size - 4; index += C4NUM) {
int32x4_t vin0 = vld1q_s32(in0 + index);
int32x4_t vin1 = vld1q_s32(in1 + index);
int32x4_t vout = vmulq_s32(vin0, vin1);
vout = vbslq_s32(vcgtq_s32(vout, zeros), vout, zeros);
vst1q_s32(out + index, vout);
#if defined(ENABLE_AVX)
MS_INT32X8 zeros_8 = MS_MOV256_EPI32(0);
for (; index <= size - C8NUM; index += C8NUM) {
MS_INT32X8 vin0 = MS_LD256_EPI32(in0 + index);
MS_INT32X8 vin1 = MS_LD256_EPI32(in1 + index);
MS_INT32X8 vout = MS_MUL256_EPI32(vin0, vin1);
vout = MS_BLEND256_EPI32(zeros_8, vout, MS_CMPGT256_EPI32(vout, zeros_8));
MS_ST256_EPI32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_INT32X4 zeros = MS_MOVQ_EPI32(0);
for (; index <= size - C4NUM; index += C4NUM) {
MS_INT32X4 vin0 = MS_LDQ_EPI32(in0 + index);
MS_INT32X4 vin1 = MS_LDQ_EPI32(in1 + index);
MS_INT32X4 vout = MS_MULQ_EPI32(vin0, vin1);
vout = MS_BLENDQ_EPI32(zeros, vout, MS_CMPGTQ_EPI32(vout, zeros));
MS_STQ_EPI32(out + index, vout);
}
#endif
for (; index < size; index++) {
float res = in0[index] * in1[index];
int res = in0[index] * in1[index];
out[index] = res > 0 ? res : 0;
}
return NNACL_OK;
@@ -112,14 +158,24 @@ int ElementMulReluInt(const int *in0, const int *in1, int *out, int size) {

int ElementMulRelu6Int(const int *in0, const int *in1, int *out, int size) {
int index = 0;
#ifdef ENABLE_NEON
int32x4_t zeros = vdupq_n_s32(0);
int32x4_t bounds = vdupq_n_s32(6);
for (; index <= size - 4; index += C4NUM) {
int32x4_t vin0 = vld1q_s32(in0 + index);
int32x4_t vin1 = vld1q_s32(in1 + index);
int32x4_t vout = vminq_s32(vmaxq_s32(vmulq_s32(vin0, vin1), zeros), bounds);
vst1q_s32(out + index, vout);
#if defined(ENABLE_AVX)
MS_INT32X8 zeros_8 = MS_MOV256_EPI32(0);
MS_INT32X8 bounds_8 = MS_MOV256_EPI32(6);
for (; index <= size - C8NUM; index += C8NUM) {
MS_INT32X8 vin0 = MS_LD256_EPI32(in0 + index);
MS_INT32X8 vin1 = MS_LD256_EPI32(in1 + index);
MS_INT32X8 vout = MS_MIN256_EPI32(MS_MAX256_EPI32(MS_MUL256_EPI32(vin0, vin1), zeros_8), bounds_8);
MS_ST256_EPI32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_INT32X4 zeros = MS_MOVQ_EPI32(0);
MS_INT32X4 bounds = MS_MOVQ_EPI32(6);
for (; index <= size - C4NUM; index += C4NUM) {
MS_INT32X4 vin0 = MS_LDQ_EPI32(in0 + index);
MS_INT32X4 vin1 = MS_LDQ_EPI32(in1 + index);
MS_INT32X4 vout = MS_MINQ_EPI32(MS_MAXQ_EPI32(MS_MULQ_EPI32(vin0, vin1), zeros), bounds);
MS_STQ_EPI32(out + index, vout);
}
#endif
for (; index < size; index++) {
@@ -129,28 +185,42 @@ int ElementMulRelu6Int(const int *in0, const int *in1, int *out, int size) {
}

int ElementOptMul(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float32x4_t vin0_opt = vdupq_n_f32(in0[0]);
float32x4_t vin1_opt = vdupq_n_f32(in1[0]);
#endif
int index = 0;
if (param->in_elements_num0_ == 1) {
#ifdef ENABLE_NEON
for (; index <= size - 4; index += C4NUM) {
float32x4_t vin1 = vld1q_f32(in1 + index);
float32x4_t vout = vmulq_f32(vin0_opt, vin1);
vst1q_f32(out + index, vout);
#if defined(ENABLE_AVX)
MS_FLOAT32X8 vin0_opt_8 = MS_MOV256_F32(in0[0]);
for (; index <= size - C8NUM; index += C8NUM) {
MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index);
MS_FLOAT32X8 vout = MS_MUL256_F32(vin0_opt_8, vin1);
MS_ST256_F32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_FLOAT32X4 vin0_opt = MS_MOVQ_F32(in0[0]);
for (; index <= size - C4NUM; index += C4NUM) {
MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index);
MS_FLOAT32X4 vout = MS_MULQ_F32(vin0_opt, vin1);
MS_STQ_F32(out + index, vout);
}
#endif
for (; index < size; index++) {
out[index] = in0[0] * in1[index];
}
} else {
#ifdef ENABLE_NEON
for (; index <= size - 4; index += C4NUM) {
float32x4_t vin0 = vld1q_f32(in0 + index);
float32x4_t vout = vmulq_f32(vin0, vin1_opt);
vst1q_f32(out + index, vout);
#if defined(ENABLE_AVX)
MS_FLOAT32X8 vin1_opt_8 = MS_MOV256_F32(in1[0]);
for (; index <= size - C8NUM; index += C8NUM) {
MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index);
MS_FLOAT32X8 vout = MS_MUL256_F32(vin0, vin1_opt_8);
MS_ST256_F32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_FLOAT32X4 vin1_opt = MS_MOVQ_F32(in1[0]);
for (; index <= size - C4NUM; index += C4NUM) {
MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index);
MS_FLOAT32X4 vout = MS_MULQ_F32(vin0, vin1_opt);
MS_STQ_F32(out + index, vout);
}
#endif
for (; index < size; index++) {
@@ -161,29 +231,46 @@ int ElementOptMul(const float *in0, const float *in1, float *out, int size, cons
}

int ElementOptMulRelu(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float32x4_t vin0_opt = vdupq_n_f32(in0[0]);
float32x4_t vin1_opt = vdupq_n_f32(in1[0]);
float32x4_t zeros = vdupq_n_f32(0.0f);
#endif
int index = 0;
if (param->in_elements_num0_ == 1) {
#ifdef ENABLE_NEON
for (; index <= size - 4; index += C4NUM) {
float32x4_t vin1 = vld1q_f32(in1 + index);
float32x4_t vout = vmaxq_f32(vmulq_f32(vin0_opt, vin1), zeros);
vst1q_f32(out + index, vout);
#if defined(ENABLE_AVX)
MS_FLOAT32X8 vin0_opt_8 = MS_MOV256_F32(in0[0]);
MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f);
for (; index <= size - C8NUM; index += C8NUM) {
MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index);
MS_FLOAT32X8 vout = MS_MAX256_F32(MS_MUL256_F32(vin0_opt_8, vin1), zeros_8);
MS_ST256_F32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_FLOAT32X4 vin0_opt = MS_MOVQ_F32(in0[0]);
MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f);
for (; index <= size - C4NUM; index += C4NUM) {
MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index);
MS_FLOAT32X4 vout = MS_MAXQ_F32(MS_MULQ_F32(vin0_opt, vin1), zeros);
MS_STQ_F32(out + index, vout);
}
#endif
for (; index < size; index++) {
out[index] = MSMAX(in0[0] * in1[index], 0);
}
} else {
#ifdef ENABLE_NEON
for (; index <= size - 4; index += C4NUM) {
float32x4_t vin0 = vld1q_f32(in0 + index);
float32x4_t vout = vmaxq_f32(vmulq_f32(vin0, vin1_opt), zeros);
vst1q_f32(out + index, vout);
#if defined(ENABLE_AVX)
MS_FLOAT32X8 vin1_opt_8 = MS_MOV256_F32(in1[0]);
MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f);
for (; index <= size - C8NUM; index += C8NUM) {
MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index);
MS_FLOAT32X8 vout = MS_MAX256_F32(MS_MUL256_F32(vin0, vin1_opt_8), zeros_8);
MS_ST256_F32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_FLOAT32X4 vin1_opt = MS_MOVQ_F32(in1[0]);
MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f);
for (; index <= size - C4NUM; index += C4NUM) {
MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index);
MS_FLOAT32X4 vout = MS_MAXQ_F32(MS_MULQ_F32(vin0, vin1_opt), zeros);
MS_STQ_F32(out + index, vout);
}
#endif
for (; index < size; index++) {
@@ -194,30 +281,50 @@ int ElementOptMulRelu(const float *in0, const float *in1, float *out, int size,
}

int ElementOptMulRelu6(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float32x4_t vin0_opt = vdupq_n_f32(in0[0]);
float32x4_t vin1_opt = vdupq_n_f32(in1[0]);
float32x4_t zeros = vdupq_n_f32(0.0f);
float32x4_t bounds = vdupq_n_f32(6.0f);
#endif
int index = 0;
if (param->in_elements_num0_ == 1) {
#ifdef ENABLE_NEON
for (; index <= size - 4; index += C4NUM) {
float32x4_t vin1 = vld1q_f32(in1 + index);
float32x4_t vout = vminq_f32(vmaxq_f32(vmulq_f32(vin0_opt, vin1), zeros), bounds);
vst1q_f32(out + index, vout);
#if defined(ENABLE_AVX)
MS_FLOAT32X8 vin0_opt_8 = MS_MOV256_F32(in0[0]);
MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f);
MS_FLOAT32X8 bounds_8 = MS_MOV256_F32(6.0f);
for (; index <= size - C8NUM; index += C8NUM) {
MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index);
MS_FLOAT32X8 vout = MS_MIN256_F32(MS_MAX256_F32(MS_MUL256_F32(vin0_opt_8, vin1), zeros_8), bounds_8);
MS_ST256_F32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_FLOAT32X4 vin0_opt = MS_MOVQ_F32(in0[0]);
MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f);
MS_FLOAT32X4 bounds = MS_MOVQ_F32(6.0f);
for (; index <= size - C4NUM; index += C4NUM) {
MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index);
MS_FLOAT32X4 vout = MS_MINQ_F32(MS_MAXQ_F32(MS_MULQ_F32(vin0_opt, vin1), zeros), bounds);
MS_STQ_F32(out + index, vout);
}
#endif
for (; index < size; index++) {
out[index] = MSMIN(MSMAX(in0[0] * in1[index], 0), 6);
}
} else {
#ifdef ENABLE_NEON
for (; index <= size - 4; index += C4NUM) {
float32x4_t vin0 = vld1q_f32(in0 + index);
float32x4_t vout = vminq_f32(vmaxq_f32(vmulq_f32(vin0, vin1_opt), zeros), bounds);
vst1q_f32(out + index, vout);
#if defined(ENABLE_AVX)
MS_FLOAT32X8 vin1_opt_8 = MS_MOV256_F32(in1[0]);
MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f);
MS_FLOAT32X8 bounds_8 = MS_MOV256_F32(6.0f);
for (; index <= size - C8NUM; index += C8NUM) {
MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index);
MS_FLOAT32X8 vout = MS_MIN256_F32(MS_MAX256_F32(MS_MUL256_F32(vin0, vin1_opt_8), zeros_8), bounds_8);
MS_ST256_F32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_FLOAT32X4 vin1_opt = MS_MOVQ_F32(in1[0]);
MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f);
MS_FLOAT32X4 bounds = MS_MOVQ_F32(6.0f);
for (; index <= size - C4NUM; index += C4NUM) {
MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index);
MS_FLOAT32X4 vout = MS_MINQ_F32(MS_MAXQ_F32(MS_MULQ_F32(vin0, vin1_opt), zeros), bounds);
MS_STQ_F32(out + index, vout);
}
#endif
for (; index < size; index++) {
@@ -228,28 +335,42 @@ int ElementOptMulRelu6(const float *in0, const float *in1, float *out, int size,
}

int ElementOptMulInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param) {
#ifdef ENABLE_NEON
int32x4_t vin0_opt = vdupq_n_s32(in0[0]);
int32x4_t vin1_opt = vdupq_n_s32(in1[0]);
#endif
int index = 0;
if (param->in_elements_num0_ == 1) {
#ifdef ENABLE_NEON
for (; index <= size - 4; index += C4NUM) {
int32x4_t vin1 = vld1q_s32(in1 + index);
int32x4_t vout = vmulq_s32(vin0_opt, vin1);
vst1q_s32(out + index, vout);
#if defined(ENABLE_AVX)
MS_INT32X8 vin0_opt_8 = MS_MOV256_EPI32(in0[0]);
for (; index <= size - C8NUM; index += C8NUM) {
MS_INT32X8 vin1 = MS_LD256_EPI32(in1 + index);
MS_INT32X8 vout = MS_MUL256_EPI32(vin0_opt_8, vin1);
MS_ST256_EPI32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_INT32X4 vin0_opt = MS_MOVQ_EPI32(in0[0]);
for (; index <= size - C4NUM; index += C4NUM) {
MS_INT32X4 vin1 = MS_LDQ_EPI32(in1 + index);
MS_INT32X4 vout = MS_MULQ_EPI32(vin0_opt, vin1);
MS_STQ_EPI32(out + index, vout);
}
#endif
for (; index < size; index++) {
out[index] = in0[0] * in1[index];
}
} else {
#ifdef ENABLE_NEON
for (; index <= size - 4; index += C4NUM) {
int32x4_t vin0 = vld1q_s32(in0 + index);
int32x4_t vout = vmulq_s32(vin0, vin1_opt);
vst1q_s32(out + index, vout);
#if defined(ENABLE_AVX)
MS_INT32X8 vin1_opt_8 = MS_MOV256_EPI32(in1[0]);
for (; index <= size - C8NUM; index += C8NUM) {
MS_INT32X8 vin0 = MS_LD256_EPI32(in0 + index);
MS_INT32X8 vout = MS_MUL256_EPI32(vin0, vin1_opt_8);
MS_ST256_EPI32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_INT32X4 vin1_opt = MS_MOVQ_EPI32(in1[0]);
for (; index <= size - C4NUM; index += C4NUM) {
MS_INT32X4 vin0 = MS_LDQ_EPI32(in0 + index);
MS_INT32X4 vout = MS_MULQ_EPI32(vin0, vin1_opt);
MS_STQ_EPI32(out + index, vout);
}
#endif
for (; index < size; index++) {
@@ -260,29 +381,46 @@ int ElementOptMulInt(const int *in0, const int *in1, int *out, int size, const A
}

int ElementOptMulReluInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param) {
#ifdef ENABLE_NEON
int32x4_t vin0_opt = vdupq_n_s32(in0[0]);
int32x4_t vin1_opt = vdupq_n_s32(in1[0]);
int32x4_t zeros = vdupq_n_s32(0);
#endif
int index = 0;
if (param->in_elements_num0_ == 1) {
#ifdef ENABLE_NEON
for (; index <= size - 4; index += C4NUM) {
int32x4_t vin1 = vld1q_s32(in1 + index);
int32x4_t vout = vmaxq_s32(vmulq_s32(vin0_opt, vin1), zeros);
vst1q_s32(out + index, vout);
#if defined(ENABLE_AVX)
MS_INT32X8 vin0_opt_8 = MS_MOV256_EPI32(in0[0]);
MS_INT32X8 zeros_8 = MS_MOV256_EPI32(0);
for (; index <= size - C8NUM; index += C8NUM) {
MS_INT32X8 vin1 = MS_LD256_EPI32(in1 + index);
MS_INT32X8 vout = MS_MAX256_EPI32(MS_MUL256_EPI32(vin0_opt_8, vin1), zeros_8);
MS_ST256_EPI32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_INT32X4 vin0_opt = MS_MOVQ_EPI32(in0[0]);
MS_INT32X4 zeros = MS_MOVQ_EPI32(0);
for (; index <= size - C4NUM; index += C4NUM) {
MS_INT32X4 vin1 = MS_LDQ_EPI32(in1 + index);
MS_INT32X4 vout = MS_MAXQ_EPI32(MS_MULQ_EPI32(vin0_opt, vin1), zeros);
MS_STQ_EPI32(out + index, vout);
}
#endif
for (; index < size; index++) {
out[index] = MSMAX(in0[0] * in1[index], 0);
}
} else {
#ifdef ENABLE_NEON
for (; index <= size - 4; index += C4NUM) {
int32x4_t vin0 = vld1q_s32(in0 + index);
int32x4_t vout = vmaxq_s32(vmulq_s32(vin0, vin1_opt), zeros);
vst1q_s32(out + index, vout);
#if defined(ENABLE_AVX)
MS_INT32X8 vin1_opt_8 = MS_MOV256_EPI32(in1[0]);
MS_INT32X8 zeros_8 = MS_MOV256_EPI32(0);
for (; index <= size - C8NUM; index += C8NUM) {
MS_INT32X8 vin0 = MS_LD256_EPI32(in0 + index);
MS_INT32X8 vout = MS_MAX256_EPI32(MS_MUL256_EPI32(vin0, vin1_opt_8), zeros_8);
MS_ST256_EPI32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_INT32X4 vin1_opt = MS_MOVQ_EPI32(in1[0]);
MS_INT32X4 zeros = MS_MOVQ_EPI32(0);
for (; index <= size - C4NUM; index += C4NUM) {
MS_INT32X4 vin0 = MS_LDQ_EPI32(in0 + index);
MS_INT32X4 vout = MS_MAXQ_EPI32(MS_MULQ_EPI32(vin0, vin1_opt), zeros);
MS_STQ_EPI32(out + index, vout);
}
#endif
for (; index < size; index++) {
@@ -293,30 +431,50 @@ int ElementOptMulReluInt(const int *in0, const int *in1, int *out, int size, con
}

int ElementOptMulRelu6Int(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param) {
#ifdef ENABLE_NEON
int32x4_t vin0_opt = vdupq_n_s32(in0[0]);
int32x4_t vin1_opt = vdupq_n_s32(in1[0]);
int32x4_t zeros = vdupq_n_s32(0);
int32x4_t bounds = vdupq_n_s32(6);
#endif
int index = 0;
if (param->in_elements_num0_ == 1) {
#ifdef ENABLE_NEON
for (; index <= size - 4; index += C4NUM) {
int32x4_t vin1 = vld1q_s32(in1 + index);
int32x4_t vout = vminq_s32(vmaxq_s32(vmulq_s32(vin0_opt, vin1), zeros), bounds);
vst1q_s32(out + index, vout);
#if defined(ENABLE_AVX)
MS_INT32X8 vin0_opt_8 = MS_MOV256_EPI32(in0[0]);
MS_INT32X8 zeros_8 = MS_MOV256_EPI32(0);
MS_INT32X8 bounds_8 = MS_MOV256_EPI32(6);
for (; index <= size - C8NUM; index += C8NUM) {
MS_INT32X8 vin1 = MS_LD256_EPI32(in1 + index);
MS_INT32X8 vout = MS_MIN256_EPI32(MS_MAX256_EPI32(MS_MUL256_EPI32(vin0_opt_8, vin1), zeros_8), bounds_8);
MS_ST256_EPI32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_INT32X4 vin0_opt = MS_MOVQ_EPI32(in0[0]);
MS_INT32X4 zeros = MS_MOVQ_EPI32(0);
MS_INT32X4 bounds = MS_MOVQ_EPI32(6);
for (; index <= size - C4NUM; index += C4NUM) {
MS_INT32X4 vin1 = MS_LDQ_EPI32(in1 + index);
MS_INT32X4 vout = MS_MINQ_EPI32(MS_MAXQ_EPI32(MS_MULQ_EPI32(vin0_opt, vin1), zeros), bounds);
MS_STQ_EPI32(out + index, vout);
}
#endif
for (; index < size; index++) {
out[index] = MSMIN(MSMAX(in0[0] * in1[index], 0), 6);
}
} else {
#ifdef ENABLE_NEON
for (; index <= size - 4; index += C4NUM) {
int32x4_t vin0 = vld1q_s32(in0 + index);
int32x4_t vout = vminq_s32(vmaxq_s32(vmulq_s32(vin0, vin1_opt), zeros), bounds);
vst1q_s32(out + index, vout);
#if defined(ENABLE_AVX)
MS_INT32X8 vin1_opt_8 = MS_MOV256_EPI32(in1[0]);
MS_INT32X8 zeros_8 = MS_MOV256_EPI32(0);
MS_INT32X8 bounds_8 = MS_MOV256_EPI32(6);
for (; index <= size - C8NUM; index += C8NUM) {
MS_INT32X8 vin0 = MS_LD256_EPI32(in0 + index);
MS_INT32X8 vout = MS_MIN256_EPI32(MS_MAX256_EPI32(MS_MUL256_EPI32(vin0, vin1_opt_8), zeros_8), bounds_8);
MS_ST256_EPI32(out + index, vout);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_INT32X4 vin1_opt = MS_MOVQ_EPI32(in1[0]);
MS_INT32X4 zeros = MS_MOVQ_EPI32(0);
MS_INT32X4 bounds = MS_MOVQ_EPI32(6);
for (; index <= size - C4NUM; index += C4NUM) {
MS_INT32X4 vin0 = MS_LDQ_EPI32(in0 + index);
MS_INT32X4 vout = MS_MINQ_EPI32(MS_MAXQ_EPI32(MS_MULQ_EPI32(vin0, vin1_opt), zeros), bounds);
MS_STQ_EPI32(out + index, vout);
}
#endif
for (; index < size; index++) {


+ 323
- 280
mindspore/lite/nnacl/fp32/winograd_utils.c
File diff suppressed because it is too large
View File


+ 81
- 4
mindspore/lite/nnacl/op_base.h View File

@@ -107,28 +107,105 @@ typedef enum CalFixedMultiplierMode {

#ifdef ENABLE_ARM
#define MS_FLOAT32X4 float32x4_t
#define MS_INT32X4 int32x4_t
#define MS_LDQ_F32 vld1q_f32
#define MS_LDQ_EPI32 vld1q_s32
#define MS_ADDQ_F32 vaddq_f32
#define MS_ADDQ_EPI32 vaddq_s32
#define MS_MOVQ_F32 vmovq_n_f32
#define MS_MOVQ_EPI32 vmovq_n_s32
#define MS_DUPQ_F32 vdupq_n_f32 // It is recommended to replace with MS_MOVQ_F32.
#define MS_SUBQ_F32 vsubq_f32
#define MS_MLAQ_F32(src1, src2, src3) vmlaq_f32(src1, src2, src3)
#define MS_STQ_F32 vst1q_f32
#define MS_STQ_EPI32 vst1q_s32
#define MS_MAXQ_F32 vmaxq_f32
#define MS_MAXQ_EPI32 vmaxq_s32
#define MS_MINQ_F32 vminq_f32
#define MS_MULQ_F32(src1, src2) vmulq_n_f32(src1, src2)
#elif defined(ENABLE_SSE)
#define MS_MINQ_EPI32 vminq_s32
#define MS_MULQ_F32(src1, src2) vmulq_f32(src1, src2)
#define MS_MULQ_EPI32(src1, src2) vmulq_s32(src1, src2)
#ifdef ENABLE_ARM64
#define MS_DIVQ_F32(src1, src2) vdivq_f32(src1, src2)
#else
#define MS_DIVQ_F32(src1, src2) vmulq_f32(src1, vrecpeq_f32(src2))
#endif
#define MS_MULQ_N_F32(src1, src2) vmulq_n_f32(src1, src2)
#define MS_MULQ_N_EPI32(src1, src2) vmulq_n_s32(src1, src2)
#define MS_DIVQ_N_F32(src1, src2) vdivq_n_f32(src1, src2)
#define MS_SLLIQ_EPI32(src1, src2) vshlq_s32(src1, vmovq_n_s32(src2))
#define MS_CVTQPS_EPI32(src) vcvtq_s32_f32(src)
#define MS_CVTQEPI32_PS(src) vcvtq_f32_s32(src)
#define MS_CMPGTQ_PS(src1, src2) vcgtq_f32(src1, src2)
#define MS_CMPGTQ_EPI32(src1, src2) vcgtq_s32(src1, src2)
// Note: Compared with X86, the vbslq_f32 parameters are the opposite with _mm_blendv_ps
#define MS_BLENDQ_PS(src1, src2, src3) vbslq_f32(src3, src2, src1)
#define MS_BLENDQ_EPI32(src1, src2, src3) vbslq_s32(src3, src2, src1)

#elif defined(ENABLE_AVX)
#define MS_FLOAT32X8 __m256
#define MS_INT32X8 __m256i
#define MS_LD256_F32 _mm256_loadu_ps
#define MS_LD256_EPI32(src) _mm256_loadu_si256((__m256i const *)(src))
#define MS_ADD256_F32 _mm256_add_ps
#define MS_ADD256_EPI32 _mm256_add_epi32
#define MS_MOV256_F32 _mm256_set1_ps
#define MS_MOV256_EPI32 _mm256_set1_epi32
#define MS_DUP256_F32 _mm256_load_ps1 // It is recommended to replace with MS_MOV256_F32.
#define MS_MLA256_F32(src1, src2, src3) _mm256_add_ps(src1, _mm256_mul_ps(src2, src3))
#define MS_ST256_F32 _mm256_storeu_ps
#define MS_ST256_EPI32(src1, src2) _mm256_storeu_si256((__m256i *)(src1), src2)
#define MS_SUB256_F32 _mm256_sub_ps
#define MS_MAX256_F32 _mm256_max_ps
#define MS_MAX256_EPI32 _mm256_max_epi32
#define MS_MIN256_F32 _mm256_min_ps
#define MS_MIN256_EPI32 _mm256_min_epi32
#define MS_MUL256_F32(src1, src2) _mm256_mul_ps(src1, src2)
#define MS_MUL256_EPI32(src1, src2) _mm256_mul_epi32(src1, src2)
#define MS_DIV256_F32(src1, src2) _mm256_div_ps(src1, src2)
#define MS_MUL256_N_F32(src1, src2) _mm256_mul_ps(src1, _mm256_set1_ps(src2))
#define MS_MUL256_N_EPI32(src1, src2) _mm256_mul_epi32(src1, _mm256_set1_epi32(src2))
#define MS_DIV256_N_F32(src1, src2) _mm256_div_ps(src1, _mm256_set1_ps(src2))
#define MS_SLLI256_EPI32(src1, src2) _mm256_slli_epi32(src1, src2)
#define MS_CVT256PS_EPI32(src) _mm256_cvttps_epi32(src)
#define MS_CVT256EPI32_PS(src) _mm256_cvtepi32_ps(src) // truncate float to int
#define MS_CMP256_PS(src1, src2, src3) _mm256_cmp_ps(src1, src2, src3)
#define MS_CMPGT256_EPI32(src1, src2) _mm256_cmpgt_epi32(src1, src2)
#define MS_BLEND256_PS(src1, src2, src3) _mm256_blendv_ps(src1, src2, src3)
#define MS_BLEND256_EPI32(src1, src2, src3) _mm256_blendv_epi8(src1, src2, src3)
#endif

#if defined(ENABLE_SSE)
#define MS_FLOAT32X4 __m128
#define MS_INT32X4 __m128i
#define MS_LDQ_F32 _mm_loadu_ps
#define MS_LDQ_EPI32(src) _mm_loadu_si128((__m128i const *)(src))
#define MS_ADDQ_F32 _mm_add_ps
#define MS_MOVQ_F32 _mm_set_ps1
#define MS_ADDQ_EPI32 _mm_add_epi32
#define MS_MOVQ_F32 _mm_set1_ps
#define MS_MOVQ_EPI32 _mm_set1_epi32
#define MS_DUPQ_F32 _mm_load_ps1 // It is recommended to replace with MS_MOVQ_F32.
#define MS_MLAQ_F32(src1, src2, src3) _mm_add_ps(src1, _mm_mul_ps(src2, src3))
#define MS_STQ_F32 _mm_storeu_ps
#define MS_STQ_EPI32(src1, src2) _mm_storeu_si128((__m128i *)(src1), src2)
#define MS_SUBQ_F32 _mm_sub_ps
#define MS_MAXQ_F32 _mm_max_ps
#define MS_MAXQ_EPI32 _mm_max_epi32
#define MS_MINQ_F32 _mm_min_ps
#define MS_MULQ_F32(src1, src2) _mm_mul_ps(src1, _mm_set_ps1(src2))
#define MS_MINQ_EPI32 _mm_min_epi32
#define MS_MULQ_F32(src1, src2) _mm_mul_ps(src1, src2)
#define MS_MULQ_EPI32(src1, src2) _mm_mul_epi32(src1, src2)
#define MS_DIVQ_F32(src1, src2) _mm_div_ps(src1, src2)
#define MS_MULQ_N_F32(src1, src2) _mm_mul_ps(src1, _mm_set1_ps(src2))
#define MS_MULQ_N_EPI32(src1, src2) _mm_mul_epi32(src1, _mm_set1_epi32(src2))
#define MS_DIVQ_N_F32(src1, src2) _mm_div_ps(src1, _mm_set1_ps(src2))
#define MS_SLLIQ_EPI32(src1, src2) _mm_slli_epi32(src1, src2)
#define MS_CVTQPS_EPI32(src) _mm_cvttps_epi32(src) // truncate float to int
#define MS_CVTQEPI32_PS(src) _mm_cvtepi32_ps(src)
#define MS_CMPGTQ_PS(src1, src2) _mm_cmpgt_ps(src1, src2)
#define MS_CMPGTQ_EPI32(src1, src2) _mm_cmpgt_epi32(src1, src2)
#define MS_BLENDQ_PS(src1, src2, src3) _mm_blendv_ps(src1, src2, src3)
#define MS_BLENDQ_EPI32(src1, src2, src3) _mm_blendv_epi8(src1, src2, src3)
#endif

#endif // MINDSPORE_LITE_NNACL_OP_BASE_H_

+ 1
- 1
mindspore/lite/src/ops/tensorlist_getitem.cc View File

@@ -179,7 +179,7 @@ int TensorListGetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect
MS_LOG(ERROR) << "element_shape_ is not fullyDefined!";
return RET_ERROR;
}
output->set_data_type(input0->data_type());
output->set_data_type(input0->tensors_data_type());
output->set_shape(element_shape_);
}
output->set_format(input0->GetTensor(index_)->format());


Loading…
Cancel
Save