You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

framework.cpp 32 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916
  1. /**
  2. * \file src/gopt/impl/framework.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megbrain/gopt/framework.h"
  13. #include "megbrain/gopt/basic_arith.h"
  14. #include "megbrain/gopt/gtrans.h"
  15. #include "megbrain/gopt/inference.h"
  16. #include "megbrain/gopt/misc.h"
  17. #include "megbrain/graph/cg.h"
  18. #include "megbrain/graph/event.h"
  19. #include "megbrain/graph/exc_extra_info.h"
  20. #include "megbrain/serialization/opr_shallow_copy.h"
  21. #include "megbrain/serialization/serializer.h"
  22. #include "megbrain/utils/timer.h"
  23. #if MGB_JIT
  24. #include "megbrain/jit/fusion_pass.h"
  25. #endif
  26. #if MGB_ENABLE_TENSOR_RT
  27. #include "megbrain/tensorrt/opr_replace.h"
  28. #endif
  29. #include "megbrain/gopt/layout_transform_context.h"
  30. #include "megbrain/gopt/layout_transform_pass.h"
  31. #include "megbrain/gopt/profiler.h"
  32. #include "megbrain/gopt/solver.h"
  33. using namespace mgb;
  34. using namespace gopt;
  35. /* ================ SubGraph ================ */
  36. OperatorNodeBase* SubGraph::Rewriter::auto_replace_outputs(OperatorNodeBase* opr) {
  37. auto&& new_inp = m_opr_new_inp_cache;
  38. new_inp.clear();
  39. new_inp.reserve(opr->input().size());
  40. bool has_replaced_inp = false;
  41. for (auto i : opr->input()) {
  42. auto new_var = get_var(i);
  43. if (new_var != i) {
  44. has_replaced_inp = true;
  45. new_inp.push_back(new_var);
  46. } else {
  47. new_inp.push_back(i);
  48. }
  49. }
  50. if (has_replaced_inp) {
  51. auto new_opr = serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  52. auto &&out0 = opr->output(), &&out1 = new_opr->output();
  53. size_t i = 0;
  54. auto err_msg = [opr, new_opr] {
  55. return ssprintf(
  56. "bad opr copy: src=%s{%s} dst=%s{%s}", opr->cname(),
  57. opr->dyn_typeinfo()->name, new_opr->cname(),
  58. new_opr->dyn_typeinfo()->name);
  59. };
  60. MGB_MARK_USED_VAR(err_msg);
  61. // opr output size mismatch may be caused by:
  62. // 0) inplace arith optimization (e.g. PowC need an extra workspace)
  63. // 1) other post-insert optimization (e.g. const folding)
  64. // we can't handle only usable_output here, since some output var with
  65. // volatile flag could be the graph's endpoint (e.g. RemoteSend)
  66. for (; i < std::min(out0.size(), out1.size()); ++i) {
  67. bool v0 = out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT),
  68. v1 = out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT);
  69. mgb_assert(v0 == v1, "%s", err_msg().c_str());
  70. auto&& ins = m_varmap.insert({out0[i], {true, nullptr}});
  71. mgb_assert(
  72. ins.second || ins.first->second.first,
  73. "opr output already replaced");
  74. // handle repeated call on the same opr
  75. ins.first->second.second = out1[i];
  76. on_var_replaced(out0[i], out1[i], nullptr);
  77. }
  78. for (; i < out0.size(); ++i) {
  79. mgb_assert(
  80. out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT), "%s",
  81. err_msg().c_str());
  82. }
  83. for (; i < out1.size(); ++i) {
  84. mgb_assert(
  85. out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT), "%s",
  86. err_msg().c_str());
  87. }
  88. return new_opr;
  89. }
  90. return opr;
  91. }
  92. void SubGraph::Rewriter::replace_var(VarNode* src, VarNode* dst, const char* msg) {
  93. if (src == dst)
  94. return;
  95. // Optimizers should not create a loop in varaible replace map.
  96. mgb_throw_if(
  97. get_var_internal(dst).second == src, InternalError,
  98. "dst %s maps back to src %s in SubGraph::Rewriter::replace_var",
  99. dst->cname(), src->cname());
  100. auto&& ins = m_varmap.insert({src, {false, dst}});
  101. if (!ins.second) {
  102. auto&& old_rep = ins.first->second;
  103. mgb_assert(
  104. old_rep.first || old_rep.second == dst, "can not replace a var twice");
  105. old_rep.first = false;
  106. old_rep.second = dst;
  107. }
  108. on_var_replaced(src, dst, msg);
  109. }
  110. void SubGraph::Rewriter::on_var_replaced(VarNode* src, VarNode* dst, const char* msg) {
  111. if (auto state = m_owner_graph->owner_opt_state()) {
  112. state->on_var_replaced(src, dst, msg);
  113. }
  114. }
  115. void SubGraph::Rewriter::apply_inplace() const {
  116. m_owner_graph->m_endpoint_oprs.clear();
  117. m_owner_graph->m_endpoint_vars_set.clear();
  118. for (auto&& var : m_owner_graph->m_endpoint_vars) {
  119. var = get_var(var.node());
  120. m_owner_graph->m_endpoint_oprs.insert(var.node()->owner_opr());
  121. m_owner_graph->m_endpoint_vars_set.insert(var.node());
  122. }
  123. }
  124. std::pair<bool, VarNode*> SubGraph::Rewriter::get_var_internal(VarNode* var) {
  125. // The implementation is (manually) unrolled once, background:
  126. // git-core/brain-sdk/MegBrain/merge_requests/486#note_76971
  127. auto it = m_varmap.find(var);
  128. if (it == m_varmap.end()) {
  129. return {true, var};
  130. }
  131. mgb_assert(it->second.second != var, "loop detected in m_varmap");
  132. auto it_next = m_varmap.find(it->second.second);
  133. if (it_next == m_varmap.end()) {
  134. return it->second;
  135. }
  136. mgb_assert(
  137. it_next->second.second != it->second.second, "loop detected in m_varmap");
  138. auto next = get_var_internal(it_next->second.second);
  139. it_next->second = {next.first & it_next->second.first, next.second};
  140. return it->second = {it_next->second.first & it->second.first, next.second};
  141. }
  142. SubGraph::SubGraph(const SymbolVarArray& endpoint_vars)
  143. : m_endpoint_vars(endpoint_vars) {
  144. mgb_assert(!endpoint_vars.empty(), "endpoints can not be empty");
  145. m_comp_graph = endpoint_vars[0].node()->owner_graph();
  146. for (auto i : endpoint_vars) {
  147. m_endpoint_oprs.insert(i.node()->owner_opr());
  148. m_endpoint_vars_set.insert(i.node());
  149. mgb_assert(
  150. m_comp_graph == i.node()->owner_graph(),
  151. "endpoints belong to different computing graphs");
  152. }
  153. }
  154. void SubGraph::iter(const Callback& cb, std::shared_ptr<ExtraDep> extra_dep) const {
  155. Callback on_opr;
  156. if (m_owner_opt_state) {
  157. on_opr = [state = m_owner_opt_state, &cb](OperatorNodeBase* opr) {
  158. state->m_opr_property_flag = OprPropertyFlag::ALL;
  159. state->m_cur_iter_src_opr = cg::get_opr_root_source_opr(opr);
  160. state->m_cur_iter_opr_priority = opr->node_prop().attribute().priority;
  161. state->m_cur_iter_opr_stream_prop_type =
  162. state->m_comp_node_opt.stream_prop_type(opr->output(0));
  163. mgb_assert(state->m_oprs_inserted.empty());
  164. cb(opr);
  165. state->m_opr_property_flag = OprPropertyFlag::NONE;
  166. state->m_cur_iter_src_opr = nullptr;
  167. state->m_oprs_inserted.clear();
  168. };
  169. } else {
  170. on_opr = cb;
  171. }
  172. cg::DepOprIter dep_iter{on_opr, std::move(extra_dep)};
  173. for (auto i : m_endpoint_oprs)
  174. dep_iter.add(i);
  175. }
  176. ThinHashMap<VarNode*, size_t> SubGraph::get_var2nr_val_dep_oprs() const {
  177. ThinHashMap<VarNode*, size_t> ret;
  178. auto cb = [&](OperatorNodeBase* opr) {
  179. for (auto&& i : opr->node_prop().dep_map()) {
  180. if (OperatorNodeBase::NodeProp::is_device_value_dep(i.second)) {
  181. ++ret.at(i.first);
  182. }
  183. }
  184. for (auto i : opr->output()) {
  185. if (!i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  186. auto ins = ret.insert({i, 0});
  187. mgb_assert(ins.second);
  188. }
  189. }
  190. };
  191. iter(cb);
  192. for (auto i : m_endpoint_vars_set) {
  193. auto iter = ret.find(i);
  194. if (iter == ret.end()) {
  195. mgb_assert(i->contain_flag(VarNode::Flag::VOLATILE_CONTENT));
  196. ret[i] = 1;
  197. } else {
  198. ++ret.at(i);
  199. }
  200. }
  201. return ret;
  202. }
  203. /* ================ UniqReaderCheck ================ */
  204. UniqReaderCheck::UniqReaderCheck(const SubGraph& graph)
  205. : m_var2nr_val_dep{graph.get_var2nr_val_dep_oprs()} {}
  206. void UniqReaderCheck::update_on_opr_auto_replace(
  207. OperatorNodeBase* opr, OperatorNodeBase* repl_opr) {
  208. auto non_volatile_size = [](const VarNodeArray& vars) -> size_t {
  209. size_t size = 0;
  210. for (size_t i = 0; i < vars.size(); ++i) {
  211. if (!vars[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  212. size++;
  213. }
  214. }
  215. return size;
  216. };
  217. if (opr != repl_opr) {
  218. auto &&o0 = opr->output(), &&o1 = repl_opr->output();
  219. mgb_assert(non_volatile_size(o0) == non_volatile_size(o1));
  220. for (size_t i = 0; i < o0.size(); ++i) {
  221. auto iter = m_var2nr_val_dep.find(o0[i]);
  222. if (iter != m_var2nr_val_dep.end()) {
  223. auto n = iter->second;
  224. m_var2nr_val_dep[o1[i]] = n;
  225. }
  226. }
  227. }
  228. }
  229. /* ================ OptState ================ */
  230. OptState::OptState(const GraphOptimizer* owner_optimizer, const SubGraph& graph)
  231. : m_owner_optimizer{owner_optimizer},
  232. m_var_replace_map{const_cast<ThinHashMap<VarNode*, VarNode*>*>(
  233. &GraphOptimizer::var_replace_map(*graph.comp_graph()))},
  234. m_comp_node_opt{graph.comp_graph()->seq_comp_node_optimizer()},
  235. m_graph{graph} {
  236. mgb_assert(!m_graph.m_owner_opt_state);
  237. m_var_replace_map->clear();
  238. m_graph.m_owner_opt_state = this;
  239. m_oprs_inserted.clear();
  240. auto on_opr_insert = [this](const cg::event::OprInserted& ev) {
  241. auto need_src_opr = m_opr_property_flag & OprPropertyFlag::SOURCE_OPR,
  242. need_priority = m_opr_property_flag & OprPropertyFlag::PRIORITY;
  243. if (need_src_opr)
  244. mgb_assert(
  245. m_cur_iter_src_opr,
  246. "opr %s{%s} created outside from "
  247. "SubGraph::iter",
  248. ev.opr->cname(), ev.opr->dyn_typeinfo()->name);
  249. if (ev.exc || ev.is_dedup)
  250. return;
  251. auto&& new_attr = ev.opr->node_prop().attribute();
  252. auto&& ins = m_oprs_inserted.insert({ev.opr, OprPropertyFlag::NONE});
  253. mgb_assert(ins.second);
  254. if (need_src_opr && !new_attr.src_opr) {
  255. auto src_opr = m_cur_iter_src_opr;
  256. if (ev.opr != src_opr)
  257. new_attr.src_opr = src_opr;
  258. ins.first->second |= OprPropertyFlag::SOURCE_OPR;
  259. }
  260. if (need_priority) {
  261. new_attr.priority = m_cur_iter_opr_priority;
  262. if (!ev.opr->update_priority()) {
  263. ins.first->second |= OprPropertyFlag::PRIORITY;
  264. }
  265. }
  266. auto csp = m_cur_iter_opr_stream_prop_type;
  267. if (csp.prop_type != cg::SeqCompNodeOptimizer::StreamPropType::NONE) {
  268. for (auto i : ev.opr->output())
  269. m_comp_node_opt.register_stream_var(i, csp);
  270. }
  271. };
  272. m_on_opr_insert_handler =
  273. graph.comp_graph()->event().register_receiver<cg::event::OprInserted>(
  274. on_opr_insert);
  275. }
  276. void OptState::on_var_replaced(VarNode* src, VarNode* dst, const char* msg) {
  277. if (src->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  278. // this can only happen in auto_replace_outputs()
  279. mgb_assert(
  280. dst->contain_flag(VarNode::Flag::VOLATILE_CONTENT) &&
  281. src->owner_opr()->dyn_typeinfo() == dst->owner_opr()->dyn_typeinfo());
  282. mgb_assert(!msg);
  283. return;
  284. }
  285. //! check_property
  286. {
  287. auto iter = m_oprs_inserted.find(dst->owner_opr());
  288. if (iter != m_oprs_inserted.end()) {
  289. auto &&src_attr = src->owner_opr()->node_prop().attribute(),
  290. &&dst_attr = dst->owner_opr()->node_prop().attribute();
  291. auto opr_info = [&](OperatorNodeBase* opr) {
  292. return opr ? opr->name() + "(" + std::to_string(opr->id()) + ")"
  293. : "NULL";
  294. };
  295. auto err_msg = [&] {
  296. std::string ret = "Please contact Engine group:\n";
  297. ret += "src opr: ";
  298. ret += opr_info(src->owner_opr());
  299. ret += ", dst opr: ";
  300. ret += opr_info(dst->owner_opr());
  301. return ret;
  302. };
  303. MGB_MARK_USED_VAR(err_msg);
  304. if (iter->second & OprPropertyFlag::SOURCE_OPR) {
  305. auto &&src_rt = get_opr_root_source_opr(src->owner_opr()),
  306. &&dst_rt = get_opr_root_source_opr(dst->owner_opr());
  307. mgb_assert(
  308. dst_rt == src_rt,
  309. "%s\nsrc source_opr: %s, dst source_opr: %s\n",
  310. err_msg().c_str(), opr_info(src_rt).c_str(),
  311. opr_info(dst_rt).c_str());
  312. }
  313. if (iter->second & OprPropertyFlag::PRIORITY) {
  314. mgb_assert(
  315. src_attr.priority == dst_attr.priority,
  316. "%s\nsrc priority: %d, dst priority %d\n", err_msg().c_str(),
  317. src_attr.priority, dst_attr.priority);
  318. }
  319. }
  320. }
  321. {
  322. bool suc = true;
  323. SmallVector<std::string> fail_chks;
  324. if (m_var_replace_check_flag & VarReplaceCheckFlag::CHECK_INFER_TYPE) {
  325. auto&& mgr = src->owner_graph()->static_infer_manager();
  326. auto it0 = mgr.get_infer_type(src), it1 = mgr.get_infer_type(dst);
  327. using cg::static_infer::InferType;
  328. // only check wheter inferable
  329. auto norm = [](InferType::Flag f) -> bool {
  330. return f & (InferType::RT_STATIC | InferType::CONST);
  331. };
  332. if (!(norm(it0.shape) == norm(it1.shape) &&
  333. norm(it0.value) <= norm(it1.value))) {
  334. suc = false;
  335. fail_chks.push_back("infer-type");
  336. }
  337. }
  338. if (m_var_replace_check_flag & VarReplaceCheckFlag::CHECK_DTYPE) {
  339. if (src->dtype() != dst->dtype()) {
  340. suc = false;
  341. fail_chks.push_back("dtype");
  342. }
  343. }
  344. if (m_var_replace_check_flag & VarReplaceCheckFlag::CHECK_SHAPE) {
  345. if (!(src->shape().eq_shape(dst->shape()))) {
  346. suc = false;
  347. fail_chks.push_back("shape");
  348. }
  349. }
  350. if (!suc) {
  351. std::string fail_msg = "{";
  352. for (size_t i = 0; i < fail_chks.size(); i++) {
  353. fail_msg += fail_chks[i];
  354. if (i < fail_chks.size() - 1) {
  355. fail_msg += ",";
  356. }
  357. }
  358. fail_msg += "}";
  359. mgb_throw_raw(
  360. cg::OperatorNodeExcExtraInfo::ExcMaker{src->owner_opr()}
  361. .make<InternalError>(ssprintf(
  362. "%s mismatch for replace_var: %s", fail_msg.c_str(),
  363. cg::dump_var_info({src, dst}).c_str())));
  364. }
  365. }
  366. if (src->has_name_set() && !dst->has_name_set()) {
  367. dst->name(src->name());
  368. }
  369. (*m_var_replace_map)[src] = dst;
  370. // dst should be considered as newly inserted, and previous replace
  371. // record should be ignored
  372. m_var_replace_map->erase(dst);
  373. #if MGB_ENABLE_LOGGING
  374. if (msg && m_owner_optimizer->verbosity()) {
  375. m_log_msg.append("\n ")
  376. .append(std::to_string(m_log_nr_item))
  377. .append(": ")
  378. .append(src->owner_opr()->cname())
  379. .append(" => ")
  380. .append(dst->owner_opr()->cname())
  381. .append(" (")
  382. .append(msg)
  383. .append(")");
  384. }
  385. ++m_log_nr_item;
  386. #endif
  387. }
  388. size_t OptState::flush_log(const char* title) {
  389. if (m_owner_optimizer->verbosity() >= 2) {
  390. if (m_log_msg.empty()) {
  391. m_log_msg = mgb_cstr_log(" no var replacement logged");
  392. }
  393. mgb_log("%s%s", title, m_log_msg.c_str());
  394. m_log_msg.clear();
  395. }
  396. auto ret = m_log_nr_item;
  397. m_log_nr_item = 0;
  398. return ret;
  399. }
  400. void OptState::call_with_opr(
  401. OperatorNodeBase* opr, thin_function<void(void)> func,
  402. OprPropertyFlag opr_property_flag) {
  403. auto src_opr = cg::get_opr_root_source_opr(opr);
  404. auto opr_priority = opr->node_prop().attribute().priority;
  405. auto stream_prop_type = m_comp_node_opt.stream_prop_type(opr->output(0));
  406. ThinHashMap<OperatorNodeBase*, OprPropertyFlag> oprs_inserted;
  407. auto swap_properties =
  408. [&, need_src_opr = opr_property_flag & OprPropertyFlag::SOURCE_OPR,
  409. need_priority = opr_property_flag & OprPropertyFlag::PRIORITY] {
  410. if (need_src_opr) {
  411. std::swap(m_cur_iter_src_opr, src_opr);
  412. }
  413. if (need_priority) {
  414. std::swap(m_cur_iter_opr_priority, opr_priority);
  415. }
  416. std::swap(m_cur_iter_opr_stream_prop_type, stream_prop_type);
  417. std::swap(m_opr_property_flag, opr_property_flag);
  418. std::swap(m_oprs_inserted, oprs_inserted);
  419. };
  420. MGB_TRY {
  421. swap_properties();
  422. func();
  423. }
  424. MGB_FINALLY({ swap_properties(); });
  425. }
  426. /* ================ RecursiveSubGraphRewriteHelper ================ */
  427. RecursiveSubGraphRewriteHelper::~RecursiveSubGraphRewriteHelper() noexcept = default;
  428. RecursiveSubGraphRewriteHelper::RecursiveSubGraphRewriteHelper(OptState& state)
  429. : m_opt_state{state}, m_rewriter{state.graph().make_rewriter()} {}
  430. void RecursiveSubGraphRewriteHelper::apply() {
  431. using namespace std::placeholders;
  432. m_opt_state.graph().iter(
  433. std::bind(&RecursiveSubGraphRewriteHelper::on_opr, this, _1));
  434. m_rewriter.apply_inplace();
  435. }
  436. void RecursiveSubGraphRewriteHelper::on_opr(OperatorNodeBase* opr) {
  437. auto on_new_opr = [this](OperatorNodeBase* opr) {
  438. auto repl_opr = m_rewriter.auto_replace_outputs(opr);
  439. return on_new_opr_check_should_process(opr, repl_opr);
  440. };
  441. if (!on_new_opr(opr))
  442. return;
  443. auto orig_out = get_opr_single_output_var(opr);
  444. if (!orig_out)
  445. return;
  446. mgb_assert(m_opr_stack.empty());
  447. m_opr_stack.push_back({orig_out, m_rewriter.get_var(orig_out)->owner_opr()});
  448. bool first = true;
  449. while (!m_opr_stack.empty()) {
  450. auto cur_frame = m_opr_stack.back();
  451. m_opr_stack.pop_back();
  452. auto cur_opr = cur_frame.opr;
  453. bool should_process;
  454. if (first) {
  455. should_process = true;
  456. first = false;
  457. } else {
  458. should_process = on_new_opr(cur_opr);
  459. }
  460. auto cur_out = get_opr_single_output_var(cur_opr);
  461. mgb_assert(cur_out);
  462. cur_out = m_rewriter.get_var(cur_out);
  463. if (should_process) {
  464. auto trans = process_opr(cur_out);
  465. if (trans.valid()) {
  466. m_opr_stack.push_back({cur_frame.orig_var, trans->result->owner_opr()});
  467. for (auto i : reverse_adaptor(trans->internal)) {
  468. if (i)
  469. m_opr_stack.push_back({i, i->owner_opr()});
  470. }
  471. if (trans->msg) {
  472. if (!m_log_msg.empty())
  473. m_log_msg.push_back(';');
  474. m_log_msg.append(trans->msg);
  475. }
  476. continue;
  477. }
  478. }
  479. auto src = cur_frame.orig_var;
  480. if (m_rewriter.get_var(src) != cur_out) {
  481. const char* msg = nullptr;
  482. if (m_opr_stack.empty()) {
  483. msg = m_log_msg.c_str();
  484. }
  485. m_rewriter.replace_var(src, cur_out, msg);
  486. after_replace_var(src, cur_out);
  487. if (m_opr_stack.empty()) {
  488. m_log_msg.clear();
  489. break;
  490. }
  491. }
  492. }
  493. }
  494. /* ================ GraphOptimizer ================ */
  495. GraphOptimizer::~GraphOptimizer() noexcept = default;
  496. class GraphOptimizer::VarReplaceMapStorage : public UserDataContainer::UserData {
  497. MGB_TYPEINFO_OBJ_DECL;
  498. public:
  499. ThinHashMap<VarNode*, VarNode*> map;
  500. };
  501. MGB_TYPEINFO_OBJ_IMPL(GraphOptimizer::VarReplaceMapStorage);
  502. GraphOptimizer& GraphOptimizer::add_pass(std::unique_ptr<Pass> pass) {
  503. mgb_assert(!pass->m_owner_optimizer);
  504. pass->m_owner_optimizer = this;
  505. m_passes.emplace_back(std::move(pass));
  506. return *this;
  507. }
  508. SubGraph GraphOptimizer::apply(const SubGraph& graph) const {
  509. RealTimer timer;
  510. OptState state{this, graph};
  511. size_t tot_nr_replace = 0;
  512. // first update output var shapes of all oprs
  513. state.graph().iter(cg::update_output_var_shapes);
  514. auto&& opt = graph.comp_graph()->options();
  515. auto orig_setting = opt.graph_opt_level;
  516. Pass* cur_pass = nullptr;
  517. MGB_MARK_USED_VAR(cur_pass);
  518. MGB_TRY {
  519. for (auto&& i : m_passes) {
  520. state.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL);
  521. cur_pass = i.get();
  522. opt.graph_opt_level = 1;
  523. i->apply(state);
  524. tot_nr_replace += state.flush_log(
  525. mgb_ssprintf_log("apply optimization pass %s:", i->name()).c_str());
  526. }
  527. }
  528. MGB_CATCH(std::exception & exc, {
  529. mgb_log_error(
  530. "error while applying optimization pass %s: %s", cur_pass->name(),
  531. exc.what());
  532. opt.graph_opt_level = orig_setting;
  533. throw;
  534. })
  535. MGB_FINALLY(opt.graph_opt_level = orig_setting);
  536. if (verbosity() >= 1) {
  537. mgb_log_debug(
  538. "graph optimization: applied %zu passes, "
  539. "total %zu var(s) replaced; time=%.2fms",
  540. m_passes.size(), tot_nr_replace, timer.get_msecs());
  541. }
  542. return state.graph();
  543. }
  544. const GraphOptimizer& GraphOptimizer::apply_inplace(VarNodeArray& vars) const {
  545. if (m_passes.empty()) {
  546. // this check is necessary, since OptState would clear
  547. // var_replace_map()
  548. return *this;
  549. }
  550. auto g = apply({{vars.begin(), vars.end()}});
  551. for (size_t i = 0; i < vars.size(); ++i) {
  552. vars[i] = g.endpoint_vars()[i].node();
  553. }
  554. return *this;
  555. }
  556. GraphOptimizer& GraphOptimizer::add_preset_passes(
  557. bool after_grad, const OptimizeForInferenceOptions* inference_opt,
  558. const ComputingGraph::Options* comp_graph_opt) {
  559. auto cv_type =
  560. inference_opt ? ConstVarType::IMMUTABLE_AND_PARAM : ConstVarType::IMMUTABLE;
  561. if (inference_opt) {
  562. add_pass<ConvertBatchNormToElemwisePass>();
  563. }
  564. if (!after_grad || inference_opt) {
  565. add_pass<CondExecConstPredicateFolding>();
  566. }
  567. if (after_grad || inference_opt) {
  568. add_pass<RemoveNonComputingOprPass>();
  569. }
  570. add_pass<DelayBroadcastPass>();
  571. add_pass<ExpandFusedArithPass>();
  572. add_pass<NormalizeArithChainPass>();
  573. if (inference_opt) {
  574. add_pass<ParamRedistributePass>();
  575. add_pass<ParamFusePass>();
  576. }
  577. add_pass<ArithMulDistributePass>();
  578. add_pass<ReorderArithChainPass>(cv_type);
  579. add_pass<ArithFusePass>();
  580. // reorder again because shapes of fused oprs might change
  581. add_pass<ReorderArithChainPass>(cv_type);
  582. add_pass<FinalArithTransformPass>();
  583. add_pass<RemoveRedundantTypeCvtPass>();
  584. add_pass<RemoveRedundantCopyPass>();
  585. #if MGB_JIT
  586. using JITConfig = cg::ComputingGraph::Options::GraphOpt::JITConfig;
  587. int jit_opt_level = 0;
  588. JITConfig jit_config;
  589. // for more detail on what is happening here, see comments on the
  590. // constuctor of class JITFusionPass in fusion_pass.h
  591. if (comp_graph_opt) {
  592. jit_opt_level = comp_graph_opt->graph_opt.jit;
  593. if (comp_graph_opt->graph_opt_level >= 3) {
  594. jit_opt_level = std::max(jit_opt_level, 1);
  595. }
  596. jit_config = comp_graph_opt->graph_opt.jit_config;
  597. }
  598. bool need_jit = (jit_opt_level > 0) || jit_config.enabled();
  599. if (need_jit && after_grad) {
  600. add_pass<gopt::RecompTypeCvtPass>();
  601. }
  602. #endif
  603. // combine astype and reduce.
  604. // Note: apply this pass before JITFusion, so the TypeCvt which
  605. // read by both Reduce and Elemwise could be fused correctly.
  606. add_pass<CombineAstypeAndReducePass>();
  607. #if MGB_JIT
  608. if (need_jit) {
  609. add_pass<gopt::JITFusionPass>(after_grad, jit_opt_level, jit_config);
  610. }
  611. #endif
  612. if (inference_opt) {
  613. add_pass<ParamFusePass>();
  614. add_passes_for_optimize_options(*inference_opt);
  615. }
  616. if (inference_opt) {
  617. // merge params to reduce loading time and graph overhead
  618. add_pass<ParamMergePass>();
  619. add_pass<FuseDeconvCvtPass>();
  620. }
  621. if (inference_opt) {
  622. // remove shape hint after inference optimization
  623. add_pass<RemoveShapeHintPass>();
  624. }
  625. return *this;
  626. }
  627. const ThinHashMap<VarNode*, VarNode*>& GraphOptimizer::var_replace_map(
  628. ComputingGraph& graph) {
  629. auto storage =
  630. graph.options().user_data.get_user_data_or_create<VarReplaceMapStorage>();
  631. return storage->map;
  632. }
  633. VarNode* GraphOptimizer::var_replace_lookup(VarNode* var) {
  634. auto&& map = var_replace_map(*(var->owner_graph()));
  635. for (;;) {
  636. auto iter = map.find(var);
  637. if (iter == map.end())
  638. return var;
  639. var = iter->second;
  640. }
  641. }
  642. const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
  643. const cg::GraphCommonOptimizeOptions& options) {
  644. return add_passes_for_optimize_options(
  645. const_cast<cg::GraphCommonOptimizeOptions&>(options));
  646. }
  647. const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
  648. cg::GraphCommonOptimizeOptions& options, bool reset) {
  649. bool need_param_fuse = false;
  650. #define cb(_option, _passes) \
  651. if (options.has_set_##_option()) { \
  652. _passes need_param_fuse = true; \
  653. if (reset) { \
  654. options.disable_##_option(); \
  655. } \
  656. }
  657. cb(fuse_preprocess, {
  658. add_pass(FuseNCHW4Int8Preprocess::make());
  659. add_pass<FuseWarpPerspectiveDimshufflePass>();
  660. });
  661. cb(f16_io_comp, { add_pass(ConvertF32ToF16Pass::make(false)); });
  662. cb(f16_io_f32_comp, { add_pass(ConvertF32ToF16Pass::make(true)); });
  663. cb(nchw4, {
  664. add_pass<FuseConvBiasNonlinPass>();
  665. add_pass<FuseConvBiasZPass>();
  666. add_pass(EnableNCHW4Pass::make_nchw4_converter());
  667. add_pass<ShuffleShuffleRemovePass>();
  668. add_pass<RemoveRedundantTypeCvtPass>();
  669. });
  670. cb(nhwcd4, {
  671. add_pass<FuseConvBiasNonlinPass>();
  672. add_pass(ConvertFormatPass::make_nhwcd4_converter());
  673. });
  674. cb(nchw88, {
  675. add_pass<FuseConvBiasNonlinPass>();
  676. add_pass(EnableNchwxxPass::make_nchwxx_converter(8));
  677. add_pass<ShuffleShuffleRemovePass>();
  678. });
  679. cb(nchw44, {
  680. add_pass<FuseConvBiasNonlinPass>();
  681. add_pass(EnableNchwxxPass::make_nchwxx_converter(4));
  682. add_pass<ShuffleShuffleRemovePass>();
  683. });
  684. cb(nchw44_dot, {
  685. add_pass<FuseConvBiasNonlinPass>();
  686. add_pass(EnableNchw44DotPass::make_nchw44_dot_converter());
  687. add_pass<ShuffleShuffleRemovePass>();
  688. });
  689. cb(nchw32, {
  690. add_pass<FuseConvBiasNonlinPass>();
  691. add_pass<FuseConvBiasZPass>();
  692. add_pass(EnableNCHW4Pass::make_nchw4_converter());
  693. add_pass(EnableTensorCorePass::make_tensorcore_converter());
  694. add_pass<ShuffleShuffleRemovePass>();
  695. add_pass<RemoveRedundantTypeCvtPass>();
  696. add_pass(FuseNCHW4Int8Preprocess::make());
  697. add_pass<FuseWarpPerspectiveDimshufflePass>();
  698. #if CUDA_VERSION >= 10020
  699. add_pass<FoldingConvBiasDimshufflePass>();
  700. #endif
  701. });
  702. cb(chwn4, {
  703. add_pass<FuseConvBiasNonlinPass>();
  704. add_pass<FuseConvBiasZPass>();
  705. add_pass(EnableNCHW4Pass::make_nchw4_converter());
  706. add_pass(EnableCHWN4Pass::make_chwn4_converter());
  707. add_pass<ShuffleShuffleRemovePass>();
  708. add_pass<RemoveRedundantTypeCvtPass>();
  709. });
  710. cb(nchw64, {
  711. add_pass<FuseConvBiasNonlinPass>();
  712. add_pass<PaddingChannelPass>();
  713. add_pass<FuseConvBiasZPass>();
  714. add_pass(EnableNCHW64Pass::make_nchw64_converter());
  715. add_pass<ShuffleShuffleRemovePass>();
  716. add_pass<RemoveRedundantTypeCvtPass>();
  717. add_pass(FuseNCHW4Int8Preprocess::make());
  718. add_pass<FuseWarpPerspectiveDimshufflePass>();
  719. #if CUDA_VERSION >= 10020
  720. add_pass<FoldingConvBiasDimshufflePass>();
  721. #endif
  722. });
  723. cb(fuse_conv_bias_nonlinearity, { add_pass<FuseConvBiasNonlinPass>(); });
  724. cb(fuse_conv_bias_with_z, {
  725. add_pass<FuseConvBiasNonlinPass>();
  726. add_pass<FuseConvBiasZPass>();
  727. });
  728. #undef cb
  729. if (need_param_fuse) {
  730. add_pass<ParamFusePass>();
  731. }
  732. return *this;
  733. }
  734. const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options(
  735. const GraphTuningOptions& options) {
  736. bool need_param_fuse = false;
  737. #define cb(_options, _passes) \
  738. if (options.has_set_##_options()) { \
  739. _passes need_param_fuse = true; \
  740. }
  741. using Target = GraphTuningOptions::Target;
  742. cb(layout_transform, {
  743. add_pass<FuseConvBiasNonlinPass>();
  744. if (options.target == Target::CUDA)
  745. add_pass<FuseConvBiasZPass>();
  746. add_pass(LayoutTransformPass::make(options.target));
  747. add_pass<ShuffleShuffleRemovePass>();
  748. if (options.target == Target::CUDA) {
  749. add_pass(FuseNCHW4Int8Preprocess::make());
  750. add_pass<FuseWarpPerspectiveDimshufflePass>();
  751. #if CUDA_VERSION >= 10020
  752. add_pass<FoldingConvBiasDimshufflePass>();
  753. add_pass<FoldingConvBiasTypecvtPass>();
  754. #endif
  755. }
  756. });
  757. #undef cb
  758. if (need_param_fuse) {
  759. add_pass<ParamFusePass>();
  760. add_pass<ParamMergePass>();
  761. }
  762. return *this;
  763. }
  764. /* ================ ConstVarPropogateBase ================ */
  765. ConstVarPropogate::AddOprResult ConstVarPropogate::add_opr(OperatorNodeBase* opr) {
  766. using ProfFlag = OperatorNodeBase::NodeProp::Flag;
  767. auto&& info = m_oprinfo[opr];
  768. if (info.processed)
  769. return info.result;
  770. info.processed = true;
  771. #if MGB_ENABLE_JSON
  772. (*opr->to_json_extra_json)["gopt::cvprop"] = json::Bool::make(false);
  773. #endif
  774. AddOprResult ret{false, false, false};
  775. auto make_ret = [&ret, &info]() {
  776. info.result = ret;
  777. return ret;
  778. };
  779. if (is_const_var(m_const_var_type, opr)) {
  780. auto sz = var_mem_size(opr->output(0));
  781. mgb_assert(
  782. sz || opr->output(0)->contain_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE));
  783. info.is_const = true;
  784. info.max_size = sz;
  785. return make_ret();
  786. }
  787. if (opr->input().empty())
  788. return make_ret();
  789. if (opr->node_prop().contain(
  790. ProfFlag::FORCE_UPDATE_INPUT_VAR | ProfFlag::IMPURE_FUNC)) {
  791. return make_ret();
  792. }
  793. size_t max_input_size = 0;
  794. ret.all_const_inp = true;
  795. for (auto i : opr->input()) {
  796. auto io = i->owner_opr();
  797. auto iter = m_oprinfo.find(io);
  798. if (iter == m_oprinfo.end()) {
  799. add_opr(io);
  800. iter = m_oprinfo.find(io);
  801. mgb_assert(iter != m_oprinfo.end());
  802. }
  803. auto&& src = iter->second;
  804. if (src.is_const) {
  805. update_max(max_input_size, src.max_size);
  806. ret.has_const_inp = true;
  807. if (!is_const_var(m_const_var_type, i->owner_opr())) {
  808. ret.has_midconst_inp = true;
  809. }
  810. } else {
  811. ret.all_const_inp = false;
  812. }
  813. }
  814. if (ret.all_const_inp) {
  815. #if MGB_ENABLE_JSON
  816. (*opr->to_json_extra_json)["gopt::cvprop"] = json::Bool::make(true);
  817. #endif
  818. info.max_size = max_input_size;
  819. info.is_const = true;
  820. }
  821. return make_ret();
  822. }
  823. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台