| @@ -235,10 +235,13 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::tensor::Ten | |||||
| conv_param->input_w_ = inputs.front()->Width(); | conv_param->input_w_ = inputs.front()->Width(); | ||||
| conv_param->output_h_ = outputs.front()->Height(); | conv_param->output_h_ = outputs.front()->Height(); | ||||
| conv_param->output_w_ = outputs.front()->Width(); | conv_param->output_w_ = outputs.front()->Width(); | ||||
| bool prefer_flag = false; | |||||
| if (conv_param->output_h_ * conv_param->output_w_ > 64) { | |||||
| prefer_flag = true; | |||||
| } | |||||
| kernel::LiteKernel *kernel = nullptr; | kernel::LiteKernel *kernel = nullptr; | ||||
| if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { | |||||
| kernel = new (std::nothrow) kernel::Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| } else if (kernel_h == 1 && kernel_w == 1) { | |||||
| if (kernel_h == 1 && kernel_w == 1) { | |||||
| kernel = new (std::nothrow) kernel::Convolution1x1FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | kernel = new (std::nothrow) kernel::Convolution1x1FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | ||||
| } else { | } else { | ||||
| bool use_winograd = false; | bool use_winograd = false; | ||||
| @@ -249,6 +252,9 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::tensor::Ten | |||||
| if (use_winograd) { | if (use_winograd) { | ||||
| kernel = new (std::nothrow) | kernel = new (std::nothrow) | ||||
| kernel::ConvolutionWinogradFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive, out_unit); | kernel::ConvolutionWinogradFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive, out_unit); | ||||
| } else if (prefer_flag && kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && | |||||
| dilation_w == 1) { | |||||
| kernel = new (std::nothrow) kernel::Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| } | } | ||||
| if (kernel_h != 1 && kernel_w != 1 && !use_winograd) { | if (kernel_h != 1 && kernel_w != 1 && !use_winograd) { | ||||
| kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | ||||
| @@ -115,7 +115,7 @@ int ConvolutionInt8CPUKernel::InitWeightBias() { | |||||
| size_t input_sum_size; | size_t input_sum_size; | ||||
| if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) { | if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) { | ||||
| input_sum_size = conv_param_->output_channel_ * tile_num_ * thread_count_ * sizeof(int32_t); | |||||
| input_sum_size = oc4 * C4NUM * tile_num_ * thread_count_ * sizeof(int32_t); | |||||
| } else { | } else { | ||||
| input_sum_size = tile_num_ * thread_count_ * sizeof(int32_t); | input_sum_size = tile_num_ * thread_count_ * sizeof(int32_t); | ||||
| } | } | ||||
| @@ -202,7 +202,7 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() { | |||||
| size_t input_sum_size; | size_t input_sum_size; | ||||
| if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) { | if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) { | ||||
| input_sum_size = conv_param_->output_channel_ * tile_num_ * thread_count_ * sizeof(int32_t); | |||||
| input_sum_size = oc4 * C4NUM * tile_num_ * thread_count_ * sizeof(int32_t); | |||||
| } else { | } else { | ||||
| input_sum_size = tile_num_ * thread_count_ * sizeof(int32_t); | input_sum_size = tile_num_ * thread_count_ * sizeof(int32_t); | ||||
| } | } | ||||
| @@ -169,9 +169,9 @@ void ConvDwInt8(int8_t *output_data, const int16_t *input_data, const int16_t *w | |||||
| const int32_t *bias = bias_data + oc * C4NUM; | const int32_t *bias = bias_data + oc * C4NUM; | ||||
| if (per_channel) { | if (per_channel) { | ||||
| out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_ + oc; | |||||
| left_shift = conv_param->conv_quant_arg_.left_shift_ + oc; | |||||
| right_shift = conv_param->conv_quant_arg_.right_shift_ + oc; | |||||
| out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_ + oc * C4NUM; | |||||
| left_shift = conv_param->conv_quant_arg_.left_shift_ + oc * C4NUM; | |||||
| right_shift = conv_param->conv_quant_arg_.right_shift_ + oc * C4NUM; | |||||
| } | } | ||||
| DepthwiseBorderInt8(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param, | DepthwiseBorderInt8(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param, | ||||
| @@ -28,6 +28,7 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const in | |||||
| int32_t out_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; | int32_t out_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; | ||||
| int32_t act_min = conv_param->conv_quant_arg_.out_act_min_[0]; | int32_t act_min = conv_param->conv_quant_arg_.out_act_min_[0]; | ||||
| int32_t act_max = conv_param->conv_quant_arg_.out_act_max_[0]; | int32_t act_max = conv_param->conv_quant_arg_.out_act_max_[0]; | ||||
| int oc4 = UP_DIV(output_channel, C4NUM); | |||||
| #ifdef __aarch64__ | #ifdef __aarch64__ | ||||
| IndirectGemmInt8_4x4(dst, src, weight, bias, UP_DIV(kernel_plane, C4NUM), ic4, output_channel, | IndirectGemmInt8_4x4(dst, src, weight, bias, UP_DIV(kernel_plane, C4NUM), ic4, output_channel, | ||||
| output_channel * sizeof(int8_t), input_sum, act_min, act_max, out_zp, out_multiplier, | output_channel * sizeof(int8_t), input_sum, act_min, act_max, out_zp, out_multiplier, | ||||
| @@ -96,7 +97,7 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const in | |||||
| dst[dst_tile_offset] = (int8_t)result; | dst[dst_tile_offset] = (int8_t)result; | ||||
| } else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) && | } else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) && | ||||
| (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { | (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { | ||||
| tmp_dst[dst_tile_offset] -= input_sum[n * output_channel + oc]; | |||||
| tmp_dst[dst_tile_offset] -= input_sum[n * oc4 * C4NUM + oc]; | |||||
| int result = tmp_dst[dst_tile_offset] + bias[oc]; | int result = tmp_dst[dst_tile_offset] + bias[oc]; | ||||
| result = RoundingDivideByPOT( | result = RoundingDivideByPOT( | ||||
| SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[oc]), out_multiplier[oc]), | SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[oc]), out_multiplier[oc]), | ||||
| @@ -120,6 +121,7 @@ void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const | |||||
| int32_t out_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; | int32_t out_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; | ||||
| int32_t act_min = conv_param->conv_quant_arg_.out_act_min_[0]; | int32_t act_min = conv_param->conv_quant_arg_.out_act_min_[0]; | ||||
| int32_t act_max = conv_param->conv_quant_arg_.out_act_max_[0]; | int32_t act_max = conv_param->conv_quant_arg_.out_act_max_[0]; | ||||
| int oc4 = UP_DIV(output_channel, C4NUM); | |||||
| if (gemm_func != NULL) { | if (gemm_func != NULL) { | ||||
| #ifdef __aarch64__ | #ifdef __aarch64__ | ||||
| gemm_func(dst, src, weight, bias, kernel_plane, ic4, output_channel, output_channel * sizeof(int8_t), input_sum, | gemm_func(dst, src, weight, bias, kernel_plane, ic4, output_channel, output_channel * sizeof(int8_t), input_sum, | ||||
| @@ -181,7 +183,7 @@ void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const | |||||
| dst[dst_tile_offset] = (int8_t)result; | dst[dst_tile_offset] = (int8_t)result; | ||||
| } else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) && | } else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) && | ||||
| (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { | (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { | ||||
| tmp_dst[dst_tile_offset] -= input_sum[n * output_channel + oc]; | |||||
| tmp_dst[dst_tile_offset] -= input_sum[n * oc4 * C4NUM + oc]; | |||||
| int result = tmp_dst[dst_tile_offset] + bias[oc]; | int result = tmp_dst[dst_tile_offset] + bias[oc]; | ||||
| result = RoundingDivideByPOT( | result = RoundingDivideByPOT( | ||||
| SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[oc]), out_multiplier[oc]), | SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[oc]), out_multiplier[oc]), | ||||
| @@ -252,6 +254,7 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c | |||||
| int out_h = conv_param->output_h_; | int out_h = conv_param->output_h_; | ||||
| int out_w = conv_param->output_w_; | int out_w = conv_param->output_w_; | ||||
| int out_channel = conv_param->output_channel_; | int out_channel = conv_param->output_channel_; | ||||
| int oc4 = UP_DIV(out_channel, C4NUM); | |||||
| int32_t input_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; | int32_t input_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; | ||||
| int tile_n = conv_param->tile_num_; | int tile_n = conv_param->tile_num_; | ||||
| @@ -264,7 +267,7 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c | |||||
| int packed_input_size = output_tile_count * tile_n * unit_size; | int packed_input_size = output_tile_count * tile_n * unit_size; | ||||
| int input_sum_offset; | int input_sum_offset; | ||||
| if (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL) { | if (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL) { | ||||
| input_sum_offset = tile_n * out_channel; | |||||
| input_sum_offset = tile_n * oc4 * C4NUM; | |||||
| } else { | } else { | ||||
| input_sum_offset = tile_n; | input_sum_offset = tile_n; | ||||
| } | } | ||||
| @@ -314,6 +317,7 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight | |||||
| int out_h = conv_param->output_h_; | int out_h = conv_param->output_h_; | ||||
| int out_w = conv_param->output_w_; | int out_w = conv_param->output_w_; | ||||
| int out_channel = conv_param->output_channel_; | int out_channel = conv_param->output_channel_; | ||||
| int oc4 = UP_DIV(out_channel, C4NUM); | |||||
| int32_t input_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; | int32_t input_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; | ||||
| int tile_n = conv_param->tile_num_; | int tile_n = conv_param->tile_num_; | ||||
| int thread_count = conv_param->thread_num_; | int thread_count = conv_param->thread_num_; | ||||
| @@ -325,7 +329,7 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight | |||||
| int packed_input_size = output_tile_count * tile_n * unit_size; | int packed_input_size = output_tile_count * tile_n * unit_size; | ||||
| int input_sum_offset; | int input_sum_offset; | ||||
| if (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL) { | if (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL) { | ||||
| input_sum_offset = tile_n * out_channel; | |||||
| input_sum_offset = tile_n * oc4 * C4NUM; | |||||
| } else { | } else { | ||||
| input_sum_offset = tile_n; | input_sum_offset = tile_n; | ||||
| } | } | ||||
| @@ -255,6 +255,7 @@ void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real | |||||
| int in_h = conv_param->input_h_; | int in_h = conv_param->input_h_; | ||||
| int in_w = conv_param->input_w_; | int in_w = conv_param->input_w_; | ||||
| int ic4 = UP_DIV(in_channel, C4NUM); | int ic4 = UP_DIV(in_channel, C4NUM); | ||||
| int oc4 = UP_DIV(conv_param->output_channel_, C4NUM); | |||||
| int out_w = conv_param->output_w_; | int out_w = conv_param->output_w_; | ||||
| for (int i = 0; i < real_cal_num; i++) { | for (int i = 0; i < real_cal_num; i++) { | ||||
| @@ -297,7 +298,7 @@ void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real | |||||
| continue; | continue; | ||||
| } else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) && | } else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) && | ||||
| (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { | (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { | ||||
| int cal_num_offset = i * conv_param->output_channel_; | |||||
| int cal_num_offset = i * oc4 * C4NUM; | |||||
| for (int l = 0; l < conv_param->output_channel_; ++l) { | for (int l = 0; l < conv_param->output_channel_; ++l) { | ||||
| input_sum[cal_num_offset + l] = input_accumulator * filter_arg[l].zp_; | input_sum[cal_num_offset + l] = input_accumulator * filter_arg[l].zp_; | ||||
| } | } | ||||
| @@ -325,6 +326,7 @@ void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int r | |||||
| int in_h = conv_param->input_h_; | int in_h = conv_param->input_h_; | ||||
| int in_w = conv_param->input_w_; | int in_w = conv_param->input_w_; | ||||
| int ic4 = UP_DIV(in_channel, C4NUM); | int ic4 = UP_DIV(in_channel, C4NUM); | ||||
| int oc4 = UP_DIV(conv_param->output_channel_, C4NUM); | |||||
| int out_w = conv_param->output_w_; | int out_w = conv_param->output_w_; | ||||
| int block_size = kernel_h * kernel_w; | int block_size = kernel_h * kernel_w; | ||||
| @@ -368,7 +370,7 @@ void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int r | |||||
| continue; | continue; | ||||
| } else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) && | } else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) && | ||||
| (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { | (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { | ||||
| int cal_num_offset = i * conv_param->output_channel_; | |||||
| int cal_num_offset = i * oc4 * C4NUM; | |||||
| for (int l = 0; l < conv_param->output_channel_; ++l) { | for (int l = 0; l < conv_param->output_channel_; ++l) { | ||||
| input_sum[cal_num_offset + l] = input_accumulator * filter_arg[l].zp_; | input_sum[cal_num_offset + l] = input_accumulator * filter_arg[l].zp_; | ||||
| } | } | ||||