Merge pull request !25878 from ZPaC/dir-of-distributedtags/v1.6.0
| @@ -110,8 +110,7 @@ if(ENABLE_GPU) | |||
| "runtime/device/gpu/distribution/collective_wrapper.cc" | |||
| "runtime/device/gpu/distribution/mpi_wrapper.cc" | |||
| "runtime/device/gpu/distribution/nccl_wrapper.cc" | |||
| "runtime/device/gpu/trt_loader.cc" | |||
| ) | |||
| "runtime/device/gpu/trt_loader.cc") | |||
| if(NOT ${TENSORRT_HOME} STREQUAL "") | |||
| find_path(TENSORRT_HOME_INCLUDE NvInfer.h HINTS ${TENSORRT_HOME}/include) | |||
| @@ -418,4 +417,8 @@ if(ENABLE_D) | |||
| endif() | |||
| endif() | |||
| if(ENABLE_MPI) | |||
| target_link_libraries(mindspore mindspore::ompi) | |||
| endif() | |||
| add_subdirectory(cxx_api) | |||
| @@ -1,5 +1,5 @@ | |||
| file(GLOB_RECURSE HARDWARE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| "device_context_manager.cc") | |||
| "device_context_manager.cc" "collective/*.cc") | |||
| if(ENABLE_D) | |||
| file(GLOB_RECURSE HARDWARE_D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ascend/*.cc") | |||
| @@ -11,9 +11,14 @@ endif() | |||
| if(ENABLE_CPU) | |||
| file(GLOB_RECURSE HARDWARE_CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc") | |||
| list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "cpu/mpi_collective_comm_lib.cc" "cpu/mpi_communication_group.cc") | |||
| if(ENABLE_MPI) | |||
| list(APPEND HARDWARE_CPU_SRC_LIST "cpu/mpi_collective_comm_lib.cc" "cpu/mpi_communication_group.cc") | |||
| endif() | |||
| endif() | |||
| set_property(SOURCE ${HARDWARE_SRC_LIST} ${HARDWARE_D_SRC_LIST} ${HARDWARE_GPU_SRC_LIST} ${HARDWARE_CPU_SRC_LIST} | |||
| PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) | |||
| PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) | |||
| add_library(_mindspore_runtime_hardware_obj OBJECT ${HARDWARE_SRC_LIST} ${HARDWARE_D_SRC_LIST} | |||
| ${HARDWARE_GPU_SRC_LIST} ${HARDWARE_CPU_SRC_LIST}) | |||
| @@ -0,0 +1,54 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/hardware/collective/collective_communication_lib.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| bool CollectiveCommunicationLib::DestroyCommunicationGroup(const std::string &group_name) { | |||
| if (groups_.count(group_name) == 0) { | |||
| MS_LOG(EXCEPTION) << "The group " << group_name << " is not created."; | |||
| return false; | |||
| } | |||
| auto group = groups_[group_name]; | |||
| MS_EXCEPTION_IF_NULL(group); | |||
| group->Finalize(); | |||
| return true; | |||
| } | |||
| uint32_t CollectiveCommunicationLib::GetRankId(const std::string &group_name) { | |||
| if (groups_.count(group_name) == 0) { | |||
| MS_LOG(EXCEPTION) << "The group " << group_name << " does not exist."; | |||
| return UINT32_MAX; | |||
| } | |||
| auto group = groups_[group_name]; | |||
| MS_EXCEPTION_IF_NULL(group); | |||
| return group->GetGroupRank(global_rank_id_); | |||
| } | |||
| uint32_t CollectiveCommunicationLib::GetGroupSize(const std::string &group_name) { | |||
| if (groups_.count(group_name) == 0) { | |||
| MS_LOG(EXCEPTION) << "The group " << group_name << " does not exist."; | |||
| return UINT32_MAX; | |||
| } | |||
| auto group = groups_[group_name]; | |||
| MS_EXCEPTION_IF_NULL(group); | |||
| return group->group_size(); | |||
| } | |||
| uint32_t CollectiveCommunicationLib::local_rank_id() const { return local_rank_id_; } | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -31,11 +31,14 @@ namespace device { | |||
| // MsCollectiveCommLib which uses the host-side communication library developed by MindSpore. | |||
| class CollectiveCommunicationLib { | |||
| public: | |||
| CollectiveCommunicationLib() : global_rank_id_(0), local_rank_id_(0), groups_({}) {} | |||
| CollectiveCommunicationLib() : global_rank_id_(0), local_rank_id_(0), global_rank_size_(0) {} | |||
| virtual ~CollectiveCommunicationLib() { groups_.clear(); } | |||
| // Initialize collecitve communication library. | |||
| virtual void Initialize() { return; } | |||
| // Input 'global_rank' represents this process's global rank. | |||
| // Normally, collective communication libraries on host side will generate this rank inside the 'Initialize' method. | |||
| // But collective communication libraries on device side needs this input passed by the caller. | |||
| virtual void Initialize(uint32_t global_rank = UINT32_MAX) { return; } | |||
| // Finalize collecitve communication library. | |||
| virtual void Finalize() { return; } | |||
| @@ -46,7 +49,7 @@ class CollectiveCommunicationLib { | |||
| } | |||
| // Destroy the communication group. | |||
| virtual bool DestroyCommunicationGroup(const std::string &group_name) { return true; } | |||
| virtual bool DestroyCommunicationGroup(const std::string &group_name); | |||
| // Get the rank id of this process in the specified group. | |||
| uint32_t GetRankId(const std::string &group_name); | |||
| @@ -64,6 +67,9 @@ class CollectiveCommunicationLib { | |||
| // The local rank id of this process within the same node. This is usually used as device id. | |||
| uint32_t local_rank_id_; | |||
| // The global rank size. Normally this is equal to `total process number`. | |||
| uint32_t global_rank_size_; | |||
| // This map stores the groups which will be accessed and used by the caller. | |||
| std::map<std::string, std::shared_ptr<CommunicationGroup>> groups_; | |||
| }; | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/hardware/collective/communication_group.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| CommunicationGroup::CommunicationGroup(const std::string name, const std::vector<uint32_t> &group_ranks) | |||
| : size_(group_ranks.size()), name_(name), group_ranks_(group_ranks) { | |||
| uint32_t group_rank = 0; | |||
| // The input group_ranks contains the global ranks of the processes in this group. | |||
| (void)std::for_each(group_ranks.begin(), group_ranks.end(), [&](const uint32_t &global_rank) { | |||
| global_to_group_ranks_[global_rank] = group_rank; | |||
| group_to_global_ranks_[group_rank] = global_rank; | |||
| group_rank++; | |||
| }); | |||
| } | |||
| uint32_t CommunicationGroup::GetGroupRank(uint32_t global_rank) { | |||
| if (global_to_group_ranks_.count(global_rank) == 0) { | |||
| MS_LOG(EXCEPTION) << "Group " << name_ << " doesn't contain the global rank " << global_rank; | |||
| return UINT32_MAX; | |||
| } | |||
| return global_to_group_ranks_[global_rank]; | |||
| } | |||
| uint32_t CommunicationGroup::GetGlobalRank(uint32_t group_rank) { | |||
| if (group_to_global_ranks_.count(group_rank) == 0) { | |||
| MS_LOG(EXCEPTION) << "Group " << name_ << " doesn't contain the group rank " << group_rank; | |||
| return UINT32_MAX; | |||
| } | |||
| return group_to_global_ranks_[group_rank]; | |||
| } | |||
| uint32_t CommunicationGroup::group_size() const { return size_; } | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -20,7 +20,9 @@ | |||
| #include <map> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "mindspore/core/utils/log_adapter.h" | |||
| #include "mindspore/core/utils/convert_utils_base.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| @@ -28,9 +30,7 @@ namespace device { | |||
| // communication group. MindSpore uses 'hccl_world_group' or 'nccl_world_group' as the default group. | |||
| class CommunicationGroup { | |||
| public: | |||
| explicit CommunicationGroup(uint32_t size, const std::string name, const std::vector<uint32_t> &group_ranks) | |||
| : size_(size), name_(name), group_ranks_(group_ranks), global_to_group_ranks_({}), group_to_global_ranks_({}) {} | |||
| explicit CommunicationGroup(const std::string name, const std::vector<uint32_t> &group_ranks); | |||
| virtual ~CommunicationGroup() { | |||
| group_ranks_.clear(); | |||
| global_to_group_ranks_.clear(); | |||
| @@ -64,6 +64,7 @@ class CommunicationGroup { | |||
| std::map<uint32_t, uint32_t> global_to_group_ranks_; | |||
| std::map<uint32_t, uint32_t> group_to_global_ranks_; | |||
| }; | |||
| using CommunicationGroupPtr = std::shared_ptr<CommunicationGroup>; | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_COLLECTIVE_COMMUNICATION_GROUP_H_ | |||
| @@ -0,0 +1,68 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/hardware/cpu/mpi_collective_comm_lib.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace cpu { | |||
| void MPICollectiveCommLib::Initialize(uint32_t) { | |||
| int initialized = 0; | |||
| CHECK_MPI_RET(MPI_Initialized(&initialized), "Failed to check MPI initialization status."); | |||
| if (initialized == 0) { | |||
| CHECK_MPI_RET(MPI_Init(nullptr, nullptr), "Failed to initialize MPI."); | |||
| } | |||
| // Generated MPI global rank id and rank size for the world group MPI_COMM_WORLD. | |||
| int rank_id = 0; | |||
| int rank_size = 0; | |||
| CHECK_MPI_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_id), "Failed to initialize MPI global rank id."); | |||
| CHECK_MPI_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_size), "Failed to initialize MPI global rank size."); | |||
| global_rank_id_ = IntToUint(rank_id); | |||
| global_rank_size_ = IntToUint(rank_size); | |||
| MS_LOG(INFO) << "The MPI global rank id of this process is " << global_rank_id_ << ", global rank size is " | |||
| << global_rank_size_; | |||
| // Create the world group of MPI because every other group is generated from MPI world group. | |||
| CHECK_MPI_RET(MPI_Comm_group(MPI_COMM_WORLD, &world_group_), "Failed to get group of MPI_COMM_WORLD."); | |||
| } | |||
| void MPICollectiveCommLib::Finalize() { | |||
| // The world group is also stored in groups_. So we don't need to finalize world group separately. | |||
| for (const auto &group : groups_) { | |||
| MS_EXCEPTION_IF_NULL(group.second); | |||
| group.second->Finalize(); | |||
| } | |||
| groups_.clear(); | |||
| } | |||
| bool MPICollectiveCommLib::CreateCommunicationGroup(const std::string &group_name, | |||
| const std::vector<uint32_t> &group_ranks) { | |||
| if (groups_.count(group_name) != 0) { | |||
| MS_LOG(EXCEPTION) << "The MPI group " << group_name << " has already existed."; | |||
| return false; | |||
| } | |||
| MPICommunicationGroupPtr group = std::make_shared<MPICommunicationGroup>(group_name, group_ranks); | |||
| MS_EXCEPTION_IF_NULL(group); | |||
| group->Initialize(world_group_); | |||
| groups_[group_name] = group; | |||
| MS_LOG(INFO) << "MPI group of " << group_name << " is created."; | |||
| return true; | |||
| } | |||
| } // namespace cpu | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -21,6 +21,7 @@ | |||
| #include <vector> | |||
| #include <string> | |||
| #include "runtime/hardware/collective/collective_communication_lib.h" | |||
| #include "runtime/hardware/cpu/mpi_communication_group.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| @@ -32,15 +33,17 @@ class MPICollectiveCommLib : public CollectiveCommunicationLib { | |||
| return instance; | |||
| } | |||
| void Initialize() override; | |||
| void Initialize(uint32_t global_rank = UINT32_MAX) override; | |||
| void Finalize() override; | |||
| // Override creating method. Reuse destroying method in base class CollectiveCommunicationLib. | |||
| bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) override; | |||
| bool DestroyCommunicationGroup(const std::string &group_name) override; | |||
| private: | |||
| MPICollectiveCommLib() {} | |||
| ~MPICollectiveCommLib() override; | |||
| ~MPICollectiveCommLib() override = default; | |||
| MPI_Group world_group_; | |||
| }; | |||
| } // namespace cpu | |||
| } // namespace device | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/hardware/cpu/mpi_communication_group.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace cpu { | |||
| void MPICommunicationGroup::Finalize() { | |||
| CHECK_MPI_RET(MPI_Comm_free(&group_communicator_), "Freeing MPI group communicator for " + name_ + " failed."); | |||
| CHECK_MPI_RET(MPI_Group_free(&group_), "Freeing MPI group for " + name_ + " failed."); | |||
| } | |||
| void MPICommunicationGroup::Initialize(const MPI_Group &world_group) { | |||
| std::vector<int> ranks(group_ranks_.begin(), group_ranks_.end()); | |||
| CHECK_MPI_RET(MPI_Group_incl(world_group, ranks.size(), ranks.data(), &group_), | |||
| "Creating MPI group for " + name_ + " failed."); | |||
| CHECK_MPI_RET(MPI_Comm_create(MPI_COMM_WORLD, group_, &group_communicator_), | |||
| "Creating MPI group communicator for " + name_ + " failed."); | |||
| if (group_communicator_ == MPI_COMM_NULL) { | |||
| MS_LOG(EXCEPTION) << "The MPI communicator for group " << name_ << " failed."; | |||
| return; | |||
| } | |||
| } | |||
| } // namespace cpu | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -20,24 +20,40 @@ | |||
| #include <mpi.h> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "runtime/hardware/communication_group.h" | |||
| #include <memory> | |||
| #include "runtime/hardware/collective/communication_group.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace cpu { | |||
| class MPICommunicationGroup : public CommunicationGroup { | |||
| public: | |||
| explicit MPICommunicationGroup(uint32_t size, const std::string name, const std::vector<uint32_t> &group_ranks) | |||
| : CommunicationGroup(size, name, group_ranks) {} | |||
| explicit MPICommunicationGroup(const std::string name, const std::vector<uint32_t> &group_ranks) | |||
| : CommunicationGroup(name, group_ranks) {} | |||
| ~MPICommunicationGroup() = default; | |||
| ~MPICommunicationGroup() override = default; | |||
| void Initialize() override; | |||
| void Initialize() override { return; } | |||
| void Finalize() override; | |||
| // The OpenMPI groups should be created from the world group. | |||
| void Initialize(const MPI_Group &world_group); | |||
| private: | |||
| MPI_Group mpi_group_; | |||
| MPI_Group group_; | |||
| MPI_Comm group_communicator_; | |||
| }; | |||
| using MPICommunicationGroupPtr = std::shared_ptr<MPICommunicationGroup>; | |||
| #define CHECK_MPI_RET(expression, message) \ | |||
| do { \ | |||
| { \ | |||
| auto ret = (expression); \ | |||
| if (ret != MPI_SUCCESS) { \ | |||
| MS_LOG(EXCEPTION) << (message); \ | |||
| } \ | |||
| } \ | |||
| } while (false) | |||
| } // namespace cpu | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -41,7 +41,7 @@ class MsCollectiveCommLib : public CollectiveCommunicationLib { | |||
| private: | |||
| MsCollectiveCommLib() {} | |||
| ~MsCollectiveCommLib() override; | |||
| ~MsCollectiveCommLib() override = default; | |||
| }; | |||
| } // namespace cpu | |||
| } // namespace device | |||
| @@ -26,10 +26,10 @@ namespace device { | |||
| namespace cpu { | |||
| class MsCommunicationGroup : public CommunicationGroup { | |||
| public: | |||
| explicit MsCommunicationGroup(uint32_t size, const std::string name, const std::vector<uint32_t> &group_ranks) | |||
| : MsCommunicationGroup(size, name, group_ranks) {} | |||
| explicit MsCommunicationGroup(const std::string name, const std::vector<uint32_t> &group_ranks) | |||
| : MsCommunicationGroup(name, group_ranks) {} | |||
| ~MsCommunicationGroup() = default; | |||
| ~MsCommunicationGroup() override = default; | |||
| void Initialize() override; | |||
| void Finalize() override; | |||
| @@ -33,7 +33,7 @@ class NvidiaCollectiveCommLib : public CollectiveCommunicationLib { | |||
| return instance; | |||
| } | |||
| void Initialize() override {} | |||
| void Initialize(uint32_t global_rank = UINT32_MAX) override {} | |||
| void Finalize() override {} | |||
| bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) override { | |||
| @@ -43,7 +43,7 @@ class NvidiaCollectiveCommLib : public CollectiveCommunicationLib { | |||
| private: | |||
| NvidiaCollectiveCommLib() {} | |||
| ~NvidiaCollectiveCommLib() override {} | |||
| ~NvidiaCollectiveCommLib() override = default; | |||
| }; | |||
| } // namespace gpu | |||
| } // namespace device | |||
| @@ -27,10 +27,10 @@ namespace device { | |||
| namespace gpu { | |||
| class NvidiaCommunicationGroup : public CommunicationGroup { | |||
| public: | |||
| explicit NvidiaCommunicationGroup(uint32_t size, const std::string name, const std::vector<uint32_t> &group_ranks) | |||
| : CommunicationGroup(size, name, group_ranks) {} | |||
| explicit NvidiaCommunicationGroup(const std::string name, const std::vector<uint32_t> &group_ranks) | |||
| : CommunicationGroup(name, group_ranks) {} | |||
| ~NvidiaCommunicationGroup() = default; | |||
| ~NvidiaCommunicationGroup() override = default; | |||
| void Initialize() override; | |||
| void Finalize() override; | |||