|
- /**
- * Copyright 2021 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #ifndef MINDSPORE_CCSRC_DISTRIBUTED_COLLECTIVE_COLLECTIVE_MANAGER_H_
- #define MINDSPORE_CCSRC_DISTRIBUTED_COLLECTIVE_COLLECTIVE_MANAGER_H_
-
- #include <string>
- #include <memory>
- #include <vector>
- #include <atomic>
- #include "utils/ms_utils.h"
- #include "distributed/constants.h"
- #include "runtime/hardware/device_context_manager.h"
-
- namespace mindspore {
- namespace distributed {
- namespace collective {
- using DeviceContext = device::DeviceContext;
- using DeviceContextKey = device::DeviceContextKey;
- using DeviceContextManager = device::DeviceContextManager;
- using CollectiveCommunicationLib = device::CollectiveCommunicationLib;
- using CommunicationGroupPtr = device::CommunicationGroupPtr;
-
- // The collective communication API.
- // MindSpore uses OpenMPI on CPU, NCCL on GPU, HCCL on Ascend, to achieve distributed training.
- // Besides, MindSpore also has its own communication library which is implemented on the CPU side.
- class CollectiveManager {
- public:
- ~CollectiveManager();
- DISABLE_COPY_AND_ASSIGN(CollectiveManager);
- static std::shared_ptr<CollectiveManager> instance();
-
- // Initialize the collective communication for distributed training. The backend type is read from MindSpore context.
- bool Initialize();
-
- // Finalize the collective communication.
- bool Finalize();
-
- // Create communication group.
- bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks);
-
- // Destroy the communication group.
- bool DestroyCommunicationGroup(const std::string &group_name);
-
- // Get the rank id of this process in the specified group.
- uint32_t GetRankId(const std::string &group_name);
-
- // Get the size of the specified group.
- uint32_t GetGroupSize(const std::string &group_name);
-
- // In some cases global rank id and rank size should be set by caller, e.g., when using MindSpore communication
- // framework, they're generated by cluster::ClusterContext.
- void set_global_rank_id(uint32_t global_rank_id);
- void set_global_rank_size(uint32_t global_rank_size);
-
- uint32_t local_rank_id() const;
-
- private:
- CollectiveManager();
-
- // Initialize communication library on host side.
- bool InitHostCommlib();
-
- // Initialize communication library on device side.
- bool InitDeviceCommLib();
-
- // Assign the local rank id for this process.
- bool AssignLocalRank();
-
- std::atomic_bool inited_;
- std::atomic_bool finalized_;
-
- // The device type read from MindSpore context.
- std::string device_type_;
-
- // The device context on both host and device side. They are used to access the communication library on different
- // devices.
- DeviceContext *host_ctx_;
- DeviceContext *device_ctx_;
-
- // Host communication library refers to the communication libaray for CPU, e.g., OpenMPI and MindSpore communication
- // framework.
- CollectiveCommunicationLib *host_comm_lib_instance_;
-
- // Device communication library refers to the communication libaray for NPU or GPU, e.g., NCCL and HCCL.
- // When only CPU backend is used, device communication library should not be initialized.
- CollectiveCommunicationLib *device_comm_lib_instance_;
-
- // The global rank id of this process. Normally this range is 0 to `total process number - 1`.
- uint32_t global_rank_id_;
-
- // The local rank id of this process within the same node. This is usually used as device id.
- uint32_t local_rank_id_;
-
- // The global rank size. Normally this is equal to `total process number`.
- uint32_t global_rank_size_;
-
- // Global group ranks.
- std::vector<uint32_t> global_group_ranks_;
-
- // The global group name on the host side. This is used for Creating global group on host side for AllGather operation
- // of host name while assigning local rank.
- std::string host_global_group_name_;
- };
- } // namespace collective
- } // namespace distributed
- } // namespace mindspore
- #endif // MINDSPORE_CCSRC_DISTRIBUTED_COLLECTIVE_COLLECTIVE_MANAGER_H_
|