GitOrigin-RevId: d5ae3c5a7c
tags/v1.0.0-rc1
| @@ -687,11 +687,21 @@ SymbolVarArray CollectiveComm::make( | |||||
| void CollectiveComm::opr_register() { | void CollectiveComm::opr_register() { | ||||
| if (m_init) | if (m_init) | ||||
| return; | return; | ||||
| auto&& comp_node = output(0)->comp_node(); | auto&& comp_node = output(0)->comp_node(); | ||||
| bool use_cache = output(0)->owner_graph()->options().imperative_proxy_graph; | |||||
| struct GroupManager::RegisterInfo reg_info; | |||||
| auto reg_info = m_group_client->opr_register( | |||||
| m_key, m_nr_devices, m_is_root, m_rank, | |||||
| comp_node.get_uid()); | |||||
| if (use_cache and RegInfoCache::has_info(m_key)) { | |||||
| reg_info = RegInfoCache::get_info(m_key); | |||||
| } else { | |||||
| reg_info = m_group_client->opr_register( | |||||
| m_key, m_nr_devices, m_is_root, m_rank, | |||||
| comp_node.get_uid()); | |||||
| if (use_cache) { | |||||
| RegInfoCache::set_info(m_key, reg_info); | |||||
| } | |||||
| } | |||||
| m_rank = reg_info.rank; | m_rank = reg_info.rank; | ||||
| m_root = reg_info.root_rank; | m_root = reg_info.root_rank; | ||||
| @@ -205,4 +205,20 @@ uint32_t GroupManager::group_barrier(uint32_t size, uint32_t rank) { | |||||
| return m_barrier_size; | return m_barrier_size; | ||||
| } | } | ||||
| void RegInfoCache::set_info(const std::string& key, | |||||
| const GroupManager::RegisterInfo& info) { | |||||
| std::unique_lock<std::mutex> lock(RegInfoCache::mtx); | |||||
| RegInfoCache::key2info[key] = info; | |||||
| } | |||||
| bool RegInfoCache::has_info(const std::string& key) { | |||||
| std::unique_lock<std::mutex> lock(RegInfoCache::mtx); | |||||
| return RegInfoCache::key2info.find(key) != RegInfoCache::key2info.end(); | |||||
| } | |||||
| GroupManager::RegisterInfo RegInfoCache::get_info(const std::string& key) { | |||||
| std::unique_lock<std::mutex> lock(RegInfoCache::mtx); | |||||
| return RegInfoCache::key2info[key]; | |||||
| } | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -53,10 +53,19 @@ SymbolVar RemoteSend::make(const std::string& key, SymbolVar var, | |||||
| void RemoteSend::scn_do_execute() { | void RemoteSend::scn_do_execute() { | ||||
| if (!m_init) { | if (!m_init) { | ||||
| auto&& comp_node = output(0)->comp_node(); | auto&& comp_node = output(0)->comp_node(); | ||||
| bool use_cache = output(0)->owner_graph()->options().imperative_proxy_graph; | |||||
| struct GroupManager::RegisterInfo reg_info; | |||||
| // rank 0 for RemoteSend | |||||
| auto reg_info = m_group_client->opr_register(m_key, 2, 0, false, | |||||
| comp_node.get_uid()); | |||||
| if (use_cache and RegInfoCache::has_info(m_key)) { | |||||
| reg_info = RegInfoCache::get_info(m_key); | |||||
| } else { | |||||
| // rank 0 for RemoteSend | |||||
| reg_info = m_group_client->opr_register(m_key, 2, 0, false, | |||||
| comp_node.get_uid()); | |||||
| if (use_cache) { | |||||
| RegInfoCache::set_info(m_key, reg_info); | |||||
| } | |||||
| } | |||||
| m_megray_comm = MegRayCommBuilder::get_megray_comm( | m_megray_comm = MegRayCommBuilder::get_megray_comm( | ||||
| reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_NCCL, m_group_client); | reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_NCCL, m_group_client); | ||||
| @@ -153,11 +162,20 @@ SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, | |||||
| void RemoteRecv::scn_do_execute() { | void RemoteRecv::scn_do_execute() { | ||||
| if (!m_init) { | if (!m_init) { | ||||
| auto&& comp_node = output(0)->comp_node(); | auto&& comp_node = output(0)->comp_node(); | ||||
| bool use_cache = output(0)->owner_graph()->options().imperative_proxy_graph; | |||||
| struct GroupManager::RegisterInfo reg_info; | |||||
| // rank 1 for RemoteRecv | |||||
| auto reg_info = m_group_client->opr_register( | |||||
| m_key, 2, false, 1, | |||||
| comp_node.get_uid()); | |||||
| if (use_cache and RegInfoCache::has_info(m_key)) { | |||||
| reg_info = RegInfoCache::get_info(m_key); | |||||
| } else { | |||||
| // rank 1 for RemoteRecv | |||||
| reg_info = m_group_client->opr_register( | |||||
| m_key, 2, false, 1, | |||||
| comp_node.get_uid()); | |||||
| if (use_cache) { | |||||
| RegInfoCache::set_info(m_key, reg_info); | |||||
| } | |||||
| } | |||||
| m_megray_comm = MegRayCommBuilder::get_megray_comm( | m_megray_comm = MegRayCommBuilder::get_megray_comm( | ||||
| reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_NCCL, m_group_client); | reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_NCCL, m_group_client); | ||||
| @@ -145,6 +145,22 @@ class GroupClient { | |||||
| virtual uint32_t group_barrier(uint32_t size, uint32_t rank) = 0; | virtual uint32_t group_barrier(uint32_t size, uint32_t rank) = 0; | ||||
| }; | }; | ||||
| /*! | |||||
| * Cache RegisterInfo returned from GroupManager. This feature is only enabled | |||||
| * in imperative runtime mode, so that multi-machine operators do not have to | |||||
| * call opr_register repeatedly in each iter | |||||
| */ | |||||
| namespace RegInfoCache { | |||||
| static std::mutex mtx; | |||||
| static std::unordered_map<std::string, GroupManager::RegisterInfo> key2info; | |||||
| void set_info(const std::string& key, const GroupManager::RegisterInfo& info); | |||||
| bool has_info(const std::string& key); | |||||
| GroupManager::RegisterInfo get_info(const std::string& key); | |||||
| } // namespace RegInfoCache | |||||
| } // namespace opr | } // namespace opr | ||||
| } // namespace mgb | } // namespace mgb | ||||