|
|
@@ -15,13 +15,41 @@ |
|
|
*/ |
|
|
*/ |
|
|
|
|
|
|
|
|
#include "device/cpu/mpi/mpi_adapter.h" |
|
|
#include "device/cpu/mpi/mpi_adapter.h" |
|
|
|
|
|
#ifdef ENABLE_MPI |
|
|
#include <algorithm> |
|
|
#include <algorithm> |
|
|
#include "utils/mpi/mpi_config.h" |
|
|
|
|
|
|
|
|
#include <sstream> |
|
|
|
|
|
#include "pybind11/pybind11.h" |
|
|
|
|
|
#endif // ENABLE_MPI |
|
|
#include "utils/log_adapter.h" |
|
|
#include "utils/log_adapter.h" |
|
|
|
|
|
|
|
|
namespace mindspore { |
|
|
namespace mindspore { |
|
|
namespace device { |
|
|
namespace device { |
|
|
namespace cpu { |
|
|
namespace cpu { |
|
|
|
|
|
std::shared_ptr<MPIAdapter> MPIAdapter::instance_ = nullptr; |
|
|
|
|
|
std::shared_ptr<MPIAdapter> MPIAdapter::Instance() { |
|
|
|
|
|
if (instance_ == nullptr) { |
|
|
|
|
|
MS_LOG(DEBUG) << "Create new mpi adapter instance."; |
|
|
|
|
|
instance_.reset(new (std::nothrow) MPIAdapter()); |
|
|
|
|
|
} |
|
|
|
|
|
return instance_; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
#ifdef ENABLE_MPI |
|
|
|
|
|
|
|
|
|
|
|
#define RAISE_EXCEPTION(message) \ |
|
|
|
|
|
{ \ |
|
|
|
|
|
std::ostringstream oss; \ |
|
|
|
|
|
oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << message; \ |
|
|
|
|
|
pybind11::pybind11_fail(oss.str()); \ |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
#define RAISE_EXCEPTION_WITH_PARAM(message, param) \ |
|
|
|
|
|
{ \ |
|
|
|
|
|
std::ostringstream oss; \ |
|
|
|
|
|
oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << message << param; \ |
|
|
|
|
|
pybind11::pybind11_fail(oss.str()); \ |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
namespace { |
|
|
namespace { |
|
|
MPI_Op GetMpiOp(const std::string &op_type) { |
|
|
MPI_Op GetMpiOp(const std::string &op_type) { |
|
|
if (op_type == "sum") { |
|
|
if (op_type == "sum") { |
|
|
@@ -33,7 +61,8 @@ MPI_Op GetMpiOp(const std::string &op_type) { |
|
|
} else if (op_type == "prod") { |
|
|
} else if (op_type == "prod") { |
|
|
return MPI_PROD; |
|
|
return MPI_PROD; |
|
|
} |
|
|
} |
|
|
MS_LOG(EXCEPTION) << "unsupport op_type:" << op_type; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
RAISE_EXCEPTION_WITH_PARAM("unsupport op_type: ", op_type); |
|
|
return MPI_SUM; |
|
|
return MPI_SUM; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -46,80 +75,72 @@ int GetScatterIndex(int rankid, const std::vector<int> &ranks_group) { |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
if (scatter_index == -1) { |
|
|
if (scatter_index == -1) { |
|
|
MS_LOG(EXCEPTION) << "process rankid " << rankid << " does not in the input rank group!"; |
|
|
|
|
|
|
|
|
RAISE_EXCEPTION_WITH_PARAM("local rankid does not in the input rank group!local rank id:", rankid); |
|
|
} |
|
|
} |
|
|
return scatter_index; |
|
|
return scatter_index; |
|
|
} |
|
|
} |
|
|
} // namespace |
|
|
} // namespace |
|
|
|
|
|
|
|
|
MPIAdapter::MPIAdapter() : rank_id_(0), rank_size_(0), comm_group_world_(MPI_GROUP_NULL) { Init(); } |
|
|
|
|
|
|
|
|
MPIAdapter::MPIAdapter() : comm_group_world_(MPI_GROUP_NULL) { Init(); } |
|
|
|
|
|
|
|
|
MPIAdapter::~MPIAdapter() { |
|
|
MPIAdapter::~MPIAdapter() { |
|
|
|
|
|
int finalized; |
|
|
|
|
|
MPI_Finalized(&finalized); |
|
|
|
|
|
if (finalized != 0) { |
|
|
|
|
|
return; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
for (auto iter = ranks_group_.begin(); iter != ranks_group_.end(); ++iter) { |
|
|
for (auto iter = ranks_group_.begin(); iter != ranks_group_.end(); ++iter) { |
|
|
MPI_Group_free(&iter->second); |
|
|
MPI_Group_free(&iter->second); |
|
|
} |
|
|
} |
|
|
|
|
|
ranks_group_.clear(); |
|
|
if (comm_group_world_ != MPI_GROUP_NULL) { |
|
|
if (comm_group_world_ != MPI_GROUP_NULL) { |
|
|
MPI_Group_free(&comm_group_world_); |
|
|
MPI_Group_free(&comm_group_world_); |
|
|
|
|
|
comm_group_world_ = MPI_GROUP_NULL; |
|
|
} |
|
|
} |
|
|
int finalized; |
|
|
|
|
|
MPI_Finalized(&finalized); |
|
|
|
|
|
if (finalized == 0) { |
|
|
|
|
|
MPI_Finalize(); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
MPI_Finalize(); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
MPIAdapter &MPIAdapter::Instance() { |
|
|
|
|
|
static MPIAdapter instance; |
|
|
|
|
|
return instance; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
int MPIAdapter::GetRankId() const { return rank_id_; } |
|
|
|
|
|
|
|
|
|
|
|
void MPIAdapter::Init() { |
|
|
void MPIAdapter::Init() { |
|
|
static bool init = false; |
|
|
static bool init = false; |
|
|
if (init) { |
|
|
if (init) { |
|
|
return; |
|
|
return; |
|
|
} |
|
|
} |
|
|
auto mpi_config_ptr = MpiConfig::GetInstance(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(mpi_config_ptr); |
|
|
|
|
|
if (!mpi_config_ptr->enable_mpi()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "MPI is disabled now!Please enable mpi with mpi config first."; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
int init_flag = 0; |
|
|
int init_flag = 0; |
|
|
if (MPI_Initialized(&init_flag) != MPI_SUCCESS) { |
|
|
if (MPI_Initialized(&init_flag) != MPI_SUCCESS) { |
|
|
MS_LOG(EXCEPTION) << "Check mpi initialized fail!"; |
|
|
|
|
|
|
|
|
RAISE_EXCEPTION("Check mpi initialized fail!"); |
|
|
} |
|
|
} |
|
|
if (init_flag == 0) { |
|
|
if (init_flag == 0) { |
|
|
auto ret = MPI_Init(nullptr, nullptr); |
|
|
auto ret = MPI_Init(nullptr, nullptr); |
|
|
if (ret != MPI_SUCCESS) { |
|
|
if (ret != MPI_SUCCESS) { |
|
|
MS_LOG(EXCEPTION) << "Failed to init mpi!"; |
|
|
|
|
|
|
|
|
RAISE_EXCEPTION("Failed to init mpi!"); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
MPI_Comm_group(MPI_COMM_WORLD, &comm_group_world_); |
|
|
MPI_Comm_group(MPI_COMM_WORLD, &comm_group_world_); |
|
|
if (comm_group_world_ == MPI_GROUP_NULL) { |
|
|
if (comm_group_world_ == MPI_GROUP_NULL) { |
|
|
MS_LOG(EXCEPTION) << "comm_group_world_ init fail!"; |
|
|
|
|
|
|
|
|
RAISE_EXCEPTION("comm_group_world_ init fail!"); |
|
|
} |
|
|
} |
|
|
auto ret = MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_); |
|
|
auto ret = MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_); |
|
|
if (ret != MPI_SUCCESS) { |
|
|
if (ret != MPI_SUCCESS) { |
|
|
MS_LOG(EXCEPTION) << "Failed to init mpi rank id!"; |
|
|
|
|
|
|
|
|
RAISE_EXCEPTION("Failed to init mpi rank id!"); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
ret = MPI_Comm_size(MPI_COMM_WORLD, &rank_size_); |
|
|
ret = MPI_Comm_size(MPI_COMM_WORLD, &rank_size_); |
|
|
if (ret != MPI_SUCCESS) { |
|
|
if (ret != MPI_SUCCESS) { |
|
|
MS_LOG(EXCEPTION) << "Failed to init mpi rank size!rankid:" << rank_id_; |
|
|
|
|
|
|
|
|
RAISE_EXCEPTION_WITH_PARAM("Failed to init mpi rank size!rankid:", rank_id_) |
|
|
} |
|
|
} |
|
|
init = true; |
|
|
init = true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
MPI_Group MPIAdapter::AddGroup(const std::vector<int> &ranks) { |
|
|
MPI_Group MPIAdapter::AddGroup(const std::vector<int> &ranks) { |
|
|
if (ranks.size() > static_cast<size_t>(rank_size_) || ranks.empty()) { |
|
|
if (ranks.size() > static_cast<size_t>(rank_size_) || ranks.empty()) { |
|
|
MS_LOG(EXCEPTION) << "input rank size: " << ranks.size() << ", max rank size: " << rank_size_; |
|
|
|
|
|
|
|
|
RAISE_EXCEPTION_WITH_PARAM("input rank size:", ranks.size()); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (std::find(ranks.begin(), ranks.end(), rank_id_) == ranks.end()) { |
|
|
if (std::find(ranks.begin(), ranks.end(), rank_id_) == ranks.end()) { |
|
|
MS_LOG(ERROR) << "rankid:" << rank_id_ << " is not in the input group."; |
|
|
|
|
|
return MPI_GROUP_NULL; |
|
|
|
|
|
|
|
|
RAISE_EXCEPTION_WITH_PARAM("local rankid does not in the input rank group!local rank id:", rank_id_); |
|
|
} |
|
|
} |
|
|
std::lock_guard<std::mutex> lock(group_mutex_); |
|
|
std::lock_guard<std::mutex> lock(group_mutex_); |
|
|
auto iter = ranks_group_.find(ranks); |
|
|
auto iter = ranks_group_.find(ranks); |
|
|
@@ -135,29 +156,28 @@ MPI_Group MPIAdapter::AddGroup(const std::vector<int> &ranks) { |
|
|
MPI_Group group = MPI_GROUP_NULL; |
|
|
MPI_Group group = MPI_GROUP_NULL; |
|
|
MPI_Group_incl(comm_group_world_, ranks.size(), ranks_input.data(), &group); |
|
|
MPI_Group_incl(comm_group_world_, ranks.size(), ranks_input.data(), &group); |
|
|
if (group == MPI_GROUP_NULL) { |
|
|
if (group == MPI_GROUP_NULL) { |
|
|
MS_LOG(EXCEPTION) << "create mpi group fail!rankid:" << rank_id_; |
|
|
|
|
|
|
|
|
RAISE_EXCEPTION_WITH_PARAM("create mpi group fail!rankid:", rank_id_) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
ranks_group_[ranks] = group; |
|
|
ranks_group_[ranks] = group; |
|
|
MS_LOG(INFO) << "rank:" << rank_id_ << " add group:" << group; |
|
|
|
|
|
return group; |
|
|
return group; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num, |
|
|
bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num, |
|
|
const std::string &op_type) { |
|
|
const std::string &op_type) { |
|
|
if (ranks_group.empty()) { |
|
|
if (ranks_group.empty()) { |
|
|
MS_LOG(ERROR) << "input rank group is empty!"; |
|
|
|
|
|
|
|
|
RAISE_EXCEPTION("input rank group is empty!"); |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
auto group = AddGroup(ranks_group); |
|
|
auto group = AddGroup(ranks_group); |
|
|
if (group == MPI_GROUP_NULL) { |
|
|
if (group == MPI_GROUP_NULL) { |
|
|
MS_LOG(EXCEPTION) << "Get mpi group fail!rankid:" << rank_id_; |
|
|
|
|
|
|
|
|
RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail!rankid:", rank_id_) |
|
|
} |
|
|
} |
|
|
MPI_Comm comm; |
|
|
MPI_Comm comm; |
|
|
MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm); |
|
|
MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm); |
|
|
if (comm == MPI_COMM_NULL) { |
|
|
if (comm == MPI_COMM_NULL) { |
|
|
MS_LOG(EXCEPTION) << "create mpi comm fail!rankid:" << rank_id_; |
|
|
|
|
|
|
|
|
RAISE_EXCEPTION_WITH_PARAM("create mpi comm fail!rankid:", rank_id_); |
|
|
} |
|
|
} |
|
|
std::vector<int> receive_count(ranks_group.size(), 0); |
|
|
std::vector<int> receive_count(ranks_group.size(), 0); |
|
|
for (size_t i = 0; i < ranks_group.size(); ++i) { |
|
|
for (size_t i = 0; i < ranks_group.size(); ++i) { |
|
|
@@ -168,13 +188,13 @@ bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vec |
|
|
auto ret = MPI_Reduce_scatter(input, output, receive_count.data(), MPI_FLOAT, op, comm); |
|
|
auto ret = MPI_Reduce_scatter(input, output, receive_count.data(), MPI_FLOAT, op, comm); |
|
|
bool result = true; |
|
|
bool result = true; |
|
|
if (ret != MPI_SUCCESS) { |
|
|
if (ret != MPI_SUCCESS) { |
|
|
MS_LOG(ERROR) << "mpi reduce_scatter fail!ret = " << ret << ", rankid:" << rank_id_; |
|
|
|
|
|
|
|
|
RAISE_EXCEPTION_WITH_PARAM("mpi reduce_scatter fail!ret = ", ret); |
|
|
result = false; |
|
|
result = false; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
ret = MPI_Comm_free(&comm); |
|
|
ret = MPI_Comm_free(&comm); |
|
|
if (ret != MPI_SUCCESS) { |
|
|
if (ret != MPI_SUCCESS) { |
|
|
MS_LOG(WARNING) << "mpi comm free fail! ret = " << ret << ", rankid:" << rank_id_; |
|
|
|
|
|
|
|
|
RAISE_EXCEPTION_WITH_PARAM("mpi comm free fail! ret = ", ret); |
|
|
} |
|
|
} |
|
|
return result; |
|
|
return result; |
|
|
} |
|
|
} |
|
|
@@ -184,19 +204,18 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int |
|
|
int scatter_index = GetScatterIndex(rank_id_, ranks_group); |
|
|
int scatter_index = GetScatterIndex(rank_id_, ranks_group); |
|
|
auto group = AddGroup(ranks_group); |
|
|
auto group = AddGroup(ranks_group); |
|
|
if (group == MPI_GROUP_NULL) { |
|
|
if (group == MPI_GROUP_NULL) { |
|
|
MS_LOG(EXCEPTION) << "Get mpi group fail!rankid:" << rank_id_; |
|
|
|
|
|
|
|
|
RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail!rankid:", rank_id_); |
|
|
} |
|
|
} |
|
|
MPI_Comm comm; |
|
|
MPI_Comm comm; |
|
|
MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm); |
|
|
MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm); |
|
|
if (comm == MPI_COMM_NULL) { |
|
|
if (comm == MPI_COMM_NULL) { |
|
|
MS_LOG(EXCEPTION) << "create mpi comm fail!rankid:" << rank_id_; |
|
|
|
|
|
|
|
|
RAISE_EXCEPTION_WITH_PARAM("create mpi comm fail!rankid:", rank_id_); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
MPI_Win window; |
|
|
MPI_Win window; |
|
|
auto ret = MPI_Win_create(input, input_data_num * sizeof(float), sizeof(float), MPI_INFO_NULL, comm, &window); |
|
|
auto ret = MPI_Win_create(input, input_data_num * sizeof(float), sizeof(float), MPI_INFO_NULL, comm, &window); |
|
|
if (ret != MPI_SUCCESS) { |
|
|
if (ret != MPI_SUCCESS) { |
|
|
MS_LOG(ERROR) << "mpi window create fail! ret = " << ret; |
|
|
|
|
|
return false; |
|
|
|
|
|
|
|
|
RAISE_EXCEPTION_WITH_PARAM("mpi window create fail! ret = ", ret); |
|
|
} |
|
|
} |
|
|
MPI_Win_fence(0, window); |
|
|
MPI_Win_fence(0, window); |
|
|
for (size_t i = 0; i < ranks_group.size(); ++i) { |
|
|
for (size_t i = 0; i < ranks_group.size(); ++i) { |
|
|
@@ -208,18 +227,20 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int |
|
|
ret = MPI_Accumulate(input + i * input_data_num, input_data_num, MPI_FLOAT, remote_rank, i * input_data_num, |
|
|
ret = MPI_Accumulate(input + i * input_data_num, input_data_num, MPI_FLOAT, remote_rank, i * input_data_num, |
|
|
input_data_num, MPI_FLOAT, op, window); |
|
|
input_data_num, MPI_FLOAT, op, window); |
|
|
if (ret != MPI_SUCCESS) { |
|
|
if (ret != MPI_SUCCESS) { |
|
|
MS_LOG(EXCEPTION) << "mpi accumulate " << op_type << " fail!ret = " << ret; |
|
|
|
|
|
|
|
|
RAISE_EXCEPTION_WITH_PARAM("mpi accumulate fail!ret = ", ret); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
MPI_Win_fence(0, window); |
|
|
MPI_Win_fence(0, window); |
|
|
if (output != nullptr) { |
|
|
if (output != nullptr) { |
|
|
auto data_size = input_data_num * sizeof(float); |
|
|
auto data_size = input_data_num * sizeof(float); |
|
|
if (output_size < data_size) { |
|
|
if (output_size < data_size) { |
|
|
MS_LOG(EXCEPTION) << "output buffer size " << output_size << " < input size " << data_size; |
|
|
|
|
|
|
|
|
std::ostringstream exception_msg; |
|
|
|
|
|
exception_msg << "output buffer size " << output_size << " < input size " << data_size; |
|
|
|
|
|
RAISE_EXCEPTION(exception_msg.str()) |
|
|
} |
|
|
} |
|
|
auto copy_ret = memcpy_s(output, output_size, input + scatter_index * input_data_num, data_size); |
|
|
auto copy_ret = memcpy_s(output, output_size, input + scatter_index * input_data_num, data_size); |
|
|
if (copy_ret != 0) { |
|
|
if (copy_ret != 0) { |
|
|
MS_LOG(EXCEPTION) << "copy output memory fail!ret = " << copy_ret; |
|
|
|
|
|
|
|
|
RAISE_EXCEPTION_WITH_PARAM("copy output memory fail!ret = ", copy_ret); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
MPI_Win_free(&window); |
|
|
MPI_Win_free(&window); |
|
|
@@ -229,31 +250,31 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int |
|
|
|
|
|
|
|
|
bool MPIAdapter::AllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) { |
|
|
bool MPIAdapter::AllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) { |
|
|
if (ranks_group.empty()) { |
|
|
if (ranks_group.empty()) { |
|
|
MS_LOG(ERROR) << "input rank group is empty!"; |
|
|
|
|
|
|
|
|
RAISE_EXCEPTION("input rank group is empty!"); |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
auto group = AddGroup(ranks_group); |
|
|
auto group = AddGroup(ranks_group); |
|
|
if (group == MPI_GROUP_NULL) { |
|
|
if (group == MPI_GROUP_NULL) { |
|
|
MS_LOG(EXCEPTION) << "Get mpi group fail! rankid:" << rank_id_; |
|
|
|
|
|
|
|
|
RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail! rankid:", rank_id_); |
|
|
} |
|
|
} |
|
|
MPI_Comm comm; |
|
|
MPI_Comm comm; |
|
|
MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm); |
|
|
MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm); |
|
|
if (comm == MPI_COMM_NULL) { |
|
|
if (comm == MPI_COMM_NULL) { |
|
|
MS_LOG(EXCEPTION) << "create mpi comm fail! rankid:" << rank_id_; |
|
|
|
|
|
|
|
|
RAISE_EXCEPTION_WITH_PARAM("create mpi comm fail! rankid:", rank_id_); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
auto ret = MPI_Allgather(input, data_num, MPI_FLOAT, output, data_num, MPI_FLOAT, comm); |
|
|
auto ret = MPI_Allgather(input, data_num, MPI_FLOAT, output, data_num, MPI_FLOAT, comm); |
|
|
bool result = true; |
|
|
|
|
|
|
|
|
|
|
|
if (ret != MPI_SUCCESS) { |
|
|
if (ret != MPI_SUCCESS) { |
|
|
MS_LOG(ERROR) << "mpi allgater fail!ret = " << ret << ", rankid:" << rank_id_; |
|
|
|
|
|
result = false; |
|
|
|
|
|
|
|
|
RAISE_EXCEPTION_WITH_PARAM("mpi allgater fail!ret = ", ret); |
|
|
} |
|
|
} |
|
|
ret = MPI_Comm_free(&comm); |
|
|
ret = MPI_Comm_free(&comm); |
|
|
if (ret != MPI_SUCCESS) { |
|
|
if (ret != MPI_SUCCESS) { |
|
|
MS_LOG(WARNING) << "mpi comm free fail!ret = " << ret << ",rankid:" << rank_id_; |
|
|
|
|
|
|
|
|
RAISE_EXCEPTION_WITH_PARAM("mpi comm free fail!ret = ", ret); |
|
|
} |
|
|
} |
|
|
return result; |
|
|
|
|
|
|
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
#endif // ENABLE_MPI |
|
|
} // namespace cpu |
|
|
} // namespace cpu |
|
|
} // namespace device |
|
|
} // namespace device |
|
|
} // namespace mindspore |
|
|
} // namespace mindspore |