Browse Source

support host reduce scatter and mpi config

tags/v0.5.0-beta
chenjianping 5 years ago
parent
commit
6034f9c1e2
12 changed files with 355 additions and 21 deletions
  1. +1
    -1
      build.sh
  2. +9
    -7
      mindspore/ccsrc/device/CMakeLists.txt
  3. +24
    -7
      mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc
  4. +66
    -1
      mindspore/ccsrc/device/cpu/mpi/mpi_adapter.cc
  5. +3
    -1
      mindspore/ccsrc/device/cpu/mpi/mpi_adapter.h
  6. +6
    -0
      mindspore/ccsrc/pipeline/init.cc
  7. +31
    -0
      mindspore/ccsrc/utils/mpi/mpi_config.cc
  8. +42
    -0
      mindspore/ccsrc/utils/mpi/mpi_config.h
  9. +38
    -0
      mindspore/context.py
  10. +14
    -0
      mindspore/parallel/mpi/__init__.py
  11. +111
    -0
      mindspore/parallel/mpi/_mpi_config.py
  12. +10
    -4
      tests/st/ops/cpu/test_reduce_scatter.py

+ 1
- 1
build.sh View File

@@ -49,7 +49,7 @@ usage()
echo " -Q Enable dump memory, default off"
echo " -D Enable dumping of function graph ir, default on"
echo " -z Compile dataset & mindrecord, default on"
echo " -M Enable MPI and NCCL for GPU training, default on"
echo " -M Enable MPI and NCCL for GPU training, gpu default on"
echo " -V Specify the minimum required cuda version, default CUDA 9.2"
echo " -I Compile predict, default off"
echo " -K Compile with AKG, default off"


+ 9
- 7
mindspore/ccsrc/device/CMakeLists.txt View File

@@ -14,17 +14,19 @@ endif ()

if (ENABLE_CPU)
file(GLOB_RECURSE CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc")
if (ENABLE_MPI)
# _ms_mpi
set_property(SOURCE "gpu/mpi/mpi_initializer.cc"
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
pybind11_add_module(_ms_mpi "gpu/mpi/mpi_initializer.cc")
target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mindspore::ompi)
else ()
if (NOT ENABLE_MPI)
list(REMOVE_ITEM CPU_SRC_LIST "cpu/mpi/mpi_adapter.cc")
endif ()
endif ()

if (ENABLE_MPI)
# _ms_mpi
set_property(SOURCE "gpu/mpi/mpi_initializer.cc"
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
pybind11_add_module(_ms_mpi "gpu/mpi/mpi_initializer.cc")
target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mindspore::ompi)
endif ()

# gpu
if (ENABLE_GPU)
file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/*.cc" "gpu/*.cu")


+ 24
- 7
mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc View File

@@ -25,6 +25,7 @@
#include "device/ascend/ascend_device_address.h"
#include "device/cpu/mpi/mpi_adapter.h"
#include "utils/context/ms_context.h"
#include "utils/mpi/mpi_config.h"
#include "device/ascend/profiling/profiling_manager.h"
#include "hccl/hcom.h"
#include "common/trans.h"
@@ -510,19 +511,35 @@ bool AscendKernelRuntime::HcclInit() {
MS_LOG(ERROR) << "file path " << config_path_str << " does not exist";
return false;
}
const char *identify = nullptr;
#ifdef ENABLE_MPI
int rank_id = device::cpu::MPIAdapter::Instance().GetRankId();
const char *offset = std::getenv("RANK_OFFSET");
if (offset != nullptr) {
int rank_offset = std::stoi(offset);
rank_id += rank_offset;
std::string rank_id_tmp;
auto mpi_config_ptr = MpiConfig::GetInstance();
MS_EXCEPTION_IF_NULL(mpi_config_ptr);
if (mpi_config_ptr->enable_mpi()) {
int rank_id = device::cpu::MPIAdapter::Instance().GetRankId();
const char *offset = std::getenv("RANK_OFFSET");
if (offset != nullptr) {
try {
int rank_offset = std::stoi(offset);
rank_id += rank_offset;
} catch (std::invalid_argument) {
MS_LOG(EXCEPTION) << "stoi invalid argument:" << offset;
} catch (std::out_of_range) {
MS_LOG(EXCEPTION) << "stoi out_of_range:" << offset;
}
}
rank_id_tmp = std::to_string(rank_id);
identify = rank_id_tmp.c_str();
} else {
identify = std::getenv("RANK_ID");
}
const char *identify = reinterpret_cast<const char *>(std::to_string(rank_id).c_str());
#else
const char *identify = std::getenv("RANK_ID");
identify = std::getenv("RANK_ID");
#endif
if (identify == nullptr) {
MS_LOG(ERROR) << "get hccl rankid failed, please set env RANK_ID";
free(full_path);
return false;
}
MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << identify;


+ 66
- 1
mindspore/ccsrc/device/cpu/mpi/mpi_adapter.cc View File

@@ -16,6 +16,7 @@

#include "device/cpu/mpi/mpi_adapter.h"
#include <algorithm>
#include "utils/mpi/mpi_config.h"
#include "utils/log_adapter.h"

namespace mindspore {
@@ -35,6 +36,20 @@ MPI_Op GetMpiOp(const std::string &op_type) {
MS_LOG(EXCEPTION) << "unsupport op_type:" << op_type;
return MPI_SUM;
}

int GetScatterIndex(int rankid, const std::vector<int> &ranks_group) {
int scatter_index = -1;
for (size_t i = 0; i < ranks_group.size(); ++i) {
if (ranks_group[i] == rankid) {
scatter_index = static_cast<int>(i);
break;
}
}
if (scatter_index == -1) {
MS_LOG(EXCEPTION) << "process rankid " << rankid << " does not in the input rank group!";
}
return scatter_index;
}
} // namespace

MPIAdapter::MPIAdapter() : rank_id_(0), rank_size_(0), comm_group_world_(MPI_GROUP_NULL) { Init(); }
@@ -65,6 +80,11 @@ void MPIAdapter::Init() {
if (init) {
return;
}
auto mpi_config_ptr = MpiConfig::GetInstance();
MS_EXCEPTION_IF_NULL(mpi_config_ptr);
if (!mpi_config_ptr->enable_mpi()) {
MS_LOG(EXCEPTION) << "MPI is disabled now!Please enable mpi with mpi config first.";
}
int init_flag = 0;
if (MPI_Initialized(&init_flag) != MPI_SUCCESS) {
MS_LOG(EXCEPTION) << "Check mpi initialized fail!";
@@ -123,7 +143,7 @@ MPI_Group MPIAdapter::AddGroup(const std::vector<int> &ranks) {
return group;
}

bool MPIAdapter::ReduceScatter(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,
bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,
const std::string &op_type) {
if (ranks_group.empty()) {
MS_LOG(ERROR) << "input rank group is empty!";
@@ -159,6 +179,51 @@ bool MPIAdapter::ReduceScatter(float *input, float *output, const std::vector<in
return result;
}

bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t data_num,
const std::string &op_type, float *output) {
int scatter_index = GetScatterIndex(rank_id_, ranks_group);
auto group = AddGroup(ranks_group);
if (group == MPI_GROUP_NULL) {
MS_LOG(EXCEPTION) << "Get mpi group fail!rankid:" << rank_id_;
}
MPI_Comm comm;
MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm);
if (comm == MPI_COMM_NULL) {
MS_LOG(EXCEPTION) << "create mpi comm fail!rankid:" << rank_id_;
}

MPI_Win window;
auto ret = MPI_Win_create(input, data_num * sizeof(float), sizeof(float), MPI_INFO_NULL, comm, &window);
if (ret != MPI_SUCCESS) {
MS_LOG(ERROR) << "mpi window create fail! ret = " << ret;
return false;
}
MPI_Win_fence(0, window);
for (size_t i = 0; i < ranks_group.size(); ++i) {
int remote_rank = ranks_group[i];
if (rank_id_ == remote_rank) {
continue;
}
auto op = GetMpiOp(op_type);
ret = MPI_Accumulate(input + i * data_num, data_num, MPI_FLOAT, remote_rank, i * data_num, data_num, MPI_FLOAT, op,
window);
if (ret != MPI_SUCCESS) {
MS_LOG(EXCEPTION) << "mpi accumulate " << op_type << " fail!ret = " << ret;
}
}
MPI_Win_fence(0, window);
if (output != nullptr) {
auto data_size = data_num * sizeof(float);
auto copy_ret = memcpy_s(output, data_size, input + scatter_index * data_num, data_size);
if (copy_ret != 0) {
MS_LOG(EXCEPTION) << "copy output memory fail!";
}
}
MPI_Win_free(&window);
MPI_Comm_free(&comm);
return true;
}

bool MPIAdapter::AllGather(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) {
if (ranks_group.empty()) {
MS_LOG(ERROR) << "input rank group is empty!";


+ 3
- 1
mindspore/ccsrc/device/cpu/mpi/mpi_adapter.h View File

@@ -32,8 +32,10 @@ class MPIAdapter {
~MPIAdapter();
static MPIAdapter &Instance();
int GetRankId() const;
bool ReduceScatter(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,
bool ReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,
const std::string &op_type = kOpTypeSum);
bool ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t data_num,
const std::string &op_type = kOpTypeSum, float *output = nullptr);
bool AllGather(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num);

private:


+ 6
- 0
mindspore/ccsrc/pipeline/init.cc View File

@@ -26,6 +26,7 @@
#include "pipeline/parse/python_adapter.h"
#include "utils/summary/event_writer.h"
#include "utils/config_manager.h"
#include "utils/mpi/mpi_config.h"
#include "parallel/context.h"
#include "parallel/device_manager.h"
#include "parallel/costmodel_context.h"
@@ -147,6 +148,11 @@ PYBIND11_MODULE(_c_expression, m) {
.def("get_max_device_memory", &mindspore::MsContext::max_device_memory, "Get deivce memory max size.")
.def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size.");

(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")
.def("get_enable_mpi", &mindspore::MpiConfig::enable_mpi, "Get whether enable mpi.")
.def("set_enable_mpi", &mindspore::MpiConfig::set_enable_mpi, "Set whether to enable mpi.");

(void)py::class_<ParallelContext, std::shared_ptr<ParallelContext>>(m, "AutoParallelContext")
.def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.")
.def("get_device_num", &ParallelContext::device_num, "Get device num.")


+ 31
- 0
mindspore/ccsrc/utils/mpi/mpi_config.cc View File

@@ -0,0 +1,31 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "utils/mpi/mpi_config.h"

namespace mindspore {

std::shared_ptr<MpiConfig> MpiConfig::instance_ = nullptr;

std::shared_ptr<MpiConfig> MpiConfig::GetInstance() {
if (instance_ == nullptr) {
MS_LOG(DEBUG) << "Create new mpi config instance.";
instance_.reset(new (std::nothrow) MpiConfig());
}
return instance_;
}

} // namespace mindspore

+ 42
- 0
mindspore/ccsrc/utils/mpi/mpi_config.h View File

@@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_UTILS_MPI_MS_CONTEXT_H_
#define MINDSPORE_CCSRC_UTILS_MPI_MS_CONTEXT_H_
#include <memory>
#include "utils/log_adapter.h"

namespace mindspore {
class MpiConfig {
public:
~MpiConfig() = default;
MpiConfig(const MpiConfig &) = delete;
MpiConfig &operator=(const MpiConfig &) = delete;

static std::shared_ptr<MpiConfig> GetInstance();

void set_enable_mpi(bool flag) { enable_mpi_ = flag; }
bool enable_mpi() const { return enable_mpi_; }

private:
MpiConfig() : enable_mpi_(false) {}

static std::shared_ptr<MpiConfig> instance_;
bool enable_mpi_;
};
} // namespace mindspore

#endif // MINDSPORE_CCSRC_UTILS_MPI_MS_CONTEXT_H_

+ 38
- 0
mindspore/context.py View File

@@ -25,6 +25,7 @@ from mindspore._c_expression import MSContext
from mindspore._checkparam import args_type_check
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
_reset_auto_parallel_context
from mindspore.parallel.mpi._mpi_config import _set_mpi_config, _get_mpi_config

__all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context',
'get_auto_parallel_context', 'reset_auto_parallel_context']
@@ -566,3 +567,40 @@ def get_context(attr_key):
if not hasattr(_context(), attr_key):
raise ValueError("Get context keyword %s is not recognized!" % attr_key)
return getattr(_context(), attr_key)

@args_type_check(enable_mpi=bool)
def set_mpi_config(**kwargs):
"""
Sets mpi config for running environment.

mpi config should be configured before running your program. If there is no configuration,
mpi moudle will be disabled by default.

Note:
Attribute name is required for setting attributes.

Args:
enable_mpi (bool): Whether to enable mpi. Default: False.

Raises:
ValueError: If input key is not an attribute in mpi config.

Examples:
>>> mpiconfig.set_mpi_config(enable_mpi=True)
"""
_set_mpi_config(**kwargs)

def get_mpi_config(attr_key):
"""
Gets mpi config attribute value according to the input key.

Args:
attr_key (str): The key of the attribute.

Returns:
Object, The value of given attribute key.

Raises:
ValueError: If input key is not an attribute in context.
"""
return _get_mpi_config(attr_key)

+ 14
- 0
mindspore/parallel/mpi/__init__.py View File

@@ -0,0 +1,14 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

+ 111
- 0
mindspore/parallel/mpi/_mpi_config.py View File

@@ -0,0 +1,111 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
The MPI config, used to configure the MPI environment.
"""
import threading
from mindspore._c_expression import MpiConfig
from mindspore._checkparam import args_type_check

class _MpiConfig:
"""
_MpiConfig is the config tool for controlling MPI

Note:
Create a config through instantiating MpiConfig object is not recommended.
should use MpiConfig() to get the config since MpiConfig is singleton.
"""
_instance = None
_instance_lock = threading.Lock()

def __init__(self):
self._mpiconfig_handle = MpiConfig.get_instance()

def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance_lock.acquire()
cls._instance = object.__new__(cls)
cls._instance_lock.release()
return cls._instance

def __getattribute__(self, attr):
value = object.__getattribute__(self, attr)
if attr == "_mpiconfig_handle" and value is None:
raise ValueError("mpiconfig handle is none in MpiConfig!!!")
return value

@property
def enable_mpi(self):
return self._mpiconfig_handle.get_enable_mpi()

@enable_mpi.setter
def enable_mpi(self, enable_mpi):
self._mpiconfig_handle.set_enable_mpi(enable_mpi)

_k_mpi_config = None
def _mpi_config():
"""
Get the global mpi config, if mpi config is not created, create a new one.

Returns:
_MpiConfig, the global mpi config.
"""
global _k_mpi_config
if _k_mpi_config is None:
_k_mpi_config = _MpiConfig()
return _k_mpi_config

@args_type_check(enable_mpi=bool)
def _set_mpi_config(**kwargs):
"""
Sets mpi config for running environment.

mpi config should be configured before running your program. If there is no configuration,
mpi moudle will be disabled by default.

Note:
Attribute name is required for setting attributes.

Args:
enable_mpi (bool): Whether to enable mpi. Default: False.

Raises:
ValueError: If input key is not an attribute in mpi config.

Examples:
>>> mpiconfig.set_mpi_config(enable_mpi=True)
"""
for key, value in kwargs.items():
if not hasattr(_mpi_config(), key):
raise ValueError("Set mpi config keyword %s is not recognized!" % key)
setattr(_mpi_config(), key, value)


def _get_mpi_config(attr_key):
"""
Gets mpi config attribute value according to the input key.

Args:
attr_key (str): The key of the attribute.

Returns:
Object, The value of given attribute key.

Raises:
ValueError: If input key is not an attribute in context.
"""
if not hasattr(_mpi_config(), attr_key):
raise ValueError("Get context keyword %s is not recognized!" % attr_key)
return getattr(_mpi_config(), attr_key)

+ 10
- 4
tests/st/ops/cpu/test_reduce_scatter.py View File

@@ -23,9 +23,10 @@ from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
import mindspore._ms_mpi as mpi
# run comand:
# mpirun -np 3 python test_reduce_scatter.py
# mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_reduce_scatter.py

context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
context.set_mpi_config(enable_mpi=True)

class Net(nn.Cell):
def __init__(self):
@@ -46,14 +47,19 @@ class AllGatherNet(nn.Cell):
return self.hostallgather(x)

def test_net_reduce_scatter():
x = np.ones(12).astype(np.float32) * 0.1
x = np.arange(12).astype(np.float32) * 0.1
reducescatter = Net()
rankid = mpi.get_rank_id()
print("self rankid:", rankid)
output = reducescatter(Tensor(x, mstype.float32))
print("output:\n", output)
expect_result = np.ones(4).astype(np.float32) * 0.3
if rankid == 0:
expect_result = np.arange(4).astype(np.float32) * 0.3
if rankid == 1:
expect_result = np.arange(4, 8).astype(np.float32) * 0.3
if rankid == 2:
expect_result = np.arange(8, 12).astype(np.float32) * 0.3
diff = abs(output.asnumpy() - expect_result)
error = np.ones(shape=expect_result.shape) * 1.0e-6
assert np.all(diff < error)
@@ -61,7 +67,7 @@ def test_net_reduce_scatter():
allgather = AllGatherNet()
allgather_output = allgather(output)
print("allgather result:\n", allgather_output)
expect_allgather_result = np.ones(12).astype(np.float32) * 0.3
expect_allgather_result = np.arange(12).astype(np.float32) * 0.3
diff = abs(allgather_output.asnumpy() - expect_allgather_result)
error = np.ones(shape=expect_allgather_result.shape) * 1.0e-6
assert np.all(diff < error)


Loading…
Cancel
Save