Browse Source

!25878 Add MPI implementation.

Merge pull request !25878 from ZPaC/dir-of-distributed
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
5536279419
14 changed files with 274 additions and 28 deletions
  1. +5
    -2
      mindspore/ccsrc/CMakeLists.txt
  2. +7
    -2
      mindspore/ccsrc/runtime/hardware/CMakeLists.txt
  3. +54
    -0
      mindspore/ccsrc/runtime/hardware/collective/collective_communication_lib.cc
  4. +9
    -3
      mindspore/ccsrc/runtime/hardware/collective/collective_communication_lib.h
  5. +49
    -0
      mindspore/ccsrc/runtime/hardware/collective/communication_group.cc
  6. +4
    -3
      mindspore/ccsrc/runtime/hardware/collective/communication_group.h
  7. +68
    -0
      mindspore/ccsrc/runtime/hardware/cpu/mpi_collective_comm_lib.cc
  8. +6
    -3
      mindspore/ccsrc/runtime/hardware/cpu/mpi_collective_comm_lib.h
  9. +41
    -0
      mindspore/ccsrc/runtime/hardware/cpu/mpi_communication_group.cc
  10. +22
    -6
      mindspore/ccsrc/runtime/hardware/cpu/mpi_communication_group.h
  11. +1
    -1
      mindspore/ccsrc/runtime/hardware/cpu/ms_collective_comm_lib.h
  12. +3
    -3
      mindspore/ccsrc/runtime/hardware/cpu/ms_communication_group.h
  13. +2
    -2
      mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.h
  14. +3
    -3
      mindspore/ccsrc/runtime/hardware/gpu/nvidia_communication_group.h

+ 5
- 2
mindspore/ccsrc/CMakeLists.txt View File

@@ -110,8 +110,7 @@ if(ENABLE_GPU)
"runtime/device/gpu/distribution/collective_wrapper.cc"
"runtime/device/gpu/distribution/mpi_wrapper.cc"
"runtime/device/gpu/distribution/nccl_wrapper.cc"
"runtime/device/gpu/trt_loader.cc"
)
"runtime/device/gpu/trt_loader.cc")

if(NOT ${TENSORRT_HOME} STREQUAL "")
find_path(TENSORRT_HOME_INCLUDE NvInfer.h HINTS ${TENSORRT_HOME}/include)
@@ -418,4 +417,8 @@ if(ENABLE_D)
endif()
endif()

if(ENABLE_MPI)
target_link_libraries(mindspore mindspore::ompi)
endif()

add_subdirectory(cxx_api)

+ 7
- 2
mindspore/ccsrc/runtime/hardware/CMakeLists.txt View File

@@ -1,5 +1,5 @@
file(GLOB_RECURSE HARDWARE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"device_context_manager.cc")
"device_context_manager.cc" "collective/*.cc")

if(ENABLE_D)
file(GLOB_RECURSE HARDWARE_D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ascend/*.cc")
@@ -11,9 +11,14 @@ endif()

if(ENABLE_CPU)
file(GLOB_RECURSE HARDWARE_CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc")
list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "cpu/mpi_collective_comm_lib.cc" "cpu/mpi_communication_group.cc")
if(ENABLE_MPI)
list(APPEND HARDWARE_CPU_SRC_LIST "cpu/mpi_collective_comm_lib.cc" "cpu/mpi_communication_group.cc")
endif()
endif()


set_property(SOURCE ${HARDWARE_SRC_LIST} ${HARDWARE_D_SRC_LIST} ${HARDWARE_GPU_SRC_LIST} ${HARDWARE_CPU_SRC_LIST}
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
add_library(_mindspore_runtime_hardware_obj OBJECT ${HARDWARE_SRC_LIST} ${HARDWARE_D_SRC_LIST}
${HARDWARE_GPU_SRC_LIST} ${HARDWARE_CPU_SRC_LIST})

+ 54
- 0
mindspore/ccsrc/runtime/hardware/collective/collective_communication_lib.cc View File

@@ -0,0 +1,54 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "runtime/hardware/collective/collective_communication_lib.h"

namespace mindspore {
namespace device {
bool CollectiveCommunicationLib::DestroyCommunicationGroup(const std::string &group_name) {
if (groups_.count(group_name) == 0) {
MS_LOG(EXCEPTION) << "The group " << group_name << " is not created.";
return false;
}
auto group = groups_[group_name];
MS_EXCEPTION_IF_NULL(group);
group->Finalize();
return true;
}

uint32_t CollectiveCommunicationLib::GetRankId(const std::string &group_name) {
if (groups_.count(group_name) == 0) {
MS_LOG(EXCEPTION) << "The group " << group_name << " does not exist.";
return UINT32_MAX;
}
auto group = groups_[group_name];
MS_EXCEPTION_IF_NULL(group);
return group->GetGroupRank(global_rank_id_);
}

uint32_t CollectiveCommunicationLib::GetGroupSize(const std::string &group_name) {
if (groups_.count(group_name) == 0) {
MS_LOG(EXCEPTION) << "The group " << group_name << " does not exist.";
return UINT32_MAX;
}
auto group = groups_[group_name];
MS_EXCEPTION_IF_NULL(group);
return group->group_size();
}

uint32_t CollectiveCommunicationLib::local_rank_id() const { return local_rank_id_; }
} // namespace device
} // namespace mindspore

+ 9
- 3
mindspore/ccsrc/runtime/hardware/collective/collective_communication_lib.h View File

@@ -31,11 +31,14 @@ namespace device {
// MsCollectiveCommLib which uses the host-side communication library developed by MindSpore.
class CollectiveCommunicationLib {
public:
CollectiveCommunicationLib() : global_rank_id_(0), local_rank_id_(0), groups_({}) {}
CollectiveCommunicationLib() : global_rank_id_(0), local_rank_id_(0), global_rank_size_(0) {}
virtual ~CollectiveCommunicationLib() { groups_.clear(); }

// Initialize collecitve communication library.
virtual void Initialize() { return; }
// Input 'global_rank' represents this process's global rank.
// Normally, collective communication libraries on host side will generate this rank inside the 'Initialize' method.
// But collective communication libraries on device side needs this input passed by the caller.
virtual void Initialize(uint32_t global_rank = UINT32_MAX) { return; }

// Finalize collecitve communication library.
virtual void Finalize() { return; }
@@ -46,7 +49,7 @@ class CollectiveCommunicationLib {
}

// Destroy the communication group.
virtual bool DestroyCommunicationGroup(const std::string &group_name) { return true; }
virtual bool DestroyCommunicationGroup(const std::string &group_name);

// Get the rank id of this process in the specified group.
uint32_t GetRankId(const std::string &group_name);
@@ -64,6 +67,9 @@ class CollectiveCommunicationLib {
// The local rank id of this process within the same node. This is usually used as device id.
uint32_t local_rank_id_;

// The global rank size. Normally this is equal to `total process number`.
uint32_t global_rank_size_;

// This map stores the groups which will be accessed and used by the caller.
std::map<std::string, std::shared_ptr<CommunicationGroup>> groups_;
};


+ 49
- 0
mindspore/ccsrc/runtime/hardware/collective/communication_group.cc View File

@@ -0,0 +1,49 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "runtime/hardware/collective/communication_group.h"

namespace mindspore {
namespace device {
CommunicationGroup::CommunicationGroup(const std::string name, const std::vector<uint32_t> &group_ranks)
: size_(group_ranks.size()), name_(name), group_ranks_(group_ranks) {
uint32_t group_rank = 0;
// The input group_ranks contains the global ranks of the processes in this group.
(void)std::for_each(group_ranks.begin(), group_ranks.end(), [&](const uint32_t &global_rank) {
global_to_group_ranks_[global_rank] = group_rank;
group_to_global_ranks_[group_rank] = global_rank;
group_rank++;
});
}
uint32_t CommunicationGroup::GetGroupRank(uint32_t global_rank) {
if (global_to_group_ranks_.count(global_rank) == 0) {
MS_LOG(EXCEPTION) << "Group " << name_ << " doesn't contain the global rank " << global_rank;
return UINT32_MAX;
}
return global_to_group_ranks_[global_rank];
}

uint32_t CommunicationGroup::GetGlobalRank(uint32_t group_rank) {
if (group_to_global_ranks_.count(group_rank) == 0) {
MS_LOG(EXCEPTION) << "Group " << name_ << " doesn't contain the group rank " << group_rank;
return UINT32_MAX;
}
return group_to_global_ranks_[group_rank];
}

uint32_t CommunicationGroup::group_size() const { return size_; }
} // namespace device
} // namespace mindspore

+ 4
- 3
mindspore/ccsrc/runtime/hardware/collective/communication_group.h View File

@@ -20,7 +20,9 @@
#include <map>
#include <string>
#include <vector>
#include <memory>
#include "mindspore/core/utils/log_adapter.h"
#include "mindspore/core/utils/convert_utils_base.h"

namespace mindspore {
namespace device {
@@ -28,9 +30,7 @@ namespace device {
// communication group. MindSpore uses 'hccl_world_group' or 'nccl_world_group' as the default group.
class CommunicationGroup {
public:
explicit CommunicationGroup(uint32_t size, const std::string name, const std::vector<uint32_t> &group_ranks)
: size_(size), name_(name), group_ranks_(group_ranks), global_to_group_ranks_({}), group_to_global_ranks_({}) {}

explicit CommunicationGroup(const std::string name, const std::vector<uint32_t> &group_ranks);
virtual ~CommunicationGroup() {
group_ranks_.clear();
global_to_group_ranks_.clear();
@@ -64,6 +64,7 @@ class CommunicationGroup {
std::map<uint32_t, uint32_t> global_to_group_ranks_;
std::map<uint32_t, uint32_t> group_to_global_ranks_;
};
using CommunicationGroupPtr = std::shared_ptr<CommunicationGroup>;
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_COLLECTIVE_COMMUNICATION_GROUP_H_

+ 68
- 0
mindspore/ccsrc/runtime/hardware/cpu/mpi_collective_comm_lib.cc View File

@@ -0,0 +1,68 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "runtime/hardware/cpu/mpi_collective_comm_lib.h"

namespace mindspore {
namespace device {
namespace cpu {
void MPICollectiveCommLib::Initialize(uint32_t) {
int initialized = 0;
CHECK_MPI_RET(MPI_Initialized(&initialized), "Failed to check MPI initialization status.");
if (initialized == 0) {
CHECK_MPI_RET(MPI_Init(nullptr, nullptr), "Failed to initialize MPI.");
}

// Generated MPI global rank id and rank size for the world group MPI_COMM_WORLD.
int rank_id = 0;
int rank_size = 0;
CHECK_MPI_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_id), "Failed to initialize MPI global rank id.");
CHECK_MPI_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_size), "Failed to initialize MPI global rank size.");
global_rank_id_ = IntToUint(rank_id);
global_rank_size_ = IntToUint(rank_size);
MS_LOG(INFO) << "The MPI global rank id of this process is " << global_rank_id_ << ", global rank size is "
<< global_rank_size_;

// Create the world group of MPI because every other group is generated from MPI world group.
CHECK_MPI_RET(MPI_Comm_group(MPI_COMM_WORLD, &world_group_), "Failed to get group of MPI_COMM_WORLD.");
}

void MPICollectiveCommLib::Finalize() {
// The world group is also stored in groups_. So we don't need to finalize world group separately.
for (const auto &group : groups_) {
MS_EXCEPTION_IF_NULL(group.second);
group.second->Finalize();
}
groups_.clear();
}

bool MPICollectiveCommLib::CreateCommunicationGroup(const std::string &group_name,
const std::vector<uint32_t> &group_ranks) {
if (groups_.count(group_name) != 0) {
MS_LOG(EXCEPTION) << "The MPI group " << group_name << " has already existed.";
return false;
}

MPICommunicationGroupPtr group = std::make_shared<MPICommunicationGroup>(group_name, group_ranks);
MS_EXCEPTION_IF_NULL(group);
group->Initialize(world_group_);
groups_[group_name] = group;
MS_LOG(INFO) << "MPI group of " << group_name << " is created.";
return true;
}
} // namespace cpu
} // namespace device
} // namespace mindspore

+ 6
- 3
mindspore/ccsrc/runtime/hardware/cpu/mpi_collective_comm_lib.h View File

@@ -21,6 +21,7 @@
#include <vector>
#include <string>
#include "runtime/hardware/collective/collective_communication_lib.h"
#include "runtime/hardware/cpu/mpi_communication_group.h"

namespace mindspore {
namespace device {
@@ -32,15 +33,17 @@ class MPICollectiveCommLib : public CollectiveCommunicationLib {
return instance;
}

void Initialize() override;
void Initialize(uint32_t global_rank = UINT32_MAX) override;
void Finalize() override;

// Override creating method. Reuse destroying method in base class CollectiveCommunicationLib.
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) override;
bool DestroyCommunicationGroup(const std::string &group_name) override;

private:
MPICollectiveCommLib() {}
~MPICollectiveCommLib() override;
~MPICollectiveCommLib() override = default;

MPI_Group world_group_;
};
} // namespace cpu
} // namespace device


+ 41
- 0
mindspore/ccsrc/runtime/hardware/cpu/mpi_communication_group.cc View File

@@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "runtime/hardware/cpu/mpi_communication_group.h"

namespace mindspore {
namespace device {
namespace cpu {
void MPICommunicationGroup::Finalize() {
CHECK_MPI_RET(MPI_Comm_free(&group_communicator_), "Freeing MPI group communicator for " + name_ + " failed.");
CHECK_MPI_RET(MPI_Group_free(&group_), "Freeing MPI group for " + name_ + " failed.");
}

void MPICommunicationGroup::Initialize(const MPI_Group &world_group) {
std::vector<int> ranks(group_ranks_.begin(), group_ranks_.end());
CHECK_MPI_RET(MPI_Group_incl(world_group, ranks.size(), ranks.data(), &group_),
"Creating MPI group for " + name_ + " failed.");

CHECK_MPI_RET(MPI_Comm_create(MPI_COMM_WORLD, group_, &group_communicator_),
"Creating MPI group communicator for " + name_ + " failed.");
if (group_communicator_ == MPI_COMM_NULL) {
MS_LOG(EXCEPTION) << "The MPI communicator for group " << name_ << " failed.";
return;
}
}
} // namespace cpu
} // namespace device
} // namespace mindspore

+ 22
- 6
mindspore/ccsrc/runtime/hardware/cpu/mpi_communication_group.h View File

@@ -20,24 +20,40 @@
#include <mpi.h>
#include <string>
#include <vector>
#include "runtime/hardware/communication_group.h"
#include <memory>
#include "runtime/hardware/collective/communication_group.h"

namespace mindspore {
namespace device {
namespace cpu {
class MPICommunicationGroup : public CommunicationGroup {
public:
explicit MPICommunicationGroup(uint32_t size, const std::string name, const std::vector<uint32_t> &group_ranks)
: CommunicationGroup(size, name, group_ranks) {}
explicit MPICommunicationGroup(const std::string name, const std::vector<uint32_t> &group_ranks)
: CommunicationGroup(name, group_ranks) {}

~MPICommunicationGroup() = default;
~MPICommunicationGroup() override = default;

void Initialize() override;
void Initialize() override { return; }
void Finalize() override;

// The OpenMPI groups should be created from the world group.
void Initialize(const MPI_Group &world_group);

private:
MPI_Group mpi_group_;
MPI_Group group_;
MPI_Comm group_communicator_;
};
using MPICommunicationGroupPtr = std::shared_ptr<MPICommunicationGroup>;

#define CHECK_MPI_RET(expression, message) \
do { \
{ \
auto ret = (expression); \
if (ret != MPI_SUCCESS) { \
MS_LOG(EXCEPTION) << (message); \
} \
} \
} while (false)
} // namespace cpu
} // namespace device
} // namespace mindspore


+ 1
- 1
mindspore/ccsrc/runtime/hardware/cpu/ms_collective_comm_lib.h View File

@@ -41,7 +41,7 @@ class MsCollectiveCommLib : public CollectiveCommunicationLib {

private:
MsCollectiveCommLib() {}
~MsCollectiveCommLib() override;
~MsCollectiveCommLib() override = default;
};
} // namespace cpu
} // namespace device


+ 3
- 3
mindspore/ccsrc/runtime/hardware/cpu/ms_communication_group.h View File

@@ -26,10 +26,10 @@ namespace device {
namespace cpu {
class MsCommunicationGroup : public CommunicationGroup {
public:
explicit MsCommunicationGroup(uint32_t size, const std::string name, const std::vector<uint32_t> &group_ranks)
: MsCommunicationGroup(size, name, group_ranks) {}
explicit MsCommunicationGroup(const std::string name, const std::vector<uint32_t> &group_ranks)
: MsCommunicationGroup(name, group_ranks) {}

~MsCommunicationGroup() = default;
~MsCommunicationGroup() override = default;

void Initialize() override;
void Finalize() override;


+ 2
- 2
mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.h View File

@@ -33,7 +33,7 @@ class NvidiaCollectiveCommLib : public CollectiveCommunicationLib {
return instance;
}

void Initialize() override {}
void Initialize(uint32_t global_rank = UINT32_MAX) override {}
void Finalize() override {}

bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) override {
@@ -43,7 +43,7 @@ class NvidiaCollectiveCommLib : public CollectiveCommunicationLib {

private:
NvidiaCollectiveCommLib() {}
~NvidiaCollectiveCommLib() override {}
~NvidiaCollectiveCommLib() override = default;
};
} // namespace gpu
} // namespace device


+ 3
- 3
mindspore/ccsrc/runtime/hardware/gpu/nvidia_communication_group.h View File

@@ -27,10 +27,10 @@ namespace device {
namespace gpu {
class NvidiaCommunicationGroup : public CommunicationGroup {
public:
explicit NvidiaCommunicationGroup(uint32_t size, const std::string name, const std::vector<uint32_t> &group_ranks)
: CommunicationGroup(size, name, group_ranks) {}
explicit NvidiaCommunicationGroup(const std::string name, const std::vector<uint32_t> &group_ranks)
: CommunicationGroup(name, group_ranks) {}

~NvidiaCommunicationGroup() = default;
~NvidiaCommunicationGroup() override = default;

void Initialize() override;
void Finalize() override;


Loading…
Cancel
Save