You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

framework.cpp 32 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914
  1. #include "megbrain/gopt/framework.h"
  2. #include "megbrain/gopt/basic_arith.h"
  3. #include "megbrain/gopt/gtrans.h"
  4. #include "megbrain/gopt/inference.h"
  5. #include "megbrain/gopt/misc.h"
  6. #include "megbrain/graph/cg.h"
  7. #include "megbrain/graph/event.h"
  8. #include "megbrain/graph/exc_extra_info.h"
  9. #include "megbrain/serialization/opr_shallow_copy.h"
  10. #include "megbrain/serialization/serializer.h"
  11. #include "megbrain/utils/timer.h"
  12. #if MGB_JIT
  13. #include "megbrain/jit/fusion_pass.h"
  14. #endif
  15. #if MGB_ENABLE_TENSOR_RT
  16. #include "megbrain/tensorrt/opr_replace.h"
  17. #endif
  18. #include "megbrain/gopt/layout_transform_context.h"
  19. #include "megbrain/gopt/layout_transform_pass.h"
  20. #include "megbrain/gopt/profiler.h"
  21. #include "megbrain/gopt/solver.h"
  22. using namespace mgb;
  23. using namespace gopt;
  24. /* ================ SubGraph ================ */
  25. OperatorNodeBase* SubGraph::Rewriter::auto_replace_outputs(OperatorNodeBase* opr) {
  26. auto&& new_inp = m_opr_new_inp_cache;
  27. new_inp.clear();
  28. new_inp.reserve(opr->input().size());
  29. bool has_replaced_inp = false;
  30. for (auto i : opr->input()) {
  31. auto new_var = get_var(i);
  32. if (new_var != i) {
  33. has_replaced_inp = true;
  34. new_inp.push_back(new_var);
  35. } else {
  36. new_inp.push_back(i);
  37. }
  38. }
  39. if (has_replaced_inp) {
  40. auto new_opr = serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  41. auto &&out0 = opr->output(), &&out1 = new_opr->output();
  42. size_t i = 0;
  43. auto err_msg = [opr, new_opr] {
  44. return ssprintf(
  45. "bad opr copy: src=%s{%s} dst=%s{%s}", opr->cname(),
  46. opr->dyn_typeinfo()->name, new_opr->cname(),
  47. new_opr->dyn_typeinfo()->name);
  48. };
  49. MGB_MARK_USED_VAR(err_msg);
  50. // opr output size mismatch may be caused by:
  51. // 0) inplace arith optimization (e.g. PowC need an extra workspace)
  52. // 1) other post-insert optimization (e.g. const folding)
  53. // we can't handle only usable_output here, since some output var with
  54. // volatile flag could be the graph's endpoint (e.g. RemoteSend)
  55. for (; i < std::min(out0.size(), out1.size()); ++i) {
  56. bool v0 = out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT),
  57. v1 = out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT);
  58. mgb_assert(v0 == v1, "%s", err_msg().c_str());
  59. //! rename new var
  60. out1[i]->name(out0[i]->cname());
  61. auto&& ins = m_varmap.insert({out0[i], {true, nullptr}});
  62. mgb_assert(
  63. ins.second || ins.first->second.first,
  64. "opr output already replaced");
  65. // handle repeated call on the same opr
  66. ins.first->second.second = out1[i];
  67. on_var_replaced(out0[i], out1[i], nullptr);
  68. }
  69. for (; i < out0.size(); ++i) {
  70. mgb_assert(
  71. out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT), "%s",
  72. err_msg().c_str());
  73. }
  74. for (; i < out1.size(); ++i) {
  75. mgb_assert(
  76. out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT), "%s",
  77. err_msg().c_str());
  78. }
  79. return new_opr;
  80. }
  81. return opr;
  82. }
  83. void SubGraph::Rewriter::replace_var(VarNode* src, VarNode* dst, const char* msg) {
  84. if (src == dst)
  85. return;
  86. // Optimizers should not create a loop in varaible replace map.
  87. mgb_throw_if(
  88. get_var_internal(dst).second == src, InternalError,
  89. "dst %s maps back to src %s in SubGraph::Rewriter::replace_var",
  90. dst->cname(), src->cname());
  91. auto&& ins = m_varmap.insert({src, {false, dst}});
  92. if (!ins.second) {
  93. auto&& old_rep = ins.first->second;
  94. mgb_assert(
  95. old_rep.first || old_rep.second == dst, "can not replace a var twice");
  96. old_rep.first = false;
  97. old_rep.second = dst;
  98. }
  99. on_var_replaced(src, dst, msg);
  100. }
  101. void SubGraph::Rewriter::on_var_replaced(VarNode* src, VarNode* dst, const char* msg) {
  102. if (auto state = m_owner_graph->owner_opt_state()) {
  103. state->on_var_replaced(src, dst, msg);
  104. }
  105. }
  106. void SubGraph::Rewriter::apply_inplace() const {
  107. m_owner_graph->m_endpoint_oprs.clear();
  108. m_owner_graph->m_endpoint_vars_set.clear();
  109. for (auto&& var : m_owner_graph->m_endpoint_vars) {
  110. var = get_var(var.node());
  111. m_owner_graph->m_endpoint_oprs.insert(var.node()->owner_opr());
  112. m_owner_graph->m_endpoint_vars_set.insert(var.node());
  113. }
  114. }
  115. std::pair<bool, VarNode*> SubGraph::Rewriter::get_var_internal(VarNode* var) {
  116. // The implementation is (manually) unrolled once, background:
  117. // git-core/brain-sdk/MegBrain/merge_requests/486#note_76971
  118. auto it = m_varmap.find(var);
  119. if (it == m_varmap.end()) {
  120. return {true, var};
  121. }
  122. mgb_assert(it->second.second != var, "loop detected in m_varmap");
  123. auto it_next = m_varmap.find(it->second.second);
  124. if (it_next == m_varmap.end()) {
  125. return it->second;
  126. }
  127. mgb_assert(
  128. it_next->second.second != it->second.second, "loop detected in m_varmap");
  129. auto next = get_var_internal(it_next->second.second);
  130. it_next->second = {next.first & it_next->second.first, next.second};
  131. return it->second = {it_next->second.first & it->second.first, next.second};
  132. }
  133. SubGraph::SubGraph(const SymbolVarArray& endpoint_vars)
  134. : m_endpoint_vars(endpoint_vars) {
  135. mgb_assert(!endpoint_vars.empty(), "endpoints can not be empty");
  136. m_comp_graph = endpoint_vars[0].node()->owner_graph();
  137. for (auto i : endpoint_vars) {
  138. m_endpoint_oprs.insert(i.node()->owner_opr());
  139. m_endpoint_vars_set.insert(i.node());
  140. mgb_assert(
  141. m_comp_graph == i.node()->owner_graph(),
  142. "endpoints belong to different computing graphs");
  143. }
  144. }
  145. void SubGraph::iter(const Callback& cb, std::shared_ptr<ExtraDep> extra_dep) const {
  146. Callback on_opr;
  147. if (m_owner_opt_state) {
  148. on_opr = [state = m_owner_opt_state, &cb](OperatorNodeBase* opr) {
  149. state->m_opr_property_flag = OprPropertyFlag::ALL;
  150. state->m_cur_iter_src_opr = cg::get_opr_root_source_opr(opr);
  151. state->m_cur_iter_opr_priority = opr->node_prop().attribute().priority;
  152. state->m_cur_iter_opr_stream_prop_type =
  153. state->m_comp_node_opt.stream_prop_type(opr->output(0));
  154. mgb_assert(state->m_oprs_inserted.empty());
  155. cb(opr);
  156. state->m_opr_property_flag = OprPropertyFlag::NONE;
  157. state->m_cur_iter_src_opr = nullptr;
  158. state->m_oprs_inserted.clear();
  159. };
  160. } else {
  161. on_opr = cb;
  162. }
  163. cg::DepOprIter dep_iter{on_opr, std::move(extra_dep)};
  164. for (auto i : m_endpoint_oprs)
  165. dep_iter.add(i);
  166. }
  167. ThinHashMap<VarNode*, size_t> SubGraph::get_var2nr_val_dep_oprs() const {
  168. ThinHashMap<VarNode*, size_t> ret;
  169. auto cb = [&](OperatorNodeBase* opr) {
  170. for (auto&& i : opr->node_prop().dep_map()) {
  171. if (OperatorNodeBase::NodeProp::is_device_value_dep(i.second)) {
  172. ++ret.at(i.first);
  173. }
  174. }
  175. for (auto i : opr->output()) {
  176. if (!i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  177. auto ins = ret.insert({i, 0});
  178. mgb_assert(ins.second);
  179. }
  180. }
  181. };
  182. iter(cb);
  183. for (auto i : m_endpoint_vars_set) {
  184. auto iter = ret.find(i);
  185. if (iter == ret.end()) {
  186. mgb_assert(i->contain_flag(VarNode::Flag::VOLATILE_CONTENT));
  187. ret[i] = 1;
  188. } else {
  189. ++ret.at(i);
  190. }
  191. }
  192. return ret;
  193. }
  194. /* ================ UniqReaderCheck ================ */
  195. UniqReaderCheck::UniqReaderCheck(const SubGraph& graph)
  196. : m_var2nr_val_dep{graph.get_var2nr_val_dep_oprs()} {}
  197. void UniqReaderCheck::update_on_opr_auto_replace(
  198. OperatorNodeBase* opr, OperatorNodeBase* repl_opr) {
  199. auto non_volatile_size = [](const VarNodeArray& vars) -> size_t {
  200. size_t size = 0;
  201. for (size_t i = 0; i < vars.size(); ++i) {
  202. if (!vars[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  203. size++;
  204. }
  205. }
  206. return size;
  207. };
  208. if (opr != repl_opr) {
  209. auto &&o0 = opr->output(), &&o1 = repl_opr->output();
  210. mgb_assert(non_volatile_size(o0) == non_volatile_size(o1));
  211. for (size_t i = 0; i < o0.size(); ++i) {
  212. auto iter = m_var2nr_val_dep.find(o0[i]);
  213. if (iter != m_var2nr_val_dep.end()) {
  214. auto n = iter->second;
  215. m_var2nr_val_dep[o1[i]] = n;
  216. }
  217. }
  218. }
  219. }
  220. /* ================ OptState ================ */
  221. OptState::OptState(const GraphOptimizer* owner_optimizer, const SubGraph& graph)
  222. : m_owner_optimizer{owner_optimizer},
  223. m_var_replace_map{const_cast<ThinHashMap<VarNode*, VarNode*>*>(
  224. &GraphOptimizer::var_replace_map(*graph.comp_graph()))},
  225. m_comp_node_opt{graph.comp_graph()->seq_comp_node_optimizer()},
  226. m_graph{graph} {
  227. mgb_assert(!m_graph.m_owner_opt_state);
  228. m_var_replace_map->clear();
  229. m_graph.m_owner_opt_state = this;
  230. m_oprs_inserted.clear();
  231. auto on_opr_insert = [this](const cg::event::OprInserted& ev) {
  232. auto need_src_opr = m_opr_property_flag & OprPropertyFlag::SOURCE_OPR,
  233. need_priority = m_opr_property_flag & OprPropertyFlag::PRIORITY;
  234. if (need_src_opr)
  235. mgb_assert(
  236. m_cur_iter_src_opr,
  237. "opr %s{%s} created outside from "
  238. "SubGraph::iter",
  239. ev.opr->cname(), ev.opr->dyn_typeinfo()->name);
  240. if (ev.exc || ev.is_dedup)
  241. return;
  242. auto&& new_attr = ev.opr->node_prop().attribute();
  243. auto&& ins = m_oprs_inserted.insert({ev.opr, OprPropertyFlag::NONE});
  244. mgb_assert(ins.second);
  245. if (need_src_opr && !new_attr.src_opr) {
  246. auto src_opr = m_cur_iter_src_opr;
  247. if (ev.opr != src_opr)
  248. new_attr.src_opr = src_opr;
  249. ins.first->second |= OprPropertyFlag::SOURCE_OPR;
  250. }
  251. if (need_priority) {
  252. new_attr.priority = m_cur_iter_opr_priority;
  253. if (!ev.opr->update_priority()) {
  254. ins.first->second |= OprPropertyFlag::PRIORITY;
  255. }
  256. }
  257. auto csp = m_cur_iter_opr_stream_prop_type;
  258. if (csp.prop_type != cg::SeqCompNodeOptimizer::StreamPropType::NONE) {
  259. for (auto i : ev.opr->output())
  260. m_comp_node_opt.register_stream_var(i, csp);
  261. }
  262. };
  263. m_on_opr_insert_handler =
  264. graph.comp_graph()->event().register_receiver<cg::event::OprInserted>(
  265. on_opr_insert);
  266. }
  267. void OptState::on_var_replaced(VarNode* src, VarNode* dst, const char* msg) {
  268. if (src->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  269. // this can only happen in auto_replace_outputs()
  270. mgb_assert(
  271. dst->contain_flag(VarNode::Flag::VOLATILE_CONTENT) &&
  272. src->owner_opr()->dyn_typeinfo() == dst->owner_opr()->dyn_typeinfo());
  273. mgb_assert(!msg);
  274. return;
  275. }
  276. //! check_property
  277. {
  278. auto iter = m_oprs_inserted.find(dst->owner_opr());
  279. if (iter != m_oprs_inserted.end()) {
  280. auto &&src_attr = src->owner_opr()->node_prop().attribute(),
  281. &&dst_attr = dst->owner_opr()->node_prop().attribute();
  282. auto opr_info = [&](OperatorNodeBase* opr) {
  283. return opr ? opr->name() + "(" + std::to_string(opr->id()) + ")"
  284. : "NULL";
  285. };
  286. auto err_msg = [&] {
  287. std::string ret = "Please contact Engine group:\n";
  288. ret += "src opr: ";
  289. ret += opr_info(src->owner_opr());
  290. ret += ", dst opr: ";
  291. ret += opr_info(dst->owner_opr());
  292. return ret;
  293. };
  294. MGB_MARK_USED_VAR(err_msg);
  295. if (iter->second & OprPropertyFlag::SOURCE_OPR) {
  296. auto &&src_rt = get_opr_root_source_opr(src->owner_opr()),
  297. &&dst_rt = get_opr_root_source_opr(dst->owner_opr());
  298. mgb_assert(
  299. dst_rt == src_rt,
  300. "%s\nsrc source_opr: %s, dst source_opr: %s\n",
  301. err_msg().c_str(), opr_info(src_rt).c_str(),
  302. opr_info(dst_rt).c_str());
  303. }
  304. if (iter->second & OprPropertyFlag::PRIORITY) {
  305. mgb_assert(
  306. src_attr.priority == dst_attr.priority,
  307. "%s\nsrc priority: %d, dst priority %d\n", err_msg().c_str(),
  308. src_attr.priority, dst_attr.priority);
  309. }
  310. }
  311. }
  312. {
  313. bool suc = true;
  314. SmallVector<std::string> fail_chks;
  315. if (m_var_replace_check_flag & VarReplaceCheckFlag::CHECK_INFER_TYPE) {
  316. auto&& mgr = src->owner_graph()->static_infer_manager();
  317. auto it0 = mgr.get_infer_type(src), it1 = mgr.get_infer_type(dst);
  318. using cg::static_infer::InferType;
  319. // only check wheter inferable
  320. auto norm = [](InferType::Flag f) -> bool {
  321. return f & (InferType::RT_STATIC | InferType::CONST);
  322. };
  323. if (!(norm(it0.shape) == norm(it1.shape) &&
  324. norm(it0.value) <= norm(it1.value))) {
  325. suc = false;
  326. fail_chks.push_back("infer-type");
  327. }
  328. }
  329. if (m_var_replace_check_flag & VarReplaceCheckFlag::CHECK_DTYPE) {
  330. if (src->dtype() != dst->dtype()) {
  331. suc = false;
  332. fail_chks.push_back("dtype");
  333. }
  334. }
  335. if (m_var_replace_check_flag & VarReplaceCheckFlag::CHECK_SHAPE) {
  336. if (!(src->shape().eq_shape(dst->shape()))) {
  337. suc = false;
  338. fail_chks.push_back("shape");
  339. }
  340. }
  341. if (!suc) {
  342. std::string fail_msg = "{";
  343. for (size_t i = 0; i < fail_chks.size(); i++) {
  344. fail_msg += fail_chks[i];
  345. if (i < fail_chks.size() - 1) {
  346. fail_msg += ",";
  347. }
  348. }
  349. fail_msg += "}";
  350. mgb_throw_raw(
  351. cg::OperatorNodeExcExtraInfo::ExcMaker{src->owner_opr()}
  352. .make<InternalError>(ssprintf(
  353. "%s mismatch for replace_var: %s", fail_msg.c_str(),
  354. cg::dump_var_info({src, dst}).c_str())));
  355. }
  356. }
  357. if (src->has_name_set() && !dst->has_name_set()) {
  358. dst->name(src->name());
  359. }
  360. (*m_var_replace_map)[src] = dst;
  361. // dst should be considered as newly inserted, and previous replace
  362. // record should be ignored
  363. m_var_replace_map->erase(dst);
  364. #if MGB_ENABLE_LOGGING
  365. if (msg && m_owner_optimizer->verbosity()) {
  366. m_log_msg.append("\n ")
  367. .append(std::to_string(m_log_nr_item))
  368. .append(": ")
  369. .append(src->owner_opr()->cname())
  370. .append(" => ")
  371. .append(dst->owner_opr()->cname())
  372. .append(" (")
  373. .append(msg)
  374. .append(")");
  375. }
  376. ++m_log_nr_item;
  377. #endif
  378. }
  379. size_t OptState::flush_log(const char* title) {
  380. if (m_owner_optimizer->verbosity() >= 2) {
  381. if (m_log_msg.empty()) {
  382. m_log_msg = mgb_cstr_log(" no var replacement logged");
  383. }
  384. mgb_log("%s%s", title, m_log_msg.c_str());
  385. m_log_msg.clear();
  386. }
  387. auto ret = m_log_nr_item;
  388. m_log_nr_item = 0;
  389. return ret;
  390. }
  391. void OptState::call_with_opr(
  392. OperatorNodeBase* opr, thin_function<void(void)> func,
  393. OprPropertyFlag opr_property_flag) {
  394. auto src_opr = cg::get_opr_root_source_opr(opr);
  395. auto opr_priority = opr->node_prop().attribute().priority;
  396. auto stream_prop_type = m_comp_node_opt.stream_prop_type(opr->output(0));
  397. ThinHashMap<OperatorNodeBase*, OprPropertyFlag> oprs_inserted;
  398. auto swap_properties =
  399. [&, need_src_opr = opr_property_flag & OprPropertyFlag::SOURCE_OPR,
  400. need_priority = opr_property_flag & OprPropertyFlag::PRIORITY] {
  401. if (need_src_opr) {
  402. std::swap(m_cur_iter_src_opr, src_opr);
  403. }
  404. if (need_priority) {
  405. std::swap(m_cur_iter_opr_priority, opr_priority);
  406. }
  407. std::swap(m_cur_iter_opr_stream_prop_type, stream_prop_type);
  408. std::swap(m_opr_property_flag, opr_property_flag);
  409. std::swap(m_oprs_inserted, oprs_inserted);
  410. };
  411. MGB_TRY {
  412. swap_properties();
  413. func();
  414. }
  415. MGB_FINALLY({ swap_properties(); });
  416. }
  417. /* ================ RecursiveSubGraphRewriteHelper ================ */
  418. RecursiveSubGraphRewriteHelper::~RecursiveSubGraphRewriteHelper() noexcept = default;
  419. RecursiveSubGraphRewriteHelper::RecursiveSubGraphRewriteHelper(OptState& state)
  420. : m_opt_state{state}, m_rewriter{state.graph().make_rewriter()} {}
  421. void RecursiveSubGraphRewriteHelper::apply() {
  422. using namespace std::placeholders;
  423. m_opt_state.graph().iter(
  424. std::bind(&RecursiveSubGraphRewriteHelper::on_opr, this, _1));
  425. m_rewriter.apply_inplace();
  426. }
  427. void RecursiveSubGraphRewriteHelper::on_opr(OperatorNodeBase* opr) {
  428. auto on_new_opr = [this](OperatorNodeBase* opr) {
  429. auto repl_opr = m_rewriter.auto_replace_outputs(opr);
  430. return on_new_opr_check_should_process(opr, repl_opr);
  431. };
  432. if (!on_new_opr(opr))
  433. return;
  434. auto orig_out = get_opr_single_output_var(opr);
  435. if (!orig_out)
  436. return;
  437. mgb_assert(m_opr_stack.empty());
  438. m_opr_stack.push_back({orig_out, m_rewriter.get_var(orig_out)->owner_opr()});
  439. bool first = true;
  440. while (!m_opr_stack.empty()) {
  441. auto cur_frame = m_opr_stack.back();
  442. m_opr_stack.pop_back();
  443. auto cur_opr = cur_frame.opr;
  444. bool should_process;
  445. if (first) {
  446. should_process = true;
  447. first = false;
  448. } else {
  449. should_process = on_new_opr(cur_opr);
  450. }
  451. auto cur_out = get_opr_single_output_var(cur_opr);
  452. mgb_assert(cur_out);
  453. cur_out = m_rewriter.get_var(cur_out);
  454. if (should_process) {
  455. auto trans = process_opr(cur_out);
  456. if (trans.valid()) {
  457. m_opr_stack.push_back({cur_frame.orig_var, trans->result->owner_opr()});
  458. for (auto i : reverse_adaptor(trans->internal)) {
  459. if (i)
  460. m_opr_stack.push_back({i, i->owner_opr()});
  461. }
  462. if (trans->msg) {
  463. if (!m_log_msg.empty())
  464. m_log_msg.push_back(';');
  465. m_log_msg.append(trans->msg);
  466. }
  467. continue;
  468. }
  469. }
  470. auto src = cur_frame.orig_var;
  471. if (m_rewriter.get_var(src) != cur_out) {
  472. const char* msg = nullptr;
  473. if (m_opr_stack.empty()) {
  474. msg = m_log_msg.c_str();
  475. }
  476. m_rewriter.replace_var(src, cur_out, msg);
  477. after_replace_var(src, cur_out);
  478. if (m_opr_stack.empty()) {
  479. m_log_msg.clear();
  480. break;
  481. }
  482. }
  483. }
  484. }
  485. /* ================ GraphOptimizer ================ */
  486. GraphOptimizer::~GraphOptimizer() noexcept = default;
  487. class GraphOptimizer::VarReplaceMapStorage : public UserDataContainer::UserData {
  488. MGB_TYPEINFO_OBJ_DECL;
  489. public:
  490. ThinHashMap<VarNode*, VarNode*> map;
  491. };
  492. MGB_TYPEINFO_OBJ_IMPL(GraphOptimizer::VarReplaceMapStorage);
  493. GraphOptimizer& GraphOptimizer::add_pass(std::unique_ptr<Pass> pass) {
  494. mgb_assert(!pass->m_owner_optimizer);
  495. pass->m_owner_optimizer = this;
  496. m_passes.emplace_back(std::move(pass));
  497. return *this;
  498. }
  499. SubGraph GraphOptimizer::apply(const SubGraph& graph) const {
  500. RealTimer timer;
  501. OptState state{this, graph};
  502. size_t tot_nr_replace = 0;
  503. // first update output var shapes of all oprs
  504. state.graph().iter(cg::update_output_var_shapes);
  505. auto&& opt = graph.comp_graph()->options();
  506. auto orig_setting = opt.graph_opt_level;
  507. Pass* cur_pass = nullptr;
  508. MGB_MARK_USED_VAR(cur_pass);
  509. MGB_TRY {
  510. for (auto&& i : m_passes) {
  511. state.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL);
  512. cur_pass = i.get();
  513. opt.graph_opt_level = 1;
  514. i->apply(state);
  515. tot_nr_replace += state.flush_log(
  516. mgb_ssprintf_log("apply optimization pass %s:", i->name()).c_str());
  517. }
  518. }
  519. MGB_CATCH(std::exception & exc, {
  520. mgb_log_error(
  521. "error while applying optimization pass %s: %s", cur_pass->name(),
  522. exc.what());
  523. opt.graph_opt_level = orig_setting;
  524. throw;
  525. })
  526. MGB_FINALLY(opt.graph_opt_level = orig_setting);
  527. if (verbosity() >= 1) {
  528. mgb_log_debug(
  529. "graph optimization: applied %zu passes, "
  530. "total %zu var(s) replaced; time=%.2fms",
  531. m_passes.size(), tot_nr_replace, timer.get_msecs());
  532. }
  533. return state.graph();
  534. }
  535. const GraphOptimizer& GraphOptimizer::apply_inplace(VarNodeArray& vars) const {
  536. if (m_passes.empty()) {
  537. // this check is necessary, since OptState would clear
  538. // var_replace_map()
  539. return *this;
  540. }
  541. auto g = apply({{vars.begin(), vars.end()}});
  542. for (size_t i = 0; i < vars.size(); ++i) {
  543. vars[i] = g.endpoint_vars()[i].node();
  544. }
  545. return *this;
  546. }
  547. GraphOptimizer& GraphOptimizer::add_preset_passes(
  548. bool after_grad, const OptimizeForInferenceOptions* inference_opt,
  549. const ComputingGraph::Options* comp_graph_opt) {
  550. auto cv_type =
  551. inference_opt ? ConstVarType::IMMUTABLE_AND_PARAM : ConstVarType::IMMUTABLE;
  552. if (inference_opt) {
  553. add_pass<ConvertBatchNormToElemwisePass>();
  554. }
  555. if (!after_grad || inference_opt) {
  556. add_pass<CondExecConstPredicateFolding>();
  557. }
  558. if (after_grad || inference_opt) {
  559. add_pass<RemoveNonComputingOprPass>();
  560. }
  561. add_pass<DelayBroadcastPass>();
  562. add_pass<ExpandFusedArithPass>();
  563. add_pass<NormalizeArithChainPass>();
  564. if (inference_opt) {
  565. add_pass<ParamRedistributePass>();
  566. add_pass<ParamFusePass>();
  567. }
  568. add_pass<ArithMulDistributePass>();
  569. add_pass<ReorderArithChainPass>(cv_type);
  570. add_pass<ArithFusePass>();
  571. // reorder again because shapes of fused oprs might change
  572. add_pass<ReorderArithChainPass>(cv_type);
  573. add_pass<FinalArithTransformPass>();
  574. add_pass<RemoveRedundantTypeCvtPass>();
  575. add_pass<RemoveRedundantCopyPass>();
  576. //! Only arm_common implement Fuse TypeCvt and Elemwise optimized kernel
  577. #if (MEGDNN_AARCH64 || MEGDNN_ARMV7) && !MGB_OPENCL && !MGB_CUDA
  578. add_pass<FuseTypecvtElemwisePass>();
  579. #endif
  580. #if MGB_JIT
  581. using JITConfig = cg::ComputingGraph::Options::GraphOpt::JITConfig;
  582. int jit_opt_level = 0;
  583. JITConfig jit_config;
  584. // for more detail on what is happening here, see comments on the
  585. // constuctor of class JITFusionPass in fusion_pass.h
  586. if (comp_graph_opt) {
  587. jit_opt_level = comp_graph_opt->graph_opt.jit;
  588. if (comp_graph_opt->graph_opt_level >= 3) {
  589. jit_opt_level = std::max(jit_opt_level, 1);
  590. }
  591. jit_config = comp_graph_opt->graph_opt.jit_config;
  592. }
  593. bool need_jit = (jit_opt_level > 0) || jit_config.enabled();
  594. if (need_jit && after_grad) {
  595. add_pass<gopt::RecompTypeCvtPass>();
  596. }
  597. #endif
  598. // combine astype and reduce.
  599. // Note: apply this pass before JITFusion, so the TypeCvt which
  600. // read by both Reduce and Elemwise could be fused correctly.
  601. add_pass<CombineAstypeAndReducePass>();
  602. #if MGB_JIT
  603. if (need_jit) {
  604. add_pass<gopt::JITFusionPass>(after_grad, jit_opt_level, jit_config);
  605. }
  606. #endif
  607. if (inference_opt) {
  608. add_pass<ParamFusePass>();
  609. add_passes_for_optimize_options(*inference_opt);
  610. }
  611. if (inference_opt) {
  612. // merge params to reduce loading time and graph overhead
  613. add_pass<ParamMergePass>();
  614. add_pass<FuseDeconvCvtPass>();
  615. }
  616. if (inference_opt) {
  617. // remove shape hint after inference optimization
  618. add_pass<RemoveShapeHintPass>();
  619. }
  620. return *this;
  621. }
  622. const ThinHashMap<VarNode*, VarNode*>& GraphOptimizer::var_replace_map(
  623. ComputingGraph& graph) {
  624. auto storage =
  625. graph.options().user_data.get_user_data_or_create<VarReplaceMapStorage>();
  626. return storage->map;
  627. }
  628. VarNode* GraphOptimizer::var_replace_lookup(VarNode* var) {
  629. auto&& map = var_replace_map(*(var->owner_graph()));
  630. for (;;) {
  631. auto iter = map.find(var);
  632. if (iter == map.end())
  633. return var;
  634. var = iter->second;
  635. }
  636. }
  637. const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
  638. const cg::GraphCommonOptimizeOptions& options) {
  639. return add_passes_for_optimize_options(
  640. const_cast<cg::GraphCommonOptimizeOptions&>(options));
  641. }
  642. const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
  643. cg::GraphCommonOptimizeOptions& options, bool reset) {
  644. bool need_param_fuse = false;
  645. #define cb(_option, _passes) \
  646. if (options.has_set_##_option()) { \
  647. _passes need_param_fuse = true; \
  648. if (reset) { \
  649. options.disable_##_option(); \
  650. } \
  651. }
  652. cb(fuse_grain, {
  653. add_pass<FoldingReduceMeanPass>();
  654. add_pass<FoldingGlobalPoolingPass>();
  655. });
  656. cb(fuse_preprocess, {
  657. add_pass(FuseNCHW4Int8Preprocess::make());
  658. add_pass<FuseWarpPerspectiveDimshufflePass>();
  659. });
  660. cb(f16_io_comp, { add_pass(ConvertF32ToF16Pass::make(false)); });
  661. cb(f16_io_f32_comp, { add_pass(ConvertF32ToF16Pass::make(true)); });
  662. cb(nchw4, {
  663. add_pass<FuseConvBiasNonlinPass>();
  664. add_pass<FuseConvBiasZPass>();
  665. add_pass(EnableNCHW4Pass::make_nchw4_converter());
  666. add_pass<ShuffleShuffleRemovePass>();
  667. add_pass<RemoveRedundantTypeCvtPass>();
  668. });
  669. cb(nhwcd4, {
  670. add_pass<FuseConvBiasNonlinPass>();
  671. add_pass(ConvertFormatPass::make_nhwcd4_converter());
  672. });
  673. cb(nchw88, {
  674. add_pass<FuseConvBiasNonlinPass>();
  675. add_pass(EnableNchwxxPass::make_nchwxx_converter(8));
  676. add_pass<ShuffleShuffleRemovePass>();
  677. });
  678. cb(nchw44, {
  679. add_pass<FuseConvBiasNonlinPass>();
  680. add_pass(EnableNchwxxPass::make_nchwxx_converter(4));
  681. add_pass<ShuffleShuffleRemovePass>();
  682. });
  683. cb(nchw44_dot, {
  684. add_pass<FuseConvBiasNonlinPass>();
  685. add_pass(EnableNchw44DotPass::make_nchw44_dot_converter());
  686. add_pass<ShuffleShuffleRemovePass>();
  687. });
  688. cb(nchw32, {
  689. add_pass<FuseConvBiasNonlinPass>();
  690. add_pass<FuseConvBiasZPass>();
  691. add_pass(EnableNCHW4Pass::make_nchw4_converter());
  692. add_pass(EnableTensorCorePass::make_tensorcore_converter());
  693. add_pass<ShuffleShuffleRemovePass>();
  694. add_pass<RemoveRedundantTypeCvtPass>();
  695. add_pass(FuseNCHW4Int8Preprocess::make());
  696. add_pass<FuseWarpPerspectiveDimshufflePass>();
  697. #if CUDA_VERSION >= 10020
  698. add_pass<FoldingConvBiasDimshufflePass>();
  699. #endif
  700. });
  701. cb(chwn4, {
  702. add_pass<FuseConvBiasNonlinPass>();
  703. add_pass<FuseConvBiasZPass>();
  704. add_pass(EnableNCHW4Pass::make_nchw4_converter());
  705. add_pass(EnableCHWN4Pass::make_chwn4_converter());
  706. add_pass<ShuffleShuffleRemovePass>();
  707. add_pass<RemoveRedundantTypeCvtPass>();
  708. });
  709. cb(nchw64, {
  710. add_pass<FuseConvBiasNonlinPass>();
  711. add_pass<PaddingChannelPass>();
  712. add_pass<FuseConvBiasZPass>();
  713. add_pass(EnableNCHW64Pass::make_nchw64_converter());
  714. add_pass<ShuffleShuffleRemovePass>();
  715. add_pass<RemoveRedundantTypeCvtPass>();
  716. add_pass(FuseNCHW4Int8Preprocess::make());
  717. add_pass<FuseWarpPerspectiveDimshufflePass>();
  718. #if CUDA_VERSION >= 10020
  719. add_pass<FoldingConvBiasDimshufflePass>();
  720. #endif
  721. });
  722. cb(fuse_conv_bias_nonlinearity, { add_pass<FuseConvBiasNonlinPass>(); });
  723. cb(fuse_conv_bias_with_z, {
  724. add_pass<FuseConvBiasNonlinPass>();
  725. add_pass<FuseConvBiasZPass>();
  726. });
  727. #undef cb
  728. if (need_param_fuse) {
  729. add_pass<ParamFusePass>();
  730. }
  731. return *this;
  732. }
  733. const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options(
  734. const GraphTuningOptions& options) {
  735. bool need_param_fuse = false;
  736. #define cb(_options, _passes) \
  737. if (options.has_set_##_options()) { \
  738. _passes need_param_fuse = true; \
  739. }
  740. using Target = GraphTuningOptions::Target;
  741. cb(layout_transform, {
  742. add_pass<FuseConvBiasNonlinPass>();
  743. if (options.target == Target::CUDA)
  744. add_pass<FuseConvBiasZPass>();
  745. add_pass(LayoutTransformPass::make(options.target));
  746. add_pass<ShuffleShuffleRemovePass>();
  747. if (options.target == Target::CUDA) {
  748. add_pass(FuseNCHW4Int8Preprocess::make());
  749. add_pass<FuseWarpPerspectiveDimshufflePass>();
  750. #if CUDA_VERSION >= 10020
  751. add_pass<FoldingConvBiasDimshufflePass>();
  752. add_pass<FoldingConvBiasTypecvtPass>();
  753. #endif
  754. }
  755. });
  756. #undef cb
  757. if (need_param_fuse) {
  758. add_pass<ParamFusePass>();
  759. add_pass<ParamMergePass>();
  760. }
  761. return *this;
  762. }
  763. /* ================ ConstVarPropogateBase ================ */
  764. ConstVarPropogate::AddOprResult ConstVarPropogate::add_opr(OperatorNodeBase* opr) {
  765. using ProfFlag = OperatorNodeBase::NodeProp::Flag;
  766. auto&& info = m_oprinfo[opr];
  767. if (info.processed)
  768. return info.result;
  769. info.processed = true;
  770. #if MGB_ENABLE_JSON
  771. (*opr->to_json_extra_json)["gopt::cvprop"] = json::Bool::make(false);
  772. #endif
  773. AddOprResult ret{false, false, false};
  774. auto make_ret = [&ret, &info]() {
  775. info.result = ret;
  776. return ret;
  777. };
  778. if (is_const_var(m_const_var_type, opr)) {
  779. auto sz = var_mem_size(opr->output(0));
  780. mgb_assert(
  781. sz || opr->output(0)->contain_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE));
  782. info.is_const = true;
  783. info.max_size = sz;
  784. return make_ret();
  785. }
  786. if (opr->input().empty())
  787. return make_ret();
  788. if (opr->node_prop().contain(
  789. ProfFlag::FORCE_UPDATE_INPUT_VAR | ProfFlag::IMPURE_FUNC)) {
  790. return make_ret();
  791. }
  792. size_t max_input_size = 0;
  793. ret.all_const_inp = true;
  794. for (auto i : opr->input()) {
  795. auto io = i->owner_opr();
  796. auto iter = m_oprinfo.find(io);
  797. if (iter == m_oprinfo.end()) {
  798. add_opr(io);
  799. iter = m_oprinfo.find(io);
  800. mgb_assert(iter != m_oprinfo.end());
  801. }
  802. auto&& src = iter->second;
  803. if (src.is_const) {
  804. update_max(max_input_size, src.max_size);
  805. ret.has_const_inp = true;
  806. if (!is_const_var(m_const_var_type, i->owner_opr())) {
  807. ret.has_midconst_inp = true;
  808. }
  809. } else {
  810. ret.all_const_inp = false;
  811. }
  812. }
  813. if (ret.all_const_inp) {
  814. #if MGB_ENABLE_JSON
  815. (*opr->to_json_extra_json)["gopt::cvprop"] = json::Bool::make(true);
  816. #endif
  817. info.max_size = max_input_size;
  818. info.is_const = true;
  819. }
  820. return make_ret();
  821. }
  822. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}