|
|
@@ -204,7 +204,9 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr |
|
|
auto cnode = communication_op_info.communication_op_nodes[idx];
|
|
|
auto cnode = communication_op_info.communication_op_nodes[idx];
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(cnode, 0);
|
|
|
std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(cnode, 0);
|
|
|
shape[0] /= rank_size;
|
|
|
|
|
|
|
|
|
if (!shape.empty()) {
|
|
|
|
|
|
shape[0] /= rank_size;
|
|
|
|
|
|
}
|
|
|
shapes.push_back(shape);
|
|
|
shapes.push_back(shape);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|