You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_multiheadattention_1.cpp 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. // Copyright 2024 Tencent
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. #include "testutil.h"
  4. #if NCNN_INT8
  5. static int test_multiheadattention_int8(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int embed_dim, int num_heads, int attn_mask)
  6. {
  7. const int qdim = q.w;
  8. const int kdim = k.w;
  9. const int vdim = v.w;
  10. ncnn::ParamDict pd;
  11. pd.set(0, embed_dim);
  12. pd.set(1, num_heads);
  13. pd.set(2, embed_dim * qdim);
  14. pd.set(3, kdim);
  15. pd.set(4, vdim);
  16. pd.set(5, attn_mask);
  17. pd.set(6, 1.f / sqrtf(embed_dim / num_heads));
  18. pd.set(18, 2); // int8_scale_term
  19. std::vector<ncnn::Mat> weights(12);
  20. weights[0] = RandomS8Mat(embed_dim * qdim);
  21. weights[1] = RandomMat(embed_dim);
  22. weights[2] = RandomS8Mat(embed_dim * kdim);
  23. weights[3] = RandomMat(embed_dim);
  24. weights[4] = RandomS8Mat(embed_dim * vdim);
  25. weights[5] = RandomMat(embed_dim);
  26. weights[6] = RandomS8Mat(qdim * embed_dim);
  27. weights[7] = RandomMat(qdim);
  28. weights[8] = RandomMat(embed_dim, 160.f, 200.f);
  29. weights[9] = RandomMat(embed_dim, 160.f, 200.f);
  30. weights[10] = RandomMat(embed_dim, 160.f, 200.f);
  31. weights[11] = RandomMat(1, 160.f, 200.f);
  32. std::vector<ncnn::Mat> as(3);
  33. as[0] = q;
  34. as[1] = k;
  35. as[2] = v;
  36. if (attn_mask)
  37. {
  38. as.push_back(RandomMat(k.h, q.h));
  39. }
  40. float epsilon = 0.1;
  41. int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon);
  42. if (ret != 0)
  43. {
  44. fprintf(stderr, "test_multiheadattention_int8 failed q=(%d %d) k=(%d %d) v=(%d %d) embed_dim=%d num_heads=%d kdim=%d vdim=%d attn_mask=%d\n", q.w, q.h, k.w, k.h, v.w, v.h, embed_dim, num_heads, kdim, vdim, attn_mask);
  45. }
  46. return ret;
  47. }
  48. static int test_multiheadattention_int8_samekv(const ncnn::Mat& q, const ncnn::Mat& kv, int embed_dim, int num_heads)
  49. {
  50. const int qdim = q.w;
  51. const int kvdim = kv.w;
  52. ncnn::ParamDict pd;
  53. pd.set(0, embed_dim);
  54. pd.set(1, num_heads);
  55. pd.set(2, embed_dim * qdim);
  56. pd.set(3, kvdim);
  57. pd.set(4, kvdim);
  58. pd.set(6, 1.f / sqrtf(embed_dim / num_heads));
  59. pd.set(18, 2); // int8_scale_term
  60. std::vector<ncnn::Mat> weights(12);
  61. weights[0] = RandomS8Mat(embed_dim * qdim);
  62. weights[1] = RandomMat(embed_dim);
  63. weights[2] = RandomS8Mat(embed_dim * kvdim);
  64. weights[3] = RandomMat(embed_dim);
  65. weights[4] = RandomS8Mat(embed_dim * kvdim);
  66. weights[5] = RandomMat(embed_dim);
  67. weights[6] = RandomS8Mat(qdim * embed_dim);
  68. weights[7] = RandomMat(qdim);
  69. weights[8] = RandomMat(embed_dim, 160.f, 200.f);
  70. weights[9] = RandomMat(embed_dim, 160.f, 200.f);
  71. weights[10] = RandomMat(embed_dim, 160.f, 200.f);
  72. weights[11] = RandomMat(1, 160.f, 200.f);
  73. std::vector<ncnn::Mat> as(2);
  74. as[0] = q;
  75. as[1] = kv;
  76. float epsilon = 0.1;
  77. int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon);
  78. if (ret != 0)
  79. {
  80. fprintf(stderr, "test_multiheadattention_int8_samekv failed q=(%d %d) kv=(%d %d) embed_dim=%d num_heads=%d kvdim=%d\n", q.w, q.h, kv.w, kv.h, embed_dim, num_heads, kvdim);
  81. }
  82. return ret;
  83. }
  84. static int test_multiheadattention_int8_sameqkv(const ncnn::Mat& a, int embed_dim, int num_heads)
  85. {
  86. const int qdim = a.w;
  87. ncnn::ParamDict pd;
  88. pd.set(0, embed_dim);
  89. pd.set(1, num_heads);
  90. pd.set(2, embed_dim * qdim);
  91. pd.set(3, qdim);
  92. pd.set(4, qdim);
  93. pd.set(6, 1.f / sqrtf(embed_dim / num_heads));
  94. pd.set(18, 2); // int8_scale_term
  95. std::vector<ncnn::Mat> weights(12);
  96. weights[0] = RandomS8Mat(embed_dim * qdim);
  97. weights[1] = RandomMat(embed_dim);
  98. weights[2] = RandomS8Mat(embed_dim * qdim);
  99. weights[3] = RandomMat(embed_dim);
  100. weights[4] = RandomS8Mat(embed_dim * qdim);
  101. weights[5] = RandomMat(embed_dim);
  102. weights[6] = RandomS8Mat(qdim * embed_dim);
  103. weights[7] = RandomMat(qdim);
  104. weights[8] = RandomMat(embed_dim, 160.f, 200.f);
  105. weights[9] = RandomMat(embed_dim, 160.f, 200.f);
  106. weights[10] = RandomMat(embed_dim, 160.f, 200.f);
  107. weights[11] = RandomMat(1, 160.f, 200.f);
  108. std::vector<ncnn::Mat> as(1);
  109. as[0] = a;
  110. float epsilon = 0.1;
  111. int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon);
  112. if (ret != 0)
  113. {
  114. fprintf(stderr, "test_multiheadattention_int8_sameqkv failed a=(%d %d) embed_dim=%d num_heads=%d\n", a.w, a.h, embed_dim, num_heads);
  115. }
  116. return ret;
  117. }
  118. static int test_multiheadattention_0()
  119. {
  120. return 0
  121. || test_multiheadattention_int8(RandomMat(62, 66), RandomMat(32, 66), RandomMat(20, 66), 62, 2, 0)
  122. || test_multiheadattention_int8(RandomMat(26, 64), RandomMat(32, 64), RandomMat(18, 64), 26, 2, 1)
  123. || test_multiheadattention_int8(RandomMat(64, 128), RandomMat(64, 128), RandomMat(64, 128), 64, 4, 0)
  124. || test_multiheadattention_int8(RandomMat(48, 127), RandomMat(64, 127), RandomMat(64, 127), 64, 16, 1)
  125. || test_multiheadattention_int8(RandomMat(16, 128), RandomMat(44, 128), RandomMat(55, 128), 16, 2, 0)
  126. || test_multiheadattention_int8(RandomMat(12, 128), RandomMat(44, 127), RandomMat(55, 127), 16, 4, 1)
  127. || test_multiheadattention_int8(RandomMat(12, 17), RandomMat(28, 127), RandomMat(32, 127), 12, 3, 0)
  128. || test_multiheadattention_int8(RandomMat(12, 17), RandomMat(28, 32), RandomMat(11, 32), 12, 3, 1);
  129. }
  130. static int test_multiheadattention_1()
  131. {
  132. return 0
  133. || test_multiheadattention_int8_samekv(RandomMat(64, 128), RandomMat(64, 128), 64, 4)
  134. || test_multiheadattention_int8_samekv(RandomMat(48, 127), RandomMat(64, 127), 64, 16)
  135. || test_multiheadattention_int8_samekv(RandomMat(16, 128), RandomMat(44, 128), 16, 2)
  136. || test_multiheadattention_int8_samekv(RandomMat(12, 128), RandomMat(22, 127), 16, 4)
  137. || test_multiheadattention_int8_samekv(RandomMat(12, 17), RandomMat(28, 127), 12, 3)
  138. || test_multiheadattention_int8_samekv(RandomMat(12, 17), RandomMat(11, 32), 12, 3);
  139. }
  140. static int test_multiheadattention_2()
  141. {
  142. return 0
  143. || test_multiheadattention_int8_sameqkv(RandomMat(64, 128), 64, 4)
  144. || test_multiheadattention_int8_sameqkv(RandomMat(48, 127), 64, 8);
  145. }
  146. #endif
  147. int main()
  148. {
  149. SRAND(7767517);
  150. #if NCNN_INT8
  151. return 0
  152. || test_multiheadattention_0()
  153. || test_multiheadattention_1()
  154. || test_multiheadattention_2();
  155. #else
  156. // test nothing
  157. return 0;
  158. #endif
  159. }