You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algos.cpp 63 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416
  1. #include "src/aarch64/matrix_mul/algos.h"
  2. #include "src/aarch64/matrix_mul/fp16/strategy.h"
  3. #include "src/aarch64/matrix_mul/fp32/strategy.h"
  4. #include "src/aarch64/matrix_mul/int16/strategy.h"
  5. #include "src/aarch64/matrix_mul/int4x4x16/strategy.h"
  6. #include "src/aarch64/matrix_mul/int8/strategy.h"
  7. #include "src/aarch64/matrix_mul/int8_dot/strategy.h"
  8. #include "src/aarch64/matrix_mul/int8x8x16/strategy.h"
  9. #include "src/aarch64/matrix_mul/quint8/strategy.h"
  10. #include "src/aarch64/matrix_mul/quint8_dot/gemv.h"
  11. #include "src/aarch64/matrix_mul/quint8_dot/strategy.h"
  12. #include "src/common/utils.h"
  13. #include "src/fallback/matrix_mul/gemm_impl.h"
  14. #include "midout.h"
  15. MIDOUT_DECL(megdnn_aarch64_matmul_kern)
  16. using namespace megdnn;
  17. using namespace aarch64;
  18. /* ===================== F32K8X12X1 algo ===================== */
  19. bool MatrixMulImpl::AlgoF32K8x12x1::usable(const KernSizeParam& kern_size_param) const {
  20. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  21. kern_size_param.B_type == kern_size_param.A_type &&
  22. kern_size_param.C_type == kern_size_param.A_type &&
  23. kern_size_param.A_type == dtype::Float32() &&
  24. kern_size_param.format == param::MatrixMul::Format::DEFAULT;
  25. }
  26. size_t MatrixMulImpl::AlgoF32K8x12x1::get_workspace(
  27. const KernSizeParam& kern_size_param) const {
  28. MIDOUT_BEGIN(
  29. megdnn_aarch64_matmul_kern,
  30. midout_iv("AlgoF32K8x12x1::get_workspace"_hash)) {
  31. auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
  32. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  33. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  34. C_type = kern_size_param.C_type;
  35. aarch64::matmul::sgemm_8x12 strategy(M, N, K, A_type, B_type, C_type);
  36. return megdnn::matmul::GemmInterleaved<aarch64::matmul::sgemm_8x12>(
  37. M, N, K, trA, trB, strategy)
  38. .get_workspace_size();
  39. }
  40. MIDOUT_END();
  41. return 0;
  42. }
  43. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern(
  44. const KernSizeParam&) const {
  45. auto f32_kern_8x12 = [](const MatrixMulImpl::KernParam& kern_param) {
  46. MIDOUT_BEGIN(
  47. megdnn_aarch64_matmul_kern,
  48. midout_iv("AlgoF32K8x12x1::get_kern"_hash)) {
  49. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  50. auto trA = kern_param.trA, trB = kern_param.trB;
  51. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  52. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  53. C_type = kern_param.C_type;
  54. const auto Aptr = kern_param.A<float>(), Bptr = kern_param.B<float>();
  55. auto Cptr = kern_param.C<float>();
  56. aarch64::matmul::sgemm_8x12 strategy(M, N, K, A_type, B_type, C_type);
  57. megdnn::matmul::GemmInterleaved<aarch64::matmul::sgemm_8x12>(
  58. M, N, K, trA, trB, strategy)
  59. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
  60. }
  61. MIDOUT_END();
  62. };
  63. return f32_kern_8x12;
  64. }
  65. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(
  66. AlgoF32K8x12x1, megdnn_aarch64_matmul_kern, "AlgoF32K8x12x1Impl"_hash,
  67. aarch64::matmul::sgemm_8x12, float, float, AlgoDataType::FLOAT32, DEFAULT);
  68. /* ===================== F32_MK4_8X12X1 algo ===================== */
  69. bool MatrixMulImpl::AlgoF32MK4_8x12x1::usable(
  70. const KernSizeParam& kern_size_param) const {
  71. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  72. kern_size_param.B_type == kern_size_param.A_type &&
  73. kern_size_param.C_type == kern_size_param.A_type &&
  74. kern_size_param.A_type == dtype::Float32() &&
  75. kern_size_param.format == param::MatrixMul::Format::MK4 &&
  76. !kern_size_param.trA && !kern_size_param.trB && kern_size_param.M % 4 == 0 &&
  77. kern_size_param.K % 4 == 0;
  78. }
  79. size_t MatrixMulImpl::AlgoF32MK4_8x12x1::get_workspace(
  80. const KernSizeParam& kern_size_param) const {
  81. MIDOUT_BEGIN(
  82. megdnn_aarch64_matmul_kern,
  83. midout_iv("AlgoF32MK4_8x12x1::get_workspace"_hash)) {
  84. auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
  85. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  86. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  87. C_type = kern_size_param.C_type;
  88. aarch64::matmul::sgemm_mk4_8x12 strategy(M, N, K, A_type, B_type, C_type);
  89. return megdnn::matmul::GemmInterleaved<aarch64::matmul::sgemm_mk4_8x12>(
  90. M, N, K, trA, trB, strategy)
  91. .get_workspace_size();
  92. }
  93. MIDOUT_END();
  94. return 0;
  95. }
  96. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_8x12x1::get_kern(
  97. const KernSizeParam&) const {
  98. auto f32_kern_mk4_8x12 = [](const MatrixMulImpl::KernParam& kern_param) {
  99. MIDOUT_BEGIN(
  100. megdnn_aarch64_matmul_kern,
  101. midout_iv("AlgoF32MK4_8x12x1::get_kern"_hash)) {
  102. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  103. auto trA = kern_param.trA, trB = kern_param.trB;
  104. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  105. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  106. C_type = kern_param.C_type;
  107. const auto Aptr = kern_param.A<float>(), Bptr = kern_param.B<float>();
  108. auto Cptr = kern_param.C<float>();
  109. aarch64::matmul::sgemm_mk4_8x12 strategy(M, N, K, A_type, B_type, C_type);
  110. megdnn::matmul::GemmInterleaved<aarch64::matmul::sgemm_mk4_8x12>(
  111. M, N, K, trA, trB, strategy)
  112. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
  113. }
  114. MIDOUT_END();
  115. };
  116. return f32_kern_mk4_8x12;
  117. }
  118. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(
  119. AlgoF32MK4_8x12x1, megdnn_aarch64_matmul_kern, "AlgoF32MK4_8x12x1Impl"_hash,
  120. aarch64::matmul::sgemm_mk4_8x12, float, float, AlgoDataType::FLOAT32, MK4);
  121. /* ===================== F32K4X16X1 algo ===================== */
  122. bool MatrixMulImpl::AlgoF32K4x16x1::usable(const KernSizeParam& kern_size_param) const {
  123. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  124. kern_size_param.B_type == kern_size_param.A_type &&
  125. kern_size_param.C_type == kern_size_param.A_type &&
  126. kern_size_param.A_type == dtype::Float32() &&
  127. kern_size_param.format == param::MatrixMul::Format::DEFAULT;
  128. }
  129. size_t MatrixMulImpl::AlgoF32K4x16x1::get_workspace(
  130. const KernSizeParam& kern_size_param) const {
  131. MIDOUT_BEGIN(
  132. megdnn_aarch64_matmul_kern,
  133. midout_iv("AlgoF32K4x16x1::get_workspace"_hash)) {
  134. auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
  135. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  136. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  137. C_type = kern_size_param.C_type;
  138. aarch64::matmul::sgemm_4x16 strategy(M, N, K, A_type, B_type, C_type);
  139. return megdnn::matmul::GemmInterleaved<aarch64::matmul::sgemm_4x16>(
  140. M, N, K, trA, trB, strategy)
  141. .get_workspace_size();
  142. }
  143. MIDOUT_END();
  144. return 0;
  145. }
  146. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K4x16x1::get_kern(
  147. const KernSizeParam&) const {
  148. auto f32_kern_4x16 = [](const MatrixMulImpl::KernParam& kern_param) {
  149. MIDOUT_BEGIN(
  150. megdnn_aarch64_matmul_kern,
  151. midout_iv("AlgoF32K4x16x1::get_kern"_hash)) {
  152. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  153. auto trA = kern_param.trA, trB = kern_param.trB;
  154. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  155. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  156. C_type = kern_param.C_type;
  157. const auto Aptr = kern_param.A<float>(), Bptr = kern_param.B<float>();
  158. auto Cptr = kern_param.C<float>();
  159. aarch64::matmul::sgemm_4x16 strategy(M, N, K, A_type, B_type, C_type);
  160. megdnn::matmul::GemmInterleaved<aarch64::matmul::sgemm_4x16>(
  161. M, N, K, trA, trB, strategy)
  162. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
  163. }
  164. MIDOUT_END();
  165. };
  166. return f32_kern_4x16;
  167. }
  168. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(
  169. AlgoF32K4x16x1, megdnn_aarch64_matmul_kern, "AlgoF32K4x16x1Impl"_hash,
  170. aarch64::matmul::sgemm_4x16, float, float, AlgoDataType::FLOAT32, MK4);
  171. /* ===================== F32MK4_4x16 algo ===================== */
  172. bool MatrixMulImpl::AlgoF32MK4_4x16::usable(
  173. const KernSizeParam& kern_size_param) const {
  174. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  175. kern_size_param.C_type == dtype::Float32() &&
  176. kern_size_param.B_type == dtype::Float32() &&
  177. kern_size_param.A_type == dtype::Float32() &&
  178. kern_size_param.format == param::MatrixMul::Format::MK4 &&
  179. !kern_size_param.trA && !kern_size_param.trB;
  180. }
  181. size_t MatrixMulImpl::AlgoF32MK4_4x16::get_workspace(
  182. const KernSizeParam& kern_size_param) const {
  183. MIDOUT_BEGIN(
  184. megdnn_aarch64_matmul_kern,
  185. midout_iv("AlgoF32MK4_4x16::get_workspace"_hash)) {
  186. auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
  187. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  188. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  189. C_type = kern_size_param.C_type;
  190. aarch64::matmul::sgemm_nopack_4x16 strategy(A_type, B_type, C_type);
  191. return megdnn::matmul::GemmInterleaved<
  192. aarch64::matmul::sgemm_nopack_4x16, false>(
  193. M, N, K, trA, trB, strategy)
  194. .get_workspace_size();
  195. }
  196. MIDOUT_END();
  197. return 0;
  198. }
  199. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_4x16::get_kern(
  200. const KernSizeParam&) const {
  201. auto f32_kern_mk4_4x16 = [](const MatrixMulImpl::KernParam& kern_param) {
  202. MIDOUT_BEGIN(
  203. megdnn_aarch64_matmul_kern,
  204. midout_iv("AlgoF32MK4_4x16::get_kern"_hash)) {
  205. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  206. auto trA = kern_param.trA, trB = kern_param.trB;
  207. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  208. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  209. C_type = kern_param.C_type;
  210. const auto Aptr = kern_param.A<float>(), Bptr = kern_param.B<float>();
  211. auto Cptr = kern_param.C<float>();
  212. aarch64::matmul::sgemm_nopack_4x16 strategy(A_type, B_type, C_type);
  213. megdnn::matmul::GemmInterleaved<aarch64::matmul::sgemm_nopack_4x16, false>(
  214. M, N, K, trA, trB, strategy)
  215. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
  216. }
  217. MIDOUT_END();
  218. };
  219. return f32_kern_mk4_4x16;
  220. }
  221. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  222. /* ===================== F16 K8x24x1 algo ===================== */
  223. namespace {
  224. void f16_kern(const MatrixMulImpl::KernParam& kern_param) {
  225. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, midout_iv("f16_kern"_hash)) {
  226. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  227. auto trA = kern_param.trA, trB = kern_param.trB;
  228. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  229. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  230. C_type = kern_param.C_type;
  231. const auto Aptr = kern_param.A<dt_float16>(), Bptr = kern_param.B<dt_float16>();
  232. auto Cptr = kern_param.C<dt_float16>();
  233. aarch64::matmul::hgemm_8x24 strategy(M, N, K, A_type, B_type, C_type);
  234. megdnn::matmul::GemmInterleaved<aarch64::matmul::hgemm_8x24>(
  235. M, N, K, trA, trB, strategy)
  236. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
  237. }
  238. MIDOUT_END();
  239. }
  240. } // anonymous namespace
  241. bool MatrixMulImpl::AlgoF16K8x24x1::usable(const KernSizeParam& kern_size_param) const {
  242. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  243. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  244. kern_size_param.C_type == kern_size_param.A_type &&
  245. kern_size_param.B_type == kern_size_param.A_type &&
  246. kern_size_param.A_type == dtype::Float16();
  247. }
  248. size_t MatrixMulImpl::AlgoF16K8x24x1::get_workspace(
  249. const KernSizeParam& kern_size_param) const {
  250. MIDOUT_BEGIN(
  251. megdnn_aarch64_matmul_kern,
  252. midout_iv("AlgoF16K8x24x1::get_workspace"_hash)) {
  253. auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
  254. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  255. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  256. C_type = kern_size_param.C_type;
  257. aarch64::matmul::hgemm_8x24 strategy(M, N, K, A_type, B_type, C_type);
  258. return megdnn::matmul::GemmInterleaved<aarch64::matmul::hgemm_8x24>(
  259. M, N, K, trA, trB, strategy)
  260. .get_workspace_size();
  261. }
  262. MIDOUT_END();
  263. return 0;
  264. }
  265. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K8x24x1::get_kern(
  266. const KernSizeParam&) const {
  267. return f16_kern;
  268. }
  269. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(
  270. AlgoF16K8x24x1, megdnn_aarch64_matmul_kern, "AlogF16K8x24x1Impl"_hash,
  271. aarch64::matmul::hgemm_8x24, dt_float16, dt_float16, AlgoDataType::FLOAT16,
  272. DEFAULT);
  273. /* ===================== F16_MK8_8x8 algo ===================== */
  274. bool MatrixMulImpl::AlgoF16MK8_8x8::usable(const KernSizeParam& kern_size_param) const {
  275. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  276. kern_size_param.C_type == kern_size_param.A_type &&
  277. kern_size_param.B_type == kern_size_param.A_type &&
  278. kern_size_param.A_type == dtype::Float16() &&
  279. kern_size_param.format == param::MatrixMul::Format::MK8 &&
  280. !kern_size_param.trA && !kern_size_param.trB;
  281. }
  282. size_t MatrixMulImpl::AlgoF16MK8_8x8::get_workspace(
  283. const KernSizeParam& kern_size_param) const {
  284. MIDOUT_BEGIN(
  285. megdnn_aarch64_matmul_kern,
  286. midout_iv("AlgoF16MK8_8x8::get_workspace"_hash)) {
  287. auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
  288. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  289. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  290. C_type = kern_size_param.C_type;
  291. aarch64::matmul::gemm_nopack_f16_8x8 strategy(A_type, B_type, C_type);
  292. return megdnn::matmul::GemmInterleaved<
  293. aarch64::matmul::gemm_nopack_f16_8x8, false>(
  294. M, N, K, trA, trB, strategy)
  295. .get_workspace_size();
  296. }
  297. MIDOUT_END();
  298. return 0;
  299. }
  300. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_8x8::get_kern(
  301. const KernSizeParam&) const {
  302. auto kern_mk8_8x8 = [](const MatrixMulImpl::KernParam& kern_param) {
  303. MIDOUT_BEGIN(
  304. megdnn_aarch64_matmul_kern,
  305. midout_iv("AlgoF16MK8_8x8::get_kern"_hash)) {
  306. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  307. auto trA = kern_param.trA, trB = kern_param.trB;
  308. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  309. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  310. C_type = kern_param.C_type;
  311. const auto Aptr = kern_param.A<dt_float16>(),
  312. Bptr = kern_param.B<dt_float16>();
  313. auto Cptr = kern_param.C<dt_float16>();
  314. aarch64::matmul::gemm_nopack_f16_8x8 strategy(A_type, B_type, C_type);
  315. megdnn::matmul::GemmInterleaved<
  316. aarch64::matmul::gemm_nopack_f16_8x8, false>(
  317. M, N, K, trA, trB, strategy)
  318. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
  319. }
  320. MIDOUT_END();
  321. };
  322. return kern_mk8_8x8;
  323. }
  324. /* ==================== F16_MK8_16x12x1 algo ====================*/
  325. bool MatrixMulImpl::AlgoF16MK8_16x12x1::usable(
  326. const KernSizeParam& kern_size_param) const {
  327. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  328. kern_size_param.C_type == kern_size_param.A_type &&
  329. kern_size_param.B_type == kern_size_param.A_type &&
  330. kern_size_param.A_type == dtype::Float16() &&
  331. kern_size_param.format == param::MatrixMul::Format::MK8 &&
  332. !kern_size_param.trA && !kern_size_param.trB;
  333. }
  334. size_t MatrixMulImpl::AlgoF16MK8_16x12x1::get_workspace(
  335. const KernSizeParam& kern_size_param) const {
  336. MIDOUT_BEGIN(
  337. megdnn_aarch64_matmul_kern,
  338. midout_iv("AlgoF16MK8_16x12x1::get_workspace"_hash)) {
  339. auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
  340. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  341. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  342. C_type = kern_size_param.C_type;
  343. aarch64::matmul::hgemm_mk8_16x12 strategy(M, N, K, A_type, B_type, C_type);
  344. return megdnn::matmul::GemmInterleaved<aarch64::matmul::hgemm_mk8_16x12>(
  345. M, N, K, trA, trB, strategy)
  346. .get_workspace_size();
  347. }
  348. MIDOUT_END();
  349. return 0;
  350. }
  351. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_16x12x1::get_kern(
  352. const KernSizeParam&) const {
  353. auto kern_mk8_16x12x1 = [](const MatrixMulImpl::KernParam& kern_param) {
  354. MIDOUT_BEGIN(
  355. megdnn_aarch64_matmul_kern,
  356. midout_iv("AlgoF16MK8_16x12x1::get_kern"_hash)) {
  357. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  358. auto trA = kern_param.trA, trB = kern_param.trB;
  359. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  360. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  361. C_type = kern_param.C_type;
  362. const auto Aptr = kern_param.A<dt_float16>(),
  363. Bptr = kern_param.B<dt_float16>();
  364. auto Cptr = kern_param.C<dt_float16>();
  365. aarch64::matmul::hgemm_mk8_16x12 strategy(M, N, K, A_type, B_type, C_type);
  366. megdnn::matmul::GemmInterleaved<aarch64::matmul::hgemm_mk8_16x12>(
  367. M, N, K, trA, trB, strategy)
  368. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
  369. }
  370. MIDOUT_END();
  371. };
  372. return kern_mk8_16x12x1;
  373. }
  374. #endif
  375. #if MGB_ENABLE_DOT
  376. /* ==================== Int8x8x32 K8x12x4 Dotprod algo ==================== */
  377. namespace {
  378. void int8x8x32_k8x12x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
  379. MIDOUT_BEGIN(
  380. megdnn_aarch64_matmul_kern,
  381. midout_iv("int8x8x32_k8x12x4_dotprod_kern"_hash)) {
  382. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  383. auto trA = kern_param.trA, trB = kern_param.trB;
  384. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  385. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  386. C_type = kern_param.C_type;
  387. const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
  388. auto Cptr = kern_param.C<dt_int32>();
  389. aarch64::matmul::gemm_s8_8x12 strategy(M, N, K, A_type, B_type, C_type);
  390. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s8_8x12>(
  391. M, N, K, trA, trB, strategy)
  392. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
  393. }
  394. MIDOUT_END();
  395. }
  396. } // anonymous namespace
  397. bool MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::usable(
  398. const KernSizeParam& kern_size_param) const {
  399. if (!cpuinfo_has_arm_neon_dot()) {
  400. return false;
  401. }
  402. return can_be_treated_as_int8x8x32(kern_size_param);
  403. }
  404. size_t MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::get_workspace(
  405. const KernSizeParam& kern_size_param) const {
  406. MIDOUT_BEGIN(
  407. megdnn_aarch64_matmul_kern,
  408. midout_iv("AlgoInt8x8x32K8x12x4DotProd::get_workspace"_hash)) {
  409. auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
  410. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  411. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  412. C_type = kern_size_param.C_type;
  413. aarch64::matmul::gemm_s8_8x12 strategy(M, N, K, A_type, B_type, C_type);
  414. return megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s8_8x12>(
  415. M, N, K, trA, trB, strategy)
  416. .get_workspace_size();
  417. }
  418. MIDOUT_END();
  419. return 0;
  420. }
  421. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::get_kern(
  422. const KernSizeParam&) const {
  423. return int8x8x32_k8x12x4_dotprod_kern;
  424. }
  425. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(
  426. AlgoInt8x8x32K8x12x4DotProd, megdnn_aarch64_matmul_kern,
  427. "AlgoInt8x8x32K8x12x4DotProdImpl"_hash, aarch64::matmul::gemm_s8_8x12, int8_t,
  428. int32_t, AlgoDataType::QINT8X8X32, DEFAULT);
  429. /* =================== Int8x8x32 MK4 8X12X4 Dotprod algo =================== */
  430. namespace {
  431. void int8x8x32_mk4_8x12x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
  432. MIDOUT_BEGIN(
  433. megdnn_aarch64_matmul_kern,
  434. midout_iv("int8x8x32_mk4_8x12x4_dotprod_kern"_hash)) {
  435. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  436. auto trA = kern_param.trA, trB = kern_param.trB;
  437. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  438. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  439. C_type = kern_param.C_type;
  440. const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
  441. auto Cptr = kern_param.C<dt_int32>();
  442. aarch64::matmul::gemm_mk4_s8_8x12 strategy(M, N, K, A_type, B_type, C_type);
  443. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_mk4_s8_8x12>(
  444. M, N, K, trA, trB, strategy)
  445. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
  446. }
  447. MIDOUT_END();
  448. }
  449. } // anonymous namespace
  450. bool MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::usable(
  451. const KernSizeParam& kern_size_param) const {
  452. if (!cpuinfo_has_arm_neon_dot()) {
  453. return false;
  454. }
  455. return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
  456. (kern_size_param.A_type.enumv() == DTypeEnum::Int8 ||
  457. kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) &&
  458. (kern_size_param.C_type.enumv() == DTypeEnum::Int32 ||
  459. kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32) &&
  460. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  461. kern_size_param.format == param::MatrixMul::Format::MK4_DOT &&
  462. !kern_size_param.trA && !kern_size_param.trB;
  463. }
  464. size_t MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::get_workspace(
  465. const KernSizeParam& kern_size_param) const {
  466. MIDOUT_BEGIN(
  467. megdnn_aarch64_matmul_kern,
  468. midout_iv("AlgoInt8x8x32MK4_8x12x4DotProd::get_workspace"_hash)) {
  469. auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
  470. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  471. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  472. C_type = kern_size_param.C_type;
  473. aarch64::matmul::gemm_mk4_s8_8x12 strategy(M, N, K, A_type, B_type, C_type);
  474. return megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_mk4_s8_8x12>(
  475. M, N, K, trA, trB, strategy)
  476. .get_workspace_size();
  477. }
  478. MIDOUT_END();
  479. return 0;
  480. }
  481. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::get_kern(
  482. const KernSizeParam&) const {
  483. return int8x8x32_mk4_8x12x4_dotprod_kern;
  484. }
  485. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(
  486. AlgoInt8x8x32MK4_8x12x4DotProd, megdnn_aarch64_matmul_kern,
  487. "AlgoInt8x8x32MK4_8x12x4DotProdImpl"_hash, aarch64::matmul::gemm_mk4_s8_8x12,
  488. int8_t, int32_t, AlgoDataType::QINT8X8X32, MK4_DOT);
  489. #endif
  490. /* ===================== Int8x8x32 MK4 4x4x16 algo ===================== */
  491. namespace {
  492. void int8x8x32_mk4_4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) {
  493. MIDOUT_BEGIN(
  494. megdnn_aarch64_matmul_kern, midout_iv("int8x8x32_mk4_4x4x16_kern"_hash)) {
  495. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  496. auto trA = kern_param.trA, trB = kern_param.trB;
  497. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  498. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  499. C_type = kern_param.C_type;
  500. const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
  501. auto Cptr = kern_param.C<dt_int32>();
  502. aarch64::matmul::gemm_mk4_s8_4x4 strategy(M, N, K, A_type, B_type, C_type);
  503. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_mk4_s8_4x4>(
  504. M, N, K, trA, trB, strategy)
  505. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
  506. }
  507. MIDOUT_END();
  508. }
  509. } // anonymous namespace
  510. bool MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::usable(const KernSizeParam& param) const {
  511. return param.A_type.enumv() == param.B_type.enumv() &&
  512. (param.A_type.enumv() == DTypeEnum::Int8 ||
  513. param.A_type.enumv() == DTypeEnum::QuantizedS8) &&
  514. (param.C_type.enumv() == DTypeEnum::Int32 ||
  515. param.C_type.enumv() == DTypeEnum::QuantizedS32) &&
  516. param.compute_mode == Param::ComputeMode::DEFAULT &&
  517. param.format == param::MatrixMul::Format::MK4 && !param.trA && !param.trB &&
  518. param.M % 4 == 0 && param.K % 4 == 0;
  519. }
  520. bool MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::preferred(
  521. const KernSizeParam& kern_size_param) const {
  522. return kern_size_param.K > 16;
  523. }
  524. size_t MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::get_workspace(
  525. const KernSizeParam& kern_size_param) const {
  526. MIDOUT_BEGIN(
  527. megdnn_aarch64_matmul_kern,
  528. midout_iv("AlgoInt8x8x32MK4_4x4x16::get_workspace"_hash)) {
  529. auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
  530. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  531. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  532. C_type = kern_size_param.C_type;
  533. aarch64::matmul::gemm_mk4_s8_4x4 strategy(M, N, K, A_type, B_type, C_type);
  534. return megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_mk4_s8_4x4>(
  535. M, N, K, trA, trB, strategy)
  536. .get_workspace_size();
  537. }
  538. MIDOUT_END();
  539. return 0;
  540. }
  541. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::get_kern(
  542. const KernSizeParam&) const {
  543. return int8x8x32_mk4_4x4x16_kern;
  544. }
  545. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(
  546. AlgoInt8x8x32MK4_4x4x16, megdnn_aarch64_matmul_kern,
  547. "AlgoInt8x8x32MK4_4x4x16Impl"_hash, aarch64::matmul::gemm_mk4_s8_4x4, int8_t,
  548. int32_t, AlgoDataType::QINT8X8X32, MK4);
  549. /* ===================== Int8x8x32 K4x4x16 algo ===================== */
  550. namespace {
  551. void int8x8x32_k4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) {
  552. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, midout_iv("int8x8x32_k4x4x16_kern"_hash)) {
  553. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  554. auto trA = kern_param.trA, trB = kern_param.trB;
  555. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  556. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  557. C_type = kern_param.C_type;
  558. const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
  559. auto Cptr = kern_param.C<dt_int32>();
  560. aarch64::matmul::gemm_s8_4x4 strategy(M, N, K, A_type, B_type, C_type);
  561. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s8_4x4>(
  562. M, N, K, trA, trB, strategy)
  563. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
  564. }
  565. MIDOUT_END();
  566. }
  567. } // anonymous namespace
  568. bool MatrixMulImpl::AlgoInt8x8x32K4x4x16::usable(
  569. const KernSizeParam& kern_size_param) const {
  570. return can_be_treated_as_int8x8x32(kern_size_param);
  571. }
  572. bool MatrixMulImpl::AlgoInt8x8x32K4x4x16::preferred(
  573. const KernSizeParam& kern_size_param) const {
  574. return kern_size_param.K > 16;
  575. }
  576. size_t MatrixMulImpl::AlgoInt8x8x32K4x4x16::get_workspace(
  577. const KernSizeParam& kern_size_param) const {
  578. MIDOUT_BEGIN(
  579. megdnn_aarch64_matmul_kern,
  580. midout_iv("AlgoInt8x8x32K4x4x16::get_workspace"_hash)) {
  581. auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
  582. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  583. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  584. C_type = kern_size_param.C_type;
  585. aarch64::matmul::gemm_s8_4x4 strategy(M, N, K, A_type, B_type, C_type);
  586. return megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s8_4x4>(
  587. M, N, K, trA, trB, strategy)
  588. .get_workspace_size();
  589. }
  590. MIDOUT_END();
  591. return 0;
  592. }
  593. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K4x4x16::get_kern(
  594. const KernSizeParam&) const {
  595. return int8x8x32_k4x4x16_kern;
  596. }
  597. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(
  598. AlgoInt8x8x32K4x4x16, megdnn_aarch64_matmul_kern,
  599. "AlgoInt8x8x32K4x4x16Impl"_hash, aarch64::matmul::gemm_s8_4x4, int8_t, int32_t,
  600. AlgoDataType::QINT8X8X32, DEFAULT);
  601. /* ===================== Int8x8x32 K8x8x8 algo ===================== */
  602. namespace {
  603. void int8x8x32_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) {
  604. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, midout_iv("int8x8x32_k8x8x8_kern"_hash)) {
  605. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  606. auto trA = kern_param.trA, trB = kern_param.trB;
  607. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  608. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  609. C_type = kern_param.C_type;
  610. const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
  611. auto Cptr = kern_param.C<dt_int32>();
  612. aarch64::matmul::gemm_s8_8x8 strategy(M, N, K, A_type, B_type, C_type);
  613. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s8_8x8>(
  614. M, N, K, trA, trB, strategy)
  615. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
  616. }
  617. MIDOUT_END();
  618. }
  619. } // anonymous namespace
  620. bool MatrixMulImpl::AlgoInt8x8x32K8x8x8::usable(
  621. const KernSizeParam& kern_size_param) const {
  622. return can_be_treated_as_int8x8x32(kern_size_param);
  623. }
  624. bool MatrixMulImpl::AlgoInt8x8x32K8x8x8::preferred(
  625. const KernSizeParam& kern_size_param) const {
  626. return kern_size_param.K <= 16;
  627. }
  628. size_t MatrixMulImpl::AlgoInt8x8x32K8x8x8::get_workspace(
  629. const KernSizeParam& kern_size_param) const {
  630. MIDOUT_BEGIN(
  631. megdnn_aarch64_matmul_kern,
  632. midout_iv("AlgoInt8x8x32K8x8x8::get_workspace"_hash)) {
  633. auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
  634. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  635. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  636. C_type = kern_size_param.C_type;
  637. aarch64::matmul::gemm_s8_8x8 strategy(M, N, K, A_type, B_type, C_type);
  638. return megdnn::matmul::GemmInterleaved<matmul::gemm_s8_8x8>(
  639. M, N, K, trA, trB, strategy)
  640. .get_workspace_size();
  641. }
  642. MIDOUT_END();
  643. return 0;
  644. }
  645. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K8x8x8::get_kern(
  646. const KernSizeParam&) const {
  647. return int8x8x32_k8x8x8_kern;
  648. }
  649. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(
  650. AlgoInt8x8x32K8x8x8, megdnn_aarch64_matmul_kern, "AlgoInt8x8x32K8x8x8Impl"_hash,
  651. aarch64::matmul::gemm_s8_8x8, int8_t, int32_t, AlgoDataType::QINT8X8X32,
  652. DEFAULT);
  653. /* ===================== Int8x8x16 K8x8x8 algo ===================== */
  654. namespace {
  655. void int8x8x16_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) {
  656. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, midout_iv("int8x8x16_k8x8x8_kern"_hash)) {
  657. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  658. auto trA = kern_param.trA, trB = kern_param.trB;
  659. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  660. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  661. C_type = kern_param.C_type;
  662. const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
  663. auto Cptr = kern_param.C<dt_int16>();
  664. aarch64::matmul::gemm_s8x8x16_8x8 strategy(M, N, K, A_type, B_type, C_type);
  665. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s8x8x16_8x8>(
  666. M, N, K, trA, trB, strategy)
  667. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
  668. }
  669. MIDOUT_END();
  670. }
  671. } // anonymous namespace
  672. bool MatrixMulImpl::AlgoInt8x8x16K8x8x8::usable(
  673. const KernSizeParam& kern_size_param) const {
  674. return can_be_treated_as_int8x8x16(kern_size_param) &&
  675. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  676. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT;
  677. }
  678. bool MatrixMulImpl::AlgoInt8x8x16K8x8x8::preferred(
  679. const KernSizeParam& kern_size_param) const {
  680. return kern_size_param.K <= 16;
  681. }
  682. size_t MatrixMulImpl::AlgoInt8x8x16K8x8x8::get_workspace(
  683. const KernSizeParam& kern_size_param) const {
  684. MIDOUT_BEGIN(
  685. megdnn_aarch64_matmul_kern,
  686. midout_iv("AlgoInt8x8x16K8x8x8::get_workspace"_hash)) {
  687. auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
  688. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  689. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  690. C_type = kern_size_param.C_type;
  691. aarch64::matmul::gemm_s8x8x16_8x8 strategy(M, N, K, A_type, B_type, C_type);
  692. return megdnn::matmul::GemmInterleaved<matmul::gemm_s8x8x16_8x8>(
  693. M, N, K, trA, trB, strategy)
  694. .get_workspace_size();
  695. }
  696. MIDOUT_END();
  697. return 0;
  698. }
  699. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K8x8x8::get_kern(
  700. const KernSizeParam&) const {
  701. return int8x8x16_k8x8x8_kern;
  702. }
  703. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(
  704. AlgoInt8x8x16K8x8x8, megdnn_aarch64_matmul_kern, "AlgoInt8x8x16K8x8x8Impl"_hash,
  705. aarch64::matmul::gemm_s8x8x16_8x8, int8_t, int16_t, AlgoDataType::INT8X8X16,
  706. DEFAULT);
  707. /* ===================== Int8x8x16 K4x4x16 algo ===================== */
  708. namespace {
  709. void int8x8x16_k4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) {
  710. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, midout_iv("int8x8x16_k4x4x16_kern"_hash)) {
  711. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  712. auto trA = kern_param.trA, trB = kern_param.trB;
  713. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  714. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  715. C_type = kern_param.C_type;
  716. const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
  717. auto Cptr = kern_param.C<dt_int16>();
  718. aarch64::matmul::gemm_s8x8x16_4x4 strategy(M, N, K, A_type, B_type, C_type);
  719. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s8x8x16_4x4>(
  720. M, N, K, trA, trB, strategy)
  721. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
  722. }
  723. MIDOUT_END();
  724. }
  725. } // anonymous namespace
  726. bool MatrixMulImpl::AlgoInt8x8x16K4x4x16::usable(
  727. const KernSizeParam& kern_size_param) const {
  728. return can_be_treated_as_int8x8x16(kern_size_param) &&
  729. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  730. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT;
  731. }
  732. bool MatrixMulImpl::AlgoInt8x8x16K4x4x16::preferred(
  733. const KernSizeParam& kern_size_param) const {
  734. MEGDNN_MARK_USED_VAR(kern_size_param);
  735. return true;
  736. }
  737. size_t MatrixMulImpl::AlgoInt8x8x16K4x4x16::get_workspace(
  738. const KernSizeParam& kern_size_param) const {
  739. MIDOUT_BEGIN(
  740. megdnn_aarch64_matmul_kern,
  741. midout_iv("AlgoInt8x8x16K4x4x16::get_workspace"_hash)) {
  742. auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
  743. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  744. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  745. C_type = kern_size_param.C_type;
  746. aarch64::matmul::gemm_s8x8x16_4x4 strategy(M, N, K, A_type, B_type, C_type);
  747. return megdnn::matmul::GemmInterleaved<matmul::gemm_s8x8x16_4x4>(
  748. M, N, K, trA, trB, strategy)
  749. .get_workspace_size();
  750. }
  751. MIDOUT_END();
  752. return 0;
  753. }
  754. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K4x4x16::get_kern(
  755. const KernSizeParam&) const {
  756. return int8x8x16_k4x4x16_kern;
  757. }
  758. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(
  759. AlgoInt8x8x16K4x4x16, megdnn_aarch64_matmul_kern,
  760. "AlgoInt8x8x16K4x4x16Impl"_hash, aarch64::matmul::gemm_s8x8x16_4x4, int8_t,
  761. int16_t, AlgoDataType::INT8X8X16, DEFAULT);
  762. /* ===================== Int8x8x16 K16x12x4 algo ===================== */
  763. namespace {
  764. void int8x8x16_mk4_16x12x4_kern(const MatrixMulImpl::KernParam& kern_param) {
  765. MIDOUT_BEGIN(
  766. megdnn_aarch64_matmul_kern, midout_iv("int8x8x16_mk4_16x12x4_kern"_hash)) {
  767. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  768. auto trA = kern_param.trA, trB = kern_param.trB;
  769. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  770. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  771. C_type = kern_param.C_type;
  772. const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
  773. auto Cptr = kern_param.C<dt_int16>();
  774. aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53 strategy(
  775. M, N, K, A_type, B_type, C_type);
  776. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53>(
  777. M, N, K, trA, trB, strategy)
  778. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
  779. }
  780. MIDOUT_END();
  781. }
  782. } // anonymous namespace
  783. bool MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::usable(
  784. const KernSizeParam& kern_size_param) const {
  785. return can_be_treated_as_int8x8x16(kern_size_param) &&
  786. kern_size_param.format == param::MatrixMul::Format::MK4 &&
  787. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  788. !kern_size_param.trA && !kern_size_param.trB && kern_size_param.M % 4 == 0 &&
  789. kern_size_param.K % 4 == 0;
  790. }
  791. bool MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::preferred(const KernSizeParam&) const {
  792. #if !MGB_ENABLE_CPUINFO
  793. return false;
  794. #else
  795. auto arch = cpuinfo_get_current_core()->uarch;
  796. #ifdef __IN_TEE_ENV__
  797. arch = cpuinfo_uarch_unknown;
  798. #endif
  799. bool little_core =
  800. arch == cpuinfo_uarch_cortex_a53 || arch == cpuinfo_uarch_cortex_a55;
  801. return little_core;
  802. #endif
  803. }
  804. size_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_workspace(
  805. const KernSizeParam& kern_size_param) const {
  806. MIDOUT_BEGIN(
  807. megdnn_aarch64_matmul_kern,
  808. midout_iv("AlgoInt8x8x16MK4_16x12x4::get_workspace"_hash)) {
  809. auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
  810. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  811. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  812. C_type = kern_size_param.C_type;
  813. aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53 strategy(
  814. M, N, K, A_type, B_type, C_type);
  815. return megdnn::matmul::GemmInterleaved<matmul::gemm_s8x8x16_mk4_16x12_a53>(
  816. M, N, K, trA, trB, strategy)
  817. .get_workspace_size();
  818. }
  819. MIDOUT_END();
  820. return 0;
  821. }
  822. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_kern(
  823. const KernSizeParam&) const {
  824. return int8x8x16_mk4_16x12x4_kern;
  825. }
  826. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(
  827. AlgoInt8x8x16MK4_16x12x4, megdnn_aarch64_matmul_kern,
  828. "AlgoInt8x8x16MK4_16x12x4Impl"_hash,
  829. aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53, int8_t, int16_t, int16_t,
  830. AlgoDataType::INT8X8X16, MK4);
  831. /* ===================== Int8x8x16 MK4 4x4x8 algo ===================== */
  832. namespace {
  833. void int8x8x16_mk4_4x4x8_kern(const MatrixMulImpl::KernParam& kern_param) {
  834. MIDOUT_BEGIN(
  835. megdnn_aarch64_matmul_kern, midout_iv("int8x8x16_mk4_4x4x8_kern"_hash)) {
  836. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  837. auto trA = kern_param.trA, trB = kern_param.trB;
  838. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  839. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  840. C_type = kern_param.C_type;
  841. const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
  842. auto Cptr = kern_param.C<dt_int16>();
  843. aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72 strategy(
  844. M, N, K, A_type, B_type, C_type);
  845. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72>(
  846. M, N, K, trA, trB, strategy)
  847. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
  848. }
  849. MIDOUT_END();
  850. }
  851. } // anonymous namespace
  852. bool MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::usable(
  853. const KernSizeParam& kern_size_param) const {
  854. return can_be_treated_as_int8x8x16(kern_size_param) &&
  855. kern_size_param.format == param::MatrixMul::Format::MK4 &&
  856. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  857. !kern_size_param.trA && !kern_size_param.trB && kern_size_param.M % 4 == 0 &&
  858. kern_size_param.K % 4 == 0;
  859. }
  860. bool MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::preferred(const KernSizeParam&) const {
  861. #if !MGB_ENABLE_CPUINFO
  862. return false;
  863. #else
  864. auto arch = cpuinfo_get_current_core()->uarch;
  865. #ifdef __IN_TEE_ENV__
  866. arch = cpuinfo_uarch_unknown;
  867. #endif
  868. bool little_core =
  869. arch == cpuinfo_uarch_cortex_a53 || arch == cpuinfo_uarch_cortex_a55;
  870. return !little_core;
  871. #endif
  872. }
  873. size_t MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::get_workspace(
  874. const KernSizeParam& kern_size_param) const {
  875. MIDOUT_BEGIN(
  876. megdnn_aarch64_matmul_kern,
  877. midout_iv("AlgoInt8x8x16MK4_4x4x8::get_workspace"_hash)) {
  878. auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
  879. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  880. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  881. C_type = kern_size_param.C_type;
  882. aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72 strategy(
  883. M, N, K, A_type, B_type, C_type);
  884. return megdnn::matmul::GemmInterleaved<matmul::gemm_s8x8x16_mk4_4x4_a72>(
  885. M, N, K, trA, trB, strategy)
  886. .get_workspace_size();
  887. }
  888. MIDOUT_END();
  889. return 0;
  890. }
  891. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::get_kern(
  892. const KernSizeParam&) const {
  893. return int8x8x16_mk4_4x4x8_kern;
  894. }
  895. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(
  896. AlgoInt8x8x16MK4_4x4x8, megdnn_aarch64_matmul_kern,
  897. "AlgoInt8x8x16MK4_4x4x8_Impl"_hash, aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72,
  898. int8_t, int16_t, AlgoDataType::INT8X8X16, MK4);
  899. /* ===================== Int16x16x32 K12x8x1 algo ===================== */
  900. namespace {
  901. void int16x16x32_k12x8x1_kern(const MatrixMulImpl::KernParam& kern_param) {
  902. MIDOUT_BEGIN(
  903. megdnn_aarch64_matmul_kern, midout_iv("int16x16x32_k12x8x1_kern"_hash)) {
  904. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  905. auto trA = kern_param.trA, trB = kern_param.trB;
  906. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  907. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  908. C_type = kern_param.C_type;
  909. const auto Aptr = kern_param.A<dt_int16>(), Bptr = kern_param.B<dt_int16>();
  910. auto Cptr = kern_param.C<dt_int32>();
  911. aarch64::matmul::gemm_s16_12x8x1 strategy(M, N, K, A_type, B_type, C_type);
  912. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s16_12x8x1>(
  913. M, N, K, trA, trB, strategy)
  914. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
  915. }
  916. MIDOUT_END();
  917. }
  918. } // anonymous namespace
  919. bool MatrixMulImpl::AlgoInt16x16x32K12x8x1::usable(
  920. const KernSizeParam& kern_size_param) const {
  921. return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
  922. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  923. kern_size_param.compute_mode == param::MatrixMul::ComputeMode::DEFAULT &&
  924. kern_size_param.A_type.enumv() == DTypeEnum::Int16 &&
  925. kern_size_param.C_type.enumv() == DTypeEnum::Int32;
  926. }
  927. bool MatrixMulImpl::AlgoInt16x16x32K12x8x1::preferred(
  928. const KernSizeParam& kern_size_param) const {
  929. MEGDNN_MARK_USED_VAR(kern_size_param);
  930. return true;
  931. }
  932. size_t MatrixMulImpl::AlgoInt16x16x32K12x8x1::get_workspace(
  933. const KernSizeParam& kern_size_param) const {
  934. MIDOUT_BEGIN(
  935. megdnn_aarch64_matmul_kern,
  936. midout_iv("AlgoInt16x16x32K12x8x1::get_workspace"_hash)) {
  937. auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
  938. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  939. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  940. C_type = kern_size_param.C_type;
  941. aarch64::matmul::gemm_s16_12x8x1 strategy(M, N, K, A_type, B_type, C_type);
  942. return megdnn::matmul::GemmInterleaved<matmul::gemm_s16_12x8x1>(
  943. M, N, K, trA, trB, strategy)
  944. .get_workspace_size();
  945. }
  946. MIDOUT_END();
  947. return 0;
  948. }
  949. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32K12x8x1::get_kern(
  950. const KernSizeParam&) const {
  951. return int16x16x32_k12x8x1_kern;
  952. }
  953. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(
  954. AlgoInt16x16x32K12x8x1, megdnn_aarch64_matmul_kern,
  955. "AlgoInt16x16x32K12x8x1Impl"_hash, aarch64::matmul::gemm_s16_12x8x1, int16_t,
  956. int32_t, AlgoDataType::INT16X16X32, DEFAULT);
  957. /* ===================== Int16x16x32MK8_8x8 algo ===================== */
  958. bool MatrixMulImpl::AlgoInt16x16x32MK8_8x8::usable(
  959. const KernSizeParam& kern_size_param) const {
  960. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  961. kern_size_param.C_type == dtype::Int32() &&
  962. kern_size_param.B_type == dtype::Int16() &&
  963. kern_size_param.A_type == dtype::Int16() &&
  964. kern_size_param.format == param::MatrixMul::Format::MK8 &&
  965. !kern_size_param.trA && !kern_size_param.trB;
  966. }
  967. size_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_workspace(
  968. const KernSizeParam& kern_size_param) const {
  969. MIDOUT_BEGIN(
  970. megdnn_aarch64_matmul_kern,
  971. midout_iv("AlgoInt16x16x32MK8_8x8::get_workspace"_hash)) {
  972. auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
  973. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  974. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  975. C_type = kern_size_param.C_type;
  976. aarch64::matmul::gemm_nopack_s16_8x8 strategy(A_type, B_type, C_type);
  977. return megdnn::matmul::GemmInterleaved<
  978. aarch64::matmul::gemm_nopack_s16_8x8, false>(
  979. M, N, K, trA, trB, strategy)
  980. .get_workspace_size();
  981. }
  982. MIDOUT_END();
  983. return 0;
  984. }
  985. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_kern(
  986. const KernSizeParam&) const {
  987. auto kern_mk8_8x8 = [](const MatrixMulImpl::KernParam& kern_param) {
  988. MIDOUT_BEGIN(
  989. megdnn_aarch64_matmul_kern,
  990. midout_iv("AlgoInt16x16x32MK8_8x8::get_kern"_hash)) {
  991. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  992. auto trA = kern_param.trA, trB = kern_param.trB;
  993. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  994. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  995. C_type = kern_param.C_type;
  996. const auto Aptr = kern_param.A<dt_int16>(), Bptr = kern_param.B<dt_int16>();
  997. auto Cptr = kern_param.C<dt_int32>();
  998. aarch64::matmul::gemm_nopack_s16_8x8 strategy(A_type, B_type, C_type);
  999. megdnn::matmul::GemmInterleaved<
  1000. aarch64::matmul::gemm_nopack_s16_8x8, false>(
  1001. M, N, K, trA, trB, strategy)
  1002. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
  1003. }
  1004. MIDOUT_END();
  1005. };
  1006. return kern_mk8_8x8;
  1007. }
  1008. #if MGB_ENABLE_DOT
  1009. /* ==================== Quint8 K8x8x4 Dotprod algo ==================== */
  1010. namespace {
  1011. void quint8_k8x8x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
  1012. MIDOUT_BEGIN(
  1013. megdnn_aarch64_matmul_kern, midout_iv("quint8_k8x8x4_dotprod_kern"_hash)) {
  1014. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  1015. auto trA = kern_param.trA, trB = kern_param.trB;
  1016. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  1017. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  1018. C_type = kern_param.C_type;
  1019. const auto Aptr = kern_param.A<dt_uint8>(), Bptr = kern_param.B<dt_uint8>();
  1020. auto Cptr = kern_param.C<dt_int32>();
  1021. aarch64::matmul::gemm_u8_8x8_dot strategy(M, N, K, A_type, B_type, C_type);
  1022. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_u8_8x8_dot>(
  1023. M, N, K, trA, trB, strategy)
  1024. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
  1025. }
  1026. MIDOUT_END();
  1027. }
  1028. } // anonymous namespace
  1029. bool MatrixMulImpl::AlgoQuint8K8x8x4DotProd::usable(
  1030. const KernSizeParam& kern_size_param) const {
  1031. if (!cpuinfo_has_arm_neon_dot()) {
  1032. return false;
  1033. }
  1034. return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm &&
  1035. kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm &&
  1036. kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 &&
  1037. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  1038. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT;
  1039. }
  1040. size_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_workspace(
  1041. const KernSizeParam& kern_size_param) const {
  1042. MIDOUT_BEGIN(
  1043. megdnn_aarch64_matmul_kern,
  1044. midout_iv("AlgoQuint8K8x8x4DotProd::get_workspace"_hash)) {
  1045. auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
  1046. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  1047. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  1048. C_type = kern_size_param.C_type;
  1049. aarch64::matmul::gemm_u8_8x8_dot strategy(M, N, K, A_type, B_type, C_type);
  1050. return megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_u8_8x8_dot>(
  1051. M, N, K, trA, trB, strategy)
  1052. .get_workspace_size();
  1053. }
  1054. MIDOUT_END();
  1055. return 0;
  1056. }
  1057. MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_kern(
  1058. const KernSizeParam&) const {
  1059. return quint8_k8x8x4_dotprod_kern;
  1060. }
  1061. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(
  1062. AlgoQuint8K8x8x4DotProd, megdnn_aarch64_matmul_kern,
  1063. "AlgoQuint8K8x8x4DotProdImpl"_hash, aarch64::matmul::gemm_u8_8x8_dot, uint8_t,
  1064. int32_t, AlgoDataType::QUINT8X8X32, DEFAULT);
  1065. /* ===================== Quint8 Gemv DotProd algo ===================== */
  1066. namespace {
  1067. void quint8_gemv_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
  1068. MIDOUT_BEGIN(
  1069. megdnn_aarch64_matmul_kern, midout_iv("quint8_gemv_dotprod_kern"_hash)) {
  1070. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  1071. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  1072. const auto Aptr = kern_param.A<dt_uint8>(), Bptr = kern_param.B<dt_uint8>();
  1073. auto Cptr = kern_param.C<dt_int32>();
  1074. auto A_type = kern_param.A_type, B_type = kern_param.B_type;
  1075. aarch64::matmul::gemv_like_quint8(
  1076. Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC,
  1077. A_type.param<dtype::Quantized8Asymm>().zero_point,
  1078. B_type.param<dtype::Quantized8Asymm>().zero_point);
  1079. }
  1080. MIDOUT_END();
  1081. }
  1082. } // anonymous namespace
  1083. bool MatrixMulImpl::AlgoQuint8GemvDotProd::usable(
  1084. const KernSizeParam& kern_size_param) const {
  1085. if (!cpuinfo_has_arm_neon_dot()) {
  1086. return false;
  1087. }
  1088. return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm &&
  1089. kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm &&
  1090. kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 &&
  1091. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  1092. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  1093. !kern_size_param.trA && !kern_size_param.trB && kern_size_param.N == 1 &&
  1094. kern_size_param.LDB == 1;
  1095. }
  1096. bool MatrixMulImpl::AlgoQuint8GemvDotProd::preferred(
  1097. const KernSizeParam& kern_size_param) const {
  1098. auto N = kern_size_param.N, LDB = kern_size_param.LDB;
  1099. return (N == 1 && LDB == 1);
  1100. }
  1101. MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8GemvDotProd::get_kern(
  1102. const KernSizeParam&) const {
  1103. return quint8_gemv_dotprod_kern;
  1104. }
  1105. #endif
  1106. /* ===================== Quint8 K8x8x8 algo ===================== */
  1107. namespace {
  1108. void quint8_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) {
  1109. MIDOUT_BEGIN(
  1110. megdnn_aarch64_matmul_kern, midout_iv("quint8_gemv_dotprod_kern"_hash)) {
  1111. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  1112. auto trA = kern_param.trA, trB = kern_param.trB;
  1113. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  1114. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  1115. C_type = kern_param.C_type;
  1116. const auto Aptr = kern_param.A<dt_uint8>(), Bptr = kern_param.B<dt_uint8>();
  1117. auto Cptr = kern_param.C<dt_int32>();
  1118. aarch64::matmul::gemm_u8_8x8 strategy(M, N, K, A_type, B_type, C_type);
  1119. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_u8_8x8>(
  1120. M, N, K, trA, trB, strategy)
  1121. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
  1122. }
  1123. MIDOUT_END();
  1124. }
  1125. } // anonymous namespace
  1126. bool MatrixMulImpl::AlgoQuint8K8x8x8::usable(
  1127. const KernSizeParam& kern_size_param) const {
  1128. return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm &&
  1129. kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm &&
  1130. kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 &&
  1131. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  1132. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT;
  1133. }
  1134. size_t MatrixMulImpl::AlgoQuint8K8x8x8::get_workspace(
  1135. const KernSizeParam& kern_size_param) const {
  1136. MIDOUT_BEGIN(
  1137. megdnn_aarch64_matmul_kern,
  1138. midout_iv("AlgoQuint8K8x8x8::get_workspace"_hash)) {
  1139. auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
  1140. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  1141. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  1142. C_type = kern_size_param.C_type;
  1143. aarch64::matmul::gemm_u8_8x8 strategy(M, N, K, A_type, B_type, C_type);
  1144. return megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_u8_8x8>(
  1145. M, N, K, trA, trB, strategy)
  1146. .get_workspace_size();
  1147. }
  1148. MIDOUT_END();
  1149. return 0;
  1150. }
  1151. MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K8x8x8::get_kern(
  1152. const KernSizeParam&) const {
  1153. return quint8_k8x8x8_kern;
  1154. }
  1155. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(
  1156. AlgoQuint8K8x8x8, megdnn_aarch64_matmul_kern, "AlgoQuint8K8x8x8Impl"_hash,
  1157. aarch64::matmul::gemm_u8_8x8, uint8_t, int32_t, AlgoDataType::QUINT8X8X32,
  1158. DEFAULT);
  1159. /* ===================== Int8x8x16 K8x8x8 algo ===================== */
  1160. namespace {
  1161. void int8x8x16_mk4_8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) {
  1162. MIDOUT_BEGIN(
  1163. megdnn_aarch64_matmul_kern, midout_iv("int8x8x16_mk4_8x8x8_kern"_hash)) {
  1164. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  1165. auto trA = kern_param.trA, trB = kern_param.trB;
  1166. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  1167. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  1168. C_type = kern_param.C_type;
  1169. const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
  1170. auto Cptr = kern_param.C<dt_int16>();
  1171. aarch64::matmul::gemm_s8x8x16_mk4_8x8x8 strategy(
  1172. M, N, K, A_type, B_type, C_type);
  1173. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s8x8x16_mk4_8x8x8>(
  1174. M, N, K, trA, trB, strategy)
  1175. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
  1176. }
  1177. MIDOUT_END();
  1178. }
  1179. } // anonymous namespace
  1180. bool MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::usable(
  1181. const KernSizeParam& kern_size_param) const {
  1182. return can_be_treated_as_int8x8x16(kern_size_param) &&
  1183. kern_size_param.format == param::MatrixMul::Format::MK4 &&
  1184. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  1185. !kern_size_param.trA && !kern_size_param.trB && kern_size_param.M % 4 == 0 &&
  1186. kern_size_param.K % 4 == 0;
  1187. }
  1188. bool MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::preferred(const KernSizeParam&) const {
  1189. return true;
  1190. }
  1191. size_t MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::get_workspace(
  1192. const KernSizeParam& kern_size_param) const {
  1193. MIDOUT_BEGIN(
  1194. megdnn_aarch64_matmul_kern,
  1195. midout_iv("AlgoInt8x8x16_MK4_8x8x8::get_workspace"_hash)) {
  1196. auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
  1197. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  1198. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  1199. C_type = kern_size_param.C_type;
  1200. aarch64::matmul::gemm_s8x8x16_mk4_8x8x8 strategy(
  1201. M, N, K, A_type, B_type, C_type);
  1202. return megdnn::matmul::GemmInterleaved<matmul::gemm_s8x8x16_mk4_8x8x8>(
  1203. M, N, K, trA, trB, strategy)
  1204. .get_workspace_size();
  1205. }
  1206. MIDOUT_END();
  1207. return 0;
  1208. }
  1209. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::get_kern(
  1210. const KernSizeParam&) const {
  1211. return int8x8x16_mk4_8x8x8_kern;
  1212. }
  1213. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(
  1214. AlgoInt8x8x16MK4_K8x8x8, megdnn_aarch64_matmul_kern,
  1215. "AlgoInt8x8x16MK4_K8x8x8Impl"_hash, aarch64::matmul::gemm_s8x8x16_mk4_8x8x8,
  1216. int8_t, int16_t, AlgoDataType::INT8X8X16, MK4);
  1217. /* ===================== Int4x4x16 K8x8x8 algo ===================== */
  1218. namespace {
  1219. void int4x4x16_k8x8x16_kern(const MatrixMulImpl::KernParam& kern_param) {
  1220. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, midout_iv("int4x4x16_k8x8x8_kern"_hash)) {
  1221. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  1222. auto trA = kern_param.trA, trB = kern_param.trB;
  1223. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  1224. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  1225. C_type = kern_param.C_type;
  1226. const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
  1227. auto Cptr = kern_param.C<dt_int16>();
  1228. aarch64::matmul::gemm_s4x4x16_s4_8x8x8 strategy(
  1229. M, N, K, A_type, B_type, C_type);
  1230. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s4x4x16_s4_8x8x8>(
  1231. M, N, K, trA, trB, strategy)
  1232. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
  1233. }
  1234. MIDOUT_END();
  1235. }
  1236. } // anonymous namespace
  1237. bool MatrixMulImpl::AlgoInt4x4x16K8x8x8::usable(
  1238. const KernSizeParam& kern_size_param) const {
  1239. return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
  1240. kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS4 &&
  1241. kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16 &&
  1242. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  1243. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  1244. (kern_size_param.K & 1) == 0 && (kern_size_param.N & 1) == 0;
  1245. }
  1246. bool MatrixMulImpl::AlgoInt4x4x16K8x8x8::preferred(
  1247. const KernSizeParam& kern_size_param) const {
  1248. MEGDNN_MARK_USED_VAR(kern_size_param);
  1249. return true;
  1250. }
  1251. size_t MatrixMulImpl::AlgoInt4x4x16K8x8x8::get_workspace(
  1252. const KernSizeParam& kern_size_param) const {
  1253. MIDOUT_BEGIN(
  1254. megdnn_aarch64_matmul_kern,
  1255. midout_iv("AlgoInt4x4x16K8x8x8::get_workspace"_hash)) {
  1256. auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
  1257. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  1258. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  1259. C_type = kern_size_param.C_type;
  1260. aarch64::matmul::gemm_s4x4x16_s4_8x8x8 strategy(
  1261. M, N, K, A_type, B_type, C_type);
  1262. return megdnn::matmul::GemmInterleaved<matmul::gemm_s4x4x16_s4_8x8x8>(
  1263. M, N, K, trA, trB, strategy)
  1264. .get_workspace_size();
  1265. }
  1266. MIDOUT_END();
  1267. }
  1268. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt4x4x16K8x8x8::get_kern(
  1269. const KernSizeParam&) const {
  1270. return int4x4x16_k8x8x16_kern;
  1271. }
  1272. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(
  1273. AlgoInt4x4x16K8x8x8, megdnn_aarch64_matmul_kern, "AlgoInt4x4x16K8x8x8Impl"_hash,
  1274. aarch64::matmul::gemm_s4x4x16_s4_8x8x8, int8_t, int16_t,
  1275. AlgoDataType::INT4X4X16, DEFAULT);
  1276. // vim: syntax=cpp.doxygen