You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_opt.cpp 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. /**
  2. * \file src/core/impl/graph/graph_opt.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./graph_opt.h"
  12. #include "megbrain/opr/io.h"
  13. #include "megbrain/opr/tensor_manip.h"
  14. #include "megbrain/opr/basic_arith.h"
  15. #include "megbrain/serialization/serializer.h"
  16. using namespace mgb;
  17. using namespace cg;
  18. constexpr size_t MAX_CONST_FOLDING_SIZE = 1024;
  19. OperatorNodeBase* GraphOptimizer::insert_pre(OperatorNodeBase *opr) {
  20. auto hash = opr->hash();
  21. auto iter = m_opr_hash_list.find(hash);
  22. if (iter != m_opr_hash_list.end()) {
  23. for (auto i: iter->second) {
  24. if (i->is_same(*opr)) {
  25. if (opr->owner_graph()->options().log_level >= 2) {
  26. mgb_log_debug("opr %s{%s} already exists as %s, "
  27. "do not insert again",
  28. opr->cname(), opr->dyn_typeinfo()->name,
  29. i->cname());
  30. }
  31. mgb_assert(i->output().size() == opr->output().size());
  32. if (opr->usable_output().size() == 1) {
  33. auto c = m_const_map.find(i->output(0));
  34. if (c != m_const_map.end())
  35. return c->second;
  36. }
  37. return i;
  38. }
  39. }
  40. }
  41. return nullptr;
  42. }
  43. OperatorNodeBase* GraphOptimizer::insert_post(OperatorNodeBase *opr) {
  44. bool already_inserted = false;
  45. auto hash = opr->hash();
  46. auto iter = m_opr_hash_list.find(hash);
  47. if (iter != m_opr_hash_list.end()) {
  48. for (auto i: iter->second) {
  49. if (i->is_same(*opr)) {
  50. already_inserted = true;
  51. // If the hash of the operator to be saved is already saved in
  52. // m_opr_hash_list, we validate that the to-be-saved operator
  53. // is original one which we saved.
  54. // If this fails, it usually means insert_post is not paired
  55. // with a corresponding insert_pre, or the caller didn't use
  56. // the saved operator returned by insert_pre.
  57. mgb_assert(i == opr);
  58. }
  59. }
  60. }
  61. if (!already_inserted) {
  62. m_opr_hash_list[hash].push_back(opr);
  63. }
  64. #if !MGB_BUILD_SLIM_SERVING
  65. // For eager mode, return the original opr without the opt pass
  66. if (opr->owner_graph()->options().eager_evaluation) return opr;
  67. #endif
  68. OperatorNodeBase* ret = nullptr;
  69. static const std::array<OperatorNodeBase* (GraphOptimizer::*) (VarNode*), 3> passes = {
  70. &GraphOptimizer::merge_bcast,
  71. &GraphOptimizer::swap_typecvt_and_bcast,
  72. &GraphOptimizer::replace_const_var,
  73. };
  74. for (auto pass : passes) {
  75. if (opr->usable_output().size() > 1)
  76. break;
  77. ret = (this->*pass)(opr->output(0));
  78. opr = ret ? ret : opr;
  79. }
  80. return opr;
  81. }
  82. namespace {
  83. Maybe<std::pair<OperatorNodeBase*, OperatorNodeBase*>> match_oprs_in_chain(
  84. VarNode* var, Typeinfo* type, Typeinfo* prev_type) {
  85. auto opr = var->owner_opr();
  86. if (opr->input().size() == 0)
  87. return {};
  88. if (opr->dyn_typeinfo() != type)
  89. return {};
  90. auto prev_opr = opr->input(0)->owner_opr();
  91. if (prev_opr->dyn_typeinfo() != prev_type)
  92. return {};
  93. return std::pair<OperatorNodeBase*, OperatorNodeBase*>{opr, prev_opr};
  94. }
  95. } // namespace
  96. OperatorNodeBase* GraphOptimizer::merge_bcast(VarNode* var) {
  97. if (!is_const_var_value(var))
  98. return nullptr;
  99. auto bcast_type = opr::Broadcast::typeinfo();
  100. auto oprs = match_oprs_in_chain(var, bcast_type, bcast_type);
  101. if (!oprs.valid())
  102. return nullptr;
  103. auto opr = oprs->first;
  104. auto prev_opr = oprs->second;
  105. auto new_bcast = opr::Broadcast::make(
  106. prev_opr->input(0), opr->output(0)->shape(), opr->config());
  107. return new_bcast.node()->owner_opr();
  108. }
  109. OperatorNodeBase* GraphOptimizer::swap_typecvt_and_bcast(VarNode* var) {
  110. if (!is_const_var_value(var))
  111. return nullptr;
  112. auto oprs = match_oprs_in_chain(var, opr::TypeCvt::typeinfo(),
  113. opr::Broadcast::typeinfo());
  114. if (!oprs.valid())
  115. return nullptr;
  116. auto opr = oprs->first;
  117. auto prev_opr = oprs->second;
  118. auto new_cvt =
  119. opr::TypeCvt::make(prev_opr->input(0), var->dtype(), opr->config());
  120. auto new_bcast = opr::Broadcast::make(new_cvt, prev_opr->output(0)->shape(),
  121. prev_opr->config());
  122. return new_bcast.node()->owner_opr();
  123. }
  124. OperatorNodeBase* GraphOptimizer::replace_const_var(VarNode* var) {
  125. if (!is_const_var_value(var))
  126. return nullptr;
  127. {
  128. auto type = var->owner_opr()->dyn_typeinfo();
  129. if (type == opr::ImmutableTensor::typeinfo())
  130. return nullptr;
  131. }
  132. auto&& mgr = var->owner_graph()->static_infer_manager();
  133. auto&& shp = mgr.infer_shape(var);
  134. if (shp.total_nr_elems() >= MAX_CONST_FOLDING_SIZE)
  135. return nullptr;
  136. auto&& infer_val = mgr.infer_value(var);
  137. if (!infer_val.layout().is_contiguous()) {
  138. return nullptr;
  139. }
  140. HostTensorND val;
  141. val.copy_from(infer_val);
  142. auto imm = opr::ImmutableTensor::make(
  143. *var->owner_graph(), val,
  144. OperatorNodeConfig{}.comp_node(var->comp_node()))
  145. .node()
  146. ->owner_opr();
  147. m_const_map[var] = imm;
  148. mgb_assert(imm->output(0)->dtype() == var->dtype());
  149. return imm;
  150. }
  151. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台