Merge pull request !21208 from baihuawei/graph_mode_nonsink_part2tags/v1.4.0
| @@ -64,7 +64,8 @@ HcclKernelFactory &HcclKernelFactory::Get() { | |||
| return _this; | |||
| } | |||
| HcclKernel::HcclKernel() : hccl_count_(0), op_type_(::HcclReduceOp::HCCL_REDUCE_SUM), root_id_(0) {} | |||
| HcclKernel::HcclKernel() | |||
| : hccl_count_(0), op_type_(::HcclReduceOp::HCCL_REDUCE_SUM), root_id_(0), src_rank_(0), dest_rank_(0) {} | |||
| HcclKernel::~HcclKernel() { | |||
| hccl_kernel_input_shape_list_.clear(); | |||
| @@ -81,6 +82,18 @@ HcclKernel::~HcclKernel() { | |||
| bool HcclKernel::Init(const AnfNodePtr &anf_node) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| op_name_ = AnfAlgo::GetCNodeName(anf_node); | |||
| if (op_name_ == kHcomSend) { | |||
| if (!HcomUtil::GetHcomDestRank(anf_node, &dest_rank_)) { | |||
| MS_LOG(ERROR) << "GetHcomDestRank fail!"; | |||
| return false; | |||
| } | |||
| } | |||
| if (op_name_ == kReceive) { | |||
| if (!HcomUtil::GetHcomSrcRank(anf_node, &src_rank_)) { | |||
| MS_LOG(ERROR) << "GetHcomSrcRank fail!"; | |||
| return false; | |||
| } | |||
| } | |||
| if (!HcomUtil::GetKernelInputShape(anf_node, &hccl_kernel_input_shape_list_)) { | |||
| MS_LOG(ERROR) << "GetKernelInputShape fail!"; | |||
| return false; | |||
| @@ -180,10 +193,13 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const { | |||
| } | |||
| const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const { | |||
| if (!workspace_size_list_.empty() || hccl_data_type_list_.empty()) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK); | |||
| auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE); | |||
| if (!workspace_size_list_.empty() || hccl_data_type_list_.empty() || (!is_task_sink && mode == kGraphMode)) { | |||
| return workspace_size_list_; | |||
| } | |||
| workspace_size_list_.emplace_back( | |||
| hccl::HcclAdapter::GetInstance().CalcWorkspaceSize(anf_node_.lock(), hccl_data_type_list_[0])); | |||
| return workspace_size_list_; | |||
| @@ -51,6 +51,8 @@ class HcclKernel : public AscendKernelMod { | |||
| uint64_t hccl_count_; | |||
| HcclReduceOp op_type_; | |||
| uint32_t root_id_; | |||
| uint32_t src_rank_; | |||
| uint32_t dest_rank_; | |||
| mutable std::vector<size_t> input_size_list_; | |||
| mutable std::vector<size_t> output_size_list_; | |||
| mutable std::vector<size_t> workspace_size_list_; | |||
| @@ -23,6 +23,7 @@ namespace mindspore { | |||
| namespace kernel { | |||
| bool HcomAllBroadCastKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &, void *stream_ptr) { | |||
| MS_LOG(DEBUG) << "HcomAllBroadCast launch"; | |||
| if (inputs.empty() || hccl_data_type_list_.empty()) { | |||
| MS_LOG(ERROR) << "BroadCast param is empty"; | |||
| return false; | |||
| @@ -16,13 +16,27 @@ | |||
| #include "backend/kernel_compiler/hccl/hcom_all_gather.h" | |||
| #include <memory> | |||
| #include "utils/ms_context.h" | |||
| #include "runtime/hccl_adapter/hccl_adapter.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| bool HcomAllGatherKernel::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &, void *) { | |||
| MS_LOG(INFO) << "HcomAllGather launch"; | |||
| bool HcomAllGatherKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) { | |||
| MS_LOG(DEBUG) << "HcomAllGather launch"; | |||
| if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) { | |||
| MS_LOG(ERROR) << "Invalid AllGather input, output or data type size(" << inputs.size() << ", " << outputs.size() | |||
| << ", " << hccl_data_type_list_.size() << ")."; | |||
| return false; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(inputs[0]); | |||
| MS_EXCEPTION_IF_NULL(outputs[0]); | |||
| MS_EXCEPTION_IF_NULL(stream_ptr); | |||
| auto hccl_result = hccl::HcclAdapter::GetInstance().HcclAllGather(inputs[0]->addr, outputs[0]->addr, hccl_count_, | |||
| hccl_data_type_list_[0], stream_ptr, group_); | |||
| if (hccl_result != HCCL_SUCCESS) { | |||
| MS_LOG(ERROR) << "HcclAllGather faled, ret:" << hccl_result; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| @@ -19,7 +19,6 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "hccl/hcom.h" | |||
| #include "backend/kernel_compiler/hccl/hccl_kernel.h" | |||
| namespace mindspore { | |||
| @@ -22,17 +22,17 @@ namespace mindspore { | |||
| namespace kernel { | |||
| bool HcomAllReduceKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) { | |||
| MS_LOG(INFO) << "HcclAllReduce launch"; | |||
| MS_LOG(DEBUG) << "HcclAllReduce launch"; | |||
| if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) { | |||
| MS_LOG(ERROR) << "Invalid AllReduce input, output or data type size(" << inputs.size() << ", " << outputs.size() | |||
| MS_LOG(ERROR) << "Invalid AllReduce input, output or data type size (" << inputs.size() << ", " << outputs.size() | |||
| << ", " << hccl_data_type_list_.size() << ")."; | |||
| return false; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(inputs[0]); | |||
| MS_EXCEPTION_IF_NULL(outputs[0]); | |||
| MS_EXCEPTION_IF_NULL(stream_ptr); | |||
| auto hccl_result = hccl::HcclAdapter::GetInstance().HcclAllReduce(inputs[0]->addr, outputs[0]->addr, hccl_count_, | |||
| hccl_data_type_list_[0], op_type_, stream_ptr); | |||
| auto hccl_result = hccl::HcclAdapter::GetInstance().HcclAllReduce( | |||
| inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], op_type_, stream_ptr, group_); | |||
| if (hccl_result != HCCL_SUCCESS) { | |||
| MS_LOG(ERROR) << "HcclAllReduce faled, ret:" << hccl_result; | |||
| return false; | |||
| @@ -16,13 +16,27 @@ | |||
| #include "backend/kernel_compiler/hccl/hcom_all_reduce_scatter.h" | |||
| #include <memory> | |||
| #include "utils/ms_context.h" | |||
| #include "runtime/hccl_adapter/hccl_adapter.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| bool HcomAllReduceScatterKernel::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &, void *) { | |||
| MS_LOG(INFO) << "HcomAllReduceScatter launch"; | |||
| bool HcomAllReduceScatterKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) { | |||
| MS_LOG(DEBUG) << "HcomAllReduceScatter launch"; | |||
| if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) { | |||
| MS_LOG(ERROR) << "Invalid AllReduceScatter input, output or data type size(" << inputs.size() << ", " | |||
| << outputs.size() << ", " << hccl_data_type_list_.size() << ")."; | |||
| return false; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(inputs[0]); | |||
| MS_EXCEPTION_IF_NULL(outputs[0]); | |||
| MS_EXCEPTION_IF_NULL(stream_ptr); | |||
| auto hccl_result = hccl::HcclAdapter::GetInstance().HcclReduceScatter( | |||
| inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], op_type_, stream_ptr, group_); | |||
| if (hccl_result != HCCL_SUCCESS) { | |||
| MS_LOG(ERROR) << "HcclReduceScatter faled, ret:" << hccl_result; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| @@ -19,7 +19,6 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "hccl/hcom.h" | |||
| #include "backend/kernel_compiler/hccl/hccl_kernel.h" | |||
| namespace mindspore { | |||
| @@ -16,12 +16,26 @@ | |||
| #include "backend/kernel_compiler/hccl/hcom_receive.h" | |||
| #include <memory> | |||
| #include "utils/ms_context.h" | |||
| #include "runtime/hccl_adapter/hccl_adapter.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| bool HcomReceiveKernel::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &, void *) { | |||
| MS_LOG(INFO) << "HcomReceive launch"; | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) { | |||
| MS_LOG(DEBUG) << "HcomReceive launch"; | |||
| if (outputs.empty() || hccl_data_type_list_.empty()) { | |||
| MS_LOG(ERROR) << "Invalid HcomReceive outputs size or data type size (" << outputs.size() << ", " | |||
| << hccl_data_type_list_.size() << ")."; | |||
| return false; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(outputs[0]); | |||
| MS_EXCEPTION_IF_NULL(stream_ptr); | |||
| auto hccl_result = hccl::HcclAdapter::GetInstance().HcclRecv(outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], | |||
| src_rank_, stream_ptr, group_); | |||
| if (hccl_result != HCCL_SUCCESS) { | |||
| MS_LOG(ERROR) << "HcomReceive failed, ret:" << hccl_result; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| @@ -16,13 +16,26 @@ | |||
| #include "backend/kernel_compiler/hccl/hcom_send.h" | |||
| #include <memory> | |||
| #include "utils/ms_context.h" | |||
| #include "runtime/hccl_adapter/hccl_adapter.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| bool HcomSendKernel::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &, void *) { | |||
| MS_LOG(INFO) << "HcomSend launch"; | |||
| bool HcomSendKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &, void *stream_ptr) { | |||
| MS_LOG(DEBUG) << "HcomSend launch"; | |||
| if (inputs.empty() || hccl_data_type_list_.empty()) { | |||
| MS_LOG(ERROR) << "Invalid HcomSend input size or data type size (" << inputs.size() << ", " | |||
| << hccl_data_type_list_.size() << ")."; | |||
| return false; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(inputs[0]); | |||
| MS_EXCEPTION_IF_NULL(stream_ptr); | |||
| auto hccl_result = hccl::HcclAdapter::GetInstance().HcclSend(inputs[0]->addr, hccl_count_, hccl_data_type_list_[0], | |||
| dest_rank_, stream_ptr, group_); | |||
| if (hccl_result != HCCL_SUCCESS) { | |||
| MS_LOG(ERROR) << "HcomSend faled, ret:" << hccl_result; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| @@ -218,6 +218,34 @@ bool HcomUtil::GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id) { | |||
| return true; | |||
| } | |||
| bool HcomUtil::GetHcomSrcRank(const AnfNodePtr &anf_node, uint32_t *src_rank) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| MS_EXCEPTION_IF_NULL(src_rank); | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| if (primitive->GetAttr("src_rank") != nullptr) { | |||
| *src_rank = static_cast<uint32_t>(GetValue<int64_t>(primitive->GetAttr("src_rank"))); | |||
| } else { | |||
| MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_SRC_RANK fail, not support!"; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool HcomUtil::GetHcomDestRank(const AnfNodePtr &anf_node, uint32_t *dest_rank) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| MS_EXCEPTION_IF_NULL(dest_rank); | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| if (primitive->GetAttr("dest_rank") != nullptr) { | |||
| *dest_rank = static_cast<uint32_t>(GetValue<int64_t>(primitive->GetAttr("dest_rank"))); | |||
| } else { | |||
| MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_DEST_RANK fail, not support!"; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool HcomUtil::GetHcomReceiveType(const AnfNodePtr &anf_node, TypeId *receive_type) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| MS_EXCEPTION_IF_NULL(receive_type); | |||
| @@ -66,6 +66,8 @@ class HcomUtil { | |||
| const vector<vector<size_t>> &shape_list, uint64_t *total_count); | |||
| static bool GetHcomOperationType(const AnfNodePtr &anf_node, HcclReduceOp *op_type); | |||
| static bool GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id); | |||
| static bool GetHcomSrcRank(const AnfNodePtr &anf_node, uint32_t *src_rank); | |||
| static bool GetHcomDestRank(const AnfNodePtr &anf_node, uint32_t *dest_rank); | |||
| static void GetHcomGroup(NotNull<const AnfNodePtr &> anf_node, NotNull<std::string *> group); | |||
| static bool GetHcomReceiveType(const AnfNodePtr &anf_node, TypeId *receive_type); | |||
| }; | |||
| @@ -53,6 +53,10 @@ void AscendEvent::WaitEvent() { | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "rtStreamWaitEvent failed, ret:" << ret; | |||
| } | |||
| ret = rtEventReset(event_, wait_stream_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "rtEventReset failed, ret:" << ret; | |||
| } | |||
| need_wait_ = false; | |||
| } | |||
| @@ -22,6 +22,7 @@ | |||
| #include "utils/signal_util.h" | |||
| #include "debug/data_dump/e2e_dump.h" | |||
| #include "runtime/device/ascend/ascend_device_address.h" | |||
| #include "runtime/device/ascend/distribute/ascend_collective.h" | |||
| #include "utils/ms_context.h" | |||
| #include "utils/context/context_extends.h" | |||
| #include "utils/mpi/mpi_config.h" | |||
| @@ -64,6 +65,7 @@ using mindspore::device::ascend::ProfilingManager; | |||
| using mindspore::device::ascend::ProfilingUtils; | |||
| using mindspore::device::ascend::tasksink::TaskGenerator; | |||
| using mindspore::ge::model_runner::ModelRunner; | |||
| using HcclCollectiveGroup = mindspore::device::ascend::collective::HcclCollectiveGroup; | |||
| using mindspore::kernel::tbe::TbeUtils; | |||
| using std::vector; | |||
| @@ -78,32 +80,17 @@ namespace mindspore::device::ascend { | |||
| static thread_local rtContext_t thread_local_rt_context{nullptr}; | |||
| namespace { | |||
| std::string GetRankId() { | |||
| std::string rank_id_str; | |||
| #ifdef ENABLE_MPI | |||
| auto mpi_config_ptr = MpiConfig::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(mpi_config_ptr); | |||
| if (mpi_config_ptr->enable_mpi()) { | |||
| int rank_id = GetMPIRankId(); | |||
| const std::string offset = common::GetEnv("RANK_OFFSET"); | |||
| if (offset.empty()) { | |||
| try { | |||
| int rank_offset = std::stoi(offset); | |||
| rank_id += rank_offset; | |||
| } catch (std::invalid_argument) { | |||
| MS_LOG(EXCEPTION) << "Call stoi invalid argument:" << offset; | |||
| } catch (std::out_of_range) { | |||
| MS_LOG(EXCEPTION) << "Call stoi out_of_range:" << offset; | |||
| } | |||
| } | |||
| rank_id_str = std::to_string(rank_id); | |||
| } else { | |||
| rank_id_str = common::GetEnv("RANK_ID"); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) { | |||
| MS_LOG(INFO) << "Get hccl rankid from mpi"; | |||
| auto rank = HcclCollectiveGroup::instance().GetRankId(); | |||
| return std::to_string(rank); | |||
| } | |||
| #else | |||
| rank_id_str = common::GetEnv("RANK_ID"); | |||
| #endif | |||
| std::string rank_id_str; | |||
| rank_id_str = std::getenv("RANK_ID"); | |||
| if (rank_id_str.empty()) { | |||
| MS_LOG(ERROR) << "Get hccl rankid failed, please set env RANK_ID"; | |||
| MS_LOG(EXCEPTION) << "Get hccl rankid failed, please set env RANK_ID"; | |||
| } | |||
| return rank_id_str; | |||
| } | |||
| @@ -744,6 +731,7 @@ bool AscendKernelRuntime::SyncStream() { | |||
| MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error."; | |||
| return false; | |||
| } | |||
| if (RT_ERROR_NONE != rtStreamSynchronize(communication_stream_)) { // o for switch stream | |||
| MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error."; | |||
| return false; | |||
| @@ -832,7 +820,6 @@ bool AscendKernelRuntime::ResetDevice(uint32_t device_id) { | |||
| } | |||
| stream_ = nullptr; | |||
| } | |||
| if (communication_stream_ != nullptr) { | |||
| ret = rtStreamDestroy(communication_stream_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| @@ -840,7 +827,6 @@ bool AscendKernelRuntime::ResetDevice(uint32_t device_id) { | |||
| } | |||
| communication_stream_ = nullptr; | |||
| } | |||
| ret = rtDeviceReset(device_id); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "Call rtDeviceReset, ret[" << ret << "]"; | |||
| @@ -857,6 +843,19 @@ bool AscendKernelRuntime::HcclInit() { | |||
| MS_LOG(EXCEPTION) << "Hccl dependent tsd is not open"; | |||
| } | |||
| MS_LOG(INFO) << "Do hcom init"; | |||
| bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK); | |||
| auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE); | |||
| if (!is_task_sink && mode == kGraphMode) { | |||
| hccl::HcclAdapter::GetInstance().InitHccl(); | |||
| std::vector<unsigned int> ranks; | |||
| auto rank_size = HcclCollectiveGroup::instance().GetRankSize(); | |||
| for (size_t i = 0; i < IntToSize(rank_size); ++i) { | |||
| ranks.push_back(i); | |||
| } | |||
| HcclCollectiveGroup::instance().CreateCommGroup(kHcclWorldGroup, ranks); | |||
| return true; | |||
| } | |||
| auto config_path_str = std::getenv("MINDSPORE_HCCL_CONFIG_PATH"); | |||
| if (config_path_str == nullptr) { | |||
| config_path_str = std::getenv("RANK_TABLE_FILE"); | |||
| @@ -482,9 +482,12 @@ void KernelRuntime::GenKernelEvents(const session::KernelGraph *graph) { | |||
| for (size_t j = i + 1; j < kernels.size(); ++j) { | |||
| auto &child = kernels[j]; | |||
| MS_EXCEPTION_IF_NULL(child); | |||
| if (AnfAlgo::IsCommunicationOp(child)) { | |||
| continue; | |||
| } | |||
| auto input_size = child->inputs().size() - 1; | |||
| for (size_t k = 0; k < input_size; ++k) { | |||
| auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(child, k), 0); | |||
| auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(child, k), 0, true); | |||
| if (kernel_index.first == kernel) { | |||
| found_nearest_child = true; | |||
| break; | |||
| @@ -617,7 +620,6 @@ void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodeP | |||
| if (addr_size.empty()) { | |||
| return; | |||
| } | |||
| if (type == kSomasReuseDynamicMem) { | |||
| bool not_reuse = KernelMemNotReuse(node); | |||
| if (not_reuse) { | |||
| @@ -26,7 +26,10 @@ | |||
| #include "hccl/hcom.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/ms_utils.h" | |||
| #include "utils/ms_context.h" | |||
| #include "runtime/hccl_adapter/converter.h" | |||
| #include "runtime/device/ascend/distribute/ascend_collective.h" | |||
| using HcclCollectiveGroup = mindspore::device::ascend::collective::HcclCollectiveGroup; | |||
| static constexpr const char *kHcclPluginFileName = "libhccl_plugin.so"; | |||
| static constexpr const char *kHcclDeployModeEnv = "DEPLOY_MODE"; | |||
| @@ -75,7 +78,6 @@ void HcclAdapter::InitPlugin() { | |||
| if (plugin_handle_ == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Dlopen " << kHcclPluginFileName << " failed, result = " << GetDlErrorMsg(); | |||
| } | |||
| init_hcom_graph_adapter_ = DlsymFuncObj(InitHcomGraphAdapter, plugin_handle_); | |||
| finalize_hcom_graph_adapter_ = DlsymFuncObj(FinalizeHcomGraphAdapter, plugin_handle_); | |||
| get_hccl_kernel_info_store_ = DlsymFuncObj(GetHcclKernelInfoStore, plugin_handle_); | |||
| @@ -98,7 +100,6 @@ void HcclAdapter::FinalizePlugin() { | |||
| if (plugin_handle_ == nullptr) { | |||
| return; | |||
| } | |||
| init_hcom_graph_adapter_ = nullptr; | |||
| finalize_hcom_graph_adapter_ = nullptr; | |||
| get_hccl_kernel_info_store_ = nullptr; | |||
| @@ -107,6 +108,10 @@ void HcclAdapter::FinalizePlugin() { | |||
| finalize_hccl_comm_ = nullptr; | |||
| launch_hccl_broadcast_ = nullptr; | |||
| launch_hccl_all_reduce_ = nullptr; | |||
| launch_hccl_reduce_scatter_ = nullptr; | |||
| launch_hccl_all_gather_ = nullptr; | |||
| launch_hccl_send_ = nullptr; | |||
| launch_hccl_recv_ = nullptr; | |||
| hccl_create_group_ = nullptr; | |||
| hccl_destroy_group_ = nullptr; | |||
| hccl_get_rank_id_ = nullptr; | |||
| @@ -119,6 +124,19 @@ void HcclAdapter::FinalizePlugin() { | |||
| plugin_handle_ = nullptr; | |||
| } | |||
| bool HcclAdapter::InitHccl() { | |||
| MS_LOG(INFO) << "Start init hccl adapter."; | |||
| std::lock_guard<std::mutex> lock(init_mutex_); | |||
| if (init_flag_) { | |||
| MS_LOG(INFO) << "Hccl has been inited, skip."; | |||
| return true; | |||
| } | |||
| InitPlugin(); | |||
| init_flag_ = true; | |||
| MS_LOG(INFO) << "Init hccl adapter success."; | |||
| return true; | |||
| } | |||
| bool HcclAdapter::InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file) { | |||
| MS_LOG(INFO) << "Start init hccl adapter."; | |||
| std::lock_guard<std::mutex> lock(init_mutex_); | |||
| @@ -136,12 +154,10 @@ bool HcclAdapter::InitHccl(uint32_t device_id, std::string_view rank_id, std::st | |||
| if (!ret) { | |||
| return false; | |||
| } | |||
| ret = InitHcclExec(); | |||
| if (!ret) { | |||
| return false; | |||
| } | |||
| init_flag_ = true; | |||
| MS_LOG(INFO) << "Init hccl adapter success."; | |||
| return true; | |||
| @@ -238,10 +254,69 @@ HcclResult HcclAdapter::HcclBroadcast(void *buf, uint64_t count, HcclDataType da | |||
| return launch_hccl_broadcast_(buf, count, dataType, root, hccl_comm_, stream); | |||
| } | |||
| HcclResult HcclAdapter::HcclAllReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, | |||
| HcclReduceOp op, aclrtStream stream) const { | |||
| HcclResult HcclAdapter::HcclAllReduce(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, | |||
| HcclReduceOp op, aclrtStream stream, const std::string &group) const { | |||
| MS_EXCEPTION_IF_NULL(launch_hccl_all_reduce_); | |||
| return launch_hccl_all_reduce_(sendBuf, recvBuf, count, dataType, op, hccl_comm_, stream); | |||
| HcclComm hccl_comm; | |||
| if (hccl_comm_ != nullptr) { | |||
| hccl_comm = hccl_comm_; | |||
| } else { | |||
| hccl_comm = HcclCollectiveGroup::instance().GetGroupComm(group); | |||
| MS_EXCEPTION_IF_NULL(hccl_comm); | |||
| } | |||
| return launch_hccl_all_reduce_(send_buf, recv_buf, count, dataType, op, hccl_comm, stream); | |||
| } | |||
| HcclResult HcclAdapter::HcclReduceScatter(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, | |||
| HcclReduceOp op, aclrtStream stream, const std::string &group) const { | |||
| MS_EXCEPTION_IF_NULL(launch_hccl_reduce_scatter_); | |||
| HcclComm hccl_comm; | |||
| if (hccl_comm_ != nullptr) { | |||
| hccl_comm = hccl_comm_; | |||
| } else { | |||
| hccl_comm = HcclCollectiveGroup::instance().GetGroupComm(group); | |||
| MS_EXCEPTION_IF_NULL(hccl_comm); | |||
| } | |||
| return launch_hccl_reduce_scatter_(send_buf, recv_buf, count, dataType, op, hccl_comm, stream); | |||
| } | |||
| HcclResult HcclAdapter::HcclAllGather(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, | |||
| aclrtStream stream, const std::string &group) const { | |||
| MS_EXCEPTION_IF_NULL(launch_hccl_all_gather_); | |||
| HcclComm hccl_comm; | |||
| if (hccl_comm_ != nullptr) { | |||
| hccl_comm = hccl_comm_; | |||
| } else { | |||
| hccl_comm = HcclCollectiveGroup::instance().GetGroupComm(group); | |||
| MS_EXCEPTION_IF_NULL(hccl_comm); | |||
| } | |||
| return launch_hccl_all_gather_(send_buf, recv_buf, count, dataType, hccl_comm, stream); | |||
| } | |||
| HcclResult HcclAdapter::HcclSend(void *send_buf, uint64_t count, HcclDataType dataType, uint32_t destRank, | |||
| aclrtStream stream, const std::string &group) const { | |||
| MS_EXCEPTION_IF_NULL(launch_hccl_send_); | |||
| HcclComm hccl_comm; | |||
| if (hccl_comm_ != nullptr) { | |||
| hccl_comm = hccl_comm_; | |||
| } else { | |||
| hccl_comm = HcclCollectiveGroup::instance().GetGroupComm(group); | |||
| MS_EXCEPTION_IF_NULL(hccl_comm); | |||
| } | |||
| return launch_hccl_send_(send_buf, count, dataType, destRank, hccl_comm, stream); | |||
| } | |||
| HcclResult HcclAdapter::HcclRecv(void *recv_buf, uint64_t count, HcclDataType dataType, uint32_t srcRank, | |||
| aclrtStream stream, const std::string &group) const { | |||
| MS_EXCEPTION_IF_NULL(launch_hccl_recv_); | |||
| HcclComm hccl_comm; | |||
| if (hccl_comm_ != nullptr) { | |||
| hccl_comm = hccl_comm_; | |||
| } else { | |||
| hccl_comm = HcclCollectiveGroup::instance().GetGroupComm(group); | |||
| MS_EXCEPTION_IF_NULL(hccl_comm); | |||
| } | |||
| return launch_hccl_recv_(recv_buf, count, dataType, srcRank, hccl_comm, stream); | |||
| } | |||
| bool HcclAdapter::InitKernelInfoStore(uint32_t device_id, std::string_view rank_id, std::string_view rank_file) { | |||
| @@ -338,6 +413,12 @@ bool HcclAdapter::InitHcclComm(std::string_view rank_id, std::string_view rank_f | |||
| bool HcclAdapter::FinalizeHcclComm() { | |||
| MS_LOG(INFO) << "Start finalize hccl comm."; | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| auto task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK); | |||
| if (!task_sink) { | |||
| HcclCollectiveGroup::instance().DestroyCommGroup(); | |||
| } | |||
| if (hccl_comm_ == nullptr) { | |||
| return true; | |||
| } | |||
| @@ -43,6 +43,7 @@ class HcclAdapter { | |||
| // common | |||
| bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file); | |||
| bool InitHccl(); | |||
| bool FinalizeHccl(); | |||
| HcclResult HcclCreateGroup(const std::string &group, uint32_t rank_num, uint32_t *rank_ids) const; | |||
| @@ -58,8 +59,16 @@ class HcclAdapter { | |||
| // for single op | |||
| HcclResult HcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root, aclrtStream stream) const; | |||
| HcclResult HcclAllReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, HcclReduceOp op, | |||
| aclrtStream stream) const; | |||
| HcclResult HcclAllReduce(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, HcclReduceOp op, | |||
| aclrtStream stream, const std::string &group = "") const; | |||
| HcclResult HcclAllGather(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, aclrtStream stream, | |||
| const std::string &group = "") const; | |||
| HcclResult HcclReduceScatter(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, HcclReduceOp op, | |||
| aclrtStream stream, const std::string &group = "") const; | |||
| HcclResult HcclSend(void *send_buf, uint64_t count, HcclDataType dataType, uint32_t destRank, aclrtStream stream, | |||
| const std::string &group = "") const; | |||
| HcclResult HcclRecv(void *recv_buf, uint64_t count, HcclDataType dataType, uint32_t srcRank, aclrtStream stream, | |||
| const std::string &group = "") const; | |||
| // for enqueue op | |||
| HcclResult HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const; | |||
| @@ -91,6 +100,10 @@ class HcclAdapter { | |||
| HcclCommDestroyFunObj finalize_hccl_comm_ = nullptr; | |||
| HcclBroadcastFunObj launch_hccl_broadcast_ = nullptr; | |||
| HcclAllReduceFunObj launch_hccl_all_reduce_ = nullptr; | |||
| HcclReduceScatterFunObj launch_hccl_reduce_scatter_ = nullptr; | |||
| HcclAllGatherFunObj launch_hccl_all_gather_ = nullptr; | |||
| HcclSendFunObj launch_hccl_send_ = nullptr; | |||
| HcclRecvFunObj launch_hccl_recv_ = nullptr; | |||
| HcomCreateGroupFunObj hccl_create_group_ = nullptr; | |||
| HcomDestroyGroupFunObj hccl_destroy_group_ = nullptr; | |||
| @@ -47,6 +47,12 @@ PLUGIN_METHOD(GetAllKernelBuilder, void, OpsKernelBuilderMap *); | |||
| ORIGIN_METHOD(HcclBroadcast, HcclResult, void *, uint64_t, HcclDataType, uint32_t, HcclComm, aclrtStream); | |||
| ORIGIN_METHOD(HcclAllReduce, HcclResult, void *, void *, uint64_t, HcclDataType, HcclReduceOp, HcclComm, aclrtStream); | |||
| ORIGIN_METHOD(HcclReduceScatter, HcclResult, void *, void *, uint64_t, HcclDataType, HcclReduceOp, HcclComm, | |||
| aclrtStream); | |||
| ORIGIN_METHOD(HcclAllGather, HcclResult, void *, void *, uint64_t, HcclDataType, HcclComm, aclrtStream); | |||
| ORIGIN_METHOD(HcclSend, HcclResult, void *, uint64_t, HcclDataType, uint32_t, HcclComm, aclrtStream); | |||
| ORIGIN_METHOD(HcclRecv, HcclResult, void *, uint64_t, HcclDataType, uint32_t, HcclComm, aclrtStream); | |||
| ORIGIN_METHOD(HcclCommInitClusterInfo, HcclResult, const char *, uint32_t, HcclComm *); | |||
| ORIGIN_METHOD(HcclCommDestroy, HcclResult, HcclComm); | |||
| ORIGIN_METHOD(HcomCreateGroup, HcclResult, const char *, uint32_t, uint32_t *); | |||
| @@ -22,6 +22,7 @@ HcclAdapter &HcclAdapter::GetInstance() { | |||
| static HcclAdapter instance; | |||
| return instance; | |||
| } | |||
| bool HcclAdapter::InitHccl() { return true; } | |||
| bool HcclAdapter::InitHccl(uint32_t, std::string_view, std::string_view) { return true; } | |||
| bool HcclAdapter::FinalizeHccl() { return true; } | |||
| HcclResult HcclAdapter::HcclCreateGroup(const std::string &, uint32_t, uint32_t *) const { return HCCL_SUCCESS; } | |||
| @@ -35,7 +36,21 @@ std::string HcclAdapter::GetHcclType(const AnfNodePtr &) { return ""; } | |||
| HcclResult HcclAdapter::HcclBroadcast(void *, uint64_t, HcclDataType, uint32_t, aclrtStream) const { | |||
| return HCCL_SUCCESS; | |||
| } | |||
| HcclResult HcclAdapter::HcclAllReduce(void *, void *, uint64_t, HcclDataType, HcclReduceOp, aclrtStream) const { | |||
| HcclResult HcclAdapter::HcclAllReduce(void *, void *, uint64_t, HcclDataType, HcclReduceOp, aclrtStream, | |||
| const std::string &) const { | |||
| return HCCL_SUCCESS; | |||
| } | |||
| HcclResult HcclAdapter::HcclAllGather(void *, void *, uint64_t, HcclDataType, aclrtStream, const std::string &) const { | |||
| return HCCL_SUCCESS; | |||
| } | |||
| HcclResult HcclAdapter::HcclReduceScatter(void *, void *, uint64_t, HcclDataType, HcclReduceOp, aclrtStream, | |||
| const std::string &) const { | |||
| return HCCL_SUCCESS; | |||
| } | |||
| HcclResult HcclAdapter::HcclSend(void *, uint64_t, HcclDataType, uint32_t, aclrtStream, const std::string &) const { | |||
| return HCCL_SUCCESS; | |||
| } | |||
| HcclResult HcclAdapter::HcclRecv(void *, uint64_t, HcclDataType, uint32_t, aclrtStream, const std::string &) const { | |||
| return HCCL_SUCCESS; | |||
| } | |||
| HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const { | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/device/ascend/distribute/ascend_collective.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace ascend { | |||
| namespace collective { | |||
| HcclCollectiveGroup &HcclCollectiveGroup::instance() { | |||
| static HcclCollectiveGroup instance; | |||
| return instance; | |||
| } | |||
| int HcclCollectiveGroup::GetRankSize(const std::string &) const { return 0; } | |||
| int HcclCollectiveGroup::GetRankId(const std::string &) const { return 0; } | |||
| int HcclCollectiveGroup::GetDeviceId() const { return 0; } | |||
| void HcclCollectiveGroup::CreateCommGroup(const std::string &, const std::vector<unsigned int> &) { return; } | |||
| } // namespace collective | |||
| } // namespace ascend | |||
| } // namespace device | |||
| } // namespace mindspore | |||