From fdaff630940834d918191b4a31eb1f6ddd8e16b3 Mon Sep 17 00:00:00 2001 From: lzk Date: Wed, 21 Apr 2021 20:00:25 -0700 Subject: [PATCH] fp1611 --- .../cpu/nnacl/fp16/conv_depthwise_fp16.c | 2 +- .../cpu/nnacl/fp16/conv_fp16.c | 102 -------- .../cpu/nnacl/fp16/conv_fp16.h | 13 - .../kernel_compiler/cpu/nnacl/fp16/exp_fp16.c | 2 +- .../kernel_compiler/cpu/nnacl/fp16/exp_fp16.h | 4 +- .../cpu/nnacl/fp16/instance_norm_fp16.c | 4 +- .../cpu/nnacl/fp16/matmul_fp16.c | 10 +- .../cpu/nnacl/fp16/power_fp16.c | 12 +- .../cpu/nnacl/fp16/power_fp16.h | 4 +- .../cpu/nnacl/fp16/scale_fp16.c | 20 +- .../cpu/nnacl/fp16/softmax_fp16.c | 10 +- .../cpu/nnacl/fp32/prelu_fp32.c | 2 +- .../intrinsics/ms_simd_instructions_fp16.h | 32 ++- mindspore/lite/CMakeLists.txt | 2 +- .../kernel/arm/fp16/fullconnection_fp16.cc | 5 + .../kernel/arm/fp16/matmul_base_fp16.cc | 7 +- .../kernel/arm/fp16/matmul_base_fp16.h | 1 + .../runtime/kernel/arm/fp16/matmul_fp16.cc | 7 +- mindspore/lite/test/models_caffe_fp16.cfg | 18 +- mindspore/lite/test/models_onnx_fp16.cfg | 8 +- mindspore/lite/test/models_tf_fp16.cfg | 5 + mindspore/lite/test/models_tflite_fp16.cfg | 5 + .../test/models_with_multiple_inputs_fp16.cfg | 9 +- mindspore/lite/test/run_benchmark_nets.sh | 232 ++++++++++++++++-- 24 files changed, 324 insertions(+), 192 deletions(-) mode change 100755 => 100644 mindspore/lite/test/run_benchmark_nets.sh diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_depthwise_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_depthwise_fp16.c index f7176e51d6..c7738c62b6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_depthwise_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_depthwise_fp16.c @@ -322,7 +322,7 @@ void DeconvDepthwiseCenterFp16(float16_t *dst, const float16_t *src, const float float16_t *dst_kw = dst_kh; const float16_t *weight_kw = weight_kh; for (int kw = 0; kw < kernel_w; kw++) { -#ifdef ENABLE_ARM64 +#ifdef ENABLE_NEON float16x8_t src_8 = vld1q_f16(src_w); float16x8_t weight_8 = vld1q_f16(weight_kw); float16x8_t dst_8 = vld1q_f16(dst_kw); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.c index df9524c5c9..e590c18f21 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.c @@ -19,108 +19,6 @@ #include "nnacl/fp16/winograd_transform_fp16.h" #include "nnacl/fp16/matmul_fp16.h" -#ifdef __cplusplus -extern "C" { -#endif -#ifdef ENABLE_ARM64 -void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, - size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu, - size_t relu6); -#endif - -#ifdef __cplusplus -} -#endif -#ifndef ENABLE_ARM64 -void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, - size_t ic4, size_t out_channel, size_t offset, size_t mode, size_t writeC8, size_t relu, - size_t relu6) { - if (!(mode && writeC8)) { - IndirectGemmFp16_16x8_common(output, input, weight, bias, step, ic4, out_channel, offset, relu, relu6); - } else { - IndirectGemmFp16_16x8_c8(output, input, weight, bias, step, ic4, out_channel, offset, mode, writeC8, relu, relu6); - } -} - -void IndirectGemmFp16_16x8_common(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, - size_t ic4, size_t out_channel, size_t offset, size_t relu, size_t relu6) { - const int tile_n = 16; - for (int i = 0; i < out_channel; i++) { - int oc8_block = i / C8NUM; - int oc8_res = i % C8NUM; - int weight_oc_offset = oc8_block * step * ic4 * C4NUM * C8NUM + oc8_res; - for (int k = 0; k < tile_n; k++) { - int input_tile_offset = k * C4NUM; - int out_tile_offset = i + k * out_channel; - - float16_t tmp_out = 0; - for (int n = 0; n < step; n++) { - int input_kw_offset = input_tile_offset + n * tile_n * ic4 * C4NUM; - int weight_kw_offset = weight_oc_offset + n * ic4 * C4NUM * C8NUM; - for (int j = 0; j < ic4; j++) { - int input_ic4_offset = input_kw_offset + j * tile_n * C4NUM; - int weight_ic4_offset = weight_kw_offset + j * C4NUM * C8NUM; - for (int m = 0; m < C4NUM; m++) { - int input_c4_offset = input_ic4_offset + m; - int weight_c4_offset = weight_ic4_offset + m * C8NUM; - tmp_out += (input + input_c4_offset)[0] * (weight + weight_c4_offset)[0]; - } - } - } - - float16_t *tmp = output + out_tile_offset; - if (bias != NULL) { - tmp[0] = tmp_out + bias[i]; - } - if (relu) { - tmp[0] = tmp[0] < 0 ? 0 : tmp[0]; - } else if (relu6) { - tmp[0] = tmp[0] < 0 ? 0 : tmp[0]; - tmp[0] = tmp[0] > 6 ? 6 : tmp[0]; - } - } - } -} - -void IndirectGemmFp16_16x8_c8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, - size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC8, - size_t relu, size_t relu6) { - const int tile_num = 16; - if (mode && writeC8) { - for (int i = 0; i < tile_num; i++) { - int input_tile_offset = i * C4NUM; - int output_tile_offset = i * output_channel * step; - for (int j = 0; j < output_channel; j++) { - int oc8_block = j / C8NUM; - int oc8_res = j % C8NUM; - int weight_oc_offset = oc8_block * step * ic4 * C4NUM * C8NUM + oc8_res; - int out_oc_offset = output_tile_offset + oc8_block * step * C8NUM + oc8_res; - - for (int n = 0; n < step; n++) { - int input_kw_offset = input_tile_offset + n * ic4 * C4NUM * tile_num; - int weight_kw_offset = weight_oc_offset + n * ic4 * C4NUM * C8NUM; - int output_kw_offset = out_oc_offset + n * C8NUM; - float16_t acc = 0; - - for (int k = 0; k < ic4; k++) { - int input_ic4_offset = input_kw_offset + k * tile_num * C4NUM; - int weight_ic4_offset = weight_kw_offset + k * C4NUM * C8NUM; - for (int m = 0; m < C4NUM; m++) { - int input_ic_offset = input_ic4_offset + m; - int weight_ic_offset = weight_ic4_offset + m * C8NUM; - acc += (weight + weight_ic_offset)[0] * (input + input_ic_offset)[0]; - } - } - - (output + output_kw_offset)[0] = acc; - } - } - } - } else { - } -} -#endif - // fp16 convolution common (im2col+gemm) void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data, float16_t *col_major_input, float16_t *output_data, int task_id, ConvParameter *conv_param) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.h index 971044769a..34d97fb75a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.h @@ -24,19 +24,6 @@ typedef float16_t *TmpBufferAddressFp16; typedef float16_t *MatricesFp16; -#ifndef ENABLE_ARM64 -void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, - size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC8, size_t relu, - size_t relu6); - -void IndirectGemmFp16_16x8_common(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, - size_t ic4, size_t oc8, size_t offset, size_t relu, size_t relu6); - -void IndirectGemmFp16_16x8_c8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, - size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC8, size_t relu, - size_t relu6); -#endif - #ifdef __cplusplus extern "C" { #endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/exp_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/exp_fp16.c index 9b6dc15ee9..aec77bd79f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/exp_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/exp_fp16.c @@ -21,7 +21,7 @@ void ExpFp16(const float16_t *src, float16_t *dst, int num) { int i = 0; -#ifdef ENABLE_ARM64 +#ifdef ENABLE_NEON int count = (num / C8NUM) * C8NUM; for (; i < count; i += C8NUM) { simd_exp_fp16(vld1q_f16(src + i), dst + i); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/exp_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/exp_fp16.h index 9607239bda..0738504d96 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/exp_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/exp_fp16.h @@ -25,7 +25,7 @@ extern "C" { #endif void ExpFp16(const float16_t *src, float16_t *dst, int num); -#if defined(ENABLE_ARM64) +#if defined(ENABLE_NEON) static inline float32x4_t exp_fp32(float32x4_t input) { static float32x4_t param[] = {{0.693147f, 0.693147f, 0.693147f, 0.693147f}, {1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120}, @@ -49,7 +49,7 @@ static inline void simd_exp_fp16(float16x8_t input, float16_t *dst) { input = vmaxq_f16(minv, vminq_f16(input, maxv)); float32x4_t input_low = vcvt_f32_f16(vget_low_f16(input)); - float32x4_t input_high = vcvt_high_f32_f16(input); + float32x4_t input_high = vcvt_f32_f16(vget_high_f16(input)); vst1q_f16(dst, vcombine_f16(vcvt_f16_f32(exp_fp32(input_low)), vcvt_f16_f32(exp_fp32(input_high)))); } #endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/instance_norm_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/instance_norm_fp16.c index 75ca35b3d0..c4a89b7602 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/instance_norm_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/instance_norm_fp16.c @@ -43,11 +43,11 @@ int InstanceNormFp16(const float16_t *src_data, float16_t *dst_data, const float float16x4_t sum2 = vadd_f16(vget_low_f16(srcv), vget_high_f16(srcv)); float32x4_t sum_f32 = vcvt_f32_f16(sum2); - mean += vaddvq_f32(sum_f32); + mean += MS_ADDVQ_F32(sum_f32); float16x4_t square2 = vadd_f16(vget_low_f16(squarev), vget_high_f16(squarev)); float32x4_t square_f32 = vcvt_f32_f16(square2); - square_mean += vaddvq_f32(square_f32); + square_mean += MS_ADDVQ_F32(square_f32); } for (; index < param->inner_size_; index++) { mean += src[index]; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/matmul_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/matmul_fp16.c index 0cd5b56e95..2e4bdd31b8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/matmul_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/matmul_fp16.c @@ -290,7 +290,7 @@ void MatMul16x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, cons void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, int deep, int row, int col, int stride, int write_mode) { - if (write_mode == OutType_Nhwc) { + if (write_mode == OutType_Nhwc) { // common conv and matmul for (int r = 0; r < row; r++) { for (int c = 0; c < col; c++) { int r12div = r / 12, r12mod = r % 12; @@ -308,7 +308,7 @@ void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, cons dst[ci] = value; } } - } else if (write_mode == OutType_C8) { + } else if (write_mode == OutType_C8) { // common deconv int col_8 = UP_ROUND(col, C8NUM); int row_12 = UP_ROUND(row, C12NUM); for (int r = 0; r < row_12; r++) { @@ -328,7 +328,7 @@ void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, cons dst[ci] = value; } } - } else { + } else { // winograd conv for (int i = 0; i < row; ++i) { int src_r_offset = i; int dst_r_offset = i * col * stride; @@ -353,12 +353,14 @@ void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, cons void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, int depth, int row, int col, int stride, int out_type) { if (out_type == OutType_C8) { + // common deconv #ifdef ENABLE_ARM64 MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, false); #else - MatMul12x8A32Fp16(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type); + MatMul12x8Fp16(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type); #endif } else { + // winograd conv(OntType_TileC8) and common conv(OutType_Nhwc) and matmul(OutType_Nhwc) #ifdef ENABLE_ARM64 MatmulFp16Neon64Opt(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type); #else diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/power_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/power_fp16.c index 471221e81b..5d23c75031 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/power_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/power_fp16.c @@ -17,7 +17,7 @@ #include "nnacl/fp16/power_fp16.h" #include "nnacl/errorcode.h" -#if defined(ENABLE_ARM64) +#if defined(ENABLE_NEON) float16x8_t OptimizedPowerSimdFp16(float16x8_t x, const void *exponent) { int tmp = (int)(*(float16_t *)exponent); int exp = abs(tmp); @@ -53,23 +53,23 @@ float16_t OptimizedPowerScalarFp16(float16_t x, const void *exponent) { void PowerBroadCastFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale, float shift) { PowerScalarFunFp16 PowerScalarFunFp16_ = NULL; -#if defined(ENABLE_ARM64) +#if defined(ENABLE_NEON) PowerSimdFunFp16 PowerSimdFunFp16_ = NULL; #endif if (CheckInteger(*exponent)) { -#if defined(ENABLE_ARM64) +#if defined(ENABLE_NEON) PowerSimdFunFp16_ = OptimizedPowerSimdFp16; #endif PowerScalarFunFp16_ = OptimizedPowerScalarFp16; } else { -#if defined(ENABLE_ARM64) +#if defined(ENABLE_NEON) PowerSimdFunFp16_ = StdPowerSimdFp16; #endif PowerScalarFunFp16_ = StdPowerScalarFp16; } int i = 0; -#ifdef ENABLE_ARM64 +#ifdef ENABLE_NEON int len_c8 = UP_ROUND(len, C8NUM); float16x8_t scale_8 = vmovq_n_f16(scale); float16x8_t shift_8 = vmovq_n_f16(shift); @@ -87,7 +87,7 @@ void PowerSingleFp16(const float16_t *input, const float16_t *exponent, float16_ float shift) { int i = 0; PowerScalarFunFp16 PowerScalarFunFp16_ = NULL; -#ifdef ENABLE_ARM64 +#ifdef ENABLE_NEON int len_c8 = UP_ROUND(len, C8NUM); float16x8_t scale_8 = vmovq_n_f16(scale); float16x8_t shift_8 = vmovq_n_f16(shift); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/power_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/power_fp16.h index 206c68afd8..b49d27a422 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/power_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/power_fp16.h @@ -22,7 +22,7 @@ #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" #include "nnacl/power_parameter.h" -#if defined(ENABLE_ARM64) +#if defined(ENABLE_NEON) typedef float16x8_t (*PowerSimdFunFp16)(float16x8_t x, const void *exponent); #endif typedef float16_t (*PowerScalarFunFp16)(float16_t x, const void *exponent); @@ -37,7 +37,7 @@ static inline float16_t StdPowerScalarFp16(float16_t x, const void *exponent) { return powf(x, *(float16_t *)exponent); } -#if defined(ENABLE_ARM64) +#if defined(ENABLE_NEON) static inline float16x8_t StdPowerSimdFp16(float16x8_t x, const void *exponent) { float16x8_t result; result[0] = powf(x[0], *(float16_t *)exponent); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/scale_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/scale_fp16.c index aea928149d..954540de6b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/scale_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/scale_fp16.c @@ -23,7 +23,7 @@ void Fp16ScaleInner(float16_t *in_data, float16_t *out_data, float16_t *scale, f for (int i = 0; i < axis_size; i++) { int axis_offset = out_offset + i * inner_size; int in_index = 0; -#ifdef ENABLE_ARM64 +#ifdef ENABLE_NEON for (; in_index < inner_size - 8; in_index += 8) { int in_offset = axis_offset + in_index; float16x8_t data = vld1q_f16(in_data + in_offset); @@ -47,7 +47,7 @@ void Fp16ScaleAxis(float16_t *in_data, float16_t *out_data, float16_t *scale, fl for (int out = outer_start; out < outer_end; out++) { int out_offset = out * axis_size; int index = 0; -#ifdef ENABLE_ARM64 +#ifdef ENABLE_NEON for (; index < axis_size - 8; index += 8) { int in_offset = out_offset + index; float16x8_t data = vld1q_f16(in_data + in_offset); @@ -80,7 +80,7 @@ void DoScaleFp16(float16_t *in_data, float16_t *out_data, float16_t *scale, floa void Fp16ScaleInnerRelu(float16_t *in_data, float16_t *out_data, float16_t *scale, float16_t *offset, int outer_start, int outer_end, int axis_size, int inner_size) { -#ifdef ENABLE_ARM64 +#ifdef ENABLE_NEON float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; #endif for (int out = outer_start; out < outer_end; out++) { @@ -88,7 +88,7 @@ void Fp16ScaleInnerRelu(float16_t *in_data, float16_t *out_data, float16_t *scal for (int i = 0; i < axis_size; i++) { int axis_offset = out_offset + i * inner_size; int in_index = 0; -#ifdef ENABLE_ARM64 +#ifdef ENABLE_NEON for (; in_index < inner_size - 8; in_index += 8) { int in_offset = axis_offset + in_index; float16x8_t data = vld1q_f16(in_data + in_offset); @@ -110,13 +110,13 @@ void Fp16ScaleInnerRelu(float16_t *in_data, float16_t *out_data, float16_t *scal void Fp16ScaleAxisRelu(float16_t *in_data, float16_t *out_data, float16_t *scale, float16_t *offset, int outer_start, int outer_end, int axis_size) { -#ifdef ENABLE_ARM64 +#ifdef ENABLE_NEON float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; #endif for (int out = outer_start; out < outer_end; out++) { int out_offset = out * axis_size; int index = 0; -#ifdef ENABLE_ARM64 +#ifdef ENABLE_NEON for (; index < axis_size - 8; index += 8) { int in_offset = out_offset + index; float16x8_t data = vld1q_f16(in_data + in_offset); @@ -151,7 +151,7 @@ void Fp16DoScaleRelu(float16_t *in_data, float16_t *out_data, float16_t *scale, void Fp16ScaleInnerRelu6(float16_t *in_data, float16_t *out_data, float16_t *scale, float16_t *offset, int outer_start, int outer_end, int axis_size, int inner_size) { -#ifdef ENABLE_ARM64 +#ifdef ENABLE_NEON float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6}; #endif @@ -160,7 +160,7 @@ void Fp16ScaleInnerRelu6(float16_t *in_data, float16_t *out_data, float16_t *sca for (int i = 0; i < axis_size; i++) { int axis_offset = out_offset + i * inner_size; int in_index = 0; -#ifdef ENABLE_ARM64 +#ifdef ENABLE_NEON for (; in_index < inner_size - 8; in_index += 8) { int in_offset = axis_offset + in_index; float16x8_t data = vld1q_f16(in_data + in_offset); @@ -182,14 +182,14 @@ void Fp16ScaleInnerRelu6(float16_t *in_data, float16_t *out_data, float16_t *sca void Fp16ScaleAxisRelu6(float16_t *in_data, float16_t *out_data, float16_t *scale, float16_t *offset, int outer_start, int outer_end, int axis_size) { -#ifdef ENABLE_ARM64 +#ifdef ENABLE_NEON float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6}; #endif for (int out = outer_start; out < outer_end; out++) { int out_offset = out * axis_size; int index = 0; -#ifdef ENABLE_ARM64 +#ifdef ENABLE_NEON for (; index < axis_size - 8; index += 8) { int in_offset = out_offset + index; float16x8_t data = vld1q_f16(in_data + in_offset); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/softmax_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/softmax_fp16.c index dcc64d9199..ee6432a1ff 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/softmax_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/softmax_fp16.c @@ -22,14 +22,14 @@ void SoftmaxNormFp16(const float16_t *src, float16_t *dst, int batch, int channe int cur_batch_offset = 0; for (int i = 0; i < batch; i++, cur_batch_offset += channel) { int j = 0; -#ifdef ENABLE_ARM64 +#ifdef ENABLE_NEON float16x8_t max_8 = vdupq_n_f16(-FLT16_MAX); int count = (channel / C8NUM) * C8NUM; for (; j < count; j += C8NUM) { float16x8_t input_8 = vld1q_f16(src + cur_batch_offset + j); max_8 = vmaxq_f16(max_8, input_8); } - float16_t max = vmaxvq_f16(max_8); + float16_t max = MS_MAXVQ_F16(max_8); #else float16_t max = -FLT_MAX; #endif @@ -40,7 +40,7 @@ void SoftmaxNormFp16(const float16_t *src, float16_t *dst, int batch, int channe } } int k = 0; -#ifdef ENABLE_ARM64 +#ifdef ENABLE_NEON int count2 = (channel / C8NUM) * C8NUM; for (; k < count2; k += C8NUM) { float16x8_t input_8 = vld1q_f16(src + cur_batch_offset + k); @@ -60,7 +60,7 @@ void SumAndDivFp16(const float16_t *src, float16_t *dst, int batch, int channel) for (int i = 0; i < batch; i++, cur_batch_offset += channel) { float16_t sum = 0.0f; int j = 0; -#ifdef ENABLE_ARM64 +#ifdef ENABLE_NEON float16x8_t sum8 = vdupq_n_f16(0); int count = (channel / C8NUM) * C8NUM; for (; j < count; j += C8NUM) { @@ -72,7 +72,7 @@ void SumAndDivFp16(const float16_t *src, float16_t *dst, int batch, int channel) sum += src[cur_batch_offset + j]; } int k = 0; -#ifdef ENABLE_ARM64 +#ifdef ENABLE_NEON const float16_t div = 1.0f / sum; for (; k < count; k += C8NUM) { vst1q_f16(dst + cur_batch_offset + k, vmulq_n_f16(vld1q_f16(src + cur_batch_offset + k), div)); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/prelu_fp32.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/prelu_fp32.c index e2caa6a6cd..74bd2b924f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/prelu_fp32.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/prelu_fp32.c @@ -113,7 +113,7 @@ void PRelu(const float *input, float *output, float *slope, int start, int end, const float *cur_in = input + i * channel; float *cur_out = output + i * channel; int j = 0; -#if defined(ENABLE_ARM) +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) for (; j < channel - 3; j += 4) { MS_FLOAT32X4 in = MS_LDQ_F32(cur_in + j); MS_FLOAT32X4 s = MS_LDQ_F32(slope + j); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/ms_simd_instructions_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/ms_simd_instructions_fp16.h index cbe3a3ab12..2758790032 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/ms_simd_instructions_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/ms_simd_instructions_fp16.h @@ -19,7 +19,7 @@ #include "nnacl/intrinsics/ms_simd_instructions.h" #if defined(ENABLE_ARM82_A32) -static inline float16x8_t divq_f16(float16x8_t in1, float16x8_t in2) { +static inline float16x8_t ms_vdivq_f16(float16x8_t in1, float16x8_t in2) { float16x8_t dst; asm volatile( "vrecpe.f16 q14, %3\n" @@ -34,7 +34,7 @@ static inline float16x8_t divq_f16(float16x8_t in1, float16x8_t in2) { return dst; } -static inline float16x4_t div_f16(float16x4_t in1, float16x4_t in2) { +static inline float16x4_t ms_vdiv_f16(float16x4_t in1, float16x4_t in2) { float16x4_t dst; asm volatile( "vrecpe.f16 d14, %3\n" @@ -49,33 +49,47 @@ static inline float16x4_t div_f16(float16x4_t in1, float16x4_t in2) { return dst; } -static inline float vaddvq_f32(float32x4_t in) { // is not support in arm82 aarch32 +static inline float ms_vaddvq_f32(float32x4_t in) { + // is not support in arm82 aarch32 and there is no assembly instruction to process all the data return in[0] + in[1] + in[2] + in[3]; } -static inline float32x4_t cvt_f32_f16(float16x4_t in) { +static inline float16_t ms_vmaxvq_f16(float16x8_t in) { + // is not support in arm82 aarch32 and there is no assembly instruction to process all the data + float16_t dst = in[0]; + for (int i = 1; i < 8; ++i) { + dst = dst > in[i] ? dst : in[i]; + } + return dst; +} + +static inline float32x4_t ms_vcvt_f32_f16(float16x4_t in) { float32x4_t dst; asm volatile("vcvt.f32.f16 %0, %2\n" : "=w"(dst) : "0"(dst), "w"(in) :); return dst; } -static inline float16x4_t cvt_f16_f32(float32x4_t in) { +static inline float16x4_t ms_vcvt_f16_f32(float32x4_t in) { float16x4_t dst; asm volatile("vcvt.f16.f32 %0, %2\n" : "=w"(dst) : "0"(dst), "w"(in) :); return dst; } -#define MS_CVT_F32_F16(src) cvt_f32_f16(src) -#define MS_CVT_F16_F32(src) cvt_f16_f32(src) -#define MS_DIV_F16(src1, src2) div_f16(src1, src2) -#define MS_DIVQ_F16(src1, src2) divq_f16(src1, src2) +#define MS_CVT_F32_F16(src) ms_vcvt_f32_f16(src) +#define MS_CVT_F16_F32(src) ms_vcvt_f16_f32(src) +#define MS_DIV_F16(src1, src2) ms_vdiv_f16(src1, src2) +#define MS_DIVQ_F16(src1, src2) ms_vdivq_f16(src1, src2) #define MS_FMAQ_N_F16(src1, src2, src3) vfmaq_f16(src1, src2, vdupq_n_f16(src3)) +#define MS_MAXVQ_F16(src) ms_vmaxvq_f16(src) +#define MS_ADDVQ_F32(src) ms_vaddvq_f32(src) #else #define MS_CVT_F32_F16(src) vcvt_f32_f16(src) #define MS_CVT_F16_F32(src) vcvt_f16_f32(src) #define MS_DIV_F16(src1, src2) vdiv_f16(src1, src2) #define MS_DIVQ_F16(src1, src2) vdivq_f16(src1, src2) #define MS_FMAQ_N_F16(src1, src2, src3) vfmaq_n_f16(src1, src2, src3) +#define MS_MAXVQ_F16(src) vmaxvq_f16(src) +#define MS_ADDVQ_F32(src) vaddvq_f32(src) #endif static inline float16x8_t MS_TANHX8_F16(float16x8_t src) { diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt index 9ce4ddb182..679c1db108 100644 --- a/mindspore/lite/CMakeLists.txt +++ b/mindspore/lite/CMakeLists.txt @@ -7,7 +7,7 @@ endif() if(PLATFORM_ARM32 AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0) set(ENABLE_FP16 "off") - message(WARNING "If you want to build fp16 in arm82_a32, \ + message(STATUS "If you want to build fp16 in arm82_a32, \ your Clang version:[${CMAKE_CXX_COMPILER_VERSION}] must not be less than 9.0 and please use android nkd r21e!") endif() diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc index b1185b1c86..69d0305e0f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc @@ -42,6 +42,11 @@ int FullconnectionFP16CPUKernel::ReSize() { } int FullconnectionFP16CPUKernel::Init() { +#ifdef ENABLE_ARM64 + row_tile_ = C16NUM; +#else + row_tile_ = C12NUM; +#endif params_->batch = 1; params_->a_transpose_ = false; params_->b_transpose_ = true; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_base_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_base_fp16.cc index c50a081490..5b0e13be2a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_base_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_base_fp16.cc @@ -114,12 +114,7 @@ void MatmulBaseFP16CPUKernel::ResizeParameter() { params_->row_align_ = 1; params_->col_align_ = params_->col_; } else { -#ifdef ENABLE_ARM64 - int row_tile = C16NUM; -#else - int row_tile = C12NUM; -#endif - params_->row_align_ = UP_ROUND(params_->row_, row_tile); + params_->row_align_ = UP_ROUND(params_->row_, row_tile_); params_->col_align_ = UP_ROUND(params_->col_, C8NUM); } return; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_base_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_base_fp16.h index f2538a31e1..df5240dd0e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_base_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_base_fp16.h @@ -55,6 +55,7 @@ class MatmulBaseFP16CPUKernel : public LiteKernel { protected: MatMulParameter *params_ = nullptr; + int row_tile_ = 0; private: int thread_stride_ = 0; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc index f2e9c4ca20..9c5bab0d3b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc @@ -36,7 +36,7 @@ void MatmulFP16CPUKernel::InitAShape() { params_->batch = batch; params_->row_ = params_->a_transpose_ ? a_shape[a_shape.size() - 1] : a_shape[a_shape.size() - 2]; params_->deep_ = params_->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1]; - params_->row_16_ = UP_ROUND(params_->row_, C16NUM); + params_->row_16_ = UP_ROUND(params_->row_, row_tile_); } void MatmulFP16CPUKernel::InitBShape() { @@ -55,6 +55,11 @@ void MatmulFP16CPUKernel::InitBShape() { } int MatmulFP16CPUKernel::Init() { +#ifdef ENABLE_ARM64 + row_tile_ = C16NUM; +#else + row_tile_ = C12NUM; +#endif MatmulBaseFP16CPUKernel::InitParameter(); if (params_->a_const_) { diff --git a/mindspore/lite/test/models_caffe_fp16.cfg b/mindspore/lite/test/models_caffe_fp16.cfg index 352bd3bf98..cce066f536 100644 --- a/mindspore/lite/test/models_caffe_fp16.cfg +++ b/mindspore/lite/test/models_caffe_fp16.cfg @@ -1,3 +1,8 @@ +# [first column]:model_name, If you need input shape, please connect it through ';' after the model name. +# [second column]:accuracy limit in arm64 +# [third column]:accuracy limit in armv82_a32 +# Each column is separated by a space and comment on a single line! +# The missing third column indicates that armv82_a32 does not need to maintain this model. age_medium 6 beard 2 emotion 60 @@ -68,7 +73,8 @@ PoseNet_dla_17_x512_tmp 5 ml_location_scene_division 8 ml_tabel_recog 0.1 ml_text_division 12 -ml_video_edit_Mnet 11 # Further analysis in the future +# Further analysis in the future to model ml_video_edit_Mnet +ml_video_edit_Mnet 11 ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145 0.5 hdc_age_medium 6 hdc_contour_pose_128 0.5 @@ -100,13 +106,13 @@ ml_face_glasses 2.5 # ml_segmentation_matting 26 # output value unstable ml_segmentation_atlanta_10 5 # ml_bodymask: The difference of output node divided by a very small value leads to a large error -ml_bodymask 14 -ml_Hand_deploy 4 +ml_bodymask 14 13 +ml_Hand_deploy 4 4 # ml_hand_3d_detection: The difference of output node divided by a very small value leads to a large error -ml_hand_3d_detection 12 -ml_hand_3d_regression 3 +ml_hand_3d_detection 12 10 +ml_hand_3d_regression 3 4 # ml_ARengine23_bodypose: The difference of output node divided by a very small value leads to a large error -ml_ARengine23_bodypose 56 +ml_ARengine23_bodypose 56 58 ml_ocr_bank_card_detection_inception_tmp 20 ml_ocr_bank_card_recognition_fcny 0.5 hiai_cv_aestheticsEngineModel_osp 1.5 diff --git a/mindspore/lite/test/models_onnx_fp16.cfg b/mindspore/lite/test/models_onnx_fp16.cfg index d20fe7943b..224f452cc8 100644 --- a/mindspore/lite/test/models_onnx_fp16.cfg +++ b/mindspore/lite/test/models_onnx_fp16.cfg @@ -1,3 +1,8 @@ +# [first column]:model_name, If you need input shape, please connect it through ';' after the model name. +# [second column]:accuracy limit in arm64 +# [third column]:accuracy limit in armv82_a32 +# Each column is separated by a space and comment on a single line! +# The missing third column indicates that armv82_a32 does not need to maintain this model. mtk_detect-mbv2-shortcut-400-400-simplified.onnx 4 mtk_face_features_v3.onnx 20 emotion-ferplus-8.onnx 1 @@ -37,7 +42,8 @@ residual_distill_res34_cifar10_bs_1_update.onnx 2 residual_distill_res50_cifar10_bs_1_update.onnx 2 #ml_voice_detect.onnx #out of float16 range because power op hdc_ocr_attention.onnx 1.6 -hdc_ocr_detect_tmp.onnx 30 #one of the output has small values +#one of the output has small values in model hdc_ocr_detect_tmp.onnx +hdc_ocr_detect_tmp.onnx 30 ml_edu_kit_hand_detection.onnx 2 ml_edu_kit_hand_key_position.onnx 2 ml_video_edit_judge.onnx 12 diff --git a/mindspore/lite/test/models_tf_fp16.cfg b/mindspore/lite/test/models_tf_fp16.cfg index 2c5e9d73fb..1709130c7a 100644 --- a/mindspore/lite/test/models_tf_fp16.cfg +++ b/mindspore/lite/test/models_tf_fp16.cfg @@ -1,3 +1,8 @@ +# [first column]:model_name, If you need input shape, please connect it through ';' after the model name. +# [second column]:accuracy limit in arm64 +# [third column]:accuracy limit in armv82_a32 +# Each column is separated by a space and comment on a single line! +# The missing third column indicates that armv82_a32 does not need to maintain this model. ml_vision_guide_detection1.pb 0.5 ml_vision_guide_detection3.pb 0.5 ml_video_edit_generate_filter.pb 2 diff --git a/mindspore/lite/test/models_tflite_fp16.cfg b/mindspore/lite/test/models_tflite_fp16.cfg index 900b2cc982..b06d681105 100644 --- a/mindspore/lite/test/models_tflite_fp16.cfg +++ b/mindspore/lite/test/models_tflite_fp16.cfg @@ -1,3 +1,8 @@ +# [first column]:model_name, If you need input shape, please connect it through ';' after the model name. +# [second column]:accuracy limit in arm64 +# [third column]:accuracy limit in armv82_a32 +# Each column is separated by a space and comment on a single line! +# The missing third column indicates that armv82_a32 does not need to maintain this model. hiai_model_0909_kd_rot_ps_softmax.tflite 10 hiai_chinese_english_recognize_model_float32.tflite 13 hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite.tflite 10 diff --git a/mindspore/lite/test/models_with_multiple_inputs_fp16.cfg b/mindspore/lite/test/models_with_multiple_inputs_fp16.cfg index 5dcd399c6a..71e2ab3036 100644 --- a/mindspore/lite/test/models_with_multiple_inputs_fp16.cfg +++ b/mindspore/lite/test/models_with_multiple_inputs_fp16.cfg @@ -1,3 +1,8 @@ +# [first column]:model_name;input_bin_number;input_shape (input_bin_number and input_shape maybe do not need.) +# [second column]:accuracy limit in arm64 +# [third column]:accuracy limit in armv82_a32 +# Each column is separated by a space. +# The missing third column indicates that armv82_a32 does not need to maintain this model. ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite.tflite;2 11 ml_video_edit_video_segment_gauss_adaptis_part2.pb;2 11 ml_video_edit_img_segment_adaptise.pb;2 40 @@ -19,5 +24,5 @@ ml_tts_decoder.pb;5 2.5 hiai_cv_labelDetectorModel_v3.tflite;2 2 ml_tts_vocoder.pb;66 53 # The outputs of two Heatmap_depth models have small value -ml_Heatmap_depth_240180;2 102 -ml_Heatmap_depth_180240;2 101 \ No newline at end of file +ml_Heatmap_depth_240180;2 10 16 +ml_Heatmap_depth_180240;2 7 7 diff --git a/mindspore/lite/test/run_benchmark_nets.sh b/mindspore/lite/test/run_benchmark_nets.sh old mode 100755 new mode 100644 index 405a186a3e..9b734b6ea4 --- a/mindspore/lite/test/run_benchmark_nets.sh +++ b/mindspore/lite/test/run_benchmark_nets.sh @@ -1900,6 +1900,181 @@ function Run_arm64_fp16() { fi done < ${models_multiple_inputs_fp16_config} } + +# Run on armv8.2-a32-fp16 platform: +function Run_armv82_a32_fp16() { + cd ${armv82_path} || exit 1 + tar -zxf mindspore-lite-${version}-inference-android-aarch32.tar.gz || exit 1 + + # If build with minddata, copy the minddata related libs + cd ${benchmark_test_path} || exit 1 + if [ -f ${armv82_path}/mindspore-lite-${version}-inference-android-aarch32/inference/minddata/lib/libminddata-lite.so ]; then + cp -a ${armv82_path}/mindspore-lite-${version}-inference-android-aarch32/inference/minddata/lib/libminddata-lite.so ${benchmark_test_path}/libminddata-lite.so || exit 1 + fi + + cp -a ${armv82_path}/mindspore-lite-${version}-inference-android-aarch32/inference/lib/libmindspore-lite.so ${benchmark_test_path}/libmindspore-lite.so || exit 1 + cp -a ${armv82_path}/mindspore-lite-${version}-inference-android-aarch32/tools/benchmark/benchmark ${benchmark_test_path}/benchmark || exit 1 + + # adb push all needed files to the phone + adb -s ${device_id} push ${benchmark_test_path} /data/local/tmp/ > adb_push_log.txt + + # run adb ,run session ,check the result: + echo 'cd /data/local/tmp/benchmark_test' > adb_cmd.txt + echo 'cp /data/local/tmp/arm32/libc++_shared.so ./' >> adb_cmd.txt + echo 'chmod 777 benchmark' >> adb_cmd.txt + + adb -s ${device_id} shell < adb_cmd.txt + + # Run fp16 converted models: + while read line; do + fp16_line_info=${line} + column_num=`echo ${fp16_line_info} | awk -F ' ' '{print NF}'` + if [[ ${fp16_line_info} == \#* || ${column_num} -lt 3 ]]; then + continue + fi + model_info=`echo ${fp16_line_info}|awk -F ' ' '{print $1}'` + accuracy_limit=`echo ${fp16_line_info}|awk -F ' ' '{print $3}'` + model_name=${model_info%%;*} + length=${#model_name} + input_shapes=${model_info:length+1} + echo "---------------------------------------------------------" >> "${run_armv82_a32_fp16_log_file}" + echo "fp16 run: ${model_name}, accuracy limit:${accuracy_limit}" >> "${run_armv82_a32_fp16_log_file}" + + echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt + echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test' >> adb_run_cmd.txt + if [[ $accuracy_limit == "-1" ]]; then + echo './benchmark --modelFile='${model_name}'.fp16.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --enableFp16=true --inputShapes='${input_shapes} >> adb_run_cmd.txt + else + echo './benchmark --modelFile='${model_name}'.fp16.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out --enableFp16=true --accuracyThreshold='${accuracy_limit} ' --inputShapes='${input_shapes} >> adb_run_cmd.txt + fi + cat adb_run_cmd.txt >> "${run_armv82_a32_fp16_log_file}" + adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_armv82_a32_fp16_log_file}" + if [ $? = 0 ]; then + run_result='armv82_a32_fp16: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} + else + run_result='armv82_a32_fp16: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1 + fi + done < ${models_onnx_fp16_config} + + while read line; do + fp16_line_info=${line} + column_num=`echo ${fp16_line_info} | awk -F ' ' '{print NF}'` + if [[ ${fp16_line_info} == \#* || ${column_num} -lt 3 ]]; then + continue + fi + model_name=`echo ${fp16_line_info}|awk -F ' ' '{print $1}'` + accuracy_limit=`echo ${fp16_line_info}|awk -F ' ' '{print $3}'` + echo "---------------------------------------------------------" >> "${run_armv82_a32_fp16_log_file}" + echo "fp16 run: ${model_name}, accuracy limit:${accuracy_limit}" >> "${run_armv82_a32_fp16_log_file}" + + echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt + echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test' >> adb_run_cmd.txt + echo './benchmark --modelFile='${model_name}'.fp16.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out --enableFp16=true --accuracyThreshold='${accuracy_limit} >> adb_run_cmd.txt + + cat adb_run_cmd.txt >> "${run_armv82_a32_fp16_log_file}" + adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_armv82_a32_fp16_log_file}" + if [ $? = 0 ]; then + run_result='armv82_a32_fp16: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} + else + run_result='armv82_a32_fp16: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1 + fi + done < ${models_caffe_fp16_config} + + while read line; do + fp16_line_info=${line} + column_num=`echo ${fp16_line_info} | awk -F ' ' '{print NF}'` + if [[ ${fp16_line_info} == \#* || ${column_num} -lt 3 ]]; then + continue + fi + model_name=`echo ${fp16_line_info}|awk -F ' ' '{print $1}'` + accuracy_limit=`echo ${fp16_line_info}|awk -F ' ' '{print $3}'` + echo "---------------------------------------------------------" >> "${run_armv82_a32_fp16_log_file}" + echo "fp16 run: ${model_name}, accuracy limit:${accuracy_limit}" >> "${run_armv82_a32_fp16_log_file}" + + echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt + echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test' >> adb_run_cmd.txt + echo './benchmark --modelFile='${model_name}'.fp16.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out --enableFp16=true --accuracyThreshold='${accuracy_limit} >> adb_run_cmd.txt + + cat adb_run_cmd.txt >> "${run_armv82_a32_fp16_log_file}" + adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_armv82_a32_fp16_log_file}" + if [ $? = 0 ]; then + run_result='armv82_a32_fp16: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} + else + run_result='armv82_a32_fp16: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1 + fi + done < ${models_tflite_fp16_config} + + # Run fp16 converted models: + while read line; do + fp16_line_info=${line} + column_num=`echo ${fp16_line_info} | awk -F ' ' '{print NF}'` + if [[ ${fp16_line_info} == \#* || ${column_num} -lt 3 ]]; then + continue + fi + model_info=`echo ${fp16_line_info}|awk -F ' ' '{print $1}'` + accuracy_limit=`echo ${fp16_line_info}|awk -F ' ' '{print $3}'` + model_name=${model_info%%;*} + length=${#model_name} + input_shapes=${model_info:length+1} + echo "---------------------------------------------------------" >> "${run_armv82_a32_fp16_log_file}" + echo "fp16 run: ${model_name}, accuracy limit:${accuracy_limit}" >> "${run_armv82_a32_fp16_log_file}" + + echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt + echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test' >> adb_run_cmd.txt + if [[ $accuracy_limit == "-1" ]]; then + echo './benchmark --modelFile='${model_name}'.fp16.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --enableFp16=true --inputShapes='${input_shapes} >> adb_run_cmd.txt + else + echo './benchmark --modelFile='${model_name}'.fp16.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out --enableFp16=true --accuracyThreshold='${accuracy_limit} ' --inputShapes='${input_shapes} >> adb_run_cmd.txt + fi + cat adb_run_cmd.txt >> "${run_armv82_a32_fp16_log_file}" + adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_armv82_a32_fp16_log_file}" + if [ $? = 0 ]; then + run_result='armv82_a32_fp16: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} + else + run_result='armv82_a32_fp16: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1 + fi + done < ${models_tf_fp16_config} + + # Run converted models which has multiple inputs in fp16 mode: + while read line; do + fp16_line_info=${line} + column_num=`echo ${fp16_line_info} | awk -F ' ' '{print NF}'` + if [[ ${fp16_line_info} == \#* || ${column_num} -lt 3 ]]; then + continue + fi + model_info=`echo ${fp16_line_info}|awk -F ' ' '{print $1}'` + accuracy_limit=`echo ${fp16_line_info}|awk -F ' ' '{print $3}'` + model_name=`echo ${model_info}|awk -F ';' '{print $1}'` + input_num=`echo ${model_info} | awk -F ';' '{print $2}'` + input_shapes=`echo ${model_info} | awk -F ';' '{print $3}'` + input_files='' + output_file='' + data_path="/data/local/tmp/input_output/" + for i in $(seq 1 $input_num) + do + input_files=$input_files${data_path}'input/'$model_name'.ms.bin_'$i',' + done + output_file=${data_path}'output/'${model_name}'.ms.out' + if [[ ${model_name##*.} == "caffemodel" ]]; then + model_name=${model_name%.*} + fi + echo "---------------------------------------------------------" >> "${run_armv82_a32_fp16_log_file}" + echo "fp16 run: ${model_name}, accuracy limit:${accuracy_limit}" >> "${run_armv82_a32_fp16_log_file}" + + echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt + echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test' >> adb_run_cmd.txt + echo './benchmark --modelFile='${model_name}'.fp16.ms --inDataFile='${input_files}' --inputShapes='${input_shapes}' --benchmarkDataFile='${output_file} '--enableFp16=true --accuracyThreshold='${accuracy_limit} >> adb_run_cmd.txt + + cat adb_run_cmd.txt >> "${run_armv82_a32_fp16_log_file}" + adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_armv82_a32_fp16_log_file}" + if [ $? = 0 ]; then + run_result='armv82_a32_fp16: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} + else + run_result='armv82_a32_fp16: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1 + fi + done < ${models_multiple_inputs_fp16_config} +} + # Run on gpu platform: function Run_gpu() { cd ${arm64_path} || exit 1 @@ -2249,7 +2424,7 @@ fi # Write benchmark result to temp file run_benchmark_result_file=${basepath}/run_benchmark_result.txt echo ' ' > ${run_benchmark_result_file} -run_x86_log_file + run_x86_log_file=${basepath}/run_x86_log.txt echo 'run x86 logs: ' > ${run_x86_log_file} @@ -2271,6 +2446,9 @@ echo 'run arm64_fp32 logs: ' > ${run_arm64_fp32_log_file} run_arm64_fp16_log_file=${basepath}/run_arm64_fp16_log.txt echo 'run arm64_fp16 logs: ' > ${run_arm64_fp16_log_file} +run_armv82_a32_fp16_log_file=${basepath}/run_armv82_a32_fp16_log.txt +echo 'run arm82_a32_fp16 logs: ' > ${run_armv82_a32_fp16_log_file} + run_arm32_log_file=${basepath}/run_arm32_log.txt echo 'run arm32 logs: ' > ${run_arm32_log_file} @@ -2331,6 +2509,33 @@ if [[ $backend == "all" || $backend == "x86-all" || $backend == "x86-codegen" ]] sleep 1 fi +if [[ $backend == "all" || $backend == "arm_cpu" || $backend == "arm32_fp16" ]]; then + # Run on armv82-a32-fp16 + armv82_path=${release_path}/android_aarch32 + file_name=$(ls ${armv82_path}/*inference-android-aarch32.tar.gz) + IFS="-" read -r -a file_name_array <<< "$file_name" + version=${file_name_array[2]} + + echo "start Run armv82-a32-fp16 ..." + Run_armv82_a32_fp16 + Run_armv82_a32_fp16_status=$? + sleep 1 +fi + +if [[ $backend == "all" || $backend == "arm_cpu" || $backend == "arm32_fp32" ]]; then + # Run on arm32 + arm32_path=${release_path}/android_aarch32 + # mv ${arm32_path}/*train-android-aarch32* ./train + file_name=$(ls ${arm32_path}/*inference-android-aarch32.tar.gz) + IFS="-" read -r -a file_name_array <<< "$file_name" + version=${file_name_array[2]} + + echo "start Run arm32 ..." + Run_arm32 + Run_arm32_status=$? + sleep 1 +fi + if [[ $backend == "all" || $backend == "arm_cpu" || $backend == "arm64_fp32" ]]; then # Run on arm64 arm64_path=${release_path}/android_aarch64 @@ -2359,20 +2564,6 @@ if [[ $backend == "all" || $backend == "arm_cpu" || $backend == "arm64_fp16" ]]; sleep 1 fi -if [[ $backend == "all" || $backend == "arm_cpu" || $backend == "arm32" ]]; then - # Run on arm32 - arm32_path=${release_path}/android_aarch32 - # mv ${arm32_path}/*train-android-aarch32* ./train - file_name=$(ls ${arm32_path}/*inference-android-aarch32.tar.gz) - IFS="-" read -r -a file_name_array <<< "$file_name" - version=${file_name_array[2]} - - echo "start Run arm32 ..." - Run_arm32 - Run_arm32_status=$? - sleep 1 -fi - if [[ $backend == "all" || $backend == "gpu_npu" || $backend == "gpu" ]]; then # Run on gpu arm64_path=${release_path}/android_aarch64 @@ -2468,7 +2659,14 @@ if [[ $backend == "all" || $backend == "arm_cpu" || $backend == "arm64_fp16" ]]; isFailed=1 fi fi -if [[ $backend == "all" || $backend == "arm_cpu" || $backend == "arm32" ]]; then +if [[ $backend == "all" || $backend == "arm_cpu" || $backend == "arm32_fp16" ]]; then + if [[ ${Run_armv82_a32_fp16_status} != 0 ]];then + echo "Run_armv82_a32_fp16 failed" + cat ${run_armv82_a32_fp16_log_file} + isFailed=1 + fi +fi +if [[ $backend == "all" || $backend == "arm_cpu" || $backend == "arm32_fp32" ]]; then if [[ ${Run_arm32_status} != 0 ]];then echo "Run_arm32 failed" cat ${run_arm32_log_file} @@ -2490,7 +2688,7 @@ if [[ $backend == "all" || $backend == "gpu_npu" || $backend == "npu" ]]; then fi fi -echo "Run_x86 and Run_x86_sse and Run_arm64_fp32 and Run_arm64_fp16 and Run_arm32 and Run_gpu and Run_npu is ended" +echo "Run_x86 and Run_x86_sse and Run_x86_avx and Run_arm64_fp32 and Run_arm64_fp16 and Run_arm32_fp32 and Run_armv82_a32_fp16 and Run_gpu and Run_npu and is ended" Print_Benchmark_Result if [[ $isFailed == 1 ]]; then exit 1