| @@ -47,33 +47,6 @@ const char* get_param_name(CollectiveComm::Param param) { | |||||
| } | } | ||||
| } | } | ||||
| MegRay::DType get_megray_dtype(megdnn::DType dtype) { | |||||
| switch(dtype.enumv()) { | |||||
| case DTypeEnum::Int8: | |||||
| return MegRay::DType::MEGRAY_INT8; | |||||
| case DTypeEnum::Int32: | |||||
| return MegRay::DType::MEGRAY_INT32; | |||||
| case DTypeEnum::Float32: | |||||
| return MegRay::DType::MEGRAY_FLOAT32; | |||||
| #ifndef MEGDNN_DISABLE_FLOAT16 | |||||
| case DTypeEnum::Float16: | |||||
| return MegRay::DType::MEGRAY_FLOAT16; | |||||
| #endif | |||||
| default: | |||||
| mgb_throw(MegBrainError, "bad CollectiveComm dtype"); | |||||
| } | |||||
| } | |||||
| MegRay::Backend get_megray_backend(const std::string& backend) { | |||||
| if (backend == "nccl") { | |||||
| return MegRay::MEGRAY_NCCL; | |||||
| } else if (backend == "ucx") { | |||||
| return MegRay::MEGRAY_UCX; | |||||
| } else { | |||||
| mgb_throw(MegBrainError, "back CollectiveComm backend"); | |||||
| } | |||||
| } | |||||
| cudaStream_t get_stream(VarNode* var) { | cudaStream_t get_stream(VarNode* var) { | ||||
| return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream; | return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream; | ||||
| } | } | ||||
| @@ -82,8 +82,9 @@ void RemoteSend::scn_do_execute() { | |||||
| for (size_t i = 0; i < ishp.ndim; i++) { | for (size_t i = 0; i < ishp.ndim; i++) { | ||||
| data_size *= ishp[i]; | data_size *= ishp[i]; | ||||
| } | } | ||||
| data_size *= tensor.dtype().size(); | |||||
| auto status = m_megray_comm->send(tensor.raw_ptr(), data_size, 1, m_megray_ctx); | |||||
| auto status = m_megray_comm->send(tensor.raw_ptr(), data_size, | |||||
| get_megray_dtype(tensor.dtype()), | |||||
| 1, m_megray_ctx); | |||||
| mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed"); | mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed"); | ||||
| if (m_is_grad) { | if (m_is_grad) { | ||||
| @@ -192,8 +193,9 @@ void RemoteRecv::scn_do_execute() { | |||||
| for (size_t i = 0; i < ishp.ndim; i++) { | for (size_t i = 0; i < ishp.ndim; i++) { | ||||
| data_size *= ishp[i]; | data_size *= ishp[i]; | ||||
| } | } | ||||
| data_size *= tensor.dtype().size(); | |||||
| auto status = m_megray_comm->recv(tensor.raw_ptr(), data_size, 0, m_megray_ctx); | |||||
| auto status = m_megray_comm->recv(tensor.raw_ptr(), data_size, | |||||
| get_megray_dtype(tensor.dtype()), | |||||
| 0, m_megray_ctx); | |||||
| mgb_assert(status == MegRay::MEGRAY_OK, "MegRay recv failed"); | mgb_assert(status == MegRay::MEGRAY_OK, "MegRay recv failed"); | ||||
| } | } | ||||
| @@ -14,6 +14,33 @@ | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace opr; | using namespace opr; | ||||
| MegRay::DType mgb::opr::get_megray_dtype(megdnn::DType dtype) { | |||||
| switch(dtype.enumv()) { | |||||
| case DTypeEnum::Int8: | |||||
| return MegRay::DType::MEGRAY_INT8; | |||||
| case DTypeEnum::Int32: | |||||
| return MegRay::DType::MEGRAY_INT32; | |||||
| case DTypeEnum::Float32: | |||||
| return MegRay::DType::MEGRAY_FLOAT32; | |||||
| #ifndef MEGDNN_DISABLE_FLOAT16 | |||||
| case DTypeEnum::Float16: | |||||
| return MegRay::DType::MEGRAY_FLOAT16; | |||||
| #endif | |||||
| default: | |||||
| mgb_throw(MegBrainError, "bad CollectiveComm dtype"); | |||||
| } | |||||
| } | |||||
| MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) { | |||||
| if (backend == "nccl") { | |||||
| return MegRay::MEGRAY_NCCL; | |||||
| } else if (backend == "ucx") { | |||||
| return MegRay::MEGRAY_UCX; | |||||
| } else { | |||||
| mgb_throw(MegBrainError, "back CollectiveComm backend"); | |||||
| } | |||||
| } | |||||
| bool MegRayCommBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) { | bool MegRayCommBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) { | ||||
| std::unique_lock<std::mutex> lk(m_map_mtx); | std::unique_lock<std::mutex> lk(m_map_mtx); | ||||
| auto it = m_megray_comms.find(hash); | auto it = m_megray_comms.find(hash); | ||||
| @@ -13,13 +13,16 @@ | |||||
| #include <mutex> | #include <mutex> | ||||
| #include "megbrain/utils/metahelper.h" | |||||
| #include "megbrain/opr/group_manager.h" | #include "megbrain/opr/group_manager.h" | ||||
| #include "megray.h" | #include "megray.h" | ||||
| namespace mgb { | namespace mgb { | ||||
| namespace opr { | namespace opr { | ||||
| MegRay::DType get_megray_dtype(megdnn::DType); | |||||
| MegRay::Backend get_megray_backend(const std::string& backend); | |||||
| /*! | /*! | ||||
| * gather MegRay unique ids and build communicator, use hash for deduplication | * gather MegRay unique ids and build communicator, use hash for deduplication | ||||
| */ | */ | ||||