| @@ -130,8 +130,9 @@ int Convolution3x3FP16CPUKernel::InitTmpBuffer() { | |||
| memset(tmp_dst_buffer_, 0, tmp_dst_buffer_size); | |||
| /*=============================tmp_out_============================*/ | |||
| size_t tmp_out_size = oC8 * C8NUM * conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * | |||
| tile_num * sizeof(float16_t); | |||
| int new_out_plane = UP_DIV(conv_param_->output_h_, C4NUM) * UP_DIV(conv_param_->output_w_, C4NUM) * C4NUM * C4NUM; | |||
| size_t tmp_out_size = | |||
| oC8 * C8NUM * conv_param_->output_batch_ * new_out_plane * sizeof(float16_t); | |||
| tmp_out_ = reinterpret_cast<float16_t *>(malloc(tmp_out_size)); | |||
| if (tmp_out_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc tmp_out_ failed."; | |||
| @@ -278,7 +279,7 @@ int Convolution3x3FP16CPUKernel::Run() { | |||
| auto out_tensor = outputs_.at(kOutputIndex); | |||
| auto output_addr = reinterpret_cast<float *>(out_tensor->Data()); | |||
| for (int j = 0; j < out_tensor->ElementsNum(); ++j) { | |||
| output_addr[j] = (reinterpret_cast<float *>(fp16_out_))[j]; | |||
| output_addr[j] = static_cast<float >(fp16_out_[j]); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -77,7 +77,6 @@ void IndirectGemmFp16_16x8_tmp(float16_t *output, float16_t *input, float16_t *w | |||
| int oc8_block = j / 8; | |||
| int oc8_res = j % 8; | |||
| int weight_oc_offset = oc8_block * 36 * ic4 * C4NUM * 8 + oc8_res; | |||
| // todo nc4hw4 -> nhwc | |||
| int out_oc_offset = output_tile_offset + oc8_block * 36 * C8NUM + oc8_res; | |||
| for (int n = 0; n < step; n++) { | |||
| @@ -169,6 +168,7 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16 | |||
| int thread_count = conv_param->thread_num_; | |||
| int tile_num = 16; | |||
| int output_unit = 4; | |||
| int k_plane = 36; | |||
| int ic4 = UP_DIV(conv_param->input_channel_, C4NUM); | |||
| int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); | |||
| @@ -181,6 +181,9 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16 | |||
| int out_h_block = UP_DIV(conv_param->output_h_, C4NUM); | |||
| int output_count = out_w_block * out_h_block; | |||
| int output_tile_count = UP_DIV(output_count, tile_num); | |||
| int tile_buffer_offset = tile_num * k_plane * ic4 * C4NUM; | |||
| int block_unit_buffer_offset = k_plane * C4NUM; | |||
| int tmp_dst_buffer_offset = tile_num * k_plane * oc8 * C8NUM; | |||
| int input_batch = conv_param->input_batch_; | |||
| for (int batch = 0; batch < input_batch; batch++) { | |||
| @@ -188,14 +191,16 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16 | |||
| int start_index = thread_id * tile_num; | |||
| int real_cal_num = (output_count - start_index) < tile_num ? (output_count - start_index) : tile_num; | |||
| Conv3x3Fp16InputTransform(input_data, tile_buffer, block_unit_buffer, start_index, real_cal_num, out_w_block, | |||
| conv_param); | |||
| Conv3x3Fp16InputTransform(input_data, tile_buffer + task_id * tile_buffer_offset, | |||
| block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num, | |||
| out_w_block, conv_param); | |||
| IndirectGemmFp16_16x8(tmp_dst_buffer, tile_buffer, transed_weight, NULL, 36, ic4, oc8 * C8NUM, | |||
| IndirectGemmFp16_16x8(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, | |||
| tile_buffer + task_id * tile_buffer_offset, transed_weight, NULL, 36, ic4, oc8 * C8NUM, | |||
| oc8 * C8NUM * 36 * sizeof(float16_t), 1, 1, 0, 0); | |||
| Conv3x3Fp16OutputTransform(tmp_dst_buffer, tmp_out, bias_data, start_index, real_cal_num, out_w_block, | |||
| conv_param); | |||
| Conv3x3Fp16OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out, bias_data, start_index, | |||
| real_cal_num, out_w_block, conv_param); | |||
| } | |||
| } | |||
| @@ -207,7 +207,7 @@ void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_inp | |||
| int real_y_start = origin_y > 0 ? 0 : -origin_y; | |||
| int real_y_end = (origin_y + 6) < input_height ? 6 : (input_height - origin_y); | |||
| int src_plane_offset = input_channel * (origin_y * input_width + origin_x); | |||
| int src_plane_offset = ic4 * C4NUM * (origin_y * input_width + origin_x); | |||
| int dst_plane_offset = cal_id * C4NUM; | |||
| for (int ic = 0; ic < ic4; ic++) { | |||
| // clear tmp buffer | |||
| @@ -216,10 +216,10 @@ void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_inp | |||
| // get real input block with padding | |||
| int src_ic4_offset = src_plane_offset + ic * C4NUM; | |||
| for (int interval = real_y_start; interval < real_y_end; interval++) { | |||
| int src_y_offset = src_ic4_offset + interval * input_width * input_channel + real_x_start * input_channel; | |||
| int src_y_offset = src_ic4_offset + (interval * input_width + real_x_start) * ic4 * C4NUM; | |||
| int dst_y_offset = interval * 6 * C4NUM + real_x_start * C4NUM; | |||
| for (int j = 0; j < (real_x_end - real_x_start); j++) { | |||
| int src_x_offset = src_y_offset + j * input_channel; | |||
| int src_x_offset = src_y_offset + j * ic4 * C4NUM; | |||
| int dst_x_offset = dst_y_offset + j * C4NUM; | |||
| float16_t *src_addr = (float16_t *)(input_data) + src_x_offset; | |||
| float16_t *dst_addr = tmp_data + dst_x_offset; | |||
| @@ -511,7 +511,7 @@ void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, | |||
| int output_w = conv_param->output_w_; | |||
| int output_h = conv_param->output_h_; | |||
| int oc8 = UP_DIV(output_channel, C8NUM); | |||
| // todo outputw --> out_w_block * out_unit | |||
| for (int i = 0; i < real_cal_num; i++) { | |||
| int out_w_index = (start_index + i) % out_w_block; | |||
| int out_h_index = (start_index + i) / out_w_block; | |||
| @@ -203,19 +203,20 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c | |||
| // clear tmp buffer before compute | |||
| memset(gemm_input, (int8_t)input_zp, unit_size * tile_n); | |||
| int out_offset = thread_id * tile_n * out_channel + out_batch_offset; | |||
| // todo | |||
| size_t tmp_dst_size = thread_count * tile_n * conv_param->output_channel_ * sizeof(int32_t); | |||
| memset(tmp_dst, 0, tmp_dst_size); | |||
| size_t tmp_dst_size = tile_n * conv_param->output_channel_ * sizeof(int32_t); | |||
| int tmp_dst_offset = task_id * tile_n * conv_param->output_channel_; | |||
| memset(tmp_dst + tmp_dst_offset, 0, tmp_dst_size); | |||
| Im2ColPackUnitInt8(input_data + in_batch_offset, gemm_input, real_cal_num, start_index, input_sum, conv_param); | |||
| if (real_cal_num == tile_n) { | |||
| int8_t *gemm_output = output_data + out_offset; | |||
| IndirectGemmInt8(gemm_output, tmp_dst, gemm_input, packed_weight, bias_data, ic4, kernel_plane, out_channel, | |||
| input_sum, conv_param); | |||
| IndirectGemmInt8(gemm_output, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, kernel_plane, | |||
| out_channel, input_sum, conv_param); | |||
| } else { | |||
| // res part | |||
| IndirectGemmInt8(tmp_out, tmp_dst, gemm_input, packed_weight, bias_data, ic4, kernel_plane, out_channel, | |||
| input_sum, conv_param); | |||
| IndirectGemmInt8(tmp_out, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, kernel_plane, | |||
| out_channel, input_sum, conv_param); | |||
| memcpy(output_data + out_offset, tmp_out, real_cal_num * out_channel); | |||
| } | |||
| } | |||
| @@ -257,19 +258,20 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight | |||
| // clear tmp buffer before compute | |||
| memset(gemm_input, (int8_t)input_zp, unit_size * tile_n); | |||
| int out_offset = thread_id * tile_n * out_channel + out_batch_offset; | |||
| // todo | |||
| size_t tmp_dst_size = thread_count * tile_n * conv_param->output_channel_ * sizeof(int32_t); | |||
| memset(tmp_dst, 0, tmp_dst_size); | |||
| size_t tmp_dst_size = tile_n * conv_param->output_channel_ * sizeof(int32_t); | |||
| int tmp_dst_offset = task_id * tile_n * conv_param->output_channel_; | |||
| memset(tmp_dst + tmp_dst_offset, 0, tmp_dst_size); | |||
| Im2ColPackUnitInt8Opt(input_data + in_batch_offset, gemm_input, real_cal_num, start_index, input_sum, conv_param); | |||
| if (real_cal_num == tile_n) { | |||
| int8_t *gemm_output = output_data + out_offset; | |||
| IndirectGemmInt8Opt(gemm_output, tmp_dst, gemm_input, packed_weight, bias_data, ic4, kernel_plane, out_channel, | |||
| input_sum, conv_param, gemm_func); | |||
| IndirectGemmInt8Opt(gemm_output, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, | |||
| kernel_plane, out_channel, input_sum, conv_param, gemm_func); | |||
| } else { | |||
| // res part | |||
| IndirectGemmInt8Opt(tmp_out, tmp_dst, gemm_input, packed_weight, bias_data, ic4, kernel_plane, out_channel, | |||
| input_sum, conv_param, gemm_func); | |||
| IndirectGemmInt8Opt(tmp_out, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, kernel_plane, | |||
| out_channel, input_sum, conv_param, gemm_func); | |||
| memcpy(output_data + out_offset, tmp_out, real_cal_num * out_channel); | |||
| } | |||
| } | |||
| @@ -290,6 +292,10 @@ void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bi | |||
| int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT); | |||
| int output_count = out_w_block * out_h_block; | |||
| int output_tile_count = UP_DIV(output_count, TILE_NUM); | |||
| int oc4 = UP_DIV(output_channel, C4NUM); | |||
| int tile_buffer_offset = TILE_NUM * 16 * ic8 * C8NUM; | |||
| int block_unit_buffer_offset = 16 * C8NUM; | |||
| int tmp_dst_buffer_offset = TILE_NUM * 16 * oc4 * C4NUM; | |||
| int input_batch = conv_param->input_batch_; | |||
| for (int batch = 0; batch < input_batch; batch++) { | |||
| @@ -297,13 +303,15 @@ void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bi | |||
| int start_index = thread_id * TILE_NUM; | |||
| int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; | |||
| Conv3x3Uint8InputTransform(input_data, tile_buffer, block_unit_buffer, start_index, real_cal_num, out_w_block, | |||
| conv_param); | |||
| Conv3x3Uint8InputTransform(input_data, tile_buffer + task_id * tile_buffer_offset, | |||
| block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num, | |||
| out_w_block, conv_param); | |||
| Conv3x3Uint8Gemm(tmp_dst_buffer, tile_buffer, transed_weight, output_channel, ic8, real_cal_num); | |||
| Conv3x3Uint8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset, | |||
| transed_weight, output_channel, ic8, real_cal_num); | |||
| Conv3x3Uint8OutputTransform(tmp_dst_buffer, tmp_out, bias_data, start_index, real_cal_num, out_w_block, | |||
| conv_param); | |||
| Conv3x3Uint8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out, bias_data, start_index, | |||
| real_cal_num, out_w_block, conv_param); | |||
| } | |||
| } | |||
| @@ -136,7 +136,7 @@ inline uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp) { retu | |||
| inline int32_t QuantizeToInt8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); } | |||
| inline void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32_t zp, int32_t scale, int *mini, | |||
| inline void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32_t zp, float scale, int *mini, | |||
| int *maxi) { | |||
| int32_t min = std::numeric_limits<int8_t>::min(); | |||
| int32_t max = std::numeric_limits<int8_t>::max(); | |||
| @@ -584,7 +584,7 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, | |||
| vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t01, t02), -0.3), vmulq_n_f32(vaddq_f32(t03, t04), 1.33333333333)), | |||
| vmulq_n_f32(vaddq_f32(t05, t06), -0.533333333333)); | |||
| float32x4_t m04 = | |||
| vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t01, t02), 0.3), vmulq_n_f32(vsubq_f32(t03, t04), 1.33333333333)), | |||
| vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t01, t02), 0.3), vmulq_n_f32(vsubq_f32(t04, t03), 1.33333333333)), | |||
| vmulq_n_f32(vsubq_f32(t05, t06), 0.533333333333)); | |||
| float32x4_t m05 = | |||
| vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t01, 0.03333333), vmulq_n_f32(t02, 0.0222222)), | |||
| @@ -618,7 +618,7 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, | |||
| vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t11, t12), -0.3), vmulq_n_f32(vaddq_f32(t13, t14), 1.33333333333)), | |||
| vmulq_n_f32(vaddq_f32(t15, t16), -0.533333333333)); | |||
| float32x4_t m14 = | |||
| vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t11, t12), 0.3), vmulq_n_f32(vsubq_f32(t13, t14), 1.33333333333)), | |||
| vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t11, t12), 0.3), vmulq_n_f32(vsubq_f32(t14, t13), 1.33333333333)), | |||
| vmulq_n_f32(vsubq_f32(t15, t16), 0.533333333333)); | |||
| float32x4_t m15 = | |||
| vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t11, 0.03333333), vmulq_n_f32(t12, 0.0222222)), | |||
| @@ -652,7 +652,7 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, | |||
| vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t21, t22), -0.3), vmulq_n_f32(vaddq_f32(t23, t24), 1.33333333333)), | |||
| vmulq_n_f32(vaddq_f32(t25, t26), -0.533333333333)); | |||
| float32x4_t m24 = | |||
| vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t21, t22), 0.3), vmulq_n_f32(vsubq_f32(t23, t24), 1.33333333333)), | |||
| vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t21, t22), 0.3), vmulq_n_f32(vsubq_f32(t24, t23), 1.33333333333)), | |||
| vmulq_n_f32(vsubq_f32(t25, t26), 0.533333333333)); | |||
| float32x4_t m25 = | |||
| vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t21, 0.03333333), vmulq_n_f32(t22, 0.0222222)), | |||
| @@ -686,7 +686,7 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, | |||
| vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t31, t32), -0.3), vmulq_n_f32(vaddq_f32(t33, t34), 1.33333333333)), | |||
| vmulq_n_f32(vaddq_f32(t35, t36), -0.533333333333)); | |||
| float32x4_t m34 = | |||
| vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t31, t32), 0.3), vmulq_n_f32(vsubq_f32(t33, t34), 1.33333333333)), | |||
| vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t31, t32), 0.3), vmulq_n_f32(vsubq_f32(t34, t33), 1.33333333333)), | |||
| vmulq_n_f32(vsubq_f32(t35, t36), 0.533333333333)); | |||
| float32x4_t m35 = | |||
| vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t31, 0.03333333), vmulq_n_f32(t32, 0.0222222)), | |||
| @@ -720,7 +720,7 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, | |||
| vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t41, t42), -0.3), vmulq_n_f32(vaddq_f32(t43, t44), 1.33333333333)), | |||
| vmulq_n_f32(vaddq_f32(t45, t46), -0.533333333333)); | |||
| float32x4_t m44 = | |||
| vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t41, t42), 0.3), vmulq_n_f32(vsubq_f32(t43, t44), 1.33333333333)), | |||
| vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t41, t42), 0.3), vmulq_n_f32(vsubq_f32(t44, t43), 1.33333333333)), | |||
| vmulq_n_f32(vsubq_f32(t45, t46), 0.533333333333)); | |||
| float32x4_t m45 = | |||
| vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t41, 0.03333333), vmulq_n_f32(t42, 0.0222222)), | |||
| @@ -754,7 +754,7 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, | |||
| vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t51, t52), -0.3), vmulq_n_f32(vaddq_f32(t53, t54), 1.33333333333)), | |||
| vmulq_n_f32(vaddq_f32(t55, t56), -0.533333333333)); | |||
| float32x4_t m54 = | |||
| vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t51, t52), 0.3), vmulq_n_f32(vsubq_f32(t53, t54), 1.33333333333)), | |||
| vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t51, t52), 0.3), vmulq_n_f32(vsubq_f32(t54, t53), 1.33333333333)), | |||
| vmulq_n_f32(vsubq_f32(t55, t56), 0.533333333333)); | |||
| float32x4_t m55 = | |||
| vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t51, 0.03333333), vmulq_n_f32(t52, 0.0222222)), | |||
| @@ -788,7 +788,7 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, | |||
| vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t61, t62), -0.3), vmulq_n_f32(vaddq_f32(t63, t64), 1.33333333333)), | |||
| vmulq_n_f32(vaddq_f32(t65, t66), -0.533333333333)); | |||
| float32x4_t m64 = | |||
| vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t61, t62), 0.3), vmulq_n_f32(vsubq_f32(t63, t64), 1.33333333333)), | |||
| vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t61, t62), 0.3), vmulq_n_f32(vsubq_f32(t64, t63), 1.33333333333)), | |||
| vmulq_n_f32(vsubq_f32(t65, t66), 0.533333333333)); | |||
| float32x4_t m65 = | |||
| vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t61, 0.03333333), vmulq_n_f32(t62, 0.0222222)), | |||
| @@ -822,7 +822,7 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, | |||
| vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t71, t72), -0.3), vmulq_n_f32(vaddq_f32(t73, t74), 1.33333333333)), | |||
| vmulq_n_f32(vaddq_f32(t75, t76), -0.533333333333)); | |||
| float32x4_t m74 = | |||
| vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t71, t72), 0.3), vmulq_n_f32(vsubq_f32(t73, t74), 1.33333333333)), | |||
| vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t71, t72), 0.3), vmulq_n_f32(vsubq_f32(t74, t73), 1.33333333333)), | |||
| vmulq_n_f32(vsubq_f32(t75, t76), 0.533333333333)); | |||
| float32x4_t m75 = | |||
| vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t71, 0.03333333), vmulq_n_f32(t72, 0.0222222)), | |||