From 8a2eab111478ded313d2e43dd5274144ec0a65f7 Mon Sep 17 00:00:00 2001 From: nihui Date: Mon, 9 Jun 2025 15:32:50 +0800 Subject: [PATCH] set localsize as multiple of subgroup size (#2483) * fix innerproduct gemm vulkan --- .../shader/innerproduct_gemm_wp4to8.comp | 2 +- .../vulkan/shader/innerproduct_gemm_wp8.comp | 2 +- .../shader/innerproduct_gemm_wp8to4.comp | 2 +- src/pipeline.cpp | 98 ++++++++++++++++++- 4 files changed, 100 insertions(+), 4 deletions(-) diff --git a/src/layer/vulkan/shader/innerproduct_gemm_wp4to8.comp b/src/layer/vulkan/shader/innerproduct_gemm_wp4to8.comp index 9135afb5c..497a5576d 100644 --- a/src/layer/vulkan/shader/innerproduct_gemm_wp4to8.comp +++ b/src/layer/vulkan/shader/innerproduct_gemm_wp4to8.comp @@ -68,7 +68,7 @@ void main() int gy = int(gl_GlobalInvocationID.y); int gz = int(gl_GlobalInvocationID.z); - if (gx >= psc(outw) || gy >= psc(outh) || gz >= 1) + if (gx * 8 >= psc(outw) || gy >= psc(outh) || gz >= 1) return; afpvec8 sum; diff --git a/src/layer/vulkan/shader/innerproduct_gemm_wp8.comp b/src/layer/vulkan/shader/innerproduct_gemm_wp8.comp index b64cc8713..0c21e54de 100644 --- a/src/layer/vulkan/shader/innerproduct_gemm_wp8.comp +++ b/src/layer/vulkan/shader/innerproduct_gemm_wp8.comp @@ -68,7 +68,7 @@ void main() int gy = int(gl_GlobalInvocationID.y); int gz = int(gl_GlobalInvocationID.z); - if (gx >= psc(outw) || gy >= psc(outh) || gz >= 1) + if (gx * 8 >= psc(outw) || gy >= psc(outh) || gz >= 1) return; afpvec8 sum; diff --git a/src/layer/vulkan/shader/innerproduct_gemm_wp8to4.comp b/src/layer/vulkan/shader/innerproduct_gemm_wp8to4.comp index 76998c49b..e4aa13a74 100644 --- a/src/layer/vulkan/shader/innerproduct_gemm_wp8to4.comp +++ b/src/layer/vulkan/shader/innerproduct_gemm_wp8to4.comp @@ -68,7 +68,7 @@ void main() int gy = int(gl_GlobalInvocationID.y); int gz = int(gl_GlobalInvocationID.z); - if (gx >= psc(outw) || gy >= psc(outh) || gz >= 1) + if (gx * 4 >= psc(outw) || gy >= psc(outh) || gz >= 1) return; afpvec4 sum; diff --git a/src/pipeline.cpp b/src/pipeline.cpp index 35c069d9c..2354f9ede 100644 --- a/src/pipeline.cpp +++ b/src/pipeline.cpp @@ -122,13 +122,109 @@ void Pipeline::set_subgroup_size(uint32_t subgroup_size) d->subgroup_size = subgroup_size; } +static int count_trailing_zeros(unsigned int v) +{ + int cnt = 0; + while ((v & 1) == 0) + { + cnt++; + v >>= 1; + } + return cnt; +} + +// round up v to the next multiple of 2^k +static unsigned int round_up_pow2_mul(unsigned int v, int k) +{ + unsigned int m = 1u << k; + return ((v + m - 1) / m) * m; +} + +// adjust x, y, z so that new x * y * z is a multiple of size (size must be a power of two), and new x, y, z are no less than the inputs +// new values do not have to be integer multiples of the originals +// minimize the total increment (x'-x)+(y'-y)+(z'-z) +// additional constraint: if original y is 1, prefer not to adjust y; if original z is 1, prefer not to adjust z +static void adjust_xyz(int& x, int& y, int& z, const int subgroup_size) +{ + if (x * y * z % subgroup_size == 0) + return; + + int target_n = 0; + { + while ((1 << target_n) != subgroup_size) + target_n++; + } + + // subgroup shall usually be 4 ~ 128, sanitize the max possible size + target_n = std::min(target_n, 10); + + const int tx = count_trailing_zeros((unsigned int)x); + const int ty = count_trailing_zeros((unsigned int)y); + const int tz = count_trailing_zeros((unsigned int)z); + const int tn = tx + ty + tz; + + const int need = target_n - tn; + + if (z == 1) + { + if (y == 1) + { + // adjust x only + x = round_up_pow2_mul((unsigned int)x, target_n); + } + else if (x == 1) + { + // adjust y only + y = round_up_pow2_mul((unsigned int)y, target_n); + } + else + { + // adjust x and y + y = round_up_pow2_mul((unsigned int)y, ty + need / 2); + x = round_up_pow2_mul((unsigned int)x, tx + need - need / 2); + } + } + else if (y == 1) + { + if (x == 1) + { + // adjust z only + z = round_up_pow2_mul((unsigned int)z, target_n); + } + else + { + // adjust x and z + z = round_up_pow2_mul((unsigned int)z, tz + need / 2); + x = round_up_pow2_mul((unsigned int)x, tx + need - need / 2); + } + } + else if (x == 1) + { + // adjust y and z + z = round_up_pow2_mul((unsigned int)z, tz + need / 2); + y = round_up_pow2_mul((unsigned int)y, ty + need - need / 2); + } + else + { + // adjust x y z + z = round_up_pow2_mul((unsigned int)z, tz + need / 3); + y = round_up_pow2_mul((unsigned int)y, ty + (need - need / 3) / 2); + x = round_up_pow2_mul((unsigned int)x, tx + need - (need - need / 3) / 2); + } +} + void Pipeline::set_local_size_xyz(int w, int h, int c) { + // dispatch at least one subgroup + // make local size be multiple of subgroup size + // and metal is unhappy with arbitrary local size anyway + adjust_xyz(w, h, c, d->subgroup_size); + d->local_size_x = w; d->local_size_y = h; d->local_size_z = c; - // NCNN_LOGE("local size = %d %d %d", local_size_x, local_size_y, local_size_z); + // NCNN_LOGE("local size = %d %d %d", local_size_x, local_size_y, local_size_z); } int Pipeline::create(const uint32_t* spv_data, size_t spv_data_size, const std::vector& specializations)