Browse Source

add ConvDwFp32Center,ConvDwFp32Border,DeconvDwFp32Center and PackNHWCToNCHWFp32's x86_sse assembly operators

tags/v1.1.0
wangyanling 5 years ago
parent
commit
d27df64c9b
5 changed files with 530 additions and 16 deletions
  1. +9
    -13
      mindspore/lite/nnacl/fp32/common_func.h
  2. +3
    -3
      mindspore/lite/nnacl/fp32/conv_depthwise.c
  3. +2
    -0
      mindspore/lite/nnacl/pack.c
  4. +376
    -0
      mindspore/lite/nnacl/x86_64_sse/DepthwiseFp32_Sse.c
  5. +140
    -0
      mindspore/lite/nnacl/x86_64_sse/PackNHWCToNCHWFp32.c

+ 9
- 13
mindspore/lite/nnacl/fp32/common_func.h View File

@@ -39,27 +39,23 @@ float ShortToFloat32(uint16_t src_value);


uint16_t Float32ToShort(float src_value); uint16_t Float32ToShort(float src_value);


#ifdef ENABLE_X86_64_SSE
void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod,
size_t plane_size, size_t stride, size_t relu_type);
void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod,
size_t plane_size, size_t plane_stride, size_t relu_type);
#endif

#ifdef ENABLE_ARM
#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE)
void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6);
void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,
size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6);
void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, size_t kernel_h, void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, size_t kernel_h,
size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step,
size_t in_kh_step, size_t in_kw_step); size_t in_kh_step, size_t in_kw_step);
void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, size_t num_pixels,
size_t output_channel, size_t input_step);

void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,
size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6);
void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod,
size_t plane_size, size_t stride, size_t relu_type); size_t plane_size, size_t stride, size_t relu_type);
#endif

#ifdef ENABLE_ARM

void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, size_t num_pixels,
size_t output_channel, size_t input_step);
void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod,
size_t plane_size, size_t plane_stride, size_t relu_type); size_t plane_size, size_t plane_stride, size_t relu_type);
#endif #endif


+ 3
- 3
mindspore/lite/nnacl/fp32/conv_depthwise.c View File

@@ -200,7 +200,7 @@ void ConvDwBorder(float *dst, const float *src, const float *weight, const float
const float *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; const float *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_;
const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM;


#ifdef ENABLE_ARM
#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE)
ConvDwFp32Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, ConvDwFp32Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float), sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float),
conv_param->kernel_w_ * C4NUM * sizeof(float), relu, relu6); conv_param->kernel_w_ * C4NUM * sizeof(float), relu, relu6);
@@ -283,7 +283,7 @@ void ConvDwSWFp32(float *output_data, const float *input_data, const float *weig
int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
const float *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; const float *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_;
float *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; float *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
#ifdef ENABLE_ARM
#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE)
ConvDwFp32Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, ConvDwFp32Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float),
sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float),
@@ -437,7 +437,7 @@ void DeconvDwSWFp32(float *output_data, const float *input_data, const float *we
float *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_; float *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_;
const float *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; const float *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;


#ifdef ENABLE_ARM
#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE)
DeconvDwFp32Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, DeconvDwFp32Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float),
sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float),


+ 2
- 0
mindspore/lite/nnacl/pack.c View File

@@ -743,6 +743,7 @@ void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int ch
return; return;
} }


#ifndef ENABLE_X86_64_SSE
void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel) { void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel) {
int hw8 = plane / C8NUM * C8NUM; int hw8 = plane / C8NUM * C8NUM;
int c8 = channel / C8NUM * C8NUM; int c8 = channel / C8NUM * C8NUM;
@@ -928,6 +929,7 @@ void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int
} }
return; return;
} }
#endif


void PackNHWCToNCHWInt8(const void *src, void *dst, int batches, int plane, int channel) { void PackNHWCToNCHWInt8(const void *src, void *dst, int batches, int plane, int channel) {
int hw8 = plane / C8NUM * C8NUM; int hw8 = plane / C8NUM * C8NUM;


+ 376
- 0
mindspore/lite/nnacl/x86_64_sse/DepthwiseFp32_Sse.c View File

@@ -0,0 +1,376 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifdef ENABLE_X86_64_SSE
#include <nmmintrin.h>
#include "nnacl/fp32/conv_depthwise.h"

void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,
size_t in_kh_step, size_t in_kw_step, size_t kernel_w_step, size_t relu, size_t relu6) {
in_kh_step /= sizeof(float);
in_kw_step /= sizeof(float);
kernel_w_step /= sizeof(float);

const float *src_kh = src;
const float *weight_kh = weight;
__m128 dst_ma = _mm_setzero_ps();

for (int kh = 0; kh < height; kh++) {
const float *src_kw = src_kh;
const float *weight_kw = weight_kh;

int c1 = 0;
int c4 = DOWN_DIV(width, C4NUM) * C4NUM;
int c2 = DOWN_DIV(width, C2NUM) * C2NUM;
// c4 loop
for (; c1 < c4; c1 += C4NUM) {
__m128 src_ma1 = _mm_loadu_ps(src_kw);
__m128 src_ma2 = _mm_loadu_ps(src_kw + in_kw_step);
__m128 src_ma3 = _mm_loadu_ps(src_kw + 2 * in_kw_step);
__m128 src_ma4 = _mm_loadu_ps(src_kw + 3 * in_kw_step);

__m128 weight_ma1 = _mm_loadu_ps(weight_kw);
__m128 weight_ma2 = _mm_loadu_ps(weight_kw + C4NUM);
__m128 weight_ma3 = _mm_loadu_ps(weight_kw + 2 * C4NUM);
__m128 weight_ma4 = _mm_loadu_ps(weight_kw + 3 * C4NUM);

__m128 mul_ma1 = _mm_mul_ps(src_ma1, weight_ma1);
__m128 mul_ma2 = _mm_mul_ps(src_ma2, weight_ma2);
__m128 mul_ma3 = _mm_mul_ps(src_ma3, weight_ma3);
__m128 mul_ma4 = _mm_mul_ps(src_ma4, weight_ma4);

dst_ma = _mm_add_ps(dst_ma, mul_ma1);
dst_ma = _mm_add_ps(dst_ma, mul_ma2);
dst_ma = _mm_add_ps(dst_ma, mul_ma3);
dst_ma = _mm_add_ps(dst_ma, mul_ma4);

src_kw += in_kw_step * 4;
weight_kw += C4NUM * 4;
}

// c2 loop
for (; c1 < c2; c1 += C2NUM) {
__m128 src_ma1 = _mm_loadu_ps(src_kw);
__m128 src_ma2 = _mm_loadu_ps(src_kw + in_kw_step);
__m128 weight_ma1 = _mm_loadu_ps(weight_kw);
__m128 weight_ma2 = _mm_loadu_ps(weight_kw + C4NUM);
__m128 mul_ma1 = _mm_mul_ps(src_ma1, weight_ma1);
__m128 mul_ma2 = _mm_mul_ps(src_ma2, weight_ma2);
dst_ma = _mm_add_ps(dst_ma, mul_ma1);
dst_ma = _mm_add_ps(dst_ma, mul_ma2);

src_kw += in_kw_step * 2;
weight_kw += C4NUM * 2;
}

// remaining
for (; c1 < width; ++c1) {
__m128 src_ma1 = _mm_loadu_ps(src_kw);
__m128 weight_ma1 = _mm_loadu_ps(weight_kw);
__m128 mul_ma1 = _mm_mul_ps(src_ma1, weight_ma1);
dst_ma = _mm_add_ps(dst_ma, mul_ma1);

src_kw += in_kw_step;
weight_kw += C4NUM;
}

src_kh += in_kh_step;
weight_kh += kernel_w_step;
}

__m128 bias_ma = _mm_loadu_ps(bias);
dst_ma = _mm_add_ps(dst_ma, bias_ma);
__m128 zero_ma = _mm_setzero_ps();
if (relu || relu6) {
dst_ma = _mm_max_ps(zero_ma, dst_ma);
if (relu6) {
__m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f);
dst_ma = _mm_min_ps(const_ma, dst_ma);
}
}
_mm_storeu_ps(dst, dst_ma);
}

void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6) {
out_h_step /= sizeof(float);
block_channel /= sizeof(float);
in_sh_step /= sizeof(float);
in_sw_step /= sizeof(float);
in_kh_step /= sizeof(float);
in_kw_step /= sizeof(float);

float *dst_h = dst;
const float *src_h = src;
for (int oh = 0; oh < height; oh++) {
float *dst_w = dst_h;
const float *src_w = src_h;
int c4 = DOWN_DIV(width, C4NUM) * C4NUM;
int c2 = DOWN_DIV(width, C2NUM) * C2NUM;
int c1 = 0;
// c4 loop
for (; c1 < c4; c1 += C4NUM) {
const float *src_kh = src_w;
const float *weight_kh = weight;
__m128 dst_w_ma1 = _mm_setzero_ps();
__m128 dst_w_ma2 = _mm_setzero_ps();
__m128 dst_w_ma3 = _mm_setzero_ps();
__m128 dst_w_ma4 = _mm_setzero_ps();

for (int kh = 0; kh < kernel_h; kh++) {
const float *src_kw = src_kh;
const float *weight_kw = weight_kh;
for (int kw = 0; kw < kernel_w; kw++) {
__m128 src_kw_ma1 = _mm_loadu_ps(src_kw);
__m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw);
__m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1);
dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1);

__m128 src_kw_ma2 = _mm_loadu_ps(src_kw + in_sw_step);
__m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw);
__m128 tmp_ma2 = _mm_mul_ps(src_kw_ma2, weight_kw_ma2);
dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2);

__m128 src_kw_ma3 = _mm_loadu_ps(src_kw + 2 * in_sw_step);
__m128 weight_kw_ma3 = _mm_loadu_ps(weight_kw);
__m128 tmp_ma3 = _mm_mul_ps(src_kw_ma3, weight_kw_ma3);
dst_w_ma3 = _mm_add_ps(dst_w_ma3, tmp_ma3);

__m128 src_kw_ma4 = _mm_loadu_ps(src_kw + 3 * in_sw_step);
__m128 weight_kw_ma4 = _mm_loadu_ps(weight_kw);
__m128 tmp_ma4 = _mm_mul_ps(src_kw_ma4, weight_kw_ma4);
dst_w_ma4 = _mm_add_ps(dst_w_ma4, tmp_ma4);

src_kw += in_kw_step;
weight_kw += C4NUM;
} // kernel_w loop
src_kh += in_kh_step;
weight_kh += kernel_w * C4NUM;
} // kernel_h loop
// add bias relu
__m128 bias_ma = _mm_loadu_ps(bias);
dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma);
dst_w_ma2 = _mm_add_ps(dst_w_ma2, bias_ma);
dst_w_ma3 = _mm_add_ps(dst_w_ma3, bias_ma);
dst_w_ma4 = _mm_add_ps(dst_w_ma4, bias_ma);

__m128 zero_ma = _mm_setzero_ps();
if (relu || relu6) {
dst_w_ma1 = _mm_max_ps(zero_ma, dst_w_ma1);
dst_w_ma2 = _mm_max_ps(zero_ma, dst_w_ma2);
dst_w_ma3 = _mm_max_ps(zero_ma, dst_w_ma3);
dst_w_ma4 = _mm_max_ps(zero_ma, dst_w_ma4);
if (relu6) {
__m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f);
dst_w_ma1 = _mm_min_ps(const_ma, dst_w_ma1);
dst_w_ma2 = _mm_min_ps(const_ma, dst_w_ma2);
dst_w_ma3 = _mm_min_ps(const_ma, dst_w_ma3);
dst_w_ma4 = _mm_min_ps(const_ma, dst_w_ma4);
}
}
_mm_storeu_ps(dst_w, dst_w_ma1);
_mm_storeu_ps(dst_w + block_channel, dst_w_ma2);
_mm_storeu_ps(dst_w + 2 * block_channel, dst_w_ma3);
_mm_storeu_ps(dst_w + 3 * block_channel, dst_w_ma4);

dst_w += C4NUM * block_channel;
src_w += C4NUM * in_sw_step;
} // dst_width loop
// c2 loop
for (; c1 < c2; c1 += C2NUM) {
const float *src_kh = src_w;
const float *weight_kh = weight;
__m128 dst_w_ma1 = _mm_setzero_ps();
__m128 dst_w_ma2 = _mm_setzero_ps();

for (int kh = 0; kh < kernel_h; kh++) {
const float *src_kw = src_kh;
const float *weight_kw = weight_kh;
for (int kw = 0; kw < kernel_w; kw++) {
__m128 src_kw_ma1 = _mm_loadu_ps(src_kw);
__m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw);
__m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1);
dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1);

__m128 src_kw_ma2 = _mm_loadu_ps(src_kw + in_sw_step);
__m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw);
__m128 tmp_ma2 = _mm_mul_ps(src_kw_ma2, weight_kw_ma2);
dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2);

src_kw += in_kw_step;
weight_kw += C4NUM;
} // kernel_w loop
src_kh += in_kh_step;
weight_kh += kernel_w * C4NUM;
} // kernel_h loop
// add bias relu
__m128 bias_ma = _mm_loadu_ps(bias);
dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma);
dst_w_ma2 = _mm_add_ps(dst_w_ma2, bias_ma);
__m128 zero_ma = _mm_setzero_ps();
if (relu || relu6) {
dst_w_ma1 = _mm_max_ps(zero_ma, dst_w_ma1);
dst_w_ma2 = _mm_max_ps(zero_ma, dst_w_ma2);
if (relu6) {
__m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f);
dst_w_ma1 = _mm_min_ps(const_ma, dst_w_ma1);
dst_w_ma2 = _mm_min_ps(const_ma, dst_w_ma2);
}
}
_mm_storeu_ps(dst_w, dst_w_ma1);
_mm_storeu_ps(dst_w + block_channel, dst_w_ma2);

dst_w += C2NUM * block_channel;
src_w += C2NUM * in_sw_step;
}
// remaining
for (; c1 < width; c1++) {
const float *src_kh = src_w;
const float *weight_kh = weight;
__m128 dst_w_ma1 = _mm_setzero_ps();
for (int kh = 0; kh < kernel_h; kh++) {
const float *src_kw = src_kh;
const float *weight_kw = weight_kh;
for (int kw = 0; kw < kernel_w; kw++) {
__m128 src_kw_ma1 = _mm_loadu_ps(src_kw);
__m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw);
__m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1);
dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1);

src_kw += in_kw_step;
weight_kw += C4NUM;
} // kernel_w loop
src_kh += in_kh_step;
weight_kh += kernel_w * C4NUM;
} // kernel_h loop
// add bias relu
__m128 bias_ma = _mm_loadu_ps(bias);
dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma);
__m128 zero_ma = _mm_setzero_ps();
if (relu || relu6) {
dst_w_ma1 = _mm_max_ps(zero_ma, dst_w_ma1);
if (relu6) {
__m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f);
dst_w_ma1 = _mm_min_ps(const_ma, dst_w_ma1);
}
}
_mm_storeu_ps(dst_w, dst_w_ma1);

dst_w += block_channel;
src_w += in_sw_step;
}
dst_h += out_h_step;
src_h += in_sh_step;
} // dst_height loop
}

void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, size_t kernel_h,
size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step,
size_t in_kh_step, size_t in_kw_step) {
out_h_step /= sizeof(float);
block_channel /= sizeof(float);
in_sh_step /= sizeof(float);
in_sw_step /= sizeof(float);
in_kh_step /= sizeof(float);
in_kw_step /= sizeof(float);

float *dst_h = dst;
const float *src_h = src;
for (int oh = 0; oh < height; oh++) {
float *dst_w = dst_h;
const float *src_w = src_h;
for (int ow = 0; ow < width; ow++) {
float *dst_kh = dst_w;
const float *weight_kh = weight;
__m128 src_w_ma = _mm_loadu_ps(src_w);
for (int kh = 0; kh < kernel_h; kh++) {
float *dst_kw = dst_kh;
const float *weight_kw = weight_kh;

int c4 = DOWN_DIV(kernel_w, C4NUM) * C4NUM;
int c2 = DOWN_DIV(kernel_w, C2NUM) * C2NUM;
int c1 = 0;
// c4 loop
for (; c1 < c4; c1 += C4NUM) {
__m128 dst_w_ma1 = _mm_loadu_ps(dst_kw);
__m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw);
__m128 tmp_ma1 = _mm_mul_ps(src_w_ma, weight_kw_ma1);
dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1);
_mm_storeu_ps(dst_kw, dst_w_ma1);

__m128 dst_w_ma2 = _mm_loadu_ps(dst_kw + in_kw_step);
__m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw + C4NUM);
__m128 tmp_ma2 = _mm_mul_ps(src_w_ma, weight_kw_ma2);
dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2);
_mm_storeu_ps(dst_kw + in_kw_step, dst_w_ma2);

__m128 dst_w_ma3 = _mm_loadu_ps(dst_kw + 2 * in_kw_step);
__m128 weight_kw_ma3 = _mm_loadu_ps(weight_kw + 2 * C4NUM);
__m128 tmp_ma3 = _mm_mul_ps(src_w_ma, weight_kw_ma3);
dst_w_ma3 = _mm_add_ps(dst_w_ma3, tmp_ma3);
_mm_storeu_ps(dst_kw + 2 * in_kw_step, dst_w_ma3);

__m128 dst_w_ma4 = _mm_loadu_ps(dst_kw + 3 * in_kw_step);
__m128 weight_kw_ma4 = _mm_loadu_ps(weight_kw + 3 * C4NUM);
__m128 tmp_ma4 = _mm_mul_ps(src_w_ma, weight_kw_ma4);
dst_w_ma4 = _mm_add_ps(dst_w_ma4, tmp_ma4);
_mm_storeu_ps(dst_kw + 3 * in_kw_step, dst_w_ma4);

dst_kw += 4 * in_kw_step;
weight_kw += 4 * C4NUM;
}
// c2 loop
for (; c1 < c2; c1 += C2NUM) {
__m128 dst_w_ma1 = _mm_loadu_ps(dst_kw);
__m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw);
__m128 tmp_ma1 = _mm_mul_ps(src_w_ma, weight_kw_ma1);
dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1);
_mm_storeu_ps(dst_kw, dst_w_ma1);

__m128 dst_w_ma2 = _mm_loadu_ps(dst_kw + in_kw_step);
__m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw + C4NUM);
__m128 tmp_ma2 = _mm_mul_ps(src_w_ma, weight_kw_ma2);
dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2);
_mm_storeu_ps(dst_kw + in_kw_step, dst_w_ma2);

dst_kw += 2 * in_kw_step;
weight_kw += 2 * C4NUM;
}
// remaining
for (; c1 < kernel_w; ++c1) {
__m128 dst_w_ma1 = _mm_loadu_ps(dst_kw);
__m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw);
__m128 tmp_ma1 = _mm_mul_ps(src_w_ma, weight_kw_ma1);
dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1);
_mm_storeu_ps(dst_kw, dst_w_ma1);

dst_kw += in_kw_step;
weight_kw += C4NUM;
} // kernel_w loop

dst_kh += in_kh_step;
weight_kh += kernel_w * C4NUM;
} // kernel_h loop
dst_w += in_sw_step;
src_w += block_channel;
} // dst_width loop
dst_h += in_sh_step;
src_h += out_h_step;
} // dst_height loop
}

#endif

+ 140
- 0
mindspore/lite/nnacl/x86_64_sse/PackNHWCToNCHWFp32.c View File

@@ -0,0 +1,140 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifdef ENABLE_X86_64_SSE
#include <nmmintrin.h>
#include "nnacl/pack.h"
#include "nnacl/int8/conv_int8.h"

void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel) {
int hw8 = plane / C8NUM * C8NUM;
int c8 = channel / C8NUM * C8NUM;
int batch = plane * channel;
for (int n = 0; n < batches; n++) {
const float *src_batch = (const float *)src + n * batch;
float *dst_batch = (float *)dst + n * batch;
int hw = 0;
for (; hw < hw8; hw += C8NUM) {
int c = 0;
for (; c < c8; c += C8NUM) {
const float *src_ptr = src_batch + hw * channel + c;
float *dst_ptr = dst_batch + c * plane + hw;

// 11-14
__m128 v0_ma = _mm_loadu_ps(src_ptr);
__m128 v1_ma = _mm_loadu_ps(src_ptr + channel);
__m128 v2_ma = _mm_loadu_ps(src_ptr + 2 * channel);
__m128 v3_ma = _mm_loadu_ps(src_ptr + 3 * channel);

__m128 v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma);
__m128 v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma);
__m128 v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma);
__m128 v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma);

__m128 v8_ma = _mm_movelh_ps(v4_ma, v6_ma);
__m128 v9_ma = _mm_movehl_ps(v6_ma, v4_ma);
__m128 v10_ma = _mm_movelh_ps(v5_ma, v7_ma);
__m128 v11_ma = _mm_movehl_ps(v7_ma, v5_ma);

_mm_storeu_ps(dst_ptr, v8_ma);
_mm_storeu_ps(dst_ptr + plane, v9_ma);
_mm_storeu_ps(dst_ptr + 2 * plane, v10_ma);
_mm_storeu_ps(dst_ptr + 3 * plane, v11_ma);

// 15-18
v0_ma = _mm_loadu_ps(src_ptr + C4NUM);
v1_ma = _mm_loadu_ps(src_ptr + channel + C4NUM);
v2_ma = _mm_loadu_ps(src_ptr + 2 * channel + C4NUM);
v3_ma = _mm_loadu_ps(src_ptr + 3 * channel + C4NUM);

v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma);
v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma);
v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma);
v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma);

v8_ma = _mm_movelh_ps(v4_ma, v6_ma);
v9_ma = _mm_movehl_ps(v6_ma, v4_ma);
v10_ma = _mm_movelh_ps(v5_ma, v7_ma);
v11_ma = _mm_movehl_ps(v7_ma, v5_ma);

_mm_storeu_ps(dst_ptr + C4NUM * plane, v8_ma);
_mm_storeu_ps(dst_ptr + (C4NUM + 1) * plane, v9_ma);
_mm_storeu_ps(dst_ptr + (C4NUM + 2) * plane, v10_ma);
_mm_storeu_ps(dst_ptr + (C4NUM + 3) * plane, v11_ma);

// 21-24
v0_ma = _mm_loadu_ps(src_ptr + C4NUM * channel);
v1_ma = _mm_loadu_ps(src_ptr + (C4NUM + 1) * channel);
v2_ma = _mm_loadu_ps(src_ptr + (C4NUM + 2) * channel);
v3_ma = _mm_loadu_ps(src_ptr + (C4NUM + 3) * channel);

v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma);
v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma);
v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma);
v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma);

v8_ma = _mm_movelh_ps(v4_ma, v6_ma);
v9_ma = _mm_movehl_ps(v6_ma, v4_ma);
v10_ma = _mm_movelh_ps(v5_ma, v7_ma);
v11_ma = _mm_movehl_ps(v7_ma, v5_ma);

_mm_storeu_ps(dst_ptr + C4NUM, v8_ma);
_mm_storeu_ps(dst_ptr + plane + C4NUM, v9_ma);
_mm_storeu_ps(dst_ptr + 2 * plane + C4NUM, v10_ma);
_mm_storeu_ps(dst_ptr + 3 * plane + C4NUM, v11_ma);

// 25-28
v0_ma = _mm_loadu_ps(src_ptr + C4NUM * channel + C4NUM);
v1_ma = _mm_loadu_ps(src_ptr + (C4NUM + 1) * channel + C4NUM);
v2_ma = _mm_loadu_ps(src_ptr + (C4NUM + 2) * channel + C4NUM);
v3_ma = _mm_loadu_ps(src_ptr + (C4NUM + 3) * channel + C4NUM);

v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma);
v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma);
v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma);
v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma);

v8_ma = _mm_movelh_ps(v4_ma, v6_ma);
v9_ma = _mm_movehl_ps(v6_ma, v4_ma);
v10_ma = _mm_movelh_ps(v5_ma, v7_ma);
v11_ma = _mm_movehl_ps(v7_ma, v5_ma);

_mm_storeu_ps(dst_ptr + C4NUM * plane + C4NUM, v8_ma);
_mm_storeu_ps(dst_ptr + (C4NUM + 1) * plane + C4NUM, v9_ma);
_mm_storeu_ps(dst_ptr + (C4NUM + 2) * plane + C4NUM, v10_ma);
_mm_storeu_ps(dst_ptr + (C4NUM + 3) * plane + C4NUM, v11_ma);
}

for (; c < channel; c++) {
const float *src_ptr = src_batch + hw * channel + c;
float *dst_ptr = dst_batch + c * plane + hw;
for (size_t i = 0; i < C8NUM; i++) {
dst_ptr[i] = src_ptr[i * channel];
}
}
}
for (; hw < plane; hw++) {
const float *src_ptr = src_batch + hw * channel;
float *dst_ptr = dst_batch + hw;
for (size_t i = 0; i < channel; i++) {
dst_ptr[i * plane] = src_ptr[i];
}
}
}
return;
}

#endif

Loading…
Cancel
Save