|
|
|
@@ -165,7 +165,7 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const { |
|
|
|
fusion = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion); |
|
|
|
} |
|
|
|
ulong loop_size = hccl_data_type_list_.size(); |
|
|
|
if (op_name == kAllGatherOpName && fusion >= 1) { |
|
|
|
if (AnfAlgo::GetInputTensorNum(anf_node_) > 1 && op_name == kAllGatherOpName && fusion >= 1) { |
|
|
|
loop_size *= rank_size; |
|
|
|
} |
|
|
|
for (ulong i = 0; i < loop_size; ++i) { |
|
|
|
|