|
|
@@ -138,7 +138,8 @@ bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector<HcclDataTyp |
|
|
} else { |
|
|
} else { |
|
|
if (AnfAlgo::GetCNodeName(anf_node) == kAllGatherOpName) { |
|
|
if (AnfAlgo::GetCNodeName(anf_node) == kAllGatherOpName) { |
|
|
auto cnode = anf_node->cast<CNodePtr>(); |
|
|
auto cnode = anf_node->cast<CNodePtr>(); |
|
|
if (AnfAlgo::HasNodeAttr(kAttrFusion, cnode) && AnfAlgo::GetNodeAttr<int64_t>(anf_node, kAttrFusion)) { |
|
|
|
|
|
|
|
|
if (AnfAlgo::HasNodeAttr(kAttrFusion, cnode) && AnfAlgo::GetNodeAttr<int64_t>(anf_node, kAttrFusion) && |
|
|
|
|
|
AnfAlgo::GetInputTensorNum(anf_node) > 1) { |
|
|
block_size = (input_size + align_size - 1 + filled_size) / align_size * align_size; |
|
|
block_size = (input_size + align_size - 1 + filled_size) / align_size * align_size; |
|
|
} else { |
|
|
} else { |
|
|
block_size = input_size; |
|
|
block_size = input_size; |
|
|
|