You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_gemm_4.cpp 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "testutil.h"
  15. #if NCNN_INT8
  16. static int test_gemm_int8(int M, int N, int K, int TILE_M, int TILE_N, int TILE_K, float alpha, int transA, int transB, int output_transpose)
  17. {
  18. ncnn::ParamDict pd;
  19. pd.set(0, alpha);
  20. pd.set(1, 1.f); // beta
  21. pd.set(2, transA);
  22. pd.set(3, transB);
  23. pd.set(14, output_transpose);
  24. pd.set(18, 2); // int8_scale_term
  25. pd.set(20, TILE_M);
  26. pd.set(21, TILE_N);
  27. pd.set(22, TILE_K);
  28. std::vector<ncnn::Mat> weights(0);
  29. std::vector<ncnn::Mat> a(2);
  30. a[0] = transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M);
  31. a[1] = transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K);
  32. Randomize(a[0], -10.f, 10.f);
  33. Randomize(a[1], -10.f, 10.f);
  34. int ret = test_layer("Gemm", pd, weights, a);
  35. if (ret != 0)
  36. {
  37. fprintf(stderr, "test_gemm_int8 failed M=%d N=%d K=%d TILE_M=%d TILE_N=%d TILE_K=%d alpha=%f transA=%d transB=%d output_transpose=%d\n", M, N, K, TILE_M, TILE_N, TILE_K, alpha, transA, transB, output_transpose);
  38. }
  39. return ret;
  40. }
  41. static int test_gemm_0(int M, int N, int K, int TILE_M, int TILE_N, int TILE_K)
  42. {
  43. return 0
  44. || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 2.1f, 0, 0, 0)
  45. || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 3.1f, 0, 1, 0)
  46. || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 4.1f, 1, 0, 0)
  47. || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 5.1f, 1, 1, 0)
  48. || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 2.1f, 0, 0, 1)
  49. || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 3.1f, 0, 1, 1)
  50. || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 4.1f, 1, 0, 1)
  51. || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 5.1f, 1, 1, 1);
  52. }
  53. #endif // NCNN_INT8
  54. int main()
  55. {
  56. SRAND(7767517);
  57. #if NCNN_INT8
  58. int mnk[][3] = {
  59. {1, 1, 1},
  60. {2, 2, 2},
  61. {3, 3, 3},
  62. {4, 4, 4},
  63. {5, 5, 5},
  64. {6, 6, 6},
  65. {7, 7, 7},
  66. {8, 8, 8},
  67. {15, 15, 15},
  68. {16, 16, 16},
  69. {24, 24, 24},
  70. {31, 31, 31},
  71. {31, 32, 31},
  72. {32, 31, 32},
  73. {32, 32, 32},
  74. {20, 32, 20},
  75. {40, 40, 40},
  76. {47, 47, 47},
  77. {48, 48, 48},
  78. {52, 52, 52},
  79. {63, 64, 63},
  80. {64, 63, 64},
  81. {64, 64, 64}
  82. };
  83. int tile_mnk[][3] = {
  84. {1, 1, 1},
  85. {2, 2, 2},
  86. {4, 4, 4},
  87. {8, 8, 8},
  88. {12, 12, 12},
  89. {16, 16, 16},
  90. {20, 20, 20},
  91. {24, 24, 24},
  92. {28, 28, 28}
  93. };
  94. int mnk_count = sizeof(mnk) / sizeof(int) / 3;
  95. int tile_mnk_count = sizeof(tile_mnk) / sizeof(int) / 3;
  96. for (int i = 0; i < mnk_count; i++)
  97. {
  98. int M = mnk[i][0];
  99. int N = mnk[i][1];
  100. int K = mnk[i][2];
  101. for (int j = 0; j < tile_mnk_count; j++)
  102. {
  103. int TILE_M = tile_mnk[j][0];
  104. int TILE_N = tile_mnk[j][1];
  105. int TILE_K = tile_mnk[j][2];
  106. if (TILE_M >= M && TILE_N >= N && TILE_K >= K)
  107. continue;
  108. int ret = test_gemm_0(M, N, K, TILE_M, TILE_N, TILE_K);
  109. if (ret != 0)
  110. return ret;
  111. }
  112. // test no tiling
  113. int ret = test_gemm_0(M, N, K, 100, 100, 100);
  114. if (ret != 0)
  115. return ret;
  116. }
  117. #else
  118. // test nothing for non-int8 build
  119. #endif
  120. return 0;
  121. }