From d8467311f46840ce0f93baa7312dddad4e854c04 Mon Sep 17 00:00:00 2001 From: fuzhiye Date: Thu, 12 Nov 2020 20:55:59 +0800 Subject: [PATCH] optimize int8 mul op && delete useless code --- mindspore/lite/nnacl/fp16/conv_fp16.c | 41 ++-- mindspore/lite/nnacl/fp32/conv.c | 41 ++-- mindspore/lite/nnacl/fp32/pooling.c | 25 +-- mindspore/lite/nnacl/int8/conv_int8.c | 31 +-- mindspore/lite/nnacl/int8/mul_int8.c | 204 +++++++++++++++--- mindspore/lite/nnacl/int8/mul_int8.h | 2 + mindspore/lite/nnacl/int8/pooling_int8.c | 78 +++---- .../lite/nnacl/int8/space_to_batch_int8.c | 21 +- mindspore/lite/nnacl/pack.c | 11 +- .../src/runtime/kernel/arm/int8/mul_int8.cc | 67 ++++++ .../src/runtime/kernel/arm/int8/mul_int8.h | 6 + 11 files changed, 336 insertions(+), 191 deletions(-) diff --git a/mindspore/lite/nnacl/fp16/conv_fp16.c b/mindspore/lite/nnacl/fp16/conv_fp16.c index b48cb5c656..dd017aacc7 100644 --- a/mindspore/lite/nnacl/fp16/conv_fp16.c +++ b/mindspore/lite/nnacl/fp16/conv_fp16.c @@ -125,25 +125,15 @@ void IndirectGemmFp16_16x8_c8(float16_t *output, float16_t *input, float16_t *we void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data, float16_t *col_major_input, float16_t *output_data, int task_id, ConvParameter *conv_param) { const int tile_n = 16; - int kernel_h = conv_param->kernel_h_; - int kernel_w = conv_param->kernel_w_; - int in_batch = conv_param->input_batch_; - int in_channel = conv_param->input_channel_; - int in_h = conv_param->input_h_; - int in_w = conv_param->input_w_; - int out_h = conv_param->output_h_; - int out_w = conv_param->output_w_; int out_channel = conv_param->output_channel_; - int thread_count = conv_param->thread_num_; - int output_count = out_h * out_w; + int output_count = conv_param->output_h_ * conv_param->output_w_; int output_tile_count = UP_DIV(output_count, tile_n); - int kernel_plane = kernel_h * kernel_w; - int deep = kernel_plane * in_channel; + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; - for (int b = 0; b < in_batch; b++) { - int in_batch_offset = b * in_channel * in_h * in_w; - int out_batch_offset = b * out_channel * out_h * out_w; - for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_batch_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + int out_batch_offset = b * out_channel * output_count; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { int start_index = thread_id * tile_n; int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n; float16_t *gemm_input = packed_input + task_id * deep * tile_n; @@ -166,18 +156,13 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa float16_t *output_data, TmpBufferAddressFp16 *buffer_list, int task_id, ConvParameter *conv_param, InputTransFp16Func in_func, OutputTransFp16Func out_func) { const int tile_num = 16; - int thread_num = conv_param->thread_num_; - int input_unit = conv_param->input_unit_; - int in_batch = conv_param->input_batch_; int in_channel = conv_param->input_channel_; - int out_unit = conv_param->output_unit_; - int out_w_block = UP_DIV(conv_param->output_w_, out_unit); - int out_h_block = UP_DIV(conv_param->output_h_, out_unit); + int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_); + int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_); int output_count = out_w_block * out_h_block; int output_tile_count = UP_DIV(output_count, tile_num); - int out_channel = conv_param->output_channel_; - int oc8 = UP_DIV(out_channel, C8NUM); - int input_unit_square = input_unit * input_unit; + int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); + int input_unit_square = conv_param->input_unit_ * conv_param->input_unit_; float16_t *trans_input = buffer_list[0]; float16_t *gemm_out = buffer_list[1]; @@ -189,10 +174,10 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa int col_buffer_offset = tile_num * in_channel; // step 1 : filter transform (pre-processed offline) // step 2 : input transform (online) - for (int b = 0; b < in_batch; b++) { + for (int b = 0; b < conv_param->input_batch_; b++) { int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_; - int out_batch_offset = b * out_channel * conv_param->output_h_ * conv_param->output_w_; - for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) { + int out_batch_offset = b * conv_param->output_channel_ * conv_param->output_h_ * conv_param->output_w_; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { int out_tile_index = thread_id * tile_num; int cal_num = output_count - thread_id * tile_num; cal_num = cal_num > tile_num ? tile_num : cal_num; diff --git a/mindspore/lite/nnacl/fp32/conv.c b/mindspore/lite/nnacl/fp32/conv.c index 7bc0f3fd89..3bdfd40636 100644 --- a/mindspore/lite/nnacl/fp32/conv.c +++ b/mindspore/lite/nnacl/fp32/conv.c @@ -23,30 +23,20 @@ // fp32 conv common void ConvFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, float *col_major_input, float *output_data, int task_id, ConvParameter *conv_param) { - int kernel_h = conv_param->kernel_h_; - int kernel_w = conv_param->kernel_w_; - int in_batch = conv_param->input_batch_; - int in_channel = conv_param->input_channel_; - int in_h = conv_param->input_h_; - int in_w = conv_param->input_w_; - int out_h = conv_param->output_h_; - int out_w = conv_param->output_w_; int out_channel = conv_param->output_channel_; - int thread_count = conv_param->thread_num_; - int output_count = out_h * out_w; + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + int output_count = conv_param->output_h_ * conv_param->output_w_; #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) const int cal_num = C4NUM; #else const int cal_num = C12NUM; #endif int output_tile_count = UP_DIV(output_count, cal_num); - int kernel_plane = kernel_h * kernel_w; - int deep = kernel_plane * in_channel; - for (int b = 0; b < in_batch; b++) { - int in_batch_offset = b * in_channel * in_h * in_w; - int out_batch_offset = b * out_channel * out_h * out_w; - for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_batch_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + int out_batch_offset = b * out_channel * output_count; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { int start_index = thread_id * cal_num; int real_cal_num = (output_count - start_index) < cal_num ? (output_count - start_index) : cal_num; float *gemm_input = packed_input + task_id * deep * cal_num; @@ -73,19 +63,14 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const float *bias_data, float *output_data, TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, InputTransFunc in_func, OutputTransFunc out_func) { - int thread_num = conv_param->thread_num_; - int input_unit = conv_param->input_unit_; - int in_batch = conv_param->input_batch_; int in_channel = conv_param->input_channel_; - int out_unit = conv_param->output_unit_; - int out_w_block = UP_DIV(conv_param->output_w_, out_unit); - int out_h_block = UP_DIV(conv_param->output_h_, out_unit); + int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_); + int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_); int output_count = out_w_block * out_h_block; const int tile_num = C12NUM; int output_tile_count = UP_DIV(output_count, tile_num); - int out_channel = conv_param->output_channel_; - int oc8 = UP_DIV(out_channel, C8NUM); - int input_unit_square = input_unit * input_unit; + int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); + int input_unit_square = conv_param->input_unit_ * conv_param->input_unit_; float *trans_input = buffer_list[0]; float *gemm_out = buffer_list[1]; @@ -97,10 +82,10 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const int col_buffer_offset = tile_num * in_channel; // step 1 : filter transform (pre-processed offline) // step 2 : input transform (online) - for (int b = 0; b < in_batch; b++) { + for (int b = 0; b < conv_param->input_batch_; b++) { int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_; - int out_batch_offset = b * out_channel * conv_param->output_w_ * conv_param->output_h_; - for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) { + int out_batch_offset = b * conv_param->output_channel_ * conv_param->output_w_ * conv_param->output_h_; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { int out_tile_index = thread_id * tile_num; int cal_num = output_count - out_tile_index; cal_num = cal_num > tile_num ? tile_num : cal_num; diff --git a/mindspore/lite/nnacl/fp32/pooling.c b/mindspore/lite/nnacl/fp32/pooling.c index 231d500db4..8aab662d52 100644 --- a/mindspore/lite/nnacl/fp32/pooling.c +++ b/mindspore/lite/nnacl/fp32/pooling.c @@ -20,10 +20,6 @@ int AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id, float minf, float maxf) { - int stride_w = pooling_param->stride_w_; - int stride_h = pooling_param->stride_h_; - int pad_w = pooling_param->pad_l_; - int pad_h = pooling_param->pad_u_; int win_w = pooling_param->window_w_; int win_h = pooling_param->window_h_; int channel = pooling_param->input_channel_; @@ -32,10 +28,8 @@ int AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pool int in_h = pooling_param->input_h_; int output_w = pooling_param->output_w_; int output_h = pooling_param->output_h_; - int output_batch = pooling_param->output_batch_; int out_plane = output_w * output_h; int out_tile_count = UP_DIV(out_plane, TILE_NUM); - int thread_num = pooling_param->thread_num_; int window = win_w * win_h; #ifdef ENABLE_NEON @@ -43,18 +37,18 @@ int AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pool float32x4_t max_value = vdupq_n_f32(maxf); #endif - for (int batch = 0; batch < output_batch; batch++) { + for (int batch = 0; batch < pooling_param->output_batch_; batch++) { const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel; float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel; - for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += pooling_param->thread_num_) { int cal_start_index = thread_id * TILE_NUM; int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); for (int i = 0; i < real_cal_num; i++) { int index = cal_start_index + i; int out_w_index = index % output_w; int out_h_index = index / output_w; - int in_w_index = out_w_index * stride_w - pad_w; - int in_h_index = out_h_index * stride_h - pad_h; + int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; + int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; const float *src_plane_ptr = src_b_ptr; float *dst_plane_ptr = dst_b_ptr + index * channel; @@ -152,10 +146,6 @@ int AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pool void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id, float minf, float maxf) { - int stride_w = pooling_param->stride_w_; - int stride_h = pooling_param->stride_h_; - int pad_w = pooling_param->pad_l_; - int pad_h = pooling_param->pad_u_; int win_w = pooling_param->window_w_; int win_h = pooling_param->window_h_; int channel = pooling_param->input_channel_; @@ -166,7 +156,6 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo int output_batch = pooling_param->output_batch_; int out_plane = output_w * output_h; int out_tile_count = UP_DIV(out_plane, TILE_NUM); - int thread_num = pooling_param->thread_num_; int c4 = channel / C4NUM; /* oc && ic */ #ifdef ENABLE_NEON @@ -177,15 +166,15 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo for (int batch = 0; batch < output_batch; batch++) { const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel; float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel; - for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += pooling_param->thread_num_) { int cal_start_index = thread_id * TILE_NUM; int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); for (int i = 0; i < real_cal_num; i++) { int index = cal_start_index + i; int out_w_index = index % output_w; int out_h_index = index / output_w; - int in_w_index = out_w_index * stride_w - pad_w; - int in_h_index = out_h_index * stride_h - pad_h; + int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; + int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; const float *src_plane_ptr = src_b_ptr; float *dst_plane_ptr = dst_b_ptr + index * channel; diff --git a/mindspore/lite/nnacl/int8/conv_int8.c b/mindspore/lite/nnacl/int8/conv_int8.c index af39334282..2e99904c89 100644 --- a/mindspore/lite/nnacl/int8/conv_int8.c +++ b/mindspore/lite/nnacl/int8/conv_int8.c @@ -65,20 +65,12 @@ void Conv3x3Int8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, in void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int8_t *packed_weight, const int32_t *bias_data, int8_t *output_data, int32_t *filter_zp, int32_t *input_sum, int task_id, ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func, bool is_optimize) { - int kernel_h = conv_param->kernel_h_; - int kernel_w = conv_param->kernel_w_; - int in_batch = conv_param->input_batch_; int in_channel = conv_param->input_channel_; - int in_h = conv_param->input_h_; - int in_w = conv_param->input_w_; - int out_h = conv_param->output_h_; - int out_w = conv_param->output_w_; int out_channel = conv_param->output_channel_; int tile_n = conv_param->tile_num_; - int thread_count = conv_param->thread_num_; - int output_count = out_h * out_w; + int output_count = conv_param->output_h_ * conv_param->output_w_; int output_tile_count = UP_DIV(output_count, tile_n); - int kernel_plane = kernel_h * kernel_w; + int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_; int unit_size; int input_sum_offset; int up_round_oc; @@ -103,10 +95,10 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, in per_channel = false; } - for (int b = 0; b < in_batch; b++) { - int in_batch_offset = b * in_channel * in_h * in_w; - int out_batch_offset = b * out_channel * out_h * out_w; - for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_; + int out_batch_offset = b * out_channel * conv_param->output_h_ * conv_param->output_w_; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { int start_index = thread_id * tile_n; int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n; int32_t *tmp_input_sum = input_sum + task_id * input_sum_offset; @@ -858,23 +850,20 @@ void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out, int task_id, ConvParameter *conv_param) { - int thread_count = conv_param->thread_num_; int ic8 = UP_DIV(conv_param->input_channel_, C8NUM); - int output_channel = conv_param->output_channel_; int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT); int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT); int output_count = out_w_block * out_h_block; int output_tile_count = UP_DIV(output_count, TILE_NUM); - int oc4 = UP_DIV(output_channel, C4NUM); + int oc4 = UP_DIV(conv_param->output_channel_, C4NUM); int tile_buffer_offset = TILE_NUM * 16 * ic8 * C8NUM; const int block_unit_buffer_offset = 16 * C8NUM; int tmp_dst_buffer_offset = TILE_NUM * 16 * oc4 * C4NUM; - int input_batch = conv_param->input_batch_; - for (int batch = 0; batch < input_batch; batch++) { + for (int batch = 0; batch < conv_param->input_batch_; batch++) { int in_batch_offset = batch * ic8 * C8NUM * conv_param->input_h_ * conv_param->input_w_; int tmp_out_batch_offset = batch * oc4 * C4NUM * conv_param->output_w_ * conv_param->output_h_; - for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { int start_index = thread_id * TILE_NUM; int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; @@ -883,7 +872,7 @@ void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bi out_w_block, conv_param); Conv3x3Int8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset, - transed_weight, output_channel, ic8, real_cal_num); + transed_weight, conv_param->output_channel_, ic8, real_cal_num); Conv3x3Int8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset, bias_data, start_index, real_cal_num, out_w_block, conv_param); diff --git a/mindspore/lite/nnacl/int8/mul_int8.c b/mindspore/lite/nnacl/int8/mul_int8.c index 35c530b244..98a1568916 100644 --- a/mindspore/lite/nnacl/int8/mul_int8.c +++ b/mindspore/lite/nnacl/int8/mul_int8.c @@ -24,15 +24,13 @@ #ifdef ENABLE_NEON -int16x4_t ClacSumHalfWordMul(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t left_shift_out_vec, - int32x4_t output_multiplier_vec, MulQuantArg para) { - int32x4_t input_scale = vmulq_s32(scaled_input0, scaled_input1); - int32x4_t raw_sum = RoundingDivideByPOTInt32x4( - SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec), - para.shift_right_); - raw_sum = vaddq_s32(raw_sum, vdupq_n_s32(para.out_quant_arg_.zp_)); - raw_sum = vmaxq_s32(raw_sum, vdupq_n_s32(para.output_activation_min_)); - raw_sum = vminq_s32(raw_sum, vdupq_n_s32(para.output_activation_max_)); +int16x4_t ClacSumHalfWordMul(int16x4_t scaled_input0, int16x4_t scaled_input1, int32x4_t left_shift_out_vec, + int32x4_t right_shift_out_vec, int32x4_t output_multiplier_vec) { + int32x4_t input_scale = vmull_s16(scaled_input0, scaled_input1); + int32x4_t raw_sum = vqrdmulhq_s32(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec); + const int32x4_t fixup = vshrq_n_s32(vandq_s32(raw_sum, right_shift_out_vec), 31); + const int32x4_t fixed_up_x = vqaddq_s32(raw_sum, fixup); + raw_sum = vrshlq_s32(fixed_up_x, right_shift_out_vec); return vqmovn_s32(raw_sum); } @@ -40,27 +38,189 @@ void MulInt8NEON(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, MulQuantArg para, int *index) { int32x4_t output_multiplier_vec = vdupq_n_s32(para.output_multiplier_); int32x4_t left_shift_out_vec = vdupq_n_s32(1 << para.shift_left_); + int32x4_t right_shift_out_vec = vdupq_n_s32(-para.shift_right_); + int16x8_t out_zp_vec = vdupq_n_s16(para.out_quant_arg_.zp_); + int8x16_t out_min_vec = vdupq_n_s8(para.output_activation_min_); + int8x16_t out_max_vec = vdupq_n_s8(para.output_activation_max_); + int8x8_t out_min_vec_s8 = vdup_n_s8(para.output_activation_min_); + int8x8_t out_max_vec_s8 = vdup_n_s8(para.output_activation_max_); + for (; (*index) <= real_dst_count - 16; (*index) += 16) { + int16x8_t zp1_vec = vdupq_n_s16(para.in_quant_args_[0].zp_); + int16x8_t zp2_vec = vdupq_n_s16(para.in_quant_args_[1].zp_); + int8x16_t input0_vec = vld1q_s8(input0_data + *index); + int8x16_t input1_vec = vld1q_s8(input1_data + *index); + int16x8_t input0_low = vmovl_s8(vget_low_s8(input0_vec)); + int16x8_t input0_high = vmovl_s8(vget_high_s8(input0_vec)); + int16x8_t input1_low = vmovl_s8(vget_low_s8(input1_vec)); + int16x8_t input1_high = vmovl_s8(vget_high_s8(input1_vec)); + input0_low = vaddq_s16(input0_low, zp1_vec); + input0_high = vaddq_s16(input0_high, zp1_vec); + input1_low = vaddq_s16(input1_low, zp2_vec); + input1_high = vaddq_s16(input1_high, zp2_vec); + + int16x4_t input0_low_low = vget_low_s16(input0_low); + int16x4_t input0_low_high = vget_high_s16(input0_low); + int16x4_t input0_high_low = vget_low_s16(input0_high); + int16x4_t input0_high_high = vget_high_s16(input0_high); + int16x4_t input1_low_low = vget_low_s16(input1_low); + int16x4_t input1_low_high = vget_high_s16(input1_low); + int16x4_t input1_high_low = vget_low_s16(input1_high); + int16x4_t input1_high_high = vget_high_s16(input1_high); + + int16x4_t sum_low_low = ClacSumHalfWordMul(input0_low_low, input1_low_low, left_shift_out_vec, right_shift_out_vec, + output_multiplier_vec); + int16x4_t sum_low_high = ClacSumHalfWordMul(input0_low_high, input1_low_high, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_high_low = ClacSumHalfWordMul(input0_high_low, input1_high_low, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_high_high = ClacSumHalfWordMul(input0_high_high, input1_high_high, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + + int16x8_t res_s16 = vaddq_s16(vcombine_s16(sum_low_low, sum_low_high), out_zp_vec); + int16x8_t res_s162 = vaddq_s16(vcombine_s16(sum_high_low, sum_high_high), out_zp_vec); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + int8x8_t res_u8_n1 = vqmovn_s16(res_s162); + int8x16_t res_s8 = vcombine_s8(res_u8_n0, res_u8_n1); + res_s8 = vminq_s8(res_s8, out_max_vec); + res_s8 = vmaxq_s8(res_s8, out_min_vec); + vst1q_s8(output_data, res_s8); + output_data += 16; + } for (; (*index) <= real_dst_count - 8; (*index) += 8) { int16x8_t input0_val = LoadAndAddOffset(input0_data, *index, para.in_quant_args_[0].zp_); int16x8_t input1_val = LoadAndAddOffset(input1_data, *index, para.in_quant_args_[1].zp_); - int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val)); - int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val)); - int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val)); - int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val)); + int16x4_t input0_low = vget_low_s16(input0_val); + int16x4_t input0_high = vget_high_s16(input0_val); + int16x4_t input1_low = vget_low_s16(input1_val); + int16x4_t input1_high = vget_high_s16(input1_val); - int16x4_t sum_low = ClacSumHalfWordMul(input0_low, input1_low, left_shift_out_vec, output_multiplier_vec, para); - int16x4_t sum_high = ClacSumHalfWordMul(input0_high, input1_high, left_shift_out_vec, output_multiplier_vec, para); + int16x4_t sum_low = + ClacSumHalfWordMul(input0_low, input1_low, left_shift_out_vec, right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_high = + ClacSumHalfWordMul(input0_high, input1_high, left_shift_out_vec, right_shift_out_vec, output_multiplier_vec); - int16x8_t res_s16 = vcombine_s16(sum_low, sum_high); + int16x8_t res_s16 = vaddq_s16(vcombine_s16(sum_low, sum_high), out_zp_vec); int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + res_u8_n0 = vmin_s8(res_u8_n0, out_max_vec_s8); + res_u8_n0 = vmax_s8(res_u8_n0, out_min_vec_s8); vst1_s8(output_data, res_u8_n0); output_data += 8; } } #endif +void FastMul(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int depth, int64_t real_dst_count, + bool input1_broad, MulQuantArg para) { + // input0 need broadcast + int32_t zp1 = para.in_quant_args_[0].zp_; + int32_t zp2 = para.in_quant_args_[1].zp_; + if (input1_broad) { + zp1 = para.in_quant_args_[1].zp_; + zp2 = para.in_quant_args_[0].zp_; + } +#ifdef ENABLE_ARM + int32x4_t output_multiplier_vec = vdupq_n_s32(para.output_multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32(1 << para.shift_left_); + int32x4_t right_shift_out_vec = vdupq_n_s32(-para.shift_right_); + int16x8_t out_zp_vec = vdupq_n_s16(para.out_quant_arg_.zp_); + int8x16_t out_min_vec = vdupq_n_s8(para.output_activation_min_); + int8x16_t out_max_vec = vdupq_n_s8(para.output_activation_max_); + int8x8_t out_min_vec_s8 = vdup_n_s8(para.output_activation_min_); + int8x8_t out_max_vec_s8 = vdup_n_s8(para.output_activation_max_); + int16x8_t zp1_vec = vdupq_n_s16(zp1); + int16x8_t zp2_vec = vdupq_n_s16(zp2); +#endif + for (int index = 0; index < real_dst_count; ++index) { + int j = 0; +#ifdef ENABLE_ARM + for (; j <= depth - 16; j += 16) { + int8x16_t input0_vec = vld1q_s8(input0_data + j); + int8x16_t input1_vec = vld1q_s8(input1_data); + int16x8_t input0_low = vmovl_s8(vget_low_s8(input0_vec)); + int16x8_t input0_high = vmovl_s8(vget_high_s8(input0_vec)); + int16x8_t input1_low = vmovl_s8(vget_low_s8(input1_vec)); + int16x8_t input1_high = vmovl_s8(vget_high_s8(input1_vec)); + input0_low = vaddq_s16(input0_low, zp1_vec); + input0_high = vaddq_s16(input0_high, zp1_vec); + input1_low = vaddq_s16(input1_low, zp2_vec); + input1_high = vaddq_s16(input1_high, zp2_vec); + + int16x4_t input0_low_low = vget_low_s16(input0_low); + int16x4_t input0_low_high = vget_high_s16(input0_low); + int16x4_t input0_high_low = vget_low_s16(input0_high); + int16x4_t input0_high_high = vget_high_s16(input0_high); + int16x4_t input1_low_low = vget_low_s16(input1_low); + int16x4_t input1_low_high = vget_high_s16(input1_low); + int16x4_t input1_high_low = vget_low_s16(input1_high); + int16x4_t input1_high_high = vget_high_s16(input1_high); + + int16x4_t sum_low_low = ClacSumHalfWordMul(input0_low_low, input1_low_low, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_low_high = ClacSumHalfWordMul(input0_low_high, input1_low_high, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_high_low = ClacSumHalfWordMul(input0_high_low, input1_high_low, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_high_high = ClacSumHalfWordMul(input0_high_high, input1_high_high, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + + int16x8_t res_s16 = vaddq_s16(vcombine_s16(sum_low_low, sum_low_high), out_zp_vec); + int16x8_t res_s162 = vaddq_s16(vcombine_s16(sum_high_low, sum_high_high), out_zp_vec); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + int8x8_t res_u8_n1 = vqmovn_s16(res_s162); + int8x16_t res_s8 = vcombine_s8(res_u8_n0, res_u8_n1); + res_s8 = vminq_s8(res_s8, out_max_vec); + res_s8 = vmaxq_s8(res_s8, out_min_vec); + vst1q_s8(output_data, res_s8); + input1_data += 16; + output_data += 16; + } + for (; j <= depth - 8; j += 8) { + int8x8_t input0_vec = vld1_s8(input0_data + j); + int8x8_t input1_vec = vld1_s8(input1_data); + int16x8_t input0_val = vmovl_s8(input0_vec); + int16x8_t input1_val = vmovl_s8(input1_vec); + input0_val = vaddq_s16(input0_val, zp1_vec); + input1_val = vaddq_s16(input1_val, zp2_vec); + + int16x4_t input0_low = vget_low_s16(input0_val); + int16x4_t input0_high = vget_high_s16(input0_val); + int16x4_t input1_low = vget_low_s16(input1_val); + int16x4_t input1_high = vget_high_s16(input1_val); + + int16x4_t sum_low = + ClacSumHalfWordMul(input0_low, input1_low, left_shift_out_vec, right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_high = + ClacSumHalfWordMul(input0_high, input1_high, left_shift_out_vec, right_shift_out_vec, output_multiplier_vec); + + int16x8_t res_s16 = vaddq_s16(vcombine_s16(sum_low, sum_high), out_zp_vec); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + res_u8_n0 = vmin_s8(res_u8_n0, out_max_vec_s8); + res_u8_n0 = vmax_s8(res_u8_n0, out_min_vec_s8); + vst1_s8(output_data, res_u8_n0); + input1_data += 8; + output_data += 8; + } +#endif + for (; j < depth; ++j) { + const int32_t input0_val = zp1 + input0_data[j]; + const int32_t input1_val = zp2 + input1_data[0]; + int32_t mul_result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(input0_val * input1_val * (1 << para.shift_left_), para.output_multiplier_), + para.shift_right_); + + mul_result += para.out_quant_arg_.zp_; + mul_result = mul_result < para.output_activation_max_ ? mul_result : para.output_activation_max_; + mul_result = mul_result > para.output_activation_min_ ? mul_result : para.output_activation_min_; + output_data[0] = (int8_t)mul_result; + input1_data++; + output_data++; + } + } + return; +} + void Mul(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, MulQuantArg para) { int index = 0; #ifdef ENABLE_NEON @@ -74,14 +234,10 @@ void Mul(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t para.shift_right_); mul_result += para.out_quant_arg_.zp_; - - if (mul_result > para.output_activation_max_) { - output_data[index] = para.output_activation_max_; - } else if (mul_result < para.output_activation_min_) { - output_data[index] = para.output_activation_min_; - } else { - output_data[index] = (int8_t)mul_result; - } + mul_result = mul_result < para.output_activation_max_ ? mul_result : para.output_activation_max_; + mul_result = mul_result > para.output_activation_min_ ? mul_result : para.output_activation_min_; + output_data[0] = (int8_t)mul_result; + output_data++; } return; } diff --git a/mindspore/lite/nnacl/int8/mul_int8.h b/mindspore/lite/nnacl/int8/mul_int8.h index f4e80243ee..cb198f7fb0 100644 --- a/mindspore/lite/nnacl/int8/mul_int8.h +++ b/mindspore/lite/nnacl/int8/mul_int8.h @@ -24,6 +24,8 @@ extern "C" { #endif void Mul(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, MulQuantArg para); +void FastMul(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int depth, int64_t real_dst_count, + bool input1_broad, MulQuantArg para); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/int8/pooling_int8.c b/mindspore/lite/nnacl/int8/pooling_int8.c index 4b1ebcc3e1..8621702f7a 100644 --- a/mindspore/lite/nnacl/int8/pooling_int8.c +++ b/mindspore/lite/nnacl/int8/pooling_int8.c @@ -80,33 +80,24 @@ int AvgPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter } int AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id) { - int stride_w = pooling_param->stride_w_; - int stride_h = pooling_param->stride_h_; - int pad_w = pooling_param->pad_l_; - int pad_h = pooling_param->pad_u_; int win_w = pooling_param->window_w_; int win_h = pooling_param->window_h_; int channel = pooling_param->input_channel_; + int c16 = channel / C16NUM; int in_w = pooling_param->input_w_; - int in_h = pooling_param->input_h_; int output_w = pooling_param->output_w_; - int output_h = pooling_param->output_h_; - int output_batch = pooling_param->output_batch_; - int out_plane = output_w * output_h; + int out_plane = output_w * pooling_param->output_h_; int out_tile_count = UP_DIV(out_plane, TILE_NUM); int thread_num = out_tile_count < pooling_param->thread_num_ ? out_tile_count : pooling_param->thread_num_; - float input_scale = pooling_param->quant_args_[0][0].scale_; int input_zp = pooling_param->quant_args_[0][0].zp_; - float output_scale = pooling_param->quant_args_[1][0].scale_; int output_zp = pooling_param->quant_args_[1][0].zp_; - double real_multiplier = input_scale / output_scale; - int c16 = channel / C16NUM; + double real_multiplier = pooling_param->quant_args_[0][0].scale_ / pooling_param->quant_args_[1][0].scale_; const int8_t out_min = INT8_MIN; const int8_t out_max = INT8_MAX; - for (int batch = 0; batch < output_batch; batch++) { - int in_batch_offset = batch * in_h * in_w * channel; - int out_batch_offset = batch * output_h * output_w * channel; + for (int batch = 0; batch < pooling_param->output_batch_; batch++) { + int in_batch_offset = batch * pooling_param->input_h_ * in_w * channel; + int out_batch_offset = batch * pooling_param->output_h_ * output_w * channel; for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { int cal_start_index = thread_id * TILE_NUM; int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); @@ -114,14 +105,14 @@ int AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParame int index = cal_start_index + i; int out_w_index = index % output_w; int out_h_index = index / output_w; - int in_w_index = out_w_index * stride_w - pad_w; - int in_h_index = out_h_index * stride_h - pad_h; + int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; + int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; int out_plane_offset = out_batch_offset + index * channel; int input_stride = (in_h_index * in_w + in_w_index) * channel; int kw_s = MSMAX(0, -in_w_index); int kw_e = MSMIN(win_w, in_w - in_w_index); int kh_s = MSMAX(0, -in_h_index); - int kh_e = MSMIN(win_h, in_h - in_h_index); + int kh_e = MSMIN(win_h, pooling_param->input_h_ - in_h_index); int real_count = (kw_e - kw_s) * (kh_e - kh_s); if (real_count == 0) { return NNACL_ERR; @@ -335,19 +326,11 @@ void MaxPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParamete void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id) { - int stride_w = pooling_param->stride_w_; - int stride_h = pooling_param->stride_h_; - int pad_w = pooling_param->pad_l_; - int pad_h = pooling_param->pad_u_; - int win_w = pooling_param->window_w_; - int win_h = pooling_param->window_h_; int channel = pooling_param->input_channel_; int in_w = pooling_param->input_w_; int in_h = pooling_param->input_h_; int output_w = pooling_param->output_w_; - int output_h = pooling_param->output_h_; - int output_batch = pooling_param->output_batch_; - int out_plane = output_w * output_h; + int out_plane = output_w * pooling_param->output_h_; int out_tile_count = UP_DIV(out_plane, TILE_NUM); int thread_num = out_tile_count < pooling_param->thread_num_ ? out_tile_count : pooling_param->thread_num_; int c16 = UP_DIV(channel, 16); @@ -358,9 +341,9 @@ void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, Poolin int output_zp = pooling_param->quant_args_[1][0].zp_; double real_multiplier = input_scale / output_scale; - for (int batch = 0; batch < output_batch; batch++) { + for (int batch = 0; batch < pooling_param->output_batch_; batch++) { int in_batch_offset = batch * in_h * in_w * channel; - int out_batch_offset = batch * output_h * output_w * channel; + int out_batch_offset = batch * pooling_param->output_h_ * output_w * channel; for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { int cal_start_index = thread_id * TILE_NUM; int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); @@ -368,8 +351,8 @@ void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, Poolin int index = cal_start_index + i; int out_w_index = index % output_w; int out_h_index = index / output_w; - int in_w_index = out_w_index * stride_w - pad_w; - int in_h_index = out_h_index * stride_h - pad_h; + int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; + int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; int out_plane_offset = out_batch_offset + index * channel; for (int j = 0; j < c16 - 1; j++) { int in_channel_offset = in_batch_offset + j * 16; @@ -382,8 +365,8 @@ void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, Poolin tmp_max[m] = INT8_MIN; } #endif - for (int h = 0; h < win_h; h++) { - for (int w = 0; w < win_w; w++) { + for (int h = 0; h < pooling_param->window_h_; h++) { + for (int w = 0; w < pooling_param->window_w_; w++) { if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || (in_w_index + w) >= in_w) { continue; @@ -418,8 +401,8 @@ void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, Poolin int in_channel_offset = in_batch_offset + k; int out_channel_offset = out_plane_offset + k; int8_t tmp_max = INT8_MIN; - for (int h = 0; h < win_h; h++) { - for (int w = 0; w < win_w; w++) { + for (int h = 0; h < pooling_param->window_h_; h++) { + for (int w = 0; w < pooling_param->window_w_; w++) { if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || (in_w_index + w) >= in_w) { continue; @@ -437,26 +420,17 @@ void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, Poolin } void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id) { - int stride_w = pooling_param->stride_w_; - int stride_h = pooling_param->stride_h_; - int pad_w = pooling_param->pad_l_; - int pad_h = pooling_param->pad_u_; - int win_w = pooling_param->window_w_; - int win_h = pooling_param->window_h_; int channel = pooling_param->input_channel_; int in_w = pooling_param->input_w_; - int in_h = pooling_param->input_h_; int output_w = pooling_param->output_w_; - int output_h = pooling_param->output_h_; - int output_batch = pooling_param->output_batch_; - int out_plane = output_w * output_h; + int out_plane = output_w * pooling_param->output_h_; int out_tile_count = UP_DIV(out_plane, TILE_NUM); int thread_num = MSMIN(out_tile_count, pooling_param->thread_num_); int8_t out_array[MAX_MAXPOOL_SIZE]; - for (int batch = 0; batch < output_batch; batch++) { - int in_batch_offset = batch * in_h * in_w * channel; - int out_batch_offset = batch * output_h * output_w * channel; + for (int batch = 0; batch < pooling_param->output_batch_; batch++) { + int in_batch_offset = batch * pooling_param->input_h_ * in_w * channel; + int out_batch_offset = batch * pooling_param->output_h_ * output_w * channel; for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { int cal_start_index = thread_id * TILE_NUM; int real_cal_num = out_plane - cal_start_index; @@ -465,12 +439,12 @@ void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam int index = cal_start_index + i; int out_w_index = index % output_w; int out_h_index = index / output_w; - int in_w_index = out_w_index * stride_w - pad_w; - int in_h_index = out_h_index * stride_h - pad_h; + int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; + int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; const int ky_s = 0 > (-in_h_index) ? 0 : (-in_h_index); - int ky_e = MSMIN(win_h, in_h - in_h_index); + int ky_e = MSMIN(pooling_param->window_h_, pooling_param->input_h_ - in_h_index); const int kx_s = 0 > (-in_w_index) ? 0 : (-in_w_index); - int kx_e = MSMIN(win_w, in_w - in_w_index); + int kx_e = MSMIN(pooling_param->window_w_, in_w - in_w_index); int input_stride = (in_h_index * in_w + in_w_index) * channel + in_batch_offset; int out_plane_offset = out_batch_offset + index * channel; diff --git a/mindspore/lite/nnacl/int8/space_to_batch_int8.c b/mindspore/lite/nnacl/int8/space_to_batch_int8.c index 76613f0bb6..df3aa2cfc6 100644 --- a/mindspore/lite/nnacl/int8/space_to_batch_int8.c +++ b/mindspore/lite/nnacl/int8/space_to_batch_int8.c @@ -47,20 +47,17 @@ void DoSpaceToBatchNHWCInt8(const int8_t *input, int8_t *output, const int *bloc } void DoSpaceToBatchPaddingNHWCInt8(const int8_t *input, int8_t *output, SpaceToBatchParameter *param, int32_t zp) { - int *in_shape = param->input_shape_; - int *out_shape = param->output_shape_; - int *paddings = param->paddings_; int block_shape_h = param->block_sizes_[0]; int block_shape_w = param->m_ == 2 ? param->block_sizes_[1] : 1; - int in_b = in_shape[0]; - int in_h = in_shape[1]; - int in_w = in_shape[2]; - int channel = in_shape[3]; - int out_h = out_shape[1]; - int out_w = out_shape[2]; - int pad_t = paddings[0]; - int pad_l = param->m_ == 2 ? paddings[2] : 0; - for (int i = 0; i < out_shape[0]; ++i) { + int in_b = param->input_shape_[0]; + int in_h = param->input_shape_[1]; + int in_w = param->input_shape_[2]; + int channel = param->input_shape_[3]; + int out_h = param->output_shape_[1]; + int out_w = param->output_shape_[2]; + int pad_t = param->paddings_[0]; + int pad_l = param->m_ == 2 ? param->paddings_[2] : 0; + for (int i = 0; i < param->output_shape_[0]; ++i) { int in_batch = i % in_b; int offset_w = (i / in_b) % block_shape_w; int offset_h = (i / in_b) / block_shape_w; diff --git a/mindspore/lite/nnacl/pack.c b/mindspore/lite/nnacl/pack.c index 877efe2324..fde87508eb 100644 --- a/mindspore/lite/nnacl/pack.c +++ b/mindspore/lite/nnacl/pack.c @@ -219,24 +219,19 @@ void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, floa int kernel_h = conv_param->kernel_h_; int kernel_w = conv_param->kernel_w_; int kernel_plane = kernel_h * kernel_w; - int stride_h = conv_param->stride_h_; - int stride_w = conv_param->stride_w_; - int pad_h = conv_param->pad_u_; - int pad_w = conv_param->pad_l_; int dilation_h = conv_param->dilation_h_; int dilation_w = conv_param->dilation_w_; int in_channel = conv_param->input_channel_; - int in_h = conv_param->input_h_; int in_w = conv_param->input_w_; int out_w = conv_param->output_w_; for (int i = 0; i < real_cal_num; i++) { int block_start = block_index + i; - int input_h = block_start / out_w * stride_h - pad_h; - int input_w = block_start % out_w * stride_w - pad_w; + int input_h = block_start / out_w * conv_param->stride_h_ - conv_param->pad_u_; + int input_w = block_start % out_w * conv_param->stride_w_ - conv_param->pad_l_; int input_stride = (input_h * in_w + input_w) * in_channel; int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h)); - int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h)); + int kh_e = MSMIN(kernel_h, UP_DIV(conv_param->input_h_ - input_h, dilation_h)); int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w)); int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w)); if (dilation_w == 1 && dilation_h == 1) { diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc index 68cd32f59b..63176af092 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc @@ -62,11 +62,46 @@ int MulInt8CPUKernel::Init() { return ReSize(); } +void MulInt8CPUKernel::CheckSameShapeSize(std::vector in_tensor0_shape, std::vector in_tensor1_shape) { + bool condition1 = in_tensor0_shape[0] == in_tensor1_shape[0]; + bool condition2 = in_tensor0_shape[1] == 1; + bool condition3 = in_tensor0_shape[2] == 1; + bool condition4 = in_tensor0_shape[3] == in_tensor1_shape[3]; + bool condition5 = in_tensor1_shape[1] == 1; + bool condition6 = in_tensor1_shape[2] == 1; + if (condition1 && condition2 && condition3 && condition4) { + fast_hw_broadcast_ = true; + } else if (condition1 && condition4 && condition5 && condition6) { + fast_hw_broadcast_ = true; + input1_hw_broadcast_ = true; + } +} + +void MulInt8CPUKernel::CheckIfFastImpl() { + auto in_tensor0 = in_tensors_.at(0); + auto in_tensor1 = in_tensors_.at(1); + if (in_tensor0->ElementsNum() != in_tensor1->ElementsNum()) { + if (in_tensor0->shape().size() == 4 && in_tensor1->shape().size() == 4) { + CheckSameShapeSize(in_tensor0->shape(), in_tensor1->shape()); + } else if (in_tensor0->shape().size() == 1 && in_tensor1->shape().size() == 4) { + if (in_tensor0->ElementsNum() == in_tensor1->shape()[3]) { + fast_hw_broadcast_ = true; + } + } else if (in_tensor0->shape().size() == 4 && in_tensor1->shape().size() == 1) { + if (in_tensor1->ElementsNum() == in_tensor0->shape()[3]) { + fast_hw_broadcast_ = true; + input1_hw_broadcast_ = true; + } + } + } +} + int MulInt8CPUKernel::ReSize() { size_t input0_size = in_tensors_.at(0)->shape().size(); size_t input1_size = in_tensors_.at(1)->shape().size(); size_t output_size = out_tensors_.at(0)->shape().size(); tile_para->ndim_ = output_size; + if (input0_size == input1_size) { for (size_t i = 0; i < output_size; i++) { tile_para->in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i); @@ -106,6 +141,14 @@ int MulInt8CPUKernel::Run() { input1_data_ = static_cast(in_tensors_.at(1)->MutableData()); output_data_ = static_cast(out_tensors_.at(0)->MutableData()); + CheckIfFastImpl(); + // can implement fast broadcast mul + if (fast_hw_broadcast_) { + elements_num_ = out_tensors_.front()->Batch() * out_tensors_.front()->Height() * out_tensors_.front()->Width(); + count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_; + return ParallelLaunch(this->context_->thread_pool_, FastHWBroadcatMulInt8Run, this, thread_count_); + } + elements_num_ = out_tensors_.at(0)->ElementsNum(); count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_; if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum()) { @@ -132,12 +175,36 @@ int MulInt8CPUKernel::Run() { return ret; } +int FastHWBroadcatMulInt8Run(void *cdata, int task_id) { + auto mul = reinterpret_cast(cdata); + mul->FastDoExecute(task_id); + return lite::RET_OK; +} + int MulInt8Run(void *cdata, int task_id) { auto mul = reinterpret_cast(cdata); mul->DoExecute(task_id); return lite::RET_OK; } +int MulInt8CPUKernel::FastDoExecute(int task_id) { + int depth = out_tensors_.front()->Channel(); + int64_t real_dst_count = MSMIN(elements_num_ - task_id * count_unit_, count_unit_); + if (real_dst_count <= 0) { + return lite::RET_OK; + } + int8_t *cur_input0_data = input0_data_; + int8_t *cur_input1_data = input1_data_ + task_id * count_unit_ * depth; + int8_t *cur_output_data = output_data_ + task_id * count_unit_ * depth; + if (input1_hw_broadcast_) { + cur_input0_data = input1_data_; + cur_input1_data = input0_data_ + task_id * count_unit_ * depth; + } + FastMul(cur_input0_data, cur_input1_data, cur_output_data, depth, real_dst_count, input1_hw_broadcast_, + para_.mul_quant_arg_); + return RET_OK; +} + int MulInt8CPUKernel::DoExecute(int task_id) { int64_t real_dst_count = MSMIN(elements_num_ - task_id * count_unit_, count_unit_); if (real_dst_count <= 0) { diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.h index d2657793d4..1c3dc18ae7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.h @@ -35,13 +35,18 @@ class MulInt8CPUKernel : public LiteKernel { int Init() override; int ReSize() override; + void CheckSameShapeSize(std::vector in_tensor0_shape, std::vector in_tensor1_shape); + void CheckIfFastImpl(); int Run() override; int DoExecute(int task_id); + int FastDoExecute(int task_id); private: const lite::InnerContext *ctx_ = nullptr; ArithmeticParameter *tile_para = nullptr; MulParameter para_; + bool fast_hw_broadcast_ = false; + bool input1_hw_broadcast_ = false; int thread_count_ = 1; int64_t elements_num_ = 0; int64_t count_unit_ = 0; @@ -51,6 +56,7 @@ class MulInt8CPUKernel : public LiteKernel { }; int MulInt8Run(void *cdata, int task_id); +int FastHWBroadcatMulInt8Run(void *cdata, int task_id); } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MUL_INT8_H_