You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_reformat.cpp 140 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021
  1. /**
  2. * \file src/gopt/impl/tensor_reformat.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megbrain/gopt/basic_arith.h"
  13. #include "megbrain/gopt/gtrans.h"
  14. #include "megbrain/gopt/inference.h"
  15. #include "megbrain/graph/event.h"
  16. #include "megbrain/opr/basic_arith.h"
  17. #include "megbrain/opr/blas.h"
  18. #include "megbrain/opr/dnn/batch_norm.h"
  19. #include "megbrain/opr/dnn/convolution.h"
  20. #include "megbrain/opr/dnn/local.h"
  21. #include "megbrain/opr/dnn/pooling.h"
  22. #include "megbrain/opr/imgproc.h"
  23. #include "megbrain/opr/misc.h"
  24. #include "megbrain/opr/nn_int.h"
  25. #include "megbrain/opr/tensor_manip.h"
  26. #include "megbrain/opr/utility.h"
  27. #include "megbrain/serialization/opr_shallow_copy.h"
  28. #include "megbrain/utils/shared_set.h"
  29. #include "megdnn/opr_param_defs.h"
  30. #include "megdnn/tensor_format.h"
  31. #if MGB_ENABLE_TENSOR_RT
  32. #include "megbrain/tensorrt/tensorrt_opr.h"
  33. #endif
  34. #include "megbrain/gopt/misc.h"
  35. #include "megbrain/utils/hash_ct.h"
  36. #include "midout.h"
  37. MIDOUT_DECL(megbrain_tensor_reformat)
  38. #define MIDOUT_B(tag) \
  39. MIDOUT_BEGIN(megbrain_tensor_reformat, midout_iv(MGB_HASH_STR(tag))) {
  40. #define MIDOUT_E \
  41. } \
  42. MIDOUT_END();
  43. using namespace mgb;
  44. using namespace gopt;
  45. /* ================ TensorReformatPass =============== */
  46. /*!
  47. * \brief relayout placeholder opr
  48. *
  49. * RelayoutPlaceholder oprs act as the placeholders of the ComputingGraph
  50. * during graph opt pass `TensorReformatPass`. These oprs are introduced
  51. * into a ComputingGraph for conveniently discovering further optimize
  52. * opportunities (such as fuse consecutive relayouts, translate into
  53. * optimized implementations). They are canonized to have a shape infer, so
  54. * the ouput's shape can be correctly deduced during the opt pass.
  55. *
  56. * Note that the oprs in the ComputingGraph are only used as intermediate
  57. * representations before being translated to MegBrain oprs, so the
  58. * oprs should not get involved in any actual computing.
  59. */
  60. MGB_DEFINE_OPR_CLASS(TensorReformatPass::RelayoutPlaceholder,
  61. cg::SingleCNOperatorNodeBase) // {
  62. public:
  63. //! relayout type of this opr
  64. enum class LayoutType {
  65. NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout
  66. NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout
  67. NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout
  68. CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout
  69. NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout
  70. NCHW_TO_NCHW4_IC_SMALL_CONV, ///< from nchw layout to nchw4 whose
  71. ///< channel size less than 4
  72. NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout
  73. NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout
  74. NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout
  75. WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4
  76. //!< layout
  77. WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to
  78. //!< nchw4 layout
  79. WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV, //!< weight from nchw layout
  80. //!< to nchw4 layout whose
  81. //! channel size less than 4
  82. WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88
  83. //!< layout
  84. WEIGHT_NCHW_TO_NCHW88_GROUP, //!< group weight from nchw layout to
  85. //!< nchw88 layout
  86. WEIGHT_NCHW_TO_NCHW88_CHAN, //!< channel wise weight from nchw layout
  87. //!< to nchw88 layout
  88. //!< the weight layout of input is nchw output is nchw88, special for
  89. //!< shape weight in nchw like {64, 2, 3, 3} to {8, 3, 3, 2, 8}
  90. WEIGHT_HYBIRD_NCHW_NCHW88,
  91. WEIGHT_NCHW_TO_NCHW44_DENSE, //!< weight from nchw layout to nchw44
  92. //!< layout
  93. WEIGHT_NCHW_TO_NCHW44_GROUP, //!< group weight from nchw layout to
  94. //!< nchw44 layout
  95. WEIGHT_NCHW_TO_NCHW44_CHAN, //!< channel wise weight from nchw layout
  96. //!< to nchw44 layout
  97. //!< the weight layout of input is nchw output is nchw44, special for
  98. //!< shape weight in nchw like {64, 2, 3, 3} to {16, 3, 3, 2, 4}
  99. WEIGHT_HYBIRD_NCHW_NCHW44,
  100. WEIGHT_NCHW_TO_NCHW44_DOT_DENSE, //!< weight from NCHW44 layout to
  101. //!< NCHW44_DOT layout dense
  102. WEIGHT_NCHW_TO_NCHW44_DOT_GROUP, //!< weight from NCHW44 layout to
  103. //!< NCHW44_DOT layout group
  104. };
  105. RelayoutPlaceholder(VarNode* src_var, LayoutType layout_type);
  106. /*!
  107. * \param src_var the input var
  108. * \param layout_type tensor layout transform type of this relayout
  109. * placeholder as described in LayoutType
  110. */
  111. static SymbolVar make(VarNode* src_var, LayoutType layout_type);
  112. LayoutType layout_type() const {
  113. return m_layout_type;
  114. }
  115. private:
  116. void init_output_static_infer_desc() override;
  117. void scn_do_execute() override;
  118. void init_output_comp_node() override;
  119. const LayoutType m_layout_type;
  120. }
  121. ;
  122. MGB_DYN_TYPE_OBJ_FINAL_IMPL(TensorReformatPass::RelayoutPlaceholder);
  123. TensorReformatPass::RelayoutPlaceholder::RelayoutPlaceholder(
  124. VarNode* src_var, LayoutType layout_type)
  125. : Super(src_var->owner_graph(), {}, "RelayoutPlaceholder", {src_var}),
  126. m_layout_type{layout_type} {
  127. add_input({src_var});
  128. add_equivalence_component<ScalarHash<LayoutType>>(m_layout_type);
  129. add_output(None)->dtype(src_var->dtype());
  130. }
  131. void TensorReformatPass::RelayoutPlaceholder::scn_do_execute() {
  132. mgb_throw(InternalError, "RelayoutPlaceholder opr can not be executed");
  133. }
  134. void TensorReformatPass::RelayoutPlaceholder::init_output_comp_node() {
  135. output(0)->comp_node(input(0)->comp_node());
  136. }
  137. void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
  138. using namespace cg::static_infer;
  139. auto&& mgr = owner_graph()->static_infer_manager();
  140. DepVal deps;
  141. for (auto i : input())
  142. deps.push_back({i, DepType::SHAPE});
  143. auto infer_shape = [this](TensorShape& dst, const InpVal& inp) {
  144. TensorShape inp_shape = inp.val[0].shape();
  145. dst = inp_shape;
  146. if (layout_type() == RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32) {
  147. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
  148. dst[0] = inp_shape[0];
  149. dst[1] = inp_shape[1] / 8;
  150. dst[2] = inp_shape[2];
  151. dst[3] = inp_shape[3];
  152. dst[4] = inp_shape[4] * 8;
  153. } else if (layout_type() ==
  154. RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4) {
  155. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 32);
  156. dst[0] = inp_shape[0];
  157. dst[1] = inp_shape[1] * 8;
  158. dst[2] = inp_shape[2];
  159. dst[3] = inp_shape[3];
  160. dst[4] = inp_shape[4] / 8;
  161. } else if (layout_type() ==
  162. RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4) {
  163. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
  164. dst[0] = inp_shape[1];
  165. dst[1] = inp_shape[2];
  166. dst[2] = inp_shape[3];
  167. dst[3] = inp_shape[0];
  168. dst[4] = inp_shape[4];
  169. } else if (layout_type() ==
  170. RelayoutPlaceholder::LayoutType::CHWN4_TO_NCHW4) {
  171. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
  172. dst[0] = inp_shape[3];
  173. dst[1] = inp_shape[0];
  174. dst[2] = inp_shape[1];
  175. dst[3] = inp_shape[2];
  176. dst[4] = inp_shape[4];
  177. } else if (layout_type() ==
  178. RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4 ||
  179. layout_type() == RelayoutPlaceholder::LayoutType::
  180. NCHW_TO_NCHW4_IC_SMALL_CONV) {
  181. if (layout_type() ==
  182. RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4) {
  183. mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0);
  184. } else {
  185. mgb_assert(layout_type() ==
  186. RelayoutPlaceholder::LayoutType::
  187. NCHW_TO_NCHW4_IC_SMALL_CONV);
  188. mgb_assert(inp_shape.ndim == 4 && inp_shape[1] < 4);
  189. }
  190. dst.ndim = 5;
  191. dst[0] = inp_shape[0];
  192. dst[1] = (inp_shape[1] + 4 - 1) / 4;
  193. dst[2] = inp_shape[2];
  194. dst[3] = inp_shape[3];
  195. dst[4] = 4;
  196. } else if (layout_type() ==
  197. RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW) {
  198. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
  199. dst.ndim = 4;
  200. dst[0] = inp_shape[0];
  201. dst[1] = inp_shape[1] * 4;
  202. dst[2] = inp_shape[2];
  203. dst[3] = inp_shape[3];
  204. } else if (layout_type() == RelayoutPlaceholder::LayoutType::
  205. WEIGHT_NCHW_TO_NCHW4_DENSE ||
  206. layout_type() ==
  207. RelayoutPlaceholder::LayoutType::
  208. WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV) {
  209. if (layout_type() ==
  210. RelayoutPlaceholder::LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE) {
  211. mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0);
  212. } else {
  213. mgb_assert(layout_type() ==
  214. RelayoutPlaceholder::LayoutType::
  215. WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV);
  216. mgb_assert(inp_shape.ndim == 4 && inp_shape[1] < 4);
  217. }
  218. dst.ndim = 5;
  219. dst[0] = inp_shape[0];
  220. dst[1] = (inp_shape[1] + 4 - 1) / 4;
  221. dst[2] = inp_shape[2];
  222. dst[3] = inp_shape[3];
  223. dst[4] = 4;
  224. } else if (layout_type() == RelayoutPlaceholder::LayoutType::
  225. WEIGHT_NCHW_TO_NCHW4_GROUP) {
  226. mgb_assert(inp_shape.ndim == 5 && inp_shape[2] % 4 == 0);
  227. dst.ndim = 6;
  228. dst[0] = inp_shape[0];
  229. dst[1] = inp_shape[1];
  230. dst[2] = inp_shape[2] / 4;
  231. dst[3] = inp_shape[3];
  232. dst[4] = inp_shape[4];
  233. dst[5] = 4;
  234. } else if (layout_type() ==
  235. RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW88) {
  236. mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 8 == 0);
  237. dst.ndim = 5;
  238. dst[0] = inp_shape[0];
  239. dst[1] = inp_shape[1] / 8;
  240. dst[2] = inp_shape[2];
  241. dst[3] = inp_shape[3];
  242. dst[4] = 8;
  243. } else if (layout_type() ==
  244. RelayoutPlaceholder::LayoutType::NCHW88_TO_NCHW) {
  245. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 8);
  246. dst.ndim = 4;
  247. dst[0] = inp_shape[0];
  248. dst[1] = inp_shape[1] * 8;
  249. dst[2] = inp_shape[2];
  250. dst[3] = inp_shape[3];
  251. } else if (layout_type() == RelayoutPlaceholder::LayoutType::
  252. WEIGHT_NCHW_TO_NCHW88_DENSE) {
  253. mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 8 == 0 &&
  254. inp_shape[1] % 8 == 0);
  255. dst.ndim = 6;
  256. dst[0] = inp_shape[0] / 8;
  257. dst[1] = inp_shape[1] / 8;
  258. dst[2] = inp_shape[2];
  259. dst[3] = inp_shape[3];
  260. dst[4] = 8;
  261. dst[5] = 8;
  262. } else if (layout_type() == RelayoutPlaceholder::LayoutType::
  263. WEIGHT_NCHW_TO_NCHW88_GROUP) {
  264. mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 8 == 0 &&
  265. inp_shape[2] % 8 == 0);
  266. dst.ndim = 7;
  267. dst[0] = inp_shape[0];
  268. dst[1] = inp_shape[1] / 8;
  269. dst[2] = inp_shape[2] / 8;
  270. dst[3] = inp_shape[3];
  271. dst[4] = inp_shape[4];
  272. dst[5] = 8;
  273. dst[6] = 8;
  274. } else if (layout_type() == RelayoutPlaceholder::LayoutType::
  275. WEIGHT_NCHW_TO_NCHW88_CHAN) {
  276. mgb_assert(inp_shape.ndim == 5 && inp_shape[1] == 1 &&
  277. inp_shape[2] == 1 && inp_shape[0] % 8 == 0);
  278. dst.ndim = 6;
  279. dst[0] = inp_shape[0] / 8;
  280. dst[1] = inp_shape[1];
  281. dst[2] = inp_shape[2];
  282. dst[3] = inp_shape[3];
  283. dst[4] = inp_shape[4];
  284. dst[5] = 8;
  285. } else if (layout_type() ==
  286. RelayoutPlaceholder::LayoutType::WEIGHT_HYBIRD_NCHW_NCHW88) {
  287. mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 8 == 0);
  288. dst.ndim = 5;
  289. dst[0] = inp_shape[0] / 8;
  290. dst[1] = inp_shape[2];
  291. dst[2] = inp_shape[3];
  292. dst[3] = inp_shape[1];
  293. dst[4] = 8;
  294. } else if (layout_type() == RelayoutPlaceholder::LayoutType::
  295. WEIGHT_NCHW_TO_NCHW44_DENSE ||
  296. layout_type() == RelayoutPlaceholder::LayoutType::
  297. WEIGHT_NCHW_TO_NCHW44_DOT_DENSE) {
  298. mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 4 == 0 &&
  299. inp_shape[1] % 4 == 0);
  300. dst.ndim = 6;
  301. dst[0] = inp_shape[0] / 4;
  302. dst[1] = inp_shape[1] / 4;
  303. dst[2] = inp_shape[2];
  304. dst[3] = inp_shape[3];
  305. dst[4] = 4;
  306. dst[5] = 4;
  307. } else if (layout_type() == RelayoutPlaceholder::LayoutType::
  308. WEIGHT_NCHW_TO_NCHW44_GROUP ||
  309. layout_type() == RelayoutPlaceholder::LayoutType::
  310. WEIGHT_NCHW_TO_NCHW44_DOT_GROUP) {
  311. mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 4 == 0 &&
  312. inp_shape[2] % 4 == 0);
  313. dst.ndim = 7;
  314. dst[0] = inp_shape[0];
  315. dst[1] = inp_shape[1] / 4;
  316. dst[2] = inp_shape[2] / 4;
  317. dst[3] = inp_shape[3];
  318. dst[4] = inp_shape[4];
  319. dst[5] = 4;
  320. dst[6] = 4;
  321. } else if (layout_type() == RelayoutPlaceholder::LayoutType::
  322. WEIGHT_NCHW_TO_NCHW44_CHAN) {
  323. mgb_assert(inp_shape.ndim == 5 && inp_shape[1] == 1 &&
  324. inp_shape[2] == 1 && inp_shape[0] % 4 == 0);
  325. dst.ndim = 6;
  326. dst[0] = inp_shape[0] / 4;
  327. dst[1] = inp_shape[1];
  328. dst[2] = inp_shape[2];
  329. dst[3] = inp_shape[3];
  330. dst[4] = inp_shape[4];
  331. dst[5] = 4;
  332. } else {
  333. mgb_assert(
  334. layout_type() ==
  335. RelayoutPlaceholder::LayoutType::WEIGHT_HYBIRD_NCHW_NCHW44);
  336. mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 4 == 0);
  337. dst.ndim = 5;
  338. dst[0] = inp_shape[0] / 4;
  339. dst[1] = inp_shape[2];
  340. dst[2] = inp_shape[3];
  341. dst[3] = inp_shape[1];
  342. dst[4] = 4;
  343. }
  344. return true;
  345. };
  346. mgr.register_shape_infer(output(0), {SourceType::DEP, deps, infer_shape});
  347. }
  348. SymbolVar TensorReformatPass::RelayoutPlaceholder::make(
  349. VarNode* src_var, LayoutType layout_type) {
  350. return src_var->owner_graph()
  351. ->insert_opr(
  352. std::make_unique<RelayoutPlaceholder>(src_var, layout_type))
  353. ->output(0);
  354. }
  355. void TensorReformatPass::insert_pass(OptState& opt) const {
  356. opt.set_var_replace_check_flag(m_var_replace_check_flag);
  357. auto rewriter = opt.graph().make_rewriter();
  358. VarNodeArray new_inp_cache;
  359. auto on_opr = [this, &opt, &rewriter,
  360. &new_inp_cache](OperatorNodeBase* opr) {
  361. auto it = m_opr_replace_func.find(opr->dyn_typeinfo());
  362. if (it != m_opr_replace_func.end()) {
  363. auto& new_inp = new_inp_cache;
  364. new_inp.clear();
  365. new_inp.reserve(opr->input().size());
  366. for (auto&& inp : opr->input()) {
  367. new_inp.push_back(rewriter.get_var(inp));
  368. }
  369. auto new_opr = (it->second)(opr, new_inp);
  370. auto &&out0 = opr->output(), &&out1 = new_opr->output();
  371. mgb_assert(out0.size() == out1.size(),
  372. "bad opr replace: src=%s{%s} dst=%s{%s}, src.size=%zu "
  373. "dst.size=%zu",
  374. opr->cname(), opr->dyn_typeinfo()->name,
  375. new_opr->cname(), new_opr->dyn_typeinfo()->name,
  376. out0.size(), out1.size());
  377. for (size_t i = 0; i < out0.size(); ++i) {
  378. if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  379. mgb_assert(!out1[i]->contain_flag(
  380. VarNode::Flag::VOLATILE_CONTENT));
  381. auto src = out0[i];
  382. auto dst = out1[i];
  383. if (opt.graph().endpoint_contain(src)) {
  384. // additional process on endpoint var node
  385. dst = on_graph_endpoint_var(dst, src);
  386. }
  387. rewriter.replace_var(src, dst, nullptr);
  388. }
  389. }
  390. } else {
  391. rewriter.auto_replace_outputs(opr);
  392. }
  393. };
  394. opt.graph().iter(on_opr);
  395. rewriter.apply_inplace();
  396. }
  397. void TensorReformatPass::translate_pass(OptState& opt) const {
  398. ThinHashMap<RelayoutPlaceholder::LayoutType,
  399. thin_function<VarNode*(VarNode*)>>
  400. reformat;
  401. using LayoutType = RelayoutPlaceholder::LayoutType;
  402. reformat[LayoutType::NCHW4_TO_CHWN4] = [](VarNode* inp) -> VarNode* {
  403. megdnn::param::RelayoutFormat param;
  404. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW4_CHWN4;
  405. auto reformat = opr::RelayoutFormat::make(inp, param);
  406. return reformat.node();
  407. };
  408. reformat[LayoutType::CHWN4_TO_NCHW4] = [](VarNode* inp) -> VarNode* {
  409. megdnn::param::RelayoutFormat param;
  410. param.mode = megdnn::param::RelayoutFormat::Mode::CHWN4_NCHW4;
  411. auto reformat = opr::RelayoutFormat::make(inp, param);
  412. return reformat.node();
  413. };
  414. reformat[LayoutType::NCHW4_TO_NCHW32] = [](VarNode* inp) -> VarNode* {
  415. auto x = SymbolVar(inp);
  416. auto xshp = opr::GetVarShape::make(x);
  417. auto cv = [&x](int v) { return x.make_scalar(v); };
  418. auto sub = [&xshp, &cv](int idx) {
  419. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  420. };
  421. auto tshp0 = opr::Concat::make(
  422. {sub(0), sub(1) / 8, cv(8), sub(2), sub(3), sub(4)}, 0),
  423. tshp1 = opr::Concat::make(
  424. {sub(0), sub(1) / 8, sub(2), sub(3), sub(4) * 8}, 0);
  425. auto y0 = opr::Reshape::make(x, tshp0);
  426. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2, 5});
  427. auto y2 = opr::Reshape::make(y1, tshp1);
  428. return y2.node();
  429. };
  430. reformat[LayoutType::NCHW32_TO_NCHW4] = [](VarNode* inp) -> VarNode* {
  431. auto x = SymbolVar(inp);
  432. auto xshp = opr::GetVarShape::make(x);
  433. auto cv = [&x](int v) { return x.make_scalar(v); };
  434. auto sub = [&xshp, &cv](int idx) {
  435. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  436. };
  437. auto tshp0 = opr::Concat::make(
  438. {sub(0), sub(1), sub(2), sub(3), cv(8), sub(4) / 8}, 0),
  439. tshp1 = opr::Concat::make(
  440. {sub(0), sub(1) * 8, sub(2), sub(3), sub(4) / 8}, 0);
  441. auto y0 = opr::Reshape::make(x, tshp0);
  442. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 4, 2, 3, 5});
  443. auto y2 = opr::Reshape::make(y1, tshp1);
  444. return y2.node();
  445. };
  446. reformat[LayoutType::NCHW_TO_NCHW4_IC_SMALL_CONV] =
  447. [](VarNode* inp) -> VarNode* {
  448. auto x = SymbolVar(inp);
  449. auto y = opr::RelayoutFormat::make(
  450. x, megdnn::param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL);
  451. return y.node();
  452. };
  453. reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV] =
  454. [](VarNode* inp) -> VarNode* {
  455. auto x = SymbolVar(inp);
  456. auto y = opr::RelayoutFormat::make(
  457. x, megdnn::param::RelayoutFormat::Mode::
  458. NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT);
  459. return y.node();
  460. };
  461. reformat[LayoutType::NCHW_TO_NCHW4] = [](VarNode* inp) -> VarNode* {
  462. auto x = SymbolVar(inp);
  463. auto xshp = opr::GetVarShape::make(x);
  464. auto cv = [&x](int v) { return x.make_scalar(v); };
  465. auto sub = [&xshp, &cv](int idx) {
  466. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  467. };
  468. auto tshp0 = opr::Concat::make(
  469. {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0);
  470. auto y0 = opr::Reshape::make(x, tshp0);
  471. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
  472. return y1.node();
  473. };
  474. reformat[LayoutType::NCHW4_TO_NCHW] = [](VarNode* inp) -> VarNode* {
  475. auto x = SymbolVar(inp);
  476. auto xshp = opr::GetVarShape::make(x);
  477. auto cv = [&x](int v) { return x.make_scalar(v); };
  478. auto sub = [&xshp, &cv](int idx) {
  479. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  480. };
  481. auto tshp0 = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0);
  482. auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
  483. auto y1 = opr::Reshape::make(y0, tshp0);
  484. return y1.node();
  485. };
  486. reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE] =
  487. [](VarNode* inp) -> VarNode* {
  488. auto x = SymbolVar(inp);
  489. auto xshp = opr::GetVarShape::make(x);
  490. auto cv = [&x](int v) { return x.make_scalar(v); };
  491. auto sub = [&xshp, &cv](int idx) {
  492. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  493. };
  494. auto tshp0 = opr::Concat::make(
  495. {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0),
  496. tshp1 = opr::Concat::make(
  497. {sub(0), sub(1) / 4, sub(2), sub(3), cv(4)}, 0);
  498. auto y0 = opr::Reshape::make(x, tshp0);
  499. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
  500. auto y2 = opr::Reshape::make(y1, tshp1);
  501. return y2.node();
  502. };
  503. reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_GROUP] =
  504. [](VarNode* inp) -> VarNode* {
  505. auto x = SymbolVar(inp);
  506. auto xshp = opr::GetVarShape::make(x);
  507. auto cv = [&x](int v) { return x.make_scalar(v); };
  508. auto sub = [&xshp, &cv](int idx) {
  509. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  510. };
  511. auto tshp0 = opr::Concat::make(
  512. {sub(0), sub(1), sub(2) / 4, cv(4), sub(3), sub(4)}, 0),
  513. tshp1 = opr::Concat::make(
  514. {sub(0), sub(1), sub(2) / 4, sub(3), sub(4), cv(4)}, 0);
  515. auto y0 = opr::Reshape::make(x, tshp0);
  516. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 2, 4, 5, 3});
  517. auto y2 = opr::Reshape::make(y1, tshp1);
  518. return y2.node();
  519. };
  520. reformat[LayoutType::NCHW_TO_NCHW88] = [](VarNode* inp) -> VarNode* {
  521. auto x = SymbolVar(inp);
  522. auto xshp = opr::GetVarShape::make(x);
  523. auto cv = [&x](int v) { return x.make_scalar(v); };
  524. auto sub = [&xshp, &cv](int idx) {
  525. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  526. };
  527. auto tshp0 = opr::Concat::make(
  528. {sub(0), sub(1) / 8, cv(8), sub(2), sub(3)}, 0),
  529. tshp1 = opr::Concat::make(
  530. {sub(0), sub(1) / 8, sub(2), sub(3), cv(8)}, 0);
  531. auto y0 = opr::Reshape::make(x, tshp0);
  532. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
  533. auto y2 = opr::Reshape::make(y1, tshp1);
  534. return y2.node();
  535. };
  536. reformat[LayoutType::NCHW88_TO_NCHW] = [](VarNode* inp) -> VarNode* {
  537. auto x = SymbolVar(inp);
  538. auto xshp = opr::GetVarShape::make(x);
  539. auto cv = [&x](int v) { return x.make_scalar(v); };
  540. auto sub = [&xshp, &cv](int idx) {
  541. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  542. };
  543. auto tshp0 = opr::Concat::make({sub(0), sub(1) * 8, sub(2), sub(3)}, 0);
  544. auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
  545. auto y1 = opr::Reshape::make(y0, tshp0);
  546. return y1.node();
  547. };
  548. reformat[LayoutType::WEIGHT_NCHW_TO_NCHW88_DENSE] =
  549. [](VarNode* inp) -> VarNode* {
  550. auto x = SymbolVar(inp);
  551. auto xshp = opr::GetVarShape::make(x);
  552. auto cv = [&x](int v) { return x.make_scalar(v); };
  553. auto sub = [&xshp, &cv](int idx) {
  554. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  555. };
  556. auto tshp0 = opr::Concat::make(
  557. {sub(0) / 8, cv(8), sub(1) / 8, cv(8), sub(2), sub(3)}, 0),
  558. tshp1 = opr::Concat::make(
  559. {sub(0) / 8, sub(1) / 8, sub(2), sub(3), cv(8), cv(8)}, 0);
  560. auto y0 = opr::Reshape::make(x, tshp0);
  561. auto y1 = opr::Dimshuffle::make(y0, {0, 2, 4, 5, 3, 1});
  562. auto y2 = opr::Reshape::make(y1, tshp1);
  563. return y2.node();
  564. };
  565. reformat[LayoutType::WEIGHT_NCHW_TO_NCHW88_GROUP] =
  566. [](VarNode* inp) -> VarNode* {
  567. auto x = SymbolVar(inp);
  568. auto xshp = opr::GetVarShape::make(x);
  569. auto cv = [&x](int v) { return x.make_scalar(v); };
  570. auto sub = [&xshp, &cv](int idx) {
  571. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  572. };
  573. auto tshp0 = opr::Concat::make({sub(0), sub(1) / 8, cv(8), sub(2) / 8,
  574. cv(8), sub(3), sub(4)},
  575. 0),
  576. tshp1 = opr::Concat::make({sub(0), sub(1) / 8, sub(2) / 8, sub(3),
  577. sub(4), cv(8), cv(8)},
  578. 0);
  579. auto y0 = opr::Reshape::make(x, tshp0);
  580. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 4, 2});
  581. auto y2 = opr::Reshape::make(y1, tshp1);
  582. return y2.node();
  583. };
  584. reformat[LayoutType::WEIGHT_NCHW_TO_NCHW88_CHAN] =
  585. [](VarNode* inp) -> VarNode* {
  586. auto x = SymbolVar(inp);
  587. auto xshp = opr::GetVarShape::make(x);
  588. auto cv = [&x](int v) { return x.make_scalar(v); };
  589. auto sub = [&xshp, &cv](int idx) {
  590. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  591. };
  592. auto tshp0 = opr::Concat::make(
  593. {sub(0) / 8, cv(8), sub(1), sub(2), sub(3), sub(4)}, 0),
  594. tshp1 = opr::Concat::make(
  595. {sub(0) / 8, sub(1), sub(2), sub(3), sub(4), cv(8)}, 0);
  596. auto y0 = opr::Reshape::make(x, tshp0);
  597. auto y1 = opr::Dimshuffle::make(y0, {0, 2, 3, 4, 5, 1});
  598. auto y2 = opr::Reshape::make(y1, tshp1);
  599. return y2.node();
  600. };
  601. reformat[LayoutType::WEIGHT_HYBIRD_NCHW_NCHW88] =
  602. [](VarNode* inp) -> VarNode* {
  603. auto x = SymbolVar(inp);
  604. auto xshp = opr::GetVarShape::make(x);
  605. auto cv = [&x](int v) { return x.make_scalar(v); };
  606. auto sub = [&xshp, &cv](int idx) {
  607. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  608. };
  609. auto tshp0 = opr::Concat::make(
  610. {sub(0) / 8, cv(8), sub(1), sub(2), sub(3)}, 0),
  611. tshp1 = opr::Concat::make(
  612. {sub(0) / 8, sub(2), sub(3), sub(1), cv(8)}, 0);
  613. auto y0 = opr::Reshape::make(x, tshp0);
  614. auto y1 = opr::Dimshuffle::make(y0, {0, 3, 4, 2, 1});
  615. auto y2 = opr::Reshape::make(y1, tshp1);
  616. return y2.node();
  617. };
  618. reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_DENSE] =
  619. [](VarNode* inp) -> VarNode* {
  620. auto x = SymbolVar(inp);
  621. auto xshp = opr::GetVarShape::make(x);
  622. auto cv = [&x](int v) { return x.make_scalar(v); };
  623. auto sub = [&xshp, &cv](int idx) {
  624. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  625. };
  626. auto tshp0 = opr::Concat::make(
  627. {sub(0) / 4, cv(4), sub(1) / 4, cv(4), sub(2), sub(3)}, 0),
  628. tshp1 = opr::Concat::make(
  629. {sub(0) / 4, sub(1) / 4, sub(2), sub(3), cv(4), cv(4)}, 0);
  630. auto y0 = opr::Reshape::make(x, tshp0);
  631. auto y1 = opr::Dimshuffle::make(y0, {0, 2, 4, 5, 3, 1});
  632. auto y2 = opr::Reshape::make(y1, tshp1);
  633. return y2.node();
  634. };
  635. reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_GROUP] =
  636. [](VarNode* inp) -> VarNode* {
  637. auto x = SymbolVar(inp);
  638. auto xshp = opr::GetVarShape::make(x);
  639. auto cv = [&x](int v) { return x.make_scalar(v); };
  640. auto sub = [&xshp, &cv](int idx) {
  641. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  642. };
  643. auto tshp0 = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2) / 4,
  644. cv(4), sub(3), sub(4)},
  645. 0),
  646. tshp1 = opr::Concat::make({sub(0), sub(1) / 4, sub(2) / 4, sub(3),
  647. sub(4), cv(4), cv(4)},
  648. 0);
  649. auto y0 = opr::Reshape::make(x, tshp0);
  650. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 4, 2});
  651. auto y2 = opr::Reshape::make(y1, tshp1);
  652. return y2.node();
  653. };
  654. reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_CHAN] =
  655. [](VarNode* inp) -> VarNode* {
  656. auto x = SymbolVar(inp);
  657. auto xshp = opr::GetVarShape::make(x);
  658. auto cv = [&x](int v) { return x.make_scalar(v); };
  659. auto sub = [&xshp, &cv](int idx) {
  660. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  661. };
  662. auto tshp0 = opr::Concat::make(
  663. {sub(0) / 4, cv(4), sub(1), sub(2), sub(3), sub(4)}, 0),
  664. tshp1 = opr::Concat::make(
  665. {sub(0) / 4, sub(1), sub(2), sub(3), sub(4), cv(4)}, 0);
  666. auto y0 = opr::Reshape::make(x, tshp0);
  667. auto y1 = opr::Dimshuffle::make(y0, {0, 2, 3, 4, 5, 1});
  668. auto y2 = opr::Reshape::make(y1, tshp1);
  669. return y2.node();
  670. };
  671. reformat[LayoutType::WEIGHT_HYBIRD_NCHW_NCHW44] =
  672. [](VarNode* inp) -> VarNode* {
  673. auto x = SymbolVar(inp);
  674. auto xshp = opr::GetVarShape::make(x);
  675. auto cv = [&x](int v) { return x.make_scalar(v); };
  676. auto sub = [&xshp, &cv](int idx) {
  677. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  678. };
  679. auto tshp0 = opr::Concat::make(
  680. {sub(0) / 4, cv(4), sub(1), sub(2), sub(3)}, 0),
  681. tshp1 = opr::Concat::make(
  682. {sub(0) / 4, sub(2), sub(3), sub(1), cv(4)}, 0);
  683. auto y0 = opr::Reshape::make(x, tshp0);
  684. auto y1 = opr::Dimshuffle::make(y0, {0, 3, 4, 2, 1});
  685. auto y2 = opr::Reshape::make(y1, tshp1);
  686. return y2.node();
  687. };
  688. reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE] =
  689. [](VarNode* inp) -> VarNode* {
  690. auto x = SymbolVar(inp);
  691. auto xshp = opr::GetVarShape::make(x);
  692. auto cv = [&x](int v) { return x.make_scalar(v); };
  693. auto sub = [&xshp, &cv](int idx) {
  694. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  695. };
  696. auto tshp0 = opr::Concat::make(
  697. {sub(0) / 4, cv(4), sub(1) / 4, cv(4), sub(2), sub(3)}, 0),
  698. tshp1 = opr::Concat::make(
  699. {sub(0) / 4, sub(1) / 4, sub(2), sub(3), cv(4), cv(4)}, 0);
  700. auto y0 = opr::Reshape::make(x, tshp0);
  701. auto y1 = opr::Dimshuffle::make(y0, {0, 2, 4, 5, 1, 3});
  702. auto y2 = opr::Reshape::make(y1, tshp1);
  703. return y2.node();
  704. };
  705. reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP] =
  706. [](VarNode* inp) -> VarNode* {
  707. auto x = SymbolVar(inp);
  708. auto xshp = opr::GetVarShape::make(x);
  709. auto cv = [&x](int v) { return x.make_scalar(v); };
  710. auto sub = [&xshp, &cv](int idx) {
  711. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  712. };
  713. auto tshp0 = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2) / 4,
  714. cv(4), sub(3), sub(4)},
  715. 0),
  716. tshp1 = opr::Concat::make({sub(0), sub(1) / 4, sub(2) / 4, sub(3),
  717. sub(4), cv(4), cv(4)},
  718. 0);
  719. auto y0 = opr::Reshape::make(x, tshp0);
  720. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 2, 4});
  721. auto y2 = opr::Reshape::make(y1, tshp1);
  722. return y2.node();
  723. };
  724. auto rewriter = opt.graph().make_rewriter();
  725. auto on_opr = [&reformat, &rewriter](OperatorNodeBase* opr) {
  726. if (opr->same_type<RelayoutPlaceholder>()) {
  727. auto ph = try_cast_as_op<RelayoutPlaceholder>(opr);
  728. auto new_inp = rewriter.get_var(opr->input(0));
  729. mgb_assert(reformat.count(ph->layout_type()),
  730. "no replace rule can be found for layout_type(%u)",
  731. static_cast<uint32_t>(ph->layout_type()));
  732. auto new_var = reformat[ph->layout_type()](new_inp);
  733. rewriter.replace_var(opr->output(0), new_var,
  734. mgb_cstr_log("replace relayout placeholder"));
  735. return;
  736. }
  737. rewriter.auto_replace_outputs(opr);
  738. };
  739. opt.graph().iter(on_opr);
  740. rewriter.apply_inplace();
  741. }
  742. void TensorReformatPass::apply(OptState& opt) const {
  743. MIDOUT_B("TensorReformatPass::apply")
  744. insert_pass(opt);
  745. translate_pass(opt);
  746. MIDOUT_E
  747. }
  748. /* ================ EnableTensorCorePass =============== */
  749. VarNode* EnableTensorCorePass::on_graph_endpoint_var(VarNode* new_var,
  750. VarNode* orig_var) const {
  751. if (!orig_var->shape().eq_shape(new_var->shape())) {
  752. return RelayoutPlaceholder::make(
  753. new_var,
  754. RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4)
  755. .node();
  756. }
  757. return new_var;
  758. }
  759. std::unique_ptr<EnableTensorCorePass>
  760. EnableTensorCorePass::make_tensorcore_converter() {
  761. MIDOUT_B("EnableTensorCorePass::make")
  762. // replace rule for conv bias opr
  763. auto replace_conv_bias_opr = [](OperatorNodeBase* opr,
  764. const VarNodeArray& new_inp) {
  765. using Param = megdnn::param::ConvBias;
  766. using Format = Param::Format;
  767. using Sparse = Param::Sparse;
  768. mgb_assert(opr->input().size() == new_inp.size());
  769. auto& conv_bias = opr->cast_final_safe<opr::ConvBiasForward>();
  770. if (conv_bias.param().format != Format::NCHW4 ||
  771. conv_bias.output(0)->dtype().enumv() != DTypeEnum::QuantizedS8) {
  772. size_t nr_inps = opr->input().size();
  773. bool shape_has_changed = false;
  774. for (size_t i = 0; i < nr_inps; ++i) {
  775. if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
  776. shape_has_changed = true;
  777. }
  778. }
  779. MGB_MARK_USED_VAR(shape_has_changed);
  780. mgb_assert(
  781. !shape_has_changed,
  782. "EnableTensorCorePass assumes that the shape of inputs of"
  783. "ConvBias operators whose output dtype is not QuantizedS8 "
  784. "can not be changed in this opt pass");
  785. return serialization::copy_opr_shallow(*opr, new_inp,
  786. opr->config());
  787. }
  788. mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape()),
  789. "EnableTensorCorePass assumes that filter tensor of "
  790. "conv_bias operator can not be changed by other operators");
  791. VarNode* orig_filter = opr->input(1);
  792. auto is_nchw4 = [](TensorShape shape) -> bool {
  793. return shape.ndim == 5 && shape[4] == 4;
  794. };
  795. auto is_nchw32 = [](TensorShape shape) -> bool {
  796. return shape.ndim == 5 && shape[4] == 32;
  797. };
  798. bool can_replace_nchw32 = false;
  799. VarNode *src = nullptr, *weight = nullptr, *bias = nullptr,
  800. *z_inp = nullptr;
  801. // process src tensor
  802. if (is_nchw4(new_inp[0]->shape())) { // new input is NCHW4 layout
  803. size_t group = 1, icpg, ocpg;
  804. if (conv_bias.param().sparse == Sparse::DENSE) {
  805. icpg = orig_filter->shape()[1] * 4;
  806. ocpg = orig_filter->shape()[0];
  807. } else {
  808. mgb_assert(conv_bias.param().sparse == Sparse::GROUP);
  809. group = orig_filter->shape()[0];
  810. icpg = orig_filter->shape()[2];
  811. ocpg = orig_filter->shape()[1];
  812. if (icpg == 1 && ocpg == 1) { // channel wise conv
  813. group *= 4;
  814. } else {
  815. icpg *= 4;
  816. }
  817. }
  818. // nchw32 layout need that input width and height are larger than 3
  819. size_t ih = new_inp[0]->shape()[2], iw = new_inp[0]->shape()[3];
  820. if (group == 1 && ocpg % 32 == 0 && icpg % 32 == 0 && ih >= 3 &&
  821. iw >= 3) {
  822. auto symvar = RelayoutPlaceholder::make(
  823. new_inp[0],
  824. RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32);
  825. src = symvar.node();
  826. can_replace_nchw32 = true;
  827. } else {
  828. src = new_inp[0];
  829. }
  830. } else { // new input is NCHW32 layout
  831. mgb_assert(is_nchw32(new_inp[0]->shape()));
  832. size_t group = 1, ocpg;
  833. if (conv_bias.param().sparse == Sparse::DENSE) {
  834. ocpg = orig_filter->shape()[0];
  835. } else {
  836. mgb_assert(conv_bias.param().sparse == Sparse::GROUP);
  837. size_t icpg = orig_filter->shape()[2];
  838. ocpg = orig_filter->shape()[1];
  839. if (icpg == 1 && ocpg == 1) {
  840. group *= 4;
  841. } else {
  842. icpg *= 4;
  843. }
  844. }
  845. size_t ih = new_inp[0]->shape()[2], iw = new_inp[0]->shape()[3];
  846. if (group == 1 && ocpg % 32 == 0 && ih >= 3 && iw >= 3) {
  847. can_replace_nchw32 = true;
  848. src = new_inp[0];
  849. } else {
  850. auto symvar = RelayoutPlaceholder::make(
  851. new_inp[0],
  852. RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4);
  853. src = symvar.node();
  854. }
  855. }
  856. // process filter tensor
  857. if (can_replace_nchw32) {
  858. auto symvar = RelayoutPlaceholder::make(
  859. new_inp[1],
  860. RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32);
  861. weight = symvar.node();
  862. } else {
  863. weight = new_inp[1];
  864. }
  865. if (new_inp.size() == 2) {
  866. if (can_replace_nchw32) {
  867. auto param = conv_bias.param();
  868. param.format = Format::NCHW32;
  869. auto new_opr = opr::ConvBiasForward::make(
  870. src, weight, param, conv_bias.execution_policy(),
  871. conv_bias.config());
  872. return new_opr.node()->owner_opr();
  873. } else {
  874. VarNodeArray inps{src, weight};
  875. auto new_opr = serialization::copy_opr_shallow(*opr, inps,
  876. opr->config());
  877. return new_opr;
  878. }
  879. }
  880. auto process_inp = [&](VarNode* inp) -> VarNode* {
  881. if (can_replace_nchw32) {
  882. if (is_nchw4(inp->shape())) {
  883. auto symvar = RelayoutPlaceholder::make(
  884. inp,
  885. RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32);
  886. return symvar.node();
  887. } else {
  888. mgb_assert(is_nchw32(inp->shape()));
  889. return inp;
  890. }
  891. } else {
  892. if (is_nchw4(inp->shape())) {
  893. return inp;
  894. } else {
  895. mgb_assert(is_nchw32(inp->shape()));
  896. auto symvar = RelayoutPlaceholder::make(
  897. inp,
  898. RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4);
  899. return symvar.node();
  900. }
  901. }
  902. };
  903. // process bias tensor
  904. bias = process_inp(new_inp[2]);
  905. if (new_inp.size() == 3) {
  906. if (can_replace_nchw32) {
  907. auto param = conv_bias.param();
  908. param.format = Format::NCHW32;
  909. auto new_opr = opr::ConvBiasForward::make(
  910. src, weight, bias, param, conv_bias.execution_policy(),
  911. conv_bias.config());
  912. return new_opr.node()->owner_opr();
  913. } else {
  914. VarNodeArray inps{src, weight, bias};
  915. auto new_opr = serialization::copy_opr_shallow(*opr, inps,
  916. opr->config());
  917. return new_opr;
  918. }
  919. }
  920. // process z_inp tensor
  921. z_inp = process_inp(new_inp[3]);
  922. if (can_replace_nchw32) {
  923. auto param = conv_bias.param();
  924. param.format = Format::NCHW32;
  925. auto new_opr = opr::ConvBiasForward::make(
  926. src, weight, bias, z_inp, param,
  927. conv_bias.execution_policy(), conv_bias.config());
  928. return new_opr.node()->owner_opr();
  929. }
  930. VarNodeArray inps{src, weight, bias, z_inp};
  931. auto new_opr =
  932. serialization::copy_opr_shallow(*opr, inps, opr->config());
  933. return new_opr;
  934. };
  935. // replace rule for elemwise like opr
  936. // for oprs support NCHW4 and NCHW32 layout
  937. auto replace_elemwise_like_opr = [](OperatorNodeBase* opr,
  938. const VarNodeArray new_inp) {
  939. mgb_assert(opr->input().size() == new_inp.size());
  940. size_t nr_inps = new_inp.size();
  941. size_t nr_shape_changed = 0;
  942. for (size_t i = 0; i < nr_inps; ++i) {
  943. if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
  944. nr_shape_changed++;
  945. }
  946. }
  947. if (nr_shape_changed) {
  948. auto inps = new_inp;
  949. if (nr_shape_changed >=
  950. nr_inps / 2) { // NCHW32 > NCHW4 -> use NCHW32
  951. for (size_t i = 0; i < nr_inps; ++i) {
  952. if (opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
  953. auto symvar = RelayoutPlaceholder::make(
  954. new_inp[i], RelayoutPlaceholder::LayoutType::
  955. NCHW4_TO_NCHW32);
  956. inps[i] = symvar.node();
  957. }
  958. }
  959. } else { // NCHW32 < NCHW4 -> use NCHW4
  960. for (size_t i = 0; i < nr_inps; ++i) {
  961. if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
  962. auto symvar = RelayoutPlaceholder::make(
  963. new_inp[i], RelayoutPlaceholder::LayoutType::
  964. NCHW32_TO_NCHW4);
  965. inps[i] = symvar.node();
  966. }
  967. }
  968. }
  969. return serialization::copy_opr_shallow(*opr, inps, opr->config());
  970. }
  971. return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  972. };
  973. // for oprs only supports NCHW4 layout
  974. auto replace_inps_to_nchw4 = [](OperatorNodeBase* opr,
  975. const VarNodeArray new_inp) {
  976. mgb_assert(opr->input().size() == new_inp.size());
  977. VarNodeArray inps = new_inp;
  978. for (size_t i = 0; i < opr->input().size(); ++i) {
  979. if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
  980. mgb_assert(opr->input(i)->shape().ndim == 5 &&
  981. opr->input(i)->shape()[4] == 4);
  982. mgb_assert(new_inp[i]->shape().ndim == 5 &&
  983. new_inp[i]->shape()[4] == 32);
  984. auto symvar = RelayoutPlaceholder::make(
  985. new_inp[i],
  986. RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4);
  987. inps[i] = symvar.node();
  988. }
  989. }
  990. auto new_opr =
  991. serialization::copy_opr_shallow(*opr, inps, opr->config());
  992. return new_opr;
  993. };
  994. auto replace_non_nchw4_opr = [](OperatorNodeBase* opr,
  995. const VarNodeArray new_inp) {
  996. size_t nr_inps = opr->input().size();
  997. bool shape_has_changed = false;
  998. for (size_t i = 0; i < nr_inps; ++i) {
  999. if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
  1000. shape_has_changed = true;
  1001. }
  1002. }
  1003. mgb_assert(!shape_has_changed,
  1004. "EnableTensorCorePass assumes that inputs' shape of "
  1005. "non-nchw4 operators "
  1006. "can not be changed in this opt "
  1007. "pass");
  1008. return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  1009. };
  1010. auto replace_warp_affine_opr =
  1011. [replace_inps_to_nchw4, replace_non_nchw4_opr](
  1012. OperatorNodeBase* opr, const VarNodeArray new_inp) {
  1013. using Param = opr::WarpAffineForward::Param;
  1014. using Format = Param::Format;
  1015. mgb_assert(opr->input().size() == new_inp.size());
  1016. auto& warp = opr->cast_final_safe<opr::WarpAffineForward>();
  1017. if (warp.param().format != Format::NCHW4) {
  1018. return replace_non_nchw4_opr(opr, new_inp);
  1019. }
  1020. return replace_inps_to_nchw4(opr, new_inp);
  1021. };
  1022. auto replace_warp_perspective_opr =
  1023. [replace_inps_to_nchw4, replace_non_nchw4_opr](
  1024. OperatorNodeBase* opr, const VarNodeArray new_inp) {
  1025. using Param = opr::WarpPerspectiveForward::Param;
  1026. using Format = Param::Format;
  1027. mgb_assert(opr->input().size() == new_inp.size());
  1028. auto& warp =
  1029. opr->cast_final_safe<opr::WarpPerspectiveForward>();
  1030. if (warp.param().format != Format::NCHW4) {
  1031. return replace_non_nchw4_opr(opr, new_inp);
  1032. }
  1033. return replace_inps_to_nchw4(opr, new_inp);
  1034. };
  1035. auto replace_resize_opr = [replace_inps_to_nchw4, replace_non_nchw4_opr](
  1036. OperatorNodeBase* opr,
  1037. const VarNodeArray new_inp) {
  1038. using Param = opr::ResizeForward::Param;
  1039. using Format = Param::Format;
  1040. mgb_assert(opr->input().size() == new_inp.size());
  1041. auto& resize = opr->cast_final_safe<opr::ResizeForward>();
  1042. if (resize.param().format != Format::NCHW4) {
  1043. return replace_non_nchw4_opr(opr, new_inp);
  1044. }
  1045. return replace_inps_to_nchw4(opr, new_inp);
  1046. };
  1047. auto replace_pooling_opr = [replace_non_nchw4_opr](
  1048. OperatorNodeBase* opr,
  1049. const VarNodeArray new_inp) {
  1050. using Param = opr::PoolingForward::Param;
  1051. using Format = Param::Format;
  1052. mgb_assert(opr->input().size() == new_inp.size());
  1053. auto& pooling = opr->cast_final_safe<opr::PoolingForward>();
  1054. if (pooling.param().format != Format::NCHW4) {
  1055. return replace_non_nchw4_opr(opr, new_inp);
  1056. }
  1057. size_t nr_inps = opr->input().size();
  1058. MGB_MARK_USED_VAR(nr_inps);
  1059. mgb_assert(nr_inps == 1);
  1060. if (!opr->input(0)->shape().eq_shape(new_inp[0]->shape())) {
  1061. mgb_assert(opr->input(0)->shape().ndim == 5 &&
  1062. opr->input(0)->shape()[4] == 4);
  1063. mgb_assert(new_inp[0]->shape().ndim == 5 &&
  1064. new_inp[0]->shape()[4] == 32);
  1065. auto new_param = pooling.param();
  1066. new_param.format = Format::NCHW32;
  1067. auto new_pooling = opr::PoolingForward::make(new_inp[0], new_param,
  1068. opr->config());
  1069. return new_pooling.node()->owner_opr();
  1070. }
  1071. return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  1072. };
  1073. auto ret = std::make_unique<EnableTensorCorePass>();
  1074. ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
  1075. auto&& replace_func = ret->m_opr_replace_func;
  1076. replace_func[opr::ConvBiasForward::typeinfo()] = replace_conv_bias_opr;
  1077. // elemwise like
  1078. replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr;
  1079. replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr;
  1080. replace_func[opr::ElemwiseMultiType::typeinfo()] =
  1081. replace_elemwise_like_opr;
  1082. replace_func[opr::PowC::typeinfo()] = replace_elemwise_like_opr;
  1083. // format aware
  1084. replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
  1085. replace_func[opr::WarpAffineForward::typeinfo()] = replace_warp_affine_opr;
  1086. replace_func[opr::WarpPerspectiveForward::typeinfo()] =
  1087. replace_warp_perspective_opr;
  1088. replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr;
  1089. // to nchw4
  1090. replace_func[opr::Reduce::typeinfo()] = replace_inps_to_nchw4;
  1091. replace_func[opr::Concat::typeinfo()] = replace_inps_to_nchw4;
  1092. replace_func[opr::Reshape::typeinfo()] = replace_inps_to_nchw4;
  1093. replace_func[opr::GetVarShape::typeinfo()] = replace_inps_to_nchw4;
  1094. replace_func[opr::Dimshuffle::typeinfo()] = replace_inps_to_nchw4;
  1095. return ret;
  1096. MIDOUT_E
  1097. }
  1098. /* ================ EnableCHWN4Pass =============== */
  1099. VarNode* EnableCHWN4Pass::on_graph_endpoint_var(VarNode* new_var,
  1100. VarNode* /* orig_var */) const {
  1101. if (m_varshape_changed.count(new_var)) {
  1102. return RelayoutPlaceholder::make(
  1103. new_var, RelayoutPlaceholder::LayoutType::CHWN4_TO_NCHW4)
  1104. .node();
  1105. }
  1106. return new_var;
  1107. }
  1108. std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() {
  1109. MIDOUT_B("EnableCHWN4Pass::make")
  1110. auto ret = std::make_unique<EnableCHWN4Pass>();
  1111. ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
  1112. auto&& replace_func = ret->m_opr_replace_func;
  1113. auto&& varshape_changed = ret->m_varshape_changed;
  1114. // replace rule for conv bias opr
  1115. auto replace_conv_bias_opr = [&varshape_changed](
  1116. OperatorNodeBase* opr,
  1117. const VarNodeArray& new_inp) {
  1118. using Param = megdnn::param::ConvBias;
  1119. using Format = Param::Format;
  1120. mgb_assert(opr->input().size() == new_inp.size());
  1121. auto& conv_bias = opr->cast_final_safe<opr::ConvBiasForward>();
  1122. if (conv_bias.param().format != Format::NCHW4 ||
  1123. conv_bias.output(0)->dtype().enumv() != DTypeEnum::QuantizedS8) {
  1124. size_t nr_inps = new_inp.size();
  1125. bool shape_has_changed = false;
  1126. for (size_t i = 0; i < nr_inps; ++i) {
  1127. if (varshape_changed.count(new_inp[i])) {
  1128. shape_has_changed = true;
  1129. break;
  1130. }
  1131. }
  1132. mgb_assert(
  1133. !shape_has_changed,
  1134. "EnableCHWN4Pass assumes that the shape of inputs of"
  1135. "ConvBias operators whose output dtype is not QuantizedS8 "
  1136. "can not be changed in this opt pass");
  1137. return serialization::copy_opr_shallow(*opr, new_inp,
  1138. opr->config());
  1139. }
  1140. mgb_assert(varshape_changed.count(new_inp[1]) == 0,
  1141. "EnableCHWN4Pass assumes that filter tensor of "
  1142. "conv_bias operator can not be changed by other operators");
  1143. VarNode *src = nullptr, *weight = nullptr, *bias = nullptr,
  1144. *z_inp = nullptr;
  1145. // process src tensor
  1146. if (varshape_changed.count(new_inp[0]) ==
  1147. 0) { // new input is NCHW4 layout
  1148. // currently not support group conv
  1149. auto symvar = RelayoutPlaceholder::make(
  1150. new_inp[0],
  1151. RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4);
  1152. src = symvar.node();
  1153. } else { // new input is NCHW32 layout
  1154. src = new_inp[0];
  1155. }
  1156. // process weight tensor
  1157. {
  1158. auto symvar = RelayoutPlaceholder::make(
  1159. new_inp[1],
  1160. RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4);
  1161. weight = symvar.node();
  1162. }
  1163. if (new_inp.size() == 2) {
  1164. auto param = conv_bias.param();
  1165. param.format = Format::CHWN4;
  1166. auto new_opr = opr::ConvBiasForward::make(
  1167. src, weight, param, conv_bias.execution_policy(),
  1168. conv_bias.config());
  1169. varshape_changed.insert(new_opr.node());
  1170. return new_opr.node()->owner_opr();
  1171. }
  1172. auto process_inp = [&](VarNode* inp) -> VarNode* {
  1173. if (varshape_changed.count(inp) == 0) {
  1174. auto symvar = RelayoutPlaceholder::make(
  1175. inp, RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4);
  1176. return symvar.node();
  1177. } else {
  1178. return inp;
  1179. }
  1180. };
  1181. // process bias tensor
  1182. bias = process_inp(new_inp[2]);
  1183. if (new_inp.size() == 3) {
  1184. auto param = conv_bias.param();
  1185. param.format = Format::CHWN4;
  1186. auto new_opr = opr::ConvBiasForward::make(
  1187. src, weight, bias, param, conv_bias.execution_policy(),
  1188. conv_bias.config());
  1189. varshape_changed.insert(new_opr.node());
  1190. return new_opr.node()->owner_opr();
  1191. }
  1192. // process z_inp tensor
  1193. z_inp = process_inp(new_inp[3]);
  1194. auto param = conv_bias.param();
  1195. param.format = Format::CHWN4;
  1196. auto new_opr = opr::ConvBiasForward::make(
  1197. src, weight, bias, z_inp, param, conv_bias.execution_policy(),
  1198. conv_bias.config());
  1199. varshape_changed.insert(new_opr.node());
  1200. return new_opr.node()->owner_opr();
  1201. };
  1202. // replace rule for elemwise like opr
  1203. // for oprs support NCHW4 and CHWN4 layout
  1204. auto replace_elemwise_like_opr = [&varshape_changed](
  1205. OperatorNodeBase* opr,
  1206. const VarNodeArray new_inp) {
  1207. mgb_assert(opr->input().size() == new_inp.size());
  1208. size_t nr_inps = new_inp.size();
  1209. size_t nr_shape_changed = 0;
  1210. for (size_t i = 0; i < nr_inps; ++i) {
  1211. if (varshape_changed.count(new_inp[i])) {
  1212. nr_shape_changed++;
  1213. }
  1214. }
  1215. if (nr_shape_changed) {
  1216. auto inps = new_inp;
  1217. if (nr_shape_changed >=
  1218. nr_inps / 2) { // CHWN4 > NCHW4 -> use CHWN4
  1219. for (size_t i = 0; i < nr_inps; ++i) {
  1220. if (varshape_changed.count(new_inp[i]) == 0) {
  1221. auto symvar = RelayoutPlaceholder::make(
  1222. new_inp[i], RelayoutPlaceholder::LayoutType::
  1223. NCHW4_TO_CHWN4);
  1224. inps[i] = symvar.node();
  1225. }
  1226. }
  1227. auto new_opr = serialization::copy_opr_shallow(*opr, inps,
  1228. opr->config());
  1229. varshape_changed.insert(new_opr->output(0));
  1230. return new_opr;
  1231. } else { // CHWN4 < NCHW4 -> use NCHW4
  1232. for (size_t i = 0; i < nr_inps; ++i) {
  1233. if (varshape_changed.count(new_inp[i])) {
  1234. auto symvar = RelayoutPlaceholder::make(
  1235. new_inp[i], RelayoutPlaceholder::LayoutType::
  1236. CHWN4_TO_NCHW4);
  1237. inps[i] = symvar.node();
  1238. }
  1239. }
  1240. return serialization::copy_opr_shallow(*opr, inps,
  1241. opr->config());
  1242. }
  1243. }
  1244. return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  1245. };
  1246. // for oprs only supports NCHW4 layout
  1247. auto replace_inps_to_nchw4 = [&varshape_changed](
  1248. OperatorNodeBase* opr,
  1249. const VarNodeArray new_inp) {
  1250. mgb_assert(opr->input().size() == new_inp.size());
  1251. VarNodeArray inps = new_inp;
  1252. for (size_t i = 0; i < opr->input().size(); ++i) {
  1253. if (varshape_changed.count(new_inp[i])) {
  1254. auto symvar = RelayoutPlaceholder::make(
  1255. new_inp[i],
  1256. RelayoutPlaceholder::LayoutType::CHWN4_TO_NCHW4);
  1257. inps[i] = symvar.node();
  1258. }
  1259. }
  1260. auto new_opr =
  1261. serialization::copy_opr_shallow(*opr, inps, opr->config());
  1262. return new_opr;
  1263. };
  1264. auto replace_non_nchw4_opr = [&varshape_changed](
  1265. OperatorNodeBase* opr,
  1266. const VarNodeArray new_inp) {
  1267. size_t nr_inps = opr->input().size();
  1268. bool shape_has_changed = false;
  1269. for (size_t i = 0; i < nr_inps; ++i) {
  1270. if (varshape_changed.count(new_inp[i])) {
  1271. shape_has_changed = true;
  1272. }
  1273. }
  1274. mgb_assert(!shape_has_changed,
  1275. "EnableCHWN4Pass assumes that inputs' shape of "
  1276. "non-nchw4 operators "
  1277. "can not be changed in this opt "
  1278. "pass");
  1279. return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  1280. };
  1281. // capture by copy to avoid use after return
  1282. auto replace_warp_affine_opr =
  1283. [replace_inps_to_nchw4, replace_non_nchw4_opr](
  1284. OperatorNodeBase* opr, const VarNodeArray new_inp) {
  1285. using Param = opr::WarpAffineForward::Param;
  1286. using Format = Param::Format;
  1287. mgb_assert(opr->input().size() == new_inp.size());
  1288. auto& warp = opr->cast_final_safe<opr::WarpAffineForward>();
  1289. if (warp.param().format != Format::NCHW4) {
  1290. return replace_non_nchw4_opr(opr, new_inp);
  1291. }
  1292. return replace_inps_to_nchw4(opr, new_inp);
  1293. };
  1294. auto replace_warp_perspective_opr =
  1295. [replace_inps_to_nchw4, replace_non_nchw4_opr](
  1296. OperatorNodeBase* opr, const VarNodeArray new_inp) {
  1297. using Param = opr::WarpPerspectiveForward::Param;
  1298. using Format = Param::Format;
  1299. mgb_assert(opr->input().size() == new_inp.size());
  1300. auto& warp =
  1301. opr->cast_final_safe<opr::WarpPerspectiveForward>();
  1302. if (warp.param().format != Format::NCHW4) {
  1303. return replace_non_nchw4_opr(opr, new_inp);
  1304. }
  1305. return replace_inps_to_nchw4(opr, new_inp);
  1306. };
  1307. auto replace_resize_opr = [replace_inps_to_nchw4, replace_non_nchw4_opr](
  1308. OperatorNodeBase* opr,
  1309. const VarNodeArray new_inp) {
  1310. using Param = opr::ResizeForward::Param;
  1311. using Format = Param::Format;
  1312. mgb_assert(opr->input().size() == new_inp.size());
  1313. auto& resize = opr->cast_final_safe<opr::ResizeForward>();
  1314. if (resize.param().format != Format::NCHW4) {
  1315. return replace_non_nchw4_opr(opr, new_inp);
  1316. }
  1317. return replace_inps_to_nchw4(opr, new_inp);
  1318. };
  1319. auto replace_pooling_opr = [&varshape_changed, replace_non_nchw4_opr](
  1320. OperatorNodeBase* opr,
  1321. const VarNodeArray new_inp) {
  1322. using Param = opr::PoolingForward::Param;
  1323. using Format = Param::Format;
  1324. mgb_assert(opr->input().size() == new_inp.size());
  1325. auto& pooling = opr->cast_final_safe<opr::PoolingForward>();
  1326. if (pooling.param().format != Format::NCHW4) {
  1327. return replace_non_nchw4_opr(opr, new_inp);
  1328. }
  1329. size_t nr_inps = opr->input().size();
  1330. MGB_MARK_USED_VAR(nr_inps);
  1331. mgb_assert(nr_inps == 1);
  1332. if (varshape_changed.count(new_inp[0])) {
  1333. auto new_param = pooling.param();
  1334. new_param.format = Format::CHWN4;
  1335. auto new_pooling = opr::PoolingForward::make(new_inp[0], new_param,
  1336. opr->config());
  1337. varshape_changed.insert(new_pooling.node());
  1338. return new_pooling.node()->owner_opr();
  1339. }
  1340. return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  1341. };
  1342. replace_func[opr::ConvBiasForward::typeinfo()] = replace_conv_bias_opr;
  1343. // elemwise like
  1344. replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr;
  1345. replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr;
  1346. replace_func[opr::ElemwiseMultiType::typeinfo()] =
  1347. replace_elemwise_like_opr;
  1348. replace_func[opr::PowC::typeinfo()] = replace_elemwise_like_opr;
  1349. // format aware
  1350. replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
  1351. replace_func[opr::WarpAffineForward::typeinfo()] = replace_warp_affine_opr;
  1352. replace_func[opr::WarpPerspectiveForward::typeinfo()] =
  1353. replace_warp_perspective_opr;
  1354. replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr;
  1355. // to nchw4
  1356. replace_func[opr::Reduce::typeinfo()] = replace_inps_to_nchw4;
  1357. replace_func[opr::Concat::typeinfo()] = replace_inps_to_nchw4;
  1358. replace_func[opr::Reshape::typeinfo()] = replace_inps_to_nchw4;
  1359. replace_func[opr::GetVarShape::typeinfo()] = replace_inps_to_nchw4;
  1360. replace_func[opr::Dimshuffle::typeinfo()] = replace_inps_to_nchw4;
  1361. replace_func[opr::BatchConvBias::typeinfo()] = replace_inps_to_nchw4;
  1362. return ret;
  1363. MIDOUT_E
  1364. }
  1365. /* ================ EnableNCHW4Pass ================ */
  1366. VarNode* EnableNCHW4Pass::on_graph_endpoint_var(VarNode* new_var,
  1367. VarNode* orig_var) const {
  1368. if (!orig_var->shape().eq_shape(new_var->shape())) {
  1369. return RelayoutPlaceholder::make(
  1370. new_var, RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW)
  1371. .node();
  1372. }
  1373. return new_var;
  1374. }
  1375. std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
  1376. MIDOUT_B("EnableNCHW4Pass::make")
  1377. auto ret = std::make_unique<EnableNCHW4Pass>();
  1378. ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
  1379. using RelayoutMode = RelayoutPlaceholder::LayoutType;
  1380. megdnn::param::Convolution::Format conv_format =
  1381. megdnn::param::Convolution::Format::NCHW4;
  1382. megdnn::param::ConvBias::Format conv_bias_format =
  1383. megdnn::param::ConvBias::Format::NCHW4;
  1384. megdnn::param::BatchConvBias::Format batch_conv_bias_format =
  1385. megdnn::param::BatchConvBias::Format::NCHW4;
  1386. RelayoutMode src_to_nchw4_mode = RelayoutMode::NCHW_TO_NCHW4;
  1387. RelayoutMode src_to_nchw_mode = RelayoutMode::NCHW4_TO_NCHW;
  1388. RelayoutMode weight_to_nchw4_mode_dense =
  1389. RelayoutMode::WEIGHT_NCHW_TO_NCHW4_DENSE;
  1390. RelayoutMode weight_to_nchw4_mode_group =
  1391. RelayoutMode::WEIGHT_NCHW_TO_NCHW4_GROUP;
  1392. struct ConvMode {
  1393. RelayoutMode weight;
  1394. RelayoutMode src;
  1395. };
  1396. auto trans_nchw4 =
  1397. [weight_to_nchw4_mode_dense, weight_to_nchw4_mode_group,
  1398. src_to_nchw4_mode](
  1399. const megdnn::param::Convolution::Sparse conv_mode,
  1400. const VarNode* filter) -> ConvMode {
  1401. if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) {
  1402. mgb_assert(filter->shape().ndim == 4,
  1403. "The origin filter is not NCHW mode");
  1404. size_t IC = filter->shape()[1];
  1405. if (IC < 4) {
  1406. return {RelayoutMode::WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV,
  1407. RelayoutMode::NCHW_TO_NCHW4_IC_SMALL_CONV};
  1408. } else {
  1409. return {weight_to_nchw4_mode_dense, src_to_nchw4_mode};
  1410. }
  1411. } else {
  1412. mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP);
  1413. mgb_assert(filter->shape().ndim == 5,
  1414. "The origin filter if not NCHW mode");
  1415. size_t IC = filter->shape()[2];
  1416. mgb_assert(IC % 4 == 0,
  1417. "The input channel should be divisible by 4 for group "
  1418. "conv");
  1419. return {weight_to_nchw4_mode_group, src_to_nchw4_mode};
  1420. }
  1421. };
  1422. auto replace_conv_opr = [trans_nchw4, conv_format](
  1423. OperatorNodeBase* opr,
  1424. const VarNodeArray& new_inp) {
  1425. mgb_assert(opr->input().size() == new_inp.size());
  1426. auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
  1427. if (conv_opr.param().format !=
  1428. megdnn::param::Convolution::Format::NCHW) {
  1429. return serialization::copy_opr_shallow(*opr, new_inp,
  1430. opr->config());
  1431. }
  1432. auto conv_mode = trans_nchw4(conv_opr.param().sparse, new_inp[1]);
  1433. VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1];
  1434. // src: NCHW --> NCWH4
  1435. if (new_inp[0]->shape().ndim != 5) {
  1436. mgb_assert(new_inp[0]->shape().ndim == 4);
  1437. auto new_src = RelayoutPlaceholder::make(new_inp[0], conv_mode.src);
  1438. conv_src = new_src.node();
  1439. }
  1440. // weight: NCHW --> NCHW4
  1441. auto new_filter =
  1442. RelayoutPlaceholder::make(new_inp[1], conv_mode.weight);
  1443. conv_filter = new_filter.node();
  1444. // format: NCHW --> NCHW4
  1445. auto new_param = conv_opr.param();
  1446. new_param.format = conv_format;
  1447. // dst
  1448. auto new_conv_opr = opr::Convolution::make(
  1449. conv_src, conv_filter, new_param, conv_opr.execution_policy(),
  1450. conv_opr.config());
  1451. OperatorNodeBase* new_opr = new_conv_opr.node()->owner_opr();
  1452. mgb_assert(new_conv_opr.shape().ndim == 5,
  1453. "The conv dst dim is not trans to nchw4");
  1454. return new_opr;
  1455. };
  1456. auto replace_batch_conv_bias_opr = [batch_conv_bias_format,
  1457. src_to_nchw4_mode](
  1458. OperatorNodeBase* opr,
  1459. const VarNodeArray& new_inp) {
  1460. mgb_assert(opr->input().size() == new_inp.size());
  1461. auto& batch_conv_bias_opr =
  1462. opr->cast_final_safe<opr::BatchConvBiasForward>();
  1463. if (batch_conv_bias_opr.param().format !=
  1464. megdnn::param::BatchConvBias::Format::NCHW) {
  1465. return serialization::copy_opr_shallow(*opr, new_inp,
  1466. opr->config());
  1467. }
  1468. mgb_assert(batch_conv_bias_opr.param().format ==
  1469. megdnn::param::BatchConvBias::Format::NCHW,
  1470. "ConvertFormat Pass only support converting NCHW to NCHW4");
  1471. // what should be converted: src, weight
  1472. VarNode *src = new_inp[0], *filter = new_inp[1];
  1473. // src: NCHW --> NCHW4
  1474. if (new_inp[0]->shape().ndim != 5) {
  1475. mgb_assert(new_inp[0]->shape().ndim == 4);
  1476. auto new_src =
  1477. RelayoutPlaceholder::make(new_inp[0], src_to_nchw4_mode);
  1478. src = new_src.node();
  1479. }
  1480. // weight: BNCHW --> BNCHW4
  1481. // only support dense mode, which is similar with conv->group.
  1482. auto weight_mode =
  1483. RelayoutPlaceholder::LayoutType::WEIGHT_NCHW_TO_NCHW4_GROUP;
  1484. auto new_filter = RelayoutPlaceholder::make(new_inp[1], weight_mode);
  1485. filter = new_filter.node();
  1486. // format: NCHW --> NCHW4
  1487. auto new_param = batch_conv_bias_opr.param();
  1488. new_param.format = batch_conv_bias_format;
  1489. if (new_inp.size() == 2) {
  1490. auto dst = opr::BatchConvBias::make(
  1491. src, filter, new_param,
  1492. batch_conv_bias_opr.execution_policy(),
  1493. batch_conv_bias_opr.config());
  1494. OperatorNodeBase* new_opr = dst.node()->owner_opr();
  1495. mgb_assert(dst.shape().ndim == 5,
  1496. "The conv_bias dst dim is not trans to nchw4");
  1497. return new_opr;
  1498. }
  1499. // bias: NCHW --> NCHW4
  1500. VarNode* bias = new_inp[2];
  1501. if (new_inp[2]->shape().ndim == 4) {
  1502. auto new_bias =
  1503. RelayoutPlaceholder::make(new_inp[2], src_to_nchw4_mode);
  1504. bias = new_bias.node();
  1505. }
  1506. if (new_inp.size() == 3) {
  1507. auto dst = opr::BatchConvBias::make(
  1508. src, filter, bias, new_param,
  1509. batch_conv_bias_opr.execution_policy(),
  1510. batch_conv_bias_opr.config());
  1511. OperatorNodeBase* new_opr = dst.node()->owner_opr();
  1512. mgb_assert(dst.shape().ndim == 5,
  1513. "The conv_bias dst dim is not trans to nchw4");
  1514. return new_opr;
  1515. }
  1516. // z_inp: NCHW --> NCHW4
  1517. VarNode* z_inp = new_inp[3];
  1518. if (new_inp[3]->shape().ndim == 4) {
  1519. auto new_z =
  1520. RelayoutPlaceholder::make(new_inp[3], src_to_nchw4_mode);
  1521. z_inp = new_z.node();
  1522. }
  1523. auto dst =
  1524. opr::BatchConvBias::make(src, filter, bias, z_inp, new_param,
  1525. batch_conv_bias_opr.execution_policy(),
  1526. batch_conv_bias_opr.config());
  1527. OperatorNodeBase* new_opr = dst.node()->owner_opr();
  1528. mgb_assert(dst.shape().ndim == 5,
  1529. "The conv_bias dst dim is not trans to nchw4");
  1530. return new_opr;
  1531. };
  1532. auto replace_conv_bias_opr = [trans_nchw4, conv_bias_format,
  1533. src_to_nchw4_mode](
  1534. OperatorNodeBase* opr,
  1535. const VarNodeArray& new_inp) {
  1536. mgb_assert(opr->input().size() == new_inp.size());
  1537. auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
  1538. if (conv_bias_opr.param().format !=
  1539. megdnn::param::Convolution::Format::NCHW) {
  1540. return serialization::copy_opr_shallow(*opr, new_inp,
  1541. opr->config());
  1542. }
  1543. // what should be converted: src, weight
  1544. VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1];
  1545. auto conv_mode = trans_nchw4(conv_bias_opr.param().sparse, new_inp[1]);
  1546. // src: NCHW --> NCHW4
  1547. if (new_inp[0]->shape().ndim != 5) {
  1548. mgb_assert(new_inp[0]->shape().ndim == 4);
  1549. auto new_src = RelayoutPlaceholder::make(new_inp[0], conv_mode.src);
  1550. conv_bias_src = new_src.node();
  1551. }
  1552. // weight: NCHW --> NCHW4 or GNCHW --> GNCHW4
  1553. auto new_filter =
  1554. RelayoutPlaceholder::make(new_inp[1], conv_mode.weight);
  1555. conv_bias_filter = new_filter.node();
  1556. // format: NCHW --> NCHW4
  1557. auto new_param = conv_bias_opr.param();
  1558. new_param.format = conv_bias_format;
  1559. if (new_inp.size() == 2) {
  1560. auto new_conv_bias_opr = opr::ConvBias::make(
  1561. conv_bias_src, conv_bias_filter, new_param,
  1562. conv_bias_opr.execution_policy(), conv_bias_opr.config());
  1563. OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
  1564. mgb_assert(new_conv_bias_opr.shape().ndim == 5,
  1565. "The conv_bias dst dim is not trans to nchw4");
  1566. return new_opr;
  1567. }
  1568. // bias: NCHW --> NCHW4
  1569. VarNode* conv_bias_bias = new_inp[2];
  1570. if (new_inp[2]->shape().ndim == 4) {
  1571. auto new_bias =
  1572. RelayoutPlaceholder::make(new_inp[2], src_to_nchw4_mode);
  1573. conv_bias_bias = new_bias.node();
  1574. }
  1575. if (new_inp.size() == 3) {
  1576. auto new_conv_bias_opr = opr::ConvBias::make(
  1577. conv_bias_src, conv_bias_filter, conv_bias_bias, new_param,
  1578. conv_bias_opr.execution_policy(), conv_bias_opr.config());
  1579. OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
  1580. mgb_assert(new_conv_bias_opr.shape().ndim == 5,
  1581. "The conv_bias dst dim is not trans to nchw4");
  1582. return new_opr;
  1583. }
  1584. // z_inp: NCHW --> NCHW4
  1585. VarNode* z_inp = new_inp[3];
  1586. if (new_inp[3]->shape().ndim == 4) {
  1587. auto new_z =
  1588. RelayoutPlaceholder::make(new_inp[3], src_to_nchw4_mode);
  1589. z_inp = new_z.node();
  1590. }
  1591. auto new_conv_bias_opr = opr::ConvBias::make(
  1592. conv_bias_src, conv_bias_filter, conv_bias_bias, z_inp,
  1593. new_param, conv_bias_opr.execution_policy(),
  1594. conv_bias_opr.config());
  1595. OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
  1596. mgb_assert(new_conv_bias_opr.shape().ndim == 5,
  1597. "The conv_bias dst dim is not trans to nchw4");
  1598. return new_opr;
  1599. };
  1600. auto replace_elemwise_opr = [=](OperatorNodeBase* opr,
  1601. const VarNodeArray& new_inp) {
  1602. mgb_assert(opr->input().size() == new_inp.size());
  1603. bool has_inp_changed = false;
  1604. for (size_t i = 0; i < opr->input().size(); i++) {
  1605. if (new_inp[i]->shape().ndim == 5) {
  1606. has_inp_changed = true;
  1607. break;
  1608. }
  1609. }
  1610. if (has_inp_changed) {
  1611. auto temp_inp = new_inp;
  1612. for (size_t i = 0; i < opr->input().size(); i++) {
  1613. if (new_inp[i]->shape().ndim == 4) {
  1614. auto new_var = RelayoutPlaceholder::make(new_inp[i],
  1615. src_to_nchw4_mode);
  1616. temp_inp[i] = new_var.node();
  1617. } else {
  1618. mgb_assert((new_inp[i]->shape().ndim == 5) ||
  1619. new_inp[i]->shape().is_scalar());
  1620. }
  1621. }
  1622. return serialization::copy_opr_shallow(*opr, temp_inp,
  1623. opr->config());
  1624. } else {
  1625. return serialization::copy_opr_shallow(*opr, new_inp,
  1626. opr->config());
  1627. }
  1628. };
  1629. auto relayout_inp_to_nchw = [=](OperatorNodeBase* opr,
  1630. const VarNodeArray& new_inp) {
  1631. mgb_assert(opr->input().size() == new_inp.size());
  1632. VarNodeArray temp_inp = new_inp;
  1633. for (size_t i = 0; i < opr->input().size(); i++) {
  1634. if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
  1635. mgb_assert(opr->input(i)->shape().ndim == 4);
  1636. mgb_assert(new_inp[i]->shape().ndim == 5);
  1637. auto new_var =
  1638. RelayoutPlaceholder::make(new_inp[i], src_to_nchw_mode);
  1639. temp_inp[i] = new_var.node();
  1640. }
  1641. }
  1642. return serialization::copy_opr_shallow(*opr, temp_inp, opr->config());
  1643. };
  1644. auto replace_pooling_opr = [](OperatorNodeBase* opr,
  1645. const VarNodeArray& new_inp) {
  1646. using Param = opr::PoolingForward::Param;
  1647. using Format = Param::Format;
  1648. mgb_assert(opr->input().size() == new_inp.size());
  1649. auto& pooling = opr->cast_final_safe<opr::PoolingForward>();
  1650. if (pooling.param().format != Format::NCHW) {
  1651. return opr;
  1652. }
  1653. if (new_inp[0]->shape().ndim == 5) {
  1654. mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8);
  1655. auto new_param = pooling.param();
  1656. new_param.format = Format::NCHW4;
  1657. auto new_pooling = opr::PoolingForward::make(new_inp[0], new_param,
  1658. opr->config());
  1659. mgb_assert(new_pooling.shape().ndim == 5,
  1660. "out var of Pooling opr after transform must be 5 (got: "
  1661. "%zu).",
  1662. new_pooling.shape().ndim);
  1663. return new_pooling.node()->owner_opr();
  1664. }
  1665. auto new_opr =
  1666. serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  1667. return new_opr;
  1668. };
  1669. auto replace_resize_opr = [](OperatorNodeBase* opr,
  1670. const VarNodeArray& new_inp) {
  1671. using Param = opr::ResizeForward::Param;
  1672. using Format = Param::Format;
  1673. mgb_assert(opr->input().size() == new_inp.size());
  1674. auto& resize = opr->cast_final_safe<opr::ResizeForward>();
  1675. if (new_inp[0]->shape().ndim == 5) {
  1676. mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8);
  1677. auto new_param = resize.param();
  1678. new_param.format = Format::NCHW4;
  1679. auto new_resize = opr::ResizeForward::make(
  1680. new_inp[0], new_inp[1], new_param, opr->config());
  1681. mgb_assert(new_resize.shape().ndim == 5,
  1682. "out var of Resize opr after transform must be 5 (got: "
  1683. "%zu).",
  1684. new_resize.shape().ndim);
  1685. return new_resize.node()->owner_opr();
  1686. }
  1687. auto new_opr =
  1688. serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  1689. return new_opr;
  1690. };
  1691. auto replace_warp_perspective_opr = [](OperatorNodeBase* opr,
  1692. const VarNodeArray& new_inp) {
  1693. using Param = opr::WarpPerspective::Param;
  1694. using Format = Param::Format;
  1695. mgb_assert(opr->input().size() == new_inp.size());
  1696. auto& warp = opr->cast_final_safe<opr::WarpPerspectiveForward>();
  1697. if (new_inp[0]->shape().ndim == 5) {
  1698. mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8);
  1699. auto new_param = warp.param();
  1700. new_param.format = Format::NCHW4;
  1701. SymbolVar new_warp;
  1702. if (new_inp.size() == 3) {
  1703. new_warp = opr::WarpPerspectiveForward::make(
  1704. new_inp[0], new_inp[1], nullptr, new_inp[2], new_param,
  1705. opr->config());
  1706. } else {
  1707. mgb_assert(new_inp.size() == 4);
  1708. new_warp = opr::WarpPerspectiveForward::make(
  1709. new_inp[0], new_inp[1], new_inp[2], new_inp[3],
  1710. new_param, opr->config());
  1711. }
  1712. mgb_assert(new_warp.shape().ndim == 5,
  1713. "out var of WarpPerspective opr after transform must be "
  1714. "5 (got: "
  1715. "%zu).",
  1716. new_warp.shape().ndim);
  1717. return new_warp.node()->owner_opr();
  1718. }
  1719. auto new_opr =
  1720. serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  1721. return new_opr;
  1722. };
  1723. auto&& replace_func = ret->m_opr_replace_func;
  1724. //! supportted nchw4
  1725. replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
  1726. replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr;
  1727. replace_func[opr::BatchConvBias::typeinfo()] = replace_batch_conv_bias_opr;
  1728. replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
  1729. replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr;
  1730. replace_func[opr::WarpPerspectiveForward::typeinfo()] =
  1731. replace_warp_perspective_opr;
  1732. replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr;
  1733. replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr;
  1734. replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr;
  1735. replace_func[opr::PowC::typeinfo()] = replace_elemwise_opr;
  1736. //! not supported nchw4
  1737. replace_func[opr::Concat::typeinfo()] = relayout_inp_to_nchw;
  1738. replace_func[opr::ConvolutionBackwardData::typeinfo()] =
  1739. relayout_inp_to_nchw;
  1740. replace_func[opr::Subtensor::typeinfo()] = relayout_inp_to_nchw;
  1741. replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_nchw;
  1742. replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_nchw;
  1743. replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw;
  1744. replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw;
  1745. replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw;
  1746. replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw;
  1747. return ret;
  1748. MIDOUT_E
  1749. }
  1750. /* ================ EnableNchwxxPass =============== */
  1751. VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var,
  1752. VarNode* orig_var) const {
  1753. if (!orig_var->shape().eq_shape(new_var->shape())) {
  1754. if (m_pack_c_size == 8) {
  1755. return RelayoutPlaceholder::make(
  1756. new_var,
  1757. RelayoutPlaceholder::LayoutType::NCHW88_TO_NCHW)
  1758. .node();
  1759. } else if (m_pack_c_size == 4) {
  1760. return RelayoutPlaceholder::make(
  1761. new_var,
  1762. RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW)
  1763. .node();
  1764. }
  1765. }
  1766. return new_var;
  1767. }
  1768. //! nchw_nchwxx_valid is used to indicate optimized nchw_nchw44 conv
  1769. static inline bool nchw_nchwxx_valid(const size_t oc, const size_t ic,
  1770. const size_t pack_c_size, const size_t fh,
  1771. const size_t fw, const size_t stride_h,
  1772. const size_t stride_w) {
  1773. return ic < pack_c_size && oc % pack_c_size == 0 && fh == fw &&
  1774. stride_h == stride_w && (stride_h == 1 || stride_h == 2) &&
  1775. (fh == 2 || fh == 3 || fh == 5 || fh == 7);
  1776. }
  1777. void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
  1778. using RelayoutMode = RelayoutPlaceholder::LayoutType;
  1779. using TestFilterResult = std::pair<TransType, RelayoutMode>;
  1780. RelayoutMode weight_to_nchwxx_mode_dense =
  1781. RelayoutMode::WEIGHT_NCHW_TO_NCHW88_DENSE;
  1782. RelayoutMode weight_to_nchwxx_mode_group =
  1783. RelayoutMode::WEIGHT_NCHW_TO_NCHW88_GROUP;
  1784. RelayoutMode weight_to_nchwxx_mode_chan =
  1785. RelayoutMode::WEIGHT_NCHW_TO_NCHW88_CHAN;
  1786. RelayoutMode hybrid_nchw_nchwxx = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW88;
  1787. RelayoutMode src_to_nchwxx_mode = RelayoutMode::NCHW_TO_NCHW88;
  1788. RelayoutMode src_to_nchw_mode = RelayoutMode::NCHW88_TO_NCHW;
  1789. megdnn::param::ConvBias::Format conv_bias_format =
  1790. megdnn::param::ConvBias::Format::NCHW88;
  1791. megdnn::param::Convolution::Format conv_format =
  1792. megdnn::param::ConvolutionV0::Format::NCHW88;
  1793. megdnn::param::Pooling::Format pooling_format =
  1794. megdnn::param::Pooling::Format::NCHW88;
  1795. std::string convter_pass_name = "conv_format_nchw88";
  1796. if (pack_c_size == 4) {
  1797. weight_to_nchwxx_mode_dense = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DENSE;
  1798. weight_to_nchwxx_mode_group = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_GROUP;
  1799. weight_to_nchwxx_mode_chan = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_CHAN;
  1800. hybrid_nchw_nchwxx = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44;
  1801. src_to_nchwxx_mode = RelayoutMode::NCHW_TO_NCHW4;
  1802. src_to_nchw_mode = RelayoutMode::NCHW4_TO_NCHW;
  1803. conv_bias_format = megdnn::param::ConvBias::Format::NCHW44;
  1804. conv_format = megdnn::param::ConvolutionV0::Format::NCHW44;
  1805. pooling_format = megdnn::param::Pooling::Format::NCHW44;
  1806. convter_pass_name = "conv_format_nchw44";
  1807. }
  1808. auto test_trans_nchwxx =
  1809. [pack_c_size, weight_to_nchwxx_mode_dense,
  1810. weight_to_nchwxx_mode_group, weight_to_nchwxx_mode_chan,
  1811. hybrid_nchw_nchwxx](
  1812. const megdnn::param::Convolution::Sparse conv_mode,
  1813. const VarNode* filter, const size_t stride_h,
  1814. const size_t stride_w) -> TestFilterResult {
  1815. TestFilterResult ret{TransType::TRANS_NONE, {}};
  1816. if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) {
  1817. size_t OC = filter->shape()[0];
  1818. size_t IC = filter->shape()[1];
  1819. size_t FH = filter->shape()[2];
  1820. size_t FW = filter->shape()[3];
  1821. if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) {
  1822. ret.first = TransType::TRANS_PURE_NCHWXX;
  1823. ret.second = weight_to_nchwxx_mode_dense;
  1824. } else if (nchw_nchwxx_valid(OC, IC, pack_c_size, FH, FW, stride_h,
  1825. stride_w)) {
  1826. ret.first = TransType::TRANS_HYBIRD_NCHWXX;
  1827. ret.second = hybrid_nchw_nchwxx;
  1828. }
  1829. } else {
  1830. mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP);
  1831. size_t group = filter->shape()[0];
  1832. size_t ocpg = filter->shape()[1];
  1833. size_t icpg = filter->shape()[2];
  1834. if (icpg == 1 && ocpg == 1 && (group % pack_c_size == 0)) {
  1835. ret.first = TransType::TRANS_PURE_NCHWXX;
  1836. ret.second = weight_to_nchwxx_mode_chan;
  1837. } else if ((icpg % pack_c_size == 0) && (ocpg % pack_c_size == 0)) {
  1838. ret.first = TransType::TRANS_PURE_NCHWXX;
  1839. ret.second = weight_to_nchwxx_mode_group;
  1840. }
  1841. }
  1842. return ret;
  1843. };
  1844. auto replace_conv_opr = [test_trans_nchwxx, conv_format, src_to_nchwxx_mode,
  1845. src_to_nchw_mode](OperatorNodeBase* opr,
  1846. const VarNodeArray& new_inp) {
  1847. mgb_assert(opr->input().size() == new_inp.size());
  1848. auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
  1849. mgb_assert(conv_opr.param().format ==
  1850. megdnn::param::Convolution::Format::NCHW,
  1851. "ConvertFormat Pass only support converting NCHW to NCHWXX");
  1852. auto is_trans = test_trans_nchwxx(conv_opr.param().sparse, new_inp[1],
  1853. conv_opr.param().stride_h,
  1854. conv_opr.param().stride_w);
  1855. //! can not trans to nchwxx
  1856. if (is_trans.first == TransType::TRANS_NONE) {
  1857. mgb_assert(new_inp[1]->shape().ndim == 4 ||
  1858. new_inp[1]->shape().ndim == 5,
  1859. "The origin filter is not NCHW mode");
  1860. VarNodeArray temp_inp = new_inp;
  1861. //! if src is nchwxx, should RelayoutPlaceholder to nchw
  1862. if (temp_inp[0]->shape().ndim == 5) {
  1863. auto new_src =
  1864. RelayoutPlaceholder::make(new_inp[0], src_to_nchw_mode);
  1865. temp_inp[0] = new_src.node();
  1866. }
  1867. auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp,
  1868. opr->config());
  1869. return new_opr;
  1870. } else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) {
  1871. //! filter trans to nchwxx mode
  1872. mgb_assert(new_inp[1]->shape().ndim == 4 ||
  1873. new_inp[1]->shape().ndim == 5,
  1874. "The origin filter is not NCHW mode");
  1875. VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1];
  1876. auto new_filter =
  1877. RelayoutPlaceholder::make(new_inp[1], is_trans.second);
  1878. conv_filter = new_filter.node();
  1879. //! src trans to nchwxx mode
  1880. if (new_inp[0]->shape().ndim != 5) {
  1881. mgb_assert(new_inp[0]->shape().ndim == 4);
  1882. auto new_src = RelayoutPlaceholder::make(new_inp[0],
  1883. src_to_nchwxx_mode);
  1884. conv_src = new_src.node();
  1885. }
  1886. auto new_param = conv_opr.param();
  1887. new_param.format = conv_format;
  1888. mgb_assert(conv_src->shape().ndim == 5 &&
  1889. conv_filter->shape().ndim >= 6,
  1890. "The conv src dim is not trans to nchwxx");
  1891. auto new_conv_opr = opr::Convolution::make(
  1892. conv_src, conv_filter, new_param,
  1893. conv_opr.execution_policy(), conv_opr.config());
  1894. OperatorNodeBase* new_opr = new_conv_opr.node()->owner_opr();
  1895. mgb_assert(new_conv_opr.shape().ndim == 5,
  1896. "The conv dst dim is not trans to nchwxx");
  1897. return new_opr;
  1898. } else {
  1899. mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX);
  1900. VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1];
  1901. auto new_filter =
  1902. RelayoutPlaceholder::make(new_inp[1], is_trans.second);
  1903. conv_filter = new_filter.node();
  1904. mgb_assert(conv_src->shape().ndim == 4 &&
  1905. conv_filter->shape().ndim == 5,
  1906. "The src and filter is OK");
  1907. auto new_param = conv_opr.param();
  1908. new_param.format = conv_format;
  1909. auto new_conv_opr = opr::Convolution::make(
  1910. conv_src, conv_filter, new_param,
  1911. conv_opr.execution_policy(), conv_opr.config());
  1912. OperatorNodeBase* new_opr = new_conv_opr.node()->owner_opr();
  1913. mgb_assert(new_conv_opr.shape().ndim == 5,
  1914. "The conv dst dim is not trans to nchwxx");
  1915. return new_opr;
  1916. }
  1917. };
  1918. auto replace_conv_bias_opr = [test_trans_nchwxx, conv_bias_format,
  1919. src_to_nchwxx_mode, src_to_nchw_mode](
  1920. OperatorNodeBase* opr,
  1921. const VarNodeArray& new_inp) {
  1922. mgb_assert(opr->input().size() == new_inp.size());
  1923. auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
  1924. mgb_assert(conv_bias_opr.param().format ==
  1925. megdnn::param::ConvBias::Format::NCHW,
  1926. "ConvertFormat Pass only support converting NCHW to NCHWXX");
  1927. auto is_trans = test_trans_nchwxx(
  1928. conv_bias_opr.param().sparse, new_inp[1],
  1929. conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w);
  1930. //! can not trans to nchwxx
  1931. if (is_trans.first == TransType::TRANS_NONE) {
  1932. mgb_assert(new_inp[1]->shape().ndim == 4 ||
  1933. new_inp[1]->shape().ndim == 5,
  1934. "The origin filter is not NCHW mode");
  1935. VarNodeArray temp_inp = new_inp;
  1936. //! if src is nchwxx, should RelayoutPlaceholder to nchw
  1937. if (temp_inp[0]->shape().ndim == 5) {
  1938. auto new_src =
  1939. RelayoutPlaceholder::make(new_inp[0], src_to_nchw_mode);
  1940. temp_inp[0] = new_src.node();
  1941. }
  1942. //! the bias is nchwxx
  1943. if (temp_inp[2]->shape().ndim == 5) {
  1944. auto new_bias =
  1945. RelayoutPlaceholder::make(new_inp[2], src_to_nchw_mode);
  1946. temp_inp[2] = new_bias.node();
  1947. }
  1948. auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp,
  1949. opr->config());
  1950. return new_opr;
  1951. } else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) {
  1952. VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1],
  1953. *conv_bias_bias = new_inp[2];
  1954. //! filter trans to nchwxx mode
  1955. mgb_assert(new_inp[1]->shape().ndim == 4 ||
  1956. new_inp[1]->shape().ndim == 5,
  1957. "The origin filter is not NCHW mode");
  1958. auto new_filter =
  1959. RelayoutPlaceholder::make(new_inp[1], is_trans.second);
  1960. conv_bias_filter = new_filter.node();
  1961. //! src trans to nchwxx mode
  1962. if (new_inp[0]->shape().ndim != 5) {
  1963. mgb_assert(new_inp[0]->shape().ndim == 4);
  1964. auto new_src = RelayoutPlaceholder::make(new_inp[0],
  1965. src_to_nchwxx_mode);
  1966. conv_bias_src = new_src.node();
  1967. }
  1968. //! bias trans to nchwxx mode, bias may be scale
  1969. if (new_inp[2]->shape().ndim == 4) {
  1970. auto new_bias = RelayoutPlaceholder::make(new_inp[2],
  1971. src_to_nchwxx_mode);
  1972. conv_bias_bias = new_bias.node();
  1973. }
  1974. auto new_param = conv_bias_opr.param();
  1975. new_param.format = conv_bias_format;
  1976. mgb_assert(conv_bias_src->shape().ndim == 5 &&
  1977. conv_bias_filter->shape().ndim >= 6,
  1978. "The conv_bias src dim is not trans to nchwxx");
  1979. auto new_conv_bias_opr = opr::ConvBias::make(
  1980. conv_bias_src, conv_bias_filter, conv_bias_bias, new_param,
  1981. conv_bias_opr.execution_policy(), conv_bias_opr.config());
  1982. OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
  1983. mgb_assert(new_conv_bias_opr.shape().ndim == 5,
  1984. "The conv_bias dst dim is not trans to nchwxx");
  1985. return new_opr;
  1986. } else {
  1987. mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX);
  1988. VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1],
  1989. *conv_bias_bias = new_inp[2];
  1990. auto new_filter =
  1991. RelayoutPlaceholder::make(new_inp[1], is_trans.second);
  1992. conv_bias_filter = new_filter.node();
  1993. //! bias trans to nchwxx mode, bias may be scale
  1994. if (new_inp[2]->shape().ndim == 4) {
  1995. auto new_bias = RelayoutPlaceholder::make(new_inp[2],
  1996. src_to_nchwxx_mode);
  1997. conv_bias_bias = new_bias.node();
  1998. }
  1999. mgb_assert(conv_bias_src->shape().ndim == 4 &&
  2000. conv_bias_filter->shape().ndim == 5);
  2001. mgb_assert((conv_bias_bias->shape().ndim == 5) ||
  2002. conv_bias_bias->shape().is_scalar());
  2003. auto new_param = conv_bias_opr.param();
  2004. new_param.format = conv_bias_format;
  2005. auto new_conv_bias_opr = opr::ConvBias::make(
  2006. conv_bias_src, conv_bias_filter, conv_bias_bias, new_param,
  2007. conv_bias_opr.execution_policy(), conv_bias_opr.config());
  2008. OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
  2009. mgb_assert(new_conv_bias_opr.shape().ndim == 5,
  2010. "The conv dst dim is not trans to nchwxx");
  2011. return new_opr;
  2012. }
  2013. };
  2014. auto replace_pooling_opr = [=](OperatorNodeBase* opr,
  2015. const VarNodeArray& new_inp) {
  2016. mgb_assert(opr->input().size() == new_inp.size());
  2017. auto& pooling_opr = opr->cast_final_safe<opr::PoolingForward>();
  2018. mgb_assert(pooling_opr.param().format ==
  2019. megdnn::param::Pooling::Format::NCHW,
  2020. "ConvertFormat Pass only support converting NCHW to NCHWxx");
  2021. VarNode* inp = new_inp[0];
  2022. //! if input is nchwxx
  2023. if (inp->shape().ndim == 5) {
  2024. auto new_param = pooling_opr.param();
  2025. new_param.format = pooling_format;
  2026. auto new_pooling_opr =
  2027. opr::PoolingForward::make(inp, new_param, opr->config());
  2028. mgb_assert(new_pooling_opr.shape().ndim == 5,
  2029. "The pooling dst dim is not trans to nchwxx");
  2030. return new_pooling_opr.node()->owner_opr();
  2031. } else {
  2032. auto new_opr = serialization::copy_opr_shallow(*opr, new_inp,
  2033. opr->config());
  2034. return new_opr;
  2035. }
  2036. };
  2037. //! When input change and all input can convert to nchwxx, this opr will run
  2038. //! in nchwxx mode, else it will run in nchw mode, for example concat and
  2039. //! elemwise opr
  2040. auto replace_multi_inp_opr = [=](OperatorNodeBase* opr,
  2041. const VarNodeArray& new_inp) {
  2042. mgb_assert(opr->input().size() == new_inp.size());
  2043. bool has_inp_changed = false;
  2044. bool can_exec_ncwxx = true;
  2045. for (size_t i = 0; i < opr->input().size(); i++) {
  2046. if (new_inp[i]->shape().ndim == 5) {
  2047. has_inp_changed = true;
  2048. } else if (new_inp[i]->shape().ndim == 4) {
  2049. if (new_inp[i]->shape()[1] % pack_c_size != 0) {
  2050. can_exec_ncwxx = false;
  2051. }
  2052. }
  2053. }
  2054. if (has_inp_changed) {
  2055. auto temp_inp = new_inp;
  2056. if (can_exec_ncwxx) {
  2057. for (size_t i = 0; i < opr->input().size(); i++) {
  2058. if (new_inp[i]->shape().ndim == 4) {
  2059. auto new_var = RelayoutPlaceholder::make(
  2060. new_inp[i], src_to_nchwxx_mode);
  2061. temp_inp[i] = new_var.node();
  2062. } else {
  2063. mgb_assert((new_inp[i]->shape().ndim == 5) ||
  2064. new_inp[i]->shape().is_scalar());
  2065. }
  2066. }
  2067. } else {
  2068. for (size_t i = 0; i < opr->input().size(); i++) {
  2069. if (new_inp[i]->shape().ndim == 5) {
  2070. auto new_var = RelayoutPlaceholder::make(
  2071. new_inp[i], src_to_nchw_mode);
  2072. temp_inp[i] = new_var.node();
  2073. }
  2074. }
  2075. }
  2076. return serialization::copy_opr_shallow(*opr, temp_inp,
  2077. opr->config());
  2078. } else {
  2079. return serialization::copy_opr_shallow(*opr, new_inp,
  2080. opr->config());
  2081. }
  2082. };
  2083. auto relayout_inp_to_nchw = [=](OperatorNodeBase* opr,
  2084. const VarNodeArray& new_inp) {
  2085. mgb_assert(opr->input().size() == new_inp.size());
  2086. VarNodeArray temp_inp = new_inp;
  2087. for (size_t i = 0; i < opr->input().size(); i++) {
  2088. if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
  2089. mgb_assert(opr->input(i)->shape().ndim == 4);
  2090. mgb_assert(new_inp[i]->shape().ndim == 5);
  2091. auto new_var =
  2092. RelayoutPlaceholder::make(new_inp[i], src_to_nchw_mode);
  2093. temp_inp[i] = new_var.node();
  2094. }
  2095. }
  2096. return serialization::copy_opr_shallow(*opr, temp_inp, opr->config());
  2097. };
  2098. auto&& replace_func = m_opr_replace_func;
  2099. //! supportted nchwxx
  2100. replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
  2101. replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr;
  2102. replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
  2103. replace_func[opr::Concat::typeinfo()] = replace_multi_inp_opr;
  2104. replace_func[opr::Elemwise::typeinfo()] = replace_multi_inp_opr;
  2105. replace_func[opr::TypeCvt::typeinfo()] = replace_multi_inp_opr;
  2106. replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_multi_inp_opr;
  2107. replace_func[opr::PowC::typeinfo()] = replace_multi_inp_opr;
  2108. //! not support yet
  2109. replace_func[opr::ConvolutionBackwardData::typeinfo()] =
  2110. relayout_inp_to_nchw;
  2111. replace_func[opr::Subtensor::typeinfo()] = relayout_inp_to_nchw;
  2112. replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_nchw;
  2113. replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_nchw;
  2114. replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw;
  2115. replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw;
  2116. replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw;
  2117. replace_func[opr::ResizeForward::typeinfo()] = relayout_inp_to_nchw;
  2118. replace_func[opr::WarpPerspectiveForward::typeinfo()] =
  2119. relayout_inp_to_nchw;
  2120. replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw;
  2121. replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_nchw;
  2122. }
  2123. std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
  2124. size_t pack_c_size) {
  2125. MIDOUT_B("EnableNchwxxPass::make")
  2126. auto ret = std::make_unique<EnableNchwxxPass>(pack_c_size);
  2127. ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
  2128. std::string convter_pass_name = "conv_format_nchw88";
  2129. if (pack_c_size == 4) {
  2130. convter_pass_name = "conv_format_nchw44";
  2131. }
  2132. ret->fill_opr_convert_fun(pack_c_size);
  2133. ret->set_name(convter_pass_name);
  2134. return ret;
  2135. MIDOUT_E
  2136. }
  2137. /* ================ EnableNchw44DotPass =============== */
  2138. VarNode* EnableNchw44DotPass::on_graph_endpoint_var(VarNode* new_var,
  2139. VarNode* orig_var) const {
  2140. if (!orig_var->shape().eq_shape(new_var->shape())) {
  2141. return RelayoutPlaceholder::make(
  2142. new_var, RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW)
  2143. .node();
  2144. }
  2145. return new_var;
  2146. }
  2147. std::unique_ptr<EnableNchw44DotPass>
  2148. EnableNchw44DotPass::make_nchw44_dot_converter() {
  2149. MIDOUT_B("EnableNchw44DotPass::make")
  2150. auto ret = std::make_unique<EnableNchw44DotPass>();
  2151. ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
  2152. //! First is whether the conv can trans to nchwxx, second is the filter
  2153. //! trans mode
  2154. using RelayoutMode = RelayoutPlaceholder::LayoutType;
  2155. struct TestTransResult {
  2156. TransType trans_type;
  2157. RelayoutMode relayout_mod;
  2158. megdnn::param::ConvolutionV0::Format conv_format;
  2159. };
  2160. constexpr size_t pack_c_size = 4_z;
  2161. auto test_trans_nchw44_dot =
  2162. [](const megdnn::param::Convolution::Sparse conv_mode,
  2163. const VarNode* filter, const size_t stride_h,
  2164. const size_t stride_w) -> TestTransResult {
  2165. TestTransResult ret{TransType::TRANS_NONE, {}, {}};
  2166. if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) {
  2167. size_t OC = filter->shape()[0];
  2168. size_t IC = filter->shape()[1];
  2169. size_t FH = filter->shape()[2];
  2170. size_t FW = filter->shape()[3];
  2171. if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) {
  2172. ret.trans_type = TransType::TRANS_PURE_NCHWXX;
  2173. ret.relayout_mod =
  2174. RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE;
  2175. ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT;
  2176. } else if (nchw_nchwxx_valid(OC, IC, pack_c_size, FH, FW, stride_h,
  2177. stride_w)) {
  2178. ret.trans_type = TransType::TRANS_HYBIRD_NCHWXX;
  2179. ret.relayout_mod = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44;
  2180. ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT;
  2181. }
  2182. } else {
  2183. mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP);
  2184. size_t group = filter->shape()[0];
  2185. size_t ocpg = filter->shape()[1];
  2186. size_t icpg = filter->shape()[2];
  2187. if (icpg == 1 && ocpg == 1 && (group % pack_c_size == 0)) {
  2188. ret.trans_type = TransType::TRANS_PURE_NCHWXX;
  2189. ret.relayout_mod = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_CHAN;
  2190. ret.conv_format = megdnn::param::ConvBias::Format::NCHW44;
  2191. } else if ((icpg % pack_c_size == 0) && (ocpg % pack_c_size == 0)) {
  2192. ret.trans_type = TransType::TRANS_PURE_NCHWXX;
  2193. ret.relayout_mod =
  2194. RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP;
  2195. ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT;
  2196. }
  2197. }
  2198. return ret;
  2199. };
  2200. auto replace_conv_opr = [test_trans_nchw44_dot](
  2201. OperatorNodeBase* opr,
  2202. const VarNodeArray& new_inp) {
  2203. mgb_assert(opr->input().size() == new_inp.size());
  2204. auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
  2205. mgb_assert(conv_opr.param().format ==
  2206. megdnn::param::Convolution::Format::NCHW,
  2207. "ConvertFormat Pass only support converting NCHW to "
  2208. "NCHW44_DOT");
  2209. auto is_trans = test_trans_nchw44_dot(
  2210. conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h,
  2211. conv_opr.param().stride_w);
  2212. //! can not trans to nchwxx
  2213. if (is_trans.trans_type == TransType::TRANS_NONE) {
  2214. mgb_assert(new_inp[1]->shape().ndim == 4 ||
  2215. new_inp[1]->shape().ndim == 5,
  2216. "The origin filter is not NCHW mode");
  2217. VarNodeArray temp_inp = new_inp;
  2218. //! if src is nchwxx, should RelayoutPlaceholder to nchw
  2219. if (temp_inp[0]->shape().ndim == 5) {
  2220. auto new_src = RelayoutPlaceholder::make(
  2221. new_inp[0], RelayoutMode::NCHW4_TO_NCHW);
  2222. temp_inp[0] = new_src.node();
  2223. }
  2224. auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp,
  2225. opr->config());
  2226. return new_opr;
  2227. } else if (is_trans.trans_type == TransType::TRANS_PURE_NCHWXX) {
  2228. //! filter trans to nchwxx mode
  2229. mgb_assert(new_inp[1]->shape().ndim == 4 ||
  2230. new_inp[1]->shape().ndim == 5,
  2231. "The origin filter is not NCHW mode");
  2232. VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1];
  2233. auto new_filter = RelayoutPlaceholder::make(new_inp[1],
  2234. is_trans.relayout_mod);
  2235. conv_filter = new_filter.node();
  2236. //! src trans to nchwxx mode
  2237. if (new_inp[0]->shape().ndim != 5) {
  2238. mgb_assert(new_inp[0]->shape().ndim == 4);
  2239. auto new_src = RelayoutPlaceholder::make(
  2240. new_inp[0], RelayoutMode::NCHW_TO_NCHW4);
  2241. conv_src = new_src.node();
  2242. }
  2243. auto new_param = conv_opr.param();
  2244. new_param.format = is_trans.conv_format;
  2245. mgb_assert(conv_src->shape().ndim == 5 &&
  2246. conv_filter->shape().ndim >= 6,
  2247. "The conv src dim is not trans to nchwxx");
  2248. auto new_conv_opr = opr::Convolution::make(
  2249. conv_src, conv_filter, new_param,
  2250. conv_opr.execution_policy(), conv_opr.config());
  2251. OperatorNodeBase* new_opr = new_conv_opr.node()->owner_opr();
  2252. mgb_assert(new_conv_opr.shape().ndim == 5,
  2253. "The conv dst dim is not trans to nchwxx");
  2254. return new_opr;
  2255. } else {
  2256. mgb_assert(is_trans.trans_type == TransType::TRANS_HYBIRD_NCHWXX);
  2257. VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1];
  2258. auto new_filter = RelayoutPlaceholder::make(new_inp[1],
  2259. is_trans.relayout_mod);
  2260. conv_filter = new_filter.node();
  2261. mgb_assert(conv_src->shape().ndim == 4 &&
  2262. conv_filter->shape().ndim == 5,
  2263. "The src and filter is OK");
  2264. auto new_param = conv_opr.param();
  2265. new_param.format = is_trans.conv_format;
  2266. auto new_conv_opr = opr::Convolution::make(
  2267. conv_src, conv_filter, new_param,
  2268. conv_opr.execution_policy(), conv_opr.config());
  2269. OperatorNodeBase* new_opr = new_conv_opr.node()->owner_opr();
  2270. mgb_assert(new_conv_opr.shape().ndim == 5,
  2271. "The conv dst dim is not trans to nchwxx");
  2272. return new_opr;
  2273. }
  2274. };
  2275. auto replace_conv_bias_opr = [test_trans_nchw44_dot](
  2276. OperatorNodeBase* opr,
  2277. const VarNodeArray& new_inp) {
  2278. mgb_assert(opr->input().size() == new_inp.size());
  2279. auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
  2280. mgb_assert(conv_bias_opr.param().format ==
  2281. megdnn::param::ConvBias::Format::NCHW,
  2282. "ConvertFormat Pass only support converting NCHW to NCHWXX");
  2283. auto is_trans = test_trans_nchw44_dot(
  2284. conv_bias_opr.param().sparse, new_inp[1],
  2285. conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w);
  2286. //! can not trans to nchwxx
  2287. if (is_trans.trans_type == TransType::TRANS_NONE) {
  2288. mgb_assert(new_inp[1]->shape().ndim == 4 ||
  2289. new_inp[1]->shape().ndim == 5,
  2290. "The origin filter is not NCHW mode");
  2291. VarNodeArray temp_inp = new_inp;
  2292. //! if src is nchwxx, should RelayoutPlaceholder to nchw
  2293. if (temp_inp[0]->shape().ndim == 5) {
  2294. auto new_src = RelayoutPlaceholder::make(
  2295. new_inp[0], RelayoutMode::NCHW4_TO_NCHW);
  2296. temp_inp[0] = new_src.node();
  2297. }
  2298. //! the bias is nchwxx
  2299. if (temp_inp[2]->shape().ndim == 5) {
  2300. auto new_bias = RelayoutPlaceholder::make(
  2301. new_inp[2], RelayoutMode::NCHW4_TO_NCHW);
  2302. temp_inp[2] = new_bias.node();
  2303. }
  2304. auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp,
  2305. opr->config());
  2306. return new_opr;
  2307. } else if (is_trans.trans_type == TransType::TRANS_PURE_NCHWXX) {
  2308. VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1],
  2309. *conv_bias_bias = new_inp[2];
  2310. //! filter trans to nchwxx mode
  2311. mgb_assert(new_inp[1]->shape().ndim == 4 ||
  2312. new_inp[1]->shape().ndim == 5,
  2313. "The origin filter is not NCHW mode");
  2314. auto new_filter = RelayoutPlaceholder::make(new_inp[1],
  2315. is_trans.relayout_mod);
  2316. conv_bias_filter = new_filter.node();
  2317. //! src trans to nchwxx mode
  2318. if (new_inp[0]->shape().ndim != 5) {
  2319. mgb_assert(new_inp[0]->shape().ndim == 4);
  2320. auto new_src = RelayoutPlaceholder::make(
  2321. new_inp[0], RelayoutMode::NCHW_TO_NCHW4);
  2322. conv_bias_src = new_src.node();
  2323. }
  2324. //! bias trans to nchwxx mode, bias may be scale
  2325. if (new_inp[2]->shape().ndim == 4) {
  2326. auto new_bias = RelayoutPlaceholder::make(
  2327. new_inp[2], RelayoutMode::NCHW_TO_NCHW4);
  2328. conv_bias_bias = new_bias.node();
  2329. }
  2330. auto new_param = conv_bias_opr.param();
  2331. new_param.format = is_trans.conv_format;
  2332. mgb_assert(conv_bias_src->shape().ndim == 5 &&
  2333. conv_bias_filter->shape().ndim >= 6,
  2334. "The conv_bias src dim is not trans to nchwxx");
  2335. auto new_conv_bias_opr = opr::ConvBias::make(
  2336. conv_bias_src, conv_bias_filter, conv_bias_bias, new_param,
  2337. conv_bias_opr.execution_policy(), conv_bias_opr.config());
  2338. OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
  2339. mgb_assert(new_conv_bias_opr.shape().ndim == 5,
  2340. "The conv_bias dst dim is not trans to nchwxx");
  2341. return new_opr;
  2342. } else {
  2343. mgb_assert(is_trans.trans_type == TransType::TRANS_HYBIRD_NCHWXX);
  2344. VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1],
  2345. *conv_bias_bias = new_inp[2];
  2346. auto new_filter = RelayoutPlaceholder::make(new_inp[1],
  2347. is_trans.relayout_mod);
  2348. conv_bias_filter = new_filter.node();
  2349. //! bias trans to nchwxx mode, bias may be scale
  2350. if (new_inp[2]->shape().ndim == 4) {
  2351. auto new_bias = RelayoutPlaceholder::make(
  2352. new_inp[2], RelayoutMode::NCHW_TO_NCHW4);
  2353. conv_bias_bias = new_bias.node();
  2354. }
  2355. mgb_assert(conv_bias_src->shape().ndim == 4 &&
  2356. conv_bias_filter->shape().ndim == 5);
  2357. mgb_assert((conv_bias_bias->shape().ndim == 5) ||
  2358. conv_bias_bias->shape().is_scalar());
  2359. auto new_param = conv_bias_opr.param();
  2360. new_param.format = is_trans.conv_format;
  2361. auto new_conv_bias_opr = opr::ConvBias::make(
  2362. conv_bias_src, conv_bias_filter, conv_bias_bias, new_param,
  2363. conv_bias_opr.execution_policy(), conv_bias_opr.config());
  2364. OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
  2365. mgb_assert(new_conv_bias_opr.shape().ndim == 5,
  2366. "The conv dst dim is not trans to nchwxx");
  2367. return new_opr;
  2368. }
  2369. };
  2370. ret->fill_opr_convert_fun(4);
  2371. auto&& replace_func = ret->m_opr_replace_func;
  2372. //! supportted nchwxx
  2373. replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
  2374. replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr;
  2375. return ret;
  2376. MIDOUT_E
  2377. }
  2378. /* ==================== ShuffleShuffleRemovePass ================= */
  2379. class ShuffleShuffleRemovePass::Impl {
  2380. using TensorFormat = opr::ConvBias::Param::Format;
  2381. OptState& m_opt_state;
  2382. ThinHashMap<std::pair<TensorFormat, TensorFormat>,
  2383. thin_function<VarNode*(VarNode*)>>
  2384. m_reformat;
  2385. class AbstractShuffleOpr;
  2386. void detect_shuffle_operations();
  2387. void do_replace();
  2388. public:
  2389. Impl(OptState& opt_state) : m_opt_state{opt_state} {
  2390. m_reformat[std::make_pair(TensorFormat::NCHW, TensorFormat::NCHW4)] =
  2391. [](VarNode* inp) -> VarNode* {
  2392. auto x = SymbolVar(inp);
  2393. auto xshp = opr::GetVarShape::make(x);
  2394. auto cv = [&x](int v) { return x.make_scalar(v); };
  2395. auto sub = [&xshp, &cv](int idx) {
  2396. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  2397. };
  2398. auto tshp = opr::Concat::make(
  2399. {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0);
  2400. auto y0 = opr::Reshape::make(x, tshp);
  2401. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
  2402. return y1.node();
  2403. };
  2404. m_reformat[std::make_pair(TensorFormat::NCHW, TensorFormat::NCHW32)] =
  2405. [](VarNode* inp) -> VarNode* {
  2406. auto x = SymbolVar(inp);
  2407. auto xshp = opr::GetVarShape::make(x);
  2408. auto cv = [&x](int v) { return x.make_scalar(v); };
  2409. auto sub = [&xshp, &cv](int idx) {
  2410. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  2411. };
  2412. auto tshp = opr::Concat::make(
  2413. {sub(0), sub(1) / 32, cv(32), sub(2), sub(3)}, 0);
  2414. auto y0 = opr::Reshape::make(x, tshp);
  2415. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
  2416. return y1.node();
  2417. };
  2418. m_reformat[std::make_pair(TensorFormat::NCHW4, TensorFormat::NCHW)] =
  2419. [](VarNode* inp) -> VarNode* {
  2420. mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4);
  2421. auto x = SymbolVar(inp);
  2422. auto xshp = opr::GetVarShape::make(x);
  2423. auto cv = [&x](int v) { return x.make_scalar(v); };
  2424. auto sub = [&xshp, &cv](int idx) {
  2425. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  2426. };
  2427. auto tshp =
  2428. opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0);
  2429. auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
  2430. auto y1 = opr::Reshape::make(y0, tshp);
  2431. return y1.node();
  2432. };
  2433. m_reformat[std::make_pair(TensorFormat::NCHW32, TensorFormat::NCHW)] =
  2434. [](VarNode* inp) -> VarNode* {
  2435. mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 32);
  2436. auto x = SymbolVar(inp);
  2437. auto xshp = opr::GetVarShape::make(x);
  2438. auto cv = [&x](int v) { return x.make_scalar(v); };
  2439. auto sub = [&xshp, &cv](int idx) {
  2440. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  2441. };
  2442. auto tshp =
  2443. opr::Concat::make({sub(0), sub(1) * 32, sub(2), sub(3)}, 0);
  2444. auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
  2445. auto y1 = opr::Reshape::make(y0, tshp);
  2446. return y1.node();
  2447. };
  2448. m_reformat[std::make_pair(TensorFormat::NCHW4, TensorFormat::NCHW32)] =
  2449. [](VarNode* inp) -> VarNode* {
  2450. mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4);
  2451. auto x = SymbolVar(inp);
  2452. auto xshp = opr::GetVarShape::make(x);
  2453. auto cv = [&x](int v) { return x.make_scalar(v); };
  2454. auto sub = [&xshp, &cv](int idx) {
  2455. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  2456. };
  2457. auto tshp0 = opr::Concat::make(
  2458. {sub(0), sub(1) / 8, cv(8), sub(2), sub(3), sub(4)},
  2459. 0),
  2460. tshp1 = opr::Concat::make(
  2461. {sub(0), sub(1) / 8, sub(2), sub(3), sub(4) * 8}, 0);
  2462. auto y0 = opr::Reshape::make(x, tshp0);
  2463. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2, 5});
  2464. auto y2 = opr::Reshape::make(y1, tshp1);
  2465. return y2.node();
  2466. };
  2467. m_reformat[std::make_pair(TensorFormat::NCHW32, TensorFormat::NCHW4)] =
  2468. [](VarNode* inp) -> VarNode* {
  2469. mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 32);
  2470. auto x = SymbolVar(inp);
  2471. auto xshp = opr::GetVarShape::make(x);
  2472. auto cv = [&x](int v) { return x.make_scalar(v); };
  2473. auto sub = [&xshp, &cv](int idx) {
  2474. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  2475. };
  2476. auto tshp0 = opr::Concat::make(
  2477. {sub(0), sub(1), sub(2), sub(3), cv(8), sub(4) / 8},
  2478. 0),
  2479. tshp1 = opr::Concat::make(
  2480. {sub(0), sub(1) * 8, sub(2), sub(3), sub(4) / 8}, 0);
  2481. auto y0 = opr::Reshape::make(x, tshp0);
  2482. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 4, 2, 3, 5});
  2483. auto y2 = opr::Reshape::make(y1, tshp1);
  2484. return y2.node();
  2485. };
  2486. m_reformat[std::make_pair(TensorFormat::NCHW4, TensorFormat::CHWN4)] =
  2487. [](VarNode* inp) -> VarNode* {
  2488. megdnn::param::RelayoutFormat param;
  2489. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW4_CHWN4;
  2490. auto reformat = opr::RelayoutFormat::make(inp, param);
  2491. return reformat.node();
  2492. };
  2493. m_reformat[std::make_pair(TensorFormat::CHWN4, TensorFormat::NCHW4)] =
  2494. [](VarNode* inp) -> VarNode* {
  2495. megdnn::param::RelayoutFormat param;
  2496. param.mode = megdnn::param::RelayoutFormat::Mode::CHWN4_NCHW4;
  2497. auto reformat = opr::RelayoutFormat::make(inp, param);
  2498. return reformat.node();
  2499. };
  2500. m_reformat[std::make_pair(TensorFormat::NCHW, TensorFormat::CHWN4)] =
  2501. [](VarNode* inp) -> VarNode* {
  2502. auto x = SymbolVar(inp);
  2503. auto xshp = opr::GetVarShape::make(x);
  2504. auto cv = [&x](int v) { return x.make_scalar(v); };
  2505. auto sub = [&xshp, &cv](int idx) {
  2506. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  2507. };
  2508. auto tshp = opr::Concat::make(
  2509. {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0);
  2510. auto y0 = opr::Reshape::make(x, tshp);
  2511. auto y1 = opr::Dimshuffle::make(y0, {1, 3, 4, 0, 2});
  2512. return y1.node();
  2513. };
  2514. m_reformat[std::make_pair(TensorFormat::CHWN4, TensorFormat::NCHW)] =
  2515. [](VarNode* inp) -> VarNode* {
  2516. mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4);
  2517. auto x = SymbolVar(inp);
  2518. auto xshp = opr::GetVarShape::make(x);
  2519. auto cv = [&x](int v) { return x.make_scalar(v); };
  2520. auto sub = [&xshp, &cv](int idx) {
  2521. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  2522. };
  2523. auto tshp =
  2524. opr::Concat::make({sub(3), sub(0) * 4, sub(1), sub(2)}, 0);
  2525. auto y0 = opr::Dimshuffle::make(x, {3, 0, 4, 1, 2});
  2526. auto y1 = opr::Reshape::make(y0, tshp);
  2527. return y1.node();
  2528. };
  2529. detect_shuffle_operations();
  2530. do_replace();
  2531. }
  2532. };
  2533. /*!
  2534. * \brief abstract operator representation of shuffle operation
  2535. */
  2536. MGB_DEFINE_OPR_CLASS(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr,
  2537. cg::SingleCNOperatorNodeBase) // {
  2538. public:
  2539. AbstractShuffleOpr(VarNode* inpvar, TensorFormat inp_format,
  2540. TensorFormat out_format);
  2541. static SymbolVar make(VarNode* inpvar, TensorFormat inp_format,
  2542. TensorFormat out_format);
  2543. TensorFormat inp_format() const {
  2544. return m_inp_format;
  2545. }
  2546. TensorFormat out_format() const {
  2547. return m_out_format;
  2548. }
  2549. private:
  2550. void init_output_static_infer_desc() override;
  2551. void scn_do_execute() override;
  2552. const TensorFormat m_inp_format;
  2553. const TensorFormat m_out_format;
  2554. }
  2555. ;
  2556. MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr);
  2557. void ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr::scn_do_execute() {
  2558. mgb_throw(InternalError, "AbstractShuffleOpr cannot be executed");
  2559. }
  2560. void ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr::
  2561. init_output_static_infer_desc() {
  2562. using namespace cg::static_infer;
  2563. auto&& mgr = owner_graph()->static_infer_manager();
  2564. DepVal deps;
  2565. for (auto i : input())
  2566. deps.push_back({i, DepType::SHAPE});
  2567. auto infer_shape = [this](TensorShape& dst, const InpVal& inp) {
  2568. TensorShape inp_shape = inp.val[0].shape();
  2569. if (m_inp_format == TensorFormat::NCHW4 &&
  2570. m_out_format == TensorFormat::NCHW32) {
  2571. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
  2572. dst = inp_shape;
  2573. dst[0] = inp_shape[0];
  2574. dst[1] = inp_shape[1] / 8;
  2575. dst[2] = inp_shape[2];
  2576. dst[3] = inp_shape[3];
  2577. dst[4] = inp_shape[4] * 8;
  2578. } else if (m_inp_format == TensorFormat::NCHW32 &&
  2579. m_out_format == TensorFormat::NCHW4) {
  2580. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 32);
  2581. dst = inp_shape;
  2582. dst[0] = inp_shape[0];
  2583. dst[1] = inp_shape[1] * 8;
  2584. dst[2] = inp_shape[2];
  2585. dst[3] = inp_shape[3];
  2586. dst[4] = inp_shape[4] / 8;
  2587. } else if (m_inp_format == TensorFormat::NCHW &&
  2588. m_out_format == TensorFormat::NCHW4) {
  2589. mgb_assert(inp_shape.ndim == 4);
  2590. dst.ndim = 5;
  2591. dst[0] = inp_shape[0];
  2592. dst[1] = inp_shape[1] / 4;
  2593. dst[2] = inp_shape[2];
  2594. dst[3] = inp_shape[3];
  2595. dst[4] = 4;
  2596. } else if (m_inp_format == TensorFormat::NCHW4 &&
  2597. m_out_format == TensorFormat::NCHW) {
  2598. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
  2599. dst.ndim = 4;
  2600. dst[0] = inp_shape[0];
  2601. dst[1] = inp_shape[1] * 4;
  2602. dst[2] = inp_shape[2];
  2603. dst[3] = inp_shape[3];
  2604. } else if (m_inp_format == TensorFormat::NCHW4 &&
  2605. m_out_format == TensorFormat::CHWN4) {
  2606. dst.ndim = 5;
  2607. dst[0] = inp_shape[1];
  2608. dst[1] = inp_shape[2];
  2609. dst[2] = inp_shape[3];
  2610. dst[3] = inp_shape[0];
  2611. dst[4] = inp_shape[4];
  2612. } else if (m_inp_format == TensorFormat::CHWN4 &&
  2613. m_out_format == TensorFormat::NCHW4) {
  2614. dst.ndim = 5;
  2615. dst[0] = inp_shape[3];
  2616. dst[1] = inp_shape[0];
  2617. dst[2] = inp_shape[1];
  2618. dst[3] = inp_shape[2];
  2619. dst[4] = inp_shape[4];
  2620. } else {
  2621. mgb_throw(InternalError,
  2622. "Unsupported input format and output format.");
  2623. }
  2624. return true;
  2625. };
  2626. mgr.register_shape_infer(output(0), {SourceType::DEP, deps, infer_shape});
  2627. }
  2628. ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr::AbstractShuffleOpr(
  2629. VarNode* inpvar, TensorFormat inp_format, TensorFormat out_format)
  2630. : Super(inpvar->owner_graph(), {}, "AbstractShuffleOpr", {inpvar}),
  2631. m_inp_format{inp_format},
  2632. m_out_format{out_format} {
  2633. add_input({inpvar});
  2634. add_equivalence_component<ScalarHash<TensorFormat>>(m_inp_format);
  2635. add_equivalence_component<ScalarHash<TensorFormat>>(m_out_format);
  2636. add_output(None)->dtype(inpvar->dtype());
  2637. }
  2638. SymbolVar ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr::make(
  2639. VarNode* inpvar, TensorFormat inp_format, TensorFormat out_format) {
  2640. return inpvar->owner_graph()
  2641. ->insert_opr(std::make_unique<AbstractShuffleOpr>(
  2642. inpvar, inp_format, out_format))
  2643. ->output(0);
  2644. }
  2645. void ShuffleShuffleRemovePass::Impl::detect_shuffle_operations() {
  2646. auto rewriter = m_opt_state.graph().make_rewriter();
  2647. auto uniq_reader_check = UniqReaderCheck{m_opt_state.graph()};
  2648. auto try_reshape_shuffle = [&rewriter,
  2649. &uniq_reader_check](OperatorNodeBase* opr) {
  2650. // check shuffle
  2651. auto shuffle = try_cast_as_op<opr::Dimshuffle>(opr);
  2652. if (shuffle == nullptr)
  2653. return false;
  2654. auto&& param = shuffle->param();
  2655. if (param.pattern_len != 5)
  2656. return false;
  2657. bool is_nchw2nchw4 = param.pattern[0] == 0 && param.pattern[1] == 1 &&
  2658. param.pattern[2] == 3 && param.pattern[3] == 4 &&
  2659. param.pattern[4] == 2 &&
  2660. opr->output(0)->shape()[4] == 4;
  2661. if (!is_nchw2nchw4)
  2662. return false;
  2663. if (!uniq_reader_check(shuffle->input(0)))
  2664. return false;
  2665. // check reshape
  2666. auto reshape = try_cast_as_op<opr::Reshape>(opr->input(0)->owner_opr());
  2667. if (reshape == nullptr)
  2668. return false;
  2669. auto inp_var = rewriter.get_var(reshape->input(0));
  2670. auto abstract_shuffle = AbstractShuffleOpr::make(
  2671. inp_var, TensorFormat::NCHW, TensorFormat::NCHW4);
  2672. rewriter.replace_var(
  2673. opr->output(0), abstract_shuffle.node(),
  2674. mgb_cstr_log("replace reformat(nchw -> nchw4) to "
  2675. "AbstractShuffleOpr(nchw -> nchw4)."));
  2676. return true;
  2677. };
  2678. auto try_reshape_shuffle_reshape = [&rewriter, &uniq_reader_check](
  2679. OperatorNodeBase* opr) {
  2680. // check reshape
  2681. auto reshape1 = try_cast_as_op<opr::Reshape>(opr);
  2682. if (reshape1 == nullptr)
  2683. return false;
  2684. if (!uniq_reader_check(reshape1->input(0)))
  2685. return false;
  2686. // check shuffle
  2687. auto shuffle =
  2688. try_cast_as_op<opr::Dimshuffle>(opr->input(0)->owner_opr());
  2689. if (shuffle == nullptr)
  2690. return false;
  2691. auto&& param = shuffle->param();
  2692. if (param.pattern_len != 6)
  2693. return false;
  2694. bool is_nchw42nchw32 = param.pattern[0] == 0 && param.pattern[1] == 1 &&
  2695. param.pattern[2] == 3 && param.pattern[3] == 4 &&
  2696. param.pattern[4] == 2 && param.pattern[5] == 5 &&
  2697. shuffle->input(0)->shape()[5] == 4 &&
  2698. shuffle->input(0)->shape()[2] == 8;
  2699. bool is_nchw322nchw4 = param.pattern[0] == 0 && param.pattern[1] == 1 &&
  2700. param.pattern[2] == 4 && param.pattern[3] == 2 &&
  2701. param.pattern[4] == 3 && param.pattern[5] == 5 &&
  2702. shuffle->input(0)->shape()[4] == 8 &&
  2703. shuffle->input(0)->shape()[5] == 4;
  2704. if (!is_nchw42nchw32 && !is_nchw322nchw4)
  2705. return false;
  2706. if (!uniq_reader_check(shuffle->input(0)))
  2707. return false;
  2708. // check reshape
  2709. auto reshape2 =
  2710. try_cast_as_op<opr::Reshape>(shuffle->input(0)->owner_opr());
  2711. if (reshape2 == nullptr)
  2712. return false;
  2713. auto inp_var = rewriter.get_var(reshape2->input(0));
  2714. TensorFormat inp_format = is_nchw42nchw32 ? TensorFormat::NCHW4
  2715. : TensorFormat::NCHW32,
  2716. out_format = is_nchw42nchw32 ? TensorFormat::NCHW32
  2717. : TensorFormat::NCHW4;
  2718. auto abstract_shuffle =
  2719. AbstractShuffleOpr::make(inp_var, inp_format, out_format);
  2720. std::string reformat_type =
  2721. is_nchw42nchw32 ? "nchw4 -> nchw32" : "nchw32 -> nchw4";
  2722. rewriter.replace_var(opr->output(0), abstract_shuffle.node(),
  2723. mgb_cstr_log(ssprintf("replace reformat(%s) to "
  2724. "AbstractShuffleOpr(%s).",
  2725. reformat_type.c_str(),
  2726. reformat_type.c_str())
  2727. .c_str()));
  2728. return true;
  2729. };
  2730. auto try_shuffle_reshape = [&rewriter,
  2731. &uniq_reader_check](OperatorNodeBase* opr) {
  2732. // check reshape
  2733. auto reshape = try_cast_as_op<opr::Reshape>(opr);
  2734. if (reshape == nullptr)
  2735. return false;
  2736. if (!uniq_reader_check(reshape->input(0)))
  2737. return false;
  2738. // check shuffle
  2739. auto shuffle =
  2740. try_cast_as_op<opr::Dimshuffle>(opr->input(0)->owner_opr());
  2741. if (shuffle == nullptr)
  2742. return false;
  2743. auto&& param = shuffle->param();
  2744. if (param.pattern_len != 5)
  2745. return false;
  2746. bool is_nchw42nchw = param.pattern[0] == 0 && param.pattern[1] == 1 &&
  2747. param.pattern[2] == 4 && param.pattern[3] == 2 &&
  2748. param.pattern[4] == 3 &&
  2749. shuffle->input(0)->shape()[4] == 4;
  2750. if (!is_nchw42nchw)
  2751. return false;
  2752. auto inp_var = rewriter.get_var(shuffle->input(0));
  2753. auto abstract_shuffle = AbstractShuffleOpr::make(
  2754. inp_var, TensorFormat::NCHW4, TensorFormat::NCHW);
  2755. rewriter.replace_var(
  2756. opr->output(0), abstract_shuffle.node(),
  2757. mgb_cstr_log("replace reformat(nchw4 -> nchw) to "
  2758. "AbstractShuffleOpr(nchw4 -> nchw)."));
  2759. return true;
  2760. };
  2761. auto try_relayout_format = [&rewriter](OperatorNodeBase* opr) {
  2762. // check relayout format
  2763. auto reformat = try_cast_as_op<opr::RelayoutFormat>(opr);
  2764. if (reformat == nullptr)
  2765. return false;
  2766. auto&& param = reformat->param();
  2767. if (param.mode != opr::RelayoutFormat::Param::Mode::CHWN4_NCHW4 &&
  2768. param.mode != opr::RelayoutFormat::Param::Mode::NCHW4_CHWN4)
  2769. return false;
  2770. auto inp_var = rewriter.get_var(reformat->input(0));
  2771. cg::SymbolVar abstract_shuffle;
  2772. if (param.mode == opr::RelayoutFormat::Param::Mode::NCHW4_CHWN4) {
  2773. abstract_shuffle = AbstractShuffleOpr::make(
  2774. inp_var, TensorFormat::NCHW4, TensorFormat::CHWN4);
  2775. } else {
  2776. abstract_shuffle = AbstractShuffleOpr::make(
  2777. inp_var, TensorFormat::CHWN4, TensorFormat::NCHW4);
  2778. }
  2779. rewriter.replace_var(
  2780. opr->output(0), abstract_shuffle.node(),
  2781. mgb_cstr_log("replace reformat(nchw4 -> nchw) to "
  2782. "AbstractShuffleOpr(nchw4 -> nchw)."));
  2783. return true;
  2784. };
  2785. auto on_opr = [&try_reshape_shuffle, &try_shuffle_reshape,
  2786. &try_reshape_shuffle_reshape, &try_relayout_format,
  2787. &rewriter, &uniq_reader_check](OperatorNodeBase* opr) {
  2788. if (!try_reshape_shuffle_reshape(opr) && !try_reshape_shuffle(opr) &&
  2789. !try_shuffle_reshape(opr) && !try_relayout_format(opr)) {
  2790. auto new_opr = rewriter.auto_replace_outputs(opr);
  2791. uniq_reader_check.update_on_opr_auto_replace(opr, new_opr);
  2792. }
  2793. };
  2794. m_opt_state.graph().iter(on_opr);
  2795. rewriter.apply_inplace();
  2796. }
  2797. void ShuffleShuffleRemovePass::Impl::do_replace() {
  2798. auto rewriter = m_opt_state.graph().make_rewriter();
  2799. auto uniq_reader_check = UniqReaderCheck{m_opt_state.graph()};
  2800. ThinHashMap<VarNode*, VarNode*> var2endpoint;
  2801. ThinHashSet<VarNode*> trt_opr_inps;
  2802. SmallVector<OperatorNodeBase*> topo_order;
  2803. auto cb = [&topo_order, &trt_opr_inps](OperatorNodeBase* opr) {
  2804. topo_order.push_back(opr);
  2805. MGB_MARK_USED_VAR(trt_opr_inps);
  2806. #if MGB_ENABLE_TENSOR_RT
  2807. if (opr->same_type<opr::TensorRTOpr>()) {
  2808. for (auto&& inp : opr->input())
  2809. trt_opr_inps.insert(inp);
  2810. }
  2811. #endif
  2812. };
  2813. m_opt_state.graph().iter(cb);
  2814. for (auto&& opr : reverse_adaptor(topo_order)) {
  2815. if (opr->same_type<opr::TypeCvt>() ||
  2816. opr->same_type<AbstractShuffleOpr>()) {
  2817. auto find = var2endpoint.find(opr->output(0));
  2818. if (find != var2endpoint.end()) {
  2819. if (uniq_reader_check(opr->output(0))) {
  2820. var2endpoint[opr->input(0)] = find->second;
  2821. } else {
  2822. var2endpoint[opr->input(0)] = opr->output(0);
  2823. }
  2824. } else {
  2825. var2endpoint[opr->input(0)] = opr->output(0);
  2826. }
  2827. }
  2828. }
  2829. auto on_opr = [this, &rewriter, &uniq_reader_check, &trt_opr_inps,
  2830. &var2endpoint](OperatorNodeBase* opr) {
  2831. MGB_MARK_USED_VAR(trt_opr_inps);
  2832. bool cond_opr = opr->same_type<opr::TypeCvt>() ||
  2833. opr->same_type<AbstractShuffleOpr>();
  2834. if (cond_opr) {
  2835. bool cond_endpoint = var2endpoint[opr->input(0)] == opr->output(0);
  2836. if (!cond_endpoint)
  2837. return;
  2838. auto cur = opr;
  2839. auto var = opr->output(0), inp_var = opr->input(0);
  2840. bool force_folding_typecvt = false;
  2841. bool first_shuffle = false;
  2842. // initialize inp_format and out_format
  2843. TensorFormat out_format = TensorFormat::NCHW,
  2844. inp_format = out_format;
  2845. megdnn::DType inp_dtype = cur->input(0)->dtype(),
  2846. out_dtype = cur->output(0)->dtype();
  2847. SmallVector<megdnn::DType> out_dtype_vec;
  2848. while (cond_opr) {
  2849. if (cur->same_type<AbstractShuffleOpr>()) {
  2850. auto shuffle = try_cast_as_op<AbstractShuffleOpr>(cur);
  2851. inp_format = shuffle->inp_format();
  2852. if (!first_shuffle) {
  2853. out_format = shuffle->out_format();
  2854. first_shuffle = true;
  2855. }
  2856. } else {
  2857. mgb_assert(cur->same_type<opr::TypeCvt>());
  2858. out_dtype_vec.push_back(cur->output(0)->dtype());
  2859. }
  2860. inp_var = cur->input(0);
  2861. bool cond_reader = uniq_reader_check(inp_var);
  2862. if (!cond_reader)
  2863. break;
  2864. cur = cur->input(0)->owner_opr();
  2865. cond_opr = cur->same_type<opr::TypeCvt>() ||
  2866. cur->same_type<AbstractShuffleOpr>();
  2867. }
  2868. std::reverse(out_dtype_vec.begin(), out_dtype_vec.end());
  2869. #if MGB_ENABLE_TENSOR_RT
  2870. force_folding_typecvt =
  2871. inp_var->owner_opr()->same_type<opr::TensorRTOpr>() ||
  2872. trt_opr_inps.count(var);
  2873. #endif
  2874. auto new_var = rewriter.get_var(inp_var);
  2875. if (inp_format != out_format) {
  2876. mgb_assert(m_reformat.find(std::make_pair(
  2877. inp_format, out_format)) != m_reformat.end(),
  2878. "Unsupported shuffle shuffle remove pass");
  2879. new_var = m_reformat[std::make_pair(inp_format, out_format)](
  2880. new_var);
  2881. }
  2882. if (force_folding_typecvt) {
  2883. inp_dtype = inp_var->dtype();
  2884. if (inp_dtype != out_dtype) {
  2885. auto type_cvt = opr::TypeCvt::make(new_var, out_dtype);
  2886. new_var = type_cvt.node();
  2887. }
  2888. } else {
  2889. if (out_dtype_vec.back() != var->dtype())
  2890. out_dtype_vec.push_back(var->dtype());
  2891. for (auto&& dtype : out_dtype_vec) {
  2892. auto type_cvt = opr::TypeCvt::make(new_var, dtype);
  2893. new_var = type_cvt.node();
  2894. }
  2895. }
  2896. rewriter.replace_var(
  2897. var, new_var,
  2898. mgb_cstr_log("replace Dimshuffle and TypeCvt chain"));
  2899. } else {
  2900. auto new_opr = rewriter.auto_replace_outputs(opr);
  2901. uniq_reader_check.update_on_opr_auto_replace(opr, new_opr);
  2902. }
  2903. };
  2904. m_opt_state.graph().iter(on_opr);
  2905. rewriter.apply_inplace();
  2906. }
  2907. const char* ShuffleShuffleRemovePass::name() const {
  2908. return mgb_cstr_log("shuffle shuffle remove pass");
  2909. }
  2910. void ShuffleShuffleRemovePass::apply(OptState& opt) const {
  2911. MIDOUT_B("ShuffleShuffleRemovePass::apply")
  2912. opt.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_SHAPE |
  2913. VarReplaceCheckFlag::CHECK_DTYPE);
  2914. Impl{opt};
  2915. MIDOUT_E
  2916. }
  2917. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台