Browse Source

!4912 modify arm cpu fp16 & fp32 op: arithmetic

Merge pull request !4912 from 陶云浩/tmp10
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
11e3b1dfff
7 changed files with 842 additions and 304 deletions
  1. +236
    -142
      mindspore/lite/nnacl/fp16/arithmetic_fp16.c
  2. +421
    -21
      mindspore/lite/nnacl/fp32/arithmetic.c
  3. +6
    -0
      mindspore/lite/nnacl/fp32/arithmetic.h
  4. +75
    -80
      mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc
  5. +3
    -1
      mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h
  6. +95
    -40
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc
  7. +6
    -20
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h

+ 236
- 142
mindspore/lite/nnacl/fp16/arithmetic_fp16.c View File

@@ -74,33 +74,48 @@ int ElementOptMulFp16(float16_t *input0, float16_t *input1, float16_t *output, i
ArithmeticParameter *param) { ArithmeticParameter *param) {
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#endif #endif
for (int index = 0; index < block_c8; index += C8NUM) {
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0);
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1);
float16x8_t vout = vmulq_f16(vin0, vin1);
vst1q_f16(output, vout);
float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vmulq_f16(vin0, vin1);
vst1q_f16(output, vout);
#else #else
for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = in0 * in1;
for (int i = 0; i < C8NUM; ++i) {
output[i] = in0_opt * input1[i];
}
#endif
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = in0_opt * input1[index];
} }
} else {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vmulq_f16(vin0, vin1);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = input0[i] * in1_opt;
}
#endif #endif
input0 += C8NUM;
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index];
output[index] = in0 * in1;
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = input0[index] * in1_opt;
}
} }


return NNACL_OK; return NNACL_OK;
@@ -113,7 +128,6 @@ int ElementMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output,
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
#endif #endif

for (int index = 0; index < block_c8; index += C8NUM) { for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0); float16x8_t vin0 = vld1q_f16(input0);
@@ -143,39 +157,58 @@ int ElementOptMulReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu
ArithmeticParameter *param) { ArithmeticParameter *param) {
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
#endif #endif


for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0);
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1);
float16x8_t vout = vmulq_f16(vin0, vin1);
vout = vmaxq_f16(vout, zeros);
vst1q_f16(output, vout);
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vmulq_f16(vin0, vin1);
vout = vmaxq_f16(vout, zeros);
vst1q_f16(output, vout);
#else #else
float16_t res;
for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
res = in0 * in1;
output[i] = res > 0 ? res : 0;
float16_t res;
for (int i = 0; i < C8NUM; ++i) {
res = in0_opt * input1[i];
output[i] = res > 0 ? res : 0;
}
#endif
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
float16_t res = in0_opt * input1[index];
output[index] = res > 0 ? res : 0;
} }
} else {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vmulq_f16(vin0, vin1);
vout = vmaxq_f16(vout, zeros);
vst1q_f16(output, vout);
#else
float16_t res;
for (int i = 0; i < C8NUM; ++i) {
res = input0[i] * in1_opt;
output[i] = res > 0 ? res : 0;
}
#endif #endif
input0 += C8NUM;
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index];
float16_t res = in0 * in1;
output[index] = res > 0 ? res : 0;
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
float16_t res = input0[index] * in1_opt;
output[index] = res > 0 ? res : 0;
}
} }


return NNACL_OK; return NNACL_OK;
@@ -216,37 +249,52 @@ int ElementOptMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp
ArithmeticParameter *param) { ArithmeticParameter *param) {
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6}; float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6};
#endif #endif
for (int index = 0; index < block_c8; index += C8NUM) {
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0);
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1);
float16x8_t vout = vmulq_f16(vin0, vin1);
vout = vminq_f16(vmaxq_f16(vout, zeros), bounds);
vst1q_f16(output, vout);
float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vmulq_f16(vin0, vin1);
vout = vminq_f16(vmaxq_f16(vout, zeros), bounds);
vst1q_f16(output, vout);
#else #else
for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = MSMIN(MSMAX(in0 * in1, 0), 6);
for (int i = 0; i < C8NUM; ++i) {
output[i] = MSMIN(MSMAX(in0_opt * input1[i], 0), 6);
}
#endif
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(in0_opt * input1[index], 0), 6);
} }
} else {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vmulq_f16(vin0, vin1);
vout = vminq_f16(vmaxq_f16(vout, zeros), bounds);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = MSMIN(MSMAX(input0[i] * in1_opt, 0), 6);
}
#endif #endif
input0 += C8NUM;
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index];
output[index] = MSMIN(MSMAX(in0 * in1, 0), 6);
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(input0[index] * in1_opt, 0), 6);
}
} }


return NNACL_OK; return NNACL_OK;
@@ -255,7 +303,6 @@ int ElementOptMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp
int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;

for (int index = 0; index < block_c8; index += C8NUM) { for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0); float16x8_t vin0 = vld1q_f16(input0);
@@ -280,34 +327,50 @@ int ElementOptAddFp16(float16_t *input0, float16_t *input1, float16_t *output, i
ArithmeticParameter *param) { ArithmeticParameter *param) {
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#endif #endif
for (int index = 0; index < block_c8; index += C8NUM) {
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0);
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1);
float16x8_t vout = vaddq_f16(vin0, vin1);
vst1q_f16(output, vout);
float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vaddq_f16(vin0, vin1);
vst1q_f16(output, vout);
#else #else
for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = in0 + in1;
for (int i = 0; i < C8NUM; ++i) {
output[i] = in0_opt + input1[i];
}
#endif
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = in0_opt + input1[index];
} }
} else {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vaddq_f16(vin0, vin1);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = input0[i] + in1_opt;
}
#endif #endif
input0 += C8NUM;
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index];
output[index] = in0 + in1;
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = input0[index] + in1_opt;
}
} }

return NNACL_OK; return NNACL_OK;
} }


@@ -345,37 +408,54 @@ int ElementOptAddReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu
ArithmeticParameter *param) { ArithmeticParameter *param) {
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
#endif #endif


for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0);
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1);
float16x8_t vout = vaddq_f16(vin0, vin1);
vout = vmaxq_f16(vout, zeros);
vst1q_f16(output, vout);
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vaddq_f16(vin0, vin1);
vout = vmaxq_f16(vout, zeros);
vst1q_f16(output, vout);
#else #else
for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = MSMAX(in0 + in1, 0);
for (int i = 0; i < C8NUM; ++i) {
output[i] = MSMAX(in0_opt + input1[i], 0);
}
#endif
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
float16_t res = in0_opt + input1[index];
output[index] = res > 0 ? res : 0;
} }
} else {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vaddq_f16(vin0, vin1);
vout = vmaxq_f16(vout, zeros);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = MSMAX(input0[i] + in1_opt, 0);
}
#endif #endif
input0 += C8NUM;
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index];
float16_t res = in0 + in1;
output[index] = res > 0 ? res : 0;
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
float16_t res = input0[index] + in1_opt;
output[index] = res > 0 ? res : 0;
}
} }
return NNACL_OK; return NNACL_OK;
} }
@@ -415,39 +495,54 @@ int ElementOptAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp
ArithmeticParameter *param) { ArithmeticParameter *param) {
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6}; float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6};
#endif #endif


for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0);
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1);
float16x8_t vout = vaddq_f16(vin0, vin1);
vout = vminq_f16(vmaxq_f16(vout, zeros), bounds);
vst1q_f16(output, vout);
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vaddq_f16(vin0, vin1);
vout = vminq_f16(vmaxq_f16(vout, zeros), bounds);
vst1q_f16(output, vout);
#else #else
for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = MSMIN(MSMAX(in0 + in1, 0), 6);
for (int i = 0; i < C8NUM; ++i) {
output[i] = MSMIN(MSMAX(in0_opt + input1[i], 0), 6);
}
#endif
input1 += C8NUM;
output += C8NUM;
} }
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(in0_opt + input1[index], 0), 6);
}
} else {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vaddq_f16(vin0, vin1);
vout = vminq_f16(vmaxq_f16(vout, zeros), bounds);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = MSMIN(MSMAX(input0[i] + in1_opt, 0), 6);
}
#endif #endif
input0 += C8NUM;
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index];
output[index] = MSMIN(MSMAX(in0 + in1, 0), 6);
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(input0[index] + in1_opt, 0), 6);
}
} }

return NNACL_OK; return NNACL_OK;
} }


@@ -479,11 +574,11 @@ int ElementOptSubFp16(float16_t *input0, float16_t *input1, float16_t *output, i
ArithmeticParameter *param) { ArithmeticParameter *param) {
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#endif #endif
for (int index = 0; index < block_c8; index += C8NUM) { for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
@@ -542,11 +637,11 @@ int ElementOptSubReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu
ArithmeticParameter *param) { ArithmeticParameter *param) {
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
#endif #endif
for (int index = 0; index < block_c8; index += C8NUM) { for (int index = 0; index < block_c8; index += C8NUM) {
@@ -609,11 +704,11 @@ int ElementOptSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp
ArithmeticParameter *param) { ArithmeticParameter *param) {
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6}; float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6};
#endif #endif
@@ -680,11 +775,11 @@ int ElementOptDivFp16(float16_t *input0, float16_t *input1, float16_t *output, i
ArithmeticParameter *param) { ArithmeticParameter *param) {
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#endif #endif
for (int index = 0; index < block_c8; index += C8NUM) { for (int index = 0; index < block_c8; index += C8NUM) {
if (param->in_elements_num1_ == 1) { if (param->in_elements_num1_ == 1) {
@@ -765,12 +860,11 @@ int ElementOptDivReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu
ArithmeticParameter *param) { ArithmeticParameter *param) {
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;

float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
#endif #endif
for (int index = 0; index < block_c8; index += C8NUM) { for (int index = 0; index < block_c8; index += C8NUM) {
@@ -855,11 +949,11 @@ int ElementOptDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;


float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6}; float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6};
#endif #endif


+ 421
- 21
mindspore/lite/nnacl/fp32/arithmetic.c View File

@@ -20,55 +20,455 @@
#define ACCURACY_DATA 0.00000001 #define ACCURACY_DATA 0.00000001


int ElementOptMul(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) { int ElementOptMul(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
float in0_opt = input0[0];
float in1_opt = input1[0];
#ifdef ENABLE_NEON
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
#endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vin0_opt;
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vmulq_f32(vin0, vin1);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = in0_opt * input1[i];
}
#endif
input1 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = in0_opt * input1[index];
}
} else {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vin1_opt;
float32x4_t vout = vmulq_f32(vin0, vin1);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = input0[i] * in1_opt;
}
#endif
input0 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = input0[index] * in1_opt;
}
}

return NNACL_OK;
}
int ElementOptMulRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
float in0_opt = input0[0];
float in1_opt = input1[0];
#ifdef ENABLE_NEON
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
float32x4_t zeros = {0, 0, 0, 0};
#endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vin0_opt;
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vmaxq_f32(vmulq_f32(vin0, vin1), zeros);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMAX(in0_opt * input1[i], 0);
}
#endif
input1 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMAX(in0_opt * input1[index], 0);
}
} else {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vin1_opt;
float32x4_t vout = vmaxq_f32(vmulq_f32(vin0, vin1), zeros);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMAX(input0[i] * in1_opt, 0);
}
#endif
input0 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMAX(input0[index] * in1_opt, 0);
}
}

return NNACL_OK;
}
int ElementOptMulRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
float in0_opt = input0[0];
float in1_opt = input1[0];
#ifdef ENABLE_NEON
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
float32x4_t zeros = {0, 0, 0, 0};
float32x4_t bounds = {6, 6, 6, 6};
#endif
if (param->in_elements_num0_ == 1) { if (param->in_elements_num0_ == 1) {
for (int i = 0; i < element_size; ++i) {
output[i] = input0[0] * input1[i];
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vin0_opt;
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vminq_f32(vmaxq_f32(vmulq_f32(vin0, vin1), zeros), bounds);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMIN(MSMAX(in0_opt * input1[i], 0), 6);
}
#endif
input1 += C4NUM;
output += C4NUM;
} }
} else if (param->in_elements_num1_ == 1) {
for (int i = 0; i < element_size; ++i) {
output[i] = input0[i] * input1[0];
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(in0_opt * input1[index], 0), 6);
} }
} else { } else {
for (int i = 0; i < element_size; ++i) {
output[i] = input0[i] * input1[i];
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vin1_opt;
float32x4_t vout = vminq_f32(vmaxq_f32(vmulq_f32(vin0, vin1), zeros), bounds);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMIN(MSMAX(input0[i] * in1_opt, 0), 6);
}
#endif
input0 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(input0[index] * in1_opt, 0), 6);
} }
} }

return NNACL_OK; return NNACL_OK;
} }


int ElementOptSub(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) { int ElementOptSub(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
float in0_opt = input0[0];
float in1_opt = input1[0];
#ifdef ENABLE_NEON
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
#endif
if (param->in_elements_num0_ == 1) { if (param->in_elements_num0_ == 1) {
for (int i = 0; i < element_size; ++i) {
output[i] = input0[0] - input1[i];
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vin0_opt;
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vsubq_f32(vin0, vin1);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = in0_opt - input1[i];
}
#endif
input1 += C4NUM;
output += C4NUM;
} }
} else if (param->in_elements_num1_ == 1) {
for (int i = 0; i < element_size; ++i) {
output[i] = input0[i] - input1[0];
for (int index = 0; index < block_mod; ++index) {
output[index] = in0_opt - input1[index];
} }
} else { } else {
for (int i = 0; i < element_size; ++i) {
output[i] = input0[i] - input1[i];
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vin1_opt;
float32x4_t vout = vsubq_f32(vin0, vin1);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = input0[i] - in1_opt;
}
#endif
input0 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = input0[index] - in1_opt;
} }
} }
return NNACL_OK; return NNACL_OK;
} }
int ElementOptSubRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
float in0_opt = input0[0];
float in1_opt = input1[0];
#ifdef ENABLE_NEON
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
float32x4_t zeros = {0, 0, 0, 0};
#endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vin0_opt;
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vmaxq_f32(vsubq_f32(vin0, vin1), zeros);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMAX(in0_opt - input1[i], 0);
}
#endif
input1 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMAX(in0_opt - input1[index], 0);
}
} else {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vin1_opt;
float32x4_t vout = vmaxq_f32(vsubq_f32(vin0, vin1), zeros);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMAX(input0[i] - in1_opt, 0);
}
#endif
input0 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMAX(input0[index] - in1_opt, 0);
}
}

return NNACL_OK;
}
int ElementOptSubRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
float in0_opt = input0[0];
float in1_opt = input1[0];
#ifdef ENABLE_NEON
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
float32x4_t zeros = {0, 0, 0, 0};
float32x4_t bounds = {6, 6, 6, 6};
#endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vin0_opt;
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vminq_f32(vmaxq_f32(vsubq_f32(vin0, vin1), zeros), bounds);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMIN(MSMAX(in0_opt - input1[i], 0), 6);
}
#endif
input1 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(in0_opt - input1[index], 0), 6);
}
} else {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vin1_opt;
float32x4_t vout = vminq_f32(vmaxq_f32(vsubq_f32(vin0, vin1), zeros), bounds);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMIN(MSMAX(input0[i] - in1_opt, 0), 6);
}
#endif
input0 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(input0[index] - in1_opt, 0), 6);
}
}

return NNACL_OK;
}


int ElementOptAdd(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) { int ElementOptAdd(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
float in0_opt = input0[0];
float in1_opt = input1[0];
#ifdef ENABLE_NEON
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
#endif
if (param->in_elements_num0_ == 1) { if (param->in_elements_num0_ == 1) {
for (int i = 0; i < element_size; ++i) {
output[i] = input0[0] + input1[i];
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vin0_opt;
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vaddq_f32(vin0, vin1);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = in0_opt + input1[i];
}
#endif
input1 += C4NUM;
output += C4NUM;
} }
} else if (param->in_elements_num1_ == 1) {
for (int i = 0; i < element_size; ++i) {
output[i] = input0[i] + input1[0];
for (int index = 0; index < block_mod; ++index) {
output[index] = in0_opt + input1[index];
} }
} else { } else {
for (int i = 0; i < element_size; ++i) {
output[i] = input0[i] + input1[i];
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vin1_opt;
float32x4_t vout = vaddq_f32(vin0, vin1);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = input0[i] + in1_opt;
}
#endif
input0 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = input0[index] + in1_opt;
} }
} }
return NNACL_OK; return NNACL_OK;
} }
int ElementOptAddRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
float in0_opt = input0[0];
float in1_opt = input1[0];
#ifdef ENABLE_NEON
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
float32x4_t zeros = {0, 0, 0, 0};
#endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vin0_opt;
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vmaxq_f32(vaddq_f32(vin0, vin1), zeros);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMAX(in0_opt + input1[i], 0);
}
#endif
input1 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMAX(in0_opt + input1[index], 0);
}
} else {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vin1_opt;
float32x4_t vout = vmaxq_f32(vaddq_f32(vin0, vin1), zeros);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMAX(input0[i] + in1_opt, 0);
}
#endif
input0 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMAX(input0[index] + in1_opt, 0);
}
}

return NNACL_OK;
}
int ElementOptAddRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
float in0_opt = input0[0];
float in1_opt = input1[0];
#ifdef ENABLE_NEON
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
float32x4_t zeros = {0, 0, 0, 0};
float32x4_t bounds = {6, 6, 6, 6};
#endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vin0_opt;
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vminq_f32(vmaxq_f32(vaddq_f32(vin0, vin1), zeros), bounds);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMIN(MSMAX(in0_opt + input1[i], 0), 6);
}
#endif
input1 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(in0_opt + input1[index], 0), 6);
}
} else {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vin1_opt;
float32x4_t vout = vminq_f32(vmaxq_f32(vaddq_f32(vin0, vin1), zeros), bounds);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMIN(MSMAX(input0[i] + in1_opt, 0), 6);
}
#endif
input0 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(input0[index] + in1_opt, 0), 6);
}
}

return NNACL_OK;
}


int ElementMul(float *input0, float *input1, float *output, int element_size) { int ElementMul(float *input0, float *input1, float *output, int element_size) {
int block_mod = element_size % C4NUM; int block_mod = element_size % C4NUM;


+ 6
- 0
mindspore/lite/nnacl/fp32/arithmetic.h View File

@@ -27,8 +27,14 @@
extern "C" { extern "C" {
#endif #endif
int ElementOptAdd(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); int ElementOptAdd(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptAddRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptAddRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptSub(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); int ElementOptSub(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptSubRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptSubRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptMul(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); int ElementOptMul(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptMulRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptMulRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementMul(float *input0, float *input1, float *output, int element_size); int ElementMul(float *input0, float *input1, float *output, int element_size);
int ElementMulRelu(float *input0, float *input1, float *output, int element_size); int ElementMulRelu(float *input0, float *input1, float *output, int element_size);
int ElementMulRelu6(float *input0, float *input1, float *output, int element_size); int ElementMulRelu6(float *input0, float *input1, float *output, int element_size);


+ 75
- 80
mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc View File

@@ -162,6 +162,7 @@ int ArithmeticFP16CPUKernel::Init() {
} }


int ArithmeticFP16CPUKernel::ReSize() { int ArithmeticFP16CPUKernel::ReSize() {
FreeTmpBuffer();
arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
@@ -175,10 +176,10 @@ int ArithmeticFP16CPUKernel::ReSize() {
arithmetic_opt_run_ = ElementOptMulReluFp16; arithmetic_opt_run_ = ElementOptMulReluFp16;
break; break;
case schema::ActivationType_RELU6: case schema::ActivationType_RELU6:
arithmetic_opt_run_ = ElementOptDivRelu6Fp16;
arithmetic_opt_run_ = ElementOptMulRelu6Fp16;
break; break;
default: default:
arithmetic_opt_run_ = ElementOptDivFp16;
arithmetic_opt_run_ = ElementOptMulFp16;
break; break;
} }
break; break;
@@ -267,20 +268,46 @@ int ArithmeticFP16CPUKernel::ReSize() {
break; break;
} }
} }

if (arithmeticParameter_->broadcasting_) {
outside_ = 1;
for (int i = arithmeticParameter_->ndim_ - 1; i >= 0; --i) {
if (arithmeticParameter_->in_shape0_[i] != arithmeticParameter_->in_shape1_[i]) {
break_pos_ = i;
break;
}
outside_ *= arithmeticParameter_->out_shape_[i];
}
ComputeStrides(arithmeticParameter_->in_shape0_, arithmeticParameter_->in_strides0_, arithmeticParameter_->ndim_);
ComputeStrides(arithmeticParameter_->in_shape1_, arithmeticParameter_->in_strides1_, arithmeticParameter_->ndim_);
ComputeStrides(arithmeticParameter_->out_shape_, arithmeticParameter_->out_strides_, arithmeticParameter_->ndim_);
}
return RET_OK; return RET_OK;
} }


int ArithmeticFP16CPUKernel::broadcast_run_(float16_t *input0, float16_t *input1, float16_t *output, int dim) {
int ArithmeticFP16CPUKernel::BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim,
int out_count, int out_thread_stride) {
if (dim > break_pos_) { if (dim > break_pos_) {
return arithmetic_run_(input0 + out_thread_stride_, input1 + out_thread_stride_, output + out_thread_stride_,
out_count_);
int error_code =
arithmetic_run_(input0 + out_thread_stride, input1 + out_thread_stride, output + out_thread_stride, out_count);
if (output_fp16_ != nullptr) {
auto output_fp32 = reinterpret_cast<float *>(out_tensors_[0]->Data());
int bias = output - output_fp16_;
output_fp32 += bias;
Float16ToFloat32(output + out_thread_stride, output_fp32 + out_thread_stride, out_count);
}
return error_code;
} }
for (int i = 0; i < arithmeticParameter_->out_shape_[dim]; ++i) { for (int i = 0; i < arithmeticParameter_->out_shape_[dim]; ++i) {
int pos0_ = arithmeticParameter_->in_shape0_[0] == 1 ? 0 : i;
int pos1_ = arithmeticParameter_->in_shape1_[0] == 1 ? 0 : i;
return broadcast_run_(input0 + pos0_ * arithmeticParameter_->in_strides0_[dim],
input1 + pos1_ * arithmeticParameter_->in_strides1_[dim],
output + i * arithmeticParameter_->out_strides_[dim], dim + 1);
int pos0_ = arithmeticParameter_->in_shape0_[dim] == 1 ? 0 : i;
int pos1_ = arithmeticParameter_->in_shape1_[dim] == 1 ? 0 : i;
int error_code =
BroadcastRun(input0 + pos0_ * arithmeticParameter_->in_strides0_[dim],
input1 + pos1_ * arithmeticParameter_->in_strides1_[dim],
output + i * arithmeticParameter_->out_strides_[dim], dim + 1, out_count, out_thread_stride);
if (error_code != RET_OK) {
return RET_ERROR;
}
} }
return RET_OK; return RET_OK;
} }
@@ -300,13 +327,16 @@ int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) {


if (arithmetic_run_ == nullptr) { if (arithmetic_run_ == nullptr) {
MS_LOG(ERROR) << "arithmetic_run function is nullptr!"; MS_LOG(ERROR) << "arithmetic_run function is nullptr!";
FreeTmpBuffer();
return RET_ERROR; return RET_ERROR;
} }


int error_code = RET_OK; int error_code = RET_OK;
if (arithmeticParameter_->broadcasting_) { if (arithmeticParameter_->broadcasting_) {
error_code =
arithmetic_run_(tile_data0_ + thread_stride, tile_data1_ + thread_stride, output_data + thread_stride, count);
stride = UP_DIV(outside_, context_->thread_num_);
out_count_ = MSMIN(stride, outside_ - stride * task_id);
out_thread_stride_ = stride * task_id;
error_code = BroadcastRun(input0_data, input1_data1, output_data, 0, out_count_, out_thread_stride_);
} else if (arithmetic_opt_run_ != nullptr) { } else if (arithmetic_opt_run_ != nullptr) {
if (arithmeticParameter_->in_elements_num0_ == 1) { if (arithmeticParameter_->in_elements_num0_ == 1) {
error_code = arithmetic_opt_run_(input0_data, input1_data1 + thread_stride, output_data + thread_stride, count, error_code = arithmetic_opt_run_(input0_data, input1_data1 + thread_stride, output_data + thread_stride, count,
@@ -323,17 +353,16 @@ int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) {
arithmetic_run_(input0_data + thread_stride, input1_data1 + thread_stride, output_data + thread_stride, count); arithmetic_run_(input0_data + thread_stride, input1_data1 + thread_stride, output_data + thread_stride, count);
} }
if (error_code != RET_OK) { if (error_code != RET_OK) {
FreeTmpBuffer();
return RET_ERROR; return RET_ERROR;
} }
if (output_fp16_ != nullptr) {
if (output_fp16_ != nullptr && !arithmeticParameter_->broadcasting_) {
auto output_fp32 = reinterpret_cast<float *>(out_tensors_[0]->Data()); auto output_fp32 = reinterpret_cast<float *>(out_tensors_[0]->Data());
Float16ToFloat32(output_data + thread_stride, output_fp32 + thread_stride, count); Float16ToFloat32(output_data + thread_stride, output_fp32 + thread_stride, count);
} }
return RET_OK; return RET_OK;
} }


static int ArithmeticsRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
static int ArithmeticsRun_Fp16(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto arithmetic_kernel = reinterpret_cast<ArithmeticFP16CPUKernel *>(cdata); auto arithmetic_kernel = reinterpret_cast<ArithmeticFP16CPUKernel *>(cdata);
auto error_code = arithmetic_kernel->DoArithmetic(task_id); auto error_code = arithmetic_kernel->DoArithmetic(task_id);
if (error_code != RET_OK) { if (error_code != RET_OK) {
@@ -353,24 +382,6 @@ int ArithmeticFP16CPUKernel::Run() {
arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) {
input0_fp16_ = reinterpret_cast<float16_t *>(
context_->allocator->Malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t)));
if (input0_fp16_ == nullptr) {
MS_LOG(ERROR) << "malloc data fail!";
FreeTmpBuffer();
return RET_ERROR;
}
}
if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) {
input1_fp16_ = reinterpret_cast<float16_t *>(
context_->allocator->Malloc(arithmeticParameter_->in_elements_num1_ * sizeof(float16_t)));
if (input0_fp16_ == nullptr) {
MS_LOG(ERROR) << "malloc data fail!";
FreeTmpBuffer();
return RET_ERROR;
}
}
if (out_tensors_[0]->data_type() == kNumberTypeFloat32 || out_tensors_[0]->data_type() == kNumberTypeFloat) { if (out_tensors_[0]->data_type() == kNumberTypeFloat32 || out_tensors_[0]->data_type() == kNumberTypeFloat) {
output_fp16_ = reinterpret_cast<float16_t *>( output_fp16_ = reinterpret_cast<float16_t *>(
context_->allocator->Malloc(arithmeticParameter_->out_elements_num_ * sizeof(float16_t))); context_->allocator->Malloc(arithmeticParameter_->out_elements_num_ * sizeof(float16_t)));
@@ -380,46 +391,30 @@ int ArithmeticFP16CPUKernel::Run() {
return RET_ERROR; return RET_ERROR;
} }
} }

if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) { if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) {
Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[0]->Data()), input0_fp16_,
arithmeticParameter_->in_elements_num0_);
}
if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) {
Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[1]->Data()), input1_fp16_,
arithmeticParameter_->in_elements_num1_);
}

if (arithmeticParameter_->broadcasting_) {
auto tile_size = arithmeticParameter_->out_elements_num_ * sizeof(float16_t);
tile_data0_ = reinterpret_cast<float16_t *>(malloc(tile_size));
if (tile_data0_ == nullptr) {
input0_fp16_ = reinterpret_cast<float16_t *>(
context_->allocator->Malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t)));
if (input0_fp16_ == nullptr) {
MS_LOG(ERROR) << "malloc data fail!"; MS_LOG(ERROR) << "malloc data fail!";
FreeTmpBuffer(); FreeTmpBuffer();
return RET_ERROR; return RET_ERROR;
} }
tile_data1_ = reinterpret_cast<float16_t *>(malloc(tile_size));
if (tile_data1_ == nullptr) {
Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[0]->Data()), input0_fp16_,
arithmeticParameter_->in_elements_num0_);
}
if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) {
input1_fp16_ = reinterpret_cast<float16_t *>(
context_->allocator->Malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t)));
if (input1_fp16_ == nullptr) {
MS_LOG(ERROR) << "malloc data fail!"; MS_LOG(ERROR) << "malloc data fail!";
FreeTmpBuffer(); FreeTmpBuffer();
return RET_ERROR; return RET_ERROR;
} }
auto input0 = reinterpret_cast<float16_t *>(in_tensors_[0]->Data());
auto input1 = reinterpret_cast<float16_t *>(in_tensors_[1]->Data());

float16_t *input0_data = input0_fp16_ == nullptr ? input0 : input0_fp16_;
float16_t *input1_data1 = input1_fp16_ == nullptr ? input1 : input1_fp16_;

TileDimensionsFp16(input0_data, input1_data1, tile_data0_, tile_data1_, arithmeticParameter_);
}

ret = LiteBackendParallelLaunch(ArithmeticsRun, this, context_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Arithmetic function fail!ret: " << ret;
FreeTmpBuffer();
return ret;
Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[1]->Data()), input1_fp16_,
arithmeticParameter_->in_elements_num1_);
} }
return RET_OK;
ret = LiteBackendParallelLaunch(ArithmeticsRun_Fp16, this, context_->thread_num_);
return ret;
} }


kernel::LiteKernel *CpuArithmeticFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, kernel::LiteKernel *CpuArithmeticFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
@@ -446,21 +441,21 @@ kernel::LiteKernel *CpuArithmeticFp16KernelCreator(const std::vector<lite::tenso
return kernel; return kernel;
} }


// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Mul, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Add, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Sub, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Div, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_FloorMod, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_FloorDiv, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LogicalAnd, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LogicalOr, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Maximum, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Minimum, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_NotEqual, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Equal, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Less, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LessEqual, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Greater, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_GreaterEqual, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Eltwise, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Mul, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Add, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Sub, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Div, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_FloorMod, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_FloorDiv, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LogicalAnd, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LogicalOr, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Maximum, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Minimum, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_NotEqual, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Equal, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Less, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LessEqual, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Greater, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_GreaterEqual, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Eltwise, CpuArithmeticFp16KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel

+ 3
- 1
mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h View File

@@ -41,10 +41,12 @@ class ArithmeticFP16CPUKernel : public LiteKernel {
int ReSize() override; int ReSize() override;
int Run() override; int Run() override;
int DoArithmetic(int task_id); int DoArithmetic(int task_id);
int broadcast_run_(float16_t *input0, float16_t *input1, float16_t *output, int dim);
int BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim, int out_count,
int out_thread_stride);


private: private:
void FreeTmpBuffer(); void FreeTmpBuffer();
int outside_;
int break_pos_; int break_pos_;
int out_thread_stride_; int out_thread_stride_;
int out_count_; int out_count_;


+ 95
- 40
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc View File

@@ -29,6 +29,9 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Eltwise; using mindspore::schema::PrimitiveType_Eltwise;


namespace mindspore::kernel { namespace mindspore::kernel {

ArithmeticCPUKernel::~ArithmeticCPUKernel() {}

int ArithmeticCPUKernel::Init() { int ArithmeticCPUKernel::Init() {
if (!InferShapeDone()) { if (!InferShapeDone()) {
return RET_OK; return RET_OK;
@@ -42,23 +45,77 @@ int ArithmeticCPUKernel::ReSize() {
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();


if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) { if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) {
if (arithmeticParameter_->activation_type_ == schema::ActivationType_NO_ACTIVATION) {
switch (arithmeticParameter_->op_parameter_.type_) {
case PrimitiveType_Mul:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptMul;
break;
case PrimitiveType_Add:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptAdd;
break;
case PrimitiveType_Sub:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptSub;
break;
default:
break;
}
switch (arithmeticParameter_->op_parameter_.type_) {
case PrimitiveType_Mul:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptMulRelu;
break;
case schema::ActivationType_RELU6:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptMulRelu6;
break;
default:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptMul;
break;
}
break;
case PrimitiveType_Add:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptAddRelu;
break;
case schema::ActivationType_RELU6:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptAddRelu6;
break;
default:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptAdd;
break;
}
break;
case PrimitiveType_Sub:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptSubRelu;
break;
case schema::ActivationType_RELU6:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptSubRelu6;
break;
default:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptSub;
break;
}
break;
default:
break;
}
}
return RET_OK;
}

int ArithmeticCPUKernel::BroadcastRun(float *input0, float *input1, float *output, int dim, int out_count,
int out_thread_stride) {
if (dim > break_pos_) {
return arithmetic_run_(input0 + out_thread_stride, input1 + out_thread_stride, output + out_thread_stride,
out_count);
}
for (int i = 0; i < arithmeticParameter_->out_shape_[dim]; ++i) {
int pos0_ = arithmeticParameter_->in_shape0_[dim] == 1 ? 0 : i;
int pos1_ = arithmeticParameter_->in_shape1_[dim] == 1 ? 0 : i;
int error_code =
BroadcastRun(input0 + pos0_ * arithmeticParameter_->in_strides0_[dim],
input1 + pos1_ * arithmeticParameter_->in_strides1_[dim],
output + i * arithmeticParameter_->out_strides_[dim], dim + 1, out_count, out_thread_stride);
if (error_code != RET_OK) {
return error_code;
} }
} }
return RET_OK; return RET_OK;
@@ -81,8 +138,10 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {


int error_code = RET_OK; int error_code = RET_OK;
if (arithmeticParameter_->broadcasting_) { if (arithmeticParameter_->broadcasting_) {
error_code = arithmetic_run_(tile_data0_ + stride * task_id, tile_data1_ + stride * task_id,
output_data + stride * task_id, count);
stride = UP_DIV(outside_, thread_count_);
out_count_ = MSMIN(stride, outside_ - stride * task_id);
out_thread_stride_ = stride * task_id;
error_code = BroadcastRun(input0_data, input1_data1, output_data, 0, out_count_, out_thread_stride_);
} else if (arithmetic_opt_run_ != nullptr) { } else if (arithmetic_opt_run_ != nullptr) {
if (arithmeticParameter_->in_elements_num0_ == 1) { if (arithmeticParameter_->in_elements_num0_ == 1) {
error_code = arithmetic_opt_run_(input0_data, input1_data1 + stride * task_id, output_data + stride * task_id, error_code = arithmetic_opt_run_(input0_data, input1_data1 + stride * task_id, output_data + stride * task_id,
@@ -120,31 +179,27 @@ int ArithmeticCPUKernel::Run() {
MS_LOG(ERROR) << "Prepare fail!ret: " << ret; MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
return ret; return ret;
} }

if (arithmeticParameter_->broadcasting_) { if (arithmeticParameter_->broadcasting_) {
auto input_data0 = reinterpret_cast<float *>(in_tensors_[0]->Data());
auto input_data1 = reinterpret_cast<float *>(in_tensors_[1]->Data());
auto length = arithmeticParameter_->out_elements_num_ * sizeof(float);
MS_ASSERT(context_->allocator != nullptr);
tile_data0_ = reinterpret_cast<float *>(context_->allocator->Malloc(length));
tile_data1_ = reinterpret_cast<float *>(context_->allocator->Malloc(length));
if (tile_data0_ == nullptr || tile_data1_ == nullptr) {
MS_LOG(ERROR) << "Memory allocation failed";
context_->allocator->Free(tile_data0_);
context_->allocator->Free(tile_data1_);
return RET_ERROR;
outside_ = 1;
for (auto i = arithmeticParameter_->ndim_ - 1; i >= 0; --i) {
if (arithmeticParameter_->in_shape0_[i] != arithmeticParameter_->in_shape1_[i]) {
break_pos_ = i;
break;
}
outside_ *= arithmeticParameter_->out_shape_[i];
} }
TileDimensions(input_data0, input_data1, tile_data0_, tile_data1_, arithmeticParameter_);
ComputeStrides(arithmeticParameter_->in_shape0_, arithmeticParameter_->in_strides0_, arithmeticParameter_->ndim_);
ComputeStrides(arithmeticParameter_->in_shape1_, arithmeticParameter_->in_strides1_, arithmeticParameter_->ndim_);
ComputeStrides(arithmeticParameter_->out_shape_, arithmeticParameter_->out_strides_, arithmeticParameter_->ndim_);
} }
ret = LiteBackendParallelLaunch(ArithmeticsRun, this, thread_count_);
if (arithmeticParameter_->broadcasting_) {
context_->allocator->Free(tile_data0_);
context_->allocator->Free(tile_data1_);
}
if (ret != RET_OK) {
MS_LOG(ERROR) << "Arithmetic function error error_code[" << ret << "]";

int error_code = LiteBackendParallelLaunch(ArithmeticsRun, this, thread_count_);

if (error_code != RET_OK) {
MS_LOG(ERROR) << "Arithmetic function error error_code[" << error_code << "]";
return RET_ERROR;
} }
return ret;
return RET_OK;
} }


kernel::LiteKernel *CpuArithmeticFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, kernel::LiteKernel *CpuArithmeticFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,


+ 6
- 20
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h View File

@@ -45,8 +45,6 @@ class ArithmeticCPUKernel : public LiteKernel {
typedef int (*ArithmeticRun)(float *input0, float *input1, float *output, int element_size); typedef int (*ArithmeticRun)(float *input0, float *input1, float *output, int element_size);
typedef int (*ArithmeticOptRun)(float *input0, float *input1, float *output, int element_size, typedef int (*ArithmeticOptRun)(float *input0, float *input1, float *output, int element_size,
ArithmeticParameter *param); ArithmeticParameter *param);
typedef int (*ArithmeticBroadcastRun)(float *input0, float *input1, float *tile_input0, float *tile_input1,
float *output, int element_size, ArithmeticParameter *param);


public: public:
ArithmeticCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, ArithmeticCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
@@ -109,64 +107,50 @@ class ArithmeticCPUKernel : public LiteKernel {
break; break;
case PrimitiveType_LogicalAnd: case PrimitiveType_LogicalAnd:
arithmetic_run_ = ElementLogicalAnd; arithmetic_run_ = ElementLogicalAnd;
arithmetic_broadcast_run_ = BroadcastLogicalAnd;
break; break;
case PrimitiveType_LogicalOr: case PrimitiveType_LogicalOr:
arithmetic_run_ = ElementLogicalOr; arithmetic_run_ = ElementLogicalOr;
arithmetic_broadcast_run_ = BroadcastLogicalOr;
break; break;
case PrimitiveType_Maximum: case PrimitiveType_Maximum:
arithmetic_run_ = ElementMaximum; arithmetic_run_ = ElementMaximum;
arithmetic_broadcast_run_ = BroadcastMaximum;
break; break;
case PrimitiveType_Minimum: case PrimitiveType_Minimum:
arithmetic_run_ = ElementMinimum; arithmetic_run_ = ElementMinimum;
arithmetic_broadcast_run_ = BroadcastMinimum;
break; break;
case PrimitiveType_FloorDiv: case PrimitiveType_FloorDiv:
arithmetic_run_ = ElementFloorDiv; arithmetic_run_ = ElementFloorDiv;
arithmetic_broadcast_run_ = BroadcastFloorDiv;
break; break;
case PrimitiveType_FloorMod: case PrimitiveType_FloorMod:
arithmetic_run_ = ElementFloorMod; arithmetic_run_ = ElementFloorMod;
arithmetic_broadcast_run_ = BroadcastFloorMod;
break; break;
case PrimitiveType_Equal: case PrimitiveType_Equal:
arithmetic_run_ = ElementEqual; arithmetic_run_ = ElementEqual;
arithmetic_broadcast_run_ = BroadcastEqual;
break; break;
case PrimitiveType_NotEqual: case PrimitiveType_NotEqual:
arithmetic_run_ = ElementNotEqual; arithmetic_run_ = ElementNotEqual;
arithmetic_broadcast_run_ = BroadcastNotEqual;
break; break;
case PrimitiveType_Less: case PrimitiveType_Less:
arithmetic_run_ = ElementLess; arithmetic_run_ = ElementLess;
arithmetic_broadcast_run_ = BroadcastLess;
break; break;
case PrimitiveType_LessEqual: case PrimitiveType_LessEqual:
arithmetic_run_ = ElementLessEqual; arithmetic_run_ = ElementLessEqual;
arithmetic_broadcast_run_ = BroadcastLessEqual;
break; break;
case PrimitiveType_Greater: case PrimitiveType_Greater:
arithmetic_run_ = ElementGreater; arithmetic_run_ = ElementGreater;
arithmetic_broadcast_run_ = BroadcastGreater;
break; break;
case PrimitiveType_GreaterEqual: case PrimitiveType_GreaterEqual:
arithmetic_run_ = ElementGreaterEqual; arithmetic_run_ = ElementGreaterEqual;
arithmetic_broadcast_run_ = BroadcastGreaterEqual;
break; break;
case PrimitiveType_SquaredDifference: case PrimitiveType_SquaredDifference:
arithmetic_run_ = ElementSquaredDifference; arithmetic_run_ = ElementSquaredDifference;
arithmetic_broadcast_run_ = BroadcastSquaredDifference;
break; break;
default: default:
MS_LOG(ERROR) << "Error Operator type " << parameter->type_; MS_LOG(ERROR) << "Error Operator type " << parameter->type_;
arithmetic_run_ = nullptr; arithmetic_run_ = nullptr;
arithmetic_broadcast_run_ = nullptr;
break; break;
} }
} }
~ArithmeticCPUKernel() = default;
~ArithmeticCPUKernel() override;


int Init() override; int Init() override;
int ReSize() override; int ReSize() override;
@@ -174,12 +158,14 @@ class ArithmeticCPUKernel : public LiteKernel {
int DoArithmetic(int task_id); int DoArithmetic(int task_id);


private: private:
int BroadcastRun(float *input0, float *input1, float *output, int dim, int out_count, int out_thread_stride);
int break_pos_;
int outside_;
int out_thread_stride_;
int out_count_;
int thread_count_; int thread_count_;
float *tile_data0_ = nullptr;
float *tile_data1_ = nullptr;
ArithmeticParameter *arithmeticParameter_; ArithmeticParameter *arithmeticParameter_;
ArithmeticRun arithmetic_run_ = nullptr; ArithmeticRun arithmetic_run_ = nullptr;
ArithmeticBroadcastRun arithmetic_broadcast_run_ = nullptr;
ArithmeticOptRun arithmetic_opt_run_ = nullptr; ArithmeticOptRun arithmetic_opt_run_ = nullptr;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel


Loading…
Cancel
Save