Merge pull request !2428 from chenjianping/host_reducetags/v0.6.0-beta
| @@ -128,6 +128,11 @@ if (ENABLE_MPI) | |||
| DESTINATION ${INSTALL_BASE_DIR} | |||
| COMPONENT mindspore | |||
| ) | |||
| install( | |||
| TARGETS mpi_adapter | |||
| DESTINATION ${INSTALL_LIB_DIR} | |||
| COMPONENT mindspore | |||
| ) | |||
| endif () | |||
| if (ENABLE_GPU) | |||
| @@ -145,11 +145,12 @@ if (ENABLE_DEBUGGER) | |||
| endif() | |||
| target_link_libraries(mindspore proto_input) | |||
| if (ENABLE_CPU AND ENABLE_MPI) | |||
| target_link_libraries(mindspore securec mindspore::flatbuffers mindspore::ompi) | |||
| if (ENABLE_MPI) | |||
| target_link_libraries(mindspore securec mindspore::flatbuffers mpi_adapter) | |||
| else () | |||
| target_link_libraries(mindspore securec mindspore::flatbuffers) | |||
| endif () | |||
| if (NOT WIN32) | |||
| target_link_libraries(mindspore dl) | |||
| endif() | |||
| @@ -14,17 +14,22 @@ endif () | |||
| if (ENABLE_CPU) | |||
| file(GLOB_RECURSE CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc") | |||
| if (NOT ENABLE_MPI) | |||
| list(REMOVE_ITEM CPU_SRC_LIST "cpu/mpi/mpi_adapter.cc") | |||
| endif () | |||
| list(REMOVE_ITEM CPU_SRC_LIST "cpu/mpi/mpi_adapter.cc") | |||
| list(REMOVE_ITEM CPU_SRC_LIST "cpu/mpi/mpi_interface.cc") | |||
| endif () | |||
| if (ENABLE_MPI) | |||
| # _ms_mpi | |||
| set_property(SOURCE "gpu/mpi/mpi_initializer.cc" | |||
| file(GLOB_RECURSE MPI_SRC_LIST "cpu/mpi/mpi_adapter.cc") | |||
| set_property(SOURCE ${MPI_SRC_LIST} | |||
| PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) | |||
| add_library(mpi_adapter SHARED ${MPI_SRC_LIST}) | |||
| target_link_libraries(mpi_adapter PRIVATE mindspore::ompi) | |||
| set_property(SOURCE "cpu/mpi/mpi_interface.cc" | |||
| PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) | |||
| pybind11_add_module(_ms_mpi "gpu/mpi/mpi_initializer.cc") | |||
| target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mindspore::ompi) | |||
| pybind11_add_module(_ms_mpi "cpu/mpi/mpi_interface.cc") | |||
| target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mpi_adapter) | |||
| endif () | |||
| # gpu | |||
| @@ -15,13 +15,41 @@ | |||
| */ | |||
| #include "device/cpu/mpi/mpi_adapter.h" | |||
| #ifdef ENABLE_MPI | |||
| #include <algorithm> | |||
| #include "utils/mpi/mpi_config.h" | |||
| #include <sstream> | |||
| #include "pybind11/pybind11.h" | |||
| #endif // ENABLE_MPI | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace cpu { | |||
| std::shared_ptr<MPIAdapter> MPIAdapter::instance_ = nullptr; | |||
| std::shared_ptr<MPIAdapter> MPIAdapter::Instance() { | |||
| if (instance_ == nullptr) { | |||
| MS_LOG(DEBUG) << "Create new mpi adapter instance."; | |||
| instance_.reset(new (std::nothrow) MPIAdapter()); | |||
| } | |||
| return instance_; | |||
| } | |||
| #ifdef ENABLE_MPI | |||
| #define RAISE_EXCEPTION(message) \ | |||
| { \ | |||
| std::ostringstream oss; \ | |||
| oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << message; \ | |||
| pybind11::pybind11_fail(oss.str()); \ | |||
| } | |||
| #define RAISE_EXCEPTION_WITH_PARAM(message, param) \ | |||
| { \ | |||
| std::ostringstream oss; \ | |||
| oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << message << param; \ | |||
| pybind11::pybind11_fail(oss.str()); \ | |||
| } | |||
| namespace { | |||
| MPI_Op GetMpiOp(const std::string &op_type) { | |||
| if (op_type == "sum") { | |||
| @@ -33,7 +61,8 @@ MPI_Op GetMpiOp(const std::string &op_type) { | |||
| } else if (op_type == "prod") { | |||
| return MPI_PROD; | |||
| } | |||
| MS_LOG(EXCEPTION) << "unsupport op_type:" << op_type; | |||
| RAISE_EXCEPTION_WITH_PARAM("unsupport op_type: ", op_type); | |||
| return MPI_SUM; | |||
| } | |||
| @@ -46,80 +75,72 @@ int GetScatterIndex(int rankid, const std::vector<int> &ranks_group) { | |||
| } | |||
| } | |||
| if (scatter_index == -1) { | |||
| MS_LOG(EXCEPTION) << "process rankid " << rankid << " does not in the input rank group!"; | |||
| RAISE_EXCEPTION_WITH_PARAM("local rankid does not in the input rank group!local rank id:", rankid); | |||
| } | |||
| return scatter_index; | |||
| } | |||
| } // namespace | |||
| MPIAdapter::MPIAdapter() : rank_id_(0), rank_size_(0), comm_group_world_(MPI_GROUP_NULL) { Init(); } | |||
| MPIAdapter::MPIAdapter() : comm_group_world_(MPI_GROUP_NULL) { Init(); } | |||
| MPIAdapter::~MPIAdapter() { | |||
| int finalized; | |||
| MPI_Finalized(&finalized); | |||
| if (finalized != 0) { | |||
| return; | |||
| } | |||
| for (auto iter = ranks_group_.begin(); iter != ranks_group_.end(); ++iter) { | |||
| MPI_Group_free(&iter->second); | |||
| } | |||
| ranks_group_.clear(); | |||
| if (comm_group_world_ != MPI_GROUP_NULL) { | |||
| MPI_Group_free(&comm_group_world_); | |||
| comm_group_world_ = MPI_GROUP_NULL; | |||
| } | |||
| int finalized; | |||
| MPI_Finalized(&finalized); | |||
| if (finalized == 0) { | |||
| MPI_Finalize(); | |||
| } | |||
| MPI_Finalize(); | |||
| } | |||
| MPIAdapter &MPIAdapter::Instance() { | |||
| static MPIAdapter instance; | |||
| return instance; | |||
| } | |||
| int MPIAdapter::GetRankId() const { return rank_id_; } | |||
| void MPIAdapter::Init() { | |||
| static bool init = false; | |||
| if (init) { | |||
| return; | |||
| } | |||
| auto mpi_config_ptr = MpiConfig::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(mpi_config_ptr); | |||
| if (!mpi_config_ptr->enable_mpi()) { | |||
| MS_LOG(EXCEPTION) << "MPI is disabled now!Please enable mpi with mpi config first."; | |||
| } | |||
| int init_flag = 0; | |||
| if (MPI_Initialized(&init_flag) != MPI_SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Check mpi initialized fail!"; | |||
| RAISE_EXCEPTION("Check mpi initialized fail!"); | |||
| } | |||
| if (init_flag == 0) { | |||
| auto ret = MPI_Init(nullptr, nullptr); | |||
| if (ret != MPI_SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Failed to init mpi!"; | |||
| RAISE_EXCEPTION("Failed to init mpi!"); | |||
| } | |||
| } | |||
| MPI_Comm_group(MPI_COMM_WORLD, &comm_group_world_); | |||
| if (comm_group_world_ == MPI_GROUP_NULL) { | |||
| MS_LOG(EXCEPTION) << "comm_group_world_ init fail!"; | |||
| RAISE_EXCEPTION("comm_group_world_ init fail!"); | |||
| } | |||
| auto ret = MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_); | |||
| if (ret != MPI_SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Failed to init mpi rank id!"; | |||
| RAISE_EXCEPTION("Failed to init mpi rank id!"); | |||
| } | |||
| ret = MPI_Comm_size(MPI_COMM_WORLD, &rank_size_); | |||
| if (ret != MPI_SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Failed to init mpi rank size!rankid:" << rank_id_; | |||
| RAISE_EXCEPTION_WITH_PARAM("Failed to init mpi rank size!rankid:", rank_id_) | |||
| } | |||
| init = true; | |||
| } | |||
| MPI_Group MPIAdapter::AddGroup(const std::vector<int> &ranks) { | |||
| if (ranks.size() > static_cast<size_t>(rank_size_) || ranks.empty()) { | |||
| MS_LOG(EXCEPTION) << "input rank size: " << ranks.size() << ", max rank size: " << rank_size_; | |||
| RAISE_EXCEPTION_WITH_PARAM("input rank size:", ranks.size()); | |||
| } | |||
| if (std::find(ranks.begin(), ranks.end(), rank_id_) == ranks.end()) { | |||
| MS_LOG(ERROR) << "rankid:" << rank_id_ << " is not in the input group."; | |||
| return MPI_GROUP_NULL; | |||
| RAISE_EXCEPTION_WITH_PARAM("local rankid does not in the input rank group!local rank id:", rank_id_); | |||
| } | |||
| std::lock_guard<std::mutex> lock(group_mutex_); | |||
| auto iter = ranks_group_.find(ranks); | |||
| @@ -135,29 +156,28 @@ MPI_Group MPIAdapter::AddGroup(const std::vector<int> &ranks) { | |||
| MPI_Group group = MPI_GROUP_NULL; | |||
| MPI_Group_incl(comm_group_world_, ranks.size(), ranks_input.data(), &group); | |||
| if (group == MPI_GROUP_NULL) { | |||
| MS_LOG(EXCEPTION) << "create mpi group fail!rankid:" << rank_id_; | |||
| RAISE_EXCEPTION_WITH_PARAM("create mpi group fail!rankid:", rank_id_) | |||
| } | |||
| ranks_group_[ranks] = group; | |||
| MS_LOG(INFO) << "rank:" << rank_id_ << " add group:" << group; | |||
| return group; | |||
| } | |||
| bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num, | |||
| const std::string &op_type) { | |||
| if (ranks_group.empty()) { | |||
| MS_LOG(ERROR) << "input rank group is empty!"; | |||
| RAISE_EXCEPTION("input rank group is empty!"); | |||
| return false; | |||
| } | |||
| auto group = AddGroup(ranks_group); | |||
| if (group == MPI_GROUP_NULL) { | |||
| MS_LOG(EXCEPTION) << "Get mpi group fail!rankid:" << rank_id_; | |||
| RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail!rankid:", rank_id_) | |||
| } | |||
| MPI_Comm comm; | |||
| MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm); | |||
| if (comm == MPI_COMM_NULL) { | |||
| MS_LOG(EXCEPTION) << "create mpi comm fail!rankid:" << rank_id_; | |||
| RAISE_EXCEPTION_WITH_PARAM("create mpi comm fail!rankid:", rank_id_); | |||
| } | |||
| std::vector<int> receive_count(ranks_group.size(), 0); | |||
| for (size_t i = 0; i < ranks_group.size(); ++i) { | |||
| @@ -168,13 +188,13 @@ bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vec | |||
| auto ret = MPI_Reduce_scatter(input, output, receive_count.data(), MPI_FLOAT, op, comm); | |||
| bool result = true; | |||
| if (ret != MPI_SUCCESS) { | |||
| MS_LOG(ERROR) << "mpi reduce_scatter fail!ret = " << ret << ", rankid:" << rank_id_; | |||
| RAISE_EXCEPTION_WITH_PARAM("mpi reduce_scatter fail!ret = ", ret); | |||
| result = false; | |||
| } | |||
| ret = MPI_Comm_free(&comm); | |||
| if (ret != MPI_SUCCESS) { | |||
| MS_LOG(WARNING) << "mpi comm free fail! ret = " << ret << ", rankid:" << rank_id_; | |||
| RAISE_EXCEPTION_WITH_PARAM("mpi comm free fail! ret = ", ret); | |||
| } | |||
| return result; | |||
| } | |||
| @@ -184,19 +204,18 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int | |||
| int scatter_index = GetScatterIndex(rank_id_, ranks_group); | |||
| auto group = AddGroup(ranks_group); | |||
| if (group == MPI_GROUP_NULL) { | |||
| MS_LOG(EXCEPTION) << "Get mpi group fail!rankid:" << rank_id_; | |||
| RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail!rankid:", rank_id_); | |||
| } | |||
| MPI_Comm comm; | |||
| MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm); | |||
| if (comm == MPI_COMM_NULL) { | |||
| MS_LOG(EXCEPTION) << "create mpi comm fail!rankid:" << rank_id_; | |||
| RAISE_EXCEPTION_WITH_PARAM("create mpi comm fail!rankid:", rank_id_); | |||
| } | |||
| MPI_Win window; | |||
| auto ret = MPI_Win_create(input, input_data_num * sizeof(float), sizeof(float), MPI_INFO_NULL, comm, &window); | |||
| if (ret != MPI_SUCCESS) { | |||
| MS_LOG(ERROR) << "mpi window create fail! ret = " << ret; | |||
| return false; | |||
| RAISE_EXCEPTION_WITH_PARAM("mpi window create fail! ret = ", ret); | |||
| } | |||
| MPI_Win_fence(0, window); | |||
| for (size_t i = 0; i < ranks_group.size(); ++i) { | |||
| @@ -208,18 +227,20 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int | |||
| ret = MPI_Accumulate(input + i * input_data_num, input_data_num, MPI_FLOAT, remote_rank, i * input_data_num, | |||
| input_data_num, MPI_FLOAT, op, window); | |||
| if (ret != MPI_SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "mpi accumulate " << op_type << " fail!ret = " << ret; | |||
| RAISE_EXCEPTION_WITH_PARAM("mpi accumulate fail!ret = ", ret); | |||
| } | |||
| } | |||
| MPI_Win_fence(0, window); | |||
| if (output != nullptr) { | |||
| auto data_size = input_data_num * sizeof(float); | |||
| if (output_size < data_size) { | |||
| MS_LOG(EXCEPTION) << "output buffer size " << output_size << " < input size " << data_size; | |||
| std::ostringstream exception_msg; | |||
| exception_msg << "output buffer size " << output_size << " < input size " << data_size; | |||
| RAISE_EXCEPTION(exception_msg.str()) | |||
| } | |||
| auto copy_ret = memcpy_s(output, output_size, input + scatter_index * input_data_num, data_size); | |||
| if (copy_ret != 0) { | |||
| MS_LOG(EXCEPTION) << "copy output memory fail!ret = " << copy_ret; | |||
| RAISE_EXCEPTION_WITH_PARAM("copy output memory fail!ret = ", copy_ret); | |||
| } | |||
| } | |||
| MPI_Win_free(&window); | |||
| @@ -229,31 +250,31 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int | |||
| bool MPIAdapter::AllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) { | |||
| if (ranks_group.empty()) { | |||
| MS_LOG(ERROR) << "input rank group is empty!"; | |||
| RAISE_EXCEPTION("input rank group is empty!"); | |||
| return false; | |||
| } | |||
| auto group = AddGroup(ranks_group); | |||
| if (group == MPI_GROUP_NULL) { | |||
| MS_LOG(EXCEPTION) << "Get mpi group fail! rankid:" << rank_id_; | |||
| RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail! rankid:", rank_id_); | |||
| } | |||
| MPI_Comm comm; | |||
| MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm); | |||
| if (comm == MPI_COMM_NULL) { | |||
| MS_LOG(EXCEPTION) << "create mpi comm fail! rankid:" << rank_id_; | |||
| RAISE_EXCEPTION_WITH_PARAM("create mpi comm fail! rankid:", rank_id_); | |||
| } | |||
| auto ret = MPI_Allgather(input, data_num, MPI_FLOAT, output, data_num, MPI_FLOAT, comm); | |||
| bool result = true; | |||
| if (ret != MPI_SUCCESS) { | |||
| MS_LOG(ERROR) << "mpi allgater fail!ret = " << ret << ", rankid:" << rank_id_; | |||
| result = false; | |||
| RAISE_EXCEPTION_WITH_PARAM("mpi allgater fail!ret = ", ret); | |||
| } | |||
| ret = MPI_Comm_free(&comm); | |||
| if (ret != MPI_SUCCESS) { | |||
| MS_LOG(WARNING) << "mpi comm free fail!ret = " << ret << ",rankid:" << rank_id_; | |||
| RAISE_EXCEPTION_WITH_PARAM("mpi comm free fail!ret = ", ret); | |||
| } | |||
| return result; | |||
| return true; | |||
| } | |||
| #endif // ENABLE_MPI | |||
| } // namespace cpu | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -22,37 +22,53 @@ | |||
| #include <map> | |||
| #include <string> | |||
| #include <mutex> | |||
| #endif // ENABLE_MPI | |||
| #include <memory> | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace cpu { | |||
| #ifndef FUNC_EXPORT | |||
| #define FUNC_EXPORT __attribute__((visibility("default"))) | |||
| #endif | |||
| constexpr auto kOpTypeSum = "sum"; | |||
| class MPIAdapter { | |||
| public: | |||
| ~MPIAdapter(); | |||
| static MPIAdapter &Instance(); | |||
| int GetRankId() const; | |||
| bool ReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num, | |||
| const std::string &op_type = kOpTypeSum); | |||
| bool ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t input_data_num, | |||
| size_t output_size, const std::string &op_type = kOpTypeSum, | |||
| float *output = nullptr); | |||
| bool AllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num); | |||
| FUNC_EXPORT static std::shared_ptr<MPIAdapter> Instance(); | |||
| FUNC_EXPORT int GetRankId() const { return rank_id_; } | |||
| FUNC_EXPORT int GetRankSize() const { return rank_size_; } | |||
| #ifdef ENABLE_MPI | |||
| FUNC_EXPORT ~MPIAdapter(); | |||
| FUNC_EXPORT bool ReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, | |||
| size_t data_num, const std::string &op_type = kOpTypeSum); | |||
| FUNC_EXPORT bool ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t in_data_num, | |||
| size_t output_size, const std::string &op_type = kOpTypeSum, | |||
| float *output = nullptr); | |||
| FUNC_EXPORT bool AllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num); | |||
| #else | |||
| FUNC_EXPORT ~MPIAdapter() = default; | |||
| #endif // ENABLE_MPI | |||
| private: | |||
| #ifdef ENABLE_MPI | |||
| MPIAdapter(); | |||
| void Init(); | |||
| MPI_Group AddGroup(const std::vector<int> &ranks); | |||
| int rank_id_; | |||
| int rank_size_; | |||
| MPI_Group comm_group_world_; | |||
| // key:ranks group, value: mpi group | |||
| std::map<std::vector<int>, MPI_Group> ranks_group_; | |||
| std::mutex group_mutex_; | |||
| #else | |||
| MPIAdapter() = default; | |||
| #endif // ENABLE_MPI | |||
| int rank_id_{-1}; | |||
| int rank_size_{0}; | |||
| static std::shared_ptr<MPIAdapter> instance_; | |||
| }; | |||
| } // namespace cpu | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| #endif // ENABLE_MPI | |||
| #endif // MINDSPORE_CCSRC_DEVICE_CPU_MPI_MPI_ADAPTER_H_ | |||
| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <pybind11/operators.h> | |||
| #include "device/cpu/mpi/mpi_adapter.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace cpu { | |||
| int get_rank_id() { return MPIAdapter::Instance()->GetRankId(); } | |||
| int get_rank_size() { return MPIAdapter::Instance()->GetRankSize(); } | |||
| PYBIND11_MODULE(_ms_mpi, mpi_interface) { | |||
| mpi_interface.doc() = "mindspore mpi python wrapper"; | |||
| mpi_interface.def("get_rank_id", &get_rank_id, "get rank id"); | |||
| mpi_interface.def("get_rank_size", &get_rank_size, "get rank size"); | |||
| } | |||
| } // namespace cpu | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -17,7 +17,6 @@ | |||
| #include "device/gpu/mpi/mpi_initializer.h" | |||
| #include <mpi.h> | |||
| #include <pybind11/operators.h> | |||
| #include <iostream> | |||
| namespace mindspore { | |||
| @@ -54,12 +53,6 @@ MPIInitializer &MPIInitializer::GetInstance() { | |||
| int MPIInitializer::get_rank_id() { return MPIInitializer::GetInstance().rank_id_; } | |||
| int MPIInitializer::get_rank_size() { return MPIInitializer::GetInstance().rank_size_; } | |||
| PYBIND11_MODULE(_ms_mpi, mpi_initializer) { | |||
| mpi_initializer.doc() = "mindspore mpi python wrapper"; | |||
| mpi_initializer.def("get_rank_id", &MPIInitializer::get_rank_id, "get rank id"); | |||
| mpi_initializer.def("get_rank_size", &MPIInitializer::get_rank_size, "get rank size"); | |||
| } | |||
| } // namespace gpu | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -47,7 +47,7 @@ bool AllGatherCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| auto output_addr = reinterpret_cast<float *>(outputs[0]->addr); | |||
| auto input_data_num = inputs[0]->size / sizeof(float); | |||
| return device::cpu::MPIAdapter::Instance().AllGather(input_addr, output_addr, ranks_group_, input_data_num); | |||
| return device::cpu::MPIAdapter::Instance()->AllGather(input_addr, output_addr, ranks_group_, input_data_num); | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -51,8 +51,8 @@ bool EmbeddingLookUpCommGradCPUKernel::Launch(const std::vector<kernel::AddressP | |||
| size_t input_split_lens = input_size / split_num_ / sizeof(float_t); | |||
| size_t output_split_lens = output_size / split_num_ / sizeof(float_t); | |||
| for (int i = 0; i < split_num_; i++) { | |||
| device::cpu::MPIAdapter::Instance().AllGather(input_addr + i * input_split_lens, | |||
| output_addr + i * output_split_lens, rank_group, input_split_lens); | |||
| device::cpu::MPIAdapter::Instance()->AllGather(input_addr + i * input_split_lens, | |||
| output_addr + i * output_split_lens, rank_group, input_split_lens); | |||
| } | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| auto end_time = std::chrono::steady_clock::now(); | |||
| @@ -105,9 +105,9 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp | |||
| size_t reduce_scatter_out_lens = one_split_lens / 8; | |||
| const std::vector<int> &group = {0, 1, 2, 3, 4, 5, 6, 7}; | |||
| for (int i = 0; i < split_num_; i++) { | |||
| device::cpu::MPIAdapter::Instance().ReduceScatter(reinterpret_cast<float *>(gather_v2_out_) + i * one_split_lens, | |||
| output_addr + i * reduce_scatter_out_lens, group, | |||
| one_split_lens / 8, "sum"); | |||
| device::cpu::MPIAdapter::Instance()->ReduceScatter(reinterpret_cast<float *>(gather_v2_out_) + i * one_split_lens, | |||
| output_addr + i * reduce_scatter_out_lens, group, | |||
| one_split_lens / 8, "sum"); | |||
| } | |||
| } | |||
| #endif | |||
| @@ -47,8 +47,8 @@ bool ReduceScatterCPUKernel::Launch(const std::vector<kernel::AddressPtr> &input | |||
| auto output_addr = reinterpret_cast<float *>(outputs[0]->addr); | |||
| auto output_data_num = outputs[0]->size / sizeof(float); | |||
| return device::cpu::MPIAdapter::Instance().ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num, | |||
| op_type_); | |||
| return device::cpu::MPIAdapter::Instance()->ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num, | |||
| op_type_); | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -25,7 +25,6 @@ from mindspore._c_expression import MSContext | |||
| from mindspore._checkparam import args_type_check | |||
| from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \ | |||
| _reset_auto_parallel_context | |||
| from mindspore.parallel.mpi._mpi_config import _set_mpi_config, _get_mpi_config | |||
| __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context', | |||
| 'get_auto_parallel_context', 'reset_auto_parallel_context'] | |||
| @@ -608,40 +607,3 @@ def get_context(attr_key): | |||
| raise ValueError( | |||
| "Get context keyword %s is not recognized!" % attr_key) | |||
| return getattr(_context(), attr_key) | |||
| @args_type_check(enable_mpi=bool) | |||
| def set_mpi_config(**kwargs): | |||
| """ | |||
| Sets mpi config for running environment. | |||
| mpi config should be configured before running your program. If there is no configuration, | |||
| mpi moudle will be disabled by default. | |||
| Note: | |||
| Attribute name is required for setting attributes. | |||
| Args: | |||
| enable_mpi (bool): Whether to enable mpi. Default: False. | |||
| Raises: | |||
| ValueError: If input key is not an attribute in mpi config. | |||
| Examples: | |||
| >>> mpiconfig.set_mpi_config(enable_mpi=True) | |||
| """ | |||
| _set_mpi_config(**kwargs) | |||
| def get_mpi_config(attr_key): | |||
| """ | |||
| Gets mpi config attribute value according to the input key. | |||
| Args: | |||
| attr_key (str): The key of the attribute. | |||
| Returns: | |||
| Object, The value of given attribute key. | |||
| Raises: | |||
| ValueError: If input key is not an attribute in context. | |||
| """ | |||
| return _get_mpi_config(attr_key) | |||
| @@ -104,7 +104,7 @@ def _get_mpi_config(attr_key): | |||
| Object, The value of given attribute key. | |||
| Raises: | |||
| ValueError: If input key is not an attribute in context. | |||
| ValueError: If input key is not an attribute in config. | |||
| """ | |||
| if not hasattr(_mpi_config(), attr_key): | |||
| raise ValueError("Get context keyword %s is not recognized!" % attr_key) | |||