/* Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause-Clear */ /*-------------------------------------------------------------------------- * SME1 based Matrix multiplication code for FP32 input matrices to FP32 * output matrix * C = A*B * A: Left input matrix of dimension M x K * B: Right input matrix of dimension K x N * C: Result matrix of dimension M x N * * Usage of function: * sgemm_direct_sme1_2VLx2VL( uint64_t M , uint64_t K, uint64_t N,\ const float * restrict A_base,\ const float * restrict B_base,\ const float * restrict C_base); ----------------------------------------------------------------------------*/ #define M x0 //M dimension #define K x1 //K dimension #define N x2 //N dimension #define A_base x3 //Pointer to left matrix(A) #define B_base x4 //Pointer to right matrix(B) #define C_base x5 //Pointer to result matrix(C) #define Aptr x6 //Pointer to traverse A #define Aptr_end x7 //Pointer to end of row of A #define Cptr x8 //Pointer to traverse C #define Cptr0 x9 //2nd Pointer to traverse C #define Cptr1 x10 //3rd Pointer to traverse C #define Bptr x11 //Pointer to traverse B #define Bptr0 x12 //2nd Pointer to traverse B #define N_exit x14 //Exit condition for N loop #define K_exit x15 //Exit condition for K loop #define M_cntr x16 //M loop counter #define C1 x17 //Constant1: N*(SVLs+1);SVLs-No. of 32-bit elements #define C2 x18 //Constant2: N + SVLs #define C3 x19 //Constant3: K*SVLs + SVLs #define C4 x20 //Constant4: SVLs-2 #define C5 x21 //Constant5: K*SVLs #define C6 x22 //Constant6: N*SVLs .text .global sgemm_direct_sme1_2VLx2VL sgemm_direct_sme1_2VLx2VL: stp x19, x20, [sp, #-48]! stp x21, x22, [sp, #16] stp x23, x24, [sp, #32] smstart cntw C4 //SVLs mul C5, C4, K //K*SVLs mul C6, C4, N //N*SVLs add C1, C6, N //N*SVLs + N add N_exit, B_base, N, lsl #2 //N_Loop exit conditon mov M_cntr, #0 add C2, N, C4 //N + SVLs add C3, C5, C4 //K*SVLs + SVLs whilelt p2.s, M_cntr, M //Tile 0,1 predicate (M dimension) sub w20, w20, #2 //SVLs-2 .M_Loop: incw M_cntr whilelt p3.s, M_cntr, M //Tile 2,3 predicate (M dimension) mov Bptr, B_base //B_base mov Cptr, C_base //C_base whilelt p0.b, Bptr, N_exit //Tile 0/2 predicate (N dimension) .N_Loop: mov Aptr, A_base //Aptr = A_base mov Bptr0, Bptr //Bptr = B_base mov Cptr0, Cptr //Cptr0 = C_base addvl Cptr1, Cptr, #1 //Cptr1 = C_base + SVLb addvl Bptr, Bptr, #1 whilelt p1.b, Bptr, N_exit //Tile 1,3 predicate (N dimension) add Aptr_end, A_base, C5, lsl #2 //A_base + K*SVLs addvl K_exit, Aptr_end, #-1 //Exit condition for K loop //Load 1st vector from Aptr ld1w {z1.s}, p2/z, [Aptr] zero {za} // Load 1st vector from Bptr ld1w {z2.s}, p0/z, [Bptr0] // ZA0 += 1st Aptr vector OP 1st Bptr vector fmopa za0.s, p2/m, p0/m, z1.s, z2.s // Load 2nd vector from Aptr ld1w {z5.s}, p3/z, [Aptr, C5, lsl #2] // Aptr += SVLb addvl Aptr, Aptr, #1 .K_Loop: // ZA2 += 2nd Aptr vector OP 1st Bptr vector fmopa za2.s, p3/m, p0/m, z5.s, z2.s // Load 2nd vector from Bptr ld1w {z3.s}, p1/z, [Bptr0, #1, MUL VL] // ZA1 += 1st Aptr vector OP 2nd Bptr vector fmopa za1.s, p2/m, p1/m, z1.s, z3.s // Load next 1st vector from Aptr ld1w {z0.s}, p2/z, [Aptr] // ZA3 += 2nd Aptr vector OP 2nd Bptr vector fmopa za3.s, p3/m, p1/m, z5.s, z3.s cmp K, #2 b.le process_K_less_than_equal_2 // Load next 1st vector from Bptr ld1w {z6.s}, p0/z, [Bptr0, N, lsl #2] // ZA0 += 1st Aptr vector OP 1st Bptr vector fmopa za0.s, p2/m, p0/m, z0.s, z6.s // Load next 2nd vector from Aptr ld1w {z4.s}, p3/z, [Aptr, C5, lsl #2] // ZA2 += 2nd Aptr vector OP 1st Bptr vector fmopa za2.s, p3/m, p0/m, z4.s, z6.s // Load next 2nd vector from Bptr ld1w {z7.s}, p1/z, [Bptr0, C2, lsl #2] // Bptr += 2*ldb FP32 elms [Bytes] add Bptr0, Bptr0, N, lsl #3 // ZA1 += 1st Aptr vector OP 2nd Bptr vector fmopa za1.s, p2/m, p1/m, z0.s, z7.s // Load next 2nd vector from Aptr ld1w {z1.s}, p2/z, [Aptr, #1, MUL VL] // ZA3 += 2nd Aptr vector OP 2nd Bptr vector fmopa za3.s, p3/m, p1/m, z4.s, z7.s // Load next 1st vector from Bptr ld1w {z2.s}, p0/z, [Bptr0] // ZA0 += 1st Aptr vector OP 1st Bptr vector fmopa za0.s, p2/m, p0/m, z1.s, z2.s // Load next 2nd vector from Aptr ld1w {z5.s}, p3/z, [Aptr, C3, lsl #2] // Aptr += 2*SVLb [Bytes] addvl Aptr, Aptr, #2 cmp Aptr, K_exit b.mi .K_Loop // ZA2 += 2nd Aptr vector OP 1st Bptr vector fmopa za2.s, p3/m, p0/m, z5.s, z2.s // Load next 2nd vector from Bptr ld1w {z3.s}, p1/z, [Bptr0, #1, MUL VL] // ZA1 += 1st Aptr vector OP 2nd Bptr vector fmopa za1.s, p2/m, p1/m, z1.s, z3.s // ZA3 += 2nd Aptr vector OP 2nd Bptr vector fmopa za3.s, p3/m, p1/m, z5.s, z3.s process_K_less_than_equal_2: // Bptr += 2*ldb FP32 elements add Bptr0, Bptr0, N, lsl #2 cmp Aptr, Aptr_end b.pl .Ktail_end .Ktail_start: ld1w {z1.s}, p2/z, [Aptr] ld1w {z2.s}, p0/z, [Bptr0] ld1w {z3.s}, p1/z, [Bptr0, #1, MUL VL] fmopa za0.s, p2/m, p0/m, z1.s, z2.s ld1w {z5.s}, p3/z, [Aptr, C5, lsl #2] fmopa za2.s, p3/m, p0/m, z5.s, z2.s fmopa za1.s, p2/m, p1/m, z1.s, z3.s fmopa za3.s, p3/m, p1/m, z5.s, z3.s .Ktail_end: mov w13, #0 psel p4, p0, p2.s[w13, 0] psel p5, p1, p2.s[w13, 0] psel p6, p0, p3.s[w13, 0] psel p7, p1, p3.s[w13, 0] // Store to Cptr0 st1w {za0h.s[w13, #0]}, p4, [Cptr0] // Store to Cptr1 st1w {za1h.s[w13, #0]}, p5, [Cptr1] // Store to Cptr0 + N*SVLs st1w {za2h.s[w13, #0]}, p6, [Cptr0, C6, lsl #2] // Store to Cptr1 + N*SVLs st1w {za3h.s[w13, #0]}, p7, [Cptr1, C6, lsl #2] .Loop_store_ZA: psel p4, p0, p2.s[w13, 1] psel p5, p1, p2.s[w13, 1] psel p6, p0, p3.s[w13, 1] psel p7, p1, p3.s[w13, 1] // Store to Cptr0 + N st1w {za0h.s[w13, #1]}, p4, [Cptr0, N, lsl #2] // Store to Cptr1 + N st1w {za1h.s[w13, #1]}, p5, [Cptr1, N, lsl #2] // Store to Cptr0 + N*(SVLs+1) st1w {za2h.s[w13, #1]}, p6, [Cptr0, C1, lsl #2] // Store to Cptr1 + N*(SVLs+1) st1w {za3h.s[w13, #1]}, p7, [Cptr1, C1, lsl #2] add Cptr0, Cptr0, N, lsl #3 //Cptr0 += 2*N FP32 elements add Cptr1, Cptr1, N, lsl #3 //Cptr1 += 2*N FP32 elements add w13, w13, #2 psel p4, p0, p2.s[w13, 0] psel p5, p1, p2.s[w13, 0] psel p6, p0, p3.s[w13, 0] psel p7, p1, p3.s[w13, 0] st1w {za0h.s[w13, #0]}, p4, [Cptr0] st1w {za1h.s[w13, #0]}, p5, [Cptr1] st1w {za2h.s[w13, #0]}, p6, [Cptr0, C6, lsl #2] st1w {za3h.s[w13, #0]}, p7, [Cptr1, C6, lsl #2] cmp w13, w20 b.mi .Loop_store_ZA psel p4, p0, p2.s[w13, 1] psel p5, p1, p2.s[w13, 1] psel p6, p0, p3.s[w13, 1] psel p7, p1, p3.s[w13, 1] st1w {za0h.s[w13, #1]}, p4, [Cptr0, N, lsl #2] st1w {za1h.s[w13, #1]}, p5, [Cptr1, N, lsl #2] st1w {za2h.s[w13, #1]}, p6, [Cptr0, C1, lsl #2] st1w {za3h.s[w13, #1]}, p7, [Cptr1, C1, lsl #2] addvl Cptr, Cptr, #2 addvl Bptr, Bptr, #1 whilelt p0.b, Bptr, N_exit //1st Tile predicate (N dimension) b.first .N_Loop add A_base, A_base, C5, lsl #3 //A_base += 2*K*SVLs FP32 elements add C_base, C_base, C6, lsl #3 //C_base += 2*N*SVLs FP32 elements incw M_cntr whilelt p2.s, M_cntr, M //1st Tile predicate (M dimension) b.first .M_Loop smstop ldp x23, x24, [sp, #32] ldp x21, x22, [sp, #16] ldp x19, x20, [sp], #48 ret