/** * \file dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ #include #if !(__ARM_FEATURE_DOTPROD) #include "src/aarch64/matrix_mul/asm/common.h" #include "src/arm_common/simd_macro/marm_neon.h" namespace megdnn { namespace aarch64 { namespace matmul_mk4_4x4x16 { /** * Overview of register layout: * * A 16x4 cell of Rhs is stored in 8bit in v0-q3. * B 16x4 cell of Lhs is stored in 8bit in q4-q7 * C 8x16 block of accumulators is stored in 8bit in q8--q31. * * \warning Fast kernel operating on int8 operands. * It is assumed that one of the two int8 operands only takes values * in [-127, 127], while the other may freely range in [-128, 127]. * The issue with both operands taking the value -128 is that: * -128*-128 + -128*-128 == -32768 overflows int16. * Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16 * range. That is the basic idea of this kernel. * * * +--------+--------+---------+---------+ * |v4[0-16]|v5[0-16]| v6[0-16]| v7[0-16]| * Rhs +--------+--------+---------+---------+ * | | | | | * * Lhs | | | | | * * +--------+ - - - - +-------------------------------------+ * |v0[0-16]| |v16[0-4]|v17[0-4]| v18[0-4]| v19[0-4]| * |v1[0-16]| |v20[0-4]|v21[0-4]| v22[0-4]| v23[0-4]| * |v2[0-16]| |v24[0-4]|v25[0-4]| v26[0-4]| v27[0-4]| * |v3[0-16]| |v28[0-4]|v29[0-4]| v30[0-4]| v31[0-4]| * +--------+ - - - - +-------------------------------------+ * * Accumulator */ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, int32_t* output, bool is_first_k) { K = div_ceil(K, 16); const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; asm volatile( // load accumulator C "ld1 {v0.16b}, [%[a_ptr]], #16\n" "eor v16.16b, v16.16b, v16.16b\n" "eor v17.16b, v17.16b, v17.16b\n" "eor v18.16b, v18.16b, v18.16b\n" "ld1 {v1.16b}, [%[a_ptr]], #16\n" "eor v19.16b, v19.16b, v19.16b\n" "eor v20.16b, v19.16b, v19.16b\n" "eor v21.16b, v19.16b, v19.16b\n" "ld1 {v4.16b, v5.16b}, [%[b_ptr]], #32\n" "eor v22.16b, v19.16b, v19.16b\n" "PRFM PLDL1KEEP, [%[a_ptr], #32]\n" "eor v23.16b, v19.16b, v19.16b\n" "eor v24.16b, v19.16b, v19.16b\n" "PRFM PLDL1KEEP, [%[b_ptr], #32]\n" "eor v25.16b, v19.16b, v19.16b\n" "eor v26.16b, v19.16b, v19.16b\n" "PRFM PLDL1KEEP, [%[b_ptr], #64]\n" "eor v27.16b, v19.16b, v19.16b\n" "eor v28.16b, v19.16b, v19.16b\n" "PRFM PLDL1KEEP, [%[a_ptr], #64]\n" "eor v29.16b, v19.16b, v19.16b\n" "eor v30.16b, v19.16b, v19.16b\n" "PRFM PLDL1KEEP, [%[b_ptr], #128]\n" "eor v31.16b, v19.16b, v19.16b\n" //! if K==1 jump to compute last K "cmp %w[k], #2\n" "beq 2f\n" "blt 3f\n" //! K>2 "1:\n" //! First k "smull v8.8h, v0.8b, v4.8b\n" "smull v9.8h, v0.8b, v5.8b\n" "ld1 {v6.16b}, [%[b_ptr]], #16\n" "smull v12.8h, v1.8b, v4.8b\n" "smull v13.8h, v1.8b, v5.8b\n" "ld1 {v7.16b}, [%[b_ptr]], #16\n" "smlal2 v8.8h, v0.16b, v4.16b\n" "smlal2 v9.8h, v0.16b, v5.16b\n" "smlal2 v12.8h, v1.16b, v4.16b\n" "smlal2 v13.8h, v1.16b, v5.16b\n" "smull v10.8h, v0.8b, v6.8b\n" "ld1 {v2.16b}, [%[a_ptr]], #16\n" "smull v11.8h, v0.8b, v7.8b\n" "smull v14.8h, v1.8b, v6.8b\n" "ld1 {v3.16b}, [%[a_ptr]], #16\n" "smull v15.8h, v1.8b, v7.8b\n" "sadalp v16.4s, v8.8h\n" "smlal2 v10.8h, v0.16b, v6.16b\n" "sadalp v17.4s, v9.8h\n" "smlal2 v11.8h, v0.16b, v7.16b\n" "sadalp v20.4s, v12.8h\n" "smlal2 v14.8h, v1.16b, v6.16b\n" "sadalp v21.4s, v13.8h\n" "smlal2 v15.8h, v1.16b, v7.16b\n" "smull v8.8h, v2.8b, v4.8b\n" "smull v9.8h, v2.8b, v5.8b\n" "ld1 {v0.16b}, [%[a_ptr]], #16\n" "smull v12.8h, v3.8b, v4.8b\n" "ld1 {v1.16b}, [%[a_ptr]], #16\n" "smull v13.8h, v3.8b, v5.8b\n" "sadalp v18.4s, v10.8h\n" "smlal2 v8.8h, v2.16b, v4.16b\n" "sadalp v19.4s, v11.8h\n" "smlal2 v9.8h, v2.16b, v5.16b\n" "sadalp v22.4s, v14.8h\n" "smlal2 v12.8h, v3.16b, v4.16b\n" "sadalp v23.4s, v15.8h\n" "smlal2 v13.8h, v3.16b, v5.16b\n" "smull v10.8h, v2.8b, v6.8b\n" "smull v11.8h, v2.8b, v7.8b\n" "ld1 {v4.16b}, [%[b_ptr]], #16\n" "smull v14.8h, v3.8b, v6.8b\n" "ld1 {v5.16b}, [%[b_ptr]], #16\n" "smull v15.8h, v3.8b, v7.8b\n" "sadalp v24.4s, v8.8h\n" "smlal2 v10.8h, v2.16b, v6.16b\n" "sadalp v25.4s, v9.8h\n" "smlal2 v11.8h, v2.16b, v7.16b\n" "sadalp v28.4s, v12.8h\n" "smlal2 v14.8h, v3.16b, v6.16b\n" "sadalp v29.4s, v13.8h\n" "smlal2 v15.8h, v3.16b, v7.16b\n" //! Second k "smull v8.8h, v0.8b, v4.8b\n" "smull v9.8h, v0.8b, v5.8b\n" "ld1 {v6.16b}, [%[b_ptr]], #16\n" "smull v12.8h, v1.8b, v4.8b\n" "smull v13.8h, v1.8b, v5.8b\n" "ld1 {v7.16b}, [%[b_ptr]], #16\n" "smlal2 v8.8h, v0.16b, v4.16b\n" "sadalp v26.4s, v10.8h\n" "smlal2 v9.8h, v0.16b, v5.16b\n" "sadalp v27.4s, v11.8h\n" "smlal2 v12.8h, v1.16b, v4.16b\n" "sadalp v30.4s, v14.8h\n" "smlal2 v13.8h, v1.16b, v5.16b\n" "sadalp v31.4s, v15.8h\n" "smull v10.8h, v0.8b, v6.8b\n" "ld1 {v2.16b}, [%[a_ptr]], #16\n" "smull v11.8h, v0.8b, v7.8b\n" "smull v14.8h, v1.8b, v6.8b\n" "ld1 {v3.16b}, [%[a_ptr]], #16\n" "smull v15.8h, v1.8b, v7.8b\n" "sadalp v16.4s, v8.8h\n" "smlal2 v10.8h, v0.16b, v6.16b\n" "sadalp v17.4s, v9.8h\n" "smlal2 v11.8h, v0.16b, v7.16b\n" "sadalp v20.4s, v12.8h\n" "smlal2 v14.8h, v1.16b, v6.16b\n" "sadalp v21.4s, v13.8h\n" "smlal2 v15.8h, v1.16b, v7.16b\n" "smull v8.8h, v2.8b, v4.8b\n" "smull v9.8h, v2.8b, v5.8b\n" "ld1 {v0.16b}, [%[a_ptr]], #16\n" "smull v12.8h, v3.8b, v4.8b\n" "ld1 {v1.16b}, [%[a_ptr]], #16\n" "smull v13.8h, v3.8b, v5.8b\n" "sadalp v18.4s, v10.8h\n" "smlal2 v8.8h, v2.16b, v4.16b\n" "sadalp v19.4s, v11.8h\n" "smlal2 v9.8h, v2.16b, v5.16b\n" "sadalp v22.4s, v14.8h\n" "smlal2 v12.8h, v3.16b, v4.16b\n" "sadalp v23.4s, v15.8h\n" "smlal2 v13.8h, v3.16b, v5.16b\n" "sub %w[k], %w[k], #2\n" "cmp %w[k], #2\n" "smull v10.8h, v2.8b, v6.8b\n" "ld1 {v4.16b}, [%[b_ptr]], #16\n" "smull v11.8h, v2.8b, v7.8b\n" "ld1 {v5.16b}, [%[b_ptr]], #16\n" "smull v14.8h, v3.8b, v6.8b\n" "sadalp v24.4s, v8.8h\n" "smull v15.8h, v3.8b, v7.8b\n" "sadalp v25.4s, v9.8h\n" "smlal2 v10.8h, v2.16b, v6.16b\n" "sadalp v28.4s, v12.8h\n" "smlal2 v11.8h, v2.16b, v7.16b\n" "sadalp v29.4s, v13.8h\n" "smlal2 v14.8h, v3.16b, v6.16b\n" "sadalp v26.4s, v10.8h\n" "smlal2 v15.8h, v3.16b, v7.16b\n" "sadalp v27.4s, v11.8h\n" "sadalp v30.4s, v14.8h\n" "sadalp v31.4s, v15.8h\n" "bgt 1b\n" "blt 3f\n" //! K==2 "2:\n" "smull v8.8h, v0.8b, v4.8b\n" "smull v9.8h, v0.8b, v5.8b\n" "ld1 {v6.16b}, [%[b_ptr]], #16\n" "smull v12.8h, v1.8b, v4.8b\n" "smull v13.8h, v1.8b, v5.8b\n" "ld1 {v7.16b}, [%[b_ptr]], #16\n" "smlal2 v8.8h, v0.16b, v4.16b\n" "smlal2 v9.8h, v0.16b, v5.16b\n" "smlal2 v12.8h, v1.16b, v4.16b\n" "smlal2 v13.8h, v1.16b, v5.16b\n" "smull v10.8h, v0.8b, v6.8b\n" "ld1 {v2.16b}, [%[a_ptr]], #16\n" "smull v11.8h, v0.8b, v7.8b\n" "smull v14.8h, v1.8b, v6.8b\n" "ld1 {v3.16b}, [%[a_ptr]], #16\n" "smull v15.8h, v1.8b, v7.8b\n" "sadalp v16.4s, v8.8h\n" "smlal2 v10.8h, v0.16b, v6.16b\n" "sadalp v17.4s, v9.8h\n" "smlal2 v11.8h, v0.16b, v7.16b\n" "sadalp v20.4s, v12.8h\n" "smlal2 v14.8h, v1.16b, v6.16b\n" "sadalp v21.4s, v13.8h\n" "smlal2 v15.8h, v1.16b, v7.16b\n" "smull v8.8h, v2.8b, v4.8b\n" "smull v9.8h, v2.8b, v5.8b\n" "ld1 {v0.16b}, [%[a_ptr]], #16\n" "smull v12.8h, v3.8b, v4.8b\n" "ld1 {v1.16b}, [%[a_ptr]], #16\n" "smull v13.8h, v3.8b, v5.8b\n" "sadalp v18.4s, v10.8h\n" "smlal2 v8.8h, v2.16b, v4.16b\n" "sadalp v19.4s, v11.8h\n" "smlal2 v9.8h, v2.16b, v5.16b\n" "sadalp v22.4s, v14.8h\n" "smlal2 v12.8h, v3.16b, v4.16b\n" "sadalp v23.4s, v15.8h\n" "smlal2 v13.8h, v3.16b, v5.16b\n" "smull v10.8h, v2.8b, v6.8b\n" "smull v11.8h, v2.8b, v7.8b\n" "ld1 {v4.16b}, [%[b_ptr]], #16\n" "smull v14.8h, v3.8b, v6.8b\n" "ld1 {v5.16b}, [%[b_ptr]], #16\n" "smull v15.8h, v3.8b, v7.8b\n" "sadalp v24.4s, v8.8h\n" "smlal2 v10.8h, v2.16b, v6.16b\n" "sadalp v25.4s, v9.8h\n" "smlal2 v11.8h, v2.16b, v7.16b\n" "sadalp v28.4s, v12.8h\n" "smlal2 v14.8h, v3.16b, v6.16b\n" "sadalp v29.4s, v13.8h\n" "smlal2 v15.8h, v3.16b, v7.16b\n" "sadalp v26.4s, v10.8h\n" "sadalp v27.4s, v11.8h\n" "sadalp v30.4s, v14.8h\n" "sadalp v31.4s, v15.8h\n" //! K==1 "3:\n" "smull v8.8h, v0.8b, v4.8b\n" "smull v9.8h, v0.8b, v5.8b\n" "ld1 {v6.16b}, [%[b_ptr]], #16\n" "smull v12.8h, v1.8b, v4.8b\n" "smull v13.8h, v1.8b, v5.8b\n" "ld1 {v7.16b}, [%[b_ptr]], #16\n" "smlal2 v8.8h, v0.16b, v4.16b\n" "smlal2 v9.8h, v0.16b, v5.16b\n" "smlal2 v12.8h, v1.16b, v4.16b\n" "smlal2 v13.8h, v1.16b, v5.16b\n" "smull v10.8h, v0.8b, v6.8b\n" "ld1 {v2.16b}, [%[a_ptr]], #16\n" "smull v11.8h, v0.8b, v7.8b\n" "smull v14.8h, v1.8b, v6.8b\n" "ld1 {v3.16b}, [%[a_ptr]], #16\n" "smull v15.8h, v1.8b, v7.8b\n" "sadalp v16.4s, v8.8h\n" "smlal2 v10.8h, v0.16b, v6.16b\n" "sadalp v17.4s, v9.8h\n" "smlal2 v11.8h, v0.16b, v7.16b\n" "sadalp v20.4s, v12.8h\n" "smlal2 v14.8h, v1.16b, v6.16b\n" "sadalp v21.4s, v13.8h\n" "smlal2 v15.8h, v1.16b, v7.16b\n" "smull v8.8h, v2.8b, v4.8b\n" "smull v9.8h, v2.8b, v5.8b\n" "smull v12.8h, v3.8b, v4.8b\n" "smull v13.8h, v3.8b, v5.8b\n" "sadalp v18.4s, v10.8h\n" "smlal2 v8.8h, v2.16b, v4.16b\n" "sadalp v19.4s, v11.8h\n" "smlal2 v9.8h, v2.16b, v5.16b\n" "sadalp v22.4s, v14.8h\n" "smlal2 v12.8h, v3.16b, v4.16b\n" "sadalp v23.4s, v15.8h\n" "smlal2 v13.8h, v3.16b, v5.16b\n" "smull v10.8h, v2.8b, v6.8b\n" "sadalp v24.4s, v8.8h\n" "smull v11.8h, v2.8b, v7.8b\n" "sadalp v25.4s, v9.8h\n" "smull v14.8h, v3.8b, v6.8b\n" "sadalp v28.4s, v12.8h\n" "smull v15.8h, v3.8b, v7.8b\n" "sadalp v29.4s, v13.8h\n" "smlal2 v10.8h, v2.16b, v6.16b\n" "smlal2 v11.8h, v2.16b, v7.16b\n" "sadalp v26.4s, v10.8h\n" "smlal2 v14.8h, v3.16b, v6.16b\n" "sadalp v27.4s, v11.8h\n" "smlal2 v15.8h, v3.16b, v7.16b\n" "sadalp v30.4s, v14.8h\n" "sadalp v31.4s, v15.8h\n" "addp v4.4s, v16.4s, v20.4s\n" "addp v5.4s, v24.4s, v28.4s\n" "addp v6.4s, v17.4s, v21.4s\n" "addp v7.4s, v25.4s, v29.4s\n" "addp v8.4s, v18.4s, v22.4s\n" "addp v9.4s, v26.4s, v30.4s\n" "addp v10.4s, v19.4s, v23.4s\n" "addp v11.4s, v27.4s, v31.4s\n" "cmp %w[is_first_k], #1\n" "addp v0.4s, v4.4s, v5.4s\n" "addp v1.4s, v6.4s, v7.4s\n" "addp v2.4s, v8.4s, v9.4s\n" "addp v3.4s, v10.4s, v11.4s\n" "beq 6f\n" "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output]]\n" "add v0.4s, v0.4s, v8.4s\n" "add v1.4s, v1.4s, v9.4s\n" "add v2.4s, v2.4s, v10.4s\n" "add v3.4s, v3.4s, v11.4s\n" "6:\n" "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[output]], #64\n" : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), [k] "+r"(K), [output] "+r"(output) : : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); } static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, int32_t* output, bool is_first_k, size_t remain_n) { K = div_ceil(K, 16); const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; asm volatile( // load accumulator C "ld1 {v0.16b}, [%[a_ptr]], #16\n" "eor v16.16b, v16.16b, v16.16b\n" "eor v17.16b, v17.16b, v17.16b\n" "eor v18.16b, v18.16b, v18.16b\n" "ld1 {v1.16b}, [%[a_ptr]], #16\n" "eor v19.16b, v19.16b, v19.16b\n" "eor v20.16b, v19.16b, v19.16b\n" "eor v21.16b, v19.16b, v19.16b\n" "ld1 {v4.16b, v5.16b}, [%[b_ptr]], #32\n" "eor v22.16b, v19.16b, v19.16b\n" "PRFM PLDL1KEEP, [%[a_ptr], #32]\n" "eor v23.16b, v19.16b, v19.16b\n" "eor v24.16b, v19.16b, v19.16b\n" "PRFM PLDL1KEEP, [%[b_ptr], #32]\n" "eor v25.16b, v19.16b, v19.16b\n" "eor v26.16b, v19.16b, v19.16b\n" "PRFM PLDL1KEEP, [%[b_ptr], #64]\n" "eor v27.16b, v19.16b, v19.16b\n" "eor v28.16b, v19.16b, v19.16b\n" "PRFM PLDL1KEEP, [%[a_ptr], #64]\n" "eor v29.16b, v19.16b, v19.16b\n" "eor v30.16b, v19.16b, v19.16b\n" "PRFM PLDL1KEEP, [%[b_ptr], #128]\n" "eor v31.16b, v19.16b, v19.16b\n" //! if K==1 jump to compute last K "cmp %w[k], #2\n" "beq 2f\n" "blt 3f\n" //! K>2 "1:\n" //! First k "smull v8.8h, v0.8b, v4.8b\n" "smull v9.8h, v0.8b, v5.8b\n" "ld1 {v6.16b}, [%[b_ptr]], #16\n" "smull v12.8h, v1.8b, v4.8b\n" "smull v13.8h, v1.8b, v5.8b\n" "ld1 {v7.16b}, [%[b_ptr]], #16\n" "smlal2 v8.8h, v0.16b, v4.16b\n" "smlal2 v9.8h, v0.16b, v5.16b\n" "smlal2 v12.8h, v1.16b, v4.16b\n" "smlal2 v13.8h, v1.16b, v5.16b\n" "smull v10.8h, v0.8b, v6.8b\n" "ld1 {v2.16b}, [%[a_ptr]], #16\n" "smull v11.8h, v0.8b, v7.8b\n" "smull v14.8h, v1.8b, v6.8b\n" "ld1 {v3.16b}, [%[a_ptr]], #16\n" "smull v15.8h, v1.8b, v7.8b\n" "sadalp v16.4s, v8.8h\n" "smlal2 v10.8h, v0.16b, v6.16b\n" "sadalp v17.4s, v9.8h\n" "smlal2 v11.8h, v0.16b, v7.16b\n" "sadalp v20.4s, v12.8h\n" "smlal2 v14.8h, v1.16b, v6.16b\n" "sadalp v21.4s, v13.8h\n" "smlal2 v15.8h, v1.16b, v7.16b\n" "smull v8.8h, v2.8b, v4.8b\n" "smull v9.8h, v2.8b, v5.8b\n" "ld1 {v0.16b}, [%[a_ptr]], #16\n" "smull v12.8h, v3.8b, v4.8b\n" "ld1 {v1.16b}, [%[a_ptr]], #16\n" "smull v13.8h, v3.8b, v5.8b\n" "sadalp v18.4s, v10.8h\n" "smlal2 v8.8h, v2.16b, v4.16b\n" "sadalp v19.4s, v11.8h\n" "smlal2 v9.8h, v2.16b, v5.16b\n" "sadalp v22.4s, v14.8h\n" "smlal2 v12.8h, v3.16b, v4.16b\n" "sadalp v23.4s, v15.8h\n" "smlal2 v13.8h, v3.16b, v5.16b\n" "smull v10.8h, v2.8b, v6.8b\n" "smull v11.8h, v2.8b, v7.8b\n" "ld1 {v4.16b}, [%[b_ptr]], #16\n" "smull v14.8h, v3.8b, v6.8b\n" "ld1 {v5.16b}, [%[b_ptr]], #16\n" "smull v15.8h, v3.8b, v7.8b\n" "sadalp v24.4s, v8.8h\n" "smlal2 v10.8h, v2.16b, v6.16b\n" "sadalp v25.4s, v9.8h\n" "smlal2 v11.8h, v2.16b, v7.16b\n" "sadalp v28.4s, v12.8h\n" "smlal2 v14.8h, v3.16b, v6.16b\n" "sadalp v29.4s, v13.8h\n" "smlal2 v15.8h, v3.16b, v7.16b\n" //! Second k "smull v8.8h, v0.8b, v4.8b\n" "smull v9.8h, v0.8b, v5.8b\n" "ld1 {v6.16b}, [%[b_ptr]], #16\n" "smull v12.8h, v1.8b, v4.8b\n" "smull v13.8h, v1.8b, v5.8b\n" "ld1 {v7.16b}, [%[b_ptr]], #16\n" "smlal2 v8.8h, v0.16b, v4.16b\n" "sadalp v26.4s, v10.8h\n" "smlal2 v9.8h, v0.16b, v5.16b\n" "sadalp v27.4s, v11.8h\n" "smlal2 v12.8h, v1.16b, v4.16b\n" "sadalp v30.4s, v14.8h\n" "smlal2 v13.8h, v1.16b, v5.16b\n" "sadalp v31.4s, v15.8h\n" "smull v10.8h, v0.8b, v6.8b\n" "ld1 {v2.16b}, [%[a_ptr]], #16\n" "smull v11.8h, v0.8b, v7.8b\n" "smull v14.8h, v1.8b, v6.8b\n" "ld1 {v3.16b}, [%[a_ptr]], #16\n" "smull v15.8h, v1.8b, v7.8b\n" "sadalp v16.4s, v8.8h\n" "smlal2 v10.8h, v0.16b, v6.16b\n" "sadalp v17.4s, v9.8h\n" "smlal2 v11.8h, v0.16b, v7.16b\n" "sadalp v20.4s, v12.8h\n" "smlal2 v14.8h, v1.16b, v6.16b\n" "sadalp v21.4s, v13.8h\n" "smlal2 v15.8h, v1.16b, v7.16b\n" "smull v8.8h, v2.8b, v4.8b\n" "smull v9.8h, v2.8b, v5.8b\n" "ld1 {v0.16b}, [%[a_ptr]], #16\n" "smull v12.8h, v3.8b, v4.8b\n" "ld1 {v1.16b}, [%[a_ptr]], #16\n" "smull v13.8h, v3.8b, v5.8b\n" "sadalp v18.4s, v10.8h\n" "smlal2 v8.8h, v2.16b, v4.16b\n" "sadalp v19.4s, v11.8h\n" "smlal2 v9.8h, v2.16b, v5.16b\n" "sadalp v22.4s, v14.8h\n" "smlal2 v12.8h, v3.16b, v4.16b\n" "sadalp v23.4s, v15.8h\n" "smlal2 v13.8h, v3.16b, v5.16b\n" "sub %w[k], %w[k], #2\n" "cmp %w[k], #2\n" "smull v10.8h, v2.8b, v6.8b\n" "ld1 {v4.16b}, [%[b_ptr]], #16\n" "smull v11.8h, v2.8b, v7.8b\n" "ld1 {v5.16b}, [%[b_ptr]], #16\n" "smull v14.8h, v3.8b, v6.8b\n" "sadalp v24.4s, v8.8h\n" "smull v15.8h, v3.8b, v7.8b\n" "sadalp v25.4s, v9.8h\n" "smlal2 v10.8h, v2.16b, v6.16b\n" "sadalp v28.4s, v12.8h\n" "smlal2 v11.8h, v2.16b, v7.16b\n" "sadalp v29.4s, v13.8h\n" "smlal2 v14.8h, v3.16b, v6.16b\n" "sadalp v26.4s, v10.8h\n" "smlal2 v15.8h, v3.16b, v7.16b\n" "sadalp v27.4s, v11.8h\n" "sadalp v30.4s, v14.8h\n" "sadalp v31.4s, v15.8h\n" "bgt 1b\n" "blt 3f\n" //! K==2 "2:\n" "smull v8.8h, v0.8b, v4.8b\n" "smull v9.8h, v0.8b, v5.8b\n" "ld1 {v6.16b}, [%[b_ptr]], #16\n" "smull v12.8h, v1.8b, v4.8b\n" "smull v13.8h, v1.8b, v5.8b\n" "ld1 {v7.16b}, [%[b_ptr]], #16\n" "smlal2 v8.8h, v0.16b, v4.16b\n" "smlal2 v9.8h, v0.16b, v5.16b\n" "smlal2 v12.8h, v1.16b, v4.16b\n" "smlal2 v13.8h, v1.16b, v5.16b\n" "smull v10.8h, v0.8b, v6.8b\n" "ld1 {v2.16b}, [%[a_ptr]], #16\n" "smull v11.8h, v0.8b, v7.8b\n" "smull v14.8h, v1.8b, v6.8b\n" "ld1 {v3.16b}, [%[a_ptr]], #16\n" "smull v15.8h, v1.8b, v7.8b\n" "sadalp v16.4s, v8.8h\n" "smlal2 v10.8h, v0.16b, v6.16b\n" "sadalp v17.4s, v9.8h\n" "smlal2 v11.8h, v0.16b, v7.16b\n" "sadalp v20.4s, v12.8h\n" "smlal2 v14.8h, v1.16b, v6.16b\n" "sadalp v21.4s, v13.8h\n" "smlal2 v15.8h, v1.16b, v7.16b\n" "smull v8.8h, v2.8b, v4.8b\n" "smull v9.8h, v2.8b, v5.8b\n" "ld1 {v0.16b}, [%[a_ptr]], #16\n" "smull v12.8h, v3.8b, v4.8b\n" "ld1 {v1.16b}, [%[a_ptr]], #16\n" "smull v13.8h, v3.8b, v5.8b\n" "sadalp v18.4s, v10.8h\n" "smlal2 v8.8h, v2.16b, v4.16b\n" "sadalp v19.4s, v11.8h\n" "smlal2 v9.8h, v2.16b, v5.16b\n" "sadalp v22.4s, v14.8h\n" "smlal2 v12.8h, v3.16b, v4.16b\n" "sadalp v23.4s, v15.8h\n" "smlal2 v13.8h, v3.16b, v5.16b\n" "smull v10.8h, v2.8b, v6.8b\n" "smull v11.8h, v2.8b, v7.8b\n" "ld1 {v4.16b}, [%[b_ptr]], #16\n" "smull v14.8h, v3.8b, v6.8b\n" "ld1 {v5.16b}, [%[b_ptr]], #16\n" "smull v15.8h, v3.8b, v7.8b\n" "sadalp v24.4s, v8.8h\n" "smlal2 v10.8h, v2.16b, v6.16b\n" "sadalp v25.4s, v9.8h\n" "smlal2 v11.8h, v2.16b, v7.16b\n" "sadalp v28.4s, v12.8h\n" "smlal2 v14.8h, v3.16b, v6.16b\n" "sadalp v29.4s, v13.8h\n" "smlal2 v15.8h, v3.16b, v7.16b\n" "sadalp v26.4s, v10.8h\n" "sadalp v27.4s, v11.8h\n" "sadalp v30.4s, v14.8h\n" "sadalp v31.4s, v15.8h\n" //! K==1 "3:\n" "smull v8.8h, v0.8b, v4.8b\n" "smull v9.8h, v0.8b, v5.8b\n" "ld1 {v6.16b}, [%[b_ptr]], #16\n" "smull v12.8h, v1.8b, v4.8b\n" "smull v13.8h, v1.8b, v5.8b\n" "ld1 {v7.16b}, [%[b_ptr]], #16\n" "smlal2 v8.8h, v0.16b, v4.16b\n" "smlal2 v9.8h, v0.16b, v5.16b\n" "smlal2 v12.8h, v1.16b, v4.16b\n" "smlal2 v13.8h, v1.16b, v5.16b\n" "smull v10.8h, v0.8b, v6.8b\n" "ld1 {v2.16b}, [%[a_ptr]], #16\n" "smull v11.8h, v0.8b, v7.8b\n" "smull v14.8h, v1.8b, v6.8b\n" "ld1 {v3.16b}, [%[a_ptr]], #16\n" "smull v15.8h, v1.8b, v7.8b\n" "sadalp v16.4s, v8.8h\n" "smlal2 v10.8h, v0.16b, v6.16b\n" "sadalp v17.4s, v9.8h\n" "smlal2 v11.8h, v0.16b, v7.16b\n" "sadalp v20.4s, v12.8h\n" "smlal2 v14.8h, v1.16b, v6.16b\n" "sadalp v21.4s, v13.8h\n" "smlal2 v15.8h, v1.16b, v7.16b\n" "smull v8.8h, v2.8b, v4.8b\n" "smull v9.8h, v2.8b, v5.8b\n" "smull v12.8h, v3.8b, v4.8b\n" "smull v13.8h, v3.8b, v5.8b\n" "sadalp v18.4s, v10.8h\n" "smlal2 v8.8h, v2.16b, v4.16b\n" "sadalp v19.4s, v11.8h\n" "smlal2 v9.8h, v2.16b, v5.16b\n" "sadalp v22.4s, v14.8h\n" "smlal2 v12.8h, v3.16b, v4.16b\n" "sadalp v23.4s, v15.8h\n" "smlal2 v13.8h, v3.16b, v5.16b\n" "smull v10.8h, v2.8b, v6.8b\n" "sadalp v24.4s, v8.8h\n" "smull v11.8h, v2.8b, v7.8b\n" "sadalp v25.4s, v9.8h\n" "smull v14.8h, v3.8b, v6.8b\n" "sadalp v28.4s, v12.8h\n" "smull v15.8h, v3.8b, v7.8b\n" "sadalp v29.4s, v13.8h\n" "smlal2 v10.8h, v2.16b, v6.16b\n" "smlal2 v11.8h, v2.16b, v7.16b\n" "sadalp v26.4s, v10.8h\n" "smlal2 v14.8h, v3.16b, v6.16b\n" "sadalp v27.4s, v11.8h\n" "smlal2 v15.8h, v3.16b, v7.16b\n" "sadalp v30.4s, v14.8h\n" "sadalp v31.4s, v15.8h\n" "addp v4.4s, v16.4s, v20.4s\n" "addp v5.4s, v24.4s, v28.4s\n" "addp v6.4s, v17.4s, v21.4s\n" "addp v7.4s, v25.4s, v29.4s\n" "addp v8.4s, v18.4s, v22.4s\n" "addp v9.4s, v26.4s, v30.4s\n" "addp v10.4s, v19.4s, v23.4s\n" "addp v11.4s, v27.4s, v31.4s\n" "addp v0.4s, v4.4s, v5.4s\n" "addp v1.4s, v6.4s, v7.4s\n" "addp v2.4s, v8.4s, v9.4s\n" "addp v3.4s, v10.4s, v11.4s\n" "cmp %w[is_first_k], #1\n" "beq 6f\n" "cmp %w[remain_n], #3\n" "beq 1003f\n" "cmp %w[remain_n], #2\n" "beq 1002f\n" "cmp %w[remain_n], #1\n" "beq 1001f\n" "1003:\n" "ld1 {v8.4s, v9.4s, v10.4s}, [%[output]]\n" "add v0.4s, v0.4s, v8.4s\n" "add v1.4s, v1.4s, v9.4s\n" "add v2.4s, v2.4s, v10.4s\n" "b 6f\n" "1002:\n" "ld1 {v8.4s, v9.4s}, [%[output]]\n" "add v0.4s, v0.4s, v8.4s\n" "add v1.4s, v1.4s, v9.4s\n" "b 6f\n" "1001:\n" "ld1 {v8.4s}, [%[output]]\n" "add v0.4s, v0.4s, v8.4s\n" "6:\n" "cmp %w[remain_n], #3\n" "beq 10003f\n" "cmp %w[remain_n], #2\n" "beq 10002f\n" "cmp %w[remain_n], #1\n" "beq 10001f\n" "10003:\n" "str q2, [%[output], #32]\n" "10002:\n" "str q1, [%[output], #16]\n" "10001:\n" "str q0, [%[output]]\n" "7:\n" : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [remain_n] "+r"(remain_n), [is_first_k] "+r"(is_first_k), [k] "+r"(K), [output] "+r"(output) : : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); } static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, int kmax) { //! pack form {oc/4, ic/4, 4(ic), 4(oc)} to {oc/4, ic/16, 4(oc), 16(ic)} int8_t zerobuff[4][64]; std::memset(zerobuff, 0, sizeof(int8_t) * 64 * 4); megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0 && (ymax - y0) % 4 == 0, "mk4 matmul with m is not times of 4"); megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && (kmax - k0) % 4 == 0, "mk4 matmul with k is not times of 4"); size_t roundk = round_up(kmax - k0, 16); size_t out_offset = roundk * 4; int y = y0; int start_y = y0 / 4; for (; y + 15 < ymax; y += 16, start_y += 4) { const int8_t* inptr0 = inptr + start_y * ldin + k0 * 4; const int8_t* inptr1 = inptr0 + ldin; const int8_t* inptr2 = inptr1 + ldin; const int8_t* inptr3 = inptr2 + ldin; int8_t* output = outptr + (y - y0) / 4 * out_offset; prefetch_2x(inptr0); prefetch_2x(inptr1); prefetch_2x(inptr2); prefetch_2x(inptr3); int K = kmax - k0; for (; K > 15; K -= 16) { transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output, out_offset); output += 64; } if (K > 0) { std::memcpy(zerobuff[0], inptr0, sizeof(int8_t) * K * 4); std::memcpy(zerobuff[1], inptr1, sizeof(int8_t) * K * 4); std::memcpy(zerobuff[2], inptr2, sizeof(int8_t) * K * 4); std::memcpy(zerobuff[3], inptr3, sizeof(int8_t) * K * 4); inptr0 = zerobuff[0]; inptr1 = zerobuff[1]; inptr2 = zerobuff[2]; inptr3 = zerobuff[3]; transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output, out_offset); output += 64; } } for (; y + 3 < ymax; y += 4, start_y++) { const int8_t* inptr0 = inptr + start_y * ldin + k0 * 4; int8_t* output = outptr + (y - y0) / 4 * out_offset; prefetch_2x(inptr0); int K = kmax - k0; for (; K > 15; K -= 16) { transpose_interleave_1x4_4_b(inptr0, output); output += 64; } if (K > 0) { std::memcpy(zerobuff[0], inptr0, sizeof(int8_t) * K * 4); inptr0 = zerobuff[0]; transpose_interleave_1x4_4_b(inptr0, output); output += 64; } } } static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { int32_t zerobuff[4]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); const int ksize = kmax - k0; const int ICB = (ksize) / 4; const int ksize4 = round_up(ICB, 4) * 4; int32_t* outptr = reinterpret_cast(out); megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && ksize % 4 == 0, "mk4 matmul with k is not times of 4"); int k = k0 / 4; for (; k + 3 < ICB; k += 4) { const int32_t* inptr0 = reinterpret_cast(in + k * ldin + x0); const int32_t* inptr1 = reinterpret_cast(in + (k + 1) * ldin + x0); const int32_t* inptr2 = reinterpret_cast(in + (k + 2) * ldin + x0); const int32_t* inptr3 = reinterpret_cast(in + (k + 3) * ldin + x0); int32_t* outptr_inner = outptr; int x = x0; for (; x + 3 < xmax; x += 4) { transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner); outptr_inner += ksize4; } if (x < xmax) { for (; x < xmax; x++) { *outptr_inner++ = *inptr0++; *outptr_inner++ = *inptr1++; *outptr_inner++ = *inptr2++; *outptr_inner++ = *inptr3++; } } outptr += 4 * 4; } if (k < ICB) { const int32_t* inptr0 = reinterpret_cast(in + k * ldin + x0); const int32_t* inptr1 = reinterpret_cast(in + (k + 1) * ldin + x0); const int32_t* inptr2 = reinterpret_cast(in + (k + 2) * ldin + x0); const int32_t* inptr3 = reinterpret_cast(in + (k + 3) * ldin + x0); int32_t* outptr_inner = outptr; int x = x0; for (; x + 3 < xmax; x += 4) { if (k + 3 >= ICB) { switch (k + 3 - ICB) { case 2: inptr1 = zerobuff; MEGDNN_FALLTHRU case 1: inptr2 = zerobuff; MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; default: megdnn_assert(0); } } transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner); outptr_inner += ksize4; } if (x < xmax) { if (k + 3 >= ICB) { switch (k + 3 - ICB) { case 2: inptr1 = zerobuff; MEGDNN_FALLTHRU case 1: inptr2 = zerobuff; MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; default: megdnn_assert(0); } } for (; x < xmax; x++) { *outptr_inner++ = *inptr0++; *outptr_inner++ = *inptr1++; *outptr_inner++ = *inptr2++; *outptr_inner++ = *inptr3++; } } outptr += 4 * 4; } } } // namespace matmul_4x4x16 } // namespace aarch64 } // namespace megdnn #endif // vim: syntax=cpp.doxygen