GitOrigin-RevId: 0f64b9f70f
tags/v1.0.0-rc1
| @@ -214,8 +214,7 @@ bool MatrixMulImpl::AlgoF32MK4_4x16::usable( | |||
| kern_size_param.B_type == dtype::Float32() && | |||
| kern_size_param.A_type == dtype::Float32() && | |||
| kern_size_param.format == param::MatrixMul::Format::MK4 && | |||
| !kern_size_param.trA && !kern_size_param.trB && | |||
| kern_size_param.N % 4 == 0; | |||
| !kern_size_param.trA && !kern_size_param.trB; | |||
| } | |||
| size_t MatrixMulImpl::AlgoF32MK4_4x16::get_workspace( | |||
| @@ -330,8 +329,7 @@ bool MatrixMulImpl::AlgoF16MK8_8x8::usable( | |||
| kern_size_param.B_type == kern_size_param.A_type && | |||
| kern_size_param.A_type == dtype::Float16() && | |||
| kern_size_param.format == param::MatrixMul::Format::MK8 && | |||
| !kern_size_param.trA && !kern_size_param.trB && | |||
| kern_size_param.N % 4 == 0; | |||
| !kern_size_param.trA && !kern_size_param.trB; | |||
| } | |||
| size_t MatrixMulImpl::AlgoF16MK8_8x8::get_workspace( | |||
| @@ -918,8 +916,7 @@ bool MatrixMulImpl::AlgoInt16x16x32MK8_8x8::usable( | |||
| kern_size_param.B_type == dtype::Int16() && | |||
| kern_size_param.A_type == dtype::Int16() && | |||
| kern_size_param.format == param::MatrixMul::Format::MK8 && | |||
| !kern_size_param.trA && !kern_size_param.trB && | |||
| kern_size_param.N % 4 == 0; | |||
| !kern_size_param.trA && !kern_size_param.trB; | |||
| } | |||
| size_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_workspace( | |||
| @@ -21,6 +21,76 @@ using namespace aarch64::matmul; | |||
| namespace { | |||
| void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
| dt_float16* output) { | |||
| LDB *= sizeof(dt_float16); | |||
| asm volatile( | |||
| ".arch armv8.2-a+fp16\n" | |||
| "subs %w[K], %w[K], #8\n" | |||
| "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[a_ptr]], 64\n" | |||
| "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[a_ptr]], 64\n" | |||
| "eor v24.16b, v24.16b, v24.16b\n" | |||
| "eor v25.16b, v25.16b, v25.16b\n" | |||
| "eor v26.16b, v26.16b, v26.16b\n" | |||
| "eor v27.16b, v27.16b, v27.16b\n" | |||
| "eor v28.16b, v28.16b, v28.16b\n" | |||
| "eor v29.16b, v29.16b, v29.16b\n" | |||
| "eor v30.16b, v30.16b, v30.16b\n" | |||
| "eor v31.16b, v31.16b, v31.16b\n" | |||
| "ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n" | |||
| "fmla v24.8h, v16.8h, v0.h[0]\n" | |||
| "fmla v25.8h, v17.8h, v0.h[1]\n" | |||
| "fmla v26.8h, v18.8h, v0.h[2]\n" | |||
| "fmla v27.8h, v19.8h, v0.h[3]\n" | |||
| "beq 2f\n" | |||
| "1:\n" | |||
| "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[a_ptr]], 64\n" | |||
| "fmla v28.8h, v20.8h, v0.h[4]\n" | |||
| "fmla v29.8h, v21.8h, v0.h[5]\n" | |||
| "fmla v30.8h, v22.8h, v0.h[6]\n" | |||
| "fmla v31.8h, v23.8h, v0.h[7]\n" | |||
| "ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n" | |||
| "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[a_ptr]], 64\n" | |||
| "fmla v24.8h, v16.8h, v0.h[0]\n" | |||
| "fmla v25.8h, v17.8h, v0.h[1]\n" | |||
| "fmla v26.8h, v18.8h, v0.h[2]\n" | |||
| "fmla v27.8h, v19.8h, v0.h[3]\n" | |||
| "subs %w[K], %w[K], #8\n" | |||
| "bne 1b\n" | |||
| "2:\n" | |||
| "fmla v28.8h, v20.8h, v0.h[4]\n" | |||
| "fmla v29.8h, v21.8h, v0.h[5]\n" | |||
| "fmla v30.8h, v22.8h, v0.h[6]\n" | |||
| "fmla v31.8h, v23.8h, v0.h[7]\n" | |||
| "fadd v24.8h, v24.8h, v25.8h\n" | |||
| "fadd v26.8h, v26.8h, v27.8h\n" | |||
| "fadd v28.8h, v28.8h, v29.8h\n" | |||
| "fadd v30.8h, v30.8h, v31.8h\n" | |||
| "fadd v24.8h, v24.8h, v26.8h\n" | |||
| "fadd v28.8h, v28.8h, v30.8h\n" | |||
| "fadd v24.8h, v24.8h, v28.8h\n" | |||
| "st1 {v24.4s}, [%[output]], 16\n" | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [output] "+r"(output), [LDB] "+r"(LDB) | |||
| : | |||
| : "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", | |||
| "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", | |||
| "memory"); | |||
| } | |||
| // Overview of register layout: | |||
| // | |||
| // A 8x1 cell of Rhs is stored in 16bit in v0-v3 | |||
| @@ -416,7 +486,7 @@ void gemm_nopack_f16_8x8::kern(const dt_float16* A, size_t LDA, | |||
| constexpr static size_t NB = 8; | |||
| constexpr static size_t CALCBLK = 4; | |||
| megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); | |||
| megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); | |||
| //! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) | |||
| for (size_t m = 0; m < M; m += MB) { | |||
| @@ -428,8 +498,17 @@ void gemm_nopack_f16_8x8::kern(const dt_float16* A, size_t LDA, | |||
| cur_B += KB * NB; | |||
| output += MB * NB; | |||
| } | |||
| if (n < N) { | |||
| if (N - n >= 4) { | |||
| kern_8x4(A, cur_B, LDB, K, output); | |||
| cur_B += KB * CALCBLK; | |||
| output += MB * CALCBLK; | |||
| n += 4; | |||
| } | |||
| while (n < N) { | |||
| kern_8x1(A, cur_B, LDB, K, output); | |||
| cur_B += KB; | |||
| output += MB; | |||
| n++; | |||
| } | |||
| A += LDA; | |||
| } | |||
| @@ -20,6 +20,54 @@ using namespace aarch64::matmul; | |||
| namespace { | |||
| void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||
| float* output) { | |||
| LDB *= sizeof(float); | |||
| asm volatile( | |||
| "subs %w[K], %w[K], #4\n" | |||
| "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[a_ptr]], 64\n" | |||
| "eor v16.16b, v16.16b, v16.16b\n" | |||
| "eor v17.16b, v17.16b, v17.16b\n" | |||
| "eor v18.16b, v18.16b, v18.16b\n" | |||
| "eor v19.16b, v19.16b, v19.16b\n" | |||
| "ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n" | |||
| "prfm pstl1keep, [%[b_ptr]]\n" | |||
| "fmla v16.4s, v4.4s, v0.s[0]\n" | |||
| "fmla v17.4s, v5.4s, v0.s[1]\n" | |||
| "beq 2f\n" | |||
| "1:\n" | |||
| "ld1 {v4.4s, v5.4s}, [%[a_ptr]], 32\n" | |||
| "fmla v18.4s, v6.4s, v0.s[2]\n" | |||
| "fmla v19.4s, v7.4s, v0.s[3]\n" | |||
| "ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n" | |||
| "prfm pstl1keep, [%[b_ptr]]\n" | |||
| "ld1 {v6.4s, v7.4s}, [%[a_ptr]], 32\n" | |||
| "fmla v16.4s, v4.4s, v0.s[0]\n" | |||
| "fmla v17.4s, v5.4s, v0.s[1]\n" | |||
| "subs %w[K], %w[K], #4\n" | |||
| "bne 1b\n" | |||
| "2:\n" | |||
| "fmla v18.4s, v6.4s, v0.s[2]\n" | |||
| "fmla v19.4s, v7.4s, v0.s[3]\n" | |||
| "fadd v16.4s, v16.4s, v18.4s\n" | |||
| "fadd v17.4s, v17.4s, v19.4s\n" | |||
| "fadd v16.4s, v16.4s, v17.4s\n" | |||
| "st1 {v16.4s}, [%[output]], 16\n" | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [output] "+r"(output), [LDB] "+r"(LDB) | |||
| : | |||
| : "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "cc", | |||
| "memory"); | |||
| } | |||
| // Overview of register layout: | |||
| // | |||
| // A 4x4 block of A is stored in register v4-v7 | |||
| @@ -117,7 +165,8 @@ void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [output] "+r"(output), [LDB] "+r"(LDB) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", | |||
| "v18", "v19", "cc", "memory"); | |||
| } | |||
| // Overview of register layout: | |||
| @@ -535,7 +584,7 @@ void sgemm_nopack_4x16::kern(const float* A, size_t LDA, const float* B, | |||
| constexpr static size_t NB = 16; | |||
| constexpr static size_t CALCBLK = 4; | |||
| megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); | |||
| megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); | |||
| //! (m/4, k/4, 4, 4) * (k/4, n, 4) = (m/4, n, 4) | |||
| for (size_t m = 0; m < M; m += MB) { | |||
| @@ -547,21 +596,23 @@ void sgemm_nopack_4x16::kern(const float* A, size_t LDA, const float* B, | |||
| cur_B += KB * NB; | |||
| output += MB * NB; | |||
| } | |||
| switch (N - n) { | |||
| case 4: | |||
| kern_4x4(A, cur_B, LDB, K, output); | |||
| break; | |||
| case 8: | |||
| kern_4x8(A, cur_B, LDB, K, output); | |||
| break; | |||
| case 12: | |||
| kern_4x8(A, cur_B, LDB, K, output); | |||
| cur_B += KB * CALCBLK * 2; | |||
| output += MB * CALCBLK * 2; | |||
| kern_4x4(A, cur_B, LDB, K, output); | |||
| break; | |||
| default: | |||
| break; | |||
| if (N - n >= 8) { | |||
| kern_4x8(A, cur_B, LDB, K, output); | |||
| cur_B += KB * CALCBLK * 2; | |||
| output += MB * CALCBLK * 2; | |||
| n += 8; | |||
| } | |||
| if (N - n >= 4) { | |||
| kern_4x4(A, cur_B, LDB, K, output); | |||
| cur_B += KB * CALCBLK; | |||
| output += MB * CALCBLK; | |||
| n += 4; | |||
| } | |||
| while (n < N) { | |||
| kern_4x1(A, cur_B, LDB, K, output); | |||
| cur_B += KB; | |||
| output += MB; | |||
| n++; | |||
| } | |||
| A += LDA; | |||
| } | |||
| @@ -20,6 +20,82 @@ using namespace aarch64::matmul; | |||
| namespace { | |||
| void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
| dt_int32* output) { | |||
| //! As each load 32 number from B, but the pos add 24 * 2, so we minus 24 | |||
| //! here. | |||
| LDB *= sizeof(dt_int16); | |||
| asm volatile( | |||
| "subs %w[K], %w[K], #8\n" | |||
| "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[a_ptr]], 64\n" | |||
| "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[a_ptr]], 64\n" | |||
| "ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n" | |||
| "smull v16.4s, v24.4h, v0.h[0]\n" | |||
| "smull2 v17.4s, v24.8h, v0.h[0]\n" | |||
| "smull v18.4s, v25.4h, v0.h[1]\n" | |||
| "smull2 v19.4s, v25.8h, v0.h[1]\n" | |||
| "smull v20.4s, v26.4h, v0.h[2]\n" | |||
| "smull2 v21.4s, v26.8h, v0.h[2]\n" | |||
| "smull v22.4s, v27.4h, v0.h[3]\n" | |||
| "smull2 v23.4s, v27.8h, v0.h[3]\n" | |||
| "beq 2f\n" | |||
| "1:\n" | |||
| "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[a_ptr]], 64\n" | |||
| "smlal v16.4s, v28.4h, v0.h[4]\n" | |||
| "smlal2 v17.4s, v28.8h, v0.h[4]\n" | |||
| "smlal v18.4s, v29.4h, v0.h[5]\n" | |||
| "smlal2 v19.4s, v29.8h, v0.h[5]\n" | |||
| "smlal v20.4s, v30.4h, v0.h[6]\n" | |||
| "smlal2 v21.4s, v30.8h, v0.h[6]\n" | |||
| "smlal v22.4s, v31.4h, v0.h[7]\n" | |||
| "smlal2 v23.4s, v31.8h, v0.h[7]\n" | |||
| "ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n" | |||
| "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[a_ptr]], 64\n" | |||
| "smlal v16.4s, v24.4h, v0.h[0]\n" | |||
| "smlal2 v17.4s, v24.8h, v0.h[0]\n" | |||
| "smlal v18.4s, v25.4h, v0.h[1]\n" | |||
| "smlal2 v19.4s, v25.8h, v0.h[1]\n" | |||
| "smlal v20.4s, v26.4h, v0.h[2]\n" | |||
| "smlal2 v21.4s, v26.8h, v0.h[2]\n" | |||
| "smlal v22.4s, v27.4h, v0.h[3]\n" | |||
| "smlal2 v23.4s, v27.8h, v0.h[3]\n" | |||
| "subs %w[K], %w[K], #8\n" | |||
| "bne 1b\n" | |||
| "2:\n" | |||
| "smlal v16.4s, v28.4h, v0.h[4]\n" | |||
| "smlal2 v17.4s, v28.8h, v0.h[4]\n" | |||
| "smlal v18.4s, v29.4h, v0.h[5]\n" | |||
| "smlal2 v19.4s, v29.8h, v0.h[5]\n" | |||
| "smlal v20.4s, v30.4h, v0.h[6]\n" | |||
| "smlal2 v21.4s, v30.8h, v0.h[6]\n" | |||
| "smlal v22.4s, v31.4h, v0.h[7]\n" | |||
| "smlal2 v23.4s, v31.8h, v0.h[7]\n" | |||
| "add v16.4s, v16.4s, v18.4s\n" | |||
| "add v20.4s, v20.4s, v22.4s\n" | |||
| "add v17.4s, v17.4s, v19.4s\n" | |||
| "add v21.4s, v21.4s, v23.4s\n" | |||
| "add v16.4s, v16.4s, v20.4s\n" | |||
| "add v17.4s, v17.4s, v21.4s\n" | |||
| "st1 {v16.4s, v17.4s}, [%[output]], 32\n" | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [output] "+r"(output), [LDB] "+r"(LDB) | |||
| : | |||
| : "v0", "v16", "v17", "v18", "v19", "v20", "v21", | |||
| "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
| "v29", "v30", "v31", "cc", "memory"); | |||
| } | |||
| // Overview of register layout: | |||
| // | |||
| // A 8x1 cell of Lhs is stored in 16bit in v24-v27 | |||
| @@ -636,7 +712,7 @@ void gemm_nopack_s16_8x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B, | |||
| constexpr static size_t NB = 8; | |||
| constexpr static size_t CALCBLK = 4; | |||
| megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); | |||
| megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); | |||
| //! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) | |||
| for (size_t m = 0; m < M; m += MB) { | |||
| @@ -648,8 +724,17 @@ void gemm_nopack_s16_8x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B, | |||
| cur_B += KB * NB; | |||
| output += MB * NB; | |||
| } | |||
| if (n < N) { | |||
| if (N - n >= 4) { | |||
| kern_8x4(A, cur_B, LDB, K, output); | |||
| cur_B += KB * CALCBLK; | |||
| output += MB * CALCBLK; | |||
| n += 4; | |||
| } | |||
| while (n < N) { | |||
| kern_8x1(A, cur_B, LDB, K, output); | |||
| cur_B += KB; | |||
| output += MB; | |||
| n++; | |||
| } | |||
| A += LDA; | |||
| } | |||
| @@ -390,7 +390,7 @@ void winograd_2x3_8x8_f16::output(const dt_float16* output_transform_buf, | |||
| size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||
| size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||
| DISPATCH_CONV_WINOGRAD_BIAS( | |||
| megdnn_arm_common_winograd_fp16_F23_8x8, cb, __fp16, __fp16, | |||
| megdnn_arm_common_winograd_f16_F23_8x8, cb, __fp16, __fp16, | |||
| bmode, nonline_mode, output_transform_buf, bias, output, | |||
| transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, | |||
| oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, dst_dtype); | |||
| @@ -875,8 +875,7 @@ bool MatrixMulImpl::AlgoF32MK4_4x8::usable( | |||
| kern_size_param.B_type == kern_size_param.A_type && | |||
| kern_size_param.C_type == kern_size_param.A_type && | |||
| kern_size_param.A_type == dtype::Float32() && | |||
| kern_size_param.N % 4 == 0 && !kern_size_param.trA && | |||
| !kern_size_param.trB; | |||
| !kern_size_param.trA && !kern_size_param.trB; | |||
| } | |||
| size_t MatrixMulImpl::AlgoF32MK4_4x8::get_workspace( | |||
| @@ -911,8 +910,7 @@ bool MatrixMulImpl::AlgoInt16x16x32MK8_4x8::usable( | |||
| kern_size_param.A_type == dtype::Int16() && | |||
| kern_size_param.B_type == dtype::Int16() && | |||
| kern_size_param.C_type == dtype::Int32() && | |||
| kern_size_param.N % 4 == 0 && !kern_size_param.trA && | |||
| !kern_size_param.trB; | |||
| !kern_size_param.trA && !kern_size_param.trB; | |||
| } | |||
| size_t MatrixMulImpl::AlgoInt16x16x32MK8_4x8::get_workspace( | |||
| @@ -969,8 +967,7 @@ bool MatrixMulImpl::AlgoF16MK8_4x8::usable( | |||
| kern_size_param.B_type == kern_size_param.A_type && | |||
| kern_size_param.A_type == dtype::Float16() && | |||
| kern_size_param.format == param::MatrixMul::Format::MK8 && | |||
| !kern_size_param.trA && !kern_size_param.trB && | |||
| kern_size_param.N % 4 == 0; | |||
| !kern_size_param.trA && !kern_size_param.trB; | |||
| } | |||
| size_t MatrixMulImpl::AlgoF16MK8_4x8::get_workspace( | |||
| @@ -21,6 +21,66 @@ using namespace armv7::matmul; | |||
| namespace { | |||
| void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
| dt_float16* output) { | |||
| LDB = (LDB - 4) * sizeof(dt_float16); | |||
| asm volatile( | |||
| "subs %[K], #8\n" | |||
| "vld1.32 {d0}, [%[b_ptr]]!\n" | |||
| "vld1.32 {d1}, [%[b_ptr]], %[LDB]\n" | |||
| "vld1.32 {d8, d9, d10, d11}, [%[a_ptr]]!\n" | |||
| "vld1.32 {d12, d13, d14, d15}, [%[a_ptr]]!\n" | |||
| "vld1.32 {d16, d17, d18, d19}, [%[a_ptr]]!\n" | |||
| "vld1.32 {d20, d21, d22, d23}, [%[a_ptr]]!\n" | |||
| "vmul.f16 q12, q4, d0[0]\n" | |||
| "vmul.f16 q13, q5, d0[1]\n" | |||
| "vmul.f16 q14, q6, d0[2]\n" | |||
| "vmul.f16 q15, q7, d0[3]\n" | |||
| "beq 2f\n" | |||
| "1:\n" | |||
| "vmla.f16 q12, q8, d1[0]\n" | |||
| "vld1.32 {d0}, [%[b_ptr]]!\n" | |||
| "vmla.f16 q13, q9, d1[1]\n" | |||
| "vld1.32 {d8, d9, d10, d11}, [%[a_ptr]]!\n" | |||
| "vmla.f16 q14, q10, d1[2]\n" | |||
| "vld1.32 {d12, d13, d14, d15}, [%[a_ptr]]!\n" | |||
| "vmla.f16 q15, q11, d1[3]\n" | |||
| "vmla.f16 q12, q4, d0[0]\n" | |||
| "vld1.32 {d1}, [%[b_ptr]], %[LDB]\n" | |||
| "vmla.f16 q13, q5, d0[1]\n" | |||
| "vld1.32 {d16, d17, d18, d19}, [%[a_ptr]]!\n" | |||
| "vmla.f16 q14, q6, d0[2]\n" | |||
| "vld1.32 {d20, d21, d22, d23}, [%[a_ptr]]!\n" | |||
| "vmla.f16 q15, q7, d0[3]\n" | |||
| "subs %[K], #8\n" | |||
| "bne 1b\n" | |||
| "2:\n" | |||
| "vmla.f16 q12, q8, d1[0]\n" | |||
| "vmla.f16 q13, q9, d1[1]\n" | |||
| "vmla.f16 q14, q10, d1[2]\n" | |||
| "vmla.f16 q15, q11, d1[3]\n" | |||
| "vadd.f16 q12, q12, q14\n" | |||
| "vadd.f16 q13, q13, q15\n" | |||
| "vadd.f16 q12, q12, q13\n" | |||
| "vst1.32 {d24, d25}, [%[output]]!\n" | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [output] "+r"(output), [LDB] "+r"(LDB) | |||
| : | |||
| : "d0", "d1", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", | |||
| "d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23", "d24", | |||
| "d25", "d26", "d27", "d28", "d29", "d30", "d31", "cc", "memory"); | |||
| } | |||
| // Overview of register layout: | |||
| // | |||
| // A 8x1 cell of Rhs is stored in 16bit in v4-v11 | |||
| @@ -45,7 +105,7 @@ namespace { | |||
| // | v3[0-7]| |v15[0-7]| | |||
| // +--------+ +--------+--------+ | |||
| // Accumulator | |||
| void kern_4x8(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
| void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
| dt_float16* output) { | |||
| //! As each load 64 number from B, but the pos add 48 * 2, so we minus 48 | |||
| //! here. | |||
| @@ -179,19 +239,25 @@ void gemm_nopack_f16_4x8::kern(const dt_float16* A, size_t LDA, | |||
| constexpr static size_t MB = 8; | |||
| constexpr static size_t KB = 8; | |||
| constexpr static size_t NB = 4; | |||
| constexpr static size_t CALCBLK = 4; | |||
| megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); | |||
| megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); | |||
| //! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) | |||
| for (size_t m = 0; m < M; m += MB) { | |||
| dt_float16* output = C + (m / MB) * LDC; | |||
| const dt_float16* cur_B = B; | |||
| for (size_t n = 0; n < N; n += NB) { | |||
| kern_4x8(A, cur_B, LDB, K, output); | |||
| size_t n = 0; | |||
| for (; n + NB - 1 < N; n += NB) { | |||
| kern_8x4(A, cur_B, LDB, K, output); | |||
| cur_B += KB * NB; | |||
| output += MB * NB; | |||
| } | |||
| while (n < N) { | |||
| kern_8x1(A, cur_B, LDB, K, output); | |||
| cur_B += KB; | |||
| output += MB; | |||
| n++; | |||
| } | |||
| A += LDA; | |||
| } | |||
| } | |||
| @@ -20,6 +20,58 @@ using namespace armv7::matmul; | |||
| namespace { | |||
| void kern_4x1(const float* A, const float* B, size_t LDB, size_t K, float* C) { | |||
| LDB = (LDB - 4) * sizeof(float); | |||
| asm volatile( | |||
| "subs %[K], %[K], #4\n" | |||
| "vld1.32 {d8-d11}, [%[A]]!\n" | |||
| "vld1.32 {d12-d15}, [%[A]]!\n" | |||
| "veor q8, q8 \n" | |||
| "veor q9, q9 \n" | |||
| "veor q10, q10 \n" | |||
| "veor q11, q11 \n" | |||
| "vld1.32 {d0-d1}, [%[B]]!\n" | |||
| "vmla.f32 q8, q4, d0[0]\n" | |||
| "vmla.f32 q9, q5, d0[1]\n" | |||
| "beq 2f\n" | |||
| "1:\n" | |||
| "vld1.32 {d8-d11}, [%[A]]!\n" | |||
| "vmla.f32 q10, q6, d1[0]\n" | |||
| "vmla.f32 q11, q7, d1[1]\n" | |||
| "add %[B], %[B], %[LDB]\n" | |||
| "vld1.32 {d0-d1}, [%[B]]!\n" | |||
| "vld1.32 {d12-d15}, [%[A]]!\n" | |||
| "vmla.f32 q8, q4, d0[0]\n" | |||
| "vmla.f32 q9, q5, d0[1]\n" | |||
| "subs %[K], %[K], #4\n" | |||
| "bne 1b\n" | |||
| "2:\n" | |||
| "vmla.f32 q10, q6, d1[0]\n" | |||
| "vmla.f32 q11, q7, d1[1]\n" | |||
| "vadd.f32 q8, q8, q10\n" | |||
| "vadd.f32 q9, q9, q11\n" | |||
| "vadd.f32 q8, q8, q9\n" | |||
| "vst1.32 {d16, d17}, [%[C]]!\n" | |||
| : [ A ] "+r"(A), [ B ] "+r"(B), [ K ] "+r"(K), [ C ] "+r"(C) | |||
| : [ LDB ] "r"(LDB) | |||
| : "d0", "d1", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", | |||
| "d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23", "cc", | |||
| "memory"); | |||
| } | |||
| // Overview of register layout: | |||
| // | |||
| // A 8x4 cell of Rhs is stored in 32bit in q0-q3, load 4 register each time | |||
| @@ -268,9 +320,9 @@ void sgemm_nopack_4x8::kern(const float* A, size_t LDA, const float* B, | |||
| constexpr size_t MB = 4; | |||
| constexpr size_t KB = 4; | |||
| constexpr size_t NB = 8; | |||
| constexpr size_t CALCBLK = 4; | |||
| constexpr size_t NB_HALF = 4; | |||
| megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); | |||
| megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); | |||
| //! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) | |||
| for (size_t m = 0; m < M; m += MB) { | |||
| @@ -282,8 +334,17 @@ void sgemm_nopack_4x8::kern(const float* A, size_t LDA, const float* B, | |||
| cur_B += KB * NB; | |||
| output += MB * NB; | |||
| } | |||
| if (n < N) { | |||
| if (N - n >= 4) { | |||
| kern_4x4(A, cur_B, LDB, K, output); | |||
| cur_B += KB * NB_HALF; | |||
| output += MB * NB_HALF; | |||
| n += 4; | |||
| } | |||
| while (n < N) { | |||
| kern_4x1(A, cur_B, LDB, K, output); | |||
| cur_B += KB; | |||
| output += MB; | |||
| n++; | |||
| } | |||
| A += LDA; | |||
| } | |||
| @@ -20,6 +20,91 @@ using namespace armv7::matmul; | |||
| namespace { | |||
| void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
| dt_int32* output) { | |||
| //! As each load 16 number from B, but the pos add 16 * 2, so we minus 16 | |||
| //! here. | |||
| LDB = (LDB - 4) * sizeof(dt_int16); | |||
| asm volatile( | |||
| "subs %[K], #8\n" | |||
| "vld1.32 {d8, d9, d10, d11}, [%[a_ptr]]!\n" | |||
| "vld1.32 {d12, d13, d14, d15}, [%[a_ptr]]!\n" | |||
| "vld1.32 {d16, d17, d18, d19}, [%[a_ptr]]!\n" | |||
| "vld1.32 {d20, d21, d22, d23}, [%[a_ptr]]!\n" | |||
| "vld1.32 {d0}, [%[b_ptr]]!\n" | |||
| "vld1.32 {d1}, [%[b_ptr]], %[LDB]\n" | |||
| "vmull.s16 q12, d8, d0[0]\n" | |||
| "vmull.s16 q13, d9, d0[0]\n" | |||
| "vmull.s16 q14, d10, d0[1]\n" | |||
| "vmull.s16 q15, d11, d0[1]\n" | |||
| "vmlal.s16 q12, d12, d0[2]\n" | |||
| "vmlal.s16 q13, d13, d0[2]\n" | |||
| "vmlal.s16 q14, d14, d0[3]\n" | |||
| "vmlal.s16 q15, d15, d0[3]\n" | |||
| "beq 2f\n" | |||
| "1:\n" | |||
| "vld1.32 {d8, d9, d10, d11}, [%[a_ptr]]!\n" | |||
| "vld1.32 {d12, d13, d14, d15}, [%[a_ptr]]!\n" | |||
| "vld1.32 {d0}, [%[b_ptr]]!\n" | |||
| "vmlal.s16 q12, d16, d1[0]\n" | |||
| "vmlal.s16 q13, d17, d1[0]\n" | |||
| "vmlal.s16 q14, d18, d1[1]\n" | |||
| "vmlal.s16 q15, d19, d1[1]\n" | |||
| "vmlal.s16 q12, d20, d1[2]\n" | |||
| "vmlal.s16 q13, d21, d1[2]\n" | |||
| "vmlal.s16 q14, d22, d1[3]\n" | |||
| "vmlal.s16 q15, d23, d1[3]\n" | |||
| "vld1.32 {d1}, [%[b_ptr]], %[LDB]\n" | |||
| "vld1.32 {d16, d17, d18, d19}, [%[a_ptr]]!\n" | |||
| "vld1.32 {d20, d21, d22, d23}, [%[a_ptr]]!\n" | |||
| "vmlal.s16 q12, d8, d0[0]\n" | |||
| "vmlal.s16 q13, d9, d0[0]\n" | |||
| "vmlal.s16 q14, d10, d0[1]\n" | |||
| "vmlal.s16 q15, d11, d0[1]\n" | |||
| "vmlal.s16 q12, d12, d0[2]\n" | |||
| "vmlal.s16 q13, d13, d0[2]\n" | |||
| "vmlal.s16 q14, d14, d0[3]\n" | |||
| "vmlal.s16 q15, d15, d0[3]\n" | |||
| "subs %[K], %[K], #8\n" | |||
| "bne 1b\n" | |||
| "2:\n" | |||
| "vmlal.s16 q12, d16, d1[0]\n" | |||
| "vmlal.s16 q13, d17, d1[0]\n" | |||
| "vmlal.s16 q14, d18, d1[1]\n" | |||
| "vmlal.s16 q15, d19, d1[1]\n" | |||
| "vmlal.s16 q12, d20, d1[2]\n" | |||
| "vmlal.s16 q13, d21, d1[2]\n" | |||
| "vmlal.s16 q14, d22, d1[3]\n" | |||
| "vmlal.s16 q15, d23, d1[3]\n" | |||
| "vadd.s32 q12, q12, q14\n" | |||
| "vadd.s32 q13, q13, q15\n" | |||
| "vst1.32 {d24, d25, d26, d27}, [%[output]]!\n" | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [output] "+r"(output), [LDB] "+r"(LDB) | |||
| : | |||
| : "d0", "d1", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", | |||
| "d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23", "d24", | |||
| "d25", "d26", "d27", "d28", "d29", "d30", "d31", "cc", "memory"); | |||
| } | |||
| // Overview of register layout: | |||
| // | |||
| // A 4x8 cell of Rhs is stored in 16bit in q0-q3 | |||
| @@ -40,7 +125,7 @@ namespace { | |||
| // | q3[0-7]| |q14[0-3]|v15[0-3]| | |||
| // +--------+ +--------+--------+ | |||
| // Accumulator | |||
| void kern_4x8(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
| void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
| dt_int32* output) { | |||
| //! As each load 16 number from B, but the pos add 16 * 2, so we minus 16 | |||
| //! here. | |||
| @@ -247,19 +332,25 @@ void gemm_nopack_s16_4x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B, | |||
| constexpr static size_t MB = 8; | |||
| constexpr static size_t KB = 8; | |||
| constexpr static size_t NB = 4; | |||
| constexpr static size_t CALCBLK = 4; | |||
| megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); | |||
| megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); | |||
| //! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) | |||
| for (size_t m = 0; m < M; m += MB) { | |||
| dt_int32* output = C + (m / MB) * LDC; | |||
| const dt_int16* cur_B = B; | |||
| for (size_t n = 0; n < N; n += NB) { | |||
| kern_4x8(A, cur_B, LDB, K, output); | |||
| size_t n = 0; | |||
| for (; n + NB - 1 < N; n += NB) { | |||
| kern_8x4(A, cur_B, LDB, K, output); | |||
| cur_B += KB * NB; | |||
| output += MB * NB; | |||
| } | |||
| while (n < N) { | |||
| kern_8x1(A, cur_B, LDB, K, output); | |||
| cur_B += KB; | |||
| output += MB; | |||
| n++; | |||
| } | |||
| A += LDA; | |||
| } | |||
| } | |||
| @@ -427,9 +427,6 @@ public: | |||
| "The winograd remain oc is not times of OC_BLOCK_SIZE"); | |||
| if (format == param::MatrixMul::Format::MK4 || | |||
| format == param::MatrixMul::Format::MK8) { | |||
| #if !MEGDNN_X86 | |||
| nr_tiles_in_unit = round_up<size_t>(nr_tiles_in_unit, 4); | |||
| #endif | |||
| megdnn_assert(nr_tiles_in_unit <= unit_tile_size, | |||
| "nr_tiles_in_unit: %zu TILE_SIZE:%zu", | |||
| nr_tiles_in_unit, unit_tile_size); | |||
| @@ -38,10 +38,9 @@ TEST_F(AARCH64, MATRIX_MUL_FP32_PACK_MK4) { | |||
| } | |||
| TEST_F(AARCH64, MATRIX_MUL_FP32_MK4) { | |||
| //! nbase should be 4 in order to test the last rest 4 in N dim | |||
| matrix_mul::check_matrix_mul( | |||
| dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | |||
| "AARCH64_F32_MK4_4x16", param::MatrixMul::Format::MK4, 4); | |||
| "AARCH64_F32_MK4_4x16", param::MatrixMul::Format::MK4, 1); | |||
| } | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| @@ -52,10 +51,9 @@ TEST_F(AARCH64, MATRIX_MUL_F16_K8X24X1) { | |||
| } | |||
| TEST_F(AARCH64, MATRIX_MUL_F16_MK8) { | |||
| //! nbase should be 4 in order to test the last rest 4 in N dim | |||
| matrix_mul::check_matrix_mul( | |||
| dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, handle(), | |||
| "AARCH64_F16_MK8_8X8", param::MatrixMul::Format::MK8, 4); | |||
| "AARCH64_F16_MK8_8X8", param::MatrixMul::Format::MK8, 1); | |||
| } | |||
| #endif | |||
| @@ -116,10 +114,9 @@ TEST_F(AARCH64, MATRIX_MUL_INT16x16x32_K12X8X1) { | |||
| } | |||
| TEST_F(AARCH64, MATRIX_MUL_INT16x16x32_MK8) { | |||
| //! nbase should be 4 in order to test the last rest 4 in N dim | |||
| matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, | |||
| handle(), "AARCH64_INT16X16X32_MK8_8X8", | |||
| param::MatrixMul::Format::MK8, 4); | |||
| param::MatrixMul::Format::MK8, 1); | |||
| } | |||
| //! FIXME: need to add tests of GEMV and QUINT8 | |||
| @@ -26,7 +26,7 @@ TEST_F(ARMV7, MATRIX_MUL) { | |||
| TEST_F(ARMV7, MATRIX_MUL_MK4) { | |||
| matrix_mul::check_matrix_mul( | |||
| dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | |||
| "ARMV7_F32_MK4_4x8", param::MatrixMul::Format::MK4, 4); | |||
| "ARMV7_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1); | |||
| } | |||
| TEST_F(ARMV7, MATRIX_MUL_PACK_MK4) { | |||
| @@ -66,7 +66,7 @@ TEST_F(ARMV7, MATRIX_MUL_INT16x16x32) { | |||
| TEST_F(ARMV7, MATRIX_MUL_INT16x16x32_MK8) { | |||
| matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, | |||
| handle(), "ARMV7_INT16X16X32_MK8_4X8", | |||
| param::MatrixMul::Format::MK8, 4); | |||
| param::MatrixMul::Format::MK8, 1); | |||
| } | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| @@ -78,7 +78,7 @@ TEST_F(ARMV7, MATRIX_MUL_FP16) { | |||
| TEST_F(ARMV7, MATRIX_MUL_F16_MK8) { | |||
| matrix_mul::check_matrix_mul( | |||
| dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, handle(), | |||
| "AARCH32_F16_MK8_4X8", param::MatrixMul::Format::MK8, 4); | |||
| "AARCH32_F16_MK8_4X8", param::MatrixMul::Format::MK8, 1); | |||
| } | |||
| #endif | |||