Browse Source

rough vulkan gemm and multiheadattention (#4618)

tags/20230517
nihui GitHub 3 years ago
parent
commit
b640574b88
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 2596 additions and 34 deletions
  1. +1
    -1
      docs/developer-guide/operators.md
  2. +11
    -11
      src/layer/arm/multiheadattention_arm.cpp
  3. +9
    -9
      src/layer/multiheadattention.cpp
  4. +1
    -1
      src/layer/multiheadattention.h
  5. +425
    -0
      src/layer/vulkan/gemm_vulkan.cpp
  6. +56
    -0
      src/layer/vulkan/gemm_vulkan.h
  7. +705
    -0
      src/layer/vulkan/multiheadattention_vulkan.cpp
  8. +57
    -0
      src/layer/vulkan/multiheadattention_vulkan.h
  9. +453
    -0
      src/layer/vulkan/shader/gemm.comp
  10. +81
    -0
      src/layer/vulkan/shader/multiheadattention_qk_cross.comp
  11. +90
    -0
      src/layer/vulkan/shader/multiheadattention_qk_cross_pack1to4.comp
  12. +186
    -0
      src/layer/vulkan/shader/multiheadattention_qk_cross_pack4.comp
  13. +81
    -0
      src/layer/vulkan/shader/multiheadattention_qk_cross_pack4to1.comp
  14. +81
    -0
      src/layer/vulkan/shader/multiheadattention_qkv_cross.comp
  15. +81
    -0
      src/layer/vulkan/shader/multiheadattention_qkv_cross_pack1to4.comp
  16. +177
    -0
      src/layer/vulkan/shader/multiheadattention_qkv_cross_pack4.comp
  17. +87
    -0
      src/layer/vulkan/shader/multiheadattention_qkv_cross_pack4to1.comp
  18. +6
    -6
      src/layer/x86/multiheadattention_x86.cpp
  19. +7
    -5
      tests/test_multiheadattention.cpp
  20. +1
    -1
      tools/modelwriter.h

+ 1
- 1
docs/developer-guide/operators.md View File

@@ -1185,7 +1185,7 @@ y = affine(out)
| param id | name | type | default | description |
| --------- | ------------- | ----- | --------- | ----------------- |
| 0 | embed_dim | int | 0 | |
| 1 | num_head | int | 1 | |
| 1 | num_heads | int | 1 | |
| 2 | weight_data_size| int | 0 | |
| 3 | kdim | int | embed_dim | |
| 4 | vdim | int | embed_dim | |


+ 11
- 11
src/layer/arm/multiheadattention_arm.cpp View File

@@ -68,7 +68,7 @@ int MultiHeadAttention_arm::create_pipeline(const Option& opt)
Option optopt = optn;

{
const int embed_dim_per_head = embed_dim / num_head;
const int embed_dim_per_head = embed_dim / num_heads;
const float inv_sqrt_embed_dim_per_head = 1.f / sqrt(embed_dim_per_head);

q_gemm = ncnn::create_layer(ncnn::LayerType::Gemm);
@@ -240,7 +240,7 @@ int MultiHeadAttention_arm::create_pipeline(const Option& opt)
optopt.use_fp16_storage = false;

{
const int embed_dim_per_head = embed_dim / num_head;
const int embed_dim_per_head = embed_dim / num_heads;
const float inv_sqrt_embed_dim_per_head = 1.f / sqrt(embed_dim_per_head);

q_gemm = ncnn::create_layer(ncnn::LayerType::Gemm);
@@ -530,7 +530,7 @@ int MultiHeadAttention_arm::forward(const std::vector<Mat>& bottom_blobs, std::v
const Mat& k_blob = bottom_blobs.size() == 1 ? q_blob : bottom_blobs[1];
const Mat& v_blob = bottom_blobs.size() == 1 ? q_blob : bottom_blobs.size() == 2 ? k_blob : bottom_blobs[2];

const int embed_dim_per_head = embed_dim / num_head;
const int embed_dim_per_head = embed_dim / num_heads;
const int src_seqlen = q_blob.h * q_blob.elempack;
const int dst_seqlen = k_blob.h * k_blob.elempack;

@@ -555,9 +555,9 @@ int MultiHeadAttention_arm::forward(const std::vector<Mat>& bottom_blobs, std::v
Mat k_affine;
k_gemm->forward(k_blob, k_affine, optn);

Mat qk_cross(dst_seqlen, src_seqlen * num_head, 2u, optn.blob_allocator);
Mat qk_cross(dst_seqlen, src_seqlen * num_heads, 2u, optn.blob_allocator);
#pragma omp parallel for num_threads(optn.num_threads)
for (int i = 0; i < num_head; i++)
for (int i = 0; i < num_heads; i++)
{
std::vector<Mat> qk_bottom_blobs(2);
qk_bottom_blobs[0] = q_affine.row_range(i * embed_dim_per_head, embed_dim_per_head);
@@ -577,9 +577,9 @@ int MultiHeadAttention_arm::forward(const std::vector<Mat>& bottom_blobs, std::v
Mat v_affine;
v_gemm->forward(v_blob, v_affine, optn);

Mat qkv_cross(src_seqlen, embed_dim_per_head * num_head, 2u, optn.blob_allocator);
Mat qkv_cross(src_seqlen, embed_dim_per_head * num_heads, 2u, optn.blob_allocator);
#pragma omp parallel for num_threads(optn.num_threads)
for (int i = 0; i < num_head; i++)
for (int i = 0; i < num_heads; i++)
{
std::vector<Mat> qkv_bottom_blobs(2);
qkv_bottom_blobs[0] = qk_cross.row_range(i * src_seqlen, src_seqlen);
@@ -605,9 +605,9 @@ int MultiHeadAttention_arm::forward(const std::vector<Mat>& bottom_blobs, std::v
Mat k_affine;
k_gemm->forward(k_blob, k_affine, opt32);

Mat qk_cross(dst_seqlen, src_seqlen * num_head, 4u, opt32.blob_allocator);
Mat qk_cross(dst_seqlen, src_seqlen * num_heads, 4u, opt32.blob_allocator);
#pragma omp parallel for num_threads(opt32.num_threads)
for (int i = 0; i < num_head; i++)
for (int i = 0; i < num_heads; i++)
{
std::vector<Mat> qk_bottom_blobs(2);
qk_bottom_blobs[0] = q_affine.row_range(i * embed_dim_per_head, embed_dim_per_head);
@@ -627,9 +627,9 @@ int MultiHeadAttention_arm::forward(const std::vector<Mat>& bottom_blobs, std::v
Mat v_affine;
v_gemm->forward(v_blob, v_affine, opt32);

Mat qkv_cross(src_seqlen, embed_dim_per_head * num_head, 4u, opt32.blob_allocator);
Mat qkv_cross(src_seqlen, embed_dim_per_head * num_heads, 4u, opt32.blob_allocator);
#pragma omp parallel for num_threads(opt32.num_threads)
for (int i = 0; i < num_head; i++)
for (int i = 0; i < num_heads; i++)
{
std::vector<Mat> qkv_bottom_blobs(2);
qkv_bottom_blobs[0] = qk_cross.row_range(i * src_seqlen, src_seqlen);


+ 9
- 9
src/layer/multiheadattention.cpp View File

@@ -25,7 +25,7 @@ MultiHeadAttention::MultiHeadAttention()
int MultiHeadAttention::load_param(const ParamDict& pd)
{
embed_dim = pd.get(0, 0);
num_head = pd.get(1, 1);
num_heads = pd.get(1, 1);
weight_data_size = pd.get(2, 0);
kdim = pd.get(3, embed_dim);
vdim = pd.get(4, embed_dim);
@@ -79,7 +79,7 @@ int MultiHeadAttention::forward(const std::vector<Mat>& bottom_blobs, std::vecto

const int src_seqlen = q_blob.h;
const int dst_seqlen = k_blob.h;
const int embed_dim_per_head = embed_dim / num_head;
const int embed_dim_per_head = embed_dim / num_heads;

// assert k_blob.h == v_blob.h

@@ -88,18 +88,18 @@ int MultiHeadAttention::forward(const std::vector<Mat>& bottom_blobs, std::vecto
if (top_blob.empty())
return -1;

Mat xq(embed_dim_per_head, src_seqlen, num_head, 4u, opt.workspace_allocator);
Mat xk(embed_dim_per_head, dst_seqlen, num_head, 4u, opt.workspace_allocator);
Mat xv(dst_seqlen, embed_dim_per_head, num_head, 4u, opt.workspace_allocator);
Mat xq(embed_dim_per_head, src_seqlen, num_heads, 4u, opt.workspace_allocator);
Mat xk(embed_dim_per_head, dst_seqlen, num_heads, 4u, opt.workspace_allocator);
Mat xv(dst_seqlen, embed_dim_per_head, num_heads, 4u, opt.workspace_allocator);

Mat xqk(dst_seqlen, src_seqlen, num_head, 4u, opt.workspace_allocator);
Mat xqk(dst_seqlen, src_seqlen, num_heads, 4u, opt.workspace_allocator);

Mat xqkv(embed_dim_per_head, num_head, src_seqlen, 4u, opt.workspace_allocator);
Mat xqkv(embed_dim_per_head, num_heads, src_seqlen, 4u, opt.workspace_allocator);

const float inv_sqrt_embed_dim_per_head = 1.f / sqrt(embed_dim_per_head);

#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < num_head; q++)
for (int q = 0; q < num_heads; q++)
{
// xq = affine(q) * inv_sqrt_embed_dim_per_head
{
@@ -233,7 +233,7 @@ int MultiHeadAttention::forward(const std::vector<Mat>& bottom_blobs, std::vecto
// xqkv = xqk * xv
// xqk (dst_seqlen, src_seqlen)
// xv (dst_seqlen, embed_dim_per_head)
// out (embed_dim_per_head, num_head, src_seqlen)
// out (embed_dim_per_head, num_heads, src_seqlen)
{
const Mat xqkm = xqk.channel(q);
const Mat xvm = xv.channel(q);


+ 1
- 1
src/layer/multiheadattention.h View File

@@ -32,7 +32,7 @@ public:

public:
int embed_dim;
int num_head;
int num_heads;
int weight_data_size;
int kdim;
int vdim;


+ 425
- 0
src/layer/vulkan/gemm_vulkan.cpp View File

@@ -0,0 +1,425 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "gemm_vulkan.h"

#include "layer_shader_type.h"

namespace ncnn {

Gemm_vulkan::Gemm_vulkan()
{
support_vulkan = true;
support_image_storage = true;

pipeline_gemm = 0;
}

int Gemm_vulkan::create_pipeline(const Option& opt)
{
// const Mat& shape = top_shapes.empty() ? Mat() : top_shapes[0];

// int elempack = 1;
// if (shape.dims == 2) elempack = opt.use_shader_pack8 && shape.h % 8 == 0 ? 8 : shape.h % 4 == 0 ? 4 : 1;

// size_t elemsize;
// if (opt.use_fp16_storage)
// {
// elemsize = elempack * 2u;
// }
// else if (opt.use_fp16_packed)
// {
// elemsize = elempack == 1 ? 4u : elempack * 2u;
// }
// else
// {
// elemsize = elempack * 4u;
// }

// Mat shape_packed;
// if (shape.dims == 2) shape_packed = Mat(shape.w, shape.h / elempack, (void*)0, elemsize, elempack);

if (constantA)
{
A_data_packed = transA ? A_data.reshape(constantM, constantK) : A_data.reshape(constantK, constantM);
}
if (constantB)
{
B_data_packed = transB ? B_data.reshape(constantK, constantN) : B_data.reshape(constantN, constantK);
}
if (constantC)
{
C_data_packed = C_data;
}

std::vector<vk_specialization_type> specializations(15);
specializations[0].f = alpha;
specializations[1].f = beta;
specializations[2].i = transA;
specializations[3].i = transB;
specializations[4].i = constantA;
specializations[5].i = constantB;
specializations[6].i = constantC;
specializations[7].i = constantM;
specializations[8].i = constantN;
specializations[9].i = constantK;
specializations[10].i = constant_broadcast_type_C;
specializations[11].i = output_N1M;
specializations[12].i = output_elempack;
specializations[13].i = output_elemtype;
specializations[14].i = output_transpose;

Mat local_size_xyz;
// if (shape_packed.dims == 2)
// {
// local_size_xyz.w = std::min(8, shape_packed.w);
// local_size_xyz.h = std::min(8, shape_packed.h);
// local_size_xyz.c = 1;
// }

// pack1
// if (shape.dims == 0 || elempack == 1)
{
pipeline_gemm = new Pipeline(vkdev);
pipeline_gemm->set_optimal_local_size_xyz(local_size_xyz);
if (opt.use_shader_local_memory)
{
pipeline_gemm->set_local_size_xyz(8, 8, 1);
}
pipeline_gemm->create(LayerShaderType::gemm, opt, specializations);
}

return 0;
}

int Gemm_vulkan::destroy_pipeline(const Option& /*opt*/)
{
delete pipeline_gemm;
pipeline_gemm = 0;

return 0;
}

int Gemm_vulkan::upload_model(VkTransfer& cmd, const Option& opt)
{
if (constantA)
{
if (support_image_storage && opt.use_image_storage)
{
cmd.record_upload(A_data_packed, A_data_gpu_image, opt);
}
else
{
cmd.record_upload(A_data_packed, A_data_gpu, opt);
}

A_data_packed.release();
}

if (constantB)
{
if (support_image_storage && opt.use_image_storage)
{
cmd.record_upload(B_data_packed, B_data_gpu_image, opt);
}
else
{
cmd.record_upload(B_data_packed, B_data_gpu, opt);
}

B_data_packed.release();
}

if (constantC)
{
if (support_image_storage && opt.use_image_storage)
{
cmd.record_upload(C_data_packed, C_data_gpu_image, opt);
}
else
{
cmd.record_upload(C_data_packed, C_data_gpu, opt);
}

C_data_packed.release();
}

return 0;
}

int Gemm_vulkan::forward(const std::vector<VkMat>& bottom_blobs, std::vector<VkMat>& top_blobs, VkCompute& cmd, const Option& opt) const
{
const VkMat& A0 = constantA ? A_data_gpu : bottom_blobs[0];
const VkMat& B0 = constantB ? B_data_gpu : constantA ? bottom_blobs[0] : bottom_blobs[1];
const VkMat& C0 = constantC ? C_data_gpu : bottom_blobs[bottom_blobs.size() - 1];

VkMat A;
VkMat B;
VkMat C;
vkdev->convert_packing(A0, A, 1, cmd, opt);
vkdev->convert_packing(B0, B, 1, cmd, opt);
vkdev->convert_packing(C0, C, 1, cmd, opt);

const int M = constantM ? constantM : transA ? A.w : (A.dims == 3 ? A.c : A.h);
const int K = constantK ? constantK : transA ? (A.dims == 3 ? A.c : A.h) : A.w;
const int N = constantN ? constantN : transB ? (B.dims == 3 ? B.c : B.h) : B.w;

int broadcast_type_C;
if (constantC)
{
broadcast_type_C = constant_broadcast_type_C;
}
else
{
if (C.dims == 1 && C.w == 1)
{
// scalar
broadcast_type_C = 0;
}
if (C.dims == 1 && C.w == M)
{
// M
// auto broadcast from h to w is the ncnn-style convention
broadcast_type_C = 1;
}
if (C.dims == 1 && C.w == N)
{
// N
broadcast_type_C = 4;
}
if (C.dims == 2 && C.w == 1 && C.h == M)
{
// Mx1
broadcast_type_C = 2;
}
if (C.dims == 2 && C.w == N && C.h == M)
{
// MxN
broadcast_type_C = 3;
}
if (C.dims == 2 && C.w == N && C.h == 1)
{
// 1xN
broadcast_type_C = 4;
}
}

int elempack = A.elempack;
size_t elemsize = A.elemsize;

VkMat& top_blob = top_blobs[0];
if (output_transpose)
{
if (output_N1M)
top_blob.create(M, 1, N, elemsize, opt.blob_vkallocator);
else
top_blob.create(M, N, elemsize, opt.blob_vkallocator);
}
else
{
if (output_N1M)
top_blob.create(N, 1, M, elemsize, opt.blob_vkallocator);
else
top_blob.create(N, M, elemsize, opt.blob_vkallocator);
}
if (top_blob.empty())
return -100;

std::vector<VkMat> bindings(4);
bindings[0] = top_blob;
bindings[1] = A;
bindings[2] = B;
bindings[3] = C;

std::vector<vk_constant_type> constants(10);
constants[0].i = M;
constants[1].i = N;
constants[2].i = K;
constants[3].i = broadcast_type_C;
constants[4].i = A.dims;
constants[5].i = A.dims == 3 ? A.cstep : transA ? M : K;
constants[6].i = B.dims;
constants[7].i = B.dims == 3 ? B.cstep : transB ? K : N;
constants[8].i = top_blob.dims;
constants[9].i = top_blob.dims == 3 ? top_blob.cstep : top_blob.w;

const Pipeline* pipeline = pipeline_gemm;

VkMat dispatcher;
dispatcher.w = (N + 1) / 2;
dispatcher.h = (M + 1) / 2;
dispatcher.c = 1;
cmd.record_pipeline(pipeline, bindings, constants, dispatcher);

int out_elempack = 1;
{
int outh = output_transpose ? N : M;
out_elempack = opt.use_shader_pack8 && outh % 8 == 0 ? 8 : outh % 4 == 0 ? 4 : 1;
}
if (output_elempack)
out_elempack = output_elempack;

if (out_elempack != 1)
{
VkMat top_blob0;
vkdev->convert_packing(top_blob, top_blob0, out_elempack, cmd, opt);
top_blobs[0] = top_blob0;
}

return 0;
}

int Gemm_vulkan::forward(const VkMat& bottom_blob, VkMat& top_blob, VkCompute& cmd, const Option& opt) const
{
std::vector<VkMat> bottom_blobs(1);
std::vector<VkMat> top_blobs(1);
bottom_blobs[0] = bottom_blob;
int ret = forward(bottom_blobs, top_blobs, cmd, opt);
top_blob = top_blobs[0];
return ret;
}

int Gemm_vulkan::forward(const std::vector<VkImageMat>& bottom_blobs, std::vector<VkImageMat>& top_blobs, VkCompute& cmd, const Option& opt) const
{
const VkImageMat& A0 = constantA ? A_data_gpu_image : bottom_blobs[0];
const VkImageMat& B0 = constantB ? B_data_gpu_image : constantA ? bottom_blobs[0] : bottom_blobs[1];
const VkImageMat& C0 = constantC ? C_data_gpu_image : bottom_blobs[bottom_blobs.size() - 1];

VkImageMat A;
VkImageMat B;
VkImageMat C;
vkdev->convert_packing(A0, A, 1, cmd, opt);
vkdev->convert_packing(B0, B, 1, cmd, opt);
vkdev->convert_packing(C0, C, 1, cmd, opt);

const int M = constantM ? constantM : transA ? A.w : (A.dims == 3 ? A.c : A.h);
const int K = constantK ? constantK : transA ? (A.dims == 3 ? A.c : A.h) : A.w;
const int N = constantN ? constantN : transB ? (B.dims == 3 ? B.c : B.h) : B.w;

int broadcast_type_C;
if (constantC)
{
broadcast_type_C = constant_broadcast_type_C;
}
else
{
if (C.dims == 1 && C.w == 1)
{
// scalar
broadcast_type_C = 0;
}
if (C.dims == 1 && C.w == M)
{
// M
// auto broadcast from h to w is the ncnn-style convention
broadcast_type_C = 1;
}
if (C.dims == 1 && C.w == N)
{
// N
broadcast_type_C = 4;
}
if (C.dims == 2 && C.w == 1 && C.h == M)
{
// Mx1
broadcast_type_C = 2;
}
if (C.dims == 2 && C.w == N && C.h == M)
{
// MxN
broadcast_type_C = 3;
}
if (C.dims == 2 && C.w == N && C.h == 1)
{
// 1xN
broadcast_type_C = 4;
}
}

int elempack = A.elempack;
size_t elemsize = A.elemsize;

VkImageMat& top_blob = top_blobs[0];
if (output_transpose)
{
if (output_N1M)
top_blob.create(M, 1, N, elemsize, opt.blob_vkallocator);
else
top_blob.create(M, N, elemsize, opt.blob_vkallocator);
}
else
{
if (output_N1M)
top_blob.create(N, 1, M, elemsize, opt.blob_vkallocator);
else
top_blob.create(N, M, elemsize, opt.blob_vkallocator);
}
if (top_blob.empty())
return -100;

std::vector<VkImageMat> bindings(4);
bindings[0] = top_blob;
bindings[1] = A;
bindings[2] = B;
bindings[3] = C;

std::vector<vk_constant_type> constants(10);
constants[0].i = M;
constants[1].i = N;
constants[2].i = K;
constants[3].i = broadcast_type_C;
constants[4].i = A.dims;
constants[5].i = 0; //A.w;
constants[6].i = B.dims;
constants[7].i = 0; //B.w;
constants[8].i = top_blob.dims;
constants[9].i = 0; //top_blob.w;

const Pipeline* pipeline = pipeline_gemm;

VkImageMat dispatcher;
dispatcher.w = (N + 1) / 2;
dispatcher.h = (M + 1) / 2;
dispatcher.c = 1;
cmd.record_pipeline(pipeline, bindings, constants, dispatcher);

int out_elempack = 1;
{
int outh = output_transpose ? N : M;
out_elempack = opt.use_shader_pack8 && outh % 8 == 0 ? 8 : outh % 4 == 0 ? 4 : 1;
}
if (output_elempack)
out_elempack = output_elempack;

if (out_elempack != 1)
{
VkImageMat top_blob0;
vkdev->convert_packing(top_blob, top_blob0, out_elempack, cmd, opt);
top_blobs[0] = top_blob0;
}

return 0;
}

int Gemm_vulkan::forward(const VkImageMat& bottom_blob, VkImageMat& top_blob, VkCompute& cmd, const Option& opt) const
{
std::vector<VkImageMat> bottom_blobs(1);
std::vector<VkImageMat> top_blobs(1);
bottom_blobs[0] = bottom_blob;
int ret = forward(bottom_blobs, top_blobs, cmd, opt);
top_blob = top_blobs[0];
return ret;
}

} // namespace ncnn

+ 56
- 0
src/layer/vulkan/gemm_vulkan.h View File

@@ -0,0 +1,56 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#ifndef LAYER_GEMM_VULKAN_H
#define LAYER_GEMM_VULKAN_H

#include "gemm.h"

namespace ncnn {

class Gemm_vulkan : virtual public Gemm
{
public:
Gemm_vulkan();

virtual int create_pipeline(const Option& opt);
virtual int destroy_pipeline(const Option& opt);

virtual int upload_model(VkTransfer& cmd, const Option& opt);

using Gemm::forward;
virtual int forward(const std::vector<VkMat>& bottom_blobs, std::vector<VkMat>& top_blobs, VkCompute& cmd, const Option& opt) const;
virtual int forward(const VkMat& bottom_blob, VkMat& top_blob, VkCompute& cmd, const Option& opt) const;
virtual int forward(const std::vector<VkImageMat>& bottom_blobs, std::vector<VkImageMat>& top_blobs, VkCompute& cmd, const Option& opt) const;
virtual int forward(const VkImageMat& bottom_blob, VkImageMat& top_blob, VkCompute& cmd, const Option& opt) const;

public:
Mat A_data_packed;
Mat B_data_packed;
Mat C_data_packed;

VkMat A_data_gpu;
VkMat B_data_gpu;
VkMat C_data_gpu;

VkImageMat A_data_gpu_image;
VkImageMat B_data_gpu_image;
VkImageMat C_data_gpu_image;

Pipeline* pipeline_gemm;
};

} // namespace ncnn

#endif // LAYER_GEMM_VULKAN_H

+ 705
- 0
src/layer/vulkan/multiheadattention_vulkan.cpp View File

@@ -0,0 +1,705 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "multiheadattention_vulkan.h"

#include "layer_shader_type.h"
#include "layer_type.h"

namespace ncnn {

MultiHeadAttention_vulkan::MultiHeadAttention_vulkan()
{
support_vulkan = true;
support_image_storage = true;

q_gemm = 0;
k_gemm = 0;
v_gemm = 0;

qk_softmax = 0;

o_gemm = 0;

pipeline_multiheadattention_qk_cross = 0;
pipeline_multiheadattention_qk_cross_pack4 = 0;
pipeline_multiheadattention_qk_cross_pack1to4 = 0;
pipeline_multiheadattention_qk_cross_pack4to1 = 0;

pipeline_multiheadattention_qkv_cross = 0;
pipeline_multiheadattention_qkv_cross_pack4 = 0;
pipeline_multiheadattention_qkv_cross_pack1to4 = 0;
pipeline_multiheadattention_qkv_cross_pack4to1 = 0;
}

int MultiHeadAttention_vulkan::create_pipeline(const Option& opt)
{
const int embed_dim_per_head = embed_dim / num_heads;
{
const float inv_sqrt_embed_dim_per_head = 1.f / sqrt(embed_dim_per_head);

q_gemm = ncnn::create_layer(ncnn::LayerType::Gemm);
q_gemm->vkdev = vkdev;
ncnn::ParamDict pd;
pd.set(0, inv_sqrt_embed_dim_per_head);
pd.set(1, 1.f);
pd.set(2, 0); // transA
pd.set(3, 1); // transB
pd.set(4, 1); // constantA
pd.set(5, 0); // constantB
pd.set(6, 1); // constantC
pd.set(7, embed_dim); // M
pd.set(8, 0); // N
pd.set(9, embed_dim); // K
pd.set(10, 1); // constant_broadcast_type_C
pd.set(11, 0); // output_N1M
// pd.set(12, 1); // output_elempack
pd.set(14, 0); // output_transpose
q_gemm->load_param(pd);
Mat weights[2];
weights[0] = q_weight_data;
weights[1] = q_bias_data;
q_gemm->load_model(ModelBinFromMatArray(weights));
q_gemm->create_pipeline(opt);
}

{
k_gemm = ncnn::create_layer(ncnn::LayerType::Gemm);
k_gemm->vkdev = vkdev;
ncnn::ParamDict pd;
pd.set(2, 0); // transA
pd.set(3, 1); // transB
pd.set(4, 1); // constantA
pd.set(5, 0); // constantB
pd.set(6, 1); // constantC
pd.set(7, embed_dim); // M
pd.set(8, 0); // N
pd.set(9, kdim); // K
pd.set(10, 1); // constant_broadcast_type_C
pd.set(11, 0); // output_N1M
// pd.set(12, 1); // output_elempack
pd.set(14, 0); // output_transpose
k_gemm->load_param(pd);
Mat weights[2];
weights[0] = k_weight_data;
weights[1] = k_bias_data;
k_gemm->load_model(ModelBinFromMatArray(weights));
k_gemm->create_pipeline(opt);
}

{
v_gemm = ncnn::create_layer(ncnn::LayerType::Gemm);
v_gemm->vkdev = vkdev;
ncnn::ParamDict pd;
pd.set(2, 0); // transA
pd.set(3, 1); // transB
pd.set(4, 1); // constantA
pd.set(5, 0); // constantB
pd.set(6, 1); // constantC
pd.set(7, embed_dim); // M
pd.set(8, 0); // N
pd.set(9, vdim); // K
pd.set(10, 1); // constant_broadcast_type_C
pd.set(11, 0); // output_N1M
// pd.set(12, 1); // output_elempack
pd.set(14, 0); // output_transpose
v_gemm->load_param(pd);
Mat weights[2];
weights[0] = v_weight_data;
weights[1] = v_bias_data;
v_gemm->load_model(ModelBinFromMatArray(weights));
v_gemm->create_pipeline(opt);
}

{
std::vector<vk_specialization_type> specializations(4);
specializations[0].i = 0; //constantM;
specializations[1].i = 0; //constantN;
specializations[2].i = 0; //embed_dim_per_head;//constantK;
specializations[3].i = num_heads;

{
pipeline_multiheadattention_qk_cross = new Pipeline(vkdev);
pipeline_multiheadattention_qk_cross->set_local_size_xyz(8, 8, 1);
pipeline_multiheadattention_qk_cross->create(LayerShaderType::multiheadattention_qk_cross, opt, specializations);
}
{
pipeline_multiheadattention_qk_cross_pack4 = new Pipeline(vkdev);
pipeline_multiheadattention_qk_cross_pack4->set_local_size_xyz(8, 8, 1);
pipeline_multiheadattention_qk_cross_pack4->create(LayerShaderType::multiheadattention_qk_cross_pack4, opt, specializations);
}
{
pipeline_multiheadattention_qk_cross_pack1to4 = new Pipeline(vkdev);
pipeline_multiheadattention_qk_cross_pack1to4->set_local_size_xyz(8, 8, 1);
pipeline_multiheadattention_qk_cross_pack1to4->create(LayerShaderType::multiheadattention_qk_cross_pack1to4, opt, specializations);
}
{
pipeline_multiheadattention_qk_cross_pack4to1 = new Pipeline(vkdev);
pipeline_multiheadattention_qk_cross_pack4to1->set_local_size_xyz(8, 8, 1);
pipeline_multiheadattention_qk_cross_pack4to1->create(LayerShaderType::multiheadattention_qk_cross_pack4to1, opt, specializations);
}
}
{
std::vector<vk_specialization_type> specializations(4);
specializations[0].i = 0; //constantM;
specializations[1].i = 0; //embed_dim_per_head;//constantN;
specializations[2].i = 0; //constantK;
specializations[3].i = num_heads;

{
pipeline_multiheadattention_qkv_cross = new Pipeline(vkdev);
pipeline_multiheadattention_qkv_cross->set_local_size_xyz(8, 8, 1);
pipeline_multiheadattention_qkv_cross->create(LayerShaderType::multiheadattention_qkv_cross, opt, specializations);
}
{
pipeline_multiheadattention_qkv_cross_pack4 = new Pipeline(vkdev);
pipeline_multiheadattention_qkv_cross_pack4->set_local_size_xyz(8, 8, 1);
pipeline_multiheadattention_qkv_cross_pack4->create(LayerShaderType::multiheadattention_qkv_cross_pack4, opt, specializations);
}
{
pipeline_multiheadattention_qkv_cross_pack1to4 = new Pipeline(vkdev);
pipeline_multiheadattention_qkv_cross_pack1to4->set_local_size_xyz(8, 8, 1);
pipeline_multiheadattention_qkv_cross_pack1to4->create(LayerShaderType::multiheadattention_qkv_cross_pack1to4, opt, specializations);
}
{
pipeline_multiheadattention_qkv_cross_pack4to1 = new Pipeline(vkdev);
pipeline_multiheadattention_qkv_cross_pack4to1->set_local_size_xyz(8, 8, 1);
pipeline_multiheadattention_qkv_cross_pack4to1->create(LayerShaderType::multiheadattention_qkv_cross_pack4to1, opt, specializations);
}
}

{
qk_softmax = ncnn::create_layer(ncnn::LayerType::Softmax);
qk_softmax->vkdev = vkdev;
ncnn::ParamDict pd;
pd.set(0, -1);
pd.set(1, 1);
qk_softmax->load_param(pd);
qk_softmax->load_model(ModelBinFromMatArray(0));
qk_softmax->create_pipeline(opt);
}

{
o_gemm = ncnn::create_layer(ncnn::LayerType::Gemm);
o_gemm->vkdev = vkdev;
ncnn::ParamDict pd;
pd.set(2, 1); // transA
pd.set(3, 1); // transB
pd.set(4, 0); // constantA
pd.set(5, 1); // constantB
pd.set(6, 1); // constantC
pd.set(7, 0); // M = outch
pd.set(8, embed_dim); // N = size
pd.set(9, embed_dim); // K = maxk*inch
pd.set(10, 4); // constant_broadcast_type_C
pd.set(11, 0); // output_N1M
o_gemm->load_param(pd);
Mat weights[2];
weights[0] = out_weight_data;
weights[1] = out_bias_data;
o_gemm->load_model(ModelBinFromMatArray(weights));
o_gemm->create_pipeline(opt);
}

return 0;
}

int MultiHeadAttention_vulkan::destroy_pipeline(const Option& opt)
{
if (q_gemm)
{
q_gemm->destroy_pipeline(opt);
delete q_gemm;
q_gemm = 0;
}

if (k_gemm)
{
k_gemm->destroy_pipeline(opt);
delete k_gemm;
k_gemm = 0;
}

if (v_gemm)
{
v_gemm->destroy_pipeline(opt);
delete v_gemm;
v_gemm = 0;
}

delete pipeline_multiheadattention_qk_cross;
pipeline_multiheadattention_qk_cross = 0;

delete pipeline_multiheadattention_qk_cross_pack4;
pipeline_multiheadattention_qk_cross_pack4 = 0;

delete pipeline_multiheadattention_qk_cross_pack1to4;
pipeline_multiheadattention_qk_cross_pack1to4 = 0;

delete pipeline_multiheadattention_qk_cross_pack4to1;
pipeline_multiheadattention_qk_cross_pack4to1 = 0;

delete pipeline_multiheadattention_qkv_cross;
pipeline_multiheadattention_qkv_cross = 0;

delete pipeline_multiheadattention_qkv_cross_pack4;
pipeline_multiheadattention_qkv_cross_pack4 = 0;

delete pipeline_multiheadattention_qkv_cross_pack1to4;
pipeline_multiheadattention_qkv_cross_pack1to4 = 0;

delete pipeline_multiheadattention_qkv_cross_pack4to1;
pipeline_multiheadattention_qkv_cross_pack4to1 = 0;

if (qk_softmax)
{
qk_softmax->destroy_pipeline(opt);
delete qk_softmax;
qk_softmax = 0;
}

if (o_gemm)
{
o_gemm->destroy_pipeline(opt);
delete o_gemm;
o_gemm = 0;
}

return 0;
}

int MultiHeadAttention_vulkan::upload_model(VkTransfer& cmd, const Option& opt)
{
if (q_gemm)
{
q_gemm->upload_model(cmd, opt);
}

if (k_gemm)
{
k_gemm->upload_model(cmd, opt);
}

if (v_gemm)
{
v_gemm->upload_model(cmd, opt);
}

if (o_gemm)
{
o_gemm->upload_model(cmd, opt);
}

return 0;
}

int MultiHeadAttention_vulkan::forward(const std::vector<VkMat>& bottom_blobs, std::vector<VkMat>& top_blobs, VkCompute& cmd, const Option& opt) const
{
const VkMat& q_blob = bottom_blobs[0];
const VkMat& k_blob = bottom_blobs.size() == 1 ? q_blob : bottom_blobs[1];
const VkMat& v_blob = bottom_blobs.size() == 1 ? q_blob : bottom_blobs.size() == 2 ? k_blob : bottom_blobs[2];

const int embed_dim_per_head = embed_dim / num_heads;
const int src_seqlen = q_blob.h * q_blob.elempack;
const int dst_seqlen = k_blob.h * k_blob.elempack;

VkMat q_affine;
q_gemm->forward(q_blob, q_affine, cmd, opt);

VkMat k_affine;
k_gemm->forward(k_blob, k_affine, cmd, opt);

VkMat qk_cross;
{
int M = q_affine.w;
int N = k_affine.w;
int K = q_affine.h * q_affine.elempack / num_heads;
int B = num_heads;

// int K_elempack = opt.use_shader_pack8 && K % 8 == 0 ? 8 : K % 4 == 0 ? 4 : 1;
// int M_elempack = opt.use_shader_pack8 && M % 8 == 0 ? 8 : M % 4 == 0 ? 4 : 1;
// int MB_elempack = opt.use_shader_pack8 && (M * B) % 8 == 0 ? 8 : (M * B) % 4 == 0 ? 4 : 1;
int K_elempack = K % 4 == 0 ? 4 : 1;
int M_elempack = M % 4 == 0 ? 4 : 1;
int MB_elempack = (M * B) % 4 == 0 ? 4 : 1;
size_t M_elemsize = q_affine.elemsize / q_affine.elempack * M_elempack;

if (opt.use_fp16_packed && !opt.use_fp16_storage)
{
if (M_elempack == 8) M_elemsize = 8 * 2u;
if (M_elempack == 4) M_elemsize = 4 * 2u;
if (M_elempack == 1) M_elemsize = 4u;
}

if (K_elempack < q_affine.elempack)
{
VkMat tmp;
vkdev->convert_packing(q_affine, tmp, K_elempack, cmd, opt);
q_affine = tmp;
}
if (K_elempack < k_affine.elempack)
{
VkMat tmp;
vkdev->convert_packing(k_affine, tmp, K_elempack, cmd, opt);
k_affine = tmp;
}

qk_cross.create(N, M / M_elempack * B, M_elemsize, M_elempack, opt.blob_vkallocator);
if (qk_cross.empty())
return -100;

std::vector<VkMat> bindings(3);
bindings[0] = q_affine;
bindings[1] = k_affine;
bindings[2] = qk_cross;

std::vector<vk_constant_type> constants(4);
constants[0].i = M / M_elempack;
constants[1].i = N;
constants[2].i = K / K_elempack;
constants[3].i = B;

VkMat dispatcher;
dispatcher.w = N;
dispatcher.h = M / M_elempack;
dispatcher.c = B;

const Pipeline* pipeline = 0;
if (K_elempack == 1 && M_elempack == 1)
{
pipeline = pipeline_multiheadattention_qk_cross;
}
if (K_elempack == 1 && M_elempack == 4)
{
pipeline = pipeline_multiheadattention_qk_cross_pack1to4;
}
if (K_elempack == 4 && M_elempack == 1)
{
pipeline = pipeline_multiheadattention_qk_cross_pack4to1;
}
if (K_elempack == 4 && M_elempack == 4)
{
pipeline = pipeline_multiheadattention_qk_cross_pack4;
}

cmd.record_pipeline(pipeline, bindings, constants, dispatcher);

if (MB_elempack > M_elempack)
{
VkMat tmp;
vkdev->convert_packing(qk_cross, tmp, MB_elempack, cmd, opt);
qk_cross = tmp;
}
}

q_affine.release();
k_affine.release();

qk_softmax->forward_inplace(qk_cross, cmd, opt);

VkMat v_affine;
v_gemm->forward(v_blob, v_affine, cmd, opt);

VkMat qkv_cross;
{
int M = qk_cross.h * qk_cross.elempack / num_heads;
int N = v_affine.h * v_affine.elempack / num_heads;
int K = v_affine.w;
int B = num_heads;

// int M_elempack = opt.use_shader_pack8 && M % 8 == 0 ? 8 : M % 4 == 0 ? 4 : 1;
// int N_elempack = opt.use_shader_pack8 && N % 8 == 0 ? 8 : N % 4 == 0 ? 4 : 1;
// int NB_elempack = opt.use_shader_pack8 && (N * B) % 8 == 0 ? 8 : (N * B) % 4 == 0 ? 4 : 1;
int M_elempack = M % 4 == 0 ? 4 : 1;
int N_elempack = N % 4 == 0 ? 4 : 1;
int NB_elempack = (N * B) % 4 == 0 ? 4 : 1;
size_t N_elemsize = v_affine.elemsize / v_affine.elempack * N_elempack;

if (opt.use_fp16_packed && !opt.use_fp16_storage)
{
if (N_elempack == 8) N_elemsize = 8 * 2u;
if (N_elempack == 4) N_elemsize = 4 * 2u;
if (N_elempack == 1) N_elemsize = 4u;
}

if (M_elempack < qk_cross.elempack)
{
VkMat tmp;
vkdev->convert_packing(qk_cross, tmp, M_elempack, cmd, opt);
qk_cross = tmp;
}

if (N_elempack < v_affine.elempack)
{
VkMat tmp;
vkdev->convert_packing(v_affine, tmp, N_elempack, cmd, opt);
v_affine = tmp;
}

qkv_cross.create(M, N / N_elempack * B, N_elemsize, N_elempack, opt.blob_vkallocator);
if (qkv_cross.empty())
return -100;

std::vector<VkMat> bindings(3);
bindings[0] = qk_cross;
bindings[1] = v_affine;
bindings[2] = qkv_cross;

std::vector<vk_constant_type> constants(4);
constants[0].i = M / M_elempack;
constants[1].i = N / N_elempack;
constants[2].i = K;
constants[3].i = B;

VkMat dispatcher;
dispatcher.w = N / N_elempack;
dispatcher.h = M / M_elempack;
dispatcher.c = B;

const Pipeline* pipeline = 0;
if (M_elempack == 1 && N_elempack == 1)
{
pipeline = pipeline_multiheadattention_qkv_cross;
}
if (M_elempack == 1 && N_elempack == 4)
{
pipeline = pipeline_multiheadattention_qkv_cross_pack1to4;
}
if (M_elempack == 4 && N_elempack == 1)
{
pipeline = pipeline_multiheadattention_qkv_cross_pack4to1;
}
if (M_elempack == 4 && N_elempack == 4)
{
pipeline = pipeline_multiheadattention_qkv_cross_pack4;
}

cmd.record_pipeline(pipeline, bindings, constants, dispatcher);

if (NB_elempack > N_elempack)
{
VkMat tmp;
vkdev->convert_packing(qkv_cross, tmp, NB_elempack, cmd, opt);
qkv_cross = tmp;
}
}

qk_cross.release();
v_affine.release();

o_gemm->forward(qkv_cross, top_blobs[0], cmd, opt);

return 0;
}

int MultiHeadAttention_vulkan::forward(const std::vector<VkImageMat>& bottom_blobs, std::vector<VkImageMat>& top_blobs, VkCompute& cmd, const Option& opt) const
{
const VkImageMat& q_blob = bottom_blobs[0];
const VkImageMat& k_blob = bottom_blobs.size() == 1 ? q_blob : bottom_blobs[1];
const VkImageMat& v_blob = bottom_blobs.size() == 1 ? q_blob : bottom_blobs.size() == 2 ? k_blob : bottom_blobs[2];

const int embed_dim_per_head = embed_dim / num_heads;
const int src_seqlen = q_blob.h * q_blob.elempack;
const int dst_seqlen = k_blob.h * k_blob.elempack;

VkImageMat q_affine;
q_gemm->forward(q_blob, q_affine, cmd, opt);

VkImageMat k_affine;
k_gemm->forward(k_blob, k_affine, cmd, opt);

VkImageMat qk_cross;
{
int M = q_affine.w;
int N = k_affine.w;
int K = q_affine.h * q_affine.elempack / num_heads;
int B = num_heads;

// int K_elempack = opt.use_shader_pack8 && K % 8 == 0 ? 8 : K % 4 == 0 ? 4 : 1;
// int M_elempack = opt.use_shader_pack8 && M % 8 == 0 ? 8 : M % 4 == 0 ? 4 : 1;
// int MB_elempack = opt.use_shader_pack8 && (M * B) % 8 == 0 ? 8 : (M * B) % 4 == 0 ? 4 : 1;
int K_elempack = K % 4 == 0 ? 4 : 1;
int M_elempack = M % 4 == 0 ? 4 : 1;
int MB_elempack = (M * B) % 4 == 0 ? 4 : 1;
size_t M_elemsize = q_affine.elemsize / q_affine.elempack * M_elempack;

if (opt.use_fp16_packed && !opt.use_fp16_storage)
{
if (M_elempack == 8) M_elemsize = 8 * 2u;
if (M_elempack == 4) M_elemsize = 4 * 2u;
if (M_elempack == 1) M_elemsize = 4u;
}

if (K_elempack < q_affine.elempack)
{
VkImageMat tmp;
vkdev->convert_packing(q_affine, tmp, K_elempack, cmd, opt);
q_affine = tmp;
}
if (K_elempack < k_affine.elempack)
{
VkImageMat tmp;
vkdev->convert_packing(k_affine, tmp, K_elempack, cmd, opt);
k_affine = tmp;
}

qk_cross.create(N, M / M_elempack * B, M_elemsize, M_elempack, opt.blob_vkallocator);
if (qk_cross.empty())
return -100;

std::vector<VkImageMat> bindings(3);
bindings[0] = q_affine;
bindings[1] = k_affine;
bindings[2] = qk_cross;

std::vector<vk_constant_type> constants(4);
constants[0].i = M / M_elempack;
constants[1].i = N;
constants[2].i = K / K_elempack;
constants[3].i = B;

VkImageMat dispatcher;
dispatcher.w = N;
dispatcher.h = M / M_elempack;
dispatcher.c = B;

const Pipeline* pipeline = 0;
if (K_elempack == 1 && M_elempack == 1)
{
pipeline = pipeline_multiheadattention_qk_cross;
}
if (K_elempack == 1 && M_elempack == 4)
{
pipeline = pipeline_multiheadattention_qk_cross_pack1to4;
}
if (K_elempack == 4 && M_elempack == 1)
{
pipeline = pipeline_multiheadattention_qk_cross_pack4to1;
}
if (K_elempack == 4 && M_elempack == 4)
{
pipeline = pipeline_multiheadattention_qk_cross_pack4;
}

cmd.record_pipeline(pipeline, bindings, constants, dispatcher);

if (MB_elempack > M_elempack)
{
VkImageMat tmp;
vkdev->convert_packing(qk_cross, tmp, MB_elempack, cmd, opt);
qk_cross = tmp;
}
}

q_affine.release();
k_affine.release();

qk_softmax->forward_inplace(qk_cross, cmd, opt);

VkImageMat v_affine;
v_gemm->forward(v_blob, v_affine, cmd, opt);

VkImageMat qkv_cross;
{
int M = qk_cross.h * qk_cross.elempack / num_heads;
int N = v_affine.h * v_affine.elempack / num_heads;
int K = v_affine.w;
int B = num_heads;

// int M_elempack = opt.use_shader_pack8 && M % 8 == 0 ? 8 : M % 4 == 0 ? 4 : 1;
// int N_elempack = opt.use_shader_pack8 && N % 8 == 0 ? 8 : N % 4 == 0 ? 4 : 1;
// int NB_elempack = opt.use_shader_pack8 && (N * B) % 8 == 0 ? 8 : (N * B) % 4 == 0 ? 4 : 1;
int M_elempack = M % 4 == 0 ? 4 : 1;
int N_elempack = N % 4 == 0 ? 4 : 1;
int NB_elempack = (N * B) % 4 == 0 ? 4 : 1;
size_t N_elemsize = v_affine.elemsize / v_affine.elempack * N_elempack;

if (opt.use_fp16_packed && !opt.use_fp16_storage)
{
if (N_elempack == 8) N_elemsize = 8 * 2u;
if (N_elempack == 4) N_elemsize = 4 * 2u;
if (N_elempack == 1) N_elemsize = 4u;
}

if (M_elempack < qk_cross.elempack)
{
VkImageMat tmp;
vkdev->convert_packing(qk_cross, tmp, M_elempack, cmd, opt);
qk_cross = tmp;
}

if (N_elempack < v_affine.elempack)
{
VkImageMat tmp;
vkdev->convert_packing(v_affine, tmp, N_elempack, cmd, opt);
v_affine = tmp;
}

qkv_cross.create(M, N / N_elempack * B, N_elemsize, N_elempack, opt.blob_vkallocator);
if (qkv_cross.empty())
return -100;

std::vector<VkImageMat> bindings(3);
bindings[0] = qk_cross;
bindings[1] = v_affine;
bindings[2] = qkv_cross;

std::vector<vk_constant_type> constants(4);
constants[0].i = M / M_elempack;
constants[1].i = N / N_elempack;
constants[2].i = K;
constants[3].i = B;

VkImageMat dispatcher;
dispatcher.w = N / N_elempack;
dispatcher.h = M / M_elempack;
dispatcher.c = B;

const Pipeline* pipeline = 0;
if (M_elempack == 1 && N_elempack == 1)
{
pipeline = pipeline_multiheadattention_qkv_cross;
}
if (M_elempack == 1 && N_elempack == 4)
{
pipeline = pipeline_multiheadattention_qkv_cross_pack1to4;
}
if (M_elempack == 4 && N_elempack == 1)
{
pipeline = pipeline_multiheadattention_qkv_cross_pack4to1;
}
if (M_elempack == 4 && N_elempack == 4)
{
pipeline = pipeline_multiheadattention_qkv_cross_pack4;
}

cmd.record_pipeline(pipeline, bindings, constants, dispatcher);

if (NB_elempack > N_elempack)
{
VkImageMat tmp;
vkdev->convert_packing(qkv_cross, tmp, NB_elempack, cmd, opt);
qkv_cross = tmp;
}
}

qk_cross.release();
v_affine.release();

o_gemm->forward(qkv_cross, top_blobs[0], cmd, opt);

return 0;
}

} // namespace ncnn

+ 57
- 0
src/layer/vulkan/multiheadattention_vulkan.h View File

@@ -0,0 +1,57 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#ifndef LAYER_MULTIHEADATTENTION_VULKAN_H
#define LAYER_MULTIHEADATTENTION_VULKAN_H

#include "multiheadattention.h"

namespace ncnn {

class MultiHeadAttention_vulkan : virtual public MultiHeadAttention
{
public:
MultiHeadAttention_vulkan();

virtual int create_pipeline(const Option& opt);
virtual int destroy_pipeline(const Option& opt);

virtual int upload_model(VkTransfer& cmd, const Option& opt);

using MultiHeadAttention::forward;
virtual int forward(const std::vector<VkMat>& bottom_blobs, std::vector<VkMat>& top_blobs, VkCompute& cmd, const Option& opt) const;
virtual int forward(const std::vector<VkImageMat>& bottom_blobs, std::vector<VkImageMat>& top_blobs, VkCompute& cmd, const Option& opt) const;

public:
Layer* q_gemm;
Layer* k_gemm;
Layer* v_gemm;
Layer* o_gemm;

Layer* qk_softmax;

Pipeline* pipeline_multiheadattention_qk_cross;
Pipeline* pipeline_multiheadattention_qk_cross_pack4;
Pipeline* pipeline_multiheadattention_qk_cross_pack1to4;
Pipeline* pipeline_multiheadattention_qk_cross_pack4to1;

Pipeline* pipeline_multiheadattention_qkv_cross;
Pipeline* pipeline_multiheadattention_qkv_cross_pack4;
Pipeline* pipeline_multiheadattention_qkv_cross_pack1to4;
Pipeline* pipeline_multiheadattention_qkv_cross_pack4to1;
};

} // namespace ncnn

#endif // LAYER_MULTIHEADATTENTION_VULKAN_H

+ 453
- 0
src/layer/vulkan/shader/gemm.comp View File

@@ -0,0 +1,453 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#version 450

#if NCNN_fp16_storage
#extension GL_EXT_shader_16bit_storage: require
#endif
#if NCNN_fp16_arithmetic
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
#endif

#define LOCAL_MEMORY_UNROLL_INCH 8

layout (constant_id = 0) const float alpha = 1.f;
layout (constant_id = 1) const float beta = 1.f;
layout (constant_id = 2) const int transA = 0;
layout (constant_id = 3) const int transB = 0;
layout (constant_id = 4) const int constantA = 0;
layout (constant_id = 5) const int constantB = 0;
layout (constant_id = 6) const int constantC = 0;
layout (constant_id = 7) const int M = 0;
layout (constant_id = 8) const int N = 0;
layout (constant_id = 9) const int K = 0;
layout (constant_id = 10) const int constant_broadcast_type_C = 0;
layout (constant_id = 11) const int output_N1M = 0;
layout (constant_id = 12) const int output_elempack = 0;
layout (constant_id = 13) const int output_elemtype = 0;
layout (constant_id = 14) const int output_transpose = 0;

// TODO psc more

#if NCNN_image_shader
layout (binding = 0, imfmtc1) writeonly uniform unfp image3D top_blob_3d;
layout (binding = 1) uniform unfp sampler3D A_blob_3d;
layout (binding = 2) uniform unfp sampler3D B_blob_3d;
layout (binding = 3) uniform unfp sampler3D C_blob_3d;
#else
layout (binding = 0) writeonly buffer top_blob { sfp top_blob_data[]; };
layout (binding = 1) readonly buffer A_blob { sfp A_blob_data[]; };
layout (binding = 2) readonly buffer B_blob { sfp B_blob_data[]; };
layout (binding = 3) readonly buffer C_blob { sfp C_blob_data[]; };
#endif

layout (push_constant) uniform parameter
{
int M;
int N;
int K;
int broadcast_type_C;
int A_dims;
int A_hstep;
int B_dims;
int B_hstep;
int outdims;
int outhstep;
} p;

#if NCNN_shader_local_memory
shared lfp tmp_a[8][LOCAL_MEMORY_UNROLL_INCH][2];
shared lfp tmp_b[8][LOCAL_MEMORY_UNROLL_INCH][2];
#endif

void main()
{
int gx = int(gl_GlobalInvocationID.x) * 2;
int gy = int(gl_GlobalInvocationID.y) * 2;
int gz = int(gl_GlobalInvocationID.z);

#if !NCNN_shader_local_memory
if (gx >= psc(N) || gy >= psc(M) || gz >= 1)
return;
#endif

afp sum0 = afp(0.f);
afp sum1 = afp(0.f);
afp sum2 = afp(0.f);
afp sum3 = afp(0.f);

const int broadcast_type_C = constantC == 1 ? constant_broadcast_type_C : p.broadcast_type_C;

#if NCNN_image_shader
if (broadcast_type_C == 0)
{
sum0 = image3d_ld1(C_blob_3d, ivec3(0, 0, 0));
sum1 = sum0;
sum2 = sum0;
sum3 = sum0;
}
if (broadcast_type_C == 1)
{
sum0 = image3d_ld1(C_blob_3d, ivec3(gy, 0, 0));
sum1 = sum0;
sum2 = image3d_ld1(C_blob_3d, ivec3(gy + 1, 0, 0));
sum3 = sum2;
}
if (broadcast_type_C == 2)
{
sum0 = image3d_ld1(C_blob_3d, ivec3(0, gy, 0));
sum1 = sum0;
sum2 = image3d_ld1(C_blob_3d, ivec3(0, gy + 1, 0));
sum3 = sum2;
}
if (broadcast_type_C == 3)
{
sum0 = image3d_ld1(C_blob_3d, ivec3(gx, gy, 0));
sum1 = image3d_ld1(C_blob_3d, ivec3(gx + 1, gy, 0));
sum2 = image3d_ld1(C_blob_3d, ivec3(gx, gy + 1, 0));
sum3 = image3d_ld1(C_blob_3d, ivec3(gx + 1, gy + 1, 0));
}
if (broadcast_type_C == 4)
{
sum0 = image3d_ld1(C_blob_3d, ivec3(gx, 0, 0));
sum1 = image3d_ld1(C_blob_3d, ivec3(gx + 1, 0, 0));
sum2 = sum0;
sum3 = sum1;
}
#else
if (broadcast_type_C == 0)
{
sum0 = buffer_ld1(C_blob_data, 0);
sum1 = sum0;
sum2 = sum0;
sum3 = sum0;
}
if (broadcast_type_C == 1 || broadcast_type_C == 2)
{
sum0 = buffer_ld1(C_blob_data, gy);
sum1 = sum0;
sum2 = buffer_ld1(C_blob_data, gy + 1);
sum3 = sum2;
}
if (broadcast_type_C == 3)
{
const int ci = gy * psc(N) + gx;
sum0 = buffer_ld1(C_blob_data, ci);
sum1 = buffer_ld1(C_blob_data, ci + 1);
sum2 = buffer_ld1(C_blob_data, ci + psc(N));
sum3 = buffer_ld1(C_blob_data, ci + psc(N) + 1);
}
if (broadcast_type_C == 4)
{
sum0 = buffer_ld1(C_blob_data, gx);
sum1 = buffer_ld1(C_blob_data, gx + 1);
sum2 = sum0;
sum3 = sum1;
}
#endif

sum0 *= afp(beta);
sum1 *= afp(beta);
sum2 *= afp(beta);
sum3 *= afp(beta);

#if !NCNN_image_shader && NCNN_shader_local_memory
const int NN = psc(K);

const int lx = int(gl_LocalInvocationID.x);
const int ly = int(gl_LocalInvocationID.y);

int k = 0;
for (; k + (LOCAL_MEMORY_UNROLL_INCH - 1) < NN; k += LOCAL_MEMORY_UNROLL_INCH)
{
{
if (transA == 1)
{
const int ai = (k + lx) * p.A_hstep + gy;
tmp_a[ly][lx][0] = sfp2lfp(A_blob_data[ai]);
tmp_a[ly][lx][1] = sfp2lfp(A_blob_data[ai + 1]);
}
else
{
const int ai = gy * p.A_hstep + (k + lx);
tmp_a[ly][lx][0] = sfp2lfp(A_blob_data[ai]);
tmp_a[ly][lx][1] = sfp2lfp(A_blob_data[ai + p.A_hstep]);
}

if (transB == 1)
{
const int bi = gx * p.B_hstep + (k + ly);
tmp_b[lx][ly][0] = sfp2lfp(B_blob_data[bi]);
tmp_b[lx][ly][1] = sfp2lfp(B_blob_data[bi + p.B_hstep]);
}
else
{
const int bi = (k + ly) * p.B_hstep + gx;
tmp_b[lx][ly][0] = sfp2lfp(B_blob_data[bi]);
tmp_b[lx][ly][1] = sfp2lfp(B_blob_data[bi + 1]);
}
}

barrier();

for (int k4 = 0; k4 < LOCAL_MEMORY_UNROLL_INCH; k4++)
{
afp a0 = lfp2afp(tmp_a[ly][k4][0]);
afp a1 = lfp2afp(tmp_a[ly][k4][1]);

afp b0 = lfp2afp(tmp_b[lx][k4][0]);
afp b1 = lfp2afp(tmp_b[lx][k4][1]);

sum0 += a0 * b0;
sum1 += a0 * b1;
sum2 += a1 * b0;
sum3 += a1 * b1;
}

barrier();
}

if (k < NN)
{
const int remain = NN - k;

if (lx < remain)
{
if (transA == 1)
{
const int ai = (k + lx) * p.A_hstep + gy;
tmp_a[ly][lx][0] = sfp2lfp(A_blob_data[ai]);
tmp_a[ly][lx][1] = sfp2lfp(A_blob_data[ai + 1]);
}
else
{
const int ai = gy * p.A_hstep + (k + lx);
tmp_a[ly][lx][0] = sfp2lfp(A_blob_data[ai]);
tmp_a[ly][lx][1] = sfp2lfp(A_blob_data[ai + p.A_hstep]);
}
}

if (ly < remain)
{
if (transB == 1)
{
const int bi = gx * p.B_hstep + (k + ly);
tmp_b[lx][ly][0] = sfp2lfp(B_blob_data[bi]);
tmp_b[lx][ly][1] = sfp2lfp(B_blob_data[bi + p.B_hstep]);
}
else
{
const int bi = (k + ly) * p.B_hstep + gx;
tmp_b[lx][ly][0] = sfp2lfp(B_blob_data[bi]);
tmp_b[lx][ly][1] = sfp2lfp(B_blob_data[bi + 1]);
}
}

barrier();

for (int k4 = 0; k4 < remain; k4++)
{
afp a0 = lfp2afp(tmp_a[ly][k4][0]);
afp a1 = lfp2afp(tmp_a[ly][k4][1]);

afp b0 = lfp2afp(tmp_b[lx][k4][0]);
afp b1 = lfp2afp(tmp_b[lx][k4][1]);

sum0 += a0 * b0;
sum1 += a0 * b1;
sum2 += a1 * b0;
sum3 += a1 * b1;
}
}
#else
for (int k = 0; k < psc(K); k++)
{
afp a0;
afp a1;
afp b0;
afp b1;
#if NCNN_image_shader
if (transA == 1)
{
if (p.A_dims == 3)
{
a0 = image3d_ld1(A_blob_3d, ivec3(gy, 0, k));
a1 = image3d_ld1(A_blob_3d, ivec3(gy + 1, 0, k));
}
else
{
a0 = image3d_ld1(A_blob_3d, ivec3(gy, k, 0));
a1 = image3d_ld1(A_blob_3d, ivec3(gy + 1, k, 0));
}
}
else
{
if (p.A_dims == 3)
{
a0 = image3d_ld1(A_blob_3d, ivec3(k, 0, gy));
a1 = image3d_ld1(A_blob_3d, ivec3(k, 0, gy + 1));
}
else
{
a0 = image3d_ld1(A_blob_3d, ivec3(k, gy, 0));
a1 = image3d_ld1(A_blob_3d, ivec3(k, gy + 1, 0));
}
}

if (transB == 1)
{
if (p.B_dims == 3)
{
b0 = image3d_ld1(B_blob_3d, ivec3(k, 0, gx));
b1 = image3d_ld1(B_blob_3d, ivec3(k, 0, gx + 1));
}
else
{
b0 = image3d_ld1(B_blob_3d, ivec3(k, gx, 0));
b1 = image3d_ld1(B_blob_3d, ivec3(k, gx + 1, 0));
}
}
else
{
if (p.B_dims == 3)
{
b0 = image3d_ld1(B_blob_3d, ivec3(gx, 0, k));
b1 = image3d_ld1(B_blob_3d, ivec3(gx + 1, 0, k));
}
else
{
b0 = image3d_ld1(B_blob_3d, ivec3(gx, k, 0));
b1 = image3d_ld1(B_blob_3d, ivec3(gx + 1, k, 0));
}
}
#else
if (transA == 1)
{
const int ai = k * p.A_hstep + gy;
a0 = buffer_ld1(A_blob_data, ai);
a1 = buffer_ld1(A_blob_data, ai + 1);
}
else
{
const int ai = gy * p.A_hstep + k;
a0 = buffer_ld1(A_blob_data, ai);
a1 = buffer_ld1(A_blob_data, ai + p.A_hstep);
}

if (transB == 1)
{
const int bi = gx * p.B_hstep + k;
b0 = buffer_ld1(B_blob_data, bi);
b1 = buffer_ld1(B_blob_data, bi + p.B_hstep);
}
else
{
const int bi = k * p.B_hstep + gx;
b0 = buffer_ld1(B_blob_data, bi);
b1 = buffer_ld1(B_blob_data, bi + 1);
}
#endif

sum0 += a0 * b0;
sum1 += a0 * b1;
sum2 += a1 * b0;
sum3 += a1 * b1;
}
#endif

#if NCNN_shader_local_memory
if (gx >= psc(N) || gy >= psc(M) || gz >= 1)
return;
#endif

sum0 *= afp(alpha);
sum1 *= afp(alpha);
sum2 *= afp(alpha);
sum3 *= afp(alpha);

#if NCNN_image_shader
if (output_transpose == 1)
{
if (output_N1M == 1)
{
image3d_st1(top_blob_3d, ivec3(gy, 0, gx), sum0);
if (gy + 1 < psc(M)) image3d_st1(top_blob_3d, ivec3(gy + 1, 0, gx), sum2);
if (gx + 1 < psc(N))
{
image3d_st1(top_blob_3d, ivec3(gy, 0, gx + 1), sum1);
if (gy + 1 < psc(M)) image3d_st1(top_blob_3d, ivec3(gy + 1, 0, gx + 1), sum3);
}
}
else
{
image3d_st1(top_blob_3d, ivec3(gy, gx, 0), sum0);
if (gy + 1 < psc(M)) image3d_st1(top_blob_3d, ivec3(gy + 1, gx, 0), sum2);
if (gx + 1 < psc(N))
{
image3d_st1(top_blob_3d, ivec3(gy, gx + 1, 0), sum1);
if (gy + 1 < psc(M)) image3d_st1(top_blob_3d, ivec3(gy + 1, gx + 1, 0), sum3);
}
}
}
else
{
if (output_N1M == 1)
{
image3d_st1(top_blob_3d, ivec3(gx, 0, gy), sum0);
if (gx + 1 < psc(N)) image3d_st1(top_blob_3d, ivec3(gx + 1, 0, gy), sum1);
if (gy + 1 < psc(M))
{
image3d_st1(top_blob_3d, ivec3(gx, 0, gy + 1), sum2);
if (gx + 1 < psc(N)) image3d_st1(top_blob_3d, ivec3(gx + 1, 0, gy + 1), sum3);
}
}
else
{
image3d_st1(top_blob_3d, ivec3(gx, gy, 0), sum0);
if (gx + 1 < psc(N)) image3d_st1(top_blob_3d, ivec3(gx + 1, gy, 0), sum1);
if (gy + 1 < psc(M))
{
image3d_st1(top_blob_3d, ivec3(gx, gy + 1, 0), sum2);
if (gx + 1 < psc(N)) image3d_st1(top_blob_3d, ivec3(gx + 1, gy + 1, 0), sum3);
}
}
}
#else
if (output_transpose == 1)
{
const int gi = gx * p.outhstep + gy;

buffer_st1(top_blob_data, gi, sum0);
if (gy + 1 < psc(M)) buffer_st1(top_blob_data, gi + 1, sum2);
if (gx + 1 < psc(N))
{
buffer_st1(top_blob_data, gi + p.outhstep, sum1);
if (gy + 1 < psc(M)) buffer_st1(top_blob_data, gi + p.outhstep + 1, sum3);
}
}
else
{
const int gi = gy * p.outhstep + gx;

buffer_st1(top_blob_data, gi, sum0);
if (gx + 1 < psc(N)) buffer_st1(top_blob_data, gi + 1, sum1);
if (gy + 1 < psc(M))
{
buffer_st1(top_blob_data, gi + p.outhstep, sum2);
if (gx + 1 < psc(N)) buffer_st1(top_blob_data, gi + p.outhstep + 1, sum3);
}
}
#endif
}

+ 81
- 0
src/layer/vulkan/shader/multiheadattention_qk_cross.comp View File

@@ -0,0 +1,81 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#version 450

#if NCNN_fp16_storage
#extension GL_EXT_shader_16bit_storage: require
#endif
#if NCNN_fp16_arithmetic
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
#endif

layout (constant_id = 0) const int M = 0;
layout (constant_id = 1) const int N = 0;
layout (constant_id = 2) const int K = 0;
layout (constant_id = 3) const int B = 0;

#if NCNN_image_shader
layout (binding = 0) uniform unfp sampler3D q_blob_3d;
layout (binding = 1) uniform unfp sampler3D k_blob_3d;
layout (binding = 2, imfmtc1) writeonly uniform unfp image3D qkcross_blob_3d;
#else
layout (binding = 0) readonly buffer q_blob { sfp q_blob_data[]; };
layout (binding = 1) readonly buffer k_blob { sfp k_blob_data[]; };
layout (binding = 2) writeonly buffer qkcross_blob { sfp qkcross_blob_data[]; };
#endif

layout (push_constant) uniform parameter
{
int M;
int N;
int K;
int B;
} p;

void main()
{
int gx = int(gl_GlobalInvocationID.x);
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);

if (gx >= psc(N) || gy >= psc(M) || gz >= psc(B))
return;

afp sum = afp(0.f);

for (int k = 0; k < psc(K); k++)
{
#if NCNN_image_shader
afp q0 = image3d_ld1(q_blob_3d, ivec3(gy, gz * psc(K) + k, 0));

afp k0 = image3d_ld1(k_blob_3d, ivec3(gx, gz * psc(K) + k, 0));
#else
const int ai = gz * psc(M) * psc(K) + k * psc(M) + gy;
afp q0 = buffer_ld1(q_blob_data, ai);

const int bi = gz * psc(N) * psc(K) + k * psc(N) + gx;
afp k0 = buffer_ld1(k_blob_data, bi);
#endif

sum += q0 * k0;
}

#if NCNN_image_shader
image3d_st1(qkcross_blob_3d, ivec3(gx, gz * psc(M) + gy, 0), sum);
#else
const int gi = gz * psc(M) * psc(N) + gy * psc(N) + gx;
buffer_st1(qkcross_blob_data, gi, sum);
#endif
}

+ 90
- 0
src/layer/vulkan/shader/multiheadattention_qk_cross_pack1to4.comp View File

@@ -0,0 +1,90 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#version 450

#if NCNN_fp16_storage
#extension GL_EXT_shader_16bit_storage: require
#endif
#if NCNN_fp16_arithmetic
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
#endif

layout (constant_id = 0) const int M = 0;
layout (constant_id = 1) const int N = 0;
layout (constant_id = 2) const int K = 0;
layout (constant_id = 3) const int B = 0;

#if NCNN_image_shader
layout (binding = 0) uniform unfp sampler3D q_blob_3d;
layout (binding = 1) uniform unfp sampler3D k_blob_3d;
layout (binding = 2, imfmtc4) writeonly uniform unfp image3D qkcross_blob_3d;
#else
layout (binding = 0) readonly buffer q_blob { sfp q_blob_data[]; };
layout (binding = 1) readonly buffer k_blob { sfp k_blob_data[]; };
layout (binding = 2) writeonly buffer qkcross_blob { sfpvec4 qkcross_blob_data[]; };
#endif

layout (push_constant) uniform parameter
{
int M;
int N;
int K;
int B;
} p;

void main()
{
int gx = int(gl_GlobalInvocationID.x);
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);

if (gx >= psc(N) || gy >= psc(M) || gz >= psc(B))
return;

afpvec4 sum = afpvec4(0.f);

for (int k = 0; k < psc(K); k++)
{
#if NCNN_image_shader
afp q0 = image3d_ld1(q_blob_3d, ivec3(gy * 4, gz * psc(K) + k, 0));
afp q1 = image3d_ld1(q_blob_3d, ivec3(gy * 4 + 1, gz * psc(K) + k, 0));
afp q2 = image3d_ld1(q_blob_3d, ivec3(gy * 4 + 2, gz * psc(K) + k, 0));
afp q3 = image3d_ld1(q_blob_3d, ivec3(gy * 4 + 3, gz * psc(K) + k, 0));

afp k0 = image3d_ld1(k_blob_3d, ivec3(gx, gz * psc(K) + k, 0));
#else
const int ai = (gz * psc(M) * psc(K) + k * psc(M) + gy) * 4;
afp q0 = buffer_ld1(q_blob_data, ai);
afp q1 = buffer_ld1(q_blob_data, ai + 1);
afp q2 = buffer_ld1(q_blob_data, ai + 2);
afp q3 = buffer_ld1(q_blob_data, ai + 3);

const int bi = gz * psc(N) * psc(K) + k * psc(N) + gx;
afp k0 = buffer_ld1(k_blob_data, bi);
#endif

sum.r += q0 * k0;
sum.g += q1 * k0;
sum.b += q2 * k0;
sum.a += q3 * k0;
}

#if NCNN_image_shader
image3d_st4(qkcross_blob_3d, ivec3(gx, gz * psc(M) + gy, 0), sum);
#else
const int gi = gz * psc(M) * psc(N) + gy * psc(N) + gx;
buffer_st4(qkcross_blob_data, gi, sum);
#endif
}

+ 186
- 0
src/layer/vulkan/shader/multiheadattention_qk_cross_pack4.comp View File

@@ -0,0 +1,186 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#version 450

#if NCNN_fp16_storage
#extension GL_EXT_shader_16bit_storage: require
#endif
#if NCNN_fp16_arithmetic
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
#endif

#define LOCAL_MEMORY_UNROLL_INCH 8

layout (constant_id = 0) const int M = 0;
layout (constant_id = 1) const int N = 0;
layout (constant_id = 2) const int K = 0;
layout (constant_id = 3) const int B = 0;

#if NCNN_image_shader
layout (binding = 0) uniform unfp sampler3D q_blob_3d;
layout (binding = 1) uniform unfp sampler3D k_blob_3d;
layout (binding = 2, imfmtc4) writeonly uniform unfp image3D qkcross_blob_3d;
#else
layout (binding = 0) readonly buffer q_blob { sfpvec4 q_blob_data[]; };
layout (binding = 1) readonly buffer k_blob { sfpvec4 k_blob_data[]; };
layout (binding = 2) writeonly buffer qkcross_blob { sfpvec4 qkcross_blob_data[]; };
#endif

layout (push_constant) uniform parameter
{
int M;
int N;
int K;
int B;
} p;

#if NCNN_shader_local_memory
shared lfpvec4 tmp_q[8][LOCAL_MEMORY_UNROLL_INCH][4];
shared lfpvec4 tmp_k[8][LOCAL_MEMORY_UNROLL_INCH];
#endif

void main()
{
int gx = int(gl_GlobalInvocationID.x);
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);

#if !NCNN_shader_local_memory
if (gx >= psc(N) || gy >= psc(M) || gz >= psc(B))
return;
#endif

afpvec4 sum = afpvec4(0.f);

#if !NCNN_image_shader && NCNN_shader_local_memory
const int NN = psc(K);

const int lx = int(gl_LocalInvocationID.x);
const int ly = int(gl_LocalInvocationID.y);

int ai = (gz * psc(M) * psc(K) + lx * psc(M) + gy) * 4;
int bi = gz * psc(N) * psc(K) + ly * psc(N) + gx;

int k = 0;
for (; k + (LOCAL_MEMORY_UNROLL_INCH - 1) < NN; k += LOCAL_MEMORY_UNROLL_INCH)
{
{
tmp_q[ly][lx][0] = sfp2lfpvec4(q_blob_data[ai]);
tmp_q[ly][lx][1] = sfp2lfpvec4(q_blob_data[ai + 1]);
tmp_q[ly][lx][2] = sfp2lfpvec4(q_blob_data[ai + 2]);
tmp_q[ly][lx][3] = sfp2lfpvec4(q_blob_data[ai + 3]);
}

{
tmp_k[lx][ly] = sfp2lfpvec4(k_blob_data[bi]);
}

barrier();

for (int k4 = 0; k4 < LOCAL_MEMORY_UNROLL_INCH; k4++)
{
afpvec4 q0 = lfp2afpvec4(tmp_q[ly][k4][0]);
afpvec4 q1 = lfp2afpvec4(tmp_q[ly][k4][1]);
afpvec4 q2 = lfp2afpvec4(tmp_q[ly][k4][2]);
afpvec4 q3 = lfp2afpvec4(tmp_q[ly][k4][3]);

afpvec4 k0 = lfp2afpvec4(tmp_k[lx][k4]);

sum.r += dot(q0, k0);
sum.g += dot(q1, k0);
sum.b += dot(q2, k0);
sum.a += dot(q3, k0);
}

ai += LOCAL_MEMORY_UNROLL_INCH * psc(M) * 4;
bi += LOCAL_MEMORY_UNROLL_INCH * psc(N);

barrier();
}

if (k < NN)
{
const int remain = NN - k;

if (lx < remain)
{
tmp_q[ly][lx][0] = sfp2lfpvec4(q_blob_data[ai]);
tmp_q[ly][lx][1] = sfp2lfpvec4(q_blob_data[ai + 1]);
tmp_q[ly][lx][2] = sfp2lfpvec4(q_blob_data[ai + 2]);
tmp_q[ly][lx][3] = sfp2lfpvec4(q_blob_data[ai + 3]);
}

if (ly < remain)
{
tmp_k[lx][ly] = sfp2lfpvec4(k_blob_data[bi]);
}

barrier();

for (int k4 = 0; k4 < remain; k4++)
{
afpvec4 q0 = lfp2afpvec4(tmp_q[ly][k4][0]);
afpvec4 q1 = lfp2afpvec4(tmp_q[ly][k4][1]);
afpvec4 q2 = lfp2afpvec4(tmp_q[ly][k4][2]);
afpvec4 q3 = lfp2afpvec4(tmp_q[ly][k4][3]);

afpvec4 k0 = lfp2afpvec4(tmp_k[lx][k4]);

sum.r += dot(q0, k0);
sum.g += dot(q1, k0);
sum.b += dot(q2, k0);
sum.a += dot(q3, k0);
}
}
#else
for (int k = 0; k < psc(K); k++)
{
#if NCNN_image_shader
afpvec4 q0 = image3d_ld4(q_blob_3d, ivec3(gy * 4, gz * psc(K) + k, 0));
afpvec4 q1 = image3d_ld4(q_blob_3d, ivec3(gy * 4 + 1, gz * psc(K) + k, 0));
afpvec4 q2 = image3d_ld4(q_blob_3d, ivec3(gy * 4 + 2, gz * psc(K) + k, 0));
afpvec4 q3 = image3d_ld4(q_blob_3d, ivec3(gy * 4 + 3, gz * psc(K) + k, 0));

afpvec4 k0 = image3d_ld4(k_blob_3d, ivec3(gx, gz * psc(K) + k, 0));
#else
const int ai = (gz * psc(M) * psc(K) + k * psc(M) + gy) * 4;
afpvec4 q0 = buffer_ld4(q_blob_data, ai);
afpvec4 q1 = buffer_ld4(q_blob_data, ai + 1);
afpvec4 q2 = buffer_ld4(q_blob_data, ai + 2);
afpvec4 q3 = buffer_ld4(q_blob_data, ai + 3);

const int bi = gz * psc(N) * psc(K) + k * psc(N) + gx;
afpvec4 k0 = buffer_ld4(k_blob_data, bi);
#endif

sum.r += dot(q0, k0);
sum.g += dot(q1, k0);
sum.b += dot(q2, k0);
sum.a += dot(q3, k0);
}
#endif

#if NCNN_shader_local_memory
if (gx >= psc(N) || gy >= psc(M) || gz >= psc(B))
return;
#endif

#if NCNN_image_shader
image3d_st4(qkcross_blob_3d, ivec3(gx, gz * psc(M) + gy, 0), sum);
#else
const int gi = gz * psc(M) * psc(N) + gy * psc(N) + gx;
buffer_st4(qkcross_blob_data, gi, sum);
#endif
}

+ 81
- 0
src/layer/vulkan/shader/multiheadattention_qk_cross_pack4to1.comp View File

@@ -0,0 +1,81 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#version 450

#if NCNN_fp16_storage
#extension GL_EXT_shader_16bit_storage: require
#endif
#if NCNN_fp16_arithmetic
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
#endif

layout (constant_id = 0) const int M = 0;
layout (constant_id = 1) const int N = 0;
layout (constant_id = 2) const int K = 0;
layout (constant_id = 3) const int B = 0;

#if NCNN_image_shader
layout (binding = 0) uniform unfp sampler3D q_blob_3d;
layout (binding = 1) uniform unfp sampler3D k_blob_3d;
layout (binding = 2, imfmtc1) writeonly uniform unfp image3D qkcross_blob_3d;
#else
layout (binding = 0) readonly buffer q_blob { sfpvec4 q_blob_data[]; };
layout (binding = 1) readonly buffer k_blob { sfpvec4 k_blob_data[]; };
layout (binding = 2) writeonly buffer qkcross_blob { sfp qkcross_blob_data[]; };
#endif

layout (push_constant) uniform parameter
{
int M;
int N;
int K;
int B;
} p;

void main()
{
int gx = int(gl_GlobalInvocationID.x);
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);

if (gx >= psc(N) || gy >= psc(M) || gz >= psc(B))
return;

afp sum = afp(0.f);

for (int k = 0; k < psc(K); k++)
{
#if NCNN_image_shader
afpvec4 q0 = image3d_ld4(q_blob_3d, ivec3(gy, gz * psc(K) + k, 0));

afpvec4 k0 = image3d_ld4(k_blob_3d, ivec3(gx, gz * psc(K) + k, 0));
#else
const int ai = gz * psc(M) * psc(K) + k * psc(M) + gy;
afpvec4 q0 = buffer_ld4(q_blob_data, ai);

const int bi = gz * psc(N) * psc(K) + k * psc(N) + gx;
afpvec4 k0 = buffer_ld4(k_blob_data, bi);
#endif

sum += dot(q0, k0);
}

#if NCNN_image_shader
image3d_st1(qkcross_blob_3d, ivec3(gx, gz * psc(M) + gy, 0), sum);
#else
const int gi = gz * psc(M) * psc(N) + gy * psc(N) + gx;
buffer_st1(qkcross_blob_data, gi, sum);
#endif
}

+ 81
- 0
src/layer/vulkan/shader/multiheadattention_qkv_cross.comp View File

@@ -0,0 +1,81 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#version 450

#if NCNN_fp16_storage
#extension GL_EXT_shader_16bit_storage: require
#endif
#if NCNN_fp16_arithmetic
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
#endif

layout (constant_id = 0) const int M = 0;
layout (constant_id = 1) const int N = 0;
layout (constant_id = 2) const int K = 0;
layout (constant_id = 3) const int B = 0;

#if NCNN_image_shader
layout (binding = 0) uniform unfp sampler3D qkcross_blob_3d;
layout (binding = 1) uniform unfp sampler3D v_blob_3d;
layout (binding = 2, imfmtc1) writeonly uniform unfp image3D qkvcross_blob_3d;
#else
layout (binding = 0) readonly buffer qkcross_blob { sfp qkcross_blob_data[]; };
layout (binding = 1) readonly buffer v_blob { sfp v_blob_data[]; };
layout (binding = 2) writeonly buffer qkvcross_blob { sfp qkvcross_blob_data[]; };
#endif

layout (push_constant) uniform parameter
{
int M;
int N;
int K;
int B;
} p;

void main()
{
int gx = int(gl_GlobalInvocationID.x);
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);

if (gx >= psc(N) || gy >= psc(M) || gz >= psc(B))
return;

afp sum = afp(0.f);

for (int k = 0; k < psc(K); k++)
{
#if NCNN_image_shader
afp qk0 = image3d_ld1(qkcross_blob_3d, ivec3(k, gz * psc(M) + gy, 0));
afp v0 = image3d_ld1(v_blob_3d, ivec3(k, gz * psc(N) + gx, 0));
#else
const int ai = gz * psc(M) * psc(K) + gy * psc(K) + k;
afp qk0 = buffer_ld1(qkcross_blob_data, ai);

const int bi = gz * psc(N) * psc(K) + gx * psc(K) + k;
afp v0 = buffer_ld1(v_blob_data, bi);
#endif

sum += qk0 * v0;
}

#if NCNN_image_shader
image3d_st1(qkvcross_blob_3d, ivec3(gy, gz * psc(N) + gx, 0), sum);
#else
const int gi = gz * psc(M) * psc(N) + gx * psc(M) + gy;

buffer_st1(qkvcross_blob_data, gi, sum);
#endif
}

+ 81
- 0
src/layer/vulkan/shader/multiheadattention_qkv_cross_pack1to4.comp View File

@@ -0,0 +1,81 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#version 450

#if NCNN_fp16_storage
#extension GL_EXT_shader_16bit_storage: require
#endif
#if NCNN_fp16_arithmetic
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
#endif

layout (constant_id = 0) const int M = 0;
layout (constant_id = 1) const int N = 0;
layout (constant_id = 2) const int K = 0;
layout (constant_id = 3) const int B = 0;

#if NCNN_image_shader
layout (binding = 0) uniform unfp sampler3D qkcross_blob_3d;
layout (binding = 1) uniform unfp sampler3D v_blob_3d;
layout (binding = 2, imfmtc4) writeonly uniform unfp image3D qkvcross_blob_3d;
#else
layout (binding = 0) readonly buffer qkcross_blob { sfp qkcross_blob_data[]; };
layout (binding = 1) readonly buffer v_blob { sfpvec4 v_blob_data[]; };
layout (binding = 2) writeonly buffer qkvcross_blob { sfpvec4 qkvcross_blob_data[]; };
#endif

layout (push_constant) uniform parameter
{
int M;
int N;
int K;
int B;
} p;

void main()
{
int gx = int(gl_GlobalInvocationID.x);
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);

if (gx >= psc(N) || gy >= psc(M) || gz >= psc(B))
return;

afpvec4 sum = afpvec4(0.f);

for (int k = 0; k < psc(K); k++)
{
#if NCNN_image_shader
afp qk0 = image3d_ld1(qkcross_blob_3d, ivec3(k, gz * psc(M) + gy, 0));
afpvec4 v0 = image3d_ld4(v_blob_3d, ivec3(k, gz * psc(N) + gx, 0));
#else
const int ai = gz * psc(M) * psc(K) + gy * psc(K) + k;
afp qk0 = buffer_ld1(qkcross_blob_data, ai);

const int bi = gz * psc(N) * psc(K) + gx * psc(K) + k;
afpvec4 v0 = buffer_ld4(v_blob_data, bi);
#endif

sum += qk0 * v0;
}

#if NCNN_image_shader
image3d_st4(qkvcross_blob_3d, ivec3(gy, gz * psc(N) + gx, 0), sum);
#else
const int gi = gz * psc(M) * psc(N) + gx * psc(M) + gy;

buffer_st4(qkvcross_blob_data, gi, sum);
#endif
}

+ 177
- 0
src/layer/vulkan/shader/multiheadattention_qkv_cross_pack4.comp View File

@@ -0,0 +1,177 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#version 450

#if NCNN_fp16_storage
#extension GL_EXT_shader_16bit_storage: require
#endif
#if NCNN_fp16_arithmetic
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
#endif

#define LOCAL_MEMORY_UNROLL_INCH 8

layout (constant_id = 0) const int M = 0;
layout (constant_id = 1) const int N = 0;
layout (constant_id = 2) const int K = 0;
layout (constant_id = 3) const int B = 0;

#if NCNN_image_shader
layout (binding = 0) uniform unfp sampler3D qkcross_blob_3d;
layout (binding = 1) uniform unfp sampler3D v_blob_3d;
layout (binding = 2, imfmtc4) writeonly uniform unfp image3D qkvcross_blob_3d;
#else
layout (binding = 0) readonly buffer qkcross_blob { sfpvec4 qkcross_blob_data[]; };
layout (binding = 1) readonly buffer v_blob { sfpvec4 v_blob_data[]; };
layout (binding = 2) writeonly buffer qkvcross_blob { sfpvec4 qkvcross_blob_data[]; };
#endif

layout (push_constant) uniform parameter
{
int M;
int N;
int K;
int B;
} p;

#if NCNN_shader_local_memory
shared lfpvec4 tmp_qk[8][LOCAL_MEMORY_UNROLL_INCH];
shared lfpvec4 tmp_v[8][LOCAL_MEMORY_UNROLL_INCH];
#endif

void main()
{
int gx = int(gl_GlobalInvocationID.x);
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);

#if !NCNN_shader_local_memory
if (gx >= psc(N) || gy >= psc(M) || gz >= psc(B))
return;
#endif

afpvec4 sum0 = afpvec4(0.f);
afpvec4 sum1 = afpvec4(0.f);
afpvec4 sum2 = afpvec4(0.f);
afpvec4 sum3 = afpvec4(0.f);

#if !NCNN_image_shader && NCNN_shader_local_memory
const int NN = psc(K);

const int lx = int(gl_LocalInvocationID.x);
const int ly = int(gl_LocalInvocationID.y);

int ai = gz * psc(M) * psc(K) + gy * psc(K) + lx;
int bi = gz * psc(N) * psc(K) + gx * psc(K) + ly;

int k = 0;
for (; k + (LOCAL_MEMORY_UNROLL_INCH - 1) < NN; k += LOCAL_MEMORY_UNROLL_INCH)
{
{
tmp_qk[ly][lx] = sfp2lfpvec4(qkcross_blob_data[ai]);
}

{
tmp_v[lx][ly] = sfp2lfpvec4(v_blob_data[bi]);
}

barrier();

for (int k4 = 0; k4 < LOCAL_MEMORY_UNROLL_INCH; k4++)
{
afpvec4 qk0 = lfp2afpvec4(tmp_qk[ly][k4]);

afpvec4 v0 = lfp2afpvec4(tmp_v[lx][k4]);

sum0 += qk0.r * v0;
sum1 += qk0.g * v0;
sum2 += qk0.b * v0;
sum3 += qk0.a * v0;
}

ai += LOCAL_MEMORY_UNROLL_INCH;
bi += LOCAL_MEMORY_UNROLL_INCH;

barrier();
}

if (k < NN)
{
const int remain = NN - k;

if (lx < remain)
{
tmp_qk[ly][lx] = sfp2lfpvec4(qkcross_blob_data[ai]);
}

if (ly < remain)
{
tmp_v[lx][ly] = sfp2lfpvec4(v_blob_data[bi]);
}

barrier();

for (int k4 = 0; k4 < remain; k4++)
{
afpvec4 qk0 = lfp2afpvec4(tmp_qk[ly][k4]);

afpvec4 v0 = lfp2afpvec4(tmp_v[lx][k4]);

sum0 += qk0.r * v0;
sum1 += qk0.g * v0;
sum2 += qk0.b * v0;
sum3 += qk0.a * v0;
}
}
#else
for (int k = 0; k < psc(K); k++)
{
#if NCNN_image_shader
afpvec4 qk0 = image3d_ld4(qkcross_blob_3d, ivec3(k, gz * psc(M) + gy, 0));
afpvec4 v0 = image3d_ld4(v_blob_3d, ivec3(k, gz * psc(N) + gx, 0));
#else
const int ai = gz * psc(M) * psc(K) + gy * psc(K) + k;
afpvec4 qk0 = buffer_ld4(qkcross_blob_data, ai);

const int bi = gz * psc(N) * psc(K) + gx * psc(K) + k;
afpvec4 v0 = buffer_ld4(v_blob_data, bi);
#endif

sum0 += qk0.r * v0;
sum1 += qk0.g * v0;
sum2 += qk0.b * v0;
sum3 += qk0.a * v0;
}
#endif

#if NCNN_shader_local_memory
if (gx >= psc(N) || gy >= psc(M) || gz >= psc(B))
return;
#endif

#if NCNN_image_shader
image3d_st4(qkvcross_blob_3d, ivec3(gy * 4, gz * psc(N) + gx, 0), sum0);
image3d_st4(qkvcross_blob_3d, ivec3(gy * 4 + 1, gz * psc(N) + gx, 0), sum1);
image3d_st4(qkvcross_blob_3d, ivec3(gy * 4 + 2, gz * psc(N) + gx, 0), sum2);
image3d_st4(qkvcross_blob_3d, ivec3(gy * 4 + 3, gz * psc(N) + gx, 0), sum3);
#else
const int gi = (gz * psc(M) * psc(N) + gx * psc(M) + gy) * 4;

buffer_st4(qkvcross_blob_data, gi, sum0);
buffer_st4(qkvcross_blob_data, gi + 1, sum1);
buffer_st4(qkvcross_blob_data, gi + 2, sum2);
buffer_st4(qkvcross_blob_data, gi + 3, sum3);
#endif
}

+ 87
- 0
src/layer/vulkan/shader/multiheadattention_qkv_cross_pack4to1.comp View File

@@ -0,0 +1,87 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#version 450

#if NCNN_fp16_storage
#extension GL_EXT_shader_16bit_storage: require
#endif
#if NCNN_fp16_arithmetic
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
#endif

layout (constant_id = 0) const int M = 0;
layout (constant_id = 1) const int N = 0;
layout (constant_id = 2) const int K = 0;
layout (constant_id = 3) const int B = 0;

#if NCNN_image_shader
layout (binding = 0) uniform unfp sampler3D qkcross_blob_3d;
layout (binding = 1) uniform unfp sampler3D v_blob_3d;
layout (binding = 2, imfmtc1) writeonly uniform unfp image3D qkvcross_blob_3d;
#else
layout (binding = 0) readonly buffer qkcross_blob { sfpvec4 qkcross_blob_data[]; };
layout (binding = 1) readonly buffer v_blob { sfp v_blob_data[]; };
layout (binding = 2) writeonly buffer qkvcross_blob { sfp qkvcross_blob_data[]; };
#endif

layout (push_constant) uniform parameter
{
int M;
int N;
int K;
int B;
} p;

void main()
{
int gx = int(gl_GlobalInvocationID.x);
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);

if (gx >= psc(N) || gy >= psc(M) || gz >= psc(B))
return;

afpvec4 sum = afpvec4(0.f);

for (int k = 0; k < psc(K); k++)
{
#if NCNN_image_shader
afpvec4 qk0 = image3d_ld4(qkcross_blob_3d, ivec3(k, gz * psc(M) + gy, 0));
afp v0 = image3d_ld1(v_blob_3d, ivec3(k, gz * psc(N) + gx, 0));
#else
const int ai = gz * psc(M) * psc(K) + gy * psc(K) + k;
afpvec4 qk0 = buffer_ld4(qkcross_blob_data, ai);

const int bi = gz * psc(N) * psc(K) + gx * psc(K) + k;
afp v0 = buffer_ld1(v_blob_data, bi);
#endif

sum += qk0 * v0;
}

#if NCNN_image_shader
image3d_st1(qkvcross_blob_3d, ivec3(gy * 4, gz * psc(N) + gx, 0), sum.r);
image3d_st1(qkvcross_blob_3d, ivec3(gy * 4 + 1, gz * psc(N) + gx, 0), sum.g);
image3d_st1(qkvcross_blob_3d, ivec3(gy * 4 + 2, gz * psc(N) + gx, 0), sum.b);
image3d_st1(qkvcross_blob_3d, ivec3(gy * 4 + 3, gz * psc(N) + gx, 0), sum.a);
#else
const int gi = (gz * psc(M) * psc(N) + gx * psc(M) + gy) * 4;

buffer_st1(qkvcross_blob_data, gi, sum.r);
buffer_st1(qkvcross_blob_data, gi + 1, sum.g);
buffer_st1(qkvcross_blob_data, gi + 2, sum.b);
buffer_st1(qkvcross_blob_data, gi + 3, sum.a);
#endif
}

+ 6
- 6
src/layer/x86/multiheadattention_x86.cpp View File

@@ -39,7 +39,7 @@ MultiHeadAttention_x86::MultiHeadAttention_x86()
int MultiHeadAttention_x86::create_pipeline(const Option& opt)
{
{
const int embed_dim_per_head = embed_dim / num_head;
const int embed_dim_per_head = embed_dim / num_heads;
const float inv_sqrt_embed_dim_per_head = 1.f / sqrt(embed_dim_per_head);

q_gemm = ncnn::create_layer(ncnn::LayerType::Gemm);
@@ -271,7 +271,7 @@ int MultiHeadAttention_x86::forward(const std::vector<Mat>& bottom_blobs, std::v
const Mat& k_blob = bottom_blobs.size() == 1 ? q_blob : bottom_blobs[1];
const Mat& v_blob = bottom_blobs.size() == 1 ? q_blob : bottom_blobs.size() == 2 ? k_blob : bottom_blobs[2];

const int embed_dim_per_head = embed_dim / num_head;
const int embed_dim_per_head = embed_dim / num_heads;
const int src_seqlen = q_blob.h * q_blob.elempack;
const int dst_seqlen = k_blob.h * k_blob.elempack;

@@ -281,9 +281,9 @@ int MultiHeadAttention_x86::forward(const std::vector<Mat>& bottom_blobs, std::v
Mat k_affine;
k_gemm->forward(k_blob, k_affine, opt);

Mat qk_cross(dst_seqlen, src_seqlen * num_head, 4u, opt.blob_allocator);
Mat qk_cross(dst_seqlen, src_seqlen * num_heads, 4u, opt.blob_allocator);
#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < num_head; i++)
for (int i = 0; i < num_heads; i++)
{
std::vector<Mat> qk_bottom_blobs(2);
qk_bottom_blobs[0] = q_affine.row_range(i * embed_dim_per_head, embed_dim_per_head);
@@ -303,9 +303,9 @@ int MultiHeadAttention_x86::forward(const std::vector<Mat>& bottom_blobs, std::v
Mat v_affine;
v_gemm->forward(v_blob, v_affine, opt);

Mat qkv_cross(src_seqlen, embed_dim_per_head * num_head, 4u, opt.blob_allocator);
Mat qkv_cross(src_seqlen, embed_dim_per_head * num_heads, 4u, opt.blob_allocator);
#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < num_head; i++)
for (int i = 0; i < num_heads; i++)
{
std::vector<Mat> qkv_bottom_blobs(2);
qkv_bottom_blobs[0] = qk_cross.row_range(i * src_seqlen, src_seqlen);


+ 7
- 5
tests/test_multiheadattention.cpp View File

@@ -46,7 +46,7 @@ static int test_multiheadattention(const ncnn::Mat& q, const ncnn::Mat& k, const
int ret = test_layer<ncnn::MultiHeadAttention>("MultiHeadAttention", pd, weights, as, 1, epsilon);
if (ret != 0)
{
fprintf(stderr, "test_multiheadattention failed q=(%d %d) k=(%d %d) v=(%d %d)\n", q.w, q.h, k.w, k.h, v.w, v.h);
fprintf(stderr, "test_multiheadattention failed q=(%d %d) k=(%d %d) v=(%d %d) num_heads=%d kdim=%d vdim=%d\n", q.w, q.h, k.w, k.h, v.w, v.h, num_heads, kdim, vdim);
}

return ret;
@@ -82,7 +82,7 @@ static int test_multiheadattention_samekv(const ncnn::Mat& q, const ncnn::Mat& k
int ret = test_layer<ncnn::MultiHeadAttention>("MultiHeadAttention", pd, weights, as, 1, epsilon);
if (ret != 0)
{
fprintf(stderr, "test_multiheadattention_samekv failed q=(%d %d) kv=(%d %d)\n", q.w, q.h, kv.w, kv.h);
fprintf(stderr, "test_multiheadattention_samekv failed q=(%d %d) kv=(%d %d) num_heads=%d kvdim=%d\n", q.w, q.h, kv.w, kv.h, num_heads, kvdim);
}

return ret;
@@ -115,7 +115,7 @@ static int test_multiheadattention_sameqkv(const ncnn::Mat& a, int num_heads)
int ret = test_layer<ncnn::MultiHeadAttention>("MultiHeadAttention", pd, weights, as, 1, epsilon);
if (ret != 0)
{
fprintf(stderr, "test_multiheadattention_sameqkv failed a=(%d %d)\n", a.w, a.h);
fprintf(stderr, "test_multiheadattention_sameqkv failed a=(%d %d) num_heads=%d\n", a.w, a.h, num_heads);
}

return ret;
@@ -124,6 +124,8 @@ static int test_multiheadattention_sameqkv(const ncnn::Mat& a, int num_heads)
static int test_multiheadattention_0()
{
return 0
|| test_multiheadattention(RandomMat(62, 66), RandomMat(32, 66), RandomMat(20, 66), 2, 32, 20)
|| test_multiheadattention(RandomMat(26, 64), RandomMat(32, 64), RandomMat(18, 64), 2, 32, 18)
|| test_multiheadattention(RandomMat(64, 128), RandomMat(64, 128), RandomMat(64, 128), 4, 64, 64)
|| test_multiheadattention(RandomMat(64, 127), RandomMat(64, 127), RandomMat(64, 127), 16, 64, 64)
|| test_multiheadattention(RandomMat(16, 128), RandomMat(44, 128), RandomMat(55, 128), 2, 44, 55)
@@ -146,8 +148,8 @@ static int test_multiheadattention_1()
static int test_multiheadattention_2()
{
return 0
|| test_multiheadattention_sameqkv(RandomMat(64, 128), 8)
|| test_multiheadattention_sameqkv(RandomMat(64, 127), 32);
|| test_multiheadattention_sameqkv(RandomMat(64, 128), 4)
|| test_multiheadattention_sameqkv(RandomMat(64, 127), 8);
}

int main()


+ 1
- 1
tools/modelwriter.h View File

@@ -1922,7 +1922,7 @@ int ModelWriter::save(const char* parampath, const char* binpath)
ncnn::MultiHeadAttention* op_default = (ncnn::MultiHeadAttention*)layer_default;

fprintf_param_value(" 0=%d", embed_dim)
fprintf_param_value(" 1=%d", num_head)
fprintf_param_value(" 1=%d", num_heads)
fprintf_param_value(" 2=%d", weight_data_size)
fprintf_param_value(" 3=%d", kdim)
fprintf_param_value(" 4=%d", vdim)


Loading…
Cancel
Save