From a511b2641440f37261249e3b099af8bde754e05c Mon Sep 17 00:00:00 2001 From: zhupuxu Date: Tue, 16 Mar 2021 09:23:54 +0800 Subject: [PATCH] fix bug for circular reference Signed-off-by: zhupuxu --- .../kernel_compiler/hccl/hccl_kernel.cc | 31 ++++++++++++------- .../kernel_compiler/hccl/hccl_kernel.h | 2 +- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc index cfc792843a..5b955f2013 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc @@ -64,7 +64,7 @@ HcclKernelFactory &HcclKernelFactory::Get() { return _this; } -HcclKernel::HcclKernel() : hccl_count_(0), op_type_(HCCL_REDUCE_SUM), root_id_(0), anf_node_(nullptr) {} +HcclKernel::HcclKernel() : hccl_count_(0), op_type_(HCCL_REDUCE_SUM), root_id_(0) {} HcclKernel::~HcclKernel() { hccl_kernel_input_shape_list_.clear(); @@ -76,7 +76,6 @@ HcclKernel::~HcclKernel() { input_size_list_.clear(); output_size_list_.clear(); workspace_size_list_.clear(); - anf_node_ = nullptr; } bool HcclKernel::Init(const AnfNodePtr &anf_node) { @@ -150,11 +149,15 @@ const std::vector &HcclKernel::GetInputSizeList() const { } const std::vector &HcclKernel::GetOutputSizeList() const { + auto anf_node = anf_node_.lock(); + if (!anf_node) { + MS_LOG(EXCEPTION) << "anf_node pointer is expired."; + } size_t size = 0; if (!output_size_list_.empty()) { return output_size_list_; } - auto cnode = anf_node_->cast(); + auto cnode = anf_node->cast(); auto op_name = AnfAlgo::GetCNodeName(cnode); int64_t rank_size = 1; if (AnfAlgo::HasNodeAttr(kAttrRankSize, cnode)) { @@ -165,11 +168,11 @@ const std::vector &HcclKernel::GetOutputSizeList() const { fusion = AnfAlgo::GetNodeAttr(cnode, kAttrFusion); } ulong loop_size = hccl_data_type_list_.size(); - if (AnfAlgo::GetInputTensorNum(anf_node_) > 1 && op_name == kAllGatherOpName && fusion >= 1) { + if (AnfAlgo::GetInputTensorNum(anf_node) > 1 && op_name == kAllGatherOpName && fusion >= 1) { loop_size *= rank_size; } if (op_name == kReduceScatterOpName && fusion >= 1) { - loop_size = AnfAlgo::GetOutputTensorNum(anf_node_); + loop_size = AnfAlgo::GetOutputTensorNum(anf_node); } for (ulong i = 0; i < loop_size; ++i) { if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[0], hccl_kernel_output_shape_list_[i], &size)) { @@ -185,14 +188,18 @@ const std::vector &HcclKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } - workspace_size_list_.emplace_back(hccl::CalcWorkspaceSize(anf_node_, hccl_data_type_list_[0])); + workspace_size_list_.emplace_back(hccl::CalcWorkspaceSize(anf_node_.lock(), hccl_data_type_list_[0])); return workspace_size_list_; } std::vector HcclKernel::GenTask(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, uint32_t stream_id) { - std::string hccl_type = AnfAlgo::GetCNodeName(anf_node_); + auto anf_node = anf_node_.lock(); + if (!anf_node) { + MS_LOG(EXCEPTION) << "anf_node pointer is expired."; + } + std::string hccl_type = AnfAlgo::GetCNodeName(anf_node); if (hccl_type == kReceive) { if (outputs.empty()) { MS_LOG(EXCEPTION) << "Outputs is empty"; @@ -211,9 +218,9 @@ std::vector HcclKernel::GenTask(const std::vector &inpu std::vector private_def; HcclDataType data_type = hccl_data_type_list_[0]; std::vector task_info; - bool ret = hccl::GenTask(anf_node_, data_type, &task_info); + bool ret = hccl::GenTask(anf_node, data_type, &task_info); if (!ret) { - MS_LOG(EXCEPTION) << "Gen Task for " << anf_node_->DebugString() << " failed."; + MS_LOG(EXCEPTION) << "Gen Task for " << anf_node->DebugString() << " failed."; } std::vector results; @@ -232,14 +239,14 @@ std::vector HcclKernel::GenTask(const std::vector &inpu void *workspace_addr = nullptr; if (task.workspace_size != 0) { if (workspace.empty()) { - MS_LOG(EXCEPTION) << "Workspace size list of " << anf_node_->DebugString() << " is empty"; + MS_LOG(EXCEPTION) << "Workspace size list of " << anf_node->DebugString() << " is empty"; } MS_EXCEPTION_IF_NULL(workspace.at(0)); workspace_addr = workspace.at(0)->addr; } results.emplace_back(std::make_shared( - kernel_name_, stream_id, hccl::GetHcclType(anf_node_), input_data_addr, output_data_addr, workspace_addr, + kernel_name_, stream_id, hccl::GetHcclType(anf_node), input_data_addr, output_data_addr, workspace_addr, task.workspace_size, task.stream_num, private_def, hccl::GetHcclOpsKernelInfoStore(), hccl_count_, root_id_, op_type_, data_type, group_, NeedDump())); } @@ -253,7 +260,7 @@ device::DynamicKernelPtr HcclKernel::GenDynamicKernel(const CNodePtr &cnode_ptr, AddressPtrList outputs; device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &inputs, &workspaces, &outputs); - std::string hccl_type = MsOpNameToHcomOpType(AnfAlgo::GetCNodeName(anf_node_)); + std::string hccl_type = MsOpNameToHcomOpType(AnfAlgo::GetCNodeName(anf_node_.lock())); if (inputs.empty()) { MS_LOG(EXCEPTION) << "Hccl kernel input is empty"; diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h index 4930ba6b6d..f519fc13fc 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h @@ -55,7 +55,7 @@ class HcclKernel : public AscendKernelMod { mutable std::vector input_size_list_; mutable std::vector output_size_list_; mutable std::vector workspace_size_list_; - AnfNodePtr anf_node_; + AnfNodeWeakPtr anf_node_; std::string op_name_; std::string group_; };