You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inference.cpp 88 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085
  1. /**
  2. * \file src/gopt/impl/inference.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/gopt/inference.h"
  12. #include "megbrain/gopt/gtrans.h"
  13. #include "megbrain/gopt/basic_arith.h"
  14. #include "megbrain/graph/event.h"
  15. #include "megbrain/opr/dnn/batch_norm.h"
  16. #include "megbrain/opr/dnn/local.h"
  17. #include "megbrain/utils/shared_set.h"
  18. #include "megbrain/serialization/opr_shallow_copy.h"
  19. #include "megbrain/opr/basic_arith.h"
  20. #include "megbrain/opr/dnn/convolution.h"
  21. #include "megbrain/opr/blas.h"
  22. #include "megbrain/opr/misc.h"
  23. #include "megbrain/opr/utility.h"
  24. #include "megbrain/opr/dnn/pooling.h"
  25. #include "megbrain/opr/tensor_manip.h"
  26. #include "megbrain/opr/imgproc.h"
  27. #include "megbrain/opr/nn_int.h"
  28. #include "megbrain/opr/tensor_gen.h"
  29. #include "megbrain/utils/hash_ct.h"
  30. #include "megdnn/tensor_format.h"
  31. #if MGB_ENABLE_TENSOR_RT
  32. #include "megbrain/tensorrt/tensorrt_opr.h"
  33. #endif
  34. #include "megbrain/gopt/misc.h"
  35. #include "megbrain/utils/hash_ct.h"
  36. #include "midout.h"
  37. MIDOUT_DECL(megbrain_inference)
  38. #define MIDOUT_B(tag) \
  39. MIDOUT_BEGIN(megbrain_inference, midout_iv(MGB_HASH_STR(tag))) {
  40. #define MIDOUT_E \
  41. } \
  42. MIDOUT_END();
  43. using namespace mgb;
  44. using namespace gopt;
  45. namespace {
  46. template <typename SharedDeviceTensor, typename MultipleDeviceTensorHolder>
  47. void param_merge(OptState& opt_state) {
  48. auto rewriter = opt_state.graph().make_rewriter();
  49. ThinHashMap<OperatorNodeBase*, size_t> opr2idx;
  50. std::vector<OperatorNodeBase*> all_oprs;
  51. typename MultipleDeviceTensorHolder::ValueArray all_values;
  52. auto cb_find_opr = [&](cg::OperatorNodeBase* opr) {
  53. if (opr->same_type<SharedDeviceTensor>()) {
  54. auto p = &opr->cast_final<SharedDeviceTensor>();
  55. // ShredD may be manu
  56. opr2idx[p] = all_values.size();
  57. all_values.push_back(p->dev_data());
  58. all_oprs.push_back(p);
  59. }
  60. };
  61. opt_state.graph().iter(cb_find_opr);
  62. SymbolVarArray new_vars;
  63. auto cb_replace = [&](cg::OperatorNodeBase* opr) {
  64. auto iter = opr2idx.find(opr);
  65. if (iter == opr2idx.end()) {
  66. rewriter.auto_replace_outputs(opr);
  67. } else {
  68. if (new_vars.empty()) {
  69. // new oprs must be created in iter callback; so we populate
  70. // new_vars lazily
  71. new_vars = MultipleDeviceTensorHolder::make(
  72. *opt_state.graph().comp_graph(), std::move(all_values),
  73. {ssprintf("merged%zu", all_values.size())});
  74. for (size_t i = 0; i < new_vars.size(); ++i) {
  75. auto src = all_oprs[i]->output(0);
  76. if (src->has_name_set()) {
  77. new_vars[i].rename(src->name());
  78. }
  79. }
  80. }
  81. rewriter.replace_var(
  82. opr->output(0), new_vars.at(iter->second).node(),
  83. mgb_cstr_log("replace multi SharedDeviceTensor(Format) to "
  84. "MultipleDeviceTensorHolder(Format)"));
  85. }
  86. };
  87. opt_state.graph().iter(cb_replace);
  88. rewriter.apply_inplace();
  89. }
  90. }
  91. /* ================ global functions ================ */
  92. SymbolVarArray gopt::optimize_for_inference(
  93. const SymbolVarArray& dest_vars,
  94. const OptimizeForInferenceOptions& opt) {
  95. return gopt::GraphOptimizer()
  96. .add_preset_passes(false, &opt,
  97. &dest_vars[0].node()->owner_graph()->options())
  98. .apply({dest_vars})
  99. .endpoint_vars();
  100. }
  101. namespace {
  102. void modify_conv_strategy(
  103. opr::mixin::Convolution& conv,
  104. opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) {
  105. auto policy = conv.execution_policy_transient();
  106. policy.strategy = strategy;
  107. conv.set_execution_policy(policy);
  108. }
  109. template <typename Opr>
  110. void inplace_conv_opr_modifier(
  111. OperatorNodeBase& opr,
  112. opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) {
  113. modify_conv_strategy(
  114. opr.cast_final_safe<Opr>(),
  115. strategy);
  116. }
  117. void modify_conv_policy_workspace_limit(opr::mixin::Convolution& conv,
  118. size_t workspace_limit) {
  119. auto policy = conv.execution_policy_transient();
  120. policy.workspace_limit = workspace_limit;
  121. conv.set_execution_policy(policy);
  122. }
  123. template <typename Opr>
  124. void inplace_conv_opr_workspace_limit_modifier(OperatorNodeBase& opr,
  125. size_t workspace_limit) {
  126. modify_conv_policy_workspace_limit(opr.cast_final_safe<Opr>(),
  127. workspace_limit);
  128. }
  129. } // anonymous namespace
  130. #define MGB_FOREACH_FASTRUN_OPR(cb) \
  131. cb(ConvolutionForward), cb(ConvBiasForward), cb(ConvolutionBackwardData), \
  132. cb(ConvolutionBackwardFilter), cb(Convolution3DForward), \
  133. cb(Convolution3DBackwardData), cb(Convolution3DBackwardFilter), \
  134. cb(LocalShareForward), cb(LocalShareBackwardData), \
  135. cb(LocalShareBackwardFilter), cb(DeformableConvForward), \
  136. cb(DeformableConvBackwardFilter), cb(DeformableConvBackwardData), \
  137. cb(BatchConvBiasForward),
  138. void gopt::modify_opr_algo_strategy_inplace(
  139. const VarNodeArrayView& dest_vars,
  140. opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) {
  141. #if !MGB_ENABLE_FASTRUN
  142. using S = opr::mixin::Convolution::ExecutionPolicy::Strategy;
  143. if (strategy == S::PROFILE || strategy == S::PROFILE_REPRODUCIBLE) {
  144. mgb_throw(MegBrainError, "fastrun is disabled at compile time");
  145. }
  146. #endif
  147. const ThinHashMap<Typeinfo*, std::function<void(OperatorNodeBase&)>>
  148. modifiers = {
  149. #define CONV(t) \
  150. {opr::t::typeinfo(), std::bind(inplace_conv_opr_modifier<opr::t>, \
  151. std::placeholders::_1, strategy)}
  152. MGB_FOREACH_FASTRUN_OPR(CONV)
  153. #undef CONV
  154. };
  155. auto on_opr = [&](OperatorNodeBase* opr) {
  156. auto iter = modifiers.find(opr->dyn_typeinfo());
  157. if (iter != modifiers.end()) {
  158. iter->second(*opr);
  159. }
  160. };
  161. cg::DepOprIter dep_iter{on_opr};
  162. for (auto i : dest_vars) {
  163. dep_iter.add(i);
  164. }
  165. }
  166. void gopt::enable_opr_algo_profiling_inplace(
  167. const VarNodeArrayView& dest_vars) {
  168. modify_opr_algo_strategy_inplace(dest_vars,
  169. opr::mixin::Convolution::ExecutionPolicy::
  170. Strategy::PROFILE);
  171. }
  172. void gopt::enable_opr_use_profiling_cache_inplace(
  173. const VarNodeArrayView& dest_vars) {
  174. modify_opr_algo_strategy_inplace(dest_vars,
  175. opr::mixin::Convolution::ExecutionPolicy::
  176. Strategy::PROFILE_HEURISTIC);
  177. }
  178. void gopt::set_opr_algo_workspace_limit_inplace(
  179. const VarNodeArrayView& dest_vars, size_t workspace_limit) {
  180. static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&, size_t)>
  181. modifiers = {
  182. #define CONV(t) \
  183. {opr::t::typeinfo(), &inplace_conv_opr_workspace_limit_modifier<opr::t>}
  184. MGB_FOREACH_FASTRUN_OPR(CONV)
  185. #undef CONV
  186. };
  187. auto on_opr = [&](OperatorNodeBase* opr) {
  188. auto iter = modifiers.find(opr->dyn_typeinfo());
  189. if (iter != modifiers.end()) {
  190. iter->second(*opr, workspace_limit);
  191. }
  192. };
  193. cg::DepOprIter dep_iter{on_opr};
  194. for (auto i : dest_vars) {
  195. dep_iter.add(i);
  196. }
  197. }
  198. #undef MGB_FOREACH_FASTRUN_OPR
  199. /* ================ ParamRedistributePass ================ */
  200. const char* ParamRedistributePass::name() const {
  201. return mgb_cstr_log("param_redistribute");
  202. }
  203. class ParamRedistributePass::Impl final: public RecursiveSubGraphRewriteHelper {
  204. ConstVarPropogate m_cvprop;
  205. UniqReaderCheck m_uniq_reader_check;
  206. //! oprs already processed in try_distribute_then_reassociate() should be
  207. //! skipped in on_new_opr_check_should_process()
  208. ThinHashSet<OperatorNodeBase*> m_opr_blacklist;
  209. std::string m_distribute_reasso_log_msg;
  210. //! try applying BinaryTrans20::associtive
  211. GTransResult try_reassociate(OperatorNodeBase *opr);
  212. //! try applying BinaryTrans20::distributive_add
  213. GTransResult try_distribute_add(OperatorNodeBase *opr);
  214. //! try distribute MUL/DIV over ADD/SUB and then apply
  215. GTransResult try_distribute_then_reassociate(OperatorNodeBase *opr);
  216. GTransResult process_opr(VarNode *out_var) override;
  217. bool on_new_opr_check_should_process(
  218. OperatorNodeBase*opr, OperatorNodeBase *repl_opr) override {
  219. m_uniq_reader_check.update_on_opr_auto_replace(opr, repl_opr);
  220. auto ins = m_cvprop.add_opr(opr);
  221. return ins.has_const_inp && !ins.all_const_inp &&
  222. !m_opr_blacklist.count(opr);
  223. };
  224. void after_replace_var(VarNode *orig_var, VarNode* new_var) override {
  225. m_uniq_reader_check.update_on_opr_auto_replace(orig_var->owner_opr(),
  226. new_var->owner_opr());
  227. }
  228. /*!
  229. * \brief try to reorder opr inputs to a const one and a non-const one
  230. *
  231. * return true if it can be reformulated as f(nci, ci), where nci is
  232. * non-const and ci is const.
  233. */
  234. bool reorder_for_normconst(OperatorNodeBase *opr,
  235. bool &swap_inp, VarNode *&nci, VarNode *&ci);
  236. public:
  237. Impl(OptState &state);
  238. };
  239. GTransResult ParamRedistributePass::Impl::process_opr(VarNode *out_var) {
  240. auto opr = out_var->owner_opr();
  241. auto trans = try_reassociate(opr);
  242. if (!trans.valid()) {
  243. trans = try_distribute_add(opr);
  244. if (!trans.valid())
  245. trans = try_distribute_then_reassociate(opr);
  246. }
  247. return trans;
  248. }
  249. GTransResult ParamRedistributePass::Impl::try_reassociate(
  250. OperatorNodeBase *opr) {
  251. // apply BinaryAssociative0 if opr is the form f(g(a, b), c) and b and c are
  252. // const
  253. bool swap_fop_inp = false, swap_gop_inp = false;
  254. VarNode *a, *b, *c, *ab;
  255. if (!reorder_for_normconst(opr, swap_fop_inp, ab, c))
  256. return None;
  257. if (!m_uniq_reader_check(ab))
  258. return None;
  259. if (!reorder_for_normconst(ab->owner_opr(), swap_gop_inp, a, b))
  260. return None;
  261. return BinaryTrans20::associtive().apply(opr, swap_fop_inp, swap_gop_inp);
  262. }
  263. GTransResult ParamRedistributePass::Impl::try_distribute_add(
  264. OperatorNodeBase *opr) {
  265. if (opr->same_type<opr::Elemwise>() || opr->input().size() != 2)
  266. return None;
  267. if (!m_cvprop.is_const(opr->input(1)))
  268. return None;
  269. auto ab = as_elem_opr(opr->input(0)->owner_opr(), opr::Elemwise::Mode::ADD);
  270. if (ab) {
  271. bool swap;
  272. VarNode *a, *b;
  273. if (reorder_for_normconst(ab, swap, a, b)) {
  274. return BinaryTrans20::distributive_add().apply(
  275. opr, false, swap);
  276. }
  277. }
  278. return None;
  279. }
  280. GTransResult ParamRedistributePass::Impl::try_distribute_then_reassociate(
  281. OperatorNodeBase *opr) {
  282. if (!opr->same_type<opr::Elemwise>())
  283. return None;
  284. using Mode = opr::Elemwise::Mode;
  285. auto mode = opr->cast_final<opr::Elemwise>().param().mode;
  286. if (!(mode == Mode::MUL || mode == Mode::TRUE_DIV))
  287. return None;
  288. VarNode *a, *b;
  289. bool swap;
  290. if (!reorder_for_normconst(opr, swap, a, b))
  291. return None;
  292. auto chain_pred = [this](OperatorNodeBase *opr) {
  293. if (as_elem_opr(opr, Mode::ADD)) {
  294. auto var = opr->output(0);
  295. return m_uniq_reader_check(var) || m_cvprop.is_const(var);
  296. }
  297. return false;
  298. };
  299. auto chain = extract_opr_leaves(a, chain_pred);
  300. if (chain.size() <= 1)
  301. return None;
  302. std::unordered_map<VarNode*, VarNode*> repl_map;
  303. m_distribute_reasso_log_msg.clear();
  304. int nr_fail = 0, nr_succ = 0;
  305. for (auto &&var: chain) {
  306. {
  307. auto iter = repl_map.find(var);
  308. if (iter != repl_map.end()) {
  309. var = iter->second;
  310. continue;
  311. }
  312. }
  313. auto vnew = (SymbolVar{var} * b).node();
  314. m_opr_blacklist.insert(vnew->owner_opr());
  315. if (!m_cvprop.is_const(var)) {
  316. auto trans = try_reassociate(vnew->owner_opr());
  317. if (!trans.valid()) {
  318. // allow at most one failed redistribution
  319. if (nr_fail)
  320. return None;
  321. ++ nr_fail;
  322. } else {
  323. ++ nr_succ;
  324. vnew = trans->result;
  325. if (!m_distribute_reasso_log_msg.empty()) {
  326. m_distribute_reasso_log_msg.append(mgb_cstr_log(";"));
  327. }
  328. m_distribute_reasso_log_msg.append(trans->msg);
  329. }
  330. }
  331. repl_map[var] = vnew;
  332. var = vnew;
  333. }
  334. if (nr_succ) {
  335. m_distribute_reasso_log_msg.insert(0,
  336. mgb_cstr_log("distribute_mul("));
  337. m_distribute_reasso_log_msg.append(mgb_cstr_log(")"));
  338. return GTransResultItem{
  339. elemwise_reduce_var_list(chain, Mode::ADD),
  340. m_distribute_reasso_log_msg.c_str(),
  341. {}};
  342. }
  343. return None;
  344. }
  345. bool ParamRedistributePass::Impl::reorder_for_normconst(
  346. OperatorNodeBase *opr, bool &swap_inp, VarNode *&nci, VarNode *&ci) {
  347. if (opr->input().size() != 2)
  348. return false;
  349. nci = opr->input(0);
  350. ci = opr->input(1);
  351. if (!m_cvprop.is_const(ci)) {
  352. if (!is_commutable_binary(opr) || !m_cvprop.is_const(nci))
  353. return false;
  354. swap_inp = true;
  355. std::swap(nci, ci);
  356. } else {
  357. if (m_cvprop.is_const(nci))
  358. return false;
  359. swap_inp = false;
  360. }
  361. return true;
  362. }
  363. ParamRedistributePass::Impl::Impl(OptState &state):
  364. RecursiveSubGraphRewriteHelper{state},
  365. m_cvprop{ConstVarType::IMMUTABLE_AND_PARAM},
  366. m_uniq_reader_check{state.graph()}
  367. {
  368. auto cg = state.graph().comp_graph();
  369. auto on_new_opr = [this](const cg::event::OprInserted &ev) {
  370. if (!ev.is_dedup && !ev.exc) {
  371. // call add_opr eagerly to avoid deep recursion
  372. m_cvprop.add_opr(ev.opr);
  373. }
  374. };
  375. auto hdl = cg->event().register_receiver
  376. <cg::event::OprInserted>(on_new_opr);
  377. apply();
  378. }
  379. void ParamRedistributePass::apply(OptState &state) const {
  380. MIDOUT_B("ParamRedistributePass::apply")
  381. Impl{state};
  382. MIDOUT_E
  383. }
  384. /* ================ ParamFusePass ================ */
  385. /*!
  386. * \brief get name for new param
  387. */
  388. class ParamFusePass::VarNamer {
  389. #if MGB_BUILD_SLIM_SERVING
  390. public:
  391. const std::string& name(VarNode*) {
  392. static std::string ret("fuse");
  393. return ret;
  394. }
  395. #else
  396. using SrcSet = SharedSet<OperatorNodeBase*>;
  397. //! map from var to source SharedDeviceTensor/MultiSharedDeviceHolder oprs
  398. //! that it depends on
  399. ThinHashMap<OperatorNodeBase*, SrcSet> m_opr2srcs;
  400. std::string m_name_cache;
  401. std::vector<const char*> m_cur_name;
  402. SrcSet& get_src_set(OperatorNodeBase* opr) {
  403. auto opr_typeinfo = opr->dyn_typeinfo();
  404. auto iter = m_opr2srcs.find(opr);
  405. if (iter != m_opr2srcs.end()) {
  406. return iter->second;
  407. }
  408. auto &&ret = m_opr2srcs[opr];
  409. if (opr->input().empty()) {
  410. if (opr_typeinfo == opr::SharedDeviceTensor::typeinfo() ||
  411. opr_typeinfo == opr::MultipleDeviceTensorHolder::typeinfo()) {
  412. ret.insert(opr);
  413. } else {
  414. mgb_assert(opr_typeinfo == opr::ImmutableTensor::typeinfo());
  415. }
  416. return ret;
  417. }
  418. for (auto i: opr->input()) {
  419. ret.merge_from(get_src_set(i->owner_opr()));
  420. }
  421. return ret;
  422. }
  423. public:
  424. const std::string& name(VarNode *var) {
  425. m_cur_name.clear();
  426. for (auto i : get_src_set(var->owner_opr())) {
  427. m_cur_name.push_back(i->cname());
  428. }
  429. auto cmp = [](const char *x, const char *y) {
  430. return strcmp(x, y) < 0;
  431. };
  432. std::sort(m_cur_name.begin(), m_cur_name.end(), cmp);
  433. m_name_cache.clear();
  434. m_name_cache.append(mgb_cstr_log("fuse("));
  435. bool first = true;
  436. for (auto i: m_cur_name) {
  437. if (first) {
  438. first = false;
  439. } else {
  440. m_name_cache.push_back(',');
  441. }
  442. m_name_cache.append(i);
  443. }
  444. m_name_cache.append(mgb_cstr_log(
  445. ssprintf("):%s@%zu", var->cname(), var->id())));
  446. return m_name_cache;
  447. }
  448. #endif
  449. };
  450. const char* ParamFusePass::name() const {
  451. return mgb_cstr_log("param_fuse");
  452. }
  453. void ParamFusePass::apply(OptState &state) const {
  454. MIDOUT_B("ParamFusePass::apply")
  455. auto rewriter = state.graph().make_rewriter();
  456. auto cg = state.graph().comp_graph();
  457. ConstVarPropogate cvprop{ConstVarType::IMMUTABLE_AND_PARAM};
  458. state.graph().iter([&cvprop](OperatorNodeBase *opr) {
  459. cvprop.add_opr(opr);
  460. });
  461. ThinHashSet<VarNode*> processed_var;
  462. VarNamer var_namer;
  463. // reader: null if used as endvar
  464. auto replace_single_var = [&](VarNode *var, OperatorNodeBase *reader) {
  465. if (!processed_var.insert(var).second)
  466. return;
  467. auto inferred_val = std::make_shared<DeviceTensorND>(
  468. var->comp_node(), var->dtype());
  469. auto cb = [&](DeviceTensorND& val) {
  470. // retain format of val
  471. mgb_assert(val.format() == var->format());
  472. inferred_val->format(val.format())
  473. .resize(val.shape())
  474. .copy_from_fixlayout(val);
  475. };
  476. {
  477. auto orig_level = cg->options().log_level;
  478. cg->options().log_level = 0;
  479. MGB_TRY {
  480. cg->compile({{var, cb}})->execute();
  481. } MGB_FINALLY(cg->options().log_level = orig_level);
  482. }
  483. SymbolVar new_var;
  484. bool is_default_format = var->layout().format.is_default();
  485. if (cg::is_static_var_value(var) && is_default_format) {
  486. // use ImmutableTensor for inferable vars
  487. HostTensorND hv;
  488. hv.copy_from(*inferred_val).sync();
  489. new_var = opr::ImmutableTensor::make(
  490. *var->owner_graph(), hv, var_namer.name(var));
  491. } else {
  492. if (is_default_format) {
  493. new_var = opr::SharedDeviceTensor::make_const(
  494. *var->owner_graph(), inferred_val, var_namer.name(var));
  495. } else {
  496. new_var = opr::SharedDeviceTensorWithFormat::make_const(
  497. *var->owner_graph(), inferred_val, var_namer.name(var));
  498. }
  499. }
  500. std::string log;
  501. if (reader) {
  502. log = mgb_ssprintf_log(
  503. "due to read by %s{%s}",
  504. reader->cname(), reader->dyn_typeinfo()->name);
  505. } else {
  506. log = mgb_cstr_log("as endpoint");
  507. }
  508. rewriter.replace_var(var, new_var.node(), log.c_str());
  509. };
  510. auto replace_opr = [&](OperatorNodeBase* opr) {
  511. auto add_ret = cvprop.opr_rst(opr);
  512. if (!add_ret.all_const_inp && add_ret.has_midconst_inp) {
  513. for (auto i: opr->input()) {
  514. if (cvprop.is_midconst(i)) {
  515. state.call_with_opr(i->owner_opr(),
  516. [&]{replace_single_var(i, opr);});
  517. }
  518. }
  519. }
  520. rewriter.auto_replace_outputs(opr);
  521. //! we should deal with midconst var after auto_replace_outputs, as
  522. //! on_midconst_opr will replace the endpoint output which may cause
  523. //! double replace.
  524. if (add_ret.all_const_inp) {
  525. for (auto var : opr->output()) {
  526. if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT))
  527. continue;
  528. auto osize = ConstVarPropogate::var_mem_size(var);
  529. if (osize >= cvprop.max_size(opr) &&
  530. osize - cvprop.max_size(opr) > m_param_grow_limit) {
  531. return;
  532. }
  533. // const oprs should be evaluated when output is used by another
  534. // non-const opr or output is needed by the user
  535. if (state.graph().endpoint_contain(var)) {
  536. replace_single_var(var, nullptr);
  537. }
  538. }
  539. }
  540. };
  541. state.graph().iter(replace_opr);
  542. rewriter.apply_inplace();
  543. MIDOUT_E
  544. }
  545. /* ================ One2OneOprReplacePass ================ */
  546. const char* ConvertF32ToF16Pass::name() const {
  547. return mgb_cstr_log("convert_f32_to_f16");
  548. }
  549. void ConvertF32ToF16Pass::apply(OptState& state) const {
  550. MIDOUT_B("ConvertF32ToF16Pass::apply")
  551. state.set_var_replace_check_flag(m_var_replace_check_flag);
  552. auto rewriter = state.graph().make_rewriter();
  553. VarNodeArray new_inp_cache;
  554. // record original output dtype
  555. const SymbolVarArray& vars = state.graph().endpoint_vars();
  556. std::vector<DType> dtypes;
  557. for (size_t i = 0; i < vars.size(); i++) {
  558. dtypes.push_back(vars[i].node()->dtype());
  559. }
  560. auto on_opr = [this, &rewriter, &new_inp_cache](OperatorNodeBase* opr) {
  561. auto it = m_opr_replace_func.find(opr->dyn_typeinfo());
  562. if (it != m_opr_replace_func.end()) {
  563. auto&& new_inp = new_inp_cache;
  564. new_inp.clear();
  565. new_inp.reserve(opr->input().size());
  566. for (auto i: opr->input()) {
  567. new_inp.push_back(rewriter.get_var(i));
  568. }
  569. auto new_opr = (it->second)(opr, new_inp);
  570. auto &&origin_out = opr->output(), &&cur_out = new_opr->output();
  571. mgb_assert(origin_out.size() == cur_out.size(),
  572. "bad opr replace: src=%s{%s} dst=%s{%s}", opr->cname(),
  573. opr->dyn_typeinfo()->name, new_opr->cname(),
  574. new_opr->dyn_typeinfo()->name);
  575. for (size_t i = 0; i < origin_out.size(); i++) {
  576. rewriter.replace_var(origin_out[i], cur_out[i], nullptr);
  577. }
  578. } else {
  579. rewriter.auto_replace_outputs(opr);
  580. }
  581. };
  582. state.graph().iter(on_opr);
  583. rewriter.apply_inplace();
  584. // recover output dtype
  585. rewriter = state.graph().make_rewriter();
  586. const SymbolVarArray& endpoints = state.graph().endpoint_vars();
  587. auto replace_output = [&]() {
  588. for (size_t i = 0; i < endpoints.size(); i++) {
  589. VarNode* var = endpoints[i].node();
  590. if (var->dtype().enumv() != dtypes[i].enumv()) {
  591. auto new_var = opr::TypeCvt::make(var, dtypes[i]).node();
  592. rewriter.replace_var(var, new_var, nullptr);
  593. }
  594. }
  595. };
  596. mgb_assert(endpoints.size() > 0);
  597. auto opr = endpoints[0].node()->owner_opr();
  598. state.call_with_opr(opr, replace_output, OprPropertyFlag::NONE);
  599. rewriter.apply_inplace();
  600. MIDOUT_E
  601. }
  602. std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(
  603. bool use_f32_comp) {
  604. #if MEGDNN_DISABLE_FLOAT16
  605. mgb_throw(SystemError, "float16 disabled at compile time.");
  606. #else
  607. auto replace_h2d_opr = [](OperatorNodeBase* opr,
  608. const VarNodeArray& new_inp) {
  609. mgb_assert(opr->input().size() == new_inp.size());
  610. auto& h2d_opr = opr->cast_final_safe<opr::Host2DeviceCopy>();
  611. if (h2d_opr.output(0)->dtype() == dtype::Float32()) {
  612. auto cvt_var =
  613. opr::TypeCvt::make(h2d_opr.output(0), dtype::Float16(), {});
  614. return cvt_var.node()->owner_opr();
  615. }
  616. return opr;
  617. };
  618. auto replace_sdt_opr = [](OperatorNodeBase* opr,
  619. const VarNodeArray& new_inp) {
  620. mgb_assert(opr->input().size() == new_inp.size());
  621. auto& sdt_opr = opr->cast_final_safe<opr::SharedDeviceTensor>();
  622. if (sdt_opr.output(0)->dtype() == dtype::Float32()) {
  623. auto cvt_var =
  624. opr::TypeCvt::make(sdt_opr.output(0), dtype::Float16(), {});
  625. return cvt_var.node()->owner_opr();
  626. }
  627. return opr;
  628. };
  629. auto replace_imt_opr = [](OperatorNodeBase* opr,
  630. const VarNodeArray& new_inp) {
  631. mgb_assert(opr->same_type<opr::ImmutableTensor>());
  632. mgb_assert(opr->input().size() == new_inp.size());
  633. auto& imt_opr = opr->cast_final_safe<opr::ImmutableTensor>();
  634. if (imt_opr.output(0)->dtype() == dtype::Float32()) {
  635. auto cvt_var =
  636. opr::TypeCvt::make(imt_opr.output(0), dtype::Float16(), {});
  637. return cvt_var.node()->owner_opr();
  638. }
  639. return opr;
  640. };
  641. auto replace_lsp_opr = [](OperatorNodeBase* opr,
  642. const VarNodeArray& new_inp) {
  643. mgb_assert(opr->same_type<opr::Linspace>());
  644. mgb_assert(opr->input().size() == new_inp.size());
  645. auto& lsp_opr = opr->cast_final_safe<opr::Linspace>();
  646. if (lsp_opr.output(0)->dtype() != dtype::Float16()) {
  647. auto cvt_var =
  648. opr::TypeCvt::make(lsp_opr.output(0), dtype::Float16(), {});
  649. return cvt_var.node()->owner_opr();
  650. }
  651. return opr;
  652. };
  653. auto replace_conv_opr = [use_f32_comp](OperatorNodeBase* opr,
  654. const VarNodeArray& new_inp) {
  655. mgb_assert(opr->input().size() == new_inp.size());
  656. auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
  657. auto new_param = conv_opr.param();
  658. if (use_f32_comp) {
  659. new_param.compute_mode =
  660. megdnn::param::Convolution::ComputeMode::FLOAT32;
  661. }
  662. mgb_assert(new_inp[0]->dtype() == dtype::Float16(),
  663. "inp %s:%s, owner_opr:%s", new_inp[0]->dtype().name(),
  664. new_inp[0]->name().c_str(),
  665. new_inp[0]->owner_opr()->name().c_str());
  666. mgb_assert(new_inp[1]->dtype() == dtype::Float16(),
  667. "inp %s:%s, owner_opr:%s", new_inp[1]->dtype().name(),
  668. new_inp[1]->name().c_str(),
  669. new_inp[1]->owner_opr()->name().c_str());
  670. auto new_conv_opr = opr::Convolution::make(
  671. new_inp[0], new_inp[1], new_param, conv_opr.execution_policy(),
  672. conv_opr.config());
  673. return new_conv_opr.node()->owner_opr();
  674. };
  675. auto replace_convbias_opr = [use_f32_comp](OperatorNodeBase* opr,
  676. const VarNodeArray& new_inp) {
  677. auto& convbias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
  678. auto new_param = convbias_opr.param();
  679. if (use_f32_comp) {
  680. new_param.compute_mode =
  681. megdnn::param::ConvBias::ComputeMode::FLOAT32;
  682. }
  683. mgb_assert(new_inp[0]->dtype() == dtype::Float16(),
  684. "inp %s:%s, owner_opr:%s", new_inp[0]->dtype().name(),
  685. new_inp[0]->name().c_str(),
  686. new_inp[0]->owner_opr()->name().c_str());
  687. mgb_assert(new_inp[1]->dtype() == dtype::Float16(),
  688. "inp %s:%s, owner_opr:%s", new_inp[1]->dtype().name(),
  689. new_inp[1]->name().c_str(),
  690. new_inp[1]->owner_opr()->name().c_str());
  691. if(opr->input().size() == 2) {
  692. auto new_conv_opr = opr::ConvBias::make(
  693. new_inp[0], new_inp[1], new_param, convbias_opr.execution_policy(),
  694. convbias_opr.config());
  695. return new_conv_opr.node()->owner_opr();
  696. } else if(opr->input().size() == 3) {
  697. auto new_conv_opr = opr::ConvBias::make(
  698. new_inp[0], new_inp[1], new_inp[2], new_param, convbias_opr.execution_policy(),
  699. convbias_opr.config());
  700. return new_conv_opr.node()->owner_opr();
  701. } else {
  702. mgb_assert(opr->input().size() == 4, "invalid input size %zu",
  703. opr->input().size());
  704. auto new_conv_opr = opr::ConvBias::make(
  705. new_inp[0], new_inp[1], new_inp[2], new_inp[3], new_param, convbias_opr.execution_policy(),
  706. convbias_opr.config());
  707. return new_conv_opr.node()->owner_opr();
  708. }
  709. };
  710. auto replace_matmul_opr = [use_f32_comp](OperatorNodeBase* opr,
  711. const VarNodeArray& new_inp) {
  712. mgb_assert(opr->input().size() == new_inp.size());
  713. auto& matmul_opr = opr->cast_final_safe<opr::MatrixMul>();
  714. auto new_param = matmul_opr.param();
  715. if (use_f32_comp) {
  716. new_param.compute_mode =
  717. megdnn::param::MatrixMul::ComputeMode::FLOAT32;
  718. }
  719. auto new_matmul_opr = opr::MatrixMul::make(
  720. new_inp[0], new_inp[1], new_param, matmul_opr.config());
  721. return new_matmul_opr.node()->owner_opr();
  722. };
  723. auto replace_batched_matmul_opr = [use_f32_comp](
  724. OperatorNodeBase* opr,
  725. const VarNodeArray& new_inp) {
  726. mgb_assert(opr->input().size() == new_inp.size());
  727. auto& matmul_opr = opr->cast_final_safe<opr::BatchedMatrixMul>();
  728. auto new_param = matmul_opr.param();
  729. if (use_f32_comp) {
  730. new_param.compute_mode =
  731. megdnn::param::MatrixMul::ComputeMode::FLOAT32;
  732. }
  733. mgb_assert(new_inp[0]->dtype() == dtype::Float16(),
  734. "inp %s:%s, owner_opr:%s", new_inp[0]->dtype().name(),
  735. new_inp[0]->name().c_str(),
  736. new_inp[0]->owner_opr()->name().c_str());
  737. mgb_assert(new_inp[1]->dtype() == dtype::Float16(),
  738. "inp %s:%s, owner_opr:%s", new_inp[1]->dtype().name(),
  739. new_inp[1]->name().c_str(),
  740. new_inp[1]->owner_opr()->name().c_str());
  741. auto new_matmul_opr = opr::BatchedMatrixMul::make(
  742. new_inp[0], new_inp[1], new_param, matmul_opr.config());
  743. return new_matmul_opr.node()->owner_opr();
  744. };
  745. auto replace_reduce_opr = [use_f32_comp](OperatorNodeBase* opr,
  746. const VarNodeArray& new_inp) {
  747. auto& reduce_opr = opr->cast_final_safe<opr::Reduce>();
  748. auto new_param = reduce_opr.param();
  749. if (use_f32_comp) {
  750. new_param.data_type =
  751. megdnn::param::Reduce::DataType::FLOAT_O16xC32;
  752. }
  753. if (opr->input().size() == 1) {
  754. auto new_matmul_opr = opr::Reduce::make(new_inp[0], new_param, {},
  755. reduce_opr.config());
  756. return new_matmul_opr.node()->owner_opr();
  757. } else {
  758. mgb_assert(opr->input().size() == 2, "invalid input size %zu",
  759. opr->input().size());
  760. auto new_matmul_opr = opr::Reduce::make(
  761. new_inp[0], new_param, new_inp[1], reduce_opr.config());
  762. return new_matmul_opr.node()->owner_opr();
  763. }
  764. };
  765. auto replace_cvt_opr = [](OperatorNodeBase* opr,
  766. const VarNodeArray& new_inp) {
  767. auto& cvt_opr = opr->cast_final_safe<opr::TypeCvt>();
  768. SymbolVar new_cvt;
  769. if (cvt_opr.output(0)->dtype() == dtype::Float32()) {
  770. new_cvt = opr::TypeCvt::make(new_inp[0], dtype::Float16(),
  771. cvt_opr.config());
  772. } else {
  773. new_cvt = opr::TypeCvt::make(
  774. new_inp[0], cvt_opr.output()[0]->dtype(), cvt_opr.config());
  775. }
  776. return new_cvt.node()->owner_opr();
  777. };
  778. auto replace_warp_opr = [](OperatorNodeBase* opr,
  779. const VarNodeArray& new_inp) {
  780. mgb_assert(opr->input().size() == new_inp.size() &&
  781. (new_inp.size() == 3 || new_inp.size() == 4));
  782. auto& warp_opr = opr->cast_final<opr::WarpPerspective>();
  783. // mat tensor must be float32
  784. auto new_mat = new_inp[1];
  785. if (new_inp[1]->dtype() != dtype::Float32()) {
  786. if (try_cast_as_op<opr::TypeCvt>(new_mat->owner_opr()) &&
  787. new_mat->owner_opr()->input(0)->dtype() == dtype::Float32())
  788. new_mat = new_mat->owner_opr()->input(0);
  789. else
  790. new_mat =
  791. opr::TypeCvt::make(new_inp[1], dtype::Float32(), {}).node();
  792. }
  793. SymbolVar new_warp;
  794. if (new_inp.size() == 3) {
  795. new_warp = opr::WarpPerspective::make(new_inp[0], new_mat,
  796. new_inp[2], warp_opr.param(),
  797. warp_opr.config());
  798. } else {
  799. mgb_assert(new_inp.size() == 4);
  800. new_warp = opr::WarpPerspective::make(
  801. new_inp[0], new_mat, new_inp[2], new_inp[3],
  802. warp_opr.param(), warp_opr.config());
  803. }
  804. return new_warp.node()->owner_opr();
  805. };
  806. auto replace_remap_opr = [](OperatorNodeBase* opr,
  807. const VarNodeArray& new_inp) {
  808. mgb_assert(opr->input().size() == new_inp.size() &&
  809. (new_inp.size() == 2));
  810. auto& remap_opr = opr->cast_final<opr::Remap>();
  811. // map tensor must be float32
  812. auto new_map = new_inp[1];
  813. if (new_inp[1]->dtype() != dtype::Float32()) {
  814. if (try_cast_as_op<opr::TypeCvt>(new_map->owner_opr()) &&
  815. new_map->owner_opr()->input(0)->dtype() == dtype::Float32())
  816. new_map = new_map->owner_opr()->input(0);
  817. else
  818. new_map =
  819. opr::TypeCvt::make(new_inp[1], dtype::Float32(), {}).node();
  820. }
  821. SymbolVar new_remap;
  822. new_remap = opr::Remap::make(new_inp[0], new_map,
  823. remap_opr.param(),
  824. remap_opr.config());
  825. return new_remap.node()->owner_opr();
  826. };
  827. auto ret = std::make_unique<ConvertF32ToF16Pass>();
  828. // don't check dtype
  829. ret->set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^
  830. VarReplaceCheckFlag::CHECK_DTYPE);
  831. auto&& replace_func = ret->m_opr_replace_func;
  832. replace_func[opr::Linspace::typeinfo()] = replace_lsp_opr;
  833. replace_func[opr::Host2DeviceCopy::typeinfo()] = replace_h2d_opr;
  834. replace_func[opr::SharedDeviceTensor::typeinfo()] = replace_sdt_opr;
  835. replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
  836. replace_func[opr::ConvBias::typeinfo()] = replace_convbias_opr;
  837. replace_func[opr::MatrixMul::typeinfo()] = replace_matmul_opr;
  838. replace_func[opr::Reduce::typeinfo()] = replace_reduce_opr;
  839. replace_func[opr::ImmutableTensor::typeinfo()] = replace_imt_opr;
  840. replace_func[opr::TypeCvt::typeinfo()] = replace_cvt_opr;
  841. replace_func[opr::WarpPerspective::typeinfo()] = replace_warp_opr;
  842. replace_func[opr::Remap::typeinfo()] = replace_remap_opr;
  843. replace_func[opr::BatchedMatrixMul::typeinfo()] =
  844. replace_batched_matmul_opr;
  845. return ret;
  846. #endif
  847. }
  848. /* ================ ConvertFormatPass ================ */
  849. void ConvertFormatPass::apply(OptState& state) const {
  850. MIDOUT_B("ConvertFormatPass::apply")
  851. state.set_var_replace_check_flag(m_var_replace_check_flag);
  852. auto rewriter = state.graph().make_rewriter();
  853. VarNodeArray new_inp_cache;
  854. auto on_opr = [this, &state, &rewriter,
  855. &new_inp_cache](OperatorNodeBase* opr) {
  856. auto it = m_opr_replace_func.find(opr->dyn_typeinfo());
  857. if (it != m_opr_replace_func.end()) {
  858. auto&& new_inp = new_inp_cache;
  859. new_inp.clear();
  860. new_inp.reserve(opr->input().size());
  861. for (auto i : opr->input()) {
  862. new_inp.push_back(rewriter.get_var(i));
  863. }
  864. auto new_opr = (it->second)(opr, new_inp);
  865. auto &&out0 = opr->output(), &&out1 = new_opr->output();
  866. mgb_assert(out0.size() == out1.size(),
  867. "bad opr replace: src=%s{%s} dst=%s{%s}, src.size=%zu "
  868. "dst.size=%zu",
  869. opr->cname(), opr->dyn_typeinfo()->name,
  870. new_opr->cname(), new_opr->dyn_typeinfo()->name,
  871. out0.size(), out1.size());
  872. for (size_t i = 0; i < out0.size(); i++) {
  873. if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  874. mgb_assert(!out1[i]->contain_flag(
  875. VarNode::Flag::VOLATILE_CONTENT));
  876. auto src = out0[i];
  877. auto dst = out1[i];
  878. auto dst_is_image = dst->format().type() ==
  879. TensorFormat::Type::IMAGE2D_PACK4;
  880. if (!dst_is_image &&
  881. !src->owner_opr()->same_type<opr::ImmutableTensor>()) {
  882. mgb_log_warn(
  883. "convert NHWCD4 replaced to non-img format: "
  884. "dst_opr=%s{%s} format=%s",
  885. dst->owner_opr()->cname(),
  886. dst->owner_opr()->dyn_typeinfo()->name,
  887. dst->format().to_string().c_str());
  888. }
  889. if (state.graph().endpoint_contain(src) && dst_is_image) {
  890. // relayout back to NCHW for output vars
  891. dst = opr::RelayoutFormat::make(
  892. dst, {opr::RelayoutFormat::Param::Mode::
  893. NHWCD4I_NCHW})
  894. .node();
  895. }
  896. rewriter.replace_var(src, dst, nullptr);
  897. }
  898. }
  899. } else {
  900. rewriter.auto_replace_outputs(opr);
  901. }
  902. };
  903. state.graph().iter(on_opr);
  904. rewriter.apply_inplace();
  905. MIDOUT_E
  906. }
  907. std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
  908. MIDOUT_B("ConvertFormatPass::make")
  909. auto filter_mode =
  910. [](const megdnn::param::Convolution::Sparse conv_mode,
  911. const VarNode* filter) -> megdnn::param::RelayoutFormat::Mode {
  912. bool use_dot = false;
  913. if (filter->dtype().enumv() == megdnn::DTypeEnum::QuantizedS8 ||
  914. filter->dtype().enumv() == megdnn::DTypeEnum::Quantized8Asymm)
  915. use_dot = true;
  916. if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) {
  917. if (use_dot)
  918. return megdnn::param::RelayoutFormat::Mode::
  919. INTER_WEIGHT_DENSEI_DOT;
  920. return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_DENSEI;
  921. } else {
  922. mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP);
  923. if (filter->shape()[1] == 1 && filter->shape()[2] == 1) {
  924. return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_CHANI;
  925. } else {
  926. if (use_dot)
  927. return megdnn::param::RelayoutFormat::Mode::
  928. INTER_WEIGHT_GROUPI_DOT;
  929. return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_GROUPI;
  930. }
  931. }
  932. };
  933. auto replace_conv_opr = [&filter_mode](OperatorNodeBase* opr,
  934. const VarNodeArray& new_inp) {
  935. mgb_assert(opr->input().size() == new_inp.size());
  936. auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
  937. mgb_assert(conv_opr.param().format ==
  938. megdnn::param::Convolution::Format::NCHW,
  939. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  940. VarNode *conv_src = nullptr, *conv_weights = nullptr;
  941. if (new_inp[0]->shape().ndim == 4) {
  942. // new input src is NCHW
  943. size_t group, icpg, ocpg;
  944. if (conv_opr.param().sparse ==
  945. megdnn::param::Convolution::Sparse::DENSE) {
  946. group = 1;
  947. icpg = new_inp[1]->shape()[1];
  948. ocpg = new_inp[1]->shape()[0];
  949. } else {
  950. mgb_assert(conv_opr.param().sparse ==
  951. megdnn::param::Convolution::Sparse::GROUP);
  952. group = new_inp[1]->shape()[0];
  953. icpg = new_inp[1]->shape()[2];
  954. ocpg = new_inp[1]->shape()[1];
  955. }
  956. if (ocpg % 4 == 0 && (icpg % 4 == 0 || group == 1)) {
  957. auto param = megdnn::param::RelayoutFormat();
  958. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  959. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  960. conv_src = rf.node();
  961. } else {
  962. // can not convert to hwcd4
  963. return serialization::copy_opr_shallow(*opr, new_inp,
  964. opr->config());
  965. }
  966. } else {
  967. size_t ocpg;
  968. bool is_channel_wise = false;
  969. if (conv_opr.param().sparse ==
  970. megdnn::param::Convolution::Sparse::DENSE) {
  971. ocpg = new_inp[1]->shape()[0];
  972. } else {
  973. mgb_assert(conv_opr.param().sparse ==
  974. megdnn::param::Convolution::Sparse::GROUP);
  975. size_t icpg = new_inp[1]->shape()[2];
  976. ocpg = new_inp[1]->shape()[1];
  977. if (icpg == 1 && ocpg == 1) {
  978. is_channel_wise = true;
  979. }
  980. }
  981. if (ocpg % 4 != 0 && !is_channel_wise) {
  982. VarNodeArray t_inp = new_inp;
  983. auto param = megdnn::param::RelayoutFormat();
  984. param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
  985. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  986. t_inp[0] = rf.node();
  987. auto new_opr = serialization::copy_opr_shallow(*opr, t_inp,
  988. opr->config());
  989. return new_opr;
  990. }
  991. // new input src is NHWCD4
  992. auto&& fmt = new_inp[0]
  993. ->format()
  994. .as_impl<megdnn::Image2DPack4TensorFormat>();
  995. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  996. conv_src = new_inp[0];
  997. }
  998. mgb_assert(new_inp[1]->format().type() !=
  999. TensorFormat::Type::IMAGE2D_PACK4);
  1000. auto param = megdnn::param::RelayoutFormat();
  1001. param.mode = filter_mode(conv_opr.param().sparse, new_inp[1]);
  1002. auto relayout_weight = opr::RelayoutFormat::make(new_inp[1], param);
  1003. conv_weights = relayout_weight.node();
  1004. auto new_param = conv_opr.param();
  1005. new_param.format = megdnn::param::Convolution::Format::NHWCD4;
  1006. mgb_assert(conv_src->shape().ndim == 5 &&
  1007. conv_src->format().type() ==
  1008. TensorFormat::Type::IMAGE2D_PACK4);
  1009. auto new_conv_opr = opr::Convolution::make(
  1010. conv_src, conv_weights, new_param, conv_opr.execution_policy(),
  1011. conv_opr.config());
  1012. OperatorNodeBase* ret = new_conv_opr.node()->owner_opr();
  1013. mgb_assert(new_conv_opr.shape().ndim == 5 &&
  1014. new_conv_opr.format().type() ==
  1015. TensorFormat::Type::IMAGE2D_PACK4);
  1016. return ret;
  1017. };
  1018. auto replace_conv_bias_opr = [&filter_mode](OperatorNodeBase* opr,
  1019. const VarNodeArray& new_inp) {
  1020. mgb_assert(opr->input().size() == new_inp.size());
  1021. auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
  1022. mgb_assert(conv_bias_opr.param().format ==
  1023. megdnn::param::ConvBias::Format::NCHW,
  1024. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1025. VarNode *conv_bias_src = nullptr, *conv_bias_weights = nullptr,
  1026. *conv_bias_bias = nullptr;
  1027. if (new_inp[0]->shape().ndim == 4) {
  1028. // new input src is NCHW
  1029. size_t group, icpg, ocpg;
  1030. if (conv_bias_opr.param().sparse ==
  1031. megdnn::param::ConvBias::Sparse::DENSE) {
  1032. group = 1;
  1033. icpg = new_inp[1]->shape()[1];
  1034. ocpg = new_inp[1]->shape()[0];
  1035. } else {
  1036. mgb_assert(conv_bias_opr.param().sparse ==
  1037. megdnn::param::ConvBias::Sparse::GROUP);
  1038. group = new_inp[1]->shape()[0];
  1039. icpg = new_inp[1]->shape()[2];
  1040. ocpg = new_inp[1]->shape()[1];
  1041. }
  1042. if (ocpg % 4 == 0 && (icpg % 4 == 0 || group == 1)) {
  1043. auto param = megdnn::param::RelayoutFormat();
  1044. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1045. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1046. conv_bias_src = rf.node();
  1047. } else {
  1048. // can not convert to hwcd4
  1049. return serialization::copy_opr_shallow(*opr, new_inp,
  1050. opr->config());
  1051. }
  1052. } else {
  1053. size_t ocpg;
  1054. bool is_channel_wise = false;
  1055. if (conv_bias_opr.param().sparse ==
  1056. megdnn::param::ConvBias::Sparse::DENSE) {
  1057. ocpg = new_inp[1]->shape()[0];
  1058. } else {
  1059. mgb_assert(conv_bias_opr.param().sparse ==
  1060. megdnn::param::ConvBias::Sparse::GROUP);
  1061. size_t icpg = new_inp[1]->shape()[2];
  1062. ocpg = new_inp[1]->shape()[1];
  1063. if (icpg == 1 && ocpg == 1) {
  1064. is_channel_wise = true;
  1065. }
  1066. }
  1067. if (ocpg % 4 != 0 && !is_channel_wise) {
  1068. VarNodeArray t_inp = new_inp;
  1069. auto param = megdnn::param::RelayoutFormat();
  1070. param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
  1071. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1072. t_inp[0] = rf.node();
  1073. auto new_opr = serialization::copy_opr_shallow(*opr, t_inp,
  1074. opr->config());
  1075. return new_opr;
  1076. }
  1077. // new input src is NHWCD4
  1078. auto&& fmt = new_inp[0]
  1079. ->format()
  1080. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1081. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1082. conv_bias_src = new_inp[0];
  1083. }
  1084. mgb_assert(new_inp[1]->format().type() !=
  1085. TensorFormat::Type::IMAGE2D_PACK4);
  1086. auto param = megdnn::param::RelayoutFormat();
  1087. param.mode = filter_mode(conv_bias_opr.param().sparse, new_inp[1]);
  1088. auto relayout_weight = opr::RelayoutFormat::make(new_inp[1], param);
  1089. conv_bias_weights = relayout_weight.node();
  1090. mgb_assert(new_inp.size() < 4,
  1091. "ConvertFormat pass does not support fuse Z");
  1092. bool has_bias = new_inp.size() > 2;
  1093. if (has_bias &&
  1094. new_inp[2]->format().type() == TensorFormat::Type::DEFAULT) {
  1095. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1096. auto relayout_bias = opr::RelayoutFormat::make(new_inp[2], param);
  1097. conv_bias_bias = relayout_bias.node();
  1098. } else if (has_bias) {
  1099. conv_bias_bias = new_inp[2];
  1100. }
  1101. auto new_param = conv_bias_opr.param();
  1102. new_param.format = megdnn::param::ConvBias::Format::NHWCD4;
  1103. mgb_assert(conv_bias_src->shape().ndim == 5 &&
  1104. conv_bias_src->format().type() ==
  1105. TensorFormat::Type::IMAGE2D_PACK4);
  1106. SymbolVar new_conv_bias_opr;
  1107. if (has_bias) {
  1108. new_conv_bias_opr = opr::ConvBias::make(
  1109. conv_bias_src, conv_bias_weights, conv_bias_bias, new_param,
  1110. conv_bias_opr.execution_policy(), conv_bias_opr.config());
  1111. } else {
  1112. new_conv_bias_opr = opr::ConvBias::make(
  1113. conv_bias_src, conv_bias_weights, new_param,
  1114. conv_bias_opr.execution_policy(), conv_bias_opr.config());
  1115. }
  1116. OperatorNodeBase* ret = new_conv_bias_opr.node()->owner_opr();
  1117. mgb_assert(new_conv_bias_opr.shape().ndim == 5 &&
  1118. new_conv_bias_opr.format().type() ==
  1119. TensorFormat::Type::IMAGE2D_PACK4);
  1120. return ret;
  1121. };
  1122. auto replace_deconv_opr = [&filter_mode](OperatorNodeBase* opr,
  1123. const VarNodeArray& new_inp) {
  1124. mgb_assert(opr->input().size() == new_inp.size());
  1125. auto& deconv_opr = opr->cast_final_safe<opr::ConvolutionBackwardData>();
  1126. mgb_assert(deconv_opr.param().format ==
  1127. megdnn::param::Convolution::Format::NCHW,
  1128. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1129. VarNode *deconv_src = nullptr, *deconv_weights = nullptr;
  1130. if (new_inp[1]->shape().ndim == 4) {
  1131. // new input src is NCHW
  1132. size_t group, icpg, ocpg;
  1133. if (deconv_opr.param().sparse ==
  1134. megdnn::param::Convolution::Sparse::DENSE) {
  1135. group = 1;
  1136. icpg = new_inp[0]->shape()[0];
  1137. ocpg = new_inp[0]->shape()[1];
  1138. } else {
  1139. mgb_assert(deconv_opr.param().sparse ==
  1140. megdnn::param::Convolution::Sparse::GROUP);
  1141. group = new_inp[0]->shape()[0];
  1142. icpg = new_inp[0]->shape()[1];
  1143. ocpg = new_inp[0]->shape()[2];
  1144. }
  1145. if (ocpg % 4 == 0 && (icpg % 4 == 0 || group == 1)) {
  1146. auto param = megdnn::param::RelayoutFormat();
  1147. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1148. auto rf = opr::RelayoutFormat::make(new_inp[1], param);
  1149. deconv_src = rf.node();
  1150. } else {
  1151. // can not convert to hwcd4
  1152. return serialization::copy_opr_shallow(*opr, new_inp,
  1153. opr->config());
  1154. }
  1155. } else {
  1156. //! XXXX, fix me, check filter size
  1157. size_t ocpg;
  1158. if (deconv_opr.param().sparse ==
  1159. megdnn::param::Convolution::Sparse::DENSE) {
  1160. ocpg = new_inp[0]->shape()[1];
  1161. } else {
  1162. mgb_assert(deconv_opr.param().sparse ==
  1163. megdnn::param::Convolution::Sparse::GROUP);
  1164. ocpg = new_inp[0]->shape()[2];
  1165. }
  1166. if (ocpg % 4 != 0) {
  1167. VarNodeArray t_inp = new_inp;
  1168. auto param = megdnn::param::RelayoutFormat();
  1169. param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
  1170. auto rf = opr::RelayoutFormat::make(new_inp[1], param);
  1171. t_inp[1] = rf.node();
  1172. auto new_opr = serialization::copy_opr_shallow(*opr, t_inp,
  1173. opr->config());
  1174. return new_opr;
  1175. }
  1176. // new input src is NHWCD4
  1177. auto&& fmt = new_inp[1]
  1178. ->format()
  1179. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1180. mgb_assert(new_inp[1]->shape().ndim == 5 && fmt.align_axis() == 2);
  1181. deconv_src = new_inp[1];
  1182. }
  1183. mgb_assert(new_inp[0]->format().type() !=
  1184. TensorFormat::Type::IMAGE2D_PACK4);
  1185. auto param = megdnn::param::RelayoutFormat();
  1186. param.mode = filter_mode(deconv_opr.param().sparse, new_inp[0]);
  1187. auto relayout_weight = opr::RelayoutFormat::make(new_inp[0], param);
  1188. deconv_weights = relayout_weight.node();
  1189. auto new_param = deconv_opr.param();
  1190. new_param.format = megdnn::param::Convolution::Format::NHWCD4;
  1191. mgb_assert(deconv_src->shape().ndim == 5 &&
  1192. deconv_src->format().type() ==
  1193. TensorFormat::Type::IMAGE2D_PACK4);
  1194. auto new_deconv_opr = opr::ConvolutionBackwardData::make(
  1195. deconv_weights, deconv_src, new_param,
  1196. deconv_opr.execution_policy(), deconv_opr.config());
  1197. OperatorNodeBase* ret = new_deconv_opr.node()->owner_opr();
  1198. mgb_assert(new_deconv_opr.shape().ndim == 5 &&
  1199. new_deconv_opr.format().type() ==
  1200. TensorFormat::Type::IMAGE2D_PACK4);
  1201. return ret;
  1202. };
  1203. /* This helper function guarantees the format convert pass won't change
  1204. * output var's channel. Changing output's channel will cause channel
  1205. * mismatch problem for replacing conv/conv_bias operator.
  1206. */
  1207. auto replace_helper = [](OperatorNodeBase* opr,
  1208. const VarNodeArray& new_inp) -> OperatorNodeBase* {
  1209. auto&& new_shp = new_inp[0]->shape();
  1210. size_t inp_channel = new_shp[1];
  1211. if (new_shp.eq_shape(opr->input(0)->shape())&& inp_channel % 4 != 0) {
  1212. auto new_opr = serialization::copy_opr_shallow(*opr, new_inp,
  1213. opr->config());
  1214. return new_opr;
  1215. }
  1216. return nullptr;
  1217. };
  1218. auto replace_resize_opr = [replace_helper](OperatorNodeBase* opr,
  1219. const VarNodeArray& new_inp) {
  1220. mgb_assert(opr->input().size() == new_inp.size());
  1221. if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
  1222. return opr_shallow_copy;
  1223. }
  1224. auto& resize_opr = opr->cast_final_safe<opr::ResizeForward>();
  1225. mgb_assert(resize_opr.param().format ==
  1226. megdnn::param::Resize::Format::NCHW,
  1227. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1228. VarNode* inp = nullptr;
  1229. if (new_inp[0]->shape().ndim == 4) {
  1230. auto param = megdnn::param::RelayoutFormat();
  1231. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1232. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1233. inp = rf.node();
  1234. } else {
  1235. // new input src is NHWCD
  1236. auto&& fmt = new_inp[0]
  1237. ->format()
  1238. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1239. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1240. inp = new_inp[0];
  1241. }
  1242. auto new_param = resize_opr.param();
  1243. new_param.format = megdnn::param::Resize::Format::NHWCD4;
  1244. auto new_resize_opr = opr::ResizeForward::make(
  1245. inp, new_inp[1], new_param, opr->config());
  1246. return new_resize_opr.node()->owner_opr();
  1247. };
  1248. auto replace_warp_perspective_opr = [replace_helper](
  1249. OperatorNodeBase* opr,
  1250. const VarNodeArray& new_inp) {
  1251. mgb_assert(opr->input().size() == new_inp.size());
  1252. if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
  1253. return opr_shallow_copy;
  1254. }
  1255. auto& warp_opr = opr->cast_final_safe<opr::WarpPerspectiveForward>();
  1256. mgb_assert(warp_opr.param().format ==
  1257. megdnn::param::WarpPerspective::Format::NCHW,
  1258. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1259. VarNode* inp = nullptr;
  1260. if (new_inp[0]->shape().ndim == 4) {
  1261. // new input src is NCHW
  1262. auto param = megdnn::param::RelayoutFormat();
  1263. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1264. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1265. inp = rf.node();
  1266. } else {
  1267. // new input src is NHWCD
  1268. auto&& fmt = new_inp[0]
  1269. ->format()
  1270. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1271. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1272. inp = new_inp[0];
  1273. }
  1274. auto new_param = warp_opr.param();
  1275. new_param.format = megdnn::param::WarpPerspective::Format::NHWCD4;
  1276. SymbolVar new_warp_opr;
  1277. if (new_inp.size() == 3) {
  1278. new_warp_opr = opr::WarpPerspectiveForward::make(
  1279. inp, new_inp[1], nullptr, new_inp[2], new_param,
  1280. opr->config());
  1281. } else {
  1282. mgb_assert(new_inp.size() == 4);
  1283. new_warp_opr = opr::WarpPerspectiveForward::make(
  1284. inp, new_inp[1], new_inp[2], new_inp[3], new_param,
  1285. opr->config());
  1286. }
  1287. return new_warp_opr.node()->owner_opr();
  1288. };
  1289. auto replace_warp_affine_opr = [replace_helper](OperatorNodeBase* opr,
  1290. const VarNodeArray& new_inp) {
  1291. mgb_assert(opr->input().size() == new_inp.size());
  1292. if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
  1293. return opr_shallow_copy;
  1294. }
  1295. auto& warp_opr = opr->cast_final_safe<opr::WarpAffineForward>();
  1296. mgb_assert(warp_opr.param().format ==
  1297. megdnn::param::WarpAffine::Format::NCHW,
  1298. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1299. VarNode* inp = nullptr;
  1300. if (new_inp[0]->shape().ndim == 4) {
  1301. // new input src is NCHW
  1302. auto param = megdnn::param::RelayoutFormat();
  1303. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1304. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1305. inp = rf.node();
  1306. } else {
  1307. // new input src is NHWCD
  1308. auto&& fmt = new_inp[0]
  1309. ->format()
  1310. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1311. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1312. inp = new_inp[0];
  1313. }
  1314. auto new_param = warp_opr.param();
  1315. new_param.format = megdnn::param::WarpAffine::Format::NHWCD4;
  1316. SymbolVar new_warp_opr;
  1317. new_warp_opr = opr::WarpAffineForward::make(inp, new_inp[1], new_inp[2],
  1318. new_param, opr->config());
  1319. return new_warp_opr.node()->owner_opr();
  1320. };
  1321. auto replace_pooling_opr = [replace_helper](OperatorNodeBase* opr,
  1322. const VarNodeArray& new_inp) {
  1323. mgb_assert(opr->input().size() == new_inp.size());
  1324. if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
  1325. return opr_shallow_copy;
  1326. }
  1327. auto& pooling_opr = opr->cast_final_safe<opr::PoolingForward>();
  1328. mgb_assert(pooling_opr.param().format ==
  1329. megdnn::param::Pooling::Format::NCHW,
  1330. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1331. VarNode* inp = nullptr;
  1332. if (new_inp[0]->shape().ndim == 4) {
  1333. // new input src is NCHW
  1334. auto param = megdnn::param::RelayoutFormat();
  1335. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1336. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1337. inp = rf.node();
  1338. } else {
  1339. // new input src is NHWCD
  1340. auto&& fmt = new_inp[0]
  1341. ->format()
  1342. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1343. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1344. inp = new_inp[0];
  1345. }
  1346. auto new_param = pooling_opr.param();
  1347. new_param.format = megdnn::param::Pooling::Format::NHWCD4;
  1348. auto new_pooling_opr =
  1349. opr::PoolingForward::make(inp, new_param, opr->config());
  1350. return new_pooling_opr.node()->owner_opr();
  1351. };
  1352. auto var_to_chw = [](VarNode* inp, VarNode* new_inp) {
  1353. if (!inp->shape().eq_shape(new_inp->shape())) {
  1354. mgb_assert(inp->shape().ndim == 4 &&
  1355. inp->format().type() !=
  1356. TensorFormat::Type::IMAGE2D_PACK4);
  1357. mgb_assert(new_inp->shape().ndim == 5 &&
  1358. new_inp->format().type() ==
  1359. TensorFormat::Type::IMAGE2D_PACK4);
  1360. auto param = megdnn::param::RelayoutFormat();
  1361. param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
  1362. auto rf = opr::RelayoutFormat::make(new_inp, param);
  1363. return rf.node();
  1364. }
  1365. return new_inp;
  1366. };
  1367. auto relayout_inp_to_chw = [var_to_chw](OperatorNodeBase* opr,
  1368. const VarNodeArray& new_inp) {
  1369. mgb_assert(opr->input().size() == new_inp.size());
  1370. VarNodeArray t_inp = new_inp;
  1371. for (size_t i = 0; i < opr->input().size(); i++) {
  1372. t_inp[i] = var_to_chw(opr->input(i), new_inp[i]);
  1373. }
  1374. auto new_opr =
  1375. serialization::copy_opr_shallow(*opr, t_inp, opr->config());
  1376. return new_opr;
  1377. };
  1378. auto replace_elemwise_opr = [](OperatorNodeBase* opr,
  1379. const VarNodeArray& new_inp) {
  1380. mgb_assert(opr->input().size() == new_inp.size());
  1381. bool has_inp_changed = false;
  1382. for (size_t i = 0; i < opr->input().size(); i++) {
  1383. if (!new_inp[i]->format().is_default()) {
  1384. has_inp_changed = true;
  1385. break;
  1386. }
  1387. }
  1388. if (has_inp_changed) {
  1389. // assumption: all inputs are changed from nchw to nhwcd4
  1390. auto t_inp = new_inp;
  1391. for (size_t i = 0; i < opr->input().size(); i++) {
  1392. if (new_inp[i]->shape().ndim == 4) {
  1393. auto param = megdnn::param::RelayoutFormat();
  1394. param.mode =
  1395. megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1396. auto rf = opr::RelayoutFormat::make(new_inp[i], param);
  1397. t_inp[i] = rf.node();
  1398. } else {
  1399. mgb_assert((new_inp[i]->shape().ndim == 5 &&
  1400. new_inp[i]->format().type() ==
  1401. TensorFormat::Type::IMAGE2D_PACK4) ||
  1402. new_inp[i]->shape().is_scalar());
  1403. }
  1404. }
  1405. return serialization::copy_opr_shallow(*opr, t_inp, opr->config());
  1406. } else {
  1407. return serialization::copy_opr_shallow(*opr, new_inp,
  1408. opr->config());
  1409. }
  1410. };
  1411. /* This helper function converts the first input to the NCHW format to
  1412. * handle operations that do not support NHWCD4 format
  1413. */
  1414. auto relayout_first_inp_to_chw =
  1415. [var_to_chw](OperatorNodeBase* opr,
  1416. const VarNodeArray& new_inp) -> OperatorNodeBase* {
  1417. mgb_assert(opr->input().size() == new_inp.size());
  1418. VarNodeArray t_inp = new_inp;
  1419. t_inp[0] = var_to_chw(opr->input(0), new_inp[0]);
  1420. return serialization::copy_opr_shallow(*opr, t_inp, opr->config());
  1421. };
  1422. auto ret = std::make_unique<ConvertFormatPass>();
  1423. ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
  1424. auto&& replace_func = ret->m_opr_replace_func;
  1425. replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
  1426. replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr;
  1427. replace_func[opr::ConvolutionBackwardData::typeinfo()] = replace_deconv_opr;
  1428. replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
  1429. replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr;
  1430. replace_func[opr::Concat::typeinfo()] = relayout_inp_to_chw;
  1431. replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_chw;
  1432. replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_chw;
  1433. replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_chw;
  1434. replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_chw;
  1435. replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_chw;
  1436. replace_func[opr::Subtensor::typeinfo()] = relayout_inp_to_chw;
  1437. replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_chw;
  1438. replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_chw;
  1439. replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr;
  1440. replace_func[opr::WarpPerspectiveForward::typeinfo()] =
  1441. replace_warp_perspective_opr;
  1442. replace_func[opr::WarpAffineForward::typeinfo()] = replace_warp_affine_opr;
  1443. replace_func[opr::LocalForward::typeinfo()] = relayout_first_inp_to_chw;
  1444. replace_func[opr::GroupLocalForward::typeinfo()] =
  1445. relayout_first_inp_to_chw;
  1446. return ret;
  1447. MIDOUT_E
  1448. }
  1449. /* ================ ConvertBatchNormPass ================ */
  1450. const char* ConvertBatchNormToElemwisePass::name() const {
  1451. return "convert_batch_norm";
  1452. }
  1453. void ConvertBatchNormToElemwisePass::apply(OptState& state) const {
  1454. MIDOUT_B("ConvertBatchNormToElemwisePass::apply")
  1455. auto rewriter = state.graph().make_rewriter();
  1456. auto on_opr = [&](OperatorNodeBase* opr) {
  1457. if (auto bn = try_cast_as_op<opr::BatchNorm>(opr)) {
  1458. if (bn->input().size() == 5) {
  1459. mgb_assert(bn->param().fwd_mode ==
  1460. opr::BatchNorm::Param::FwdMode::INFERENCE);
  1461. SymbolVar x = {rewriter.get_var(bn->input(0))};
  1462. SymbolVar scale = {rewriter.get_var(bn->input(1))};
  1463. SymbolVar bias = {rewriter.get_var(bn->input(2))};
  1464. SymbolVar mean = {rewriter.get_var(bn->input(3))};
  1465. SymbolVar variance = {rewriter.get_var(bn->input(4))};
  1466. SymbolVar invsqrt_variance = opr::PowC::make(variance
  1467. + variance.make_scalar_dt(float(bn->param().epsilon)), {-0.5});
  1468. auto res = scale * (x - mean) * invsqrt_variance + bias;
  1469. rewriter.replace_var(
  1470. opr->output(4), res.node(),
  1471. mgb_cstr_log(
  1472. "replace batch_norm(x, scale, bias, mean, "
  1473. "varience) "
  1474. "-> (sclae * (x - mean) / sqrt(variance)) + b)"));
  1475. return;
  1476. }
  1477. }
  1478. rewriter.auto_replace_outputs(opr);
  1479. };
  1480. state.graph().iter(on_opr);
  1481. rewriter.apply_inplace();
  1482. MIDOUT_E
  1483. }
  1484. /* ================ FuseConvBiasNonlinPass ================ */
  1485. const char* FuseConvBiasNonlinPass::name() const {
  1486. return "combine_conv_bias_and_relu";
  1487. }
  1488. void FuseConvBiasNonlinPass::apply(OptState& state) const {
  1489. MIDOUT_B("FuseConvBiasNonlinPass::apply")
  1490. std::unordered_map<VarNode*, std::vector<OperatorNodeBase*>> m_deps;
  1491. state.graph().iter([&m_deps](OperatorNodeBase* opr) {
  1492. for (auto& inp : opr->input()) {
  1493. m_deps[inp].push_back(opr);
  1494. }
  1495. });
  1496. auto rewriter = state.graph().make_rewriter();
  1497. using Mode = opr::Elemwise::Param::Mode;
  1498. using NonlineMode = opr::ConvBiasForward::Param::NonlineMode;
  1499. auto get_nonlinearity_mode = [&](opr::Elemwise* elem) -> NonlineMode {
  1500. if (elem->param().mode == Mode::FUSE_ADD_RELU ||
  1501. elem->param().mode == Mode::RELU) {
  1502. return NonlineMode::RELU;
  1503. } else if (elem->param().mode == Mode::FUSE_ADD_SIGMOID ||
  1504. elem->param().mode == Mode::SIGMOID) {
  1505. return NonlineMode::SIGMOID;
  1506. } else {
  1507. return NonlineMode::IDENTITY;
  1508. }
  1509. };
  1510. auto try_fuse_bias_nonlinearity = [&](opr::Elemwise* elem) -> bool {
  1511. bool can_be_fused = true;
  1512. can_be_fused &= (elem->input().size() == 2);
  1513. can_be_fused &= (elem->param().mode == Mode::FUSE_ADD_RELU) ||
  1514. (elem->param().mode == Mode::FUSE_ADD_TANH) ||
  1515. (elem->param().mode == Mode::FUSE_ADD_SIGMOID);
  1516. return can_be_fused;
  1517. };
  1518. auto try_fuse_bias = [&](opr::Elemwise* elem) -> bool {
  1519. bool can_be_fused = true;
  1520. can_be_fused &= (elem->input().size() == 2);
  1521. can_be_fused &= (elem->param().mode == Mode::ADD);
  1522. return can_be_fused;
  1523. };
  1524. auto try_fuse_nonlinearity = [&](opr::Elemwise* elem) -> bool {
  1525. bool can_be_fused = true;
  1526. can_be_fused &= (elem->input().size() == 1);
  1527. can_be_fused &= (elem->param().mode == Mode::RELU) ||
  1528. (elem->param().mode == Mode::TANH) ||
  1529. (elem->param().mode == Mode::SIGMOID);
  1530. return can_be_fused;
  1531. };
  1532. auto convert_to_conv_bias_param = [&](const opr::Convolution::Param& param)
  1533. -> opr::ConvBiasForward::Param {
  1534. using Param = opr::ConvBiasForward::Param;
  1535. return opr::ConvBiasForward::Param{Param::NonlineMode::IDENTITY,
  1536. param.mode,
  1537. param.sparse,
  1538. param.format,
  1539. param.pad_h,
  1540. param.pad_w,
  1541. param.stride_h,
  1542. param.stride_w,
  1543. param.dilate_h,
  1544. param.dilate_w,
  1545. 0,
  1546. param.compute_mode};
  1547. };
  1548. auto check_bias_shape = [&](opr::Convolution* conv, VarNode* bias) -> bool {
  1549. bool valid_bias_shape = true;
  1550. using Format = opr::Convolution::Param::Format;
  1551. using Sparse = opr::Convolution::Param::Sparse;
  1552. auto dst_shape = conv->output(0)->shape();
  1553. auto filter_shape = conv->input(1)->shape();
  1554. auto bias_shape = bias->shape();
  1555. if (dst_shape.eq_shape(bias_shape)) {
  1556. return valid_bias_shape;
  1557. }
  1558. size_t OC = filter_shape[0];
  1559. if (conv->param().sparse == Sparse::GROUP) {
  1560. OC *= filter_shape[1];
  1561. }
  1562. if (conv->param().format == Format::NCHW) {
  1563. valid_bias_shape &=
  1564. ((bias_shape.ndim == 4) && (bias_shape[0] == 1) &&
  1565. (bias_shape[1] == OC) && (bias_shape[2] == 1) &&
  1566. (bias_shape[3] == 1));
  1567. } else if (conv->param().format == Format::NCHW4) {
  1568. valid_bias_shape &=
  1569. ((bias_shape.ndim == 5) && (bias_shape[0] == 1) &&
  1570. (bias_shape[1] == OC / 4) && (bias_shape[2] == 1) &&
  1571. (bias_shape[3] == 1) && bias_shape[4] == 4);
  1572. } else if (conv->param().format == Format::NHWC) {
  1573. valid_bias_shape &= ((bias_shape.ndim == 4) &&
  1574. (bias_shape[0] == 1) && (bias_shape[1] == 1) &&
  1575. (bias_shape[2] == 1) && (bias_shape[3] == OC));
  1576. } else {
  1577. valid_bias_shape &=
  1578. ((bias_shape.ndim == 5) && (bias_shape[0] == 1) &&
  1579. (bias_shape[1] == 1) && (bias_shape[2] == OC) &&
  1580. (bias_shape[3] == 1) && (bias_shape[4] == 4));
  1581. mgb_assert(conv->param().format == Format::NHWCD4);
  1582. }
  1583. return valid_bias_shape;
  1584. };
  1585. auto try_fuse_typecvt = [&](opr::TypeCvt* typecvt) -> OperatorNodeBase* {
  1586. mgb_assert(typecvt->input().size() == 1);
  1587. auto conv_bias = try_cast_as_op<opr::ConvBias>(
  1588. rewriter.get_var(typecvt->input(0))->owner_opr());
  1589. if (!conv_bias || m_deps.count(typecvt->input(0)) != 1 ||
  1590. typecvt->output(0)->dtype().enumv() !=
  1591. DTypeTrait<dtype::QuantizedS8>::enumv ||
  1592. typecvt->input(0)->dtype().enumv() !=
  1593. DTypeTrait<dtype::QuantizedS32>::enumv)
  1594. return nullptr;
  1595. auto config = conv_bias->config();
  1596. config.output_dtype(typecvt->output(0)->dtype());
  1597. if (conv_bias->input().size() == 3) {
  1598. // conv + bias
  1599. return opr::ConvBias::make(conv_bias->input(0), conv_bias->input(1),
  1600. conv_bias->input(2), conv_bias->param(),
  1601. conv_bias->execution_policy(), config)
  1602. .node()
  1603. ->owner_opr();
  1604. } else {
  1605. // conv without bias
  1606. return opr::ConvBias::make(conv_bias->input(0), conv_bias->input(1),
  1607. conv_bias->param(),
  1608. conv_bias->execution_policy(), config)
  1609. .node()
  1610. ->owner_opr();
  1611. }
  1612. };
  1613. auto on_opr = [&](OperatorNodeBase* opr) {
  1614. auto check_conv = [](opr::Convolution* conv) -> bool {
  1615. return conv->param().format ==
  1616. megdnn::param::Convolution::Format::NHWCD4 ||
  1617. conv->param().format ==
  1618. megdnn::param::Convolution::Format::NHWC ||
  1619. conv->param().format ==
  1620. megdnn::param::Convolution::Format::NCHW ||
  1621. conv->param().format ==
  1622. megdnn::param::Convolution::Format::NCHW4
  1623. ;
  1624. };
  1625. if (auto elem = try_cast_as_op<opr::Elemwise>(opr)) {
  1626. if (try_fuse_bias_nonlinearity(elem) || try_fuse_bias(elem)) {
  1627. auto inp1 = rewriter.get_var(elem->input(0));
  1628. auto inp2 = rewriter.get_var(elem->input(1));
  1629. opr::Convolution* conv = nullptr;
  1630. size_t bias_idx = 0;
  1631. if (inp1->owner_opr()->same_type<opr::Convolution>() &&
  1632. m_deps[elem->input(0)].size() == 1) {
  1633. conv = try_cast_as_op<opr::Convolution>(inp1->owner_opr());
  1634. bias_idx = 1;
  1635. } else if (inp2->owner_opr()->same_type<opr::Convolution>() &&
  1636. m_deps[elem->input(1)].size() == 1) {
  1637. conv = try_cast_as_op<opr::Convolution>(inp2->owner_opr());
  1638. bias_idx = 0;
  1639. }
  1640. auto bias_inp = rewriter.get_var(elem->input(bias_idx));
  1641. if (conv && check_conv(conv) &&
  1642. check_bias_shape(conv, bias_inp)) {
  1643. opr::ConvBiasForward::Param param =
  1644. convert_to_conv_bias_param(conv->param());
  1645. param.nonlineMode = get_nonlinearity_mode(elem);
  1646. auto new_var =
  1647. opr::ConvBiasForward::make(
  1648. conv->input(0), conv->input(1), bias_inp,
  1649. param, conv->execution_policy(),
  1650. conv->config())
  1651. .node();
  1652. rewriter.replace_var(
  1653. opr->output(0), new_var,
  1654. mgb_cstr_log("replace nonlinearity(conv(x, w) + b) "
  1655. "-> conv_bias(x, w, b)"));
  1656. return;
  1657. }
  1658. } else if (try_fuse_nonlinearity(elem)) {
  1659. auto inp = rewriter.get_var(elem->input(0));
  1660. {
  1661. auto conv =
  1662. try_cast_as_op<opr::Convolution>(inp->owner_opr());
  1663. if (conv && check_conv(conv) &&
  1664. m_deps[elem->input(0)].size() == 1) {
  1665. opr::ConvBiasForward::Param param =
  1666. convert_to_conv_bias_param(conv->param());
  1667. param.nonlineMode = get_nonlinearity_mode(elem);
  1668. auto new_var = opr::ConvBiasForward::make(
  1669. conv->input(0), conv->input(1),
  1670. param, conv->execution_policy(),
  1671. conv->config())
  1672. .node();
  1673. rewriter.replace_var(
  1674. opr->output(0), new_var,
  1675. mgb_cstr_log("replace nonlinearity(conv(x, w)) "
  1676. "-> conv_bias(x, w)"));
  1677. return;
  1678. }
  1679. }
  1680. {
  1681. auto conv = try_cast_as_op<opr::ConvBias>(inp->owner_opr());
  1682. auto check_conv_bias = [&](opr::ConvBias* opr) {
  1683. return opr->param().format ==
  1684. opr::ConvBias::Param::Format::NHWC ||
  1685. opr->param().format ==
  1686. opr::ConvBias::Param::Format::NCHW ||
  1687. opr->param().format ==
  1688. opr::ConvBias::Param::Format::NCHW4
  1689. ;
  1690. };
  1691. if (conv && check_conv_bias(conv) &&
  1692. m_deps[elem->input(0)].size() == 1) {
  1693. auto param = conv->param();
  1694. param.nonlineMode = get_nonlinearity_mode(elem);
  1695. auto new_var = opr::ConvBiasForward::make(
  1696. conv->input(0), conv->input(1),
  1697. conv->input(2), param,
  1698. conv->execution_policy(),
  1699. conv->config())
  1700. .node();
  1701. rewriter.replace_var(
  1702. opr->output(0), new_var,
  1703. mgb_cstr_log("replace nonlinearity(conv(x, w)) "
  1704. "-> conv_bias(x, w)"));
  1705. return;
  1706. }
  1707. }
  1708. }
  1709. } else if (auto typecvt = try_cast_as_op<opr::TypeCvt>(opr)) {
  1710. auto new_opr = try_fuse_typecvt(typecvt);
  1711. if (new_opr) {
  1712. rewriter.replace_var(
  1713. opr->output(0), new_opr->output(0),
  1714. mgb_cstr_log("replace typecvt(conv_bias(x, w, b)) -> "
  1715. "conv_bias(x, w, b)"));
  1716. return;
  1717. }
  1718. }
  1719. rewriter.auto_replace_outputs(opr);
  1720. };
  1721. state.graph().iter(on_opr);
  1722. rewriter.apply_inplace();
  1723. MIDOUT_E
  1724. }
  1725. /* ================ FuseConvBiasZPass ================ */
  1726. const char* FuseConvBiasZPass::name() const {
  1727. return "combine_conv_bias_and_z";
  1728. }
  1729. void FuseConvBiasZPass::apply(OptState& state) const {
  1730. MIDOUT_B("FuseConvBiasZPass::apply")
  1731. UniqReaderCheck uniq_reader_check{state.graph()};
  1732. auto rewriter = state.graph().make_rewriter();
  1733. using Mode = opr::Elemwise::Param::Mode;
  1734. using MultiMode = opr::ElemwiseMultiType::Param::Mode;
  1735. using NonlineMode = opr::ConvBiasForward::Param::NonlineMode;
  1736. auto check_conv_bias = [](opr::ConvBias* conv_bias) -> bool {
  1737. return conv_bias->param().format ==
  1738. megdnn::param::ConvBias::Format::NHWC ||
  1739. conv_bias->param().format ==
  1740. megdnn::param::ConvBias::Format::NCHW ||
  1741. conv_bias->param().format ==
  1742. megdnn::param::ConvBias::Format::NCHW4
  1743. ;
  1744. };
  1745. auto check_fuse_shape = [&](opr::ConvBias* conv_bias, VarNode* z) -> bool {
  1746. bool valid_fuse_shape = true;
  1747. auto z_shape = z->shape();
  1748. auto bias_shape = conv_bias->input(2)->shape();
  1749. auto conv_bias_shape = conv_bias->output(0)->shape();
  1750. valid_fuse_shape &= (!conv_bias_shape.eq_shape(bias_shape));
  1751. valid_fuse_shape &= conv_bias_shape.eq_shape(z_shape);
  1752. return valid_fuse_shape;
  1753. };
  1754. auto check_fuse_dtype = [&](opr::ConvBias* conv_bias, VarNode* z) -> bool {
  1755. return conv_bias->output(0)->dtype().enumv() == z->dtype().enumv();
  1756. };
  1757. auto get_convbias_nonline_mode = [&](OperatorNodeBase* opr) -> NonlineMode {
  1758. if (opr->same_type<opr::Elemwise>()) {
  1759. auto elem = try_cast_as_op<opr::Elemwise>(opr);
  1760. if (elem->param().mode == Mode::FUSE_ADD_RELU)
  1761. return NonlineMode::RELU;
  1762. }
  1763. if (opr->same_type<opr::ElemwiseMultiType>()) {
  1764. auto elem = try_cast_as_op<opr::ElemwiseMultiType>(opr);
  1765. if (elem->param().mode == MultiMode::QFUSE_ADD_RELU)
  1766. return NonlineMode::RELU;
  1767. else if (elem->param().mode == MultiMode::QFUSE_ADD_H_SWISH)
  1768. return NonlineMode::H_SWISH;
  1769. }
  1770. return NonlineMode::IDENTITY;
  1771. };
  1772. auto try_replace_var_node = [&](OperatorNodeBase* opr) {
  1773. opr::ConvBias* conv_bias = nullptr;
  1774. size_t z_idx = 0;
  1775. size_t nr_inps = opr->input().size();
  1776. for (size_t i = 0; i < nr_inps; i++) {
  1777. auto inp = rewriter.get_var(opr->input(i));
  1778. if (inp->owner_opr()->same_type<opr::ConvBias>()) {
  1779. auto cb = try_cast_as_op<opr::ConvBias>(inp->owner_opr());
  1780. if (cb->input().size() == 3 &&
  1781. cb->param().nonlineMode ==
  1782. opr::ConvBias::Param::NonlineMode::IDENTITY &&
  1783. uniq_reader_check(opr->input(i))) {
  1784. conv_bias = cb;
  1785. z_idx = nr_inps - i - 1;
  1786. break;
  1787. }
  1788. }
  1789. }
  1790. auto z_inp = rewriter.get_var(opr->input(z_idx));
  1791. if (conv_bias && check_conv_bias(conv_bias) &&
  1792. check_fuse_shape(conv_bias, z_inp) &&
  1793. check_fuse_dtype(conv_bias, z_inp)) {
  1794. auto param = conv_bias->param();
  1795. param.nonlineMode = get_convbias_nonline_mode(opr);
  1796. auto config = conv_bias->config();
  1797. auto new_var = opr::ConvBiasForward::make(
  1798. conv_bias->input(0), conv_bias->input(1),
  1799. conv_bias->input(2), z_inp, param,
  1800. conv_bias->execution_policy(),
  1801. config.output_dtype(opr->output(0)->dtype()))
  1802. .node();
  1803. rewriter.replace_var(
  1804. opr->output(0), new_var,
  1805. mgb_cstr_log("replace "
  1806. "nonlinearity(conv_bias(x,w,b) + z) "
  1807. "-> conv_bias(x, w, b, z)"));
  1808. uniq_reader_check.update_on_opr_auto_replace(opr,
  1809. new_var->owner_opr());
  1810. return true;
  1811. }
  1812. return false;
  1813. };
  1814. auto try_fuse_elemwise = [&](OperatorNodeBase* opr) {
  1815. if (!opr->same_type<opr::Elemwise>())
  1816. return false;
  1817. auto elem = try_cast_as_op<opr::Elemwise>(opr);
  1818. if (elem->input().size() != 2)
  1819. return false;
  1820. if (elem->param().mode != Mode::ADD &&
  1821. elem->param().mode != Mode::FUSE_ADD_RELU)
  1822. return false;
  1823. return try_replace_var_node(opr);
  1824. };
  1825. auto try_fuse_elemwise_multi_type = [&](OperatorNodeBase* opr) {
  1826. if (!opr->same_type<opr::ElemwiseMultiType>())
  1827. return false;
  1828. auto elem = try_cast_as_op<opr::ElemwiseMultiType>(opr);
  1829. if (elem->input().size() != 2)
  1830. return false;
  1831. if (elem->param().mode != MultiMode::QADD &&
  1832. elem->param().mode != MultiMode::QFUSE_ADD_RELU &&
  1833. elem->param().mode != MultiMode::QFUSE_ADD_H_SWISH)
  1834. return false;
  1835. return try_replace_var_node(opr);
  1836. };
  1837. auto on_opr = [&](OperatorNodeBase* opr) {
  1838. if (try_fuse_elemwise(opr))
  1839. return;
  1840. if (try_fuse_elemwise_multi_type(opr))
  1841. return;
  1842. auto new_opr = rewriter.auto_replace_outputs(opr);
  1843. uniq_reader_check.update_on_opr_auto_replace(opr, new_opr);
  1844. };
  1845. state.graph().iter(on_opr);
  1846. rewriter.apply_inplace();
  1847. MIDOUT_E
  1848. }
  1849. /* ================ FuseDeconvCvtPass ================ */
  1850. const char* FuseDeconvCvtPass::name() const {
  1851. return "combine_deconv_and_typecvt";
  1852. }
  1853. void FuseDeconvCvtPass::apply(OptState& state) const {
  1854. MIDOUT_B("FuseDeconvCvtPass::apply")
  1855. std::unordered_map<VarNode*, std::vector<OperatorNodeBase*>> m_deps;
  1856. state.graph().iter([&m_deps](OperatorNodeBase* opr) {
  1857. for (auto& inp : opr->input()) {
  1858. m_deps[inp].push_back(opr);
  1859. }
  1860. });
  1861. UniqReaderCheck uniq_reader_check{state.graph()};
  1862. auto rewriter = state.graph().make_rewriter();
  1863. auto try_fuse_deconv_typecvt =
  1864. [&](opr::TypeCvt* typecvt) -> OperatorNodeBase* {
  1865. mgb_assert(typecvt->input().size() == 1);
  1866. auto deconv = try_cast_as_op<opr::ConvolutionBackwardData>(
  1867. rewriter.get_var(typecvt->input(0))->owner_opr());
  1868. if (!deconv
  1869. || m_deps.count(typecvt->input(0)) != 1 ||
  1870. typecvt->output(0)->dtype().enumv() !=
  1871. DTypeTrait<dtype::QuantizedS8>::enumv) {
  1872. return nullptr;
  1873. }
  1874. if (!uniq_reader_check(deconv->output(0)))
  1875. return nullptr;
  1876. auto config = deconv->config();
  1877. config.output_dtype(typecvt->output(0)->dtype());
  1878. return opr::ConvolutionBackwardData::make(
  1879. deconv->input(0), deconv->input(1), deconv->param(),
  1880. deconv->execution_policy(), config)
  1881. .node()
  1882. ->owner_opr();
  1883. };
  1884. auto on_opr = [&](OperatorNodeBase* opr) {
  1885. if (auto typecvt = try_cast_as_op<opr::TypeCvt>(opr)) {
  1886. if (auto deconv_new = try_fuse_deconv_typecvt(typecvt)) {
  1887. rewriter.replace_var(
  1888. opr->output(0), deconv_new->output(0),
  1889. mgb_cstr_log("replace typecvt(deconv(x, w)) -> "
  1890. "deconv(x, w)"));
  1891. uniq_reader_check.update_on_opr_auto_replace(opr, deconv_new);
  1892. return;
  1893. }
  1894. }
  1895. auto new_opr = rewriter.auto_replace_outputs(opr);
  1896. uniq_reader_check.update_on_opr_auto_replace(
  1897. opr, new_opr);
  1898. };
  1899. state.graph().iter(on_opr);
  1900. rewriter.apply_inplace();
  1901. MIDOUT_E
  1902. }
  1903. /* ================ ParamMergePass ================ */
  1904. const char* ParamMergePass::name() const {
  1905. return mgb_cstr_log("param_merge");
  1906. }
  1907. void ParamMergePass::apply(OptState& opt_state) const {
  1908. MIDOUT_B("ParamMergePass::apply")
  1909. param_merge<opr::SharedDeviceTensor, opr::MultipleDeviceTensorHolder>(
  1910. opt_state);
  1911. param_merge<opr::SharedDeviceTensorWithFormat,
  1912. opr::MultipleDeviceTensorWithFormatHolder>(opt_state);
  1913. MIDOUT_E
  1914. }
  1915. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台