You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_bias_multi_thread.cpp 50 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209
  1. /**
  2. * \file dnn/test/arm_common/conv_bias_multi_thread.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/dtype.h"
  13. #include "test/arm_common/fixture.h"
  14. #include "test/common/benchmarker.h"
  15. #include "test/common/conv_bias.h"
  16. #include "test/arm_common/cpuinfo_help.h"
  17. using namespace megdnn;
  18. using namespace test;
  19. using namespace conv_bias;
  20. std::vector<conv_bias::TestArg> get_int8_quint8_conv_bias_args(
  21. std::vector<size_t> kernel, size_t stride, bool no_pad, bool no_bias,
  22. bool no_nonlinemode) {
  23. using namespace conv_bias;
  24. using Param = param::ConvBias;
  25. using NLMode = param::ConvBias::NonlineMode;
  26. std::vector<TestArg> args;
  27. auto pack = [&](size_t n, size_t oc, size_t ic, size_t w, size_t h,
  28. size_t kernel, size_t stride, NLMode nlmode) {
  29. Param param;
  30. param.stride_h = stride;
  31. param.stride_w = stride;
  32. if (!no_pad) {
  33. param.pad_h = kernel / 2;
  34. param.pad_w = kernel / 2;
  35. } else {
  36. param.pad_h = 0;
  37. param.pad_w = 0;
  38. }
  39. param.nonlineMode = nlmode;
  40. args.emplace_back(param, TensorShape{n, ic, h, w},
  41. TensorShape{oc, ic, kernel, kernel}, TensorShape{});
  42. if (!no_bias) {
  43. args.emplace_back(param, TensorShape{n, ic, h, w},
  44. TensorShape{oc, ic, kernel, kernel},
  45. TensorShape{1, oc, 1, 1});
  46. }
  47. };
  48. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  49. if (!no_nonlinemode) {
  50. nonlinemode.emplace_back(NLMode::RELU);
  51. nonlinemode.emplace_back(NLMode::H_SWISH);
  52. }
  53. for (size_t n : {1, 2}) {
  54. for (auto nlmode : nonlinemode) {
  55. for (size_t ic : {1, 3, 7}) {
  56. for (size_t oc : {1, 3, 7}) {
  57. for (size_t size : {4, 6, 8, 14, 16, 18}) {
  58. for (size_t kern : kernel) {
  59. pack(n, oc, ic, size, size, kern, stride, nlmode);
  60. }
  61. }
  62. }
  63. }
  64. }
  65. }
  66. return args;
  67. }
  68. std::vector<conv_bias::TestArg> get_nchw44_channel_wise_args(
  69. std::vector<size_t> kernel, size_t stride, bool no_bias,
  70. bool no_nonlinemode, bool no_full_bias) {
  71. using namespace conv_bias;
  72. using Param = param::ConvBias;
  73. using NLMode = param::ConvBias::NonlineMode;
  74. std::vector<TestArg> args;
  75. auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel,
  76. size_t stride, NLMode nlmode, bool pad) {
  77. Param param;
  78. param.stride_h = stride;
  79. param.stride_w = stride;
  80. if (pad) {
  81. param.pad_h = kernel / 2;
  82. param.pad_w = kernel / 2;
  83. } else {
  84. param.pad_h = 0;
  85. param.pad_w = 0;
  86. }
  87. param.nonlineMode = nlmode;
  88. param.format = param::ConvBias::Format::NCHW44;
  89. param.sparse = param::ConvBias::Sparse::GROUP;
  90. args.emplace_back(param, TensorShape{n, group, h, w, 4},
  91. TensorShape{group, 1, 1, kernel, kernel, 4},
  92. TensorShape{});
  93. if (!no_bias) {
  94. args.emplace_back(param, TensorShape{n, group, h, w, 4},
  95. TensorShape{group, 1, 1, kernel, kernel, 4},
  96. TensorShape{1, group, 1, 1, 4});
  97. }
  98. if (!no_full_bias) {
  99. args.emplace_back(
  100. param, TensorShape{n, group, h, w, 4},
  101. TensorShape{group, 1, 1, kernel, kernel, 4},
  102. TensorShape{n, group,
  103. (h + 2 * param.pad_w - kernel) / stride + 1,
  104. (w + 2 * param.pad_w - kernel) / stride + 1,
  105. 4});
  106. }
  107. };
  108. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  109. if (!no_nonlinemode) {
  110. nonlinemode.emplace_back(NLMode::RELU);
  111. nonlinemode.emplace_back(NLMode::H_SWISH);
  112. }
  113. for (size_t n : {1, 2}) {
  114. for (auto nlmode : nonlinemode) {
  115. for (bool pad : {true}) {
  116. for (size_t group : {1, 2, 4, 7, 128}) {
  117. for (size_t size : {4, 6, 7, 9, 15, 40}) {
  118. for (size_t kern : kernel) {
  119. pack(n, group, size, size, kern, stride, nlmode,
  120. pad);
  121. }
  122. }
  123. }
  124. }
  125. for (bool pad : {false}) {
  126. for (size_t group : {1, 2, 7, 128}) {
  127. for (size_t size : {7, 9, 15, 40}) {
  128. for (size_t kern : kernel) {
  129. pack(n, group, size, size, kern, stride, nlmode,
  130. pad);
  131. }
  132. }
  133. }
  134. }
  135. }
  136. }
  137. return args;
  138. }
  139. void checker_conv_bias_qint8x8x8(std::vector<conv_bias::TestArg> args,
  140. Handle* handle, const char* algo_name) {
  141. Checker<ConvBias> checker(handle);
  142. checker.set_before_exec_callback(
  143. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  144. #if MEGDNN_ARMV7
  145. checker.set_epsilon(1);
  146. #endif
  147. UniformIntRNG rng{-50, 50};
  148. checker.set_dtype(0, dtype::QuantizedS8(0.41113496f))
  149. .set_dtype(1, dtype::QuantizedS8(0.01887994f))
  150. .set_dtype(2, dtype::QuantizedS32(0.41113496f * 0.01887994f))
  151. .set_dtype(4, dtype::QuantizedS8(0.49550694f))
  152. .set_rng(0, &rng)
  153. .set_rng(1, &rng)
  154. .set_rng(2, &rng);
  155. for (auto&& arg : args) {
  156. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  157. }
  158. }
  159. void checker_conv_bias_qint8x8x32(std::vector<conv_bias::TestArg> args,
  160. Handle* handle, const char* algo_name) {
  161. Checker<ConvBias> checker(handle);
  162. UniformIntRNG rng{-50, 50};
  163. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  164. .set_dtype(1, dtype::QuantizedS8(2.5f))
  165. .set_dtype(2, dtype::QuantizedS32(6.25f))
  166. .set_dtype(4, {});
  167. checker.set_before_exec_callback(
  168. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  169. for (auto&& arg : args) {
  170. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  171. }
  172. }
  173. void checker_conv_bias_quint8x8x8(std::vector<conv_bias::TestArg> args,
  174. Handle* handle, const char* algo_name) {
  175. Checker<ConvBias> checker(handle);
  176. checker.set_before_exec_callback(
  177. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  178. UniformIntRNG rng(0, 255);
  179. checker.set_dtype(0, dtype::Quantized8Asymm(0.2f, 100))
  180. .set_dtype(1, dtype::Quantized8Asymm(0.2f, 120))
  181. .set_dtype(2, dtype::QuantizedS32(0.04f))
  182. .set_dtype(4, dtype::Quantized8Asymm(1.4f, 110))
  183. .set_rng(0, &rng)
  184. .set_rng(1, &rng)
  185. .set_rng(2, &rng);
  186. for (auto&& arg : args) {
  187. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  188. }
  189. }
  190. void checker_conv_bias_quint8x8x32(std::vector<conv_bias::TestArg> args,
  191. Handle* handle, const char* algo_name) {
  192. Checker<ConvBias> checker(handle);
  193. checker.set_before_exec_callback(
  194. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  195. NormalRNG rng(128.f);
  196. checker.set_rng(0, &rng).set_rng(1, &rng);
  197. checker.set_dtype(0, dtype::Quantized8Asymm(1.2f, (uint8_t)127))
  198. .set_dtype(1, dtype::Quantized8Asymm(1.3f, (uint8_t)129))
  199. .set_dtype(2, dtype::QuantizedS32(1.2 * 1.3))
  200. .set_dtype(4, {});
  201. for (auto&& arg : args) {
  202. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  203. }
  204. }
  205. void checker_conv_bias_int8x8x32_multi(std::vector<conv_bias::TestArg> args,
  206. Handle* handle, const char* algo_name) {
  207. Checker<ConvBias> checker(handle);
  208. checker.set_before_exec_callback(
  209. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  210. checker.set_dtype(0, dtype::Int8());
  211. checker.set_dtype(1, dtype::Int8());
  212. checker.set_dtype(2, dtype::Int32());
  213. checker.set_dtype(4, dtype::Int32());
  214. for (auto&& arg : args) {
  215. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  216. }
  217. }
  218. /**********************************F32 direct************************/
  219. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32) {
  220. check_conv_bias(
  221. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  222. handle(), "F32DIRECT");
  223. }
  224. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) {
  225. //! k=7 s=1
  226. check_conv_bias(get_nchw44_conv_bias_args({7}, ONLY_IDENTITY_NLMODE,
  227. BR_AND_NO_BIASMODE, 1),
  228. handle(), "F32_CONV_NCHW44_DIRECT");
  229. }
  230. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K2K3) {
  231. check_conv_bias(
  232. get_nchw44_conv_bias_args({2, 3}, FULL_NLMODE, ONLY_BR_BIASMODE, 1),
  233. handle(), "F32_CONV_NCHW44_DIRECT");
  234. }
  235. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K5) {
  236. check_conv_bias(
  237. get_nchw44_conv_bias_args({5}, FULL_NLMODE, ONLY_BR_BIASMODE, 1),
  238. handle(), "F32_CONV_NCHW44_DIRECT");
  239. }
  240. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S2) {
  241. check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, FULL_NLMODE,
  242. ONLY_BR_BIASMODE, 2),
  243. handle(), "F32_CONV_NCHW44_DIRECT");
  244. }
  245. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1) {
  246. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  247. handle(), "F32STRD1");
  248. }
  249. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2) {
  250. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  251. handle(), "F32STRD2");
  252. }
  253. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S2) {
  254. check_conv_bias(
  255. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  256. ONLY_BR_BIASMODE, 2, false, true),
  257. handle(), "F32_CONV_NCHW_NCHW44");
  258. }
  259. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S1) {
  260. check_conv_bias(
  261. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  262. ONLY_BR_BIASMODE, 1, false, true),
  263. handle(), "F32_CONV_NCHW_NCHW44");
  264. }
  265. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_1) {
  266. check_conv_bias(
  267. get_nchw44_channel_wise_args({2, 3}, 1, false, false, false),
  268. handle(), "F32_CHANNEL_WISE_NCHW44");
  269. }
  270. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_2) {
  271. check_conv_bias(get_nchw44_channel_wise_args({5}, 1, false, false, false),
  272. handle(), "F32_CHANNEL_WISE_NCHW44");
  273. }
  274. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP32_NCHW44) {
  275. check_conv_bias(
  276. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, false),
  277. handle(), "F32_CHANNEL_WISE_NCHW44");
  278. }
  279. /**********************************F16 direct************************/
  280. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  281. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16) {
  282. NormalRNG rng(1);
  283. checker_conv_bias_f16(
  284. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  285. handle(), rng, "F16DIRECT", 0.03);
  286. }
  287. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1) {
  288. NormalRNG rng(1);
  289. checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false),
  290. handle(), rng, "F16STRD1", 0.03);
  291. }
  292. #endif
  293. /**********************************algo 8816 direct************************/
  294. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT) {
  295. checker_conv_bias_int8x8x16(
  296. get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(),
  297. "I8816DIRECT");
  298. }
  299. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2) {
  300. checker_conv_bias_int8x8x16(
  301. get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(),
  302. "I8816STRD2");
  303. }
  304. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_NCHW_NCHW44_S2) {
  305. checker_conv_bias_int8x8x16(
  306. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  307. ONLY_NO_BIASMODE, 2, false, true),
  308. handle(), "I8816_CONV_NCHW_NCHW44");
  309. }
  310. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_NCHW_NCHW44_S1) {
  311. checker_conv_bias_int8x8x16(
  312. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  313. ONLY_NO_BIASMODE, 1, false, true),
  314. handle(), "I8816_CONV_NCHW_NCHW44");
  315. }
  316. /**********************************algo 8-8-32 direct************************/
  317. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1) {
  318. checker_conv_bias_int8x8x32_multi(
  319. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  320. "S8STRD1");
  321. }
  322. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2) {
  323. checker_conv_bias_int8x8x32_multi(
  324. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  325. "S8STRD2");
  326. }
  327. TEST_F(ARM_COMMON_MULTI_THREADS,
  328. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT1_NCHW44) {
  329. checker_conv_bias_int8x8x32_multi(
  330. get_nchw44_channel_wise_args({2, 3, 5}, 1, false, true, true),
  331. handle(), "S8_CHAN_WISE_STRD1_NCHW44");
  332. }
  333. TEST_F(ARM_COMMON_MULTI_THREADS,
  334. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT2_NCHW44) {
  335. checker_conv_bias_int8x8x32_multi(
  336. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, true, true),
  337. handle(), "S8_CHAN_WISE_STRD2_NCHW44");
  338. }
  339. TEST_F(ARM_COMMON, CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT1_NCHW44) {
  340. Checker<ConvBias> checker(handle());
  341. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  342. "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"));
  343. checker.set_dtype(0, dtype::Int8());
  344. checker.set_dtype(1, dtype::Int8());
  345. checker.set_dtype(2, dtype::Int16());
  346. checker.set_dtype(4, dtype::Int16());
  347. auto args = get_nchw44_channel_wise_args({2, 3, 5}, 1, false, true, true);
  348. for (auto&& arg : args) {
  349. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  350. }
  351. }
  352. TEST_F(ARM_COMMON_MULTI_THREADS,
  353. CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT2_NCHW44) {
  354. Checker<ConvBias> checker(handle());
  355. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  356. "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"));
  357. checker.set_dtype(0, dtype::Int8());
  358. checker.set_dtype(1, dtype::Int8());
  359. checker.set_dtype(2, dtype::Int16());
  360. checker.set_dtype(4, dtype::Int16());
  361. auto args = get_nchw44_channel_wise_args({2, 3, 5}, 2, false, true, true);
  362. for (auto&& arg : args) {
  363. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  364. }
  365. }
  366. /********************************qint8 direct******************************/
  367. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1) {
  368. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  369. {2, 3, 5, 7}, 1, false, false, false),
  370. handle(), "S8STRD1");
  371. }
  372. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2) {
  373. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  374. {2, 3, 5, 7}, 2, false, false, false),
  375. handle(), "S8STRD2");
  376. }
  377. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) {
  378. checker_conv_bias_qint8x8x8(
  379. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE,
  380. ONLY_BR_BIASMODE, 1),
  381. handle(), "S8_NCHW44_DIRECT");
  382. }
  383. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8816) {
  384. checker_conv_bias_int8x8x16(
  385. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  386. ONLY_BR_BIASMODE, 1),
  387. handle(), "S8x8x16_NCHW44_DIRECT");
  388. }
  389. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8816) {
  390. checker_conv_bias_int8x8x16(
  391. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  392. ONLY_BR_BIASMODE, 2),
  393. handle(), "S8x8x16_NCHW44_DIRECT");
  394. }
  395. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8832) {
  396. checker_conv_bias_qint8x8x32(
  397. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  398. ONLY_BR_BIASMODE, 1),
  399. handle(), "S8_NCHW44_DIRECT");
  400. }
  401. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8832) {
  402. checker_conv_bias_qint8x8x32(
  403. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  404. ONLY_NO_BIASMODE, 2),
  405. handle(), "S8_NCHW44_DIRECT");
  406. }
  407. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44) {
  408. checker_conv_bias_qint8x8x8(
  409. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE,
  410. BR_AND_NO_BIASMODE, 2),
  411. handle(), "S8_NCHW44_DIRECT");
  412. }
  413. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT1_NCHW44) {
  414. checker_conv_bias_qint8x8x8(
  415. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  416. BR_AND_NO_BIASMODE, 1),
  417. handle(), "S8_NCHW44_DIRECT");
  418. checker_conv_bias_qint8x8x8(
  419. get_nchw44_channel_wise_args({2, 3, 5}, 1, false, false, true),
  420. handle(), "S8_CHAN_WISE_STRD1_NCHW44");
  421. }
  422. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT2_NCHW44) {
  423. checker_conv_bias_qint8x8x8(
  424. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, true),
  425. handle(), "S8_CHAN_WISE_STRD2_NCHW44");
  426. }
  427. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S1) {
  428. checker_conv_bias_qint8x8x8(
  429. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE,
  430. BR_AND_NO_BIASMODE, 1, false, true),
  431. handle(), "S8_CONV_NCHW_NCHW44");
  432. }
  433. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S2) {
  434. checker_conv_bias_qint8x8x8(
  435. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE,
  436. BR_AND_NO_BIASMODE, 2, false, true),
  437. handle(), "S8_CONV_NCHW_NCHW44");
  438. }
  439. /*****************************quint8 direct****************************/
  440. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1) {
  441. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  442. {2, 3, 5, 7}, 1, false, false, false),
  443. handle(), "QU8STRD1");
  444. }
  445. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2) {
  446. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  447. {2, 3, 5, 7}, 2, false, false, false),
  448. handle(), "QU8STRD2");
  449. }
  450. /****************************dot qint8 direct*************************/
  451. #if __ARM_FEATURE_DOTPROD
  452. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) {
  453. auto args = get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE,
  454. BR_AND_NO_BIASMODE, 2, false, true);
  455. for (auto&& arg : args) {
  456. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  457. }
  458. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
  459. args = get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE,
  460. BR_AND_NO_BIASMODE, 1, false, true);
  461. for (auto&& arg : args) {
  462. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  463. }
  464. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
  465. }
  466. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) {
  467. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  468. {2, 3, 5, 7}, 1, false, false, false),
  469. handle(), "ARMDOTS8STRD1");
  470. }
  471. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_WITHDOTPROD) {
  472. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  473. {2, 3, 5, 7}, 2, false, false, false),
  474. handle(), "ARMDOTS8STRD2");
  475. }
  476. /****************************dot 8-8-32 direct*************************/
  477. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT) {
  478. checker_conv_bias_qint8x8x32(
  479. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  480. "ARMDOTS8STRD1");
  481. }
  482. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT) {
  483. checker_conv_bias_qint8x8x32(
  484. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  485. "ARMDOTS8STRD2");
  486. }
  487. /******************************dot quint8*****************************/
  488. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD) {
  489. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  490. {2, 3, 5, 7}, 1, false, false, false),
  491. handle(), "ARMDOTU8STRD1");
  492. }
  493. //! TODO: this test without test kernel size=3, add it will case buss error now
  494. //! in armv7
  495. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD) {
  496. checker_conv_bias_quint8x8x8(
  497. get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false),
  498. handle(), "ARMDOTU8STRD2");
  499. }
  500. /******************************dot quint8x8x32***********************/
  501. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1) {
  502. checker_conv_bias_quint8x8x32(
  503. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  504. "ARMDOTU8STRD1");
  505. }
  506. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2) {
  507. checker_conv_bias_quint8x8x32(
  508. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  509. "ARMDOTU8STRD2");
  510. }
  511. /******************************dot int8x8x8 nchw44 ***********************/
  512. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x8) {
  513. using namespace conv_bias;
  514. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  515. {2, 3, 5, 7}, QUAN_NLMODE, ONLY_BR_BIASMODE, 1);
  516. for (auto&& arg : args)
  517. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  518. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  519. }
  520. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x32) {
  521. using namespace conv_bias;
  522. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  523. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1);
  524. for (auto&& arg : args)
  525. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  526. checker_conv_bias_qint8x8x32(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  527. }
  528. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_8x8x32) {
  529. using namespace conv_bias;
  530. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  531. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1);
  532. for (auto&& arg : args)
  533. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  534. checker_conv_bias_int8x8x32_multi(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  535. }
  536. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x8) {
  537. using namespace conv_bias;
  538. //! test qint8x8x8
  539. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  540. {2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2);
  541. for (auto&& arg : args)
  542. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  543. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  544. }
  545. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x32) {
  546. using namespace conv_bias;
  547. //! test qint8x8x8
  548. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  549. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 2);
  550. for (auto&& arg : args)
  551. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  552. checker_conv_bias_qint8x8x32(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  553. }
  554. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_8x8x32) {
  555. using namespace conv_bias;
  556. //! test qint8x8x8
  557. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  558. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 2);
  559. for (auto&& arg : args)
  560. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  561. checker_conv_bias_int8x8x32_multi(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  562. }
  563. #endif
  564. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4) {
  565. using namespace conv_bias;
  566. std::vector<TestArg> args = get_winograd_mk_packed_args();
  567. Checker<ConvBiasForward> checker(handle());
  568. check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4);
  569. }
  570. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4_NCHW44) {
  571. using namespace conv_bias;
  572. std::vector<TestArg> args =
  573. get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1);
  574. Checker<ConvBiasForward> checker(handle());
  575. check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4,
  576. param::ConvBias::Format::NCHW44);
  577. }
  578. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63) {
  579. using namespace conv_bias;
  580. std::vector<TestArg> args = get_winograd_args(3);
  581. Checker<ConvBiasForward> checker(handle());
  582. check_winograd("1:6:32", checker, args);
  583. }
  584. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4) {
  585. using namespace conv_bias;
  586. std::vector<TestArg> args = get_winograd_mk_packed_args();
  587. Checker<ConvBiasForward> checker(handle());
  588. check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4);
  589. }
  590. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44) {
  591. using namespace conv_bias;
  592. std::vector<TestArg> args =
  593. get_nchw44_conv_bias_args({3},QUAN_NLMODE,BR_AND_NO_BIASMODE,1);
  594. Checker<ConvBiasForward> checker(handle());
  595. check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4,
  596. param::ConvBias::Format::NCHW44);
  597. }
  598. //! uncomment it when low precision mode is ok
  599. #if 0
  600. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44) {
  601. using namespace conv_bias;
  602. std::vector<TestArg> args =
  603. get_nchw44_conv_bias_args({3},QUAN_NLMODE,BR_AND_NO_BIASMODE,1);
  604. Checker<ConvBiasForward> checker(handle());
  605. check_winograd("4:7:16", checker, args, param::MatrixMul::Format::MK4,
  606. param::ConvBias::Format::NCHW44);
  607. }
  608. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44_WEIGHT_PREPROCESS) {
  609. using namespace conv_bias;
  610. std::vector<TestArg> args =
  611. get_nchw44_conv_bias_args({3},QUAN_NLMODE,BR_AND_NO_BIASMODE,1);
  612. Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
  613. handle());
  614. check_winograd("4:7:16", checker, args, param::MatrixMul::Format::MK4,
  615. param::ConvBias::Format::NCHW44);
  616. }
  617. #endif
  618. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F54) {
  619. using namespace conv_bias;
  620. std::vector<TestArg> args = get_winograd_args(4);
  621. Checker<ConvBiasForward> checker(handle());
  622. check_winograd("1:5:32", checker, args);
  623. }
  624. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F45) {
  625. using namespace conv_bias;
  626. std::vector<TestArg> args = get_winograd_args(5);
  627. Checker<ConvBiasForward> checker(handle());
  628. check_winograd("1:4:32", checker, args);
  629. }
  630. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD) {
  631. using namespace conv_bias;
  632. std::vector<TestArg> args = get_winograd_args(3);
  633. Checker<ConvBiasForward> checker(handle());
  634. auto extra_impl = [](const TensorNDArray& tensors, uint32_t m,
  635. param::ConvBias param, Handle* handle) {
  636. megdnn_assert(param.format == param::ConvBias::Format::NCHW);
  637. auto winograd_preprocess_opr =
  638. handle->create_operator<WinogradFilterPreprocess>();
  639. winograd_preprocess_opr->param().output_block_size = m;
  640. TensorLayout filter_transform_layout;
  641. winograd_preprocess_opr->deduce_layout(tensors[1].layout,
  642. filter_transform_layout);
  643. size_t winograd_preprocess_workspace_in_bytes =
  644. winograd_preprocess_opr->get_workspace_in_bytes(
  645. tensors[1].layout, filter_transform_layout);
  646. auto conv_bias_opr = handle->create_operator<ConvBias>();
  647. conv_bias_opr->param() = param;
  648. conv_bias_opr->param().format = param::ConvBias::Format::NCHW_WINOGRAD;
  649. conv_bias_opr->param().output_block_size = m;
  650. size_t conv_bias_workspace_in_bytes =
  651. conv_bias_opr->get_workspace_in_bytes(
  652. tensors[0].layout, filter_transform_layout,
  653. tensors[2].layout, tensors[3].layout, tensors[4].layout,
  654. nullptr);
  655. WorkspaceBundle wb(nullptr, {filter_transform_layout.span().dist_byte(),
  656. conv_bias_workspace_in_bytes,
  657. winograd_preprocess_workspace_in_bytes});
  658. wb.set(malloc(wb.total_size_in_bytes()));
  659. TensorND filter_transform_tensor(wb.get(0),
  660. std::move(filter_transform_layout));
  661. winograd_preprocess_opr->exec(tensors[1], filter_transform_tensor,
  662. wb.get_workspace(2));
  663. conv_bias_opr->exec(tensors[0], filter_transform_tensor, tensors[2],
  664. tensors[3], tensors[4], nullptr,
  665. wb.get_workspace(1));
  666. free(wb.ptr());
  667. };
  668. auto run = [&checker, &extra_impl](
  669. Handle* handle, const std::vector<TestArg>& args,
  670. const std::vector<size_t>& out_size, DType A_dtype,
  671. DType B_dtype, DType C_dtype, DType D_dtype,
  672. const float eps) {
  673. for (auto&& arg : args) {
  674. for (uint32_t m : out_size) {
  675. checker.set_extra_opr_impl(std::bind(extra_impl,
  676. std::placeholders::_1, m,
  677. arg.param, handle));
  678. checker.set_dtype(0, A_dtype)
  679. .set_dtype(1, B_dtype)
  680. .set_dtype(2, C_dtype)
  681. .set_dtype(4, D_dtype)
  682. .set_epsilon(eps)
  683. .set_param(arg.param)
  684. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  685. }
  686. }
  687. };
  688. run(handle(), args, {6}, dtype::Float32(), dtype::Float32(),
  689. dtype::Float32(), dtype::Float32(), 1e-3f);
  690. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  691. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  692. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  693. run(handle(), args, {6}, dtype::Float16(), dtype::Float16(),
  694. dtype::Float16(), dtype::Float16(), 0.35f);
  695. #endif
  696. }
  697. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_1) {
  698. using namespace conv_bias;
  699. Checker<ConvBiasForward> checker(handle());
  700. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  701. const std::vector<size_t>& out_size, DType A_dtype,
  702. DType B_dtype, DType C_dtype, DType D_dtype,
  703. param::MatrixMul::Format format, float eps) {
  704. for (auto&& arg : args) {
  705. for (uint32_t m : out_size) {
  706. checker.set_extra_opr_impl(std::bind(
  707. winograd_algo_extra_impl, std::placeholders::_1, m,
  708. arg.param, handle, format));
  709. checker.set_dtype(0, A_dtype)
  710. .set_dtype(1, B_dtype)
  711. .set_dtype(2, C_dtype)
  712. .set_dtype(4, D_dtype)
  713. .set_epsilon(eps)
  714. .set_param(arg.param)
  715. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  716. }
  717. }
  718. };
  719. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  720. std::vector<TestArg> args_first_half(args.begin(),
  721. args.begin() + args.size() / 2);
  722. run(handle(), args_first_half, {2, 6}, dtype::Float32{}, dtype::Float32{},
  723. dtype::Float32{}, dtype::Float32{}, param::MatrixMul::Format::MK4,
  724. 1e-3f);
  725. }
  726. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_2) {
  727. using namespace conv_bias;
  728. Checker<ConvBiasForward> checker(handle());
  729. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  730. const std::vector<size_t>& out_size, DType A_dtype,
  731. DType B_dtype, DType C_dtype, DType D_dtype,
  732. param::MatrixMul::Format format, float eps) {
  733. for (auto&& arg : args) {
  734. for (uint32_t m : out_size) {
  735. checker.set_extra_opr_impl(std::bind(
  736. winograd_algo_extra_impl, std::placeholders::_1, m,
  737. arg.param, handle, format));
  738. checker.set_dtype(0, A_dtype)
  739. .set_dtype(1, B_dtype)
  740. .set_dtype(2, C_dtype)
  741. .set_dtype(4, D_dtype)
  742. .set_epsilon(eps)
  743. .set_param(arg.param)
  744. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  745. }
  746. }
  747. };
  748. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  749. std::vector<TestArg> args_second_half(args.begin() + args.size() / 2,
  750. args.end());
  751. run(handle(), args_second_half, {2, 6}, dtype::Float32{}, dtype::Float32{},
  752. dtype::Float32{}, dtype::Float32{}, param::MatrixMul::Format::MK4,
  753. 1e-3f);
  754. }
  755. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  756. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F16) {
  757. using namespace conv_bias;
  758. Checker<ConvBiasForward> checker(handle());
  759. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  760. const std::vector<size_t>& out_size, DType A_dtype,
  761. DType B_dtype, DType C_dtype, DType D_dtype,
  762. param::MatrixMul::Format format, float eps) {
  763. for (auto&& arg : args) {
  764. for (uint32_t m : out_size) {
  765. checker.set_extra_opr_impl(std::bind(
  766. winograd_algo_extra_impl, std::placeholders::_1, m,
  767. arg.param, handle, format));
  768. checker.set_dtype(0, A_dtype)
  769. .set_dtype(1, B_dtype)
  770. .set_dtype(2, C_dtype)
  771. .set_dtype(4, D_dtype)
  772. .set_epsilon(eps)
  773. .set_param(arg.param)
  774. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  775. }
  776. }
  777. };
  778. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  779. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  780. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  781. run(handle(), args, {2}, dtype::Float16{}, dtype::Float16{},
  782. dtype::Float16{}, dtype::Float16{}, param::MatrixMul::Format::MK8,
  783. 0.25);
  784. }
  785. #endif
  786. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_INT8) {
  787. using namespace conv_bias;
  788. Checker<ConvBiasForward> checker(handle());
  789. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  790. const std::vector<size_t>& out_size, DType A_dtype,
  791. DType B_dtype, DType C_dtype, DType D_dtype,
  792. param::MatrixMul::Format format, float eps) {
  793. for (auto&& arg : args) {
  794. for (uint32_t m : out_size) {
  795. checker.set_extra_opr_impl(std::bind(
  796. winograd_algo_extra_impl, std::placeholders::_1, m,
  797. arg.param, handle, format));
  798. checker.set_dtype(0, A_dtype)
  799. .set_dtype(1, B_dtype)
  800. .set_dtype(2, C_dtype)
  801. .set_dtype(4, D_dtype)
  802. .set_epsilon(eps)
  803. .set_param(arg.param)
  804. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  805. }
  806. }
  807. };
  808. #if MEGDNN_AARCH64
  809. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  810. #else
  811. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  812. #endif
  813. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  814. ssprintf("WINOGRAD:%s:8:2:32", matmul_name).c_str()));
  815. std::vector<TestArg> quantized_args =
  816. get_quantized_winograd_mk_packed_args(8);
  817. UniformIntRNG int_rng{-50, 50};
  818. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  819. run(handle(), quantized_args, {2}, dtype::QuantizedS8(2.5f),
  820. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f),
  821. dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3);
  822. }
  823. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) {
  824. using namespace conv_bias;
  825. Checker<ConvBiasForward> checker(handle());
  826. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  827. const std::vector<size_t>& out_size, DType A_dtype,
  828. DType B_dtype, DType C_dtype, DType D_dtype,
  829. param::MatrixMul::Format format, float eps) {
  830. for (auto&& arg : args) {
  831. for (uint32_t m : out_size) {
  832. checker.set_extra_opr_impl(std::bind(
  833. winograd_algo_extra_impl, std::placeholders::_1, m,
  834. arg.param, handle, format));
  835. checker.set_dtype(0, A_dtype)
  836. .set_dtype(1, B_dtype)
  837. .set_dtype(2, C_dtype)
  838. .set_dtype(4, D_dtype)
  839. .set_epsilon(eps)
  840. .set_param(arg.param)
  841. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  842. }
  843. }
  844. };
  845. #if MEGDNN_AARCH64
  846. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  847. #else
  848. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  849. #endif
  850. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  851. ssprintf("WINOGRAD_NCHW44:%s:8:2:32", matmul_name).c_str()));
  852. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4);
  853. UniformIntRNG int_rng{-50, 50};
  854. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  855. run(handle(), quantized_args, {2}, dtype::QuantizedS8(2.5f),
  856. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f),
  857. dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3);
  858. }
  859. TEST_F(ARM_COMMON_MULTI_THREADS,
  860. CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPMODE) {
  861. using namespace conv_bias;
  862. Checker<ConvBiasForward> checker(handle());
  863. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  864. const std::vector<size_t>& out_size, DType A_dtype,
  865. DType B_dtype, DType C_dtype, DType D_dtype,
  866. param::MatrixMul::Format format, float eps) {
  867. for (auto&& arg : args) {
  868. for (uint32_t m : out_size) {
  869. checker.set_extra_opr_impl(std::bind(
  870. winograd_algo_extra_impl, std::placeholders::_1, m,
  871. arg.param, handle, format));
  872. checker.set_dtype(0, A_dtype)
  873. .set_dtype(1, B_dtype)
  874. .set_dtype(2, C_dtype)
  875. .set_dtype(4, D_dtype)
  876. .set_epsilon(eps)
  877. .set_param(arg.param)
  878. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  879. }
  880. }
  881. };
  882. #if MEGDNN_AARCH64
  883. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  884. #else
  885. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  886. #endif
  887. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  888. ssprintf("WINOGRAD_NCHW44:%s:8:2:32", matmul_name).c_str()));
  889. std::vector<TestArg> quantized_args =
  890. get_int8_nchw44_args(3, 4, false, true);
  891. UniformIntRNG int_rng{-50, 50};
  892. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  893. run(handle(), quantized_args, {2}, dtype::QuantizedS8(2.5f),
  894. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f),
  895. dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3);
  896. }
  897. TEST_F(ARM_COMMON_MULTI_THREADS,
  898. CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32) {
  899. using namespace conv_bias;
  900. Checker<ConvBiasForward> checker(handle());
  901. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  902. const std::vector<size_t>& out_size, DType A_dtype,
  903. DType B_dtype, DType C_dtype, DType D_dtype,
  904. param::MatrixMul::Format format, float eps) {
  905. for (auto&& arg : args) {
  906. for (uint32_t m : out_size) {
  907. checker.set_extra_opr_impl(std::bind(
  908. winograd_algo_extra_impl, std::placeholders::_1, m,
  909. arg.param, handle, format));
  910. checker.set_dtype(0, A_dtype)
  911. .set_dtype(1, B_dtype)
  912. .set_dtype(2, C_dtype)
  913. .set_dtype(4, D_dtype)
  914. .set_epsilon(eps)
  915. .set_param(arg.param)
  916. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  917. }
  918. }
  919. };
  920. float epsilon = 0.001;
  921. #if MEGDNN_AARCH64
  922. const char* matmul_name = "AARCH64_F32_MK4_4x16";
  923. #else
  924. const char* matmul_name = "ARMV7_F32_MK4_4x8";
  925. #endif
  926. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  927. ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str()));
  928. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4, true);
  929. UniformIntRNG int_rng{-50, 50};
  930. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  931. run(handle(), quantized_args, {2}, dtype::QuantizedS8(0.41113496f),
  932. dtype::QuantizedS8(0.01887994f),
  933. dtype::QuantizedS32(0.41113496f * 0.01887994f),
  934. dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4,
  935. epsilon);
  936. }
  937. TEST_F(ARM_COMMON_MULTI_THREADS,
  938. CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32_GROUPMODE) {
  939. using namespace conv_bias;
  940. Checker<ConvBiasForward> checker(handle());
  941. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  942. const std::vector<size_t>& out_size, DType A_dtype,
  943. DType B_dtype, DType C_dtype, DType D_dtype,
  944. param::MatrixMul::Format format, float eps) {
  945. for (auto&& arg : args) {
  946. for (uint32_t m : out_size) {
  947. checker.set_extra_opr_impl(std::bind(
  948. winograd_algo_extra_impl, std::placeholders::_1, m,
  949. arg.param, handle, format));
  950. checker.set_dtype(0, A_dtype)
  951. .set_dtype(1, B_dtype)
  952. .set_dtype(2, C_dtype)
  953. .set_dtype(4, D_dtype)
  954. .set_epsilon(eps)
  955. .set_param(arg.param)
  956. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  957. }
  958. }
  959. };
  960. float epsilon = 0.001;
  961. #if MEGDNN_AARCH64
  962. const char* matmul_name = "AARCH64_F32_MK4_4x16";
  963. #else
  964. const char* matmul_name = "ARMV7_F32_MK4_4x8";
  965. #endif
  966. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  967. ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str()));
  968. std::vector<TestArg> quantized_args =
  969. get_int8_nchw44_args(3, 4, true, true);
  970. UniformIntRNG int_rng{-50, 50};
  971. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  972. run(handle(), quantized_args, {2}, dtype::QuantizedS8(0.41113496f),
  973. dtype::QuantizedS8(0.01887994f),
  974. dtype::QuantizedS32(0.41113496f * 0.01887994f),
  975. dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4,
  976. epsilon);
  977. }
  978. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  979. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F23) {
  980. using namespace conv_bias;
  981. std::vector<TestArg> args = get_winograd_mk_packed_args();
  982. Checker<ConvBiasForward> checker(handle());
  983. check_winograd_fp16("1:2:32", checker, args, NULL, 0.08);
  984. }
  985. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_1) {
  986. using namespace conv_bias;
  987. std::vector<TestArg> args = get_winograd_args(5);
  988. std::vector<TestArg> args_head_half(args.begin(),
  989. args.begin() + args.size() / 2);
  990. Checker<ConvBiasForward> checker(handle());
  991. //! fp16 range -1.0 ~ 1.0
  992. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  993. check_winograd_fp16("1:4:32", checker, args_head_half, rng, 0.25);
  994. }
  995. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_2) {
  996. using namespace conv_bias;
  997. std::vector<TestArg> args = get_winograd_args(5);
  998. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2,
  999. args.end());
  1000. Checker<ConvBiasForward> checker(handle());
  1001. //! fp16 range -1.0 ~ 1.0
  1002. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  1003. check_winograd_fp16("1:4:32", checker, args_back_half, rng, 0.25);
  1004. }
  1005. //! FIXME: This test may be failed if run `ARM_COMMON.CONV_BIAS_WINOGRAD*`, but
  1006. //! it will pass when run single testcase
  1007. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F63) {
  1008. using namespace conv_bias;
  1009. std::vector<TestArg> args = get_winograd_args(3);
  1010. Checker<ConvBiasForward> checker(handle());
  1011. //! fp16 range -1.0 ~ 1.0
  1012. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  1013. check_winograd_fp16("1:6:32", checker, args, rng, 0.3);
  1014. }
  1015. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_1) {
  1016. using namespace conv_bias;
  1017. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  1018. std::vector<TestArg> args_head_half(args.begin(),
  1019. args.begin() + args.size() / 2);
  1020. Checker<ConvBiasForward> checker(handle());
  1021. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  1022. check_winograd_fp16("8:2:32", checker, args_head_half, rng, 0.25,
  1023. param::MatrixMul::Format::MK8);
  1024. }
  1025. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_2) {
  1026. using namespace conv_bias;
  1027. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  1028. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2,
  1029. args.end());
  1030. Checker<ConvBiasForward> checker(handle());
  1031. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  1032. check_winograd_fp16("8:2:32", checker, args_back_half, rng, 0.25,
  1033. param::MatrixMul::Format::MK8);
  1034. }
  1035. #endif
  1036. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_INT8_8X8) {
  1037. using namespace conv_bias;
  1038. std::vector<TestArg> args = get_quantized_winograd_mk_packed_args(8);
  1039. Checker<ConvBiasForward> checker(handle());
  1040. UniformIntRNG rng{-50, 50};
  1041. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  1042. .set_dtype(1, dtype::QuantizedS8(2.5f))
  1043. .set_dtype(2, dtype::QuantizedS32(6.25f))
  1044. .set_dtype(4, dtype::QuantizedS8(60.25f))
  1045. .set_rng(0, &rng)
  1046. .set_rng(1, &rng)
  1047. .set_rng(2, &rng);
  1048. check_winograd("8:2:32", checker, args, param::MatrixMul::Format::MK8);
  1049. }
  1050. TEST_F(ARM_COMMON_MULTI_THREADS,
  1051. CONV_BIAS_WINOGRAD_INT8_8X8_WEIGHT_PREPROCESS) {
  1052. using namespace conv_bias;
  1053. std::vector<TestArg> args = get_quantized_winograd_mk_packed_args(8);
  1054. Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
  1055. handle());
  1056. UniformIntRNG rng{-50, 50};
  1057. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  1058. .set_dtype(1, dtype::QuantizedS8(2.5f))
  1059. .set_dtype(2, dtype::QuantizedS32(6.25f))
  1060. .set_dtype(4, dtype::QuantizedS8(60.25f))
  1061. .set_rng(0, &rng)
  1062. .set_rng(1, &rng)
  1063. .set_rng(2, &rng);
  1064. check_winograd("8:2:32", checker, args, param::MatrixMul::Format::MK8);
  1065. }
  1066. // clang-format on
  1067. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台