diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.cc index c39a5e01e6..2ab1cb6130 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.cc @@ -82,6 +82,9 @@ void CreateOutputShapeAndTypeId(const CNodePtr &origin_cnode, int split_dim, int MS_EXCEPTION_IF_NULL(new_type_ids); MS_EXCEPTION_IF_NULL(new_output_shapes); auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); + if (split_dim < 0) { + split_dim += output_shape.size(); + } output_shape[split_dim] = split_size; TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); for (int i = 0; i < num_split; ++i) { @@ -97,6 +100,9 @@ void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePt std::vector> base_output_shapes_base; auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); + if (split_dim < 0) { + split_dim += output_shape.size(); + } for (int i = 0; i < num_split; ++i) { output_shape[split_dim] = size_splits_base[i]; base_output_shapes_base.emplace_back(output_shape);