You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

framework.cpp 32 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918
  1. /**
  2. * \file src/gopt/impl/framework.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megbrain/gopt/framework.h"
  13. #include "megbrain/gopt/basic_arith.h"
  14. #include "megbrain/gopt/gtrans.h"
  15. #include "megbrain/gopt/inference.h"
  16. #include "megbrain/gopt/misc.h"
  17. #include "megbrain/graph/cg.h"
  18. #include "megbrain/graph/event.h"
  19. #include "megbrain/graph/exc_extra_info.h"
  20. #include "megbrain/serialization/opr_shallow_copy.h"
  21. #include "megbrain/serialization/serializer.h"
  22. #include "megbrain/utils/timer.h"
  23. #if MGB_JIT
  24. #include "megbrain/jit/fusion_pass.h"
  25. #endif
  26. #if MGB_ENABLE_TENSOR_RT
  27. #include "megbrain/tensorrt/opr_replace.h"
  28. #endif
  29. #include "megbrain/gopt/layout_transform_context.h"
  30. #include "megbrain/gopt/layout_transform_pass.h"
  31. #include "megbrain/gopt/profiler.h"
  32. #include "megbrain/gopt/solver.h"
  33. using namespace mgb;
  34. using namespace gopt;
  35. /* ================ SubGraph ================ */
  36. OperatorNodeBase* SubGraph::Rewriter::auto_replace_outputs(
  37. OperatorNodeBase* opr) {
  38. auto&& new_inp = m_opr_new_inp_cache;
  39. new_inp.clear();
  40. new_inp.reserve(opr->input().size());
  41. bool has_replaced_inp = false;
  42. for (auto i : opr->input()) {
  43. auto new_var = get_var(i);
  44. if (new_var != i) {
  45. has_replaced_inp = true;
  46. new_inp.push_back(new_var);
  47. } else {
  48. new_inp.push_back(i);
  49. }
  50. }
  51. if (has_replaced_inp) {
  52. auto new_opr =
  53. serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  54. auto &&out0 = opr->output(), &&out1 = new_opr->output();
  55. size_t i = 0;
  56. auto err_msg = [opr, new_opr] {
  57. return ssprintf("bad opr copy: src=%s{%s} dst=%s{%s}", opr->cname(),
  58. opr->dyn_typeinfo()->name, new_opr->cname(),
  59. new_opr->dyn_typeinfo()->name);
  60. };
  61. MGB_MARK_USED_VAR(err_msg);
  62. // opr output size mismatch may be caused by:
  63. // 0) inplace arith optimization (e.g. PowC need an extra workspace)
  64. // 1) other post-insert optimization (e.g. const folding)
  65. // we can't handle only usable_output here, since some output var with
  66. // volatile flag could be the graph's endpoint (e.g. RemoteSend)
  67. for (; i < std::min(out0.size(), out1.size()); ++i) {
  68. bool v0 = out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT),
  69. v1 = out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT);
  70. mgb_assert(v0 == v1, "%s", err_msg().c_str());
  71. auto&& ins = m_varmap.insert({out0[i], {true, nullptr}});
  72. mgb_assert(ins.second || ins.first->second.first,
  73. "opr output already replaced");
  74. // handle repeated call on the same opr
  75. ins.first->second.second = out1[i];
  76. on_var_replaced(out0[i], out1[i], nullptr);
  77. }
  78. for (; i < out0.size(); ++i) {
  79. mgb_assert(out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT),
  80. "%s", err_msg().c_str());
  81. }
  82. for (; i < out1.size(); ++i) {
  83. mgb_assert(out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT),
  84. "%s", err_msg().c_str());
  85. }
  86. return new_opr;
  87. }
  88. return opr;
  89. }
  90. void SubGraph::Rewriter::replace_var(VarNode* src, VarNode* dst,
  91. const char* msg) {
  92. if (src == dst)
  93. return;
  94. // Optimizers should not create a loop in varaible replace map.
  95. mgb_throw_if(
  96. get_var_internal(dst).second == src, InternalError,
  97. "dst %s maps back to src %s in SubGraph::Rewriter::replace_var",
  98. dst->cname(), src->cname());
  99. auto&& ins = m_varmap.insert({src, {false, dst}});
  100. if (!ins.second) {
  101. auto&& old_rep = ins.first->second;
  102. mgb_assert(old_rep.first || old_rep.second == dst,
  103. "can not replace a var twice");
  104. old_rep.first = false;
  105. old_rep.second = dst;
  106. }
  107. on_var_replaced(src, dst, msg);
  108. }
  109. void SubGraph::Rewriter::on_var_replaced(VarNode* src, VarNode* dst,
  110. const char* msg) {
  111. if (auto state = m_owner_graph->owner_opt_state()) {
  112. state->on_var_replaced(src, dst, msg);
  113. }
  114. }
  115. void SubGraph::Rewriter::apply_inplace() const {
  116. m_owner_graph->m_endpoint_oprs.clear();
  117. m_owner_graph->m_endpoint_vars_set.clear();
  118. for (auto&& var : m_owner_graph->m_endpoint_vars) {
  119. var = get_var(var.node());
  120. m_owner_graph->m_endpoint_oprs.insert(var.node()->owner_opr());
  121. m_owner_graph->m_endpoint_vars_set.insert(var.node());
  122. }
  123. }
  124. std::pair<bool, VarNode*> SubGraph::Rewriter::get_var_internal(VarNode* var) {
  125. // The implementation is (manually) unrolled once, background:
  126. // git-core/brain-sdk/MegBrain/merge_requests/486#note_76971
  127. auto it = m_varmap.find(var);
  128. if (it == m_varmap.end()) {
  129. return {true, var};
  130. }
  131. mgb_assert(it->second.second != var, "loop detected in m_varmap");
  132. auto it_next = m_varmap.find(it->second.second);
  133. if (it_next == m_varmap.end()) {
  134. return it->second;
  135. }
  136. mgb_assert(it_next->second.second != it->second.second,
  137. "loop detected in m_varmap");
  138. auto next = get_var_internal(it_next->second.second);
  139. it_next->second = {next.first & it_next->second.first, next.second};
  140. return it->second = {it_next->second.first & it->second.first, next.second};
  141. }
  142. SubGraph::SubGraph(const SymbolVarArray& endpoint_vars)
  143. : m_endpoint_vars(endpoint_vars) {
  144. mgb_assert(!endpoint_vars.empty(), "endpoints can not be empty");
  145. m_comp_graph = endpoint_vars[0].node()->owner_graph();
  146. for (auto i : endpoint_vars) {
  147. m_endpoint_oprs.insert(i.node()->owner_opr());
  148. m_endpoint_vars_set.insert(i.node());
  149. mgb_assert(m_comp_graph == i.node()->owner_graph(),
  150. "endpoints belong to different computing graphs");
  151. }
  152. }
  153. void SubGraph::iter(const Callback& cb,
  154. std::shared_ptr<ExtraDep> extra_dep) const {
  155. Callback on_opr;
  156. if (m_owner_opt_state) {
  157. on_opr = [state = m_owner_opt_state, &cb](OperatorNodeBase* opr) {
  158. state->m_opr_property_flag = OprPropertyFlag::ALL;
  159. state->m_cur_iter_src_opr = cg::get_opr_root_source_opr(opr);
  160. state->m_cur_iter_opr_priority =
  161. opr->node_prop().attribute().priority;
  162. state->m_cur_iter_opr_stream_prop_type =
  163. state->m_comp_node_opt.stream_prop_type(opr->output(0));
  164. mgb_assert(state->m_oprs_inserted.empty());
  165. cb(opr);
  166. state->m_opr_property_flag = OprPropertyFlag::NONE;
  167. state->m_cur_iter_src_opr = nullptr;
  168. state->m_oprs_inserted.clear();
  169. };
  170. } else {
  171. on_opr = cb;
  172. }
  173. cg::DepOprIter dep_iter{on_opr, std::move(extra_dep)};
  174. for (auto i : m_endpoint_oprs)
  175. dep_iter.add(i);
  176. }
  177. ThinHashMap<VarNode*, size_t> SubGraph::get_var2nr_val_dep_oprs() const {
  178. ThinHashMap<VarNode*, size_t> ret;
  179. auto cb = [&](OperatorNodeBase* opr) {
  180. for (auto&& i : opr->node_prop().dep_map()) {
  181. if (OperatorNodeBase::NodeProp::is_device_value_dep(i.second)) {
  182. ++ret.at(i.first);
  183. }
  184. }
  185. for (auto i : opr->output()) {
  186. if (!i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  187. auto ins = ret.insert({i, 0});
  188. mgb_assert(ins.second);
  189. }
  190. }
  191. };
  192. iter(cb);
  193. for (auto i : m_endpoint_vars_set) {
  194. auto iter = ret.find(i);
  195. if (iter == ret.end()) {
  196. mgb_assert(i->contain_flag(VarNode::Flag::VOLATILE_CONTENT));
  197. ret[i] = 1;
  198. } else {
  199. ++ret.at(i);
  200. }
  201. }
  202. return ret;
  203. }
  204. /* ================ UniqReaderCheck ================ */
  205. UniqReaderCheck::UniqReaderCheck(const SubGraph& graph)
  206. : m_var2nr_val_dep{graph.get_var2nr_val_dep_oprs()} {}
  207. void UniqReaderCheck::update_on_opr_auto_replace(OperatorNodeBase* opr,
  208. OperatorNodeBase* repl_opr) {
  209. auto non_volatile_size = [](const VarNodeArray& vars) -> size_t {
  210. size_t size = 0;
  211. for (size_t i = 0; i < vars.size(); ++i) {
  212. if (!vars[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  213. size++;
  214. }
  215. }
  216. return size;
  217. };
  218. if (opr != repl_opr) {
  219. auto &&o0 = opr->output(), &&o1 = repl_opr->output();
  220. mgb_assert(non_volatile_size(o0) == non_volatile_size(o1));
  221. for (size_t i = 0; i < o0.size(); ++i) {
  222. auto iter = m_var2nr_val_dep.find(o0[i]);
  223. if (iter != m_var2nr_val_dep.end()) {
  224. auto n = iter->second;
  225. m_var2nr_val_dep[o1[i]] = n;
  226. }
  227. }
  228. }
  229. }
  230. /* ================ OptState ================ */
  231. OptState::OptState(const GraphOptimizer* owner_optimizer, const SubGraph& graph)
  232. : m_owner_optimizer{owner_optimizer},
  233. m_var_replace_map{const_cast<ThinHashMap<VarNode*, VarNode*>*>(
  234. &GraphOptimizer::var_replace_map(*graph.comp_graph()))},
  235. m_comp_node_opt{graph.comp_graph()->seq_comp_node_optimizer()},
  236. m_graph{graph} {
  237. mgb_assert(!m_graph.m_owner_opt_state);
  238. m_var_replace_map->clear();
  239. m_graph.m_owner_opt_state = this;
  240. m_oprs_inserted.clear();
  241. auto on_opr_insert = [this](const cg::event::OprInserted& ev) {
  242. auto need_src_opr = m_opr_property_flag & OprPropertyFlag::SOURCE_OPR,
  243. need_priority = m_opr_property_flag & OprPropertyFlag::PRIORITY;
  244. if (need_src_opr)
  245. mgb_assert(m_cur_iter_src_opr,
  246. "opr %s{%s} created outside from "
  247. "SubGraph::iter",
  248. ev.opr->cname(), ev.opr->dyn_typeinfo()->name);
  249. if (ev.exc || ev.is_dedup)
  250. return;
  251. auto&& new_attr = ev.opr->node_prop().attribute();
  252. auto&& ins = m_oprs_inserted.insert({ev.opr, OprPropertyFlag::NONE});
  253. mgb_assert(ins.second);
  254. if (need_src_opr && !new_attr.src_opr) {
  255. auto src_opr = m_cur_iter_src_opr;
  256. if (ev.opr != src_opr)
  257. new_attr.src_opr = src_opr;
  258. ins.first->second |= OprPropertyFlag::SOURCE_OPR;
  259. }
  260. if (need_priority) {
  261. new_attr.priority = m_cur_iter_opr_priority;
  262. if (!ev.opr->update_priority()) {
  263. ins.first->second |= OprPropertyFlag::PRIORITY;
  264. }
  265. }
  266. auto csp = m_cur_iter_opr_stream_prop_type;
  267. if (csp.prop_type != cg::SeqCompNodeOptimizer::StreamPropType::NONE) {
  268. for (auto i : ev.opr->output())
  269. m_comp_node_opt.register_stream_var(i, csp);
  270. }
  271. };
  272. m_on_opr_insert_handler =
  273. graph.comp_graph()
  274. ->event()
  275. .register_receiver<cg::event::OprInserted>(on_opr_insert);
  276. }
  277. void OptState::on_var_replaced(VarNode* src, VarNode* dst, const char* msg) {
  278. if (src->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  279. // this can only happen in auto_replace_outputs()
  280. mgb_assert(dst->contain_flag(VarNode::Flag::VOLATILE_CONTENT) &&
  281. src->owner_opr()->dyn_typeinfo() ==
  282. dst->owner_opr()->dyn_typeinfo());
  283. mgb_assert(!msg);
  284. return;
  285. }
  286. //! check_property
  287. {
  288. auto iter = m_oprs_inserted.find(dst->owner_opr());
  289. if (iter != m_oprs_inserted.end()) {
  290. auto &&src_attr = src->owner_opr()->node_prop().attribute(),
  291. &&dst_attr = dst->owner_opr()->node_prop().attribute();
  292. auto opr_info = [&](OperatorNodeBase* opr) {
  293. return opr ? opr->name() + "(" + std::to_string(opr->id()) + ")"
  294. : "NULL";
  295. };
  296. auto err_msg = [&] {
  297. std::string ret = "Please contact Engine group:\n";
  298. ret += "src opr: ";
  299. ret += opr_info(src->owner_opr());
  300. ret += ", dst opr: ";
  301. ret += opr_info(dst->owner_opr());
  302. return ret;
  303. };
  304. MGB_MARK_USED_VAR(err_msg);
  305. if (iter->second & OprPropertyFlag::SOURCE_OPR) {
  306. auto &&src_rt = get_opr_root_source_opr(src->owner_opr()),
  307. &&dst_rt = get_opr_root_source_opr(dst->owner_opr());
  308. mgb_assert(dst_rt == src_rt,
  309. "%s\nsrc source_opr: %s, dst source_opr: %s\n",
  310. err_msg().c_str(), opr_info(src_rt).c_str(),
  311. opr_info(dst_rt).c_str());
  312. }
  313. if (iter->second & OprPropertyFlag::PRIORITY) {
  314. mgb_assert(src_attr.priority == dst_attr.priority,
  315. "%s\nsrc priority: %d, dst priority %d\n",
  316. err_msg().c_str(), src_attr.priority,
  317. dst_attr.priority);
  318. }
  319. }
  320. }
  321. {
  322. bool suc = true;
  323. SmallVector<std::string> fail_chks;
  324. if (m_var_replace_check_flag & VarReplaceCheckFlag::CHECK_INFER_TYPE) {
  325. auto&& mgr = src->owner_graph()->static_infer_manager();
  326. auto it0 = mgr.get_infer_type(src), it1 = mgr.get_infer_type(dst);
  327. using cg::static_infer::InferType;
  328. // only check wheter inferable
  329. auto norm = [](InferType::Flag f) -> bool {
  330. return f & (InferType::RT_STATIC | InferType::CONST);
  331. };
  332. if (!(norm(it0.shape) == norm(it1.shape) &&
  333. norm(it0.value) <= norm(it1.value))) {
  334. suc = false;
  335. fail_chks.push_back("infer-type");
  336. }
  337. }
  338. if (m_var_replace_check_flag & VarReplaceCheckFlag::CHECK_DTYPE) {
  339. if (src->dtype() != dst->dtype()) {
  340. suc = false;
  341. fail_chks.push_back("dtype");
  342. }
  343. }
  344. if (m_var_replace_check_flag & VarReplaceCheckFlag::CHECK_SHAPE) {
  345. if (!(src->shape().eq_shape(dst->shape()))) {
  346. suc = false;
  347. fail_chks.push_back("shape");
  348. }
  349. }
  350. if (!suc) {
  351. std::string fail_msg = "{";
  352. for (size_t i = 0; i < fail_chks.size(); i++) {
  353. fail_msg += fail_chks[i];
  354. if (i < fail_chks.size() - 1) {
  355. fail_msg += ",";
  356. }
  357. }
  358. fail_msg += "}";
  359. mgb_throw_raw(
  360. cg::OperatorNodeExcExtraInfo::ExcMaker{src->owner_opr()}
  361. .make<InternalError>(ssprintf(
  362. "%s mismatch for replace_var: %s",
  363. fail_msg.c_str(),
  364. cg::dump_var_info({src, dst}).c_str())));
  365. }
  366. }
  367. if (src->has_name_set() && !dst->has_name_set()) {
  368. dst->name(src->name());
  369. }
  370. (*m_var_replace_map)[src] = dst;
  371. // dst should be considered as newly inserted, and previous replace
  372. // record should be ignored
  373. m_var_replace_map->erase(dst);
  374. #if MGB_ENABLE_LOGGING
  375. if (msg && m_owner_optimizer->verbosity()) {
  376. m_log_msg.append("\n ")
  377. .append(std::to_string(m_log_nr_item))
  378. .append(": ")
  379. .append(src->owner_opr()->cname())
  380. .append(" => ")
  381. .append(dst->owner_opr()->cname())
  382. .append(" (")
  383. .append(msg)
  384. .append(")");
  385. }
  386. ++m_log_nr_item;
  387. #endif
  388. }
  389. size_t OptState::flush_log(const char* title) {
  390. if (m_owner_optimizer->verbosity() >= 2) {
  391. if (m_log_msg.empty()) {
  392. m_log_msg = mgb_cstr_log(" no var replacement logged");
  393. }
  394. mgb_log("%s%s", title, m_log_msg.c_str());
  395. m_log_msg.clear();
  396. }
  397. auto ret = m_log_nr_item;
  398. m_log_nr_item = 0;
  399. return ret;
  400. }
  401. void OptState::call_with_opr(OperatorNodeBase* opr,
  402. thin_function<void(void)> func,
  403. OprPropertyFlag opr_property_flag) {
  404. auto src_opr = cg::get_opr_root_source_opr(opr);
  405. auto opr_priority = opr->node_prop().attribute().priority;
  406. auto stream_prop_type = m_comp_node_opt.stream_prop_type(opr->output(0));
  407. ThinHashMap<OperatorNodeBase*, OprPropertyFlag> oprs_inserted;
  408. auto swap_properties =
  409. [&, need_src_opr = opr_property_flag & OprPropertyFlag::SOURCE_OPR,
  410. need_priority = opr_property_flag & OprPropertyFlag::PRIORITY] {
  411. if (need_src_opr) {
  412. std::swap(m_cur_iter_src_opr, src_opr);
  413. }
  414. if (need_priority) {
  415. std::swap(m_cur_iter_opr_priority, opr_priority);
  416. }
  417. std::swap(m_cur_iter_opr_stream_prop_type, stream_prop_type);
  418. std::swap(m_opr_property_flag, opr_property_flag);
  419. std::swap(m_oprs_inserted, oprs_inserted);
  420. };
  421. MGB_TRY {
  422. swap_properties();
  423. func();
  424. }
  425. MGB_FINALLY({ swap_properties(); });
  426. }
  427. /* ================ RecursiveSubGraphRewriteHelper ================ */
  428. RecursiveSubGraphRewriteHelper::~RecursiveSubGraphRewriteHelper() noexcept =
  429. default;
  430. RecursiveSubGraphRewriteHelper::RecursiveSubGraphRewriteHelper(OptState& state)
  431. : m_opt_state{state}, m_rewriter{state.graph().make_rewriter()} {}
  432. void RecursiveSubGraphRewriteHelper::apply() {
  433. using namespace std::placeholders;
  434. m_opt_state.graph().iter(
  435. std::bind(&RecursiveSubGraphRewriteHelper::on_opr, this, _1));
  436. m_rewriter.apply_inplace();
  437. }
  438. void RecursiveSubGraphRewriteHelper::on_opr(OperatorNodeBase* opr) {
  439. auto on_new_opr = [this](OperatorNodeBase* opr) {
  440. auto repl_opr = m_rewriter.auto_replace_outputs(opr);
  441. return on_new_opr_check_should_process(opr, repl_opr);
  442. };
  443. if (!on_new_opr(opr))
  444. return;
  445. auto orig_out = get_opr_single_output_var(opr);
  446. if (!orig_out)
  447. return;
  448. mgb_assert(m_opr_stack.empty());
  449. m_opr_stack.push_back(
  450. {orig_out, m_rewriter.get_var(orig_out)->owner_opr()});
  451. bool first = true;
  452. while (!m_opr_stack.empty()) {
  453. auto cur_frame = m_opr_stack.back();
  454. m_opr_stack.pop_back();
  455. auto cur_opr = cur_frame.opr;
  456. bool should_process;
  457. if (first) {
  458. should_process = true;
  459. first = false;
  460. } else {
  461. should_process = on_new_opr(cur_opr);
  462. }
  463. auto cur_out = get_opr_single_output_var(cur_opr);
  464. mgb_assert(cur_out);
  465. cur_out = m_rewriter.get_var(cur_out);
  466. if (should_process) {
  467. auto trans = process_opr(cur_out);
  468. if (trans.valid()) {
  469. m_opr_stack.push_back(
  470. {cur_frame.orig_var, trans->result->owner_opr()});
  471. for (auto i : reverse_adaptor(trans->internal)) {
  472. if (i)
  473. m_opr_stack.push_back({i, i->owner_opr()});
  474. }
  475. if (trans->msg) {
  476. if (!m_log_msg.empty())
  477. m_log_msg.push_back(';');
  478. m_log_msg.append(trans->msg);
  479. }
  480. continue;
  481. }
  482. }
  483. auto src = cur_frame.orig_var;
  484. if (m_rewriter.get_var(src) != cur_out) {
  485. const char* msg = nullptr;
  486. if (m_opr_stack.empty()) {
  487. msg = m_log_msg.c_str();
  488. }
  489. m_rewriter.replace_var(src, cur_out, msg);
  490. after_replace_var(src, cur_out);
  491. if (m_opr_stack.empty()) {
  492. m_log_msg.clear();
  493. break;
  494. }
  495. }
  496. }
  497. }
  498. /* ================ GraphOptimizer ================ */
  499. GraphOptimizer::~GraphOptimizer() noexcept = default;
  500. class GraphOptimizer::VarReplaceMapStorage
  501. : public UserDataContainer::UserData {
  502. MGB_TYPEINFO_OBJ_DECL;
  503. public:
  504. ThinHashMap<VarNode*, VarNode*> map;
  505. };
  506. MGB_TYPEINFO_OBJ_IMPL(GraphOptimizer::VarReplaceMapStorage);
  507. GraphOptimizer& GraphOptimizer::add_pass(std::unique_ptr<Pass> pass) {
  508. mgb_assert(!pass->m_owner_optimizer);
  509. pass->m_owner_optimizer = this;
  510. m_passes.emplace_back(std::move(pass));
  511. return *this;
  512. }
  513. SubGraph GraphOptimizer::apply(const SubGraph& graph) const {
  514. RealTimer timer;
  515. OptState state{this, graph};
  516. size_t tot_nr_replace = 0;
  517. // first update output var shapes of all oprs
  518. state.graph().iter(cg::update_output_var_shapes);
  519. auto&& opt = graph.comp_graph()->options();
  520. auto orig_setting = opt.graph_opt_level;
  521. Pass* cur_pass = nullptr;
  522. MGB_MARK_USED_VAR(cur_pass);
  523. MGB_TRY {
  524. for (auto&& i : m_passes) {
  525. state.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL);
  526. cur_pass = i.get();
  527. opt.graph_opt_level = 1;
  528. i->apply(state);
  529. tot_nr_replace += state.flush_log(
  530. mgb_ssprintf_log("apply optimization pass %s:", i->name())
  531. .c_str());
  532. }
  533. }
  534. MGB_CATCH(std::exception & exc, {
  535. mgb_log_error("error while applying optimization pass %s: %s",
  536. cur_pass->name(), exc.what());
  537. opt.graph_opt_level = orig_setting;
  538. throw;
  539. })
  540. MGB_FINALLY(opt.graph_opt_level = orig_setting);
  541. if (verbosity() >= 1) {
  542. mgb_log_debug(
  543. "graph optimization: applied %zu passes, "
  544. "total %zu var(s) replaced; time=%.2fms",
  545. m_passes.size(), tot_nr_replace, timer.get_msecs());
  546. }
  547. return state.graph();
  548. }
  549. const GraphOptimizer& GraphOptimizer::apply_inplace(VarNodeArray& vars) const {
  550. if (m_passes.empty()) {
  551. // this check is necessary, since OptState would clear
  552. // var_replace_map()
  553. return *this;
  554. }
  555. auto g = apply({{vars.begin(), vars.end()}});
  556. for (size_t i = 0; i < vars.size(); ++i) {
  557. vars[i] = g.endpoint_vars()[i].node();
  558. }
  559. return *this;
  560. }
  561. GraphOptimizer& GraphOptimizer::add_preset_passes(
  562. bool after_grad, const OptimizeForInferenceOptions* inference_opt,
  563. const ComputingGraph::Options* comp_graph_opt) {
  564. auto cv_type = inference_opt ? ConstVarType::IMMUTABLE_AND_PARAM
  565. : ConstVarType::IMMUTABLE;
  566. if (inference_opt) {
  567. add_pass<ConvertBatchNormToElemwisePass>();
  568. }
  569. if (!after_grad || inference_opt) {
  570. add_pass<CondExecConstPredicateFolding>();
  571. }
  572. if (after_grad || inference_opt) {
  573. add_pass<RemoveNonComputingOprPass>();
  574. }
  575. add_pass<DelayBroadcastPass>();
  576. add_pass<ExpandFusedArithPass>();
  577. add_pass<NormalizeArithChainPass>();
  578. if (inference_opt) {
  579. add_pass<ParamRedistributePass>();
  580. add_pass<ParamFusePass>();
  581. }
  582. add_pass<ArithMulDistributePass>();
  583. add_pass<ReorderArithChainPass>(cv_type);
  584. add_pass<ArithFusePass>();
  585. // reorder again because shapes of fused oprs might change
  586. add_pass<ReorderArithChainPass>(cv_type);
  587. add_pass<FinalArithTransformPass>();
  588. add_pass<RemoveRedundantTypeCvtPass>();
  589. add_pass<RemoveRedundantCopyPass>();
  590. #if MGB_JIT
  591. using JITConfig = cg::ComputingGraph::Options::GraphOpt::JITConfig;
  592. int jit_opt_level = 0;
  593. JITConfig jit_config;
  594. // for more detail on what is happening here, see comments on the
  595. // constuctor of class JITFusionPass in fusion_pass.h
  596. if (comp_graph_opt) {
  597. jit_opt_level = comp_graph_opt->graph_opt.jit;
  598. if (comp_graph_opt->graph_opt_level >= 3) {
  599. jit_opt_level = std::max(jit_opt_level, 1);
  600. }
  601. jit_config = comp_graph_opt->graph_opt.jit_config;
  602. }
  603. bool need_jit = (jit_opt_level > 0) || jit_config.enabled();
  604. if (need_jit && after_grad) {
  605. add_pass<gopt::RecompTypeCvtPass>();
  606. }
  607. #endif
  608. // combine astype and reduce.
  609. // Note: apply this pass before JITFusion, so the TypeCvt which
  610. // read by both Reduce and Elemwise could be fused correctly.
  611. add_pass<CombineAstypeAndReducePass>();
  612. #if MGB_JIT
  613. if (need_jit) {
  614. add_pass<gopt::JITFusionPass>(after_grad, jit_opt_level, jit_config);
  615. }
  616. #endif
  617. if (inference_opt) {
  618. add_pass<ParamFusePass>();
  619. add_passes_for_optimize_options(*inference_opt);
  620. }
  621. if (inference_opt) {
  622. // merge params to reduce loading time and graph overhead
  623. add_pass<ParamMergePass>();
  624. add_pass<FuseDeconvCvtPass>();
  625. }
  626. if (inference_opt) {
  627. // remove shape hint after inference optimization
  628. add_pass<RemoveShapeHintPass>();
  629. }
  630. return *this;
  631. }
  632. const ThinHashMap<VarNode*, VarNode*>& GraphOptimizer::var_replace_map(
  633. ComputingGraph& graph) {
  634. auto storage =
  635. graph.options()
  636. .user_data.get_user_data_or_create<VarReplaceMapStorage>();
  637. return storage->map;
  638. }
  639. VarNode* GraphOptimizer::var_replace_lookup(VarNode* var) {
  640. auto&& map = var_replace_map(*(var->owner_graph()));
  641. for (;;) {
  642. auto iter = map.find(var);
  643. if (iter == map.end())
  644. return var;
  645. var = iter->second;
  646. }
  647. }
  648. const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
  649. const cg::GraphCommonOptimizeOptions& options) {
  650. return add_passes_for_optimize_options(
  651. const_cast<cg::GraphCommonOptimizeOptions&>(options));
  652. }
  653. const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
  654. cg::GraphCommonOptimizeOptions& options, bool reset) {
  655. bool need_param_fuse = false;
  656. #define cb(_option, _passes) \
  657. if (options.has_set_##_option()) { \
  658. _passes need_param_fuse = true; \
  659. if (reset) { \
  660. options.disable_##_option(); \
  661. } \
  662. }
  663. cb(fuse_preprocess, {
  664. add_pass(FuseNCHW4Int8Preprocess::make());
  665. add_pass<FuseWarpPerspectiveDimshufflePass>();
  666. });
  667. cb(f16_io_comp, { add_pass(ConvertF32ToF16Pass::make(false)); });
  668. cb(f16_io_f32_comp, { add_pass(ConvertF32ToF16Pass::make(true)); });
  669. cb(nchw4, {
  670. add_pass<FuseConvBiasNonlinPass>();
  671. add_pass<FuseConvBiasZPass>();
  672. add_pass(EnableNCHW4Pass::make_nchw4_converter());
  673. add_pass<ShuffleShuffleRemovePass>();
  674. add_pass<RemoveRedundantTypeCvtPass>();
  675. });
  676. cb(nhwcd4, {
  677. add_pass<FuseConvBiasNonlinPass>();
  678. add_pass(ConvertFormatPass::make_nhwcd4_converter());
  679. });
  680. cb(nchw88, {
  681. add_pass<FuseConvBiasNonlinPass>();
  682. add_pass(EnableNchwxxPass::make_nchwxx_converter(8));
  683. add_pass<ShuffleShuffleRemovePass>();
  684. });
  685. cb(nchw44, {
  686. add_pass<FuseConvBiasNonlinPass>();
  687. add_pass(EnableNchwxxPass::make_nchwxx_converter(4));
  688. add_pass<ShuffleShuffleRemovePass>();
  689. });
  690. cb(nchw44_dot, {
  691. add_pass<FuseConvBiasNonlinPass>();
  692. add_pass(EnableNchw44DotPass::make_nchw44_dot_converter());
  693. add_pass<ShuffleShuffleRemovePass>();
  694. });
  695. cb(nchw32, {
  696. add_pass<FuseConvBiasNonlinPass>();
  697. add_pass<FuseConvBiasZPass>();
  698. add_pass(EnableNCHW4Pass::make_nchw4_converter());
  699. add_pass(EnableTensorCorePass::make_tensorcore_converter());
  700. add_pass<ShuffleShuffleRemovePass>();
  701. add_pass<RemoveRedundantTypeCvtPass>();
  702. add_pass(FuseNCHW4Int8Preprocess::make());
  703. add_pass<FuseWarpPerspectiveDimshufflePass>();
  704. #if CUDA_VERSION >= 10020
  705. add_pass<FoldingConvBiasDimshufflePass>();
  706. #endif
  707. });
  708. cb(chwn4, {
  709. add_pass<FuseConvBiasNonlinPass>();
  710. add_pass<FuseConvBiasZPass>();
  711. add_pass(EnableNCHW4Pass::make_nchw4_converter());
  712. add_pass(EnableCHWN4Pass::make_chwn4_converter());
  713. add_pass<ShuffleShuffleRemovePass>();
  714. add_pass<RemoveRedundantTypeCvtPass>();
  715. });
  716. cb(nchw64, {
  717. add_pass<FuseConvBiasNonlinPass>();
  718. add_pass<PaddingChannelPass>();
  719. add_pass<FuseConvBiasZPass>();
  720. add_pass(EnableNCHW64Pass::make_nchw64_converter());
  721. add_pass<ShuffleShuffleRemovePass>();
  722. add_pass<RemoveRedundantTypeCvtPass>();
  723. add_pass(FuseNCHW4Int8Preprocess::make());
  724. add_pass<FuseWarpPerspectiveDimshufflePass>();
  725. #if CUDA_VERSION >= 10020
  726. add_pass<FoldingConvBiasDimshufflePass>();
  727. #endif
  728. });
  729. cb(fuse_conv_bias_nonlinearity, { add_pass<FuseConvBiasNonlinPass>(); });
  730. cb(fuse_conv_bias_with_z, {
  731. add_pass<FuseConvBiasNonlinPass>();
  732. add_pass<FuseConvBiasZPass>();
  733. });
  734. #undef cb
  735. if (need_param_fuse) {
  736. add_pass<ParamFusePass>();
  737. }
  738. return *this;
  739. }
  740. const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options(
  741. const GraphTuningOptions& options) {
  742. bool need_param_fuse = false;
  743. #define cb(_options, _passes) \
  744. if (options.has_set_##_options()) { \
  745. _passes need_param_fuse = true; \
  746. }
  747. cb(layout_transform, {
  748. add_pass<FuseConvBiasNonlinPass>();
  749. add_pass<FuseConvBiasZPass>();
  750. add_pass(LayoutTransformPass::make(options.target));
  751. add_pass<ShuffleShuffleRemovePass>();
  752. add_pass(FuseNCHW4Int8Preprocess::make());
  753. add_pass<FuseWarpPerspectiveDimshufflePass>();
  754. #if CUDA_VERSION >= 10020
  755. add_pass<FoldingConvBiasDimshufflePass>();
  756. add_pass<FoldingConvBiasTypecvtPass>();
  757. #endif
  758. });
  759. #undef cb
  760. if (need_param_fuse) {
  761. add_pass<ParamFusePass>();
  762. add_pass<ParamMergePass>();
  763. }
  764. return *this;
  765. }
  766. /* ================ ConstVarPropogateBase ================ */
  767. ConstVarPropogate::AddOprResult ConstVarPropogate::add_opr(
  768. OperatorNodeBase* opr) {
  769. using ProfFlag = OperatorNodeBase::NodeProp::Flag;
  770. auto&& info = m_oprinfo[opr];
  771. if (info.processed)
  772. return info.result;
  773. info.processed = true;
  774. #if MGB_ENABLE_JSON
  775. (*opr->to_json_extra_json)["gopt::cvprop"] = json::Bool::make(false);
  776. #endif
  777. AddOprResult ret{false, false, false};
  778. auto make_ret = [&ret, &info]() {
  779. info.result = ret;
  780. return ret;
  781. };
  782. if (is_const_var(m_const_var_type, opr)) {
  783. auto sz = var_mem_size(opr->output(0));
  784. mgb_assert(sz || opr->output(0)->contain_flag(
  785. VarNode::Flag::ALLOW_EMPTY_SHAPE));
  786. info.is_const = true;
  787. info.max_size = sz;
  788. return make_ret();
  789. }
  790. if (opr->input().empty())
  791. return make_ret();
  792. if (opr->node_prop().contain(ProfFlag::FORCE_UPDATE_INPUT_VAR |
  793. ProfFlag::IMPURE_FUNC)) {
  794. return make_ret();
  795. }
  796. size_t max_input_size = 0;
  797. ret.all_const_inp = true;
  798. for (auto i : opr->input()) {
  799. auto io = i->owner_opr();
  800. auto iter = m_oprinfo.find(io);
  801. if (iter == m_oprinfo.end()) {
  802. add_opr(io);
  803. iter = m_oprinfo.find(io);
  804. mgb_assert(iter != m_oprinfo.end());
  805. }
  806. auto&& src = iter->second;
  807. if (src.is_const) {
  808. update_max(max_input_size, src.max_size);
  809. ret.has_const_inp = true;
  810. if (!is_const_var(m_const_var_type, i->owner_opr())) {
  811. ret.has_midconst_inp = true;
  812. }
  813. } else {
  814. ret.all_const_inp = false;
  815. }
  816. }
  817. if (ret.all_const_inp) {
  818. #if MGB_ENABLE_JSON
  819. (*opr->to_json_extra_json)["gopt::cvprop"] = json::Bool::make(true);
  820. #endif
  821. info.max_size = max_input_size;
  822. info.is_const = true;
  823. }
  824. return make_ret();
  825. }
  826. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台