| @@ -39,27 +39,23 @@ float ShortToFloat32(uint16_t src_value); | |||||
| uint16_t Float32ToShort(float src_value); | uint16_t Float32ToShort(float src_value); | ||||
| #ifdef ENABLE_X86_64_SSE | |||||
| void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, | |||||
| size_t plane_size, size_t stride, size_t relu_type); | |||||
| void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, | |||||
| size_t plane_size, size_t plane_stride, size_t relu_type); | |||||
| #endif | |||||
| #ifdef ENABLE_ARM | |||||
| #if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) | |||||
| void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, | void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, | ||||
| size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, | size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, | ||||
| size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); | size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); | ||||
| void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, | |||||
| size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6); | |||||
| void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, size_t kernel_h, | void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, size_t kernel_h, | ||||
| size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, | size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, | ||||
| size_t in_kh_step, size_t in_kw_step); | size_t in_kh_step, size_t in_kw_step); | ||||
| void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, size_t num_pixels, | |||||
| size_t output_channel, size_t input_step); | |||||
| void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, | |||||
| size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6); | |||||
| void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, | void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, | ||||
| size_t plane_size, size_t stride, size_t relu_type); | size_t plane_size, size_t stride, size_t relu_type); | ||||
| #endif | |||||
| #ifdef ENABLE_ARM | |||||
| void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, size_t num_pixels, | |||||
| size_t output_channel, size_t input_step); | |||||
| void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, | void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, | ||||
| size_t plane_size, size_t plane_stride, size_t relu_type); | size_t plane_size, size_t plane_stride, size_t relu_type); | ||||
| #endif | #endif | ||||
| @@ -200,7 +200,7 @@ void ConvDwBorder(float *dst, const float *src, const float *weight, const float | |||||
| const float *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; | const float *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; | ||||
| const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; | const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; | ||||
| #ifdef ENABLE_ARM | |||||
| #if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) | |||||
| ConvDwFp32Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, | ConvDwFp32Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, | ||||
| sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float), | sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float), | ||||
| conv_param->kernel_w_ * C4NUM * sizeof(float), relu, relu6); | conv_param->kernel_w_ * C4NUM * sizeof(float), relu, relu6); | ||||
| @@ -283,7 +283,7 @@ void ConvDwSWFp32(float *output_data, const float *input_data, const float *weig | |||||
| int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; | int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; | ||||
| const float *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; | const float *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; | ||||
| float *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; | float *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; | ||||
| #ifdef ENABLE_ARM | |||||
| #if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) | |||||
| ConvDwFp32Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, | ConvDwFp32Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, | ||||
| conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), | conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), | ||||
| sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), | sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), | ||||
| @@ -437,7 +437,7 @@ void DeconvDwSWFp32(float *output_data, const float *input_data, const float *we | |||||
| float *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_; | float *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_; | ||||
| const float *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; | const float *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; | ||||
| #ifdef ENABLE_ARM | |||||
| #if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) | |||||
| DeconvDwFp32Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, | DeconvDwFp32Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, | ||||
| conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), | conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), | ||||
| sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), | sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), | ||||
| @@ -743,6 +743,7 @@ void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int ch | |||||
| return; | return; | ||||
| } | } | ||||
| #ifndef ENABLE_X86_64_SSE | |||||
| void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel) { | void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel) { | ||||
| int hw8 = plane / C8NUM * C8NUM; | int hw8 = plane / C8NUM * C8NUM; | ||||
| int c8 = channel / C8NUM * C8NUM; | int c8 = channel / C8NUM * C8NUM; | ||||
| @@ -928,6 +929,7 @@ void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int | |||||
| } | } | ||||
| return; | return; | ||||
| } | } | ||||
| #endif | |||||
| void PackNHWCToNCHWInt8(const void *src, void *dst, int batches, int plane, int channel) { | void PackNHWCToNCHWInt8(const void *src, void *dst, int batches, int plane, int channel) { | ||||
| int hw8 = plane / C8NUM * C8NUM; | int hw8 = plane / C8NUM * C8NUM; | ||||
| @@ -0,0 +1,376 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifdef ENABLE_X86_64_SSE | |||||
| #include <nmmintrin.h> | |||||
| #include "nnacl/fp32/conv_depthwise.h" | |||||
| void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, | |||||
| size_t in_kh_step, size_t in_kw_step, size_t kernel_w_step, size_t relu, size_t relu6) { | |||||
| in_kh_step /= sizeof(float); | |||||
| in_kw_step /= sizeof(float); | |||||
| kernel_w_step /= sizeof(float); | |||||
| const float *src_kh = src; | |||||
| const float *weight_kh = weight; | |||||
| __m128 dst_ma = _mm_setzero_ps(); | |||||
| for (int kh = 0; kh < height; kh++) { | |||||
| const float *src_kw = src_kh; | |||||
| const float *weight_kw = weight_kh; | |||||
| int c1 = 0; | |||||
| int c4 = DOWN_DIV(width, C4NUM) * C4NUM; | |||||
| int c2 = DOWN_DIV(width, C2NUM) * C2NUM; | |||||
| // c4 loop | |||||
| for (; c1 < c4; c1 += C4NUM) { | |||||
| __m128 src_ma1 = _mm_loadu_ps(src_kw); | |||||
| __m128 src_ma2 = _mm_loadu_ps(src_kw + in_kw_step); | |||||
| __m128 src_ma3 = _mm_loadu_ps(src_kw + 2 * in_kw_step); | |||||
| __m128 src_ma4 = _mm_loadu_ps(src_kw + 3 * in_kw_step); | |||||
| __m128 weight_ma1 = _mm_loadu_ps(weight_kw); | |||||
| __m128 weight_ma2 = _mm_loadu_ps(weight_kw + C4NUM); | |||||
| __m128 weight_ma3 = _mm_loadu_ps(weight_kw + 2 * C4NUM); | |||||
| __m128 weight_ma4 = _mm_loadu_ps(weight_kw + 3 * C4NUM); | |||||
| __m128 mul_ma1 = _mm_mul_ps(src_ma1, weight_ma1); | |||||
| __m128 mul_ma2 = _mm_mul_ps(src_ma2, weight_ma2); | |||||
| __m128 mul_ma3 = _mm_mul_ps(src_ma3, weight_ma3); | |||||
| __m128 mul_ma4 = _mm_mul_ps(src_ma4, weight_ma4); | |||||
| dst_ma = _mm_add_ps(dst_ma, mul_ma1); | |||||
| dst_ma = _mm_add_ps(dst_ma, mul_ma2); | |||||
| dst_ma = _mm_add_ps(dst_ma, mul_ma3); | |||||
| dst_ma = _mm_add_ps(dst_ma, mul_ma4); | |||||
| src_kw += in_kw_step * 4; | |||||
| weight_kw += C4NUM * 4; | |||||
| } | |||||
| // c2 loop | |||||
| for (; c1 < c2; c1 += C2NUM) { | |||||
| __m128 src_ma1 = _mm_loadu_ps(src_kw); | |||||
| __m128 src_ma2 = _mm_loadu_ps(src_kw + in_kw_step); | |||||
| __m128 weight_ma1 = _mm_loadu_ps(weight_kw); | |||||
| __m128 weight_ma2 = _mm_loadu_ps(weight_kw + C4NUM); | |||||
| __m128 mul_ma1 = _mm_mul_ps(src_ma1, weight_ma1); | |||||
| __m128 mul_ma2 = _mm_mul_ps(src_ma2, weight_ma2); | |||||
| dst_ma = _mm_add_ps(dst_ma, mul_ma1); | |||||
| dst_ma = _mm_add_ps(dst_ma, mul_ma2); | |||||
| src_kw += in_kw_step * 2; | |||||
| weight_kw += C4NUM * 2; | |||||
| } | |||||
| // remaining | |||||
| for (; c1 < width; ++c1) { | |||||
| __m128 src_ma1 = _mm_loadu_ps(src_kw); | |||||
| __m128 weight_ma1 = _mm_loadu_ps(weight_kw); | |||||
| __m128 mul_ma1 = _mm_mul_ps(src_ma1, weight_ma1); | |||||
| dst_ma = _mm_add_ps(dst_ma, mul_ma1); | |||||
| src_kw += in_kw_step; | |||||
| weight_kw += C4NUM; | |||||
| } | |||||
| src_kh += in_kh_step; | |||||
| weight_kh += kernel_w_step; | |||||
| } | |||||
| __m128 bias_ma = _mm_loadu_ps(bias); | |||||
| dst_ma = _mm_add_ps(dst_ma, bias_ma); | |||||
| __m128 zero_ma = _mm_setzero_ps(); | |||||
| if (relu || relu6) { | |||||
| dst_ma = _mm_max_ps(zero_ma, dst_ma); | |||||
| if (relu6) { | |||||
| __m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); | |||||
| dst_ma = _mm_min_ps(const_ma, dst_ma); | |||||
| } | |||||
| } | |||||
| _mm_storeu_ps(dst, dst_ma); | |||||
| } | |||||
| void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, | |||||
| size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, | |||||
| size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6) { | |||||
| out_h_step /= sizeof(float); | |||||
| block_channel /= sizeof(float); | |||||
| in_sh_step /= sizeof(float); | |||||
| in_sw_step /= sizeof(float); | |||||
| in_kh_step /= sizeof(float); | |||||
| in_kw_step /= sizeof(float); | |||||
| float *dst_h = dst; | |||||
| const float *src_h = src; | |||||
| for (int oh = 0; oh < height; oh++) { | |||||
| float *dst_w = dst_h; | |||||
| const float *src_w = src_h; | |||||
| int c4 = DOWN_DIV(width, C4NUM) * C4NUM; | |||||
| int c2 = DOWN_DIV(width, C2NUM) * C2NUM; | |||||
| int c1 = 0; | |||||
| // c4 loop | |||||
| for (; c1 < c4; c1 += C4NUM) { | |||||
| const float *src_kh = src_w; | |||||
| const float *weight_kh = weight; | |||||
| __m128 dst_w_ma1 = _mm_setzero_ps(); | |||||
| __m128 dst_w_ma2 = _mm_setzero_ps(); | |||||
| __m128 dst_w_ma3 = _mm_setzero_ps(); | |||||
| __m128 dst_w_ma4 = _mm_setzero_ps(); | |||||
| for (int kh = 0; kh < kernel_h; kh++) { | |||||
| const float *src_kw = src_kh; | |||||
| const float *weight_kw = weight_kh; | |||||
| for (int kw = 0; kw < kernel_w; kw++) { | |||||
| __m128 src_kw_ma1 = _mm_loadu_ps(src_kw); | |||||
| __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); | |||||
| __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1); | |||||
| dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); | |||||
| __m128 src_kw_ma2 = _mm_loadu_ps(src_kw + in_sw_step); | |||||
| __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw); | |||||
| __m128 tmp_ma2 = _mm_mul_ps(src_kw_ma2, weight_kw_ma2); | |||||
| dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2); | |||||
| __m128 src_kw_ma3 = _mm_loadu_ps(src_kw + 2 * in_sw_step); | |||||
| __m128 weight_kw_ma3 = _mm_loadu_ps(weight_kw); | |||||
| __m128 tmp_ma3 = _mm_mul_ps(src_kw_ma3, weight_kw_ma3); | |||||
| dst_w_ma3 = _mm_add_ps(dst_w_ma3, tmp_ma3); | |||||
| __m128 src_kw_ma4 = _mm_loadu_ps(src_kw + 3 * in_sw_step); | |||||
| __m128 weight_kw_ma4 = _mm_loadu_ps(weight_kw); | |||||
| __m128 tmp_ma4 = _mm_mul_ps(src_kw_ma4, weight_kw_ma4); | |||||
| dst_w_ma4 = _mm_add_ps(dst_w_ma4, tmp_ma4); | |||||
| src_kw += in_kw_step; | |||||
| weight_kw += C4NUM; | |||||
| } // kernel_w loop | |||||
| src_kh += in_kh_step; | |||||
| weight_kh += kernel_w * C4NUM; | |||||
| } // kernel_h loop | |||||
| // add bias relu | |||||
| __m128 bias_ma = _mm_loadu_ps(bias); | |||||
| dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma); | |||||
| dst_w_ma2 = _mm_add_ps(dst_w_ma2, bias_ma); | |||||
| dst_w_ma3 = _mm_add_ps(dst_w_ma3, bias_ma); | |||||
| dst_w_ma4 = _mm_add_ps(dst_w_ma4, bias_ma); | |||||
| __m128 zero_ma = _mm_setzero_ps(); | |||||
| if (relu || relu6) { | |||||
| dst_w_ma1 = _mm_max_ps(zero_ma, dst_w_ma1); | |||||
| dst_w_ma2 = _mm_max_ps(zero_ma, dst_w_ma2); | |||||
| dst_w_ma3 = _mm_max_ps(zero_ma, dst_w_ma3); | |||||
| dst_w_ma4 = _mm_max_ps(zero_ma, dst_w_ma4); | |||||
| if (relu6) { | |||||
| __m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); | |||||
| dst_w_ma1 = _mm_min_ps(const_ma, dst_w_ma1); | |||||
| dst_w_ma2 = _mm_min_ps(const_ma, dst_w_ma2); | |||||
| dst_w_ma3 = _mm_min_ps(const_ma, dst_w_ma3); | |||||
| dst_w_ma4 = _mm_min_ps(const_ma, dst_w_ma4); | |||||
| } | |||||
| } | |||||
| _mm_storeu_ps(dst_w, dst_w_ma1); | |||||
| _mm_storeu_ps(dst_w + block_channel, dst_w_ma2); | |||||
| _mm_storeu_ps(dst_w + 2 * block_channel, dst_w_ma3); | |||||
| _mm_storeu_ps(dst_w + 3 * block_channel, dst_w_ma4); | |||||
| dst_w += C4NUM * block_channel; | |||||
| src_w += C4NUM * in_sw_step; | |||||
| } // dst_width loop | |||||
| // c2 loop | |||||
| for (; c1 < c2; c1 += C2NUM) { | |||||
| const float *src_kh = src_w; | |||||
| const float *weight_kh = weight; | |||||
| __m128 dst_w_ma1 = _mm_setzero_ps(); | |||||
| __m128 dst_w_ma2 = _mm_setzero_ps(); | |||||
| for (int kh = 0; kh < kernel_h; kh++) { | |||||
| const float *src_kw = src_kh; | |||||
| const float *weight_kw = weight_kh; | |||||
| for (int kw = 0; kw < kernel_w; kw++) { | |||||
| __m128 src_kw_ma1 = _mm_loadu_ps(src_kw); | |||||
| __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); | |||||
| __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1); | |||||
| dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); | |||||
| __m128 src_kw_ma2 = _mm_loadu_ps(src_kw + in_sw_step); | |||||
| __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw); | |||||
| __m128 tmp_ma2 = _mm_mul_ps(src_kw_ma2, weight_kw_ma2); | |||||
| dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2); | |||||
| src_kw += in_kw_step; | |||||
| weight_kw += C4NUM; | |||||
| } // kernel_w loop | |||||
| src_kh += in_kh_step; | |||||
| weight_kh += kernel_w * C4NUM; | |||||
| } // kernel_h loop | |||||
| // add bias relu | |||||
| __m128 bias_ma = _mm_loadu_ps(bias); | |||||
| dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma); | |||||
| dst_w_ma2 = _mm_add_ps(dst_w_ma2, bias_ma); | |||||
| __m128 zero_ma = _mm_setzero_ps(); | |||||
| if (relu || relu6) { | |||||
| dst_w_ma1 = _mm_max_ps(zero_ma, dst_w_ma1); | |||||
| dst_w_ma2 = _mm_max_ps(zero_ma, dst_w_ma2); | |||||
| if (relu6) { | |||||
| __m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); | |||||
| dst_w_ma1 = _mm_min_ps(const_ma, dst_w_ma1); | |||||
| dst_w_ma2 = _mm_min_ps(const_ma, dst_w_ma2); | |||||
| } | |||||
| } | |||||
| _mm_storeu_ps(dst_w, dst_w_ma1); | |||||
| _mm_storeu_ps(dst_w + block_channel, dst_w_ma2); | |||||
| dst_w += C2NUM * block_channel; | |||||
| src_w += C2NUM * in_sw_step; | |||||
| } | |||||
| // remaining | |||||
| for (; c1 < width; c1++) { | |||||
| const float *src_kh = src_w; | |||||
| const float *weight_kh = weight; | |||||
| __m128 dst_w_ma1 = _mm_setzero_ps(); | |||||
| for (int kh = 0; kh < kernel_h; kh++) { | |||||
| const float *src_kw = src_kh; | |||||
| const float *weight_kw = weight_kh; | |||||
| for (int kw = 0; kw < kernel_w; kw++) { | |||||
| __m128 src_kw_ma1 = _mm_loadu_ps(src_kw); | |||||
| __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); | |||||
| __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1); | |||||
| dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); | |||||
| src_kw += in_kw_step; | |||||
| weight_kw += C4NUM; | |||||
| } // kernel_w loop | |||||
| src_kh += in_kh_step; | |||||
| weight_kh += kernel_w * C4NUM; | |||||
| } // kernel_h loop | |||||
| // add bias relu | |||||
| __m128 bias_ma = _mm_loadu_ps(bias); | |||||
| dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma); | |||||
| __m128 zero_ma = _mm_setzero_ps(); | |||||
| if (relu || relu6) { | |||||
| dst_w_ma1 = _mm_max_ps(zero_ma, dst_w_ma1); | |||||
| if (relu6) { | |||||
| __m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); | |||||
| dst_w_ma1 = _mm_min_ps(const_ma, dst_w_ma1); | |||||
| } | |||||
| } | |||||
| _mm_storeu_ps(dst_w, dst_w_ma1); | |||||
| dst_w += block_channel; | |||||
| src_w += in_sw_step; | |||||
| } | |||||
| dst_h += out_h_step; | |||||
| src_h += in_sh_step; | |||||
| } // dst_height loop | |||||
| } | |||||
| void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, size_t kernel_h, | |||||
| size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, | |||||
| size_t in_kh_step, size_t in_kw_step) { | |||||
| out_h_step /= sizeof(float); | |||||
| block_channel /= sizeof(float); | |||||
| in_sh_step /= sizeof(float); | |||||
| in_sw_step /= sizeof(float); | |||||
| in_kh_step /= sizeof(float); | |||||
| in_kw_step /= sizeof(float); | |||||
| float *dst_h = dst; | |||||
| const float *src_h = src; | |||||
| for (int oh = 0; oh < height; oh++) { | |||||
| float *dst_w = dst_h; | |||||
| const float *src_w = src_h; | |||||
| for (int ow = 0; ow < width; ow++) { | |||||
| float *dst_kh = dst_w; | |||||
| const float *weight_kh = weight; | |||||
| __m128 src_w_ma = _mm_loadu_ps(src_w); | |||||
| for (int kh = 0; kh < kernel_h; kh++) { | |||||
| float *dst_kw = dst_kh; | |||||
| const float *weight_kw = weight_kh; | |||||
| int c4 = DOWN_DIV(kernel_w, C4NUM) * C4NUM; | |||||
| int c2 = DOWN_DIV(kernel_w, C2NUM) * C2NUM; | |||||
| int c1 = 0; | |||||
| // c4 loop | |||||
| for (; c1 < c4; c1 += C4NUM) { | |||||
| __m128 dst_w_ma1 = _mm_loadu_ps(dst_kw); | |||||
| __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); | |||||
| __m128 tmp_ma1 = _mm_mul_ps(src_w_ma, weight_kw_ma1); | |||||
| dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); | |||||
| _mm_storeu_ps(dst_kw, dst_w_ma1); | |||||
| __m128 dst_w_ma2 = _mm_loadu_ps(dst_kw + in_kw_step); | |||||
| __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw + C4NUM); | |||||
| __m128 tmp_ma2 = _mm_mul_ps(src_w_ma, weight_kw_ma2); | |||||
| dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2); | |||||
| _mm_storeu_ps(dst_kw + in_kw_step, dst_w_ma2); | |||||
| __m128 dst_w_ma3 = _mm_loadu_ps(dst_kw + 2 * in_kw_step); | |||||
| __m128 weight_kw_ma3 = _mm_loadu_ps(weight_kw + 2 * C4NUM); | |||||
| __m128 tmp_ma3 = _mm_mul_ps(src_w_ma, weight_kw_ma3); | |||||
| dst_w_ma3 = _mm_add_ps(dst_w_ma3, tmp_ma3); | |||||
| _mm_storeu_ps(dst_kw + 2 * in_kw_step, dst_w_ma3); | |||||
| __m128 dst_w_ma4 = _mm_loadu_ps(dst_kw + 3 * in_kw_step); | |||||
| __m128 weight_kw_ma4 = _mm_loadu_ps(weight_kw + 3 * C4NUM); | |||||
| __m128 tmp_ma4 = _mm_mul_ps(src_w_ma, weight_kw_ma4); | |||||
| dst_w_ma4 = _mm_add_ps(dst_w_ma4, tmp_ma4); | |||||
| _mm_storeu_ps(dst_kw + 3 * in_kw_step, dst_w_ma4); | |||||
| dst_kw += 4 * in_kw_step; | |||||
| weight_kw += 4 * C4NUM; | |||||
| } | |||||
| // c2 loop | |||||
| for (; c1 < c2; c1 += C2NUM) { | |||||
| __m128 dst_w_ma1 = _mm_loadu_ps(dst_kw); | |||||
| __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); | |||||
| __m128 tmp_ma1 = _mm_mul_ps(src_w_ma, weight_kw_ma1); | |||||
| dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); | |||||
| _mm_storeu_ps(dst_kw, dst_w_ma1); | |||||
| __m128 dst_w_ma2 = _mm_loadu_ps(dst_kw + in_kw_step); | |||||
| __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw + C4NUM); | |||||
| __m128 tmp_ma2 = _mm_mul_ps(src_w_ma, weight_kw_ma2); | |||||
| dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2); | |||||
| _mm_storeu_ps(dst_kw + in_kw_step, dst_w_ma2); | |||||
| dst_kw += 2 * in_kw_step; | |||||
| weight_kw += 2 * C4NUM; | |||||
| } | |||||
| // remaining | |||||
| for (; c1 < kernel_w; ++c1) { | |||||
| __m128 dst_w_ma1 = _mm_loadu_ps(dst_kw); | |||||
| __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); | |||||
| __m128 tmp_ma1 = _mm_mul_ps(src_w_ma, weight_kw_ma1); | |||||
| dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); | |||||
| _mm_storeu_ps(dst_kw, dst_w_ma1); | |||||
| dst_kw += in_kw_step; | |||||
| weight_kw += C4NUM; | |||||
| } // kernel_w loop | |||||
| dst_kh += in_kh_step; | |||||
| weight_kh += kernel_w * C4NUM; | |||||
| } // kernel_h loop | |||||
| dst_w += in_sw_step; | |||||
| src_w += block_channel; | |||||
| } // dst_width loop | |||||
| dst_h += in_sh_step; | |||||
| src_h += out_h_step; | |||||
| } // dst_height loop | |||||
| } | |||||
| #endif | |||||
| @@ -0,0 +1,140 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifdef ENABLE_X86_64_SSE | |||||
| #include <nmmintrin.h> | |||||
| #include "nnacl/pack.h" | |||||
| #include "nnacl/int8/conv_int8.h" | |||||
| void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel) { | |||||
| int hw8 = plane / C8NUM * C8NUM; | |||||
| int c8 = channel / C8NUM * C8NUM; | |||||
| int batch = plane * channel; | |||||
| for (int n = 0; n < batches; n++) { | |||||
| const float *src_batch = (const float *)src + n * batch; | |||||
| float *dst_batch = (float *)dst + n * batch; | |||||
| int hw = 0; | |||||
| for (; hw < hw8; hw += C8NUM) { | |||||
| int c = 0; | |||||
| for (; c < c8; c += C8NUM) { | |||||
| const float *src_ptr = src_batch + hw * channel + c; | |||||
| float *dst_ptr = dst_batch + c * plane + hw; | |||||
| // 11-14 | |||||
| __m128 v0_ma = _mm_loadu_ps(src_ptr); | |||||
| __m128 v1_ma = _mm_loadu_ps(src_ptr + channel); | |||||
| __m128 v2_ma = _mm_loadu_ps(src_ptr + 2 * channel); | |||||
| __m128 v3_ma = _mm_loadu_ps(src_ptr + 3 * channel); | |||||
| __m128 v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma); | |||||
| __m128 v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma); | |||||
| __m128 v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma); | |||||
| __m128 v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma); | |||||
| __m128 v8_ma = _mm_movelh_ps(v4_ma, v6_ma); | |||||
| __m128 v9_ma = _mm_movehl_ps(v6_ma, v4_ma); | |||||
| __m128 v10_ma = _mm_movelh_ps(v5_ma, v7_ma); | |||||
| __m128 v11_ma = _mm_movehl_ps(v7_ma, v5_ma); | |||||
| _mm_storeu_ps(dst_ptr, v8_ma); | |||||
| _mm_storeu_ps(dst_ptr + plane, v9_ma); | |||||
| _mm_storeu_ps(dst_ptr + 2 * plane, v10_ma); | |||||
| _mm_storeu_ps(dst_ptr + 3 * plane, v11_ma); | |||||
| // 15-18 | |||||
| v0_ma = _mm_loadu_ps(src_ptr + C4NUM); | |||||
| v1_ma = _mm_loadu_ps(src_ptr + channel + C4NUM); | |||||
| v2_ma = _mm_loadu_ps(src_ptr + 2 * channel + C4NUM); | |||||
| v3_ma = _mm_loadu_ps(src_ptr + 3 * channel + C4NUM); | |||||
| v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma); | |||||
| v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma); | |||||
| v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma); | |||||
| v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma); | |||||
| v8_ma = _mm_movelh_ps(v4_ma, v6_ma); | |||||
| v9_ma = _mm_movehl_ps(v6_ma, v4_ma); | |||||
| v10_ma = _mm_movelh_ps(v5_ma, v7_ma); | |||||
| v11_ma = _mm_movehl_ps(v7_ma, v5_ma); | |||||
| _mm_storeu_ps(dst_ptr + C4NUM * plane, v8_ma); | |||||
| _mm_storeu_ps(dst_ptr + (C4NUM + 1) * plane, v9_ma); | |||||
| _mm_storeu_ps(dst_ptr + (C4NUM + 2) * plane, v10_ma); | |||||
| _mm_storeu_ps(dst_ptr + (C4NUM + 3) * plane, v11_ma); | |||||
| // 21-24 | |||||
| v0_ma = _mm_loadu_ps(src_ptr + C4NUM * channel); | |||||
| v1_ma = _mm_loadu_ps(src_ptr + (C4NUM + 1) * channel); | |||||
| v2_ma = _mm_loadu_ps(src_ptr + (C4NUM + 2) * channel); | |||||
| v3_ma = _mm_loadu_ps(src_ptr + (C4NUM + 3) * channel); | |||||
| v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma); | |||||
| v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma); | |||||
| v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma); | |||||
| v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma); | |||||
| v8_ma = _mm_movelh_ps(v4_ma, v6_ma); | |||||
| v9_ma = _mm_movehl_ps(v6_ma, v4_ma); | |||||
| v10_ma = _mm_movelh_ps(v5_ma, v7_ma); | |||||
| v11_ma = _mm_movehl_ps(v7_ma, v5_ma); | |||||
| _mm_storeu_ps(dst_ptr + C4NUM, v8_ma); | |||||
| _mm_storeu_ps(dst_ptr + plane + C4NUM, v9_ma); | |||||
| _mm_storeu_ps(dst_ptr + 2 * plane + C4NUM, v10_ma); | |||||
| _mm_storeu_ps(dst_ptr + 3 * plane + C4NUM, v11_ma); | |||||
| // 25-28 | |||||
| v0_ma = _mm_loadu_ps(src_ptr + C4NUM * channel + C4NUM); | |||||
| v1_ma = _mm_loadu_ps(src_ptr + (C4NUM + 1) * channel + C4NUM); | |||||
| v2_ma = _mm_loadu_ps(src_ptr + (C4NUM + 2) * channel + C4NUM); | |||||
| v3_ma = _mm_loadu_ps(src_ptr + (C4NUM + 3) * channel + C4NUM); | |||||
| v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma); | |||||
| v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma); | |||||
| v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma); | |||||
| v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma); | |||||
| v8_ma = _mm_movelh_ps(v4_ma, v6_ma); | |||||
| v9_ma = _mm_movehl_ps(v6_ma, v4_ma); | |||||
| v10_ma = _mm_movelh_ps(v5_ma, v7_ma); | |||||
| v11_ma = _mm_movehl_ps(v7_ma, v5_ma); | |||||
| _mm_storeu_ps(dst_ptr + C4NUM * plane + C4NUM, v8_ma); | |||||
| _mm_storeu_ps(dst_ptr + (C4NUM + 1) * plane + C4NUM, v9_ma); | |||||
| _mm_storeu_ps(dst_ptr + (C4NUM + 2) * plane + C4NUM, v10_ma); | |||||
| _mm_storeu_ps(dst_ptr + (C4NUM + 3) * plane + C4NUM, v11_ma); | |||||
| } | |||||
| for (; c < channel; c++) { | |||||
| const float *src_ptr = src_batch + hw * channel + c; | |||||
| float *dst_ptr = dst_batch + c * plane + hw; | |||||
| for (size_t i = 0; i < C8NUM; i++) { | |||||
| dst_ptr[i] = src_ptr[i * channel]; | |||||
| } | |||||
| } | |||||
| } | |||||
| for (; hw < plane; hw++) { | |||||
| const float *src_ptr = src_batch + hw * channel; | |||||
| float *dst_ptr = dst_batch + hw; | |||||
| for (size_t i = 0; i < channel; i++) { | |||||
| dst_ptr[i * plane] = src_ptr[i]; | |||||
| } | |||||
| } | |||||
| } | |||||
| return; | |||||
| } | |||||
| #endif | |||||