You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mm_handler.cpp 9.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. /**
  2. * \file python_module/src/cpp/mm_handler.cpp
  3. *
  4. * This file is part of MegBrain, a deep learning framework developed by Megvii.
  5. *
  6. * \copyright Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  7. *
  8. */
  9. #include "mm_handler.h"
  10. #include "megbrain/exception.h"
  11. #include "megbrain_config.h"
  12. #if MGB_ENABLE_OPR_MM
  13. #include "zmq_rpc.h"
  14. #include <future>
  15. /* ======================== GroupServerProxy ========================== */
  16. /*!
  17. * A proxy that receives zmqrpc call, direct call to NCCL Manager
  18. */
  19. #define RUNSERVER(rpc_name) \
  20. if (std::strcmp(describe, #rpc_name) == 0) { \
  21. std::string output; \
  22. rpc_name(input_ptr, input_len, &output); \
  23. reply.rebuild(output.length()); \
  24. memcpy(reply.data(), output.data(), output.length()); \
  25. return; \
  26. }
  27. class GroupServerProxy final : public ZmqRpc::ZmqRpcServerImpl {
  28. public:
  29. void solve_request(zmq::message_t& request,
  30. zmq::message_t& reply) override {
  31. char* describe = (char*)request.data();
  32. void* input_ptr = (char*)request.data() + strlen(describe) + 1;
  33. size_t input_len = request.size() - strlen(describe) - 1;
  34. RUNSERVER(opr_register);
  35. RUNSERVER(set_output_shape);
  36. RUNSERVER(get_output_shape);
  37. RUNSERVER(gather_uid);
  38. RUNSERVER(group_barrier);
  39. mgb_assert(false, "invalid rpc request");
  40. }
  41. private:
  42. void opr_register(void* input_ptr, size_t input_len, std::string *output);
  43. void set_output_shape(void* input_ptr, size_t input_len, std::string *output);
  44. void get_output_shape(void* input_ptr, size_t input_len, std::string *output);
  45. void gather_uid(void* input_ptr, size_t input_len, std::string *output);
  46. void group_barrier(void* input_ptr, size_t input_len, std::string *output);
  47. private:
  48. GroupManager m_mgr;
  49. };
  50. #undef RUNSERVER
  51. #define INFO_INIT(space, name) \
  52. using Request = space::name##Request; \
  53. using Response = space::name##Response; \
  54. Request req; \
  55. Response rsp; \
  56. req.ParseFromArray(input_ptr, input_len);
  57. void GroupServerProxy::opr_register(void* input_ptr, size_t input_len,
  58. std::string *output) {
  59. INFO_INIT(mm_handler, OprRegister);
  60. uint64_t hash = m_mgr.opr_register(req.key(), req.nr_expected_devices(),
  61. req.rank(), req.stream());
  62. rsp.set_hash(hash);
  63. rsp.SerializeToString(output);
  64. }
  65. void GroupServerProxy::set_output_shape(void* input_ptr, size_t input_len,
  66. std::string *output) {
  67. INFO_INIT(mm_handler, SetOutputShape);
  68. auto&& shape_proto = req.shape();
  69. TensorShape shape;
  70. shape.ndim = shape_proto.ndim();
  71. for (size_t i = 0; i < shape.ndim; ++i) {
  72. shape.shape[i] = shape_proto.shape(i);
  73. }
  74. m_mgr.set_output_shape(req.key(), shape);
  75. rsp.SerializeToString(output);
  76. }
  77. void GroupServerProxy::get_output_shape(void* input_ptr, size_t input_len,
  78. std::string *output) {
  79. INFO_INIT(mm_handler, GetOutputShape);
  80. auto shape = m_mgr.get_output_shape(req.key());
  81. auto&& shape_proto = *rsp.mutable_shape();
  82. shape_proto.set_ndim(shape.ndim);
  83. for (size_t i = 0; i < shape.ndim; ++i) {
  84. shape_proto.add_shape(shape[i]);
  85. }
  86. rsp.SerializeToString(output);
  87. }
  88. void GroupServerProxy::gather_uid(void* input_ptr, size_t input_len,
  89. std::string *output) {
  90. INFO_INIT(mm_handler, GatherUid);
  91. auto uid = req.uid();
  92. auto uids = m_mgr.gather_uid(uid, req.key(), req.size(), req.rank());
  93. for (size_t i = 0;i < uids.size();i++) {
  94. rsp.add_uids();
  95. rsp.set_uids(i, uids[i].data(), uids[i].size());
  96. }
  97. rsp.SerializeToString(output);
  98. }
  99. void GroupServerProxy::group_barrier(void* input_ptr, size_t input_len,
  100. std::string *output) {
  101. INFO_INIT(mm_handler, GroupBarrier);
  102. uint32_t rsp_size = m_mgr.group_barrier(req.size(), req.rank());
  103. rsp.set_size(rsp_size);
  104. rsp.SerializeToString(output);
  105. }
  106. #undef INFO_INIT
  107. /* ======================== GroupClientProxy ========================== */
  108. #define INFO_INIT(space, f_name, name) \
  109. using Request = space::name##Request; \
  110. using Response = space::name##Response; \
  111. std::string func_name = #f_name; \
  112. Request req; \
  113. Response rsp;
  114. #define SOLVE_REQUEST(name, req, rsp) \
  115. std::string req_str; \
  116. mgb_assert(req.SerializeToString(&req_str)); \
  117. zmq::message_t send(req_str.length() + name.length() + 1); \
  118. zmq::message_t recv; \
  119. memcpy(send.data(), name.data(), name.length() + 1); \
  120. memcpy((char*)send.data() + name.length() + 1, req_str.data(), \
  121. req_str.length()); \
  122. m_stub->request(send, recv); \
  123. mgb_assert(rsp.ParseFromArray(recv.data(), recv.size()));
  124. uint64_t GroupClientProxy::opr_register(const std::string& key, size_t nr_devices,
  125. uint32_t rank, uintptr_t stream) {
  126. INFO_INIT(mm_handler, opr_register, OprRegister)
  127. req.set_key(key);
  128. req.set_rank(rank);
  129. req.set_stream(stream);
  130. req.set_nr_expected_devices(nr_devices);
  131. SOLVE_REQUEST(func_name, req, rsp);
  132. return rsp.hash();
  133. }
  134. void GroupClientProxy::set_output_shape(const std::string& key,
  135. const TensorShape& shape) {
  136. INFO_INIT(mm_handler, set_output_shape, SetOutputShape)
  137. req.set_key(key);
  138. auto&& shape_proto = *req.mutable_shape();
  139. shape_proto.set_ndim(shape.ndim);
  140. for (size_t i = 0; i < shape.ndim; ++i) {
  141. shape_proto.add_shape(shape[i]);
  142. }
  143. SOLVE_REQUEST(func_name, req, rsp);
  144. }
  145. TensorShape GroupClientProxy::get_output_shape(const std::string& key) {
  146. INFO_INIT(mm_handler, get_output_shape, GetOutputShape)
  147. req.set_key(key);
  148. SOLVE_REQUEST(func_name, req, rsp);
  149. TensorShape shape;
  150. shape.ndim = rsp.shape().ndim();
  151. for (size_t i = 0; i < shape.ndim; ++i) {
  152. shape[i] = rsp.shape().shape(i);
  153. }
  154. return shape;
  155. }
  156. std::vector<std::string> GroupClientProxy::gather_uid(const std::string& uid,
  157. const std::string& key, uint32_t size, uint32_t rank) {
  158. INFO_INIT(mm_handler, gather_uid, GatherUid);
  159. req.set_uid(uid.data(), uid.size());
  160. req.set_key(key.data(), key.size());
  161. req.set_size(size);
  162. req.set_rank(rank);
  163. SOLVE_REQUEST(func_name, req, rsp);
  164. std::vector<std::string> rst;
  165. for (size_t i = 0;i < size;i++) {
  166. rst.push_back(rsp.uids(i));
  167. }
  168. return rst;
  169. }
  170. uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) {
  171. INFO_INIT(mm_handler, group_barrier, GroupBarrier);
  172. req.set_size(size);
  173. req.set_rank(rank);
  174. SOLVE_REQUEST(func_name, req, rsp);
  175. return rsp.size();
  176. }
  177. #undef INFO_INIT
  178. #undef SOLVE_REQUEST
  179. /* ======================== ZmqRpcServerMgr ========================== */
  180. class ZmqRpcServerMgr {
  181. struct ServerInfo {
  182. std::unique_ptr<ZmqRpc::ZmqRpcServer> server;
  183. };
  184. public:
  185. int create_zmqrpc_server(const std::string& server_addr, int port,
  186. std::unique_ptr<ZmqRpc::ZmqRpcServerImpl> service) {
  187. MGB_LOCK_GUARD(m_mtx);
  188. auto server =
  189. std::make_unique<ZmqRpc::ZmqRpcServer>("tcp://" + server_addr, port,
  190. std::move(service));
  191. port = server->port();
  192. if (port == -1) {
  193. return -1;
  194. }
  195. auto full_srv_addr = ssprintf("%s:%d", server_addr.c_str(), port);
  196. server->run();
  197. auto ins = m_addr2server.emplace(
  198. full_srv_addr, ServerInfo{std::move(server)});
  199. mgb_assert(ins.second);
  200. return port;
  201. }
  202. static ZmqRpcServerMgr* get_zmqrpc_server_mgr() {
  203. static ZmqRpcServerMgr mgr;
  204. return &mgr;
  205. }
  206. private:
  207. std::unordered_map<std::string, ServerInfo> m_addr2server;
  208. std::mutex m_mtx;
  209. };
  210. /*! see definition : src/cpp/megbrain_config.h.
  211. * Create mm server. port 0 is permitted, leave zmqrpc to decide which port
  212. * should be used.
  213. */
  214. int _config::create_mm_server(const std::string& server_addr, int port) {
  215. return ZmqRpcServerMgr::get_zmqrpc_server_mgr()->create_zmqrpc_server(
  216. server_addr, port, std::make_unique<GroupServerProxy>());
  217. }
  218. /* ======================== Group Barrier ========================== */
  219. /*! see definition : src/cpp/megbrain_config.h.
  220. * Block until all ranks in the group reach this barrier
  221. */
  222. void _config::group_barrier(const std::string& server_addr,
  223. int port, uint32_t size, uint32_t rank) {
  224. mgb_assert(rank < size, "invalid rank %d", rank);
  225. auto group_mgr = std::make_shared<GroupClientProxy>(
  226. ssprintf("%s:%d", server_addr.c_str(), port));
  227. uint32_t rsp = group_mgr->group_barrier(size, rank);
  228. mgb_assert(rsp != 0, "rank already registered: %d", rank);
  229. mgb_assert(size == rsp, "inconsistent size: %d, expect %d", size, rsp);
  230. }
  231. #else
  232. int _config::create_mm_server(const std::string& server_addr, int port) {
  233. mgb_throw(mgb::MegBrainError, "distributed mode disabled at compile time");
  234. return 0;
  235. }
  236. void _config::group_barrier(const std::string& server_addr,
  237. int port, uint32_t size, uint32_t rank) {
  238. mgb_throw(mgb::MegBrainError, "distributed mode disabled at compile time");
  239. }
  240. #endif
  241. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台