Browse Source

!9766 Fix a bug in AllGather fusion when there is only one input

From: @alouhahahahaha
Reviewed-by: @zhoufeng54,@xu-yfei
Signed-off-by: @xu-yfei
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
b82c4cba32
2 changed files with 2 additions and 2 deletions
  1. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc
  2. +1
    -1
      mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.cc

+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc View File

@@ -165,7 +165,7 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
fusion = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion);
}
ulong loop_size = hccl_data_type_list_.size();
if (op_name == kAllGatherOpName && fusion >= 1) {
if (AnfAlgo::GetInputTensorNum(anf_node_) > 1 && op_name == kAllGatherOpName && fusion >= 1) {
loop_size *= rank_size;
}
for (ulong i = 0; i < loop_size; ++i) {


+ 1
- 1
mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.cc View File

@@ -69,7 +69,7 @@ const AnfNodePtr ConcatOutputsForAllGather::Process(const FuncGraphPtr &func_gra
if (fusion <= 0) {
return nullptr;
}
if (AnfAlgo::HasNodeAttr("fused", cnode)) {
if (AnfAlgo::HasNodeAttr("fused", cnode) || AnfAlgo::GetInputTensorNum(node) == 1) {
return nullptr;
}
AnfAlgo::SetNodeAttr("fused", MakeValue(true), node);


Loading…
Cancel
Save