You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fusion_pass.cpp 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. /**
  2. * \file src/jit/impl/fusion_pass.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/jit/fusion_pass.h"
  12. #include "megbrain/common.h"
  13. #include "megbrain/gopt/gtrans.h"
  14. #include "megbrain/jit/ast_c.h"
  15. #include "megbrain/jit/compiler.h"
  16. #include "megbrain/jit/internal_graph.h"
  17. #include "megbrain/opr/tensor_manip.h"
  18. #include "megbrain/serialization/serializer.h"
  19. #if MGB_JIT
  20. #if MGB_JIT_MLIR
  21. #include "./mlir/ir/each_mode.h"
  22. #endif
  23. using namespace mgb;
  24. using namespace gopt;
  25. using namespace jit;
  26. class JITFusionPass::Impl final {
  27. using Mode = opr::Elemwise::Mode;
  28. using DepType = OperatorNodeBase::NodeProp::DepType;
  29. const bool m_after_grad;
  30. JITFeatureBits m_feature_bits;
  31. OptState& m_opt_state;
  32. CompNode::UnorderedMap<size_t> m_cn2max_nr_input;
  33. SubGraph::Rewriter m_rewriter;
  34. SmallVector<std::unique_ptr<InternalGraphGenerator>> m_igraph_gen_storage;
  35. ThinHashMap<VarNode*, InternalGraphGenerator*> m_var2igraph_gen;
  36. //! map from var to its reader oprs and the corresponding dependency types
  37. ThinHashMap<VarNode*, SmallVector<std::pair<OperatorNodeBase*, DepType>>>
  38. m_var_readers;
  39. ThinHashSet<VarNode*> m_endpoint_set;
  40. //! create a new InternalGraphGenerator rooted at given opr
  41. InternalGraphGenerator* create_new_igraph_gen(OperatorNodeBase* opr);
  42. //! process a single operator, maintaining m_var2igraph_gen
  43. void process_opr(OperatorNodeBase* opr);
  44. size_t max_nr_input(CompNode cn);
  45. //! check whether all oprs which depend on the var are in i_graph
  46. bool test_all_readers_in_the_graph(VarNode* var,
  47. InternalGraphGenerator* i_graph);
  48. //! check shape to determine whether the opr should be added to the internal
  49. //! graph
  50. bool check_shape(cg::OperatorNodeBase* opr, InternalGraphGenerator* i_graph);
  51. //! use m_rewriter to update graph
  52. void update_graph();
  53. //! find the subgraph which can be fused
  54. void detect_fusion();
  55. //! check whether an opr can be fused
  56. bool can_be_fused(cg::OperatorNodeBase* opr) const;
  57. static size_t nr_non_const_vars(const VarNodeArray& vars) {
  58. size_t num = 0;
  59. for (auto i : vars) {
  60. num += !SymbolVar{i}.as_immutable_scalar().valid();
  61. }
  62. return num;
  63. }
  64. public:
  65. Impl(bool after_grad, JITFeatureBits feature_bits, OptState& opt_state)
  66. : m_after_grad{after_grad},
  67. m_feature_bits{feature_bits},
  68. m_opt_state{opt_state},
  69. m_rewriter{opt_state.graph().make_rewriter()} {
  70. detect_fusion();
  71. update_graph();
  72. }
  73. };
  74. void JITFusionPass::Impl::detect_fusion() {
  75. std::vector<OperatorNodeBase*> topo_order;
  76. m_opt_state.graph().iter([this, &topo_order](OperatorNodeBase* opr) {
  77. topo_order.push_back(opr);
  78. for (auto&& i : opr->node_prop().dep_map()) {
  79. m_var_readers[i.first].emplace_back(opr, i.second);
  80. }
  81. });
  82. for (auto opr : reverse_adaptor(topo_order)) {
  83. if (can_be_fused(opr)) {
  84. process_opr(opr);
  85. }
  86. }
  87. }
  88. void JITFusionPass::Impl::update_graph() {
  89. auto process = [this](OperatorNodeBase* opr) {
  90. if (!Compiler::is_supported_device(
  91. opr->output(0)->comp_node().device_type()))
  92. return;
  93. auto fuse_varnode = [this](VarNode* var) {
  94. auto ig_gen_iter = m_var2igraph_gen.find(var);
  95. if (ig_gen_iter == m_var2igraph_gen.end()) {
  96. return;
  97. }
  98. auto ig_gen = ig_gen_iter->second;
  99. if (m_endpoint_set.count(var) != 0 &&
  100. ig_gen->opr_set().size() >= 2) {
  101. auto igraph = ig_gen->generate();
  102. auto&& inputs = ig_gen->orig_inps();
  103. if (m_after_grad || nr_non_const_vars(inputs) == 1) {
  104. // in the forward pass, only fuse oprs with one non-const
  105. // inp
  106. VarNodeArray rewritten_inputs;
  107. for (auto&& input : inputs) {
  108. auto new_input = m_rewriter.get_var(input);
  109. rewritten_inputs.push_back(new_input);
  110. }
  111. auto fusion_op =
  112. JITExecutor::make(igraph, rewritten_inputs);
  113. m_rewriter.replace_var(
  114. var, fusion_op.node(),
  115. mgb_ssprintf_log("fuse endpoint: %s",
  116. var->owner_opr()->cname())
  117. .c_str());
  118. }
  119. }
  120. };
  121. for (auto i : opr->input()) {
  122. if (!m_rewriter.has_manual_replace(i)) {
  123. // if input i is a endpoint, and number of oprs in this subgraph
  124. // is greater than 2
  125. m_opt_state.call_with_opr(i->owner_opr(),
  126. [&] { fuse_varnode(i); });
  127. }
  128. }
  129. m_rewriter.auto_replace_outputs(opr);
  130. if (m_opt_state.graph().endpoint_contain(opr->output(0))) {
  131. // process final endpoint
  132. fuse_varnode(opr->output(0));
  133. }
  134. };
  135. m_opt_state.graph().iter(process);
  136. m_rewriter.apply_inplace();
  137. }
  138. bool JITFusionPass::Impl::test_all_readers_in_the_graph(
  139. VarNode* var, InternalGraphGenerator* ig_gen) {
  140. for (auto&& reader : m_var_readers.at(var)) {
  141. if (reader.second & DepType::DEV_VALUE) {
  142. if (ig_gen->opr_set().count(reader.first) == 0) {
  143. return false;
  144. }
  145. }
  146. }
  147. return true;
  148. }
  149. bool JITFusionPass::Impl::check_shape(cg::OperatorNodeBase* opr,
  150. InternalGraphGenerator* ig_gen) {
  151. if (!cg::is_static_var_shape(opr->output(0))) {
  152. // currently we do not handle dynamic shape in JIT
  153. return false;
  154. }
  155. if (!(m_feature_bits & JITFeatureBits::REDUCE)) {
  156. // By requiring opr output shape to be the same as final output shape,
  157. // we permit only one broadcast. If multiple broadcasts are fused,
  158. // together, execution would be actually slower.
  159. if ((m_feature_bits & JITFeatureBits::DIMSHUFFLE) &&
  160. ig_gen->has_dimshuffle() &&
  161. ig_gen->oprs_depended_by_dimshuffe().count(opr)) {
  162. return opr->output(0)->shape().eq_shape(
  163. ig_gen->oprs_depended_by_dimshuffe()
  164. .at(opr)
  165. ->input(0)
  166. ->shape());
  167. } else {
  168. return opr->output(0)->shape().eq_shape(ig_gen->output()->shape());
  169. }
  170. }
  171. bool before_reduce = false;
  172. for (auto&& op_set : ig_gen->reduce_out_var_deps()) {
  173. if (op_set.second.count(opr)) {
  174. before_reduce = true;
  175. break;
  176. }
  177. }
  178. if (opr->same_type<JITExecutor>()) {
  179. auto jit = &opr->cast_final<JITExecutor>();
  180. bool jit_has_reduce = jit->has_reduce();
  181. auto jit_inp_shp = jit->broadcasted_input_shape();
  182. if (jit_has_reduce) {
  183. if (before_reduce)
  184. return jit_inp_shp.eq_shape(jit->output(0)->shape()) &&
  185. jit_inp_shp.eq_shape(ig_gen->before_reduce_shape());
  186. else {
  187. bool ret = true;
  188. if (ig_gen->has_reduce()) {
  189. ret &= jit_inp_shp.eq_shape(ig_gen->before_reduce_shape());
  190. }
  191. ret &= jit->output(0)->shape().eq_shape(
  192. ig_gen->output()->shape());
  193. return ret;
  194. }
  195. }
  196. }
  197. if (opr->same_type<opr::Reduce>()) {
  198. // TODO: handle reduce target shape in sub graph (especially considering
  199. // placeholder has constant shape)
  200. //
  201. // The best way is to have a dedicated AST for the internal graph; but
  202. // we want to reuse the deduplication and gradient mechanisms from the
  203. // mgb cg
  204. auto reduce = &opr->cast_final<opr::Reduce>();
  205. if (before_reduce) {
  206. return reduce->input(0)->shape().eq_shape(
  207. ig_gen->before_reduce_shape()) &&
  208. reduce->output(0)->shape().eq_shape(
  209. ig_gen->before_reduce_shape());
  210. } else {
  211. bool ret = true;
  212. if (ig_gen->has_reduce()) {
  213. ret &= reduce->input(0)->shape().eq_shape(
  214. ig_gen->before_reduce_shape());
  215. }
  216. ret &= reduce->output(0)->shape().eq_shape(
  217. ig_gen->output()->shape());
  218. return ret;
  219. }
  220. }
  221. if (before_reduce) {
  222. return opr->output(0)->shape().eq_shape(ig_gen->before_reduce_shape());
  223. } else {
  224. return opr->output(0)->shape().eq_shape(ig_gen->output()->shape());
  225. }
  226. }
  227. InternalGraphGenerator* JITFusionPass::Impl::create_new_igraph_gen(
  228. OperatorNodeBase* opr) {
  229. auto uptr = std::make_unique<InternalGraphGenerator>(opr);
  230. auto ptr = uptr.get();
  231. m_igraph_gen_storage.emplace_back(std::move(uptr));
  232. m_var2igraph_gen[opr->output(0)] = ptr;
  233. m_endpoint_set.insert(opr->output(0));
  234. return ptr;
  235. }
  236. void JITFusionPass::Impl::process_opr(OperatorNodeBase* opr) {
  237. auto max_nr_input = this->max_nr_input(opr->output(0)->comp_node());
  238. if (nr_non_const_vars(opr->input()) > max_nr_input ||
  239. !cg::is_static_var_shape(opr->output(0))) {
  240. return;
  241. }
  242. // dimshuffle should not be an endpoint, because megbrain has lazy
  243. // dimshuffle machanism
  244. InternalGraphGenerator* ig_gen = nullptr;
  245. if (m_var2igraph_gen.count(opr->output(0)) == 0) {
  246. // because of the reverse traversal, when an operator is being
  247. // processed but not in m_var2igraph_gen, means it is a endpoint of a
  248. // JIT subgraph.
  249. if (opr->same_type<opr::Dimshuffle>()) {
  250. return;
  251. }
  252. ig_gen = create_new_igraph_gen(opr);
  253. } else {
  254. ig_gen = m_var2igraph_gen[opr->output(0)];
  255. // if all oprs which depend on this elemwise opr's output were already
  256. // in the subgraph and the opr's comp_node is same with the subgraph's,
  257. // then this opr can be fused to this graph as an internal node rather
  258. // than a leaf.
  259. bool cond_readers =
  260. test_all_readers_in_the_graph(opr->output(0), ig_gen),
  261. cond_cn = opr->output(0)->comp_node() ==
  262. ig_gen->output()->comp_node(),
  263. cond_shp = check_shape(opr, ig_gen),
  264. cond_nr_inp = ig_gen->get_cnt_input_if_add(opr) <= max_nr_input,
  265. cond_mlir_specific = true;
  266. #if MGB_JIT_MLIR
  267. //! FIXME mlir does't support broadcast currently.
  268. auto backend = MGB_GETENV("MGB_JIT_BACKEND");
  269. if (backend && !strcmp(backend, "MLIR")) {
  270. for (VarNode* var : opr->input()) {
  271. if (!SymbolVar{var}.as_immutable_scalar().valid()) {
  272. if (opr->node_prop().dep_map().at(var) &
  273. DepType::DEV_VALUE) {
  274. if (!var->shape().eq_shape(opr->output(0)->shape())) {
  275. cond_mlir_specific = false;
  276. }
  277. }
  278. }
  279. }
  280. }
  281. #endif
  282. if (cond_readers && cond_cn && cond_shp && cond_nr_inp &&
  283. cond_mlir_specific) {
  284. ig_gen->add_opr(opr);
  285. } else {
  286. if (opr->same_type<opr::Dimshuffle>()) {
  287. return;
  288. }
  289. // create a new sub graph starting from this opr
  290. mgb_log_debug(
  291. "JIT graph stopped at opr %s{%s}: cond: readers=%d cn=%d "
  292. "shp=%d nr_inp=%d",
  293. opr->cname(), opr->dyn_typeinfo()->name, cond_readers,
  294. cond_cn, cond_shp, cond_nr_inp);
  295. ig_gen = create_new_igraph_gen(opr);
  296. }
  297. }
  298. // handle const inputs
  299. for (auto&& i : opr->node_prop().dep_map()) {
  300. if (i.second & cg::OperatorNodeBase::NodeProp::DepType::DEV_VALUE) {
  301. if (SymbolVar{i.first}
  302. .as_immutable_scalar_require_shape()
  303. .valid()) {
  304. auto opr = i.first->owner_opr();
  305. mgb_assert(opr->same_type<opr::ImmutableTensor>(),
  306. "got imm scalar from non ImmutableTensor: %s{%s}",
  307. opr->cname(), opr->dyn_typeinfo()->name);
  308. ig_gen->add_opr(opr);
  309. continue;
  310. }
  311. }
  312. m_var2igraph_gen[i.first] = ig_gen;
  313. }
  314. }
  315. size_t JITFusionPass::Impl::max_nr_input(CompNode cn) {
  316. auto&& ret = m_cn2max_nr_input[cn];
  317. if (!ret) {
  318. ret = Compiler::get(*m_opt_state.graph().comp_graph(), cn)
  319. ->property()
  320. .max_nr_input;
  321. mgb_assert(ret);
  322. }
  323. return ret;
  324. }
  325. bool JITFusionPass::Impl::can_be_fused(cg::OperatorNodeBase* opr) const {
  326. if (!Compiler::is_supported_device(
  327. opr->output(0)->comp_node().device_type())) {
  328. return false;
  329. }
  330. //! As MLIR backend has some contraints
  331. const char* backend = MGB_GETENV("MGB_JIT_BACKEND");
  332. if (!backend) {
  333. backend = "DEFAULT";
  334. }
  335. // float elemwise
  336. if (auto elem = gopt::try_cast_as_op<opr::Elemwise>(opr)) {
  337. bool ret = true;
  338. #if MGB_JIT_MLIR
  339. if (!strcmp(backend, "MLIR")) {
  340. switch (elem->param().mode) {
  341. #define cb(_, _mode) \
  342. case opr::Elemwise::Mode::_mode: \
  343. ret = true; \
  344. break;
  345. MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb)
  346. MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb)
  347. MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb)
  348. default:
  349. ret = false;
  350. #undef cb
  351. }
  352. #define FOREACH_ELEMWISE_SKIP_MODE(cb) cb(SIN)
  353. //! FIXME mlir on cuda does't support sin currently.
  354. if (opr->output(0)->comp_node().device_type() ==
  355. CompNode::DeviceType::CUDA) {
  356. switch (elem->param().mode) {
  357. #define cb(_mode) \
  358. case opr::Elemwise::Mode::_mode: \
  359. ret = false; \
  360. break;
  361. FOREACH_ELEMWISE_SKIP_MODE(cb)
  362. default:
  363. break;
  364. #undef cb
  365. }
  366. }
  367. #undef FOREACH_ELEMWISE_SKIP_MODE
  368. }
  369. #endif // MGB_JIT_MLIR
  370. return ret && ast_c::check_elem_mode(elem->param().mode) &&
  371. elem->output(0)->dtype().category() == DTypeCategory::FLOAT;
  372. }
  373. if (strcmp(backend, "MLIR")) {
  374. if (opr->same_type<opr::PowC>()) {
  375. return true;
  376. }
  377. // float typecvt (e.g. used in f16 training)
  378. if (opr->same_type<opr::TypeCvt>()) {
  379. auto category = opr->input(0)->dtype().category();
  380. if (category != opr->output(0)->dtype().category())
  381. return false;
  382. return category == DTypeCategory::FLOAT;
  383. }
  384. // float reduce
  385. if ((m_feature_bits & JITFeatureBits::REDUCE) &&
  386. opr->same_type<opr::Reduce>()) {
  387. return opr->output(0)->dtype().category() == DTypeCategory::FLOAT;
  388. }
  389. // dimshuffle
  390. if ((m_feature_bits & JITFeatureBits::DIMSHUFFLE) &&
  391. opr->same_type<opr::Dimshuffle>()) {
  392. auto param = opr->cast_final_safe<opr::Dimshuffle>().param();
  393. return param.pattern_len <= 4;
  394. }
  395. }
  396. // existing JITExecutor
  397. if (opr->same_type<JITExecutor>())
  398. return true;
  399. return false;
  400. }
  401. JITFusionPass::JITFusionPass(bool after_grad, int8_t jit_opt_level)
  402. : m_after_grad{after_grad}, m_feature_bits{JITFeatureBits::NONE} {
  403. // TODO reduce and dimshuffle can not coexsit now.
  404. if (jit_opt_level >= 2) {
  405. m_feature_bits |= JITFeatureBits::REDUCE;
  406. } else {
  407. m_feature_bits |= JITFeatureBits::DIMSHUFFLE;
  408. }
  409. }
  410. const char* JITFusionPass::name() const {
  411. return mgb_cstr_log("fusion_pass");
  412. }
  413. void JITFusionPass::apply(OptState& opt) const {
  414. Impl{m_after_grad, m_feature_bits, opt};
  415. }
  416. #endif
  417. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台