Browse Source

clean code

tags/v1.6.0
kswang 4 years ago
parent
commit
e87939b7dc
6 changed files with 24 additions and 27 deletions
  1. +1
    -1
      mindspore/ccsrc/runtime/device/CMakeLists.txt
  2. +0
    -1
      mindspore/ccsrc/runtime/device/cpu/mpi/mpi_adapter.cc
  3. +2
    -1
      mindspore/ccsrc/runtime/device/cpu/mpi/mpi_export.cc
  4. +9
    -10
      mindspore/ccsrc/runtime/device/cpu/mpi/mpi_export.h
  5. +11
    -10
      mindspore/ccsrc/runtime/device/cpu/mpi/mpi_interface.cc
  6. +1
    -4
      mindspore/ccsrc/runtime/device/cpu/mpi/mpi_interface.h

+ 1
- 1
mindspore/ccsrc/runtime/device/CMakeLists.txt View File

@@ -36,7 +36,7 @@ if(ENABLE_MPI)
set_property(SOURCE ${MPI_SRC_LIST}
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
add_library(mpi_adapter SHARED ${MPI_SRC_LIST})
target_link_libraries(mpi_adapter PRIVATE mindspore::ompi)
target_link_libraries(mpi_adapter PRIVATE mindspore::ompi mindspore::pybind11_module -ldl ${SECUREC_LIBRARY})
endif()

if(ENABLE_GPU)


+ 0
- 1
mindspore/ccsrc/runtime/device/cpu/mpi/mpi_adapter.cc View File

@@ -27,7 +27,6 @@ namespace cpu {
std::shared_ptr<MPIAdapter> MPIAdapter::instance_ = nullptr;
std::shared_ptr<MPIAdapter> MPIAdapter::Instance() {
if (instance_ == nullptr) {
MS_LOG(DEBUG) << "Create new mpi adapter instance.";
instance_.reset(new (std::nothrow) MPIAdapter());
}
return instance_;


+ 2
- 1
mindspore/ccsrc/runtime/device/cpu/mpi/mpi_export.cc View File

@@ -15,9 +15,9 @@
*/
#include "runtime/device/cpu/mpi/mpi_export.h"
#include <vector>
#include <string>
#include "runtime/device/cpu/mpi/mpi_adapter.h"

extern "C" {
int GetMPIRankId() {
auto inst = mindspore::device::cpu::MPIAdapter::Instance();
if (inst == nullptr) {
@@ -59,3 +59,4 @@ bool MPIAllGather(const float *input, float *output, const std::vector<int> &ran
}
return inst->AllGather(input, output, ranks_group, data_num);
}
}

+ 9
- 10
mindspore/ccsrc/runtime/device/cpu/mpi/mpi_export.h View File

@@ -22,14 +22,13 @@
#define FUNC_EXPORT __attribute__((visibility("default")))
#endif
extern "C" FUNC_EXPORT FUNC_EXPORT int GetMPIRankId();
extern "C" FUNC_EXPORT FUNC_EXPORT int GetMPIRankSize();
extern "C" FUNC_EXPORT bool MPIReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group,
size_t data_num, const std::string &op_type);
extern "C" FUNC_EXPORT bool MPIReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group,
size_t in_data_num, size_t output_size,
const std::string &op_type, float *output);
extern "C" FUNC_EXPORT bool MPIAllGather(const float *input, float *output, const std::vector<int> &ranks_group,
size_t data_num);
extern "C" {
FUNC_EXPORT int GetMPIRankId();
FUNC_EXPORT int GetMPIRankSize();
FUNC_EXPORT bool MPIReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group,
size_t data_num, const std::string &op_type);
FUNC_EXPORT bool MPIReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t in_data_num,
size_t output_size, const std::string &op_type, float *output);
FUNC_EXPORT bool MPIAllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num);
}
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_EXPORT_H_

+ 11
- 10
mindspore/ccsrc/runtime/device/cpu/mpi/mpi_interface.cc View File

@@ -19,6 +19,7 @@
#include <vector>
#include <string>
#include "utils/log_adapter.h"
#include "utils/dlopen_macro.h"

inline void *LoadLibrary(const char *name) {
auto handle = dlopen(name, RTLD_LAZY | RTLD_LOCAL);
@@ -29,16 +30,17 @@ inline void *LoadLibrary(const char *name) {
}

inline void *GetMPIAdapterHandle() {
static void *handle = LoadLibrary("mpi_adapter.so");
static void *handle = LoadLibrary("libmpi_adapter.so");
return handle;
}

void *GetMPIAdapterFunc(const char *name) {
static void *handle = GetMPIAdapterHandle();
template <class T>
static T GetMPIAdapterFunc(const char *name) {
void *handle = GetMPIAdapterHandle();
if (handle == nullptr) {
MS_LOG(EXCEPTION) << "Load lib " << name << " failed, make sure you have installed it!";
}
void *func = dlsym(handle, name);
auto func = reinterpret_cast<T>(dlsym(handle, name));
if (func == nullptr) {
MS_LOG(EXCEPTION) << "Load func " << name << " failed, make sure you have implied it!";
}
@@ -56,30 +58,29 @@ typedef bool (*MPIAllGatherFunc)(const float *input, float *output, const std::v
size_t data_num);

int GetMPIRankId() {
static GetMPIRankIdFunc func = reinterpret_cast<GetMPIRankIdFunc>(GetMPIAdapterFunc("GetMPIRankId"));
auto func = GetMPIAdapterFunc<GetMPIRankIdFunc>("GetMPIRankId");
return func();
}

int GetMPIRankSize() {
static GetMPIRankIdFunc func = reinterpret_cast<GetMPIRankSizeFunc>(GetMPIAdapterFunc("GetMPIRankSize"));
auto func = GetMPIAdapterFunc<GetMPIRankSizeFunc>("GetMPIRankSize");
return func();
}

bool MPIReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,
const std::string &op_type) {
static MPIReduceScatterFunc func = reinterpret_cast<MPIReduceScatterFunc>(GetMPIAdapterFunc("MPIReduceScatter"));
auto func = GetMPIAdapterFunc<MPIReduceScatterFunc>("MPIReduceScatter");
return func(input, output, ranks_group, data_num, op_type);
}

bool MPIReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t in_data_num,
size_t output_size, const std::string &op_type, float *output) {
static MPIReduceScatterOverwriteInputFunc func =
reinterpret_cast<MPIReduceScatterOverwriteInputFunc>(GetMPIAdapterFunc("MPIReduceScatterOverwriteInput"));
auto func = GetMPIAdapterFunc<MPIReduceScatterOverwriteInputFunc>("MPIReduceScatterOverwriteInput");
return func(input, ranks_group, in_data_num, output_size, op_type, output);
}

bool MPIAllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) {
static MPIAllGatherFunc func = reinterpret_cast<MPIAllGatherFunc>(GetMPIAdapterFunc("MPIAllGather"));
auto func = GetMPIAdapterFunc<MPIAllGatherFunc>("MPIAllGather");
return func(input, output, ranks_group, data_num);
}
#endif // ENABLE_MPI

+ 1
- 4
mindspore/ccsrc/runtime/device/cpu/mpi/mpi_interface.h View File

@@ -17,11 +17,8 @@
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_INTERFACE_H_
#include <vector>
#include <string>
#ifndef FUNC_EXPORT
#define FUNC_EXPORT __attribute__((visibility("default")))
#endif
constexpr auto kMPIOpTypeSum = "sum";
#ifdef ENABLE_MPI
constexpr auto kMPIOpTypeSum = "sum";
int GetMPIRankId();
int GetMPIRankSize();
bool MPIReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,


Loading…
Cancel
Save