You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inference.cpp 92 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158
  1. /**
  2. * \file src/gopt/impl/inference.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/gopt/inference.h"
  12. #include "megbrain/gopt/gtrans.h"
  13. #include "megbrain/gopt/basic_arith.h"
  14. #include "megbrain/graph/event.h"
  15. #include "megbrain/opr/dnn/batch_norm.h"
  16. #include "megbrain/opr/dnn/local.h"
  17. #include "megbrain/opr/search_policy/algo_chooser_helper.h"
  18. #include "megbrain/opr/search_policy/profiler.h"
  19. #include "megbrain/utils/shared_set.h"
  20. #include "megbrain/serialization/opr_shallow_copy.h"
  21. #include "megbrain/opr/basic_arith.h"
  22. #include "megbrain/opr/dnn/convolution.h"
  23. #include "megbrain/opr/blas.h"
  24. #include "megbrain/opr/misc.h"
  25. #include "megbrain/opr/utility.h"
  26. #include "megbrain/opr/dnn/pooling.h"
  27. #include "megbrain/opr/tensor_manip.h"
  28. #include "megbrain/opr/imgproc.h"
  29. #include "megbrain/opr/nn_int.h"
  30. #include "megbrain/opr/tensor_gen.h"
  31. #include "megbrain/utils/hash_ct.h"
  32. #include "megdnn/tensor_format.h"
  33. #if MGB_ENABLE_TENSOR_RT
  34. #include "megbrain/tensorrt/tensorrt_opr.h"
  35. #endif
  36. #include "megbrain/gopt/misc.h"
  37. #include "megbrain/utils/hash_ct.h"
  38. #include "midout.h"
  39. MIDOUT_DECL(megbrain_inference)
  40. #define MIDOUT_B(tag) \
  41. MIDOUT_BEGIN(megbrain_inference, midout_iv(MGB_HASH_STR(tag))) {
  42. #define MIDOUT_E \
  43. } \
  44. MIDOUT_END();
  45. using namespace mgb;
  46. using namespace gopt;
  47. namespace {
  48. template <typename SharedDeviceTensor, typename MultipleDeviceTensorHolder>
  49. void param_merge(OptState& opt_state) {
  50. auto rewriter = opt_state.graph().make_rewriter();
  51. ThinHashMap<OperatorNodeBase*, size_t> opr2idx;
  52. std::vector<OperatorNodeBase*> all_oprs;
  53. typename MultipleDeviceTensorHolder::ValueArray all_values;
  54. auto cb_find_opr = [&](cg::OperatorNodeBase* opr) {
  55. if (opr->same_type<SharedDeviceTensor>()) {
  56. auto p = &opr->cast_final<SharedDeviceTensor>();
  57. // ShredD may be manu
  58. opr2idx[p] = all_values.size();
  59. all_values.push_back(p->dev_data());
  60. all_oprs.push_back(p);
  61. }
  62. };
  63. opt_state.graph().iter(cb_find_opr);
  64. SymbolVarArray new_vars;
  65. auto cb_replace = [&](cg::OperatorNodeBase* opr) {
  66. auto iter = opr2idx.find(opr);
  67. if (iter == opr2idx.end()) {
  68. rewriter.auto_replace_outputs(opr);
  69. } else {
  70. if (new_vars.empty()) {
  71. // new oprs must be created in iter callback; so we populate
  72. // new_vars lazily
  73. new_vars = MultipleDeviceTensorHolder::make(
  74. *opt_state.graph().comp_graph(), std::move(all_values),
  75. {ssprintf("merged%zu", all_values.size())});
  76. for (size_t i = 0; i < new_vars.size(); ++i) {
  77. auto src = all_oprs[i]->output(0);
  78. if (src->has_name_set()) {
  79. new_vars[i].rename(src->name());
  80. }
  81. }
  82. }
  83. rewriter.replace_var(
  84. opr->output(0), new_vars.at(iter->second).node(),
  85. mgb_cstr_log("replace multi SharedDeviceTensor(Format) to "
  86. "MultipleDeviceTensorHolder(Format)"));
  87. }
  88. };
  89. opt_state.graph().iter(cb_replace);
  90. rewriter.apply_inplace();
  91. }
  92. }
  93. /* ================ global functions ================ */
  94. SymbolVarArray gopt::optimize_for_inference(
  95. const SymbolVarArray& dest_vars,
  96. const OptimizeForInferenceOptions& opt) {
  97. return gopt::GraphOptimizer()
  98. .add_preset_passes(false, &opt,
  99. &dest_vars[0].node()->owner_graph()->options())
  100. .apply({dest_vars})
  101. .endpoint_vars();
  102. }
  103. namespace {
  104. void modify_conv_strategy(
  105. opr::mixin::AlgoChooserHelper& conv,
  106. opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) {
  107. auto policy = conv.execution_policy_transient();
  108. policy.strategy = strategy;
  109. conv.set_execution_policy(policy);
  110. }
  111. template <typename Opr>
  112. void inplace_conv_opr_modifier(
  113. OperatorNodeBase& opr,
  114. opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) {
  115. modify_conv_strategy(
  116. opr.cast_final_safe<Opr>(),
  117. strategy);
  118. }
  119. void modify_conv_policy_workspace_limit(opr::mixin::AlgoChooserHelper& conv,
  120. size_t workspace_limit) {
  121. auto policy = conv.execution_policy_transient();
  122. policy.workspace_limit = workspace_limit;
  123. conv.set_execution_policy(policy);
  124. }
  125. template <typename Opr>
  126. void inplace_conv_opr_workspace_limit_modifier(OperatorNodeBase& opr,
  127. size_t workspace_limit) {
  128. modify_conv_policy_workspace_limit(opr.cast_final_safe<Opr>(),
  129. workspace_limit);
  130. }
  131. } // anonymous namespace
  132. void gopt::modify_opr_algo_strategy_inplace(
  133. const VarNodeArrayView& dest_vars,
  134. opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) {
  135. #if !MGB_ENABLE_FASTRUN
  136. using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
  137. if (strategy == S::PROFILE || strategy == S::PROFILE_REPRODUCIBLE) {
  138. mgb_throw(MegBrainError, "fastrun is disabled at compile time");
  139. }
  140. #endif
  141. const ThinHashMap<Typeinfo*, std::function<void(OperatorNodeBase&)>>
  142. modifiers = {
  143. #define CONV(t) \
  144. {opr::t::typeinfo(), std::bind(inplace_conv_opr_modifier<opr::t>, \
  145. std::placeholders::_1, strategy)},
  146. MGB_FOREACH_FASTRUN_OPR(CONV)
  147. #undef CONV
  148. };
  149. auto on_opr = [&](OperatorNodeBase* opr) {
  150. auto iter = modifiers.find(opr->dyn_typeinfo());
  151. if (iter != modifiers.end()) {
  152. iter->second(*opr);
  153. }
  154. };
  155. cg::DepOprIter dep_iter{on_opr};
  156. for (auto i : dest_vars) {
  157. dep_iter.add(i);
  158. }
  159. }
  160. void gopt::enable_opr_algo_profiling_inplace(
  161. const VarNodeArrayView& dest_vars) {
  162. modify_opr_algo_strategy_inplace(
  163. dest_vars,
  164. opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy::PROFILE);
  165. }
  166. void gopt::enable_opr_use_profiling_cache_inplace(
  167. const VarNodeArrayView& dest_vars) {
  168. modify_opr_algo_strategy_inplace(
  169. dest_vars, opr::mixin::AlgoChooserHelper::ExecutionPolicy::
  170. Strategy::PROFILE_HEURISTIC);
  171. }
  172. void gopt::set_opr_algo_workspace_limit_inplace(
  173. const VarNodeArrayView& dest_vars, size_t workspace_limit) {
  174. static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&, size_t)>
  175. modifiers = {
  176. #define CONV(t) \
  177. {opr::t::typeinfo(), &inplace_conv_opr_workspace_limit_modifier<opr::t>},
  178. MGB_FOREACH_FASTRUN_OPR(CONV)
  179. #undef CONV
  180. };
  181. auto on_opr = [&](OperatorNodeBase* opr) {
  182. auto iter = modifiers.find(opr->dyn_typeinfo());
  183. if (iter != modifiers.end()) {
  184. iter->second(*opr, workspace_limit);
  185. }
  186. };
  187. cg::DepOprIter dep_iter{on_opr};
  188. for (auto i : dest_vars) {
  189. dep_iter.add(i);
  190. }
  191. }
  192. /* ================ ParamRedistributePass ================ */
  193. const char* ParamRedistributePass::name() const {
  194. return mgb_cstr_log("param_redistribute");
  195. }
  196. class ParamRedistributePass::Impl final: public RecursiveSubGraphRewriteHelper {
  197. ConstVarPropogate m_cvprop;
  198. UniqReaderCheck m_uniq_reader_check;
  199. //! oprs already processed in try_distribute_then_reassociate() should be
  200. //! skipped in on_new_opr_check_should_process()
  201. ThinHashSet<OperatorNodeBase*> m_opr_blacklist;
  202. std::string m_distribute_reasso_log_msg;
  203. //! try applying BinaryTrans20::associtive
  204. GTransResult try_reassociate(OperatorNodeBase *opr);
  205. //! try applying BinaryTrans20::distributive_add
  206. GTransResult try_distribute_add(OperatorNodeBase *opr);
  207. //! try distribute MUL/DIV over ADD/SUB and then apply
  208. GTransResult try_distribute_then_reassociate(OperatorNodeBase *opr);
  209. GTransResult process_opr(VarNode *out_var) override;
  210. bool on_new_opr_check_should_process(
  211. OperatorNodeBase*opr, OperatorNodeBase *repl_opr) override {
  212. m_uniq_reader_check.update_on_opr_auto_replace(opr, repl_opr);
  213. auto ins = m_cvprop.add_opr(opr);
  214. return ins.has_const_inp && !ins.all_const_inp &&
  215. !m_opr_blacklist.count(opr);
  216. };
  217. void after_replace_var(VarNode *orig_var, VarNode* new_var) override {
  218. m_uniq_reader_check.update_on_opr_auto_replace(orig_var->owner_opr(),
  219. new_var->owner_opr());
  220. }
  221. /*!
  222. * \brief try to reorder opr inputs to a const one and a non-const one
  223. *
  224. * return true if it can be reformulated as f(nci, ci), where nci is
  225. * non-const and ci is const.
  226. */
  227. bool reorder_for_normconst(OperatorNodeBase *opr,
  228. bool &swap_inp, VarNode *&nci, VarNode *&ci);
  229. public:
  230. Impl(OptState &state);
  231. };
  232. GTransResult ParamRedistributePass::Impl::process_opr(VarNode *out_var) {
  233. auto opr = out_var->owner_opr();
  234. auto trans = try_reassociate(opr);
  235. if (!trans.valid()) {
  236. trans = try_distribute_add(opr);
  237. if (!trans.valid())
  238. trans = try_distribute_then_reassociate(opr);
  239. }
  240. return trans;
  241. }
  242. GTransResult ParamRedistributePass::Impl::try_reassociate(
  243. OperatorNodeBase *opr) {
  244. // apply BinaryAssociative0 if opr is the form f(g(a, b), c) and b and c are
  245. // const
  246. bool swap_fop_inp = false, swap_gop_inp = false;
  247. VarNode *a, *b, *c, *ab;
  248. if (!reorder_for_normconst(opr, swap_fop_inp, ab, c))
  249. return None;
  250. if (!m_uniq_reader_check(ab))
  251. return None;
  252. if (!reorder_for_normconst(ab->owner_opr(), swap_gop_inp, a, b))
  253. return None;
  254. return BinaryTrans20::associtive().apply(opr, swap_fop_inp, swap_gop_inp);
  255. }
  256. GTransResult ParamRedistributePass::Impl::try_distribute_add(
  257. OperatorNodeBase *opr) {
  258. if (opr->same_type<opr::Elemwise>() || opr->input().size() != 2)
  259. return None;
  260. if (!m_cvprop.is_const(opr->input(1)))
  261. return None;
  262. auto ab = as_elem_opr(opr->input(0)->owner_opr(), opr::Elemwise::Mode::ADD);
  263. if (ab) {
  264. bool swap;
  265. VarNode *a, *b;
  266. if (reorder_for_normconst(ab, swap, a, b)) {
  267. return BinaryTrans20::distributive_add().apply(
  268. opr, false, swap);
  269. }
  270. }
  271. return None;
  272. }
  273. GTransResult ParamRedistributePass::Impl::try_distribute_then_reassociate(
  274. OperatorNodeBase *opr) {
  275. if (!opr->same_type<opr::Elemwise>())
  276. return None;
  277. using Mode = opr::Elemwise::Mode;
  278. auto mode = opr->cast_final<opr::Elemwise>().param().mode;
  279. if (!(mode == Mode::MUL || mode == Mode::TRUE_DIV))
  280. return None;
  281. VarNode *a, *b;
  282. bool swap;
  283. if (!reorder_for_normconst(opr, swap, a, b))
  284. return None;
  285. auto chain_pred = [this](OperatorNodeBase *opr) {
  286. if (as_elem_opr(opr, Mode::ADD)) {
  287. auto var = opr->output(0);
  288. return m_uniq_reader_check(var) || m_cvprop.is_const(var);
  289. }
  290. return false;
  291. };
  292. auto chain = extract_opr_leaves(a, chain_pred);
  293. if (chain.size() <= 1)
  294. return None;
  295. std::unordered_map<VarNode*, VarNode*> repl_map;
  296. m_distribute_reasso_log_msg.clear();
  297. int nr_fail = 0, nr_succ = 0;
  298. for (auto &&var: chain) {
  299. {
  300. auto iter = repl_map.find(var);
  301. if (iter != repl_map.end()) {
  302. var = iter->second;
  303. continue;
  304. }
  305. }
  306. auto vnew = (SymbolVar{var} * b).node();
  307. m_opr_blacklist.insert(vnew->owner_opr());
  308. if (!m_cvprop.is_const(var)) {
  309. auto trans = try_reassociate(vnew->owner_opr());
  310. if (!trans.valid()) {
  311. // allow at most one failed redistribution
  312. if (nr_fail)
  313. return None;
  314. ++ nr_fail;
  315. } else {
  316. ++ nr_succ;
  317. vnew = trans->result;
  318. if (!m_distribute_reasso_log_msg.empty()) {
  319. m_distribute_reasso_log_msg.append(mgb_cstr_log(";"));
  320. }
  321. m_distribute_reasso_log_msg.append(trans->msg);
  322. }
  323. }
  324. repl_map[var] = vnew;
  325. var = vnew;
  326. }
  327. if (nr_succ) {
  328. m_distribute_reasso_log_msg.insert(0,
  329. mgb_cstr_log("distribute_mul("));
  330. m_distribute_reasso_log_msg.append(mgb_cstr_log(")"));
  331. return GTransResultItem{
  332. elemwise_reduce_var_list(chain, Mode::ADD),
  333. m_distribute_reasso_log_msg.c_str(),
  334. {}};
  335. }
  336. return None;
  337. }
  338. bool ParamRedistributePass::Impl::reorder_for_normconst(
  339. OperatorNodeBase *opr, bool &swap_inp, VarNode *&nci, VarNode *&ci) {
  340. if (opr->input().size() != 2)
  341. return false;
  342. nci = opr->input(0);
  343. ci = opr->input(1);
  344. if (!m_cvprop.is_const(ci)) {
  345. if (!is_commutable_binary(opr) || !m_cvprop.is_const(nci))
  346. return false;
  347. swap_inp = true;
  348. std::swap(nci, ci);
  349. } else {
  350. if (m_cvprop.is_const(nci))
  351. return false;
  352. swap_inp = false;
  353. }
  354. return true;
  355. }
  356. ParamRedistributePass::Impl::Impl(OptState &state):
  357. RecursiveSubGraphRewriteHelper{state},
  358. m_cvprop{ConstVarType::IMMUTABLE_AND_PARAM},
  359. m_uniq_reader_check{state.graph()}
  360. {
  361. auto cg = state.graph().comp_graph();
  362. auto on_new_opr = [this](const cg::event::OprInserted &ev) {
  363. if (!ev.is_dedup && !ev.exc) {
  364. // call add_opr eagerly to avoid deep recursion
  365. m_cvprop.add_opr(ev.opr);
  366. }
  367. };
  368. auto hdl = cg->event().register_receiver
  369. <cg::event::OprInserted>(on_new_opr);
  370. apply();
  371. }
  372. void ParamRedistributePass::apply(OptState &state) const {
  373. MIDOUT_B("ParamRedistributePass::apply")
  374. Impl{state};
  375. MIDOUT_E
  376. }
  377. /* ================ ParamFusePass ================ */
  378. /*!
  379. * \brief get name for new param
  380. */
  381. class ParamFusePass::VarNamer {
  382. #if MGB_BUILD_SLIM_SERVING
  383. public:
  384. const std::string& name(VarNode*) {
  385. static std::string ret("fuse");
  386. return ret;
  387. }
  388. #else
  389. using SrcSet = SharedSet<OperatorNodeBase*>;
  390. //! map from var to source SharedDeviceTensor/MultiSharedDeviceHolder oprs
  391. //! that it depends on
  392. ThinHashMap<OperatorNodeBase*, SrcSet> m_opr2srcs;
  393. std::string m_name_cache;
  394. std::vector<const char*> m_cur_name;
  395. SrcSet& get_src_set(OperatorNodeBase* opr) {
  396. auto opr_typeinfo = opr->dyn_typeinfo();
  397. auto iter = m_opr2srcs.find(opr);
  398. if (iter != m_opr2srcs.end()) {
  399. return iter->second;
  400. }
  401. auto &&ret = m_opr2srcs[opr];
  402. if (opr->input().empty()) {
  403. if (opr_typeinfo == opr::SharedDeviceTensor::typeinfo() ||
  404. opr_typeinfo == opr::MultipleDeviceTensorHolder::typeinfo()) {
  405. ret.insert(opr);
  406. } else {
  407. mgb_assert(opr_typeinfo == opr::ImmutableTensor::typeinfo());
  408. }
  409. return ret;
  410. }
  411. for (auto i: opr->input()) {
  412. ret.merge_from(get_src_set(i->owner_opr()));
  413. }
  414. return ret;
  415. }
  416. public:
  417. const std::string& name(VarNode *var) {
  418. m_cur_name.clear();
  419. for (auto i : get_src_set(var->owner_opr())) {
  420. m_cur_name.push_back(i->cname());
  421. }
  422. auto cmp = [](const char *x, const char *y) {
  423. return strcmp(x, y) < 0;
  424. };
  425. std::sort(m_cur_name.begin(), m_cur_name.end(), cmp);
  426. m_name_cache.clear();
  427. m_name_cache.append(mgb_cstr_log("fuse("));
  428. bool first = true;
  429. for (auto i: m_cur_name) {
  430. if (first) {
  431. first = false;
  432. } else {
  433. m_name_cache.push_back(',');
  434. }
  435. m_name_cache.append(i);
  436. }
  437. m_name_cache.append(mgb_cstr_log(
  438. ssprintf("):%s@%zu", var->cname(), var->id())));
  439. return m_name_cache;
  440. }
  441. #endif
  442. };
  443. const char* ParamFusePass::name() const {
  444. return mgb_cstr_log("param_fuse");
  445. }
  446. void ParamFusePass::apply(OptState &state) const {
  447. MIDOUT_B("ParamFusePass::apply")
  448. auto rewriter = state.graph().make_rewriter();
  449. auto cg = state.graph().comp_graph();
  450. ConstVarPropogate cvprop{ConstVarType::IMMUTABLE_AND_PARAM};
  451. state.graph().iter([&cvprop](OperatorNodeBase *opr) {
  452. cvprop.add_opr(opr);
  453. });
  454. ThinHashSet<VarNode*> processed_var;
  455. VarNamer var_namer;
  456. // reader: null if used as endvar
  457. auto replace_single_var = [&](VarNode *var, OperatorNodeBase *reader) {
  458. if (!processed_var.insert(var).second)
  459. return;
  460. auto inferred_val = std::make_shared<DeviceTensorND>(
  461. var->comp_node(), var->dtype());
  462. auto cb = [&](DeviceTensorND& val) {
  463. // retain format of val
  464. mgb_assert(val.format() == var->format());
  465. inferred_val->format(val.format())
  466. .resize(val.shape())
  467. .copy_from_fixlayout(val);
  468. };
  469. {
  470. auto orig_level = cg->options().log_level;
  471. cg->options().log_level = 0;
  472. MGB_TRY {
  473. cg->compile({{var, cb}})->execute();
  474. } MGB_FINALLY(cg->options().log_level = orig_level);
  475. }
  476. SymbolVar new_var;
  477. bool is_default_format = var->format().is_default();
  478. if (cg::is_static_var_value(var) && is_default_format) {
  479. // use ImmutableTensor for inferable vars
  480. HostTensorND hv;
  481. hv.copy_from(*inferred_val).sync();
  482. new_var = opr::ImmutableTensor::make(
  483. *var->owner_graph(), hv, var_namer.name(var));
  484. } else {
  485. if (is_default_format) {
  486. new_var = opr::SharedDeviceTensor::make_const(
  487. *var->owner_graph(), inferred_val, var_namer.name(var));
  488. } else {
  489. new_var = opr::SharedDeviceTensorWithFormat::make_const(
  490. *var->owner_graph(), inferred_val, var_namer.name(var));
  491. }
  492. }
  493. std::string log;
  494. if (reader) {
  495. log = mgb_ssprintf_log(
  496. "due to read by %s{%s}",
  497. reader->cname(), reader->dyn_typeinfo()->name);
  498. } else {
  499. log = mgb_cstr_log("as endpoint");
  500. }
  501. rewriter.replace_var(var, new_var.node(), log.c_str());
  502. };
  503. auto replace_opr = [&](OperatorNodeBase* opr) {
  504. auto add_ret = cvprop.opr_rst(opr);
  505. if (!add_ret.all_const_inp && add_ret.has_midconst_inp) {
  506. for (auto i: opr->input()) {
  507. if (cvprop.is_midconst(i)) {
  508. state.call_with_opr(i->owner_opr(),
  509. [&]{replace_single_var(i, opr);});
  510. }
  511. }
  512. }
  513. rewriter.auto_replace_outputs(opr);
  514. //! we should deal with midconst var after auto_replace_outputs, as
  515. //! on_midconst_opr will replace the endpoint output which may cause
  516. //! double replace.
  517. if (add_ret.all_const_inp) {
  518. for (auto var : opr->output()) {
  519. if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT))
  520. continue;
  521. auto osize = ConstVarPropogate::var_mem_size(var);
  522. if (osize >= cvprop.max_size(opr) &&
  523. osize - cvprop.max_size(opr) > m_param_grow_limit) {
  524. return;
  525. }
  526. // const oprs should be evaluated when output is used by another
  527. // non-const opr or output is needed by the user
  528. if (state.graph().endpoint_contain(var)) {
  529. replace_single_var(var, nullptr);
  530. }
  531. }
  532. }
  533. };
  534. state.graph().iter(replace_opr);
  535. rewriter.apply_inplace();
  536. MIDOUT_E
  537. }
  538. /* ================ One2OneOprReplacePass ================ */
  539. const char* ConvertF32ToF16Pass::name() const {
  540. return mgb_cstr_log("convert_f32_to_f16");
  541. }
  542. void ConvertF32ToF16Pass::apply(OptState& state) const {
  543. MIDOUT_B("ConvertF32ToF16Pass::apply")
  544. state.set_var_replace_check_flag(m_var_replace_check_flag);
  545. auto rewriter = state.graph().make_rewriter();
  546. VarNodeArray new_inp_cache;
  547. // record original output dtype
  548. const SymbolVarArray& vars = state.graph().endpoint_vars();
  549. std::vector<DType> dtypes;
  550. for (size_t i = 0; i < vars.size(); i++) {
  551. dtypes.push_back(vars[i].node()->dtype());
  552. }
  553. auto on_opr = [this, &rewriter, &new_inp_cache](OperatorNodeBase* opr) {
  554. auto it = m_opr_replace_func.find(opr->dyn_typeinfo());
  555. if (it != m_opr_replace_func.end()) {
  556. auto&& new_inp = new_inp_cache;
  557. new_inp.clear();
  558. new_inp.reserve(opr->input().size());
  559. for (auto i: opr->input()) {
  560. new_inp.push_back(rewriter.get_var(i));
  561. }
  562. auto new_opr = (it->second)(opr, new_inp);
  563. auto &&origin_out = opr->output(), &&cur_out = new_opr->output();
  564. mgb_assert(origin_out.size() == cur_out.size(),
  565. "bad opr replace: src=%s{%s} dst=%s{%s}", opr->cname(),
  566. opr->dyn_typeinfo()->name, new_opr->cname(),
  567. new_opr->dyn_typeinfo()->name);
  568. for (size_t i = 0; i < origin_out.size(); i++) {
  569. rewriter.replace_var(origin_out[i], cur_out[i], nullptr);
  570. }
  571. } else {
  572. rewriter.auto_replace_outputs(opr);
  573. }
  574. };
  575. state.graph().iter(on_opr);
  576. rewriter.apply_inplace();
  577. // recover output dtype
  578. rewriter = state.graph().make_rewriter();
  579. const SymbolVarArray& endpoints = state.graph().endpoint_vars();
  580. auto replace_output = [&]() {
  581. for (size_t i = 0; i < endpoints.size(); i++) {
  582. VarNode* var = endpoints[i].node();
  583. if (var->dtype().enumv() != dtypes[i].enumv()) {
  584. auto new_var = opr::TypeCvt::make(var, dtypes[i]).node();
  585. rewriter.replace_var(var, new_var, nullptr);
  586. }
  587. }
  588. };
  589. mgb_assert(endpoints.size() > 0);
  590. auto opr = endpoints[0].node()->owner_opr();
  591. state.call_with_opr(opr, replace_output, OprPropertyFlag::NONE);
  592. rewriter.apply_inplace();
  593. MIDOUT_E
  594. }
  595. std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(
  596. bool use_f32_comp) {
  597. #if MEGDNN_DISABLE_FLOAT16
  598. mgb_throw(SystemError, "float16 disabled at compile time.");
  599. #else
  600. auto replace_h2d_opr = [](OperatorNodeBase* opr,
  601. const VarNodeArray& new_inp) {
  602. mgb_assert(opr->input().size() == new_inp.size());
  603. auto& h2d_opr = opr->cast_final_safe<opr::Host2DeviceCopy>();
  604. if (h2d_opr.output(0)->dtype() == dtype::Float32()) {
  605. auto cvt_var =
  606. opr::TypeCvt::make(h2d_opr.output(0), dtype::Float16(), {});
  607. return cvt_var.node()->owner_opr();
  608. }
  609. return opr;
  610. };
  611. auto replace_sdt_opr = [](OperatorNodeBase* opr,
  612. const VarNodeArray& new_inp) {
  613. mgb_assert(opr->input().size() == new_inp.size());
  614. auto& sdt_opr = opr->cast_final_safe<opr::SharedDeviceTensor>();
  615. if (sdt_opr.output(0)->dtype() == dtype::Float32()) {
  616. auto cvt_var =
  617. opr::TypeCvt::make(sdt_opr.output(0), dtype::Float16(), {});
  618. return cvt_var.node()->owner_opr();
  619. }
  620. return opr;
  621. };
  622. auto replace_imt_opr = [](OperatorNodeBase* opr,
  623. const VarNodeArray& new_inp) {
  624. mgb_assert(opr->same_type<opr::ImmutableTensor>());
  625. mgb_assert(opr->input().size() == new_inp.size());
  626. auto& imt_opr = opr->cast_final_safe<opr::ImmutableTensor>();
  627. if (imt_opr.output(0)->dtype() == dtype::Float32()) {
  628. auto cvt_var =
  629. opr::TypeCvt::make(imt_opr.output(0), dtype::Float16(), {});
  630. return cvt_var.node()->owner_opr();
  631. }
  632. return opr;
  633. };
  634. auto replace_lsp_opr = [](OperatorNodeBase* opr,
  635. const VarNodeArray& new_inp) {
  636. mgb_assert(opr->same_type<opr::Linspace>());
  637. mgb_assert(opr->input().size() == new_inp.size());
  638. auto& lsp_opr = opr->cast_final_safe<opr::Linspace>();
  639. if (lsp_opr.output(0)->dtype() != dtype::Float16()) {
  640. auto cvt_var =
  641. opr::TypeCvt::make(lsp_opr.output(0), dtype::Float16(), {});
  642. return cvt_var.node()->owner_opr();
  643. }
  644. return opr;
  645. };
  646. auto replace_conv_opr = [use_f32_comp](OperatorNodeBase* opr,
  647. const VarNodeArray& new_inp) {
  648. mgb_assert(opr->input().size() == new_inp.size());
  649. auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
  650. auto new_param = conv_opr.param();
  651. if (use_f32_comp) {
  652. new_param.compute_mode =
  653. megdnn::param::Convolution::ComputeMode::FLOAT32;
  654. }
  655. mgb_assert(new_inp[0]->dtype() == dtype::Float16(),
  656. "inp %s:%s, owner_opr:%s", new_inp[0]->dtype().name(),
  657. new_inp[0]->name().c_str(),
  658. new_inp[0]->owner_opr()->name().c_str());
  659. mgb_assert(new_inp[1]->dtype() == dtype::Float16(),
  660. "inp %s:%s, owner_opr:%s", new_inp[1]->dtype().name(),
  661. new_inp[1]->name().c_str(),
  662. new_inp[1]->owner_opr()->name().c_str());
  663. auto new_conv_opr = opr::Convolution::make(
  664. new_inp[0], new_inp[1], new_param, conv_opr.execution_policy(),
  665. conv_opr.config());
  666. return new_conv_opr.node()->owner_opr();
  667. };
  668. auto replace_deconv_opr = [use_f32_comp](OperatorNodeBase* opr,
  669. const VarNodeArray& new_inp) {
  670. mgb_assert(opr->input().size() == new_inp.size());
  671. auto& deconv_opr = opr->cast_final_safe<opr::ConvolutionBackwardData>();
  672. auto new_param = deconv_opr.param();
  673. if (use_f32_comp) {
  674. new_param.compute_mode =
  675. megdnn::param::Convolution::ComputeMode::FLOAT32;
  676. }
  677. mgb_assert(new_inp[0]->dtype() == dtype::Float16(),
  678. "inp %s:%s, owner_opr:%s", new_inp[0]->dtype().name(),
  679. new_inp[0]->name().c_str(),
  680. new_inp[0]->owner_opr()->name().c_str());
  681. mgb_assert(new_inp[1]->dtype() == dtype::Float16(),
  682. "inp %s:%s, owner_opr:%s", new_inp[1]->dtype().name(),
  683. new_inp[1]->name().c_str(),
  684. new_inp[1]->owner_opr()->name().c_str());
  685. auto new_deconv_opr = opr::ConvolutionBackwardData::make(
  686. new_inp[0], new_inp[1], new_param,
  687. deconv_opr.execution_policy(), deconv_opr.config());
  688. return new_deconv_opr.node()->owner_opr();
  689. };
  690. auto replace_convbias_opr = [use_f32_comp](OperatorNodeBase* opr,
  691. const VarNodeArray& new_inp) {
  692. auto& convbias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
  693. auto new_param = convbias_opr.param();
  694. if (use_f32_comp) {
  695. new_param.compute_mode =
  696. megdnn::param::ConvBias::ComputeMode::FLOAT32;
  697. }
  698. mgb_assert(new_inp[0]->dtype() == dtype::Float16(),
  699. "inp %s:%s, owner_opr:%s", new_inp[0]->dtype().name(),
  700. new_inp[0]->name().c_str(),
  701. new_inp[0]->owner_opr()->name().c_str());
  702. mgb_assert(new_inp[1]->dtype() == dtype::Float16(),
  703. "inp %s:%s, owner_opr:%s", new_inp[1]->dtype().name(),
  704. new_inp[1]->name().c_str(),
  705. new_inp[1]->owner_opr()->name().c_str());
  706. if(opr->input().size() == 2) {
  707. auto new_conv_opr = opr::ConvBias::make(
  708. new_inp[0], new_inp[1], new_param,
  709. convbias_opr.execution_policy(), convbias_opr.config());
  710. return new_conv_opr.node()->owner_opr();
  711. } else if(opr->input().size() == 3) {
  712. auto new_conv_opr = opr::ConvBias::make(
  713. new_inp[0], new_inp[1], new_inp[2], new_param,
  714. convbias_opr.execution_policy(), convbias_opr.config());
  715. return new_conv_opr.node()->owner_opr();
  716. } else {
  717. mgb_assert(opr->input().size() == 4, "invalid input size %zu",
  718. opr->input().size());
  719. auto new_conv_opr = opr::ConvBias::make(
  720. new_inp[0], new_inp[1], new_inp[2], new_inp[3], new_param,
  721. convbias_opr.execution_policy(), convbias_opr.config());
  722. return new_conv_opr.node()->owner_opr();
  723. }
  724. };
  725. auto replace_matmul_opr = [use_f32_comp](OperatorNodeBase* opr,
  726. const VarNodeArray& new_inp) {
  727. mgb_assert(opr->input().size() == new_inp.size());
  728. auto& matmul_opr = opr->cast_final_safe<opr::MatrixMul>();
  729. auto new_param = matmul_opr.param();
  730. if (use_f32_comp) {
  731. new_param.compute_mode =
  732. megdnn::param::MatrixMul::ComputeMode::FLOAT32;
  733. }
  734. auto new_matmul_opr = opr::MatrixMul::make(
  735. new_inp[0], new_inp[1], new_param,
  736. matmul_opr.execution_policy(), matmul_opr.config());
  737. return new_matmul_opr.node()->owner_opr();
  738. };
  739. auto replace_batched_matmul_opr = [use_f32_comp](
  740. OperatorNodeBase* opr,
  741. const VarNodeArray& new_inp) {
  742. mgb_assert(opr->input().size() == new_inp.size());
  743. auto& matmul_opr = opr->cast_final_safe<opr::BatchedMatrixMul>();
  744. auto new_param = matmul_opr.param();
  745. if (use_f32_comp) {
  746. new_param.compute_mode =
  747. megdnn::param::MatrixMul::ComputeMode::FLOAT32;
  748. }
  749. mgb_assert(new_inp[0]->dtype() == dtype::Float16(),
  750. "inp %s:%s, owner_opr:%s", new_inp[0]->dtype().name(),
  751. new_inp[0]->name().c_str(),
  752. new_inp[0]->owner_opr()->name().c_str());
  753. mgb_assert(new_inp[1]->dtype() == dtype::Float16(),
  754. "inp %s:%s, owner_opr:%s", new_inp[1]->dtype().name(),
  755. new_inp[1]->name().c_str(),
  756. new_inp[1]->owner_opr()->name().c_str());
  757. auto new_matmul_opr = opr::BatchedMatrixMul::make(
  758. new_inp[0], new_inp[1], new_param,
  759. matmul_opr.execution_policy(), matmul_opr.config());
  760. return new_matmul_opr.node()->owner_opr();
  761. };
  762. auto replace_reduce_opr = [use_f32_comp](OperatorNodeBase* opr,
  763. const VarNodeArray& new_inp) {
  764. auto& reduce_opr = opr->cast_final_safe<opr::Reduce>();
  765. auto new_param = reduce_opr.param();
  766. if (use_f32_comp) {
  767. new_param.data_type =
  768. megdnn::param::Reduce::DataType::FLOAT_O16xC32;
  769. }
  770. if (opr->input().size() == 1) {
  771. auto new_matmul_opr = opr::Reduce::make(new_inp[0], new_param, {},
  772. reduce_opr.config());
  773. return new_matmul_opr.node()->owner_opr();
  774. } else {
  775. mgb_assert(opr->input().size() == 2, "invalid input size %zu",
  776. opr->input().size());
  777. auto new_matmul_opr = opr::Reduce::make(
  778. new_inp[0], new_param, new_inp[1], reduce_opr.config());
  779. return new_matmul_opr.node()->owner_opr();
  780. }
  781. };
  782. auto replace_cvt_opr = [](OperatorNodeBase* opr,
  783. const VarNodeArray& new_inp) {
  784. auto& cvt_opr = opr->cast_final_safe<opr::TypeCvt>();
  785. SymbolVar new_cvt;
  786. if (cvt_opr.output(0)->dtype() == dtype::Float32()) {
  787. new_cvt = opr::TypeCvt::make(new_inp[0], dtype::Float16(),
  788. cvt_opr.config());
  789. } else {
  790. new_cvt = opr::TypeCvt::make(
  791. new_inp[0], cvt_opr.output()[0]->dtype(), cvt_opr.config());
  792. }
  793. return new_cvt.node()->owner_opr();
  794. };
  795. auto replace_warp_opr = [](OperatorNodeBase* opr,
  796. const VarNodeArray& new_inp) {
  797. mgb_assert(opr->input().size() == new_inp.size() &&
  798. (new_inp.size() == 3 || new_inp.size() == 4));
  799. auto& warp_opr = opr->cast_final<opr::WarpPerspective>();
  800. // mat tensor must be float32
  801. auto new_mat = new_inp[1];
  802. if (new_inp[1]->dtype() != dtype::Float32()) {
  803. if (try_cast_as_op<opr::TypeCvt>(new_mat->owner_opr()) &&
  804. new_mat->owner_opr()->input(0)->dtype() == dtype::Float32())
  805. new_mat = new_mat->owner_opr()->input(0);
  806. else
  807. new_mat = opr::TypeCvt::make(new_inp[1], dtype::Float32(), {})
  808. .node();
  809. }
  810. SymbolVar new_warp;
  811. if (new_inp.size() == 3) {
  812. new_warp = opr::WarpPerspective::make(new_inp[0], new_mat,
  813. new_inp[2], warp_opr.param(),
  814. warp_opr.config());
  815. } else {
  816. mgb_assert(new_inp.size() == 4);
  817. new_warp = opr::WarpPerspective::make(
  818. new_inp[0], new_mat, new_inp[2], new_inp[3],
  819. warp_opr.param(), warp_opr.config());
  820. }
  821. return new_warp.node()->owner_opr();
  822. };
  823. auto replace_remap_opr = [](OperatorNodeBase* opr,
  824. const VarNodeArray& new_inp) {
  825. mgb_assert(opr->input().size() == new_inp.size() &&
  826. (new_inp.size() == 2));
  827. auto& remap_opr = opr->cast_final<opr::Remap>();
  828. // map tensor must be float32
  829. auto new_map = new_inp[1];
  830. if (new_inp[1]->dtype() != dtype::Float32()) {
  831. if (try_cast_as_op<opr::TypeCvt>(new_map->owner_opr()) &&
  832. new_map->owner_opr()->input(0)->dtype() == dtype::Float32())
  833. new_map = new_map->owner_opr()->input(0);
  834. else
  835. new_map = opr::TypeCvt::make(new_inp[1], dtype::Float32(), {})
  836. .node();
  837. }
  838. SymbolVar new_remap;
  839. new_remap = opr::Remap::make(new_inp[0], new_map,
  840. remap_opr.param(),
  841. remap_opr.config());
  842. return new_remap.node()->owner_opr();
  843. };
  844. auto ret = std::make_unique<ConvertF32ToF16Pass>();
  845. // don't check dtype
  846. ret->set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^
  847. VarReplaceCheckFlag::CHECK_DTYPE);
  848. auto&& replace_func = ret->m_opr_replace_func;
  849. replace_func[opr::Linspace::typeinfo()] = replace_lsp_opr;
  850. replace_func[opr::Host2DeviceCopy::typeinfo()] = replace_h2d_opr;
  851. replace_func[opr::SharedDeviceTensor::typeinfo()] = replace_sdt_opr;
  852. replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
  853. replace_func[opr::ConvolutionBackwardData::typeinfo()] = replace_deconv_opr;
  854. replace_func[opr::ConvBias::typeinfo()] = replace_convbias_opr;
  855. replace_func[opr::MatrixMul::typeinfo()] = replace_matmul_opr;
  856. replace_func[opr::Reduce::typeinfo()] = replace_reduce_opr;
  857. replace_func[opr::ImmutableTensor::typeinfo()] = replace_imt_opr;
  858. replace_func[opr::TypeCvt::typeinfo()] = replace_cvt_opr;
  859. replace_func[opr::WarpPerspective::typeinfo()] = replace_warp_opr;
  860. replace_func[opr::Remap::typeinfo()] = replace_remap_opr;
  861. replace_func[opr::BatchedMatrixMul::typeinfo()] =
  862. replace_batched_matmul_opr;
  863. return ret;
  864. #endif
  865. }
  866. /* ================ ConvertFormatPass ================ */
  867. void ConvertFormatPass::apply(OptState& state) const {
  868. MIDOUT_B("ConvertFormatPass::apply")
  869. state.set_var_replace_check_flag(m_var_replace_check_flag);
  870. auto rewriter = state.graph().make_rewriter();
  871. VarNodeArray new_inp_cache;
  872. auto on_opr = [this, &state, &rewriter,
  873. &new_inp_cache](OperatorNodeBase* opr) {
  874. auto it = m_opr_replace_func.find(opr->dyn_typeinfo());
  875. if (it != m_opr_replace_func.end()) {
  876. auto&& new_inp = new_inp_cache;
  877. new_inp.clear();
  878. new_inp.reserve(opr->input().size());
  879. for (auto i : opr->input()) {
  880. new_inp.push_back(rewriter.get_var(i));
  881. }
  882. auto new_opr = (it->second)(opr, new_inp);
  883. auto &&out0 = opr->output(), &&out1 = new_opr->output();
  884. mgb_assert(out0.size() == out1.size(),
  885. "bad opr replace: src=%s{%s} dst=%s{%s}, src.size=%zu "
  886. "dst.size=%zu",
  887. opr->cname(), opr->dyn_typeinfo()->name,
  888. new_opr->cname(), new_opr->dyn_typeinfo()->name,
  889. out0.size(), out1.size());
  890. for (size_t i = 0; i < out0.size(); i++) {
  891. if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  892. mgb_assert(!out1[i]->contain_flag(
  893. VarNode::Flag::VOLATILE_CONTENT));
  894. auto src = out0[i];
  895. auto dst = out1[i];
  896. auto dst_is_image = dst->format().type() ==
  897. TensorFormat::Type::IMAGE2D_PACK4;
  898. if (!dst_is_image &&
  899. !src->owner_opr()->same_type<opr::ImmutableTensor>()) {
  900. mgb_log_warn(
  901. "convert NHWCD4 replaced to non-img format: "
  902. "dst_opr=%s{%s} format=%s",
  903. dst->owner_opr()->cname(),
  904. dst->owner_opr()->dyn_typeinfo()->name,
  905. dst->format().to_string().c_str());
  906. }
  907. if (state.graph().endpoint_contain(src) && dst_is_image) {
  908. // relayout back to NCHW for output vars
  909. dst = opr::RelayoutFormat::make(
  910. dst, {opr::RelayoutFormat::Param::Mode::
  911. NHWCD4I_NCHW})
  912. .node();
  913. }
  914. rewriter.replace_var(src, dst, nullptr);
  915. }
  916. }
  917. } else {
  918. rewriter.auto_replace_outputs(opr);
  919. }
  920. };
  921. state.graph().iter(on_opr);
  922. rewriter.apply_inplace();
  923. MIDOUT_E
  924. }
  925. std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
  926. MIDOUT_B("ConvertFormatPass::make")
  927. auto filter_mode =
  928. [](const megdnn::param::Convolution::Sparse conv_mode,
  929. const VarNode* filter) -> megdnn::param::RelayoutFormat::Mode {
  930. bool use_dot = false;
  931. if (filter->dtype().enumv() == megdnn::DTypeEnum::QuantizedS8 ||
  932. filter->dtype().enumv() == megdnn::DTypeEnum::Quantized8Asymm)
  933. use_dot = true;
  934. if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) {
  935. if (use_dot)
  936. return megdnn::param::RelayoutFormat::Mode::
  937. INTER_WEIGHT_DENSEI_DOT;
  938. return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_DENSEI;
  939. } else {
  940. mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP);
  941. if (filter->shape()[1] == 1 && filter->shape()[2] == 1) {
  942. return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_CHANI;
  943. } else {
  944. if (use_dot)
  945. return megdnn::param::RelayoutFormat::Mode::
  946. INTER_WEIGHT_GROUPI_DOT;
  947. return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_GROUPI;
  948. }
  949. }
  950. };
  951. auto size_one_conv_to_dense_conv =
  952. [](VarNode* origin_filter_input,
  953. megdnn::param::Convolution::Sparse sparse) {
  954. VarNode* reshaped_filter = origin_filter_input;
  955. bool is_size_one_group_conv = false;
  956. if (sparse == megdnn::param::Convolution::Sparse::GROUP &&
  957. origin_filter_input->shape()[0] == 1) {
  958. is_size_one_group_conv = true;
  959. auto new_shape = origin_filter_input->shape();
  960. new_shape.ndim = 4;
  961. for (int i = 0; i < 4; i++) {
  962. new_shape[i] = origin_filter_input->shape()[i + 1];
  963. }
  964. SymbolVar new_var(origin_filter_input);
  965. reshaped_filter = new_var.reshape(new_shape).node();
  966. }
  967. return std::make_tuple(reshaped_filter, is_size_one_group_conv);
  968. };
  969. auto replace_conv_opr =
  970. [&filter_mode, &size_one_conv_to_dense_conv](OperatorNodeBase* opr,
  971. const VarNodeArray& new_inp) {
  972. mgb_assert(opr->input().size() == new_inp.size());
  973. auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
  974. mgb_assert(conv_opr.param().format ==
  975. megdnn::param::Convolution::Format::NCHW,
  976. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  977. VarNode *conv_src = nullptr, *conv_weights = nullptr;
  978. if (new_inp[0]->shape().ndim == 4) {
  979. // new input src is NCHW
  980. size_t group, icpg, ocpg;
  981. if (conv_opr.param().sparse ==
  982. megdnn::param::Convolution::Sparse::DENSE) {
  983. group = 1;
  984. icpg = new_inp[1]->shape()[1];
  985. ocpg = new_inp[1]->shape()[0];
  986. } else {
  987. mgb_assert(conv_opr.param().sparse ==
  988. megdnn::param::Convolution::Sparse::GROUP);
  989. group = new_inp[1]->shape()[0];
  990. icpg = new_inp[1]->shape()[2];
  991. ocpg = new_inp[1]->shape()[1];
  992. }
  993. if (ocpg % 4 == 0 && (icpg % 4 == 0 || group == 1)) {
  994. auto param = megdnn::param::RelayoutFormat();
  995. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  996. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  997. conv_src = rf.node();
  998. } else {
  999. // can not convert to hwcd4
  1000. return serialization::copy_opr_shallow(*opr, new_inp,
  1001. opr->config());
  1002. }
  1003. } else {
  1004. size_t ocpg;
  1005. bool is_channel_wise = false;
  1006. if (conv_opr.param().sparse ==
  1007. megdnn::param::Convolution::Sparse::DENSE) {
  1008. ocpg = new_inp[1]->shape()[0];
  1009. } else {
  1010. mgb_assert(conv_opr.param().sparse ==
  1011. megdnn::param::Convolution::Sparse::GROUP);
  1012. size_t icpg = new_inp[1]->shape()[2];
  1013. ocpg = new_inp[1]->shape()[1];
  1014. if (icpg == 1 && ocpg == 1) {
  1015. is_channel_wise = true;
  1016. }
  1017. }
  1018. if (ocpg % 4 != 0 && !is_channel_wise) {
  1019. VarNodeArray t_inp = new_inp;
  1020. auto param = megdnn::param::RelayoutFormat();
  1021. param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
  1022. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1023. t_inp[0] = rf.node();
  1024. auto new_opr = serialization::copy_opr_shallow(*opr, t_inp,
  1025. opr->config());
  1026. return new_opr;
  1027. }
  1028. // new input src is NHWCD4
  1029. auto&& fmt = new_inp[0]
  1030. ->format()
  1031. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1032. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1033. conv_src = new_inp[0];
  1034. }
  1035. VarNode* reshaped_filter;
  1036. bool is_size_one_group_conv;
  1037. std::tie(reshaped_filter, is_size_one_group_conv) =
  1038. size_one_conv_to_dense_conv(new_inp[1],
  1039. conv_opr.param().sparse);
  1040. auto new_conv_param = conv_opr.param();
  1041. if (is_size_one_group_conv) {
  1042. new_conv_param.sparse = megdnn::param::Convolution::Sparse::DENSE;
  1043. }
  1044. mgb_assert(new_inp[1]->format().type() !=
  1045. TensorFormat::Type::IMAGE2D_PACK4);
  1046. auto param = megdnn::param::RelayoutFormat();
  1047. param.mode = filter_mode(new_conv_param.sparse, reshaped_filter);
  1048. auto relayout_weight = opr::RelayoutFormat::make(reshaped_filter, param);
  1049. conv_weights = relayout_weight.node();
  1050. new_conv_param.format = megdnn::param::Convolution::Format::NHWCD4;
  1051. mgb_assert(conv_src->shape().ndim == 5 &&
  1052. conv_src->format().type() ==
  1053. TensorFormat::Type::IMAGE2D_PACK4);
  1054. auto new_conv_opr = opr::Convolution::make(
  1055. conv_src, conv_weights, new_conv_param, conv_opr.execution_policy(),
  1056. conv_opr.config());
  1057. OperatorNodeBase* ret = new_conv_opr.node()->owner_opr();
  1058. mgb_assert(new_conv_opr.shape().ndim == 5 &&
  1059. new_conv_opr.format().type() ==
  1060. TensorFormat::Type::IMAGE2D_PACK4);
  1061. return ret;
  1062. };
  1063. auto replace_conv_bias_opr =
  1064. [&filter_mode, &size_one_conv_to_dense_conv](OperatorNodeBase* opr,
  1065. const VarNodeArray& new_inp) {
  1066. mgb_assert(opr->input().size() == new_inp.size());
  1067. auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
  1068. mgb_assert(conv_bias_opr.param().format ==
  1069. megdnn::param::ConvBias::Format::NCHW,
  1070. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1071. VarNode *conv_bias_src = nullptr, *conv_bias_weights = nullptr,
  1072. *conv_bias_bias = nullptr;
  1073. if (new_inp[0]->shape().ndim == 4) {
  1074. // new input src is NCHW
  1075. size_t group, icpg, ocpg;
  1076. if (conv_bias_opr.param().sparse ==
  1077. megdnn::param::ConvBias::Sparse::DENSE) {
  1078. group = 1;
  1079. icpg = new_inp[1]->shape()[1];
  1080. ocpg = new_inp[1]->shape()[0];
  1081. } else {
  1082. mgb_assert(conv_bias_opr.param().sparse ==
  1083. megdnn::param::ConvBias::Sparse::GROUP);
  1084. group = new_inp[1]->shape()[0];
  1085. icpg = new_inp[1]->shape()[2];
  1086. ocpg = new_inp[1]->shape()[1];
  1087. }
  1088. if (ocpg % 4 == 0 && (icpg % 4 == 0 || group == 1)) {
  1089. auto param = megdnn::param::RelayoutFormat();
  1090. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1091. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1092. conv_bias_src = rf.node();
  1093. } else {
  1094. // can not convert to hwcd4
  1095. return serialization::copy_opr_shallow(*opr, new_inp,
  1096. opr->config());
  1097. }
  1098. } else {
  1099. size_t ocpg;
  1100. bool is_channel_wise = false;
  1101. if (conv_bias_opr.param().sparse ==
  1102. megdnn::param::ConvBias::Sparse::DENSE) {
  1103. ocpg = new_inp[1]->shape()[0];
  1104. } else {
  1105. mgb_assert(conv_bias_opr.param().sparse ==
  1106. megdnn::param::ConvBias::Sparse::GROUP);
  1107. size_t icpg = new_inp[1]->shape()[2];
  1108. ocpg = new_inp[1]->shape()[1];
  1109. if (icpg == 1 && ocpg == 1) {
  1110. is_channel_wise = true;
  1111. }
  1112. }
  1113. if (ocpg % 4 != 0 && !is_channel_wise) {
  1114. VarNodeArray t_inp = new_inp;
  1115. auto param = megdnn::param::RelayoutFormat();
  1116. param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
  1117. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1118. t_inp[0] = rf.node();
  1119. auto new_opr = serialization::copy_opr_shallow(*opr, t_inp,
  1120. opr->config());
  1121. return new_opr;
  1122. }
  1123. // new input src is NHWCD4
  1124. auto&& fmt = new_inp[0]
  1125. ->format()
  1126. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1127. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1128. conv_bias_src = new_inp[0];
  1129. }
  1130. mgb_assert(new_inp[1]->format().type() !=
  1131. TensorFormat::Type::IMAGE2D_PACK4);
  1132. VarNode* reshaped_filter;
  1133. bool is_size_one_group_conv;
  1134. std::tie(reshaped_filter, is_size_one_group_conv) =
  1135. size_one_conv_to_dense_conv(new_inp[1],
  1136. conv_bias_opr.param().sparse);
  1137. auto new_conv_param = conv_bias_opr.param();
  1138. if (is_size_one_group_conv) {
  1139. new_conv_param.sparse = megdnn::param::Convolution::Sparse::DENSE;
  1140. }
  1141. auto param = megdnn::param::RelayoutFormat();
  1142. param.mode = filter_mode(new_conv_param.sparse, reshaped_filter);
  1143. auto relayout_weight = opr::RelayoutFormat::make(reshaped_filter, param);
  1144. conv_bias_weights = relayout_weight.node();
  1145. mgb_assert(new_inp.size() < 4,
  1146. "ConvertFormat pass does not support fuse Z");
  1147. bool has_bias = new_inp.size() > 2;
  1148. if (has_bias &&
  1149. new_inp[2]->format().type() == TensorFormat::Type::DEFAULT) {
  1150. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1151. auto relayout_bias = opr::RelayoutFormat::make(new_inp[2], param);
  1152. conv_bias_bias = relayout_bias.node();
  1153. } else if (has_bias) {
  1154. conv_bias_bias = new_inp[2];
  1155. }
  1156. new_conv_param.format = megdnn::param::ConvBias::Format::NHWCD4;
  1157. mgb_assert(conv_bias_src->shape().ndim == 5 &&
  1158. conv_bias_src->format().type() ==
  1159. TensorFormat::Type::IMAGE2D_PACK4);
  1160. SymbolVar new_conv_bias_opr;
  1161. if (has_bias) {
  1162. new_conv_bias_opr = opr::ConvBias::make(
  1163. conv_bias_src, conv_bias_weights, conv_bias_bias, new_conv_param,
  1164. conv_bias_opr.execution_policy(), conv_bias_opr.config());
  1165. } else {
  1166. new_conv_bias_opr = opr::ConvBias::make(
  1167. conv_bias_src, conv_bias_weights, new_conv_param,
  1168. conv_bias_opr.execution_policy(), conv_bias_opr.config());
  1169. }
  1170. OperatorNodeBase* ret = new_conv_bias_opr.node()->owner_opr();
  1171. mgb_assert(new_conv_bias_opr.shape().ndim == 5 &&
  1172. new_conv_bias_opr.format().type() ==
  1173. TensorFormat::Type::IMAGE2D_PACK4);
  1174. return ret;
  1175. };
  1176. auto replace_deconv_opr = [&filter_mode](OperatorNodeBase* opr,
  1177. const VarNodeArray& new_inp) {
  1178. mgb_assert(opr->input().size() == new_inp.size());
  1179. auto& deconv_opr = opr->cast_final_safe<opr::ConvolutionBackwardData>();
  1180. mgb_assert(deconv_opr.param().format ==
  1181. megdnn::param::Convolution::Format::NCHW,
  1182. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1183. VarNode *deconv_src = nullptr, *deconv_weights = nullptr;
  1184. if (new_inp[1]->shape().ndim == 4) {
  1185. // new input src is NCHW
  1186. size_t group, icpg, ocpg;
  1187. if (deconv_opr.param().sparse ==
  1188. megdnn::param::Convolution::Sparse::DENSE) {
  1189. group = 1;
  1190. icpg = new_inp[0]->shape()[0];
  1191. ocpg = new_inp[0]->shape()[1];
  1192. } else {
  1193. mgb_assert(deconv_opr.param().sparse ==
  1194. megdnn::param::Convolution::Sparse::GROUP);
  1195. group = new_inp[0]->shape()[0];
  1196. icpg = new_inp[0]->shape()[1];
  1197. ocpg = new_inp[0]->shape()[2];
  1198. }
  1199. if (ocpg % 4 == 0 && (icpg % 4 == 0 || group == 1)) {
  1200. auto param = megdnn::param::RelayoutFormat();
  1201. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1202. auto rf = opr::RelayoutFormat::make(new_inp[1], param);
  1203. deconv_src = rf.node();
  1204. } else {
  1205. // can not convert to hwcd4
  1206. return serialization::copy_opr_shallow(*opr, new_inp,
  1207. opr->config());
  1208. }
  1209. } else {
  1210. //! XXXX, fix me, check filter size
  1211. size_t ocpg;
  1212. if (deconv_opr.param().sparse ==
  1213. megdnn::param::Convolution::Sparse::DENSE) {
  1214. ocpg = new_inp[0]->shape()[1];
  1215. } else {
  1216. mgb_assert(deconv_opr.param().sparse ==
  1217. megdnn::param::Convolution::Sparse::GROUP);
  1218. ocpg = new_inp[0]->shape()[2];
  1219. }
  1220. if (ocpg % 4 != 0) {
  1221. VarNodeArray t_inp = new_inp;
  1222. auto param = megdnn::param::RelayoutFormat();
  1223. param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
  1224. auto rf = opr::RelayoutFormat::make(new_inp[1], param);
  1225. t_inp[1] = rf.node();
  1226. auto new_opr = serialization::copy_opr_shallow(*opr, t_inp,
  1227. opr->config());
  1228. return new_opr;
  1229. }
  1230. // new input src is NHWCD4
  1231. auto&& fmt = new_inp[1]
  1232. ->format()
  1233. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1234. mgb_assert(new_inp[1]->shape().ndim == 5 && fmt.align_axis() == 2);
  1235. deconv_src = new_inp[1];
  1236. }
  1237. mgb_assert(new_inp[0]->format().type() !=
  1238. TensorFormat::Type::IMAGE2D_PACK4);
  1239. auto param = megdnn::param::RelayoutFormat();
  1240. param.mode = filter_mode(deconv_opr.param().sparse, new_inp[0]);
  1241. auto relayout_weight = opr::RelayoutFormat::make(new_inp[0], param);
  1242. deconv_weights = relayout_weight.node();
  1243. auto new_param = deconv_opr.param();
  1244. new_param.format = megdnn::param::Convolution::Format::NHWCD4;
  1245. mgb_assert(deconv_src->shape().ndim == 5 &&
  1246. deconv_src->format().type() ==
  1247. TensorFormat::Type::IMAGE2D_PACK4);
  1248. auto new_deconv_opr = opr::ConvolutionBackwardData::make(
  1249. deconv_weights, deconv_src, new_param,
  1250. deconv_opr.execution_policy(), deconv_opr.config());
  1251. OperatorNodeBase* ret = new_deconv_opr.node()->owner_opr();
  1252. mgb_assert(new_deconv_opr.shape().ndim == 5 &&
  1253. new_deconv_opr.format().type() ==
  1254. TensorFormat::Type::IMAGE2D_PACK4);
  1255. return ret;
  1256. };
  1257. /* This helper function guarantees the format convert pass won't change
  1258. * output var's channel. Changing output's channel will cause channel
  1259. * mismatch problem for replacing conv/conv_bias operator.
  1260. */
  1261. auto replace_helper = [](OperatorNodeBase* opr,
  1262. const VarNodeArray& new_inp) -> OperatorNodeBase* {
  1263. auto&& new_shp = new_inp[0]->shape();
  1264. size_t inp_channel = new_shp[1];
  1265. if (new_shp.eq_shape(opr->input(0)->shape())&& inp_channel % 4 != 0) {
  1266. auto new_opr = serialization::copy_opr_shallow(*opr, new_inp,
  1267. opr->config());
  1268. return new_opr;
  1269. }
  1270. return nullptr;
  1271. };
  1272. auto replace_resize_opr = [replace_helper](OperatorNodeBase* opr,
  1273. const VarNodeArray& new_inp) {
  1274. mgb_assert(opr->input().size() == new_inp.size());
  1275. if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
  1276. return opr_shallow_copy;
  1277. }
  1278. auto& resize_opr = opr->cast_final_safe<opr::ResizeForward>();
  1279. mgb_assert(resize_opr.param().format ==
  1280. megdnn::param::Resize::Format::NCHW,
  1281. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1282. VarNode* inp = nullptr;
  1283. if (new_inp[0]->shape().ndim == 4) {
  1284. auto param = megdnn::param::RelayoutFormat();
  1285. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1286. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1287. inp = rf.node();
  1288. } else {
  1289. // new input src is NHWCD
  1290. auto&& fmt = new_inp[0]
  1291. ->format()
  1292. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1293. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1294. inp = new_inp[0];
  1295. }
  1296. auto new_param = resize_opr.param();
  1297. new_param.format = megdnn::param::Resize::Format::NHWCD4;
  1298. auto new_resize_opr = opr::ResizeForward::make(
  1299. inp, new_inp[1], new_param, opr->config());
  1300. return new_resize_opr.node()->owner_opr();
  1301. };
  1302. auto replace_warp_perspective_opr = [replace_helper](
  1303. OperatorNodeBase* opr,
  1304. const VarNodeArray& new_inp) {
  1305. mgb_assert(opr->input().size() == new_inp.size());
  1306. if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
  1307. return opr_shallow_copy;
  1308. }
  1309. auto& warp_opr = opr->cast_final_safe<opr::WarpPerspectiveForward>();
  1310. mgb_assert(warp_opr.param().format ==
  1311. megdnn::param::WarpPerspective::Format::NCHW,
  1312. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1313. VarNode* inp = nullptr;
  1314. if (new_inp[0]->shape().ndim == 4) {
  1315. // new input src is NCHW
  1316. auto param = megdnn::param::RelayoutFormat();
  1317. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1318. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1319. inp = rf.node();
  1320. } else {
  1321. // new input src is NHWCD
  1322. auto&& fmt = new_inp[0]
  1323. ->format()
  1324. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1325. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1326. inp = new_inp[0];
  1327. }
  1328. auto new_param = warp_opr.param();
  1329. new_param.format = megdnn::param::WarpPerspective::Format::NHWCD4;
  1330. SymbolVar new_warp_opr;
  1331. if (new_inp.size() == 3) {
  1332. new_warp_opr = opr::WarpPerspectiveForward::make(
  1333. inp, new_inp[1], nullptr, new_inp[2], new_param,
  1334. opr->config());
  1335. } else {
  1336. mgb_assert(new_inp.size() == 4);
  1337. new_warp_opr = opr::WarpPerspectiveForward::make(
  1338. inp, new_inp[1], new_inp[2], new_inp[3], new_param,
  1339. opr->config());
  1340. }
  1341. return new_warp_opr.node()->owner_opr();
  1342. };
  1343. auto replace_warp_affine_opr = [replace_helper](OperatorNodeBase* opr,
  1344. const VarNodeArray& new_inp) {
  1345. mgb_assert(opr->input().size() == new_inp.size());
  1346. if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
  1347. return opr_shallow_copy;
  1348. }
  1349. auto& warp_opr = opr->cast_final_safe<opr::WarpAffineForward>();
  1350. mgb_assert(warp_opr.param().format ==
  1351. megdnn::param::WarpAffine::Format::NCHW,
  1352. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1353. VarNode* inp = nullptr;
  1354. if (new_inp[0]->shape().ndim == 4) {
  1355. // new input src is NCHW
  1356. auto param = megdnn::param::RelayoutFormat();
  1357. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1358. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1359. inp = rf.node();
  1360. } else {
  1361. // new input src is NHWCD
  1362. auto&& fmt = new_inp[0]
  1363. ->format()
  1364. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1365. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1366. inp = new_inp[0];
  1367. }
  1368. auto new_param = warp_opr.param();
  1369. new_param.format = megdnn::param::WarpAffine::Format::NHWCD4;
  1370. SymbolVar new_warp_opr;
  1371. new_warp_opr = opr::WarpAffineForward::make(inp, new_inp[1], new_inp[2],
  1372. new_param, opr->config());
  1373. return new_warp_opr.node()->owner_opr();
  1374. };
  1375. auto replace_pooling_opr = [replace_helper](OperatorNodeBase* opr,
  1376. const VarNodeArray& new_inp) {
  1377. mgb_assert(opr->input().size() == new_inp.size());
  1378. if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
  1379. return opr_shallow_copy;
  1380. }
  1381. auto& pooling_opr = opr->cast_final_safe<opr::PoolingForward>();
  1382. mgb_assert(pooling_opr.param().format ==
  1383. megdnn::param::Pooling::Format::NCHW,
  1384. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1385. VarNode* inp = nullptr;
  1386. if (new_inp[0]->shape().ndim == 4) {
  1387. // new input src is NCHW
  1388. auto param = megdnn::param::RelayoutFormat();
  1389. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1390. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1391. inp = rf.node();
  1392. } else {
  1393. // new input src is NHWCD
  1394. auto&& fmt = new_inp[0]
  1395. ->format()
  1396. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1397. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1398. inp = new_inp[0];
  1399. }
  1400. auto new_param = pooling_opr.param();
  1401. new_param.format = megdnn::param::Pooling::Format::NHWCD4;
  1402. auto new_pooling_opr =
  1403. opr::PoolingForward::make(inp, new_param, opr->config());
  1404. return new_pooling_opr.node()->owner_opr();
  1405. };
  1406. auto var_to_chw = [](VarNode* inp, VarNode* new_inp) {
  1407. if (!inp->shape().eq_shape(new_inp->shape())) {
  1408. mgb_assert(inp->shape().ndim == 4 &&
  1409. inp->format().type() !=
  1410. TensorFormat::Type::IMAGE2D_PACK4);
  1411. mgb_assert(new_inp->shape().ndim == 5 &&
  1412. new_inp->format().type() ==
  1413. TensorFormat::Type::IMAGE2D_PACK4);
  1414. auto param = megdnn::param::RelayoutFormat();
  1415. param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
  1416. auto rf = opr::RelayoutFormat::make(new_inp, param);
  1417. return rf.node();
  1418. }
  1419. return new_inp;
  1420. };
  1421. auto relayout_inp_to_chw = [var_to_chw](OperatorNodeBase* opr,
  1422. const VarNodeArray& new_inp) {
  1423. mgb_assert(opr->input().size() == new_inp.size());
  1424. VarNodeArray t_inp = new_inp;
  1425. for (size_t i = 0; i < opr->input().size(); i++) {
  1426. t_inp[i] = var_to_chw(opr->input(i), new_inp[i]);
  1427. }
  1428. auto new_opr =
  1429. serialization::copy_opr_shallow(*opr, t_inp, opr->config());
  1430. return new_opr;
  1431. };
  1432. auto replace_elemwise_opr = [&relayout_inp_to_chw](
  1433. OperatorNodeBase* opr,
  1434. const VarNodeArray& new_inp) {
  1435. mgb_assert(opr->input().size() == new_inp.size());
  1436. bool has_inp_changed = false;
  1437. bool can_exec_cd4 = true;
  1438. for (size_t i = 0; i < opr->input().size(); i++) {
  1439. if (!new_inp[i]->format().is_default()) {
  1440. has_inp_changed = true;
  1441. } else if (new_inp[i]->shape().ndim == 4) {
  1442. if (new_inp[i]->shape()[1] % 4 != 0) {
  1443. can_exec_cd4 = false;
  1444. }
  1445. //! cd4 elemwise with scaler is supported
  1446. } else if (!new_inp[i]->shape().is_scalar()) {
  1447. can_exec_cd4 = false;
  1448. }
  1449. }
  1450. if (!can_exec_cd4) {
  1451. return relayout_inp_to_chw(opr, new_inp);
  1452. }
  1453. if (has_inp_changed) {
  1454. // assumption: all inputs are changed from nchw to nhwcd4
  1455. auto t_inp = new_inp;
  1456. for (size_t i = 0; i < opr->input().size(); i++) {
  1457. if (new_inp[i]->shape().ndim == 4) {
  1458. auto param = megdnn::param::RelayoutFormat();
  1459. param.mode =
  1460. megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1461. auto rf = opr::RelayoutFormat::make(new_inp[i], param);
  1462. t_inp[i] = rf.node();
  1463. } else {
  1464. mgb_assert((new_inp[i]->shape().ndim == 5 &&
  1465. new_inp[i]->format().type() ==
  1466. TensorFormat::Type::IMAGE2D_PACK4) ||
  1467. new_inp[i]->shape().is_scalar());
  1468. }
  1469. }
  1470. return serialization::copy_opr_shallow(*opr, t_inp, opr->config());
  1471. } else {
  1472. return serialization::copy_opr_shallow(*opr, new_inp,
  1473. opr->config());
  1474. }
  1475. };
  1476. /* This helper function converts the first input to the NCHW format to
  1477. * handle operations that do not support NHWCD4 format
  1478. */
  1479. auto relayout_first_inp_to_chw =
  1480. [var_to_chw](OperatorNodeBase* opr,
  1481. const VarNodeArray& new_inp) -> OperatorNodeBase* {
  1482. mgb_assert(opr->input().size() == new_inp.size());
  1483. VarNodeArray t_inp = new_inp;
  1484. t_inp[0] = var_to_chw(opr->input(0), new_inp[0]);
  1485. return serialization::copy_opr_shallow(*opr, t_inp, opr->config());
  1486. };
  1487. auto ret = std::make_unique<ConvertFormatPass>();
  1488. ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
  1489. auto&& replace_func = ret->m_opr_replace_func;
  1490. replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
  1491. replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr;
  1492. replace_func[opr::ConvolutionBackwardData::typeinfo()] = replace_deconv_opr;
  1493. replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
  1494. replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr;
  1495. replace_func[opr::Concat::typeinfo()] = relayout_inp_to_chw;
  1496. replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_chw;
  1497. replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_chw;
  1498. replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_chw;
  1499. replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_chw;
  1500. replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_chw;
  1501. replace_func[opr::Subtensor::typeinfo()] = relayout_inp_to_chw;
  1502. replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_chw;
  1503. replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_chw;
  1504. replace_func[opr::AxisAddRemove::typeinfo()] = relayout_inp_to_chw;
  1505. replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr;
  1506. replace_func[opr::WarpPerspectiveForward::typeinfo()] =
  1507. replace_warp_perspective_opr;
  1508. replace_func[opr::WarpAffineForward::typeinfo()] = replace_warp_affine_opr;
  1509. replace_func[opr::LocalForward::typeinfo()] = relayout_first_inp_to_chw;
  1510. replace_func[opr::GroupLocalForward::typeinfo()] =
  1511. relayout_first_inp_to_chw;
  1512. return ret;
  1513. MIDOUT_E
  1514. }
  1515. /* ================ ConvertBatchNormPass ================ */
  1516. const char* ConvertBatchNormToElemwisePass::name() const {
  1517. return "convert_batch_norm";
  1518. }
  1519. void ConvertBatchNormToElemwisePass::apply(OptState& state) const {
  1520. MIDOUT_B("ConvertBatchNormToElemwisePass::apply")
  1521. auto rewriter = state.graph().make_rewriter();
  1522. auto on_opr = [&](OperatorNodeBase* opr) {
  1523. if (auto bn = try_cast_as_op<opr::BatchNorm>(opr)) {
  1524. if (bn->input().size() == 5) {
  1525. mgb_assert(bn->param().fwd_mode ==
  1526. opr::BatchNorm::Param::FwdMode::INFERENCE);
  1527. SymbolVar x = {rewriter.get_var(bn->input(0))};
  1528. SymbolVar scale = {rewriter.get_var(bn->input(1))};
  1529. SymbolVar bias = {rewriter.get_var(bn->input(2))};
  1530. SymbolVar mean = {rewriter.get_var(bn->input(3))};
  1531. SymbolVar variance = {rewriter.get_var(bn->input(4))};
  1532. SymbolVar invsqrt_variance = opr::PowC::make(variance
  1533. + variance.make_scalar_dt(float(bn->param().epsilon)), {-0.5});
  1534. auto res = scale * (x - mean) * invsqrt_variance + bias;
  1535. rewriter.replace_var(
  1536. opr->output(4), res.node(),
  1537. mgb_cstr_log(
  1538. "replace batch_norm(x, scale, bias, mean, "
  1539. "varience) "
  1540. "-> (sclae * (x - mean) / sqrt(variance)) + b)"));
  1541. return;
  1542. }
  1543. }
  1544. rewriter.auto_replace_outputs(opr);
  1545. };
  1546. state.graph().iter(on_opr);
  1547. rewriter.apply_inplace();
  1548. MIDOUT_E
  1549. }
  1550. /* ================ FuseConvBiasNonlinPass ================ */
  1551. const char* FuseConvBiasNonlinPass::name() const {
  1552. return "combine_conv_bias_and_relu";
  1553. }
  1554. void FuseConvBiasNonlinPass::apply(OptState& state) const {
  1555. MIDOUT_B("FuseConvBiasNonlinPass::apply")
  1556. std::unordered_map<VarNode*, std::vector<OperatorNodeBase*>> m_deps;
  1557. state.graph().iter([&m_deps](OperatorNodeBase* opr) {
  1558. for (auto& inp : opr->input()) {
  1559. m_deps[inp].push_back(opr);
  1560. }
  1561. });
  1562. auto rewriter = state.graph().make_rewriter();
  1563. using Mode = opr::Elemwise::Param::Mode;
  1564. using NonlineMode = opr::ConvBiasForward::Param::NonlineMode;
  1565. auto get_nonlinearity_mode = [&](opr::Elemwise* elem) -> NonlineMode {
  1566. if (elem->param().mode == Mode::FUSE_ADD_RELU ||
  1567. elem->param().mode == Mode::RELU) {
  1568. return NonlineMode::RELU;
  1569. } else if (elem->param().mode == Mode::FUSE_ADD_SIGMOID ||
  1570. elem->param().mode == Mode::SIGMOID) {
  1571. return NonlineMode::SIGMOID;
  1572. } else {
  1573. return NonlineMode::IDENTITY;
  1574. }
  1575. };
  1576. auto try_fuse_bias_nonlinearity = [&](opr::Elemwise* elem) -> bool {
  1577. bool can_be_fused = true;
  1578. can_be_fused &= (elem->input().size() == 2);
  1579. can_be_fused &= (elem->param().mode == Mode::FUSE_ADD_RELU) ||
  1580. (elem->param().mode == Mode::FUSE_ADD_TANH) ||
  1581. (elem->param().mode == Mode::FUSE_ADD_SIGMOID);
  1582. return can_be_fused;
  1583. };
  1584. auto try_fuse_bias = [&](opr::Elemwise* elem) -> bool {
  1585. bool can_be_fused = true;
  1586. can_be_fused &= (elem->input().size() == 2);
  1587. can_be_fused &= (elem->param().mode == Mode::ADD);
  1588. return can_be_fused;
  1589. };
  1590. auto try_fuse_nonlinearity = [&](opr::Elemwise* elem) -> bool {
  1591. bool can_be_fused = true;
  1592. can_be_fused &= (elem->input().size() == 1);
  1593. can_be_fused &= (elem->param().mode == Mode::RELU) ||
  1594. (elem->param().mode == Mode::TANH) ||
  1595. (elem->param().mode == Mode::SIGMOID);
  1596. return can_be_fused;
  1597. };
  1598. auto convert_to_conv_bias_param = [&](const opr::Convolution::Param& param)
  1599. -> opr::ConvBiasForward::Param {
  1600. using Param = opr::ConvBiasForward::Param;
  1601. return opr::ConvBiasForward::Param{Param::NonlineMode::IDENTITY,
  1602. param.mode,
  1603. param.sparse,
  1604. param.format,
  1605. param.pad_h,
  1606. param.pad_w,
  1607. param.stride_h,
  1608. param.stride_w,
  1609. param.dilate_h,
  1610. param.dilate_w,
  1611. 0,
  1612. param.compute_mode};
  1613. };
  1614. auto check_bias_shape = [&](opr::Convolution* conv, VarNode* bias) -> bool {
  1615. bool valid_bias_shape = true;
  1616. using Format = opr::Convolution::Param::Format;
  1617. using Sparse = opr::Convolution::Param::Sparse;
  1618. auto dst_shape = conv->output(0)->shape();
  1619. auto filter_shape = conv->input(1)->shape();
  1620. auto bias_shape = bias->shape();
  1621. //! pay attention: make sure bias node is not const provider when
  1622. //! batch > 1 cause shape assert problem in convbias
  1623. //! if you resize the input shape, can not update the bias shape too.
  1624. //! so do not fuse conv bias in this situation
  1625. if (dst_shape.eq_shape(bias_shape) && !cg::is_const_var_shape(bias)) {
  1626. return valid_bias_shape;
  1627. }
  1628. size_t OC = filter_shape[0];
  1629. if (conv->param().sparse == Sparse::GROUP) {
  1630. OC *= filter_shape[1];
  1631. }
  1632. if (conv->param().format == Format::NCHW) {
  1633. valid_bias_shape &=
  1634. ((bias_shape.ndim == 4) && (bias_shape[0] == 1) &&
  1635. (bias_shape[1] == OC) && (bias_shape[2] == 1) &&
  1636. (bias_shape[3] == 1));
  1637. } else if (conv->param().format == Format::NCHW4) {
  1638. valid_bias_shape &=
  1639. ((bias_shape.ndim == 5) && (bias_shape[0] == 1) &&
  1640. (bias_shape[1] == OC / 4) && (bias_shape[2] == 1) &&
  1641. (bias_shape[3] == 1) && bias_shape[4] == 4);
  1642. } else if (conv->param().format == Format::NHWC) {
  1643. valid_bias_shape &= ((bias_shape.ndim == 4) &&
  1644. (bias_shape[0] == 1) && (bias_shape[1] == 1) &&
  1645. (bias_shape[2] == 1) && (bias_shape[3] == OC));
  1646. } else {
  1647. valid_bias_shape &=
  1648. ((bias_shape.ndim == 5) && (bias_shape[0] == 1) &&
  1649. (bias_shape[1] == 1) && (bias_shape[2] == OC) &&
  1650. (bias_shape[3] == 1) && (bias_shape[4] == 4));
  1651. mgb_assert(conv->param().format == Format::NHWCD4);
  1652. }
  1653. return valid_bias_shape;
  1654. };
  1655. auto try_fuse_typecvt = [&](opr::TypeCvt* typecvt) -> OperatorNodeBase* {
  1656. mgb_assert(typecvt->input().size() == 1);
  1657. auto conv_bias = try_cast_as_op<opr::ConvBias>(
  1658. rewriter.get_var(typecvt->input(0))->owner_opr());
  1659. if (!conv_bias || m_deps.count(typecvt->input(0)) != 1 ||
  1660. typecvt->output(0)->dtype().enumv() !=
  1661. DTypeTrait<dtype::QuantizedS8>::enumv ||
  1662. typecvt->input(0)->dtype().enumv() !=
  1663. DTypeTrait<dtype::QuantizedS32>::enumv)
  1664. return nullptr;
  1665. auto config = conv_bias->config();
  1666. config.output_dtype(typecvt->output(0)->dtype());
  1667. if (conv_bias->input().size() == 3) {
  1668. // conv + bias
  1669. return opr::ConvBias::make(conv_bias->input(0), conv_bias->input(1),
  1670. conv_bias->input(2), conv_bias->param(),
  1671. conv_bias->execution_policy(), config)
  1672. .node()
  1673. ->owner_opr();
  1674. } else {
  1675. // conv without bias
  1676. return opr::ConvBias::make(conv_bias->input(0), conv_bias->input(1),
  1677. conv_bias->param(),
  1678. conv_bias->execution_policy(), config)
  1679. .node()
  1680. ->owner_opr();
  1681. }
  1682. };
  1683. auto on_opr = [&](OperatorNodeBase* opr) {
  1684. auto check_conv = [](opr::Convolution* conv) -> bool {
  1685. return conv->param().format ==
  1686. megdnn::param::Convolution::Format::NHWCD4 ||
  1687. conv->param().format ==
  1688. megdnn::param::Convolution::Format::NHWC ||
  1689. conv->param().format ==
  1690. megdnn::param::Convolution::Format::NCHW ||
  1691. conv->param().format ==
  1692. megdnn::param::Convolution::Format::NCHW4
  1693. ;
  1694. };
  1695. if (auto elem = try_cast_as_op<opr::Elemwise>(opr)) {
  1696. if (try_fuse_bias_nonlinearity(elem) || try_fuse_bias(elem)) {
  1697. auto inp1 = rewriter.get_var(elem->input(0));
  1698. auto inp2 = rewriter.get_var(elem->input(1));
  1699. opr::Convolution* conv = nullptr;
  1700. size_t bias_idx = 0;
  1701. if (inp1->owner_opr()->same_type<opr::Convolution>() &&
  1702. m_deps[elem->input(0)].size() == 1) {
  1703. conv = try_cast_as_op<opr::Convolution>(inp1->owner_opr());
  1704. bias_idx = 1;
  1705. } else if (inp2->owner_opr()->same_type<opr::Convolution>() &&
  1706. m_deps[elem->input(1)].size() == 1) {
  1707. conv = try_cast_as_op<opr::Convolution>(inp2->owner_opr());
  1708. bias_idx = 0;
  1709. }
  1710. auto bias_inp = rewriter.get_var(elem->input(bias_idx));
  1711. if (conv && check_conv(conv) &&
  1712. check_bias_shape(conv, bias_inp)) {
  1713. opr::ConvBiasForward::Param param =
  1714. convert_to_conv_bias_param(conv->param());
  1715. param.nonlineMode = get_nonlinearity_mode(elem);
  1716. auto new_var =
  1717. opr::ConvBiasForward::make(
  1718. conv->input(0), conv->input(1), bias_inp,
  1719. param, conv->execution_policy(),
  1720. conv->config())
  1721. .node();
  1722. rewriter.replace_var(
  1723. opr->output(0), new_var,
  1724. mgb_cstr_log("replace nonlinearity(conv(x, w) + b) "
  1725. "-> conv_bias(x, w, b)"));
  1726. return;
  1727. }
  1728. } else if (try_fuse_nonlinearity(elem)) {
  1729. auto inp = rewriter.get_var(elem->input(0));
  1730. {
  1731. auto conv =
  1732. try_cast_as_op<opr::Convolution>(inp->owner_opr());
  1733. if (conv && check_conv(conv) &&
  1734. m_deps[elem->input(0)].size() == 1) {
  1735. opr::ConvBiasForward::Param param =
  1736. convert_to_conv_bias_param(conv->param());
  1737. param.nonlineMode = get_nonlinearity_mode(elem);
  1738. auto new_var = opr::ConvBiasForward::make(
  1739. conv->input(0), conv->input(1),
  1740. param, conv->execution_policy(),
  1741. conv->config())
  1742. .node();
  1743. rewriter.replace_var(
  1744. opr->output(0), new_var,
  1745. mgb_cstr_log("replace nonlinearity(conv(x, w)) "
  1746. "-> conv_bias(x, w)"));
  1747. return;
  1748. }
  1749. }
  1750. {
  1751. auto conv = try_cast_as_op<opr::ConvBias>(inp->owner_opr());
  1752. auto check_conv_bias = [&](opr::ConvBias* opr) {
  1753. return opr->param().format ==
  1754. opr::ConvBias::Param::Format::NHWC ||
  1755. opr->param().format ==
  1756. opr::ConvBias::Param::Format::NCHW ||
  1757. opr->param().format ==
  1758. opr::ConvBias::Param::Format::NCHW4
  1759. ;
  1760. };
  1761. if (conv && check_conv_bias(conv) &&
  1762. m_deps[elem->input(0)].size() == 1) {
  1763. auto param = conv->param();
  1764. param.nonlineMode = get_nonlinearity_mode(elem);
  1765. auto new_var = opr::ConvBiasForward::make(
  1766. conv->input(0), conv->input(1),
  1767. conv->input(2), param,
  1768. conv->execution_policy(),
  1769. conv->config())
  1770. .node();
  1771. rewriter.replace_var(
  1772. opr->output(0), new_var,
  1773. mgb_cstr_log("replace nonlinearity(conv(x, w)) "
  1774. "-> conv_bias(x, w)"));
  1775. return;
  1776. }
  1777. }
  1778. }
  1779. } else if (auto typecvt = try_cast_as_op<opr::TypeCvt>(opr)) {
  1780. auto new_opr = try_fuse_typecvt(typecvt);
  1781. if (new_opr) {
  1782. rewriter.replace_var(
  1783. opr->output(0), new_opr->output(0),
  1784. mgb_cstr_log("replace typecvt(conv_bias(x, w, b)) -> "
  1785. "conv_bias(x, w, b)"));
  1786. return;
  1787. }
  1788. }
  1789. rewriter.auto_replace_outputs(opr);
  1790. };
  1791. state.graph().iter(on_opr);
  1792. rewriter.apply_inplace();
  1793. MIDOUT_E
  1794. }
  1795. /* ================ FuseConvBiasZPass ================ */
  1796. const char* FuseConvBiasZPass::name() const {
  1797. return "combine_conv_bias_and_z";
  1798. }
  1799. void FuseConvBiasZPass::apply(OptState& state) const {
  1800. MIDOUT_B("FuseConvBiasZPass::apply")
  1801. UniqReaderCheck uniq_reader_check{state.graph()};
  1802. auto rewriter = state.graph().make_rewriter();
  1803. using Mode = opr::Elemwise::Param::Mode;
  1804. using MultiMode = opr::ElemwiseMultiType::Param::Mode;
  1805. using NonlineMode = opr::ConvBiasForward::Param::NonlineMode;
  1806. auto check_conv_bias = [](opr::ConvBias* conv_bias) -> bool {
  1807. return conv_bias->param().format ==
  1808. megdnn::param::ConvBias::Format::NHWC ||
  1809. conv_bias->param().format ==
  1810. megdnn::param::ConvBias::Format::NCHW ||
  1811. conv_bias->param().format ==
  1812. megdnn::param::ConvBias::Format::NCHW4
  1813. ;
  1814. };
  1815. auto check_fuse_shape = [&](opr::ConvBias* conv_bias, VarNode* z) -> bool {
  1816. bool valid_fuse_shape = true;
  1817. auto z_shape = z->shape();
  1818. auto bias_shape = conv_bias->input(2)->shape();
  1819. auto conv_bias_shape = conv_bias->output(0)->shape();
  1820. valid_fuse_shape &= (!conv_bias_shape.eq_shape(bias_shape));
  1821. valid_fuse_shape &= conv_bias_shape.eq_shape(z_shape);
  1822. return valid_fuse_shape;
  1823. };
  1824. auto check_fuse_dtype = [&](opr::ConvBias* conv_bias, VarNode* z) -> bool {
  1825. return conv_bias->output(0)->dtype().enumv() == z->dtype().enumv();
  1826. };
  1827. auto get_convbias_nonline_mode = [&](OperatorNodeBase* opr) -> NonlineMode {
  1828. if (opr->same_type<opr::Elemwise>()) {
  1829. auto elem = try_cast_as_op<opr::Elemwise>(opr);
  1830. if (elem->param().mode == Mode::FUSE_ADD_RELU)
  1831. return NonlineMode::RELU;
  1832. }
  1833. if (opr->same_type<opr::ElemwiseMultiType>()) {
  1834. auto elem = try_cast_as_op<opr::ElemwiseMultiType>(opr);
  1835. if (elem->param().mode == MultiMode::QFUSE_ADD_RELU)
  1836. return NonlineMode::RELU;
  1837. else if (elem->param().mode == MultiMode::QFUSE_ADD_H_SWISH)
  1838. return NonlineMode::H_SWISH;
  1839. }
  1840. return NonlineMode::IDENTITY;
  1841. };
  1842. auto try_replace_var_node = [&](OperatorNodeBase* opr) {
  1843. opr::ConvBias* conv_bias = nullptr;
  1844. size_t z_idx = 0;
  1845. size_t nr_inps = opr->input().size();
  1846. for (size_t i = 0; i < nr_inps; i++) {
  1847. auto inp = rewriter.get_var(opr->input(i));
  1848. if (inp->owner_opr()->same_type<opr::ConvBias>()) {
  1849. auto cb = try_cast_as_op<opr::ConvBias>(inp->owner_opr());
  1850. if (cb->input().size() == 3 &&
  1851. cb->param().nonlineMode ==
  1852. opr::ConvBias::Param::NonlineMode::IDENTITY &&
  1853. uniq_reader_check(opr->input(i))) {
  1854. conv_bias = cb;
  1855. z_idx = nr_inps - i - 1;
  1856. break;
  1857. }
  1858. }
  1859. }
  1860. auto z_inp = rewriter.get_var(opr->input(z_idx));
  1861. if (conv_bias && check_conv_bias(conv_bias) &&
  1862. check_fuse_shape(conv_bias, z_inp) &&
  1863. check_fuse_dtype(conv_bias, z_inp)) {
  1864. auto param = conv_bias->param();
  1865. param.nonlineMode = get_convbias_nonline_mode(opr);
  1866. auto config = conv_bias->config();
  1867. auto new_var = opr::ConvBiasForward::make(
  1868. conv_bias->input(0), conv_bias->input(1),
  1869. conv_bias->input(2), z_inp, param,
  1870. conv_bias->execution_policy(),
  1871. config.output_dtype(opr->output(0)->dtype()))
  1872. .node();
  1873. rewriter.replace_var(
  1874. opr->output(0), new_var,
  1875. mgb_cstr_log("replace "
  1876. "nonlinearity(conv_bias(x,w,b) + z) "
  1877. "-> conv_bias(x, w, b, z)"));
  1878. uniq_reader_check.update_on_opr_auto_replace(opr,
  1879. new_var->owner_opr());
  1880. return true;
  1881. }
  1882. return false;
  1883. };
  1884. auto try_fuse_elemwise = [&](OperatorNodeBase* opr) {
  1885. if (!opr->same_type<opr::Elemwise>())
  1886. return false;
  1887. auto elem = try_cast_as_op<opr::Elemwise>(opr);
  1888. if (elem->input().size() != 2)
  1889. return false;
  1890. if (elem->param().mode != Mode::ADD &&
  1891. elem->param().mode != Mode::FUSE_ADD_RELU)
  1892. return false;
  1893. return try_replace_var_node(opr);
  1894. };
  1895. auto try_fuse_elemwise_multi_type = [&](OperatorNodeBase* opr) {
  1896. if (!opr->same_type<opr::ElemwiseMultiType>())
  1897. return false;
  1898. auto elem = try_cast_as_op<opr::ElemwiseMultiType>(opr);
  1899. if (elem->input().size() != 2)
  1900. return false;
  1901. if (elem->param().mode != MultiMode::QADD &&
  1902. elem->param().mode != MultiMode::QFUSE_ADD_RELU &&
  1903. elem->param().mode != MultiMode::QFUSE_ADD_H_SWISH)
  1904. return false;
  1905. return try_replace_var_node(opr);
  1906. };
  1907. auto on_opr = [&](OperatorNodeBase* opr) {
  1908. if (try_fuse_elemwise(opr))
  1909. return;
  1910. if (try_fuse_elemwise_multi_type(opr))
  1911. return;
  1912. auto new_opr = rewriter.auto_replace_outputs(opr);
  1913. uniq_reader_check.update_on_opr_auto_replace(opr, new_opr);
  1914. };
  1915. state.graph().iter(on_opr);
  1916. rewriter.apply_inplace();
  1917. MIDOUT_E
  1918. }
  1919. /* ================ FuseDeconvCvtPass ================ */
  1920. const char* FuseDeconvCvtPass::name() const {
  1921. return "combine_deconv_and_typecvt";
  1922. }
  1923. void FuseDeconvCvtPass::apply(OptState& state) const {
  1924. MIDOUT_B("FuseDeconvCvtPass::apply")
  1925. std::unordered_map<VarNode*, std::vector<OperatorNodeBase*>> m_deps;
  1926. state.graph().iter([&m_deps](OperatorNodeBase* opr) {
  1927. for (auto& inp : opr->input()) {
  1928. m_deps[inp].push_back(opr);
  1929. }
  1930. });
  1931. UniqReaderCheck uniq_reader_check{state.graph()};
  1932. auto rewriter = state.graph().make_rewriter();
  1933. auto try_fuse_deconv_typecvt =
  1934. [&](opr::TypeCvt* typecvt) -> OperatorNodeBase* {
  1935. mgb_assert(typecvt->input().size() == 1);
  1936. auto deconv = try_cast_as_op<opr::ConvolutionBackwardData>(
  1937. rewriter.get_var(typecvt->input(0))->owner_opr());
  1938. if (!deconv
  1939. || m_deps.count(typecvt->input(0)) != 1 ||
  1940. typecvt->output(0)->dtype().enumv() !=
  1941. DTypeTrait<dtype::QuantizedS8>::enumv) {
  1942. return nullptr;
  1943. }
  1944. if (!uniq_reader_check(deconv->output(0)))
  1945. return nullptr;
  1946. auto config = deconv->config();
  1947. config.output_dtype(typecvt->output(0)->dtype());
  1948. return opr::ConvolutionBackwardData::make(
  1949. deconv->input(0), deconv->input(1), deconv->param(),
  1950. deconv->execution_policy(), config)
  1951. .node()
  1952. ->owner_opr();
  1953. };
  1954. auto on_opr = [&](OperatorNodeBase* opr) {
  1955. if (auto typecvt = try_cast_as_op<opr::TypeCvt>(opr)) {
  1956. if (auto deconv_new = try_fuse_deconv_typecvt(typecvt)) {
  1957. rewriter.replace_var(
  1958. opr->output(0), deconv_new->output(0),
  1959. mgb_cstr_log("replace typecvt(deconv(x, w)) -> "
  1960. "deconv(x, w)"));
  1961. uniq_reader_check.update_on_opr_auto_replace(opr, deconv_new);
  1962. return;
  1963. }
  1964. }
  1965. auto new_opr = rewriter.auto_replace_outputs(opr);
  1966. uniq_reader_check.update_on_opr_auto_replace(
  1967. opr, new_opr);
  1968. };
  1969. state.graph().iter(on_opr);
  1970. rewriter.apply_inplace();
  1971. MIDOUT_E
  1972. }
  1973. /* ================ ParamMergePass ================ */
  1974. const char* ParamMergePass::name() const {
  1975. return mgb_cstr_log("param_merge");
  1976. }
  1977. void ParamMergePass::apply(OptState& opt_state) const {
  1978. MIDOUT_B("ParamMergePass::apply")
  1979. param_merge<opr::SharedDeviceTensor, opr::MultipleDeviceTensorHolder>(
  1980. opt_state);
  1981. param_merge<opr::SharedDeviceTensorWithFormat,
  1982. opr::MultipleDeviceTensorWithFormatHolder>(opt_state);
  1983. MIDOUT_E
  1984. }
  1985. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台