|
|
|
@@ -19,6 +19,7 @@ |
|
|
|
#include <vector> |
|
|
|
#include <string> |
|
|
|
#include "utils/log_adapter.h" |
|
|
|
#include "utils/dlopen_macro.h" |
|
|
|
|
|
|
|
inline void *LoadLibrary(const char *name) { |
|
|
|
auto handle = dlopen(name, RTLD_LAZY | RTLD_LOCAL); |
|
|
|
@@ -29,16 +30,17 @@ inline void *LoadLibrary(const char *name) { |
|
|
|
} |
|
|
|
|
|
|
|
inline void *GetMPIAdapterHandle() { |
|
|
|
static void *handle = LoadLibrary("mpi_adapter.so"); |
|
|
|
static void *handle = LoadLibrary("libmpi_adapter.so"); |
|
|
|
return handle; |
|
|
|
} |
|
|
|
|
|
|
|
void *GetMPIAdapterFunc(const char *name) { |
|
|
|
static void *handle = GetMPIAdapterHandle(); |
|
|
|
template <class T> |
|
|
|
static T GetMPIAdapterFunc(const char *name) { |
|
|
|
void *handle = GetMPIAdapterHandle(); |
|
|
|
if (handle == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Load lib " << name << " failed, make sure you have installed it!"; |
|
|
|
} |
|
|
|
void *func = dlsym(handle, name); |
|
|
|
auto func = reinterpret_cast<T>(dlsym(handle, name)); |
|
|
|
if (func == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Load func " << name << " failed, make sure you have implied it!"; |
|
|
|
} |
|
|
|
@@ -56,30 +58,29 @@ typedef bool (*MPIAllGatherFunc)(const float *input, float *output, const std::v |
|
|
|
size_t data_num); |
|
|
|
|
|
|
|
int GetMPIRankId() { |
|
|
|
static GetMPIRankIdFunc func = reinterpret_cast<GetMPIRankIdFunc>(GetMPIAdapterFunc("GetMPIRankId")); |
|
|
|
auto func = GetMPIAdapterFunc<GetMPIRankIdFunc>("GetMPIRankId"); |
|
|
|
return func(); |
|
|
|
} |
|
|
|
|
|
|
|
int GetMPIRankSize() { |
|
|
|
static GetMPIRankIdFunc func = reinterpret_cast<GetMPIRankSizeFunc>(GetMPIAdapterFunc("GetMPIRankSize")); |
|
|
|
auto func = GetMPIAdapterFunc<GetMPIRankSizeFunc>("GetMPIRankSize"); |
|
|
|
return func(); |
|
|
|
} |
|
|
|
|
|
|
|
bool MPIReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num, |
|
|
|
const std::string &op_type) { |
|
|
|
static MPIReduceScatterFunc func = reinterpret_cast<MPIReduceScatterFunc>(GetMPIAdapterFunc("MPIReduceScatter")); |
|
|
|
auto func = GetMPIAdapterFunc<MPIReduceScatterFunc>("MPIReduceScatter"); |
|
|
|
return func(input, output, ranks_group, data_num, op_type); |
|
|
|
} |
|
|
|
|
|
|
|
bool MPIReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t in_data_num, |
|
|
|
size_t output_size, const std::string &op_type, float *output) { |
|
|
|
static MPIReduceScatterOverwriteInputFunc func = |
|
|
|
reinterpret_cast<MPIReduceScatterOverwriteInputFunc>(GetMPIAdapterFunc("MPIReduceScatterOverwriteInput")); |
|
|
|
auto func = GetMPIAdapterFunc<MPIReduceScatterOverwriteInputFunc>("MPIReduceScatterOverwriteInput"); |
|
|
|
return func(input, ranks_group, in_data_num, output_size, op_type, output); |
|
|
|
} |
|
|
|
|
|
|
|
bool MPIAllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) { |
|
|
|
static MPIAllGatherFunc func = reinterpret_cast<MPIAllGatherFunc>(GetMPIAdapterFunc("MPIAllGather")); |
|
|
|
auto func = GetMPIAdapterFunc<MPIAllGatherFunc>("MPIAllGather"); |
|
|
|
return func(input, output, ranks_group, data_num); |
|
|
|
} |
|
|
|
#endif // ENABLE_MPI |