Browse Source

Fix get rank size func in alltoall fusion

tags/v1.6.0
ZPaC 4 years ago
parent
commit
246f1bcd06
4 changed files with 16 additions and 18 deletions
  1. +1
    -15
      mindspore/ccsrc/backend/optimizer/gpu/alltoall_fusion.cc
  2. +7
    -2
      mindspore/ccsrc/distributed/cluster/cluster_context.cc
  3. +7
    -0
      mindspore/ccsrc/distributed/cluster/cluster_context.h
  4. +1
    -1
      mindspore/python/mindspore/communication/management.py

+ 1
- 15
mindspore/ccsrc/backend/optimizer/gpu/alltoall_fusion.cc View File

@@ -37,20 +37,6 @@ inline int64_t NormalizeDim(const std::vector<size_t> &shape, int64_t dim) {
return dim < 0 ? SizeToLong(shape.size()) + dim : dim;
}

uint32_t GetRankSize(const std::string &group) {
uint32_t rank_size;
const void *collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
MS_EXCEPTION_IF_NULL(collective_handle_);

// Get group size
auto get_group_size_funcptr =
reinterpret_cast<GetGroupRanks>(dlsym(const_cast<void *>(collective_handle_), "GetGroupRanks"));
MS_EXCEPTION_IF_NULL(get_group_size_funcptr);
std::vector<int> group_ranks = (*get_group_size_funcptr)(group);
rank_size = group_ranks.size();
return rank_size;
}

CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(all_to_all);
@@ -113,7 +99,7 @@ CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &all_to_a
std::vector<TypeId> dtypes(split_count, single_type);
std::vector<std::vector<size_t>> shapes(split_count, single_shape);
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, all_to_all_v.get());
uint32_t rank_size = GetRankSize(group);
uint32_t rank_size = device::gpu::CollectiveInitializer::instance().GetGroupSize(group);
std::vector<int64_t> rank_ids(rank_size, 0);
for (uint32_t i = 0; i < rank_size; ++i) {
rank_ids[i] = static_cast<int64_t>(i);


+ 7
- 2
mindspore/ccsrc/distributed/cluster/cluster_context.cc View File

@@ -156,7 +156,7 @@ bool ClusterContext::BuildCluster() {
void ClusterContext::InitNodeRole() {
node_role_ = common::GetEnv(kEnvRole);
if (kValidRoleName.count(node_role_) == 0) {
MS_LOG(EXCEPTION) << "Role name " << node_role_ << " is invalid.";
MS_LOG(EXCEPTION) << "Role name '" << node_role_ << "' is invalid. " << kDetailedFailureReason;
return;
}

@@ -177,7 +177,12 @@ void ClusterContext::InitNodeRole() {
}
}

void ClusterContext::InitSchedulerIp() { scheduler_host_ = common::GetEnv(kEnvSchedulerHost); }
void ClusterContext::InitSchedulerIp() {
scheduler_host_ = common::GetEnv(kEnvSchedulerHost);
if (scheduler_host_.empty()) {
MS_LOG(EXCEPTION) << kEnvSchedulerHost << " is empty. " << kEnvSchedulerHost;
}
}

void ClusterContext::InitSchedulerPort() {
TRY_AND_CATCH_WITH_EXCEPTION((scheduler_port_ = static_cast<uint16_t>(std::stoi(common::GetEnv(kEnvSchedulerPort)))),


+ 7
- 0
mindspore/ccsrc/distributed/cluster/cluster_context.h View File

@@ -37,6 +37,13 @@
namespace mindspore {
namespace distributed {
namespace cluster {
// The detailed reason of failing to run 'mindspore.communication.init()' with ClusterContext.
constexpr char kDetailedFailureReason[] =
"Maybe you are trying to call 'mindspore.communication.init()' without using 'mpirun', which will make MindSpore "
"load several environment variables and check their validation. Please use 'mpirun' to launch this process to fix "
"this issue, or refer to this link if you want to run distributed training without using 'mpirun': "
"https://www.mindspore.cn/docs/programming_guide/zh-CN/master/distributed_training_gpu.html#openmpi.";

// Node role based cluster built by MindSpore communication framework.
class ClusterContext {
public:


+ 1
- 1
mindspore/python/mindspore/communication/management.py View File

@@ -90,7 +90,7 @@ def init(backend_name=None):
The full name of HCCL is Huawei Collective Communication Library.
The full name of NCCL is NVIDIA Collective Communication Library.
This method should be used after set_context. The user needs to preset communication environment variables
before running the following example, please see the docstring of the mindspore.managerment.
before running the following example, please see the docstring of the mindspore.management.

Args:
backend_name (str): Backend, using HCCL/NCCL. If the `backend_name` is None, system will recognize


Loading…
Cancel
Save