diff --git a/mindspore/lite/nnacl/fp16/conv_fp16.c b/mindspore/lite/nnacl/fp16/conv_fp16.c index fad112f77d..e7a5291c48 100644 --- a/mindspore/lite/nnacl/fp16/conv_fp16.c +++ b/mindspore/lite/nnacl/fp16/conv_fp16.c @@ -17,6 +17,7 @@ #include #include "nnacl/fp16/pack_fp16.h" #include "nnacl/fp16/winograd_transform_fp16.h" +#include "nnacl/fp16/matmul_fp16.h" #ifdef __cplusplus extern "C" { @@ -122,7 +123,8 @@ void IndirectGemmFp16_16x8_c8(float16_t *output, float16_t *input, float16_t *we // fp16 convolution common (im2col+gemm) void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data, - float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param) { + float16_t *col_major_input, float16_t *output_data, int task_id, ConvParameter *conv_param) { + const int tile_n = 16; int kernel_h = conv_param->kernel_h_; int kernel_w = conv_param->kernel_w_; int in_batch = conv_param->input_batch_; @@ -132,203 +134,29 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_ int out_h = conv_param->output_h_; int out_w = conv_param->output_w_; int out_channel = conv_param->output_channel_; - bool relu = conv_param->act_type_ == ActType_Relu; - bool relu6 = conv_param->act_type_ == ActType_Relu6; int thread_count = conv_param->thread_num_; - const int tile_n = 16; int output_count = out_h * out_w; int output_tile_count = UP_DIV(output_count, tile_n); - - int channel_block = UP_DIV(in_channel, C4NUM); int kernel_plane = kernel_h * kernel_w; - int unit_size = kernel_plane * channel_block * C4NUM; - - // we accumulate 4 channels per time for input blocks - int ic4 = UP_DIV(in_channel, C4NUM); - int conv_depth = kernel_h * kernel_w; - // bytes from one output's i-th channel to the next output's i-th channel - // we write 32 bytes per st1 instruction, after which the pointer in register will step 32B forward + int deep = kernel_plane * in_channel; for (int b = 0; b < in_batch; b++) { - int in_batch_offset = b * ic4 * C4NUM * in_h * in_w; + int in_batch_offset = b * in_channel * in_h * in_w; int out_batch_offset = b * out_channel * out_h * out_w; for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { int start_index = thread_id * tile_n; int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n; - float16_t *gemm_input = (float16_t *)(packed_input + task_id * unit_size * tile_n); + float16_t *gemm_input = packed_input + task_id * deep * tile_n; + float16_t *col_major_gemm_input = col_major_input + task_id * deep * tile_n; + size_t packed_input_size = deep * tile_n * sizeof(float16_t); + memset(gemm_input, 0, packed_input_size); + memset(col_major_gemm_input, 0, packed_input_size); Im2ColPackUnitFp16(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index); int out_offset = thread_id * tile_n * out_channel + out_batch_offset; - if (real_cal_num == tile_n) { - float16_t *gemm_output = output_data + out_offset; - IndirectGemmFp16_16x8(gemm_output, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, - out_channel * sizeof(float16_t), 0, 0, relu, relu6); - } else { - // res part - float16_t *tmp_out_ptr = tmp_out_block + task_id * tile_n * out_channel; - IndirectGemmFp16_16x8(tmp_out_ptr, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, - out_channel * sizeof(float16_t), 0, 0, relu, relu6); - memcpy(output_data + out_offset, tmp_out_ptr, real_cal_num * out_channel * sizeof(float16_t)); - } - } - } -} - -// fp16 conv3x3 -void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16_t *bias_data, float16_t *output_data, - float16_t *tile_buffer, float16_t *block_unit_buffer, float16_t *tmp_dst_buffer, float16_t *tmp_out, - int task_id, ConvParameter *conv_param) { - int thread_count = conv_param->thread_num_; - const int tile_num = 16; - const int output_unit = 4; - const int k_plane = 36; - int ic8 = UP_DIV(conv_param->input_channel_, C8NUM); - int ic4 = ic8 * 2; - int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); - - int out_w_block = UP_DIV(conv_param->output_w_, C4NUM); - int out_h_block = UP_DIV(conv_param->output_h_, C4NUM); - int output_count = out_w_block * out_h_block; - int output_tile_count = UP_DIV(output_count, tile_num); - int tile_buffer_offset = tile_num * k_plane * ic4 * C4NUM; - int block_unit_buffer_offset = k_plane * C8NUM; - int tmp_dst_buffer_offset = tile_num * k_plane * oc8 * C8NUM; - - int input_batch = conv_param->input_batch_; - for (int batch = 0; batch < input_batch; batch++) { - int tmp_out_batch_offset = batch * oc8 * C8NUM * out_w_block * out_h_block * output_unit * output_unit; - for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { - int start_index = thread_id * tile_num; - int real_cal_num = (output_count - start_index) < tile_num ? (output_count - start_index) : tile_num; - - Conv3x3Fp16InputTransform(input_data, tile_buffer + task_id * tile_buffer_offset, - block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num, - out_w_block, conv_param); - - IndirectGemmFp16_16x8(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, - tile_buffer + task_id * tile_buffer_offset, transed_weight, NULL, 36, ic4, oc8 * C8NUM, - oc8 * C8NUM * 36 * sizeof(float16_t), 1, 1, 0, 0); - - Conv3x3Fp16OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset, - bias_data, start_index, real_cal_num, out_w_block, conv_param); - } - } -} - -void UnPack3x3OutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel) { - int out_w_block = UP_DIV(width, C4NUM); - int out_h_block = UP_DIV(height, C4NUM); - int oc8 = UP_DIV(channel, C8NUM); - - for (int b = 0; b < batch; b++) { - int tmp_out_batch_offset = b * oc8 * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM; - int ro_batch_size = b * channel * height * width; - const float16_t *batch_tmp_out = src + tmp_out_batch_offset; - float16_t *batch_out = dst + ro_batch_size; - for (int h = 0; h < height; h++) { - int src_h_offset = h * out_w_block * C4NUM * C8NUM; - const int dst_h_offset = h * width * channel; - for (int w = 0; w < width; w++) { - int src_w_offset = src_h_offset + w * C8NUM; - int dst_w_offset = dst_h_offset + w * channel; - for (int c = 0; c < oc8 - 1; ++c) { - int src_offset = c * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM + src_w_offset; - int dst_offset = dst_w_offset + c * C8NUM; - vst1q_f16(batch_out + dst_offset, vld1q_f16(batch_tmp_out + src_offset)); - } - - int c_res = channel - (oc8 - 1) * C8NUM; - int src_c_res_offset = src_w_offset + (oc8 - 1) * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM; - int dst_c_res_offset = dst_w_offset + (oc8 - 1) * C8NUM; - for (int c = 0; c < c_res; c++) { - int src_offset = src_c_res_offset + c; - int dst_offset = dst_c_res_offset + c; - (batch_out + dst_offset)[0] = (batch_tmp_out + src_offset)[0]; - } - } - } - } -} - -void UnPack3x3ReluOutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel) { - int out_w_block = UP_DIV(width, C4NUM); - int out_h_block = UP_DIV(height, C4NUM); - int oc8 = UP_DIV(channel, C8NUM); - - for (int b = 0; b < batch; b++) { - int tmp_out_batch_offset = b * oc8 * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM; - int ro_batch_size = b * channel * height * width; - const float16_t *batch_tmp_out = src + tmp_out_batch_offset; - float16_t *batch_out = dst + ro_batch_size; - for (int h = 0; h < height; h++) { - int src_h_offset = h * out_w_block * C4NUM * C8NUM; - const int dst_h_offset = h * width * channel; - for (int w = 0; w < width; w++) { - int src_w_offset = src_h_offset + w * C8NUM; - int dst_w_offset = dst_h_offset + w * channel; - for (int c = 0; c < oc8 - 1; ++c) { - int src_offset = c * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM + src_w_offset; - int dst_offset = dst_w_offset + c * C8NUM; - float16x8_t input_ptr = vld1q_f16(batch_tmp_out + src_offset); - float16x8_t zero = vdupq_n_f16(0); - input_ptr = vmaxq_f16(zero, input_ptr); - vst1q_f16(batch_out + dst_offset, input_ptr); - } - - int c_res = channel - (oc8 - 1) * C8NUM; - int src_c_res_offset = src_w_offset + (oc8 - 1) * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM; - int dst_c_res_offset = dst_w_offset + (oc8 - 1) * C8NUM; - for (int c = 0; c < c_res; c++) { - int src_offset = src_c_res_offset + c; - int dst_offset = dst_c_res_offset + c; - float16_t input_data = (batch_tmp_out + src_offset)[0]; - input_data = input_data < 0 ? 0 : input_data; - (batch_out + dst_offset)[0] = input_data; - } - } - } - } -} - -void UnPack3x3Relu6OutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel) { - int out_w_block = UP_DIV(width, C4NUM); - int out_h_block = UP_DIV(height, C4NUM); - int oc8 = UP_DIV(channel, C8NUM); - - for (int b = 0; b < batch; b++) { - int tmp_out_batch_offset = b * oc8 * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM; - int ro_batch_size = b * channel * height * width; - const float16_t *batch_tmp_out = src + tmp_out_batch_offset; - float16_t *batch_out = dst + ro_batch_size; - for (int h = 0; h < height; h++) { - int src_h_offset = h * out_w_block * C4NUM * C8NUM; - const int dst_h_offset = h * width * channel; - for (int w = 0; w < width; w++) { - int src_w_offset = src_h_offset + w * C8NUM; - int dst_w_offset = dst_h_offset + w * channel; - for (int c = 0; c < oc8 - 1; ++c) { - int src_offset = c * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM + src_w_offset; - int dst_offset = dst_w_offset + c * C8NUM; - float16x8_t input_ptr = vld1q_f16(batch_tmp_out + src_offset); - float16x8_t zero = vdupq_n_f16(0); - float16x8_t six = vdupq_n_f16(6); - input_ptr = vmaxq_f16(zero, input_ptr); - input_ptr = vminq_f16(six, input_ptr); - vst1q_f16(batch_out + dst_offset, input_ptr); - } - - int c_res = channel - (oc8 - 1) * C8NUM; - int src_c_res_offset = src_w_offset + (oc8 - 1) * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM; - int dst_c_res_offset = dst_w_offset + (oc8 - 1) * C8NUM; - for (int c = 0; c < c_res; c++) { - int src_offset = src_c_res_offset + c; - int dst_offset = dst_c_res_offset + c; - float16_t input_data = (batch_tmp_out + src_offset)[0]; - input_data = input_data < 0 ? 0 : input_data; - input_data = input_data > 6 ? 6 : input_data; - (batch_out + dst_offset)[0] = input_data; - } - } + RowMajor2Col16MajorFp16Opt(gemm_input, col_major_gemm_input, tile_n, deep); + MatMulFp16(col_major_gemm_input, packed_weight, output_data + out_offset, bias_data, conv_param->act_type_, deep, + real_cal_num, out_channel, out_channel, OutType_Nhwc); } } } diff --git a/mindspore/lite/nnacl/fp16/conv_fp16.h b/mindspore/lite/nnacl/fp16/conv_fp16.h index 0064a553d5..b38b2854ea 100644 --- a/mindspore/lite/nnacl/fp16/conv_fp16.h +++ b/mindspore/lite/nnacl/fp16/conv_fp16.h @@ -43,18 +43,7 @@ extern "C" { // fp16 convolution common (im2col+gemm) void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data, - float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param); - -// fp16 conv3x3 -void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16_t *bias_data, float16_t *output_data, - float16_t *tile_buffer, float16_t *block_unit_buffer, float16_t *tmp_dst_buffer, float16_t *tmp_out, - int task_id, ConvParameter *conv_param); - -void UnPack3x3OutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel); - -void UnPack3x3ReluOutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel); - -void UnPack3x3Relu6OutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel); + float16_t *col_major_input, float16_t *output_data, int task_id, ConvParameter *conv_param); // fp16 convolution winograd void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const float16_t *bias_data, diff --git a/mindspore/lite/nnacl/fp16/pack_fp16.c b/mindspore/lite/nnacl/fp16/pack_fp16.c index b08f258415..27f98cbfdf 100644 --- a/mindspore/lite/nnacl/fp16/pack_fp16.c +++ b/mindspore/lite/nnacl/fp16/pack_fp16.c @@ -23,6 +23,7 @@ void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float1 // input format : nhwc int kernel_h = conv_param->kernel_h_; int kernel_w = conv_param->kernel_w_; + int kernel_plane = kernel_h * kernel_w; int stride_h = conv_param->stride_h_; int stride_w = conv_param->stride_w_; int pad_h = conv_param->pad_u_; @@ -33,9 +34,6 @@ void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float1 int in_h = conv_param->input_h_; int in_w = conv_param->input_w_; int out_w = conv_param->output_w_; - int ic4 = UP_DIV(in_channel, 4); - int ic4_minus = in_channel / 4; - memset(packed_input, 0, kernel_w * kernel_h * ic4 * C4NUM * 16 * sizeof(float16_t)); for (int i = 0; i < real_cal_num; i++) { int block_start = block_index + i; @@ -46,74 +44,25 @@ void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float1 int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h)); int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w)); int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w)); - for (int j = kh_s; j < kh_e; j++) { - int input_y_stride = j * dilation_h * in_w * in_channel + input_stride; - for (int n = kw_s; n < kw_e; n++) { - int input_x_stride = input_y_stride + n * dilation_w * in_channel; - int input_plane_offset = (j * kernel_w + n) * 16 * C4NUM * ic4 + i * C4NUM; - for (int m = 0; m < ic4_minus; m++) { - int channel_block_stride = input_x_stride + m * C4NUM; - int channel_block_offset = input_plane_offset + m * 16 * C4NUM; -#ifdef ENABLE_ARM64 - vst1_f16(packed_input + channel_block_offset, vld1_f16(input_data + channel_block_stride)); -#else - for (int l = 0; l < C4NUM; ++l) { - (packed_input + channel_block_offset)[l] = (input_data + channel_block_stride)[l]; - } -#endif - } // channel_block loop - int ic_res = in_channel - ic4_minus * C4NUM; - for (int l = 0; l < ic_res; ++l) { - int channel_block_stride = input_x_stride + ic4_minus * C4NUM + l; - int channel_block_offset = input_plane_offset + ic4_minus * 16 * C4NUM + l; - packed_input[channel_block_offset] = input_data[channel_block_stride]; - } - } // kernel_w loop - } // kernel_h loop - } // tile num loop -} - -void PackWeightFp16(float16_t *weight_data, ConvParameter *conv_param, float16_t *packed_weight) { - // original weight format : ohwi - const int tile_num = 8; - const int inchannel_block = 4; - int kernel_h = conv_param->kernel_h_; - int kernel_w = conv_param->kernel_w_; - int in_channel = conv_param->input_channel_; - int out_channel = conv_param->output_channel_; - int kernel_block = UP_DIV(out_channel, tile_num); - int channel_block = UP_DIV(in_channel, inchannel_block); - int kernel_plane = kernel_h * kernel_w; - int pack_weight_size = kernel_block * channel_block * tile_num * inchannel_block * kernel_plane; - - int unit_size = tile_num * inchannel_block; - int block_size = pack_weight_size / kernel_block; - - for (int m = 0; m < kernel_plane; m++) { - int kernel_plane_stride = m * in_channel; - int packed_kernel_plane_stride = m * unit_size * channel_block; - for (int i = 0; i < channel_block; i++) { - int channel_block_stride = kernel_plane_stride + i * inchannel_block; - int packed_channel_block_size = packed_kernel_plane_stride + i * unit_size; - int ic_remainder = in_channel - i * inchannel_block; - int real_ic_num = ic_remainder < inchannel_block ? ic_remainder : inchannel_block; - for (int h = 0; h < real_ic_num; h++) { - int block_stride = channel_block_stride + h; - int packed_block_stride = packed_channel_block_size + h * tile_num; - for (int j = 0; j < kernel_block; j++) { - int kernel_block_stride = block_stride + j * tile_num * kernel_plane * in_channel; - int packed_kernel_block_size = packed_block_stride + j * block_size; - int oc_remainder = out_channel - j * tile_num; - int real_oc_num = oc_remainder < tile_num ? oc_remainder : tile_num; - for (int k = 0; k < real_oc_num; k++) { - float16_t *origin_data_ptr = weight_data + kernel_block_stride + k * kernel_plane * in_channel; - float16_t *packed_data_ptr = packed_weight + packed_kernel_block_size + k; - *packed_data_ptr = *origin_data_ptr; - } - } // kernel block loop - } // inchannel block loop - } // channel block loop - } // kernel plane loop + if (dilation_h == 1 && dilation_w == 1) { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * in_w * in_channel + input_stride; + int input_x_stride = input_y_stride + kw_s * in_channel; + int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane; + memcpy(packed_input + input_plane_offset, input_data + input_x_stride, + (kw_e - kw_s) * in_channel * sizeof(float16_t)); + } // kernel_h loop + } else { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * dilation_h * in_w * in_channel + input_stride; + for (int n = kw_s; n < kw_e; n++) { + int input_x_stride = input_y_stride + n * dilation_w * in_channel; + int input_plane_offset = (j * kernel_w + n) * in_channel + i * in_channel * kernel_plane; + memcpy(packed_input + input_plane_offset, input_data + input_x_stride, in_channel * sizeof(float16_t)); + } // kernel_w loop + } // kernel_h loop + } + } // tile num loop } void PackWeightToC8Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param) { diff --git a/mindspore/lite/nnacl/fp16/pack_fp16.h b/mindspore/lite/nnacl/fp16/pack_fp16.h index 759a0e04a7..2f1ad6eebe 100644 --- a/mindspore/lite/nnacl/fp16/pack_fp16.h +++ b/mindspore/lite/nnacl/fp16/pack_fp16.h @@ -29,8 +29,6 @@ extern "C" { void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float16_t *packed_input, int real_cal_num, int block_index); -void PackWeightFp16(float16_t *weight_data, ConvParameter *conv_param, float16_t *packed_weight); - void PackWeightToC8Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param); void PackWeightToC4Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc deleted file mode 100644 index 9bd75a49a7..0000000000 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc +++ /dev/null @@ -1,261 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h" -#include "nnacl/fp16/conv_fp16.h" -#include "nnacl/fp16/cast_fp16.h" -#include "nnacl/fp16/winograd_transform_fp16.h" -#include "nnacl/fp16/pack_fp16.h" -#include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h" -#include "schema/model_generated.h" -#include "src/kernel_registry.h" -#include "include/errorcode.h" -#include "src/runtime/runtime_api.h" - -using mindspore::kernel::KERNEL_ARCH::kCPU; -using mindspore::lite::KernelRegistrar; -using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_OK; -using mindspore::schema::PrimitiveType_Conv2D; - -namespace mindspore::kernel { -void ProcessFilterFp16(float16_t *origin_weight, float16_t *dst_weight, ConvParameter *conv_param) { - auto input_channel = conv_param->input_channel_; - auto output_channel = conv_param->output_channel_; - auto kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; - int iC8 = UP_DIV(input_channel, C8NUM); - int oC8 = UP_DIV(output_channel, C8NUM); - - size_t tmp_size = oC8 * C8NUM * iC8 * C8NUM * kernel_plane * sizeof(float16_t); - auto tmp_addr = reinterpret_cast(malloc(tmp_size)); - if (tmp_addr == nullptr) { - MS_LOG(ERROR) << "malloc tmp_addr failed."; - return; - } - memset(tmp_addr, 0, tmp_size); - - PackWeightToC4Fp16(origin_weight, tmp_addr, conv_param); - Conv3x3Fp16FilterTransform(tmp_addr, dst_weight, iC8 * 2, output_channel, kernel_plane); - - free(tmp_addr); -} - -int Convolution3x3FP16CPUKernel::InitWeightBias() { - auto filter_tensor = in_tensors_.at(kWeightIndex); - auto input_channel = filter_tensor->Channel(); - auto output_channel = filter_tensor->Batch(); - conv_param_->input_channel_ = input_channel; - conv_param_->output_channel_ = output_channel; - int iC8 = UP_DIV(input_channel, C8NUM); - int oC8 = UP_DIV(output_channel, C8NUM); - - size_t transformed_size = iC8 * C8NUM * oC8 * C8NUM * 36 * sizeof(float16_t); - transformed_filter_addr_ = reinterpret_cast(malloc(transformed_size)); - if (transformed_filter_addr_ == nullptr) { - MS_LOG(ERROR) << "malloc transformed_filter_addr_ failed."; - return RET_ERROR; - } - memset(transformed_filter_addr_, 0, transformed_size); - auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Get Execute filter failed."; - return ret; - } - ProcessFilterFp16(execute_weight_, transformed_filter_addr_, conv_param_); - - size_t new_bias_size = oC8 * C8NUM * sizeof(float16_t); - bias_data_ = malloc(new_bias_size); - if (bias_data_ == nullptr) { - MS_LOG(ERROR) << "malloc bias_data_ failed."; - return RET_ERROR; - } - memset(bias_data_, 0, new_bias_size); - auto fp16_bias_data = reinterpret_cast(bias_data_); - if (in_tensors_.size() == kInputSize2) { - auto ori_bias_addr = reinterpret_cast(in_tensors_.at(kBiasIndex)->MutableData()); - for (int i = 0; i < output_channel; ++i) { - fp16_bias_data[i] = (float16_t)ori_bias_addr[i]; - } - } else { - MS_ASSERT(inputs_.size() == kInputSize1); - } - return RET_OK; -} - -int Convolution3x3FP16CPUKernel::InitTmpBuffer() { - const int tile_num = 16; - const int k_plane = 36; - int oC8 = UP_DIV(conv_param_->output_channel_, C8NUM); - int iC8 = UP_DIV(conv_param_->input_channel_, C8NUM); - MS_ASSERT(ctx_->allocator != nullptr); - - size_t nhwc8_input_size = - iC8 * C8NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t); - nhwc4_input_ = ctx_->allocator->Malloc(nhwc8_input_size); - if (nhwc4_input_ == nullptr) { - MS_LOG(ERROR) << "malloc nhwc4_input_ failed."; - return RET_ERROR; - } - - size_t tile_buffer_size = thread_count_ * tile_num * k_plane * iC8 * C8NUM * sizeof(float16_t); - tile_buffer_ = reinterpret_cast(ctx_->allocator->Malloc(tile_buffer_size)); - if (tile_buffer_ == nullptr) { - MS_LOG(ERROR) << "malloc tile_buffer_ failed."; - return RET_ERROR; - } - - size_t block_unit_buffer_size = thread_count_ * k_plane * C8NUM * sizeof(float16_t); - block_unit_buffer_ = reinterpret_cast(ctx_->allocator->Malloc(block_unit_buffer_size)); - if (block_unit_buffer_ == nullptr) { - MS_LOG(ERROR) << "malloc block_unit_buffer_ failed."; - return RET_ERROR; - } - - size_t tmp_dst_buffer_size = thread_count_ * tile_num * k_plane * oC8 * C8NUM * sizeof(float16_t); - tmp_dst_buffer_ = reinterpret_cast(ctx_->allocator->Malloc(tmp_dst_buffer_size)); - if (tmp_dst_buffer_ == nullptr) { - MS_LOG(ERROR) << "malloc tmp_dst_buffer_ failed."; - return RET_ERROR; - } - - int new_out_plane = UP_DIV(conv_param_->output_h_, C4NUM) * UP_DIV(conv_param_->output_w_, C4NUM) * C4NUM * C4NUM; - size_t tmp_out_size = oC8 * C8NUM * conv_param_->output_batch_ * new_out_plane * sizeof(float16_t); - tmp_out_ = reinterpret_cast(ctx_->allocator->Malloc(tmp_out_size)); - if (tmp_out_ == nullptr) { - MS_LOG(ERROR) << "malloc tmp_out_ failed."; - return RET_ERROR; - } - - return RET_OK; -} - -void Convolution3x3FP16CPUKernel::ConfigInputOutput() { - auto input_tensor = in_tensors_.at(kInputIndex); - auto input_format = input_tensor->GetFormat(); - schema::Format execute_format = schema::Format::Format_NHWC4; - convert_func_ = LayoutTransformFp16(input_format, execute_format); - if (convert_func_ == nullptr) { - MS_LOG(ERROR) << "layout convert func is nullptr."; - return; - } -} - -int Convolution3x3FP16CPUKernel::Init() { - auto ret = InitWeightBias(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Init weight bias failed."; - return RET_ERROR; - } - if (!InferShapeDone()) { - return RET_OK; - } - return ReSize(); -} - -int Convolution3x3FP16CPUKernel::ReSize() { - auto ret = ConvolutionBaseCPUKernel::CheckResizeValid(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Resize is invalid."; - return ret; - } - - ret = ConvolutionBaseCPUKernel::Init(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "ConvolutionBase init failed."; - return ret; - } - return RET_OK; -} - -int Convolution3x3FP16CPUKernel::RunImpl(int task_id) { - Conv3x3Fp16(reinterpret_cast(nhwc4_input_), transformed_filter_addr_, - reinterpret_cast(bias_data_), execute_output_, tile_buffer_, block_unit_buffer_, - tmp_dst_buffer_, tmp_out_, task_id, conv_param_); - return RET_OK; -} - -static int Convolution3x3Fp16Impl(void *cdata, int task_id) { - auto conv = reinterpret_cast(cdata); - auto error_code = conv->RunImpl(task_id); - if (error_code != RET_OK) { - MS_LOG(ERROR) << "Convolution3x3 Fp16 Run error task_id[" << task_id << "] error_code[" << error_code << "]"; - return RET_ERROR; - } - return RET_OK; -} - -int Convolution3x3FP16CPUKernel::PostProcess() { - auto act_type = conv_param_->act_type_; - switch (act_type) { - case ActType_No: - UnPack3x3OutputFp16(tmp_out_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_, - conv_param_->output_w_, conv_param_->output_channel_); - break; - case ActType_Relu: - UnPack3x3ReluOutputFp16(tmp_out_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_, - conv_param_->output_w_, conv_param_->output_channel_); - break; - case ActType_Relu6: - UnPack3x3Relu6OutputFp16(tmp_out_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_, - conv_param_->output_w_, conv_param_->output_channel_); - break; - default: - MS_LOG(ERROR) << "Unsupport activation type."; - return RET_ERROR; - } - return RET_OK; -} - -int Convolution3x3FP16CPUKernel::Run() { - auto ret = Prepare(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Prepare failed."; - return RET_ERROR; - } - ret = ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Get execute tensor failed."; - return ret; - } - ret = InitTmpBuffer(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Init tmp buffer failed."; - return RET_ERROR; - } - int in_batch = conv_param_->input_batch_; - int in_h = conv_param_->input_h_; - int in_w = conv_param_->input_w_; - int in_channel = conv_param_->input_channel_; - PackNHWCToNHWC8Fp16(reinterpret_cast(execute_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel); - - int error_code = ParallelLaunch(this->context_->thread_pool_, Convolution3x3Fp16Impl, this, thread_count_); - if (error_code != RET_OK) { - MS_LOG(ERROR) << "conv3x3 fp16 error error_code[" << error_code << "]"; - FreeTmpBuffer(); - return RET_ERROR; - } - - ret = PostProcess(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Post process failed."; - return ret; - } - ConvolutionBaseFP16CPUKernel::IfCastOutput(); - ConvolutionBaseFP16CPUKernel::FreeTmpBuffer(); - FreeTmpBuffer(); - return RET_OK; -} -} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h deleted file mode 100644 index fbd8748a42..0000000000 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h +++ /dev/null @@ -1,85 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_3x3_FP16_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_3x3_FP16_H_ - -#include -#include -#include "src/lite_kernel.h" -#include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h" -#include "nnacl/optimized_kernel.h" - -namespace mindspore::kernel { -class Convolution3x3FP16CPUKernel : public ConvolutionBaseFP16CPUKernel { - public: - Convolution3x3FP16CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const InnerContext *ctx, - const mindspore::lite::PrimitiveC *primitive) - : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {} - ~Convolution3x3FP16CPUKernel() override { - if (fp16_weight_ != nullptr) { - free(fp16_weight_); - fp16_weight_ = nullptr; - } - if (transformed_filter_addr_ != nullptr) { - free(transformed_filter_addr_); - transformed_filter_addr_ = nullptr; - } - } - - int Init() override; - int ReSize() override; - int Run() override; - int RunImpl(int task_id); - int InitWeightBias(); - int InitTmpBuffer(); - void ConfigInputOutput(); - int PostProcess(); - - private: - void FreeTmpBuffer() { - if (nhwc4_input_ != nullptr) { - ctx_->allocator->Free(nhwc4_input_); - nhwc4_input_ = nullptr; - } - if (tile_buffer_ != nullptr) { - ctx_->allocator->Free(tile_buffer_); - tile_buffer_ = nullptr; - } - if (block_unit_buffer_ != nullptr) { - ctx_->allocator->Free(block_unit_buffer_); - block_unit_buffer_ = nullptr; - } - if (tmp_dst_buffer_ != nullptr) { - ctx_->allocator->Free(tmp_dst_buffer_); - tmp_dst_buffer_ = nullptr; - } - if (tmp_out_ != nullptr) { - ctx_->allocator->Free(tmp_out_); - tmp_out_ = nullptr; - } - } - float16_t *transformed_filter_addr_ = nullptr; - float16_t *tile_buffer_ = nullptr; - float16_t *block_unit_buffer_ = nullptr; - float16_t *tmp_dst_buffer_ = nullptr; - float16_t *tmp_out_ = nullptr; -}; -void ProcessFilterFp16(float16_t *origin_weight, float16_t *dst_weight, ConvParameter *conv_param); -} // namespace mindspore::kernel - -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_3x3_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc index ca61d26254..2f68e63b81 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc @@ -17,7 +17,6 @@ #include "src/runtime/kernel/arm/fp16/convolution_fp16.h" #include #include "src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h" -#include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h" #include "src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h" #include "nnacl/fp16/conv_fp16.h" #include "nnacl/fp16/cast_fp16.h" @@ -45,9 +44,8 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { conv_param_->input_channel_ = in_channel; conv_param_->output_channel_ = out_channel; int oc8 = UP_DIV(out_channel, C8NUM); - int ic4 = UP_DIV(in_channel, C4NUM); int kernel_plane = kernel_h * kernel_w; - int pack_weight_size = oc8 * ic4 * C8NUM * C4NUM * kernel_plane; + int pack_weight_size = oc8 * C8NUM * in_channel * kernel_plane; // init weight auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); @@ -61,7 +59,7 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { return RET_ERROR; } memset(packed_weight_, 0, pack_weight_size * sizeof(float16_t)); - PackWeightFp16(execute_weight_, conv_param_, packed_weight_); + RowMajor2Col8MajorFp16(execute_weight_, packed_weight_, out_channel, in_channel * kernel_plane, false); // init bias bias_data_ = malloc(oc8 * C8NUM * sizeof(float16_t)); @@ -83,24 +81,20 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { } int ConvolutionFP16CPUKernel::InitTmpBuffer() { + const int cal_num = 16; int in_channel = conv_param_->input_channel_; - int out_channel = conv_param_->output_channel_; - int channel_block = UP_DIV(in_channel, C4NUM); - int cal_num = 16; int kernel_plane = conv_param_->kernel_h_ * conv_param_->kernel_w_; - int unit_size = kernel_plane * channel_block * C4NUM; - int packed_input_size = thread_count_ * cal_num * unit_size; + int unit_size = kernel_plane * in_channel * cal_num * thread_count_; - packed_input_ = reinterpret_cast(ctx_->allocator->Malloc(packed_input_size * sizeof(float16_t))); + packed_input_ = reinterpret_cast(ctx_->allocator->Malloc(unit_size * sizeof(float16_t))); if (packed_input_ == nullptr) { MS_LOG(ERROR) << "malloc packed_input_ failed."; return RET_ERROR; } - tmp_output_block_ = - reinterpret_cast(ctx_->allocator->Malloc(thread_count_ * cal_num * out_channel * sizeof(float16_t))); - if (tmp_output_block_ == nullptr) { - MS_LOG(ERROR) << "malloc tmp_output_block_ failed."; + col_major_input_ = reinterpret_cast(ctx_->allocator->Malloc(unit_size * sizeof(float16_t))); + if (col_major_input_ == nullptr) { + MS_LOG(ERROR) << "malloc col_major_input_ failed."; return RET_ERROR; } return RET_OK; @@ -134,7 +128,7 @@ int ConvolutionFP16CPUKernel::ReSize() { } int ConvolutionFP16CPUKernel::RunImpl(int task_id) { - ConvFp16(execute_input_, packed_input_, packed_weight_, reinterpret_cast(bias_data_), tmp_output_block_, + ConvFp16(execute_input_, packed_input_, packed_weight_, reinterpret_cast(bias_data_), col_major_input_, execute_output_, task_id, conv_param_); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h index 36082117cb..8b13f1578f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h @@ -53,14 +53,14 @@ class ConvolutionFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { ctx_->allocator->Free(packed_input_); packed_input_ = nullptr; } - if (tmp_output_block_ != nullptr) { - ctx_->allocator->Free(tmp_output_block_); - tmp_output_block_ = nullptr; + if (col_major_input_ != nullptr) { + ctx_->allocator->Free(col_major_input_); + col_major_input_ = nullptr; } } float16_t *packed_input_ = nullptr; float16_t *packed_weight_ = nullptr; - float16_t *tmp_output_block_ = nullptr; + float16_t *col_major_input_ = nullptr; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp16/convolution_fp16_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp16/convolution_fp16_tests.cc deleted file mode 100644 index 684d2c9d07..0000000000 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp16/convolution_fp16_tests.cc +++ /dev/null @@ -1,591 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "src/common/log_adapter.h" -#include "common/common_test.h" -#include "mindspore/lite/src/common/utils.h" -#include "src/common/file_utils.h" -#include "mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h" -#include "mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h" -#include "nnacl/fp16/conv_fp16.h" - -namespace mindspore { -class TestConvolutionFp16 : public mindspore::CommonTest { - public: - TestConvolutionFp16() {} -}; - -void InitConvParamGroup1Fp16(ConvParameter *conv_param) { - conv_param->input_batch_ = 1; - conv_param->input_h_ = 28; - conv_param->input_w_ = 28; - conv_param->input_channel_ = 3; - - conv_param->output_batch_ = 1; - conv_param->output_h_ = 28; - conv_param->output_w_ = 28; - conv_param->output_channel_ = 32; - - conv_param->kernel_h_ = 3; - conv_param->kernel_w_ = 3; - - conv_param->stride_h_ = 1; - conv_param->stride_w_ = 1; - - conv_param->dilation_h_ = 1; - conv_param->dilation_w_ = 1; - - conv_param->pad_u_ = 1; - conv_param->pad_l_ = 1; - conv_param->thread_num_ = 1; -} - -void InitConvParamGroup2Fp16(ConvParameter *conv_param) { - conv_param->input_batch_ = 1; - conv_param->input_h_ = 128; - conv_param->input_w_ = 128; - conv_param->input_channel_ = 32; - - conv_param->output_batch_ = 1; - conv_param->output_h_ = 128; - conv_param->output_w_ = 128; - conv_param->output_channel_ = 32; - - conv_param->kernel_h_ = 3; - conv_param->kernel_w_ = 3; - - conv_param->stride_h_ = 1; - conv_param->stride_w_ = 1; - - conv_param->dilation_h_ = 1; - conv_param->dilation_w_ = 1; - - conv_param->pad_u_ = 1; - conv_param->pad_l_ = 1; - conv_param->thread_num_ = 1; -} - -TEST_F(TestConvolutionFp16, ConvTest1) { - // prepare stage - auto conv_param = new ConvParameter(); - InitConvParamGroup1Fp16(conv_param); - - int tile_num = 16; - int k_h = conv_param->kernel_h_; - int k_w = conv_param->kernel_w_; - int kernel_plane = k_h * k_w; - int in_batch = conv_param->input_batch_; - int in_channel = conv_param->input_channel_; - int i_h = conv_param->input_h_; - int i_w = conv_param->input_w_; - int out_channel = conv_param->output_channel_; - int ic4 = UP_DIV(in_channel, C4NUM); - int oc8 = UP_DIV(out_channel, C8NUM); - - size_t weight_size; - std::string weight_path = "./test_data/conv/convfp32_weight_32_3_3_3.bin"; - auto weight_data = reinterpret_cast(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size)); - std::cout << "==============fp32 weight data===========" << std::endl; - for (int i = 0; i < 20; i++) { - std::cout << weight_data[i] << ", "; - } - std::cout << std::endl; - - std::cout << "weight data size: " << weight_size / sizeof(float) << std::endl; - - int weight_ele_size = weight_size / sizeof(float); - auto fp16_weight_data = new float16_t[weight_ele_size]; - for (int i = 0; i < weight_ele_size; i++) { - fp16_weight_data[i] = static_cast(weight_data[i]); - } - - std::cout << "==============fp16 weight data===========" << std::endl; - for (int i = 0; i < 20; i++) { - std::cout << fp16_weight_data[i] << ", "; - } - std::cout << std::endl; - - auto packed_weight = reinterpret_cast(malloc(k_h * k_w * ic4 * C4NUM * oc8 * C8NUM * sizeof(float16_t))); - PackWeightFp16(fp16_weight_data, conv_param, packed_weight); - - std::cout << "==============fp16 packed weight data===========" << std::endl; - for (int i = 0; i < 20; i++) { - std::cout << packed_weight[i] << ", "; - } - std::cout << std::endl; - - size_t input_size; - std::string input_path = "./test_data/conv/convfp32_input_1_28_28_3.bin"; - auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); - std::cout << "==============fp32 input data===========" << std::endl; - for (int i = 0; i < 20; i++) { - std::cout << input_data[i] << ", "; - } - std::cout << std::endl; - - int input_ele_size = input_size / sizeof(float); - auto fp16_input_data = new float16_t[input_ele_size]; - for (int i = 0; i < input_ele_size; i++) { - fp16_input_data[i] = static_cast(input_data[i]); - } - - auto nhwc4_input_data = reinterpret_cast(malloc(i_h * i_w * ic4 * C4NUM * sizeof(float16_t))); - PackNHWCToNHWC4Fp32(fp16_input_data, nhwc4_input_data, 1, i_h * i_w, in_channel); - - std::cout << "==============fp16 input data===========" << std::endl; - for (int i = 0; i < 20; i++) { - std::cout << fp16_input_data[i] << ", "; - } - std::cout << std::endl; - - int output_count = conv_param->output_h_ * conv_param->output_w_; - int output_tile_count = UP_DIV(output_count, tile_num); - int unit_size = kernel_plane * ic4 * C4NUM; - int packed_input_size = output_tile_count * tile_num * unit_size; - auto packed_input = reinterpret_cast(malloc(in_batch * packed_input_size * sizeof(float16_t))); - memset(packed_input, 0, in_batch * packed_input_size * sizeof(float16_t)); - - auto bias_data = reinterpret_cast(malloc(conv_param->output_channel_ * sizeof(float16_t))); - memset(bias_data, 0, conv_param->output_channel_ * sizeof(float16_t)); - - size_t output_data_size = - conv_param->output_batch_ * conv_param->output_channel_ * conv_param->output_h_ * conv_param->output_w_; - auto output_data = new float16_t[output_data_size]; - auto tmp_output_block = reinterpret_cast(malloc(tile_num * out_channel * sizeof(float16_t))); - - // runtime part - printf("Calculating runtime cost...\n"); - uint64_t time_avg = 0; - // warmup - for (int i = 0; i < 3; i++) { - ConvFp16(nhwc4_input_data, packed_input, packed_weight, bias_data, tmp_output_block, output_data, 0, conv_param); - } - - int loop_count = 100; - auto time_start = mindspore::lite::GetTimeUs(); - for (int i = 0; i < loop_count; i++) { - ConvFp16(nhwc4_input_data, packed_input, packed_weight, bias_data, tmp_output_block, output_data, 0, conv_param); - } - auto time_end = mindspore::lite::GetTimeUs(); - auto cost = time_end - time_start; - time_avg = cost / loop_count; - printf("single thread running time : %f ms\n", time_avg / 1000.0f); - - std::cout << "==============fp16 output data===========" << std::endl; - for (int i = 0; i < 20; i++) { - std::cout << output_data[i] << ", "; - } - std::cout << std::endl; - - auto fp32_output_data = new float[output_data_size]; - for (int i = 0; i < output_data_size; i++) { - fp32_output_data[i] = static_cast(output_data[i]); - } - printf("==================output data=================\n"); - for (int i = 0; i < 20; i++) { - std::cout << fp32_output_data[i] << " ,"; - } - std::cout << std::endl; - - std::string output_path = "./test_data/conv/convfp32_out_1_28_28_32.bin"; - lite::CompareOutput(fp32_output_data, output_path); - - free(nhwc4_input_data); - free(packed_input); - free(bias_data); - free(packed_weight); - free(tmp_output_block); - delete conv_param; - delete input_data; - delete weight_data; - delete[] fp16_weight_data; - delete[] fp16_input_data; - delete[] fp32_output_data; - delete[] output_data; - MS_LOG(INFO) << "TestConvolutionFp16 passed"; -} - -TEST_F(TestConvolutionFp16, ConvTest2) { - // prepare stage - auto conv_param = new ConvParameter(); - InitConvParamGroup2Fp16(conv_param); - - // parameter - int tile_num = 16; - int k_h = conv_param->kernel_h_; - int k_w = conv_param->kernel_w_; - int kernel_plane = k_h * k_w; - int in_batch = conv_param->input_batch_; - int in_channel = conv_param->input_channel_; - int out_channel = conv_param->output_channel_; - int ic4 = UP_DIV(in_channel, C4NUM); - int oc8 = UP_DIV(out_channel, C8NUM); - - // weight - size_t weight_size; - std::string weight_path = "./test_data/conv/convfp32_weight_32_3_3_32.bin"; - auto weight_data = reinterpret_cast(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size)); - int weight_ele_size = weight_size / sizeof(float); - auto fp16_weight_data = new float16_t[weight_ele_size]; - for (int i = 0; i < weight_ele_size; i++) { - fp16_weight_data[i] = static_cast(weight_data[i]); - } - auto packed_weight = reinterpret_cast(malloc(k_h * k_w * ic4 * C4NUM * oc8 * C8NUM * sizeof(float16_t))); - PackWeightFp16(fp16_weight_data, conv_param, packed_weight); - - // input - size_t input_size; - std::string input_path = "./test_data/conv/convfp32_input_1_128_128_32.bin"; - auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); - int input_ele_size = input_size / sizeof(float); - auto fp16_input_data = new float16_t[input_ele_size]; - for (int i = 0; i < input_ele_size; i++) { - fp16_input_data[i] = static_cast(input_data[i]); - } - int output_count = conv_param->output_h_ * conv_param->output_w_; - int output_tile_count = UP_DIV(output_count, tile_num); - int unit_size = kernel_plane * ic4 * C4NUM; - int packed_input_size = output_tile_count * tile_num * unit_size; - auto packed_input = reinterpret_cast(malloc(in_batch * packed_input_size * sizeof(float16_t))); - memset(packed_input, 0, in_batch * packed_input_size * sizeof(float16_t)); - - // bias - auto bias_data = reinterpret_cast(malloc(conv_param->output_channel_ * sizeof(float16_t))); - memset(bias_data, 0, conv_param->output_channel_ * sizeof(float16_t)); - - // output - auto tmp_output_block = reinterpret_cast(malloc(tile_num * out_channel * sizeof(float16_t))); - size_t output_data_size = - conv_param->output_batch_ * conv_param->output_channel_ * conv_param->output_h_ * conv_param->output_w_; - auto output_data = new float16_t[output_data_size]; - - // runtime part - printf("Calculating runtime cost...\n"); - uint64_t time_avg = 0; - // warmup - for (int i = 0; i < 3; i++) { - ConvFp16(fp16_input_data, packed_input, packed_weight, bias_data, tmp_output_block, output_data, 0, conv_param); - } - - int loop_count = 100; - auto time_start = mindspore::lite::GetTimeUs(); - for (int i = 0; i < loop_count; i++) { - ConvFp16(fp16_input_data, packed_input, packed_weight, bias_data, tmp_output_block, output_data, 0, conv_param); - } - auto time_end = mindspore::lite::GetTimeUs(); - auto cost = time_end - time_start; - time_avg = cost / loop_count; - printf("single thread running time : %f ms\n", time_avg / 1000.0f); - - std::cout << "==============fp16 output data===========" << std::endl; - for (int i = 0; i < 20; i++) { - std::cout << output_data[i] << ", "; - } - std::cout << std::endl; - - auto fp32_output_data = new float[output_data_size]; - for (int i = 0; i < output_data_size; i++) { - fp32_output_data[i] = static_cast(output_data[i]); - } - printf("==================output data=================\n"); - for (int i = 0; i < 20; i++) { - std::cout << fp32_output_data[i] << " ,"; - } - std::cout << std::endl; - - std::string output_path = "./test_data/conv/convfp32_out_1_128_128_32.bin"; - lite::CompareOutput(fp32_output_data, output_path); - - free(packed_input); - free(bias_data); - free(packed_weight); - free(tmp_output_block); - delete conv_param; - delete input_data; - delete weight_data; - delete[] fp16_weight_data; - delete[] fp16_input_data; - delete[] fp32_output_data; - delete[] output_data; - MS_LOG(INFO) << "TestConvolutionFp16 passed"; -} - -TEST_F(TestConvolutionFp16, Conv3x3Test1) { - auto conv_param = new ConvParameter(); - InitConvParamGroup1Fp16(conv_param); - int thread_count = 1; - int tile_num = 16; - int output_batch = conv_param->output_batch_; - int output_h = conv_param->output_h_; - int output_w = conv_param->output_w_; - int ic4 = UP_DIV(conv_param->input_channel_, C4NUM); - int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); - - // tmp buffer - int k_plane = 36; - size_t tile_buffer_size = thread_count * tile_num * k_plane * ic4 * C4NUM * sizeof(float16_t); - float16_t *tile_buffer = reinterpret_cast(malloc(tile_buffer_size)); - memset(tile_buffer, 0, tile_buffer_size); - - size_t block_unit_buffer_size = thread_count * k_plane * C4NUM * sizeof(float16_t); - float16_t *block_unit_buffer = reinterpret_cast(malloc(block_unit_buffer_size)); - memset(block_unit_buffer, 0, block_unit_buffer_size); - - size_t tmp_dst_buffer_size = thread_count * tile_num * k_plane * oc8 * C8NUM * sizeof(float16_t); - float16_t *tmp_dst_buffer = reinterpret_cast(malloc(tmp_dst_buffer_size)); - memset(tmp_dst_buffer, 0, tmp_dst_buffer_size); - - size_t tmp_out_size = oc8 * C8NUM * output_batch * output_h * output_w * tile_num * sizeof(float16_t); - float16_t *tmp_out = reinterpret_cast(malloc(tmp_out_size)); - memset(tmp_out, 0, tmp_out_size); - - // weight - size_t weight_size; - std::string weight_path = "./test_data/conv/convfp32_weight_32_3_3_3.bin"; - auto weight_data = reinterpret_cast(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size)); - std::cout << "==============fp32 weight data===========" << std::endl; - for (int i = 0; i < 20; i++) { - std::cout << weight_data[i] << ", "; - } - std::cout << std::endl; - - std::cout << "weight data size: " << weight_size / sizeof(float) << std::endl; - - int weight_ele_size = weight_size / sizeof(float); - auto fp16_weight_data = new float16_t[weight_ele_size]; - for (int i = 0; i < weight_ele_size; i++) { - fp16_weight_data[i] = (float16_t)weight_data[i]; - } - - std::cout << "==============fp16 weight data===========" << std::endl; - for (int i = 0; i < 20; i++) { - std::cout << fp16_weight_data[i] << ", "; - } - std::cout << std::endl; - - size_t transformed_size = ic4 * C4NUM * oc8 * C8NUM * 36; - auto transformed_weight_data = new float16_t[transformed_size]; - memset(transformed_weight_data, 0, transformed_size * sizeof(float16_t)); - kernel::ProcessFilterFp16(fp16_weight_data, transformed_weight_data, conv_param); - - // bias - auto bias_data = - reinterpret_cast(malloc(UP_DIV(conv_param->output_channel_, 8) * 8 * sizeof(float16_t))); - memset(bias_data, 0, UP_DIV(conv_param->output_channel_, 8) * 8 * sizeof(float16_t)); - - // input - size_t input_size; - std::string input_path = "./test_data/conv/convfp32_input_1_28_28_3.bin"; - auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); - std::cout << "==============fp32 input data===========" << std::endl; - for (int i = 0; i < 20; i++) { - std::cout << input_data[i] << ", "; - } - std::cout << std::endl; - - int input_ele_size = input_size / sizeof(float); - auto fp16_input_data = new float16_t[input_ele_size]; - for (int i = 0; i < input_ele_size; i++) { - fp16_input_data[i] = static_cast(input_data[i]); - } - - std::cout << "==============fp16 input data===========" << std::endl; - for (int i = 0; i < 20; i++) { - std::cout << fp16_input_data[i] << ", "; - } - std::cout << std::endl; - - // output - size_t output_data_size = - conv_param->output_batch_ * conv_param->output_channel_ * conv_param->output_h_ * conv_param->output_w_; - auto output_data = new float16_t[output_data_size]; - - // runtime part - printf("Calculating runtime cost...\n"); - uint64_t time_avg = 0; - // warmup - for (int i = 0; i < 3; i++) { - Conv3x3Fp16(fp16_input_data, transformed_weight_data, bias_data, output_data, tile_buffer, block_unit_buffer, - tmp_dst_buffer, tmp_out, 0, conv_param); - } - - int loop_count = 100; - auto time_start = mindspore::lite::GetTimeUs(); - for (int i = 0; i < loop_count; i++) { - Conv3x3Fp16(fp16_input_data, transformed_weight_data, bias_data, output_data, tile_buffer, block_unit_buffer, - tmp_dst_buffer, tmp_out, 0, conv_param); - } - auto time_end = mindspore::lite::GetTimeUs(); - auto cost = time_end - time_start; - time_avg = cost / loop_count; - printf("single thread running time : %f ms\n", time_avg / 1000.0f); - - std::cout << "==============fp16 output data===========" << std::endl; - for (int i = 0; i < 20; i++) { - std::cout << output_data[i] << ", "; - } - std::cout << std::endl; - - auto fp32_output_data = new float[output_data_size]; - for (int i = 0; i < output_data_size; i++) { - fp32_output_data[i] = static_cast(output_data[i]); - } - printf("==================output data=================\n"); - for (int i = 0; i < 20; i++) { - std::cout << fp32_output_data[i] << " ,"; - } - std::cout << std::endl; - - std::string output_path = "./test_data/conv/convfp32_out_1_28_28_32.bin"; - lite::CompareOutput(fp32_output_data, output_path); - - free(bias_data); - free(tile_buffer); - free(block_unit_buffer); - free(tmp_dst_buffer); - free(tmp_out); - delete input_data; - delete weight_data; - delete conv_param; - delete[] fp16_weight_data; - delete[] fp16_input_data; - delete[] fp32_output_data; - delete[] output_data; - delete[] transformed_weight_data; - MS_LOG(INFO) << "TestConvolutionFp16 Conv3x3 passed"; -} - -TEST_F(TestConvolutionFp16, Conv3x3Test2) { - auto conv_param = new ConvParameter(); - InitConvParamGroup2Fp16(conv_param); - int thread_count = 1; - int tile_num = 16; - int output_batch = conv_param->output_batch_; - int output_h = conv_param->output_h_; - int output_w = conv_param->output_w_; - int ic4 = UP_DIV(conv_param->input_channel_, C4NUM); - int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); - - // tmp buffer - int k_plane = 36; - size_t tile_buffer_size = thread_count * tile_num * k_plane * ic4 * C4NUM * sizeof(float16_t); - float16_t *tile_buffer = reinterpret_cast(malloc(tile_buffer_size)); - memset(tile_buffer, 0, tile_buffer_size); - - size_t block_unit_buffer_size = thread_count * k_plane * C4NUM * sizeof(float16_t); - float16_t *block_unit_buffer = reinterpret_cast(malloc(block_unit_buffer_size)); - memset(block_unit_buffer, 0, block_unit_buffer_size); - - size_t tmp_dst_buffer_size = thread_count * tile_num * k_plane * oc8 * C8NUM * sizeof(float16_t); - float16_t *tmp_dst_buffer = reinterpret_cast(malloc(tmp_dst_buffer_size)); - memset(tmp_dst_buffer, 0, tmp_dst_buffer_size); - - size_t tmp_out_size = oc8 * C8NUM * output_batch * output_h * output_w * tile_num * sizeof(float16_t); - float16_t *tmp_out = reinterpret_cast(malloc(tmp_out_size)); - memset(tmp_out, 0, tmp_out_size); - - // weight - size_t weight_size; - std::string weight_path = "./test_data/conv/convfp32_weight_32_3_3_32.bin"; - auto weight_data = reinterpret_cast(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size)); - int weight_ele_size = weight_size / sizeof(float); - auto fp16_weight_data = new float16_t[weight_ele_size]; - for (int i = 0; i < weight_ele_size; i++) { - fp16_weight_data[i] = static_cast(weight_data[i]); - } - size_t transformed_size = ic4 * C4NUM * oc8 * C8NUM * 36; - auto transformed_weight_data = new float16_t[transformed_size]; - memset(transformed_weight_data, 0, transformed_size * sizeof(float16_t)); - kernel::ProcessFilterFp16(fp16_weight_data, transformed_weight_data, conv_param); - - // bias - auto bias_data = - reinterpret_cast(malloc(UP_DIV(conv_param->output_channel_, 8) * 8 * sizeof(float16_t))); - memset(bias_data, 0, UP_DIV(conv_param->output_channel_, 8) * 8 * sizeof(float16_t)); - - // input - size_t input_size; - std::string input_path = "./test_data/conv/convfp32_input_1_128_128_32.bin"; - auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); - int input_ele_size = input_size / sizeof(float); - auto fp16_input_data = new float16_t[input_ele_size]; - for (int i = 0; i < input_ele_size; i++) { - fp16_input_data[i] = static_cast(input_data[i]); - } - - // output - size_t output_data_size = - conv_param->output_batch_ * conv_param->output_channel_ * conv_param->output_h_ * conv_param->output_w_; - auto output_data = new float16_t[output_data_size]; - - // runtime part - printf("Calculating runtime cost...\n"); - uint64_t time_avg = 0; - // warmup - for (int i = 0; i < 3; i++) { - Conv3x3Fp16(fp16_input_data, transformed_weight_data, bias_data, output_data, tile_buffer, block_unit_buffer, - tmp_dst_buffer, tmp_out, 0, conv_param); - } - - int loop_count = 100; - auto time_start = mindspore::lite::GetTimeUs(); - for (int i = 0; i < loop_count; i++) { - Conv3x3Fp16(fp16_input_data, transformed_weight_data, bias_data, output_data, tile_buffer, block_unit_buffer, - tmp_dst_buffer, tmp_out, 0, conv_param); - } - auto time_end = mindspore::lite::GetTimeUs(); - auto cost = time_end - time_start; - time_avg = cost / loop_count; - printf("single thread running time : %f ms\n", time_avg / 1000.0f); - - std::cout << "==============fp16 output data===========" << std::endl; - for (int i = 0; i < 20; i++) { - std::cout << output_data[i] << ", "; - } - std::cout << std::endl; - - auto fp32_output_data = new float[output_data_size]; - for (int i = 0; i < output_data_size; i++) { - fp32_output_data[i] = static_cast(output_data[i]); - } - printf("==================output data=================\n"); - for (int i = 0; i < 20; i++) { - std::cout << fp32_output_data[i] << " ,"; - } - std::cout << std::endl; - - std::string output_path = "./test_data/conv/convfp32_out_1_128_128_32.bin"; - lite::CompareOutput(fp32_output_data, output_path); - - free(bias_data); - free(tile_buffer); - free(block_unit_buffer); - free(tmp_dst_buffer); - free(tmp_out); - delete input_data; - delete weight_data; - delete conv_param; - delete[] fp16_weight_data; - delete[] fp16_input_data; - delete[] fp32_output_data; - delete[] output_data; - delete[] transformed_weight_data; - MS_LOG(INFO) << "TestConvolutionFp16 Conv3x3 passed"; -} - -} // namespace mindspore