|
|
|
@@ -40,17 +40,6 @@ static constexpr const char *kHcclAlgoOption = "HCCL_algorithm"; |
|
|
|
return HcclResult::HCCL_E_RESERVED; \ |
|
|
|
} |
|
|
|
|
|
|
|
#define CHECK_EXCUTION_MODE() \ |
|
|
|
do { \ |
|
|
|
auto hccl_mode = GetCurrentHcclMode(); \ |
|
|
|
if (hccl_mode != hccl_mode_) { \ |
|
|
|
MS_LOG(EXCEPTION) << "HCCL is initialized in " << GetHcclModeString(hccl_mode_) \ |
|
|
|
<< " but current execution mode is " << GetHcclModeString(hccl_mode) \ |
|
|
|
<< ". Please set the execution mode before HCCL init(), and then do not " \ |
|
|
|
"change it in the subsequent script"; \ |
|
|
|
} \ |
|
|
|
} while (0) |
|
|
|
|
|
|
|
static std::map<std::string, std::string> GenHcclOptions(uint32_t device_id, std::string_view rank_id, |
|
|
|
std::string_view rank_file) { |
|
|
|
auto env_deploy_mode = mindspore::common::GetEnv(kHcclDeployModeEnv); |
|
|
|
@@ -159,6 +148,16 @@ HcclMode HcclAdapter::GetCurrentHcclMode() const { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void HcclAdapter::CheckExcutionMode() const { |
|
|
|
auto hccl_mode = GetCurrentHcclMode(); |
|
|
|
if (hccl_mode != hccl_mode_) { |
|
|
|
MS_LOG(EXCEPTION) << "HCCL is initialized in " << GetHcclModeString(hccl_mode_) << " but current execution mode is " |
|
|
|
<< GetHcclModeString(hccl_mode) |
|
|
|
<< ". Please set the execution mode before HCCL init(), and then do not change it in the " |
|
|
|
"subsequent script"; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
std::string HcclAdapter::GetHcclModeString(HcclMode hccl_mode) { |
|
|
|
static std::map<HcclMode, std::string> kHcclModeString = { |
|
|
|
{HcclMode::kGraph, "GRAPH_MODE"}, |
|
|
|
@@ -307,14 +306,14 @@ std::string HcclAdapter::GetHcclType(const AnfNodePtr &node) { |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root, |
|
|
|
aclrtStream stream) const { |
|
|
|
CHECK_EXCUTION_MODE(); |
|
|
|
CheckExcutionMode(); |
|
|
|
CHECK_SYMBOL_NULL(launch_hccl_broadcast_); |
|
|
|
return launch_hccl_broadcast_(buf, count, dataType, root, hccl_comm_, stream); |
|
|
|
} |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclAllReduce(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, |
|
|
|
HcclReduceOp op, aclrtStream stream, const std::string &group) const { |
|
|
|
CHECK_EXCUTION_MODE(); |
|
|
|
CheckExcutionMode(); |
|
|
|
CHECK_SYMBOL_NULL(launch_hccl_all_reduce_); |
|
|
|
auto hccl_comm = GetHcomm(group); |
|
|
|
MS_EXCEPTION_IF_NULL(hccl_comm); |
|
|
|
@@ -323,7 +322,7 @@ HcclResult HcclAdapter::HcclAllReduce(void *send_buf, void *recv_buf, uint64_t c |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclReduceScatter(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, |
|
|
|
HcclReduceOp op, aclrtStream stream, const std::string &group) const { |
|
|
|
CHECK_EXCUTION_MODE(); |
|
|
|
CheckExcutionMode(); |
|
|
|
CHECK_SYMBOL_NULL(launch_hccl_reduce_scatter_); |
|
|
|
auto hccl_comm = GetHcomm(group); |
|
|
|
MS_EXCEPTION_IF_NULL(hccl_comm); |
|
|
|
@@ -332,7 +331,7 @@ HcclResult HcclAdapter::HcclReduceScatter(void *send_buf, void *recv_buf, uint64 |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclAllGather(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, |
|
|
|
aclrtStream stream, const std::string &group) const { |
|
|
|
CHECK_EXCUTION_MODE(); |
|
|
|
CheckExcutionMode(); |
|
|
|
CHECK_SYMBOL_NULL(launch_hccl_all_gather_); |
|
|
|
auto hccl_comm = GetHcomm(group); |
|
|
|
MS_EXCEPTION_IF_NULL(hccl_comm); |
|
|
|
@@ -341,7 +340,7 @@ HcclResult HcclAdapter::HcclAllGather(void *send_buf, void *recv_buf, uint64_t c |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclSend(void *send_buf, uint64_t count, HcclDataType dataType, uint32_t destRank, |
|
|
|
aclrtStream stream, const std::string &group) const { |
|
|
|
CHECK_EXCUTION_MODE(); |
|
|
|
CheckExcutionMode(); |
|
|
|
CHECK_SYMBOL_NULL(launch_hccl_send_); |
|
|
|
auto hccl_comm = GetHcomm(group); |
|
|
|
MS_EXCEPTION_IF_NULL(hccl_comm); |
|
|
|
@@ -350,7 +349,7 @@ HcclResult HcclAdapter::HcclSend(void *send_buf, uint64_t count, HcclDataType da |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclRecv(void *recv_buf, uint64_t count, HcclDataType dataType, uint32_t srcRank, |
|
|
|
aclrtStream stream, const std::string &group) const { |
|
|
|
CHECK_EXCUTION_MODE(); |
|
|
|
CheckExcutionMode(); |
|
|
|
CHECK_SYMBOL_NULL(launch_hccl_recv_); |
|
|
|
auto hccl_comm = GetHcomm(group); |
|
|
|
MS_EXCEPTION_IF_NULL(hccl_comm); |
|
|
|
@@ -474,7 +473,7 @@ bool HcclAdapter::FinalizeHcclComm() { |
|
|
|
} |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclCreateGroup(const std::string &group, uint32_t rank_num, uint32_t *rank_ids) const { |
|
|
|
CHECK_EXCUTION_MODE(); |
|
|
|
CheckExcutionMode(); |
|
|
|
CHECK_SYMBOL_NULL(hccl_create_group_); |
|
|
|
return hccl_create_group_(group.c_str(), rank_num, rank_ids); |
|
|
|
} |
|
|
|
@@ -485,25 +484,25 @@ HcclResult HcclAdapter::HcclDestroyGroup(const std::string &group) const { |
|
|
|
} |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclGetRankId(uint32_t *rank_id) const { |
|
|
|
CHECK_EXCUTION_MODE(); |
|
|
|
CheckExcutionMode(); |
|
|
|
CHECK_SYMBOL_NULL(single_op_hccl_get_rank_id_); |
|
|
|
return single_op_hccl_get_rank_id_(hccl_comm_, rank_id); |
|
|
|
} |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclGetRankSize(uint32_t *rank_size) const { |
|
|
|
CHECK_EXCUTION_MODE(); |
|
|
|
CheckExcutionMode(); |
|
|
|
CHECK_SYMBOL_NULL(single_op_hccl_get_rank_size_); |
|
|
|
return single_op_hccl_get_rank_size_(hccl_comm_, rank_size); |
|
|
|
} |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclGetRankId(const std::string &group, uint32_t *rank_id) const { |
|
|
|
CHECK_EXCUTION_MODE(); |
|
|
|
CheckExcutionMode(); |
|
|
|
CHECK_SYMBOL_NULL(hccl_get_rank_id_); |
|
|
|
return hccl_get_rank_id_(group.c_str(), rank_id); |
|
|
|
} |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclGetRankSize(const std::string &group, uint32_t *rank_size) const { |
|
|
|
CHECK_EXCUTION_MODE(); |
|
|
|
CheckExcutionMode(); |
|
|
|
CHECK_SYMBOL_NULL(hccl_get_rank_size_); |
|
|
|
return hccl_get_rank_size_(group.c_str(), rank_size); |
|
|
|
} |
|
|
|
@@ -537,13 +536,13 @@ bool HcclAdapter::FinalizeHcclExec() { |
|
|
|
} |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const { |
|
|
|
CHECK_EXCUTION_MODE(); |
|
|
|
CheckExcutionMode(); |
|
|
|
CHECK_SYMBOL_NULL(hccl_exec_enqueue_op_); |
|
|
|
return hccl_exec_enqueue_op_(op_info, callback); |
|
|
|
} |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclExecAllToAllv(const ::HcomAllToAllVParams ¶ms, const HExecCallBack &callback) const { |
|
|
|
CHECK_EXCUTION_MODE(); |
|
|
|
CheckExcutionMode(); |
|
|
|
CHECK_SYMBOL_NULL(hccl_exec_enqueue_all_to_all_v_); |
|
|
|
return hccl_exec_enqueue_all_to_all_v_(params, callback); |
|
|
|
} |
|
|
|
|