|
|
|
@@ -1422,6 +1422,8 @@ def _split(x, indices_or_sections, opname, axis=0): |
|
|
|
arr_shape = x.shape |
|
|
|
length_along_dim = arr_shape[axis] |
|
|
|
if isinstance(indices_or_sections, int): |
|
|
|
if indices_or_sections > length_along_dim: |
|
|
|
_raise_value_error("empty tensor encountered.") |
|
|
|
if opname == "split" or length_along_dim % indices_or_sections == 0: |
|
|
|
res = P.Split(axis, indices_or_sections)(x) |
|
|
|
else: |
|
|
|
@@ -1461,6 +1463,8 @@ def _split_sub_tensors(x, indices, axis): |
|
|
|
for i, idx in enumerate(indices): |
|
|
|
begin[axis] = 0 if i == 0 else indices[i-1] |
|
|
|
end[axis] = idx |
|
|
|
if end[axis] <= begin[axis]: |
|
|
|
_raise_value_error("empty sub-tensor encountered.") |
|
|
|
sliced_tensor = F.strided_slice(x, _type_convert(tuple, begin), _type_convert(tuple, end), strides) |
|
|
|
sub_tensors.append(sliced_tensor) |
|
|
|
return sub_tensors |
|
|
|
@@ -2136,19 +2140,19 @@ def choose(a, choices, mode='clip'): |
|
|
|
with ``shape Ba.shape`` is created as follows: |
|
|
|
|
|
|
|
- if ``mode='raise'`` (the default), then, first of all, each element of `a` |
|
|
|
(and thus `Ba`) must be in the range `[0, n-1]`; now, suppose that `i` |
|
|
|
(in that range) is the value at the `(j0, j1, ..., jm)` position in |
|
|
|
`Ba` - then the value at the same position in the new array is the |
|
|
|
value in ``Bchoices[i]`` at that same position; |
|
|
|
(and thus `Ba`) must be in the range `[0, n-1]`; now, suppose that `i` |
|
|
|
(in that range) is the value at the `(j0, j1, ..., jm)` position in |
|
|
|
`Ba` - then the value at the same position in the new array is the |
|
|
|
value in ``Bchoices[i]`` at that same position; |
|
|
|
|
|
|
|
- if ``mode='wrap'``, values in `a` (and thus `Ba`) may be any (signed) |
|
|
|
integer; modular arithmetic is used to map integers outside the |
|
|
|
range ``[0, n-1]`` back into that range; and then the new array is |
|
|
|
constructed as above; |
|
|
|
integer; modular arithmetic is used to map integers outside the |
|
|
|
range ``[0, n-1]`` back into that range; and then the new array is |
|
|
|
constructed as above; |
|
|
|
|
|
|
|
- if ``mode='clip'``, values in `a` (and thus `Ba`) may be any (signed) integer; |
|
|
|
negative integers are mapped to 0; values greater than `n-1` are mapped to |
|
|
|
`n-1`; and then the new array is constructed as above. |
|
|
|
negative integers are mapped to 0; values greater than `n-1` are mapped to |
|
|
|
`n-1`; and then the new array is constructed as above. |
|
|
|
|
|
|
|
Note: |
|
|
|
Numpy argument `out` is not supported. |
|
|
|
|