|
|
|
@@ -44,7 +44,7 @@ std::vector<AnfNodePtr> SplitInputsForReduceScatter::InsertSplitForInput(const F |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split.get()); |
|
|
|
|
|
|
|
AnfAlgo::SetNodeAttr("split_dim", MakeValue(0L), split); |
|
|
|
AnfAlgo::SetNodeAttr("num_split", MakeValue(rank_size_t), split); |
|
|
|
AnfAlgo::SetNodeAttr("num_split", MakeValue(rank_size), split); |
|
|
|
AnfAlgo::SetNodeAttr("size_splits", MakeValue(size_splits), split); |
|
|
|
kernel_select_->SelectKernel(split); |
|
|
|
std::vector<AnfNodePtr> new_outputs; |
|
|
|
|