From: @ling_qiao_min Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -17,7 +17,7 @@ | |||
| #include "micro/coder/opcoders/base/conv2d_base_coder.h" | |||
| #include <string> | |||
| #include <vector> | |||
| #include "nnacl/winograd_utils.h" | |||
| #include "nnacl/fp32/winograd_utils.h" | |||
| #include "nnacl/int8/quantize.h" | |||
| #include "micro/coder/log.h" | |||
| @@ -5,8 +5,10 @@ include_directories(NNACL_DIR) | |||
| if(PLATFORM_ARM32 OR PLATFORM_ARM64) | |||
| if("${CMAKE_BUILD_TYPE}" STREQUAL "Release") | |||
| set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fomit-frame-pointer -fstrict-aliasing -ffunction-sections -fdata-sections -ffast-math") | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fomit-frame-pointer -fstrict-aliasing -ffunction-sections -fdata-sections -ffast-math") | |||
| set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fomit-frame-pointer -fstrict-aliasing \ | |||
| -ffunction-sections -fdata-sections -ffast-math") | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fomit-frame-pointer -fstrict-aliasing \ | |||
| -ffunction-sections -fdata-sections -ffast-math") | |||
| endif() | |||
| endif() | |||
| if("${X86_64_SIMD}" STREQUAL "avx") | |||
| @@ -37,14 +39,14 @@ if(PLATFORM_ARM32) | |||
| endif() | |||
| if("${X86_64_SIMD}" STREQUAL "sse") | |||
| file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/x86_64_sse/*.c) | |||
| file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/intrinsics/sse/*.c) | |||
| set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) | |||
| endif() | |||
| if("${X86_64_SIMD}" STREQUAL "avx") | |||
| file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/x86_64_sse/*.c | |||
| ${NNACL_DIR}/x86_64_avx/*.c | |||
| ${NNACL_DIR}/assembly/avx/*.S) | |||
| file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/intrinsics/sse/*.c | |||
| ${NNACL_DIR}/intrinsics/avx/*.c | |||
| ${NNACL_DIR}/assembly/avx/*.S) | |||
| set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) | |||
| endif() | |||
| @@ -13,10 +13,10 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/minimal_filtering_generator.h" | |||
| #include "nnacl/base/minimal_filtering_generator.h" | |||
| #include <string.h> | |||
| #include <math.h> | |||
| #include "nnacl/winograd_utils.h" | |||
| #include "nnacl/fp32/winograd_utils.h" | |||
| #include "nnacl/errorcode.h" | |||
| void Polynomial(const float *interval, float *m, int degree) { | |||
| @@ -72,6 +72,7 @@ typedef struct SlidingWindowParam { | |||
| int kernel_step_; | |||
| } SlidingWindowParam; | |||
| #define OUPUT_UNIT 2 | |||
| #define DECONV_WINOGRAD_DEFAULT_UNIT 3 | |||
| #define DECONV_WINOGRAD_DEFAULT_TILE 8 | |||
| #define DECONV_WINOGRAD_BUFFER_COUNT 8 | |||
| @@ -15,7 +15,7 @@ | |||
| */ | |||
| #include "nnacl/fp16/deconv_winograd_fp16.h" | |||
| #include "nnacl/minimal_filtering_generator.h" | |||
| #include "nnacl/base/minimal_filtering_generator.h" | |||
| void DeConvWgInputPackFp16(float16_t *src_ptr, float16_t *dst_ptr, int channel, int stride) { | |||
| int ic4div = channel / C4NUM; | |||
| @@ -111,16 +111,16 @@ void DeConvWgMergeFp16(const float16_t *src, float16_t *dst, size_t src_stride, | |||
| } | |||
| void DeConvWgCalWgFp16(float16_t *tile_in, float16_t *tile_out, float16_t *weight_buf, float16_t *tmp_buf, | |||
| float16_t *at_buf, float16_t *a_mid_buf, float16_t *trans_a_buf, bool *transfered, | |||
| float16_t *at_buf, float16_t *a_mid_buf, float16_t *trans_a_buf, bool *transferred, | |||
| float16_t *bt_buf, float16_t *b_tmp_buf, int unit_size, int w_start, int h_start, | |||
| ConvParameter *conv_param, DeConvParam *deconv_param) { | |||
| int winograd_plane = unit_size * unit_size; | |||
| if (!transfered[unit_size]) { | |||
| if (!transferred[unit_size]) { | |||
| WinogradTransLeftFp16(tile_in, at_buf, a_mid_buf, DECONV_WINOGRAD_DEFAULT_UNIT, unit_size, | |||
| DECONV_WINOGRAD_DEFAULT_UNIT, deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE); | |||
| WinogradTransRightFp16(a_mid_buf, at_buf, trans_a_buf, unit_size, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT, | |||
| deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE); | |||
| transfered[unit_size] = true; | |||
| transferred[unit_size] = true; | |||
| } | |||
| for (int index = 0; index < winograd_plane; index++) { | |||
| @@ -311,7 +311,7 @@ void DeconvWgFp16(float16_t *nhwc_input_, float16_t *tile_in, float16_t *tile_ou | |||
| } | |||
| /* compute */ | |||
| bool transfered[DECONV_WINOGRAD_BUFFER_COUNT] = {false}; | |||
| bool transferred[DECONV_WINOGRAD_BUFFER_COUNT] = {false}; | |||
| for (int i = 0; i < deconv_param->compute_size_; i++) { | |||
| DeConvComputeUnit *unit = &deconv_param->compute_units_[i]; | |||
| if (unit->use_winograd_) { | |||
| @@ -328,7 +328,7 @@ void DeconvWgFp16(float16_t *nhwc_input_, float16_t *tile_in, float16_t *tile_ou | |||
| DECONV_WINOGRAD_DEFAULT_TILE * | |||
| deconv_param->oc_up4_; | |||
| DeConvWgCalWgFp16(tile_in, tile_out, (float16_t *)unit->weight_, tmp_buf, unit->winograd_.AT_, mid_a, dst_a, | |||
| transfered, unit->winograd_.BT_, tmp_b, unit->winograd_.kh_, unit->w_start_, unit->h_start_, | |||
| transferred, unit->winograd_.BT_, tmp_b, unit->winograd_.kh_, unit->w_start_, unit->h_start_, | |||
| conv_param, deconv_param); | |||
| } else { | |||
| float16_t *tmp_buf = (float16_t *)unit->tmp_buffer_ + task_id * deconv_param->oc_div4_ * unit->w_size_ * | |||
| @@ -17,7 +17,7 @@ | |||
| #include "nnacl/fp32/conv_depthwise_fp32.h" | |||
| #include "nnacl/common_func.h" | |||
| #include "nnacl/fp32/common_func_fp32.h" | |||
| #include "nnacl/winograd_transform.h" | |||
| #include "nnacl/fp32/winograd_transform.h" | |||
| #ifdef ENABLE_ARM64 | |||
| #include <arm_neon.h> | |||
| #endif | |||
| @@ -17,7 +17,7 @@ | |||
| #include "nnacl/fp32/conv_winograd_fp32.h" | |||
| #include <string.h> | |||
| #include "nnacl/fp32/common_func_fp32.h" | |||
| #include "nnacl/winograd_transform.h" | |||
| #include "nnacl/fp32/winograd_transform.h" | |||
| #include "nnacl/fp32/matmul_fp32.h" | |||
| // fp32 conv winograd | |||
| @@ -24,7 +24,7 @@ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/common_func.h" | |||
| #include "nnacl/conv_parameter.h" | |||
| #include "nnacl/winograd_utils.h" | |||
| #include "nnacl/fp32/winograd_utils.h" | |||
| #include "nnacl/fp32/conv_depthwise_fp32.h" | |||
| typedef float *TmpBufferAddress; | |||
| @@ -22,7 +22,7 @@ | |||
| #include "nnacl/conv_parameter.h" | |||
| #include "nnacl/errorcode.h" | |||
| #include "nnacl/fp32/common_func_fp32.h" | |||
| #include "nnacl/minimal_filtering_generator.h" | |||
| #include "nnacl/base/minimal_filtering_generator.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -22,7 +22,7 @@ | |||
| #include "nnacl/conv_parameter.h" | |||
| #include "nnacl/errorcode.h" | |||
| #include "nnacl/fp32/common_func_fp32.h" | |||
| #include "nnacl/minimal_filtering_generator.h" | |||
| #include "nnacl/base/minimal_filtering_generator.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/winograd_transform.h" | |||
| #include "nnacl/fp32/winograd_transform.h" | |||
| #include "nnacl/op_base.h" | |||
| // fp32 conv winograd | |||
| @@ -22,10 +22,7 @@ | |||
| #endif | |||
| #include <string.h> | |||
| #include "nnacl/pack.h" | |||
| #include "nnacl/winograd_utils.h" | |||
| #include "mindspore/lite/nnacl/int8/fixed_point.h" | |||
| #define OUPUT_UNIT 2 | |||
| #include "nnacl/fp32/winograd_utils.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/winograd_utils.h" | |||
| #include "nnacl/minimal_filtering_generator.h" | |||
| #include "nnacl/fp32/winograd_utils.h" | |||
| #include "nnacl/base/minimal_filtering_generator.h" | |||
| #define MIN_UNIT 2 | |||
| #define MAX_UNIT 8 | |||
| @@ -303,65 +303,66 @@ void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, | |||
| #endif | |||
| } | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| void InputTransform8x8Unit_block4(const float *src_data, float *dst_data, int src_step, int dst_step) { | |||
| MS_FLOAT32X4 src[64]; | |||
| MS_FLOAT32X4 t[64]; | |||
| MS_FLOAT32X4 m[64]; | |||
| Load64Data; | |||
| for (int l = 0; l < 8; ++l) { | |||
| int offset = l * 8; | |||
| t[l] = MS_SUBQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_F32(src[offset], 0.5625), MS_MULQ_F32(src[2 + offset], 3.0625)), | |||
| MS_MULQ_F32(src[4 + offset], 3.5)), | |||
| src[6 + offset]); | |||
| MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(MS_MULQ_F32(src[1 + offset], 1.125), MS_MULQ_F32(src[5 + offset], 0.5)); | |||
| MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(MS_MULQ_F32(src[2 + offset], 2.25), MS_MULQ_F32(src[4 + offset], 3.25)); | |||
| t[8 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(src[3 + offset], 1.625)), src[6 + offset]); | |||
| t[16 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(src[3 + offset], 1.625)), src[6 + offset]); | |||
| tmp1 = MS_ADDQ_F32(MS_MULQ_F32(src[1 + offset], 0.5625), src[5 + offset]); | |||
| tmp2 = MS_SUBQ_F32(MS_MULQ_F32(src[2 + offset], 0.5625), MS_MULQ_F32(src[4 + offset], 2.5)); | |||
| t[24 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(src[3 + offset], 2.5)), src[6 + offset]); | |||
| t[32 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(src[3 + offset], 2.5)), src[6 + offset]); | |||
| tmp1 = MS_ADDQ_F32(MS_MULQ_F32(src[1 + offset], 0.375), MS_MULQ_F32(src[5 + offset], 1.5)); | |||
| tmp2 = MS_SUBQ_F32(MS_MULQ_F32(src[2 + offset], 0.25), MS_MULQ_F32(src[4 + offset], 1.25)); | |||
| t[40 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(src[3 + offset], 1.875)), src[6 + offset]); | |||
| t[48 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(src[3 + offset], 1.875)), src[6 + offset]); | |||
| t[56 + l] = | |||
| MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_F32(src[1 + offset], -0.5625), MS_MULQ_F32(src[3 + offset], 3.0625)), | |||
| MS_MULQ_F32(src[5 + offset], 3.5)), | |||
| src[7 + offset]); | |||
| } | |||
| for (int l = 0; l < 8; ++l) { | |||
| int offset = l * 8; | |||
| m[l] = MS_SUBQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_F32(t[offset], 0.5625), MS_MULQ_F32(t[2 + offset], 3.0625)), | |||
| MS_MULQ_F32(t[4 + offset], 3.5)), | |||
| t[6 + offset]); | |||
| MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(MS_MULQ_F32(t[1 + offset], 1.125), MS_MULQ_F32(t[5 + offset], 0.5)); | |||
| MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(MS_MULQ_F32(t[2 + offset], 2.25), MS_MULQ_F32(t[4 + offset], 3.25)); | |||
| m[8 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(t[3 + offset], 1.625)), t[6 + offset]); | |||
| m[16 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(t[3 + offset], 1.625)), t[6 + offset]); | |||
| tmp1 = MS_ADDQ_F32(MS_MULQ_F32(t[1 + offset], 0.5625), t[5 + offset]); | |||
| tmp2 = MS_SUBQ_F32(MS_MULQ_F32(t[2 + offset], 0.5625), MS_MULQ_F32(t[4 + offset], 2.5)); | |||
| m[24 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(t[3 + offset], 2.5)), t[6 + offset]); | |||
| m[32 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(t[3 + offset], 2.5)), t[6 + offset]); | |||
| tmp1 = MS_ADDQ_F32(MS_MULQ_F32(t[1 + offset], 0.375), MS_MULQ_F32(t[5 + offset], 1.5)); | |||
| tmp2 = MS_SUBQ_F32(MS_MULQ_F32(t[2 + offset], 0.25), MS_MULQ_F32(t[4 + offset], 1.25)); | |||
| m[40 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(t[3 + offset], 1.875)), t[6 + offset]); | |||
| m[48 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(t[3 + offset], 1.875)), t[6 + offset]); | |||
| m[56 + l] = | |||
| MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_F32(t[1 + offset], -0.5625), MS_MULQ_F32(t[3 + offset], 3.0625)), | |||
| MS_MULQ_F32(t[5 + offset], 3.5)), | |||
| t[7 + offset]); | |||
| } | |||
| for (int i = 0; i < 64; i++) { | |||
| MS_STQ_F32(dst_data + i * dst_step, m[i]); | |||
| } | |||
| } | |||
| #endif | |||
| void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| if (real_c == 4) { | |||
| MS_FLOAT32X4 src[64]; | |||
| MS_FLOAT32X4 t[64]; | |||
| MS_FLOAT32X4 m[64]; | |||
| Load64Data; | |||
| for (int l = 0; l < 8; ++l) { | |||
| int offset = l * 8; | |||
| t[l] = | |||
| MS_SUBQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_F32(src[offset], 0.5625), MS_MULQ_F32(src[2 + offset], 3.0625)), | |||
| MS_MULQ_F32(src[4 + offset], 3.5)), | |||
| src[6 + offset]); | |||
| MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(MS_MULQ_F32(src[1 + offset], 1.125), MS_MULQ_F32(src[5 + offset], 0.5)); | |||
| MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(MS_MULQ_F32(src[2 + offset], 2.25), MS_MULQ_F32(src[4 + offset], 3.25)); | |||
| t[8 + l] = | |||
| MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(src[3 + offset], 1.625)), src[6 + offset]); | |||
| t[16 + l] = | |||
| MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(src[3 + offset], 1.625)), src[6 + offset]); | |||
| tmp1 = MS_ADDQ_F32(MS_MULQ_F32(src[1 + offset], 0.5625), src[5 + offset]); | |||
| tmp2 = MS_SUBQ_F32(MS_MULQ_F32(src[2 + offset], 0.5625), MS_MULQ_F32(src[4 + offset], 2.5)); | |||
| t[24 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(src[3 + offset], 2.5)), src[6 + offset]); | |||
| t[32 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(src[3 + offset], 2.5)), src[6 + offset]); | |||
| tmp1 = MS_ADDQ_F32(MS_MULQ_F32(src[1 + offset], 0.375), MS_MULQ_F32(src[5 + offset], 1.5)); | |||
| tmp2 = MS_SUBQ_F32(MS_MULQ_F32(src[2 + offset], 0.25), MS_MULQ_F32(src[4 + offset], 1.25)); | |||
| t[40 + l] = | |||
| MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(src[3 + offset], 1.875)), src[6 + offset]); | |||
| t[48 + l] = | |||
| MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(src[3 + offset], 1.875)), src[6 + offset]); | |||
| t[56 + l] = MS_ADDQ_F32( | |||
| MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_F32(src[1 + offset], -0.5625), MS_MULQ_F32(src[3 + offset], 3.0625)), | |||
| MS_MULQ_F32(src[5 + offset], 3.5)), | |||
| src[7 + offset]); | |||
| } | |||
| for (int l = 0; l < 8; ++l) { | |||
| int offset = l * 8; | |||
| m[l] = MS_SUBQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_F32(t[offset], 0.5625), MS_MULQ_F32(t[2 + offset], 3.0625)), | |||
| MS_MULQ_F32(t[4 + offset], 3.5)), | |||
| t[6 + offset]); | |||
| MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(MS_MULQ_F32(t[1 + offset], 1.125), MS_MULQ_F32(t[5 + offset], 0.5)); | |||
| MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(MS_MULQ_F32(t[2 + offset], 2.25), MS_MULQ_F32(t[4 + offset], 3.25)); | |||
| m[8 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(t[3 + offset], 1.625)), t[6 + offset]); | |||
| m[16 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(t[3 + offset], 1.625)), t[6 + offset]); | |||
| tmp1 = MS_ADDQ_F32(MS_MULQ_F32(t[1 + offset], 0.5625), t[5 + offset]); | |||
| tmp2 = MS_SUBQ_F32(MS_MULQ_F32(t[2 + offset], 0.5625), MS_MULQ_F32(t[4 + offset], 2.5)); | |||
| m[24 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(t[3 + offset], 2.5)), t[6 + offset]); | |||
| m[32 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(t[3 + offset], 2.5)), t[6 + offset]); | |||
| tmp1 = MS_ADDQ_F32(MS_MULQ_F32(t[1 + offset], 0.375), MS_MULQ_F32(t[5 + offset], 1.5)); | |||
| tmp2 = MS_SUBQ_F32(MS_MULQ_F32(t[2 + offset], 0.25), MS_MULQ_F32(t[4 + offset], 1.25)); | |||
| m[40 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(t[3 + offset], 1.875)), t[6 + offset]); | |||
| m[48 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(t[3 + offset], 1.875)), t[6 + offset]); | |||
| m[56 + l] = | |||
| MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_F32(t[1 + offset], -0.5625), MS_MULQ_F32(t[3 + offset], 3.0625)), | |||
| MS_MULQ_F32(t[5 + offset], 3.5)), | |||
| t[7 + offset]); | |||
| } | |||
| for (int i = 0; i < 64; i++) { | |||
| MS_STQ_F32(dst_data + i * dst_step, m[i]); | |||
| } | |||
| InputTransform8x8Unit_block4(src_data, dst_data, src_step, dst_step); | |||
| } else { | |||
| #endif | |||
| float src[64]; | |||
| @@ -2778,9 +2779,9 @@ void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float | |||
| #endif | |||
| } | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, | |||
| int dst_step, int out_c, int r_w, int r_h, int r_c) { | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| MS_FLOAT32X4 src[64]; | |||
| MS_FLOAT32X4 t[40]; | |||
| MS_FLOAT32X4 m[25]; | |||
| @@ -2837,7 +2838,10 @@ void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const fl | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #else | |||
| void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, | |||
| int dst_step, int out_c, int r_w, int r_h, int r_c) { | |||
| float src[64]; | |||
| float t[40]; | |||
| float m[25]; | |||
| @@ -2882,12 +2886,12 @@ void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const fl | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, | |||
| int dst_step, int out_c, int r_w, int r_h, int r_c) { | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| MS_FLOAT32X4 src[64]; | |||
| MS_FLOAT32X4 t[40]; | |||
| MS_FLOAT32X4 m[25]; | |||
| @@ -2950,7 +2954,10 @@ void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const f | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #else | |||
| void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, | |||
| int dst_step, int out_c, int r_w, int r_h, int r_c) { | |||
| float src[64]; | |||
| float t[40]; | |||
| float m[25]; | |||
| @@ -2996,12 +3003,12 @@ void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const f | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, | |||
| int out_c, int r_w, int r_h, int r_c) { | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| MS_FLOAT32X4 src[64]; | |||
| MS_FLOAT32X4 t[48]; | |||
| MS_FLOAT32X4 m[36]; | |||
| @@ -3065,7 +3072,10 @@ void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #else | |||
| void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, | |||
| int out_c, int r_w, int r_h, int r_c) { | |||
| float src[64]; | |||
| float t[48]; | |||
| float m[36]; | |||
| @@ -3112,12 +3122,12 @@ void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, | |||
| int dst_step, int out_c, int r_w, int r_h, int r_c) { | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| MS_FLOAT32X4 src[64]; | |||
| MS_FLOAT32X4 t[48]; | |||
| MS_FLOAT32X4 m[36]; | |||
| @@ -3188,7 +3198,10 @@ void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const fl | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #else | |||
| void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, | |||
| int dst_step, int out_c, int r_w, int r_h, int r_c) { | |||
| float src[64]; | |||
| float t[48]; | |||
| float m[36]; | |||
| @@ -3237,12 +3250,12 @@ void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const fl | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, | |||
| int dst_step, int out_c, int r_w, int r_h, int r_c) { | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| MS_FLOAT32X4 src[64]; | |||
| MS_FLOAT32X4 t[48]; | |||
| MS_FLOAT32X4 m[36]; | |||
| @@ -3320,7 +3333,10 @@ void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const f | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #else | |||
| void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, | |||
| int dst_step, int out_c, int r_w, int r_h, int r_c) { | |||
| float src[64]; | |||
| float t[48]; | |||
| float m[36]; | |||
| @@ -3370,12 +3386,12 @@ void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const f | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, | |||
| int out_c, int r_w, int r_h, int r_c) { | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| MS_FLOAT32X4 src[64]; | |||
| MS_FLOAT32X4 t[56]; | |||
| MS_FLOAT32X4 m[49]; | |||
| @@ -3443,7 +3459,10 @@ void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #else | |||
| void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, | |||
| int out_c, int r_w, int r_h, int r_c) { | |||
| float src[64]; | |||
| float t[56]; | |||
| float m[49]; | |||
| @@ -3494,12 +3513,12 @@ void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, | |||
| int dst_step, int out_c, int r_w, int r_h, int r_c) { | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| MS_FLOAT32X4 src[64]; | |||
| MS_FLOAT32X4 t[56]; | |||
| MS_FLOAT32X4 m[49]; | |||
| @@ -3575,7 +3594,10 @@ void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const fl | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #else | |||
| void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, | |||
| int dst_step, int out_c, int r_w, int r_h, int r_c) { | |||
| float src[64]; | |||
| float t[56]; | |||
| float m[49]; | |||
| @@ -3628,12 +3650,12 @@ void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const fl | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, | |||
| int dst_step, int out_c, int r_w, int r_h, int r_c) { | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| MS_FLOAT32X4 src[64]; | |||
| MS_FLOAT32X4 t[56]; | |||
| MS_FLOAT32X4 m[49]; | |||
| @@ -3717,7 +3739,10 @@ void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const f | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #else | |||
| void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, | |||
| int dst_step, int out_c, int r_w, int r_h, int r_c) { | |||
| float src[64]; | |||
| float t[56]; | |||
| float m[49]; | |||
| @@ -3771,8 +3796,8 @@ void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const f | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #endif | |||
| // Reference to the paper "Fast Algorithms for Convolutional Neural Networks" | |||
| // Utilize cost model to compute performance gain. | |||
| @@ -20,10 +20,64 @@ | |||
| #endif | |||
| #ifdef ENABLE_AVX | |||
| #include <x86intrin.h> | |||
| #include "nnacl/x86_64_avx/common_utils.h" | |||
| #include "nnacl/intrinsics/avx/common_utils.h" | |||
| #endif | |||
| #include "nnacl/int8/fixed_point.h" | |||
| #ifdef ENABLE_ARM | |||
| void AddInt8InputRounding(int32x4_t *in1, int32x4_t *in2, int32x4_t *in3, int32x4_t *in4, const int32x4_t left_vec, | |||
| const int32x4_t right_vec, const int32_t multiplier) { | |||
| // Apply left shift | |||
| *in1 = vmulq_s32(*in1, left_vec); | |||
| *in2 = vmulq_s32(*in2, left_vec); | |||
| *in3 = vmulq_s32(*in3, left_vec); | |||
| *in4 = vmulq_s32(*in4, left_vec); | |||
| // Apply the fixed-point part of the multiplier. | |||
| *in1 = vqrdmulhq_n_s32(*in1, multiplier); | |||
| *in2 = vqrdmulhq_n_s32(*in2, multiplier); | |||
| *in3 = vqrdmulhq_n_s32(*in3, multiplier); | |||
| *in4 = vqrdmulhq_n_s32(*in4, multiplier); | |||
| // Apply right shift | |||
| *in1 = vqaddq_s32(*in1, vshrq_n_s32(vandq_s32(*in1, right_vec), 31)); | |||
| *in2 = vqaddq_s32(*in2, vshrq_n_s32(vandq_s32(*in2, right_vec), 31)); | |||
| *in3 = vqaddq_s32(*in3, vshrq_n_s32(vandq_s32(*in3, right_vec), 31)); | |||
| *in4 = vqaddq_s32(*in4, vshrq_n_s32(vandq_s32(*in4, right_vec), 31)); | |||
| *in1 = vrshlq_s32(*in1, right_vec); | |||
| *in2 = vrshlq_s32(*in2, right_vec); | |||
| *in3 = vrshlq_s32(*in3, right_vec); | |||
| *in4 = vrshlq_s32(*in4, right_vec); | |||
| } | |||
| void AddInt8OutputRounding(int32x4_t *out1, int32x4_t *out2, int32x4_t *out3, int32x4_t *out4, const int32x4_t left_vec, | |||
| const int32x4_t right_vec, const int32_t multiplier) { | |||
| // Apply left shift | |||
| *out1 = vshlq_s32(*out1, left_vec); | |||
| *out2 = vshlq_s32(*out2, left_vec); | |||
| *out3 = vshlq_s32(*out3, left_vec); | |||
| *out4 = vshlq_s32(*out4, left_vec); | |||
| // Apply the fixed-point part of the multiplier. | |||
| *out1 = vqrdmulhq_n_s32(*out1, multiplier); | |||
| *out2 = vqrdmulhq_n_s32(*out2, multiplier); | |||
| *out3 = vqrdmulhq_n_s32(*out3, multiplier); | |||
| *out4 = vqrdmulhq_n_s32(*out4, multiplier); | |||
| // Apply right shift | |||
| *out1 = vqaddq_s32(*out1, vshrq_n_s32(vandq_s32(*out1, right_vec), 31)); | |||
| *out2 = vqaddq_s32(*out2, vshrq_n_s32(vandq_s32(*out2, right_vec), 31)); | |||
| *out3 = vqaddq_s32(*out3, vshrq_n_s32(vandq_s32(*out3, right_vec), 31)); | |||
| *out4 = vqaddq_s32(*out4, vshrq_n_s32(vandq_s32(*out4, right_vec), 31)); | |||
| *out1 = vrshlq_s32(*out1, right_vec); | |||
| *out2 = vrshlq_s32(*out2, right_vec); | |||
| *out3 = vrshlq_s32(*out3, right_vec); | |||
| *out4 = vrshlq_s32(*out4, right_vec); | |||
| } | |||
| #endif | |||
| void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, AddQuantParameter *params) { | |||
| int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_args_.left_shift_); | |||
| int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_args_.left_shift_); | |||
| @@ -68,44 +122,8 @@ void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int siz | |||
| int32x4_t in1_3 = vmovl_s16(vget_low_s16(in1_zp_high)); | |||
| int32x4_t in1_4 = vmovl_s16(vget_high_s16(in1_zp_high)); | |||
| // Apply left shift | |||
| in0_1 = vmulq_s32(in0_1, in0_left_vec); | |||
| in0_2 = vmulq_s32(in0_2, in0_left_vec); | |||
| in0_3 = vmulq_s32(in0_3, in0_left_vec); | |||
| in0_4 = vmulq_s32(in0_4, in0_left_vec); | |||
| in1_1 = vmulq_s32(in1_1, in1_left_vec); | |||
| in1_2 = vmulq_s32(in1_2, in1_left_vec); | |||
| in1_3 = vmulq_s32(in1_3, in1_left_vec); | |||
| in1_4 = vmulq_s32(in1_4, in1_left_vec); | |||
| // Apply the fixed-point part of the multiplier. | |||
| in0_1 = vqrdmulhq_n_s32(in0_1, params->in0_args_.multiplier_); | |||
| in0_2 = vqrdmulhq_n_s32(in0_2, params->in0_args_.multiplier_); | |||
| in0_3 = vqrdmulhq_n_s32(in0_3, params->in0_args_.multiplier_); | |||
| in0_4 = vqrdmulhq_n_s32(in0_4, params->in0_args_.multiplier_); | |||
| in1_1 = vqrdmulhq_n_s32(in1_1, params->in1_args_.multiplier_); | |||
| in1_2 = vqrdmulhq_n_s32(in1_2, params->in1_args_.multiplier_); | |||
| in1_3 = vqrdmulhq_n_s32(in1_3, params->in1_args_.multiplier_); | |||
| in1_4 = vqrdmulhq_n_s32(in1_4, params->in1_args_.multiplier_); | |||
| // Apply right shift | |||
| in0_1 = vqaddq_s32(in0_1, vshrq_n_s32(vandq_s32(in0_1, in0_right_vec), 31)); | |||
| in0_2 = vqaddq_s32(in0_2, vshrq_n_s32(vandq_s32(in0_2, in0_right_vec), 31)); | |||
| in0_3 = vqaddq_s32(in0_3, vshrq_n_s32(vandq_s32(in0_3, in0_right_vec), 31)); | |||
| in0_4 = vqaddq_s32(in0_4, vshrq_n_s32(vandq_s32(in0_4, in0_right_vec), 31)); | |||
| in1_1 = vqaddq_s32(in1_1, vshrq_n_s32(vandq_s32(in1_1, in1_right_vec), 31)); | |||
| in1_2 = vqaddq_s32(in1_2, vshrq_n_s32(vandq_s32(in1_2, in1_right_vec), 31)); | |||
| in1_3 = vqaddq_s32(in1_3, vshrq_n_s32(vandq_s32(in1_3, in1_right_vec), 31)); | |||
| in1_4 = vqaddq_s32(in1_4, vshrq_n_s32(vandq_s32(in1_4, in1_right_vec), 31)); | |||
| in0_1 = vrshlq_s32(in0_1, in0_right_vec); | |||
| in0_2 = vrshlq_s32(in0_2, in0_right_vec); | |||
| in0_3 = vrshlq_s32(in0_3, in0_right_vec); | |||
| in0_4 = vrshlq_s32(in0_4, in0_right_vec); | |||
| in1_1 = vrshlq_s32(in1_1, in1_right_vec); | |||
| in1_2 = vrshlq_s32(in1_2, in1_right_vec); | |||
| in1_3 = vrshlq_s32(in1_3, in1_right_vec); | |||
| in1_4 = vrshlq_s32(in1_4, in1_right_vec); | |||
| AddInt8InputRounding(&in0_1, &in0_2, &in0_3, &in0_4, in0_left_vec, in0_right_vec, params->in0_args_.multiplier_); | |||
| AddInt8InputRounding(&in1_1, &in1_2, &in1_3, &in1_4, in1_left_vec, in1_right_vec, params->in1_args_.multiplier_); | |||
| /* calculate output */ | |||
| int32x4_t out1 = vaddq_s32(in0_1, in1_1); | |||
| @@ -113,28 +131,7 @@ void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int siz | |||
| int32x4_t out3 = vaddq_s32(in0_3, in1_3); | |||
| int32x4_t out4 = vaddq_s32(in0_4, in1_4); | |||
| // Apply left shift | |||
| out1 = vshlq_s32(out1, out_left_vec); | |||
| out2 = vshlq_s32(out2, out_left_vec); | |||
| out3 = vshlq_s32(out3, out_left_vec); | |||
| out4 = vshlq_s32(out4, out_left_vec); | |||
| // Apply the fixed-point part of the multiplier. | |||
| out1 = vqrdmulhq_n_s32(out1, params->out_multiplier_); | |||
| out2 = vqrdmulhq_n_s32(out2, params->out_multiplier_); | |||
| out3 = vqrdmulhq_n_s32(out3, params->out_multiplier_); | |||
| out4 = vqrdmulhq_n_s32(out4, params->out_multiplier_); | |||
| // Apply right shift | |||
| out1 = vqaddq_s32(out1, vshrq_n_s32(vandq_s32(out1, out_right_vec), 31)); | |||
| out2 = vqaddq_s32(out2, vshrq_n_s32(vandq_s32(out2, out_right_vec), 31)); | |||
| out3 = vqaddq_s32(out3, vshrq_n_s32(vandq_s32(out3, out_right_vec), 31)); | |||
| out4 = vqaddq_s32(out4, vshrq_n_s32(vandq_s32(out4, out_right_vec), 31)); | |||
| out1 = vrshlq_s32(out1, out_right_vec); | |||
| out2 = vrshlq_s32(out2, out_right_vec); | |||
| out3 = vrshlq_s32(out3, out_right_vec); | |||
| out4 = vrshlq_s32(out4, out_right_vec); | |||
| AddInt8OutputRounding(&out1, &out2, &out3, &out4, out_left_vec, out_right_vec, params->out_multiplier_); | |||
| const int16x4_t out1_s16 = vmovn_s32(out1); | |||
| const int16x4_t out2_s16 = vmovn_s32(out2); | |||
| @@ -200,25 +197,8 @@ void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, i | |||
| int32x4_t ele2 = vmovl_s16(vget_high_s16(ele_zp_low)); | |||
| int32x4_t ele3 = vmovl_s16(vget_low_s16(ele_zp_high)); | |||
| int32x4_t ele4 = vmovl_s16(vget_high_s16(ele_zp_high)); | |||
| // Apply left shift | |||
| ele1 = vmulq_s32(ele1, ele_left_vec); | |||
| ele2 = vmulq_s32(ele2, ele_left_vec); | |||
| ele3 = vmulq_s32(ele3, ele_left_vec); | |||
| ele4 = vmulq_s32(ele4, ele_left_vec); | |||
| // Apply the fixed-point part of the multiplier. | |||
| ele1 = vqrdmulhq_n_s32(ele1, ele_args->multiplier_); | |||
| ele2 = vqrdmulhq_n_s32(ele2, ele_args->multiplier_); | |||
| ele3 = vqrdmulhq_n_s32(ele3, ele_args->multiplier_); | |||
| ele4 = vqrdmulhq_n_s32(ele4, ele_args->multiplier_); | |||
| // Apply right shift | |||
| ele1 = vqaddq_s32(ele1, vshrq_n_s32(vandq_s32(ele1, ele_right_vec), 31)); | |||
| ele2 = vqaddq_s32(ele2, vshrq_n_s32(vandq_s32(ele2, ele_right_vec), 31)); | |||
| ele3 = vqaddq_s32(ele3, vshrq_n_s32(vandq_s32(ele3, ele_right_vec), 31)); | |||
| ele4 = vqaddq_s32(ele4, vshrq_n_s32(vandq_s32(ele4, ele_right_vec), 31)); | |||
| ele1 = vrshlq_s32(ele1, ele_right_vec); | |||
| ele2 = vrshlq_s32(ele2, ele_right_vec); | |||
| ele3 = vrshlq_s32(ele3, ele_right_vec); | |||
| ele4 = vrshlq_s32(ele4, ele_right_vec); | |||
| AddInt8InputRounding(&ele1, &ele2, &ele3, &ele4, ele_left_vec, ele_right_vec, ele_args->multiplier_); | |||
| for (; index <= size - 16; index += 16) { | |||
| const int8x16_t ptr_src = vld1q_s8(ptr_in + index); | |||
| @@ -234,28 +214,7 @@ void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, i | |||
| int32x4_t ptr3 = vmovl_s16(vget_low_s16(ptr_zp_high)); | |||
| int32x4_t ptr4 = vmovl_s16(vget_high_s16(ptr_zp_high)); | |||
| // Apply left shift | |||
| ptr1 = vmulq_s32(ptr1, ptr_left_vec); | |||
| ptr2 = vmulq_s32(ptr2, ptr_left_vec); | |||
| ptr3 = vmulq_s32(ptr3, ptr_left_vec); | |||
| ptr4 = vmulq_s32(ptr4, ptr_left_vec); | |||
| // Apply the fixed-point part of the multiplier. | |||
| ptr1 = vqrdmulhq_n_s32(ptr1, ptr_args->multiplier_); | |||
| ptr2 = vqrdmulhq_n_s32(ptr2, ptr_args->multiplier_); | |||
| ptr3 = vqrdmulhq_n_s32(ptr3, ptr_args->multiplier_); | |||
| ptr4 = vqrdmulhq_n_s32(ptr4, ptr_args->multiplier_); | |||
| // Apply right shift | |||
| ptr1 = vqaddq_s32(ptr1, vshrq_n_s32(vandq_s32(ptr1, ptr_right_vec), 31)); | |||
| ptr2 = vqaddq_s32(ptr2, vshrq_n_s32(vandq_s32(ptr2, ptr_right_vec), 31)); | |||
| ptr3 = vqaddq_s32(ptr3, vshrq_n_s32(vandq_s32(ptr3, ptr_right_vec), 31)); | |||
| ptr4 = vqaddq_s32(ptr4, vshrq_n_s32(vandq_s32(ptr4, ptr_right_vec), 31)); | |||
| ptr1 = vrshlq_s32(ptr1, ptr_right_vec); | |||
| ptr2 = vrshlq_s32(ptr2, ptr_right_vec); | |||
| ptr3 = vrshlq_s32(ptr3, ptr_right_vec); | |||
| ptr4 = vrshlq_s32(ptr4, ptr_right_vec); | |||
| AddInt8InputRounding(&ptr1, &ptr2, &ptr3, &ptr4, ptr_left_vec, ptr_right_vec, ptr_args->multiplier_); | |||
| /* calculate output */ | |||
| int32x4_t out1 = vaddq_s32(ptr1, ele1); | |||
| @@ -263,28 +222,7 @@ void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, i | |||
| int32x4_t out3 = vaddq_s32(ptr3, ele3); | |||
| int32x4_t out4 = vaddq_s32(ptr4, ele4); | |||
| // Apply output left shift | |||
| out1 = vshlq_s32(out1, out_left_vec); | |||
| out2 = vshlq_s32(out2, out_left_vec); | |||
| out3 = vshlq_s32(out3, out_left_vec); | |||
| out4 = vshlq_s32(out4, out_left_vec); | |||
| // Apply output fixed-point part of the multiplier. | |||
| out1 = vqrdmulhq_n_s32(out1, params->out_multiplier_); | |||
| out2 = vqrdmulhq_n_s32(out2, params->out_multiplier_); | |||
| out3 = vqrdmulhq_n_s32(out3, params->out_multiplier_); | |||
| out4 = vqrdmulhq_n_s32(out4, params->out_multiplier_); | |||
| // Apply output right shift | |||
| out1 = vqaddq_s32(out1, vshrq_n_s32(vandq_s32(out1, out_right_vec), 31)); | |||
| out2 = vqaddq_s32(out2, vshrq_n_s32(vandq_s32(out2, out_right_vec), 31)); | |||
| out3 = vqaddq_s32(out3, vshrq_n_s32(vandq_s32(out3, out_right_vec), 31)); | |||
| out4 = vqaddq_s32(out4, vshrq_n_s32(vandq_s32(out4, out_right_vec), 31)); | |||
| out1 = vrshlq_s32(out1, out_right_vec); | |||
| out2 = vrshlq_s32(out2, out_right_vec); | |||
| out3 = vrshlq_s32(out3, out_right_vec); | |||
| out4 = vrshlq_s32(out4, out_right_vec); | |||
| AddInt8OutputRounding(&out1, &out2, &out3, &out4, out_left_vec, out_right_vec, params->out_multiplier_); | |||
| const int16x4_t out1_s16 = vmovn_s32(out1); | |||
| const int16x4_t out2_s16 = vmovn_s32(out2); | |||
| @@ -300,7 +238,6 @@ void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, i | |||
| vst1q_s8(output + index, int8_out); | |||
| } | |||
| #endif | |||
| for (; index < size; index++) { | |||
| const int32_t ptr_left = (ptr_in[index] + ptr_args->zp_) * ptr_left_shift; | |||
| const int32_t ele_left = (element_in + ele_args->zp_) * ele_left_shift; | |||
| @@ -329,6 +266,43 @@ int BroadcastAddInt8(const int8_t *in0, const int8_t *in1, int8_t *tile_in0, int | |||
| } | |||
| #ifdef ENABLE_AVX | |||
| void AddInt8Rounding(__m128i *in1, __m128i *in2, __m128i *in3, __m128i *in4, const __m128i left_vec, | |||
| const int32_t right_shift, const __m128i multiplier) { | |||
| // Apply left shift | |||
| *in1 = _mm_mullo_epi32(*in1, left_vec); | |||
| *in2 = _mm_mullo_epi32(*in2, left_vec); | |||
| *in3 = _mm_mullo_epi32(*in3, left_vec); | |||
| *in4 = _mm_mullo_epi32(*in4, left_vec); | |||
| // Apply the fixed-point part of the multiplier. | |||
| *in1 = _mm_qrdmulh_epi32(*in1, multiplier); | |||
| *in2 = _mm_qrdmulh_epi32(*in2, multiplier); | |||
| *in3 = _mm_qrdmulh_epi32(*in3, multiplier); | |||
| *in4 = _mm_qrdmulh_epi32(*in4, multiplier); | |||
| // Apply right shift | |||
| int32_t in1_remainder_mask = (1ll << (right_shift)) - 1; | |||
| int32_t in1_remainder_threshold = in1_remainder_mask >> 1; | |||
| const __m128i vin1_remainder_mask = _mm_set1_epi32(in1_remainder_mask); | |||
| const __m128i vin1_remainder_threshold = _mm_set1_epi32(in1_remainder_threshold); | |||
| const __m128i in1_remainder = | |||
| _mm_add_epi32(_mm_and_si128(*in1, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), *in1)); | |||
| *in1 = _mm_sub_epi32(_mm_rshr_epi32(*in1, right_shift), _mm_cmpgt_epi32(in1_remainder, vin1_remainder_threshold)); | |||
| const __m128i in2_remainder = | |||
| _mm_add_epi32(_mm_and_si128(*in2, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), *in2)); | |||
| *in2 = _mm_sub_epi32(_mm_rshr_epi32(*in2, right_shift), _mm_cmpgt_epi32(in2_remainder, vin1_remainder_threshold)); | |||
| const __m128i in3_remainder = | |||
| _mm_add_epi32(_mm_and_si128(*in3, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), *in3)); | |||
| *in3 = _mm_sub_epi32(_mm_rshr_epi32(*in3, right_shift), _mm_cmpgt_epi32(in3_remainder, vin1_remainder_threshold)); | |||
| const __m128i in4_remainder = | |||
| _mm_add_epi32(_mm_and_si128(*in4, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), *in4)); | |||
| *in4 = _mm_sub_epi32(_mm_rshr_epi32(*in4, right_shift), _mm_cmpgt_epi32(in4_remainder, vin1_remainder_threshold)); | |||
| } | |||
| void AddInt8_AVX2(const int8_t *input0, const int8_t *input1, int8_t *output, int size, AddQuantParameter *params) { | |||
| const int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_args_.left_shift_); | |||
| const int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_args_.left_shift_); | |||
| @@ -372,68 +346,8 @@ void AddInt8_AVX2(const int8_t *input0, const int8_t *input1, int8_t *output, in | |||
| __m128i in1_3 = _mm256_extractf128_si256(tmp_in1, 0); | |||
| __m128i in1_4 = _mm256_extractf128_si256(tmp_in1, 1); | |||
| // Apply left shift | |||
| in0_1 = _mm_mullo_epi32(in0_1, in0_left_vec); | |||
| in0_2 = _mm_mullo_epi32(in0_2, in0_left_vec); | |||
| in0_3 = _mm_mullo_epi32(in0_3, in0_left_vec); | |||
| in0_4 = _mm_mullo_epi32(in0_4, in0_left_vec); | |||
| in1_1 = _mm_mullo_epi32(in1_1, in1_left_vec); | |||
| in1_2 = _mm_mullo_epi32(in1_2, in1_left_vec); | |||
| in1_3 = _mm_mullo_epi32(in1_3, in1_left_vec); | |||
| in1_4 = _mm_mullo_epi32(in1_4, in1_left_vec); | |||
| // Apply the fixed-point part of the multiplier. | |||
| in0_1 = _mm_qrdmulh_epi32(in0_1, in0_multiplier); | |||
| in0_2 = _mm_qrdmulh_epi32(in0_2, in0_multiplier); | |||
| in0_3 = _mm_qrdmulh_epi32(in0_3, in0_multiplier); | |||
| in0_4 = _mm_qrdmulh_epi32(in0_4, in0_multiplier); | |||
| in1_1 = _mm_qrdmulh_epi32(in1_1, in1_multiplier); | |||
| in1_2 = _mm_qrdmulh_epi32(in1_2, in1_multiplier); | |||
| in1_3 = _mm_qrdmulh_epi32(in1_3, in1_multiplier); | |||
| in1_4 = _mm_qrdmulh_epi32(in1_4, in1_multiplier); | |||
| // Apply right shift | |||
| int32_t in0_remainder_mask = (1ll << (params->in0_args_.right_shift_)) - 1; | |||
| int32_t in0_remainder_threshold = in0_remainder_mask >> 1; | |||
| const __m128i vin0_remainder_mask = _mm_set1_epi32(in0_remainder_mask); | |||
| const __m128i vin0_remainder_threshold = _mm_set1_epi32(in0_remainder_threshold); | |||
| const __m128i in0_1_remainder = | |||
| _mm_add_epi32(_mm_and_si128(in0_1, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_1)); | |||
| in0_1 = _mm_sub_epi32(_mm_rshr_epi32(in0_1, params->in0_args_.right_shift_), | |||
| _mm_cmpgt_epi32(in0_1_remainder, vin0_remainder_threshold)); | |||
| const __m128i in0_2_remainder = | |||
| _mm_add_epi32(_mm_and_si128(in0_2, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_2)); | |||
| in0_2 = _mm_sub_epi32(_mm_rshr_epi32(in0_2, params->in0_args_.right_shift_), | |||
| _mm_cmpgt_epi32(in0_2_remainder, vin0_remainder_threshold)); | |||
| const __m128i in0_3_remainder = | |||
| _mm_add_epi32(_mm_and_si128(in0_3, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_3)); | |||
| in0_3 = _mm_sub_epi32(_mm_rshr_epi32(in0_3, params->in0_args_.right_shift_), | |||
| _mm_cmpgt_epi32(in0_3_remainder, vin0_remainder_threshold)); | |||
| const __m128i in0_4_remainder = | |||
| _mm_add_epi32(_mm_and_si128(in0_4, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_4)); | |||
| in0_4 = _mm_sub_epi32(_mm_rshr_epi32(in0_4, params->in0_args_.right_shift_), | |||
| _mm_cmpgt_epi32(in0_4_remainder, vin0_remainder_threshold)); | |||
| int32_t in1_remainder_mask = (1ll << (params->in1_args_.right_shift_)) - 1; | |||
| int32_t in1_remainder_threshold = in1_remainder_mask >> 1; | |||
| const __m128i vin1_remainder_mask = _mm_set1_epi32(in1_remainder_mask); | |||
| const __m128i vin1_remainder_threshold = _mm_set1_epi32(in1_remainder_threshold); | |||
| const __m128i in1_1_remainder = | |||
| _mm_add_epi32(_mm_and_si128(in1_1, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_1)); | |||
| in1_1 = _mm_sub_epi32(_mm_rshr_epi32(in1_1, params->in1_args_.right_shift_), | |||
| _mm_cmpgt_epi32(in1_1_remainder, vin1_remainder_threshold)); | |||
| const __m128i in1_2_remainder = | |||
| _mm_add_epi32(_mm_and_si128(in1_2, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_2)); | |||
| in1_2 = _mm_sub_epi32(_mm_rshr_epi32(in1_2, params->in1_args_.right_shift_), | |||
| _mm_cmpgt_epi32(in1_2_remainder, vin1_remainder_threshold)); | |||
| const __m128i in1_3_remainder = | |||
| _mm_add_epi32(_mm_and_si128(in1_3, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_3)); | |||
| in1_3 = _mm_sub_epi32(_mm_rshr_epi32(in1_3, params->in1_args_.right_shift_), | |||
| _mm_cmpgt_epi32(in1_3_remainder, vin1_remainder_threshold)); | |||
| const __m128i in1_4_remainder = | |||
| _mm_add_epi32(_mm_and_si128(in1_4, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_4)); | |||
| in1_4 = _mm_sub_epi32(_mm_rshr_epi32(in1_4, params->in1_args_.right_shift_), | |||
| _mm_cmpgt_epi32(in1_4_remainder, vin1_remainder_threshold)); | |||
| AddInt8Rounding(&in0_1, &in0_2, &in0_3, &in0_4, in0_left_vec, params->in0_args_.right_shift_, in0_multiplier); | |||
| AddInt8Rounding(&in1_1, &in1_2, &in1_3, &in1_4, in1_left_vec, params->in1_args_.right_shift_, in1_multiplier); | |||
| /* calculate output */ | |||
| __m128i out1 = _mm_add_epi32(in0_1, in1_1); | |||
| @@ -533,39 +447,7 @@ void AddOptInt8_AVX2(const int8_t *ptr_in, const int8_t element_in, int8_t *outp | |||
| __m128i in1_3 = _mm256_extractf128_si256(tmp_in1, 0); | |||
| __m128i in1_4 = _mm256_extractf128_si256(tmp_in1, 1); | |||
| // Apply left shift | |||
| in1_1 = _mm_mullo_epi32(in1_1, in1_left_vec); | |||
| in1_2 = _mm_mullo_epi32(in1_2, in1_left_vec); | |||
| in1_3 = _mm_mullo_epi32(in1_3, in1_left_vec); | |||
| in1_4 = _mm_mullo_epi32(in1_4, in1_left_vec); | |||
| // Apply the fixed-point part of the multiplier. | |||
| in1_1 = _mm_qrdmulh_epi32(in1_1, in1_multiplier); | |||
| in1_2 = _mm_qrdmulh_epi32(in1_2, in1_multiplier); | |||
| in1_3 = _mm_qrdmulh_epi32(in1_3, in1_multiplier); | |||
| in1_4 = _mm_qrdmulh_epi32(in1_4, in1_multiplier); | |||
| // Apply right shift | |||
| int32_t in1_remainder_mask = (1ll << (params->in1_args_.right_shift_)) - 1; | |||
| int32_t in1_remainder_threshold = in1_remainder_mask >> 1; | |||
| const __m128i vin1_remainder_mask = _mm_set1_epi32(in1_remainder_mask); | |||
| const __m128i vin1_remainder_threshold = _mm_set1_epi32(in1_remainder_threshold); | |||
| const __m128i in1_1_remainder = | |||
| _mm_add_epi32(_mm_and_si128(in1_1, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_1)); | |||
| in1_1 = _mm_sub_epi32(_mm_rshr_epi32(in1_1, params->in1_args_.right_shift_), | |||
| _mm_cmpgt_epi32(in1_1_remainder, vin1_remainder_threshold)); | |||
| const __m128i in1_2_remainder = | |||
| _mm_add_epi32(_mm_and_si128(in1_2, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_2)); | |||
| in1_2 = _mm_sub_epi32(_mm_rshr_epi32(in1_2, params->in1_args_.right_shift_), | |||
| _mm_cmpgt_epi32(in1_2_remainder, vin1_remainder_threshold)); | |||
| const __m128i in1_3_remainder = | |||
| _mm_add_epi32(_mm_and_si128(in1_3, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_3)); | |||
| in1_3 = _mm_sub_epi32(_mm_rshr_epi32(in1_3, params->in1_args_.right_shift_), | |||
| _mm_cmpgt_epi32(in1_3_remainder, vin1_remainder_threshold)); | |||
| const __m128i in1_4_remainder = | |||
| _mm_add_epi32(_mm_and_si128(in1_4, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_4)); | |||
| in1_4 = _mm_sub_epi32(_mm_rshr_epi32(in1_4, params->in1_args_.right_shift_), | |||
| _mm_cmpgt_epi32(in1_4_remainder, vin1_remainder_threshold)); | |||
| AddInt8Rounding(&in1_1, &in1_2, &in1_3, &in1_4, in1_left_vec, params->in1_args_.right_shift_, in1_multiplier); | |||
| int index = 0; | |||
| for (; index <= size - 16; index += 16) { | |||
| @@ -583,39 +465,7 @@ void AddOptInt8_AVX2(const int8_t *ptr_in, const int8_t element_in, int8_t *outp | |||
| __m128i in0_3 = _mm256_extractf128_si256(tmp_in0, 0); | |||
| __m128i in0_4 = _mm256_extractf128_si256(tmp_in0, 1); | |||
| // Apply left shift | |||
| in0_1 = _mm_mullo_epi32(in0_1, in0_left_vec); | |||
| in0_2 = _mm_mullo_epi32(in0_2, in0_left_vec); | |||
| in0_3 = _mm_mullo_epi32(in0_3, in0_left_vec); | |||
| in0_4 = _mm_mullo_epi32(in0_4, in0_left_vec); | |||
| // Apply the fixed-point part of the multiplier. | |||
| in0_1 = _mm_qrdmulh_epi32(in0_1, in0_multiplier); | |||
| in0_2 = _mm_qrdmulh_epi32(in0_2, in0_multiplier); | |||
| in0_3 = _mm_qrdmulh_epi32(in0_3, in0_multiplier); | |||
| in0_4 = _mm_qrdmulh_epi32(in0_4, in0_multiplier); | |||
| // Apply right shift | |||
| int32_t in0_remainder_mask = (1ll << (params->in0_args_.right_shift_)) - 1; | |||
| int32_t in0_remainder_threshold = in0_remainder_mask >> 1; | |||
| const __m128i vin0_remainder_mask = _mm_set1_epi32(in0_remainder_mask); | |||
| const __m128i vin0_remainder_threshold = _mm_set1_epi32(in0_remainder_threshold); | |||
| const __m128i in0_1_remainder = | |||
| _mm_add_epi32(_mm_and_si128(in0_1, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_1)); | |||
| in0_1 = _mm_sub_epi32(_mm_rshr_epi32(in0_1, params->in0_args_.right_shift_), | |||
| _mm_cmpgt_epi32(in0_1_remainder, vin0_remainder_threshold)); | |||
| const __m128i in0_2_remainder = | |||
| _mm_add_epi32(_mm_and_si128(in0_2, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_2)); | |||
| in0_2 = _mm_sub_epi32(_mm_rshr_epi32(in0_2, params->in0_args_.right_shift_), | |||
| _mm_cmpgt_epi32(in0_2_remainder, vin0_remainder_threshold)); | |||
| const __m128i in0_3_remainder = | |||
| _mm_add_epi32(_mm_and_si128(in0_3, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_3)); | |||
| in0_3 = _mm_sub_epi32(_mm_rshr_epi32(in0_3, params->in0_args_.right_shift_), | |||
| _mm_cmpgt_epi32(in0_3_remainder, vin0_remainder_threshold)); | |||
| const __m128i in0_4_remainder = | |||
| _mm_add_epi32(_mm_and_si128(in0_4, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_4)); | |||
| in0_4 = _mm_sub_epi32(_mm_rshr_epi32(in0_4, params->in0_args_.right_shift_), | |||
| _mm_cmpgt_epi32(in0_4_remainder, vin0_remainder_threshold)); | |||
| AddInt8Rounding(&in0_1, &in0_2, &in0_3, &in0_4, in0_left_vec, params->in0_args_.right_shift_, in0_multiplier); | |||
| /* calculate output */ | |||
| __m128i out1 = _mm_add_epi32(in0_1, in1_1); | |||
| @@ -24,11 +24,10 @@ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/common_func.h" | |||
| #include "nnacl/conv_parameter.h" | |||
| #include "nnacl/winograd_utils.h" | |||
| #include "nnacl/int8/fixed_point.h" | |||
| #include "nnacl/int8/quantize.h" | |||
| #include "nnacl/matmul_parameter.h" | |||
| #include "nnacl/int8/matmul_int8.h" | |||
| #include "nnacl/winograd_transform.h" | |||
| #include "nnacl/int8/common_func_int8.h" | |||
| #ifdef __cplusplus | |||
| @@ -24,7 +24,6 @@ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/common_func.h" | |||
| #include "nnacl/conv_parameter.h" | |||
| #include "nnacl/winograd_utils.h" | |||
| #include "nnacl/int8/quantize.h" | |||
| #include "nnacl/matmul_parameter.h" | |||
| #include "nnacl/int8/matmul_int8.h" | |||
| @@ -13,7 +13,7 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/x86_64_avx/common_utils.h" | |||
| #include "nnacl/intrinsics/avx/common_utils.h" | |||
| #ifdef WIN32 | |||
| #ifdef ENABLE_AVX | |||
| #include <stdint.h> | |||
| @@ -17,6 +17,7 @@ | |||
| #ifdef ENABLE_SSE | |||
| #include <x86intrin.h> | |||
| #include "nnacl/fp32/conv_depthwise_fp32.h" | |||
| #include "nnacl/intrinsics/sse/sse_common.h" | |||
| void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, | |||
| size_t in_kh_step, size_t in_kw_step, size_t kernel_w_step, size_t relu, size_t relu6) { | |||
| @@ -123,18 +124,16 @@ void ConvDwFp32Center(float *dst, const float *src, const float *weight, const f | |||
| int c2 = DOWN_DIV(width, C2NUM) * C2NUM; | |||
| int c1 = 0; | |||
| // c4 loop | |||
| for (; c1 < c4; c1 += C4NUM) { | |||
| const float *src_kh = src_w; | |||
| const float *weight_kh = weight; | |||
| for (; c1 < c4; c1 += C4NUM, dst_w += C4NUM * block_channel, src_w += C4NUM * in_sw_step) { | |||
| const float *src_kh = src_w, *weight_kh = weight; | |||
| __m128 dst_w_ma1 = _mm_setzero_ps(); | |||
| __m128 dst_w_ma2 = _mm_setzero_ps(); | |||
| __m128 dst_w_ma3 = _mm_setzero_ps(); | |||
| __m128 dst_w_ma4 = _mm_setzero_ps(); | |||
| for (int kh = 0; kh < kernel_h; kh++) { | |||
| const float *src_kw = src_kh; | |||
| const float *weight_kw = weight_kh; | |||
| for (int kw = 0; kw < kernel_w; kw++) { | |||
| for (int kh = 0; kh < kernel_h; kh++, src_kh += in_kh_step, weight_kh += kernel_w * C4NUM) { | |||
| const float *src_kw = src_kh, *weight_kw = weight_kh; | |||
| for (int kw = 0; kw < kernel_w; kw++, src_kw += in_kw_step, weight_kw += C4NUM) { | |||
| __m128 src_kw_ma1 = _mm_loadu_ps(src_kw); | |||
| __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); | |||
| __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1); | |||
| @@ -154,13 +153,9 @@ void ConvDwFp32Center(float *dst, const float *src, const float *weight, const f | |||
| __m128 weight_kw_ma4 = _mm_loadu_ps(weight_kw); | |||
| __m128 tmp_ma4 = _mm_mul_ps(src_kw_ma4, weight_kw_ma4); | |||
| dst_w_ma4 = _mm_add_ps(dst_w_ma4, tmp_ma4); | |||
| src_kw += in_kw_step; | |||
| weight_kw += C4NUM; | |||
| } // kernel_w loop | |||
| src_kh += in_kh_step; | |||
| weight_kh += kernel_w * C4NUM; | |||
| } // kernel_h loop | |||
| } // kernel_h loop | |||
| // add bias relu | |||
| __m128 bias_ma = _mm_loadu_ps(bias); | |||
| dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma); | |||
| @@ -168,39 +163,23 @@ void ConvDwFp32Center(float *dst, const float *src, const float *weight, const f | |||
| dst_w_ma3 = _mm_add_ps(dst_w_ma3, bias_ma); | |||
| dst_w_ma4 = _mm_add_ps(dst_w_ma4, bias_ma); | |||
| __m128 zero_ma = _mm_setzero_ps(); | |||
| if (relu || relu6) { | |||
| dst_w_ma1 = _mm_max_ps(zero_ma, dst_w_ma1); | |||
| dst_w_ma2 = _mm_max_ps(zero_ma, dst_w_ma2); | |||
| dst_w_ma3 = _mm_max_ps(zero_ma, dst_w_ma3); | |||
| dst_w_ma4 = _mm_max_ps(zero_ma, dst_w_ma4); | |||
| if (relu6) { | |||
| __m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); | |||
| dst_w_ma1 = _mm_min_ps(const_ma, dst_w_ma1); | |||
| dst_w_ma2 = _mm_min_ps(const_ma, dst_w_ma2); | |||
| dst_w_ma3 = _mm_min_ps(const_ma, dst_w_ma3); | |||
| dst_w_ma4 = _mm_min_ps(const_ma, dst_w_ma4); | |||
| } | |||
| } | |||
| ActBlock4(&dst_w_ma1, &dst_w_ma2, &dst_w_ma3, &dst_w_ma4, relu, relu6); | |||
| _mm_storeu_ps(dst_w, dst_w_ma1); | |||
| _mm_storeu_ps(dst_w + block_channel, dst_w_ma2); | |||
| _mm_storeu_ps(dst_w + 2 * block_channel, dst_w_ma3); | |||
| _mm_storeu_ps(dst_w + 3 * block_channel, dst_w_ma4); | |||
| dst_w += C4NUM * block_channel; | |||
| src_w += C4NUM * in_sw_step; | |||
| } // dst_width loop | |||
| // c2 loop | |||
| for (; c1 < c2; c1 += C2NUM) { | |||
| const float *src_kh = src_w; | |||
| const float *weight_kh = weight; | |||
| for (; c1 < c2; c1 += C2NUM, dst_w += C2NUM * block_channel, src_w += C2NUM * in_sw_step) { | |||
| const float *src_kh = src_w, *weight_kh = weight; | |||
| __m128 dst_w_ma1 = _mm_setzero_ps(); | |||
| __m128 dst_w_ma2 = _mm_setzero_ps(); | |||
| for (int kh = 0; kh < kernel_h; kh++) { | |||
| const float *src_kw = src_kh; | |||
| const float *weight_kw = weight_kh; | |||
| for (int kw = 0; kw < kernel_w; kw++) { | |||
| for (int kh = 0; kh < kernel_h; kh++, src_kh += in_kh_step, weight_kh += kernel_w * C4NUM) { | |||
| const float *src_kw = src_kh, *weight_kw = weight_kh; | |||
| for (int kw = 0; kw < kernel_w; kw++, src_kw += in_kw_step, weight_kw += C4NUM) { | |||
| __m128 src_kw_ma1 = _mm_loadu_ps(src_kw); | |||
| __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); | |||
| __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1); | |||
| @@ -210,68 +189,38 @@ void ConvDwFp32Center(float *dst, const float *src, const float *weight, const f | |||
| __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw); | |||
| __m128 tmp_ma2 = _mm_mul_ps(src_kw_ma2, weight_kw_ma2); | |||
| dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2); | |||
| src_kw += in_kw_step; | |||
| weight_kw += C4NUM; | |||
| } // kernel_w loop | |||
| src_kh += in_kh_step; | |||
| weight_kh += kernel_w * C4NUM; | |||
| } // kernel_h loop | |||
| } // kernel_h loop | |||
| // add bias relu | |||
| __m128 bias_ma = _mm_loadu_ps(bias); | |||
| dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma); | |||
| dst_w_ma2 = _mm_add_ps(dst_w_ma2, bias_ma); | |||
| __m128 zero_ma = _mm_setzero_ps(); | |||
| if (relu || relu6) { | |||
| dst_w_ma1 = _mm_max_ps(zero_ma, dst_w_ma1); | |||
| dst_w_ma2 = _mm_max_ps(zero_ma, dst_w_ma2); | |||
| if (relu6) { | |||
| __m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); | |||
| dst_w_ma1 = _mm_min_ps(const_ma, dst_w_ma1); | |||
| dst_w_ma2 = _mm_min_ps(const_ma, dst_w_ma2); | |||
| } | |||
| } | |||
| ActBlock2(&dst_w_ma1, &dst_w_ma2, relu, relu6); | |||
| _mm_storeu_ps(dst_w, dst_w_ma1); | |||
| _mm_storeu_ps(dst_w + block_channel, dst_w_ma2); | |||
| dst_w += C2NUM * block_channel; | |||
| src_w += C2NUM * in_sw_step; | |||
| } | |||
| // remaining | |||
| for (; c1 < width; c1++) { | |||
| const float *src_kh = src_w; | |||
| const float *weight_kh = weight; | |||
| for (; c1 < width; c1++, dst_w += block_channel, src_w += in_sw_step) { | |||
| const float *src_kh = src_w, *weight_kh = weight; | |||
| __m128 dst_w_ma1 = _mm_setzero_ps(); | |||
| for (int kh = 0; kh < kernel_h; kh++) { | |||
| const float *src_kw = src_kh; | |||
| const float *weight_kw = weight_kh; | |||
| for (int kw = 0; kw < kernel_w; kw++) { | |||
| for (int kh = 0; kh < kernel_h; kh++, src_kh += in_kh_step, weight_kh += kernel_w * C4NUM) { | |||
| const float *src_kw = src_kh, *weight_kw = weight_kh; | |||
| for (int kw = 0; kw < kernel_w; kw++, src_kw += in_kw_step, weight_kw += C4NUM) { | |||
| __m128 src_kw_ma1 = _mm_loadu_ps(src_kw); | |||
| __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); | |||
| __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1); | |||
| dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); | |||
| src_kw += in_kw_step; | |||
| weight_kw += C4NUM; | |||
| } // kernel_w loop | |||
| src_kh += in_kh_step; | |||
| weight_kh += kernel_w * C4NUM; | |||
| } // kernel_h loop | |||
| } // kernel_h loop | |||
| // add bias relu | |||
| __m128 bias_ma = _mm_loadu_ps(bias); | |||
| dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma); | |||
| __m128 zero_ma = _mm_setzero_ps(); | |||
| if (relu || relu6) { | |||
| dst_w_ma1 = _mm_max_ps(zero_ma, dst_w_ma1); | |||
| if (relu6) { | |||
| __m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); | |||
| dst_w_ma1 = _mm_min_ps(const_ma, dst_w_ma1); | |||
| } | |||
| } | |||
| ActBlock1(&dst_w_ma1, relu, relu6); | |||
| _mm_storeu_ps(dst_w, dst_w_ma1); | |||
| dst_w += block_channel; | |||
| src_w += in_sw_step; | |||
| } | |||
| dst_h += out_h_step; | |||
| src_h += in_sh_step; | |||
| @@ -0,0 +1,293 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifdef ENABLE_SSE | |||
| #include <x86intrin.h> | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/matmul_parameter.h" | |||
| #include "nnacl/intrinsics/sse/sse_common.h" | |||
| #include "nnacl/base/minimal_filtering_generator.h" | |||
| void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, | |||
| int in_channel, int c4_channel) { | |||
| const float *src1 = matix_a; | |||
| int c16 = DOWN_DIV(in_channel, C16NUM) * C16NUM; | |||
| int c8 = DOWN_DIV(in_channel, C8NUM) * C8NUM; | |||
| for (int i = 0; i < m; ++i) { | |||
| const float *src1_n = src1; | |||
| const float *src2_n = matrix_b; | |||
| for (int j = 0; j < n; ++j) { | |||
| const float *src1_j = src1_n; | |||
| int y = 0; | |||
| // 16 channel | |||
| for (; y < c16; y += C16NUM) { | |||
| __m128 dst1 = _mm_setzero_ps(); | |||
| __m128 dst2 = _mm_setzero_ps(); | |||
| __m128 dst3 = _mm_setzero_ps(); | |||
| __m128 dst4 = _mm_setzero_ps(); | |||
| const float *src2_y = src2_n; | |||
| for (int z = 0; z < k; ++z) { | |||
| __m128 ma1 = _mm_loadu_ps(src1_j); | |||
| __m128 ma2 = _mm_loadu_ps(src1_j + 4); | |||
| __m128 ma3 = _mm_loadu_ps(src1_j + 8); | |||
| __m128 ma4 = _mm_loadu_ps(src1_j + 12); | |||
| __m128 mb = _mm_load_ps1(src2_y); | |||
| __m128 tmp1 = _mm_mul_ps(ma1, mb); | |||
| __m128 tmp2 = _mm_mul_ps(ma2, mb); | |||
| __m128 tmp3 = _mm_mul_ps(ma3, mb); | |||
| __m128 tmp4 = _mm_mul_ps(ma4, mb); | |||
| dst1 = _mm_add_ps(dst1, tmp1); | |||
| dst2 = _mm_add_ps(dst2, tmp2); | |||
| dst3 = _mm_add_ps(dst3, tmp3); | |||
| dst4 = _mm_add_ps(dst4, tmp4); | |||
| src1_j += in_channel; | |||
| src2_y += n; | |||
| } | |||
| _mm_storeu_ps(matrix_c, dst1); | |||
| _mm_storeu_ps(matrix_c + 4, dst2); | |||
| _mm_storeu_ps(matrix_c + 8, dst3); | |||
| _mm_storeu_ps(matrix_c + 12, dst4); | |||
| src1_j -= in_channel * k; | |||
| src1_j += C16NUM; | |||
| matrix_c += C16NUM; | |||
| } | |||
| // 8 channel | |||
| for (; y < c8; y += C8NUM) { | |||
| __m128 dst1 = _mm_setzero_ps(); | |||
| __m128 dst2 = _mm_setzero_ps(); | |||
| const float *src2_y = src2_n; | |||
| for (int z = 0; z < k; ++z) { | |||
| __m128 ma1 = _mm_loadu_ps(src1_j); | |||
| __m128 ma2 = _mm_loadu_ps(src1_j + 4); | |||
| __m128 mb = _mm_load_ps1(src2_y); | |||
| __m128 tmp1 = _mm_mul_ps(ma1, mb); | |||
| __m128 tmp2 = _mm_mul_ps(ma2, mb); | |||
| dst1 = _mm_add_ps(dst1, tmp1); | |||
| dst2 = _mm_add_ps(dst2, tmp2); | |||
| src1_j += in_channel; | |||
| src2_y += n; | |||
| } | |||
| _mm_storeu_ps(matrix_c, dst1); | |||
| _mm_storeu_ps(matrix_c + 4, dst2); | |||
| src1_j -= in_channel * k; | |||
| src1_j += C8NUM; | |||
| matrix_c += C8NUM; | |||
| } | |||
| // remain chann | |||
| for (; y < in_channel; ++y) { | |||
| float tmp = 0; | |||
| for (int z = 0; z < k; ++z) { | |||
| tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n]; | |||
| } | |||
| *matrix_c++ = tmp; | |||
| } | |||
| src2_n += 1; | |||
| } | |||
| src1 += k * in_channel; | |||
| } | |||
| } | |||
| void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | |||
| int col, int stride, int write_mode) { | |||
| int C8Steps = row * C8NUM, WinoSteps1 = stride * col, WinoSteps2 = stride * C8NUM; | |||
| for (int r = row; r > 0; r -= C4NUM) { | |||
| const float *srcb_d = b, *bias_d = bias; | |||
| float *dst = NULL; | |||
| for (int cc = col; cc > 0; cc -= C8NUM) { | |||
| if (write_mode != 0) { // writec8 | |||
| dst = c; | |||
| } | |||
| const float *srca_d = a; | |||
| __m128 dst1 = _mm_setzero_ps(), dst2 = _mm_setzero_ps(), dst3 = _mm_setzero_ps(), dst4 = _mm_setzero_ps(); | |||
| __m128 dst5 = _mm_setzero_ps(), dst6 = _mm_setzero_ps(), dst7 = _mm_setzero_ps(), dst8 = _mm_setzero_ps(); | |||
| for (int d = depth; d > 0; --d) { | |||
| __m128 b1 = _mm_loadu_ps(srcb_d), b2 = _mm_loadu_ps(srcb_d + 4); | |||
| __m128 a1 = _mm_load_ps1(srca_d), a2 = _mm_load_ps1(srca_d + 1); | |||
| __m128 tmp1 = _mm_mul_ps(b1, a1), tmp2 = _mm_mul_ps(b2, a1); | |||
| __m128 tmp3 = _mm_mul_ps(b1, a2), tmp4 = _mm_mul_ps(b2, a2); | |||
| a1 = _mm_load_ps1(srca_d + 2); | |||
| dst1 = _mm_add_ps(dst1, tmp1), dst2 = _mm_add_ps(dst2, tmp2); | |||
| a2 = _mm_load_ps1(srca_d + 3); | |||
| dst3 = _mm_add_ps(dst3, tmp3), dst4 = _mm_add_ps(dst4, tmp4); | |||
| tmp1 = _mm_mul_ps(b1, a1), tmp2 = _mm_mul_ps(b2, a1); | |||
| tmp3 = _mm_mul_ps(b1, a2), tmp4 = _mm_mul_ps(b2, a2); | |||
| dst5 = _mm_add_ps(dst5, tmp1), dst6 = _mm_add_ps(dst6, tmp2); | |||
| dst7 = _mm_add_ps(dst7, tmp3), dst8 = _mm_add_ps(dst8, tmp4); | |||
| srcb_d += C8NUM, srca_d += C4NUM; | |||
| } | |||
| if (bias != NULL) { | |||
| DoBiasBlock8(bias_d, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8); | |||
| bias_d += C8NUM; | |||
| } | |||
| ActBlock8(&dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, act_type); | |||
| if (write_mode == OutType_TileC8) { // WriteWino | |||
| c = dst + WinoSteps2; | |||
| _mm_storeu_ps(dst, dst1), _mm_storeu_ps(dst + 4, dst2); | |||
| dst += WinoSteps1; | |||
| _mm_storeu_ps(dst, dst3), _mm_storeu_ps(dst + 4, dst4); | |||
| dst += WinoSteps1; | |||
| _mm_storeu_ps(dst, dst5), _mm_storeu_ps(dst + 4, dst6); | |||
| dst += WinoSteps1; | |||
| _mm_storeu_ps(dst, dst7), _mm_storeu_ps(dst + 4, dst8); | |||
| } else if (write_mode == OutType_C8) { // WriteC8 | |||
| _mm_storeu_ps(c, dst1), _mm_storeu_ps(c + 4, dst2); | |||
| _mm_storeu_ps(c + 8, dst3), _mm_storeu_ps(c + 12, dst4); | |||
| _mm_storeu_ps(c + 16, dst5), _mm_storeu_ps(c + 20, dst6); | |||
| _mm_storeu_ps(c + 24, dst7), _mm_storeu_ps(c + 28, dst8); | |||
| c += C8Steps; | |||
| } else { | |||
| switch (cc) { | |||
| case 1: // write1 | |||
| c = dst + 1; | |||
| WriteCol1(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 1, r); | |||
| break; | |||
| case 2: // write2 | |||
| c = dst + 2; | |||
| WriteCol2Opt(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, r); | |||
| break; | |||
| case 3: // write3 | |||
| c = dst + 3; | |||
| _mm_store_ss(dst, dst1); | |||
| dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 1, dst1); | |||
| dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 2, dst1); | |||
| WriteCol3(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 3, r); | |||
| break; | |||
| case 4: // write4 | |||
| c = dst + 4; | |||
| WriteCol4(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 4, r); | |||
| break; | |||
| case 5: // write5 | |||
| c = dst + 5; | |||
| WriteCol5(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 5, r); | |||
| break; | |||
| case 6: // write6 | |||
| c = dst + 6; | |||
| WriteCol6(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 6, r); | |||
| break; | |||
| case 7: // write7 | |||
| c = dst + 7; | |||
| WriteCol7(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 7, r); | |||
| break; | |||
| default: // write8 | |||
| c = dst + C8NUM; | |||
| WriteCol8(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 8, r); | |||
| break; | |||
| } | |||
| } | |||
| if (cc <= C8NUM) break; // write end | |||
| } | |||
| a += C4NUM * depth; | |||
| if (write_mode == OutType_C8) c += 32; | |||
| if (write_mode == OutType_TileC8) c = dst + WinoSteps2; | |||
| if (write_mode == OutType_Nhwc) c = dst - col; | |||
| if (r <= C4NUM) break; | |||
| } | |||
| } | |||
| void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | |||
| int col, int stride, size_t writeNhwc, size_t WriteWino) { | |||
| size_t DstWinoSteps = stride * C8NUM, WriteWinoSteps = stride * col; | |||
| for (int col_tmp = col; col_tmp > 0; col_tmp -= C8NUM) { | |||
| const float *srca_d = a; | |||
| float *dst = c; | |||
| for (int r = row; r > 0; r -= C4NUM) { | |||
| const float *srcb_d = b; | |||
| __m128 dst1 = _mm_setzero_ps(), dst2 = _mm_setzero_ps(); | |||
| __m128 dst3 = _mm_setzero_ps(), dst4 = _mm_setzero_ps(); | |||
| __m128 dst5 = _mm_setzero_ps(), dst6 = _mm_setzero_ps(); | |||
| __m128 dst7 = _mm_setzero_ps(), dst8 = _mm_setzero_ps(); | |||
| for (int d = 0; d < depth; d++) { | |||
| __m128 b1 = _mm_loadu_ps(srcb_d), b2 = _mm_loadu_ps(srcb_d + 4); | |||
| __m128 a1 = _mm_load_ps1(srca_d), a2 = _mm_load_ps1(srca_d + 1); | |||
| __m128 tmp1 = _mm_mul_ps(b1, a1), tmp2 = _mm_mul_ps(b2, a1); | |||
| __m128 tmp3 = _mm_mul_ps(b1, a2), tmp4 = _mm_mul_ps(b2, a2); | |||
| a1 = _mm_load_ps1(srca_d + 2); | |||
| dst1 = _mm_add_ps(dst1, tmp1), dst2 = _mm_add_ps(dst2, tmp2); | |||
| a2 = _mm_load_ps1(srca_d + 3); | |||
| dst3 = _mm_add_ps(dst3, tmp3), dst4 = _mm_add_ps(dst4, tmp4); | |||
| tmp1 = _mm_mul_ps(b1, a1), tmp2 = _mm_mul_ps(b2, a1); | |||
| tmp3 = _mm_mul_ps(b1, a2), tmp4 = _mm_mul_ps(b2, a2); | |||
| dst5 = _mm_add_ps(dst5, tmp1), dst6 = _mm_add_ps(dst6, tmp2); | |||
| dst7 = _mm_add_ps(dst7, tmp3), dst8 = _mm_add_ps(dst8, tmp4); | |||
| srcb_d += C8NUM, srca_d += C4NUM; | |||
| } | |||
| if (bias != NULL) { | |||
| DoBiasBlock8(bias, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8); | |||
| } | |||
| ActBlock8(&dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, act_type); | |||
| if (WriteWino != 0) { // WriteWino | |||
| _mm_storeu_ps(dst, dst1), _mm_storeu_ps(dst + 4, dst2); | |||
| dst += WriteWinoSteps; | |||
| _mm_storeu_ps(dst, dst3), _mm_storeu_ps(dst + 4, dst4); | |||
| dst += WriteWinoSteps; | |||
| _mm_storeu_ps(dst, dst5), _mm_storeu_ps(dst + 4, dst6); | |||
| dst += WriteWinoSteps; | |||
| _mm_storeu_ps(dst, dst7), _mm_storeu_ps(dst + 4, dst8); | |||
| dst += WriteWinoSteps; | |||
| } else if (writeNhwc == 0) { // WriteC8 | |||
| _mm_storeu_ps(dst, dst1), _mm_storeu_ps(dst + 4, dst2); | |||
| _mm_storeu_ps(dst + 8, dst3), _mm_storeu_ps(dst + 12, dst4); | |||
| _mm_storeu_ps(dst + 16, dst5), _mm_storeu_ps(dst + 20, dst6); | |||
| _mm_storeu_ps(dst + 24, dst7), _mm_storeu_ps(dst + 28, dst8); | |||
| dst += 32; | |||
| c = dst; | |||
| } else { | |||
| switch (col) { | |||
| case 1: // write1 | |||
| WriteCol1(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r); | |||
| case 2: // write2 | |||
| WriteCol2(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, r); | |||
| case 3: // write3 | |||
| WriteCol3(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r); | |||
| case 4: // write4 | |||
| WriteCol4(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r); | |||
| case 5: // // write | |||
| WriteCol5(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r); | |||
| case 6: // write6 | |||
| WriteCol6(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r); | |||
| case 7: // write7 | |||
| WriteCol7(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r); | |||
| default: // write8 | |||
| WriteCol8(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r); | |||
| } | |||
| } | |||
| if (r <= C4NUM) { // WriteEnd | |||
| break; | |||
| } | |||
| } | |||
| b += depth * C8NUM; | |||
| bias += (bias != NULL) ? C8NUM : 0; | |||
| if (WriteWino != 0) { | |||
| c += DstWinoSteps; | |||
| } else if (writeNhwc != 0) { | |||
| c += C8NUM; | |||
| } | |||
| if (col_tmp <= C8NUM) { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| @@ -17,11 +17,10 @@ | |||
| #ifdef ENABLE_SSE | |||
| #include <x86intrin.h> | |||
| #include "nnacl/fp32/common_func_fp32.h" | |||
| #include "nnacl/intrinsics/sse/sse_common.h" | |||
| void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, | |||
| size_t plane_size, size_t plane_stride, size_t relu_type) { | |||
| __m128 relu6 = _mm_set_ps1(6.0); | |||
| __m128 zero = _mm_setzero_ps(); | |||
| size_t stride = oc4div + oc4mod; | |||
| plane_stride /= sizeof(float); | |||
| for (size_t loop_c4 = 0; loop_c4 < oc4div; loop_c4 += C4NUM) { | |||
| @@ -42,19 +41,9 @@ void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t | |||
| src2 = _mm_add_ps(src2, bias1); | |||
| src3 = _mm_add_ps(src3, bias1); | |||
| src4 = _mm_add_ps(src4, bias1); | |||
| switch (relu_type) { | |||
| case 3: | |||
| src1 = _mm_min_ps(src1, relu6); | |||
| src2 = _mm_min_ps(src2, relu6); | |||
| src3 = _mm_min_ps(src3, relu6); | |||
| src4 = _mm_min_ps(src4, relu6); | |||
| case 1: | |||
| src1 = _mm_max_ps(src1, zero); | |||
| src2 = _mm_max_ps(src2, zero); | |||
| src3 = _mm_max_ps(src3, zero); | |||
| src4 = _mm_max_ps(src4, zero); | |||
| break; | |||
| } | |||
| ActBlock4(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == 3); | |||
| _mm_storeu_ps(dst_c4, src1); | |||
| dst_c4 += stride; | |||
| _mm_storeu_ps(dst_c4, src2); | |||
| @@ -67,20 +56,15 @@ void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t | |||
| for (; plane_size_tmp > 0; plane_size_tmp -= 1) { | |||
| __m128 src1 = _mm_loadu_ps(src); | |||
| src1 = _mm_add_ps(src1, bias1); | |||
| switch (relu_type) { | |||
| case 3: | |||
| src1 = _mm_min_ps(src1, relu6); | |||
| case 1: | |||
| src1 = _mm_max_ps(src1, zero); | |||
| break; | |||
| } | |||
| ActBlock1(&src1, relu_type == 1, relu_type == 3); | |||
| _mm_storeu_ps(dst_c4, src1); | |||
| dst_c4 += stride; | |||
| src += 4; | |||
| } | |||
| src += plane_stride; | |||
| } | |||
| if (oc4mod == 0) { | |||
| return; | |||
| } | |||
| @@ -94,13 +78,9 @@ void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t | |||
| __m128 src1 = _mm_loadu_ps(src); | |||
| src += 4; | |||
| src1 = _mm_add_ps(src1, bias1); | |||
| switch (relu_type) { | |||
| case 3: | |||
| src1 = _mm_min_ps(src1, relu6); | |||
| case 1: | |||
| src1 = _mm_max_ps(src1, zero); | |||
| break; | |||
| } | |||
| ActBlock1(&src1, relu_type == 1, relu_type == 3); | |||
| switch (oc4mod) { | |||
| case 1: | |||
| _mm_store_ss(dst_c1, src1); | |||
| @@ -17,23 +17,21 @@ | |||
| #ifdef ENABLE_SSE | |||
| #include <x86intrin.h> | |||
| #include "nnacl/fp32/common_func_fp32.h" | |||
| #include "nnacl/intrinsics/sse/sse_common.h" | |||
| void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, | |||
| size_t plane_size, size_t stride, size_t relu_type) { | |||
| __m128 relu6 = _mm_set_ps1(6.0); | |||
| __m128 zero = _mm_setzero_ps(); | |||
| stride /= sizeof(float); | |||
| for (int loop_c8 = 0; !(loop_c8 == oc8div); loop_c8 += C8NUM) { | |||
| size_t plane_size_tmp = plane_size; | |||
| float *dst_c8 = dst + loop_c8; | |||
| __m128 bias1 = _mm_setzero_ps(); | |||
| __m128 bias2 = _mm_setzero_ps(); | |||
| __m128 bias1 = _mm_setzero_ps(), bias2 = _mm_setzero_ps(); | |||
| if (bias != NULL) { | |||
| bias1 = _mm_loadu_ps(bias); | |||
| bias2 = _mm_loadu_ps(bias + 4); | |||
| bias += 8; | |||
| } | |||
| for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { | |||
| for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM, src += 32) { | |||
| __m128 src1 = _mm_loadu_ps(src); | |||
| __m128 src2 = _mm_loadu_ps(src + 4); | |||
| __m128 src3 = _mm_loadu_ps(src + 8); | |||
| @@ -42,7 +40,6 @@ void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t | |||
| __m128 src6 = _mm_loadu_ps(src + 20); | |||
| __m128 src7 = _mm_loadu_ps(src + 24); | |||
| __m128 src8 = _mm_loadu_ps(src + 28); | |||
| src += 32; | |||
| src1 = _mm_add_ps(src1, bias1); | |||
| src2 = _mm_add_ps(src2, bias2); | |||
| src3 = _mm_add_ps(src3, bias1); | |||
| @@ -51,27 +48,9 @@ void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t | |||
| src6 = _mm_add_ps(src6, bias2); | |||
| src7 = _mm_add_ps(src7, bias1); | |||
| src8 = _mm_add_ps(src8, bias2); | |||
| switch (relu_type) { | |||
| case 3: | |||
| src1 = _mm_min_ps(src1, relu6); | |||
| src2 = _mm_min_ps(src2, relu6); | |||
| src3 = _mm_min_ps(src3, relu6); | |||
| src4 = _mm_min_ps(src4, relu6); | |||
| src5 = _mm_min_ps(src5, relu6); | |||
| src6 = _mm_min_ps(src6, relu6); | |||
| src7 = _mm_min_ps(src7, relu6); | |||
| src8 = _mm_min_ps(src8, relu6); | |||
| case 1: | |||
| src1 = _mm_max_ps(src1, zero); | |||
| src2 = _mm_max_ps(src2, zero); | |||
| src3 = _mm_max_ps(src3, zero); | |||
| src4 = _mm_max_ps(src4, zero); | |||
| src5 = _mm_max_ps(src5, zero); | |||
| src6 = _mm_max_ps(src6, zero); | |||
| src7 = _mm_max_ps(src7, zero); | |||
| src8 = _mm_max_ps(src8, zero); | |||
| break; | |||
| } | |||
| ActBlock8(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, relu_type); | |||
| _mm_storeu_ps(dst_c8, src1); | |||
| _mm_storeu_ps(dst_c8 + 4, src2); | |||
| dst_c8 += stride; | |||
| @@ -85,29 +64,21 @@ void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t | |||
| _mm_storeu_ps(dst_c8 + 4, src8); | |||
| dst_c8 += stride; | |||
| } | |||
| for (; plane_size_tmp > 0; plane_size_tmp -= 1) { | |||
| for (; plane_size_tmp > 0; plane_size_tmp -= 1, src += 8, dst_c8 += stride) { | |||
| __m128 src1 = _mm_loadu_ps(src); | |||
| __m128 src2 = _mm_loadu_ps(src + 4); | |||
| src1 = _mm_add_ps(src1, bias1); | |||
| src2 = _mm_add_ps(src2, bias2); | |||
| switch (relu_type) { | |||
| case 3: | |||
| src1 = _mm_min_ps(src1, relu6); | |||
| src2 = _mm_min_ps(src2, relu6); | |||
| case 1: | |||
| src1 = _mm_max_ps(src1, zero); | |||
| src2 = _mm_max_ps(src2, zero); | |||
| break; | |||
| } | |||
| ActBlock2(&src1, &src2, relu_type == 1, relu_type == 3); | |||
| _mm_storeu_ps(dst_c8, src1); | |||
| _mm_storeu_ps(dst_c8 + 4, src2); | |||
| dst_c8 += stride; | |||
| src += 8; | |||
| } | |||
| } | |||
| if (oc8mod == 0) { | |||
| return; | |||
| } | |||
| if ((oc8mod == 0)) return; | |||
| __m128 bias1 = _mm_setzero_ps(); | |||
| __m128 bias2 = _mm_setzero_ps(); | |||
| if (bias != NULL) { | |||
| @@ -116,56 +87,42 @@ void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t | |||
| bias += 8; | |||
| } | |||
| float *dst_c1 = dst + oc8div; | |||
| for (size_t plane_size_tmp = plane_size; plane_size_tmp > 0; plane_size_tmp -= 1) { | |||
| for (size_t plane_size_tmp = plane_size; plane_size_tmp > 0; plane_size_tmp -= 1, src += 8, dst_c1 += stride) { | |||
| __m128 src1 = _mm_loadu_ps(src); | |||
| __m128 src2 = _mm_loadu_ps(src + 4); | |||
| src += 8; | |||
| src1 = _mm_add_ps(src1, bias1); | |||
| src2 = _mm_add_ps(src2, bias2); | |||
| switch (relu_type) { | |||
| case 3: | |||
| src1 = _mm_min_ps(src1, relu6); | |||
| src2 = _mm_min_ps(src2, relu6); | |||
| case 1: | |||
| src1 = _mm_max_ps(src1, zero); | |||
| src2 = _mm_max_ps(src2, zero); | |||
| break; | |||
| } | |||
| ActBlock2(&src1, &src2, relu_type == 1, relu_type == 3); | |||
| switch (oc8mod) { | |||
| case 1: | |||
| _mm_store_ss(dst_c1, src1); | |||
| dst_c1 += stride; | |||
| break; | |||
| case 2: | |||
| _mm_storel_pi((__m64 *)(dst_c1), src1); | |||
| dst_c1 += stride; | |||
| break; | |||
| case 3: | |||
| _mm_storel_pi((__m64 *)(dst_c1), src1); | |||
| src1 = _mm_unpackhi_ps(src1, src1); | |||
| _mm_store_ss(dst_c1 + 2, src1); | |||
| dst_c1 += stride; | |||
| break; | |||
| case 4: | |||
| _mm_storeu_ps(dst_c1, src1); | |||
| dst_c1 += stride; | |||
| break; | |||
| case 5: | |||
| _mm_storeu_ps(dst_c1, src1); | |||
| _mm_store_ss(dst_c1 + 4, src2); | |||
| dst_c1 += stride; | |||
| break; | |||
| case 6: | |||
| _mm_storeu_ps(dst_c1, src1); | |||
| _mm_storel_pi((__m64 *)(dst_c1 + 4), src2); | |||
| dst_c1 += stride; | |||
| break; | |||
| case 7: | |||
| _mm_storeu_ps(dst_c1, src1); | |||
| _mm_storel_pi((__m64 *)(dst_c1 + 4), src2); | |||
| src2 = _mm_unpackhi_ps(src2, src2); | |||
| _mm_store_ss(dst_c1 + 6, src2); | |||
| dst_c1 += stride; | |||
| break; | |||
| } | |||
| } | |||
| @@ -0,0 +1,146 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifdef ENABLE_SSE | |||
| #include <x86intrin.h> | |||
| #include "nnacl/fp32/common_func_fp32.h" | |||
| void TiledC4MatmulFp32_Transfer(__m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, const __m128 weight, | |||
| const float v1, const float v2, const float v3, const float v4) { | |||
| *dst1 = _mm_add_ps(*dst1, _mm_mul_ps(weight, _mm_set_ps1(v1))); | |||
| *dst2 = _mm_add_ps(*dst2, _mm_mul_ps(weight, _mm_set_ps1(v2))); | |||
| *dst3 = _mm_add_ps(*dst3, _mm_mul_ps(weight, _mm_set_ps1(v3))); | |||
| *dst4 = _mm_add_ps(*dst4, _mm_mul_ps(weight, _mm_set_ps1(v4))); | |||
| } | |||
| void TiledC4MatmulFp32_LoadData(__m128 *src1, __m128 *src2, __m128 *src3, __m128 *src4, const float *src) { | |||
| *src1 = _mm_loadu_ps(src); | |||
| *src2 = _mm_loadu_ps(src + 4); | |||
| *src3 = _mm_loadu_ps(src + 8); | |||
| *src4 = _mm_loadu_ps(src + 12); | |||
| } | |||
| void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t cal_num, size_t ic4, size_t oc4) { | |||
| const float *src_tmp = src; | |||
| for (int i = 0; i < oc4; ++i) { | |||
| float *dst_tmp = dst; | |||
| src = src_tmp; | |||
| size_t ic4_tmp = ic4 - 1; | |||
| __m128 src1 = _mm_loadu_ps(src); | |||
| __m128 src2 = _mm_loadu_ps(src + 4); | |||
| __m128 src3 = _mm_loadu_ps(src + 8); | |||
| __m128 src4 = _mm_loadu_ps(src + 12); | |||
| src += 16; | |||
| __m128 weight_data[4]; | |||
| weight_data[0] = _mm_loadu_ps(weight); | |||
| weight_data[1] = _mm_loadu_ps(weight + 4); | |||
| weight_data[2] = _mm_loadu_ps(weight + 8); | |||
| weight_data[3] = _mm_loadu_ps(weight + 12); | |||
| weight += 16; | |||
| __m128 dst1 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0])); | |||
| __m128 dst2 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0])); | |||
| __m128 dst3 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0])); | |||
| __m128 dst4 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0])); | |||
| for (int j = 1; j < 4; ++j) { | |||
| TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[j], src1[j], src2[j], src3[j], src4[j]); | |||
| } | |||
| TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src); | |||
| src += 16; | |||
| __m128 dst5 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0])); | |||
| __m128 dst6 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0])); | |||
| __m128 dst7 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0])); | |||
| __m128 dst8 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0])); | |||
| for (int j = 1; j < 4; ++j) { | |||
| TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[j], src1[j], src2[j], src3[j], src4[j]); | |||
| } | |||
| if (ic4_tmp != 0) { | |||
| ic4_tmp -= 1; | |||
| TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src); | |||
| src += 16; | |||
| weight_data[0] = _mm_loadu_ps(weight); | |||
| weight_data[1] = _mm_loadu_ps(weight + 4); | |||
| weight += 8; | |||
| dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0]))); | |||
| dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0]))); | |||
| for (; ic4_tmp != 0; ic4_tmp -= 1) { | |||
| dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0]))); | |||
| dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0]))); | |||
| TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[1], src1[1], src2[1], src3[1], src4[1]); | |||
| weight_data[2] = _mm_loadu_ps(weight); | |||
| weight_data[3] = _mm_loadu_ps(weight + 4); | |||
| weight += 8; | |||
| TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[2], src1[2], src2[2], src3[2], src4[2]); | |||
| dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[3], _mm_set_ps1(src1[3]))); | |||
| dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[3], _mm_set_ps1(src2[3]))); | |||
| src1 = _mm_loadu_ps(src); | |||
| src2 = _mm_loadu_ps(src + 4); | |||
| dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[3], _mm_set_ps1(src3[3]))); | |||
| dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[3], _mm_set_ps1(src4[3]))); | |||
| src3 = _mm_loadu_ps(src + 8); | |||
| src4 = _mm_loadu_ps(src + 12); | |||
| src += 16; | |||
| TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[0], src1[0], src2[0], src3[0], src4[0]); | |||
| TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[1], src1[1], src2[1], src3[1], src4[1]); | |||
| TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[2], src1[2], src2[2], src3[2], src4[2]); | |||
| weight_data[0] = _mm_loadu_ps(weight); | |||
| weight_data[1] = _mm_loadu_ps(weight + 4); | |||
| weight += 8; | |||
| TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[3], src1[3], src2[3], src3[3], src4[3]); | |||
| TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src); | |||
| src += 16; | |||
| dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0]))); | |||
| dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0]))); | |||
| } | |||
| dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0]))); | |||
| dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0]))); | |||
| TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[1], src1[1], src2[1], src3[1], src4[1]); | |||
| weight_data[2] = _mm_loadu_ps(weight); | |||
| weight_data[3] = _mm_loadu_ps(weight + 4); | |||
| weight += 8; | |||
| TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[2], src1[2], src2[2], src3[2], src4[2]); | |||
| TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[3], src1[3], src2[3], src3[3], src4[3]); | |||
| TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src); | |||
| src += 16; | |||
| for (int j = 0; j < 4; ++j) { | |||
| TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[j], src1[j], src2[j], src3[j], src4[j]); | |||
| } | |||
| } | |||
| _mm_storeu_ps(dst, dst1); | |||
| _mm_storeu_ps(dst + 4, dst2); | |||
| _mm_storeu_ps(dst + 8, dst3); | |||
| _mm_storeu_ps(dst + 12, dst4); | |||
| _mm_storeu_ps(dst + 16, dst5); | |||
| _mm_storeu_ps(dst + 20, dst6); | |||
| _mm_storeu_ps(dst + 24, dst7); | |||
| _mm_storeu_ps(dst + 28, dst8); | |||
| dst = dst_tmp + cal_num; | |||
| } | |||
| } | |||
| #endif | |||
| @@ -36,7 +36,7 @@ void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_ | |||
| __m128 k6 = _mm_load_ps1(BK + 5 * h); | |||
| __m128 k7 = _mm_load_ps1(BK + 6 * h); | |||
| BK += 7 * h; | |||
| for (int len_tmp = length; len_tmp > 0; --len_tmp) { | |||
| for (int len_tmp = length; len_tmp > 0; --len_tmp, M += 4, SK += 4) { | |||
| __m128 M1 = _mm_loadu_ps(M); | |||
| __m128 s0 = _mm_loadu_ps(SK); | |||
| M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); | |||
| @@ -54,8 +54,6 @@ void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_ | |||
| M1 = _mm_add_ps(M1, _mm_mul_ps(s7, k7)); | |||
| M1 = _mm_add_ps(M1, s1); | |||
| _mm_storeu_ps(M, M1); | |||
| M += 4; | |||
| SK += 4; | |||
| } | |||
| M -= len_c4; | |||
| SK += 7 * S_step - len_c4; | |||
| @@ -66,7 +64,7 @@ void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_ | |||
| __m128 k3 = _mm_load_ps1(BK + 2 * h); | |||
| __m128 k4 = _mm_load_ps1(BK + 3 * h); | |||
| BK += 4 * h; | |||
| for (int len_tmp = length; len_tmp > 0; --len_tmp) { | |||
| for (int len_tmp = length; len_tmp > 0; --len_tmp, SK += 4, M += 4) { | |||
| __m128 M1 = _mm_loadu_ps(M); | |||
| __m128 s0 = _mm_loadu_ps(SK); | |||
| M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); | |||
| @@ -78,8 +76,6 @@ void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_ | |||
| s1 = _mm_add_ps(s1, _mm_mul_ps(s4, k4)); | |||
| M1 = _mm_add_ps(M1, s1); | |||
| _mm_storeu_ps(M, M1); | |||
| SK += 4; | |||
| M += 4; | |||
| } | |||
| M -= len_c4; | |||
| SK += 4 * S_step - len_c4; | |||
| @@ -89,7 +85,7 @@ void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_ | |||
| __m128 k2 = _mm_load_ps1(BK + h); | |||
| __m128 k3 = _mm_load_ps1(BK + 2 * h); | |||
| BK += 3 * h; | |||
| for (int len_tmp = length; len_tmp > 0; --len_tmp) { | |||
| for (int len_tmp = length; len_tmp > 0; --len_tmp, SK += 4, M += 4) { | |||
| __m128 M1 = _mm_loadu_ps(M); | |||
| __m128 s0 = _mm_loadu_ps(SK); | |||
| M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); | |||
| @@ -99,8 +95,6 @@ void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_ | |||
| M1 = _mm_add_ps(M1, _mm_mul_ps(s3, k3)); | |||
| M1 = _mm_add_ps(M1, s1); | |||
| _mm_storeu_ps(M, M1); | |||
| SK += 4; | |||
| M += 4; | |||
| } | |||
| M -= len_c4; | |||
| SK += 3 * S_step - len_c4; | |||
| @@ -108,13 +102,11 @@ void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_ | |||
| for (; k_tmp > 0; k_tmp -= 1) { | |||
| __m128 k1 = _mm_load_ps1(BK); | |||
| BK += h; | |||
| for (int len_tmp = length; len_tmp > 0; --len_tmp) { | |||
| for (int len_tmp = length; len_tmp > 0; --len_tmp, SK += 4, M += 4) { | |||
| __m128 M1 = _mm_loadu_ps(M); | |||
| __m128 s0 = _mm_loadu_ps(SK); | |||
| M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); | |||
| _mm_storeu_ps(M, M1); | |||
| SK += 4; | |||
| M += 4; | |||
| } | |||
| M -= len_c4; | |||
| SK += S_step - len_c4; | |||
| @@ -127,16 +119,14 @@ void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_ | |||
| } | |||
| void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { | |||
| size_t len_c4 = length * 4; | |||
| size_t k_step = len_c4 * k; | |||
| for (int h1 = 0; h1 < h; ++h1) { | |||
| size_t len_c4 = length * 4, k_step = len_c4 * k; | |||
| for (int h1 = 0; h1 < h; ++h1, S += k_step) { | |||
| const float *BW = B; | |||
| for (int ww = 0; ww < w; ++ww) { | |||
| const float *SK = S; // r0 | |||
| const float *BK = BW; // r1 | |||
| for (int ww = 0; ww < w; ++ww, BW += 1, M += len_c4) { | |||
| const float *SK = S, *BK = BW; | |||
| memset(M, 0, len_c4 * sizeof(float)); | |||
| int k_tmp = k; | |||
| for (; k_tmp >= 7; k_tmp -= 7) { | |||
| for (; k_tmp >= 7; k_tmp -= 7, M -= len_c4) { | |||
| __m128 k1 = _mm_load_ps1(BK); | |||
| __m128 k2 = _mm_load_ps1(BK + h); | |||
| __m128 k3 = _mm_load_ps1(BK + 2 * h); | |||
| @@ -145,13 +135,11 @@ void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size | |||
| __m128 k6 = _mm_load_ps1(BK + 5 * h); | |||
| __m128 k7 = _mm_load_ps1(BK + 6 * h); | |||
| BK += 7 * h; | |||
| const float *S2 = SK + len_c4; | |||
| const float *S3 = S2 + len_c4; | |||
| const float *S4 = S3 + len_c4; | |||
| const float *S5 = S4 + len_c4; | |||
| const float *S6 = S5 + len_c4; | |||
| const float *S7 = S6 + len_c4; | |||
| for (int len_tmp = length; len_tmp > 0; --len_tmp) { | |||
| const float *S2 = SK + len_c4, *S3 = S2 + len_c4; | |||
| const float *S4 = S3 + len_c4, *S5 = S4 + len_c4; | |||
| const float *S6 = S5 + len_c4, *S7 = S6 + len_c4; | |||
| for (int len_tmp = length; len_tmp > 0; | |||
| --len_tmp, M += 4, SK += 4, S2 += 4, S3 += 4, S4 += 4, S5 += 4, S6 += 4, S7 += 4) { | |||
| __m128 M1 = _mm_loadu_ps(M); | |||
| __m128 s0 = _mm_loadu_ps(SK); | |||
| M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); | |||
| @@ -169,19 +157,10 @@ void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size | |||
| M1 = _mm_add_ps(M1, _mm_mul_ps(s7, k7)); | |||
| M1 = _mm_add_ps(M1, s1); | |||
| _mm_storeu_ps(M, M1); | |||
| M += 4; | |||
| SK += 4; | |||
| S2 += 4; | |||
| S3 += 4; | |||
| S4 += 4; | |||
| S5 += 4; | |||
| S6 += 4; | |||
| S7 += 4; | |||
| } | |||
| M -= len_c4; | |||
| SK = S7; | |||
| } | |||
| for (; k_tmp >= 4; k_tmp -= 4) { | |||
| for (; k_tmp >= 4; k_tmp -= 4, M -= len_c4) { | |||
| __m128 k1 = _mm_load_ps1(BK); | |||
| __m128 k2 = _mm_load_ps1(BK + h); | |||
| __m128 k3 = _mm_load_ps1(BK + 2 * h); | |||
| @@ -190,7 +169,7 @@ void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size | |||
| const float *S2 = SK + len_c4; | |||
| const float *S3 = S2 + len_c4; | |||
| const float *S4 = S3 + len_c4; | |||
| for (int len_tmp = length; len_tmp > 0; --len_tmp) { | |||
| for (int len_tmp = length; len_tmp > 0; --len_tmp, M += 4, SK += 4, S2 += 4, S3 += 4, S4 += 4) { | |||
| __m128 M1 = _mm_loadu_ps(M); | |||
| __m128 s0 = _mm_loadu_ps(SK); | |||
| M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); | |||
| @@ -202,23 +181,17 @@ void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size | |||
| s1 = _mm_add_ps(s1, _mm_mul_ps(s4, k4)); | |||
| M1 = _mm_add_ps(M1, s1); | |||
| _mm_storeu_ps(M, M1); | |||
| M += 4; | |||
| SK += 4; | |||
| S2 += 4; | |||
| S3 += 4; | |||
| S4 += 4; | |||
| } | |||
| M -= len_c4; | |||
| SK = S4; | |||
| } | |||
| for (; k_tmp >= 3; k_tmp -= 3) { | |||
| for (; k_tmp >= 3; k_tmp -= 3, M -= len_c4) { | |||
| __m128 k1 = _mm_load_ps1(BK); | |||
| __m128 k2 = _mm_load_ps1(BK + h); | |||
| __m128 k3 = _mm_load_ps1(BK + 2 * h); | |||
| BK += 3 * h; | |||
| const float *S2 = SK + len_c4; | |||
| const float *S3 = S2 + len_c4; | |||
| for (int len_tmp = length; len_tmp > 0; --len_tmp) { | |||
| for (int len_tmp = length; len_tmp > 0; --len_tmp, M += 4, SK += 4, S2 += 4, S3 += 4) { | |||
| __m128 M1 = _mm_loadu_ps(M); | |||
| __m128 s0 = _mm_loadu_ps(SK); | |||
| M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); | |||
| @@ -228,31 +201,20 @@ void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size | |||
| M1 = _mm_add_ps(M1, _mm_mul_ps(s3, k3)); | |||
| M1 = _mm_add_ps(M1, s1); | |||
| _mm_storeu_ps(M, M1); | |||
| M += 4; | |||
| SK += 4; | |||
| S2 += 4; | |||
| S3 += 4; | |||
| } | |||
| M -= len_c4; | |||
| SK = S3; | |||
| } | |||
| for (; k_tmp >= 1; k_tmp -= 1) { | |||
| for (; k_tmp >= 1; k_tmp -= 1, M -= len_c4) { | |||
| __m128 k1 = _mm_load_ps1(BK); | |||
| BK += h; | |||
| for (int len_tmp = length; len_tmp > 0; --len_tmp) { | |||
| for (int len_tmp = length; len_tmp > 0; --len_tmp, M += 4, SK += 4) { | |||
| __m128 M1 = _mm_loadu_ps(M); | |||
| __m128 s0 = _mm_loadu_ps(SK); | |||
| M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); | |||
| _mm_storeu_ps(M, M1); | |||
| M += 4; | |||
| SK += 4; | |||
| } | |||
| M -= len_c4; | |||
| } | |||
| BW += 1; | |||
| M += len_c4; | |||
| } | |||
| S += k_step; | |||
| } | |||
| } | |||
| #endif | |||
| @@ -0,0 +1,336 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifdef ENABLE_SSE | |||
| #include <x86intrin.h> | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/intrinsics/sse/sse_common.h" | |||
| void ActBlock1(__m128 *v1, size_t relu, size_t relu6) { | |||
| __m128 zero_ma = _mm_setzero_ps(); | |||
| if (relu || relu6) { | |||
| *v1 = _mm_max_ps(zero_ma, *v1); | |||
| } | |||
| if (relu6) { | |||
| __m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); | |||
| *v1 = _mm_min_ps(relu6_ma, *v1); | |||
| } | |||
| } | |||
| void ActBlock2(__m128 *v1, __m128 *v2, size_t relu, size_t relu6) { | |||
| __m128 zero_ma = _mm_setzero_ps(); | |||
| if (relu || relu6) { | |||
| *v1 = _mm_max_ps(zero_ma, *v1); | |||
| *v2 = _mm_max_ps(zero_ma, *v2); | |||
| } | |||
| if (relu6) { | |||
| __m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); | |||
| *v1 = _mm_min_ps(relu6_ma, *v1); | |||
| *v2 = _mm_min_ps(relu6_ma, *v2); | |||
| } | |||
| } | |||
| void ActBlock4(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, size_t relu, size_t relu6) { | |||
| __m128 zero_ma = _mm_setzero_ps(); | |||
| if (relu || relu6) { | |||
| *v1 = _mm_max_ps(zero_ma, *v1); | |||
| *v2 = _mm_max_ps(zero_ma, *v2); | |||
| *v3 = _mm_max_ps(zero_ma, *v3); | |||
| *v4 = _mm_max_ps(zero_ma, *v4); | |||
| } | |||
| if (relu6) { | |||
| __m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); | |||
| *v1 = _mm_min_ps(relu6_ma, *v1); | |||
| *v2 = _mm_min_ps(relu6_ma, *v2); | |||
| *v3 = _mm_min_ps(relu6_ma, *v3); | |||
| *v4 = _mm_min_ps(relu6_ma, *v4); | |||
| } | |||
| } | |||
| void ActBlock8(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, __m128 *v5, __m128 *v6, __m128 *v7, __m128 *v8, | |||
| size_t relu_type) { | |||
| __m128 relu6 = _mm_set_ps1(6.0); | |||
| __m128 zero = _mm_setzero_ps(); | |||
| switch (relu_type) { | |||
| case 3: | |||
| *v1 = _mm_min_ps(*v1, relu6); | |||
| *v2 = _mm_min_ps(*v2, relu6); | |||
| *v3 = _mm_min_ps(*v3, relu6); | |||
| *v4 = _mm_min_ps(*v4, relu6); | |||
| *v5 = _mm_min_ps(*v5, relu6); | |||
| *v6 = _mm_min_ps(*v6, relu6); | |||
| *v7 = _mm_min_ps(*v7, relu6); | |||
| *v8 = _mm_min_ps(*v8, relu6); | |||
| case 1: | |||
| *v1 = _mm_max_ps(*v1, zero); | |||
| *v2 = _mm_max_ps(*v2, zero); | |||
| *v3 = _mm_max_ps(*v3, zero); | |||
| *v4 = _mm_max_ps(*v4, zero); | |||
| *v5 = _mm_max_ps(*v5, zero); | |||
| *v6 = _mm_max_ps(*v6, zero); | |||
| *v7 = _mm_max_ps(*v7, zero); | |||
| *v8 = _mm_max_ps(*v8, zero); | |||
| break; | |||
| } | |||
| } | |||
| void WriteCol1(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, | |||
| __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { | |||
| _mm_store_ss(*dst, *dst1); | |||
| if (r > 1) { | |||
| *dst += stride; | |||
| _mm_store_ss(*dst, *dst3); | |||
| } | |||
| if (r > 2) { | |||
| *dst += stride; | |||
| _mm_store_ss(*dst, *dst5); | |||
| } | |||
| if (r > 3) { | |||
| *dst += stride; | |||
| _mm_store_ss(*dst, *dst7); | |||
| *dst += stride; | |||
| *dst += extra_stride; | |||
| } | |||
| } | |||
| void WriteCol2(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, | |||
| __m128 *dst7, __m128 *dst8, int stride, int r) { | |||
| _mm_store_ss(*dst, *dst1); | |||
| *dst1 = _mm_shuffle_ps(*dst1, *dst1, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst, *dst1); | |||
| if (r > 1) { | |||
| *dst += stride; | |||
| _mm_store_ss(*dst, *dst3); | |||
| *dst3 = _mm_shuffle_ps(*dst3, *dst3, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst, *dst3); | |||
| } | |||
| if (r > 2) { | |||
| *dst += stride; | |||
| _mm_store_ss(*dst, *dst5); | |||
| *dst5 = _mm_shuffle_ps(*dst5, *dst5, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst, *dst5); | |||
| } | |||
| if (r > 3) { | |||
| *dst += stride; | |||
| _mm_store_ss(*dst, *dst7); | |||
| *dst7 = _mm_shuffle_ps(*dst7, *dst7, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst, *dst7); | |||
| } | |||
| } | |||
| void WriteCol2Opt(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, | |||
| __m128 *dst7, __m128 *dst8, int stride, int r) { | |||
| _mm_store_ss(*dst, *dst1); | |||
| *dst1 = _mm_shuffle_ps(*dst1, *dst1, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst + 1, *dst1); | |||
| if (r > 1) { | |||
| *dst += stride; | |||
| _mm_store_ss(*dst, *dst3); | |||
| *dst3 = _mm_shuffle_ps(*dst3, *dst3, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst + 1, *dst3); | |||
| } | |||
| if (r > 2) { | |||
| *dst += stride; | |||
| _mm_store_ss(*dst, *dst5); | |||
| *dst5 = _mm_shuffle_ps(*dst5, *dst5, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst + 1, *dst5); | |||
| } | |||
| if (r > 3) { | |||
| *dst += stride; | |||
| _mm_store_ss(*dst, *dst7); | |||
| *dst7 = _mm_shuffle_ps(*dst7, *dst7, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst + 1, *dst7); | |||
| *dst += stride; | |||
| *dst += 2; | |||
| } | |||
| } | |||
| void WriteCol3(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, | |||
| __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { | |||
| if (r > 1) { | |||
| *dst += stride; | |||
| _mm_store_ss(*dst, *dst3); | |||
| *dst3 = _mm_shuffle_ps(*dst3, *dst3, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst + 1, *dst3); | |||
| *dst3 = _mm_shuffle_ps(*dst3, *dst3, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst + 2, *dst3); | |||
| } | |||
| if (r > 2) { | |||
| *dst += stride; | |||
| _mm_store_ss(*dst, *dst5); | |||
| *dst5 = _mm_shuffle_ps(*dst5, *dst5, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst + 1, *dst5); | |||
| *dst5 = _mm_shuffle_ps(*dst5, *dst5, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst + 2, *dst5); | |||
| } | |||
| if (r > 3) { | |||
| *dst += stride; | |||
| _mm_store_ss(*dst, *dst7); | |||
| *dst7 = _mm_shuffle_ps(*dst7, *dst7, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst + 1, *dst7); | |||
| *dst7 = _mm_shuffle_ps(*dst7, *dst7, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst + 2, *dst7); | |||
| *dst += stride; | |||
| *dst += extra_stride; | |||
| } | |||
| } | |||
| void WriteCol4(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, | |||
| __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { | |||
| _mm_storeu_ps(*dst, *dst1); | |||
| if (r > 1) { | |||
| *dst += stride; | |||
| _mm_storeu_ps(*dst, *dst3); | |||
| } | |||
| if (r > 2) { | |||
| *dst += stride; | |||
| _mm_storeu_ps(*dst, *dst5); | |||
| } | |||
| if (r > 3) { | |||
| *dst += stride; | |||
| _mm_storeu_ps(*dst, *dst7); | |||
| *dst += stride; | |||
| *dst += extra_stride; | |||
| } | |||
| } | |||
| void WriteCol5(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, | |||
| __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { | |||
| _mm_storeu_ps(*dst, *dst1); | |||
| _mm_store_ss(*dst + 4, *dst2); | |||
| if (r > 1) { | |||
| *dst += stride; | |||
| _mm_storeu_ps(*dst, *dst3); | |||
| _mm_store_ss(*dst + 4, *dst4); | |||
| } | |||
| if (r > 2) { | |||
| *dst += stride; | |||
| _mm_storeu_ps(*dst, *dst5); | |||
| _mm_store_ss(*dst + 4, *dst6); | |||
| } | |||
| if (r > 3) { | |||
| *dst += stride; | |||
| _mm_storeu_ps(*dst, *dst7); | |||
| _mm_store_ss(*dst + 4, *dst8); | |||
| *dst += stride; | |||
| *dst += extra_stride; | |||
| } | |||
| } | |||
| void WriteCol6(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, | |||
| __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { | |||
| _mm_storeu_ps(*dst, *dst1); | |||
| _mm_store_ss(*dst + 4, *dst2); | |||
| *dst2 = _mm_shuffle_ps(*dst2, *dst2, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst + 5, *dst2); | |||
| if (r > 1) { | |||
| *dst += stride; | |||
| _mm_storeu_ps(*dst, *dst3); | |||
| _mm_store_ss(*dst + 4, *dst4); | |||
| *dst4 = _mm_shuffle_ps(*dst4, *dst4, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst + 5, *dst4); | |||
| } | |||
| if (r > 2) { | |||
| *dst += stride; | |||
| _mm_storeu_ps(*dst, *dst5); | |||
| _mm_store_ss(*dst + 4, *dst6); | |||
| *dst6 = _mm_shuffle_ps(*dst6, *dst6, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst + 5, *dst6); | |||
| } | |||
| if (r > 3) { | |||
| *dst += stride; | |||
| _mm_storeu_ps(*dst, *dst7); | |||
| _mm_store_ss(*dst + 4, *dst8); | |||
| *dst8 = _mm_shuffle_ps(*dst8, *dst8, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst + 5, *dst8); | |||
| *dst += stride; | |||
| *dst += extra_stride; | |||
| } | |||
| } | |||
| void WriteCol7(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, | |||
| __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { | |||
| _mm_storeu_ps(*dst, *dst1); | |||
| _mm_store_ss(*dst + 4, *dst2); | |||
| *dst2 = _mm_shuffle_ps(*dst2, *dst2, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst + 5, *dst2); | |||
| *dst2 = _mm_shuffle_ps(*dst2, *dst2, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst + 6, *dst2); | |||
| if (r > 1) { | |||
| *dst += stride; | |||
| _mm_storeu_ps(*dst, *dst3); | |||
| _mm_store_ss(*dst + 4, *dst4); | |||
| *dst4 = _mm_shuffle_ps(*dst4, *dst4, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst + 5, *dst4); | |||
| *dst4 = _mm_shuffle_ps(*dst4, *dst4, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst + 6, *dst4); | |||
| } | |||
| if (r > 2) { | |||
| *dst += stride; | |||
| _mm_storeu_ps(*dst, *dst5); | |||
| _mm_store_ss(*dst + 4, *dst6); | |||
| *dst6 = _mm_shuffle_ps(*dst6, *dst6, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst + 5, *dst6); | |||
| *dst6 = _mm_shuffle_ps(*dst6, *dst6, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst + 6, *dst6); | |||
| } | |||
| if (r > 3) { | |||
| *dst += stride; | |||
| _mm_storeu_ps(*dst, *dst7); | |||
| _mm_store_ss(*dst + 4, *dst8); | |||
| *dst8 = _mm_shuffle_ps(*dst8, *dst8, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst + 5, *dst8); | |||
| *dst8 = _mm_shuffle_ps(*dst8, *dst8, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(*dst + 6, *dst8); | |||
| *dst += stride; | |||
| *dst += extra_stride; | |||
| } | |||
| } | |||
| void WriteCol8(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, | |||
| __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { | |||
| _mm_storeu_ps(*dst, *dst1); | |||
| _mm_storeu_ps(*dst + 4, *dst2); | |||
| if (r > 1) { | |||
| *dst += stride; | |||
| _mm_storeu_ps(*dst, *dst3); | |||
| _mm_storeu_ps(*dst + 4, *dst4); | |||
| } | |||
| if (r > 2) { | |||
| *dst += stride; | |||
| _mm_storeu_ps(*dst, *dst5); | |||
| _mm_storeu_ps(*dst + 4, *dst6); | |||
| } | |||
| if (r > 3) { | |||
| *dst += stride; | |||
| _mm_storeu_ps(*dst, *dst7); | |||
| _mm_storeu_ps(*dst + 4, *dst8); | |||
| *dst += stride; | |||
| *dst += extra_stride; | |||
| } | |||
| } | |||
| void DoBiasBlock8(const float *bias_ptr, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, | |||
| __m128 *dst6, __m128 *dst7, __m128 *dst8) { | |||
| __m128 bias1 = _mm_loadu_ps(bias_ptr); | |||
| __m128 bias2 = _mm_loadu_ps(bias_ptr + C4NUM); | |||
| *dst1 = _mm_add_ps(*dst1, bias1); | |||
| *dst2 = _mm_add_ps(*dst2, bias2); | |||
| *dst3 = _mm_add_ps(*dst3, bias1); | |||
| *dst4 = _mm_add_ps(*dst4, bias2); | |||
| *dst5 = _mm_add_ps(*dst5, bias1); | |||
| *dst6 = _mm_add_ps(*dst6, bias2); | |||
| *dst7 = _mm_add_ps(*dst7, bias1); | |||
| *dst8 = _mm_add_ps(*dst8, bias2); | |||
| } | |||
| #endif | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_INTRINSICS_SSE_SSE_COMMON_H_ | |||
| #define MINDSPORE_LITE_NNACL_INTRINSICS_SSE_SSE_COMMON_H_ | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void ActBlock1(__m128 *v1, size_t relu, size_t relu6); | |||
| void ActBlock2(__m128 *v1, __m128 *v2, size_t relu, size_t relu6); | |||
| void ActBlock4(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, size_t relu, size_t relu6); | |||
| void ActBlock8(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, __m128 *v5, __m128 *v6, __m128 *v7, __m128 *v8, | |||
| size_t relu_type); | |||
| void WriteCol1(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, | |||
| __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r); | |||
| void WriteCol2(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, | |||
| __m128 *dst7, __m128 *dst8, int stride, int r); | |||
| void WriteCol2Opt(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, | |||
| __m128 *dst7, __m128 *dst8, int stride, int r); | |||
| void WriteCol3(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, | |||
| __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r); | |||
| void WriteCol4(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, | |||
| __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r); | |||
| void WriteCol5(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, | |||
| __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r); | |||
| void WriteCol6(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, | |||
| __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r); | |||
| void WriteCol7(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, | |||
| __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r); | |||
| void WriteCol8(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, __m128 *dst6, | |||
| __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r); | |||
| void DoBiasBlock8(const float *bias_ptr, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, | |||
| __m128 *dst6, __m128 *dst7, __m128 *dst8); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_INTRINSICS_SSE_SSE_COMMON_H_ | |||
| @@ -1,747 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifdef ENABLE_SSE | |||
| #include <x86intrin.h> | |||
| #include "nnacl/minimal_filtering_generator.h" | |||
| #include "nnacl/op_base.h" | |||
| void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, | |||
| int in_channel, int c4_channel) { | |||
| const float *src1 = matix_a; | |||
| int c16 = DOWN_DIV(in_channel, C16NUM) * C16NUM; | |||
| int c8 = DOWN_DIV(in_channel, C8NUM) * C8NUM; | |||
| for (int i = 0; i < m; ++i) { | |||
| const float *src1_n = src1; | |||
| const float *src2_n = matrix_b; | |||
| for (int j = 0; j < n; ++j) { | |||
| const float *src1_j = src1_n; | |||
| int y = 0; | |||
| // 16 channel | |||
| for (; y < c16; y += C16NUM) { | |||
| __m128 dst1 = _mm_setzero_ps(); | |||
| __m128 dst2 = _mm_setzero_ps(); | |||
| __m128 dst3 = _mm_setzero_ps(); | |||
| __m128 dst4 = _mm_setzero_ps(); | |||
| const float *src2_y = src2_n; | |||
| for (int z = 0; z < k; ++z) { | |||
| __m128 ma1 = _mm_loadu_ps(src1_j); | |||
| __m128 ma2 = _mm_loadu_ps(src1_j + 4); | |||
| __m128 ma3 = _mm_loadu_ps(src1_j + 8); | |||
| __m128 ma4 = _mm_loadu_ps(src1_j + 12); | |||
| __m128 mb = _mm_load_ps1(src2_y); | |||
| __m128 tmp1 = _mm_mul_ps(ma1, mb); | |||
| __m128 tmp2 = _mm_mul_ps(ma2, mb); | |||
| __m128 tmp3 = _mm_mul_ps(ma3, mb); | |||
| __m128 tmp4 = _mm_mul_ps(ma4, mb); | |||
| dst1 = _mm_add_ps(dst1, tmp1); | |||
| dst2 = _mm_add_ps(dst2, tmp2); | |||
| dst3 = _mm_add_ps(dst3, tmp3); | |||
| dst4 = _mm_add_ps(dst4, tmp4); | |||
| src1_j += in_channel; | |||
| src2_y += n; | |||
| } | |||
| _mm_storeu_ps(matrix_c, dst1); | |||
| _mm_storeu_ps(matrix_c + 4, dst2); | |||
| _mm_storeu_ps(matrix_c + 8, dst3); | |||
| _mm_storeu_ps(matrix_c + 12, dst4); | |||
| src1_j -= in_channel * k; | |||
| src1_j += C16NUM; | |||
| matrix_c += C16NUM; | |||
| } | |||
| // 8 channel | |||
| for (; y < c8; y += C8NUM) { | |||
| __m128 dst1 = _mm_setzero_ps(); | |||
| __m128 dst2 = _mm_setzero_ps(); | |||
| const float *src2_y = src2_n; | |||
| for (int z = 0; z < k; ++z) { | |||
| __m128 ma1 = _mm_loadu_ps(src1_j); | |||
| __m128 ma2 = _mm_loadu_ps(src1_j + 4); | |||
| __m128 mb = _mm_load_ps1(src2_y); | |||
| __m128 tmp1 = _mm_mul_ps(ma1, mb); | |||
| __m128 tmp2 = _mm_mul_ps(ma2, mb); | |||
| dst1 = _mm_add_ps(dst1, tmp1); | |||
| dst2 = _mm_add_ps(dst2, tmp2); | |||
| src1_j += in_channel; | |||
| src2_y += n; | |||
| } | |||
| _mm_storeu_ps(matrix_c, dst1); | |||
| _mm_storeu_ps(matrix_c + 4, dst2); | |||
| src1_j -= in_channel * k; | |||
| src1_j += C8NUM; | |||
| matrix_c += C8NUM; | |||
| } | |||
| // remain chann | |||
| for (; y < in_channel; ++y) { | |||
| float tmp = 0; | |||
| for (int z = 0; z < k; ++z) { | |||
| tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n]; | |||
| } | |||
| *matrix_c++ = tmp; | |||
| } | |||
| src2_n += 1; | |||
| } | |||
| src1 += k * in_channel; | |||
| } | |||
| } | |||
| void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | |||
| int col, int stride, int write_mode) { | |||
| int C8Steps = row * C8NUM; | |||
| int WinoSteps1 = stride * col; | |||
| int WinoSteps2 = stride * C8NUM; | |||
| for (int r = row; r > 0; r -= C4NUM) { | |||
| const float *srcb_d = b; | |||
| const float *bias_d = bias; | |||
| float *dst = NULL; | |||
| for (int cc = col; cc > 0; cc -= C8NUM) { | |||
| if (write_mode != 0) { // writec8 | |||
| dst = c; | |||
| } | |||
| const float *srca_d = a; | |||
| __m128 dst1 = _mm_setzero_ps(); | |||
| __m128 dst2 = _mm_setzero_ps(); | |||
| __m128 dst3 = _mm_setzero_ps(); | |||
| __m128 dst4 = _mm_setzero_ps(); | |||
| __m128 dst5 = _mm_setzero_ps(); | |||
| __m128 dst6 = _mm_setzero_ps(); | |||
| __m128 dst7 = _mm_setzero_ps(); | |||
| __m128 dst8 = _mm_setzero_ps(); | |||
| for (int d = depth; d > 0; --d) { | |||
| __m128 b1 = _mm_loadu_ps(srcb_d); | |||
| __m128 b2 = _mm_loadu_ps(srcb_d + 4); | |||
| __m128 a1 = _mm_load_ps1(srca_d); | |||
| __m128 a2 = _mm_load_ps1(srca_d + 1); | |||
| __m128 tmp1 = _mm_mul_ps(b1, a1); | |||
| __m128 tmp2 = _mm_mul_ps(b2, a1); | |||
| __m128 tmp3 = _mm_mul_ps(b1, a2); | |||
| __m128 tmp4 = _mm_mul_ps(b2, a2); | |||
| a1 = _mm_load_ps1(srca_d + 2); | |||
| dst1 = _mm_add_ps(dst1, tmp1); | |||
| dst2 = _mm_add_ps(dst2, tmp2); | |||
| a2 = _mm_load_ps1(srca_d + 3); | |||
| dst3 = _mm_add_ps(dst3, tmp3); | |||
| dst4 = _mm_add_ps(dst4, tmp4); | |||
| tmp1 = _mm_mul_ps(b1, a1); | |||
| tmp2 = _mm_mul_ps(b2, a1); | |||
| tmp3 = _mm_mul_ps(b1, a2); | |||
| tmp4 = _mm_mul_ps(b2, a2); | |||
| dst5 = _mm_add_ps(dst5, tmp1); | |||
| dst6 = _mm_add_ps(dst6, tmp2); | |||
| dst7 = _mm_add_ps(dst7, tmp3); | |||
| dst8 = _mm_add_ps(dst8, tmp4); | |||
| srcb_d += C8NUM; | |||
| srca_d += C4NUM; | |||
| } | |||
| if (bias != NULL) { | |||
| __m128 bias1 = _mm_loadu_ps(bias_d); | |||
| __m128 bias2 = _mm_loadu_ps(bias_d + C4NUM); | |||
| dst1 = _mm_add_ps(dst1, bias1); | |||
| dst2 = _mm_add_ps(dst2, bias2); | |||
| dst3 = _mm_add_ps(dst3, bias1); | |||
| dst4 = _mm_add_ps(dst4, bias2); | |||
| dst5 = _mm_add_ps(dst5, bias1); | |||
| dst6 = _mm_add_ps(dst6, bias2); | |||
| dst7 = _mm_add_ps(dst7, bias1); | |||
| dst8 = _mm_add_ps(dst8, bias2); | |||
| bias_d += C8NUM; | |||
| } | |||
| if (act_type == 3) { | |||
| __m128 relu6 = _mm_set_ps(6.0, 6.0, 6.0, 6.0); | |||
| dst1 = _mm_min_ps(dst1, relu6); | |||
| dst2 = _mm_min_ps(dst2, relu6); | |||
| dst3 = _mm_min_ps(dst3, relu6); | |||
| dst4 = _mm_min_ps(dst4, relu6); | |||
| dst5 = _mm_min_ps(dst5, relu6); | |||
| dst6 = _mm_min_ps(dst6, relu6); | |||
| dst7 = _mm_min_ps(dst7, relu6); | |||
| dst8 = _mm_min_ps(dst8, relu6); | |||
| } | |||
| if (act_type == 1 || act_type == 3) { | |||
| __m128 zero = _mm_setzero_ps(); | |||
| dst1 = _mm_max_ps(dst1, zero); | |||
| dst2 = _mm_max_ps(dst2, zero); | |||
| dst3 = _mm_max_ps(dst3, zero); | |||
| dst4 = _mm_max_ps(dst4, zero); | |||
| dst5 = _mm_max_ps(dst5, zero); | |||
| dst6 = _mm_max_ps(dst6, zero); | |||
| dst7 = _mm_max_ps(dst7, zero); | |||
| dst8 = _mm_max_ps(dst8, zero); | |||
| } | |||
| if (write_mode == 2) { // WriteWino | |||
| c = dst + WinoSteps2; | |||
| _mm_storeu_ps(dst, dst1); | |||
| _mm_storeu_ps(dst + 4, dst2); | |||
| dst += WinoSteps1; | |||
| _mm_storeu_ps(dst, dst3); | |||
| _mm_storeu_ps(dst + 4, dst4); | |||
| dst += WinoSteps1; | |||
| _mm_storeu_ps(dst, dst5); | |||
| _mm_storeu_ps(dst + 4, dst6); | |||
| dst += WinoSteps1; | |||
| _mm_storeu_ps(dst, dst7); | |||
| _mm_storeu_ps(dst + 4, dst8); | |||
| } else if (write_mode == 0) { // WriteC8 | |||
| _mm_storeu_ps(c, dst1); | |||
| _mm_storeu_ps(c + 4, dst2); | |||
| _mm_storeu_ps(c + 8, dst3); | |||
| _mm_storeu_ps(c + 12, dst4); | |||
| _mm_storeu_ps(c + 16, dst5); | |||
| _mm_storeu_ps(c + 20, dst6); | |||
| _mm_storeu_ps(c + 24, dst7); | |||
| _mm_storeu_ps(c + 28, dst8); | |||
| c += C8Steps; | |||
| } else { | |||
| switch (cc) { | |||
| case 1: // write1 | |||
| c = dst + 1; | |||
| _mm_store_ss(dst, dst1); | |||
| if (r > 1) { | |||
| dst += stride; | |||
| _mm_store_ss(dst, dst3); | |||
| } | |||
| if (r > 2) { | |||
| dst += stride; | |||
| _mm_store_ss(dst, dst5); | |||
| } | |||
| if (r > 3) { | |||
| dst += stride; | |||
| _mm_store_ss(dst, dst7); | |||
| dst += stride; | |||
| dst += 1; | |||
| } | |||
| break; | |||
| case 2: // write2 | |||
| c = dst + 2; | |||
| _mm_store_ss(dst, dst1); | |||
| dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 1, dst1); | |||
| if (r > 1) { | |||
| dst += stride; | |||
| _mm_store_ss(dst, dst3); | |||
| dst3 = _mm_shuffle_ps(dst3, dst3, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 1, dst3); | |||
| } | |||
| if (r > 2) { | |||
| dst += stride; | |||
| _mm_store_ss(dst, dst5); | |||
| dst5 = _mm_shuffle_ps(dst5, dst5, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 1, dst5); | |||
| } | |||
| if (r > 3) { | |||
| dst += stride; | |||
| _mm_store_ss(dst, dst7); | |||
| dst7 = _mm_shuffle_ps(dst7, dst7, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 1, dst7); | |||
| dst += stride; | |||
| dst += 2; | |||
| } | |||
| break; | |||
| case 3: // write3 | |||
| c = dst + 3; | |||
| _mm_store_ss(dst, dst1); | |||
| dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 1, dst1); | |||
| dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 2, dst1); | |||
| if (r > 1) { | |||
| dst += stride; | |||
| _mm_store_ss(dst, dst3); | |||
| dst3 = _mm_shuffle_ps(dst3, dst3, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 1, dst3); | |||
| dst3 = _mm_shuffle_ps(dst3, dst3, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 2, dst3); | |||
| } | |||
| if (r > 2) { | |||
| dst += stride; | |||
| _mm_store_ss(dst, dst5); | |||
| dst5 = _mm_shuffle_ps(dst5, dst5, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 1, dst5); | |||
| dst5 = _mm_shuffle_ps(dst5, dst5, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 2, dst5); | |||
| } | |||
| if (r > 3) { | |||
| dst += stride; | |||
| _mm_store_ss(dst, dst7); | |||
| dst7 = _mm_shuffle_ps(dst7, dst7, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 1, dst7); | |||
| dst7 = _mm_shuffle_ps(dst7, dst7, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 2, dst7); | |||
| dst += stride; | |||
| dst += 3; | |||
| } | |||
| break; | |||
| case 4: // write4 | |||
| c = dst + 4; | |||
| _mm_storeu_ps(dst, dst1); | |||
| if (r > 1) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst3); | |||
| } | |||
| if (r > 2) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst5); | |||
| } | |||
| if (r > 3) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst7); | |||
| dst += stride; | |||
| dst += 4; | |||
| } | |||
| break; | |||
| case 5: // write5 | |||
| c = dst + 5; | |||
| _mm_storeu_ps(dst, dst1); | |||
| _mm_store_ss(dst + 4, dst2); | |||
| if (r > 1) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst3); | |||
| _mm_store_ss(dst + 4, dst4); | |||
| } | |||
| if (r > 2) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst5); | |||
| _mm_store_ss(dst + 4, dst6); | |||
| } | |||
| if (r > 3) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst7); | |||
| _mm_store_ss(dst + 4, dst8); | |||
| dst += stride; | |||
| dst += 5; | |||
| } | |||
| break; | |||
| case 6: // write6 | |||
| c = dst + 6; | |||
| _mm_storeu_ps(dst, dst1); | |||
| _mm_store_ss(dst + 4, dst2); | |||
| dst2 = _mm_shuffle_ps(dst2, dst2, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 5, dst2); | |||
| if (r > 1) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst3); | |||
| _mm_store_ss(dst + 4, dst4); | |||
| dst4 = _mm_shuffle_ps(dst4, dst4, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 5, dst4); | |||
| } | |||
| if (r > 2) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst5); | |||
| _mm_store_ss(dst + 4, dst6); | |||
| dst6 = _mm_shuffle_ps(dst6, dst6, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 5, dst6); | |||
| } | |||
| if (r > 3) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst7); | |||
| _mm_store_ss(dst + 4, dst8); | |||
| dst8 = _mm_shuffle_ps(dst8, dst8, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 5, dst8); | |||
| dst += stride; | |||
| dst += 6; | |||
| } | |||
| break; | |||
| case 7: // write7 | |||
| c = dst + 7; | |||
| _mm_storeu_ps(dst, dst1); | |||
| _mm_store_ss(dst + 4, dst2); | |||
| dst2 = _mm_shuffle_ps(dst2, dst2, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 5, dst2); | |||
| dst2 = _mm_shuffle_ps(dst2, dst2, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 6, dst2); | |||
| if (r > 1) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst3); | |||
| _mm_store_ss(dst + 4, dst4); | |||
| dst4 = _mm_shuffle_ps(dst4, dst4, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 5, dst4); | |||
| dst4 = _mm_shuffle_ps(dst4, dst4, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 6, dst4); | |||
| } | |||
| if (r > 2) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst5); | |||
| _mm_store_ss(dst + 4, dst6); | |||
| dst6 = _mm_shuffle_ps(dst6, dst6, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 5, dst6); | |||
| dst6 = _mm_shuffle_ps(dst6, dst6, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 6, dst6); | |||
| } | |||
| if (r > 3) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst7); | |||
| _mm_store_ss(dst + 4, dst8); | |||
| dst8 = _mm_shuffle_ps(dst8, dst8, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 5, dst8); | |||
| dst8 = _mm_shuffle_ps(dst8, dst8, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 6, dst8); | |||
| dst += stride; | |||
| dst += 7; | |||
| } | |||
| break; | |||
| default: // write8 | |||
| c = dst + C8NUM; | |||
| _mm_storeu_ps(dst, dst1); | |||
| _mm_storeu_ps(dst + 4, dst2); | |||
| if (r > 1) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst3); | |||
| _mm_storeu_ps(dst + 4, dst4); | |||
| } | |||
| if (r > 2) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst5); | |||
| _mm_storeu_ps(dst + 4, dst6); | |||
| } | |||
| if (r > 3) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst7); | |||
| _mm_storeu_ps(dst + 4, dst8); | |||
| dst += stride; | |||
| dst += C8NUM; | |||
| } | |||
| break; | |||
| } | |||
| } | |||
| if (cc <= C8NUM) { // write end | |||
| break; | |||
| } | |||
| } // col end | |||
| a += C4NUM * depth; | |||
| switch (write_mode) { | |||
| case 0: // C8DstStep | |||
| c += 32; | |||
| break; | |||
| case 2: | |||
| c = dst + WinoSteps2; | |||
| break; | |||
| default: | |||
| c = dst - col; | |||
| break; | |||
| } | |||
| if (r <= C4NUM) { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | |||
| int col, int stride, size_t writeNhwc, size_t WriteWino) { | |||
| size_t DstWinoSteps = stride * C8NUM; | |||
| size_t WriteWinoSteps = stride * col; | |||
| for (int col_tmp = col; col_tmp > 0; col_tmp -= C8NUM) { | |||
| const float *srca_d = a; | |||
| float *dst = c; | |||
| for (int r = row; r > 0; r -= C4NUM) { | |||
| const float *srcb_d = b; | |||
| __m128 dst1 = _mm_setzero_ps(); | |||
| __m128 dst2 = _mm_setzero_ps(); | |||
| __m128 dst3 = _mm_setzero_ps(); | |||
| __m128 dst4 = _mm_setzero_ps(); | |||
| __m128 dst5 = _mm_setzero_ps(); | |||
| __m128 dst6 = _mm_setzero_ps(); | |||
| __m128 dst7 = _mm_setzero_ps(); | |||
| __m128 dst8 = _mm_setzero_ps(); | |||
| for (int d = 0; d < depth; d++) { | |||
| __m128 b1 = _mm_loadu_ps(srcb_d); | |||
| __m128 b2 = _mm_loadu_ps(srcb_d + 4); | |||
| __m128 a1 = _mm_load_ps1(srca_d); | |||
| __m128 a2 = _mm_load_ps1(srca_d + 1); | |||
| __m128 tmp1 = _mm_mul_ps(b1, a1); | |||
| __m128 tmp2 = _mm_mul_ps(b2, a1); | |||
| __m128 tmp3 = _mm_mul_ps(b1, a2); | |||
| __m128 tmp4 = _mm_mul_ps(b2, a2); | |||
| a1 = _mm_load_ps1(srca_d + 2); | |||
| dst1 = _mm_add_ps(dst1, tmp1); | |||
| dst2 = _mm_add_ps(dst2, tmp2); | |||
| a2 = _mm_load_ps1(srca_d + 3); | |||
| dst3 = _mm_add_ps(dst3, tmp3); | |||
| dst4 = _mm_add_ps(dst4, tmp4); | |||
| tmp1 = _mm_mul_ps(b1, a1); | |||
| tmp2 = _mm_mul_ps(b2, a1); | |||
| tmp3 = _mm_mul_ps(b1, a2); | |||
| tmp4 = _mm_mul_ps(b2, a2); | |||
| dst5 = _mm_add_ps(dst5, tmp1); | |||
| dst6 = _mm_add_ps(dst6, tmp2); | |||
| dst7 = _mm_add_ps(dst7, tmp3); | |||
| dst8 = _mm_add_ps(dst8, tmp4); | |||
| srcb_d += C8NUM; | |||
| srca_d += C4NUM; | |||
| } | |||
| if (bias != NULL) { | |||
| __m128 bias1 = _mm_loadu_ps(bias); | |||
| __m128 bias2 = _mm_loadu_ps(bias + C4NUM); | |||
| dst1 = _mm_add_ps(dst1, bias1); | |||
| dst2 = _mm_add_ps(dst2, bias2); | |||
| dst3 = _mm_add_ps(dst3, bias1); | |||
| dst4 = _mm_add_ps(dst4, bias2); | |||
| dst5 = _mm_add_ps(dst5, bias1); | |||
| dst6 = _mm_add_ps(dst6, bias2); | |||
| dst7 = _mm_add_ps(dst7, bias1); | |||
| dst8 = _mm_add_ps(dst8, bias2); | |||
| } | |||
| if (act_type == 3) { | |||
| __m128 relu6 = _mm_set_ps(6.0, 6.0, 6.0, 6.0); | |||
| dst1 = _mm_min_ps(dst1, relu6); | |||
| dst2 = _mm_min_ps(dst2, relu6); | |||
| dst3 = _mm_min_ps(dst3, relu6); | |||
| dst4 = _mm_min_ps(dst4, relu6); | |||
| dst5 = _mm_min_ps(dst5, relu6); | |||
| dst6 = _mm_min_ps(dst6, relu6); | |||
| dst7 = _mm_min_ps(dst7, relu6); | |||
| dst8 = _mm_min_ps(dst8, relu6); | |||
| } | |||
| if (act_type == 1 || act_type == 3) { | |||
| __m128 zero = _mm_setzero_ps(); | |||
| dst1 = _mm_max_ps(dst1, zero); | |||
| dst2 = _mm_max_ps(dst2, zero); | |||
| dst3 = _mm_max_ps(dst3, zero); | |||
| dst4 = _mm_max_ps(dst4, zero); | |||
| dst5 = _mm_max_ps(dst5, zero); | |||
| dst6 = _mm_max_ps(dst6, zero); | |||
| dst7 = _mm_max_ps(dst7, zero); | |||
| dst8 = _mm_max_ps(dst8, zero); | |||
| } | |||
| if (WriteWino != 0) { // WriteWino | |||
| _mm_storeu_ps(dst, dst1); | |||
| _mm_storeu_ps(dst + 4, dst2); | |||
| dst += WriteWinoSteps; | |||
| _mm_storeu_ps(dst, dst3); | |||
| _mm_storeu_ps(dst + 4, dst4); | |||
| dst += WriteWinoSteps; | |||
| _mm_storeu_ps(dst, dst5); | |||
| _mm_storeu_ps(dst + 4, dst6); | |||
| dst += WriteWinoSteps; | |||
| _mm_storeu_ps(dst, dst7); | |||
| _mm_storeu_ps(dst + 4, dst8); | |||
| dst += WriteWinoSteps; | |||
| } else if (writeNhwc == 0) { // WriteC8 | |||
| _mm_storeu_ps(dst, dst1); | |||
| _mm_storeu_ps(dst + 4, dst2); | |||
| _mm_storeu_ps(dst + 8, dst3); | |||
| _mm_storeu_ps(dst + 12, dst4); | |||
| _mm_storeu_ps(dst + 16, dst5); | |||
| _mm_storeu_ps(dst + 20, dst6); | |||
| _mm_storeu_ps(dst + 24, dst7); | |||
| _mm_storeu_ps(dst + 28, dst8); | |||
| dst += 32; | |||
| c = dst; | |||
| } else { | |||
| switch (col) { | |||
| case 1: // write1 | |||
| _mm_store_ss(dst, dst1); | |||
| if (r > 1) { | |||
| dst += stride; | |||
| _mm_store_ss(dst, dst3); | |||
| } | |||
| if (r > 2) { | |||
| dst += stride; | |||
| _mm_store_ss(dst, dst5); | |||
| } | |||
| if (r > 3) { | |||
| dst += stride; | |||
| _mm_store_ss(dst, dst7); | |||
| dst += stride; | |||
| } | |||
| case 2: // write2 | |||
| _mm_store_ss(dst, dst1); | |||
| dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst, dst1); | |||
| if (r > 1) { | |||
| dst += stride; | |||
| _mm_store_ss(dst, dst3); | |||
| dst3 = _mm_shuffle_ps(dst3, dst3, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst, dst3); | |||
| } | |||
| if (r > 2) { | |||
| dst += stride; | |||
| _mm_store_ss(dst, dst5); | |||
| dst5 = _mm_shuffle_ps(dst5, dst5, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst, dst5); | |||
| } | |||
| if (r > 3) { | |||
| dst += stride; | |||
| _mm_store_ss(dst, dst7); | |||
| dst7 = _mm_shuffle_ps(dst7, dst7, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst, dst7); | |||
| } | |||
| case 3: // write3 | |||
| _mm_store_ss(dst, dst1); | |||
| dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 1, dst1); | |||
| dst1 = _mm_shuffle_ps(dst1 + 2, dst1, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst, dst1); | |||
| if (r > 1) { | |||
| dst += stride; | |||
| _mm_store_ss(dst, dst3); | |||
| dst3 = _mm_shuffle_ps(dst3, dst3, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 1, dst3); | |||
| dst3 = _mm_shuffle_ps(dst3, dst3, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 2, dst3); | |||
| } | |||
| if (r > 2) { | |||
| dst += stride; | |||
| _mm_store_ss(dst, dst5); | |||
| dst5 = _mm_shuffle_ps(dst5, dst5, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 1, dst5); | |||
| dst5 = _mm_shuffle_ps(dst5, dst5, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 2, dst5); | |||
| } | |||
| if (r > 3) { | |||
| dst += stride; | |||
| _mm_store_ss(dst, dst7); | |||
| dst7 = _mm_shuffle_ps(dst7, dst7, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 1, dst7); | |||
| dst7 = _mm_shuffle_ps(dst7, dst7, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 2, dst7); | |||
| dst += stride; | |||
| } | |||
| case 4: // write4 | |||
| _mm_storeu_ps(dst, dst1); | |||
| if (r > 1) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst3); | |||
| } | |||
| if (r > 2) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst5); | |||
| } | |||
| if (r > 3) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst7); | |||
| dst += stride; | |||
| } | |||
| case 5: // // write5 | |||
| _mm_storeu_ps(dst, dst1); | |||
| _mm_store_ss(dst + 4, dst2); | |||
| if (r > 1) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst3); | |||
| _mm_store_ss(dst + 4, dst4); | |||
| } | |||
| if (r > 2) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst5); | |||
| _mm_store_ss(dst + 4, dst6); | |||
| } | |||
| if (r > 3) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst7); | |||
| _mm_store_ss(dst + 4, dst8); | |||
| dst += stride; | |||
| } | |||
| case 6: // write6 | |||
| _mm_storeu_ps(dst, dst1); | |||
| _mm_store_ss(dst + 4, dst2); | |||
| dst2 = _mm_shuffle_ps(dst2, dst2, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 5, dst2); | |||
| if (r > 1) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst3); | |||
| _mm_store_ss(dst + 4, dst4); | |||
| dst4 = _mm_shuffle_ps(dst4, dst4, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 5, dst4); | |||
| } | |||
| if (r > 2) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst5); | |||
| _mm_store_ss(dst + 4, dst6); | |||
| dst6 = _mm_shuffle_ps(dst6, dst6, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 5, dst6); | |||
| } | |||
| if (r > 3) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst7); | |||
| _mm_store_ss(dst + 4, dst8); | |||
| dst8 = _mm_shuffle_ps(dst8, dst8, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 5, dst8); | |||
| dst += stride; | |||
| } | |||
| case 7: // write7 | |||
| _mm_storeu_ps(dst, dst1); | |||
| _mm_store_ss(dst + 4, dst2); | |||
| dst2 = _mm_shuffle_ps(dst2, dst2, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 5, dst2); | |||
| dst2 = _mm_shuffle_ps(dst2, dst2, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 6, dst2); | |||
| if (r > 1) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst3); | |||
| _mm_store_ss(dst + 4, dst4); | |||
| dst4 = _mm_shuffle_ps(dst4, dst4, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 5, dst4); | |||
| dst4 = _mm_shuffle_ps(dst4, dst4, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 6, dst4); | |||
| } | |||
| if (r > 2) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst5); | |||
| _mm_store_ss(dst + 4, dst6); | |||
| dst6 = _mm_shuffle_ps(dst6, dst6, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 5, dst6); | |||
| dst6 = _mm_shuffle_ps(dst6, dst6, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 6, dst6); | |||
| } | |||
| if (r > 3) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst7); | |||
| _mm_store_ss(dst + 4, dst8); | |||
| dst8 = _mm_shuffle_ps(dst8, dst8, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 5, dst8); | |||
| dst8 = _mm_shuffle_ps(dst8, dst8, _MM_SHUFFLE(0, 3, 2, 1)); | |||
| _mm_store_ss(dst + 6, dst8); | |||
| dst += stride; | |||
| } | |||
| default: // write8 | |||
| _mm_storeu_ps(dst, dst1); | |||
| _mm_storeu_ps(dst + 4, dst2); | |||
| if (r > 1) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst3); | |||
| _mm_storeu_ps(dst + 4, dst4); | |||
| } | |||
| if (r > 2) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst5); | |||
| _mm_storeu_ps(dst + 4, dst6); | |||
| } | |||
| if (r > 3) { | |||
| dst += stride; | |||
| _mm_storeu_ps(dst, dst7); | |||
| _mm_storeu_ps(dst + 4, dst8); | |||
| dst += stride; | |||
| } | |||
| } | |||
| } | |||
| if (r <= C4NUM) { // WriteEnd | |||
| break; | |||
| } | |||
| } | |||
| b += depth * C8NUM; | |||
| bias += (bias != NULL) ? C8NUM : 0; | |||
| if (WriteWino != 0) { | |||
| c += DstWinoSteps; | |||
| } else if (writeNhwc != 0) { | |||
| c += C8NUM; | |||
| } | |||
| if (col_tmp <= C8NUM) { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| @@ -1,175 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifdef ENABLE_SSE | |||
| #include <x86intrin.h> | |||
| #include "nnacl/fp32/common_func_fp32.h" | |||
| void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t cal_num, size_t ic4, size_t oc4) { | |||
| const float *src_tmp = src; | |||
| for (int i = 0; i < oc4; ++i) { | |||
| float *dst_tmp = dst; | |||
| src = src_tmp; | |||
| size_t ic4_tmp = ic4 - 1; | |||
| __m128 src1 = _mm_loadu_ps(src); | |||
| __m128 src2 = _mm_loadu_ps(src + 4); | |||
| __m128 src3 = _mm_loadu_ps(src + 8); | |||
| __m128 src4 = _mm_loadu_ps(src + 12); | |||
| src += 16; | |||
| __m128 weight_data[4]; | |||
| weight_data[0] = _mm_loadu_ps(weight); | |||
| weight_data[1] = _mm_loadu_ps(weight + 4); | |||
| weight_data[2] = _mm_loadu_ps(weight + 8); | |||
| weight_data[3] = _mm_loadu_ps(weight + 12); | |||
| weight += 16; | |||
| __m128 dst1 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0])); | |||
| __m128 dst2 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0])); | |||
| __m128 dst3 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0])); | |||
| __m128 dst4 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0])); | |||
| for (int j = 1; j < 4; ++j) { | |||
| dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[j], _mm_set_ps1(src1[j]))); | |||
| dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[j], _mm_set_ps1(src2[j]))); | |||
| dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[j], _mm_set_ps1(src3[j]))); | |||
| dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[j], _mm_set_ps1(src4[j]))); | |||
| } | |||
| src1 = _mm_loadu_ps(src); | |||
| src2 = _mm_loadu_ps(src + 4); | |||
| src3 = _mm_loadu_ps(src + 8); | |||
| src4 = _mm_loadu_ps(src + 12); | |||
| src += 16; | |||
| __m128 dst5 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0])); | |||
| __m128 dst6 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0])); | |||
| __m128 dst7 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0])); | |||
| __m128 dst8 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0])); | |||
| for (int j = 1; j < 4; ++j) { | |||
| dst5 = _mm_add_ps(dst5, _mm_mul_ps(weight_data[j], _mm_set_ps1(src1[j]))); | |||
| dst6 = _mm_add_ps(dst6, _mm_mul_ps(weight_data[j], _mm_set_ps1(src2[j]))); | |||
| dst7 = _mm_add_ps(dst7, _mm_mul_ps(weight_data[j], _mm_set_ps1(src3[j]))); | |||
| dst8 = _mm_add_ps(dst8, _mm_mul_ps(weight_data[j], _mm_set_ps1(src4[j]))); | |||
| } | |||
| if (ic4_tmp != 0) { | |||
| ic4_tmp -= 1; | |||
| src1 = _mm_loadu_ps(src); | |||
| src2 = _mm_loadu_ps(src + 4); | |||
| src3 = _mm_loadu_ps(src + 8); | |||
| src4 = _mm_loadu_ps(src + 12); | |||
| src += 16; | |||
| weight_data[0] = _mm_loadu_ps(weight); | |||
| weight_data[1] = _mm_loadu_ps(weight + 4); | |||
| weight += 8; | |||
| dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0]))); | |||
| dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0]))); | |||
| for (; ic4_tmp != 0; ic4_tmp -= 1) { | |||
| dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0]))); | |||
| dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0]))); | |||
| dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[1], _mm_set_ps1(src1[1]))); | |||
| dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[1], _mm_set_ps1(src2[1]))); | |||
| weight_data[2] = _mm_loadu_ps(weight); | |||
| weight_data[3] = _mm_loadu_ps(weight + 4); | |||
| weight += 8; | |||
| dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[1], _mm_set_ps1(src3[1]))); | |||
| dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[1], _mm_set_ps1(src4[1]))); | |||
| dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[2], _mm_set_ps1(src1[2]))); | |||
| dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[2], _mm_set_ps1(src2[2]))); | |||
| dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[2], _mm_set_ps1(src3[2]))); | |||
| dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[2], _mm_set_ps1(src4[2]))); | |||
| dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[3], _mm_set_ps1(src1[3]))); | |||
| dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[3], _mm_set_ps1(src2[3]))); | |||
| src1 = _mm_loadu_ps(src); | |||
| src2 = _mm_loadu_ps(src + 4); | |||
| dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[3], _mm_set_ps1(src3[3]))); | |||
| dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[3], _mm_set_ps1(src4[3]))); | |||
| src3 = _mm_loadu_ps(src + 8); | |||
| src4 = _mm_loadu_ps(src + 12); | |||
| src += 16; | |||
| dst5 = _mm_add_ps(dst5, _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0]))); | |||
| dst6 = _mm_add_ps(dst6, _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0]))); | |||
| dst7 = _mm_add_ps(dst7, _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0]))); | |||
| dst8 = _mm_add_ps(dst8, _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0]))); | |||
| dst5 = _mm_add_ps(dst5, _mm_mul_ps(weight_data[1], _mm_set_ps1(src1[1]))); | |||
| dst6 = _mm_add_ps(dst6, _mm_mul_ps(weight_data[1], _mm_set_ps1(src2[1]))); | |||
| dst7 = _mm_add_ps(dst7, _mm_mul_ps(weight_data[1], _mm_set_ps1(src3[1]))); | |||
| dst8 = _mm_add_ps(dst8, _mm_mul_ps(weight_data[1], _mm_set_ps1(src4[1]))); | |||
| dst5 = _mm_add_ps(dst5, _mm_mul_ps(weight_data[2], _mm_set_ps1(src1[2]))); | |||
| dst6 = _mm_add_ps(dst6, _mm_mul_ps(weight_data[2], _mm_set_ps1(src2[2]))); | |||
| dst7 = _mm_add_ps(dst7, _mm_mul_ps(weight_data[2], _mm_set_ps1(src3[2]))); | |||
| weight_data[0] = _mm_loadu_ps(weight); | |||
| weight_data[1] = _mm_loadu_ps(weight + 4); | |||
| weight += 8; | |||
| dst8 = _mm_add_ps(dst8, _mm_mul_ps(weight_data[2], _mm_set_ps1(src4[2]))); | |||
| dst5 = _mm_add_ps(dst5, _mm_mul_ps(weight_data[3], _mm_set_ps1(src1[3]))); | |||
| dst6 = _mm_add_ps(dst6, _mm_mul_ps(weight_data[3], _mm_set_ps1(src2[3]))); | |||
| dst7 = _mm_add_ps(dst7, _mm_mul_ps(weight_data[3], _mm_set_ps1(src3[3]))); | |||
| src1 = _mm_loadu_ps(src); | |||
| src2 = _mm_loadu_ps(src + 4); | |||
| dst8 = _mm_add_ps(dst8, _mm_mul_ps(weight_data[3], _mm_set_ps1(src4[3]))); | |||
| src3 = _mm_loadu_ps(src + 8); | |||
| src4 = _mm_loadu_ps(src + 12); | |||
| src += 16; | |||
| dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0]))); | |||
| dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0]))); | |||
| } | |||
| dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0]))); | |||
| dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0]))); | |||
| dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[1], _mm_set_ps1(src1[1]))); | |||
| dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[1], _mm_set_ps1(src2[1]))); | |||
| weight_data[2] = _mm_loadu_ps(weight); | |||
| weight_data[3] = _mm_loadu_ps(weight + 4); | |||
| weight += 8; | |||
| dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[1], _mm_set_ps1(src3[1]))); | |||
| dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[1], _mm_set_ps1(src4[1]))); | |||
| dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[2], _mm_set_ps1(src1[2]))); | |||
| dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[2], _mm_set_ps1(src2[2]))); | |||
| dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[2], _mm_set_ps1(src3[2]))); | |||
| dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[2], _mm_set_ps1(src4[2]))); | |||
| dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[3], _mm_set_ps1(src1[3]))); | |||
| dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[3], _mm_set_ps1(src2[3]))); | |||
| dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[3], _mm_set_ps1(src3[3]))); | |||
| src1 = _mm_loadu_ps(src); | |||
| src2 = _mm_loadu_ps(src + 4); | |||
| dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[3], _mm_set_ps1(src4[3]))); | |||
| src3 = _mm_loadu_ps(src + 8); | |||
| src4 = _mm_loadu_ps(src + 12); | |||
| src += 16; | |||
| for (int j = 0; j < 4; ++j) { | |||
| dst5 = _mm_add_ps(dst5, _mm_mul_ps(weight_data[j], _mm_set_ps1(src1[j]))); | |||
| dst6 = _mm_add_ps(dst6, _mm_mul_ps(weight_data[j], _mm_set_ps1(src2[j]))); | |||
| dst7 = _mm_add_ps(dst7, _mm_mul_ps(weight_data[j], _mm_set_ps1(src3[j]))); | |||
| dst8 = _mm_add_ps(dst8, _mm_mul_ps(weight_data[j], _mm_set_ps1(src4[j]))); | |||
| } | |||
| } | |||
| _mm_storeu_ps(dst, dst1); | |||
| _mm_storeu_ps(dst + 4, dst2); | |||
| _mm_storeu_ps(dst + 8, dst3); | |||
| _mm_storeu_ps(dst + 12, dst4); | |||
| _mm_storeu_ps(dst + 16, dst5); | |||
| _mm_storeu_ps(dst + 20, dst6); | |||
| _mm_storeu_ps(dst + 24, dst7); | |||
| _mm_storeu_ps(dst + 28, dst8); | |||
| dst = dst_tmp + cal_num; | |||
| } | |||
| } | |||
| #endif | |||
| @@ -30,7 +30,6 @@ namespace mindspore::kernel { | |||
| namespace { | |||
| constexpr int kMaxInputNum = 2; | |||
| constexpr int kOutputNum = 1; | |||
| constexpr int kRank = 4; | |||
| } // namespace | |||
| int ResizeBaseCPUKernel::CheckParameters() { | |||
| @@ -113,7 +112,7 @@ int ResizeBaseCPUKernel::Init() { | |||
| auto input = in_tensors_.at(0); | |||
| auto input_shape = input->shape(); | |||
| if (!input_shape.empty() && input_shape.size() != kRank) { | |||
| if (!input_shape.empty() && input_shape.size() != COMM_SHAPE_SIZE) { | |||
| MS_LOG(ERROR) << "Resize op support input rank 4, got " << input_shape.size(); | |||
| return RET_ERROR; | |||
| } | |||
| @@ -24,7 +24,7 @@ | |||
| #include "nnacl/fp16/conv_fp16.h" | |||
| #include "nnacl/fp16/winograd_utils_fp16.h" | |||
| #include "src/common/utils.h" | |||
| #include "nnacl/minimal_filtering_generator.h" | |||
| #include "nnacl/base/minimal_filtering_generator.h" | |||
| namespace mindspore::kernel { | |||
| class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||
| @@ -49,7 +49,7 @@ int TransposeFp16CPUKernel::Run() { | |||
| for (int i = 0; i < input_perm->ElementsNum(); ++i) { | |||
| param->perm_[i] = perm_data[i]; | |||
| } | |||
| for (int i = input_perm->ElementsNum(); i < 8; ++i) { | |||
| for (int i = input_perm->ElementsNum(); i < MAX_SHAPE_SIZE; ++i) { | |||
| param->perm_[i] = 0; | |||
| } | |||
| param->num_axes_ = input_perm->ElementsNum(); | |||
| @@ -54,7 +54,7 @@ int BatchToSpaceCPUKernel::Init() { | |||
| } | |||
| int BatchToSpaceCPUKernel::ReSize() { | |||
| MS_ASSERT(in_tensors_.at(0)->shape().size() == 4); | |||
| MS_ASSERT(in_tensors_.at(0)->shape().size() == COMM_SHAPE_SIZE); | |||
| return RET_OK; | |||
| } | |||
| @@ -19,8 +19,8 @@ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "nnacl/winograd_transform.h" | |||
| #include "nnacl/minimal_filtering_generator.h" | |||
| #include "nnacl/fp32/winograd_transform.h" | |||
| #include "nnacl/base/minimal_filtering_generator.h" | |||
| #include "nnacl/fp32/conv_winograd_fp32.h" | |||
| #include "src/runtime/kernel/arm/base/convolution_base.h" | |||
| @@ -28,7 +28,6 @@ using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_Reverse; | |||
| namespace mindspore::kernel { | |||
| int ReverseCPUKernel::Stride(int index) { | |||
| int stride = 1; | |||
| for (size_t i = index + 1; i < in_tensors_.at(0)->shape().size(); ++i) { | |||
| @@ -39,7 +38,7 @@ int ReverseCPUKernel::Stride(int index) { | |||
| int ReverseCPUKernel::ReSize() { | |||
| data_size_ = in_tensors_.at(0)->ElementsNum(); | |||
| thread_sz_count_ = MSMIN(thread_count_, data_size_); | |||
| thread_sz_count_ = MSMIN(op_parameter_->thread_num_, data_size_); | |||
| thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); | |||
| auto *param = reinterpret_cast<ReverseParameter *>(op_parameter_); | |||
| @@ -18,11 +18,8 @@ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "include/context.h" | |||
| #define REVERSE_STRIDE_MAX_SIZE 4 | |||
| using mindspore::lite::InnerContext; | |||
| namespace mindspore::kernel { | |||
| @@ -31,7 +28,7 @@ class ReverseCPUKernel : public LiteKernel { | |||
| ReverseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {} | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| ~ReverseCPUKernel() { | |||
| if (tmp_ != nullptr) { | |||
| free(tmp_); | |||
| @@ -49,10 +46,9 @@ class ReverseCPUKernel : public LiteKernel { | |||
| int thread_sz_count_ = 0; | |||
| int thread_sz_stride_ = 0; | |||
| int data_size_ = 0; | |||
| int strides_[REVERSE_STRIDE_MAX_SIZE] = {0}; | |||
| int inCount_[REVERSE_STRIDE_MAX_SIZE] = {0}; | |||
| int outCount_[REVERSE_STRIDE_MAX_SIZE] = {0}; | |||
| int thread_count_ = 1; | |||
| int strides_[COMM_SHAPE_SIZE] = {0}; | |||
| int inCount_[COMM_SHAPE_SIZE] = {0}; | |||
| int outCount_[COMM_SHAPE_SIZE] = {0}; | |||
| int *tmp_ = nullptr; | |||
| float *in_ptr_ = nullptr; | |||
| float *out_ptr_ = nullptr; | |||
| @@ -128,7 +128,7 @@ int TransposeCPUKernel::Run() { | |||
| for (int i = 0; i < input_perm->ElementsNum(); ++i) { | |||
| param->perm_[i] = perm_data[i]; | |||
| } | |||
| for (int i = input_perm->ElementsNum(); i < 8; ++i) { | |||
| for (int i = input_perm->ElementsNum(); i < MAX_SHAPE_SIZE; ++i) { | |||
| param->perm_[i] = 0; | |||
| } | |||
| } | |||
| @@ -19,8 +19,7 @@ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "nnacl/winograd_transform.h" | |||
| #include "nnacl/fp32/winograd_transform.h" | |||
| #include "src/runtime/kernel/arm/base/convolution_base.h" | |||
| namespace mindspore::kernel { | |||
| @@ -77,13 +77,13 @@ void MulInt8CPUKernel::CheckIfFastImpl() { | |||
| auto in_tensor0 = in_tensors_.at(0); | |||
| auto in_tensor1 = in_tensors_.at(1); | |||
| if (in_tensor0->ElementsNum() != in_tensor1->ElementsNum()) { | |||
| if (in_tensor0->shape().size() == 4 && in_tensor1->shape().size() == 4) { | |||
| if (in_tensor0->shape().size() == COMM_SHAPE_SIZE && in_tensor1->shape().size() == COMM_SHAPE_SIZE) { | |||
| CheckSameShapeSize(in_tensor0->shape(), in_tensor1->shape()); | |||
| } else if (in_tensor0->shape().size() == 1 && in_tensor1->shape().size() == 4) { | |||
| } else if (in_tensor0->shape().size() == 1 && in_tensor1->shape().size() == COMM_SHAPE_SIZE) { | |||
| if (in_tensor0->ElementsNum() == in_tensor1->shape()[3]) { | |||
| fast_hw_broadcast_ = true; | |||
| } | |||
| } else if (in_tensor0->shape().size() == 4 && in_tensor1->shape().size() == 1) { | |||
| } else if (in_tensor0->shape().size() == COMM_SHAPE_SIZE && in_tensor1->shape().size() == 1) { | |||
| if (in_tensor1->ElementsNum() == in_tensor0->shape()[3]) { | |||
| fast_hw_broadcast_ = true; | |||
| input1_hw_broadcast_ = true; | |||
| @@ -43,7 +43,7 @@ void MatMulRInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst, | |||
| size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | |||
| int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, | |||
| int32_t maxi, size_t per_channel) { | |||
| return MatmulInt8DpNeon64(a, b, dst, UP_ROUND(row, 8), UP_ROUND(col, 8), deep_4, input_sum, bias, mini, maxi, | |||
| return MatmulInt8DpNeon64(a, b, dst, UP_ROUND(row, C8NUM), UP_ROUND(col, C8NUM), deep_4, input_sum, bias, mini, maxi, | |||
| output_zp, multiplier, left_shift, right_shift, row, col, stride, per_channel); | |||
| } | |||
| void MatMulDpInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, | |||
| @@ -67,7 +67,7 @@ if(PLATFORM_ARM32) | |||
| endif() | |||
| if("${X86_64_SIMD}" STREQUAL "sse") | |||
| file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/x86_64_sse/*.c) | |||
| file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/intrinsics/sse/*.c) | |||
| set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C) | |||
| set(KERNEL_OP_SRC | |||
| ${KERNEL_OP_SRC} | |||
| @@ -78,8 +78,8 @@ endif() | |||
| if("${X86_64_SIMD}" STREQUAL "avx") | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.1 -mavx -mavx2") | |||
| set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1 -mavx -mavx2") | |||
| file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/x86_64_sse/*.c | |||
| ${LITE_DIR}/nnacl/x86_64_avx/*.c | |||
| file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/intrinsics/sse/*.c | |||
| ${LITE_DIR}/nnacl/intrinsics/avx/*.c | |||
| ${LITE_DIR}/nnacl/assembly/avx/*.S) | |||
| set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C) | |||
| set(KERNEL_OP_SRC | |||