| @@ -82,6 +82,9 @@ void CreateOutputShapeAndTypeId(const CNodePtr &origin_cnode, int split_dim, int | |||||
| MS_EXCEPTION_IF_NULL(new_type_ids); | MS_EXCEPTION_IF_NULL(new_type_ids); | ||||
| MS_EXCEPTION_IF_NULL(new_output_shapes); | MS_EXCEPTION_IF_NULL(new_output_shapes); | ||||
| auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); | auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); | ||||
| if (split_dim < 0) { | |||||
| split_dim += output_shape.size(); | |||||
| } | |||||
| output_shape[split_dim] = split_size; | output_shape[split_dim] = split_size; | ||||
| TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); | TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); | ||||
| for (int i = 0; i < num_split; ++i) { | for (int i = 0; i < num_split; ++i) { | ||||
| @@ -97,6 +100,9 @@ void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePt | |||||
| std::vector<std::vector<size_t>> base_output_shapes_base; | std::vector<std::vector<size_t>> base_output_shapes_base; | ||||
| auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); | auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); | ||||
| TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); | TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); | ||||
| if (split_dim < 0) { | |||||
| split_dim += output_shape.size(); | |||||
| } | |||||
| for (int i = 0; i < num_split; ++i) { | for (int i = 0; i < num_split; ++i) { | ||||
| output_shape[split_dim] = size_splits_base[i]; | output_shape[split_dim] = size_splits_base[i]; | ||||
| base_output_shapes_base.emplace_back(output_shape); | base_output_shapes_base.emplace_back(output_shape); | ||||