/** * \file python_module/src/cpp/mm_handler.h * * This file is part of MegBrain, a deep learning framework developed by Megvii. * * \copyright Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * */ #pragma once #include "megbrain_build_config.h" #if MGB_ENABLE_OPR_MM #include "zmq_rpc.h" #include "megbrain/opr/collective_comm.h" #include "mm_handler.pb.h" using namespace mgb; using namespace opr; /*! * Comm MM Client Proxy. * proxy the call by using zmqrpc client interact with zmqrpc server. */ class GroupClientProxy : public std::enable_shared_from_this, public opr::GroupClient { public: virtual ~GroupClientProxy() = default; GroupClientProxy(const std::string& server_addr) : m_addr(server_addr), m_stub{ZmqRpc::ZmqRpcClient::get_client("tcp://" + server_addr)} { } //! graph registration, assign graph_id to worker. uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank, uintptr_t stream) override; std::vector gather_uid(const std::string& uid, const std::string& key, uint32_t size, uint32_t rank) override; void set_output_shape(const std::string& key, const TensorShape& shape) override; TensorShape get_output_shape(const std::string& key) override; uint32_t group_barrier(uint32_t size, uint32_t rank) override; //! thread safe to create handler with address static GroupClientProxy* get_handler(const std::string& addr) { static std::unordered_map> addr2handler; static std::mutex mtx; MGB_LOCK_GUARD(mtx); auto it = addr2handler.emplace(addr, nullptr); if (!it.second) { mgb_assert(it.first->second->m_addr == addr); return it.first->second.get(); } else { auto handler = std::make_unique(addr); auto handler_ptr = handler.get(); it.first->second = std::move(handler); return handler_ptr; } } const std::string& get_addr() const { return m_addr; } private: const std::string m_addr; ZmqRpc::ZmqRpcClient* m_stub; }; #endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}