You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inference.cpp 171 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909
  1. /**
  2. * \file src/gopt/impl/inference.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/gopt/inference.h"
  12. #include "megbrain/gopt/gtrans.h"
  13. #include "megbrain/gopt/basic_arith.h"
  14. #include "megbrain/graph/event.h"
  15. #include "megbrain/opr/dnn/batch_norm.h"
  16. #include "megbrain/opr/dnn/local.h"
  17. #include "megbrain/utils/shared_set.h"
  18. #include "megbrain/serialization/opr_shallow_copy.h"
  19. #include "megbrain/opr/basic_arith.h"
  20. #include "megbrain/opr/dnn/convolution.h"
  21. #include "megbrain/opr/blas.h"
  22. #include "megbrain/opr/misc.h"
  23. #include "megbrain/opr/utility.h"
  24. #include "megbrain/opr/dnn/pooling.h"
  25. #include "megbrain/opr/tensor_manip.h"
  26. #include "megbrain/opr/imgproc.h"
  27. #include "megbrain/opr/nn_int.h"
  28. #include "megdnn/tensor_format.h"
  29. #if MGB_ENABLE_TENSOR_RT
  30. #include "megbrain/tensorrt/tensorrt_opr.h"
  31. #endif
  32. #include "megbrain/gopt/misc.h"
  33. using namespace mgb;
  34. using namespace gopt;
  35. namespace {
  36. template <typename SharedDeviceTensor, typename MultipleDeviceTensorHolder>
  37. void param_merge(OptState& opt_state) {
  38. auto rewriter = opt_state.graph().make_rewriter();
  39. ThinHashMap<OperatorNodeBase*, size_t> opr2idx;
  40. std::vector<OperatorNodeBase*> all_oprs;
  41. typename MultipleDeviceTensorHolder::ValueArray all_values;
  42. auto cb_find_opr = [&](cg::OperatorNodeBase* opr) {
  43. if (opr->same_type<SharedDeviceTensor>()) {
  44. auto p = &opr->cast_final<SharedDeviceTensor>();
  45. // ShredD may be manu
  46. opr2idx[p] = all_values.size();
  47. all_values.push_back(p->dev_data());
  48. all_oprs.push_back(p);
  49. }
  50. };
  51. opt_state.graph().iter(cb_find_opr);
  52. SymbolVarArray new_vars;
  53. auto cb_replace = [&](cg::OperatorNodeBase* opr) {
  54. auto iter = opr2idx.find(opr);
  55. if (iter == opr2idx.end()) {
  56. rewriter.auto_replace_outputs(opr);
  57. } else {
  58. if (new_vars.empty()) {
  59. // new oprs must be created in iter callback; so we populate
  60. // new_vars lazily
  61. new_vars = MultipleDeviceTensorHolder::make(
  62. *opt_state.graph().comp_graph(), std::move(all_values),
  63. {ssprintf("merged%zu", all_values.size())});
  64. for (size_t i = 0; i < new_vars.size(); ++i) {
  65. auto src = all_oprs[i]->output(0);
  66. if (src->has_name_set()) {
  67. new_vars[i].rename(src->name());
  68. }
  69. }
  70. }
  71. rewriter.replace_var(
  72. opr->output(0), new_vars.at(iter->second).node(),
  73. mgb_cstr_log("replace multi SharedDeviceTensor(Format) to "
  74. "MultipleDeviceTensorHolder(Format)"));
  75. }
  76. };
  77. opt_state.graph().iter(cb_replace);
  78. rewriter.apply_inplace();
  79. }
  80. }
  81. /* ================ global functions ================ */
  82. SymbolVarArray gopt::optimize_for_inference(
  83. const SymbolVarArray& dest_vars,
  84. const OptimizeForInferenceOptions& opt) {
  85. return gopt::GraphOptimizer()
  86. .add_preset_passes(false, &opt,
  87. &dest_vars[0].node()->owner_graph()->options())
  88. .apply({dest_vars})
  89. .endpoint_vars();
  90. }
  91. namespace {
  92. void modify_conv_policy(opr::mixin::Convolution& conv,
  93. megdnn::param::ExecutionPolicy::Strategy strategy) {
  94. auto policy = conv.execution_policy_transient();
  95. policy.strategy = strategy;
  96. conv.set_execution_policy(policy);
  97. }
  98. template <typename Opr>
  99. void inplace_conv_opr_profile_modifier(OperatorNodeBase& opr) {
  100. modify_conv_policy(
  101. opr.cast_final_safe<Opr>(),
  102. opr::mixin::Convolution::ExecutionPolicy::Strategy::PROFILE);
  103. }
  104. template <typename Opr>
  105. void inplace_conv_opr_profile_cache_modifier(OperatorNodeBase& opr) {
  106. modify_conv_policy(opr.cast_final_safe<Opr>(),
  107. opr::mixin::Convolution::ExecutionPolicy::Strategy::
  108. PROFILE_HEURISTIC);
  109. }
  110. void modify_conv_policy_workspace_limit(opr::mixin::Convolution& conv,
  111. size_t workspace_limit) {
  112. auto policy = conv.execution_policy_transient();
  113. policy.workspace_limit = workspace_limit;
  114. conv.set_execution_policy(policy);
  115. }
  116. template <typename Opr>
  117. void inplace_conv_opr_workspace_limit_modifier(OperatorNodeBase& opr,
  118. size_t workspace_limit) {
  119. modify_conv_policy_workspace_limit(opr.cast_final_safe<Opr>(),
  120. workspace_limit);
  121. }
  122. } // anonymous namespace
  123. #define MGB_FOREACH_FASTRUN_OPR(cb) \
  124. cb(ConvolutionForward), cb(ConvBiasForward), cb(ConvolutionBackwardData), \
  125. cb(ConvolutionBackwardFilter), cb(Convolution3DForward), \
  126. cb(Convolution3DBackwardData), cb(Convolution3DBackwardFilter), \
  127. cb(LocalShareForward), cb(LocalShareBackwardData), \
  128. cb(LocalShareBackwardFilter), cb(DeformableConvForward), \
  129. cb(DeformableConvBackwardFilter), cb(DeformableConvBackwardData), \
  130. cb(BatchConvBiasForward),
  131. void gopt::enable_opr_algo_profiling_inplace(
  132. const VarNodeArrayView& dest_vars) {
  133. #if MGB_ENABLE_FASTRUN
  134. static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&)> modifiers =
  135. {
  136. #define CONV(t) {opr::t::typeinfo(), &inplace_conv_opr_profile_modifier<opr::t>}
  137. MGB_FOREACH_FASTRUN_OPR(CONV)
  138. #undef CONV
  139. };
  140. auto on_opr = [&](OperatorNodeBase* opr) {
  141. auto iter = modifiers.find(opr->dyn_typeinfo());
  142. if (iter != modifiers.end()) {
  143. iter->second(*opr);
  144. }
  145. };
  146. cg::DepOprIter dep_iter{on_opr};
  147. for (auto i : dest_vars) {
  148. dep_iter.add(i);
  149. }
  150. #else
  151. mgb_throw(MegBrainError, "fastrun is disabled at compile time");
  152. #endif
  153. }
  154. void gopt::enable_opr_use_profiling_cache_inplace(
  155. const VarNodeArrayView& dest_vars) {
  156. static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&)> modifiers =
  157. {
  158. #define CONV(t) \
  159. {opr::t::typeinfo(), &inplace_conv_opr_profile_cache_modifier<opr::t>}
  160. MGB_FOREACH_FASTRUN_OPR(CONV)
  161. #undef CONV
  162. };
  163. auto on_opr = [&](OperatorNodeBase* opr) {
  164. auto iter = modifiers.find(opr->dyn_typeinfo());
  165. if (iter != modifiers.end()) {
  166. iter->second(*opr);
  167. }
  168. };
  169. cg::DepOprIter dep_iter{on_opr};
  170. for (auto i : dest_vars) {
  171. dep_iter.add(i);
  172. }
  173. }
  174. void gopt::set_opr_algo_workspace_limit_inplace(
  175. const VarNodeArrayView& dest_vars, size_t workspace_limit) {
  176. static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&, size_t)>
  177. modifiers = {
  178. #define CONV(t) \
  179. {opr::t::typeinfo(), &inplace_conv_opr_workspace_limit_modifier<opr::t>}
  180. MGB_FOREACH_FASTRUN_OPR(CONV)
  181. #undef CONV
  182. };
  183. auto on_opr = [&](OperatorNodeBase* opr) {
  184. auto iter = modifiers.find(opr->dyn_typeinfo());
  185. if (iter != modifiers.end()) {
  186. iter->second(*opr, workspace_limit);
  187. }
  188. };
  189. cg::DepOprIter dep_iter{on_opr};
  190. for (auto i : dest_vars) {
  191. dep_iter.add(i);
  192. }
  193. }
  194. #undef MGB_FOREACH_FASTRUN_OPR
  195. /* ================ ParamRedistributePass ================ */
  196. const char* ParamRedistributePass::name() const {
  197. return mgb_cstr_log("param_redistribute");
  198. }
  199. class ParamRedistributePass::Impl final: public RecursiveSubGraphRewriteHelper {
  200. ConstVarPropogate m_cvprop;
  201. UniqReaderCheck m_uniq_reader_check;
  202. //! oprs already processed in try_distribute_then_reassociate() should be
  203. //! skipped in on_new_opr_check_should_process()
  204. ThinHashSet<OperatorNodeBase*> m_opr_blacklist;
  205. std::string m_distribute_reasso_log_msg;
  206. //! try applying BinaryTrans20::associtive
  207. GTransResult try_reassociate(OperatorNodeBase *opr);
  208. //! try applying BinaryTrans20::distributive_add
  209. GTransResult try_distribute_add(OperatorNodeBase *opr);
  210. //! try distribute MUL/DIV over ADD/SUB and then apply
  211. GTransResult try_distribute_then_reassociate(OperatorNodeBase *opr);
  212. GTransResult process_opr(VarNode *out_var) override;
  213. bool on_new_opr_check_should_process(
  214. OperatorNodeBase*opr, OperatorNodeBase *repl_opr) override {
  215. m_uniq_reader_check.update_on_opr_auto_replace(opr, repl_opr);
  216. auto ins = m_cvprop.add_opr(opr);
  217. return ins.has_const_inp && !ins.all_const_inp &&
  218. !m_opr_blacklist.count(opr);
  219. };
  220. void after_replace_var(VarNode *orig_var, VarNode* new_var) override {
  221. m_uniq_reader_check.update_on_opr_auto_replace(orig_var->owner_opr(),
  222. new_var->owner_opr());
  223. }
  224. /*!
  225. * \brief try to reorder opr inputs to a const one and a non-const one
  226. *
  227. * return true if it can be reformulated as f(nci, ci), where nci is
  228. * non-const and ci is const.
  229. */
  230. bool reorder_for_normconst(OperatorNodeBase *opr,
  231. bool &swap_inp, VarNode *&nci, VarNode *&ci);
  232. public:
  233. Impl(OptState &state);
  234. };
  235. GTransResult ParamRedistributePass::Impl::process_opr(VarNode *out_var) {
  236. auto opr = out_var->owner_opr();
  237. auto trans = try_reassociate(opr);
  238. if (!trans.valid()) {
  239. trans = try_distribute_add(opr);
  240. if (!trans.valid())
  241. trans = try_distribute_then_reassociate(opr);
  242. }
  243. return trans;
  244. }
  245. GTransResult ParamRedistributePass::Impl::try_reassociate(
  246. OperatorNodeBase *opr) {
  247. // apply BinaryAssociative0 if opr is the form f(g(a, b), c) and b and c are
  248. // const
  249. bool swap_fop_inp = false, swap_gop_inp = false;
  250. VarNode *a, *b, *c, *ab;
  251. if (!reorder_for_normconst(opr, swap_fop_inp, ab, c))
  252. return None;
  253. if (!m_uniq_reader_check(ab))
  254. return None;
  255. if (!reorder_for_normconst(ab->owner_opr(), swap_gop_inp, a, b))
  256. return None;
  257. return BinaryTrans20::associtive().apply(opr, swap_fop_inp, swap_gop_inp);
  258. }
  259. GTransResult ParamRedistributePass::Impl::try_distribute_add(
  260. OperatorNodeBase *opr) {
  261. if (opr->same_type<opr::Elemwise>() || opr->input().size() != 2)
  262. return None;
  263. if (!m_cvprop.is_const(opr->input(1)))
  264. return None;
  265. auto ab = as_elem_opr(opr->input(0)->owner_opr(), opr::Elemwise::Mode::ADD);
  266. if (ab) {
  267. bool swap;
  268. VarNode *a, *b;
  269. if (reorder_for_normconst(ab, swap, a, b)) {
  270. return BinaryTrans20::distributive_add().apply(
  271. opr, false, swap);
  272. }
  273. }
  274. return None;
  275. }
  276. GTransResult ParamRedistributePass::Impl::try_distribute_then_reassociate(
  277. OperatorNodeBase *opr) {
  278. if (!opr->same_type<opr::Elemwise>())
  279. return None;
  280. using Mode = opr::Elemwise::Mode;
  281. auto mode = opr->cast_final<opr::Elemwise>().param().mode;
  282. if (!(mode == Mode::MUL || mode == Mode::TRUE_DIV))
  283. return None;
  284. VarNode *a, *b;
  285. bool swap;
  286. if (!reorder_for_normconst(opr, swap, a, b))
  287. return None;
  288. auto chain_pred = [this](OperatorNodeBase *opr) {
  289. if (as_elem_opr(opr, Mode::ADD)) {
  290. auto var = opr->output(0);
  291. return m_uniq_reader_check(var) || m_cvprop.is_const(var);
  292. }
  293. return false;
  294. };
  295. auto chain = extract_opr_leaves(a, chain_pred);
  296. if (chain.size() <= 1)
  297. return None;
  298. std::unordered_map<VarNode*, VarNode*> repl_map;
  299. m_distribute_reasso_log_msg.clear();
  300. int nr_fail = 0, nr_succ = 0;
  301. for (auto &&var: chain) {
  302. {
  303. auto iter = repl_map.find(var);
  304. if (iter != repl_map.end()) {
  305. var = iter->second;
  306. continue;
  307. }
  308. }
  309. auto vnew = (SymbolVar{var} * b).node();
  310. m_opr_blacklist.insert(vnew->owner_opr());
  311. if (!m_cvprop.is_const(var)) {
  312. auto trans = try_reassociate(vnew->owner_opr());
  313. if (!trans.valid()) {
  314. // allow at most one failed redistribution
  315. if (nr_fail)
  316. return None;
  317. ++ nr_fail;
  318. } else {
  319. ++ nr_succ;
  320. vnew = trans->result;
  321. if (!m_distribute_reasso_log_msg.empty()) {
  322. m_distribute_reasso_log_msg.append(mgb_cstr_log(";"));
  323. }
  324. m_distribute_reasso_log_msg.append(trans->msg);
  325. }
  326. }
  327. repl_map[var] = vnew;
  328. var = vnew;
  329. }
  330. if (nr_succ) {
  331. m_distribute_reasso_log_msg.insert(0,
  332. mgb_cstr_log("distribute_mul("));
  333. m_distribute_reasso_log_msg.append(mgb_cstr_log(")"));
  334. return GTransResultItem{
  335. elemwise_reduce_var_list(chain, Mode::ADD),
  336. m_distribute_reasso_log_msg.c_str(),
  337. {}};
  338. }
  339. return None;
  340. }
  341. bool ParamRedistributePass::Impl::reorder_for_normconst(
  342. OperatorNodeBase *opr, bool &swap_inp, VarNode *&nci, VarNode *&ci) {
  343. if (opr->input().size() != 2)
  344. return false;
  345. nci = opr->input(0);
  346. ci = opr->input(1);
  347. if (!m_cvprop.is_const(ci)) {
  348. if (!is_commutable_binary(opr) || !m_cvprop.is_const(nci))
  349. return false;
  350. swap_inp = true;
  351. std::swap(nci, ci);
  352. } else {
  353. if (m_cvprop.is_const(nci))
  354. return false;
  355. swap_inp = false;
  356. }
  357. return true;
  358. }
  359. ParamRedistributePass::Impl::Impl(OptState &state):
  360. RecursiveSubGraphRewriteHelper{state},
  361. m_cvprop{ConstVarType::IMMUTABLE_AND_PARAM},
  362. m_uniq_reader_check{state.graph()}
  363. {
  364. auto cg = state.graph().comp_graph();
  365. auto on_new_opr = [this](const cg::event::OprInserted &ev) {
  366. if (!ev.is_dedup && !ev.exc) {
  367. // call add_opr eagerly to avoid deep recursion
  368. m_cvprop.add_opr(ev.opr);
  369. }
  370. };
  371. auto hdl = cg->event().register_receiver
  372. <cg::event::OprInserted>(on_new_opr);
  373. apply();
  374. }
  375. void ParamRedistributePass::apply(OptState &state) const {
  376. Impl{state};
  377. }
  378. /* ================ ParamFusePass ================ */
  379. class ParamFusePass::ConstVarPropogateWithSizeCheck final:
  380. public ConstVarPropogateBase
  381. {
  382. public:
  383. //! rewrite a var; reader == nullptr means needed by endpoint
  384. using VarRewriter = std::function<
  385. void(VarNode *var, OperatorNodeBase *reader)>;
  386. ConstVarPropogateWithSizeCheck(
  387. const ParamFusePass &pf, OptState &opt_state,
  388. const VarRewriter &rewriter):
  389. ConstVarPropogateBase{ConstVarType::IMMUTABLE_AND_PARAM},
  390. m_owner{pf}, m_opt_state{opt_state}, m_rewriter{rewriter}
  391. {
  392. }
  393. private:
  394. const ParamFusePass &m_owner;
  395. OptState &m_opt_state;
  396. VarRewriter m_rewriter;
  397. void on_midconst_opr(
  398. OperatorNodeBase *opr, size_t max_src_size) override {
  399. for (auto var: opr->output()) {
  400. if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT))
  401. continue;
  402. auto osize = var_mem_size(var);
  403. if (osize >= max_src_size &&
  404. osize - max_src_size > m_owner.m_param_grow_limit) {
  405. return;
  406. }
  407. // const oprs should be evaluated when output is used by another
  408. // non-const opr or output is needed by the user
  409. if (m_opt_state.graph().endpoint_contain(var)) {
  410. m_rewriter(var, nullptr);
  411. }
  412. }
  413. }
  414. };
  415. /*!
  416. * \brief get name for new param
  417. */
  418. class ParamFusePass::VarNamer {
  419. #if MGB_BUILD_SLIM_SERVING
  420. public:
  421. const std::string& name(VarNode*) {
  422. static std::string ret("fuse");
  423. return ret;
  424. }
  425. #else
  426. using SrcSet = SharedSet<OperatorNodeBase*>;
  427. //! map from var to source SharedDeviceTensor/MultiSharedDeviceHolder oprs
  428. //! that it depends on
  429. ThinHashMap<OperatorNodeBase*, SrcSet> m_opr2srcs;
  430. std::string m_name_cache;
  431. std::vector<const char*> m_cur_name;
  432. SrcSet& get_src_set(OperatorNodeBase* opr) {
  433. auto opr_typeinfo = opr->dyn_typeinfo();
  434. auto iter = m_opr2srcs.find(opr);
  435. if (iter != m_opr2srcs.end()) {
  436. return iter->second;
  437. }
  438. auto &&ret = m_opr2srcs[opr];
  439. if (opr->input().empty()) {
  440. if (opr_typeinfo == opr::SharedDeviceTensor::typeinfo() ||
  441. opr_typeinfo == opr::MultipleDeviceTensorHolder::typeinfo()) {
  442. ret.insert(opr);
  443. } else {
  444. mgb_assert(opr_typeinfo == opr::ImmutableTensor::typeinfo());
  445. }
  446. return ret;
  447. }
  448. for (auto i: opr->input()) {
  449. ret.merge_from(get_src_set(i->owner_opr()));
  450. }
  451. return ret;
  452. }
  453. public:
  454. const std::string& name(VarNode *var) {
  455. m_cur_name.clear();
  456. for (auto i : get_src_set(var->owner_opr())) {
  457. m_cur_name.push_back(i->cname());
  458. }
  459. auto cmp = [](const char *x, const char *y) {
  460. return strcmp(x, y) < 0;
  461. };
  462. std::sort(m_cur_name.begin(), m_cur_name.end(), cmp);
  463. m_name_cache.clear();
  464. m_name_cache.append(mgb_cstr_log("fuse("));
  465. bool first = true;
  466. for (auto i: m_cur_name) {
  467. if (first) {
  468. first = false;
  469. } else {
  470. m_name_cache.push_back(',');
  471. }
  472. m_name_cache.append(i);
  473. }
  474. m_name_cache.append(mgb_cstr_log(
  475. ssprintf("):%s@%zu", var->cname(), var->id())));
  476. return m_name_cache;
  477. }
  478. #endif
  479. };
  480. const char* ParamFusePass::name() const {
  481. return mgb_cstr_log("param_fuse");
  482. }
  483. void ParamFusePass::apply(OptState &state) const {
  484. auto rewriter = state.graph().make_rewriter();
  485. auto cg = state.graph().comp_graph();
  486. ThinHashSet<VarNode*> processed_var;
  487. VarNamer var_namer;
  488. // reader: null if used as endvar
  489. auto replace_single_var = [&](VarNode *var, OperatorNodeBase *reader) {
  490. if (!processed_var.insert(var).second)
  491. return;
  492. auto inferred_val = std::make_shared<DeviceTensorND>(
  493. var->comp_node(), var->dtype());
  494. auto cb = [&](DeviceTensorND& val) {
  495. // retain format of val
  496. mgb_assert(val.format() == var->format());
  497. inferred_val->format(val.format())
  498. .resize(val.shape())
  499. .copy_from_fixlayout(val);
  500. };
  501. {
  502. auto orig_level = cg->options().log_level;
  503. cg->options().log_level = 0;
  504. MGB_TRY {
  505. cg->compile({{var, cb}})->execute();
  506. } MGB_FINALLY(cg->options().log_level = orig_level);
  507. }
  508. SymbolVar new_var;
  509. bool is_default_format = var->layout().format.is_default();
  510. if (cg::is_static_var_value(var) && is_default_format) {
  511. // use ImmutableTensor for inferable vars
  512. HostTensorND hv;
  513. hv.copy_from(*inferred_val).sync();
  514. new_var = opr::ImmutableTensor::make(
  515. *var->owner_graph(), hv, var_namer.name(var));
  516. } else {
  517. if (is_default_format) {
  518. new_var = opr::SharedDeviceTensor::make(
  519. *var->owner_graph(), inferred_val, var_namer.name(var));
  520. } else {
  521. new_var = opr::SharedDeviceTensorWithFormat::make(
  522. *var->owner_graph(), inferred_val, var_namer.name(var));
  523. }
  524. }
  525. std::string log;
  526. if (reader) {
  527. log = mgb_ssprintf_log(
  528. "due to read by %s{%s}",
  529. reader->cname(), reader->dyn_typeinfo()->name);
  530. } else {
  531. log = mgb_cstr_log("as endpoint");
  532. }
  533. rewriter.replace_var(var, new_var.node(), log.c_str());
  534. };
  535. ConstVarPropogateWithSizeCheck cvprop{*this, state, replace_single_var};
  536. auto on_opr = [&](OperatorNodeBase *opr) {
  537. auto add_ret = cvprop.add_opr(opr);
  538. if (!add_ret.all_const_inp && add_ret.has_midconst_inp) {
  539. for (auto i: opr->input()) {
  540. if (cvprop.is_midconst(i)) {
  541. state.call_with_opr(i->owner_opr(),
  542. [&]{replace_single_var(i, opr);});
  543. }
  544. }
  545. }
  546. rewriter.auto_replace_outputs(opr);
  547. };
  548. state.graph().iter(on_opr);
  549. rewriter.apply_inplace();
  550. }
  551. /* ================ One2OneOprReplacePass ================ */
  552. const char* ConvertF32ToF16Pass::name() const {
  553. return mgb_cstr_log("convert_f32_to_f16");
  554. }
  555. void ConvertF32ToF16Pass::apply(OptState& state) const {
  556. state.set_var_replace_check_flag(m_var_replace_check_flag);
  557. auto rewriter = state.graph().make_rewriter();
  558. VarNodeArray new_inp_cache;
  559. auto on_opr = [this, &rewriter, &new_inp_cache,
  560. &state](OperatorNodeBase* opr) {
  561. auto it = m_opr_replace_func.find(opr->dyn_typeinfo());
  562. if (it != m_opr_replace_func.end()) {
  563. auto&& new_inp = new_inp_cache;
  564. new_inp.clear();
  565. new_inp.reserve(opr->input().size());
  566. for (auto i: opr->input()) {
  567. new_inp.push_back(rewriter.get_var(i));
  568. }
  569. auto new_opr = (it->second)(opr, new_inp);
  570. auto &&origin_out = opr->output(), &&cur_out = new_opr->output();
  571. mgb_assert(origin_out.size() == cur_out.size(),
  572. "bad opr replace: src=%s{%s} dst=%s{%s}", opr->cname(),
  573. opr->dyn_typeinfo()->name, new_opr->cname(),
  574. new_opr->dyn_typeinfo()->name);
  575. //! change the output type if it's the endpoint
  576. for (size_t i = 0; i < origin_out.size(); i++) {
  577. if (state.graph().endpoint_contain(origin_out[i]) &&
  578. origin_out[i]->dtype().enumv() !=
  579. cur_out[i]->dtype().enumv()) {
  580. rewriter.replace_var(
  581. origin_out[i],
  582. opr::TypeCvt::make(cur_out[i],
  583. origin_out[i]->dtype())
  584. .node(),
  585. nullptr);
  586. } else {
  587. rewriter.replace_var(origin_out[i], cur_out[i], nullptr);
  588. }
  589. }
  590. } else {
  591. auto new_opr = rewriter.auto_replace_outputs(opr);
  592. auto&& out = opr->output();
  593. auto&& new_out = new_opr->output();
  594. for (size_t i = 0; i < out.size(); i++) {
  595. if (state.graph().endpoint_contain(out[i]) &&
  596. new_out[i]->dtype().enumv() != out[i]->dtype().enumv()) {
  597. rewriter.replace_var(
  598. new_out[i],
  599. opr::TypeCvt::make(new_out[i],
  600. out[i]->dtype())
  601. .node(),
  602. nullptr);
  603. }
  604. }
  605. }
  606. };
  607. state.graph().iter(on_opr);
  608. rewriter.apply_inplace();
  609. }
  610. std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(
  611. bool use_f32_comp) {
  612. #if MEGDNN_DISABLE_FLOAT16
  613. mgb_throw(SystemError, "float16 disabled at compile time.");
  614. #else
  615. auto replace_h2d_opr = [](OperatorNodeBase* opr,
  616. const VarNodeArray& new_inp) {
  617. mgb_assert(opr->input().size() == new_inp.size());
  618. auto& h2d_opr = opr->cast_final_safe<opr::Host2DeviceCopy>();
  619. if (h2d_opr.output(0)->dtype() == dtype::Float32()) {
  620. auto cvt_var =
  621. opr::TypeCvt::make(h2d_opr.output(0), dtype::Float16(), {});
  622. return cvt_var.node()->owner_opr();
  623. }
  624. return opr;
  625. };
  626. auto replace_sdt_opr = [](OperatorNodeBase* opr,
  627. const VarNodeArray& new_inp) {
  628. mgb_assert(opr->input().size() == new_inp.size());
  629. auto& sdt_opr = opr->cast_final_safe<opr::SharedDeviceTensor>();
  630. if (sdt_opr.output(0)->dtype() == dtype::Float32()) {
  631. auto cvt_var =
  632. opr::TypeCvt::make(sdt_opr.output(0), dtype::Float16(), {});
  633. return cvt_var.node()->owner_opr();
  634. }
  635. return opr;
  636. };
  637. auto replace_imt_opr = [](OperatorNodeBase* opr,
  638. const VarNodeArray& new_inp) {
  639. mgb_assert(opr->same_type<opr::ImmutableTensor>());
  640. mgb_assert(opr->input().size() == new_inp.size());
  641. auto& imt_opr = opr->cast_final_safe<opr::ImmutableTensor>();
  642. if (imt_opr.output(0)->dtype() == dtype::Float32()) {
  643. auto cvt_var =
  644. opr::TypeCvt::make(imt_opr.output(0), dtype::Float16(), {});
  645. return cvt_var.node()->owner_opr();
  646. }
  647. return opr;
  648. };
  649. auto replace_conv_opr = [use_f32_comp](OperatorNodeBase* opr,
  650. const VarNodeArray& new_inp) {
  651. mgb_assert(opr->input().size() == new_inp.size());
  652. auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
  653. auto new_param = conv_opr.param();
  654. if (use_f32_comp) {
  655. new_param.compute_mode =
  656. megdnn::param::Convolution::ComputeMode::FLOAT32;
  657. }
  658. mgb_assert(new_inp[0]->dtype() == dtype::Float16(),
  659. "inp %s:%s, owner_opr:%s", new_inp[0]->dtype().name(),
  660. new_inp[0]->name().c_str(),
  661. new_inp[0]->owner_opr()->name().c_str());
  662. mgb_assert(new_inp[1]->dtype() == dtype::Float16(),
  663. "inp %s:%s, owner_opr:%s", new_inp[1]->dtype().name(),
  664. new_inp[1]->name().c_str(),
  665. new_inp[1]->owner_opr()->name().c_str());
  666. auto new_conv_opr = opr::Convolution::make(
  667. new_inp[0], new_inp[1], new_param, conv_opr.execution_policy(),
  668. conv_opr.config());
  669. return new_conv_opr.node()->owner_opr();
  670. };
  671. auto replace_matmul_opr = [use_f32_comp](OperatorNodeBase* opr,
  672. const VarNodeArray& new_inp) {
  673. mgb_assert(opr->input().size() == new_inp.size());
  674. auto& matmul_opr = opr->cast_final_safe<opr::MatrixMul>();
  675. auto new_param = matmul_opr.param();
  676. if (use_f32_comp) {
  677. new_param.compute_mode =
  678. megdnn::param::MatrixMul::ComputeMode::FLOAT32;
  679. }
  680. auto new_matmul_opr = opr::MatrixMul::make(
  681. new_inp[0], new_inp[1], new_param, matmul_opr.config());
  682. return new_matmul_opr.node()->owner_opr();
  683. };
  684. auto replace_reduce_opr = [use_f32_comp](OperatorNodeBase* opr,
  685. const VarNodeArray& new_inp) {
  686. auto& reduce_opr = opr->cast_final_safe<opr::Reduce>();
  687. auto new_param = reduce_opr.param();
  688. if (use_f32_comp) {
  689. new_param.data_type =
  690. megdnn::param::Reduce::DataType::FLOAT_O16xC32;
  691. }
  692. if (opr->input().size() == 1) {
  693. auto new_matmul_opr = opr::Reduce::make(new_inp[0], new_param, {},
  694. reduce_opr.config());
  695. return new_matmul_opr.node()->owner_opr();
  696. } else {
  697. mgb_assert(opr->input().size() == 2, "invalid input size %zu",
  698. opr->input().size());
  699. auto new_matmul_opr = opr::Reduce::make(
  700. new_inp[0], new_param, new_inp[1], reduce_opr.config());
  701. return new_matmul_opr.node()->owner_opr();
  702. }
  703. };
  704. auto replace_cvt_opr = [](OperatorNodeBase* opr,
  705. const VarNodeArray& new_inp) {
  706. auto& cvt_opr = opr->cast_final_safe<opr::TypeCvt>();
  707. SymbolVar new_cvt;
  708. if (cvt_opr.output(0)->dtype() == dtype::Float32()) {
  709. new_cvt = opr::TypeCvt::make(new_inp[0], dtype::Float16(),
  710. cvt_opr.config());
  711. } else {
  712. new_cvt = opr::TypeCvt::make(
  713. new_inp[0], cvt_opr.output()[0]->dtype(), cvt_opr.config());
  714. }
  715. return new_cvt.node()->owner_opr();
  716. };
  717. auto replace_warp_opr = [](OperatorNodeBase* opr,
  718. const VarNodeArray& new_inp) {
  719. mgb_assert(opr->input().size() == new_inp.size() &&
  720. (new_inp.size() == 3 || new_inp.size() == 4));
  721. auto& warp_opr = opr->cast_final<opr::WarpPerspective>();
  722. // mat tensor must be float32
  723. auto new_mat = new_inp[1];
  724. if (new_inp[1]->dtype() != dtype::Float32()) {
  725. if (try_cast_as_op<opr::TypeCvt>(new_mat->owner_opr()) &&
  726. new_mat->owner_opr()->input(0)->dtype() == dtype::Float32())
  727. new_mat = new_mat->owner_opr()->input(0);
  728. else
  729. new_mat =
  730. opr::TypeCvt::make(new_inp[1], dtype::Float32(), {}).node();
  731. }
  732. SymbolVar new_warp;
  733. if (new_inp.size() == 3) {
  734. new_warp = opr::WarpPerspective::make(new_inp[0], new_mat,
  735. new_inp[2], warp_opr.param(),
  736. warp_opr.config());
  737. } else {
  738. mgb_assert(new_inp.size() == 4);
  739. new_warp = opr::WarpPerspective::make(
  740. new_inp[0], new_mat, new_inp[2], new_inp[3],
  741. warp_opr.param(), warp_opr.config());
  742. }
  743. return new_warp.node()->owner_opr();
  744. };
  745. auto ret = std::make_unique<ConvertF32ToF16Pass>();
  746. // don't check dtype
  747. ret->set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^
  748. VarReplaceCheckFlag::CHECK_DTYPE);
  749. auto&& replace_func = ret->m_opr_replace_func;
  750. replace_func[opr::Host2DeviceCopy::typeinfo()] = replace_h2d_opr;
  751. replace_func[opr::SharedDeviceTensor::typeinfo()] = replace_sdt_opr;
  752. replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
  753. replace_func[opr::MatrixMul::typeinfo()] = replace_matmul_opr;
  754. replace_func[opr::Reduce::typeinfo()] = replace_reduce_opr;
  755. replace_func[opr::ImmutableTensor::typeinfo()] = replace_imt_opr;
  756. replace_func[opr::TypeCvt::typeinfo()] = replace_cvt_opr;
  757. replace_func[opr::WarpPerspective::typeinfo()] = replace_warp_opr;
  758. return ret;
  759. #endif
  760. }
  761. /* ================ ConvertFormatPass ================ */
  762. void ConvertFormatPass::apply(OptState& state) const {
  763. state.set_var_replace_check_flag(m_var_replace_check_flag);
  764. auto rewriter = state.graph().make_rewriter();
  765. VarNodeArray new_inp_cache;
  766. auto on_opr = [this, &state, &rewriter,
  767. &new_inp_cache](OperatorNodeBase* opr) {
  768. auto it = m_opr_replace_func.find(opr->dyn_typeinfo());
  769. if (it != m_opr_replace_func.end()) {
  770. auto&& new_inp = new_inp_cache;
  771. new_inp.clear();
  772. new_inp.reserve(opr->input().size());
  773. for (auto i : opr->input()) {
  774. new_inp.push_back(rewriter.get_var(i));
  775. }
  776. auto new_opr = (it->second)(opr, new_inp);
  777. auto &&out0 = opr->output(), &&out1 = new_opr->output();
  778. mgb_assert(out0.size() == out1.size(),
  779. "bad opr replace: src=%s{%s} dst=%s{%s}, src.size=%zu "
  780. "dst.size=%zu",
  781. opr->cname(), opr->dyn_typeinfo()->name,
  782. new_opr->cname(), new_opr->dyn_typeinfo()->name,
  783. out0.size(), out1.size());
  784. for (size_t i = 0; i < out0.size(); i++) {
  785. if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  786. mgb_assert(!out1[i]->contain_flag(
  787. VarNode::Flag::VOLATILE_CONTENT));
  788. auto src = out0[i];
  789. auto dst = out1[i];
  790. auto dst_is_image = dst->format().type() ==
  791. TensorFormat::Type::IMAGE2D_PACK4;
  792. if (!dst_is_image &&
  793. !src->owner_opr()->same_type<opr::ImmutableTensor>()) {
  794. mgb_log_warn(
  795. "convert NHWCD4 replaced to non-img format: "
  796. "dst_opr=%s{%s} format=%s",
  797. dst->owner_opr()->cname(),
  798. dst->owner_opr()->dyn_typeinfo()->name,
  799. dst->format().to_string().c_str());
  800. }
  801. if (state.graph().endpoint_contain(src) && dst_is_image) {
  802. // relayout back to NCHW for output vars
  803. dst = opr::RelayoutFormat::make(
  804. dst, {opr::RelayoutFormat::Param::Mode::
  805. NHWCD4I_NCHW})
  806. .node();
  807. }
  808. rewriter.replace_var(src, dst, nullptr);
  809. }
  810. }
  811. } else {
  812. rewriter.auto_replace_outputs(opr);
  813. }
  814. };
  815. state.graph().iter(on_opr);
  816. rewriter.apply_inplace();
  817. }
  818. std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
  819. auto filter_mode =
  820. [](const megdnn::param::Convolution::Sparse conv_mode,
  821. const VarNode* filter) -> megdnn::param::RelayoutFormat::Mode {
  822. bool use_dot = false;
  823. if (filter->dtype().enumv() == megdnn::DTypeEnum::QuantizedS8 ||
  824. filter->dtype().enumv() == megdnn::DTypeEnum::Quantized8Asymm)
  825. use_dot = true;
  826. if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) {
  827. if (use_dot)
  828. return megdnn::param::RelayoutFormat::Mode::
  829. INTER_WEIGHT_DENSEI_DOT;
  830. return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_DENSEI;
  831. } else {
  832. mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP);
  833. if (filter->shape()[1] == 1 && filter->shape()[2] == 1) {
  834. return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_CHANI;
  835. } else {
  836. if (use_dot)
  837. return megdnn::param::RelayoutFormat::Mode::
  838. INTER_WEIGHT_GROUPI_DOT;
  839. return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_GROUPI;
  840. }
  841. }
  842. };
  843. auto replace_conv_opr = [&filter_mode](OperatorNodeBase* opr,
  844. const VarNodeArray& new_inp) {
  845. mgb_assert(opr->input().size() == new_inp.size());
  846. auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
  847. mgb_assert(conv_opr.param().format ==
  848. megdnn::param::Convolution::Format::NCHW,
  849. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  850. VarNode *conv_src = nullptr, *conv_weights = nullptr;
  851. if (new_inp[0]->shape().ndim == 4) {
  852. // new input src is NCHW
  853. size_t group, icpg, ocpg;
  854. if (conv_opr.param().sparse ==
  855. megdnn::param::Convolution::Sparse::DENSE) {
  856. group = 1;
  857. icpg = new_inp[1]->shape()[1];
  858. ocpg = new_inp[1]->shape()[0];
  859. } else {
  860. mgb_assert(conv_opr.param().sparse ==
  861. megdnn::param::Convolution::Sparse::GROUP);
  862. group = new_inp[1]->shape()[0];
  863. icpg = new_inp[1]->shape()[2];
  864. ocpg = new_inp[1]->shape()[1];
  865. }
  866. if (ocpg % 4 == 0 && (icpg % 4 == 0 || group == 1)) {
  867. auto param = megdnn::param::RelayoutFormat();
  868. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  869. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  870. conv_src = rf.node();
  871. } else {
  872. // can not convert to hwcd4
  873. return serialization::copy_opr_shallow(*opr, new_inp,
  874. opr->config());
  875. }
  876. } else {
  877. size_t ocpg;
  878. bool is_channel_wise = false;
  879. if (conv_opr.param().sparse ==
  880. megdnn::param::Convolution::Sparse::DENSE) {
  881. ocpg = new_inp[1]->shape()[0];
  882. } else {
  883. mgb_assert(conv_opr.param().sparse ==
  884. megdnn::param::Convolution::Sparse::GROUP);
  885. size_t icpg = new_inp[1]->shape()[2];
  886. ocpg = new_inp[1]->shape()[1];
  887. if (icpg == 1 && ocpg == 1) {
  888. is_channel_wise = true;
  889. }
  890. }
  891. if (ocpg % 4 != 0 && !is_channel_wise) {
  892. VarNodeArray t_inp = new_inp;
  893. auto param = megdnn::param::RelayoutFormat();
  894. param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
  895. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  896. t_inp[0] = rf.node();
  897. auto new_opr = serialization::copy_opr_shallow(*opr, t_inp,
  898. opr->config());
  899. return new_opr;
  900. }
  901. // new input src is NHWCD4
  902. auto&& fmt = new_inp[0]
  903. ->format()
  904. .as_impl<megdnn::Image2DPack4TensorFormat>();
  905. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  906. conv_src = new_inp[0];
  907. }
  908. mgb_assert(new_inp[1]->format().type() !=
  909. TensorFormat::Type::IMAGE2D_PACK4);
  910. auto param = megdnn::param::RelayoutFormat();
  911. param.mode = filter_mode(conv_opr.param().sparse, new_inp[1]);
  912. auto relayout_weight = opr::RelayoutFormat::make(new_inp[1], param);
  913. conv_weights = relayout_weight.node();
  914. auto new_param = conv_opr.param();
  915. new_param.format = megdnn::param::Convolution::Format::NHWCD4;
  916. mgb_assert(conv_src->shape().ndim == 5 &&
  917. conv_src->format().type() ==
  918. TensorFormat::Type::IMAGE2D_PACK4);
  919. auto new_conv_opr = opr::Convolution::make(
  920. conv_src, conv_weights, new_param, conv_opr.execution_policy(),
  921. conv_opr.config());
  922. OperatorNodeBase* ret = new_conv_opr.node()->owner_opr();
  923. mgb_assert(new_conv_opr.shape().ndim == 5 &&
  924. new_conv_opr.format().type() ==
  925. TensorFormat::Type::IMAGE2D_PACK4);
  926. return ret;
  927. };
  928. auto replace_conv_bias_opr = [&filter_mode](OperatorNodeBase* opr,
  929. const VarNodeArray& new_inp) {
  930. mgb_assert(opr->input().size() == new_inp.size());
  931. auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
  932. mgb_assert(conv_bias_opr.param().format ==
  933. megdnn::param::ConvBias::Format::NCHW,
  934. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  935. VarNode *conv_bias_src = nullptr, *conv_bias_weights = nullptr,
  936. *conv_bias_bias = nullptr;
  937. if (new_inp[0]->shape().ndim == 4) {
  938. // new input src is NCHW
  939. size_t group, icpg, ocpg;
  940. if (conv_bias_opr.param().sparse ==
  941. megdnn::param::ConvBias::Sparse::DENSE) {
  942. group = 1;
  943. icpg = new_inp[1]->shape()[1];
  944. ocpg = new_inp[1]->shape()[0];
  945. } else {
  946. mgb_assert(conv_bias_opr.param().sparse ==
  947. megdnn::param::ConvBias::Sparse::GROUP);
  948. group = new_inp[1]->shape()[0];
  949. icpg = new_inp[1]->shape()[2];
  950. ocpg = new_inp[1]->shape()[1];
  951. }
  952. if (ocpg % 4 == 0 && (icpg % 4 == 0 || group == 1)) {
  953. auto param = megdnn::param::RelayoutFormat();
  954. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  955. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  956. conv_bias_src = rf.node();
  957. } else {
  958. // can not convert to hwcd4
  959. return serialization::copy_opr_shallow(*opr, new_inp,
  960. opr->config());
  961. }
  962. } else {
  963. size_t ocpg;
  964. bool is_channel_wise = false;
  965. if (conv_bias_opr.param().sparse ==
  966. megdnn::param::ConvBias::Sparse::DENSE) {
  967. ocpg = new_inp[1]->shape()[0];
  968. } else {
  969. mgb_assert(conv_bias_opr.param().sparse ==
  970. megdnn::param::ConvBias::Sparse::GROUP);
  971. size_t icpg = new_inp[1]->shape()[2];
  972. ocpg = new_inp[1]->shape()[1];
  973. if (icpg == 1 && ocpg == 1) {
  974. is_channel_wise = true;
  975. }
  976. }
  977. if (ocpg % 4 != 0 && !is_channel_wise) {
  978. VarNodeArray t_inp = new_inp;
  979. auto param = megdnn::param::RelayoutFormat();
  980. param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
  981. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  982. t_inp[0] = rf.node();
  983. auto new_opr = serialization::copy_opr_shallow(*opr, t_inp,
  984. opr->config());
  985. return new_opr;
  986. }
  987. // new input src is NHWCD4
  988. auto&& fmt = new_inp[0]
  989. ->format()
  990. .as_impl<megdnn::Image2DPack4TensorFormat>();
  991. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  992. conv_bias_src = new_inp[0];
  993. }
  994. mgb_assert(new_inp[1]->format().type() !=
  995. TensorFormat::Type::IMAGE2D_PACK4);
  996. auto param = megdnn::param::RelayoutFormat();
  997. param.mode = filter_mode(conv_bias_opr.param().sparse, new_inp[1]);
  998. auto relayout_weight = opr::RelayoutFormat::make(new_inp[1], param);
  999. conv_bias_weights = relayout_weight.node();
  1000. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1001. auto relayout_bias = opr::RelayoutFormat::make(new_inp[2], param);
  1002. conv_bias_bias = relayout_bias.node();
  1003. auto new_param = conv_bias_opr.param();
  1004. new_param.format = megdnn::param::ConvBias::Format::NHWCD4;
  1005. mgb_assert(conv_bias_src->shape().ndim == 5 &&
  1006. conv_bias_src->format().type() ==
  1007. TensorFormat::Type::IMAGE2D_PACK4);
  1008. auto new_conv_bias_opr = opr::ConvBias::make(
  1009. conv_bias_src, conv_bias_weights, conv_bias_bias, new_param,
  1010. conv_bias_opr.execution_policy(), conv_bias_opr.config());
  1011. OperatorNodeBase* ret = new_conv_bias_opr.node()->owner_opr();
  1012. mgb_assert(new_conv_bias_opr.shape().ndim == 5 &&
  1013. new_conv_bias_opr.format().type() ==
  1014. TensorFormat::Type::IMAGE2D_PACK4);
  1015. return ret;
  1016. };
  1017. auto replace_deconv_opr = [&filter_mode](OperatorNodeBase* opr,
  1018. const VarNodeArray& new_inp) {
  1019. mgb_assert(opr->input().size() == new_inp.size());
  1020. auto& deconv_opr = opr->cast_final_safe<opr::ConvolutionBackwardData>();
  1021. mgb_assert(deconv_opr.param().format ==
  1022. megdnn::param::Convolution::Format::NCHW,
  1023. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1024. VarNode *deconv_src = nullptr, *deconv_weights = nullptr;
  1025. if (new_inp[1]->shape().ndim == 4) {
  1026. // new input src is NCHW
  1027. size_t group, icpg, ocpg;
  1028. if (deconv_opr.param().sparse ==
  1029. megdnn::param::Convolution::Sparse::DENSE) {
  1030. group = 1;
  1031. icpg = new_inp[0]->shape()[0];
  1032. ocpg = new_inp[0]->shape()[1];
  1033. } else {
  1034. mgb_assert(deconv_opr.param().sparse ==
  1035. megdnn::param::Convolution::Sparse::GROUP);
  1036. group = new_inp[0]->shape()[0];
  1037. icpg = new_inp[0]->shape()[1];
  1038. ocpg = new_inp[0]->shape()[2];
  1039. }
  1040. if (ocpg % 4 == 0 && (icpg % 4 == 0 || group == 1)) {
  1041. auto param = megdnn::param::RelayoutFormat();
  1042. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1043. auto rf = opr::RelayoutFormat::make(new_inp[1], param);
  1044. deconv_src = rf.node();
  1045. } else {
  1046. // can not convert to hwcd4
  1047. return serialization::copy_opr_shallow(*opr, new_inp,
  1048. opr->config());
  1049. }
  1050. } else {
  1051. //! XXXX, fix me, check filter size
  1052. size_t ocpg;
  1053. if (deconv_opr.param().sparse ==
  1054. megdnn::param::Convolution::Sparse::DENSE) {
  1055. ocpg = new_inp[0]->shape()[1];
  1056. } else {
  1057. mgb_assert(deconv_opr.param().sparse ==
  1058. megdnn::param::Convolution::Sparse::GROUP);
  1059. ocpg = new_inp[0]->shape()[2];
  1060. }
  1061. if (ocpg % 4 != 0) {
  1062. VarNodeArray t_inp = new_inp;
  1063. auto param = megdnn::param::RelayoutFormat();
  1064. param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
  1065. auto rf = opr::RelayoutFormat::make(new_inp[1], param);
  1066. t_inp[1] = rf.node();
  1067. auto new_opr = serialization::copy_opr_shallow(*opr, t_inp,
  1068. opr->config());
  1069. return new_opr;
  1070. }
  1071. // new input src is NHWCD4
  1072. auto&& fmt = new_inp[1]
  1073. ->format()
  1074. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1075. mgb_assert(new_inp[1]->shape().ndim == 5 && fmt.align_axis() == 2);
  1076. deconv_src = new_inp[1];
  1077. }
  1078. mgb_assert(new_inp[0]->format().type() !=
  1079. TensorFormat::Type::IMAGE2D_PACK4);
  1080. auto param = megdnn::param::RelayoutFormat();
  1081. param.mode = filter_mode(deconv_opr.param().sparse, new_inp[0]);
  1082. auto relayout_weight = opr::RelayoutFormat::make(new_inp[0], param);
  1083. deconv_weights = relayout_weight.node();
  1084. auto new_param = deconv_opr.param();
  1085. new_param.format = megdnn::param::Convolution::Format::NHWCD4;
  1086. mgb_assert(deconv_src->shape().ndim == 5 &&
  1087. deconv_src->format().type() ==
  1088. TensorFormat::Type::IMAGE2D_PACK4);
  1089. auto new_deconv_opr = opr::ConvolutionBackwardData::make(
  1090. deconv_weights, deconv_src, new_param,
  1091. deconv_opr.execution_policy(), deconv_opr.config());
  1092. OperatorNodeBase* ret = new_deconv_opr.node()->owner_opr();
  1093. mgb_assert(new_deconv_opr.shape().ndim == 5 &&
  1094. new_deconv_opr.format().type() ==
  1095. TensorFormat::Type::IMAGE2D_PACK4);
  1096. return ret;
  1097. };
  1098. /* This helper function guarantees the format convert pass won't change
  1099. * output var's channel. Changing output's channel will cause channel
  1100. * mismatch problem for replacing conv/conv_bias operator.
  1101. */
  1102. auto replace_helper = [](OperatorNodeBase* opr,
  1103. const VarNodeArray& new_inp) -> OperatorNodeBase* {
  1104. auto&& new_shp = new_inp[0]->shape();
  1105. size_t inp_channel = new_shp[1];
  1106. if (new_shp.eq_shape(opr->input(0)->shape())&& inp_channel % 4 != 0) {
  1107. auto new_opr = serialization::copy_opr_shallow(*opr, new_inp,
  1108. opr->config());
  1109. return new_opr;
  1110. }
  1111. return nullptr;
  1112. };
  1113. auto replace_resize_opr = [replace_helper](OperatorNodeBase* opr,
  1114. const VarNodeArray& new_inp) {
  1115. mgb_assert(opr->input().size() == new_inp.size());
  1116. if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
  1117. return opr_shallow_copy;
  1118. }
  1119. auto& resize_opr = opr->cast_final_safe<opr::ResizeForward>();
  1120. mgb_assert(resize_opr.param().format ==
  1121. megdnn::param::Resize::Format::NCHW,
  1122. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1123. VarNode* inp = nullptr;
  1124. if (new_inp[0]->shape().ndim == 4) {
  1125. auto param = megdnn::param::RelayoutFormat();
  1126. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1127. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1128. inp = rf.node();
  1129. } else {
  1130. // new input src is NHWCD
  1131. auto&& fmt = new_inp[0]
  1132. ->format()
  1133. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1134. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1135. inp = new_inp[0];
  1136. }
  1137. auto new_param = resize_opr.param();
  1138. new_param.format = megdnn::param::Resize::Format::NHWCD4;
  1139. auto new_resize_opr = opr::ResizeForward::make(
  1140. inp, new_inp[1], new_param, opr->config());
  1141. return new_resize_opr.node()->owner_opr();
  1142. };
  1143. auto replace_warp_perspective_opr = [replace_helper](
  1144. OperatorNodeBase* opr,
  1145. const VarNodeArray& new_inp) {
  1146. mgb_assert(opr->input().size() == new_inp.size());
  1147. if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
  1148. return opr_shallow_copy;
  1149. }
  1150. auto& warp_opr = opr->cast_final_safe<opr::WarpPerspectiveForward>();
  1151. mgb_assert(warp_opr.param().format ==
  1152. megdnn::param::WarpPerspective::Format::NCHW,
  1153. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1154. VarNode* inp = nullptr;
  1155. if (new_inp[0]->shape().ndim == 4) {
  1156. // new input src is NCHW
  1157. auto param = megdnn::param::RelayoutFormat();
  1158. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1159. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1160. inp = rf.node();
  1161. } else {
  1162. // new input src is NHWCD
  1163. auto&& fmt = new_inp[0]
  1164. ->format()
  1165. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1166. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1167. inp = new_inp[0];
  1168. }
  1169. auto new_param = warp_opr.param();
  1170. new_param.format = megdnn::param::WarpPerspective::Format::NHWCD4;
  1171. SymbolVar new_warp_opr;
  1172. if (new_inp.size() == 3) {
  1173. new_warp_opr = opr::WarpPerspectiveForward::make(
  1174. inp, new_inp[1], nullptr, new_inp[2], new_param,
  1175. opr->config());
  1176. } else {
  1177. mgb_assert(new_inp.size() == 4);
  1178. new_warp_opr = opr::WarpPerspectiveForward::make(
  1179. inp, new_inp[1], new_inp[2], new_inp[3], new_param,
  1180. opr->config());
  1181. }
  1182. return new_warp_opr.node()->owner_opr();
  1183. };
  1184. auto replace_warp_affine_opr = [replace_helper](OperatorNodeBase* opr,
  1185. const VarNodeArray& new_inp) {
  1186. mgb_assert(opr->input().size() == new_inp.size());
  1187. if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
  1188. return opr_shallow_copy;
  1189. }
  1190. auto& warp_opr = opr->cast_final_safe<opr::WarpAffineForward>();
  1191. mgb_assert(warp_opr.param().format ==
  1192. megdnn::param::WarpAffine::Format::NCHW,
  1193. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1194. VarNode* inp = nullptr;
  1195. if (new_inp[0]->shape().ndim == 4) {
  1196. // new input src is NCHW
  1197. auto param = megdnn::param::RelayoutFormat();
  1198. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1199. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1200. inp = rf.node();
  1201. } else {
  1202. // new input src is NHWCD
  1203. auto&& fmt = new_inp[0]
  1204. ->format()
  1205. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1206. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1207. inp = new_inp[0];
  1208. }
  1209. auto new_param = warp_opr.param();
  1210. new_param.format = megdnn::param::WarpAffine::Format::NHWCD4;
  1211. SymbolVar new_warp_opr;
  1212. new_warp_opr = opr::WarpAffineForward::make(inp, new_inp[1], new_inp[2],
  1213. new_param, opr->config());
  1214. return new_warp_opr.node()->owner_opr();
  1215. };
  1216. auto replace_pooling_opr = [replace_helper](OperatorNodeBase* opr,
  1217. const VarNodeArray& new_inp) {
  1218. mgb_assert(opr->input().size() == new_inp.size());
  1219. if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
  1220. return opr_shallow_copy;
  1221. }
  1222. auto& pooling_opr = opr->cast_final_safe<opr::PoolingForward>();
  1223. mgb_assert(pooling_opr.param().format ==
  1224. megdnn::param::Pooling::Format::NCHW,
  1225. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1226. VarNode* inp = nullptr;
  1227. if (new_inp[0]->shape().ndim == 4) {
  1228. // new input src is NCHW
  1229. auto param = megdnn::param::RelayoutFormat();
  1230. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1231. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1232. inp = rf.node();
  1233. } else {
  1234. // new input src is NHWCD
  1235. auto&& fmt = new_inp[0]
  1236. ->format()
  1237. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1238. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1239. inp = new_inp[0];
  1240. }
  1241. auto new_param = pooling_opr.param();
  1242. new_param.format = megdnn::param::Pooling::Format::NHWCD4;
  1243. auto new_pooling_opr =
  1244. opr::PoolingForward::make(inp, new_param, opr->config());
  1245. return new_pooling_opr.node()->owner_opr();
  1246. };
  1247. auto var_to_chw = [](VarNode* inp, VarNode* new_inp) {
  1248. if (!inp->shape().eq_shape(new_inp->shape())) {
  1249. mgb_assert(inp->shape().ndim == 4 &&
  1250. inp->format().type() !=
  1251. TensorFormat::Type::IMAGE2D_PACK4);
  1252. mgb_assert(new_inp->shape().ndim == 5 &&
  1253. new_inp->format().type() ==
  1254. TensorFormat::Type::IMAGE2D_PACK4);
  1255. auto param = megdnn::param::RelayoutFormat();
  1256. param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
  1257. auto rf = opr::RelayoutFormat::make(new_inp, param);
  1258. return rf.node();
  1259. }
  1260. return new_inp;
  1261. };
  1262. auto relayout_inp_to_chw = [var_to_chw](OperatorNodeBase* opr,
  1263. const VarNodeArray& new_inp) {
  1264. mgb_assert(opr->input().size() == new_inp.size());
  1265. VarNodeArray t_inp = new_inp;
  1266. for (size_t i = 0; i < opr->input().size(); i++) {
  1267. t_inp[i] = var_to_chw(opr->input(i), new_inp[i]);
  1268. }
  1269. auto new_opr =
  1270. serialization::copy_opr_shallow(*opr, t_inp, opr->config());
  1271. return new_opr;
  1272. };
  1273. auto replace_elemwise_opr = [](OperatorNodeBase* opr,
  1274. const VarNodeArray& new_inp) {
  1275. mgb_assert(opr->input().size() == new_inp.size());
  1276. bool has_inp_changed = false;
  1277. for (size_t i = 0; i < opr->input().size(); i++) {
  1278. if (!new_inp[i]->format().is_default()) {
  1279. has_inp_changed = true;
  1280. break;
  1281. }
  1282. }
  1283. if (has_inp_changed) {
  1284. // assumption: all inputs are changed from nchw to nhwcd4
  1285. auto t_inp = new_inp;
  1286. for (size_t i = 0; i < opr->input().size(); i++) {
  1287. if (new_inp[i]->shape().ndim == 4) {
  1288. auto param = megdnn::param::RelayoutFormat();
  1289. param.mode =
  1290. megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1291. auto rf = opr::RelayoutFormat::make(new_inp[i], param);
  1292. t_inp[i] = rf.node();
  1293. } else {
  1294. mgb_assert((new_inp[i]->shape().ndim == 5 &&
  1295. new_inp[i]->format().type() ==
  1296. TensorFormat::Type::IMAGE2D_PACK4) ||
  1297. new_inp[i]->shape().is_scalar());
  1298. }
  1299. }
  1300. return serialization::copy_opr_shallow(*opr, t_inp, opr->config());
  1301. } else {
  1302. return serialization::copy_opr_shallow(*opr, new_inp,
  1303. opr->config());
  1304. }
  1305. };
  1306. /* This helper function converts the first input to the NCHW format to
  1307. * handle operations that do not support NHWCD4 format
  1308. */
  1309. auto relayout_first_inp_to_chw =
  1310. [var_to_chw](OperatorNodeBase* opr,
  1311. const VarNodeArray& new_inp) -> OperatorNodeBase* {
  1312. mgb_assert(opr->input().size() == new_inp.size());
  1313. VarNodeArray t_inp = new_inp;
  1314. t_inp[0] = var_to_chw(opr->input(0), new_inp[0]);
  1315. return serialization::copy_opr_shallow(*opr, t_inp, opr->config());
  1316. };
  1317. auto ret = std::make_unique<ConvertFormatPass>();
  1318. ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
  1319. auto&& replace_func = ret->m_opr_replace_func;
  1320. replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
  1321. replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr;
  1322. replace_func[opr::ConvolutionBackwardData::typeinfo()] = replace_deconv_opr;
  1323. replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
  1324. replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr;
  1325. replace_func[opr::Concat::typeinfo()] = relayout_inp_to_chw;
  1326. replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_chw;
  1327. replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_chw;
  1328. replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_chw;
  1329. replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_chw;
  1330. replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_chw;
  1331. replace_func[opr::Subtensor::typeinfo()] = relayout_inp_to_chw;
  1332. replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_chw;
  1333. replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_chw;
  1334. replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr;
  1335. replace_func[opr::WarpPerspectiveForward::typeinfo()] =
  1336. replace_warp_perspective_opr;
  1337. replace_func[opr::WarpAffineForward::typeinfo()] = replace_warp_affine_opr;
  1338. replace_func[opr::LocalForward::typeinfo()] = relayout_first_inp_to_chw;
  1339. replace_func[opr::GroupLocalForward::typeinfo()] =
  1340. relayout_first_inp_to_chw;
  1341. return ret;
  1342. }
  1343. /* ================ ConvertBatchNormPass ================ */
  1344. const char* ConvertBatchNormToElemwisePass::name() const {
  1345. return "convert_batch_norm";
  1346. }
  1347. void ConvertBatchNormToElemwisePass::apply(OptState& state) const {
  1348. auto rewriter = state.graph().make_rewriter();
  1349. auto on_opr = [&](OperatorNodeBase* opr) {
  1350. if (auto bn = try_cast_as_op<opr::BatchNorm>(opr)) {
  1351. if (bn->input().size() == 5) {
  1352. mgb_assert(bn->param().fwd_mode ==
  1353. opr::BatchNorm::Param::FwdMode::INFERENCE);
  1354. SymbolVar x = {rewriter.get_var(bn->input(0))};
  1355. SymbolVar scale = {rewriter.get_var(bn->input(1))};
  1356. SymbolVar bias = {rewriter.get_var(bn->input(2))};
  1357. SymbolVar mean = {rewriter.get_var(bn->input(3))};
  1358. SymbolVar variance = {rewriter.get_var(bn->input(4))};
  1359. SymbolVar invsqrt_variance = opr::PowC::make(variance, {-0.5});
  1360. auto res = scale * (x - mean) * invsqrt_variance + bias;
  1361. rewriter.replace_var(
  1362. opr->output(4), res.node(),
  1363. mgb_cstr_log(
  1364. "replace batch_norm(x, scale, bias, mean, "
  1365. "varience) "
  1366. "-> (sclae * (x - mean) / sqrt(variance)) + b)"));
  1367. return;
  1368. }
  1369. }
  1370. rewriter.auto_replace_outputs(opr);
  1371. };
  1372. state.graph().iter(on_opr);
  1373. rewriter.apply_inplace();
  1374. }
  1375. /* ================ FuseConvBiasNonlinPass ================ */
  1376. const char* FuseConvBiasNonlinPass::name() const {
  1377. return "combine_conv_bias_and_relu";
  1378. }
  1379. void FuseConvBiasNonlinPass::apply(OptState& state) const {
  1380. std::unordered_map<VarNode*, std::vector<OperatorNodeBase*>> m_deps;
  1381. state.graph().iter([&m_deps](OperatorNodeBase* opr) {
  1382. for (auto& inp : opr->input()) {
  1383. m_deps[inp].push_back(opr);
  1384. }
  1385. });
  1386. auto rewriter = state.graph().make_rewriter();
  1387. using Mode = opr::Elemwise::Param::Mode;
  1388. using NonlineMode = opr::ConvBiasForward::Param::NonlineMode;
  1389. auto get_nonlinearity_mode = [&](opr::Elemwise* elem) -> NonlineMode {
  1390. if (elem->param().mode == Mode::FUSE_ADD_RELU ||
  1391. elem->param().mode == Mode::RELU) {
  1392. return NonlineMode::RELU;
  1393. } else if (elem->param().mode == Mode::FUSE_ADD_SIGMOID ||
  1394. elem->param().mode == Mode::SIGMOID) {
  1395. return NonlineMode::SIGMOID;
  1396. } else {
  1397. return NonlineMode::IDENTITY;
  1398. }
  1399. };
  1400. auto try_fuse_bias_nonlinearity = [&](opr::Elemwise* elem) -> bool {
  1401. bool can_be_fused = true;
  1402. can_be_fused &= (elem->input().size() == 2);
  1403. can_be_fused &= (elem->param().mode == Mode::FUSE_ADD_RELU) ||
  1404. (elem->param().mode == Mode::FUSE_ADD_TANH) ||
  1405. (elem->param().mode == Mode::FUSE_ADD_SIGMOID);
  1406. return can_be_fused;
  1407. };
  1408. auto try_fuse_bias = [&](opr::Elemwise* elem) -> bool {
  1409. bool can_be_fused = true;
  1410. can_be_fused &= (elem->input().size() == 2);
  1411. can_be_fused &= (elem->param().mode == Mode::ADD);
  1412. return can_be_fused;
  1413. };
  1414. auto try_fuse_nonlinearity = [&](opr::Elemwise* elem) -> bool {
  1415. bool can_be_fused = true;
  1416. can_be_fused &= (elem->input().size() == 1);
  1417. can_be_fused &= (elem->param().mode == Mode::RELU) ||
  1418. (elem->param().mode == Mode::TANH) ||
  1419. (elem->param().mode == Mode::SIGMOID);
  1420. return can_be_fused;
  1421. };
  1422. auto convert_to_conv_bias_param = [&](const opr::Convolution::Param& param)
  1423. -> opr::ConvBiasForward::Param {
  1424. using Param = opr::ConvBiasForward::Param;
  1425. return opr::ConvBiasForward::Param{Param::NonlineMode::IDENTITY,
  1426. param.mode,
  1427. param.sparse,
  1428. param.format,
  1429. param.pad_h,
  1430. param.pad_w,
  1431. param.stride_h,
  1432. param.stride_w,
  1433. param.dilate_h,
  1434. param.dilate_w};
  1435. };
  1436. auto check_bias_shape = [&](opr::Convolution* conv, VarNode* bias) -> bool {
  1437. bool valid_bias_shape = true;
  1438. using Format = opr::Convolution::Param::Format;
  1439. using Sparse = opr::Convolution::Param::Sparse;
  1440. auto dst_shape = conv->output(0)->shape();
  1441. auto filter_shape = conv->input(1)->shape();
  1442. auto bias_shape = bias->shape();
  1443. if (dst_shape.eq_shape(bias_shape)) {
  1444. return valid_bias_shape;
  1445. }
  1446. size_t OC = filter_shape[0];
  1447. if (conv->param().sparse == Sparse::GROUP) {
  1448. OC *= filter_shape[1];
  1449. }
  1450. if (conv->param().format == Format::NCHW) {
  1451. valid_bias_shape &=
  1452. ((bias_shape.ndim == 4) && (bias_shape[0] == 1) &&
  1453. (bias_shape[1] == OC) && (bias_shape[2] == 1) &&
  1454. (bias_shape[3] == 1));
  1455. } else if (conv->param().format == Format::NCHW4) {
  1456. valid_bias_shape &=
  1457. ((bias_shape.ndim == 5) && (bias_shape[0] == 1) &&
  1458. (bias_shape[1] == OC / 4) && (bias_shape[2] == 1) &&
  1459. (bias_shape[3] == 1) && bias_shape[4] == 4);
  1460. } else if (conv->param().format == Format::NHWC) {
  1461. valid_bias_shape &= ((bias_shape.ndim == 4) &&
  1462. (bias_shape[0] == 1) && (bias_shape[1] == 1) &&
  1463. (bias_shape[2] == 1) && (bias_shape[3] == OC));
  1464. } else {
  1465. valid_bias_shape &=
  1466. ((bias_shape.ndim == 5) && (bias_shape[0] == 1) &&
  1467. (bias_shape[1] == 1) && (bias_shape[2] == OC) &&
  1468. (bias_shape[3] == 1) && (bias_shape[4] == 4));
  1469. mgb_assert(conv->param().format == Format::NHWCD4);
  1470. }
  1471. return valid_bias_shape;
  1472. };
  1473. auto try_fuse_typecvt = [&](opr::TypeCvt* typecvt) -> OperatorNodeBase* {
  1474. mgb_assert(typecvt->input().size() == 1);
  1475. auto conv_bias = try_cast_as_op<opr::ConvBias>(
  1476. rewriter.get_var(typecvt->input(0))->owner_opr());
  1477. if (!conv_bias || m_deps.count(typecvt->input(0)) != 1 ||
  1478. typecvt->output(0)->dtype().enumv() !=
  1479. DTypeTrait<dtype::QuantizedS8>::enumv)
  1480. return nullptr;
  1481. auto config = conv_bias->config();
  1482. config.output_dtype(typecvt->output(0)->dtype());
  1483. if (conv_bias->input().size() == 3) {
  1484. // conv + bias
  1485. return opr::ConvBias::make(conv_bias->input(0), conv_bias->input(1),
  1486. conv_bias->input(2), conv_bias->param(),
  1487. conv_bias->execution_policy(), config)
  1488. .node()
  1489. ->owner_opr();
  1490. } else {
  1491. // conv without bias
  1492. return opr::ConvBias::make(conv_bias->input(0), conv_bias->input(1),
  1493. conv_bias->param(),
  1494. conv_bias->execution_policy(), config)
  1495. .node()
  1496. ->owner_opr();
  1497. }
  1498. };
  1499. auto on_opr = [&](OperatorNodeBase* opr) {
  1500. auto check_conv = [](opr::Convolution* conv) -> bool {
  1501. return conv->param().format ==
  1502. megdnn::param::Convolution::Format::NHWCD4 ||
  1503. conv->param().format ==
  1504. megdnn::param::Convolution::Format::NHWC ||
  1505. conv->param().format ==
  1506. megdnn::param::Convolution::Format::NCHW ||
  1507. conv->param().format ==
  1508. megdnn::param::Convolution::Format::NCHW4
  1509. ;
  1510. };
  1511. if (auto elem = try_cast_as_op<opr::Elemwise>(opr)) {
  1512. if (try_fuse_bias_nonlinearity(elem) || try_fuse_bias(elem)) {
  1513. auto inp1 = rewriter.get_var(elem->input(0));
  1514. auto inp2 = rewriter.get_var(elem->input(1));
  1515. opr::Convolution* conv = nullptr;
  1516. size_t bias_idx = 0;
  1517. if (inp1->owner_opr()->same_type<opr::Convolution>() &&
  1518. m_deps[elem->input(0)].size() == 1) {
  1519. conv = try_cast_as_op<opr::Convolution>(inp1->owner_opr());
  1520. bias_idx = 1;
  1521. } else if (inp2->owner_opr()->same_type<opr::Convolution>() &&
  1522. m_deps[elem->input(1)].size() == 1) {
  1523. conv = try_cast_as_op<opr::Convolution>(inp2->owner_opr());
  1524. bias_idx = 0;
  1525. }
  1526. auto bias_inp = rewriter.get_var(elem->input(bias_idx));
  1527. if (conv && check_conv(conv) &&
  1528. check_bias_shape(conv, bias_inp)) {
  1529. opr::ConvBiasForward::Param param =
  1530. convert_to_conv_bias_param(conv->param());
  1531. param.nonlineMode = get_nonlinearity_mode(elem);
  1532. auto new_var =
  1533. opr::ConvBiasForward::make(
  1534. conv->input(0), conv->input(1), bias_inp,
  1535. param, conv->execution_policy(),
  1536. conv->config())
  1537. .node();
  1538. rewriter.replace_var(
  1539. opr->output(0), new_var,
  1540. mgb_cstr_log("replace nonlinearity(conv(x, w) + b) "
  1541. "-> conv_bias(x, w, b)"));
  1542. return;
  1543. }
  1544. } else if (try_fuse_nonlinearity(elem)) {
  1545. auto inp = rewriter.get_var(elem->input(0));
  1546. {
  1547. auto conv =
  1548. try_cast_as_op<opr::Convolution>(inp->owner_opr());
  1549. if (conv && check_conv(conv) &&
  1550. m_deps[elem->input(0)].size() == 1) {
  1551. opr::ConvBiasForward::Param param =
  1552. convert_to_conv_bias_param(conv->param());
  1553. param.nonlineMode = get_nonlinearity_mode(elem);
  1554. auto new_var = opr::ConvBiasForward::make(
  1555. conv->input(0), conv->input(1),
  1556. param, conv->execution_policy(),
  1557. conv->config())
  1558. .node();
  1559. rewriter.replace_var(
  1560. opr->output(0), new_var,
  1561. mgb_cstr_log("replace nonlinearity(conv(x, w)) "
  1562. "-> conv_bias(x, w)"));
  1563. return;
  1564. }
  1565. }
  1566. {
  1567. auto conv = try_cast_as_op<opr::ConvBias>(inp->owner_opr());
  1568. auto check_conv_bias = [&](opr::ConvBias* opr) {
  1569. return opr->param().format ==
  1570. opr::ConvBias::Param::Format::NHWC ||
  1571. opr->param().format ==
  1572. opr::ConvBias::Param::Format::NCHW ||
  1573. opr->param().format ==
  1574. opr::ConvBias::Param::Format::NCHW4
  1575. ;
  1576. };
  1577. if (conv && check_conv_bias(conv) &&
  1578. m_deps[elem->input(0)].size() == 1) {
  1579. auto param = conv->param();
  1580. param.nonlineMode = get_nonlinearity_mode(elem);
  1581. auto new_var = opr::ConvBiasForward::make(
  1582. conv->input(0), conv->input(1),
  1583. conv->input(2), param,
  1584. conv->execution_policy(),
  1585. conv->config())
  1586. .node();
  1587. rewriter.replace_var(
  1588. opr->output(0), new_var,
  1589. mgb_cstr_log("replace nonlinearity(conv(x, w)) "
  1590. "-> conv_bias(x, w)"));
  1591. return;
  1592. }
  1593. }
  1594. }
  1595. } else if (auto typecvt = try_cast_as_op<opr::TypeCvt>(opr)) {
  1596. auto new_opr = try_fuse_typecvt(typecvt);
  1597. if (new_opr) {
  1598. rewriter.replace_var(
  1599. opr->output(0), new_opr->output(0),
  1600. mgb_cstr_log("replace typecvt(conv_bias(x, w, b)) -> "
  1601. "conv_bias(x, w, b)"));
  1602. return;
  1603. }
  1604. }
  1605. rewriter.auto_replace_outputs(opr);
  1606. };
  1607. state.graph().iter(on_opr);
  1608. rewriter.apply_inplace();
  1609. }
  1610. /* ================ FuseConvBiasZPass ================ */
  1611. const char* FuseConvBiasZPass::name() const {
  1612. return "combine_conv_bias_and_z";
  1613. }
  1614. void FuseConvBiasZPass::apply(OptState& state) const {
  1615. UniqReaderCheck uniq_reader_check{state.graph()};
  1616. auto rewriter = state.graph().make_rewriter();
  1617. using Mode = opr::Elemwise::Param::Mode;
  1618. using MultiMode = opr::ElemwiseMultiType::Param::Mode;
  1619. using NonlineMode = opr::ConvBiasForward::Param::NonlineMode;
  1620. auto check_conv_bias = [](opr::ConvBias* conv_bias) -> bool {
  1621. return conv_bias->param().format ==
  1622. megdnn::param::ConvBias::Format::NHWC ||
  1623. conv_bias->param().format ==
  1624. megdnn::param::ConvBias::Format::NCHW ||
  1625. conv_bias->param().format ==
  1626. megdnn::param::ConvBias::Format::NCHW4
  1627. ;
  1628. };
  1629. auto check_fuse_shape = [&](opr::ConvBias* conv_bias, VarNode* z) -> bool {
  1630. bool valid_fuse_shape = true;
  1631. auto z_shape = z->shape();
  1632. auto bias_shape = conv_bias->input(2)->shape();
  1633. auto conv_bias_shape = conv_bias->output(0)->shape();
  1634. valid_fuse_shape &= (!conv_bias_shape.eq_shape(bias_shape));
  1635. valid_fuse_shape &= conv_bias_shape.eq_shape(z_shape);
  1636. return valid_fuse_shape;
  1637. };
  1638. auto check_fuse_dtype = [&](opr::ConvBias* conv_bias, VarNode* z) -> bool {
  1639. return conv_bias->output(0)->dtype().enumv() == z->dtype().enumv();
  1640. };
  1641. auto get_convbias_nonline_mode = [&](OperatorNodeBase* opr) -> NonlineMode {
  1642. if (opr->same_type<opr::Elemwise>()) {
  1643. auto elem = try_cast_as_op<opr::Elemwise>(opr);
  1644. if (elem->param().mode == Mode::FUSE_ADD_RELU)
  1645. return NonlineMode::RELU;
  1646. }
  1647. if (opr->same_type<opr::ElemwiseMultiType>()) {
  1648. auto elem = try_cast_as_op<opr::ElemwiseMultiType>(opr);
  1649. if (elem->param().mode == MultiMode::QFUSE_ADD_RELU)
  1650. return NonlineMode::RELU;
  1651. }
  1652. return NonlineMode::IDENTITY;
  1653. };
  1654. auto try_replace_var_node = [&](OperatorNodeBase* opr) {
  1655. opr::ConvBias* conv_bias = nullptr;
  1656. size_t z_idx = 0;
  1657. size_t nr_inps = opr->input().size();
  1658. for (size_t i = 0; i < nr_inps; i++) {
  1659. auto inp = rewriter.get_var(opr->input(i));
  1660. if (inp->owner_opr()->same_type<opr::ConvBias>()) {
  1661. auto cb = try_cast_as_op<opr::ConvBias>(inp->owner_opr());
  1662. if (cb->input().size() == 3 &&
  1663. cb->param().nonlineMode ==
  1664. opr::ConvBias::Param::NonlineMode::IDENTITY &&
  1665. uniq_reader_check(opr->input(i))) {
  1666. conv_bias = cb;
  1667. z_idx = nr_inps - i - 1;
  1668. break;
  1669. }
  1670. }
  1671. }
  1672. auto z_inp = rewriter.get_var(opr->input(z_idx));
  1673. if (conv_bias && check_conv_bias(conv_bias) &&
  1674. check_fuse_shape(conv_bias, z_inp) &&
  1675. check_fuse_dtype(conv_bias, z_inp)) {
  1676. auto param = conv_bias->param();
  1677. param.nonlineMode = get_convbias_nonline_mode(opr);
  1678. auto config = conv_bias->config();
  1679. auto new_var = opr::ConvBiasForward::make(
  1680. conv_bias->input(0), conv_bias->input(1),
  1681. conv_bias->input(2), z_inp, param,
  1682. conv_bias->execution_policy(),
  1683. config.output_dtype(opr->output(0)->dtype()))
  1684. .node();
  1685. rewriter.replace_var(
  1686. opr->output(0), new_var,
  1687. mgb_cstr_log("replace "
  1688. "nonlinearity(conv_bias(x,w,b) + z) "
  1689. "-> conv_bias(x, w, b, z)"));
  1690. uniq_reader_check.update_on_opr_auto_replace(opr,
  1691. new_var->owner_opr());
  1692. return true;
  1693. }
  1694. return false;
  1695. };
  1696. auto try_fuse_elemwise = [&](OperatorNodeBase* opr) {
  1697. if (!opr->same_type<opr::Elemwise>())
  1698. return false;
  1699. auto elem = try_cast_as_op<opr::Elemwise>(opr);
  1700. if (elem->input().size() != 2)
  1701. return false;
  1702. if (elem->param().mode != Mode::ADD &&
  1703. elem->param().mode != Mode::FUSE_ADD_RELU)
  1704. return false;
  1705. return try_replace_var_node(opr);
  1706. };
  1707. auto try_fuse_elemwise_multi_type = [&](OperatorNodeBase* opr) {
  1708. if (!opr->same_type<opr::ElemwiseMultiType>())
  1709. return false;
  1710. auto elem = try_cast_as_op<opr::ElemwiseMultiType>(opr);
  1711. if (elem->input().size() != 2)
  1712. return false;
  1713. if (elem->param().mode != MultiMode::QADD &&
  1714. elem->param().mode != MultiMode::QFUSE_ADD_RELU)
  1715. return false;
  1716. return try_replace_var_node(opr);
  1717. };
  1718. auto on_opr = [&](OperatorNodeBase* opr) {
  1719. if (try_fuse_elemwise(opr))
  1720. return;
  1721. if (try_fuse_elemwise_multi_type(opr))
  1722. return;
  1723. auto new_opr = rewriter.auto_replace_outputs(opr);
  1724. uniq_reader_check.update_on_opr_auto_replace(opr, new_opr);
  1725. };
  1726. state.graph().iter(on_opr);
  1727. rewriter.apply_inplace();
  1728. }
  1729. /* ================ FuseDeconvCvtPass ================ */
  1730. const char* FuseDeconvCvtPass::name() const {
  1731. return "combine_deconv_and_typecvt";
  1732. }
  1733. void FuseDeconvCvtPass::apply(OptState& state) const {
  1734. std::unordered_map<VarNode*, std::vector<OperatorNodeBase*>> m_deps;
  1735. state.graph().iter([&m_deps](OperatorNodeBase* opr) {
  1736. for (auto& inp : opr->input()) {
  1737. m_deps[inp].push_back(opr);
  1738. }
  1739. });
  1740. UniqReaderCheck uniq_reader_check{state.graph()};
  1741. auto rewriter = state.graph().make_rewriter();
  1742. auto try_fuse_deconv_typecvt =
  1743. [&](opr::TypeCvt* typecvt) -> OperatorNodeBase* {
  1744. mgb_assert(typecvt->input().size() == 1);
  1745. auto deconv = try_cast_as_op<opr::ConvolutionBackwardData>(
  1746. rewriter.get_var(typecvt->input(0))->owner_opr());
  1747. if (!deconv
  1748. || m_deps.count(typecvt->input(0)) != 1 ||
  1749. typecvt->output(0)->dtype().enumv() !=
  1750. DTypeTrait<dtype::QuantizedS8>::enumv) {
  1751. return nullptr;
  1752. }
  1753. if (!uniq_reader_check(deconv->output(0)))
  1754. return nullptr;
  1755. auto config = deconv->config();
  1756. config.output_dtype(typecvt->output(0)->dtype());
  1757. return opr::ConvolutionBackwardData::make(
  1758. deconv->input(0), deconv->input(1), deconv->param(),
  1759. deconv->execution_policy(), config)
  1760. .node()
  1761. ->owner_opr();
  1762. };
  1763. auto on_opr = [&](OperatorNodeBase* opr) {
  1764. if (auto typecvt = try_cast_as_op<opr::TypeCvt>(opr)) {
  1765. if (auto deconv_new = try_fuse_deconv_typecvt(typecvt)) {
  1766. rewriter.replace_var(
  1767. opr->output(0), deconv_new->output(0),
  1768. mgb_cstr_log("replace typecvt(deconv(x, w)) -> "
  1769. "deconv(x, w)"));
  1770. uniq_reader_check.update_on_opr_auto_replace(opr, deconv_new);
  1771. return;
  1772. }
  1773. }
  1774. auto new_opr = rewriter.auto_replace_outputs(opr);
  1775. uniq_reader_check.update_on_opr_auto_replace(
  1776. opr, new_opr);
  1777. };
  1778. state.graph().iter(on_opr);
  1779. rewriter.apply_inplace();
  1780. }
  1781. /* ================ ParamMergePass ================ */
  1782. const char* ParamMergePass::name() const {
  1783. return mgb_cstr_log("param_merge");
  1784. }
  1785. void ParamMergePass::apply(OptState& opt_state) const {
  1786. param_merge<opr::SharedDeviceTensor, opr::MultipleDeviceTensorHolder>(
  1787. opt_state);
  1788. param_merge<opr::SharedDeviceTensorWithFormat,
  1789. opr::MultipleDeviceTensorWithFormatHolder>(opt_state);
  1790. }
  1791. /* ================ TensorReformatPass =============== */
  1792. /*!
  1793. * \brief relayout placeholder opr
  1794. *
  1795. * RelayoutPlaceholder oprs act as the placeholders of the ComputingGraph
  1796. * during graph opt pass `TensorReformatPass`. These oprs are introduced
  1797. * into a ComputingGraph for conveniently discovering further optimize
  1798. * opportunities (such as fuse consecutive relayouts, translate into
  1799. * optimized implementations). They are canonized to have a shape infer, so
  1800. * the ouput's shape can be correctly deduced during the opt pass.
  1801. *
  1802. * Note that the oprs in the ComputingGraph are only used as intermediate
  1803. * representations before being translated to MegBrain oprs, so the
  1804. * oprs should not get involved in any actual computing.
  1805. */
  1806. MGB_DEFINE_OPR_CLASS(TensorReformatPass::RelayoutPlaceholder,
  1807. cg::SingleCNOperatorNodeBase) // {
  1808. public:
  1809. //! relayout type of this opr
  1810. enum class LayoutType {
  1811. NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout
  1812. NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout
  1813. NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout
  1814. CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout
  1815. NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout
  1816. NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout
  1817. WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88
  1818. //!< layout
  1819. WEIGHT_NCHW_TO_NCHW88_GROUP, //!< group weight from nchw layout to
  1820. //!< nchw88 layout
  1821. WEIGHT_NCHW_TO_NCHW88_CHAN, //!< channel wise weight from nchw layout
  1822. //!< to nchw88 layout
  1823. //!< the weight layout of input is nchw output is nchw88, special for
  1824. //!< shape weight in nchw like {64, 2, 3, 3} to {8, 3, 3, 2, 8}
  1825. WEIGHT_HYBIRD_NCHW_NCHW88,
  1826. };
  1827. RelayoutPlaceholder(VarNode* src_var, LayoutType layout_type);
  1828. /*!
  1829. * \param src_var the input var
  1830. * \param layout_type tensor layout transform type of this relayout
  1831. * placeholder as described in LayoutType
  1832. */
  1833. static SymbolVar make(VarNode* src_var, LayoutType layout_type);
  1834. LayoutType layout_type() const { return m_layout_type; }
  1835. private:
  1836. void init_output_static_infer_desc() override;
  1837. void scn_do_execute() override;
  1838. void init_output_comp_node() override;
  1839. const LayoutType m_layout_type;
  1840. };
  1841. MGB_DYN_TYPE_OBJ_FINAL_IMPL(TensorReformatPass::RelayoutPlaceholder);
  1842. TensorReformatPass::RelayoutPlaceholder::RelayoutPlaceholder(
  1843. VarNode* src_var, LayoutType layout_type)
  1844. : Super(src_var->owner_graph(), {}, "RelayoutPlaceholder", {src_var}),
  1845. m_layout_type{layout_type} {
  1846. add_input({src_var});
  1847. add_equivalence_component<ScalarHash<LayoutType>>(m_layout_type);
  1848. add_output(None)->dtype(src_var->dtype());
  1849. }
  1850. void TensorReformatPass::RelayoutPlaceholder::scn_do_execute() {
  1851. mgb_throw(InternalError, "RelayoutPlaceholder opr can not be executed");
  1852. }
  1853. void TensorReformatPass::RelayoutPlaceholder::init_output_comp_node() {
  1854. output(0)->comp_node(input(0)->comp_node());
  1855. }
  1856. void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
  1857. using namespace cg::static_infer;
  1858. auto&& mgr = owner_graph()->static_infer_manager();
  1859. DepVal deps;
  1860. for (auto i : input())
  1861. deps.push_back({i, DepType::SHAPE});
  1862. auto infer_shape = [this](TensorShape& dst, const InpVal& inp) {
  1863. TensorShape inp_shape = inp.val[0].shape();
  1864. dst = inp_shape;
  1865. if (layout_type() == RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32) {
  1866. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
  1867. dst[0] = inp_shape[0];
  1868. dst[1] = inp_shape[1] / 8;
  1869. dst[2] = inp_shape[2];
  1870. dst[3] = inp_shape[3];
  1871. dst[4] = inp_shape[4] * 8;
  1872. } else if (layout_type() ==
  1873. RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4) {
  1874. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 32);
  1875. dst[0] = inp_shape[0];
  1876. dst[1] = inp_shape[1] * 8;
  1877. dst[2] = inp_shape[2];
  1878. dst[3] = inp_shape[3];
  1879. dst[4] = inp_shape[4] / 8;
  1880. } else if (layout_type() ==
  1881. RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4) {
  1882. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
  1883. dst[0] = inp_shape[1];
  1884. dst[1] = inp_shape[2];
  1885. dst[2] = inp_shape[3];
  1886. dst[3] = inp_shape[0];
  1887. dst[4] = inp_shape[4];
  1888. } else if (layout_type() ==
  1889. RelayoutPlaceholder::LayoutType::CHWN4_TO_NCHW4) {
  1890. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
  1891. dst[0] = inp_shape[3];
  1892. dst[1] = inp_shape[0];
  1893. dst[2] = inp_shape[1];
  1894. dst[3] = inp_shape[2];
  1895. dst[4] = inp_shape[4];
  1896. } else if (layout_type() ==
  1897. RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW88) {
  1898. mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 8 == 0);
  1899. dst.ndim = 5;
  1900. dst[0] = inp_shape[0];
  1901. dst[1] = inp_shape[1] / 8;
  1902. dst[2] = inp_shape[2];
  1903. dst[3] = inp_shape[3];
  1904. dst[4] = 8;
  1905. } else if (layout_type() ==
  1906. RelayoutPlaceholder::LayoutType::NCHW88_TO_NCHW) {
  1907. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 8);
  1908. dst.ndim = 4;
  1909. dst[0] = inp_shape[0];
  1910. dst[1] = inp_shape[1] * 8;
  1911. dst[2] = inp_shape[2];
  1912. dst[3] = inp_shape[3];
  1913. } else if (layout_type() == RelayoutPlaceholder::LayoutType::
  1914. WEIGHT_NCHW_TO_NCHW88_DENSE) {
  1915. mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 8 == 0 &&
  1916. inp_shape[1] % 8 == 0);
  1917. dst.ndim = 6;
  1918. dst[0] = inp_shape[0] / 8;
  1919. dst[1] = inp_shape[1] / 8;
  1920. dst[2] = inp_shape[2];
  1921. dst[3] = inp_shape[3];
  1922. dst[4] = 8;
  1923. dst[5] = 8;
  1924. } else if (layout_type() == RelayoutPlaceholder::LayoutType::
  1925. WEIGHT_NCHW_TO_NCHW88_GROUP) {
  1926. mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 8 == 0 &&
  1927. inp_shape[2] % 8 == 0);
  1928. dst.ndim = 7;
  1929. dst[0] = inp_shape[0];
  1930. dst[1] = inp_shape[1] / 8;
  1931. dst[2] = inp_shape[2] / 8;
  1932. dst[3] = inp_shape[3];
  1933. dst[4] = inp_shape[4];
  1934. dst[5] = 8;
  1935. dst[6] = 8;
  1936. } else if (layout_type() == RelayoutPlaceholder::LayoutType::
  1937. WEIGHT_NCHW_TO_NCHW88_CHAN) {
  1938. mgb_assert(inp_shape.ndim == 5 && inp_shape[1] == 1 &&
  1939. inp_shape[2] == 1 && inp_shape[0] % 8 == 0);
  1940. dst.ndim = 6;
  1941. dst[0] = inp_shape[0] / 8;
  1942. dst[1] = inp_shape[1];
  1943. dst[2] = inp_shape[2];
  1944. dst[3] = inp_shape[3];
  1945. dst[4] = inp_shape[4];
  1946. dst[5] = 8;
  1947. } else {
  1948. mgb_assert(
  1949. layout_type() ==
  1950. RelayoutPlaceholder::LayoutType::WEIGHT_HYBIRD_NCHW_NCHW88);
  1951. mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 8 == 0);
  1952. dst.ndim = 5;
  1953. dst[0] = inp_shape[0] / 8;
  1954. dst[1] = inp_shape[2];
  1955. dst[2] = inp_shape[3];
  1956. dst[3] = inp_shape[1];
  1957. dst[4] = 8;
  1958. }
  1959. return true;
  1960. };
  1961. mgr.register_shape_infer(output(0), {SourceType::DEP, deps, infer_shape});
  1962. }
  1963. SymbolVar TensorReformatPass::RelayoutPlaceholder::make(
  1964. VarNode* src_var, LayoutType layout_type) {
  1965. return src_var->owner_graph()
  1966. ->insert_opr(
  1967. std::make_unique<RelayoutPlaceholder>(src_var, layout_type))
  1968. ->output(0);
  1969. }
  1970. void TensorReformatPass::insert_pass(OptState& opt) const {
  1971. opt.set_var_replace_check_flag(m_var_replace_check_flag);
  1972. auto rewriter = opt.graph().make_rewriter();
  1973. VarNodeArray new_inp_cache;
  1974. auto on_opr = [this, &opt, &rewriter,
  1975. &new_inp_cache](OperatorNodeBase* opr) {
  1976. auto it = m_opr_replace_func.find(opr->dyn_typeinfo());
  1977. if (it != m_opr_replace_func.end()) {
  1978. auto& new_inp = new_inp_cache;
  1979. new_inp.clear();
  1980. new_inp.reserve(opr->input().size());
  1981. for (auto&& inp : opr->input()) {
  1982. new_inp.push_back(rewriter.get_var(inp));
  1983. }
  1984. auto new_opr = (it->second)(opr, new_inp);
  1985. auto &&out0 = opr->output(), &&out1 = new_opr->output();
  1986. mgb_assert(out0.size() == out1.size(),
  1987. "bad opr replace: src=%s{%s} dst=%s{%s}, src.size=%zu "
  1988. "dst.size=%zu",
  1989. opr->cname(), opr->dyn_typeinfo()->name,
  1990. new_opr->cname(), new_opr->dyn_typeinfo()->name,
  1991. out0.size(), out1.size());
  1992. for (size_t i = 0; i < out0.size(); ++i) {
  1993. if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  1994. mgb_assert(!out1[i]->contain_flag(
  1995. VarNode::Flag::VOLATILE_CONTENT));
  1996. auto src = out0[i];
  1997. auto dst = out1[i];
  1998. if (opt.graph().endpoint_contain(src)) {
  1999. // additional process on endpoint var node
  2000. dst = on_graph_endpoint_var(dst, src);
  2001. }
  2002. rewriter.replace_var(src, dst, nullptr);
  2003. }
  2004. }
  2005. } else {
  2006. rewriter.auto_replace_outputs(opr);
  2007. }
  2008. };
  2009. opt.graph().iter(on_opr);
  2010. rewriter.apply_inplace();
  2011. }
  2012. void TensorReformatPass::translate_pass(OptState& opt) const {
  2013. ThinHashMap<RelayoutPlaceholder::LayoutType,
  2014. thin_function<VarNode*(VarNode*)>>
  2015. reformat;
  2016. using LayoutType = RelayoutPlaceholder::LayoutType;
  2017. reformat[LayoutType::NCHW4_TO_CHWN4] = [](VarNode* inp) -> VarNode* {
  2018. megdnn::param::RelayoutFormat param;
  2019. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW4_CHWN4;
  2020. auto reformat = opr::RelayoutFormat::make(inp, param);
  2021. return reformat.node();
  2022. };
  2023. reformat[LayoutType::CHWN4_TO_NCHW4] = [](VarNode* inp) -> VarNode* {
  2024. megdnn::param::RelayoutFormat param;
  2025. param.mode = megdnn::param::RelayoutFormat::Mode::CHWN4_NCHW4;
  2026. auto reformat = opr::RelayoutFormat::make(inp, param);
  2027. return reformat.node();
  2028. };
  2029. reformat[LayoutType::NCHW4_TO_NCHW32] = [](VarNode* inp) -> VarNode* {
  2030. auto x = SymbolVar(inp);
  2031. auto xshp = opr::GetVarShape::make(x);
  2032. auto cv = [&x](int v) { return x.make_scalar(v); };
  2033. auto sub = [&xshp, &cv](int idx) {
  2034. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  2035. };
  2036. auto tshp0 = opr::Concat::make(
  2037. {sub(0), sub(1) / 8, cv(8), sub(2), sub(3), sub(4)}, 0),
  2038. tshp1 = opr::Concat::make(
  2039. {sub(0), sub(1) / 8, sub(2), sub(3), sub(4) * 8}, 0);
  2040. auto y0 = opr::Reshape::make(x, tshp0);
  2041. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2, 5});
  2042. auto y2 = opr::Reshape::make(y1, tshp1);
  2043. return y2.node();
  2044. };
  2045. reformat[LayoutType::NCHW32_TO_NCHW4] = [](VarNode* inp) -> VarNode* {
  2046. auto x = SymbolVar(inp);
  2047. auto xshp = opr::GetVarShape::make(x);
  2048. auto cv = [&x](int v) { return x.make_scalar(v); };
  2049. auto sub = [&xshp, &cv](int idx) {
  2050. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  2051. };
  2052. auto tshp0 = opr::Concat::make(
  2053. {sub(0), sub(1), sub(2), sub(3), cv(8), sub(4) / 8}, 0),
  2054. tshp1 = opr::Concat::make(
  2055. {sub(0), sub(1) * 8, sub(2), sub(3), sub(4) / 8}, 0);
  2056. auto y0 = opr::Reshape::make(x, tshp0);
  2057. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 4, 2, 3, 5});
  2058. auto y2 = opr::Reshape::make(y1, tshp1);
  2059. return y2.node();
  2060. };
  2061. reformat[LayoutType::NCHW_TO_NCHW88] = [](VarNode* inp) -> VarNode* {
  2062. auto x = SymbolVar(inp);
  2063. auto xshp = opr::GetVarShape::make(x);
  2064. auto cv = [&x](int v) { return x.make_scalar(v); };
  2065. auto sub = [&xshp, &cv](int idx) {
  2066. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  2067. };
  2068. auto tshp0 = opr::Concat::make(
  2069. {sub(0), sub(1) / 8, cv(8), sub(2), sub(3)}, 0),
  2070. tshp1 = opr::Concat::make(
  2071. {sub(0), sub(1) / 8, sub(2), sub(3), cv(8)}, 0);
  2072. auto y0 = opr::Reshape::make(x, tshp0);
  2073. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
  2074. auto y2 = opr::Reshape::make(y1, tshp1);
  2075. return y2.node();
  2076. };
  2077. reformat[LayoutType::NCHW88_TO_NCHW] = [](VarNode* inp) -> VarNode* {
  2078. auto x = SymbolVar(inp);
  2079. auto xshp = opr::GetVarShape::make(x);
  2080. auto cv = [&x](int v) { return x.make_scalar(v); };
  2081. auto sub = [&xshp, &cv](int idx) {
  2082. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  2083. };
  2084. auto tshp0 = opr::Concat::make({sub(0), sub(1) * 8, sub(2), sub(3)}, 0);
  2085. auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
  2086. auto y1 = opr::Reshape::make(y0, tshp0);
  2087. return y1.node();
  2088. };
  2089. reformat[LayoutType::WEIGHT_NCHW_TO_NCHW88_DENSE] =
  2090. [](VarNode* inp) -> VarNode* {
  2091. auto x = SymbolVar(inp);
  2092. auto xshp = opr::GetVarShape::make(x);
  2093. auto cv = [&x](int v) { return x.make_scalar(v); };
  2094. auto sub = [&xshp, &cv](int idx) {
  2095. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  2096. };
  2097. auto tshp0 = opr::Concat::make(
  2098. {sub(0) / 8, cv(8), sub(1) / 8, cv(8), sub(2), sub(3)}, 0),
  2099. tshp1 = opr::Concat::make(
  2100. {sub(0) / 8, sub(1) / 8, sub(2), sub(3), cv(8), cv(8)}, 0);
  2101. auto y0 = opr::Reshape::make(x, tshp0);
  2102. auto y1 = opr::Dimshuffle::make(y0, {0, 2, 4, 5, 3, 1});
  2103. auto y2 = opr::Reshape::make(y1, tshp1);
  2104. return y2.node();
  2105. };
  2106. reformat[LayoutType::WEIGHT_NCHW_TO_NCHW88_GROUP] =
  2107. [](VarNode* inp) -> VarNode* {
  2108. auto x = SymbolVar(inp);
  2109. auto xshp = opr::GetVarShape::make(x);
  2110. auto cv = [&x](int v) { return x.make_scalar(v); };
  2111. auto sub = [&xshp, &cv](int idx) {
  2112. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  2113. };
  2114. auto tshp0 = opr::Concat::make({sub(0), sub(1) / 8, cv(8), sub(2) / 8,
  2115. cv(8), sub(3), sub(4)},
  2116. 0),
  2117. tshp1 = opr::Concat::make({sub(0), sub(1) / 8, sub(2) / 8, sub(3),
  2118. sub(4), cv(8), cv(8)},
  2119. 0);
  2120. auto y0 = opr::Reshape::make(x, tshp0);
  2121. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 4, 2});
  2122. auto y2 = opr::Reshape::make(y1, tshp1);
  2123. return y2.node();
  2124. };
  2125. reformat[LayoutType::WEIGHT_NCHW_TO_NCHW88_CHAN] =
  2126. [](VarNode* inp) -> VarNode* {
  2127. auto x = SymbolVar(inp);
  2128. auto xshp = opr::GetVarShape::make(x);
  2129. auto cv = [&x](int v) { return x.make_scalar(v); };
  2130. auto sub = [&xshp, &cv](int idx) {
  2131. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  2132. };
  2133. auto tshp0 = opr::Concat::make(
  2134. {sub(0) / 8, cv(8), sub(1), sub(2), sub(3), sub(4)}, 0),
  2135. tshp1 = opr::Concat::make(
  2136. {sub(0) / 8, sub(1), sub(2), sub(3), sub(4), cv(8)}, 0);
  2137. auto y0 = opr::Reshape::make(x, tshp0);
  2138. auto y1 = opr::Dimshuffle::make(y0, {0, 2, 3, 4, 5, 1});
  2139. auto y2 = opr::Reshape::make(y1, tshp1);
  2140. return y2.node();
  2141. };
  2142. reformat[LayoutType::WEIGHT_HYBIRD_NCHW_NCHW88] =
  2143. [](VarNode* inp) -> VarNode* {
  2144. auto x = SymbolVar(inp);
  2145. auto xshp = opr::GetVarShape::make(x);
  2146. auto cv = [&x](int v) { return x.make_scalar(v); };
  2147. auto sub = [&xshp, &cv](int idx) {
  2148. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  2149. };
  2150. auto tshp0 = opr::Concat::make(
  2151. {sub(0) / 8, cv(8), sub(1), sub(2), sub(3)}, 0),
  2152. tshp1 = opr::Concat::make(
  2153. {sub(0) / 8, sub(2), sub(3), sub(1), cv(8)}, 0);
  2154. auto y0 = opr::Reshape::make(x, tshp0);
  2155. auto y1 = opr::Dimshuffle::make(y0, {0, 3, 4, 2, 1});
  2156. auto y2 = opr::Reshape::make(y1, tshp1);
  2157. return y2.node();
  2158. };
  2159. auto rewriter = opt.graph().make_rewriter();
  2160. auto on_opr = [&reformat, &rewriter](OperatorNodeBase* opr) {
  2161. if (opr->same_type<RelayoutPlaceholder>()) {
  2162. auto ph = try_cast_as_op<RelayoutPlaceholder>(opr);
  2163. auto new_inp = rewriter.get_var(opr->input(0));
  2164. mgb_assert(reformat.count(ph->layout_type()),
  2165. "no replace rule can be found for layout_type(%u)",
  2166. static_cast<uint32_t>(ph->layout_type()));
  2167. auto new_var = reformat[ph->layout_type()](new_inp);
  2168. rewriter.replace_var(opr->output(0), new_var,
  2169. mgb_cstr_log("replace relayout placeholder"));
  2170. return;
  2171. }
  2172. rewriter.auto_replace_outputs(opr);
  2173. };
  2174. opt.graph().iter(on_opr);
  2175. rewriter.apply_inplace();
  2176. }
  2177. void TensorReformatPass::apply(OptState& opt) const {
  2178. insert_pass(opt);
  2179. translate_pass(opt);
  2180. }
  2181. /* ================ EnableTensorCorePass =============== */
  2182. VarNode* EnableTensorCorePass::on_graph_endpoint_var(VarNode* new_var,
  2183. VarNode* orig_var) const {
  2184. if (!orig_var->shape().eq_shape(new_var->shape())) {
  2185. return RelayoutPlaceholder::make(
  2186. new_var,
  2187. RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4)
  2188. .node();
  2189. }
  2190. return new_var;
  2191. }
  2192. std::unique_ptr<EnableTensorCorePass>
  2193. EnableTensorCorePass::make_tensorcore_converter() {
  2194. // replace rule for conv bias opr
  2195. auto replace_conv_bias_opr = [](OperatorNodeBase* opr,
  2196. const VarNodeArray& new_inp) {
  2197. using Param = megdnn::param::ConvBias;
  2198. using Format = Param::Format;
  2199. using Sparse = Param::Sparse;
  2200. mgb_assert(opr->input().size() == new_inp.size());
  2201. auto& conv_bias = opr->cast_final_safe<opr::ConvBiasForward>();
  2202. if (conv_bias.param().format != Format::NCHW4 ||
  2203. conv_bias.output(0)->dtype().enumv() != DTypeEnum::QuantizedS8) {
  2204. size_t nr_inps = opr->input().size();
  2205. bool shape_has_changed = false;
  2206. for (size_t i = 0; i < nr_inps; ++i) {
  2207. if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
  2208. shape_has_changed = true;
  2209. }
  2210. }
  2211. MGB_MARK_USED_VAR(shape_has_changed);
  2212. mgb_assert(
  2213. !shape_has_changed,
  2214. "EnableTensorCorePass assumes that the shape of inputs of"
  2215. "ConvBias operators whose output dtype is not QuantizedS8 "
  2216. "can not be changed in this opt pass");
  2217. return serialization::copy_opr_shallow(*opr, new_inp,
  2218. opr->config());
  2219. }
  2220. mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape()),
  2221. "EnableTensorCorePass assumes that filter tensor of "
  2222. "conv_bias operator can not be changed by other operators");
  2223. VarNode* orig_filter = opr->input(1);
  2224. auto is_nchw4 = [](TensorShape shape) -> bool {
  2225. return shape.ndim == 5 && shape[4] == 4;
  2226. };
  2227. auto is_nchw32 = [](TensorShape shape) -> bool {
  2228. return shape.ndim == 5 && shape[4] == 32;
  2229. };
  2230. bool can_replace_nchw32 = false;
  2231. VarNode *src = nullptr, *weight = nullptr, *bias = nullptr,
  2232. *z_inp = nullptr;
  2233. // process src tensor
  2234. if (is_nchw4(new_inp[0]->shape())) { // new input is NCHW4 layout
  2235. size_t group = 1, icpg, ocpg;
  2236. if (conv_bias.param().sparse == Sparse::DENSE) {
  2237. icpg = orig_filter->shape()[1] * 4;
  2238. ocpg = orig_filter->shape()[0];
  2239. } else {
  2240. mgb_assert(conv_bias.param().sparse == Sparse::GROUP);
  2241. group = orig_filter->shape()[0];
  2242. icpg = orig_filter->shape()[2];
  2243. ocpg = orig_filter->shape()[1];
  2244. if (icpg == 1 && ocpg == 1) { // channel wise conv
  2245. group *= 4;
  2246. } else {
  2247. icpg *= 4;
  2248. }
  2249. }
  2250. // nchw32 layout need that input width and height are larger than 3
  2251. size_t ih = new_inp[0]->shape()[2], iw = new_inp[0]->shape()[3];
  2252. if (group == 1 && ocpg % 32 == 0 && icpg % 32 == 0 && ih >= 3 &&
  2253. iw >= 3) {
  2254. auto symvar = RelayoutPlaceholder::make(
  2255. new_inp[0],
  2256. RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32);
  2257. src = symvar.node();
  2258. can_replace_nchw32 = true;
  2259. } else {
  2260. src = new_inp[0];
  2261. }
  2262. } else { // new input is NCHW32 layout
  2263. mgb_assert(is_nchw32(new_inp[0]->shape()));
  2264. size_t group = 1, ocpg;
  2265. if (conv_bias.param().sparse == Sparse::DENSE) {
  2266. ocpg = orig_filter->shape()[0];
  2267. } else {
  2268. mgb_assert(conv_bias.param().sparse == Sparse::GROUP);
  2269. size_t icpg = orig_filter->shape()[2];
  2270. ocpg = orig_filter->shape()[1];
  2271. if (icpg == 1 && ocpg == 1) {
  2272. group *= 4;
  2273. } else {
  2274. icpg *= 4;
  2275. }
  2276. }
  2277. size_t ih = new_inp[0]->shape()[2], iw = new_inp[0]->shape()[3];
  2278. if (group == 1 && ocpg % 32 == 0 && ih >= 3 && iw >= 3) {
  2279. can_replace_nchw32 = true;
  2280. src = new_inp[0];
  2281. } else {
  2282. auto symvar = RelayoutPlaceholder::make(
  2283. new_inp[0],
  2284. RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4);
  2285. src = symvar.node();
  2286. }
  2287. }
  2288. // process filter tensor
  2289. if (can_replace_nchw32) {
  2290. auto symvar = RelayoutPlaceholder::make(
  2291. new_inp[1],
  2292. RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32);
  2293. weight = symvar.node();
  2294. } else {
  2295. weight = new_inp[1];
  2296. }
  2297. if (new_inp.size() == 2) {
  2298. if (can_replace_nchw32) {
  2299. auto param = conv_bias.param();
  2300. param.format = Format::NCHW32;
  2301. auto new_opr = opr::ConvBiasForward::make(
  2302. src, weight, param, conv_bias.execution_policy(),
  2303. conv_bias.config());
  2304. return new_opr.node()->owner_opr();
  2305. } else {
  2306. VarNodeArray inps{src, weight};
  2307. auto new_opr = serialization::copy_opr_shallow(*opr, inps,
  2308. opr->config());
  2309. return new_opr;
  2310. }
  2311. }
  2312. auto process_inp = [&](VarNode* inp) -> VarNode* {
  2313. if (can_replace_nchw32) {
  2314. if (is_nchw4(inp->shape())) {
  2315. auto symvar = RelayoutPlaceholder::make(
  2316. inp,
  2317. RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32);
  2318. return symvar.node();
  2319. } else {
  2320. mgb_assert(is_nchw32(inp->shape()));
  2321. return inp;
  2322. }
  2323. } else {
  2324. if (is_nchw4(inp->shape())) {
  2325. return inp;
  2326. } else {
  2327. mgb_assert(is_nchw32(inp->shape()));
  2328. auto symvar = RelayoutPlaceholder::make(
  2329. inp,
  2330. RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4);
  2331. return symvar.node();
  2332. }
  2333. }
  2334. };
  2335. // process bias tensor
  2336. bias = process_inp(new_inp[2]);
  2337. if (new_inp.size() == 3) {
  2338. if (can_replace_nchw32) {
  2339. auto param = conv_bias.param();
  2340. param.format = Format::NCHW32;
  2341. auto new_opr = opr::ConvBiasForward::make(
  2342. src, weight, bias, param, conv_bias.execution_policy(),
  2343. conv_bias.config());
  2344. return new_opr.node()->owner_opr();
  2345. } else {
  2346. VarNodeArray inps{src, weight, bias};
  2347. auto new_opr = serialization::copy_opr_shallow(*opr, inps,
  2348. opr->config());
  2349. return new_opr;
  2350. }
  2351. }
  2352. // process z_inp tensor
  2353. z_inp = process_inp(new_inp[3]);
  2354. if (can_replace_nchw32) {
  2355. auto param = conv_bias.param();
  2356. param.format = Format::NCHW32;
  2357. auto new_opr = opr::ConvBiasForward::make(
  2358. src, weight, bias, z_inp, param,
  2359. conv_bias.execution_policy(), conv_bias.config());
  2360. return new_opr.node()->owner_opr();
  2361. }
  2362. VarNodeArray inps{src, weight, bias, z_inp};
  2363. auto new_opr =
  2364. serialization::copy_opr_shallow(*opr, inps, opr->config());
  2365. return new_opr;
  2366. };
  2367. // replace rule for elemwise like opr
  2368. // for oprs support NCHW4 and NCHW32 layout
  2369. auto replace_elemwise_like_opr = [](OperatorNodeBase* opr,
  2370. const VarNodeArray new_inp) {
  2371. mgb_assert(opr->input().size() == new_inp.size());
  2372. size_t nr_inps = new_inp.size();
  2373. size_t nr_shape_changed = 0;
  2374. for (size_t i = 0; i < nr_inps; ++i) {
  2375. if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
  2376. nr_shape_changed++;
  2377. }
  2378. }
  2379. if (nr_shape_changed) {
  2380. auto inps = new_inp;
  2381. if (nr_shape_changed >=
  2382. nr_inps / 2) { // NCHW32 > NCHW4 -> use NCHW32
  2383. for (size_t i = 0; i < nr_inps; ++i) {
  2384. if (opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
  2385. auto symvar = RelayoutPlaceholder::make(
  2386. new_inp[i], RelayoutPlaceholder::LayoutType::
  2387. NCHW4_TO_NCHW32);
  2388. inps[i] = symvar.node();
  2389. }
  2390. }
  2391. } else { // NCHW32 < NCHW4 -> use NCHW4
  2392. for (size_t i = 0; i < nr_inps; ++i) {
  2393. if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
  2394. auto symvar = RelayoutPlaceholder::make(
  2395. new_inp[i], RelayoutPlaceholder::LayoutType::
  2396. NCHW32_TO_NCHW4);
  2397. inps[i] = symvar.node();
  2398. }
  2399. }
  2400. }
  2401. return serialization::copy_opr_shallow(*opr, inps, opr->config());
  2402. }
  2403. return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  2404. };
  2405. // for oprs only supports NCHW4 layout
  2406. auto replace_inps_to_nchw4 = [](OperatorNodeBase* opr,
  2407. const VarNodeArray new_inp) {
  2408. mgb_assert(opr->input().size() == new_inp.size());
  2409. VarNodeArray inps = new_inp;
  2410. for (size_t i = 0; i < opr->input().size(); ++i) {
  2411. if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
  2412. mgb_assert(opr->input(i)->shape().ndim == 5 &&
  2413. opr->input(i)->shape()[4] == 4);
  2414. mgb_assert(new_inp[i]->shape().ndim == 5 &&
  2415. new_inp[i]->shape()[4] == 32);
  2416. auto symvar = RelayoutPlaceholder::make(
  2417. new_inp[i],
  2418. RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4);
  2419. inps[i] = symvar.node();
  2420. }
  2421. }
  2422. auto new_opr =
  2423. serialization::copy_opr_shallow(*opr, inps, opr->config());
  2424. return new_opr;
  2425. };
  2426. auto replace_non_nchw4_opr = [](OperatorNodeBase* opr,
  2427. const VarNodeArray new_inp) {
  2428. size_t nr_inps = opr->input().size();
  2429. bool shape_has_changed = false;
  2430. for (size_t i = 0; i < nr_inps; ++i) {
  2431. if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
  2432. shape_has_changed = true;
  2433. }
  2434. }
  2435. mgb_assert(!shape_has_changed,
  2436. "EnableTensorCorePass assumes that inputs' shape of "
  2437. "non-nchw4 operators "
  2438. "can not be changed in this opt "
  2439. "pass");
  2440. return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  2441. };
  2442. auto replace_warp_affine_opr =
  2443. [replace_inps_to_nchw4, replace_non_nchw4_opr](
  2444. OperatorNodeBase* opr, const VarNodeArray new_inp) {
  2445. using Param = opr::WarpAffineForward::Param;
  2446. using Format = Param::Format;
  2447. mgb_assert(opr->input().size() == new_inp.size());
  2448. auto& warp = opr->cast_final_safe<opr::WarpAffineForward>();
  2449. if (warp.param().format != Format::NCHW4) {
  2450. return replace_non_nchw4_opr(opr, new_inp);
  2451. }
  2452. return replace_inps_to_nchw4(opr, new_inp);
  2453. };
  2454. auto replace_warp_perspective_opr =
  2455. [replace_inps_to_nchw4, replace_non_nchw4_opr](
  2456. OperatorNodeBase* opr, const VarNodeArray new_inp) {
  2457. using Param = opr::WarpPerspectiveForward::Param;
  2458. using Format = Param::Format;
  2459. mgb_assert(opr->input().size() == new_inp.size());
  2460. auto& warp =
  2461. opr->cast_final_safe<opr::WarpPerspectiveForward>();
  2462. if (warp.param().format != Format::NCHW4) {
  2463. return replace_non_nchw4_opr(opr, new_inp);
  2464. }
  2465. return replace_inps_to_nchw4(opr, new_inp);
  2466. };
  2467. auto replace_resize_opr = [replace_inps_to_nchw4, replace_non_nchw4_opr](
  2468. OperatorNodeBase* opr,
  2469. const VarNodeArray new_inp) {
  2470. using Param = opr::ResizeForward::Param;
  2471. using Format = Param::Format;
  2472. mgb_assert(opr->input().size() == new_inp.size());
  2473. auto& resize = opr->cast_final_safe<opr::ResizeForward>();
  2474. if (resize.param().format != Format::NCHW4) {
  2475. return replace_non_nchw4_opr(opr, new_inp);
  2476. }
  2477. return replace_inps_to_nchw4(opr, new_inp);
  2478. };
  2479. auto replace_pooling_opr = [replace_non_nchw4_opr](
  2480. OperatorNodeBase* opr,
  2481. const VarNodeArray new_inp) {
  2482. using Param = opr::PoolingForward::Param;
  2483. using Format = Param::Format;
  2484. mgb_assert(opr->input().size() == new_inp.size());
  2485. auto& pooling = opr->cast_final_safe<opr::PoolingForward>();
  2486. if (pooling.param().format != Format::NCHW4) {
  2487. return replace_non_nchw4_opr(opr, new_inp);
  2488. }
  2489. size_t nr_inps = opr->input().size();
  2490. MGB_MARK_USED_VAR(nr_inps);
  2491. mgb_assert(nr_inps == 1);
  2492. if (!opr->input(0)->shape().eq_shape(new_inp[0]->shape())) {
  2493. mgb_assert(opr->input(0)->shape().ndim == 5 &&
  2494. opr->input(0)->shape()[4] == 4);
  2495. mgb_assert(new_inp[0]->shape().ndim == 5 &&
  2496. new_inp[0]->shape()[4] == 32);
  2497. auto new_param = pooling.param();
  2498. new_param.format = Format::NCHW32;
  2499. auto new_pooling = opr::PoolingForward::make(new_inp[0], new_param,
  2500. opr->config());
  2501. return new_pooling.node()->owner_opr();
  2502. }
  2503. return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  2504. };
  2505. auto ret = std::make_unique<EnableTensorCorePass>();
  2506. ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
  2507. auto&& replace_func = ret->m_opr_replace_func;
  2508. replace_func[opr::ConvBiasForward::typeinfo()] = replace_conv_bias_opr;
  2509. // elemwise like
  2510. replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr;
  2511. replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr;
  2512. replace_func[opr::ElemwiseMultiType::typeinfo()] =
  2513. replace_elemwise_like_opr;
  2514. replace_func[opr::PowC::typeinfo()] = replace_elemwise_like_opr;
  2515. // format aware
  2516. replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
  2517. replace_func[opr::WarpAffineForward::typeinfo()] = replace_warp_affine_opr;
  2518. replace_func[opr::WarpPerspectiveForward::typeinfo()] =
  2519. replace_warp_perspective_opr;
  2520. replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr;
  2521. // to nchw4
  2522. replace_func[opr::Reduce::typeinfo()] = replace_inps_to_nchw4;
  2523. replace_func[opr::Concat::typeinfo()] = replace_inps_to_nchw4;
  2524. replace_func[opr::Reshape::typeinfo()] = replace_inps_to_nchw4;
  2525. replace_func[opr::GetVarShape::typeinfo()] = replace_inps_to_nchw4;
  2526. replace_func[opr::Dimshuffle::typeinfo()] = replace_inps_to_nchw4;
  2527. return ret;
  2528. }
  2529. /* ================ EnableCHWN4Pass =============== */
  2530. VarNode* EnableCHWN4Pass::on_graph_endpoint_var(VarNode* new_var,
  2531. VarNode* /* orig_var */) const {
  2532. if (m_varshape_changed.count(new_var)) {
  2533. return RelayoutPlaceholder::make(
  2534. new_var, RelayoutPlaceholder::LayoutType::CHWN4_TO_NCHW4)
  2535. .node();
  2536. }
  2537. return new_var;
  2538. }
  2539. std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() {
  2540. auto ret = std::make_unique<EnableCHWN4Pass>();
  2541. ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
  2542. auto&& replace_func = ret->m_opr_replace_func;
  2543. auto&& varshape_changed = ret->m_varshape_changed;
  2544. // replace rule for conv bias opr
  2545. auto replace_conv_bias_opr = [&varshape_changed](
  2546. OperatorNodeBase* opr,
  2547. const VarNodeArray& new_inp) {
  2548. using Param = megdnn::param::ConvBias;
  2549. using Format = Param::Format;
  2550. mgb_assert(opr->input().size() == new_inp.size());
  2551. auto& conv_bias = opr->cast_final_safe<opr::ConvBiasForward>();
  2552. if (conv_bias.param().format != Format::NCHW4 ||
  2553. conv_bias.output(0)->dtype().enumv() != DTypeEnum::QuantizedS8) {
  2554. size_t nr_inps = new_inp.size();
  2555. bool shape_has_changed = false;
  2556. for (size_t i = 0; i < nr_inps; ++i) {
  2557. if (varshape_changed.count(new_inp[i])) {
  2558. shape_has_changed = true;
  2559. break;
  2560. }
  2561. }
  2562. mgb_assert(
  2563. !shape_has_changed,
  2564. "EnableCHWN4Pass assumes that the shape of inputs of"
  2565. "ConvBias operators whose output dtype is not QuantizedS8 "
  2566. "can not be changed in this opt pass");
  2567. return serialization::copy_opr_shallow(*opr, new_inp,
  2568. opr->config());
  2569. }
  2570. mgb_assert(varshape_changed.count(new_inp[1]) == 0,
  2571. "EnableCHWN4Pass assumes that filter tensor of "
  2572. "conv_bias operator can not be changed by other operators");
  2573. VarNode *src = nullptr, *weight = nullptr, *bias = nullptr,
  2574. *z_inp = nullptr;
  2575. // process src tensor
  2576. if (varshape_changed.count(new_inp[0]) ==
  2577. 0) { // new input is NCHW4 layout
  2578. // currently not support group conv
  2579. auto symvar = RelayoutPlaceholder::make(
  2580. new_inp[0],
  2581. RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4);
  2582. src = symvar.node();
  2583. } else { // new input is NCHW32 layout
  2584. src = new_inp[0];
  2585. }
  2586. // process weight tensor
  2587. {
  2588. auto symvar = RelayoutPlaceholder::make(
  2589. new_inp[1],
  2590. RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4);
  2591. weight = symvar.node();
  2592. }
  2593. if (new_inp.size() == 2) {
  2594. auto param = conv_bias.param();
  2595. param.format = Format::CHWN4;
  2596. auto new_opr = opr::ConvBiasForward::make(
  2597. src, weight, param, conv_bias.execution_policy(),
  2598. conv_bias.config());
  2599. varshape_changed.insert(new_opr.node());
  2600. return new_opr.node()->owner_opr();
  2601. }
  2602. auto process_inp = [&](VarNode* inp) -> VarNode* {
  2603. if (varshape_changed.count(inp) == 0) {
  2604. auto symvar = RelayoutPlaceholder::make(
  2605. inp, RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4);
  2606. return symvar.node();
  2607. } else {
  2608. return inp;
  2609. }
  2610. };
  2611. // process bias tensor
  2612. bias = process_inp(new_inp[2]);
  2613. if (new_inp.size() == 3) {
  2614. auto param = conv_bias.param();
  2615. param.format = Format::CHWN4;
  2616. auto new_opr = opr::ConvBiasForward::make(
  2617. src, weight, bias, param, conv_bias.execution_policy(),
  2618. conv_bias.config());
  2619. varshape_changed.insert(new_opr.node());
  2620. return new_opr.node()->owner_opr();
  2621. }
  2622. // process z_inp tensor
  2623. z_inp = process_inp(new_inp[3]);
  2624. auto param = conv_bias.param();
  2625. param.format = Format::CHWN4;
  2626. auto new_opr = opr::ConvBiasForward::make(
  2627. src, weight, bias, z_inp, param, conv_bias.execution_policy(),
  2628. conv_bias.config());
  2629. varshape_changed.insert(new_opr.node());
  2630. return new_opr.node()->owner_opr();
  2631. };
  2632. // replace rule for elemwise like opr
  2633. // for oprs support NCHW4 and CHWN4 layout
  2634. auto replace_elemwise_like_opr = [&varshape_changed](
  2635. OperatorNodeBase* opr,
  2636. const VarNodeArray new_inp) {
  2637. mgb_assert(opr->input().size() == new_inp.size());
  2638. size_t nr_inps = new_inp.size();
  2639. size_t nr_shape_changed = 0;
  2640. for (size_t i = 0; i < nr_inps; ++i) {
  2641. if (varshape_changed.count(new_inp[i])) {
  2642. nr_shape_changed++;
  2643. }
  2644. }
  2645. if (nr_shape_changed) {
  2646. auto inps = new_inp;
  2647. if (nr_shape_changed >= nr_inps / 2) { // CHWN4 > NCHW4 -> use CHWN4
  2648. for (size_t i = 0; i < nr_inps; ++i) {
  2649. if (varshape_changed.count(new_inp[i]) == 0) {
  2650. auto symvar = RelayoutPlaceholder::make(
  2651. new_inp[i], RelayoutPlaceholder::LayoutType::
  2652. NCHW4_TO_CHWN4);
  2653. inps[i] = symvar.node();
  2654. }
  2655. }
  2656. auto new_opr = serialization::copy_opr_shallow(*opr, inps,
  2657. opr->config());
  2658. varshape_changed.insert(new_opr->output(0));
  2659. return new_opr;
  2660. } else { // CHWN4 < NCHW4 -> use NCHW4
  2661. for (size_t i = 0; i < nr_inps; ++i) {
  2662. if (varshape_changed.count(new_inp[i])) {
  2663. auto symvar = RelayoutPlaceholder::make(
  2664. new_inp[i], RelayoutPlaceholder::LayoutType::
  2665. CHWN4_TO_NCHW4);
  2666. inps[i] = symvar.node();
  2667. }
  2668. }
  2669. return serialization::copy_opr_shallow(*opr, inps,
  2670. opr->config());
  2671. }
  2672. }
  2673. return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  2674. };
  2675. // for oprs only supports NCHW4 layout
  2676. auto replace_inps_to_nchw4 = [&varshape_changed](
  2677. OperatorNodeBase* opr,
  2678. const VarNodeArray new_inp) {
  2679. mgb_assert(opr->input().size() == new_inp.size());
  2680. VarNodeArray inps = new_inp;
  2681. for (size_t i = 0; i < opr->input().size(); ++i) {
  2682. if (varshape_changed.count(new_inp[i])) {
  2683. auto symvar = RelayoutPlaceholder::make(
  2684. new_inp[i],
  2685. RelayoutPlaceholder::LayoutType::CHWN4_TO_NCHW4);
  2686. inps[i] = symvar.node();
  2687. }
  2688. }
  2689. auto new_opr =
  2690. serialization::copy_opr_shallow(*opr, inps, opr->config());
  2691. return new_opr;
  2692. };
  2693. auto replace_non_nchw4_opr = [&varshape_changed](
  2694. OperatorNodeBase* opr,
  2695. const VarNodeArray new_inp) {
  2696. size_t nr_inps = opr->input().size();
  2697. bool shape_has_changed = false;
  2698. for (size_t i = 0; i < nr_inps; ++i) {
  2699. if (varshape_changed.count(new_inp[i])) {
  2700. shape_has_changed = true;
  2701. }
  2702. }
  2703. mgb_assert(!shape_has_changed,
  2704. "EnableCHWN4Pass assumes that inputs' shape of "
  2705. "non-nchw4 operators "
  2706. "can not be changed in this opt "
  2707. "pass");
  2708. return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  2709. };
  2710. // capture by copy to avoid use after return
  2711. auto replace_warp_affine_opr =
  2712. [replace_inps_to_nchw4, replace_non_nchw4_opr](
  2713. OperatorNodeBase* opr, const VarNodeArray new_inp) {
  2714. using Param = opr::WarpAffineForward::Param;
  2715. using Format = Param::Format;
  2716. mgb_assert(opr->input().size() == new_inp.size());
  2717. auto& warp = opr->cast_final_safe<opr::WarpAffineForward>();
  2718. if (warp.param().format != Format::NCHW4) {
  2719. return replace_non_nchw4_opr(opr, new_inp);
  2720. }
  2721. return replace_inps_to_nchw4(opr, new_inp);
  2722. };
  2723. auto replace_warp_perspective_opr =
  2724. [replace_inps_to_nchw4, replace_non_nchw4_opr](
  2725. OperatorNodeBase* opr, const VarNodeArray new_inp) {
  2726. using Param = opr::WarpPerspectiveForward::Param;
  2727. using Format = Param::Format;
  2728. mgb_assert(opr->input().size() == new_inp.size());
  2729. auto& warp =
  2730. opr->cast_final_safe<opr::WarpPerspectiveForward>();
  2731. if (warp.param().format != Format::NCHW4) {
  2732. return replace_non_nchw4_opr(opr, new_inp);
  2733. }
  2734. return replace_inps_to_nchw4(opr, new_inp);
  2735. };
  2736. auto replace_resize_opr = [replace_inps_to_nchw4, replace_non_nchw4_opr](
  2737. OperatorNodeBase* opr,
  2738. const VarNodeArray new_inp) {
  2739. using Param = opr::ResizeForward::Param;
  2740. using Format = Param::Format;
  2741. mgb_assert(opr->input().size() == new_inp.size());
  2742. auto& resize = opr->cast_final_safe<opr::ResizeForward>();
  2743. if (resize.param().format != Format::NCHW4) {
  2744. return replace_non_nchw4_opr(opr, new_inp);
  2745. }
  2746. return replace_inps_to_nchw4(opr, new_inp);
  2747. };
  2748. auto replace_pooling_opr = [&varshape_changed, replace_non_nchw4_opr](
  2749. OperatorNodeBase* opr,
  2750. const VarNodeArray new_inp) {
  2751. using Param = opr::PoolingForward::Param;
  2752. using Format = Param::Format;
  2753. mgb_assert(opr->input().size() == new_inp.size());
  2754. auto& pooling = opr->cast_final_safe<opr::PoolingForward>();
  2755. if (pooling.param().format != Format::NCHW4) {
  2756. return replace_non_nchw4_opr(opr, new_inp);
  2757. }
  2758. size_t nr_inps = opr->input().size();
  2759. MGB_MARK_USED_VAR(nr_inps);
  2760. mgb_assert(nr_inps == 1);
  2761. if (varshape_changed.count(new_inp[0])) {
  2762. auto new_param = pooling.param();
  2763. new_param.format = Format::CHWN4;
  2764. auto new_pooling = opr::PoolingForward::make(new_inp[0], new_param,
  2765. opr->config());
  2766. varshape_changed.insert(new_pooling.node());
  2767. return new_pooling.node()->owner_opr();
  2768. }
  2769. return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  2770. };
  2771. replace_func[opr::ConvBiasForward::typeinfo()] = replace_conv_bias_opr;
  2772. // elemwise like
  2773. replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr;
  2774. replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr;
  2775. replace_func[opr::ElemwiseMultiType::typeinfo()] =
  2776. replace_elemwise_like_opr;
  2777. replace_func[opr::PowC::typeinfo()] = replace_elemwise_like_opr;
  2778. // format aware
  2779. replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
  2780. replace_func[opr::WarpAffineForward::typeinfo()] = replace_warp_affine_opr;
  2781. replace_func[opr::WarpPerspectiveForward::typeinfo()] =
  2782. replace_warp_perspective_opr;
  2783. replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr;
  2784. // to nchw4
  2785. replace_func[opr::Reduce::typeinfo()] = replace_inps_to_nchw4;
  2786. replace_func[opr::Concat::typeinfo()] = replace_inps_to_nchw4;
  2787. replace_func[opr::Reshape::typeinfo()] = replace_inps_to_nchw4;
  2788. replace_func[opr::GetVarShape::typeinfo()] = replace_inps_to_nchw4;
  2789. replace_func[opr::Dimshuffle::typeinfo()] = replace_inps_to_nchw4;
  2790. replace_func[opr::BatchConvBias::typeinfo()] = replace_inps_to_nchw4;
  2791. return ret;
  2792. }
  2793. /* ================ EnableNchwxxPass =============== */
  2794. VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var,
  2795. VarNode* orig_var) const {
  2796. if (!orig_var->shape().eq_shape(new_var->shape())) {
  2797. return RelayoutPlaceholder::make(
  2798. new_var, RelayoutPlaceholder::LayoutType::NCHW88_TO_NCHW)
  2799. .node();
  2800. }
  2801. return new_var;
  2802. }
  2803. std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
  2804. size_t pack_c_size) {
  2805. auto ret = std::make_unique<EnableNchwxxPass>();
  2806. ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
  2807. //! First is whether the conv can trans to nchwxx, second is the filter
  2808. //! trans mode
  2809. using RelayoutMode = RelayoutPlaceholder::LayoutType;
  2810. using TestFilterResult = std::pair<TransType, RelayoutMode>;
  2811. RelayoutMode weight_to_nchwxx_mode_dense =
  2812. RelayoutMode::WEIGHT_NCHW_TO_NCHW88_DENSE;
  2813. RelayoutMode weight_to_nchwxx_mode_group =
  2814. RelayoutMode::WEIGHT_NCHW_TO_NCHW88_GROUP;
  2815. RelayoutMode weight_to_nchwxx_mode_chan =
  2816. RelayoutMode::WEIGHT_NCHW_TO_NCHW88_CHAN;
  2817. RelayoutMode hybrid_nchw_nchwxx = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW88;
  2818. RelayoutMode src_to_nchwxx_mode = RelayoutMode::NCHW_TO_NCHW88;
  2819. RelayoutMode src_to_nchw_mode = RelayoutMode::NCHW88_TO_NCHW;
  2820. megdnn::param::ConvBias::Format conv_bias_format =
  2821. megdnn::param::ConvBias::Format::NCHW88;
  2822. megdnn::param::Convolution::Format conv_format =
  2823. megdnn::param::ConvolutionV0::Format::NCHW88;
  2824. megdnn::param::Pooling::Format pooling_format =
  2825. megdnn::param::Pooling::Format::NCHW88;
  2826. std::string convter_pass_name = "conv_format_nchw88";
  2827. mgb_assert(pack_c_size == static_cast<size_t>(8),
  2828. "The ConvertFormatPass to nchwxx only support NCHW88 now !");
  2829. auto test_trans_nchwxx =
  2830. [pack_c_size, weight_to_nchwxx_mode_dense,
  2831. weight_to_nchwxx_mode_group, weight_to_nchwxx_mode_chan,
  2832. hybrid_nchw_nchwxx](
  2833. const megdnn::param::Convolution::Sparse conv_mode,
  2834. const VarNode* filter) -> TestFilterResult {
  2835. TestFilterResult ret{TransType::TRANS_NONE, {}};
  2836. if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) {
  2837. size_t IC = filter->shape()[1];
  2838. size_t OC = filter->shape()[0];
  2839. if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) {
  2840. ret.first = TransType::TRANS_PURE_NCHWXX;
  2841. ret.second = weight_to_nchwxx_mode_dense;
  2842. } else if (IC < pack_c_size && OC % pack_c_size == 0) {
  2843. ret.first = TransType::TRANS_HYBIRD_NCHWXX;
  2844. ret.second = hybrid_nchw_nchwxx;
  2845. }
  2846. } else {
  2847. mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP);
  2848. size_t group = filter->shape()[0];
  2849. size_t ocpg = filter->shape()[1];
  2850. size_t icpg = filter->shape()[2];
  2851. if (icpg == 1 && ocpg == 1 && (group % pack_c_size == 0)) {
  2852. ret.first = TransType::TRANS_PURE_NCHWXX;
  2853. ret.second = weight_to_nchwxx_mode_chan;
  2854. } else if ((icpg % pack_c_size == 0) && (ocpg % pack_c_size == 0)) {
  2855. ret.first = TransType::TRANS_PURE_NCHWXX;
  2856. ret.second = weight_to_nchwxx_mode_group;
  2857. }
  2858. }
  2859. return ret;
  2860. };
  2861. auto replace_conv_opr = [test_trans_nchwxx, conv_format, src_to_nchwxx_mode,
  2862. src_to_nchw_mode](OperatorNodeBase* opr,
  2863. const VarNodeArray& new_inp) {
  2864. mgb_assert(opr->input().size() == new_inp.size());
  2865. auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
  2866. mgb_assert(conv_opr.param().format ==
  2867. megdnn::param::Convolution::Format::NCHW,
  2868. "ConvertFormat Pass only support converting NCHW to NCHWXX");
  2869. auto is_trans = test_trans_nchwxx(conv_opr.param().sparse, new_inp[1]);
  2870. //! can not trans to nchwxx
  2871. if (is_trans.first == TransType::TRANS_NONE) {
  2872. mgb_assert(new_inp[1]->shape().ndim == 4 ||
  2873. new_inp[1]->shape().ndim == 5,
  2874. "The origin filter is not NCHW mode");
  2875. VarNodeArray temp_inp = new_inp;
  2876. //! if src is nchwxx, should RelayoutPlaceholder to nchw
  2877. if (temp_inp[0]->shape().ndim == 5) {
  2878. auto new_src =
  2879. RelayoutPlaceholder::make(new_inp[0], src_to_nchw_mode);
  2880. temp_inp[0] = new_src.node();
  2881. }
  2882. auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp,
  2883. opr->config());
  2884. return new_opr;
  2885. } else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) {
  2886. //! filter trans to nchwxx mode
  2887. mgb_assert(new_inp[1]->shape().ndim == 4 ||
  2888. new_inp[1]->shape().ndim == 5,
  2889. "The origin filter is not NCHW mode");
  2890. VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1];
  2891. auto new_filter =
  2892. RelayoutPlaceholder::make(new_inp[1], is_trans.second);
  2893. conv_filter = new_filter.node();
  2894. //! src trans to nchwxx mode
  2895. if (new_inp[0]->shape().ndim != 5) {
  2896. mgb_assert(new_inp[0]->shape().ndim == 4);
  2897. auto new_src = RelayoutPlaceholder::make(new_inp[0],
  2898. src_to_nchwxx_mode);
  2899. conv_src = new_src.node();
  2900. }
  2901. auto new_param = conv_opr.param();
  2902. new_param.format = conv_format;
  2903. mgb_assert(conv_src->shape().ndim == 5 &&
  2904. conv_filter->shape().ndim >= 6,
  2905. "The conv src dim is not trans to nchwxx");
  2906. auto new_conv_opr = opr::Convolution::make(
  2907. conv_src, conv_filter, new_param,
  2908. conv_opr.execution_policy(), conv_opr.config());
  2909. OperatorNodeBase* new_opr = new_conv_opr.node()->owner_opr();
  2910. mgb_assert(new_conv_opr.shape().ndim == 5,
  2911. "The conv dst dim is not trans to nchwxx");
  2912. return new_opr;
  2913. } else {
  2914. mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX);
  2915. VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1];
  2916. auto new_filter =
  2917. RelayoutPlaceholder::make(new_inp[1], is_trans.second);
  2918. conv_filter = new_filter.node();
  2919. mgb_assert(conv_src->shape().ndim == 4 &&
  2920. conv_filter->shape().ndim == 5,
  2921. "The src and filter is OK");
  2922. auto new_param = conv_opr.param();
  2923. new_param.format = conv_format;
  2924. auto new_conv_opr = opr::Convolution::make(
  2925. conv_src, conv_filter, new_param,
  2926. conv_opr.execution_policy(), conv_opr.config());
  2927. OperatorNodeBase* new_opr = new_conv_opr.node()->owner_opr();
  2928. mgb_assert(new_conv_opr.shape().ndim == 5,
  2929. "The conv dst dim is not trans to nchwxx");
  2930. return new_opr;
  2931. }
  2932. };
  2933. auto replace_conv_bias_opr = [test_trans_nchwxx, conv_bias_format,
  2934. src_to_nchwxx_mode, src_to_nchw_mode](
  2935. OperatorNodeBase* opr,
  2936. const VarNodeArray& new_inp) {
  2937. mgb_assert(opr->input().size() == new_inp.size());
  2938. auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
  2939. mgb_assert(conv_bias_opr.param().format ==
  2940. megdnn::param::ConvBias::Format::NCHW,
  2941. "ConvertFormat Pass only support converting NCHW to NCHWXX");
  2942. auto is_trans =
  2943. test_trans_nchwxx(conv_bias_opr.param().sparse, new_inp[1]);
  2944. //! can not trans to nchwxx
  2945. if (is_trans.first == TransType::TRANS_NONE) {
  2946. mgb_assert(new_inp[1]->shape().ndim == 4 ||
  2947. new_inp[1]->shape().ndim == 5,
  2948. "The origin filter is not NCHW mode");
  2949. VarNodeArray temp_inp = new_inp;
  2950. //! if src is nchwxx, should RelayoutPlaceholder to nchw
  2951. if (temp_inp[0]->shape().ndim == 5) {
  2952. auto new_src =
  2953. RelayoutPlaceholder::make(new_inp[0], src_to_nchw_mode);
  2954. temp_inp[0] = new_src.node();
  2955. }
  2956. //! the bias is nchwxx
  2957. if (temp_inp[2]->shape().ndim == 5) {
  2958. auto new_bias =
  2959. RelayoutPlaceholder::make(new_inp[2], src_to_nchw_mode);
  2960. temp_inp[2] = new_bias.node();
  2961. }
  2962. auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp,
  2963. opr->config());
  2964. return new_opr;
  2965. } else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) {
  2966. VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1],
  2967. *conv_bias_bias = new_inp[2];
  2968. //! filter trans to nchwxx mode
  2969. mgb_assert(new_inp[1]->shape().ndim == 4 ||
  2970. new_inp[1]->shape().ndim == 5,
  2971. "The origin filter is not NCHW mode");
  2972. auto new_filter =
  2973. RelayoutPlaceholder::make(new_inp[1], is_trans.second);
  2974. conv_bias_filter = new_filter.node();
  2975. //! src trans to nchwxx mode
  2976. if (new_inp[0]->shape().ndim != 5) {
  2977. mgb_assert(new_inp[0]->shape().ndim == 4);
  2978. auto new_src = RelayoutPlaceholder::make(new_inp[0],
  2979. src_to_nchwxx_mode);
  2980. conv_bias_src = new_src.node();
  2981. }
  2982. //! bias trans to nchwxx mode, bias may be scale
  2983. if (new_inp[2]->shape().ndim == 4) {
  2984. auto new_bias = RelayoutPlaceholder::make(new_inp[2],
  2985. src_to_nchwxx_mode);
  2986. conv_bias_bias = new_bias.node();
  2987. }
  2988. auto new_param = conv_bias_opr.param();
  2989. new_param.format = conv_bias_format;
  2990. mgb_assert(conv_bias_src->shape().ndim == 5 &&
  2991. conv_bias_filter->shape().ndim >= 6,
  2992. "The conv_bias src dim is not trans to nchwxx");
  2993. auto new_conv_bias_opr = opr::ConvBias::make(
  2994. conv_bias_src, conv_bias_filter, conv_bias_bias, new_param,
  2995. conv_bias_opr.execution_policy(), conv_bias_opr.config());
  2996. OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
  2997. mgb_assert(new_conv_bias_opr.shape().ndim == 5,
  2998. "The conv_bias dst dim is not trans to nchwxx");
  2999. return new_opr;
  3000. } else {
  3001. mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX);
  3002. VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1],
  3003. *conv_bias_bias = new_inp[2];
  3004. auto new_filter =
  3005. RelayoutPlaceholder::make(new_inp[1], is_trans.second);
  3006. conv_bias_filter = new_filter.node();
  3007. //! bias trans to nchwxx mode, bias may be scale
  3008. if (new_inp[2]->shape().ndim == 4) {
  3009. auto new_bias = RelayoutPlaceholder::make(new_inp[2],
  3010. src_to_nchwxx_mode);
  3011. conv_bias_bias = new_bias.node();
  3012. }
  3013. mgb_assert(conv_bias_src->shape().ndim == 4 &&
  3014. conv_bias_filter->shape().ndim == 5);
  3015. mgb_assert((conv_bias_bias->shape().ndim == 5) ||
  3016. conv_bias_bias->shape().is_scalar());
  3017. auto new_param = conv_bias_opr.param();
  3018. new_param.format = conv_bias_format;
  3019. auto new_conv_bias_opr = opr::ConvBias::make(
  3020. conv_bias_src, conv_bias_filter, new_param,
  3021. conv_bias_opr.execution_policy(), conv_bias_opr.config());
  3022. OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
  3023. mgb_assert(new_conv_bias_opr.shape().ndim == 5,
  3024. "The conv dst dim is not trans to nchwxx");
  3025. return new_opr;
  3026. }
  3027. };
  3028. auto replace_pooling_opr = [=](OperatorNodeBase* opr,
  3029. const VarNodeArray& new_inp) {
  3030. mgb_assert(opr->input().size() == new_inp.size());
  3031. auto& pooling_opr = opr->cast_final_safe<opr::PoolingForward>();
  3032. mgb_assert(pooling_opr.param().format ==
  3033. megdnn::param::Pooling::Format::NCHW,
  3034. "ConvertFormat Pass only support converting NCHW to NCHWxx");
  3035. VarNode* inp = new_inp[0];
  3036. //! if input is nchwxx
  3037. if (inp->shape().ndim == 5) {
  3038. auto new_param = pooling_opr.param();
  3039. new_param.format = pooling_format;
  3040. auto new_pooling_opr =
  3041. opr::PoolingForward::make(inp, new_param, opr->config());
  3042. mgb_assert(new_pooling_opr.shape().ndim == 5,
  3043. "The pooling dst dim is not trans to nchwxx");
  3044. return new_pooling_opr.node()->owner_opr();
  3045. } else {
  3046. auto new_opr = serialization::copy_opr_shallow(*opr, new_inp,
  3047. opr->config());
  3048. return new_opr;
  3049. }
  3050. };
  3051. auto replace_elemwise_opr = [=](OperatorNodeBase* opr,
  3052. const VarNodeArray& new_inp) {
  3053. mgb_assert(opr->input().size() == new_inp.size());
  3054. bool has_inp_changed = false;
  3055. for (size_t i = 0; i < opr->input().size(); i++) {
  3056. if (new_inp[i]->shape().ndim == 5) {
  3057. has_inp_changed = true;
  3058. break;
  3059. }
  3060. }
  3061. if (has_inp_changed) {
  3062. auto temp_inp = new_inp;
  3063. for (size_t i = 0; i < opr->input().size(); i++) {
  3064. if (new_inp[i]->shape().ndim == 4) {
  3065. auto new_var = RelayoutPlaceholder::make(
  3066. new_inp[i], src_to_nchwxx_mode);
  3067. temp_inp[i] = new_var.node();
  3068. } else {
  3069. mgb_assert((new_inp[i]->shape().ndim == 5) ||
  3070. new_inp[i]->shape().is_scalar());
  3071. }
  3072. }
  3073. return serialization::copy_opr_shallow(*opr, temp_inp,
  3074. opr->config());
  3075. } else {
  3076. return serialization::copy_opr_shallow(*opr, new_inp,
  3077. opr->config());
  3078. }
  3079. };
  3080. auto relayout_inp_to_nchw = [=](OperatorNodeBase* opr,
  3081. const VarNodeArray& new_inp) {
  3082. mgb_assert(opr->input().size() == new_inp.size());
  3083. VarNodeArray temp_inp = new_inp;
  3084. for (size_t i = 0; i < opr->input().size(); i++) {
  3085. if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
  3086. mgb_assert(opr->input(i)->shape().ndim == 4);
  3087. mgb_assert(new_inp[i]->shape().ndim == 5);
  3088. auto new_var =
  3089. RelayoutPlaceholder::make(new_inp[i], src_to_nchw_mode);
  3090. temp_inp[i] = new_var.node();
  3091. }
  3092. }
  3093. return serialization::copy_opr_shallow(*opr, temp_inp, opr->config());
  3094. };
  3095. ret->set_name(convter_pass_name);
  3096. auto&& replace_func = ret->m_opr_replace_func;
  3097. //! supportted nchwxx
  3098. replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
  3099. replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr;
  3100. replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
  3101. replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr;
  3102. replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr;
  3103. replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr;
  3104. replace_func[opr::PowC::typeinfo()] = replace_elemwise_opr;
  3105. //! not support yet
  3106. replace_func[opr::ConvolutionBackwardData::typeinfo()] =
  3107. relayout_inp_to_nchw;
  3108. replace_func[opr::Subtensor::typeinfo()] = relayout_inp_to_nchw;
  3109. replace_func[opr::Concat::typeinfo()] = relayout_inp_to_nchw;
  3110. replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_nchw;
  3111. replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_nchw;
  3112. replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_nchw;
  3113. replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw;
  3114. replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw;
  3115. replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_nchw;
  3116. replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw;
  3117. replace_func[opr::ResizeForward::typeinfo()] = relayout_inp_to_nchw;
  3118. replace_func[opr::WarpPerspectiveForward::typeinfo()] =
  3119. relayout_inp_to_nchw;
  3120. replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw;
  3121. return ret;
  3122. }
  3123. /* ==================== ShuffleShuffleRemovePass ================= */
  3124. class ShuffleShuffleRemovePass::Impl {
  3125. using TensorFormat = opr::ConvBias::Param::Format;
  3126. OptState& m_opt_state;
  3127. ThinHashMap<std::pair<TensorFormat, TensorFormat>,
  3128. thin_function<VarNode*(VarNode*)>>
  3129. m_reformat;
  3130. class AbstractShuffleOpr;
  3131. void detect_shuffle_operations();
  3132. void do_replace();
  3133. public:
  3134. Impl(OptState& opt_state) : m_opt_state{opt_state} {
  3135. m_reformat[std::make_pair(TensorFormat::NCHW, TensorFormat::NCHW4)] =
  3136. [](VarNode* inp) -> VarNode* {
  3137. auto x = SymbolVar(inp);
  3138. auto xshp = opr::GetVarShape::make(x);
  3139. auto cv = [&x](int v) { return x.make_scalar(v); };
  3140. auto sub = [&xshp, &cv](int idx) {
  3141. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  3142. };
  3143. auto tshp = opr::Concat::make(
  3144. {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0);
  3145. auto y0 = opr::Reshape::make(x, tshp);
  3146. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
  3147. return y1.node();
  3148. };
  3149. m_reformat[std::make_pair(TensorFormat::NCHW, TensorFormat::NCHW32)] =
  3150. [](VarNode* inp) -> VarNode* {
  3151. auto x = SymbolVar(inp);
  3152. auto xshp = opr::GetVarShape::make(x);
  3153. auto cv = [&x](int v) { return x.make_scalar(v); };
  3154. auto sub = [&xshp, &cv](int idx) {
  3155. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  3156. };
  3157. auto tshp = opr::Concat::make(
  3158. {sub(0), sub(1) / 32, cv(32), sub(2), sub(3)}, 0);
  3159. auto y0 = opr::Reshape::make(x, tshp);
  3160. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
  3161. return y1.node();
  3162. };
  3163. m_reformat[std::make_pair(TensorFormat::NCHW4, TensorFormat::NCHW)] =
  3164. [](VarNode* inp) -> VarNode* {
  3165. mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4);
  3166. auto x = SymbolVar(inp);
  3167. auto xshp = opr::GetVarShape::make(x);
  3168. auto cv = [&x](int v) { return x.make_scalar(v); };
  3169. auto sub = [&xshp, &cv](int idx) {
  3170. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  3171. };
  3172. auto tshp =
  3173. opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0);
  3174. auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
  3175. auto y1 = opr::Reshape::make(y0, tshp);
  3176. return y1.node();
  3177. };
  3178. m_reformat[std::make_pair(TensorFormat::NCHW32, TensorFormat::NCHW)] =
  3179. [](VarNode* inp) -> VarNode* {
  3180. mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 32);
  3181. auto x = SymbolVar(inp);
  3182. auto xshp = opr::GetVarShape::make(x);
  3183. auto cv = [&x](int v) { return x.make_scalar(v); };
  3184. auto sub = [&xshp, &cv](int idx) {
  3185. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  3186. };
  3187. auto tshp =
  3188. opr::Concat::make({sub(0), sub(1) * 32, sub(2), sub(3)}, 0);
  3189. auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
  3190. auto y1 = opr::Reshape::make(y0, tshp);
  3191. return y1.node();
  3192. };
  3193. m_reformat[std::make_pair(TensorFormat::NCHW4, TensorFormat::NCHW32)] =
  3194. [](VarNode* inp) -> VarNode* {
  3195. mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4);
  3196. auto x = SymbolVar(inp);
  3197. auto xshp = opr::GetVarShape::make(x);
  3198. auto cv = [&x](int v) { return x.make_scalar(v); };
  3199. auto sub = [&xshp, &cv](int idx) {
  3200. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  3201. };
  3202. auto tshp0 = opr::Concat::make(
  3203. {sub(0), sub(1) / 8, cv(8), sub(2), sub(3), sub(4)},
  3204. 0),
  3205. tshp1 = opr::Concat::make(
  3206. {sub(0), sub(1) / 8, sub(2), sub(3), sub(4) * 8}, 0);
  3207. auto y0 = opr::Reshape::make(x, tshp0);
  3208. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2, 5});
  3209. auto y2 = opr::Reshape::make(y1, tshp1);
  3210. return y2.node();
  3211. };
  3212. m_reformat[std::make_pair(TensorFormat::NCHW32, TensorFormat::NCHW4)] =
  3213. [](VarNode* inp) -> VarNode* {
  3214. mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 32);
  3215. auto x = SymbolVar(inp);
  3216. auto xshp = opr::GetVarShape::make(x);
  3217. auto cv = [&x](int v) { return x.make_scalar(v); };
  3218. auto sub = [&xshp, &cv](int idx) {
  3219. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  3220. };
  3221. auto tshp0 = opr::Concat::make(
  3222. {sub(0), sub(1), sub(2), sub(3), cv(8), sub(4) / 8},
  3223. 0),
  3224. tshp1 = opr::Concat::make(
  3225. {sub(0), sub(1) * 8, sub(2), sub(3), sub(4) / 8}, 0);
  3226. auto y0 = opr::Reshape::make(x, tshp0);
  3227. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 4, 2, 3, 5});
  3228. auto y2 = opr::Reshape::make(y1, tshp1);
  3229. return y2.node();
  3230. };
  3231. m_reformat[std::make_pair(TensorFormat::NCHW4, TensorFormat::CHWN4)] =
  3232. [](VarNode* inp) -> VarNode* {
  3233. megdnn::param::RelayoutFormat param;
  3234. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW4_CHWN4;
  3235. auto reformat = opr::RelayoutFormat::make(inp, param);
  3236. return reformat.node();
  3237. };
  3238. m_reformat[std::make_pair(TensorFormat::CHWN4, TensorFormat::NCHW4)] =
  3239. [](VarNode* inp) -> VarNode* {
  3240. megdnn::param::RelayoutFormat param;
  3241. param.mode = megdnn::param::RelayoutFormat::Mode::CHWN4_NCHW4;
  3242. auto reformat = opr::RelayoutFormat::make(inp, param);
  3243. return reformat.node();
  3244. };
  3245. m_reformat[std::make_pair(TensorFormat::NCHW, TensorFormat::CHWN4)] =
  3246. [](VarNode* inp) -> VarNode* {
  3247. auto x = SymbolVar(inp);
  3248. auto xshp = opr::GetVarShape::make(x);
  3249. auto cv = [&x](int v) { return x.make_scalar(v); };
  3250. auto sub = [&xshp, &cv](int idx) {
  3251. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  3252. };
  3253. auto tshp = opr::Concat::make(
  3254. {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0);
  3255. auto y0 = opr::Reshape::make(x, tshp);
  3256. auto y1 = opr::Dimshuffle::make(y0, {1, 3, 4, 0, 2});
  3257. return y1.node();
  3258. };
  3259. m_reformat[std::make_pair(TensorFormat::CHWN4, TensorFormat::NCHW)] =
  3260. [](VarNode* inp) -> VarNode* {
  3261. mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4);
  3262. auto x = SymbolVar(inp);
  3263. auto xshp = opr::GetVarShape::make(x);
  3264. auto cv = [&x](int v) { return x.make_scalar(v); };
  3265. auto sub = [&xshp, &cv](int idx) {
  3266. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  3267. };
  3268. auto tshp =
  3269. opr::Concat::make({sub(3), sub(0) * 4, sub(1), sub(2)}, 0);
  3270. auto y0 = opr::Dimshuffle::make(x, {3, 0, 4, 1, 2});
  3271. auto y1 = opr::Reshape::make(y0, tshp);
  3272. return y1.node();
  3273. };
  3274. detect_shuffle_operations();
  3275. do_replace();
  3276. }
  3277. };
  3278. /*!
  3279. * \brief abstract operator representation of shuffle operation
  3280. */
  3281. MGB_DEFINE_OPR_CLASS(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr,
  3282. cg::SingleCNOperatorNodeBase) // {
  3283. public:
  3284. AbstractShuffleOpr(VarNode* inpvar, TensorFormat inp_format,
  3285. TensorFormat out_format);
  3286. static SymbolVar make(VarNode* inpvar, TensorFormat inp_format,
  3287. TensorFormat out_format);
  3288. TensorFormat inp_format() const { return m_inp_format; }
  3289. TensorFormat out_format() const { return m_out_format; }
  3290. private:
  3291. void init_output_static_infer_desc() override;
  3292. void scn_do_execute() override;
  3293. const TensorFormat m_inp_format;
  3294. const TensorFormat m_out_format;
  3295. };
  3296. MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr);
  3297. void ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr::scn_do_execute() {
  3298. mgb_throw(InternalError, "AbstractShuffleOpr cannot be executed");
  3299. }
  3300. void ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr::
  3301. init_output_static_infer_desc() {
  3302. using namespace cg::static_infer;
  3303. auto&& mgr = owner_graph()->static_infer_manager();
  3304. DepVal deps;
  3305. for (auto i : input())
  3306. deps.push_back({i, DepType::SHAPE});
  3307. auto infer_shape = [this](TensorShape& dst, const InpVal& inp) {
  3308. TensorShape inp_shape = inp.val[0].shape();
  3309. if (m_inp_format == TensorFormat::NCHW4 &&
  3310. m_out_format == TensorFormat::NCHW32) {
  3311. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
  3312. dst = inp_shape;
  3313. dst[0] = inp_shape[0];
  3314. dst[1] = inp_shape[1] / 8;
  3315. dst[2] = inp_shape[2];
  3316. dst[3] = inp_shape[3];
  3317. dst[4] = inp_shape[4] * 8;
  3318. } else if (m_inp_format == TensorFormat::NCHW32 &&
  3319. m_out_format == TensorFormat::NCHW4) {
  3320. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 32);
  3321. dst = inp_shape;
  3322. dst[0] = inp_shape[0];
  3323. dst[1] = inp_shape[1] * 8;
  3324. dst[2] = inp_shape[2];
  3325. dst[3] = inp_shape[3];
  3326. dst[4] = inp_shape[4] / 8;
  3327. } else if (m_inp_format == TensorFormat::NCHW &&
  3328. m_out_format == TensorFormat::NCHW4) {
  3329. mgb_assert(inp_shape.ndim == 4);
  3330. dst.ndim = 5;
  3331. dst[0] = inp_shape[0];
  3332. dst[1] = inp_shape[1] / 4;
  3333. dst[2] = inp_shape[2];
  3334. dst[3] = inp_shape[3];
  3335. dst[4] = 4;
  3336. } else if (m_inp_format == TensorFormat::NCHW4 &&
  3337. m_out_format == TensorFormat::NCHW) {
  3338. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
  3339. dst.ndim = 4;
  3340. dst[0] = inp_shape[0];
  3341. dst[1] = inp_shape[1] * 4;
  3342. dst[2] = inp_shape[2];
  3343. dst[3] = inp_shape[3];
  3344. } else if (m_inp_format == TensorFormat::NCHW4 &&
  3345. m_out_format == TensorFormat::CHWN4) {
  3346. dst.ndim = 5;
  3347. dst[0] = inp_shape[1];
  3348. dst[1] = inp_shape[2];
  3349. dst[2] = inp_shape[3];
  3350. dst[3] = inp_shape[0];
  3351. dst[4] = inp_shape[4];
  3352. } else if (m_inp_format == TensorFormat::CHWN4 &&
  3353. m_out_format == TensorFormat::NCHW4) {
  3354. dst.ndim = 5;
  3355. dst[0] = inp_shape[3];
  3356. dst[1] = inp_shape[0];
  3357. dst[2] = inp_shape[1];
  3358. dst[3] = inp_shape[2];
  3359. dst[4] = inp_shape[4];
  3360. } else {
  3361. mgb_throw(InternalError,
  3362. "Unsupported input format and output format.");
  3363. }
  3364. return true;
  3365. };
  3366. mgr.register_shape_infer(output(0), {SourceType::DEP, deps, infer_shape});
  3367. }
  3368. ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr::AbstractShuffleOpr(
  3369. VarNode* inpvar, TensorFormat inp_format, TensorFormat out_format)
  3370. : Super(inpvar->owner_graph(), {}, "AbstractShuffleOpr", {inpvar}),
  3371. m_inp_format{inp_format},
  3372. m_out_format{out_format} {
  3373. add_input({inpvar});
  3374. add_equivalence_component<ScalarHash<TensorFormat>>(m_inp_format);
  3375. add_equivalence_component<ScalarHash<TensorFormat>>(m_out_format);
  3376. add_output(None)->dtype(inpvar->dtype());
  3377. }
  3378. SymbolVar ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr::make(
  3379. VarNode* inpvar, TensorFormat inp_format, TensorFormat out_format) {
  3380. return inpvar->owner_graph()
  3381. ->insert_opr(std::make_unique<AbstractShuffleOpr>(
  3382. inpvar, inp_format, out_format))
  3383. ->output(0);
  3384. }
  3385. void ShuffleShuffleRemovePass::Impl::detect_shuffle_operations() {
  3386. auto rewriter = m_opt_state.graph().make_rewriter();
  3387. auto uniq_reader_check = UniqReaderCheck{m_opt_state.graph()};
  3388. auto try_reshape_shuffle = [&rewriter,
  3389. &uniq_reader_check](OperatorNodeBase* opr) {
  3390. // check shuffle
  3391. auto shuffle = try_cast_as_op<opr::Dimshuffle>(opr);
  3392. if (shuffle == nullptr)
  3393. return false;
  3394. auto&& param = shuffle->param();
  3395. if (param.pattern_len != 5)
  3396. return false;
  3397. bool is_nchw2nchw4 = param.pattern[0] == 0 && param.pattern[1] == 1 &&
  3398. param.pattern[2] == 3 && param.pattern[3] == 4 &&
  3399. param.pattern[4] == 2 &&
  3400. opr->output(0)->shape()[4] == 4;
  3401. if (!is_nchw2nchw4)
  3402. return false;
  3403. if (!uniq_reader_check(shuffle->input(0)))
  3404. return false;
  3405. // check reshape
  3406. auto reshape = try_cast_as_op<opr::Reshape>(opr->input(0)->owner_opr());
  3407. if (reshape == nullptr)
  3408. return false;
  3409. auto inp_var = rewriter.get_var(reshape->input(0));
  3410. auto abstract_shuffle = AbstractShuffleOpr::make(
  3411. inp_var, TensorFormat::NCHW, TensorFormat::NCHW4);
  3412. rewriter.replace_var(
  3413. opr->output(0), abstract_shuffle.node(),
  3414. mgb_cstr_log("replace reformat(nchw -> nchw4) to "
  3415. "AbstractShuffleOpr(nchw -> nchw4)."));
  3416. return true;
  3417. };
  3418. auto try_reshape_shuffle_reshape = [&rewriter, &uniq_reader_check](
  3419. OperatorNodeBase* opr) {
  3420. // check reshape
  3421. auto reshape1 = try_cast_as_op<opr::Reshape>(opr);
  3422. if (reshape1 == nullptr)
  3423. return false;
  3424. if (!uniq_reader_check(reshape1->input(0)))
  3425. return false;
  3426. // check shuffle
  3427. auto shuffle =
  3428. try_cast_as_op<opr::Dimshuffle>(opr->input(0)->owner_opr());
  3429. if (shuffle == nullptr)
  3430. return false;
  3431. auto&& param = shuffle->param();
  3432. if (param.pattern_len != 6)
  3433. return false;
  3434. bool is_nchw42nchw32 = param.pattern[0] == 0 && param.pattern[1] == 1 &&
  3435. param.pattern[2] == 3 && param.pattern[3] == 4 &&
  3436. param.pattern[4] == 2 && param.pattern[5] == 5 &&
  3437. shuffle->input(0)->shape()[5] == 4 &&
  3438. shuffle->input(0)->shape()[2] == 8;
  3439. bool is_nchw322nchw4 = param.pattern[0] == 0 && param.pattern[1] == 1 &&
  3440. param.pattern[2] == 4 && param.pattern[3] == 2 &&
  3441. param.pattern[4] == 3 && param.pattern[5] == 5 &&
  3442. shuffle->input(0)->shape()[4] == 8 &&
  3443. shuffle->input(0)->shape()[5] == 4;
  3444. if (!is_nchw42nchw32 && !is_nchw322nchw4)
  3445. return false;
  3446. if (!uniq_reader_check(shuffle->input(0)))
  3447. return false;
  3448. // check reshape
  3449. auto reshape2 =
  3450. try_cast_as_op<opr::Reshape>(shuffle->input(0)->owner_opr());
  3451. if (reshape2 == nullptr)
  3452. return false;
  3453. auto inp_var = rewriter.get_var(reshape2->input(0));
  3454. TensorFormat inp_format = is_nchw42nchw32 ? TensorFormat::NCHW4
  3455. : TensorFormat::NCHW32,
  3456. out_format = is_nchw42nchw32 ? TensorFormat::NCHW32
  3457. : TensorFormat::NCHW4;
  3458. auto abstract_shuffle =
  3459. AbstractShuffleOpr::make(inp_var, inp_format, out_format);
  3460. std::string reformat_type =
  3461. is_nchw42nchw32 ? "nchw4 -> nchw32" : "nchw32 -> nchw4";
  3462. rewriter.replace_var(opr->output(0), abstract_shuffle.node(),
  3463. mgb_cstr_log(ssprintf("replace reformat(%s) to "
  3464. "AbstractShuffleOpr(%s).",
  3465. reformat_type.c_str(),
  3466. reformat_type.c_str())
  3467. .c_str()));
  3468. return true;
  3469. };
  3470. auto try_shuffle_reshape = [&rewriter,
  3471. &uniq_reader_check](OperatorNodeBase* opr) {
  3472. // check reshape
  3473. auto reshape = try_cast_as_op<opr::Reshape>(opr);
  3474. if (reshape == nullptr)
  3475. return false;
  3476. if (!uniq_reader_check(reshape->input(0)))
  3477. return false;
  3478. // check shuffle
  3479. auto shuffle =
  3480. try_cast_as_op<opr::Dimshuffle>(opr->input(0)->owner_opr());
  3481. if (shuffle == nullptr)
  3482. return false;
  3483. auto&& param = shuffle->param();
  3484. if (param.pattern_len != 5)
  3485. return false;
  3486. bool is_nchw42nchw = param.pattern[0] == 0 && param.pattern[1] == 1 &&
  3487. param.pattern[2] == 4 && param.pattern[3] == 2 &&
  3488. param.pattern[4] == 3 &&
  3489. shuffle->input(0)->shape()[4] == 4;
  3490. if (!is_nchw42nchw)
  3491. return false;
  3492. auto inp_var = rewriter.get_var(shuffle->input(0));
  3493. auto abstract_shuffle = AbstractShuffleOpr::make(
  3494. inp_var, TensorFormat::NCHW4, TensorFormat::NCHW);
  3495. rewriter.replace_var(
  3496. opr->output(0), abstract_shuffle.node(),
  3497. mgb_cstr_log("replace reformat(nchw4 -> nchw) to "
  3498. "AbstractShuffleOpr(nchw4 -> nchw)."));
  3499. return true;
  3500. };
  3501. auto try_relayout_format = [&rewriter](OperatorNodeBase* opr) {
  3502. // check relayout format
  3503. auto reformat = try_cast_as_op<opr::RelayoutFormat>(opr);
  3504. if (reformat == nullptr)
  3505. return false;
  3506. auto&& param = reformat->param();
  3507. if (param.mode != opr::RelayoutFormat::Param::Mode::CHWN4_NCHW4 &&
  3508. param.mode != opr::RelayoutFormat::Param::Mode::NCHW4_CHWN4)
  3509. return false;
  3510. auto inp_var = rewriter.get_var(reformat->input(0));
  3511. cg::SymbolVar abstract_shuffle;
  3512. if (param.mode == opr::RelayoutFormat::Param::Mode::NCHW4_CHWN4) {
  3513. abstract_shuffle = AbstractShuffleOpr::make(
  3514. inp_var, TensorFormat::NCHW4, TensorFormat::CHWN4);
  3515. } else {
  3516. abstract_shuffle = AbstractShuffleOpr::make(
  3517. inp_var, TensorFormat::CHWN4, TensorFormat::NCHW4);
  3518. }
  3519. rewriter.replace_var(
  3520. opr->output(0), abstract_shuffle.node(),
  3521. mgb_cstr_log("replace reformat(nchw4 -> nchw) to "
  3522. "AbstractShuffleOpr(nchw4 -> nchw)."));
  3523. return true;
  3524. };
  3525. auto on_opr = [&try_reshape_shuffle, &try_shuffle_reshape,
  3526. &try_reshape_shuffle_reshape, &try_relayout_format,
  3527. &rewriter, &uniq_reader_check](OperatorNodeBase* opr) {
  3528. if (!try_reshape_shuffle_reshape(opr) && !try_reshape_shuffle(opr) &&
  3529. !try_shuffle_reshape(opr) && !try_relayout_format(opr)) {
  3530. auto new_opr = rewriter.auto_replace_outputs(opr);
  3531. uniq_reader_check.update_on_opr_auto_replace(opr, new_opr);
  3532. }
  3533. };
  3534. m_opt_state.graph().iter(on_opr);
  3535. rewriter.apply_inplace();
  3536. }
  3537. void ShuffleShuffleRemovePass::Impl::do_replace() {
  3538. auto rewriter = m_opt_state.graph().make_rewriter();
  3539. auto uniq_reader_check = UniqReaderCheck{m_opt_state.graph()};
  3540. ThinHashMap<VarNode*, VarNode*> var2endpoint;
  3541. ThinHashSet<VarNode*> trt_opr_inps;
  3542. SmallVector<OperatorNodeBase*> topo_order;
  3543. auto cb = [&topo_order, &trt_opr_inps](OperatorNodeBase* opr) {
  3544. topo_order.push_back(opr);
  3545. MGB_MARK_USED_VAR(trt_opr_inps);
  3546. #if MGB_ENABLE_TENSOR_RT
  3547. if (opr->same_type<opr::TensorRTOpr>()) {
  3548. for (auto&& inp : opr->input())
  3549. trt_opr_inps.insert(inp);
  3550. }
  3551. #endif
  3552. };
  3553. m_opt_state.graph().iter(cb);
  3554. for (auto&& opr : reverse_adaptor(topo_order)) {
  3555. if (opr->same_type<opr::TypeCvt>() ||
  3556. opr->same_type<AbstractShuffleOpr>()) {
  3557. auto find = var2endpoint.find(opr->output(0));
  3558. if (find != var2endpoint.end()) {
  3559. if (uniq_reader_check(opr->output(0))) {
  3560. var2endpoint[opr->input(0)] = find->second;
  3561. } else {
  3562. var2endpoint[opr->input(0)] = opr->output(0);
  3563. }
  3564. } else {
  3565. var2endpoint[opr->input(0)] = opr->output(0);
  3566. }
  3567. }
  3568. }
  3569. auto on_opr = [this, &rewriter, &uniq_reader_check, &trt_opr_inps,
  3570. &var2endpoint](OperatorNodeBase* opr) {
  3571. MGB_MARK_USED_VAR(trt_opr_inps);
  3572. bool cond_opr = opr->same_type<opr::TypeCvt>() ||
  3573. opr->same_type<AbstractShuffleOpr>();
  3574. if (cond_opr) {
  3575. bool cond_endpoint = var2endpoint[opr->input(0)] == opr->output(0);
  3576. if (!cond_endpoint)
  3577. return;
  3578. auto cur = opr;
  3579. auto var = opr->output(0), inp_var = opr->input(0);
  3580. bool force_folding_typecvt = false;
  3581. bool first_shuffle = false;
  3582. // initialize inp_format and out_format
  3583. TensorFormat out_format = TensorFormat::NCHW, inp_format = out_format;
  3584. megdnn::DType inp_dtype = cur->input(0)->dtype(),
  3585. out_dtype = cur->output(0)->dtype();
  3586. SmallVector<megdnn::DType> out_dtype_vec;
  3587. while (cond_opr) {
  3588. if (cur->same_type<AbstractShuffleOpr>()) {
  3589. auto shuffle = try_cast_as_op<AbstractShuffleOpr>(cur);
  3590. inp_format = shuffle->inp_format();
  3591. if (!first_shuffle) {
  3592. out_format = shuffle->out_format();
  3593. first_shuffle = true;
  3594. }
  3595. } else {
  3596. mgb_assert(cur->same_type<opr::TypeCvt>());
  3597. out_dtype_vec.push_back(cur->output(0)->dtype());
  3598. }
  3599. inp_var = cur->input(0);
  3600. bool cond_reader = uniq_reader_check(inp_var);
  3601. if (!cond_reader)
  3602. break;
  3603. cur = cur->input(0)->owner_opr();
  3604. cond_opr = cur->same_type<opr::TypeCvt>() ||
  3605. cur->same_type<AbstractShuffleOpr>();
  3606. }
  3607. std::reverse(out_dtype_vec.begin(), out_dtype_vec.end());
  3608. #if MGB_ENABLE_TENSOR_RT
  3609. force_folding_typecvt =
  3610. inp_var->owner_opr()->same_type<opr::TensorRTOpr>() ||
  3611. trt_opr_inps.count(var);
  3612. #endif
  3613. auto new_var = rewriter.get_var(inp_var);
  3614. if (inp_format != out_format) {
  3615. new_var = m_reformat[std::make_pair(inp_format, out_format)](
  3616. new_var);
  3617. }
  3618. if (force_folding_typecvt) {
  3619. inp_dtype = inp_var->dtype();
  3620. if (inp_dtype != out_dtype) {
  3621. auto type_cvt = opr::TypeCvt::make(new_var, out_dtype);
  3622. new_var = type_cvt.node();
  3623. }
  3624. } else {
  3625. if (out_dtype_vec.back() != var->dtype())
  3626. out_dtype_vec.push_back(var->dtype());
  3627. for (auto&& dtype : out_dtype_vec) {
  3628. auto type_cvt = opr::TypeCvt::make(new_var, dtype);
  3629. new_var = type_cvt.node();
  3630. }
  3631. }
  3632. rewriter.replace_var(
  3633. var, new_var,
  3634. mgb_cstr_log("replace Dimshuffle and TypeCvt chain"));
  3635. } else {
  3636. auto new_opr = rewriter.auto_replace_outputs(opr);
  3637. uniq_reader_check.update_on_opr_auto_replace(opr, new_opr);
  3638. }
  3639. };
  3640. m_opt_state.graph().iter(on_opr);
  3641. rewriter.apply_inplace();
  3642. }
  3643. const char* ShuffleShuffleRemovePass::name() const {
  3644. return mgb_cstr_log("shuffle shuffle remove pass");
  3645. }
  3646. void ShuffleShuffleRemovePass::apply(OptState& opt) const {
  3647. opt.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_SHAPE |
  3648. VarReplaceCheckFlag::CHECK_DTYPE);
  3649. Impl{opt};
  3650. }
  3651. void gopt::reformat_to_chwn4_transform_dest_vars_inplace(
  3652. mgb::cg::VarNodeArray& dest_vars) {
  3653. gopt::GraphOptimizer optimizer;
  3654. optimizer.add_pass<FuseConvBiasNonlinPass>();
  3655. optimizer.add_pass<FuseConvBiasZPass>();
  3656. optimizer.add_pass(EnableCHWN4Pass::make_chwn4_converter());
  3657. optimizer.add_pass<ShuffleShuffleRemovePass>();
  3658. optimizer.add_pass<RemoveRedundantTypeCvtPass>();
  3659. optimizer.add_pass<ParamFusePass>();
  3660. optimizer.apply_inplace(dest_vars);
  3661. }
  3662. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)