You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

io_remote.cpp 9.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. /**
  2. * \file src/opr-mm/impl/io_remote.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/opr/io_remote.h"
  12. #include "megbrain/comp_node_env.h"
  13. #include "megbrain/graph/grad_impl.h"
  14. #include "megbrain/opr/megray_helper.h"
  15. #include "megbrain/serialization/sereg.h"
  16. using namespace mgb;
  17. using namespace opr;
  18. cudaStream_t get_stream(VarNode* var) {
  19. return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream;
  20. }
  21. /* ===================== RemoteSend ===================== */
  22. MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend);
  23. RemoteSend::RemoteSend(const std::string& key, VarNode* var,
  24. std::shared_ptr<GroupClient> group_client,
  25. bool is_grad, const OperatorNodeConfig& config) :
  26. Super(var->owner_graph(), config, "remote_send", {var}),
  27. m_is_grad(is_grad) {
  28. m_key = key;
  29. m_group_client = group_client;
  30. add_input({var});
  31. auto ovar = add_output(None);
  32. if (!m_is_grad) {
  33. ovar->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
  34. .add_flag(VarNode::Flag::VOLATILE_CONTENT);
  35. }
  36. add_equivalence_component<ScalarHash<void*>>(this);
  37. }
  38. SymbolVar RemoteSend::make(const std::string& key, SymbolVar var,
  39. std::shared_ptr<GroupClient> group_client,
  40. bool is_grad, const OperatorNodeConfig& config) {
  41. return var.insert_single_output_opr<RemoteSend>(key, var.node(), group_client,
  42. is_grad, config);
  43. }
  44. void RemoteSend::scn_do_execute() {
  45. if (!m_init) {
  46. auto&& comp_node = output(0)->comp_node();
  47. bool use_cache = output(0)->owner_graph()->options().imperative_proxy_graph;
  48. struct GroupManager::RegisterInfo reg_info;
  49. if (use_cache and RegInfoCache::has_info(m_key)) {
  50. reg_info = RegInfoCache::get_info(m_key);
  51. } else {
  52. // rank 0 for RemoteSend
  53. reg_info = m_group_client->opr_register(m_key, 2, 0, false,
  54. comp_node.get_uid());
  55. if (use_cache) {
  56. RegInfoCache::set_info(m_key, reg_info);
  57. }
  58. }
  59. m_megray_comm = MegRayCommBuilder::get_megray_comm(
  60. reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_NCCL, m_group_client);
  61. m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0)));
  62. m_init = true;
  63. }
  64. mgb_assert(m_init);
  65. size_t data_size = 1;
  66. auto&& tensor = input(0)->dev_tensor();
  67. auto&& ishp = tensor.shape();
  68. for (size_t i = 0; i < ishp.ndim; i++) {
  69. data_size *= ishp[i];
  70. }
  71. data_size *= tensor.dtype().size();
  72. auto status = m_megray_comm->send(tensor.raw_ptr(), data_size, 1, m_megray_ctx);
  73. mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed");
  74. if (m_is_grad) {
  75. auto&& dest = output(0)->dev_tensor();
  76. if (m_output_val.empty()) {
  77. m_output_val.comp_node(dest.comp_node())
  78. .dtype(dest.dtype())
  79. .resize({1});
  80. memset(m_output_val.raw_ptr(), 0, m_output_val.dtype().size());
  81. }
  82. dest.copy_from_fixlayout(m_output_val);
  83. }
  84. }
  85. void RemoteSend::init_output_static_infer_desc() {
  86. using namespace cg::static_infer;
  87. auto&& mgr = owner_graph()->static_infer_manager();
  88. auto do_infer = [this](TensorShape& dest, const InpVal&) {
  89. if (m_is_grad) {
  90. dest = {1};
  91. } else {
  92. dest = {0};
  93. }
  94. return true;
  95. };
  96. mgr.register_shape_infer(output(0), {SourceType::CONSTANT, {}, do_infer});
  97. }
  98. cg::OperatorNodeBase::NodeProp* RemoteSend::do_make_node_prop() const {
  99. auto prop = RemoteIOBase::do_make_node_prop();
  100. prop->add_flag(NodeProp::Flag::CROSS_COMP_NODE_MEMORY);
  101. return prop;
  102. }
  103. #ifdef MGB_ENABLE_GRAD
  104. MGB_IMPL_OPR_GRAD(RemoteSend) {
  105. mgb_assert(opr.is_grad());
  106. return RemoteRecv::make(opr.key() + ":grad",
  107. *opr.owner_graph(), opr.group_client(),
  108. OperatorNodeConfig{opr.comp_node()}.name(
  109. opr.name() + ":grad_recv"),
  110. opr.input(0)->shape(), opr.input(0)->dtype())
  111. .node();
  112. }
  113. #endif
  114. /* ===================== RemoteRecv ===================== */
  115. MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv);
  116. RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph,
  117. std::shared_ptr<GroupClient> group_client,
  118. const OperatorNodeConfig& config,
  119. const TensorShape& shape, DType dtype) :
  120. Super(&graph, config, "remote_recv", {}),
  121. m_shape(shape), m_dtype(dtype) {
  122. m_key = key;
  123. m_group_client = group_client;
  124. add_output(None)
  125. ->dtype(dtype)
  126. .add_flag(VarNode::Flag::NO_MEM_RECLAIM)
  127. .add_flag(VarNode::Flag::DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC);
  128. add_equivalence_component<ScalarHash<void*>>(this);
  129. }
  130. SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph,
  131. std::shared_ptr<GroupClient> group_client,
  132. const OperatorNodeConfig& config,
  133. const TensorShape& shape, DType dtype) {
  134. auto opr = graph.insert_opr(std::make_unique<RemoteRecv>(
  135. key, graph, group_client, config, shape, dtype));
  136. return opr->output(0);
  137. }
  138. void RemoteRecv::scn_do_execute() {
  139. if (!m_init) {
  140. auto&& comp_node = output(0)->comp_node();
  141. bool use_cache = output(0)->owner_graph()->options().imperative_proxy_graph;
  142. struct GroupManager::RegisterInfo reg_info;
  143. if (use_cache and RegInfoCache::has_info(m_key)) {
  144. reg_info = RegInfoCache::get_info(m_key);
  145. } else {
  146. // rank 1 for RemoteRecv
  147. reg_info = m_group_client->opr_register(
  148. m_key, 2, false, 1,
  149. comp_node.get_uid());
  150. if (use_cache) {
  151. RegInfoCache::set_info(m_key, reg_info);
  152. }
  153. }
  154. m_megray_comm = MegRayCommBuilder::get_megray_comm(
  155. reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_NCCL, m_group_client);
  156. m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0)));
  157. m_init = true;
  158. }
  159. mgb_assert(m_init);
  160. size_t data_size = 1;
  161. auto&& tensor = output(0)->dev_tensor();
  162. auto&& ishp = tensor.shape();
  163. for (size_t i = 0; i < ishp.ndim; i++) {
  164. data_size *= ishp[i];
  165. }
  166. data_size *= tensor.dtype().size();
  167. auto status = m_megray_comm->recv(tensor.raw_ptr(), data_size, 0, m_megray_ctx);
  168. mgb_assert(status == MegRay::MEGRAY_OK, "MegRay recv failed");
  169. }
  170. void RemoteRecv::init_output_static_infer_desc() {
  171. using namespace cg::static_infer;
  172. auto&& mgr = owner_graph()->static_infer_manager();
  173. auto do_infer = [this](TensorShape& dest, const InpVal&) {
  174. dest = m_shape;
  175. return true;
  176. };
  177. mgr.register_shape_infer(output(0), {SourceType::CONSTANT, {}, do_infer});
  178. }
  179. cg::OperatorNodeBase::NodeProp* RemoteRecv::do_make_node_prop() const {
  180. auto prop = RemoteIOBase::do_make_node_prop();
  181. prop->add_flag(NodeProp::Flag::IMPURE_FUNC);
  182. if (input().size() == 1)
  183. prop->reset_dep_type(input(), {NodeProp::DepType::DEV_COMP_ORDER});
  184. return prop;
  185. }
  186. /* ===================== shallow copy ===================== */
  187. namespace mgb {
  188. namespace opr {
  189. cg::OperatorNodeBase* opr_shallow_copy_remote_send(
  190. const serialization::OprShallowCopyContext& ctx,
  191. const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
  192. const OperatorNodeConfig& config) {
  193. mgb_assert(inputs.size() == 1);
  194. auto&& opr = opr_.cast_final_safe<RemoteSend>();
  195. return RemoteSend::make(opr.key(), inputs[0], opr.group_client(),
  196. opr.is_grad(), config)
  197. .node()
  198. ->owner_opr();
  199. }
  200. MGB_REG_OPR_SHALLOW_COPY(RemoteSend, opr_shallow_copy_remote_send);
  201. cg::OperatorNodeBase* opr_shallow_copy_remote_recv(
  202. const serialization::OprShallowCopyContext& ctx,
  203. const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
  204. const OperatorNodeConfig& config) {
  205. auto&& opr = opr_.cast_final_safe<RemoteRecv>();
  206. return RemoteRecv::make(opr.key(), *opr.owner_graph(),
  207. opr.group_client(), config, inputs[0]->shape(),
  208. inputs[0]->dtype())
  209. .node()
  210. ->owner_opr();
  211. }
  212. MGB_REG_OPR_SHALLOW_COPY(RemoteRecv, opr_shallow_copy_remote_recv);
  213. } // namespace opr
  214. } // namespace mgb
  215. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台