| @@ -15,9 +15,7 @@ | |||
| */ | |||
| #include "src/runtime/kernel/arm/int8/deconvolution_int8.h" | |||
| #include "src/runtime/kernel/arm/nnacl/quantization/fixed_point.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/kernel/arm/nnacl/optimized_kernel.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| @@ -89,9 +87,8 @@ int DeConvInt8CPUKernel::Init() { | |||
| } | |||
| void DeConvInt8CPUKernel::CheckSupportOptimize() { | |||
| matmul_func_ = nullptr; | |||
| support_optimize_ = true; | |||
| matmul_func_ = MatMulInt8_16x4; | |||
| #ifdef ENABLE_ARM64 | |||
| void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_; | |||
| if (optimize_op_handler != nullptr) { | |||
| @@ -102,12 +99,15 @@ void DeConvInt8CPUKernel::CheckSupportOptimize() { | |||
| MS_LOG(ERROR) << "load matmul func failed! " << dlopen_error << "."; | |||
| support_optimize_ = false; | |||
| matmul_func_ = nullptr; | |||
| } else { | |||
| support_optimize_ = true; | |||
| } | |||
| } else { | |||
| support_optimize_ = false; | |||
| matmul_func_ = nullptr; | |||
| } | |||
| #endif | |||
| return; | |||
| } | |||
| int DeConvInt8CPUKernel::InitParam() { | |||
| @@ -120,6 +120,7 @@ int DeConvInt8CPUKernel::InitParam() { | |||
| matmul_param_->deep_ = conv_param_->input_channel_; | |||
| matmul_param_->col_ = conv_param_->output_channel_ * conv_param_->kernel_h_ * conv_param_->kernel_w_; | |||
| /* optimize normal -> same data layout */ | |||
| input_trans_func_ = RowMajor2Row16x4MajorInt8; | |||
| size_t oc4 = UP_DIV(conv_param_->output_channel_, C4NUM); | |||
| thread_count_ = MSMIN(op_parameter_->thread_num_, oc4); | |||
| @@ -0,0 +1,246 @@ | |||
| #ifdef __aarch64__ | |||
| .text | |||
| .align 5 | |||
| //.p2align 5,,15 | |||
| .global PostFuncInt8C4Neon64 | |||
| #ifndef __APPLE__ | |||
| .type PostFuncInt8C4Neon64, %function | |||
| #endif | |||
| //void PostFuncInt8C4Neon64(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc4div, size_t oc4res, | |||
| // size_t plane, size_t stride, int32_t multiplier, int32_t left_shift, int32_t right_shift, | |||
| // int32_t zp, int32_t mini, int32_t maxi); | |||
| // x0 in | |||
| // x1 bias | |||
| // x2 out | |||
| // x3 oc4div | |||
| // x4 oc4res | |||
| // x5 plane | |||
| // x6 stride | |||
| // x7 multiplier | |||
| // x8 left_shift | |||
| // x9 right_shift | |||
| // x10 zp | |||
| // x11 mini | |||
| // x12 maxi | |||
| // v0 ~ v15 value | |||
| // x24 x25 write loop tmp buf | |||
| // v16 bias data | |||
| // v26 multiplier | |||
| // v27 left_shift | |||
| // v28 right_shift | |||
| // v29 zp | |||
| // v30 min | |||
| // v31 max | |||
| // w15 oc4 loop control | |||
| // w16 hw loop control | |||
| PostFuncInt8C4Neon64: | |||
| ldr w8, [sp] | |||
| ldr w9, [sp, #8] | |||
| ldr w10, [sp, #16] | |||
| ldr w11, [sp, #24] | |||
| ldr w12, [sp, #32] | |||
| ldr w13, [sp, #40] | |||
| dup v26.4s, w7 | |||
| dup v27.4s, w8 | |||
| dup v28.4s, w9 | |||
| dup v29.4s, w10 | |||
| dup v30.4s, w11 | |||
| dup v31.4s, w12 | |||
| mov w15, #0 | |||
| Loop_C4: | |||
| cmp w15, w3 | |||
| beq Loop_C1 | |||
| mov x25, #4 | |||
| mul x24, x15, x25 | |||
| add x25, x2, x24 | |||
| add w15, w15, #4 | |||
| mov w16, w5 | |||
| ld1 {v16.4s}, [x1], #16 | |||
| Loop_4x4: | |||
| cmp w16, #4 | |||
| blt Loop_1x4 | |||
| sub w16, w16, #4 | |||
| ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 | |||
| add v0.4s, v0.4s, v16.4s | |||
| add v1.4s, v1.4s, v16.4s | |||
| add v2.4s, v2.4s, v16.4s | |||
| add v3.4s, v3.4s, v16.4s | |||
| sqshl v0.4s, v0.4s, v27.4s | |||
| sqshl v1.4s, v1.4s, v27.4s | |||
| sqshl v2.4s, v2.4s, v27.4s | |||
| sqshl v3.4s, v3.4s, v27.4s | |||
| sqrdmulh v0.4s, v0.4s, v26.4s | |||
| sqrdmulh v1.4s, v1.4s, v26.4s | |||
| sqrdmulh v2.4s, v2.4s, v26.4s | |||
| sqrdmulh v3.4s, v3.4s, v26.4s | |||
| and v4.16b, v28.16b, v0.16b | |||
| and v5.16b, v28.16b, v1.16b | |||
| and v6.16b, v28.16b, v2.16b | |||
| and v7.16b, v28.16b, v3.16b | |||
| sshr v4.4s, v4.4s, #31 | |||
| sshr v5.4s, v5.4s, #31 | |||
| sshr v6.4s, v6.4s, #31 | |||
| sshr v7.4s, v7.4s, #31 | |||
| sqadd v0.4s, v0.4s, v4.4s | |||
| sqadd v1.4s, v1.4s, v5.4s | |||
| sqadd v2.4s, v2.4s, v6.4s | |||
| sqadd v3.4s, v3.4s, v7.4s | |||
| srshl v0.4s, v0.4s, v28.4s | |||
| srshl v1.4s, v1.4s, v28.4s | |||
| srshl v2.4s, v2.4s, v28.4s | |||
| srshl v3.4s, v3.4s, v28.4s | |||
| add v0.4s, v0.4s, v29.4s | |||
| add v1.4s, v1.4s, v29.4s | |||
| add v2.4s, v2.4s, v29.4s | |||
| add v3.4s, v3.4s, v29.4s | |||
| smax v0.4s, v0.4s, v30.4s | |||
| smax v1.4s, v1.4s, v30.4s | |||
| smax v2.4s, v2.4s, v30.4s | |||
| smax v3.4s, v3.4s, v30.4s | |||
| smin v0.4s, v0.4s, v31.4s | |||
| smin v1.4s, v1.4s, v31.4s | |||
| smin v2.4s, v2.4s, v31.4s | |||
| smin v3.4s, v3.4s, v31.4s | |||
| sqxtn v4.4h, v0.4s | |||
| sqxtn v5.4h, v1.4s | |||
| sqxtn v6.4h, v2.4s | |||
| sqxtn v7.4h, v3.4s | |||
| sqxtn v0.8b, v4.8h | |||
| sqxtn v1.8b, v5.8h | |||
| sqxtn v2.8b, v6.8h | |||
| sqxtn v3.8b, v7.8h | |||
| st1 {v0.s}[0], [x2], x6 | |||
| st1 {v1.s}[0], [x2], x6 | |||
| st1 {v2.s}[0], [x2], x6 | |||
| st1 {v3.s}[0], [x2], x6 | |||
| b Loop_4x4 | |||
| Loop_1x4: | |||
| cmp w16, #0 | |||
| beq Loop_C4 | |||
| sub w16, w16, #1 | |||
| ld1 {v0.4s}, [x0], #16 | |||
| add v0.4s, v0.4s, v16.4s | |||
| sqshl v0.4s, v0.4s, v27.4s | |||
| sqrdmulh v0.4s, v0.4s, v26.4s | |||
| and v2.16b, v28.16b, v0.16b | |||
| sshr v2.4s, v2.4s, #31 | |||
| sqadd v0.4s, v0.4s, v2.4s | |||
| srshl v0.4s, v0.4s, v28.4s | |||
| add v0.4s, v0.4s, v29.4s | |||
| smax v0.4s, v0.4s, v30.4s | |||
| smin v0.4s, v0.4s, v31.4s | |||
| sqxtn v1.4h, v0.4s | |||
| sqxtn v0.8b, v1.8h | |||
| st1 {v0.s}[0], [x2], x6 | |||
| b Loop_1x4 | |||
| Loop_C1: | |||
| cmp x4, #0 | |||
| beq End | |||
| mov w16, w5 | |||
| ld1 {v16.4s}, [x1], #16 | |||
| mov x25, #4 | |||
| mul x24, x15, x25 | |||
| add x25, x2, x24 | |||
| add x24, x25, #2 | |||
| cmp x4, #1 | |||
| beq Loop_C1_1 | |||
| cmp x4, #2 | |||
| beq Loop_C1_2 | |||
| cmp x4, #3 | |||
| beq Loop_C1_3 | |||
| Loop_C1_1: | |||
| cmp w16, #0 | |||
| beq End | |||
| sub w16, w16, #1 | |||
| ld1 {v0.4s}, [x0], #16 | |||
| add v0.4s, v0.4s, v16.4s | |||
| sqshl v0.4s, v0.4s, v27.4s | |||
| sqrdmulh v0.4s, v0.4s, v26.4s | |||
| and v2.16b, v28.16b, v0.16b | |||
| sshr v2.4s, v2.4s, #31 | |||
| sqadd v0.4s, v0.4s, v2.4s | |||
| srshl v0.4s, v0.4s, v28.4s | |||
| add v0.4s, v0.4s, v29.4s | |||
| smax v0.4s, v0.4s, v30.4s | |||
| smin v0.4s, v0.4s, v31.4s | |||
| sqxtn v1.4h, v0.4s | |||
| sqxtn v0.8b, v1.8h | |||
| st1 {v0.b}[0], [x25], x6 | |||
| b Loop_C1_1 | |||
| Loop_C1_2: | |||
| cmp w16, #0 | |||
| beq End | |||
| sub w16, w16, #1 | |||
| ld1 {v0.4s}, [x0], #16 | |||
| add v0.4s, v0.4s, v16.4s | |||
| sqshl v0.4s, v0.4s, v27.4s | |||
| sqrdmulh v0.4s, v0.4s, v26.4s | |||
| and v2.16b, v28.16b, v0.16b | |||
| sshr v2.4s, v2.4s, #31 | |||
| sqadd v0.4s, v0.4s, v2.4s | |||
| srshl v0.4s, v0.4s, v28.4s | |||
| add v0.4s, v0.4s, v29.4s | |||
| smax v0.4s, v0.4s, v30.4s | |||
| smin v0.4s, v0.4s, v31.4s | |||
| sqxtn v1.4h, v0.4s | |||
| sqxtn v0.8b, v1.8h | |||
| st1 {v0.h}[0], [x25], x6 | |||
| b Loop_C1_2 | |||
| Loop_C1_3: | |||
| cmp w16, #0 | |||
| beq End | |||
| sub w16, w16, #1 | |||
| ld1 {v0.4s}, [x0], #16 | |||
| add v0.4s, v0.4s, v16.4s | |||
| sqshl v0.4s, v0.4s, v27.4s | |||
| sqrdmulh v0.4s, v0.4s, v26.4s | |||
| and v2.16b, v28.16b, v0.16b | |||
| sshr v2.4s, v2.4s, #31 | |||
| sqadd v0.4s, v0.4s, v2.4s | |||
| srshl v0.4s, v0.4s, v28.4s | |||
| add v0.4s, v0.4s, v29.4s | |||
| smax v0.4s, v0.4s, v30.4s | |||
| smin v0.4s, v0.4s, v31.4s | |||
| sqxtn v1.4h, v0.4s | |||
| sqxtn v0.8b, v1.8h | |||
| st1 {v0.h}[0], [x25], x6 | |||
| st1 {v0.b}[2], [x24], x6 | |||
| b Loop_C1_3 | |||
| End: | |||
| ret | |||
| #endif | |||
| @@ -15,9 +15,10 @@ | |||
| */ | |||
| #include "nnacl/int8/common_func.h" | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| void PostConvFuncCommInt8(const int32_t *in, int8_t *out, const int32_t *bias, size_t oc, size_t plane, | |||
| size_t out_oc_stride, size_t in_plane_stride, int32_t multiplier, int8_t mini, int8_t maxi, | |||
| size_t out_oc_stride, size_t in_plane_stride, int32_t multiplier, int32_t mini, int32_t maxi, | |||
| int32_t left_shift, int32_t right_shift, int32_t zp, int size) { | |||
| if (size == 0) { | |||
| return; | |||
| @@ -40,18 +41,26 @@ void PostConvFuncCommInt8(const int32_t *in, int8_t *out, const int32_t *bias, s | |||
| return; | |||
| } | |||
| void PostFuncInt8C8(const int *in, const int *bias, int8_t *out, int oc, int plane, int32_t multiplier, | |||
| int32_t left_shift, int32_t right_shift, int32_t zp, int8_t mini, int8_t maxi) { | |||
| void PostFuncInt8C8(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, int32_t multiplier, | |||
| int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini, int32_t maxi) { | |||
| /* ((int32_t)row8x8-major + bias) * multiplier + output_zp => (int8)relu => (int8_t)row-major */ | |||
| PostConvFuncCommInt8(in, out, bias, oc, plane, oc, UP_ROUND(plane, C8NUM) * C8NUM, multiplier, mini, maxi, left_shift, | |||
| right_shift, zp, C8NUM); | |||
| return; | |||
| } | |||
| void PostFuncInt8C4(const int *in, const int *bias, int8_t *out, int oc, int plane, int stride, int32_t multiplier, | |||
| int32_t left_shift, int32_t right_shift, int32_t zp, int8_t mini, int8_t maxi) { | |||
| /* ((int32_t)row4x4-major + bias) * multiplier + output_zp => (int8)relu => (int8_t)row-major */ | |||
| void PostFuncInt8C4(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, size_t stride, | |||
| int32_t multiplier, int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini, | |||
| int32_t maxi) { | |||
| /* ((int32_t)row4x4-major + bias) * multiplier + output_zp => (int8)relu => (int8_t)row-major */ | |||
| #ifndef ENABLE_ARM64 | |||
| PostConvFuncCommInt8(in, out, bias, oc, plane, stride, UP_ROUND(plane, C4NUM) * C4NUM, multiplier, mini, maxi, | |||
| left_shift, right_shift, zp, C4NUM); | |||
| #else | |||
| size_t oc4div = oc / C4NUM * C4NUM; | |||
| size_t oc4res = oc % C4NUM; | |||
| PostFuncInt8C4Neon64(in, bias, out, oc4div, oc4res, plane, stride * sizeof(int8_t), multiplier, left_shift, | |||
| right_shift, zp, mini, maxi); | |||
| #endif | |||
| return; | |||
| } | |||
| @@ -27,30 +27,21 @@ | |||
| extern "C" { | |||
| #endif | |||
| void PostFuncInt8C8(const int *in, const int *bias, int8_t *out, int oc, int plane, int32_t multiplier, | |||
| int32_t left_shift, int32_t right_shift, int32_t zp, int8_t mini, int8_t maxi); | |||
| void PostFuncInt8C4(const int *in, const int *bias, int8_t *out, int oc, int plane, int stride, int32_t multiplier, | |||
| int32_t left_shift, int32_t right_shift, int32_t zp, int8_t mini, int8_t maxi); | |||
| #ifdef ENABLE_ARM | |||
| void PostFuncInt8C8(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, int32_t multiplier, | |||
| int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini, int32_t maxi); | |||
| void PostFuncInt8C4(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, size_t stride, | |||
| int32_t multiplier, int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini, | |||
| int32_t maxi); | |||
| #ifdef ENABLE_ARM64 | |||
| void PostFuncInt8C4Neon64(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc4div, size_t oc4res, | |||
| size_t plane, size_t stride, int32_t multiplier, int32_t left_shift, int32_t right_shift, | |||
| int32_t zp, int32_t mini, int32_t maxi); | |||
| void IndirectGemmInt16to32_8x4(int32_t *dst, const int16_t *src, const int16_t *weight, size_t ksize, size_t ic8, | |||
| size_t oc4, size_t offset); | |||
| #ifdef ENABLE_ARM64 | |||
| void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize, | |||
| size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min, | |||
| size_t act_max, size_t out_zp, size_t out_multiplier, size_t shift_before, | |||
| size_t shift_after); | |||
| // #elif defined(ENABLE_ARM32) | |||
| // void IndirectGemmInt8_2x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, | |||
| // size_t ksize, | |||
| // size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min, | |||
| // size_t act_max, size_t out_zp, size_t out_multiplier, size_t shift_before, | |||
| // size_t shift_after); | |||
| #endif | |||
| #endif | |||
| #ifdef ENABLE_ARM | |||
| void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width, | |||
| size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, | |||
| size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); | |||
| @@ -136,60 +136,109 @@ int DeConvPostInt8C4(const int32_t *src, const int32_t *bias, int32_t *tmp, int8 | |||
| void DeConvWeightTransInt8(int8_t *src, int8_t *dst, int input_channel, int output_channel, int plane, | |||
| bool support_optimize_) { | |||
| if (support_optimize_) { | |||
| int ic16 = UP_ROUND(input_channel, C16NUM); | |||
| int oc4 = UP_ROUND(output_channel, C4NUM); | |||
| for (int ic = 0; ic < input_channel; ic++) { | |||
| int ic16div = ic / C16NUM, ic16mod = ic % C16NUM; | |||
| for (int oc = 0; oc < output_channel; oc++) { | |||
| int oc4div = oc / C4NUM, oc4mod = oc % C4NUM; | |||
| for (int hw = 0; hw < plane; hw++) { | |||
| int src_index = ic * output_channel * plane + hw * output_channel + oc; | |||
| int dst_index = | |||
| hw * ic16 * oc4 + oc4div * ic16 * C4NUM + ic16div * C16NUM * C4NUM + oc4mod * C16NUM + ic16mod; | |||
| dst[dst_index] = src[src_index]; | |||
| } | |||
| /* optimize normal -> same layout */ | |||
| int ic16 = UP_ROUND(input_channel, C16NUM); | |||
| int oc4 = UP_ROUND(output_channel, C4NUM); | |||
| for (int ic = 0; ic < input_channel; ic++) { | |||
| int ic16div = ic / C16NUM, ic16mod = ic % C16NUM; | |||
| for (int oc = 0; oc < output_channel; oc++) { | |||
| int oc4div = oc / C4NUM, oc4mod = oc % C4NUM; | |||
| for (int hw = 0; hw < plane; hw++) { | |||
| int src_index = ic * output_channel * plane + hw * output_channel + oc; | |||
| int dst_index = hw * ic16 * oc4 + oc4div * ic16 * C4NUM + ic16div * C16NUM * C4NUM + oc4mod * C16NUM + ic16mod; | |||
| dst[dst_index] = src[src_index]; | |||
| } | |||
| } | |||
| } else { | |||
| /* normal int8 deconv */ | |||
| } | |||
| return; | |||
| } | |||
| void DeConvPackWeightSum(int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep16, int col4, | |||
| bool suppport_opt) { | |||
| if (suppport_opt) { | |||
| for (int c = 0; c < col4; c++) { | |||
| int c4div = c / C4NUM, c4mod = c % C4NUM; | |||
| int32_t value = 0; | |||
| for (int r = 0; r < deep16; r++) { | |||
| int r16div = r / 16, r16mod = r % 16; | |||
| int src_index = c4div * deep16 * C4NUM + r16div * C4NUM * C16NUM + c4mod * C16NUM + r16mod; | |||
| value += weight[src_index]; | |||
| } | |||
| weight_sum[c] = filter_zp * input_zp * deep16 - value * input_zp; | |||
| /* optimize normal -> same layout */ | |||
| for (int c = 0; c < col4; c++) { | |||
| int c4div = c / C4NUM, c4mod = c % C4NUM; | |||
| int32_t value = 0; | |||
| for (int r = 0; r < deep16; r++) { | |||
| int r16div = r / C16NUM, r16mod = r % C16NUM; | |||
| int src_index = c4div * deep16 * C4NUM + r16div * C4NUM * C16NUM + c4mod * C16NUM + r16mod; | |||
| value += weight[src_index]; | |||
| } | |||
| } else { | |||
| /* normal int8 deconv */ | |||
| weight_sum[c] = filter_zp * input_zp * deep16 - value * input_zp; | |||
| } | |||
| return; | |||
| } | |||
| void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, int row4, int col16, bool suppport_opt) { | |||
| if (suppport_opt) { | |||
| for (int r = 0; r < row4; r++) { | |||
| int32_t tmp_value = 0; | |||
| for (int c = 0; c < col16; c++) { | |||
| int r4div = r / C4NUM, r4mod = r % C4NUM, c16div = c / C16NUM, c16mod = c % C16NUM; | |||
| int src_index = r4div * C4NUM * col16 + c16div * C16NUM * C4NUM + r4mod * C16NUM + c16mod; | |||
| tmp_value += src[src_index]; | |||
| } | |||
| dst[r] = tmp_value * filter_zp; | |||
| void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16, | |||
| bool suppport_opt) { | |||
| /* optimize normal -> same layout */ | |||
| #ifdef ENABLE_ARM64 | |||
| asm volatile( | |||
| "mov x10, %[src] \n" | |||
| "mov x11, %[dst] \n" | |||
| "dup v15.4s, %w[filter_zp] \n" | |||
| "mov x0, #0 \n" | |||
| "1: \n" | |||
| "cmp x0, %[row4] \n" | |||
| "beq 4f \n" | |||
| "add x0, x0, #4\n" | |||
| "dup v10.4s, wzr \n" | |||
| "mov x2, #0 \n" | |||
| "2: \n" | |||
| "cmp x2, %[col16] \n" | |||
| "beq 3f \n" | |||
| "add x2, x2, #16\n" | |||
| "ld1 {v0.16b}, [x10], #16\n" | |||
| "ld1 {v1.16b}, [x10], #16\n" | |||
| "ld1 {v2.16b}, [x10], #16\n" | |||
| "ld1 {v3.16b}, [x10], #16\n" | |||
| "saddlp v4.8h, v0.16b \n" | |||
| "saddlp v5.8h, v1.16b \n" | |||
| "saddlp v6.8h, v2.16b \n" | |||
| "saddlp v7.8h, v3.16b \n" | |||
| "saddlp v0.4S, v4.8h \n" | |||
| "saddlp v1.4S, v5.8h \n" | |||
| "saddlp v2.4S, v6.8h \n" | |||
| "saddlp v3.4S, v7.8h \n" | |||
| "addv s4, v0.4S \n" | |||
| "addv s5, v1.4S \n" | |||
| "addv s6, v2.4S \n" | |||
| "addv s7, v3.4S \n" | |||
| "mov v0.s[0], v4.s[0] \n" | |||
| "mov v0.s[1], v5.s[0] \n" | |||
| "mov v0.s[2], v6.s[0] \n" | |||
| "mov v0.s[3], v7.s[0] \n" | |||
| "add v10.4s, v10.4s, v0.4s \n" | |||
| "b 2b\n" | |||
| "3: \n" | |||
| "mul v10.4s, v10.4s, v15.4s \n" | |||
| "st1 {v10.4s}, [x11], #16 \n" | |||
| "beq 1b \n" | |||
| "4: \n" | |||
| : | |||
| : [ dst ] "r"(dst), [ src ] "r"(src), [ row4 ] "r"(row4), [ col16 ] "r"(col16), [ filter_zp ] "r"(filter_zp) | |||
| : "x0", "x1", "x2", "x3", "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v15"); | |||
| #else | |||
| for (int r = 0; r < row4; r++) { | |||
| int32_t tmp_value = 0; | |||
| for (int c = 0; c < col16; c++) { | |||
| int r4div = r / C4NUM, r4mod = r % C4NUM, c16div = c / C16NUM, c16mod = c % C16NUM; | |||
| int src_index = r4div * C4NUM * col16 + c16div * C16NUM * C4NUM + r4mod * C16NUM + c16mod; | |||
| tmp_value += src[src_index]; | |||
| } | |||
| } else { | |||
| /* normal int8 deconv */ | |||
| } | |||
| #endif | |||
| return; | |||
| } | |||
| @@ -199,18 +248,14 @@ int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, int32 | |||
| if (matmul_func != NULL) { | |||
| matmul_func(input, weight, output, act_row, act_col, act_deep, input_sum, weight_sum); | |||
| } else { | |||
| /* normal int8 deconv */ | |||
| MatMulInt8_16x4(input, weight, output, act_row, act_col, act_deep, input_sum, weight_sum); | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel, | |||
| ConvParameter *conv_param, bool support_optimize) { | |||
| int error_code = NNACL_OK; | |||
| if (support_optimize) { | |||
| error_code = DeConvPostInt8C4(src, bias, tmp, out, output_channel, conv_param); | |||
| } else { | |||
| /* normal int8 deconv post */ | |||
| } | |||
| /* optimize normal -> same layout (C4) */ | |||
| int error_code = DeConvPostInt8C4(src, bias, tmp, out, output_channel, conv_param); | |||
| return error_code; | |||
| } | |||
| @@ -29,7 +29,8 @@ extern "C" { | |||
| #endif | |||
| void DeConvPackWeightSum(int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep16, int col4, | |||
| bool suppport_opt); | |||
| void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, int row4, int col16, bool suppport_opt); | |||
| void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16, | |||
| bool suppport_opt); | |||
| void DeConvWeightTransInt8(int8_t *src, int8_t *dst, int input_channel, int output_channel, int plane, | |||
| bool support_optimize_); | |||
| @@ -28,18 +28,66 @@ void RowMajor2Row8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) | |||
| } | |||
| } | |||
| void MatrixPack4x16UnitInt8(int8_t *src, int8_t *dst, int row, int col, int stride) { | |||
| for (int r = 0; r < row; r++) { | |||
| int8_t *src_r = src + r * stride; | |||
| int8_t *dst_r = dst + r * C16NUM; | |||
| memcpy(dst_r, src_r, col * sizeof(int8_t)); | |||
| } | |||
| return; | |||
| } | |||
| void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col) { | |||
| /* Row-major to row16x4-major (block row-major) */ | |||
| int col16 = UP_ROUND(col, C16NUM); | |||
| for (int r = 0; r < row; r++) { | |||
| int r4div = r / C4NUM; | |||
| int r4mod = r % C4NUM; | |||
| for (int c = 0; c < col; c++) { | |||
| int c16div = c / C16NUM; | |||
| int c16mod = c % C16NUM; | |||
| int src_index = r * col + c; | |||
| int dst_index = r4div * C4NUM * col16 + c16div * C16NUM * C4NUM + r4mod * C16NUM + c16mod; | |||
| ((int8_t *)dst_ptr)[dst_index] = ((int8_t *)src_ptr)[src_index]; | |||
| size_t row_4div = row / C4NUM * C4NUM; | |||
| size_t row_4res = row - row_4div; | |||
| size_t col_16div = col / C16NUM * C16NUM; | |||
| size_t col_16res = col - col_16div; | |||
| int8_t *src_r = (int8_t *)src_ptr; | |||
| int8_t *dst_r = (int8_t *)dst_ptr; | |||
| for (int ri = 0; ri < row_4div; ri += C4NUM) { | |||
| for (int ci = 0; ci < col_16div; ci += C16NUM) { | |||
| #ifdef ENABLE_ARM64 | |||
| int8_t *src_c = src_r + ci; | |||
| int8_t *dst_c = dst_r + ci * C4NUM; | |||
| asm volatile( | |||
| "mov x10, %[src_c] \n" | |||
| "mov x11, %[dst_c] \n" | |||
| "ld1 {v0.16b}, [x10], %[col]\n" | |||
| "ld1 {v1.16b}, [x10], %[col]\n" | |||
| "ld1 {v2.16b}, [x10], %[col]\n" | |||
| "ld1 {v3.16b}, [x10], %[col]\n" | |||
| "st1 {v0.16b}, [x11], #16\n" | |||
| "st1 {v1.16b}, [x11], #16\n" | |||
| "st1 {v2.16b}, [x11], #16\n" | |||
| "st1 {v3.16b}, [x11], #16\n" | |||
| : | |||
| : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ col ] "r"(col) | |||
| : "x10", "x11", "v0", "v1", "v2", "v3"); | |||
| #else | |||
| MatrixPack4x16UnitInt8(src_r + ci, dst_r + ci * C4NUM, C4NUM, C16NUM, col); | |||
| #endif | |||
| } | |||
| if (col != col_16div) { | |||
| MatrixPack4x16UnitInt8(src_r + col_16div, dst_r + col_16div * C4NUM, C4NUM, col_16res, col); | |||
| } | |||
| src_r += C4NUM * col; | |||
| dst_r += C4NUM * col16; | |||
| } | |||
| if (row != row_4div) { | |||
| for (int ci = 0; ci < col_16div; ci += C16NUM) { | |||
| MatrixPack4x16UnitInt8(src_r + ci, dst_r + ci * C4NUM, row_4res, C16NUM, col); | |||
| } | |||
| if (col != col_16div) { | |||
| MatrixPack4x16UnitInt8(src_r + col_16div, dst_r + col_16div * C4NUM, row_4res, col_16res, col); | |||
| } | |||
| } | |||
| return; | |||
| @@ -74,7 +122,7 @@ void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, co | |||
| } | |||
| } | |||
| void MatMulOptR4Int8(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16, | |||
| void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16, | |||
| const int *input_sum, const int *bias) { | |||
| /* row4x16-major * row16x4-major => row4x4-major */ | |||
| for (int r = 0; r < row_4; r++) { | |||
| @@ -17,6 +17,7 @@ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_MATMUL_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_MATMUL_H_ | |||
| #include <string.h> | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/matmul_parameter.h" | |||
| @@ -25,7 +26,7 @@ extern "C" { | |||
| #endif | |||
| void MatMulInt8(const int8_t *a, const int8_t *b, int *c, const int row8, const int col8, const int deep, | |||
| const int a_zp, const int b_zp); | |||
| void MatMulOptR4Int8(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16, | |||
| void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16, | |||
| const int *input_sum, const int *bias); | |||
| void RowMajor2Row8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); | |||
| void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); | |||
| @@ -107,6 +107,32 @@ TEST_F(TestDeconvInt8, PackWeight2) { | |||
| CompareOutputData(dst, co, 528, 1); | |||
| } | |||
| TEST_F(TestDeconvInt8, PackInputTest1) { | |||
| /* 6 x 20 */ | |||
| int8_t in[] = {40, 24, 94, 122, 67, 34, -89, 31, -43, 121, 48, -54, 44, -91, 35, 89, -37, 114, -8, 103, | |||
| -22, 32, 26, 112, -92, -23, 43, 9, 81, 118, -73, -54, 65, -99, 51, -90, 121, -62, 119, -93, | |||
| 21, -92, -1, -82, -71, -54, 63, -93, 92, -93, 99, 122, -104, -16, -8, -32, 90, -126, 51, 91, | |||
| 4, 70, -7, 116, 99, 81, -79, 124, -14, 28, 97, 9, -97, 99, 88, -15, 54, 26, 77, -25, | |||
| 113, 119, 119, -75, -17, 7, 7, 1, 69, 66, 40, -13, 80, -115, -98, -8, -17, 31, 88, 65, | |||
| -1, -15, -98, 77, 56, 119, -20, -32, -54, -58, -16, 52, 121, 126, -33, 43, 92, -34, -17, -52}; | |||
| int8_t co[] = {40, 24, 94, 122, 67, 34, -89, 31, -43, 121, 48, -54, 44, -91, 35, 89, -22, 32, 26, 112, | |||
| -92, -23, 43, 9, 81, 118, -73, -54, 65, -99, 51, -90, 21, -92, -1, -82, -71, -54, 63, -93, | |||
| 92, -93, 99, 122, -104, -16, -8, -32, 4, 70, -7, 116, 99, 81, -79, 124, -14, 28, 97, 9, | |||
| -97, 99, 88, -15, -37, 114, -8, 103, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 121, -62, 119, -93, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 90, -126, 51, 91, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 54, 26, 77, -25, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 113, 119, 119, -75, -17, 7, 7, 1, 69, 66, 40, -13, | |||
| 80, -115, -98, -8, -1, -15, -98, 77, 56, 119, -20, -32, -54, -58, -16, 52, 121, 126, -33, 43, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -17, 31, 88, 65, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 92, -34, -17, -52, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; | |||
| int8_t dst[8 * 32] = {0}; | |||
| RowMajor2Row16x4MajorInt8(in, dst, 6, 20); | |||
| CompareOutputData(dst, co, 8 * 32, 1); | |||
| } | |||
| TEST_F(TestDeconvInt8, MatMulTest1) { | |||
| int8_t a_row_major_10_12[] = { | |||
| -6, 76, 32, 80, -73, 8, -85, -3, 114, 80, 30, 42, -41, 117, 62, -76, -77, -111, 88, 105, | |||
| @@ -155,6 +181,30 @@ TEST_F(TestDeconvInt8, MatMulTest1) { | |||
| CompareOutputData(out_row_major, co_row_major_10_18, 180, 1); | |||
| } | |||
| TEST_F(TestDeconvInt8, InputSumTest1) { | |||
| int8_t packed_a[] = { | |||
| -6, 76, 32, 80, -73, 8, -85, -3, 114, 80, 30, 42, 15, 15, 15, 15, -41, 117, 62, -76, -77, -111, | |||
| 88, 105, 68, 105, -74, 13, 15, 15, 15, 15, 51, 94, 31, -52, -92, -4, -35, -71, 101, -93, 46, -65, | |||
| 15, 15, 15, 15, 57, -41, -51, 77, 1, 9, 73, -19, -36, 57, 81, -24, 15, 15, 15, 15, 40, 103, | |||
| 112, 109, -41, -68, 57, 61, 55, -20, 3, 2, 15, 15, 15, 15, 17, -16, -31, 58, -4, 67, -4, -95, | |||
| -5, -72, 81, 15, 15, 15, 15, 15, -7, -16, -47, 112, 114, -26, -98, 53, 15, -49, 26, 19, 15, 15, | |||
| 15, 15, 19, 8, -57, -35, -79, 118, 29, 21, 37, -48, 83, 7, 15, 15, 15, 15, 124, 113, -5, 15, | |||
| -8, 107, -65, -88, 50, -47, -80, -84, 15, 15, 15, 15, 3, -45, 92, 42, -20, -101, 106, -10, 89, 67, | |||
| 55, 10, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, | |||
| 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15}; | |||
| int32_t filter_zp = -20; | |||
| int32_t input_sum[12] = {0}; | |||
| int32_t correct_input_sum[] = {-7100, -4780, 580, -4880, -9460, -1420, -3120, -3260, -1840, -6960, -4800, -4800}; | |||
| DeConvPackInputSum(packed_a, input_sum, filter_zp, 12, 16, true); | |||
| CompareOutputData(input_sum, correct_input_sum, 12, 0); | |||
| int32_t input_sum_4[4] = {0}; | |||
| int32_t correct_input_sum_4[] = {-18400, -13160, -7340, -12940}; | |||
| DeConvPackInputSum(packed_a, input_sum_4, filter_zp, 4, 16 * 3, true); | |||
| CompareOutputData(input_sum_4, correct_input_sum_4, 4, 0); | |||
| } | |||
| TEST_F(TestDeconvInt8, MatMulOptTest1) { | |||
| int8_t a_src_ptr[] = {-6, 76, 32, 80, -73, 8, -85, -3, 114, 80, 30, 42, -41, 117, 62, -76, -77, -111, | |||
| 88, 105, 68, 105, -74, 13, 51, 94, 31, -52, -92, -4, -35, -71, 101, -93, 46, -65, | |||
| @@ -191,8 +241,7 @@ TEST_F(TestDeconvInt8, MatMulOptTest1) { | |||
| 15, 15, 19, 8, -57, -35, -79, 118, 29, 21, 37, -48, 83, 7, 15, 15, 15, 15, 124, 113, -5, 15, | |||
| -8, 107, -65, -88, 50, -47, -80, -84, 15, 15, 15, 15, 3, -45, 92, 42, -20, -101, 106, -10, 89, 67, | |||
| 55, 10, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, | |||
| 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, | |||
| }; | |||
| 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15}; | |||
| RowMajor2Row16x4MajorInt8(a_src_ptr, packed_a, 10, 12); | |||
| CompareOutputData(packed_a, correct_packed_a, 16 * 12, 0); | |||
| @@ -231,12 +280,6 @@ TEST_F(TestDeconvInt8, MatMulOptTest1) { | |||
| DeConvPackInputSum(packed_a, input_sum, filter_zp, 12, 16, true); | |||
| CompareOutputData(input_sum, correct_input_sum, 12, 0); | |||
| for (int i = 0; i < 12; i++) { | |||
| if (input_sum[i] != correct_input_sum[i]) { | |||
| printf("%d %d %d\n", i, input_sum[i], correct_input_sum[i]); | |||
| } | |||
| } | |||
| /* | |||
| * ---------------------- calculate weight_sum ------------------------- */ | |||
| int32_t weight_sum[3 * 8] = {0}; | |||
| @@ -270,7 +313,8 @@ TEST_F(TestDeconvInt8, MatMulOptTest1) { | |||
| 7894, -51, 0, 0, -4775, -29785, 0, 0, -12597, 4088, 0, 0, -17420, 1815, | |||
| 0, 0, 15796, 3101, 0, 0, -37969, -10818, 0, 0, 12714, -7827, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0}; | |||
| MatMulOptR4Int8(packed_a, packed_b, tmp_output, 12, 24, 16, input_sum, weight_sum); | |||
| MatMulInt8_16x4(packed_a, packed_b, tmp_output, 12, 24, 16, input_sum, weight_sum); | |||
| CompareOutputData(tmp_output, correct_tmp_output, 12 * 3 * 8, 0); | |||
| } | |||