diff --git a/src/gpu.cpp b/src/gpu.cpp index c359a890d..00a711d09 100644 --- a/src/gpu.cpp +++ b/src/gpu.cpp @@ -5,6 +5,7 @@ #if NCNN_VULKAN +#include #include #include #include @@ -2014,6 +2015,79 @@ const std::vector& GpuInfo::queryCooperativeVec return d->queryCooperativeVectorSubPropertiesNV; } +void GpuInfo::get_optimal_cooperative_matrix_mnk(int M, int N, int K, VkComponentTypeKHR type, VkComponentTypeKHR acctype, VkScopeKHR scope, int& coopmat_M, int& coopmat_N, int& coopmat_K) const +{ + coopmat_M = 0; + coopmat_N = 0; + coopmat_K = 0; + + // collect mnk candidates + std::vector mnk_properties; + + if (d->support_VK_KHR_cooperative_matrix && d->queryCooperativeMatrixFeatures.cooperativeMatrix) + { + for (size_t i = 0; i < d->queryCooperativeMatrixSubProperties.size(); i++) + { + const VkCooperativeMatrixPropertiesKHR& cmp = d->queryCooperativeMatrixSubProperties[i]; + + if (cmp.AType == type && cmp.BType == type + && cmp.CType == acctype && cmp.ResultType == acctype + && cmp.scope == scope) + { + mnk_properties.push_back(cmp); + } + } + } + else if (d->support_VK_NV_cooperative_matrix && d->queryCooperativeMatrixFeaturesNV.cooperativeMatrix) + { + for (size_t i = 0; i < d->queryCooperativeMatrixSubPropertiesNV.size(); i++) + { + const VkCooperativeMatrixPropertiesNV& cmp = d->queryCooperativeMatrixSubPropertiesNV[i]; + + if (cmp.AType == (VkComponentTypeNV)type && cmp.BType == (VkComponentTypeNV)type + && cmp.CType == (VkComponentTypeNV)acctype && cmp.DType == (VkComponentTypeNV)acctype + && cmp.scope == (VkScopeNV)scope) + { + VkCooperativeMatrixPropertiesKHR cmp_khr; + cmp_khr.MSize = cmp.MSize; + cmp_khr.NSize = cmp.NSize; + cmp_khr.KSize = cmp.KSize; + + mnk_properties.push_back(cmp_khr); + } + } + } + + if (mnk_properties.empty() && (acctype == VK_COMPONENT_TYPE_FLOAT16_KHR || acctype == VK_COMPONENT_TYPE_BFLOAT16_KHR)) + { + // try acctype fp32 + return get_optimal_cooperative_matrix_mnk(M, N, K, type, VK_COMPONENT_TYPE_FLOAT32_KHR, scope, coopmat_M, coopmat_N, coopmat_K); + } + + if (mnk_properties.empty()) + return; + + // find the optimal, prefer the first mnk tuple with same cost + double min_cost = DBL_MAX; + for (size_t i = 0; i < mnk_properties.size(); i++) + { + const VkCooperativeMatrixPropertiesKHR& cmp = mnk_properties[i]; + + const int M_pad = (M + cmp.MSize - 1) / cmp.MSize * cmp.MSize; + const int N_pad = (N + cmp.NSize - 1) / cmp.NSize * cmp.NSize; + const int K_pad = (K + cmp.KSize - 1) / cmp.KSize * cmp.KSize; + + double cost = M_pad * N_pad * K_pad - M * N * K; + if (cost < min_cost) + { + min_cost = cost; + coopmat_M = cmp.MSize; + coopmat_N = cmp.NSize; + coopmat_K = cmp.KSize; + } + } +} + static int init_instance_core() { vkAllocateCommandBuffers = (PFN_vkAllocateCommandBuffers)vkGetInstanceProcAddr(g_instance, "vkAllocateCommandBuffers"); diff --git a/src/gpu.h b/src/gpu.h index b9c038de1..7863b2e21 100644 --- a/src/gpu.h +++ b/src/gpu.h @@ -381,6 +381,9 @@ public: const std::vector& queryCooperativeMatrixFlexibleDimensionsSubPropertiesNV() const; const std::vector& queryCooperativeVectorSubPropertiesNV() const; + // some utility functions + void get_optimal_cooperative_matrix_mnk(int M, int N, int K, VkComponentTypeKHR type, VkComponentTypeKHR acctype, VkScopeKHR scope, int& coopmat_M, int& coopmat_N, int& coopmat_K) const; + private: GpuInfo(const GpuInfo&); GpuInfo& operator=(const GpuInfo&); diff --git a/src/layer/vulkan/convolution_vulkan.cpp b/src/layer/vulkan/convolution_vulkan.cpp index 57acce84c..115d4d684 100644 --- a/src/layer/vulkan/convolution_vulkan.cpp +++ b/src/layer/vulkan/convolution_vulkan.cpp @@ -29,6 +29,16 @@ Convolution_vulkan::Convolution_vulkan() reshape_1x1xw = 0; reshape_w = 0; + + use_cooperative_matrix = false; + coopmat_M = 0; + coopmat_N = 0; + coopmat_K = 0; + UNROLL_SG_M = 1; + UNROLL_SG_N = 1; + UNROLL_SG_K = 1; + UNROLL_WG_M = 1; + UNROLL_WG_N = 1; } int Convolution_vulkan::load_param(const ParamDict& pd) @@ -684,99 +694,219 @@ int Convolution_vulkan::create_pipeline(const Option& _opt) } } } - else + else if (opt.use_sgemm_convolution && !is_conv1x1s1d1 && num_input >= 16 && num_output >= 16) { - // src = kw-kh-inch-outch - // dst = pa-pb-kw-kh-inch/pa-outch/pb - if (opt.use_sgemm_convolution && !is_conv1x1s1d1 && num_input >= 16 && num_output >= 16) + bool use_cooperative_matrix_16_8_8 = vkdev->info.support_cooperative_matrix_16_8_8() && opt.use_cooperative_matrix && !opt.use_shader_pack8 && opt.use_fp16_storage && num_input % 8 == 0 && num_output % 8 == 0; + bool use_cooperative_matrix_16_16_16 = vkdev->info.support_cooperative_matrix_16_16_16() && opt.use_cooperative_matrix && !opt.use_shader_pack8 && opt.use_fp16_storage && num_input % 16 == 0 && num_output % 16 == 0; + if (vkdev->info.subgroup_size() != 32 && (!vkdev->info.support_subgroup_size_control() || vkdev->info.min_subgroup_size() > 32 || vkdev->info.max_subgroup_size() < 32)) + { + use_cooperative_matrix_16_8_8 = false; + use_cooperative_matrix_16_16_16 = false; + } + + if (use_cooperative_matrix_16_8_8) { - bool use_cooperative_matrix_16_8_8 = vkdev->info.support_cooperative_matrix_16_8_8() && opt.use_cooperative_matrix && !opt.use_shader_pack8 && opt.use_fp16_storage && num_input % 8 == 0 && num_output % 8 == 0; - bool use_cooperative_matrix_16_16_16 = vkdev->info.support_cooperative_matrix_16_16_16() && opt.use_cooperative_matrix && !opt.use_shader_pack8 && opt.use_fp16_storage && num_input % 16 == 0 && num_output % 16 == 0; - if (vkdev->info.subgroup_size() != 32 && (!vkdev->info.support_subgroup_size_control() || vkdev->info.min_subgroup_size() > 32 || vkdev->info.max_subgroup_size() < 32)) + // dst = 8b-8a-maxk-inch/8a-outch/8b + Mat weight_data_r2 = weight_data.reshape(maxk, num_input, num_output); + + weight_data_packed.create(maxk * num_input / 8, num_output / 8, (size_t)4 * 8 * 8, 8 * 8); + + for (int q = 0; q + 7 < num_output; q += 8) { - use_cooperative_matrix_16_8_8 = false; - use_cooperative_matrix_16_16_16 = false; + float* g00 = weight_data_packed.row(q / 8); + + for (int p = 0; p + 7 < num_input; p += 8) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 8; i++) + { + for (int j = 0; j < 8; j++) + { + const float* k00 = weight_data_r2.channel(q + j).row(p + i); + g00[0] = k00[k]; + g00++; + } + } + } + } } + } + else if (use_cooperative_matrix_16_16_16) + { + // dst = 16b-16a-maxk-inch/16a-outch/16b + Mat weight_data_r2 = weight_data.reshape(maxk, num_input, num_output); - if (use_cooperative_matrix_16_8_8) - { - // dst = 8b-8a-maxk-inch/8a-outch/8b - Mat weight_data_r2 = weight_data.reshape(maxk, num_input, num_output); + weight_data_packed.create(maxk * num_input / 16, num_output / 16, (size_t)4 * 16 * 16, 16 * 16); - weight_data_packed.create(maxk * num_input / 8, num_output / 8, (size_t)4 * 8 * 8, 8 * 8); + for (int q = 0; q + 15 < num_output; q += 16) + { + float* g00 = weight_data_packed.row(q / 16); - for (int q = 0; q + 7 < num_output; q += 8) + for (int p = 0; p + 15 < num_input; p += 16) { - float* g00 = weight_data_packed.row(q / 8); - - for (int p = 0; p + 7 < num_input; p += 8) + for (int k = 0; k < maxk; k++) { - for (int k = 0; k < maxk; k++) + for (int i = 0; i < 16; i++) { - for (int i = 0; i < 8; i++) + for (int j = 0; j < 16; j++) { - for (int j = 0; j < 8; j++) - { - const float* k00 = weight_data_r2.channel(q + j).row(p + i); - g00[0] = k00[k]; - g00++; - } + const float* k00 = weight_data_r2.channel(q + j).row(p + i); + g00[0] = k00[k]; + g00++; } } } } } - else if (use_cooperative_matrix_16_16_16) - { - // dst = 16b-16a-maxk-inch/16a-outch/16b - Mat weight_data_r2 = weight_data.reshape(maxk, num_input, num_output); + } + else + { + Mat weight_data_r2 = weight_data.reshape(maxk, num_input, num_output); - weight_data_packed.create(maxk * num_input / 16, num_output / 16, (size_t)4 * 16 * 16, 16 * 16); + weight_data_packed.create(maxk * num_input / elempack, num_output / out_elempack, (size_t)4 * elempack * out_elempack, elempack * out_elempack); - for (int q = 0; q + 15 < num_output; q += 16) - { - float* g00 = weight_data_packed.row(q / 16); + for (int q = 0; q + (out_elempack - 1) < num_output; q += out_elempack) + { + float* g00 = weight_data_packed.row(q / out_elempack); - for (int p = 0; p + 15 < num_input; p += 16) + for (int p = 0; p + (elempack - 1) < num_input; p += elempack) + { + for (int k = 0; k < maxk; k++) { - for (int k = 0; k < maxk; k++) + for (int i = 0; i < out_elempack; i++) { - for (int i = 0; i < 16; i++) + const Mat k0 = weight_data_r2.channel(q + i); + + for (int j = 0; j < elempack; j++) { - for (int j = 0; j < 16; j++) - { - const float* k00 = weight_data_r2.channel(q + j).row(p + i); - g00[0] = k00[k]; - g00++; - } + const float* k00 = k0.row(p + j); + g00[0] = k00[k]; + g00++; } } } } } - else - { - Mat weight_data_r2 = weight_data.reshape(maxk, num_input, num_output); + } + } + else if (is_conv1x1s1d1) + { + use_cooperative_matrix = vkdev->info.support_cooperative_matrix() && opt.use_cooperative_matrix && !opt.use_shader_pack8 && opt.use_fp16_storage && num_input >= 8 && num_output >= 8; - weight_data_packed.create(maxk * num_input / elempack, num_output / out_elempack, (size_t)4 * elempack * out_elempack, elempack * out_elempack); + if (use_cooperative_matrix) + { + int size = 1024; + if (shape_bordered_packed.dims == 3) + size = shape_bordered_packed.w * shape_bordered_packed.h; + + vkdev->info.get_optimal_cooperative_matrix_mnk(size, num_output, num_input, VK_COMPONENT_TYPE_FLOAT16_KHR, opt.use_fp16_arithmetic ? VK_COMPONENT_TYPE_FLOAT16_KHR : VK_COMPONENT_TYPE_FLOAT32_KHR, VK_SCOPE_SUBGROUP_KHR, coopmat_M, coopmat_N, coopmat_K); + + // assert coopmat_M != 0 && coopmat_N != 0 && coopmat_K != 0 + + UNROLL_SG_M = std::min((size + coopmat_M - 1) / coopmat_M, 2); + UNROLL_SG_N = std::min((num_output + coopmat_N - 1) / coopmat_N, 2); + UNROLL_SG_K = std::min((num_input + coopmat_K - 1) / coopmat_K, 2); + + UNROLL_WG_M = std::min((size + coopmat_M * UNROLL_SG_M - 1) / (coopmat_M * UNROLL_SG_M), 2); + UNROLL_WG_N = std::min((num_output + coopmat_N * UNROLL_SG_N - 1) / (coopmat_N * UNROLL_SG_N), 2); + + // +-N-+ + // K | + // +SG_UN + // | | + // ^ +---+ + // | | | + // SG_UK+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UN +---+ + // | | | + // | +- -+ + // | | | + // v +---+ + + // +-N-+ + // K | + // +SG_UN + // | | + // ^ +---+ + // | | | + // WG_UN+- -+ + // | | | + // v +---+ + + const int blocks_n = (num_output + coopmat_N * UNROLL_SG_N * UNROLL_WG_N - 1) / (coopmat_N * UNROLL_SG_N * UNROLL_WG_N); + // const int blocks_k = (num_input + coopmat_K * UNROLL_SG_K - 1) / (coopmat_K * UNROLL_SG_K); + const int kk = (num_input + coopmat_K - 1) / coopmat_K; + + weight_data_packed.create(coopmat_N * coopmat_K * UNROLL_SG_N * UNROLL_WG_N * kk, blocks_n); + for (int bn = 0; bn < blocks_n; bn++) + { + float* p = weight_data_packed.row(bn); - for (int q = 0; q + (out_elempack - 1) < num_output; q += out_elempack) + int k = 0; + for (; k + UNROLL_SG_K - 1 < kk; k += UNROLL_SG_K) { - float* g00 = weight_data_packed.row(q / out_elempack); + // const int ki = k * coopmat_K; - for (int p = 0; p + (elempack - 1) < num_input; p += elempack) + for (int wn = 0; wn < UNROLL_WG_N; wn++) { - for (int k = 0; k < maxk; k++) + for (int zk = 0; zk < UNROLL_SG_K; zk++) { - for (int i = 0; i < out_elempack; i++) + for (int zn = 0; zn < UNROLL_SG_N; zn++) { - const Mat k0 = weight_data_r2.channel(q + i); + for (int i = 0; i < coopmat_K; i++) + { + for (int j = 0; j < coopmat_N; j++) + { + const int gni = ((bn * UNROLL_WG_N + wn) * UNROLL_SG_N + zn) * coopmat_N + j; + const int gki = (k + zk) * coopmat_K + i; + + if (gni < num_output && gki < num_input) + { + *p++ = weight_data[gni * num_input + gki]; + } + else + { + *p++ = 0.f; + } + } + } + } + } + } + } + for (; k < kk; k++) + { + // const int ki = k * coopmat_K; - for (int j = 0; j < elempack; j++) + for (int wn = 0; wn < UNROLL_WG_N; wn++) + { + // for (int zk = 0; zk < UNROLL_SG_K; zk++) + { + for (int zn = 0; zn < UNROLL_SG_N; zn++) + { + for (int i = 0; i < coopmat_K; i++) { - const float* k00 = k0.row(p + j); - g00[0] = k00[k]; - g00++; + for (int j = 0; j < coopmat_N; j++) + { + const int gni = ((bn * UNROLL_WG_N + wn) * UNROLL_SG_N + zn) * coopmat_N + j; + // const int gki = (k + zk) * coopmat_K + i; + const int gki = k * coopmat_K + i; + + if (gni < num_output && gki < num_input) + { + *p++ = weight_data[gni * num_input + gki]; + } + else + { + *p++ = 0.f; + } + } } } } @@ -786,94 +916,125 @@ int Convolution_vulkan::create_pipeline(const Option& _opt) } else { - bool use_cooperative_matrix_16_8_8 = vkdev->info.support_cooperative_matrix_16_8_8() && opt.use_cooperative_matrix && is_conv1x1s1d1 && !opt.use_shader_pack8 && opt.use_fp16_storage && num_input % 8 == 0 && num_output % 8 == 0; - bool use_cooperative_matrix_16_16_16 = vkdev->info.support_cooperative_matrix_16_16_16() && opt.use_cooperative_matrix && is_conv1x1s1d1 && !opt.use_shader_pack8 && opt.use_fp16_storage && num_input % 16 == 0 && num_output % 16 == 0; - if (vkdev->info.subgroup_size() != 32 && (!vkdev->info.support_subgroup_size_control() || vkdev->info.min_subgroup_size() > 32 || vkdev->info.max_subgroup_size() < 32)) + Mat weight_data_r2 = weight_data.reshape(maxk, num_input, num_output); + + weight_data_packed.create(maxk, num_input / elempack, num_output / out_elempack, (size_t)4 * elempack * out_elempack, elempack * out_elempack); + + for (int q = 0; q + (out_elempack - 1) < num_output; q += out_elempack) { - use_cooperative_matrix_16_8_8 = false; - use_cooperative_matrix_16_16_16 = false; + float* g00 = weight_data_packed.channel(q / out_elempack); + + for (int p = 0; p + (elempack - 1) < num_input; p += elempack) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < out_elempack; i++) + { + const Mat k0 = weight_data_r2.channel(q + i); + + for (int j = 0; j < elempack; j++) + { + const float* k00 = k0.row(p + j); + g00[0] = k00[k]; + g00++; + } + } + } + } } + } + } + else + { + // src = kw-kh-inch-outch + // dst = pa-pb-kw-kh-inch/pa-outch/pb + bool use_cooperative_matrix_16_8_8 = vkdev->info.support_cooperative_matrix_16_8_8() && opt.use_cooperative_matrix && is_conv1x1s1d1 && !opt.use_shader_pack8 && opt.use_fp16_storage && num_input % 8 == 0 && num_output % 8 == 0; + bool use_cooperative_matrix_16_16_16 = vkdev->info.support_cooperative_matrix_16_16_16() && opt.use_cooperative_matrix && is_conv1x1s1d1 && !opt.use_shader_pack8 && opt.use_fp16_storage && num_input % 16 == 0 && num_output % 16 == 0; + if (vkdev->info.subgroup_size() != 32 && (!vkdev->info.support_subgroup_size_control() || vkdev->info.min_subgroup_size() > 32 || vkdev->info.max_subgroup_size() < 32)) + { + use_cooperative_matrix_16_8_8 = false; + use_cooperative_matrix_16_16_16 = false; + } - if (use_cooperative_matrix_16_8_8) - { - // dst = 8b-8a-inch/8a-outch/8b - Mat weight_data_r2 = weight_data.reshape(maxk, num_input, num_output); + if (use_cooperative_matrix_16_8_8) + { + // dst = 8b-8a-inch/8a-outch/8b + Mat weight_data_r2 = weight_data.reshape(maxk, num_input, num_output); - weight_data_packed.create(maxk, num_input / 8, num_output / 8, (size_t)4 * 8 * 8, 8 * 8); + weight_data_packed.create(maxk, num_input / 8, num_output / 8, (size_t)4 * 8 * 8, 8 * 8); - for (int q = 0; q + 7 < num_output; q += 8) - { - float* g00 = weight_data_packed.channel(q / 8); + for (int q = 0; q + 7 < num_output; q += 8) + { + float* g00 = weight_data_packed.channel(q / 8); - for (int p = 0; p + 7 < num_input; p += 8) + for (int p = 0; p + 7 < num_input; p += 8) + { + for (int k = 0; k < maxk; k++) { - for (int k = 0; k < maxk; k++) + for (int i = 0; i < 8; i++) { - for (int i = 0; i < 8; i++) + for (int j = 0; j < 8; j++) { - for (int j = 0; j < 8; j++) - { - const float* k00 = weight_data_r2.channel(q + j).row(p + i); - g00[0] = k00[k]; - g00++; - } + const float* k00 = weight_data_r2.channel(q + j).row(p + i); + g00[0] = k00[k]; + g00++; } } } } } - else if (use_cooperative_matrix_16_16_16) - { - // dst = 16b-16a-inch/16a-outch/16b - Mat weight_data_r2 = weight_data.reshape(maxk, num_input, num_output); + } + else if (use_cooperative_matrix_16_16_16) + { + // dst = 16b-16a-inch/16a-outch/16b + Mat weight_data_r2 = weight_data.reshape(maxk, num_input, num_output); - weight_data_packed.create(maxk, num_input / 16, num_output / 16, (size_t)4 * 16 * 16, 16 * 16); + weight_data_packed.create(maxk, num_input / 16, num_output / 16, (size_t)4 * 16 * 16, 16 * 16); - for (int q = 0; q + 15 < num_output; q += 16) - { - float* g00 = weight_data_packed.channel(q / 16); + for (int q = 0; q + 15 < num_output; q += 16) + { + float* g00 = weight_data_packed.channel(q / 16); - for (int p = 0; p + 15 < num_input; p += 16) + for (int p = 0; p + 15 < num_input; p += 16) + { + for (int k = 0; k < maxk; k++) { - for (int k = 0; k < maxk; k++) + for (int i = 0; i < 16; i++) { - for (int i = 0; i < 16; i++) + for (int j = 0; j < 16; j++) { - for (int j = 0; j < 16; j++) - { - const float* k00 = weight_data_r2.channel(q + j).row(p + i); - g00[0] = k00[k]; - g00++; - } + const float* k00 = weight_data_r2.channel(q + j).row(p + i); + g00[0] = k00[k]; + g00++; } } } } } - else - { - Mat weight_data_r2 = weight_data.reshape(maxk, num_input, num_output); + } + else + { + Mat weight_data_r2 = weight_data.reshape(maxk, num_input, num_output); - weight_data_packed.create(maxk, num_input / elempack, num_output / out_elempack, (size_t)4 * elempack * out_elempack, elempack * out_elempack); + weight_data_packed.create(maxk, num_input / elempack, num_output / out_elempack, (size_t)4 * elempack * out_elempack, elempack * out_elempack); - for (int q = 0; q + (out_elempack - 1) < num_output; q += out_elempack) - { - float* g00 = weight_data_packed.channel(q / out_elempack); + for (int q = 0; q + (out_elempack - 1) < num_output; q += out_elempack) + { + float* g00 = weight_data_packed.channel(q / out_elempack); - for (int p = 0; p + (elempack - 1) < num_input; p += elempack) + for (int p = 0; p + (elempack - 1) < num_input; p += elempack) + { + for (int k = 0; k < maxk; k++) { - for (int k = 0; k < maxk; k++) + for (int i = 0; i < out_elempack; i++) { - for (int i = 0; i < out_elempack; i++) - { - const Mat k0 = weight_data_r2.channel(q + i); + const Mat k0 = weight_data_r2.channel(q + i); - for (int j = 0; j < elempack; j++) - { - const float* k00 = k0.row(p + j); - g00[0] = k00[k]; - g00++; - } + for (int j = 0; j < elempack; j++) + { + const float* k00 = k0.row(p + j); + g00[0] = k00[k]; + g00++; } } } @@ -971,68 +1132,74 @@ int Convolution_vulkan::create_pipeline(const Option& _opt) } else if (is_conv1x1s1d1) { - bool use_cooperative_matrix_16_8_8 = vkdev->info.support_cooperative_matrix_16_8_8() && opt.use_cooperative_matrix && !opt.use_shader_pack8 && opt.use_fp16_storage && num_input % 8 == 0 && num_output % 8 == 0; - bool use_cooperative_matrix_16_16_16 = vkdev->info.support_cooperative_matrix_16_16_16() && opt.use_cooperative_matrix && !opt.use_shader_pack8 && opt.use_fp16_storage && num_input % 16 == 0 && num_output % 16 == 0; - if (vkdev->info.subgroup_size() != 32 && (!vkdev->info.support_subgroup_size_control() || vkdev->info.min_subgroup_size() > 32 || vkdev->info.max_subgroup_size() < 32)) - { - use_cooperative_matrix_16_8_8 = false; - use_cooperative_matrix_16_16_16 = false; - } - - std::vector specializations(4 + 8); - specializations[0].i = bias_term; - specializations[1].i = activation_type; - specializations[2].f = activation_params.w >= 1 ? activation_params[0] : 0.f; - specializations[3].f = activation_params.w == 2 ? activation_params[1] : 0.f; - specializations[4 + 0].i = shape_bordered_packed.w; - specializations[4 + 1].i = shape_bordered_packed.h; - specializations[4 + 2].i = shape_bordered_packed.c; - specializations[4 + 3].i = shape_bordered_packed.cstep; - specializations[4 + 4].i = out_shape_packed.w; - specializations[4 + 5].i = out_shape_packed.h; - specializations[4 + 6].i = out_shape_packed.c; - specializations[4 + 7].i = out_shape_packed.cstep; - - int shader_type_index = -1; - if (elempack == 1 && out_elempack == 1) shader_type_index = LayerShaderType::convolution_1x1s1d1; - if (elempack == 4 && out_elempack == 4) shader_type_index = LayerShaderType::convolution_pack4_1x1s1d1; - if (elempack == 1 && out_elempack == 4) shader_type_index = LayerShaderType::convolution_pack1to4_1x1s1d1; - if (elempack == 4 && out_elempack == 1) shader_type_index = LayerShaderType::convolution_pack4to1_1x1s1d1; - if (elempack == 8 && out_elempack == 8) shader_type_index = LayerShaderType::convolution_pack8_1x1s1d1; - if (elempack == 1 && out_elempack == 8) shader_type_index = LayerShaderType::convolution_pack1to8_1x1s1d1; - if (elempack == 8 && out_elempack == 1) shader_type_index = LayerShaderType::convolution_pack8to1_1x1s1d1; - if (elempack == 4 && out_elempack == 8) shader_type_index = LayerShaderType::convolution_pack4to8_1x1s1d1; - if (elempack == 8 && out_elempack == 4) shader_type_index = LayerShaderType::convolution_pack8to4_1x1s1d1; - - if (use_cooperative_matrix_16_8_8) - { - shader_type_index = LayerShaderType::convolution_pack4_1x1s1d1_cm_16_8_8; - } - else if (use_cooperative_matrix_16_16_16) - { - shader_type_index = LayerShaderType::convolution_pack4_1x1s1d1_cm_16_16_16; - } - - pipeline_convolution_1x1s1d1 = new Pipeline(vkdev); - if (use_cooperative_matrix_16_8_8) + if (use_cooperative_matrix) { - pipeline_convolution_1x1s1d1->set_subgroup_size(32); - pipeline_convolution_1x1s1d1->set_local_size_xyz(32, 1, 1); // 16_8_8 - } - else if (use_cooperative_matrix_16_16_16) - { - pipeline_convolution_1x1s1d1->set_subgroup_size(32); - pipeline_convolution_1x1s1d1->set_local_size_xyz(32, 1, 1); // 16_16_16 - } - else if (opt.use_shader_local_memory) - { - pipeline_convolution_1x1s1d1->set_local_size_xyz(8, 8, 1); + std::vector specializations(16 + 3); + specializations[0].i = bias_term; + specializations[1].i = activation_type; + specializations[2].f = activation_params.w >= 1 ? activation_params[0] : 0.f; + specializations[3].f = activation_params.w == 2 ? activation_params[1] : 0.f; + specializations[4].u32 = coopmat_M; + specializations[5].u32 = coopmat_N; + specializations[6].u32 = coopmat_K; + specializations[7].u32 = UNROLL_SG_M; + specializations[8].u32 = UNROLL_SG_N; + specializations[9].u32 = UNROLL_SG_K; + specializations[10].u32 = UNROLL_WG_M; + specializations[11].u32 = UNROLL_WG_N; + specializations[12].u32 = num_input; + specializations[13].u32 = num_output; + specializations[14].u32 = elempack; + specializations[15].u32 = out_elempack; + specializations[16 + 0].u32 = shape_bordered_packed.w * shape_bordered_packed.h; + specializations[16 + 1].u32 = shape_bordered_packed.cstep; + specializations[16 + 2].u32 = out_shape_packed.cstep; + + const int subgroup_size = vkdev->info.subgroup_size(); + + pipeline_convolution_1x1s1d1 = new Pipeline(vkdev); + pipeline_convolution_1x1s1d1->set_subgroup_size(subgroup_size); + pipeline_convolution_1x1s1d1->set_local_size_xyz(subgroup_size * UNROLL_WG_M * UNROLL_WG_N, 1, 1); + pipeline_convolution_1x1s1d1->create(LayerShaderType::convolution_1x1s1d1_cm, opt, specializations); } else { - pipeline_convolution_1x1s1d1->set_local_size_xyz(8, std::min(8, num_output / out_elempack), 1); + std::vector specializations(4 + 8); + specializations[0].i = bias_term; + specializations[1].i = activation_type; + specializations[2].f = activation_params.w >= 1 ? activation_params[0] : 0.f; + specializations[3].f = activation_params.w == 2 ? activation_params[1] : 0.f; + specializations[4 + 0].i = shape_bordered_packed.w; + specializations[4 + 1].i = shape_bordered_packed.h; + specializations[4 + 2].i = shape_bordered_packed.c; + specializations[4 + 3].i = shape_bordered_packed.cstep; + specializations[4 + 4].i = out_shape_packed.w; + specializations[4 + 5].i = out_shape_packed.h; + specializations[4 + 6].i = out_shape_packed.c; + specializations[4 + 7].i = out_shape_packed.cstep; + + int shader_type_index = -1; + if (elempack == 1 && out_elempack == 1) shader_type_index = LayerShaderType::convolution_1x1s1d1; + if (elempack == 4 && out_elempack == 4) shader_type_index = LayerShaderType::convolution_pack4_1x1s1d1; + if (elempack == 1 && out_elempack == 4) shader_type_index = LayerShaderType::convolution_pack1to4_1x1s1d1; + if (elempack == 4 && out_elempack == 1) shader_type_index = LayerShaderType::convolution_pack4to1_1x1s1d1; + if (elempack == 8 && out_elempack == 8) shader_type_index = LayerShaderType::convolution_pack8_1x1s1d1; + if (elempack == 1 && out_elempack == 8) shader_type_index = LayerShaderType::convolution_pack1to8_1x1s1d1; + if (elempack == 8 && out_elempack == 1) shader_type_index = LayerShaderType::convolution_pack8to1_1x1s1d1; + if (elempack == 4 && out_elempack == 8) shader_type_index = LayerShaderType::convolution_pack4to8_1x1s1d1; + if (elempack == 8 && out_elempack == 4) shader_type_index = LayerShaderType::convolution_pack8to4_1x1s1d1; + + pipeline_convolution_1x1s1d1 = new Pipeline(vkdev); + if (opt.use_shader_local_memory) + { + pipeline_convolution_1x1s1d1->set_local_size_xyz(8, 8, 1); + } + else + { + pipeline_convolution_1x1s1d1->set_local_size_xyz(8, std::min(8, num_output / out_elempack), 1); + } + pipeline_convolution_1x1s1d1->create(shader_type_index, opt, specializations); } - pipeline_convolution_1x1s1d1->create(shader_type_index, opt, specializations); } else { @@ -1138,6 +1305,16 @@ int Convolution_vulkan::destroy_pipeline(const Option& opt) reshape_w = 0; } + use_cooperative_matrix = false; + coopmat_M = 0; + coopmat_N = 0; + coopmat_K = 0; + UNROLL_SG_M = 1; + UNROLL_SG_N = 1; + UNROLL_SG_K = 1; + UNROLL_WG_M = 1; + UNROLL_WG_N = 1; + return 0; } @@ -1583,55 +1760,64 @@ int Convolution_vulkan::forward(const VkMat& bottom_blob, VkMat& top_blob, VkCom return 0; } - if (is_conv1x1s1d1) + else if (is_conv1x1s1d1) { - bool use_cooperative_matrix_16_8_8 = vkdev->info.support_cooperative_matrix_16_8_8() && opt.use_cooperative_matrix && !opt.use_shader_pack8 && opt.use_fp16_storage && channels * elempack % 8 == 0 && num_output % 8 == 0; - bool use_cooperative_matrix_16_16_16 = vkdev->info.support_cooperative_matrix_16_16_16() && opt.use_cooperative_matrix && !opt.use_shader_pack8 && opt.use_fp16_storage && channels * elempack % 16 == 0 && num_output % 16 == 0; - if (vkdev->info.subgroup_size() != 32 && (!vkdev->info.support_subgroup_size_control() || vkdev->info.min_subgroup_size() > 32 || vkdev->info.max_subgroup_size() < 32)) - { - use_cooperative_matrix_16_8_8 = false; - use_cooperative_matrix_16_16_16 = false; - } - top_blob.create(outw, outh, num_output / out_elempack, out_elemsize, out_elempack, opt.blob_vkallocator); if (top_blob.empty()) return -100; - std::vector bindings(4); - bindings[0] = bottom_blob_bordered; - bindings[1] = top_blob; - bindings[2] = weight_data_gpu; - bindings[3] = bias_data_gpu; + const int num_input = channels * elempack; - std::vector constants(8); - constants[0].i = bottom_blob_bordered.w; - constants[1].i = bottom_blob_bordered.h; - constants[2].i = bottom_blob_bordered.c; - constants[3].i = bottom_blob_bordered.cstep; - constants[4].i = top_blob.w; - constants[5].i = top_blob.h; - constants[6].i = top_blob.c; - constants[7].i = top_blob.cstep; + if (use_cooperative_matrix) + { + std::vector bindings(4); + bindings[0] = bottom_blob_bordered; + bindings[1] = top_blob; + bindings[2] = weight_data_gpu; + bindings[3] = bias_data_gpu; - VkMat dispatcher; - dispatcher.w = (top_blob.w * top_blob.h + 3) / 4; - dispatcher.h = top_blob.c; - dispatcher.c = 1; + std::vector constants(3); + constants[0].u32 = bottom_blob_bordered.w * bottom_blob_bordered.h; + constants[1].u32 = bottom_blob_bordered.cstep; + constants[2].u32 = top_blob.cstep; - if (use_cooperative_matrix_16_8_8) - { - dispatcher.w = ((top_blob.w * top_blob.h + 15) / 16 + 1) / 2 * 32; - dispatcher.h = ((top_blob.c + 1) / 2 + 3) / 4; + const int blocks_x = (top_blob.w * top_blob.h + coopmat_M * UNROLL_SG_M * UNROLL_WG_M - 1) / (coopmat_M * UNROLL_SG_M * UNROLL_WG_M); + const int blocks_y = (num_output + coopmat_N * UNROLL_SG_N * UNROLL_WG_N - 1) / (coopmat_N * UNROLL_SG_N * UNROLL_WG_N); + + const int subgroup_size = vkdev->info.subgroup_size(); + + VkMat dispatcher; + dispatcher.w = (blocks_x * blocks_y) * (subgroup_size * UNROLL_WG_M * UNROLL_WG_N); + dispatcher.h = 1; dispatcher.c = 1; + + cmd.record_pipeline(pipeline_convolution_1x1s1d1, bindings, constants, dispatcher); } - else if (use_cooperative_matrix_16_16_16) + else { - dispatcher.w = ((top_blob.w * top_blob.h + 15) / 16 + 1) / 2 * 32; - dispatcher.h = ((top_blob.c + 3) / 4 + 1) / 2; + std::vector bindings(4); + bindings[0] = bottom_blob_bordered; + bindings[1] = top_blob; + bindings[2] = weight_data_gpu; + bindings[3] = bias_data_gpu; + + std::vector constants(8); + constants[0].i = bottom_blob_bordered.w; + constants[1].i = bottom_blob_bordered.h; + constants[2].i = bottom_blob_bordered.c; + constants[3].i = bottom_blob_bordered.cstep; + constants[4].i = top_blob.w; + constants[5].i = top_blob.h; + constants[6].i = top_blob.c; + constants[7].i = top_blob.cstep; + + VkMat dispatcher; + dispatcher.w = (top_blob.w * top_blob.h + 3) / 4; + dispatcher.h = top_blob.c; dispatcher.c = 1; - } - cmd.record_pipeline(pipeline_convolution_1x1s1d1, bindings, constants, dispatcher); + cmd.record_pipeline(pipeline_convolution_1x1s1d1, bindings, constants, dispatcher); + } return 0; } diff --git a/src/layer/vulkan/convolution_vulkan.h b/src/layer/vulkan/convolution_vulkan.h index 5ffccc3e5..aeb353735 100644 --- a/src/layer/vulkan/convolution_vulkan.h +++ b/src/layer/vulkan/convolution_vulkan.h @@ -53,6 +53,17 @@ public: // convolution as fc ncnn::Layer* reshape_1x1xw; ncnn::Layer* reshape_w; + + // cooperative matrix + bool use_cooperative_matrix; + int coopmat_M; + int coopmat_N; + int coopmat_K; + int UNROLL_SG_M; + int UNROLL_SG_N; + int UNROLL_SG_K; + int UNROLL_WG_M; + int UNROLL_WG_N; }; } // namespace ncnn diff --git a/src/layer/vulkan/shader/convolution_1x1s1d1_cm.comp b/src/layer/vulkan/shader/convolution_1x1s1d1_cm.comp new file mode 100644 index 000000000..4d46f572f --- /dev/null +++ b/src/layer/vulkan/shader/convolution_1x1s1d1_cm.comp @@ -0,0 +1,1253 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#version 450 + +#extension GL_GOOGLE_include_directive: require +#include "vulkan_activation.comp" + +#extension GL_EXT_control_flow_attributes: require + +#extension GL_KHR_shader_subgroup_basic: require + +#extension GL_KHR_memory_scope_semantics: require +#extension GL_EXT_shader_explicit_arithmetic_types: require +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#if ncnn_VK_KHR_cooperative_matrix +#extension GL_KHR_cooperative_matrix: require +#elif ncnn_VK_NV_cooperative_matrix +#extension GL_NV_cooperative_matrix: require +#endif + +layout (constant_id = 0) const int bias_term = 0; +layout (constant_id = 1) const int activation_type = 0; +layout (constant_id = 2) const float activation_param_0 = 0; +layout (constant_id = 3) const float activation_param_1 = 0; +layout (constant_id = 4) const uint M = 1; +layout (constant_id = 5) const uint N = 1; +layout (constant_id = 6) const uint K = 1; +layout (constant_id = 7) const uint UNROLL_SG_M = 2; +layout (constant_id = 8) const uint UNROLL_SG_N = 2; +layout (constant_id = 9) const uint UNROLL_SG_K = 2; +layout (constant_id = 10) const uint UNROLL_WG_M = 2; +layout (constant_id = 11) const uint UNROLL_WG_N = 2; +layout (constant_id = 12) const uint inch = 1; +layout (constant_id = 13) const uint outch = 1; +layout (constant_id = 14) const uint elempack = 1; +layout (constant_id = 15) const uint out_elempack = 1; + +#define shape_constant_id_offset 16 +layout (constant_id = shape_constant_id_offset + 0) const uint size = 0; +layout (constant_id = shape_constant_id_offset + 1) const uint cstep = 0; +layout (constant_id = shape_constant_id_offset + 2) const uint outcstep = 0; + +layout (binding = 0) readonly buffer bottom_blob { uvec2 bottom_blob_data[]; }; +layout (binding = 1) writeonly buffer top_blob { uvec2 top_blob_data[]; }; +layout (binding = 2) readonly buffer weight_blob { uvec2 weight_data[]; }; +layout (binding = 3) readonly buffer bias_blob { uvec2 bias_data[]; }; + +layout (push_constant) uniform parameter +{ + uint size; + uint cstep; + uint outcstep; +} p; + +shared uvec2 tmp_v[UNROLL_WG_M][UNROLL_SG_M * UNROLL_SG_K * M * K / 4]; + +shared uvec2 tmp_k[UNROLL_WG_N][UNROLL_SG_N * UNROLL_SG_K * K * N / 4]; + +shared uvec2 tmp_o[UNROLL_WG_N * UNROLL_WG_M][UNROLL_SG_N * UNROLL_SG_M * M * N / 4]; + +void main() +{ + // assert gl_WorkGroupSize.x == gl_SubgroupSize + // but neither gl_SubgroupSize nor gl_WorkGroupSize.x is a constant + const uint local_size = ncnn_subgroupSize * UNROLL_WG_M * UNROLL_WG_N; + + // [ WG_UN * WG_UM * [ SG_UN * SG_UM * subgroup ] ] + + // <----WG_UN----> + // +---N--+-SG_UN+------+------+ + // | | | |XXXXXX| + // M | XXXX<----coopmat + // | | | |XXXXXX| + // +-- --SG0-- --+-- --SG2-- --+ + // | | | | | + // SG_UM | | + // | | | | | + // ^ +------+--WORKGROUP--+------+ + // | | | | | | + // | | | | + // | | | | | | + // WG_UM+-- --SG1-- --+-- --SG3-- --+ + // | | | | | | + // | | | | + // | | | | | | + // v +------+------+------+------+ + // + + const uint wgi = gl_WorkGroupID.x; + const uint sgi = gl_SubgroupID; + + const uint wgmm = (psc(size) + M * UNROLL_SG_M * UNROLL_WG_M - 1) / (M * UNROLL_SG_M * UNROLL_WG_M); + const uint wgnn = (outch + N * UNROLL_SG_N * UNROLL_WG_N - 1) / (N * UNROLL_SG_N * UNROLL_WG_N); + + const uint wgmi = wgi / wgnn; + const uint wgni = wgi % wgnn; + + const uint sgmi = sgi / UNROLL_WG_N; + const uint sgni = sgi % UNROLL_WG_N; + +// const uint mm = (psc(size) + M - 1) / M; +// const uint nn = (outch + N - 1) / N; + const uint kk = (inch + K - 1) / K; + + if (wgmi >= wgmm) + return; + + const uint li = gl_LocalInvocationID.x; + const uint si = gl_SubgroupInvocationID; + + const uint Md4 = M / 4; + const uint Nd4 = N / 4; + const uint Kd4 = K / 4; + + const uint ni = (wgni * UNROLL_WG_N + sgni) * UNROLL_SG_N; + const uint mi = (wgmi * UNROLL_WG_M + sgmi) * UNROLL_SG_M; + +#if ncnn_VK_KHR_cooperative_matrix + coopmat sum[UNROLL_SG_N][UNROLL_SG_M]; +#elif ncnn_VK_NV_cooperative_matrix +#if NCNN_fp16_arithmetic + fcoopmatNV<16, gl_ScopeSubgroup, M, N> sum[UNROLL_SG_N][UNROLL_SG_M]; +#else + fcoopmatNV<32, gl_ScopeSubgroup, M, N> sum[UNROLL_SG_N][UNROLL_SG_M]; +#endif +#endif + + if (bias_term == 1) + { + [[unroll]] for (uint zn = 0; zn < UNROLL_SG_N; zn++) + { +#if ncnn_VK_KHR_cooperative_matrix + coopmat bias; + coopMatLoad(bias, bias_data, ((wgni* UNROLL_WG_N + sgni) * UNROLL_SG_N + zn) * Nd4, 0, gl_CooperativeMatrixLayoutRowMajor); +#elif ncnn_VK_NV_cooperative_matrix + fcoopmatNV<16, gl_ScopeSubgroup, M, N> bias; + coopMatLoadNV(bias, bias_data, ((wgni* UNROLL_WG_N + sgni) * UNROLL_SG_N + zn) * Nd4, 0, false); +#endif + + [[unroll]] for (uint zm = 0; zm < UNROLL_SG_M; zm++) + { +#if NCNN_fp16_arithmetic + sum[zn][zm] = bias; +#else +#if ncnn_VK_KHR_cooperative_matrix + sum[zn][zm] = coopmat(bias); +#elif ncnn_VK_NV_cooperative_matrix + sum[zn][zm] = fcoopmatNV<32, gl_ScopeSubgroup, M, N>(bias); +#endif +#endif + } + } + } + else + { + [[unroll]] for (uint zn = 0; zn < UNROLL_SG_N; zn++) + { + [[unroll]] for (uint zm = 0; zm < UNROLL_SG_M; zm++) + { +#if ncnn_VK_KHR_cooperative_matrix + sum[zn][zm] = coopmat(0.f); +#elif ncnn_VK_NV_cooperative_matrix +#if NCNN_fp16_arithmetic + sum[zn][zm] = fcoopmatNV<16, gl_ScopeSubgroup, M, N>(0.f); +#else + sum[zn][zm] = fcoopmatNV<32, gl_ScopeSubgroup, M, N>(0.f); +#endif +#endif + } + } + } + + uint k = 0; + + if (kk >= UNROLL_SG_K * 2) + { + // local stack and shared memory ping-pong + + // prefetch + uvec2 prefetch_tmp_v[(UNROLL_SG_M * UNROLL_SG_K * M * K / 4 + (ncnn_subgroupSize * UNROLL_WG_N - 1)) / (ncnn_subgroupSize * UNROLL_WG_N)]; + uvec2 prefetch_tmp_k[(UNROLL_SG_N * UNROLL_SG_K * K * N / 4 + (ncnn_subgroupSize * UNROLL_WG_M - 1)) / (ncnn_subgroupSize * UNROLL_WG_M)]; + + // prefetch the very first + { + const uint ki = 0; + + // load bottom_blob + { + if (elempack == 1) + { + // +-M-+ + // K | + // +SG_UM + // | | + // ^ +---+ + // | | | + // SG_UK+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UM +---+ + // | | | + // | +- -+ + // | | | + // v +---+ + + const uint cstepd4 = psc(cstep) / 4; + + const uint Md4_K_USGM_USGK = Md4 * K * UNROLL_SG_M * UNROLL_SG_K; + const uint Md4_K_USGM_USGK_d_subgroupsize = (Md4_K_USGM_USGK + (ncnn_subgroupSize * UNROLL_WG_N - 1)) / (ncnn_subgroupSize * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Md4_K_USGM_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * ncnn_subgroupSize + si; + + if (Md4_K_USGM_USGK % (ncnn_subgroupSize * UNROLL_WG_N) == 0 || siq < Md4_K_USGM_USGK) + { + const uint zk = siq / (Md4 * K * UNROLL_SG_M); + const uint zmij = siq % (Md4 * K * UNROLL_SG_M); + const uint zm = zmij / (Md4 * K); + const uint ij = zmij % (Md4 * K); + const uint i = ij / Md4; + const uint j = ij % Md4; + + const uint gk = ki + zk * K + i; + const uint gm = (mi + zm) * Md4 + j; + + uvec2 v = gk < inch && gm < cstepd4 ? bottom_blob_data[gk * cstepd4 + gm] : uvec2(0); + + prefetch_tmp_v[q] = v; + } + } + } + else // if (elempack == 4) + { + // +-K-+ + // M | + // +- -+ + // SG_UM | + // ^ +---+ + // | | | + // SG_UK+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UM +---+ + // | | | + // | +- -+ + // | | | + // v +---+ + + const uint inchd4 = inch / 4; + + const uint Kd4_M_USGM_USGK = Kd4 * M * UNROLL_SG_M * UNROLL_SG_K; + const uint Kd4_M_USGM_USGK_d_subgroupsize = (Kd4_M_USGM_USGK + (ncnn_subgroupSize * UNROLL_WG_N - 1)) / (ncnn_subgroupSize * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Kd4_M_USGM_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * ncnn_subgroupSize + si; + + if (Kd4_M_USGM_USGK % (ncnn_subgroupSize * UNROLL_WG_N) == 0 || siq < Kd4_M_USGM_USGK) + { + const uint zk = siq / (Kd4 * M * UNROLL_SG_M); + const uint zmij = siq % (Kd4 * M * UNROLL_SG_M); + const uint zmi = zmij / Kd4; + const uint j = zmij % Kd4; + + const uint gm = mi * M + zmi; + const uint gk = ki / 4 + zk * Kd4 + j; + + uvec2 v = gk < inchd4 && gm < psc(cstep) ? bottom_blob_data[gk * psc(cstep) + gm] : uvec2(0); + + prefetch_tmp_v[q] = v; + } + } + } + } + + // load weight + { + // +-N-+ + // K | + // +SG_UN + // | | + // ^ +---+ + // | | | + // SG_UK+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UN +---+ + // | | | + // | +- -+ + // | | | + // v +---+ + + // weight_data coopmat_N * coopmat_K * UNROLL_SG_N * UNROLL_WG_N * kk, blocks_n + + const uint Nd4_K_USGN = Nd4 * K * UNROLL_SG_N; + const uint Nd4_K_USGN_USGK = Nd4 * K * UNROLL_SG_N * UNROLL_SG_K; + + const uint w_offset = ((wgni * kk) * UNROLL_WG_N) * Nd4_K_USGN + ((k / UNROLL_SG_K) * UNROLL_WG_N + sgni) * Nd4_K_USGN_USGK; + + const uint Nd4_K_USGN_USGK_d_subgroupsize = (Nd4_K_USGN_USGK + (ncnn_subgroupSize * UNROLL_WG_M - 1)) / (ncnn_subgroupSize * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Nd4_K_USGN_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_M + sgmi) * ncnn_subgroupSize + si; + + if (Nd4_K_USGN_USGK % (ncnn_subgroupSize * UNROLL_WG_M) == 0 || siq < Nd4_K_USGN_USGK) + { + prefetch_tmp_k[q] = weight_data[w_offset + siq]; + } + } + } + } + + k += UNROLL_SG_K; + + for (; k + UNROLL_SG_K - 1 < kk; k += UNROLL_SG_K) + { + barrier(); + + // copy prefetch to shared memory + { + // load bottom_blob + { + if (elempack == 1) + { + const uint cstepd4 = psc(cstep) / 4; + + const uint Md4_K_USGM_USGK = Md4 * K * UNROLL_SG_M * UNROLL_SG_K; + const uint Md4_K_USGM_USGK_d_subgroupsize = (Md4_K_USGM_USGK + (ncnn_subgroupSize * UNROLL_WG_N - 1)) / (ncnn_subgroupSize * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Md4_K_USGM_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * ncnn_subgroupSize + si; + + if (Md4_K_USGM_USGK % (ncnn_subgroupSize * UNROLL_WG_N) == 0 || siq < Md4_K_USGM_USGK) + { + tmp_v[sgmi][siq] = prefetch_tmp_v[q]; + } + } + } + else // if (elempack == 4) + { + const uint inchd4 = inch / 4; + + const uint Kd4_M_USGM_USGK = Kd4 * M * UNROLL_SG_M * UNROLL_SG_K; + const uint Kd4_M_USGM_USGK_d_subgroupsize = (Kd4_M_USGM_USGK + (ncnn_subgroupSize * UNROLL_WG_N - 1)) / (ncnn_subgroupSize * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Kd4_M_USGM_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * ncnn_subgroupSize + si; + + if (Kd4_M_USGM_USGK % (ncnn_subgroupSize * UNROLL_WG_N) == 0 || siq < Kd4_M_USGM_USGK) + { + tmp_v[sgmi][siq] = prefetch_tmp_v[q]; + } + } + } + } + + // load weight + { + const uint Nd4_K_USGN = Nd4 * K * UNROLL_SG_N; + const uint Nd4_K_USGN_USGK = Nd4 * K * UNROLL_SG_N * UNROLL_SG_K; + + const uint w_offset = ((wgni * kk) * UNROLL_WG_N) * Nd4_K_USGN + ((k / UNROLL_SG_K) * UNROLL_WG_N + sgni) * Nd4_K_USGN_USGK; + + const uint Nd4_K_USGN_USGK_d_subgroupsize = (Nd4_K_USGN_USGK + (ncnn_subgroupSize * UNROLL_WG_M - 1)) / (ncnn_subgroupSize * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Nd4_K_USGN_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_M + sgmi) * ncnn_subgroupSize + si; + + if (Nd4_K_USGN_USGK % (ncnn_subgroupSize * UNROLL_WG_M) == 0 || siq < Nd4_K_USGN_USGK) + { + tmp_k[sgni][siq] = prefetch_tmp_k[q]; + } + } + } + } + + barrier(); + + // prefetch the next + { + const uint ki = k * K; + + // load bottom_blob + { + if (elempack == 1) + { + // +-M-+ + // K | + // +SG_UM + // | | + // ^ +---+ + // | | | + // SG_UK+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UM +---+ + // | | | + // | +- -+ + // | | | + // v +---+ + + const uint cstepd4 = psc(cstep) / 4; + + const uint Md4_K_USGM_USGK = Md4 * K * UNROLL_SG_M * UNROLL_SG_K; + const uint Md4_K_USGM_USGK_d_subgroupsize = (Md4_K_USGM_USGK + (ncnn_subgroupSize * UNROLL_WG_N - 1)) / (ncnn_subgroupSize * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Md4_K_USGM_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * ncnn_subgroupSize + si; + + if (Md4_K_USGM_USGK % (ncnn_subgroupSize * UNROLL_WG_N) == 0 || siq < Md4_K_USGM_USGK) + { + const uint zk = siq / (Md4 * K * UNROLL_SG_M); + const uint zmij = siq % (Md4 * K * UNROLL_SG_M); + const uint zm = zmij / (Md4 * K); + const uint ij = zmij % (Md4 * K); + const uint i = ij / Md4; + const uint j = ij % Md4; + + const uint gk = ki + zk * K + i; + const uint gm = (mi + zm) * Md4 + j; + + uvec2 v = gk < inch && gm < cstepd4 ? bottom_blob_data[gk * cstepd4 + gm] : uvec2(0); + + prefetch_tmp_v[q] = v; + } + } + } + else // if (elempack == 4) + { + // +-K-+ + // M | + // +- -+ + // SG_UM | + // ^ +---+ + // | | | + // SG_UK+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UM +---+ + // | | | + // | +- -+ + // | | | + // v +---+ + + const uint inchd4 = inch / 4; + + const uint Kd4_M_USGM_USGK = Kd4 * M * UNROLL_SG_M * UNROLL_SG_K; + const uint Kd4_M_USGM_USGK_d_subgroupsize = (Kd4_M_USGM_USGK + (ncnn_subgroupSize * UNROLL_WG_N - 1)) / (ncnn_subgroupSize * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Kd4_M_USGM_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * ncnn_subgroupSize + si; + + if (Kd4_M_USGM_USGK % (ncnn_subgroupSize * UNROLL_WG_N) == 0 || siq < Kd4_M_USGM_USGK) + { + const uint zk = siq / (Kd4 * M * UNROLL_SG_M); + const uint zmij = siq % (Kd4 * M * UNROLL_SG_M); + const uint zmi = zmij / Kd4; + const uint j = zmij % Kd4; + + const uint gm = mi * M + zmi; + const uint gk = ki / 4 + zk * Kd4 + j; + + uvec2 v = gk < inchd4 && gm < psc(cstep) ? bottom_blob_data[gk * psc(cstep) + gm] : uvec2(0); + + prefetch_tmp_v[q] = v; + } + } + } + } + + // load weight + { + // +-N-+ + // K | + // +SG_UN + // | | + // ^ +---+ + // | | | + // SG_UK+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UN +---+ + // | | | + // | +- -+ + // | | | + // v +---+ + + // weight_data coopmat_N * coopmat_K * UNROLL_SG_N * UNROLL_WG_N * kk, blocks_n + + const uint Nd4_K_USGN = Nd4 * K * UNROLL_SG_N; + const uint Nd4_K_USGN_USGK = Nd4 * K * UNROLL_SG_N * UNROLL_SG_K; + + const uint w_offset = ((wgni * kk) * UNROLL_WG_N) * Nd4_K_USGN + ((k / UNROLL_SG_K) * UNROLL_WG_N + sgni) * Nd4_K_USGN_USGK; + + const uint Nd4_K_USGN_USGK_d_subgroupsize = (Nd4_K_USGN_USGK + (ncnn_subgroupSize * UNROLL_WG_M - 1)) / (ncnn_subgroupSize * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Nd4_K_USGN_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_M + sgmi) * ncnn_subgroupSize + si; + + if (Nd4_K_USGN_USGK % (ncnn_subgroupSize * UNROLL_WG_M) == 0 || siq < Nd4_K_USGN_USGK) + { + prefetch_tmp_k[q] = weight_data[w_offset + siq]; + } + } + } + } + +#if ncnn_VK_KHR_cooperative_matrix + coopmat A[UNROLL_SG_M]; + coopmat B[UNROLL_SG_N]; +#elif ncnn_VK_NV_cooperative_matrix + fcoopmatNV<16, gl_ScopeSubgroup, M, K> A[UNROLL_SG_M]; + fcoopmatNV<16, gl_ScopeSubgroup, K, N> B[UNROLL_SG_N]; +#endif + + [[unroll]] for (uint zk = 0; zk < UNROLL_SG_K; zk++) + { + [[unroll]] for (uint zm = 0; zm < UNROLL_SG_M; zm++) + { + if (elempack == 1) + { + #if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(A[zm], tmp_v[sgmi], (zk * UNROLL_SG_M + zm) * (Md4 * K), Md4, gl_CooperativeMatrixLayoutColumnMajor); + #elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(A[zm], tmp_v[sgmi], (zk * UNROLL_SG_M + zm) * (Md4 * K), Md4, true); + #endif + } + else // if (elempack == 4) + { + #if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(A[zm], tmp_v[sgmi], (zk * UNROLL_SG_M + zm) * (Kd4 * M), Kd4, gl_CooperativeMatrixLayoutRowMajor); + #elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(A[zm], tmp_v[sgmi], (zk * UNROLL_SG_M + zm) * (Kd4 * M), Kd4, false); + #endif + } + } + + [[unroll]] for (uint zn = 0; zn < UNROLL_SG_N; zn++) + { + #if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(B[zn], tmp_k[sgni], (zk * UNROLL_SG_N + zn) * (Nd4 * K), Nd4, gl_CooperativeMatrixLayoutRowMajor); + #elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(B[zn], tmp_k[sgni], (zk * UNROLL_SG_N + zn) * (Nd4 * K), Nd4, false); + #endif + } + + // sum += k * v + [[unroll]] for (uint zn = 0; zn < UNROLL_SG_N; zn++) + { + [[unroll]] for (uint zm = 0; zm < UNROLL_SG_M; zm++) + { + #if ncnn_VK_KHR_cooperative_matrix + sum[zn][zm] = coopMatMulAdd(A[zm], B[zn], sum[zn][zm]); + #elif ncnn_VK_NV_cooperative_matrix + sum[zn][zm] = coopMatMulAddNV(A[zm], B[zn], sum[zn][zm]); + #endif + } + } + } + } + + barrier(); + + // the last copy prefetch to shared memory + { + // load bottom_blob + { + if (elempack == 1) + { + const uint cstepd4 = psc(cstep) / 4; + + const uint Md4_K_USGM_USGK = Md4 * K * UNROLL_SG_M * UNROLL_SG_K; + const uint Md4_K_USGM_USGK_d_subgroupsize = (Md4_K_USGM_USGK + (ncnn_subgroupSize * UNROLL_WG_N - 1)) / (ncnn_subgroupSize * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Md4_K_USGM_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * ncnn_subgroupSize + si; + + if (Md4_K_USGM_USGK % (ncnn_subgroupSize * UNROLL_WG_N) == 0 || siq < Md4_K_USGM_USGK) + { + tmp_v[sgmi][siq] = prefetch_tmp_v[q]; + } + } + } + else // if (elempack == 4) + { + const uint inchd4 = inch / 4; + + const uint Kd4_M_USGM_USGK = Kd4 * M * UNROLL_SG_M * UNROLL_SG_K; + const uint Kd4_M_USGM_USGK_d_subgroupsize = (Kd4_M_USGM_USGK + (ncnn_subgroupSize * UNROLL_WG_N - 1)) / (ncnn_subgroupSize * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Kd4_M_USGM_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * ncnn_subgroupSize + si; + + if (Kd4_M_USGM_USGK % (ncnn_subgroupSize * UNROLL_WG_N) == 0 || siq < Kd4_M_USGM_USGK) + { + tmp_v[sgmi][siq] = prefetch_tmp_v[q]; + } + } + } + } + + // load weight + { + const uint Nd4_K_USGN = Nd4 * K * UNROLL_SG_N; + const uint Nd4_K_USGN_USGK = Nd4 * K * UNROLL_SG_N * UNROLL_SG_K; + + const uint w_offset = ((wgni * kk) * UNROLL_WG_N) * Nd4_K_USGN + ((k / UNROLL_SG_K) * UNROLL_WG_N + sgni) * Nd4_K_USGN_USGK; + + const uint Nd4_K_USGN_USGK_d_subgroupsize = (Nd4_K_USGN_USGK + (ncnn_subgroupSize * UNROLL_WG_M - 1)) / (ncnn_subgroupSize * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Nd4_K_USGN_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_M + sgmi) * ncnn_subgroupSize + si; + + if (Nd4_K_USGN_USGK % (ncnn_subgroupSize * UNROLL_WG_M) == 0 || siq < Nd4_K_USGN_USGK) + { + tmp_k[sgni][siq] = prefetch_tmp_k[q]; + } + } + } + } + + barrier(); + +#if ncnn_VK_KHR_cooperative_matrix + coopmat A[UNROLL_SG_M]; + coopmat B[UNROLL_SG_N]; +#elif ncnn_VK_NV_cooperative_matrix + fcoopmatNV<16, gl_ScopeSubgroup, M, K> A[UNROLL_SG_M]; + fcoopmatNV<16, gl_ScopeSubgroup, K, N> B[UNROLL_SG_N]; +#endif + + [[unroll]] for (uint zk = 0; zk < UNROLL_SG_K; zk++) + { + [[unroll]] for (uint zm = 0; zm < UNROLL_SG_M; zm++) + { + if (elempack == 1) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(A[zm], tmp_v[sgmi], (zk * UNROLL_SG_M + zm) * (Md4 * K), Md4, gl_CooperativeMatrixLayoutColumnMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(A[zm], tmp_v[sgmi], (zk * UNROLL_SG_M + zm) * (Md4 * K), Md4, true); +#endif + } + else // if (elempack == 4) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(A[zm], tmp_v[sgmi], (zk * UNROLL_SG_M + zm) * (Kd4 * M), Kd4, gl_CooperativeMatrixLayoutRowMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(A[zm], tmp_v[sgmi], (zk * UNROLL_SG_M + zm) * (Kd4 * M), Kd4, false); +#endif + } + } + + [[unroll]] for (uint zn = 0; zn < UNROLL_SG_N; zn++) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(B[zn], tmp_k[sgni], (zk * UNROLL_SG_N + zn) * (Nd4 * K), Nd4, gl_CooperativeMatrixLayoutRowMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(B[zn], tmp_k[sgni], (zk * UNROLL_SG_N + zn) * (Nd4 * K), Nd4, false); +#endif + } + + // sum += k * v + [[unroll]] for (uint zn = 0; zn < UNROLL_SG_N; zn++) + { + [[unroll]] for (uint zm = 0; zm < UNROLL_SG_M; zm++) + { +#if ncnn_VK_KHR_cooperative_matrix + sum[zn][zm] = coopMatMulAdd(A[zm], B[zn], sum[zn][zm]); +#elif ncnn_VK_NV_cooperative_matrix + sum[zn][zm] = coopMatMulAddNV(A[zm], B[zn], sum[zn][zm]); +#endif + } + } + } + } + else if (kk >= UNROLL_SG_K) + { + // no ping-pong version + + const uint ki = 0; + + // load bottom_blob + { + if (elempack == 1) + { + // +-M-+ + // K | + // +SG_UM + // | | + // ^ +---+ + // | | | + // SG_UK+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UM +---+ + // | | | + // | +- -+ + // | | | + // v +---+ + + const uint cstepd4 = psc(cstep) / 4; + + const uint Md4_K_USGM_USGK = Md4 * K * UNROLL_SG_M * UNROLL_SG_K; + const uint Md4_K_USGM_USGK_d_subgroupsize = (Md4_K_USGM_USGK + (ncnn_subgroupSize * UNROLL_WG_N - 1)) / (ncnn_subgroupSize * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Md4_K_USGM_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * ncnn_subgroupSize + si; + + if (Md4_K_USGM_USGK % (ncnn_subgroupSize * UNROLL_WG_N) == 0 || siq < Md4_K_USGM_USGK) + { + const uint zk = siq / (Md4 * K * UNROLL_SG_M); + const uint zmij = siq % (Md4 * K * UNROLL_SG_M); + const uint zm = zmij / (Md4 * K); + const uint ij = zmij % (Md4 * K); + const uint i = ij / Md4; + const uint j = ij % Md4; + + const uint gk = ki + zk * K + i; + const uint gm = (mi + zm) * Md4 + j; + + uvec2 v = gk < inch && gm < cstepd4 ? bottom_blob_data[gk * cstepd4 + gm] : uvec2(0); + + tmp_v[sgmi][siq] = v; + } + } + } + else // if (elempack == 4) + { + // +-K-+ + // M | + // +- -+ + // SG_UM | + // ^ +---+ + // | | | + // SG_UK+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UM +---+ + // | | | + // | +- -+ + // | | | + // v +---+ + + const uint inchd4 = inch / 4; + + const uint Kd4_M_USGM_USGK = Kd4 * M * UNROLL_SG_M * UNROLL_SG_K; + const uint Kd4_M_USGM_USGK_d_subgroupsize = (Kd4_M_USGM_USGK + (ncnn_subgroupSize * UNROLL_WG_N - 1)) / (ncnn_subgroupSize * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Kd4_M_USGM_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * ncnn_subgroupSize + si; + + if (Kd4_M_USGM_USGK % (ncnn_subgroupSize * UNROLL_WG_N) == 0 || siq < Kd4_M_USGM_USGK) + { + const uint zk = siq / (Kd4 * M * UNROLL_SG_M); + const uint zmij = siq % (Kd4 * M * UNROLL_SG_M); + const uint zmi = zmij / Kd4; + const uint j = zmij % Kd4; + + const uint gm = mi * M + zmi; + const uint gk = ki / 4 + zk * Kd4 + j; + + uvec2 v = gk < inchd4 && gm < psc(cstep) ? bottom_blob_data[gk * psc(cstep) + gm] : uvec2(0); + + tmp_v[sgmi][siq] = v; + } + } + } + } + + // load weight + { + // +-N-+ + // K | + // +SG_UN + // | | + // ^ +---+ + // | | | + // SG_UK+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UN +---+ + // | | | + // | +- -+ + // | | | + // v +---+ + + // weight_data coopmat_N * coopmat_K * UNROLL_SG_N * UNROLL_WG_N * kk, blocks_n + + const uint Nd4_K_USGN = Nd4 * K * UNROLL_SG_N; + const uint Nd4_K_USGN_USGK = Nd4 * K * UNROLL_SG_N * UNROLL_SG_K; + + const uint w_offset = ((wgni * kk) * UNROLL_WG_N) * Nd4_K_USGN + ((k / UNROLL_SG_K) * UNROLL_WG_N + sgni) * Nd4_K_USGN_USGK; + + const uint Nd4_K_USGN_USGK_d_subgroupsize = (Nd4_K_USGN_USGK + (ncnn_subgroupSize * UNROLL_WG_M - 1)) / (ncnn_subgroupSize * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Nd4_K_USGN_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_M + sgmi) * ncnn_subgroupSize + si; + + if (Nd4_K_USGN_USGK % (ncnn_subgroupSize * UNROLL_WG_M) == 0 || siq < Nd4_K_USGN_USGK) + { + tmp_k[sgni][siq] = weight_data[w_offset + siq]; + } + } + } + + barrier(); + +#if ncnn_VK_KHR_cooperative_matrix + coopmat A[UNROLL_SG_M]; + coopmat B[UNROLL_SG_N]; +#elif ncnn_VK_NV_cooperative_matrix + fcoopmatNV<16, gl_ScopeSubgroup, M, K> A[UNROLL_SG_M]; + fcoopmatNV<16, gl_ScopeSubgroup, K, N> B[UNROLL_SG_N]; +#endif + + [[unroll]] for (uint zk = 0; zk < UNROLL_SG_K; zk++) + { + [[unroll]] for (uint zm = 0; zm < UNROLL_SG_M; zm++) + { + if (elempack == 1) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(A[zm], tmp_v[sgmi], (zk * UNROLL_SG_M + zm) * (Md4 * K), Md4, gl_CooperativeMatrixLayoutColumnMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(A[zm], tmp_v[sgmi], (zk * UNROLL_SG_M + zm) * (Md4 * K), Md4, true); +#endif + } + else // if (elempack == 4) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(A[zm], tmp_v[sgmi], (zk * UNROLL_SG_M + zm) * (Kd4 * M), Kd4, gl_CooperativeMatrixLayoutRowMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(A[zm], tmp_v[sgmi], (zk * UNROLL_SG_M + zm) * (Kd4 * M), Kd4, false); +#endif + } + } + + [[unroll]] for (uint zn = 0; zn < UNROLL_SG_N; zn++) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(B[zn], tmp_k[sgni], (zk * UNROLL_SG_N + zn) * (Nd4 * K), Nd4, gl_CooperativeMatrixLayoutRowMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(B[zn], tmp_k[sgni], (zk * UNROLL_SG_N + zn) * (Nd4 * K), Nd4, false); +#endif + } + + // sum += k * v + [[unroll]] for (uint zn = 0; zn < UNROLL_SG_N; zn++) + { + [[unroll]] for (uint zm = 0; zm < UNROLL_SG_M; zm++) + { +#if ncnn_VK_KHR_cooperative_matrix + sum[zn][zm] = coopMatMulAdd(A[zm], B[zn], sum[zn][zm]); +#elif ncnn_VK_NV_cooperative_matrix + sum[zn][zm] = coopMatMulAddNV(A[zm], B[zn], sum[zn][zm]); +#endif + } + } + } + + k += UNROLL_SG_K; + } + + for (; k < kk; k++) + { + const uint ki = k * K; + + barrier(); + + // load bottom_blob + { + if (elempack == 1) + { + // +-M-+ + // K | + // +SG_UM + // | | + // ^ +---+ + // | | | + // WG_UM+- -+ + // | | | + // v +---+ + + const uint cstepd4 = psc(cstep) / 4; + + const uint Md4_K_USGM = Md4 * K * UNROLL_SG_M; + const uint Md4_K_USGM_d_subgroupsize = (Md4_K_USGM + (ncnn_subgroupSize * UNROLL_WG_N - 1)) / (ncnn_subgroupSize * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Md4_K_USGM_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * ncnn_subgroupSize + si; + + if (Md4_K_USGM % (ncnn_subgroupSize * UNROLL_WG_N) == 0 || siq < Md4_K_USGM) + { + const uint zm = siq / (Md4 * K); + const uint ij = siq % (Md4 * K); + const uint i = ij / Md4; + const uint j = ij % Md4; + + const uint gk = ki + i; + const uint gm = (mi + zm) * Md4 + j; + + uvec2 v = gk < inch && gm < cstepd4 ? bottom_blob_data[gk * cstepd4 + gm] : uvec2(0); + + tmp_v[sgmi][siq] = v; + } + } + } + else // if (elempack == 4) + { + // +-K-+ + // M | + // +- -+ + // SG_UM | + // ^ +---+ + // | | | + // WG_UM+- -+ + // | | | + // v +---+ + + const uint inchd4 = inch / 4; + + const uint Kd4_M_USGM = Kd4 * M * UNROLL_SG_M; + const uint Kd4_M_USGM_d_subgroupsize = (Kd4_M_USGM + (ncnn_subgroupSize * UNROLL_WG_N - 1)) / (ncnn_subgroupSize * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Kd4_M_USGM_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * ncnn_subgroupSize + si; + + if (Kd4_M_USGM % (ncnn_subgroupSize * UNROLL_WG_N) == 0 || siq < Kd4_M_USGM) + { + const uint zmi = siq / Kd4; + const uint j = siq % Kd4; + + const uint gm = mi * M + zmi; + const uint gk = ki / 4 + j; + + uvec2 v = gk < inchd4 && gm < psc(cstep) ? bottom_blob_data[gk * psc(cstep) + gm] : uvec2(0); + + tmp_v[sgmi][siq] = v; + } + } + } + } + + // load weight + { + // +-N-+ + // K | + // +SG_UN + // | | + // ^ +---+ + // | | | + // WG_UN+- -+ + // | | | + // v +---+ + + // weight_data coopmat_N * coopmat_K * UNROLL_SG_N * UNROLL_WG_N * kk, blocks_n + + const uint Nd4_K_USGN = Nd4 * K * UNROLL_SG_N; + + const uint w_offset = ((wgni * kk + k) * UNROLL_WG_N + sgni) * Nd4_K_USGN; + + const uint Nd4_K_USGN_d_subgroupsize = (Nd4_K_USGN + (ncnn_subgroupSize * UNROLL_WG_M - 1)) / (ncnn_subgroupSize * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Nd4_K_USGN_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_M + sgmi) * ncnn_subgroupSize + si; + + if (Nd4_K_USGN % (ncnn_subgroupSize * UNROLL_WG_M) == 0 || siq < Nd4_K_USGN) + { + tmp_k[sgni][siq] = weight_data[w_offset + siq]; + } + } + } + + barrier(); + +#if ncnn_VK_KHR_cooperative_matrix + coopmat A[UNROLL_SG_M]; + coopmat B[UNROLL_SG_N]; +#elif ncnn_VK_NV_cooperative_matrix + fcoopmatNV<16, gl_ScopeSubgroup, M, K> A[UNROLL_SG_M]; + fcoopmatNV<16, gl_ScopeSubgroup, K, N> B[UNROLL_SG_N]; +#endif + + [[unroll]] for (uint zm = 0; zm < UNROLL_SG_M; zm++) + { + if (elempack == 1) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(A[zm], tmp_v[sgmi], zm * (Md4 * K), Md4, gl_CooperativeMatrixLayoutColumnMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(A[zm], tmp_v[sgmi], zm * (Md4 * K), Md4, true); +#endif + } + else // if (elempack == 4) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(A[zm], tmp_v[sgmi], zm * (Kd4 * M), Kd4, gl_CooperativeMatrixLayoutRowMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(A[zm], tmp_v[sgmi], zm * (Kd4 * M), Kd4, false); +#endif + } + } + + [[unroll]] for (uint zn = 0; zn < UNROLL_SG_N; zn++) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(B[zn], tmp_k[sgni], zn * (Nd4 * K), Nd4, gl_CooperativeMatrixLayoutRowMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(B[zn], tmp_k[sgni], zn * (Nd4 * K), Nd4, false); +#endif + } + + // sum += k * v + [[unroll]] for (uint zn = 0; zn < UNROLL_SG_N; zn++) + { + [[unroll]] for (uint zm = 0; zm < UNROLL_SG_M; zm++) + { +#if ncnn_VK_KHR_cooperative_matrix + sum[zn][zm] = coopMatMulAdd(A[zm], B[zn], sum[zn][zm]); +#elif ncnn_VK_NV_cooperative_matrix + sum[zn][zm] = coopMatMulAddNV(A[zm], B[zn], sum[zn][zm]); +#endif + } + } + } + + [[unroll]] for (uint zn = 0; zn < UNROLL_SG_N; zn++) + { + [[unroll]] for (uint zm = 0; zm < UNROLL_SG_M; zm++) + { + if (out_elempack == 1) + { +#if ncnn_VK_KHR_cooperative_matrix +#if NCNN_fp16_arithmetic + coopMatStore(sum[zn][zm], tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Md4 * N), Md4, gl_CooperativeMatrixLayoutColumnMajor); +#else + coopmat sum_fp16 = coopmat(sum[zn][zm]); + coopMatStore(sum_fp16, tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Md4 * N), Md4, gl_CooperativeMatrixLayoutColumnMajor); +#endif +#elif ncnn_VK_NV_cooperative_matrix +#if NCNN_fp16_arithmetic + coopMatStoreNV(sum[zn][zm], tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Md4 * N), Md4, true); +#else + fcoopmatNV<16, gl_ScopeSubgroup, M, N> sum_fp16 = fcoopmatNV<16, gl_ScopeSubgroup, M, N>(sum[zn][zm]); + coopMatStoreNV(sum_fp16, tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Md4 * N), Md4, true); +#endif +#endif + } + else // if (out_elempack == 4) + { +#if ncnn_VK_KHR_cooperative_matrix +#if NCNN_fp16_arithmetic + coopMatStore(sum[zn][zm], tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Nd4 * M), Nd4, gl_CooperativeMatrixLayoutRowMajor); +#else + coopmat sum_fp16 = coopmat(sum[zn][zm]); + coopMatStore(sum_fp16, tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Nd4 * M), Nd4, gl_CooperativeMatrixLayoutRowMajor); +#endif +#elif ncnn_VK_NV_cooperative_matrix +#if NCNN_fp16_arithmetic + coopMatStoreNV(sum[zn][zm], tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Nd4 * M), Nd4, false); +#else + fcoopmatNV<16, gl_ScopeSubgroup, M, N> sum_fp16 = fcoopmatNV<16, gl_ScopeSubgroup, M, N>(sum[zn][zm]); + coopMatStoreNV(sum_fp16, tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Nd4 * M), Nd4, false); +#endif +#endif + } + } + } + + barrier(); + + // store top_blob + { + if (out_elempack == 1) + { + // +-M-+ + // N | + // +SG_UM + // | | + // ^ +---+ + // | | | + // SG_UN+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UM +- -+ + // | | | + // | +- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // | +---+ + // | | | + // | +- -+ + // | | | + // WG_UN +---+ + // | | | + // | +- -+ + // | | | + // | +---+ + // | | | + // | +- -+ + // | | | + // v +---+ + + const uint outcstepd4 = psc(outcstep) / 4; + + const uint Md4_N_USGM_USGN = Md4 * N * UNROLL_SG_M * UNROLL_SG_N; + const uint Md4_N_USGM_USGN_d_subgroupsize = (Md4_N_USGM_USGN + ncnn_subgroupSize - 1) / ncnn_subgroupSize; + [[unroll]] for (uint q = 0; q < Md4_N_USGM_USGN_d_subgroupsize; q++) + { + const uint siq = si + q * ncnn_subgroupSize; + + if (Md4_N_USGM_USGN % ncnn_subgroupSize == 0 || siq < Md4_N_USGM_USGN) + { + const uint zn = siq / (Md4 * N * UNROLL_SG_M); + const uint zmij = siq % (Md4 * N * UNROLL_SG_M); + const uint zm = zmij / (Md4 * N); + const uint ij = zmij % (Md4 * N); + const uint i = ij / Md4; + const uint j = ij % Md4; + + const uint gn = (ni + zn) * N + i; + const uint gm = (mi + zm) * Md4 + j; + + if (gn < outch && gm < outcstepd4) + { + uvec2 sum = tmp_o[sgi][siq]; + + if (activation_type == 0) + { + top_blob_data[gn * outcstepd4 + gm] = sum; + } + else + { + afpvec4 v = afpvec4(unpackHalf2x16(sum.r), unpackHalf2x16(sum.g)); + + v = activation_afpvec4(v, activation_type, activation_param_0, activation_param_1); + + top_blob_data[gn * outcstepd4 + gm] = uvec2(packHalf2x16(vec4(v).rg), packHalf2x16(vec4(v).ba)); + } + } + } + } + } + else // if (out_elempack == 4) + { + // +-N-+ + // M | + // +---+ + // SG_UM | + // ^ +---+ + // | | | + // SG_UN+---+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UM +---+ + // | | | + // | +---+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // | +---+ + // | | | + // | +---+ + // | | | + // WG_UN +---+ + // | | | + // | +- -+ + // | | | + // | +---+ + // | | | + // | +---+ + // | | | + // v +---+ + + const uint outchd4 = outch / 4; + + const uint Nd4_M_USGM_USGN = Nd4 * M * UNROLL_SG_M * UNROLL_SG_N; + const uint Nd4_M_USGM_USGN_d_subgroupsize = (Nd4_M_USGM_USGN + ncnn_subgroupSize - 1) / ncnn_subgroupSize; + [[unroll]] for (uint q = 0; q < Nd4_M_USGM_USGN_d_subgroupsize; q++) + { + const uint siq = si + q * ncnn_subgroupSize; + + if (Nd4_M_USGM_USGN % ncnn_subgroupSize == 0 || siq < Nd4_M_USGM_USGN) + { + const uint zn = siq / (Nd4 * M * UNROLL_SG_M); + const uint zmij = siq % (Nd4 * M * UNROLL_SG_M); + const uint zmi = zmij / Nd4; + const uint j = zmij % Nd4; + + const uint gn = (ni + zn) * Nd4 + j; + const uint gm = mi * M + zmi; + + if (gn < outchd4 && gm < psc(outcstep)) + { + uvec2 sum = tmp_o[sgi][siq]; + + if (activation_type == 0) + { + top_blob_data[gn * psc(outcstep) + gm] = sum; + } + else + { + afpvec4 v = afpvec4(unpackHalf2x16(sum.r), unpackHalf2x16(sum.g)); + + v = activation_afpvec4(v, activation_type, activation_param_0, activation_param_1); + + top_blob_data[gn * psc(outcstep) + gm] = uvec2(packHalf2x16(vec4(v).rg), packHalf2x16(vec4(v).ba)); + } + } + } + } + } + } +} diff --git a/src/layer/vulkan/shader/convolution_pack4_1x1s1d1_cm_16_16_16.comp b/src/layer/vulkan/shader/convolution_pack4_1x1s1d1_cm_16_16_16.comp deleted file mode 100644 index b0093a6d3..000000000 --- a/src/layer/vulkan/shader/convolution_pack4_1x1s1d1_cm_16_16_16.comp +++ /dev/null @@ -1,337 +0,0 @@ -// Copyright 2023 Tencent -// SPDX-License-Identifier: BSD-3-Clause - -#version 450 - -#extension GL_GOOGLE_include_directive: enable -#include "vulkan_activation.comp" - -#extension GL_KHR_memory_scope_semantics: require -#extension GL_EXT_shader_explicit_arithmetic_types: require -#extension GL_EXT_shader_explicit_arithmetic_types_float16: require -#if ncnn_VK_KHR_cooperative_matrix -#extension GL_KHR_cooperative_matrix: require -#elif ncnn_VK_NV_cooperative_matrix -#extension GL_NV_cooperative_matrix: require -#endif - -layout (constant_id = 0) const int bias_term = 0; -layout (constant_id = 1) const int activation_type = 0; -layout (constant_id = 2) const float activation_param_0 = 0; -layout (constant_id = 3) const float activation_param_1 = 0; - -#define shape_constant_id_offset 4 -layout (constant_id = shape_constant_id_offset + 0) const int w = 0; -layout (constant_id = shape_constant_id_offset + 1) const int h = 0; -layout (constant_id = shape_constant_id_offset + 2) const int c = 0; -layout (constant_id = shape_constant_id_offset + 3) const int cstep = 0; - -layout (constant_id = shape_constant_id_offset + 4) const int outw = 0; -layout (constant_id = shape_constant_id_offset + 5) const int outh = 0; -layout (constant_id = shape_constant_id_offset + 6) const int outc = 0; -layout (constant_id = shape_constant_id_offset + 7) const int outcstep = 0; - -layout (binding = 0) readonly buffer bottom_blob { uvec2 bottom_blob_data[]; }; -layout (binding = 1) writeonly buffer top_blob { sfpvec4 top_blob_data[]; }; -layout (binding = 2) readonly buffer weight_blob { uvec2 weight_data[]; }; -layout (binding = 3) readonly buffer bias_blob { uvec2 bias_data[]; }; - -layout (push_constant) uniform parameter -{ - int w; - int h; - int c; - int cstep; - - int outw; - int outh; - int outc; - int outcstep; -} p; - -#define UNROLL_INCH 2 - -shared uvec2 tmp_v0[UNROLL_INCH * 16*4]; -shared uvec2 tmp_v1[UNROLL_INCH * 16*4]; -shared uvec2 tmp_k0[UNROLL_INCH * 16*4]; -shared uvec2 tmp_k1[UNROLL_INCH * 16*4]; - -void main() -{ - int gx = int(gl_GlobalInvocationID.x) / 32 * 2 * 16; - int gy = int(gl_GlobalInvocationID.y) * 2 * 4; - - const int lx = int(gl_LocalInvocationID.x); - - const int lxd16 = lx / 16; // 0 1 - const int lxm16 = lx % 16; // 0 1 2 3 .... 15 - -#if ncnn_VK_KHR_cooperative_matrix - coopmat sum0; - coopmat sum1; - coopmat sum2; - coopmat sum3; -#elif ncnn_VK_NV_cooperative_matrix -#if NCNN_fp16_arithmetic - fcoopmatNV<16, gl_ScopeSubgroup, 16, 16> sum0; - fcoopmatNV<16, gl_ScopeSubgroup, 16, 16> sum1; - fcoopmatNV<16, gl_ScopeSubgroup, 16, 16> sum2; - fcoopmatNV<16, gl_ScopeSubgroup, 16, 16> sum3; -#else - fcoopmatNV<32, gl_ScopeSubgroup, 16, 16> sum0; - fcoopmatNV<32, gl_ScopeSubgroup, 16, 16> sum1; - fcoopmatNV<32, gl_ScopeSubgroup, 16, 16> sum2; - fcoopmatNV<32, gl_ScopeSubgroup, 16, 16> sum3; -#endif -#endif - - if (bias_term == 1) - { -#if ncnn_VK_KHR_cooperative_matrix - coopmat bias0; - coopmat bias1; - - coopMatLoad(bias0, bias_data, gy, 0, gl_CooperativeMatrixLayoutRowMajor); - coopMatLoad(bias1, bias_data, gy + 4, 0, gl_CooperativeMatrixLayoutRowMajor); -#elif ncnn_VK_NV_cooperative_matrix - fcoopmatNV<16, gl_ScopeSubgroup, 16, 16> bias0; - fcoopmatNV<16, gl_ScopeSubgroup, 16, 16> bias1; - - coopMatLoadNV(bias0, bias_data, gy, 0, false); - coopMatLoadNV(bias1, bias_data, gy + 4, 0, false); -#endif - -#if NCNN_fp16_arithmetic - sum0 = bias0; - sum1 = bias0; - sum2 = bias1; - sum3 = bias1; -#else -#if ncnn_VK_KHR_cooperative_matrix - sum0 = coopmat(bias0); - sum1 = coopmat(bias0); - sum2 = coopmat(bias1); - sum3 = coopmat(bias1); -#elif ncnn_VK_NV_cooperative_matrix - sum0 = fcoopmatNV<32, gl_ScopeSubgroup, 16, 16>(bias0); - sum1 = fcoopmatNV<32, gl_ScopeSubgroup, 16, 16>(bias0); - sum2 = fcoopmatNV<32, gl_ScopeSubgroup, 16, 16>(bias1); - sum3 = fcoopmatNV<32, gl_ScopeSubgroup, 16, 16>(bias1); -#endif -#endif - } - else - { -#if ncnn_VK_KHR_cooperative_matrix - sum0 = coopmat(0.f); - sum1 = coopmat(0.f); - sum2 = coopmat(0.f); - sum3 = coopmat(0.f); -#elif ncnn_VK_NV_cooperative_matrix -#if NCNN_fp16_arithmetic - sum0 = fcoopmatNV<16, gl_ScopeSubgroup, 16, 16>(0.f); - sum1 = fcoopmatNV<16, gl_ScopeSubgroup, 16, 16>(0.f); - sum2 = fcoopmatNV<16, gl_ScopeSubgroup, 16, 16>(0.f); - sum3 = fcoopmatNV<16, gl_ScopeSubgroup, 16, 16>(0.f); -#else - sum0 = fcoopmatNV<32, gl_ScopeSubgroup, 16, 16>(0.f); - sum1 = fcoopmatNV<32, gl_ScopeSubgroup, 16, 16>(0.f); - sum2 = fcoopmatNV<32, gl_ScopeSubgroup, 16, 16>(0.f); - sum3 = fcoopmatNV<32, gl_ScopeSubgroup, 16, 16>(0.f); -#endif -#endif - } - - const int N = psc(c) / 4; - - int z = 0; - for (; z + (UNROLL_INCH - 1) < N; z += UNROLL_INCH) - { - { - for (int j = 0; j < 4; j++) - { - const int tmp_i = lxd16*16*4 + lxm16 * 4 + j; - - const int v_offset = ((z + lxd16) * 4 + j) * psc(outcstep) + (gx + lxm16); - - tmp_v0[tmp_i] = (gx + lxm16) < psc(outcstep) ? bottom_blob_data[v_offset] : uvec2(0); - tmp_v1[tmp_i] = (gx + lxm16 + 16) < psc(outcstep) ? bottom_blob_data[v_offset + 16] : uvec2(0); - - const int w_offset = gy * psc(c) * 4 + (z + lxd16) * 4 * 16 + (lxm16 * 4 + j); - - tmp_k0[tmp_i] = weight_data[w_offset]; - tmp_k1[tmp_i] = weight_data[w_offset + psc(c) * 16]; - } - } - - barrier(); - - for (int z4 = 0; z4 < UNROLL_INCH; z4++) - { -#if ncnn_VK_KHR_cooperative_matrix - coopmat A0; - coopmat A1; - coopMatLoad(A0, tmp_v0, z4*16*4, 4, gl_CooperativeMatrixLayoutRowMajor); - coopMatLoad(A1, tmp_v1, z4*16*4, 4, gl_CooperativeMatrixLayoutRowMajor); - - coopmat B0; - coopmat B1; - coopMatLoad(B0, tmp_k0, z4*16*4, 4, gl_CooperativeMatrixLayoutRowMajor); - coopMatLoad(B1, tmp_k1, z4*16*4, 4, gl_CooperativeMatrixLayoutRowMajor); - - // sum += v * k - sum0 = coopMatMulAdd(A0, B0, sum0); - sum1 = coopMatMulAdd(A1, B0, sum1); - sum2 = coopMatMulAdd(A0, B1, sum2); - sum3 = coopMatMulAdd(A1, B1, sum3); -#elif ncnn_VK_NV_cooperative_matrix - fcoopmatNV<16, gl_ScopeSubgroup, 16, 16> A0; - fcoopmatNV<16, gl_ScopeSubgroup, 16, 16> A1; - coopMatLoadNV(A0, tmp_v0, z4*16*4, 4, false); - coopMatLoadNV(A1, tmp_v1, z4*16*4, 4, false); - - fcoopmatNV<16, gl_ScopeSubgroup, 16, 16> B0; - fcoopmatNV<16, gl_ScopeSubgroup, 16, 16> B1; - coopMatLoadNV(B0, tmp_k0, z4*16*4, 4, false); - coopMatLoadNV(B1, tmp_k1, z4*16*4, 4, false); - - // sum += v * k - sum0 = coopMatMulAddNV(A0, B0, sum0); - sum1 = coopMatMulAddNV(A1, B0, sum1); - sum2 = coopMatMulAddNV(A0, B1, sum2); - sum3 = coopMatMulAddNV(A1, B1, sum3); -#endif - } - - barrier(); - } - - if (z < N) - { - const int remain = N - z; - - if (lxd16 == 0) - { - for (int j = 0; j < 4; j++) - { - const int tmp_i = lxd16*16*4 + lxm16 * 4 + j; - - const int v_offset = ((z + lxd16) * 4 + j) * psc(outcstep) + (gx + lxm16); - - tmp_v0[tmp_i] = (gx + lxm16) < psc(outcstep) ? bottom_blob_data[v_offset] : uvec2(0); - tmp_v1[tmp_i] = (gx + lxm16 + 16) < psc(outcstep) ? bottom_blob_data[v_offset + 16] : uvec2(0); - - const int w_offset = gy * psc(c) * 4 + (z + lxd16) * 4 * 16 + (lxm16 * 4 + j); - - tmp_k0[tmp_i] = weight_data[w_offset]; - tmp_k1[tmp_i] = weight_data[w_offset + psc(c) * 16]; - } - } - - barrier(); - - for (int z4 = 0; z4 < remain; z4++) - { -#if ncnn_VK_KHR_cooperative_matrix - coopmat A0; - coopmat A1; - coopMatLoad(A0, tmp_v0, z4*16*4, 4, gl_CooperativeMatrixLayoutRowMajor); - coopMatLoad(A1, tmp_v1, z4*16*4, 4, gl_CooperativeMatrixLayoutRowMajor); - - coopmat B0; - coopmat B1; - coopMatLoad(B0, tmp_k0, z4*16*4, 4, gl_CooperativeMatrixLayoutRowMajor); - coopMatLoad(B1, tmp_k1, z4*16*4, 4, gl_CooperativeMatrixLayoutRowMajor); - - // sum += v * k - sum0 = coopMatMulAdd(A0, B0, sum0); - sum1 = coopMatMulAdd(A1, B0, sum1); - sum2 = coopMatMulAdd(A0, B1, sum2); - sum3 = coopMatMulAdd(A1, B1, sum3); -#elif ncnn_VK_NV_cooperative_matrix - fcoopmatNV<16, gl_ScopeSubgroup, 16, 16> A0; - fcoopmatNV<16, gl_ScopeSubgroup, 16, 16> A1; - coopMatLoadNV(A0, tmp_v0, z4*16*4, 4, false); - coopMatLoadNV(A1, tmp_v1, z4*16*4, 4, false); - - fcoopmatNV<16, gl_ScopeSubgroup, 16, 16> B0; - fcoopmatNV<16, gl_ScopeSubgroup, 16, 16> B1; - coopMatLoadNV(B0, tmp_k0, z4*16*4, 4, false); - coopMatLoadNV(B1, tmp_k1, z4*16*4, 4, false); - - // sum += v * k - sum0 = coopMatMulAddNV(A0, B0, sum0); - sum1 = coopMatMulAddNV(A1, B0, sum1); - sum2 = coopMatMulAddNV(A0, B1, sum2); - sum3 = coopMatMulAddNV(A1, B1, sum3); -#endif - } - - barrier(); - } - - if (gx >= psc(outcstep) || gy >= psc(outc)) - return; - -#if ncnn_VK_KHR_cooperative_matrix -#if NCNN_fp16_arithmetic - coopMatStore(sum0, tmp_v0, 0, 4, gl_CooperativeMatrixLayoutRowMajor); - coopMatStore(sum1, tmp_v1, 0, 4, gl_CooperativeMatrixLayoutRowMajor); - coopMatStore(sum2, tmp_v0, 16*4, 4, gl_CooperativeMatrixLayoutRowMajor); - coopMatStore(sum3, tmp_v1, 16*4, 4, gl_CooperativeMatrixLayoutRowMajor); -#else - coopmat sum0_fp16 = coopmat(sum0); - coopmat sum1_fp16 = coopmat(sum1); - coopmat sum2_fp16 = coopmat(sum2); - coopmat sum3_fp16 = coopmat(sum3); - - coopMatStore(sum0_fp16, tmp_v0, 0, 4, gl_CooperativeMatrixLayoutRowMajor); - coopMatStore(sum1_fp16, tmp_v1, 0, 4, gl_CooperativeMatrixLayoutRowMajor); - coopMatStore(sum2_fp16, tmp_v0, 16*4, 4, gl_CooperativeMatrixLayoutRowMajor); - coopMatStore(sum3_fp16, tmp_v1, 16*4, 4, gl_CooperativeMatrixLayoutRowMajor); -#endif -#elif ncnn_VK_NV_cooperative_matrix -#if NCNN_fp16_arithmetic - coopMatStoreNV(sum0, tmp_v0, 0, 4, false); - coopMatStoreNV(sum1, tmp_v1, 0, 4, false); - coopMatStoreNV(sum2, tmp_v0, 16*4, 4, false); - coopMatStoreNV(sum3, tmp_v1, 16*4, 4, false); -#else - fcoopmatNV<16, gl_ScopeSubgroup, 16, 16> sum0_fp16 = fcoopmatNV<16, gl_ScopeSubgroup, 16, 16>(sum0); - fcoopmatNV<16, gl_ScopeSubgroup, 16, 16> sum1_fp16 = fcoopmatNV<16, gl_ScopeSubgroup, 16, 16>(sum1); - fcoopmatNV<16, gl_ScopeSubgroup, 16, 16> sum2_fp16 = fcoopmatNV<16, gl_ScopeSubgroup, 16, 16>(sum2); - fcoopmatNV<16, gl_ScopeSubgroup, 16, 16> sum3_fp16 = fcoopmatNV<16, gl_ScopeSubgroup, 16, 16>(sum3); - - coopMatStoreNV(sum0_fp16, tmp_v0, 0, 4, false); - coopMatStoreNV(sum1_fp16, tmp_v1, 0, 4, false); - coopMatStoreNV(sum2_fp16, tmp_v0, 16*4, 4, false); - coopMatStoreNV(sum3_fp16, tmp_v1, 16*4, 4, false); -#endif -#endif - - barrier(); - - { - for (int j = 0; j < 4; j++) - { - const int tmp_vi = lxm16 * 4 + j + lxd16*16*4; - - uvec2 sum0_u2 = tmp_v0[tmp_vi]; - uvec2 sum1_u2 = tmp_v1[tmp_vi]; - - afpvec4 sum0 = afpvec4(unpackHalf2x16(sum0_u2.x), unpackHalf2x16(sum0_u2.y)); - afpvec4 sum1 = afpvec4(unpackHalf2x16(sum1_u2.x), unpackHalf2x16(sum1_u2.y)); - - sum0 = activation_afpvec4(sum0, activation_type, activation_param_0, activation_param_1); - sum1 = activation_afpvec4(sum1, activation_type, activation_param_0, activation_param_1); - - const int gi = (gy + lxd16 * 4 + j) * psc(outcstep) + (gx + lxm16); - - if (gy + lxd16 * 4 + j < psc(outc)) - { - if (gx + lxm16 < psc(outcstep)) buffer_st4(top_blob_data, gi, sum0); - if (gx + lxm16 + 16 < psc(outcstep)) buffer_st4(top_blob_data, gi + 16, sum1); - } - } - } -} diff --git a/src/layer/vulkan/shader/convolution_pack4_1x1s1d1_cm_16_8_8.comp b/src/layer/vulkan/shader/convolution_pack4_1x1s1d1_cm_16_8_8.comp deleted file mode 100644 index 9922263ec..000000000 --- a/src/layer/vulkan/shader/convolution_pack4_1x1s1d1_cm_16_8_8.comp +++ /dev/null @@ -1,456 +0,0 @@ -// Copyright 2023 Tencent -// SPDX-License-Identifier: BSD-3-Clause - -#version 450 - -#extension GL_GOOGLE_include_directive: enable -#include "vulkan_activation.comp" - -#extension GL_KHR_memory_scope_semantics: require -#extension GL_EXT_shader_explicit_arithmetic_types: require -#extension GL_EXT_shader_explicit_arithmetic_types_float16: require -#if ncnn_VK_KHR_cooperative_matrix -#extension GL_KHR_cooperative_matrix: require -#elif ncnn_VK_NV_cooperative_matrix -#extension GL_NV_cooperative_matrix: require -#endif - -layout (constant_id = 0) const int bias_term = 0; -layout (constant_id = 1) const int activation_type = 0; -layout (constant_id = 2) const float activation_param_0 = 0; -layout (constant_id = 3) const float activation_param_1 = 0; - -#define shape_constant_id_offset 4 -layout (constant_id = shape_constant_id_offset + 0) const int w = 0; -layout (constant_id = shape_constant_id_offset + 1) const int h = 0; -layout (constant_id = shape_constant_id_offset + 2) const int c = 0; -layout (constant_id = shape_constant_id_offset + 3) const int cstep = 0; - -layout (constant_id = shape_constant_id_offset + 4) const int outw = 0; -layout (constant_id = shape_constant_id_offset + 5) const int outh = 0; -layout (constant_id = shape_constant_id_offset + 6) const int outc = 0; -layout (constant_id = shape_constant_id_offset + 7) const int outcstep = 0; - -layout (binding = 0) readonly buffer bottom_blob { uvec2 bottom_blob_data[]; }; -layout (binding = 1) writeonly buffer top_blob { sfpvec4 top_blob_data[]; }; -layout (binding = 2) readonly buffer weight_blob { uvec2 weight_data[]; }; -layout (binding = 3) readonly buffer bias_blob { uvec2 bias_data[]; }; - -layout (push_constant) uniform parameter -{ - int w; - int h; - int c; - int cstep; - - int outw; - int outh; - int outc; - int outcstep; -} p; - -#define UNROLL_INCH 4 - -shared uvec2 tmp_v0[UNROLL_INCH * 16*2]; -shared uvec2 tmp_v1[UNROLL_INCH * 16*2]; -shared uvec2 tmp_k0[UNROLL_INCH * 8*2]; -shared uvec2 tmp_k1[UNROLL_INCH * 8*2]; -shared uvec2 tmp_k2[UNROLL_INCH * 8*2]; -shared uvec2 tmp_k3[UNROLL_INCH * 8*2]; - -void main() -{ - int gx = int(gl_GlobalInvocationID.x) / 32 * 2 * 16; - int gy = int(gl_GlobalInvocationID.y) * 2 * 4; - - const int lx = int(gl_LocalInvocationID.x); - - const int lxd8 = lx / 8; // 0 1 2 3 - const int lxm8 = lx % 8; // 0 1 2 3 .... 7 - -#if ncnn_VK_KHR_cooperative_matrix - coopmat sum0; - coopmat sum1; - coopmat sum2; - coopmat sum3; - coopmat sum4; - coopmat sum5; - coopmat sum6; - coopmat sum7; -#elif ncnn_VK_NV_cooperative_matrix -#if NCNN_fp16_arithmetic - fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> sum0; - fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> sum1; - fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> sum2; - fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> sum3; - fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> sum4; - fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> sum5; - fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> sum6; - fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> sum7; -#else - fcoopmatNV<32, gl_ScopeSubgroup, 16, 8> sum0; - fcoopmatNV<32, gl_ScopeSubgroup, 16, 8> sum1; - fcoopmatNV<32, gl_ScopeSubgroup, 16, 8> sum2; - fcoopmatNV<32, gl_ScopeSubgroup, 16, 8> sum3; - fcoopmatNV<32, gl_ScopeSubgroup, 16, 8> sum4; - fcoopmatNV<32, gl_ScopeSubgroup, 16, 8> sum5; - fcoopmatNV<32, gl_ScopeSubgroup, 16, 8> sum6; - fcoopmatNV<32, gl_ScopeSubgroup, 16, 8> sum7; -#endif -#endif - - if (bias_term == 1) - { -#if ncnn_VK_KHR_cooperative_matrix - coopmat bias0; - coopmat bias1; - coopmat bias2; - coopmat bias3; - - coopMatLoad(bias0, bias_data, gy, 0, gl_CooperativeMatrixLayoutRowMajor); - coopMatLoad(bias1, bias_data, gy + 2, 0, gl_CooperativeMatrixLayoutRowMajor); - coopMatLoad(bias2, bias_data, gy + 4, 0, gl_CooperativeMatrixLayoutRowMajor); - coopMatLoad(bias3, bias_data, gy + 6, 0, gl_CooperativeMatrixLayoutRowMajor); -#elif ncnn_VK_NV_cooperative_matrix - fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> bias0; - fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> bias1; - fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> bias2; - fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> bias3; - - coopMatLoadNV(bias0, bias_data, gy, 0, false); - coopMatLoadNV(bias1, bias_data, gy + 2, 0, false); - coopMatLoadNV(bias2, bias_data, gy + 4, 0, false); - coopMatLoadNV(bias3, bias_data, gy + 6, 0, false); -#endif - -#if NCNN_fp16_arithmetic - sum0 = bias0; - sum1 = bias0; - sum2 = bias1; - sum3 = bias1; - sum4 = bias2; - sum5 = bias2; - sum6 = bias3; - sum7 = bias3; -#else -#if ncnn_VK_KHR_cooperative_matrix - sum0 = coopmat(bias0); - sum1 = coopmat(bias0); - sum2 = coopmat(bias1); - sum3 = coopmat(bias1); - sum4 = coopmat(bias2); - sum5 = coopmat(bias2); - sum6 = coopmat(bias3); - sum7 = coopmat(bias3); -#elif ncnn_VK_NV_cooperative_matrix - sum0 = fcoopmatNV<32, gl_ScopeSubgroup, 16, 8>(bias0); - sum1 = fcoopmatNV<32, gl_ScopeSubgroup, 16, 8>(bias0); - sum2 = fcoopmatNV<32, gl_ScopeSubgroup, 16, 8>(bias1); - sum3 = fcoopmatNV<32, gl_ScopeSubgroup, 16, 8>(bias1); - sum4 = fcoopmatNV<32, gl_ScopeSubgroup, 16, 8>(bias2); - sum5 = fcoopmatNV<32, gl_ScopeSubgroup, 16, 8>(bias2); - sum6 = fcoopmatNV<32, gl_ScopeSubgroup, 16, 8>(bias3); - sum7 = fcoopmatNV<32, gl_ScopeSubgroup, 16, 8>(bias3); -#endif -#endif - } - else - { -#if ncnn_VK_KHR_cooperative_matrix - sum0 = coopmat(0.f); - sum1 = coopmat(0.f); - sum2 = coopmat(0.f); - sum3 = coopmat(0.f); - sum4 = coopmat(0.f); - sum5 = coopmat(0.f); - sum6 = coopmat(0.f); - sum7 = coopmat(0.f); -#elif ncnn_VK_NV_cooperative_matrix -#if NCNN_fp16_arithmetic - sum0 = fcoopmatNV<16, gl_ScopeSubgroup, 16, 8>(0.f); - sum1 = fcoopmatNV<16, gl_ScopeSubgroup, 16, 8>(0.f); - sum2 = fcoopmatNV<16, gl_ScopeSubgroup, 16, 8>(0.f); - sum3 = fcoopmatNV<16, gl_ScopeSubgroup, 16, 8>(0.f); - sum4 = fcoopmatNV<16, gl_ScopeSubgroup, 16, 8>(0.f); - sum5 = fcoopmatNV<16, gl_ScopeSubgroup, 16, 8>(0.f); - sum6 = fcoopmatNV<16, gl_ScopeSubgroup, 16, 8>(0.f); - sum7 = fcoopmatNV<16, gl_ScopeSubgroup, 16, 8>(0.f); -#else - sum0 = fcoopmatNV<32, gl_ScopeSubgroup, 16, 8>(0.f); - sum1 = fcoopmatNV<32, gl_ScopeSubgroup, 16, 8>(0.f); - sum2 = fcoopmatNV<32, gl_ScopeSubgroup, 16, 8>(0.f); - sum3 = fcoopmatNV<32, gl_ScopeSubgroup, 16, 8>(0.f); - sum4 = fcoopmatNV<32, gl_ScopeSubgroup, 16, 8>(0.f); - sum5 = fcoopmatNV<32, gl_ScopeSubgroup, 16, 8>(0.f); - sum6 = fcoopmatNV<32, gl_ScopeSubgroup, 16, 8>(0.f); - sum7 = fcoopmatNV<32, gl_ScopeSubgroup, 16, 8>(0.f); -#endif -#endif - } - - const int N = psc(c) / 2; - - int z = 0; - for (; z + (UNROLL_INCH - 1) < N; z += UNROLL_INCH) - { - { - for (int j = 0; j < 2; j++) - { - const int tmp_vi = lxd8*16*2 + lxm8 * 2 + j; - - int v_offset = ((z + lxd8) * 2 + j) * psc(outcstep) + (gx + lxm8); - - tmp_v0[tmp_vi] = (gx + lxm8) < psc(outcstep) ? bottom_blob_data[v_offset] : uvec2(0); - tmp_v0[tmp_vi + 16] = (gx + lxm8 + 8) < psc(outcstep) ? bottom_blob_data[v_offset + 8] : uvec2(0); - tmp_v1[tmp_vi] = (gx + lxm8 + 16) < psc(outcstep) ? bottom_blob_data[v_offset + 16] : uvec2(0); - tmp_v1[tmp_vi + 16] = (gx + lxm8 + 24) < psc(outcstep) ? bottom_blob_data[v_offset + 24] : uvec2(0); - - const int tmp_ki = lxd8*8*2 + lxm8 * 2 + j; - - int w_offset = gy * psc(c) * 4 + (z + lxd8) * 2 * 8 + (lxm8 * 2 + j); - - tmp_k0[tmp_ki] = weight_data[w_offset]; - tmp_k1[tmp_ki] = weight_data[w_offset + psc(c) * 8]; - tmp_k2[tmp_ki] = weight_data[w_offset + psc(c) * 16]; - tmp_k3[tmp_ki] = weight_data[w_offset + psc(c) * 24]; - } - } - - barrier(); - - for (int z4 = 0; z4 < UNROLL_INCH; z4++) - { -#if ncnn_VK_KHR_cooperative_matrix - coopmat A0; - coopmat A1; - coopMatLoad(A0, tmp_v0, z4*16*2, 2, gl_CooperativeMatrixLayoutRowMajor); - coopMatLoad(A1, tmp_v1, z4*16*2, 2, gl_CooperativeMatrixLayoutRowMajor); - - coopmat B0; - coopmat B1; - coopmat B2; - coopmat B3; - coopMatLoad(B0, tmp_k0, z4*8*2, 2, gl_CooperativeMatrixLayoutRowMajor); - coopMatLoad(B1, tmp_k1, z4*8*2, 2, gl_CooperativeMatrixLayoutRowMajor); - coopMatLoad(B2, tmp_k2, z4*8*2, 2, gl_CooperativeMatrixLayoutRowMajor); - coopMatLoad(B3, tmp_k3, z4*8*2, 2, gl_CooperativeMatrixLayoutRowMajor); - - // sum += v * k - sum0 = coopMatMulAdd(A0, B0, sum0); - sum1 = coopMatMulAdd(A1, B0, sum1); - sum2 = coopMatMulAdd(A0, B1, sum2); - sum3 = coopMatMulAdd(A1, B1, sum3); - sum4 = coopMatMulAdd(A0, B2, sum4); - sum5 = coopMatMulAdd(A1, B2, sum5); - sum6 = coopMatMulAdd(A0, B3, sum6); - sum7 = coopMatMulAdd(A1, B3, sum7); -#elif ncnn_VK_NV_cooperative_matrix - fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> A0; - fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> A1; - coopMatLoadNV(A0, tmp_v0, z4*16*2, 2, false); - coopMatLoadNV(A1, tmp_v1, z4*16*2, 2, false); - - fcoopmatNV<16, gl_ScopeSubgroup, 8, 8> B0; - fcoopmatNV<16, gl_ScopeSubgroup, 8, 8> B1; - fcoopmatNV<16, gl_ScopeSubgroup, 8, 8> B2; - fcoopmatNV<16, gl_ScopeSubgroup, 8, 8> B3; - coopMatLoadNV(B0, tmp_k0, z4*8*2, 2, false); - coopMatLoadNV(B1, tmp_k1, z4*8*2, 2, false); - coopMatLoadNV(B2, tmp_k2, z4*8*2, 2, false); - coopMatLoadNV(B3, tmp_k3, z4*8*2, 2, false); - - // sum += v * k - sum0 = coopMatMulAddNV(A0, B0, sum0); - sum1 = coopMatMulAddNV(A1, B0, sum1); - sum2 = coopMatMulAddNV(A0, B1, sum2); - sum3 = coopMatMulAddNV(A1, B1, sum3); - sum4 = coopMatMulAddNV(A0, B2, sum4); - sum5 = coopMatMulAddNV(A1, B2, sum5); - sum6 = coopMatMulAddNV(A0, B3, sum6); - sum7 = coopMatMulAddNV(A1, B3, sum7); -#endif - } - - barrier(); - } - - if (z < N) - { - const int remain = N - z; - - if (lxd8 < remain) - { - for (int j = 0; j < 2; j++) - { - const int tmp_vi = lxd8*16*2 + lxm8 * 2 + j; - - int v_offset = ((z + lxd8) * 2 + j) * psc(outcstep) + (gx + lxm8); - - tmp_v0[tmp_vi] = (gx + lxm8) < psc(outcstep) ? bottom_blob_data[v_offset] : uvec2(0); - tmp_v0[tmp_vi + 16] = (gx + lxm8 + 8) < psc(outcstep) ? bottom_blob_data[v_offset + 8] : uvec2(0); - tmp_v1[tmp_vi] = (gx + lxm8 + 16) < psc(outcstep) ? bottom_blob_data[v_offset + 16] : uvec2(0); - tmp_v1[tmp_vi + 16] = (gx + lxm8 + 24) < psc(outcstep) ? bottom_blob_data[v_offset + 24] : uvec2(0); - - const int tmp_ki = lxd8*8*2 + lxm8 * 2 + j; - - int w_offset = gy * psc(c) * 4 + (z + lxd8) * 2 * 8 + (lxm8 * 2 + j); - - tmp_k0[tmp_ki] = weight_data[w_offset]; - tmp_k1[tmp_ki] = weight_data[w_offset + psc(c) * 8]; - tmp_k2[tmp_ki] = weight_data[w_offset + psc(c) * 16]; - tmp_k3[tmp_ki] = weight_data[w_offset + psc(c) * 24]; - } - } - - barrier(); - - for (int z4 = 0; z4 < remain; z4++) - { -#if ncnn_VK_KHR_cooperative_matrix - coopmat A0; - coopmat A1; - coopMatLoad(A0, tmp_v0, z4*16*2, 2, gl_CooperativeMatrixLayoutRowMajor); - coopMatLoad(A1, tmp_v1, z4*16*2, 2, gl_CooperativeMatrixLayoutRowMajor); - - coopmat B0; - coopmat B1; - coopmat B2; - coopmat B3; - coopMatLoad(B0, tmp_k0, z4*8*2, 2, gl_CooperativeMatrixLayoutRowMajor); - coopMatLoad(B1, tmp_k1, z4*8*2, 2, gl_CooperativeMatrixLayoutRowMajor); - coopMatLoad(B2, tmp_k2, z4*8*2, 2, gl_CooperativeMatrixLayoutRowMajor); - coopMatLoad(B3, tmp_k3, z4*8*2, 2, gl_CooperativeMatrixLayoutRowMajor); - - // sum += v * k - sum0 = coopMatMulAdd(A0, B0, sum0); - sum1 = coopMatMulAdd(A1, B0, sum1); - sum2 = coopMatMulAdd(A0, B1, sum2); - sum3 = coopMatMulAdd(A1, B1, sum3); - sum4 = coopMatMulAdd(A0, B2, sum4); - sum5 = coopMatMulAdd(A1, B2, sum5); - sum6 = coopMatMulAdd(A0, B3, sum6); - sum7 = coopMatMulAdd(A1, B3, sum7); -#elif ncnn_VK_NV_cooperative_matrix - fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> A0; - fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> A1; - coopMatLoadNV(A0, tmp_v0, z4*16*2, 2, false); - coopMatLoadNV(A1, tmp_v1, z4*16*2, 2, false); - - fcoopmatNV<16, gl_ScopeSubgroup, 8, 8> B0; - fcoopmatNV<16, gl_ScopeSubgroup, 8, 8> B1; - fcoopmatNV<16, gl_ScopeSubgroup, 8, 8> B2; - fcoopmatNV<16, gl_ScopeSubgroup, 8, 8> B3; - coopMatLoadNV(B0, tmp_k0, z4*8*2, 2, false); - coopMatLoadNV(B1, tmp_k1, z4*8*2, 2, false); - coopMatLoadNV(B2, tmp_k2, z4*8*2, 2, false); - coopMatLoadNV(B3, tmp_k3, z4*8*2, 2, false); - - // sum += v * k - sum0 = coopMatMulAddNV(A0, B0, sum0); - sum1 = coopMatMulAddNV(A1, B0, sum1); - sum2 = coopMatMulAddNV(A0, B1, sum2); - sum3 = coopMatMulAddNV(A1, B1, sum3); - sum4 = coopMatMulAddNV(A0, B2, sum4); - sum5 = coopMatMulAddNV(A1, B2, sum5); - sum6 = coopMatMulAddNV(A0, B3, sum6); - sum7 = coopMatMulAddNV(A1, B3, sum7); -#endif - } - - barrier(); - } - - if (gx >= psc(outcstep) || gy >= psc(outc)) - return; - -#if ncnn_VK_KHR_cooperative_matrix -#if NCNN_fp16_arithmetic - coopMatStore(sum0, tmp_v0, 0, 2, gl_CooperativeMatrixLayoutRowMajor); - coopMatStore(sum1, tmp_v1, 0, 2, gl_CooperativeMatrixLayoutRowMajor); - coopMatStore(sum2, tmp_v0, 16*2, 2, gl_CooperativeMatrixLayoutRowMajor); - coopMatStore(sum3, tmp_v1, 16*2, 2, gl_CooperativeMatrixLayoutRowMajor); - coopMatStore(sum4, tmp_v0, 16*4, 2, gl_CooperativeMatrixLayoutRowMajor); - coopMatStore(sum5, tmp_v1, 16*4, 2, gl_CooperativeMatrixLayoutRowMajor); - coopMatStore(sum6, tmp_v0, 16*6, 2, gl_CooperativeMatrixLayoutRowMajor); - coopMatStore(sum7, tmp_v1, 16*6, 2, gl_CooperativeMatrixLayoutRowMajor); -#else - coopmat sum0_fp16 = coopmat(sum0); - coopmat sum1_fp16 = coopmat(sum1); - coopmat sum2_fp16 = coopmat(sum2); - coopmat sum3_fp16 = coopmat(sum3); - coopmat sum4_fp16 = coopmat(sum4); - coopmat sum5_fp16 = coopmat(sum5); - coopmat sum6_fp16 = coopmat(sum6); - coopmat sum7_fp16 = coopmat(sum7); - - coopMatStore(sum0_fp16, tmp_v0, 0, 2, gl_CooperativeMatrixLayoutRowMajor); - coopMatStore(sum1_fp16, tmp_v1, 0, 2, gl_CooperativeMatrixLayoutRowMajor); - coopMatStore(sum2_fp16, tmp_v0, 16*2, 2, gl_CooperativeMatrixLayoutRowMajor); - coopMatStore(sum3_fp16, tmp_v1, 16*2, 2, gl_CooperativeMatrixLayoutRowMajor); - coopMatStore(sum4_fp16, tmp_v0, 16*4, 2, gl_CooperativeMatrixLayoutRowMajor); - coopMatStore(sum5_fp16, tmp_v1, 16*4, 2, gl_CooperativeMatrixLayoutRowMajor); - coopMatStore(sum6_fp16, tmp_v0, 16*6, 2, gl_CooperativeMatrixLayoutRowMajor); - coopMatStore(sum7_fp16, tmp_v1, 16*6, 2, gl_CooperativeMatrixLayoutRowMajor); -#endif -#elif ncnn_VK_NV_cooperative_matrix -#if NCNN_fp16_arithmetic - coopMatStoreNV(sum0, tmp_v0, 0, 2, false); - coopMatStoreNV(sum1, tmp_v1, 0, 2, false); - coopMatStoreNV(sum2, tmp_v0, 16*2, 2, false); - coopMatStoreNV(sum3, tmp_v1, 16*2, 2, false); - coopMatStoreNV(sum4, tmp_v0, 16*4, 2, false); - coopMatStoreNV(sum5, tmp_v1, 16*4, 2, false); - coopMatStoreNV(sum6, tmp_v0, 16*6, 2, false); - coopMatStoreNV(sum7, tmp_v1, 16*6, 2, false); -#else - fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> sum0_fp16 = fcoopmatNV<16, gl_ScopeSubgroup, 16, 8>(sum0); - fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> sum1_fp16 = fcoopmatNV<16, gl_ScopeSubgroup, 16, 8>(sum1); - fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> sum2_fp16 = fcoopmatNV<16, gl_ScopeSubgroup, 16, 8>(sum2); - fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> sum3_fp16 = fcoopmatNV<16, gl_ScopeSubgroup, 16, 8>(sum3); - fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> sum4_fp16 = fcoopmatNV<16, gl_ScopeSubgroup, 16, 8>(sum4); - fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> sum5_fp16 = fcoopmatNV<16, gl_ScopeSubgroup, 16, 8>(sum5); - fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> sum6_fp16 = fcoopmatNV<16, gl_ScopeSubgroup, 16, 8>(sum6); - fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> sum7_fp16 = fcoopmatNV<16, gl_ScopeSubgroup, 16, 8>(sum7); - - coopMatStoreNV(sum0_fp16, tmp_v0, 0, 2, false); - coopMatStoreNV(sum1_fp16, tmp_v1, 0, 2, false); - coopMatStoreNV(sum2_fp16, tmp_v0, 16*2, 2, false); - coopMatStoreNV(sum3_fp16, tmp_v1, 16*2, 2, false); - coopMatStoreNV(sum4_fp16, tmp_v0, 16*4, 2, false); - coopMatStoreNV(sum5_fp16, tmp_v1, 16*4, 2, false); - coopMatStoreNV(sum6_fp16, tmp_v0, 16*6, 2, false); - coopMatStoreNV(sum7_fp16, tmp_v1, 16*6, 2, false); -#endif -#endif - - barrier(); - - const int lxd16 = lx / 16; // 0 1 - const int lxm16 = lx % 16; // 0 1 2 3 .... 15 - - { - for (int j = 0; j < 4; j++) - { - const int tmp_vi = lxm16 * 2 + lxd16 + j*16*2; - const int gi = (gy + lxd16 + j*2) * psc(outcstep) + (gx + lxm16); - - if (gy + j * 2 + lxd16 < psc(outc)) - { - if (gx + lxm16 < psc(outcstep)) - { - uvec2 sum0_u2 = tmp_v0[tmp_vi]; - afpvec4 sum0 = afpvec4(unpackHalf2x16(sum0_u2.x), unpackHalf2x16(sum0_u2.y)); - sum0 = activation_afpvec4(sum0, activation_type, activation_param_0, activation_param_1); - buffer_st4(top_blob_data, gi, sum0); - } - if (gx + lxm16 + 16 < psc(outcstep)) - { - uvec2 sum1_u2 = tmp_v1[tmp_vi]; - afpvec4 sum1 = afpvec4(unpackHalf2x16(sum1_u2.x), unpackHalf2x16(sum1_u2.y)); - sum1 = activation_afpvec4(sum1, activation_type, activation_param_0, activation_param_1); - buffer_st4(top_blob_data, gi + 16, sum1); - } - } - } - } -}