|
|
|
@@ -42,7 +42,7 @@ std::vector<AnfNodePtr> SplitInputsForReduceScatter::InsertSplitForInput(const F |
|
|
|
} |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split.get()); |
|
|
|
|
|
|
|
AnfAlgo::SetNodeAttr("split_dim", MakeValue(0), split); |
|
|
|
AnfAlgo::SetNodeAttr("split_dim", MakeValue(0L), split); |
|
|
|
AnfAlgo::SetNodeAttr("num_split", MakeValue(SizeToInt(rank_size)), split); |
|
|
|
AnfAlgo::SetNodeAttr("size_splits", MakeValue(size_splits), split); |
|
|
|
kernel_select_->SelectKernel(split); |
|
|
|
|