From: @lzkcode Reviewed-by: @zhang_xue_tong Signed-off-by:tags/v1.1.0
| @@ -43,7 +43,7 @@ void PostConvFuncComm(const float *src_ptr_, float *out_ptr, const float *bias_p | |||
| void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, | |||
| size_t plane_size, size_t stride, size_t relu_type) { | |||
| #ifndef ENABLE_ARM | |||
| #if !defined(ENABLE_ARM) && !defined(ENABLE_X86_64_SSE) | |||
| PostConvFuncComm(c8_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_size, stride, relu_type, C8NUM); | |||
| #else | |||
| size_t oc8mod = output_channel % C8NUM; | |||
| @@ -34,7 +34,7 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_ | |||
| int out_channel = conv_param->output_channel_; | |||
| int thread_count = conv_param->thread_num_; | |||
| int output_count = out_h * out_w; | |||
| #ifdef ENABLE_ARM32 | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| const int cal_num = C4NUM; | |||
| #else | |||
| const int cal_num = C12NUM; | |||
| @@ -58,7 +58,7 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_ | |||
| int out_offset = thread_id * cal_num * out_channel + out_batch_offset; | |||
| float *gemm_output = output_data + out_offset; | |||
| #ifdef ENABLE_ARM32 | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| RowMajor2Col4Major(gemm_input, col_major_gemm_input, cal_num, deep); | |||
| #else | |||
| RowMajor2Col12Major(gemm_input, col_major_gemm_input, cal_num, deep); | |||
| @@ -112,7 +112,7 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const | |||
| float *dst_ptr = gemm_out + task_id * gemm_out_offset; | |||
| float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset; | |||
| for (int i = 0; i < input_unit_square; ++i) { | |||
| #ifdef ENABLE_ARM32 | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| RowMajor2Col4Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); | |||
| #else | |||
| RowMajor2Col12Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); | |||
| @@ -41,7 +41,7 @@ void DeConvPostFp32C8(const float *src, float *tmp, const float *bias, float *ds | |||
| size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; | |||
| size_t output_plane = conv_param->output_w_ * conv_param->output_h_; | |||
| int oc8 = UP_ROUND(output_channel, C8NUM); | |||
| #ifdef ENABLE_ARM32 | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| const int tile_num = 4; | |||
| #else | |||
| const int tile_num = 12; | |||
| @@ -190,6 +190,62 @@ void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) | |||
| : | |||
| : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) | |||
| : "r10", "r12", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); | |||
| #elif ENABLE_X86_64_SSE | |||
| __m128 src1 = _mm_loadu_ps(src_c); | |||
| __m128 src2 = _mm_loadu_ps(src_c + col); | |||
| __m128 src3 = _mm_loadu_ps(src_c + 2 * col); | |||
| __m128 src4 = _mm_loadu_ps(src_c + 3 * col); | |||
| src_c += 4 * col; | |||
| __m128 src12L = _mm_unpacklo_ps(src1, src2); | |||
| __m128 src12H = _mm_unpackhi_ps(src1, src2); | |||
| __m128 src34L = _mm_unpacklo_ps(src3, src4); | |||
| __m128 src34H = _mm_unpackhi_ps(src3, src4); | |||
| __m128 dst0 = _mm_movelh_ps(src12L, src34L); | |||
| __m128 dst3 = _mm_movehl_ps(src34L, src12L); | |||
| __m128 dst6 = _mm_movelh_ps(src12H, src34H); | |||
| __m128 dst9 = _mm_movehl_ps(src34H, src12H); | |||
| __m128 src5 = _mm_loadu_ps(src_c); | |||
| __m128 src6 = _mm_loadu_ps(src_c + col); | |||
| __m128 src7 = _mm_loadu_ps(src_c + 2 * col); | |||
| __m128 src8 = _mm_loadu_ps(src_c + 3 * col); | |||
| src_c += 4 * col; | |||
| __m128 src56L = _mm_unpacklo_ps(src5, src6); | |||
| __m128 src56H = _mm_unpackhi_ps(src5, src6); | |||
| __m128 src78L = _mm_unpacklo_ps(src7, src8); | |||
| __m128 src78H = _mm_unpackhi_ps(src7, src8); | |||
| __m128 dst1 = _mm_movelh_ps(src56L, src78L); | |||
| __m128 dst4 = _mm_movehl_ps(src78L, src56L); | |||
| __m128 dst7 = _mm_movelh_ps(src56H, src78H); | |||
| __m128 dst10 = _mm_movehl_ps(src78H, src56H); | |||
| __m128 src9 = _mm_loadu_ps(src_c); | |||
| __m128 src10 = _mm_loadu_ps(src_c + col); | |||
| __m128 src11 = _mm_loadu_ps(src_c + 2 * col); | |||
| __m128 src12 = _mm_loadu_ps(src_c + 3 * col); | |||
| src_c += 4 * col; | |||
| __m128 src910L = _mm_unpacklo_ps(src9, src10); | |||
| __m128 src910H = _mm_unpackhi_ps(src9, src10); | |||
| __m128 src1112L = _mm_unpacklo_ps(src11, src12); | |||
| __m128 src1112H = _mm_unpackhi_ps(src11, src12); | |||
| __m128 dst2 = _mm_movelh_ps(src910L, src1112L); | |||
| __m128 dst5 = _mm_movehl_ps(src1112L, src910L); | |||
| __m128 dst8 = _mm_movelh_ps(src910H, src1112H); | |||
| __m128 dst11 = _mm_movehl_ps(src1112H, src910H); | |||
| _mm_storeu_ps(dst_c, dst0); | |||
| _mm_storeu_ps(dst_c + 4, dst1); | |||
| _mm_storeu_ps(dst_c + 8, dst2); | |||
| _mm_storeu_ps(dst_c + 12, dst3); | |||
| _mm_storeu_ps(dst_c + 16, dst4); | |||
| _mm_storeu_ps(dst_c + 20, dst5); | |||
| _mm_storeu_ps(dst_c + 24, dst6); | |||
| _mm_storeu_ps(dst_c + 28, dst7); | |||
| _mm_storeu_ps(dst_c + 32, dst8); | |||
| _mm_storeu_ps(dst_c + 36, dst9); | |||
| _mm_storeu_ps(dst_c + 40, dst10); | |||
| _mm_storeu_ps(dst_c + 44, dst11); | |||
| #else | |||
| for (int tr = 0; tr < C12NUM; tr++) { | |||
| for (int tc = 0; tc < C4NUM; tc++) { | |||
| @@ -365,6 +421,35 @@ void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) | |||
| : | |||
| : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) | |||
| : "r10", "r11", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); | |||
| #elif ENABLE_X86_64_SSE | |||
| /* 8x4 row-major to col-major */ | |||
| __m128 src1 = _mm_loadu_ps(src_c); | |||
| __m128 src2 = _mm_loadu_ps(src_c + col); | |||
| __m128 src3 = _mm_loadu_ps(src_c + 2 * col); | |||
| __m128 src4 = _mm_loadu_ps(src_c + 3 * col); | |||
| src_c += 4 * col; | |||
| __m128 src12L = _mm_unpacklo_ps(src1, src2); // x5 | |||
| __m128 src12H = _mm_unpackhi_ps(src1, src2); // x1 | |||
| __m128 src34L = _mm_unpacklo_ps(src3, src4); // x | |||
| __m128 src34H = _mm_unpackhi_ps(src3, src4); | |||
| _mm_storeu_ps(dst_c, _mm_movelh_ps(src12L, src34L)); | |||
| _mm_storeu_ps(dst_c + 8, _mm_movehl_ps(src34L, src12L)); | |||
| _mm_storeu_ps(dst_c + 16, _mm_movelh_ps(src12H, src34H)); | |||
| _mm_storeu_ps(dst_c + 24, _mm_movehl_ps(src34H, src12H)); | |||
| __m128 src5 = _mm_loadu_ps(src_c); | |||
| __m128 src6 = _mm_loadu_ps(src_c + col); | |||
| __m128 src7 = _mm_loadu_ps(src_c + 2 * col); | |||
| __m128 src8 = _mm_loadu_ps(src_c + 3 * col); | |||
| src_c += 4 * col; | |||
| __m128 src56L = _mm_unpacklo_ps(src5, src6); | |||
| __m128 src56H = _mm_unpackhi_ps(src5, src6); | |||
| __m128 src78L = _mm_unpacklo_ps(src7, src8); | |||
| __m128 src78H = _mm_unpackhi_ps(src7, src8); | |||
| _mm_storeu_ps(dst_c + 4, _mm_movelh_ps(src56L, src78L)); | |||
| _mm_storeu_ps(dst_c + 12, _mm_movehl_ps(src78L, src56L)); | |||
| _mm_storeu_ps(dst_c + 20, _mm_movelh_ps(src56H, src78H)); | |||
| _mm_storeu_ps(dst_c + 28, _mm_movehl_ps(src78H, src56H)); | |||
| #else | |||
| for (int tr = 0; tr < 8; tr++) { | |||
| for (int tc = 0; tc < 4; tc++) { | |||
| @@ -434,6 +519,26 @@ void RowMajor2Col4Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) | |||
| : | |||
| : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) | |||
| : "r10", "r12", "q0", "q1", "q2", "q3"); | |||
| #elif ENABLE_X86_64_SSE | |||
| __m128 src1 = _mm_loadu_ps(src_c); | |||
| __m128 src2 = _mm_loadu_ps(src_c + col); | |||
| __m128 src3 = _mm_loadu_ps(src_c + 2 * col); | |||
| __m128 src4 = _mm_loadu_ps(src_c + 3 * col); | |||
| src_c += 4 * col; | |||
| __m128 src12L = _mm_unpacklo_ps(src1, src2); | |||
| __m128 src12H = _mm_unpackhi_ps(src1, src2); | |||
| __m128 src34L = _mm_unpacklo_ps(src3, src4); | |||
| __m128 src34H = _mm_unpackhi_ps(src3, src4); | |||
| __m128 dst0 = _mm_movelh_ps(src12L, src34L); | |||
| __m128 dst1 = _mm_movehl_ps(src34L, src12L); | |||
| __m128 dst2 = _mm_movelh_ps(src12H, src34H); | |||
| __m128 dst3 = _mm_movehl_ps(src34H, src12H); | |||
| _mm_storeu_ps(dst_c, dst0); | |||
| _mm_storeu_ps(dst_c + 4, dst1); | |||
| _mm_storeu_ps(dst_c + 8, dst2); | |||
| _mm_storeu_ps(dst_c + 12, dst3); | |||
| #else | |||
| for (int tr = 0; tr < C4NUM; tr++) { | |||
| for (int tc = 0; tc < C4NUM; tc++) { | |||
| @@ -565,6 +670,12 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT | |||
| } else { | |||
| MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); | |||
| } | |||
| #elif ENABLE_X86_64_SSE | |||
| if (out_type == OutType_C8) { | |||
| MatmulFloatSse64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0); | |||
| } else { | |||
| MatmulFloatSse64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); | |||
| } | |||
| #else | |||
| MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type); | |||
| #endif | |||
| @@ -47,6 +47,11 @@ void MatmulFloatNeon32(const float *a, const float *b, float *c, const float *bi | |||
| int col, int stride, size_t writeNhwc, size_t WriteWino); | |||
| void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | |||
| int col, int stride, int write_mode); | |||
| #elif ENABLE_X86_64_SSE | |||
| void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | |||
| int col, int stride, size_t writeNhwc, size_t WriteWino); | |||
| void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | |||
| int col, int stride, int write_mode); | |||
| #endif | |||
| #ifdef ENABLE_NNACL_INFER_SHAPE | |||
| @@ -221,35 +221,23 @@ int CookToomFilter(float *matrix_a, float *matrix_at, float *matrix_b, float *ma | |||
| return NNACL_OK; | |||
| } | |||
| #ifdef ENABLE_ARM | |||
| void MatrixMultiplyVec(const float32x4_t *matrix_a, const float32x4_t *matrix_b, float32x4_t *matrix_c, | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) | |||
| void MatrixMultiplyVec(const MS_FLOAT32X4 *matrix_a, const MS_FLOAT32X4 *matrix_b, MS_FLOAT32X4 *matrix_c, | |||
| const float *bias, int m, int k, int n) { | |||
| if (bias == NULL) { | |||
| int count = 0; | |||
| for (int h = 0; h < m; h++) { | |||
| int h_offset = h * k; | |||
| for (int w = 0; w < n; w++) { | |||
| float32x4_t res = vmovq_n_f32(0); | |||
| for (int i = 0; i < k; i++) { | |||
| res = vmlaq_f32(res, matrix_a[h_offset + i], matrix_b[w + i * n]); | |||
| } | |||
| matrix_c[count] = res; | |||
| count++; | |||
| } | |||
| } | |||
| } else { | |||
| int count = 0; | |||
| float32x4_t bias_ptr = vld1q_f32(bias); | |||
| for (int h = 0; h < m; h++) { | |||
| int h_offset = h * k; | |||
| for (int w = 0; w < n; w++) { | |||
| float32x4_t res = vmovq_n_f32(0); | |||
| for (int i = 0; i < k; i++) { | |||
| res = vmlaq_f32(res, matrix_a[h_offset + i], matrix_b[w + i * n]); | |||
| } | |||
| matrix_c[count] = vaddq_f32(res, bias_ptr); | |||
| count++; | |||
| int count = 0; | |||
| MS_FLOAT32X4 bias_ptr = MS_MOVQ_F32(0); | |||
| if (bias != NULL) { | |||
| bias_ptr = MS_LDQ_F32(bias); | |||
| } | |||
| for (int h = 0; h < m; h++) { | |||
| int h_offset = h * k; | |||
| for (int w = 0; w < n; w++) { | |||
| MS_FLOAT32X4 res = MS_MOVQ_F32(0); | |||
| for (int i = 0; i < k; i++) { | |||
| res = MS_MLAQ_F32(res, matrix_a[h_offset + i], matrix_b[w + i * n]); | |||
| } | |||
| matrix_c[count] = MS_ADDQ_F32(res, bias_ptr); | |||
| count++; | |||
| } | |||
| } | |||
| } | |||
| @@ -52,8 +52,8 @@ void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float * | |||
| int WinogradWeightTransform(const float *weight_data, float *winograd_data, float *matrix_g, const float *matrix_gt, | |||
| int oc_block, int input_unit_, int kernel_unit_, int channel, int batch, bool pack); | |||
| #ifdef ENABLE_ARM | |||
| void MatrixMultiplyVec(const float32x4_t *matrix_a, const float32x4_t *matrix_b, float32x4_t *matrix_c, | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) | |||
| void MatrixMultiplyVec(const MS_FLOAT32X4 *matrix_a, const MS_FLOAT32X4 *matrix_b, MS_FLOAT32X4 *matrix_c, | |||
| const float *bias, int m, int k, int n); | |||
| #endif | |||
| #ifdef __cplusplus | |||
| @@ -17,6 +17,14 @@ | |||
| #ifndef MINDSPORE_LITE_NNACL_OP_BASE_H_ | |||
| #define MINDSPORE_LITE_NNACL_OP_BASE_H_ | |||
| #ifdef ENABLE_ARM | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #ifdef ENABLE_X86_64_SSE | |||
| #include <nmmintrin.h> | |||
| #endif | |||
| #include <stdint.h> | |||
| #include <stdlib.h> | |||
| #include <stdbool.h> | |||
| @@ -70,4 +78,30 @@ typedef struct OpParameter { | |||
| typedef enum ActType { ActType_No, ActType_Relu, ActType_Sigmod, ActType_Relu6, ActType_Prelu } ActType; | |||
| #ifdef ENABLE_ARM | |||
| #define MS_FLOAT32X4 float32x4_t | |||
| #define MS_LDQ_F32 vld1q_f32 | |||
| #define MS_ADDQ_F32 vaddq_f32 | |||
| #define MS_MOVQ_F32 vmovq_n_f32 | |||
| #define MS_DUPQ_F32 vdupq_n_f32 // It is recommended to replace with MS_MOVQ_F32. | |||
| #define MS_SUBQ_F32 vsubq_f32 | |||
| #define MS_MLAQ_F32(src1, src2, src3) vmlaq_f32(src1, src2, src3) | |||
| #define MS_STQ_F32 vst1q_f32 | |||
| #define MS_MAXQ_F32 vmaxq_f32 | |||
| #define MS_MINQ_F32 vminq_f32 | |||
| #define MS_MULQ_F32(src1, src2) vmulq_n_f32(src1, src2) | |||
| #elif defined(ENABLE_X86_64_SSE) | |||
| #define MS_FLOAT32X4 __m128 | |||
| #define MS_LDQ_F32 _mm_loadu_ps | |||
| #define MS_ADDQ_F32 _mm_add_ps | |||
| #define MS_MOVQ_F32 _mm_set_ps1 | |||
| #define MS_DUPQ_F32 _mm_load_ps1 // It is recommended to replace with MS_MOVQ_F32. | |||
| #define MS_MLAQ_F32(src1, src2, src3) _mm_add_ps(src1, _mm_mul_ps(src2, src3)) | |||
| #define MS_STQ_F32 _mm_storeu_ps | |||
| #define MS_SUBQ_F32 _mm_sub_ps | |||
| #define MS_MAXQ_F32 _mm_max_ps | |||
| #define MS_MINQ_F32 _mm_min_ps | |||
| #define MS_MULQ_F32(src1, src2) _mm_mul_ps(src1, _mm_set_ps1(src2)) | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_OP_BASE_H_ | |||
| @@ -79,21 +79,21 @@ void GeneralInputTransformUnit(const float *src_data, float *dst_data, const flo | |||
| int src_step, int dst_step, int in_unit) { | |||
| int len = in_unit * in_unit; | |||
| if (len > MAX_LEN) return; | |||
| #ifdef ENABLE_ARM | |||
| float32x4_t src[MAX_LEN]; | |||
| float32x4_t t[MAX_LEN]; | |||
| float32x4_t m[MAX_LEN]; | |||
| float32x4_t vec_b[MAX_LEN]; | |||
| float32x4_t vec_bt[MAX_LEN]; | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) | |||
| MS_FLOAT32X4 src[MAX_LEN]; | |||
| MS_FLOAT32X4 t[MAX_LEN]; | |||
| MS_FLOAT32X4 m[MAX_LEN]; | |||
| MS_FLOAT32X4 vec_b[MAX_LEN]; | |||
| MS_FLOAT32X4 vec_bt[MAX_LEN]; | |||
| for (int i = 0; i < len; i++) { | |||
| src[i] = vld1q_f32(src_data + i * src_step); | |||
| vec_b[i] = vdupq_n_f32(matrix_b[i]); | |||
| vec_bt[i] = vdupq_n_f32(matrix_bt[i]); | |||
| src[i] = MS_LDQ_F32(src_data + i * src_step); | |||
| vec_b[i] = MS_MOVQ_F32(matrix_b[i]); | |||
| vec_bt[i] = MS_MOVQ_F32(matrix_bt[i]); | |||
| } | |||
| MatrixMultiplyVec(vec_bt, src, t, NULL, in_unit, in_unit, in_unit); | |||
| MatrixMultiplyVec(t, vec_b, m, NULL, in_unit, in_unit, in_unit); | |||
| for (int i = 0; i < len; i++) { | |||
| vst1q_f32(dst_data + i * dst_step, m[i]); | |||
| MS_STQ_F32(dst_data + i * dst_step, m[i]); | |||
| } | |||
| #else | |||
| float src[MAX_LEN]; | |||
| @@ -98,7 +98,7 @@ int Convolution1x1CPUKernel::InitConv1x1BiasWeight() { | |||
| int Convolution1x1CPUKernel::InitConv1x1Param() { | |||
| int hw_tile = C12NUM; | |||
| #ifdef ENABLE_ARM32 | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| hw_tile = C4NUM; | |||
| #endif | |||
| if ((matmul_param_->row_ > (hw_tile * op_parameter_->thread_num_)) && (matmul_param_->row_ > matmul_param_->col_)) { | |||
| @@ -170,7 +170,7 @@ int Convolution1x1CPUKernel::DoConv1x1Hw(int task_id) { | |||
| float *thread_input_ptr = input_ptr_ + task_id * thread_stride_ * matmul_param_->deep_; | |||
| float *thread_pack_input = pack_input_ + task_id * thread_stride_ * matmul_param_->deep_; | |||
| #ifdef ENABLE_ARM32 | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| RowMajor2Col4Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_); | |||
| #else | |||
| RowMajor2Col12Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_); | |||
| @@ -197,7 +197,7 @@ int Convolution1x1CPUKernel::Run() { | |||
| auto src_in = reinterpret_cast<float *>(in_tensors_[0]->MutableData()); | |||
| auto src_out = reinterpret_cast<float *>(out_tensors_[0]->MutableData()); | |||
| #ifdef ENABLE_ARM32 | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| pack_input_ = | |||
| reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->deep_ * sizeof(float))); | |||
| #else | |||
| @@ -221,7 +221,7 @@ int Convolution1x1CPUKernel::Run() { | |||
| if (multi_thread_by_hw_) { | |||
| ParallelLaunch(this->context_->thread_pool_, Convolution1x1RunHw, this, thread_count_); | |||
| } else { | |||
| #ifdef ENABLE_ARM32 | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| RowMajor2Col4Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | |||
| #else | |||
| RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | |||
| @@ -115,7 +115,7 @@ int DeConvolutionCPUKernel::DoDeconv(int task_id) { | |||
| return RET_OK; | |||
| } | |||
| #ifdef ENABLE_ARM32 | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_4_; | |||
| MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, | |||
| tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_4_, oc * C8NUM * kernel_plane_, | |||
| @@ -169,7 +169,7 @@ int DeConvolutionCPUKernel::InitRunBuf() { | |||
| return RET_NULL_PTR; | |||
| } | |||
| #ifdef ENABLE_ARM32 | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| tmp_buffer_ = | |||
| reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->col_8_ * sizeof(float))); | |||
| #else | |||
| @@ -181,7 +181,7 @@ int DeConvolutionCPUKernel::InitRunBuf() { | |||
| return RET_NULL_PTR; | |||
| } | |||
| #ifdef ENABLE_ARM32 | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| pack_input_ = | |||
| reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->deep_ * sizeof(float))); | |||
| #else | |||
| @@ -209,7 +209,7 @@ int DeConvolutionCPUKernel::Run() { | |||
| input_ptr_ = src_in + batch_index * input_plane_ * conv_param_->input_channel_; | |||
| output_ptr_ = src_out + batch_index * output_plane_ * conv_param_->output_channel_; | |||
| #ifdef ENABLE_ARM32 | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| RowMajor2Col4Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | |||
| #else | |||
| RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | |||
| @@ -71,7 +71,7 @@ int FullconnectionCPUKernel::ReSize() { | |||
| memcpy(bias_ptr_, in_tensors_[2]->MutableData(), fc_param_->col_ * sizeof(float)); | |||
| } | |||
| #ifdef ENABLE_ARM32 | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| a_pack_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->row_4_ * fc_param_->deep_ * sizeof(float))); | |||
| if (a_pack_ptr_ == nullptr) { | |||
| return RET_MEMORY_FAILED; | |||
| @@ -120,7 +120,7 @@ void FullconnectionCPUKernel::InitMatrixA(float *src_ptr, float *dst_ptr) { | |||
| return; | |||
| } | |||
| #ifdef ENABLE_ARM32 | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| RowMajor2Col4Major(src_ptr, a_pack_ptr_, fc_param_->row_, fc_param_->deep_); | |||
| #else | |||
| RowMajor2Col12Major(src_ptr, a_pack_ptr_, fc_param_->row_, fc_param_->deep_); | |||
| @@ -65,7 +65,7 @@ int MatmulCPUKernel::MallocMatrixABuffer() { | |||
| params_->row_4_ = UP_ROUND(params_->row_, C4NUM); | |||
| params_->row_12_ = UP_ROUND(params_->row_, C12NUM); | |||
| #ifdef ENABLE_ARM32 | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| a_pack_ptr_ = reinterpret_cast<float *>(malloc(params_->batch * params_->row_4_ * params_->deep_ * sizeof(float))); | |||
| if (a_pack_ptr_ == nullptr) { | |||
| FreeTmpBuffer(); | |||
| @@ -176,7 +176,7 @@ void MatmulCPUKernel::InitMatrixA(float *src_ptr, float *dst_ptr) { | |||
| for (int i = 0; i < params_->batch; i++) { | |||
| float *src = src_ptr + i * params_->deep_ * params_->row_; | |||
| #ifdef ENABLE_ARM32 | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| float *dst = dst_ptr + i * params_->deep_ * params_->row_4_; | |||
| if (params_->a_transpose_) { | |||
| RowMajor2Row4Major(src, dst, params_->deep_, params_->row_); | |||