Browse Source

fix bug for circular reference

Signed-off-by: zhupuxu <zhupuxu@huawei.com>
tags/v1.2.0-rc1
zhupuxu 4 years ago
parent
commit
a511b26414
2 changed files with 20 additions and 13 deletions
  1. +19
    -12
      mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc
  2. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h

+ 19
- 12
mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc View File

@@ -64,7 +64,7 @@ HcclKernelFactory &HcclKernelFactory::Get() {
return _this;
}

HcclKernel::HcclKernel() : hccl_count_(0), op_type_(HCCL_REDUCE_SUM), root_id_(0), anf_node_(nullptr) {}
HcclKernel::HcclKernel() : hccl_count_(0), op_type_(HCCL_REDUCE_SUM), root_id_(0) {}

HcclKernel::~HcclKernel() {
hccl_kernel_input_shape_list_.clear();
@@ -76,7 +76,6 @@ HcclKernel::~HcclKernel() {
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
anf_node_ = nullptr;
}

bool HcclKernel::Init(const AnfNodePtr &anf_node) {
@@ -150,11 +149,15 @@ const std::vector<size_t> &HcclKernel::GetInputSizeList() const {
}

const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
auto anf_node = anf_node_.lock();
if (!anf_node) {
MS_LOG(EXCEPTION) << "anf_node pointer is expired.";
}
size_t size = 0;
if (!output_size_list_.empty()) {
return output_size_list_;
}
auto cnode = anf_node_->cast<CNodePtr>();
auto cnode = anf_node->cast<CNodePtr>();
auto op_name = AnfAlgo::GetCNodeName(cnode);
int64_t rank_size = 1;
if (AnfAlgo::HasNodeAttr(kAttrRankSize, cnode)) {
@@ -165,11 +168,11 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
fusion = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion);
}
ulong loop_size = hccl_data_type_list_.size();
if (AnfAlgo::GetInputTensorNum(anf_node_) > 1 && op_name == kAllGatherOpName && fusion >= 1) {
if (AnfAlgo::GetInputTensorNum(anf_node) > 1 && op_name == kAllGatherOpName && fusion >= 1) {
loop_size *= rank_size;
}
if (op_name == kReduceScatterOpName && fusion >= 1) {
loop_size = AnfAlgo::GetOutputTensorNum(anf_node_);
loop_size = AnfAlgo::GetOutputTensorNum(anf_node);
}
for (ulong i = 0; i < loop_size; ++i) {
if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[0], hccl_kernel_output_shape_list_[i], &size)) {
@@ -185,14 +188,18 @@ const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const {
return workspace_size_list_;
}

workspace_size_list_.emplace_back(hccl::CalcWorkspaceSize(anf_node_, hccl_data_type_list_[0]));
workspace_size_list_.emplace_back(hccl::CalcWorkspaceSize(anf_node_.lock(), hccl_data_type_list_[0]));
return workspace_size_list_;
}

std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uint32_t stream_id) {
std::string hccl_type = AnfAlgo::GetCNodeName(anf_node_);
auto anf_node = anf_node_.lock();
if (!anf_node) {
MS_LOG(EXCEPTION) << "anf_node pointer is expired.";
}
std::string hccl_type = AnfAlgo::GetCNodeName(anf_node);
if (hccl_type == kReceive) {
if (outputs.empty()) {
MS_LOG(EXCEPTION) << "Outputs is empty";
@@ -211,9 +218,9 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu
std::vector<uint8_t> private_def;
HcclDataType data_type = hccl_data_type_list_[0];
std::vector<hccl::HcclTaskInfo> task_info;
bool ret = hccl::GenTask(anf_node_, data_type, &task_info);
bool ret = hccl::GenTask(anf_node, data_type, &task_info);
if (!ret) {
MS_LOG(EXCEPTION) << "Gen Task for " << anf_node_->DebugString() << " failed.";
MS_LOG(EXCEPTION) << "Gen Task for " << anf_node->DebugString() << " failed.";
}

std::vector<TaskInfoPtr> results;
@@ -232,14 +239,14 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu
void *workspace_addr = nullptr;
if (task.workspace_size != 0) {
if (workspace.empty()) {
MS_LOG(EXCEPTION) << "Workspace size list of " << anf_node_->DebugString() << " is empty";
MS_LOG(EXCEPTION) << "Workspace size list of " << anf_node->DebugString() << " is empty";
}
MS_EXCEPTION_IF_NULL(workspace.at(0));
workspace_addr = workspace.at(0)->addr;
}

results.emplace_back(std::make_shared<HcclTaskInfo>(
kernel_name_, stream_id, hccl::GetHcclType(anf_node_), input_data_addr, output_data_addr, workspace_addr,
kernel_name_, stream_id, hccl::GetHcclType(anf_node), input_data_addr, output_data_addr, workspace_addr,
task.workspace_size, task.stream_num, private_def, hccl::GetHcclOpsKernelInfoStore(), hccl_count_, root_id_,
op_type_, data_type, group_, NeedDump()));
}
@@ -253,7 +260,7 @@ device::DynamicKernelPtr HcclKernel::GenDynamicKernel(const CNodePtr &cnode_ptr,
AddressPtrList outputs;
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &inputs, &workspaces, &outputs);

std::string hccl_type = MsOpNameToHcomOpType(AnfAlgo::GetCNodeName(anf_node_));
std::string hccl_type = MsOpNameToHcomOpType(AnfAlgo::GetCNodeName(anf_node_.lock()));

if (inputs.empty()) {
MS_LOG(EXCEPTION) << "Hccl kernel input is empty";


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h View File

@@ -55,7 +55,7 @@ class HcclKernel : public AscendKernelMod {
mutable std::vector<size_t> input_size_list_;
mutable std::vector<size_t> output_size_list_;
mutable std::vector<size_t> workspace_size_list_;
AnfNodePtr anf_node_;
AnfNodeWeakPtr anf_node_;
std::string op_name_;
std::string group_;
};


Loading…
Cancel
Save