From d27df64c9b9347ec05819d7beee5c923475a41f1 Mon Sep 17 00:00:00 2001 From: wangyanling Date: Wed, 11 Nov 2020 11:51:17 +0800 Subject: [PATCH] add ConvDwFp32Center,ConvDwFp32Border,DeconvDwFp32Center and PackNHWCToNCHWFp32's x86_sse assembly operators --- mindspore/lite/nnacl/fp32/common_func.h | 22 +- mindspore/lite/nnacl/fp32/conv_depthwise.c | 6 +- mindspore/lite/nnacl/pack.c | 2 + .../lite/nnacl/x86_64_sse/DepthwiseFp32_Sse.c | 376 ++++++++++++++++++ .../nnacl/x86_64_sse/PackNHWCToNCHWFp32.c | 140 +++++++ 5 files changed, 530 insertions(+), 16 deletions(-) create mode 100644 mindspore/lite/nnacl/x86_64_sse/DepthwiseFp32_Sse.c create mode 100644 mindspore/lite/nnacl/x86_64_sse/PackNHWCToNCHWFp32.c diff --git a/mindspore/lite/nnacl/fp32/common_func.h b/mindspore/lite/nnacl/fp32/common_func.h index aeca3bacee..8a9d818c55 100644 --- a/mindspore/lite/nnacl/fp32/common_func.h +++ b/mindspore/lite/nnacl/fp32/common_func.h @@ -39,27 +39,23 @@ float ShortToFloat32(uint16_t src_value); uint16_t Float32ToShort(float src_value); -#ifdef ENABLE_X86_64_SSE -void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, - size_t plane_size, size_t stride, size_t relu_type); -void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, - size_t plane_size, size_t plane_stride, size_t relu_type); -#endif - -#ifdef ENABLE_ARM +#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); +void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, + size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6); void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); -void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, size_t num_pixels, - size_t output_channel, size_t input_step); - -void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, - size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6); void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, size_t plane_size, size_t stride, size_t relu_type); +#endif + +#ifdef ENABLE_ARM + +void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, size_t num_pixels, + size_t output_channel, size_t input_step); void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, size_t plane_size, size_t plane_stride, size_t relu_type); #endif diff --git a/mindspore/lite/nnacl/fp32/conv_depthwise.c b/mindspore/lite/nnacl/fp32/conv_depthwise.c index 5f86343d61..6bf11d2dbc 100644 --- a/mindspore/lite/nnacl/fp32/conv_depthwise.c +++ b/mindspore/lite/nnacl/fp32/conv_depthwise.c @@ -200,7 +200,7 @@ void ConvDwBorder(float *dst, const float *src, const float *weight, const float const float *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; -#ifdef ENABLE_ARM +#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) ConvDwFp32Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float), conv_param->kernel_w_ * C4NUM * sizeof(float), relu, relu6); @@ -283,7 +283,7 @@ void ConvDwSWFp32(float *output_data, const float *input_data, const float *weig int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; const float *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; float *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; -#ifdef ENABLE_ARM +#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) ConvDwFp32Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), @@ -437,7 +437,7 @@ void DeconvDwSWFp32(float *output_data, const float *input_data, const float *we float *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_; const float *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; -#ifdef ENABLE_ARM +#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) DeconvDwFp32Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), diff --git a/mindspore/lite/nnacl/pack.c b/mindspore/lite/nnacl/pack.c index f80fc5cb84..877efe2324 100644 --- a/mindspore/lite/nnacl/pack.c +++ b/mindspore/lite/nnacl/pack.c @@ -743,6 +743,7 @@ void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int ch return; } +#ifndef ENABLE_X86_64_SSE void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel) { int hw8 = plane / C8NUM * C8NUM; int c8 = channel / C8NUM * C8NUM; @@ -928,6 +929,7 @@ void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int } return; } +#endif void PackNHWCToNCHWInt8(const void *src, void *dst, int batches, int plane, int channel) { int hw8 = plane / C8NUM * C8NUM; diff --git a/mindspore/lite/nnacl/x86_64_sse/DepthwiseFp32_Sse.c b/mindspore/lite/nnacl/x86_64_sse/DepthwiseFp32_Sse.c new file mode 100644 index 0000000000..e4cc2fb689 --- /dev/null +++ b/mindspore/lite/nnacl/x86_64_sse/DepthwiseFp32_Sse.c @@ -0,0 +1,376 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_X86_64_SSE +#include +#include "nnacl/fp32/conv_depthwise.h" + +void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, + size_t in_kh_step, size_t in_kw_step, size_t kernel_w_step, size_t relu, size_t relu6) { + in_kh_step /= sizeof(float); + in_kw_step /= sizeof(float); + kernel_w_step /= sizeof(float); + + const float *src_kh = src; + const float *weight_kh = weight; + __m128 dst_ma = _mm_setzero_ps(); + + for (int kh = 0; kh < height; kh++) { + const float *src_kw = src_kh; + const float *weight_kw = weight_kh; + + int c1 = 0; + int c4 = DOWN_DIV(width, C4NUM) * C4NUM; + int c2 = DOWN_DIV(width, C2NUM) * C2NUM; + // c4 loop + for (; c1 < c4; c1 += C4NUM) { + __m128 src_ma1 = _mm_loadu_ps(src_kw); + __m128 src_ma2 = _mm_loadu_ps(src_kw + in_kw_step); + __m128 src_ma3 = _mm_loadu_ps(src_kw + 2 * in_kw_step); + __m128 src_ma4 = _mm_loadu_ps(src_kw + 3 * in_kw_step); + + __m128 weight_ma1 = _mm_loadu_ps(weight_kw); + __m128 weight_ma2 = _mm_loadu_ps(weight_kw + C4NUM); + __m128 weight_ma3 = _mm_loadu_ps(weight_kw + 2 * C4NUM); + __m128 weight_ma4 = _mm_loadu_ps(weight_kw + 3 * C4NUM); + + __m128 mul_ma1 = _mm_mul_ps(src_ma1, weight_ma1); + __m128 mul_ma2 = _mm_mul_ps(src_ma2, weight_ma2); + __m128 mul_ma3 = _mm_mul_ps(src_ma3, weight_ma3); + __m128 mul_ma4 = _mm_mul_ps(src_ma4, weight_ma4); + + dst_ma = _mm_add_ps(dst_ma, mul_ma1); + dst_ma = _mm_add_ps(dst_ma, mul_ma2); + dst_ma = _mm_add_ps(dst_ma, mul_ma3); + dst_ma = _mm_add_ps(dst_ma, mul_ma4); + + src_kw += in_kw_step * 4; + weight_kw += C4NUM * 4; + } + + // c2 loop + for (; c1 < c2; c1 += C2NUM) { + __m128 src_ma1 = _mm_loadu_ps(src_kw); + __m128 src_ma2 = _mm_loadu_ps(src_kw + in_kw_step); + __m128 weight_ma1 = _mm_loadu_ps(weight_kw); + __m128 weight_ma2 = _mm_loadu_ps(weight_kw + C4NUM); + __m128 mul_ma1 = _mm_mul_ps(src_ma1, weight_ma1); + __m128 mul_ma2 = _mm_mul_ps(src_ma2, weight_ma2); + dst_ma = _mm_add_ps(dst_ma, mul_ma1); + dst_ma = _mm_add_ps(dst_ma, mul_ma2); + + src_kw += in_kw_step * 2; + weight_kw += C4NUM * 2; + } + + // remaining + for (; c1 < width; ++c1) { + __m128 src_ma1 = _mm_loadu_ps(src_kw); + __m128 weight_ma1 = _mm_loadu_ps(weight_kw); + __m128 mul_ma1 = _mm_mul_ps(src_ma1, weight_ma1); + dst_ma = _mm_add_ps(dst_ma, mul_ma1); + + src_kw += in_kw_step; + weight_kw += C4NUM; + } + + src_kh += in_kh_step; + weight_kh += kernel_w_step; + } + + __m128 bias_ma = _mm_loadu_ps(bias); + dst_ma = _mm_add_ps(dst_ma, bias_ma); + __m128 zero_ma = _mm_setzero_ps(); + if (relu || relu6) { + dst_ma = _mm_max_ps(zero_ma, dst_ma); + if (relu6) { + __m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); + dst_ma = _mm_min_ps(const_ma, dst_ma); + } + } + _mm_storeu_ps(dst, dst_ma); +} + +void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, + size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, + size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6) { + out_h_step /= sizeof(float); + block_channel /= sizeof(float); + in_sh_step /= sizeof(float); + in_sw_step /= sizeof(float); + in_kh_step /= sizeof(float); + in_kw_step /= sizeof(float); + + float *dst_h = dst; + const float *src_h = src; + for (int oh = 0; oh < height; oh++) { + float *dst_w = dst_h; + const float *src_w = src_h; + int c4 = DOWN_DIV(width, C4NUM) * C4NUM; + int c2 = DOWN_DIV(width, C2NUM) * C2NUM; + int c1 = 0; + // c4 loop + for (; c1 < c4; c1 += C4NUM) { + const float *src_kh = src_w; + const float *weight_kh = weight; + __m128 dst_w_ma1 = _mm_setzero_ps(); + __m128 dst_w_ma2 = _mm_setzero_ps(); + __m128 dst_w_ma3 = _mm_setzero_ps(); + __m128 dst_w_ma4 = _mm_setzero_ps(); + + for (int kh = 0; kh < kernel_h; kh++) { + const float *src_kw = src_kh; + const float *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { + __m128 src_kw_ma1 = _mm_loadu_ps(src_kw); + __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); + + __m128 src_kw_ma2 = _mm_loadu_ps(src_kw + in_sw_step); + __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma2 = _mm_mul_ps(src_kw_ma2, weight_kw_ma2); + dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2); + + __m128 src_kw_ma3 = _mm_loadu_ps(src_kw + 2 * in_sw_step); + __m128 weight_kw_ma3 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma3 = _mm_mul_ps(src_kw_ma3, weight_kw_ma3); + dst_w_ma3 = _mm_add_ps(dst_w_ma3, tmp_ma3); + + __m128 src_kw_ma4 = _mm_loadu_ps(src_kw + 3 * in_sw_step); + __m128 weight_kw_ma4 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma4 = _mm_mul_ps(src_kw_ma4, weight_kw_ma4); + dst_w_ma4 = _mm_add_ps(dst_w_ma4, tmp_ma4); + + src_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + // add bias relu + __m128 bias_ma = _mm_loadu_ps(bias); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma); + dst_w_ma2 = _mm_add_ps(dst_w_ma2, bias_ma); + dst_w_ma3 = _mm_add_ps(dst_w_ma3, bias_ma); + dst_w_ma4 = _mm_add_ps(dst_w_ma4, bias_ma); + + __m128 zero_ma = _mm_setzero_ps(); + if (relu || relu6) { + dst_w_ma1 = _mm_max_ps(zero_ma, dst_w_ma1); + dst_w_ma2 = _mm_max_ps(zero_ma, dst_w_ma2); + dst_w_ma3 = _mm_max_ps(zero_ma, dst_w_ma3); + dst_w_ma4 = _mm_max_ps(zero_ma, dst_w_ma4); + if (relu6) { + __m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); + dst_w_ma1 = _mm_min_ps(const_ma, dst_w_ma1); + dst_w_ma2 = _mm_min_ps(const_ma, dst_w_ma2); + dst_w_ma3 = _mm_min_ps(const_ma, dst_w_ma3); + dst_w_ma4 = _mm_min_ps(const_ma, dst_w_ma4); + } + } + _mm_storeu_ps(dst_w, dst_w_ma1); + _mm_storeu_ps(dst_w + block_channel, dst_w_ma2); + _mm_storeu_ps(dst_w + 2 * block_channel, dst_w_ma3); + _mm_storeu_ps(dst_w + 3 * block_channel, dst_w_ma4); + + dst_w += C4NUM * block_channel; + src_w += C4NUM * in_sw_step; + } // dst_width loop + // c2 loop + for (; c1 < c2; c1 += C2NUM) { + const float *src_kh = src_w; + const float *weight_kh = weight; + __m128 dst_w_ma1 = _mm_setzero_ps(); + __m128 dst_w_ma2 = _mm_setzero_ps(); + + for (int kh = 0; kh < kernel_h; kh++) { + const float *src_kw = src_kh; + const float *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { + __m128 src_kw_ma1 = _mm_loadu_ps(src_kw); + __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); + + __m128 src_kw_ma2 = _mm_loadu_ps(src_kw + in_sw_step); + __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma2 = _mm_mul_ps(src_kw_ma2, weight_kw_ma2); + dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2); + + src_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + // add bias relu + __m128 bias_ma = _mm_loadu_ps(bias); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma); + dst_w_ma2 = _mm_add_ps(dst_w_ma2, bias_ma); + __m128 zero_ma = _mm_setzero_ps(); + if (relu || relu6) { + dst_w_ma1 = _mm_max_ps(zero_ma, dst_w_ma1); + dst_w_ma2 = _mm_max_ps(zero_ma, dst_w_ma2); + if (relu6) { + __m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); + dst_w_ma1 = _mm_min_ps(const_ma, dst_w_ma1); + dst_w_ma2 = _mm_min_ps(const_ma, dst_w_ma2); + } + } + _mm_storeu_ps(dst_w, dst_w_ma1); + _mm_storeu_ps(dst_w + block_channel, dst_w_ma2); + + dst_w += C2NUM * block_channel; + src_w += C2NUM * in_sw_step; + } + // remaining + for (; c1 < width; c1++) { + const float *src_kh = src_w; + const float *weight_kh = weight; + __m128 dst_w_ma1 = _mm_setzero_ps(); + for (int kh = 0; kh < kernel_h; kh++) { + const float *src_kw = src_kh; + const float *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { + __m128 src_kw_ma1 = _mm_loadu_ps(src_kw); + __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); + + src_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + // add bias relu + __m128 bias_ma = _mm_loadu_ps(bias); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma); + __m128 zero_ma = _mm_setzero_ps(); + if (relu || relu6) { + dst_w_ma1 = _mm_max_ps(zero_ma, dst_w_ma1); + if (relu6) { + __m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); + dst_w_ma1 = _mm_min_ps(const_ma, dst_w_ma1); + } + } + _mm_storeu_ps(dst_w, dst_w_ma1); + + dst_w += block_channel; + src_w += in_sw_step; + } + dst_h += out_h_step; + src_h += in_sh_step; + } // dst_height loop +} + +void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, size_t kernel_h, + size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, + size_t in_kh_step, size_t in_kw_step) { + out_h_step /= sizeof(float); + block_channel /= sizeof(float); + in_sh_step /= sizeof(float); + in_sw_step /= sizeof(float); + in_kh_step /= sizeof(float); + in_kw_step /= sizeof(float); + + float *dst_h = dst; + const float *src_h = src; + for (int oh = 0; oh < height; oh++) { + float *dst_w = dst_h; + const float *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + float *dst_kh = dst_w; + const float *weight_kh = weight; + __m128 src_w_ma = _mm_loadu_ps(src_w); + for (int kh = 0; kh < kernel_h; kh++) { + float *dst_kw = dst_kh; + const float *weight_kw = weight_kh; + + int c4 = DOWN_DIV(kernel_w, C4NUM) * C4NUM; + int c2 = DOWN_DIV(kernel_w, C2NUM) * C2NUM; + int c1 = 0; + // c4 loop + for (; c1 < c4; c1 += C4NUM) { + __m128 dst_w_ma1 = _mm_loadu_ps(dst_kw); + __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma1 = _mm_mul_ps(src_w_ma, weight_kw_ma1); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); + _mm_storeu_ps(dst_kw, dst_w_ma1); + + __m128 dst_w_ma2 = _mm_loadu_ps(dst_kw + in_kw_step); + __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw + C4NUM); + __m128 tmp_ma2 = _mm_mul_ps(src_w_ma, weight_kw_ma2); + dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2); + _mm_storeu_ps(dst_kw + in_kw_step, dst_w_ma2); + + __m128 dst_w_ma3 = _mm_loadu_ps(dst_kw + 2 * in_kw_step); + __m128 weight_kw_ma3 = _mm_loadu_ps(weight_kw + 2 * C4NUM); + __m128 tmp_ma3 = _mm_mul_ps(src_w_ma, weight_kw_ma3); + dst_w_ma3 = _mm_add_ps(dst_w_ma3, tmp_ma3); + _mm_storeu_ps(dst_kw + 2 * in_kw_step, dst_w_ma3); + + __m128 dst_w_ma4 = _mm_loadu_ps(dst_kw + 3 * in_kw_step); + __m128 weight_kw_ma4 = _mm_loadu_ps(weight_kw + 3 * C4NUM); + __m128 tmp_ma4 = _mm_mul_ps(src_w_ma, weight_kw_ma4); + dst_w_ma4 = _mm_add_ps(dst_w_ma4, tmp_ma4); + _mm_storeu_ps(dst_kw + 3 * in_kw_step, dst_w_ma4); + + dst_kw += 4 * in_kw_step; + weight_kw += 4 * C4NUM; + } + // c2 loop + for (; c1 < c2; c1 += C2NUM) { + __m128 dst_w_ma1 = _mm_loadu_ps(dst_kw); + __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma1 = _mm_mul_ps(src_w_ma, weight_kw_ma1); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); + _mm_storeu_ps(dst_kw, dst_w_ma1); + + __m128 dst_w_ma2 = _mm_loadu_ps(dst_kw + in_kw_step); + __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw + C4NUM); + __m128 tmp_ma2 = _mm_mul_ps(src_w_ma, weight_kw_ma2); + dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2); + _mm_storeu_ps(dst_kw + in_kw_step, dst_w_ma2); + + dst_kw += 2 * in_kw_step; + weight_kw += 2 * C4NUM; + } + // remaining + for (; c1 < kernel_w; ++c1) { + __m128 dst_w_ma1 = _mm_loadu_ps(dst_kw); + __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma1 = _mm_mul_ps(src_w_ma, weight_kw_ma1); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); + _mm_storeu_ps(dst_kw, dst_w_ma1); + + dst_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + + dst_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + dst_w += in_sw_step; + src_w += block_channel; + } // dst_width loop + dst_h += in_sh_step; + src_h += out_h_step; + } // dst_height loop +} + +#endif diff --git a/mindspore/lite/nnacl/x86_64_sse/PackNHWCToNCHWFp32.c b/mindspore/lite/nnacl/x86_64_sse/PackNHWCToNCHWFp32.c new file mode 100644 index 0000000000..ea9bd43ba2 --- /dev/null +++ b/mindspore/lite/nnacl/x86_64_sse/PackNHWCToNCHWFp32.c @@ -0,0 +1,140 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_X86_64_SSE +#include +#include "nnacl/pack.h" +#include "nnacl/int8/conv_int8.h" + +void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel) { + int hw8 = plane / C8NUM * C8NUM; + int c8 = channel / C8NUM * C8NUM; + int batch = plane * channel; + for (int n = 0; n < batches; n++) { + const float *src_batch = (const float *)src + n * batch; + float *dst_batch = (float *)dst + n * batch; + int hw = 0; + for (; hw < hw8; hw += C8NUM) { + int c = 0; + for (; c < c8; c += C8NUM) { + const float *src_ptr = src_batch + hw * channel + c; + float *dst_ptr = dst_batch + c * plane + hw; + + // 11-14 + __m128 v0_ma = _mm_loadu_ps(src_ptr); + __m128 v1_ma = _mm_loadu_ps(src_ptr + channel); + __m128 v2_ma = _mm_loadu_ps(src_ptr + 2 * channel); + __m128 v3_ma = _mm_loadu_ps(src_ptr + 3 * channel); + + __m128 v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma); + __m128 v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma); + __m128 v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma); + __m128 v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma); + + __m128 v8_ma = _mm_movelh_ps(v4_ma, v6_ma); + __m128 v9_ma = _mm_movehl_ps(v6_ma, v4_ma); + __m128 v10_ma = _mm_movelh_ps(v5_ma, v7_ma); + __m128 v11_ma = _mm_movehl_ps(v7_ma, v5_ma); + + _mm_storeu_ps(dst_ptr, v8_ma); + _mm_storeu_ps(dst_ptr + plane, v9_ma); + _mm_storeu_ps(dst_ptr + 2 * plane, v10_ma); + _mm_storeu_ps(dst_ptr + 3 * plane, v11_ma); + + // 15-18 + v0_ma = _mm_loadu_ps(src_ptr + C4NUM); + v1_ma = _mm_loadu_ps(src_ptr + channel + C4NUM); + v2_ma = _mm_loadu_ps(src_ptr + 2 * channel + C4NUM); + v3_ma = _mm_loadu_ps(src_ptr + 3 * channel + C4NUM); + + v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma); + v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma); + v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma); + v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma); + + v8_ma = _mm_movelh_ps(v4_ma, v6_ma); + v9_ma = _mm_movehl_ps(v6_ma, v4_ma); + v10_ma = _mm_movelh_ps(v5_ma, v7_ma); + v11_ma = _mm_movehl_ps(v7_ma, v5_ma); + + _mm_storeu_ps(dst_ptr + C4NUM * plane, v8_ma); + _mm_storeu_ps(dst_ptr + (C4NUM + 1) * plane, v9_ma); + _mm_storeu_ps(dst_ptr + (C4NUM + 2) * plane, v10_ma); + _mm_storeu_ps(dst_ptr + (C4NUM + 3) * plane, v11_ma); + + // 21-24 + v0_ma = _mm_loadu_ps(src_ptr + C4NUM * channel); + v1_ma = _mm_loadu_ps(src_ptr + (C4NUM + 1) * channel); + v2_ma = _mm_loadu_ps(src_ptr + (C4NUM + 2) * channel); + v3_ma = _mm_loadu_ps(src_ptr + (C4NUM + 3) * channel); + + v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma); + v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma); + v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma); + v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma); + + v8_ma = _mm_movelh_ps(v4_ma, v6_ma); + v9_ma = _mm_movehl_ps(v6_ma, v4_ma); + v10_ma = _mm_movelh_ps(v5_ma, v7_ma); + v11_ma = _mm_movehl_ps(v7_ma, v5_ma); + + _mm_storeu_ps(dst_ptr + C4NUM, v8_ma); + _mm_storeu_ps(dst_ptr + plane + C4NUM, v9_ma); + _mm_storeu_ps(dst_ptr + 2 * plane + C4NUM, v10_ma); + _mm_storeu_ps(dst_ptr + 3 * plane + C4NUM, v11_ma); + + // 25-28 + v0_ma = _mm_loadu_ps(src_ptr + C4NUM * channel + C4NUM); + v1_ma = _mm_loadu_ps(src_ptr + (C4NUM + 1) * channel + C4NUM); + v2_ma = _mm_loadu_ps(src_ptr + (C4NUM + 2) * channel + C4NUM); + v3_ma = _mm_loadu_ps(src_ptr + (C4NUM + 3) * channel + C4NUM); + + v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma); + v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma); + v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma); + v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma); + + v8_ma = _mm_movelh_ps(v4_ma, v6_ma); + v9_ma = _mm_movehl_ps(v6_ma, v4_ma); + v10_ma = _mm_movelh_ps(v5_ma, v7_ma); + v11_ma = _mm_movehl_ps(v7_ma, v5_ma); + + _mm_storeu_ps(dst_ptr + C4NUM * plane + C4NUM, v8_ma); + _mm_storeu_ps(dst_ptr + (C4NUM + 1) * plane + C4NUM, v9_ma); + _mm_storeu_ps(dst_ptr + (C4NUM + 2) * plane + C4NUM, v10_ma); + _mm_storeu_ps(dst_ptr + (C4NUM + 3) * plane + C4NUM, v11_ma); + } + + for (; c < channel; c++) { + const float *src_ptr = src_batch + hw * channel + c; + float *dst_ptr = dst_batch + c * plane + hw; + for (size_t i = 0; i < C8NUM; i++) { + dst_ptr[i] = src_ptr[i * channel]; + } + } + } + for (; hw < plane; hw++) { + const float *src_ptr = src_batch + hw * channel; + float *dst_ptr = dst_batch + hw; + for (size_t i = 0; i < channel; i++) { + dst_ptr[i * plane] = src_ptr[i]; + } + } + } + return; +} + +#endif