diff --git a/mindspore/lite/micro/coder/opcoders/base/conv2d_base_coder.cc b/mindspore/lite/micro/coder/opcoders/base/conv2d_base_coder.cc index 3aac24a8a6..a1aa621a85 100644 --- a/mindspore/lite/micro/coder/opcoders/base/conv2d_base_coder.cc +++ b/mindspore/lite/micro/coder/opcoders/base/conv2d_base_coder.cc @@ -17,7 +17,7 @@ #include "micro/coder/opcoders/base/conv2d_base_coder.h" #include #include -#include "nnacl/winograd_utils.h" +#include "nnacl/fp32/winograd_utils.h" #include "nnacl/int8/quantize.h" #include "micro/coder/log.h" diff --git a/mindspore/lite/nnacl/CMakeLists.txt b/mindspore/lite/nnacl/CMakeLists.txt index eb926daeee..fcc6558ed7 100644 --- a/mindspore/lite/nnacl/CMakeLists.txt +++ b/mindspore/lite/nnacl/CMakeLists.txt @@ -5,8 +5,10 @@ include_directories(NNACL_DIR) if(PLATFORM_ARM32 OR PLATFORM_ARM64) if("${CMAKE_BUILD_TYPE}" STREQUAL "Release") - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fomit-frame-pointer -fstrict-aliasing -ffunction-sections -fdata-sections -ffast-math") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fomit-frame-pointer -fstrict-aliasing -ffunction-sections -fdata-sections -ffast-math") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fomit-frame-pointer -fstrict-aliasing \ + -ffunction-sections -fdata-sections -ffast-math") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fomit-frame-pointer -fstrict-aliasing \ + -ffunction-sections -fdata-sections -ffast-math") endif() endif() if("${X86_64_SIMD}" STREQUAL "avx") @@ -37,14 +39,14 @@ if(PLATFORM_ARM32) endif() if("${X86_64_SIMD}" STREQUAL "sse") - file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/x86_64_sse/*.c) + file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/intrinsics/sse/*.c) set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) endif() if("${X86_64_SIMD}" STREQUAL "avx") - file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/x86_64_sse/*.c - ${NNACL_DIR}/x86_64_avx/*.c - ${NNACL_DIR}/assembly/avx/*.S) + file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/intrinsics/sse/*.c + ${NNACL_DIR}/intrinsics/avx/*.c + ${NNACL_DIR}/assembly/avx/*.S) set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) endif() diff --git a/mindspore/lite/nnacl/minimal_filtering_generator.c b/mindspore/lite/nnacl/base/minimal_filtering_generator.c similarity index 99% rename from mindspore/lite/nnacl/minimal_filtering_generator.c rename to mindspore/lite/nnacl/base/minimal_filtering_generator.c index acb8dbdce4..b17000d357 100644 --- a/mindspore/lite/nnacl/minimal_filtering_generator.c +++ b/mindspore/lite/nnacl/base/minimal_filtering_generator.c @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "nnacl/minimal_filtering_generator.h" +#include "nnacl/base/minimal_filtering_generator.h" #include #include -#include "nnacl/winograd_utils.h" +#include "nnacl/fp32/winograd_utils.h" #include "nnacl/errorcode.h" void Polynomial(const float *interval, float *m, int degree) { diff --git a/mindspore/lite/nnacl/minimal_filtering_generator.h b/mindspore/lite/nnacl/base/minimal_filtering_generator.h similarity index 100% rename from mindspore/lite/nnacl/minimal_filtering_generator.h rename to mindspore/lite/nnacl/base/minimal_filtering_generator.h diff --git a/mindspore/lite/nnacl/conv_parameter.h b/mindspore/lite/nnacl/conv_parameter.h index 4d1adc6a82..174240b894 100644 --- a/mindspore/lite/nnacl/conv_parameter.h +++ b/mindspore/lite/nnacl/conv_parameter.h @@ -72,6 +72,7 @@ typedef struct SlidingWindowParam { int kernel_step_; } SlidingWindowParam; +#define OUPUT_UNIT 2 #define DECONV_WINOGRAD_DEFAULT_UNIT 3 #define DECONV_WINOGRAD_DEFAULT_TILE 8 #define DECONV_WINOGRAD_BUFFER_COUNT 8 diff --git a/mindspore/lite/nnacl/fp16/deconv_winograd_fp16.c b/mindspore/lite/nnacl/fp16/deconv_winograd_fp16.c index 64dd317d4d..98db40619e 100644 --- a/mindspore/lite/nnacl/fp16/deconv_winograd_fp16.c +++ b/mindspore/lite/nnacl/fp16/deconv_winograd_fp16.c @@ -15,7 +15,7 @@ */ #include "nnacl/fp16/deconv_winograd_fp16.h" -#include "nnacl/minimal_filtering_generator.h" +#include "nnacl/base/minimal_filtering_generator.h" void DeConvWgInputPackFp16(float16_t *src_ptr, float16_t *dst_ptr, int channel, int stride) { int ic4div = channel / C4NUM; @@ -111,16 +111,16 @@ void DeConvWgMergeFp16(const float16_t *src, float16_t *dst, size_t src_stride, } void DeConvWgCalWgFp16(float16_t *tile_in, float16_t *tile_out, float16_t *weight_buf, float16_t *tmp_buf, - float16_t *at_buf, float16_t *a_mid_buf, float16_t *trans_a_buf, bool *transfered, + float16_t *at_buf, float16_t *a_mid_buf, float16_t *trans_a_buf, bool *transferred, float16_t *bt_buf, float16_t *b_tmp_buf, int unit_size, int w_start, int h_start, ConvParameter *conv_param, DeConvParam *deconv_param) { int winograd_plane = unit_size * unit_size; - if (!transfered[unit_size]) { + if (!transferred[unit_size]) { WinogradTransLeftFp16(tile_in, at_buf, a_mid_buf, DECONV_WINOGRAD_DEFAULT_UNIT, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT, deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE); WinogradTransRightFp16(a_mid_buf, at_buf, trans_a_buf, unit_size, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT, deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE); - transfered[unit_size] = true; + transferred[unit_size] = true; } for (int index = 0; index < winograd_plane; index++) { @@ -311,7 +311,7 @@ void DeconvWgFp16(float16_t *nhwc_input_, float16_t *tile_in, float16_t *tile_ou } /* compute */ - bool transfered[DECONV_WINOGRAD_BUFFER_COUNT] = {false}; + bool transferred[DECONV_WINOGRAD_BUFFER_COUNT] = {false}; for (int i = 0; i < deconv_param->compute_size_; i++) { DeConvComputeUnit *unit = &deconv_param->compute_units_[i]; if (unit->use_winograd_) { @@ -328,7 +328,7 @@ void DeconvWgFp16(float16_t *nhwc_input_, float16_t *tile_in, float16_t *tile_ou DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up4_; DeConvWgCalWgFp16(tile_in, tile_out, (float16_t *)unit->weight_, tmp_buf, unit->winograd_.AT_, mid_a, dst_a, - transfered, unit->winograd_.BT_, tmp_b, unit->winograd_.kh_, unit->w_start_, unit->h_start_, + transferred, unit->winograd_.BT_, tmp_b, unit->winograd_.kh_, unit->w_start_, unit->h_start_, conv_param, deconv_param); } else { float16_t *tmp_buf = (float16_t *)unit->tmp_buffer_ + task_id * deconv_param->oc_div4_ * unit->w_size_ * diff --git a/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c b/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c index 3e088002cb..7b9ff25553 100644 --- a/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c +++ b/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c @@ -17,7 +17,7 @@ #include "nnacl/fp32/conv_depthwise_fp32.h" #include "nnacl/common_func.h" #include "nnacl/fp32/common_func_fp32.h" -#include "nnacl/winograd_transform.h" +#include "nnacl/fp32/winograd_transform.h" #ifdef ENABLE_ARM64 #include #endif diff --git a/mindspore/lite/nnacl/fp32/conv_winograd_fp32.c b/mindspore/lite/nnacl/fp32/conv_winograd_fp32.c index d8370aa468..3cdf63d731 100644 --- a/mindspore/lite/nnacl/fp32/conv_winograd_fp32.c +++ b/mindspore/lite/nnacl/fp32/conv_winograd_fp32.c @@ -17,7 +17,7 @@ #include "nnacl/fp32/conv_winograd_fp32.h" #include #include "nnacl/fp32/common_func_fp32.h" -#include "nnacl/winograd_transform.h" +#include "nnacl/fp32/winograd_transform.h" #include "nnacl/fp32/matmul_fp32.h" // fp32 conv winograd diff --git a/mindspore/lite/nnacl/fp32/conv_winograd_fp32.h b/mindspore/lite/nnacl/fp32/conv_winograd_fp32.h index aaa8402985..c4cee273f0 100644 --- a/mindspore/lite/nnacl/fp32/conv_winograd_fp32.h +++ b/mindspore/lite/nnacl/fp32/conv_winograd_fp32.h @@ -24,7 +24,7 @@ #include "nnacl/op_base.h" #include "nnacl/common_func.h" #include "nnacl/conv_parameter.h" -#include "nnacl/winograd_utils.h" +#include "nnacl/fp32/winograd_utils.h" #include "nnacl/fp32/conv_depthwise_fp32.h" typedef float *TmpBufferAddress; diff --git a/mindspore/lite/nnacl/fp32/deconv_fp32.h b/mindspore/lite/nnacl/fp32/deconv_fp32.h index ce53163484..2cd7ea3b02 100644 --- a/mindspore/lite/nnacl/fp32/deconv_fp32.h +++ b/mindspore/lite/nnacl/fp32/deconv_fp32.h @@ -22,7 +22,7 @@ #include "nnacl/conv_parameter.h" #include "nnacl/errorcode.h" #include "nnacl/fp32/common_func_fp32.h" -#include "nnacl/minimal_filtering_generator.h" +#include "nnacl/base/minimal_filtering_generator.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore/lite/nnacl/fp32/deconv_winograd_fp32.h b/mindspore/lite/nnacl/fp32/deconv_winograd_fp32.h index c3e9750c10..a2056dd36d 100644 --- a/mindspore/lite/nnacl/fp32/deconv_winograd_fp32.h +++ b/mindspore/lite/nnacl/fp32/deconv_winograd_fp32.h @@ -22,7 +22,7 @@ #include "nnacl/conv_parameter.h" #include "nnacl/errorcode.h" #include "nnacl/fp32/common_func_fp32.h" -#include "nnacl/minimal_filtering_generator.h" +#include "nnacl/base/minimal_filtering_generator.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore/lite/nnacl/winograd_transform.c b/mindspore/lite/nnacl/fp32/winograd_transform.c similarity index 99% rename from mindspore/lite/nnacl/winograd_transform.c rename to mindspore/lite/nnacl/fp32/winograd_transform.c index a14a026600..ff32a319f7 100644 --- a/mindspore/lite/nnacl/winograd_transform.c +++ b/mindspore/lite/nnacl/fp32/winograd_transform.c @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "nnacl/winograd_transform.h" +#include "nnacl/fp32/winograd_transform.h" #include "nnacl/op_base.h" // fp32 conv winograd diff --git a/mindspore/lite/nnacl/winograd_transform.h b/mindspore/lite/nnacl/fp32/winograd_transform.h similarity index 93% rename from mindspore/lite/nnacl/winograd_transform.h rename to mindspore/lite/nnacl/fp32/winograd_transform.h index 39b4961e42..58da682215 100644 --- a/mindspore/lite/nnacl/winograd_transform.h +++ b/mindspore/lite/nnacl/fp32/winograd_transform.h @@ -22,10 +22,7 @@ #endif #include #include "nnacl/pack.h" -#include "nnacl/winograd_utils.h" -#include "mindspore/lite/nnacl/int8/fixed_point.h" - -#define OUPUT_UNIT 2 +#include "nnacl/fp32/winograd_utils.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore/lite/nnacl/winograd_utils.c b/mindspore/lite/nnacl/fp32/winograd_utils.c similarity index 96% rename from mindspore/lite/nnacl/winograd_utils.c rename to mindspore/lite/nnacl/fp32/winograd_utils.c index 06231970bf..941920a852 100644 --- a/mindspore/lite/nnacl/winograd_utils.c +++ b/mindspore/lite/nnacl/fp32/winograd_utils.c @@ -14,8 +14,8 @@ * limitations under the License. */ -#include "nnacl/winograd_utils.h" -#include "nnacl/minimal_filtering_generator.h" +#include "nnacl/fp32/winograd_utils.h" +#include "nnacl/base/minimal_filtering_generator.h" #define MIN_UNIT 2 #define MAX_UNIT 8 @@ -303,65 +303,66 @@ void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, #endif } +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void InputTransform8x8Unit_block4(const float *src_data, float *dst_data, int src_step, int dst_step) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[64]; + MS_FLOAT32X4 m[64]; + Load64Data; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = MS_SUBQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_F32(src[offset], 0.5625), MS_MULQ_F32(src[2 + offset], 3.0625)), + MS_MULQ_F32(src[4 + offset], 3.5)), + src[6 + offset]); + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(MS_MULQ_F32(src[1 + offset], 1.125), MS_MULQ_F32(src[5 + offset], 0.5)); + MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(MS_MULQ_F32(src[2 + offset], 2.25), MS_MULQ_F32(src[4 + offset], 3.25)); + t[8 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(src[3 + offset], 1.625)), src[6 + offset]); + t[16 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(src[3 + offset], 1.625)), src[6 + offset]); + tmp1 = MS_ADDQ_F32(MS_MULQ_F32(src[1 + offset], 0.5625), src[5 + offset]); + tmp2 = MS_SUBQ_F32(MS_MULQ_F32(src[2 + offset], 0.5625), MS_MULQ_F32(src[4 + offset], 2.5)); + t[24 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(src[3 + offset], 2.5)), src[6 + offset]); + t[32 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(src[3 + offset], 2.5)), src[6 + offset]); + tmp1 = MS_ADDQ_F32(MS_MULQ_F32(src[1 + offset], 0.375), MS_MULQ_F32(src[5 + offset], 1.5)); + tmp2 = MS_SUBQ_F32(MS_MULQ_F32(src[2 + offset], 0.25), MS_MULQ_F32(src[4 + offset], 1.25)); + t[40 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(src[3 + offset], 1.875)), src[6 + offset]); + t[48 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(src[3 + offset], 1.875)), src[6 + offset]); + t[56 + l] = + MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_F32(src[1 + offset], -0.5625), MS_MULQ_F32(src[3 + offset], 3.0625)), + MS_MULQ_F32(src[5 + offset], 3.5)), + src[7 + offset]); + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = MS_SUBQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_F32(t[offset], 0.5625), MS_MULQ_F32(t[2 + offset], 3.0625)), + MS_MULQ_F32(t[4 + offset], 3.5)), + t[6 + offset]); + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(MS_MULQ_F32(t[1 + offset], 1.125), MS_MULQ_F32(t[5 + offset], 0.5)); + MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(MS_MULQ_F32(t[2 + offset], 2.25), MS_MULQ_F32(t[4 + offset], 3.25)); + m[8 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(t[3 + offset], 1.625)), t[6 + offset]); + m[16 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(t[3 + offset], 1.625)), t[6 + offset]); + tmp1 = MS_ADDQ_F32(MS_MULQ_F32(t[1 + offset], 0.5625), t[5 + offset]); + tmp2 = MS_SUBQ_F32(MS_MULQ_F32(t[2 + offset], 0.5625), MS_MULQ_F32(t[4 + offset], 2.5)); + m[24 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(t[3 + offset], 2.5)), t[6 + offset]); + m[32 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(t[3 + offset], 2.5)), t[6 + offset]); + tmp1 = MS_ADDQ_F32(MS_MULQ_F32(t[1 + offset], 0.375), MS_MULQ_F32(t[5 + offset], 1.5)); + tmp2 = MS_SUBQ_F32(MS_MULQ_F32(t[2 + offset], 0.25), MS_MULQ_F32(t[4 + offset], 1.25)); + m[40 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(t[3 + offset], 1.875)), t[6 + offset]); + m[48 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(t[3 + offset], 1.875)), t[6 + offset]); + m[56 + l] = + MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_F32(t[1 + offset], -0.5625), MS_MULQ_F32(t[3 + offset], 3.0625)), + MS_MULQ_F32(t[5 + offset], 3.5)), + t[7 + offset]); + } + for (int i = 0; i < 64; i++) { + MS_STQ_F32(dst_data + i * dst_step, m[i]); + } +} +#endif + void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { #if defined(ENABLE_ARM) || defined(ENABLE_SSE) if (real_c == 4) { - MS_FLOAT32X4 src[64]; - MS_FLOAT32X4 t[64]; - MS_FLOAT32X4 m[64]; - Load64Data; - for (int l = 0; l < 8; ++l) { - int offset = l * 8; - t[l] = - MS_SUBQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_F32(src[offset], 0.5625), MS_MULQ_F32(src[2 + offset], 3.0625)), - MS_MULQ_F32(src[4 + offset], 3.5)), - src[6 + offset]); - MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(MS_MULQ_F32(src[1 + offset], 1.125), MS_MULQ_F32(src[5 + offset], 0.5)); - MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(MS_MULQ_F32(src[2 + offset], 2.25), MS_MULQ_F32(src[4 + offset], 3.25)); - t[8 + l] = - MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(src[3 + offset], 1.625)), src[6 + offset]); - t[16 + l] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(src[3 + offset], 1.625)), src[6 + offset]); - tmp1 = MS_ADDQ_F32(MS_MULQ_F32(src[1 + offset], 0.5625), src[5 + offset]); - tmp2 = MS_SUBQ_F32(MS_MULQ_F32(src[2 + offset], 0.5625), MS_MULQ_F32(src[4 + offset], 2.5)); - t[24 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(src[3 + offset], 2.5)), src[6 + offset]); - t[32 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(src[3 + offset], 2.5)), src[6 + offset]); - tmp1 = MS_ADDQ_F32(MS_MULQ_F32(src[1 + offset], 0.375), MS_MULQ_F32(src[5 + offset], 1.5)); - tmp2 = MS_SUBQ_F32(MS_MULQ_F32(src[2 + offset], 0.25), MS_MULQ_F32(src[4 + offset], 1.25)); - t[40 + l] = - MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(src[3 + offset], 1.875)), src[6 + offset]); - t[48 + l] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(src[3 + offset], 1.875)), src[6 + offset]); - t[56 + l] = MS_ADDQ_F32( - MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_F32(src[1 + offset], -0.5625), MS_MULQ_F32(src[3 + offset], 3.0625)), - MS_MULQ_F32(src[5 + offset], 3.5)), - src[7 + offset]); - } - for (int l = 0; l < 8; ++l) { - int offset = l * 8; - m[l] = MS_SUBQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_F32(t[offset], 0.5625), MS_MULQ_F32(t[2 + offset], 3.0625)), - MS_MULQ_F32(t[4 + offset], 3.5)), - t[6 + offset]); - MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(MS_MULQ_F32(t[1 + offset], 1.125), MS_MULQ_F32(t[5 + offset], 0.5)); - MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(MS_MULQ_F32(t[2 + offset], 2.25), MS_MULQ_F32(t[4 + offset], 3.25)); - m[8 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(t[3 + offset], 1.625)), t[6 + offset]); - m[16 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(t[3 + offset], 1.625)), t[6 + offset]); - tmp1 = MS_ADDQ_F32(MS_MULQ_F32(t[1 + offset], 0.5625), t[5 + offset]); - tmp2 = MS_SUBQ_F32(MS_MULQ_F32(t[2 + offset], 0.5625), MS_MULQ_F32(t[4 + offset], 2.5)); - m[24 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(t[3 + offset], 2.5)), t[6 + offset]); - m[32 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(t[3 + offset], 2.5)), t[6 + offset]); - tmp1 = MS_ADDQ_F32(MS_MULQ_F32(t[1 + offset], 0.375), MS_MULQ_F32(t[5 + offset], 1.5)); - tmp2 = MS_SUBQ_F32(MS_MULQ_F32(t[2 + offset], 0.25), MS_MULQ_F32(t[4 + offset], 1.25)); - m[40 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(t[3 + offset], 1.875)), t[6 + offset]); - m[48 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(t[3 + offset], 1.875)), t[6 + offset]); - m[56 + l] = - MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_F32(t[1 + offset], -0.5625), MS_MULQ_F32(t[3 + offset], 3.0625)), - MS_MULQ_F32(t[5 + offset], 3.5)), - t[7 + offset]); - } - for (int i = 0; i < 64; i++) { - MS_STQ_F32(dst_data + i * dst_step, m[i]); - } + InputTransform8x8Unit_block4(src_data, dst_data, src_step, dst_step); } else { #endif float src[64]; @@ -2778,9 +2779,9 @@ void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float #endif } +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { -#if defined(ENABLE_ARM) || defined(ENABLE_SSE) MS_FLOAT32X4 src[64]; MS_FLOAT32X4 t[40]; MS_FLOAT32X4 m[25]; @@ -2837,7 +2838,10 @@ void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const fl } } } +} #else +void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { float src[64]; float t[40]; float m[25]; @@ -2882,12 +2886,12 @@ void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const fl } } } -#endif } +#endif +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { -#if defined(ENABLE_ARM) || defined(ENABLE_SSE) MS_FLOAT32X4 src[64]; MS_FLOAT32X4 t[40]; MS_FLOAT32X4 m[25]; @@ -2950,7 +2954,10 @@ void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const f } } } +} #else +void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { float src[64]; float t[40]; float m[25]; @@ -2996,12 +3003,12 @@ void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const f } } } -#endif } +#endif +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { -#if defined(ENABLE_ARM) || defined(ENABLE_SSE) MS_FLOAT32X4 src[64]; MS_FLOAT32X4 t[48]; MS_FLOAT32X4 m[36]; @@ -3065,7 +3072,10 @@ void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float } } } +} #else +void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { float src[64]; float t[48]; float m[36]; @@ -3112,12 +3122,12 @@ void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float } } } -#endif } +#endif +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { -#if defined(ENABLE_ARM) || defined(ENABLE_SSE) MS_FLOAT32X4 src[64]; MS_FLOAT32X4 t[48]; MS_FLOAT32X4 m[36]; @@ -3188,7 +3198,10 @@ void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const fl } } } +} #else +void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { float src[64]; float t[48]; float m[36]; @@ -3237,12 +3250,12 @@ void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const fl } } } -#endif } +#endif +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { -#if defined(ENABLE_ARM) || defined(ENABLE_SSE) MS_FLOAT32X4 src[64]; MS_FLOAT32X4 t[48]; MS_FLOAT32X4 m[36]; @@ -3320,7 +3333,10 @@ void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const f } } } +} #else +void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { float src[64]; float t[48]; float m[36]; @@ -3370,12 +3386,12 @@ void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const f } } } -#endif } +#endif +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { -#if defined(ENABLE_ARM) || defined(ENABLE_SSE) MS_FLOAT32X4 src[64]; MS_FLOAT32X4 t[56]; MS_FLOAT32X4 m[49]; @@ -3443,7 +3459,10 @@ void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float } } } +} #else +void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { float src[64]; float t[56]; float m[49]; @@ -3494,12 +3513,12 @@ void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float } } } -#endif } +#endif +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { -#if defined(ENABLE_ARM) || defined(ENABLE_SSE) MS_FLOAT32X4 src[64]; MS_FLOAT32X4 t[56]; MS_FLOAT32X4 m[49]; @@ -3575,7 +3594,10 @@ void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const fl } } } +} #else +void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { float src[64]; float t[56]; float m[49]; @@ -3628,12 +3650,12 @@ void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const fl } } } -#endif } +#endif +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { -#if defined(ENABLE_ARM) || defined(ENABLE_SSE) MS_FLOAT32X4 src[64]; MS_FLOAT32X4 t[56]; MS_FLOAT32X4 m[49]; @@ -3717,7 +3739,10 @@ void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const f } } } +} #else +void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { float src[64]; float t[56]; float m[49]; @@ -3771,8 +3796,8 @@ void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const f } } } -#endif } +#endif // Reference to the paper "Fast Algorithms for Convolutional Neural Networks" // Utilize cost model to compute performance gain. diff --git a/mindspore/lite/nnacl/winograd_utils.h b/mindspore/lite/nnacl/fp32/winograd_utils.h similarity index 100% rename from mindspore/lite/nnacl/winograd_utils.h rename to mindspore/lite/nnacl/fp32/winograd_utils.h diff --git a/mindspore/lite/nnacl/int8/add_int8.c b/mindspore/lite/nnacl/int8/add_int8.c index 662b47ab7c..6cba727af7 100644 --- a/mindspore/lite/nnacl/int8/add_int8.c +++ b/mindspore/lite/nnacl/int8/add_int8.c @@ -20,10 +20,64 @@ #endif #ifdef ENABLE_AVX #include -#include "nnacl/x86_64_avx/common_utils.h" +#include "nnacl/intrinsics/avx/common_utils.h" #endif #include "nnacl/int8/fixed_point.h" +#ifdef ENABLE_ARM +void AddInt8InputRounding(int32x4_t *in1, int32x4_t *in2, int32x4_t *in3, int32x4_t *in4, const int32x4_t left_vec, + const int32x4_t right_vec, const int32_t multiplier) { + // Apply left shift + *in1 = vmulq_s32(*in1, left_vec); + *in2 = vmulq_s32(*in2, left_vec); + *in3 = vmulq_s32(*in3, left_vec); + *in4 = vmulq_s32(*in4, left_vec); + + // Apply the fixed-point part of the multiplier. + *in1 = vqrdmulhq_n_s32(*in1, multiplier); + *in2 = vqrdmulhq_n_s32(*in2, multiplier); + *in3 = vqrdmulhq_n_s32(*in3, multiplier); + *in4 = vqrdmulhq_n_s32(*in4, multiplier); + + // Apply right shift + *in1 = vqaddq_s32(*in1, vshrq_n_s32(vandq_s32(*in1, right_vec), 31)); + *in2 = vqaddq_s32(*in2, vshrq_n_s32(vandq_s32(*in2, right_vec), 31)); + *in3 = vqaddq_s32(*in3, vshrq_n_s32(vandq_s32(*in3, right_vec), 31)); + *in4 = vqaddq_s32(*in4, vshrq_n_s32(vandq_s32(*in4, right_vec), 31)); + + *in1 = vrshlq_s32(*in1, right_vec); + *in2 = vrshlq_s32(*in2, right_vec); + *in3 = vrshlq_s32(*in3, right_vec); + *in4 = vrshlq_s32(*in4, right_vec); +} + +void AddInt8OutputRounding(int32x4_t *out1, int32x4_t *out2, int32x4_t *out3, int32x4_t *out4, const int32x4_t left_vec, + const int32x4_t right_vec, const int32_t multiplier) { + // Apply left shift + *out1 = vshlq_s32(*out1, left_vec); + *out2 = vshlq_s32(*out2, left_vec); + *out3 = vshlq_s32(*out3, left_vec); + *out4 = vshlq_s32(*out4, left_vec); + + // Apply the fixed-point part of the multiplier. + *out1 = vqrdmulhq_n_s32(*out1, multiplier); + *out2 = vqrdmulhq_n_s32(*out2, multiplier); + *out3 = vqrdmulhq_n_s32(*out3, multiplier); + *out4 = vqrdmulhq_n_s32(*out4, multiplier); + + // Apply right shift + *out1 = vqaddq_s32(*out1, vshrq_n_s32(vandq_s32(*out1, right_vec), 31)); + *out2 = vqaddq_s32(*out2, vshrq_n_s32(vandq_s32(*out2, right_vec), 31)); + *out3 = vqaddq_s32(*out3, vshrq_n_s32(vandq_s32(*out3, right_vec), 31)); + *out4 = vqaddq_s32(*out4, vshrq_n_s32(vandq_s32(*out4, right_vec), 31)); + + *out1 = vrshlq_s32(*out1, right_vec); + *out2 = vrshlq_s32(*out2, right_vec); + *out3 = vrshlq_s32(*out3, right_vec); + *out4 = vrshlq_s32(*out4, right_vec); +} +#endif + void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, AddQuantParameter *params) { int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_args_.left_shift_); int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_args_.left_shift_); @@ -68,44 +122,8 @@ void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int siz int32x4_t in1_3 = vmovl_s16(vget_low_s16(in1_zp_high)); int32x4_t in1_4 = vmovl_s16(vget_high_s16(in1_zp_high)); - // Apply left shift - in0_1 = vmulq_s32(in0_1, in0_left_vec); - in0_2 = vmulq_s32(in0_2, in0_left_vec); - in0_3 = vmulq_s32(in0_3, in0_left_vec); - in0_4 = vmulq_s32(in0_4, in0_left_vec); - in1_1 = vmulq_s32(in1_1, in1_left_vec); - in1_2 = vmulq_s32(in1_2, in1_left_vec); - in1_3 = vmulq_s32(in1_3, in1_left_vec); - in1_4 = vmulq_s32(in1_4, in1_left_vec); - - // Apply the fixed-point part of the multiplier. - in0_1 = vqrdmulhq_n_s32(in0_1, params->in0_args_.multiplier_); - in0_2 = vqrdmulhq_n_s32(in0_2, params->in0_args_.multiplier_); - in0_3 = vqrdmulhq_n_s32(in0_3, params->in0_args_.multiplier_); - in0_4 = vqrdmulhq_n_s32(in0_4, params->in0_args_.multiplier_); - in1_1 = vqrdmulhq_n_s32(in1_1, params->in1_args_.multiplier_); - in1_2 = vqrdmulhq_n_s32(in1_2, params->in1_args_.multiplier_); - in1_3 = vqrdmulhq_n_s32(in1_3, params->in1_args_.multiplier_); - in1_4 = vqrdmulhq_n_s32(in1_4, params->in1_args_.multiplier_); - - // Apply right shift - in0_1 = vqaddq_s32(in0_1, vshrq_n_s32(vandq_s32(in0_1, in0_right_vec), 31)); - in0_2 = vqaddq_s32(in0_2, vshrq_n_s32(vandq_s32(in0_2, in0_right_vec), 31)); - in0_3 = vqaddq_s32(in0_3, vshrq_n_s32(vandq_s32(in0_3, in0_right_vec), 31)); - in0_4 = vqaddq_s32(in0_4, vshrq_n_s32(vandq_s32(in0_4, in0_right_vec), 31)); - in1_1 = vqaddq_s32(in1_1, vshrq_n_s32(vandq_s32(in1_1, in1_right_vec), 31)); - in1_2 = vqaddq_s32(in1_2, vshrq_n_s32(vandq_s32(in1_2, in1_right_vec), 31)); - in1_3 = vqaddq_s32(in1_3, vshrq_n_s32(vandq_s32(in1_3, in1_right_vec), 31)); - in1_4 = vqaddq_s32(in1_4, vshrq_n_s32(vandq_s32(in1_4, in1_right_vec), 31)); - - in0_1 = vrshlq_s32(in0_1, in0_right_vec); - in0_2 = vrshlq_s32(in0_2, in0_right_vec); - in0_3 = vrshlq_s32(in0_3, in0_right_vec); - in0_4 = vrshlq_s32(in0_4, in0_right_vec); - in1_1 = vrshlq_s32(in1_1, in1_right_vec); - in1_2 = vrshlq_s32(in1_2, in1_right_vec); - in1_3 = vrshlq_s32(in1_3, in1_right_vec); - in1_4 = vrshlq_s32(in1_4, in1_right_vec); + AddInt8InputRounding(&in0_1, &in0_2, &in0_3, &in0_4, in0_left_vec, in0_right_vec, params->in0_args_.multiplier_); + AddInt8InputRounding(&in1_1, &in1_2, &in1_3, &in1_4, in1_left_vec, in1_right_vec, params->in1_args_.multiplier_); /* calculate output */ int32x4_t out1 = vaddq_s32(in0_1, in1_1); @@ -113,28 +131,7 @@ void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int siz int32x4_t out3 = vaddq_s32(in0_3, in1_3); int32x4_t out4 = vaddq_s32(in0_4, in1_4); - // Apply left shift - out1 = vshlq_s32(out1, out_left_vec); - out2 = vshlq_s32(out2, out_left_vec); - out3 = vshlq_s32(out3, out_left_vec); - out4 = vshlq_s32(out4, out_left_vec); - - // Apply the fixed-point part of the multiplier. - out1 = vqrdmulhq_n_s32(out1, params->out_multiplier_); - out2 = vqrdmulhq_n_s32(out2, params->out_multiplier_); - out3 = vqrdmulhq_n_s32(out3, params->out_multiplier_); - out4 = vqrdmulhq_n_s32(out4, params->out_multiplier_); - - // Apply right shift - out1 = vqaddq_s32(out1, vshrq_n_s32(vandq_s32(out1, out_right_vec), 31)); - out2 = vqaddq_s32(out2, vshrq_n_s32(vandq_s32(out2, out_right_vec), 31)); - out3 = vqaddq_s32(out3, vshrq_n_s32(vandq_s32(out3, out_right_vec), 31)); - out4 = vqaddq_s32(out4, vshrq_n_s32(vandq_s32(out4, out_right_vec), 31)); - - out1 = vrshlq_s32(out1, out_right_vec); - out2 = vrshlq_s32(out2, out_right_vec); - out3 = vrshlq_s32(out3, out_right_vec); - out4 = vrshlq_s32(out4, out_right_vec); + AddInt8OutputRounding(&out1, &out2, &out3, &out4, out_left_vec, out_right_vec, params->out_multiplier_); const int16x4_t out1_s16 = vmovn_s32(out1); const int16x4_t out2_s16 = vmovn_s32(out2); @@ -200,25 +197,8 @@ void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, i int32x4_t ele2 = vmovl_s16(vget_high_s16(ele_zp_low)); int32x4_t ele3 = vmovl_s16(vget_low_s16(ele_zp_high)); int32x4_t ele4 = vmovl_s16(vget_high_s16(ele_zp_high)); - // Apply left shift - ele1 = vmulq_s32(ele1, ele_left_vec); - ele2 = vmulq_s32(ele2, ele_left_vec); - ele3 = vmulq_s32(ele3, ele_left_vec); - ele4 = vmulq_s32(ele4, ele_left_vec); - // Apply the fixed-point part of the multiplier. - ele1 = vqrdmulhq_n_s32(ele1, ele_args->multiplier_); - ele2 = vqrdmulhq_n_s32(ele2, ele_args->multiplier_); - ele3 = vqrdmulhq_n_s32(ele3, ele_args->multiplier_); - ele4 = vqrdmulhq_n_s32(ele4, ele_args->multiplier_); - // Apply right shift - ele1 = vqaddq_s32(ele1, vshrq_n_s32(vandq_s32(ele1, ele_right_vec), 31)); - ele2 = vqaddq_s32(ele2, vshrq_n_s32(vandq_s32(ele2, ele_right_vec), 31)); - ele3 = vqaddq_s32(ele3, vshrq_n_s32(vandq_s32(ele3, ele_right_vec), 31)); - ele4 = vqaddq_s32(ele4, vshrq_n_s32(vandq_s32(ele4, ele_right_vec), 31)); - ele1 = vrshlq_s32(ele1, ele_right_vec); - ele2 = vrshlq_s32(ele2, ele_right_vec); - ele3 = vrshlq_s32(ele3, ele_right_vec); - ele4 = vrshlq_s32(ele4, ele_right_vec); + + AddInt8InputRounding(&ele1, &ele2, &ele3, &ele4, ele_left_vec, ele_right_vec, ele_args->multiplier_); for (; index <= size - 16; index += 16) { const int8x16_t ptr_src = vld1q_s8(ptr_in + index); @@ -234,28 +214,7 @@ void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, i int32x4_t ptr3 = vmovl_s16(vget_low_s16(ptr_zp_high)); int32x4_t ptr4 = vmovl_s16(vget_high_s16(ptr_zp_high)); - // Apply left shift - ptr1 = vmulq_s32(ptr1, ptr_left_vec); - ptr2 = vmulq_s32(ptr2, ptr_left_vec); - ptr3 = vmulq_s32(ptr3, ptr_left_vec); - ptr4 = vmulq_s32(ptr4, ptr_left_vec); - - // Apply the fixed-point part of the multiplier. - ptr1 = vqrdmulhq_n_s32(ptr1, ptr_args->multiplier_); - ptr2 = vqrdmulhq_n_s32(ptr2, ptr_args->multiplier_); - ptr3 = vqrdmulhq_n_s32(ptr3, ptr_args->multiplier_); - ptr4 = vqrdmulhq_n_s32(ptr4, ptr_args->multiplier_); - - // Apply right shift - ptr1 = vqaddq_s32(ptr1, vshrq_n_s32(vandq_s32(ptr1, ptr_right_vec), 31)); - ptr2 = vqaddq_s32(ptr2, vshrq_n_s32(vandq_s32(ptr2, ptr_right_vec), 31)); - ptr3 = vqaddq_s32(ptr3, vshrq_n_s32(vandq_s32(ptr3, ptr_right_vec), 31)); - ptr4 = vqaddq_s32(ptr4, vshrq_n_s32(vandq_s32(ptr4, ptr_right_vec), 31)); - - ptr1 = vrshlq_s32(ptr1, ptr_right_vec); - ptr2 = vrshlq_s32(ptr2, ptr_right_vec); - ptr3 = vrshlq_s32(ptr3, ptr_right_vec); - ptr4 = vrshlq_s32(ptr4, ptr_right_vec); + AddInt8InputRounding(&ptr1, &ptr2, &ptr3, &ptr4, ptr_left_vec, ptr_right_vec, ptr_args->multiplier_); /* calculate output */ int32x4_t out1 = vaddq_s32(ptr1, ele1); @@ -263,28 +222,7 @@ void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, i int32x4_t out3 = vaddq_s32(ptr3, ele3); int32x4_t out4 = vaddq_s32(ptr4, ele4); - // Apply output left shift - out1 = vshlq_s32(out1, out_left_vec); - out2 = vshlq_s32(out2, out_left_vec); - out3 = vshlq_s32(out3, out_left_vec); - out4 = vshlq_s32(out4, out_left_vec); - - // Apply output fixed-point part of the multiplier. - out1 = vqrdmulhq_n_s32(out1, params->out_multiplier_); - out2 = vqrdmulhq_n_s32(out2, params->out_multiplier_); - out3 = vqrdmulhq_n_s32(out3, params->out_multiplier_); - out4 = vqrdmulhq_n_s32(out4, params->out_multiplier_); - - // Apply output right shift - out1 = vqaddq_s32(out1, vshrq_n_s32(vandq_s32(out1, out_right_vec), 31)); - out2 = vqaddq_s32(out2, vshrq_n_s32(vandq_s32(out2, out_right_vec), 31)); - out3 = vqaddq_s32(out3, vshrq_n_s32(vandq_s32(out3, out_right_vec), 31)); - out4 = vqaddq_s32(out4, vshrq_n_s32(vandq_s32(out4, out_right_vec), 31)); - - out1 = vrshlq_s32(out1, out_right_vec); - out2 = vrshlq_s32(out2, out_right_vec); - out3 = vrshlq_s32(out3, out_right_vec); - out4 = vrshlq_s32(out4, out_right_vec); + AddInt8OutputRounding(&out1, &out2, &out3, &out4, out_left_vec, out_right_vec, params->out_multiplier_); const int16x4_t out1_s16 = vmovn_s32(out1); const int16x4_t out2_s16 = vmovn_s32(out2); @@ -300,7 +238,6 @@ void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, i vst1q_s8(output + index, int8_out); } #endif - for (; index < size; index++) { const int32_t ptr_left = (ptr_in[index] + ptr_args->zp_) * ptr_left_shift; const int32_t ele_left = (element_in + ele_args->zp_) * ele_left_shift; @@ -329,6 +266,43 @@ int BroadcastAddInt8(const int8_t *in0, const int8_t *in1, int8_t *tile_in0, int } #ifdef ENABLE_AVX +void AddInt8Rounding(__m128i *in1, __m128i *in2, __m128i *in3, __m128i *in4, const __m128i left_vec, + const int32_t right_shift, const __m128i multiplier) { + // Apply left shift + *in1 = _mm_mullo_epi32(*in1, left_vec); + *in2 = _mm_mullo_epi32(*in2, left_vec); + *in3 = _mm_mullo_epi32(*in3, left_vec); + *in4 = _mm_mullo_epi32(*in4, left_vec); + + // Apply the fixed-point part of the multiplier. + *in1 = _mm_qrdmulh_epi32(*in1, multiplier); + *in2 = _mm_qrdmulh_epi32(*in2, multiplier); + *in3 = _mm_qrdmulh_epi32(*in3, multiplier); + *in4 = _mm_qrdmulh_epi32(*in4, multiplier); + + // Apply right shift + int32_t in1_remainder_mask = (1ll << (right_shift)) - 1; + int32_t in1_remainder_threshold = in1_remainder_mask >> 1; + const __m128i vin1_remainder_mask = _mm_set1_epi32(in1_remainder_mask); + const __m128i vin1_remainder_threshold = _mm_set1_epi32(in1_remainder_threshold); + + const __m128i in1_remainder = + _mm_add_epi32(_mm_and_si128(*in1, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), *in1)); + *in1 = _mm_sub_epi32(_mm_rshr_epi32(*in1, right_shift), _mm_cmpgt_epi32(in1_remainder, vin1_remainder_threshold)); + + const __m128i in2_remainder = + _mm_add_epi32(_mm_and_si128(*in2, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), *in2)); + *in2 = _mm_sub_epi32(_mm_rshr_epi32(*in2, right_shift), _mm_cmpgt_epi32(in2_remainder, vin1_remainder_threshold)); + + const __m128i in3_remainder = + _mm_add_epi32(_mm_and_si128(*in3, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), *in3)); + *in3 = _mm_sub_epi32(_mm_rshr_epi32(*in3, right_shift), _mm_cmpgt_epi32(in3_remainder, vin1_remainder_threshold)); + + const __m128i in4_remainder = + _mm_add_epi32(_mm_and_si128(*in4, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), *in4)); + *in4 = _mm_sub_epi32(_mm_rshr_epi32(*in4, right_shift), _mm_cmpgt_epi32(in4_remainder, vin1_remainder_threshold)); +} + void AddInt8_AVX2(const int8_t *input0, const int8_t *input1, int8_t *output, int size, AddQuantParameter *params) { const int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_args_.left_shift_); const int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_args_.left_shift_); @@ -372,68 +346,8 @@ void AddInt8_AVX2(const int8_t *input0, const int8_t *input1, int8_t *output, in __m128i in1_3 = _mm256_extractf128_si256(tmp_in1, 0); __m128i in1_4 = _mm256_extractf128_si256(tmp_in1, 1); - // Apply left shift - in0_1 = _mm_mullo_epi32(in0_1, in0_left_vec); - in0_2 = _mm_mullo_epi32(in0_2, in0_left_vec); - in0_3 = _mm_mullo_epi32(in0_3, in0_left_vec); - in0_4 = _mm_mullo_epi32(in0_4, in0_left_vec); - in1_1 = _mm_mullo_epi32(in1_1, in1_left_vec); - in1_2 = _mm_mullo_epi32(in1_2, in1_left_vec); - in1_3 = _mm_mullo_epi32(in1_3, in1_left_vec); - in1_4 = _mm_mullo_epi32(in1_4, in1_left_vec); - - // Apply the fixed-point part of the multiplier. - in0_1 = _mm_qrdmulh_epi32(in0_1, in0_multiplier); - in0_2 = _mm_qrdmulh_epi32(in0_2, in0_multiplier); - in0_3 = _mm_qrdmulh_epi32(in0_3, in0_multiplier); - in0_4 = _mm_qrdmulh_epi32(in0_4, in0_multiplier); - in1_1 = _mm_qrdmulh_epi32(in1_1, in1_multiplier); - in1_2 = _mm_qrdmulh_epi32(in1_2, in1_multiplier); - in1_3 = _mm_qrdmulh_epi32(in1_3, in1_multiplier); - in1_4 = _mm_qrdmulh_epi32(in1_4, in1_multiplier); - - // Apply right shift - int32_t in0_remainder_mask = (1ll << (params->in0_args_.right_shift_)) - 1; - int32_t in0_remainder_threshold = in0_remainder_mask >> 1; - const __m128i vin0_remainder_mask = _mm_set1_epi32(in0_remainder_mask); - const __m128i vin0_remainder_threshold = _mm_set1_epi32(in0_remainder_threshold); - const __m128i in0_1_remainder = - _mm_add_epi32(_mm_and_si128(in0_1, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_1)); - in0_1 = _mm_sub_epi32(_mm_rshr_epi32(in0_1, params->in0_args_.right_shift_), - _mm_cmpgt_epi32(in0_1_remainder, vin0_remainder_threshold)); - const __m128i in0_2_remainder = - _mm_add_epi32(_mm_and_si128(in0_2, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_2)); - in0_2 = _mm_sub_epi32(_mm_rshr_epi32(in0_2, params->in0_args_.right_shift_), - _mm_cmpgt_epi32(in0_2_remainder, vin0_remainder_threshold)); - const __m128i in0_3_remainder = - _mm_add_epi32(_mm_and_si128(in0_3, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_3)); - in0_3 = _mm_sub_epi32(_mm_rshr_epi32(in0_3, params->in0_args_.right_shift_), - _mm_cmpgt_epi32(in0_3_remainder, vin0_remainder_threshold)); - const __m128i in0_4_remainder = - _mm_add_epi32(_mm_and_si128(in0_4, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_4)); - in0_4 = _mm_sub_epi32(_mm_rshr_epi32(in0_4, params->in0_args_.right_shift_), - _mm_cmpgt_epi32(in0_4_remainder, vin0_remainder_threshold)); - - int32_t in1_remainder_mask = (1ll << (params->in1_args_.right_shift_)) - 1; - int32_t in1_remainder_threshold = in1_remainder_mask >> 1; - const __m128i vin1_remainder_mask = _mm_set1_epi32(in1_remainder_mask); - const __m128i vin1_remainder_threshold = _mm_set1_epi32(in1_remainder_threshold); - const __m128i in1_1_remainder = - _mm_add_epi32(_mm_and_si128(in1_1, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_1)); - in1_1 = _mm_sub_epi32(_mm_rshr_epi32(in1_1, params->in1_args_.right_shift_), - _mm_cmpgt_epi32(in1_1_remainder, vin1_remainder_threshold)); - const __m128i in1_2_remainder = - _mm_add_epi32(_mm_and_si128(in1_2, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_2)); - in1_2 = _mm_sub_epi32(_mm_rshr_epi32(in1_2, params->in1_args_.right_shift_), - _mm_cmpgt_epi32(in1_2_remainder, vin1_remainder_threshold)); - const __m128i in1_3_remainder = - _mm_add_epi32(_mm_and_si128(in1_3, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_3)); - in1_3 = _mm_sub_epi32(_mm_rshr_epi32(in1_3, params->in1_args_.right_shift_), - _mm_cmpgt_epi32(in1_3_remainder, vin1_remainder_threshold)); - const __m128i in1_4_remainder = - _mm_add_epi32(_mm_and_si128(in1_4, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_4)); - in1_4 = _mm_sub_epi32(_mm_rshr_epi32(in1_4, params->in1_args_.right_shift_), - _mm_cmpgt_epi32(in1_4_remainder, vin1_remainder_threshold)); + AddInt8Rounding(&in0_1, &in0_2, &in0_3, &in0_4, in0_left_vec, params->in0_args_.right_shift_, in0_multiplier); + AddInt8Rounding(&in1_1, &in1_2, &in1_3, &in1_4, in1_left_vec, params->in1_args_.right_shift_, in1_multiplier); /* calculate output */ __m128i out1 = _mm_add_epi32(in0_1, in1_1); @@ -533,39 +447,7 @@ void AddOptInt8_AVX2(const int8_t *ptr_in, const int8_t element_in, int8_t *outp __m128i in1_3 = _mm256_extractf128_si256(tmp_in1, 0); __m128i in1_4 = _mm256_extractf128_si256(tmp_in1, 1); - // Apply left shift - in1_1 = _mm_mullo_epi32(in1_1, in1_left_vec); - in1_2 = _mm_mullo_epi32(in1_2, in1_left_vec); - in1_3 = _mm_mullo_epi32(in1_3, in1_left_vec); - in1_4 = _mm_mullo_epi32(in1_4, in1_left_vec); - - // Apply the fixed-point part of the multiplier. - in1_1 = _mm_qrdmulh_epi32(in1_1, in1_multiplier); - in1_2 = _mm_qrdmulh_epi32(in1_2, in1_multiplier); - in1_3 = _mm_qrdmulh_epi32(in1_3, in1_multiplier); - in1_4 = _mm_qrdmulh_epi32(in1_4, in1_multiplier); - - // Apply right shift - int32_t in1_remainder_mask = (1ll << (params->in1_args_.right_shift_)) - 1; - int32_t in1_remainder_threshold = in1_remainder_mask >> 1; - const __m128i vin1_remainder_mask = _mm_set1_epi32(in1_remainder_mask); - const __m128i vin1_remainder_threshold = _mm_set1_epi32(in1_remainder_threshold); - const __m128i in1_1_remainder = - _mm_add_epi32(_mm_and_si128(in1_1, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_1)); - in1_1 = _mm_sub_epi32(_mm_rshr_epi32(in1_1, params->in1_args_.right_shift_), - _mm_cmpgt_epi32(in1_1_remainder, vin1_remainder_threshold)); - const __m128i in1_2_remainder = - _mm_add_epi32(_mm_and_si128(in1_2, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_2)); - in1_2 = _mm_sub_epi32(_mm_rshr_epi32(in1_2, params->in1_args_.right_shift_), - _mm_cmpgt_epi32(in1_2_remainder, vin1_remainder_threshold)); - const __m128i in1_3_remainder = - _mm_add_epi32(_mm_and_si128(in1_3, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_3)); - in1_3 = _mm_sub_epi32(_mm_rshr_epi32(in1_3, params->in1_args_.right_shift_), - _mm_cmpgt_epi32(in1_3_remainder, vin1_remainder_threshold)); - const __m128i in1_4_remainder = - _mm_add_epi32(_mm_and_si128(in1_4, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_4)); - in1_4 = _mm_sub_epi32(_mm_rshr_epi32(in1_4, params->in1_args_.right_shift_), - _mm_cmpgt_epi32(in1_4_remainder, vin1_remainder_threshold)); + AddInt8Rounding(&in1_1, &in1_2, &in1_3, &in1_4, in1_left_vec, params->in1_args_.right_shift_, in1_multiplier); int index = 0; for (; index <= size - 16; index += 16) { @@ -583,39 +465,7 @@ void AddOptInt8_AVX2(const int8_t *ptr_in, const int8_t element_in, int8_t *outp __m128i in0_3 = _mm256_extractf128_si256(tmp_in0, 0); __m128i in0_4 = _mm256_extractf128_si256(tmp_in0, 1); - // Apply left shift - in0_1 = _mm_mullo_epi32(in0_1, in0_left_vec); - in0_2 = _mm_mullo_epi32(in0_2, in0_left_vec); - in0_3 = _mm_mullo_epi32(in0_3, in0_left_vec); - in0_4 = _mm_mullo_epi32(in0_4, in0_left_vec); - - // Apply the fixed-point part of the multiplier. - in0_1 = _mm_qrdmulh_epi32(in0_1, in0_multiplier); - in0_2 = _mm_qrdmulh_epi32(in0_2, in0_multiplier); - in0_3 = _mm_qrdmulh_epi32(in0_3, in0_multiplier); - in0_4 = _mm_qrdmulh_epi32(in0_4, in0_multiplier); - - // Apply right shift - int32_t in0_remainder_mask = (1ll << (params->in0_args_.right_shift_)) - 1; - int32_t in0_remainder_threshold = in0_remainder_mask >> 1; - const __m128i vin0_remainder_mask = _mm_set1_epi32(in0_remainder_mask); - const __m128i vin0_remainder_threshold = _mm_set1_epi32(in0_remainder_threshold); - const __m128i in0_1_remainder = - _mm_add_epi32(_mm_and_si128(in0_1, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_1)); - in0_1 = _mm_sub_epi32(_mm_rshr_epi32(in0_1, params->in0_args_.right_shift_), - _mm_cmpgt_epi32(in0_1_remainder, vin0_remainder_threshold)); - const __m128i in0_2_remainder = - _mm_add_epi32(_mm_and_si128(in0_2, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_2)); - in0_2 = _mm_sub_epi32(_mm_rshr_epi32(in0_2, params->in0_args_.right_shift_), - _mm_cmpgt_epi32(in0_2_remainder, vin0_remainder_threshold)); - const __m128i in0_3_remainder = - _mm_add_epi32(_mm_and_si128(in0_3, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_3)); - in0_3 = _mm_sub_epi32(_mm_rshr_epi32(in0_3, params->in0_args_.right_shift_), - _mm_cmpgt_epi32(in0_3_remainder, vin0_remainder_threshold)); - const __m128i in0_4_remainder = - _mm_add_epi32(_mm_and_si128(in0_4, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_4)); - in0_4 = _mm_sub_epi32(_mm_rshr_epi32(in0_4, params->in0_args_.right_shift_), - _mm_cmpgt_epi32(in0_4_remainder, vin0_remainder_threshold)); + AddInt8Rounding(&in0_1, &in0_2, &in0_3, &in0_4, in0_left_vec, params->in0_args_.right_shift_, in0_multiplier); /* calculate output */ __m128i out1 = _mm_add_epi32(in0_1, in1_1); diff --git a/mindspore/lite/nnacl/int8/conv3x3_int8.h b/mindspore/lite/nnacl/int8/conv3x3_int8.h index 6111b46ef4..5c1c9818a2 100644 --- a/mindspore/lite/nnacl/int8/conv3x3_int8.h +++ b/mindspore/lite/nnacl/int8/conv3x3_int8.h @@ -24,11 +24,10 @@ #include "nnacl/op_base.h" #include "nnacl/common_func.h" #include "nnacl/conv_parameter.h" -#include "nnacl/winograd_utils.h" +#include "nnacl/int8/fixed_point.h" #include "nnacl/int8/quantize.h" #include "nnacl/matmul_parameter.h" #include "nnacl/int8/matmul_int8.h" -#include "nnacl/winograd_transform.h" #include "nnacl/int8/common_func_int8.h" #ifdef __cplusplus diff --git a/mindspore/lite/nnacl/int8/conv_int8.h b/mindspore/lite/nnacl/int8/conv_int8.h index 208b8dbef8..bc0aab94db 100644 --- a/mindspore/lite/nnacl/int8/conv_int8.h +++ b/mindspore/lite/nnacl/int8/conv_int8.h @@ -24,7 +24,6 @@ #include "nnacl/op_base.h" #include "nnacl/common_func.h" #include "nnacl/conv_parameter.h" -#include "nnacl/winograd_utils.h" #include "nnacl/int8/quantize.h" #include "nnacl/matmul_parameter.h" #include "nnacl/int8/matmul_int8.h" diff --git a/mindspore/lite/nnacl/x86_64_avx/common_utils.c b/mindspore/lite/nnacl/intrinsics/avx/common_utils.c similarity index 98% rename from mindspore/lite/nnacl/x86_64_avx/common_utils.c rename to mindspore/lite/nnacl/intrinsics/avx/common_utils.c index 4a10a3573d..3152b30f97 100644 --- a/mindspore/lite/nnacl/x86_64_avx/common_utils.c +++ b/mindspore/lite/nnacl/intrinsics/avx/common_utils.c @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "nnacl/x86_64_avx/common_utils.h" +#include "nnacl/intrinsics/avx/common_utils.h" #ifdef WIN32 #ifdef ENABLE_AVX #include diff --git a/mindspore/lite/nnacl/x86_64_avx/common_utils.h b/mindspore/lite/nnacl/intrinsics/avx/common_utils.h similarity index 100% rename from mindspore/lite/nnacl/x86_64_avx/common_utils.h rename to mindspore/lite/nnacl/intrinsics/avx/common_utils.h diff --git a/mindspore/lite/nnacl/x86_64_sse/ConvDwFp32IndirectRow.c b/mindspore/lite/nnacl/intrinsics/sse/ConvDwFp32IndirectRow.c similarity index 100% rename from mindspore/lite/nnacl/x86_64_sse/ConvDwFp32IndirectRow.c rename to mindspore/lite/nnacl/intrinsics/sse/ConvDwFp32IndirectRow.c diff --git a/mindspore/lite/nnacl/x86_64_sse/ConvDwFp32Row_sse.c b/mindspore/lite/nnacl/intrinsics/sse/ConvDwFp32Row_sse.c similarity index 100% rename from mindspore/lite/nnacl/x86_64_sse/ConvDwFp32Row_sse.c rename to mindspore/lite/nnacl/intrinsics/sse/ConvDwFp32Row_sse.c diff --git a/mindspore/lite/nnacl/x86_64_sse/DepthwiseFp32_Sse.c b/mindspore/lite/nnacl/intrinsics/sse/DepthwiseFp32_Sse.c similarity index 80% rename from mindspore/lite/nnacl/x86_64_sse/DepthwiseFp32_Sse.c rename to mindspore/lite/nnacl/intrinsics/sse/DepthwiseFp32_Sse.c index 486f1bd87b..587bd676fa 100644 --- a/mindspore/lite/nnacl/x86_64_sse/DepthwiseFp32_Sse.c +++ b/mindspore/lite/nnacl/intrinsics/sse/DepthwiseFp32_Sse.c @@ -17,6 +17,7 @@ #ifdef ENABLE_SSE #include #include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl/intrinsics/sse/sse_common.h" void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w_step, size_t relu, size_t relu6) { @@ -123,18 +124,16 @@ void ConvDwFp32Center(float *dst, const float *src, const float *weight, const f int c2 = DOWN_DIV(width, C2NUM) * C2NUM; int c1 = 0; // c4 loop - for (; c1 < c4; c1 += C4NUM) { - const float *src_kh = src_w; - const float *weight_kh = weight; + for (; c1 < c4; c1 += C4NUM, dst_w += C4NUM * block_channel, src_w += C4NUM * in_sw_step) { + const float *src_kh = src_w, *weight_kh = weight; __m128 dst_w_ma1 = _mm_setzero_ps(); __m128 dst_w_ma2 = _mm_setzero_ps(); __m128 dst_w_ma3 = _mm_setzero_ps(); __m128 dst_w_ma4 = _mm_setzero_ps(); - for (int kh = 0; kh < kernel_h; kh++) { - const float *src_kw = src_kh; - const float *weight_kw = weight_kh; - for (int kw = 0; kw < kernel_w; kw++) { + for (int kh = 0; kh < kernel_h; kh++, src_kh += in_kh_step, weight_kh += kernel_w * C4NUM) { + const float *src_kw = src_kh, *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++, src_kw += in_kw_step, weight_kw += C4NUM) { __m128 src_kw_ma1 = _mm_loadu_ps(src_kw); __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1); @@ -154,13 +153,9 @@ void ConvDwFp32Center(float *dst, const float *src, const float *weight, const f __m128 weight_kw_ma4 = _mm_loadu_ps(weight_kw); __m128 tmp_ma4 = _mm_mul_ps(src_kw_ma4, weight_kw_ma4); dst_w_ma4 = _mm_add_ps(dst_w_ma4, tmp_ma4); - - src_kw += in_kw_step; - weight_kw += C4NUM; } // kernel_w loop - src_kh += in_kh_step; - weight_kh += kernel_w * C4NUM; - } // kernel_h loop + } // kernel_h loop + // add bias relu __m128 bias_ma = _mm_loadu_ps(bias); dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma); @@ -168,39 +163,23 @@ void ConvDwFp32Center(float *dst, const float *src, const float *weight, const f dst_w_ma3 = _mm_add_ps(dst_w_ma3, bias_ma); dst_w_ma4 = _mm_add_ps(dst_w_ma4, bias_ma); - __m128 zero_ma = _mm_setzero_ps(); - if (relu || relu6) { - dst_w_ma1 = _mm_max_ps(zero_ma, dst_w_ma1); - dst_w_ma2 = _mm_max_ps(zero_ma, dst_w_ma2); - dst_w_ma3 = _mm_max_ps(zero_ma, dst_w_ma3); - dst_w_ma4 = _mm_max_ps(zero_ma, dst_w_ma4); - if (relu6) { - __m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); - dst_w_ma1 = _mm_min_ps(const_ma, dst_w_ma1); - dst_w_ma2 = _mm_min_ps(const_ma, dst_w_ma2); - dst_w_ma3 = _mm_min_ps(const_ma, dst_w_ma3); - dst_w_ma4 = _mm_min_ps(const_ma, dst_w_ma4); - } - } + ActBlock4(&dst_w_ma1, &dst_w_ma2, &dst_w_ma3, &dst_w_ma4, relu, relu6); + _mm_storeu_ps(dst_w, dst_w_ma1); _mm_storeu_ps(dst_w + block_channel, dst_w_ma2); _mm_storeu_ps(dst_w + 2 * block_channel, dst_w_ma3); _mm_storeu_ps(dst_w + 3 * block_channel, dst_w_ma4); - - dst_w += C4NUM * block_channel; - src_w += C4NUM * in_sw_step; } // dst_width loop + // c2 loop - for (; c1 < c2; c1 += C2NUM) { - const float *src_kh = src_w; - const float *weight_kh = weight; + for (; c1 < c2; c1 += C2NUM, dst_w += C2NUM * block_channel, src_w += C2NUM * in_sw_step) { + const float *src_kh = src_w, *weight_kh = weight; __m128 dst_w_ma1 = _mm_setzero_ps(); __m128 dst_w_ma2 = _mm_setzero_ps(); - for (int kh = 0; kh < kernel_h; kh++) { - const float *src_kw = src_kh; - const float *weight_kw = weight_kh; - for (int kw = 0; kw < kernel_w; kw++) { + for (int kh = 0; kh < kernel_h; kh++, src_kh += in_kh_step, weight_kh += kernel_w * C4NUM) { + const float *src_kw = src_kh, *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++, src_kw += in_kw_step, weight_kw += C4NUM) { __m128 src_kw_ma1 = _mm_loadu_ps(src_kw); __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1); @@ -210,68 +189,38 @@ void ConvDwFp32Center(float *dst, const float *src, const float *weight, const f __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw); __m128 tmp_ma2 = _mm_mul_ps(src_kw_ma2, weight_kw_ma2); dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2); - - src_kw += in_kw_step; - weight_kw += C4NUM; } // kernel_w loop - src_kh += in_kh_step; - weight_kh += kernel_w * C4NUM; - } // kernel_h loop + } // kernel_h loop // add bias relu __m128 bias_ma = _mm_loadu_ps(bias); dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma); dst_w_ma2 = _mm_add_ps(dst_w_ma2, bias_ma); - __m128 zero_ma = _mm_setzero_ps(); - if (relu || relu6) { - dst_w_ma1 = _mm_max_ps(zero_ma, dst_w_ma1); - dst_w_ma2 = _mm_max_ps(zero_ma, dst_w_ma2); - if (relu6) { - __m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); - dst_w_ma1 = _mm_min_ps(const_ma, dst_w_ma1); - dst_w_ma2 = _mm_min_ps(const_ma, dst_w_ma2); - } - } + + ActBlock2(&dst_w_ma1, &dst_w_ma2, relu, relu6); + _mm_storeu_ps(dst_w, dst_w_ma1); _mm_storeu_ps(dst_w + block_channel, dst_w_ma2); - - dst_w += C2NUM * block_channel; - src_w += C2NUM * in_sw_step; } + // remaining - for (; c1 < width; c1++) { - const float *src_kh = src_w; - const float *weight_kh = weight; + for (; c1 < width; c1++, dst_w += block_channel, src_w += in_sw_step) { + const float *src_kh = src_w, *weight_kh = weight; __m128 dst_w_ma1 = _mm_setzero_ps(); - for (int kh = 0; kh < kernel_h; kh++) { - const float *src_kw = src_kh; - const float *weight_kw = weight_kh; - for (int kw = 0; kw < kernel_w; kw++) { + for (int kh = 0; kh < kernel_h; kh++, src_kh += in_kh_step, weight_kh += kernel_w * C4NUM) { + const float *src_kw = src_kh, *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++, src_kw += in_kw_step, weight_kw += C4NUM) { __m128 src_kw_ma1 = _mm_loadu_ps(src_kw); __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1); dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); - - src_kw += in_kw_step; - weight_kw += C4NUM; } // kernel_w loop - src_kh += in_kh_step; - weight_kh += kernel_w * C4NUM; - } // kernel_h loop + } // kernel_h loop + // add bias relu __m128 bias_ma = _mm_loadu_ps(bias); dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma); - __m128 zero_ma = _mm_setzero_ps(); - if (relu || relu6) { - dst_w_ma1 = _mm_max_ps(zero_ma, dst_w_ma1); - if (relu6) { - __m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); - dst_w_ma1 = _mm_min_ps(const_ma, dst_w_ma1); - } - } + ActBlock1(&dst_w_ma1, relu, relu6); _mm_storeu_ps(dst_w, dst_w_ma1); - - dst_w += block_channel; - src_w += in_sw_step; } dst_h += out_h_step; src_h += in_sh_step; diff --git a/mindspore/lite/nnacl/intrinsics/sse/MatMul_Sse.c b/mindspore/lite/nnacl/intrinsics/sse/MatMul_Sse.c new file mode 100644 index 0000000000..17981ccf5e --- /dev/null +++ b/mindspore/lite/nnacl/intrinsics/sse/MatMul_Sse.c @@ -0,0 +1,293 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_SSE +#include +#include "nnacl/op_base.h" +#include "nnacl/matmul_parameter.h" +#include "nnacl/intrinsics/sse/sse_common.h" +#include "nnacl/base/minimal_filtering_generator.h" + +void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, + int in_channel, int c4_channel) { + const float *src1 = matix_a; + int c16 = DOWN_DIV(in_channel, C16NUM) * C16NUM; + int c8 = DOWN_DIV(in_channel, C8NUM) * C8NUM; + for (int i = 0; i < m; ++i) { + const float *src1_n = src1; + const float *src2_n = matrix_b; + for (int j = 0; j < n; ++j) { + const float *src1_j = src1_n; + int y = 0; + // 16 channel + for (; y < c16; y += C16NUM) { + __m128 dst1 = _mm_setzero_ps(); + __m128 dst2 = _mm_setzero_ps(); + __m128 dst3 = _mm_setzero_ps(); + __m128 dst4 = _mm_setzero_ps(); + const float *src2_y = src2_n; + for (int z = 0; z < k; ++z) { + __m128 ma1 = _mm_loadu_ps(src1_j); + __m128 ma2 = _mm_loadu_ps(src1_j + 4); + __m128 ma3 = _mm_loadu_ps(src1_j + 8); + __m128 ma4 = _mm_loadu_ps(src1_j + 12); + + __m128 mb = _mm_load_ps1(src2_y); + __m128 tmp1 = _mm_mul_ps(ma1, mb); + __m128 tmp2 = _mm_mul_ps(ma2, mb); + __m128 tmp3 = _mm_mul_ps(ma3, mb); + __m128 tmp4 = _mm_mul_ps(ma4, mb); + dst1 = _mm_add_ps(dst1, tmp1); + dst2 = _mm_add_ps(dst2, tmp2); + dst3 = _mm_add_ps(dst3, tmp3); + dst4 = _mm_add_ps(dst4, tmp4); + src1_j += in_channel; + src2_y += n; + } + _mm_storeu_ps(matrix_c, dst1); + _mm_storeu_ps(matrix_c + 4, dst2); + _mm_storeu_ps(matrix_c + 8, dst3); + _mm_storeu_ps(matrix_c + 12, dst4); + src1_j -= in_channel * k; + src1_j += C16NUM; + matrix_c += C16NUM; + } + // 8 channel + for (; y < c8; y += C8NUM) { + __m128 dst1 = _mm_setzero_ps(); + __m128 dst2 = _mm_setzero_ps(); + const float *src2_y = src2_n; + for (int z = 0; z < k; ++z) { + __m128 ma1 = _mm_loadu_ps(src1_j); + __m128 ma2 = _mm_loadu_ps(src1_j + 4); + + __m128 mb = _mm_load_ps1(src2_y); + __m128 tmp1 = _mm_mul_ps(ma1, mb); + __m128 tmp2 = _mm_mul_ps(ma2, mb); + dst1 = _mm_add_ps(dst1, tmp1); + dst2 = _mm_add_ps(dst2, tmp2); + src1_j += in_channel; + src2_y += n; + } + _mm_storeu_ps(matrix_c, dst1); + _mm_storeu_ps(matrix_c + 4, dst2); + src1_j -= in_channel * k; + src1_j += C8NUM; + matrix_c += C8NUM; + } + // remain chann + for (; y < in_channel; ++y) { + float tmp = 0; + for (int z = 0; z < k; ++z) { + tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n]; + } + *matrix_c++ = tmp; + } + src2_n += 1; + } + src1 += k * in_channel; + } +} + +void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, int stride, int write_mode) { + int C8Steps = row * C8NUM, WinoSteps1 = stride * col, WinoSteps2 = stride * C8NUM; + for (int r = row; r > 0; r -= C4NUM) { + const float *srcb_d = b, *bias_d = bias; + float *dst = NULL; + for (int cc = col; cc > 0; cc -= C8NUM) { + if (write_mode != 0) { // writec8 + dst = c; + } + const float *srca_d = a; + __m128 dst1 = _mm_setzero_ps(), dst2 = _mm_setzero_ps(), dst3 = _mm_setzero_ps(), dst4 = _mm_setzero_ps(); + __m128 dst5 = _mm_setzero_ps(), dst6 = _mm_setzero_ps(), dst7 = _mm_setzero_ps(), dst8 = _mm_setzero_ps(); + for (int d = depth; d > 0; --d) { + __m128 b1 = _mm_loadu_ps(srcb_d), b2 = _mm_loadu_ps(srcb_d + 4); + __m128 a1 = _mm_load_ps1(srca_d), a2 = _mm_load_ps1(srca_d + 1); + __m128 tmp1 = _mm_mul_ps(b1, a1), tmp2 = _mm_mul_ps(b2, a1); + __m128 tmp3 = _mm_mul_ps(b1, a2), tmp4 = _mm_mul_ps(b2, a2); + a1 = _mm_load_ps1(srca_d + 2); + dst1 = _mm_add_ps(dst1, tmp1), dst2 = _mm_add_ps(dst2, tmp2); + a2 = _mm_load_ps1(srca_d + 3); + dst3 = _mm_add_ps(dst3, tmp3), dst4 = _mm_add_ps(dst4, tmp4); + tmp1 = _mm_mul_ps(b1, a1), tmp2 = _mm_mul_ps(b2, a1); + tmp3 = _mm_mul_ps(b1, a2), tmp4 = _mm_mul_ps(b2, a2); + dst5 = _mm_add_ps(dst5, tmp1), dst6 = _mm_add_ps(dst6, tmp2); + dst7 = _mm_add_ps(dst7, tmp3), dst8 = _mm_add_ps(dst8, tmp4); + srcb_d += C8NUM, srca_d += C4NUM; + } + + if (bias != NULL) { + DoBiasBlock8(bias_d, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8); + bias_d += C8NUM; + } + + ActBlock8(&dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, act_type); + + if (write_mode == OutType_TileC8) { // WriteWino + c = dst + WinoSteps2; + _mm_storeu_ps(dst, dst1), _mm_storeu_ps(dst + 4, dst2); + dst += WinoSteps1; + _mm_storeu_ps(dst, dst3), _mm_storeu_ps(dst + 4, dst4); + dst += WinoSteps1; + _mm_storeu_ps(dst, dst5), _mm_storeu_ps(dst + 4, dst6); + dst += WinoSteps1; + _mm_storeu_ps(dst, dst7), _mm_storeu_ps(dst + 4, dst8); + } else if (write_mode == OutType_C8) { // WriteC8 + _mm_storeu_ps(c, dst1), _mm_storeu_ps(c + 4, dst2); + _mm_storeu_ps(c + 8, dst3), _mm_storeu_ps(c + 12, dst4); + _mm_storeu_ps(c + 16, dst5), _mm_storeu_ps(c + 20, dst6); + _mm_storeu_ps(c + 24, dst7), _mm_storeu_ps(c + 28, dst8); + c += C8Steps; + } else { + switch (cc) { + case 1: // write1 + c = dst + 1; + WriteCol1(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 1, r); + break; + case 2: // write2 + c = dst + 2; + WriteCol2Opt(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, r); + break; + case 3: // write3 + c = dst + 3; + _mm_store_ss(dst, dst1); + dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 1, dst1); + dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 2, dst1); + WriteCol3(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 3, r); + break; + case 4: // write4 + c = dst + 4; + WriteCol4(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 4, r); + break; + case 5: // write5 + c = dst + 5; + WriteCol5(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 5, r); + break; + case 6: // write6 + c = dst + 6; + WriteCol6(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 6, r); + break; + case 7: // write7 + c = dst + 7; + WriteCol7(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 7, r); + break; + default: // write8 + c = dst + C8NUM; + WriteCol8(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 8, r); + break; + } + } + if (cc <= C8NUM) break; // write end + } + a += C4NUM * depth; + if (write_mode == OutType_C8) c += 32; + if (write_mode == OutType_TileC8) c = dst + WinoSteps2; + if (write_mode == OutType_Nhwc) c = dst - col; + if (r <= C4NUM) break; + } +} + +void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, int stride, size_t writeNhwc, size_t WriteWino) { + size_t DstWinoSteps = stride * C8NUM, WriteWinoSteps = stride * col; + for (int col_tmp = col; col_tmp > 0; col_tmp -= C8NUM) { + const float *srca_d = a; + float *dst = c; + for (int r = row; r > 0; r -= C4NUM) { + const float *srcb_d = b; + __m128 dst1 = _mm_setzero_ps(), dst2 = _mm_setzero_ps(); + __m128 dst3 = _mm_setzero_ps(), dst4 = _mm_setzero_ps(); + __m128 dst5 = _mm_setzero_ps(), dst6 = _mm_setzero_ps(); + __m128 dst7 = _mm_setzero_ps(), dst8 = _mm_setzero_ps(); + for (int d = 0; d < depth; d++) { + __m128 b1 = _mm_loadu_ps(srcb_d), b2 = _mm_loadu_ps(srcb_d + 4); + __m128 a1 = _mm_load_ps1(srca_d), a2 = _mm_load_ps1(srca_d + 1); + __m128 tmp1 = _mm_mul_ps(b1, a1), tmp2 = _mm_mul_ps(b2, a1); + __m128 tmp3 = _mm_mul_ps(b1, a2), tmp4 = _mm_mul_ps(b2, a2); + a1 = _mm_load_ps1(srca_d + 2); + dst1 = _mm_add_ps(dst1, tmp1), dst2 = _mm_add_ps(dst2, tmp2); + a2 = _mm_load_ps1(srca_d + 3); + dst3 = _mm_add_ps(dst3, tmp3), dst4 = _mm_add_ps(dst4, tmp4); + tmp1 = _mm_mul_ps(b1, a1), tmp2 = _mm_mul_ps(b2, a1); + tmp3 = _mm_mul_ps(b1, a2), tmp4 = _mm_mul_ps(b2, a2); + dst5 = _mm_add_ps(dst5, tmp1), dst6 = _mm_add_ps(dst6, tmp2); + dst7 = _mm_add_ps(dst7, tmp3), dst8 = _mm_add_ps(dst8, tmp4); + srcb_d += C8NUM, srca_d += C4NUM; + } + + if (bias != NULL) { + DoBiasBlock8(bias, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8); + } + + ActBlock8(&dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, act_type); + + if (WriteWino != 0) { // WriteWino + _mm_storeu_ps(dst, dst1), _mm_storeu_ps(dst + 4, dst2); + dst += WriteWinoSteps; + _mm_storeu_ps(dst, dst3), _mm_storeu_ps(dst + 4, dst4); + dst += WriteWinoSteps; + _mm_storeu_ps(dst, dst5), _mm_storeu_ps(dst + 4, dst6); + dst += WriteWinoSteps; + _mm_storeu_ps(dst, dst7), _mm_storeu_ps(dst + 4, dst8); + dst += WriteWinoSteps; + } else if (writeNhwc == 0) { // WriteC8 + _mm_storeu_ps(dst, dst1), _mm_storeu_ps(dst + 4, dst2); + _mm_storeu_ps(dst + 8, dst3), _mm_storeu_ps(dst + 12, dst4); + _mm_storeu_ps(dst + 16, dst5), _mm_storeu_ps(dst + 20, dst6); + _mm_storeu_ps(dst + 24, dst7), _mm_storeu_ps(dst + 28, dst8); + dst += 32; + c = dst; + } else { + switch (col) { + case 1: // write1 + WriteCol1(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r); + case 2: // write2 + WriteCol2(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, r); + case 3: // write3 + WriteCol3(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r); + case 4: // write4 + WriteCol4(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r); + case 5: // // write + WriteCol5(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r); + case 6: // write6 + WriteCol6(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r); + case 7: // write7 + WriteCol7(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r); + default: // write8 + WriteCol8(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r); + } + } + if (r <= C4NUM) { // WriteEnd + break; + } + } + b += depth * C8NUM; + bias += (bias != NULL) ? C8NUM : 0; + if (WriteWino != 0) { + c += DstWinoSteps; + } else if (writeNhwc != 0) { + c += C8NUM; + } + if (col_tmp <= C8NUM) { + break; + } + } +} +#endif diff --git a/mindspore/lite/nnacl/x86_64_sse/PackNHWCToNCHWFp32.c b/mindspore/lite/nnacl/intrinsics/sse/PackNHWCToNCHWFp32.c similarity index 100% rename from mindspore/lite/nnacl/x86_64_sse/PackNHWCToNCHWFp32.c rename to mindspore/lite/nnacl/intrinsics/sse/PackNHWCToNCHWFp32.c diff --git a/mindspore/lite/nnacl/x86_64_sse/PostFuncBiasReluC4.c b/mindspore/lite/nnacl/intrinsics/sse/PostFuncBiasReluC4.c similarity index 78% rename from mindspore/lite/nnacl/x86_64_sse/PostFuncBiasReluC4.c rename to mindspore/lite/nnacl/intrinsics/sse/PostFuncBiasReluC4.c index 382a2d2cb4..c049240a03 100644 --- a/mindspore/lite/nnacl/x86_64_sse/PostFuncBiasReluC4.c +++ b/mindspore/lite/nnacl/intrinsics/sse/PostFuncBiasReluC4.c @@ -17,11 +17,10 @@ #ifdef ENABLE_SSE #include #include "nnacl/fp32/common_func_fp32.h" +#include "nnacl/intrinsics/sse/sse_common.h" void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, size_t plane_size, size_t plane_stride, size_t relu_type) { - __m128 relu6 = _mm_set_ps1(6.0); - __m128 zero = _mm_setzero_ps(); size_t stride = oc4div + oc4mod; plane_stride /= sizeof(float); for (size_t loop_c4 = 0; loop_c4 < oc4div; loop_c4 += C4NUM) { @@ -42,19 +41,9 @@ void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t src2 = _mm_add_ps(src2, bias1); src3 = _mm_add_ps(src3, bias1); src4 = _mm_add_ps(src4, bias1); - switch (relu_type) { - case 3: - src1 = _mm_min_ps(src1, relu6); - src2 = _mm_min_ps(src2, relu6); - src3 = _mm_min_ps(src3, relu6); - src4 = _mm_min_ps(src4, relu6); - case 1: - src1 = _mm_max_ps(src1, zero); - src2 = _mm_max_ps(src2, zero); - src3 = _mm_max_ps(src3, zero); - src4 = _mm_max_ps(src4, zero); - break; - } + + ActBlock4(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == 3); + _mm_storeu_ps(dst_c4, src1); dst_c4 += stride; _mm_storeu_ps(dst_c4, src2); @@ -67,20 +56,15 @@ void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t for (; plane_size_tmp > 0; plane_size_tmp -= 1) { __m128 src1 = _mm_loadu_ps(src); src1 = _mm_add_ps(src1, bias1); - switch (relu_type) { - case 3: - src1 = _mm_min_ps(src1, relu6); - case 1: - src1 = _mm_max_ps(src1, zero); - break; - } + + ActBlock1(&src1, relu_type == 1, relu_type == 3); + _mm_storeu_ps(dst_c4, src1); dst_c4 += stride; src += 4; } src += plane_stride; } - if (oc4mod == 0) { return; } @@ -94,13 +78,9 @@ void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t __m128 src1 = _mm_loadu_ps(src); src += 4; src1 = _mm_add_ps(src1, bias1); - switch (relu_type) { - case 3: - src1 = _mm_min_ps(src1, relu6); - case 1: - src1 = _mm_max_ps(src1, zero); - break; - } + + ActBlock1(&src1, relu_type == 1, relu_type == 3); + switch (oc4mod) { case 1: _mm_store_ss(dst_c1, src1); diff --git a/mindspore/lite/nnacl/x86_64_sse/PostFuncBiasReluC8.c b/mindspore/lite/nnacl/intrinsics/sse/PostFuncBiasReluC8.c similarity index 67% rename from mindspore/lite/nnacl/x86_64_sse/PostFuncBiasReluC8.c rename to mindspore/lite/nnacl/intrinsics/sse/PostFuncBiasReluC8.c index 3fbded1ff7..fb72a81467 100644 --- a/mindspore/lite/nnacl/x86_64_sse/PostFuncBiasReluC8.c +++ b/mindspore/lite/nnacl/intrinsics/sse/PostFuncBiasReluC8.c @@ -17,23 +17,21 @@ #ifdef ENABLE_SSE #include #include "nnacl/fp32/common_func_fp32.h" +#include "nnacl/intrinsics/sse/sse_common.h" void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, size_t plane_size, size_t stride, size_t relu_type) { - __m128 relu6 = _mm_set_ps1(6.0); - __m128 zero = _mm_setzero_ps(); stride /= sizeof(float); for (int loop_c8 = 0; !(loop_c8 == oc8div); loop_c8 += C8NUM) { size_t plane_size_tmp = plane_size; float *dst_c8 = dst + loop_c8; - __m128 bias1 = _mm_setzero_ps(); - __m128 bias2 = _mm_setzero_ps(); + __m128 bias1 = _mm_setzero_ps(), bias2 = _mm_setzero_ps(); if (bias != NULL) { bias1 = _mm_loadu_ps(bias); bias2 = _mm_loadu_ps(bias + 4); bias += 8; } - for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM, src += 32) { __m128 src1 = _mm_loadu_ps(src); __m128 src2 = _mm_loadu_ps(src + 4); __m128 src3 = _mm_loadu_ps(src + 8); @@ -42,7 +40,6 @@ void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t __m128 src6 = _mm_loadu_ps(src + 20); __m128 src7 = _mm_loadu_ps(src + 24); __m128 src8 = _mm_loadu_ps(src + 28); - src += 32; src1 = _mm_add_ps(src1, bias1); src2 = _mm_add_ps(src2, bias2); src3 = _mm_add_ps(src3, bias1); @@ -51,27 +48,9 @@ void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t src6 = _mm_add_ps(src6, bias2); src7 = _mm_add_ps(src7, bias1); src8 = _mm_add_ps(src8, bias2); - switch (relu_type) { - case 3: - src1 = _mm_min_ps(src1, relu6); - src2 = _mm_min_ps(src2, relu6); - src3 = _mm_min_ps(src3, relu6); - src4 = _mm_min_ps(src4, relu6); - src5 = _mm_min_ps(src5, relu6); - src6 = _mm_min_ps(src6, relu6); - src7 = _mm_min_ps(src7, relu6); - src8 = _mm_min_ps(src8, relu6); - case 1: - src1 = _mm_max_ps(src1, zero); - src2 = _mm_max_ps(src2, zero); - src3 = _mm_max_ps(src3, zero); - src4 = _mm_max_ps(src4, zero); - src5 = _mm_max_ps(src5, zero); - src6 = _mm_max_ps(src6, zero); - src7 = _mm_max_ps(src7, zero); - src8 = _mm_max_ps(src8, zero); - break; - } + + ActBlock8(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, relu_type); + _mm_storeu_ps(dst_c8, src1); _mm_storeu_ps(dst_c8 + 4, src2); dst_c8 += stride; @@ -85,29 +64,21 @@ void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t _mm_storeu_ps(dst_c8 + 4, src8); dst_c8 += stride; } - for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + for (; plane_size_tmp > 0; plane_size_tmp -= 1, src += 8, dst_c8 += stride) { __m128 src1 = _mm_loadu_ps(src); __m128 src2 = _mm_loadu_ps(src + 4); src1 = _mm_add_ps(src1, bias1); src2 = _mm_add_ps(src2, bias2); - switch (relu_type) { - case 3: - src1 = _mm_min_ps(src1, relu6); - src2 = _mm_min_ps(src2, relu6); - case 1: - src1 = _mm_max_ps(src1, zero); - src2 = _mm_max_ps(src2, zero); - break; - } + + ActBlock2(&src1, &src2, relu_type == 1, relu_type == 3); + _mm_storeu_ps(dst_c8, src1); _mm_storeu_ps(dst_c8 + 4, src2); - dst_c8 += stride; - src += 8; } } - if (oc8mod == 0) { - return; - } + + if ((oc8mod == 0)) return; + __m128 bias1 = _mm_setzero_ps(); __m128 bias2 = _mm_setzero_ps(); if (bias != NULL) { @@ -116,56 +87,42 @@ void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t bias += 8; } float *dst_c1 = dst + oc8div; - for (size_t plane_size_tmp = plane_size; plane_size_tmp > 0; plane_size_tmp -= 1) { + for (size_t plane_size_tmp = plane_size; plane_size_tmp > 0; plane_size_tmp -= 1, src += 8, dst_c1 += stride) { __m128 src1 = _mm_loadu_ps(src); __m128 src2 = _mm_loadu_ps(src + 4); - src += 8; src1 = _mm_add_ps(src1, bias1); src2 = _mm_add_ps(src2, bias2); - switch (relu_type) { - case 3: - src1 = _mm_min_ps(src1, relu6); - src2 = _mm_min_ps(src2, relu6); - case 1: - src1 = _mm_max_ps(src1, zero); - src2 = _mm_max_ps(src2, zero); - break; - } + + ActBlock2(&src1, &src2, relu_type == 1, relu_type == 3); + switch (oc8mod) { case 1: _mm_store_ss(dst_c1, src1); - dst_c1 += stride; break; case 2: _mm_storel_pi((__m64 *)(dst_c1), src1); - dst_c1 += stride; break; case 3: _mm_storel_pi((__m64 *)(dst_c1), src1); src1 = _mm_unpackhi_ps(src1, src1); _mm_store_ss(dst_c1 + 2, src1); - dst_c1 += stride; break; case 4: _mm_storeu_ps(dst_c1, src1); - dst_c1 += stride; break; case 5: _mm_storeu_ps(dst_c1, src1); _mm_store_ss(dst_c1 + 4, src2); - dst_c1 += stride; break; case 6: _mm_storeu_ps(dst_c1, src1); _mm_storel_pi((__m64 *)(dst_c1 + 4), src2); - dst_c1 += stride; break; case 7: _mm_storeu_ps(dst_c1, src1); _mm_storel_pi((__m64 *)(dst_c1 + 4), src2); src2 = _mm_unpackhi_ps(src2, src2); _mm_store_ss(dst_c1 + 6, src2); - dst_c1 += stride; break; } } diff --git a/mindspore/lite/nnacl/intrinsics/sse/TiledC4MatMulFp32.c b/mindspore/lite/nnacl/intrinsics/sse/TiledC4MatMulFp32.c new file mode 100644 index 0000000000..59e4f758b8 --- /dev/null +++ b/mindspore/lite/nnacl/intrinsics/sse/TiledC4MatMulFp32.c @@ -0,0 +1,146 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_SSE +#include +#include "nnacl/fp32/common_func_fp32.h" + +void TiledC4MatmulFp32_Transfer(__m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, const __m128 weight, + const float v1, const float v2, const float v3, const float v4) { + *dst1 = _mm_add_ps(*dst1, _mm_mul_ps(weight, _mm_set_ps1(v1))); + *dst2 = _mm_add_ps(*dst2, _mm_mul_ps(weight, _mm_set_ps1(v2))); + *dst3 = _mm_add_ps(*dst3, _mm_mul_ps(weight, _mm_set_ps1(v3))); + *dst4 = _mm_add_ps(*dst4, _mm_mul_ps(weight, _mm_set_ps1(v4))); +} + +void TiledC4MatmulFp32_LoadData(__m128 *src1, __m128 *src2, __m128 *src3, __m128 *src4, const float *src) { + *src1 = _mm_loadu_ps(src); + *src2 = _mm_loadu_ps(src + 4); + *src3 = _mm_loadu_ps(src + 8); + *src4 = _mm_loadu_ps(src + 12); +} +void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t cal_num, size_t ic4, size_t oc4) { + const float *src_tmp = src; + for (int i = 0; i < oc4; ++i) { + float *dst_tmp = dst; + src = src_tmp; + size_t ic4_tmp = ic4 - 1; + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + 4); + __m128 src3 = _mm_loadu_ps(src + 8); + __m128 src4 = _mm_loadu_ps(src + 12); + src += 16; + __m128 weight_data[4]; + weight_data[0] = _mm_loadu_ps(weight); + weight_data[1] = _mm_loadu_ps(weight + 4); + weight_data[2] = _mm_loadu_ps(weight + 8); + weight_data[3] = _mm_loadu_ps(weight + 12); + weight += 16; + __m128 dst1 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0])); + __m128 dst2 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0])); + __m128 dst3 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0])); + __m128 dst4 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0])); + for (int j = 1; j < 4; ++j) { + TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[j], src1[j], src2[j], src3[j], src4[j]); + } + TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src); + src += 16; + __m128 dst5 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0])); + __m128 dst6 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0])); + __m128 dst7 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0])); + __m128 dst8 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0])); + for (int j = 1; j < 4; ++j) { + TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[j], src1[j], src2[j], src3[j], src4[j]); + } + if (ic4_tmp != 0) { + ic4_tmp -= 1; + TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src); + src += 16; + weight_data[0] = _mm_loadu_ps(weight); + weight_data[1] = _mm_loadu_ps(weight + 4); + weight += 8; + + dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0]))); + dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0]))); + for (; ic4_tmp != 0; ic4_tmp -= 1) { + dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0]))); + dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0]))); + + TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[1], src1[1], src2[1], src3[1], src4[1]); + + weight_data[2] = _mm_loadu_ps(weight); + weight_data[3] = _mm_loadu_ps(weight + 4); + weight += 8; + + TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[2], src1[2], src2[2], src3[2], src4[2]); + + dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[3], _mm_set_ps1(src1[3]))); + dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[3], _mm_set_ps1(src2[3]))); + src1 = _mm_loadu_ps(src); + src2 = _mm_loadu_ps(src + 4); + dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[3], _mm_set_ps1(src3[3]))); + dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[3], _mm_set_ps1(src4[3]))); + src3 = _mm_loadu_ps(src + 8); + src4 = _mm_loadu_ps(src + 12); + src += 16; + + TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[0], src1[0], src2[0], src3[0], src4[0]); + + TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[1], src1[1], src2[1], src3[1], src4[1]); + + TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[2], src1[2], src2[2], src3[2], src4[2]); + + weight_data[0] = _mm_loadu_ps(weight); + weight_data[1] = _mm_loadu_ps(weight + 4); + weight += 8; + + TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[3], src1[3], src2[3], src3[3], src4[3]); + TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src); + src += 16; + + dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0]))); + dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0]))); + } + dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0]))); + dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0]))); + + TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[1], src1[1], src2[1], src3[1], src4[1]); + + weight_data[2] = _mm_loadu_ps(weight); + weight_data[3] = _mm_loadu_ps(weight + 4); + weight += 8; + + TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[2], src1[2], src2[2], src3[2], src4[2]); + + TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[3], src1[3], src2[3], src3[3], src4[3]); + + TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src); + src += 16; + for (int j = 0; j < 4; ++j) { + TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[j], src1[j], src2[j], src3[j], src4[j]); + } + } + _mm_storeu_ps(dst, dst1); + _mm_storeu_ps(dst + 4, dst2); + _mm_storeu_ps(dst + 8, dst3); + _mm_storeu_ps(dst + 12, dst4); + _mm_storeu_ps(dst + 16, dst5); + _mm_storeu_ps(dst + 20, dst6); + _mm_storeu_ps(dst + 24, dst7); + _mm_storeu_ps(dst + 28, dst8); + dst = dst_tmp + cal_num; + } +} +#endif diff --git a/mindspore/lite/nnacl/x86_64_sse/WinogradTrans.c b/mindspore/lite/nnacl/intrinsics/sse/WinogradTrans.c similarity index 80% rename from mindspore/lite/nnacl/x86_64_sse/WinogradTrans.c rename to mindspore/lite/nnacl/intrinsics/sse/WinogradTrans.c index 04e4f2333c..a1821af201 100644 --- a/mindspore/lite/nnacl/x86_64_sse/WinogradTrans.c +++ b/mindspore/lite/nnacl/intrinsics/sse/WinogradTrans.c @@ -36,7 +36,7 @@ void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_ __m128 k6 = _mm_load_ps1(BK + 5 * h); __m128 k7 = _mm_load_ps1(BK + 6 * h); BK += 7 * h; - for (int len_tmp = length; len_tmp > 0; --len_tmp) { + for (int len_tmp = length; len_tmp > 0; --len_tmp, M += 4, SK += 4) { __m128 M1 = _mm_loadu_ps(M); __m128 s0 = _mm_loadu_ps(SK); M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); @@ -54,8 +54,6 @@ void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_ M1 = _mm_add_ps(M1, _mm_mul_ps(s7, k7)); M1 = _mm_add_ps(M1, s1); _mm_storeu_ps(M, M1); - M += 4; - SK += 4; } M -= len_c4; SK += 7 * S_step - len_c4; @@ -66,7 +64,7 @@ void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_ __m128 k3 = _mm_load_ps1(BK + 2 * h); __m128 k4 = _mm_load_ps1(BK + 3 * h); BK += 4 * h; - for (int len_tmp = length; len_tmp > 0; --len_tmp) { + for (int len_tmp = length; len_tmp > 0; --len_tmp, SK += 4, M += 4) { __m128 M1 = _mm_loadu_ps(M); __m128 s0 = _mm_loadu_ps(SK); M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); @@ -78,8 +76,6 @@ void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_ s1 = _mm_add_ps(s1, _mm_mul_ps(s4, k4)); M1 = _mm_add_ps(M1, s1); _mm_storeu_ps(M, M1); - SK += 4; - M += 4; } M -= len_c4; SK += 4 * S_step - len_c4; @@ -89,7 +85,7 @@ void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_ __m128 k2 = _mm_load_ps1(BK + h); __m128 k3 = _mm_load_ps1(BK + 2 * h); BK += 3 * h; - for (int len_tmp = length; len_tmp > 0; --len_tmp) { + for (int len_tmp = length; len_tmp > 0; --len_tmp, SK += 4, M += 4) { __m128 M1 = _mm_loadu_ps(M); __m128 s0 = _mm_loadu_ps(SK); M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); @@ -99,8 +95,6 @@ void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_ M1 = _mm_add_ps(M1, _mm_mul_ps(s3, k3)); M1 = _mm_add_ps(M1, s1); _mm_storeu_ps(M, M1); - SK += 4; - M += 4; } M -= len_c4; SK += 3 * S_step - len_c4; @@ -108,13 +102,11 @@ void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_ for (; k_tmp > 0; k_tmp -= 1) { __m128 k1 = _mm_load_ps1(BK); BK += h; - for (int len_tmp = length; len_tmp > 0; --len_tmp) { + for (int len_tmp = length; len_tmp > 0; --len_tmp, SK += 4, M += 4) { __m128 M1 = _mm_loadu_ps(M); __m128 s0 = _mm_loadu_ps(SK); M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); _mm_storeu_ps(M, M1); - SK += 4; - M += 4; } M -= len_c4; SK += S_step - len_c4; @@ -127,16 +119,14 @@ void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_ } void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { - size_t len_c4 = length * 4; - size_t k_step = len_c4 * k; - for (int h1 = 0; h1 < h; ++h1) { + size_t len_c4 = length * 4, k_step = len_c4 * k; + for (int h1 = 0; h1 < h; ++h1, S += k_step) { const float *BW = B; - for (int ww = 0; ww < w; ++ww) { - const float *SK = S; // r0 - const float *BK = BW; // r1 + for (int ww = 0; ww < w; ++ww, BW += 1, M += len_c4) { + const float *SK = S, *BK = BW; memset(M, 0, len_c4 * sizeof(float)); int k_tmp = k; - for (; k_tmp >= 7; k_tmp -= 7) { + for (; k_tmp >= 7; k_tmp -= 7, M -= len_c4) { __m128 k1 = _mm_load_ps1(BK); __m128 k2 = _mm_load_ps1(BK + h); __m128 k3 = _mm_load_ps1(BK + 2 * h); @@ -145,13 +135,11 @@ void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size __m128 k6 = _mm_load_ps1(BK + 5 * h); __m128 k7 = _mm_load_ps1(BK + 6 * h); BK += 7 * h; - const float *S2 = SK + len_c4; - const float *S3 = S2 + len_c4; - const float *S4 = S3 + len_c4; - const float *S5 = S4 + len_c4; - const float *S6 = S5 + len_c4; - const float *S7 = S6 + len_c4; - for (int len_tmp = length; len_tmp > 0; --len_tmp) { + const float *S2 = SK + len_c4, *S3 = S2 + len_c4; + const float *S4 = S3 + len_c4, *S5 = S4 + len_c4; + const float *S6 = S5 + len_c4, *S7 = S6 + len_c4; + for (int len_tmp = length; len_tmp > 0; + --len_tmp, M += 4, SK += 4, S2 += 4, S3 += 4, S4 += 4, S5 += 4, S6 += 4, S7 += 4) { __m128 M1 = _mm_loadu_ps(M); __m128 s0 = _mm_loadu_ps(SK); M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); @@ -169,19 +157,10 @@ void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size M1 = _mm_add_ps(M1, _mm_mul_ps(s7, k7)); M1 = _mm_add_ps(M1, s1); _mm_storeu_ps(M, M1); - M += 4; - SK += 4; - S2 += 4; - S3 += 4; - S4 += 4; - S5 += 4; - S6 += 4; - S7 += 4; } - M -= len_c4; SK = S7; } - for (; k_tmp >= 4; k_tmp -= 4) { + for (; k_tmp >= 4; k_tmp -= 4, M -= len_c4) { __m128 k1 = _mm_load_ps1(BK); __m128 k2 = _mm_load_ps1(BK + h); __m128 k3 = _mm_load_ps1(BK + 2 * h); @@ -190,7 +169,7 @@ void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size const float *S2 = SK + len_c4; const float *S3 = S2 + len_c4; const float *S4 = S3 + len_c4; - for (int len_tmp = length; len_tmp > 0; --len_tmp) { + for (int len_tmp = length; len_tmp > 0; --len_tmp, M += 4, SK += 4, S2 += 4, S3 += 4, S4 += 4) { __m128 M1 = _mm_loadu_ps(M); __m128 s0 = _mm_loadu_ps(SK); M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); @@ -202,23 +181,17 @@ void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size s1 = _mm_add_ps(s1, _mm_mul_ps(s4, k4)); M1 = _mm_add_ps(M1, s1); _mm_storeu_ps(M, M1); - M += 4; - SK += 4; - S2 += 4; - S3 += 4; - S4 += 4; } - M -= len_c4; SK = S4; } - for (; k_tmp >= 3; k_tmp -= 3) { + for (; k_tmp >= 3; k_tmp -= 3, M -= len_c4) { __m128 k1 = _mm_load_ps1(BK); __m128 k2 = _mm_load_ps1(BK + h); __m128 k3 = _mm_load_ps1(BK + 2 * h); BK += 3 * h; const float *S2 = SK + len_c4; const float *S3 = S2 + len_c4; - for (int len_tmp = length; len_tmp > 0; --len_tmp) { + for (int len_tmp = length; len_tmp > 0; --len_tmp, M += 4, SK += 4, S2 += 4, S3 += 4) { __m128 M1 = _mm_loadu_ps(M); __m128 s0 = _mm_loadu_ps(SK); M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); @@ -228,31 +201,20 @@ void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size M1 = _mm_add_ps(M1, _mm_mul_ps(s3, k3)); M1 = _mm_add_ps(M1, s1); _mm_storeu_ps(M, M1); - M += 4; - SK += 4; - S2 += 4; - S3 += 4; } - M -= len_c4; SK = S3; } - for (; k_tmp >= 1; k_tmp -= 1) { + for (; k_tmp >= 1; k_tmp -= 1, M -= len_c4) { __m128 k1 = _mm_load_ps1(BK); BK += h; - for (int len_tmp = length; len_tmp > 0; --len_tmp) { + for (int len_tmp = length; len_tmp > 0; --len_tmp, M += 4, SK += 4) { __m128 M1 = _mm_loadu_ps(M); __m128 s0 = _mm_loadu_ps(SK); M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); _mm_storeu_ps(M, M1); - M += 4; - SK += 4; } - M -= len_c4; } - BW += 1; - M += len_c4; } - S += k_step; } } #endif diff --git a/mindspore/lite/nnacl/intrinsics/sse/sse_common.c b/mindspore/lite/nnacl/intrinsics/sse/sse_common.c new file mode 100644 index 0000000000..6a358cac42 --- /dev/null +++ b/mindspore/lite/nnacl/intrinsics/sse/sse_common.c @@ -0,0 +1,336 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_SSE +#include +#include "nnacl/op_base.h" +#include "nnacl/intrinsics/sse/sse_common.h" + +void ActBlock1(__m128 *v1, size_t relu, size_t relu6) { + __m128 zero_ma = _mm_setzero_ps(); + if (relu || relu6) { + *v1 = _mm_max_ps(zero_ma, *v1); + } + if (relu6) { + __m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); + *v1 = _mm_min_ps(relu6_ma, *v1); + } +} + +void ActBlock2(__m128 *v1, __m128 *v2, size_t relu, size_t relu6) { + __m128 zero_ma = _mm_setzero_ps(); + if (relu || relu6) { + *v1 = _mm_max_ps(zero_ma, *v1); + *v2 = _mm_max_ps(zero_ma, *v2); + } + if (relu6) { + __m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); + *v1 = _mm_min_ps(relu6_ma, *v1); + *v2 = _mm_min_ps(relu6_ma, *v2); + } +} + +void ActBlock4(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, size_t relu, size_t relu6) { + __m128 zero_ma = _mm_setzero_ps(); + if (relu || relu6) { + *v1 = _mm_max_ps(zero_ma, *v1); + *v2 = _mm_max_ps(zero_ma, *v2); + *v3 = _mm_max_ps(zero_ma, *v3); + *v4 = _mm_max_ps(zero_ma, *v4); + } + if (relu6) { + __m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); + *v1 = _mm_min_ps(relu6_ma, *v1); + *v2 = _mm_min_ps(relu6_ma, *v2); + *v3 = _mm_min_ps(relu6_ma, *v3); + *v4 = _mm_min_ps(relu6_ma, *v4); + } +} + +void ActBlock8(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, __m128 *v5, __m128 *v6, __m128 *v7, __m128 *v8, + size_t relu_type) { + __m128 relu6 = _mm_set_ps1(6.0); + __m128 zero = _mm_setzero_ps(); + switch (relu_type) { + case 3: + *v1 = _mm_min_ps(*v1, relu6); + *v2 = _mm_min_ps(*v2, relu6); + *v3 = _mm_min_ps(*v3, relu6); + *v4 = _mm_min_ps(*v4, relu6); + *v5 = _mm_min_ps(*v5, relu6); + *v6 = _mm_min_ps(*v6, relu6); + *v7 = _mm_min_ps(*v7, relu6); + *v8 = _mm_min_ps(*v8, relu6); + case 1: + *v1 = _mm_max_ps(*v1, zero); + *v2 = _mm_max_ps(*v2, zero); + *v3 = _mm_max_ps(*v3, zero); + *v4 = _mm_max_ps(*v4, zero); + *v5 = _mm_max_ps(*v5, zero); + *v6 = _mm_max_ps(*v6, zero); + *v7 = _mm_max_ps(*v7, zero); + *v8 = _mm_max_ps(*v8, zero); + break; + } +} + +void WriteCol1(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, + __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + _mm_store_ss(*dst, *dst1); + if (r > 1) { + *dst += stride; + _mm_store_ss(*dst, *dst3); + } + if (r > 2) { + *dst += stride; + _mm_store_ss(*dst, *dst5); + } + if (r > 3) { + *dst += stride; + _mm_store_ss(*dst, *dst7); + *dst += stride; + *dst += extra_stride; + } +} + +void WriteCol2(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, + __m128 *dst7, __m128 *dst8, int stride, int r) { + _mm_store_ss(*dst, *dst1); + *dst1 = _mm_shuffle_ps(*dst1, *dst1, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst, *dst1); + if (r > 1) { + *dst += stride; + _mm_store_ss(*dst, *dst3); + *dst3 = _mm_shuffle_ps(*dst3, *dst3, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst, *dst3); + } + if (r > 2) { + *dst += stride; + _mm_store_ss(*dst, *dst5); + *dst5 = _mm_shuffle_ps(*dst5, *dst5, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst, *dst5); + } + if (r > 3) { + *dst += stride; + _mm_store_ss(*dst, *dst7); + *dst7 = _mm_shuffle_ps(*dst7, *dst7, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst, *dst7); + } +} + +void WriteCol2Opt(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, + __m128 *dst7, __m128 *dst8, int stride, int r) { + _mm_store_ss(*dst, *dst1); + *dst1 = _mm_shuffle_ps(*dst1, *dst1, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst + 1, *dst1); + if (r > 1) { + *dst += stride; + _mm_store_ss(*dst, *dst3); + *dst3 = _mm_shuffle_ps(*dst3, *dst3, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst + 1, *dst3); + } + if (r > 2) { + *dst += stride; + _mm_store_ss(*dst, *dst5); + *dst5 = _mm_shuffle_ps(*dst5, *dst5, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst + 1, *dst5); + } + if (r > 3) { + *dst += stride; + _mm_store_ss(*dst, *dst7); + *dst7 = _mm_shuffle_ps(*dst7, *dst7, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst + 1, *dst7); + *dst += stride; + *dst += 2; + } +} + +void WriteCol3(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, + __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + if (r > 1) { + *dst += stride; + _mm_store_ss(*dst, *dst3); + *dst3 = _mm_shuffle_ps(*dst3, *dst3, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst + 1, *dst3); + *dst3 = _mm_shuffle_ps(*dst3, *dst3, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst + 2, *dst3); + } + if (r > 2) { + *dst += stride; + _mm_store_ss(*dst, *dst5); + *dst5 = _mm_shuffle_ps(*dst5, *dst5, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst + 1, *dst5); + *dst5 = _mm_shuffle_ps(*dst5, *dst5, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst + 2, *dst5); + } + if (r > 3) { + *dst += stride; + _mm_store_ss(*dst, *dst7); + *dst7 = _mm_shuffle_ps(*dst7, *dst7, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst + 1, *dst7); + *dst7 = _mm_shuffle_ps(*dst7, *dst7, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst + 2, *dst7); + *dst += stride; + *dst += extra_stride; + } +} + +void WriteCol4(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, + __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + _mm_storeu_ps(*dst, *dst1); + if (r > 1) { + *dst += stride; + _mm_storeu_ps(*dst, *dst3); + } + if (r > 2) { + *dst += stride; + _mm_storeu_ps(*dst, *dst5); + } + if (r > 3) { + *dst += stride; + _mm_storeu_ps(*dst, *dst7); + *dst += stride; + *dst += extra_stride; + } +} +void WriteCol5(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, + __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + _mm_storeu_ps(*dst, *dst1); + _mm_store_ss(*dst + 4, *dst2); + if (r > 1) { + *dst += stride; + _mm_storeu_ps(*dst, *dst3); + _mm_store_ss(*dst + 4, *dst4); + } + if (r > 2) { + *dst += stride; + _mm_storeu_ps(*dst, *dst5); + _mm_store_ss(*dst + 4, *dst6); + } + if (r > 3) { + *dst += stride; + _mm_storeu_ps(*dst, *dst7); + _mm_store_ss(*dst + 4, *dst8); + *dst += stride; + *dst += extra_stride; + } +} +void WriteCol6(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, + __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + _mm_storeu_ps(*dst, *dst1); + _mm_store_ss(*dst + 4, *dst2); + *dst2 = _mm_shuffle_ps(*dst2, *dst2, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst + 5, *dst2); + if (r > 1) { + *dst += stride; + _mm_storeu_ps(*dst, *dst3); + _mm_store_ss(*dst + 4, *dst4); + *dst4 = _mm_shuffle_ps(*dst4, *dst4, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst + 5, *dst4); + } + if (r > 2) { + *dst += stride; + _mm_storeu_ps(*dst, *dst5); + _mm_store_ss(*dst + 4, *dst6); + *dst6 = _mm_shuffle_ps(*dst6, *dst6, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst + 5, *dst6); + } + if (r > 3) { + *dst += stride; + _mm_storeu_ps(*dst, *dst7); + _mm_store_ss(*dst + 4, *dst8); + *dst8 = _mm_shuffle_ps(*dst8, *dst8, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst + 5, *dst8); + *dst += stride; + *dst += extra_stride; + } +} +void WriteCol7(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, + __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + _mm_storeu_ps(*dst, *dst1); + _mm_store_ss(*dst + 4, *dst2); + *dst2 = _mm_shuffle_ps(*dst2, *dst2, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst + 5, *dst2); + *dst2 = _mm_shuffle_ps(*dst2, *dst2, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst + 6, *dst2); + if (r > 1) { + *dst += stride; + _mm_storeu_ps(*dst, *dst3); + _mm_store_ss(*dst + 4, *dst4); + *dst4 = _mm_shuffle_ps(*dst4, *dst4, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst + 5, *dst4); + *dst4 = _mm_shuffle_ps(*dst4, *dst4, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst + 6, *dst4); + } + if (r > 2) { + *dst += stride; + _mm_storeu_ps(*dst, *dst5); + _mm_store_ss(*dst + 4, *dst6); + *dst6 = _mm_shuffle_ps(*dst6, *dst6, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst + 5, *dst6); + *dst6 = _mm_shuffle_ps(*dst6, *dst6, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst + 6, *dst6); + } + if (r > 3) { + *dst += stride; + _mm_storeu_ps(*dst, *dst7); + _mm_store_ss(*dst + 4, *dst8); + *dst8 = _mm_shuffle_ps(*dst8, *dst8, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst + 5, *dst8); + *dst8 = _mm_shuffle_ps(*dst8, *dst8, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(*dst + 6, *dst8); + *dst += stride; + *dst += extra_stride; + } +} +void WriteCol8(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, + __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + _mm_storeu_ps(*dst, *dst1); + _mm_storeu_ps(*dst + 4, *dst2); + if (r > 1) { + *dst += stride; + _mm_storeu_ps(*dst, *dst3); + _mm_storeu_ps(*dst + 4, *dst4); + } + if (r > 2) { + *dst += stride; + _mm_storeu_ps(*dst, *dst5); + _mm_storeu_ps(*dst + 4, *dst6); + } + if (r > 3) { + *dst += stride; + _mm_storeu_ps(*dst, *dst7); + _mm_storeu_ps(*dst + 4, *dst8); + *dst += stride; + *dst += extra_stride; + } +} + +void DoBiasBlock8(const float *bias_ptr, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8) { + __m128 bias1 = _mm_loadu_ps(bias_ptr); + __m128 bias2 = _mm_loadu_ps(bias_ptr + C4NUM); + *dst1 = _mm_add_ps(*dst1, bias1); + *dst2 = _mm_add_ps(*dst2, bias2); + *dst3 = _mm_add_ps(*dst3, bias1); + *dst4 = _mm_add_ps(*dst4, bias2); + *dst5 = _mm_add_ps(*dst5, bias1); + *dst6 = _mm_add_ps(*dst6, bias2); + *dst7 = _mm_add_ps(*dst7, bias1); + *dst8 = _mm_add_ps(*dst8, bias2); +} + +#endif diff --git a/mindspore/lite/nnacl/intrinsics/sse/sse_common.h b/mindspore/lite/nnacl/intrinsics/sse/sse_common.h new file mode 100644 index 0000000000..fd48d184c8 --- /dev/null +++ b/mindspore/lite/nnacl/intrinsics/sse/sse_common.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_NNACL_INTRINSICS_SSE_SSE_COMMON_H_ +#define MINDSPORE_LITE_NNACL_INTRINSICS_SSE_SSE_COMMON_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +void ActBlock1(__m128 *v1, size_t relu, size_t relu6); +void ActBlock2(__m128 *v1, __m128 *v2, size_t relu, size_t relu6); +void ActBlock4(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, size_t relu, size_t relu6); +void ActBlock8(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, __m128 *v5, __m128 *v6, __m128 *v7, __m128 *v8, + size_t relu_type); + +void WriteCol1(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, + __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r); +void WriteCol2(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, + __m128 *dst7, __m128 *dst8, int stride, int r); +void WriteCol2Opt(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, + __m128 *dst7, __m128 *dst8, int stride, int r); +void WriteCol3(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, + __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r); +void WriteCol4(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, + __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r); +void WriteCol5(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, + __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r); +void WriteCol6(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, + __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r); +void WriteCol7(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, + __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r); +void WriteCol8(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, + __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r); + +void DoBiasBlock8(const float *bias_ptr, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_NNACL_INTRINSICS_SSE_SSE_COMMON_H_ diff --git a/mindspore/lite/nnacl/x86_64_sse/MatMul_Sse.c b/mindspore/lite/nnacl/x86_64_sse/MatMul_Sse.c deleted file mode 100644 index 75d5a563d7..0000000000 --- a/mindspore/lite/nnacl/x86_64_sse/MatMul_Sse.c +++ /dev/null @@ -1,747 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifdef ENABLE_SSE -#include -#include "nnacl/minimal_filtering_generator.h" -#include "nnacl/op_base.h" - -void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, - int in_channel, int c4_channel) { - const float *src1 = matix_a; - int c16 = DOWN_DIV(in_channel, C16NUM) * C16NUM; - int c8 = DOWN_DIV(in_channel, C8NUM) * C8NUM; - for (int i = 0; i < m; ++i) { - const float *src1_n = src1; - const float *src2_n = matrix_b; - for (int j = 0; j < n; ++j) { - const float *src1_j = src1_n; - int y = 0; - // 16 channel - for (; y < c16; y += C16NUM) { - __m128 dst1 = _mm_setzero_ps(); - __m128 dst2 = _mm_setzero_ps(); - __m128 dst3 = _mm_setzero_ps(); - __m128 dst4 = _mm_setzero_ps(); - const float *src2_y = src2_n; - for (int z = 0; z < k; ++z) { - __m128 ma1 = _mm_loadu_ps(src1_j); - __m128 ma2 = _mm_loadu_ps(src1_j + 4); - __m128 ma3 = _mm_loadu_ps(src1_j + 8); - __m128 ma4 = _mm_loadu_ps(src1_j + 12); - - __m128 mb = _mm_load_ps1(src2_y); - __m128 tmp1 = _mm_mul_ps(ma1, mb); - __m128 tmp2 = _mm_mul_ps(ma2, mb); - __m128 tmp3 = _mm_mul_ps(ma3, mb); - __m128 tmp4 = _mm_mul_ps(ma4, mb); - dst1 = _mm_add_ps(dst1, tmp1); - dst2 = _mm_add_ps(dst2, tmp2); - dst3 = _mm_add_ps(dst3, tmp3); - dst4 = _mm_add_ps(dst4, tmp4); - src1_j += in_channel; - src2_y += n; - } - _mm_storeu_ps(matrix_c, dst1); - _mm_storeu_ps(matrix_c + 4, dst2); - _mm_storeu_ps(matrix_c + 8, dst3); - _mm_storeu_ps(matrix_c + 12, dst4); - src1_j -= in_channel * k; - src1_j += C16NUM; - matrix_c += C16NUM; - } - // 8 channel - for (; y < c8; y += C8NUM) { - __m128 dst1 = _mm_setzero_ps(); - __m128 dst2 = _mm_setzero_ps(); - const float *src2_y = src2_n; - for (int z = 0; z < k; ++z) { - __m128 ma1 = _mm_loadu_ps(src1_j); - __m128 ma2 = _mm_loadu_ps(src1_j + 4); - - __m128 mb = _mm_load_ps1(src2_y); - __m128 tmp1 = _mm_mul_ps(ma1, mb); - __m128 tmp2 = _mm_mul_ps(ma2, mb); - dst1 = _mm_add_ps(dst1, tmp1); - dst2 = _mm_add_ps(dst2, tmp2); - src1_j += in_channel; - src2_y += n; - } - _mm_storeu_ps(matrix_c, dst1); - _mm_storeu_ps(matrix_c + 4, dst2); - src1_j -= in_channel * k; - src1_j += C8NUM; - matrix_c += C8NUM; - } - // remain chann - for (; y < in_channel; ++y) { - float tmp = 0; - for (int z = 0; z < k; ++z) { - tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n]; - } - *matrix_c++ = tmp; - } - src2_n += 1; - } - src1 += k * in_channel; - } -} - -void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, - int col, int stride, int write_mode) { - int C8Steps = row * C8NUM; - int WinoSteps1 = stride * col; - int WinoSteps2 = stride * C8NUM; - for (int r = row; r > 0; r -= C4NUM) { - const float *srcb_d = b; - const float *bias_d = bias; - float *dst = NULL; - for (int cc = col; cc > 0; cc -= C8NUM) { - if (write_mode != 0) { // writec8 - dst = c; - } - const float *srca_d = a; - __m128 dst1 = _mm_setzero_ps(); - __m128 dst2 = _mm_setzero_ps(); - __m128 dst3 = _mm_setzero_ps(); - __m128 dst4 = _mm_setzero_ps(); - __m128 dst5 = _mm_setzero_ps(); - __m128 dst6 = _mm_setzero_ps(); - __m128 dst7 = _mm_setzero_ps(); - __m128 dst8 = _mm_setzero_ps(); - for (int d = depth; d > 0; --d) { - __m128 b1 = _mm_loadu_ps(srcb_d); - __m128 b2 = _mm_loadu_ps(srcb_d + 4); - __m128 a1 = _mm_load_ps1(srca_d); - __m128 a2 = _mm_load_ps1(srca_d + 1); - __m128 tmp1 = _mm_mul_ps(b1, a1); - __m128 tmp2 = _mm_mul_ps(b2, a1); - __m128 tmp3 = _mm_mul_ps(b1, a2); - __m128 tmp4 = _mm_mul_ps(b2, a2); - a1 = _mm_load_ps1(srca_d + 2); - dst1 = _mm_add_ps(dst1, tmp1); - dst2 = _mm_add_ps(dst2, tmp2); - a2 = _mm_load_ps1(srca_d + 3); - dst3 = _mm_add_ps(dst3, tmp3); - dst4 = _mm_add_ps(dst4, tmp4); - tmp1 = _mm_mul_ps(b1, a1); - tmp2 = _mm_mul_ps(b2, a1); - tmp3 = _mm_mul_ps(b1, a2); - tmp4 = _mm_mul_ps(b2, a2); - dst5 = _mm_add_ps(dst5, tmp1); - dst6 = _mm_add_ps(dst6, tmp2); - dst7 = _mm_add_ps(dst7, tmp3); - dst8 = _mm_add_ps(dst8, tmp4); - srcb_d += C8NUM; - srca_d += C4NUM; - } - if (bias != NULL) { - __m128 bias1 = _mm_loadu_ps(bias_d); - __m128 bias2 = _mm_loadu_ps(bias_d + C4NUM); - dst1 = _mm_add_ps(dst1, bias1); - dst2 = _mm_add_ps(dst2, bias2); - dst3 = _mm_add_ps(dst3, bias1); - dst4 = _mm_add_ps(dst4, bias2); - dst5 = _mm_add_ps(dst5, bias1); - dst6 = _mm_add_ps(dst6, bias2); - dst7 = _mm_add_ps(dst7, bias1); - dst8 = _mm_add_ps(dst8, bias2); - bias_d += C8NUM; - } - if (act_type == 3) { - __m128 relu6 = _mm_set_ps(6.0, 6.0, 6.0, 6.0); - dst1 = _mm_min_ps(dst1, relu6); - dst2 = _mm_min_ps(dst2, relu6); - dst3 = _mm_min_ps(dst3, relu6); - dst4 = _mm_min_ps(dst4, relu6); - dst5 = _mm_min_ps(dst5, relu6); - dst6 = _mm_min_ps(dst6, relu6); - dst7 = _mm_min_ps(dst7, relu6); - dst8 = _mm_min_ps(dst8, relu6); - } - if (act_type == 1 || act_type == 3) { - __m128 zero = _mm_setzero_ps(); - dst1 = _mm_max_ps(dst1, zero); - dst2 = _mm_max_ps(dst2, zero); - dst3 = _mm_max_ps(dst3, zero); - dst4 = _mm_max_ps(dst4, zero); - dst5 = _mm_max_ps(dst5, zero); - dst6 = _mm_max_ps(dst6, zero); - dst7 = _mm_max_ps(dst7, zero); - dst8 = _mm_max_ps(dst8, zero); - } - if (write_mode == 2) { // WriteWino - c = dst + WinoSteps2; - _mm_storeu_ps(dst, dst1); - _mm_storeu_ps(dst + 4, dst2); - dst += WinoSteps1; - _mm_storeu_ps(dst, dst3); - _mm_storeu_ps(dst + 4, dst4); - dst += WinoSteps1; - _mm_storeu_ps(dst, dst5); - _mm_storeu_ps(dst + 4, dst6); - dst += WinoSteps1; - _mm_storeu_ps(dst, dst7); - _mm_storeu_ps(dst + 4, dst8); - } else if (write_mode == 0) { // WriteC8 - _mm_storeu_ps(c, dst1); - _mm_storeu_ps(c + 4, dst2); - _mm_storeu_ps(c + 8, dst3); - _mm_storeu_ps(c + 12, dst4); - _mm_storeu_ps(c + 16, dst5); - _mm_storeu_ps(c + 20, dst6); - _mm_storeu_ps(c + 24, dst7); - _mm_storeu_ps(c + 28, dst8); - c += C8Steps; - } else { - switch (cc) { - case 1: // write1 - c = dst + 1; - _mm_store_ss(dst, dst1); - if (r > 1) { - dst += stride; - _mm_store_ss(dst, dst3); - } - if (r > 2) { - dst += stride; - _mm_store_ss(dst, dst5); - } - if (r > 3) { - dst += stride; - _mm_store_ss(dst, dst7); - dst += stride; - dst += 1; - } - break; - case 2: // write2 - c = dst + 2; - _mm_store_ss(dst, dst1); - dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 1, dst1); - if (r > 1) { - dst += stride; - _mm_store_ss(dst, dst3); - dst3 = _mm_shuffle_ps(dst3, dst3, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 1, dst3); - } - if (r > 2) { - dst += stride; - _mm_store_ss(dst, dst5); - dst5 = _mm_shuffle_ps(dst5, dst5, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 1, dst5); - } - if (r > 3) { - dst += stride; - _mm_store_ss(dst, dst7); - dst7 = _mm_shuffle_ps(dst7, dst7, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 1, dst7); - dst += stride; - dst += 2; - } - break; - case 3: // write3 - c = dst + 3; - _mm_store_ss(dst, dst1); - dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 1, dst1); - dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 2, dst1); - if (r > 1) { - dst += stride; - _mm_store_ss(dst, dst3); - dst3 = _mm_shuffle_ps(dst3, dst3, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 1, dst3); - dst3 = _mm_shuffle_ps(dst3, dst3, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 2, dst3); - } - if (r > 2) { - dst += stride; - _mm_store_ss(dst, dst5); - dst5 = _mm_shuffle_ps(dst5, dst5, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 1, dst5); - dst5 = _mm_shuffle_ps(dst5, dst5, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 2, dst5); - } - if (r > 3) { - dst += stride; - _mm_store_ss(dst, dst7); - dst7 = _mm_shuffle_ps(dst7, dst7, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 1, dst7); - dst7 = _mm_shuffle_ps(dst7, dst7, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 2, dst7); - dst += stride; - dst += 3; - } - break; - case 4: // write4 - c = dst + 4; - _mm_storeu_ps(dst, dst1); - if (r > 1) { - dst += stride; - _mm_storeu_ps(dst, dst3); - } - if (r > 2) { - dst += stride; - _mm_storeu_ps(dst, dst5); - } - if (r > 3) { - dst += stride; - _mm_storeu_ps(dst, dst7); - dst += stride; - dst += 4; - } - break; - case 5: // write5 - c = dst + 5; - _mm_storeu_ps(dst, dst1); - _mm_store_ss(dst + 4, dst2); - if (r > 1) { - dst += stride; - _mm_storeu_ps(dst, dst3); - _mm_store_ss(dst + 4, dst4); - } - if (r > 2) { - dst += stride; - _mm_storeu_ps(dst, dst5); - _mm_store_ss(dst + 4, dst6); - } - if (r > 3) { - dst += stride; - _mm_storeu_ps(dst, dst7); - _mm_store_ss(dst + 4, dst8); - dst += stride; - dst += 5; - } - break; - case 6: // write6 - c = dst + 6; - _mm_storeu_ps(dst, dst1); - _mm_store_ss(dst + 4, dst2); - dst2 = _mm_shuffle_ps(dst2, dst2, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 5, dst2); - if (r > 1) { - dst += stride; - _mm_storeu_ps(dst, dst3); - _mm_store_ss(dst + 4, dst4); - dst4 = _mm_shuffle_ps(dst4, dst4, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 5, dst4); - } - if (r > 2) { - dst += stride; - _mm_storeu_ps(dst, dst5); - _mm_store_ss(dst + 4, dst6); - dst6 = _mm_shuffle_ps(dst6, dst6, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 5, dst6); - } - if (r > 3) { - dst += stride; - _mm_storeu_ps(dst, dst7); - _mm_store_ss(dst + 4, dst8); - dst8 = _mm_shuffle_ps(dst8, dst8, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 5, dst8); - dst += stride; - dst += 6; - } - break; - case 7: // write7 - c = dst + 7; - _mm_storeu_ps(dst, dst1); - _mm_store_ss(dst + 4, dst2); - dst2 = _mm_shuffle_ps(dst2, dst2, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 5, dst2); - dst2 = _mm_shuffle_ps(dst2, dst2, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 6, dst2); - if (r > 1) { - dst += stride; - _mm_storeu_ps(dst, dst3); - _mm_store_ss(dst + 4, dst4); - dst4 = _mm_shuffle_ps(dst4, dst4, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 5, dst4); - dst4 = _mm_shuffle_ps(dst4, dst4, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 6, dst4); - } - if (r > 2) { - dst += stride; - _mm_storeu_ps(dst, dst5); - _mm_store_ss(dst + 4, dst6); - dst6 = _mm_shuffle_ps(dst6, dst6, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 5, dst6); - dst6 = _mm_shuffle_ps(dst6, dst6, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 6, dst6); - } - if (r > 3) { - dst += stride; - _mm_storeu_ps(dst, dst7); - _mm_store_ss(dst + 4, dst8); - dst8 = _mm_shuffle_ps(dst8, dst8, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 5, dst8); - dst8 = _mm_shuffle_ps(dst8, dst8, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 6, dst8); - dst += stride; - dst += 7; - } - break; - default: // write8 - c = dst + C8NUM; - _mm_storeu_ps(dst, dst1); - _mm_storeu_ps(dst + 4, dst2); - if (r > 1) { - dst += stride; - _mm_storeu_ps(dst, dst3); - _mm_storeu_ps(dst + 4, dst4); - } - if (r > 2) { - dst += stride; - _mm_storeu_ps(dst, dst5); - _mm_storeu_ps(dst + 4, dst6); - } - if (r > 3) { - dst += stride; - _mm_storeu_ps(dst, dst7); - _mm_storeu_ps(dst + 4, dst8); - dst += stride; - dst += C8NUM; - } - break; - } - } - if (cc <= C8NUM) { // write end - break; - } - } // col end - a += C4NUM * depth; - switch (write_mode) { - case 0: // C8DstStep - c += 32; - break; - case 2: - c = dst + WinoSteps2; - break; - default: - c = dst - col; - break; - } - if (r <= C4NUM) { - break; - } - } -} - -void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, - int col, int stride, size_t writeNhwc, size_t WriteWino) { - size_t DstWinoSteps = stride * C8NUM; - size_t WriteWinoSteps = stride * col; - for (int col_tmp = col; col_tmp > 0; col_tmp -= C8NUM) { - const float *srca_d = a; - float *dst = c; - for (int r = row; r > 0; r -= C4NUM) { - const float *srcb_d = b; - __m128 dst1 = _mm_setzero_ps(); - __m128 dst2 = _mm_setzero_ps(); - __m128 dst3 = _mm_setzero_ps(); - __m128 dst4 = _mm_setzero_ps(); - __m128 dst5 = _mm_setzero_ps(); - __m128 dst6 = _mm_setzero_ps(); - __m128 dst7 = _mm_setzero_ps(); - __m128 dst8 = _mm_setzero_ps(); - for (int d = 0; d < depth; d++) { - __m128 b1 = _mm_loadu_ps(srcb_d); - __m128 b2 = _mm_loadu_ps(srcb_d + 4); - __m128 a1 = _mm_load_ps1(srca_d); - __m128 a2 = _mm_load_ps1(srca_d + 1); - __m128 tmp1 = _mm_mul_ps(b1, a1); - __m128 tmp2 = _mm_mul_ps(b2, a1); - __m128 tmp3 = _mm_mul_ps(b1, a2); - __m128 tmp4 = _mm_mul_ps(b2, a2); - a1 = _mm_load_ps1(srca_d + 2); - dst1 = _mm_add_ps(dst1, tmp1); - dst2 = _mm_add_ps(dst2, tmp2); - a2 = _mm_load_ps1(srca_d + 3); - dst3 = _mm_add_ps(dst3, tmp3); - dst4 = _mm_add_ps(dst4, tmp4); - tmp1 = _mm_mul_ps(b1, a1); - tmp2 = _mm_mul_ps(b2, a1); - tmp3 = _mm_mul_ps(b1, a2); - tmp4 = _mm_mul_ps(b2, a2); - dst5 = _mm_add_ps(dst5, tmp1); - dst6 = _mm_add_ps(dst6, tmp2); - dst7 = _mm_add_ps(dst7, tmp3); - dst8 = _mm_add_ps(dst8, tmp4); - srcb_d += C8NUM; - srca_d += C4NUM; - } - if (bias != NULL) { - __m128 bias1 = _mm_loadu_ps(bias); - __m128 bias2 = _mm_loadu_ps(bias + C4NUM); - dst1 = _mm_add_ps(dst1, bias1); - dst2 = _mm_add_ps(dst2, bias2); - dst3 = _mm_add_ps(dst3, bias1); - dst4 = _mm_add_ps(dst4, bias2); - dst5 = _mm_add_ps(dst5, bias1); - dst6 = _mm_add_ps(dst6, bias2); - dst7 = _mm_add_ps(dst7, bias1); - dst8 = _mm_add_ps(dst8, bias2); - } - if (act_type == 3) { - __m128 relu6 = _mm_set_ps(6.0, 6.0, 6.0, 6.0); - dst1 = _mm_min_ps(dst1, relu6); - dst2 = _mm_min_ps(dst2, relu6); - dst3 = _mm_min_ps(dst3, relu6); - dst4 = _mm_min_ps(dst4, relu6); - dst5 = _mm_min_ps(dst5, relu6); - dst6 = _mm_min_ps(dst6, relu6); - dst7 = _mm_min_ps(dst7, relu6); - dst8 = _mm_min_ps(dst8, relu6); - } - if (act_type == 1 || act_type == 3) { - __m128 zero = _mm_setzero_ps(); - dst1 = _mm_max_ps(dst1, zero); - dst2 = _mm_max_ps(dst2, zero); - dst3 = _mm_max_ps(dst3, zero); - dst4 = _mm_max_ps(dst4, zero); - dst5 = _mm_max_ps(dst5, zero); - dst6 = _mm_max_ps(dst6, zero); - dst7 = _mm_max_ps(dst7, zero); - dst8 = _mm_max_ps(dst8, zero); - } - if (WriteWino != 0) { // WriteWino - _mm_storeu_ps(dst, dst1); - _mm_storeu_ps(dst + 4, dst2); - dst += WriteWinoSteps; - _mm_storeu_ps(dst, dst3); - _mm_storeu_ps(dst + 4, dst4); - dst += WriteWinoSteps; - _mm_storeu_ps(dst, dst5); - _mm_storeu_ps(dst + 4, dst6); - dst += WriteWinoSteps; - _mm_storeu_ps(dst, dst7); - _mm_storeu_ps(dst + 4, dst8); - dst += WriteWinoSteps; - } else if (writeNhwc == 0) { // WriteC8 - _mm_storeu_ps(dst, dst1); - _mm_storeu_ps(dst + 4, dst2); - _mm_storeu_ps(dst + 8, dst3); - _mm_storeu_ps(dst + 12, dst4); - _mm_storeu_ps(dst + 16, dst5); - _mm_storeu_ps(dst + 20, dst6); - _mm_storeu_ps(dst + 24, dst7); - _mm_storeu_ps(dst + 28, dst8); - dst += 32; - c = dst; - } else { - switch (col) { - case 1: // write1 - _mm_store_ss(dst, dst1); - if (r > 1) { - dst += stride; - _mm_store_ss(dst, dst3); - } - if (r > 2) { - dst += stride; - _mm_store_ss(dst, dst5); - } - if (r > 3) { - dst += stride; - _mm_store_ss(dst, dst7); - dst += stride; - } - case 2: // write2 - _mm_store_ss(dst, dst1); - dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst, dst1); - if (r > 1) { - dst += stride; - _mm_store_ss(dst, dst3); - dst3 = _mm_shuffle_ps(dst3, dst3, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst, dst3); - } - if (r > 2) { - dst += stride; - _mm_store_ss(dst, dst5); - dst5 = _mm_shuffle_ps(dst5, dst5, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst, dst5); - } - if (r > 3) { - dst += stride; - _mm_store_ss(dst, dst7); - dst7 = _mm_shuffle_ps(dst7, dst7, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst, dst7); - } - case 3: // write3 - _mm_store_ss(dst, dst1); - dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 1, dst1); - dst1 = _mm_shuffle_ps(dst1 + 2, dst1, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst, dst1); - if (r > 1) { - dst += stride; - _mm_store_ss(dst, dst3); - dst3 = _mm_shuffle_ps(dst3, dst3, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 1, dst3); - dst3 = _mm_shuffle_ps(dst3, dst3, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 2, dst3); - } - if (r > 2) { - dst += stride; - _mm_store_ss(dst, dst5); - dst5 = _mm_shuffle_ps(dst5, dst5, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 1, dst5); - dst5 = _mm_shuffle_ps(dst5, dst5, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 2, dst5); - } - if (r > 3) { - dst += stride; - _mm_store_ss(dst, dst7); - dst7 = _mm_shuffle_ps(dst7, dst7, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 1, dst7); - dst7 = _mm_shuffle_ps(dst7, dst7, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 2, dst7); - dst += stride; - } - case 4: // write4 - _mm_storeu_ps(dst, dst1); - if (r > 1) { - dst += stride; - _mm_storeu_ps(dst, dst3); - } - if (r > 2) { - dst += stride; - _mm_storeu_ps(dst, dst5); - } - if (r > 3) { - dst += stride; - _mm_storeu_ps(dst, dst7); - dst += stride; - } - case 5: // // write5 - _mm_storeu_ps(dst, dst1); - _mm_store_ss(dst + 4, dst2); - if (r > 1) { - dst += stride; - _mm_storeu_ps(dst, dst3); - _mm_store_ss(dst + 4, dst4); - } - if (r > 2) { - dst += stride; - _mm_storeu_ps(dst, dst5); - _mm_store_ss(dst + 4, dst6); - } - if (r > 3) { - dst += stride; - _mm_storeu_ps(dst, dst7); - _mm_store_ss(dst + 4, dst8); - dst += stride; - } - case 6: // write6 - _mm_storeu_ps(dst, dst1); - _mm_store_ss(dst + 4, dst2); - dst2 = _mm_shuffle_ps(dst2, dst2, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 5, dst2); - if (r > 1) { - dst += stride; - _mm_storeu_ps(dst, dst3); - _mm_store_ss(dst + 4, dst4); - dst4 = _mm_shuffle_ps(dst4, dst4, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 5, dst4); - } - if (r > 2) { - dst += stride; - _mm_storeu_ps(dst, dst5); - _mm_store_ss(dst + 4, dst6); - dst6 = _mm_shuffle_ps(dst6, dst6, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 5, dst6); - } - if (r > 3) { - dst += stride; - _mm_storeu_ps(dst, dst7); - _mm_store_ss(dst + 4, dst8); - dst8 = _mm_shuffle_ps(dst8, dst8, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 5, dst8); - dst += stride; - } - case 7: // write7 - _mm_storeu_ps(dst, dst1); - _mm_store_ss(dst + 4, dst2); - dst2 = _mm_shuffle_ps(dst2, dst2, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 5, dst2); - dst2 = _mm_shuffle_ps(dst2, dst2, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 6, dst2); - if (r > 1) { - dst += stride; - _mm_storeu_ps(dst, dst3); - _mm_store_ss(dst + 4, dst4); - dst4 = _mm_shuffle_ps(dst4, dst4, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 5, dst4); - dst4 = _mm_shuffle_ps(dst4, dst4, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 6, dst4); - } - if (r > 2) { - dst += stride; - _mm_storeu_ps(dst, dst5); - _mm_store_ss(dst + 4, dst6); - dst6 = _mm_shuffle_ps(dst6, dst6, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 5, dst6); - dst6 = _mm_shuffle_ps(dst6, dst6, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 6, dst6); - } - if (r > 3) { - dst += stride; - _mm_storeu_ps(dst, dst7); - _mm_store_ss(dst + 4, dst8); - dst8 = _mm_shuffle_ps(dst8, dst8, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 5, dst8); - dst8 = _mm_shuffle_ps(dst8, dst8, _MM_SHUFFLE(0, 3, 2, 1)); - _mm_store_ss(dst + 6, dst8); - dst += stride; - } - default: // write8 - _mm_storeu_ps(dst, dst1); - _mm_storeu_ps(dst + 4, dst2); - if (r > 1) { - dst += stride; - _mm_storeu_ps(dst, dst3); - _mm_storeu_ps(dst + 4, dst4); - } - if (r > 2) { - dst += stride; - _mm_storeu_ps(dst, dst5); - _mm_storeu_ps(dst + 4, dst6); - } - if (r > 3) { - dst += stride; - _mm_storeu_ps(dst, dst7); - _mm_storeu_ps(dst + 4, dst8); - dst += stride; - } - } - } - if (r <= C4NUM) { // WriteEnd - break; - } - } - b += depth * C8NUM; - bias += (bias != NULL) ? C8NUM : 0; - if (WriteWino != 0) { - c += DstWinoSteps; - } else if (writeNhwc != 0) { - c += C8NUM; - } - if (col_tmp <= C8NUM) { - break; - } - } -} -#endif diff --git a/mindspore/lite/nnacl/x86_64_sse/TiledC4MatMulFp32.c b/mindspore/lite/nnacl/x86_64_sse/TiledC4MatMulFp32.c deleted file mode 100644 index 2db1768ce9..0000000000 --- a/mindspore/lite/nnacl/x86_64_sse/TiledC4MatMulFp32.c +++ /dev/null @@ -1,175 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifdef ENABLE_SSE -#include -#include "nnacl/fp32/common_func_fp32.h" -void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t cal_num, size_t ic4, size_t oc4) { - const float *src_tmp = src; - for (int i = 0; i < oc4; ++i) { - float *dst_tmp = dst; - src = src_tmp; - size_t ic4_tmp = ic4 - 1; - __m128 src1 = _mm_loadu_ps(src); - __m128 src2 = _mm_loadu_ps(src + 4); - __m128 src3 = _mm_loadu_ps(src + 8); - __m128 src4 = _mm_loadu_ps(src + 12); - src += 16; - __m128 weight_data[4]; - weight_data[0] = _mm_loadu_ps(weight); - weight_data[1] = _mm_loadu_ps(weight + 4); - weight_data[2] = _mm_loadu_ps(weight + 8); - weight_data[3] = _mm_loadu_ps(weight + 12); - weight += 16; - __m128 dst1 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0])); - __m128 dst2 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0])); - __m128 dst3 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0])); - __m128 dst4 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0])); - for (int j = 1; j < 4; ++j) { - dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[j], _mm_set_ps1(src1[j]))); - dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[j], _mm_set_ps1(src2[j]))); - dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[j], _mm_set_ps1(src3[j]))); - dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[j], _mm_set_ps1(src4[j]))); - } - src1 = _mm_loadu_ps(src); - src2 = _mm_loadu_ps(src + 4); - src3 = _mm_loadu_ps(src + 8); - src4 = _mm_loadu_ps(src + 12); - src += 16; - __m128 dst5 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0])); - __m128 dst6 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0])); - __m128 dst7 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0])); - __m128 dst8 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0])); - for (int j = 1; j < 4; ++j) { - dst5 = _mm_add_ps(dst5, _mm_mul_ps(weight_data[j], _mm_set_ps1(src1[j]))); - dst6 = _mm_add_ps(dst6, _mm_mul_ps(weight_data[j], _mm_set_ps1(src2[j]))); - dst7 = _mm_add_ps(dst7, _mm_mul_ps(weight_data[j], _mm_set_ps1(src3[j]))); - dst8 = _mm_add_ps(dst8, _mm_mul_ps(weight_data[j], _mm_set_ps1(src4[j]))); - } - if (ic4_tmp != 0) { - ic4_tmp -= 1; - src1 = _mm_loadu_ps(src); - src2 = _mm_loadu_ps(src + 4); - src3 = _mm_loadu_ps(src + 8); - src4 = _mm_loadu_ps(src + 12); - src += 16; - weight_data[0] = _mm_loadu_ps(weight); - weight_data[1] = _mm_loadu_ps(weight + 4); - weight += 8; - - dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0]))); - dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0]))); - for (; ic4_tmp != 0; ic4_tmp -= 1) { - dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0]))); - dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0]))); - - dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[1], _mm_set_ps1(src1[1]))); - dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[1], _mm_set_ps1(src2[1]))); - weight_data[2] = _mm_loadu_ps(weight); - weight_data[3] = _mm_loadu_ps(weight + 4); - weight += 8; - dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[1], _mm_set_ps1(src3[1]))); - dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[1], _mm_set_ps1(src4[1]))); - - dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[2], _mm_set_ps1(src1[2]))); - dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[2], _mm_set_ps1(src2[2]))); - dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[2], _mm_set_ps1(src3[2]))); - dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[2], _mm_set_ps1(src4[2]))); - - dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[3], _mm_set_ps1(src1[3]))); - dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[3], _mm_set_ps1(src2[3]))); - src1 = _mm_loadu_ps(src); - src2 = _mm_loadu_ps(src + 4); - dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[3], _mm_set_ps1(src3[3]))); - dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[3], _mm_set_ps1(src4[3]))); - src3 = _mm_loadu_ps(src + 8); - src4 = _mm_loadu_ps(src + 12); - src += 16; - - dst5 = _mm_add_ps(dst5, _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0]))); - dst6 = _mm_add_ps(dst6, _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0]))); - dst7 = _mm_add_ps(dst7, _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0]))); - dst8 = _mm_add_ps(dst8, _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0]))); - - dst5 = _mm_add_ps(dst5, _mm_mul_ps(weight_data[1], _mm_set_ps1(src1[1]))); - dst6 = _mm_add_ps(dst6, _mm_mul_ps(weight_data[1], _mm_set_ps1(src2[1]))); - dst7 = _mm_add_ps(dst7, _mm_mul_ps(weight_data[1], _mm_set_ps1(src3[1]))); - dst8 = _mm_add_ps(dst8, _mm_mul_ps(weight_data[1], _mm_set_ps1(src4[1]))); - - dst5 = _mm_add_ps(dst5, _mm_mul_ps(weight_data[2], _mm_set_ps1(src1[2]))); - dst6 = _mm_add_ps(dst6, _mm_mul_ps(weight_data[2], _mm_set_ps1(src2[2]))); - dst7 = _mm_add_ps(dst7, _mm_mul_ps(weight_data[2], _mm_set_ps1(src3[2]))); - weight_data[0] = _mm_loadu_ps(weight); - weight_data[1] = _mm_loadu_ps(weight + 4); - weight += 8; - dst8 = _mm_add_ps(dst8, _mm_mul_ps(weight_data[2], _mm_set_ps1(src4[2]))); - - dst5 = _mm_add_ps(dst5, _mm_mul_ps(weight_data[3], _mm_set_ps1(src1[3]))); - dst6 = _mm_add_ps(dst6, _mm_mul_ps(weight_data[3], _mm_set_ps1(src2[3]))); - dst7 = _mm_add_ps(dst7, _mm_mul_ps(weight_data[3], _mm_set_ps1(src3[3]))); - src1 = _mm_loadu_ps(src); - src2 = _mm_loadu_ps(src + 4); - dst8 = _mm_add_ps(dst8, _mm_mul_ps(weight_data[3], _mm_set_ps1(src4[3]))); - src3 = _mm_loadu_ps(src + 8); - src4 = _mm_loadu_ps(src + 12); - src += 16; - - dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0]))); - dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0]))); - } - dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0]))); - dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0]))); - - dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[1], _mm_set_ps1(src1[1]))); - dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[1], _mm_set_ps1(src2[1]))); - weight_data[2] = _mm_loadu_ps(weight); - weight_data[3] = _mm_loadu_ps(weight + 4); - weight += 8; - dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[1], _mm_set_ps1(src3[1]))); - dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[1], _mm_set_ps1(src4[1]))); - - dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[2], _mm_set_ps1(src1[2]))); - dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[2], _mm_set_ps1(src2[2]))); - dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[2], _mm_set_ps1(src3[2]))); - dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[2], _mm_set_ps1(src4[2]))); - - dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[3], _mm_set_ps1(src1[3]))); - dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[3], _mm_set_ps1(src2[3]))); - dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[3], _mm_set_ps1(src3[3]))); - src1 = _mm_loadu_ps(src); - src2 = _mm_loadu_ps(src + 4); - dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[3], _mm_set_ps1(src4[3]))); - src3 = _mm_loadu_ps(src + 8); - src4 = _mm_loadu_ps(src + 12); - src += 16; - for (int j = 0; j < 4; ++j) { - dst5 = _mm_add_ps(dst5, _mm_mul_ps(weight_data[j], _mm_set_ps1(src1[j]))); - dst6 = _mm_add_ps(dst6, _mm_mul_ps(weight_data[j], _mm_set_ps1(src2[j]))); - dst7 = _mm_add_ps(dst7, _mm_mul_ps(weight_data[j], _mm_set_ps1(src3[j]))); - dst8 = _mm_add_ps(dst8, _mm_mul_ps(weight_data[j], _mm_set_ps1(src4[j]))); - } - } - _mm_storeu_ps(dst, dst1); - _mm_storeu_ps(dst + 4, dst2); - _mm_storeu_ps(dst + 8, dst3); - _mm_storeu_ps(dst + 12, dst4); - _mm_storeu_ps(dst + 16, dst5); - _mm_storeu_ps(dst + 20, dst6); - _mm_storeu_ps(dst + 24, dst7); - _mm_storeu_ps(dst + 28, dst8); - dst = dst_tmp + cal_num; - } -} -#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc index a43f0145b4..19efba6976 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc @@ -30,7 +30,6 @@ namespace mindspore::kernel { namespace { constexpr int kMaxInputNum = 2; constexpr int kOutputNum = 1; -constexpr int kRank = 4; } // namespace int ResizeBaseCPUKernel::CheckParameters() { @@ -113,7 +112,7 @@ int ResizeBaseCPUKernel::Init() { auto input = in_tensors_.at(0); auto input_shape = input->shape(); - if (!input_shape.empty() && input_shape.size() != kRank) { + if (!input_shape.empty() && input_shape.size() != COMM_SHAPE_SIZE) { MS_LOG(ERROR) << "Resize op support input rank 4, got " << input_shape.size(); return RET_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h index c1a4b3cfc4..3fe11c848d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h @@ -24,7 +24,7 @@ #include "nnacl/fp16/conv_fp16.h" #include "nnacl/fp16/winograd_utils_fp16.h" #include "src/common/utils.h" -#include "nnacl/minimal_filtering_generator.h" +#include "nnacl/base/minimal_filtering_generator.h" namespace mindspore::kernel { class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.cc index bc9c87a4ff..dfc6834aeb 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.cc @@ -49,7 +49,7 @@ int TransposeFp16CPUKernel::Run() { for (int i = 0; i < input_perm->ElementsNum(); ++i) { param->perm_[i] = perm_data[i]; } - for (int i = input_perm->ElementsNum(); i < 8; ++i) { + for (int i = input_perm->ElementsNum(); i < MAX_SHAPE_SIZE; ++i) { param->perm_[i] = 0; } param->num_axes_ = input_perm->ElementsNum(); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space_fp32.cc index b61e6c5acd..d9c823e57c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space_fp32.cc @@ -54,7 +54,7 @@ int BatchToSpaceCPUKernel::Init() { } int BatchToSpaceCPUKernel::ReSize() { - MS_ASSERT(in_tensors_.at(0)->shape().size() == 4); + MS_ASSERT(in_tensors_.at(0)->shape().size() == COMM_SHAPE_SIZE); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h index f2cb63b4e5..9efff6cda2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h @@ -19,8 +19,8 @@ #include #include "src/lite_kernel.h" -#include "nnacl/winograd_transform.h" -#include "nnacl/minimal_filtering_generator.h" +#include "nnacl/fp32/winograd_transform.h" +#include "nnacl/base/minimal_filtering_generator.h" #include "nnacl/fp32/conv_winograd_fp32.h" #include "src/runtime/kernel/arm/base/convolution_base.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_fp32.cc index 16b133989b..e35a3864bf 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_fp32.cc @@ -28,7 +28,6 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_Reverse; namespace mindspore::kernel { - int ReverseCPUKernel::Stride(int index) { int stride = 1; for (size_t i = index + 1; i < in_tensors_.at(0)->shape().size(); ++i) { @@ -39,7 +38,7 @@ int ReverseCPUKernel::Stride(int index) { int ReverseCPUKernel::ReSize() { data_size_ = in_tensors_.at(0)->ElementsNum(); - thread_sz_count_ = MSMIN(thread_count_, data_size_); + thread_sz_count_ = MSMIN(op_parameter_->thread_num_, data_size_); thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); auto *param = reinterpret_cast(op_parameter_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_fp32.h index 8205a8a8b4..c2fd59828e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_fp32.h @@ -18,11 +18,8 @@ #include #include "src/lite_kernel.h" - #include "include/context.h" -#define REVERSE_STRIDE_MAX_SIZE 4 - using mindspore::lite::InnerContext; namespace mindspore::kernel { @@ -31,7 +28,7 @@ class ReverseCPUKernel : public LiteKernel { ReverseCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {} + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~ReverseCPUKernel() { if (tmp_ != nullptr) { free(tmp_); @@ -49,10 +46,9 @@ class ReverseCPUKernel : public LiteKernel { int thread_sz_count_ = 0; int thread_sz_stride_ = 0; int data_size_ = 0; - int strides_[REVERSE_STRIDE_MAX_SIZE] = {0}; - int inCount_[REVERSE_STRIDE_MAX_SIZE] = {0}; - int outCount_[REVERSE_STRIDE_MAX_SIZE] = {0}; - int thread_count_ = 1; + int strides_[COMM_SHAPE_SIZE] = {0}; + int inCount_[COMM_SHAPE_SIZE] = {0}; + int outCount_[COMM_SHAPE_SIZE] = {0}; int *tmp_ = nullptr; float *in_ptr_ = nullptr; float *out_ptr_ = nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc index f5edc39d82..c5a2bc76ea 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc @@ -128,7 +128,7 @@ int TransposeCPUKernel::Run() { for (int i = 0; i < input_perm->ElementsNum(); ++i) { param->perm_[i] = perm_data[i]; } - for (int i = input_perm->ElementsNum(); i < 8; ++i) { + for (int i = input_perm->ElementsNum(); i < MAX_SHAPE_SIZE; ++i) { param->perm_[i] = 0; } } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.h index bcdc4b8044..da04ba54a7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.h @@ -19,8 +19,7 @@ #include #include "src/lite_kernel.h" - -#include "nnacl/winograd_transform.h" +#include "nnacl/fp32/winograd_transform.h" #include "src/runtime/kernel/arm/base/convolution_base.h" namespace mindspore::kernel { diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc index d2df7b3753..789a9c50ba 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc @@ -77,13 +77,13 @@ void MulInt8CPUKernel::CheckIfFastImpl() { auto in_tensor0 = in_tensors_.at(0); auto in_tensor1 = in_tensors_.at(1); if (in_tensor0->ElementsNum() != in_tensor1->ElementsNum()) { - if (in_tensor0->shape().size() == 4 && in_tensor1->shape().size() == 4) { + if (in_tensor0->shape().size() == COMM_SHAPE_SIZE && in_tensor1->shape().size() == COMM_SHAPE_SIZE) { CheckSameShapeSize(in_tensor0->shape(), in_tensor1->shape()); - } else if (in_tensor0->shape().size() == 1 && in_tensor1->shape().size() == 4) { + } else if (in_tensor0->shape().size() == 1 && in_tensor1->shape().size() == COMM_SHAPE_SIZE) { if (in_tensor0->ElementsNum() == in_tensor1->shape()[3]) { fast_hw_broadcast_ = true; } - } else if (in_tensor0->shape().size() == 4 && in_tensor1->shape().size() == 1) { + } else if (in_tensor0->shape().size() == COMM_SHAPE_SIZE && in_tensor1->shape().size() == 1) { if (in_tensor1->ElementsNum() == in_tensor0->shape()[3]) { fast_hw_broadcast_ = true; input1_hw_broadcast_ = true; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/opt_op_handler.cc b/mindspore/lite/src/runtime/kernel/arm/int8/opt_op_handler.cc index eee84c75b7..0d098d6ada 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/opt_op_handler.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/opt_op_handler.cc @@ -43,7 +43,7 @@ void MatMulRInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst, size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, size_t per_channel) { - return MatmulInt8DpNeon64(a, b, dst, UP_ROUND(row, 8), UP_ROUND(col, 8), deep_4, input_sum, bias, mini, maxi, + return MatmulInt8DpNeon64(a, b, dst, UP_ROUND(row, C8NUM), UP_ROUND(col, C8NUM), deep_4, input_sum, bias, mini, maxi, output_zp, multiplier, left_shift, right_shift, row, col, stride, per_channel); } void MatMulDpInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 74a5e2265b..1ab737ccf7 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -67,7 +67,7 @@ if(PLATFORM_ARM32) endif() if("${X86_64_SIMD}" STREQUAL "sse") - file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/x86_64_sse/*.c) + file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/intrinsics/sse/*.c) set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C) set(KERNEL_OP_SRC ${KERNEL_OP_SRC} @@ -78,8 +78,8 @@ endif() if("${X86_64_SIMD}" STREQUAL "avx") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.1 -mavx -mavx2") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1 -mavx -mavx2") - file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/x86_64_sse/*.c - ${LITE_DIR}/nnacl/x86_64_avx/*.c + file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/intrinsics/sse/*.c + ${LITE_DIR}/nnacl/intrinsics/avx/*.c ${LITE_DIR}/nnacl/assembly/avx/*.S) set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C) set(KERNEL_OP_SRC