You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_shallow_copy.cpp 7.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. /**
  2. * \file src/serialization/impl/opr_shallow_copy.cpp
  3. *
  4. * This file is part of MegBrain, a deep learning framework developed by Megvii.
  5. *
  6. * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved.
  7. *
  8. */
  9. #include "megbrain/serialization/opr_shallow_copy.h"
  10. #include "megbrain/gopt/basic_arith.h"
  11. #include "megbrain/serialization/opr_load_dump.h"
  12. #include "megbrain/serialization/opr_registry.h"
  13. #include "megbrain/utils/big_key_hashmap.h"
  14. using namespace mgb;
  15. using namespace serialization;
  16. namespace {
  17. //! dump single opr to memory for shallow copy
  18. class OprDumpContextMemory final : public OprDumpContextRawPOD {
  19. std::vector<uint8_t> m_buf;
  20. void write_raw(const void* data, size_t size) override {
  21. auto pos = m_buf.size();
  22. auto end = pos + size;
  23. if (end > m_buf.capacity())
  24. m_buf.reserve(end * 2);
  25. m_buf.resize(end);
  26. memcpy(m_buf.data() + pos, data, size);
  27. }
  28. void dump_tensor(const std::string&, const HostTensorND&,
  29. TensorWriteMethod) override {
  30. mgb_throw(GraphError,
  31. "OprDumpContextMemory does not support dump tensor");
  32. }
  33. const GraphDumpConfig& config() const override {
  34. mgb_throw(GraphError, "OprDumpContextMemory has no associated config");
  35. }
  36. public:
  37. OprDumpContextMemory() : OprDumpContextRawPOD(false) {}
  38. auto&& buf() const { return m_buf; }
  39. };
  40. //! load single opr from memory for shallow copy
  41. class OprLoadContextMemory final : public OprLoadContextRawPOD {
  42. const uint8_t* m_ptr;
  43. size_t m_size, m_pos = 0;
  44. ComputingGraph* m_graph;
  45. void read_raw(void* dest, size_t size) override {
  46. auto end = m_pos + size;
  47. mgb_assert(end <= m_size);
  48. memcpy(dest, m_ptr + m_pos, size);
  49. m_pos = end;
  50. }
  51. ComputingGraph& graph() override { return *m_graph; }
  52. std::shared_ptr<HostTensorND> load_tensor() override { mgb_assert(0); }
  53. std::shared_ptr<DeviceTensorND> load_tensor_shared() override {
  54. mgb_assert(0);
  55. }
  56. const GraphLoadConfig& config() const override {
  57. mgb_throw(GraphError, "OprLoadContextMemory has no associated config");
  58. }
  59. public:
  60. OprLoadContextMemory(ComputingGraph* graph,
  61. const OprDumpContextMemory& dumper)
  62. : OprLoadContextRawPOD(false),
  63. m_ptr{dumper.buf().data()},
  64. m_size{dumper.buf().size()},
  65. m_graph{graph} {}
  66. ~OprLoadContextMemory() { mgb_assert(m_pos == m_size); }
  67. };
  68. class ShallowCopyCacheContainer final : public UserDataContainer::UserData {
  69. MGB_TYPEINFO_OBJ_DECL;
  70. struct HashEq {
  71. template <typename T>
  72. static bool eq(const T& x, const T& y) {
  73. return x == y;
  74. }
  75. static bool eq(const OperatorNodeConfig& x,
  76. const OperatorNodeConfig& y) {
  77. return x.is_same(y);
  78. }
  79. static size_t hash(const void* ptr) {
  80. return std::hash<const void*>{}(ptr);
  81. }
  82. static size_t hash(const VarNodeArray& inputs) {
  83. return PODHash<VarNode*>::perform(inputs.data(), inputs.size());
  84. }
  85. static size_t hash(const OperatorNodeConfig& config) {
  86. return config.hash();
  87. }
  88. };
  89. public:
  90. big_key_hash_map::BigKeyHashMap<
  91. cg::OperatorNodeBase*, HashEq,
  92. big_key_hash_map::Copy<const cg::OperatorNodeBase*>,
  93. big_key_hash_map::Ref<VarNodeArray>,
  94. big_key_hash_map::Ref<OperatorNodeConfig>>
  95. cache;
  96. };
  97. MGB_TYPEINFO_OBJ_IMPL(ShallowCopyCacheContainer);
  98. } // anonymous namespace
  99. ComputingGraph* serialization::OprShallowCopyContext::owner_graph(
  100. const cg::OperatorNodeBase& opr, const VarNodeArray& inputs) const {
  101. if (!m_owner_graph) {
  102. if (inputs.empty())
  103. return opr.owner_graph();
  104. return inputs[0]->owner_graph();
  105. }
  106. if (!inputs.empty())
  107. mgb_assert(m_owner_graph == inputs[0]->owner_graph());
  108. return m_owner_graph;
  109. }
  110. cg::OperatorNodeBase* serialization::copy_opr_shallow(
  111. const cg::OperatorNodeBase& opr, const VarNodeArray& inputs,
  112. const OperatorNodeConfig& config, const OprShallowCopyContext& ctx) {
  113. auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo());
  114. mgb_assert(registry, "could not find OprReceiver to copy opr %s{%s}",
  115. opr.cname(), opr.dyn_typeinfo()->name);
  116. mgb_assert(inputs.size() == opr.input().size());
  117. auto dst_og = ctx.owner_graph(opr, inputs);
  118. auto do_copy = [&]() {
  119. auto nr_opr_before = opr.owner_graph()->nr_oprs_in_graph();
  120. auto ret = registry->shallow_copy(ctx, opr, inputs, config);
  121. if (dst_og != opr.owner_graph() ||
  122. opr.owner_graph()->nr_oprs_in_graph() != nr_opr_before) {
  123. auto&& attr = ret->node_prop().attribute();
  124. if (!attr.src_opr) {
  125. auto src = cg::get_opr_root_source_opr(
  126. const_cast<cg::OperatorNodeBase*>(&opr));
  127. if (ret != src)
  128. attr.src_opr = src;
  129. }
  130. if (!attr.priority) {
  131. // priority may have been changed by OprInserted event handlers
  132. // (like in python case)
  133. attr.priority = opr.node_prop().attribute().priority;
  134. }
  135. }
  136. return ret;
  137. };
  138. cg::OperatorNodeBase* ret;
  139. if (dst_og == opr.owner_graph()) {
  140. // use cache for copy in same graph
  141. auto&& cache =
  142. dst_og->options()
  143. .user_data
  144. .get_user_data_or_create<ShallowCopyCacheContainer>()
  145. ->cache;
  146. auto ins = cache.get(&opr, inputs, config);
  147. if (ins.first) {
  148. *ins.second = do_copy();
  149. } else {
  150. cg::update_output_var_shapes(*ins.second);
  151. }
  152. ret = *ins.second;
  153. } else {
  154. ret = do_copy();
  155. }
  156. mgb_assert(gopt::has_inplace_basic_arith_opt(opr) ||
  157. (( // outputs match
  158. opr.usable_output().size() ==
  159. ret->usable_output().size()) &&
  160. ( // new opr is returned
  161. (&opr != ret) || opr.input() == inputs)),
  162. "bad opr copy: src=%s{%s} dst=%s{%s}", opr.cname(),
  163. opr.dyn_typeinfo()->name, ret->cname(),
  164. ret->dyn_typeinfo()->name);
  165. return ret;
  166. }
  167. cg::OperatorNodeBase* serialization::intl::copy_opr_shallow_default_impl(
  168. const OprShallowCopyContext& ctx, const cg::OperatorNodeBase& opr,
  169. const VarNodeArray& inputs, const OperatorNodeConfig& config) {
  170. MGB_MARK_USED_VAR(ctx);
  171. auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo());
  172. mgb_assert(registry && registry->dumper && registry->loader,
  173. "can not shallow_copy operator %s{%s}: "
  174. "no dumper/loader registered",
  175. opr.cname(), opr.dyn_typeinfo()->name);
  176. OprDumpContextMemory dumper;
  177. registry->dumper(dumper, opr);
  178. OprLoadContextMemory loader{opr.owner_graph(), dumper};
  179. return registry->loader(loader, inputs, config);
  180. }
  181. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台