| @@ -19,7 +19,7 @@ if(NOT MSVC) | |||
| set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1 -mavx -mavx2 -mfma") | |||
| endif() | |||
| if("${X86_64_SIMD}" STREQUAL "avx512") | |||
| set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1 -mavx -mavx2 -mfma") | |||
| set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1 -mavx -mavx2 -mfma -mavx512f") | |||
| endif() | |||
| if("${X86_64_SIMD}" STREQUAL "sse") | |||
| set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1") | |||
| @@ -20,6 +20,13 @@ | |||
| int Fp32Relu(const float *src, int length, float *dst) { | |||
| int i = 0; | |||
| #if defined(ENABLE_AVX512) | |||
| MS_FLOAT32X16 zero_16 = MS_MOV512_F32(0.0f); | |||
| for (; i <= length - C16NUM; i += C16NUM) { | |||
| MS_ST512_F32(dst + i, MS_MAX512_F32(MS_LD512_F32(src + i), zero_16)); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_AVX) | |||
| MS_FLOAT32X8 zero_8 = MS_MOV256_F32(0.0f); | |||
| for (; i <= length - 8; i += 8) { | |||
| @@ -49,6 +56,16 @@ int Int32Relu(const int32_t *src, int length, int32_t *dst) { | |||
| int Fp32Relu6(const float *src, int length, float *dst) { | |||
| int i = 0; | |||
| #if defined(ENABLE_AVX512) | |||
| MS_FLOAT32X16 zero_16 = MS_MOV512_F32(0.0f); | |||
| MS_FLOAT32X16 six_16 = MS_MOV512_F32(6.0f); | |||
| for (; i <= length - C16NUM; i += C16NUM) { | |||
| MS_FLOAT32X16 dst_tmp = MS_MAX512_F32(MS_LD512_F32(src + i), zero_16); | |||
| dst_tmp = MS_MIN512_F32(dst_tmp, six_16); | |||
| MS_ST512_F32(dst + i, dst_tmp); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_AVX) | |||
| MS_FLOAT32X8 zero_8 = MS_MOV256_F32(0.0f); | |||
| MS_FLOAT32X8 six_8 = MS_MOV256_F32(6.0f); | |||
| @@ -109,6 +126,14 @@ int LRelu(const float *src, int length, float *dst, float alpha) { | |||
| int Sigmoid(const float *src, int length, float *dst) { | |||
| int i = 0; | |||
| #if defined(ENABLE_AVX512) | |||
| for (; i <= length - C16NUM; i += C16NUM) { | |||
| simd_exp_avx512(MS_SUB512_F32(MS_MOV512_F32(0.0f), (MS_LD512_F32(src + i))), dst + i); | |||
| MS_ST512_F32(dst + i, | |||
| MS_DIV512_F32(MS_MOV512_F32(1.0f), MS_ADD512_F32(MS_MOV512_F32(1.0f), MS_LD512_F32(dst + i)))); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_AVX) | |||
| for (; i <= length - 8; i += 8) { | |||
| simd_exp_avx(MS_SUB256_F32(MS_MOV256_F32(0.0f), (MS_LD256_F32(src + i))), dst + i); | |||
| @@ -145,6 +170,13 @@ float TanhOpt(float src) { | |||
| int Tanh(const float *src, int length, float *dst) { | |||
| int i = 0; | |||
| #if defined(ENABLE_AVX512) | |||
| for (; i <= length - C16NUM; i += C16NUM) { | |||
| MS_FLOAT32X16 input = MS_LD512_F32(src + i); | |||
| MS_ST512_F32(dst + i, MS_TANHX16_F32(input)); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_AVX) | |||
| for (; i <= length - 8; i += 8) { | |||
| MS_FLOAT32X8 input = MS_LD256_F32(src + i); | |||
| @@ -232,6 +232,14 @@ int BroadcastAdd(const float *in0, const float *in1, float *tile_in0, float *til | |||
| int ElementAdd(const float *in0, const float *in1, float *out, int size) { | |||
| int index = 0; | |||
| #ifdef ENABLE_AVX512 | |||
| for (; index <= size - C16NUM; index += C16NUM) { | |||
| MS_FLOAT32X16 vin0 = MS_LD512_F32(in0 + index); | |||
| MS_FLOAT32X16 vin1 = MS_LD512_F32(in1 + index); | |||
| MS_FLOAT32X16 vout = MS_ADD512_F32(vin0, vin1); | |||
| MS_ST512_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| #ifdef ENABLE_AVX | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); | |||
| @@ -256,6 +264,17 @@ int ElementAdd(const float *in0, const float *in1, float *out, int size) { | |||
| int ElementAddRelu(const float *in0, const float *in1, float *out, int size) { | |||
| int index = 0; | |||
| #ifdef ENABLE_AVX512 | |||
| MS_FLOAT32X16 zeros_16 = MS_MOV512_F32(0.0f); | |||
| for (; index <= size - C16NUM; index += C16NUM) { | |||
| MS_FLOAT32X16 vin0 = MS_LD512_F32(in0 + index); | |||
| MS_FLOAT32X16 vin1 = MS_LD512_F32(in1 + index); | |||
| MS_FLOAT32X16 vout = MS_ADD512_F32(vin0, vin1); | |||
| __mmask16 mask = MS_CMP512_F32(vout, zeros_16, 30); | |||
| vout = MS_BLEND512_F32(mask, zeros_16, vout); | |||
| MS_ST512_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| #ifdef ENABLE_AVX | |||
| MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f); | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| @@ -285,6 +304,16 @@ int ElementAddRelu(const float *in0, const float *in1, float *out, int size) { | |||
| int ElementAddRelu6(const float *in0, const float *in1, float *out, int size) { | |||
| int index = 0; | |||
| #ifdef ENABLE_AVX512 | |||
| MS_FLOAT32X16 zeros_16 = MS_MOV512_F32(0.0f); | |||
| MS_FLOAT32X16 bounds_16 = MS_MOV512_F32(6.0f); | |||
| for (; index <= size - C16NUM; index += C16NUM) { | |||
| MS_FLOAT32X16 vin0 = MS_LD512_F32(in0 + index); | |||
| MS_FLOAT32X16 vin1 = MS_LD512_F32(in1 + index); | |||
| MS_FLOAT32X16 vout = MS_MIN512_F32(MS_MAX512_F32(MS_ADD512_F32(vin0, vin1), zeros_16), bounds_16); | |||
| MS_ST512_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| #ifdef ENABLE_AVX | |||
| MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f); | |||
| MS_FLOAT32X8 bounds_8 = MS_MOV256_F32(6.0f); | |||
| @@ -19,8 +19,8 @@ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/exp_parameter.h" | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| #include "nnacl/intrinsics/ms_simd_instructions.h" | |||
| #ifdef ENABLE_AVX512 | |||
| #include "nnacl/intrinsics/ms_simd_avx512_instructions.h" | |||
| #endif | |||
| #ifdef __cplusplus | |||
| @@ -55,6 +55,34 @@ static inline void simd_exp(MS_FLOAT32X4 input, float *dst) { | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_AVX512) | |||
| static inline void simd_exp_avx512(MS_FLOAT32X16 input, float *dst) { | |||
| static MS_FLOAT32X16 maxv = {88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f, | |||
| 98.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f}; | |||
| static MS_FLOAT32X16 minv = {-88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f, | |||
| -88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f}; | |||
| static MS_FLOAT32X16 param[] = { | |||
| {0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, | |||
| 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f}, | |||
| {1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, | |||
| 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120}, | |||
| {1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, | |||
| 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24}, | |||
| {1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, | |||
| 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6}, | |||
| {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}, | |||
| {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}}; | |||
| input = MS_MAX512_F32(minv, MS_MIN512_F32(input, maxv)); | |||
| MS_INT32X16 integer = MS_CVT512PS_EPI32(MS_DIV512_F32(input, param[0])); | |||
| MS_FLOAT32X16 decimal = MS_SUB512_F32(input, MS_MUL512_F32(MS_CVT512EPI32_PS(integer), param[0])); | |||
| MS_INT32X16 int_exp = MS_SLLI512_EPI32(MS_ADD512_EPI32(integer, MS_MOV512_EPI32(127)), 23); | |||
| MS_FLOAT32X16 tmp = MS_MUL512_F32(decimal, (MS_ADD512_F32(param[2], MS_MUL512_F32(decimal, param[1])))); | |||
| tmp = MS_MUL512_F32(decimal, MS_ADD512_F32(param[4], MS_MUL512_F32(decimal, MS_ADD512_F32(param[3], tmp)))); | |||
| MS_FLOAT32X16 decimal_exp = MS_ADD512_F32(param[5], MS_MUL512_F32(decimal, MS_ADD512_F32(param[5], tmp))); | |||
| MS_ST512_F32(dst, MS_MUL512_F32(decimal_exp, MS_CAST512_F32_S32(int_exp))); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_AVX) | |||
| static inline void simd_exp_avx(MS_FLOAT32X8 input, float *dst) { | |||
| static MS_FLOAT32X8 maxv = {88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f}; | |||
| @@ -24,6 +24,15 @@ int BroadcastMul(const float *in0, const float *in1, float *tile_in0, float *til | |||
| int ElementMul(const float *in0, const float *in1, float *out, int size) { | |||
| int index = 0; | |||
| #if defined(ENABLE_AVX512) | |||
| for (; index <= size - C16NUM; index += C16NUM) { | |||
| MS_FLOAT32X16 vin0 = MS_LD512_F32(in0 + index); | |||
| MS_FLOAT32X16 vin1 = MS_LD512_F32(in1 + index); | |||
| MS_FLOAT32X16 vout = MS_MUL512_F32(vin0, vin1); | |||
| MS_ST512_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_AVX) | |||
| for (; index <= size - C8NUM; index += C8NUM) { | |||
| MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); | |||
| @@ -32,6 +41,7 @@ int ElementMul(const float *in0, const float *in1, float *out, int size) { | |||
| MS_ST256_F32(out + index, vout); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) || defined(ENABLE_SSE) | |||
| for (; index <= size - C4NUM; index += C4NUM) { | |||
| MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); | |||
| @@ -0,0 +1,89 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version C2NUM.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-C2NUM.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_NNACL_AVX512_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ | |||
| #define MINDSPORE_NNACL_AVX512_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ | |||
| #include <math.h> | |||
| #ifdef _MSC_VER | |||
| #include <immintrin.h> | |||
| #define MS_F32X16_GETI(src, i) src.m512_f32[i] | |||
| #else | |||
| #include <x86intrin.h> | |||
| #define MS_F32X16_GETI(src, i) src[i] | |||
| #endif | |||
| #define MS_FLOAT32X16 __m512 | |||
| #define MS_INT32X16 __m512i | |||
| #define MS_LD512_F32 _mm512_loadu_ps | |||
| #define MS_LD512_EPI32(src) _mm512_loadu_si512((__m512i const *)(src)) | |||
| #define MS_ADD512_F32 _mm512_add_ps | |||
| #define MS_ADD512_EPI32 _mm512_add_epi32 | |||
| #define MS_MOV512_F32 _mm512_set1_ps | |||
| #define MS_MOV512_EPI32 _mm512_set1_epi32 | |||
| #define MS_MLA512_F32(src1, src2, src3) _mm512_add_ps(src1, _mm512_mul_ps(src2, src3)) | |||
| #define MS_ST512_F32 _mm512_storeu_ps | |||
| #define MS_ST512_EPI32(src1, src2) _mm512_storeu_si512((__m512i *)(src1), src2) | |||
| #define MS_SUB512_F32 _mm512_sub_ps | |||
| #define MS_MAX512_F32 _mm512_max_ps | |||
| #define MS_MAX512_EPI32 _mm512_max_epi32 | |||
| #define MS_MIN512_F32 _mm512_min_ps | |||
| #define MS_MIN512_EPI32 _mm512_min_epi32 | |||
| #define MS_MUL512_F32(src1, src2) _mm512_mul_ps(src1, src2) | |||
| #define MS_MUL512_EPI32(src1, src2) _mm512_mul_epi32(src1, src2) | |||
| #define MS_DIV512_F32(src1, src2) _mm512_div_ps(src1, src2) | |||
| #define MS_MUL512_N_F32(src1, src2) _mm512_mul_ps(src1, _mm512_set1_ps(src2)) | |||
| #define MS_MUL512_N_EPI32(src1, src2) _mm512_mul_epi32(src1, _mm512_set1_epi32(src2)) | |||
| #define MS_DIV512_N_F32(src1, src2) _mm512_div_ps(src1, _mm512_set1_ps(src2)) | |||
| #define MS_SLLI512_EPI32(src1, src2) _mm512_slli_epi32(src1, src2) | |||
| #define MS_CVT512PS_EPI32(src) _mm512_cvttps_epi32(src) | |||
| #define MS_CVT512EPI32_PS(src) _mm512_cvtepi32_ps(src) // truncate float to int | |||
| #define MS_CMP512_F32(src1, src2, src3) _mm512_cmp_ps_mask(src1, src2, src3) | |||
| #define MS_CMPGT512_EPI32(src1, src2) _mm512_cmpgt_epi32(src1, src2) | |||
| #define MS_BLEND512_F32(src1, src2, src3) _mm512_mask_blend_ps(src1, src2, src3) | |||
| #define MS_BLEND512_EPI32(src1, src2, src3) _mm512_mask_blend_epi32(src1, src2, src3) | |||
| #define MS_CAST512_F32_S32(src) _mm512_castsi512_ps(src) | |||
| static inline MS_FLOAT32X16 MS_TANHX16_F32(MS_FLOAT32X16 src) { | |||
| static const MS_FLOAT32X16 data0 = {378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, | |||
| 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f}; | |||
| static const MS_FLOAT32X16 data1 = {17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, | |||
| 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f}; | |||
| static const MS_FLOAT32X16 data2 = {135135.0f, 135135.0f, 135135.0f, 135135.0f, 135135.0f, 135135.0f, | |||
| 135135.0f, 135135.0f, 135135.0f, 135135.0f, 135135.0f, 135135.0f, | |||
| 135135.0f, 135135.0f, 135135.0f, 135135.0f}; | |||
| static const MS_FLOAT32X16 data3 = {28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, | |||
| 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f}; | |||
| static const MS_FLOAT32X16 data4 = {3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, | |||
| 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f}; | |||
| static const MS_FLOAT32X16 data5 = {62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, | |||
| 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f}; | |||
| static const MS_FLOAT32X16 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, | |||
| -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; | |||
| static const MS_FLOAT32X16 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, | |||
| 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; | |||
| MS_FLOAT32X16 square = MS_MUL512_F32(src, src); | |||
| MS_FLOAT32X16 a = MS_MUL512_F32( | |||
| MS_ADD512_F32(MS_MUL512_F32(MS_ADD512_F32(MS_MUL512_F32(MS_ADD512_F32(square, data0), square), data1), square), | |||
| data2), | |||
| src); | |||
| MS_FLOAT32X16 b = MS_ADD512_F32( | |||
| MS_MUL512_F32(MS_ADD512_F32(MS_MUL512_F32(MS_ADD512_F32(MS_MUL512_F32(data3, square), data4), square), data5), | |||
| square), | |||
| data2); | |||
| return MS_MIN512_F32(MS_MAX512_F32(MS_DIV512_F32(a, b), neg), pos); | |||
| } | |||
| #endif | |||
| @@ -22,6 +22,11 @@ | |||
| #include <stdbool.h> | |||
| #include <string.h> | |||
| #include <limits.h> | |||
| #ifdef ENABLE_AVX512 | |||
| #include "nnacl/intrinsics/ms_simd_avx512_instructions.h" | |||
| #endif | |||
| #if defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM) | |||
| #include "nnacl/intrinsics/ms_simd_instructions.h" | |||
| #endif | |||
| @@ -14,6 +14,7 @@ option(MSLITE_ENABLE_NPU "enable npu, only arm64 or arm32 support" off) | |||
| option(MSLITE_ENABLE_TRAIN "enable train" on) | |||
| option(MSLITE_ENABLE_SSE "enable SSE instruction set, only x86_64 support" off) | |||
| option(MSLITE_ENABLE_AVX "enable AVX instruction set, only x86_64 support" off) | |||
| option(MSLITE_ENABLE_AVX512 "enable AVX512 instruction set, only x86_64 support" off) | |||
| option(MSLITE_ENABLE_CONVERTER "enable converter, only x86_64 support" on) | |||
| option(MSLITE_ENABLE_TOOLS "enable tools" on) | |||
| option(MSLITE_ENABLE_TESTCASES "enable testcase" off) | |||
| @@ -57,6 +58,9 @@ endif() | |||
| if(DEFINED ENV{MSLITE_ENABLE_AVX}) | |||
| set(MSLITE_ENABLE_AVX $ENV{MSLITE_ENABLE_AVX}) | |||
| endif() | |||
| if(DEFINED ENV{MSLITE_ENABLE_AVX512}) | |||
| set(MSLITE_ENABLE_AVX512 $ENV{MSLITE_ENABLE_AVX512}) | |||
| endif() | |||
| if(DEFINED ENV{MSLITE_ENABLE_CONVERTER}) | |||
| set(MSLITE_ENABLE_CONVERTER $ENV{MSLITE_ENABLE_CONVERTER}) | |||
| endif() | |||
| @@ -178,6 +182,7 @@ if(PLATFORM_ARM64 OR PLATFORM_ARM32) | |||
| set(PLATFORM_ARM "on") | |||
| set(MSLITE_ENABLE_SSE off) | |||
| set(MSLITE_ENABLE_AVX off) | |||
| set(MSLITE_ENABLE_AVX512 off) | |||
| set(MSLITE_ENABLE_CONVERTER off) | |||
| set(MSLITE_ENABLE_RUNTIME_GLOG off) | |||
| set(MSLITE_ENABLE_RUNTIME_CONVERT off) | |||
| @@ -189,7 +194,7 @@ else() | |||
| set(MSLITE_ENABLE_NPU off) | |||
| endif() | |||
| if(MSLITE_ENABLE_SSE OR MSLITE_ENABLE_AVX OR WIN32 OR MSLITE_ENABLE_ACL) | |||
| if(MSLITE_ENABLE_SSE OR MSLITE_ENABLE_AVX OR MSLITE_ENABLE_AVX512 OR WIN32 OR MSLITE_ENABLE_ACL) | |||
| set(MSLITE_ENABLE_RUNTIME_CONVERT off) | |||
| endif() | |||
| @@ -229,6 +234,7 @@ message(STATUS "\tMSLITE_ENABLE_NPU = \t${MSLITE_ENABLE_NPU}") | |||
| message(STATUS "\tMSLITE_ENABLE_TRAIN = \t${MSLITE_ENABLE_TRAIN}") | |||
| message(STATUS "\tMSLITE_ENABLE_SSE = \t${MSLITE_ENABLE_SSE}") | |||
| message(STATUS "\tMSLITE_ENABLE_AVX = \t${MSLITE_ENABLE_AVX}") | |||
| message(STATUS "\tMSLITE_ENABLE_AVX512 = \t${MSLITE_ENABLE_AVX512}") | |||
| message(STATUS "\tMSLITE_ENABLE_CONVERTER = \t${MSLITE_ENABLE_CONVERTER}") | |||
| message(STATUS "\tMSLITE_ENABLE_TOOLS = \t${MSLITE_ENABLE_TOOLS}") | |||
| message(STATUS "\tMSLITE_ENABLE_TESTCASES = \t${MSLITE_ENABLE_TESTCASES}") | |||
| @@ -494,7 +500,16 @@ if(PLATFORM_ARM64) | |||
| endif() | |||
| if(NOT PLATFORM_ARM) | |||
| if(MSLITE_ENABLE_AVX) | |||
| if(MSLITE_ENABLE_AVX512) | |||
| set(X86_64_SIMD "avx512") | |||
| add_compile_definitions(ENABLE_SSE) | |||
| add_compile_definitions(ENABLE_AVX) | |||
| add_compile_definitions(ENABLE_AVX512) | |||
| if(NOT MSVC) | |||
| set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx -mfma -mavx512f") | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx -mfma -mavx512f") | |||
| endif() | |||
| elseif(MSLITE_ENABLE_AVX) | |||
| set(X86_64_SIMD "avx") | |||
| add_compile_definitions(ENABLE_SSE) | |||
| add_compile_definitions(ENABLE_AVX) | |||
| @@ -200,7 +200,6 @@ void ArithmeticCPUKernel::FreeConstTileBuff() { | |||
| input1_ptr_ = nullptr; | |||
| input1_broadcast_ = false; | |||
| } | |||
| return; | |||
| } | |||
| void ArithmeticCPUKernel::InitRunFunction(int primitive_type) { | |||