You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mm_handler.h 2.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. /**
  2. * \file python_module/src/cpp/mm_handler.h
  3. *
  4. * This file is part of MegBrain, a deep learning framework developed by Megvii.
  5. *
  6. * \copyright Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  7. *
  8. */
  9. #pragma once
  10. #include "megbrain_build_config.h"
  11. #if MGB_ENABLE_OPR_MM
  12. #include "zmq_rpc.h"
  13. #include "megbrain/opr/collective_comm.h"
  14. #include "mm_handler.pb.h"
  15. using namespace mgb;
  16. using namespace opr;
  17. /*!
  18. * Comm MM Client Proxy.
  19. * proxy the call by using zmqrpc client interact with zmqrpc server.
  20. */
  21. class GroupClientProxy
  22. : public std::enable_shared_from_this<GroupClientProxy>,
  23. public opr::GroupClient {
  24. public:
  25. virtual ~GroupClientProxy() = default;
  26. GroupClientProxy(const std::string& server_addr)
  27. : m_addr(server_addr),
  28. m_stub{ZmqRpc::ZmqRpcClient::get_client("tcp://" + server_addr)} {
  29. }
  30. //! graph registration, assign graph_id to worker.
  31. uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank,
  32. uintptr_t stream) override;
  33. std::vector<std::string> gather_uid(const std::string& uid,
  34. const std::string& key, uint32_t size, uint32_t rank) override;
  35. void set_output_shape(const std::string& key,
  36. const TensorShape& shape) override;
  37. TensorShape get_output_shape(const std::string& key) override;
  38. uint32_t group_barrier(uint32_t size, uint32_t rank) override;
  39. //! thread safe to create handler with address
  40. static GroupClientProxy* get_handler(const std::string& addr) {
  41. static std::unordered_map<std::string,
  42. std::unique_ptr<GroupClientProxy>>
  43. addr2handler;
  44. static std::mutex mtx;
  45. MGB_LOCK_GUARD(mtx);
  46. auto it = addr2handler.emplace(addr, nullptr);
  47. if (!it.second) {
  48. mgb_assert(it.first->second->m_addr == addr);
  49. return it.first->second.get();
  50. } else {
  51. auto handler = std::make_unique<GroupClientProxy>(addr);
  52. auto handler_ptr = handler.get();
  53. it.first->second = std::move(handler);
  54. return handler_ptr;
  55. }
  56. }
  57. const std::string& get_addr() const {
  58. return m_addr;
  59. }
  60. private:
  61. const std::string m_addr;
  62. ZmqRpc::ZmqRpcClient* m_stub;
  63. };
  64. #endif
  65. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台