diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc index 19d37206ba..398b875143 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc @@ -165,7 +165,7 @@ const std::vector &HcclKernel::GetOutputSizeList() const { fusion = AnfAlgo::GetNodeAttr(cnode, kAttrFusion); } ulong loop_size = hccl_data_type_list_.size(); - if (op_name == kAllGatherOpName && fusion >= 1) { + if (AnfAlgo::GetInputTensorNum(anf_node_) > 1 && op_name == kAllGatherOpName && fusion >= 1) { loop_size *= rank_size; } for (ulong i = 0; i < loop_size; ++i) { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.cc index 7589ed1d12..fbf79d9352 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.cc @@ -69,7 +69,7 @@ const AnfNodePtr ConcatOutputsForAllGather::Process(const FuncGraphPtr &func_gra if (fusion <= 0) { return nullptr; } - if (AnfAlgo::HasNodeAttr("fused", cnode)) { + if (AnfAlgo::HasNodeAttr("fused", cnode) || AnfAlgo::GetInputTensorNum(node) == 1) { return nullptr; } AnfAlgo::SetNodeAttr("fused", MakeValue(true), node);