You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fusion_pass.cpp 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. /**
  2. * \file src/jit/impl/fusion_pass.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/jit/fusion_pass.h"
  12. #include "megbrain/common.h"
  13. #include "megbrain/gopt/gtrans.h"
  14. #include "megbrain/jit/ast_c.h"
  15. #include "megbrain/jit/compiler.h"
  16. #include "megbrain/jit/internal_graph.h"
  17. #include "megbrain/opr/tensor_manip.h"
  18. #include "megbrain/serialization/serializer.h"
  19. #if MGB_JIT
  20. using namespace mgb;
  21. using namespace gopt;
  22. using namespace jit;
  23. class JITFusionPass::Impl final {
  24. using Mode = opr::Elemwise::Mode;
  25. using DepType = OperatorNodeBase::NodeProp::DepType;
  26. const bool m_after_grad;
  27. JITFeatureBits m_feature_bits;
  28. OptState& m_opt_state;
  29. CompNode::UnorderedMap<size_t> m_cn2max_nr_input;
  30. SubGraph::Rewriter m_rewriter;
  31. SmallVector<std::unique_ptr<InternalGraphGenerator>> m_igraph_gen_storage;
  32. ThinHashMap<VarNode*, InternalGraphGenerator*> m_var2igraph_gen;
  33. //! map from var to its reader oprs and the corresponding dependency types
  34. ThinHashMap<VarNode*, SmallVector<std::pair<OperatorNodeBase*, DepType>>>
  35. m_var_readers;
  36. ThinHashSet<VarNode*> m_endpoint_set;
  37. //! create a new InternalGraphGenerator rooted at given opr
  38. InternalGraphGenerator* create_new_igraph_gen(OperatorNodeBase* opr);
  39. //! process a single operator, maintaining m_var2igraph_gen
  40. void process_opr(OperatorNodeBase* opr);
  41. size_t max_nr_input(CompNode cn);
  42. //! check whether all oprs which depend on the var are in i_graph
  43. bool test_all_readers_in_the_graph(VarNode* var,
  44. InternalGraphGenerator* i_graph);
  45. //! check shape to determine whether the opr should be added to the internal
  46. //! graph
  47. bool check_shape(cg::OperatorNodeBase* opr, InternalGraphGenerator* i_graph);
  48. //! use m_rewriter to update graph
  49. void update_graph();
  50. //! find the subgraph which can be fused
  51. void detect_fusion();
  52. //! check whether an opr can be fused
  53. bool can_be_fused(cg::OperatorNodeBase* opr) const;
  54. static size_t nr_non_const_vars(const VarNodeArray& vars) {
  55. size_t num = 0;
  56. for (auto i : vars) {
  57. num += !SymbolVar{i}.as_immutable_scalar().valid();
  58. }
  59. return num;
  60. }
  61. public:
  62. Impl(bool after_grad, JITFeatureBits feature_bits, OptState& opt_state)
  63. : m_after_grad{after_grad},
  64. m_feature_bits{feature_bits},
  65. m_opt_state{opt_state},
  66. m_rewriter{opt_state.graph().make_rewriter()} {
  67. detect_fusion();
  68. update_graph();
  69. }
  70. };
  71. void JITFusionPass::Impl::detect_fusion() {
  72. std::vector<OperatorNodeBase*> topo_order;
  73. m_opt_state.graph().iter([this, &topo_order](OperatorNodeBase* opr) {
  74. topo_order.push_back(opr);
  75. for (auto&& i : opr->node_prop().dep_map()) {
  76. m_var_readers[i.first].emplace_back(opr, i.second);
  77. }
  78. });
  79. for (auto opr : reverse_adaptor(topo_order)) {
  80. if (can_be_fused(opr)) {
  81. process_opr(opr);
  82. }
  83. }
  84. }
  85. void JITFusionPass::Impl::update_graph() {
  86. auto process = [this](OperatorNodeBase* opr) {
  87. if (!Compiler::is_supported_device(
  88. opr->output(0)->comp_node().device_type()))
  89. return;
  90. auto fuse_varnode = [this](VarNode* var) {
  91. auto ig_gen_iter = m_var2igraph_gen.find(var);
  92. if (ig_gen_iter == m_var2igraph_gen.end()) {
  93. return;
  94. }
  95. auto ig_gen = ig_gen_iter->second;
  96. if (m_endpoint_set.count(var) != 0 &&
  97. ig_gen->opr_set().size() >= 2) {
  98. auto igraph = ig_gen->generate();
  99. auto&& inputs = ig_gen->orig_inps();
  100. if (m_after_grad || nr_non_const_vars(inputs) == 1) {
  101. // in the forward pass, only fuse oprs with one non-const
  102. // inp
  103. VarNodeArray rewritten_inputs;
  104. for (auto&& input : inputs) {
  105. auto new_input = m_rewriter.get_var(input);
  106. rewritten_inputs.push_back(new_input);
  107. }
  108. auto fusion_op =
  109. JITExecutor::make(igraph, rewritten_inputs);
  110. m_rewriter.replace_var(
  111. var, fusion_op.node(),
  112. mgb_ssprintf_log("fuse endpoint: %s",
  113. var->owner_opr()->cname())
  114. .c_str());
  115. }
  116. }
  117. };
  118. for (auto i : opr->input()) {
  119. if (!m_rewriter.has_manual_replace(i)) {
  120. // if input i is a endpoint, and number of oprs in this subgraph
  121. // is greater than 2
  122. m_opt_state.call_with_opr(i->owner_opr(),
  123. [&] { fuse_varnode(i); });
  124. }
  125. }
  126. m_rewriter.auto_replace_outputs(opr);
  127. if (m_opt_state.graph().endpoint_contain(opr->output(0))) {
  128. // process final endpoint
  129. fuse_varnode(opr->output(0));
  130. }
  131. };
  132. m_opt_state.graph().iter(process);
  133. m_rewriter.apply_inplace();
  134. }
  135. bool JITFusionPass::Impl::test_all_readers_in_the_graph(
  136. VarNode* var, InternalGraphGenerator* ig_gen) {
  137. for (auto&& reader : m_var_readers.at(var)) {
  138. if (reader.second & DepType::DEV_VALUE) {
  139. if (ig_gen->opr_set().count(reader.first) == 0) {
  140. return false;
  141. }
  142. }
  143. }
  144. return true;
  145. }
  146. bool JITFusionPass::Impl::check_shape(cg::OperatorNodeBase* opr,
  147. InternalGraphGenerator* ig_gen) {
  148. if (!cg::is_static_var_shape(opr->output(0))) {
  149. // currently we do not handle dynamic shape in JIT
  150. return false;
  151. }
  152. if (!(m_feature_bits & JITFeatureBits::REDUCE)) {
  153. // By requiring opr output shape to be the same as final output shape,
  154. // we permit only one broadcast. If multiple broadcasts are fused,
  155. // together, execution would be actually slower.
  156. if ((m_feature_bits & JITFeatureBits::DIMSHUFFLE) &&
  157. ig_gen->has_dimshuffle() &&
  158. ig_gen->oprs_depended_by_dimshuffe().count(opr)) {
  159. return opr->output(0)->shape().eq_shape(
  160. ig_gen->oprs_depended_by_dimshuffe()
  161. .at(opr)
  162. ->input(0)
  163. ->shape());
  164. } else {
  165. return opr->output(0)->shape().eq_shape(ig_gen->output()->shape());
  166. }
  167. }
  168. bool before_reduce = false;
  169. for (auto&& op_set : ig_gen->reduce_out_var_deps()) {
  170. if (op_set.second.count(opr)) {
  171. before_reduce = true;
  172. break;
  173. }
  174. }
  175. if (opr->same_type<JITExecutor>()) {
  176. auto jit = &opr->cast_final<JITExecutor>();
  177. bool jit_has_reduce = jit->has_reduce();
  178. auto jit_inp_shp = jit->broadcasted_input_shape();
  179. if (jit_has_reduce) {
  180. if (before_reduce)
  181. return jit_inp_shp.eq_shape(jit->output(0)->shape()) &&
  182. jit_inp_shp.eq_shape(ig_gen->before_reduce_shape());
  183. else {
  184. bool ret = true;
  185. if (ig_gen->has_reduce()) {
  186. ret &= jit_inp_shp.eq_shape(ig_gen->before_reduce_shape());
  187. }
  188. ret &= jit->output(0)->shape().eq_shape(
  189. ig_gen->output()->shape());
  190. return ret;
  191. }
  192. }
  193. }
  194. if (opr->same_type<opr::Reduce>()) {
  195. // TODO: handle reduce target shape in sub graph (especially considering
  196. // placeholder has constant shape)
  197. //
  198. // The best way is to have a dedicated AST for the internal graph; but
  199. // we want to reuse the deduplication and gradient mechanisms from the
  200. // mgb cg
  201. auto reduce = &opr->cast_final<opr::Reduce>();
  202. if (before_reduce) {
  203. return reduce->input(0)->shape().eq_shape(
  204. ig_gen->before_reduce_shape()) &&
  205. reduce->output(0)->shape().eq_shape(
  206. ig_gen->before_reduce_shape());
  207. } else {
  208. bool ret = true;
  209. if (ig_gen->has_reduce()) {
  210. ret &= reduce->input(0)->shape().eq_shape(
  211. ig_gen->before_reduce_shape());
  212. }
  213. ret &= reduce->output(0)->shape().eq_shape(
  214. ig_gen->output()->shape());
  215. return ret;
  216. }
  217. }
  218. if (before_reduce) {
  219. return opr->output(0)->shape().eq_shape(ig_gen->before_reduce_shape());
  220. } else {
  221. return opr->output(0)->shape().eq_shape(ig_gen->output()->shape());
  222. }
  223. }
  224. InternalGraphGenerator* JITFusionPass::Impl::create_new_igraph_gen(
  225. OperatorNodeBase* opr) {
  226. auto uptr = std::make_unique<InternalGraphGenerator>(opr);
  227. auto ptr = uptr.get();
  228. m_igraph_gen_storage.emplace_back(std::move(uptr));
  229. m_var2igraph_gen[opr->output(0)] = ptr;
  230. m_endpoint_set.insert(opr->output(0));
  231. return ptr;
  232. }
  233. void JITFusionPass::Impl::process_opr(OperatorNodeBase* opr) {
  234. auto max_nr_input = this->max_nr_input(opr->output(0)->comp_node());
  235. if (nr_non_const_vars(opr->input()) > max_nr_input ||
  236. !cg::is_static_var_shape(opr->output(0))) {
  237. return;
  238. }
  239. // dimshuffle should not be an endpoint, because megbrain has lazy
  240. // dimshuffle machanism
  241. InternalGraphGenerator* ig_gen = nullptr;
  242. if (m_var2igraph_gen.count(opr->output(0)) == 0) {
  243. // because of the reverse traversal, when an operator is being
  244. // processed but not in m_var2igraph_gen, means it is a endpoint of a
  245. // JIT subgraph.
  246. if (opr->same_type<opr::Dimshuffle>()) {
  247. return;
  248. }
  249. ig_gen = create_new_igraph_gen(opr);
  250. } else {
  251. ig_gen = m_var2igraph_gen[opr->output(0)];
  252. // if all oprs which depend on this elemwise opr's output were already
  253. // in the subgraph and the opr's comp_node is same with the subgraph's,
  254. // then this opr can be fused to this graph as an internal node rather
  255. // than a leaf.
  256. bool cond_readers =
  257. test_all_readers_in_the_graph(opr->output(0), ig_gen),
  258. cond_cn = opr->output(0)->comp_node() ==
  259. ig_gen->output()->comp_node(),
  260. cond_shp = check_shape(opr, ig_gen),
  261. cond_nr_inp = ig_gen->get_cnt_input_if_add(opr) <= max_nr_input;
  262. if (cond_readers && cond_cn && cond_shp && cond_nr_inp) {
  263. ig_gen->add_opr(opr);
  264. } else {
  265. if (opr->same_type<opr::Dimshuffle>()) {
  266. return;
  267. }
  268. // create a new sub graph starting from this opr
  269. mgb_log_debug(
  270. "JIT graph stopped at opr %s{%s}: cond: readers=%d cn=%d "
  271. "shp=%d nr_inp=%d",
  272. opr->cname(), opr->dyn_typeinfo()->name, cond_readers,
  273. cond_cn, cond_shp, cond_nr_inp);
  274. ig_gen = create_new_igraph_gen(opr);
  275. }
  276. }
  277. // handle const inputs
  278. for (auto&& i : opr->node_prop().dep_map()) {
  279. if (i.second & cg::OperatorNodeBase::NodeProp::DepType::DEV_VALUE) {
  280. if (SymbolVar{i.first}
  281. .as_immutable_scalar_require_shape()
  282. .valid()) {
  283. auto opr = i.first->owner_opr();
  284. mgb_assert(opr->same_type<opr::ImmutableTensor>(),
  285. "got imm scalar from non ImmutableTensor: %s{%s}",
  286. opr->cname(), opr->dyn_typeinfo()->name);
  287. ig_gen->add_opr(opr);
  288. continue;
  289. }
  290. }
  291. m_var2igraph_gen[i.first] = ig_gen;
  292. }
  293. }
  294. size_t JITFusionPass::Impl::max_nr_input(CompNode cn) {
  295. auto&& ret = m_cn2max_nr_input[cn];
  296. if (!ret) {
  297. ret = Compiler::get(*m_opt_state.graph().comp_graph(), cn)
  298. ->property()
  299. .max_nr_input;
  300. mgb_assert(ret);
  301. }
  302. return ret;
  303. }
  304. bool JITFusionPass::Impl::can_be_fused(cg::OperatorNodeBase* opr) const {
  305. if (!Compiler::is_supported_device(
  306. opr->output(0)->comp_node().device_type())) {
  307. return false;
  308. }
  309. // float elemwise
  310. if (auto elem = gopt::try_cast_as_op<opr::Elemwise>(opr)) {
  311. return ast_c::check_elem_mode(elem->param().mode) &&
  312. elem->output(0)->dtype().category() == DTypeCategory::FLOAT;
  313. }
  314. if (opr->same_type<opr::PowC>()) {
  315. return true;
  316. }
  317. // float typecvt (e.g. used in f16 training)
  318. if (opr->same_type<opr::TypeCvt>()) {
  319. auto category = opr->input(0)->dtype().category();
  320. if (category != opr->output(0)->dtype().category())
  321. return false;
  322. return category == DTypeCategory::FLOAT;
  323. }
  324. // float reduce
  325. if ((m_feature_bits & JITFeatureBits::REDUCE) &&
  326. opr->same_type<opr::Reduce>()) {
  327. return opr->output(0)->dtype().category() == DTypeCategory::FLOAT;
  328. }
  329. // dimshuffle
  330. if ((m_feature_bits & JITFeatureBits::DIMSHUFFLE) &&
  331. opr->same_type<opr::Dimshuffle>()) {
  332. auto param = opr->cast_final_safe<opr::Dimshuffle>().param();
  333. return param.pattern_len <= 4;
  334. }
  335. // existing JITExecutor
  336. if (opr->same_type<JITExecutor>())
  337. return true;
  338. return false;
  339. }
  340. JITFusionPass::JITFusionPass(bool after_grad, int8_t jit_opt_level)
  341. : m_after_grad{after_grad}, m_feature_bits{JITFeatureBits::NONE} {
  342. // TODO reduce and dimshuffle can not coexsit now.
  343. if (jit_opt_level >= 2) {
  344. m_feature_bits |= JITFeatureBits::REDUCE;
  345. } else {
  346. m_feature_bits |= JITFeatureBits::DIMSHUFFLE;
  347. }
  348. }
  349. const char* JITFusionPass::name() const {
  350. return mgb_cstr_log("fusion_pass");
  351. }
  352. void JITFusionPass::apply(OptState& opt) const {
  353. Impl{m_after_grad, m_feature_bits, opt};
  354. }
  355. #endif
  356. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)