| @@ -125,25 +125,15 @@ void IndirectGemmFp16_16x8_c8(float16_t *output, float16_t *input, float16_t *we | |||
| void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data, | |||
| float16_t *col_major_input, float16_t *output_data, int task_id, ConvParameter *conv_param) { | |||
| const int tile_n = 16; | |||
| int kernel_h = conv_param->kernel_h_; | |||
| int kernel_w = conv_param->kernel_w_; | |||
| int in_batch = conv_param->input_batch_; | |||
| int in_channel = conv_param->input_channel_; | |||
| int in_h = conv_param->input_h_; | |||
| int in_w = conv_param->input_w_; | |||
| int out_h = conv_param->output_h_; | |||
| int out_w = conv_param->output_w_; | |||
| int out_channel = conv_param->output_channel_; | |||
| int thread_count = conv_param->thread_num_; | |||
| int output_count = out_h * out_w; | |||
| int output_count = conv_param->output_h_ * conv_param->output_w_; | |||
| int output_tile_count = UP_DIV(output_count, tile_n); | |||
| int kernel_plane = kernel_h * kernel_w; | |||
| int deep = kernel_plane * in_channel; | |||
| int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; | |||
| for (int b = 0; b < in_batch; b++) { | |||
| int in_batch_offset = b * in_channel * in_h * in_w; | |||
| int out_batch_offset = b * out_channel * out_h * out_w; | |||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { | |||
| for (int b = 0; b < conv_param->input_batch_; b++) { | |||
| int in_batch_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; | |||
| int out_batch_offset = b * out_channel * output_count; | |||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { | |||
| int start_index = thread_id * tile_n; | |||
| int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n; | |||
| float16_t *gemm_input = packed_input + task_id * deep * tile_n; | |||
| @@ -166,18 +156,13 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa | |||
| float16_t *output_data, TmpBufferAddressFp16 *buffer_list, int task_id, ConvParameter *conv_param, | |||
| InputTransFp16Func in_func, OutputTransFp16Func out_func) { | |||
| const int tile_num = 16; | |||
| int thread_num = conv_param->thread_num_; | |||
| int input_unit = conv_param->input_unit_; | |||
| int in_batch = conv_param->input_batch_; | |||
| int in_channel = conv_param->input_channel_; | |||
| int out_unit = conv_param->output_unit_; | |||
| int out_w_block = UP_DIV(conv_param->output_w_, out_unit); | |||
| int out_h_block = UP_DIV(conv_param->output_h_, out_unit); | |||
| int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_); | |||
| int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_); | |||
| int output_count = out_w_block * out_h_block; | |||
| int output_tile_count = UP_DIV(output_count, tile_num); | |||
| int out_channel = conv_param->output_channel_; | |||
| int oc8 = UP_DIV(out_channel, C8NUM); | |||
| int input_unit_square = input_unit * input_unit; | |||
| int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); | |||
| int input_unit_square = conv_param->input_unit_ * conv_param->input_unit_; | |||
| float16_t *trans_input = buffer_list[0]; | |||
| float16_t *gemm_out = buffer_list[1]; | |||
| @@ -189,10 +174,10 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa | |||
| int col_buffer_offset = tile_num * in_channel; | |||
| // step 1 : filter transform (pre-processed offline) | |||
| // step 2 : input transform (online) | |||
| for (int b = 0; b < in_batch; b++) { | |||
| for (int b = 0; b < conv_param->input_batch_; b++) { | |||
| int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_; | |||
| int out_batch_offset = b * out_channel * conv_param->output_h_ * conv_param->output_w_; | |||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) { | |||
| int out_batch_offset = b * conv_param->output_channel_ * conv_param->output_h_ * conv_param->output_w_; | |||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { | |||
| int out_tile_index = thread_id * tile_num; | |||
| int cal_num = output_count - thread_id * tile_num; | |||
| cal_num = cal_num > tile_num ? tile_num : cal_num; | |||
| @@ -23,30 +23,20 @@ | |||
| // fp32 conv common | |||
| void ConvFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, | |||
| float *col_major_input, float *output_data, int task_id, ConvParameter *conv_param) { | |||
| int kernel_h = conv_param->kernel_h_; | |||
| int kernel_w = conv_param->kernel_w_; | |||
| int in_batch = conv_param->input_batch_; | |||
| int in_channel = conv_param->input_channel_; | |||
| int in_h = conv_param->input_h_; | |||
| int in_w = conv_param->input_w_; | |||
| int out_h = conv_param->output_h_; | |||
| int out_w = conv_param->output_w_; | |||
| int out_channel = conv_param->output_channel_; | |||
| int thread_count = conv_param->thread_num_; | |||
| int output_count = out_h * out_w; | |||
| int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; | |||
| int output_count = conv_param->output_h_ * conv_param->output_w_; | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| const int cal_num = C4NUM; | |||
| #else | |||
| const int cal_num = C12NUM; | |||
| #endif | |||
| int output_tile_count = UP_DIV(output_count, cal_num); | |||
| int kernel_plane = kernel_h * kernel_w; | |||
| int deep = kernel_plane * in_channel; | |||
| for (int b = 0; b < in_batch; b++) { | |||
| int in_batch_offset = b * in_channel * in_h * in_w; | |||
| int out_batch_offset = b * out_channel * out_h * out_w; | |||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { | |||
| for (int b = 0; b < conv_param->input_batch_; b++) { | |||
| int in_batch_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; | |||
| int out_batch_offset = b * out_channel * output_count; | |||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { | |||
| int start_index = thread_id * cal_num; | |||
| int real_cal_num = (output_count - start_index) < cal_num ? (output_count - start_index) : cal_num; | |||
| float *gemm_input = packed_input + task_id * deep * cal_num; | |||
| @@ -73,19 +63,14 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_ | |||
| void ConvWinogardFp32(const float *input_data, const float *trans_weight, const float *bias_data, float *output_data, | |||
| TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, InputTransFunc in_func, | |||
| OutputTransFunc out_func) { | |||
| int thread_num = conv_param->thread_num_; | |||
| int input_unit = conv_param->input_unit_; | |||
| int in_batch = conv_param->input_batch_; | |||
| int in_channel = conv_param->input_channel_; | |||
| int out_unit = conv_param->output_unit_; | |||
| int out_w_block = UP_DIV(conv_param->output_w_, out_unit); | |||
| int out_h_block = UP_DIV(conv_param->output_h_, out_unit); | |||
| int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_); | |||
| int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_); | |||
| int output_count = out_w_block * out_h_block; | |||
| const int tile_num = C12NUM; | |||
| int output_tile_count = UP_DIV(output_count, tile_num); | |||
| int out_channel = conv_param->output_channel_; | |||
| int oc8 = UP_DIV(out_channel, C8NUM); | |||
| int input_unit_square = input_unit * input_unit; | |||
| int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); | |||
| int input_unit_square = conv_param->input_unit_ * conv_param->input_unit_; | |||
| float *trans_input = buffer_list[0]; | |||
| float *gemm_out = buffer_list[1]; | |||
| @@ -97,10 +82,10 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const | |||
| int col_buffer_offset = tile_num * in_channel; | |||
| // step 1 : filter transform (pre-processed offline) | |||
| // step 2 : input transform (online) | |||
| for (int b = 0; b < in_batch; b++) { | |||
| for (int b = 0; b < conv_param->input_batch_; b++) { | |||
| int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_; | |||
| int out_batch_offset = b * out_channel * conv_param->output_w_ * conv_param->output_h_; | |||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) { | |||
| int out_batch_offset = b * conv_param->output_channel_ * conv_param->output_w_ * conv_param->output_h_; | |||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { | |||
| int out_tile_index = thread_id * tile_num; | |||
| int cal_num = output_count - out_tile_index; | |||
| cal_num = cal_num > tile_num ? tile_num : cal_num; | |||
| @@ -20,10 +20,6 @@ | |||
| int AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id, float minf, | |||
| float maxf) { | |||
| int stride_w = pooling_param->stride_w_; | |||
| int stride_h = pooling_param->stride_h_; | |||
| int pad_w = pooling_param->pad_l_; | |||
| int pad_h = pooling_param->pad_u_; | |||
| int win_w = pooling_param->window_w_; | |||
| int win_h = pooling_param->window_h_; | |||
| int channel = pooling_param->input_channel_; | |||
| @@ -32,10 +28,8 @@ int AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pool | |||
| int in_h = pooling_param->input_h_; | |||
| int output_w = pooling_param->output_w_; | |||
| int output_h = pooling_param->output_h_; | |||
| int output_batch = pooling_param->output_batch_; | |||
| int out_plane = output_w * output_h; | |||
| int out_tile_count = UP_DIV(out_plane, TILE_NUM); | |||
| int thread_num = pooling_param->thread_num_; | |||
| int window = win_w * win_h; | |||
| #ifdef ENABLE_NEON | |||
| @@ -43,18 +37,18 @@ int AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pool | |||
| float32x4_t max_value = vdupq_n_f32(maxf); | |||
| #endif | |||
| for (int batch = 0; batch < output_batch; batch++) { | |||
| for (int batch = 0; batch < pooling_param->output_batch_; batch++) { | |||
| const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel; | |||
| float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel; | |||
| for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { | |||
| for (int thread_id = task_id; thread_id < out_tile_count; thread_id += pooling_param->thread_num_) { | |||
| int cal_start_index = thread_id * TILE_NUM; | |||
| int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); | |||
| for (int i = 0; i < real_cal_num; i++) { | |||
| int index = cal_start_index + i; | |||
| int out_w_index = index % output_w; | |||
| int out_h_index = index / output_w; | |||
| int in_w_index = out_w_index * stride_w - pad_w; | |||
| int in_h_index = out_h_index * stride_h - pad_h; | |||
| int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; | |||
| int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; | |||
| const float *src_plane_ptr = src_b_ptr; | |||
| float *dst_plane_ptr = dst_b_ptr + index * channel; | |||
| @@ -152,10 +146,6 @@ int AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pool | |||
| void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id, float minf, | |||
| float maxf) { | |||
| int stride_w = pooling_param->stride_w_; | |||
| int stride_h = pooling_param->stride_h_; | |||
| int pad_w = pooling_param->pad_l_; | |||
| int pad_h = pooling_param->pad_u_; | |||
| int win_w = pooling_param->window_w_; | |||
| int win_h = pooling_param->window_h_; | |||
| int channel = pooling_param->input_channel_; | |||
| @@ -166,7 +156,6 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo | |||
| int output_batch = pooling_param->output_batch_; | |||
| int out_plane = output_w * output_h; | |||
| int out_tile_count = UP_DIV(out_plane, TILE_NUM); | |||
| int thread_num = pooling_param->thread_num_; | |||
| int c4 = channel / C4NUM; /* oc && ic */ | |||
| #ifdef ENABLE_NEON | |||
| @@ -177,15 +166,15 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo | |||
| for (int batch = 0; batch < output_batch; batch++) { | |||
| const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel; | |||
| float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel; | |||
| for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { | |||
| for (int thread_id = task_id; thread_id < out_tile_count; thread_id += pooling_param->thread_num_) { | |||
| int cal_start_index = thread_id * TILE_NUM; | |||
| int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); | |||
| for (int i = 0; i < real_cal_num; i++) { | |||
| int index = cal_start_index + i; | |||
| int out_w_index = index % output_w; | |||
| int out_h_index = index / output_w; | |||
| int in_w_index = out_w_index * stride_w - pad_w; | |||
| int in_h_index = out_h_index * stride_h - pad_h; | |||
| int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; | |||
| int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; | |||
| const float *src_plane_ptr = src_b_ptr; | |||
| float *dst_plane_ptr = dst_b_ptr + index * channel; | |||
| @@ -65,20 +65,12 @@ void Conv3x3Int8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, in | |||
| void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int8_t *packed_weight, | |||
| const int32_t *bias_data, int8_t *output_data, int32_t *filter_zp, int32_t *input_sum, int task_id, | |||
| ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func, bool is_optimize) { | |||
| int kernel_h = conv_param->kernel_h_; | |||
| int kernel_w = conv_param->kernel_w_; | |||
| int in_batch = conv_param->input_batch_; | |||
| int in_channel = conv_param->input_channel_; | |||
| int in_h = conv_param->input_h_; | |||
| int in_w = conv_param->input_w_; | |||
| int out_h = conv_param->output_h_; | |||
| int out_w = conv_param->output_w_; | |||
| int out_channel = conv_param->output_channel_; | |||
| int tile_n = conv_param->tile_num_; | |||
| int thread_count = conv_param->thread_num_; | |||
| int output_count = out_h * out_w; | |||
| int output_count = conv_param->output_h_ * conv_param->output_w_; | |||
| int output_tile_count = UP_DIV(output_count, tile_n); | |||
| int kernel_plane = kernel_h * kernel_w; | |||
| int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_; | |||
| int unit_size; | |||
| int input_sum_offset; | |||
| int up_round_oc; | |||
| @@ -103,10 +95,10 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, in | |||
| per_channel = false; | |||
| } | |||
| for (int b = 0; b < in_batch; b++) { | |||
| int in_batch_offset = b * in_channel * in_h * in_w; | |||
| int out_batch_offset = b * out_channel * out_h * out_w; | |||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { | |||
| for (int b = 0; b < conv_param->input_batch_; b++) { | |||
| int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_; | |||
| int out_batch_offset = b * out_channel * conv_param->output_h_ * conv_param->output_w_; | |||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { | |||
| int start_index = thread_id * tile_n; | |||
| int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n; | |||
| int32_t *tmp_input_sum = input_sum + task_id * input_sum_offset; | |||
| @@ -858,23 +850,20 @@ void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t | |||
| void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, | |||
| int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out, | |||
| int task_id, ConvParameter *conv_param) { | |||
| int thread_count = conv_param->thread_num_; | |||
| int ic8 = UP_DIV(conv_param->input_channel_, C8NUM); | |||
| int output_channel = conv_param->output_channel_; | |||
| int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT); | |||
| int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT); | |||
| int output_count = out_w_block * out_h_block; | |||
| int output_tile_count = UP_DIV(output_count, TILE_NUM); | |||
| int oc4 = UP_DIV(output_channel, C4NUM); | |||
| int oc4 = UP_DIV(conv_param->output_channel_, C4NUM); | |||
| int tile_buffer_offset = TILE_NUM * 16 * ic8 * C8NUM; | |||
| const int block_unit_buffer_offset = 16 * C8NUM; | |||
| int tmp_dst_buffer_offset = TILE_NUM * 16 * oc4 * C4NUM; | |||
| int input_batch = conv_param->input_batch_; | |||
| for (int batch = 0; batch < input_batch; batch++) { | |||
| for (int batch = 0; batch < conv_param->input_batch_; batch++) { | |||
| int in_batch_offset = batch * ic8 * C8NUM * conv_param->input_h_ * conv_param->input_w_; | |||
| int tmp_out_batch_offset = batch * oc4 * C4NUM * conv_param->output_w_ * conv_param->output_h_; | |||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { | |||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { | |||
| int start_index = thread_id * TILE_NUM; | |||
| int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; | |||
| @@ -883,7 +872,7 @@ void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bi | |||
| out_w_block, conv_param); | |||
| Conv3x3Int8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset, | |||
| transed_weight, output_channel, ic8, real_cal_num); | |||
| transed_weight, conv_param->output_channel_, ic8, real_cal_num); | |||
| Conv3x3Int8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset, | |||
| bias_data, start_index, real_cal_num, out_w_block, conv_param); | |||
| @@ -24,15 +24,13 @@ | |||
| #ifdef ENABLE_NEON | |||
| int16x4_t ClacSumHalfWordMul(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t left_shift_out_vec, | |||
| int32x4_t output_multiplier_vec, MulQuantArg para) { | |||
| int32x4_t input_scale = vmulq_s32(scaled_input0, scaled_input1); | |||
| int32x4_t raw_sum = RoundingDivideByPOTInt32x4( | |||
| SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec), | |||
| para.shift_right_); | |||
| raw_sum = vaddq_s32(raw_sum, vdupq_n_s32(para.out_quant_arg_.zp_)); | |||
| raw_sum = vmaxq_s32(raw_sum, vdupq_n_s32(para.output_activation_min_)); | |||
| raw_sum = vminq_s32(raw_sum, vdupq_n_s32(para.output_activation_max_)); | |||
| int16x4_t ClacSumHalfWordMul(int16x4_t scaled_input0, int16x4_t scaled_input1, int32x4_t left_shift_out_vec, | |||
| int32x4_t right_shift_out_vec, int32x4_t output_multiplier_vec) { | |||
| int32x4_t input_scale = vmull_s16(scaled_input0, scaled_input1); | |||
| int32x4_t raw_sum = vqrdmulhq_s32(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec); | |||
| const int32x4_t fixup = vshrq_n_s32(vandq_s32(raw_sum, right_shift_out_vec), 31); | |||
| const int32x4_t fixed_up_x = vqaddq_s32(raw_sum, fixup); | |||
| raw_sum = vrshlq_s32(fixed_up_x, right_shift_out_vec); | |||
| return vqmovn_s32(raw_sum); | |||
| } | |||
| @@ -40,27 +38,189 @@ void MulInt8NEON(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, | |||
| MulQuantArg para, int *index) { | |||
| int32x4_t output_multiplier_vec = vdupq_n_s32(para.output_multiplier_); | |||
| int32x4_t left_shift_out_vec = vdupq_n_s32(1 << para.shift_left_); | |||
| int32x4_t right_shift_out_vec = vdupq_n_s32(-para.shift_right_); | |||
| int16x8_t out_zp_vec = vdupq_n_s16(para.out_quant_arg_.zp_); | |||
| int8x16_t out_min_vec = vdupq_n_s8(para.output_activation_min_); | |||
| int8x16_t out_max_vec = vdupq_n_s8(para.output_activation_max_); | |||
| int8x8_t out_min_vec_s8 = vdup_n_s8(para.output_activation_min_); | |||
| int8x8_t out_max_vec_s8 = vdup_n_s8(para.output_activation_max_); | |||
| for (; (*index) <= real_dst_count - 16; (*index) += 16) { | |||
| int16x8_t zp1_vec = vdupq_n_s16(para.in_quant_args_[0].zp_); | |||
| int16x8_t zp2_vec = vdupq_n_s16(para.in_quant_args_[1].zp_); | |||
| int8x16_t input0_vec = vld1q_s8(input0_data + *index); | |||
| int8x16_t input1_vec = vld1q_s8(input1_data + *index); | |||
| int16x8_t input0_low = vmovl_s8(vget_low_s8(input0_vec)); | |||
| int16x8_t input0_high = vmovl_s8(vget_high_s8(input0_vec)); | |||
| int16x8_t input1_low = vmovl_s8(vget_low_s8(input1_vec)); | |||
| int16x8_t input1_high = vmovl_s8(vget_high_s8(input1_vec)); | |||
| input0_low = vaddq_s16(input0_low, zp1_vec); | |||
| input0_high = vaddq_s16(input0_high, zp1_vec); | |||
| input1_low = vaddq_s16(input1_low, zp2_vec); | |||
| input1_high = vaddq_s16(input1_high, zp2_vec); | |||
| int16x4_t input0_low_low = vget_low_s16(input0_low); | |||
| int16x4_t input0_low_high = vget_high_s16(input0_low); | |||
| int16x4_t input0_high_low = vget_low_s16(input0_high); | |||
| int16x4_t input0_high_high = vget_high_s16(input0_high); | |||
| int16x4_t input1_low_low = vget_low_s16(input1_low); | |||
| int16x4_t input1_low_high = vget_high_s16(input1_low); | |||
| int16x4_t input1_high_low = vget_low_s16(input1_high); | |||
| int16x4_t input1_high_high = vget_high_s16(input1_high); | |||
| int16x4_t sum_low_low = ClacSumHalfWordMul(input0_low_low, input1_low_low, left_shift_out_vec, right_shift_out_vec, | |||
| output_multiplier_vec); | |||
| int16x4_t sum_low_high = ClacSumHalfWordMul(input0_low_high, input1_low_high, left_shift_out_vec, | |||
| right_shift_out_vec, output_multiplier_vec); | |||
| int16x4_t sum_high_low = ClacSumHalfWordMul(input0_high_low, input1_high_low, left_shift_out_vec, | |||
| right_shift_out_vec, output_multiplier_vec); | |||
| int16x4_t sum_high_high = ClacSumHalfWordMul(input0_high_high, input1_high_high, left_shift_out_vec, | |||
| right_shift_out_vec, output_multiplier_vec); | |||
| int16x8_t res_s16 = vaddq_s16(vcombine_s16(sum_low_low, sum_low_high), out_zp_vec); | |||
| int16x8_t res_s162 = vaddq_s16(vcombine_s16(sum_high_low, sum_high_high), out_zp_vec); | |||
| int8x8_t res_u8_n0 = vqmovn_s16(res_s16); | |||
| int8x8_t res_u8_n1 = vqmovn_s16(res_s162); | |||
| int8x16_t res_s8 = vcombine_s8(res_u8_n0, res_u8_n1); | |||
| res_s8 = vminq_s8(res_s8, out_max_vec); | |||
| res_s8 = vmaxq_s8(res_s8, out_min_vec); | |||
| vst1q_s8(output_data, res_s8); | |||
| output_data += 16; | |||
| } | |||
| for (; (*index) <= real_dst_count - 8; (*index) += 8) { | |||
| int16x8_t input0_val = LoadAndAddOffset(input0_data, *index, para.in_quant_args_[0].zp_); | |||
| int16x8_t input1_val = LoadAndAddOffset(input1_data, *index, para.in_quant_args_[1].zp_); | |||
| int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val)); | |||
| int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val)); | |||
| int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val)); | |||
| int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val)); | |||
| int16x4_t input0_low = vget_low_s16(input0_val); | |||
| int16x4_t input0_high = vget_high_s16(input0_val); | |||
| int16x4_t input1_low = vget_low_s16(input1_val); | |||
| int16x4_t input1_high = vget_high_s16(input1_val); | |||
| int16x4_t sum_low = ClacSumHalfWordMul(input0_low, input1_low, left_shift_out_vec, output_multiplier_vec, para); | |||
| int16x4_t sum_high = ClacSumHalfWordMul(input0_high, input1_high, left_shift_out_vec, output_multiplier_vec, para); | |||
| int16x4_t sum_low = | |||
| ClacSumHalfWordMul(input0_low, input1_low, left_shift_out_vec, right_shift_out_vec, output_multiplier_vec); | |||
| int16x4_t sum_high = | |||
| ClacSumHalfWordMul(input0_high, input1_high, left_shift_out_vec, right_shift_out_vec, output_multiplier_vec); | |||
| int16x8_t res_s16 = vcombine_s16(sum_low, sum_high); | |||
| int16x8_t res_s16 = vaddq_s16(vcombine_s16(sum_low, sum_high), out_zp_vec); | |||
| int8x8_t res_u8_n0 = vqmovn_s16(res_s16); | |||
| res_u8_n0 = vmin_s8(res_u8_n0, out_max_vec_s8); | |||
| res_u8_n0 = vmax_s8(res_u8_n0, out_min_vec_s8); | |||
| vst1_s8(output_data, res_u8_n0); | |||
| output_data += 8; | |||
| } | |||
| } | |||
| #endif | |||
| void FastMul(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int depth, int64_t real_dst_count, | |||
| bool input1_broad, MulQuantArg para) { | |||
| // input0 need broadcast | |||
| int32_t zp1 = para.in_quant_args_[0].zp_; | |||
| int32_t zp2 = para.in_quant_args_[1].zp_; | |||
| if (input1_broad) { | |||
| zp1 = para.in_quant_args_[1].zp_; | |||
| zp2 = para.in_quant_args_[0].zp_; | |||
| } | |||
| #ifdef ENABLE_ARM | |||
| int32x4_t output_multiplier_vec = vdupq_n_s32(para.output_multiplier_); | |||
| int32x4_t left_shift_out_vec = vdupq_n_s32(1 << para.shift_left_); | |||
| int32x4_t right_shift_out_vec = vdupq_n_s32(-para.shift_right_); | |||
| int16x8_t out_zp_vec = vdupq_n_s16(para.out_quant_arg_.zp_); | |||
| int8x16_t out_min_vec = vdupq_n_s8(para.output_activation_min_); | |||
| int8x16_t out_max_vec = vdupq_n_s8(para.output_activation_max_); | |||
| int8x8_t out_min_vec_s8 = vdup_n_s8(para.output_activation_min_); | |||
| int8x8_t out_max_vec_s8 = vdup_n_s8(para.output_activation_max_); | |||
| int16x8_t zp1_vec = vdupq_n_s16(zp1); | |||
| int16x8_t zp2_vec = vdupq_n_s16(zp2); | |||
| #endif | |||
| for (int index = 0; index < real_dst_count; ++index) { | |||
| int j = 0; | |||
| #ifdef ENABLE_ARM | |||
| for (; j <= depth - 16; j += 16) { | |||
| int8x16_t input0_vec = vld1q_s8(input0_data + j); | |||
| int8x16_t input1_vec = vld1q_s8(input1_data); | |||
| int16x8_t input0_low = vmovl_s8(vget_low_s8(input0_vec)); | |||
| int16x8_t input0_high = vmovl_s8(vget_high_s8(input0_vec)); | |||
| int16x8_t input1_low = vmovl_s8(vget_low_s8(input1_vec)); | |||
| int16x8_t input1_high = vmovl_s8(vget_high_s8(input1_vec)); | |||
| input0_low = vaddq_s16(input0_low, zp1_vec); | |||
| input0_high = vaddq_s16(input0_high, zp1_vec); | |||
| input1_low = vaddq_s16(input1_low, zp2_vec); | |||
| input1_high = vaddq_s16(input1_high, zp2_vec); | |||
| int16x4_t input0_low_low = vget_low_s16(input0_low); | |||
| int16x4_t input0_low_high = vget_high_s16(input0_low); | |||
| int16x4_t input0_high_low = vget_low_s16(input0_high); | |||
| int16x4_t input0_high_high = vget_high_s16(input0_high); | |||
| int16x4_t input1_low_low = vget_low_s16(input1_low); | |||
| int16x4_t input1_low_high = vget_high_s16(input1_low); | |||
| int16x4_t input1_high_low = vget_low_s16(input1_high); | |||
| int16x4_t input1_high_high = vget_high_s16(input1_high); | |||
| int16x4_t sum_low_low = ClacSumHalfWordMul(input0_low_low, input1_low_low, left_shift_out_vec, | |||
| right_shift_out_vec, output_multiplier_vec); | |||
| int16x4_t sum_low_high = ClacSumHalfWordMul(input0_low_high, input1_low_high, left_shift_out_vec, | |||
| right_shift_out_vec, output_multiplier_vec); | |||
| int16x4_t sum_high_low = ClacSumHalfWordMul(input0_high_low, input1_high_low, left_shift_out_vec, | |||
| right_shift_out_vec, output_multiplier_vec); | |||
| int16x4_t sum_high_high = ClacSumHalfWordMul(input0_high_high, input1_high_high, left_shift_out_vec, | |||
| right_shift_out_vec, output_multiplier_vec); | |||
| int16x8_t res_s16 = vaddq_s16(vcombine_s16(sum_low_low, sum_low_high), out_zp_vec); | |||
| int16x8_t res_s162 = vaddq_s16(vcombine_s16(sum_high_low, sum_high_high), out_zp_vec); | |||
| int8x8_t res_u8_n0 = vqmovn_s16(res_s16); | |||
| int8x8_t res_u8_n1 = vqmovn_s16(res_s162); | |||
| int8x16_t res_s8 = vcombine_s8(res_u8_n0, res_u8_n1); | |||
| res_s8 = vminq_s8(res_s8, out_max_vec); | |||
| res_s8 = vmaxq_s8(res_s8, out_min_vec); | |||
| vst1q_s8(output_data, res_s8); | |||
| input1_data += 16; | |||
| output_data += 16; | |||
| } | |||
| for (; j <= depth - 8; j += 8) { | |||
| int8x8_t input0_vec = vld1_s8(input0_data + j); | |||
| int8x8_t input1_vec = vld1_s8(input1_data); | |||
| int16x8_t input0_val = vmovl_s8(input0_vec); | |||
| int16x8_t input1_val = vmovl_s8(input1_vec); | |||
| input0_val = vaddq_s16(input0_val, zp1_vec); | |||
| input1_val = vaddq_s16(input1_val, zp2_vec); | |||
| int16x4_t input0_low = vget_low_s16(input0_val); | |||
| int16x4_t input0_high = vget_high_s16(input0_val); | |||
| int16x4_t input1_low = vget_low_s16(input1_val); | |||
| int16x4_t input1_high = vget_high_s16(input1_val); | |||
| int16x4_t sum_low = | |||
| ClacSumHalfWordMul(input0_low, input1_low, left_shift_out_vec, right_shift_out_vec, output_multiplier_vec); | |||
| int16x4_t sum_high = | |||
| ClacSumHalfWordMul(input0_high, input1_high, left_shift_out_vec, right_shift_out_vec, output_multiplier_vec); | |||
| int16x8_t res_s16 = vaddq_s16(vcombine_s16(sum_low, sum_high), out_zp_vec); | |||
| int8x8_t res_u8_n0 = vqmovn_s16(res_s16); | |||
| res_u8_n0 = vmin_s8(res_u8_n0, out_max_vec_s8); | |||
| res_u8_n0 = vmax_s8(res_u8_n0, out_min_vec_s8); | |||
| vst1_s8(output_data, res_u8_n0); | |||
| input1_data += 8; | |||
| output_data += 8; | |||
| } | |||
| #endif | |||
| for (; j < depth; ++j) { | |||
| const int32_t input0_val = zp1 + input0_data[j]; | |||
| const int32_t input1_val = zp2 + input1_data[0]; | |||
| int32_t mul_result = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(input0_val * input1_val * (1 << para.shift_left_), para.output_multiplier_), | |||
| para.shift_right_); | |||
| mul_result += para.out_quant_arg_.zp_; | |||
| mul_result = mul_result < para.output_activation_max_ ? mul_result : para.output_activation_max_; | |||
| mul_result = mul_result > para.output_activation_min_ ? mul_result : para.output_activation_min_; | |||
| output_data[0] = (int8_t)mul_result; | |||
| input1_data++; | |||
| output_data++; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| void Mul(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, MulQuantArg para) { | |||
| int index = 0; | |||
| #ifdef ENABLE_NEON | |||
| @@ -74,14 +234,10 @@ void Mul(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t | |||
| para.shift_right_); | |||
| mul_result += para.out_quant_arg_.zp_; | |||
| if (mul_result > para.output_activation_max_) { | |||
| output_data[index] = para.output_activation_max_; | |||
| } else if (mul_result < para.output_activation_min_) { | |||
| output_data[index] = para.output_activation_min_; | |||
| } else { | |||
| output_data[index] = (int8_t)mul_result; | |||
| } | |||
| mul_result = mul_result < para.output_activation_max_ ? mul_result : para.output_activation_max_; | |||
| mul_result = mul_result > para.output_activation_min_ ? mul_result : para.output_activation_min_; | |||
| output_data[0] = (int8_t)mul_result; | |||
| output_data++; | |||
| } | |||
| return; | |||
| } | |||
| @@ -24,6 +24,8 @@ | |||
| extern "C" { | |||
| #endif | |||
| void Mul(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, MulQuantArg para); | |||
| void FastMul(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int depth, int64_t real_dst_count, | |||
| bool input1_broad, MulQuantArg para); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -80,33 +80,24 @@ int AvgPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter | |||
| } | |||
| int AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id) { | |||
| int stride_w = pooling_param->stride_w_; | |||
| int stride_h = pooling_param->stride_h_; | |||
| int pad_w = pooling_param->pad_l_; | |||
| int pad_h = pooling_param->pad_u_; | |||
| int win_w = pooling_param->window_w_; | |||
| int win_h = pooling_param->window_h_; | |||
| int channel = pooling_param->input_channel_; | |||
| int c16 = channel / C16NUM; | |||
| int in_w = pooling_param->input_w_; | |||
| int in_h = pooling_param->input_h_; | |||
| int output_w = pooling_param->output_w_; | |||
| int output_h = pooling_param->output_h_; | |||
| int output_batch = pooling_param->output_batch_; | |||
| int out_plane = output_w * output_h; | |||
| int out_plane = output_w * pooling_param->output_h_; | |||
| int out_tile_count = UP_DIV(out_plane, TILE_NUM); | |||
| int thread_num = out_tile_count < pooling_param->thread_num_ ? out_tile_count : pooling_param->thread_num_; | |||
| float input_scale = pooling_param->quant_args_[0][0].scale_; | |||
| int input_zp = pooling_param->quant_args_[0][0].zp_; | |||
| float output_scale = pooling_param->quant_args_[1][0].scale_; | |||
| int output_zp = pooling_param->quant_args_[1][0].zp_; | |||
| double real_multiplier = input_scale / output_scale; | |||
| int c16 = channel / C16NUM; | |||
| double real_multiplier = pooling_param->quant_args_[0][0].scale_ / pooling_param->quant_args_[1][0].scale_; | |||
| const int8_t out_min = INT8_MIN; | |||
| const int8_t out_max = INT8_MAX; | |||
| for (int batch = 0; batch < output_batch; batch++) { | |||
| int in_batch_offset = batch * in_h * in_w * channel; | |||
| int out_batch_offset = batch * output_h * output_w * channel; | |||
| for (int batch = 0; batch < pooling_param->output_batch_; batch++) { | |||
| int in_batch_offset = batch * pooling_param->input_h_ * in_w * channel; | |||
| int out_batch_offset = batch * pooling_param->output_h_ * output_w * channel; | |||
| for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { | |||
| int cal_start_index = thread_id * TILE_NUM; | |||
| int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); | |||
| @@ -114,14 +105,14 @@ int AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParame | |||
| int index = cal_start_index + i; | |||
| int out_w_index = index % output_w; | |||
| int out_h_index = index / output_w; | |||
| int in_w_index = out_w_index * stride_w - pad_w; | |||
| int in_h_index = out_h_index * stride_h - pad_h; | |||
| int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; | |||
| int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; | |||
| int out_plane_offset = out_batch_offset + index * channel; | |||
| int input_stride = (in_h_index * in_w + in_w_index) * channel; | |||
| int kw_s = MSMAX(0, -in_w_index); | |||
| int kw_e = MSMIN(win_w, in_w - in_w_index); | |||
| int kh_s = MSMAX(0, -in_h_index); | |||
| int kh_e = MSMIN(win_h, in_h - in_h_index); | |||
| int kh_e = MSMIN(win_h, pooling_param->input_h_ - in_h_index); | |||
| int real_count = (kw_e - kw_s) * (kh_e - kh_s); | |||
| if (real_count == 0) { | |||
| return NNACL_ERR; | |||
| @@ -335,19 +326,11 @@ void MaxPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParamete | |||
| void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, | |||
| int task_id) { | |||
| int stride_w = pooling_param->stride_w_; | |||
| int stride_h = pooling_param->stride_h_; | |||
| int pad_w = pooling_param->pad_l_; | |||
| int pad_h = pooling_param->pad_u_; | |||
| int win_w = pooling_param->window_w_; | |||
| int win_h = pooling_param->window_h_; | |||
| int channel = pooling_param->input_channel_; | |||
| int in_w = pooling_param->input_w_; | |||
| int in_h = pooling_param->input_h_; | |||
| int output_w = pooling_param->output_w_; | |||
| int output_h = pooling_param->output_h_; | |||
| int output_batch = pooling_param->output_batch_; | |||
| int out_plane = output_w * output_h; | |||
| int out_plane = output_w * pooling_param->output_h_; | |||
| int out_tile_count = UP_DIV(out_plane, TILE_NUM); | |||
| int thread_num = out_tile_count < pooling_param->thread_num_ ? out_tile_count : pooling_param->thread_num_; | |||
| int c16 = UP_DIV(channel, 16); | |||
| @@ -358,9 +341,9 @@ void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, Poolin | |||
| int output_zp = pooling_param->quant_args_[1][0].zp_; | |||
| double real_multiplier = input_scale / output_scale; | |||
| for (int batch = 0; batch < output_batch; batch++) { | |||
| for (int batch = 0; batch < pooling_param->output_batch_; batch++) { | |||
| int in_batch_offset = batch * in_h * in_w * channel; | |||
| int out_batch_offset = batch * output_h * output_w * channel; | |||
| int out_batch_offset = batch * pooling_param->output_h_ * output_w * channel; | |||
| for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { | |||
| int cal_start_index = thread_id * TILE_NUM; | |||
| int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); | |||
| @@ -368,8 +351,8 @@ void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, Poolin | |||
| int index = cal_start_index + i; | |||
| int out_w_index = index % output_w; | |||
| int out_h_index = index / output_w; | |||
| int in_w_index = out_w_index * stride_w - pad_w; | |||
| int in_h_index = out_h_index * stride_h - pad_h; | |||
| int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; | |||
| int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; | |||
| int out_plane_offset = out_batch_offset + index * channel; | |||
| for (int j = 0; j < c16 - 1; j++) { | |||
| int in_channel_offset = in_batch_offset + j * 16; | |||
| @@ -382,8 +365,8 @@ void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, Poolin | |||
| tmp_max[m] = INT8_MIN; | |||
| } | |||
| #endif | |||
| for (int h = 0; h < win_h; h++) { | |||
| for (int w = 0; w < win_w; w++) { | |||
| for (int h = 0; h < pooling_param->window_h_; h++) { | |||
| for (int w = 0; w < pooling_param->window_w_; w++) { | |||
| if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || | |||
| (in_w_index + w) >= in_w) { | |||
| continue; | |||
| @@ -418,8 +401,8 @@ void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, Poolin | |||
| int in_channel_offset = in_batch_offset + k; | |||
| int out_channel_offset = out_plane_offset + k; | |||
| int8_t tmp_max = INT8_MIN; | |||
| for (int h = 0; h < win_h; h++) { | |||
| for (int w = 0; w < win_w; w++) { | |||
| for (int h = 0; h < pooling_param->window_h_; h++) { | |||
| for (int w = 0; w < pooling_param->window_w_; w++) { | |||
| if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || | |||
| (in_w_index + w) >= in_w) { | |||
| continue; | |||
| @@ -437,26 +420,17 @@ void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, Poolin | |||
| } | |||
| void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id) { | |||
| int stride_w = pooling_param->stride_w_; | |||
| int stride_h = pooling_param->stride_h_; | |||
| int pad_w = pooling_param->pad_l_; | |||
| int pad_h = pooling_param->pad_u_; | |||
| int win_w = pooling_param->window_w_; | |||
| int win_h = pooling_param->window_h_; | |||
| int channel = pooling_param->input_channel_; | |||
| int in_w = pooling_param->input_w_; | |||
| int in_h = pooling_param->input_h_; | |||
| int output_w = pooling_param->output_w_; | |||
| int output_h = pooling_param->output_h_; | |||
| int output_batch = pooling_param->output_batch_; | |||
| int out_plane = output_w * output_h; | |||
| int out_plane = output_w * pooling_param->output_h_; | |||
| int out_tile_count = UP_DIV(out_plane, TILE_NUM); | |||
| int thread_num = MSMIN(out_tile_count, pooling_param->thread_num_); | |||
| int8_t out_array[MAX_MAXPOOL_SIZE]; | |||
| for (int batch = 0; batch < output_batch; batch++) { | |||
| int in_batch_offset = batch * in_h * in_w * channel; | |||
| int out_batch_offset = batch * output_h * output_w * channel; | |||
| for (int batch = 0; batch < pooling_param->output_batch_; batch++) { | |||
| int in_batch_offset = batch * pooling_param->input_h_ * in_w * channel; | |||
| int out_batch_offset = batch * pooling_param->output_h_ * output_w * channel; | |||
| for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { | |||
| int cal_start_index = thread_id * TILE_NUM; | |||
| int real_cal_num = out_plane - cal_start_index; | |||
| @@ -465,12 +439,12 @@ void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam | |||
| int index = cal_start_index + i; | |||
| int out_w_index = index % output_w; | |||
| int out_h_index = index / output_w; | |||
| int in_w_index = out_w_index * stride_w - pad_w; | |||
| int in_h_index = out_h_index * stride_h - pad_h; | |||
| int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; | |||
| int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; | |||
| const int ky_s = 0 > (-in_h_index) ? 0 : (-in_h_index); | |||
| int ky_e = MSMIN(win_h, in_h - in_h_index); | |||
| int ky_e = MSMIN(pooling_param->window_h_, pooling_param->input_h_ - in_h_index); | |||
| const int kx_s = 0 > (-in_w_index) ? 0 : (-in_w_index); | |||
| int kx_e = MSMIN(win_w, in_w - in_w_index); | |||
| int kx_e = MSMIN(pooling_param->window_w_, in_w - in_w_index); | |||
| int input_stride = (in_h_index * in_w + in_w_index) * channel + in_batch_offset; | |||
| int out_plane_offset = out_batch_offset + index * channel; | |||
| @@ -47,20 +47,17 @@ void DoSpaceToBatchNHWCInt8(const int8_t *input, int8_t *output, const int *bloc | |||
| } | |||
| void DoSpaceToBatchPaddingNHWCInt8(const int8_t *input, int8_t *output, SpaceToBatchParameter *param, int32_t zp) { | |||
| int *in_shape = param->input_shape_; | |||
| int *out_shape = param->output_shape_; | |||
| int *paddings = param->paddings_; | |||
| int block_shape_h = param->block_sizes_[0]; | |||
| int block_shape_w = param->m_ == 2 ? param->block_sizes_[1] : 1; | |||
| int in_b = in_shape[0]; | |||
| int in_h = in_shape[1]; | |||
| int in_w = in_shape[2]; | |||
| int channel = in_shape[3]; | |||
| int out_h = out_shape[1]; | |||
| int out_w = out_shape[2]; | |||
| int pad_t = paddings[0]; | |||
| int pad_l = param->m_ == 2 ? paddings[2] : 0; | |||
| for (int i = 0; i < out_shape[0]; ++i) { | |||
| int in_b = param->input_shape_[0]; | |||
| int in_h = param->input_shape_[1]; | |||
| int in_w = param->input_shape_[2]; | |||
| int channel = param->input_shape_[3]; | |||
| int out_h = param->output_shape_[1]; | |||
| int out_w = param->output_shape_[2]; | |||
| int pad_t = param->paddings_[0]; | |||
| int pad_l = param->m_ == 2 ? param->paddings_[2] : 0; | |||
| for (int i = 0; i < param->output_shape_[0]; ++i) { | |||
| int in_batch = i % in_b; | |||
| int offset_w = (i / in_b) % block_shape_w; | |||
| int offset_h = (i / in_b) / block_shape_w; | |||
| @@ -219,24 +219,19 @@ void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, floa | |||
| int kernel_h = conv_param->kernel_h_; | |||
| int kernel_w = conv_param->kernel_w_; | |||
| int kernel_plane = kernel_h * kernel_w; | |||
| int stride_h = conv_param->stride_h_; | |||
| int stride_w = conv_param->stride_w_; | |||
| int pad_h = conv_param->pad_u_; | |||
| int pad_w = conv_param->pad_l_; | |||
| int dilation_h = conv_param->dilation_h_; | |||
| int dilation_w = conv_param->dilation_w_; | |||
| int in_channel = conv_param->input_channel_; | |||
| int in_h = conv_param->input_h_; | |||
| int in_w = conv_param->input_w_; | |||
| int out_w = conv_param->output_w_; | |||
| for (int i = 0; i < real_cal_num; i++) { | |||
| int block_start = block_index + i; | |||
| int input_h = block_start / out_w * stride_h - pad_h; | |||
| int input_w = block_start % out_w * stride_w - pad_w; | |||
| int input_h = block_start / out_w * conv_param->stride_h_ - conv_param->pad_u_; | |||
| int input_w = block_start % out_w * conv_param->stride_w_ - conv_param->pad_l_; | |||
| int input_stride = (input_h * in_w + input_w) * in_channel; | |||
| int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h)); | |||
| int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h)); | |||
| int kh_e = MSMIN(kernel_h, UP_DIV(conv_param->input_h_ - input_h, dilation_h)); | |||
| int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w)); | |||
| int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w)); | |||
| if (dilation_w == 1 && dilation_h == 1) { | |||
| @@ -62,11 +62,46 @@ int MulInt8CPUKernel::Init() { | |||
| return ReSize(); | |||
| } | |||
| void MulInt8CPUKernel::CheckSameShapeSize(std::vector<int> in_tensor0_shape, std::vector<int> in_tensor1_shape) { | |||
| bool condition1 = in_tensor0_shape[0] == in_tensor1_shape[0]; | |||
| bool condition2 = in_tensor0_shape[1] == 1; | |||
| bool condition3 = in_tensor0_shape[2] == 1; | |||
| bool condition4 = in_tensor0_shape[3] == in_tensor1_shape[3]; | |||
| bool condition5 = in_tensor1_shape[1] == 1; | |||
| bool condition6 = in_tensor1_shape[2] == 1; | |||
| if (condition1 && condition2 && condition3 && condition4) { | |||
| fast_hw_broadcast_ = true; | |||
| } else if (condition1 && condition4 && condition5 && condition6) { | |||
| fast_hw_broadcast_ = true; | |||
| input1_hw_broadcast_ = true; | |||
| } | |||
| } | |||
| void MulInt8CPUKernel::CheckIfFastImpl() { | |||
| auto in_tensor0 = in_tensors_.at(0); | |||
| auto in_tensor1 = in_tensors_.at(1); | |||
| if (in_tensor0->ElementsNum() != in_tensor1->ElementsNum()) { | |||
| if (in_tensor0->shape().size() == 4 && in_tensor1->shape().size() == 4) { | |||
| CheckSameShapeSize(in_tensor0->shape(), in_tensor1->shape()); | |||
| } else if (in_tensor0->shape().size() == 1 && in_tensor1->shape().size() == 4) { | |||
| if (in_tensor0->ElementsNum() == in_tensor1->shape()[3]) { | |||
| fast_hw_broadcast_ = true; | |||
| } | |||
| } else if (in_tensor0->shape().size() == 4 && in_tensor1->shape().size() == 1) { | |||
| if (in_tensor1->ElementsNum() == in_tensor0->shape()[3]) { | |||
| fast_hw_broadcast_ = true; | |||
| input1_hw_broadcast_ = true; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| int MulInt8CPUKernel::ReSize() { | |||
| size_t input0_size = in_tensors_.at(0)->shape().size(); | |||
| size_t input1_size = in_tensors_.at(1)->shape().size(); | |||
| size_t output_size = out_tensors_.at(0)->shape().size(); | |||
| tile_para->ndim_ = output_size; | |||
| if (input0_size == input1_size) { | |||
| for (size_t i = 0; i < output_size; i++) { | |||
| tile_para->in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i); | |||
| @@ -106,6 +141,14 @@ int MulInt8CPUKernel::Run() { | |||
| input1_data_ = static_cast<int8_t *>(in_tensors_.at(1)->MutableData()); | |||
| output_data_ = static_cast<int8_t *>(out_tensors_.at(0)->MutableData()); | |||
| CheckIfFastImpl(); | |||
| // can implement fast broadcast mul | |||
| if (fast_hw_broadcast_) { | |||
| elements_num_ = out_tensors_.front()->Batch() * out_tensors_.front()->Height() * out_tensors_.front()->Width(); | |||
| count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_; | |||
| return ParallelLaunch(this->context_->thread_pool_, FastHWBroadcatMulInt8Run, this, thread_count_); | |||
| } | |||
| elements_num_ = out_tensors_.at(0)->ElementsNum(); | |||
| count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_; | |||
| if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum()) { | |||
| @@ -132,12 +175,36 @@ int MulInt8CPUKernel::Run() { | |||
| return ret; | |||
| } | |||
| int FastHWBroadcatMulInt8Run(void *cdata, int task_id) { | |||
| auto mul = reinterpret_cast<MulInt8CPUKernel *>(cdata); | |||
| mul->FastDoExecute(task_id); | |||
| return lite::RET_OK; | |||
| } | |||
| int MulInt8Run(void *cdata, int task_id) { | |||
| auto mul = reinterpret_cast<MulInt8CPUKernel *>(cdata); | |||
| mul->DoExecute(task_id); | |||
| return lite::RET_OK; | |||
| } | |||
| int MulInt8CPUKernel::FastDoExecute(int task_id) { | |||
| int depth = out_tensors_.front()->Channel(); | |||
| int64_t real_dst_count = MSMIN(elements_num_ - task_id * count_unit_, count_unit_); | |||
| if (real_dst_count <= 0) { | |||
| return lite::RET_OK; | |||
| } | |||
| int8_t *cur_input0_data = input0_data_; | |||
| int8_t *cur_input1_data = input1_data_ + task_id * count_unit_ * depth; | |||
| int8_t *cur_output_data = output_data_ + task_id * count_unit_ * depth; | |||
| if (input1_hw_broadcast_) { | |||
| cur_input0_data = input1_data_; | |||
| cur_input1_data = input0_data_ + task_id * count_unit_ * depth; | |||
| } | |||
| FastMul(cur_input0_data, cur_input1_data, cur_output_data, depth, real_dst_count, input1_hw_broadcast_, | |||
| para_.mul_quant_arg_); | |||
| return RET_OK; | |||
| } | |||
| int MulInt8CPUKernel::DoExecute(int task_id) { | |||
| int64_t real_dst_count = MSMIN(elements_num_ - task_id * count_unit_, count_unit_); | |||
| if (real_dst_count <= 0) { | |||
| @@ -35,13 +35,18 @@ class MulInt8CPUKernel : public LiteKernel { | |||
| int Init() override; | |||
| int ReSize() override; | |||
| void CheckSameShapeSize(std::vector<int> in_tensor0_shape, std::vector<int> in_tensor1_shape); | |||
| void CheckIfFastImpl(); | |||
| int Run() override; | |||
| int DoExecute(int task_id); | |||
| int FastDoExecute(int task_id); | |||
| private: | |||
| const lite::InnerContext *ctx_ = nullptr; | |||
| ArithmeticParameter *tile_para = nullptr; | |||
| MulParameter para_; | |||
| bool fast_hw_broadcast_ = false; | |||
| bool input1_hw_broadcast_ = false; | |||
| int thread_count_ = 1; | |||
| int64_t elements_num_ = 0; | |||
| int64_t count_unit_ = 0; | |||
| @@ -51,6 +56,7 @@ class MulInt8CPUKernel : public LiteKernel { | |||
| }; | |||
| int MulInt8Run(void *cdata, int task_id); | |||
| int FastHWBroadcatMulInt8Run(void *cdata, int task_id); | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MUL_INT8_H_ | |||