|
|
|
@@ -32,6 +32,13 @@ |
|
|
|
static constexpr const char *kHcclPluginFileName = "libhccl_plugin.so"; |
|
|
|
static constexpr const char *kHcclDeployModeEnv = "DEPLOY_MODE"; |
|
|
|
static constexpr const char *kHcclAlgoEnv = "HCCL_ALGO"; |
|
|
|
static constexpr const char *kHcclAlgoOption = "HCCL_algorithm"; |
|
|
|
|
|
|
|
#define CHECK_SYMBOL_NULL(symbol) \ |
|
|
|
if (symbol == nullptr) { \ |
|
|
|
MS_LOG(WARNING) << #symbol << " is null, hccl has not been inited, do nothing."; \ |
|
|
|
return HcclResult::HCCL_E_RESERVED; \ |
|
|
|
} |
|
|
|
|
|
|
|
static std::map<std::string, std::string> GenHcclOptions(uint32_t device_id, std::string_view rank_id, |
|
|
|
std::string_view rank_file) { |
|
|
|
@@ -54,8 +61,7 @@ static std::map<std::string, std::string> GenHcclOptions(uint32_t device_id, std |
|
|
|
|
|
|
|
auto env_hccl_algo = mindspore::common::GetEnv(kHcclAlgoEnv); |
|
|
|
if (!env_hccl_algo.empty()) { |
|
|
|
std::string ge_hccl_algo = "HCCL_algorithm"; |
|
|
|
default_options_map.emplace(ge_hccl_algo, env_hccl_algo); |
|
|
|
default_options_map.emplace(kHcclAlgoOption, env_hccl_algo); |
|
|
|
} |
|
|
|
|
|
|
|
return default_options_map; |
|
|
|
@@ -208,7 +214,7 @@ bool HcclAdapter::GenTask(const AnfNodePtr &node, HcclDataType datatype, |
|
|
|
MS_EXCEPTION_IF_NULL(ops_kernel_builder_); |
|
|
|
ge::Status ret = ops_kernel_builder_->CalcOpRunningParam(*ge_node); |
|
|
|
if (ret != ge::SUCCESS) { |
|
|
|
MS_LOG(ERROR) << "OpsKernelBuilder CalcOpRunningParam failed, ret = " << ret; |
|
|
|
MS_LOG(ERROR) << "Call hccl OpsKernelBuilder CalcOpRunningParam failed, check slog for detail, ret = " << ret; |
|
|
|
return false; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Start to call GenerateTask"; |
|
|
|
@@ -216,7 +222,7 @@ bool HcclAdapter::GenTask(const AnfNodePtr &node, HcclDataType datatype, |
|
|
|
std::vector<domi::TaskDef> domi_tasks; |
|
|
|
ret = ops_kernel_builder_->GenerateTask(*ge_node, unused_ctx, domi_tasks); |
|
|
|
if (ret != ge::SUCCESS) { |
|
|
|
MS_LOG(ERROR) << "OpsKernelBuilder GenerateTask failed, ret = " << ret; |
|
|
|
MS_LOG(ERROR) << "Call hccl OpsKernelBuilder GenerateTask failed, check slog for detail, ret = " << ret; |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -229,7 +235,9 @@ bool HcclAdapter::GenTask(const AnfNodePtr &node, HcclDataType datatype, |
|
|
|
} |
|
|
|
|
|
|
|
int64_t HcclAdapter::CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype) const { |
|
|
|
MS_EXCEPTION_IF_NULL(ops_kernel_builder_); |
|
|
|
if (ops_kernel_builder_ == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Hccl ops kernel builder is null, may not be inited."; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Start calc workspace size for hccl node " << node->DebugString() << " ,dtype is " << datatype; |
|
|
|
auto [ge_node, ge_graph] = GenerateStubGeNode(node, datatype); |
|
|
|
MS_EXCEPTION_IF_NULL(ge_node); |
|
|
|
@@ -239,13 +247,13 @@ int64_t HcclAdapter::CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType data |
|
|
|
MS_LOG(INFO) << "Start to call CalcOpRunningParam"; |
|
|
|
ge::Status ret = ops_kernel_builder_->CalcOpRunningParam(*ge_node); |
|
|
|
if (ret != ge::SUCCESS) { |
|
|
|
MS_LOG(ERROR) << "OpsKernelBuilder CalcOpRunningParam failed, ret = " << ret; |
|
|
|
MS_LOG(ERROR) << "Call hccl OpsKernelBuilder CalcOpRunningParam failed, check slog for detail, ret = " << ret; |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
auto workspace_sizes = op->GetWorkspaceBytes(); |
|
|
|
if (workspace_sizes.size() != 1) { |
|
|
|
MS_LOG(EXCEPTION) << "Unexpected workspace size " << workspace_sizes.size(); |
|
|
|
MS_LOG(EXCEPTION) << "Unexpected workspace size " << workspace_sizes.size() << ", which should be 1."; |
|
|
|
} |
|
|
|
int64_t workspace_size = workspace_sizes[0]; |
|
|
|
MS_LOG(INFO) << "Node " << node->DebugString() << " workspace size is " << workspace_size; |
|
|
|
@@ -264,13 +272,13 @@ std::string HcclAdapter::GetHcclType(const AnfNodePtr &node) { |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root, |
|
|
|
aclrtStream stream) const { |
|
|
|
MS_EXCEPTION_IF_NULL(launch_hccl_broadcast_); |
|
|
|
CHECK_SYMBOL_NULL(launch_hccl_broadcast_); |
|
|
|
return launch_hccl_broadcast_(buf, count, dataType, root, hccl_comm_, stream); |
|
|
|
} |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclAllReduce(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, |
|
|
|
HcclReduceOp op, aclrtStream stream, const std::string &group) const { |
|
|
|
MS_EXCEPTION_IF_NULL(launch_hccl_all_reduce_); |
|
|
|
CHECK_SYMBOL_NULL(launch_hccl_all_reduce_); |
|
|
|
auto hccl_comm = GetHcomm(group); |
|
|
|
MS_EXCEPTION_IF_NULL(hccl_comm); |
|
|
|
return launch_hccl_all_reduce_(send_buf, recv_buf, count, dataType, op, hccl_comm, stream); |
|
|
|
@@ -278,7 +286,7 @@ HcclResult HcclAdapter::HcclAllReduce(void *send_buf, void *recv_buf, uint64_t c |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclReduceScatter(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, |
|
|
|
HcclReduceOp op, aclrtStream stream, const std::string &group) const { |
|
|
|
MS_EXCEPTION_IF_NULL(launch_hccl_reduce_scatter_); |
|
|
|
CHECK_SYMBOL_NULL(launch_hccl_reduce_scatter_); |
|
|
|
auto hccl_comm = GetHcomm(group); |
|
|
|
MS_EXCEPTION_IF_NULL(hccl_comm); |
|
|
|
return launch_hccl_reduce_scatter_(send_buf, recv_buf, count, dataType, op, hccl_comm, stream); |
|
|
|
@@ -286,7 +294,7 @@ HcclResult HcclAdapter::HcclReduceScatter(void *send_buf, void *recv_buf, uint64 |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclAllGather(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, |
|
|
|
aclrtStream stream, const std::string &group) const { |
|
|
|
MS_EXCEPTION_IF_NULL(launch_hccl_all_gather_); |
|
|
|
CHECK_SYMBOL_NULL(launch_hccl_all_gather_); |
|
|
|
auto hccl_comm = GetHcomm(group); |
|
|
|
MS_EXCEPTION_IF_NULL(hccl_comm); |
|
|
|
return launch_hccl_all_gather_(send_buf, recv_buf, count, dataType, hccl_comm, stream); |
|
|
|
@@ -294,7 +302,7 @@ HcclResult HcclAdapter::HcclAllGather(void *send_buf, void *recv_buf, uint64_t c |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclSend(void *send_buf, uint64_t count, HcclDataType dataType, uint32_t destRank, |
|
|
|
aclrtStream stream, const std::string &group) const { |
|
|
|
MS_EXCEPTION_IF_NULL(launch_hccl_send_); |
|
|
|
CHECK_SYMBOL_NULL(launch_hccl_send_); |
|
|
|
auto hccl_comm = GetHcomm(group); |
|
|
|
MS_EXCEPTION_IF_NULL(hccl_comm); |
|
|
|
return launch_hccl_send_(send_buf, count, dataType, destRank, hccl_comm, stream); |
|
|
|
@@ -302,7 +310,7 @@ HcclResult HcclAdapter::HcclSend(void *send_buf, uint64_t count, HcclDataType da |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclRecv(void *recv_buf, uint64_t count, HcclDataType dataType, uint32_t srcRank, |
|
|
|
aclrtStream stream, const std::string &group) const { |
|
|
|
MS_EXCEPTION_IF_NULL(launch_hccl_recv_); |
|
|
|
CHECK_SYMBOL_NULL(launch_hccl_recv_); |
|
|
|
auto hccl_comm = GetHcomm(group); |
|
|
|
MS_EXCEPTION_IF_NULL(hccl_comm); |
|
|
|
return launch_hccl_recv_(recv_buf, count, dataType, srcRank, hccl_comm, stream); |
|
|
|
@@ -333,20 +341,20 @@ bool HcclAdapter::InitKernelInfoStore(uint32_t device_id, std::string_view rank_ |
|
|
|
auto options = GenHcclOptions(device_id, rank_id, rank_file); |
|
|
|
auto ret = ops_kernel_builder_->Initialize(options); |
|
|
|
if (ret != ge::SUCCESS) { |
|
|
|
MS_LOG(EXCEPTION) << "Init hccl kernel builder failed, ret = " << ret; |
|
|
|
MS_LOG(EXCEPTION) << "Init hccl kernel builder failed, check slog for detail, ret = " << ret; |
|
|
|
} |
|
|
|
|
|
|
|
// get ops_kernel_info_store |
|
|
|
ret = init_hcom_graph_adapter_(options); |
|
|
|
if (ret != ge::SUCCESS) { |
|
|
|
MS_LOG(EXCEPTION) << "Init hccl graph adapter failed, ret = " << ret; |
|
|
|
MS_LOG(EXCEPTION) << "Init hccl graph adapter failed, check slog for detail, ret = " << ret; |
|
|
|
} |
|
|
|
|
|
|
|
get_hccl_kernel_info_store_(&ops_kernel_info_store_); |
|
|
|
MS_EXCEPTION_IF_NULL(ops_kernel_info_store_); |
|
|
|
ret = ops_kernel_info_store_->Initialize(options); |
|
|
|
if (ret != ge::SUCCESS) { |
|
|
|
MS_LOG(EXCEPTION) << "Init info store failed, ret = " << ret; |
|
|
|
MS_LOG(EXCEPTION) << "Init info store failed, check slog for detail, ret = " << ret; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Init hccl kernel info store success."; |
|
|
|
return true; |
|
|
|
@@ -424,32 +432,32 @@ bool HcclAdapter::FinalizeHcclComm() { |
|
|
|
} |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclCreateGroup(const std::string &group, uint32_t rank_num, uint32_t *rank_ids) const { |
|
|
|
MS_EXCEPTION_IF_NULL(hccl_create_group_); |
|
|
|
CHECK_SYMBOL_NULL(hccl_create_group_); |
|
|
|
return hccl_create_group_(group.c_str(), rank_num, rank_ids); |
|
|
|
} |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclDestroyGroup(const std::string &group) const { |
|
|
|
MS_EXCEPTION_IF_NULL(hccl_destroy_group_); |
|
|
|
CHECK_SYMBOL_NULL(hccl_destroy_group_); |
|
|
|
return hccl_destroy_group_(group.c_str()); |
|
|
|
} |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclGetRankId(uint32_t *rank_id) const { |
|
|
|
MS_EXCEPTION_IF_NULL(single_op_hccl_get_rank_id_); |
|
|
|
CHECK_SYMBOL_NULL(single_op_hccl_get_rank_id_); |
|
|
|
return single_op_hccl_get_rank_id_(hccl_comm_, rank_id); |
|
|
|
} |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclGetRankSize(uint32_t *rank_size) const { |
|
|
|
MS_EXCEPTION_IF_NULL(single_op_hccl_get_rank_size_); |
|
|
|
CHECK_SYMBOL_NULL(single_op_hccl_get_rank_size_); |
|
|
|
return single_op_hccl_get_rank_size_(hccl_comm_, rank_size); |
|
|
|
} |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclGetRankId(const std::string &group, uint32_t *rank_id) const { |
|
|
|
MS_EXCEPTION_IF_NULL(hccl_get_rank_id_); |
|
|
|
CHECK_SYMBOL_NULL(hccl_get_rank_id_); |
|
|
|
return hccl_get_rank_id_(group.c_str(), rank_id); |
|
|
|
} |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclGetRankSize(const std::string &group, uint32_t *rank_size) const { |
|
|
|
MS_EXCEPTION_IF_NULL(hccl_get_rank_size_); |
|
|
|
CHECK_SYMBOL_NULL(hccl_get_rank_size_); |
|
|
|
return hccl_get_rank_size_(group.c_str(), rank_size); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -482,12 +490,12 @@ bool HcclAdapter::FinalizeHcclExec() { |
|
|
|
} |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const { |
|
|
|
MS_EXCEPTION_IF_NULL(hccl_exec_enqueue_op_); |
|
|
|
CHECK_SYMBOL_NULL(hccl_exec_enqueue_op_); |
|
|
|
return hccl_exec_enqueue_op_(op_info, callback); |
|
|
|
} |
|
|
|
|
|
|
|
HcclResult HcclAdapter::HcclExecAllToAllv(const ::HcomAllToAllVParams ¶ms, const HExecCallBack &callback) const { |
|
|
|
MS_EXCEPTION_IF_NULL(hccl_exec_enqueue_all_to_all_v_); |
|
|
|
CHECK_SYMBOL_NULL(hccl_exec_enqueue_all_to_all_v_); |
|
|
|
return hccl_exec_enqueue_all_to_all_v_(params, callback); |
|
|
|
} |
|
|
|
} // namespace mindspore::hccl |