Browse Source

modify the data format of split_dims

tags/v1.2.0-rc1
alouhahaha 4 years ago
parent
commit
9a0ae7ddde
1 changed files with 1 additions and 1 deletions
  1. +1
    -1
      mindspore/ccsrc/backend/optimizer/ascend/enhancer/split_inputs_for_reduce_scatter.cc

+ 1
- 1
mindspore/ccsrc/backend/optimizer/ascend/enhancer/split_inputs_for_reduce_scatter.cc View File

@@ -42,7 +42,7 @@ std::vector<AnfNodePtr> SplitInputsForReduceScatter::InsertSplitForInput(const F
}
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split.get());

AnfAlgo::SetNodeAttr("split_dim", MakeValue(0), split);
AnfAlgo::SetNodeAttr("split_dim", MakeValue(0L), split);
AnfAlgo::SetNodeAttr("num_split", MakeValue(SizeToInt(rank_size)), split);
AnfAlgo::SetNodeAttr("size_splits", MakeValue(size_splits), split);
kernel_select_->SelectKernel(split);


Loading…
Cancel
Save