You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inference.cpp 93 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189
  1. /**
  2. * \file src/gopt/impl/inference.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/gopt/inference.h"
  12. #include "megbrain/gopt/gtrans.h"
  13. #include "megbrain/gopt/basic_arith.h"
  14. #include "megbrain/graph/event.h"
  15. #include "megbrain/opr/dnn/batch_norm.h"
  16. #include "megbrain/opr/dnn/local.h"
  17. #include "megbrain/opr/search_policy/algo_chooser_helper.h"
  18. #include "megbrain/opr/search_policy/profiler.h"
  19. #include "megbrain/utils/shared_set.h"
  20. #include "megbrain/serialization/opr_shallow_copy.h"
  21. #include "megbrain/opr/basic_arith.h"
  22. #include "megbrain/opr/dnn/convolution.h"
  23. #include "megbrain/opr/blas.h"
  24. #include "megbrain/opr/misc.h"
  25. #include "megbrain/opr/utility.h"
  26. #include "megbrain/opr/dnn/pooling.h"
  27. #include "megbrain/opr/tensor_manip.h"
  28. #include "megbrain/opr/imgproc.h"
  29. #include "megbrain/opr/nn_int.h"
  30. #include "megbrain/opr/tensor_gen.h"
  31. #include "megbrain/utils/hash_ct.h"
  32. #include "megdnn/tensor_format.h"
  33. #if MGB_ENABLE_TENSOR_RT
  34. #include "megbrain/tensorrt/tensorrt_opr.h"
  35. #endif
  36. #if MGB_CUDA
  37. #include <cudnn.h>
  38. #endif
  39. #include "megbrain/gopt/misc.h"
  40. #include "megbrain/utils/hash_ct.h"
  41. #include "midout.h"
  42. MIDOUT_DECL(megbrain_inference)
  43. #define MIDOUT_B(tag) \
  44. MIDOUT_BEGIN(megbrain_inference, midout_iv(MGB_HASH_STR(tag))) {
  45. #define MIDOUT_E \
  46. } \
  47. MIDOUT_END();
  48. using namespace mgb;
  49. using namespace gopt;
  50. namespace {
  51. template <typename SharedDeviceTensor, typename MultipleDeviceTensorHolder>
  52. void param_merge(OptState& opt_state) {
  53. auto rewriter = opt_state.graph().make_rewriter();
  54. ThinHashMap<OperatorNodeBase*, size_t> opr2idx;
  55. std::vector<OperatorNodeBase*> all_oprs;
  56. typename MultipleDeviceTensorHolder::ValueArray all_values;
  57. auto cb_find_opr = [&](cg::OperatorNodeBase* opr) {
  58. if (opr->same_type<SharedDeviceTensor>()) {
  59. auto p = &opr->cast_final<SharedDeviceTensor>();
  60. // ShredD may be manu
  61. opr2idx[p] = all_values.size();
  62. all_values.push_back(p->dev_data());
  63. all_oprs.push_back(p);
  64. }
  65. };
  66. opt_state.graph().iter(cb_find_opr);
  67. SymbolVarArray new_vars;
  68. auto cb_replace = [&](cg::OperatorNodeBase* opr) {
  69. auto iter = opr2idx.find(opr);
  70. if (iter == opr2idx.end()) {
  71. rewriter.auto_replace_outputs(opr);
  72. } else {
  73. if (new_vars.empty()) {
  74. // new oprs must be created in iter callback; so we populate
  75. // new_vars lazily
  76. new_vars = MultipleDeviceTensorHolder::make(
  77. *opt_state.graph().comp_graph(), std::move(all_values),
  78. {ssprintf("merged%zu", all_values.size())});
  79. for (size_t i = 0; i < new_vars.size(); ++i) {
  80. auto src = all_oprs[i]->output(0);
  81. if (src->has_name_set()) {
  82. new_vars[i].rename(src->name());
  83. }
  84. }
  85. }
  86. rewriter.replace_var(
  87. opr->output(0), new_vars.at(iter->second).node(),
  88. mgb_cstr_log("replace multi SharedDeviceTensor(Format) to "
  89. "MultipleDeviceTensorHolder(Format)"));
  90. }
  91. };
  92. opt_state.graph().iter(cb_replace);
  93. rewriter.apply_inplace();
  94. }
  95. } // namespace
  96. /* ================ global functions ================ */
  97. SymbolVarArray gopt::optimize_for_inference(
  98. const SymbolVarArray& dest_vars,
  99. const OptimizeForInferenceOptions& opt) {
  100. return gopt::GraphOptimizer()
  101. .add_preset_passes(false, &opt,
  102. &dest_vars[0].node()->owner_graph()->options())
  103. .apply({dest_vars})
  104. .endpoint_vars();
  105. }
  106. namespace {
  107. void modify_conv_strategy(
  108. opr::mixin::AlgoChooserHelper& conv,
  109. opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) {
  110. auto policy = conv.execution_policy_transient();
  111. policy.strategy = strategy;
  112. conv.set_execution_policy(policy);
  113. }
  114. template <typename Opr>
  115. void inplace_conv_opr_modifier(
  116. OperatorNodeBase& opr,
  117. opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) {
  118. modify_conv_strategy(
  119. opr.cast_final_safe<Opr>(),
  120. strategy);
  121. }
  122. void modify_conv_policy_workspace_limit(opr::mixin::AlgoChooserHelper& conv,
  123. size_t workspace_limit) {
  124. auto policy = conv.execution_policy_transient();
  125. policy.workspace_limit = workspace_limit;
  126. conv.set_execution_policy(policy);
  127. }
  128. template <typename Opr>
  129. void inplace_conv_opr_workspace_limit_modifier(OperatorNodeBase& opr,
  130. size_t workspace_limit) {
  131. modify_conv_policy_workspace_limit(opr.cast_final_safe<Opr>(),
  132. workspace_limit);
  133. }
  134. } // anonymous namespace
  135. void gopt::modify_opr_algo_strategy_inplace(
  136. const VarNodeArrayView& dest_vars,
  137. opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) {
  138. #if !MGB_ENABLE_FASTRUN
  139. using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
  140. if ((strategy & S::PROFILE) && !(strategy & S::HEURISTIC)) {
  141. mgb_throw(MegBrainError, "fastrun is disabled at compile time");
  142. }
  143. #endif
  144. const ThinHashMap<Typeinfo*, std::function<void(OperatorNodeBase&)>>
  145. modifiers = {
  146. #define CONV(t) \
  147. {opr::t::typeinfo(), std::bind(inplace_conv_opr_modifier<opr::t>, \
  148. std::placeholders::_1, strategy)},
  149. MGB_FOREACH_FASTRUN_OPR(CONV)
  150. #undef CONV
  151. };
  152. auto on_opr = [&](OperatorNodeBase* opr) {
  153. auto iter = modifiers.find(opr->dyn_typeinfo());
  154. if (iter != modifiers.end()) {
  155. iter->second(*opr);
  156. }
  157. };
  158. cg::DepOprIter dep_iter{on_opr};
  159. for (auto i : dest_vars) {
  160. dep_iter.add(i);
  161. }
  162. }
  163. void gopt::enable_opr_algo_profiling_inplace(
  164. const VarNodeArrayView& dest_vars) {
  165. modify_opr_algo_strategy_inplace(
  166. dest_vars,
  167. opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy::PROFILE);
  168. }
  169. void gopt::enable_opr_use_profiling_cache_inplace(
  170. const VarNodeArrayView& dest_vars) {
  171. using S = megdnn::param::ExecutionPolicy::Strategy;
  172. modify_opr_algo_strategy_inplace(dest_vars, S::PROFILE | S::HEURISTIC);
  173. }
  174. void gopt::set_opr_algo_workspace_limit_inplace(
  175. const VarNodeArrayView& dest_vars, size_t workspace_limit) {
  176. static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&, size_t)>
  177. modifiers = {
  178. #define CONV(t) \
  179. {opr::t::typeinfo(), &inplace_conv_opr_workspace_limit_modifier<opr::t>},
  180. MGB_FOREACH_FASTRUN_OPR(CONV)
  181. #undef CONV
  182. };
  183. auto on_opr = [&](OperatorNodeBase* opr) {
  184. auto iter = modifiers.find(opr->dyn_typeinfo());
  185. if (iter != modifiers.end()) {
  186. iter->second(*opr, workspace_limit);
  187. }
  188. };
  189. cg::DepOprIter dep_iter{on_opr};
  190. for (auto i : dest_vars) {
  191. dep_iter.add(i);
  192. }
  193. }
  194. /* ================ ParamRedistributePass ================ */
  195. const char* ParamRedistributePass::name() const {
  196. return mgb_cstr_log("param_redistribute");
  197. }
  198. class ParamRedistributePass::Impl final: public RecursiveSubGraphRewriteHelper {
  199. ConstVarPropogate m_cvprop;
  200. UniqReaderCheck m_uniq_reader_check;
  201. //! oprs already processed in try_distribute_then_reassociate() should be
  202. //! skipped in on_new_opr_check_should_process()
  203. ThinHashSet<OperatorNodeBase*> m_opr_blacklist;
  204. std::string m_distribute_reasso_log_msg;
  205. //! try applying BinaryTrans20::associtive
  206. GTransResult try_reassociate(OperatorNodeBase *opr);
  207. //! try applying BinaryTrans20::distributive_add
  208. GTransResult try_distribute_add(OperatorNodeBase *opr);
  209. //! try distribute MUL/DIV over ADD/SUB and then apply
  210. GTransResult try_distribute_then_reassociate(OperatorNodeBase *opr);
  211. GTransResult process_opr(VarNode *out_var) override;
  212. bool on_new_opr_check_should_process(
  213. OperatorNodeBase*opr, OperatorNodeBase *repl_opr) override {
  214. m_uniq_reader_check.update_on_opr_auto_replace(opr, repl_opr);
  215. auto ins = m_cvprop.add_opr(opr);
  216. return ins.has_const_inp && !ins.all_const_inp &&
  217. !m_opr_blacklist.count(opr);
  218. };
  219. void after_replace_var(VarNode *orig_var, VarNode* new_var) override {
  220. m_uniq_reader_check.update_on_opr_auto_replace(orig_var->owner_opr(),
  221. new_var->owner_opr());
  222. }
  223. /*!
  224. * \brief try to reorder opr inputs to a const one and a non-const one
  225. *
  226. * return true if it can be reformulated as f(nci, ci), where nci is
  227. * non-const and ci is const.
  228. */
  229. bool reorder_for_normconst(OperatorNodeBase *opr,
  230. bool &swap_inp, VarNode *&nci, VarNode *&ci);
  231. public:
  232. Impl(OptState &state);
  233. };
  234. GTransResult ParamRedistributePass::Impl::process_opr(VarNode *out_var) {
  235. auto opr = out_var->owner_opr();
  236. auto trans = try_reassociate(opr);
  237. if (!trans.valid()) {
  238. trans = try_distribute_add(opr);
  239. if (!trans.valid())
  240. trans = try_distribute_then_reassociate(opr);
  241. }
  242. return trans;
  243. }
  244. GTransResult ParamRedistributePass::Impl::try_reassociate(
  245. OperatorNodeBase *opr) {
  246. // apply BinaryAssociative0 if opr is the form f(g(a, b), c) and b and c are
  247. // const
  248. bool swap_fop_inp = false, swap_gop_inp = false;
  249. VarNode *a, *b, *c, *ab;
  250. if (!reorder_for_normconst(opr, swap_fop_inp, ab, c))
  251. return None;
  252. if (!m_uniq_reader_check(ab))
  253. return None;
  254. if (!reorder_for_normconst(ab->owner_opr(), swap_gop_inp, a, b))
  255. return None;
  256. return BinaryTrans20::associtive().apply(opr, swap_fop_inp, swap_gop_inp);
  257. }
  258. GTransResult ParamRedistributePass::Impl::try_distribute_add(
  259. OperatorNodeBase *opr) {
  260. if (opr->same_type<opr::Elemwise>() || opr->input().size() != 2)
  261. return None;
  262. if (!m_cvprop.is_const(opr->input(1)))
  263. return None;
  264. auto ab = as_elem_opr(opr->input(0)->owner_opr(), opr::Elemwise::Mode::ADD);
  265. if (ab) {
  266. bool swap;
  267. VarNode *a, *b;
  268. if (reorder_for_normconst(ab, swap, a, b)) {
  269. return BinaryTrans20::distributive_add().apply(
  270. opr, false, swap);
  271. }
  272. }
  273. return None;
  274. }
  275. GTransResult ParamRedistributePass::Impl::try_distribute_then_reassociate(
  276. OperatorNodeBase *opr) {
  277. if (!opr->same_type<opr::Elemwise>())
  278. return None;
  279. using Mode = opr::Elemwise::Mode;
  280. auto mode = opr->cast_final<opr::Elemwise>().param().mode;
  281. if (!(mode == Mode::MUL || mode == Mode::TRUE_DIV))
  282. return None;
  283. VarNode *a, *b;
  284. bool swap;
  285. if (!reorder_for_normconst(opr, swap, a, b))
  286. return None;
  287. auto chain_pred = [this](OperatorNodeBase *opr) {
  288. if (as_elem_opr(opr, Mode::ADD)) {
  289. auto var = opr->output(0);
  290. return m_uniq_reader_check(var) || m_cvprop.is_const(var);
  291. }
  292. return false;
  293. };
  294. auto chain = extract_opr_leaves(a, chain_pred);
  295. if (chain.size() <= 1)
  296. return None;
  297. std::unordered_map<VarNode*, VarNode*> repl_map;
  298. m_distribute_reasso_log_msg.clear();
  299. int nr_fail = 0, nr_succ = 0;
  300. for (auto &&var: chain) {
  301. {
  302. auto iter = repl_map.find(var);
  303. if (iter != repl_map.end()) {
  304. var = iter->second;
  305. continue;
  306. }
  307. }
  308. auto vnew = (SymbolVar{var} * b).node();
  309. m_opr_blacklist.insert(vnew->owner_opr());
  310. if (!m_cvprop.is_const(var)) {
  311. auto trans = try_reassociate(vnew->owner_opr());
  312. if (!trans.valid()) {
  313. // allow at most one failed redistribution
  314. if (nr_fail)
  315. return None;
  316. ++ nr_fail;
  317. } else {
  318. ++ nr_succ;
  319. vnew = trans->result;
  320. if (!m_distribute_reasso_log_msg.empty()) {
  321. m_distribute_reasso_log_msg.append(mgb_cstr_log(";"));
  322. }
  323. m_distribute_reasso_log_msg.append(trans->msg);
  324. }
  325. }
  326. repl_map[var] = vnew;
  327. var = vnew;
  328. }
  329. if (nr_succ) {
  330. m_distribute_reasso_log_msg.insert(0,
  331. mgb_cstr_log("distribute_mul("));
  332. m_distribute_reasso_log_msg.append(mgb_cstr_log(")"));
  333. return GTransResultItem{
  334. elemwise_reduce_var_list(chain, Mode::ADD),
  335. m_distribute_reasso_log_msg.c_str(),
  336. {}};
  337. }
  338. return None;
  339. }
  340. bool ParamRedistributePass::Impl::reorder_for_normconst(
  341. OperatorNodeBase *opr, bool &swap_inp, VarNode *&nci, VarNode *&ci) {
  342. if (opr->input().size() != 2)
  343. return false;
  344. nci = opr->input(0);
  345. ci = opr->input(1);
  346. if (!m_cvprop.is_const(ci)) {
  347. if (!is_commutable_binary(opr) || !m_cvprop.is_const(nci))
  348. return false;
  349. swap_inp = true;
  350. std::swap(nci, ci);
  351. } else {
  352. if (m_cvprop.is_const(nci))
  353. return false;
  354. swap_inp = false;
  355. }
  356. return true;
  357. }
  358. ParamRedistributePass::Impl::Impl(OptState &state):
  359. RecursiveSubGraphRewriteHelper{state},
  360. m_cvprop{ConstVarType::IMMUTABLE_AND_PARAM},
  361. m_uniq_reader_check{state.graph()}
  362. {
  363. auto cg = state.graph().comp_graph();
  364. auto on_new_opr = [this](const cg::event::OprInserted &ev) {
  365. if (!ev.is_dedup && !ev.exc) {
  366. // call add_opr eagerly to avoid deep recursion
  367. m_cvprop.add_opr(ev.opr);
  368. }
  369. };
  370. auto hdl = cg->event().register_receiver
  371. <cg::event::OprInserted>(on_new_opr);
  372. apply();
  373. }
  374. void ParamRedistributePass::apply(OptState &state) const {
  375. MIDOUT_B("ParamRedistributePass::apply")
  376. Impl{state};
  377. MIDOUT_E
  378. }
  379. /* ================ ParamFusePass ================ */
  380. /*!
  381. * \brief get name for new param
  382. */
  383. class ParamFusePass::VarNamer {
  384. #if MGB_BUILD_SLIM_SERVING
  385. public:
  386. const std::string& name(VarNode*) {
  387. static std::string ret("fuse");
  388. return ret;
  389. }
  390. #else
  391. using SrcSet = SharedSet<OperatorNodeBase*>;
  392. //! map from var to source SharedDeviceTensor/MultiSharedDeviceHolder oprs
  393. //! that it depends on
  394. ThinHashMap<OperatorNodeBase*, SrcSet> m_opr2srcs;
  395. std::string m_name_cache;
  396. std::vector<const char*> m_cur_name;
  397. SrcSet& get_src_set(OperatorNodeBase* opr) {
  398. auto opr_typeinfo = opr->dyn_typeinfo();
  399. auto iter = m_opr2srcs.find(opr);
  400. if (iter != m_opr2srcs.end()) {
  401. return iter->second;
  402. }
  403. auto &&ret = m_opr2srcs[opr];
  404. if (opr->input().empty()) {
  405. if (opr_typeinfo == opr::SharedDeviceTensor::typeinfo() ||
  406. opr_typeinfo == opr::MultipleDeviceTensorHolder::typeinfo()) {
  407. ret.insert(opr);
  408. } else {
  409. mgb_assert(opr_typeinfo == opr::ImmutableTensor::typeinfo());
  410. }
  411. return ret;
  412. }
  413. for (auto i: opr->input()) {
  414. ret.merge_from(get_src_set(i->owner_opr()));
  415. }
  416. return ret;
  417. }
  418. public:
  419. const std::string& name(VarNode *var) {
  420. m_cur_name.clear();
  421. for (auto i : get_src_set(var->owner_opr())) {
  422. m_cur_name.push_back(i->cname());
  423. }
  424. auto cmp = [](const char *x, const char *y) {
  425. return strcmp(x, y) < 0;
  426. };
  427. std::sort(m_cur_name.begin(), m_cur_name.end(), cmp);
  428. m_name_cache.clear();
  429. m_name_cache.append(mgb_cstr_log("fuse("));
  430. bool first = true;
  431. for (auto i: m_cur_name) {
  432. if (first) {
  433. first = false;
  434. } else {
  435. m_name_cache.push_back(',');
  436. }
  437. m_name_cache.append(i);
  438. }
  439. m_name_cache.append(mgb_cstr_log(
  440. ssprintf("):%s@%zu", var->cname(), var->id())));
  441. return m_name_cache;
  442. }
  443. #endif
  444. };
  445. const char* ParamFusePass::name() const {
  446. return mgb_cstr_log("param_fuse");
  447. }
  448. void ParamFusePass::apply(OptState &state) const {
  449. MIDOUT_B("ParamFusePass::apply")
  450. auto rewriter = state.graph().make_rewriter();
  451. auto cg = state.graph().comp_graph();
  452. ConstVarPropogate cvprop{ConstVarType::IMMUTABLE_AND_PARAM};
  453. state.graph().iter([&cvprop](OperatorNodeBase *opr) {
  454. cvprop.add_opr(opr);
  455. });
  456. ThinHashSet<VarNode*> processed_var;
  457. VarNamer var_namer;
  458. // reader: null if used as endvar
  459. auto replace_single_var = [&](VarNode *var, OperatorNodeBase *reader) {
  460. if (!processed_var.insert(var).second)
  461. return;
  462. auto inferred_val = std::make_shared<DeviceTensorND>(
  463. var->comp_node(), var->dtype());
  464. auto cb = [&](DeviceTensorND& val) {
  465. // retain format of val
  466. mgb_assert(val.format() == var->format());
  467. inferred_val->format(val.format())
  468. .resize(val.shape())
  469. .copy_from_fixlayout(val);
  470. };
  471. {
  472. auto orig_level = cg->options().log_level;
  473. cg->options().log_level = 0;
  474. MGB_TRY {
  475. cg->compile({{var, cb}})->execute();
  476. } MGB_FINALLY(cg->options().log_level = orig_level);
  477. }
  478. SymbolVar new_var;
  479. bool is_default_format = var->format().is_default();
  480. bool is_lowbit_aligned = var->format().is_lowbit_aligned();
  481. if (cg::is_static_var_value(var) &&
  482. (is_default_format || is_lowbit_aligned)) {
  483. // use ImmutableTensor for inferable vars
  484. HostTensorND hv;
  485. hv.copy_from(*inferred_val).sync();
  486. new_var = opr::ImmutableTensor::make(
  487. *var->owner_graph(), hv, var_namer.name(var));
  488. } else {
  489. if (is_default_format || is_lowbit_aligned) {
  490. new_var = opr::SharedDeviceTensor::make_const(
  491. *var->owner_graph(), inferred_val, var_namer.name(var));
  492. } else {
  493. new_var = opr::SharedDeviceTensorWithFormat::make_const(
  494. *var->owner_graph(), inferred_val, var_namer.name(var));
  495. }
  496. }
  497. std::string log;
  498. if (reader) {
  499. log = mgb_ssprintf_log(
  500. "due to read by %s{%s}",
  501. reader->cname(), reader->dyn_typeinfo()->name);
  502. } else {
  503. log = mgb_cstr_log("as endpoint");
  504. }
  505. rewriter.replace_var(var, new_var.node(), log.c_str());
  506. };
  507. auto replace_opr = [&](OperatorNodeBase* opr) {
  508. auto add_ret = cvprop.opr_rst(opr);
  509. if (!add_ret.all_const_inp && add_ret.has_midconst_inp) {
  510. for (auto i: opr->input()) {
  511. if (cvprop.is_midconst(i)) {
  512. state.call_with_opr(i->owner_opr(),
  513. [&]{replace_single_var(i, opr);});
  514. }
  515. }
  516. }
  517. rewriter.auto_replace_outputs(opr);
  518. //! we should deal with midconst var after auto_replace_outputs, as
  519. //! on_midconst_opr will replace the endpoint output which may cause
  520. //! double replace.
  521. if (add_ret.all_const_inp) {
  522. for (auto var : opr->output()) {
  523. if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT))
  524. continue;
  525. auto osize = ConstVarPropogate::var_mem_size(var);
  526. if (osize >= cvprop.max_size(opr) &&
  527. osize - cvprop.max_size(opr) > m_param_grow_limit) {
  528. return;
  529. }
  530. // const oprs should be evaluated when output is used by another
  531. // non-const opr or output is needed by the user
  532. if (state.graph().endpoint_contain(var)) {
  533. replace_single_var(var, nullptr);
  534. }
  535. }
  536. }
  537. };
  538. state.graph().iter(replace_opr);
  539. rewriter.apply_inplace();
  540. MIDOUT_E
  541. }
  542. /* ================ One2OneOprReplacePass ================ */
  543. const char* ConvertF32ToF16Pass::name() const {
  544. return mgb_cstr_log("convert_f32_to_f16");
  545. }
  546. void ConvertF32ToF16Pass::apply(OptState& state) const {
  547. MIDOUT_B("ConvertF32ToF16Pass::apply")
  548. state.set_var_replace_check_flag(m_var_replace_check_flag);
  549. auto rewriter = state.graph().make_rewriter();
  550. VarNodeArray new_inp_cache;
  551. // record original output dtype
  552. const SymbolVarArray& vars = state.graph().endpoint_vars();
  553. std::vector<DType> dtypes;
  554. for (size_t i = 0; i < vars.size(); i++) {
  555. dtypes.push_back(vars[i].node()->dtype());
  556. }
  557. auto on_opr = [this, &rewriter, &new_inp_cache](OperatorNodeBase* opr) {
  558. auto it = m_opr_replace_func.find(opr->dyn_typeinfo());
  559. if (it != m_opr_replace_func.end()) {
  560. auto&& new_inp = new_inp_cache;
  561. new_inp.clear();
  562. new_inp.reserve(opr->input().size());
  563. for (auto i: opr->input()) {
  564. new_inp.push_back(rewriter.get_var(i));
  565. }
  566. auto new_opr = (it->second)(opr, new_inp);
  567. auto &&origin_out = opr->output(), &&cur_out = new_opr->output();
  568. mgb_assert(origin_out.size() == cur_out.size(),
  569. "bad opr replace: src=%s{%s} dst=%s{%s}", opr->cname(),
  570. opr->dyn_typeinfo()->name, new_opr->cname(),
  571. new_opr->dyn_typeinfo()->name);
  572. for (size_t i = 0; i < origin_out.size(); i++) {
  573. rewriter.replace_var(origin_out[i], cur_out[i], nullptr);
  574. }
  575. } else {
  576. rewriter.auto_replace_outputs(opr);
  577. }
  578. };
  579. state.graph().iter(on_opr);
  580. rewriter.apply_inplace();
  581. // recover output dtype
  582. rewriter = state.graph().make_rewriter();
  583. const SymbolVarArray& endpoints = state.graph().endpoint_vars();
  584. auto replace_output = [&]() {
  585. for (size_t i = 0; i < endpoints.size(); i++) {
  586. VarNode* var = endpoints[i].node();
  587. if (var->dtype().enumv() != dtypes[i].enumv()) {
  588. auto new_var = opr::TypeCvt::make(var, dtypes[i]).node();
  589. rewriter.replace_var(var, new_var, nullptr);
  590. }
  591. }
  592. };
  593. mgb_assert(endpoints.size() > 0);
  594. auto opr = endpoints[0].node()->owner_opr();
  595. state.call_with_opr(opr, replace_output, OprPropertyFlag::NONE);
  596. rewriter.apply_inplace();
  597. MIDOUT_E
  598. }
  599. std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(
  600. bool use_f32_comp) {
  601. #if MEGDNN_DISABLE_FLOAT16
  602. mgb_throw(SystemError, "float16 disabled at compile time.");
  603. #else
  604. auto replace_h2d_opr = [](OperatorNodeBase* opr,
  605. const VarNodeArray& new_inp) {
  606. mgb_assert(opr->input().size() == new_inp.size());
  607. auto& h2d_opr = opr->cast_final_safe<opr::Host2DeviceCopy>();
  608. if (h2d_opr.output(0)->dtype() == dtype::Float32()) {
  609. auto cvt_var =
  610. opr::TypeCvt::make(h2d_opr.output(0), dtype::Float16(), {});
  611. return cvt_var.node()->owner_opr();
  612. }
  613. return opr;
  614. };
  615. auto replace_sdt_opr = [](OperatorNodeBase* opr,
  616. const VarNodeArray& new_inp) {
  617. mgb_assert(opr->input().size() == new_inp.size());
  618. auto& sdt_opr = opr->cast_final_safe<opr::SharedDeviceTensor>();
  619. if (sdt_opr.output(0)->dtype() == dtype::Float32()) {
  620. auto cvt_var =
  621. opr::TypeCvt::make(sdt_opr.output(0), dtype::Float16(), {});
  622. return cvt_var.node()->owner_opr();
  623. }
  624. return opr;
  625. };
  626. auto replace_imt_opr = [](OperatorNodeBase* opr,
  627. const VarNodeArray& new_inp) {
  628. mgb_assert(opr->same_type<opr::ImmutableTensor>());
  629. mgb_assert(opr->input().size() == new_inp.size());
  630. auto& imt_opr = opr->cast_final_safe<opr::ImmutableTensor>();
  631. if (imt_opr.output(0)->dtype() == dtype::Float32()) {
  632. auto cvt_var =
  633. opr::TypeCvt::make(imt_opr.output(0), dtype::Float16(), {});
  634. return cvt_var.node()->owner_opr();
  635. }
  636. return opr;
  637. };
  638. auto replace_lsp_opr = [](OperatorNodeBase* opr,
  639. const VarNodeArray& new_inp) {
  640. mgb_assert(opr->same_type<opr::Linspace>());
  641. mgb_assert(opr->input().size() == new_inp.size());
  642. auto& lsp_opr = opr->cast_final_safe<opr::Linspace>();
  643. if (lsp_opr.output(0)->dtype() != dtype::Float16()) {
  644. auto cvt_var =
  645. opr::TypeCvt::make(lsp_opr.output(0), dtype::Float16(), {});
  646. return cvt_var.node()->owner_opr();
  647. }
  648. return opr;
  649. };
  650. auto replace_conv_opr = [use_f32_comp](OperatorNodeBase* opr,
  651. const VarNodeArray& new_inp) {
  652. mgb_assert(opr->input().size() == new_inp.size());
  653. auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
  654. auto new_param = conv_opr.param();
  655. if (use_f32_comp) {
  656. new_param.compute_mode =
  657. megdnn::param::Convolution::ComputeMode::FLOAT32;
  658. }
  659. mgb_assert(new_inp[0]->dtype() == dtype::Float16(),
  660. "inp %s:%s, owner_opr:%s", new_inp[0]->dtype().name(),
  661. new_inp[0]->name().c_str(),
  662. new_inp[0]->owner_opr()->name().c_str());
  663. mgb_assert(new_inp[1]->dtype() == dtype::Float16(),
  664. "inp %s:%s, owner_opr:%s", new_inp[1]->dtype().name(),
  665. new_inp[1]->name().c_str(),
  666. new_inp[1]->owner_opr()->name().c_str());
  667. auto new_conv_opr = opr::Convolution::make(
  668. new_inp[0], new_inp[1], new_param, conv_opr.execution_policy(),
  669. conv_opr.config());
  670. return new_conv_opr.node()->owner_opr();
  671. };
  672. auto replace_deconv_opr = [use_f32_comp](OperatorNodeBase* opr,
  673. const VarNodeArray& new_inp) {
  674. mgb_assert(opr->input().size() == new_inp.size());
  675. auto& deconv_opr = opr->cast_final_safe<opr::ConvolutionBackwardData>();
  676. auto new_param = deconv_opr.param();
  677. if (use_f32_comp) {
  678. new_param.compute_mode =
  679. megdnn::param::Convolution::ComputeMode::FLOAT32;
  680. }
  681. mgb_assert(new_inp[0]->dtype() == dtype::Float16(),
  682. "inp %s:%s, owner_opr:%s", new_inp[0]->dtype().name(),
  683. new_inp[0]->name().c_str(),
  684. new_inp[0]->owner_opr()->name().c_str());
  685. mgb_assert(new_inp[1]->dtype() == dtype::Float16(),
  686. "inp %s:%s, owner_opr:%s", new_inp[1]->dtype().name(),
  687. new_inp[1]->name().c_str(),
  688. new_inp[1]->owner_opr()->name().c_str());
  689. auto new_deconv_opr = opr::ConvolutionBackwardData::make(
  690. new_inp[0], new_inp[1], new_param,
  691. deconv_opr.execution_policy(), deconv_opr.config());
  692. return new_deconv_opr.node()->owner_opr();
  693. };
  694. auto replace_convbias_opr = [use_f32_comp](OperatorNodeBase* opr,
  695. const VarNodeArray& new_inp) {
  696. auto& convbias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
  697. auto new_param = convbias_opr.param();
  698. if (use_f32_comp) {
  699. new_param.compute_mode =
  700. megdnn::param::ConvBias::ComputeMode::FLOAT32;
  701. }
  702. mgb_assert(new_inp[0]->dtype() == dtype::Float16(),
  703. "inp %s:%s, owner_opr:%s", new_inp[0]->dtype().name(),
  704. new_inp[0]->name().c_str(),
  705. new_inp[0]->owner_opr()->name().c_str());
  706. mgb_assert(new_inp[1]->dtype() == dtype::Float16(),
  707. "inp %s:%s, owner_opr:%s", new_inp[1]->dtype().name(),
  708. new_inp[1]->name().c_str(),
  709. new_inp[1]->owner_opr()->name().c_str());
  710. if(opr->input().size() == 2) {
  711. auto new_conv_opr = opr::ConvBias::make(
  712. new_inp[0], new_inp[1], new_param,
  713. convbias_opr.execution_policy(), convbias_opr.config());
  714. return new_conv_opr.node()->owner_opr();
  715. } else if(opr->input().size() == 3) {
  716. auto new_conv_opr = opr::ConvBias::make(
  717. new_inp[0], new_inp[1], new_inp[2], new_param,
  718. convbias_opr.execution_policy(), convbias_opr.config());
  719. return new_conv_opr.node()->owner_opr();
  720. } else {
  721. mgb_assert(opr->input().size() == 4, "invalid input size %zu",
  722. opr->input().size());
  723. auto new_conv_opr = opr::ConvBias::make(
  724. new_inp[0], new_inp[1], new_inp[2], new_inp[3], new_param,
  725. convbias_opr.execution_policy(), convbias_opr.config());
  726. return new_conv_opr.node()->owner_opr();
  727. }
  728. };
  729. auto replace_matmul_opr = [use_f32_comp](OperatorNodeBase* opr,
  730. const VarNodeArray& new_inp) {
  731. mgb_assert(opr->input().size() == new_inp.size());
  732. auto& matmul_opr = opr->cast_final_safe<opr::MatrixMul>();
  733. auto new_param = matmul_opr.param();
  734. if (use_f32_comp) {
  735. new_param.compute_mode =
  736. megdnn::param::MatrixMul::ComputeMode::FLOAT32;
  737. }
  738. auto new_matmul_opr = opr::MatrixMul::make(
  739. new_inp[0], new_inp[1], new_param,
  740. matmul_opr.execution_policy(), matmul_opr.config());
  741. return new_matmul_opr.node()->owner_opr();
  742. };
  743. auto replace_batched_matmul_opr = [use_f32_comp](
  744. OperatorNodeBase* opr,
  745. const VarNodeArray& new_inp) {
  746. mgb_assert(opr->input().size() == new_inp.size());
  747. auto& matmul_opr = opr->cast_final_safe<opr::BatchedMatrixMul>();
  748. auto new_param = matmul_opr.param();
  749. if (use_f32_comp) {
  750. new_param.compute_mode =
  751. megdnn::param::MatrixMul::ComputeMode::FLOAT32;
  752. }
  753. mgb_assert(new_inp[0]->dtype() == dtype::Float16(),
  754. "inp %s:%s, owner_opr:%s", new_inp[0]->dtype().name(),
  755. new_inp[0]->name().c_str(),
  756. new_inp[0]->owner_opr()->name().c_str());
  757. mgb_assert(new_inp[1]->dtype() == dtype::Float16(),
  758. "inp %s:%s, owner_opr:%s", new_inp[1]->dtype().name(),
  759. new_inp[1]->name().c_str(),
  760. new_inp[1]->owner_opr()->name().c_str());
  761. auto new_matmul_opr = opr::BatchedMatrixMul::make(
  762. new_inp[0], new_inp[1], new_param,
  763. matmul_opr.execution_policy(), matmul_opr.config());
  764. return new_matmul_opr.node()->owner_opr();
  765. };
  766. auto replace_reduce_opr = [use_f32_comp](OperatorNodeBase* opr,
  767. const VarNodeArray& new_inp) {
  768. auto& reduce_opr = opr->cast_final_safe<opr::Reduce>();
  769. auto new_param = reduce_opr.param();
  770. if (use_f32_comp) {
  771. new_param.data_type =
  772. megdnn::param::Reduce::DataType::FLOAT_O16xC32;
  773. }
  774. if (opr->input().size() == 1) {
  775. auto new_matmul_opr = opr::Reduce::make(new_inp[0], new_param, {},
  776. reduce_opr.config());
  777. return new_matmul_opr.node()->owner_opr();
  778. } else {
  779. mgb_assert(opr->input().size() == 2, "invalid input size %zu",
  780. opr->input().size());
  781. auto new_matmul_opr = opr::Reduce::make(
  782. new_inp[0], new_param, new_inp[1], reduce_opr.config());
  783. return new_matmul_opr.node()->owner_opr();
  784. }
  785. };
  786. auto replace_cvt_opr = [](OperatorNodeBase* opr,
  787. const VarNodeArray& new_inp) {
  788. auto& cvt_opr = opr->cast_final_safe<opr::TypeCvt>();
  789. SymbolVar new_cvt;
  790. if (cvt_opr.output(0)->dtype() == dtype::Float32()) {
  791. new_cvt = opr::TypeCvt::make(new_inp[0], dtype::Float16(),
  792. cvt_opr.config());
  793. } else {
  794. new_cvt = opr::TypeCvt::make(
  795. new_inp[0], cvt_opr.output()[0]->dtype(), cvt_opr.config());
  796. }
  797. return new_cvt.node()->owner_opr();
  798. };
  799. auto replace_warp_opr = [](OperatorNodeBase* opr,
  800. const VarNodeArray& new_inp) {
  801. mgb_assert(opr->input().size() == new_inp.size() &&
  802. (new_inp.size() == 3 || new_inp.size() == 4));
  803. auto& warp_opr = opr->cast_final<opr::WarpPerspective>();
  804. // mat tensor must be float32
  805. auto new_mat = new_inp[1];
  806. if (new_inp[1]->dtype() != dtype::Float32()) {
  807. if (try_cast_as_op<opr::TypeCvt>(new_mat->owner_opr()) &&
  808. new_mat->owner_opr()->input(0)->dtype() == dtype::Float32())
  809. new_mat = new_mat->owner_opr()->input(0);
  810. else
  811. new_mat = opr::TypeCvt::make(new_inp[1], dtype::Float32(), {})
  812. .node();
  813. }
  814. SymbolVar new_warp;
  815. if (new_inp.size() == 3) {
  816. new_warp = opr::WarpPerspective::make(new_inp[0], new_mat,
  817. new_inp[2], warp_opr.param(),
  818. warp_opr.config());
  819. } else {
  820. mgb_assert(new_inp.size() == 4);
  821. new_warp = opr::WarpPerspective::make(
  822. new_inp[0], new_mat, new_inp[2], new_inp[3],
  823. warp_opr.param(), warp_opr.config());
  824. }
  825. return new_warp.node()->owner_opr();
  826. };
  827. auto replace_remap_opr = [](OperatorNodeBase* opr,
  828. const VarNodeArray& new_inp) {
  829. mgb_assert(opr->input().size() == new_inp.size() &&
  830. (new_inp.size() == 2));
  831. auto& remap_opr = opr->cast_final<opr::Remap>();
  832. // map tensor must be float32
  833. auto new_map = new_inp[1];
  834. if (new_inp[1]->dtype() != dtype::Float32()) {
  835. if (try_cast_as_op<opr::TypeCvt>(new_map->owner_opr()) &&
  836. new_map->owner_opr()->input(0)->dtype() == dtype::Float32())
  837. new_map = new_map->owner_opr()->input(0);
  838. else
  839. new_map = opr::TypeCvt::make(new_inp[1], dtype::Float32(), {})
  840. .node();
  841. }
  842. SymbolVar new_remap;
  843. new_remap = opr::Remap::make(new_inp[0], new_map,
  844. remap_opr.param(),
  845. remap_opr.config());
  846. return new_remap.node()->owner_opr();
  847. };
  848. auto ret = std::make_unique<ConvertF32ToF16Pass>();
  849. // don't check dtype
  850. ret->set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^
  851. VarReplaceCheckFlag::CHECK_DTYPE);
  852. auto&& replace_func = ret->m_opr_replace_func;
  853. replace_func[opr::Linspace::typeinfo()] = replace_lsp_opr;
  854. replace_func[opr::Host2DeviceCopy::typeinfo()] = replace_h2d_opr;
  855. replace_func[opr::SharedDeviceTensor::typeinfo()] = replace_sdt_opr;
  856. replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
  857. replace_func[opr::ConvolutionBackwardData::typeinfo()] = replace_deconv_opr;
  858. replace_func[opr::ConvBias::typeinfo()] = replace_convbias_opr;
  859. replace_func[opr::MatrixMul::typeinfo()] = replace_matmul_opr;
  860. replace_func[opr::Reduce::typeinfo()] = replace_reduce_opr;
  861. replace_func[opr::ImmutableTensor::typeinfo()] = replace_imt_opr;
  862. replace_func[opr::TypeCvt::typeinfo()] = replace_cvt_opr;
  863. replace_func[opr::WarpPerspective::typeinfo()] = replace_warp_opr;
  864. replace_func[opr::Remap::typeinfo()] = replace_remap_opr;
  865. replace_func[opr::BatchedMatrixMul::typeinfo()] =
  866. replace_batched_matmul_opr;
  867. return ret;
  868. #endif
  869. }
  870. /* ================ ConvertFormatPass ================ */
  871. void ConvertFormatPass::apply(OptState& state) const {
  872. MIDOUT_B("ConvertFormatPass::apply")
  873. state.set_var_replace_check_flag(m_var_replace_check_flag);
  874. auto rewriter = state.graph().make_rewriter();
  875. VarNodeArray new_inp_cache;
  876. auto on_opr = [this, &state, &rewriter,
  877. &new_inp_cache](OperatorNodeBase* opr) {
  878. auto it = m_opr_replace_func.find(opr->dyn_typeinfo());
  879. if (it != m_opr_replace_func.end()) {
  880. auto&& new_inp = new_inp_cache;
  881. new_inp.clear();
  882. new_inp.reserve(opr->input().size());
  883. for (auto i : opr->input()) {
  884. new_inp.push_back(rewriter.get_var(i));
  885. }
  886. auto new_opr = (it->second)(opr, new_inp);
  887. auto &&out0 = opr->output(), &&out1 = new_opr->output();
  888. mgb_assert(out0.size() == out1.size(),
  889. "bad opr replace: src=%s{%s} dst=%s{%s}, src.size=%zu "
  890. "dst.size=%zu",
  891. opr->cname(), opr->dyn_typeinfo()->name,
  892. new_opr->cname(), new_opr->dyn_typeinfo()->name,
  893. out0.size(), out1.size());
  894. for (size_t i = 0; i < out0.size(); i++) {
  895. if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  896. mgb_assert(!out1[i]->contain_flag(
  897. VarNode::Flag::VOLATILE_CONTENT));
  898. auto src = out0[i];
  899. auto dst = out1[i];
  900. auto dst_is_image = dst->format().type() ==
  901. TensorFormat::Type::IMAGE2D_PACK4;
  902. if (!dst_is_image &&
  903. !src->owner_opr()->same_type<opr::ImmutableTensor>()) {
  904. mgb_log_warn(
  905. "convert NHWCD4 replaced to non-img format: "
  906. "dst_opr=%s{%s} format=%s",
  907. dst->owner_opr()->cname(),
  908. dst->owner_opr()->dyn_typeinfo()->name,
  909. dst->format().to_string().c_str());
  910. }
  911. if (state.graph().endpoint_contain(src) && dst_is_image) {
  912. // relayout back to NCHW for output vars
  913. dst = opr::RelayoutFormat::make(
  914. dst, {opr::RelayoutFormat::Param::Mode::
  915. NHWCD4I_NCHW})
  916. .node();
  917. }
  918. rewriter.replace_var(src, dst, nullptr);
  919. }
  920. }
  921. } else {
  922. rewriter.auto_replace_outputs(opr);
  923. }
  924. };
  925. state.graph().iter(on_opr);
  926. rewriter.apply_inplace();
  927. MIDOUT_E
  928. }
  929. std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
  930. MIDOUT_B("ConvertFormatPass::make")
  931. auto filter_mode =
  932. [](const megdnn::param::Convolution::Sparse conv_mode,
  933. const VarNode* filter) -> megdnn::param::RelayoutFormat::Mode {
  934. bool use_dot = false;
  935. if (filter->dtype().enumv() == megdnn::DTypeEnum::QuantizedS8 ||
  936. filter->dtype().enumv() == megdnn::DTypeEnum::Quantized8Asymm)
  937. use_dot = true;
  938. if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) {
  939. if (use_dot)
  940. return megdnn::param::RelayoutFormat::Mode::
  941. INTER_WEIGHT_DENSEI_DOT;
  942. return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_DENSEI;
  943. } else {
  944. mgb_throw_if(conv_mode != megdnn::param::Convolution::Sparse::GROUP,
  945. MegBrainError, "mode error");
  946. if (filter->shape()[1] == 1 && filter->shape()[2] == 1) {
  947. return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_CHANI;
  948. } else {
  949. if (use_dot)
  950. return megdnn::param::RelayoutFormat::Mode::
  951. INTER_WEIGHT_GROUPI_DOT;
  952. return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_GROUPI;
  953. }
  954. }
  955. };
  956. auto size_one_conv_to_dense_conv =
  957. [](VarNode* origin_filter_input,
  958. megdnn::param::Convolution::Sparse sparse) {
  959. VarNode* reshaped_filter = origin_filter_input;
  960. bool is_size_one_group_conv = false;
  961. if (sparse == megdnn::param::Convolution::Sparse::GROUP &&
  962. origin_filter_input->shape()[0] == 1) {
  963. is_size_one_group_conv = true;
  964. auto new_shape = origin_filter_input->shape();
  965. new_shape.ndim = 4;
  966. for (int i = 0; i < 4; i++) {
  967. new_shape[i] = origin_filter_input->shape()[i + 1];
  968. }
  969. SymbolVar new_var(origin_filter_input);
  970. reshaped_filter = new_var.reshape(new_shape).node();
  971. }
  972. return std::make_tuple(reshaped_filter, is_size_one_group_conv);
  973. };
  974. auto replace_conv_opr =
  975. [&filter_mode, &size_one_conv_to_dense_conv](OperatorNodeBase* opr,
  976. const VarNodeArray& new_inp) {
  977. mgb_assert(opr->input().size() == new_inp.size());
  978. auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
  979. mgb_throw_if(
  980. conv_opr.param().format !=
  981. megdnn::param::Convolution::Format::NCHW,
  982. MegBrainError,
  983. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  984. VarNode *conv_src = nullptr, *conv_weights = nullptr;
  985. if (new_inp[0]->shape().ndim == 4) {
  986. // new input src is NCHW
  987. size_t group, icpg, ocpg;
  988. if (conv_opr.param().sparse ==
  989. megdnn::param::Convolution::Sparse::DENSE) {
  990. group = 1;
  991. icpg = new_inp[1]->shape()[1];
  992. ocpg = new_inp[1]->shape()[0];
  993. } else {
  994. mgb_throw_if(conv_opr.param().sparse !=
  995. megdnn::param::Convolution::Sparse::GROUP,
  996. MegBrainError, "ERROR mode");
  997. group = new_inp[1]->shape()[0];
  998. icpg = new_inp[1]->shape()[2];
  999. ocpg = new_inp[1]->shape()[1];
  1000. }
  1001. if (ocpg % 4 == 0 && (icpg % 4 == 0 || group == 1)) {
  1002. auto param = megdnn::param::RelayoutFormat();
  1003. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1004. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1005. conv_src = rf.node();
  1006. } else {
  1007. // can not convert to hwcd4
  1008. return serialization::copy_opr_shallow(*opr, new_inp,
  1009. opr->config());
  1010. }
  1011. } else {
  1012. size_t ocpg;
  1013. bool is_channel_wise = false;
  1014. if (conv_opr.param().sparse ==
  1015. megdnn::param::Convolution::Sparse::DENSE) {
  1016. ocpg = new_inp[1]->shape()[0];
  1017. } else {
  1018. mgb_throw_if(conv_opr.param().sparse !=
  1019. megdnn::param::Convolution::Sparse::GROUP,
  1020. MegBrainError, "ERROR mode");
  1021. size_t icpg = new_inp[1]->shape()[2];
  1022. ocpg = new_inp[1]->shape()[1];
  1023. if (icpg == 1 && ocpg == 1) {
  1024. is_channel_wise = true;
  1025. }
  1026. }
  1027. if (ocpg % 4 != 0 && !is_channel_wise) {
  1028. VarNodeArray t_inp = new_inp;
  1029. auto param = megdnn::param::RelayoutFormat();
  1030. param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
  1031. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1032. t_inp[0] = rf.node();
  1033. auto new_opr = serialization::copy_opr_shallow(*opr, t_inp,
  1034. opr->config());
  1035. return new_opr;
  1036. }
  1037. // new input src is NHWCD4
  1038. auto&& fmt = new_inp[0]
  1039. ->format()
  1040. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1041. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1042. conv_src = new_inp[0];
  1043. }
  1044. VarNode* reshaped_filter;
  1045. bool is_size_one_group_conv;
  1046. std::tie(reshaped_filter, is_size_one_group_conv) =
  1047. size_one_conv_to_dense_conv(new_inp[1],
  1048. conv_opr.param().sparse);
  1049. auto new_conv_param = conv_opr.param();
  1050. if (is_size_one_group_conv) {
  1051. new_conv_param.sparse = megdnn::param::Convolution::Sparse::DENSE;
  1052. }
  1053. mgb_assert(new_inp[1]->format().type() !=
  1054. TensorFormat::Type::IMAGE2D_PACK4);
  1055. auto param = megdnn::param::RelayoutFormat();
  1056. param.mode = filter_mode(new_conv_param.sparse, reshaped_filter);
  1057. auto relayout_weight = opr::RelayoutFormat::make(reshaped_filter, param);
  1058. conv_weights = relayout_weight.node();
  1059. new_conv_param.format = megdnn::param::Convolution::Format::NHWCD4;
  1060. mgb_assert(conv_src->shape().ndim == 5 &&
  1061. conv_src->format().type() ==
  1062. TensorFormat::Type::IMAGE2D_PACK4);
  1063. auto new_conv_opr = opr::Convolution::make(
  1064. conv_src, conv_weights, new_conv_param, conv_opr.execution_policy(),
  1065. conv_opr.config());
  1066. OperatorNodeBase* ret = new_conv_opr.node()->owner_opr();
  1067. mgb_assert(new_conv_opr.shape().ndim == 5 &&
  1068. new_conv_opr.format().type() ==
  1069. TensorFormat::Type::IMAGE2D_PACK4);
  1070. return ret;
  1071. };
  1072. auto replace_conv_bias_opr =
  1073. [&filter_mode, &size_one_conv_to_dense_conv](OperatorNodeBase* opr,
  1074. const VarNodeArray& new_inp) {
  1075. mgb_assert(opr->input().size() == new_inp.size());
  1076. auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
  1077. mgb_throw_if(
  1078. conv_bias_opr.param().format !=
  1079. megdnn::param::ConvBias::Format::NCHW,
  1080. MegBrainError,
  1081. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1082. VarNode *conv_bias_src = nullptr, *conv_bias_weights = nullptr,
  1083. *conv_bias_bias = nullptr;
  1084. if (new_inp[0]->shape().ndim == 4) {
  1085. // new input src is NCHW
  1086. size_t group, icpg, ocpg;
  1087. if (conv_bias_opr.param().sparse ==
  1088. megdnn::param::ConvBias::Sparse::DENSE) {
  1089. group = 1;
  1090. icpg = new_inp[1]->shape()[1];
  1091. ocpg = new_inp[1]->shape()[0];
  1092. } else {
  1093. mgb_throw_if(conv_bias_opr.param().sparse !=
  1094. megdnn::param::ConvBias::Sparse::GROUP,
  1095. MegBrainError, "mode error");
  1096. group = new_inp[1]->shape()[0];
  1097. icpg = new_inp[1]->shape()[2];
  1098. ocpg = new_inp[1]->shape()[1];
  1099. }
  1100. if (ocpg % 4 == 0 && (icpg % 4 == 0 || group == 1)) {
  1101. auto param = megdnn::param::RelayoutFormat();
  1102. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1103. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1104. conv_bias_src = rf.node();
  1105. } else {
  1106. // can not convert to hwcd4
  1107. return serialization::copy_opr_shallow(*opr, new_inp,
  1108. opr->config());
  1109. }
  1110. } else {
  1111. size_t ocpg;
  1112. bool is_channel_wise = false;
  1113. if (conv_bias_opr.param().sparse ==
  1114. megdnn::param::ConvBias::Sparse::DENSE) {
  1115. ocpg = new_inp[1]->shape()[0];
  1116. } else {
  1117. mgb_throw_if(conv_bias_opr.param().sparse !=
  1118. megdnn::param::ConvBias::Sparse::GROUP,
  1119. MegBrainError, "ERROR mode");
  1120. size_t icpg = new_inp[1]->shape()[2];
  1121. ocpg = new_inp[1]->shape()[1];
  1122. if (icpg == 1 && ocpg == 1) {
  1123. is_channel_wise = true;
  1124. }
  1125. }
  1126. if (ocpg % 4 != 0 && !is_channel_wise) {
  1127. VarNodeArray t_inp = new_inp;
  1128. auto param = megdnn::param::RelayoutFormat();
  1129. param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
  1130. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1131. t_inp[0] = rf.node();
  1132. auto new_opr = serialization::copy_opr_shallow(*opr, t_inp,
  1133. opr->config());
  1134. return new_opr;
  1135. }
  1136. // new input src is NHWCD4
  1137. auto&& fmt = new_inp[0]
  1138. ->format()
  1139. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1140. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1141. conv_bias_src = new_inp[0];
  1142. }
  1143. mgb_assert(new_inp[1]->format().type() !=
  1144. TensorFormat::Type::IMAGE2D_PACK4);
  1145. VarNode* reshaped_filter;
  1146. bool is_size_one_group_conv;
  1147. std::tie(reshaped_filter, is_size_one_group_conv) =
  1148. size_one_conv_to_dense_conv(new_inp[1],
  1149. conv_bias_opr.param().sparse);
  1150. auto new_conv_param = conv_bias_opr.param();
  1151. if (is_size_one_group_conv) {
  1152. new_conv_param.sparse = megdnn::param::Convolution::Sparse::DENSE;
  1153. }
  1154. auto param = megdnn::param::RelayoutFormat();
  1155. param.mode = filter_mode(new_conv_param.sparse, reshaped_filter);
  1156. auto relayout_weight = opr::RelayoutFormat::make(reshaped_filter, param);
  1157. conv_bias_weights = relayout_weight.node();
  1158. mgb_assert(new_inp.size() < 4,
  1159. "ConvertFormat pass does not support fuse Z");
  1160. bool has_bias = new_inp.size() > 2;
  1161. if (has_bias &&
  1162. new_inp[2]->format().type() == TensorFormat::Type::DEFAULT) {
  1163. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1164. auto relayout_bias = opr::RelayoutFormat::make(new_inp[2], param);
  1165. conv_bias_bias = relayout_bias.node();
  1166. } else if (has_bias) {
  1167. conv_bias_bias = new_inp[2];
  1168. }
  1169. new_conv_param.format = megdnn::param::ConvBias::Format::NHWCD4;
  1170. mgb_assert(conv_bias_src->shape().ndim == 5 &&
  1171. conv_bias_src->format().type() ==
  1172. TensorFormat::Type::IMAGE2D_PACK4);
  1173. SymbolVar new_conv_bias_opr;
  1174. if (has_bias) {
  1175. new_conv_bias_opr = opr::ConvBias::make(
  1176. conv_bias_src, conv_bias_weights, conv_bias_bias, new_conv_param,
  1177. conv_bias_opr.execution_policy(), conv_bias_opr.config());
  1178. } else {
  1179. new_conv_bias_opr = opr::ConvBias::make(
  1180. conv_bias_src, conv_bias_weights, new_conv_param,
  1181. conv_bias_opr.execution_policy(), conv_bias_opr.config());
  1182. }
  1183. OperatorNodeBase* ret = new_conv_bias_opr.node()->owner_opr();
  1184. mgb_assert(new_conv_bias_opr.shape().ndim == 5 &&
  1185. new_conv_bias_opr.format().type() ==
  1186. TensorFormat::Type::IMAGE2D_PACK4);
  1187. return ret;
  1188. };
  1189. auto replace_deconv_opr = [&filter_mode](OperatorNodeBase* opr,
  1190. const VarNodeArray& new_inp) {
  1191. mgb_assert(opr->input().size() == new_inp.size());
  1192. auto& deconv_opr = opr->cast_final_safe<opr::ConvolutionBackwardData>();
  1193. mgb_throw_if(
  1194. deconv_opr.param().format !=
  1195. megdnn::param::Convolution::Format::NCHW,
  1196. MegBrainError,
  1197. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1198. VarNode *deconv_src = nullptr, *deconv_weights = nullptr;
  1199. if (new_inp[1]->shape().ndim == 4) {
  1200. // new input src is NCHW
  1201. size_t group, icpg, ocpg;
  1202. if (deconv_opr.param().sparse ==
  1203. megdnn::param::Convolution::Sparse::DENSE) {
  1204. group = 1;
  1205. icpg = new_inp[0]->shape()[0];
  1206. ocpg = new_inp[0]->shape()[1];
  1207. } else {
  1208. mgb_throw_if(deconv_opr.param().sparse !=
  1209. megdnn::param::Convolution::Sparse::GROUP,
  1210. MegBrainError, "mode error");
  1211. group = new_inp[0]->shape()[0];
  1212. icpg = new_inp[0]->shape()[1];
  1213. ocpg = new_inp[0]->shape()[2];
  1214. }
  1215. if (ocpg % 4 == 0 && (icpg % 4 == 0 || group == 1)) {
  1216. auto param = megdnn::param::RelayoutFormat();
  1217. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1218. auto rf = opr::RelayoutFormat::make(new_inp[1], param);
  1219. deconv_src = rf.node();
  1220. } else {
  1221. // can not convert to hwcd4
  1222. return serialization::copy_opr_shallow(*opr, new_inp,
  1223. opr->config());
  1224. }
  1225. } else {
  1226. //! XXXX, fix me, check filter size
  1227. size_t ocpg;
  1228. if (deconv_opr.param().sparse ==
  1229. megdnn::param::Convolution::Sparse::DENSE) {
  1230. ocpg = new_inp[0]->shape()[1];
  1231. } else {
  1232. mgb_throw_if(deconv_opr.param().sparse !=
  1233. megdnn::param::Convolution::Sparse::GROUP,
  1234. MegBrainError, "mode error");
  1235. ocpg = new_inp[0]->shape()[2];
  1236. }
  1237. if (ocpg % 4 != 0) {
  1238. VarNodeArray t_inp = new_inp;
  1239. auto param = megdnn::param::RelayoutFormat();
  1240. param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
  1241. auto rf = opr::RelayoutFormat::make(new_inp[1], param);
  1242. t_inp[1] = rf.node();
  1243. auto new_opr = serialization::copy_opr_shallow(*opr, t_inp,
  1244. opr->config());
  1245. return new_opr;
  1246. }
  1247. // new input src is NHWCD4
  1248. auto&& fmt = new_inp[1]
  1249. ->format()
  1250. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1251. mgb_assert(new_inp[1]->shape().ndim == 5 && fmt.align_axis() == 2);
  1252. deconv_src = new_inp[1];
  1253. }
  1254. mgb_assert(new_inp[0]->format().type() !=
  1255. TensorFormat::Type::IMAGE2D_PACK4);
  1256. auto param = megdnn::param::RelayoutFormat();
  1257. param.mode = filter_mode(deconv_opr.param().sparse, new_inp[0]);
  1258. auto relayout_weight = opr::RelayoutFormat::make(new_inp[0], param);
  1259. deconv_weights = relayout_weight.node();
  1260. auto new_param = deconv_opr.param();
  1261. new_param.format = megdnn::param::Convolution::Format::NHWCD4;
  1262. mgb_assert(deconv_src->shape().ndim == 5 &&
  1263. deconv_src->format().type() ==
  1264. TensorFormat::Type::IMAGE2D_PACK4);
  1265. auto new_deconv_opr = opr::ConvolutionBackwardData::make(
  1266. deconv_weights, deconv_src, new_param,
  1267. deconv_opr.execution_policy(), deconv_opr.config());
  1268. OperatorNodeBase* ret = new_deconv_opr.node()->owner_opr();
  1269. mgb_assert(new_deconv_opr.shape().ndim == 5 &&
  1270. new_deconv_opr.format().type() ==
  1271. TensorFormat::Type::IMAGE2D_PACK4);
  1272. return ret;
  1273. };
  1274. /* This helper function guarantees the format convert pass won't change
  1275. * output var's channel. Changing output's channel will cause channel
  1276. * mismatch problem for replacing conv/conv_bias operator.
  1277. */
  1278. auto replace_helper = [](OperatorNodeBase* opr,
  1279. const VarNodeArray& new_inp) -> OperatorNodeBase* {
  1280. auto&& new_shp = new_inp[0]->shape();
  1281. size_t inp_channel = new_shp[1];
  1282. if (new_shp.eq_shape(opr->input(0)->shape())&& inp_channel % 4 != 0) {
  1283. auto new_opr = serialization::copy_opr_shallow(*opr, new_inp,
  1284. opr->config());
  1285. return new_opr;
  1286. }
  1287. return nullptr;
  1288. };
  1289. auto replace_resize_opr = [replace_helper](OperatorNodeBase* opr,
  1290. const VarNodeArray& new_inp) {
  1291. mgb_assert(opr->input().size() == new_inp.size());
  1292. if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
  1293. return opr_shallow_copy;
  1294. }
  1295. auto& resize_opr = opr->cast_final_safe<opr::ResizeForward>();
  1296. mgb_throw_if(
  1297. resize_opr.param().format !=
  1298. megdnn::param::Resize::Format::NCHW,
  1299. MegBrainError,
  1300. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1301. VarNode* inp = nullptr;
  1302. if (new_inp[0]->shape().ndim == 4) {
  1303. auto param = megdnn::param::RelayoutFormat();
  1304. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1305. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1306. inp = rf.node();
  1307. } else {
  1308. // new input src is NHWCD
  1309. auto&& fmt = new_inp[0]
  1310. ->format()
  1311. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1312. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1313. inp = new_inp[0];
  1314. }
  1315. auto new_param = resize_opr.param();
  1316. new_param.format = megdnn::param::Resize::Format::NHWCD4;
  1317. auto new_resize_opr = opr::ResizeForward::make(
  1318. inp, new_inp[1], new_param, opr->config());
  1319. return new_resize_opr.node()->owner_opr();
  1320. };
  1321. auto replace_warp_perspective_opr = [replace_helper](
  1322. OperatorNodeBase* opr,
  1323. const VarNodeArray& new_inp) {
  1324. mgb_assert(opr->input().size() == new_inp.size());
  1325. if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
  1326. return opr_shallow_copy;
  1327. }
  1328. auto& warp_opr = opr->cast_final_safe<opr::WarpPerspectiveForward>();
  1329. mgb_throw_if(
  1330. warp_opr.param().format !=
  1331. megdnn::param::WarpPerspective::Format::NCHW,
  1332. MegBrainError,
  1333. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1334. VarNode* inp = nullptr;
  1335. if (new_inp[0]->shape().ndim == 4) {
  1336. // new input src is NCHW
  1337. auto param = megdnn::param::RelayoutFormat();
  1338. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1339. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1340. inp = rf.node();
  1341. } else {
  1342. // new input src is NHWCD
  1343. auto&& fmt = new_inp[0]
  1344. ->format()
  1345. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1346. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1347. inp = new_inp[0];
  1348. }
  1349. auto new_param = warp_opr.param();
  1350. new_param.format = megdnn::param::WarpPerspective::Format::NHWCD4;
  1351. SymbolVar new_warp_opr;
  1352. if (new_inp.size() == 3) {
  1353. new_warp_opr = opr::WarpPerspectiveForward::make(
  1354. inp, new_inp[1], nullptr, new_inp[2], new_param,
  1355. opr->config());
  1356. } else {
  1357. mgb_assert(new_inp.size() == 4);
  1358. new_warp_opr = opr::WarpPerspectiveForward::make(
  1359. inp, new_inp[1], new_inp[2], new_inp[3], new_param,
  1360. opr->config());
  1361. }
  1362. return new_warp_opr.node()->owner_opr();
  1363. };
  1364. auto replace_warp_affine_opr = [replace_helper](OperatorNodeBase* opr,
  1365. const VarNodeArray& new_inp) {
  1366. mgb_assert(opr->input().size() == new_inp.size());
  1367. if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
  1368. return opr_shallow_copy;
  1369. }
  1370. auto& warp_opr = opr->cast_final_safe<opr::WarpAffineForward>();
  1371. mgb_throw_if(
  1372. warp_opr.param().format !=
  1373. megdnn::param::WarpAffine::Format::NCHW,
  1374. MegBrainError,
  1375. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1376. VarNode* inp = nullptr;
  1377. if (new_inp[0]->shape().ndim == 4) {
  1378. // new input src is NCHW
  1379. auto param = megdnn::param::RelayoutFormat();
  1380. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1381. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1382. inp = rf.node();
  1383. } else {
  1384. // new input src is NHWCD
  1385. auto&& fmt = new_inp[0]
  1386. ->format()
  1387. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1388. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1389. inp = new_inp[0];
  1390. }
  1391. auto new_param = warp_opr.param();
  1392. new_param.format = megdnn::param::WarpAffine::Format::NHWCD4;
  1393. SymbolVar new_warp_opr;
  1394. new_warp_opr = opr::WarpAffineForward::make(inp, new_inp[1], new_inp[2],
  1395. new_param, opr->config());
  1396. return new_warp_opr.node()->owner_opr();
  1397. };
  1398. auto replace_pooling_opr = [replace_helper](OperatorNodeBase* opr,
  1399. const VarNodeArray& new_inp) {
  1400. mgb_assert(opr->input().size() == new_inp.size());
  1401. if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
  1402. return opr_shallow_copy;
  1403. }
  1404. auto& pooling_opr = opr->cast_final_safe<opr::PoolingForward>();
  1405. mgb_throw_if(
  1406. pooling_opr.param().format !=
  1407. megdnn::param::Pooling::Format::NCHW,
  1408. MegBrainError,
  1409. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1410. VarNode* inp = nullptr;
  1411. if (new_inp[0]->shape().ndim == 4) {
  1412. // new input src is NCHW
  1413. auto param = megdnn::param::RelayoutFormat();
  1414. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1415. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1416. inp = rf.node();
  1417. } else {
  1418. // new input src is NHWCD
  1419. auto&& fmt = new_inp[0]
  1420. ->format()
  1421. .as_impl<megdnn::Image2DPack4TensorFormat>();
  1422. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1423. inp = new_inp[0];
  1424. }
  1425. auto new_param = pooling_opr.param();
  1426. new_param.format = megdnn::param::Pooling::Format::NHWCD4;
  1427. auto new_pooling_opr =
  1428. opr::PoolingForward::make(inp, new_param, opr->config());
  1429. return new_pooling_opr.node()->owner_opr();
  1430. };
  1431. auto var_to_chw = [](VarNode* inp, VarNode* new_inp) {
  1432. if (!inp->shape().eq_shape(new_inp->shape())) {
  1433. mgb_assert(inp->shape().ndim == 4 &&
  1434. inp->format().type() !=
  1435. TensorFormat::Type::IMAGE2D_PACK4);
  1436. mgb_assert(new_inp->shape().ndim == 5 &&
  1437. new_inp->format().type() ==
  1438. TensorFormat::Type::IMAGE2D_PACK4);
  1439. auto param = megdnn::param::RelayoutFormat();
  1440. param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
  1441. auto rf = opr::RelayoutFormat::make(new_inp, param);
  1442. return rf.node();
  1443. }
  1444. return new_inp;
  1445. };
  1446. auto relayout_inp_to_chw = [var_to_chw](OperatorNodeBase* opr,
  1447. const VarNodeArray& new_inp) {
  1448. mgb_assert(opr->input().size() == new_inp.size());
  1449. VarNodeArray t_inp = new_inp;
  1450. for (size_t i = 0; i < opr->input().size(); i++) {
  1451. t_inp[i] = var_to_chw(opr->input(i), new_inp[i]);
  1452. }
  1453. auto new_opr =
  1454. serialization::copy_opr_shallow(*opr, t_inp, opr->config());
  1455. return new_opr;
  1456. };
  1457. auto replace_elemwise_opr = [&relayout_inp_to_chw](
  1458. OperatorNodeBase* opr,
  1459. const VarNodeArray& new_inp) {
  1460. mgb_assert(opr->input().size() == new_inp.size());
  1461. bool has_inp_changed = false;
  1462. bool can_exec_cd4 = true;
  1463. for (size_t i = 0; i < opr->input().size(); i++) {
  1464. if (!new_inp[i]->format().is_default()) {
  1465. has_inp_changed = true;
  1466. } else if (new_inp[i]->shape().ndim == 4) {
  1467. if (new_inp[i]->shape()[1] % 4 != 0) {
  1468. can_exec_cd4 = false;
  1469. }
  1470. //! cd4 elemwise with scaler is unsupported
  1471. } else if (!new_inp[i]->shape().is_scalar()) {
  1472. can_exec_cd4 = false;
  1473. }
  1474. }
  1475. if (!can_exec_cd4) {
  1476. return relayout_inp_to_chw(opr, new_inp);
  1477. }
  1478. if (has_inp_changed) {
  1479. // assumption: all inputs are changed from nchw to nhwcd4
  1480. auto t_inp = new_inp;
  1481. for (size_t i = 0; i < opr->input().size(); i++) {
  1482. if (new_inp[i]->shape().ndim == 4) {
  1483. auto param = megdnn::param::RelayoutFormat();
  1484. param.mode =
  1485. megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1486. auto rf = opr::RelayoutFormat::make(new_inp[i], param);
  1487. t_inp[i] = rf.node();
  1488. } else {
  1489. mgb_assert((new_inp[i]->shape().ndim == 5 &&
  1490. new_inp[i]->format().type() ==
  1491. TensorFormat::Type::IMAGE2D_PACK4) ||
  1492. new_inp[i]->shape().is_scalar());
  1493. }
  1494. }
  1495. return serialization::copy_opr_shallow(*opr, t_inp, opr->config());
  1496. } else {
  1497. return serialization::copy_opr_shallow(*opr, new_inp,
  1498. opr->config());
  1499. }
  1500. };
  1501. /* This helper function converts the first input to the NCHW format to
  1502. * handle operations that do not support NHWCD4 format
  1503. */
  1504. auto relayout_first_inp_to_chw =
  1505. [var_to_chw](OperatorNodeBase* opr,
  1506. const VarNodeArray& new_inp) -> OperatorNodeBase* {
  1507. mgb_assert(opr->input().size() == new_inp.size());
  1508. VarNodeArray t_inp = new_inp;
  1509. t_inp[0] = var_to_chw(opr->input(0), new_inp[0]);
  1510. return serialization::copy_opr_shallow(*opr, t_inp, opr->config());
  1511. };
  1512. auto ret = std::make_unique<ConvertFormatPass>();
  1513. ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
  1514. auto&& replace_func = ret->m_opr_replace_func;
  1515. replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
  1516. replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr;
  1517. replace_func[opr::ConvolutionBackwardData::typeinfo()] = replace_deconv_opr;
  1518. replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
  1519. replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr;
  1520. replace_func[opr::Concat::typeinfo()] = relayout_inp_to_chw;
  1521. replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_chw;
  1522. replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_chw;
  1523. replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_chw;
  1524. replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_chw;
  1525. replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_chw;
  1526. replace_func[opr::Subtensor::typeinfo()] = relayout_inp_to_chw;
  1527. replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_chw;
  1528. replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_chw;
  1529. replace_func[opr::AxisAddRemove::typeinfo()] = relayout_inp_to_chw;
  1530. replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr;
  1531. replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr;
  1532. replace_func[opr::WarpPerspectiveForward::typeinfo()] =
  1533. replace_warp_perspective_opr;
  1534. replace_func[opr::WarpAffineForward::typeinfo()] = replace_warp_affine_opr;
  1535. replace_func[opr::LocalForward::typeinfo()] = relayout_first_inp_to_chw;
  1536. replace_func[opr::GroupLocalForward::typeinfo()] =
  1537. relayout_first_inp_to_chw;
  1538. return ret;
  1539. MIDOUT_E
  1540. }
  1541. /* ================ ConvertBatchNormPass ================ */
  1542. const char* ConvertBatchNormToElemwisePass::name() const {
  1543. return "convert_batch_norm";
  1544. }
  1545. void ConvertBatchNormToElemwisePass::apply(OptState& state) const {
  1546. MIDOUT_B("ConvertBatchNormToElemwisePass::apply")
  1547. auto rewriter = state.graph().make_rewriter();
  1548. auto on_opr = [&](OperatorNodeBase* opr) {
  1549. if (auto bn = try_cast_as_op<opr::BatchNorm>(opr)) {
  1550. if (bn->input().size() == 5) {
  1551. mgb_assert(bn->param().fwd_mode ==
  1552. opr::BatchNorm::Param::FwdMode::INFERENCE);
  1553. SymbolVar x = {rewriter.get_var(bn->input(0))};
  1554. SymbolVar scale = {rewriter.get_var(bn->input(1))};
  1555. SymbolVar bias = {rewriter.get_var(bn->input(2))};
  1556. SymbolVar mean = {rewriter.get_var(bn->input(3))};
  1557. SymbolVar variance = {rewriter.get_var(bn->input(4))};
  1558. SymbolVar invsqrt_variance = opr::PowC::make(variance
  1559. + variance.make_scalar_dt(float(bn->param().epsilon)), {-0.5});
  1560. auto res = scale * (x - mean) * invsqrt_variance + bias;
  1561. rewriter.replace_var(
  1562. opr->output(4), res.node(),
  1563. mgb_cstr_log(
  1564. "replace batch_norm(x, scale, bias, mean, "
  1565. "varience) "
  1566. "-> (sclae * (x - mean) / sqrt(variance)) + b)"));
  1567. return;
  1568. }
  1569. }
  1570. rewriter.auto_replace_outputs(opr);
  1571. };
  1572. state.graph().iter(on_opr);
  1573. rewriter.apply_inplace();
  1574. MIDOUT_E
  1575. }
  1576. /* ================ FuseConvBiasNonlinPass ================ */
  1577. const char* FuseConvBiasNonlinPass::name() const {
  1578. return "combine_conv_bias_and_relu";
  1579. }
  1580. void FuseConvBiasNonlinPass::apply(OptState& state) const {
  1581. MIDOUT_B("FuseConvBiasNonlinPass::apply")
  1582. std::unordered_map<VarNode*, std::vector<OperatorNodeBase*>> m_deps;
  1583. state.graph().iter([&m_deps](OperatorNodeBase* opr) {
  1584. for (auto& inp : opr->input()) {
  1585. m_deps[inp].push_back(opr);
  1586. }
  1587. });
  1588. auto rewriter = state.graph().make_rewriter();
  1589. using Mode = opr::Elemwise::Param::Mode;
  1590. using NonlineMode = opr::ConvBiasForward::Param::NonlineMode;
  1591. auto get_nonlinearity_mode = [&](opr::Elemwise* elem) -> NonlineMode {
  1592. if (elem->param().mode == Mode::FUSE_ADD_RELU ||
  1593. elem->param().mode == Mode::RELU) {
  1594. return NonlineMode::RELU;
  1595. } else if (elem->param().mode == Mode::FUSE_ADD_SIGMOID ||
  1596. elem->param().mode == Mode::SIGMOID) {
  1597. return NonlineMode::SIGMOID;
  1598. } else {
  1599. return NonlineMode::IDENTITY;
  1600. }
  1601. };
  1602. auto try_fuse_bias_nonlinearity = [&](opr::Elemwise* elem) -> bool {
  1603. bool can_be_fused = true;
  1604. can_be_fused &= (elem->input().size() == 2);
  1605. can_be_fused &= (elem->param().mode == Mode::FUSE_ADD_RELU) ||
  1606. (elem->param().mode == Mode::FUSE_ADD_TANH) ||
  1607. (elem->param().mode == Mode::FUSE_ADD_SIGMOID);
  1608. return can_be_fused;
  1609. };
  1610. auto try_fuse_bias = [&](opr::Elemwise* elem) -> bool {
  1611. bool can_be_fused = true;
  1612. can_be_fused &= (elem->input().size() == 2);
  1613. can_be_fused &= (elem->param().mode == Mode::ADD);
  1614. return can_be_fused;
  1615. };
  1616. auto try_fuse_nonlinearity = [&](opr::Elemwise* elem) -> bool {
  1617. bool can_be_fused = true;
  1618. can_be_fused &= (elem->input().size() == 1);
  1619. can_be_fused &= (elem->param().mode == Mode::RELU) ||
  1620. (elem->param().mode == Mode::TANH) ||
  1621. (elem->param().mode == Mode::SIGMOID);
  1622. return can_be_fused;
  1623. };
  1624. auto convert_to_conv_bias_param = [&](const opr::Convolution::Param& param)
  1625. -> opr::ConvBiasForward::Param {
  1626. using Param = opr::ConvBiasForward::Param;
  1627. return opr::ConvBiasForward::Param{Param::NonlineMode::IDENTITY,
  1628. param.mode,
  1629. param.sparse,
  1630. param.format,
  1631. param.pad_h,
  1632. param.pad_w,
  1633. param.stride_h,
  1634. param.stride_w,
  1635. param.dilate_h,
  1636. param.dilate_w,
  1637. param.compute_mode};
  1638. };
  1639. auto check_bias_shape = [&](opr::Convolution* conv, VarNode* bias) -> bool {
  1640. bool valid_bias_shape = true;
  1641. using Format = opr::Convolution::Param::Format;
  1642. using Sparse = opr::Convolution::Param::Sparse;
  1643. auto dst_shape = conv->output(0)->shape();
  1644. auto filter_shape = conv->input(1)->shape();
  1645. auto bias_shape = bias->shape();
  1646. //! pay attention: make sure bias node is not const provider when
  1647. //! batch > 1 cause shape assert problem in convbias
  1648. //! if you resize the input shape, can not update the bias shape too.
  1649. //! so do not fuse conv bias in this situation
  1650. if (dst_shape.eq_shape(bias_shape) && !cg::is_const_var_shape(bias)) {
  1651. return valid_bias_shape;
  1652. }
  1653. size_t OC = filter_shape[0];
  1654. if (conv->param().sparse == Sparse::GROUP) {
  1655. OC *= filter_shape[1];
  1656. }
  1657. if (conv->param().format == Format::NCHW) {
  1658. valid_bias_shape &=
  1659. ((bias_shape.ndim == 4) && (bias_shape[0] == 1) &&
  1660. (bias_shape[1] == OC) && (bias_shape[2] == 1) &&
  1661. (bias_shape[3] == 1));
  1662. } else if (conv->param().format == Format::NCHW4) {
  1663. valid_bias_shape &=
  1664. ((bias_shape.ndim == 5) && (bias_shape[0] == 1) &&
  1665. (bias_shape[1] == OC / 4) && (bias_shape[2] == 1) &&
  1666. (bias_shape[3] == 1) && bias_shape[4] == 4);
  1667. } else if (conv->param().format == Format::NHWC) {
  1668. valid_bias_shape &= ((bias_shape.ndim == 4) &&
  1669. (bias_shape[0] == 1) && (bias_shape[1] == 1) &&
  1670. (bias_shape[2] == 1) && (bias_shape[3] == OC));
  1671. } else {
  1672. valid_bias_shape &=
  1673. ((bias_shape.ndim == 5) && (bias_shape[0] == 1) &&
  1674. (bias_shape[1] == 1) && (bias_shape[2] == OC) &&
  1675. (bias_shape[3] == 1) && (bias_shape[4] == 4));
  1676. mgb_assert(conv->param().format == Format::NHWCD4);
  1677. }
  1678. return valid_bias_shape;
  1679. };
  1680. auto try_fuse_typecvt = [&](opr::TypeCvt* typecvt) -> OperatorNodeBase* {
  1681. mgb_assert(typecvt->input().size() == 1);
  1682. auto conv_bias = try_cast_as_op<opr::ConvBias>(
  1683. rewriter.get_var(typecvt->input(0))->owner_opr());
  1684. if (!conv_bias || m_deps.count(typecvt->input(0)) != 1 ||
  1685. typecvt->output(0)->dtype().enumv() !=
  1686. DTypeTrait<dtype::QuantizedS8>::enumv ||
  1687. typecvt->input(0)->dtype().enumv() !=
  1688. DTypeTrait<dtype::QuantizedS32>::enumv)
  1689. return nullptr;
  1690. auto config = conv_bias->config();
  1691. config.output_dtype(typecvt->output(0)->dtype());
  1692. if (conv_bias->input().size() == 3) {
  1693. // conv + bias
  1694. return opr::ConvBias::make(conv_bias->input(0), conv_bias->input(1),
  1695. conv_bias->input(2), conv_bias->param(),
  1696. conv_bias->execution_policy(), config)
  1697. .node()
  1698. ->owner_opr();
  1699. } else {
  1700. // conv without bias
  1701. return opr::ConvBias::make(conv_bias->input(0), conv_bias->input(1),
  1702. conv_bias->param(),
  1703. conv_bias->execution_policy(), config)
  1704. .node()
  1705. ->owner_opr();
  1706. }
  1707. };
  1708. auto on_opr = [&](OperatorNodeBase* opr) {
  1709. auto check_conv = [](opr::Convolution* conv) -> bool {
  1710. return conv->param().format ==
  1711. megdnn::param::Convolution::Format::NHWCD4 ||
  1712. conv->param().format ==
  1713. megdnn::param::Convolution::Format::NHWC ||
  1714. conv->param().format ==
  1715. megdnn::param::Convolution::Format::NCHW ||
  1716. conv->param().format ==
  1717. megdnn::param::Convolution::Format::NCHW4
  1718. ;
  1719. };
  1720. if (auto elem = try_cast_as_op<opr::Elemwise>(opr)) {
  1721. if (try_fuse_bias_nonlinearity(elem) || try_fuse_bias(elem)) {
  1722. auto inp1 = rewriter.get_var(elem->input(0));
  1723. auto inp2 = rewriter.get_var(elem->input(1));
  1724. opr::Convolution* conv = nullptr;
  1725. size_t bias_idx = 0;
  1726. if (inp1->owner_opr()->same_type<opr::Convolution>() &&
  1727. m_deps[elem->input(0)].size() == 1) {
  1728. conv = try_cast_as_op<opr::Convolution>(inp1->owner_opr());
  1729. bias_idx = 1;
  1730. } else if (inp2->owner_opr()->same_type<opr::Convolution>() &&
  1731. m_deps[elem->input(1)].size() == 1) {
  1732. conv = try_cast_as_op<opr::Convolution>(inp2->owner_opr());
  1733. bias_idx = 0;
  1734. }
  1735. auto bias_inp = rewriter.get_var(elem->input(bias_idx));
  1736. if (conv && check_conv(conv) &&
  1737. check_bias_shape(conv, bias_inp)) {
  1738. opr::ConvBiasForward::Param param =
  1739. convert_to_conv_bias_param(conv->param());
  1740. param.nonlineMode = get_nonlinearity_mode(elem);
  1741. auto new_var =
  1742. opr::ConvBiasForward::make(
  1743. conv->input(0), conv->input(1), bias_inp,
  1744. param, conv->execution_policy(),
  1745. conv->config())
  1746. .node();
  1747. rewriter.replace_var(
  1748. opr->output(0), new_var,
  1749. mgb_cstr_log("replace nonlinearity(conv(x, w) + b) "
  1750. "-> conv_bias(x, w, b)"));
  1751. return;
  1752. }
  1753. } else if (try_fuse_nonlinearity(elem)) {
  1754. auto inp = rewriter.get_var(elem->input(0));
  1755. {
  1756. auto conv =
  1757. try_cast_as_op<opr::Convolution>(inp->owner_opr());
  1758. if (conv && check_conv(conv) &&
  1759. m_deps[elem->input(0)].size() == 1) {
  1760. opr::ConvBiasForward::Param param =
  1761. convert_to_conv_bias_param(conv->param());
  1762. param.nonlineMode = get_nonlinearity_mode(elem);
  1763. auto new_var = opr::ConvBiasForward::make(
  1764. conv->input(0), conv->input(1),
  1765. param, conv->execution_policy(),
  1766. conv->config())
  1767. .node();
  1768. rewriter.replace_var(
  1769. opr->output(0), new_var,
  1770. mgb_cstr_log("replace nonlinearity(conv(x, w)) "
  1771. "-> conv_bias(x, w)"));
  1772. return;
  1773. }
  1774. }
  1775. {
  1776. auto conv = try_cast_as_op<opr::ConvBias>(inp->owner_opr());
  1777. auto check_conv_bias = [&](opr::ConvBias* opr) {
  1778. return opr->param().format ==
  1779. opr::ConvBias::Param::Format::NHWC ||
  1780. opr->param().format ==
  1781. opr::ConvBias::Param::Format::NCHW ||
  1782. opr->param().format ==
  1783. opr::ConvBias::Param::Format::NCHW4
  1784. ;
  1785. };
  1786. if (conv && check_conv_bias(conv) &&
  1787. m_deps[elem->input(0)].size() == 1) {
  1788. auto param = conv->param();
  1789. param.nonlineMode = get_nonlinearity_mode(elem);
  1790. auto new_var = opr::ConvBiasForward::make(
  1791. conv->input(0), conv->input(1),
  1792. conv->input(2), param,
  1793. conv->execution_policy(),
  1794. conv->config())
  1795. .node();
  1796. rewriter.replace_var(
  1797. opr->output(0), new_var,
  1798. mgb_cstr_log("replace nonlinearity(conv(x, w)) "
  1799. "-> conv_bias(x, w)"));
  1800. return;
  1801. }
  1802. }
  1803. }
  1804. } else if (auto typecvt = try_cast_as_op<opr::TypeCvt>(opr)) {
  1805. auto new_opr = try_fuse_typecvt(typecvt);
  1806. if (new_opr) {
  1807. rewriter.replace_var(
  1808. opr->output(0), new_opr->output(0),
  1809. mgb_cstr_log("replace typecvt(conv_bias(x, w, b)) -> "
  1810. "conv_bias(x, w, b)"));
  1811. return;
  1812. }
  1813. }
  1814. rewriter.auto_replace_outputs(opr);
  1815. };
  1816. state.graph().iter(on_opr);
  1817. rewriter.apply_inplace();
  1818. MIDOUT_E
  1819. }
  1820. /* ================ FuseConvBiasZPass ================ */
  1821. const char* FuseConvBiasZPass::name() const {
  1822. return "combine_conv_bias_and_z";
  1823. }
  1824. void FuseConvBiasZPass::apply(OptState& state) const {
  1825. MIDOUT_B("FuseConvBiasZPass::apply")
  1826. UniqReaderCheck uniq_reader_check{state.graph()};
  1827. auto rewriter = state.graph().make_rewriter();
  1828. using Mode = opr::Elemwise::Param::Mode;
  1829. using MultiMode = opr::ElemwiseMultiType::Param::Mode;
  1830. using NonlineMode = opr::ConvBiasForward::Param::NonlineMode;
  1831. auto check_conv_bias = [](opr::ConvBias* conv_bias) -> bool {
  1832. return conv_bias->param().format ==
  1833. megdnn::param::ConvBias::Format::NHWC ||
  1834. conv_bias->param().format ==
  1835. megdnn::param::ConvBias::Format::NCHW ||
  1836. conv_bias->param().format ==
  1837. megdnn::param::ConvBias::Format::NCHW4
  1838. ;
  1839. };
  1840. auto check_fuse_shape = [&](opr::ConvBias* conv_bias, VarNode* z) -> bool {
  1841. bool valid_fuse_shape = true;
  1842. auto z_shape = z->shape();
  1843. auto bias_shape = conv_bias->input(2)->shape();
  1844. auto conv_bias_shape = conv_bias->output(0)->shape();
  1845. valid_fuse_shape &= (!conv_bias_shape.eq_shape(bias_shape));
  1846. valid_fuse_shape &= conv_bias_shape.eq_shape(z_shape);
  1847. return valid_fuse_shape;
  1848. };
  1849. auto check_fuse_dtype = [&](opr::ConvBias* conv_bias, VarNode* z) -> bool {
  1850. return conv_bias->output(0)->dtype().enumv() == z->dtype().enumv();
  1851. };
  1852. #if MGB_CUDA && (CUDNN_MAJOR == 8)
  1853. auto check_fuse_param = [&](opr::ConvBias* conv_bias, VarNode* z) -> bool {
  1854. return conv_bias->input(0) != z;
  1855. };
  1856. #endif
  1857. auto get_convbias_nonline_mode = [&](OperatorNodeBase* opr) -> NonlineMode {
  1858. if (opr->same_type<opr::Elemwise>()) {
  1859. auto elem = try_cast_as_op<opr::Elemwise>(opr);
  1860. if (elem->param().mode == Mode::FUSE_ADD_RELU)
  1861. return NonlineMode::RELU;
  1862. }
  1863. if (opr->same_type<opr::ElemwiseMultiType>()) {
  1864. auto elem = try_cast_as_op<opr::ElemwiseMultiType>(opr);
  1865. if (elem->param().mode == MultiMode::QFUSE_ADD_RELU)
  1866. return NonlineMode::RELU;
  1867. else if (elem->param().mode == MultiMode::QFUSE_ADD_H_SWISH)
  1868. return NonlineMode::H_SWISH;
  1869. }
  1870. return NonlineMode::IDENTITY;
  1871. };
  1872. auto try_replace_var_node = [&](OperatorNodeBase* opr) {
  1873. opr::ConvBias* conv_bias = nullptr;
  1874. size_t z_idx = 0;
  1875. size_t nr_inps = opr->input().size();
  1876. for (size_t i = 0; i < nr_inps; i++) {
  1877. auto inp = rewriter.get_var(opr->input(i));
  1878. if (inp->owner_opr()->same_type<opr::ConvBias>()) {
  1879. auto cb = try_cast_as_op<opr::ConvBias>(inp->owner_opr());
  1880. if (cb->input().size() == 3 &&
  1881. cb->param().nonlineMode ==
  1882. opr::ConvBias::Param::NonlineMode::IDENTITY &&
  1883. uniq_reader_check(opr->input(i))) {
  1884. conv_bias = cb;
  1885. z_idx = nr_inps - i - 1;
  1886. break;
  1887. }
  1888. }
  1889. }
  1890. auto z_inp = rewriter.get_var(opr->input(z_idx));
  1891. if (conv_bias && check_conv_bias(conv_bias) &&
  1892. check_fuse_shape(conv_bias, z_inp) &&
  1893. #if MGB_CUDA && (CUDNN_MAJOR == 8)
  1894. check_fuse_param(conv_bias, z_inp) &&
  1895. #endif
  1896. check_fuse_dtype(conv_bias, z_inp)) {
  1897. auto param = conv_bias->param();
  1898. param.nonlineMode = get_convbias_nonline_mode(opr);
  1899. auto config = conv_bias->config();
  1900. auto new_var = opr::ConvBiasForward::make(
  1901. conv_bias->input(0), conv_bias->input(1),
  1902. conv_bias->input(2), z_inp, param,
  1903. conv_bias->execution_policy(),
  1904. config.output_dtype(opr->output(0)->dtype()))
  1905. .node();
  1906. rewriter.replace_var(
  1907. opr->output(0), new_var,
  1908. mgb_cstr_log("replace "
  1909. "nonlinearity(conv_bias(x,w,b) + z) "
  1910. "-> conv_bias(x, w, b, z)"));
  1911. uniq_reader_check.update_on_opr_auto_replace(opr,
  1912. new_var->owner_opr());
  1913. return true;
  1914. }
  1915. return false;
  1916. };
  1917. auto try_fuse_elemwise = [&](OperatorNodeBase* opr) {
  1918. if (!opr->same_type<opr::Elemwise>())
  1919. return false;
  1920. auto elem = try_cast_as_op<opr::Elemwise>(opr);
  1921. if (elem->input().size() != 2)
  1922. return false;
  1923. if (elem->param().mode != Mode::ADD &&
  1924. elem->param().mode != Mode::FUSE_ADD_RELU)
  1925. return false;
  1926. return try_replace_var_node(opr);
  1927. };
  1928. auto try_fuse_elemwise_multi_type = [&](OperatorNodeBase* opr) {
  1929. if (!opr->same_type<opr::ElemwiseMultiType>())
  1930. return false;
  1931. auto elem = try_cast_as_op<opr::ElemwiseMultiType>(opr);
  1932. if (elem->input().size() != 2)
  1933. return false;
  1934. if (elem->param().mode != MultiMode::QADD &&
  1935. elem->param().mode != MultiMode::QFUSE_ADD_RELU &&
  1936. elem->param().mode != MultiMode::QFUSE_ADD_H_SWISH)
  1937. return false;
  1938. return try_replace_var_node(opr);
  1939. };
  1940. auto on_opr = [&](OperatorNodeBase* opr) {
  1941. if (try_fuse_elemwise(opr))
  1942. return;
  1943. if (try_fuse_elemwise_multi_type(opr))
  1944. return;
  1945. auto new_opr = rewriter.auto_replace_outputs(opr);
  1946. uniq_reader_check.update_on_opr_auto_replace(opr, new_opr);
  1947. };
  1948. state.graph().iter(on_opr);
  1949. rewriter.apply_inplace();
  1950. MIDOUT_E
  1951. }
  1952. /* ================ FuseDeconvCvtPass ================ */
  1953. const char* FuseDeconvCvtPass::name() const {
  1954. return "combine_deconv_and_typecvt";
  1955. }
  1956. void FuseDeconvCvtPass::apply(OptState& state) const {
  1957. MIDOUT_B("FuseDeconvCvtPass::apply")
  1958. std::unordered_map<VarNode*, std::vector<OperatorNodeBase*>> m_deps;
  1959. state.graph().iter([&m_deps](OperatorNodeBase* opr) {
  1960. for (auto& inp : opr->input()) {
  1961. m_deps[inp].push_back(opr);
  1962. }
  1963. });
  1964. UniqReaderCheck uniq_reader_check{state.graph()};
  1965. auto rewriter = state.graph().make_rewriter();
  1966. auto try_fuse_deconv_typecvt =
  1967. [&](opr::TypeCvt* typecvt) -> OperatorNodeBase* {
  1968. mgb_assert(typecvt->input().size() == 1);
  1969. auto deconv = try_cast_as_op<opr::ConvolutionBackwardData>(
  1970. rewriter.get_var(typecvt->input(0))->owner_opr());
  1971. if (!deconv
  1972. || m_deps.count(typecvt->input(0)) != 1 ||
  1973. typecvt->output(0)->dtype().enumv() !=
  1974. DTypeTrait<dtype::QuantizedS8>::enumv) {
  1975. return nullptr;
  1976. }
  1977. if (!uniq_reader_check(deconv->output(0)))
  1978. return nullptr;
  1979. auto config = deconv->config();
  1980. config.output_dtype(typecvt->output(0)->dtype());
  1981. return opr::ConvolutionBackwardData::make(
  1982. deconv->input(0), deconv->input(1), deconv->param(),
  1983. deconv->execution_policy(), config)
  1984. .node()
  1985. ->owner_opr();
  1986. };
  1987. auto on_opr = [&](OperatorNodeBase* opr) {
  1988. if (auto typecvt = try_cast_as_op<opr::TypeCvt>(opr)) {
  1989. if (auto deconv_new = try_fuse_deconv_typecvt(typecvt)) {
  1990. rewriter.replace_var(
  1991. opr->output(0), deconv_new->output(0),
  1992. mgb_cstr_log("replace typecvt(deconv(x, w)) -> "
  1993. "deconv(x, w)"));
  1994. uniq_reader_check.update_on_opr_auto_replace(opr, deconv_new);
  1995. return;
  1996. }
  1997. }
  1998. auto new_opr = rewriter.auto_replace_outputs(opr);
  1999. uniq_reader_check.update_on_opr_auto_replace(
  2000. opr, new_opr);
  2001. };
  2002. state.graph().iter(on_opr);
  2003. rewriter.apply_inplace();
  2004. MIDOUT_E
  2005. }
  2006. /* ================ ParamMergePass ================ */
  2007. const char* ParamMergePass::name() const {
  2008. return mgb_cstr_log("param_merge");
  2009. }
  2010. void ParamMergePass::apply(OptState& opt_state) const {
  2011. MIDOUT_B("ParamMergePass::apply")
  2012. param_merge<opr::SharedDeviceTensor, opr::MultipleDeviceTensorHolder>(
  2013. opt_state);
  2014. param_merge<opr::SharedDeviceTensorWithFormat,
  2015. opr::MultipleDeviceTensorWithFormatHolder>(opt_state);
  2016. MIDOUT_E
  2017. }
  2018. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台