| @@ -369,9 +369,10 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_ | |||||
| out_channel * sizeof(float16_t), 0, 0, relu, relu6); | out_channel * sizeof(float16_t), 0, 0, relu, relu6); | ||||
| } else { | } else { | ||||
| // res part | // res part | ||||
| IndirectGemmFp16_16x8(tmp_out_block, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, | |||||
| float16_t *tmp_out_ptr = tmp_out_block + task_id * tile_n * out_channel; | |||||
| IndirectGemmFp16_16x8(tmp_out_ptr, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, | |||||
| out_channel * sizeof(float16_t), 0, 0, relu, relu6); | out_channel * sizeof(float16_t), 0, 0, relu, relu6); | ||||
| memcpy(output_data + out_offset, tmp_out_block, real_cal_num * out_channel * sizeof(float16_t)); | |||||
| memcpy(output_data + out_offset, tmp_out_ptr, real_cal_num * out_channel * sizeof(float16_t)); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -248,9 +248,10 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons | |||||
| relu, relu6); | relu, relu6); | ||||
| } else { | } else { | ||||
| // res part | // res part | ||||
| gemm_func(tmp_out_block, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, output_offset, 0, | |||||
| 0, relu, relu6); | |||||
| memcpy(output_data + out_offset, tmp_out_block, real_cal_num * out_channel * sizeof(float)); | |||||
| float *tmp_out_ptr = tmp_out_block + task_id * TILE_NUM * out_channel; | |||||
| gemm_func(tmp_out_ptr, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, output_offset, 0, 0, | |||||
| relu, relu6); | |||||
| memcpy(output_data + out_offset, tmp_out_ptr, real_cal_num * out_channel * sizeof(float)); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -199,7 +199,7 @@ void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const | |||||
| } | } | ||||
| } | } | ||||
| void Conv3x3Uint8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, int oc, int ic8, size_t real_cal_num) { | |||||
| void Conv3x3Int8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, int oc, int ic8, size_t real_cal_num) { | |||||
| int oc4 = UP_DIV(oc, C4NUM); | int oc4 = UP_DIV(oc, C4NUM); | ||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| IndirectGemmInt16to32_8x4(dst, src, weight, 16, ic8, oc4, oc4 * 4 * 16 * sizeof(int32_t)); | IndirectGemmInt16to32_8x4(dst, src, weight, 16, ic8, oc4, oc4 * 4 * 16 * sizeof(int32_t)); | ||||
| @@ -298,9 +298,10 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c | |||||
| out_channel, tmp_input_sum, conv_param); | out_channel, tmp_input_sum, conv_param); | ||||
| } else { | } else { | ||||
| // res part | // res part | ||||
| IndirectGemmInt8(tmp_out, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, kernel_plane, | |||||
| int8_t *tmp_out_ptr = tmp_out + task_id * tile_n * out_channel; | |||||
| IndirectGemmInt8(tmp_out_ptr, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, kernel_plane, | |||||
| out_channel, tmp_input_sum, conv_param); | out_channel, tmp_input_sum, conv_param); | ||||
| memcpy(output_data + out_offset, tmp_out, real_cal_num * out_channel); | |||||
| memcpy(output_data + out_offset, tmp_out_ptr, real_cal_num * out_channel); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -360,9 +361,10 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight | |||||
| kernel_plane, out_channel, tmp_input_sum, conv_param, gemm_func); | kernel_plane, out_channel, tmp_input_sum, conv_param, gemm_func); | ||||
| } else { | } else { | ||||
| // res part | // res part | ||||
| IndirectGemmInt8Opt(tmp_out, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, kernel_plane, | |||||
| out_channel, tmp_input_sum, conv_param, gemm_func); | |||||
| memcpy(output_data + out_offset, tmp_out, real_cal_num * out_channel); | |||||
| int8_t *tmp_out_ptr = tmp_out + task_id * tile_n * out_channel; | |||||
| IndirectGemmInt8Opt(tmp_out_ptr, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, | |||||
| kernel_plane, out_channel, tmp_input_sum, conv_param, gemm_func); | |||||
| memcpy(output_data + out_offset, tmp_out_ptr, real_cal_num * out_channel); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -412,15 +414,15 @@ void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bi | |||||
| int start_index = thread_id * TILE_NUM; | int start_index = thread_id * TILE_NUM; | ||||
| int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; | int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; | ||||
| Conv3x3Uint8InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset, | |||||
| block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num, | |||||
| out_w_block, conv_param); | |||||
| Conv3x3Int8InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset, | |||||
| block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num, | |||||
| out_w_block, conv_param); | |||||
| Conv3x3Uint8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset, | |||||
| transed_weight, output_channel, ic8, real_cal_num); | |||||
| Conv3x3Int8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset, | |||||
| transed_weight, output_channel, ic8, real_cal_num); | |||||
| Conv3x3Uint8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset, | |||||
| bias_data, start_index, real_cal_num, out_w_block, conv_param); | |||||
| Conv3x3Int8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset, | |||||
| bias_data, start_index, real_cal_num, out_w_block, conv_param); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -695,7 +695,7 @@ void Conv3x3Fp32OutputTransform(const float *gemm_out, float *out_data, const fl | |||||
| } | } | ||||
| // int8 conv3x3 | // int8 conv3x3 | ||||
| void Conv3x3Uint8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp) { | |||||
| void Conv3x3Int8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp) { | |||||
| #ifdef ENABLE_ARM | #ifdef ENABLE_ARM | ||||
| int16x8_t zp = vdupq_n_s16(input_zp); | int16x8_t zp = vdupq_n_s16(input_zp); | ||||
| @@ -864,7 +864,7 @@ void Conv3x3Uint8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t | |||||
| #endif | #endif | ||||
| } | } | ||||
| void Conv3x3Uint8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index, | |||||
| void Conv3x3Int8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index, | |||||
| int real_cal_num, int out_w_block, ConvParameter *conv_param) { | int real_cal_num, int out_w_block, ConvParameter *conv_param) { | ||||
| // input data format : nhwc | // input data format : nhwc | ||||
| int input_channel = conv_param->input_channel_; | int input_channel = conv_param->input_channel_; | ||||
| @@ -904,7 +904,7 @@ void Conv3x3Uint8InputTransform(const int16_t *input_data, int16_t *trans_input, | |||||
| int dst_ic8_offset = dst_plane_offset + ic * TILE_NUM * C8NUM; | int dst_ic8_offset = dst_plane_offset + ic * TILE_NUM * C8NUM; | ||||
| size_t dst_step = ic8 * C8NUM * TILE_NUM; | size_t dst_step = ic8 * C8NUM * TILE_NUM; | ||||
| int16_t *trans_input_ptr = trans_input + dst_ic8_offset; | int16_t *trans_input_ptr = trans_input + dst_ic8_offset; | ||||
| Conv3x3Uint8InputUnit(tmp_data, trans_input_ptr, dst_step, input_zp); | |||||
| Conv3x3Int8InputUnit(tmp_data, trans_input_ptr, dst_step, input_zp); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1175,7 +1175,7 @@ void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weigh | |||||
| } | } | ||||
| } | } | ||||
| void Conv3x3Uint8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound, | |||||
| void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound, | |||||
| bool w_not_bound, int output_w, int real_num, int oc_start, ConvParameter *conv_param) { | bool w_not_bound, int output_w, int real_num, int oc_start, ConvParameter *conv_param) { | ||||
| int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_; | int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_; | ||||
| int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_; | int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_; | ||||
| @@ -1267,27 +1267,27 @@ void Conv3x3Uint8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, i | |||||
| d11 = vmaxq_s32(d11, output_min); | d11 = vmaxq_s32(d11, output_min); | ||||
| d11 = vminq_s32(d11, output_max); | d11 = vminq_s32(d11, output_max); | ||||
| (output_data)[0] = (uint8_t)d00[0]; | |||||
| (output_data + 1)[0] = (uint8_t)d00[1]; | |||||
| (output_data + 2)[0] = (uint8_t)d00[2]; | |||||
| (output_data + 3)[0] = (uint8_t)d00[3]; | |||||
| (output_data)[0] = (int8_t)d00[0]; | |||||
| (output_data + 1)[0] = (int8_t)d00[1]; | |||||
| (output_data + 2)[0] = (int8_t)d00[2]; | |||||
| (output_data + 3)[0] = (int8_t)d00[3]; | |||||
| if (w_not_bound) { | if (w_not_bound) { | ||||
| *(output_data + 4) = (uint8_t)d01[0]; | |||||
| *(output_data + 5) = (uint8_t)d01[1]; | |||||
| *(output_data + 6) = (uint8_t)d01[2]; | |||||
| *(output_data + 7) = (uint8_t)d01[3]; | |||||
| *(output_data + 4) = (int8_t)d01[0]; | |||||
| *(output_data + 5) = (int8_t)d01[1]; | |||||
| *(output_data + 6) = (int8_t)d01[2]; | |||||
| *(output_data + 7) = (int8_t)d01[3]; | |||||
| } | } | ||||
| if (h_not_bound) { | if (h_not_bound) { | ||||
| *(output_data + output_w * 4) = (uint8_t)d10[0]; | |||||
| *(output_data + output_w * 4 + 1) = (uint8_t)d10[1]; | |||||
| *(output_data + output_w * 4 + 2) = (uint8_t)d10[2]; | |||||
| *(output_data + output_w * 4 + 3) = (uint8_t)d10[3]; | |||||
| *(output_data + output_w * 4) = (int8_t)d10[0]; | |||||
| *(output_data + output_w * 4 + 1) = (int8_t)d10[1]; | |||||
| *(output_data + output_w * 4 + 2) = (int8_t)d10[2]; | |||||
| *(output_data + output_w * 4 + 3) = (int8_t)d10[3]; | |||||
| if (w_not_bound) { | if (w_not_bound) { | ||||
| *(output_data + output_w * 4 + 4) = (uint8_t)d11[0]; | |||||
| *(output_data + output_w * 4 + 5) = (uint8_t)d11[1]; | |||||
| *(output_data + output_w * 4 + 6) = (uint8_t)d11[2]; | |||||
| *(output_data + output_w * 4 + 7) = (uint8_t)d11[3]; | |||||
| *(output_data + output_w * 4 + 4) = (int8_t)d11[0]; | |||||
| *(output_data + output_w * 4 + 5) = (int8_t)d11[1]; | |||||
| *(output_data + output_w * 4 + 6) = (int8_t)d11[2]; | |||||
| *(output_data + output_w * 4 + 7) = (int8_t)d11[3]; | |||||
| } | } | ||||
| } | } | ||||
| #else | #else | ||||
| @@ -1456,7 +1456,7 @@ void Conv3x3Uint8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, i | |||||
| #endif | #endif | ||||
| } | } | ||||
| void Conv3x3Uint8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index, | |||||
| void Conv3x3Int8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index, | |||||
| int real_cal_num, int out_w_block, ConvParameter *conv_param) { | int real_cal_num, int out_w_block, ConvParameter *conv_param) { | ||||
| int output_channel = conv_param->output_channel_; | int output_channel = conv_param->output_channel_; | ||||
| int output_w = conv_param->output_w_; | int output_w = conv_param->output_w_; | ||||
| @@ -1483,7 +1483,7 @@ void Conv3x3Uint8OutputTransform(const int32_t *gemm_out, int8_t *out_data, cons | |||||
| int real_num = (output_channel - j * C4NUM) < C4NUM ? (output_channel - j * C4NUM) : C4NUM; | int real_num = (output_channel - j * C4NUM) < C4NUM ? (output_channel - j * C4NUM) : C4NUM; | ||||
| bool w_not_bound = out_w_index * OUPUT_UNIT + 1 < output_w; | bool w_not_bound = out_w_index * OUPUT_UNIT + 1 < output_w; | ||||
| bool h_not_bound = out_h_index * OUPUT_UNIT + 1 < output_h; | bool h_not_bound = out_h_index * OUPUT_UNIT + 1 < output_h; | ||||
| Conv3x3Uint8OutputUnit(src_ptr, bias_ptr, dst_ptr, h_not_bound, w_not_bound, output_w, real_num, j * C4NUM, | |||||
| Conv3x3Int8OutputUnit(src_ptr, bias_ptr, dst_ptr, h_not_bound, w_not_bound, output_w, real_num, j * C4NUM, | |||||
| conv_param); | conv_param); | ||||
| } | } | ||||
| } | } | ||||
| @@ -56,18 +56,18 @@ void Conv3x3Fp32OutputTransform(const float *gemm_out, float *out_data, const fl | |||||
| int real_cal_num, int out_w_block, ConvParameter *conv_param); | int real_cal_num, int out_w_block, ConvParameter *conv_param); | ||||
| // for int8 convolution 3x3 filter/input/output transform | // for int8 convolution 3x3 filter/input/output transform | ||||
| void Conv3x3Uint8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp); | |||||
| void Conv3x3Int8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp); | |||||
| void Conv3x3Uint8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index, | |||||
| void Conv3x3Int8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index, | |||||
| int real_cal_num, int out_w_block, ConvParameter *conv_param); | int real_cal_num, int out_w_block, ConvParameter *conv_param); | ||||
| void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weight, int iC8, int output_channel, | void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weight, int iC8, int output_channel, | ||||
| int kernel_plane); | int kernel_plane); | ||||
| void Conv3x3Uint8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound, | |||||
| void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound, | |||||
| bool w_not_bound, int output_w, int real_num, int oc_start, ConvParameter *conv_param); | bool w_not_bound, int output_w, int real_num, int oc_start, ConvParameter *conv_param); | ||||
| void Conv3x3Uint8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index, | |||||
| void Conv3x3Int8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index, | |||||
| int real_cal_num, int out_w_block, ConvParameter *conv_param); | int real_cal_num, int out_w_block, ConvParameter *conv_param); | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| @@ -112,7 +112,7 @@ int ConvolutionFP16CPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| memset(nhwc4_input_, 0, nhwc4_input_size); | memset(nhwc4_input_, 0, nhwc4_input_size); | ||||
| tmp_output_block_ = reinterpret_cast<float16_t *>(malloc(cal_num * out_channel * sizeof(float16_t))); | |||||
| tmp_output_block_ = reinterpret_cast<float16_t *>(malloc(thread_count_ * cal_num * out_channel * sizeof(float16_t))); | |||||
| if (tmp_output_block_ == nullptr) { | if (tmp_output_block_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc tmp_output_block_ failed."; | MS_LOG(ERROR) << "malloc tmp_output_block_ failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -83,7 +83,8 @@ int ConvolutionCPUKernel::InitTmpBuffer() { | |||||
| int out_channel = conv_param_->output_channel_; | int out_channel = conv_param_->output_channel_; | ||||
| MS_ASSERT(ctx_->allocator != nullptr); | MS_ASSERT(ctx_->allocator != nullptr); | ||||
| tmp_output_block_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(TILE_NUM * out_channel * sizeof(float))); | |||||
| tmp_output_block_ = | |||||
| reinterpret_cast<float *>(ctx_->allocator->Malloc(thread_count_ * TILE_NUM * out_channel * sizeof(float))); | |||||
| if (tmp_output_block_ == nullptr) { | if (tmp_output_block_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc tmp output block failed."; | MS_LOG(ERROR) << "malloc tmp output block failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -338,7 +338,7 @@ int ConvolutionInt8CPUKernel::RunImpl(int task_id) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int ConvolutionInt8Impl(int task_id, LiteParallelGroupEnv *mpenv, void *cdata) { | |||||
| int ConvolutionInt8Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||||
| auto conv = reinterpret_cast<ConvolutionInt8CPUKernel *>(cdata); | auto conv = reinterpret_cast<ConvolutionInt8CPUKernel *>(cdata); | ||||
| auto error_code = conv->RunImpl(task_id); | auto error_code = conv->RunImpl(task_id); | ||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||