|
|
|
@@ -197,13 +197,13 @@ def generate_gemm_kernel_inner_complex( settings, dest, M, N, vlen, a_regs ): |
|
|
|
dest.write("ai += {M}*2;") |
|
|
|
dest.write() |
|
|
|
|
|
|
|
|
|
|
|
accumulation_regs = a_regs * N * settings['LMUL_ACC'].value |
|
|
|
# for each vector register loaded from matrix A, we require N registers to hold vector-scalar multiply-accumulate results |
|
|
|
accumulation_regs = a_regs * N |
|
|
|
dest.write("// {a_regs} vector regs to hold A array contents, {accumulation_regs} regs to hold values accumulated over k", |
|
|
|
a_regs=a_regs*2, accumulation_regs=accumulation_regs*2 |
|
|
|
) |
|
|
|
pass_regs = (accumulation_regs + a_regs)*2 |
|
|
|
tmp_regs = 32-pass_regs |
|
|
|
tmp_regs = (32 // settings['LMUL_ACC'].value) - pass_regs |
|
|
|
if tmp_regs < 2: |
|
|
|
raise RuntimeError("Complex kernel would use too many registers!") |
|
|
|
|
|
|
|
@@ -337,10 +337,12 @@ def generate_gemm_kernel( settings, OUTPUT ): |
|
|
|
|
|
|
|
M = settings['M'].value |
|
|
|
N = settings['N'].value |
|
|
|
vlenmax = int( settings['reg_width_bits'].value / settings['ELEN_PARAM'].value ) |
|
|
|
vlenmax = int(settings['reg_width_bits'].value * settings['LMUL_ACC'].value / |
|
|
|
settings['ELEN_PARAM'].value) |
|
|
|
a_regs = max(int(M/vlenmax), 1) |
|
|
|
|
|
|
|
accumulation_regs = a_regs * N * settings['LMUL_ACC'].value |
|
|
|
# for each vector register loaded from matrix A, we require N registers to hold vector-scalar multiply-accumulate results |
|
|
|
accumulation_regs = a_regs * N |
|
|
|
required_regs = accumulation_regs + a_regs |
|
|
|
if is_complex: |
|
|
|
required_regs = required_regs * 2 + 2 |
|
|
|
@@ -380,9 +382,9 @@ def generate_gemm_kernel( settings, OUTPUT ): |
|
|
|
'''.format(tail_policy=settings['tail_policy'].value)) |
|
|
|
|
|
|
|
|
|
|
|
if required_regs > 32: |
|
|
|
raise Exception("{} vector registers needed during accumulation for unrolling {} x {}{} but only 32 are available".format( |
|
|
|
required_regs, N, M, (" with wide accumulator" if settings['LMUL_ACC'].value > 1 else '') |
|
|
|
if required_regs > (32 // settings['LMUL_ACC'].value): |
|
|
|
raise Exception("{} vector registers needed during accumulation for unrolling {} x {}{} but only {} are available".format( |
|
|
|
required_regs, N, M, (" with wide accumulator" if settings['LMUL_ACC'].value > 1 else ''), 32 // settings['LMUL_ACC'].value |
|
|
|
)) |
|
|
|
|
|
|
|
TRMM = (settings['op'].value == 'trmm') |
|
|
|
@@ -448,7 +450,8 @@ def generate_gemm_kernel( settings, OUTPUT ): |
|
|
|
def generate_M_tails( dest, settings, M, N ): |
|
|
|
M_tail = int(M/2) |
|
|
|
M_tail_min = settings['M_tail_scalar_from'].value |
|
|
|
vlenmax = int( settings['reg_width_bits'].value / settings['ELEN_PARAM'].value ) |
|
|
|
vlenmax = int(settings['reg_width_bits'].value * settings['LMUL_ACC'].value |
|
|
|
/ settings['ELEN_PARAM'].value ) |
|
|
|
TRMM = (settings['op'].value == 'trmm') |
|
|
|
is_complex = settings['complex'].value |
|
|
|
generate_gemm_kernel_inner = generate_gemm_kernel_inner_complex if is_complex else generate_gemm_kernel_inner_real |
|
|
|
@@ -667,4 +670,4 @@ def main(): |
|
|
|
ERROR("unsupported kernel type {}".format(settings['op'])) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
main() |
|
|
|
main() |