diff --git a/src/opr-mm/impl/collective_comm.cpp b/src/opr-mm/impl/collective_comm.cpp index 5f6166e1..d80df9b7 100644 --- a/src/opr-mm/impl/collective_comm.cpp +++ b/src/opr-mm/impl/collective_comm.cpp @@ -458,13 +458,16 @@ void CollectiveComm::opr_register() { auto hash = m_group_client->opr_register(m_key, m_nr_devices, m_rank, reinterpret_cast(cuda_env.stream)); - auto megray_comm_builder = - owner_graph() - ->options() - .user_data - .get_user_data_or_create(); + MegRayCommunicatorBuilder* builder; - m_megray_comm = megray_comm_builder->get_megray_comm( + { + static std::mutex user_data_mtx; + std::unique_lock lk(user_data_mtx); + builder = owner_graph()->options().user_data + .get_user_data_or_create(); + } + + m_megray_comm = builder->get_megray_comm( hash, m_key, m_nr_devices, m_rank, get_megray_backend(m_backend), m_group_client);