|
|
|
@@ -83,6 +83,14 @@ void MatMulAvx512Fp32(const float *a, const float *b, float *c, const float *bia |
|
|
|
} |
|
|
|
GemmAvx512Kernel kernel[C4NUM][C13NUM]; |
|
|
|
int max_shape[C4NUM] = {C12NUM, C12NUM, C8NUM, C6NUM}; |
|
|
|
|
|
|
|
#ifdef ENABLE_DEBUG |
|
|
|
for (int i = 0; i < C4NUM; i++) { |
|
|
|
for (int j = 0; j < C13NUM; j++) { |
|
|
|
kernel[i][j] = GemmRowxColKernelFp32; |
|
|
|
} |
|
|
|
} |
|
|
|
#else |
|
|
|
kernel[0][1] = nnacl_gemm_avx512_1x16_kernel_nhwc_fp32; |
|
|
|
kernel[0][2] = nnacl_gemm_avx512_2x16_kernel_nhwc_fp32; |
|
|
|
kernel[0][3] = nnacl_gemm_avx512_3x16_kernel_nhwc_fp32; |
|
|
|
@@ -124,6 +132,8 @@ void MatMulAvx512Fp32(const float *a, const float *b, float *c, const float *bia |
|
|
|
kernel[3][4] = nnacl_gemm_avx512_4x64_kernel_nhwc_fp32; |
|
|
|
kernel[3][5] = nnacl_gemm_avx512_5x64_kernel_nhwc_fp32; |
|
|
|
kernel[3][6] = nnacl_gemm_avx512_6x64_kernel_nhwc_fp32; |
|
|
|
#endif |
|
|
|
|
|
|
|
int inc_flag; |
|
|
|
for (int k = 0; k < depth; k += k_block) { |
|
|
|
if (depth - k <= k_block) { |
|
|
|
@@ -162,8 +172,13 @@ void MatVecMulAvx512Fp32(const float *a, const float *b, float *c, const float * |
|
|
|
if (act_type == ActType_Relu || act_type == ActType_Relu6) { |
|
|
|
act_flag += C2NUM; |
|
|
|
} |
|
|
|
#ifdef ENABLE_DEBUG |
|
|
|
GemmAvx512Kernel kernel[C4NUM] = {GemmRowxColKernelFp32, GemmRowxColKernelFp32, GemmRowxColKernelFp32, |
|
|
|
GemmRowxColKernelFp32}; |
|
|
|
#else |
|
|
|
GemmAvx512Kernel kernel[C4NUM] = {nnacl_gemm_avx512_1x16_kernel_nhwc_fp32, nnacl_gemm_avx512_1x32_kernel_nhwc_fp32, |
|
|
|
nnacl_gemm_avx512_1x48_kernel_nhwc_fp32, nnacl_gemm_avx512_1x64_kernel_nhwc_fp32}; |
|
|
|
#endif |
|
|
|
int inc_flag; |
|
|
|
for (int k = 0; k < depth; k += k_block) { |
|
|
|
if (depth - k <= k_block) { |
|
|
|
|