diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc index 6ab1b8190e..321a1d19bc 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc @@ -64,16 +64,14 @@ int DeConvolutionCPUKernel::ReSize() { } int DeConvolutionCPUKernel::InitWeightBias() { + bias_data_ = malloc(UP_ROUND(conv_param_->output_channel_, C4NUM) * sizeof(float)); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "deconv malloc bias_data_ error!"; + return RET_ERROR; + } + memset(bias_data_, 0, UP_ROUND(conv_param_->output_channel_, C4NUM) * sizeof(float)); if (in_tensors_.size() == 3) { - bias_data_ = malloc(UP_ROUND(conv_param_->output_channel_, C4NUM) * sizeof(float)); - if (bias_data_ == nullptr) { - MS_LOG(ERROR) << "deconv malloc bias_data_ error!"; - return RET_ERROR; - } - memset(bias_data_, 0, UP_ROUND(conv_param_->output_channel_, C4NUM) * sizeof(float)); memcpy(bias_data_, in_tensors_[2]->Data(), conv_param_->output_channel_ * sizeof(float)); - } else { - bias_data_ = nullptr; } size_t weight_pack_size = conv_param_->input_channel_ * conv_param_->kernel_w_ * conv_param_->kernel_h_ * @@ -134,41 +132,21 @@ int DeConvFp32Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { return RET_OK; } -int DeConvFp32PostRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { - auto deconv = reinterpret_cast(cdata); - auto error_code = deconv->DoPostFunc(task_id); - if (error_code != RET_OK) { - MS_LOG(ERROR) << "DeConvFp32PostRun error task_id[" << task_id << "] error_code[" << error_code << "]"; - return RET_ERROR; - } - return RET_OK; -} - int DeConvolutionCPUKernel::DoDeconv(int task_id) { int oc = MSMIN(thread_stride_, UP_DIV(conv_param_->output_channel_, C8NUM) - task_id * thread_stride_); - if (oc <= 0) { + int oc_res = MSMIN(thread_stride_ * C8NUM, conv_param_->output_channel_ - task_id * thread_stride_ * C8NUM); + if (oc <= 0 || oc_res <= 0) { return RET_OK; } - MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, - tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_8_, nullptr, ActType_No, - matmul_param_->deep_, matmul_param_->row_8_, oc * C8NUM * kernel_plane_, matmul_param_->col_, false); + auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_8_; + MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, tmp_buffer, + nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_8_, oc * C8NUM * kernel_plane_, + matmul_param_->col_, false); - return RET_OK; -} - -int DeConvolutionCPUKernel::DoPostFunc(int task_id) { - int oc = MSMIN(thread_stride_ * C8NUM, conv_param_->output_channel_ - task_id * thread_stride_ * C8NUM); - if (oc <= 0) { - return RET_OK; - } - - float *bias = - (bias_data_ == nullptr) ? nullptr : reinterpret_cast(bias_data_) + thread_stride_ * task_id * C8NUM; - - DeConvPostFp32C8x8(tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_8_, - pack_output_ + task_id * thread_stride_ * C8NUM * output_plane_, bias, - output_ptr_ + task_id * thread_stride_ * C8NUM, oc, conv_param_); + DeConvPostFp32C8x8(tmp_buffer, pack_output_ + task_id * thread_stride_ * C8NUM * output_plane_, + reinterpret_cast(bias_data_) + thread_stride_ * task_id * C8NUM, + output_ptr_ + task_id * thread_stride_ * C8NUM, oc_res, conv_param_); return RET_OK; } @@ -213,12 +191,6 @@ int DeConvolutionCPUKernel::Run() { MS_LOG(ERROR) << "deconv fp32 run error! error_code[" << error_code << "]"; return RET_ERROR; } - - error_code = LiteBackendParallelLaunch(DeConvFp32PostRun, this, thread_count_); - if (error_code != RET_OK) { - MS_LOG(ERROR) << "deconv fp32 postrun error! error_code[" << error_code << "]"; - return RET_ERROR; - } } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.h b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.h index fb29ff9e5e..eafbb4c6a7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.h @@ -43,7 +43,6 @@ class DeConvolutionCPUKernel : public ConvolutionBaseCPUKernel { public: int DoDeconv(int task_id); - int DoPostFunc(int task_id); private: int InitParam(); diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/PostFuncBiasReluC8.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/PostFuncBiasReluC8.S new file mode 100644 index 0000000000..f07e05f87f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/PostFuncBiasReluC8.S @@ -0,0 +1,532 @@ +#ifdef __aarch64__ + + .text + .align 5 + //.p2align 5,,15 + .global PostFuncBiasReluC8 +#ifndef __APPLE__ + .type PostFuncBiasReluC8, %function +#endif + +//void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div,size_t oc8mod +// size_t plane_size, size_t stride, int relu_type); +// x0 dst x1 srx x2 bias +// x3 oc8div x4 oc8mod x5 plane_size +// x6 stride x7 relu_type + +// v0 ~ v15 value +// v16 v17 bias data +// x24 x25 weite loop tmp buf +// x26 relu6 #6; x27 relu #0 +// w10 oc8 loop control +// w13 hw loop control + +PostFuncBiasReluC8: + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + mov w10, #0 + +Loop_C8: + cmp w10, w3 + beq Loop_C1 + mov x25, #4 + mul x24, x10, x25 + add x25, x0, x24 + add w10, w10, #8 + mov w13, w5 + ld1 {v16.4s, v17.4s}, [x2], #32 + +Loop8x8: + cmp w13, #8 + blt Loop_4x8 + sub w13, w13, #8 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64 + + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fadd v2.4s, v2.4s, v16.4s + fadd v3.4s, v3.4s, v17.4s + fadd v4.4s, v4.4s, v16.4s + fadd v5.4s, v5.4s, v17.4s + fadd v6.4s, v6.4s, v16.4s + fadd v7.4s, v7.4s, v17.4s + fadd v8.4s, v8.4s, v16.4s + fadd v9.4s, v9.4s, v17.4s + fadd v10.4s, v10.4s, v16.4s + fadd v11.4s, v11.4s, v17.4s + fadd v12.4s, v12.4s, v16.4s + fadd v13.4s, v13.4s, v17.4s + fadd v14.4s, v14.4s, v16.4s + fadd v15.4s, v15.4s, v17.4s + + cmp w7, #2 + beq Relu6_8x8 + cmp w7, #1 + beq Relu_8x8 + b Write_8x8 +Relu6_8x8: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s + fmin v8.4s, v8.4s, v26.4s + fmin v9.4s, v9.4s, v26.4s + fmin v10.4s, v10.4s, v26.4s + fmin v11.4s, v11.4s, v26.4s + fmin v12.4s, v12.4s, v26.4s + fmin v13.4s, v13.4s, v26.4s + fmin v14.4s, v14.4s, v26.4s + fmin v15.4s, v15.4s, v26.4s +Relu_8x8: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s + fmax v8.4s, v8.4s, v27.4s + fmax v9.4s, v9.4s, v27.4s + fmax v10.4s, v10.4s, v27.4s + fmax v11.4s, v11.4s, v27.4s + fmax v12.4s, v12.4s, v27.4s + fmax v13.4s, v13.4s, v27.4s + fmax v14.4s, v14.4s, v27.4s + fmax v15.4s, v15.4s, v27.4s +Write_8x8: + st1 {v0.4s, v1.4s}, [x25], x6 + st1 {v2.4s, v3.4s}, [x25], x6 + st1 {v4.4s, v5.4s}, [x25], x6 + st1 {v6.4s, v7.4s}, [x25], x6 + st1 {v8.4s, v9.4s}, [x25], x6 + st1 {v10.4s, v11.4s}, [x25], x6 + st1 {v12.4s, v13.4s}, [x25], x6 + st1 {v14.4s, v15.4s}, [x25], x6 + b Loop8x8 + +Loop_4x8: + cmp w13, #4 + blt Loop_1x8 + sub w13, w13, #4 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64 + + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fadd v2.4s, v2.4s, v16.4s + fadd v3.4s, v3.4s, v17.4s + fadd v4.4s, v4.4s, v16.4s + fadd v5.4s, v5.4s, v17.4s + fadd v6.4s, v6.4s, v16.4s + fadd v7.4s, v7.4s, v17.4s + + cmp w7, #2 + beq Relu6_4x8 + cmp w7, #1 + beq Relu_4x8 + b Write_4x8 +Relu6_4x8: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s +Relu_4x8: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s +Write_4x8: + st1 {v0.4s, v1.4s}, [x25], x6 + st1 {v2.4s, v3.4s}, [x25], x6 + st1 {v4.4s, v5.4s}, [x25], x6 + st1 {v6.4s, v7.4s}, [x25], x6 + +Loop_1x8: + cmp w7, #2 + beq Relu6_1x8 + cmp w7, #1 + beq Relu_1x8 + b Write_1x8 +Relu6_1x8: + cmp w13, #0 + beq Loop_C8 + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s, v1.4s}, [x25], x6 + b Relu6_1x8 +Relu_1x8: + cmp w13, #0 + beq Loop_C8 + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s, v1.4s}, [x25], x6 + b Relu_1x8 +Write_1x8: + cmp w13, #0 + beq Loop_C8 + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + st1 {v0.4s, v1.4s}, [x25], x6 + b Write_1x8 + + +Loop_C1: + cmp x4, #0 + beq End + mov w13, w5 + ld1 {v16.4s, v17.4s}, [x2], #32 + + cmp x4, #1 + beq Loop_C1_1 + cmp x4, #2 + beq Loop_C1_2 + cmp x4, #3 + beq Loop_C1_3 + cmp x4, #4 + beq Loop_C1_4 + cmp x4, #5 + beq Loop_C1_5 + cmp x4, #6 + beq Loop_C1_6 + cmp x4, #7 + beq Loop_C1_7 + +Loop_C1_1: + cmp w7, #2 + beq Loop_C1_1_Relu6 + cmp w7, #1 + beq Loop_C1_1_Relu + b Loop_C1_1_Write +Loop_C1_1_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + str s0, [x0] + add x0, x0, x6 + b Loop_C1_1_Relu6 +Loop_C1_1_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + str s0, [x0] + add x0, x0, x6 + b Loop_C1_1_Relu +Loop_C1_1_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + str s0, [x0] + add x0, x0, x6 + b Loop_C1_1_Write + +Loop_C1_2: + cmp w7, #2 + beq Loop_C1_2_Relu6 + cmp w7, #1 + beq Loop_C1_2_Relu + b Loop_C1_2_Write +Loop_C1_2_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x6 + b Loop_C1_2_Relu6 +Loop_C1_2_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x6 + b Loop_C1_2_Relu +Loop_C1_2_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x6 + b Loop_C1_2_Write + + +Loop_C1_3: + add x25, x0, #8 + cmp w7, #2 + beq Loop_C1_3_Relu6 + cmp w7, #1 + beq Loop_C1_3_Relu + b Loop_C1_3_Write +Loop_C1_3_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x6 + st1 {v0.s}[2], [x25], x6 + b Loop_C1_3_Relu6 +Loop_C1_3_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x6 + st1 {v0.s}[2], [x25], x6 + b Loop_C1_3_Relu +Loop_C1_3_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x6 + st1 {v0.s}[2], [x25], x6 + b Loop_C1_3_Write + +Loop_C1_4: + cmp w7, #2 + beq Loop_C1_4_Relu6 + cmp w7, #1 + beq Loop_C1_4_Relu + b Loop_C1_4_Write +Loop_C1_4_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + st1 {v0.4s}, [x0], x6 + b Loop_C1_4_Relu6 +Loop_C1_4_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + st1 {v0.4s}, [x0], x6 + b Loop_C1_4_Relu6 +Loop_C1_4_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + st1 {v0.4s}, [x0], x6 + b Loop_C1_4_Write + +Loop_C1_5: + add x25, x0, #16 + cmp w7, #2 + beq Loop_C1_5_Relu6 + cmp w7, #1 + beq Loop_C1_5_Relu + b Loop_C1_5_Write +Loop_C1_5_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s}, [x0], x6 + str s1, [x25] + add x25, x25, x6 + b Loop_C1_5_Relu6 +Loop_C1_5_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s}, [x0], x6 + str s1, [x25] + add x25, x25, x6 + b Loop_C1_5_Relu +Loop_C1_5_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + st1 {v0.4s}, [x0], x6 + str s1, [x25] + add x25, x25, x6 + b Loop_C1_5_Write + +Loop_C1_6: + add x25, x0, #16 + cmp w7, #2 + beq Loop_C1_6_Relu6 + cmp w7, #1 + beq Loop_C1_6_Relu + b Loop_C1_6_Write +Loop_C1_6_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s}, [x0], x6 + dup s0, v1.s[1] + stp s1, s0, [x25] + add x25, x25, x6 + b Loop_C1_6_Relu6 +Loop_C1_6_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s}, [x0], x6 + dup s0, v1.s[1] + stp s1, s0, [x25] + add x25, x25, x6 + b Loop_C1_6_Relu +Loop_C1_6_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + st1 {v0.4s}, [x0], x6 + dup s0, v1.s[1] + stp s1, s0, [x25] + add x25, x25, x6 + b Loop_C1_6_Write + +Loop_C1_7: + add x25, x0, #16 + add x24, x0, #24 + cmp w7, #2 + beq Loop_C1_7_Relu6 + cmp w7, #1 + beq Loop_C1_7_Relu + b Loop_C1_7_Write +Loop_C1_7_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s}, [x0], x6 + dup s0, v1.s[1] + stp s1, s0, [x25] + add x25, x25, x6 + st1 {v1.s}[2], [x24], x6 + b Loop_C1_7_Relu6 +Loop_C1_7_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s}, [x0], x6 + dup s0, v1.s[1] + stp s1, s0, [x25] + add x25, x25, x6 + st1 {v1.s}[2], [x24], x6 + b Loop_C1_7_Relu +Loop_C1_7_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + st1 {v0.4s}, [x0], x6 + dup s0, v1.s[1] + stp s1, s0, [x25] + add x25, x25, x6 + st1 {v1.s}[2], [x24], x6 + b Loop_C1_7_Write + +End: + ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.c index 2e77ce0df5..ecce94ca1d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.c @@ -113,6 +113,15 @@ void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bi void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, size_t plane_size, size_t stride, bool is_relu, bool is_relu6) { +#ifndef ENABLE_ARM64 PostConvFuncComm(c8_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, stride, is_relu, is_relu6, C8NUM); +#else + size_t oc8mod = output_channel % C8NUM; + size_t oc8div = output_channel - oc8mod; + size_t stride_size = stride * sizeof(float); + size_t relu_type = is_relu ? 1 : 0; + relu_type = is_relu6 ? 2 : relu_type; + PostFuncBiasReluC8(out_ptr, c8_out_ptr, bias_ptr, oc8div, oc8mod, plane_size, stride_size, relu_type); +#endif return; } diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.h index 90b0d76215..c254219003 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.h @@ -61,6 +61,8 @@ void C4Relu6(float *dst, const float *input, size_t oc, size_t plane_size, size_ void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6); +void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, + size_t plane_size, size_t stride, size_t relu_type); #endif #ifdef __cplusplus diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/deconv.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/deconv.c index 769802e186..5f281c2152 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/deconv.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/deconv.c @@ -33,24 +33,27 @@ void PackDeConvWeightFp32(const float *weight, float *dst, int input_channel, in return; } -int DeConvFp32(const float *input, const float *weight, float *output, float *tmp_buffer, - StrassenMatMulParameter matmul_param) { - return StrassenMatmul(input, weight, output, &matmul_param, FP32_STRASSEN_MAX_RECURSION, 0, tmp_buffer); -} - int DeConvPostFp32C8x8(const float *src, float *tmp, const float *bias, float *dst, int output_channel, ConvParameter *conv_param) { /* row8x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */ size_t input_plane = conv_param->input_w_ * conv_param->input_h_; size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; size_t output_plane = conv_param->output_w_ * conv_param->output_h_; - int oc8 = UP_DIV(output_channel, C8NUM); + int oc8 = UP_ROUND(output_channel, C8NUM); int in_plane8 = UP_ROUND(input_plane, C8NUM); + int src_iw_stride = C8NUM; + int src_ih_stride = conv_param->input_w_ * C8NUM; + int src_kw_stride = in_plane8 * C8NUM; + int src_kh_stride = in_plane8 * conv_param->kernel_w_ * C8NUM; + int dst_oh_stride = conv_param->output_w_ * C8NUM; + int dst_ow_stride = C8NUM; + int dst_kh_stride = conv_param->dilation_h_ * conv_param->output_w_ * C8NUM; + int dst_kw_stride = conv_param->dilation_w_ * C8NUM; - for (int c = 0; c < oc8; c++) { - float *dst_ptr = tmp + c * output_plane * C8NUM; - const float *src_ptr = src + c * in_plane8 * kernel_plane * C8NUM; - memset(dst_ptr, 0, output_plane * C8NUM * sizeof(int32_t)); + for (int c = 0; c < oc8; c += 8) { + float *dst_ptr = tmp + c * output_plane; + const float *src_ptr = src + c * in_plane8 * kernel_plane; + memset(dst_ptr, 0, output_plane * C8NUM * sizeof(float)); for (int ih = 0; ih < conv_param->input_h_; ih++) { for (int iw = 0; iw < conv_param->input_w_; iw++) { @@ -63,14 +66,31 @@ int DeConvPostFp32C8x8(const float *src, float *tmp, const float *bias, float *d int kw_end = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); for (int kh = kh_start; kh < kh_end; kh++) { for (int kw = kw_start; kw < kw_end; kw++) { - int src_index = ih * conv_param->input_w_ * C8NUM + iw * C8NUM + - kh * in_plane8 * conv_param->kernel_w_ * C8NUM + kw * in_plane8 * C8NUM; - int dst_index = oh * conv_param->output_w_ * C8NUM + ow * C8NUM + - kh * conv_param->dilation_h_ * conv_param->output_w_ * C8NUM + - kw * conv_param->dilation_w_ * C8NUM; + int src_index = ih * src_ih_stride + iw * src_iw_stride + kh * src_kh_stride + kw * src_kw_stride; + int dst_index = oh * dst_oh_stride + ow * dst_ow_stride + kh * dst_kh_stride + kw * dst_kw_stride; + float *tmp_dst = dst_ptr + dst_index; + float *tmp_src = src_ptr + src_index; +#ifdef ENABLE_ARM64 + asm volatile( + "mov x0, %[tmp_src] \n" + "mov x1, %[tmp_dst] \n" + + "ld1 {v0.4s, v1.4s}, [x0] \n" + "ld1 {v2.4s, v3.4s}, [x1] \n" + + "fadd v0.4s, v0.4s, v2.4s \n" + "fadd v1.4s, v1.4s, v3.4s \n" + + "st1 {v0.4s, v1.4s}, [x1] \n" + + : + : [ tmp_src ] "r"(tmp_src), [ tmp_dst ] "r"(tmp_dst) + : "x0", "x1", "v0", "v1", "v2", "v3"); +#else for (int i = 0; i < C8NUM; i++) { - dst_ptr[dst_index + i] += src_ptr[src_index + i]; + tmp_dst[i] += tmp_src[i]; } +#endif } /*kw*/ } /*kh*/ } /*iw*/ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/deconv.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/deconv.h index 8853917598..75d03f91ea 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/deconv.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/deconv.h @@ -26,9 +26,6 @@ extern "C" { #endif void PackDeConvWeightFp32(const float *weight, float *dst, int input_channel, int output_channel, int plane); -int DeConvFp32(const float *input, const float *weight, float *output, float *tmp_buffer, - StrassenMatMulParameter matmul_param); - int DeConvPostFp32C4(const float *src, float *tmp_c4, float *dst, const float *bias, int output_channel, int input_plane, int kernel_plane, int output_plane, ConvParameter *conv_param); int DeConvPostFp32C8x8(const float *src, float *tmp_out, const float *bias, float *dst, int output_channel, diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc index 317654aa56..465cc28df3 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc @@ -370,35 +370,26 @@ TEST_F(TestConv1x1Fp32, Conv1x1Test2) { conv1x1->Run(); CompareOutputData(reinterpret_cast(outputs_[0]->Data()), correct, total_size, 0.0001); - auto ptr = reinterpret_cast(outputs_[0]->Data()); - bool first = true; - for (int i = 0; i < total_size; i++) { - if (fabs(ptr[i] - correct[i]) > 0.001 && first) { - printf("%d %f %f\n", i, ptr[i], correct[i]); - first = false; - } + /* running warm up */ + for (int i = 0; i < 0; i++) { + conv1x1->Run(); } - // /* running warm up */ - // for (int i = 0; i < 0; i++) { - // conv1x1->Run(); - // } - // - // /* running time cost */ - // int loop_count = 1; - // auto time_start = mindspore::lite::GetTimeUs(); - // for (int i = 0; i < loop_count; i++) { - // conv1x1->Run(); - // } - // auto time_end = mindspore::lite::GetTimeUs(); - // auto cost = time_end - time_start; - // uint64_t time_avg = cost / loop_count; - // printf("1x1 average time : %f ms\n", time_avg / 1000.0f); - // - // delete conv_param; - // delete conv1x1; - // for (auto t : inputs_) delete t; - // for (auto t : outputs_) delete t; - // free(correct); + /* running time cost */ + int loop_count = 1; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + conv1x1->Run(); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + uint64_t time_avg = cost / loop_count; + printf("1x1 average time : %f ms\n", time_avg / 1000.0f); + + delete conv_param; + delete conv1x1; + for (auto t : inputs_) delete t; + for (auto t : outputs_) delete t; + free(correct); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc index df7c0b19d3..a51901ac51 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc @@ -95,6 +95,99 @@ TEST_F(TestDeConvolutionFp32, DeConvWeightC4x4Pack2) { } TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test1) { + float in[] = {-9.389655, -5.83877, 7.5724425, -1.4675674, -2.6300175, 0, 0, 0, + -5.456284, 0.7406984, 16.965645, 10.888806, -7.2690716, 0, 0, 0, + -0.8614793, -4.404605, 10.917422, 0.11158327, 11.1863365, 0, 0, 0, + -5.2733865, -0.96367484, -4.731118, -7.576815, -3.4595785, 0, 0, 0, + -6.1621623, -0.6315082, -9.140878, 9.266748, -8.344107, 0, 0, 0, + 13.644127, 8.206812, 7.091153, -0.50162584, -3.792715, 0, 0, 0, + 2.0889723, 6.6916203, -5.3981733, 11.997365, -7.0394287, 0, 0, 0, + -9.254076, -5.5964484, -5.981469, -0.51114964, -2.7693212, 0, 0, 0}; + float bias[] = {0.7429814, 0.4863214, 0.9888875, 0.19727881, 0.009881007, 0, 0, 0}; + float out[8] = {0}; + + float no[] = {-8.646674, -4.7133026, -0.11849791, -4.530405, -5.419181, 14.387108, 2.8319538, -8.511095}; + PostConvFuncFp32C8(in, out, bias, 1, 8, 1, false, false); + CompareOutputData(out, no, 8, 0.0001); + + float relu[] = {0, 0, 0, 0, 0, 14.387108, 2.8319538, 0}; + PostConvFuncFp32C8(in, out, bias, 1, 8, 1, true, false); + CompareOutputData(out, relu, 8, 0.0001); + + float corr_relu6[] = {0, 0, 0, 0, 0, 6, 2.8319538, 0}; + PostConvFuncFp32C8(in, out, bias, 1, 8, 1, false, true); + CompareOutputData(out, corr_relu6, 8, 0.0001); +} + +TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test2) { + float in[] = {-9.389655, -5.83877, 7.5724425, -1.4675674, -2.6300175, 0, 0, 0, + -5.456284, 0.7406984, 16.965645, 10.888806, -7.2690716, 0, 0, 0, + -0.8614793, -4.404605, 10.917422, 0.11158327, 11.1863365, 0, 0, 0, + -5.2733865, -0.96367484, -4.731118, -7.576815, -3.4595785, 0, 0, 0, + -6.1621623, -0.6315082, -9.140878, 9.266748, -8.344107, 0, 0, 0, + 13.644127, 8.206812, 7.091153, -0.50162584, -3.792715, 0, 0, 0, + 2.0889723, 6.6916203, -5.3981733, 11.997365, -7.0394287, 0, 0, 0, + -9.254076, -5.5964484, -5.981469, -0.51114964, -2.7693212, 0, 0, 0}; + float bias[] = {0.7429814, 0.4863214, 0.9888875, 0.19727881, 0.009881007, 0, 0, 0}; + float out[16] = {0}; + + float no[] = {-8.646674, 0, -4.7133026, 0, -0.11849791, 0, -4.530405, 0, + -5.419181, 0, 14.387108, 0, 2.8319538, 0, -8.511095, 0}; + PostConvFuncFp32C8(in, out, bias, 1, 8, 2, false, false); + CompareOutputData(out, no, 16, 0.0001); + + float relu[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14.387108, 0, 2.8319538, 0, 0, 0}; + PostConvFuncFp32C8(in, out, bias, 1, 8, 2, true, false); + CompareOutputData(out, relu, 16, 0.0001); + + float corr_relu6[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 2.8319538, 0, 0, 0}; + PostConvFuncFp32C8(in, out, bias, 1, 8, 2, false, true); + CompareOutputData(out, corr_relu6, 16, 0.0001); +} + +TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test3) { + float in[] = {-9.389655, -5.83877, 7.5724425, -1.4675674, -2.6300175, 0, 0, 0, + -5.456284, 0.7406984, 16.965645, 10.888806, -7.2690716, 0, 0, 0, + -0.8614793, -4.404605, 10.917422, 0.11158327, 11.1863365, 0, 0, 0, + -5.2733865, -0.96367484, -4.731118, -7.576815, -3.4595785, 0, 0, 0, + -6.1621623, -0.6315082, -9.140878, 9.266748, -8.344107, 0, 0, 0, + 13.644127, 8.206812, 7.091153, -0.50162584, -3.792715, 0, 0, 0, + 2.0889723, 6.6916203, -5.3981733, 11.997365, -7.0394287, 0, 0, 0, + -9.254076, -5.5964484, -5.981469, -0.51114964, -2.7693212, 0, 0, 0}; + float bias[] = {0.7429814, 0.4863214, 0.9888875, 0.19727881, 0.009881007, 0, 0, 0}; + float out[24] = {0}; + + float no[] = {-8.646674, -5.3524485, 8.56133, -4.7133026, 1.2270198, 17.954533, -0.11849791, -3.9182835, + 11.90631, -4.530405, -0.47735345, -3.7422307, -5.419181, -0.14518678, -8.15199, 14.387108, + 8.693133, 8.080041, 2.8319538, 7.177942, -4.409286, -8.511095, -5.110127, -4.992582}; + PostConvFuncFp32C8(in, out, bias, 3, 8, 3, false, false); + CompareOutputData(out, no, 24, 0.0001); +} + +TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test4) { + float in[] = {-9.389655, -5.83877, 7.5724425, -1.4675674, -2.6300175, 0, 0, 0, + -5.456284, 0.7406984, 16.965645, 10.888806, -7.2690716, 0, 0, 0, + -0.8614793, -4.404605, 10.917422, 0.11158327, 11.1863365, 0, 0, 0, + -5.2733865, -0.96367484, -4.731118, -7.576815, -3.4595785, 0, 0, 0, + -6.1621623, -0.6315082, -9.140878, 9.266748, -8.344107, 0, 0, 0, + 13.644127, 8.206812, 7.091153, -0.50162584, -3.792715, 0, 0, 0, + 2.0889723, 6.6916203, -5.3981733, 11.997365, -7.0394287, 0, 0, 0, + -9.254076, -5.5964484, -5.981469, -0.51114964, -2.7693212, 0, 0, 0}; + float bias[] = {0.7429814, 0.4863214, 0.9888875, 0.19727881, 0.009881007, 0, 0, 0}; + float out[32] = {0}; + + float co32[] = {0, 0, 0, 0, 0, 1.2270198, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 14.387108, 8.693133, 0, 0, 2.8319538, 7.177942, 0, 0, 0, 0, 0, 0}; + PostConvFuncFp32C8(in, out, bias, 2, 8, 4, true, false); + CompareOutputData(out, co32, 32, 0.0001); + + float co32_relu6[] = {0, 0, 6, 0, 0, 1.2270198, 6, 6, 0, 0, 6, 0.3088621, 0, 0, 0, 0, + 0, 0, 0, 6, 6, 6, 6, 0, 2.8319538, 6, 0, 6, 0, 0, 0, 0}; + PostConvFuncFp32C8(in, out, bias, 4, 8, 4, false, true); + CompareOutputData(out, co32_relu6, 32, 0.0001); +} + +TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test5) { float in[] = {-9.389655, -5.83877, 7.5724425, -1.4675674, -2.6300175, 0, 0, 0, -5.456284, 0.7406984, 16.965645, 10.888806, -7.2690716, 0, 0, 0, -0.8614793, -4.404605, 10.917422, 0.11158327, 11.1863365, 0, 0, 0, @@ -125,14 +218,106 @@ TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test1) { 0, 0, 0, 6, 0, 6, 6, 6, 0, 0, 2.8319538, 6, 0, 6, 0, 0, 0, 0, 0, 0}; PostConvFuncFp32C8(in, out, bias, 5, 8, 5, false, true); CompareOutputData(out, corr_relu6, 40, 0.0001); +} + +TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test6) { + float in[] = {-9.389655, -5.83877, 7.5724425, -1.4675674, -5.456284, 0.7406984, 16.965645, 10.888806, + -0.8614793, -4.404605, 10.917422, 0.11158327, -5.2733865, -0.96367484, -4.731118, -7.576815, + -6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 7.091153, -0.50162584, + 2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469, -0.51114964}; + float bias[] = {0, 0, 0, 0, 0, 0, 0, 0}; + float out[24] = {0}; + + float no_3[] = {-9.389655, -5.83877, 7.5724425, 0, 0, 0, -0.8614793, -4.404605, 10.917422, 0, 0, 0, + -6.1621623, -0.6315082, -9.140878, 0, 0, 0, 2.0889723, 6.6916203, -5.3981733, 0, 0, 0}; + PostConvFuncFp32C8(in, out, bias, 3, 4, 6, false, false); + CompareOutputData(out, no_3, 24, 0.0001); + + float no_6[] = {-9.389655, -5.83877, 7.5724425, -1.4675674, -5.456284, 0.7406984, -0.8614793, -4.404605, + 10.917422, 0.11158327, -5.2733865, -0.96367484, -6.1621623, -0.6315082, -9.140878, 9.266748, + 13.644127, 8.206812, 2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484}; + PostConvFuncFp32C8(in, out, bias, 6, 4, 6, false, false); + CompareOutputData(out, no_6, 24, 0.0001); +} + +TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test7) { + float in[] = {-9.389655, -5.83877, 7.5724425, -1.4675674, -5.456284, 0.7406984, 16.965645, 10.888806, + -0.8614793, -4.404605, 10.917422, 0.11158327, -5.2733865, -0.96367484, -4.731118, -7.576815, + -6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 7.091153, -0.50162584, + 2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469, -0.51114964}; + float bias[] = {0, 0, 0, 0, 0, 0, 0, 0}; + float out[28] = {0}; + + float no[] = {-9.389655, -5.83877, 7.5724425, -1.4675674, -5.456284, 0.7406984, 16.965645, + -0.8614793, -4.404605, 10.917422, 0.11158327, -5.2733865, -0.96367484, -4.731118, + -6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 7.091153, + 2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469}; + PostConvFuncFp32C8(in, out, bias, 7, 4, 7, false, false); + CompareOutputData(out, no, 28, 0.0001); +} + +TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test8_2) { + float in[] = {-9.389655, -5.83877, 7.5724425, -1.4675674, -5.456284, 0.7406984, 16.965645, 10.888806, + -0.8614793, -4.404605, 10.917422, 0.11158327, -5.2733865, -0.96367484, -4.731118, -7.576815, + -6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 7.091153, -0.50162584, + 2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469, -0.51114964}; + float bias[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + float out[28] = {0}; + + float no[] = {-9.389655, -5.83877, 7.5724425, -1.4675674, -5.456284, 0.7406984, 16.965645, 10.888806, + -6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 7.091153, -0.50162584, + -0.8614793, -4.404605, 10.917422, 0.11158327, -5.2733865, -0.96367484, -4.731118, -7.576815, + 2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469, -0.51114964}; + PostConvFuncFp32C8(in, out, bias, 16, 2, 16, false, false); + CompareOutputData(out, no, 28, 0.0001); +} + +TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test8_4) { + float in[] = {-9.389655, -5.83877, 7.5724425, -1.4675674, -5.456284, 0.7406984, 16.965645, 10.888806, + -0.8614793, -4.404605, 10.917422, 0.11158327, -5.2733865, -0.96367484, -4.731118, -7.576815, + -6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 7.091153, -0.50162584, + 2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469, -0.51114964, + -9.389655, -5.83877, 7.5724425, -1.4675674, -5.456284, 0.7406984, 16.965645, 10.888806, + -0.8614793, -4.404605, 10.917422, 0.11158327, -5.2733865, -0.96367484, -4.731118, -7.576815, + -6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 7.091153, -0.50162584, + 2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469, -0.51114964}; + float bias[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + float out[64] = {0}; + + float no[] = {-9.389655, -5.83877, 7.5724425, -1.4675674, -5.456284, 0.7406984, 16.965645, 10.888806, + -9.389655, -5.83877, 7.5724425, -1.4675674, -5.456284, 0.7406984, 16.965645, 10.888806, + -0.8614793, -4.404605, 10.917422, 0.11158327, -5.2733865, -0.96367484, -4.731118, -7.576815, + -0.8614793, -4.404605, 10.917422, 0.11158327, -5.2733865, -0.96367484, -4.731118, -7.576815, + -6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 7.091153, -0.50162584, + -6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 7.091153, -0.50162584, + 2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469, -0.51114964, + 2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469, -0.51114964}; + PostConvFuncFp32C8(in, out, bias, 16, 4, 16, false, false); + CompareOutputData(out, no, 64, 0.0001); +} - float nob_relu[] = {0, 0, 7.5724425, 0, 0, 0, 0.7406984, 16.965645, - 10.888806, 0, 0, 0, 10.917422, 0.11158327, 11.1863365, 0, - 0, 0, 0, 0, 0, 0, 0, 9.266748, - 0, 13.644127, 8.206812, 7.091153, 0, 0, 2.0889723, 6.6916203, - 0, 11.997365, 0, 0, 0, 0, 0, 0}; - PostConvFuncFp32C8(in, out, nullptr, 5, 8, 5, true, false); - CompareOutputData(out, nob_relu, 40, 0.0001); +TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test8_8) { + float in[] = {-9.389655, -5.83877, 7.5724425, -1.4675674, -5.456284, 0.7406984, 16.965645, 10.888806, + -0.8614793, -4.404605, 10.917422, 0.11158327, -5.2733865, -0.96367484, -4.731118, -7.576815, + -6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 7.091153, -0.50162584, + 2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469, -0.51114964, + -9.389655, -5.83877, 7.5724425, -1.4675674, -5.456284, 0.7406984, 16.965645, 10.888806, + -0.8614793, -4.404605, 10.917422, 0.11158327, -5.2733865, -0.96367484, -4.731118, -7.576815, + -6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 7.091153, -0.50162584, + 2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469, -0.51114964}; + float bias[] = {0, 0, 0, 0, 0, 0, 0, 0}; + float out[64] = {0}; + + float no[] = {-9.389655, -5.83877, 7.5724425, -1.4675674, -5.456284, 0.7406984, 16.965645, 10.888806, + -0.8614793, -4.404605, 10.917422, 0.11158327, -5.2733865, -0.96367484, -4.731118, -7.576815, + -6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 7.091153, -0.50162584, + 2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469, -0.51114964, + -9.389655, -5.83877, 7.5724425, -1.4675674, -5.456284, 0.7406984, 16.965645, 10.888806, + -0.8614793, -4.404605, 10.917422, 0.11158327, -5.2733865, -0.96367484, -4.731118, -7.576815, + -6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 7.091153, -0.50162584, + 2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469, -0.51114964}; + PostConvFuncFp32C8(in, out, bias, 8, 8, 8, false, false); + CompareOutputData(out, no, 64, 0.0001); } int DeConvTestInit1(std::vector *inputs_, std::vector *outputs_,