From 80d3c2ad95781211b77272a1cfc9d77ba7ec402d Mon Sep 17 00:00:00 2001 From: Masato Nakagawa Date: Tue, 11 Mar 2025 20:18:20 +0900 Subject: [PATCH] Add Improving Load Imbalance in Thread-Parallel GEMM --- driver/level3/level3_thread.c | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/driver/level3/level3_thread.c b/driver/level3/level3_thread.c index 9b1aadf7d..77aaeee6b 100644 --- a/driver/level3/level3_thread.c +++ b/driver/level3/level3_thread.c @@ -591,7 +591,7 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG BLASLONG nthreads = args -> nthreads; - BLASLONG width, i, j, k, js; + BLASLONG width, width_n, i, j, k, js; BLASLONG m, n, n_from, n_to; int mode; #if defined(DYNAMIC_ARCH) @@ -740,18 +740,25 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG /* Partition (a step of) n into nthreads regions */ range_N[0] = js; num_parts = 0; - while (n > 0){ - width = blas_quickdivide(n + nthreads - num_parts - 1, nthreads - num_parts); - if (width < switch_ratio) { - width = switch_ratio; + for(j = 0; j < nthreads_n; j++){ + width_n = blas_quickdivide(n + nthreads_n - j - 1, nthreads_n - j); + n -= width_n; + for(i = 0; i < nthreads_m; i++){ + width = blas_quickdivide(width_n + nthreads_m - i - 1, nthreads_m - i); + if (width < switch_ratio) { + width = switch_ratio; + } + width = round_up(width_n, width, GEMM_PREFERED_SIZE); + + width_n -= width; + if (width_n < 0) { + width = width + width_n; + width_n = 0; + } + range_N[num_parts + 1] = range_N[num_parts] + width; + + num_parts ++; } - width = round_up(n, width, GEMM_PREFERED_SIZE); - - n -= width; - if (n < 0) width = width + n; - range_N[num_parts + 1] = range_N[num_parts] + width; - - num_parts ++; } for (j = num_parts; j < MAX_CPU_NUMBER; j++) { range_N[j + 1] = range_N[num_parts];