|
|
|
@@ -64,7 +64,7 @@ HcclKernelFactory &HcclKernelFactory::Get() { |
|
|
|
return _this; |
|
|
|
} |
|
|
|
|
|
|
|
HcclKernel::HcclKernel() : hccl_count_(0), op_type_(HCCL_REDUCE_SUM), root_id_(0), anf_node_(nullptr) {} |
|
|
|
HcclKernel::HcclKernel() : hccl_count_(0), op_type_(HCCL_REDUCE_SUM), root_id_(0) {} |
|
|
|
|
|
|
|
HcclKernel::~HcclKernel() { |
|
|
|
hccl_kernel_input_shape_list_.clear(); |
|
|
|
@@ -76,7 +76,6 @@ HcclKernel::~HcclKernel() { |
|
|
|
input_size_list_.clear(); |
|
|
|
output_size_list_.clear(); |
|
|
|
workspace_size_list_.clear(); |
|
|
|
anf_node_ = nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
bool HcclKernel::Init(const AnfNodePtr &anf_node) { |
|
|
|
@@ -150,11 +149,15 @@ const std::vector<size_t> &HcclKernel::GetInputSizeList() const { |
|
|
|
} |
|
|
|
|
|
|
|
const std::vector<size_t> &HcclKernel::GetOutputSizeList() const { |
|
|
|
auto anf_node = anf_node_.lock(); |
|
|
|
if (!anf_node) { |
|
|
|
MS_LOG(EXCEPTION) << "anf_node pointer is expired."; |
|
|
|
} |
|
|
|
size_t size = 0; |
|
|
|
if (!output_size_list_.empty()) { |
|
|
|
return output_size_list_; |
|
|
|
} |
|
|
|
auto cnode = anf_node_->cast<CNodePtr>(); |
|
|
|
auto cnode = anf_node->cast<CNodePtr>(); |
|
|
|
auto op_name = AnfAlgo::GetCNodeName(cnode); |
|
|
|
int64_t rank_size = 1; |
|
|
|
if (AnfAlgo::HasNodeAttr(kAttrRankSize, cnode)) { |
|
|
|
@@ -165,11 +168,11 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const { |
|
|
|
fusion = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion); |
|
|
|
} |
|
|
|
ulong loop_size = hccl_data_type_list_.size(); |
|
|
|
if (AnfAlgo::GetInputTensorNum(anf_node_) > 1 && op_name == kAllGatherOpName && fusion >= 1) { |
|
|
|
if (AnfAlgo::GetInputTensorNum(anf_node) > 1 && op_name == kAllGatherOpName && fusion >= 1) { |
|
|
|
loop_size *= rank_size; |
|
|
|
} |
|
|
|
if (op_name == kReduceScatterOpName && fusion >= 1) { |
|
|
|
loop_size = AnfAlgo::GetOutputTensorNum(anf_node_); |
|
|
|
loop_size = AnfAlgo::GetOutputTensorNum(anf_node); |
|
|
|
} |
|
|
|
for (ulong i = 0; i < loop_size; ++i) { |
|
|
|
if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[0], hccl_kernel_output_shape_list_[i], &size)) { |
|
|
|
@@ -185,14 +188,18 @@ const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const { |
|
|
|
return workspace_size_list_; |
|
|
|
} |
|
|
|
|
|
|
|
workspace_size_list_.emplace_back(hccl::CalcWorkspaceSize(anf_node_, hccl_data_type_list_[0])); |
|
|
|
workspace_size_list_.emplace_back(hccl::CalcWorkspaceSize(anf_node_.lock(), hccl_data_type_list_[0])); |
|
|
|
return workspace_size_list_; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inputs, |
|
|
|
const std::vector<AddressPtr> &workspace, |
|
|
|
const std::vector<AddressPtr> &outputs, uint32_t stream_id) { |
|
|
|
std::string hccl_type = AnfAlgo::GetCNodeName(anf_node_); |
|
|
|
auto anf_node = anf_node_.lock(); |
|
|
|
if (!anf_node) { |
|
|
|
MS_LOG(EXCEPTION) << "anf_node pointer is expired."; |
|
|
|
} |
|
|
|
std::string hccl_type = AnfAlgo::GetCNodeName(anf_node); |
|
|
|
if (hccl_type == kReceive) { |
|
|
|
if (outputs.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "Outputs is empty"; |
|
|
|
@@ -211,9 +218,9 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu |
|
|
|
std::vector<uint8_t> private_def; |
|
|
|
HcclDataType data_type = hccl_data_type_list_[0]; |
|
|
|
std::vector<hccl::HcclTaskInfo> task_info; |
|
|
|
bool ret = hccl::GenTask(anf_node_, data_type, &task_info); |
|
|
|
bool ret = hccl::GenTask(anf_node, data_type, &task_info); |
|
|
|
if (!ret) { |
|
|
|
MS_LOG(EXCEPTION) << "Gen Task for " << anf_node_->DebugString() << " failed."; |
|
|
|
MS_LOG(EXCEPTION) << "Gen Task for " << anf_node->DebugString() << " failed."; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<TaskInfoPtr> results; |
|
|
|
@@ -232,14 +239,14 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu |
|
|
|
void *workspace_addr = nullptr; |
|
|
|
if (task.workspace_size != 0) { |
|
|
|
if (workspace.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "Workspace size list of " << anf_node_->DebugString() << " is empty"; |
|
|
|
MS_LOG(EXCEPTION) << "Workspace size list of " << anf_node->DebugString() << " is empty"; |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(workspace.at(0)); |
|
|
|
workspace_addr = workspace.at(0)->addr; |
|
|
|
} |
|
|
|
|
|
|
|
results.emplace_back(std::make_shared<HcclTaskInfo>( |
|
|
|
kernel_name_, stream_id, hccl::GetHcclType(anf_node_), input_data_addr, output_data_addr, workspace_addr, |
|
|
|
kernel_name_, stream_id, hccl::GetHcclType(anf_node), input_data_addr, output_data_addr, workspace_addr, |
|
|
|
task.workspace_size, task.stream_num, private_def, hccl::GetHcclOpsKernelInfoStore(), hccl_count_, root_id_, |
|
|
|
op_type_, data_type, group_, NeedDump())); |
|
|
|
} |
|
|
|
@@ -253,7 +260,7 @@ device::DynamicKernelPtr HcclKernel::GenDynamicKernel(const CNodePtr &cnode_ptr, |
|
|
|
AddressPtrList outputs; |
|
|
|
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &inputs, &workspaces, &outputs); |
|
|
|
|
|
|
|
std::string hccl_type = MsOpNameToHcomOpType(AnfAlgo::GetCNodeName(anf_node_)); |
|
|
|
std::string hccl_type = MsOpNameToHcomOpType(AnfAlgo::GetCNodeName(anf_node_.lock())); |
|
|
|
|
|
|
|
if (inputs.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "Hccl kernel input is empty"; |
|
|
|
|