Browse Source

!7436 [MS][LITE][Develop]fix fp32 deconv relu_type bug

Merge pull request !7436 from lixian/master
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
89e11f6b2b
6 changed files with 53 additions and 57 deletions
  1. +22
    -22
      mindspore/lite/nnacl/assembly/arm64/PostFuncBiasReluC8.S
  2. +8
    -11
      mindspore/lite/nnacl/fp32/common_func.c
  3. +2
    -2
      mindspore/lite/nnacl/fp32/common_func.h
  4. +1
    -2
      mindspore/lite/nnacl/fp32/deconv.c
  5. +1
    -2
      mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_winograd.cc
  6. +19
    -18
      mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc

+ 22
- 22
mindspore/lite/nnacl/assembly/arm64/PostFuncBiasReluC8.S View File

@@ -41,7 +41,7 @@ Loop_C8:
mov w13, w5
ld1 {v16.4s, v17.4s}, [x2], #32

Loop8x8:
Loop_8x8:
cmp w13, #8
blt Loop_4x8
sub w13, w13, #8
@@ -67,9 +67,9 @@ Loop8x8:
fadd v14.4s, v14.4s, v16.4s
fadd v15.4s, v15.4s, v17.4s

cmp w7, #3
cmp x7, #3
beq Relu6_8x8
cmp w7, #1
cmp x7, #1
beq Relu_8x8
b Write_8x8
Relu6_8x8:
@@ -115,7 +115,7 @@ Write_8x8:
st1 {v10.4s, v11.4s}, [x15], x6
st1 {v12.4s, v13.4s}, [x15], x6
st1 {v14.4s, v15.4s}, [x15], x6
b Loop8x8
b Loop_8x8

Loop_4x8:
cmp w13, #4
@@ -133,9 +133,9 @@ Loop_4x8:
fadd v6.4s, v6.4s, v16.4s
fadd v7.4s, v7.4s, v17.4s

cmp w7, #2
cmp x7, #3
beq Relu6_4x8
cmp w7, #1
cmp x7, #1
beq Relu_4x8
b Write_4x8
Relu6_4x8:
@@ -163,9 +163,9 @@ Write_4x8:
st1 {v6.4s, v7.4s}, [x15], x6

Loop_1x8:
cmp w7, #2
cmp x7, #3
beq Relu6_1x8
cmp w7, #1
cmp x7, #1
beq Relu_1x8
b Write_1x8
Relu6_1x8:
@@ -228,9 +228,9 @@ Loop_C1:
beq Loop_C1_7

Loop_C1_1:
cmp w7, #2
cmp x7, #3
beq Loop_C1_1_Relu6
cmp w7, #1
cmp x7, #1
beq Loop_C1_1_Relu
b Loop_C1_1_Write
Loop_C1_1_Relu6:
@@ -265,9 +265,9 @@ Loop_C1_1_Write:
b Loop_C1_1_Write

Loop_C1_2:
cmp w7, #2
cmp x7, #3
beq Loop_C1_2_Relu6
cmp w7, #1
cmp x7, #1
beq Loop_C1_2_Relu
b Loop_C1_2_Write
Loop_C1_2_Relu6:
@@ -307,9 +307,9 @@ Loop_C1_2_Write:

Loop_C1_3:
add x15, x0, #8
cmp w7, #2
cmp x7, #3
beq Loop_C1_3_Relu6
cmp w7, #1
cmp x7, #1
beq Loop_C1_3_Relu
b Loop_C1_3_Write
Loop_C1_3_Relu6:
@@ -350,9 +350,9 @@ Loop_C1_3_Write:
b Loop_C1_3_Write

Loop_C1_4:
cmp w7, #2
cmp x7, #3
beq Loop_C1_4_Relu6
cmp w7, #1
cmp x7, #1
beq Loop_C1_4_Relu
b Loop_C1_4_Write
Loop_C1_4_Relu6:
@@ -385,9 +385,9 @@ Loop_C1_4_Write:

Loop_C1_5:
add x15, x0, #16
cmp w7, #2
cmp x7, #3
beq Loop_C1_5_Relu6
cmp w7, #1
cmp x7, #1
beq Loop_C1_5_Relu
b Loop_C1_5_Write
Loop_C1_5_Relu6:
@@ -432,9 +432,9 @@ Loop_C1_5_Write:

Loop_C1_6:
add x15, x0, #16
cmp w7, #2
cmp x7, #3
beq Loop_C1_6_Relu6
cmp w7, #1
cmp x7, #1
beq Loop_C1_6_Relu
b Loop_C1_6_Write
Loop_C1_6_Relu6:
@@ -483,9 +483,9 @@ Loop_C1_6_Write:
Loop_C1_7:
add x15, x0, #16
add x14, x0, #24
cmp w7, #2
cmp x7, #3
beq Loop_C1_7_Relu6
cmp w7, #1
cmp x7, #1
beq Loop_C1_7_Relu
b Loop_C1_7_Write
Loop_C1_7_Relu6:


+ 8
- 11
mindspore/lite/nnacl/fp32/common_func.c View File

@@ -17,7 +17,7 @@
#include "nnacl/fp32/common_func.h"

void PostConvFuncComm(const float *src_ptr_, float *out_ptr, const float *bias_ptr, size_t output_channel,
size_t plane_size, size_t plane_stride, size_t oc_stride, bool is_relu, bool is_relu6, int size) {
size_t plane_size, size_t plane_stride, size_t oc_stride, ActType relu_type, int size) {
int oc_div = 0, oc_mod = 0;
for (int oc = 0; oc < output_channel; oc++) {
if (size != 0) {
@@ -33,8 +33,8 @@ void PostConvFuncComm(const float *src_ptr_, float *out_ptr, const float *bias_p
if (bias_ptr != NULL) {
value = value + bias_ptr[oc];
}
value = (is_relu || is_relu6) ? (MSMAX(0.f, value)) : (value);
value = (is_relu6) ? (MSMIN(6.f, value)) : (value);
value = (relu_type == ActType_Relu || relu_type == ActType_Relu6) ? (MSMAX(0.f, value)) : (value);
value = (relu_type == ActType_Relu6) ? (MSMIN(6.f, value)) : (value);
out_ptr[dst_index] = value;
}
}
@@ -42,25 +42,22 @@ void PostConvFuncComm(const float *src_ptr_, float *out_ptr, const float *bias_p
}

void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel,
size_t plane_size, size_t stride, bool is_relu, bool is_relu6) {
size_t plane_size, size_t stride, size_t relu_type) {
#ifndef ENABLE_ARM
PostConvFuncComm(c8_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_size, stride, is_relu, is_relu6,
C8NUM);
PostConvFuncComm(c8_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_size, stride, relu_type, C8NUM);
#else
size_t oc8mod = output_channel % C8NUM;
size_t oc8div = output_channel - oc8mod;
size_t stride_size = stride * sizeof(float);
size_t relu_type = is_relu ? 1 : 0;
relu_type = is_relu6 ? 3 : relu_type;
PostFuncBiasReluC8(out_ptr, c8_out_ptr, bias_ptr, oc8div, oc8mod, plane_size, stride_size, relu_type);
#endif
return;
}

void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel,
size_t plane_size, size_t plane_stride, bool is_relu, bool is_relu6) {
PostConvFuncComm(c4_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_stride, output_channel, is_relu,
is_relu6, C4NUM);
size_t plane_size, size_t plane_stride, size_t relu_type) {
PostConvFuncComm(c4_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_stride, output_channel, relu_type,
C4NUM);
return;
}



+ 2
- 2
mindspore/lite/nnacl/fp32/common_func.h View File

@@ -28,9 +28,9 @@ extern "C" {
#endif

void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel,
size_t plane_size, size_t stride, bool is_relu, bool is_relu6);
size_t plane_size, size_t stride, size_t relu_type);
void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel,
size_t plane_size, size_t plane_stride, bool is_relu, bool is_relu6);
size_t plane_size, size_t plane_stride, size_t relu_type);

void WinogradMatrixProductLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length);
void WinogradMatrixProductRight(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length);


+ 1
- 2
mindspore/lite/nnacl/fp32/deconv.c View File

@@ -103,7 +103,6 @@ void DeConvPostFp32C8(const float *src, float *tmp, const float *bias, float *ds
} /*ih*/
} /*oc8*/

PostConvFuncFp32C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_,
conv_param->act_type_ == ActType_Relu, conv_param->act_type_ == ActType_Relu6);
PostConvFuncFp32C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_, conv_param->act_type_);
return;
}

+ 1
- 2
mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_winograd.cc View File

@@ -334,8 +334,7 @@ int DeConvolutionWinogradCPUKernel::DeDeconvPost(int task_id) {
PostConvFuncFp32C4(nc4hw4_output_ + task_id * thread_stride_hw_ * C4NUM,
nhwc_output_ + task_id * thread_stride_hw_ * conv_param_->output_channel_,
reinterpret_cast<float *>(bias_data_), conv_param_->output_channel_, current_plane,
deconv_param_->output_plane_, conv_param_->act_type_ == ActType_Relu,
conv_param_->act_type_ == ActType_Relu6);
deconv_param_->output_plane_, conv_param_->act_type_);
return RET_OK;
}



+ 19
- 18
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc View File

@@ -20,6 +20,7 @@
#include "src/common/file_utils.h"
#include "mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.h"
#include "mindspore/lite/nnacl/fp32/deconv.h"
#include "mindspore/lite/nnacl/op_base.h"

namespace mindspore {
class TestDeConvolutionFp32 : public mindspore::CommonTest {
@@ -106,15 +107,15 @@ TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test1) {
float out[8] = {0};

float no[] = {-8.646674, -4.7133026, -0.11849791, -4.530405, -5.419181, 14.387108, 2.8319538, -8.511095};
PostConvFuncFp32C8(in, out, bias, 1, 8, 1, false, false);
PostConvFuncFp32C8(in, out, bias, 1, 8, 1, ActType_No);
CompareOutputData(out, no, 8, 0.0001);

float relu[] = {0, 0, 0, 0, 0, 14.387108, 2.8319538, 0};
PostConvFuncFp32C8(in, out, bias, 1, 8, 1, true, false);
PostConvFuncFp32C8(in, out, bias, 1, 8, 1, ActType_Relu);
CompareOutputData(out, relu, 8, 0.0001);

float corr_relu6[] = {0, 0, 0, 0, 0, 6, 2.8319538, 0};
PostConvFuncFp32C8(in, out, bias, 1, 8, 1, false, true);
PostConvFuncFp32C8(in, out, bias, 1, 8, 1, ActType_Relu6);
CompareOutputData(out, corr_relu6, 8, 0.0001);
}

@@ -132,15 +133,15 @@ TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test2) {

float no[] = {-8.646674, 0, -4.7133026, 0, -0.11849791, 0, -4.530405, 0,
-5.419181, 0, 14.387108, 0, 2.8319538, 0, -8.511095, 0};
PostConvFuncFp32C8(in, out, bias, 1, 8, 2, false, false);
PostConvFuncFp32C8(in, out, bias, 1, 8, 2, ActType_No);
CompareOutputData(out, no, 16, 0.0001);

float relu[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14.387108, 0, 2.8319538, 0, 0, 0};
PostConvFuncFp32C8(in, out, bias, 1, 8, 2, true, false);
PostConvFuncFp32C8(in, out, bias, 1, 8, 2, ActType_Relu);
CompareOutputData(out, relu, 16, 0.0001);

float corr_relu6[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 2.8319538, 0, 0, 0};
PostConvFuncFp32C8(in, out, bias, 1, 8, 2, false, true);
PostConvFuncFp32C8(in, out, bias, 1, 8, 2, ActType_Relu6);
CompareOutputData(out, corr_relu6, 16, 0.0001);
}

@@ -159,7 +160,7 @@ TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test3) {
float no[] = {-8.646674, -5.3524485, 8.56133, -4.7133026, 1.2270198, 17.954533, -0.11849791, -3.9182835,
11.90631, -4.530405, -0.47735345, -3.7422307, -5.419181, -0.14518678, -8.15199, 14.387108,
8.693133, 8.080041, 2.8319538, 7.177942, -4.409286, -8.511095, -5.110127, -4.992582};
PostConvFuncFp32C8(in, out, bias, 3, 8, 3, false, false);
PostConvFuncFp32C8(in, out, bias, 3, 8, 3, ActType_No);
CompareOutputData(out, no, 24, 0.0001);
}

@@ -177,12 +178,12 @@ TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test4) {

float co32[] = {0, 0, 0, 0, 0, 1.2270198, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 14.387108, 8.693133, 0, 0, 2.8319538, 7.177942, 0, 0, 0, 0, 0, 0};
PostConvFuncFp32C8(in, out, bias, 2, 8, 4, true, false);
PostConvFuncFp32C8(in, out, bias, 2, 8, 4, ActType_Relu);
CompareOutputData(out, co32, 32, 0.0001);

float co32_relu6[] = {0, 0, 6, 0, 0, 1.2270198, 6, 6, 0, 0, 6, 0.3088621, 0, 0, 0, 0,
0, 0, 0, 6, 6, 6, 6, 0, 2.8319538, 6, 0, 6, 0, 0, 0, 0};
PostConvFuncFp32C8(in, out, bias, 4, 8, 4, false, true);
PostConvFuncFp32C8(in, out, bias, 4, 8, 4, ActType_Relu6);
CompareOutputData(out, co32_relu6, 32, 0.0001);
}

@@ -203,19 +204,19 @@ TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test5) {
-0.47735345, -3.7422307, -7.379536, -3.4496975, -5.419181, -0.14518678, -8.15199, 9.464027,
-8.334226, 14.387108, 8.693133, 8.080041, -0.30434704, -3.782834, 2.8319538, 7.177942,
-4.409286, 12.194644, -7.0295477, -8.511095, -5.110127, -4.992582, -0.31387085, -2.7594402};
PostConvFuncFp32C8(in, out, bias, 5, 8, 5, false, false);
PostConvFuncFp32C8(in, out, bias, 5, 8, 5, ActType_No);
CompareOutputData(out, no, 40, 0.0001);

float relu[] = {0, 0, 8.56133, 0, 0, 0, 1.2270198, 17.954533, 11.086085, 0,
0, 0, 11.90631, 0.3088621, 11.196218, 0, 0, 0, 0, 0,
0, 0, 0, 9.464027, 0, 14.387108, 8.693133, 8.080041, 0, 0,
2.8319538, 7.177942, 0, 12.194644, 0, 0, 0, 0, 0, 0};
PostConvFuncFp32C8(in, out, bias, 5, 8, 5, true, false);
PostConvFuncFp32C8(in, out, bias, 5, 8, 5, ActType_Relu);
CompareOutputData(out, relu, 40, 0.0001);

float corr_relu6[] = {0, 0, 6, 0, 0, 0, 1.2270198, 6, 6, 0, 0, 0, 6, 0.3088621, 6, 0, 0, 0, 0, 0,
0, 0, 0, 6, 0, 6, 6, 6, 0, 0, 2.8319538, 6, 0, 6, 0, 0, 0, 0, 0, 0};
PostConvFuncFp32C8(in, out, bias, 5, 8, 5, false, true);
PostConvFuncFp32C8(in, out, bias, 5, 8, 5, ActType_Relu6);
CompareOutputData(out, corr_relu6, 40, 0.0001);
}

@@ -229,13 +230,13 @@ TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test6) {

float no_3[] = {-9.389655, -5.83877, 7.5724425, 0, 0, 0, -0.8614793, -4.404605, 10.917422, 0, 0, 0,
-6.1621623, -0.6315082, -9.140878, 0, 0, 0, 2.0889723, 6.6916203, -5.3981733, 0, 0, 0};
PostConvFuncFp32C8(in, out, bias, 3, 4, 6, false, false);
PostConvFuncFp32C8(in, out, bias, 3, 4, 6, ActType_No);
CompareOutputData(out, no_3, 24, 0.0001);

float no_6[] = {-9.389655, -5.83877, 7.5724425, -1.4675674, -5.456284, 0.7406984, -0.8614793, -4.404605,
10.917422, 0.11158327, -5.2733865, -0.96367484, -6.1621623, -0.6315082, -9.140878, 9.266748,
13.644127, 8.206812, 2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484};
PostConvFuncFp32C8(in, out, bias, 6, 4, 6, false, false);
PostConvFuncFp32C8(in, out, bias, 6, 4, 6, ActType_No);
CompareOutputData(out, no_6, 24, 0.0001);
}

@@ -251,7 +252,7 @@ TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test7) {
-0.8614793, -4.404605, 10.917422, 0.11158327, -5.2733865, -0.96367484, -4.731118,
-6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 7.091153,
2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469};
PostConvFuncFp32C8(in, out, bias, 7, 4, 7, false, false);
PostConvFuncFp32C8(in, out, bias, 7, 4, 7, ActType_No);
CompareOutputData(out, no, 28, 0.0001);
}

@@ -267,7 +268,7 @@ TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test8_2) {
-6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 7.091153, -0.50162584,
-0.8614793, -4.404605, 10.917422, 0.11158327, -5.2733865, -0.96367484, -4.731118, -7.576815,
2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469, -0.51114964};
PostConvFuncFp32C8(in, out, bias, 16, 2, 16, false, false);
PostConvFuncFp32C8(in, out, bias, 16, 2, 16, ActType_No);
CompareOutputData(out, no, 28, 0.0001);
}

@@ -291,7 +292,7 @@ TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test8_4) {
-6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 7.091153, -0.50162584,
2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469, -0.51114964,
2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469, -0.51114964};
PostConvFuncFp32C8(in, out, bias, 16, 4, 16, false, false);
PostConvFuncFp32C8(in, out, bias, 16, 4, 16, ActType_No);
CompareOutputData(out, no, 64, 0.0001);
}

@@ -315,7 +316,7 @@ TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test8_8) {
-0.8614793, -4.404605, 10.917422, 0.11158327, -5.2733865, -0.96367484, -4.731118, -7.576815,
-6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 7.091153, -0.50162584,
2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469, -0.51114964};
PostConvFuncFp32C8(in, out, bias, 8, 8, 8, false, false);
PostConvFuncFp32C8(in, out, bias, 8, 8, 8, ActType_No);
CompareOutputData(out, no, 64, 0.0001);
}



Loading…
Cancel
Save