You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_gemm_1.cpp 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. // Copyright 2023 Tencent
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. #include "testutil.h"
  4. static int test_gemm(int M, int N, int K, int TILE_M, int TILE_N, int TILE_K, float alpha, int transA, int transB, int output_transpose)
  5. {
  6. ncnn::ParamDict pd;
  7. pd.set(0, alpha);
  8. pd.set(1, 1.f); // beta
  9. pd.set(2, transA);
  10. pd.set(3, transB);
  11. pd.set(14, output_transpose);
  12. pd.set(20, TILE_M);
  13. pd.set(21, TILE_N);
  14. pd.set(22, TILE_K);
  15. std::vector<ncnn::Mat> weights(0);
  16. std::vector<ncnn::Mat> a(2);
  17. a[0] = transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M);
  18. a[1] = transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K);
  19. Randomize(a[0]);
  20. Randomize(a[1]);
  21. int ret = test_layer("Gemm", pd, weights, a);
  22. if (ret != 0)
  23. {
  24. fprintf(stderr, "test_gemm failed M=%d N=%d K=%d TILE_M=%d TILE_N=%d TILE_K=%d alpha=%f transA=%d transB=%d output_transpose=%d\n", M, N, K, TILE_M, TILE_N, TILE_K, alpha, transA, transB, output_transpose);
  25. }
  26. return ret;
  27. }
  28. static int test_gemm_0(int M, int N, int K, int TILE_M, int TILE_N, int TILE_K)
  29. {
  30. return 0
  31. || test_gemm(M, N, K, TILE_M, TILE_N, TILE_K, 2.1f, 0, 0, 0)
  32. || test_gemm(M, N, K, TILE_M, TILE_N, TILE_K, 3.1f, 0, 1, 0)
  33. || test_gemm(M, N, K, TILE_M, TILE_N, TILE_K, 4.1f, 1, 0, 0)
  34. || test_gemm(M, N, K, TILE_M, TILE_N, TILE_K, 5.1f, 1, 1, 0)
  35. || test_gemm(M, N, K, TILE_M, TILE_N, TILE_K, 2.1f, 0, 0, 1)
  36. || test_gemm(M, N, K, TILE_M, TILE_N, TILE_K, 3.1f, 0, 1, 1)
  37. || test_gemm(M, N, K, TILE_M, TILE_N, TILE_K, 4.1f, 1, 0, 1)
  38. || test_gemm(M, N, K, TILE_M, TILE_N, TILE_K, 5.1f, 1, 1, 1);
  39. }
  40. int main()
  41. {
  42. SRAND(7767517);
  43. int mnk[][3] = {
  44. {1, 1, 1},
  45. {2, 2, 2},
  46. {3, 3, 3},
  47. {4, 4, 4},
  48. {5, 5, 5},
  49. {6, 6, 6},
  50. {7, 7, 7},
  51. {8, 8, 8},
  52. {15, 15, 15},
  53. {16, 16, 16},
  54. {24, 24, 24},
  55. {31, 31, 31},
  56. {31, 32, 31},
  57. {32, 31, 32},
  58. {32, 32, 32},
  59. {20, 32, 20},
  60. {40, 40, 40},
  61. {47, 47, 47},
  62. {48, 48, 48},
  63. {52, 52, 52},
  64. {63, 64, 63},
  65. {64, 63, 64},
  66. {64, 64, 64}
  67. };
  68. int tile_mnk[][3] = {
  69. {1, 1, 1},
  70. {2, 2, 2},
  71. {4, 4, 4},
  72. {8, 8, 8},
  73. {12, 12, 12},
  74. {16, 16, 16},
  75. {20, 20, 20},
  76. {24, 24, 24},
  77. {28, 28, 28}
  78. };
  79. int mnk_count = sizeof(mnk) / sizeof(int) / 3;
  80. int tile_mnk_count = sizeof(tile_mnk) / sizeof(int) / 3;
  81. for (int i = 0; i < mnk_count; i++)
  82. {
  83. int M = mnk[i][0];
  84. int N = mnk[i][1];
  85. int K = mnk[i][2];
  86. for (int j = 0; j < tile_mnk_count; j++)
  87. {
  88. int TILE_M = tile_mnk[j][0];
  89. int TILE_N = tile_mnk[j][1];
  90. int TILE_K = tile_mnk[j][2];
  91. if (TILE_M >= M && TILE_N >= N && TILE_K >= K)
  92. continue;
  93. int ret = test_gemm_0(M, N, K, TILE_M, TILE_N, TILE_K);
  94. if (ret != 0)
  95. return ret;
  96. }
  97. // test no tiling
  98. int ret = test_gemm_0(M, N, K, 100, 100, 100);
  99. if (ret != 0)
  100. return ret;
  101. }
  102. return 0;
  103. }