You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batchmatmul.py 70 kB

5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2019 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """operator dsl function: batchmatmul"""
  17. from functools import reduce
  18. import akg.topi
  19. import akg.tvm
  20. from akg.tvm.hybrid import script
  21. from akg.ops.math import cast
  22. from akg.utils import custom_tiling as ct_util
  23. from akg.utils import validation_check as vc_util
  24. from akg.utils.format_transform import get_shape, get_bytes
  25. from akg.utils.math import greatest_common_divisor, least_common_multiple
  26. from akg.utils.kernel_exec import product_is_mini
  27. from akg.utils import dynamic_shape as ds
  28. batchmatmul_set_dim_map = {
  29. # 2D
  30. str((256, 1024, 4096, "float32", False, True)): ((16, 16), (16, 16), (16, 16)),
  31. str((160, 1024, 1024, "float32", False, False)): ((1, 1), (16, 16), (1024, 1024)),
  32. str((8192, 1024, 4096, "float32", False, False)): ((1, 1), (1024, 1024), (16, 16)),
  33. str((1024, 1024, 8192, "float32", True, False)): ((8, 8), (8, 8), (512, 512)),
  34. str((1024, 1024, 2, "float32", False, False)): ((64, 64), (64, 64), (2, 2)),
  35. str((1024, 1024, 4096, "float32", False, False)): ((1, 1), (16, 16), (512, 512)),
  36. str((2, 1024, 8192, "float32", True, False)): ((2, 2), (16, 16), (512, 512)),
  37. str((30522, 1024, 1280, "float32", True, False)): ((3, 3), (64, 64), (128, 128)),
  38. str((1024, 4096, 8192, "float32", True, False)): ((32, 32), (32, 32), (32, 32)),
  39. str((2, 1024, 64, "float32", True, False)): ((4, 4), (64, 64), (64, 64)),
  40. str((160, 30522, 1024, "float32", False, True)): ((1, 1), (3, 3), (512, 512)),
  41. str((1024, 1024, 64, "float32", True, False)): ((16, 16), (16, 16), (64, 64)),
  42. str((4096, 1024, 8192, "float32", True, False)): ((16, 16), (64, 64), (16, 16)),
  43. str((1280, 1024, 30522, "float32", False, False)): ((1, 1), (512, 512), (3, 3)),
  44. str((8192, 1024, 4096, "float32", False, True)): ((1, 1), (16, 16), (1024, 1024)),
  45. str((1280, 30522, 1024, "float32", False, True)): ((1, 1), (3, 3), (512, 512)),
  46. str((8192, 4096, 1024, "float32", False, False)): ((1, 1), (64, 64), (256, 256)),
  47. str((768, 768, 8192, "float16", True, False)): ((16, 16), (16, 16), (64, 64)),
  48. str((3072, 768, 8192, "float16", True, False)): ((16, 16), (16, 16), (64, 64)),
  49. str((2, 768, 64, "float32", True, False)): ((2, 2), (64, 64), (64, 64)),
  50. str((768, 1024, 8192, "float16", False, False)): ((16, 1), (16, 1), (64, 1)),
  51. str((33, 64, 16384, "float32", True, False)): ((1, 1), (16, 16), (128, 128)),
  52. str((768, 768, 64, "float32", True, False)): ((16, 16), (16, 16), (16, 16)),
  53. str((8192, 768, 21128, "float32", False, False)): ((4, 4), (128, 128), (4, 4)),
  54. str((2, 768, 8192, "float32", True, False)): ((2, 2), (16, 16), (512, 512)),
  55. str((8192, 768, 768, "float16", False, True)): ((1, 1), (16, 16), (768, 768)),
  56. str((21128, 768, 8192, "float32", True, False)): ((4, 4), (128, 128), (64, 64)),
  57. str((768, 1024, 768, "float32", True, False)): ((2, 2), (128, 128), (128, 128)),
  58. str((16384, 16384, 33, "float32", True, False)): ((16, 16), (64, 64), (33, 33)),
  59. str((21128, 768, 8192, "float32", False, False)): ((4, 4), (128, 128), (64, 64)),
  60. str((1280, 1280, 1024, "float32", False, True)): ((4, 4), (32, 32), (128, 128)),
  61. str((1280, 768, 21128, "float32", False, False)): ((1, 1), (768, 768), (8, 8)),
  62. str((8192, 768, 768, "float32", False, False)): ((1, 1), (8, 8), (768, 768)),
  63. str((20, 768, 32000, "float32", False, False)): ((20, 20), (48, 48), (32, 32)),
  64. str((21128, 768, 1280, "float32", True, False)): ((2, 2), (32, 32), (32, 32)),
  65. str((768, 3072, 1892, "float32", True, False)): ((16, 16), (16, 16), (16, 16)),
  66. str((33, 64, 16384, "float32", False, True)): ((2, 2), (32, 32), (32, 32)),
  67. str((8192, 3072, 768, "float32", False, True)): ((16, 16), (16, 16), (16, 16)),
  68. str((2, 8192, 768, "float32", True, False)): ((16, 16), (16, 16), (16, 16)),
  69. str((8192, 768, 3072, "float32", False, False)): ((32, 32), (32, 32), (32, 32)),
  70. str((8192, 768, 3072, "float16", False, True)): ((1, 1), (16, 1), (768, 1)),
  71. str((8192, 3072, 768, "float16", False, True)): ((1, 1), (16, 1), (768, 1)),
  72. str((8192, 768, 768, "float16", False, False)): ((1, 1), (16, 1), (768, 1)),
  73. str((768, 3072, 768, "float16", False, False)): ((1, 1), (16, 1), (768, 1)),
  74. str((8192, 3072, 768, "float16", False, False)): ((1, 1), (16, 1), (768, 1)),
  75. str((8192, 768, 3072, "float16", False, False)): ((1, 1), (16, 1), (768, 1)),
  76. str((21128, 768, 5120, "float32", True, False)): ((4, 4), (128, 128), (64, 64)),
  77. str((21128, 768, 2560, "float32", True, False)): ((4, 4), (128, 128), (64, 64)),
  78. str((1024, 2, 4, "float32", True, False)): ((512, 1), (2, 1), (4, 1)),
  79. str((21128, 1024, 21128, "float32", False, False)): ((4, 4), (512, 512), (4, 4)),
  80. str((320, 768, 21128, "float32", False, False)): ((40, 40), (128, 128), (4, 4)),
  81. str((5120, 1024, 21128, "float32", False, False)): ((64, 64), (128, 128), (4, 4)),
  82. str((16384, 4096, 1024, "float32", False, False)): ((4, 4), (64, 64), (64, 64)),
  83. str((1024, 4096, 16384, "float32", True, False)): ((64, 64), (64, 64), (4, 4)),
  84. str((768, 3072, 8192, "float16", True, False)): ((1, 1), (16, 1), (512, 1)),
  85. str((2048, 768, 3072, "float32", False, False)): ((8, 1), (16, 1), (128, 1)),
  86. str((1024, 2, 128, "float32", True, False)): ((32, 1), (2, 1), (128, 1)),
  87. str((1024, 2, 16, "float32", True, False)): ((256, 1), (2, 1), (16, 1)),
  88. str((1024, 2, 32, "float32", True, False)): ((128, 1), (2, 1), (32, 1)),
  89. str((1024, 2, 64, "float32", True, False)): ((64, 1), (2, 1), (64, 1)),
  90. str((1024, 2, 8, "float32", True, False)): ((512, 1), (2, 1), (8, 1)),
  91. str((768, 2, 128, "float32", True, False)): ((32, 1), (2, 1), (128, 1)),
  92. str((768, 2, 16, "float32", True, False)): ((256, 1), (2, 1), (16, 1)),
  93. str((768, 2, 32, "float32", True, False)): ((128, 1), (2, 1), (32, 1)),
  94. str((768, 2, 64, "float32", True, False)): ((64, 1), (2, 1), (64, 1)),
  95. str((768, 3072, 2048, "float32", True, False)): ((16, 1), (16, 1), (64, 1)),
  96. str((3072, 768, 2048, "float32", True, False)): ((16, 1), (16, 1), (64, 1)),
  97. str((65536, 1024, 4096, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
  98. str((10240, 1024, 21128, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
  99. str((10240, 21128, 1024, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
  100. str((21128, 768, 21128, "float32", True, False)): ((1, 1), (768, 1), (16, 1)),
  101. str((2560, 21128, 1024, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
  102. str((5120, 21128, 1024, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
  103. str((32768, 4096, 1024, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
  104. str((20480, 1024, 21128, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
  105. str((128, 1024, 4096, "float32", False, True)): ((1, 1), (16, 1), (16, 1)),
  106. str((128, 1024, 4096, "float16", False, True)): ((1, 1), (16, 1), (32, 1)),
  107. str((128, 4096, 1024, "float16", False, True)): ((1, 1), (32, 1), (32, 1)),
  108. str((20480, 21128, 1024, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
  109. str((65536, 4096, 1024, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
  110. str((512, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
  111. str((256, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
  112. str((1024, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
  113. str((2048, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
  114. str((2, 2, 1024, "float32", False, True)): ((1, 1), (2, 1), (1024, 1)),
  115. str((65536, 768, 3072, "float32", False, True)): ((1, 1), (192, 1), (96, 1)),
  116. str((16384, 1024, 4096, "float32", False, True)): ((1, 1), (128, 1), (128, 1)),
  117. str((1024, 4096, 32768, "float32", True, False)): ((8, 1), (1024, 1), (4, 1)),
  118. str((4, 2, 1024, "float32", False, True)): ((1, 1), (2, 1), (1024, 1)),
  119. str((2, 2, 768, "float32", False, True)): ((1, 1), (2, 1), (768, 1)),
  120. str((8, 2, 768, "float32", False, True)): ((1, 1), (2, 1), (768, 1)),
  121. str((4, 2, 768, "float32", False, True)): ((1, 1), (2, 1), (768, 1)),
  122. # lenet5
  123. str((32, 10, 84, 'float16', False, True)): ((1, 1), (16, 1), (84, 1)),
  124. # alexnet
  125. str((32, 4096, 9216, 'float32', False, False)): ((1, 1), (16, 1), (1024, 1)),
  126. str((32, 10, 4096, 'float32', False, False)): ((1, 1), (16, 1), (1024, 1)),
  127. # 3D
  128. str((128, 128, 64, 1536, "float32", True, False)): ((4, 4), (8, 8), (16, 16), (32, 32)),
  129. str((128, 768, 128, 64, "float32", False, True)): ((4, 4), (8, 8), (16, 16), (64, 64)),
  130. str((128, 128, 64, 6144, "float32", True, False)): ((4, 4), (8, 8), (16, 16), (32, 32)),
  131. str((128, 128, 64, 16384, "float32", True, False)): ((1, 1), (128, 128), (64, 64), (4, 4)),
  132. str((128, 128, 64, 2048, "float32", True, False)): ((1, 1), (128, 128), (64, 64), (4, 4)),
  133. str((128, 128, 64, 4096, "float32", True, False)): ((1, 1), (128, 128), (64, 64), (4, 4)),
  134. str((128, 128, 64, 8192, "float32", True, False)): ((1, 1), (128, 128), (64, 64), (4, 4)),
  135. str((128, 128, 64, 4096, "float32", True, False)): ((1, 1), (8, 1), (64, 1), (64, 1)),
  136. str((128, 128, 64, 12288, "float32", True, False)): ((1, 1), (8, 1), (64, 1), (32, 1)),
  137. # 4D
  138. str((64, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (12, 12), (8, 8), (8, 8), (32, 32)),
  139. str((1, 768, 2, "float32", False, False)): ((768, 1), (1, 1)),
  140. str((20, 768, 21128, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
  141. str((128, 12, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)),
  142. str((128, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)),
  143. str((1, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (2, 1), (128, 1), (64, 1)),
  144. str((128, 128, 64, 12, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)),
  145. str((1, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (32, 1)),
  146. str((20, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
  147. str((1, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (8, 1), (64, 1), (32, 1)),
  148. str((20, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
  149. str((128, 768, 3072, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
  150. str((1, 768, 768, "float32", False, False)): ((768, 1), (16, 1)),
  151. str((21128, 768, 20, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
  152. str((128, 12, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (128, 1), (64, 1)),
  153. str((128, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
  154. str((40, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (16, 1)),
  155. str((2, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)),
  156. str((2, 768, 2, "float32", False, False)): ((2, 1), (768, 1), (2, 1)),
  157. str((40, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
  158. str((256, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
  159. str((2, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)),
  160. str((2, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)),
  161. str((128, 128, 64, 24, "float32", True, False)): ((1, 1), (2, 1), (64, 1), (12, 1)),
  162. str((256, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)),
  163. str((128, 24, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)),
  164. str((2, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
  165. str((21128, 768, 40, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
  166. str((40, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (16, 1)),
  167. str((256, 768, 3072, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
  168. str((128, 24, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)),
  169. str((128, 48, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)),
  170. str((512, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)),
  171. str((128, 48, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)),
  172. str((512, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
  173. str((4, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)),
  174. str((512, 768, 3072, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
  175. str((4, 768, 2, "float32", False, False)): ((1, 1), (768, 1), (2, 1)),
  176. str((80, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
  177. str((4, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)),
  178. str((21128, 768, 80, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
  179. str((80, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
  180. str((128, 128, 64, 48, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)),
  181. str((4, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
  182. str((4, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)),
  183. str((80, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
  184. str((128, 96, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)),
  185. str((1024, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)),
  186. str((160, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
  187. str((8, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)),
  188. str((8, 768, 2, "float32", False, False)): ((8, 1), (768, 1), (2, 1)),
  189. str((8, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
  190. str((1024, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
  191. str((128, 96, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)),
  192. str((128, 128, 64, 96, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)),
  193. str((8, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)),
  194. str((1024, 768, 3072, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
  195. str((21128, 768, 160, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
  196. str((160, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
  197. str((8, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)),
  198. str((160, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
  199. str((16, 768, 2, "float32", False, False)): ((2, 1), (768, 1), (2, 1)),
  200. str((320, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
  201. str((16, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)),
  202. str((2048, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)),
  203. str((16, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)),
  204. str((128, 128, 64, 192, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)),
  205. str((128, 192, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)),
  206. str((21128, 768, 320, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
  207. str((128, 192, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)),
  208. str((320, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
  209. str((16, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)),
  210. str((320, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
  211. str((2048, 768, 3072, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
  212. str((2048, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
  213. str((16, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
  214. str((4096, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
  215. str((128, 128, 64, 384, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)),
  216. str((4096, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)),
  217. str((640, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
  218. str((128, 384, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)),
  219. str((128, 384, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)),
  220. str((32, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)),
  221. str((32, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)),
  222. str((32, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)),
  223. str((32, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
  224. str((4096, 768, 3072, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
  225. str((640, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
  226. str((32, 768, 2, "float32", False, False)): ((4, 1), (768, 1), (2, 1)),
  227. str((21128, 768, 640, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
  228. str((640, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
  229. str((64, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)),
  230. str((8192, 768, 3072, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
  231. str((1280, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
  232. str((8192, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)),
  233. str((128, 768, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)),
  234. str((128, 128, 64, 768, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)),
  235. str((21128, 768, 1280, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
  236. str((1280, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
  237. str((64, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)),
  238. str((1280, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
  239. str((64, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (16, 1)),
  240. str((64, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)),
  241. str((64, 768, 2, "float32", False, False)): ((8, 1), (768, 1), (2, 1)),
  242. str((8192, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
  243. str((128, 768, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)),
  244. str((128, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)),
  245. str((16384, 768, 3072, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
  246. str((16384, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
  247. str((128, 1536, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)),
  248. str((128, 768, 2, "float32", False, False)): ((16, 1), (768, 1), (2, 1)),
  249. str((2560, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
  250. str((21128, 768, 2560, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
  251. str((128, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)),
  252. str((2560, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
  253. str((128, 128, 64, 1536, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)),
  254. str((128, 1536, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)),
  255. str((2560, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
  256. str((128, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)),
  257. str((16384, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)),
  258. str((1, 16, 128, 64, 128, "float32", False, False)): ((1, 1), (4, 1), (64, 1), (8, 1)),
  259. str((1, 16, 128, 64, 128, "float32", True, False)): ((1, 1), (8, 1), (64, 1), (4, 1)),
  260. str((128, 16, 128, 64, "float32", False, True)): ((1, 1), (2, 1), (128, 1), (64, 1)),
  261. str((128, 1024, 1024, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
  262. str((20, 1024, 21128, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
  263. str((128, 4096, 1024, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
  264. str((128, 1024, 4096, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
  265. str((128, 16, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (32, 1)),
  266. str((20, 21128, 1024, "float32", False, True)): ((1, 1), (16, 1), (1024, 1)),
  267. str((1, 16, 128, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)),
  268. str((1, 1024, 2, "float32", False, False)): ((1024, 1), (2, 1)),
  269. str((21128, 1024, 20, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
  270. str((20, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
  271. str((128, 128, 64, 16, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (8, 1)),
  272. str((1, 1024, 1024, "float32", False, False)): ((1024, 1), (8, 1)),
  273. str((128, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
  274. str((1, 1024, 1024, "float32", False, True)): ((8, 1), (1024, 1)),
  275. str((20, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
  276. str((128, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)),
  277. str((128, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
  278. str((1, 2, 1024, "float32", False, True)): ((2, 1), (1024, 1)),
  279. str((2, 1024, 1, "float32", True, False)): ((2, 1), (1024, 1)),
  280. str((1024, 1024, 128, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
  281. str((1024, 1024, 1, "float32", True, False)): ((16, 1), (1024, 1)),
  282. str((1024, 1024, 20, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
  283. str((4096, 1024, 128, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
  284. str((1024, 4096, 128, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)),
  285. str((128, 128, 64, 32, "float32", True, False)): ((1, 1), (2, 1), (64, 1), (16, 1)),
  286. str((128, 32, 128, 64, "float32", False, True)): ((1, 1), (2, 1), (128, 1), (64, 1)),
  287. str((256, 4096, 1024, "float32", False, False)): ((4, 1), (2048, 1), (2, 1)),
  288. str((256, 1024, 1024, "float32", False, False)): ((4, 1), (1024, 1), (4, 1)),
  289. str((40, 21128, 1024, "float32", False, True)): ((1, 1), (16, 1), (1024, 1)),
  290. str((2, 16, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (2, 1), (64, 1), (16, 1)),
  291. str((2, 16, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (4, 1), (64, 1), (8, 1)),
  292. str((21128, 1024, 40, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
  293. str((2, 1024, 2, "float32", False, False)): ((2, 1), (1024, 1), (2, 1)),
  294. str((256, 1024, 4096, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
  295. str((40, 1024, 21128, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
  296. str((2, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
  297. str((40, 1024, 1024, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
  298. str((128, 32, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (32, 1)),
  299. str((2, 16, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (4, 1), (64, 1), (64, 1)),
  300. str((2, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
  301. str((256, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
  302. str((256, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
  303. str((2, 2, 1024, "float32", False, True)): ((1, 1), (2, 1), (1024, 1)),
  304. str((256, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)),
  305. str((40, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
  306. str((1024, 1024, 256, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
  307. str((1024, 1024, 40, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
  308. str((4096, 1024, 256, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
  309. str((2, 1024, 2, "float32", True, False)): ((2, 1), (1024, 1), (2, 1)),
  310. str((1024, 1024, 2, "float32", True, False)): ((16, 1), (1024, 1), (1, 1)),
  311. str((1024, 4096, 256, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)),
  312. str((4, 16, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (4, 1), (64, 1), (64, 1)),
  313. str((4, 16, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (4, 1), (64, 1), (8, 1)),
  314. str((80, 21128, 1024, "float32", False, True)): ((1, 1), (16, 1), (1024, 1)),
  315. str((4, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
  316. str((128, 64, 128, 64, "float32", False, True)): ((1, 1), (2, 1), (128, 1), (64, 1)),
  317. str((128, 128, 64, 64, "float32", True, False)): ((1, 1), (2, 1), (64, 1), (16, 1)),
  318. str((21128, 1024, 80, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
  319. str((4, 1024, 2, "float32", False, False)): ((4, 1), (1024, 1), (2, 1)),
  320. str((128, 64, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (32, 1)),
  321. str((80, 1024, 21128, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
  322. str((512, 1024, 1024, "float32", False, False)): ((4, 1), (1024, 1), (4, 1)),
  323. str((512, 1024, 4096, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
  324. str((512, 4096, 1024, "float32", False, False)): ((4, 1), (2048, 1), (2, 1)),
  325. str((4, 16, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (2, 1), (64, 1), (16, 1)),
  326. str((80, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
  327. str((512, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
  328. str((80, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
  329. str((4, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
  330. str((4, 2, 1024, "float32", False, True)): ((1, 1), (2, 1), (1024, 1)),
  331. str((512, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
  332. str((512, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)),
  333. str((1024, 1024, 512, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)),
  334. str((2, 1024, 4, "float32", True, False)): ((2, 1), (1024, 1), (4, 1)),
  335. str((4096, 1024, 512, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
  336. str((1024, 1024, 4, "float32", True, False)): ((16, 1), (1024, 1), (1, 1)),
  337. str((1024, 1024, 80, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
  338. str((1024, 4096, 512, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)),
  339. str((4096, 768, 3072, "float32", False, True)): ((1, 1), (8, 1), (3072, 1)),
  340. str((4096, 3072, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
  341. str((4096, 768, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
  342. str((8192, 768, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
  343. str((8192, 3072, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
  344. str((8192, 768, 3072, "float32", False, True)): ((1, 1), (8, 1), (3072, 1)),
  345. str((16384, 3072, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
  346. str((16384, 768, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
  347. str((16384, 768, 3072, "float32", False, True)): ((1, 1), (8, 1), (3072, 1)),
  348. str((3072, 768, 128, "float32", True, False)): ((4, 1), (768, 1), (4, 1)),
  349. str((768, 768, 1, "float32", True, False)): ((16, 1), (768, 1)),
  350. str((768, 768, 20, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
  351. str((768, 3072, 128, "float32", True, False)): ((4, 1), (3072, 1), (1, 1)),
  352. str((2, 768, 1, "float32", True, False)): ((2, 1), (768, 1)),
  353. str((768, 768, 128, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
  354. str((768, 768, 2, "float32", True, False)): ((4, 1), (768, 1), (2, 1)),
  355. str((768, 768, 40, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
  356. str((3072, 768, 256, "float32", True, False)): ((4, 1), (768, 1), (4, 1)),
  357. str((768, 3072, 256, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)),
  358. str((768, 768, 256, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
  359. str((2, 768, 2, "float32", True, False)): ((2, 1), (768, 1), (2, 1)),
  360. str((768, 768, 80, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
  361. str((3072, 768, 512, "float32", True, False)): ((2, 1), (768, 1), (8, 1)),
  362. str((768, 3072, 512, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)),
  363. str((768, 768, 4, "float32", True, False)): ((4, 1), (768, 1), (2, 1)),
  364. str((768, 768, 512, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
  365. str((2, 768, 4, "float32", True, False)): ((2, 1), (768, 1), (4, 1)),
  366. str((768, 768, 1024, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
  367. str((768, 3072, 1024, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)),
  368. str((3072, 768, 1024, "float32", True, False)): ((2, 1), (768, 1), (8, 1)),
  369. str((768, 768, 8, "float32", True, False)): ((4, 1), (768, 1), (2, 1)),
  370. str((768, 768, 160, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
  371. str((2, 768, 8, "float32", True, False)): ((2, 1), (768, 1), (8, 1)),
  372. str((3072, 768, 2048, "float32", True, False)): ((2, 1), (768, 1), (8, 1)),
  373. str((768, 3072, 2048, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)),
  374. str((768, 768, 2048, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
  375. str((768, 768, 320, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
  376. str((768, 768, 16, "float32", True, False)): ((4, 1), (768, 1), (2, 1)),
  377. str((2, 768, 16, "float32", True, False)): ((2, 1), (768, 1), (16, 1)),
  378. str((768, 768, 32, "float32", True, False)): ((4, 1), (768, 1), (2, 1)),
  379. str((2, 768, 32, "float32", True, False)): ((2, 1), (768, 1), (16, 1)),
  380. str((3072, 768, 4096, "float32", True, False)): ((2, 1), (768, 1), (8, 1)),
  381. str((768, 3072, 4096, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)),
  382. str((768, 768, 640, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
  383. str((768, 768, 4096, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
  384. str((768, 768, 64, "float32", True, False)): ((4, 1), (768, 1), (2, 1)),
  385. str((2, 768, 64, "float32", True, False)): ((2, 1), (768, 1), (16, 1)),
  386. str((768, 3072, 8192, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)),
  387. str((768, 768, 8192, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
  388. str((768, 768, 1280, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
  389. str((3072, 768, 8192, "float32", True, False)): ((2, 1), (768, 1), (8, 1)),
  390. str((768, 768, 2560, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
  391. str((3072, 768, 16384, "float32", True, False)): ((2, 1), (768, 1), (8, 1)),
  392. str((2, 768, 128, "float32", True, False)): ((2, 1), (768, 1), (16, 1)),
  393. str((768, 768, 16384, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
  394. str((768, 3072, 16384, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)),
  395. str((8, 16, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (2, 1), (64, 1), (16, 1)),
  396. str((21128, 1024, 160, "float32", True, False)): ((4, 1), (1024, 1), (4, 1)),
  397. str((1024, 4096, 1024, "float32", False, False)): ((4, 1), (2048, 1), (2, 1)),
  398. str((160, 1024, 21128, "float32", False, False)): ((2, 1), (1024, 1), (4, 1)),
  399. str((8, 16, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (4, 1), (64, 1), (8, 1)),
  400. str((160, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
  401. str((8, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
  402. str((1024, 1024, 1024, "float32", False, False)): ((4, 1), (1024, 1), (4, 1)),
  403. str((160, 21128, 1024, "float32", False, True)): ((1, 1), (16, 1), (1024, 1)),
  404. str((1024, 1024, 4096, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
  405. str((8, 16, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (4, 1), (64, 1), (64, 1)),
  406. str((128, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (32, 1)),
  407. str((128, 128, 128, 64, "float32", False, True)): ((1, 1), (2, 1), (128, 1), (64, 1)),
  408. str((128, 128, 64, 128, "float32", True, False)): ((1, 1), (2, 1), (64, 1), (16, 1)),
  409. str((8, 1024, 2, "float32", False, False)): ((1, 1), (1024, 1), (1, 1)),
  410. str((21128, 1024, 320, "float32", True, False)): ((4, 1), (1024, 1), (4, 1)),
  411. str((16, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
  412. str((2048, 4096, 1024, "float32", False, False)): ((4, 1), (2048, 1), (2, 1)),
  413. str((16, 1024, 2, "float32", False, False)): ((1, 1), (1024, 1), (1, 1)),
  414. str((16, 16, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (4, 1), (64, 1), (8, 1)),
  415. str((16, 16, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (2, 1), (64, 1), (16, 1)),
  416. str((2048, 1024, 1024, "float32", False, False)): ((4, 1), (1024, 1), (4, 1)),
  417. str((320, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
  418. str((2048, 1024, 4096, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
  419. str((320, 21128, 1024, "float32", False, True)): ((1, 1), (16, 1), (1024, 1)),
  420. str((128, 256, 128, 64, "float32", False, True)): ((1, 1), (2, 1), (128, 1), (64, 1)),
  421. str((320, 1024, 21128, "float32", False, False)): ((2, 1), (1024, 1), (4, 1)),
  422. str((128, 256, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (32, 1)),
  423. str((128, 128, 64, 256, "float32", True, False)): ((1, 1), (2, 1), (64, 1), (16, 1)),
  424. str((16, 16, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (4, 1), (64, 1), (64, 1)),
  425. str((1024, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
  426. str((8, 2, 1024, "float32", False, True)): ((1, 1), (2, 1), (1024, 1)),
  427. str((160, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
  428. str((1024, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)),
  429. str((1024, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
  430. str((8, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
  431. str((16, 2, 1024, "float32", False, True)): ((1, 1), (2, 1), (1024, 1)),
  432. str((320, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
  433. str((16, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
  434. str((2048, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
  435. str((2048, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
  436. str((2048, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)),
  437. str((1024, 4096, 1024, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)),
  438. str((1024, 1024, 8, "float32", True, False)): ((16, 1), (1024, 1), (1, 1)),
  439. str((2, 1024, 8, "float32", True, False)): ((2, 1), (1024, 1), (8, 1)),
  440. str((1024, 1024, 1024, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)),
  441. str((4096, 1024, 1024, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
  442. str((1024, 1024, 160, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
  443. str((4096, 1024, 2048, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
  444. str((1024, 1024, 320, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
  445. str((1024, 1024, 16, "float32", True, False)): ((16, 1), (1024, 1), (1, 1)),
  446. str((1024, 4096, 2048, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)),
  447. str((2, 1024, 16, "float32", True, False)): ((2, 1), (1024, 1), (16, 1)),
  448. str((1024, 1024, 2048, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)),
  449. str((32, 1001, 2048, "float16", False, True)): ((1, 1), (77, 1), (256, 1)),
  450. str((1001, 2048, 32, "float16", True, False)): ((1, 1), (2048, 1), (4, 1)),
  451. str((32, 2048, 1001, "float16", False, False)): ((1, 1), (2048, 1), (4, 1)),
  452. str((32, 1001, 2048, "float32", False, True)): ((1, 1), (7, 1), (2048, 1)),
  453. str((1001, 2048, 32, "float32", True, False)): ((1, 1), (2048, 1), (4, 1)),
  454. str((32, 2048, 1001, "float32", False, False)): ((1, 1), (2048, 1), (4, 1)),
  455. str((768, 3072, 131072, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)),
  456. str((3072, 768, 131072, "float32", True, False)): ((2, 1), (768, 1), (8, 1)),
  457. str((65536, 768, 3072, "float32", False, True)): ((1, 1), (8, 1), (3072, 1)),
  458. str((131072, 3072, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
  459. str((32768, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)),
  460. str((65536, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)),
  461. str((131072, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
  462. str((131072, 768, 3072, "float32", False, True)): ((1, 1), (8, 1), (3072, 1)),
  463. str((65536, 3072, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
  464. str((32768, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
  465. str((65536, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
  466. str((10240, 21128, 1024, "float32", False, True)): ((1, 1), (16, 1), (1024, 1)),
  467. str((21128, 1024, 20480, "float32", True, False)): ((4, 1), (1024, 1), (4, 1)),
  468. str((2048, 3072, 768, "float16", False, False)): ((1, 1), (768, 1), (16, 1)),
  469. str((2048, 768, 3072, "float16", False, False)): ((2, 1), (768, 1), (8, 1)),
  470. str((10240, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
  471. str((20480, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
  472. str((21128, 768, 20480, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
  473. str((20480, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
  474. str((32, 10, 2048, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)),
  475. str((32, 10, 2048, "float16", False, True)): ((1, 1), (8, 1), (2048, 1)),
  476. str((32, 10, 4096, "float16", False, True)): ((1, 1), (2, 1), (4096, 1)),
  477. str((768, 768, 16384, "float16", True, False)): ((64, 1), (32, 1), (8, 1)),
  478. str((3072, 768, 16384, "float16", True, False)): ((64, 1), (32, 1), (8, 1)),
  479. str((768, 3072, 16384, "float16", True, False)): ((64, 1), (32, 1), (8, 1)),
  480. str((32, 9216, 4096, "float32", False, False)): ((1, 1), (9216, 1), (1, 1)),
  481. str((128, 128, 64, 3072, "float32", True, False)): ((1, 1), (64, 1), (64, 1), (1, 1)),
  482. str((21128, 1024, 10240, "float32", True, False)): ((8, 1), (1024, 1), (1, 1)),
  483. str((128, 128, 64, 1536, "float16", True, False)): ((1, 1), (1, 1), (16, 1), (1536, 1)),
  484. str((3072, 768, 2048, "float16", True, False)): ((1, 1), (16, 1), (1024, 1)),
  485. str((768, 3072, 2048, "float16", True, False)): ((1, 1), (16, 1), (1024, 1)),
  486. str((4096, 1024, 65536, "float32", True, False)): ((4, 1), (1024, 1), (4, 1)), # auto tiling crash
  487. str((1024, 4096, 65536, "float32", True, False)): ((1, 1), (4096, 1), (4, 1)), # auto tiling crash
  488. str((16384, 3072, 768, "float16", False, True)): ((1, 1), (16, 1), (768, 1)), # auto tiling crash
  489. str((16384, 768, 3072, "float16", False, True)): ((1, 1), (16, 1), (1024, 1)), # auto tiling crash
  490. str((1024, 4096, 1024, "float16", False, True)): ((1, 1), (16, 1), (1024, 1)), # auto tiling crash
  491. str((131072, 1024, 1024, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
  492. # Alexnet shape
  493. str((32, 10, 4096, "float32", False, True)): ((32, 1), (10, 1), (1, 1)), # auto tiling crash
  494. str((32, 4096, 4096, "float16", False, False)): ((1, 1), (16, 1), (512, 1)), # auto tiling crash
  495. str((32, 9216, 4096, "float16", False, False)): ((1, 1), (16, 1), (512, 1)), # auto tiling crash
  496. str((32, 4096, 9216, "float16", False, True)): ((1, 1), (16, 1), (1024, 1)), # auto tiling crash
  497. str((128, 128, 64, 1536, "float16", True, False)): ((1, 1), (1, 1), (16, 1), (96, 1)), # auto tiling crash
  498. str((32, 4096, 4096, "float32", False, False)): ((1, 1), (256, 1), (64, 1)), # performance optimization
  499. # Alexnet shape
  500. str((768, 3072, 16384, "float16", True, False)): ((64, 1), (32, 1), (8, 1)),
  501. str((3072, 768, 4096, "float16", True, False)): ((2, 1), (768, 1), (8, 1)),
  502. str((768, 3072, 4096, "float16", True, False)): ((2, 1), (3072, 1), (2, 1)),
  503. str((400, 120, 32, "float16", True, False)): ((2, 1), (32, 1), (2, 1)),
  504. }
  505. CORE_NUM = 2 if product_is_mini() else 32
  506. MINIMAL_FOR_MULTICORE = CORE_NUM * 512
  507. def get_best_align_elem(tensor_size, tensor_dtype):
  508. """Get the best tiling factor for alignment axis."""
  509. basic_align_elem = int(ct_util.BLOCK_SIZE / get_bytes(tensor_dtype))
  510. lcm = least_common_multiple(tensor_size, basic_align_elem)
  511. gcd = greatest_common_divisor(tensor_size, basic_align_elem)
  512. if gcd != 1:
  513. return gcd
  514. if lcm < tensor_size:
  515. return min(tensor_size, lcm)
  516. return -1
  517. def get_shape_pos_map(tensor_shape):
  518. """Mapping tensor shape to corresponding axis position."""
  519. batch_pos = [i for i in range(len(tensor_shape) - 3) if tensor_shape[i] != 1]
  520. mnk = dict()
  521. pos_map = {0: "m", 1: "n", 2: "k"}
  522. count = -1
  523. for i, _ in enumerate(batch_pos):
  524. count += 1
  525. mnk["b%s" % str(i)] = count
  526. for i, shp in enumerate(tensor_shape[-3:]):
  527. if shp != 1:
  528. count += 1
  529. mnk[pos_map[i]] = count
  530. return batch_pos, mnk
  531. def batchmatmul_tiling_strategy(shape, align_dtype, attrs):
  532. """This is an efficient version of tiling strategy for batchmatmul."""
  533. if len(shape) < 3:
  534. raise RuntimeError("Shape must be in the form of [(Batch_out, Batch_in,) M, N, K]. "
  535. "Current length of shape is {}".format(len(shape)))
  536. strategy = list()
  537. m, n, k = shape[-3:]
  538. batch_pos, mnk = get_shape_pos_map(shape)
  539. total_size = reduce(lambda x, y: int(x) * int(y), shape)
  540. # set minimal tile for n as the basic block size for alignment
  541. align_elem = int(ct_util.BLOCK_SIZE / get_bytes(align_dtype))
  542. tile_n = get_best_align_elem(n, align_dtype)
  543. # if n is smaller than block size, it is not safe to open multi-core for now
  544. if n < align_elem:
  545. n_constraint = [ct_util.TileConstraint.FACTOR]
  546. used_core = CORE_NUM
  547. attrs["enable_multicore"] = 0
  548. else:
  549. n_constraint = [ct_util.TileConstraint.MOD, ct_util.TileConstraint.MIN]
  550. used_core = 1
  551. if total_size >= MINIMAL_FOR_MULTICORE and used_core < CORE_NUM:
  552. # set maximal tile for batch according to multi-core usage
  553. for i, p in enumerate(batch_pos):
  554. use = min(shape[p], int((CORE_NUM - 1 + used_core) / used_core))
  555. used_core *= use
  556. max_b = int(shape[p] / use)
  557. strategy.append(ct_util.create_constraint_on_axis(values=max_b,
  558. constraints=ct_util.TileConstraint.MAX,
  559. band=0,
  560. axis=i)[0])
  561. tile_k = get_best_align_elem(k, align_dtype)
  562. k_constraint = ct_util.TileConstraint.MIN
  563. total_size /= max(1, int(k / tile_k))
  564. # set minimal tile for m according to multi-core usage when there is no expansion
  565. tile_m = -1
  566. m_constraint = ct_util.TileConstraint.MAX
  567. m_per_block = m
  568. max_core = min(CORE_NUM, int(total_size / MINIMAL_FOR_MULTICORE))
  569. if greatest_common_divisor(n, align_elem) != 1 and used_core < max_core:
  570. left_core = int((max_core - 1 + used_core) / used_core)
  571. core_limit = max(1, int(m / greatest_common_divisor(left_core, m)))
  572. nk_in_mem = int(n / max(1, tile_n)) * int(k / tile_k)
  573. balance_limit = max(1, int(m / greatest_common_divisor(nk_in_mem, m)))
  574. tile_m = min(core_limit, balance_limit)
  575. m_per_block = int(m / tile_m)
  576. # for large m case, it is more efficient to balance memory bound and calculation bound
  577. if m_per_block > int(n / max(1, tile_n)) * int(k / tile_k):
  578. tile_m = max(min(m, align_elem), tile_m)
  579. k_constraint = ct_util.TileConstraint.FACTOR
  580. # create constraints based on previous analysis
  581. if m != 1:
  582. strategy.append(ct_util.create_constraint_on_axis(values=tile_m,
  583. constraints=m_constraint,
  584. band=0,
  585. axis=mnk["m"])[0])
  586. if n != 1:
  587. for constraint in n_constraint:
  588. strategy.append(ct_util.create_constraint_on_axis(values=tile_n,
  589. constraints=constraint,
  590. band=0,
  591. axis=mnk["n"])[0])
  592. if k != 1:
  593. strategy.append(ct_util.create_constraint_on_axis(values=tile_k,
  594. constraints=k_constraint,
  595. band=0,
  596. axis=mnk["k"])[0])
  597. higher_priority_pos = mnk["k"] if k >= n else mnk["n"]
  598. strategy.append(ct_util.create_constraint_on_axis(values=0,
  599. constraints=ct_util.TileConstraint.SET_PRIORITY,
  600. band=0,
  601. axis=higher_priority_pos)[0])
  602. strategy.append(ct_util.modify_common_constraints(0.7, ct_util.TileConstraint.SET_MEM_RATIO))
  603. attrs["custom_tiling"] = strategy
  604. return attrs
  605. def batchmatmul_tiling_strategy_dynamic(shape, output, attrs):
  606. """This is an efficient version of tiling strategy for batchmatmul."""
  607. if len(shape) < 3:
  608. raise RuntimeError("Shape must be in the form of [(Batch_out, Batch_in,) M, N, K]. "
  609. "Current length of shape is {}".format(len(shape)))
  610. strategy = list()
  611. _, mnk = get_shape_pos_map(shape)
  612. # create constraints based on previous analysis
  613. strategy.append(ct_util.create_constraint_on_axis(values=1,
  614. constraints=ct_util.TileConstraint.FACTOR,
  615. band=0,
  616. axis=mnk["m"])[0])
  617. strategy.append(ct_util.create_constraint_on_axis(values="FULL",
  618. constraints=ct_util.TileConstraint.MAX,
  619. band=0,
  620. axis=mnk["n"])[0])
  621. strategy.append(ct_util.create_constraint_on_axis(values=8,
  622. constraints=ct_util.TileConstraint.FACTOR,
  623. band=0,
  624. axis=mnk["k"])[0])
  625. strategy.append(ct_util.modify_common_constraints(0.7, ct_util.TileConstraint.SET_MEM_RATIO))
  626. attrs["custom_tiling"] = strategy
  627. attrs["dynamic_shape"] = ds.set_dynamic_shape_limit_for_tensor(output, 2048, [1,])
  628. return attrs
  629. def get_mnk_from_matrix(shape_a_list, shape_b_list, trans_a, trans_b):
  630. """Get m, n and k value from input tensor shapes."""
  631. m, k = shape_a_list[-2], shape_a_list[-1]
  632. if trans_a:
  633. m, k = k, m
  634. n = shape_b_list[-2] if trans_b else shape_b_list[-1]
  635. return [m, n, k]
  636. def batchmatmul_set_dim(a_value, b_value, trans_a, trans_b):
  637. """This function is used to set dim info in attrs by set_dim_map."""
  638. shape_a_list = get_shape(a_value)
  639. shape_b_list = get_shape(b_value)
  640. m, n, k = get_mnk_from_matrix(shape_a_list, shape_b_list, trans_a, trans_b)
  641. key = ()
  642. if len(shape_a_list) > 2:
  643. key += tuple(shape_a_list[:-2])
  644. key += (m, n, k, a_value.dtype, trans_a, trans_b)
  645. set_dims = ct_util.set_dims_by_key(str(key), batchmatmul_set_dim_map)
  646. return set_dims, str(key)
  647. def batchmatmul_bias_set_dim(a_value, b_value, bias_value, trans_a, trans_b):
  648. """This function is used to set dim info in attrs by set_dim_map of batchmatmul with bias."""
  649. return batchmatmul_set_dim(a_value, b_value, trans_a, trans_b)
  650. def batchmatmul_no_bias_set_dim(a_value, b_value, trans_a, trans_b):
  651. """This function is used to set dim info in attrs by set_dim_map of batchmatmul without bias."""
  652. return batchmatmul_set_dim(a_value, b_value, trans_a, trans_b)
  653. @ct_util.reg_set_dim_func(batchmatmul_bias_set_dim)
  654. @vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, bool, bool)
  655. def batchmatmul_bias(a_value, b_value, bias_value, trans_a, trans_b):
  656. """
  657. Multiplies two tensors in batches and adds bias to the output.
  658. Args:
  659. a_value (tvm.tensor.Tensor): 2D or 3D or 4D tensor of type float16 or float32 with shape(..., r_A, c_A).
  660. b_value (tvm.tensor.Tensor): 2D or 3D or 4D tensor of same type as a_value with shape(..., r_B, c_B).
  661. bias_value (tvm.tensor.Tensor): The bias tensor added to the result of a_value * b_value.
  662. Should be of same type as a_value, broadcast is allowed.
  663. trans_a (bool): Specifies whether a_value is transposed or not, default value is False.
  664. trans_b (bool): Specifies whether b_value is transposed or not, default value is False.
  665. Returns:
  666. tvm.tensor.Tensor of same type as a_value with shape(..., r_C, c_C).
  667. r_C = c_A if trans_a else r_A
  668. c_C = r_B if trans_b else c_B
  669. """
  670. if not isinstance(trans_a, bool):
  671. raise TypeError("trans_a should be of type Boolean.")
  672. if not isinstance(trans_b, bool):
  673. raise TypeError("trans_b should be of type Boolean.")
  674. vc_util.ops_dtype_check([a_value.dtype, b_value.dtype, bias_value.dtype], vc_util.DtypeForDavinci.ALL_FLOAT)
  675. vc_util.elemwise_dtype_check(a_value.dtype, b_value.dtype)
  676. vc_util.elemwise_dtype_check(a_value.dtype, bias_value.dtype)
  677. vc_util.gemm_format_check(get_shape(a_value), get_shape(b_value), trans_a, trans_b)
  678. if len(a_value.shape) not in [2, 3, 4]:
  679. raise ValueError("Batch matmul only support 2D, 3D and 4D now.")
  680. c_value = batchmatmul(a_value, b_value, trans_a, trans_b)
  681. if isinstance(c_value, (tuple, list)):
  682. c_value = c_value[0]
  683. vc_util.auto_broadcast_check(get_shape(bias_value), get_shape(c_value))
  684. shape_c_list = get_shape(c_value)
  685. bias_value = akg.topi.broadcast_to(bias_value, shape_c_list)
  686. dim_info = batchmatmul_bias_set_dim(a_value, b_value, bias_value, trans_a, trans_b)
  687. if isinstance(dim_info, (tuple, list)):
  688. dim_info = dim_info[0]
  689. attrs = {}
  690. attrs["enable_compute_in_place"] = True
  691. if dim_info != "":
  692. attrs["dim"] = dim_info
  693. batch = get_shape(a_value)[:-2]
  694. mnk = get_mnk_from_matrix(get_shape(a_value), get_shape(b_value), trans_a, trans_b)
  695. attrs = batchmatmul_tiling_strategy(batch + mnk, c_value.dtype, attrs)
  696. return akg.tvm.compute(bias_value.shape,
  697. lambda *indice: c_value(*indice) + bias_value(*indice), name='matmul_bias_output'), attrs
  698. @ct_util.reg_set_dim_func(batchmatmul_no_bias_set_dim)
  699. @vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, bool, bool)
  700. def batchmatmul(a_value, b_value, trans_a=False, trans_b=False):
  701. """
  702. Multiplies two tensors in batches.
  703. Args:
  704. a_value (tvm.tensor.Tensor): 2D or 3D or 4D tensor of type float16 or float32 with shape(..., r_A, c_A).
  705. b_value (tvm.tensor.Tensor): 2D or 3D or 4D tensor of same type as a_value with shape(..., r_B, c_B).
  706. trans_a (bool): Specifies whether a_value is transposed or not, default value is False.
  707. trans_b (bool): Specifies whether b_value is transposed or not, default value is False.
  708. Returns:
  709. tvm.tensor.Tensor of same type as a_value with shape(..., r_C, c_C).
  710. r_C = c_A if trans_a else r_A
  711. c_C = r_B if trans_b else c_B
  712. """
  713. if not isinstance(trans_a, bool):
  714. raise TypeError("trans_a should be of type Boolean.")
  715. if not isinstance(trans_b, bool):
  716. raise TypeError("trans_b should be of type Boolean.")
  717. vc_util.ops_dtype_check([a_value.dtype, b_value.dtype], vc_util.DtypeForDavinci.ALL_FLOAT)
  718. vc_util.elemwise_dtype_check(a_value.dtype, b_value.dtype)
  719. vc_util.gemm_format_check(get_shape(a_value), get_shape(b_value), trans_a, trans_b)
  720. if len(a_value.shape) not in [2, 3, 4]:
  721. raise ValueError("Batch matmul only support 2D, 3D and 4D now.")
  722. dtype = a_value.dtype
  723. if dtype == 'float16':
  724. if len(a_value.shape) == 2:
  725. c_value = vectormatmul_2d_cast(a_value, b_value, trans_a, trans_b, "float32")
  726. elif len(a_value.shape) == 3:
  727. c_value = vectormatmul_3d_cast(a_value, b_value, trans_a, trans_b, "float32")
  728. else:
  729. c_value = vectormatmul_4d_cast(a_value, b_value, trans_a, trans_b, "float32")
  730. else:
  731. if len(a_value.shape) == 2:
  732. c_value = vectormatmul_2d(a_value, b_value, trans_a, trans_b)
  733. elif len(a_value.shape) == 3:
  734. c_value = vectormatmul_3d(a_value, b_value, trans_a, trans_b)
  735. else:
  736. c_value = vectormatmul_4d(a_value, b_value, trans_a, trans_b)
  737. dim_info = batchmatmul_no_bias_set_dim(a_value, b_value, trans_a, trans_b)
  738. if isinstance(dim_info, (tuple, list)):
  739. dim_info = dim_info[0]
  740. attrs = {}
  741. attrs["enable_compute_in_place"] = True
  742. if dim_info != "":
  743. attrs["dim"] = dim_info
  744. mnk = get_mnk_from_matrix(get_shape(a_value), get_shape(b_value), trans_a, trans_b)
  745. batch = get_shape(a_value)[:-2]
  746. is_dynamic = ds.shape_is_dynamic([a_value, b_value])
  747. if not is_dynamic:
  748. attrs = batchmatmul_tiling_strategy(batch + mnk, c_value.dtype, attrs)
  749. else:
  750. attrs = batchmatmul_tiling_strategy_dynamic(batch + mnk, c_value, attrs)
  751. attrs["enable_pre_storage_write_simplify"] = True
  752. attrs["enable_sink_allocate"] = True
  753. attrs["enable_double_buffer"] = False
  754. return c_value, attrs
  755. def vectormatmul_3d(a_value, b_value, trans_a, trans_b):
  756. """hybrid implementation for 3D batchmatmul."""
  757. if trans_a:
  758. bs, k, m = a_value.shape
  759. else:
  760. bs, m, k = a_value.shape
  761. if trans_b:
  762. n = b_value.shape[-2]
  763. else:
  764. n = b_value.shape[-1]
  765. dtype = a_value.dtype
  766. zero = akg.tvm.const(0.0, dtype=dtype)
  767. @script(capture=locals())
  768. def matmul_hybrid_f_f(a, b, zero):
  769. t_1 = allocate((bs, m, k, n), a.dtype, 'local')
  770. t_2 = allocate((bs, m, n), a.dtype, 'local')
  771. for i_bs in range(0, bs):
  772. for i_m in range(0, m):
  773. for i_k in range(0, k):
  774. for i_n in range(0, n):
  775. t_1[i_bs, i_m, i_k, i_n] = a[i_bs, i_m, i_k] * b[i_bs, i_k, i_n]
  776. for i1_n in range(0, n):
  777. t_2[i_bs, i_m, i1_n] = zero
  778. for i1_k in range(0, k):
  779. for i1_n in range(0, n):
  780. t_2[i_bs, i_m, i1_n] = t_2[i_bs, i_m, i1_n] + t_1[i_bs, i_m, i1_k, i1_n]
  781. return t_2
  782. @script(capture=locals())
  783. def matmul_hybrid_f_t(a, b, zero):
  784. t_1 = allocate((bs, m, n, k), a.dtype, 'local')
  785. t_2 = allocate((bs, m, n), a.dtype, 'local')
  786. for i_bs in range(0, bs):
  787. for i_m in range(0, m):
  788. for i_n in range(0, n):
  789. t_2[i_bs, i_m, i_n] = zero
  790. for i_k in range(0, k):
  791. t_1[i_bs, i_m, i_n, i_k] = a[i_bs, i_m, i_k] * b[i_bs, i_n, i_k]
  792. t_2[i_bs, i_m, i_n] = t_1[i_bs, i_m, i_n, i_k] + t_2[i_bs, i_m, i_n]
  793. return t_2
  794. @script(capture=locals())
  795. def matmul_hybrid_t_f(a, b, zero):
  796. t_1 = allocate((bs, m, k, n), a.dtype, 'local')
  797. t_2 = allocate((bs, m, n), a.dtype, 'local')
  798. for i_bs in range(0, bs):
  799. for i_m in range(0, m):
  800. for i_k in range(0, k):
  801. for i_n in range(0, n):
  802. t_1[i_bs, i_m, i_k, i_n] = a[i_bs, i_k, i_m] * b[i_bs, i_k, i_n]
  803. for i1_n in range(0, n):
  804. t_2[i_bs, i_m, i1_n] = zero
  805. for i1_k in range(0, k):
  806. for i1_n in range(0, n):
  807. t_2[i_bs, i_m, i1_n] = t_2[i_bs, i_m, i1_n] + t_1[i_bs, i_m, i1_k, i1_n]
  808. return t_2
  809. if not trans_a and not trans_b:
  810. c_value = matmul_hybrid_f_f(a_value, b_value, zero)
  811. elif not trans_a and trans_b:
  812. c_value = matmul_hybrid_f_t(a_value, b_value, zero)
  813. elif trans_a and not trans_b:
  814. c_value = matmul_hybrid_t_f(a_value, b_value, zero)
  815. else:
  816. raise ValueError('Not support both transpose yet')
  817. return c_value
  818. def vectormatmul_4d(a_value, b_value, trans_a, trans_b):
  819. """hybrid implementation for 4D batchmatmul."""
  820. if trans_a:
  821. bs1, bs2, k, m = a_value.shape
  822. else:
  823. bs1, bs2, m, k = a_value.shape
  824. if trans_b:
  825. n = b_value.shape[-2]
  826. else:
  827. n = b_value.shape[-1]
  828. dtype = a_value.dtype
  829. zero = akg.tvm.const(0.0, dtype=dtype)
  830. @script(capture=locals())
  831. def matmul_hybrid_f_f(a, b, zero):
  832. t_1 = allocate((bs1, bs2, m, k, n), a.dtype, 'local')
  833. t_2 = allocate((bs1, bs2, m, n), a.dtype, 'local')
  834. for i_bs1 in range(0, bs1):
  835. for i_bs2 in range(0, bs2):
  836. for i_m in range(0, m):
  837. for i_k in range(0, k):
  838. for i_n in range(0, n):
  839. t_1[i_bs1, i_bs2, i_m, i_k, i_n] = a[i_bs1, i_bs2, i_m, i_k] * b[i_bs1, i_bs2, i_k, i_n]
  840. for i1_n in range(0, n):
  841. t_2[i_bs1, i_bs2, i_m, i1_n] = zero
  842. for i1_k in range(0, k):
  843. for i1_n in range(0, n):
  844. t_2[i_bs1, i_bs2, i_m, i1_n] = t_2[i_bs1, i_bs2, i_m, i1_n] + \
  845. t_1[i_bs1, i_bs2, i_m, i1_k, i1_n]
  846. return t_2
  847. @script(capture=locals())
  848. def matmul_hybrid_f_t(a, b, zero):
  849. t_1 = allocate((bs1, bs2, m, n, k), a.dtype, 'local')
  850. t_2 = allocate((bs1, bs2, m, n), a.dtype, 'local')
  851. for i_bs1 in range(0, bs1):
  852. for i_bs2 in range(0, bs2):
  853. for i_m in range(0, m):
  854. for i_n in range(0, n):
  855. t_2[i_bs1, i_bs2, i_m, i_n] = zero
  856. for i_k in range(0, k):
  857. t_1[i_bs1, i_bs2, i_m, i_n, i_k] = a[i_bs1, i_bs2, i_m, i_k] * b[i_bs1, i_bs2, i_n, i_k]
  858. t_2[i_bs1, i_bs2, i_m, i_n] = t_1[i_bs1, i_bs2, i_m, i_n, i_k] + t_2[i_bs1, i_bs2, i_m, i_n]
  859. return t_2
  860. @script(capture=locals())
  861. def matmul_hybrid_t_f(a, b, zero):
  862. t_1 = allocate((bs1, bs2, m, k, n), a.dtype, 'local')
  863. t_2 = allocate((bs1, bs2, m, n), a.dtype, 'local')
  864. for i_bs1 in range(0, bs1):
  865. for i_bs2 in range(0, bs2):
  866. for i_m in range(0, m):
  867. for i_k in range(0, k):
  868. for i_n in range(0, n):
  869. t_1[i_bs1, i_bs2, i_m, i_k, i_n] = a[i_bs1, i_bs2, i_k, i_m] * b[i_bs1, i_bs2, i_k, i_n]
  870. for i1_n in range(0, n):
  871. t_2[i_bs1, i_bs2, i_m, i1_n] = zero
  872. for i1_k in range(0, k):
  873. for i1_n in range(0, n):
  874. t_2[i_bs1, i_bs2, i_m, i1_n] = t_2[i_bs1, i_bs2, i_m, i1_n] + \
  875. t_1[i_bs1, i_bs2, i_m, i1_k, i1_n]
  876. return t_2
  877. if not trans_a and not trans_b:
  878. c_value = matmul_hybrid_f_f(a_value, b_value, zero)
  879. elif not trans_a and trans_b:
  880. c_value = matmul_hybrid_f_t(a_value, b_value, zero)
  881. elif trans_a and not trans_b:
  882. c_value = matmul_hybrid_t_f(a_value, b_value, zero)
  883. else:
  884. raise ValueError('Not support both transpose yet')
  885. return c_value
  886. def vectormatmul_4d_cast(a_value, b_value, trans_a, trans_b, cast_dtype):
  887. """dsl implementation with data type cast for 4D batchmatmul."""
  888. if trans_a:
  889. b1, b2, k, m = a_value.shape
  890. else:
  891. b1, b2, m, k = a_value.shape
  892. if trans_b:
  893. n = b_value.shape[-2]
  894. else:
  895. n = b_value.shape[-1]
  896. dtype = a_value.dtype
  897. def matmul_4d_dsl(a_value, b_value, trans_a, trans_b):
  898. if not trans_a and not trans_b:
  899. ele_mul = akg.tvm.compute((b1, b2, m, n, k),
  900. lambda i_b1, i_b2, i_m, i_n, i_k:
  901. a_value[i_b1, i_b2, i_m, i_k].astype(cast_dtype) *
  902. b_value[i_b1, i_b2, i_k, i_n].astype(cast_dtype),
  903. name="ele_mul")
  904. elif not trans_a and trans_b:
  905. ele_mul = akg.tvm.compute((b1, b2, m, n, k),
  906. lambda i_b1, i_b2, i_m, i_n, i_k:
  907. a_value[i_b1, i_b2, i_m, i_k].astype(cast_dtype) *
  908. b_value[i_b1, i_b2, i_n, i_k].astype(cast_dtype),
  909. name="ele_mul")
  910. elif trans_a and not trans_b:
  911. ele_mul = akg.tvm.compute((b1, b2, m, n, k),
  912. lambda i_b1, i_b2, i_m, i_n, i_k:
  913. a_value[i_b1, i_b2, i_k, i_m].astype(cast_dtype) *
  914. b_value[i_b1, i_b2, i_k, i_n].astype(cast_dtype),
  915. name="ele_mul")
  916. elif trans_a and trans_b:
  917. ele_mul = akg.tvm.compute((b1, b2, m, n, k),
  918. lambda i_b1, i_b2, i_m, i_n, i_k:
  919. b_value[i_b1, i_b2, i_n, i_k].astype(cast_dtype) *
  920. a_value[i_b1, i_b2, i_k, i_m].astype(cast_dtype),
  921. name="ele_mul")
  922. reduce_axis = akg.tvm.reduce_axis((0, k), name='reduce_axis')
  923. output_shape = (b1, b2, m, n)
  924. m_c = akg.tvm.compute(output_shape, lambda *i: akg.tvm.sum(ele_mul[i + (reduce_axis,)],
  925. axis=reduce_axis), name="matmul_compute")
  926. return m_c
  927. c_cast = matmul_4d_dsl(a_value, b_value, trans_a, trans_b)
  928. c_value = cast.cast(c_cast, dtype)
  929. if trans_a and trans_b:
  930. c_res = akg.topi.transpose(c_value, (1, 0))
  931. return c_res
  932. return c_value
  933. def vectormatmul_3d_cast(a_value, b_value, trans_a, trans_b, cast_dtype):
  934. """dsl implementation with data type cast for 3D batchmatmul."""
  935. if trans_a:
  936. b, k, m = a_value.shape
  937. else:
  938. b, m, k = a_value.shape
  939. if trans_b:
  940. n = b_value.shape[-2]
  941. else:
  942. n = b_value.shape[-1]
  943. dtype = a_value.dtype
  944. def matmul_3d_dsl(a_value, b_value, trans_a, trans_b):
  945. if not trans_a and not trans_b:
  946. ele_mul = akg.tvm.compute((b, m, n, k), lambda i_b, i_m, i_n, i_k:
  947. a_value[i_b, i_m, i_k].astype(cast_dtype) *
  948. b_value[i_b, i_k, i_n].astype(cast_dtype),
  949. name="ele_mul")
  950. elif not trans_a and trans_b:
  951. ele_mul = akg.tvm.compute((b, m, n, k), lambda i_b, i_m, i_n, i_k:
  952. a_value[i_b, i_m, i_k].astype(cast_dtype) *
  953. b_value[i_b, i_n, i_k].astype(cast_dtype),
  954. name="ele_mul")
  955. elif trans_a and not trans_b:
  956. ele_mul = akg.tvm.compute((b, m, n, k), lambda i_b, i_m, i_n, i_k:
  957. a_value[i_b, i_k, i_m].astype(cast_dtype) *
  958. b_value[i_b, i_k, i_n].astype(cast_dtype),
  959. name="ele_mul")
  960. elif trans_a and trans_b:
  961. ele_mul = akg.tvm.compute((b, m, n, k), lambda i_b, i_m, i_n, i_k:
  962. b_value[i_b, i_n, i_k].astype(cast_dtype) *
  963. a_value[i_b, i_k, i_m].astype(cast_dtype),
  964. name="ele_mul")
  965. reduce_axis = akg.tvm.reduce_axis((0, k), name='reduce_axis')
  966. output_shape = (b, m, n)
  967. m_c = akg.tvm.compute(output_shape, lambda *i: akg.tvm.sum(ele_mul[i + (reduce_axis,)], axis=reduce_axis),
  968. name="matmul_compute")
  969. return m_c
  970. c_cast = matmul_3d_dsl(a_value, b_value, trans_a, trans_b)
  971. c_value = cast.cast(c_cast, dtype)
  972. if trans_a and trans_b:
  973. c_res = akg.topi.transpose(c_value, (1, 0))
  974. return c_res
  975. return c_value
  976. def vectormatmul_2d_cast(a_value, b_value, trans_a, trans_b, cast_dtype):
  977. """hybrid implementation with data type cast for 2D batchmatmul."""
  978. if trans_a:
  979. k, m = a_value.shape
  980. else:
  981. m, k = a_value.shape
  982. if trans_b:
  983. n = b_value.shape[-2]
  984. else:
  985. n = b_value.shape[-1]
  986. dtype = a_value.dtype
  987. # When the float16 cast to float32 directly, the AutoPoly pass cost a long time.
  988. # Therefore, the cast be done in single element.
  989. def matmul_2d(a_value, b_value, trans_a, trans_b):
  990. if not trans_a and not trans_b:
  991. ele_mul = akg.tvm.compute((m, n, k), lambda i_m, i_n, i_k: a_value[i_m, i_k].astype(cast_dtype)
  992. * b_value[i_k, i_n].astype(cast_dtype),
  993. name="ele_mul")
  994. elif not trans_a and trans_b:
  995. ele_mul = akg.tvm.compute((m, n, k), lambda i_m, i_n, i_k: a_value[i_m, i_k].astype(cast_dtype) *
  996. b_value[i_n, i_k].astype(cast_dtype),
  997. name="ele_mul")
  998. elif trans_a and not trans_b:
  999. ele_mul = akg.tvm.compute((m, n, k), lambda i_m, i_n, i_k: a_value[i_k, i_m].astype(cast_dtype) *
  1000. b_value[i_k, i_n].astype(cast_dtype),
  1001. name="ele_mul")
  1002. elif trans_a and trans_b:
  1003. ele_mul = akg.tvm.compute((m, n, k), lambda i_m, i_n, i_k: b_value[i_n, i_k].astype(cast_dtype) *
  1004. a_value[i_k, i_m].astype(cast_dtype),
  1005. name="ele_mul")
  1006. reduce_axis = akg.tvm.reduce_axis((0, k), name='reduce_axis')
  1007. output_shape = (m, n)
  1008. m_c = akg.tvm.compute(output_shape, lambda *i: akg.tvm.sum(ele_mul[i + (reduce_axis,)], axis=reduce_axis),
  1009. name="matmul_compute")
  1010. return m_c
  1011. c_cast = matmul_2d(a_value, b_value, trans_a, trans_b)
  1012. c_value = cast.cast(c_cast, dtype)
  1013. if trans_a and trans_b:
  1014. c_res = akg.topi.transpose(c_value, (1, 0))
  1015. return c_res
  1016. return c_value
  1017. def vectormatmul_2d(a_value, b_value, trans_a, trans_b):
  1018. """hybrid implementation for 2D batchmatmul."""
  1019. if trans_a:
  1020. k, m = a_value.shape
  1021. else:
  1022. m, k = a_value.shape
  1023. if trans_b:
  1024. n = b_value.shape[-2]
  1025. else:
  1026. n = b_value.shape[-1]
  1027. dtype = a_value.dtype
  1028. zero = akg.tvm.const(0.0, dtype=dtype)
  1029. @script(capture=locals())
  1030. def matmul_hybrid_f_f(a, b, zero, mv, nv, kv):
  1031. t_1 = allocate((mv, kv, nv), a.dtype, 'local')
  1032. t_2 = output_tensor((mv, nv), a.dtype)
  1033. for i_m in range(0, mv):
  1034. for i_k in range(0, kv):
  1035. for i_n in range(0, nv):
  1036. t_1[i_m, i_k, i_n] = a[i_m, i_k] * b[i_k, i_n]
  1037. for i1_n in range(0, nv):
  1038. t_2[i_m, i1_n] = zero
  1039. for i1_k in range(0, kv):
  1040. for i1_n in range(0, nv):
  1041. t_2[i_m, i1_n] = t_2[i_m, i1_n] + t_1[i_m, i1_k, i1_n]
  1042. return t_2
  1043. @script(capture=locals())
  1044. def matmul_hybrid_f_t(a, b, zero, mv, nv, kv):
  1045. t_1 = allocate((mv, nv, kv), a.dtype, 'local')
  1046. t_2 = allocate((mv, nv), a.dtype, 'local')
  1047. for i_m in range(0, mv):
  1048. for i_n in range(0, nv):
  1049. t_2[i_m, i_n] = zero
  1050. for i_k in range(0, kv):
  1051. t_1[i_m, i_n, i_k] = a[i_m, i_k] * b[i_n, i_k]
  1052. t_2[i_m, i_n] = t_1[i_m, i_n, i_k] + t_2[i_m, i_n]
  1053. return t_2
  1054. @script(capture=locals())
  1055. def matmul_hybrid_t_f(a, b, zero, mv, nv, kv):
  1056. t_1 = allocate((mv, kv, nv), a.dtype, 'local')
  1057. t_2 = allocate((mv, nv), a.dtype, 'local')
  1058. for i_m in range(0, mv):
  1059. for i_k in range(0, kv):
  1060. for i_n in range(0, nv):
  1061. t_1[i_m, i_k, i_n] = a[i_k, i_m] * b[i_k, i_n]
  1062. for i1_n in range(0, nv):
  1063. t_2[i_m, i1_n] = zero
  1064. for i1_k in range(0, kv):
  1065. for i1_n in range(0, nv):
  1066. t_2[i_m, i1_n] = t_2[i_m, i1_n] + t_1[i_m, i1_k, i1_n]
  1067. return t_2
  1068. @script(capture=locals())
  1069. def matmul_hybrid_t_t(a, b, zero, mv, nv, kv):
  1070. t_1 = allocate((nv, kv, mv), a.dtype, 'local')
  1071. t_2 = allocate((nv, mv), a.dtype, 'local')
  1072. for i_n in range(0, nv):
  1073. for i_m in range(0, mv):
  1074. for i_k in range(0, kv):
  1075. t_1[i_n, i_k, i_m] = b[i_n, i_k] * a[i_k, i_m]
  1076. for i1_m in range(0, mv):
  1077. t_2[i_n, i1_m] = zero
  1078. for i1_k in range(0, kv):
  1079. for i2_m in range(0, mv):
  1080. t_2[i_n, i2_m] = t_2[i_n, i2_m] + t_1[i_n, i1_k, i2_m]
  1081. return t_2
  1082. if not trans_a and not trans_b:
  1083. c_value = matmul_hybrid_f_f(a_value, b_value, zero, m, n, k)
  1084. elif not trans_a and trans_b:
  1085. c_value = matmul_hybrid_f_t(a_value, b_value, zero, m, n, k)
  1086. elif trans_a and not trans_b:
  1087. c_value = matmul_hybrid_t_f(a_value, b_value, zero, m, n, k)
  1088. else:
  1089. c1 = matmul_hybrid_t_t(a_value, b_value, zero, m, n, k)
  1090. c_value = akg.topi.transpose(c1, (1, 0))
  1091. return c_value