You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inference.cpp 83 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978
  1. /**
  2. * \file src/gopt/impl/inference.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/gopt/inference.h"
  12. #include "megbrain/gopt/gtrans.h"
  13. #include "megbrain/gopt/basic_arith.h"
  14. #include "megbrain/graph/event.h"
  15. #include "megbrain/opr/dnn/batch_norm.h"
  16. #include "megbrain/opr/dnn/local.h"
  17. #include "megbrain/utils/shared_set.h"
  18. #include "megbrain/serialization/opr_shallow_copy.h"
  19. #include "megbrain/opr/basic_arith.h"
  20. #include "megbrain/opr/dnn/convolution.h"
  21. #include "megbrain/opr/blas.h"
  22. #include "megbrain/opr/misc.h"
  23. #include "megbrain/opr/utility.h"
  24. #include "megbrain/opr/dnn/pooling.h"
  25. #include "megbrain/opr/tensor_manip.h"
  26. #include "megbrain/opr/imgproc.h"
  27. #include "megbrain/opr/nn_int.h"
  28. #include "megdnn/tensor_format.h"
  29. #if MGB_ENABLE_TENSOR_RT
  30. #include "megbrain/tensorrt/tensorrt_opr.h"
  31. #endif
  32. #include "megbrain/gopt/misc.h"
  33. using namespace mgb;
  34. using namespace gopt;
  35. namespace {
  36. template <typename SharedDeviceTensor, typename MultipleDeviceTensorHolder>
  37. void param_merge(OptState& opt_state) {
  38. auto rewriter = opt_state.graph().make_rewriter();
  39. ThinHashMap<OperatorNodeBase*, size_t> opr2idx;
  40. std::vector<OperatorNodeBase*> all_oprs;
  41. typename MultipleDeviceTensorHolder::ValueArray all_values;
  42. auto cb_find_opr = [&](cg::OperatorNodeBase* opr) {
  43. if (opr->same_type<SharedDeviceTensor>()) {
  44. auto p = &opr->cast_final<SharedDeviceTensor>();
  45. // ShredD may be manu
  46. opr2idx[p] = all_values.size();
  47. all_values.push_back(p->dev_data());
  48. all_oprs.push_back(p);
  49. }
  50. };
  51. opt_state.graph().iter(cb_find_opr);
  52. SymbolVarArray new_vars;
  53. auto cb_replace = [&](cg::OperatorNodeBase* opr) {
  54. auto iter = opr2idx.find(opr);
  55. if (iter == opr2idx.end()) {
  56. rewriter.auto_replace_outputs(opr);
  57. } else {
  58. if (new_vars.empty()) {
  59. // new oprs must be created in iter callback; so we populate
  60. // new_vars lazily
  61. new_vars = MultipleDeviceTensorHolder::make(
  62. *opt_state.graph().comp_graph(), std::move(all_values),
  63. {ssprintf("merged%zu", all_values.size())});
  64. for (size_t i = 0; i < new_vars.size(); ++i) {
  65. auto src = all_oprs[i]->output(0);
  66. if (src->has_name_set()) {
  67. new_vars[i].rename(src->name());
  68. }
  69. }
  70. }
  71. rewriter.replace_var(
  72. opr->output(0), new_vars.at(iter->second).node(),
  73. mgb_cstr_log("replace multi SharedDeviceTensor(Format) to "
  74. "MultipleDeviceTensorHolder(Format)"));
  75. }
  76. };
  77. opt_state.graph().iter(cb_replace);
  78. rewriter.apply_inplace();
  79. }
  80. }
  81. /* ================ global functions ================ */
  82. SymbolVarArray gopt::optimize_for_inference(
  83. const SymbolVarArray& dest_vars,
  84. const OptimizeForInferenceOptions& opt) {
  85. return gopt::GraphOptimizer()
  86. .add_preset_passes(false, &opt,
  87. &dest_vars[0].node()->owner_graph()->options())
  88. .apply({dest_vars})
  89. .endpoint_vars();
  90. }
  91. namespace {
  92. void modify_conv_policy(opr::mixin::Convolution& conv,
  93. megdnn::param::ExecutionPolicy::Strategy strategy) {
  94. auto policy = conv.execution_policy_transient();
  95. policy.strategy = strategy;
  96. conv.set_execution_policy(policy);
  97. }
  98. template <typename Opr>
  99. void inplace_conv_opr_profile_modifier(OperatorNodeBase& opr) {
  100. modify_conv_policy(
  101. opr.cast_final_safe<Opr>(),
  102. opr::mixin::Convolution::ExecutionPolicy::Strategy::PROFILE);
  103. }
  104. template <typename Opr>
  105. void inplace_conv_opr_profile_cache_modifier(OperatorNodeBase& opr) {
  106. modify_conv_policy(opr.cast_final_safe<Opr>(),
  107. opr::mixin::Convolution::ExecutionPolicy::Strategy::
  108. PROFILE_HEURISTIC);
  109. }
  110. void modify_conv_policy_workspace_limit(opr::mixin::Convolution& conv,
  111. size_t workspace_limit) {
  112. auto policy = conv.execution_policy_transient();
  113. policy.workspace_limit = workspace_limit;
  114. conv.set_execution_policy(policy);
  115. }
  116. template <typename Opr>
  117. void inplace_conv_opr_workspace_limit_modifier(OperatorNodeBase& opr,
  118. size_t workspace_limit) {
  119. modify_conv_policy_workspace_limit(opr.cast_final_safe<Opr>(),
  120. workspace_limit);
  121. }
  122. } // anonymous namespace
  123. #define MGB_FOREACH_FASTRUN_OPR(cb) \
  124. cb(ConvolutionForward), cb(ConvBiasForward), cb(ConvolutionBackwardData), \
  125. cb(ConvolutionBackwardFilter), cb(Convolution3DForward), \
  126. cb(Convolution3DBackwardData), cb(Convolution3DBackwardFilter), \
  127. cb(LocalShareForward), cb(LocalShareBackwardData), \
  128. cb(LocalShareBackwardFilter), cb(DeformableConvForward), \
  129. cb(DeformableConvBackwardFilter), cb(DeformableConvBackwardData), \
  130. cb(BatchConvBiasForward),
  131. void gopt::enable_opr_algo_profiling_inplace(
  132. const VarNodeArrayView& dest_vars) {
  133. #if MGB_ENABLE_FASTRUN
  134. static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&)> modifiers =
  135. {
  136. #define CONV(t) {opr::t::typeinfo(), &inplace_conv_opr_profile_modifier<opr::t>}
  137. MGB_FOREACH_FASTRUN_OPR(CONV)
  138. #undef CONV
  139. };
  140. auto on_opr = [&](OperatorNodeBase* opr) {
  141. auto iter = modifiers.find(opr->dyn_typeinfo());
  142. if (iter != modifiers.end()) {
  143. iter->second(*opr);
  144. }
  145. };
  146. cg::DepOprIter dep_iter{on_opr};
  147. for (auto i : dest_vars) {
  148. dep_iter.add(i);
  149. }
  150. #else
  151. mgb_throw(MegBrainError, "fastrun is disabled at compile time");
  152. #endif
  153. }
  154. void gopt::enable_opr_use_profiling_cache_inplace(
  155. const VarNodeArrayView& dest_vars) {
  156. static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&)> modifiers =
  157. {
  158. #define CONV(t) \
  159. {opr::t::typeinfo(), &inplace_conv_opr_profile_cache_modifier<opr::t>}
  160. MGB_FOREACH_FASTRUN_OPR(CONV)
  161. #undef CONV
  162. };
  163. auto on_opr = [&](OperatorNodeBase* opr) {
  164. auto iter = modifiers.find(opr->dyn_typeinfo());
  165. if (iter != modifiers.end()) {
  166. iter->second(*opr);
  167. }
  168. };
  169. cg::DepOprIter dep_iter{on_opr};
  170. for (auto i : dest_vars) {
  171. dep_iter.add(i);
  172. }
  173. }
  174. void gopt::set_opr_algo_workspace_limit_inplace(
  175. const VarNodeArrayView& dest_vars, size_t workspace_limit) {
  176. static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&, size_t)>
  177. modifiers = {
  178. #define CONV(t) \
  179. {opr::t::typeinfo(), &inplace_conv_opr_workspace_limit_modifier<opr::t>}
  180. MGB_FOREACH_FASTRUN_OPR(CONV)
  181. #undef CONV
  182. };
  183. auto on_opr = [&](OperatorNodeBase* opr) {
  184. auto iter = modifiers.find(opr->dyn_typeinfo());
  185. if (iter != modifiers.end()) {
  186. iter->second(*opr, workspace_limit);
  187. }
  188. };
  189. cg::DepOprIter dep_iter{on_opr};
  190. for (auto i : dest_vars) {
  191. dep_iter.add(i);
  192. }
  193. }
  194. #undef MGB_FOREACH_FASTRUN_OPR
  195. /* ================ ParamRedistributePass ================ */
  196. const char* ParamRedistributePass::name() const {
  197. return mgb_cstr_log("param_redistribute");
  198. }
  199. class ParamRedistributePass::Impl final: public RecursiveSubGraphRewriteHelper {
  200. ConstVarPropogate m_cvprop;
  201. UniqReaderCheck m_uniq_reader_check;
  202. //! oprs already processed in try_distribute_then_reassociate() should be
  203. //! skipped in on_new_opr_check_should_process()
  204. ThinHashSet<OperatorNodeBase*> m_opr_blacklist;
  205. std::string m_distribute_reasso_log_msg;
  206. //! try applying BinaryTrans20::associtive
  207. GTransResult try_reassociate(OperatorNodeBase *opr);
  208. //! try applying BinaryTrans20::distributive_add
  209. GTransResult try_distribute_add(OperatorNodeBase *opr);
  210. //! try distribute MUL/DIV over ADD/SUB and then apply
  211. GTransResult try_distribute_then_reassociate(OperatorNodeBase *opr);
  212. GTransResult process_opr(VarNode *out_var) override;
  213. bool on_new_opr_check_should_process(
  214. OperatorNodeBase*opr, OperatorNodeBase *repl_opr) override {
  215. m_uniq_reader_check.update_on_opr_auto_replace(opr, repl_opr);
  216. auto ins = m_cvprop.add_opr(opr);
  217. return ins.has_const_inp && !ins.all_const_inp &&
  218. !m_opr_blacklist.count(opr);
  219. };
  220. void after_replace_var(VarNode *orig_var, VarNode* new_var) override {
  221. m_uniq_reader_check.update_on_opr_auto_replace(orig_var->owner_opr(),
  222. new_var->owner_opr());
  223. }
  224. /*!
  225. * \brief try to reorder opr inputs to a const one and a non-const one
  226. *
  227. * return true if it can be reformulated as f(nci, ci), where nci is
  228. * non-const and ci is const.
  229. */
  230. bool reorder_for_normconst(OperatorNodeBase *opr,
  231. bool &swap_inp, VarNode *&nci, VarNode *&ci);
  232. public:
  233. Impl(OptState &state);
  234. };
  235. GTransResult ParamRedistributePass::Impl::process_opr(VarNode *out_var) {
  236. auto opr = out_var->owner_opr();
  237. auto trans = try_reassociate(opr);
  238. if (!trans.valid()) {
  239. trans = try_distribute_add(opr);
  240. if (!trans.valid())
  241. trans = try_distribute_then_reassociate(opr);
  242. }
  243. return trans;
  244. }
  245. GTransResult ParamRedistributePass::Impl::try_reassociate(
  246. OperatorNodeBase *opr) {
  247. // apply BinaryAssociative0 if opr is the form f(g(a, b), c) and b and c are
  248. // const
  249. bool swap_fop_inp = false, swap_gop_inp = false;
  250. VarNode *a, *b, *c, *ab;
  251. if (!reorder_for_normconst(opr, swap_fop_inp, ab, c))
  252. return None;
  253. if (!m_uniq_reader_check(ab))
  254. return None;
  255. if (!reorder_for_normconst(ab->owner_opr(), swap_gop_inp, a, b))
  256. return None;
  257. return BinaryTrans20::associtive().apply(opr, swap_fop_inp, swap_gop_inp);
  258. }
  259. GTransResult ParamRedistributePass::Impl::try_distribute_add(
  260. OperatorNodeBase *opr) {
  261. if (opr->same_type<opr::Elemwise>() || opr->input().size() != 2)
  262. return None;
  263. if (!m_cvprop.is_const(opr->input(1)))
  264. return None;
  265. auto ab = as_elem_opr(opr->input(0)->owner_opr(), opr::Elemwise::Mode::ADD);
  266. if (ab) {
  267. bool swap;
  268. VarNode *a, *b;
  269. if (reorder_for_normconst(ab, swap, a, b)) {
  270. return BinaryTrans20::distributive_add().apply(
  271. opr, false, swap);
  272. }
  273. }
  274. return None;
  275. }
  276. GTransResult ParamRedistributePass::Impl::try_distribute_then_reassociate(
  277. OperatorNodeBase *opr) {
  278. if (!opr->same_type<opr::Elemwise>())
  279. return None;
  280. using Mode = opr::Elemwise::Mode;
  281. auto mode = opr->cast_final<opr::Elemwise>().param().mode;
  282. if (!(mode == Mode::MUL || mode == Mode::TRUE_DIV))
  283. return None;
  284. VarNode *a, *b;
  285. bool swap;
  286. if (!reorder_for_normconst(opr, swap, a, b))
  287. return None;
  288. auto chain_pred = [this](OperatorNodeBase *opr) {
  289. if (as_elem_opr(opr, Mode::ADD)) {
  290. auto var = opr->output(0);
  291. return m_uniq_reader_check(var) || m_cvprop.is_const(var);
  292. }
  293. return false;
  294. };
  295. auto chain = extract_opr_leaves(a, chain_pred);
  296. if (chain.size() <= 1)
  297. return None;
  298. std::unordered_map<VarNode*, VarNode*> repl_map;
  299. m_distribute_reasso_log_msg.clear();
  300. int nr_fail = 0, nr_succ = 0;
  301. for (auto &&var: chain) {
  302. {
  303. auto iter = repl_map.find(var);
  304. if (iter != repl_map.end()) {
  305. var = iter->second;
  306. continue;
  307. }
  308. }
  309. auto vnew = (SymbolVar{var} * b).node();
  310. m_opr_blacklist.insert(vnew->owner_opr());
  311. if (!m_cvprop.is_const(var)) {
  312. auto trans = try_reassociate(vnew->owner_opr());
  313. if (!trans.valid()) {
  314. // allow at most one failed redistribution
  315. if (nr_fail)
  316. return None;
  317. ++ nr_fail;
  318. } else {
  319. ++ nr_succ;
  320. vnew = trans->result;
  321. if (!m_distribute_reasso_log_msg.empty()) {
  322. m_distribute_reasso_log_msg.append(mgb_cstr_log(";"));
  323. }
  324. m_distribute_reasso_log_msg.append(trans->msg);
  325. }
  326. }
  327. repl_map[var] = vnew;
  328. var = vnew;
  329. }
  330. if (nr_succ) {
  331. m_distribute_reasso_log_msg.insert(0,
  332. mgb_cstr_log("distribute_mul("));
  333. m_distribute_reasso_log_msg.append(mgb_cstr_log(")"));
  334. return GTransResultItem{
  335. elemwise_reduce_var_list(chain, Mode::ADD),
  336. m_distribute_reasso_log_msg.c_str(),
  337. {}};
  338. }
  339. return None;
  340. }
  341. bool ParamRedistributePass::Impl::reorder_for_normconst(
  342. OperatorNodeBase *opr, bool &swap_inp, VarNode *&nci, VarNode *&ci) {
  343. if (opr->input().size() != 2)
  344. return false;
  345. nci = opr->input(0);
  346. ci = opr->input(1);
  347. if (!m_cvprop.is_const(ci)) {
  348. if (!is_commutable_binary(opr) || !m_cvprop.is_const(nci))
  349. return false;
  350. swap_inp = true;
  351. std::swap(nci, ci);
  352. } else {
  353. if (m_cvprop.is_const(nci))
  354. return false;
  355. swap_inp = false;
  356. }
  357. return true;
  358. }
  359. ParamRedistributePass::Impl::Impl(OptState &state):
  360. RecursiveSubGraphRewriteHelper{state},
  361. m_cvprop{ConstVarType::IMMUTABLE_AND_PARAM},
  362. m_uniq_reader_check{state.graph()}
  363. {
  364. auto cg = state.graph().comp_graph();
  365. auto on_new_opr = [this](const cg::event::OprInserted &ev) {
  366. if (!ev.is_dedup && !ev.exc) {
  367. // call add_opr eagerly to avoid deep recursion
  368. m_cvprop.add_opr(ev.opr);
  369. }
  370. };
  371. auto hdl = cg->event().register_receiver
  372. <cg::event::OprInserted>(on_new_opr);
  373. apply();
  374. }
  375. void ParamRedistributePass::apply(OptState &state) const {
  376. Impl{state};
  377. }
  378. /* ================ ParamFusePass ================ */
  379. class ParamFusePass::ConstVarPropogateWithSizeCheck final:
  380. public ConstVarPropogateBase
  381. {
  382. public:
  383. //! rewrite a var; reader == nullptr means needed by endpoint
  384. using VarRewriter = std::function<
  385. void(VarNode *var, OperatorNodeBase *reader)>;
  386. ConstVarPropogateWithSizeCheck(
  387. const ParamFusePass &pf, OptState &opt_state,
  388. const VarRewriter &rewriter):
  389. ConstVarPropogateBase{ConstVarType::IMMUTABLE_AND_PARAM},
  390. m_owner{pf}, m_opt_state{opt_state}, m_rewriter{rewriter}
  391. {
  392. }
  393. private:
  394. const ParamFusePass &m_owner;
  395. OptState &m_opt_state;
  396. VarRewriter m_rewriter;
  397. void on_midconst_opr(
  398. OperatorNodeBase *opr, size_t max_src_size) override {
  399. for (auto var: opr->output()) {
  400. if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT))
  401. continue;
  402. auto osize = var_mem_size(var);
  403. if (osize >= max_src_size &&
  404. osize - max_src_size > m_owner.m_param_grow_limit) {
  405. return;
  406. }
  407. // const oprs should be evaluated when output is used by another
  408. // non-const opr or output is needed by the user
  409. if (m_opt_state.graph().endpoint_contain(var)) {
  410. m_rewriter(var, nullptr);
  411. }
  412. }
  413. }
  414. };
  415. /*!
  416. * \brief get name for new param
  417. */
  418. class ParamFusePass::VarNamer {
  419. #if MGB_BUILD_SLIM_SERVING
  420. public:
  421. const std::string& name(VarNode*) {
  422. static std::string ret("fuse");
  423. return ret;
  424. }
  425. #else
  426. using SrcSet = SharedSet<OperatorNodeBase*>;
  427. //! map from var to source SharedDeviceTensor/MultiSharedDeviceHolder oprs
  428. //! that it depends on
  429. ThinHashMap<OperatorNodeBase*, SrcSet> m_opr2srcs;
  430. std::string m_name_cache;
  431. std::vector<const char*> m_cur_name;
  432. SrcSet& get_src_set(OperatorNodeBase* opr) {
  433. auto opr_typeinfo = opr->dyn_typeinfo();
  434. auto iter = m_opr2srcs.find(opr);
  435. if (iter != m_opr2srcs.end()) {
  436. return iter->second;
  437. }
  438. auto &&ret = m_opr2srcs[opr];
  439. if (opr->input().empty()) {
  440. if (opr_typeinfo == opr::SharedDeviceTensor::typeinfo() ||
  441. opr_typeinfo == opr::MultipleDeviceTensorHolder::typeinfo()) {
  442. ret.insert(opr);
  443. } else {
  444. mgb_assert(opr_typeinfo == opr::ImmutableTensor::typeinfo());
  445. }
  446. return ret;
  447. }
  448. for (auto i: opr->input()) {
  449. ret.merge_from(get_src_set(i->owner_opr()));
  450. }
  451. return ret;
  452. }
  453. public:
  454. const std::string& name(VarNode *var) {
  455. m_cur_name.clear();
  456. for (auto i : get_src_set(var->owner_opr())) {
  457. m_cur_name.push_back(i->cname());
  458. }
  459. auto cmp = [](const char *x, const char *y) {
  460. return strcmp(x, y) < 0;
  461. };
  462. std::sort(m_cur_name.begin(), m_cur_name.end(), cmp);
  463. m_name_cache.clear();
  464. m_name_cache.append(mgb_cstr_log("fuse("));
  465. bool first = true;
  466. for (auto i: m_cur_name) {
  467. if (first) {
  468. first = false;
  469. } else {
  470. m_name_cache.push_back(',');
  471. }
  472. m_name_cache.append(i);
  473. }
  474. m_name_cache.append(mgb_cstr_log(
  475. ssprintf("):%s@%zu", var->cname(), var->id())));
  476. return m_name_cache;
  477. }
  478. #endif
  479. };
  480. const char* ParamFusePass::name() const {
  481. return mgb_cstr_log("param_fuse");
  482. }
  483. void ParamFusePass::apply(OptState &state) const {
  484. auto rewriter = state.graph().make_rewriter();
  485. auto cg = state.graph().comp_graph();
  486. ThinHashSet<VarNode*> processed_var;
  487. VarNamer var_namer;
  488. // reader: null if used as endvar
  489. auto replace_single_var = [&](VarNode *var, OperatorNodeBase *reader) {
  490. if (!processed_var.insert(var).second)
  491. return;
  492. auto inferred_val = std::make_shared<DeviceTensorND>(
  493. var->comp_node(), var->dtype());
  494. auto cb = [&](DeviceTensorND& val) {
  495. // retain format of val
  496. mgb_assert(val.format() == var->format());
  497. inferred_val->format(val.format())
  498. .resize(val.shape())
  499. .copy_from_fixlayout(val);
  500. };
  501. {
  502. auto orig_level = cg->options().log_level;
  503. cg->options().log_level = 0;
  504. MGB_TRY {
  505. cg->compile({{var, cb}})->execute();
  506. } MGB_FINALLY(cg->options().log_level = orig_level);
  507. }
  508. SymbolVar new_var;
  509. bool is_default_format = var->layout().format.is_default();
  510. if (cg::is_static_var_value(var) && is_default_format) {
  511. // use ImmutableTensor for inferable vars
  512. HostTensorND hv;
  513. hv.copy_from(*inferred_val).sync();
  514. new_var = opr::ImmutableTensor::make(
  515. *var->owner_graph(), hv, var_namer.name(var));
  516. } else {
  517. if (is_default_format) {
  518. new_var = opr::SharedDeviceTensor::make(
  519. *var->owner_graph(), inferred_val, var_namer.name(var));
  520. } else {
  521. new_var = opr::SharedDeviceTensorWithFormat::make(
  522. *var->owner_graph(), inferred_val, var_namer.name(var));
  523. }
  524. }
  525. std::string log;
  526. if (reader) {
  527. log = mgb_ssprintf_log(
  528. "due to read by %s{%s}",
  529. reader->cname(), reader->dyn_typeinfo()->name);
  530. } else {
  531. log = mgb_cstr_log("as endpoint");
  532. }
  533. rewriter.replace_var(var, new_var.node(), log.c_str());
  534. };
  535. ConstVarPropogateWithSizeCheck cvprop{*this, state, replace_single_var};
  536. auto on_opr = [&](OperatorNodeBase *opr) {
  537. auto add_ret = cvprop.add_opr(opr);
  538. if (!add_ret.all_const_inp && add_ret.has_midconst_inp) {
  539. for (auto i: opr->input()) {
  540. if (cvprop.is_midconst(i)) {
  541. state.call_with_opr(i->owner_opr(),
  542. [&]{replace_single_var(i, opr);});
  543. }
  544. }
  545. }
  546. rewriter.auto_replace_outputs(opr);
  547. };
  548. state.graph().iter(on_opr);
  549. rewriter.apply_inplace();
  550. }
  551. /* ================ One2OneOprReplacePass ================ */
  552. const char* ConvertF32ToF16Pass::name() const {
  553. return mgb_cstr_log("convert_f32_to_f16");
  554. }
  555. void ConvertF32ToF16Pass::apply(OptState& state) const {
  556. state.set_var_replace_check_flag(m_var_replace_check_flag);
  557. auto rewriter = state.graph().make_rewriter();
  558. VarNodeArray new_inp_cache;
  559. auto on_opr = [this, &rewriter, &new_inp_cache,
  560. &state](OperatorNodeBase* opr) {
  561. auto it = m_opr_replace_func.find(opr->dyn_typeinfo());
  562. if (it != m_opr_replace_func.end()) {
  563. auto&& new_inp = new_inp_cache;
  564. new_inp.clear();
  565. new_inp.reserve(opr->input().size());
  566. for (auto i: opr->input()) {
  567. new_inp.push_back(rewriter.get_var(i));
  568. }
  569. auto new_opr = (it->second)(opr, new_inp);
  570. auto &&origin_out = opr->output(), &&cur_out = new_opr->output();
  571. mgb_assert(origin_out.size() == cur_out.size(),
  572. "bad opr replace: src=%s{%s} dst=%s{%s}", opr->cname(),
  573. opr->dyn_typeinfo()->name, new_opr->cname(),
  574. new_opr->dyn_typeinfo()->name);
  575. //! change the output type if it's the endpoint
  576. for (size_t i = 0; i < origin_out.size(); i++) {
  577. if (state.graph().endpoint_contain(origin_out[i]) &&
  578. origin_out[i]->dtype().enumv() !=
  579. cur_out[i]->dtype().enumv()) {
  580. rewriter.replace_var(
  581. origin_out[i],
  582. opr::TypeCvt::make(cur_out[i],
  583. origin_out[i]->dtype())
  584. .node(),
  585. nullptr);
  586. } else {
  587. rewriter.replace_var(origin_out[i], cur_out[i], nullptr);
  588. }
  589. }
  590. } else {
  591. auto new_opr = rewriter.auto_replace_outputs(opr);
  592. auto&& out = opr->output();
  593. auto&& new_out = new_opr->output();
  594. for (size_t i = 0; i < out.size(); i++) {
  595. if (state.graph().endpoint_contain(out[i]) &&
  596. new_out[i]->dtype().enumv() != out[i]->dtype().enumv()) {
  597. rewriter.replace_var(
  598. new_out[i],
  599. opr::TypeCvt::make(new_out[i],
  600. out[i]->dtype())
  601. .node(),
  602. nullptr);
  603. }
  604. }
  605. }
  606. };
  607. state.graph().iter(on_opr);
  608. rewriter.apply_inplace();
  609. }
  610. std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(
  611. bool use_f32_comp) {
  612. #if MEGDNN_DISABLE_FLOAT16
  613. mgb_throw(SystemError, "float16 disabled at compile time.");
  614. #else
  615. auto replace_h2d_opr = [](OperatorNodeBase* opr,
  616. const VarNodeArray& new_inp) {
  617. mgb_assert(opr->input().size() == new_inp.size());
  618. auto& h2d_opr = opr->cast_final_safe<opr::Host2DeviceCopy>();
  619. if (h2d_opr.output(0)->dtype() == dtype::Float32()) {
  620. auto cvt_var =
  621. opr::TypeCvt::make(h2d_opr.output(0), dtype::Float16(), {});
  622. return cvt_var.node()->owner_opr();
  623. }
  624. return opr;
  625. };
  626. auto replace_sdt_opr = [](OperatorNodeBase* opr,
  627. const VarNodeArray& new_inp) {
  628. mgb_assert(opr->input().size() == new_inp.size());
  629. auto& sdt_opr = opr->cast_final_safe<opr::SharedDeviceTensor>();
  630. if (sdt_opr.output(0)->dtype() == dtype::Float32()) {
  631. auto cvt_var =
  632. opr::TypeCvt::make(sdt_opr.output(0), dtype::Float16(), {});
  633. return cvt_var.node()->owner_opr();
  634. }
  635. return opr;
  636. };
  637. auto replace_imt_opr = [](OperatorNodeBase* opr,
  638. const VarNodeArray& new_inp) {
  639. mgb_assert(opr->same_type<opr::ImmutableTensor>());
  640. mgb_assert(opr->input().size() == new_inp.size());
  641. auto& imt_opr = opr->cast_final_safe<opr::ImmutableTensor>();
  642. if (imt_opr.output(0)->dtype() == dtype::Float32()) {
  643. auto cvt_var =
  644. opr::TypeCvt::make(imt_opr.output(0), dtype::Float16(), {});
  645. return cvt_var.node()->owner_opr();
  646. }
  647. return opr;
  648. };
  649. auto replace_conv_opr = [use_f32_comp](OperatorNodeBase* opr,
  650. const VarNodeArray& new_inp) {
  651. mgb_assert(opr->input().size() == new_inp.size());
  652. auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
  653. auto new_param = conv_opr.param();
  654. if (use_f32_comp) {
  655. new_param.compute_mode =
  656. megdnn::param::Convolution::ComputeMode::FLOAT32;
  657. }
  658. mgb_assert(new_inp[0]->dtype() == dtype::Float16(),
  659. "inp %s:%s, owner_opr:%s", new_inp[0]->dtype().name(),
  660. new_inp[0]->name().c_str(),
  661. new_inp[0]->owner_opr()->name().c_str());
  662. mgb_assert(new_inp[1]->dtype() == dtype::Float16(),
  663. "inp %s:%s, owner_opr:%s", new_inp[1]->dtype().name(),
  664. new_inp[1]->name().c_str(),
  665. new_inp[1]->owner_opr()->name().c_str());
  666. auto new_conv_opr = opr::Convolution::make(
  667. new_inp[0], new_inp[1], new_param, conv_opr.execution_policy(),
  668. conv_opr.config());
  669. return new_conv_opr.node()->owner_opr();
  670. };
  671. auto replace_matmul_opr = [use_f32_comp](OperatorNodeBase* opr,
  672. const VarNodeArray& new_inp) {
  673. mgb_assert(opr->input().size() == new_inp.size());
  674. auto& matmul_opr = opr->cast_final_safe<opr::MatrixMul>();
  675. auto new_param = matmul_opr.param();
  676. if (use_f32_comp) {
  677. new_param.compute_mode =
  678. megdnn::param::MatrixMul::ComputeMode::FLOAT32;
  679. }
  680. auto new_matmul_opr = opr::MatrixMul::make(
  681. new_inp[0], new_inp[1], new_param, matmul_opr.config());
  682. return new_matmul_opr.node()->owner_opr();
  683. };
  684. auto replace_reduce_opr = [use_f32_comp](OperatorNodeBase* opr,
  685. const VarNodeArray& new_inp) {
  686. auto& reduce_opr = opr->cast_final_safe<opr::Reduce>();
  687. auto new_param = reduce_opr.param();
  688. if (use_f32_comp) {
  689. new_param.data_type =
  690. megdnn::param::Reduce::DataType::FLOAT_O16xC32;
  691. }
  692. if (opr->input().size() == 1) {
  693. auto new_matmul_opr = opr::Reduce::make(new_inp[0], new_param, {},
  694. reduce_opr.config());
  695. return new_matmul_opr.node()->owner_opr();
  696. } else {
  697. mgb_assert(opr->input().size() == 2, "invalid input size %zu",
  698. opr->input().size());
  699. auto new_matmul_opr = opr::Reduce::make(
  700. new_inp[0], new_param, new_inp[1], reduce_opr.config());
  701. return new_matmul_opr.node()->owner_opr();
  702. }
  703. };
  704. auto replace_cvt_opr = [](OperatorNodeBase* opr,
  705. const VarNodeArray& new_inp) {
  706. auto& cvt_opr = opr->cast_final_safe<opr::TypeCvt>();
  707. SymbolVar new_cvt;
  708. if (cvt_opr.output(0)->dtype() == dtype::Float32()) {
  709. new_cvt = opr::TypeCvt::make(new_inp[0], dtype::Float16(),
  710. cvt_opr.config());
  711. } else {
  712. new_cvt = opr::TypeCvt::make(
  713. new_inp[0], cvt_opr.output()[0]->dtype(), cvt_opr.config());
  714. }
  715. return new_cvt.node()->owner_opr();
  716. };
  717. auto replace_warp_opr = [](OperatorNodeBase* opr,
  718. const VarNodeArray& new_inp) {
  719. mgb_assert(opr->input().size() == new_inp.size() &&
  720. (new_inp.size() == 3 || new_inp.size() == 4));
  721. auto& warp_opr = opr->cast_final<opr::WarpPerspective>();
  722. // mat tensor must be float32
  723. auto new_mat = new_inp[1];
  724. if (new_inp[1]->dtype() != dtype::Float32()) {
  725. if (try_cast_as_op<opr::TypeCvt>(new_mat->owner_opr()) &&
  726. new_mat->owner_opr()->input(0)->dtype() == dtype::Float32())
  727. new_mat = new_mat->owner_opr()->input(0);
  728. else
  729. new_mat =
  730. opr::TypeCvt::make(new_inp[1], dtype::Float32(), {}).node();
  731. }
  732. SymbolVar new_warp;
  733. if (new_inp.size() == 3) {
  734. new_warp = opr::WarpPerspective::make(new_inp[0], new_mat,
  735. new_inp[2], warp_opr.param(),
  736. warp_opr.config());
  737. } else {
  738. mgb_assert(new_inp.size() == 4);
  739. new_warp = opr::WarpPerspective::make(
  740. new_inp[0], new_mat, new_inp[2], new_inp[3],
  741. warp_opr.param(), warp_opr.config());
  742. }
  743. return new_warp.node()->owner_opr();
  744. };
  745. auto replace_remap_opr = [](OperatorNodeBase* opr,
  746. const VarNodeArray& new_inp) {
  747. mgb_assert(opr->input().size() == new_inp.size() &&
  748. (new_inp.size() == 2));
  749. auto& remap_opr = opr->cast_final<opr::Remap>();
  750. // map tensor must be float32
  751. auto new_map = new_inp[1];
  752. if (new_inp[1]->dtype() != dtype::Float32()) {
  753. if (try_cast_as_op<opr::TypeCvt>(new_map->owner_opr()) &&
  754. new_map->owner_opr()->input(0)->dtype() == dtype::Float32())
  755. new_map = new_map->owner_opr()->input(0);
  756. else
  757. new_map =
  758. opr::TypeCvt::make(new_inp[1], dtype::Float32(), {}).node();
  759. }
  760. SymbolVar new_remap;
  761. new_remap = opr::Remap::make(new_inp[0], new_map,
  762. remap_opr.param(),
  763. remap_opr.config());
  764. return new_remap.node()->owner_opr();
  765. };
  766. auto ret = std::make_unique<ConvertF32ToF16Pass>();
  767. // don't check dtype
  768. ret->set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^
  769. VarReplaceCheckFlag::CHECK_DTYPE);
  770. auto&& replace_func = ret->m_opr_replace_func;
  771. replace_func[opr::Host2DeviceCopy::typeinfo()] = replace_h2d_opr;
  772. replace_func[opr::SharedDeviceTensor::typeinfo()] = replace_sdt_opr;
  773. replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
  774. replace_func[opr::MatrixMul::typeinfo()] = replace_matmul_opr;
  775. replace_func[opr::Reduce::typeinfo()] = replace_reduce_opr;
  776. replace_func[opr::ImmutableTensor::typeinfo()] = replace_imt_opr;
  777. replace_func[opr::TypeCvt::typeinfo()] = replace_cvt_opr;
  778. replace_func[opr::WarpPerspective::typeinfo()] = replace_warp_opr;
  779. replace_func[opr::Remap::typeinfo()] = replace_remap_opr;
  780. return ret;
  781. #endif
  782. }
  783. /* ================ ConvertFormatPass ================ */
  784. void ConvertFormatPass::apply(OptState& state) const {
  785. state.set_var_replace_check_flag(m_var_replace_check_flag);
  786. auto rewriter = state.graph().make_rewriter();
  787. VarNodeArray new_inp_cache;
  788. auto on_opr = [this, &state, &rewriter,
  789. &new_inp_cache](OperatorNodeBase* opr) {
  790. auto it = m_opr_replace_func.find(opr->dyn_typeinfo());
  791. if (it != m_opr_replace_func.end()) {
  792. auto&& new_inp = new_inp_cache;
  793. new_inp.clear();
  794. new_inp.reserve(opr->input().size());
  795. for (auto i : opr->input()) {
  796. new_inp.push_back(rewriter.get_var(i));
  797. }
  798. auto new_opr = (it->second)(opr, new_inp);
  799. auto &&out0 = opr->output(), &&out1 = new_opr->output();
  800. mgb_assert(out0.size() == out1.size(),
  801. "bad opr replace: src=%s{%s} dst=%s{%s}, src.size=%zu "
  802. "dst.size=%zu",
  803. opr->cname(), opr->dyn_typeinfo()->name,
  804. new_opr->cname(), new_opr->dyn_typeinfo()->name,
  805. out0.size(), out1.size());
  806. for (size_t i = 0; i < out0.size(); i++) {
  807. if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  808. mgb_assert(!out1[i]->contain_flag(
  809. VarNode::Flag::VOLATILE_CONTENT));
  810. auto src = out0[i];
  811. auto dst = out1[i];
  812. auto dst_is_image = dst->format().type() ==
  813. TensorFormat::Type::IMAGE2D_PACK4;
  814. if (!dst_is_image &&
  815. !src->owner_opr()->same_type<opr::ImmutableTensor>()) {
  816. mgb_log_warn(
  817. "convert NHWCD4 replaced to non-img format: "
  818. "dst_opr=%s{%s} format=%s",
  819. dst->owner_opr()->cname(),
  820. dst->owner_opr()->dyn_typeinfo()->name,
  821. dst->format().to_string().c_str());
  822. }
  823. if (state.graph().endpoint_contain(src) && dst_is_image) {
  824. // relayout back to NCHW for output vars
  825. dst = opr::RelayoutFormat::make(
  826. dst, {opr::RelayoutFormat::Param::Mode::
  827. NHWCD4I_NCHW})
  828. .node();
  829. }
  830. rewriter.replace_var(src, dst, nullptr);
  831. }
  832. }
  833. } else {
  834. rewriter.auto_replace_outputs(opr);
  835. }
  836. };
  837. state.graph().iter(on_opr);
  838. rewriter.apply_inplace();
  839. }
  840. std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
  841. auto filter_mode =
  842. [](const megdnn::param::Convolution::Sparse conv_mode,
  843. const VarNode* filter) -> megdnn::param::RelayoutFormat::Mode {
  844. bool use_dot = false;
  845. if (filter->dtype().enumv() == megdnn::DTypeEnum::QuantizedS8 ||
  846. filter->dtype().enumv() == megdnn::DTypeEnum::Quantized8Asymm)
  847. use_dot = true;
  848. if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) {
  849. if (use_dot)
  850. return megdnn::param::RelayoutFormat::Mode::
  851. INTER_WEIGHT_DENSEI_DOT;
  852. return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_DENSEI;
  853. } else {
  854. mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP);
  855. if (filter->shape()[1] == 1 && filter->shape()[2] == 1) {
  856. return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_CHANI;
  857. } else {
  858. if (use_dot)
  859. return megdnn::param::RelayoutFormat::Mode::
  860. INTER_WEIGHT_GROUPI_DOT;
  861. return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_GROUPI;
  862. }
  863. }
  864. };
  865. auto replace_conv_opr = [&filter_mode](OperatorNodeBase* opr,
  866. const VarNodeArray& new_inp) {
  867. mgb_assert(opr->input().size() == new_inp.size());
  868. auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
  869. mgb_assert(conv_opr.param().format ==
  870. megdnn::param::Convolution::Format::NCHW,
  871. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  872. VarNode *conv_src = nullptr, *conv_weights = nullptr;
  873. if (new_inp[0]->shape().ndim == 4) {
  874. // new input src is NCHW
  875. size_t group, icpg, ocpg;
  876. if (conv_opr.param().sparse ==
  877. megdnn::param::Convolution::Sparse::DENSE) {
  878. group = 1;
  879. icpg = new_inp[1]->shape()[1];
  880. ocpg = new_inp[1]->shape()[0];
  881. } else {
  882. mgb_assert(conv_opr.param().sparse ==
  883. megdnn::param::Convolution::Sparse::GROUP);
  884. group = new_inp[1]->shape()[0];
  885. icpg = new_inp[1]->shape()[2];
  886. ocpg = new_inp[1]->shape()[1];
  887. }
  888. if (ocpg % 4 == 0 && (icpg % 4 == 0 || group == 1)) {
  889. auto param = megdnn::param::RelayoutFormat();
  890. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  891. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  892. conv_src = rf.node();
  893. } else {
  894. // can not convert to hwcd4
  895. return serialization::copy_opr_shallow(*opr, new_inp,
  896. opr->config());
  897. }
  898. } else {
  899. size_t ocpg;
  900. bool is_channel_wise = false;
  901. if (conv_opr.param().sparse ==
  902. megdnn::param::Convolution::Sparse::DENSE) {
  903. ocpg = new_inp[1]->shape()[0];
  904. } else {
  905. mgb_assert(conv_opr.param().sparse ==
  906. megdnn::param::Convolution::Sparse::GROUP);
  907. size_t icpg = new_inp[1]->shape()[2];
  908. ocpg = new_inp[1]->shape()[1];
  909. if (icpg == 1 && ocpg == 1) {
  910. is_channel_wise = true;
  911. }
  912. }
  913. if (ocpg % 4 != 0 && !is_channel_wise) {
  914. VarNodeArray t_inp = new_inp;
  915. auto param = megdnn::param::RelayoutFormat();
  916. param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
  917. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  918. t_inp[0] = rf.node();
  919. auto new_opr = serialization::copy_opr_shallow(*opr, t_inp,
  920. opr->config());
  921. return new_opr;
  922. }
  923. // new input src is NHWCD4
  924. auto&& fmt = new_inp[0]
  925. ->format()
  926. .as_impl<megdnn::Image2DPack4TensorFormat>();
  927. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  928. conv_src = new_inp[0];
  929. }
  930. mgb_assert(new_inp[1]->format().type() !=
  931. TensorFormat::Type::IMAGE2D_PACK4);
  932. auto param = megdnn::param::RelayoutFormat();
  933. param.mode = filter_mode(conv_opr.param().sparse, new_inp[1]);
  934. auto relayout_weight = opr::RelayoutFormat::make(new_inp[1], param);
  935. conv_weights = relayout_weight.node();
  936. auto new_param = conv_opr.param();
  937. new_param.format = megdnn::param::Convolution::Format::NHWCD4;
  938. mgb_assert(conv_src->shape().ndim == 5 &&
  939. conv_src->format().type() ==
  940. TensorFormat::Type::IMAGE2D_PACK4);
  941. auto new_conv_opr = opr::Convolution::make(
  942. conv_src, conv_weights, new_param, conv_opr.execution_policy(),
  943. conv_opr.config());
  944. OperatorNodeBase* ret = new_conv_opr.node()->owner_opr();
  945. mgb_assert(new_conv_opr.shape().ndim == 5 &&
  946. new_conv_opr.format().type() ==
  947. TensorFormat::Type::IMAGE2D_PACK4);
  948. return ret;
  949. };
  950. auto replace_conv_bias_opr = [&filter_mode](OperatorNodeBase* opr,
  951. const VarNodeArray& new_inp) {
  952. mgb_assert(opr->input().size() == new_inp.size());
  953. auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
  954. mgb_assert(conv_bias_opr.param().format ==
  955. megdnn::param::ConvBias::Format::NCHW,
  956. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  957. VarNode *conv_bias_src = nullptr, *conv_bias_weights = nullptr,
  958. *conv_bias_bias = nullptr;
  959. if (new_inp[0]->shape().ndim == 4) {
  960. // new input src is NCHW
  961. size_t group, icpg, ocpg;
  962. if (conv_bias_opr.param().sparse ==
  963. megdnn::param::ConvBias::Sparse::DENSE) {
  964. group = 1;
  965. icpg = new_inp[1]->shape()[1];
  966. ocpg = new_inp[1]->shape()[0];
  967. } else {
  968. mgb_assert(conv_bias_opr.param().sparse ==
  969. megdnn::param::ConvBias::Sparse::GROUP);
  970. group = new_inp[1]->shape()[0];
  971. icpg = new_inp[1]->shape()[2];
  972. ocpg = new_inp[1]->shape()[1];
  973. }
  974. if (ocpg % 4 == 0 && (icpg % 4 == 0 || group == 1)) {
  975. auto param = megdnn::param::RelayoutFormat();
  976. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  977. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  978. conv_bias_src = rf.node();
  979. } else {
  980. // can not convert to hwcd4
  981. return serialization::copy_opr_shallow(*opr, new_inp,
  982. opr->config());
  983. }
  984. } else {
  985. size_t ocpg;
  986. bool is_channel_wise = false;
  987. if (conv_bias_opr.param().sparse ==
  988. megdnn::param::ConvBias::Sparse::DENSE) {
  989. ocpg = new_inp[1]->shape()[0];
  990. } else {
  991. mgb_assert(conv_bias_opr.param().sparse ==
  992. megdnn::param::ConvBias::Sparse::GROUP);
  993. size_t icpg = new_inp[1]->shape()[2];
  994. ocpg = new_inp[1]->shape()[1];
  995. if (icpg == 1 && ocpg == 1) {
  996. is_channel_wise = true;
  997. }
  998. }
  999. if (ocpg % 4 != 0 && !is_channel_wise) {
  1000. VarNodeArray t_inp = new_inp;
  1001. auto param = megdnn::param::RelayoutFormat();
  1002. param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
  1003. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1004. t_inp[0] = rf.node();
  1005. auto new_opr = serialization::copy_opr_shallow(*opr, t_inp,
  1006. opr->config());
  1007. return new_opr;
  1008. }
  1009. // new input src is NHWCD4
  1010. auto&& fmt = new_inp[0]
  1011. ->format()
  1012. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1013. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1014. conv_bias_src = new_inp[0];
  1015. }
  1016. mgb_assert(new_inp[1]->format().type() !=
  1017. TensorFormat::Type::IMAGE2D_PACK4);
  1018. auto param = megdnn::param::RelayoutFormat();
  1019. param.mode = filter_mode(conv_bias_opr.param().sparse, new_inp[1]);
  1020. auto relayout_weight = opr::RelayoutFormat::make(new_inp[1], param);
  1021. conv_bias_weights = relayout_weight.node();
  1022. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1023. auto relayout_bias = opr::RelayoutFormat::make(new_inp[2], param);
  1024. conv_bias_bias = relayout_bias.node();
  1025. auto new_param = conv_bias_opr.param();
  1026. new_param.format = megdnn::param::ConvBias::Format::NHWCD4;
  1027. mgb_assert(conv_bias_src->shape().ndim == 5 &&
  1028. conv_bias_src->format().type() ==
  1029. TensorFormat::Type::IMAGE2D_PACK4);
  1030. auto new_conv_bias_opr = opr::ConvBias::make(
  1031. conv_bias_src, conv_bias_weights, conv_bias_bias, new_param,
  1032. conv_bias_opr.execution_policy(), conv_bias_opr.config());
  1033. OperatorNodeBase* ret = new_conv_bias_opr.node()->owner_opr();
  1034. mgb_assert(new_conv_bias_opr.shape().ndim == 5 &&
  1035. new_conv_bias_opr.format().type() ==
  1036. TensorFormat::Type::IMAGE2D_PACK4);
  1037. return ret;
  1038. };
  1039. auto replace_deconv_opr = [&filter_mode](OperatorNodeBase* opr,
  1040. const VarNodeArray& new_inp) {
  1041. mgb_assert(opr->input().size() == new_inp.size());
  1042. auto& deconv_opr = opr->cast_final_safe<opr::ConvolutionBackwardData>();
  1043. mgb_assert(deconv_opr.param().format ==
  1044. megdnn::param::Convolution::Format::NCHW,
  1045. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1046. VarNode *deconv_src = nullptr, *deconv_weights = nullptr;
  1047. if (new_inp[1]->shape().ndim == 4) {
  1048. // new input src is NCHW
  1049. size_t group, icpg, ocpg;
  1050. if (deconv_opr.param().sparse ==
  1051. megdnn::param::Convolution::Sparse::DENSE) {
  1052. group = 1;
  1053. icpg = new_inp[0]->shape()[0];
  1054. ocpg = new_inp[0]->shape()[1];
  1055. } else {
  1056. mgb_assert(deconv_opr.param().sparse ==
  1057. megdnn::param::Convolution::Sparse::GROUP);
  1058. group = new_inp[0]->shape()[0];
  1059. icpg = new_inp[0]->shape()[1];
  1060. ocpg = new_inp[0]->shape()[2];
  1061. }
  1062. if (ocpg % 4 == 0 && (icpg % 4 == 0 || group == 1)) {
  1063. auto param = megdnn::param::RelayoutFormat();
  1064. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1065. auto rf = opr::RelayoutFormat::make(new_inp[1], param);
  1066. deconv_src = rf.node();
  1067. } else {
  1068. // can not convert to hwcd4
  1069. return serialization::copy_opr_shallow(*opr, new_inp,
  1070. opr->config());
  1071. }
  1072. } else {
  1073. //! XXXX, fix me, check filter size
  1074. size_t ocpg;
  1075. if (deconv_opr.param().sparse ==
  1076. megdnn::param::Convolution::Sparse::DENSE) {
  1077. ocpg = new_inp[0]->shape()[1];
  1078. } else {
  1079. mgb_assert(deconv_opr.param().sparse ==
  1080. megdnn::param::Convolution::Sparse::GROUP);
  1081. ocpg = new_inp[0]->shape()[2];
  1082. }
  1083. if (ocpg % 4 != 0) {
  1084. VarNodeArray t_inp = new_inp;
  1085. auto param = megdnn::param::RelayoutFormat();
  1086. param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
  1087. auto rf = opr::RelayoutFormat::make(new_inp[1], param);
  1088. t_inp[1] = rf.node();
  1089. auto new_opr = serialization::copy_opr_shallow(*opr, t_inp,
  1090. opr->config());
  1091. return new_opr;
  1092. }
  1093. // new input src is NHWCD4
  1094. auto&& fmt = new_inp[1]
  1095. ->format()
  1096. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1097. mgb_assert(new_inp[1]->shape().ndim == 5 && fmt.align_axis() == 2);
  1098. deconv_src = new_inp[1];
  1099. }
  1100. mgb_assert(new_inp[0]->format().type() !=
  1101. TensorFormat::Type::IMAGE2D_PACK4);
  1102. auto param = megdnn::param::RelayoutFormat();
  1103. param.mode = filter_mode(deconv_opr.param().sparse, new_inp[0]);
  1104. auto relayout_weight = opr::RelayoutFormat::make(new_inp[0], param);
  1105. deconv_weights = relayout_weight.node();
  1106. auto new_param = deconv_opr.param();
  1107. new_param.format = megdnn::param::Convolution::Format::NHWCD4;
  1108. mgb_assert(deconv_src->shape().ndim == 5 &&
  1109. deconv_src->format().type() ==
  1110. TensorFormat::Type::IMAGE2D_PACK4);
  1111. auto new_deconv_opr = opr::ConvolutionBackwardData::make(
  1112. deconv_weights, deconv_src, new_param,
  1113. deconv_opr.execution_policy(), deconv_opr.config());
  1114. OperatorNodeBase* ret = new_deconv_opr.node()->owner_opr();
  1115. mgb_assert(new_deconv_opr.shape().ndim == 5 &&
  1116. new_deconv_opr.format().type() ==
  1117. TensorFormat::Type::IMAGE2D_PACK4);
  1118. return ret;
  1119. };
  1120. /* This helper function guarantees the format convert pass won't change
  1121. * output var's channel. Changing output's channel will cause channel
  1122. * mismatch problem for replacing conv/conv_bias operator.
  1123. */
  1124. auto replace_helper = [](OperatorNodeBase* opr,
  1125. const VarNodeArray& new_inp) -> OperatorNodeBase* {
  1126. auto&& new_shp = new_inp[0]->shape();
  1127. size_t inp_channel = new_shp[1];
  1128. if (new_shp.eq_shape(opr->input(0)->shape())&& inp_channel % 4 != 0) {
  1129. auto new_opr = serialization::copy_opr_shallow(*opr, new_inp,
  1130. opr->config());
  1131. return new_opr;
  1132. }
  1133. return nullptr;
  1134. };
  1135. auto replace_resize_opr = [replace_helper](OperatorNodeBase* opr,
  1136. const VarNodeArray& new_inp) {
  1137. mgb_assert(opr->input().size() == new_inp.size());
  1138. if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
  1139. return opr_shallow_copy;
  1140. }
  1141. auto& resize_opr = opr->cast_final_safe<opr::ResizeForward>();
  1142. mgb_assert(resize_opr.param().format ==
  1143. megdnn::param::Resize::Format::NCHW,
  1144. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1145. VarNode* inp = nullptr;
  1146. if (new_inp[0]->shape().ndim == 4) {
  1147. auto param = megdnn::param::RelayoutFormat();
  1148. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1149. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1150. inp = rf.node();
  1151. } else {
  1152. // new input src is NHWCD
  1153. auto&& fmt = new_inp[0]
  1154. ->format()
  1155. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1156. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1157. inp = new_inp[0];
  1158. }
  1159. auto new_param = resize_opr.param();
  1160. new_param.format = megdnn::param::Resize::Format::NHWCD4;
  1161. auto new_resize_opr = opr::ResizeForward::make(
  1162. inp, new_inp[1], new_param, opr->config());
  1163. return new_resize_opr.node()->owner_opr();
  1164. };
  1165. auto replace_warp_perspective_opr = [replace_helper](
  1166. OperatorNodeBase* opr,
  1167. const VarNodeArray& new_inp) {
  1168. mgb_assert(opr->input().size() == new_inp.size());
  1169. if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
  1170. return opr_shallow_copy;
  1171. }
  1172. auto& warp_opr = opr->cast_final_safe<opr::WarpPerspectiveForward>();
  1173. mgb_assert(warp_opr.param().format ==
  1174. megdnn::param::WarpPerspective::Format::NCHW,
  1175. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1176. VarNode* inp = nullptr;
  1177. if (new_inp[0]->shape().ndim == 4) {
  1178. // new input src is NCHW
  1179. auto param = megdnn::param::RelayoutFormat();
  1180. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1181. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1182. inp = rf.node();
  1183. } else {
  1184. // new input src is NHWCD
  1185. auto&& fmt = new_inp[0]
  1186. ->format()
  1187. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1188. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1189. inp = new_inp[0];
  1190. }
  1191. auto new_param = warp_opr.param();
  1192. new_param.format = megdnn::param::WarpPerspective::Format::NHWCD4;
  1193. SymbolVar new_warp_opr;
  1194. if (new_inp.size() == 3) {
  1195. new_warp_opr = opr::WarpPerspectiveForward::make(
  1196. inp, new_inp[1], nullptr, new_inp[2], new_param,
  1197. opr->config());
  1198. } else {
  1199. mgb_assert(new_inp.size() == 4);
  1200. new_warp_opr = opr::WarpPerspectiveForward::make(
  1201. inp, new_inp[1], new_inp[2], new_inp[3], new_param,
  1202. opr->config());
  1203. }
  1204. return new_warp_opr.node()->owner_opr();
  1205. };
  1206. auto replace_warp_affine_opr = [replace_helper](OperatorNodeBase* opr,
  1207. const VarNodeArray& new_inp) {
  1208. mgb_assert(opr->input().size() == new_inp.size());
  1209. if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
  1210. return opr_shallow_copy;
  1211. }
  1212. auto& warp_opr = opr->cast_final_safe<opr::WarpAffineForward>();
  1213. mgb_assert(warp_opr.param().format ==
  1214. megdnn::param::WarpAffine::Format::NCHW,
  1215. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1216. VarNode* inp = nullptr;
  1217. if (new_inp[0]->shape().ndim == 4) {
  1218. // new input src is NCHW
  1219. auto param = megdnn::param::RelayoutFormat();
  1220. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1221. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1222. inp = rf.node();
  1223. } else {
  1224. // new input src is NHWCD
  1225. auto&& fmt = new_inp[0]
  1226. ->format()
  1227. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1228. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1229. inp = new_inp[0];
  1230. }
  1231. auto new_param = warp_opr.param();
  1232. new_param.format = megdnn::param::WarpAffine::Format::NHWCD4;
  1233. SymbolVar new_warp_opr;
  1234. new_warp_opr = opr::WarpAffineForward::make(inp, new_inp[1], new_inp[2],
  1235. new_param, opr->config());
  1236. return new_warp_opr.node()->owner_opr();
  1237. };
  1238. auto replace_pooling_opr = [replace_helper](OperatorNodeBase* opr,
  1239. const VarNodeArray& new_inp) {
  1240. mgb_assert(opr->input().size() == new_inp.size());
  1241. if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
  1242. return opr_shallow_copy;
  1243. }
  1244. auto& pooling_opr = opr->cast_final_safe<opr::PoolingForward>();
  1245. mgb_assert(pooling_opr.param().format ==
  1246. megdnn::param::Pooling::Format::NCHW,
  1247. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1248. VarNode* inp = nullptr;
  1249. if (new_inp[0]->shape().ndim == 4) {
  1250. // new input src is NCHW
  1251. auto param = megdnn::param::RelayoutFormat();
  1252. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1253. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1254. inp = rf.node();
  1255. } else {
  1256. // new input src is NHWCD
  1257. auto&& fmt = new_inp[0]
  1258. ->format()
  1259. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1260. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1261. inp = new_inp[0];
  1262. }
  1263. auto new_param = pooling_opr.param();
  1264. new_param.format = megdnn::param::Pooling::Format::NHWCD4;
  1265. auto new_pooling_opr =
  1266. opr::PoolingForward::make(inp, new_param, opr->config());
  1267. return new_pooling_opr.node()->owner_opr();
  1268. };
  1269. auto var_to_chw = [](VarNode* inp, VarNode* new_inp) {
  1270. if (!inp->shape().eq_shape(new_inp->shape())) {
  1271. mgb_assert(inp->shape().ndim == 4 &&
  1272. inp->format().type() !=
  1273. TensorFormat::Type::IMAGE2D_PACK4);
  1274. mgb_assert(new_inp->shape().ndim == 5 &&
  1275. new_inp->format().type() ==
  1276. TensorFormat::Type::IMAGE2D_PACK4);
  1277. auto param = megdnn::param::RelayoutFormat();
  1278. param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
  1279. auto rf = opr::RelayoutFormat::make(new_inp, param);
  1280. return rf.node();
  1281. }
  1282. return new_inp;
  1283. };
  1284. auto relayout_inp_to_chw = [var_to_chw](OperatorNodeBase* opr,
  1285. const VarNodeArray& new_inp) {
  1286. mgb_assert(opr->input().size() == new_inp.size());
  1287. VarNodeArray t_inp = new_inp;
  1288. for (size_t i = 0; i < opr->input().size(); i++) {
  1289. t_inp[i] = var_to_chw(opr->input(i), new_inp[i]);
  1290. }
  1291. auto new_opr =
  1292. serialization::copy_opr_shallow(*opr, t_inp, opr->config());
  1293. return new_opr;
  1294. };
  1295. auto replace_elemwise_opr = [](OperatorNodeBase* opr,
  1296. const VarNodeArray& new_inp) {
  1297. mgb_assert(opr->input().size() == new_inp.size());
  1298. bool has_inp_changed = false;
  1299. for (size_t i = 0; i < opr->input().size(); i++) {
  1300. if (!new_inp[i]->format().is_default()) {
  1301. has_inp_changed = true;
  1302. break;
  1303. }
  1304. }
  1305. if (has_inp_changed) {
  1306. // assumption: all inputs are changed from nchw to nhwcd4
  1307. auto t_inp = new_inp;
  1308. for (size_t i = 0; i < opr->input().size(); i++) {
  1309. if (new_inp[i]->shape().ndim == 4) {
  1310. auto param = megdnn::param::RelayoutFormat();
  1311. param.mode =
  1312. megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1313. auto rf = opr::RelayoutFormat::make(new_inp[i], param);
  1314. t_inp[i] = rf.node();
  1315. } else {
  1316. mgb_assert((new_inp[i]->shape().ndim == 5 &&
  1317. new_inp[i]->format().type() ==
  1318. TensorFormat::Type::IMAGE2D_PACK4) ||
  1319. new_inp[i]->shape().is_scalar());
  1320. }
  1321. }
  1322. return serialization::copy_opr_shallow(*opr, t_inp, opr->config());
  1323. } else {
  1324. return serialization::copy_opr_shallow(*opr, new_inp,
  1325. opr->config());
  1326. }
  1327. };
  1328. /* This helper function converts the first input to the NCHW format to
  1329. * handle operations that do not support NHWCD4 format
  1330. */
  1331. auto relayout_first_inp_to_chw =
  1332. [var_to_chw](OperatorNodeBase* opr,
  1333. const VarNodeArray& new_inp) -> OperatorNodeBase* {
  1334. mgb_assert(opr->input().size() == new_inp.size());
  1335. VarNodeArray t_inp = new_inp;
  1336. t_inp[0] = var_to_chw(opr->input(0), new_inp[0]);
  1337. return serialization::copy_opr_shallow(*opr, t_inp, opr->config());
  1338. };
  1339. auto ret = std::make_unique<ConvertFormatPass>();
  1340. ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
  1341. auto&& replace_func = ret->m_opr_replace_func;
  1342. replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
  1343. replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr;
  1344. replace_func[opr::ConvolutionBackwardData::typeinfo()] = replace_deconv_opr;
  1345. replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
  1346. replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr;
  1347. replace_func[opr::Concat::typeinfo()] = relayout_inp_to_chw;
  1348. replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_chw;
  1349. replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_chw;
  1350. replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_chw;
  1351. replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_chw;
  1352. replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_chw;
  1353. replace_func[opr::Subtensor::typeinfo()] = relayout_inp_to_chw;
  1354. replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_chw;
  1355. replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_chw;
  1356. replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr;
  1357. replace_func[opr::WarpPerspectiveForward::typeinfo()] =
  1358. replace_warp_perspective_opr;
  1359. replace_func[opr::WarpAffineForward::typeinfo()] = replace_warp_affine_opr;
  1360. replace_func[opr::LocalForward::typeinfo()] = relayout_first_inp_to_chw;
  1361. replace_func[opr::GroupLocalForward::typeinfo()] =
  1362. relayout_first_inp_to_chw;
  1363. return ret;
  1364. }
  1365. /* ================ ConvertBatchNormPass ================ */
  1366. const char* ConvertBatchNormToElemwisePass::name() const {
  1367. return "convert_batch_norm";
  1368. }
  1369. void ConvertBatchNormToElemwisePass::apply(OptState& state) const {
  1370. auto rewriter = state.graph().make_rewriter();
  1371. auto on_opr = [&](OperatorNodeBase* opr) {
  1372. if (auto bn = try_cast_as_op<opr::BatchNorm>(opr)) {
  1373. if (bn->input().size() == 5) {
  1374. mgb_assert(bn->param().fwd_mode ==
  1375. opr::BatchNorm::Param::FwdMode::INFERENCE);
  1376. SymbolVar x = {rewriter.get_var(bn->input(0))};
  1377. SymbolVar scale = {rewriter.get_var(bn->input(1))};
  1378. SymbolVar bias = {rewriter.get_var(bn->input(2))};
  1379. SymbolVar mean = {rewriter.get_var(bn->input(3))};
  1380. SymbolVar variance = {rewriter.get_var(bn->input(4))};
  1381. SymbolVar invsqrt_variance = opr::PowC::make(variance, {-0.5});
  1382. auto res = scale * (x - mean) * invsqrt_variance + bias;
  1383. rewriter.replace_var(
  1384. opr->output(4), res.node(),
  1385. mgb_cstr_log(
  1386. "replace batch_norm(x, scale, bias, mean, "
  1387. "varience) "
  1388. "-> (sclae * (x - mean) / sqrt(variance)) + b)"));
  1389. return;
  1390. }
  1391. }
  1392. rewriter.auto_replace_outputs(opr);
  1393. };
  1394. state.graph().iter(on_opr);
  1395. rewriter.apply_inplace();
  1396. }
  1397. /* ================ FuseConvBiasNonlinPass ================ */
  1398. const char* FuseConvBiasNonlinPass::name() const {
  1399. return "combine_conv_bias_and_relu";
  1400. }
  1401. void FuseConvBiasNonlinPass::apply(OptState& state) const {
  1402. std::unordered_map<VarNode*, std::vector<OperatorNodeBase*>> m_deps;
  1403. state.graph().iter([&m_deps](OperatorNodeBase* opr) {
  1404. for (auto& inp : opr->input()) {
  1405. m_deps[inp].push_back(opr);
  1406. }
  1407. });
  1408. auto rewriter = state.graph().make_rewriter();
  1409. using Mode = opr::Elemwise::Param::Mode;
  1410. using NonlineMode = opr::ConvBiasForward::Param::NonlineMode;
  1411. auto get_nonlinearity_mode = [&](opr::Elemwise* elem) -> NonlineMode {
  1412. if (elem->param().mode == Mode::FUSE_ADD_RELU ||
  1413. elem->param().mode == Mode::RELU) {
  1414. return NonlineMode::RELU;
  1415. } else if (elem->param().mode == Mode::FUSE_ADD_SIGMOID ||
  1416. elem->param().mode == Mode::SIGMOID) {
  1417. return NonlineMode::SIGMOID;
  1418. } else {
  1419. return NonlineMode::IDENTITY;
  1420. }
  1421. };
  1422. auto try_fuse_bias_nonlinearity = [&](opr::Elemwise* elem) -> bool {
  1423. bool can_be_fused = true;
  1424. can_be_fused &= (elem->input().size() == 2);
  1425. can_be_fused &= (elem->param().mode == Mode::FUSE_ADD_RELU) ||
  1426. (elem->param().mode == Mode::FUSE_ADD_TANH) ||
  1427. (elem->param().mode == Mode::FUSE_ADD_SIGMOID);
  1428. return can_be_fused;
  1429. };
  1430. auto try_fuse_bias = [&](opr::Elemwise* elem) -> bool {
  1431. bool can_be_fused = true;
  1432. can_be_fused &= (elem->input().size() == 2);
  1433. can_be_fused &= (elem->param().mode == Mode::ADD);
  1434. return can_be_fused;
  1435. };
  1436. auto try_fuse_nonlinearity = [&](opr::Elemwise* elem) -> bool {
  1437. bool can_be_fused = true;
  1438. can_be_fused &= (elem->input().size() == 1);
  1439. can_be_fused &= (elem->param().mode == Mode::RELU) ||
  1440. (elem->param().mode == Mode::TANH) ||
  1441. (elem->param().mode == Mode::SIGMOID);
  1442. return can_be_fused;
  1443. };
  1444. auto convert_to_conv_bias_param = [&](const opr::Convolution::Param& param)
  1445. -> opr::ConvBiasForward::Param {
  1446. using Param = opr::ConvBiasForward::Param;
  1447. return opr::ConvBiasForward::Param{Param::NonlineMode::IDENTITY,
  1448. param.mode,
  1449. param.sparse,
  1450. param.format,
  1451. param.pad_h,
  1452. param.pad_w,
  1453. param.stride_h,
  1454. param.stride_w,
  1455. param.dilate_h,
  1456. param.dilate_w};
  1457. };
  1458. auto check_bias_shape = [&](opr::Convolution* conv, VarNode* bias) -> bool {
  1459. bool valid_bias_shape = true;
  1460. using Format = opr::Convolution::Param::Format;
  1461. using Sparse = opr::Convolution::Param::Sparse;
  1462. auto dst_shape = conv->output(0)->shape();
  1463. auto filter_shape = conv->input(1)->shape();
  1464. auto bias_shape = bias->shape();
  1465. if (dst_shape.eq_shape(bias_shape)) {
  1466. return valid_bias_shape;
  1467. }
  1468. size_t OC = filter_shape[0];
  1469. if (conv->param().sparse == Sparse::GROUP) {
  1470. OC *= filter_shape[1];
  1471. }
  1472. if (conv->param().format == Format::NCHW) {
  1473. valid_bias_shape &=
  1474. ((bias_shape.ndim == 4) && (bias_shape[0] == 1) &&
  1475. (bias_shape[1] == OC) && (bias_shape[2] == 1) &&
  1476. (bias_shape[3] == 1));
  1477. } else if (conv->param().format == Format::NCHW4) {
  1478. valid_bias_shape &=
  1479. ((bias_shape.ndim == 5) && (bias_shape[0] == 1) &&
  1480. (bias_shape[1] == OC / 4) && (bias_shape[2] == 1) &&
  1481. (bias_shape[3] == 1) && bias_shape[4] == 4);
  1482. } else if (conv->param().format == Format::NHWC) {
  1483. valid_bias_shape &= ((bias_shape.ndim == 4) &&
  1484. (bias_shape[0] == 1) && (bias_shape[1] == 1) &&
  1485. (bias_shape[2] == 1) && (bias_shape[3] == OC));
  1486. } else {
  1487. valid_bias_shape &=
  1488. ((bias_shape.ndim == 5) && (bias_shape[0] == 1) &&
  1489. (bias_shape[1] == 1) && (bias_shape[2] == OC) &&
  1490. (bias_shape[3] == 1) && (bias_shape[4] == 4));
  1491. mgb_assert(conv->param().format == Format::NHWCD4);
  1492. }
  1493. return valid_bias_shape;
  1494. };
  1495. auto try_fuse_typecvt = [&](opr::TypeCvt* typecvt) -> OperatorNodeBase* {
  1496. mgb_assert(typecvt->input().size() == 1);
  1497. auto conv_bias = try_cast_as_op<opr::ConvBias>(
  1498. rewriter.get_var(typecvt->input(0))->owner_opr());
  1499. if (!conv_bias || m_deps.count(typecvt->input(0)) != 1 ||
  1500. typecvt->output(0)->dtype().enumv() !=
  1501. DTypeTrait<dtype::QuantizedS8>::enumv)
  1502. return nullptr;
  1503. auto config = conv_bias->config();
  1504. config.output_dtype(typecvt->output(0)->dtype());
  1505. if (conv_bias->input().size() == 3) {
  1506. // conv + bias
  1507. return opr::ConvBias::make(conv_bias->input(0), conv_bias->input(1),
  1508. conv_bias->input(2), conv_bias->param(),
  1509. conv_bias->execution_policy(), config)
  1510. .node()
  1511. ->owner_opr();
  1512. } else {
  1513. // conv without bias
  1514. return opr::ConvBias::make(conv_bias->input(0), conv_bias->input(1),
  1515. conv_bias->param(),
  1516. conv_bias->execution_policy(), config)
  1517. .node()
  1518. ->owner_opr();
  1519. }
  1520. };
  1521. auto on_opr = [&](OperatorNodeBase* opr) {
  1522. auto check_conv = [](opr::Convolution* conv) -> bool {
  1523. return conv->param().format ==
  1524. megdnn::param::Convolution::Format::NHWCD4 ||
  1525. conv->param().format ==
  1526. megdnn::param::Convolution::Format::NHWC ||
  1527. conv->param().format ==
  1528. megdnn::param::Convolution::Format::NCHW ||
  1529. conv->param().format ==
  1530. megdnn::param::Convolution::Format::NCHW4
  1531. ;
  1532. };
  1533. if (auto elem = try_cast_as_op<opr::Elemwise>(opr)) {
  1534. if (try_fuse_bias_nonlinearity(elem) || try_fuse_bias(elem)) {
  1535. auto inp1 = rewriter.get_var(elem->input(0));
  1536. auto inp2 = rewriter.get_var(elem->input(1));
  1537. opr::Convolution* conv = nullptr;
  1538. size_t bias_idx = 0;
  1539. if (inp1->owner_opr()->same_type<opr::Convolution>() &&
  1540. m_deps[elem->input(0)].size() == 1) {
  1541. conv = try_cast_as_op<opr::Convolution>(inp1->owner_opr());
  1542. bias_idx = 1;
  1543. } else if (inp2->owner_opr()->same_type<opr::Convolution>() &&
  1544. m_deps[elem->input(1)].size() == 1) {
  1545. conv = try_cast_as_op<opr::Convolution>(inp2->owner_opr());
  1546. bias_idx = 0;
  1547. }
  1548. auto bias_inp = rewriter.get_var(elem->input(bias_idx));
  1549. if (conv && check_conv(conv) &&
  1550. check_bias_shape(conv, bias_inp)) {
  1551. opr::ConvBiasForward::Param param =
  1552. convert_to_conv_bias_param(conv->param());
  1553. param.nonlineMode = get_nonlinearity_mode(elem);
  1554. auto new_var =
  1555. opr::ConvBiasForward::make(
  1556. conv->input(0), conv->input(1), bias_inp,
  1557. param, conv->execution_policy(),
  1558. conv->config())
  1559. .node();
  1560. rewriter.replace_var(
  1561. opr->output(0), new_var,
  1562. mgb_cstr_log("replace nonlinearity(conv(x, w) + b) "
  1563. "-> conv_bias(x, w, b)"));
  1564. return;
  1565. }
  1566. } else if (try_fuse_nonlinearity(elem)) {
  1567. auto inp = rewriter.get_var(elem->input(0));
  1568. {
  1569. auto conv =
  1570. try_cast_as_op<opr::Convolution>(inp->owner_opr());
  1571. if (conv && check_conv(conv) &&
  1572. m_deps[elem->input(0)].size() == 1) {
  1573. opr::ConvBiasForward::Param param =
  1574. convert_to_conv_bias_param(conv->param());
  1575. param.nonlineMode = get_nonlinearity_mode(elem);
  1576. auto new_var = opr::ConvBiasForward::make(
  1577. conv->input(0), conv->input(1),
  1578. param, conv->execution_policy(),
  1579. conv->config())
  1580. .node();
  1581. rewriter.replace_var(
  1582. opr->output(0), new_var,
  1583. mgb_cstr_log("replace nonlinearity(conv(x, w)) "
  1584. "-> conv_bias(x, w)"));
  1585. return;
  1586. }
  1587. }
  1588. {
  1589. auto conv = try_cast_as_op<opr::ConvBias>(inp->owner_opr());
  1590. auto check_conv_bias = [&](opr::ConvBias* opr) {
  1591. return opr->param().format ==
  1592. opr::ConvBias::Param::Format::NHWC ||
  1593. opr->param().format ==
  1594. opr::ConvBias::Param::Format::NCHW ||
  1595. opr->param().format ==
  1596. opr::ConvBias::Param::Format::NCHW4
  1597. ;
  1598. };
  1599. if (conv && check_conv_bias(conv) &&
  1600. m_deps[elem->input(0)].size() == 1) {
  1601. auto param = conv->param();
  1602. param.nonlineMode = get_nonlinearity_mode(elem);
  1603. auto new_var = opr::ConvBiasForward::make(
  1604. conv->input(0), conv->input(1),
  1605. conv->input(2), param,
  1606. conv->execution_policy(),
  1607. conv->config())
  1608. .node();
  1609. rewriter.replace_var(
  1610. opr->output(0), new_var,
  1611. mgb_cstr_log("replace nonlinearity(conv(x, w)) "
  1612. "-> conv_bias(x, w)"));
  1613. return;
  1614. }
  1615. }
  1616. }
  1617. } else if (auto typecvt = try_cast_as_op<opr::TypeCvt>(opr)) {
  1618. auto new_opr = try_fuse_typecvt(typecvt);
  1619. if (new_opr) {
  1620. rewriter.replace_var(
  1621. opr->output(0), new_opr->output(0),
  1622. mgb_cstr_log("replace typecvt(conv_bias(x, w, b)) -> "
  1623. "conv_bias(x, w, b)"));
  1624. return;
  1625. }
  1626. }
  1627. rewriter.auto_replace_outputs(opr);
  1628. };
  1629. state.graph().iter(on_opr);
  1630. rewriter.apply_inplace();
  1631. }
  1632. /* ================ FuseConvBiasZPass ================ */
  1633. const char* FuseConvBiasZPass::name() const {
  1634. return "combine_conv_bias_and_z";
  1635. }
  1636. void FuseConvBiasZPass::apply(OptState& state) const {
  1637. UniqReaderCheck uniq_reader_check{state.graph()};
  1638. auto rewriter = state.graph().make_rewriter();
  1639. using Mode = opr::Elemwise::Param::Mode;
  1640. using MultiMode = opr::ElemwiseMultiType::Param::Mode;
  1641. using NonlineMode = opr::ConvBiasForward::Param::NonlineMode;
  1642. auto check_conv_bias = [](opr::ConvBias* conv_bias) -> bool {
  1643. return conv_bias->param().format ==
  1644. megdnn::param::ConvBias::Format::NHWC ||
  1645. conv_bias->param().format ==
  1646. megdnn::param::ConvBias::Format::NCHW ||
  1647. conv_bias->param().format ==
  1648. megdnn::param::ConvBias::Format::NCHW4
  1649. ;
  1650. };
  1651. auto check_fuse_shape = [&](opr::ConvBias* conv_bias, VarNode* z) -> bool {
  1652. bool valid_fuse_shape = true;
  1653. auto z_shape = z->shape();
  1654. auto bias_shape = conv_bias->input(2)->shape();
  1655. auto conv_bias_shape = conv_bias->output(0)->shape();
  1656. valid_fuse_shape &= (!conv_bias_shape.eq_shape(bias_shape));
  1657. valid_fuse_shape &= conv_bias_shape.eq_shape(z_shape);
  1658. return valid_fuse_shape;
  1659. };
  1660. auto check_fuse_dtype = [&](opr::ConvBias* conv_bias, VarNode* z) -> bool {
  1661. return conv_bias->output(0)->dtype().enumv() == z->dtype().enumv();
  1662. };
  1663. auto get_convbias_nonline_mode = [&](OperatorNodeBase* opr) -> NonlineMode {
  1664. if (opr->same_type<opr::Elemwise>()) {
  1665. auto elem = try_cast_as_op<opr::Elemwise>(opr);
  1666. if (elem->param().mode == Mode::FUSE_ADD_RELU)
  1667. return NonlineMode::RELU;
  1668. }
  1669. if (opr->same_type<opr::ElemwiseMultiType>()) {
  1670. auto elem = try_cast_as_op<opr::ElemwiseMultiType>(opr);
  1671. if (elem->param().mode == MultiMode::QFUSE_ADD_RELU)
  1672. return NonlineMode::RELU;
  1673. }
  1674. return NonlineMode::IDENTITY;
  1675. };
  1676. auto try_replace_var_node = [&](OperatorNodeBase* opr) {
  1677. opr::ConvBias* conv_bias = nullptr;
  1678. size_t z_idx = 0;
  1679. size_t nr_inps = opr->input().size();
  1680. for (size_t i = 0; i < nr_inps; i++) {
  1681. auto inp = rewriter.get_var(opr->input(i));
  1682. if (inp->owner_opr()->same_type<opr::ConvBias>()) {
  1683. auto cb = try_cast_as_op<opr::ConvBias>(inp->owner_opr());
  1684. if (cb->input().size() == 3 &&
  1685. cb->param().nonlineMode ==
  1686. opr::ConvBias::Param::NonlineMode::IDENTITY &&
  1687. uniq_reader_check(opr->input(i))) {
  1688. conv_bias = cb;
  1689. z_idx = nr_inps - i - 1;
  1690. break;
  1691. }
  1692. }
  1693. }
  1694. auto z_inp = rewriter.get_var(opr->input(z_idx));
  1695. if (conv_bias && check_conv_bias(conv_bias) &&
  1696. check_fuse_shape(conv_bias, z_inp) &&
  1697. check_fuse_dtype(conv_bias, z_inp)) {
  1698. auto param = conv_bias->param();
  1699. param.nonlineMode = get_convbias_nonline_mode(opr);
  1700. auto config = conv_bias->config();
  1701. auto new_var = opr::ConvBiasForward::make(
  1702. conv_bias->input(0), conv_bias->input(1),
  1703. conv_bias->input(2), z_inp, param,
  1704. conv_bias->execution_policy(),
  1705. config.output_dtype(opr->output(0)->dtype()))
  1706. .node();
  1707. rewriter.replace_var(
  1708. opr->output(0), new_var,
  1709. mgb_cstr_log("replace "
  1710. "nonlinearity(conv_bias(x,w,b) + z) "
  1711. "-> conv_bias(x, w, b, z)"));
  1712. uniq_reader_check.update_on_opr_auto_replace(opr,
  1713. new_var->owner_opr());
  1714. return true;
  1715. }
  1716. return false;
  1717. };
  1718. auto try_fuse_elemwise = [&](OperatorNodeBase* opr) {
  1719. if (!opr->same_type<opr::Elemwise>())
  1720. return false;
  1721. auto elem = try_cast_as_op<opr::Elemwise>(opr);
  1722. if (elem->input().size() != 2)
  1723. return false;
  1724. if (elem->param().mode != Mode::ADD &&
  1725. elem->param().mode != Mode::FUSE_ADD_RELU)
  1726. return false;
  1727. return try_replace_var_node(opr);
  1728. };
  1729. auto try_fuse_elemwise_multi_type = [&](OperatorNodeBase* opr) {
  1730. if (!opr->same_type<opr::ElemwiseMultiType>())
  1731. return false;
  1732. auto elem = try_cast_as_op<opr::ElemwiseMultiType>(opr);
  1733. if (elem->input().size() != 2)
  1734. return false;
  1735. if (elem->param().mode != MultiMode::QADD &&
  1736. elem->param().mode != MultiMode::QFUSE_ADD_RELU)
  1737. return false;
  1738. return try_replace_var_node(opr);
  1739. };
  1740. auto on_opr = [&](OperatorNodeBase* opr) {
  1741. if (try_fuse_elemwise(opr))
  1742. return;
  1743. if (try_fuse_elemwise_multi_type(opr))
  1744. return;
  1745. auto new_opr = rewriter.auto_replace_outputs(opr);
  1746. uniq_reader_check.update_on_opr_auto_replace(opr, new_opr);
  1747. };
  1748. state.graph().iter(on_opr);
  1749. rewriter.apply_inplace();
  1750. }
  1751. /* ================ FuseDeconvCvtPass ================ */
  1752. const char* FuseDeconvCvtPass::name() const {
  1753. return "combine_deconv_and_typecvt";
  1754. }
  1755. void FuseDeconvCvtPass::apply(OptState& state) const {
  1756. std::unordered_map<VarNode*, std::vector<OperatorNodeBase*>> m_deps;
  1757. state.graph().iter([&m_deps](OperatorNodeBase* opr) {
  1758. for (auto& inp : opr->input()) {
  1759. m_deps[inp].push_back(opr);
  1760. }
  1761. });
  1762. UniqReaderCheck uniq_reader_check{state.graph()};
  1763. auto rewriter = state.graph().make_rewriter();
  1764. auto try_fuse_deconv_typecvt =
  1765. [&](opr::TypeCvt* typecvt) -> OperatorNodeBase* {
  1766. mgb_assert(typecvt->input().size() == 1);
  1767. auto deconv = try_cast_as_op<opr::ConvolutionBackwardData>(
  1768. rewriter.get_var(typecvt->input(0))->owner_opr());
  1769. if (!deconv
  1770. || m_deps.count(typecvt->input(0)) != 1 ||
  1771. typecvt->output(0)->dtype().enumv() !=
  1772. DTypeTrait<dtype::QuantizedS8>::enumv) {
  1773. return nullptr;
  1774. }
  1775. if (!uniq_reader_check(deconv->output(0)))
  1776. return nullptr;
  1777. auto config = deconv->config();
  1778. config.output_dtype(typecvt->output(0)->dtype());
  1779. return opr::ConvolutionBackwardData::make(
  1780. deconv->input(0), deconv->input(1), deconv->param(),
  1781. deconv->execution_policy(), config)
  1782. .node()
  1783. ->owner_opr();
  1784. };
  1785. auto on_opr = [&](OperatorNodeBase* opr) {
  1786. if (auto typecvt = try_cast_as_op<opr::TypeCvt>(opr)) {
  1787. if (auto deconv_new = try_fuse_deconv_typecvt(typecvt)) {
  1788. rewriter.replace_var(
  1789. opr->output(0), deconv_new->output(0),
  1790. mgb_cstr_log("replace typecvt(deconv(x, w)) -> "
  1791. "deconv(x, w)"));
  1792. uniq_reader_check.update_on_opr_auto_replace(opr, deconv_new);
  1793. return;
  1794. }
  1795. }
  1796. auto new_opr = rewriter.auto_replace_outputs(opr);
  1797. uniq_reader_check.update_on_opr_auto_replace(
  1798. opr, new_opr);
  1799. };
  1800. state.graph().iter(on_opr);
  1801. rewriter.apply_inplace();
  1802. }
  1803. /* ================ ParamMergePass ================ */
  1804. const char* ParamMergePass::name() const {
  1805. return mgb_cstr_log("param_merge");
  1806. }
  1807. void ParamMergePass::apply(OptState& opt_state) const {
  1808. param_merge<opr::SharedDeviceTensor, opr::MultipleDeviceTensorHolder>(
  1809. opt_state);
  1810. param_merge<opr::SharedDeviceTensorWithFormat,
  1811. opr::MultipleDeviceTensorWithFormatHolder>(opt_state);
  1812. }
  1813. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台