浏览代码

fix code check on branch master

tags/v1.3.0
zuochuanyong 4 年前
父节点
当前提交
5021ec507c
共有 3 个文件被更改,包括 25 次插入25 次删除
  1. +7
    -6
      mindspore/ccsrc/runtime/device/cpu/cpu_device_address.cc
  2. +0
    -1
      mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc
  3. +18
    -18
      mindspore/ccsrc/runtime/device/cpu/mpi/mpi_adapter.cc

+ 7
- 6
mindspore/ccsrc/runtime/device/cpu/cpu_device_address.cc 查看文件

@@ -43,7 +43,8 @@ bool CPUDeviceAddress::DumpMemToFile(const std::string &filepath, const std::str
return ret;
}

bool CPUDeviceAddress::SyncDeviceToHost(const ShapeVector & /*shape*/, size_t size, TypeId type, void *host_ptr) const {
bool CPUDeviceAddress::SyncDeviceToHost(const ShapeVector & /* shape */, size_t size, TypeId type,
void *host_ptr) const {
if (ptr_ == nullptr) {
MS_LOG(ERROR) << "The pointer ptr_ is null!";
return false;
@@ -59,11 +60,11 @@ bool CPUDeviceAddress::SyncDeviceToHost(const ShapeVector & /*shape*/, size_t si
return false;
}
} else if (type == kNumberTypeFloat16 && type_id_ == kNumberTypeFloat32) {
FloatToHalf(host_ptr, ptr_, size / 2);
FloatToHalf(host_ptr, ptr_, size >> 1);
} else if (type == kNumberTypeFloat64 && type_id_ == kNumberTypeFloat32) {
FloatToDouble(host_ptr, ptr_, size / sizeof(double));
} else if (type == kNumberTypeInt16 && type_id_ == kNumberTypeInt32) {
IntToShort(host_ptr, ptr_, size / 2);
IntToShort(host_ptr, ptr_, size >> 1);
} else if (type == kNumberTypeInt64 && type_id_ == kNumberTypeInt32) {
IntToLong(host_ptr, ptr_, size / sizeof(int64_t));
} else {
@@ -74,7 +75,7 @@ bool CPUDeviceAddress::SyncDeviceToHost(const ShapeVector & /*shape*/, size_t si
return true;
}

bool CPUDeviceAddress::SyncHostToDevice(const ShapeVector & /*shape*/, size_t size, TypeId type, const void *host_ptr,
bool CPUDeviceAddress::SyncHostToDevice(const ShapeVector & /* shape */, size_t size, TypeId type, const void *host_ptr,
const std::string &format) const {
if (ptr_ == nullptr) {
MS_LOG(ERROR) << "The pointer ptr_ is null!";
@@ -92,11 +93,11 @@ bool CPUDeviceAddress::SyncHostToDevice(const ShapeVector & /*shape*/, size_t si
return false;
}
} else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat16) {
HalfToFloat(ptr_, host_ptr, size / 2);
HalfToFloat(ptr_, host_ptr, size >> 1);
} else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) {
DoubleToFloat(ptr_, host_ptr, size / sizeof(double));
} else if (type_id_ == kNumberTypeInt32 && type == kNumberTypeInt16) {
ShortToInt(ptr_, host_ptr, size / 2);
ShortToInt(ptr_, host_ptr, size >> 1);
} else if (type_id_ == kNumberTypeInt32 && type == kNumberTypeInt64) {
LongToInt(ptr_, host_ptr, size / sizeof(int64_t));
} else {


+ 0
- 1
mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc 查看文件

@@ -46,7 +46,6 @@
namespace mindspore {
namespace device {
namespace cpu {

bool CPUKernelRuntime::Init() {
if (initialized_) {
return true;


+ 18
- 18
mindspore/ccsrc/runtime/device/cpu/mpi/mpi_adapter.cc 查看文件

@@ -33,19 +33,19 @@ std::shared_ptr<MPIAdapter> MPIAdapter::Instance() {
return instance_;
}

#define RAISE_EXCEPTION(message) \
{ \
std::ostringstream oss; \
oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << message; \
pybind11::pybind11_fail(oss.str()); \
}
#define RAISE_EXCEPTION(message) \
do { \
std::ostringstream oss; \
oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << (message); \
pybind11::pybind11_fail(oss.str()); \
} while (0)

#define RAISE_EXCEPTION_WITH_PARAM(message, param) \
{ \
std::ostringstream oss; \
oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << message << param; \
pybind11::pybind11_fail(oss.str()); \
}
#define RAISE_EXCEPTION_WITH_PARAM(message, param) \
do { \
std::ostringstream oss; \
oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << (message) << (param); \
pybind11::pybind11_fail(oss.str()); \
} while (0)

namespace {
MPI_Op GetMpiOp(const std::string &op_type) {
@@ -109,8 +109,8 @@ void MPIAdapter::Init() {
RAISE_EXCEPTION("Check mpi initialized fail!");
}
if (init_flag == 0) {
auto ret = MPI_Init(nullptr, nullptr);
if (ret != MPI_SUCCESS) {
auto ret_init = MPI_Init(nullptr, nullptr);
if (ret_init != MPI_SUCCESS) {
RAISE_EXCEPTION("Failed to init mpi!");
}
}
@@ -126,7 +126,7 @@ void MPIAdapter::Init() {

ret = MPI_Comm_size(MPI_COMM_WORLD, &rank_size_);
if (ret != MPI_SUCCESS) {
RAISE_EXCEPTION_WITH_PARAM("Failed to init mpi rank size!rankid:", rank_id_)
RAISE_EXCEPTION_WITH_PARAM("Failed to init mpi rank size!rankid:", rank_id_);
}
init = true;
}
@@ -153,7 +153,7 @@ MPI_Group MPIAdapter::AddGroup(const std::vector<int> &ranks) {
MPI_Group group = MPI_GROUP_NULL;
MPI_Group_incl(comm_group_world_, ranks.size(), ranks_input.data(), &group);
if (group == MPI_GROUP_NULL) {
RAISE_EXCEPTION_WITH_PARAM("create mpi group fail!rankid:", rank_id_)
RAISE_EXCEPTION_WITH_PARAM("create mpi group fail!rankid:", rank_id_);
}

ranks_group_[ranks] = group;
@@ -169,7 +169,7 @@ bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vec

auto group = AddGroup(ranks_group);
if (group == MPI_GROUP_NULL) {
RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail!rankid:", rank_id_)
RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail!rankid:", rank_id_);
}
MPI_Comm comm;
MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm);
@@ -233,7 +233,7 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int
if (output_size < data_size) {
std::ostringstream exception_msg;
exception_msg << "output buffer size " << output_size << " < input size " << data_size;
RAISE_EXCEPTION(exception_msg.str())
RAISE_EXCEPTION(exception_msg.str());
}
auto copy_ret = memcpy_s(output, output_size, input + scatter_index * input_data_num, data_size);
if (copy_ret != 0) {


正在加载...
取消
保存