// Tencent is pleased to support the open source community by making ncnn available. // // Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at // // https://opensource.org/licenses/BSD-3-Clause // // Unless required by applicable law or agreed to in writing, software distributed // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. #include "testutil.h" static int test_multiheadattention_oom(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int embed_dim, int num_heads, int attn_mask) { const int qdim = q.w; const int kdim = k.w; const int vdim = v.w; ncnn::ParamDict pd; pd.set(0, embed_dim); pd.set(1, num_heads); pd.set(2, embed_dim * qdim); pd.set(3, kdim); pd.set(4, vdim); pd.set(5, attn_mask); std::vector weights(8); weights[0] = RandomMat(embed_dim * qdim); weights[1] = RandomMat(embed_dim); weights[2] = RandomMat(embed_dim * kdim); weights[3] = RandomMat(embed_dim); weights[4] = RandomMat(embed_dim * vdim); weights[5] = RandomMat(embed_dim); weights[6] = RandomMat(qdim * embed_dim); weights[7] = RandomMat(qdim); std::vector as(3); as[0] = q; as[1] = k; as[2] = v; if (attn_mask) { as.push_back(RandomMat(k.h, q.h)); } int ret = test_layer_oom("MultiHeadAttention", pd, weights, as, 1); if (ret != 0) { fprintf(stderr, "test_multiheadattention_oom failed q=(%d %d) k=(%d %d) v=(%d %d) embed_dim=%d num_heads=%d kdim=%d vdim=%d attn_mask=%d\n", q.w, q.h, k.w, k.h, v.w, v.h, embed_dim, num_heads, kdim, vdim, attn_mask); } return ret; } static int test_multiheadattention_0() { return 0 || test_multiheadattention_oom(RandomMat(62, 66), RandomMat(32, 66), RandomMat(20, 66), 62, 2, 0) || test_multiheadattention_oom(RandomMat(26, 64), RandomMat(32, 64), RandomMat(18, 64), 26, 2, 1) || test_multiheadattention_oom(RandomMat(12, 17), RandomMat(28, 127), RandomMat(32, 127), 12, 3, 0) || test_multiheadattention_oom(RandomMat(12, 17), RandomMat(28, 32), RandomMat(11, 32), 12, 3, 1); } int main() { SRAND(7767517); return test_multiheadattention_0(); }