You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_multiheadattention.cpp 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. // Copyright 2021 Tencent
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. #include "testutil.h"
  4. static int test_multiheadattention(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int embed_dim, int num_heads, int attn_mask)
  5. {
  6. const int qdim = q.w;
  7. const int kdim = k.w;
  8. const int vdim = v.w;
  9. ncnn::ParamDict pd;
  10. pd.set(0, embed_dim);
  11. pd.set(1, num_heads);
  12. pd.set(2, embed_dim * qdim);
  13. pd.set(3, kdim);
  14. pd.set(4, vdim);
  15. pd.set(5, attn_mask);
  16. std::vector<ncnn::Mat> weights(8);
  17. weights[0] = RandomMat(embed_dim * qdim);
  18. weights[1] = RandomMat(embed_dim);
  19. weights[2] = RandomMat(embed_dim * kdim);
  20. weights[3] = RandomMat(embed_dim);
  21. weights[4] = RandomMat(embed_dim * vdim);
  22. weights[5] = RandomMat(embed_dim);
  23. weights[6] = RandomMat(qdim * embed_dim);
  24. weights[7] = RandomMat(qdim);
  25. std::vector<ncnn::Mat> as(3);
  26. as[0] = q;
  27. as[1] = k;
  28. as[2] = v;
  29. if (attn_mask)
  30. {
  31. as.push_back(RandomMat(k.h, q.h));
  32. }
  33. float epsilon = 0.005;
  34. int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon);
  35. if (ret != 0)
  36. {
  37. fprintf(stderr, "test_multiheadattention failed q=(%d %d) k=(%d %d) v=(%d %d) embed_dim=%d num_heads=%d kdim=%d vdim=%d attn_mask=%d\n", q.w, q.h, k.w, k.h, v.w, v.h, embed_dim, num_heads, kdim, vdim, attn_mask);
  38. }
  39. return ret;
  40. }
  41. static int test_multiheadattention_samekv(const ncnn::Mat& q, const ncnn::Mat& kv, int embed_dim, int num_heads)
  42. {
  43. const int qdim = q.w;
  44. const int kvdim = kv.w;
  45. ncnn::ParamDict pd;
  46. pd.set(0, embed_dim);
  47. pd.set(1, num_heads);
  48. pd.set(2, embed_dim * qdim);
  49. pd.set(3, kvdim);
  50. pd.set(4, kvdim);
  51. std::vector<ncnn::Mat> weights(8);
  52. weights[0] = RandomMat(embed_dim * qdim);
  53. weights[1] = RandomMat(embed_dim);
  54. weights[2] = RandomMat(embed_dim * kvdim);
  55. weights[3] = RandomMat(embed_dim);
  56. weights[4] = RandomMat(embed_dim * kvdim);
  57. weights[5] = RandomMat(embed_dim);
  58. weights[6] = RandomMat(qdim * embed_dim);
  59. weights[7] = RandomMat(qdim);
  60. std::vector<ncnn::Mat> as(2);
  61. as[0] = q;
  62. as[1] = kv;
  63. float epsilon = 0.005;
  64. int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon);
  65. if (ret != 0)
  66. {
  67. fprintf(stderr, "test_multiheadattention_samekv failed q=(%d %d) kv=(%d %d) embed_dim=%d num_heads=%d kvdim=%d\n", q.w, q.h, kv.w, kv.h, embed_dim, num_heads, kvdim);
  68. }
  69. return ret;
  70. }
  71. static int test_multiheadattention_sameqkv(const ncnn::Mat& a, int embed_dim, int num_heads)
  72. {
  73. const int qdim = a.w;
  74. ncnn::ParamDict pd;
  75. pd.set(0, embed_dim);
  76. pd.set(1, num_heads);
  77. pd.set(2, embed_dim * qdim);
  78. pd.set(3, qdim);
  79. pd.set(4, qdim);
  80. pd.set(6, 0.7f / sqrtf(embed_dim / num_heads));
  81. std::vector<ncnn::Mat> weights(8);
  82. weights[0] = RandomMat(embed_dim * qdim);
  83. weights[1] = RandomMat(embed_dim);
  84. weights[2] = RandomMat(embed_dim * qdim);
  85. weights[3] = RandomMat(embed_dim);
  86. weights[4] = RandomMat(embed_dim * qdim);
  87. weights[5] = RandomMat(embed_dim);
  88. weights[6] = RandomMat(qdim * embed_dim);
  89. weights[7] = RandomMat(qdim);
  90. std::vector<ncnn::Mat> as(1);
  91. as[0] = a;
  92. float epsilon = 0.005;
  93. int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon);
  94. if (ret != 0)
  95. {
  96. fprintf(stderr, "test_multiheadattention_sameqkv failed a=(%d %d) embed_dim=%d num_heads=%d\n", a.w, a.h, embed_dim, num_heads);
  97. }
  98. return ret;
  99. }
  100. static int test_multiheadattention_0()
  101. {
  102. return 0
  103. || test_multiheadattention(RandomMat(62, 66), RandomMat(32, 66), RandomMat(20, 66), 62, 2, 0)
  104. || test_multiheadattention(RandomMat(26, 64), RandomMat(32, 64), RandomMat(18, 64), 26, 2, 1)
  105. || test_multiheadattention(RandomMat(64, 128), RandomMat(64, 128), RandomMat(64, 128), 64, 4, 0)
  106. || test_multiheadattention(RandomMat(48, 127), RandomMat(64, 127), RandomMat(64, 127), 64, 16, 1)
  107. || test_multiheadattention(RandomMat(16, 128), RandomMat(44, 128), RandomMat(55, 128), 16, 2, 0)
  108. || test_multiheadattention(RandomMat(12, 128), RandomMat(44, 127), RandomMat(55, 127), 16, 4, 1)
  109. || test_multiheadattention(RandomMat(12, 17), RandomMat(28, 127), RandomMat(32, 127), 12, 3, 0)
  110. || test_multiheadattention(RandomMat(12, 17), RandomMat(28, 32), RandomMat(11, 32), 12, 3, 1);
  111. }
  112. static int test_multiheadattention_1()
  113. {
  114. return 0
  115. || test_multiheadattention_samekv(RandomMat(64, 128), RandomMat(64, 128), 64, 4)
  116. || test_multiheadattention_samekv(RandomMat(48, 127), RandomMat(64, 127), 64, 16)
  117. || test_multiheadattention_samekv(RandomMat(16, 128), RandomMat(44, 128), 16, 2)
  118. || test_multiheadattention_samekv(RandomMat(12, 128), RandomMat(22, 127), 16, 4)
  119. || test_multiheadattention_samekv(RandomMat(12, 17), RandomMat(28, 127), 12, 3)
  120. || test_multiheadattention_samekv(RandomMat(12, 17), RandomMat(11, 32), 12, 3);
  121. }
  122. static int test_multiheadattention_2()
  123. {
  124. return 0
  125. || test_multiheadattention_sameqkv(RandomMat(64, 128), 64, 4)
  126. || test_multiheadattention_sameqkv(RandomMat(48, 127), 64, 8);
  127. }
  128. int main()
  129. {
  130. SRAND(7767517);
  131. return 0
  132. || test_multiheadattention_0()
  133. || test_multiheadattention_1()
  134. || test_multiheadattention_2();
  135. }