|
|
|
@@ -33,19 +33,19 @@ std::shared_ptr<MPIAdapter> MPIAdapter::Instance() { |
|
|
|
return instance_; |
|
|
|
} |
|
|
|
|
|
|
|
#define RAISE_EXCEPTION(message) \ |
|
|
|
{ \ |
|
|
|
std::ostringstream oss; \ |
|
|
|
oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << message; \ |
|
|
|
pybind11::pybind11_fail(oss.str()); \ |
|
|
|
} |
|
|
|
#define RAISE_EXCEPTION(message) \ |
|
|
|
do { \ |
|
|
|
std::ostringstream oss; \ |
|
|
|
oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << (message); \ |
|
|
|
pybind11::pybind11_fail(oss.str()); \ |
|
|
|
} while (0) |
|
|
|
|
|
|
|
#define RAISE_EXCEPTION_WITH_PARAM(message, param) \ |
|
|
|
{ \ |
|
|
|
std::ostringstream oss; \ |
|
|
|
oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << message << param; \ |
|
|
|
pybind11::pybind11_fail(oss.str()); \ |
|
|
|
} |
|
|
|
#define RAISE_EXCEPTION_WITH_PARAM(message, param) \ |
|
|
|
do { \ |
|
|
|
std::ostringstream oss; \ |
|
|
|
oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << (message) << (param); \ |
|
|
|
pybind11::pybind11_fail(oss.str()); \ |
|
|
|
} while (0) |
|
|
|
|
|
|
|
namespace { |
|
|
|
MPI_Op GetMpiOp(const std::string &op_type) { |
|
|
|
@@ -109,8 +109,8 @@ void MPIAdapter::Init() { |
|
|
|
RAISE_EXCEPTION("Check mpi initialized fail!"); |
|
|
|
} |
|
|
|
if (init_flag == 0) { |
|
|
|
auto ret = MPI_Init(nullptr, nullptr); |
|
|
|
if (ret != MPI_SUCCESS) { |
|
|
|
auto ret_init = MPI_Init(nullptr, nullptr); |
|
|
|
if (ret_init != MPI_SUCCESS) { |
|
|
|
RAISE_EXCEPTION("Failed to init mpi!"); |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -126,7 +126,7 @@ void MPIAdapter::Init() { |
|
|
|
|
|
|
|
ret = MPI_Comm_size(MPI_COMM_WORLD, &rank_size_); |
|
|
|
if (ret != MPI_SUCCESS) { |
|
|
|
RAISE_EXCEPTION_WITH_PARAM("Failed to init mpi rank size!rankid:", rank_id_) |
|
|
|
RAISE_EXCEPTION_WITH_PARAM("Failed to init mpi rank size!rankid:", rank_id_); |
|
|
|
} |
|
|
|
init = true; |
|
|
|
} |
|
|
|
@@ -153,7 +153,7 @@ MPI_Group MPIAdapter::AddGroup(const std::vector<int> &ranks) { |
|
|
|
MPI_Group group = MPI_GROUP_NULL; |
|
|
|
MPI_Group_incl(comm_group_world_, ranks.size(), ranks_input.data(), &group); |
|
|
|
if (group == MPI_GROUP_NULL) { |
|
|
|
RAISE_EXCEPTION_WITH_PARAM("create mpi group fail!rankid:", rank_id_) |
|
|
|
RAISE_EXCEPTION_WITH_PARAM("create mpi group fail!rankid:", rank_id_); |
|
|
|
} |
|
|
|
|
|
|
|
ranks_group_[ranks] = group; |
|
|
|
@@ -169,7 +169,7 @@ bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vec |
|
|
|
|
|
|
|
auto group = AddGroup(ranks_group); |
|
|
|
if (group == MPI_GROUP_NULL) { |
|
|
|
RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail!rankid:", rank_id_) |
|
|
|
RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail!rankid:", rank_id_); |
|
|
|
} |
|
|
|
MPI_Comm comm; |
|
|
|
MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm); |
|
|
|
@@ -233,7 +233,7 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int |
|
|
|
if (output_size < data_size) { |
|
|
|
std::ostringstream exception_msg; |
|
|
|
exception_msg << "output buffer size " << output_size << " < input size " << data_size; |
|
|
|
RAISE_EXCEPTION(exception_msg.str()) |
|
|
|
RAISE_EXCEPTION(exception_msg.str()); |
|
|
|
} |
|
|
|
auto copy_ret = memcpy_s(output, output_size, input + scatter_index * input_data_num, data_size); |
|
|
|
if (copy_ret != 0) { |
|
|
|
|