You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_embed.cpp 2.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. // Copyright 2024 Tencent
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. #include "testutil.h"
  4. static int test_embed(int words, int num_output, int input_dim, int bias)
  5. {
  6. ncnn::ParamDict pd;
  7. pd.set(0, num_output);
  8. pd.set(1, input_dim);
  9. pd.set(2, bias);
  10. pd.set(3, num_output * input_dim);
  11. std::vector<ncnn::Mat> weights(bias ? 2 : 1);
  12. weights[0] = RandomMat(num_output * input_dim);
  13. if (bias)
  14. weights[1] = RandomMat(num_output);
  15. ncnn::Mat a(words);
  16. RandomizeInt(a, 0, input_dim);
  17. int ret = test_layer("Embed", pd, weights, a);
  18. if (ret != 0)
  19. {
  20. fprintf(stderr, "test_embed failed words=%d num_output=%d input_dim=%d bias=%d\n", words, num_output, input_dim, bias);
  21. }
  22. return ret;
  23. }
  24. static int test_embed_0()
  25. {
  26. return 0
  27. || test_embed(128, 128, 128, 0)
  28. || test_embed(128, 128, 128, 1)
  29. || test_embed(127, 127, 127, 0)
  30. || test_embed(127, 127, 127, 1)
  31. || test_embed(124, 124, 124, 0)
  32. || test_embed(124, 124, 124, 1);
  33. }
  34. #if NCNN_INT8
  35. static int test_embed_int8(int words, int num_output, int input_dim, int bias)
  36. {
  37. ncnn::ParamDict pd;
  38. pd.set(0, num_output);
  39. pd.set(1, input_dim);
  40. pd.set(2, bias);
  41. pd.set(3, num_output * input_dim);
  42. pd.set(18, 2);
  43. std::vector<ncnn::Mat> weights(bias ? 3 : 2);
  44. weights[0] = RandomS8Mat(num_output * input_dim);
  45. if (bias)
  46. {
  47. weights[1] = RandomMat(num_output);
  48. weights[2] = RandomMat(1, 100.f, 200.f);
  49. }
  50. else
  51. {
  52. weights[1] = RandomMat(1, 100.f, 200.f);
  53. }
  54. ncnn::Mat a(words);
  55. RandomizeInt(a, 0, input_dim);
  56. int ret = test_layer("Embed", pd, weights, a);
  57. if (ret != 0)
  58. {
  59. fprintf(stderr, "test_embed_int8 failed words=%d num_output=%d input_dim=%d bias=%d\n", words, num_output, input_dim, bias);
  60. }
  61. return ret;
  62. }
  63. static int test_embed_1()
  64. {
  65. return 0
  66. || test_embed_int8(128, 128, 128, 0)
  67. || test_embed_int8(128, 128, 128, 1)
  68. || test_embed_int8(127, 127, 127, 0)
  69. || test_embed_int8(127, 127, 127, 1)
  70. || test_embed_int8(124, 124, 124, 0)
  71. || test_embed_int8(124, 124, 124, 1);
  72. }
  73. #endif // NCNN_INT8
  74. int main()
  75. {
  76. SRAND(7767517);
  77. #if NCNN_INT8
  78. return test_embed_0() || test_embed_1();
  79. #else
  80. return test_embed_0();
  81. #endif
  82. }