| @@ -107,15 +107,26 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_ | |||
| } | |||
| } | |||
| #ifdef ENABLE_ARM64 | |||
| #if defined(ENABLE_AVX) || defined(ENABLE_ARM64) | |||
| void ConvFp32OutNC4HW4(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, | |||
| float *col_major_input, float *output_data, int task_id, const ConvParameter *conv_param) { | |||
| if (conv_param->thread_num_ == 0) { | |||
| return; | |||
| } | |||
| int output_hw = conv_param->output_h_ * conv_param->output_w_; | |||
| int out_channel = conv_param->output_channel_; | |||
| int input_hw = conv_param->input_h_ * conv_param->input_w_; | |||
| int in_channel = conv_param->input_channel_; | |||
| Row2ColMajorFuncPtr Row2ColMajor = NULL; | |||
| int cal_num = 0; | |||
| int out_tile = 0; | |||
| #ifdef ENABLE_AVX | |||
| cal_num = C6NUM; | |||
| out_tile = C8NUM; | |||
| Row2ColMajor = RowMajor2Col6Major; | |||
| int align_channel = UP_DIV(out_channel, C16NUM) * C16NUM; | |||
| #else | |||
| out_tile = C4NUM; | |||
| MatmulFloatOptFuncPtr MatmulFloatOpt = NULL; | |||
| if (output_hw <= C4NUM) { | |||
| cal_num = C4NUM; | |||
| @@ -130,7 +141,7 @@ void ConvFp32OutNC4HW4(const float *input_data, float *packed_input, const float | |||
| Row2ColMajor = RowMajor2Col12Major; | |||
| MatmulFloatOpt = MatmulFloatNeon64OptRow12; | |||
| } | |||
| #endif | |||
| int block_per_thread = UP_DIV(UP_DIV(output_hw, cal_num), conv_param->thread_num_); | |||
| int start_block = block_per_thread * task_id; | |||
| int start_hw = start_block * cal_num; | |||
| @@ -138,31 +149,160 @@ void ConvFp32OutNC4HW4(const float *input_data, float *packed_input, const float | |||
| if (start_hw >= end_hw) { | |||
| return; | |||
| } | |||
| int out_stride = MSMIN(conv_param->output_channel_, C4NUM) * cal_num; | |||
| #ifdef ENABLE_AVX | |||
| int act_type = 0; | |||
| if (conv_param->act_type_ == ActType_Relu6) { | |||
| act_type += 1; | |||
| } | |||
| if (conv_param->act_type_ == ActType_Relu || conv_param->act_type_ == ActType_Relu6) { | |||
| act_type += 2; | |||
| } | |||
| int out_stride = out_tile * cal_num; | |||
| int out_block_stride = output_hw * C8NUM; | |||
| #else | |||
| int out_stride = MSMIN(out_channel, out_tile) * cal_num; | |||
| #endif | |||
| int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; | |||
| packed_input += task_id * deep * cal_num; | |||
| col_major_input += task_id * deep * cal_num; | |||
| size_t input_size = deep * cal_num * sizeof(float); | |||
| for (int b = 0; b < conv_param->input_batch_; b++) { | |||
| int out_channel = conv_param->output_channel_; | |||
| int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; | |||
| int out_offset = b * out_channel * output_hw + start_hw * MSMIN(out_channel, C4NUM); | |||
| int in_offset = b * in_channel * input_hw; | |||
| #ifdef ENABLE_AVX | |||
| int out_offset = b * align_channel * output_hw + start_hw * out_tile; | |||
| #else | |||
| int out_offset = b * out_channel * output_hw + start_hw * MSMIN(out_channel, out_tile); | |||
| #endif | |||
| for (int i = start_hw; i < end_hw; i += cal_num, out_offset += out_stride) { | |||
| int real_cal_row = MSMIN(output_hw - i, cal_num); | |||
| memset(packed_input, 0, input_size); | |||
| Im2ColPackUnitFp32(input_data + in_offset, conv_param, packed_input, real_cal_row, i); | |||
| Row2ColMajor(packed_input, col_major_input, cal_num, deep); | |||
| float *gemm_output = output_data + out_offset; | |||
| #ifdef ENABLE_AVX | |||
| for (int oc = 0; oc < out_channel; oc += C16NUM) { | |||
| CommonConv6x16Kernel(gemm_output + oc * output_hw, col_major_input, packed_weight + oc * deep, bias_data + oc, | |||
| deep, out_block_stride, act_type, real_cal_row); | |||
| } | |||
| #else | |||
| MatmulFloatOpt(col_major_input, packed_weight, gemm_output, bias_data, conv_param->act_type_, deep, real_cal_row, | |||
| out_channel, output_hw, OutType_NC4HW4); | |||
| #endif | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| #ifdef ENABLE_AVX | |||
| void CommonConv6x16Kernel(float *dst, const float *src, const float *weight, const float *bias, const size_t depth, | |||
| const size_t out_step, const size_t act_flag, const size_t real_cal_row) { | |||
| #define Store1 \ | |||
| _mm256_storeu_ps(dst, out[0]); \ | |||
| _mm256_storeu_ps(dst + out_step, out[1]); | |||
| #define Store2 \ | |||
| Store1 _mm256_storeu_ps(dst + C8NUM, out[2]); \ | |||
| _mm256_storeu_ps(dst + out_step + C8NUM, out[3]); | |||
| #define Store3 \ | |||
| Store2 _mm256_storeu_ps(dst + C16NUM, out[4]); \ | |||
| _mm256_storeu_ps(dst + out_step + C16NUM, out[5]); | |||
| #define Store4 \ | |||
| Store3 _mm256_storeu_ps(dst + C24NUM, out[6]); \ | |||
| _mm256_storeu_ps(dst + out_step + C24NUM, out[7]); | |||
| #define Store5 \ | |||
| Store4 _mm256_storeu_ps(dst + C32NUM, out[8]); \ | |||
| _mm256_storeu_ps(dst + out_step + C32NUM, out[9]); | |||
| #define Store6 \ | |||
| Store5 _mm256_storeu_ps(dst + C40NUM, out[10]); \ | |||
| _mm256_storeu_ps(dst + out_step + C40NUM, out[11]); | |||
| __m256 out[12]; | |||
| if (bias != NULL) { | |||
| out[0] = _mm256_loadu_ps(bias); | |||
| out[1] = _mm256_loadu_ps(bias + C8NUM); | |||
| } else { | |||
| out[0] = _mm256_set1_ps(0.0f); | |||
| out[1] = _mm256_set1_ps(0.0f); | |||
| } | |||
| out[2] = out[0]; | |||
| out[3] = out[1]; | |||
| out[4] = out[0]; | |||
| out[5] = out[1]; | |||
| out[6] = out[0]; | |||
| out[7] = out[1]; | |||
| out[8] = out[0]; | |||
| out[9] = out[1]; | |||
| out[10] = out[0]; | |||
| out[11] = out[1]; | |||
| for (int d = 0; d < depth; ++d) { | |||
| __m256 w1 = _mm256_loadu_ps(weight); | |||
| __m256 w2 = _mm256_loadu_ps(weight + C8NUM); | |||
| __m256 s1 = _mm256_set1_ps(*src); | |||
| __m256 s2 = _mm256_set1_ps(*(src + 1)); | |||
| out[0] = _mm256_fmadd_ps(s1, w1, out[0]); | |||
| out[1] = _mm256_fmadd_ps(s1, w2, out[1]); | |||
| out[2] = _mm256_fmadd_ps(s2, w1, out[2]); | |||
| out[3] = _mm256_fmadd_ps(s2, w2, out[3]); | |||
| s1 = _mm256_set1_ps(*(src + 2)); | |||
| s2 = _mm256_set1_ps(*(src + 3)); | |||
| out[4] = _mm256_fmadd_ps(s1, w1, out[4]); | |||
| out[5] = _mm256_fmadd_ps(s1, w2, out[5]); | |||
| out[6] = _mm256_fmadd_ps(s2, w1, out[6]); | |||
| out[7] = _mm256_fmadd_ps(s2, w2, out[7]); | |||
| s1 = _mm256_set1_ps(*(src + 4)); | |||
| s2 = _mm256_set1_ps(*(src + 5)); | |||
| out[8] = _mm256_fmadd_ps(s1, w1, out[8]); | |||
| out[9] = _mm256_fmadd_ps(s1, w2, out[9]); | |||
| out[10] = _mm256_fmadd_ps(s2, w1, out[10]); | |||
| out[11] = _mm256_fmadd_ps(s2, w2, out[11]); | |||
| weight += C16NUM; | |||
| src += C6NUM; | |||
| } | |||
| __m256 six = _mm256_set1_ps(6.0f); | |||
| __m256 zero = _mm256_set1_ps(0.0f); | |||
| if (0x1 & act_flag) { // relu6 | |||
| out[0] = _mm256_min_ps(out[0], six); | |||
| out[1] = _mm256_min_ps(out[1], six); | |||
| out[2] = _mm256_min_ps(out[2], six); | |||
| out[3] = _mm256_min_ps(out[3], six); | |||
| out[4] = _mm256_min_ps(out[4], six); | |||
| out[5] = _mm256_min_ps(out[5], six); | |||
| out[6] = _mm256_min_ps(out[6], six); | |||
| out[7] = _mm256_min_ps(out[7], six); | |||
| out[8] = _mm256_min_ps(out[8], six); | |||
| out[9] = _mm256_min_ps(out[9], six); | |||
| out[10] = _mm256_min_ps(out[10], six); | |||
| out[11] = _mm256_min_ps(out[11], six); | |||
| } | |||
| if (0x2 & act_flag) { // relu | |||
| out[0] = _mm256_max_ps(out[0], zero); | |||
| out[1] = _mm256_max_ps(out[1], zero); | |||
| out[2] = _mm256_max_ps(out[2], zero); | |||
| out[3] = _mm256_max_ps(out[3], zero); | |||
| out[4] = _mm256_max_ps(out[4], zero); | |||
| out[5] = _mm256_max_ps(out[5], zero); | |||
| out[6] = _mm256_max_ps(out[6], zero); | |||
| out[7] = _mm256_max_ps(out[7], zero); | |||
| out[8] = _mm256_max_ps(out[8], zero); | |||
| out[9] = _mm256_max_ps(out[9], zero); | |||
| out[10] = _mm256_max_ps(out[10], zero); | |||
| out[11] = _mm256_max_ps(out[11], zero); | |||
| } | |||
| if (real_cal_row == C6NUM) { | |||
| Store6 | |||
| } else if (real_cal_row == C5NUM) { | |||
| Store5 | |||
| } else if (real_cal_row == C4NUM) { | |||
| Store4 | |||
| } else if (real_cal_row == C3NUM) { | |||
| Store3 | |||
| } else if (real_cal_row == C2NUM) { | |||
| Store2 | |||
| } else if (real_cal_row == C1NUM) { | |||
| Store1 | |||
| } | |||
| } | |||
| void SWBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, int left, | |||
| int right, const ConvParameter *conv_param, const SlidingWindowParam *sw_param, const SWConvKernel kernel, | |||
| int act_type, int ow_bock, int oc_block, size_t write_mode) { | |||
| @@ -29,17 +29,20 @@ typedef void (*Row2ColMajorFuncPtr)(const float *src_ptr, float *dst_ptr, int ro | |||
| #ifdef ENABLE_ARM64 | |||
| typedef void (*MatmulFloatOptFuncPtr)(const float *a, const float *b, float *c, const float *bias, int act_type, | |||
| int depth, int row, int col, size_t stride, size_t write_mode); | |||
| // common convolution output C4HW4, if out_channel mod 4 remains, just output real channel, no zeros padded. | |||
| void ConvFp32OutNC4HW4(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, | |||
| float *col_major_input, float *output_data, int task_id, const ConvParameter *conv_param); | |||
| #endif | |||
| // fp32 convolution common (im2col+gemm) | |||
| void ConvFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, | |||
| float *col_major_input, float *output_data, int task_id, const ConvParameter *conv_param); | |||
| // common convolution output C4HW4, if out_channel mod 4 remains, just output real channel, no zeros padded. | |||
| void ConvFp32OutNC4HW4(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, | |||
| float *col_major_input, float *output_data, int task_id, const ConvParameter *conv_param); | |||
| #ifdef ENABLE_AVX | |||
| void CommonConv6x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t depth, | |||
| size_t out_step, size_t act_flag, size_t real_cal_row); | |||
| typedef void (*SWConvKernel)(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, | |||
| size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, | |||
| size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, | |||
| @@ -384,6 +384,27 @@ void PackNC8HW8ToNHWCFp32(const void *src, void *dst, int batch, int plane, int | |||
| } | |||
| } | |||
| void PackNC8HW8AlignedToNC8HW8NotAlignedFp32(const void *src, void *dst, const int batch, const int plane, | |||
| const int channel) { | |||
| int down_channel_8 = DOWN_ROUND(channel, C8NUM); | |||
| int up_channel_16 = UP_ROUND(channel, C16NUM); | |||
| size_t dst_batch_offset = (size_t)(plane * channel) * sizeof(float); | |||
| size_t src_batch_offset = (size_t)(plane * up_channel_16) * sizeof(float); | |||
| size_t unaligned_channel_size = (size_t)(channel - down_channel_8) * sizeof(float); | |||
| size_t aligned_channel_size = (size_t)(down_channel_8 * plane) * sizeof(float); | |||
| size_t src_p_offset = C8NUM * sizeof(float); | |||
| for (size_t b = 0; b < (size_t)(batch); ++b) { | |||
| const char *src_batch = (char *)(src) + b * src_batch_offset; | |||
| char *dst_bacth = (char *)(dst) + b * dst_batch_offset; | |||
| memcpy(dst_bacth, src_batch, aligned_channel_size); | |||
| src_batch += aligned_channel_size; | |||
| dst_bacth += aligned_channel_size; | |||
| for (int p = 0; p < plane; ++p) { | |||
| memcpy(dst_bacth + p * unaligned_channel_size, src_batch + p * src_p_offset, unaligned_channel_size); | |||
| } | |||
| } | |||
| } | |||
| void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel) { | |||
| for (int n = 0; n < batch; n++) { | |||
| for (int hw = 0; hw < plane; hw++) { | |||
| @@ -38,6 +38,7 @@ void PackNHWCXToNHWCFp32(const void *src, void *dst, int batch, int plane, int c | |||
| void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNC8HW8ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNC8HW8AlignedToNC8HW8NotAlignedFp32(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNHWCToCXHWNXFp32(const float *src, float *dst, int batch, int plane, int channel); | |||
| @@ -30,12 +30,15 @@ | |||
| #define C2NUM 2 | |||
| #define C3NUM 3 | |||
| #define C4NUM 4 | |||
| #define C5NUM 5 | |||
| #define C6NUM 6 | |||
| #define C8NUM 8 | |||
| #define C12NUM 12 | |||
| #define C16NUM 16 | |||
| #define C20NUM 20 | |||
| #define C24NUM 24 | |||
| #define C32NUM 32 | |||
| #define C40NUM 40 | |||
| #define C64NUM 64 | |||
| #define TILE_NUM 8 | |||
| @@ -36,7 +36,8 @@ namespace mindspore::kernel { | |||
| #endif | |||
| int ConvolutionCPUKernel::InitTmpBuffer() { | |||
| MS_ASSERT(ctx_->allocator != nullptr); | |||
| CHECK_NULL_RETURN(out_tensors_[0]); | |||
| CHECK_NULL_RETURN(out_tensors_[0]->MutableData()); | |||
| #ifdef ENABLE_AVX | |||
| int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * C6NUM * thread_count_; | |||
| #elif defined(ENABLE_SSE) | |||
| @@ -56,6 +57,20 @@ int ConvolutionCPUKernel::InitTmpBuffer() { | |||
| MS_LOG(ERROR) << "malloc col_major_input_ failed."; | |||
| return RET_ERROR; | |||
| } | |||
| #ifdef ENABLE_AVX | |||
| if (conv_param_->output_channel_ % OC_BLOCK != 0 && out_tensors_[0]->format() == NC4HW4) { | |||
| output_need_align_ = true; | |||
| int oc_algin = UP_DIV(conv_param_->output_channel_, OC_BLOCK); | |||
| int pack_output_size = | |||
| conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * OC_BLOCK * oc_algin; | |||
| tmp_output_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(pack_output_size * sizeof(float))); | |||
| if (tmp_output_ == nullptr) { | |||
| MS_LOG(ERROR) << "Malloc tmp_output_ buffer is failed."; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| #endif | |||
| return RET_OK; | |||
| } | |||
| @@ -96,14 +111,13 @@ int ConvolutionCPUKernel::ReSize() { | |||
| int ConvolutionCPUKernel::RunImpl(int task_id) { | |||
| auto ori_input_data = reinterpret_cast<float *>(in_tensors_.at(kInputIndex)->data_c()); | |||
| auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->data_c()); | |||
| if (out_tensors()[0]->format() != NC4HW4) { | |||
| if (out_tensors_[0]->format() != NC4HW4) { | |||
| ConvFp32(ori_input_data, packed_input_, reinterpret_cast<float *>(packed_weight_), | |||
| reinterpret_cast<float *>(bias_data_), col_major_input_, output_addr, task_id, conv_param_); | |||
| reinterpret_cast<float *>(bias_data_), col_major_input_, tmp_output_, task_id, conv_param_); | |||
| } else { | |||
| #if ENABLE_ARM64 | |||
| #if defined(ENABLE_ARM64) || defined(ENABLE_AVX) | |||
| ConvFp32OutNC4HW4(ori_input_data, packed_input_, reinterpret_cast<float *>(packed_weight_), | |||
| reinterpret_cast<float *>(bias_data_), col_major_input_, output_addr, task_id, conv_param_); | |||
| reinterpret_cast<float *>(bias_data_), col_major_input_, tmp_output_, task_id, conv_param_); | |||
| #else | |||
| MS_LOG(ERROR) << "ConvFp32OutNC4HW4 not implemented."; | |||
| return RET_ERROR; | |||
| @@ -129,8 +143,12 @@ int ConvolutionCPUKernel::Run() { | |||
| FreeTmpBuffer(); | |||
| return RET_ERROR; | |||
| } | |||
| auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData()); | |||
| if (!output_need_align_) { | |||
| tmp_output_ = output_addr; | |||
| } | |||
| if (RepackWeight() != RET_OK) { | |||
| FreeTmpBuffer(); | |||
| MS_LOG(ERROR) << "Repack weight failed."; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -138,6 +156,13 @@ int ConvolutionCPUKernel::Run() { | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "conv error error_code[" << ret << "]"; | |||
| } | |||
| #ifdef ENABLE_AVX | |||
| if (output_need_align_) { | |||
| PackNC8HW8AlignedToNC8HW8NotAlignedFp32(tmp_output_, output_addr, conv_param_->output_batch_, | |||
| conv_param_->output_h_ * conv_param_->output_w_, | |||
| conv_param_->output_channel_); | |||
| } | |||
| #endif | |||
| FreeTmpBuffer(); | |||
| return ret; | |||
| } | |||
| @@ -49,11 +49,18 @@ class ConvolutionCPUKernel : public ConvolutionBaseCPUKernel { | |||
| ctx_->allocator->Free(col_major_input_); | |||
| col_major_input_ = nullptr; | |||
| } | |||
| if (output_need_align_ && tmp_output_ != nullptr) { | |||
| ctx_->allocator->Free(tmp_output_); | |||
| tmp_output_ = nullptr; | |||
| output_need_align_ = false; | |||
| } | |||
| } | |||
| protected: | |||
| float *tmp_output_ = nullptr; | |||
| float *packed_input_ = nullptr; | |||
| float *col_major_input_ = nullptr; | |||
| bool output_need_align_ = false; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -145,7 +145,7 @@ bool RuntimePassValid(const InnerContext *context, std::vector<kernel::LiteKerne | |||
| } | |||
| } | |||
| #ifdef ENABLE_ARM64 | |||
| #if defined(ENABLE_ARM64) | |||
| return true; | |||
| #endif | |||