Browse Source

!13180 [Numpy-Native] Fix hsplit bug

From: @wangrao124
Reviewed-by: @liangchenghui,@guoqi1024
Signed-off-by: @liangchenghui
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
afe55b6881
1 changed files with 4 additions and 5 deletions
  1. +4
    -5
      mindspore/numpy/array_ops.py

+ 4
- 5
mindspore/numpy/array_ops.py View File

@@ -1412,11 +1412,10 @@ def _split_sub_tensors(x, indices, axis):
Splits the input tensor `x` into multiple sub-tensors
along the axis according to the given indices.
"""
if indices[-1] < x.shape[axis]:
if isinstance(indices, list):
indices.append(x.shape[axis])
elif isinstance(indices, tuple):
indices += (x.shape[axis],)
if isinstance(indices, list):
indices.append(x.shape[axis])
elif isinstance(indices, tuple):
indices += (x.shape[axis],)

sub_tensors = []
strides = _list_comprehensions(x.ndim, 1, True)


Loading…
Cancel
Save