|
|
|
@@ -82,6 +82,9 @@ void CreateOutputShapeAndTypeId(const CNodePtr &origin_cnode, int split_dim, int |
|
|
|
MS_EXCEPTION_IF_NULL(new_type_ids); |
|
|
|
MS_EXCEPTION_IF_NULL(new_output_shapes); |
|
|
|
auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); |
|
|
|
if (split_dim < 0) { |
|
|
|
split_dim += output_shape.size(); |
|
|
|
} |
|
|
|
output_shape[split_dim] = split_size; |
|
|
|
TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); |
|
|
|
for (int i = 0; i < num_split; ++i) { |
|
|
|
@@ -97,6 +100,9 @@ void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePt |
|
|
|
std::vector<std::vector<size_t>> base_output_shapes_base; |
|
|
|
auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); |
|
|
|
TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); |
|
|
|
if (split_dim < 0) { |
|
|
|
split_dim += output_shape.size(); |
|
|
|
} |
|
|
|
for (int i = 0; i < num_split; ++i) { |
|
|
|
output_shape[split_dim] = size_splits_base[i]; |
|
|
|
base_output_shapes_base.emplace_back(output_shape); |
|
|
|
|