|
|
|
@@ -1412,11 +1412,10 @@ def _split_sub_tensors(x, indices, axis): |
|
|
|
Splits the input tensor `x` into multiple sub-tensors |
|
|
|
along the axis according to the given indices. |
|
|
|
""" |
|
|
|
if indices[-1] < x.shape[axis]: |
|
|
|
if isinstance(indices, list): |
|
|
|
indices.append(x.shape[axis]) |
|
|
|
elif isinstance(indices, tuple): |
|
|
|
indices += (x.shape[axis],) |
|
|
|
if isinstance(indices, list): |
|
|
|
indices.append(x.shape[axis]) |
|
|
|
elif isinstance(indices, tuple): |
|
|
|
indices += (x.shape[axis],) |
|
|
|
|
|
|
|
sub_tensors = [] |
|
|
|
strides = _list_comprehensions(x.ndim, 1, True) |
|
|
|
|