|
- /**
- * Copyright 2020 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #include "device/cpu/mpi/mpi_adapter.h"
- #include <algorithm>
- #include "utils/log_adapter.h"
-
- namespace mindspore {
- namespace device {
- namespace cpu {
- namespace {
- MPI_Op GetMpiOp(const std::string &op_type) {
- if (op_type == "sum") {
- return MPI_SUM;
- } else if (op_type == "max") {
- return MPI_MAX;
- } else if (op_type == "min") {
- return MPI_MIN;
- } else if (op_type == "prod") {
- return MPI_PROD;
- }
- MS_LOG(EXCEPTION) << "unsupport op_type:" << op_type;
- return MPI_SUM;
- }
- } // namespace
-
- MPIAdapter::MPIAdapter() : rank_id_(0), rank_size_(0), comm_group_world_(MPI_GROUP_NULL) { Init(); }
-
- MPIAdapter::~MPIAdapter() {
- for (auto iter = ranks_group_.begin(); iter != ranks_group_.end(); ++iter) {
- MPI_Group_free(&iter->second);
- }
- if (comm_group_world_ != MPI_GROUP_NULL) {
- MPI_Group_free(&comm_group_world_);
- }
- int finalized;
- MPI_Finalized(&finalized);
- if (finalized == 0) {
- MPI_Finalize();
- }
- }
-
- MPIAdapter &MPIAdapter::Instance() {
- static MPIAdapter instance;
- return instance;
- }
-
- int MPIAdapter::GetRankId() const { return rank_id_; }
-
- void MPIAdapter::Init() {
- static bool init = false;
- if (init) {
- return;
- }
- int init_flag = 0;
- if (MPI_Initialized(&init_flag) != MPI_SUCCESS) {
- MS_LOG(EXCEPTION) << "Check mpi initialized fail!";
- }
- if (init_flag == 0) {
- auto ret = MPI_Init(nullptr, nullptr);
- if (ret != MPI_SUCCESS) {
- MS_LOG(EXCEPTION) << "Failed to init mpi!";
- }
- }
-
- MPI_Comm_group(MPI_COMM_WORLD, &comm_group_world_);
- if (comm_group_world_ == MPI_GROUP_NULL) {
- MS_LOG(EXCEPTION) << "comm_group_world_ init fail!";
- }
- auto ret = MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_);
- if (ret != MPI_SUCCESS) {
- MS_LOG(EXCEPTION) << "Failed to init mpi rank id!";
- }
-
- ret = MPI_Comm_size(MPI_COMM_WORLD, &rank_size_);
- if (ret != MPI_SUCCESS) {
- MS_LOG(EXCEPTION) << "Failed to init mpi rank size!rankid:" << rank_id_;
- }
- init = true;
- }
-
- MPI_Group MPIAdapter::AddGroup(const std::vector<int> &ranks) {
- if (ranks.size() > static_cast<size_t>(rank_size_) || ranks.empty()) {
- MS_LOG(EXCEPTION) << "input rank size: " << ranks.size() << ", max rank size: " << rank_size_;
- }
-
- if (std::find(ranks.begin(), ranks.end(), rank_id_) == ranks.end()) {
- MS_LOG(ERROR) << "rankid:" << rank_id_ << " is not in the input group.";
- return MPI_GROUP_NULL;
- }
- std::lock_guard<std::mutex> lock(group_mutex_);
- auto iter = ranks_group_.find(ranks);
- if (iter != ranks_group_.end()) {
- return iter->second;
- }
- const auto ranks_size = ranks.size();
- std::vector<int> ranks_input(ranks_size, 0);
- for (size_t i = 0; i < ranks_size; ++i) {
- ranks_input[i] = ranks[i];
- }
-
- MPI_Group group = MPI_GROUP_NULL;
- MPI_Group_incl(comm_group_world_, ranks.size(), ranks_input.data(), &group);
- if (group == MPI_GROUP_NULL) {
- MS_LOG(EXCEPTION) << "create mpi group fail!rankid:" << rank_id_;
- }
-
- ranks_group_[ranks] = group;
- MS_LOG(INFO) << "rank:" << rank_id_ << " add group:" << group;
- return group;
- }
-
- bool MPIAdapter::ReduceScatter(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,
- const std::string &op_type) {
- if (ranks_group.empty()) {
- MS_LOG(ERROR) << "input rank group is empty!";
- return false;
- }
-
- auto group = AddGroup(ranks_group);
- if (group == MPI_GROUP_NULL) {
- MS_LOG(EXCEPTION) << "Get mpi group fail!rankid:" << rank_id_;
- }
- MPI_Comm comm;
- MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm);
- if (comm == MPI_COMM_NULL) {
- MS_LOG(EXCEPTION) << "create mpi comm fail!rankid:" << rank_id_;
- }
- std::vector<int> receive_count(ranks_group.size(), 0);
- for (size_t i = 0; i < ranks_group.size(); ++i) {
- receive_count[i] = data_num;
- }
-
- auto op = GetMpiOp(op_type);
- auto ret = MPI_Reduce_scatter(input, output, receive_count.data(), MPI_FLOAT, op, comm);
- bool result = true;
- if (ret != MPI_SUCCESS) {
- MS_LOG(ERROR) << "mpi reduce_scatter fail!ret = " << ret << ", rankid:" << rank_id_;
- result = false;
- }
-
- ret = MPI_Comm_free(&comm);
- if (ret != MPI_SUCCESS) {
- MS_LOG(WARNING) << "mpi comm free fail! ret = " << ret << ", rankid:" << rank_id_;
- }
- return result;
- }
-
- bool MPIAdapter::AllGather(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) {
- if (ranks_group.empty()) {
- MS_LOG(ERROR) << "input rank group is empty!";
- return false;
- }
- auto group = AddGroup(ranks_group);
- if (group == MPI_GROUP_NULL) {
- MS_LOG(EXCEPTION) << "Get mpi group fail! rankid:" << rank_id_;
- }
- MPI_Comm comm;
- MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm);
- if (comm == MPI_COMM_NULL) {
- MS_LOG(EXCEPTION) << "create mpi comm fail! rankid:" << rank_id_;
- }
-
- auto ret = MPI_Allgather(input, data_num, MPI_FLOAT, output, data_num, MPI_FLOAT, comm);
- bool result = true;
- if (ret != MPI_SUCCESS) {
- MS_LOG(ERROR) << "mpi allgater fail!ret = " << ret << ", rankid:" << rank_id_;
- result = false;
- }
- ret = MPI_Comm_free(&comm);
- if (ret != MPI_SUCCESS) {
- MS_LOG(WARNING) << "mpi comm free fail!ret = " << ret << ",rankid:" << rank_id_;
- }
- return result;
- }
- } // namespace cpu
- } // namespace device
- } // namespace mindspore
|