Browse Source

fix hsplit bug

tags/v1.2.0-rc1
wangrao 5 years ago
parent
commit
0acb6e1894
1 changed files with 4 additions and 5 deletions
  1. +4
    -5
      mindspore/numpy/array_ops.py

+ 4
- 5
mindspore/numpy/array_ops.py View File

@@ -1408,11 +1408,10 @@ def _split_sub_tensors(x, indices, axis):
Splits the input tensor `x` into multiple sub-tensors
along the axis according to the given indices.
"""
if indices[-1] < x.shape[axis]:
if isinstance(indices, list):
indices.append(x.shape[axis])
elif isinstance(indices, tuple):
indices += (x.shape[axis],)
if isinstance(indices, list):
indices.append(x.shape[axis])
elif isinstance(indices, tuple):
indices += (x.shape[axis],)

sub_tensors = []
strides = _list_comprehensions(x.ndim, 1, True)


Loading…
Cancel
Save