Browse Source

!21208 Ascend support non-task sink mode

Merge pull request !21208 from baihuawei/graph_mode_nonsink_part2
tags/v1.4.0
i-robot Gitee 4 years ago
parent
commit
191c82573f
20 changed files with 318 additions and 62 deletions
  1. +19
    -3
      mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc
  2. +2
    -0
      mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h
  3. +1
    -0
      mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc
  4. +18
    -4
      mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.cc
  5. +0
    -1
      mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.h
  6. +4
    -4
      mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc
  7. +18
    -4
      mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.cc
  8. +0
    -1
      mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.h
  9. +17
    -3
      mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_receive.cc
  10. +17
    -4
      mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_send.cc
  11. +28
    -0
      mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc
  12. +2
    -0
      mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.h
  13. +4
    -0
      mindspore/ccsrc/runtime/device/ascend/ascend_event.cc
  14. +25
    -26
      mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc
  15. +4
    -2
      mindspore/ccsrc/runtime/device/kernel_runtime.cc
  16. +88
    -7
      mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.cc
  17. +15
    -2
      mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.h
  18. +6
    -0
      mindspore/ccsrc/runtime/hccl_adapter/plugin/hccl_plugin.h
  19. +16
    -1
      tests/ut/cpp/stub/ge/ge_task_launch_stub.cc
  20. +34
    -0
      tests/ut/cpp/stub/hccl/collective_stub.cc

+ 19
- 3
mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc View File

@@ -64,7 +64,8 @@ HcclKernelFactory &HcclKernelFactory::Get() {
return _this;
}

HcclKernel::HcclKernel() : hccl_count_(0), op_type_(::HcclReduceOp::HCCL_REDUCE_SUM), root_id_(0) {}
HcclKernel::HcclKernel()
: hccl_count_(0), op_type_(::HcclReduceOp::HCCL_REDUCE_SUM), root_id_(0), src_rank_(0), dest_rank_(0) {}

HcclKernel::~HcclKernel() {
hccl_kernel_input_shape_list_.clear();
@@ -81,6 +82,18 @@ HcclKernel::~HcclKernel() {
bool HcclKernel::Init(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
op_name_ = AnfAlgo::GetCNodeName(anf_node);
if (op_name_ == kHcomSend) {
if (!HcomUtil::GetHcomDestRank(anf_node, &dest_rank_)) {
MS_LOG(ERROR) << "GetHcomDestRank fail!";
return false;
}
}
if (op_name_ == kReceive) {
if (!HcomUtil::GetHcomSrcRank(anf_node, &src_rank_)) {
MS_LOG(ERROR) << "GetHcomSrcRank fail!";
return false;
}
}
if (!HcomUtil::GetKernelInputShape(anf_node, &hccl_kernel_input_shape_list_)) {
MS_LOG(ERROR) << "GetKernelInputShape fail!";
return false;
@@ -180,10 +193,13 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
}

const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const {
if (!workspace_size_list_.empty() || hccl_data_type_list_.empty()) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
if (!workspace_size_list_.empty() || hccl_data_type_list_.empty() || (!is_task_sink && mode == kGraphMode)) {
return workspace_size_list_;
}

workspace_size_list_.emplace_back(
hccl::HcclAdapter::GetInstance().CalcWorkspaceSize(anf_node_.lock(), hccl_data_type_list_[0]));
return workspace_size_list_;


+ 2
- 0
mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h View File

@@ -51,6 +51,8 @@ class HcclKernel : public AscendKernelMod {
uint64_t hccl_count_;
HcclReduceOp op_type_;
uint32_t root_id_;
uint32_t src_rank_;
uint32_t dest_rank_;
mutable std::vector<size_t> input_size_list_;
mutable std::vector<size_t> output_size_list_;
mutable std::vector<size_t> workspace_size_list_;


+ 1
- 0
mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc View File

@@ -23,6 +23,7 @@ namespace mindspore {
namespace kernel {
bool HcomAllBroadCastKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, void *stream_ptr) {
MS_LOG(DEBUG) << "HcomAllBroadCast launch";
if (inputs.empty() || hccl_data_type_list_.empty()) {
MS_LOG(ERROR) << "BroadCast param is empty";
return false;


+ 18
- 4
mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.cc View File

@@ -16,13 +16,27 @@

#include "backend/kernel_compiler/hccl/hcom_all_gather.h"
#include <memory>
#include "utils/ms_context.h"
#include "runtime/hccl_adapter/hccl_adapter.h"

namespace mindspore {
namespace kernel {
bool HcomAllGatherKernel::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, void *) {
MS_LOG(INFO) << "HcomAllGather launch";
bool HcomAllGatherKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
MS_LOG(DEBUG) << "HcomAllGather launch";
if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) {
MS_LOG(ERROR) << "Invalid AllGather input, output or data type size(" << inputs.size() << ", " << outputs.size()
<< ", " << hccl_data_type_list_.size() << ").";
return false;
}
MS_EXCEPTION_IF_NULL(inputs[0]);
MS_EXCEPTION_IF_NULL(outputs[0]);
MS_EXCEPTION_IF_NULL(stream_ptr);
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclAllGather(inputs[0]->addr, outputs[0]->addr, hccl_count_,
hccl_data_type_list_[0], stream_ptr, group_);
if (hccl_result != HCCL_SUCCESS) {
MS_LOG(ERROR) << "HcclAllGather faled, ret:" << hccl_result;
return false;
}
return true;
}
} // namespace kernel


+ 0
- 1
mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.h View File

@@ -19,7 +19,6 @@

#include <vector>
#include <memory>
#include "hccl/hcom.h"
#include "backend/kernel_compiler/hccl/hccl_kernel.h"

namespace mindspore {


+ 4
- 4
mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc View File

@@ -22,17 +22,17 @@ namespace mindspore {
namespace kernel {
bool HcomAllReduceKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
MS_LOG(INFO) << "HcclAllReduce launch";
MS_LOG(DEBUG) << "HcclAllReduce launch";
if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) {
MS_LOG(ERROR) << "Invalid AllReduce input, output or data type size(" << inputs.size() << ", " << outputs.size()
MS_LOG(ERROR) << "Invalid AllReduce input, output or data type size (" << inputs.size() << ", " << outputs.size()
<< ", " << hccl_data_type_list_.size() << ").";
return false;
}
MS_EXCEPTION_IF_NULL(inputs[0]);
MS_EXCEPTION_IF_NULL(outputs[0]);
MS_EXCEPTION_IF_NULL(stream_ptr);
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclAllReduce(inputs[0]->addr, outputs[0]->addr, hccl_count_,
hccl_data_type_list_[0], op_type_, stream_ptr);
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclAllReduce(
inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], op_type_, stream_ptr, group_);
if (hccl_result != HCCL_SUCCESS) {
MS_LOG(ERROR) << "HcclAllReduce faled, ret:" << hccl_result;
return false;


+ 18
- 4
mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.cc View File

@@ -16,13 +16,27 @@

#include "backend/kernel_compiler/hccl/hcom_all_reduce_scatter.h"
#include <memory>
#include "utils/ms_context.h"
#include "runtime/hccl_adapter/hccl_adapter.h"

namespace mindspore {
namespace kernel {
bool HcomAllReduceScatterKernel::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, void *) {
MS_LOG(INFO) << "HcomAllReduceScatter launch";
bool HcomAllReduceScatterKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
MS_LOG(DEBUG) << "HcomAllReduceScatter launch";
if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) {
MS_LOG(ERROR) << "Invalid AllReduceScatter input, output or data type size(" << inputs.size() << ", "
<< outputs.size() << ", " << hccl_data_type_list_.size() << ").";
return false;
}
MS_EXCEPTION_IF_NULL(inputs[0]);
MS_EXCEPTION_IF_NULL(outputs[0]);
MS_EXCEPTION_IF_NULL(stream_ptr);
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclReduceScatter(
inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], op_type_, stream_ptr, group_);
if (hccl_result != HCCL_SUCCESS) {
MS_LOG(ERROR) << "HcclReduceScatter faled, ret:" << hccl_result;
return false;
}
return true;
}
} // namespace kernel


+ 0
- 1
mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.h View File

@@ -19,7 +19,6 @@

#include <vector>
#include <memory>
#include "hccl/hcom.h"
#include "backend/kernel_compiler/hccl/hccl_kernel.h"

namespace mindspore {


+ 17
- 3
mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_receive.cc View File

@@ -16,12 +16,26 @@
#include "backend/kernel_compiler/hccl/hcom_receive.h"
#include <memory>
#include "utils/ms_context.h"
#include "runtime/hccl_adapter/hccl_adapter.h"
namespace mindspore {
namespace kernel {
bool HcomReceiveKernel::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, void *) {
MS_LOG(INFO) << "HcomReceive launch";
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
MS_LOG(DEBUG) << "HcomReceive launch";
if (outputs.empty() || hccl_data_type_list_.empty()) {
MS_LOG(ERROR) << "Invalid HcomReceive outputs size or data type size (" << outputs.size() << ", "
<< hccl_data_type_list_.size() << ").";
return false;
}
MS_EXCEPTION_IF_NULL(outputs[0]);
MS_EXCEPTION_IF_NULL(stream_ptr);
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclRecv(outputs[0]->addr, hccl_count_, hccl_data_type_list_[0],
src_rank_, stream_ptr, group_);
if (hccl_result != HCCL_SUCCESS) {
MS_LOG(ERROR) << "HcomReceive failed, ret:" << hccl_result;
return false;
}
return true;
}
} // namespace kernel


+ 17
- 4
mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_send.cc View File

@@ -16,13 +16,26 @@
#include "backend/kernel_compiler/hccl/hcom_send.h"
#include <memory>
#include "utils/ms_context.h"
#include "runtime/hccl_adapter/hccl_adapter.h"
namespace mindspore {
namespace kernel {
bool HcomSendKernel::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, void *) {
MS_LOG(INFO) << "HcomSend launch";
bool HcomSendKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, void *stream_ptr) {
MS_LOG(DEBUG) << "HcomSend launch";
if (inputs.empty() || hccl_data_type_list_.empty()) {
MS_LOG(ERROR) << "Invalid HcomSend input size or data type size (" << inputs.size() << ", "
<< hccl_data_type_list_.size() << ").";
return false;
}
MS_EXCEPTION_IF_NULL(inputs[0]);
MS_EXCEPTION_IF_NULL(stream_ptr);
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclSend(inputs[0]->addr, hccl_count_, hccl_data_type_list_[0],
dest_rank_, stream_ptr, group_);
if (hccl_result != HCCL_SUCCESS) {
MS_LOG(ERROR) << "HcomSend faled, ret:" << hccl_result;
return false;
}
return true;
}
} // namespace kernel


+ 28
- 0
mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc View File

@@ -218,6 +218,34 @@ bool HcomUtil::GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id) {
return true;
}

bool HcomUtil::GetHcomSrcRank(const AnfNodePtr &anf_node, uint32_t *src_rank) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(src_rank);
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
if (primitive->GetAttr("src_rank") != nullptr) {
*src_rank = static_cast<uint32_t>(GetValue<int64_t>(primitive->GetAttr("src_rank")));
} else {
MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_SRC_RANK fail, not support!";
return false;
}
return true;
}

bool HcomUtil::GetHcomDestRank(const AnfNodePtr &anf_node, uint32_t *dest_rank) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(dest_rank);
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
if (primitive->GetAttr("dest_rank") != nullptr) {
*dest_rank = static_cast<uint32_t>(GetValue<int64_t>(primitive->GetAttr("dest_rank")));
} else {
MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_DEST_RANK fail, not support!";
return false;
}
return true;
}

bool HcomUtil::GetHcomReceiveType(const AnfNodePtr &anf_node, TypeId *receive_type) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(receive_type);


+ 2
- 0
mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.h View File

@@ -66,6 +66,8 @@ class HcomUtil {
const vector<vector<size_t>> &shape_list, uint64_t *total_count);
static bool GetHcomOperationType(const AnfNodePtr &anf_node, HcclReduceOp *op_type);
static bool GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id);
static bool GetHcomSrcRank(const AnfNodePtr &anf_node, uint32_t *src_rank);
static bool GetHcomDestRank(const AnfNodePtr &anf_node, uint32_t *dest_rank);
static void GetHcomGroup(NotNull<const AnfNodePtr &> anf_node, NotNull<std::string *> group);
static bool GetHcomReceiveType(const AnfNodePtr &anf_node, TypeId *receive_type);
};


+ 4
- 0
mindspore/ccsrc/runtime/device/ascend/ascend_event.cc View File

@@ -53,6 +53,10 @@ void AscendEvent::WaitEvent() {
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "rtStreamWaitEvent failed, ret:" << ret;
}
ret = rtEventReset(event_, wait_stream_);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "rtEventReset failed, ret:" << ret;
}
need_wait_ = false;
}



+ 25
- 26
mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc View File

@@ -22,6 +22,7 @@
#include "utils/signal_util.h"
#include "debug/data_dump/e2e_dump.h"
#include "runtime/device/ascend/ascend_device_address.h"
#include "runtime/device/ascend/distribute/ascend_collective.h"
#include "utils/ms_context.h"
#include "utils/context/context_extends.h"
#include "utils/mpi/mpi_config.h"
@@ -64,6 +65,7 @@ using mindspore::device::ascend::ProfilingManager;
using mindspore::device::ascend::ProfilingUtils;
using mindspore::device::ascend::tasksink::TaskGenerator;
using mindspore::ge::model_runner::ModelRunner;
using HcclCollectiveGroup = mindspore::device::ascend::collective::HcclCollectiveGroup;
using mindspore::kernel::tbe::TbeUtils;
using std::vector;

@@ -78,32 +80,17 @@ namespace mindspore::device::ascend {
static thread_local rtContext_t thread_local_rt_context{nullptr};
namespace {
std::string GetRankId() {
std::string rank_id_str;
#ifdef ENABLE_MPI
auto mpi_config_ptr = MpiConfig::GetInstance();
MS_EXCEPTION_IF_NULL(mpi_config_ptr);
if (mpi_config_ptr->enable_mpi()) {
int rank_id = GetMPIRankId();
const std::string offset = common::GetEnv("RANK_OFFSET");
if (offset.empty()) {
try {
int rank_offset = std::stoi(offset);
rank_id += rank_offset;
} catch (std::invalid_argument) {
MS_LOG(EXCEPTION) << "Call stoi invalid argument:" << offset;
} catch (std::out_of_range) {
MS_LOG(EXCEPTION) << "Call stoi out_of_range:" << offset;
}
}
rank_id_str = std::to_string(rank_id);
} else {
rank_id_str = common::GetEnv("RANK_ID");
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
MS_LOG(INFO) << "Get hccl rankid from mpi";
auto rank = HcclCollectiveGroup::instance().GetRankId();
return std::to_string(rank);
}
#else
rank_id_str = common::GetEnv("RANK_ID");
#endif
std::string rank_id_str;
rank_id_str = std::getenv("RANK_ID");
if (rank_id_str.empty()) {
MS_LOG(ERROR) << "Get hccl rankid failed, please set env RANK_ID";
MS_LOG(EXCEPTION) << "Get hccl rankid failed, please set env RANK_ID";
}
return rank_id_str;
}
@@ -744,6 +731,7 @@ bool AscendKernelRuntime::SyncStream() {
MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
return false;
}

if (RT_ERROR_NONE != rtStreamSynchronize(communication_stream_)) { // o for switch stream
MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
return false;
@@ -832,7 +820,6 @@ bool AscendKernelRuntime::ResetDevice(uint32_t device_id) {
}
stream_ = nullptr;
}

if (communication_stream_ != nullptr) {
ret = rtStreamDestroy(communication_stream_);
if (ret != RT_ERROR_NONE) {
@@ -840,7 +827,6 @@ bool AscendKernelRuntime::ResetDevice(uint32_t device_id) {
}
communication_stream_ = nullptr;
}

ret = rtDeviceReset(device_id);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtDeviceReset, ret[" << ret << "]";
@@ -857,6 +843,19 @@ bool AscendKernelRuntime::HcclInit() {
MS_LOG(EXCEPTION) << "Hccl dependent tsd is not open";
}
MS_LOG(INFO) << "Do hcom init";
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
if (!is_task_sink && mode == kGraphMode) {
hccl::HcclAdapter::GetInstance().InitHccl();
std::vector<unsigned int> ranks;
auto rank_size = HcclCollectiveGroup::instance().GetRankSize();
for (size_t i = 0; i < IntToSize(rank_size); ++i) {
ranks.push_back(i);
}
HcclCollectiveGroup::instance().CreateCommGroup(kHcclWorldGroup, ranks);
return true;
}

auto config_path_str = std::getenv("MINDSPORE_HCCL_CONFIG_PATH");
if (config_path_str == nullptr) {
config_path_str = std::getenv("RANK_TABLE_FILE");


+ 4
- 2
mindspore/ccsrc/runtime/device/kernel_runtime.cc View File

@@ -482,9 +482,12 @@ void KernelRuntime::GenKernelEvents(const session::KernelGraph *graph) {
for (size_t j = i + 1; j < kernels.size(); ++j) {
auto &child = kernels[j];
MS_EXCEPTION_IF_NULL(child);
if (AnfAlgo::IsCommunicationOp(child)) {
continue;
}
auto input_size = child->inputs().size() - 1;
for (size_t k = 0; k < input_size; ++k) {
auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(child, k), 0);
auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(child, k), 0, true);
if (kernel_index.first == kernel) {
found_nearest_child = true;
break;
@@ -617,7 +620,6 @@ void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodeP
if (addr_size.empty()) {
return;
}

if (type == kSomasReuseDynamicMem) {
bool not_reuse = KernelMemNotReuse(node);
if (not_reuse) {


+ 88
- 7
mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.cc View File

@@ -26,7 +26,10 @@
#include "hccl/hcom.h"
#include "utils/log_adapter.h"
#include "utils/ms_utils.h"
#include "utils/ms_context.h"
#include "runtime/hccl_adapter/converter.h"
#include "runtime/device/ascend/distribute/ascend_collective.h"
using HcclCollectiveGroup = mindspore::device::ascend::collective::HcclCollectiveGroup;

static constexpr const char *kHcclPluginFileName = "libhccl_plugin.so";
static constexpr const char *kHcclDeployModeEnv = "DEPLOY_MODE";
@@ -75,7 +78,6 @@ void HcclAdapter::InitPlugin() {
if (plugin_handle_ == nullptr) {
MS_LOG(EXCEPTION) << "Dlopen " << kHcclPluginFileName << " failed, result = " << GetDlErrorMsg();
}

init_hcom_graph_adapter_ = DlsymFuncObj(InitHcomGraphAdapter, plugin_handle_);
finalize_hcom_graph_adapter_ = DlsymFuncObj(FinalizeHcomGraphAdapter, plugin_handle_);
get_hccl_kernel_info_store_ = DlsymFuncObj(GetHcclKernelInfoStore, plugin_handle_);
@@ -98,7 +100,6 @@ void HcclAdapter::FinalizePlugin() {
if (plugin_handle_ == nullptr) {
return;
}

init_hcom_graph_adapter_ = nullptr;
finalize_hcom_graph_adapter_ = nullptr;
get_hccl_kernel_info_store_ = nullptr;
@@ -107,6 +108,10 @@ void HcclAdapter::FinalizePlugin() {
finalize_hccl_comm_ = nullptr;
launch_hccl_broadcast_ = nullptr;
launch_hccl_all_reduce_ = nullptr;
launch_hccl_reduce_scatter_ = nullptr;
launch_hccl_all_gather_ = nullptr;
launch_hccl_send_ = nullptr;
launch_hccl_recv_ = nullptr;
hccl_create_group_ = nullptr;
hccl_destroy_group_ = nullptr;
hccl_get_rank_id_ = nullptr;
@@ -119,6 +124,19 @@ void HcclAdapter::FinalizePlugin() {
plugin_handle_ = nullptr;
}

bool HcclAdapter::InitHccl() {
MS_LOG(INFO) << "Start init hccl adapter.";
std::lock_guard<std::mutex> lock(init_mutex_);
if (init_flag_) {
MS_LOG(INFO) << "Hccl has been inited, skip.";
return true;
}
InitPlugin();
init_flag_ = true;
MS_LOG(INFO) << "Init hccl adapter success.";
return true;
}

bool HcclAdapter::InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file) {
MS_LOG(INFO) << "Start init hccl adapter.";
std::lock_guard<std::mutex> lock(init_mutex_);
@@ -136,12 +154,10 @@ bool HcclAdapter::InitHccl(uint32_t device_id, std::string_view rank_id, std::st
if (!ret) {
return false;
}

ret = InitHcclExec();
if (!ret) {
return false;
}

init_flag_ = true;
MS_LOG(INFO) << "Init hccl adapter success.";
return true;
@@ -238,10 +254,69 @@ HcclResult HcclAdapter::HcclBroadcast(void *buf, uint64_t count, HcclDataType da
return launch_hccl_broadcast_(buf, count, dataType, root, hccl_comm_, stream);
}

HcclResult HcclAdapter::HcclAllReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType,
HcclReduceOp op, aclrtStream stream) const {
HcclResult HcclAdapter::HcclAllReduce(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType,
HcclReduceOp op, aclrtStream stream, const std::string &group) const {
MS_EXCEPTION_IF_NULL(launch_hccl_all_reduce_);
return launch_hccl_all_reduce_(sendBuf, recvBuf, count, dataType, op, hccl_comm_, stream);
HcclComm hccl_comm;
if (hccl_comm_ != nullptr) {
hccl_comm = hccl_comm_;
} else {
hccl_comm = HcclCollectiveGroup::instance().GetGroupComm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
}
return launch_hccl_all_reduce_(send_buf, recv_buf, count, dataType, op, hccl_comm, stream);
}

HcclResult HcclAdapter::HcclReduceScatter(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType,
HcclReduceOp op, aclrtStream stream, const std::string &group) const {
MS_EXCEPTION_IF_NULL(launch_hccl_reduce_scatter_);
HcclComm hccl_comm;
if (hccl_comm_ != nullptr) {
hccl_comm = hccl_comm_;
} else {
hccl_comm = HcclCollectiveGroup::instance().GetGroupComm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
}
return launch_hccl_reduce_scatter_(send_buf, recv_buf, count, dataType, op, hccl_comm, stream);
}

HcclResult HcclAdapter::HcclAllGather(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType,
aclrtStream stream, const std::string &group) const {
MS_EXCEPTION_IF_NULL(launch_hccl_all_gather_);
HcclComm hccl_comm;
if (hccl_comm_ != nullptr) {
hccl_comm = hccl_comm_;
} else {
hccl_comm = HcclCollectiveGroup::instance().GetGroupComm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
}
return launch_hccl_all_gather_(send_buf, recv_buf, count, dataType, hccl_comm, stream);
}

HcclResult HcclAdapter::HcclSend(void *send_buf, uint64_t count, HcclDataType dataType, uint32_t destRank,
aclrtStream stream, const std::string &group) const {
MS_EXCEPTION_IF_NULL(launch_hccl_send_);
HcclComm hccl_comm;
if (hccl_comm_ != nullptr) {
hccl_comm = hccl_comm_;
} else {
hccl_comm = HcclCollectiveGroup::instance().GetGroupComm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
}
return launch_hccl_send_(send_buf, count, dataType, destRank, hccl_comm, stream);
}

HcclResult HcclAdapter::HcclRecv(void *recv_buf, uint64_t count, HcclDataType dataType, uint32_t srcRank,
aclrtStream stream, const std::string &group) const {
MS_EXCEPTION_IF_NULL(launch_hccl_recv_);
HcclComm hccl_comm;
if (hccl_comm_ != nullptr) {
hccl_comm = hccl_comm_;
} else {
hccl_comm = HcclCollectiveGroup::instance().GetGroupComm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
}
return launch_hccl_recv_(recv_buf, count, dataType, srcRank, hccl_comm, stream);
}

bool HcclAdapter::InitKernelInfoStore(uint32_t device_id, std::string_view rank_id, std::string_view rank_file) {
@@ -338,6 +413,12 @@ bool HcclAdapter::InitHcclComm(std::string_view rank_id, std::string_view rank_f

bool HcclAdapter::FinalizeHcclComm() {
MS_LOG(INFO) << "Start finalize hccl comm.";
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
if (!task_sink) {
HcclCollectiveGroup::instance().DestroyCommGroup();
}
if (hccl_comm_ == nullptr) {
return true;
}


+ 15
- 2
mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.h View File

@@ -43,6 +43,7 @@ class HcclAdapter {

// common
bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file);
bool InitHccl();
bool FinalizeHccl();

HcclResult HcclCreateGroup(const std::string &group, uint32_t rank_num, uint32_t *rank_ids) const;
@@ -58,8 +59,16 @@ class HcclAdapter {

// for single op
HcclResult HcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root, aclrtStream stream) const;
HcclResult HcclAllReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, HcclReduceOp op,
aclrtStream stream) const;
HcclResult HcclAllReduce(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, HcclReduceOp op,
aclrtStream stream, const std::string &group = "") const;
HcclResult HcclAllGather(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, aclrtStream stream,
const std::string &group = "") const;
HcclResult HcclReduceScatter(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, HcclReduceOp op,
aclrtStream stream, const std::string &group = "") const;
HcclResult HcclSend(void *send_buf, uint64_t count, HcclDataType dataType, uint32_t destRank, aclrtStream stream,
const std::string &group = "") const;
HcclResult HcclRecv(void *recv_buf, uint64_t count, HcclDataType dataType, uint32_t srcRank, aclrtStream stream,
const std::string &group = "") const;

// for enqueue op
HcclResult HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const;
@@ -91,6 +100,10 @@ class HcclAdapter {
HcclCommDestroyFunObj finalize_hccl_comm_ = nullptr;
HcclBroadcastFunObj launch_hccl_broadcast_ = nullptr;
HcclAllReduceFunObj launch_hccl_all_reduce_ = nullptr;
HcclReduceScatterFunObj launch_hccl_reduce_scatter_ = nullptr;
HcclAllGatherFunObj launch_hccl_all_gather_ = nullptr;
HcclSendFunObj launch_hccl_send_ = nullptr;
HcclRecvFunObj launch_hccl_recv_ = nullptr;

HcomCreateGroupFunObj hccl_create_group_ = nullptr;
HcomDestroyGroupFunObj hccl_destroy_group_ = nullptr;


+ 6
- 0
mindspore/ccsrc/runtime/hccl_adapter/plugin/hccl_plugin.h View File

@@ -47,6 +47,12 @@ PLUGIN_METHOD(GetAllKernelBuilder, void, OpsKernelBuilderMap *);

ORIGIN_METHOD(HcclBroadcast, HcclResult, void *, uint64_t, HcclDataType, uint32_t, HcclComm, aclrtStream);
ORIGIN_METHOD(HcclAllReduce, HcclResult, void *, void *, uint64_t, HcclDataType, HcclReduceOp, HcclComm, aclrtStream);
ORIGIN_METHOD(HcclReduceScatter, HcclResult, void *, void *, uint64_t, HcclDataType, HcclReduceOp, HcclComm,
aclrtStream);
ORIGIN_METHOD(HcclAllGather, HcclResult, void *, void *, uint64_t, HcclDataType, HcclComm, aclrtStream);
ORIGIN_METHOD(HcclSend, HcclResult, void *, uint64_t, HcclDataType, uint32_t, HcclComm, aclrtStream);
ORIGIN_METHOD(HcclRecv, HcclResult, void *, uint64_t, HcclDataType, uint32_t, HcclComm, aclrtStream);

ORIGIN_METHOD(HcclCommInitClusterInfo, HcclResult, const char *, uint32_t, HcclComm *);
ORIGIN_METHOD(HcclCommDestroy, HcclResult, HcclComm);
ORIGIN_METHOD(HcomCreateGroup, HcclResult, const char *, uint32_t, uint32_t *);


+ 16
- 1
tests/ut/cpp/stub/ge/ge_task_launch_stub.cc View File

@@ -22,6 +22,7 @@ HcclAdapter &HcclAdapter::GetInstance() {
static HcclAdapter instance;
return instance;
}
bool HcclAdapter::InitHccl() { return true; }
bool HcclAdapter::InitHccl(uint32_t, std::string_view, std::string_view) { return true; }
bool HcclAdapter::FinalizeHccl() { return true; }
HcclResult HcclAdapter::HcclCreateGroup(const std::string &, uint32_t, uint32_t *) const { return HCCL_SUCCESS; }
@@ -35,7 +36,21 @@ std::string HcclAdapter::GetHcclType(const AnfNodePtr &) { return ""; }
HcclResult HcclAdapter::HcclBroadcast(void *, uint64_t, HcclDataType, uint32_t, aclrtStream) const {
return HCCL_SUCCESS;
}
HcclResult HcclAdapter::HcclAllReduce(void *, void *, uint64_t, HcclDataType, HcclReduceOp, aclrtStream) const {
HcclResult HcclAdapter::HcclAllReduce(void *, void *, uint64_t, HcclDataType, HcclReduceOp, aclrtStream,
const std::string &) const {
return HCCL_SUCCESS;
}
HcclResult HcclAdapter::HcclAllGather(void *, void *, uint64_t, HcclDataType, aclrtStream, const std::string &) const {
return HCCL_SUCCESS;
}
HcclResult HcclAdapter::HcclReduceScatter(void *, void *, uint64_t, HcclDataType, HcclReduceOp, aclrtStream,
const std::string &) const {
return HCCL_SUCCESS;
}
HcclResult HcclAdapter::HcclSend(void *, uint64_t, HcclDataType, uint32_t, aclrtStream, const std::string &) const {
return HCCL_SUCCESS;
}
HcclResult HcclAdapter::HcclRecv(void *, uint64_t, HcclDataType, uint32_t, aclrtStream, const std::string &) const {
return HCCL_SUCCESS;
}
HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const {


+ 34
- 0
tests/ut/cpp/stub/hccl/collective_stub.cc View File

@@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/distribute/ascend_collective.h"
namespace mindspore {
namespace device {
namespace ascend {
namespace collective {
HcclCollectiveGroup &HcclCollectiveGroup::instance() {
static HcclCollectiveGroup instance;
return instance;
}
int HcclCollectiveGroup::GetRankSize(const std::string &) const { return 0; }
int HcclCollectiveGroup::GetRankId(const std::string &) const { return 0; }
int HcclCollectiveGroup::GetDeviceId() const { return 0; }
void HcclCollectiveGroup::CreateCommGroup(const std::string &, const std::vector<unsigned int> &) { return; }
} // namespace collective
} // namespace ascend
} // namespace device
} // namespace mindspore

Loading…
Cancel
Save