| @@ -173,3 +173,4 @@ mindspore/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc:mindspo | |||
| mindspore/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.cc:mindspore::parallel::GetWeights | |||
| mindspore/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.cc:mindspore::parallel::PartitionNode | |||
| mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/instance_norm_fp16.c:InstanceNormNC8HW8Fp16 | |||
| mindspore/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc:mindspore::kernel::MatmulFp32BaseCPUKernel::init_global_variable | |||
| @@ -2173,7 +2173,7 @@ void MatVecMulRowxColKernel(float *dst, const float *src, const float *weight, c | |||
| #endif | |||
| #endif | |||
| void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row) { | |||
| void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep) { | |||
| int index = 0; | |||
| #ifdef ENABLE_AVX512 | |||
| __m512 b_data16 = _mm512_set1_ps(b[0]); | |||
| @@ -2213,3 +2213,58 @@ void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, | |||
| c[index] = a[index] * b[0] + bias[0]; | |||
| } | |||
| } | |||
| void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k) { | |||
| // gemm dot is [m, k] * [k, 1] ==>> [m, 1] | |||
| int m_index = 0; | |||
| #ifdef ENABLE_AVX512 | |||
| // block 8 | |||
| for (; m_index <= m - C8NUM; m_index += C8NUM) { | |||
| int k_index = 0; | |||
| MS_FLOAT32X8 dst = MS_MOV256_F32(bias[0]); | |||
| MS_SET_ZERO512X8_F32(dst16_) | |||
| for (; k_index <= k - C16NUM; k_index += C16NUM) { | |||
| __m512 weight = _mm512_loadu_ps(b + k_index); | |||
| MS_LOAD512X8_F32(src, a + m_index * k + k_index, k) | |||
| MS_FMADD512X8_F32(src, weight, dst16_) | |||
| } | |||
| MS_F32X8_GETI(dst, 0) += _mm512_reduce_add_ps(dst16_1); | |||
| MS_F32X8_GETI(dst, 1) += _mm512_reduce_add_ps(dst16_2); | |||
| MS_F32X8_GETI(dst, 2) += _mm512_reduce_add_ps(dst16_3); | |||
| MS_F32X8_GETI(dst, 3) += _mm512_reduce_add_ps(dst16_4); | |||
| MS_F32X8_GETI(dst, 4) += _mm512_reduce_add_ps(dst16_5); | |||
| MS_F32X8_GETI(dst, 5) += _mm512_reduce_add_ps(dst16_6); | |||
| MS_F32X8_GETI(dst, 6) += _mm512_reduce_add_ps(dst16_7); | |||
| MS_F32X8_GETI(dst, 7) += _mm512_reduce_add_ps(dst16_8); | |||
| for (; k_index < k; k_index++) { | |||
| MS_F32X8_GETI(dst, 0) += b[k_index] * a[m_index * k + k_index]; | |||
| MS_F32X8_GETI(dst, 1) += b[k_index] * a[m_index * k + k_index + k]; | |||
| MS_F32X8_GETI(dst, 2) += b[k_index] * a[m_index * k + k_index + 2 * k]; | |||
| MS_F32X8_GETI(dst, 3) += b[k_index] * a[m_index * k + k_index + 3 * k]; | |||
| MS_F32X8_GETI(dst, 4) += b[k_index] * a[m_index * k + k_index + 4 * k]; | |||
| MS_F32X8_GETI(dst, 5) += b[k_index] * a[m_index * k + k_index + 5 * k]; | |||
| MS_F32X8_GETI(dst, 6) += b[k_index] * a[m_index * k + k_index + 6 * k]; | |||
| MS_F32X8_GETI(dst, 7) += b[k_index] * a[m_index * k + k_index + 7 * k]; | |||
| } | |||
| MS_ST256_F32(c + m_index, dst); | |||
| } | |||
| #endif | |||
| // block 1 | |||
| for (; m_index < m; m_index++) { | |||
| c[m_index] = bias[0]; | |||
| int k_index = 0; | |||
| #ifdef ENABLE_AVX512 | |||
| __m512 dst1 = _mm512_setzero_ps(); | |||
| for (; k_index <= k - C16NUM; k_index += C16NUM) { | |||
| __m512 weight = _mm512_loadu_ps(b + k_index); | |||
| __m512 a1 = _mm512_loadu_ps(a + m_index * k + k_index); | |||
| dst1 = _mm512_fmadd_ps(weight, a1, dst1); | |||
| } | |||
| c[m_index] += _mm512_reduce_add_ps(dst1); | |||
| #endif | |||
| for (; k_index < k; k_index++) { | |||
| c[m_index] += b[k_index] * a[m_index * k + k_index]; | |||
| } | |||
| } | |||
| } | |||
| @@ -125,7 +125,9 @@ void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float * | |||
| void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, | |||
| int col, int stride, int out_type); | |||
| void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row); | |||
| void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep); | |||
| void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k); | |||
| #ifdef __cplusplus | |||
| } | |||
| @@ -512,7 +512,7 @@ void PackNHWCToCXHWNXFp32(const float *src, float *dst, int batch, int plane, in | |||
| for (; oc < oc_remainder_c8; oc += C8NUM) { | |||
| const float *cur_src = src + index_batch + oc; | |||
| float *cur_dst = dst + oc; | |||
| LOAD256X16_F32(r, cur_src, channel); | |||
| MS_LOAD256X16_F32(r, cur_src, channel); | |||
| STORE256X16_F32(cur_dst, stride, r); | |||
| } | |||
| for (; oc < oc_remainder; ++oc) { | |||
| @@ -821,7 +821,7 @@ inline void Transpose8X8Fp32Arm32(const float *src_ptr, float *dst_ptr, int src_ | |||
| #ifdef ENABLE_AVX | |||
| inline void Transpose8X8Fp32Avx(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) { | |||
| LOAD256X8_F32(src, src_ptr, src_stride) | |||
| MS_LOAD256X8_F32(src, src_ptr, src_stride) | |||
| __m256 r1 = _mm256_unpacklo_ps(src1, src2); | |||
| __m256 r2 = _mm256_unpackhi_ps(src1, src2); | |||
| __m256 r3 = _mm256_unpacklo_ps(src3, src4); | |||
| @@ -35,7 +35,7 @@ | |||
| #define MS_ADD512_EPI32 _mm512_add_epi32 | |||
| #define MS_MOV512_F32 _mm512_set1_ps | |||
| #define MS_MOV512_EPI32 _mm512_set1_epi32 | |||
| #define MS_MLA512_F32(src1, src2, src3) _mm512_add_ps(src1, _mm512_mul_ps(src2, src3)) | |||
| #define MS_MLA512_F32(src1, src2, src3) _mm512_fmadd_ps(src2, src3, src1) | |||
| #define MS_ST512_F32 _mm512_storeu_ps | |||
| #define MS_ST512_EPI32(src1, src2) _mm512_storeu_si512((__m512i *)(src1), src2) | |||
| #define MS_SUB512_F32 _mm512_sub_ps | |||
| @@ -93,4 +93,51 @@ static inline MS_FLOAT32X16 MS_TANHX16_F32(MS_FLOAT32X16 src) { | |||
| return MS_MIN512_F32(MS_MAX512_F32(MS_DIV512_F32(a, b), neg), pos); | |||
| } | |||
| #endif | |||
| #define MS_LOAD512X8_F32(src, input_ptr, num) \ | |||
| MS_FLOAT32X16 src##1 = MS_LD512_F32(input_ptr); \ | |||
| MS_FLOAT32X16 src##2 = MS_LD512_F32(input_ptr + 1 * num); \ | |||
| MS_FLOAT32X16 src##3 = MS_LD512_F32(input_ptr + 2 * num); \ | |||
| MS_FLOAT32X16 src##4 = MS_LD512_F32(input_ptr + 3 * num); \ | |||
| MS_FLOAT32X16 src##5 = MS_LD512_F32(input_ptr + 4 * num); \ | |||
| MS_FLOAT32X16 src##6 = MS_LD512_F32(input_ptr + 5 * num); \ | |||
| MS_FLOAT32X16 src##7 = MS_LD512_F32(input_ptr + 6 * num); \ | |||
| MS_FLOAT32X16 src##8 = MS_LD512_F32(input_ptr + 7 * num); | |||
| #define MS_LOAD512X4_F32(src, input_ptr, num) \ | |||
| MS_FLOAT32X16 src##1 = MS_LD512_F32(input_ptr); \ | |||
| MS_FLOAT32X16 src##2 = MS_LD512_F32(input_ptr + 1 * num); \ | |||
| MS_FLOAT32X16 src##3 = MS_LD512_F32(input_ptr + 2 * num); \ | |||
| MS_FLOAT32X16 src##4 = MS_LD512_F32(input_ptr + 3 * num); | |||
| #define MS_FMADD512X8_F32(src, weight, dst) \ | |||
| dst##1 = MS_MLA512_F32(dst##1, src##1, weight); \ | |||
| dst##2 = MS_MLA512_F32(dst##2, src##2, weight); \ | |||
| dst##3 = MS_MLA512_F32(dst##3, src##3, weight); \ | |||
| dst##4 = MS_MLA512_F32(dst##4, src##4, weight); \ | |||
| dst##5 = MS_MLA512_F32(dst##5, src##5, weight); \ | |||
| dst##6 = MS_MLA512_F32(dst##6, src##6, weight); \ | |||
| dst##7 = MS_MLA512_F32(dst##7, src##7, weight); \ | |||
| dst##8 = MS_MLA512_F32(dst##8, src##8, weight); | |||
| #define MS_FMADD512X4_F32(src, weight, dst) \ | |||
| dst##1 = MS_MLA512_F32(src##1, weight, dst##1); \ | |||
| dst##2 = MS_MLA512_F32(src##2, weight, dst##2); \ | |||
| dst##3 = MS_MLA512_F32(src##3, weight, dst##3); \ | |||
| dst##4 = MS_MLA512_F32(src##4, weight, dst##4); | |||
| #define MS_SET_ZERO512X8_F32(dst) \ | |||
| MS_FLOAT32X16 dst##1 = _mm512_setzero_ps(); \ | |||
| MS_FLOAT32X16 dst##2 = _mm512_setzero_ps(); \ | |||
| MS_FLOAT32X16 dst##3 = _mm512_setzero_ps(); \ | |||
| MS_FLOAT32X16 dst##4 = _mm512_setzero_ps(); \ | |||
| MS_FLOAT32X16 dst##5 = _mm512_setzero_ps(); \ | |||
| MS_FLOAT32X16 dst##6 = _mm512_setzero_ps(); \ | |||
| MS_FLOAT32X16 dst##7 = _mm512_setzero_ps(); \ | |||
| MS_FLOAT32X16 dst##8 = _mm512_setzero_ps(); | |||
| #define MS_SET_ZERO512X4_F32(dst) \ | |||
| MS_FLOAT32X16 dst##1 = _mm512_setzero_ps(); \ | |||
| MS_FLOAT32X16 dst##2 = _mm512_setzero_ps(); \ | |||
| MS_FLOAT32X16 dst##3 = _mm512_setzero_ps(); \ | |||
| MS_FLOAT32X16 dst##4 = _mm512_setzero_ps(); | |||
| #endif // MINDSPORE_NNACL_AVX512_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ | |||
| @@ -76,7 +76,7 @@ static inline MS_FLOAT32X8 MS_SQRTFX8_F32(MS_FLOAT32X8 src) { | |||
| return dst; | |||
| } | |||
| #define LOAD256X8_F32(src, input_ptr, num) \ | |||
| #define MS_LOAD256X8_F32(src, input_ptr, num) \ | |||
| MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \ | |||
| MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \ | |||
| MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \ | |||
| @@ -86,7 +86,7 @@ static inline MS_FLOAT32X8 MS_SQRTFX8_F32(MS_FLOAT32X8 src) { | |||
| MS_FLOAT32X8 src##7 = MS_LD256_F32(input_ptr + 6 * num); \ | |||
| MS_FLOAT32X8 src##8 = MS_LD256_F32(input_ptr + 7 * num); | |||
| #define LOAD256X16_F32(src, input_ptr, num) \ | |||
| #define MS_LOAD256X16_F32(src, input_ptr, num) \ | |||
| MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \ | |||
| MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \ | |||
| MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \ | |||
| @@ -154,4 +154,35 @@ static inline MS_FLOAT32X8 MS_TANHX8_F32(MS_FLOAT32X8 src) { | |||
| return MS_MIN256_F32(MS_MAX256_F32(MS_DIV256_F32(a, b), neg), pos); | |||
| } | |||
| #endif | |||
| #define MS_FMADD256X8_F32(src, weight, dst) \ | |||
| dst##1 = MS_MLA256_F32(dst##1, src##1, weight); \ | |||
| dst##2 = MS_MLA256_F32(dst##2, src##2, weight); \ | |||
| dst##3 = MS_MLA256_F32(dst##3, src##3, weight); \ | |||
| dst##4 = MS_MLA256_F32(dst##4, src##4, weight); \ | |||
| dst##5 = MS_MLA256_F32(dst##5, src##5, weight); \ | |||
| dst##6 = MS_MLA256_F32(dst##6, src##6, weight); \ | |||
| dst##7 = MS_MLA256_F32(dst##7, src##7, weight); \ | |||
| dst##8 = MS_MLA256_F32(dst##8, src##8, weight); | |||
| #define MS_SET_ZERO256X8_F32(dst) \ | |||
| MS_FLOAT32X8 dst##1 = _mm256_setzero_ps(); \ | |||
| MS_FLOAT32X8 dst##2 = _mm256_setzero_ps(); \ | |||
| MS_FLOAT32X8 dst##3 = _mm256_setzero_ps(); \ | |||
| MS_FLOAT32X8 dst##4 = _mm256_setzero_ps(); \ | |||
| MS_FLOAT32X8 dst##5 = _mm256_setzero_ps(); \ | |||
| MS_FLOAT32X8 dst##6 = _mm256_setzero_ps(); \ | |||
| MS_FLOAT32X8 dst##7 = _mm256_setzero_ps(); \ | |||
| MS_FLOAT32X8 dst##8 = _mm256_setzero_ps(); | |||
| #define MS_FMADD256X4_F32(src, weight, dst) \ | |||
| dst##1 = MS_MLA256_F32(dst##1, src##1, weight); \ | |||
| dst##2 = MS_MLA256_F32(dst##2, src##2, weight); \ | |||
| dst##3 = MS_MLA256_F32(dst##3, src##3, weight); \ | |||
| dst##4 = MS_MLA256_F32(dst##4, src##4, weight); | |||
| #define MS_SET_ZERO256X4_F32(dst) \ | |||
| MS_FLOAT32X8 dst##1 = _mm256_setzero_ps(); \ | |||
| MS_FLOAT32X8 dst##2 = _mm256_setzero_ps(); \ | |||
| MS_FLOAT32X8 dst##3 = _mm256_setzero_ps(); \ | |||
| MS_FLOAT32X8 dst##4 = _mm256_setzero_ps(); | |||
| #endif // MINDSPORE_NNACL_AVX_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ | |||
| @@ -152,4 +152,51 @@ static inline MS_FLOAT32X4 MS_ERFX4_F32(MS_FLOAT32X4 src) { | |||
| return dst; | |||
| } | |||
| #endif | |||
| #define MS_FMADD128X8_F32(src, weight, dst) \ | |||
| dst##1 = MS_MLAQ_F32(src##1, weight, dst##1); \ | |||
| dst##2 = MS_MLAQ_F32(src##2, weight, dst##2); \ | |||
| dst##3 = MS_MLAQ_F32(src##3, weight, dst##3); \ | |||
| dst##4 = MS_MLAQ_F32(src##4, weight, dst##4); \ | |||
| dst##5 = MS_MLAQ_F32(src##5, weight, dst##5); \ | |||
| dst##6 = MS_MLAQ_F32(src##6, weight, dst##6); \ | |||
| dst##7 = MS_MLAQ_F32(src##7, weight, dst##7); \ | |||
| dst##8 = MS_MLAQ_F32(src##8, weight, dst##8); | |||
| #define MS_LOAD128X4_F32(src, input_ptr, num) \ | |||
| MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \ | |||
| MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \ | |||
| MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \ | |||
| MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); | |||
| #define MS_FMADD128X4_F32(src, weight, dst) \ | |||
| dst##1 = MS_MLAQ_F32(src##1, weight, dst##1); \ | |||
| dst##2 = MS_MLAQ_F32(src##2, weight, dst##2); \ | |||
| dst##3 = MS_MLAQ_F32(src##3, weight, dst##3); \ | |||
| dst##4 = MS_MLAQ_F32(src##4, weight, dst##4); | |||
| #define MS_LOAD128X8_F32(src, input_ptr, num) \ | |||
| MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \ | |||
| MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \ | |||
| MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \ | |||
| MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); \ | |||
| MS_FLOAT32X4 src##5 = MS_LDQ_F32(input_ptr + 4 * num); \ | |||
| MS_FLOAT32X4 src##6 = MS_LDQ_F32(input_ptr + 5 * num); \ | |||
| MS_FLOAT32X4 src##7 = MS_LDQ_F32(input_ptr + 6 * num); \ | |||
| MS_FLOAT32X4 src##8 = MS_LDQ_F32(input_ptr + 7 * num); | |||
| #define MS_SET_ZERO128X8_F32(dst) \ | |||
| MS_FLOAT32X4 dst##1 = MS_MOVQ_F32(0.0f); \ | |||
| MS_FLOAT32X4 dst##2 = MS_MOVQ_F32(0.0f); \ | |||
| MS_FLOAT32X4 dst##3 = MS_MOVQ_F32(0.0f); \ | |||
| MS_FLOAT32X4 dst##4 = MS_MOVQ_F32(0.0f); \ | |||
| MS_FLOAT32X4 dst##5 = MS_MOVQ_F32(0.0f); \ | |||
| MS_FLOAT32X4 dst##6 = MS_MOVQ_F32(0.0f); \ | |||
| MS_FLOAT32X4 dst##7 = MS_MOVQ_F32(0.0f); \ | |||
| MS_FLOAT32X4 dst##8 = MS_MOVQ_F32(0.0f); | |||
| #define MS_SET_ZERO128X4_F32(dst) \ | |||
| MS_FLOAT32X4 dst##1 = MS_MOVQ_F32(0.0f); \ | |||
| MS_FLOAT32X4 dst##2 = MS_MOVQ_F32(0.0f); \ | |||
| MS_FLOAT32X4 dst##3 = MS_MOVQ_F32(0.0f); \ | |||
| MS_FLOAT32X4 dst##4 = MS_MOVQ_F32(0.0f); | |||
| #endif // MINDSPORE_NNACL_NEON_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ | |||
| @@ -90,16 +90,6 @@ static inline MS_FLOAT32X4 MS_SQRTFX4_F32(MS_FLOAT32X4 src) { | |||
| return dst; | |||
| } | |||
| #define LOAD128X8_F32(src, input_ptr, num) \ | |||
| MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \ | |||
| MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \ | |||
| MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \ | |||
| MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); \ | |||
| MS_FLOAT32X4 src##5 = MS_LDQ_F32(input_ptr + 4 * num); \ | |||
| MS_FLOAT32X4 src##6 = MS_LDQ_F32(input_ptr + 5 * num); \ | |||
| MS_FLOAT32X4 src##7 = MS_LDQ_F32(input_ptr + 6 * num); \ | |||
| MS_FLOAT32X4 src##8 = MS_LDQ_F32(input_ptr + 7 * num); | |||
| #define STORE128X8_F32(output_ptr, num, dst) \ | |||
| MS_STQ_F32(output_ptr + 0 * num, dst##1); \ | |||
| MS_STQ_F32(output_ptr + 1 * num, dst##2); \ | |||
| @@ -137,4 +127,51 @@ static inline MS_FLOAT32X4 MS_ERFX4_F32(MS_FLOAT32X4 src) { | |||
| return dst; | |||
| } | |||
| #endif | |||
| #define MS_FMADD128X8_F32(src, weight, dst) \ | |||
| dst##1 = MS_MLAQ_F32(src##1, weight, dst##1); \ | |||
| dst##2 = MS_MLAQ_F32(src##2, weight, dst##2); \ | |||
| dst##3 = MS_MLAQ_F32(src##3, weight, dst##3); \ | |||
| dst##4 = MS_MLAQ_F32(src##4, weight, dst##4); \ | |||
| dst##5 = MS_MLAQ_F32(src##5, weight, dst##5); \ | |||
| dst##6 = MS_MLAQ_F32(src##6, weight, dst##6); \ | |||
| dst##7 = MS_MLAQ_F32(src##7, weight, dst##7); \ | |||
| dst##8 = MS_MLAQ_F32(src##8, weight, dst##8); | |||
| #define MS_LOAD128X4_F32(src, input_ptr, num) \ | |||
| MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \ | |||
| MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \ | |||
| MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \ | |||
| MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); | |||
| #define MS_FMADD128X4_F32(src, weight, dst) \ | |||
| dst##1 = MS_MLAQ_F32(src##1, weight, dst##1); \ | |||
| dst##2 = MS_MLAQ_F32(src##2, weight, dst##2); \ | |||
| dst##3 = MS_MLAQ_F32(src##3, weight, dst##3); \ | |||
| dst##4 = MS_MLAQ_F32(src##4, weight, dst##4); | |||
| #define MS_LOAD128X8_F32(src, input_ptr, num) \ | |||
| MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \ | |||
| MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \ | |||
| MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \ | |||
| MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); \ | |||
| MS_FLOAT32X4 src##5 = MS_LDQ_F32(input_ptr + 4 * num); \ | |||
| MS_FLOAT32X4 src##6 = MS_LDQ_F32(input_ptr + 5 * num); \ | |||
| MS_FLOAT32X4 src##7 = MS_LDQ_F32(input_ptr + 6 * num); \ | |||
| MS_FLOAT32X4 src##8 = MS_LDQ_F32(input_ptr + 7 * num); | |||
| #define MS_SET_ZERO128X8_F32(dst) \ | |||
| MS_FLOAT32X4 dst##1 = MS_MOVQ_F32(0.0f); \ | |||
| MS_FLOAT32X4 dst##2 = MS_MOVQ_F32(0.0f); \ | |||
| MS_FLOAT32X4 dst##3 = MS_MOVQ_F32(0.0f); \ | |||
| MS_FLOAT32X4 dst##4 = MS_MOVQ_F32(0.0f); \ | |||
| MS_FLOAT32X4 dst##5 = MS_MOVQ_F32(0.0f); \ | |||
| MS_FLOAT32X4 dst##6 = MS_MOVQ_F32(0.0f); \ | |||
| MS_FLOAT32X4 dst##7 = MS_MOVQ_F32(0.0f); \ | |||
| MS_FLOAT32X4 dst##8 = MS_MOVQ_F32(0.0f); | |||
| #define MS_SET_ZERO128X4_F32(dst) \ | |||
| MS_FLOAT32X4 dst##1 = MS_MOVQ_F32(0.0f); \ | |||
| MS_FLOAT32X4 dst##2 = MS_MOVQ_F32(0.0f); \ | |||
| MS_FLOAT32X4 dst##3 = MS_MOVQ_F32(0.0f); \ | |||
| MS_FLOAT32X4 dst##4 = MS_MOVQ_F32(0.0f); | |||
| #endif // MINDSPORE_NNACL_SSE_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ | |||
| @@ -37,10 +37,8 @@ int MatmulRun(const void *cdata, int task_id, float, float) { | |||
| } | |||
| MatmulFp32BaseCPUKernel::~MatmulFp32BaseCPUKernel() { | |||
| if (is_pack_) { | |||
| FreeResizeBufA(); | |||
| FreeResizeBufB(); | |||
| } | |||
| FreeResizeBufA(); | |||
| FreeResizeBufB(); | |||
| if (is_pack_ && out_need_aligned_ && oc_res_ != 0 && output_data_ != nullptr) { | |||
| free(output_data_); | |||
| output_data_ = nullptr; | |||
| @@ -250,7 +248,7 @@ int MatmulFp32BaseCPUKernel::ParallelRunIsNotPackByBatch(int task_id) const { | |||
| const float *a = a_pack_ptr_ + index * params_->row_ * params_->deep_; | |||
| const float *b = b_pack_ptr_ + index * params_->deep_ * params_->col_; | |||
| float *c = output_data_ + index * params_->row_ * params_->col_; | |||
| GemmIsNotPack(a, b, c, &bias, params_->row_); | |||
| gemmIsNotPackFun(a, b, c, &bias, params_->row_, params_->deep_); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -294,7 +292,7 @@ int MatmulFp32BaseCPUKernel::ParallelRunByOC(int task_id) const { | |||
| return RET_OK; | |||
| } | |||
| void MatmulFp32BaseCPUKernel::init_global_variable() { | |||
| int MatmulFp32BaseCPUKernel::init_global_variable() { | |||
| #ifdef ENABLE_AVX512 | |||
| matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor; | |||
| matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2Col64Major : RowMajor2Row64Major; | |||
| @@ -335,18 +333,27 @@ void MatmulFp32BaseCPUKernel::init_global_variable() { | |||
| // need not aligned | |||
| col_step_ = params_->col_; | |||
| #endif | |||
| MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_, params_->row_align_, RET_ERROR); | |||
| MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->row_align_, params_->deep_, RET_ERROR); | |||
| MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_, params_->col_align_, RET_ERROR); | |||
| MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->col_align_, params_->deep_, RET_ERROR); | |||
| if (params_->col_ == 1 && params_->b_const_) { | |||
| is_pack_ = false; | |||
| matrix_a_pack_size_ = a_batch_ * params_->row_ * params_->deep_; | |||
| matrix_b_pack_size_ = b_batch_ * params_->col_ * params_->deep_; | |||
| matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor; | |||
| matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor; | |||
| } else { | |||
| matrix_a_pack_size_ = a_batch_ * params_->row_align_ * params_->deep_; | |||
| matrix_b_pack_size_ = b_batch_ * params_->col_align_ * params_->deep_; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int MatmulFp32BaseCPUKernel::Prepare() { | |||
| CHECK_LESS_RETURN(in_tensors_.size(), C2NUM); | |||
| CHECK_LESS_RETURN(out_tensors_.size(), 1); | |||
| init_global_variable(); | |||
| MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_, params_->row_align_, RET_ERROR); | |||
| MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->row_align_, params_->deep_, RET_ERROR); | |||
| matrix_a_pack_size_ = a_batch_ * params_->row_align_ * params_->deep_; | |||
| MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_, params_->col_align_, RET_ERROR); | |||
| MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->col_align_, params_->deep_, RET_ERROR); | |||
| matrix_b_pack_size_ = b_batch_ * params_->col_align_ * params_->deep_; | |||
| if (matrix_a_pack_size_ < 0 || matrix_b_pack_size_ < 0) { | |||
| MS_LOG(ERROR) << "Matrix pack size is negative " | |||
| << "matrix_a_pack_size=" << matrix_a_pack_size_ << "matrix_b_pack_size=" << matrix_b_pack_size_; | |||
| @@ -358,6 +365,8 @@ int MatmulFp32BaseCPUKernel::Prepare() { | |||
| return ret; | |||
| } | |||
| if (params_->a_const_) { | |||
| auto a_tensor = in_tensors_[0]; | |||
| CHECK_NULL_RETURN(a_tensor); | |||
| if (InitBufferA() != RET_OK) { | |||
| return RET_ERROR; | |||
| } | |||
| @@ -394,10 +403,6 @@ int MatmulFp32BaseCPUKernel::ReSize() { | |||
| set_workspace_size((matrix_a_pack_size_ + matrix_b_pack_size_) * static_cast<int>(sizeof(float))); | |||
| } | |||
| GetThreadCuttingPolicy(); | |||
| if (params_->col_ == 1 && params_->deep_ == 1) { | |||
| is_pack_ = false; | |||
| parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunIsNotPackByBatch; | |||
| } | |||
| auto ret = InitTmpOutBuffer(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "InitTmpOutBuffer error!"; | |||
| @@ -438,7 +443,7 @@ int MatmulFp32BaseCPUKernel::InitTmpOutBuffer() { | |||
| } | |||
| void MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() { | |||
| if (params_->batch >= op_parameter_->thread_num_) { | |||
| if (params_->batch >= op_parameter_->thread_num_ || (params_->col_ == 1 && params_->b_const_)) { | |||
| thread_count_ = op_parameter_->thread_num_; | |||
| batch_stride_ = UP_DIV(params_->batch, thread_count_); | |||
| batch_split_ = true; | |||
| @@ -453,6 +458,15 @@ void MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() { | |||
| batch_split_ = false; | |||
| parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunByOC; | |||
| } | |||
| if (params_->col_ == 1 && params_->b_const_) { | |||
| is_pack_ = false; | |||
| parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunIsNotPackByBatch; | |||
| if (params_->deep_ == 1) { | |||
| gemmIsNotPackFun = GemmIsNotPack; | |||
| } else { | |||
| gemmIsNotPackFun = GemmIsNotPackOptimize; | |||
| } | |||
| } | |||
| } | |||
| int MatmulFp32BaseCPUKernel::Run() { | |||
| @@ -517,12 +531,20 @@ int MatmulFp32BaseCPUKernel::Run() { | |||
| PackNHWCXToNHWCFp32(output_data_, out_data, params_->batch, params_->row_, params_->col_, col_tile_); | |||
| } | |||
| if (!params_->a_const_) { | |||
| FreeResizeBufA(); | |||
| if (is_pack_) { | |||
| FreeResizeBufA(); | |||
| } else { | |||
| a_pack_ptr_ = nullptr; | |||
| } | |||
| } | |||
| if (!params_->b_const_) { | |||
| FreeResizeBufB(); | |||
| if (is_pack_) { | |||
| FreeResizeBufB(); | |||
| } else { | |||
| b_pack_ptr_ = nullptr; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::kernel | |||
| } // namespace mindspore::kernel | |||
| @@ -33,6 +33,8 @@ using GemmFun = void (*)(const float *a, const float *b, float *c, const float * | |||
| const int depth, const int cur_col, const int col_align, const int row); | |||
| using GemvFun = void (*)(const float *a, const float *b, float *c, const float *bias, const int act_type, | |||
| const int depth, const int cur_col, const int col_align); | |||
| using GemmIsNotPackFun = void (*)(const float *a, const float *b, float *c, const float *bias, int m, int k); | |||
| class MatmulFp32BaseCPUKernel : public InnerKernel { | |||
| public: | |||
| MatmulFp32BaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| @@ -62,7 +64,7 @@ class MatmulFp32BaseCPUKernel : public InnerKernel { | |||
| void FreeBiasBuf(); | |||
| int InitBiasData(); | |||
| void InitParameter(); | |||
| void init_global_variable(); | |||
| int init_global_variable(); | |||
| private: | |||
| void ResizeParameter(); | |||
| @@ -105,6 +107,7 @@ class MatmulFp32BaseCPUKernel : public InnerKernel { | |||
| GemmFun gemmCalFun = nullptr; | |||
| GemvFun gemvCalFun = nullptr; | |||
| #endif | |||
| GemmIsNotPackFun gemmIsNotPackFun = nullptr; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_MATMUL_FP32_BASE_H_ | |||