| @@ -49,7 +49,7 @@ usage() | |||||
| echo " -Q Enable dump memory, default off" | echo " -Q Enable dump memory, default off" | ||||
| echo " -D Enable dumping of function graph ir, default on" | echo " -D Enable dumping of function graph ir, default on" | ||||
| echo " -z Compile dataset & mindrecord, default on" | echo " -z Compile dataset & mindrecord, default on" | ||||
| echo " -M Enable MPI and NCCL for GPU training, default on" | |||||
| echo " -M Enable MPI and NCCL for GPU training, gpu default on" | |||||
| echo " -V Specify the minimum required cuda version, default CUDA 9.2" | echo " -V Specify the minimum required cuda version, default CUDA 9.2" | ||||
| echo " -I Compile predict, default off" | echo " -I Compile predict, default off" | ||||
| echo " -K Compile with AKG, default off" | echo " -K Compile with AKG, default off" | ||||
| @@ -14,17 +14,19 @@ endif () | |||||
| if (ENABLE_CPU) | if (ENABLE_CPU) | ||||
| file(GLOB_RECURSE CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc") | file(GLOB_RECURSE CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc") | ||||
| if (ENABLE_MPI) | |||||
| # _ms_mpi | |||||
| set_property(SOURCE "gpu/mpi/mpi_initializer.cc" | |||||
| PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) | |||||
| pybind11_add_module(_ms_mpi "gpu/mpi/mpi_initializer.cc") | |||||
| target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mindspore::ompi) | |||||
| else () | |||||
| if (NOT ENABLE_MPI) | |||||
| list(REMOVE_ITEM CPU_SRC_LIST "cpu/mpi/mpi_adapter.cc") | list(REMOVE_ITEM CPU_SRC_LIST "cpu/mpi/mpi_adapter.cc") | ||||
| endif () | endif () | ||||
| endif () | endif () | ||||
| if (ENABLE_MPI) | |||||
| # _ms_mpi | |||||
| set_property(SOURCE "gpu/mpi/mpi_initializer.cc" | |||||
| PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) | |||||
| pybind11_add_module(_ms_mpi "gpu/mpi/mpi_initializer.cc") | |||||
| target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mindspore::ompi) | |||||
| endif () | |||||
| # gpu | # gpu | ||||
| if (ENABLE_GPU) | if (ENABLE_GPU) | ||||
| file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/*.cc" "gpu/*.cu") | file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/*.cc" "gpu/*.cu") | ||||
| @@ -25,6 +25,7 @@ | |||||
| #include "device/ascend/ascend_device_address.h" | #include "device/ascend/ascend_device_address.h" | ||||
| #include "device/cpu/mpi/mpi_adapter.h" | #include "device/cpu/mpi/mpi_adapter.h" | ||||
| #include "utils/context/ms_context.h" | #include "utils/context/ms_context.h" | ||||
| #include "utils/mpi/mpi_config.h" | |||||
| #include "device/ascend/profiling/profiling_manager.h" | #include "device/ascend/profiling/profiling_manager.h" | ||||
| #include "hccl/hcom.h" | #include "hccl/hcom.h" | ||||
| #include "common/trans.h" | #include "common/trans.h" | ||||
| @@ -510,19 +511,35 @@ bool AscendKernelRuntime::HcclInit() { | |||||
| MS_LOG(ERROR) << "file path " << config_path_str << " does not exist"; | MS_LOG(ERROR) << "file path " << config_path_str << " does not exist"; | ||||
| return false; | return false; | ||||
| } | } | ||||
| const char *identify = nullptr; | |||||
| #ifdef ENABLE_MPI | #ifdef ENABLE_MPI | ||||
| int rank_id = device::cpu::MPIAdapter::Instance().GetRankId(); | |||||
| const char *offset = std::getenv("RANK_OFFSET"); | |||||
| if (offset != nullptr) { | |||||
| int rank_offset = std::stoi(offset); | |||||
| rank_id += rank_offset; | |||||
| std::string rank_id_tmp; | |||||
| auto mpi_config_ptr = MpiConfig::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(mpi_config_ptr); | |||||
| if (mpi_config_ptr->enable_mpi()) { | |||||
| int rank_id = device::cpu::MPIAdapter::Instance().GetRankId(); | |||||
| const char *offset = std::getenv("RANK_OFFSET"); | |||||
| if (offset != nullptr) { | |||||
| try { | |||||
| int rank_offset = std::stoi(offset); | |||||
| rank_id += rank_offset; | |||||
| } catch (std::invalid_argument) { | |||||
| MS_LOG(EXCEPTION) << "stoi invalid argument:" << offset; | |||||
| } catch (std::out_of_range) { | |||||
| MS_LOG(EXCEPTION) << "stoi out_of_range:" << offset; | |||||
| } | |||||
| } | |||||
| rank_id_tmp = std::to_string(rank_id); | |||||
| identify = rank_id_tmp.c_str(); | |||||
| } else { | |||||
| identify = std::getenv("RANK_ID"); | |||||
| } | } | ||||
| const char *identify = reinterpret_cast<const char *>(std::to_string(rank_id).c_str()); | |||||
| #else | #else | ||||
| const char *identify = std::getenv("RANK_ID"); | |||||
| identify = std::getenv("RANK_ID"); | |||||
| #endif | #endif | ||||
| if (identify == nullptr) { | if (identify == nullptr) { | ||||
| MS_LOG(ERROR) << "get hccl rankid failed, please set env RANK_ID"; | MS_LOG(ERROR) << "get hccl rankid failed, please set env RANK_ID"; | ||||
| free(full_path); | |||||
| return false; | return false; | ||||
| } | } | ||||
| MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << identify; | MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << identify; | ||||
| @@ -16,6 +16,7 @@ | |||||
| #include "device/cpu/mpi/mpi_adapter.h" | #include "device/cpu/mpi/mpi_adapter.h" | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "utils/mpi/mpi_config.h" | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -35,6 +36,20 @@ MPI_Op GetMpiOp(const std::string &op_type) { | |||||
| MS_LOG(EXCEPTION) << "unsupport op_type:" << op_type; | MS_LOG(EXCEPTION) << "unsupport op_type:" << op_type; | ||||
| return MPI_SUM; | return MPI_SUM; | ||||
| } | } | ||||
| int GetScatterIndex(int rankid, const std::vector<int> &ranks_group) { | |||||
| int scatter_index = -1; | |||||
| for (size_t i = 0; i < ranks_group.size(); ++i) { | |||||
| if (ranks_group[i] == rankid) { | |||||
| scatter_index = static_cast<int>(i); | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (scatter_index == -1) { | |||||
| MS_LOG(EXCEPTION) << "process rankid " << rankid << " does not in the input rank group!"; | |||||
| } | |||||
| return scatter_index; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| MPIAdapter::MPIAdapter() : rank_id_(0), rank_size_(0), comm_group_world_(MPI_GROUP_NULL) { Init(); } | MPIAdapter::MPIAdapter() : rank_id_(0), rank_size_(0), comm_group_world_(MPI_GROUP_NULL) { Init(); } | ||||
| @@ -65,6 +80,11 @@ void MPIAdapter::Init() { | |||||
| if (init) { | if (init) { | ||||
| return; | return; | ||||
| } | } | ||||
| auto mpi_config_ptr = MpiConfig::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(mpi_config_ptr); | |||||
| if (!mpi_config_ptr->enable_mpi()) { | |||||
| MS_LOG(EXCEPTION) << "MPI is disabled now!Please enable mpi with mpi config first."; | |||||
| } | |||||
| int init_flag = 0; | int init_flag = 0; | ||||
| if (MPI_Initialized(&init_flag) != MPI_SUCCESS) { | if (MPI_Initialized(&init_flag) != MPI_SUCCESS) { | ||||
| MS_LOG(EXCEPTION) << "Check mpi initialized fail!"; | MS_LOG(EXCEPTION) << "Check mpi initialized fail!"; | ||||
| @@ -123,7 +143,7 @@ MPI_Group MPIAdapter::AddGroup(const std::vector<int> &ranks) { | |||||
| return group; | return group; | ||||
| } | } | ||||
| bool MPIAdapter::ReduceScatter(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num, | |||||
| bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num, | |||||
| const std::string &op_type) { | const std::string &op_type) { | ||||
| if (ranks_group.empty()) { | if (ranks_group.empty()) { | ||||
| MS_LOG(ERROR) << "input rank group is empty!"; | MS_LOG(ERROR) << "input rank group is empty!"; | ||||
| @@ -159,6 +179,51 @@ bool MPIAdapter::ReduceScatter(float *input, float *output, const std::vector<in | |||||
| return result; | return result; | ||||
| } | } | ||||
| bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t data_num, | |||||
| const std::string &op_type, float *output) { | |||||
| int scatter_index = GetScatterIndex(rank_id_, ranks_group); | |||||
| auto group = AddGroup(ranks_group); | |||||
| if (group == MPI_GROUP_NULL) { | |||||
| MS_LOG(EXCEPTION) << "Get mpi group fail!rankid:" << rank_id_; | |||||
| } | |||||
| MPI_Comm comm; | |||||
| MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm); | |||||
| if (comm == MPI_COMM_NULL) { | |||||
| MS_LOG(EXCEPTION) << "create mpi comm fail!rankid:" << rank_id_; | |||||
| } | |||||
| MPI_Win window; | |||||
| auto ret = MPI_Win_create(input, data_num * sizeof(float), sizeof(float), MPI_INFO_NULL, comm, &window); | |||||
| if (ret != MPI_SUCCESS) { | |||||
| MS_LOG(ERROR) << "mpi window create fail! ret = " << ret; | |||||
| return false; | |||||
| } | |||||
| MPI_Win_fence(0, window); | |||||
| for (size_t i = 0; i < ranks_group.size(); ++i) { | |||||
| int remote_rank = ranks_group[i]; | |||||
| if (rank_id_ == remote_rank) { | |||||
| continue; | |||||
| } | |||||
| auto op = GetMpiOp(op_type); | |||||
| ret = MPI_Accumulate(input + i * data_num, data_num, MPI_FLOAT, remote_rank, i * data_num, data_num, MPI_FLOAT, op, | |||||
| window); | |||||
| if (ret != MPI_SUCCESS) { | |||||
| MS_LOG(EXCEPTION) << "mpi accumulate " << op_type << " fail!ret = " << ret; | |||||
| } | |||||
| } | |||||
| MPI_Win_fence(0, window); | |||||
| if (output != nullptr) { | |||||
| auto data_size = data_num * sizeof(float); | |||||
| auto copy_ret = memcpy_s(output, data_size, input + scatter_index * data_num, data_size); | |||||
| if (copy_ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "copy output memory fail!"; | |||||
| } | |||||
| } | |||||
| MPI_Win_free(&window); | |||||
| MPI_Comm_free(&comm); | |||||
| return true; | |||||
| } | |||||
| bool MPIAdapter::AllGather(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) { | bool MPIAdapter::AllGather(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) { | ||||
| if (ranks_group.empty()) { | if (ranks_group.empty()) { | ||||
| MS_LOG(ERROR) << "input rank group is empty!"; | MS_LOG(ERROR) << "input rank group is empty!"; | ||||
| @@ -32,8 +32,10 @@ class MPIAdapter { | |||||
| ~MPIAdapter(); | ~MPIAdapter(); | ||||
| static MPIAdapter &Instance(); | static MPIAdapter &Instance(); | ||||
| int GetRankId() const; | int GetRankId() const; | ||||
| bool ReduceScatter(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num, | |||||
| bool ReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num, | |||||
| const std::string &op_type = kOpTypeSum); | const std::string &op_type = kOpTypeSum); | ||||
| bool ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t data_num, | |||||
| const std::string &op_type = kOpTypeSum, float *output = nullptr); | |||||
| bool AllGather(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num); | bool AllGather(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num); | ||||
| private: | private: | ||||
| @@ -26,6 +26,7 @@ | |||||
| #include "pipeline/parse/python_adapter.h" | #include "pipeline/parse/python_adapter.h" | ||||
| #include "utils/summary/event_writer.h" | #include "utils/summary/event_writer.h" | ||||
| #include "utils/config_manager.h" | #include "utils/config_manager.h" | ||||
| #include "utils/mpi/mpi_config.h" | |||||
| #include "parallel/context.h" | #include "parallel/context.h" | ||||
| #include "parallel/device_manager.h" | #include "parallel/device_manager.h" | ||||
| #include "parallel/costmodel_context.h" | #include "parallel/costmodel_context.h" | ||||
| @@ -147,6 +148,11 @@ PYBIND11_MODULE(_c_expression, m) { | |||||
| .def("get_max_device_memory", &mindspore::MsContext::max_device_memory, "Get deivce memory max size.") | .def("get_max_device_memory", &mindspore::MsContext::max_device_memory, "Get deivce memory max size.") | ||||
| .def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size."); | .def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size."); | ||||
| (void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig") | |||||
| .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") | |||||
| .def("get_enable_mpi", &mindspore::MpiConfig::enable_mpi, "Get whether enable mpi.") | |||||
| .def("set_enable_mpi", &mindspore::MpiConfig::set_enable_mpi, "Set whether to enable mpi."); | |||||
| (void)py::class_<ParallelContext, std::shared_ptr<ParallelContext>>(m, "AutoParallelContext") | (void)py::class_<ParallelContext, std::shared_ptr<ParallelContext>>(m, "AutoParallelContext") | ||||
| .def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.") | .def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.") | ||||
| .def("get_device_num", &ParallelContext::device_num, "Get device num.") | .def("get_device_num", &ParallelContext::device_num, "Get device num.") | ||||
| @@ -0,0 +1,31 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "utils/mpi/mpi_config.h" | |||||
| namespace mindspore { | |||||
| std::shared_ptr<MpiConfig> MpiConfig::instance_ = nullptr; | |||||
| std::shared_ptr<MpiConfig> MpiConfig::GetInstance() { | |||||
| if (instance_ == nullptr) { | |||||
| MS_LOG(DEBUG) << "Create new mpi config instance."; | |||||
| instance_.reset(new (std::nothrow) MpiConfig()); | |||||
| } | |||||
| return instance_; | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,42 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_UTILS_MPI_MS_CONTEXT_H_ | |||||
| #define MINDSPORE_CCSRC_UTILS_MPI_MS_CONTEXT_H_ | |||||
| #include <memory> | |||||
| #include "utils/log_adapter.h" | |||||
| namespace mindspore { | |||||
| class MpiConfig { | |||||
| public: | |||||
| ~MpiConfig() = default; | |||||
| MpiConfig(const MpiConfig &) = delete; | |||||
| MpiConfig &operator=(const MpiConfig &) = delete; | |||||
| static std::shared_ptr<MpiConfig> GetInstance(); | |||||
| void set_enable_mpi(bool flag) { enable_mpi_ = flag; } | |||||
| bool enable_mpi() const { return enable_mpi_; } | |||||
| private: | |||||
| MpiConfig() : enable_mpi_(false) {} | |||||
| static std::shared_ptr<MpiConfig> instance_; | |||||
| bool enable_mpi_; | |||||
| }; | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_UTILS_MPI_MS_CONTEXT_H_ | |||||
| @@ -25,6 +25,7 @@ from mindspore._c_expression import MSContext | |||||
| from mindspore._checkparam import args_type_check | from mindspore._checkparam import args_type_check | ||||
| from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \ | from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \ | ||||
| _reset_auto_parallel_context | _reset_auto_parallel_context | ||||
| from mindspore.parallel.mpi._mpi_config import _set_mpi_config, _get_mpi_config | |||||
| __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context', | __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context', | ||||
| 'get_auto_parallel_context', 'reset_auto_parallel_context'] | 'get_auto_parallel_context', 'reset_auto_parallel_context'] | ||||
| @@ -566,3 +567,40 @@ def get_context(attr_key): | |||||
| if not hasattr(_context(), attr_key): | if not hasattr(_context(), attr_key): | ||||
| raise ValueError("Get context keyword %s is not recognized!" % attr_key) | raise ValueError("Get context keyword %s is not recognized!" % attr_key) | ||||
| return getattr(_context(), attr_key) | return getattr(_context(), attr_key) | ||||
| @args_type_check(enable_mpi=bool) | |||||
| def set_mpi_config(**kwargs): | |||||
| """ | |||||
| Sets mpi config for running environment. | |||||
| mpi config should be configured before running your program. If there is no configuration, | |||||
| mpi moudle will be disabled by default. | |||||
| Note: | |||||
| Attribute name is required for setting attributes. | |||||
| Args: | |||||
| enable_mpi (bool): Whether to enable mpi. Default: False. | |||||
| Raises: | |||||
| ValueError: If input key is not an attribute in mpi config. | |||||
| Examples: | |||||
| >>> mpiconfig.set_mpi_config(enable_mpi=True) | |||||
| """ | |||||
| _set_mpi_config(**kwargs) | |||||
| def get_mpi_config(attr_key): | |||||
| """ | |||||
| Gets mpi config attribute value according to the input key. | |||||
| Args: | |||||
| attr_key (str): The key of the attribute. | |||||
| Returns: | |||||
| Object, The value of given attribute key. | |||||
| Raises: | |||||
| ValueError: If input key is not an attribute in context. | |||||
| """ | |||||
| return _get_mpi_config(attr_key) | |||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -0,0 +1,111 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| The MPI config, used to configure the MPI environment. | |||||
| """ | |||||
| import threading | |||||
| from mindspore._c_expression import MpiConfig | |||||
| from mindspore._checkparam import args_type_check | |||||
| class _MpiConfig: | |||||
| """ | |||||
| _MpiConfig is the config tool for controlling MPI | |||||
| Note: | |||||
| Create a config through instantiating MpiConfig object is not recommended. | |||||
| should use MpiConfig() to get the config since MpiConfig is singleton. | |||||
| """ | |||||
| _instance = None | |||||
| _instance_lock = threading.Lock() | |||||
| def __init__(self): | |||||
| self._mpiconfig_handle = MpiConfig.get_instance() | |||||
| def __new__(cls, *args, **kwargs): | |||||
| if cls._instance is None: | |||||
| cls._instance_lock.acquire() | |||||
| cls._instance = object.__new__(cls) | |||||
| cls._instance_lock.release() | |||||
| return cls._instance | |||||
| def __getattribute__(self, attr): | |||||
| value = object.__getattribute__(self, attr) | |||||
| if attr == "_mpiconfig_handle" and value is None: | |||||
| raise ValueError("mpiconfig handle is none in MpiConfig!!!") | |||||
| return value | |||||
| @property | |||||
| def enable_mpi(self): | |||||
| return self._mpiconfig_handle.get_enable_mpi() | |||||
| @enable_mpi.setter | |||||
| def enable_mpi(self, enable_mpi): | |||||
| self._mpiconfig_handle.set_enable_mpi(enable_mpi) | |||||
| _k_mpi_config = None | |||||
| def _mpi_config(): | |||||
| """ | |||||
| Get the global mpi config, if mpi config is not created, create a new one. | |||||
| Returns: | |||||
| _MpiConfig, the global mpi config. | |||||
| """ | |||||
| global _k_mpi_config | |||||
| if _k_mpi_config is None: | |||||
| _k_mpi_config = _MpiConfig() | |||||
| return _k_mpi_config | |||||
| @args_type_check(enable_mpi=bool) | |||||
| def _set_mpi_config(**kwargs): | |||||
| """ | |||||
| Sets mpi config for running environment. | |||||
| mpi config should be configured before running your program. If there is no configuration, | |||||
| mpi moudle will be disabled by default. | |||||
| Note: | |||||
| Attribute name is required for setting attributes. | |||||
| Args: | |||||
| enable_mpi (bool): Whether to enable mpi. Default: False. | |||||
| Raises: | |||||
| ValueError: If input key is not an attribute in mpi config. | |||||
| Examples: | |||||
| >>> mpiconfig.set_mpi_config(enable_mpi=True) | |||||
| """ | |||||
| for key, value in kwargs.items(): | |||||
| if not hasattr(_mpi_config(), key): | |||||
| raise ValueError("Set mpi config keyword %s is not recognized!" % key) | |||||
| setattr(_mpi_config(), key, value) | |||||
| def _get_mpi_config(attr_key): | |||||
| """ | |||||
| Gets mpi config attribute value according to the input key. | |||||
| Args: | |||||
| attr_key (str): The key of the attribute. | |||||
| Returns: | |||||
| Object, The value of given attribute key. | |||||
| Raises: | |||||
| ValueError: If input key is not an attribute in context. | |||||
| """ | |||||
| if not hasattr(_mpi_config(), attr_key): | |||||
| raise ValueError("Get context keyword %s is not recognized!" % attr_key) | |||||
| return getattr(_mpi_config(), attr_key) | |||||
| @@ -23,9 +23,10 @@ from mindspore.common import dtype as mstype | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| import mindspore._ms_mpi as mpi | import mindspore._ms_mpi as mpi | ||||
| # run comand: | # run comand: | ||||
| # mpirun -np 3 python test_reduce_scatter.py | |||||
| # mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_reduce_scatter.py | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | ||||
| context.set_mpi_config(enable_mpi=True) | |||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| @@ -46,14 +47,19 @@ class AllGatherNet(nn.Cell): | |||||
| return self.hostallgather(x) | return self.hostallgather(x) | ||||
| def test_net_reduce_scatter(): | def test_net_reduce_scatter(): | ||||
| x = np.ones(12).astype(np.float32) * 0.1 | |||||
| x = np.arange(12).astype(np.float32) * 0.1 | |||||
| reducescatter = Net() | reducescatter = Net() | ||||
| rankid = mpi.get_rank_id() | rankid = mpi.get_rank_id() | ||||
| print("self rankid:", rankid) | print("self rankid:", rankid) | ||||
| output = reducescatter(Tensor(x, mstype.float32)) | output = reducescatter(Tensor(x, mstype.float32)) | ||||
| print("output:\n", output) | print("output:\n", output) | ||||
| expect_result = np.ones(4).astype(np.float32) * 0.3 | |||||
| if rankid == 0: | |||||
| expect_result = np.arange(4).astype(np.float32) * 0.3 | |||||
| if rankid == 1: | |||||
| expect_result = np.arange(4, 8).astype(np.float32) * 0.3 | |||||
| if rankid == 2: | |||||
| expect_result = np.arange(8, 12).astype(np.float32) * 0.3 | |||||
| diff = abs(output.asnumpy() - expect_result) | diff = abs(output.asnumpy() - expect_result) | ||||
| error = np.ones(shape=expect_result.shape) * 1.0e-6 | error = np.ones(shape=expect_result.shape) * 1.0e-6 | ||||
| assert np.all(diff < error) | assert np.all(diff < error) | ||||
| @@ -61,7 +67,7 @@ def test_net_reduce_scatter(): | |||||
| allgather = AllGatherNet() | allgather = AllGatherNet() | ||||
| allgather_output = allgather(output) | allgather_output = allgather(output) | ||||
| print("allgather result:\n", allgather_output) | print("allgather result:\n", allgather_output) | ||||
| expect_allgather_result = np.ones(12).astype(np.float32) * 0.3 | |||||
| expect_allgather_result = np.arange(12).astype(np.float32) * 0.3 | |||||
| diff = abs(allgather_output.asnumpy() - expect_allgather_result) | diff = abs(allgather_output.asnumpy() - expect_allgather_result) | ||||
| error = np.ones(shape=expect_allgather_result.shape) * 1.0e-6 | error = np.ones(shape=expect_allgather_result.shape) * 1.0e-6 | ||||
| assert np.all(diff < error) | assert np.all(diff < error) | ||||