diff --git a/mindspore/lite/nnacl/assembly/arm64/PostFuncBiasReluC8.S b/mindspore/lite/nnacl/assembly/arm64/PostFuncBiasReluC8.S index 07f62473a7..02c125de07 100644 --- a/mindspore/lite/nnacl/assembly/arm64/PostFuncBiasReluC8.S +++ b/mindspore/lite/nnacl/assembly/arm64/PostFuncBiasReluC8.S @@ -41,7 +41,7 @@ Loop_C8: mov w13, w5 ld1 {v16.4s, v17.4s}, [x2], #32 -Loop8x8: +Loop_8x8: cmp w13, #8 blt Loop_4x8 sub w13, w13, #8 @@ -67,9 +67,9 @@ Loop8x8: fadd v14.4s, v14.4s, v16.4s fadd v15.4s, v15.4s, v17.4s - cmp w7, #3 + cmp x7, #3 beq Relu6_8x8 - cmp w7, #1 + cmp x7, #1 beq Relu_8x8 b Write_8x8 Relu6_8x8: @@ -115,7 +115,7 @@ Write_8x8: st1 {v10.4s, v11.4s}, [x15], x6 st1 {v12.4s, v13.4s}, [x15], x6 st1 {v14.4s, v15.4s}, [x15], x6 - b Loop8x8 + b Loop_8x8 Loop_4x8: cmp w13, #4 @@ -133,9 +133,9 @@ Loop_4x8: fadd v6.4s, v6.4s, v16.4s fadd v7.4s, v7.4s, v17.4s - cmp w7, #2 + cmp x7, #3 beq Relu6_4x8 - cmp w7, #1 + cmp x7, #1 beq Relu_4x8 b Write_4x8 Relu6_4x8: @@ -163,9 +163,9 @@ Write_4x8: st1 {v6.4s, v7.4s}, [x15], x6 Loop_1x8: - cmp w7, #2 + cmp x7, #3 beq Relu6_1x8 - cmp w7, #1 + cmp x7, #1 beq Relu_1x8 b Write_1x8 Relu6_1x8: @@ -228,9 +228,9 @@ Loop_C1: beq Loop_C1_7 Loop_C1_1: - cmp w7, #2 + cmp x7, #3 beq Loop_C1_1_Relu6 - cmp w7, #1 + cmp x7, #1 beq Loop_C1_1_Relu b Loop_C1_1_Write Loop_C1_1_Relu6: @@ -265,9 +265,9 @@ Loop_C1_1_Write: b Loop_C1_1_Write Loop_C1_2: - cmp w7, #2 + cmp x7, #3 beq Loop_C1_2_Relu6 - cmp w7, #1 + cmp x7, #1 beq Loop_C1_2_Relu b Loop_C1_2_Write Loop_C1_2_Relu6: @@ -307,9 +307,9 @@ Loop_C1_2_Write: Loop_C1_3: add x15, x0, #8 - cmp w7, #2 + cmp x7, #3 beq Loop_C1_3_Relu6 - cmp w7, #1 + cmp x7, #1 beq Loop_C1_3_Relu b Loop_C1_3_Write Loop_C1_3_Relu6: @@ -350,9 +350,9 @@ Loop_C1_3_Write: b Loop_C1_3_Write Loop_C1_4: - cmp w7, #2 + cmp x7, #3 beq Loop_C1_4_Relu6 - cmp w7, #1 + cmp x7, #1 beq Loop_C1_4_Relu b Loop_C1_4_Write Loop_C1_4_Relu6: @@ -385,9 +385,9 @@ Loop_C1_4_Write: Loop_C1_5: add x15, x0, #16 - cmp w7, #2 + cmp x7, #3 beq Loop_C1_5_Relu6 - cmp w7, #1 + cmp x7, #1 beq Loop_C1_5_Relu b Loop_C1_5_Write Loop_C1_5_Relu6: @@ -432,9 +432,9 @@ Loop_C1_5_Write: Loop_C1_6: add x15, x0, #16 - cmp w7, #2 + cmp x7, #3 beq Loop_C1_6_Relu6 - cmp w7, #1 + cmp x7, #1 beq Loop_C1_6_Relu b Loop_C1_6_Write Loop_C1_6_Relu6: @@ -483,9 +483,9 @@ Loop_C1_6_Write: Loop_C1_7: add x15, x0, #16 add x14, x0, #24 - cmp w7, #2 + cmp x7, #3 beq Loop_C1_7_Relu6 - cmp w7, #1 + cmp x7, #1 beq Loop_C1_7_Relu b Loop_C1_7_Write Loop_C1_7_Relu6: diff --git a/mindspore/lite/nnacl/fp32/common_func.c b/mindspore/lite/nnacl/fp32/common_func.c index fe6ed6e9b2..cfad0d189b 100644 --- a/mindspore/lite/nnacl/fp32/common_func.c +++ b/mindspore/lite/nnacl/fp32/common_func.c @@ -17,7 +17,7 @@ #include "nnacl/fp32/common_func.h" void PostConvFuncComm(const float *src_ptr_, float *out_ptr, const float *bias_ptr, size_t output_channel, - size_t plane_size, size_t plane_stride, size_t oc_stride, bool is_relu, bool is_relu6, int size) { + size_t plane_size, size_t plane_stride, size_t oc_stride, ActType relu_type, int size) { int oc_div = 0, oc_mod = 0; for (int oc = 0; oc < output_channel; oc++) { if (size != 0) { @@ -33,8 +33,8 @@ void PostConvFuncComm(const float *src_ptr_, float *out_ptr, const float *bias_p if (bias_ptr != NULL) { value = value + bias_ptr[oc]; } - value = (is_relu || is_relu6) ? (MSMAX(0.f, value)) : (value); - value = (is_relu6) ? (MSMIN(6.f, value)) : (value); + value = (relu_type == ActType_Relu || relu_type == ActType_Relu6) ? (MSMAX(0.f, value)) : (value); + value = (relu_type == ActType_Relu6) ? (MSMIN(6.f, value)) : (value); out_ptr[dst_index] = value; } } @@ -42,25 +42,22 @@ void PostConvFuncComm(const float *src_ptr_, float *out_ptr, const float *bias_p } void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, - size_t plane_size, size_t stride, bool is_relu, bool is_relu6) { + size_t plane_size, size_t stride, size_t relu_type) { #ifndef ENABLE_ARM - PostConvFuncComm(c8_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_size, stride, is_relu, is_relu6, - C8NUM); + PostConvFuncComm(c8_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_size, stride, relu_type, C8NUM); #else size_t oc8mod = output_channel % C8NUM; size_t oc8div = output_channel - oc8mod; size_t stride_size = stride * sizeof(float); - size_t relu_type = is_relu ? 1 : 0; - relu_type = is_relu6 ? 3 : relu_type; PostFuncBiasReluC8(out_ptr, c8_out_ptr, bias_ptr, oc8div, oc8mod, plane_size, stride_size, relu_type); #endif return; } void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, - size_t plane_size, size_t plane_stride, bool is_relu, bool is_relu6) { - PostConvFuncComm(c4_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_stride, output_channel, is_relu, - is_relu6, C4NUM); + size_t plane_size, size_t plane_stride, size_t relu_type) { + PostConvFuncComm(c4_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_stride, output_channel, relu_type, + C4NUM); return; } diff --git a/mindspore/lite/nnacl/fp32/common_func.h b/mindspore/lite/nnacl/fp32/common_func.h index ab40622391..157aaecc57 100644 --- a/mindspore/lite/nnacl/fp32/common_func.h +++ b/mindspore/lite/nnacl/fp32/common_func.h @@ -28,9 +28,9 @@ extern "C" { #endif void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, - size_t plane_size, size_t stride, bool is_relu, bool is_relu6); + size_t plane_size, size_t stride, size_t relu_type); void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, - size_t plane_size, size_t plane_stride, bool is_relu, bool is_relu6); + size_t plane_size, size_t plane_stride, size_t relu_type); void WinogradMatrixProductLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length); void WinogradMatrixProductRight(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length); diff --git a/mindspore/lite/nnacl/fp32/deconv.c b/mindspore/lite/nnacl/fp32/deconv.c index a6f41bd424..0865118cde 100644 --- a/mindspore/lite/nnacl/fp32/deconv.c +++ b/mindspore/lite/nnacl/fp32/deconv.c @@ -103,7 +103,6 @@ void DeConvPostFp32C8(const float *src, float *tmp, const float *bias, float *ds } /*ih*/ } /*oc8*/ - PostConvFuncFp32C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_, - conv_param->act_type_ == ActType_Relu, conv_param->act_type_ == ActType_Relu6); + PostConvFuncFp32C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_, conv_param->act_type_); return; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_winograd.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_winograd.cc index da47475d34..48502a385f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_winograd.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_winograd.cc @@ -334,8 +334,7 @@ int DeConvolutionWinogradCPUKernel::DeDeconvPost(int task_id) { PostConvFuncFp32C4(nc4hw4_output_ + task_id * thread_stride_hw_ * C4NUM, nhwc_output_ + task_id * thread_stride_hw_ * conv_param_->output_channel_, reinterpret_cast(bias_data_), conv_param_->output_channel_, current_plane, - deconv_param_->output_plane_, conv_param_->act_type_ == ActType_Relu, - conv_param_->act_type_ == ActType_Relu6); + deconv_param_->output_plane_, conv_param_->act_type_); return RET_OK; } diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc index 289d8fd01e..f1929e1abd 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc @@ -20,6 +20,7 @@ #include "src/common/file_utils.h" #include "mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.h" #include "mindspore/lite/nnacl/fp32/deconv.h" +#include "mindspore/lite/nnacl/op_base.h" namespace mindspore { class TestDeConvolutionFp32 : public mindspore::CommonTest { @@ -106,15 +107,15 @@ TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test1) { float out[8] = {0}; float no[] = {-8.646674, -4.7133026, -0.11849791, -4.530405, -5.419181, 14.387108, 2.8319538, -8.511095}; - PostConvFuncFp32C8(in, out, bias, 1, 8, 1, false, false); + PostConvFuncFp32C8(in, out, bias, 1, 8, 1, ActType_No); CompareOutputData(out, no, 8, 0.0001); float relu[] = {0, 0, 0, 0, 0, 14.387108, 2.8319538, 0}; - PostConvFuncFp32C8(in, out, bias, 1, 8, 1, true, false); + PostConvFuncFp32C8(in, out, bias, 1, 8, 1, ActType_Relu); CompareOutputData(out, relu, 8, 0.0001); float corr_relu6[] = {0, 0, 0, 0, 0, 6, 2.8319538, 0}; - PostConvFuncFp32C8(in, out, bias, 1, 8, 1, false, true); + PostConvFuncFp32C8(in, out, bias, 1, 8, 1, ActType_Relu6); CompareOutputData(out, corr_relu6, 8, 0.0001); } @@ -132,15 +133,15 @@ TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test2) { float no[] = {-8.646674, 0, -4.7133026, 0, -0.11849791, 0, -4.530405, 0, -5.419181, 0, 14.387108, 0, 2.8319538, 0, -8.511095, 0}; - PostConvFuncFp32C8(in, out, bias, 1, 8, 2, false, false); + PostConvFuncFp32C8(in, out, bias, 1, 8, 2, ActType_No); CompareOutputData(out, no, 16, 0.0001); float relu[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14.387108, 0, 2.8319538, 0, 0, 0}; - PostConvFuncFp32C8(in, out, bias, 1, 8, 2, true, false); + PostConvFuncFp32C8(in, out, bias, 1, 8, 2, ActType_Relu); CompareOutputData(out, relu, 16, 0.0001); float corr_relu6[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 2.8319538, 0, 0, 0}; - PostConvFuncFp32C8(in, out, bias, 1, 8, 2, false, true); + PostConvFuncFp32C8(in, out, bias, 1, 8, 2, ActType_Relu6); CompareOutputData(out, corr_relu6, 16, 0.0001); } @@ -159,7 +160,7 @@ TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test3) { float no[] = {-8.646674, -5.3524485, 8.56133, -4.7133026, 1.2270198, 17.954533, -0.11849791, -3.9182835, 11.90631, -4.530405, -0.47735345, -3.7422307, -5.419181, -0.14518678, -8.15199, 14.387108, 8.693133, 8.080041, 2.8319538, 7.177942, -4.409286, -8.511095, -5.110127, -4.992582}; - PostConvFuncFp32C8(in, out, bias, 3, 8, 3, false, false); + PostConvFuncFp32C8(in, out, bias, 3, 8, 3, ActType_No); CompareOutputData(out, no, 24, 0.0001); } @@ -177,12 +178,12 @@ TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test4) { float co32[] = {0, 0, 0, 0, 0, 1.2270198, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14.387108, 8.693133, 0, 0, 2.8319538, 7.177942, 0, 0, 0, 0, 0, 0}; - PostConvFuncFp32C8(in, out, bias, 2, 8, 4, true, false); + PostConvFuncFp32C8(in, out, bias, 2, 8, 4, ActType_Relu); CompareOutputData(out, co32, 32, 0.0001); float co32_relu6[] = {0, 0, 6, 0, 0, 1.2270198, 6, 6, 0, 0, 6, 0.3088621, 0, 0, 0, 0, 0, 0, 0, 6, 6, 6, 6, 0, 2.8319538, 6, 0, 6, 0, 0, 0, 0}; - PostConvFuncFp32C8(in, out, bias, 4, 8, 4, false, true); + PostConvFuncFp32C8(in, out, bias, 4, 8, 4, ActType_Relu6); CompareOutputData(out, co32_relu6, 32, 0.0001); } @@ -203,19 +204,19 @@ TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test5) { -0.47735345, -3.7422307, -7.379536, -3.4496975, -5.419181, -0.14518678, -8.15199, 9.464027, -8.334226, 14.387108, 8.693133, 8.080041, -0.30434704, -3.782834, 2.8319538, 7.177942, -4.409286, 12.194644, -7.0295477, -8.511095, -5.110127, -4.992582, -0.31387085, -2.7594402}; - PostConvFuncFp32C8(in, out, bias, 5, 8, 5, false, false); + PostConvFuncFp32C8(in, out, bias, 5, 8, 5, ActType_No); CompareOutputData(out, no, 40, 0.0001); float relu[] = {0, 0, 8.56133, 0, 0, 0, 1.2270198, 17.954533, 11.086085, 0, 0, 0, 11.90631, 0.3088621, 11.196218, 0, 0, 0, 0, 0, 0, 0, 0, 9.464027, 0, 14.387108, 8.693133, 8.080041, 0, 0, 2.8319538, 7.177942, 0, 12.194644, 0, 0, 0, 0, 0, 0}; - PostConvFuncFp32C8(in, out, bias, 5, 8, 5, true, false); + PostConvFuncFp32C8(in, out, bias, 5, 8, 5, ActType_Relu); CompareOutputData(out, relu, 40, 0.0001); float corr_relu6[] = {0, 0, 6, 0, 0, 0, 1.2270198, 6, 6, 0, 0, 0, 6, 0.3088621, 6, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 6, 6, 6, 0, 0, 2.8319538, 6, 0, 6, 0, 0, 0, 0, 0, 0}; - PostConvFuncFp32C8(in, out, bias, 5, 8, 5, false, true); + PostConvFuncFp32C8(in, out, bias, 5, 8, 5, ActType_Relu6); CompareOutputData(out, corr_relu6, 40, 0.0001); } @@ -229,13 +230,13 @@ TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test6) { float no_3[] = {-9.389655, -5.83877, 7.5724425, 0, 0, 0, -0.8614793, -4.404605, 10.917422, 0, 0, 0, -6.1621623, -0.6315082, -9.140878, 0, 0, 0, 2.0889723, 6.6916203, -5.3981733, 0, 0, 0}; - PostConvFuncFp32C8(in, out, bias, 3, 4, 6, false, false); + PostConvFuncFp32C8(in, out, bias, 3, 4, 6, ActType_No); CompareOutputData(out, no_3, 24, 0.0001); float no_6[] = {-9.389655, -5.83877, 7.5724425, -1.4675674, -5.456284, 0.7406984, -0.8614793, -4.404605, 10.917422, 0.11158327, -5.2733865, -0.96367484, -6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484}; - PostConvFuncFp32C8(in, out, bias, 6, 4, 6, false, false); + PostConvFuncFp32C8(in, out, bias, 6, 4, 6, ActType_No); CompareOutputData(out, no_6, 24, 0.0001); } @@ -251,7 +252,7 @@ TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test7) { -0.8614793, -4.404605, 10.917422, 0.11158327, -5.2733865, -0.96367484, -4.731118, -6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 7.091153, 2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469}; - PostConvFuncFp32C8(in, out, bias, 7, 4, 7, false, false); + PostConvFuncFp32C8(in, out, bias, 7, 4, 7, ActType_No); CompareOutputData(out, no, 28, 0.0001); } @@ -267,7 +268,7 @@ TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test8_2) { -6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 7.091153, -0.50162584, -0.8614793, -4.404605, 10.917422, 0.11158327, -5.2733865, -0.96367484, -4.731118, -7.576815, 2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469, -0.51114964}; - PostConvFuncFp32C8(in, out, bias, 16, 2, 16, false, false); + PostConvFuncFp32C8(in, out, bias, 16, 2, 16, ActType_No); CompareOutputData(out, no, 28, 0.0001); } @@ -291,7 +292,7 @@ TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test8_4) { -6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 7.091153, -0.50162584, 2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469, -0.51114964, 2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469, -0.51114964}; - PostConvFuncFp32C8(in, out, bias, 16, 4, 16, false, false); + PostConvFuncFp32C8(in, out, bias, 16, 4, 16, ActType_No); CompareOutputData(out, no, 64, 0.0001); } @@ -315,7 +316,7 @@ TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test8_8) { -0.8614793, -4.404605, 10.917422, 0.11158327, -5.2733865, -0.96367484, -4.731118, -7.576815, -6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 7.091153, -0.50162584, 2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469, -0.51114964}; - PostConvFuncFp32C8(in, out, bias, 8, 8, 8, false, false); + PostConvFuncFp32C8(in, out, bias, 8, 8, 8, ActType_No); CompareOutputData(out, no, 64, 0.0001); }