From b640574b88f3433d4e11ddc233f6ce57dfe256e8 Mon Sep 17 00:00:00 2001 From: nihui Date: Wed, 26 Apr 2023 15:13:27 +0800 Subject: [PATCH] rough vulkan gemm and multiheadattention (#4618) --- docs/developer-guide/operators.md | 2 +- src/layer/arm/multiheadattention_arm.cpp | 22 +- src/layer/multiheadattention.cpp | 18 +- src/layer/multiheadattention.h | 2 +- src/layer/vulkan/gemm_vulkan.cpp | 425 +++++++++++ src/layer/vulkan/gemm_vulkan.h | 56 ++ .../vulkan/multiheadattention_vulkan.cpp | 705 ++++++++++++++++++ src/layer/vulkan/multiheadattention_vulkan.h | 57 ++ src/layer/vulkan/shader/gemm.comp | 453 +++++++++++ .../shader/multiheadattention_qk_cross.comp | 81 ++ .../multiheadattention_qk_cross_pack1to4.comp | 90 +++ .../multiheadattention_qk_cross_pack4.comp | 186 +++++ .../multiheadattention_qk_cross_pack4to1.comp | 81 ++ .../shader/multiheadattention_qkv_cross.comp | 81 ++ ...multiheadattention_qkv_cross_pack1to4.comp | 81 ++ .../multiheadattention_qkv_cross_pack4.comp | 177 +++++ ...multiheadattention_qkv_cross_pack4to1.comp | 87 +++ src/layer/x86/multiheadattention_x86.cpp | 12 +- tests/test_multiheadattention.cpp | 12 +- tools/modelwriter.h | 2 +- 20 files changed, 2596 insertions(+), 34 deletions(-) create mode 100644 src/layer/vulkan/gemm_vulkan.cpp create mode 100644 src/layer/vulkan/gemm_vulkan.h create mode 100644 src/layer/vulkan/multiheadattention_vulkan.cpp create mode 100644 src/layer/vulkan/multiheadattention_vulkan.h create mode 100644 src/layer/vulkan/shader/gemm.comp create mode 100644 src/layer/vulkan/shader/multiheadattention_qk_cross.comp create mode 100644 src/layer/vulkan/shader/multiheadattention_qk_cross_pack1to4.comp create mode 100644 src/layer/vulkan/shader/multiheadattention_qk_cross_pack4.comp create mode 100644 src/layer/vulkan/shader/multiheadattention_qk_cross_pack4to1.comp create mode 100644 src/layer/vulkan/shader/multiheadattention_qkv_cross.comp create mode 100644 src/layer/vulkan/shader/multiheadattention_qkv_cross_pack1to4.comp create mode 100644 src/layer/vulkan/shader/multiheadattention_qkv_cross_pack4.comp create mode 100644 src/layer/vulkan/shader/multiheadattention_qkv_cross_pack4to1.comp diff --git a/docs/developer-guide/operators.md b/docs/developer-guide/operators.md index c40db96a7..797e1aa08 100644 --- a/docs/developer-guide/operators.md +++ b/docs/developer-guide/operators.md @@ -1185,7 +1185,7 @@ y = affine(out) | param id | name | type | default | description | | --------- | ------------- | ----- | --------- | ----------------- | | 0 | embed_dim | int | 0 | | -| 1 | num_head | int | 1 | | +| 1 | num_heads | int | 1 | | | 2 | weight_data_size| int | 0 | | | 3 | kdim | int | embed_dim | | | 4 | vdim | int | embed_dim | | diff --git a/src/layer/arm/multiheadattention_arm.cpp b/src/layer/arm/multiheadattention_arm.cpp index 57372a033..2ea7094bc 100644 --- a/src/layer/arm/multiheadattention_arm.cpp +++ b/src/layer/arm/multiheadattention_arm.cpp @@ -68,7 +68,7 @@ int MultiHeadAttention_arm::create_pipeline(const Option& opt) Option optopt = optn; { - const int embed_dim_per_head = embed_dim / num_head; + const int embed_dim_per_head = embed_dim / num_heads; const float inv_sqrt_embed_dim_per_head = 1.f / sqrt(embed_dim_per_head); q_gemm = ncnn::create_layer(ncnn::LayerType::Gemm); @@ -240,7 +240,7 @@ int MultiHeadAttention_arm::create_pipeline(const Option& opt) optopt.use_fp16_storage = false; { - const int embed_dim_per_head = embed_dim / num_head; + const int embed_dim_per_head = embed_dim / num_heads; const float inv_sqrt_embed_dim_per_head = 1.f / sqrt(embed_dim_per_head); q_gemm = ncnn::create_layer(ncnn::LayerType::Gemm); @@ -530,7 +530,7 @@ int MultiHeadAttention_arm::forward(const std::vector& bottom_blobs, std::v const Mat& k_blob = bottom_blobs.size() == 1 ? q_blob : bottom_blobs[1]; const Mat& v_blob = bottom_blobs.size() == 1 ? q_blob : bottom_blobs.size() == 2 ? k_blob : bottom_blobs[2]; - const int embed_dim_per_head = embed_dim / num_head; + const int embed_dim_per_head = embed_dim / num_heads; const int src_seqlen = q_blob.h * q_blob.elempack; const int dst_seqlen = k_blob.h * k_blob.elempack; @@ -555,9 +555,9 @@ int MultiHeadAttention_arm::forward(const std::vector& bottom_blobs, std::v Mat k_affine; k_gemm->forward(k_blob, k_affine, optn); - Mat qk_cross(dst_seqlen, src_seqlen * num_head, 2u, optn.blob_allocator); + Mat qk_cross(dst_seqlen, src_seqlen * num_heads, 2u, optn.blob_allocator); #pragma omp parallel for num_threads(optn.num_threads) - for (int i = 0; i < num_head; i++) + for (int i = 0; i < num_heads; i++) { std::vector qk_bottom_blobs(2); qk_bottom_blobs[0] = q_affine.row_range(i * embed_dim_per_head, embed_dim_per_head); @@ -577,9 +577,9 @@ int MultiHeadAttention_arm::forward(const std::vector& bottom_blobs, std::v Mat v_affine; v_gemm->forward(v_blob, v_affine, optn); - Mat qkv_cross(src_seqlen, embed_dim_per_head * num_head, 2u, optn.blob_allocator); + Mat qkv_cross(src_seqlen, embed_dim_per_head * num_heads, 2u, optn.blob_allocator); #pragma omp parallel for num_threads(optn.num_threads) - for (int i = 0; i < num_head; i++) + for (int i = 0; i < num_heads; i++) { std::vector qkv_bottom_blobs(2); qkv_bottom_blobs[0] = qk_cross.row_range(i * src_seqlen, src_seqlen); @@ -605,9 +605,9 @@ int MultiHeadAttention_arm::forward(const std::vector& bottom_blobs, std::v Mat k_affine; k_gemm->forward(k_blob, k_affine, opt32); - Mat qk_cross(dst_seqlen, src_seqlen * num_head, 4u, opt32.blob_allocator); + Mat qk_cross(dst_seqlen, src_seqlen * num_heads, 4u, opt32.blob_allocator); #pragma omp parallel for num_threads(opt32.num_threads) - for (int i = 0; i < num_head; i++) + for (int i = 0; i < num_heads; i++) { std::vector qk_bottom_blobs(2); qk_bottom_blobs[0] = q_affine.row_range(i * embed_dim_per_head, embed_dim_per_head); @@ -627,9 +627,9 @@ int MultiHeadAttention_arm::forward(const std::vector& bottom_blobs, std::v Mat v_affine; v_gemm->forward(v_blob, v_affine, opt32); - Mat qkv_cross(src_seqlen, embed_dim_per_head * num_head, 4u, opt32.blob_allocator); + Mat qkv_cross(src_seqlen, embed_dim_per_head * num_heads, 4u, opt32.blob_allocator); #pragma omp parallel for num_threads(opt32.num_threads) - for (int i = 0; i < num_head; i++) + for (int i = 0; i < num_heads; i++) { std::vector qkv_bottom_blobs(2); qkv_bottom_blobs[0] = qk_cross.row_range(i * src_seqlen, src_seqlen); diff --git a/src/layer/multiheadattention.cpp b/src/layer/multiheadattention.cpp index 966df81d4..c9a6bfc84 100644 --- a/src/layer/multiheadattention.cpp +++ b/src/layer/multiheadattention.cpp @@ -25,7 +25,7 @@ MultiHeadAttention::MultiHeadAttention() int MultiHeadAttention::load_param(const ParamDict& pd) { embed_dim = pd.get(0, 0); - num_head = pd.get(1, 1); + num_heads = pd.get(1, 1); weight_data_size = pd.get(2, 0); kdim = pd.get(3, embed_dim); vdim = pd.get(4, embed_dim); @@ -79,7 +79,7 @@ int MultiHeadAttention::forward(const std::vector& bottom_blobs, std::vecto const int src_seqlen = q_blob.h; const int dst_seqlen = k_blob.h; - const int embed_dim_per_head = embed_dim / num_head; + const int embed_dim_per_head = embed_dim / num_heads; // assert k_blob.h == v_blob.h @@ -88,18 +88,18 @@ int MultiHeadAttention::forward(const std::vector& bottom_blobs, std::vecto if (top_blob.empty()) return -1; - Mat xq(embed_dim_per_head, src_seqlen, num_head, 4u, opt.workspace_allocator); - Mat xk(embed_dim_per_head, dst_seqlen, num_head, 4u, opt.workspace_allocator); - Mat xv(dst_seqlen, embed_dim_per_head, num_head, 4u, opt.workspace_allocator); + Mat xq(embed_dim_per_head, src_seqlen, num_heads, 4u, opt.workspace_allocator); + Mat xk(embed_dim_per_head, dst_seqlen, num_heads, 4u, opt.workspace_allocator); + Mat xv(dst_seqlen, embed_dim_per_head, num_heads, 4u, opt.workspace_allocator); - Mat xqk(dst_seqlen, src_seqlen, num_head, 4u, opt.workspace_allocator); + Mat xqk(dst_seqlen, src_seqlen, num_heads, 4u, opt.workspace_allocator); - Mat xqkv(embed_dim_per_head, num_head, src_seqlen, 4u, opt.workspace_allocator); + Mat xqkv(embed_dim_per_head, num_heads, src_seqlen, 4u, opt.workspace_allocator); const float inv_sqrt_embed_dim_per_head = 1.f / sqrt(embed_dim_per_head); #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < num_head; q++) + for (int q = 0; q < num_heads; q++) { // xq = affine(q) * inv_sqrt_embed_dim_per_head { @@ -233,7 +233,7 @@ int MultiHeadAttention::forward(const std::vector& bottom_blobs, std::vecto // xqkv = xqk * xv // xqk (dst_seqlen, src_seqlen) // xv (dst_seqlen, embed_dim_per_head) - // out (embed_dim_per_head, num_head, src_seqlen) + // out (embed_dim_per_head, num_heads, src_seqlen) { const Mat xqkm = xqk.channel(q); const Mat xvm = xv.channel(q); diff --git a/src/layer/multiheadattention.h b/src/layer/multiheadattention.h index 2de5213ca..add711b11 100644 --- a/src/layer/multiheadattention.h +++ b/src/layer/multiheadattention.h @@ -32,7 +32,7 @@ public: public: int embed_dim; - int num_head; + int num_heads; int weight_data_size; int kdim; int vdim; diff --git a/src/layer/vulkan/gemm_vulkan.cpp b/src/layer/vulkan/gemm_vulkan.cpp new file mode 100644 index 000000000..ad768c63d --- /dev/null +++ b/src/layer/vulkan/gemm_vulkan.cpp @@ -0,0 +1,425 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "gemm_vulkan.h" + +#include "layer_shader_type.h" + +namespace ncnn { + +Gemm_vulkan::Gemm_vulkan() +{ + support_vulkan = true; + support_image_storage = true; + + pipeline_gemm = 0; +} + +int Gemm_vulkan::create_pipeline(const Option& opt) +{ + // const Mat& shape = top_shapes.empty() ? Mat() : top_shapes[0]; + + // int elempack = 1; + // if (shape.dims == 2) elempack = opt.use_shader_pack8 && shape.h % 8 == 0 ? 8 : shape.h % 4 == 0 ? 4 : 1; + + // size_t elemsize; + // if (opt.use_fp16_storage) + // { + // elemsize = elempack * 2u; + // } + // else if (opt.use_fp16_packed) + // { + // elemsize = elempack == 1 ? 4u : elempack * 2u; + // } + // else + // { + // elemsize = elempack * 4u; + // } + + // Mat shape_packed; + // if (shape.dims == 2) shape_packed = Mat(shape.w, shape.h / elempack, (void*)0, elemsize, elempack); + + if (constantA) + { + A_data_packed = transA ? A_data.reshape(constantM, constantK) : A_data.reshape(constantK, constantM); + } + if (constantB) + { + B_data_packed = transB ? B_data.reshape(constantK, constantN) : B_data.reshape(constantN, constantK); + } + if (constantC) + { + C_data_packed = C_data; + } + + std::vector specializations(15); + specializations[0].f = alpha; + specializations[1].f = beta; + specializations[2].i = transA; + specializations[3].i = transB; + specializations[4].i = constantA; + specializations[5].i = constantB; + specializations[6].i = constantC; + specializations[7].i = constantM; + specializations[8].i = constantN; + specializations[9].i = constantK; + specializations[10].i = constant_broadcast_type_C; + specializations[11].i = output_N1M; + specializations[12].i = output_elempack; + specializations[13].i = output_elemtype; + specializations[14].i = output_transpose; + + Mat local_size_xyz; + // if (shape_packed.dims == 2) + // { + // local_size_xyz.w = std::min(8, shape_packed.w); + // local_size_xyz.h = std::min(8, shape_packed.h); + // local_size_xyz.c = 1; + // } + + // pack1 + // if (shape.dims == 0 || elempack == 1) + { + pipeline_gemm = new Pipeline(vkdev); + pipeline_gemm->set_optimal_local_size_xyz(local_size_xyz); + if (opt.use_shader_local_memory) + { + pipeline_gemm->set_local_size_xyz(8, 8, 1); + } + pipeline_gemm->create(LayerShaderType::gemm, opt, specializations); + } + + return 0; +} + +int Gemm_vulkan::destroy_pipeline(const Option& /*opt*/) +{ + delete pipeline_gemm; + pipeline_gemm = 0; + + return 0; +} + +int Gemm_vulkan::upload_model(VkTransfer& cmd, const Option& opt) +{ + if (constantA) + { + if (support_image_storage && opt.use_image_storage) + { + cmd.record_upload(A_data_packed, A_data_gpu_image, opt); + } + else + { + cmd.record_upload(A_data_packed, A_data_gpu, opt); + } + + A_data_packed.release(); + } + + if (constantB) + { + if (support_image_storage && opt.use_image_storage) + { + cmd.record_upload(B_data_packed, B_data_gpu_image, opt); + } + else + { + cmd.record_upload(B_data_packed, B_data_gpu, opt); + } + + B_data_packed.release(); + } + + if (constantC) + { + if (support_image_storage && opt.use_image_storage) + { + cmd.record_upload(C_data_packed, C_data_gpu_image, opt); + } + else + { + cmd.record_upload(C_data_packed, C_data_gpu, opt); + } + + C_data_packed.release(); + } + + return 0; +} + +int Gemm_vulkan::forward(const std::vector& bottom_blobs, std::vector& top_blobs, VkCompute& cmd, const Option& opt) const +{ + const VkMat& A0 = constantA ? A_data_gpu : bottom_blobs[0]; + const VkMat& B0 = constantB ? B_data_gpu : constantA ? bottom_blobs[0] : bottom_blobs[1]; + const VkMat& C0 = constantC ? C_data_gpu : bottom_blobs[bottom_blobs.size() - 1]; + + VkMat A; + VkMat B; + VkMat C; + vkdev->convert_packing(A0, A, 1, cmd, opt); + vkdev->convert_packing(B0, B, 1, cmd, opt); + vkdev->convert_packing(C0, C, 1, cmd, opt); + + const int M = constantM ? constantM : transA ? A.w : (A.dims == 3 ? A.c : A.h); + const int K = constantK ? constantK : transA ? (A.dims == 3 ? A.c : A.h) : A.w; + const int N = constantN ? constantN : transB ? (B.dims == 3 ? B.c : B.h) : B.w; + + int broadcast_type_C; + if (constantC) + { + broadcast_type_C = constant_broadcast_type_C; + } + else + { + if (C.dims == 1 && C.w == 1) + { + // scalar + broadcast_type_C = 0; + } + if (C.dims == 1 && C.w == M) + { + // M + // auto broadcast from h to w is the ncnn-style convention + broadcast_type_C = 1; + } + if (C.dims == 1 && C.w == N) + { + // N + broadcast_type_C = 4; + } + if (C.dims == 2 && C.w == 1 && C.h == M) + { + // Mx1 + broadcast_type_C = 2; + } + if (C.dims == 2 && C.w == N && C.h == M) + { + // MxN + broadcast_type_C = 3; + } + if (C.dims == 2 && C.w == N && C.h == 1) + { + // 1xN + broadcast_type_C = 4; + } + } + + int elempack = A.elempack; + size_t elemsize = A.elemsize; + + VkMat& top_blob = top_blobs[0]; + if (output_transpose) + { + if (output_N1M) + top_blob.create(M, 1, N, elemsize, opt.blob_vkallocator); + else + top_blob.create(M, N, elemsize, opt.blob_vkallocator); + } + else + { + if (output_N1M) + top_blob.create(N, 1, M, elemsize, opt.blob_vkallocator); + else + top_blob.create(N, M, elemsize, opt.blob_vkallocator); + } + if (top_blob.empty()) + return -100; + + std::vector bindings(4); + bindings[0] = top_blob; + bindings[1] = A; + bindings[2] = B; + bindings[3] = C; + + std::vector constants(10); + constants[0].i = M; + constants[1].i = N; + constants[2].i = K; + constants[3].i = broadcast_type_C; + constants[4].i = A.dims; + constants[5].i = A.dims == 3 ? A.cstep : transA ? M : K; + constants[6].i = B.dims; + constants[7].i = B.dims == 3 ? B.cstep : transB ? K : N; + constants[8].i = top_blob.dims; + constants[9].i = top_blob.dims == 3 ? top_blob.cstep : top_blob.w; + + const Pipeline* pipeline = pipeline_gemm; + + VkMat dispatcher; + dispatcher.w = (N + 1) / 2; + dispatcher.h = (M + 1) / 2; + dispatcher.c = 1; + cmd.record_pipeline(pipeline, bindings, constants, dispatcher); + + int out_elempack = 1; + { + int outh = output_transpose ? N : M; + out_elempack = opt.use_shader_pack8 && outh % 8 == 0 ? 8 : outh % 4 == 0 ? 4 : 1; + } + if (output_elempack) + out_elempack = output_elempack; + + if (out_elempack != 1) + { + VkMat top_blob0; + vkdev->convert_packing(top_blob, top_blob0, out_elempack, cmd, opt); + top_blobs[0] = top_blob0; + } + + return 0; +} + +int Gemm_vulkan::forward(const VkMat& bottom_blob, VkMat& top_blob, VkCompute& cmd, const Option& opt) const +{ + std::vector bottom_blobs(1); + std::vector top_blobs(1); + bottom_blobs[0] = bottom_blob; + int ret = forward(bottom_blobs, top_blobs, cmd, opt); + top_blob = top_blobs[0]; + return ret; +} + +int Gemm_vulkan::forward(const std::vector& bottom_blobs, std::vector& top_blobs, VkCompute& cmd, const Option& opt) const +{ + const VkImageMat& A0 = constantA ? A_data_gpu_image : bottom_blobs[0]; + const VkImageMat& B0 = constantB ? B_data_gpu_image : constantA ? bottom_blobs[0] : bottom_blobs[1]; + const VkImageMat& C0 = constantC ? C_data_gpu_image : bottom_blobs[bottom_blobs.size() - 1]; + + VkImageMat A; + VkImageMat B; + VkImageMat C; + vkdev->convert_packing(A0, A, 1, cmd, opt); + vkdev->convert_packing(B0, B, 1, cmd, opt); + vkdev->convert_packing(C0, C, 1, cmd, opt); + + const int M = constantM ? constantM : transA ? A.w : (A.dims == 3 ? A.c : A.h); + const int K = constantK ? constantK : transA ? (A.dims == 3 ? A.c : A.h) : A.w; + const int N = constantN ? constantN : transB ? (B.dims == 3 ? B.c : B.h) : B.w; + + int broadcast_type_C; + if (constantC) + { + broadcast_type_C = constant_broadcast_type_C; + } + else + { + if (C.dims == 1 && C.w == 1) + { + // scalar + broadcast_type_C = 0; + } + if (C.dims == 1 && C.w == M) + { + // M + // auto broadcast from h to w is the ncnn-style convention + broadcast_type_C = 1; + } + if (C.dims == 1 && C.w == N) + { + // N + broadcast_type_C = 4; + } + if (C.dims == 2 && C.w == 1 && C.h == M) + { + // Mx1 + broadcast_type_C = 2; + } + if (C.dims == 2 && C.w == N && C.h == M) + { + // MxN + broadcast_type_C = 3; + } + if (C.dims == 2 && C.w == N && C.h == 1) + { + // 1xN + broadcast_type_C = 4; + } + } + + int elempack = A.elempack; + size_t elemsize = A.elemsize; + + VkImageMat& top_blob = top_blobs[0]; + if (output_transpose) + { + if (output_N1M) + top_blob.create(M, 1, N, elemsize, opt.blob_vkallocator); + else + top_blob.create(M, N, elemsize, opt.blob_vkallocator); + } + else + { + if (output_N1M) + top_blob.create(N, 1, M, elemsize, opt.blob_vkallocator); + else + top_blob.create(N, M, elemsize, opt.blob_vkallocator); + } + if (top_blob.empty()) + return -100; + + std::vector bindings(4); + bindings[0] = top_blob; + bindings[1] = A; + bindings[2] = B; + bindings[3] = C; + + std::vector constants(10); + constants[0].i = M; + constants[1].i = N; + constants[2].i = K; + constants[3].i = broadcast_type_C; + constants[4].i = A.dims; + constants[5].i = 0; //A.w; + constants[6].i = B.dims; + constants[7].i = 0; //B.w; + constants[8].i = top_blob.dims; + constants[9].i = 0; //top_blob.w; + + const Pipeline* pipeline = pipeline_gemm; + + VkImageMat dispatcher; + dispatcher.w = (N + 1) / 2; + dispatcher.h = (M + 1) / 2; + dispatcher.c = 1; + cmd.record_pipeline(pipeline, bindings, constants, dispatcher); + + int out_elempack = 1; + { + int outh = output_transpose ? N : M; + out_elempack = opt.use_shader_pack8 && outh % 8 == 0 ? 8 : outh % 4 == 0 ? 4 : 1; + } + if (output_elempack) + out_elempack = output_elempack; + + if (out_elempack != 1) + { + VkImageMat top_blob0; + vkdev->convert_packing(top_blob, top_blob0, out_elempack, cmd, opt); + top_blobs[0] = top_blob0; + } + + return 0; +} + +int Gemm_vulkan::forward(const VkImageMat& bottom_blob, VkImageMat& top_blob, VkCompute& cmd, const Option& opt) const +{ + std::vector bottom_blobs(1); + std::vector top_blobs(1); + bottom_blobs[0] = bottom_blob; + int ret = forward(bottom_blobs, top_blobs, cmd, opt); + top_blob = top_blobs[0]; + return ret; +} + +} // namespace ncnn diff --git a/src/layer/vulkan/gemm_vulkan.h b/src/layer/vulkan/gemm_vulkan.h new file mode 100644 index 000000000..4edbc2f54 --- /dev/null +++ b/src/layer/vulkan/gemm_vulkan.h @@ -0,0 +1,56 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef LAYER_GEMM_VULKAN_H +#define LAYER_GEMM_VULKAN_H + +#include "gemm.h" + +namespace ncnn { + +class Gemm_vulkan : virtual public Gemm +{ +public: + Gemm_vulkan(); + + virtual int create_pipeline(const Option& opt); + virtual int destroy_pipeline(const Option& opt); + + virtual int upload_model(VkTransfer& cmd, const Option& opt); + + using Gemm::forward; + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, VkCompute& cmd, const Option& opt) const; + virtual int forward(const VkMat& bottom_blob, VkMat& top_blob, VkCompute& cmd, const Option& opt) const; + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, VkCompute& cmd, const Option& opt) const; + virtual int forward(const VkImageMat& bottom_blob, VkImageMat& top_blob, VkCompute& cmd, const Option& opt) const; + +public: + Mat A_data_packed; + Mat B_data_packed; + Mat C_data_packed; + + VkMat A_data_gpu; + VkMat B_data_gpu; + VkMat C_data_gpu; + + VkImageMat A_data_gpu_image; + VkImageMat B_data_gpu_image; + VkImageMat C_data_gpu_image; + + Pipeline* pipeline_gemm; +}; + +} // namespace ncnn + +#endif // LAYER_GEMM_VULKAN_H diff --git a/src/layer/vulkan/multiheadattention_vulkan.cpp b/src/layer/vulkan/multiheadattention_vulkan.cpp new file mode 100644 index 000000000..27bd6f7f5 --- /dev/null +++ b/src/layer/vulkan/multiheadattention_vulkan.cpp @@ -0,0 +1,705 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "multiheadattention_vulkan.h" + +#include "layer_shader_type.h" +#include "layer_type.h" + +namespace ncnn { + +MultiHeadAttention_vulkan::MultiHeadAttention_vulkan() +{ + support_vulkan = true; + support_image_storage = true; + + q_gemm = 0; + k_gemm = 0; + v_gemm = 0; + + qk_softmax = 0; + + o_gemm = 0; + + pipeline_multiheadattention_qk_cross = 0; + pipeline_multiheadattention_qk_cross_pack4 = 0; + pipeline_multiheadattention_qk_cross_pack1to4 = 0; + pipeline_multiheadattention_qk_cross_pack4to1 = 0; + + pipeline_multiheadattention_qkv_cross = 0; + pipeline_multiheadattention_qkv_cross_pack4 = 0; + pipeline_multiheadattention_qkv_cross_pack1to4 = 0; + pipeline_multiheadattention_qkv_cross_pack4to1 = 0; +} + +int MultiHeadAttention_vulkan::create_pipeline(const Option& opt) +{ + const int embed_dim_per_head = embed_dim / num_heads; + { + const float inv_sqrt_embed_dim_per_head = 1.f / sqrt(embed_dim_per_head); + + q_gemm = ncnn::create_layer(ncnn::LayerType::Gemm); + q_gemm->vkdev = vkdev; + ncnn::ParamDict pd; + pd.set(0, inv_sqrt_embed_dim_per_head); + pd.set(1, 1.f); + pd.set(2, 0); // transA + pd.set(3, 1); // transB + pd.set(4, 1); // constantA + pd.set(5, 0); // constantB + pd.set(6, 1); // constantC + pd.set(7, embed_dim); // M + pd.set(8, 0); // N + pd.set(9, embed_dim); // K + pd.set(10, 1); // constant_broadcast_type_C + pd.set(11, 0); // output_N1M + // pd.set(12, 1); // output_elempack + pd.set(14, 0); // output_transpose + q_gemm->load_param(pd); + Mat weights[2]; + weights[0] = q_weight_data; + weights[1] = q_bias_data; + q_gemm->load_model(ModelBinFromMatArray(weights)); + q_gemm->create_pipeline(opt); + } + + { + k_gemm = ncnn::create_layer(ncnn::LayerType::Gemm); + k_gemm->vkdev = vkdev; + ncnn::ParamDict pd; + pd.set(2, 0); // transA + pd.set(3, 1); // transB + pd.set(4, 1); // constantA + pd.set(5, 0); // constantB + pd.set(6, 1); // constantC + pd.set(7, embed_dim); // M + pd.set(8, 0); // N + pd.set(9, kdim); // K + pd.set(10, 1); // constant_broadcast_type_C + pd.set(11, 0); // output_N1M + // pd.set(12, 1); // output_elempack + pd.set(14, 0); // output_transpose + k_gemm->load_param(pd); + Mat weights[2]; + weights[0] = k_weight_data; + weights[1] = k_bias_data; + k_gemm->load_model(ModelBinFromMatArray(weights)); + k_gemm->create_pipeline(opt); + } + + { + v_gemm = ncnn::create_layer(ncnn::LayerType::Gemm); + v_gemm->vkdev = vkdev; + ncnn::ParamDict pd; + pd.set(2, 0); // transA + pd.set(3, 1); // transB + pd.set(4, 1); // constantA + pd.set(5, 0); // constantB + pd.set(6, 1); // constantC + pd.set(7, embed_dim); // M + pd.set(8, 0); // N + pd.set(9, vdim); // K + pd.set(10, 1); // constant_broadcast_type_C + pd.set(11, 0); // output_N1M + // pd.set(12, 1); // output_elempack + pd.set(14, 0); // output_transpose + v_gemm->load_param(pd); + Mat weights[2]; + weights[0] = v_weight_data; + weights[1] = v_bias_data; + v_gemm->load_model(ModelBinFromMatArray(weights)); + v_gemm->create_pipeline(opt); + } + + { + std::vector specializations(4); + specializations[0].i = 0; //constantM; + specializations[1].i = 0; //constantN; + specializations[2].i = 0; //embed_dim_per_head;//constantK; + specializations[3].i = num_heads; + + { + pipeline_multiheadattention_qk_cross = new Pipeline(vkdev); + pipeline_multiheadattention_qk_cross->set_local_size_xyz(8, 8, 1); + pipeline_multiheadattention_qk_cross->create(LayerShaderType::multiheadattention_qk_cross, opt, specializations); + } + { + pipeline_multiheadattention_qk_cross_pack4 = new Pipeline(vkdev); + pipeline_multiheadattention_qk_cross_pack4->set_local_size_xyz(8, 8, 1); + pipeline_multiheadattention_qk_cross_pack4->create(LayerShaderType::multiheadattention_qk_cross_pack4, opt, specializations); + } + { + pipeline_multiheadattention_qk_cross_pack1to4 = new Pipeline(vkdev); + pipeline_multiheadattention_qk_cross_pack1to4->set_local_size_xyz(8, 8, 1); + pipeline_multiheadattention_qk_cross_pack1to4->create(LayerShaderType::multiheadattention_qk_cross_pack1to4, opt, specializations); + } + { + pipeline_multiheadattention_qk_cross_pack4to1 = new Pipeline(vkdev); + pipeline_multiheadattention_qk_cross_pack4to1->set_local_size_xyz(8, 8, 1); + pipeline_multiheadattention_qk_cross_pack4to1->create(LayerShaderType::multiheadattention_qk_cross_pack4to1, opt, specializations); + } + } + { + std::vector specializations(4); + specializations[0].i = 0; //constantM; + specializations[1].i = 0; //embed_dim_per_head;//constantN; + specializations[2].i = 0; //constantK; + specializations[3].i = num_heads; + + { + pipeline_multiheadattention_qkv_cross = new Pipeline(vkdev); + pipeline_multiheadattention_qkv_cross->set_local_size_xyz(8, 8, 1); + pipeline_multiheadattention_qkv_cross->create(LayerShaderType::multiheadattention_qkv_cross, opt, specializations); + } + { + pipeline_multiheadattention_qkv_cross_pack4 = new Pipeline(vkdev); + pipeline_multiheadattention_qkv_cross_pack4->set_local_size_xyz(8, 8, 1); + pipeline_multiheadattention_qkv_cross_pack4->create(LayerShaderType::multiheadattention_qkv_cross_pack4, opt, specializations); + } + { + pipeline_multiheadattention_qkv_cross_pack1to4 = new Pipeline(vkdev); + pipeline_multiheadattention_qkv_cross_pack1to4->set_local_size_xyz(8, 8, 1); + pipeline_multiheadattention_qkv_cross_pack1to4->create(LayerShaderType::multiheadattention_qkv_cross_pack1to4, opt, specializations); + } + { + pipeline_multiheadattention_qkv_cross_pack4to1 = new Pipeline(vkdev); + pipeline_multiheadattention_qkv_cross_pack4to1->set_local_size_xyz(8, 8, 1); + pipeline_multiheadattention_qkv_cross_pack4to1->create(LayerShaderType::multiheadattention_qkv_cross_pack4to1, opt, specializations); + } + } + + { + qk_softmax = ncnn::create_layer(ncnn::LayerType::Softmax); + qk_softmax->vkdev = vkdev; + ncnn::ParamDict pd; + pd.set(0, -1); + pd.set(1, 1); + qk_softmax->load_param(pd); + qk_softmax->load_model(ModelBinFromMatArray(0)); + qk_softmax->create_pipeline(opt); + } + + { + o_gemm = ncnn::create_layer(ncnn::LayerType::Gemm); + o_gemm->vkdev = vkdev; + ncnn::ParamDict pd; + pd.set(2, 1); // transA + pd.set(3, 1); // transB + pd.set(4, 0); // constantA + pd.set(5, 1); // constantB + pd.set(6, 1); // constantC + pd.set(7, 0); // M = outch + pd.set(8, embed_dim); // N = size + pd.set(9, embed_dim); // K = maxk*inch + pd.set(10, 4); // constant_broadcast_type_C + pd.set(11, 0); // output_N1M + o_gemm->load_param(pd); + Mat weights[2]; + weights[0] = out_weight_data; + weights[1] = out_bias_data; + o_gemm->load_model(ModelBinFromMatArray(weights)); + o_gemm->create_pipeline(opt); + } + + return 0; +} + +int MultiHeadAttention_vulkan::destroy_pipeline(const Option& opt) +{ + if (q_gemm) + { + q_gemm->destroy_pipeline(opt); + delete q_gemm; + q_gemm = 0; + } + + if (k_gemm) + { + k_gemm->destroy_pipeline(opt); + delete k_gemm; + k_gemm = 0; + } + + if (v_gemm) + { + v_gemm->destroy_pipeline(opt); + delete v_gemm; + v_gemm = 0; + } + + delete pipeline_multiheadattention_qk_cross; + pipeline_multiheadattention_qk_cross = 0; + + delete pipeline_multiheadattention_qk_cross_pack4; + pipeline_multiheadattention_qk_cross_pack4 = 0; + + delete pipeline_multiheadattention_qk_cross_pack1to4; + pipeline_multiheadattention_qk_cross_pack1to4 = 0; + + delete pipeline_multiheadattention_qk_cross_pack4to1; + pipeline_multiheadattention_qk_cross_pack4to1 = 0; + + delete pipeline_multiheadattention_qkv_cross; + pipeline_multiheadattention_qkv_cross = 0; + + delete pipeline_multiheadattention_qkv_cross_pack4; + pipeline_multiheadattention_qkv_cross_pack4 = 0; + + delete pipeline_multiheadattention_qkv_cross_pack1to4; + pipeline_multiheadattention_qkv_cross_pack1to4 = 0; + + delete pipeline_multiheadattention_qkv_cross_pack4to1; + pipeline_multiheadattention_qkv_cross_pack4to1 = 0; + + if (qk_softmax) + { + qk_softmax->destroy_pipeline(opt); + delete qk_softmax; + qk_softmax = 0; + } + + if (o_gemm) + { + o_gemm->destroy_pipeline(opt); + delete o_gemm; + o_gemm = 0; + } + + return 0; +} + +int MultiHeadAttention_vulkan::upload_model(VkTransfer& cmd, const Option& opt) +{ + if (q_gemm) + { + q_gemm->upload_model(cmd, opt); + } + + if (k_gemm) + { + k_gemm->upload_model(cmd, opt); + } + + if (v_gemm) + { + v_gemm->upload_model(cmd, opt); + } + + if (o_gemm) + { + o_gemm->upload_model(cmd, opt); + } + + return 0; +} + +int MultiHeadAttention_vulkan::forward(const std::vector& bottom_blobs, std::vector& top_blobs, VkCompute& cmd, const Option& opt) const +{ + const VkMat& q_blob = bottom_blobs[0]; + const VkMat& k_blob = bottom_blobs.size() == 1 ? q_blob : bottom_blobs[1]; + const VkMat& v_blob = bottom_blobs.size() == 1 ? q_blob : bottom_blobs.size() == 2 ? k_blob : bottom_blobs[2]; + + const int embed_dim_per_head = embed_dim / num_heads; + const int src_seqlen = q_blob.h * q_blob.elempack; + const int dst_seqlen = k_blob.h * k_blob.elempack; + + VkMat q_affine; + q_gemm->forward(q_blob, q_affine, cmd, opt); + + VkMat k_affine; + k_gemm->forward(k_blob, k_affine, cmd, opt); + + VkMat qk_cross; + { + int M = q_affine.w; + int N = k_affine.w; + int K = q_affine.h * q_affine.elempack / num_heads; + int B = num_heads; + + // int K_elempack = opt.use_shader_pack8 && K % 8 == 0 ? 8 : K % 4 == 0 ? 4 : 1; + // int M_elempack = opt.use_shader_pack8 && M % 8 == 0 ? 8 : M % 4 == 0 ? 4 : 1; + // int MB_elempack = opt.use_shader_pack8 && (M * B) % 8 == 0 ? 8 : (M * B) % 4 == 0 ? 4 : 1; + int K_elempack = K % 4 == 0 ? 4 : 1; + int M_elempack = M % 4 == 0 ? 4 : 1; + int MB_elempack = (M * B) % 4 == 0 ? 4 : 1; + size_t M_elemsize = q_affine.elemsize / q_affine.elempack * M_elempack; + + if (opt.use_fp16_packed && !opt.use_fp16_storage) + { + if (M_elempack == 8) M_elemsize = 8 * 2u; + if (M_elempack == 4) M_elemsize = 4 * 2u; + if (M_elempack == 1) M_elemsize = 4u; + } + + if (K_elempack < q_affine.elempack) + { + VkMat tmp; + vkdev->convert_packing(q_affine, tmp, K_elempack, cmd, opt); + q_affine = tmp; + } + if (K_elempack < k_affine.elempack) + { + VkMat tmp; + vkdev->convert_packing(k_affine, tmp, K_elempack, cmd, opt); + k_affine = tmp; + } + + qk_cross.create(N, M / M_elempack * B, M_elemsize, M_elempack, opt.blob_vkallocator); + if (qk_cross.empty()) + return -100; + + std::vector bindings(3); + bindings[0] = q_affine; + bindings[1] = k_affine; + bindings[2] = qk_cross; + + std::vector constants(4); + constants[0].i = M / M_elempack; + constants[1].i = N; + constants[2].i = K / K_elempack; + constants[3].i = B; + + VkMat dispatcher; + dispatcher.w = N; + dispatcher.h = M / M_elempack; + dispatcher.c = B; + + const Pipeline* pipeline = 0; + if (K_elempack == 1 && M_elempack == 1) + { + pipeline = pipeline_multiheadattention_qk_cross; + } + if (K_elempack == 1 && M_elempack == 4) + { + pipeline = pipeline_multiheadattention_qk_cross_pack1to4; + } + if (K_elempack == 4 && M_elempack == 1) + { + pipeline = pipeline_multiheadattention_qk_cross_pack4to1; + } + if (K_elempack == 4 && M_elempack == 4) + { + pipeline = pipeline_multiheadattention_qk_cross_pack4; + } + + cmd.record_pipeline(pipeline, bindings, constants, dispatcher); + + if (MB_elempack > M_elempack) + { + VkMat tmp; + vkdev->convert_packing(qk_cross, tmp, MB_elempack, cmd, opt); + qk_cross = tmp; + } + } + + q_affine.release(); + k_affine.release(); + + qk_softmax->forward_inplace(qk_cross, cmd, opt); + + VkMat v_affine; + v_gemm->forward(v_blob, v_affine, cmd, opt); + + VkMat qkv_cross; + { + int M = qk_cross.h * qk_cross.elempack / num_heads; + int N = v_affine.h * v_affine.elempack / num_heads; + int K = v_affine.w; + int B = num_heads; + + // int M_elempack = opt.use_shader_pack8 && M % 8 == 0 ? 8 : M % 4 == 0 ? 4 : 1; + // int N_elempack = opt.use_shader_pack8 && N % 8 == 0 ? 8 : N % 4 == 0 ? 4 : 1; + // int NB_elempack = opt.use_shader_pack8 && (N * B) % 8 == 0 ? 8 : (N * B) % 4 == 0 ? 4 : 1; + int M_elempack = M % 4 == 0 ? 4 : 1; + int N_elempack = N % 4 == 0 ? 4 : 1; + int NB_elempack = (N * B) % 4 == 0 ? 4 : 1; + size_t N_elemsize = v_affine.elemsize / v_affine.elempack * N_elempack; + + if (opt.use_fp16_packed && !opt.use_fp16_storage) + { + if (N_elempack == 8) N_elemsize = 8 * 2u; + if (N_elempack == 4) N_elemsize = 4 * 2u; + if (N_elempack == 1) N_elemsize = 4u; + } + + if (M_elempack < qk_cross.elempack) + { + VkMat tmp; + vkdev->convert_packing(qk_cross, tmp, M_elempack, cmd, opt); + qk_cross = tmp; + } + + if (N_elempack < v_affine.elempack) + { + VkMat tmp; + vkdev->convert_packing(v_affine, tmp, N_elempack, cmd, opt); + v_affine = tmp; + } + + qkv_cross.create(M, N / N_elempack * B, N_elemsize, N_elempack, opt.blob_vkallocator); + if (qkv_cross.empty()) + return -100; + + std::vector bindings(3); + bindings[0] = qk_cross; + bindings[1] = v_affine; + bindings[2] = qkv_cross; + + std::vector constants(4); + constants[0].i = M / M_elempack; + constants[1].i = N / N_elempack; + constants[2].i = K; + constants[3].i = B; + + VkMat dispatcher; + dispatcher.w = N / N_elempack; + dispatcher.h = M / M_elempack; + dispatcher.c = B; + + const Pipeline* pipeline = 0; + if (M_elempack == 1 && N_elempack == 1) + { + pipeline = pipeline_multiheadattention_qkv_cross; + } + if (M_elempack == 1 && N_elempack == 4) + { + pipeline = pipeline_multiheadattention_qkv_cross_pack1to4; + } + if (M_elempack == 4 && N_elempack == 1) + { + pipeline = pipeline_multiheadattention_qkv_cross_pack4to1; + } + if (M_elempack == 4 && N_elempack == 4) + { + pipeline = pipeline_multiheadattention_qkv_cross_pack4; + } + + cmd.record_pipeline(pipeline, bindings, constants, dispatcher); + + if (NB_elempack > N_elempack) + { + VkMat tmp; + vkdev->convert_packing(qkv_cross, tmp, NB_elempack, cmd, opt); + qkv_cross = tmp; + } + } + + qk_cross.release(); + v_affine.release(); + + o_gemm->forward(qkv_cross, top_blobs[0], cmd, opt); + + return 0; +} + +int MultiHeadAttention_vulkan::forward(const std::vector& bottom_blobs, std::vector& top_blobs, VkCompute& cmd, const Option& opt) const +{ + const VkImageMat& q_blob = bottom_blobs[0]; + const VkImageMat& k_blob = bottom_blobs.size() == 1 ? q_blob : bottom_blobs[1]; + const VkImageMat& v_blob = bottom_blobs.size() == 1 ? q_blob : bottom_blobs.size() == 2 ? k_blob : bottom_blobs[2]; + + const int embed_dim_per_head = embed_dim / num_heads; + const int src_seqlen = q_blob.h * q_blob.elempack; + const int dst_seqlen = k_blob.h * k_blob.elempack; + + VkImageMat q_affine; + q_gemm->forward(q_blob, q_affine, cmd, opt); + + VkImageMat k_affine; + k_gemm->forward(k_blob, k_affine, cmd, opt); + + VkImageMat qk_cross; + { + int M = q_affine.w; + int N = k_affine.w; + int K = q_affine.h * q_affine.elempack / num_heads; + int B = num_heads; + + // int K_elempack = opt.use_shader_pack8 && K % 8 == 0 ? 8 : K % 4 == 0 ? 4 : 1; + // int M_elempack = opt.use_shader_pack8 && M % 8 == 0 ? 8 : M % 4 == 0 ? 4 : 1; + // int MB_elempack = opt.use_shader_pack8 && (M * B) % 8 == 0 ? 8 : (M * B) % 4 == 0 ? 4 : 1; + int K_elempack = K % 4 == 0 ? 4 : 1; + int M_elempack = M % 4 == 0 ? 4 : 1; + int MB_elempack = (M * B) % 4 == 0 ? 4 : 1; + size_t M_elemsize = q_affine.elemsize / q_affine.elempack * M_elempack; + + if (opt.use_fp16_packed && !opt.use_fp16_storage) + { + if (M_elempack == 8) M_elemsize = 8 * 2u; + if (M_elempack == 4) M_elemsize = 4 * 2u; + if (M_elempack == 1) M_elemsize = 4u; + } + + if (K_elempack < q_affine.elempack) + { + VkImageMat tmp; + vkdev->convert_packing(q_affine, tmp, K_elempack, cmd, opt); + q_affine = tmp; + } + if (K_elempack < k_affine.elempack) + { + VkImageMat tmp; + vkdev->convert_packing(k_affine, tmp, K_elempack, cmd, opt); + k_affine = tmp; + } + + qk_cross.create(N, M / M_elempack * B, M_elemsize, M_elempack, opt.blob_vkallocator); + if (qk_cross.empty()) + return -100; + + std::vector bindings(3); + bindings[0] = q_affine; + bindings[1] = k_affine; + bindings[2] = qk_cross; + + std::vector constants(4); + constants[0].i = M / M_elempack; + constants[1].i = N; + constants[2].i = K / K_elempack; + constants[3].i = B; + + VkImageMat dispatcher; + dispatcher.w = N; + dispatcher.h = M / M_elempack; + dispatcher.c = B; + + const Pipeline* pipeline = 0; + if (K_elempack == 1 && M_elempack == 1) + { + pipeline = pipeline_multiheadattention_qk_cross; + } + if (K_elempack == 1 && M_elempack == 4) + { + pipeline = pipeline_multiheadattention_qk_cross_pack1to4; + } + if (K_elempack == 4 && M_elempack == 1) + { + pipeline = pipeline_multiheadattention_qk_cross_pack4to1; + } + if (K_elempack == 4 && M_elempack == 4) + { + pipeline = pipeline_multiheadattention_qk_cross_pack4; + } + + cmd.record_pipeline(pipeline, bindings, constants, dispatcher); + + if (MB_elempack > M_elempack) + { + VkImageMat tmp; + vkdev->convert_packing(qk_cross, tmp, MB_elempack, cmd, opt); + qk_cross = tmp; + } + } + + q_affine.release(); + k_affine.release(); + + qk_softmax->forward_inplace(qk_cross, cmd, opt); + + VkImageMat v_affine; + v_gemm->forward(v_blob, v_affine, cmd, opt); + + VkImageMat qkv_cross; + { + int M = qk_cross.h * qk_cross.elempack / num_heads; + int N = v_affine.h * v_affine.elempack / num_heads; + int K = v_affine.w; + int B = num_heads; + + // int M_elempack = opt.use_shader_pack8 && M % 8 == 0 ? 8 : M % 4 == 0 ? 4 : 1; + // int N_elempack = opt.use_shader_pack8 && N % 8 == 0 ? 8 : N % 4 == 0 ? 4 : 1; + // int NB_elempack = opt.use_shader_pack8 && (N * B) % 8 == 0 ? 8 : (N * B) % 4 == 0 ? 4 : 1; + int M_elempack = M % 4 == 0 ? 4 : 1; + int N_elempack = N % 4 == 0 ? 4 : 1; + int NB_elempack = (N * B) % 4 == 0 ? 4 : 1; + size_t N_elemsize = v_affine.elemsize / v_affine.elempack * N_elempack; + + if (opt.use_fp16_packed && !opt.use_fp16_storage) + { + if (N_elempack == 8) N_elemsize = 8 * 2u; + if (N_elempack == 4) N_elemsize = 4 * 2u; + if (N_elempack == 1) N_elemsize = 4u; + } + + if (M_elempack < qk_cross.elempack) + { + VkImageMat tmp; + vkdev->convert_packing(qk_cross, tmp, M_elempack, cmd, opt); + qk_cross = tmp; + } + + if (N_elempack < v_affine.elempack) + { + VkImageMat tmp; + vkdev->convert_packing(v_affine, tmp, N_elempack, cmd, opt); + v_affine = tmp; + } + + qkv_cross.create(M, N / N_elempack * B, N_elemsize, N_elempack, opt.blob_vkallocator); + if (qkv_cross.empty()) + return -100; + + std::vector bindings(3); + bindings[0] = qk_cross; + bindings[1] = v_affine; + bindings[2] = qkv_cross; + + std::vector constants(4); + constants[0].i = M / M_elempack; + constants[1].i = N / N_elempack; + constants[2].i = K; + constants[3].i = B; + + VkImageMat dispatcher; + dispatcher.w = N / N_elempack; + dispatcher.h = M / M_elempack; + dispatcher.c = B; + + const Pipeline* pipeline = 0; + if (M_elempack == 1 && N_elempack == 1) + { + pipeline = pipeline_multiheadattention_qkv_cross; + } + if (M_elempack == 1 && N_elempack == 4) + { + pipeline = pipeline_multiheadattention_qkv_cross_pack1to4; + } + if (M_elempack == 4 && N_elempack == 1) + { + pipeline = pipeline_multiheadattention_qkv_cross_pack4to1; + } + if (M_elempack == 4 && N_elempack == 4) + { + pipeline = pipeline_multiheadattention_qkv_cross_pack4; + } + + cmd.record_pipeline(pipeline, bindings, constants, dispatcher); + + if (NB_elempack > N_elempack) + { + VkImageMat tmp; + vkdev->convert_packing(qkv_cross, tmp, NB_elempack, cmd, opt); + qkv_cross = tmp; + } + } + + qk_cross.release(); + v_affine.release(); + + o_gemm->forward(qkv_cross, top_blobs[0], cmd, opt); + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/vulkan/multiheadattention_vulkan.h b/src/layer/vulkan/multiheadattention_vulkan.h new file mode 100644 index 000000000..49662db47 --- /dev/null +++ b/src/layer/vulkan/multiheadattention_vulkan.h @@ -0,0 +1,57 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef LAYER_MULTIHEADATTENTION_VULKAN_H +#define LAYER_MULTIHEADATTENTION_VULKAN_H + +#include "multiheadattention.h" + +namespace ncnn { + +class MultiHeadAttention_vulkan : virtual public MultiHeadAttention +{ +public: + MultiHeadAttention_vulkan(); + + virtual int create_pipeline(const Option& opt); + virtual int destroy_pipeline(const Option& opt); + + virtual int upload_model(VkTransfer& cmd, const Option& opt); + + using MultiHeadAttention::forward; + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, VkCompute& cmd, const Option& opt) const; + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, VkCompute& cmd, const Option& opt) const; + +public: + Layer* q_gemm; + Layer* k_gemm; + Layer* v_gemm; + Layer* o_gemm; + + Layer* qk_softmax; + + Pipeline* pipeline_multiheadattention_qk_cross; + Pipeline* pipeline_multiheadattention_qk_cross_pack4; + Pipeline* pipeline_multiheadattention_qk_cross_pack1to4; + Pipeline* pipeline_multiheadattention_qk_cross_pack4to1; + + Pipeline* pipeline_multiheadattention_qkv_cross; + Pipeline* pipeline_multiheadattention_qkv_cross_pack4; + Pipeline* pipeline_multiheadattention_qkv_cross_pack1to4; + Pipeline* pipeline_multiheadattention_qkv_cross_pack4to1; +}; + +} // namespace ncnn + +#endif // LAYER_MULTIHEADATTENTION_VULKAN_H diff --git a/src/layer/vulkan/shader/gemm.comp b/src/layer/vulkan/shader/gemm.comp new file mode 100644 index 000000000..b08e5daa1 --- /dev/null +++ b/src/layer/vulkan/shader/gemm.comp @@ -0,0 +1,453 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +#define LOCAL_MEMORY_UNROLL_INCH 8 + +layout (constant_id = 0) const float alpha = 1.f; +layout (constant_id = 1) const float beta = 1.f; +layout (constant_id = 2) const int transA = 0; +layout (constant_id = 3) const int transB = 0; +layout (constant_id = 4) const int constantA = 0; +layout (constant_id = 5) const int constantB = 0; +layout (constant_id = 6) const int constantC = 0; +layout (constant_id = 7) const int M = 0; +layout (constant_id = 8) const int N = 0; +layout (constant_id = 9) const int K = 0; +layout (constant_id = 10) const int constant_broadcast_type_C = 0; +layout (constant_id = 11) const int output_N1M = 0; +layout (constant_id = 12) const int output_elempack = 0; +layout (constant_id = 13) const int output_elemtype = 0; +layout (constant_id = 14) const int output_transpose = 0; + +// TODO psc more + +#if NCNN_image_shader +layout (binding = 0, imfmtc1) writeonly uniform unfp image3D top_blob_3d; +layout (binding = 1) uniform unfp sampler3D A_blob_3d; +layout (binding = 2) uniform unfp sampler3D B_blob_3d; +layout (binding = 3) uniform unfp sampler3D C_blob_3d; +#else +layout (binding = 0) writeonly buffer top_blob { sfp top_blob_data[]; }; +layout (binding = 1) readonly buffer A_blob { sfp A_blob_data[]; }; +layout (binding = 2) readonly buffer B_blob { sfp B_blob_data[]; }; +layout (binding = 3) readonly buffer C_blob { sfp C_blob_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int M; + int N; + int K; + int broadcast_type_C; + int A_dims; + int A_hstep; + int B_dims; + int B_hstep; + int outdims; + int outhstep; +} p; + +#if NCNN_shader_local_memory +shared lfp tmp_a[8][LOCAL_MEMORY_UNROLL_INCH][2]; +shared lfp tmp_b[8][LOCAL_MEMORY_UNROLL_INCH][2]; +#endif + +void main() +{ + int gx = int(gl_GlobalInvocationID.x) * 2; + int gy = int(gl_GlobalInvocationID.y) * 2; + int gz = int(gl_GlobalInvocationID.z); + +#if !NCNN_shader_local_memory + if (gx >= psc(N) || gy >= psc(M) || gz >= 1) + return; +#endif + + afp sum0 = afp(0.f); + afp sum1 = afp(0.f); + afp sum2 = afp(0.f); + afp sum3 = afp(0.f); + + const int broadcast_type_C = constantC == 1 ? constant_broadcast_type_C : p.broadcast_type_C; + +#if NCNN_image_shader + if (broadcast_type_C == 0) + { + sum0 = image3d_ld1(C_blob_3d, ivec3(0, 0, 0)); + sum1 = sum0; + sum2 = sum0; + sum3 = sum0; + } + if (broadcast_type_C == 1) + { + sum0 = image3d_ld1(C_blob_3d, ivec3(gy, 0, 0)); + sum1 = sum0; + sum2 = image3d_ld1(C_blob_3d, ivec3(gy + 1, 0, 0)); + sum3 = sum2; + } + if (broadcast_type_C == 2) + { + sum0 = image3d_ld1(C_blob_3d, ivec3(0, gy, 0)); + sum1 = sum0; + sum2 = image3d_ld1(C_blob_3d, ivec3(0, gy + 1, 0)); + sum3 = sum2; + } + if (broadcast_type_C == 3) + { + sum0 = image3d_ld1(C_blob_3d, ivec3(gx, gy, 0)); + sum1 = image3d_ld1(C_blob_3d, ivec3(gx + 1, gy, 0)); + sum2 = image3d_ld1(C_blob_3d, ivec3(gx, gy + 1, 0)); + sum3 = image3d_ld1(C_blob_3d, ivec3(gx + 1, gy + 1, 0)); + } + if (broadcast_type_C == 4) + { + sum0 = image3d_ld1(C_blob_3d, ivec3(gx, 0, 0)); + sum1 = image3d_ld1(C_blob_3d, ivec3(gx + 1, 0, 0)); + sum2 = sum0; + sum3 = sum1; + } +#else + if (broadcast_type_C == 0) + { + sum0 = buffer_ld1(C_blob_data, 0); + sum1 = sum0; + sum2 = sum0; + sum3 = sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + sum0 = buffer_ld1(C_blob_data, gy); + sum1 = sum0; + sum2 = buffer_ld1(C_blob_data, gy + 1); + sum3 = sum2; + } + if (broadcast_type_C == 3) + { + const int ci = gy * psc(N) + gx; + sum0 = buffer_ld1(C_blob_data, ci); + sum1 = buffer_ld1(C_blob_data, ci + 1); + sum2 = buffer_ld1(C_blob_data, ci + psc(N)); + sum3 = buffer_ld1(C_blob_data, ci + psc(N) + 1); + } + if (broadcast_type_C == 4) + { + sum0 = buffer_ld1(C_blob_data, gx); + sum1 = buffer_ld1(C_blob_data, gx + 1); + sum2 = sum0; + sum3 = sum1; + } +#endif + + sum0 *= afp(beta); + sum1 *= afp(beta); + sum2 *= afp(beta); + sum3 *= afp(beta); + +#if !NCNN_image_shader && NCNN_shader_local_memory + const int NN = psc(K); + + const int lx = int(gl_LocalInvocationID.x); + const int ly = int(gl_LocalInvocationID.y); + + int k = 0; + for (; k + (LOCAL_MEMORY_UNROLL_INCH - 1) < NN; k += LOCAL_MEMORY_UNROLL_INCH) + { + { + if (transA == 1) + { + const int ai = (k + lx) * p.A_hstep + gy; + tmp_a[ly][lx][0] = sfp2lfp(A_blob_data[ai]); + tmp_a[ly][lx][1] = sfp2lfp(A_blob_data[ai + 1]); + } + else + { + const int ai = gy * p.A_hstep + (k + lx); + tmp_a[ly][lx][0] = sfp2lfp(A_blob_data[ai]); + tmp_a[ly][lx][1] = sfp2lfp(A_blob_data[ai + p.A_hstep]); + } + + if (transB == 1) + { + const int bi = gx * p.B_hstep + (k + ly); + tmp_b[lx][ly][0] = sfp2lfp(B_blob_data[bi]); + tmp_b[lx][ly][1] = sfp2lfp(B_blob_data[bi + p.B_hstep]); + } + else + { + const int bi = (k + ly) * p.B_hstep + gx; + tmp_b[lx][ly][0] = sfp2lfp(B_blob_data[bi]); + tmp_b[lx][ly][1] = sfp2lfp(B_blob_data[bi + 1]); + } + } + + barrier(); + + for (int k4 = 0; k4 < LOCAL_MEMORY_UNROLL_INCH; k4++) + { + afp a0 = lfp2afp(tmp_a[ly][k4][0]); + afp a1 = lfp2afp(tmp_a[ly][k4][1]); + + afp b0 = lfp2afp(tmp_b[lx][k4][0]); + afp b1 = lfp2afp(tmp_b[lx][k4][1]); + + sum0 += a0 * b0; + sum1 += a0 * b1; + sum2 += a1 * b0; + sum3 += a1 * b1; + } + + barrier(); + } + + if (k < NN) + { + const int remain = NN - k; + + if (lx < remain) + { + if (transA == 1) + { + const int ai = (k + lx) * p.A_hstep + gy; + tmp_a[ly][lx][0] = sfp2lfp(A_blob_data[ai]); + tmp_a[ly][lx][1] = sfp2lfp(A_blob_data[ai + 1]); + } + else + { + const int ai = gy * p.A_hstep + (k + lx); + tmp_a[ly][lx][0] = sfp2lfp(A_blob_data[ai]); + tmp_a[ly][lx][1] = sfp2lfp(A_blob_data[ai + p.A_hstep]); + } + } + + if (ly < remain) + { + if (transB == 1) + { + const int bi = gx * p.B_hstep + (k + ly); + tmp_b[lx][ly][0] = sfp2lfp(B_blob_data[bi]); + tmp_b[lx][ly][1] = sfp2lfp(B_blob_data[bi + p.B_hstep]); + } + else + { + const int bi = (k + ly) * p.B_hstep + gx; + tmp_b[lx][ly][0] = sfp2lfp(B_blob_data[bi]); + tmp_b[lx][ly][1] = sfp2lfp(B_blob_data[bi + 1]); + } + } + + barrier(); + + for (int k4 = 0; k4 < remain; k4++) + { + afp a0 = lfp2afp(tmp_a[ly][k4][0]); + afp a1 = lfp2afp(tmp_a[ly][k4][1]); + + afp b0 = lfp2afp(tmp_b[lx][k4][0]); + afp b1 = lfp2afp(tmp_b[lx][k4][1]); + + sum0 += a0 * b0; + sum1 += a0 * b1; + sum2 += a1 * b0; + sum3 += a1 * b1; + } + } +#else + for (int k = 0; k < psc(K); k++) + { + afp a0; + afp a1; + afp b0; + afp b1; +#if NCNN_image_shader + if (transA == 1) + { + if (p.A_dims == 3) + { + a0 = image3d_ld1(A_blob_3d, ivec3(gy, 0, k)); + a1 = image3d_ld1(A_blob_3d, ivec3(gy + 1, 0, k)); + } + else + { + a0 = image3d_ld1(A_blob_3d, ivec3(gy, k, 0)); + a1 = image3d_ld1(A_blob_3d, ivec3(gy + 1, k, 0)); + } + } + else + { + if (p.A_dims == 3) + { + a0 = image3d_ld1(A_blob_3d, ivec3(k, 0, gy)); + a1 = image3d_ld1(A_blob_3d, ivec3(k, 0, gy + 1)); + } + else + { + a0 = image3d_ld1(A_blob_3d, ivec3(k, gy, 0)); + a1 = image3d_ld1(A_blob_3d, ivec3(k, gy + 1, 0)); + } + } + + if (transB == 1) + { + if (p.B_dims == 3) + { + b0 = image3d_ld1(B_blob_3d, ivec3(k, 0, gx)); + b1 = image3d_ld1(B_blob_3d, ivec3(k, 0, gx + 1)); + } + else + { + b0 = image3d_ld1(B_blob_3d, ivec3(k, gx, 0)); + b1 = image3d_ld1(B_blob_3d, ivec3(k, gx + 1, 0)); + } + } + else + { + if (p.B_dims == 3) + { + b0 = image3d_ld1(B_blob_3d, ivec3(gx, 0, k)); + b1 = image3d_ld1(B_blob_3d, ivec3(gx + 1, 0, k)); + } + else + { + b0 = image3d_ld1(B_blob_3d, ivec3(gx, k, 0)); + b1 = image3d_ld1(B_blob_3d, ivec3(gx + 1, k, 0)); + } + } +#else + if (transA == 1) + { + const int ai = k * p.A_hstep + gy; + a0 = buffer_ld1(A_blob_data, ai); + a1 = buffer_ld1(A_blob_data, ai + 1); + } + else + { + const int ai = gy * p.A_hstep + k; + a0 = buffer_ld1(A_blob_data, ai); + a1 = buffer_ld1(A_blob_data, ai + p.A_hstep); + } + + if (transB == 1) + { + const int bi = gx * p.B_hstep + k; + b0 = buffer_ld1(B_blob_data, bi); + b1 = buffer_ld1(B_blob_data, bi + p.B_hstep); + } + else + { + const int bi = k * p.B_hstep + gx; + b0 = buffer_ld1(B_blob_data, bi); + b1 = buffer_ld1(B_blob_data, bi + 1); + } +#endif + + sum0 += a0 * b0; + sum1 += a0 * b1; + sum2 += a1 * b0; + sum3 += a1 * b1; + } +#endif + +#if NCNN_shader_local_memory + if (gx >= psc(N) || gy >= psc(M) || gz >= 1) + return; +#endif + + sum0 *= afp(alpha); + sum1 *= afp(alpha); + sum2 *= afp(alpha); + sum3 *= afp(alpha); + +#if NCNN_image_shader + if (output_transpose == 1) + { + if (output_N1M == 1) + { + image3d_st1(top_blob_3d, ivec3(gy, 0, gx), sum0); + if (gy + 1 < psc(M)) image3d_st1(top_blob_3d, ivec3(gy + 1, 0, gx), sum2); + if (gx + 1 < psc(N)) + { + image3d_st1(top_blob_3d, ivec3(gy, 0, gx + 1), sum1); + if (gy + 1 < psc(M)) image3d_st1(top_blob_3d, ivec3(gy + 1, 0, gx + 1), sum3); + } + } + else + { + image3d_st1(top_blob_3d, ivec3(gy, gx, 0), sum0); + if (gy + 1 < psc(M)) image3d_st1(top_blob_3d, ivec3(gy + 1, gx, 0), sum2); + if (gx + 1 < psc(N)) + { + image3d_st1(top_blob_3d, ivec3(gy, gx + 1, 0), sum1); + if (gy + 1 < psc(M)) image3d_st1(top_blob_3d, ivec3(gy + 1, gx + 1, 0), sum3); + } + } + } + else + { + if (output_N1M == 1) + { + image3d_st1(top_blob_3d, ivec3(gx, 0, gy), sum0); + if (gx + 1 < psc(N)) image3d_st1(top_blob_3d, ivec3(gx + 1, 0, gy), sum1); + if (gy + 1 < psc(M)) + { + image3d_st1(top_blob_3d, ivec3(gx, 0, gy + 1), sum2); + if (gx + 1 < psc(N)) image3d_st1(top_blob_3d, ivec3(gx + 1, 0, gy + 1), sum3); + } + } + else + { + image3d_st1(top_blob_3d, ivec3(gx, gy, 0), sum0); + if (gx + 1 < psc(N)) image3d_st1(top_blob_3d, ivec3(gx + 1, gy, 0), sum1); + if (gy + 1 < psc(M)) + { + image3d_st1(top_blob_3d, ivec3(gx, gy + 1, 0), sum2); + if (gx + 1 < psc(N)) image3d_st1(top_blob_3d, ivec3(gx + 1, gy + 1, 0), sum3); + } + } + } +#else + if (output_transpose == 1) + { + const int gi = gx * p.outhstep + gy; + + buffer_st1(top_blob_data, gi, sum0); + if (gy + 1 < psc(M)) buffer_st1(top_blob_data, gi + 1, sum2); + if (gx + 1 < psc(N)) + { + buffer_st1(top_blob_data, gi + p.outhstep, sum1); + if (gy + 1 < psc(M)) buffer_st1(top_blob_data, gi + p.outhstep + 1, sum3); + } + } + else + { + const int gi = gy * p.outhstep + gx; + + buffer_st1(top_blob_data, gi, sum0); + if (gx + 1 < psc(N)) buffer_st1(top_blob_data, gi + 1, sum1); + if (gy + 1 < psc(M)) + { + buffer_st1(top_blob_data, gi + p.outhstep, sum2); + if (gx + 1 < psc(N)) buffer_st1(top_blob_data, gi + p.outhstep + 1, sum3); + } + } +#endif +} diff --git a/src/layer/vulkan/shader/multiheadattention_qk_cross.comp b/src/layer/vulkan/shader/multiheadattention_qk_cross.comp new file mode 100644 index 000000000..134393e54 --- /dev/null +++ b/src/layer/vulkan/shader/multiheadattention_qk_cross.comp @@ -0,0 +1,81 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +layout (constant_id = 0) const int M = 0; +layout (constant_id = 1) const int N = 0; +layout (constant_id = 2) const int K = 0; +layout (constant_id = 3) const int B = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D q_blob_3d; +layout (binding = 1) uniform unfp sampler3D k_blob_3d; +layout (binding = 2, imfmtc1) writeonly uniform unfp image3D qkcross_blob_3d; +#else +layout (binding = 0) readonly buffer q_blob { sfp q_blob_data[]; }; +layout (binding = 1) readonly buffer k_blob { sfp k_blob_data[]; }; +layout (binding = 2) writeonly buffer qkcross_blob { sfp qkcross_blob_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int M; + int N; + int K; + int B; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x); + int gy = int(gl_GlobalInvocationID.y); + int gz = int(gl_GlobalInvocationID.z); + + if (gx >= psc(N) || gy >= psc(M) || gz >= psc(B)) + return; + + afp sum = afp(0.f); + + for (int k = 0; k < psc(K); k++) + { +#if NCNN_image_shader + afp q0 = image3d_ld1(q_blob_3d, ivec3(gy, gz * psc(K) + k, 0)); + + afp k0 = image3d_ld1(k_blob_3d, ivec3(gx, gz * psc(K) + k, 0)); +#else + const int ai = gz * psc(M) * psc(K) + k * psc(M) + gy; + afp q0 = buffer_ld1(q_blob_data, ai); + + const int bi = gz * psc(N) * psc(K) + k * psc(N) + gx; + afp k0 = buffer_ld1(k_blob_data, bi); +#endif + + sum += q0 * k0; + } + +#if NCNN_image_shader + image3d_st1(qkcross_blob_3d, ivec3(gx, gz * psc(M) + gy, 0), sum); +#else + const int gi = gz * psc(M) * psc(N) + gy * psc(N) + gx; + buffer_st1(qkcross_blob_data, gi, sum); +#endif +} diff --git a/src/layer/vulkan/shader/multiheadattention_qk_cross_pack1to4.comp b/src/layer/vulkan/shader/multiheadattention_qk_cross_pack1to4.comp new file mode 100644 index 000000000..e396770b4 --- /dev/null +++ b/src/layer/vulkan/shader/multiheadattention_qk_cross_pack1to4.comp @@ -0,0 +1,90 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +layout (constant_id = 0) const int M = 0; +layout (constant_id = 1) const int N = 0; +layout (constant_id = 2) const int K = 0; +layout (constant_id = 3) const int B = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D q_blob_3d; +layout (binding = 1) uniform unfp sampler3D k_blob_3d; +layout (binding = 2, imfmtc4) writeonly uniform unfp image3D qkcross_blob_3d; +#else +layout (binding = 0) readonly buffer q_blob { sfp q_blob_data[]; }; +layout (binding = 1) readonly buffer k_blob { sfp k_blob_data[]; }; +layout (binding = 2) writeonly buffer qkcross_blob { sfpvec4 qkcross_blob_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int M; + int N; + int K; + int B; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x); + int gy = int(gl_GlobalInvocationID.y); + int gz = int(gl_GlobalInvocationID.z); + + if (gx >= psc(N) || gy >= psc(M) || gz >= psc(B)) + return; + + afpvec4 sum = afpvec4(0.f); + + for (int k = 0; k < psc(K); k++) + { +#if NCNN_image_shader + afp q0 = image3d_ld1(q_blob_3d, ivec3(gy * 4, gz * psc(K) + k, 0)); + afp q1 = image3d_ld1(q_blob_3d, ivec3(gy * 4 + 1, gz * psc(K) + k, 0)); + afp q2 = image3d_ld1(q_blob_3d, ivec3(gy * 4 + 2, gz * psc(K) + k, 0)); + afp q3 = image3d_ld1(q_blob_3d, ivec3(gy * 4 + 3, gz * psc(K) + k, 0)); + + afp k0 = image3d_ld1(k_blob_3d, ivec3(gx, gz * psc(K) + k, 0)); +#else + const int ai = (gz * psc(M) * psc(K) + k * psc(M) + gy) * 4; + afp q0 = buffer_ld1(q_blob_data, ai); + afp q1 = buffer_ld1(q_blob_data, ai + 1); + afp q2 = buffer_ld1(q_blob_data, ai + 2); + afp q3 = buffer_ld1(q_blob_data, ai + 3); + + const int bi = gz * psc(N) * psc(K) + k * psc(N) + gx; + afp k0 = buffer_ld1(k_blob_data, bi); +#endif + + sum.r += q0 * k0; + sum.g += q1 * k0; + sum.b += q2 * k0; + sum.a += q3 * k0; + } + +#if NCNN_image_shader + image3d_st4(qkcross_blob_3d, ivec3(gx, gz * psc(M) + gy, 0), sum); +#else + const int gi = gz * psc(M) * psc(N) + gy * psc(N) + gx; + buffer_st4(qkcross_blob_data, gi, sum); +#endif +} diff --git a/src/layer/vulkan/shader/multiheadattention_qk_cross_pack4.comp b/src/layer/vulkan/shader/multiheadattention_qk_cross_pack4.comp new file mode 100644 index 000000000..e20b64b90 --- /dev/null +++ b/src/layer/vulkan/shader/multiheadattention_qk_cross_pack4.comp @@ -0,0 +1,186 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +#define LOCAL_MEMORY_UNROLL_INCH 8 + +layout (constant_id = 0) const int M = 0; +layout (constant_id = 1) const int N = 0; +layout (constant_id = 2) const int K = 0; +layout (constant_id = 3) const int B = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D q_blob_3d; +layout (binding = 1) uniform unfp sampler3D k_blob_3d; +layout (binding = 2, imfmtc4) writeonly uniform unfp image3D qkcross_blob_3d; +#else +layout (binding = 0) readonly buffer q_blob { sfpvec4 q_blob_data[]; }; +layout (binding = 1) readonly buffer k_blob { sfpvec4 k_blob_data[]; }; +layout (binding = 2) writeonly buffer qkcross_blob { sfpvec4 qkcross_blob_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int M; + int N; + int K; + int B; +} p; + +#if NCNN_shader_local_memory +shared lfpvec4 tmp_q[8][LOCAL_MEMORY_UNROLL_INCH][4]; +shared lfpvec4 tmp_k[8][LOCAL_MEMORY_UNROLL_INCH]; +#endif + +void main() +{ + int gx = int(gl_GlobalInvocationID.x); + int gy = int(gl_GlobalInvocationID.y); + int gz = int(gl_GlobalInvocationID.z); + +#if !NCNN_shader_local_memory + if (gx >= psc(N) || gy >= psc(M) || gz >= psc(B)) + return; +#endif + + afpvec4 sum = afpvec4(0.f); + +#if !NCNN_image_shader && NCNN_shader_local_memory + const int NN = psc(K); + + const int lx = int(gl_LocalInvocationID.x); + const int ly = int(gl_LocalInvocationID.y); + + int ai = (gz * psc(M) * psc(K) + lx * psc(M) + gy) * 4; + int bi = gz * psc(N) * psc(K) + ly * psc(N) + gx; + + int k = 0; + for (; k + (LOCAL_MEMORY_UNROLL_INCH - 1) < NN; k += LOCAL_MEMORY_UNROLL_INCH) + { + { + tmp_q[ly][lx][0] = sfp2lfpvec4(q_blob_data[ai]); + tmp_q[ly][lx][1] = sfp2lfpvec4(q_blob_data[ai + 1]); + tmp_q[ly][lx][2] = sfp2lfpvec4(q_blob_data[ai + 2]); + tmp_q[ly][lx][3] = sfp2lfpvec4(q_blob_data[ai + 3]); + } + + { + tmp_k[lx][ly] = sfp2lfpvec4(k_blob_data[bi]); + } + + barrier(); + + for (int k4 = 0; k4 < LOCAL_MEMORY_UNROLL_INCH; k4++) + { + afpvec4 q0 = lfp2afpvec4(tmp_q[ly][k4][0]); + afpvec4 q1 = lfp2afpvec4(tmp_q[ly][k4][1]); + afpvec4 q2 = lfp2afpvec4(tmp_q[ly][k4][2]); + afpvec4 q3 = lfp2afpvec4(tmp_q[ly][k4][3]); + + afpvec4 k0 = lfp2afpvec4(tmp_k[lx][k4]); + + sum.r += dot(q0, k0); + sum.g += dot(q1, k0); + sum.b += dot(q2, k0); + sum.a += dot(q3, k0); + } + + ai += LOCAL_MEMORY_UNROLL_INCH * psc(M) * 4; + bi += LOCAL_MEMORY_UNROLL_INCH * psc(N); + + barrier(); + } + + if (k < NN) + { + const int remain = NN - k; + + if (lx < remain) + { + tmp_q[ly][lx][0] = sfp2lfpvec4(q_blob_data[ai]); + tmp_q[ly][lx][1] = sfp2lfpvec4(q_blob_data[ai + 1]); + tmp_q[ly][lx][2] = sfp2lfpvec4(q_blob_data[ai + 2]); + tmp_q[ly][lx][3] = sfp2lfpvec4(q_blob_data[ai + 3]); + } + + if (ly < remain) + { + tmp_k[lx][ly] = sfp2lfpvec4(k_blob_data[bi]); + } + + barrier(); + + for (int k4 = 0; k4 < remain; k4++) + { + afpvec4 q0 = lfp2afpvec4(tmp_q[ly][k4][0]); + afpvec4 q1 = lfp2afpvec4(tmp_q[ly][k4][1]); + afpvec4 q2 = lfp2afpvec4(tmp_q[ly][k4][2]); + afpvec4 q3 = lfp2afpvec4(tmp_q[ly][k4][3]); + + afpvec4 k0 = lfp2afpvec4(tmp_k[lx][k4]); + + sum.r += dot(q0, k0); + sum.g += dot(q1, k0); + sum.b += dot(q2, k0); + sum.a += dot(q3, k0); + } + } +#else + for (int k = 0; k < psc(K); k++) + { +#if NCNN_image_shader + afpvec4 q0 = image3d_ld4(q_blob_3d, ivec3(gy * 4, gz * psc(K) + k, 0)); + afpvec4 q1 = image3d_ld4(q_blob_3d, ivec3(gy * 4 + 1, gz * psc(K) + k, 0)); + afpvec4 q2 = image3d_ld4(q_blob_3d, ivec3(gy * 4 + 2, gz * psc(K) + k, 0)); + afpvec4 q3 = image3d_ld4(q_blob_3d, ivec3(gy * 4 + 3, gz * psc(K) + k, 0)); + + afpvec4 k0 = image3d_ld4(k_blob_3d, ivec3(gx, gz * psc(K) + k, 0)); +#else + const int ai = (gz * psc(M) * psc(K) + k * psc(M) + gy) * 4; + afpvec4 q0 = buffer_ld4(q_blob_data, ai); + afpvec4 q1 = buffer_ld4(q_blob_data, ai + 1); + afpvec4 q2 = buffer_ld4(q_blob_data, ai + 2); + afpvec4 q3 = buffer_ld4(q_blob_data, ai + 3); + + const int bi = gz * psc(N) * psc(K) + k * psc(N) + gx; + afpvec4 k0 = buffer_ld4(k_blob_data, bi); +#endif + + sum.r += dot(q0, k0); + sum.g += dot(q1, k0); + sum.b += dot(q2, k0); + sum.a += dot(q3, k0); + } +#endif + +#if NCNN_shader_local_memory + if (gx >= psc(N) || gy >= psc(M) || gz >= psc(B)) + return; +#endif + +#if NCNN_image_shader + image3d_st4(qkcross_blob_3d, ivec3(gx, gz * psc(M) + gy, 0), sum); +#else + const int gi = gz * psc(M) * psc(N) + gy * psc(N) + gx; + buffer_st4(qkcross_blob_data, gi, sum); +#endif +} diff --git a/src/layer/vulkan/shader/multiheadattention_qk_cross_pack4to1.comp b/src/layer/vulkan/shader/multiheadattention_qk_cross_pack4to1.comp new file mode 100644 index 000000000..f0a1983b6 --- /dev/null +++ b/src/layer/vulkan/shader/multiheadattention_qk_cross_pack4to1.comp @@ -0,0 +1,81 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +layout (constant_id = 0) const int M = 0; +layout (constant_id = 1) const int N = 0; +layout (constant_id = 2) const int K = 0; +layout (constant_id = 3) const int B = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D q_blob_3d; +layout (binding = 1) uniform unfp sampler3D k_blob_3d; +layout (binding = 2, imfmtc1) writeonly uniform unfp image3D qkcross_blob_3d; +#else +layout (binding = 0) readonly buffer q_blob { sfpvec4 q_blob_data[]; }; +layout (binding = 1) readonly buffer k_blob { sfpvec4 k_blob_data[]; }; +layout (binding = 2) writeonly buffer qkcross_blob { sfp qkcross_blob_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int M; + int N; + int K; + int B; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x); + int gy = int(gl_GlobalInvocationID.y); + int gz = int(gl_GlobalInvocationID.z); + + if (gx >= psc(N) || gy >= psc(M) || gz >= psc(B)) + return; + + afp sum = afp(0.f); + + for (int k = 0; k < psc(K); k++) + { +#if NCNN_image_shader + afpvec4 q0 = image3d_ld4(q_blob_3d, ivec3(gy, gz * psc(K) + k, 0)); + + afpvec4 k0 = image3d_ld4(k_blob_3d, ivec3(gx, gz * psc(K) + k, 0)); +#else + const int ai = gz * psc(M) * psc(K) + k * psc(M) + gy; + afpvec4 q0 = buffer_ld4(q_blob_data, ai); + + const int bi = gz * psc(N) * psc(K) + k * psc(N) + gx; + afpvec4 k0 = buffer_ld4(k_blob_data, bi); +#endif + + sum += dot(q0, k0); + } + +#if NCNN_image_shader + image3d_st1(qkcross_blob_3d, ivec3(gx, gz * psc(M) + gy, 0), sum); +#else + const int gi = gz * psc(M) * psc(N) + gy * psc(N) + gx; + buffer_st1(qkcross_blob_data, gi, sum); +#endif +} diff --git a/src/layer/vulkan/shader/multiheadattention_qkv_cross.comp b/src/layer/vulkan/shader/multiheadattention_qkv_cross.comp new file mode 100644 index 000000000..e1a017f09 --- /dev/null +++ b/src/layer/vulkan/shader/multiheadattention_qkv_cross.comp @@ -0,0 +1,81 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +layout (constant_id = 0) const int M = 0; +layout (constant_id = 1) const int N = 0; +layout (constant_id = 2) const int K = 0; +layout (constant_id = 3) const int B = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D qkcross_blob_3d; +layout (binding = 1) uniform unfp sampler3D v_blob_3d; +layout (binding = 2, imfmtc1) writeonly uniform unfp image3D qkvcross_blob_3d; +#else +layout (binding = 0) readonly buffer qkcross_blob { sfp qkcross_blob_data[]; }; +layout (binding = 1) readonly buffer v_blob { sfp v_blob_data[]; }; +layout (binding = 2) writeonly buffer qkvcross_blob { sfp qkvcross_blob_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int M; + int N; + int K; + int B; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x); + int gy = int(gl_GlobalInvocationID.y); + int gz = int(gl_GlobalInvocationID.z); + + if (gx >= psc(N) || gy >= psc(M) || gz >= psc(B)) + return; + + afp sum = afp(0.f); + + for (int k = 0; k < psc(K); k++) + { +#if NCNN_image_shader + afp qk0 = image3d_ld1(qkcross_blob_3d, ivec3(k, gz * psc(M) + gy, 0)); + afp v0 = image3d_ld1(v_blob_3d, ivec3(k, gz * psc(N) + gx, 0)); +#else + const int ai = gz * psc(M) * psc(K) + gy * psc(K) + k; + afp qk0 = buffer_ld1(qkcross_blob_data, ai); + + const int bi = gz * psc(N) * psc(K) + gx * psc(K) + k; + afp v0 = buffer_ld1(v_blob_data, bi); +#endif + + sum += qk0 * v0; + } + +#if NCNN_image_shader + image3d_st1(qkvcross_blob_3d, ivec3(gy, gz * psc(N) + gx, 0), sum); +#else + const int gi = gz * psc(M) * psc(N) + gx * psc(M) + gy; + + buffer_st1(qkvcross_blob_data, gi, sum); +#endif +} diff --git a/src/layer/vulkan/shader/multiheadattention_qkv_cross_pack1to4.comp b/src/layer/vulkan/shader/multiheadattention_qkv_cross_pack1to4.comp new file mode 100644 index 000000000..ba6a3cec7 --- /dev/null +++ b/src/layer/vulkan/shader/multiheadattention_qkv_cross_pack1to4.comp @@ -0,0 +1,81 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +layout (constant_id = 0) const int M = 0; +layout (constant_id = 1) const int N = 0; +layout (constant_id = 2) const int K = 0; +layout (constant_id = 3) const int B = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D qkcross_blob_3d; +layout (binding = 1) uniform unfp sampler3D v_blob_3d; +layout (binding = 2, imfmtc4) writeonly uniform unfp image3D qkvcross_blob_3d; +#else +layout (binding = 0) readonly buffer qkcross_blob { sfp qkcross_blob_data[]; }; +layout (binding = 1) readonly buffer v_blob { sfpvec4 v_blob_data[]; }; +layout (binding = 2) writeonly buffer qkvcross_blob { sfpvec4 qkvcross_blob_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int M; + int N; + int K; + int B; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x); + int gy = int(gl_GlobalInvocationID.y); + int gz = int(gl_GlobalInvocationID.z); + + if (gx >= psc(N) || gy >= psc(M) || gz >= psc(B)) + return; + + afpvec4 sum = afpvec4(0.f); + + for (int k = 0; k < psc(K); k++) + { +#if NCNN_image_shader + afp qk0 = image3d_ld1(qkcross_blob_3d, ivec3(k, gz * psc(M) + gy, 0)); + afpvec4 v0 = image3d_ld4(v_blob_3d, ivec3(k, gz * psc(N) + gx, 0)); +#else + const int ai = gz * psc(M) * psc(K) + gy * psc(K) + k; + afp qk0 = buffer_ld1(qkcross_blob_data, ai); + + const int bi = gz * psc(N) * psc(K) + gx * psc(K) + k; + afpvec4 v0 = buffer_ld4(v_blob_data, bi); +#endif + + sum += qk0 * v0; + } + +#if NCNN_image_shader + image3d_st4(qkvcross_blob_3d, ivec3(gy, gz * psc(N) + gx, 0), sum); +#else + const int gi = gz * psc(M) * psc(N) + gx * psc(M) + gy; + + buffer_st4(qkvcross_blob_data, gi, sum); +#endif +} diff --git a/src/layer/vulkan/shader/multiheadattention_qkv_cross_pack4.comp b/src/layer/vulkan/shader/multiheadattention_qkv_cross_pack4.comp new file mode 100644 index 000000000..ec518d329 --- /dev/null +++ b/src/layer/vulkan/shader/multiheadattention_qkv_cross_pack4.comp @@ -0,0 +1,177 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +#define LOCAL_MEMORY_UNROLL_INCH 8 + +layout (constant_id = 0) const int M = 0; +layout (constant_id = 1) const int N = 0; +layout (constant_id = 2) const int K = 0; +layout (constant_id = 3) const int B = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D qkcross_blob_3d; +layout (binding = 1) uniform unfp sampler3D v_blob_3d; +layout (binding = 2, imfmtc4) writeonly uniform unfp image3D qkvcross_blob_3d; +#else +layout (binding = 0) readonly buffer qkcross_blob { sfpvec4 qkcross_blob_data[]; }; +layout (binding = 1) readonly buffer v_blob { sfpvec4 v_blob_data[]; }; +layout (binding = 2) writeonly buffer qkvcross_blob { sfpvec4 qkvcross_blob_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int M; + int N; + int K; + int B; +} p; + +#if NCNN_shader_local_memory +shared lfpvec4 tmp_qk[8][LOCAL_MEMORY_UNROLL_INCH]; +shared lfpvec4 tmp_v[8][LOCAL_MEMORY_UNROLL_INCH]; +#endif + +void main() +{ + int gx = int(gl_GlobalInvocationID.x); + int gy = int(gl_GlobalInvocationID.y); + int gz = int(gl_GlobalInvocationID.z); + +#if !NCNN_shader_local_memory + if (gx >= psc(N) || gy >= psc(M) || gz >= psc(B)) + return; +#endif + + afpvec4 sum0 = afpvec4(0.f); + afpvec4 sum1 = afpvec4(0.f); + afpvec4 sum2 = afpvec4(0.f); + afpvec4 sum3 = afpvec4(0.f); + +#if !NCNN_image_shader && NCNN_shader_local_memory + const int NN = psc(K); + + const int lx = int(gl_LocalInvocationID.x); + const int ly = int(gl_LocalInvocationID.y); + + int ai = gz * psc(M) * psc(K) + gy * psc(K) + lx; + int bi = gz * psc(N) * psc(K) + gx * psc(K) + ly; + + int k = 0; + for (; k + (LOCAL_MEMORY_UNROLL_INCH - 1) < NN; k += LOCAL_MEMORY_UNROLL_INCH) + { + { + tmp_qk[ly][lx] = sfp2lfpvec4(qkcross_blob_data[ai]); + } + + { + tmp_v[lx][ly] = sfp2lfpvec4(v_blob_data[bi]); + } + + barrier(); + + for (int k4 = 0; k4 < LOCAL_MEMORY_UNROLL_INCH; k4++) + { + afpvec4 qk0 = lfp2afpvec4(tmp_qk[ly][k4]); + + afpvec4 v0 = lfp2afpvec4(tmp_v[lx][k4]); + + sum0 += qk0.r * v0; + sum1 += qk0.g * v0; + sum2 += qk0.b * v0; + sum3 += qk0.a * v0; + } + + ai += LOCAL_MEMORY_UNROLL_INCH; + bi += LOCAL_MEMORY_UNROLL_INCH; + + barrier(); + } + + if (k < NN) + { + const int remain = NN - k; + + if (lx < remain) + { + tmp_qk[ly][lx] = sfp2lfpvec4(qkcross_blob_data[ai]); + } + + if (ly < remain) + { + tmp_v[lx][ly] = sfp2lfpvec4(v_blob_data[bi]); + } + + barrier(); + + for (int k4 = 0; k4 < remain; k4++) + { + afpvec4 qk0 = lfp2afpvec4(tmp_qk[ly][k4]); + + afpvec4 v0 = lfp2afpvec4(tmp_v[lx][k4]); + + sum0 += qk0.r * v0; + sum1 += qk0.g * v0; + sum2 += qk0.b * v0; + sum3 += qk0.a * v0; + } + } +#else + for (int k = 0; k < psc(K); k++) + { +#if NCNN_image_shader + afpvec4 qk0 = image3d_ld4(qkcross_blob_3d, ivec3(k, gz * psc(M) + gy, 0)); + afpvec4 v0 = image3d_ld4(v_blob_3d, ivec3(k, gz * psc(N) + gx, 0)); +#else + const int ai = gz * psc(M) * psc(K) + gy * psc(K) + k; + afpvec4 qk0 = buffer_ld4(qkcross_blob_data, ai); + + const int bi = gz * psc(N) * psc(K) + gx * psc(K) + k; + afpvec4 v0 = buffer_ld4(v_blob_data, bi); +#endif + + sum0 += qk0.r * v0; + sum1 += qk0.g * v0; + sum2 += qk0.b * v0; + sum3 += qk0.a * v0; + } +#endif + +#if NCNN_shader_local_memory + if (gx >= psc(N) || gy >= psc(M) || gz >= psc(B)) + return; +#endif + +#if NCNN_image_shader + image3d_st4(qkvcross_blob_3d, ivec3(gy * 4, gz * psc(N) + gx, 0), sum0); + image3d_st4(qkvcross_blob_3d, ivec3(gy * 4 + 1, gz * psc(N) + gx, 0), sum1); + image3d_st4(qkvcross_blob_3d, ivec3(gy * 4 + 2, gz * psc(N) + gx, 0), sum2); + image3d_st4(qkvcross_blob_3d, ivec3(gy * 4 + 3, gz * psc(N) + gx, 0), sum3); +#else + const int gi = (gz * psc(M) * psc(N) + gx * psc(M) + gy) * 4; + + buffer_st4(qkvcross_blob_data, gi, sum0); + buffer_st4(qkvcross_blob_data, gi + 1, sum1); + buffer_st4(qkvcross_blob_data, gi + 2, sum2); + buffer_st4(qkvcross_blob_data, gi + 3, sum3); +#endif +} diff --git a/src/layer/vulkan/shader/multiheadattention_qkv_cross_pack4to1.comp b/src/layer/vulkan/shader/multiheadattention_qkv_cross_pack4to1.comp new file mode 100644 index 000000000..0fad65987 --- /dev/null +++ b/src/layer/vulkan/shader/multiheadattention_qkv_cross_pack4to1.comp @@ -0,0 +1,87 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +layout (constant_id = 0) const int M = 0; +layout (constant_id = 1) const int N = 0; +layout (constant_id = 2) const int K = 0; +layout (constant_id = 3) const int B = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D qkcross_blob_3d; +layout (binding = 1) uniform unfp sampler3D v_blob_3d; +layout (binding = 2, imfmtc1) writeonly uniform unfp image3D qkvcross_blob_3d; +#else +layout (binding = 0) readonly buffer qkcross_blob { sfpvec4 qkcross_blob_data[]; }; +layout (binding = 1) readonly buffer v_blob { sfp v_blob_data[]; }; +layout (binding = 2) writeonly buffer qkvcross_blob { sfp qkvcross_blob_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int M; + int N; + int K; + int B; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x); + int gy = int(gl_GlobalInvocationID.y); + int gz = int(gl_GlobalInvocationID.z); + + if (gx >= psc(N) || gy >= psc(M) || gz >= psc(B)) + return; + + afpvec4 sum = afpvec4(0.f); + + for (int k = 0; k < psc(K); k++) + { +#if NCNN_image_shader + afpvec4 qk0 = image3d_ld4(qkcross_blob_3d, ivec3(k, gz * psc(M) + gy, 0)); + afp v0 = image3d_ld1(v_blob_3d, ivec3(k, gz * psc(N) + gx, 0)); +#else + const int ai = gz * psc(M) * psc(K) + gy * psc(K) + k; + afpvec4 qk0 = buffer_ld4(qkcross_blob_data, ai); + + const int bi = gz * psc(N) * psc(K) + gx * psc(K) + k; + afp v0 = buffer_ld1(v_blob_data, bi); +#endif + + sum += qk0 * v0; + } + +#if NCNN_image_shader + image3d_st1(qkvcross_blob_3d, ivec3(gy * 4, gz * psc(N) + gx, 0), sum.r); + image3d_st1(qkvcross_blob_3d, ivec3(gy * 4 + 1, gz * psc(N) + gx, 0), sum.g); + image3d_st1(qkvcross_blob_3d, ivec3(gy * 4 + 2, gz * psc(N) + gx, 0), sum.b); + image3d_st1(qkvcross_blob_3d, ivec3(gy * 4 + 3, gz * psc(N) + gx, 0), sum.a); +#else + const int gi = (gz * psc(M) * psc(N) + gx * psc(M) + gy) * 4; + + buffer_st1(qkvcross_blob_data, gi, sum.r); + buffer_st1(qkvcross_blob_data, gi + 1, sum.g); + buffer_st1(qkvcross_blob_data, gi + 2, sum.b); + buffer_st1(qkvcross_blob_data, gi + 3, sum.a); +#endif +} diff --git a/src/layer/x86/multiheadattention_x86.cpp b/src/layer/x86/multiheadattention_x86.cpp index 0b8eacbb4..707419649 100644 --- a/src/layer/x86/multiheadattention_x86.cpp +++ b/src/layer/x86/multiheadattention_x86.cpp @@ -39,7 +39,7 @@ MultiHeadAttention_x86::MultiHeadAttention_x86() int MultiHeadAttention_x86::create_pipeline(const Option& opt) { { - const int embed_dim_per_head = embed_dim / num_head; + const int embed_dim_per_head = embed_dim / num_heads; const float inv_sqrt_embed_dim_per_head = 1.f / sqrt(embed_dim_per_head); q_gemm = ncnn::create_layer(ncnn::LayerType::Gemm); @@ -271,7 +271,7 @@ int MultiHeadAttention_x86::forward(const std::vector& bottom_blobs, std::v const Mat& k_blob = bottom_blobs.size() == 1 ? q_blob : bottom_blobs[1]; const Mat& v_blob = bottom_blobs.size() == 1 ? q_blob : bottom_blobs.size() == 2 ? k_blob : bottom_blobs[2]; - const int embed_dim_per_head = embed_dim / num_head; + const int embed_dim_per_head = embed_dim / num_heads; const int src_seqlen = q_blob.h * q_blob.elempack; const int dst_seqlen = k_blob.h * k_blob.elempack; @@ -281,9 +281,9 @@ int MultiHeadAttention_x86::forward(const std::vector& bottom_blobs, std::v Mat k_affine; k_gemm->forward(k_blob, k_affine, opt); - Mat qk_cross(dst_seqlen, src_seqlen * num_head, 4u, opt.blob_allocator); + Mat qk_cross(dst_seqlen, src_seqlen * num_heads, 4u, opt.blob_allocator); #pragma omp parallel for num_threads(opt.num_threads) - for (int i = 0; i < num_head; i++) + for (int i = 0; i < num_heads; i++) { std::vector qk_bottom_blobs(2); qk_bottom_blobs[0] = q_affine.row_range(i * embed_dim_per_head, embed_dim_per_head); @@ -303,9 +303,9 @@ int MultiHeadAttention_x86::forward(const std::vector& bottom_blobs, std::v Mat v_affine; v_gemm->forward(v_blob, v_affine, opt); - Mat qkv_cross(src_seqlen, embed_dim_per_head * num_head, 4u, opt.blob_allocator); + Mat qkv_cross(src_seqlen, embed_dim_per_head * num_heads, 4u, opt.blob_allocator); #pragma omp parallel for num_threads(opt.num_threads) - for (int i = 0; i < num_head; i++) + for (int i = 0; i < num_heads; i++) { std::vector qkv_bottom_blobs(2); qkv_bottom_blobs[0] = qk_cross.row_range(i * src_seqlen, src_seqlen); diff --git a/tests/test_multiheadattention.cpp b/tests/test_multiheadattention.cpp index 7ed18c4fe..9f6ce518e 100644 --- a/tests/test_multiheadattention.cpp +++ b/tests/test_multiheadattention.cpp @@ -46,7 +46,7 @@ static int test_multiheadattention(const ncnn::Mat& q, const ncnn::Mat& k, const int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon); if (ret != 0) { - fprintf(stderr, "test_multiheadattention failed q=(%d %d) k=(%d %d) v=(%d %d)\n", q.w, q.h, k.w, k.h, v.w, v.h); + fprintf(stderr, "test_multiheadattention failed q=(%d %d) k=(%d %d) v=(%d %d) num_heads=%d kdim=%d vdim=%d\n", q.w, q.h, k.w, k.h, v.w, v.h, num_heads, kdim, vdim); } return ret; @@ -82,7 +82,7 @@ static int test_multiheadattention_samekv(const ncnn::Mat& q, const ncnn::Mat& k int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon); if (ret != 0) { - fprintf(stderr, "test_multiheadattention_samekv failed q=(%d %d) kv=(%d %d)\n", q.w, q.h, kv.w, kv.h); + fprintf(stderr, "test_multiheadattention_samekv failed q=(%d %d) kv=(%d %d) num_heads=%d kvdim=%d\n", q.w, q.h, kv.w, kv.h, num_heads, kvdim); } return ret; @@ -115,7 +115,7 @@ static int test_multiheadattention_sameqkv(const ncnn::Mat& a, int num_heads) int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon); if (ret != 0) { - fprintf(stderr, "test_multiheadattention_sameqkv failed a=(%d %d)\n", a.w, a.h); + fprintf(stderr, "test_multiheadattention_sameqkv failed a=(%d %d) num_heads=%d\n", a.w, a.h, num_heads); } return ret; @@ -124,6 +124,8 @@ static int test_multiheadattention_sameqkv(const ncnn::Mat& a, int num_heads) static int test_multiheadattention_0() { return 0 + || test_multiheadattention(RandomMat(62, 66), RandomMat(32, 66), RandomMat(20, 66), 2, 32, 20) + || test_multiheadattention(RandomMat(26, 64), RandomMat(32, 64), RandomMat(18, 64), 2, 32, 18) || test_multiheadattention(RandomMat(64, 128), RandomMat(64, 128), RandomMat(64, 128), 4, 64, 64) || test_multiheadattention(RandomMat(64, 127), RandomMat(64, 127), RandomMat(64, 127), 16, 64, 64) || test_multiheadattention(RandomMat(16, 128), RandomMat(44, 128), RandomMat(55, 128), 2, 44, 55) @@ -146,8 +148,8 @@ static int test_multiheadattention_1() static int test_multiheadattention_2() { return 0 - || test_multiheadattention_sameqkv(RandomMat(64, 128), 8) - || test_multiheadattention_sameqkv(RandomMat(64, 127), 32); + || test_multiheadattention_sameqkv(RandomMat(64, 128), 4) + || test_multiheadattention_sameqkv(RandomMat(64, 127), 8); } int main() diff --git a/tools/modelwriter.h b/tools/modelwriter.h index 8c74d5946..b165b6c74 100644 --- a/tools/modelwriter.h +++ b/tools/modelwriter.h @@ -1922,7 +1922,7 @@ int ModelWriter::save(const char* parampath, const char* binpath) ncnn::MultiHeadAttention* op_default = (ncnn::MultiHeadAttention*)layer_default; fprintf_param_value(" 0=%d", embed_dim) - fprintf_param_value(" 1=%d", num_head) + fprintf_param_value(" 1=%d", num_heads) fprintf_param_value(" 2=%d", weight_data_size) fprintf_param_value(" 3=%d", kdim) fprintf_param_value(" 4=%d", vdim)