You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_bias_multi_thread.cpp 90 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122
  1. /**
  2. * \file dnn/test/arm_common/conv_bias_multi_thread.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "test/arm_common/fixture.h"
  13. #include "test/common/benchmarker.h"
  14. #include "test/common/conv_bias.h"
  15. using namespace megdnn;
  16. using namespace test;
  17. using namespace conv_bias;
  18. std::vector<conv_bias::TestArg> get_int8_quint8_conv_bias_args(
  19. std::vector<size_t> kernel, size_t stride, bool no_pad, bool no_bias,
  20. bool no_nonlinemode) {
  21. using namespace conv_bias;
  22. using Param = param::ConvBias;
  23. using NLMode = param::ConvBias::NonlineMode;
  24. std::vector<TestArg> args;
  25. auto pack = [&](size_t n, size_t oc, size_t ic, size_t w, size_t h,
  26. size_t kernel, size_t stride, NLMode nlmode) {
  27. Param param;
  28. param.stride_h = stride;
  29. param.stride_w = stride;
  30. if (!no_pad) {
  31. param.pad_h = kernel / 2;
  32. param.pad_w = kernel / 2;
  33. } else {
  34. param.pad_h = 0;
  35. param.pad_w = 0;
  36. }
  37. param.nonlineMode = nlmode;
  38. args.emplace_back(param, TensorShape{n, ic, h, w},
  39. TensorShape{oc, ic, kernel, kernel}, TensorShape{});
  40. if (!no_bias) {
  41. args.emplace_back(param, TensorShape{n, ic, h, w},
  42. TensorShape{oc, ic, kernel, kernel},
  43. TensorShape{1, oc, 1, 1});
  44. }
  45. };
  46. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  47. if (!no_nonlinemode) {
  48. nonlinemode.emplace_back(NLMode::RELU);
  49. nonlinemode.emplace_back(NLMode::H_SWISH);
  50. }
  51. for (size_t n : {1, 2}) {
  52. for (auto nlmode : nonlinemode) {
  53. for (size_t ic : {1, 3, 7}) {
  54. for (size_t oc : {1, 3, 7}) {
  55. for (size_t size : {4, 6, 8, 14, 16, 18}) {
  56. for (size_t kern : kernel) {
  57. pack(n, oc, ic, size, size, kern, stride, nlmode);
  58. }
  59. }
  60. }
  61. }
  62. }
  63. }
  64. return args;
  65. }
  66. std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
  67. std::vector<size_t> kernel_vec, size_t stride, bool no_pad = false,
  68. bool no_bias = false, bool no_nonlinemode = false,
  69. bool is_input_nchw = false, bool is_nchw44_dot = false,
  70. bool support_full_bias = false, bool support_sigmoid = false,
  71. bool only_no_bias = false) {
  72. using namespace conv_bias;
  73. using NLMode = param::ConvBias::NonlineMode;
  74. std::vector<TestArg> args;
  75. auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w,
  76. size_t kernel, size_t stride, size_t group, NLMode nlmode,
  77. megdnn::BiasMode bias_mode, int any_pad = -1) {
  78. constexpr int pack_c = 4;
  79. const size_t pad = any_pad >= 0 ? any_pad : kernel / 2;
  80. auto oc_per_group = oc / group;
  81. auto ic_per_group = ic / group;
  82. bool ok_group = (oc % group == 0 && ic % group == 0) &&
  83. oc_per_group % pack_c == 0 && oc_per_group > 0 &&
  84. ic_per_group > 0;
  85. bool nchw_disable = group > 1 || ic_per_group >= 4;
  86. bool nchw44_disable = ic_per_group % pack_c != 0;
  87. bool invalid_pad = (w + 2 * pad < kernel) || (h + 2 * pad < kernel);
  88. if (!(ok_group) || invalid_pad) {
  89. return;
  90. }
  91. if ((is_input_nchw && nchw_disable) ||
  92. (!is_input_nchw && nchw44_disable)) {
  93. return;
  94. }
  95. size_t kernel_h = kernel;
  96. size_t kernel_w = kernel;
  97. param::ConvBias param;
  98. if (!is_nchw44_dot) {
  99. param.format = param::ConvBias::Format::NCHW44;
  100. } else {
  101. param.format = param::ConvBias::Format::NCHW44_DOT;
  102. }
  103. param.stride_h = stride;
  104. param.stride_w = stride;
  105. param.pad_h = pad;
  106. param.pad_w = pad;
  107. param.nonlineMode = nlmode;
  108. auto src_tensor_shape = TensorShape{n, ic / pack_c, h, w, pack_c};
  109. auto weight_tensor_shape = TensorShape{
  110. oc / pack_c, ic / pack_c, kernel_h, kernel_w, pack_c, pack_c};
  111. auto bias_tensor_shape = TensorShape{};
  112. if (bias_mode == megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) {
  113. bias_tensor_shape = {1, oc / pack_c, 1, 1, pack_c};
  114. } else if (bias_mode == megdnn::BiasMode::BIAS) {
  115. bias_tensor_shape = {n, oc / pack_c,
  116. (h + 2 * pad - kernel) / stride + 1,
  117. (w + 2 * pad - kernel) / stride + 1, pack_c};
  118. }
  119. if (group == 1) {
  120. param.sparse = param::ConvBias::Sparse::DENSE;
  121. } else if (group > 1 && ic / group == 1 && oc / group == 1) {
  122. megdnn_assert(0, "not support channel wise");
  123. param.sparse = param::ConvBias::Sparse::GROUP;
  124. weight_tensor_shape = TensorShape{group / pack_c, 1, 1,
  125. kernel_h, kernel_w, pack_c};
  126. } else if (group > 1 && oc_per_group % pack_c == 0 && oc / group > 0 &&
  127. ic_per_group % pack_c == 0 && ic / group > 0) {
  128. param.sparse = param::ConvBias::Sparse::GROUP;
  129. weight_tensor_shape = TensorShape{group,
  130. oc_per_group / pack_c,
  131. ic_per_group / pack_c,
  132. kernel_h,
  133. kernel_w,
  134. pack_c,
  135. pack_c};
  136. }
  137. if (is_input_nchw) {
  138. src_tensor_shape = TensorShape{n, ic, h, w};
  139. weight_tensor_shape =
  140. TensorShape{oc / pack_c, kernel_h, kernel_w, ic, pack_c};
  141. }
  142. args.emplace_back(param, src_tensor_shape, weight_tensor_shape,
  143. bias_tensor_shape);
  144. };
  145. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  146. if (!no_nonlinemode) {
  147. nonlinemode.emplace_back(NLMode::RELU);
  148. nonlinemode.emplace_back(NLMode::H_SWISH);
  149. }
  150. if (support_sigmoid) {
  151. nonlinemode.emplace_back(NLMode::SIGMOID);
  152. }
  153. std::vector<megdnn::BiasMode> bias_mode;
  154. if (!only_no_bias) {
  155. bias_mode.emplace_back(megdnn::BiasMode::BROADCAST_CHANNEL_BIAS);
  156. if (no_bias) {
  157. bias_mode.emplace_back(megdnn::BiasMode::NO_BIAS);
  158. }
  159. } else {
  160. bias_mode.emplace_back(megdnn::BiasMode::NO_BIAS);
  161. }
  162. if (support_full_bias) {
  163. bias_mode.emplace_back(megdnn::BiasMode::BIAS);
  164. }
  165. for (auto bias : bias_mode)
  166. for (auto nlmode : nonlinemode)
  167. for (size_t n : {1, 2})
  168. for (size_t kernel : kernel_vec)
  169. for (size_t oc : {4, 12})
  170. for (size_t ic : {1, 3, 4, 12})
  171. for (size_t h : {3, 5, 12})
  172. for (size_t w : {7, 16, 23}) {
  173. for (size_t group = 1;
  174. group <=
  175. std::min(std::min(oc, ic), 4_z);
  176. ++group) {
  177. pack(n, oc, ic, h, w, kernel, stride,
  178. group, nlmode, bias);
  179. }
  180. }
  181. return args;
  182. }
  183. std::vector<conv_bias::TestArg> get_nchw44_channel_wise_args(
  184. std::vector<size_t> kernel, size_t stride, bool no_bias,
  185. bool no_nonlinemode, bool no_full_bias) {
  186. using namespace conv_bias;
  187. using Param = param::ConvBias;
  188. using NLMode = param::ConvBias::NonlineMode;
  189. std::vector<TestArg> args;
  190. auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel,
  191. size_t stride, NLMode nlmode, bool pad) {
  192. Param param;
  193. param.stride_h = stride;
  194. param.stride_w = stride;
  195. if (pad) {
  196. param.pad_h = kernel / 2;
  197. param.pad_w = kernel / 2;
  198. } else {
  199. param.pad_h = 0;
  200. param.pad_w = 0;
  201. }
  202. param.nonlineMode = nlmode;
  203. param.format = param::ConvBias::Format::NCHW44;
  204. param.sparse = param::ConvBias::Sparse::GROUP;
  205. args.emplace_back(param, TensorShape{n, group, h, w, 4},
  206. TensorShape{group, 1, 1, kernel, kernel, 4},
  207. TensorShape{});
  208. if (!no_bias) {
  209. args.emplace_back(param, TensorShape{n, group, h, w, 4},
  210. TensorShape{group, 1, 1, kernel, kernel, 4},
  211. TensorShape{1, group, 1, 1, 4});
  212. }
  213. if (!no_full_bias) {
  214. args.emplace_back(
  215. param, TensorShape{n, group, h, w, 4},
  216. TensorShape{group, 1, 1, kernel, kernel, 4},
  217. TensorShape{n, group,
  218. (h + 2 * param.pad_w - kernel) / stride + 1,
  219. (w + 2 * param.pad_w - kernel) / stride + 1,
  220. 4});
  221. }
  222. };
  223. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  224. if (!no_nonlinemode) {
  225. nonlinemode.emplace_back(NLMode::RELU);
  226. nonlinemode.emplace_back(NLMode::H_SWISH);
  227. }
  228. for (size_t n : {1, 2}) {
  229. for (auto nlmode : nonlinemode) {
  230. for (bool pad : {true}) {
  231. for (size_t group : {1, 2, 4, 7, 128}) {
  232. for (size_t size : {4, 6, 7, 9, 15, 40}) {
  233. for (size_t kern : kernel) {
  234. pack(n, group, size, size, kern, stride, nlmode,
  235. pad);
  236. }
  237. }
  238. }
  239. }
  240. for (bool pad : {false}) {
  241. for (size_t group : {1, 2, 7, 128}) {
  242. for (size_t size : {7, 9, 15, 40}) {
  243. for (size_t kern : kernel) {
  244. pack(n, group, size, size, kern, stride, nlmode,
  245. pad);
  246. }
  247. }
  248. }
  249. }
  250. }
  251. }
  252. return args;
  253. }
  254. void checker_conv_bias_qint8x8x8(std::vector<conv_bias::TestArg> args,
  255. Handle* handle, const char* algo_name) {
  256. Checker<ConvBias> checker(handle);
  257. checker.set_before_exec_callback(
  258. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  259. #if MEGDNN_ARMV7
  260. checker.set_epsilon(1);
  261. #endif
  262. UniformIntRNG rng{-50, 50};
  263. checker.set_dtype(0, dtype::QuantizedS8(0.41113496f))
  264. .set_dtype(1, dtype::QuantizedS8(0.01887994f))
  265. .set_dtype(2, dtype::QuantizedS32(0.41113496f * 0.01887994f))
  266. .set_dtype(4, dtype::QuantizedS8(0.49550694f))
  267. .set_rng(0, &rng)
  268. .set_rng(1, &rng)
  269. .set_rng(2, &rng);
  270. for (auto&& arg : args) {
  271. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  272. }
  273. }
  274. void checker_conv_bias_qint8x8x32(std::vector<conv_bias::TestArg> args,
  275. Handle* handle, const char* algo_name) {
  276. Checker<ConvBias> checker(handle);
  277. UniformIntRNG rng{-50, 50};
  278. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  279. .set_dtype(1, dtype::QuantizedS8(2.5f))
  280. .set_dtype(2, dtype::QuantizedS32(6.25f))
  281. .set_dtype(4, {});
  282. checker.set_before_exec_callback(
  283. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  284. for (auto&& arg : args) {
  285. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  286. }
  287. }
  288. void checker_conv_bias_quint8x8x8(std::vector<conv_bias::TestArg> args,
  289. Handle* handle, const char* algo_name) {
  290. Checker<ConvBias> checker(handle);
  291. checker.set_before_exec_callback(
  292. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  293. UniformIntRNG rng(0, 255);
  294. checker.set_dtype(0, dtype::Quantized8Asymm(0.2f, 100))
  295. .set_dtype(1, dtype::Quantized8Asymm(0.2f, 120))
  296. .set_dtype(2, dtype::QuantizedS32(0.04f))
  297. .set_dtype(4, dtype::Quantized8Asymm(1.4f, 110))
  298. .set_rng(0, &rng)
  299. .set_rng(1, &rng)
  300. .set_rng(2, &rng);
  301. for (auto&& arg : args) {
  302. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  303. }
  304. }
  305. void checker_conv_bias_quint8x8x32(std::vector<conv_bias::TestArg> args,
  306. Handle* handle, const char* algo_name) {
  307. Checker<ConvBias> checker(handle);
  308. checker.set_before_exec_callback(
  309. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  310. NormalRNG rng(128.f);
  311. checker.set_rng(0, &rng).set_rng(1, &rng);
  312. checker.set_dtype(0, dtype::Quantized8Asymm(1.2f, (uint8_t)127))
  313. .set_dtype(1, dtype::Quantized8Asymm(1.3f, (uint8_t)129))
  314. .set_dtype(2, dtype::QuantizedS32(1.2 * 1.3))
  315. .set_dtype(4, {});
  316. for (auto&& arg : args) {
  317. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  318. }
  319. }
  320. void checker_conv_bias_int8x8x32_multi(std::vector<conv_bias::TestArg> args,
  321. Handle* handle, const char* algo_name) {
  322. Checker<ConvBias> checker(handle);
  323. checker.set_before_exec_callback(
  324. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  325. checker.set_dtype(0, dtype::Int8());
  326. checker.set_dtype(1, dtype::Int8());
  327. checker.set_dtype(2, dtype::Int32());
  328. checker.set_dtype(4, dtype::Int32());
  329. for (auto&& arg : args) {
  330. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  331. }
  332. }
  333. /**********************************F32 direct************************/
  334. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_LARGE_GROUP) {
  335. check_conv_bias(
  336. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  337. handle(), "F32DIRECT_LARGE_GROUP");
  338. }
  339. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_SMALL_GROUP) {
  340. check_conv_bias(
  341. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  342. handle(), "F32DIRECT_SMALL_GROUP");
  343. }
  344. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) {
  345. check_conv_bias(get_nchw44_conv_bias_args({7}, 1, false, true, true, false,
  346. false, false),
  347. handle(), "F32_CONV_NCHW44_DIRECT");
  348. }
  349. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K2K3) {
  350. check_conv_bias(get_nchw44_conv_bias_args({2, 3}, 1, false, false, false,
  351. false, false, true, true),
  352. handle(), "F32_CONV_NCHW44_DIRECT");
  353. }
  354. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K5) {
  355. check_conv_bias(get_nchw44_conv_bias_args({5}, 1, false, false, false,
  356. false, false, true, true),
  357. handle(), "F32_CONV_NCHW44_DIRECT");
  358. }
  359. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S2) {
  360. check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false,
  361. false, false, false, true, true),
  362. handle(), "F32_CONV_NCHW44_DIRECT");
  363. }
  364. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1_LARGE_GROUP) {
  365. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  366. handle(), "F32STRD1_LARGE_GROUP");
  367. }
  368. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1_SMALL_GROUP) {
  369. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  370. handle(), "F32STRD1_SMALL_GROUP");
  371. }
  372. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_LARGE_GROUP) {
  373. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  374. handle(), "F32STRD2_LARGE_GROUP");
  375. }
  376. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) {
  377. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  378. handle(), "F32STRD2_SMALL_GROUP");
  379. }
  380. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S2) {
  381. check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false,
  382. false, true),
  383. handle(), "F32_CONV_NCHW_NCHW44");
  384. }
  385. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S1) {
  386. check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false,
  387. false, true),
  388. handle(), "F32_CONV_NCHW_NCHW44");
  389. }
  390. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_1) {
  391. check_conv_bias(
  392. get_nchw44_channel_wise_args({2, 3}, 1, false, false, false),
  393. handle(), "F32_CHANNEL_WISE_NCHW44");
  394. }
  395. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_2) {
  396. check_conv_bias(get_nchw44_channel_wise_args({5}, 1, false, false, false),
  397. handle(), "F32_CHANNEL_WISE_NCHW44");
  398. }
  399. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP32_NCHW44) {
  400. check_conv_bias(
  401. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, false),
  402. handle(), "F32_CHANNEL_WISE_NCHW44");
  403. }
  404. /**********************************F16 direct************************/
  405. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  406. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_LARGE_GROUP) {
  407. NormalRNG rng(1);
  408. checker_conv_bias_f16(
  409. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  410. handle(), rng, "F16DIRECT_LARGE_GROUP", 0.03);
  411. }
  412. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_SMALL_GROUP) {
  413. NormalRNG rng(1);
  414. checker_conv_bias_f16(
  415. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  416. handle(), rng, "F16DIRECT_SMALL_GROUP", 0.03);
  417. }
  418. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1_LARGE_GROUP) {
  419. NormalRNG rng(1);
  420. checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false),
  421. handle(), rng, "F16STRD1_LARGE_GROUP", 0.03);
  422. }
  423. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1_SMALL_GROUP) {
  424. NormalRNG rng(1);
  425. checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false),
  426. handle(), rng, "F16STRD1_SMALL_GROUP", 0.03);
  427. }
  428. #endif
  429. /**********************************algo 8816 direct************************/
  430. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT_LARGE_GROUP) {
  431. checker_conv_bias_int8x8x16(
  432. get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(),
  433. "I8816DIRECT_LARGE_GROUP");
  434. }
  435. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT_SMALL_GROUP) {
  436. checker_conv_bias_int8x8x16(
  437. get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(),
  438. "I8816DIRECT_SMALL_GROUP");
  439. }
  440. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2_LARGE_GROUP) {
  441. checker_conv_bias_int8x8x16(
  442. get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(),
  443. "I8816STRD2_LARGE_GROUP");
  444. }
  445. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2_SMALL_GROUP) {
  446. checker_conv_bias_int8x8x16(
  447. get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(),
  448. "I8816STRD2_SMALL_GROUP");
  449. }
  450. /**********************************algo 8-8-32 direct************************/
  451. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1_LARGE_GROUP) {
  452. checker_conv_bias_int8x8x32_multi(
  453. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  454. "S8STRD1_LARGE_GROUP");
  455. }
  456. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1_SMALL_GROUP) {
  457. checker_conv_bias_int8x8x32_multi(
  458. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  459. "S8STRD1_SMALL_GROUP");
  460. }
  461. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2_LARGE_GROUP) {
  462. checker_conv_bias_int8x8x32_multi(
  463. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  464. "S8STRD2_LARGE_GROUP");
  465. }
  466. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2_SMALL_GROUP) {
  467. checker_conv_bias_int8x8x32_multi(
  468. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  469. "S8STRD2_SMALL_GROUP");
  470. }
  471. TEST_F(ARM_COMMON_MULTI_THREADS,
  472. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT1_NCHW44) {
  473. checker_conv_bias_int8x8x32_multi(
  474. get_nchw44_channel_wise_args({2, 3, 5}, 1, false, true, true),
  475. handle(), "S8_CHAN_WISE_STRD1_NCHW44");
  476. }
  477. TEST_F(ARM_COMMON_MULTI_THREADS,
  478. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT2_NCHW44) {
  479. checker_conv_bias_int8x8x32_multi(
  480. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, true, true),
  481. handle(), "S8_CHAN_WISE_STRD2_NCHW44");
  482. }
  483. /********************************qint8 direct******************************/
  484. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_LARGE_GROUP) {
  485. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  486. {2, 3, 5, 7}, 1, false, false, false),
  487. handle(), "S8STRD1_LARGE_GROUP");
  488. }
  489. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_SMALL_GROUP) {
  490. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  491. {2, 3, 5, 7}, 1, false, false, false),
  492. handle(), "S8STRD1_SMALL_GROUP");
  493. }
  494. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_LARGE_GROUP) {
  495. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  496. {2, 3, 5, 7}, 2, false, false, false),
  497. handle(), "S8STRD2_LARGE_GROUP");
  498. }
  499. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_SMALL_GROUP) {
  500. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  501. {2, 3, 5, 7}, 2, false, false, false),
  502. handle(), "S8STRD2_SMALL_GROUP");
  503. }
  504. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) {
  505. checker_conv_bias_qint8x8x8(
  506. get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  507. handle(), "S8_NCHW44_DIRECT");
  508. }
  509. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8832) {
  510. checker_conv_bias_qint8x8x32(
  511. get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, true),
  512. handle(), "S8_NCHW44_DIRECT");
  513. }
  514. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8832) {
  515. checker_conv_bias_qint8x8x32(
  516. get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, true),
  517. handle(), "S8_NCHW44_DIRECT");
  518. }
  519. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44) {
  520. checker_conv_bias_qint8x8x8(
  521. get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  522. handle(), "S8_NCHW44_DIRECT");
  523. }
  524. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT1_NCHW44) {
  525. checker_conv_bias_qint8x8x8(
  526. get_nchw44_channel_wise_args({2, 3, 5}, 1, false, false, true),
  527. handle(), "S8_CHAN_WISE_STRD1_NCHW44");
  528. }
  529. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT2_NCHW44) {
  530. checker_conv_bias_qint8x8x8(
  531. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, true),
  532. handle(), "S8_CHAN_WISE_STRD2_NCHW44");
  533. }
  534. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S1) {
  535. checker_conv_bias_qint8x8x8(
  536. get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false,
  537. true),
  538. handle(), "S8_CONV_NCHW_NCHW44");
  539. }
  540. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S2) {
  541. checker_conv_bias_qint8x8x8(
  542. get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false,
  543. true),
  544. handle(), "S8_CONV_NCHW_NCHW44");
  545. }
  546. /*****************************quint8 direct****************************/
  547. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_LARGE_GROUP) {
  548. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  549. {2, 3, 5, 7}, 1, false, false, false),
  550. handle(), "QU8STRD1_LARGE_GROUP");
  551. }
  552. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_SMALL_GROUP) {
  553. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  554. {2, 3, 5, 7}, 1, false, false, false),
  555. handle(), "QU8STRD1_SMALL_GROUP");
  556. }
  557. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_LARGE_GROUP) {
  558. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  559. {2, 3, 5, 7}, 2, false, false, false),
  560. handle(), "QU8STRD2_LARGE_GROUP");
  561. }
  562. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_SMALL_GROUP) {
  563. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  564. {2, 3, 5, 7}, 2, false, false, false),
  565. handle(), "QU8STRD2_SMALL_GROUP");
  566. }
  567. /****************************dot qint8 direct*************************/
  568. #if __ARM_FEATURE_DOTPROD
  569. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) {
  570. auto args = get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false,
  571. true);
  572. for (auto&& arg : args) {
  573. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  574. }
  575. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
  576. args = get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false,
  577. true);
  578. for (auto&& arg : args) {
  579. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  580. }
  581. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
  582. }
  583. TEST_F(ARM_COMMON_MULTI_THREADS,
  584. CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) {
  585. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  586. {2, 3, 5, 7}, 1, false, false, false),
  587. handle(), "ARMDOTS8STRD1_LARGE_GROUP");
  588. }
  589. TEST_F(ARM_COMMON_MULTI_THREADS,
  590. CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_SMALL_GROUP) {
  591. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  592. {2, 3, 5, 7}, 1, false, false, false),
  593. handle(), "ARMDOTS8STRD1_SMALL_GROUP");
  594. }
  595. TEST_F(ARM_COMMON_MULTI_THREADS,
  596. CONV_BIAS_INT8_STRIDE2_WITHDOTPROD_LARGE_GROUP) {
  597. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  598. {2, 3, 5, 7}, 2, false, false, false),
  599. handle(), "ARMDOTS8STRD2_LARGE_GROUP");
  600. }
  601. TEST_F(ARM_COMMON_MULTI_THREADS,
  602. CONV_BIAS_INT8_STRIDE2_WITHDOTPROD_SMALL_GROUP) {
  603. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  604. {2, 3, 5, 7}, 2, false, false, false),
  605. handle(), "ARMDOTS8STRD2_SMALL_GROUP");
  606. }
  607. /****************************dot 8-8-32 direct*************************/
  608. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT_LARGE_GROUP) {
  609. checker_conv_bias_qint8x8x32(
  610. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  611. "ARMDOTS8STRD1_LARGE_GROUP");
  612. }
  613. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT_SMALL_GROUP) {
  614. checker_conv_bias_qint8x8x32(
  615. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  616. "ARMDOTS8STRD1_SMALL_GROUP");
  617. }
  618. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT_LARGE_GROUP) {
  619. checker_conv_bias_qint8x8x32(
  620. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  621. "ARMDOTS8STRD2_LARGE_GROUP");
  622. }
  623. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT_SMALL_GROUP) {
  624. checker_conv_bias_qint8x8x32(
  625. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  626. "ARMDOTS8STRD2_SMALL_GROUP");
  627. }
  628. /******************************dot quint8*****************************/
  629. TEST_F(ARM_COMMON_MULTI_THREADS,
  630. CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) {
  631. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  632. {2, 3, 5, 7}, 1, false, false, false),
  633. handle(), "ARMDOTU8STRD1_LARGE_GROUP");
  634. }
  635. TEST_F(ARM_COMMON_MULTI_THREADS,
  636. CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD_SMALL_GROUP) {
  637. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  638. {2, 3, 5, 7}, 1, false, false, false),
  639. handle(), "ARMDOTU8STRD1_SMALL_GROUP");
  640. }
  641. TEST_F(ARM_COMMON_MULTI_THREADS,
  642. CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD_LARGE_GROUP) {
  643. checker_conv_bias_quint8x8x8(
  644. get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false),
  645. handle(), "ARMDOTU8STRD2_LARGE_GROUP");
  646. }
  647. TEST_F(ARM_COMMON_MULTI_THREADS,
  648. CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD_SMALL_GROUP) {
  649. checker_conv_bias_quint8x8x8(
  650. get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false),
  651. handle(), "ARMDOTU8STRD2_SMALL_GROUP");
  652. }
  653. /******************************dot quint8x8x32***********************/
  654. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1_LARGE_GROUP) {
  655. checker_conv_bias_quint8x8x32(
  656. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  657. "ARMDOTU8STRD1_LARGE_GROUP");
  658. }
  659. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1_SMALL_GROUP) {
  660. checker_conv_bias_quint8x8x32(
  661. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  662. "ARMDOTU8STRD1_SMALL_GROUP");
  663. }
  664. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_LARGE_GROUP) {
  665. checker_conv_bias_quint8x8x32(
  666. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  667. "ARMDOTU8STRD2_LARGE_GROUP");
  668. }
  669. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_SMALL_GROUP) {
  670. checker_conv_bias_quint8x8x32(
  671. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  672. "ARMDOTU8STRD2_SMALL_GROUP");
  673. }
  674. /******************************dot int8x8x8 nchw44 ***********************/
  675. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x8) {
  676. using namespace conv_bias;
  677. std::vector<TestArg> args = get_nchw44_conv_bias_args({2, 3, 5, 7}, 1);
  678. for (auto&& arg : args)
  679. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  680. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  681. }
  682. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x32) {
  683. using namespace conv_bias;
  684. std::vector<TestArg> args =
  685. get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, true, true);
  686. for (auto&& arg : args)
  687. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  688. checker_conv_bias_qint8x8x32(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  689. }
  690. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_8x8x32) {
  691. using namespace conv_bias;
  692. std::vector<TestArg> args =
  693. get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, true, true);
  694. for (auto&& arg : args)
  695. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  696. checker_conv_bias_int8x8x32_multi(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  697. }
  698. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x8) {
  699. using namespace conv_bias;
  700. //! test qint8x8x8
  701. std::vector<TestArg> args = get_nchw44_conv_bias_args({2, 3, 5, 7}, 2);
  702. for (auto&& arg : args)
  703. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  704. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  705. }
  706. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x32) {
  707. using namespace conv_bias;
  708. //! test qint8x8x8
  709. std::vector<TestArg> args =
  710. get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, true, true);
  711. for (auto&& arg : args)
  712. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  713. checker_conv_bias_qint8x8x32(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  714. }
  715. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_8x8x32) {
  716. using namespace conv_bias;
  717. //! test qint8x8x8
  718. std::vector<TestArg> args =
  719. get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, true, true);
  720. for (auto&& arg : args)
  721. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  722. checker_conv_bias_int8x8x32_multi(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  723. }
  724. #endif
  725. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4) {
  726. using namespace conv_bias;
  727. std::vector<TestArg> args = get_winograd_mk_packed_args();
  728. Checker<ConvBiasForward> checker(handle());
  729. check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4);
  730. }
  731. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4_NCHW44) {
  732. using namespace conv_bias;
  733. std::vector<TestArg> args = get_nchw44_conv_bias_args({3}, 1);
  734. Checker<ConvBiasForward> checker(handle());
  735. check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4,
  736. param::ConvBias::Format::NCHW44);
  737. }
  738. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63) {
  739. using namespace conv_bias;
  740. std::vector<TestArg> args = get_winograd_args(3);
  741. Checker<ConvBiasForward> checker(handle());
  742. check_winograd("1:6:32", checker, args);
  743. }
  744. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4) {
  745. using namespace conv_bias;
  746. std::vector<TestArg> args = get_winograd_mk_packed_args();
  747. Checker<ConvBiasForward> checker(handle());
  748. check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4);
  749. }
  750. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44) {
  751. using namespace conv_bias;
  752. std::vector<TestArg> args = get_nchw44_conv_bias_args({3}, 1);
  753. Checker<ConvBiasForward> checker(handle());
  754. check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4,
  755. param::ConvBias::Format::NCHW44);
  756. }
  757. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F54) {
  758. using namespace conv_bias;
  759. std::vector<TestArg> args = get_winograd_args(4);
  760. Checker<ConvBiasForward> checker(handle());
  761. check_winograd("1:5:32", checker, args);
  762. }
  763. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F45) {
  764. using namespace conv_bias;
  765. std::vector<TestArg> args = get_winograd_args(5);
  766. Checker<ConvBiasForward> checker(handle());
  767. check_winograd("1:4:32", checker, args);
  768. }
  769. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD) {
  770. using namespace conv_bias;
  771. std::vector<TestArg> args = get_winograd_args(3);
  772. Checker<ConvBiasForward> checker(handle());
  773. auto extra_impl = [](const TensorNDArray& tensors, uint32_t m,
  774. param::ConvBias param, Handle* handle) {
  775. megdnn_assert(param.format == param::ConvBias::Format::NCHW);
  776. auto winograd_preprocess_opr =
  777. handle->create_operator<WinogradFilterPreprocess>();
  778. winograd_preprocess_opr->param().output_block_size = m;
  779. TensorLayout filter_transform_layout;
  780. winograd_preprocess_opr->deduce_layout(tensors[1].layout,
  781. filter_transform_layout);
  782. size_t winograd_preprocess_workspace_in_bytes =
  783. winograd_preprocess_opr->get_workspace_in_bytes(
  784. tensors[1].layout, filter_transform_layout);
  785. auto conv_bias_opr = handle->create_operator<ConvBias>();
  786. conv_bias_opr->param() = param;
  787. conv_bias_opr->param().format = param::ConvBias::Format::NCHW_WINOGRAD;
  788. conv_bias_opr->param().output_block_size = m;
  789. size_t conv_bias_workspace_in_bytes =
  790. conv_bias_opr->get_workspace_in_bytes(
  791. tensors[0].layout, filter_transform_layout,
  792. tensors[2].layout, tensors[3].layout, tensors[4].layout,
  793. nullptr);
  794. WorkspaceBundle wb(nullptr, {filter_transform_layout.span().dist_byte(),
  795. conv_bias_workspace_in_bytes,
  796. winograd_preprocess_workspace_in_bytes});
  797. wb.set(malloc(wb.total_size_in_bytes()));
  798. TensorND filter_transform_tensor(wb.get(0),
  799. std::move(filter_transform_layout));
  800. winograd_preprocess_opr->exec(tensors[1], filter_transform_tensor,
  801. wb.get_workspace(2));
  802. conv_bias_opr->exec(tensors[0], filter_transform_tensor, tensors[2],
  803. tensors[3], tensors[4], nullptr,
  804. wb.get_workspace(1));
  805. free(wb.ptr());
  806. };
  807. auto run = [&checker, &extra_impl](
  808. Handle* handle, const std::vector<TestArg>& args,
  809. const std::vector<size_t>& out_size, DType A_dtype,
  810. DType B_dtype, DType C_dtype, DType D_dtype,
  811. const float eps) {
  812. for (auto&& arg : args) {
  813. for (uint32_t m : out_size) {
  814. checker.set_extra_opr_impl(std::bind(extra_impl,
  815. std::placeholders::_1, m,
  816. arg.param, handle));
  817. checker.set_dtype(0, A_dtype)
  818. .set_dtype(1, B_dtype)
  819. .set_dtype(2, C_dtype)
  820. .set_dtype(4, D_dtype)
  821. .set_epsilon(eps)
  822. .set_param(arg.param)
  823. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  824. }
  825. }
  826. };
  827. run(handle(), args, {6}, dtype::Float32(), dtype::Float32(),
  828. dtype::Float32(), dtype::Float32(), 1e-3f);
  829. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  830. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  831. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  832. run(handle(), args, {6}, dtype::Float16(), dtype::Float16(),
  833. dtype::Float16(), dtype::Float16(), 0.35f);
  834. #endif
  835. }
  836. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_PREPROCESS_NCHW44) {
  837. using namespace conv_bias;
  838. std::vector<TestArg> nchw44_args = get_nchw44_conv_bias_args({3}, 1);
  839. Checker<ConvBiasForward> checker(handle());
  840. auto extra_impl = [](const TensorNDArray& tensors, uint32_t m,
  841. param::ConvBias param, Handle* handle) {
  842. megdnn_assert(param.format == param::ConvBias::Format::NCHW44);
  843. auto winograd_preprocess_opr =
  844. handle->create_operator<WinogradFilterPreprocess>();
  845. winograd_preprocess_opr->param().output_block_size = m;
  846. winograd_preprocess_opr->param().format = param::MatrixMul::Format::MK4;
  847. TensorLayout filter_transform_layout;
  848. winograd_preprocess_opr->deduce_layout(tensors[1].layout,
  849. filter_transform_layout);
  850. size_t winograd_preprocess_workspace_in_bytes =
  851. winograd_preprocess_opr->get_workspace_in_bytes(
  852. tensors[1].layout, filter_transform_layout);
  853. auto conv_bias_opr = handle->create_operator<ConvBias>();
  854. conv_bias_opr->param() = param;
  855. conv_bias_opr->param().format =
  856. param::ConvBias::Format::NCHW44_WINOGRAD;
  857. conv_bias_opr->param().output_block_size = m;
  858. size_t conv_bias_workspace_in_bytes =
  859. conv_bias_opr->get_workspace_in_bytes(
  860. tensors[0].layout, filter_transform_layout,
  861. tensors[2].layout, tensors[3].layout, tensors[4].layout,
  862. nullptr);
  863. WorkspaceBundle wb(nullptr, {filter_transform_layout.span().dist_byte(),
  864. conv_bias_workspace_in_bytes,
  865. winograd_preprocess_workspace_in_bytes});
  866. wb.set(malloc(wb.total_size_in_bytes()));
  867. TensorND filter_transform_tensor(wb.get(0),
  868. std::move(filter_transform_layout));
  869. winograd_preprocess_opr->exec(tensors[1], filter_transform_tensor,
  870. wb.get_workspace(2));
  871. conv_bias_opr->exec(tensors[0], filter_transform_tensor, tensors[2],
  872. tensors[3], tensors[4], nullptr,
  873. wb.get_workspace(1));
  874. free(wb.ptr());
  875. };
  876. auto run = [&checker, &extra_impl](
  877. Handle* handle, const std::vector<TestArg>& args,
  878. const std::vector<size_t>& out_size, DType A_dtype,
  879. DType B_dtype, DType C_dtype, DType D_dtype,
  880. const float eps) {
  881. for (auto&& arg : args) {
  882. for (uint32_t m : out_size) {
  883. checker.set_extra_opr_impl(std::bind(extra_impl,
  884. std::placeholders::_1, m,
  885. arg.param, handle));
  886. checker.set_dtype(0, A_dtype)
  887. .set_dtype(1, B_dtype)
  888. .set_dtype(2, C_dtype)
  889. .set_dtype(4, D_dtype)
  890. .set_epsilon(eps)
  891. .set_param(arg.param)
  892. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  893. }
  894. }
  895. };
  896. run(handle(), nchw44_args, {2, 6}, dtype::Float32(), dtype::Float32(),
  897. dtype::Float32(), dtype::Float32(), 1e-3f);
  898. }
  899. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_1) {
  900. using namespace conv_bias;
  901. Checker<ConvBiasForward> checker(handle());
  902. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  903. const std::vector<size_t>& out_size, DType A_dtype,
  904. DType B_dtype, DType C_dtype, DType D_dtype,
  905. param::MatrixMul::Format format, float eps) {
  906. for (auto&& arg : args) {
  907. for (uint32_t m : out_size) {
  908. checker.set_extra_opr_impl(std::bind(
  909. winograd_algo_extra_impl, std::placeholders::_1, m,
  910. arg.param, handle, format));
  911. checker.set_dtype(0, A_dtype)
  912. .set_dtype(1, B_dtype)
  913. .set_dtype(2, C_dtype)
  914. .set_dtype(4, D_dtype)
  915. .set_epsilon(eps)
  916. .set_param(arg.param)
  917. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  918. }
  919. }
  920. };
  921. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  922. std::vector<TestArg> args_first_half(args.begin(),
  923. args.begin() + args.size() / 2);
  924. run(handle(), args_first_half, {2, 6}, dtype::Float32{}, dtype::Float32{},
  925. dtype::Float32{}, dtype::Float32{}, param::MatrixMul::Format::MK4,
  926. 1e-3f);
  927. }
  928. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_2) {
  929. using namespace conv_bias;
  930. Checker<ConvBiasForward> checker(handle());
  931. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  932. const std::vector<size_t>& out_size, DType A_dtype,
  933. DType B_dtype, DType C_dtype, DType D_dtype,
  934. param::MatrixMul::Format format, float eps) {
  935. for (auto&& arg : args) {
  936. for (uint32_t m : out_size) {
  937. checker.set_extra_opr_impl(std::bind(
  938. winograd_algo_extra_impl, std::placeholders::_1, m,
  939. arg.param, handle, format));
  940. checker.set_dtype(0, A_dtype)
  941. .set_dtype(1, B_dtype)
  942. .set_dtype(2, C_dtype)
  943. .set_dtype(4, D_dtype)
  944. .set_epsilon(eps)
  945. .set_param(arg.param)
  946. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  947. }
  948. }
  949. };
  950. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  951. std::vector<TestArg> args_second_half(args.begin() + args.size() / 2,
  952. args.end());
  953. run(handle(), args_second_half, {2, 6}, dtype::Float32{}, dtype::Float32{},
  954. dtype::Float32{}, dtype::Float32{}, param::MatrixMul::Format::MK4,
  955. 1e-3f);
  956. }
  957. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  958. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F16) {
  959. using namespace conv_bias;
  960. Checker<ConvBiasForward> checker(handle());
  961. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  962. const std::vector<size_t>& out_size, DType A_dtype,
  963. DType B_dtype, DType C_dtype, DType D_dtype,
  964. param::MatrixMul::Format format, float eps) {
  965. for (auto&& arg : args) {
  966. for (uint32_t m : out_size) {
  967. checker.set_extra_opr_impl(std::bind(
  968. winograd_algo_extra_impl, std::placeholders::_1, m,
  969. arg.param, handle, format));
  970. checker.set_dtype(0, A_dtype)
  971. .set_dtype(1, B_dtype)
  972. .set_dtype(2, C_dtype)
  973. .set_dtype(4, D_dtype)
  974. .set_epsilon(eps)
  975. .set_param(arg.param)
  976. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  977. }
  978. }
  979. };
  980. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  981. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  982. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  983. run(handle(), args, {2}, dtype::Float16{}, dtype::Float16{},
  984. dtype::Float16{}, dtype::Float16{}, param::MatrixMul::Format::MK8,
  985. 0.25);
  986. }
  987. #endif
  988. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_INT8) {
  989. using namespace conv_bias;
  990. Checker<ConvBiasForward> checker(handle());
  991. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  992. const std::vector<size_t>& out_size, DType A_dtype,
  993. DType B_dtype, DType C_dtype, DType D_dtype,
  994. param::MatrixMul::Format format, float eps) {
  995. for (auto&& arg : args) {
  996. for (uint32_t m : out_size) {
  997. checker.set_extra_opr_impl(std::bind(
  998. winograd_algo_extra_impl, std::placeholders::_1, m,
  999. arg.param, handle, format));
  1000. checker.set_dtype(0, A_dtype)
  1001. .set_dtype(1, B_dtype)
  1002. .set_dtype(2, C_dtype)
  1003. .set_dtype(4, D_dtype)
  1004. .set_epsilon(eps)
  1005. .set_param(arg.param)
  1006. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  1007. }
  1008. }
  1009. };
  1010. #if MEGDNN_AARCH64
  1011. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  1012. #else
  1013. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  1014. #endif
  1015. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  1016. ssprintf("WINOGRAD:%s:8:2:32", matmul_name).c_str()));
  1017. std::vector<TestArg> quantized_args =
  1018. get_quantized_winograd_mk_packed_args(8);
  1019. UniformIntRNG int_rng{-50, 50};
  1020. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  1021. run(handle(), quantized_args, {2}, dtype::QuantizedS8(2.5f),
  1022. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f),
  1023. dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3);
  1024. }
  1025. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) {
  1026. using namespace conv_bias;
  1027. Checker<ConvBiasForward> checker(handle());
  1028. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  1029. const std::vector<size_t>& out_size, DType A_dtype,
  1030. DType B_dtype, DType C_dtype, DType D_dtype,
  1031. param::MatrixMul::Format format, float eps) {
  1032. for (auto&& arg : args) {
  1033. for (uint32_t m : out_size) {
  1034. checker.set_extra_opr_impl(std::bind(
  1035. winograd_algo_extra_impl, std::placeholders::_1, m,
  1036. arg.param, handle, format));
  1037. checker.set_dtype(0, A_dtype)
  1038. .set_dtype(1, B_dtype)
  1039. .set_dtype(2, C_dtype)
  1040. .set_dtype(4, D_dtype)
  1041. .set_epsilon(eps)
  1042. .set_param(arg.param)
  1043. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  1044. }
  1045. }
  1046. };
  1047. #if MEGDNN_AARCH64
  1048. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  1049. #else
  1050. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  1051. #endif
  1052. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  1053. ssprintf("WINOGRAD_NCHW44:%s:8:2:32", matmul_name).c_str()));
  1054. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4);
  1055. UniformIntRNG int_rng{-50, 50};
  1056. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  1057. run(handle(), quantized_args, {2}, dtype::QuantizedS8(2.5f),
  1058. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f),
  1059. dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3);
  1060. }
  1061. TEST_F(ARM_COMMON_MULTI_THREADS,
  1062. CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPMODE) {
  1063. using namespace conv_bias;
  1064. Checker<ConvBiasForward> checker(handle());
  1065. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  1066. const std::vector<size_t>& out_size, DType A_dtype,
  1067. DType B_dtype, DType C_dtype, DType D_dtype,
  1068. param::MatrixMul::Format format, float eps) {
  1069. for (auto&& arg : args) {
  1070. for (uint32_t m : out_size) {
  1071. checker.set_extra_opr_impl(std::bind(
  1072. winograd_algo_extra_impl, std::placeholders::_1, m,
  1073. arg.param, handle, format));
  1074. checker.set_dtype(0, A_dtype)
  1075. .set_dtype(1, B_dtype)
  1076. .set_dtype(2, C_dtype)
  1077. .set_dtype(4, D_dtype)
  1078. .set_epsilon(eps)
  1079. .set_param(arg.param)
  1080. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  1081. }
  1082. }
  1083. };
  1084. #if MEGDNN_AARCH64
  1085. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  1086. #else
  1087. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  1088. #endif
  1089. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  1090. ssprintf("WINOGRAD_NCHW44:%s:8:2:32", matmul_name).c_str()));
  1091. std::vector<TestArg> quantized_args =
  1092. get_int8_nchw44_args(3, 4, false, true);
  1093. UniformIntRNG int_rng{-50, 50};
  1094. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  1095. run(handle(), quantized_args, {2}, dtype::QuantizedS8(2.5f),
  1096. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f),
  1097. dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3);
  1098. }
  1099. TEST_F(ARM_COMMON_MULTI_THREADS,
  1100. CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32) {
  1101. using namespace conv_bias;
  1102. Checker<ConvBiasForward> checker(handle());
  1103. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  1104. const std::vector<size_t>& out_size, DType A_dtype,
  1105. DType B_dtype, DType C_dtype, DType D_dtype,
  1106. param::MatrixMul::Format format, float eps) {
  1107. for (auto&& arg : args) {
  1108. for (uint32_t m : out_size) {
  1109. checker.set_extra_opr_impl(std::bind(
  1110. winograd_algo_extra_impl, std::placeholders::_1, m,
  1111. arg.param, handle, format));
  1112. checker.set_dtype(0, A_dtype)
  1113. .set_dtype(1, B_dtype)
  1114. .set_dtype(2, C_dtype)
  1115. .set_dtype(4, D_dtype)
  1116. .set_epsilon(eps)
  1117. .set_param(arg.param)
  1118. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  1119. }
  1120. }
  1121. };
  1122. float epsilon = 0.001;
  1123. #if MEGDNN_AARCH64
  1124. const char* matmul_name = "AARCH64_F32_MK4_4x16";
  1125. #else
  1126. const char* matmul_name = "ARMV7_F32_MK4_4x8";
  1127. #endif
  1128. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  1129. ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str()));
  1130. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4, true);
  1131. UniformIntRNG int_rng{-50, 50};
  1132. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  1133. run(handle(), quantized_args, {2}, dtype::QuantizedS8(0.41113496f),
  1134. dtype::QuantizedS8(0.01887994f),
  1135. dtype::QuantizedS32(0.41113496f * 0.01887994f),
  1136. dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4,
  1137. epsilon);
  1138. }
  1139. TEST_F(ARM_COMMON_MULTI_THREADS,
  1140. CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32_GROUPMODE) {
  1141. using namespace conv_bias;
  1142. Checker<ConvBiasForward> checker(handle());
  1143. auto run = [&checker](Handle* handle, const std::vector<TestArg>& args,
  1144. const std::vector<size_t>& out_size, DType A_dtype,
  1145. DType B_dtype, DType C_dtype, DType D_dtype,
  1146. param::MatrixMul::Format format, float eps) {
  1147. for (auto&& arg : args) {
  1148. for (uint32_t m : out_size) {
  1149. checker.set_extra_opr_impl(std::bind(
  1150. winograd_algo_extra_impl, std::placeholders::_1, m,
  1151. arg.param, handle, format));
  1152. checker.set_dtype(0, A_dtype)
  1153. .set_dtype(1, B_dtype)
  1154. .set_dtype(2, C_dtype)
  1155. .set_dtype(4, D_dtype)
  1156. .set_epsilon(eps)
  1157. .set_param(arg.param)
  1158. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  1159. }
  1160. }
  1161. };
  1162. float epsilon = 0.001;
  1163. #if MEGDNN_AARCH64
  1164. const char* matmul_name = "AARCH64_F32_MK4_4x16";
  1165. #else
  1166. const char* matmul_name = "ARMV7_F32_MK4_4x8";
  1167. #endif
  1168. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  1169. ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str()));
  1170. std::vector<TestArg> quantized_args =
  1171. get_int8_nchw44_args(3, 4, true, true);
  1172. UniformIntRNG int_rng{-50, 50};
  1173. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  1174. run(handle(), quantized_args, {2}, dtype::QuantizedS8(0.41113496f),
  1175. dtype::QuantizedS8(0.01887994f),
  1176. dtype::QuantizedS32(0.41113496f * 0.01887994f),
  1177. dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4,
  1178. epsilon);
  1179. }
  1180. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  1181. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F23) {
  1182. using namespace conv_bias;
  1183. std::vector<TestArg> args = get_winograd_mk_packed_args();
  1184. Checker<ConvBiasForward> checker(handle());
  1185. check_winograd_fp16("1:2:32", checker, args, NULL, 0.08);
  1186. }
  1187. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_1) {
  1188. using namespace conv_bias;
  1189. std::vector<TestArg> args = get_winograd_args(5);
  1190. std::vector<TestArg> args_head_half(args.begin(),
  1191. args.begin() + args.size() / 2);
  1192. Checker<ConvBiasForward> checker(handle());
  1193. //! fp16 range -1.0 ~ 1.0
  1194. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  1195. check_winograd_fp16("1:4:32", checker, args_head_half, rng, 0.25);
  1196. }
  1197. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_2) {
  1198. using namespace conv_bias;
  1199. std::vector<TestArg> args = get_winograd_args(5);
  1200. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2,
  1201. args.end());
  1202. Checker<ConvBiasForward> checker(handle());
  1203. //! fp16 range -1.0 ~ 1.0
  1204. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  1205. check_winograd_fp16("1:4:32", checker, args_back_half, rng, 0.25);
  1206. }
  1207. //! FIXME: This test may be failed if run `ARM_COMMON.CONV_BIAS_WINOGRAD*`, but
  1208. //! it will pass when run single testcase
  1209. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F63) {
  1210. using namespace conv_bias;
  1211. std::vector<TestArg> args = get_winograd_args(3);
  1212. Checker<ConvBiasForward> checker(handle());
  1213. //! fp16 range -1.0 ~ 1.0
  1214. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  1215. check_winograd_fp16("1:6:32", checker, args, rng, 0.3);
  1216. }
  1217. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_1) {
  1218. using namespace conv_bias;
  1219. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  1220. std::vector<TestArg> args_head_half(args.begin(),
  1221. args.begin() + args.size() / 2);
  1222. Checker<ConvBiasForward> checker(handle());
  1223. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  1224. check_winograd_fp16("8:2:32", checker, args_head_half, rng, 0.25,
  1225. param::MatrixMul::Format::MK8);
  1226. }
  1227. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_2) {
  1228. using namespace conv_bias;
  1229. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  1230. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2,
  1231. args.end());
  1232. Checker<ConvBiasForward> checker(handle());
  1233. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  1234. check_winograd_fp16("8:2:32", checker, args_back_half, rng, 0.25,
  1235. param::MatrixMul::Format::MK8);
  1236. }
  1237. #endif
  1238. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_INT8_8X8) {
  1239. using namespace conv_bias;
  1240. std::vector<TestArg> args = get_quantized_winograd_mk_packed_args(8);
  1241. Checker<ConvBiasForward> checker(handle());
  1242. UniformIntRNG rng{-50, 50};
  1243. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  1244. .set_dtype(1, dtype::QuantizedS8(2.5f))
  1245. .set_dtype(2, dtype::QuantizedS32(6.25f))
  1246. .set_dtype(4, dtype::QuantizedS8(60.25f))
  1247. .set_rng(0, &rng)
  1248. .set_rng(1, &rng)
  1249. .set_rng(2, &rng);
  1250. check_winograd("8:2:32", checker, args, param::MatrixMul::Format::MK8);
  1251. }
  1252. void checker_conv_bias(std::vector<conv_bias::TestArg> args, Handle* handle,
  1253. RNG* rng, float epsilon, DType type0, DType type1,
  1254. DType type2, DType type3, const char* algo_name) {
  1255. using namespace conv_bias;
  1256. Checker<ConvBias> checker(handle);
  1257. checker.set_before_exec_callback(
  1258. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  1259. checker.set_dtype(0, type0);
  1260. checker.set_dtype(1, type1);
  1261. checker.set_dtype(2, type2);
  1262. checker.set_dtype(4, type3);
  1263. checker.set_epsilon(epsilon);
  1264. if (NULL != rng) {
  1265. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng).set_rng(3, rng);
  1266. }
  1267. for (auto&& arg : args) {
  1268. checker.set_param(arg.param).execs(
  1269. {arg.src, arg.filter, arg.bias, {}, {}});
  1270. }
  1271. }
  1272. // clang-format off
  1273. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_STRIDE2) {
  1274. #define cb(name) \
  1275. check_conv_bias( \
  1276. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 2, false, false, false), \
  1277. handle(), name);
  1278. #if MEGDNN_AARCH64
  1279. cb("IM2COLMATMUL:AARCH64_F32K8X12X1")
  1280. cb("IM2COLMATMUL:AARCH64_F32K4X16X1")
  1281. cb("IM2COLMATMUL:FB_F32_K8X12X1")
  1282. #elif MEGDNN_ARMV7
  1283. cb("IM2COLMATMUL:ARMV7_F32")
  1284. #endif
  1285. #undef cb
  1286. }
  1287. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_STRIDE1) {
  1288. #define cb(name) \
  1289. check_conv_bias( \
  1290. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, false), \
  1291. handle(), name);
  1292. #if MEGDNN_AARCH64
  1293. cb("IM2COLMATMUL:AARCH64_F32K8X12X1")
  1294. cb("IM2COLMATMUL:AARCH64_F32K4X16X1")
  1295. cb("IM2COLMATMUL:FB_F32_K8X12X1")
  1296. #elif MEGDNN_ARMV7
  1297. cb("IM2COLMATMUL:ARMV7_F32")
  1298. cb("IM2COLMATMUL:FB_F32_K8X12X1")
  1299. #endif
  1300. #undef cb
  1301. }
  1302. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) {
  1303. UniformIntRNG rng{-50, 50};
  1304. #define cb(name) \
  1305. checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \
  1306. false, true, true), \
  1307. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1308. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1309. dtype::QuantizedS8(60.25f), name); \
  1310. checker_conv_bias( \
  1311. get_conv_bias_args({1}, 2, false, false, false, true, true), \
  1312. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1313. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1314. dtype::QuantizedS8(60.25f), name);
  1315. float epsilon = 0.001;
  1316. #if MEGDNN_AARCH64
  1317. #if __ARM_FEATURE_DOTPROD
  1318. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD");
  1319. #else
  1320. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8");
  1321. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16");
  1322. #endif
  1323. #elif MEGDNN_ARMV7
  1324. epsilon = 1;
  1325. cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8");
  1326. #endif
  1327. #undef cb
  1328. }
  1329. #if __ARM_FEATURE_DOTPROD
  1330. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_MK4_DOT) {
  1331. UniformIntRNG rng{-50, 50};
  1332. #define cb(name) \
  1333. checker_conv_bias(get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, \
  1334. false, false, false, true), \
  1335. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1336. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1337. dtype::QuantizedS8(60.25f), name); \
  1338. checker_conv_bias( \
  1339. get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true), \
  1340. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1341. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1342. dtype::QuantizedS8(60.25f), name);
  1343. float epsilon = 0.001;
  1344. #if MEGDNN_AARCH64
  1345. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96");
  1346. #elif MEGDNN_ARMV7
  1347. cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X4X4_DOTPROD:96");
  1348. #endif
  1349. #undef cb
  1350. }
  1351. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_S8x8x32_MK4_DOT) {
  1352. UniformIntRNG rng{-50, 50};
  1353. #define cb(name) \
  1354. checker_conv_bias( \
  1355. get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \
  1356. true, false, true, false, false, true), \
  1357. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1358. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); \
  1359. checker_conv_bias( \
  1360. get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \
  1361. false, false, true), \
  1362. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1363. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name);
  1364. float epsilon = 0.001;
  1365. #if MEGDNN_AARCH64
  1366. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96");
  1367. #elif MEGDNN_ARMV7
  1368. cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X4X4_DOTPROD:96");
  1369. #endif
  1370. #undef cb
  1371. }
  1372. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32_MK4_DOT) {
  1373. UniformIntRNG rng{-50, 50};
  1374. #define cb(name) \
  1375. checker_conv_bias( \
  1376. get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \
  1377. true, false, true, false, false, true), \
  1378. handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \
  1379. dtype::Int32(), {}, name); \
  1380. checker_conv_bias( \
  1381. get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \
  1382. false, false, true), \
  1383. handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \
  1384. dtype::Int32(), {}, name);
  1385. float epsilon = 0.001;
  1386. #if MEGDNN_AARCH64
  1387. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96");
  1388. #elif MEGDNN_ARMV7
  1389. cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X4X4_DOTPROD:96");
  1390. #endif
  1391. #undef cb
  1392. }
  1393. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CONV1x1_QUANTIZEDSYM_MK4_DOT) {
  1394. UniformIntRNG rng{-50, 50};
  1395. #define cb(name) \
  1396. checker_conv_bias( \
  1397. get_nchw44_conv_bias_args({1}, 1, true, true, false, false, true), \
  1398. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1399. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1400. dtype::QuantizedS8(60.25f), name); \
  1401. checker_conv_bias( \
  1402. get_nchw44_conv_bias_args({1}, 1, true, true, true, false, true, \
  1403. false, false, true), \
  1404. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1405. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); \
  1406. checker_conv_bias( \
  1407. get_nchw44_conv_bias_args({1}, 1, true, true, true, false, true, \
  1408. false, false, true), \
  1409. handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \
  1410. dtype::Int32(), {}, name);
  1411. float epsilon = 0.001;
  1412. #if MEGDNN_AARCH64
  1413. cb("CONV1x1:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD");
  1414. #elif MEGDNN_ARMV7
  1415. cb("CONV1x1:AARCH32_INT8_MK4_8X4X4_DOTPROD");
  1416. #endif
  1417. #undef cb
  1418. }
  1419. #endif
  1420. // clang-format on
  1421. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  1422. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) {
  1423. NormalRNG rng(128.f);
  1424. #define cb(name) \
  1425. checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \
  1426. false, true, true), \
  1427. handle(), &rng, epsilon, \
  1428. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  1429. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  1430. dtype::QuantizedS32(1.2 * 1.3), \
  1431. dtype::Quantized8Asymm(50.3f, (uint8_t)120), name); \
  1432. checker_conv_bias( \
  1433. get_conv_bias_args({1}, 2, false, false, false, true, true), \
  1434. handle(), &rng, epsilon, \
  1435. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  1436. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  1437. dtype::QuantizedS32(1.2 * 1.3), \
  1438. dtype::Quantized8Asymm(50.3f, (uint8_t)120), name);
  1439. float epsilon = 0.001;
  1440. #if MEGDNN_AARCH64
  1441. #if __ARM_FEATURE_DOTPROD
  1442. cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD");
  1443. #else
  1444. cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8");
  1445. #endif
  1446. #elif MEGDNN_ARMV7
  1447. epsilon = 1;
  1448. cb("IM2COLMATMUL:ARMV7_QUINT8_K4X8X8");
  1449. #endif
  1450. #undef cb
  1451. }
  1452. #endif
  1453. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  1454. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) {
  1455. UniformIntRNG rng{-50, 50};
  1456. float epsilon = 0.001;
  1457. #define cb(name) \
  1458. checker_conv_bias( \
  1459. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \
  1460. handle(), &rng, epsilon, \
  1461. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  1462. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  1463. dtype::QuantizedS32(1.2 * 1.3), {}, name); \
  1464. checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \
  1465. &rng, epsilon, \
  1466. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  1467. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  1468. dtype::QuantizedS32(1.2 * 1.3), {}, name);
  1469. #if MEGDNN_AARCH64
  1470. #if __ARM_FEATURE_DOTPROD
  1471. cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD");
  1472. #else
  1473. cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8");
  1474. #endif
  1475. #elif MEGDNN_ARMV7
  1476. #if __ARM_FEATURE_DOTPROD
  1477. cb("IM2COLMATMUL:AARCH32_QUINT8_K4X8X4");
  1478. #endif
  1479. cb("IM2COLMATMUL:ARMV7_QUINT8_K4X8X8");
  1480. #endif
  1481. #undef cb
  1482. }
  1483. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) {
  1484. UniformIntRNG rng{-50, 50};
  1485. float epsilon = 0.001;
  1486. #define cb(name) \
  1487. checker_conv_bias( \
  1488. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \
  1489. handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \
  1490. dtype::Int16{}, dtype::Int16{}, name); \
  1491. checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \
  1492. &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \
  1493. dtype::Int16{}, dtype::Int16{}, name);
  1494. #if MEGDNN_AARCH64
  1495. cb("IM2COLMATMUL:AARCH64_INT8X8X16_K8X8X8");
  1496. cb("IM2COLMATMUL:AARCH64_INT8X8X16_K4X4X16");
  1497. cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16");
  1498. #elif MEGDNN_ARMV7
  1499. cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16");
  1500. cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8");
  1501. cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X2X16");
  1502. #endif
  1503. #undef cb
  1504. }
  1505. #endif
  1506. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  1507. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP16) {
  1508. using namespace conv_bias;
  1509. param::ConvBias cur_param;
  1510. std::vector<conv_bias::TestArg> args =
  1511. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, false);
  1512. std::vector<conv_bias::TestArg> args1 =
  1513. get_conv_bias_args({1}, 2, false, false, false);
  1514. args.insert(args.begin(), args1.begin(), args1.end());
  1515. NormalRNG rng(1);
  1516. #define cb(name) \
  1517. checker_conv_bias(args, handle(), &rng, 0.03, dtype::Float16{}, \
  1518. dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, \
  1519. name);
  1520. #if MEGDNN_AARCH64
  1521. cb("IM2COLMATMUL:AARCH64_F16_K8X24X1");
  1522. #elif MEGDNN_ARMV7
  1523. cb("IM2COLMATMUL:AARCH32_F16_K4X16X1");
  1524. #endif
  1525. #undef cb
  1526. }
  1527. #endif
  1528. void checker_conv_bias_mul_int8x8x32(std::vector<conv_bias::TestArg> args,
  1529. Handle* handle, const char* algo_name) {
  1530. using namespace conv_bias;
  1531. Checker<ConvBias> checker(handle);
  1532. checker.set_before_exec_callback(
  1533. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  1534. checker.set_dtype(0, dtype::Int8());
  1535. checker.set_dtype(1, dtype::Int8());
  1536. checker.set_dtype(2, dtype::Int32());
  1537. checker.set_dtype(4, dtype::Int32());
  1538. for (auto&& arg : args) {
  1539. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  1540. }
  1541. UniformIntRNG rng{-50, 50};
  1542. for (auto&& arg : args) {
  1543. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  1544. .set_dtype(1, dtype::QuantizedS8(2.5f))
  1545. .set_dtype(2, dtype::QuantizedS32(6.25f))
  1546. .set_dtype(4, {})
  1547. .set_rng(0, &rng)
  1548. .set_rng(1, &rng)
  1549. .set_rng(2, &rng)
  1550. .set_param(arg.param)
  1551. .execs({arg.src, arg.filter, {}, {}, {}});
  1552. }
  1553. }
  1554. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  1555. #if !__ARM_FEATURE_DOTPROD
  1556. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) {
  1557. using namespace conv_bias;
  1558. std::vector<conv_bias::TestArg> args =
  1559. get_nchw44_conv_bias_args({2, 5, 7}, 2, false, true, true);
  1560. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1561. #if MEGDNN_AARCH64
  1562. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1563. #else
  1564. cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96");
  1565. #endif
  1566. #undef cb
  1567. }
  1568. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1) {
  1569. using namespace conv_bias;
  1570. std::vector<conv_bias::TestArg> args =
  1571. get_nchw44_conv_bias_args({3, 4, 6}, 1, false, true, true);
  1572. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1573. #if MEGDNN_AARCH64
  1574. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1575. #else
  1576. cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96");
  1577. #endif
  1578. #undef cb
  1579. }
  1580. TEST_F(ARM_COMMON_MULTI_THREADS,
  1581. CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_S2) {
  1582. UniformIntRNG rng{-50, 50};
  1583. #define cb(name) \
  1584. checker_conv_bias(get_nchw44_conv_bias_args({3, 4, 6}, 2), handle(), &rng, \
  1585. epsilon, dtype::QuantizedS8(2.5f), \
  1586. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1587. dtype::QuantizedS8(60.25f), name);
  1588. float epsilon = 0.001;
  1589. #if MEGDNN_AARCH64
  1590. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1591. #else
  1592. cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96");
  1593. #endif
  1594. #undef cb
  1595. }
  1596. TEST_F(ARM_COMMON_MULTI_THREADS,
  1597. CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_S1) {
  1598. UniformIntRNG rng{-50, 50};
  1599. #define cb(name) \
  1600. checker_conv_bias(get_nchw44_conv_bias_args({2, 5, 7}, 1), handle(), &rng, \
  1601. epsilon, dtype::QuantizedS8(2.5f), \
  1602. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1603. dtype::QuantizedS8(60.25f), name);
  1604. float epsilon = 0.001;
  1605. #if MEGDNN_AARCH64
  1606. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1607. #else
  1608. cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96");
  1609. #endif
  1610. #undef cb
  1611. }
  1612. #if MEGDNN_AARCH64
  1613. TEST_F(ARM_COMMON_MULTI_THREADS,
  1614. CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_FUSE) {
  1615. UniformIntRNG rng{-50, 50};
  1616. #define cb(name) \
  1617. checker_conv_bias(get_nchw44_conv_bias_args({3}, 1), handle(), &rng, \
  1618. epsilon, dtype::QuantizedS8(2.5f), \
  1619. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1620. dtype::QuantizedS8(60.25f), name);
  1621. float epsilon = 0.001;
  1622. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
  1623. #undef cb
  1624. }
  1625. #endif
  1626. #endif
  1627. #endif
  1628. #if MEGDNN_AARCH64
  1629. #if __ARM_FEATURE_DOTPROD
  1630. TEST_F(ARM_COMMON_MULTI_THREADS,
  1631. CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44DOT_FUSE) {
  1632. UniformIntRNG rng{-50, 50};
  1633. #define cb(name) \
  1634. checker_conv_bias( \
  1635. get_nchw44_conv_bias_args({3}, 1, false, false, false, false, \
  1636. true, false, false, false), \
  1637. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1638. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1639. dtype::QuantizedS8(60.25f), name);
  1640. float epsilon = 0.001;
  1641. cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96");
  1642. #undef cb
  1643. }
  1644. #endif
  1645. #endif
  1646. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) {
  1647. using namespace conv_bias;
  1648. std::vector<conv_bias::TestArg> args =
  1649. get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true);
  1650. std::vector<conv_bias::TestArg> args1 =
  1651. get_conv_bias_args({1}, 2, false, true, true);
  1652. args.insert(args.begin(), args1.begin(), args1.end());
  1653. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1654. #if MEGDNN_AARCH64
  1655. #if __ARM_FEATURE_DOTPROD
  1656. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD");
  1657. #else
  1658. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8");
  1659. cb("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16");
  1660. #endif
  1661. #elif MEGDNN_ARMV7
  1662. #if __ARM_FEATURE_DOTPROD
  1663. cb("IM2COLMATMUL:AARCH32_INT8_K6X8X4");
  1664. #endif
  1665. cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8");
  1666. #endif
  1667. #if MEGDNN_ARMV7
  1668. cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X2X16");
  1669. #endif
  1670. #undef cb
  1671. }
  1672. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) {
  1673. using namespace conv_bias;
  1674. std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args(
  1675. {2, 4, 7}, 1, false, false, false, false, false, true, true);
  1676. #if MEGDNN_AARCH64
  1677. check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1");
  1678. #elif MEGDNN_ARMV7
  1679. check_conv_bias(args, handle(), "IM2COLMATMUL:ARMV7_F32_MK4_PACK_4X12");
  1680. #endif
  1681. }
  1682. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) {
  1683. using namespace conv_bias;
  1684. std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args(
  1685. {3, 5, 6}, 2, false, false, false, false, false, true, true);
  1686. #if MEGDNN_AARCH64
  1687. check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1");
  1688. #elif MEGDNN_ARMV7
  1689. check_conv_bias(args, handle(), "IM2COLMATMUL:ARMV7_F32_MK4_PACK_4X12");
  1690. #endif
  1691. }
  1692. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32_FUSE) {
  1693. using namespace conv_bias;
  1694. std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args(
  1695. {3}, 2, false, false, false, false, false, true, true, false);
  1696. #if MEGDNN_AARCH64
  1697. check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1");
  1698. #elif MEGDNN_ARMV7
  1699. check_conv_bias(args, handle(), "IM2COLMATMUL:ARMV7_F32_MK4_PACK_4X12");
  1700. #endif
  1701. }
  1702. /***************************** Conv1x1 Algo Test ***********************/
  1703. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) {
  1704. using namespace conv_bias;
  1705. std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, false);
  1706. #if MEGDNN_AARCH64
  1707. check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32K8X12X1:24");
  1708. #elif MEGDNN_ARMV7
  1709. check_conv_bias(args, handle(), "CONV1x1:ARMV7_F32:48");
  1710. #endif
  1711. std::vector<conv_bias::TestArg> gemv_args;
  1712. for (auto&& arg : args)
  1713. if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
  1714. gemv_args.emplace_back(arg);
  1715. }
  1716. check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV");
  1717. }
  1718. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) {
  1719. using namespace conv_bias;
  1720. std::vector<conv_bias::TestArg> args =
  1721. get_nchw44_conv_bias_args({1}, 1, true, false, false);
  1722. #if MEGDNN_AARCH64
  1723. check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32_MK4_K8X12X1:24");
  1724. #elif MEGDNN_ARMV7
  1725. check_conv_bias(args, handle(), "CONV1x1:ARMV7_F32_MK4_PACK_4X12:24");
  1726. #endif
  1727. }
  1728. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_NO_PACK_F32) {
  1729. using namespace conv_bias;
  1730. std::vector<conv_bias::TestArg> args =
  1731. get_nchw44_conv_bias_args({1}, 1, true, false, false);
  1732. std::vector<conv_bias::TestArg> args_of_4;
  1733. for (auto&& arg : args) {
  1734. if (arg.src.shape[2] * arg.src.shape[3] % 4 == 0) {
  1735. args_of_4.push_back(arg);
  1736. }
  1737. }
  1738. #if MEGDNN_AARCH64
  1739. check_conv_bias(args_of_4, handle(), "CONV1x1:AARCH64_F32_MK4_4x16:24");
  1740. #elif MEGDNN_ARMV7
  1741. check_conv_bias(args_of_4, handle(), "CONV1x1:ARMV7_F32_MK4_4x8:48");
  1742. #endif
  1743. }
  1744. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  1745. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16) {
  1746. using namespace conv_bias;
  1747. std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, false);
  1748. NormalRNG rng(1);
  1749. #if MEGDNN_AARCH64
  1750. checker_conv_bias(args, handle(), &rng, 0.03, dtype::Float16{},
  1751. dtype::Float16{}, dtype::Float16{}, dtype::Float16{},
  1752. "CONV1x1:AARCH64_F16_K8X24X1:48");
  1753. #elif MEGDNN_ARMV7
  1754. checker_conv_bias(args, handle(), &rng, 0.03, dtype::Float16{},
  1755. dtype::Float16{}, dtype::Float16{}, dtype::Float16{},
  1756. "CONV1x1:AARCH32_F16_K4X16X1:24");
  1757. #endif
  1758. std::vector<conv_bias::TestArg> gemv_args;
  1759. for (auto&& arg : args)
  1760. if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
  1761. gemv_args.emplace_back(arg);
  1762. }
  1763. check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV");
  1764. }
  1765. #endif
  1766. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM) {
  1767. UniformIntRNG rng{-50, 50};
  1768. float epsilon = 0.001;
  1769. std::vector<conv_bias::TestArg> args =
  1770. get_conv_bias_1x1_args(false, false, true, true);
  1771. #define cb(name) \
  1772. checker_conv_bias(args, handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1773. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1774. dtype::QuantizedS8(60.25f), name);
  1775. #if MEGDNN_AARCH64
  1776. #if __ARM_FEATURE_DOTPROD
  1777. cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:24");
  1778. #else
  1779. cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24");
  1780. cb("CONV1x1:AARCH64_INT8X8X32_K4X4X16:48");
  1781. #endif
  1782. #elif MEGDNN_ARMV7
  1783. epsilon = 1;
  1784. cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:48");
  1785. #endif
  1786. #undef cb
  1787. std::vector<conv_bias::TestArg> gemv_args;
  1788. for (auto&& arg : args)
  1789. if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
  1790. gemv_args.emplace_back(arg);
  1791. }
  1792. checker_conv_bias(gemv_args, handle(), &rng, epsilon,
  1793. dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
  1794. dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f),
  1795. "CONV1x1_GEMV");
  1796. }
  1797. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  1798. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM) {
  1799. UniformIntRNG rng{-50, 50};
  1800. std::vector<conv_bias::TestArg> args =
  1801. get_conv_bias_1x1_args(false, false, true, true);
  1802. #define cb(name) \
  1803. checker_conv_bias(args, handle(), &rng, epsilon, \
  1804. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  1805. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  1806. dtype::QuantizedS32(1.2 * 1.3), \
  1807. dtype::Quantized8Asymm(50.3f, (uint8_t)120), name);
  1808. float epsilon = 0.001;
  1809. #if MEGDNN_AARCH64
  1810. #if __ARM_FEATURE_DOTPROD
  1811. cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:48");
  1812. #else
  1813. cb("CONV1x1:AARCH64_QUINT8_K8X8X8:24");
  1814. #endif
  1815. #elif MEGDNN_ARMV7
  1816. epsilon = 1;
  1817. cb("CONV1x1:ARMV7_QUINT8_K4X8X8:48");
  1818. #endif
  1819. #undef cb
  1820. std::vector<conv_bias::TestArg> gemv_args;
  1821. for (auto&& arg : args)
  1822. if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
  1823. gemv_args.emplace_back(arg);
  1824. }
  1825. checker_conv_bias(gemv_args, handle(), &rng, epsilon,
  1826. dtype::Quantized8Asymm(1.2f, (uint8_t)125),
  1827. dtype::Quantized8Asymm(1.3f, (uint8_t)129),
  1828. dtype::QuantizedS32(1.2 * 1.3),
  1829. dtype::Quantized8Asymm(50.3f, (uint8_t)120),
  1830. "CONV1x1_GEMV");
  1831. }
  1832. #endif
  1833. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  1834. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32) {
  1835. NormalRNG rng(128.f);
  1836. float epsilon = 0.001;
  1837. std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true);
  1838. #define cb(name) \
  1839. checker_conv_bias(args, handle(), &rng, epsilon, \
  1840. dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
  1841. dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
  1842. dtype::QuantizedS32(1.2 * 1.3), {}, name);
  1843. #if MEGDNN_AARCH64
  1844. #if __ARM_FEATURE_DOTPROD
  1845. cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:24");
  1846. #else
  1847. cb("CONV1x1:AARCH64_QUINT8_K8X8X8:48");
  1848. #endif
  1849. #elif MEGDNN_ARMV7
  1850. #if __ARM_FEATURE_DOTPROD
  1851. cb("CONV1x1:AARCH32_QUINT8_K4X8X4:48");
  1852. #endif
  1853. cb("CONV1x1:ARMV7_QUINT8_K4X8X8:24");
  1854. #endif
  1855. #undef cb
  1856. std::vector<conv_bias::TestArg> gemv_args;
  1857. for (auto&& arg : args)
  1858. if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
  1859. gemv_args.emplace_back(arg);
  1860. }
  1861. checker_conv_bias(gemv_args, handle(), &rng, epsilon,
  1862. dtype::Quantized8Asymm(1.2f, (uint8_t)125),
  1863. dtype::Quantized8Asymm(1.3f, (uint8_t)129),
  1864. dtype::QuantizedS32(1.2 * 1.3), {}, "CONV1x1_GEMV");
  1865. }
  1866. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) {
  1867. UniformIntRNG rng{-50, 50};
  1868. float epsilon = 0.001;
  1869. std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true);
  1870. #define cb(name) \
  1871. checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, \
  1872. dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name);
  1873. #if MEGDNN_AARCH64
  1874. cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24");
  1875. cb("CONV1x1:AARCH64_INT8X8X16_K4X4X16:24");
  1876. #elif MEGDNN_ARMV7
  1877. cb("CONV1x1:ARMV7_INT8X8X16_K4X8X8:24");
  1878. cb("CONV1x1:ARMV7_INT8X8X16_K4X2X16:48");
  1879. #endif
  1880. cb("CONV1x1:ARM_COMMON_INT8X8X16:48");
  1881. #undef cb
  1882. std::vector<conv_bias::TestArg> gemv_args;
  1883. for (auto&& arg : args)
  1884. if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
  1885. gemv_args.emplace_back(arg);
  1886. }
  1887. checker_conv_bias(gemv_args, handle(), &rng, epsilon, dtype::Int8{},
  1888. dtype::Int8{}, dtype::Int16{}, dtype::Int16{},
  1889. "CONV1x1_GEMV");
  1890. }
  1891. #endif
  1892. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) {
  1893. using namespace conv_bias;
  1894. std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true);
  1895. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1896. #if MEGDNN_AARCH64
  1897. #if __ARM_FEATURE_DOTPROD
  1898. cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:48");
  1899. #else
  1900. cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24");
  1901. cb("CONV1x1:AARCH64_INT8X8X32_K4X4X16:24");
  1902. #endif
  1903. #elif MEGDNN_ARMV7
  1904. #if __ARM_FEATURE_DOTPROD
  1905. cb("CONV1x1:AARCH32_INT8_K6X8X4:48");
  1906. #endif
  1907. cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:24");
  1908. #endif
  1909. #if MEGDNN_ARMV7
  1910. cb("CONV1x1:ARMV7_INT8X8X32_K4X2X16:48");
  1911. #endif
  1912. #undef cb
  1913. std::vector<conv_bias::TestArg> gemv_args;
  1914. for (auto&& arg : args)
  1915. if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
  1916. gemv_args.emplace_back(arg);
  1917. }
  1918. checker_conv_bias_mul_int8x8x32(gemv_args, handle(), "CONV1x1_GEMV");
  1919. }
  1920. #ifndef __ARM_FEATURE_DOTPROD
  1921. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4) {
  1922. using namespace conv_bias;
  1923. std::vector<conv_bias::TestArg> args =
  1924. get_nchw44_conv_bias_args({1}, 1, true, true, true);
  1925. #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
  1926. #if MEGDNN_AARCH64
  1927. cb("CONV1x1:AARCH64_INT8X8X32_MK4_4X4X16:24");
  1928. #elif MEGDNN_ARMV7
  1929. cb("CONV1x1:ARMV7_INT8X8X32_MK4_4X2X16:24");
  1930. #endif
  1931. #undef cb
  1932. UniformIntRNG rng{-50, 50};
  1933. float epsilon = 0.001;
  1934. #define cb(name) \
  1935. checker_conv_bias(get_nchw44_conv_bias_args({1}, 1, true, false, false), \
  1936. handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
  1937. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
  1938. dtype::QuantizedS8(60.25f), name);
  1939. #if MEGDNN_AARCH64
  1940. cb("CONV1x1:AARCH64_INT8X8X32_MK4_4X4X16:24");
  1941. #elif MEGDNN_ARMV7
  1942. cb("CONV1x1:ARMV7_INT8X8X32_MK4_4X2X16:24");
  1943. #endif
  1944. #undef cb
  1945. }
  1946. #endif
  1947. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台